From 5c19b763675256520bbf1c5d030be9c76b7818a7 Mon Sep 17 00:00:00 2001 From: Sudo-Ivan Date: Tue, 30 Dec 2025 01:27:30 -0600 Subject: [PATCH] fix: improve error handling in DeriveKey function to return errors for empty secret and zero length key requests --- pkg/cryptography/hkdf.go | 47 +++++++++++++++++++++++++++++------ pkg/cryptography/hkdf_test.go | 13 ++++------ 2 files changed, 44 insertions(+), 16 deletions(-) diff --git a/pkg/cryptography/hkdf.go b/pkg/cryptography/hkdf.go index e53582a..b6fb669 100644 --- a/pkg/cryptography/hkdf.go +++ b/pkg/cryptography/hkdf.go @@ -1,17 +1,48 @@ package cryptography import ( + "crypto/hmac" "crypto/sha256" - "io" - - "golang.org/x/crypto/hkdf" + "errors" + "math" ) func DeriveKey(secret, salt, info []byte, length int) ([]byte, error) { - hkdfReader := hkdf.New(sha256.New, secret, salt, info) - key := make([]byte, length) - if _, err := io.ReadFull(hkdfReader, key); err != nil { - return nil, err + hashLen := 32 + + if length < 1 { + return nil, errors.New("invalid output key length") } - return key, nil + + if len(secret) == 0 { + return nil, errors.New("cannot derive key from empty input material") + } + + if len(salt) == 0 { + salt = make([]byte, hashLen) + } + + if info == nil { + info = []byte{} + } + + pseudorandomKey := hmac.New(sha256.New, salt) + pseudorandomKey.Write(secret) + prk := pseudorandomKey.Sum(nil) + + block := []byte{} + derived := []byte{} + + iterations := int(math.Ceil(float64(length) / float64(hashLen))) + for i := 0; i < iterations; i++ { + h := hmac.New(sha256.New, prk) + h.Write(block) + h.Write(info) + counter := byte((i + 1) % (0xFF + 1)) + h.Write([]byte{counter}) + block = h.Sum(nil) + derived = append(derived, block...) + } + + return derived[:length], nil } diff --git a/pkg/cryptography/hkdf_test.go b/pkg/cryptography/hkdf_test.go index 13b8dfe..346098f 100644 --- a/pkg/cryptography/hkdf_test.go +++ b/pkg/cryptography/hkdf_test.go @@ -77,8 +77,8 @@ func TestDeriveKeyEdgeCases(t *testing.T) { t.Run("EmptySecret", func(t *testing.T) { _, err := DeriveKey([]byte{}, salt, info, 32) - if err != nil { - t.Errorf("DeriveKey failed with empty secret: %v", err) + if err == nil { + t.Errorf("DeriveKey should fail with empty secret") } }) @@ -97,12 +97,9 @@ func TestDeriveKeyEdgeCases(t *testing.T) { }) t.Run("ZeroLength", func(t *testing.T) { - key, err := DeriveKey(secret, salt, info, 0) - if err != nil { - t.Errorf("DeriveKey failed with zero length: %v", err) - } - if len(key) != 0 { - t.Errorf("DeriveKey with zero length returned non-empty key: %x", key) + _, err := DeriveKey(secret, salt, info, 0) + if err == nil { + t.Errorf("DeriveKey should fail with zero length") } }) }