feat: add link establishment tests and implement key generation, handshake, and proof validation in the Link module

This commit is contained in:
2025-11-20 21:31:18 -06:00
parent b0669954a4
commit 9e7e9a71ca
2 changed files with 697 additions and 4 deletions

View File

@@ -14,6 +14,7 @@ import (
"time"
"github.com/Sudo-Ivan/reticulum-go/pkg/common"
"github.com/Sudo-Ivan/reticulum-go/pkg/cryptography"
"github.com/Sudo-Ivan/reticulum-go/pkg/destination"
"github.com/Sudo-Ivan/reticulum-go/pkg/identity"
"github.com/Sudo-Ivan/reticulum-go/pkg/packet"
@@ -26,6 +27,12 @@ import (
const (
CURVE = "Curve25519"
ECPUBSIZE = 64
KEYSIZE = 32
LINK_MTU_SIZE = 3
MTU_BYTEMASK = 0xFFFFFF
MODE_BYTEMASK = 0xE0
ESTABLISHMENT_TIMEOUT_PER_HOP = 6
KEEPALIVE_TIMEOUT_FACTOR = 4
STALE_GRACE = 2
@@ -36,15 +43,20 @@ const (
ACCEPT_ALL = 0x01
ACCEPT_APP = 0x02
STATUS_PENDING = 0x00
STATUS_ACTIVE = 0x01
STATUS_CLOSED = 0x02
STATUS_FAILED = 0x03
STATUS_PENDING = 0x00
STATUS_HANDSHAKE = 0x01
STATUS_ACTIVE = 0x02
STATUS_CLOSED = 0x03
STATUS_FAILED = 0x04
PROVE_NONE = 0x00
PROVE_ALL = 0x01
PROVE_APP = 0x02
MODE_AES128_CBC = 0x00
MODE_AES256_CBC = 0x01
MODE_DEFAULT = MODE_AES256_CBC
WATCHDOG_MIN_SLEEP = 0.025
WATCHDOG_INTERVAL = 0.1
)
@@ -94,6 +106,19 @@ type Link struct {
keepalive time.Duration
staleTime time.Duration
initiator bool
prv []byte
sigPriv ed25519.PrivateKey
pub []byte
sigPub ed25519.PublicKey
peerPub []byte
peerSigPub ed25519.PublicKey
sharedKey []byte
derivedKey []byte
mode byte
mtu int
requestTime time.Time
requestPacket *packet.Packet
}
func NewLink(dest *destination.Destination, transport *transport.Transport, networkIface common.NetworkInterface, establishedCallback func(*Link), closedCallback func(*Link)) *Link {
@@ -892,3 +917,306 @@ func (l *Link) watchdog() {
}
l.watchdogActive = false
}
func (l *Link) Validate(signature, message []byte) bool {
l.mutex.RLock()
defer l.mutex.RUnlock()
if l.remoteIdentity == nil {
return false
}
return l.remoteIdentity.Verify(message, signature)
}
func (l *Link) generateEphemeralKeys() error {
priv, pub, err := cryptography.GenerateKeyPair()
if err != nil {
return fmt.Errorf("failed to generate X25519 keypair: %w", err)
}
l.prv = priv
l.pub = pub
pubKey, privKey, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
return fmt.Errorf("failed to generate Ed25519 keypair: %w", err)
}
l.sigPriv = privKey
l.sigPub = pubKey
return nil
}
func signallingBytes(mtu int, mode byte) []byte {
bytes := make([]byte, LINK_MTU_SIZE)
bytes[0] = byte((mtu >> 16) & 0xFF)
bytes[1] = byte((mtu >> 8) & 0xFF)
bytes[2] = byte(mtu & 0xFF)
bytes[0] |= (mode << 5)
return bytes
}
func (l *Link) SendLinkRequest() error {
if err := l.generateEphemeralKeys(); err != nil {
return err
}
l.mode = MODE_DEFAULT
l.mtu = 500
signalling := signallingBytes(l.mtu, l.mode)
requestData := make([]byte, 0, ECPUBSIZE+LINK_MTU_SIZE)
requestData = append(requestData, l.pub...)
requestData = append(requestData, l.sigPub...)
requestData = append(requestData, signalling...)
pkt := &packet.Packet{
HeaderType: packet.HeaderType1,
PacketType: packet.PacketTypeLinkReq,
TransportType: 0,
Context: packet.ContextNone,
ContextFlag: packet.FlagUnset,
Hops: 0,
DestinationType: l.destination.GetType(),
DestinationHash: l.destination.GetHash(),
Data: requestData,
CreateReceipt: false,
}
if err := pkt.Pack(); err != nil {
return fmt.Errorf("failed to pack link request: %w", err)
}
l.linkID = linkIDFromPacket(pkt)
l.requestPacket = pkt
l.requestTime = time.Now()
l.status = STATUS_PENDING
if err := l.transport.SendPacket(pkt); err != nil {
return fmt.Errorf("failed to send link request: %w", err)
}
log.Printf("[DEBUG-3] Link request sent, link_id=%x", l.linkID)
return nil
}
func linkIDFromPacket(pkt *packet.Packet) []byte {
hashablePart := make([]byte, 0, 1+16+1+ECPUBSIZE)
hashablePart = append(hashablePart, pkt.Raw[0])
if pkt.HeaderType == packet.HeaderType2 {
startIndex := 18
endIndex := startIndex + 16 + 1 + ECPUBSIZE
if len(pkt.Raw) >= endIndex {
hashablePart = append(hashablePart, pkt.Raw[startIndex:endIndex]...)
}
} else {
startIndex := 2
endIndex := startIndex + 16 + 1 + ECPUBSIZE
if len(pkt.Raw) >= endIndex {
hashablePart = append(hashablePart, pkt.Raw[startIndex:endIndex]...)
}
}
return identity.TruncatedHash(hashablePart)
}
func (l *Link) HandleLinkRequest(pkt *packet.Packet, ownerIdentity *identity.Identity) error {
if len(pkt.Data) < ECPUBSIZE {
return errors.New("link request data too short")
}
peerPub := pkt.Data[0:KEYSIZE]
peerSigPub := pkt.Data[KEYSIZE : ECPUBSIZE]
l.peerPub = peerPub
l.peerSigPub = peerSigPub
l.linkID = linkIDFromPacket(pkt)
l.initiator = false
if len(pkt.Data) >= ECPUBSIZE+LINK_MTU_SIZE {
mtuBytes := pkt.Data[ECPUBSIZE : ECPUBSIZE+LINK_MTU_SIZE]
l.mtu = (int(mtuBytes[0]&0x1F) << 16) | (int(mtuBytes[1]) << 8) | int(mtuBytes[2])
l.mode = (mtuBytes[0] & MODE_BYTEMASK) >> 5
log.Printf("[DEBUG-4] Link request includes MTU: %d, mode: %d", l.mtu, l.mode)
} else {
l.mtu = 500
l.mode = MODE_DEFAULT
}
if err := l.generateEphemeralKeys(); err != nil {
return err
}
if err := l.performHandshake(); err != nil {
return fmt.Errorf("handshake failed: %w", err)
}
if err := l.sendLinkProof(ownerIdentity); err != nil {
return fmt.Errorf("failed to send link proof: %w", err)
}
l.status = STATUS_ACTIVE
l.establishedAt = time.Now()
log.Printf("[DEBUG-3] Link established (responder), link_id=%x", l.linkID)
if l.establishedCallback != nil {
go l.establishedCallback(l)
}
return nil
}
func (l *Link) performHandshake() error {
if len(l.peerPub) != KEYSIZE {
return errors.New("invalid peer public key length")
}
sharedSecret, err := cryptography.DeriveSharedSecret(l.prv, l.peerPub)
if err != nil {
return fmt.Errorf("ECDH failed: %w", err)
}
l.sharedKey = sharedSecret
var derivedKeyLength int
if l.mode == MODE_AES128_CBC {
derivedKeyLength = 32
} else if l.mode == MODE_AES256_CBC {
derivedKeyLength = 64
} else {
return fmt.Errorf("invalid link mode: %d", l.mode)
}
derivedKey, err := cryptography.DeriveKey(l.sharedKey, l.linkID, nil, derivedKeyLength)
if err != nil {
return fmt.Errorf("HKDF failed: %w", err)
}
l.derivedKey = derivedKey
if len(derivedKey) >= 32 {
l.sessionKey = derivedKey[0:32]
}
if len(derivedKey) >= 64 {
l.hmacKey = derivedKey[32:64]
}
l.status = STATUS_HANDSHAKE
log.Printf("[DEBUG-4] Handshake completed, derived %d bytes of key material", len(derivedKey))
return nil
}
func (l *Link) sendLinkProof(ownerIdentity *identity.Identity) error {
proofPkt, err := l.GenerateLinkProof(ownerIdentity)
if err != nil {
return err
}
if l.transport != nil {
if err := l.transport.SendPacket(proofPkt); err != nil {
return fmt.Errorf("failed to send link proof: %w", err)
}
log.Printf("[DEBUG-3] Link proof sent, link_id=%x", l.linkID)
}
return nil
}
func (l *Link) GenerateLinkProof(ownerIdentity *identity.Identity) (*packet.Packet, error) {
signalling := signallingBytes(l.mtu, l.mode)
ownerSigPub := ownerIdentity.GetPublicKey()[KEYSIZE:ECPUBSIZE]
signedData := make([]byte, 0, len(l.linkID)+KEYSIZE+len(ownerSigPub)+len(signalling))
signedData = append(signedData, l.linkID...)
signedData = append(signedData, l.pub...)
signedData = append(signedData, ownerSigPub...)
signedData = append(signedData, signalling...)
signature := ownerIdentity.Sign(signedData)
proofData := make([]byte, 0, len(signature)+KEYSIZE+len(signalling))
proofData = append(proofData, signature...)
proofData = append(proofData, l.pub...)
proofData = append(proofData, signalling...)
proofPkt := &packet.Packet{
HeaderType: packet.HeaderType1,
PacketType: packet.PacketTypeProof,
TransportType: 0,
Context: packet.ContextLRProof,
ContextFlag: packet.FlagUnset,
Hops: 0,
DestinationType: 0x03,
DestinationHash: l.linkID,
Data: proofData,
CreateReceipt: false,
Link: l,
}
if err := proofPkt.Pack(); err != nil {
return nil, fmt.Errorf("failed to pack link proof: %w", err)
}
return proofPkt, nil
}
func (l *Link) ValidateLinkProof(pkt *packet.Packet) error {
if l.status != STATUS_PENDING {
return fmt.Errorf("invalid link status for proof validation: %d", l.status)
}
if len(pkt.Data) < identity.SIGLENGTH/8+KEYSIZE {
return errors.New("link proof data too short")
}
signature := pkt.Data[0 : identity.SIGLENGTH/8]
peerPub := pkt.Data[identity.SIGLENGTH/8 : identity.SIGLENGTH/8+KEYSIZE]
signalling := []byte{0, 0, 0}
if len(pkt.Data) >= identity.SIGLENGTH/8+KEYSIZE+LINK_MTU_SIZE {
signalling = pkt.Data[identity.SIGLENGTH/8+KEYSIZE : identity.SIGLENGTH/8+KEYSIZE+LINK_MTU_SIZE]
mtu := (int(signalling[0]&0x1F) << 16) | (int(signalling[1]) << 8) | int(signalling[2])
mode := (signalling[0] & MODE_BYTEMASK) >> 5
l.mtu = mtu
l.mode = mode
log.Printf("[DEBUG-4] Link proof includes MTU: %d, mode: %d", mtu, mode)
}
l.peerPub = peerPub
if l.destination != nil && l.destination.GetIdentity() != nil {
destIdent := l.destination.GetIdentity()
pubKey := destIdent.GetPublicKey()
if len(pubKey) >= ECPUBSIZE {
l.peerSigPub = pubKey[KEYSIZE:ECPUBSIZE]
}
}
signedData := make([]byte, 0, len(l.linkID)+KEYSIZE+len(l.peerSigPub)+len(signalling))
signedData = append(signedData, l.linkID...)
signedData = append(signedData, peerPub...)
signedData = append(signedData, l.peerSigPub...)
signedData = append(signedData, signalling...)
if l.destination == nil || l.destination.GetIdentity() == nil {
return errors.New("no destination identity for proof validation")
}
if !l.destination.GetIdentity().Verify(signedData, signature) {
return errors.New("link proof signature validation failed")
}
if err := l.performHandshake(); err != nil {
return fmt.Errorf("handshake failed: %w", err)
}
l.rtt = time.Since(l.requestTime).Seconds()
l.status = STATUS_ACTIVE
l.establishedAt = time.Now()
log.Printf("[DEBUG-3] Link established (initiator), link_id=%x, RTT=%.3fs", l.linkID, l.rtt)
if l.establishedCallback != nil {
go l.establishedCallback(l)
}
return nil
}