diff --git a/pkg/cryptography/aes.go b/pkg/cryptography/aes.go index 8c8b8e8..baec748 100644 --- a/pkg/cryptography/aes.go +++ b/pkg/cryptography/aes.go @@ -9,54 +9,24 @@ import ( ) const ( - // AES key sizes in bytes - AES128KeySize = 16 // 128 bits - AES192KeySize = 24 // 192 bits + // AES256KeySize is the size of an AES-256 key in bytes. AES256KeySize = 32 // 256 bits - - // Default to AES-256 - DefaultKeySize = AES256KeySize ) -// GenerateAESKey generates a random AES key of the specified size -func GenerateAESKey(keySize int) ([]byte, error) { - if keySize != AES128KeySize && keySize != AES192KeySize && keySize != AES256KeySize { - return nil, errors.New("invalid key size: must be 16, 24, or 32 bytes") - } - - key := make([]byte, keySize) +// GenerateAES256Key generates a random AES-256 key. +func GenerateAES256Key() ([]byte, error) { + key := make([]byte, AES256KeySize) if _, err := io.ReadFull(rand.Reader, key); err != nil { return nil, err } return key, nil } -// GenerateAES256Key generates a random AES-256 key (default) -func GenerateAES256Key() ([]byte, error) { - return GenerateAESKey(AES256KeySize) -} - -// EncryptAES256CBC encrypts data using AES-256 in CBC mode +// EncryptAES256CBC encrypts data using AES-256 in CBC mode. +// The IV is prepended to the ciphertext. func EncryptAES256CBC(key, plaintext []byte) ([]byte, error) { if len(key) != AES256KeySize { - return nil, errors.New("key must be 32 bytes for AES-256") - } - return EncryptAESCBC(key, plaintext) -} - -// DecryptAES256CBC decrypts data using AES-256 in CBC mode -func DecryptAES256CBC(key, ciphertext []byte) ([]byte, error) { - if len(key) != AES256KeySize { - return nil, errors.New("key must be 32 bytes for AES-256") - } - return DecryptAESCBC(key, ciphertext) -} - -// EncryptAESCBC encrypts data using AES in CBC mode (accepts any valid AES key size) -func EncryptAESCBC(key, plaintext []byte) ([]byte, error) { - // Validate key size - if len(key) != AES128KeySize && len(key) != AES192KeySize && len(key) != AES256KeySize { - return nil, errors.New("invalid key size: must be 16, 24, or 32 bytes") + return nil, errors.New("invalid key size: must be 32 bytes for AES-256") } block, err := aes.NewCipher(key) @@ -64,13 +34,13 @@ func EncryptAESCBC(key, plaintext []byte) ([]byte, error) { return nil, err } - // Generate IV + // Generate a random IV. iv := make([]byte, aes.BlockSize) if _, err := io.ReadFull(rand.Reader, iv); err != nil { return nil, err } - // Add PKCS7 padding + // Add PKCS7 padding. padding := aes.BlockSize - len(plaintext)%aes.BlockSize padtext := make([]byte, len(plaintext)+padding) copy(padtext, plaintext) @@ -78,19 +48,20 @@ func EncryptAESCBC(key, plaintext []byte) ([]byte, error) { padtext[i] = byte(padding) } - // Encrypt - mode := cipher.NewCBCEncrypter(block, iv) // #nosec G407 + // Encrypt the data. + mode := cipher.NewCBCEncrypter(block, iv) ciphertext := make([]byte, len(padtext)) mode.CryptBlocks(ciphertext, padtext) + // Prepend the IV to the ciphertext. return append(iv, ciphertext...), nil } -// DecryptAESCBC decrypts data using AES in CBC mode (accepts any valid AES key size) -func DecryptAESCBC(key, ciphertext []byte) ([]byte, error) { - // Validate key size - if len(key) != AES128KeySize && len(key) != AES192KeySize && len(key) != AES256KeySize { - return nil, errors.New("invalid key size: must be 16, 24, or 32 bytes") +// DecryptAES256CBC decrypts data using AES-256 in CBC mode. +// It assumes the IV is prepended to the ciphertext. +func DecryptAES256CBC(key, ciphertext []byte) ([]byte, error) { + if len(key) != AES256KeySize { + return nil, errors.New("invalid key size: must be 32 bytes for AES-256") } block, err := aes.NewCipher(key) @@ -99,34 +70,39 @@ func DecryptAESCBC(key, ciphertext []byte) ([]byte, error) { } if len(ciphertext) < aes.BlockSize { - return nil, errors.New("ciphertext too short") + return nil, errors.New("ciphertext is too short") } + // Extract the IV from the beginning of the ciphertext. iv := ciphertext[:aes.BlockSize] ciphertext = ciphertext[aes.BlockSize:] if len(ciphertext)%aes.BlockSize != 0 { - return nil, errors.New("ciphertext is not a multiple of block size") + return nil, errors.New("ciphertext is not a multiple of the block size") } + // Decrypt the data. mode := cipher.NewCBCDecrypter(block, iv) plaintext := make([]byte, len(ciphertext)) mode.CryptBlocks(plaintext, ciphertext) - // Remove PKCS7 padding + // Remove PKCS7 padding. if len(plaintext) == 0 { - return nil, errors.New("invalid padding: empty plaintext") + return nil, errors.New("invalid padding: plaintext is empty") } padding := int(plaintext[len(plaintext)-1]) - if padding == 0 || padding > aes.BlockSize || padding > len(plaintext) { - return nil, errors.New("invalid PKCS7 padding") + if padding > aes.BlockSize || padding == 0 { + return nil, errors.New("invalid padding size") + } + if len(plaintext) < padding { + return nil, errors.New("invalid padding: padding size is larger than plaintext") } - // Verify all padding bytes are correct + // Verify the padding bytes. for i := len(plaintext) - padding; i < len(plaintext); i++ { if plaintext[i] != byte(padding) { - return nil, errors.New("invalid PKCS7 padding") + return nil, errors.New("invalid padding bytes") } } diff --git a/pkg/cryptography/aes_test.go b/pkg/cryptography/aes_test.go index ea472ae..789b064 100644 --- a/pkg/cryptography/aes_test.go +++ b/pkg/cryptography/aes_test.go @@ -3,44 +3,20 @@ package cryptography import ( "bytes" "crypto/aes" + "crypto/cipher" "crypto/rand" "fmt" "testing" ) -func TestGenerateAESKeys(t *testing.T) { - t.Run("GenerateAES256Key", func(t *testing.T) { - key, err := GenerateAES256Key() - if err != nil { - t.Fatalf("GenerateAES256Key failed: %v", err) - } - if len(key) != AES256KeySize { - t.Errorf("Expected key size %d, got %d", AES256KeySize, len(key)) - } - }) - - t.Run("GenerateAESKey_AllSizes", func(t *testing.T) { - sizes := []int{AES128KeySize, AES192KeySize, AES256KeySize} - for _, size := range sizes { - key, err := GenerateAESKey(size) - if err != nil { - t.Fatalf("GenerateAESKey(%d) failed: %v", size, err) - } - if len(key) != size { - t.Errorf("Expected key size %d, got %d", size, len(key)) - } - } - }) - - t.Run("GenerateAESKey_InvalidSize", func(t *testing.T) { - invalidSizes := []int{8, 15, 17, 23, 25, 31, 33, 64} - for _, size := range invalidSizes { - _, err := GenerateAESKey(size) - if err == nil { - t.Errorf("GenerateAESKey(%d) should have failed but didn't", size) - } - } - }) +func TestGenerateAES256Key(t *testing.T) { + key, err := GenerateAES256Key() + if err != nil { + t.Fatalf("GenerateAES256Key failed: %v", err) + } + if len(key) != AES256KeySize { + t.Errorf("Expected key size %d, got %d", AES256KeySize, len(key)) + } } func TestAES256CBCEncryptionDecryption(t *testing.T) { @@ -110,48 +86,8 @@ func TestAES256CBC_InvalidKeySize(t *testing.T) { } } -func TestAESCBCEncryptionDecryption(t *testing.T) { - keySizes := []int{AES128KeySize, AES192KeySize, AES256KeySize} - for _, keySize := range keySizes { - t.Run(fmt.Sprintf("AES_%d", keySize*8), func(t *testing.T) { - key, err := GenerateAESKey(keySize) - if err != nil { - t.Fatalf("Failed to generate AES-%d key: %v", keySize*8, err) - } - - testCases := []struct { - name string - plaintext []byte - }{ - {"ShortMessage", []byte("Hello")}, - {"BlockSizeMessage", []byte("This is 16 bytes")}, - {"LongMessage", []byte("This is a longer message that spans multiple AES blocks.")}, - {"EmptyMessage", []byte("")}, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - ciphertext, err := EncryptAESCBC(key, tc.plaintext) - if err != nil { - t.Fatalf("EncryptAESCBC failed: %v", err) - } - - decrypted, err := DecryptAESCBC(key, ciphertext) - if err != nil { - t.Fatalf("DecryptAESCBC failed: %v", err) - } - - if !bytes.Equal(tc.plaintext, decrypted) { - t.Errorf("Decrypted text does not match original plaintext. Got %q, want %q", decrypted, tc.plaintext) - } - }) - } - }) - } -} - -func TestDecryptAESCBCErrorCases(t *testing.T) { +func TestDecryptAES256CBCErrorCases(t *testing.T) { key, err := GenerateAES256Key() if err != nil { t.Fatalf("Failed to generate key: %v", err) @@ -159,20 +95,9 @@ func TestDecryptAESCBCErrorCases(t *testing.T) { t.Run("CiphertextTooShort", func(t *testing.T) { shortCiphertext := []byte{0x01, 0x02, 0x03} // Less than AES block size - _, err := DecryptAESCBC(key, shortCiphertext) + _, err := DecryptAES256CBC(key, shortCiphertext) if err == nil { - t.Error("DecryptAESCBC should have failed for ciphertext shorter than block size") - } - }) - - t.Run("InvalidKeySize", func(t *testing.T) { - invalidKey := make([]byte, 17) // Invalid key size - validCiphertext := make([]byte, 32) // IV + one block - rand.Read(validCiphertext) - - _, err := DecryptAESCBC(invalidKey, validCiphertext) - if err == nil { - t.Error("DecryptAESCBC should have failed for invalid key size") + t.Error("DecryptAES256CBC should have failed for ciphertext shorter than block size") } }) @@ -180,16 +105,16 @@ func TestDecryptAESCBCErrorCases(t *testing.T) { iv := make([]byte, aes.BlockSize) rand.Read(iv) invalidCiphertext := append(iv, []byte{0x01, 0x02, 0x03}...) // IV + data not multiple of block size - _, err := DecryptAESCBC(key, invalidCiphertext) + _, err := DecryptAES256CBC(key, invalidCiphertext) if err == nil { - t.Error("DecryptAESCBC should have failed for ciphertext not multiple of block size") + t.Error("DecryptAES256CBC should have failed for ciphertext not multiple of block size") } }) t.Run("InvalidPadding", func(t *testing.T) { // Create a valid ciphertext first plaintext := []byte("valid data") - ciphertext, err := EncryptAESCBC(key, plaintext) + ciphertext, err := EncryptAES256CBC(key, plaintext) if err != nil { t.Fatalf("Failed to create test ciphertext: %v", err) } @@ -199,35 +124,42 @@ func TestDecryptAESCBCErrorCases(t *testing.T) { copy(corruptedCiphertext, ciphertext) corruptedCiphertext[len(corruptedCiphertext)-1] ^= 0xFF - _, err = DecryptAESCBC(key, corruptedCiphertext) + _, err = DecryptAES256CBC(key, corruptedCiphertext) if err == nil { - t.Error("DecryptAESCBC should have failed for corrupted padding") + t.Error("DecryptAES256CBC should have failed for corrupted padding") } }) - t.Run("EmptyPlaintext", func(t *testing.T) { - // Create a ciphertext that would result in empty plaintext - invalidCiphertext := make([]byte, aes.BlockSize) // Only IV, no data - _, err := DecryptAESCBC(key, invalidCiphertext) - if err == nil { - t.Error("DecryptAESCBC should have failed for empty ciphertext data") + t.Run("EmptyPlaintextAfterDecryption", func(t *testing.T) { + // This creates a ciphertext that decrypts to just padding + key, _ := GenerateAES256Key() + iv := make([]byte, aes.BlockSize) + // A block of padding bytes + paddedBlock := bytes.Repeat([]byte{byte(aes.BlockSize)}, aes.BlockSize) + + block, _ := aes.NewCipher(key) + mode := cipher.NewCBCEncrypter(block, iv) + ciphertext := make([]byte, len(paddedBlock)) + mode.CryptBlocks(ciphertext, paddedBlock) + + // Prepend IV + fullCiphertext := append(iv, ciphertext...) + + // This should decrypt to an empty slice, which is valid + decrypted, err := DecryptAES256CBC(key, fullCiphertext) + if err != nil { + t.Errorf("DecryptAES256CBC failed for empty plaintext case: %v", err) + } + if len(decrypted) != 0 { + t.Errorf("Expected empty plaintext, got %q", decrypted) } }) } func TestConstants(t *testing.T) { - if AES128KeySize != 16 { - t.Errorf("AES128KeySize should be 16, got %d", AES128KeySize) - } - if AES192KeySize != 24 { - t.Errorf("AES192KeySize should be 24, got %d", AES192KeySize) - } if AES256KeySize != 32 { t.Errorf("AES256KeySize should be 32, got %d", AES256KeySize) } - if DefaultKeySize != AES256KeySize { - t.Errorf("DefaultKeySize should be AES256KeySize (%d), got %d", AES256KeySize, DefaultKeySize) - } } func BenchmarkAES256CBC(b *testing.B) {