refactor: update encryption and decryption processes by deriving key material for HMAC and encryption separately, and improve HMAC validation logic

This commit is contained in:
2025-12-29 23:46:33 -06:00
parent 220912989e
commit 8e777bef03

View File

@@ -133,20 +133,25 @@ func (i *Identity) Encrypt(plaintext []byte, ratchet []byte) ([]byte, error) {
return nil, err return nil, err
} }
// Derive encryption key // Derive key material (64 bytes: first 32 for HMAC, last 32 for encryption)
key, err := cryptography.DeriveKey(sharedSecret, i.GetSalt(), i.GetContext(), 32) salt := i.GetSalt()
debug.Log(debug.DEBUG_ALL, "Encrypt: using salt", "salt", fmt.Sprintf("%x", salt), "identity_hash", fmt.Sprintf("%x", i.Hash()))
key, err := cryptography.DeriveKey(sharedSecret, salt, i.GetContext(), 64)
if err != nil { if err != nil {
return nil, err return nil, err
} }
hmacKey := key[:32]
encryptionKey := key[32:64]
// Encrypt data // Encrypt data
ciphertext, err := cryptography.EncryptAES256CBC(key[:32], plaintext) ciphertext, err := cryptography.EncryptAES256CBC(encryptionKey, plaintext)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Calculate HMAC // Calculate HMAC over ciphertext only (iv + encrypted_data)
mac := cryptography.ComputeHMAC(key, append(ephemeralPubKey, ciphertext...)) mac := cryptography.ComputeHMAC(hmacKey, ciphertext)
// Combine components // Combine components
token := make([]byte, 0, len(ephemeralPubKey)+len(ciphertext)+len(mac)) token := make([]byte, 0, len(ephemeralPubKey)+len(ciphertext)+len(mac))
@@ -221,13 +226,18 @@ func FromPublicKey(publicKey []byte) *Identity {
return nil return nil
} }
return &Identity{ id := &Identity{
publicKey: publicKey[:KEYSIZE/16], publicKey: publicKey[:KEYSIZE/16],
verificationKey: publicKey[KEYSIZE/16:], verificationKey: publicKey[KEYSIZE/16:],
ratchets: make(map[string][]byte), ratchets: make(map[string][]byte),
ratchetExpiry: make(map[string]int64), ratchetExpiry: make(map[string]int64),
mutex: &sync.RWMutex{}, mutex: &sync.RWMutex{},
} }
hash := cryptography.Hash(id.GetPublicKey())
id.hash = hash[:TRUNCATED_HASHLENGTH/8]
return id
} }
func (i *Identity) Hex() string { func (i *Identity) Hex() string {
@@ -335,7 +345,7 @@ func (i *Identity) Decrypt(ciphertextToken []byte, ratchets [][]byte, enforceRat
// Try decryption with ratchets first if provided // Try decryption with ratchets first if provided
if len(ratchets) > 0 { if len(ratchets) > 0 {
for _, ratchet := range ratchets { for _, ratchet := range ratchets {
if decrypted, ratchetID, err := i.tryRatchetDecryption(peerPubBytes, ciphertext, ratchet); err == nil { if decrypted, ratchetID, err := i.tryRatchetDecryption(peerPubBytes, ciphertext, mac, ratchet); err == nil {
if ratchetIDReceiver != nil { if ratchetIDReceiver != nil {
ratchetIDReceiver.LatestRatchetID = ratchetID ratchetIDReceiver.LatestRatchetID = ratchetID
} }
@@ -357,20 +367,25 @@ func (i *Identity) Decrypt(ciphertextToken []byte, ratchets [][]byte, enforceRat
return nil, fmt.Errorf("failed to generate shared key: %v", err) return nil, fmt.Errorf("failed to generate shared key: %v", err)
} }
// Derive key using HKDF // Derive key material (64 bytes: first 32 for HMAC, last 32 for encryption)
hkdfReader := hkdf.New(sha256.New, sharedKey, i.GetSalt(), i.GetContext()) salt := i.GetSalt()
derivedKey := make([]byte, 32) debug.Log(debug.DEBUG_ALL, "Decrypt: using salt", "salt", fmt.Sprintf("%x", salt), "identity_hash", fmt.Sprintf("%x", i.Hash()))
hkdfReader := hkdf.New(sha256.New, sharedKey, salt, i.GetContext())
derivedKey := make([]byte, 64)
if _, err := io.ReadFull(hkdfReader, derivedKey); err != nil { if _, err := io.ReadFull(hkdfReader, derivedKey); err != nil {
return nil, fmt.Errorf("failed to derive key: %v", err) return nil, fmt.Errorf("failed to derive key: %v", err)
} }
// Validate HMAC hmacKey := derivedKey[:32]
if !cryptography.ValidateHMAC(derivedKey, append(peerPubBytes, ciphertext...), mac) { encryptionKey := derivedKey[32:64]
// Validate HMAC over ciphertext only (iv + encrypted_data)
if !cryptography.ValidateHMAC(hmacKey, ciphertext, mac) {
return nil, errors.New("invalid HMAC") return nil, errors.New("invalid HMAC")
} }
// Create AES cipher // Create AES cipher
block, err := aes.NewCipher(derivedKey) block, err := aes.NewCipher(encryptionKey)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create cipher: %v", err) return nil, fmt.Errorf("failed to create cipher: %v", err)
} }
@@ -412,7 +427,7 @@ func (i *Identity) Decrypt(ciphertextToken []byte, ratchets [][]byte, enforceRat
} }
// Helper function to attempt decryption using a ratchet // Helper function to attempt decryption using a ratchet
func (i *Identity) tryRatchetDecryption(peerPubBytes, ciphertext, ratchet []byte) (plaintext, ratchetID []byte, err error) { func (i *Identity) tryRatchetDecryption(peerPubBytes, ciphertext, mac, ratchet []byte) (plaintext, ratchetID []byte, err error) {
// Convert ratchet to private key // Convert ratchet to private key
ratchetPriv := ratchet ratchetPriv := ratchet
@@ -429,12 +444,20 @@ func (i *Identity) tryRatchetDecryption(peerPubBytes, ciphertext, ratchet []byte
return nil, nil, err return nil, nil, err
} }
key, err := cryptography.DeriveKey(sharedSecret, i.GetSalt(), i.GetContext(), 32) key, err := cryptography.DeriveKey(sharedSecret, i.GetSalt(), i.GetContext(), 64)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
plaintext, err = cryptography.DecryptAES256CBC(key, ciphertext) hmacKey := key[:32]
encryptionKey := key[32:64]
// Validate HMAC over ciphertext only (iv + encrypted_data)
if !cryptography.ValidateHMAC(hmacKey, ciphertext, mac) {
return nil, nil, errors.New("invalid HMAC")
}
plaintext, err = cryptography.DecryptAES256CBC(encryptionKey, ciphertext)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@@ -718,7 +741,7 @@ func RecallIdentity(path string) (*Identity, error) {
copy(combinedPub[:KEYSIZE/16], id.publicKey) copy(combinedPub[:KEYSIZE/16], id.publicKey)
copy(combinedPub[KEYSIZE/16:], id.verificationKey) copy(combinedPub[KEYSIZE/16:], id.verificationKey)
hash := sha256.Sum256(combinedPub) hash := sha256.Sum256(combinedPub)
id.hash = hash[:] id.hash = hash[:TRUNCATED_HASHLENGTH/8]
debug.Log(debug.DEBUG_ALL, "Successfully recalled identity", "hash", id.GetHexHash()) debug.Log(debug.DEBUG_ALL, "Successfully recalled identity", "hash", id.GetHexHash())
return id, nil return id, nil