diff --git a/pkg/identity/identity.go b/pkg/identity/identity.go index 2a32904..b299277 100644 --- a/pkg/identity/identity.go +++ b/pkg/identity/identity.go @@ -7,9 +7,11 @@ import ( "crypto/hmac" "crypto/rand" "crypto/sha256" + "encoding/json" "errors" "fmt" "io" + "os" "sync" "time" @@ -92,6 +94,61 @@ func decryptAESGCM(key, ciphertext []byte) ([]byte, error) { return plaintext, nil } +func encryptAESCBC(key, plaintext []byte) ([]byte, error) { + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + + // Generate IV + iv := make([]byte, aes.BlockSize) + if _, err := io.ReadFull(rand.Reader, iv); err != nil { + return nil, err + } + + // Add PKCS7 padding + padding := aes.BlockSize - len(plaintext)%aes.BlockSize + padtext := make([]byte, len(plaintext)+padding) + copy(padtext, plaintext) + for i := len(plaintext); i < len(padtext); i++ { + padtext[i] = byte(padding) + } + + // Encrypt + mode := cipher.NewCBCEncrypter(block, iv) + ciphertext := make([]byte, len(padtext)) + mode.CryptBlocks(ciphertext, padtext) + + // Prepend IV to ciphertext + return append(iv, ciphertext...), nil +} + +func decryptAESCBC(key, ciphertext []byte) ([]byte, error) { + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + + if len(ciphertext) < aes.BlockSize { + return nil, errors.New("ciphertext too short") + } + + iv := ciphertext[:aes.BlockSize] + ciphertext = ciphertext[aes.BlockSize:] + + if len(ciphertext)%aes.BlockSize != 0 { + return nil, errors.New("ciphertext is not a multiple of block size") + } + + mode := cipher.NewCBCDecrypter(block, iv) + plaintext := make([]byte, len(ciphertext)) + mode.CryptBlocks(plaintext, ciphertext) + + // Remove PKCS7 padding + padding := int(plaintext[len(plaintext)-1]) + return plaintext[:len(plaintext)-padding], nil +} + func New() (*Identity, error) { i := &Identity{ ratchets: make(map[string][]byte), @@ -350,3 +407,85 @@ func (i *Identity) Decrypt(ciphertext []byte) ([]byte, error) { // Decrypt data return decryptAESGCM(key, encryptedData) } + +func (i *Identity) EncryptWithHMAC(plaintext []byte, key []byte) ([]byte, error) { + // Encrypt with AES-CBC + ciphertext, err := encryptAESCBC(key, plaintext) + if err != nil { + return nil, err + } + + // Generate HMAC + h := hmac.New(sha256.New, key) + h.Write(ciphertext) + mac := h.Sum(nil) + + // Combine ciphertext and HMAC + return append(ciphertext, mac...), nil +} + +func (i *Identity) DecryptWithHMAC(data []byte, key []byte) ([]byte, error) { + if len(data) < sha256.Size { + return nil, errors.New("data too short") + } + + // Split HMAC and ciphertext + macStart := len(data) - sha256.Size + ciphertext := data[:macStart] + messageMAC := data[macStart:] + + // Verify HMAC + h := hmac.New(sha256.New, key) + h.Write(ciphertext) + expectedMAC := h.Sum(nil) + if !hmac.Equal(messageMAC, expectedMAC) { + return nil, errors.New("invalid HMAC") + } + + // Decrypt + return decryptAESCBC(key, ciphertext) +} + +func (i *Identity) ToFile(path string) error { + data := map[string]interface{}{ + "private_key": i.privateKey, + "public_key": i.publicKey, + "signing_key": i.signingKey, + "verification_key": i.verificationKey, + "app_data": i.appData, + } + + file, err := os.Create(path) + if err != nil { + return err + } + defer file.Close() + + return json.NewEncoder(file).Encode(data) +} + +func RecallIdentity(path string) (*Identity, error) { + file, err := os.Open(path) + if err != nil { + return nil, err + } + defer file.Close() + + var data map[string]interface{} + if err := json.NewDecoder(file).Decode(&data); err != nil { + return nil, err + } + + // Reconstruct identity from saved data + id := &Identity{ + privateKey: data["private_key"].([]byte), + publicKey: data["public_key"].([]byte), + signingKey: data["signing_key"].(ed25519.PrivateKey), + verificationKey: data["verification_key"].(ed25519.PublicKey), + appData: data["app_data"].([]byte), + ratchets: make(map[string][]byte), + ratchetExpiry: make(map[string]int64), + } + + return id, nil +} diff --git a/pkg/link/link.go b/pkg/link/link.go index f6930aa..4a7c95a 100644 --- a/pkg/link/link.go +++ b/pkg/link/link.go @@ -37,10 +37,13 @@ const ( STATUS_CLOSED = 0x02 STATUS_FAILED = 0x03 - // Add packet types PACKET_TYPE_DATA = 0x00 PACKET_TYPE_LINK = 0x01 PACKET_TYPE_IDENTIFY = 0x02 + + PROVE_NONE = 0x00 + PROVE_ALL = 0x01 + PROVE_APP = 0x02 ) type Link struct { @@ -69,7 +72,6 @@ type Link struct { hmacKey []byte transport *transport.Transport - // Add missing fields rssi float64 snr float64 q float64 @@ -77,6 +79,9 @@ type Link struct { resourceStartedCallback func(interface{}) resourceConcludedCallback func(interface{}) resourceStrategy byte + proofStrategy byte + proofCallback func(*packet.Packet) bool + trackPhyStats bool } func NewLink(dest *destination.Destination, transport *transport.Transport, establishedCallback func(*Link), closedCallback func(*Link)) *Link { @@ -269,27 +274,40 @@ func (r *RequestReceipt) Concluded() bool { func (l *Link) TrackPhyStats(rssi float64, snr float64, q float64) { l.mutex.Lock() defer l.mutex.Unlock() - + l.rssi = rssi l.snr = snr l.q = q } +func (l *Link) UpdatePhyStats(rssi float64, snr float64, q float64) { + l.TrackPhyStats(rssi, snr, q) +} + func (l *Link) GetRSSI() float64 { l.mutex.RLock() defer l.mutex.RUnlock() + if !l.trackPhyStats { + return 0 + } return l.rssi } func (l *Link) GetSNR() float64 { l.mutex.RLock() defer l.mutex.RUnlock() + if !l.trackPhyStats { + return 0 + } return l.snr } func (l *Link) GetQ() float64 { l.mutex.RLock() defer l.mutex.RUnlock() + if !l.trackPhyStats { + return 0 + } return l.q } @@ -557,10 +575,6 @@ func (l *Link) decrypt(data []byte) ([]byte, error) { return plaintext[:len(plaintext)-padding], nil } -func (l *Link) UpdatePhyStats(rssi float64, snr float64, q float64) { - l.TrackPhyStats(rssi, snr, q) -} - func (l *Link) GetRTT() float64 { l.mutex.RLock() defer l.mutex.RUnlock() @@ -615,3 +629,68 @@ func (l *Link) SendResource(res *resource.Resource) error { return nil } + +func (l *Link) maintainLink() { + ticker := time.NewTicker(time.Second * KEEPALIVE) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + if l.status != STATUS_ACTIVE { + return + } + + if l.InactiveFor() > float64(STALE_TIME) { + l.teardownReason = STATUS_FAILED + l.Teardown() + return + } + + if l.NoDataFor() > float64(KEEPALIVE) { + // Send keepalive packet + l.SendPacket([]byte{}) + } + } + } +} + +func (l *Link) Start() { + go l.maintainLink() +} + +func (l *Link) SetProofStrategy(strategy byte) error { + if strategy != PROVE_NONE && strategy != PROVE_ALL && strategy != PROVE_APP { + return errors.New("invalid proof strategy") + } + + l.mutex.Lock() + defer l.mutex.Unlock() + l.proofStrategy = strategy + return nil +} + +func (l *Link) SetProofCallback(callback func(*packet.Packet) bool) { + l.mutex.Lock() + defer l.mutex.Unlock() + l.proofCallback = callback +} + +func (l *Link) HandleProofRequest(packet *packet.Packet) bool { + l.mutex.RLock() + defer l.mutex.RUnlock() + + switch l.proofStrategy { + case PROVE_NONE: + return false + case PROVE_ALL: + return true + case PROVE_APP: + if l.proofCallback != nil { + return l.proofCallback(packet) + } + return false + default: + return false + } +}