diff --git a/pkg/identity/identity.go b/pkg/identity/identity.go index 5ec909f..ba86104 100644 --- a/pkg/identity/identity.go +++ b/pkg/identity/identity.go @@ -133,20 +133,25 @@ func (i *Identity) Encrypt(plaintext []byte, ratchet []byte) ([]byte, error) { return nil, err } - // Derive encryption key - key, err := cryptography.DeriveKey(sharedSecret, i.GetSalt(), i.GetContext(), 32) + // Derive key material (64 bytes: first 32 for HMAC, last 32 for encryption) + salt := i.GetSalt() + debug.Log(debug.DEBUG_ALL, "Encrypt: using salt", "salt", fmt.Sprintf("%x", salt), "identity_hash", fmt.Sprintf("%x", i.Hash())) + key, err := cryptography.DeriveKey(sharedSecret, salt, i.GetContext(), 64) if err != nil { return nil, err } + hmacKey := key[:32] + encryptionKey := key[32:64] + // Encrypt data - ciphertext, err := cryptography.EncryptAES256CBC(key[:32], plaintext) + ciphertext, err := cryptography.EncryptAES256CBC(encryptionKey, plaintext) if err != nil { return nil, err } - // Calculate HMAC - mac := cryptography.ComputeHMAC(key, append(ephemeralPubKey, ciphertext...)) + // Calculate HMAC over ciphertext only (iv + encrypted_data) + mac := cryptography.ComputeHMAC(hmacKey, ciphertext) // Combine components token := make([]byte, 0, len(ephemeralPubKey)+len(ciphertext)+len(mac)) @@ -221,13 +226,18 @@ func FromPublicKey(publicKey []byte) *Identity { return nil } - return &Identity{ + id := &Identity{ publicKey: publicKey[:KEYSIZE/16], verificationKey: publicKey[KEYSIZE/16:], ratchets: make(map[string][]byte), ratchetExpiry: make(map[string]int64), mutex: &sync.RWMutex{}, } + + hash := cryptography.Hash(id.GetPublicKey()) + id.hash = hash[:TRUNCATED_HASHLENGTH/8] + + return id } func (i *Identity) Hex() string { @@ -335,7 +345,7 @@ func (i *Identity) Decrypt(ciphertextToken []byte, ratchets [][]byte, enforceRat // Try decryption with ratchets first if provided if len(ratchets) > 0 { for _, ratchet := range ratchets { - if decrypted, ratchetID, err := i.tryRatchetDecryption(peerPubBytes, ciphertext, ratchet); err == nil { + if decrypted, ratchetID, err := i.tryRatchetDecryption(peerPubBytes, ciphertext, mac, ratchet); err == nil { if ratchetIDReceiver != nil { ratchetIDReceiver.LatestRatchetID = ratchetID } @@ -357,20 +367,25 @@ func (i *Identity) Decrypt(ciphertextToken []byte, ratchets [][]byte, enforceRat return nil, fmt.Errorf("failed to generate shared key: %v", err) } - // Derive key using HKDF - hkdfReader := hkdf.New(sha256.New, sharedKey, i.GetSalt(), i.GetContext()) - derivedKey := make([]byte, 32) + // Derive key material (64 bytes: first 32 for HMAC, last 32 for encryption) + salt := i.GetSalt() + debug.Log(debug.DEBUG_ALL, "Decrypt: using salt", "salt", fmt.Sprintf("%x", salt), "identity_hash", fmt.Sprintf("%x", i.Hash())) + hkdfReader := hkdf.New(sha256.New, sharedKey, salt, i.GetContext()) + derivedKey := make([]byte, 64) if _, err := io.ReadFull(hkdfReader, derivedKey); err != nil { return nil, fmt.Errorf("failed to derive key: %v", err) } - // Validate HMAC - if !cryptography.ValidateHMAC(derivedKey, append(peerPubBytes, ciphertext...), mac) { + hmacKey := derivedKey[:32] + encryptionKey := derivedKey[32:64] + + // Validate HMAC over ciphertext only (iv + encrypted_data) + if !cryptography.ValidateHMAC(hmacKey, ciphertext, mac) { return nil, errors.New("invalid HMAC") } // Create AES cipher - block, err := aes.NewCipher(derivedKey) + block, err := aes.NewCipher(encryptionKey) if err != nil { return nil, fmt.Errorf("failed to create cipher: %v", err) } @@ -412,7 +427,7 @@ func (i *Identity) Decrypt(ciphertextToken []byte, ratchets [][]byte, enforceRat } // Helper function to attempt decryption using a ratchet -func (i *Identity) tryRatchetDecryption(peerPubBytes, ciphertext, ratchet []byte) (plaintext, ratchetID []byte, err error) { +func (i *Identity) tryRatchetDecryption(peerPubBytes, ciphertext, mac, ratchet []byte) (plaintext, ratchetID []byte, err error) { // Convert ratchet to private key ratchetPriv := ratchet @@ -429,12 +444,20 @@ func (i *Identity) tryRatchetDecryption(peerPubBytes, ciphertext, ratchet []byte return nil, nil, err } - key, err := cryptography.DeriveKey(sharedSecret, i.GetSalt(), i.GetContext(), 32) + key, err := cryptography.DeriveKey(sharedSecret, i.GetSalt(), i.GetContext(), 64) if err != nil { return nil, nil, err } - plaintext, err = cryptography.DecryptAES256CBC(key, ciphertext) + hmacKey := key[:32] + encryptionKey := key[32:64] + + // Validate HMAC over ciphertext only (iv + encrypted_data) + if !cryptography.ValidateHMAC(hmacKey, ciphertext, mac) { + return nil, nil, errors.New("invalid HMAC") + } + + plaintext, err = cryptography.DecryptAES256CBC(encryptionKey, ciphertext) if err != nil { return nil, nil, err } @@ -718,7 +741,7 @@ func RecallIdentity(path string) (*Identity, error) { copy(combinedPub[:KEYSIZE/16], id.publicKey) copy(combinedPub[KEYSIZE/16:], id.verificationKey) hash := sha256.Sum256(combinedPub) - id.hash = hash[:] + id.hash = hash[:TRUNCATED_HASHLENGTH/8] debug.Log(debug.DEBUG_ALL, "Successfully recalled identity", "hash", id.GetHexHash()) return id, nil