diff --git a/pkg/link/link.go b/pkg/link/link.go index 746d63c..3cc9497 100644 --- a/pkg/link/link.go +++ b/pkg/link/link.go @@ -5,6 +5,7 @@ import ( "crypto/cipher" "crypto/ed25519" "crypto/rand" + "encoding/binary" "encoding/hex" "errors" "fmt" @@ -22,6 +23,7 @@ import ( "github.com/Sudo-Ivan/reticulum-go/pkg/resolver" "github.com/Sudo-Ivan/reticulum-go/pkg/resource" "github.com/Sudo-Ivan/reticulum-go/pkg/transport" + "github.com/vmihailenco/msgpack/v5" ) const ( @@ -38,6 +40,7 @@ const ( STALE_GRACE = 2 KEEPALIVE = 360 STALE_TIME = 720 + TRAFFIC_TIMEOUT_FACTOR = 6 ACCEPT_NONE = 0x00 ACCEPT_ALL = 0x01 @@ -46,8 +49,9 @@ const ( STATUS_PENDING = 0x00 STATUS_HANDSHAKE = 0x01 STATUS_ACTIVE = 0x02 - STATUS_CLOSED = 0x03 - STATUS_FAILED = 0x04 + STATUS_STALE = 0x03 + STATUS_CLOSED = 0x04 + STATUS_FAILED = 0x05 PROVE_NONE = 0x00 PROVE_ALL = 0x01 @@ -117,8 +121,12 @@ type Link struct { derivedKey []byte mode byte mtu int + mdu int requestTime time.Time requestPacket *packet.Packet + + pendingRequests []*RequestReceipt + requestMutex sync.RWMutex } func NewLink(dest *destination.Destination, transport *transport.Transport, networkIface common.NetworkInterface, establishedCallback func(*Link), closedCallback func(*Link)) *Link { @@ -142,6 +150,7 @@ func NewLink(dest *destination.Destination, transport *transport.Transport, netw keepalive: time.Duration(KEEPALIVE * float64(time.Second)), staleTime: time.Duration(STALE_TIME * float64(time.Second)), initiator: false, + pendingRequests: make([]*RequestReceipt, 0), } } @@ -177,42 +186,25 @@ func (l *Link) Establish() error { return errors.New("link already established or failed") } - destPublicKey := l.destination.GetPublicKey() - if destPublicKey == nil { - debug.Log(debug.DEBUG_INFO, "Cannot establish link: destination has no public key") - return errors.New("destination has no public key") + if l.destination == nil { + return errors.New("destination is nil") } - // Generate link ID for this connection - l.linkID = make([]byte, 16) - if _, err := rand.Read(l.linkID); err != nil { - debug.Log(debug.DEBUG_INFO, "Failed to generate link ID", "error", err) - return fmt.Errorf("failed to generate link ID: %w", err) - } l.initiator = true + l.status = STATUS_PENDING + l.requestTime = time.Now() - debug.Log(debug.DEBUG_VERBOSE, "Creating link request packet for destination", "dest_public_key", fmt.Sprintf("%x", destPublicKey[:8]), "link_id", fmt.Sprintf("%x", l.linkID[:8])) - - p := &packet.Packet{ - HeaderType: packet.HeaderType1, - PacketType: packet.PacketTypeLinkReq, - TransportType: 0, - Context: packet.ContextLinkIdentify, - ContextFlag: packet.FlagUnset, - Hops: 0, - DestinationType: l.destination.GetType(), - DestinationHash: l.destination.GetHash(), - Data: l.linkID, - CreateReceipt: true, - } - - if err := p.Pack(); err != nil { - debug.Log(debug.DEBUG_INFO, "Failed to pack link request packet", "error", err) + if err := l.SendLinkRequest(); err != nil { return err } - debug.Log(debug.DEBUG_VERBOSE, "Sending link request packet", "link_id", fmt.Sprintf("%x", l.linkID[:8])) - return l.transport.SendPacket(p) + if l.transport != nil { + l.transport.RegisterLink(l.linkID, l) + } + + go l.startWatchdog() + + return nil } func (l *Link) Identify(id *identity.Identity) error { @@ -285,53 +277,83 @@ func (l *Link) Request(path string, data []byte, timeout time.Duration) (*Reques return nil, errors.New("link not active") } - requestID := make([]byte, 16) - if _, err := rand.Read(requestID); err != nil { - return nil, err - } - - // Create request message - reqMsg := make([]byte, 0) - reqMsg = append(reqMsg, requestID...) - reqMsg = append(reqMsg, []byte(path)...) - if data != nil { - reqMsg = append(reqMsg, data...) - } - - receipt := &RequestReceipt{ - requestID: requestID, - status: STATUS_PENDING, - sentAt: time.Now(), - } - - // Send request - err := l.SendPacket(reqMsg) + pathHash := identity.TruncatedHash([]byte(path)) + requestData := []interface{}{time.Now().Unix(), pathHash, data} + packedRequest, err := msgpack.Marshal(requestData) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to pack request: %w", err) } - // Set timeout - if timeout > 0 { - go func() { - time.Sleep(timeout) - l.mutex.Lock() - if receipt.status == STATUS_PENDING { - receipt.status = STATUS_FAILED - } - l.mutex.Unlock() - }() + if timeout <= 0 { + timeout = time.Duration(l.rtt*TRAFFIC_TIMEOUT_FACTOR*float64(time.Second)) + time.Duration(resource.RESPONSE_MAX_GRACE_TIME*1.125*float64(time.Second)) } - return receipt, nil + requestID := identity.TruncatedHash(packedRequest) + + if len(packedRequest) <= l.mdu { + reqPkt := &packet.Packet{ + HeaderType: packet.HeaderType1, + PacketType: packet.PacketTypeData, + TransportType: 0, + Context: packet.ContextRequest, + ContextFlag: packet.FlagUnset, + Hops: 0, + DestinationType: 0x03, + DestinationHash: l.linkID, + Data: packedRequest, + CreateReceipt: false, + } + + if err := reqPkt.Pack(); err != nil { + return nil, err + } + + encrypted, err := l.encrypt(packedRequest) + if err != nil { + return nil, err + } + + reqPkt.Data = encrypted + if err := reqPkt.Pack(); err != nil { + return nil, err + } + + if err := l.transport.SendPacket(reqPkt); err != nil { + return nil, err + } + + receipt := &RequestReceipt{ + link: l, + requestID: requestID, + status: STATUS_PENDING, + sentAt: time.Now(), + timeout: timeout, + } + + l.requestMutex.Lock() + l.pendingRequests = append(l.pendingRequests, receipt) + l.requestMutex.Unlock() + + go receipt.startTimeout() + + return receipt, nil + } + + return nil, errors.New("request too large, resource transfer not yet implemented") } type RequestReceipt struct { - mutex sync.RWMutex - requestID []byte - status byte - sentAt time.Time - receivedAt time.Time - response []byte + link *Link + mutex sync.RWMutex + requestID []byte + status byte + sentAt time.Time + receivedAt time.Time + response []byte + timeout time.Duration + responseCb func(*RequestReceipt) + failedCb func(*RequestReceipt) + progressCb func(*RequestReceipt) } func (r *RequestReceipt) GetRequestID() []byte { @@ -369,6 +391,36 @@ func (r *RequestReceipt) Concluded() bool { return status == STATUS_ACTIVE || status == STATUS_FAILED } +func (r *RequestReceipt) startTimeout() { + time.Sleep(r.timeout) + r.mutex.Lock() + if r.status == STATUS_PENDING { + r.status = STATUS_FAILED + if r.failedCb != nil { + go r.failedCb(r) + } + } + r.mutex.Unlock() +} + +func (r *RequestReceipt) SetResponseCallback(cb func(*RequestReceipt)) { + r.mutex.Lock() + defer r.mutex.Unlock() + r.responseCb = cb +} + +func (r *RequestReceipt) SetFailedCallback(cb func(*RequestReceipt)) { + r.mutex.Lock() + defer r.mutex.Unlock() + r.failedCb = cb +} + +func (r *RequestReceipt) SetProgressCallback(cb func(*RequestReceipt)) { + r.mutex.Lock() + defer r.mutex.Unlock() + r.progressCb = cb +} + func (l *Link) TrackPhyStats(track bool) { l.mutex.Lock() defer l.mutex.Unlock() @@ -576,38 +628,398 @@ func (l *Link) SendPacket(data []byte) error { return l.transport.SendPacket(p) } -func (l *Link) HandleInbound(data []byte) error { +func (l *Link) HandleInbound(pkt *packet.Packet) error { l.mutex.Lock() defer l.mutex.Unlock() - if l.status != STATUS_ACTIVE { - debug.Log(debug.DEBUG_INFO, "Dropping inbound packet: link not active", "status", l.status) + l.watchdogLock = true + defer func() { + l.watchdogLock = false + }() + + if l.status == STATUS_CLOSED { + return errors.New("link is closed") + } + + l.lastInbound = time.Now() + if pkt.Context != packet.ContextKeepalive { + l.lastDataReceived = time.Now() + } + + if l.status == STATUS_STALE { + l.status = STATUS_ACTIVE + } + + if pkt.PacketType == packet.PacketTypeData { + return l.handleDataPacket(pkt) + } else if pkt.PacketType == packet.PacketTypeProof { + if pkt.Context == packet.ContextLRProof { + return l.handleLinkProof(pkt) + } else if pkt.Context == packet.ContextLRRTT { + return l.handleRTTPacket(pkt) + } + } + + return nil +} + +func (l *Link) handleDataPacket(pkt *packet.Packet) error { + if l.status != STATUS_ACTIVE && l.status != STATUS_HANDSHAKE { return errors.New("link not active") } - // Decode and log packet details - l.decodePacket(data) + var plaintext []byte + var err error - // Decrypt if we have a session key if l.sessionKey != nil { - decrypted, err := l.decrypt(data) + plaintext, err = l.decrypt(pkt.Data) if err != nil { debug.Log(debug.DEBUG_INFO, "Failed to decrypt packet", "error", err) return err } - data = decrypted + } else { + plaintext = pkt.Data } - l.lastInbound = time.Now() - l.lastDataReceived = time.Now() - - if l.packetCallback != nil { - l.packetCallback(data, nil) + switch pkt.Context { + case packet.ContextNone: + if l.packetCallback != nil { + l.packetCallback(plaintext, pkt) + } + case packet.ContextRequest: + return l.handleRequest(plaintext, pkt) + case packet.ContextResponse: + return l.handleResponse(plaintext) + case packet.ContextLinkIdentify: + return l.handleIdentification(plaintext) + case packet.ContextKeepalive: + if !l.initiator && len(plaintext) == 1 && plaintext[0] == 0xFF { + keepaliveResp := []byte{0xFE} + keepalivePkt := &packet.Packet{ + HeaderType: packet.HeaderType1, + PacketType: packet.PacketTypeData, + TransportType: 0, + Context: packet.ContextKeepalive, + ContextFlag: packet.FlagUnset, + Hops: 0, + DestinationType: 0x03, + DestinationHash: l.linkID, + Data: keepaliveResp, + CreateReceipt: false, + } + if err := keepalivePkt.Pack(); err != nil { + return err + } + encrypted, err := l.encrypt(keepaliveResp) + if err != nil { + return err + } + keepalivePkt.Data = encrypted + if err := keepalivePkt.Pack(); err != nil { + return err + } + l.lastOutbound = time.Now() + return l.transport.SendPacket(keepalivePkt) + } + case packet.ContextLinkClose: + return l.handleTeardown(plaintext) + case packet.ContextLRRTT: + return l.handleRTTPacket(pkt) + case packet.ContextResourceAdv: + return l.handleResourceAdvertisement(pkt) + case packet.ContextResourceReq: + return l.handleResourceRequest(pkt) + case packet.ContextResourceHMU: + return l.handleResourceHashmapUpdate(pkt) + case packet.ContextResourceICL: + return l.handleResourceCancel(pkt) + case packet.ContextResourceRCL: + return l.handleResourceReject(pkt) + case packet.ContextResource: + return l.handleResourcePart(pkt) } return nil } +func (l *Link) handleResourceAdvertisement(pkt *packet.Packet) error { + plaintext, err := l.decrypt(pkt.Data) + if err != nil { + return err + } + + if l.resourceStrategy == ACCEPT_NONE { + return nil + } + + allowed := false + if l.resourceStrategy == ACCEPT_ALL { + allowed = true + } else if l.resourceStrategy == ACCEPT_APP && l.resourceCallback != nil { + allowed = l.resourceCallback(plaintext) + } + + if allowed { + if l.resourceStartedCallback != nil { + l.resourceStartedCallback(plaintext) + } + } else { + debug.Log(debug.DEBUG_INFO, "Resource advertisement rejected") + } + + return nil +} + +func (l *Link) handleResourceRequest(pkt *packet.Packet) error { + plaintext, err := l.decrypt(pkt.Data) + if err != nil { + return err + } + + if l.resourceStartedCallback != nil { + l.resourceStartedCallback(plaintext) + } + + return nil +} + +func (l *Link) handleResourceHashmapUpdate(pkt *packet.Packet) error { + plaintext, err := l.decrypt(pkt.Data) + if err != nil { + return err + } + + if l.resourceStartedCallback != nil { + l.resourceStartedCallback(plaintext) + } + + return nil +} + +func (l *Link) handleResourceCancel(pkt *packet.Packet) error { + plaintext, err := l.decrypt(pkt.Data) + if err != nil { + return err + } + + if l.resourceConcludedCallback != nil { + l.resourceConcludedCallback(plaintext) + } + + return nil +} + +func (l *Link) handleResourceReject(pkt *packet.Packet) error { + plaintext, err := l.decrypt(pkt.Data) + if err != nil { + return err + } + + if l.resourceConcludedCallback != nil { + l.resourceConcludedCallback(plaintext) + } + + return nil +} + +func (l *Link) handleResourcePart(pkt *packet.Packet) error { + if l.resourceStartedCallback != nil { + l.resourceStartedCallback(pkt.Data) + } + + return nil +} + +func (l *Link) handleRequest(plaintext []byte, pkt *packet.Packet) error { + if l.destination == nil { + return errors.New("no destination for request handling") + } + + var requestData []interface{} + if err := msgpack.Unmarshal(plaintext, &requestData); err != nil { + return fmt.Errorf("failed to unpack request: %w", err) + } + + if len(requestData) < 3 { + return errors.New("invalid request format") + } + + requestedAt := time.Unix(int64(requestData[0].(int64)), 0) + pathHash := requestData[1].([]byte) + requestPayload := requestData[2].([]byte) + + requestID := identity.TruncatedHash(plaintext) + + if l.destination != nil { + handler := l.destination.GetRequestHandler(pathHash) + if handler != nil { + response := handler(pathHash, requestPayload, requestID, l.remoteIdentity, requestedAt) + if response != nil { + return l.sendResponse(requestID, response) + } + } + } + + return nil +} + +func (l *Link) handleResponse(plaintext []byte) error { + var responseData []interface{} + if err := msgpack.Unmarshal(plaintext, &responseData); err != nil { + return fmt.Errorf("failed to unpack response: %w", err) + } + + if len(responseData) < 2 { + return errors.New("invalid response format") + } + + requestID := responseData[0].([]byte) + responsePayload := responseData[1].([]byte) + + l.requestMutex.Lock() + for i, req := range l.pendingRequests { + if string(req.requestID) == string(requestID) { + req.mutex.Lock() + req.status = STATUS_ACTIVE + req.response = responsePayload + req.receivedAt = time.Now() + req.mutex.Unlock() + + if req.responseCb != nil { + go req.responseCb(req) + } + + l.pendingRequests = append(l.pendingRequests[:i], l.pendingRequests[i+1:]...) + break + } + } + l.requestMutex.Unlock() + + return nil +} + +func (l *Link) sendResponse(requestID []byte, response interface{}) error { + responseData := []interface{}{requestID, response} + packedResponse, err := msgpack.Marshal(responseData) + if err != nil { + return fmt.Errorf("failed to pack response: %w", err) + } + + if len(packedResponse) <= l.mdu { + encrypted, err := l.encrypt(packedResponse) + if err != nil { + return err + } + + respPkt := &packet.Packet{ + HeaderType: packet.HeaderType1, + PacketType: packet.PacketTypeData, + TransportType: 0, + Context: packet.ContextResponse, + ContextFlag: packet.FlagUnset, + Hops: 0, + DestinationType: 0x03, + DestinationHash: l.linkID, + Data: encrypted, + CreateReceipt: false, + } + + if err := respPkt.Pack(); err != nil { + return err + } + + l.lastOutbound = time.Now() + l.lastDataSent = time.Now() + return l.transport.SendPacket(respPkt) + } + + return errors.New("response too large, resource transfer not yet implemented") +} + +func (l *Link) handleRTTPacket(pkt *packet.Packet) error { + if !l.initiator { + measuredRTT := time.Since(l.requestTime).Seconds() + plaintext, err := l.decrypt(pkt.Data) + if err != nil { + return err + } + + var rtt float64 + if len(plaintext) >= 8 { + rtt = float64(binary.BigEndian.Uint64(plaintext[:8])) / 1e9 + } + + l.rtt = max(measuredRTT, rtt) + l.status = STATUS_ACTIVE + l.establishedAt = time.Now() + + if l.transport != nil { + l.transport.RegisterLink(l.linkID, l) + } + + if l.rtt > 0 { + l.updateKeepalive() + } + + if l.establishedCallback != nil { + go l.establishedCallback(l) + } + + debug.Log(debug.DEBUG_INFO, "Link established (responder) after RTT", "link_id", fmt.Sprintf("%x", l.linkID), "rtt", fmt.Sprintf("%.3fs", l.rtt)) + } + return nil +} + +func (l *Link) updateKeepalive() { + if l.rtt <= 0 { + return + } + + keepaliveMaxRTT := 1.75 + keepaliveMax := float64(KEEPALIVE) + keepaliveMin := 5.0 + + calculatedKeepalive := l.rtt * (keepaliveMax / keepaliveMaxRTT) + if calculatedKeepalive > keepaliveMax { + calculatedKeepalive = keepaliveMax + } + if calculatedKeepalive < keepaliveMin { + calculatedKeepalive = keepaliveMin + } + + l.keepalive = time.Duration(calculatedKeepalive * float64(time.Second)) + l.staleTime = time.Duration(float64(l.keepalive) * 2.0) +} + +func (l *Link) handleLinkProof(pkt *packet.Packet) error { + if l.initiator { + return l.ValidateLinkProof(pkt) + } + return nil +} + +func (l *Link) handleTeardown(plaintext []byte) error { + if len(plaintext) == len(l.linkID) && string(plaintext) == string(l.linkID) { + l.status = STATUS_CLOSED + if l.initiator { + l.teardownReason = STATUS_FAILED + } else { + l.teardownReason = STATUS_FAILED + } + if l.closedCallback != nil { + l.closedCallback(l) + } + } + return nil +} + +func max(a, b float64) float64 { + if a > b { + return a + } + return b +} + func (l *Link) encrypt(data []byte) ([]byte, error) { if l.sessionKey == nil { return nil, errors.New("no session key available") @@ -893,34 +1305,92 @@ func (l *Link) watchdog() { for l.status != STATUS_CLOSED { l.mutex.Lock() if l.watchdogLock { + rttWait := 0.025 + if l.rtt > 0 { + rttWait = l.rtt + } + if rttWait < 0.025 { + rttWait = 0.025 + } l.mutex.Unlock() - time.Sleep(time.Duration(WATCHDOG_MIN_SLEEP * float64(time.Second))) + time.Sleep(time.Duration(rttWait * float64(time.Second))) continue } var sleepTime = WATCHDOG_INTERVAL - if l.status == STATUS_ACTIVE { - lastActivity := l.lastInbound - if l.lastOutbound.After(lastActivity) { - lastActivity = l.lastOutbound + if l.status == STATUS_PENDING { + nextCheck := l.requestTime.Add(l.establishmentTimeout) + sleepTime = time.Until(nextCheck).Seconds() + if time.Now().After(nextCheck) { + debug.Log(debug.DEBUG_INFO, "Link establishment timed out") + l.status = STATUS_CLOSED + l.teardownReason = STATUS_FAILED + if l.closedCallback != nil { + l.closedCallback(l) + } + sleepTime = 0.001 } - - if time.Since(lastActivity) > l.keepalive { + } else if l.status == STATUS_HANDSHAKE { + nextCheck := l.requestTime.Add(l.establishmentTimeout) + sleepTime = time.Until(nextCheck).Seconds() + if time.Now().After(nextCheck) { if l.initiator { - if err := l.SendPacket([]byte{}); err != nil { // #nosec G104 - debug.Log(debug.DEBUG_INFO, "Failed to send keepalive packet", "error", err) + debug.Log(debug.DEBUG_INFO, "Timeout waiting for link request proof") + } else { + debug.Log(debug.DEBUG_INFO, "Timeout waiting for RTT packet from link initiator") + } + l.status = STATUS_CLOSED + l.teardownReason = STATUS_FAILED + if l.closedCallback != nil { + l.closedCallback(l) + } + sleepTime = 0.001 + } + } else if l.status == STATUS_ACTIVE { + activatedAt := l.establishedAt + if activatedAt.IsZero() { + activatedAt = time.Time{} + } + lastInbound := l.lastInbound + if lastInbound.Before(activatedAt) { + lastInbound = activatedAt + } + now := time.Now() + + if now.After(lastInbound.Add(l.keepalive)) { + if l.initiator { + lastKeepalive := l.lastOutbound + if now.After(lastKeepalive.Add(l.keepalive)) { + l.sendKeepalive() } } - if time.Since(lastActivity) > l.staleTime { - l.status = STATUS_CLOSED - l.teardownReason = STATUS_FAILED - if l.closedCallback != nil { - l.closedCallback(l) - } + if now.After(lastInbound.Add(l.staleTime)) { + sleepTime = l.rtt*KEEPALIVE_TIMEOUT_FACTOR + STALE_GRACE + l.status = STATUS_STALE + } else { + sleepTime = float64(l.keepalive) / float64(time.Second) } + } else { + nextKeepalive := lastInbound.Add(l.keepalive) + sleepTime = time.Until(nextKeepalive).Seconds() } + } else if l.status == STATUS_STALE { + sleepTime = 0.001 + l.sendTeardownPacket() + l.status = STATUS_CLOSED + l.teardownReason = STATUS_FAILED + if l.closedCallback != nil { + l.closedCallback(l) + } + } + + if sleepTime <= 0 { + sleepTime = 0.1 + } + if sleepTime > 5.0 { + sleepTime = 5.0 } l.mutex.Unlock() @@ -929,6 +1399,57 @@ func (l *Link) watchdog() { l.watchdogActive = false } +func (l *Link) sendKeepalive() error { + keepaliveData := []byte{0xFF} + keepalivePkt := &packet.Packet{ + HeaderType: packet.HeaderType1, + PacketType: packet.PacketTypeData, + TransportType: 0, + Context: packet.ContextKeepalive, + ContextFlag: packet.FlagUnset, + Hops: 0, + DestinationType: 0x03, + DestinationHash: l.linkID, + Data: keepaliveData, + CreateReceipt: false, + } + encrypted, err := l.encrypt(keepaliveData) + if err != nil { + return err + } + keepalivePkt.Data = encrypted + if err := keepalivePkt.Pack(); err != nil { + return err + } + l.lastOutbound = time.Now() + return l.transport.SendPacket(keepalivePkt) +} + +func (l *Link) sendTeardownPacket() error { + teardownPkt := &packet.Packet{ + HeaderType: packet.HeaderType1, + PacketType: packet.PacketTypeData, + TransportType: 0, + Context: packet.ContextLinkClose, + ContextFlag: packet.FlagUnset, + Hops: 0, + DestinationType: 0x03, + DestinationHash: l.linkID, + Data: l.linkID, + CreateReceipt: false, + } + encrypted, err := l.encrypt(l.linkID) + if err != nil { + return err + } + teardownPkt.Data = encrypted + if err := teardownPkt.Pack(); err != nil { + return err + } + l.lastOutbound = time.Now() + return l.transport.SendPacket(teardownPkt) +} + func (l *Link) Validate(signature, message []byte) bool { l.mutex.RLock() defer l.mutex.RUnlock() @@ -974,6 +1495,7 @@ func (l *Link) SendLinkRequest() error { l.mode = MODE_DEFAULT l.mtu = 500 + l.updateMDU() signalling := signallingBytes(l.mtu, l.mode) requestData := make([]byte, 0, ECPUBSIZE+LINK_MTU_SIZE) @@ -1062,21 +1584,37 @@ func (l *Link) HandleLinkRequest(pkt *packet.Packet, ownerIdentity *identity.Ide return fmt.Errorf("handshake failed: %w", err) } + l.updateMDU() + if err := l.sendLinkProof(ownerIdentity); err != nil { return fmt.Errorf("failed to send link proof: %w", err) } - l.status = STATUS_ACTIVE - l.establishedAt = time.Now() - debug.Log(debug.DEBUG_INFO, "Link established (responder)", "link_id", fmt.Sprintf("%x", l.linkID)) - - if l.establishedCallback != nil { - go l.establishedCallback(l) + l.status = STATUS_HANDSHAKE + l.lastInbound = time.Now() + l.requestTime = time.Now() + + if l.transport != nil { + l.transport.RegisterLink(l.linkID, l) } + + debug.Log(debug.DEBUG_INFO, "Link proof sent (responder), waiting for RTT", "link_id", fmt.Sprintf("%x", l.linkID)) return nil } +func (l *Link) updateMDU() { + headerMaxSize := 64 + ifacMinSize := 4 + tokenOverhead := 16 + aesBlockSize := 16 + + l.mdu = int(float64(l.mtu-headerMaxSize-ifacMinSize-tokenOverhead)/float64(aesBlockSize)) * aesBlockSize - 1 + if l.mdu < 0 { + l.mdu = 100 + } +} + func (l *Link) performHandshake() error { if len(l.peerPub) != KEYSIZE { return errors.New("invalid peer public key length") @@ -1191,7 +1729,7 @@ func (l *Link) GenerateLinkProof(ownerIdentity *identity.Identity) (*packet.Pack } func (l *Link) ValidateLinkProof(pkt *packet.Packet) error { - if l.status != STATUS_PENDING { + if l.status != STATUS_PENDING && l.status != STATUS_HANDSHAKE { return fmt.Errorf("invalid link status for proof validation: %d", l.status) } @@ -1207,9 +1745,9 @@ func (l *Link) ValidateLinkProof(pkt *packet.Packet) error { signalling = pkt.Data[identity.SIGLENGTH/8+KEYSIZE : identity.SIGLENGTH/8+KEYSIZE+LINK_MTU_SIZE] mtu := (int(signalling[0]&0x1F) << 16) | (int(signalling[1]) << 8) | int(signalling[2]) mode := (signalling[0] & MODE_BYTEMASK) >> 5 - l.mtu = mtu - l.mode = mode - debug.Log(debug.DEBUG_VERBOSE, "Link proof includes MTU", "mtu", mtu, "mode", mode) + l.mtu = mtu + l.mode = mode + debug.Log(debug.DEBUG_VERBOSE, "Link proof includes MTU", "mtu", mtu, "mode", mode) } l.peerPub = peerPub @@ -1239,10 +1777,43 @@ func (l *Link) ValidateLinkProof(pkt *packet.Packet) error { return fmt.Errorf("handshake failed: %w", err) } + l.updateMDU() + l.rtt = time.Since(l.requestTime).Seconds() l.status = STATUS_ACTIVE l.establishedAt = time.Now() + if l.rtt > 0 { + l.updateKeepalive() + } + + rttData := make([]byte, 8) + binary.BigEndian.PutUint64(rttData, uint64(l.rtt*1e9)) + rttPkt := &packet.Packet{ + HeaderType: packet.HeaderType1, + PacketType: packet.PacketTypeData, + TransportType: 0, + Context: packet.ContextLRRTT, + ContextFlag: packet.FlagUnset, + Hops: 0, + DestinationType: 0x03, + DestinationHash: l.linkID, + Data: rttData, + CreateReceipt: false, + } + encrypted, err := l.encrypt(rttData) + if err == nil { + rttPkt.Data = encrypted + if err := rttPkt.Pack(); err == nil { + l.transport.SendPacket(rttPkt) + l.lastOutbound = time.Now() + } + } + + if l.transport != nil { + l.transport.RegisterLink(l.linkID, l) + } + debug.Log(debug.DEBUG_INFO, "Link established (initiator)", "link_id", fmt.Sprintf("%x", l.linkID), "rtt", fmt.Sprintf("%.3fs", l.rtt)) if l.establishedCallback != nil {