diff --git a/pkg/identity/identity.go b/pkg/identity/identity.go index 8685443..7edc929 100644 --- a/pkg/identity/identity.go +++ b/pkg/identity/identity.go @@ -307,9 +307,14 @@ func (i *Identity) Decrypt(ciphertextToken []byte, ratchets [][]byte, enforceRat return nil, errors.New("decryption failed because the token size was invalid") } - // Extract peer public key and ciphertext - peerPubBytes := ciphertextToken[:KEYSIZE/8/2] - ciphertext := ciphertextToken[KEYSIZE/8/2:] + // Extract components: ephemeralPubKey(32) + ciphertext + mac(32) + if len(ciphertextToken) < 32+32+32 { // minimum sizes + return nil, errors.New("token too short") + } + + peerPubBytes := ciphertextToken[:32] + ciphertext := ciphertextToken[32 : len(ciphertextToken)-32] + mac := ciphertextToken[len(ciphertextToken)-32:] // Try decryption with ratchets first if provided if len(ratchets) > 0 { @@ -343,6 +348,11 @@ func (i *Identity) Decrypt(ciphertextToken []byte, ratchets [][]byte, enforceRat return nil, fmt.Errorf("failed to derive key: %v", err) } + // Validate HMAC + if !cryptography.ValidateHMAC(derivedKey, append(peerPubBytes, ciphertext...), mac) { + return nil, errors.New("invalid HMAC") + } + // Create AES cipher block, err := aes.NewCipher(derivedKey) if err != nil {