refactor: update encryption and decryption processes by deriving key material for HMAC and encryption separately, and improve HMAC validation logic

This commit is contained in:
2025-12-29 23:46:33 -06:00
parent 220912989e
commit 8e777bef03

View File

@@ -133,20 +133,25 @@ func (i *Identity) Encrypt(plaintext []byte, ratchet []byte) ([]byte, error) {
return nil, err
}
// Derive encryption key
key, err := cryptography.DeriveKey(sharedSecret, i.GetSalt(), i.GetContext(), 32)
// Derive key material (64 bytes: first 32 for HMAC, last 32 for encryption)
salt := i.GetSalt()
debug.Log(debug.DEBUG_ALL, "Encrypt: using salt", "salt", fmt.Sprintf("%x", salt), "identity_hash", fmt.Sprintf("%x", i.Hash()))
key, err := cryptography.DeriveKey(sharedSecret, salt, i.GetContext(), 64)
if err != nil {
return nil, err
}
hmacKey := key[:32]
encryptionKey := key[32:64]
// Encrypt data
ciphertext, err := cryptography.EncryptAES256CBC(key[:32], plaintext)
ciphertext, err := cryptography.EncryptAES256CBC(encryptionKey, plaintext)
if err != nil {
return nil, err
}
// Calculate HMAC
mac := cryptography.ComputeHMAC(key, append(ephemeralPubKey, ciphertext...))
// Calculate HMAC over ciphertext only (iv + encrypted_data)
mac := cryptography.ComputeHMAC(hmacKey, ciphertext)
// Combine components
token := make([]byte, 0, len(ephemeralPubKey)+len(ciphertext)+len(mac))
@@ -221,13 +226,18 @@ func FromPublicKey(publicKey []byte) *Identity {
return nil
}
return &Identity{
id := &Identity{
publicKey: publicKey[:KEYSIZE/16],
verificationKey: publicKey[KEYSIZE/16:],
ratchets: make(map[string][]byte),
ratchetExpiry: make(map[string]int64),
mutex: &sync.RWMutex{},
}
hash := cryptography.Hash(id.GetPublicKey())
id.hash = hash[:TRUNCATED_HASHLENGTH/8]
return id
}
func (i *Identity) Hex() string {
@@ -335,7 +345,7 @@ func (i *Identity) Decrypt(ciphertextToken []byte, ratchets [][]byte, enforceRat
// Try decryption with ratchets first if provided
if len(ratchets) > 0 {
for _, ratchet := range ratchets {
if decrypted, ratchetID, err := i.tryRatchetDecryption(peerPubBytes, ciphertext, ratchet); err == nil {
if decrypted, ratchetID, err := i.tryRatchetDecryption(peerPubBytes, ciphertext, mac, ratchet); err == nil {
if ratchetIDReceiver != nil {
ratchetIDReceiver.LatestRatchetID = ratchetID
}
@@ -357,20 +367,25 @@ func (i *Identity) Decrypt(ciphertextToken []byte, ratchets [][]byte, enforceRat
return nil, fmt.Errorf("failed to generate shared key: %v", err)
}
// Derive key using HKDF
hkdfReader := hkdf.New(sha256.New, sharedKey, i.GetSalt(), i.GetContext())
derivedKey := make([]byte, 32)
// Derive key material (64 bytes: first 32 for HMAC, last 32 for encryption)
salt := i.GetSalt()
debug.Log(debug.DEBUG_ALL, "Decrypt: using salt", "salt", fmt.Sprintf("%x", salt), "identity_hash", fmt.Sprintf("%x", i.Hash()))
hkdfReader := hkdf.New(sha256.New, sharedKey, salt, i.GetContext())
derivedKey := make([]byte, 64)
if _, err := io.ReadFull(hkdfReader, derivedKey); err != nil {
return nil, fmt.Errorf("failed to derive key: %v", err)
}
// Validate HMAC
if !cryptography.ValidateHMAC(derivedKey, append(peerPubBytes, ciphertext...), mac) {
hmacKey := derivedKey[:32]
encryptionKey := derivedKey[32:64]
// Validate HMAC over ciphertext only (iv + encrypted_data)
if !cryptography.ValidateHMAC(hmacKey, ciphertext, mac) {
return nil, errors.New("invalid HMAC")
}
// Create AES cipher
block, err := aes.NewCipher(derivedKey)
block, err := aes.NewCipher(encryptionKey)
if err != nil {
return nil, fmt.Errorf("failed to create cipher: %v", err)
}
@@ -412,7 +427,7 @@ func (i *Identity) Decrypt(ciphertextToken []byte, ratchets [][]byte, enforceRat
}
// Helper function to attempt decryption using a ratchet
func (i *Identity) tryRatchetDecryption(peerPubBytes, ciphertext, ratchet []byte) (plaintext, ratchetID []byte, err error) {
func (i *Identity) tryRatchetDecryption(peerPubBytes, ciphertext, mac, ratchet []byte) (plaintext, ratchetID []byte, err error) {
// Convert ratchet to private key
ratchetPriv := ratchet
@@ -429,12 +444,20 @@ func (i *Identity) tryRatchetDecryption(peerPubBytes, ciphertext, ratchet []byte
return nil, nil, err
}
key, err := cryptography.DeriveKey(sharedSecret, i.GetSalt(), i.GetContext(), 32)
key, err := cryptography.DeriveKey(sharedSecret, i.GetSalt(), i.GetContext(), 64)
if err != nil {
return nil, nil, err
}
plaintext, err = cryptography.DecryptAES256CBC(key, ciphertext)
hmacKey := key[:32]
encryptionKey := key[32:64]
// Validate HMAC over ciphertext only (iv + encrypted_data)
if !cryptography.ValidateHMAC(hmacKey, ciphertext, mac) {
return nil, nil, errors.New("invalid HMAC")
}
plaintext, err = cryptography.DecryptAES256CBC(encryptionKey, ciphertext)
if err != nil {
return nil, nil, err
}
@@ -718,7 +741,7 @@ func RecallIdentity(path string) (*Identity, error) {
copy(combinedPub[:KEYSIZE/16], id.publicKey)
copy(combinedPub[KEYSIZE/16:], id.verificationKey)
hash := sha256.Sum256(combinedPub)
id.hash = hash[:]
id.hash = hash[:TRUNCATED_HASHLENGTH/8]
debug.Log(debug.DEBUG_ALL, "Successfully recalled identity", "hash", id.GetHexHash())
return id, nil