fix: improve error handling in DeriveKey function to return errors for empty secret and zero length key requests
This commit is contained in:
@@ -1,17 +1,48 @@
|
||||
package cryptography
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"io"
|
||||
|
||||
"golang.org/x/crypto/hkdf"
|
||||
"errors"
|
||||
"math"
|
||||
)
|
||||
|
||||
func DeriveKey(secret, salt, info []byte, length int) ([]byte, error) {
|
||||
hkdfReader := hkdf.New(sha256.New, secret, salt, info)
|
||||
key := make([]byte, length)
|
||||
if _, err := io.ReadFull(hkdfReader, key); err != nil {
|
||||
return nil, err
|
||||
hashLen := 32
|
||||
|
||||
if length < 1 {
|
||||
return nil, errors.New("invalid output key length")
|
||||
}
|
||||
return key, nil
|
||||
|
||||
if len(secret) == 0 {
|
||||
return nil, errors.New("cannot derive key from empty input material")
|
||||
}
|
||||
|
||||
if len(salt) == 0 {
|
||||
salt = make([]byte, hashLen)
|
||||
}
|
||||
|
||||
if info == nil {
|
||||
info = []byte{}
|
||||
}
|
||||
|
||||
pseudorandomKey := hmac.New(sha256.New, salt)
|
||||
pseudorandomKey.Write(secret)
|
||||
prk := pseudorandomKey.Sum(nil)
|
||||
|
||||
block := []byte{}
|
||||
derived := []byte{}
|
||||
|
||||
iterations := int(math.Ceil(float64(length) / float64(hashLen)))
|
||||
for i := 0; i < iterations; i++ {
|
||||
h := hmac.New(sha256.New, prk)
|
||||
h.Write(block)
|
||||
h.Write(info)
|
||||
counter := byte((i + 1) % (0xFF + 1))
|
||||
h.Write([]byte{counter})
|
||||
block = h.Sum(nil)
|
||||
derived = append(derived, block...)
|
||||
}
|
||||
|
||||
return derived[:length], nil
|
||||
}
|
||||
|
||||
@@ -77,8 +77,8 @@ func TestDeriveKeyEdgeCases(t *testing.T) {
|
||||
|
||||
t.Run("EmptySecret", func(t *testing.T) {
|
||||
_, err := DeriveKey([]byte{}, salt, info, 32)
|
||||
if err != nil {
|
||||
t.Errorf("DeriveKey failed with empty secret: %v", err)
|
||||
if err == nil {
|
||||
t.Errorf("DeriveKey should fail with empty secret")
|
||||
}
|
||||
})
|
||||
|
||||
@@ -97,12 +97,9 @@ func TestDeriveKeyEdgeCases(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("ZeroLength", func(t *testing.T) {
|
||||
key, err := DeriveKey(secret, salt, info, 0)
|
||||
if err != nil {
|
||||
t.Errorf("DeriveKey failed with zero length: %v", err)
|
||||
}
|
||||
if len(key) != 0 {
|
||||
t.Errorf("DeriveKey with zero length returned non-empty key: %x", key)
|
||||
_, err := DeriveKey(secret, salt, info, 0)
|
||||
if err == nil {
|
||||
t.Errorf("DeriveKey should fail with zero length")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user