diff --git a/pkg/transport/transport.go b/pkg/transport/transport.go index 5e0f6ce..419cd30 100644 --- a/pkg/transport/transport.go +++ b/pkg/transport/transport.go @@ -114,6 +114,8 @@ type Transport struct { pathfinder *pathfinder.PathFinder announceHandlers []announce.Handler paths map[string]*common.Path + receipts []*packet.PacketReceipt + receiptsMutex sync.RWMutex } type Path struct { @@ -133,6 +135,8 @@ func NewTransport(cfg *common.ReticulumConfig) *Transport { links: make(map[string]*Link), destinations: make(map[string]interface{}), pathfinder: pathfinder.NewPathFinder(), + receipts: make([]*packet.PacketReceipt, 0), + receiptsMutex: sync.RWMutex{}, } return t } @@ -676,9 +680,15 @@ func (t *Transport) HandlePacket(data []byte, iface common.NetworkInterface) { case PACKET_TYPE_LINK: debug.Log(debug.DEBUG_VERBOSE, "Processing link packet") t.handleLinkPacket(data[1:], iface) - case 0x03: - debug.Log(debug.DEBUG_VERBOSE, "Processing path response") - t.handlePathResponse(data[1:], iface) + case packet.PacketTypeProof: + debug.Log(debug.DEBUG_VERBOSE, "Processing proof packet") + fullData := append([]byte{packet.PacketTypeProof}, data[1:]...) + pkt := &packet.Packet{Raw: fullData} + if err := pkt.Unpack(); err != nil { + debug.Log(debug.DEBUG_INFO, "Failed to unpack proof packet", "error", err) + return + } + t.handleProofPacket(pkt, iface) case 0x00: debug.Log(debug.DEBUG_VERBOSE, "Processing transport packet") t.handleTransportPacket(data[1:], iface) @@ -1084,6 +1094,15 @@ func (t *Transport) SendPacket(p *packet.Packet) error { return fmt.Errorf("failed to send packet: %w", err) } + p.Sent = true + p.SentAt = time.Now() + + if p.CreateReceipt { + receipt := packet.NewPacketReceipt(p) + t.RegisterReceipt(receipt) + debug.Log(debug.DEBUG_PACKETS, "Created packet receipt") + } + debug.Log(debug.DEBUG_ALL, "Packet sent successfully") return nil } @@ -1356,3 +1375,61 @@ func (t *Transport) GetInterfaces() map[string]common.NetworkInterface { func (t *Transport) GetConfig() *common.ReticulumConfig { return t.config } + +func (t *Transport) RegisterReceipt(receipt *packet.PacketReceipt) { + t.receiptsMutex.Lock() + defer t.receiptsMutex.Unlock() + t.receipts = append(t.receipts, receipt) + debug.Log(debug.DEBUG_PACKETS, "Registered packet receipt", "hash", fmt.Sprintf("%x", receipt.GetHash()[:8])) +} + +func (t *Transport) UnregisterReceipt(receipt *packet.PacketReceipt) { + t.receiptsMutex.Lock() + defer t.receiptsMutex.Unlock() + + for i, r := range t.receipts { + if r == receipt { + t.receipts = append(t.receipts[:i], t.receipts[i+1:]...) + debug.Log(debug.DEBUG_PACKETS, "Unregistered packet receipt") + return + } + } +} + +func (t *Transport) handleProofPacket(pkt *packet.Packet, iface common.NetworkInterface) { + debug.Log(debug.DEBUG_PACKETS, "Processing proof packet", "size", len(pkt.Data)) + + var proofHash []byte + if len(pkt.Data) == packet.EXPL_LENGTH { + proofHash = pkt.Data[:identity.HASHLENGTH/8] + debug.Log(debug.DEBUG_PACKETS, "Explicit proof", "hash", fmt.Sprintf("%x", proofHash[:8])) + } else { + debug.Log(debug.DEBUG_PACKETS, "Implicit proof") + } + + t.receiptsMutex.RLock() + receipts := make([]*packet.PacketReceipt, len(t.receipts)) + copy(receipts, t.receipts) + t.receiptsMutex.RUnlock() + + for _, receipt := range receipts { + receiptValidated := false + + if proofHash != nil { + receiptHash := receipt.GetHash() + if string(receiptHash) == string(proofHash) { + receiptValidated = receipt.ValidateProofPacket(pkt) + } + } else { + receiptValidated = receipt.ValidateProofPacket(pkt) + } + + if receiptValidated { + debug.Log(debug.DEBUG_PACKETS, "Proof validated for receipt") + t.UnregisterReceipt(receipt) + return + } + } + + debug.Log(debug.DEBUG_PACKETS, "No matching receipt for proof") +}