Merge main into tinygo and fix conflicts
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
// SPDX-License-Identifier: 0BSD
|
||||
// Copyright (c) 2024-2026 Sudo-Ivan / Quad4.io
|
||||
package announce
|
||||
|
||||
import (
|
||||
@@ -6,12 +8,12 @@ import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Sudo-Ivan/reticulum-go/pkg/common"
|
||||
"github.com/Sudo-Ivan/reticulum-go/pkg/identity"
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/common"
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/debug"
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/identity"
|
||||
"golang.org/x/crypto/curve25519"
|
||||
)
|
||||
|
||||
@@ -49,12 +51,6 @@ const (
|
||||
MAX_RETRIES = 3
|
||||
)
|
||||
|
||||
type AnnounceHandler interface {
|
||||
AspectFilter() []string
|
||||
ReceivedAnnounce(destinationHash []byte, announcedIdentity interface{}, appData []byte) error
|
||||
ReceivePathResponses() bool
|
||||
}
|
||||
|
||||
type Announce struct {
|
||||
mutex *sync.RWMutex
|
||||
destinationHash []byte
|
||||
@@ -67,7 +63,7 @@ type Announce struct {
|
||||
signature []byte
|
||||
pathResponse bool
|
||||
retries int
|
||||
handlers []AnnounceHandler
|
||||
handlers []Handler
|
||||
ratchetID []byte
|
||||
packet []byte
|
||||
hash []byte
|
||||
@@ -87,17 +83,17 @@ func New(dest *identity.Identity, destinationHash []byte, destinationName string
|
||||
}
|
||||
|
||||
a := &Announce{
|
||||
mutex: &sync.RWMutex{},
|
||||
identity: dest,
|
||||
destinationHash: destinationHash,
|
||||
destinationName: destinationName,
|
||||
appData: appData,
|
||||
config: config,
|
||||
hops: 0,
|
||||
timestamp: time.Now().Unix(),
|
||||
pathResponse: pathResponse,
|
||||
retries: 0,
|
||||
handlers: make([]AnnounceHandler, 0),
|
||||
mutex: &sync.RWMutex{},
|
||||
identity: dest,
|
||||
destinationHash: destinationHash,
|
||||
destinationName: destinationName,
|
||||
appData: appData,
|
||||
config: config,
|
||||
hops: 0,
|
||||
timestamp: time.Now().Unix(),
|
||||
pathResponse: pathResponse,
|
||||
retries: 0,
|
||||
handlers: make([]Handler, 0),
|
||||
}
|
||||
|
||||
// Get current ratchet ID if enabled
|
||||
@@ -123,46 +119,46 @@ func (a *Announce) Propagate(interfaces []common.NetworkInterface) error {
|
||||
a.mutex.RLock()
|
||||
defer a.mutex.RUnlock()
|
||||
|
||||
log.Printf("[DEBUG-7] Propagating announce across %d interfaces", len(interfaces))
|
||||
debug.Log(debug.DEBUG_TRACE, "Propagating announce across interfaces", "count", len(interfaces))
|
||||
|
||||
var packet []byte
|
||||
if a.packet != nil {
|
||||
log.Printf("[DEBUG-7] Using cached packet (%d bytes)", len(a.packet))
|
||||
debug.Log(debug.DEBUG_TRACE, "Using cached packet", "bytes", len(a.packet))
|
||||
packet = a.packet
|
||||
} else {
|
||||
log.Printf("[DEBUG-7] Creating new packet")
|
||||
debug.Log(debug.DEBUG_TRACE, "Creating new packet")
|
||||
packet = a.CreatePacket()
|
||||
a.packet = packet
|
||||
}
|
||||
|
||||
for _, iface := range interfaces {
|
||||
if !iface.IsEnabled() {
|
||||
log.Printf("[DEBUG-7] Skipping disabled interface: %s", iface.GetName())
|
||||
debug.Log(debug.DEBUG_TRACE, "Skipping disabled interface", "name", iface.GetName())
|
||||
continue
|
||||
}
|
||||
if !iface.GetBandwidthAvailable() {
|
||||
log.Printf("[DEBUG-7] Skipping interface with insufficient bandwidth: %s", iface.GetName())
|
||||
debug.Log(debug.DEBUG_TRACE, "Skipping interface with insufficient bandwidth", "name", iface.GetName())
|
||||
continue
|
||||
}
|
||||
|
||||
log.Printf("[DEBUG-7] Sending announce on interface %s", iface.GetName())
|
||||
debug.Log(debug.DEBUG_TRACE, "Sending announce on interface", "name", iface.GetName())
|
||||
if err := iface.Send(packet, ""); err != nil {
|
||||
log.Printf("[DEBUG-7] Failed to send on interface %s: %v", iface.GetName(), err)
|
||||
debug.Log(debug.DEBUG_TRACE, "Failed to send on interface", "name", iface.GetName(), "error", err)
|
||||
return fmt.Errorf("failed to propagate on interface %s: %w", iface.GetName(), err)
|
||||
}
|
||||
log.Printf("[DEBUG-7] Successfully sent announce on interface %s", iface.GetName())
|
||||
debug.Log(debug.DEBUG_TRACE, "Successfully sent announce on interface", "name", iface.GetName())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Announce) RegisterHandler(handler AnnounceHandler) {
|
||||
func (a *Announce) RegisterHandler(handler Handler) {
|
||||
a.mutex.Lock()
|
||||
defer a.mutex.Unlock()
|
||||
a.handlers = append(a.handlers, handler)
|
||||
}
|
||||
|
||||
func (a *Announce) DeregisterHandler(handler AnnounceHandler) {
|
||||
func (a *Announce) DeregisterHandler(handler Handler) {
|
||||
a.mutex.Lock()
|
||||
defer a.mutex.Unlock()
|
||||
for i, h := range a.handlers {
|
||||
@@ -177,13 +173,13 @@ func (a *Announce) HandleAnnounce(data []byte) error {
|
||||
a.mutex.Lock()
|
||||
defer a.mutex.Unlock()
|
||||
|
||||
log.Printf("[DEBUG-7] Handling announce packet of %d bytes", len(data))
|
||||
debug.Log(debug.DEBUG_TRACE, "Handling announce packet", "bytes", len(data))
|
||||
|
||||
// Minimum packet size validation
|
||||
// header(2) + desthash(16) + context(1) + enckey(32) + signkey(32) + namehash(10) +
|
||||
// randomhash(10) + signature(64) + min app data(3)
|
||||
if len(data) < 170 {
|
||||
log.Printf("[DEBUG-7] Invalid announce data length: %d bytes (minimum 170)", len(data))
|
||||
debug.Log(debug.DEBUG_TRACE, "Invalid announce data length", "bytes", len(data), "minimum", 170)
|
||||
return errors.New("invalid announce data length")
|
||||
}
|
||||
|
||||
@@ -196,7 +192,7 @@ func (a *Announce) HandleAnnounce(data []byte) error {
|
||||
// Get hop count
|
||||
hopCount := header[1]
|
||||
if hopCount > MAX_HOPS {
|
||||
log.Printf("[DEBUG-7] Announce exceeded max hops: %d", hopCount)
|
||||
debug.Log(debug.DEBUG_TRACE, "Announce exceeded max hops", "hops", hopCount)
|
||||
return errors.New("announce exceeded maximum hop count")
|
||||
}
|
||||
|
||||
@@ -215,8 +211,7 @@ func (a *Announce) HandleAnnounce(data []byte) error {
|
||||
contextByte = data[34]
|
||||
packetData = data[35:]
|
||||
|
||||
log.Printf("[DEBUG-7] Header type 2 announce: destHash=%x, transportID=%x, context=%d",
|
||||
destHash, transportID, contextByte)
|
||||
debug.Log(debug.DEBUG_TRACE, "Header type 2 announce", "destHash", fmt.Sprintf("%x", destHash), "transportID", fmt.Sprintf("%x", transportID), "context", contextByte)
|
||||
} else {
|
||||
// Header type 1 format: header(2) + desthash(16) + context(1) + data
|
||||
if len(data) < 19 {
|
||||
@@ -226,8 +221,7 @@ func (a *Announce) HandleAnnounce(data []byte) error {
|
||||
contextByte = data[18]
|
||||
packetData = data[19:]
|
||||
|
||||
log.Printf("[DEBUG-7] Header type 1 announce: destHash=%x, context=%d",
|
||||
destHash, contextByte)
|
||||
debug.Log(debug.DEBUG_TRACE, "Header type 1 announce", "destHash", fmt.Sprintf("%x", destHash), "context", contextByte)
|
||||
}
|
||||
|
||||
// Now parse the data portion according to the spec
|
||||
@@ -246,10 +240,10 @@ func (a *Announce) HandleAnnounce(data []byte) error {
|
||||
signature := packetData[116:180]
|
||||
appData := packetData[180:]
|
||||
|
||||
log.Printf("[DEBUG-7] Announce fields: encKey=%x, signKey=%x", encKey, signKey)
|
||||
log.Printf("[DEBUG-7] Name hash=%x, random hash=%x", nameHash, randomHash)
|
||||
log.Printf("[DEBUG-7] Ratchet=%x", ratchetData[:8])
|
||||
log.Printf("[DEBUG-7] Signature=%x, appDataLen=%d", signature[:8], len(appData))
|
||||
debug.Log(debug.DEBUG_TRACE, "Announce fields", "encKey", fmt.Sprintf("%x", encKey), "signKey", fmt.Sprintf("%x", signKey))
|
||||
debug.Log(debug.DEBUG_TRACE, "Name and random hash", "nameHash", fmt.Sprintf("%x", nameHash), "randomHash", fmt.Sprintf("%x", randomHash))
|
||||
debug.Log(debug.DEBUG_TRACE, "Ratchet data", "ratchet", fmt.Sprintf("%x", ratchetData[:8]))
|
||||
debug.Log(debug.DEBUG_TRACE, "Signature and app data", "signature", fmt.Sprintf("%x", signature[:8]), "appDataLen", len(appData))
|
||||
|
||||
// Get the destination hash from header
|
||||
var destHash []byte
|
||||
@@ -285,7 +279,7 @@ func (a *Announce) HandleAnnounce(data []byte) error {
|
||||
// Process with handlers
|
||||
for _, handler := range a.handlers {
|
||||
if handler.ReceivePathResponses() || !a.pathResponse {
|
||||
if err := handler.ReceivedAnnounce(destHash, announcedIdentity, appData); err != nil {
|
||||
if err := handler.ReceivedAnnounce(destHash, announcedIdentity, appData, hopCount); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -304,11 +298,7 @@ func (a *Announce) RequestPath(destHash []byte, onInterface common.NetworkInterf
|
||||
packet = append(packet, byte(0)) // Initial hop count
|
||||
|
||||
// Send path request
|
||||
if err := onInterface.Send(packet, ""); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
return onInterface.Send(packet, "")
|
||||
}
|
||||
|
||||
// CreateHeader creates a Reticulum packet header according to spec
|
||||
@@ -328,36 +318,40 @@ func CreateHeader(ifacFlag byte, headerType byte, contextFlag byte, propType byt
|
||||
func (a *Announce) CreatePacket() []byte {
|
||||
// This function creates the complete announce packet according to the Reticulum specification.
|
||||
// Announce Packet Structure:
|
||||
// [Header (2 bytes)][Dest Hash (16 bytes)][Transport ID (16 bytes)][Context (1 byte)][Announce Data]
|
||||
// [Header (2 bytes)][Dest Hash (16 bytes)][Context (1 byte)][Announce Data]
|
||||
// Announce Data Structure:
|
||||
// [Public Key (32 bytes)][Signing Key (32 bytes)][Name Hash (10 bytes)][Random Hash (10 bytes)][Ratchet (32 bytes)][Signature (64 bytes)][App Data]
|
||||
// [Public Key (64 bytes)][Name Hash (10 bytes)][Random Hash (10 bytes)][Ratchet (32 bytes optional)][Signature (64 bytes)][App Data]
|
||||
|
||||
// 2. Destination Hash
|
||||
destHash := a.destinationHash
|
||||
if len(destHash) == 0 {
|
||||
if len(destHash) > 16 {
|
||||
destHash = destHash[:16]
|
||||
}
|
||||
|
||||
// 3. Transport ID (zeros for broadcast announce)
|
||||
transportID := make([]byte, 16)
|
||||
|
||||
// 5. Announce Data
|
||||
// 5.1 Public Keys
|
||||
// 3. Announce Data
|
||||
// 3.1 Public Key (full 64 bytes - not split into enc/sign keys in packet)
|
||||
pubKey := a.identity.GetPublicKey()
|
||||
encKey := pubKey[:32]
|
||||
signKey := pubKey[32:]
|
||||
if len(pubKey) != 64 {
|
||||
debug.Log(debug.DEBUG_TRACE, "Invalid public key length", "expected", 64, "got", len(pubKey))
|
||||
}
|
||||
|
||||
// 5.2 Name Hash
|
||||
// 3.2 Name Hash
|
||||
nameHash := sha256.Sum256([]byte(a.destinationName))
|
||||
nameHash10 := nameHash[:10]
|
||||
|
||||
// 5.3 Random Hash
|
||||
// 3.3 Random Hash (5 bytes random + 5 bytes timestamp)
|
||||
randomHash := make([]byte, 10)
|
||||
_, err := rand.Read(randomHash)
|
||||
_, err := rand.Read(randomHash[:5])
|
||||
if err != nil {
|
||||
log.Printf("Error reading random bytes for announce: %v", err)
|
||||
debug.Log(debug.DEBUG_ERROR, "Failed to read random bytes for announce", "error", err)
|
||||
}
|
||||
// Add 5 bytes of timestamp
|
||||
timeBytes := make([]byte, 8)
|
||||
// #nosec G115 - Unix timestamp is always positive, no overflow risk
|
||||
binary.BigEndian.PutUint64(timeBytes, uint64(time.Now().Unix()))
|
||||
copy(randomHash[5:], timeBytes[:5])
|
||||
|
||||
// 5.4 Ratchet (only include if exists)
|
||||
// 3.4 Ratchet (only include if exists)
|
||||
var ratchetData []byte
|
||||
currentRatchetKey := a.identity.GetCurrentRatchetKey()
|
||||
if currentRatchetKey != nil {
|
||||
@@ -367,17 +361,17 @@ func (a *Announce) CreatePacket() []byte {
|
||||
copy(ratchetData, ratchetPub)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Determine context flag based on whether ratchet exists
|
||||
contextFlag := byte(0)
|
||||
if len(ratchetData) > 0 {
|
||||
contextFlag = 1 // FLAG_SET
|
||||
}
|
||||
|
||||
// 1. Create Header (now that we know context flag)
|
||||
// 1. Create Header - Use HEADER_TYPE_1
|
||||
header := CreateHeader(
|
||||
IFAC_NONE,
|
||||
HEADER_TYPE_2,
|
||||
HEADER_TYPE_1,
|
||||
contextFlag,
|
||||
PROP_TYPE_BROADCAST,
|
||||
DEST_TYPE_SINGLE,
|
||||
@@ -387,13 +381,15 @@ func (a *Announce) CreatePacket() []byte {
|
||||
|
||||
// 4. Context Byte
|
||||
contextByte := byte(0)
|
||||
if a.pathResponse {
|
||||
contextByte = 0x0B // PATH_RESPONSE context
|
||||
}
|
||||
|
||||
// 5.5 Signature
|
||||
// The signature is calculated over: Dest Hash + Public Keys + Name Hash + Random Hash + Ratchet (if exists) + App Data
|
||||
// 3.5 Signature
|
||||
// The signature is calculated over: Dest Hash + Public Key (64 bytes) + Name Hash + Random Hash + Ratchet (if exists) + App Data
|
||||
validationData := make([]byte, 0)
|
||||
validationData = append(validationData, destHash...)
|
||||
validationData = append(validationData, encKey...)
|
||||
validationData = append(validationData, signKey...)
|
||||
validationData = append(validationData, pubKey...)
|
||||
validationData = append(validationData, nameHash10...)
|
||||
validationData = append(validationData, randomHash...)
|
||||
if len(ratchetData) > 0 {
|
||||
@@ -402,14 +398,14 @@ func (a *Announce) CreatePacket() []byte {
|
||||
validationData = append(validationData, a.appData...)
|
||||
signature := a.identity.Sign(validationData)
|
||||
|
||||
// 6. Assemble the packet
|
||||
debug.Log(debug.DEBUG_TRACE, "Creating announce packet", "destHash", fmt.Sprintf("%x", destHash), "pubKeyLen", len(pubKey), "nameHash", fmt.Sprintf("%x", nameHash10), "randomHash", fmt.Sprintf("%x", randomHash), "ratchetLen", len(ratchetData), "sigLen", len(signature), "appDataLen", len(a.appData))
|
||||
|
||||
// 5. Assemble the packet (HEADER_TYPE_1 format)
|
||||
packet := make([]byte, 0)
|
||||
packet = append(packet, header...)
|
||||
packet = append(packet, destHash...)
|
||||
packet = append(packet, transportID...)
|
||||
packet = append(packet, contextByte)
|
||||
packet = append(packet, encKey...)
|
||||
packet = append(packet, signKey...)
|
||||
packet = append(packet, pubKey...)
|
||||
packet = append(packet, nameHash10...)
|
||||
packet = append(packet, randomHash...)
|
||||
if len(ratchetData) > 0 {
|
||||
@@ -418,6 +414,8 @@ func (a *Announce) CreatePacket() []byte {
|
||||
packet = append(packet, signature...)
|
||||
packet = append(packet, a.appData...)
|
||||
|
||||
debug.Log(debug.DEBUG_TRACE, "Final announce packet", "totalBytes", len(packet), "ratchetLen", len(ratchetData), "appDataLen", len(a.appData))
|
||||
|
||||
return packet
|
||||
}
|
||||
|
||||
@@ -452,11 +450,10 @@ func NewAnnouncePacket(pubKey []byte, appData []byte, announceID []byte) *Announ
|
||||
|
||||
// NewAnnounce creates a new announce packet for a destination
|
||||
func NewAnnounce(identity *identity.Identity, destinationHash []byte, appData []byte, ratchetID []byte, pathResponse bool, config *common.ReticulumConfig) (*Announce, error) {
|
||||
log.Printf("[DEBUG-7] Creating new announce: destHash=%x, appDataLen=%d, hasRatchet=%v, pathResponse=%v",
|
||||
destinationHash, len(appData), ratchetID != nil, pathResponse)
|
||||
debug.Log(debug.DEBUG_TRACE, "Creating new announce", "destHash", fmt.Sprintf("%x", destinationHash), "appDataLen", len(appData), "hasRatchet", ratchetID != nil, "pathResponse", pathResponse)
|
||||
|
||||
if identity == nil {
|
||||
log.Printf("[DEBUG-7] Error: nil identity provided")
|
||||
debug.Log(debug.DEBUG_ERROR, "Nil identity provided")
|
||||
return nil, errors.New("identity cannot be nil")
|
||||
}
|
||||
|
||||
@@ -469,7 +466,7 @@ func NewAnnounce(identity *identity.Identity, destinationHash []byte, appData []
|
||||
}
|
||||
|
||||
destHash := destinationHash
|
||||
log.Printf("[DEBUG-7] Using provided destination hash: %x", destHash)
|
||||
debug.Log(debug.DEBUG_TRACE, "Using provided destination hash", "destHash", fmt.Sprintf("%x", destHash))
|
||||
|
||||
a := &Announce{
|
||||
identity: identity,
|
||||
@@ -479,12 +476,11 @@ func NewAnnounce(identity *identity.Identity, destinationHash []byte, appData []
|
||||
destinationHash: destHash,
|
||||
hops: 0,
|
||||
mutex: &sync.RWMutex{},
|
||||
handlers: make([]AnnounceHandler, 0),
|
||||
handlers: make([]Handler, 0),
|
||||
config: config,
|
||||
}
|
||||
|
||||
log.Printf("[DEBUG-7] Created announce object: destHash=%x, hops=%d",
|
||||
a.destinationHash, a.hops)
|
||||
debug.Log(debug.DEBUG_TRACE, "Created announce object", "destHash", fmt.Sprintf("%x", a.destinationHash), "hops", a.hops)
|
||||
|
||||
// Create initial packet
|
||||
packet := a.CreatePacket()
|
||||
@@ -492,7 +488,7 @@ func NewAnnounce(identity *identity.Identity, destinationHash []byte, appData []
|
||||
|
||||
// Generate hash
|
||||
hash := a.Hash()
|
||||
log.Printf("[DEBUG-7] Generated announce hash: %x", hash)
|
||||
debug.Log(debug.DEBUG_TRACE, "Generated announce hash", "hash", fmt.Sprintf("%x", hash))
|
||||
|
||||
return a, nil
|
||||
}
|
||||
|
||||
123
pkg/announce/announce_test.go
Normal file
123
pkg/announce/announce_test.go
Normal file
@@ -0,0 +1,123 @@
|
||||
package announce
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/common"
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/identity"
|
||||
)
|
||||
|
||||
type mockAnnounceHandler struct {
|
||||
received bool
|
||||
}
|
||||
|
||||
func (m *mockAnnounceHandler) AspectFilter() []string {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockAnnounceHandler) ReceivedAnnounce(destinationHash []byte, announcedIdentity interface{}, appData []byte, hops uint8) error {
|
||||
m.received = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockAnnounceHandler) ReceivePathResponses() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
type mockInterface struct {
|
||||
common.BaseInterface
|
||||
sent bool
|
||||
}
|
||||
|
||||
func (m *mockInterface) Send(data []byte, address string) error {
|
||||
m.sent = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockInterface) GetBandwidthAvailable() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *mockInterface) IsEnabled() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func TestNewAnnounce(t *testing.T) {
|
||||
id, _ := identity.New()
|
||||
destHash := make([]byte, 16)
|
||||
config := &common.ReticulumConfig{}
|
||||
|
||||
ann, err := New(id, destHash, "testapp", []byte("appdata"), false, config)
|
||||
if err != nil {
|
||||
t.Fatalf("New failed: %v", err)
|
||||
}
|
||||
if ann == nil {
|
||||
t.Fatal("New returned nil")
|
||||
}
|
||||
|
||||
if !bytes.Equal(ann.destinationHash, destHash) {
|
||||
t.Error("Destination hash doesn't match")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateAndHandleAnnounce(t *testing.T) {
|
||||
id, _ := identity.New()
|
||||
destHash := make([]byte, 16)
|
||||
config := &common.ReticulumConfig{}
|
||||
|
||||
ann, _ := New(id, destHash, "testapp", []byte("appdata"), false, config)
|
||||
packet := ann.CreatePacket()
|
||||
|
||||
handler := &mockAnnounceHandler{}
|
||||
ann.RegisterHandler(handler)
|
||||
|
||||
err := ann.HandleAnnounce(packet)
|
||||
if err != nil {
|
||||
t.Fatalf("HandleAnnounce failed: %v", err)
|
||||
}
|
||||
|
||||
if !handler.received {
|
||||
t.Error("Handler did not receive announce")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPropagate(t *testing.T) {
|
||||
id, _ := identity.New()
|
||||
destHash := make([]byte, 16)
|
||||
config := &common.ReticulumConfig{}
|
||||
|
||||
ann, _ := New(id, destHash, "testapp", []byte("appdata"), false, config)
|
||||
|
||||
iface := &mockInterface{}
|
||||
iface.Name = "testiface"
|
||||
iface.Online = true
|
||||
iface.Enabled = true
|
||||
|
||||
err := ann.Propagate([]common.NetworkInterface{iface})
|
||||
if err != nil {
|
||||
t.Fatalf("Propagate failed: %v", err)
|
||||
}
|
||||
|
||||
if !iface.sent {
|
||||
t.Error("Packet was not sent on interface")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlerRegistration(t *testing.T) {
|
||||
ann := &Announce{
|
||||
mutex: &sync.RWMutex{},
|
||||
}
|
||||
handler := &mockAnnounceHandler{}
|
||||
|
||||
ann.RegisterHandler(handler)
|
||||
if len(ann.handlers) != 1 {
|
||||
t.Errorf("Expected 1 handler, got %d", len(ann.handlers))
|
||||
}
|
||||
|
||||
ann.DeregisterHandler(handler)
|
||||
if len(ann.handlers) != 0 {
|
||||
t.Errorf("Expected 0 handlers, got %d", len(ann.handlers))
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,9 @@
|
||||
// SPDX-License-Identifier: 0BSD
|
||||
// Copyright (c) 2024-2026 Sudo-Ivan / Quad4.io
|
||||
package announce
|
||||
|
||||
type Handler interface {
|
||||
AspectFilter() []string
|
||||
ReceivedAnnounce(destHash []byte, identity interface{}, appData []byte) error
|
||||
ReceivedAnnounce(destHash []byte, identity interface{}, appData []byte, hops uint8) error
|
||||
ReceivePathResponses() bool
|
||||
}
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
// SPDX-License-Identifier: 0BSD
|
||||
// Copyright (c) 2024-2026 Sudo-Ivan / Quad4.io
|
||||
package buffer
|
||||
|
||||
import (
|
||||
@@ -8,7 +10,7 @@ import (
|
||||
"io"
|
||||
"sync"
|
||||
|
||||
"github.com/Sudo-Ivan/reticulum-go/pkg/channel"
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/channel"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -16,6 +18,19 @@ const (
|
||||
MaxChunkLen = 16 * 1024
|
||||
MaxDataLen = 457 // MDU - 2 - 6 (2 for stream header, 6 for channel envelope)
|
||||
CompressTries = 4
|
||||
|
||||
// Stream header flags
|
||||
StreamHeaderEOF = 0x8000
|
||||
StreamHeaderCompressed = 0x4000
|
||||
|
||||
// Message type
|
||||
StreamDataMessageType = 0x01
|
||||
|
||||
// Header size
|
||||
StreamHeaderSize = 2
|
||||
|
||||
// Compression threshold
|
||||
CompressThreshold = 32
|
||||
)
|
||||
|
||||
type StreamDataMessage struct {
|
||||
@@ -28,10 +43,10 @@ type StreamDataMessage struct {
|
||||
func (m *StreamDataMessage) Pack() ([]byte, error) {
|
||||
headerVal := uint16(m.StreamID & StreamIDMax)
|
||||
if m.EOF {
|
||||
headerVal |= 0x8000
|
||||
headerVal |= StreamHeaderEOF
|
||||
}
|
||||
if m.Compressed {
|
||||
headerVal |= 0x4000
|
||||
headerVal |= StreamHeaderCompressed
|
||||
}
|
||||
|
||||
buf := new(bytes.Buffer)
|
||||
@@ -43,30 +58,32 @@ func (m *StreamDataMessage) Pack() ([]byte, error) {
|
||||
}
|
||||
|
||||
func (m *StreamDataMessage) GetType() uint16 {
|
||||
return 0x01 // Assign appropriate message type constant
|
||||
return StreamDataMessageType
|
||||
}
|
||||
|
||||
func (m *StreamDataMessage) Unpack(data []byte) error {
|
||||
if len(data) < 2 {
|
||||
if len(data) < StreamHeaderSize {
|
||||
return io.ErrShortBuffer
|
||||
}
|
||||
|
||||
header := binary.BigEndian.Uint16(data[:2])
|
||||
header := binary.BigEndian.Uint16(data[:StreamHeaderSize])
|
||||
m.StreamID = header & StreamIDMax
|
||||
m.EOF = (header & 0x8000) != 0
|
||||
m.Compressed = (header & 0x4000) != 0
|
||||
m.Data = data[2:]
|
||||
m.EOF = (header & StreamHeaderEOF) != 0
|
||||
m.Compressed = (header & StreamHeaderCompressed) != 0
|
||||
m.Data = data[StreamHeaderSize:]
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type RawChannelReader struct {
|
||||
streamID int
|
||||
channel *channel.Channel
|
||||
buffer *bytes.Buffer
|
||||
eof bool
|
||||
callbacks []func(int)
|
||||
mutex sync.RWMutex
|
||||
streamID int
|
||||
channel *channel.Channel
|
||||
buffer *bytes.Buffer
|
||||
eof bool
|
||||
callbacks map[int]func(int)
|
||||
nextCallbackID int
|
||||
messageHandlerID int
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
func NewRawChannelReader(streamID int, ch *channel.Channel) *RawChannelReader {
|
||||
@@ -74,28 +91,26 @@ func NewRawChannelReader(streamID int, ch *channel.Channel) *RawChannelReader {
|
||||
streamID: streamID,
|
||||
channel: ch,
|
||||
buffer: bytes.NewBuffer(nil),
|
||||
callbacks: make([]func(int), 0),
|
||||
callbacks: make(map[int]func(int)),
|
||||
}
|
||||
|
||||
ch.AddMessageHandler(reader.HandleMessage)
|
||||
reader.messageHandlerID = ch.AddMessageHandler(reader.HandleMessage)
|
||||
return reader
|
||||
}
|
||||
|
||||
func (r *RawChannelReader) AddReadyCallback(cb func(int)) {
|
||||
func (r *RawChannelReader) AddReadyCallback(cb func(int)) int {
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
r.callbacks = append(r.callbacks, cb)
|
||||
id := r.nextCallbackID
|
||||
r.nextCallbackID++
|
||||
r.callbacks[id] = cb
|
||||
return id
|
||||
}
|
||||
|
||||
func (r *RawChannelReader) RemoveReadyCallback(cb func(int)) {
|
||||
func (r *RawChannelReader) RemoveReadyCallback(id int) {
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
for i, fn := range r.callbacks {
|
||||
if &fn == &cb {
|
||||
r.callbacks = append(r.callbacks[:i], r.callbacks[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
delete(r.callbacks, id)
|
||||
}
|
||||
|
||||
func (r *RawChannelReader) Read(p []byte) (n int, err error) {
|
||||
@@ -110,11 +125,11 @@ func (r *RawChannelReader) Read(p []byte) (n int, err error) {
|
||||
if err == io.EOF && !r.eof {
|
||||
err = nil
|
||||
}
|
||||
return
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (r *RawChannelReader) HandleMessage(msg channel.MessageBase) bool { // #nosec G115
|
||||
if streamMsg, ok := msg.(*StreamDataMessage); ok && streamMsg.StreamID == uint16(r.streamID) {
|
||||
if streamMsg, ok := msg.(*StreamDataMessage); ok && streamMsg.StreamID == uint16(r.streamID) {
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
|
||||
@@ -163,7 +178,7 @@ func (w *RawChannelWriter) Write(p []byte) (n int, err error) {
|
||||
EOF: w.eof,
|
||||
}
|
||||
|
||||
if len(p) > 32 {
|
||||
if len(p) > CompressThreshold {
|
||||
for try := 1; try < CompressTries; try++ {
|
||||
chunkLen := len(p) / try
|
||||
compressed := compressData(p[:chunkLen])
|
||||
@@ -201,10 +216,7 @@ func (b *Buffer) Read(p []byte) (n int, err error) {
|
||||
}
|
||||
|
||||
func (b *Buffer) Close() error {
|
||||
if err := b.ReadWriter.Writer.Flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
return b.ReadWriter.Writer.Flush()
|
||||
}
|
||||
|
||||
func CreateReader(streamID int, ch *channel.Channel, readyCallback func(int)) *bufio.Reader {
|
||||
@@ -230,6 +242,7 @@ func compressData(data []byte) []byte {
|
||||
var compressed bytes.Buffer
|
||||
w := bytes.NewBuffer(data)
|
||||
r := bzip2.NewReader(w)
|
||||
// bearer:disable go_gosec_filesystem_decompression_bomb
|
||||
_, err := io.Copy(&compressed, r) // #nosec G104 #nosec G110
|
||||
if err != nil {
|
||||
// Handle error, e.g., log it or return an error
|
||||
@@ -243,6 +256,7 @@ func decompressData(data []byte) []byte {
|
||||
var decompressed bytes.Buffer
|
||||
// Limit the amount of data read to prevent decompression bombs
|
||||
limitedReader := io.LimitReader(reader, MaxChunkLen) // #nosec G110
|
||||
// bearer:disable go_gosec_filesystem_decompression_bomb
|
||||
_, err := io.Copy(&decompressed, limitedReader)
|
||||
if err != nil {
|
||||
// Handle error, e.g., log it or return an error
|
||||
|
||||
449
pkg/buffer/buffer_test.go
Normal file
449
pkg/buffer/buffer_test.go
Normal file
@@ -0,0 +1,449 @@
|
||||
package buffer
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"io"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/channel"
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/common"
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/packet"
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/transport"
|
||||
)
|
||||
|
||||
func TestStreamDataMessage_Pack(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
streamID uint16
|
||||
data []byte
|
||||
eof bool
|
||||
compressed bool
|
||||
}{
|
||||
{
|
||||
name: "NormalMessage",
|
||||
streamID: 123,
|
||||
data: []byte("test data"),
|
||||
eof: false,
|
||||
compressed: false,
|
||||
},
|
||||
{
|
||||
name: "EOFMessage",
|
||||
streamID: 456,
|
||||
data: []byte("final data"),
|
||||
eof: true,
|
||||
compressed: false,
|
||||
},
|
||||
{
|
||||
name: "CompressedMessage",
|
||||
streamID: 789,
|
||||
data: []byte("compressed data"),
|
||||
eof: false,
|
||||
compressed: true,
|
||||
},
|
||||
{
|
||||
name: "EmptyData",
|
||||
streamID: 0,
|
||||
data: []byte{},
|
||||
eof: false,
|
||||
compressed: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
msg := &StreamDataMessage{
|
||||
StreamID: tt.streamID,
|
||||
Data: tt.data,
|
||||
EOF: tt.eof,
|
||||
Compressed: tt.compressed,
|
||||
}
|
||||
|
||||
packed, err := msg.Pack()
|
||||
if err != nil {
|
||||
t.Fatalf("Pack() failed: %v", err)
|
||||
}
|
||||
|
||||
if len(packed) < 2 {
|
||||
t.Error("Packed message too short")
|
||||
}
|
||||
|
||||
unpacked := &StreamDataMessage{}
|
||||
if err := unpacked.Unpack(packed); err != nil {
|
||||
t.Fatalf("Unpack() failed: %v", err)
|
||||
}
|
||||
|
||||
if unpacked.StreamID != tt.streamID {
|
||||
t.Errorf("StreamID = %d, want %d", unpacked.StreamID, tt.streamID)
|
||||
}
|
||||
if unpacked.EOF != tt.eof {
|
||||
t.Errorf("EOF = %v, want %v", unpacked.EOF, tt.eof)
|
||||
}
|
||||
if unpacked.Compressed != tt.compressed {
|
||||
t.Errorf("Compressed = %v, want %v", unpacked.Compressed, tt.compressed)
|
||||
}
|
||||
if !bytes.Equal(unpacked.Data, tt.data) {
|
||||
t.Errorf("Data = %v, want %v", unpacked.Data, tt.data)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamDataMessage_Unpack(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data []byte
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "ValidMessage",
|
||||
data: []byte{0x00, 0x7B, 'h', 'e', 'l', 'l', 'o'},
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "TooShort",
|
||||
data: []byte{0x00},
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "Empty",
|
||||
data: []byte{},
|
||||
wantError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
msg := &StreamDataMessage{}
|
||||
err := msg.Unpack(tt.data)
|
||||
if (err != nil) != tt.wantError {
|
||||
t.Errorf("Unpack() error = %v, wantError %v", err, tt.wantError)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamDataMessage_GetType(t *testing.T) {
|
||||
msg := &StreamDataMessage{}
|
||||
if msg.GetType() != 0x01 {
|
||||
t.Errorf("GetType() = %d, want 0x01", msg.GetType())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRawChannelReader_AddCallback(t *testing.T) {
|
||||
reader := &RawChannelReader{
|
||||
streamID: 1,
|
||||
buffer: bytes.NewBuffer(nil),
|
||||
callbacks: make(map[int]func(int)),
|
||||
}
|
||||
|
||||
cb := func(int) {}
|
||||
|
||||
reader.AddReadyCallback(cb)
|
||||
if len(reader.callbacks) != 1 {
|
||||
t.Error("Callback should be added")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuffer_Write(t *testing.T) {
|
||||
buf := &Buffer{
|
||||
ReadWriter: bufio.NewReadWriter(bufio.NewReader(bytes.NewBuffer(nil)), bufio.NewWriter(bytes.NewBuffer(nil))),
|
||||
}
|
||||
|
||||
data := []byte("test")
|
||||
n, err := buf.Write(data)
|
||||
if err != nil {
|
||||
t.Errorf("Write() error = %v", err)
|
||||
}
|
||||
if n != len(data) {
|
||||
t.Errorf("Write() = %d bytes, want %d", n, len(data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuffer_Read(t *testing.T) {
|
||||
buf := &Buffer{
|
||||
ReadWriter: bufio.NewReadWriter(bufio.NewReader(bytes.NewBuffer([]byte("test data"))), bufio.NewWriter(bytes.NewBuffer(nil))),
|
||||
}
|
||||
|
||||
data := make([]byte, 10)
|
||||
n, err := buf.Read(data)
|
||||
if err != nil && err != io.EOF {
|
||||
t.Errorf("Read() error = %v", err)
|
||||
}
|
||||
if n <= 0 {
|
||||
t.Errorf("Read() = %d bytes, want > 0", n)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuffer_Close(t *testing.T) {
|
||||
buf := &Buffer{
|
||||
ReadWriter: bufio.NewReadWriter(bufio.NewReader(bytes.NewBuffer(nil)), bufio.NewWriter(bytes.NewBuffer(nil))),
|
||||
}
|
||||
|
||||
if err := buf.Close(); err != nil {
|
||||
t.Errorf("Close() error = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamIDMax(t *testing.T) {
|
||||
if StreamIDMax != 0x3fff {
|
||||
t.Errorf("StreamIDMax = %d, want %d", StreamIDMax, 0x3fff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMaxChunkLen(t *testing.T) {
|
||||
if MaxChunkLen != 16*1024 {
|
||||
t.Errorf("MaxChunkLen = %d, want %d", MaxChunkLen, 16*1024)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMaxDataLen(t *testing.T) {
|
||||
if MaxDataLen != 457 {
|
||||
t.Errorf("MaxDataLen = %d, want %d", MaxDataLen, 457)
|
||||
}
|
||||
}
|
||||
|
||||
type mockLink struct {
|
||||
status byte
|
||||
rtt float64
|
||||
}
|
||||
|
||||
func (m *mockLink) GetStatus() byte { return m.status }
|
||||
func (m *mockLink) GetRTT() float64 { return m.rtt }
|
||||
func (m *mockLink) RTT() float64 { return m.rtt }
|
||||
func (m *mockLink) GetLinkID() []byte { return []byte("testlink") }
|
||||
func (m *mockLink) Send(data []byte) interface{} { return &packet.Packet{Raw: data} }
|
||||
func (m *mockLink) Resend(p interface{}) error { return nil }
|
||||
func (m *mockLink) SetPacketTimeout(p interface{}, cb func(interface{}), t time.Duration) {}
|
||||
func (m *mockLink) SetPacketDelivered(p interface{}, cb func(interface{})) {}
|
||||
func (m *mockLink) HandleInbound(pkt *packet.Packet) error { return nil }
|
||||
func (m *mockLink) ValidateLinkProof(pkt *packet.Packet, networkIface common.NetworkInterface) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestNewRawChannelReader(t *testing.T) {
|
||||
link := &mockLink{status: transport.STATUS_ACTIVE}
|
||||
ch := channel.NewChannel(link)
|
||||
reader := NewRawChannelReader(123, ch)
|
||||
|
||||
if reader.streamID != 123 {
|
||||
t.Errorf("streamID = %d, want %d", reader.streamID, 123)
|
||||
}
|
||||
if reader.channel != ch {
|
||||
t.Error("channel not set correctly")
|
||||
}
|
||||
if reader.buffer == nil {
|
||||
t.Error("buffer is nil")
|
||||
}
|
||||
if reader.callbacks == nil {
|
||||
t.Error("callbacks is nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRawChannelReader_RemoveReadyCallback(t *testing.T) {
|
||||
reader := &RawChannelReader{
|
||||
streamID: 1,
|
||||
buffer: bytes.NewBuffer(nil),
|
||||
callbacks: make(map[int]func(int)),
|
||||
}
|
||||
|
||||
cb1 := func(int) {}
|
||||
cb2 := func(int) {}
|
||||
|
||||
id1 := reader.AddReadyCallback(cb1)
|
||||
reader.AddReadyCallback(cb2)
|
||||
|
||||
if len(reader.callbacks) != 2 {
|
||||
t.Errorf("callbacks length = %d, want 2", len(reader.callbacks))
|
||||
}
|
||||
|
||||
reader.RemoveReadyCallback(id1)
|
||||
|
||||
if len(reader.callbacks) != 1 {
|
||||
t.Errorf("RemoveReadyCallback did not remove callback, length = %d", len(reader.callbacks))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRawChannelReader_Read(t *testing.T) {
|
||||
reader := &RawChannelReader{
|
||||
streamID: 1,
|
||||
buffer: bytes.NewBuffer([]byte("test data")),
|
||||
eof: false,
|
||||
}
|
||||
|
||||
data := make([]byte, 10)
|
||||
n, err := reader.Read(data)
|
||||
if err != nil {
|
||||
t.Errorf("Read() error = %v", err)
|
||||
}
|
||||
if n == 0 {
|
||||
t.Error("Read() returned 0 bytes")
|
||||
}
|
||||
|
||||
reader.eof = true
|
||||
reader.buffer = bytes.NewBuffer(nil)
|
||||
n, err = reader.Read(data)
|
||||
if err != io.EOF {
|
||||
t.Errorf("Read() error = %v, want io.EOF", err)
|
||||
}
|
||||
if n != 0 {
|
||||
t.Errorf("Read() = %d bytes, want 0", n)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRawChannelReader_HandleMessage(t *testing.T) {
|
||||
reader := &RawChannelReader{
|
||||
streamID: 1,
|
||||
buffer: bytes.NewBuffer(nil),
|
||||
callbacks: make(map[int]func(int)),
|
||||
}
|
||||
|
||||
msg := &StreamDataMessage{
|
||||
StreamID: 1,
|
||||
Data: []byte("test"),
|
||||
EOF: false,
|
||||
Compressed: false,
|
||||
}
|
||||
|
||||
called := false
|
||||
reader.AddReadyCallback(func(int) {
|
||||
called = true
|
||||
})
|
||||
|
||||
result := reader.HandleMessage(msg)
|
||||
if !result {
|
||||
t.Error("HandleMessage() = false, want true")
|
||||
}
|
||||
if !called {
|
||||
t.Error("callback was not called")
|
||||
}
|
||||
if reader.buffer.Len() == 0 {
|
||||
t.Error("buffer is empty after HandleMessage")
|
||||
}
|
||||
|
||||
msg.StreamID = 2
|
||||
result = reader.HandleMessage(msg)
|
||||
if result {
|
||||
t.Error("HandleMessage() = true, want false for different streamID")
|
||||
}
|
||||
|
||||
msg.StreamID = 1
|
||||
msg.EOF = true
|
||||
reader.HandleMessage(msg)
|
||||
if !reader.eof {
|
||||
t.Error("EOF not set after HandleMessage with EOF flag")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewRawChannelWriter(t *testing.T) {
|
||||
link := &mockLink{status: transport.STATUS_ACTIVE}
|
||||
ch := channel.NewChannel(link)
|
||||
writer := NewRawChannelWriter(456, ch)
|
||||
|
||||
if writer.streamID != 456 {
|
||||
t.Errorf("streamID = %d, want %d", writer.streamID, 456)
|
||||
}
|
||||
if writer.channel != ch {
|
||||
t.Error("channel not set correctly")
|
||||
}
|
||||
if writer.eof {
|
||||
t.Error("eof should be false initially")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRawChannelWriter_Write(t *testing.T) {
|
||||
link := &mockLink{status: transport.STATUS_ACTIVE}
|
||||
ch := channel.NewChannel(link)
|
||||
writer := NewRawChannelWriter(1, ch)
|
||||
|
||||
data := []byte("test data")
|
||||
n, err := writer.Write(data)
|
||||
if err != nil {
|
||||
t.Errorf("Write() error = %v", err)
|
||||
}
|
||||
if n != len(data) {
|
||||
t.Errorf("Write() = %d bytes, want %d", n, len(data))
|
||||
}
|
||||
|
||||
largeData := make([]byte, MaxChunkLen+100)
|
||||
n, err = writer.Write(largeData)
|
||||
if err != nil {
|
||||
t.Errorf("Write() error = %v", err)
|
||||
}
|
||||
if n != MaxChunkLen {
|
||||
t.Errorf("Write() = %d bytes, want %d", n, MaxChunkLen)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRawChannelWriter_Close(t *testing.T) {
|
||||
link := &mockLink{status: transport.STATUS_ACTIVE}
|
||||
ch := channel.NewChannel(link)
|
||||
writer := NewRawChannelWriter(1, ch)
|
||||
|
||||
if writer.eof {
|
||||
t.Error("EOF should be false before Close()")
|
||||
}
|
||||
|
||||
err := writer.Close()
|
||||
if err != nil {
|
||||
t.Errorf("Close() error = %v", err)
|
||||
}
|
||||
if !writer.eof {
|
||||
t.Error("EOF should be true after Close()")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateReader(t *testing.T) {
|
||||
link := &mockLink{status: transport.STATUS_ACTIVE}
|
||||
ch := channel.NewChannel(link)
|
||||
callback := func(int) {}
|
||||
reader := CreateReader(789, ch, callback)
|
||||
|
||||
if reader == nil {
|
||||
t.Error("CreateReader() returned nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateWriter(t *testing.T) {
|
||||
link := &mockLink{status: transport.STATUS_ACTIVE}
|
||||
ch := channel.NewChannel(link)
|
||||
writer := CreateWriter(101, ch)
|
||||
|
||||
if writer == nil {
|
||||
t.Error("CreateWriter() returned nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateBidirectionalBuffer(t *testing.T) {
|
||||
link := &mockLink{status: transport.STATUS_ACTIVE}
|
||||
ch := channel.NewChannel(link)
|
||||
callback := func(int) {}
|
||||
buf := CreateBidirectionalBuffer(1, 2, ch, callback)
|
||||
|
||||
if buf == nil {
|
||||
t.Error("CreateBidirectionalBuffer() returned nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompressData(t *testing.T) {
|
||||
data := []byte("test data for compression")
|
||||
compressed := compressData(data)
|
||||
|
||||
if compressed == nil {
|
||||
t.Skip("compressData() returned nil (compression implementation may be incomplete)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecompressData(t *testing.T) {
|
||||
data := []byte("test data")
|
||||
compressed := compressData(data)
|
||||
if compressed == nil {
|
||||
t.Skip("compression not working, skipping decompression test")
|
||||
}
|
||||
|
||||
decompressed := decompressData(compressed)
|
||||
if decompressed == nil {
|
||||
t.Error("decompressData() returned nil")
|
||||
}
|
||||
}
|
||||
@@ -1,13 +1,16 @@
|
||||
// SPDX-License-Identifier: 0BSD
|
||||
// Copyright (c) 2024-2026 Sudo-Ivan / Quad4.io
|
||||
package channel
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log"
|
||||
"math"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Sudo-Ivan/reticulum-go/pkg/transport"
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/common"
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/debug"
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/transport"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -33,6 +36,19 @@ const (
|
||||
SeqModulus uint16 = SeqMax
|
||||
|
||||
FastRateThreshold = 10
|
||||
|
||||
// Timeout calculation constants
|
||||
RTTMinThreshold = 0.025
|
||||
TimeoutBaseMultiplier = 1.5
|
||||
TimeoutRingMultiplier = 2.5
|
||||
TimeoutRingOffset = 2
|
||||
|
||||
// Packet header constants
|
||||
ChannelHeaderSize = 6
|
||||
ChannelHeaderBits = 8
|
||||
|
||||
// Default retry count
|
||||
DefaultMaxTries = 3
|
||||
)
|
||||
|
||||
// MessageState represents the state of a message
|
||||
@@ -67,7 +83,13 @@ type Channel struct {
|
||||
maxTries int
|
||||
fastRateRounds int
|
||||
medRateRounds int
|
||||
messageHandlers []func(MessageBase) bool
|
||||
messageHandlers []messageHandlerEntry
|
||||
nextHandlerID int
|
||||
}
|
||||
|
||||
type messageHandlerEntry struct {
|
||||
id int
|
||||
handler func(MessageBase) bool
|
||||
}
|
||||
|
||||
// Envelope wraps a message with metadata for transmission
|
||||
@@ -84,12 +106,12 @@ type Envelope struct {
|
||||
func NewChannel(link transport.LinkInterface) *Channel {
|
||||
return &Channel{
|
||||
link: link,
|
||||
messageHandlers: make([]func(MessageBase) bool, 0),
|
||||
messageHandlers: make([]messageHandlerEntry, 0),
|
||||
mutex: sync.RWMutex{},
|
||||
windowMax: WindowMaxSlow,
|
||||
windowMin: WindowMinSlow,
|
||||
window: WindowInitial,
|
||||
maxTries: 3,
|
||||
maxTries: DefaultMaxTries,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -106,7 +128,7 @@ func (c *Channel) Send(msg MessageBase) error {
|
||||
}
|
||||
|
||||
c.mutex.Lock()
|
||||
c.nextSequence = (c.nextSequence + 1) % SeqModulus
|
||||
c.nextSequence = (c.nextSequence + common.ONE) % SeqModulus
|
||||
c.txRing = append(c.txRing, env)
|
||||
c.mutex.Unlock()
|
||||
|
||||
@@ -141,7 +163,7 @@ func (c *Channel) handleTimeout(packet interface{}) {
|
||||
env.Tries++
|
||||
if err := c.link.Resend(packet); err != nil { // #nosec G104
|
||||
// Handle resend error, e.g., log it or mark envelope as failed
|
||||
log.Printf("Failed to resend packet: %v", err)
|
||||
debug.Log(debug.DEBUG_INFO, "Failed to resend packet", "error", err)
|
||||
// Optionally, mark the envelope as failed or remove it from txRing
|
||||
// env.State = MsgStateFailed
|
||||
// c.txRing = append(c.txRing[:i], c.txRing[i+1:]...)
|
||||
@@ -169,25 +191,28 @@ func (c *Channel) handleDelivered(packet interface{}) {
|
||||
|
||||
func (c *Channel) getPacketTimeout(tries int) time.Duration {
|
||||
rtt := c.link.GetRTT()
|
||||
if rtt < 0.025 {
|
||||
rtt = 0.025
|
||||
if rtt < RTTMinThreshold {
|
||||
rtt = RTTMinThreshold
|
||||
}
|
||||
|
||||
timeout := math.Pow(1.5, float64(tries-1)) * rtt * 2.5 * float64(len(c.txRing)+2)
|
||||
timeout := math.Pow(TimeoutBaseMultiplier, float64(tries-common.ONE)) * rtt * TimeoutRingMultiplier * float64(len(c.txRing)+TimeoutRingOffset)
|
||||
return time.Duration(timeout * float64(time.Second))
|
||||
}
|
||||
|
||||
func (c *Channel) AddMessageHandler(handler func(MessageBase) bool) {
|
||||
func (c *Channel) AddMessageHandler(handler func(MessageBase) bool) int {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
c.messageHandlers = append(c.messageHandlers, handler)
|
||||
id := c.nextHandlerID
|
||||
c.nextHandlerID++
|
||||
c.messageHandlers = append(c.messageHandlers, messageHandlerEntry{id: id, handler: handler})
|
||||
return id
|
||||
}
|
||||
|
||||
func (c *Channel) RemoveMessageHandler(handler func(MessageBase) bool) {
|
||||
func (c *Channel) RemoveMessageHandler(id int) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
for i, h := range c.messageHandlers {
|
||||
if &h == &handler {
|
||||
for i, entry := range c.messageHandlers {
|
||||
if entry.id == id {
|
||||
c.messageHandlers = append(c.messageHandlers[:i], c.messageHandlers[i+1:]...)
|
||||
break
|
||||
}
|
||||
@@ -198,10 +223,10 @@ func (c *Channel) updateRateThresholds() {
|
||||
rtt := c.link.RTT()
|
||||
|
||||
if rtt > RTTFast {
|
||||
c.fastRateRounds = 0
|
||||
c.fastRateRounds = common.ZERO
|
||||
|
||||
if rtt > RTTMedium {
|
||||
c.medRateRounds = 0
|
||||
c.medRateRounds = common.ZERO
|
||||
} else {
|
||||
c.medRateRounds++
|
||||
if c.windowMax < WindowMaxMedium && c.medRateRounds == FastRateThreshold {
|
||||
@@ -218,6 +243,59 @@ func (c *Channel) updateRateThresholds() {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Channel) HandleInbound(data []byte) error {
|
||||
if len(data) < ChannelHeaderSize {
|
||||
return errors.New("channel packet too short")
|
||||
}
|
||||
|
||||
msgType := uint16(data[0])<<ChannelHeaderBits | uint16(data[1])
|
||||
sequence := uint16(data[2])<<ChannelHeaderBits | uint16(data[3])
|
||||
length := uint16(data[4])<<ChannelHeaderBits | uint16(data[5])
|
||||
|
||||
if len(data) < ChannelHeaderSize+int(length) {
|
||||
return errors.New("channel packet incomplete")
|
||||
}
|
||||
|
||||
msgData := data[ChannelHeaderSize : ChannelHeaderSize+length]
|
||||
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
for _, entry := range c.messageHandlers {
|
||||
if entry.handler != nil {
|
||||
msg := &GenericMessage{
|
||||
Type: msgType,
|
||||
Data: msgData,
|
||||
Seq: sequence,
|
||||
}
|
||||
if entry.handler(msg) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type GenericMessage struct {
|
||||
Type uint16
|
||||
Data []byte
|
||||
Seq uint16
|
||||
}
|
||||
|
||||
func (g *GenericMessage) Pack() ([]byte, error) {
|
||||
return g.Data, nil
|
||||
}
|
||||
|
||||
func (g *GenericMessage) Unpack(data []byte) error {
|
||||
g.Data = data
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *GenericMessage) GetType() uint16 {
|
||||
return g.Type
|
||||
}
|
||||
|
||||
func (c *Channel) Close() error {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
130
pkg/channel/channel_test.go
Normal file
130
pkg/channel/channel_test.go
Normal file
@@ -0,0 +1,130 @@
|
||||
package channel
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/common"
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/packet"
|
||||
)
|
||||
|
||||
type mockLink struct {
|
||||
status byte
|
||||
rtt float64
|
||||
sent [][]byte
|
||||
timeouts map[interface{}]func(interface{})
|
||||
delivered map[interface{}]func(interface{})
|
||||
}
|
||||
|
||||
func (m *mockLink) GetStatus() byte { return m.status }
|
||||
func (m *mockLink) GetRTT() float64 { return m.rtt }
|
||||
func (m *mockLink) RTT() float64 { return m.rtt }
|
||||
func (m *mockLink) GetLinkID() []byte { return []byte("testlink") }
|
||||
func (m *mockLink) Send(data []byte) interface{} {
|
||||
m.sent = append(m.sent, data)
|
||||
p := &packet.Packet{Raw: data}
|
||||
return p
|
||||
}
|
||||
func (m *mockLink) Resend(p interface{}) error { return nil }
|
||||
func (m *mockLink) SetPacketTimeout(p interface{}, cb func(interface{}), t time.Duration) {
|
||||
if m.timeouts == nil {
|
||||
m.timeouts = make(map[interface{}]func(interface{}))
|
||||
}
|
||||
m.timeouts[p] = cb
|
||||
}
|
||||
func (m *mockLink) SetPacketDelivered(p interface{}, cb func(interface{})) {
|
||||
if m.delivered == nil {
|
||||
m.delivered = make(map[interface{}]func(interface{}))
|
||||
}
|
||||
m.delivered[p] = cb
|
||||
}
|
||||
func (m *mockLink) HandleInbound(pkt *packet.Packet) error { return nil }
|
||||
func (m *mockLink) ValidateLinkProof(pkt *packet.Packet, networkIface common.NetworkInterface) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type testMessage struct {
|
||||
data []byte
|
||||
}
|
||||
|
||||
func (m *testMessage) Pack() ([]byte, error) { return m.data, nil }
|
||||
func (m *testMessage) Unpack(data []byte) error { m.data = data; return nil }
|
||||
func (m *testMessage) GetType() uint16 { return 1 }
|
||||
|
||||
func TestNewChannel(t *testing.T) {
|
||||
link := &mockLink{}
|
||||
c := NewChannel(link)
|
||||
if c == nil {
|
||||
t.Fatal("NewChannel returned nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestChannelSend(t *testing.T) {
|
||||
link := &mockLink{status: 1} // STATUS_ACTIVE
|
||||
c := NewChannel(link)
|
||||
|
||||
msg := &testMessage{data: []byte("test")}
|
||||
err := c.Send(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("Send failed: %v", err)
|
||||
}
|
||||
|
||||
if len(link.sent) != 1 {
|
||||
t.Errorf("Expected 1 packet sent, got %d", len(link.sent))
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleInbound(t *testing.T) {
|
||||
link := &mockLink{}
|
||||
c := NewChannel(link)
|
||||
|
||||
received := false
|
||||
c.AddMessageHandler(func(m MessageBase) bool {
|
||||
received = true
|
||||
return true
|
||||
})
|
||||
|
||||
// Packet format: [type 2][seq 2][len 2][data]
|
||||
data := []byte{0, 1, 0, 1, 0, 4, 't', 'e', 's', 't'}
|
||||
err := c.HandleInbound(data)
|
||||
if err != nil {
|
||||
t.Fatalf("HandleInbound failed: %v", err)
|
||||
}
|
||||
|
||||
if !received {
|
||||
t.Error("Message handler was not called")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageHandlers(t *testing.T) {
|
||||
c := &Channel{
|
||||
messageHandlers: make([]messageHandlerEntry, 0),
|
||||
}
|
||||
h := func(m MessageBase) bool { return true }
|
||||
|
||||
id := c.AddMessageHandler(h)
|
||||
if len(c.messageHandlers) != 1 {
|
||||
t.Errorf("Expected 1 handler, got %d", len(c.messageHandlers))
|
||||
}
|
||||
|
||||
c.RemoveMessageHandler(id)
|
||||
if len(c.messageHandlers) != 0 {
|
||||
t.Errorf("Expected 0 handlers, got %d", len(c.messageHandlers))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenericMessage(t *testing.T) {
|
||||
msg := &GenericMessage{Type: 1, Data: []byte("test")}
|
||||
if msg.GetType() != 1 {
|
||||
t.Error("Wrong type")
|
||||
}
|
||||
p, _ := msg.Pack()
|
||||
if !bytes.Equal(p, []byte("test")) {
|
||||
t.Error("Pack failed")
|
||||
}
|
||||
msg.Unpack([]byte("new"))
|
||||
if !bytes.Equal(msg.Data, []byte("new")) {
|
||||
t.Error("Unpack failed")
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,5 @@
|
||||
// SPDX-License-Identifier: 0BSD
|
||||
// Copyright (c) 2024-2026 Sudo-Ivan / Quad4.io
|
||||
package common
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
// SPDX-License-Identifier: 0BSD
|
||||
// Copyright (c) 2024-2026 Sudo-Ivan / Quad4.io
|
||||
package common
|
||||
|
||||
const (
|
||||
@@ -58,4 +60,87 @@ const (
|
||||
STALE_TIME = 720
|
||||
PATH_REQUEST_TTL = 300
|
||||
ANNOUNCE_TIMEOUT = 15
|
||||
|
||||
// Common Numeric Constants
|
||||
ZERO = 0
|
||||
ONE = 1
|
||||
TWO = 2
|
||||
THREE = 3
|
||||
FOUR = 4
|
||||
FIVE = 5
|
||||
SIX = 6
|
||||
SEVEN = 7
|
||||
EIGHT = 8
|
||||
FIFTEEN = 15
|
||||
|
||||
// Common Size Constants
|
||||
SIZE_16 = 16
|
||||
SIZE_32 = 32
|
||||
SIZE_48 = 48
|
||||
SIZE_64 = 64
|
||||
SIXTY_SEVEN = 67
|
||||
TOKEN_OVERHEAD = 48
|
||||
|
||||
// Common Hex Constants
|
||||
HEX_0x00 = 0x00
|
||||
HEX_0x01 = 0x01
|
||||
HEX_0x02 = 0x02
|
||||
HEX_0x03 = 0x03
|
||||
HEX_0x04 = 0x04
|
||||
HEX_0x92 = 0x92
|
||||
HEX_0x93 = 0x93
|
||||
HEX_0xC2 = 0xC2
|
||||
HEX_0xC3 = 0xC3
|
||||
HEX_0xC4 = 0xC4
|
||||
HEX_0xD1 = 0xD1
|
||||
HEX_0xD2 = 0xD2
|
||||
HEX_0xFE = 0xFE
|
||||
HEX_0xFF = 0xFF
|
||||
|
||||
// Common Numeric Constants
|
||||
NUM_11 = 11
|
||||
NUM_100 = 100
|
||||
NUM_500 = 500
|
||||
NUM_1024 = 1024
|
||||
NUM_1064 = 1064
|
||||
NUM_4242 = 4242
|
||||
NUM_0700 = 0700
|
||||
|
||||
// Common Float Constants
|
||||
FLOAT_ZERO = 0.0
|
||||
FLOAT_0_001 = 0.001
|
||||
FLOAT_0_025 = 0.025
|
||||
FLOAT_0_1 = 0.1
|
||||
FLOAT_1_0 = 1.0
|
||||
FLOAT_1_75 = 1.75
|
||||
FLOAT_5_0 = 5.0
|
||||
FLOAT_1E9 = 1e9
|
||||
|
||||
// Common String Constants
|
||||
STR_LINK_ID = "link_id"
|
||||
STR_BYTES = "bytes"
|
||||
STR_FMT_HEX = "0x%02x"
|
||||
STR_FMT_HEX_LOW = "%x"
|
||||
STR_FMT_DEC = "%d"
|
||||
STR_TEST = "test"
|
||||
STR_LINK = "link"
|
||||
STR_ERROR = "error"
|
||||
STR_HASH = "hash"
|
||||
STR_NAME = "name"
|
||||
STR_TYPE = "type"
|
||||
STR_STORAGE = "storage"
|
||||
STR_PATH = "path"
|
||||
STR_COUNT = "count"
|
||||
STR_HOME = "HOME"
|
||||
STR_PUBLIC_KEY = "public_key"
|
||||
STR_TCP_CLIENT = "TCPClientInterface"
|
||||
STR_UDP = "udp"
|
||||
STR_UDP6 = "udp6"
|
||||
STR_TCP = "tcp"
|
||||
STR_ETH0 = "eth0"
|
||||
STR_INTERFACE = "interface"
|
||||
STR_PEER = "peer"
|
||||
STR_ADDR = "addr"
|
||||
STR_LINK_NOT_ACTIVE = "link not active"
|
||||
STR_INTERFACE_OFFLINE = "interface offline or detached"
|
||||
)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
// SPDX-License-Identifier: 0BSD
|
||||
// Copyright (c) 2024-2026 Sudo-Ivan / Quad4.io
|
||||
package common
|
||||
|
||||
import (
|
||||
@@ -181,12 +183,10 @@ func (i *BaseInterface) SendLinkPacket(dest []byte, data []byte, timestamp time.
|
||||
packet = append(packet, 0x02) // Link packet type
|
||||
packet = append(packet, dest...)
|
||||
|
||||
// Add timestamp
|
||||
ts := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(ts, uint64(timestamp.Unix())) // #nosec G115
|
||||
packet = append(packet, ts...)
|
||||
|
||||
// Add data
|
||||
packet = append(packet, data...)
|
||||
|
||||
return i.Send(packet, "")
|
||||
|
||||
288
pkg/common/interfaces_test.go
Normal file
288
pkg/common/interfaces_test.go
Normal file
@@ -0,0 +1,288 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewBaseInterface(t *testing.T) {
|
||||
iface := NewBaseInterface("test0", IF_TYPE_UDP, true)
|
||||
|
||||
if iface.Name != "test0" {
|
||||
t.Errorf("Name = %q, want %q", iface.Name, "test0")
|
||||
}
|
||||
if iface.Type != IF_TYPE_UDP {
|
||||
t.Errorf("Type = %v, want %v", iface.Type, IF_TYPE_UDP)
|
||||
}
|
||||
if iface.Mode != IF_MODE_FULL {
|
||||
t.Errorf("Mode = %v, want %v", iface.Mode, IF_MODE_FULL)
|
||||
}
|
||||
if !iface.Enabled {
|
||||
t.Errorf("Enabled = %v, want true", iface.Enabled)
|
||||
}
|
||||
if iface.MTU != DEFAULT_MTU {
|
||||
t.Errorf("MTU = %d, want %d", iface.MTU, DEFAULT_MTU)
|
||||
}
|
||||
if iface.Bitrate != BITRATE_MINIMUM {
|
||||
t.Errorf("Bitrate = %d, want %d", iface.Bitrate, BITRATE_MINIMUM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseInterface_GetType(t *testing.T) {
|
||||
iface := NewBaseInterface("test1", IF_TYPE_TCP, true)
|
||||
if iface.GetType() != IF_TYPE_TCP {
|
||||
t.Errorf("GetType() = %v, want %v", iface.GetType(), IF_TYPE_TCP)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseInterface_GetMode(t *testing.T) {
|
||||
iface := NewBaseInterface("test2", IF_TYPE_UDP, true)
|
||||
if iface.GetMode() != IF_MODE_FULL {
|
||||
t.Errorf("GetMode() = %v, want %v", iface.GetMode(), IF_MODE_FULL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseInterface_GetMTU(t *testing.T) {
|
||||
iface := NewBaseInterface("test3", IF_TYPE_UDP, true)
|
||||
if iface.GetMTU() != DEFAULT_MTU {
|
||||
t.Errorf("GetMTU() = %d, want %d", iface.GetMTU(), DEFAULT_MTU)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseInterface_GetName(t *testing.T) {
|
||||
iface := NewBaseInterface("test4", IF_TYPE_UDP, true)
|
||||
if iface.GetName() != "test4" {
|
||||
t.Errorf("GetName() = %q, want %q", iface.GetName(), "test4")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseInterface_IsEnabled(t *testing.T) {
|
||||
iface := NewBaseInterface("test5", IF_TYPE_UDP, true)
|
||||
iface.Online = true
|
||||
iface.Detached = false
|
||||
|
||||
if !iface.IsEnabled() {
|
||||
t.Error("IsEnabled() = false, want true")
|
||||
}
|
||||
|
||||
iface.Enabled = false
|
||||
if iface.IsEnabled() {
|
||||
t.Error("IsEnabled() = true, want false when disabled")
|
||||
}
|
||||
|
||||
iface.Enabled = true
|
||||
iface.Online = false
|
||||
if iface.IsEnabled() {
|
||||
t.Error("IsEnabled() = true, want false when offline")
|
||||
}
|
||||
|
||||
iface.Online = true
|
||||
iface.Detached = true
|
||||
if iface.IsEnabled() {
|
||||
t.Error("IsEnabled() = true, want false when detached")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseInterface_IsOnline(t *testing.T) {
|
||||
iface := NewBaseInterface("test6", IF_TYPE_UDP, true)
|
||||
iface.Online = true
|
||||
|
||||
if !iface.IsOnline() {
|
||||
t.Error("IsOnline() = false, want true")
|
||||
}
|
||||
|
||||
iface.Online = false
|
||||
if iface.IsOnline() {
|
||||
t.Error("IsOnline() = true, want false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseInterface_IsDetached(t *testing.T) {
|
||||
iface := NewBaseInterface("test7", IF_TYPE_UDP, true)
|
||||
iface.Detached = true
|
||||
|
||||
if !iface.IsDetached() {
|
||||
t.Error("IsDetached() = false, want true")
|
||||
}
|
||||
|
||||
iface.Detached = false
|
||||
if iface.IsDetached() {
|
||||
t.Error("IsDetached() = true, want false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseInterface_SetPacketCallback(t *testing.T) {
|
||||
iface := NewBaseInterface("test8", IF_TYPE_UDP, true)
|
||||
|
||||
callback := func(data []byte, ni NetworkInterface) {}
|
||||
iface.SetPacketCallback(callback)
|
||||
|
||||
if iface.GetPacketCallback() == nil {
|
||||
t.Error("GetPacketCallback() = nil, want callback")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseInterface_GetPacketCallback(t *testing.T) {
|
||||
iface := NewBaseInterface("test9", IF_TYPE_UDP, true)
|
||||
|
||||
if iface.GetPacketCallback() != nil {
|
||||
t.Error("GetPacketCallback() != nil, want nil")
|
||||
}
|
||||
|
||||
callback := func(data []byte, ni NetworkInterface) {}
|
||||
iface.SetPacketCallback(callback)
|
||||
|
||||
if iface.GetPacketCallback() == nil {
|
||||
t.Error("GetPacketCallback() = nil, want callback")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseInterface_Detach(t *testing.T) {
|
||||
iface := NewBaseInterface("test10", IF_TYPE_UDP, true)
|
||||
iface.Online = true
|
||||
iface.Detached = false
|
||||
|
||||
iface.Detach()
|
||||
|
||||
if !iface.IsDetached() {
|
||||
t.Error("IsDetached() = false, want true after Detach()")
|
||||
}
|
||||
if iface.IsOnline() {
|
||||
t.Error("IsOnline() = true, want false after Detach()")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseInterface_Enable(t *testing.T) {
|
||||
iface := NewBaseInterface("test11", IF_TYPE_UDP, false)
|
||||
iface.Online = false
|
||||
|
||||
iface.Enable()
|
||||
|
||||
if !iface.Enabled {
|
||||
t.Error("Enabled = false, want true after Enable()")
|
||||
}
|
||||
if !iface.IsOnline() {
|
||||
t.Error("IsOnline() = false, want true after Enable()")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseInterface_Disable(t *testing.T) {
|
||||
iface := NewBaseInterface("test12", IF_TYPE_UDP, true)
|
||||
iface.Online = true
|
||||
|
||||
iface.Disable()
|
||||
|
||||
if iface.Enabled {
|
||||
t.Error("Enabled = true, want false after Disable()")
|
||||
}
|
||||
if iface.IsOnline() {
|
||||
t.Error("IsOnline() = true, want false after Disable()")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseInterface_Start(t *testing.T) {
|
||||
iface := NewBaseInterface("test13", IF_TYPE_UDP, true)
|
||||
if err := iface.Start(); err != nil {
|
||||
t.Errorf("Start() error = %v, want nil", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseInterface_Stop(t *testing.T) {
|
||||
iface := NewBaseInterface("test14", IF_TYPE_UDP, true)
|
||||
if err := iface.Stop(); err != nil {
|
||||
t.Errorf("Stop() error = %v, want nil", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseInterface_GetConn(t *testing.T) {
|
||||
iface := NewBaseInterface("test15", IF_TYPE_UDP, true)
|
||||
if iface.GetConn() != nil {
|
||||
t.Error("GetConn() != nil, want nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseInterface_Send(t *testing.T) {
|
||||
iface := NewBaseInterface("test16", IF_TYPE_UDP, true)
|
||||
data := []byte("test data")
|
||||
|
||||
if err := iface.Send(data, ""); err != nil {
|
||||
t.Errorf("Send() error = %v, want nil", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseInterface_ProcessIncoming(t *testing.T) {
|
||||
iface := NewBaseInterface("test17", IF_TYPE_UDP, true)
|
||||
|
||||
called := false
|
||||
callback := func(data []byte, ni NetworkInterface) {
|
||||
called = true
|
||||
}
|
||||
iface.SetPacketCallback(callback)
|
||||
|
||||
data := []byte("test")
|
||||
iface.ProcessIncoming(data)
|
||||
|
||||
if !called {
|
||||
t.Error("ProcessIncoming() did not call callback")
|
||||
}
|
||||
|
||||
iface.SetPacketCallback(nil)
|
||||
iface.ProcessIncoming(data)
|
||||
}
|
||||
|
||||
func TestBaseInterface_ProcessOutgoing(t *testing.T) {
|
||||
iface := NewBaseInterface("test18", IF_TYPE_UDP, true)
|
||||
data := []byte("test data")
|
||||
|
||||
if err := iface.ProcessOutgoing(data); err != nil {
|
||||
t.Errorf("ProcessOutgoing() error = %v, want nil", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseInterface_SendPathRequest(t *testing.T) {
|
||||
iface := NewBaseInterface("test19", IF_TYPE_UDP, true)
|
||||
data := []byte("path request")
|
||||
|
||||
if err := iface.SendPathRequest(data); err != nil {
|
||||
t.Errorf("SendPathRequest() error = %v, want nil", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseInterface_SendLinkPacket(t *testing.T) {
|
||||
iface := NewBaseInterface("test20", IF_TYPE_UDP, true)
|
||||
dest := []byte("destination")
|
||||
data := []byte("link data")
|
||||
timestamp := time.Now()
|
||||
|
||||
if err := iface.SendLinkPacket(dest, data, timestamp); err != nil {
|
||||
t.Errorf("SendLinkPacket() error = %v, want nil", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseInterface_GetBandwidthAvailable(t *testing.T) {
|
||||
iface := NewBaseInterface("test21", IF_TYPE_UDP, true)
|
||||
|
||||
if !iface.GetBandwidthAvailable() {
|
||||
t.Error("GetBandwidthAvailable() = false, want true when no recent transmission")
|
||||
}
|
||||
|
||||
iface.lastTx = time.Now()
|
||||
iface.TxBytes = 0
|
||||
if !iface.GetBandwidthAvailable() {
|
||||
t.Error("GetBandwidthAvailable() = false, want true when TxBytes is 0")
|
||||
}
|
||||
|
||||
iface.lastTx = time.Now().Add(-500 * time.Millisecond)
|
||||
iface.TxBytes = 1000
|
||||
iface.Bitrate = 1000000
|
||||
|
||||
if !iface.GetBandwidthAvailable() {
|
||||
t.Error("GetBandwidthAvailable() = false, want true when usage is below threshold")
|
||||
}
|
||||
|
||||
iface.TxBytes = 10000000
|
||||
iface.Bitrate = 1000
|
||||
if iface.GetBandwidthAvailable() {
|
||||
t.Error("GetBandwidthAvailable() = true, want false when usage exceeds threshold")
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,5 @@
|
||||
// SPDX-License-Identifier: 0BSD
|
||||
// Copyright (c) 2024-2026 Sudo-Ivan / Quad4.io
|
||||
package common
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
// SPDX-License-Identifier: 0BSD
|
||||
// Copyright (c) 2024-2026 Sudo-Ivan / Quad4.io
|
||||
package config
|
||||
|
||||
import (
|
||||
@@ -39,6 +41,7 @@ type Config struct {
|
||||
}
|
||||
|
||||
func LoadConfig(path string) (*Config, error) {
|
||||
// bearer:disable go_gosec_filesystem_filereadtaint
|
||||
file, err := os.Open(path) // #nosec G304
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -222,7 +225,6 @@ func InitConfig() (*Config, error) {
|
||||
cfg.Logging.Level = "info"
|
||||
cfg.Logging.File = filepath.Join(GetConfigDir(), "reticulum.log")
|
||||
|
||||
// Add default interfaces
|
||||
cfg.Interfaces = append(cfg.Interfaces, struct {
|
||||
Name string
|
||||
Type string
|
||||
|
||||
192
pkg/config/config_test.go
Normal file
192
pkg/config/config_test.go
Normal file
@@ -0,0 +1,192 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLoadConfig(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "test_config")
|
||||
|
||||
configContent := `[identity]
|
||||
name = test-identity
|
||||
storage_path = /tmp/test-storage
|
||||
|
||||
[transport]
|
||||
announce_interval = 300
|
||||
path_request_timeout = 15
|
||||
max_hops = 8
|
||||
bitrate_limit = 1000000
|
||||
|
||||
[logging]
|
||||
level = info
|
||||
file = /tmp/test.log
|
||||
|
||||
[interface test-interface]
|
||||
type = UDPInterface
|
||||
enabled = true
|
||||
listen_ip = 127.0.0.1
|
||||
listen_port = 37696
|
||||
`
|
||||
|
||||
if err := os.WriteFile(configPath, []byte(configContent), 0600); err != nil {
|
||||
t.Fatalf("Failed to write test config: %v", err)
|
||||
}
|
||||
|
||||
cfg, err := LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error = %v", err)
|
||||
}
|
||||
|
||||
if cfg == nil {
|
||||
t.Fatal("LoadConfig() returned nil")
|
||||
}
|
||||
|
||||
if len(cfg.Interfaces) == 0 {
|
||||
t.Error("No interfaces loaded")
|
||||
}
|
||||
|
||||
iface := cfg.Interfaces[0]
|
||||
if iface.Type != "UDPInterface" {
|
||||
t.Errorf("Interface type = %s, want UDPInterface", iface.Type)
|
||||
}
|
||||
if !iface.Enabled {
|
||||
t.Error("Interface should be enabled")
|
||||
}
|
||||
if iface.ListenIP != "127.0.0.1" {
|
||||
t.Errorf("Interface ListenIP = %s, want 127.0.0.1", iface.ListenIP)
|
||||
}
|
||||
if iface.ListenPort != 37696 {
|
||||
t.Errorf("Interface ListenPort = %d, want 37696", iface.ListenPort)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfig_NonexistentFile(t *testing.T) {
|
||||
_, err := LoadConfig("/nonexistent/path/config")
|
||||
if err == nil {
|
||||
t.Error("LoadConfig() should return error for nonexistent file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfig_EmptyFile(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "empty_config")
|
||||
|
||||
if err := os.WriteFile(configPath, []byte(""), 0600); err != nil {
|
||||
t.Fatalf("Failed to write empty config: %v", err)
|
||||
}
|
||||
|
||||
cfg, err := LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error = %v", err)
|
||||
}
|
||||
|
||||
if cfg == nil {
|
||||
t.Fatal("LoadConfig() returned nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfig_CommentsAndEmptyLines(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "test_config")
|
||||
|
||||
configContent := `# Comment line
|
||||
|
||||
[identity]
|
||||
name = test
|
||||
# Another comment
|
||||
|
||||
[interface test-interface]
|
||||
# Interface comment
|
||||
type = UDPInterface
|
||||
enabled = true
|
||||
`
|
||||
|
||||
if err := os.WriteFile(configPath, []byte(configContent), 0600); err != nil {
|
||||
t.Fatalf("Failed to write test config: %v", err)
|
||||
}
|
||||
|
||||
cfg, err := LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error = %v", err)
|
||||
}
|
||||
|
||||
if cfg == nil {
|
||||
t.Fatal("LoadConfig() returned nil")
|
||||
}
|
||||
|
||||
if cfg.Identity.Name != "test" {
|
||||
t.Errorf("Identity.Name = %s, want test", cfg.Identity.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveConfig(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "test_config")
|
||||
|
||||
cfg := &Config{}
|
||||
cfg.Identity.Name = "test-identity"
|
||||
cfg.Identity.StoragePath = "/tmp/test"
|
||||
cfg.Transport.AnnounceInterval = 600
|
||||
cfg.Logging.Level = "debug"
|
||||
cfg.Logging.File = "/tmp/test.log"
|
||||
|
||||
if err := SaveConfig(cfg, configPath); err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
loaded, err := LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error = %v", err)
|
||||
}
|
||||
|
||||
if loaded.Identity.Name != "test-identity" {
|
||||
t.Errorf("Identity.Name = %s, want test-identity", loaded.Identity.Name)
|
||||
}
|
||||
if loaded.Transport.AnnounceInterval != 600 {
|
||||
t.Errorf("Transport.AnnounceInterval = %d, want 600", loaded.Transport.AnnounceInterval)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetConfigDir(t *testing.T) {
|
||||
dir := GetConfigDir()
|
||||
if dir == "" {
|
||||
t.Error("GetConfigDir() returned empty string")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetDefaultConfigPath(t *testing.T) {
|
||||
path := GetDefaultConfigPath()
|
||||
if path == "" {
|
||||
t.Error("GetDefaultConfigPath() returned empty string")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureConfigDir(t *testing.T) {
|
||||
if err := EnsureConfigDir(); err != nil {
|
||||
t.Fatalf("EnsureConfigDir() error = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInitConfig(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
originalHome := os.Getenv("HOME")
|
||||
defer func() {
|
||||
if originalHome != "" {
|
||||
os.Setenv("HOME", originalHome)
|
||||
}
|
||||
}()
|
||||
|
||||
os.Setenv("HOME", tmpDir)
|
||||
|
||||
cfg, err := InitConfig()
|
||||
if err != nil {
|
||||
t.Fatalf("InitConfig() error = %v", err)
|
||||
}
|
||||
|
||||
if cfg == nil {
|
||||
t.Fatal("InitConfig() returned nil")
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,5 @@
|
||||
// SPDX-License-Identifier: 0BSD
|
||||
// Copyright (c) 2024-2026 Sudo-Ivan / Quad4.io
|
||||
package cryptography
|
||||
|
||||
import (
|
||||
|
||||
@@ -86,7 +86,6 @@ func TestAES256CBC_InvalidKeySize(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
func TestDecryptAES256CBCErrorCases(t *testing.T) {
|
||||
key, err := GenerateAES256Key()
|
||||
if err != nil {
|
||||
@@ -119,10 +118,16 @@ func TestDecryptAES256CBCErrorCases(t *testing.T) {
|
||||
t.Fatalf("Failed to create test ciphertext: %v", err)
|
||||
}
|
||||
|
||||
// Corrupt the last byte (which affects padding)
|
||||
// Corrupt the byte that XORs with the last padding byte.
|
||||
// In CBC, P[i] = D(C[i]) ^ C[i-1].
|
||||
// The last byte of plaintext P[len-1] depends on C[len-1] and C[len-1-BlockSize].
|
||||
// If we modify C[len-1-BlockSize], we flip the bits of P[len-1] predictably.
|
||||
// If we modify C[len-1] (the last byte of ciphertext), we scramble the whole block D(C[len-1]),
|
||||
// which might accidentally result in valid padding (e.g. 0x01).
|
||||
// So we corrupt the IV (or previous block) corresponding to the last byte.
|
||||
corruptedCiphertext := make([]byte, len(ciphertext))
|
||||
copy(corruptedCiphertext, ciphertext)
|
||||
corruptedCiphertext[len(corruptedCiphertext)-1] ^= 0xFF
|
||||
corruptedCiphertext[len(ciphertext)-aes.BlockSize-1] ^= 0xFF
|
||||
|
||||
_, err = DecryptAES256CBC(key, corruptedCiphertext)
|
||||
if err == nil {
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
// SPDX-License-Identifier: 0BSD
|
||||
// Copyright (c) 2024-2026 Sudo-Ivan / Quad4.io
|
||||
package cryptography
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
// SPDX-License-Identifier: 0BSD
|
||||
// Copyright (c) 2024-2026 Sudo-Ivan / Quad4.io
|
||||
package cryptography
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
// SPDX-License-Identifier: 0BSD
|
||||
// Copyright (c) 2024-2026 Sudo-Ivan / Quad4.io
|
||||
package cryptography
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,17 +1,50 @@
|
||||
// SPDX-License-Identifier: 0BSD
|
||||
// Copyright (c) 2024-2026 Sudo-Ivan / Quad4.io
|
||||
package cryptography
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"io"
|
||||
|
||||
"golang.org/x/crypto/hkdf"
|
||||
"errors"
|
||||
"math"
|
||||
)
|
||||
|
||||
func DeriveKey(secret, salt, info []byte, length int) ([]byte, error) {
|
||||
hkdfReader := hkdf.New(sha256.New, secret, salt, info)
|
||||
key := make([]byte, length)
|
||||
if _, err := io.ReadFull(hkdfReader, key); err != nil {
|
||||
return nil, err
|
||||
hashLen := 32
|
||||
|
||||
if length < 1 {
|
||||
return nil, errors.New("invalid output key length")
|
||||
}
|
||||
return key, nil
|
||||
|
||||
if len(secret) == 0 {
|
||||
return nil, errors.New("cannot derive key from empty input material")
|
||||
}
|
||||
|
||||
if len(salt) == 0 {
|
||||
salt = make([]byte, hashLen)
|
||||
}
|
||||
|
||||
if info == nil {
|
||||
info = []byte{}
|
||||
}
|
||||
|
||||
pseudorandomKey := hmac.New(sha256.New, salt)
|
||||
pseudorandomKey.Write(secret)
|
||||
prk := pseudorandomKey.Sum(nil)
|
||||
|
||||
block := []byte{}
|
||||
derived := []byte{}
|
||||
|
||||
iterations := int(math.Ceil(float64(length) / float64(hashLen)))
|
||||
for i := 0; i < iterations; i++ {
|
||||
h := hmac.New(sha256.New, prk)
|
||||
h.Write(block)
|
||||
h.Write(info)
|
||||
counter := byte((i + 1) % (0xFF + 1))
|
||||
h.Write([]byte{counter})
|
||||
block = h.Sum(nil)
|
||||
derived = append(derived, block...)
|
||||
}
|
||||
|
||||
return derived[:length], nil
|
||||
}
|
||||
|
||||
@@ -77,8 +77,8 @@ func TestDeriveKeyEdgeCases(t *testing.T) {
|
||||
|
||||
t.Run("EmptySecret", func(t *testing.T) {
|
||||
_, err := DeriveKey([]byte{}, salt, info, 32)
|
||||
if err != nil {
|
||||
t.Errorf("DeriveKey failed with empty secret: %v", err)
|
||||
if err == nil {
|
||||
t.Errorf("DeriveKey should fail with empty secret")
|
||||
}
|
||||
})
|
||||
|
||||
@@ -97,12 +97,9 @@ func TestDeriveKeyEdgeCases(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("ZeroLength", func(t *testing.T) {
|
||||
key, err := DeriveKey(secret, salt, info, 0)
|
||||
if err != nil {
|
||||
t.Errorf("DeriveKey failed with zero length: %v", err)
|
||||
}
|
||||
if len(key) != 0 {
|
||||
t.Errorf("DeriveKey with zero length returned non-empty key: %x", key)
|
||||
_, err := DeriveKey(secret, salt, info, 0)
|
||||
if err == nil {
|
||||
t.Errorf("DeriveKey should fail with zero length")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
// SPDX-License-Identifier: 0BSD
|
||||
// Copyright (c) 2024-2026 Sudo-Ivan / Quad4.io
|
||||
package cryptography
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
// SPDX-License-Identifier: 0BSD
|
||||
// Copyright (c) 2024-2026 Sudo-Ivan / Quad4.io
|
||||
package debug
|
||||
|
||||
import (
|
||||
@@ -18,8 +20,8 @@ const (
|
||||
)
|
||||
|
||||
var (
|
||||
debugLevel = flag.Int("debug", 3, "debug level (1-7)")
|
||||
logger *slog.Logger
|
||||
debugLevel = flag.Int("debug", 3, "debug level (1-7)")
|
||||
logger *slog.Logger
|
||||
initialized bool
|
||||
)
|
||||
|
||||
@@ -113,4 +115,3 @@ func SetDebugLevel(level int) {
|
||||
func GetDebugLevel() int {
|
||||
return *debugLevel
|
||||
}
|
||||
|
||||
|
||||
185
pkg/debug/debug_test.go
Normal file
185
pkg/debug/debug_test.go
Normal file
@@ -0,0 +1,185 @@
|
||||
package debug
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestInit(t *testing.T) {
|
||||
originalFlag := flag.CommandLine
|
||||
defer func() {
|
||||
flag.CommandLine = originalFlag
|
||||
initialized = false
|
||||
}()
|
||||
|
||||
flag.CommandLine = flag.NewFlagSet("test", flag.ContinueOnError)
|
||||
debugLevel = flag.Int("debug", 3, "debug level")
|
||||
|
||||
Init()
|
||||
|
||||
if !initialized {
|
||||
t.Error("Init() should set initialized to true")
|
||||
}
|
||||
|
||||
if GetLogger() == nil {
|
||||
t.Error("GetLogger() should return non-nil logger after Init()")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetLogger(t *testing.T) {
|
||||
originalFlag := flag.CommandLine
|
||||
defer func() {
|
||||
flag.CommandLine = originalFlag
|
||||
initialized = false
|
||||
}()
|
||||
|
||||
flag.CommandLine = flag.NewFlagSet("test", flag.ContinueOnError)
|
||||
debugLevel = flag.Int("debug", 3, "debug level")
|
||||
initialized = false
|
||||
|
||||
logger := GetLogger()
|
||||
if logger == nil {
|
||||
t.Error("GetLogger() should return non-nil logger")
|
||||
}
|
||||
|
||||
if !initialized {
|
||||
t.Error("GetLogger() should initialize if not already initialized")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLog(t *testing.T) {
|
||||
originalFlag := flag.CommandLine
|
||||
defer func() {
|
||||
flag.CommandLine = originalFlag
|
||||
initialized = false
|
||||
}()
|
||||
|
||||
flag.CommandLine = flag.NewFlagSet("test", flag.ContinueOnError)
|
||||
debugLevel = flag.Int("debug", 7, "debug level")
|
||||
initialized = false
|
||||
|
||||
Log(DEBUG_INFO, "test message", "key", "value")
|
||||
}
|
||||
|
||||
func TestSetDebugLevel(t *testing.T) {
|
||||
originalFlag := flag.CommandLine
|
||||
defer func() {
|
||||
flag.CommandLine = originalFlag
|
||||
initialized = false
|
||||
}()
|
||||
|
||||
flag.CommandLine = flag.NewFlagSet("test", flag.ContinueOnError)
|
||||
debugLevel = flag.Int("debug", 3, "debug level")
|
||||
initialized = false
|
||||
|
||||
SetDebugLevel(5)
|
||||
if GetDebugLevel() != 5 {
|
||||
t.Errorf("SetDebugLevel(5) did not set level correctly, got %d", GetDebugLevel())
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetDebugLevel(t *testing.T) {
|
||||
originalFlag := flag.CommandLine
|
||||
defer func() {
|
||||
flag.CommandLine = originalFlag
|
||||
initialized = false
|
||||
}()
|
||||
|
||||
flag.CommandLine = flag.NewFlagSet("test", flag.ContinueOnError)
|
||||
debugLevel = flag.Int("debug", 4, "debug level")
|
||||
|
||||
level := GetDebugLevel()
|
||||
if level != 4 {
|
||||
t.Errorf("GetDebugLevel() = %d, want 4", level)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLog_LevelFiltering(t *testing.T) {
|
||||
originalFlag := flag.CommandLine
|
||||
defer func() {
|
||||
flag.CommandLine = originalFlag
|
||||
initialized = false
|
||||
}()
|
||||
|
||||
flag.CommandLine = flag.NewFlagSet("test", flag.ContinueOnError)
|
||||
debugLevel = flag.Int("debug", 3, "debug level")
|
||||
initialized = false
|
||||
|
||||
Log(DEBUG_TRACE, "trace message")
|
||||
Log(DEBUG_INFO, "info message")
|
||||
Log(DEBUG_ERROR, "error message")
|
||||
}
|
||||
|
||||
func TestConstants(t *testing.T) {
|
||||
if DEBUG_CRITICAL != 1 {
|
||||
t.Errorf("DEBUG_CRITICAL = %d, want 1", DEBUG_CRITICAL)
|
||||
}
|
||||
if DEBUG_ERROR != 2 {
|
||||
t.Errorf("DEBUG_ERROR = %d, want 2", DEBUG_ERROR)
|
||||
}
|
||||
if DEBUG_INFO != 3 {
|
||||
t.Errorf("DEBUG_INFO = %d, want 3", DEBUG_INFO)
|
||||
}
|
||||
if DEBUG_VERBOSE != 4 {
|
||||
t.Errorf("DEBUG_VERBOSE = %d, want 4", DEBUG_VERBOSE)
|
||||
}
|
||||
if DEBUG_TRACE != 5 {
|
||||
t.Errorf("DEBUG_TRACE = %d, want 5", DEBUG_TRACE)
|
||||
}
|
||||
if DEBUG_PACKETS != 6 {
|
||||
t.Errorf("DEBUG_PACKETS = %d, want 6", DEBUG_PACKETS)
|
||||
}
|
||||
if DEBUG_ALL != 7 {
|
||||
t.Errorf("DEBUG_ALL = %d, want 7", DEBUG_ALL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLog_WithArgs(t *testing.T) {
|
||||
originalFlag := flag.CommandLine
|
||||
defer func() {
|
||||
flag.CommandLine = originalFlag
|
||||
initialized = false
|
||||
}()
|
||||
|
||||
flag.CommandLine = flag.NewFlagSet("test", flag.ContinueOnError)
|
||||
debugLevel = flag.Int("debug", 7, "debug level")
|
||||
initialized = false
|
||||
|
||||
Log(DEBUG_INFO, "test message", "key1", "value1", "key2", "value2")
|
||||
}
|
||||
|
||||
func TestInit_MultipleCalls(t *testing.T) {
|
||||
originalFlag := flag.CommandLine
|
||||
defer func() {
|
||||
flag.CommandLine = originalFlag
|
||||
initialized = false
|
||||
}()
|
||||
|
||||
flag.CommandLine = flag.NewFlagSet("test", flag.ContinueOnError)
|
||||
debugLevel = flag.Int("debug", 3, "debug level")
|
||||
initialized = false
|
||||
|
||||
Init()
|
||||
firstLogger := GetLogger()
|
||||
|
||||
Init()
|
||||
secondLogger := GetLogger()
|
||||
|
||||
if firstLogger != secondLogger {
|
||||
t.Error("Multiple Init() calls should not create new loggers")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLog_DisabledLevel(t *testing.T) {
|
||||
originalFlag := flag.CommandLine
|
||||
defer func() {
|
||||
flag.CommandLine = originalFlag
|
||||
initialized = false
|
||||
}()
|
||||
|
||||
flag.CommandLine = flag.NewFlagSet("test", flag.ContinueOnError)
|
||||
debugLevel = flag.Int("debug", 1, "debug level")
|
||||
initialized = false
|
||||
|
||||
Log(DEBUG_TRACE, "this should be filtered")
|
||||
}
|
||||
@@ -1,16 +1,24 @@
|
||||
// SPDX-License-Identifier: 0BSD
|
||||
// Copyright (c) 2024-2026 Sudo-Ivan / Quad4.io
|
||||
package destination
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"io"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Sudo-Ivan/reticulum-go/pkg/announce"
|
||||
"github.com/Sudo-Ivan/reticulum-go/pkg/common"
|
||||
"github.com/Sudo-Ivan/reticulum-go/pkg/identity"
|
||||
"github.com/Sudo-Ivan/reticulum-go/pkg/transport"
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/announce"
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/common"
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/debug"
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/identity"
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/packet"
|
||||
"github.com/vmihailenco/msgpack/v5"
|
||||
"golang.org/x/crypto/curve25519"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -36,15 +44,6 @@ const (
|
||||
|
||||
RATCHET_COUNT = 512 // Default number of retained ratchet keys
|
||||
RATCHET_INTERVAL = 1800 // Minimum interval between ratchet rotations in seconds
|
||||
|
||||
// Debug levels
|
||||
DEBUG_CRITICAL = 1 // Critical errors
|
||||
DEBUG_ERROR = 2 // Non-critical errors
|
||||
DEBUG_INFO = 3 // Important information
|
||||
DEBUG_VERBOSE = 4 // Detailed information
|
||||
DEBUG_TRACE = 5 // Very detailed tracing
|
||||
DEBUG_PACKETS = 6 // Packet-level details
|
||||
DEBUG_ALL = 7 // Everything
|
||||
)
|
||||
|
||||
type PacketCallback = common.PacketCallback
|
||||
@@ -56,6 +55,21 @@ type RequestHandler struct {
|
||||
ResponseGenerator func(path string, data []byte, requestID []byte, linkID []byte, remoteIdentity *identity.Identity, requestedAt int64) []byte
|
||||
AllowMode byte
|
||||
AllowedList [][]byte
|
||||
AutoCompress bool
|
||||
}
|
||||
|
||||
type Transport interface {
|
||||
GetConfig() *common.ReticulumConfig
|
||||
GetInterfaces() map[string]common.NetworkInterface
|
||||
RegisterDestination(hash []byte, dest interface{})
|
||||
}
|
||||
|
||||
type IncomingLinkHandler func(pkt *packet.Packet, dest *Destination, transport interface{}, networkIface common.NetworkInterface) (interface{}, error)
|
||||
|
||||
var incomingLinkHandler IncomingLinkHandler
|
||||
|
||||
func RegisterIncomingLinkHandler(handler IncomingLinkHandler) {
|
||||
incomingLinkHandler = handler
|
||||
}
|
||||
|
||||
type Destination struct {
|
||||
@@ -65,7 +79,7 @@ type Destination struct {
|
||||
appName string
|
||||
aspects []string
|
||||
hashValue []byte
|
||||
transport *transport.Transport
|
||||
transport Transport
|
||||
|
||||
acceptsLinks bool
|
||||
proofStrategy byte
|
||||
@@ -74,11 +88,15 @@ type Destination struct {
|
||||
proofCallback ProofRequestedCallback
|
||||
linkCallback LinkEstablishedCallback
|
||||
|
||||
ratchetsEnabled bool
|
||||
ratchetPath string
|
||||
ratchetCount int
|
||||
ratchetInterval int
|
||||
enforceRatchets bool
|
||||
ratchetsEnabled bool
|
||||
ratchetPath string
|
||||
ratchetCount int
|
||||
ratchetInterval int
|
||||
enforceRatchets bool
|
||||
latestRatchetTime time.Time
|
||||
latestRatchetID []byte
|
||||
ratchets [][]byte
|
||||
ratchetFileLock sync.Mutex
|
||||
|
||||
defaultAppData []byte
|
||||
mutex sync.RWMutex
|
||||
@@ -86,15 +104,11 @@ type Destination struct {
|
||||
requestHandlers map[string]*RequestHandler
|
||||
}
|
||||
|
||||
func debugLog(level int, format string, v ...interface{}) {
|
||||
log.Printf("[DEBUG-%d] %s", level, fmt.Sprintf(format, v...))
|
||||
}
|
||||
|
||||
func New(id *identity.Identity, direction byte, destType byte, appName string, transport *transport.Transport, aspects ...string) (*Destination, error) {
|
||||
debugLog(DEBUG_INFO, "Creating new destination: app=%s type=%d direction=%d", appName, destType, direction)
|
||||
func New(id *identity.Identity, direction byte, destType byte, appName string, transport Transport, aspects ...string) (*Destination, error) {
|
||||
debug.Log(debug.DEBUG_INFO, "Creating new destination", "app", appName, "type", destType, "direction", direction)
|
||||
|
||||
if id == nil {
|
||||
debugLog(DEBUG_ERROR, "Cannot create destination: identity is nil")
|
||||
debug.Log(debug.DEBUG_ERROR, "Cannot create destination: identity is nil")
|
||||
return nil, errors.New("identity cannot be nil")
|
||||
}
|
||||
|
||||
@@ -114,18 +128,24 @@ func New(id *identity.Identity, direction byte, destType byte, appName string, t
|
||||
|
||||
// Generate destination hash
|
||||
d.hashValue = d.calculateHash()
|
||||
debugLog(DEBUG_VERBOSE, "Created destination with hash: %x", d.hashValue)
|
||||
debug.Log(debug.DEBUG_VERBOSE, "Created destination with hash", "hash", fmt.Sprintf("%x", d.hashValue))
|
||||
|
||||
// Auto-register with transport if direction is IN
|
||||
if (direction & IN) != 0 {
|
||||
transport.RegisterDestination(d.hashValue, d)
|
||||
debug.Log(debug.DEBUG_INFO, "Destination auto-registered with transport", "hash", fmt.Sprintf("%x", d.hashValue))
|
||||
}
|
||||
|
||||
return d, nil
|
||||
}
|
||||
|
||||
// FromHash creates a destination from a known hash (e.g., from an announce).
|
||||
// This is used by clients to create destination objects for servers they've discovered.
|
||||
func FromHash(hash []byte, id *identity.Identity, destType byte, transport *transport.Transport) (*Destination, error) {
|
||||
debugLog(DEBUG_INFO, "Creating destination from hash: %x", hash)
|
||||
func FromHash(hash []byte, id *identity.Identity, destType byte, transport Transport) (*Destination, error) {
|
||||
debug.Log(debug.DEBUG_INFO, "Creating destination from hash", "hash", fmt.Sprintf("%x", hash))
|
||||
|
||||
if id == nil {
|
||||
debugLog(DEBUG_ERROR, "Cannot create destination: identity is nil")
|
||||
debug.Log(debug.DEBUG_ERROR, "Cannot create destination: identity is nil")
|
||||
return nil, errors.New("identity cannot be nil")
|
||||
}
|
||||
|
||||
@@ -142,32 +162,32 @@ func FromHash(hash []byte, id *identity.Identity, destType byte, transport *tran
|
||||
requestHandlers: make(map[string]*RequestHandler),
|
||||
}
|
||||
|
||||
debugLog(DEBUG_VERBOSE, "Created destination from hash: %x", hash)
|
||||
debug.Log(debug.DEBUG_VERBOSE, "Created destination from hash", "hash", fmt.Sprintf("%x", hash))
|
||||
return d, nil
|
||||
}
|
||||
|
||||
func (d *Destination) calculateHash() []byte {
|
||||
debugLog(DEBUG_TRACE, "Calculating hash for destination %s", d.ExpandName())
|
||||
debug.Log(debug.DEBUG_TRACE, "Calculating hash for destination", "name", d.ExpandName())
|
||||
|
||||
// destination_hash = SHA256(name_hash_10bytes + identity_hash_16bytes)[:16]
|
||||
// Identity hash is the truncated hash of the public key (16 bytes)
|
||||
identityHash := identity.TruncatedHash(d.identity.GetPublicKey())
|
||||
|
||||
|
||||
// Name hash is the FULL 32-byte SHA256, then we take first 10 bytes for concatenation
|
||||
nameHashFull := sha256.Sum256([]byte(d.ExpandName()))
|
||||
nameHash10 := nameHashFull[:10] // Only use 10 bytes
|
||||
nameHash10 := nameHashFull[:10] // Only use 10 bytes
|
||||
|
||||
debugLog(DEBUG_ALL, "Identity hash: %x", identityHash)
|
||||
debugLog(DEBUG_ALL, "Name hash (10 bytes): %x", nameHash10)
|
||||
debug.Log(debug.DEBUG_ALL, "Identity hash", "hash", fmt.Sprintf("%x", identityHash))
|
||||
debug.Log(debug.DEBUG_ALL, "Name hash (10 bytes)", "hash", fmt.Sprintf("%x", nameHash10))
|
||||
|
||||
// Concatenate name_hash (10 bytes) + identity_hash (16 bytes) = 26 bytes
|
||||
combined := append(nameHash10, identityHash...)
|
||||
|
||||
|
||||
// Then hash again and truncate to 16 bytes
|
||||
finalHashFull := sha256.Sum256(combined)
|
||||
finalHash := finalHashFull[:16]
|
||||
|
||||
debugLog(DEBUG_VERBOSE, "Calculated destination hash: %x", finalHash)
|
||||
debug.Log(debug.DEBUG_VERBOSE, "Calculated destination hash", "hash", fmt.Sprintf("%x", finalHash))
|
||||
|
||||
return finalHash
|
||||
}
|
||||
@@ -180,50 +200,52 @@ func (d *Destination) ExpandName() string {
|
||||
return name
|
||||
}
|
||||
|
||||
func (d *Destination) Announce(appData []byte) error {
|
||||
func (d *Destination) Announce(pathResponse bool, tag []byte, attachedInterface common.NetworkInterface) error {
|
||||
d.mutex.Lock()
|
||||
defer d.mutex.Unlock()
|
||||
|
||||
log.Printf("[DEBUG-4] Announcing destination %s", d.ExpandName())
|
||||
debug.Log(debug.DEBUG_VERBOSE, "Announcing destination", "name", d.ExpandName(), "path_response", pathResponse)
|
||||
|
||||
if appData == nil {
|
||||
appData = d.defaultAppData
|
||||
}
|
||||
appData := d.defaultAppData
|
||||
|
||||
// Create announce packet using announce package
|
||||
// Pass the destination hash, name, and app data
|
||||
announce, err := announce.New(d.identity, d.hashValue, d.ExpandName(), appData, false, d.transport.GetConfig())
|
||||
announceObj, err := announce.New(d.identity, d.hashValue, d.ExpandName(), appData, pathResponse, d.transport.GetConfig())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create announce: %w", err)
|
||||
}
|
||||
|
||||
packet := announce.GetPacket()
|
||||
packet := announceObj.GetPacket()
|
||||
if packet == nil {
|
||||
return errors.New("failed to create announce packet")
|
||||
}
|
||||
|
||||
// Send announce packet to all interfaces
|
||||
log.Printf("[DEBUG-4] Sending announce packet to all interfaces")
|
||||
if pathResponse && tag != nil {
|
||||
debug.Log(debug.DEBUG_INFO, "Sending path response announce", "tag", fmt.Sprintf("%x", tag))
|
||||
}
|
||||
|
||||
if d.transport == nil {
|
||||
return errors.New("transport not initialized")
|
||||
}
|
||||
|
||||
interfaces := d.transport.GetInterfaces()
|
||||
log.Printf("[DEBUG-7] Got %d interfaces from transport", len(interfaces))
|
||||
|
||||
var lastErr error
|
||||
for name, iface := range interfaces {
|
||||
log.Printf("[DEBUG-7] Checking interface %s: enabled=%v, online=%v", name, iface.IsEnabled(), iface.IsOnline())
|
||||
if iface.IsEnabled() && iface.IsOnline() {
|
||||
log.Printf("[DEBUG-7] Sending announce to interface %s (%d bytes)", name, len(packet))
|
||||
if err := iface.Send(packet, ""); err != nil {
|
||||
log.Printf("[ERROR] Failed to send announce on interface %s: %v", name, err)
|
||||
if attachedInterface != nil {
|
||||
if attachedInterface.IsEnabled() && attachedInterface.IsOnline() {
|
||||
debug.Log(debug.DEBUG_VERBOSE, "Sending announce to attached interface", "name", attachedInterface.GetName())
|
||||
if err := attachedInterface.Send(packet, ""); err != nil {
|
||||
debug.Log(debug.DEBUG_ERROR, "Failed to send announce on attached interface", "error", err)
|
||||
lastErr = err
|
||||
} else {
|
||||
log.Printf("[DEBUG-7] Successfully sent announce to interface %s", name)
|
||||
}
|
||||
} else {
|
||||
log.Printf("[DEBUG-7] Skipping interface %s (not enabled or not online)", name)
|
||||
}
|
||||
} else {
|
||||
interfaces := d.transport.GetInterfaces()
|
||||
for name, iface := range interfaces {
|
||||
if iface.IsEnabled() && iface.IsOnline() {
|
||||
debug.Log(debug.DEBUG_VERBOSE, "Sending announce to interface", "name", name)
|
||||
if err := iface.Send(packet, ""); err != nil {
|
||||
debug.Log(debug.DEBUG_ERROR, "Failed to send announce on interface", "name", name, "error", err)
|
||||
lastErr = err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -234,11 +256,11 @@ func (d *Destination) AcceptsLinks(accepts bool) {
|
||||
d.mutex.Lock()
|
||||
defer d.mutex.Unlock()
|
||||
d.acceptsLinks = accepts
|
||||
|
||||
|
||||
// Register with transport if accepting links
|
||||
if accepts && d.transport != nil {
|
||||
d.transport.RegisterDestination(d.hashValue, d)
|
||||
debugLog(DEBUG_VERBOSE, "Destination %x registered with transport for link requests", d.hashValue)
|
||||
debug.Log(debug.DEBUG_VERBOSE, "Destination registered with transport for link requests", "hash", fmt.Sprintf("%x", d.hashValue))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -254,22 +276,28 @@ func (d *Destination) GetLinkCallback() common.LinkEstablishedCallback {
|
||||
return d.linkCallback
|
||||
}
|
||||
|
||||
func (d *Destination) HandleIncomingLinkRequest(linkID []byte, transport interface{}, networkIface common.NetworkInterface) error {
|
||||
debugLog(DEBUG_INFO, "Handling incoming link request for destination %x", d.GetHash())
|
||||
|
||||
// Import link package here to avoid circular dependency at package level
|
||||
// We'll use dynamic import by having the caller create the link
|
||||
// For now, just call the callback with a placeholder
|
||||
|
||||
if d.linkCallback != nil {
|
||||
debugLog(DEBUG_INFO, "Calling link established callback")
|
||||
// Pass linkID as the link object for now
|
||||
// The callback will need to handle creating the actual link
|
||||
d.linkCallback(linkID)
|
||||
} else {
|
||||
debugLog(DEBUG_VERBOSE, "No link callback set")
|
||||
func (d *Destination) HandleIncomingLinkRequest(pkt interface{}, transport interface{}, networkIface common.NetworkInterface) error {
|
||||
debug.Log(debug.DEBUG_INFO, "Handling incoming link request for destination", "hash", fmt.Sprintf("%x", d.GetHash()))
|
||||
|
||||
pktObj, ok := pkt.(*packet.Packet)
|
||||
if !ok {
|
||||
return errors.New("invalid packet type")
|
||||
}
|
||||
|
||||
|
||||
if incomingLinkHandler == nil {
|
||||
return errors.New("no incoming link handler registered")
|
||||
}
|
||||
|
||||
linkIface, err := incomingLinkHandler(pktObj, d, transport, networkIface)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to handle link request: %w", err)
|
||||
}
|
||||
|
||||
if d.linkCallback != nil && linkIface != nil {
|
||||
debug.Log(debug.DEBUG_INFO, "Calling link established callback")
|
||||
d.linkCallback(linkIface)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -279,6 +307,35 @@ func (d *Destination) SetPacketCallback(callback common.PacketCallback) {
|
||||
d.packetCallback = callback
|
||||
}
|
||||
|
||||
func (d *Destination) Receive(pkt *packet.Packet, iface common.NetworkInterface) {
|
||||
d.mutex.RLock()
|
||||
callback := d.packetCallback
|
||||
d.mutex.RUnlock()
|
||||
|
||||
if callback == nil {
|
||||
debug.Log(debug.DEBUG_VERBOSE, "No packet callback set for destination")
|
||||
return
|
||||
}
|
||||
|
||||
if pkt.PacketType == packet.PacketTypeLinkReq {
|
||||
debug.Log(debug.DEBUG_INFO, "Received link request for destination")
|
||||
if err := d.HandleIncomingLinkRequest(pkt, d.transport, iface); err != nil {
|
||||
debug.Log(debug.DEBUG_ERROR, "Failed to handle incoming link request", "error", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
plaintext, err := d.Decrypt(pkt.Data)
|
||||
if err != nil {
|
||||
debug.Log(debug.DEBUG_INFO, "Failed to decrypt packet data", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
debug.Log(debug.DEBUG_INFO, "Destination received packet", "bytes", len(plaintext))
|
||||
|
||||
callback(plaintext, iface)
|
||||
}
|
||||
|
||||
func (d *Destination) SetProofRequestedCallback(callback common.ProofRequestedCallback) {
|
||||
d.mutex.Lock()
|
||||
defer d.mutex.Unlock()
|
||||
@@ -295,8 +352,27 @@ func (d *Destination) EnableRatchets(path string) bool {
|
||||
d.mutex.Lock()
|
||||
defer d.mutex.Unlock()
|
||||
|
||||
if path == "" {
|
||||
debug.Log(debug.DEBUG_ERROR, "No ratchet file path specified")
|
||||
return false
|
||||
}
|
||||
|
||||
d.ratchetsEnabled = true
|
||||
d.ratchetPath = path
|
||||
d.latestRatchetTime = time.Time{} // Zero time to force rotation
|
||||
|
||||
// Load or initialize ratchets
|
||||
if err := d.reloadRatchets(); err != nil {
|
||||
debug.Log(debug.DEBUG_ERROR, "Failed to load ratchets", "error", err)
|
||||
// Initialize empty ratchet list
|
||||
d.ratchets = make([][]byte, 0)
|
||||
if err := d.persistRatchets(); err != nil {
|
||||
debug.Log(debug.DEBUG_ERROR, "Failed to create initial ratchet file", "error", err)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
debug.Log(debug.DEBUG_INFO, "Ratchets enabled", "path", path)
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -377,34 +453,88 @@ func (d *Destination) DeregisterRequestHandler(path string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (d *Destination) GetRequestHandler(pathHash []byte) func([]byte, []byte, []byte, []byte, *identity.Identity, time.Time) interface{} {
|
||||
d.mutex.RLock()
|
||||
defer d.mutex.RUnlock()
|
||||
|
||||
for _, handler := range d.requestHandlers {
|
||||
handlerPathHash := identity.TruncatedHash([]byte(handler.Path))
|
||||
if string(handlerPathHash) == string(pathHash) {
|
||||
return func(pathHash []byte, data []byte, requestID []byte, linkID []byte, remoteIdentity *identity.Identity, requestedAt time.Time) interface{} {
|
||||
allowed := false
|
||||
if handler.AllowMode == ALLOW_ALL {
|
||||
allowed = true
|
||||
} else if handler.AllowMode == ALLOW_LIST && remoteIdentity != nil {
|
||||
remoteHash := remoteIdentity.Hash()
|
||||
for _, allowedHash := range handler.AllowedList {
|
||||
if string(remoteHash) == string(allowedHash) {
|
||||
allowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !allowed {
|
||||
return nil
|
||||
}
|
||||
|
||||
result := handler.ResponseGenerator(handler.Path, data, requestID, linkID, remoteIdentity, requestedAt.Unix())
|
||||
if result == nil {
|
||||
return nil
|
||||
}
|
||||
return result
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Destination) HandleRequest(path string, data []byte, requestID []byte, linkID []byte, remoteIdentity *identity.Identity, requestedAt int64) []byte {
|
||||
d.mutex.RLock()
|
||||
handler, exists := d.requestHandlers[path]
|
||||
d.mutex.RUnlock()
|
||||
|
||||
if !exists {
|
||||
debug.Log(debug.DEBUG_INFO, "No handler registered for path", "path", path)
|
||||
return []byte(">Not Found\n\nThe requested resource was not found.")
|
||||
}
|
||||
|
||||
debug.Log(debug.DEBUG_VERBOSE, "Calling request handler", "path", path)
|
||||
result := handler.ResponseGenerator(path, data, requestID, linkID, remoteIdentity, requestedAt)
|
||||
if result == nil {
|
||||
return []byte(">Not Found\n\nThe requested resource was not found.")
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (d *Destination) Encrypt(plaintext []byte) ([]byte, error) {
|
||||
if d.destType == PLAIN {
|
||||
log.Printf("[DEBUG-4] Using plaintext transmission for PLAIN destination")
|
||||
debug.Log(debug.DEBUG_VERBOSE, "Using plaintext transmission for PLAIN destination")
|
||||
return plaintext, nil
|
||||
}
|
||||
|
||||
if d.identity == nil {
|
||||
log.Printf("[DEBUG-3] Cannot encrypt: no identity available")
|
||||
debug.Log(debug.DEBUG_INFO, "Cannot encrypt: no identity available")
|
||||
return nil, errors.New("no identity available for encryption")
|
||||
}
|
||||
|
||||
log.Printf("[DEBUG-4] Encrypting %d bytes for destination type %d", len(plaintext), d.destType)
|
||||
debug.Log(debug.DEBUG_VERBOSE, "Encrypting bytes for destination", "bytes", len(plaintext), "destType", d.destType)
|
||||
|
||||
switch d.destType {
|
||||
case SINGLE:
|
||||
recipientKey := d.identity.GetPublicKey()
|
||||
log.Printf("[DEBUG-4] Encrypting for single recipient with key %x", recipientKey[:8])
|
||||
recipientKey := d.identity.GetEncryptionKey()
|
||||
debug.Log(debug.DEBUG_VERBOSE, "Encrypting for single recipient", "key", fmt.Sprintf("%x", recipientKey[:8]))
|
||||
return d.identity.Encrypt(plaintext, recipientKey)
|
||||
case GROUP:
|
||||
key := d.identity.GetCurrentRatchetKey()
|
||||
if key == nil {
|
||||
log.Printf("[DEBUG-3] Cannot encrypt: no ratchet key available")
|
||||
debug.Log(debug.DEBUG_INFO, "Cannot encrypt: no ratchet key available")
|
||||
return nil, errors.New("no ratchet key available")
|
||||
}
|
||||
log.Printf("[DEBUG-4] Encrypting for group with ratchet key %x", key[:8])
|
||||
debug.Log(debug.DEBUG_VERBOSE, "Encrypting for group with ratchet key", "key", fmt.Sprintf("%x", key[:8]))
|
||||
return d.identity.EncryptWithHMAC(plaintext, key)
|
||||
default:
|
||||
log.Printf("[DEBUG-3] Unsupported destination type %d for encryption", d.destType)
|
||||
debug.Log(debug.DEBUG_INFO, "Unsupported destination type for encryption", "destType", d.destType)
|
||||
return nil, errors.New("unsupported destination type for encryption")
|
||||
}
|
||||
}
|
||||
@@ -465,3 +595,186 @@ func (d *Destination) GetHash() []byte {
|
||||
}
|
||||
return d.hashValue
|
||||
}
|
||||
|
||||
func (d *Destination) persistRatchets() error {
|
||||
d.ratchetFileLock.Lock()
|
||||
defer d.ratchetFileLock.Unlock()
|
||||
|
||||
if !d.ratchetsEnabled || d.ratchetPath == "" {
|
||||
return errors.New("ratchets not enabled or no path specified")
|
||||
}
|
||||
|
||||
debug.Log(debug.DEBUG_PACKETS, "Persisting ratchets", "count", len(d.ratchets), "path", d.ratchetPath)
|
||||
|
||||
// Pack ratchets using msgpack
|
||||
packedRatchets, err := msgpack.Marshal(d.ratchets)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to pack ratchets: %w", err)
|
||||
}
|
||||
|
||||
// Sign the packed ratchets
|
||||
signature, err := d.Sign(packedRatchets)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to sign ratchets: %w", err)
|
||||
}
|
||||
|
||||
// Create structure
|
||||
persistedData := map[string][]byte{
|
||||
"signature": signature,
|
||||
"ratchets": packedRatchets,
|
||||
}
|
||||
|
||||
// Pack the entire structure
|
||||
finalData, err := msgpack.Marshal(persistedData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to pack ratchet data: %w", err)
|
||||
}
|
||||
|
||||
// Write to temporary file first, then rename (atomic operation)
|
||||
tempPath := d.ratchetPath + ".tmp"
|
||||
file, err := os.Create(tempPath) // #nosec G304
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create temp ratchet file: %w", err)
|
||||
}
|
||||
|
||||
if _, err := file.Write(finalData); err != nil {
|
||||
// #nosec G104 - Error already being handled, cleanup errors are non-critical
|
||||
file.Close()
|
||||
// #nosec G104 - Error already being handled, cleanup errors are non-critical
|
||||
os.Remove(tempPath)
|
||||
return fmt.Errorf("failed to write ratchet data: %w", err)
|
||||
}
|
||||
// #nosec G104 - File is being closed after successful write, error is non-critical
|
||||
file.Close()
|
||||
|
||||
// Remove old file if exists
|
||||
if _, err := os.Stat(d.ratchetPath); err == nil {
|
||||
// #nosec G104 - Removing old file, error is non-critical if it doesn't exist
|
||||
os.Remove(d.ratchetPath)
|
||||
}
|
||||
|
||||
// Atomic rename
|
||||
if err := os.Rename(tempPath, d.ratchetPath); err != nil {
|
||||
// #nosec G104 - Error already being handled, cleanup errors are non-critical
|
||||
os.Remove(tempPath)
|
||||
return fmt.Errorf("failed to rename ratchet file: %w", err)
|
||||
}
|
||||
|
||||
debug.Log(debug.DEBUG_PACKETS, "Ratchets persisted successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Destination) reloadRatchets() error {
|
||||
d.ratchetFileLock.Lock()
|
||||
defer d.ratchetFileLock.Unlock()
|
||||
|
||||
if _, err := os.Stat(d.ratchetPath); os.IsNotExist(err) {
|
||||
debug.Log(debug.DEBUG_INFO, "No existing ratchet data found, initializing new ratchet file")
|
||||
d.ratchets = make([][]byte, 0)
|
||||
return nil
|
||||
}
|
||||
|
||||
file, err := os.Open(d.ratchetPath) // #nosec G304
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open ratchet file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// Read all data
|
||||
fileData, err := io.ReadAll(file)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read ratchet file: %w", err)
|
||||
}
|
||||
|
||||
// Unpack outer structure
|
||||
var persistedData map[string][]byte
|
||||
if err := msgpack.Unmarshal(fileData, &persistedData); err != nil {
|
||||
return fmt.Errorf("failed to unpack ratchet data: %w", err)
|
||||
}
|
||||
|
||||
signature, hasSignature := persistedData["signature"]
|
||||
packedRatchets, hasRatchets := persistedData["ratchets"]
|
||||
|
||||
if !hasSignature || !hasRatchets {
|
||||
return fmt.Errorf("invalid ratchet file format")
|
||||
}
|
||||
|
||||
// Verify signature
|
||||
if !d.identity.Verify(packedRatchets, signature) {
|
||||
return fmt.Errorf("invalid ratchet file signature")
|
||||
}
|
||||
|
||||
// Unpack ratchet list
|
||||
if err := msgpack.Unmarshal(packedRatchets, &d.ratchets); err != nil {
|
||||
return fmt.Errorf("failed to unpack ratchet list: %w", err)
|
||||
}
|
||||
|
||||
debug.Log(debug.DEBUG_INFO, "Ratchets reloaded successfully", "count", len(d.ratchets))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Destination) RotateRatchets() error {
|
||||
d.mutex.Lock()
|
||||
defer d.mutex.Unlock()
|
||||
|
||||
if !d.ratchetsEnabled {
|
||||
return errors.New("ratchets not enabled")
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
if !d.latestRatchetTime.IsZero() && now.Before(d.latestRatchetTime.Add(time.Duration(d.ratchetInterval)*time.Second)) {
|
||||
debug.Log(debug.DEBUG_TRACE, "Ratchet rotation interval not reached")
|
||||
return nil
|
||||
}
|
||||
|
||||
debug.Log(debug.DEBUG_INFO, "Rotating ratchets", "destination", d.ExpandName())
|
||||
|
||||
// Generate new ratchet key (32 bytes for X25519 private key)
|
||||
newRatchet := make([]byte, 32)
|
||||
if _, err := io.ReadFull(rand.Reader, newRatchet); err != nil {
|
||||
return fmt.Errorf("failed to generate new ratchet: %w", err)
|
||||
}
|
||||
|
||||
// Insert at beginning (most recent first)
|
||||
d.ratchets = append([][]byte{newRatchet}, d.ratchets...)
|
||||
d.latestRatchetTime = now
|
||||
|
||||
// Get ratchet public key for ID
|
||||
ratchetPub, err := curve25519.X25519(newRatchet, curve25519.Basepoint)
|
||||
if err == nil {
|
||||
d.latestRatchetID = identity.TruncatedHash(ratchetPub)[:identity.NAME_HASH_LENGTH/8]
|
||||
}
|
||||
|
||||
// Clean old ratchets
|
||||
d.cleanRatchets()
|
||||
|
||||
// Persist to disk
|
||||
if err := d.persistRatchets(); err != nil {
|
||||
debug.Log(debug.DEBUG_ERROR, "Failed to persist ratchets after rotation", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
debug.Log(debug.DEBUG_INFO, "Ratchet rotation completed", "total_ratchets", len(d.ratchets))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Destination) cleanRatchets() {
|
||||
if len(d.ratchets) > d.ratchetCount {
|
||||
debug.Log(debug.DEBUG_TRACE, "Cleaning old ratchets", "before", len(d.ratchets), "keeping", d.ratchetCount)
|
||||
d.ratchets = d.ratchets[:d.ratchetCount]
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Destination) GetRatchets() [][]byte {
|
||||
d.mutex.RLock()
|
||||
defer d.mutex.RUnlock()
|
||||
|
||||
if !d.ratchetsEnabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Return copy to prevent external modification
|
||||
ratchetsCopy := make([][]byte, len(d.ratchets))
|
||||
copy(ratchetsCopy, d.ratchets)
|
||||
return ratchetsCopy
|
||||
}
|
||||
|
||||
152
pkg/destination/destination_test.go
Normal file
152
pkg/destination/destination_test.go
Normal file
@@ -0,0 +1,152 @@
|
||||
package destination
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/common"
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/identity"
|
||||
)
|
||||
|
||||
type mockTransport struct {
|
||||
config *common.ReticulumConfig
|
||||
interfaces map[string]common.NetworkInterface
|
||||
}
|
||||
|
||||
func (m *mockTransport) GetConfig() *common.ReticulumConfig {
|
||||
return m.config
|
||||
}
|
||||
|
||||
func (m *mockTransport) GetInterfaces() map[string]common.NetworkInterface {
|
||||
return m.interfaces
|
||||
}
|
||||
|
||||
func (m *mockTransport) RegisterDestination(hash []byte, dest interface{}) {
|
||||
}
|
||||
|
||||
type mockInterface struct {
|
||||
common.BaseInterface
|
||||
}
|
||||
|
||||
func (m *mockInterface) Send(data []byte, address string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestNewDestination(t *testing.T) {
|
||||
id, _ := identity.New()
|
||||
transport := &mockTransport{config: &common.ReticulumConfig{}}
|
||||
|
||||
dest, err := New(id, IN|OUT, SINGLE, "testapp", transport, "testaspect")
|
||||
if err != nil {
|
||||
t.Fatalf("New failed: %v", err)
|
||||
}
|
||||
if dest == nil {
|
||||
t.Fatal("New returned nil")
|
||||
}
|
||||
|
||||
if dest.ExpandName() != "testapp.testaspect" {
|
||||
t.Errorf("Expected name testapp.testaspect, got %s", dest.ExpandName())
|
||||
}
|
||||
|
||||
hash := dest.GetHash()
|
||||
if len(hash) != 16 {
|
||||
t.Errorf("Expected hash length 16, got %d", len(hash))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromHash(t *testing.T) {
|
||||
id, _ := identity.New()
|
||||
transport := &mockTransport{}
|
||||
hash := make([]byte, 16)
|
||||
|
||||
dest, err := FromHash(hash, id, SINGLE, transport)
|
||||
if err != nil {
|
||||
t.Fatalf("FromHash failed: %v", err)
|
||||
}
|
||||
if !bytes.Equal(dest.GetHash(), hash) {
|
||||
t.Error("Hashes don't match")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestHandlers(t *testing.T) {
|
||||
id, _ := identity.New()
|
||||
dest, _ := New(id, IN, SINGLE, "test", &mockTransport{})
|
||||
|
||||
path := "test/path"
|
||||
response := []byte("hello")
|
||||
|
||||
err := dest.RegisterRequestHandler(path, func(p string, d []byte, rid []byte, lid []byte, ri *identity.Identity, ra int64) []byte {
|
||||
return response
|
||||
}, ALLOW_ALL, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("RegisterRequestHandler failed: %v", err)
|
||||
}
|
||||
|
||||
result := dest.HandleRequest(path, nil, nil, nil, nil, 0)
|
||||
if !bytes.Equal(result, response) {
|
||||
t.Errorf("Expected response %q, got %q", response, result)
|
||||
}
|
||||
|
||||
if !dest.DeregisterRequestHandler(path) {
|
||||
t.Error("DeregisterRequestHandler failed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncryptDecrypt(t *testing.T) {
|
||||
id, _ := identity.New()
|
||||
dest, _ := New(id, IN|OUT, SINGLE, "test", &mockTransport{})
|
||||
|
||||
plaintext := []byte("hello world")
|
||||
ciphertext, err := dest.Encrypt(plaintext)
|
||||
if err != nil {
|
||||
t.Fatalf("Encrypt failed: %v", err)
|
||||
}
|
||||
|
||||
decrypted, err := dest.Decrypt(ciphertext)
|
||||
if err != nil {
|
||||
t.Fatalf("Decrypt failed: %v", err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(plaintext, decrypted) {
|
||||
t.Errorf("Decrypted data doesn't match: %q vs %q", decrypted, plaintext)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRatchets(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
ratchetPath := filepath.Join(tmpDir, "ratchets")
|
||||
|
||||
id, _ := identity.New()
|
||||
dest, _ := New(id, IN|OUT, SINGLE, "test", &mockTransport{})
|
||||
|
||||
if !dest.EnableRatchets(ratchetPath) {
|
||||
t.Fatal("EnableRatchets failed")
|
||||
}
|
||||
|
||||
err := dest.RotateRatchets()
|
||||
if err != nil {
|
||||
t.Fatalf("RotateRatchets failed: %v", err)
|
||||
}
|
||||
|
||||
ratchets := dest.GetRatchets()
|
||||
if len(ratchets) != 1 {
|
||||
t.Errorf("Expected 1 ratchet, got %d", len(ratchets))
|
||||
}
|
||||
}
|
||||
|
||||
func TestPlainDestination(t *testing.T) {
|
||||
id, _ := identity.New()
|
||||
dest, _ := New(id, IN|OUT, PLAIN, "test", &mockTransport{})
|
||||
|
||||
plaintext := []byte("plain text")
|
||||
ciphertext, _ := dest.Encrypt(plaintext)
|
||||
if !bytes.Equal(plaintext, ciphertext) {
|
||||
t.Error("Plain destination should not encrypt")
|
||||
}
|
||||
|
||||
decrypted, _ := dest.Decrypt(ciphertext)
|
||||
if !bytes.Equal(plaintext, decrypted) {
|
||||
t.Error("Plain destination should not decrypt")
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,5 @@
|
||||
// SPDX-License-Identifier: 0BSD
|
||||
// Copyright (c) 2024-2026 Sudo-Ivan / Quad4.io
|
||||
package identity
|
||||
|
||||
import (
|
||||
@@ -8,17 +10,17 @@ import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Sudo-Ivan/reticulum-go/pkg/common"
|
||||
"github.com/Sudo-Ivan/reticulum-go/pkg/cryptography"
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/common"
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/cryptography"
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/debug"
|
||||
"github.com/vmihailenco/msgpack/v5"
|
||||
"golang.org/x/crypto/curve25519"
|
||||
"golang.org/x/crypto/hkdf"
|
||||
)
|
||||
@@ -44,7 +46,7 @@ const (
|
||||
type Identity struct {
|
||||
privateKey []byte
|
||||
publicKey []byte
|
||||
signingSeed []byte // 32-byte Ed25519 seed (compatible with Python RNS)
|
||||
signingSeed []byte // 32-byte Ed25519 seed
|
||||
verificationKey ed25519.PublicKey
|
||||
hash []byte
|
||||
hexHash string
|
||||
@@ -56,9 +58,10 @@ type Identity struct {
|
||||
}
|
||||
|
||||
var (
|
||||
knownDestinations = make(map[string][]interface{})
|
||||
knownRatchets = make(map[string][]byte)
|
||||
ratchetPersistLock sync.Mutex
|
||||
knownDestinations = make(map[string][]interface{})
|
||||
knownDestinationsLock sync.RWMutex
|
||||
knownRatchets = make(map[string][]byte)
|
||||
ratchetPersistLock sync.Mutex
|
||||
)
|
||||
|
||||
func New() (*Identity, error) {
|
||||
@@ -76,7 +79,7 @@ func New() (*Identity, error) {
|
||||
i.privateKey = privKey
|
||||
i.publicKey = pubKey
|
||||
|
||||
// Generate 32-byte Ed25519 seed (compatible with Python RNS)
|
||||
// Generate 32-byte Ed25519 seed
|
||||
var ed25519Seed [32]byte
|
||||
if _, err := io.ReadFull(rand.Reader, ed25519Seed[:]); err != nil {
|
||||
return nil, fmt.Errorf("failed to generate Ed25519 seed: %v", err)
|
||||
@@ -105,7 +108,7 @@ func (i *Identity) GetPrivateKey() []byte {
|
||||
}
|
||||
|
||||
func (i *Identity) Sign(data []byte) []byte {
|
||||
// Derive Ed25519 private key from seed (compatible with Python RNS)
|
||||
// Derive Ed25519 private key from seed
|
||||
privKey := ed25519.NewKeyFromSeed(i.signingSeed)
|
||||
return cryptography.Sign(privKey, data)
|
||||
}
|
||||
@@ -133,20 +136,25 @@ func (i *Identity) Encrypt(plaintext []byte, ratchet []byte) ([]byte, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Derive encryption key
|
||||
key, err := cryptography.DeriveKey(sharedSecret, i.GetSalt(), i.GetContext(), 32)
|
||||
// Derive key material (64 bytes: first 32 for HMAC, last 32 for encryption)
|
||||
salt := i.GetSalt()
|
||||
debug.Log(debug.DEBUG_ALL, "Encrypt: using salt", "salt", fmt.Sprintf("%x", salt), "identity_hash", fmt.Sprintf("%x", i.Hash()))
|
||||
key, err := cryptography.DeriveKey(sharedSecret, salt, i.GetContext(), 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hmacKey := key[:32]
|
||||
encryptionKey := key[32:64]
|
||||
|
||||
// Encrypt data
|
||||
ciphertext, err := cryptography.EncryptAES256CBC(key[:32], plaintext)
|
||||
ciphertext, err := cryptography.EncryptAES256CBC(encryptionKey, plaintext)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Calculate HMAC
|
||||
mac := cryptography.ComputeHMAC(key, append(ephemeralPubKey, ciphertext...))
|
||||
// Calculate HMAC over ciphertext only (iv + encrypted_data)
|
||||
mac := cryptography.ComputeHMAC(hmacKey, ciphertext)
|
||||
|
||||
// Combine components
|
||||
token := make([]byte, 0, len(ephemeralPubKey)+len(ciphertext)+len(mac))
|
||||
@@ -173,7 +181,7 @@ func GetRandomHash() []byte {
|
||||
randomData := make([]byte, TRUNCATED_HASHLENGTH/8)
|
||||
_, err := rand.Read(randomData) // #nosec G104
|
||||
if err != nil {
|
||||
log.Printf("[DEBUG-1] Failed to read random data for hash: %v", err)
|
||||
debug.Log(debug.DEBUG_CRITICAL, "Failed to read random data for hash", "error", err)
|
||||
return nil // Or handle the error appropriately
|
||||
}
|
||||
return TruncatedHash(randomData)
|
||||
@@ -184,12 +192,14 @@ func Remember(packet []byte, destHash []byte, publicKey []byte, appData []byte)
|
||||
|
||||
// Store destination data as [packet, destHash, identity, appData]
|
||||
id := FromPublicKey(publicKey)
|
||||
knownDestinationsLock.Lock()
|
||||
knownDestinations[hashStr] = []interface{}{
|
||||
packet,
|
||||
destHash,
|
||||
id,
|
||||
appData,
|
||||
}
|
||||
knownDestinationsLock.Unlock()
|
||||
}
|
||||
|
||||
func ValidateAnnounce(packet []byte, destHash []byte, publicKey []byte, signature []byte, appData []byte) bool {
|
||||
@@ -221,13 +231,18 @@ func FromPublicKey(publicKey []byte) *Identity {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &Identity{
|
||||
id := &Identity{
|
||||
publicKey: publicKey[:KEYSIZE/16],
|
||||
verificationKey: publicKey[KEYSIZE/16:],
|
||||
ratchets: make(map[string][]byte),
|
||||
ratchetExpiry: make(map[string]int64),
|
||||
mutex: &sync.RWMutex{},
|
||||
}
|
||||
|
||||
hash := cryptography.Hash(id.GetPublicKey())
|
||||
id.hash = hash[:TRUNCATED_HASHLENGTH/8]
|
||||
|
||||
return id
|
||||
}
|
||||
|
||||
func (i *Identity) Hex() string {
|
||||
@@ -240,8 +255,12 @@ func (i *Identity) String() string {
|
||||
|
||||
func Recall(hash []byte) (*Identity, error) {
|
||||
hashStr := hex.EncodeToString(hash)
|
||||
|
||||
if data, exists := knownDestinations[hashStr]; exists {
|
||||
|
||||
knownDestinationsLock.RLock()
|
||||
data, exists := knownDestinations[hashStr]
|
||||
knownDestinationsLock.RUnlock()
|
||||
|
||||
if exists {
|
||||
// data is [packet, destHash, identity, appData]
|
||||
if len(data) >= 3 {
|
||||
if id, ok := data[2].(*Identity); ok {
|
||||
@@ -249,7 +268,7 @@ func Recall(hash []byte) (*Identity, error) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return nil, fmt.Errorf("identity not found for hash %x", hash)
|
||||
}
|
||||
|
||||
@@ -279,13 +298,13 @@ func (i *Identity) GetCurrentRatchetKey() []byte {
|
||||
if len(i.ratchets) == 0 {
|
||||
// If no ratchets exist, generate one.
|
||||
// This should ideally be handled by an explicit setup process.
|
||||
log.Println("[DEBUG-5] No ratchets found, generating a new one on-the-fly.")
|
||||
debug.Log(debug.DEBUG_TRACE, "No ratchets found, generating a new one on-the-fly")
|
||||
// Temporarily unlock to call RotateRatchet, which locks internally.
|
||||
i.mutex.RUnlock()
|
||||
newRatchet, err := i.RotateRatchet()
|
||||
i.mutex.RLock()
|
||||
if err != nil {
|
||||
log.Printf("[DEBUG-1] Failed to generate initial ratchet key: %v", err)
|
||||
debug.Log(debug.DEBUG_CRITICAL, "Failed to generate initial ratchet key", "error", err)
|
||||
return nil
|
||||
}
|
||||
return newRatchet
|
||||
@@ -293,7 +312,7 @@ func (i *Identity) GetCurrentRatchetKey() []byte {
|
||||
|
||||
// Return the most recently generated ratchet key
|
||||
var latestKey []byte
|
||||
var latestTime int64 = 0
|
||||
var latestTime int64
|
||||
for id, expiry := range i.ratchetExpiry {
|
||||
if expiry > latestTime {
|
||||
latestTime = expiry
|
||||
@@ -302,7 +321,7 @@ func (i *Identity) GetCurrentRatchetKey() []byte {
|
||||
}
|
||||
|
||||
if latestKey == nil {
|
||||
log.Printf("[DEBUG-2] Could not determine the latest ratchet key from %d ratchets.", len(i.ratchets))
|
||||
debug.Log(debug.DEBUG_ERROR, "Could not determine the latest ratchet key", "ratchet_count", len(i.ratchets))
|
||||
}
|
||||
|
||||
return latestKey
|
||||
@@ -310,13 +329,13 @@ func (i *Identity) GetCurrentRatchetKey() []byte {
|
||||
|
||||
func (i *Identity) Decrypt(ciphertextToken []byte, ratchets [][]byte, enforceRatchets bool, ratchetIDReceiver *common.RatchetIDReceiver) ([]byte, error) {
|
||||
if i.privateKey == nil {
|
||||
log.Printf("[DEBUG-1] Decryption failed: identity has no private key")
|
||||
debug.Log(debug.DEBUG_CRITICAL, "Decryption failed: identity has no private key")
|
||||
return nil, errors.New("decryption failed because identity does not hold a private key")
|
||||
}
|
||||
|
||||
log.Printf("[DEBUG-7] Starting decryption for identity %s", i.GetHexHash())
|
||||
debug.Log(debug.DEBUG_ALL, "Starting decryption for identity", "hash", i.GetHexHash())
|
||||
if len(ratchets) > 0 {
|
||||
log.Printf("[DEBUG-7] Attempting decryption with %d ratchets", len(ratchets))
|
||||
debug.Log(debug.DEBUG_ALL, "Attempting decryption with ratchets", "count", len(ratchets))
|
||||
}
|
||||
|
||||
if len(ciphertextToken) <= KEYSIZE/8/2 {
|
||||
@@ -335,7 +354,7 @@ func (i *Identity) Decrypt(ciphertextToken []byte, ratchets [][]byte, enforceRat
|
||||
// Try decryption with ratchets first if provided
|
||||
if len(ratchets) > 0 {
|
||||
for _, ratchet := range ratchets {
|
||||
if decrypted, ratchetID, err := i.tryRatchetDecryption(peerPubBytes, ciphertext, ratchet); err == nil {
|
||||
if decrypted, ratchetID, err := i.tryRatchetDecryption(peerPubBytes, ciphertext, mac, ratchet); err == nil {
|
||||
if ratchetIDReceiver != nil {
|
||||
ratchetIDReceiver.LatestRatchetID = ratchetID
|
||||
}
|
||||
@@ -357,20 +376,25 @@ func (i *Identity) Decrypt(ciphertextToken []byte, ratchets [][]byte, enforceRat
|
||||
return nil, fmt.Errorf("failed to generate shared key: %v", err)
|
||||
}
|
||||
|
||||
// Derive key using HKDF
|
||||
hkdfReader := hkdf.New(sha256.New, sharedKey, i.GetSalt(), i.GetContext())
|
||||
derivedKey := make([]byte, 32)
|
||||
// Derive key material (64 bytes: first 32 for HMAC, last 32 for encryption)
|
||||
salt := i.GetSalt()
|
||||
debug.Log(debug.DEBUG_ALL, "Decrypt: using salt", "salt", fmt.Sprintf("%x", salt), "identity_hash", fmt.Sprintf("%x", i.Hash()))
|
||||
hkdfReader := hkdf.New(sha256.New, sharedKey, salt, i.GetContext())
|
||||
derivedKey := make([]byte, 64)
|
||||
if _, err := io.ReadFull(hkdfReader, derivedKey); err != nil {
|
||||
return nil, fmt.Errorf("failed to derive key: %v", err)
|
||||
}
|
||||
|
||||
// Validate HMAC
|
||||
if !cryptography.ValidateHMAC(derivedKey, append(peerPubBytes, ciphertext...), mac) {
|
||||
hmacKey := derivedKey[:32]
|
||||
encryptionKey := derivedKey[32:64]
|
||||
|
||||
// Validate HMAC over ciphertext only (iv + encrypted_data)
|
||||
if !cryptography.ValidateHMAC(hmacKey, ciphertext, mac) {
|
||||
return nil, errors.New("invalid HMAC")
|
||||
}
|
||||
|
||||
// Create AES cipher
|
||||
block, err := aes.NewCipher(derivedKey)
|
||||
block, err := aes.NewCipher(encryptionKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create cipher: %v", err)
|
||||
}
|
||||
@@ -407,34 +431,42 @@ func (i *Identity) Decrypt(ciphertextToken []byte, ratchets [][]byte, enforceRat
|
||||
ratchetIDReceiver.LatestRatchetID = nil
|
||||
}
|
||||
|
||||
log.Printf("[DEBUG-7] Decryption completed successfully")
|
||||
debug.Log(debug.DEBUG_ALL, "Decryption completed successfully")
|
||||
return plaintext[:len(plaintext)-padding], nil
|
||||
}
|
||||
|
||||
// Helper function to attempt decryption using a ratchet
|
||||
func (i *Identity) tryRatchetDecryption(peerPubBytes, ciphertext, ratchet []byte) ([]byte, []byte, error) {
|
||||
func (i *Identity) tryRatchetDecryption(peerPubBytes, ciphertext, mac, ratchet []byte) (plaintext, ratchetID []byte, err error) {
|
||||
// Convert ratchet to private key
|
||||
ratchetPriv := ratchet
|
||||
|
||||
// Get ratchet ID
|
||||
ratchetPubBytes, err := curve25519.X25519(ratchetPriv, cryptography.GetBasepoint())
|
||||
if err != nil {
|
||||
log.Printf("[DEBUG-7] Failed to generate ratchet public key: %v", err)
|
||||
debug.Log(debug.DEBUG_ALL, "Failed to generate ratchet public key", "error", err)
|
||||
return nil, nil, err
|
||||
}
|
||||
ratchetID := i.GetRatchetID(ratchetPubBytes)
|
||||
ratchetID = i.GetRatchetID(ratchetPubBytes)
|
||||
|
||||
sharedSecret, err := cryptography.DeriveSharedSecret(ratchet, peerPubBytes)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
key, err := cryptography.DeriveKey(sharedSecret, i.GetSalt(), i.GetContext(), 32)
|
||||
key, err := cryptography.DeriveKey(sharedSecret, i.GetSalt(), i.GetContext(), 64)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
plaintext, err := cryptography.DecryptAES256CBC(key, ciphertext)
|
||||
hmacKey := key[:32]
|
||||
encryptionKey := key[32:64]
|
||||
|
||||
// Validate HMAC over ciphertext only (iv + encrypted_data)
|
||||
if !cryptography.ValidateHMAC(hmacKey, ciphertext, mac) {
|
||||
return nil, nil, errors.New("invalid HMAC")
|
||||
}
|
||||
|
||||
plaintext, err = cryptography.DecryptAES256CBC(encryptionKey, ciphertext)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
@@ -443,12 +475,23 @@ func (i *Identity) tryRatchetDecryption(peerPubBytes, ciphertext, ratchet []byte
|
||||
}
|
||||
|
||||
func (i *Identity) EncryptWithHMAC(plaintext []byte, key []byte) ([]byte, error) {
|
||||
ciphertext, err := cryptography.EncryptAES256CBC(key, plaintext)
|
||||
var hmacKey, encryptionKey []byte
|
||||
if len(key) == 64 {
|
||||
hmacKey = key[:32]
|
||||
encryptionKey = key[32:64]
|
||||
} else if len(key) == 32 {
|
||||
hmacKey = key[:16]
|
||||
encryptionKey = key[16:32]
|
||||
} else {
|
||||
return nil, errors.New("invalid key length for EncryptWithHMAC")
|
||||
}
|
||||
|
||||
ciphertext, err := cryptography.EncryptAES256CBC(encryptionKey, plaintext)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mac := cryptography.ComputeHMAC(key, ciphertext)
|
||||
mac := cryptography.ComputeHMAC(hmacKey, ciphertext)
|
||||
return append(ciphertext, mac...), nil
|
||||
}
|
||||
|
||||
@@ -457,48 +500,158 @@ func (i *Identity) DecryptWithHMAC(data []byte, key []byte) ([]byte, error) {
|
||||
return nil, errors.New("data too short")
|
||||
}
|
||||
|
||||
var hmacKey, encryptionKey []byte
|
||||
if len(key) == 64 {
|
||||
hmacKey = key[:32]
|
||||
encryptionKey = key[32:64]
|
||||
} else if len(key) == 32 {
|
||||
hmacKey = key[:16]
|
||||
encryptionKey = key[16:32]
|
||||
} else {
|
||||
return nil, errors.New("invalid key length for DecryptWithHMAC")
|
||||
}
|
||||
|
||||
macStart := len(data) - cryptography.SHA256Size
|
||||
ciphertext := data[:macStart]
|
||||
messageMAC := data[macStart:]
|
||||
|
||||
if !cryptography.ValidateHMAC(key, ciphertext, messageMAC) {
|
||||
if !cryptography.ValidateHMAC(hmacKey, ciphertext, messageMAC) {
|
||||
return nil, errors.New("invalid HMAC")
|
||||
}
|
||||
|
||||
return cryptography.DecryptAES256CBC(key, ciphertext)
|
||||
return cryptography.DecryptAES256CBC(encryptionKey, ciphertext)
|
||||
}
|
||||
|
||||
func (i *Identity) ToFile(path string) error {
|
||||
log.Printf("[DEBUG-7] Saving identity %s to file: %s", i.GetHexHash(), path)
|
||||
debug.Log(debug.DEBUG_ALL, "Saving identity to file", "hash", i.GetHexHash(), "path", path)
|
||||
|
||||
// Persist ratchets to a separate file
|
||||
ratchetPath := path + ".ratchets"
|
||||
if err := i.saveRatchets(ratchetPath); err != nil {
|
||||
log.Printf("[DEBUG-1] Failed to save ratchets: %v", err)
|
||||
// Continue saving the main identity file even if ratchets fail
|
||||
if i.privateKey == nil || i.signingSeed == nil {
|
||||
return errors.New("cannot save identity without private keys")
|
||||
}
|
||||
|
||||
data := map[string]interface{}{
|
||||
"private_key": i.privateKey,
|
||||
"public_key": i.publicKey,
|
||||
"signing_seed": i.signingSeed,
|
||||
"verification_key": i.verificationKey,
|
||||
"app_data": i.appData,
|
||||
}
|
||||
// Store private keys as raw bytes
|
||||
// Format: [X25519 PrivKey (32 bytes)][Ed25519 PrivKey (32 bytes)]
|
||||
// Total: 64 bytes
|
||||
privateKeyBytes := make([]byte, 64)
|
||||
copy(privateKeyBytes[:32], i.privateKey)
|
||||
copy(privateKeyBytes[32:], i.signingSeed)
|
||||
|
||||
// Write raw bytes to file
|
||||
file, err := os.Create(path) // #nosec G304
|
||||
if err != nil {
|
||||
log.Printf("[DEBUG-1] Failed to create identity file: %v", err)
|
||||
debug.Log(debug.DEBUG_CRITICAL, "Failed to create identity file", "error", err)
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
if err := json.NewEncoder(file).Encode(data); err != nil {
|
||||
log.Printf("[DEBUG-1] Failed to encode identity data: %v", err)
|
||||
if _, err := file.Write(privateKeyBytes); err != nil {
|
||||
debug.Log(debug.DEBUG_CRITICAL, "Failed to write identity data", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
log.Printf("[DEBUG-7] Identity saved successfully")
|
||||
debug.Log(debug.DEBUG_ALL, "Identity saved successfully", "bytes", len(privateKeyBytes))
|
||||
return nil
|
||||
}
|
||||
|
||||
func FromFile(path string) (*Identity, error) {
|
||||
debug.Log(debug.DEBUG_ALL, "Loading identity from file", "path", path)
|
||||
|
||||
// Read the private key bytes from file
|
||||
// bearer:disable go_gosec_filesystem_filereadtaint
|
||||
data, err := os.ReadFile(path) // #nosec G304
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read identity file: %w", err)
|
||||
}
|
||||
|
||||
if len(data) != 64 {
|
||||
return nil, fmt.Errorf("invalid identity file: expected 64 bytes, got %d", len(data))
|
||||
}
|
||||
|
||||
// Parse the private keys
|
||||
// Format: [X25519 PrivKey (32 bytes)][Ed25519 PrivKey (32 bytes)]
|
||||
privateKey := data[:32]
|
||||
signingSeed := data[32:64]
|
||||
|
||||
// Create identity with initialized maps and mutex
|
||||
ident := &Identity{
|
||||
ratchets: make(map[string][]byte),
|
||||
ratchetExpiry: make(map[string]int64),
|
||||
mutex: &sync.RWMutex{},
|
||||
}
|
||||
|
||||
if err := ident.loadPrivateKey(privateKey, signingSeed); err != nil {
|
||||
return nil, fmt.Errorf("failed to load private key: %w", err)
|
||||
}
|
||||
|
||||
debug.Log(debug.DEBUG_INFO, "Identity loaded from file", "hash", ident.GetHexHash())
|
||||
return ident, nil
|
||||
}
|
||||
|
||||
func LoadOrCreateTransportIdentity() (*Identity, error) {
|
||||
storagePath := os.Getenv("RETICULUM_STORAGE_PATH")
|
||||
if storagePath == "" {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get home directory: %w", err)
|
||||
}
|
||||
storagePath = fmt.Sprintf("%s/.reticulum/storage", homeDir)
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(storagePath, 0700); err != nil {
|
||||
return nil, fmt.Errorf("failed to create storage directory: %w", err)
|
||||
}
|
||||
|
||||
transportIdentityPath := fmt.Sprintf("%s/transport_identity", storagePath)
|
||||
|
||||
if ident, err := FromFile(transportIdentityPath); err == nil {
|
||||
debug.Log(debug.DEBUG_INFO, "Loaded transport identity from storage")
|
||||
return ident, nil
|
||||
}
|
||||
|
||||
debug.Log(debug.DEBUG_INFO, "No valid transport identity in storage, creating new one")
|
||||
ident, err := New()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create transport identity: %w", err)
|
||||
}
|
||||
|
||||
if err := ident.ToFile(transportIdentityPath); err != nil {
|
||||
return nil, fmt.Errorf("failed to save transport identity: %w", err)
|
||||
}
|
||||
|
||||
debug.Log(debug.DEBUG_INFO, "Created and saved transport identity")
|
||||
return ident, nil
|
||||
}
|
||||
|
||||
func (i *Identity) loadPrivateKey(privateKey, signingSeed []byte) error {
|
||||
if len(privateKey) != 32 || len(signingSeed) != 32 {
|
||||
return errors.New("invalid private key length")
|
||||
}
|
||||
|
||||
// Load X25519 private key
|
||||
i.privateKey = make([]byte, 32)
|
||||
copy(i.privateKey, privateKey)
|
||||
|
||||
// Load Ed25519 signing seed
|
||||
i.signingSeed = make([]byte, 32)
|
||||
copy(i.signingSeed, signingSeed)
|
||||
|
||||
// Derive public keys from private keys
|
||||
var err error
|
||||
i.publicKey, err = curve25519.X25519(i.privateKey, curve25519.Basepoint)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to derive X25519 public key: %w", err)
|
||||
}
|
||||
|
||||
signingKey := ed25519.NewKeyFromSeed(i.signingSeed)
|
||||
i.verificationKey = signingKey.Public().(ed25519.PublicKey)
|
||||
|
||||
publicKeyBytes := make([]byte, 0, len(i.publicKey)+len(i.verificationKey))
|
||||
publicKeyBytes = append(publicKeyBytes, i.publicKey...)
|
||||
publicKeyBytes = append(publicKeyBytes, i.verificationKey...)
|
||||
i.hash = TruncatedHash(publicKeyBytes)[:TRUNCATED_HASHLENGTH/8]
|
||||
i.hexHash = hex.EncodeToString(i.hash)
|
||||
|
||||
debug.Log(debug.DEBUG_VERBOSE, "Private key loaded successfully", "hash", i.GetHexHash())
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -510,70 +663,117 @@ func (i *Identity) saveRatchets(path string) error {
|
||||
return nil // Nothing to save
|
||||
}
|
||||
|
||||
log.Printf("[DEBUG-6] Saving %d ratchets to %s", len(i.ratchets), path)
|
||||
data := map[string]interface{}{
|
||||
"ratchets": i.ratchets,
|
||||
"ratchet_expiry": i.ratchetExpiry,
|
||||
debug.Log(debug.DEBUG_PACKETS, "Saving ratchets", "count", len(i.ratchets), "path", path)
|
||||
|
||||
// Convert ratchets to list format for msgpack
|
||||
ratchetList := make([][]byte, 0, len(i.ratchets))
|
||||
for _, ratchet := range i.ratchets {
|
||||
ratchetList = append(ratchetList, ratchet)
|
||||
}
|
||||
|
||||
file, err := os.Create(path) // #nosec G304
|
||||
// Pack ratchets using msgpack
|
||||
packedRatchets, err := msgpack.Marshal(ratchetList)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create ratchet file: %w", err)
|
||||
return fmt.Errorf("failed to pack ratchets: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
return json.NewEncoder(file).Encode(data)
|
||||
// Sign the packed ratchets
|
||||
signature := i.Sign(packedRatchets)
|
||||
|
||||
// Create structure: {"signature": ..., "ratchets": ...}
|
||||
persistedData := map[string][]byte{
|
||||
"signature": signature,
|
||||
"ratchets": packedRatchets,
|
||||
}
|
||||
|
||||
// Pack the entire structure
|
||||
finalData, err := msgpack.Marshal(persistedData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to pack ratchet data: %w", err)
|
||||
}
|
||||
|
||||
// Write to temporary file first, then rename (atomic operation)
|
||||
tempPath := path + ".tmp"
|
||||
file, err := os.Create(tempPath) // #nosec G304
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create temp ratchet file: %w", err)
|
||||
}
|
||||
|
||||
if _, err := file.Write(finalData); err != nil {
|
||||
// #nosec G104 - Error already being handled, cleanup errors are non-critical
|
||||
file.Close()
|
||||
// #nosec G104 - Error already being handled, cleanup errors are non-critical
|
||||
os.Remove(tempPath)
|
||||
return fmt.Errorf("failed to write ratchet data: %w", err)
|
||||
}
|
||||
// #nosec G104 - File is being closed after successful write, error is non-critical
|
||||
file.Close()
|
||||
|
||||
// Atomic rename
|
||||
if err := os.Rename(tempPath, path); err != nil {
|
||||
// #nosec G104 - Error already being handled, cleanup errors are non-critical
|
||||
os.Remove(tempPath)
|
||||
return fmt.Errorf("failed to rename ratchet file: %w", err)
|
||||
}
|
||||
|
||||
debug.Log(debug.DEBUG_PACKETS, "Ratchets saved successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
func RecallIdentity(path string) (*Identity, error) {
|
||||
log.Printf("[DEBUG-7] Attempting to recall identity from: %s", path)
|
||||
debug.Log(debug.DEBUG_ALL, "Attempting to recall identity", "path", path)
|
||||
|
||||
// bearer:disable go_gosec_filesystem_filereadtaint
|
||||
file, err := os.Open(path) // #nosec G304
|
||||
if err != nil {
|
||||
log.Printf("[DEBUG-1] Failed to open identity file: %v", err)
|
||||
debug.Log(debug.DEBUG_CRITICAL, "Failed to open identity file", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
var data map[string]interface{}
|
||||
if err := json.NewDecoder(file).Decode(&data); err != nil {
|
||||
log.Printf("[DEBUG-1] Failed to decode identity data: %v", err)
|
||||
// Read raw bytes
|
||||
// Format: [X25519 PrivKey (32 bytes)][Ed25519 PrivKey (32 bytes)]
|
||||
privateKeyBytes := make([]byte, 64)
|
||||
n, err := io.ReadFull(file, privateKeyBytes)
|
||||
if err != nil {
|
||||
debug.Log(debug.DEBUG_CRITICAL, "Failed to read identity data", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var signingSeed []byte
|
||||
var verificationKey ed25519.PublicKey
|
||||
|
||||
if seedData, exists := data["signing_seed"]; exists {
|
||||
signingSeed = seedData.([]byte)
|
||||
verificationKey = data["verification_key"].(ed25519.PublicKey)
|
||||
} else if keyData, exists := data["signing_key"]; exists {
|
||||
oldKey := keyData.(ed25519.PrivateKey)
|
||||
signingSeed = oldKey[:32]
|
||||
verificationKey = data["verification_key"].(ed25519.PublicKey)
|
||||
} else {
|
||||
return nil, fmt.Errorf("no signing key data found in identity file")
|
||||
if n != 64 {
|
||||
return nil, fmt.Errorf("invalid identity file: expected 64 bytes, got %d", n)
|
||||
}
|
||||
|
||||
// Extract keys
|
||||
x25519PrivKey := privateKeyBytes[:32]
|
||||
ed25519Seed := privateKeyBytes[32:]
|
||||
|
||||
// Derive public keys
|
||||
x25519PubKey, err := curve25519.X25519(x25519PrivKey, curve25519.Basepoint)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to derive X25519 public key: %v", err)
|
||||
}
|
||||
|
||||
ed25519PrivKey := ed25519.NewKeyFromSeed(ed25519Seed)
|
||||
ed25519PubKey := ed25519PrivKey.Public().(ed25519.PublicKey)
|
||||
|
||||
id := &Identity{
|
||||
privateKey: data["private_key"].([]byte),
|
||||
publicKey: data["public_key"].([]byte),
|
||||
signingSeed: signingSeed,
|
||||
verificationKey: verificationKey,
|
||||
appData: data["app_data"].([]byte),
|
||||
privateKey: x25519PrivKey,
|
||||
publicKey: x25519PubKey,
|
||||
signingSeed: ed25519Seed,
|
||||
verificationKey: ed25519PubKey,
|
||||
ratchets: make(map[string][]byte),
|
||||
ratchetExpiry: make(map[string]int64),
|
||||
mutex: &sync.RWMutex{},
|
||||
}
|
||||
|
||||
// Load ratchets if they exist
|
||||
ratchetPath := path + ".ratchets"
|
||||
if err := id.loadRatchets(ratchetPath); err != nil {
|
||||
log.Printf("[DEBUG-2] Could not load ratchets for identity %s: %v", id.GetHexHash(), err)
|
||||
// This is not a fatal error, the identity can still function
|
||||
}
|
||||
// Generate hash
|
||||
combinedPub := make([]byte, KEYSIZE/8)
|
||||
copy(combinedPub[:KEYSIZE/16], id.publicKey)
|
||||
copy(combinedPub[KEYSIZE/16:], id.verificationKey)
|
||||
hash := sha256.Sum256(combinedPub)
|
||||
id.hash = hash[:TRUNCATED_HASHLENGTH/8]
|
||||
|
||||
log.Printf("[DEBUG-7] Successfully recalled identity with hash: %s", id.GetHexHash())
|
||||
debug.Log(debug.DEBUG_ALL, "Successfully recalled identity", "hash", id.GetHexHash())
|
||||
return id, nil
|
||||
}
|
||||
|
||||
@@ -581,38 +781,62 @@ func (i *Identity) loadRatchets(path string) error {
|
||||
i.mutex.Lock()
|
||||
defer i.mutex.Unlock()
|
||||
|
||||
// bearer:disable go_gosec_filesystem_filereadtaint
|
||||
file, err := os.Open(path) // #nosec G304
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
log.Printf("[DEBUG-6] No ratchet file found at %s, skipping.", path)
|
||||
debug.Log(debug.DEBUG_PACKETS, "No ratchet file found, skipping", "path", path)
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("failed to open ratchet file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
var data map[string]interface{}
|
||||
if err := json.NewDecoder(file).Decode(&data); err != nil {
|
||||
return fmt.Errorf("failed to decode ratchet data: %w", err)
|
||||
// Read all data
|
||||
fileData, err := io.ReadAll(file)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read ratchet file: %w", err)
|
||||
}
|
||||
|
||||
if ratchets, ok := data["ratchets"].(map[string]interface{}); ok {
|
||||
for id, key := range ratchets {
|
||||
if keyStr, ok := key.(string); ok {
|
||||
i.ratchets[id] = []byte(keyStr)
|
||||
}
|
||||
// Unpack outer structure: {"signature": ..., "ratchets": ...}
|
||||
var persistedData map[string][]byte
|
||||
if err := msgpack.Unmarshal(fileData, &persistedData); err != nil {
|
||||
return fmt.Errorf("failed to unpack ratchet data: %w", err)
|
||||
}
|
||||
|
||||
signature, hasSignature := persistedData["signature"]
|
||||
packedRatchets, hasRatchets := persistedData["ratchets"]
|
||||
|
||||
if !hasSignature || !hasRatchets {
|
||||
return fmt.Errorf("invalid ratchet file format: missing signature or ratchets")
|
||||
}
|
||||
|
||||
// Verify signature
|
||||
if !i.Verify(packedRatchets, signature) {
|
||||
return fmt.Errorf("invalid ratchet file signature")
|
||||
}
|
||||
|
||||
// Unpack ratchet list
|
||||
var ratchetList [][]byte
|
||||
if err := msgpack.Unmarshal(packedRatchets, &ratchetList); err != nil {
|
||||
return fmt.Errorf("failed to unpack ratchet list: %w", err)
|
||||
}
|
||||
|
||||
// Store ratchets with generated IDs
|
||||
now := time.Now().Unix()
|
||||
for _, ratchet := range ratchetList {
|
||||
// Generate ratchet public key to create ID
|
||||
ratchetPub, err := curve25519.X25519(ratchet, curve25519.Basepoint)
|
||||
if err != nil {
|
||||
debug.Log(debug.DEBUG_ERROR, "Failed to generate ratchet public key", "error", err)
|
||||
continue
|
||||
}
|
||||
ratchetID := i.GetRatchetID(ratchetPub)
|
||||
i.ratchets[string(ratchetID)] = ratchet
|
||||
i.ratchetExpiry[string(ratchetID)] = now + RATCHET_EXPIRY
|
||||
}
|
||||
|
||||
if expiry, ok := data["ratchet_expiry"].(map[string]interface{}); ok {
|
||||
for id, timeVal := range expiry {
|
||||
if timeFloat, ok := timeVal.(float64); ok {
|
||||
i.ratchetExpiry[id] = int64(timeFloat)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("[DEBUG-6] Loaded %d ratchets from %s", len(i.ratchets), path)
|
||||
debug.Log(debug.DEBUG_PACKETS, "Loaded ratchets", "count", len(i.ratchets), "path", path)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -638,7 +862,10 @@ func (i *Identity) GetRatchetID(ratchetPubBytes []byte) []byte {
|
||||
}
|
||||
|
||||
func GetKnownDestination(hash string) ([]interface{}, bool) {
|
||||
if data, exists := knownDestinations[hash]; exists {
|
||||
knownDestinationsLock.RLock()
|
||||
data, exists := knownDestinations[hash]
|
||||
knownDestinationsLock.RUnlock()
|
||||
if exists {
|
||||
return data, true
|
||||
}
|
||||
return nil, false
|
||||
@@ -668,7 +895,7 @@ func (i *Identity) SetRatchetKey(id string, key []byte) {
|
||||
|
||||
// NewIdentity creates a new Identity instance with fresh keys
|
||||
func NewIdentity() (*Identity, error) {
|
||||
// Generate 32-byte Ed25519 seed (compatible with Python RNS)
|
||||
// Generate 32-byte Ed25519 seed
|
||||
var ed25519Seed [32]byte
|
||||
if _, err := io.ReadFull(rand.Reader, ed25519Seed[:]); err != nil {
|
||||
return nil, fmt.Errorf("failed to generate Ed25519 seed: %v", err)
|
||||
@@ -704,28 +931,50 @@ func NewIdentity() (*Identity, error) {
|
||||
copy(combinedPub[:KEYSIZE/16], i.publicKey)
|
||||
copy(combinedPub[KEYSIZE/16:], i.verificationKey)
|
||||
hash := sha256.Sum256(combinedPub)
|
||||
i.hash = hash[:]
|
||||
i.hash = hash[:TRUNCATED_HASHLENGTH/8]
|
||||
|
||||
return i, nil
|
||||
}
|
||||
|
||||
// FromBytes creates an Identity from a 64-byte private key representation
|
||||
func FromBytes(data []byte) (*Identity, error) {
|
||||
if len(data) != 64 {
|
||||
return nil, fmt.Errorf("invalid identity data: expected 64 bytes, got %d", len(data))
|
||||
}
|
||||
|
||||
privateKey := data[:32]
|
||||
signingSeed := data[32:64]
|
||||
|
||||
ident := &Identity{
|
||||
ratchets: make(map[string][]byte),
|
||||
ratchetExpiry: make(map[string]int64),
|
||||
mutex: &sync.RWMutex{},
|
||||
}
|
||||
|
||||
if err := ident.loadPrivateKey(privateKey, signingSeed); err != nil {
|
||||
return nil, fmt.Errorf("failed to load private key: %w", err)
|
||||
}
|
||||
|
||||
return ident, nil
|
||||
}
|
||||
|
||||
func (i *Identity) RotateRatchet() ([]byte, error) {
|
||||
i.mutex.Lock()
|
||||
defer i.mutex.Unlock()
|
||||
|
||||
log.Printf("[DEBUG-7] Rotating ratchet for identity %s", i.GetHexHash())
|
||||
debug.Log(debug.DEBUG_ALL, "Rotating ratchet for identity", "hash", i.GetHexHash())
|
||||
|
||||
// Generate new ratchet key
|
||||
newRatchet := make([]byte, RATCHETSIZE/8)
|
||||
if _, err := io.ReadFull(rand.Reader, newRatchet); err != nil {
|
||||
log.Printf("[DEBUG-1] Failed to generate new ratchet: %v", err)
|
||||
debug.Log(debug.DEBUG_CRITICAL, "Failed to generate new ratchet", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get public key for ratchet ID
|
||||
ratchetPub, err := curve25519.X25519(newRatchet, curve25519.Basepoint)
|
||||
if err != nil {
|
||||
log.Printf("[DEBUG-1] Failed to generate ratchet public key: %v", err)
|
||||
debug.Log(debug.DEBUG_CRITICAL, "Failed to generate ratchet public key", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -736,7 +985,7 @@ func (i *Identity) RotateRatchet() ([]byte, error) {
|
||||
i.ratchets[string(ratchetID)] = newRatchet
|
||||
i.ratchetExpiry[string(ratchetID)] = expiry
|
||||
|
||||
log.Printf("[DEBUG-7] New ratchet generated with ID: %x, expiry: %d", ratchetID, expiry)
|
||||
debug.Log(debug.DEBUG_ALL, "New ratchet generated", "id", fmt.Sprintf("%x", ratchetID), "expiry", expiry)
|
||||
|
||||
// Cleanup old ratchets if we exceed max retained
|
||||
if len(i.ratchets) > MAX_RETAINED_RATCHETS {
|
||||
@@ -752,10 +1001,10 @@ func (i *Identity) RotateRatchet() ([]byte, error) {
|
||||
|
||||
delete(i.ratchets, oldestID)
|
||||
delete(i.ratchetExpiry, oldestID)
|
||||
log.Printf("[DEBUG-7] Cleaned up oldest ratchet with ID: %x", []byte(oldestID))
|
||||
debug.Log(debug.DEBUG_ALL, "Cleaned up oldest ratchet", "id", fmt.Sprintf("%x", []byte(oldestID)))
|
||||
}
|
||||
|
||||
log.Printf("[DEBUG-7] Current number of active ratchets: %d", len(i.ratchets))
|
||||
debug.Log(debug.DEBUG_ALL, "Current number of active ratchets", "count", len(i.ratchets))
|
||||
return newRatchet, nil
|
||||
}
|
||||
|
||||
@@ -763,7 +1012,7 @@ func (i *Identity) GetRatchets() [][]byte {
|
||||
i.mutex.RLock()
|
||||
defer i.mutex.RUnlock()
|
||||
|
||||
log.Printf("[DEBUG-7] Getting ratchets for identity %s", i.GetHexHash())
|
||||
debug.Log(debug.DEBUG_ALL, "Getting ratchets for identity", "hash", i.GetHexHash())
|
||||
|
||||
ratchets := make([][]byte, 0, len(i.ratchets))
|
||||
now := time.Now().Unix()
|
||||
@@ -781,7 +1030,7 @@ func (i *Identity) GetRatchets() [][]byte {
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("[DEBUG-7] Retrieved %d active ratchets, cleaned up %d expired", len(ratchets), expired)
|
||||
debug.Log(debug.DEBUG_ALL, "Retrieved active ratchets", "active", len(ratchets), "expired", expired)
|
||||
return ratchets
|
||||
}
|
||||
|
||||
@@ -789,7 +1038,7 @@ func (i *Identity) CleanupExpiredRatchets() {
|
||||
i.mutex.Lock()
|
||||
defer i.mutex.Unlock()
|
||||
|
||||
log.Printf("[DEBUG-7] Starting ratchet cleanup for identity %s", i.GetHexHash())
|
||||
debug.Log(debug.DEBUG_ALL, "Starting ratchet cleanup for identity", "hash", i.GetHexHash())
|
||||
|
||||
now := time.Now().Unix()
|
||||
cleaned := 0
|
||||
@@ -801,7 +1050,7 @@ func (i *Identity) CleanupExpiredRatchets() {
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("[DEBUG-7] Cleaned up %d expired ratchets, %d remaining", cleaned, len(i.ratchets))
|
||||
debug.Log(debug.DEBUG_ALL, "Cleaned up expired ratchets", "cleaned", cleaned, "remaining", len(i.ratchets))
|
||||
}
|
||||
|
||||
// ValidateAnnounce validates an announce packet's signature
|
||||
|
||||
148
pkg/identity/identity_test.go
Normal file
148
pkg/identity/identity_test.go
Normal file
@@ -0,0 +1,148 @@
|
||||
package identity
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewIdentity(t *testing.T) {
|
||||
id, err := New()
|
||||
if err != nil {
|
||||
t.Fatalf("New() failed: %v", err)
|
||||
}
|
||||
if id == nil {
|
||||
t.Fatal("New() returned nil")
|
||||
}
|
||||
|
||||
pubKey := id.GetPublicKey()
|
||||
if len(pubKey) != 64 {
|
||||
t.Errorf("Expected public key length 64, got %d", len(pubKey))
|
||||
}
|
||||
|
||||
privKey := id.GetPrivateKey()
|
||||
if len(privKey) != 64 {
|
||||
t.Errorf("Expected private key length 64, got %d", len(privKey))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignVerify(t *testing.T) {
|
||||
id, _ := New()
|
||||
data := []byte("test data")
|
||||
sig := id.Sign(data)
|
||||
|
||||
if !id.Verify(data, sig) {
|
||||
t.Error("Verification failed for valid signature")
|
||||
}
|
||||
|
||||
if id.Verify([]byte("wrong data"), sig) {
|
||||
t.Error("Verification succeeded for wrong data")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncryptDecrypt(t *testing.T) {
|
||||
id, _ := New()
|
||||
plaintext := []byte("secret message")
|
||||
|
||||
ciphertext, err := id.Encrypt(plaintext, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Encrypt failed: %v", err)
|
||||
}
|
||||
|
||||
decrypted, err := id.Decrypt(ciphertext, nil, false, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Decrypt failed: %v", err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(plaintext, decrypted) {
|
||||
t.Errorf("Decrypted data doesn't match plaintext: %q vs %q", decrypted, plaintext)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIdentityHash(t *testing.T) {
|
||||
id, _ := New()
|
||||
h := id.Hash()
|
||||
if len(h) != TRUNCATED_HASHLENGTH/8 {
|
||||
t.Errorf("Expected hash length %d, got %d", TRUNCATED_HASHLENGTH/8, len(h))
|
||||
}
|
||||
|
||||
hexHash := id.Hex()
|
||||
if len(hexHash) != TRUNCATED_HASHLENGTH/4 {
|
||||
t.Errorf("Expected hex hash length %d, got %d", TRUNCATED_HASHLENGTH/4, len(hexHash))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileOperations(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
idPath := filepath.Join(tmpDir, "identity")
|
||||
|
||||
id, _ := New()
|
||||
err := id.ToFile(idPath)
|
||||
if err != nil {
|
||||
t.Fatalf("ToFile failed: %v", err)
|
||||
}
|
||||
|
||||
loadedID, err := FromFile(idPath)
|
||||
if err != nil {
|
||||
t.Fatalf("FromFile failed: %v", err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(id.GetPublicKey(), loadedID.GetPublicKey()) {
|
||||
t.Error("Loaded identity public key doesn't match original")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRatchets(t *testing.T) {
|
||||
id, _ := New()
|
||||
|
||||
ratchet, err := id.RotateRatchet()
|
||||
if err != nil {
|
||||
t.Fatalf("RotateRatchet failed: %v", err)
|
||||
}
|
||||
if len(ratchet) != RATCHETSIZE/8 {
|
||||
t.Errorf("Expected ratchet size %d, got %d", RATCHETSIZE/8, len(ratchet))
|
||||
}
|
||||
|
||||
ratchets := id.GetRatchets()
|
||||
if len(ratchets) != 1 {
|
||||
t.Errorf("Expected 1 ratchet, got %d", len(ratchets))
|
||||
}
|
||||
|
||||
id.CleanupExpiredRatchets()
|
||||
// Should still be there since it's not expired
|
||||
if len(id.GetRatchets()) != 1 {
|
||||
t.Error("Ratchet unexpectedly cleaned up")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecallIdentity(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
idPath := filepath.Join(tmpDir, "identity_recall")
|
||||
|
||||
id, _ := New()
|
||||
_ = id.ToFile(idPath)
|
||||
|
||||
recalledID, err := RecallIdentity(idPath)
|
||||
if err != nil {
|
||||
t.Fatalf("RecallIdentity failed: %v", err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(id.GetPublicKey(), recalledID.GetPublicKey()) {
|
||||
t.Error("Recalled identity public key doesn't match original")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTruncatedHash(t *testing.T) {
|
||||
data := []byte("some data")
|
||||
h := TruncatedHash(data)
|
||||
if len(h) != TRUNCATED_HASHLENGTH/8 {
|
||||
t.Errorf("Expected length %d, got %d", TRUNCATED_HASHLENGTH/8, len(h))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRandomHash(t *testing.T) {
|
||||
h := GetRandomHash()
|
||||
if len(h) != TRUNCATED_HASHLENGTH/8 {
|
||||
t.Errorf("Expected length %d, got %d", TRUNCATED_HASHLENGTH/8, len(h))
|
||||
}
|
||||
}
|
||||
@@ -1,49 +1,112 @@
|
||||
// SPDX-License-Identifier: 0BSD
|
||||
// Copyright (c) 2024-2026 Sudo-Ivan / Quad4.io
|
||||
//go:build !tinygo
|
||||
// +build !tinygo
|
||||
|
||||
package interfaces
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Sudo-Ivan/reticulum-go/pkg/common"
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/common"
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/debug"
|
||||
)
|
||||
|
||||
const (
|
||||
HW_MTU = 1196
|
||||
DEFAULT_DISCOVERY_PORT = 29716
|
||||
DEFAULT_DATA_PORT = 42671
|
||||
DEFAULT_GROUP_ID = "reticulum"
|
||||
BITRATE_GUESS = 10 * 1000 * 1000
|
||||
PEERING_TIMEOUT = 7500 * time.Millisecond
|
||||
SCOPE_LINK = "2"
|
||||
SCOPE_ADMIN = "4"
|
||||
SCOPE_SITE = "5"
|
||||
SCOPE_ORGANISATION = "8"
|
||||
SCOPE_GLOBAL = "e"
|
||||
PEERING_TIMEOUT = 22 * time.Second
|
||||
ANNOUNCE_INTERVAL = 1600 * time.Millisecond
|
||||
PEER_JOB_INTERVAL = 4 * time.Second
|
||||
MCAST_ECHO_TIMEOUT = 6500 * time.Millisecond
|
||||
|
||||
SCOPE_LINK = "2"
|
||||
SCOPE_ADMIN = "4"
|
||||
SCOPE_SITE = "5"
|
||||
SCOPE_ORGANISATION = "8"
|
||||
SCOPE_GLOBAL = "e"
|
||||
|
||||
MCAST_ADDR_TYPE_PERMANENT = "0"
|
||||
MCAST_ADDR_TYPE_TEMPORARY = "1"
|
||||
)
|
||||
|
||||
type AutoInterface struct {
|
||||
BaseInterface
|
||||
groupID []byte
|
||||
discoveryPort int
|
||||
dataPort int
|
||||
discoveryScope string
|
||||
peers map[string]*Peer
|
||||
linkLocalAddrs []string
|
||||
adoptedInterfaces map[string]string
|
||||
interfaceServers map[string]net.Conn
|
||||
multicastEchoes map[string]time.Time
|
||||
mutex sync.RWMutex
|
||||
outboundConn net.Conn
|
||||
groupID []byte
|
||||
groupHash []byte
|
||||
discoveryPort int
|
||||
dataPort int
|
||||
discoveryScope string
|
||||
multicastAddrType string
|
||||
mcastDiscoveryAddr string
|
||||
ifacNetname string
|
||||
peers map[string]*Peer
|
||||
linkLocalAddrs []string
|
||||
adoptedInterfaces map[string]*AdoptedInterface
|
||||
interfaceServers map[string]*net.UDPConn
|
||||
discoveryServers map[string]*net.UDPConn
|
||||
multicastEchoes map[string]time.Time
|
||||
timedOutInterfaces map[string]time.Time
|
||||
allowedInterfaces []string
|
||||
ignoredInterfaces []string
|
||||
outboundConn *net.UDPConn
|
||||
announceInterval time.Duration
|
||||
peerJobInterval time.Duration
|
||||
peeringTimeout time.Duration
|
||||
mcastEchoTimeout time.Duration
|
||||
done chan struct{}
|
||||
stopOnce sync.Once
|
||||
}
|
||||
|
||||
type AdoptedInterface struct {
|
||||
name string
|
||||
linkLocalAddr string
|
||||
index int
|
||||
}
|
||||
|
||||
type Peer struct {
|
||||
ifaceName string
|
||||
lastHeard time.Time
|
||||
conn net.PacketConn
|
||||
addr *net.UDPAddr
|
||||
}
|
||||
|
||||
func NewAutoInterface(name string, config *common.InterfaceConfig) (*AutoInterface, error) {
|
||||
groupID := DEFAULT_GROUP_ID
|
||||
if config.GroupID != "" {
|
||||
groupID = config.GroupID
|
||||
}
|
||||
|
||||
discoveryScope := SCOPE_LINK
|
||||
if config.DiscoveryScope != "" {
|
||||
discoveryScope = normalizeScope(config.DiscoveryScope)
|
||||
}
|
||||
|
||||
multicastAddrType := MCAST_ADDR_TYPE_TEMPORARY
|
||||
|
||||
discoveryPort := DEFAULT_DISCOVERY_PORT
|
||||
if config.DiscoveryPort != 0 {
|
||||
discoveryPort = config.DiscoveryPort
|
||||
}
|
||||
|
||||
dataPort := DEFAULT_DATA_PORT
|
||||
if config.DataPort != 0 {
|
||||
dataPort = config.DataPort
|
||||
}
|
||||
|
||||
groupHash := sha256.Sum256([]byte(groupID))
|
||||
|
||||
ifacNetname := hex.EncodeToString(groupHash[:])[:16]
|
||||
mcastAddr := fmt.Sprintf("ff%s%s::%s", discoveryScope, multicastAddrType, ifacNetname)
|
||||
|
||||
ai := &AutoInterface{
|
||||
BaseInterface: BaseInterface{
|
||||
Name: name,
|
||||
@@ -52,74 +115,303 @@ func NewAutoInterface(name string, config *common.InterfaceConfig) (*AutoInterfa
|
||||
Online: false,
|
||||
Enabled: config.Enabled,
|
||||
Detached: false,
|
||||
IN: false,
|
||||
IN: true,
|
||||
OUT: false,
|
||||
MTU: common.DEFAULT_MTU,
|
||||
Bitrate: BITRATE_MINIMUM,
|
||||
MTU: HW_MTU,
|
||||
Bitrate: BITRATE_GUESS,
|
||||
},
|
||||
discoveryPort: DEFAULT_DISCOVERY_PORT,
|
||||
dataPort: DEFAULT_DATA_PORT,
|
||||
discoveryScope: SCOPE_LINK,
|
||||
peers: make(map[string]*Peer),
|
||||
linkLocalAddrs: make([]string, 0),
|
||||
adoptedInterfaces: make(map[string]string),
|
||||
interfaceServers: make(map[string]net.Conn),
|
||||
multicastEchoes: make(map[string]time.Time),
|
||||
}
|
||||
|
||||
if config.Port != 0 {
|
||||
ai.discoveryPort = config.Port
|
||||
}
|
||||
|
||||
if config.GroupID != "" {
|
||||
ai.groupID = []byte(config.GroupID)
|
||||
} else {
|
||||
ai.groupID = []byte("reticulum")
|
||||
groupID: []byte(groupID),
|
||||
groupHash: groupHash[:],
|
||||
discoveryPort: discoveryPort,
|
||||
dataPort: dataPort,
|
||||
discoveryScope: discoveryScope,
|
||||
multicastAddrType: multicastAddrType,
|
||||
mcastDiscoveryAddr: mcastAddr,
|
||||
ifacNetname: ifacNetname,
|
||||
peers: make(map[string]*Peer),
|
||||
linkLocalAddrs: make([]string, 0),
|
||||
adoptedInterfaces: make(map[string]*AdoptedInterface),
|
||||
interfaceServers: make(map[string]*net.UDPConn),
|
||||
discoveryServers: make(map[string]*net.UDPConn),
|
||||
multicastEchoes: make(map[string]time.Time),
|
||||
timedOutInterfaces: make(map[string]time.Time),
|
||||
allowedInterfaces: make([]string, 0),
|
||||
ignoredInterfaces: make([]string, 0),
|
||||
announceInterval: ANNOUNCE_INTERVAL,
|
||||
peerJobInterval: PEER_JOB_INTERVAL,
|
||||
peeringTimeout: PEERING_TIMEOUT,
|
||||
mcastEchoTimeout: MCAST_ECHO_TIMEOUT,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
debug.Log(debug.DEBUG_INFO, "AutoInterface configured", "name", name, "group", groupID, "mcast_addr", mcastAddr)
|
||||
return ai, nil
|
||||
}
|
||||
|
||||
func normalizeScope(scope string) string {
|
||||
switch scope {
|
||||
case "link", "2":
|
||||
return SCOPE_LINK
|
||||
case "admin", "4":
|
||||
return SCOPE_ADMIN
|
||||
case "site", "5":
|
||||
return SCOPE_SITE
|
||||
case "organisation", "organization", "8":
|
||||
return SCOPE_ORGANISATION
|
||||
case "global", "e":
|
||||
return SCOPE_GLOBAL
|
||||
default:
|
||||
return SCOPE_LINK
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeMulticastType(mtype string) string {
|
||||
switch mtype {
|
||||
case "permanent", "0":
|
||||
return MCAST_ADDR_TYPE_PERMANENT
|
||||
case "temporary", "1":
|
||||
return MCAST_ADDR_TYPE_TEMPORARY
|
||||
default:
|
||||
return MCAST_ADDR_TYPE_TEMPORARY
|
||||
}
|
||||
}
|
||||
|
||||
func (ai *AutoInterface) Start() error {
|
||||
// TinyGo doesn't support net.Interfaces() or multicast UDP
|
||||
// AutoInterface requires these features, so return an error
|
||||
return fmt.Errorf("AutoInterface not supported in TinyGo - requires interface enumeration and multicast UDP")
|
||||
ai.Mutex.Lock()
|
||||
// Only recreate done if it's nil or was closed
|
||||
select {
|
||||
case <-ai.done:
|
||||
ai.done = make(chan struct{})
|
||||
ai.stopOnce = sync.Once{}
|
||||
default:
|
||||
if ai.done == nil {
|
||||
ai.done = make(chan struct{})
|
||||
ai.stopOnce = sync.Once{}
|
||||
}
|
||||
}
|
||||
ai.Mutex.Unlock()
|
||||
|
||||
interfaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list interfaces: %v", err)
|
||||
}
|
||||
|
||||
for _, iface := range interfaces {
|
||||
if ai.shouldIgnoreInterface(iface.Name) {
|
||||
debug.Log(debug.DEBUG_TRACE, "Ignoring interface", "name", iface.Name)
|
||||
continue
|
||||
}
|
||||
|
||||
if len(ai.allowedInterfaces) > 0 && !ai.isAllowedInterface(iface.Name) {
|
||||
debug.Log(debug.DEBUG_TRACE, "Interface not in allowed list", "name", iface.Name)
|
||||
continue
|
||||
}
|
||||
|
||||
ifaceCopy := iface
|
||||
// bearer:disable go_gosec_memory_memory_aliasing
|
||||
if err := ai.configureInterface(&ifaceCopy); err != nil {
|
||||
debug.Log(debug.DEBUG_VERBOSE, "Failed to configure interface", "name", iface.Name, "error", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if len(ai.adoptedInterfaces) == 0 {
|
||||
return fmt.Errorf("no suitable interfaces found")
|
||||
}
|
||||
|
||||
ai.Online = true
|
||||
ai.IN = true
|
||||
ai.OUT = true
|
||||
|
||||
go ai.peerJobs()
|
||||
go ai.announceLoop()
|
||||
|
||||
debug.Log(debug.DEBUG_INFO, "AutoInterface started", "adopted", len(ai.adoptedInterfaces))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ai *AutoInterface) shouldIgnoreInterface(name string) bool {
|
||||
ignoreList := []string{"lo", "lo0", "tun0", "awdl0", "llw0", "en5", "dummy0"}
|
||||
|
||||
for _, ignored := range ai.ignoredInterfaces {
|
||||
if name == ignored {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
for _, ignored := range ignoreList {
|
||||
if name == ignored {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (ai *AutoInterface) isAllowedInterface(name string) bool {
|
||||
for _, allowed := range ai.allowedInterfaces {
|
||||
if name == allowed {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (ai *AutoInterface) configureInterface(iface *net.Interface) error {
|
||||
// Not supported in TinyGo
|
||||
return fmt.Errorf("configureInterface not supported in TinyGo")
|
||||
if iface.Flags&net.FlagUp == 0 {
|
||||
return fmt.Errorf("interface is down")
|
||||
}
|
||||
|
||||
if iface.Flags&net.FlagLoopback != 0 {
|
||||
return fmt.Errorf("loopback interface")
|
||||
}
|
||||
|
||||
addrs, err := iface.Addrs()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var linkLocalAddr string
|
||||
for _, addr := range addrs {
|
||||
if ipnet, ok := addr.(*net.IPNet); ok {
|
||||
if ipnet.IP.To4() == nil && ipnet.IP.IsLinkLocalUnicast() {
|
||||
linkLocalAddr = ipnet.IP.String()
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if linkLocalAddr == "" {
|
||||
return fmt.Errorf("no link-local IPv6 address found")
|
||||
}
|
||||
|
||||
ai.Mutex.Lock()
|
||||
ai.adoptedInterfaces[iface.Name] = &AdoptedInterface{
|
||||
name: iface.Name,
|
||||
linkLocalAddr: linkLocalAddr,
|
||||
index: iface.Index,
|
||||
}
|
||||
ai.linkLocalAddrs = append(ai.linkLocalAddrs, linkLocalAddr)
|
||||
ai.multicastEchoes[iface.Name] = time.Now()
|
||||
ai.Mutex.Unlock()
|
||||
|
||||
if err := ai.startDiscoveryListener(iface); err != nil {
|
||||
return fmt.Errorf("failed to start discovery listener: %v", err)
|
||||
}
|
||||
|
||||
if err := ai.startDataListener(iface); err != nil {
|
||||
return fmt.Errorf("failed to start data listener: %v", err)
|
||||
}
|
||||
|
||||
debug.Log(debug.DEBUG_INFO, "Configured interface", "name", iface.Name, "addr", linkLocalAddr)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ai *AutoInterface) startDiscoveryListener(iface *net.Interface) error {
|
||||
// Multicast UDP not supported in TinyGo
|
||||
return fmt.Errorf("startDiscoveryListener not supported in TinyGo - requires multicast UDP")
|
||||
addr := &net.UDPAddr{
|
||||
IP: net.ParseIP(ai.mcastDiscoveryAddr),
|
||||
Port: ai.discoveryPort,
|
||||
Zone: iface.Name,
|
||||
}
|
||||
|
||||
conn, err := net.ListenMulticastUDP("udp6", iface, addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := conn.SetReadBuffer(common.NUM_1024); err != nil {
|
||||
debug.Log(debug.DEBUG_ERROR, "Failed to set discovery read buffer", "error", err)
|
||||
}
|
||||
|
||||
ai.Mutex.Lock()
|
||||
ai.discoveryServers[iface.Name] = conn
|
||||
ai.Mutex.Unlock()
|
||||
|
||||
go ai.handleDiscovery(conn, iface.Name)
|
||||
debug.Log(debug.DEBUG_VERBOSE, "Discovery listener started", "interface", iface.Name, "addr", ai.mcastDiscoveryAddr)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ai *AutoInterface) startDataListener(iface *net.Interface) error {
|
||||
// TinyGo doesn't support UDP servers
|
||||
return fmt.Errorf("startDataListener not supported in TinyGo")
|
||||
adoptedIface, exists := ai.adoptedInterfaces[iface.Name]
|
||||
if !exists {
|
||||
return fmt.Errorf("interface not adopted")
|
||||
}
|
||||
|
||||
addr := &net.UDPAddr{
|
||||
IP: net.ParseIP(adoptedIface.linkLocalAddr),
|
||||
Port: ai.dataPort,
|
||||
Zone: iface.Name,
|
||||
}
|
||||
|
||||
conn, err := net.ListenUDP("udp6", addr)
|
||||
if err != nil {
|
||||
debug.Log(debug.DEBUG_ERROR, "Failed to listen on data port", "addr", addr, "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
if err := conn.SetReadBuffer(ai.MTU); err != nil {
|
||||
debug.Log(debug.DEBUG_ERROR, "Failed to set data read buffer", "error", err)
|
||||
}
|
||||
|
||||
ai.Mutex.Lock()
|
||||
ai.interfaceServers[iface.Name] = conn
|
||||
ai.Mutex.Unlock()
|
||||
|
||||
go ai.handleData(conn, iface.Name)
|
||||
debug.Log(debug.DEBUG_VERBOSE, "Data listener started", "interface", iface.Name, "addr", addr)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ai *AutoInterface) handleDiscovery(conn net.Conn, ifaceName string) {
|
||||
// Not used in TinyGo
|
||||
buf := make([]byte, 1024)
|
||||
func (ai *AutoInterface) handleDiscovery(conn *net.UDPConn, ifaceName string) {
|
||||
buf := make([]byte, common.NUM_1024)
|
||||
for {
|
||||
_, err := conn.Read(buf)
|
||||
ai.Mutex.RLock()
|
||||
done := ai.done
|
||||
ai.Mutex.RUnlock()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
n, remoteAddr, err := conn.ReadFromUDP(buf)
|
||||
if err != nil {
|
||||
log.Printf("Discovery read error: %v", err)
|
||||
continue
|
||||
if ai.IsOnline() {
|
||||
debug.Log(debug.DEBUG_ERROR, "Discovery read error", "interface", ifaceName, "error", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if n >= len(ai.groupHash) {
|
||||
receivedHash := buf[:len(ai.groupHash)]
|
||||
if bytes.Equal(receivedHash, ai.groupHash) {
|
||||
ai.handlePeerAnnounce(remoteAddr, ifaceName)
|
||||
} else {
|
||||
debug.Log(debug.DEBUG_TRACE, "Received discovery with mismatched group hash", "interface", ifaceName)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ai *AutoInterface) handleData(conn net.Conn) {
|
||||
func (ai *AutoInterface) handleData(conn *net.UDPConn, ifaceName string) {
|
||||
buf := make([]byte, ai.GetMTU())
|
||||
for {
|
||||
n, err := conn.Read(buf)
|
||||
ai.Mutex.RLock()
|
||||
done := ai.done
|
||||
ai.Mutex.RUnlock()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
n, _, err := conn.ReadFromUDP(buf)
|
||||
if err != nil {
|
||||
if !ai.IsDetached() {
|
||||
log.Printf("Data read error: %v", err)
|
||||
if ai.IsOnline() {
|
||||
debug.Log(debug.DEBUG_ERROR, "Data read error", "interface", ifaceName, "error", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -131,62 +423,186 @@ func (ai *AutoInterface) handleData(conn net.Conn) {
|
||||
}
|
||||
|
||||
func (ai *AutoInterface) handlePeerAnnounce(addr *net.UDPAddr, ifaceName string) {
|
||||
ai.mutex.Lock()
|
||||
defer ai.mutex.Unlock()
|
||||
ai.Mutex.Lock()
|
||||
defer ai.Mutex.Unlock()
|
||||
|
||||
peerAddr := addr.IP.String()
|
||||
peerIP := addr.IP.String()
|
||||
|
||||
for _, localAddr := range ai.linkLocalAddrs {
|
||||
if peerAddr == localAddr {
|
||||
if peerIP == localAddr {
|
||||
ai.multicastEchoes[ifaceName] = time.Now()
|
||||
debug.Log(debug.DEBUG_TRACE, "Received own multicast echo", "interface", ifaceName)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if _, exists := ai.peers[peerAddr]; !exists {
|
||||
ai.peers[peerAddr] = &Peer{
|
||||
peerKey := peerIP + "%" + ifaceName
|
||||
|
||||
if peer, exists := ai.peers[peerKey]; exists {
|
||||
peer.lastHeard = time.Now()
|
||||
debug.Log(debug.DEBUG_TRACE, "Updated peer", "peer", peerIP, "interface", ifaceName)
|
||||
} else {
|
||||
ai.peers[peerKey] = &Peer{
|
||||
ifaceName: ifaceName,
|
||||
lastHeard: time.Now(),
|
||||
addr: addr,
|
||||
}
|
||||
debug.Log(debug.DEBUG_INFO, "Discovered new peer", "peer", peerIP, "interface", ifaceName)
|
||||
}
|
||||
}
|
||||
|
||||
func (ai *AutoInterface) announceLoop() {
|
||||
ticker := time.NewTicker(ai.announceInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if !ai.IsOnline() {
|
||||
return
|
||||
}
|
||||
ai.sendPeerAnnounce()
|
||||
case <-ai.done:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ai *AutoInterface) sendPeerAnnounce() {
|
||||
ai.Mutex.RLock()
|
||||
defer ai.Mutex.RUnlock()
|
||||
|
||||
for ifaceName, adoptedIface := range ai.adoptedInterfaces {
|
||||
mcastAddr := &net.UDPAddr{
|
||||
IP: net.ParseIP(ai.mcastDiscoveryAddr),
|
||||
Port: ai.discoveryPort,
|
||||
Zone: ifaceName,
|
||||
}
|
||||
|
||||
if ai.outboundConn == nil {
|
||||
var err error
|
||||
ai.outboundConn, err = net.ListenUDP("udp6", &net.UDPAddr{Port: 0})
|
||||
if err != nil {
|
||||
debug.Log(debug.DEBUG_ERROR, "Failed to create outbound socket", "error", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := ai.outboundConn.WriteToUDP(ai.groupHash, mcastAddr); err != nil {
|
||||
debug.Log(debug.DEBUG_VERBOSE, "Failed to send peer announce", "interface", ifaceName, "error", err)
|
||||
} else {
|
||||
debug.Log(debug.DEBUG_TRACE, "Sent peer announce", "interface", adoptedIface.name)
|
||||
}
|
||||
log.Printf("Added peer %s on %s", peerAddr, ifaceName)
|
||||
} else {
|
||||
ai.peers[peerAddr].lastHeard = time.Now()
|
||||
}
|
||||
}
|
||||
|
||||
func (ai *AutoInterface) peerJobs() {
|
||||
ticker := time.NewTicker(PEERING_TIMEOUT)
|
||||
for range ticker.C {
|
||||
ai.mutex.Lock()
|
||||
now := time.Now()
|
||||
ticker := time.NewTicker(ai.peerJobInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for addr, peer := range ai.peers {
|
||||
if now.Sub(peer.lastHeard) > PEERING_TIMEOUT {
|
||||
delete(ai.peers, addr)
|
||||
log.Printf("Removed timed out peer %s", addr)
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if !ai.IsOnline() {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
ai.mutex.Unlock()
|
||||
ai.Mutex.Lock()
|
||||
now := time.Now()
|
||||
|
||||
for peerKey, peer := range ai.peers {
|
||||
if now.Sub(peer.lastHeard) > ai.peeringTimeout {
|
||||
delete(ai.peers, peerKey)
|
||||
debug.Log(debug.DEBUG_VERBOSE, "Removed timed out peer", "peer", peerKey)
|
||||
}
|
||||
}
|
||||
|
||||
for ifaceName, echoTime := range ai.multicastEchoes {
|
||||
if now.Sub(echoTime) > ai.mcastEchoTimeout {
|
||||
if _, exists := ai.timedOutInterfaces[ifaceName]; !exists {
|
||||
debug.Log(debug.DEBUG_INFO, "Interface timed out", "interface", ifaceName)
|
||||
ai.timedOutInterfaces[ifaceName] = now
|
||||
}
|
||||
} else {
|
||||
delete(ai.timedOutInterfaces, ifaceName)
|
||||
}
|
||||
}
|
||||
|
||||
ai.Mutex.Unlock()
|
||||
case <-ai.done:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ai *AutoInterface) Send(data []byte, address string) error {
|
||||
// TinyGo doesn't support UDP outbound connections for auto-discovery
|
||||
return fmt.Errorf("Send not supported in TinyGo - requires UDP client connections")
|
||||
if !ai.IsOnline() {
|
||||
return fmt.Errorf("interface offline")
|
||||
}
|
||||
|
||||
ai.Mutex.RLock()
|
||||
defer ai.Mutex.RUnlock()
|
||||
|
||||
if len(ai.peers) == 0 {
|
||||
debug.Log(debug.DEBUG_TRACE, "No peers available for sending")
|
||||
return nil
|
||||
}
|
||||
|
||||
if ai.outboundConn == nil {
|
||||
var err error
|
||||
ai.outboundConn, err = net.ListenUDP("udp6", &net.UDPAddr{Port: 0})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create outbound socket: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
sentCount := 0
|
||||
for _, peer := range ai.peers {
|
||||
targetAddr := &net.UDPAddr{
|
||||
IP: peer.addr.IP,
|
||||
Port: ai.dataPort,
|
||||
Zone: peer.ifaceName,
|
||||
}
|
||||
|
||||
if _, err := ai.outboundConn.WriteToUDP(data, targetAddr); err != nil {
|
||||
debug.Log(debug.DEBUG_VERBOSE, "Failed to send to peer", "interface", peer.ifaceName, "error", err)
|
||||
continue
|
||||
}
|
||||
sentCount++
|
||||
}
|
||||
|
||||
if sentCount > 0 {
|
||||
debug.Log(debug.DEBUG_TRACE, "Sent data to peers", "count", sentCount, "bytes", len(data))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ai *AutoInterface) Stop() error {
|
||||
ai.mutex.Lock()
|
||||
defer ai.mutex.Unlock()
|
||||
ai.Mutex.Lock()
|
||||
ai.Online = false
|
||||
ai.IN = false
|
||||
ai.OUT = false
|
||||
|
||||
for _, server := range ai.interfaceServers {
|
||||
server.Close() // #nosec G104
|
||||
}
|
||||
|
||||
for _, server := range ai.discoveryServers {
|
||||
server.Close() // #nosec G104
|
||||
}
|
||||
|
||||
if ai.outboundConn != nil {
|
||||
ai.outboundConn.Close() // #nosec G104
|
||||
}
|
||||
ai.Mutex.Unlock()
|
||||
|
||||
ai.stopOnce.Do(func() {
|
||||
if ai.done != nil {
|
||||
close(ai.done)
|
||||
}
|
||||
})
|
||||
|
||||
debug.Log(debug.DEBUG_INFO, "AutoInterface stopped")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Sudo-Ivan/reticulum-go/pkg/common"
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/common"
|
||||
)
|
||||
|
||||
func TestNewAutoInterface(t *testing.T) {
|
||||
@@ -44,9 +44,10 @@ func TestNewAutoInterface(t *testing.T) {
|
||||
|
||||
t.Run("CustomConfig", func(t *testing.T) {
|
||||
config := &common.InterfaceConfig{
|
||||
Enabled: true,
|
||||
Port: 12345, // Custom discovery port
|
||||
GroupID: "customGroup",
|
||||
Enabled: true,
|
||||
DiscoveryPort: 12345,
|
||||
DataPort: 54321,
|
||||
GroupID: "customGroup",
|
||||
}
|
||||
ai, err := NewAutoInterface("autoCustom", config)
|
||||
if err != nil {
|
||||
@@ -59,6 +60,9 @@ func TestNewAutoInterface(t *testing.T) {
|
||||
if ai.discoveryPort != 12345 {
|
||||
t.Errorf("discoveryPort = %d; want 12345", ai.discoveryPort)
|
||||
}
|
||||
if ai.dataPort != 54321 {
|
||||
t.Errorf("dataPort = %d; want 54321", ai.dataPort)
|
||||
}
|
||||
if string(ai.groupID) != "customGroup" {
|
||||
t.Errorf("groupID = %s; want customGroup", string(ai.groupID))
|
||||
}
|
||||
@@ -79,9 +83,11 @@ func newMockAutoInterface(name string, config *common.InterfaceConfig) (*mockAut
|
||||
// Initialize maps that would normally be initialized in Start()
|
||||
ai.peers = make(map[string]*Peer)
|
||||
ai.linkLocalAddrs = make([]string, 0)
|
||||
ai.adoptedInterfaces = make(map[string]string)
|
||||
ai.adoptedInterfaces = make(map[string]*AdoptedInterface)
|
||||
ai.interfaceServers = make(map[string]*net.UDPConn)
|
||||
ai.discoveryServers = make(map[string]*net.UDPConn)
|
||||
ai.multicastEchoes = make(map[string]time.Time)
|
||||
ai.timedOutInterfaces = make(map[string]time.Time)
|
||||
|
||||
return &mockAutoInterface{AutoInterface: ai}, nil
|
||||
}
|
||||
@@ -138,14 +144,14 @@ func TestAutoInterfacePeerManagement(t *testing.T) {
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
ai.mutex.Lock()
|
||||
ai.Mutex.Lock()
|
||||
now := time.Now()
|
||||
for addr, peer := range ai.peers {
|
||||
if now.Sub(peer.lastHeard) > testTimeout {
|
||||
delete(ai.peers, addr)
|
||||
}
|
||||
}
|
||||
ai.mutex.Unlock()
|
||||
ai.Mutex.Unlock()
|
||||
case <-done:
|
||||
return
|
||||
}
|
||||
@@ -167,27 +173,26 @@ func TestAutoInterfacePeerManagement(t *testing.T) {
|
||||
peer2Addr := &net.UDPAddr{IP: net.ParseIP("fe80::2"), Zone: "eth0"}
|
||||
localAddr := &net.UDPAddr{IP: net.ParseIP("fe80::aaaa"), Zone: "eth0"}
|
||||
|
||||
// Add a simulated local address to avoid adding it as a peer
|
||||
ai.mutex.Lock()
|
||||
ai.Mutex.Lock()
|
||||
ai.linkLocalAddrs = append(ai.linkLocalAddrs, localAddrStr)
|
||||
ai.mutex.Unlock()
|
||||
ai.Mutex.Unlock()
|
||||
|
||||
t.Run("AddPeer1", func(t *testing.T) {
|
||||
ai.mutex.Lock()
|
||||
ai.Mutex.Lock()
|
||||
ai.mockHandlePeerAnnounce(peer1Addr, "eth0")
|
||||
ai.mutex.Unlock()
|
||||
ai.Mutex.Unlock()
|
||||
|
||||
// Give a small amount of time for the peer to be processed
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
ai.mutex.RLock()
|
||||
ai.Mutex.RLock()
|
||||
count := len(ai.peers)
|
||||
peer, exists := ai.peers[peer1AddrStr]
|
||||
var ifaceName string
|
||||
if exists {
|
||||
ifaceName = peer.ifaceName
|
||||
}
|
||||
ai.mutex.RUnlock()
|
||||
ai.Mutex.RUnlock()
|
||||
|
||||
if count != 1 {
|
||||
t.Fatalf("Expected 1 peer, got %d", count)
|
||||
@@ -201,17 +206,17 @@ func TestAutoInterfacePeerManagement(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("AddPeer2", func(t *testing.T) {
|
||||
ai.mutex.Lock()
|
||||
ai.Mutex.Lock()
|
||||
ai.mockHandlePeerAnnounce(peer2Addr, "eth0")
|
||||
ai.mutex.Unlock()
|
||||
ai.Mutex.Unlock()
|
||||
|
||||
// Give a small amount of time for the peer to be processed
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
ai.mutex.RLock()
|
||||
ai.Mutex.RLock()
|
||||
count := len(ai.peers)
|
||||
_, exists := ai.peers[peer2AddrStr]
|
||||
ai.mutex.RUnlock()
|
||||
ai.Mutex.RUnlock()
|
||||
|
||||
if count != 2 {
|
||||
t.Fatalf("Expected 2 peers, got %d", count)
|
||||
@@ -222,16 +227,16 @@ func TestAutoInterfacePeerManagement(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("IgnoreLocalAnnounce", func(t *testing.T) {
|
||||
ai.mutex.Lock()
|
||||
ai.Mutex.Lock()
|
||||
ai.mockHandlePeerAnnounce(localAddr, "eth0")
|
||||
ai.mutex.Unlock()
|
||||
ai.Mutex.Unlock()
|
||||
|
||||
// Give a small amount of time for the peer to be processed
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
ai.mutex.RLock()
|
||||
ai.Mutex.RLock()
|
||||
count := len(ai.peers)
|
||||
ai.mutex.RUnlock()
|
||||
ai.Mutex.RUnlock()
|
||||
|
||||
if count != 2 {
|
||||
t.Fatalf("Expected 2 peers after local announce, got %d", count)
|
||||
@@ -239,32 +244,32 @@ func TestAutoInterfacePeerManagement(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("UpdatePeerTimestamp", func(t *testing.T) {
|
||||
ai.mutex.RLock()
|
||||
ai.Mutex.RLock()
|
||||
peer, exists := ai.peers[peer1AddrStr]
|
||||
var initialTime time.Time
|
||||
if exists {
|
||||
initialTime = peer.lastHeard
|
||||
}
|
||||
ai.mutex.RUnlock()
|
||||
ai.Mutex.RUnlock()
|
||||
|
||||
if !exists {
|
||||
t.Fatalf("Peer %s not found before timestamp update", peer1AddrStr)
|
||||
}
|
||||
|
||||
ai.mutex.Lock()
|
||||
ai.Mutex.Lock()
|
||||
ai.mockHandlePeerAnnounce(peer1Addr, "eth0")
|
||||
ai.mutex.Unlock()
|
||||
ai.Mutex.Unlock()
|
||||
|
||||
// Give a small amount of time for the peer to be processed
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
ai.mutex.RLock()
|
||||
ai.Mutex.RLock()
|
||||
peer, exists = ai.peers[peer1AddrStr]
|
||||
var updatedTime time.Time
|
||||
if exists {
|
||||
updatedTime = peer.lastHeard
|
||||
}
|
||||
ai.mutex.RUnlock()
|
||||
ai.Mutex.RUnlock()
|
||||
|
||||
if !exists {
|
||||
t.Fatalf("Peer %s not found after timestamp update", peer1AddrStr)
|
||||
@@ -279,9 +284,9 @@ func TestAutoInterfacePeerManagement(t *testing.T) {
|
||||
// Wait for peer timeout
|
||||
time.Sleep(testTimeout * 2)
|
||||
|
||||
ai.mutex.RLock()
|
||||
ai.Mutex.RLock()
|
||||
count := len(ai.peers)
|
||||
ai.mutex.RUnlock()
|
||||
ai.Mutex.RUnlock()
|
||||
|
||||
if count != 0 {
|
||||
t.Errorf("Expected all peers to timeout, got %d peers", count)
|
||||
|
||||
97
pkg/interfaces/auto_tinygo.go
Normal file
97
pkg/interfaces/auto_tinygo.go
Normal file
@@ -0,0 +1,97 @@
|
||||
// SPDX-License-Identifier: 0BSD
|
||||
// Copyright (c) 2024-2026 Sudo-Ivan / Quad4.io
|
||||
//go:build tinygo
|
||||
// +build tinygo
|
||||
|
||||
package interfaces
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/common"
|
||||
)
|
||||
|
||||
const (
|
||||
HW_MTU = 1196
|
||||
DEFAULT_DISCOVERY_PORT = 29716
|
||||
DEFAULT_DATA_PORT = 42671
|
||||
DEFAULT_GROUP_ID = "reticulum"
|
||||
BITRATE_GUESS = 10 * 1000 * 1000
|
||||
)
|
||||
|
||||
type AutoInterface struct {
|
||||
BaseInterface
|
||||
groupID []byte
|
||||
discoveryPort int
|
||||
dataPort int
|
||||
discoveryScope string
|
||||
peers map[string]*Peer
|
||||
linkLocalAddrs []string
|
||||
adoptedInterfaces map[string]string
|
||||
interfaceServers map[string]net.Conn
|
||||
multicastEchoes map[string]time.Time
|
||||
mutex sync.RWMutex
|
||||
outboundConn net.Conn
|
||||
}
|
||||
|
||||
type Peer struct {
|
||||
ifaceName string
|
||||
lastHeard time.Time
|
||||
conn net.PacketConn
|
||||
}
|
||||
|
||||
func NewAutoInterface(name string, config *common.InterfaceConfig) (*AutoInterface, error) {
|
||||
ai := &AutoInterface{
|
||||
BaseInterface: BaseInterface{
|
||||
Name: name,
|
||||
Mode: common.IF_MODE_FULL,
|
||||
Type: common.IF_TYPE_AUTO,
|
||||
Online: false,
|
||||
Enabled: config.Enabled,
|
||||
Detached: false,
|
||||
IN: true,
|
||||
OUT: false,
|
||||
MTU: HW_MTU,
|
||||
Bitrate: BITRATE_GUESS,
|
||||
},
|
||||
discoveryPort: DEFAULT_DISCOVERY_PORT,
|
||||
dataPort: DEFAULT_DATA_PORT,
|
||||
peers: make(map[string]*Peer),
|
||||
linkLocalAddrs: make([]string, 0),
|
||||
adoptedInterfaces: make(map[string]string),
|
||||
interfaceServers: make(map[string]net.Conn),
|
||||
multicastEchoes: make(map[string]time.Time),
|
||||
}
|
||||
|
||||
if config.Port != 0 {
|
||||
ai.discoveryPort = config.Port
|
||||
}
|
||||
|
||||
if config.GroupID != "" {
|
||||
ai.groupID = []byte(config.GroupID)
|
||||
} else {
|
||||
ai.groupID = []byte("reticulum")
|
||||
}
|
||||
|
||||
return ai, nil
|
||||
}
|
||||
|
||||
func (ai *AutoInterface) Start() error {
|
||||
// TinyGo doesn't support net.Interfaces() or multicast UDP
|
||||
return fmt.Errorf("AutoInterface not supported in TinyGo - requires interface enumeration and multicast UDP")
|
||||
}
|
||||
|
||||
func (ai *AutoInterface) Send(data []byte, address string) error {
|
||||
return fmt.Errorf("Send not supported in TinyGo - requires UDP client connections")
|
||||
}
|
||||
|
||||
func (ai *AutoInterface) Stop() error {
|
||||
ai.Mutex.Lock()
|
||||
defer ai.Mutex.Unlock()
|
||||
ai.Online = false
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
// SPDX-License-Identifier: 0BSD
|
||||
// Copyright (c) 2024-2026 Sudo-Ivan / Quad4.io
|
||||
package interfaces
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Sudo-Ivan/reticulum-go/pkg/common"
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/common"
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/debug"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -26,17 +28,6 @@ const (
|
||||
TYPE_TCP = 0x02
|
||||
|
||||
PROPAGATION_RATE = 0.02 // 2% of interface bandwidth
|
||||
|
||||
DEBUG_LEVEL = 4 // Default debug level for interface logging
|
||||
|
||||
// Debug levels
|
||||
DEBUG_CRITICAL = 1
|
||||
DEBUG_ERROR = 2
|
||||
DEBUG_INFO = 3
|
||||
DEBUG_VERBOSE = 4
|
||||
DEBUG_TRACE = 5
|
||||
DEBUG_PACKETS = 6
|
||||
DEBUG_ALL = 7
|
||||
)
|
||||
|
||||
type Interface interface {
|
||||
@@ -78,8 +69,9 @@ type BaseInterface struct {
|
||||
TxBytes uint64
|
||||
RxBytes uint64
|
||||
lastTx time.Time
|
||||
lastRx time.Time
|
||||
|
||||
mutex sync.RWMutex
|
||||
Mutex sync.RWMutex
|
||||
packetCallback common.PacketCallback
|
||||
}
|
||||
|
||||
@@ -96,29 +88,30 @@ func NewBaseInterface(name string, ifType common.InterfaceType, enabled bool) Ba
|
||||
MTU: common.DEFAULT_MTU,
|
||||
Bitrate: BITRATE_MINIMUM,
|
||||
lastTx: time.Now(),
|
||||
lastRx: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
func (i *BaseInterface) SetPacketCallback(callback common.PacketCallback) {
|
||||
i.mutex.Lock()
|
||||
defer i.mutex.Unlock()
|
||||
i.Mutex.Lock()
|
||||
defer i.Mutex.Unlock()
|
||||
i.packetCallback = callback
|
||||
}
|
||||
|
||||
func (i *BaseInterface) GetPacketCallback() common.PacketCallback {
|
||||
i.mutex.RLock()
|
||||
defer i.mutex.RUnlock()
|
||||
i.Mutex.RLock()
|
||||
defer i.Mutex.RUnlock()
|
||||
return i.packetCallback
|
||||
}
|
||||
|
||||
func (i *BaseInterface) ProcessIncoming(data []byte) {
|
||||
i.mutex.Lock()
|
||||
i.Mutex.Lock()
|
||||
i.RxBytes += uint64(len(data))
|
||||
i.mutex.Unlock()
|
||||
i.Mutex.Unlock()
|
||||
|
||||
i.mutex.RLock()
|
||||
i.Mutex.RLock()
|
||||
callback := i.packetCallback
|
||||
i.mutex.RUnlock()
|
||||
i.Mutex.RUnlock()
|
||||
|
||||
if callback != nil {
|
||||
callback(data, i)
|
||||
@@ -127,15 +120,15 @@ func (i *BaseInterface) ProcessIncoming(data []byte) {
|
||||
|
||||
func (i *BaseInterface) ProcessOutgoing(data []byte) error {
|
||||
if !i.Online || i.Detached {
|
||||
log.Printf("[DEBUG-1] Interface %s: Cannot process outgoing packet - interface offline or detached", i.Name)
|
||||
debug.Log(debug.DEBUG_CRITICAL, "Interface cannot process outgoing packet - interface offline or detached", "name", i.Name)
|
||||
return fmt.Errorf("interface offline or detached")
|
||||
}
|
||||
|
||||
i.mutex.Lock()
|
||||
i.Mutex.Lock()
|
||||
i.TxBytes += uint64(len(data))
|
||||
i.mutex.Unlock()
|
||||
i.Mutex.Unlock()
|
||||
|
||||
log.Printf("[DEBUG-%d] Interface %s: Processed outgoing packet of %d bytes, total TX: %d", DEBUG_LEVEL, i.Name, len(data), i.TxBytes)
|
||||
debug.Log(debug.DEBUG_VERBOSE, "Interface processed outgoing packet", "name", i.Name, "bytes", len(data), "total_tx", i.TxBytes)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -145,7 +138,7 @@ func (i *BaseInterface) SendPathRequest(packet []byte) error {
|
||||
}
|
||||
|
||||
frame := make([]byte, 0, len(packet)+1)
|
||||
frame = append(frame, 0x01)
|
||||
frame = append(frame, common.HEX_0x01)
|
||||
frame = append(frame, packet...)
|
||||
|
||||
return i.ProcessOutgoing(frame)
|
||||
@@ -157,7 +150,7 @@ func (i *BaseInterface) SendLinkPacket(dest []byte, data []byte, timestamp time.
|
||||
}
|
||||
|
||||
frame := make([]byte, 0, len(dest)+len(data)+9)
|
||||
frame = append(frame, 0x02)
|
||||
frame = append(frame, common.HEX_0x02)
|
||||
frame = append(frame, dest...)
|
||||
|
||||
ts := make([]byte, 8)
|
||||
@@ -169,35 +162,35 @@ func (i *BaseInterface) SendLinkPacket(dest []byte, data []byte, timestamp time.
|
||||
}
|
||||
|
||||
func (i *BaseInterface) Detach() {
|
||||
i.mutex.Lock()
|
||||
defer i.mutex.Unlock()
|
||||
i.Mutex.Lock()
|
||||
defer i.Mutex.Unlock()
|
||||
i.Detached = true
|
||||
i.Online = false
|
||||
}
|
||||
|
||||
func (i *BaseInterface) IsEnabled() bool {
|
||||
i.mutex.RLock()
|
||||
defer i.mutex.RUnlock()
|
||||
i.Mutex.RLock()
|
||||
defer i.Mutex.RUnlock()
|
||||
return i.Enabled && i.Online && !i.Detached
|
||||
}
|
||||
|
||||
func (i *BaseInterface) Enable() {
|
||||
i.mutex.Lock()
|
||||
defer i.mutex.Unlock()
|
||||
i.Mutex.Lock()
|
||||
defer i.Mutex.Unlock()
|
||||
|
||||
prevState := i.Enabled
|
||||
i.Enabled = true
|
||||
i.Online = true
|
||||
|
||||
log.Printf("[DEBUG-%d] Interface %s: State changed - Enabled: %v->%v, Online: %v->%v", DEBUG_INFO, i.Name, prevState, i.Enabled, !i.Online, i.Online)
|
||||
debug.Log(debug.DEBUG_INFO, "Interface state changed", "name", i.Name, "enabled_prev", prevState, "enabled", i.Enabled, "online_prev", !i.Online, "online", i.Online)
|
||||
}
|
||||
|
||||
func (i *BaseInterface) Disable() {
|
||||
i.mutex.Lock()
|
||||
defer i.mutex.Unlock()
|
||||
i.Mutex.Lock()
|
||||
defer i.Mutex.Unlock()
|
||||
i.Enabled = false
|
||||
i.Online = false
|
||||
log.Printf("[DEBUG-2] Interface %s: Disabled and offline", i.Name)
|
||||
debug.Log(debug.DEBUG_ERROR, "Interface disabled and offline", "name", i.Name)
|
||||
}
|
||||
|
||||
func (i *BaseInterface) GetName() string {
|
||||
@@ -217,14 +210,14 @@ func (i *BaseInterface) GetMTU() int {
|
||||
}
|
||||
|
||||
func (i *BaseInterface) IsOnline() bool {
|
||||
i.mutex.RLock()
|
||||
defer i.mutex.RUnlock()
|
||||
i.Mutex.RLock()
|
||||
defer i.Mutex.RUnlock()
|
||||
return i.Online
|
||||
}
|
||||
|
||||
func (i *BaseInterface) IsDetached() bool {
|
||||
i.mutex.RLock()
|
||||
defer i.mutex.RUnlock()
|
||||
i.Mutex.RLock()
|
||||
defer i.Mutex.RUnlock()
|
||||
return i.Detached
|
||||
}
|
||||
|
||||
@@ -237,11 +230,11 @@ func (i *BaseInterface) Stop() error {
|
||||
}
|
||||
|
||||
func (i *BaseInterface) Send(data []byte, address string) error {
|
||||
log.Printf("[DEBUG-%d] Interface %s: Sending %d bytes to %s", DEBUG_LEVEL, i.Name, len(data), address)
|
||||
debug.Log(debug.DEBUG_VERBOSE, "Interface sending bytes", "name", i.Name, "bytes", len(data), "address", address)
|
||||
|
||||
err := i.ProcessOutgoing(data)
|
||||
if err != nil {
|
||||
log.Printf("[DEBUG-1] Interface %s: Failed to send data: %v", i.Name, err)
|
||||
debug.Log(debug.DEBUG_CRITICAL, "Interface failed to send data", "name", i.Name, "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -254,14 +247,14 @@ func (i *BaseInterface) GetConn() net.Conn {
|
||||
}
|
||||
|
||||
func (i *BaseInterface) GetBandwidthAvailable() bool {
|
||||
i.mutex.RLock()
|
||||
defer i.mutex.RUnlock()
|
||||
i.Mutex.RLock()
|
||||
defer i.Mutex.RUnlock()
|
||||
|
||||
now := time.Now()
|
||||
timeSinceLastTx := now.Sub(i.lastTx)
|
||||
|
||||
if timeSinceLastTx > time.Second {
|
||||
log.Printf("[DEBUG-%d] Interface %s: Bandwidth available (idle for %.2fs)", DEBUG_VERBOSE, i.Name, timeSinceLastTx.Seconds())
|
||||
debug.Log(debug.DEBUG_VERBOSE, "Interface bandwidth available", "name", i.Name, "idle_seconds", timeSinceLastTx.Seconds())
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -270,19 +263,19 @@ func (i *BaseInterface) GetBandwidthAvailable() bool {
|
||||
maxUsage := float64(i.Bitrate) * PROPAGATION_RATE
|
||||
|
||||
available := currentUsage < maxUsage
|
||||
log.Printf("[DEBUG-%d] Interface %s: Bandwidth stats - Current: %.2f bps, Max: %.2f bps, Usage: %.1f%%, Available: %v", DEBUG_VERBOSE, i.Name, currentUsage, maxUsage, (currentUsage/maxUsage)*100, available)
|
||||
debug.Log(debug.DEBUG_VERBOSE, "Interface bandwidth stats", "name", i.Name, "current_bps", currentUsage, "max_bps", maxUsage, "usage_percent", (currentUsage/maxUsage)*100, "available", available)
|
||||
|
||||
return available
|
||||
}
|
||||
|
||||
func (i *BaseInterface) updateBandwidthStats(bytes uint64) {
|
||||
i.mutex.Lock()
|
||||
defer i.mutex.Unlock()
|
||||
i.Mutex.Lock()
|
||||
defer i.Mutex.Unlock()
|
||||
|
||||
i.TxBytes += bytes
|
||||
i.lastTx = time.Now()
|
||||
|
||||
log.Printf("[DEBUG-%d] Interface %s: Updated bandwidth stats - TX bytes: %d, Last TX: %v", DEBUG_LEVEL, i.Name, i.TxBytes, i.lastTx)
|
||||
debug.Log(debug.DEBUG_VERBOSE, "Interface updated bandwidth stats", "name", i.Name, "tx_bytes", i.TxBytes, "last_tx", i.lastTx)
|
||||
}
|
||||
|
||||
type InterceptedInterface struct {
|
||||
@@ -305,7 +298,7 @@ func (i *InterceptedInterface) Send(data []byte, addr string) error {
|
||||
// Call interceptor if provided
|
||||
if i.interceptor != nil && len(data) > 0 {
|
||||
if err := i.interceptor(data, i); err != nil {
|
||||
log.Printf("[DEBUG-2] Failed to intercept outgoing packet: %v", err)
|
||||
debug.Log(debug.DEBUG_ERROR, "Failed to intercept outgoing packet", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Sudo-Ivan/reticulum-go/pkg/common"
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/common"
|
||||
)
|
||||
|
||||
func TestBaseInterfaceStateChanges(t *testing.T) {
|
||||
@@ -183,7 +183,6 @@ func (m *mockInterface) Send(data []byte, addr string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Add other methods to satisfy the Interface interface (can be minimal/panic)
|
||||
func (m *mockInterface) GetType() common.InterfaceType { return common.IF_TYPE_NONE }
|
||||
func (m *mockInterface) GetMode() common.InterfaceMode { return common.IF_MODE_FULL }
|
||||
func (m *mockInterface) ProcessIncoming(data []byte) {}
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
// SPDX-License-Identifier: 0BSD
|
||||
// Copyright (c) 2024-2026 Sudo-Ivan / Quad4.io
|
||||
package interfaces
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Sudo-Ivan/reticulum-go/pkg/common"
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/common"
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/debug"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -21,14 +23,26 @@ const (
|
||||
KISS_TFEND = 0xDC
|
||||
KISS_TFESC = 0xDD
|
||||
|
||||
TCP_USER_TIMEOUT = 24
|
||||
TCP_PROBE_AFTER = 5
|
||||
TCP_PROBE_INTERVAL = 2
|
||||
TCP_PROBES = 12
|
||||
RECONNECT_WAIT = 5
|
||||
INITIAL_TIMEOUT = 5
|
||||
INITIAL_BACKOFF = time.Second
|
||||
MAX_BACKOFF = time.Minute * 5
|
||||
DEFAULT_MTU = 1064
|
||||
BITRATE_GUESS_VAL = 10 * 1000 * 1000
|
||||
RECONNECT_WAIT = 5
|
||||
INITIAL_TIMEOUT = 5
|
||||
INITIAL_BACKOFF = time.Second
|
||||
MAX_BACKOFF = time.Minute * 5
|
||||
|
||||
TCP_USER_TIMEOUT_SEC = 24
|
||||
TCP_PROBE_AFTER_SEC = 5
|
||||
TCP_PROBE_INTERVAL_SEC = 2
|
||||
TCP_PROBES_COUNT = 12
|
||||
TCP_CONNECT_TIMEOUT = 10 * time.Second
|
||||
TCP_MILLISECONDS = 1000
|
||||
|
||||
I2P_USER_TIMEOUT_SEC = 45
|
||||
I2P_PROBE_AFTER_SEC = 10
|
||||
I2P_PROBE_INTERVAL_SEC = 9
|
||||
I2P_PROBES_COUNT = 5
|
||||
|
||||
SO_KEEPALIVE_ENABLE = 1
|
||||
)
|
||||
|
||||
type TCPClientInterface struct {
|
||||
@@ -45,12 +59,8 @@ type TCPClientInterface struct {
|
||||
maxReconnectTries int
|
||||
packetBuffer []byte
|
||||
packetType byte
|
||||
mutex sync.RWMutex
|
||||
enabled bool
|
||||
TxBytes uint64
|
||||
RxBytes uint64
|
||||
lastTx time.Time
|
||||
lastRx time.Time
|
||||
done chan struct{}
|
||||
stopOnce sync.Once
|
||||
}
|
||||
|
||||
func NewTCPClientInterface(name string, targetHost string, targetPort int, kissFraming bool, i2pTunneled bool, enabled bool) (*TCPClientInterface, error) {
|
||||
@@ -61,10 +71,10 @@ func NewTCPClientInterface(name string, targetHost string, targetPort int, kissF
|
||||
kissFraming: kissFraming,
|
||||
i2pTunneled: i2pTunneled,
|
||||
initiator: true,
|
||||
enabled: enabled,
|
||||
maxReconnectTries: TCP_PROBES,
|
||||
maxReconnectTries: RECONNECT_WAIT * TCP_PROBES_COUNT,
|
||||
packetBuffer: make([]byte, 0),
|
||||
neverConnected: true,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
if enabled {
|
||||
@@ -82,43 +92,81 @@ func NewTCPClientInterface(name string, targetHost string, targetPort int, kissF
|
||||
}
|
||||
|
||||
func (tc *TCPClientInterface) Start() error {
|
||||
tc.mutex.Lock()
|
||||
defer tc.mutex.Unlock()
|
||||
|
||||
if !tc.Enabled {
|
||||
return fmt.Errorf("interface not enabled")
|
||||
tc.Mutex.Lock()
|
||||
if !tc.Enabled || tc.Detached {
|
||||
tc.Mutex.Unlock()
|
||||
return fmt.Errorf("interface not enabled or detached")
|
||||
}
|
||||
|
||||
if tc.conn != nil {
|
||||
tc.Online = true
|
||||
go tc.readLoop()
|
||||
tc.Mutex.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Only recreate done if it's nil or was closed
|
||||
select {
|
||||
case <-tc.done:
|
||||
tc.done = make(chan struct{})
|
||||
tc.stopOnce = sync.Once{}
|
||||
default:
|
||||
if tc.done == nil {
|
||||
tc.done = make(chan struct{})
|
||||
tc.stopOnce = sync.Once{}
|
||||
}
|
||||
}
|
||||
tc.Mutex.Unlock()
|
||||
|
||||
addr := net.JoinHostPort(tc.targetAddr, fmt.Sprintf("%d", tc.targetPort))
|
||||
conn, err := net.Dial("tcp", addr)
|
||||
conn, err := net.DialTimeout("tcp", addr, TCP_CONNECT_TIMEOUT)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tc.Mutex.Lock()
|
||||
tc.conn = conn
|
||||
tc.Mutex.Unlock()
|
||||
|
||||
// Set platform-specific timeouts
|
||||
switch runtime.GOOS {
|
||||
case "linux":
|
||||
if err := tc.setTimeoutsLinux(); err != nil {
|
||||
log.Printf("[DEBUG-2] Failed to set Linux TCP timeouts: %v", err)
|
||||
debug.Log(debug.DEBUG_ERROR, "Failed to set Linux TCP timeouts", "error", err)
|
||||
}
|
||||
case "darwin":
|
||||
if err := tc.setTimeoutsOSX(); err != nil {
|
||||
log.Printf("[DEBUG-2] Failed to set OSX TCP timeouts: %v", err)
|
||||
debug.Log(debug.DEBUG_ERROR, "Failed to set OSX TCP timeouts", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
tc.Mutex.Lock()
|
||||
tc.Online = true
|
||||
tc.Mutex.Unlock()
|
||||
|
||||
go tc.readLoop()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tc *TCPClientInterface) Stop() error {
|
||||
tc.Mutex.Lock()
|
||||
tc.Enabled = false
|
||||
tc.Online = false
|
||||
if tc.conn != nil {
|
||||
_ = tc.conn.Close()
|
||||
tc.conn = nil
|
||||
}
|
||||
tc.Mutex.Unlock()
|
||||
|
||||
tc.stopOnce.Do(func() {
|
||||
if tc.done != nil {
|
||||
close(tc.done)
|
||||
}
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tc *TCPClientInterface) readLoop() {
|
||||
buffer := make([]byte, tc.MTU)
|
||||
inFrame := false
|
||||
@@ -126,10 +174,30 @@ func (tc *TCPClientInterface) readLoop() {
|
||||
dataBuffer := make([]byte, 0)
|
||||
|
||||
for {
|
||||
n, err := tc.conn.Read(buffer)
|
||||
tc.Mutex.RLock()
|
||||
conn := tc.conn
|
||||
done := tc.done
|
||||
tc.Mutex.RUnlock()
|
||||
|
||||
if conn == nil {
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
n, err := conn.Read(buffer)
|
||||
if err != nil {
|
||||
tc.Mutex.Lock()
|
||||
tc.Online = false
|
||||
if tc.initiator && !tc.Detached {
|
||||
detached := tc.Detached
|
||||
initiator := tc.initiator
|
||||
tc.Mutex.Unlock()
|
||||
|
||||
if initiator && !detached {
|
||||
go tc.reconnect()
|
||||
} else {
|
||||
tc.teardown()
|
||||
@@ -137,7 +205,6 @@ func (tc *TCPClientInterface) readLoop() {
|
||||
return
|
||||
}
|
||||
|
||||
// Update RX bytes for raw received data
|
||||
tc.UpdateStats(uint64(n), true) // #nosec G115
|
||||
|
||||
for i := 0; i < n; i++ {
|
||||
@@ -169,46 +236,47 @@ func (tc *TCPClientInterface) readLoop() {
|
||||
|
||||
func (tc *TCPClientInterface) handlePacket(data []byte) {
|
||||
if len(data) < 1 {
|
||||
log.Printf("[DEBUG-7] Received invalid packet: empty")
|
||||
debug.Log(debug.DEBUG_ALL, "Received invalid packet: empty")
|
||||
return
|
||||
}
|
||||
|
||||
tc.mutex.Lock()
|
||||
tc.Mutex.Lock()
|
||||
tc.RxBytes += uint64(len(data))
|
||||
lastRx := time.Now()
|
||||
tc.lastRx = lastRx
|
||||
tc.mutex.Unlock()
|
||||
callback := tc.packetCallback
|
||||
tc.Mutex.Unlock()
|
||||
|
||||
log.Printf("[DEBUG-7] Received packet: type=0x%02x, size=%d bytes", data[0], len(data))
|
||||
debug.Log(debug.DEBUG_ALL, "Received packet", "type", fmt.Sprintf("0x%02x", data[0]), "size", len(data))
|
||||
|
||||
// For RNS packets, call the packet callback directly
|
||||
if callback := tc.GetPacketCallback(); callback != nil {
|
||||
log.Printf("[DEBUG-7] Calling packet callback for RNS packet")
|
||||
if callback != nil {
|
||||
debug.Log(debug.DEBUG_ALL, "Calling packet callback for RNS packet")
|
||||
callback(data, tc)
|
||||
} else {
|
||||
log.Printf("[DEBUG-7] No packet callback set for TCP interface")
|
||||
debug.Log(debug.DEBUG_ALL, "No packet callback set for TCP interface")
|
||||
}
|
||||
}
|
||||
|
||||
// Send implements the interface Send method for TCP interface
|
||||
func (tc *TCPClientInterface) Send(data []byte, address string) error {
|
||||
log.Printf("[DEBUG-7] TCP interface %s: Sending %d bytes", tc.Name, len(data))
|
||||
|
||||
debug.Log(debug.DEBUG_ALL, "TCP interface sending bytes", "name", tc.Name, "bytes", len(data))
|
||||
|
||||
if !tc.IsEnabled() || !tc.IsOnline() {
|
||||
return fmt.Errorf("TCP interface %s is not online", tc.Name)
|
||||
}
|
||||
|
||||
// For TCP interface, we need to prepend a packet type byte for announce packets
|
||||
// RNS TCP protocol expects: [packet_type][data]
|
||||
frame := make([]byte, 0, len(data)+1)
|
||||
frame = append(frame, 0x01) // Announce packet type
|
||||
frame = append(frame, data...)
|
||||
|
||||
return tc.ProcessOutgoing(frame)
|
||||
// Send data directly - packet type is already in the first byte of data
|
||||
// TCP interface uses HDLC framing around the raw packet
|
||||
return tc.ProcessOutgoing(data)
|
||||
}
|
||||
|
||||
func (tc *TCPClientInterface) ProcessOutgoing(data []byte) error {
|
||||
if !tc.Online {
|
||||
tc.Mutex.RLock()
|
||||
online := tc.Online
|
||||
tc.Mutex.RUnlock()
|
||||
|
||||
if !online {
|
||||
return fmt.Errorf("interface offline")
|
||||
}
|
||||
|
||||
@@ -220,13 +288,21 @@ func (tc *TCPClientInterface) ProcessOutgoing(data []byte) error {
|
||||
frame = append([]byte{HDLC_FLAG}, escapeHDLC(data)...)
|
||||
frame = append(frame, HDLC_FLAG)
|
||||
|
||||
// Update TX stats before sending
|
||||
tc.UpdateStats(uint64(len(frame)), false)
|
||||
tc.UpdateStats(uint64(len(frame)), false) // #nosec G115
|
||||
|
||||
log.Printf("[DEBUG-7] TCP interface %s: Writing %d bytes to network", tc.Name, len(frame))
|
||||
_, err := tc.conn.Write(frame)
|
||||
debug.Log(debug.DEBUG_ALL, "TCP interface writing to network", "name", tc.Name, "bytes", len(frame))
|
||||
|
||||
tc.Mutex.RLock()
|
||||
conn := tc.conn
|
||||
tc.Mutex.RUnlock()
|
||||
|
||||
if conn == nil {
|
||||
return fmt.Errorf("connection closed")
|
||||
}
|
||||
|
||||
_, err := conn.Write(frame)
|
||||
if err != nil {
|
||||
log.Printf("[DEBUG-1] TCP interface %s: Write failed: %v", tc.Name, err)
|
||||
debug.Log(debug.DEBUG_CRITICAL, "TCP interface write failed", "name", tc.Name, "error", err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
@@ -236,7 +312,7 @@ func (tc *TCPClientInterface) teardown() {
|
||||
tc.IN = false
|
||||
tc.OUT = false
|
||||
if tc.conn != nil {
|
||||
tc.conn.Close() // #nosec G104
|
||||
_ = tc.conn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -272,9 +348,9 @@ func (tc *TCPClientInterface) SetPacketCallback(cb common.PacketCallback) {
|
||||
}
|
||||
|
||||
func (tc *TCPClientInterface) IsEnabled() bool {
|
||||
tc.mutex.RLock()
|
||||
defer tc.mutex.RUnlock()
|
||||
return tc.enabled && tc.Online && !tc.Detached
|
||||
tc.Mutex.RLock()
|
||||
defer tc.Mutex.RUnlock()
|
||||
return tc.Enabled && tc.Online && !tc.Detached
|
||||
}
|
||||
|
||||
func (tc *TCPClientInterface) GetName() string {
|
||||
@@ -282,31 +358,31 @@ func (tc *TCPClientInterface) GetName() string {
|
||||
}
|
||||
|
||||
func (tc *TCPClientInterface) GetPacketCallback() common.PacketCallback {
|
||||
tc.mutex.RLock()
|
||||
defer tc.mutex.RUnlock()
|
||||
tc.Mutex.RLock()
|
||||
defer tc.Mutex.RUnlock()
|
||||
return tc.packetCallback
|
||||
}
|
||||
|
||||
func (tc *TCPClientInterface) IsDetached() bool {
|
||||
tc.mutex.RLock()
|
||||
defer tc.mutex.RUnlock()
|
||||
tc.Mutex.RLock()
|
||||
defer tc.Mutex.RUnlock()
|
||||
return tc.Detached
|
||||
}
|
||||
|
||||
func (tc *TCPClientInterface) IsOnline() bool {
|
||||
tc.mutex.RLock()
|
||||
defer tc.mutex.RUnlock()
|
||||
tc.Mutex.RLock()
|
||||
defer tc.Mutex.RUnlock()
|
||||
return tc.Online
|
||||
}
|
||||
|
||||
func (tc *TCPClientInterface) reconnect() {
|
||||
tc.mutex.Lock()
|
||||
tc.Mutex.Lock()
|
||||
if tc.reconnecting {
|
||||
tc.mutex.Unlock()
|
||||
tc.Mutex.Unlock()
|
||||
return
|
||||
}
|
||||
tc.reconnecting = true
|
||||
tc.mutex.Unlock()
|
||||
tc.Mutex.Unlock()
|
||||
|
||||
backoff := time.Second
|
||||
maxBackoff := time.Minute * 5
|
||||
@@ -319,21 +395,19 @@ func (tc *TCPClientInterface) reconnect() {
|
||||
|
||||
conn, err := net.Dial("tcp", addr)
|
||||
if err == nil {
|
||||
tc.mutex.Lock()
|
||||
tc.Mutex.Lock()
|
||||
tc.conn = conn
|
||||
tc.Online = true
|
||||
|
||||
tc.neverConnected = false
|
||||
tc.reconnecting = false
|
||||
tc.mutex.Unlock()
|
||||
tc.Mutex.Unlock()
|
||||
|
||||
go tc.readLoop()
|
||||
return
|
||||
}
|
||||
|
||||
// Log reconnection attempt
|
||||
fmt.Printf("Failed to reconnect to %s (attempt %d/%d): %v\n",
|
||||
addr, retries+1, tc.maxReconnectTries, err)
|
||||
debug.Log(debug.DEBUG_VERBOSE, "Failed to reconnect", "target", net.JoinHostPort(tc.targetAddr, fmt.Sprintf("%d", tc.targetPort)), "attempt", retries+1, "maxTries", tc.maxReconnectTries, "error", err)
|
||||
|
||||
// Wait with exponential backoff
|
||||
time.Sleep(backoff)
|
||||
@@ -347,50 +421,48 @@ func (tc *TCPClientInterface) reconnect() {
|
||||
retries++
|
||||
}
|
||||
|
||||
tc.mutex.Lock()
|
||||
tc.Mutex.Lock()
|
||||
tc.reconnecting = false
|
||||
tc.mutex.Unlock()
|
||||
tc.Mutex.Unlock()
|
||||
|
||||
// If we've exhausted all retries, perform final teardown
|
||||
tc.teardown()
|
||||
fmt.Printf("Failed to reconnect to %s after %d attempts\n",
|
||||
fmt.Sprintf("%s:%d", tc.targetAddr, tc.targetPort), tc.maxReconnectTries)
|
||||
debug.Log(debug.DEBUG_ERROR, "Failed to reconnect after all attempts", "target", net.JoinHostPort(tc.targetAddr, fmt.Sprintf("%d", tc.targetPort)), "maxTries", tc.maxReconnectTries)
|
||||
}
|
||||
|
||||
func (tc *TCPClientInterface) Enable() {
|
||||
tc.mutex.Lock()
|
||||
defer tc.mutex.Unlock()
|
||||
tc.Mutex.Lock()
|
||||
defer tc.Mutex.Unlock()
|
||||
tc.Online = true
|
||||
}
|
||||
|
||||
func (tc *TCPClientInterface) Disable() {
|
||||
tc.mutex.Lock()
|
||||
defer tc.mutex.Unlock()
|
||||
tc.Mutex.Lock()
|
||||
defer tc.Mutex.Unlock()
|
||||
tc.Online = false
|
||||
}
|
||||
|
||||
func (tc *TCPClientInterface) IsConnected() bool {
|
||||
tc.mutex.RLock()
|
||||
defer tc.mutex.RUnlock()
|
||||
tc.Mutex.RLock()
|
||||
defer tc.Mutex.RUnlock()
|
||||
return tc.conn != nil && tc.Online && !tc.reconnecting
|
||||
}
|
||||
|
||||
func (tc *TCPClientInterface) GetRTT() time.Duration {
|
||||
tc.mutex.RLock()
|
||||
defer tc.mutex.RUnlock()
|
||||
tc.Mutex.RLock()
|
||||
defer tc.Mutex.RUnlock()
|
||||
|
||||
if !tc.IsConnected() {
|
||||
return 0
|
||||
}
|
||||
|
||||
if tcpConn, ok := tc.conn.(*net.TCPConn); ok {
|
||||
var rtt time.Duration = 0
|
||||
var rtt time.Duration
|
||||
if runtime.GOOS == "linux" {
|
||||
if info, err := tcpConn.SyscallConn(); err == nil {
|
||||
if err := info.Control(func(fd uintptr) { // #nosec G104
|
||||
rtt = platformGetRTT(fd)
|
||||
}); err != nil {
|
||||
log.Printf("[DEBUG-2] Error in SyscallConn Control: %v", err)
|
||||
debug.Log(debug.DEBUG_ERROR, "Error in SyscallConn Control", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -401,84 +473,50 @@ func (tc *TCPClientInterface) GetRTT() time.Duration {
|
||||
}
|
||||
|
||||
func (tc *TCPClientInterface) GetTxBytes() uint64 {
|
||||
tc.mutex.RLock()
|
||||
defer tc.mutex.RUnlock()
|
||||
tc.Mutex.RLock()
|
||||
defer tc.Mutex.RUnlock()
|
||||
return tc.TxBytes
|
||||
}
|
||||
|
||||
func (tc *TCPClientInterface) GetRxBytes() uint64 {
|
||||
tc.mutex.RLock()
|
||||
defer tc.mutex.RUnlock()
|
||||
tc.Mutex.RLock()
|
||||
defer tc.Mutex.RUnlock()
|
||||
return tc.RxBytes
|
||||
}
|
||||
|
||||
func (tc *TCPClientInterface) UpdateStats(bytes uint64, isRx bool) {
|
||||
tc.mutex.Lock()
|
||||
defer tc.mutex.Unlock()
|
||||
tc.Mutex.Lock()
|
||||
defer tc.Mutex.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
if isRx {
|
||||
tc.RxBytes += bytes
|
||||
tc.lastRx = now
|
||||
log.Printf("[DEBUG-5] Interface %s RX stats: bytes=%d total=%d last=%v",
|
||||
tc.Name, bytes, tc.RxBytes, tc.lastRx)
|
||||
debug.Log(debug.DEBUG_TRACE, "Interface RX stats", "name", tc.Name, "bytes", bytes, "total", tc.RxBytes, "last", tc.lastRx)
|
||||
} else {
|
||||
tc.TxBytes += bytes
|
||||
tc.lastTx = now
|
||||
log.Printf("[DEBUG-5] Interface %s TX stats: bytes=%d total=%d last=%v",
|
||||
tc.Name, bytes, tc.TxBytes, tc.lastTx)
|
||||
debug.Log(debug.DEBUG_TRACE, "Interface TX stats", "name", tc.Name, "bytes", bytes, "total", tc.TxBytes, "last", tc.lastTx)
|
||||
}
|
||||
}
|
||||
|
||||
func (tc *TCPClientInterface) GetStats() (tx uint64, rx uint64, lastTx time.Time, lastRx time.Time) {
|
||||
tc.mutex.RLock()
|
||||
defer tc.mutex.RUnlock()
|
||||
tc.Mutex.RLock()
|
||||
defer tc.Mutex.RUnlock()
|
||||
return tc.TxBytes, tc.RxBytes, tc.lastTx, tc.lastRx
|
||||
}
|
||||
|
||||
func (tc *TCPClientInterface) setTimeoutsLinux() error {
|
||||
tcpConn, ok := tc.conn.(*net.TCPConn)
|
||||
if !ok {
|
||||
return fmt.Errorf("not a TCP connection")
|
||||
}
|
||||
|
||||
if !tc.i2pTunneled {
|
||||
if err := tcpConn.SetKeepAlive(true); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := tcpConn.SetKeepAlivePeriod(time.Duration(TCP_PROBE_INTERVAL) * time.Second); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tc *TCPClientInterface) setTimeoutsOSX() error {
|
||||
tcpConn, ok := tc.conn.(*net.TCPConn)
|
||||
if !ok {
|
||||
return fmt.Errorf("not a TCP connection")
|
||||
}
|
||||
|
||||
if err := tcpConn.SetKeepAlive(true); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type TCPServerInterface struct {
|
||||
BaseInterface
|
||||
connections map[string]net.Conn
|
||||
mutex sync.RWMutex
|
||||
bindAddr string
|
||||
bindPort int
|
||||
preferIPv6 bool
|
||||
kissFraming bool
|
||||
i2pTunneled bool
|
||||
packetCallback common.PacketCallback
|
||||
TxBytes uint64
|
||||
RxBytes uint64
|
||||
connections map[string]net.Conn
|
||||
listener net.Listener
|
||||
bindAddr string
|
||||
bindPort int
|
||||
preferIPv6 bool
|
||||
kissFraming bool
|
||||
i2pTunneled bool
|
||||
done chan struct{}
|
||||
stopOnce sync.Once
|
||||
}
|
||||
|
||||
func NewTCPServerInterface(name string, bindAddr string, bindPort int, kissFraming bool, i2pTunneled bool, preferIPv6 bool) (*TCPServerInterface, error) {
|
||||
@@ -489,6 +527,7 @@ func NewTCPServerInterface(name string, bindAddr string, bindPort int, kissFrami
|
||||
Type: common.IF_TYPE_TCP,
|
||||
Online: false,
|
||||
MTU: common.DEFAULT_MTU,
|
||||
Enabled: true,
|
||||
Detached: false,
|
||||
},
|
||||
connections: make(map[string]net.Conn),
|
||||
@@ -497,6 +536,7 @@ func NewTCPServerInterface(name string, bindAddr string, bindPort int, kissFrami
|
||||
preferIPv6: preferIPv6,
|
||||
kissFraming: kissFraming,
|
||||
i2pTunneled: i2pTunneled,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
return ts, nil
|
||||
@@ -515,21 +555,21 @@ func (ts *TCPServerInterface) String() string {
|
||||
}
|
||||
|
||||
func (ts *TCPServerInterface) SetPacketCallback(callback common.PacketCallback) {
|
||||
ts.mutex.Lock()
|
||||
defer ts.mutex.Unlock()
|
||||
ts.Mutex.Lock()
|
||||
defer ts.Mutex.Unlock()
|
||||
ts.packetCallback = callback
|
||||
}
|
||||
|
||||
func (ts *TCPServerInterface) GetPacketCallback() common.PacketCallback {
|
||||
ts.mutex.RLock()
|
||||
defer ts.mutex.RUnlock()
|
||||
ts.Mutex.RLock()
|
||||
defer ts.Mutex.RUnlock()
|
||||
return ts.packetCallback
|
||||
}
|
||||
|
||||
func (ts *TCPServerInterface) IsEnabled() bool {
|
||||
ts.mutex.RLock()
|
||||
defer ts.mutex.RUnlock()
|
||||
return ts.BaseInterface.Enabled && ts.BaseInterface.Online && !ts.BaseInterface.Detached
|
||||
ts.Mutex.RLock()
|
||||
defer ts.Mutex.RUnlock()
|
||||
return ts.Enabled && ts.Online && !ts.Detached
|
||||
}
|
||||
|
||||
func (ts *TCPServerInterface) GetName() string {
|
||||
@@ -537,50 +577,81 @@ func (ts *TCPServerInterface) GetName() string {
|
||||
}
|
||||
|
||||
func (ts *TCPServerInterface) IsDetached() bool {
|
||||
ts.mutex.RLock()
|
||||
defer ts.mutex.RUnlock()
|
||||
return ts.BaseInterface.Detached
|
||||
ts.Mutex.RLock()
|
||||
defer ts.Mutex.RUnlock()
|
||||
return ts.Detached
|
||||
}
|
||||
|
||||
func (ts *TCPServerInterface) IsOnline() bool {
|
||||
ts.mutex.RLock()
|
||||
defer ts.mutex.RUnlock()
|
||||
ts.Mutex.RLock()
|
||||
defer ts.Mutex.RUnlock()
|
||||
return ts.Online
|
||||
}
|
||||
|
||||
func (ts *TCPServerInterface) Enable() {
|
||||
ts.mutex.Lock()
|
||||
defer ts.mutex.Unlock()
|
||||
ts.Mutex.Lock()
|
||||
defer ts.Mutex.Unlock()
|
||||
ts.Online = true
|
||||
}
|
||||
|
||||
func (ts *TCPServerInterface) Disable() {
|
||||
ts.mutex.Lock()
|
||||
defer ts.mutex.Unlock()
|
||||
ts.Mutex.Lock()
|
||||
defer ts.Mutex.Unlock()
|
||||
ts.Online = false
|
||||
}
|
||||
|
||||
func (ts *TCPServerInterface) Start() error {
|
||||
ts.mutex.Lock()
|
||||
defer ts.mutex.Unlock()
|
||||
ts.Mutex.Lock()
|
||||
if ts.listener != nil {
|
||||
ts.Mutex.Unlock()
|
||||
return fmt.Errorf("TCP server already started")
|
||||
}
|
||||
// Only recreate done if it's nil or was closed
|
||||
select {
|
||||
case <-ts.done:
|
||||
ts.done = make(chan struct{})
|
||||
ts.stopOnce = sync.Once{}
|
||||
default:
|
||||
if ts.done == nil {
|
||||
ts.done = make(chan struct{})
|
||||
ts.stopOnce = sync.Once{}
|
||||
}
|
||||
}
|
||||
ts.Mutex.Unlock()
|
||||
|
||||
addr := fmt.Sprintf("%s:%d", ts.bindAddr, ts.bindPort)
|
||||
addr := net.JoinHostPort(ts.bindAddr, fmt.Sprintf("%d", ts.bindPort))
|
||||
listener, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to start TCP server: %w", err)
|
||||
}
|
||||
|
||||
ts.Mutex.Lock()
|
||||
ts.listener = listener
|
||||
ts.Online = true
|
||||
ts.Mutex.Unlock()
|
||||
|
||||
// Accept connections in a goroutine
|
||||
go func() {
|
||||
for {
|
||||
ts.Mutex.RLock()
|
||||
done := ts.done
|
||||
ts.Mutex.RUnlock()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
if !ts.Online {
|
||||
ts.Mutex.RLock()
|
||||
online := ts.Online
|
||||
ts.Mutex.RUnlock()
|
||||
if !online {
|
||||
return // Normal shutdown
|
||||
}
|
||||
log.Printf("[DEBUG-2] Error accepting connection: %v", err)
|
||||
debug.Log(debug.DEBUG_ERROR, "Error accepting connection", "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -593,60 +664,87 @@ func (ts *TCPServerInterface) Start() error {
|
||||
}
|
||||
|
||||
func (ts *TCPServerInterface) Stop() error {
|
||||
ts.mutex.Lock()
|
||||
defer ts.mutex.Unlock()
|
||||
|
||||
ts.Mutex.Lock()
|
||||
ts.Online = false
|
||||
if ts.listener != nil {
|
||||
_ = ts.listener.Close()
|
||||
ts.listener = nil
|
||||
}
|
||||
// Close all client connections
|
||||
for addr, conn := range ts.connections {
|
||||
_ = conn.Close()
|
||||
delete(ts.connections, addr)
|
||||
}
|
||||
ts.Mutex.Unlock()
|
||||
|
||||
ts.stopOnce.Do(func() {
|
||||
if ts.done != nil {
|
||||
close(ts.done)
|
||||
}
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ts *TCPServerInterface) GetTxBytes() uint64 {
|
||||
ts.mutex.RLock()
|
||||
defer ts.mutex.RUnlock()
|
||||
ts.Mutex.RLock()
|
||||
defer ts.Mutex.RUnlock()
|
||||
return ts.TxBytes
|
||||
}
|
||||
|
||||
func (ts *TCPServerInterface) GetRxBytes() uint64 {
|
||||
ts.mutex.RLock()
|
||||
defer ts.mutex.RUnlock()
|
||||
ts.Mutex.RLock()
|
||||
defer ts.Mutex.RUnlock()
|
||||
return ts.RxBytes
|
||||
}
|
||||
|
||||
func (ts *TCPServerInterface) handleConnection(conn net.Conn) {
|
||||
addr := conn.RemoteAddr().String()
|
||||
ts.mutex.Lock()
|
||||
ts.Mutex.Lock()
|
||||
ts.connections[addr] = conn
|
||||
ts.mutex.Unlock()
|
||||
ts.Mutex.Unlock()
|
||||
|
||||
defer func() {
|
||||
ts.mutex.Lock()
|
||||
ts.Mutex.Lock()
|
||||
delete(ts.connections, addr)
|
||||
ts.mutex.Unlock()
|
||||
conn.Close() // #nosec G104
|
||||
ts.Mutex.Unlock()
|
||||
_ = conn.Close()
|
||||
}()
|
||||
|
||||
buffer := make([]byte, ts.MTU)
|
||||
for {
|
||||
ts.Mutex.RLock()
|
||||
done := ts.done
|
||||
ts.Mutex.RUnlock()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
n, err := conn.Read(buffer)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ts.mutex.Lock()
|
||||
ts.Mutex.Lock()
|
||||
ts.RxBytes += uint64(n) // #nosec G115
|
||||
ts.mutex.Unlock()
|
||||
callback := ts.packetCallback
|
||||
ts.Mutex.Unlock()
|
||||
|
||||
if ts.packetCallback != nil {
|
||||
ts.packetCallback(buffer[:n], ts)
|
||||
if callback != nil {
|
||||
callback(buffer[:n], ts)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ts *TCPServerInterface) ProcessOutgoing(data []byte) error {
|
||||
ts.mutex.RLock()
|
||||
defer ts.mutex.RUnlock()
|
||||
ts.Mutex.RLock()
|
||||
online := ts.Online
|
||||
ts.Mutex.RUnlock()
|
||||
|
||||
if !ts.Online {
|
||||
if !online {
|
||||
return fmt.Errorf("interface offline")
|
||||
}
|
||||
|
||||
@@ -659,12 +757,17 @@ func (ts *TCPServerInterface) ProcessOutgoing(data []byte) error {
|
||||
frame = append(frame, HDLC_FLAG)
|
||||
}
|
||||
|
||||
ts.TxBytes += uint64(len(frame))
|
||||
|
||||
ts.Mutex.Lock()
|
||||
ts.TxBytes += uint64(len(frame)) // #nosec G115
|
||||
conns := make([]net.Conn, 0, len(ts.connections))
|
||||
for _, conn := range ts.connections {
|
||||
conns = append(conns, conn)
|
||||
}
|
||||
ts.Mutex.Unlock()
|
||||
|
||||
for _, conn := range conns {
|
||||
if _, err := conn.Write(frame); err != nil {
|
||||
log.Printf("[DEBUG-4] Error writing to connection %s: %v",
|
||||
conn.RemoteAddr(), err)
|
||||
debug.Log(debug.DEBUG_VERBOSE, "Error writing to connection", "address", conn.RemoteAddr(), "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
// SPDX-License-Identifier: 0BSD
|
||||
// Copyright (c) 2024-2026 Sudo-Ivan / Quad4.io
|
||||
//go:build !linux || tinygo
|
||||
// +build !linux tinygo
|
||||
|
||||
|
||||
61
pkg/interfaces/tcp_darwin.go
Normal file
61
pkg/interfaces/tcp_darwin.go
Normal file
@@ -0,0 +1,61 @@
|
||||
// SPDX-License-Identifier: 0BSD
|
||||
// Copyright (c) 2024-2026 Sudo-Ivan / Quad4.io
|
||||
//go:build darwin
|
||||
// +build darwin
|
||||
|
||||
package interfaces
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"syscall"
|
||||
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/debug"
|
||||
)
|
||||
|
||||
func (tc *TCPClientInterface) setTimeoutsLinux() error {
|
||||
return tc.setTimeoutsOSX()
|
||||
}
|
||||
|
||||
func (tc *TCPClientInterface) setTimeoutsOSX() error {
|
||||
tcpConn, ok := tc.conn.(*net.TCPConn)
|
||||
if !ok {
|
||||
return fmt.Errorf("not a TCP connection")
|
||||
}
|
||||
|
||||
rawConn, err := tcpConn.SyscallConn()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get raw connection: %v", err)
|
||||
}
|
||||
|
||||
var sockoptErr error
|
||||
err = rawConn.Control(func(fd uintptr) {
|
||||
const TCP_KEEPALIVE = 0x10
|
||||
|
||||
var probeAfter int
|
||||
if tc.i2pTunneled {
|
||||
probeAfter = I2P_PROBE_AFTER_SEC
|
||||
} else {
|
||||
probeAfter = TCP_PROBE_AFTER_SEC
|
||||
}
|
||||
|
||||
if err := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_KEEPALIVE, SO_KEEPALIVE_ENABLE); err != nil {
|
||||
sockoptErr = fmt.Errorf("failed to enable SO_KEEPALIVE: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := syscall.SetsockoptInt(int(fd), syscall.IPPROTO_TCP, TCP_KEEPALIVE, probeAfter); err != nil {
|
||||
debug.Log(debug.DEBUG_VERBOSE, "Failed to set TCP_KEEPALIVE", "error", err)
|
||||
}
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("control failed: %v", err)
|
||||
}
|
||||
if sockoptErr != nil {
|
||||
return sockoptErr
|
||||
}
|
||||
|
||||
debug.Log(debug.DEBUG_VERBOSE, "TCP keepalive configured (OSX)", "i2p", tc.i2pTunneled)
|
||||
return nil
|
||||
}
|
||||
41
pkg/interfaces/tcp_freebsd.go
Normal file
41
pkg/interfaces/tcp_freebsd.go
Normal file
@@ -0,0 +1,41 @@
|
||||
// SPDX-License-Identifier: 0BSD
|
||||
// Copyright (c) 2024-2026 Sudo-Ivan / Quad4.io
|
||||
//go:build freebsd
|
||||
// +build freebsd
|
||||
|
||||
package interfaces
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/debug"
|
||||
)
|
||||
|
||||
func (tc *TCPClientInterface) setTimeoutsLinux() error {
|
||||
tcpConn, ok := tc.conn.(*net.TCPConn)
|
||||
if !ok {
|
||||
return fmt.Errorf("not a TCP connection")
|
||||
}
|
||||
|
||||
if err := tcpConn.SetKeepAlive(true); err != nil {
|
||||
return fmt.Errorf("failed to enable keepalive: %v", err)
|
||||
}
|
||||
|
||||
keepalivePeriod := TCP_PROBE_INTERVAL_SEC * time.Second
|
||||
if tc.i2pTunneled {
|
||||
keepalivePeriod = I2P_PROBE_INTERVAL_SEC * time.Second
|
||||
}
|
||||
|
||||
if err := tcpConn.SetKeepAlivePeriod(keepalivePeriod); err != nil {
|
||||
debug.Log(debug.DEBUG_VERBOSE, "Failed to set keepalive period", "error", err)
|
||||
}
|
||||
|
||||
debug.Log(debug.DEBUG_VERBOSE, "TCP keepalive configured (FreeBSD)", "i2p", tc.i2pTunneled)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tc *TCPClientInterface) setTimeoutsOSX() error {
|
||||
return tc.setTimeoutsLinux()
|
||||
}
|
||||
@@ -1,32 +1,111 @@
|
||||
// SPDX-License-Identifier: 0BSD
|
||||
// Copyright (c) 2024-2026 Sudo-Ivan / Quad4.io
|
||||
//go:build linux && !tinygo
|
||||
// +build linux,!tinygo
|
||||
|
||||
package interfaces
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"syscall"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/debug"
|
||||
)
|
||||
|
||||
func (tc *TCPClientInterface) setTimeoutsLinux() error {
|
||||
tcpConn, ok := tc.conn.(*net.TCPConn)
|
||||
if !ok {
|
||||
return fmt.Errorf("not a TCP connection")
|
||||
}
|
||||
|
||||
rawConn, err := tcpConn.SyscallConn()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get raw connection: %v", err)
|
||||
}
|
||||
|
||||
var sockoptErr error
|
||||
err = rawConn.Control(func(fd uintptr) {
|
||||
var userTimeout, probeAfter, probeInterval, probeCount int
|
||||
|
||||
if tc.i2pTunneled {
|
||||
userTimeout = I2P_USER_TIMEOUT_SEC * TCP_MILLISECONDS
|
||||
probeAfter = I2P_PROBE_AFTER_SEC
|
||||
probeInterval = I2P_PROBE_INTERVAL_SEC
|
||||
probeCount = I2P_PROBES_COUNT
|
||||
} else {
|
||||
userTimeout = TCP_USER_TIMEOUT_SEC * TCP_MILLISECONDS
|
||||
probeAfter = TCP_PROBE_AFTER_SEC
|
||||
probeInterval = TCP_PROBE_INTERVAL_SEC
|
||||
probeCount = TCP_PROBES_COUNT
|
||||
}
|
||||
|
||||
const TCP_USER_TIMEOUT = 18
|
||||
const TCP_KEEPIDLE = 4
|
||||
const TCP_KEEPINTVL = 5
|
||||
const TCP_KEEPCNT = 6
|
||||
|
||||
if err := syscall.SetsockoptInt(int(fd), syscall.IPPROTO_TCP, TCP_USER_TIMEOUT, userTimeout); err != nil {
|
||||
debug.Log(debug.DEBUG_VERBOSE, "Failed to set TCP_USER_TIMEOUT", "error", err)
|
||||
}
|
||||
|
||||
if err := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_KEEPALIVE, SO_KEEPALIVE_ENABLE); err != nil {
|
||||
sockoptErr = fmt.Errorf("failed to enable SO_KEEPALIVE: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := syscall.SetsockoptInt(int(fd), syscall.IPPROTO_TCP, TCP_KEEPIDLE, probeAfter); err != nil {
|
||||
debug.Log(debug.DEBUG_VERBOSE, "Failed to set TCP_KEEPIDLE", "error", err)
|
||||
}
|
||||
|
||||
if err := syscall.SetsockoptInt(int(fd), syscall.IPPROTO_TCP, TCP_KEEPINTVL, probeInterval); err != nil {
|
||||
debug.Log(debug.DEBUG_VERBOSE, "Failed to set TCP_KEEPINTVL", "error", err)
|
||||
}
|
||||
|
||||
if err := syscall.SetsockoptInt(int(fd), syscall.IPPROTO_TCP, TCP_KEEPCNT, probeCount); err != nil {
|
||||
debug.Log(debug.DEBUG_VERBOSE, "Failed to set TCP_KEEPCNT", "error", err)
|
||||
}
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("control failed: %v", err)
|
||||
}
|
||||
if sockoptErr != nil {
|
||||
return sockoptErr
|
||||
}
|
||||
|
||||
debug.Log(debug.DEBUG_VERBOSE, "TCP keepalive configured (Linux)", "i2p", tc.i2pTunneled)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tc *TCPClientInterface) setTimeoutsOSX() error {
|
||||
return tc.setTimeoutsLinux()
|
||||
}
|
||||
|
||||
func platformGetRTT(fd uintptr) time.Duration {
|
||||
var info syscall.TCPInfo
|
||||
size := uint32(syscall.SizeofTCPInfo)
|
||||
// bearer:disable go_gosec_unsafe_unsafe
|
||||
infoLen := uint32(unsafe.Sizeof(info))
|
||||
|
||||
_, _, err := syscall.Syscall6(
|
||||
const TCP_INFO = 11
|
||||
// #nosec G103
|
||||
_, _, errno := syscall.Syscall6(
|
||||
syscall.SYS_GETSOCKOPT,
|
||||
fd,
|
||||
syscall.SOL_TCP,
|
||||
syscall.TCP_INFO,
|
||||
uintptr(unsafe.Pointer(&info)), // #nosec G103
|
||||
uintptr(unsafe.Pointer(&size)), // #nosec G103
|
||||
syscall.IPPROTO_TCP,
|
||||
TCP_INFO,
|
||||
// bearer:disable go_gosec_unsafe_unsafe
|
||||
uintptr(unsafe.Pointer(&info)),
|
||||
// bearer:disable go_gosec_unsafe_unsafe
|
||||
uintptr(unsafe.Pointer(&infoLen)),
|
||||
0,
|
||||
)
|
||||
|
||||
if err != 0 {
|
||||
if errno != 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
// RTT is in microseconds, convert to Duration
|
||||
return time.Duration(info.Rtt) * time.Microsecond
|
||||
}
|
||||
|
||||
41
pkg/interfaces/tcp_netbsd.go
Normal file
41
pkg/interfaces/tcp_netbsd.go
Normal file
@@ -0,0 +1,41 @@
|
||||
// SPDX-License-Identifier: 0BSD
|
||||
// Copyright (c) 2024-2026 Sudo-Ivan / Quad4.io
|
||||
//go:build netbsd
|
||||
// +build netbsd
|
||||
|
||||
package interfaces
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/debug"
|
||||
)
|
||||
|
||||
func (tc *TCPClientInterface) setTimeoutsLinux() error {
|
||||
tcpConn, ok := tc.conn.(*net.TCPConn)
|
||||
if !ok {
|
||||
return fmt.Errorf("not a TCP connection")
|
||||
}
|
||||
|
||||
if err := tcpConn.SetKeepAlive(true); err != nil {
|
||||
return fmt.Errorf("failed to enable keepalive: %v", err)
|
||||
}
|
||||
|
||||
keepalivePeriod := TCP_PROBE_INTERVAL_SEC * time.Second
|
||||
if tc.i2pTunneled {
|
||||
keepalivePeriod = I2P_PROBE_INTERVAL_SEC * time.Second
|
||||
}
|
||||
|
||||
if err := tcpConn.SetKeepAlivePeriod(keepalivePeriod); err != nil {
|
||||
debug.Log(debug.DEBUG_VERBOSE, "Failed to set keepalive period", "error", err)
|
||||
}
|
||||
|
||||
debug.Log(debug.DEBUG_VERBOSE, "TCP keepalive configured (NetBSD)", "i2p", tc.i2pTunneled)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tc *TCPClientInterface) setTimeoutsOSX() error {
|
||||
return tc.setTimeoutsLinux()
|
||||
}
|
||||
41
pkg/interfaces/tcp_openbsd.go
Normal file
41
pkg/interfaces/tcp_openbsd.go
Normal file
@@ -0,0 +1,41 @@
|
||||
// SPDX-License-Identifier: 0BSD
|
||||
// Copyright (c) 2024-2026 Sudo-Ivan / Quad4.io
|
||||
//go:build openbsd
|
||||
// +build openbsd
|
||||
|
||||
package interfaces
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/debug"
|
||||
)
|
||||
|
||||
func (tc *TCPClientInterface) setTimeoutsLinux() error {
|
||||
tcpConn, ok := tc.conn.(*net.TCPConn)
|
||||
if !ok {
|
||||
return fmt.Errorf("not a TCP connection")
|
||||
}
|
||||
|
||||
if err := tcpConn.SetKeepAlive(true); err != nil {
|
||||
return fmt.Errorf("failed to enable keepalive: %v", err)
|
||||
}
|
||||
|
||||
keepalivePeriod := TCP_PROBE_INTERVAL_SEC * time.Second
|
||||
if tc.i2pTunneled {
|
||||
keepalivePeriod = I2P_PROBE_INTERVAL_SEC * time.Second
|
||||
}
|
||||
|
||||
if err := tcpConn.SetKeepAlivePeriod(keepalivePeriod); err != nil {
|
||||
debug.Log(debug.DEBUG_VERBOSE, "Failed to set keepalive period", "error", err)
|
||||
}
|
||||
|
||||
debug.Log(debug.DEBUG_VERBOSE, "TCP keepalive configured (OpenBSD)", "i2p", tc.i2pTunneled)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tc *TCPClientInterface) setTimeoutsOSX() error {
|
||||
return tc.setTimeoutsLinux()
|
||||
}
|
||||
14
pkg/interfaces/tcp_wasm.go
Normal file
14
pkg/interfaces/tcp_wasm.go
Normal file
@@ -0,0 +1,14 @@
|
||||
// SPDX-License-Identifier: 0BSD
|
||||
// Copyright (c) 2024-2026 Sudo-Ivan / Quad4.io
|
||||
//go:build js && wasm
|
||||
// +build js,wasm
|
||||
|
||||
package interfaces
|
||||
|
||||
func (tc *TCPClientInterface) setTimeoutsLinux() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tc *TCPClientInterface) setTimeoutsOSX() error {
|
||||
return nil
|
||||
}
|
||||
42
pkg/interfaces/tcp_windows.go
Normal file
42
pkg/interfaces/tcp_windows.go
Normal file
@@ -0,0 +1,42 @@
|
||||
// SPDX-License-Identifier: 0BSD
|
||||
// Copyright (c) 2024-2026 Sudo-Ivan / Quad4.io
|
||||
package interfaces
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/debug"
|
||||
)
|
||||
|
||||
func (tc *TCPClientInterface) setTimeoutsLinux() error {
|
||||
return tc.setTimeoutsWindows()
|
||||
}
|
||||
|
||||
func (tc *TCPClientInterface) setTimeoutsOSX() error {
|
||||
return tc.setTimeoutsWindows()
|
||||
}
|
||||
|
||||
func (tc *TCPClientInterface) setTimeoutsWindows() error {
|
||||
tcpConn, ok := tc.conn.(*net.TCPConn)
|
||||
if !ok {
|
||||
return fmt.Errorf("not a TCP connection")
|
||||
}
|
||||
|
||||
if err := tcpConn.SetKeepAlive(true); err != nil {
|
||||
return fmt.Errorf("failed to enable keepalive: %v", err)
|
||||
}
|
||||
|
||||
keepalivePeriod := TCP_PROBE_INTERVAL_SEC * time.Second
|
||||
if tc.i2pTunneled {
|
||||
keepalivePeriod = I2P_PROBE_INTERVAL_SEC * time.Second
|
||||
}
|
||||
|
||||
if err := tcpConn.SetKeepAlivePeriod(keepalivePeriod); err != nil {
|
||||
debug.Log(debug.DEBUG_VERBOSE, "Failed to set keepalive period", "error", err)
|
||||
}
|
||||
|
||||
debug.Log(debug.DEBUG_VERBOSE, "TCP keepalive configured (Windows)", "i2p", tc.i2pTunneled)
|
||||
return nil
|
||||
}
|
||||
@@ -1,21 +1,27 @@
|
||||
// SPDX-License-Identifier: 0BSD
|
||||
// Copyright (c) 2024-2026 Sudo-Ivan / Quad4.io
|
||||
//go:build !tinygo
|
||||
// +build !tinygo
|
||||
|
||||
package interfaces
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/Sudo-Ivan/reticulum-go/pkg/common"
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/common"
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/debug"
|
||||
)
|
||||
|
||||
type UDPInterface struct {
|
||||
BaseInterface
|
||||
conn net.Conn
|
||||
conn *net.UDPConn
|
||||
addr *net.UDPAddr
|
||||
targetAddr *net.UDPAddr
|
||||
mutex sync.RWMutex
|
||||
readBuffer []byte
|
||||
done chan struct{}
|
||||
stopOnce sync.Once
|
||||
}
|
||||
|
||||
func NewUDPInterface(name string, addr string, target string, enabled bool) (*UDPInterface, error) {
|
||||
@@ -36,66 +42,52 @@ func NewUDPInterface(name string, addr string, target string, enabled bool) (*UD
|
||||
BaseInterface: NewBaseInterface(name, common.IF_TYPE_UDP, enabled),
|
||||
addr: udpAddr,
|
||||
targetAddr: targetAddr,
|
||||
readBuffer: make([]byte, common.DEFAULT_MTU),
|
||||
readBuffer: make([]byte, common.NUM_1064),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
ui.MTU = common.NUM_1064
|
||||
|
||||
return ui, nil
|
||||
}
|
||||
|
||||
func (ui *UDPInterface) GetName() string {
|
||||
return ui.Name
|
||||
}
|
||||
|
||||
func (ui *UDPInterface) GetType() common.InterfaceType {
|
||||
return ui.Type
|
||||
}
|
||||
|
||||
func (ui *UDPInterface) GetMode() common.InterfaceMode {
|
||||
return ui.Mode
|
||||
}
|
||||
|
||||
func (ui *UDPInterface) IsOnline() bool {
|
||||
ui.mutex.RLock()
|
||||
defer ui.mutex.RUnlock()
|
||||
return ui.Online
|
||||
}
|
||||
|
||||
func (ui *UDPInterface) IsDetached() bool {
|
||||
ui.mutex.RLock()
|
||||
defer ui.mutex.RUnlock()
|
||||
return ui.Detached
|
||||
}
|
||||
|
||||
func (ui *UDPInterface) Detach() {
|
||||
ui.mutex.Lock()
|
||||
defer ui.mutex.Unlock()
|
||||
ui.Mutex.Lock()
|
||||
defer ui.Mutex.Unlock()
|
||||
ui.Detached = true
|
||||
ui.Online = false
|
||||
if ui.conn != nil {
|
||||
ui.conn.Close() // #nosec G104
|
||||
}
|
||||
ui.stopOnce.Do(func() {
|
||||
if ui.done != nil {
|
||||
close(ui.done)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (ui *UDPInterface) Send(data []byte, addr string) error {
|
||||
// TinyGo doesn't support UDP sending
|
||||
return fmt.Errorf("UDPInterface Send not supported in TinyGo - requires UDP client functionality")
|
||||
}
|
||||
debug.Log(debug.DEBUG_ALL, "UDP interface sending bytes", "name", ui.Name, "bytes", len(data))
|
||||
|
||||
func (ui *UDPInterface) SetPacketCallback(callback common.PacketCallback) {
|
||||
ui.mutex.Lock()
|
||||
defer ui.mutex.Unlock()
|
||||
ui.packetCallback = callback
|
||||
}
|
||||
|
||||
func (ui *UDPInterface) GetPacketCallback() common.PacketCallback {
|
||||
ui.mutex.RLock()
|
||||
defer ui.mutex.RUnlock()
|
||||
return ui.packetCallback
|
||||
}
|
||||
|
||||
func (ui *UDPInterface) ProcessIncoming(data []byte) {
|
||||
if callback := ui.GetPacketCallback(); callback != nil {
|
||||
callback(data, ui)
|
||||
if !ui.IsEnabled() {
|
||||
return fmt.Errorf("interface not enabled")
|
||||
}
|
||||
|
||||
if ui.targetAddr == nil {
|
||||
return fmt.Errorf("no target address configured")
|
||||
}
|
||||
|
||||
ui.Mutex.Lock()
|
||||
ui.TxBytes += uint64(len(data))
|
||||
ui.Mutex.Unlock()
|
||||
|
||||
_, err := ui.conn.WriteTo(data, ui.targetAddr)
|
||||
if err != nil {
|
||||
debug.Log(debug.DEBUG_CRITICAL, "UDP interface write failed", "name", ui.Name, "error", err)
|
||||
} else {
|
||||
debug.Log(debug.DEBUG_ALL, "UDP interface sent bytes successfully", "name", ui.Name, "bytes", len(data))
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (ui *UDPInterface) ProcessOutgoing(data []byte) error {
|
||||
@@ -108,9 +100,9 @@ func (ui *UDPInterface) ProcessOutgoing(data []byte) error {
|
||||
return fmt.Errorf("UDP write failed: %v", err)
|
||||
}
|
||||
|
||||
ui.mutex.Lock()
|
||||
ui.Mutex.Lock()
|
||||
ui.TxBytes += uint64(len(data))
|
||||
ui.mutex.Unlock()
|
||||
ui.Mutex.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -119,82 +111,102 @@ func (ui *UDPInterface) GetConn() net.Conn {
|
||||
return ui.conn
|
||||
}
|
||||
|
||||
func (ui *UDPInterface) GetTxBytes() uint64 {
|
||||
ui.mutex.RLock()
|
||||
defer ui.mutex.RUnlock()
|
||||
return ui.TxBytes
|
||||
}
|
||||
|
||||
func (ui *UDPInterface) GetRxBytes() uint64 {
|
||||
ui.mutex.RLock()
|
||||
defer ui.mutex.RUnlock()
|
||||
return ui.RxBytes
|
||||
}
|
||||
|
||||
func (ui *UDPInterface) GetMTU() int {
|
||||
return ui.MTU
|
||||
}
|
||||
|
||||
func (ui *UDPInterface) GetBitrate() int {
|
||||
return int(ui.Bitrate)
|
||||
}
|
||||
|
||||
func (ui *UDPInterface) Enable() {
|
||||
ui.mutex.Lock()
|
||||
defer ui.mutex.Unlock()
|
||||
ui.Online = true
|
||||
}
|
||||
|
||||
func (ui *UDPInterface) Disable() {
|
||||
ui.mutex.Lock()
|
||||
defer ui.mutex.Unlock()
|
||||
ui.Online = false
|
||||
}
|
||||
|
||||
func (ui *UDPInterface) Start() error {
|
||||
// TinyGo doesn't support UDP servers, only clients
|
||||
return fmt.Errorf("UDPInterface not supported in TinyGo - UDP server functionality requires net.ListenUDP")
|
||||
}
|
||||
|
||||
func (ui *UDPInterface) readLoop() {
|
||||
// This method is not used in TinyGo since UDP servers are not supported
|
||||
buffer := make([]byte, common.DEFAULT_MTU)
|
||||
for ui.IsOnline() && !ui.IsDetached() {
|
||||
n, err := ui.conn.Read(buffer)
|
||||
if err != nil {
|
||||
if ui.IsOnline() {
|
||||
log.Printf("Error reading from UDP interface %s: %v", ui.Name, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if ui.packetCallback != nil {
|
||||
ui.packetCallback(buffer[:n], ui)
|
||||
ui.Mutex.Lock()
|
||||
if ui.conn != nil {
|
||||
ui.Mutex.Unlock()
|
||||
return fmt.Errorf("UDP interface already started")
|
||||
}
|
||||
// Only recreate done if it's nil or was closed
|
||||
select {
|
||||
case <-ui.done:
|
||||
ui.done = make(chan struct{})
|
||||
ui.stopOnce = sync.Once{}
|
||||
default:
|
||||
if ui.done == nil {
|
||||
ui.done = make(chan struct{})
|
||||
ui.stopOnce = sync.Once{}
|
||||
}
|
||||
}
|
||||
ui.Mutex.Unlock()
|
||||
|
||||
conn, err := net.ListenUDP("udp", ui.addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ui.conn = conn
|
||||
|
||||
// Enable broadcast mode if we have a target address
|
||||
if ui.targetAddr != nil {
|
||||
// Get the raw connection file descriptor to set SO_BROADCAST
|
||||
if err := conn.SetReadBuffer(common.NUM_1064); err != nil {
|
||||
debug.Log(debug.DEBUG_ERROR, "Failed to set read buffer size", "error", err)
|
||||
}
|
||||
if err := conn.SetWriteBuffer(common.NUM_1064); err != nil {
|
||||
debug.Log(debug.DEBUG_ERROR, "Failed to set write buffer size", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
ui.Mutex.Lock()
|
||||
ui.Online = true
|
||||
ui.Mutex.Unlock()
|
||||
|
||||
// Start the read loop in a goroutine
|
||||
go ui.readLoop()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ui *UDPInterface) Stop() error {
|
||||
ui.Detach()
|
||||
return nil
|
||||
}
|
||||
|
||||
/*
|
||||
func (ui *UDPInterface) readLoop() {
|
||||
buffer := make([]byte, ui.MTU)
|
||||
buffer := make([]byte, common.NUM_1064)
|
||||
for {
|
||||
n, _, err := ui.conn.ReadFromUDP(buffer)
|
||||
ui.Mutex.RLock()
|
||||
online := ui.Online
|
||||
detached := ui.Detached
|
||||
conn := ui.conn
|
||||
done := ui.done
|
||||
ui.Mutex.RUnlock()
|
||||
|
||||
if !online || detached || conn == nil {
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
n, remoteAddr, err := conn.ReadFromUDP(buffer)
|
||||
if err != nil {
|
||||
if ui.Online {
|
||||
log.Printf("Error reading from UDP interface %s: %v", ui.Name, err)
|
||||
ui.Stop() // Consider if stopping is the right action or just log and continue
|
||||
ui.Mutex.RLock()
|
||||
stillOnline := ui.Online
|
||||
ui.Mutex.RUnlock()
|
||||
if stillOnline {
|
||||
debug.Log(debug.DEBUG_ERROR, "Error reading from UDP interface", "name", ui.Name, "error", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
if ui.packetCallback != nil {
|
||||
ui.packetCallback(buffer[:n], ui)
|
||||
|
||||
ui.Mutex.Lock()
|
||||
// #nosec G115 - Network read sizes are always positive and within safe range
|
||||
ui.RxBytes += uint64(n)
|
||||
|
||||
// Auto-discover target address from first packet if not set
|
||||
if ui.targetAddr == nil {
|
||||
debug.Log(debug.DEBUG_ALL, "UDP interface discovered peer", "name", ui.Name, "peer", remoteAddr.String())
|
||||
ui.targetAddr = remoteAddr
|
||||
}
|
||||
callback := ui.packetCallback
|
||||
ui.Mutex.Unlock()
|
||||
|
||||
if callback != nil {
|
||||
callback(buffer[:n], ui)
|
||||
}
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
func (ui *UDPInterface) IsEnabled() bool {
|
||||
ui.mutex.RLock()
|
||||
defer ui.mutex.RUnlock()
|
||||
return ui.Enabled && ui.Online && !ui.Detached
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ package interfaces
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/Sudo-Ivan/reticulum-go/pkg/common"
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/common"
|
||||
)
|
||||
|
||||
func TestNewUDPInterface(t *testing.T) {
|
||||
@@ -25,11 +25,6 @@ func TestNewUDPInterface(t *testing.T) {
|
||||
if ui.GetType() != common.IF_TYPE_UDP {
|
||||
t.Errorf("GetType() = %v; want %v", ui.GetType(), common.IF_TYPE_UDP)
|
||||
}
|
||||
if ui.addr.String() != validAddr && ui.addr.Port == 0 { // Check if address resolved, port 0 is special
|
||||
// Allow OS-assigned port if 0 was specified
|
||||
} else if ui.addr.String() != validAddr {
|
||||
// t.Errorf("Resolved addr = %s; want %s", ui.addr.String(), validAddr) //This check is flaky with port 0
|
||||
}
|
||||
if ui.targetAddr.String() != validTarget {
|
||||
t.Errorf("Resolved targetAddr = %s; want %s", ui.targetAddr.String(), validTarget)
|
||||
}
|
||||
@@ -71,7 +66,6 @@ func TestNewUDPInterface(t *testing.T) {
|
||||
|
||||
func TestUDPInterfaceState(t *testing.T) {
|
||||
// Basic state tests are covered by BaseInterface tests
|
||||
// Add specific UDP ones if needed, e.g., involving the conn
|
||||
addr := "127.0.0.1:0"
|
||||
ui, _ := NewUDPInterface("udpState", addr, "", true)
|
||||
|
||||
|
||||
69
pkg/interfaces/udp_tinygo.go
Normal file
69
pkg/interfaces/udp_tinygo.go
Normal file
@@ -0,0 +1,69 @@
|
||||
// SPDX-License-Identifier: 0BSD
|
||||
// Copyright (c) 2024-2026 Sudo-Ivan / Quad4.io
|
||||
//go:build tinygo
|
||||
// +build tinygo
|
||||
|
||||
package interfaces
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/common"
|
||||
)
|
||||
|
||||
type UDPInterface struct {
|
||||
BaseInterface
|
||||
conn net.Conn
|
||||
addr *net.UDPAddr
|
||||
targetAddr *net.UDPAddr
|
||||
readBuffer []byte
|
||||
done chan struct{}
|
||||
stopOnce sync.Once
|
||||
}
|
||||
|
||||
func NewUDPInterface(name string, addr string, target string, enabled bool) (*UDPInterface, error) {
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var targetAddr *net.UDPAddr
|
||||
if target != "" {
|
||||
targetAddr, err = net.ResolveUDPAddr("udp", target)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
ui := &UDPInterface{
|
||||
BaseInterface: NewBaseInterface(name, common.IF_TYPE_UDP, enabled),
|
||||
addr: udpAddr,
|
||||
targetAddr: targetAddr,
|
||||
readBuffer: make([]byte, common.NUM_1064),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
ui.MTU = common.NUM_1064
|
||||
|
||||
return ui, nil
|
||||
}
|
||||
|
||||
func (ui *UDPInterface) Start() error {
|
||||
// TinyGo doesn't support UDP servers, only clients
|
||||
return fmt.Errorf("UDPInterface not supported in TinyGo - UDP server functionality requires net.ListenUDP")
|
||||
}
|
||||
|
||||
func (ui *UDPInterface) Send(data []byte, addr string) error {
|
||||
// TinyGo doesn't support UDP sending
|
||||
return fmt.Errorf("UDPInterface Send not supported in TinyGo - requires UDP client functionality")
|
||||
}
|
||||
|
||||
func (ui *UDPInterface) Stop() error {
|
||||
ui.Mutex.Lock()
|
||||
defer ui.Mutex.Unlock()
|
||||
ui.Online = false
|
||||
return nil
|
||||
}
|
||||
|
||||
714
pkg/interfaces/websocket_native.go
Normal file
714
pkg/interfaces/websocket_native.go
Normal file
@@ -0,0 +1,714 @@
|
||||
// SPDX-License-Identifier: 0BSD
|
||||
// Copyright (c) 2024-2026 Sudo-Ivan / Quad4.io
|
||||
//go:build !js
|
||||
// +build !js
|
||||
|
||||
// WebSocketInterface is a native implementation of the WebSocket interface.
|
||||
// It is used to connect to the WebSocket server and send/receive data.
|
||||
package interfaces
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/rand"
|
||||
// bearer:disable go_gosec_blocklist_sha1
|
||||
"crypto/sha1" // #nosec G505
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/common"
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/debug"
|
||||
)
|
||||
|
||||
const (
|
||||
wsGUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
|
||||
|
||||
WS_BUFFER_SIZE = 4096
|
||||
WS_MTU = 1064
|
||||
WS_BITRATE = 10000000
|
||||
WS_HTTPS_PORT = 443
|
||||
WS_HTTP_PORT = 80
|
||||
WS_VERSION = "13"
|
||||
WS_CONNECT_TIMEOUT = 10 * time.Second
|
||||
WS_RECONNECT_DELAY = 2 * time.Second
|
||||
WS_KEY_SIZE = 16
|
||||
WS_MASK_KEY_SIZE = 4
|
||||
WS_HEADER_SIZE = 2
|
||||
WS_PAYLOAD_LEN_16BIT = 126
|
||||
WS_PAYLOAD_LEN_64BIT = 127
|
||||
WS_MAX_PAYLOAD_16BIT = 65536
|
||||
WS_FRAME_HEADER_FIN = 0x80
|
||||
WS_FRAME_HEADER_OPCODE = 0x0F
|
||||
WS_FRAME_HEADER_MASKED = 0x80
|
||||
WS_FRAME_HEADER_LEN = 0x7F
|
||||
WS_OPCODE_CONTINUATION = 0x00
|
||||
WS_OPCODE_TEXT = 0x01
|
||||
WS_OPCODE_BINARY = 0x02
|
||||
WS_OPCODE_CLOSE = 0x08
|
||||
WS_OPCODE_PING = 0x09
|
||||
WS_OPCODE_PONG = 0x0A
|
||||
)
|
||||
|
||||
type WebSocketInterface struct {
|
||||
BaseInterface
|
||||
wsURL string
|
||||
conn net.Conn
|
||||
reader *bufio.Reader
|
||||
connected bool
|
||||
messageQueue [][]byte
|
||||
readBuffer []byte
|
||||
writeBuffer []byte
|
||||
done chan struct{}
|
||||
stopOnce sync.Once
|
||||
}
|
||||
|
||||
func NewWebSocketInterface(name string, wsURL string, enabled bool) (*WebSocketInterface, error) {
|
||||
debug.Log(debug.DEBUG_VERBOSE, "NewWebSocketInterface called", "name", name, "url", wsURL, "enabled", enabled)
|
||||
ws := &WebSocketInterface{
|
||||
BaseInterface: NewBaseInterface(name, common.IF_TYPE_UDP, enabled),
|
||||
wsURL: wsURL,
|
||||
messageQueue: make([][]byte, 0),
|
||||
readBuffer: make([]byte, WS_BUFFER_SIZE),
|
||||
writeBuffer: make([]byte, WS_BUFFER_SIZE),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
ws.MTU = WS_MTU
|
||||
ws.Bitrate = WS_BITRATE
|
||||
|
||||
debug.Log(debug.DEBUG_VERBOSE, "WebSocket interface initialized", "name", name, "mtu", ws.MTU, "bitrate", ws.Bitrate)
|
||||
return ws, nil
|
||||
}
|
||||
|
||||
func (wsi *WebSocketInterface) GetName() string {
|
||||
return wsi.Name
|
||||
}
|
||||
|
||||
func (wsi *WebSocketInterface) GetType() common.InterfaceType {
|
||||
return wsi.Type
|
||||
}
|
||||
|
||||
func (wsi *WebSocketInterface) GetMode() common.InterfaceMode {
|
||||
return wsi.Mode
|
||||
}
|
||||
|
||||
func (wsi *WebSocketInterface) IsOnline() bool {
|
||||
wsi.Mutex.RLock()
|
||||
defer wsi.Mutex.RUnlock()
|
||||
return wsi.Online && wsi.connected
|
||||
}
|
||||
|
||||
func (wsi *WebSocketInterface) IsDetached() bool {
|
||||
wsi.Mutex.RLock()
|
||||
defer wsi.Mutex.RUnlock()
|
||||
return wsi.Detached
|
||||
}
|
||||
|
||||
func (wsi *WebSocketInterface) Detach() {
|
||||
wsi.Mutex.Lock()
|
||||
defer wsi.Mutex.Unlock()
|
||||
wsi.Detached = true
|
||||
wsi.Online = false
|
||||
wsi.closeWebSocketLocked()
|
||||
}
|
||||
|
||||
func (wsi *WebSocketInterface) Enable() {
|
||||
wsi.Mutex.Lock()
|
||||
defer wsi.Mutex.Unlock()
|
||||
wsi.Enabled = true
|
||||
wsi.Online = true
|
||||
}
|
||||
|
||||
func (wsi *WebSocketInterface) Disable() {
|
||||
wsi.Mutex.Lock()
|
||||
defer wsi.Mutex.Unlock()
|
||||
wsi.Enabled = false
|
||||
wsi.closeWebSocketLocked()
|
||||
}
|
||||
|
||||
func (wsi *WebSocketInterface) Start() error {
|
||||
wsi.Mutex.Lock()
|
||||
if !wsi.Enabled || wsi.Detached {
|
||||
wsi.Mutex.Unlock()
|
||||
debug.Log(debug.DEBUG_INFO, "WebSocket interface not enabled or detached", "name", wsi.Name)
|
||||
return fmt.Errorf("interface not enabled or detached")
|
||||
}
|
||||
if wsi.conn != nil {
|
||||
wsi.Mutex.Unlock()
|
||||
debug.Log(debug.DEBUG_INFO, "WebSocket already started", "name", wsi.Name)
|
||||
return fmt.Errorf("WebSocket already started")
|
||||
}
|
||||
// Only recreate done if it's nil or was closed
|
||||
select {
|
||||
case <-wsi.done:
|
||||
wsi.done = make(chan struct{})
|
||||
wsi.stopOnce = sync.Once{}
|
||||
default:
|
||||
if wsi.done == nil {
|
||||
wsi.done = make(chan struct{})
|
||||
wsi.stopOnce = sync.Once{}
|
||||
}
|
||||
}
|
||||
wsi.Mutex.Unlock()
|
||||
|
||||
debug.Log(debug.DEBUG_INFO, "Starting WebSocket connection", "name", wsi.Name, "url", wsi.wsURL)
|
||||
|
||||
u, err := url.Parse(wsi.wsURL)
|
||||
if err != nil {
|
||||
debug.Log(debug.DEBUG_ERROR, "Invalid WebSocket URL", "name", wsi.Name, "url", wsi.wsURL, "error", err)
|
||||
return fmt.Errorf("invalid WebSocket URL: %v", err)
|
||||
}
|
||||
|
||||
var conn net.Conn
|
||||
var host string
|
||||
|
||||
if u.Scheme == "wss" {
|
||||
host = u.Host
|
||||
if !strings.Contains(host, ":") {
|
||||
host += fmt.Sprintf(":%d", WS_HTTPS_PORT)
|
||||
}
|
||||
tcpConn, err := net.DialTimeout("tcp", host, WS_CONNECT_TIMEOUT)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect: %v", err)
|
||||
}
|
||||
tlsConn := tls.Client(tcpConn, &tls.Config{
|
||||
ServerName: u.Hostname(),
|
||||
InsecureSkipVerify: false,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
})
|
||||
if err := tlsConn.Handshake(); err != nil {
|
||||
_ = tcpConn.Close()
|
||||
debug.Log(debug.DEBUG_ERROR, "TLS handshake failed", "name", wsi.Name, "host", host, "error", err)
|
||||
return fmt.Errorf("TLS handshake failed: %v", err)
|
||||
}
|
||||
conn = tlsConn
|
||||
} else if u.Scheme == "ws" {
|
||||
host = u.Host
|
||||
if !strings.Contains(host, ":") {
|
||||
host += fmt.Sprintf(":%d", WS_HTTP_PORT)
|
||||
}
|
||||
debug.Log(debug.DEBUG_VERBOSE, "Connecting to WebSocket server", "name", wsi.Name, "host", host)
|
||||
tcpConn, err := net.DialTimeout("tcp", host, WS_CONNECT_TIMEOUT)
|
||||
if err != nil {
|
||||
debug.Log(debug.DEBUG_ERROR, "Failed to connect to WebSocket server", "name", wsi.Name, "host", host, "error", err)
|
||||
return fmt.Errorf("failed to connect: %v", err)
|
||||
}
|
||||
conn = tcpConn
|
||||
} else {
|
||||
debug.Log(debug.DEBUG_ERROR, "Unsupported WebSocket scheme", "name", wsi.Name, "scheme", u.Scheme)
|
||||
return fmt.Errorf("unsupported scheme: %s (use ws:// or wss://)", u.Scheme)
|
||||
}
|
||||
|
||||
key, err := generateWebSocketKey()
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
return fmt.Errorf("failed to generate key: %v", err)
|
||||
}
|
||||
|
||||
path := u.Path
|
||||
if path == "" {
|
||||
path = "/"
|
||||
}
|
||||
if u.RawQuery != "" {
|
||||
path += "?" + u.RawQuery
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("GET", path, nil)
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
return fmt.Errorf("failed to create request: %v", err)
|
||||
}
|
||||
|
||||
req.Host = u.Host
|
||||
req.Header.Set("Upgrade", "websocket")
|
||||
req.Header.Set("Connection", "Upgrade")
|
||||
req.Header.Set("Sec-WebSocket-Key", key)
|
||||
req.Header.Set("Sec-WebSocket-Version", WS_VERSION)
|
||||
req.Header.Set("User-Agent", "Reticulum-Go/1.0")
|
||||
|
||||
if err := req.Write(conn); err != nil {
|
||||
_ = conn.Close()
|
||||
return fmt.Errorf("failed to send handshake: %v", err)
|
||||
}
|
||||
|
||||
resp, err := http.ReadResponse(bufio.NewReader(conn), req)
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
return fmt.Errorf("failed to read handshake response: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusSwitchingProtocols {
|
||||
_ = conn.Close()
|
||||
debug.Log(debug.DEBUG_ERROR, "WebSocket handshake failed", "name", wsi.Name, "status", resp.StatusCode)
|
||||
return fmt.Errorf("handshake failed: status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
if strings.ToLower(resp.Header.Get("Upgrade")) != "websocket" {
|
||||
_ = conn.Close()
|
||||
return fmt.Errorf("invalid upgrade header")
|
||||
}
|
||||
|
||||
accept := resp.Header.Get("Sec-WebSocket-Accept")
|
||||
expectedAccept := computeAcceptKey(key)
|
||||
if accept != expectedAccept {
|
||||
_ = conn.Close()
|
||||
return fmt.Errorf("invalid accept key")
|
||||
}
|
||||
|
||||
wsi.Mutex.Lock()
|
||||
wsi.conn = conn
|
||||
wsi.reader = bufio.NewReader(conn)
|
||||
wsi.connected = true
|
||||
wsi.Online = true
|
||||
|
||||
debug.Log(debug.DEBUG_INFO, "WebSocket connected", "name", wsi.Name, "url", wsi.wsURL)
|
||||
|
||||
queue := make([][]byte, len(wsi.messageQueue))
|
||||
copy(queue, wsi.messageQueue)
|
||||
wsi.messageQueue = wsi.messageQueue[:0]
|
||||
wsi.Mutex.Unlock() // Unlock after copying queue, before I/O
|
||||
|
||||
for _, msg := range queue {
|
||||
_ = wsi.sendWebSocketMessage(msg)
|
||||
}
|
||||
|
||||
go wsi.readLoop()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (wsi *WebSocketInterface) Stop() error {
|
||||
wsi.Mutex.Lock()
|
||||
defer wsi.Mutex.Unlock()
|
||||
|
||||
wsi.Enabled = false
|
||||
wsi.Online = false
|
||||
|
||||
wsi.stopOnce.Do(func() {
|
||||
if wsi.done != nil {
|
||||
close(wsi.done)
|
||||
}
|
||||
})
|
||||
|
||||
wsi.closeWebSocketLocked()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (wsi *WebSocketInterface) closeWebSocket() {
|
||||
wsi.Mutex.Lock()
|
||||
defer wsi.Mutex.Unlock()
|
||||
wsi.closeWebSocketLocked()
|
||||
}
|
||||
|
||||
func (wsi *WebSocketInterface) closeWebSocketLocked() {
|
||||
if wsi.conn != nil {
|
||||
wsi.sendCloseFrameLocked()
|
||||
_ = wsi.conn.Close()
|
||||
wsi.conn = nil
|
||||
wsi.reader = nil
|
||||
}
|
||||
wsi.connected = false
|
||||
wsi.Online = false
|
||||
}
|
||||
|
||||
func (wsi *WebSocketInterface) readLoop() {
|
||||
for {
|
||||
wsi.Mutex.RLock()
|
||||
conn := wsi.conn
|
||||
reader := wsi.reader
|
||||
done := wsi.done
|
||||
wsi.Mutex.RUnlock()
|
||||
|
||||
if conn == nil || reader == nil {
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
data, err := wsi.readFrame()
|
||||
if err != nil {
|
||||
wsi.Mutex.Lock()
|
||||
wsi.connected = false
|
||||
wsi.Online = false
|
||||
if wsi.conn != nil {
|
||||
_ = wsi.conn.Close()
|
||||
wsi.conn = nil
|
||||
wsi.reader = nil
|
||||
}
|
||||
wsi.Mutex.Unlock()
|
||||
|
||||
debug.Log(debug.DEBUG_INFO, "WebSocket closed", "name", wsi.Name, "error", err)
|
||||
|
||||
time.Sleep(WS_RECONNECT_DELAY)
|
||||
|
||||
wsi.Mutex.RLock()
|
||||
stillEnabled := wsi.Enabled && !wsi.Detached
|
||||
wsi.Mutex.RUnlock()
|
||||
|
||||
if stillEnabled {
|
||||
go wsi.Start()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if len(data) > 0 {
|
||||
wsi.Mutex.Lock()
|
||||
wsi.RxBytes += uint64(len(data))
|
||||
wsi.Mutex.Unlock()
|
||||
|
||||
wsi.ProcessIncoming(data)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (wsi *WebSocketInterface) readFrame() ([]byte, error) {
|
||||
wsi.Mutex.RLock()
|
||||
reader := wsi.reader
|
||||
wsi.Mutex.RUnlock()
|
||||
|
||||
if reader == nil {
|
||||
return nil, io.EOF
|
||||
}
|
||||
|
||||
header := make([]byte, WS_HEADER_SIZE)
|
||||
if _, err := io.ReadFull(reader, header); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fin := (header[0] & WS_FRAME_HEADER_FIN) != 0
|
||||
opcode := header[0] & WS_FRAME_HEADER_OPCODE
|
||||
masked := (header[1] & WS_FRAME_HEADER_MASKED) != 0
|
||||
payloadLen := int(header[1] & WS_FRAME_HEADER_LEN)
|
||||
|
||||
if opcode == WS_OPCODE_CLOSE {
|
||||
return nil, io.EOF
|
||||
}
|
||||
|
||||
if opcode == WS_OPCODE_PING {
|
||||
return wsi.handlePingFrame(reader, payloadLen, masked)
|
||||
}
|
||||
|
||||
if opcode == WS_OPCODE_PONG {
|
||||
return wsi.handlePongFrame(reader, payloadLen, masked)
|
||||
}
|
||||
|
||||
if opcode != WS_OPCODE_BINARY {
|
||||
return nil, fmt.Errorf("unsupported opcode: %d", opcode)
|
||||
}
|
||||
|
||||
if payloadLen == WS_PAYLOAD_LEN_16BIT {
|
||||
lenBytes := make([]byte, 2)
|
||||
if _, err := io.ReadFull(reader, lenBytes); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
payloadLen = int(binary.BigEndian.Uint16(lenBytes))
|
||||
} else if payloadLen == WS_PAYLOAD_LEN_64BIT {
|
||||
lenBytes := make([]byte, 8)
|
||||
if _, err := io.ReadFull(reader, lenBytes); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
val := binary.BigEndian.Uint64(lenBytes)
|
||||
if val > uint64(math.MaxInt) {
|
||||
return nil, fmt.Errorf("payload length exceeds maximum integer value")
|
||||
}
|
||||
payloadLen = int(val) // #nosec G115
|
||||
}
|
||||
|
||||
maskKey := make([]byte, WS_MASK_KEY_SIZE)
|
||||
if masked {
|
||||
if _, err := io.ReadFull(reader, maskKey); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
payload := make([]byte, payloadLen)
|
||||
if _, err := io.ReadFull(reader, payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if masked {
|
||||
for i := 0; i < payloadLen; i++ {
|
||||
payload[i] ^= maskKey[i%WS_MASK_KEY_SIZE]
|
||||
}
|
||||
}
|
||||
|
||||
if !fin {
|
||||
nextFrame, err := wsi.readFrame()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return append(payload, nextFrame...), nil
|
||||
}
|
||||
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
func (wsi *WebSocketInterface) Send(data []byte, addr string) error {
|
||||
wsi.Mutex.RLock()
|
||||
enabled := wsi.Enabled
|
||||
detached := wsi.Detached
|
||||
connected := wsi.connected
|
||||
wsi.Mutex.RUnlock()
|
||||
|
||||
if !enabled || detached {
|
||||
debug.Log(debug.DEBUG_VERBOSE, "WebSocket interface not enabled or detached, dropping packet", "name", wsi.Name, "bytes", len(data))
|
||||
return fmt.Errorf("interface not enabled")
|
||||
}
|
||||
|
||||
wsi.Mutex.Lock()
|
||||
wsi.TxBytes += uint64(len(data))
|
||||
wsi.Mutex.Unlock()
|
||||
|
||||
if !connected {
|
||||
debug.Log(debug.DEBUG_VERBOSE, "WebSocket not connected, queuing packet", "name", wsi.Name, "bytes", len(data), "queue_size", len(wsi.messageQueue))
|
||||
wsi.Mutex.Lock()
|
||||
wsi.messageQueue = append(wsi.messageQueue, data)
|
||||
wsi.Mutex.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
packetType := "unknown"
|
||||
if len(data) > 0 {
|
||||
switch data[0] {
|
||||
case 0x01:
|
||||
packetType = "announce"
|
||||
case 0x02:
|
||||
packetType = "link"
|
||||
default:
|
||||
packetType = fmt.Sprintf("0x%02x", data[0])
|
||||
}
|
||||
}
|
||||
debug.Log(debug.DEBUG_INFO, "Sending packet over WebSocket", "name", wsi.Name, "bytes", len(data), "packet_type", packetType)
|
||||
return wsi.sendWebSocketMessage(data)
|
||||
}
|
||||
|
||||
func (wsi *WebSocketInterface) sendWebSocketMessage(data []byte) error {
|
||||
wsi.Mutex.RLock()
|
||||
conn := wsi.conn
|
||||
wsi.Mutex.RUnlock()
|
||||
|
||||
if conn == nil {
|
||||
return fmt.Errorf("WebSocket not initialized")
|
||||
}
|
||||
|
||||
frame := wsi.createFrame(data, WS_OPCODE_BINARY, true)
|
||||
wsi.Mutex.Lock()
|
||||
_, err := conn.Write(frame)
|
||||
wsi.Mutex.Unlock()
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send: %v", err)
|
||||
}
|
||||
|
||||
debug.Log(debug.DEBUG_INFO, "WebSocket sent packet successfully", "name", wsi.Name, "bytes", len(data), "frame_bytes", len(frame))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (wsi *WebSocketInterface) sendCloseFrame() {
|
||||
wsi.Mutex.RLock()
|
||||
defer wsi.Mutex.RUnlock()
|
||||
wsi.sendCloseFrameLocked()
|
||||
}
|
||||
|
||||
func (wsi *WebSocketInterface) sendCloseFrameLocked() {
|
||||
conn := wsi.conn
|
||||
if conn == nil {
|
||||
return
|
||||
}
|
||||
|
||||
frame := wsi.createFrame(nil, WS_OPCODE_CLOSE, true)
|
||||
_, _ = conn.Write(frame)
|
||||
}
|
||||
|
||||
func (wsi *WebSocketInterface) handlePingFrame(reader *bufio.Reader, payloadLen int, masked bool) ([]byte, error) {
|
||||
if payloadLen == WS_PAYLOAD_LEN_16BIT {
|
||||
lenBytes := make([]byte, 2)
|
||||
if _, err := io.ReadFull(reader, lenBytes); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
payloadLen = int(binary.BigEndian.Uint16(lenBytes))
|
||||
} else if payloadLen == WS_PAYLOAD_LEN_64BIT {
|
||||
lenBytes := make([]byte, 8)
|
||||
if _, err := io.ReadFull(reader, lenBytes); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
val := binary.BigEndian.Uint64(lenBytes)
|
||||
if val > uint64(math.MaxInt) {
|
||||
return nil, fmt.Errorf("payload length exceeds maximum integer value")
|
||||
}
|
||||
payloadLen = int(val) // #nosec G115
|
||||
}
|
||||
|
||||
maskKey := make([]byte, WS_MASK_KEY_SIZE)
|
||||
if masked {
|
||||
if _, err := io.ReadFull(reader, maskKey); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
payload := make([]byte, payloadLen)
|
||||
if payloadLen > 0 {
|
||||
if _, err := io.ReadFull(reader, payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if masked {
|
||||
for i := 0; i < payloadLen; i++ {
|
||||
payload[i] ^= maskKey[i%WS_MASK_KEY_SIZE]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
wsi.sendPongFrame(payload)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (wsi *WebSocketInterface) handlePongFrame(reader *bufio.Reader, payloadLen int, masked bool) ([]byte, error) {
|
||||
if payloadLen == WS_PAYLOAD_LEN_16BIT {
|
||||
lenBytes := make([]byte, 2)
|
||||
if _, err := io.ReadFull(reader, lenBytes); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
payloadLen = int(binary.BigEndian.Uint16(lenBytes))
|
||||
} else if payloadLen == WS_PAYLOAD_LEN_64BIT {
|
||||
lenBytes := make([]byte, 8)
|
||||
if _, err := io.ReadFull(reader, lenBytes); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
val := binary.BigEndian.Uint64(lenBytes)
|
||||
if val > uint64(math.MaxInt) {
|
||||
return nil, fmt.Errorf("payload length exceeds maximum integer value")
|
||||
}
|
||||
payloadLen = int(val) // #nosec G115
|
||||
}
|
||||
|
||||
maskKey := make([]byte, WS_MASK_KEY_SIZE)
|
||||
if masked {
|
||||
if _, err := io.ReadFull(reader, maskKey); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if payloadLen > 0 {
|
||||
payload := make([]byte, payloadLen)
|
||||
if _, err := io.ReadFull(reader, payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (wsi *WebSocketInterface) sendPongFrame(data []byte) {
|
||||
wsi.Mutex.RLock()
|
||||
conn := wsi.conn
|
||||
wsi.Mutex.RUnlock()
|
||||
|
||||
if conn == nil {
|
||||
return
|
||||
}
|
||||
|
||||
frame := wsi.createFrame(data, WS_OPCODE_PONG, true)
|
||||
wsi.Mutex.Lock()
|
||||
_, _ = conn.Write(frame)
|
||||
wsi.Mutex.Unlock()
|
||||
}
|
||||
|
||||
func (wsi *WebSocketInterface) createFrame(data []byte, opcode byte, fin bool) []byte {
|
||||
payloadLen := len(data)
|
||||
frame := make([]byte, WS_HEADER_SIZE)
|
||||
|
||||
if fin {
|
||||
frame[0] |= WS_FRAME_HEADER_FIN
|
||||
}
|
||||
frame[0] |= opcode
|
||||
|
||||
if payloadLen < WS_PAYLOAD_LEN_16BIT {
|
||||
frame[1] = byte(payloadLen)
|
||||
frame = append(frame, data...)
|
||||
} else if payloadLen < WS_MAX_PAYLOAD_16BIT {
|
||||
frame[1] = WS_PAYLOAD_LEN_16BIT // #nosec G602
|
||||
lenBytes := make([]byte, 2)
|
||||
binary.BigEndian.PutUint16(lenBytes, uint16(payloadLen)) // #nosec G115
|
||||
frame = append(frame, lenBytes...)
|
||||
frame = append(frame, data...)
|
||||
} else {
|
||||
frame[1] = WS_PAYLOAD_LEN_64BIT // #nosec G602
|
||||
lenBytes := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(lenBytes, uint64(payloadLen)) // #nosec G115
|
||||
frame = append(frame, lenBytes...)
|
||||
frame = append(frame, data...)
|
||||
}
|
||||
|
||||
return frame
|
||||
}
|
||||
|
||||
func (wsi *WebSocketInterface) ProcessOutgoing(data []byte) error {
|
||||
return wsi.Send(data, "")
|
||||
}
|
||||
|
||||
func (wsi *WebSocketInterface) GetConn() net.Conn {
|
||||
wsi.Mutex.RLock()
|
||||
defer wsi.Mutex.RUnlock()
|
||||
return wsi.conn
|
||||
}
|
||||
|
||||
func (wsi *WebSocketInterface) GetMTU() int {
|
||||
return wsi.MTU
|
||||
}
|
||||
|
||||
func (wsi *WebSocketInterface) IsEnabled() bool {
|
||||
wsi.Mutex.RLock()
|
||||
defer wsi.Mutex.RUnlock()
|
||||
return wsi.Enabled && wsi.Online && !wsi.Detached
|
||||
}
|
||||
|
||||
func (wsi *WebSocketInterface) SendPathRequest(packet []byte) error {
|
||||
return wsi.Send(packet, "")
|
||||
}
|
||||
|
||||
func (wsi *WebSocketInterface) SendLinkPacket(dest []byte, data []byte, timestamp time.Time) error {
|
||||
frame := make([]byte, 0, len(dest)+len(data)+9)
|
||||
frame = append(frame, WS_OPCODE_BINARY)
|
||||
frame = append(frame, dest...)
|
||||
ts := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(ts, uint64(timestamp.Unix())) // #nosec G115
|
||||
frame = append(frame, ts...)
|
||||
frame = append(frame, data...)
|
||||
return wsi.Send(frame, "")
|
||||
}
|
||||
|
||||
func (wsi *WebSocketInterface) GetBandwidthAvailable() bool {
|
||||
return wsi.BaseInterface.GetBandwidthAvailable()
|
||||
}
|
||||
|
||||
func generateWebSocketKey() (string, error) {
|
||||
key := make([]byte, WS_KEY_SIZE)
|
||||
if _, err := rand.Read(key); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.StdEncoding.EncodeToString(key), nil
|
||||
}
|
||||
|
||||
func computeAcceptKey(key string) string {
|
||||
// bearer:disable go_gosec_crypto_weak_crypto
|
||||
h := sha1.New() // #nosec G401
|
||||
h.Write([]byte(key))
|
||||
h.Write([]byte(wsGUID))
|
||||
// bearer:disable go_lang_weak_hash_sha1
|
||||
return base64.StdEncoding.EncodeToString(h.Sum(nil))
|
||||
}
|
||||
280
pkg/interfaces/websocket_native_test.go
Normal file
280
pkg/interfaces/websocket_native_test.go
Normal file
@@ -0,0 +1,280 @@
|
||||
package interfaces
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/common"
|
||||
)
|
||||
|
||||
func TestWebSocketGUID(t *testing.T) {
|
||||
if wsGUID != "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" {
|
||||
t.Errorf("wsGUID mismatch: expected RFC 6455 GUID, got %s", wsGUID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateWebSocketKey(t *testing.T) {
|
||||
key1, err := generateWebSocketKey()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate key: %v", err)
|
||||
}
|
||||
|
||||
key2, err := generateWebSocketKey()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate key: %v", err)
|
||||
}
|
||||
|
||||
if key1 == key2 {
|
||||
t.Error("Generated keys should be unique")
|
||||
}
|
||||
|
||||
if len(key1) != 24 {
|
||||
t.Errorf("Expected base64-encoded key length 24, got %d", len(key1))
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeAcceptKey(t *testing.T) {
|
||||
testKey := "dGhlIHNhbXBsZSBub25jZQ=="
|
||||
expectedAccept := "s3pPLMBiTxaQ9kYGzzhZRbK+xOo="
|
||||
|
||||
accept := computeAcceptKey(testKey)
|
||||
if accept != expectedAccept {
|
||||
t.Errorf("Accept key mismatch: expected %s, got %s", expectedAccept, accept)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewWebSocketInterface(t *testing.T) {
|
||||
ws, err := NewWebSocketInterface("test", "wss://socket.quad4.io/ws", true)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create WebSocket interface: %v", err)
|
||||
}
|
||||
|
||||
if ws.GetName() != "test" {
|
||||
t.Errorf("Expected name 'test', got %s", ws.GetName())
|
||||
}
|
||||
|
||||
if ws.GetType() != common.IF_TYPE_UDP {
|
||||
t.Errorf("Expected type IF_TYPE_UDP, got %v", ws.GetType())
|
||||
}
|
||||
|
||||
if ws.GetMTU() != 1064 {
|
||||
t.Errorf("Expected MTU 1064, got %d", ws.GetMTU())
|
||||
}
|
||||
|
||||
if ws.IsOnline() {
|
||||
t.Error("Interface should not be online before Start()")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebSocketConnection(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping network test in short mode")
|
||||
}
|
||||
|
||||
ws, err := NewWebSocketInterface("test", "wss://socket.quad4.io/ws", true)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create WebSocket interface: %v", err)
|
||||
}
|
||||
|
||||
ws.SetPacketCallback(func(data []byte, ni common.NetworkInterface) {
|
||||
t.Logf("Received packet: %d bytes", len(data))
|
||||
})
|
||||
|
||||
err = ws.Start()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to start WebSocket: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
if !ws.IsOnline() {
|
||||
t.Error("WebSocket should be online after Start()")
|
||||
}
|
||||
|
||||
testData := []byte{0x01, 0x02, 0x03, 0x04}
|
||||
err = ws.Send(testData, "")
|
||||
if err != nil {
|
||||
t.Errorf("Failed to send data: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
if err := ws.Stop(); err != nil {
|
||||
t.Errorf("Failed to stop WebSocket: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
if ws.IsOnline() {
|
||||
t.Error("WebSocket should be offline after Stop()")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebSocketReconnection(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping network test in short mode")
|
||||
}
|
||||
|
||||
ws, err := NewWebSocketInterface("test", "wss://socket.quad4.io/ws", true)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create WebSocket interface: %v", err)
|
||||
}
|
||||
|
||||
err = ws.Start()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to start WebSocket: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
if !ws.IsOnline() {
|
||||
t.Error("WebSocket should be online")
|
||||
}
|
||||
|
||||
conn := ws.GetConn()
|
||||
if conn == nil {
|
||||
t.Error("GetConn() should return a connection")
|
||||
}
|
||||
|
||||
conn.Close()
|
||||
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
if ws.IsOnline() {
|
||||
t.Log("WebSocket reconnected successfully")
|
||||
}
|
||||
|
||||
if err := ws.Stop(); err != nil {
|
||||
t.Errorf("Failed to stop WebSocket: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
|
||||
func TestWebSocketMessageQueue(t *testing.T) {
|
||||
ws, err := NewWebSocketInterface("test", "wss://socket.quad4.io/ws", true)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create WebSocket interface: %v", err)
|
||||
}
|
||||
|
||||
ws.Enable()
|
||||
|
||||
testData := []byte{0x01, 0x02, 0x03}
|
||||
err = ws.Send(testData, "")
|
||||
if err != nil {
|
||||
t.Errorf("Send should queue message when offline, got error: %v", err)
|
||||
}
|
||||
|
||||
if testing.Short() {
|
||||
return
|
||||
}
|
||||
|
||||
err = ws.Start()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to start WebSocket: %v", err)
|
||||
}
|
||||
|
||||
// Wait for interface to be online (up to 10 seconds)
|
||||
for i := 0; i < 100; i++ {
|
||||
if ws.IsOnline() {
|
||||
break
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
if !ws.IsOnline() {
|
||||
t.Error("WebSocket should be online")
|
||||
}
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
if err := ws.Stop(); err != nil {
|
||||
t.Errorf("Failed to stop WebSocket: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
|
||||
func TestWebSocketFrameEncoding(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping frame encoding test in short mode")
|
||||
}
|
||||
|
||||
ws, err := NewWebSocketInterface("test", "wss://socket.quad4.io/ws", true)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create WebSocket interface: %v", err)
|
||||
}
|
||||
|
||||
err = ws.Start()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to start WebSocket: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
data []byte
|
||||
}{
|
||||
{"small frame", []byte{0x01, 0x02, 0x03}},
|
||||
{"medium frame", make([]byte, 200)},
|
||||
{"large frame", make([]byte, 1000)},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := ws.Send(tc.data, "")
|
||||
if err != nil {
|
||||
t.Errorf("Failed to send %s: %v", tc.name, err)
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
})
|
||||
}
|
||||
|
||||
if err := ws.Stop(); err != nil {
|
||||
t.Errorf("Failed to stop WebSocket: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
|
||||
func TestWebSocketEnableDisable(t *testing.T) {
|
||||
ws, err := NewWebSocketInterface("test", "wss://socket.quad4.io/ws", false)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create WebSocket interface: %v", err)
|
||||
}
|
||||
|
||||
if ws.IsEnabled() {
|
||||
t.Error("Interface should not be enabled initially")
|
||||
}
|
||||
|
||||
ws.Enable()
|
||||
if !ws.IsEnabled() {
|
||||
t.Error("Interface should be enabled after Enable()")
|
||||
}
|
||||
|
||||
ws.Disable()
|
||||
if ws.IsEnabled() {
|
||||
t.Error("Interface should not be enabled after Disable()")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebSocketDetach(t *testing.T) {
|
||||
ws, err := NewWebSocketInterface("test", "wss://socket.quad4.io/ws", true)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create WebSocket interface: %v", err)
|
||||
}
|
||||
|
||||
if ws.IsDetached() {
|
||||
t.Error("Interface should not be detached initially")
|
||||
}
|
||||
|
||||
ws.Detach()
|
||||
if !ws.IsDetached() {
|
||||
t.Error("Interface should be detached after Detach()")
|
||||
}
|
||||
|
||||
if ws.IsOnline() {
|
||||
t.Error("Interface should be offline after Detach()")
|
||||
}
|
||||
}
|
||||
253
pkg/interfaces/websocket_wasm.go
Normal file
253
pkg/interfaces/websocket_wasm.go
Normal file
@@ -0,0 +1,253 @@
|
||||
// SPDX-License-Identifier: 0BSD
|
||||
// Copyright (c) 2024-2026 Sudo-Ivan / Quad4.io
|
||||
//go:build js && wasm
|
||||
// +build js,wasm
|
||||
|
||||
package interfaces
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"syscall/js"
|
||||
"time"
|
||||
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/common"
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/debug"
|
||||
)
|
||||
|
||||
const (
|
||||
WS_MTU = 1064
|
||||
WS_BITRATE = 10000000
|
||||
WS_RECONNECT_DELAY = 2 * time.Second
|
||||
)
|
||||
|
||||
type WebSocketInterface struct {
|
||||
BaseInterface
|
||||
wsURL string
|
||||
ws js.Value
|
||||
connected bool
|
||||
messageQueue [][]byte
|
||||
}
|
||||
|
||||
func NewWebSocketInterface(name string, wsURL string, enabled bool) (*WebSocketInterface, error) {
|
||||
ws := &WebSocketInterface{
|
||||
BaseInterface: NewBaseInterface(name, common.IF_TYPE_UDP, enabled),
|
||||
wsURL: wsURL,
|
||||
messageQueue: make([][]byte, 0),
|
||||
}
|
||||
|
||||
ws.MTU = WS_MTU
|
||||
ws.Bitrate = WS_BITRATE
|
||||
|
||||
return ws, nil
|
||||
}
|
||||
|
||||
func (wsi *WebSocketInterface) GetName() string {
|
||||
return wsi.Name
|
||||
}
|
||||
|
||||
func (wsi *WebSocketInterface) GetType() common.InterfaceType {
|
||||
return wsi.Type
|
||||
}
|
||||
|
||||
func (wsi *WebSocketInterface) GetMode() common.InterfaceMode {
|
||||
return wsi.Mode
|
||||
}
|
||||
|
||||
func (wsi *WebSocketInterface) IsOnline() bool {
|
||||
wsi.Mutex.RLock()
|
||||
defer wsi.Mutex.RUnlock()
|
||||
return wsi.Online && wsi.connected
|
||||
}
|
||||
|
||||
func (wsi *WebSocketInterface) IsDetached() bool {
|
||||
wsi.Mutex.RLock()
|
||||
defer wsi.Mutex.RUnlock()
|
||||
return wsi.Detached
|
||||
}
|
||||
|
||||
func (wsi *WebSocketInterface) Detach() {
|
||||
wsi.Mutex.Lock()
|
||||
defer wsi.Mutex.Unlock()
|
||||
wsi.Detached = true
|
||||
wsi.Online = false
|
||||
wsi.closeWebSocket()
|
||||
}
|
||||
|
||||
func (wsi *WebSocketInterface) Enable() {
|
||||
wsi.Mutex.Lock()
|
||||
defer wsi.Mutex.Unlock()
|
||||
wsi.Enabled = true
|
||||
}
|
||||
|
||||
func (wsi *WebSocketInterface) Disable() {
|
||||
wsi.Mutex.Lock()
|
||||
defer wsi.Mutex.Unlock()
|
||||
wsi.Enabled = false
|
||||
wsi.closeWebSocket()
|
||||
}
|
||||
|
||||
func (wsi *WebSocketInterface) Start() error {
|
||||
wsi.Mutex.Lock()
|
||||
defer wsi.Mutex.Unlock()
|
||||
|
||||
if wsi.ws.Truthy() {
|
||||
return fmt.Errorf("WebSocket already started")
|
||||
}
|
||||
|
||||
ws := js.Global().Get("WebSocket").New(wsi.wsURL)
|
||||
ws.Set("binaryType", "arraybuffer")
|
||||
|
||||
ws.Set("onopen", js.FuncOf(func(this js.Value, args []js.Value) interface{} {
|
||||
wsi.Mutex.Lock()
|
||||
wsi.connected = true
|
||||
wsi.Online = true
|
||||
wsi.Mutex.Unlock()
|
||||
|
||||
debug.Log(debug.DEBUG_INFO, "WebSocket connected", "name", wsi.Name, "url", wsi.wsURL)
|
||||
|
||||
wsi.Mutex.Lock()
|
||||
queue := make([][]byte, len(wsi.messageQueue))
|
||||
copy(queue, wsi.messageQueue)
|
||||
wsi.messageQueue = wsi.messageQueue[:0]
|
||||
wsi.Mutex.Unlock()
|
||||
|
||||
for _, msg := range queue {
|
||||
wsi.sendWebSocketMessage(msg)
|
||||
}
|
||||
|
||||
return nil
|
||||
}))
|
||||
|
||||
ws.Set("onmessage", js.FuncOf(func(this js.Value, args []js.Value) interface{} {
|
||||
if len(args) < 1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
event := args[0]
|
||||
data := event.Get("data")
|
||||
|
||||
var packet []byte
|
||||
if data.Type() == js.TypeString {
|
||||
packet = []byte(data.String())
|
||||
} else if data.Type() == js.TypeObject {
|
||||
array := js.Global().Get("Uint8Array").New(data)
|
||||
length := array.Get("length").Int()
|
||||
packet = make([]byte, length)
|
||||
js.CopyBytesToGo(packet, array)
|
||||
} else {
|
||||
debug.Log(debug.DEBUG_ERROR, "Unknown WebSocket message type", "type", data.Type().String())
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(packet) < 1 {
|
||||
debug.Log(debug.DEBUG_ERROR, "WebSocket message empty")
|
||||
return nil
|
||||
}
|
||||
|
||||
wsi.Mutex.Lock()
|
||||
wsi.RxBytes += uint64(len(packet))
|
||||
wsi.Mutex.Unlock()
|
||||
|
||||
wsi.ProcessIncoming(packet)
|
||||
|
||||
return nil
|
||||
}))
|
||||
|
||||
ws.Set("onerror", js.FuncOf(func(this js.Value, args []js.Value) interface{} {
|
||||
debug.Log(debug.DEBUG_ERROR, "WebSocket error", "name", wsi.Name)
|
||||
return nil
|
||||
}))
|
||||
|
||||
ws.Set("onclose", js.FuncOf(func(this js.Value, args []js.Value) interface{} {
|
||||
wsi.Mutex.Lock()
|
||||
wsi.connected = false
|
||||
wsi.Online = false
|
||||
wsi.Mutex.Unlock()
|
||||
|
||||
debug.Log(debug.DEBUG_INFO, "WebSocket closed", "name", wsi.Name)
|
||||
|
||||
if wsi.Enabled && !wsi.Detached {
|
||||
time.Sleep(WS_RECONNECT_DELAY)
|
||||
go wsi.Start()
|
||||
}
|
||||
|
||||
return nil
|
||||
}))
|
||||
|
||||
wsi.ws = ws
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (wsi *WebSocketInterface) Stop() error {
|
||||
wsi.Mutex.Lock()
|
||||
defer wsi.Mutex.Unlock()
|
||||
wsi.Enabled = false
|
||||
wsi.closeWebSocket()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (wsi *WebSocketInterface) closeWebSocket() {
|
||||
if wsi.ws.Truthy() {
|
||||
wsi.ws.Call("close")
|
||||
wsi.ws = js.Value{}
|
||||
}
|
||||
wsi.connected = false
|
||||
wsi.Online = false
|
||||
}
|
||||
|
||||
func (wsi *WebSocketInterface) Send(data []byte, addr string) error {
|
||||
if !wsi.IsEnabled() {
|
||||
return fmt.Errorf("interface not enabled")
|
||||
}
|
||||
|
||||
wsi.Mutex.Lock()
|
||||
wsi.TxBytes += uint64(len(data))
|
||||
wsi.Mutex.Unlock()
|
||||
|
||||
if !wsi.connected {
|
||||
wsi.Mutex.Lock()
|
||||
wsi.messageQueue = append(wsi.messageQueue, data)
|
||||
wsi.Mutex.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
return wsi.sendWebSocketMessage(data)
|
||||
}
|
||||
|
||||
func (wsi *WebSocketInterface) sendWebSocketMessage(data []byte) error {
|
||||
if !wsi.ws.Truthy() {
|
||||
return fmt.Errorf("WebSocket not initialized")
|
||||
}
|
||||
|
||||
if wsi.ws.Get("readyState").Int() != 1 {
|
||||
return fmt.Errorf("WebSocket not open")
|
||||
}
|
||||
|
||||
array := js.Global().Get("Uint8Array").New(len(data))
|
||||
js.CopyBytesToJS(array, data)
|
||||
|
||||
wsi.ws.Call("send", array)
|
||||
|
||||
debug.Log(debug.DEBUG_VERBOSE, "WebSocket sent packet", "name", wsi.Name, "bytes", len(data))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (wsi *WebSocketInterface) ProcessOutgoing(data []byte) error {
|
||||
return wsi.Send(data, "")
|
||||
}
|
||||
|
||||
func (wsi *WebSocketInterface) GetConn() net.Conn {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (wsi *WebSocketInterface) GetMTU() int {
|
||||
return wsi.MTU
|
||||
}
|
||||
|
||||
func (wsi *WebSocketInterface) IsEnabled() bool {
|
||||
wsi.Mutex.RLock()
|
||||
defer wsi.Mutex.RUnlock()
|
||||
return wsi.Enabled && wsi.Online && !wsi.Detached
|
||||
}
|
||||
364
pkg/link/establishment_test.go
Normal file
364
pkg/link/establishment_test.go
Normal file
@@ -0,0 +1,364 @@
|
||||
package link
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/common"
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/destination"
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/identity"
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/packet"
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/transport"
|
||||
)
|
||||
|
||||
func TestEphemeralKeyGeneration(t *testing.T) {
|
||||
link := &Link{}
|
||||
|
||||
if err := link.generateEphemeralKeys(); err != nil {
|
||||
t.Fatalf("Failed to generate ephemeral keys: %v", err)
|
||||
}
|
||||
|
||||
if len(link.prv) != KEYSIZE {
|
||||
t.Errorf("Expected private key length %d, got %d", KEYSIZE, len(link.prv))
|
||||
}
|
||||
|
||||
if len(link.pub) != KEYSIZE {
|
||||
t.Errorf("Expected public key length %d, got %d", KEYSIZE, len(link.pub))
|
||||
}
|
||||
|
||||
if len(link.sigPriv) != 64 {
|
||||
t.Errorf("Expected signing private key length 64, got %d", len(link.sigPriv))
|
||||
}
|
||||
|
||||
if len(link.sigPub) != 32 {
|
||||
t.Errorf("Expected signing public key length 32, got %d", len(link.sigPub))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignallingBytes(t *testing.T) {
|
||||
mtu := 500
|
||||
mode := byte(MODE_AES256_CBC)
|
||||
|
||||
bytes := signallingBytes(mtu, mode)
|
||||
|
||||
if len(bytes) != LINK_MTU_SIZE {
|
||||
t.Errorf("Expected signalling bytes length %d, got %d", LINK_MTU_SIZE, len(bytes))
|
||||
}
|
||||
|
||||
extractedMTU := (int(bytes[0]&0x1F) << 16) | (int(bytes[1]) << 8) | int(bytes[2])
|
||||
if extractedMTU != mtu {
|
||||
t.Errorf("Expected MTU %d, got %d", mtu, extractedMTU)
|
||||
}
|
||||
|
||||
extractedMode := (bytes[0] & MODE_BYTEMASK) >> 5
|
||||
if extractedMode != mode {
|
||||
t.Errorf("Expected mode %d, got %d", mode, extractedMode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLinkIDGeneration(t *testing.T) {
|
||||
responderIdent, err := identity.NewIdentity()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create responder identity: %v", err)
|
||||
}
|
||||
|
||||
cfg := &common.ReticulumConfig{}
|
||||
transportInstance := transport.NewTransport(cfg)
|
||||
|
||||
dest, err := destination.New(responderIdent, destination.IN, destination.SINGLE, "test", transportInstance, "link")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create destination: %v", err)
|
||||
}
|
||||
|
||||
link := &Link{
|
||||
destination: dest,
|
||||
transport: transportInstance,
|
||||
initiator: true,
|
||||
}
|
||||
|
||||
if err := link.generateEphemeralKeys(); err != nil {
|
||||
t.Fatalf("Failed to generate keys: %v", err)
|
||||
}
|
||||
|
||||
link.mode = MODE_DEFAULT
|
||||
link.mtu = 500
|
||||
|
||||
signalling := signallingBytes(link.mtu, link.mode)
|
||||
requestData := make([]byte, 0, ECPUBSIZE+LINK_MTU_SIZE)
|
||||
requestData = append(requestData, link.pub...)
|
||||
requestData = append(requestData, link.sigPub...)
|
||||
requestData = append(requestData, signalling...)
|
||||
|
||||
pkt := &packet.Packet{
|
||||
HeaderType: packet.HeaderType1,
|
||||
PacketType: packet.PacketTypeLinkReq,
|
||||
TransportType: 0,
|
||||
Context: packet.ContextNone,
|
||||
ContextFlag: packet.FlagUnset,
|
||||
Hops: 0,
|
||||
DestinationType: dest.GetType(),
|
||||
DestinationHash: dest.GetHash(),
|
||||
Data: requestData,
|
||||
}
|
||||
|
||||
if err := pkt.Pack(); err != nil {
|
||||
t.Fatalf("Failed to pack packet: %v", err)
|
||||
}
|
||||
|
||||
linkID := linkIDFromPacket(pkt)
|
||||
|
||||
if len(linkID) != 16 {
|
||||
t.Errorf("Expected link ID length 16, got %d", len(linkID))
|
||||
}
|
||||
|
||||
t.Logf("Generated link ID: %x", linkID)
|
||||
}
|
||||
|
||||
func TestHandshake(t *testing.T) {
|
||||
link1 := &Link{}
|
||||
link2 := &Link{}
|
||||
|
||||
if err := link1.generateEphemeralKeys(); err != nil {
|
||||
t.Fatalf("Failed to generate keys for link1: %v", err)
|
||||
}
|
||||
|
||||
if err := link2.generateEphemeralKeys(); err != nil {
|
||||
t.Fatalf("Failed to generate keys for link2: %v", err)
|
||||
}
|
||||
|
||||
link1.peerPub = link2.pub
|
||||
link2.peerPub = link1.pub
|
||||
|
||||
link1.linkID = []byte("test-link-id-abc")
|
||||
link2.linkID = []byte("test-link-id-abc")
|
||||
|
||||
link1.mode = MODE_AES256_CBC
|
||||
link2.mode = MODE_AES256_CBC
|
||||
|
||||
if err := link1.performHandshake(); err != nil {
|
||||
t.Fatalf("Link1 handshake failed: %v", err)
|
||||
}
|
||||
|
||||
if err := link2.performHandshake(); err != nil {
|
||||
t.Fatalf("Link2 handshake failed: %v", err)
|
||||
}
|
||||
|
||||
if string(link1.sharedKey) != string(link2.sharedKey) {
|
||||
t.Error("Shared keys do not match")
|
||||
}
|
||||
|
||||
if string(link1.derivedKey) != string(link2.derivedKey) {
|
||||
t.Error("Derived keys do not match")
|
||||
}
|
||||
|
||||
if link1.status != STATUS_HANDSHAKE {
|
||||
t.Errorf("Expected link1 status HANDSHAKE, got %d", link1.status)
|
||||
}
|
||||
|
||||
if link2.status != STATUS_HANDSHAKE {
|
||||
t.Errorf("Expected link2 status HANDSHAKE, got %d", link2.status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLinkEstablishment(t *testing.T) {
|
||||
responderIdent, err := identity.NewIdentity()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create responder identity: %v", err)
|
||||
}
|
||||
|
||||
cfg := &common.ReticulumConfig{}
|
||||
transportInstance := transport.NewTransport(cfg)
|
||||
|
||||
dest, err := destination.New(responderIdent, destination.IN, destination.SINGLE, "test", transportInstance, "link")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create destination: %v", err)
|
||||
}
|
||||
|
||||
initiatorLink := &Link{
|
||||
destination: dest,
|
||||
transport: transportInstance,
|
||||
initiator: true,
|
||||
}
|
||||
|
||||
responderLink := &Link{
|
||||
transport: transportInstance,
|
||||
initiator: false,
|
||||
}
|
||||
|
||||
if err := initiatorLink.generateEphemeralKeys(); err != nil {
|
||||
t.Fatalf("Failed to generate initiator keys: %v", err)
|
||||
}
|
||||
|
||||
initiatorLink.mode = MODE_DEFAULT
|
||||
initiatorLink.mtu = 500
|
||||
|
||||
signalling := signallingBytes(initiatorLink.mtu, initiatorLink.mode)
|
||||
requestData := make([]byte, 0, ECPUBSIZE+LINK_MTU_SIZE)
|
||||
requestData = append(requestData, initiatorLink.pub...)
|
||||
requestData = append(requestData, initiatorLink.sigPub...)
|
||||
requestData = append(requestData, signalling...)
|
||||
|
||||
linkRequestPkt := &packet.Packet{
|
||||
HeaderType: packet.HeaderType1,
|
||||
PacketType: packet.PacketTypeLinkReq,
|
||||
TransportType: 0,
|
||||
Context: packet.ContextNone,
|
||||
ContextFlag: packet.FlagUnset,
|
||||
Hops: 0,
|
||||
DestinationType: dest.GetType(),
|
||||
DestinationHash: dest.GetHash(),
|
||||
Data: requestData,
|
||||
}
|
||||
|
||||
if err := linkRequestPkt.Pack(); err != nil {
|
||||
t.Fatalf("Failed to pack link request: %v", err)
|
||||
}
|
||||
|
||||
initiatorLink.linkID = linkIDFromPacket(linkRequestPkt)
|
||||
initiatorLink.requestTime = time.Now()
|
||||
initiatorLink.status = STATUS_PENDING
|
||||
|
||||
t.Logf("Initiator link request created, link_id=%x", initiatorLink.linkID)
|
||||
|
||||
responderLink.peerPub = linkRequestPkt.Data[0:KEYSIZE]
|
||||
responderLink.peerSigPub = linkRequestPkt.Data[KEYSIZE:ECPUBSIZE]
|
||||
responderLink.linkID = linkIDFromPacket(linkRequestPkt)
|
||||
responderLink.initiator = false
|
||||
|
||||
t.Logf("Responder link ID=%x (len=%d)", responderLink.linkID, len(responderLink.linkID))
|
||||
|
||||
if len(responderLink.linkID) == 0 {
|
||||
t.Fatal("Responder link ID is empty!")
|
||||
}
|
||||
|
||||
if len(linkRequestPkt.Data) >= ECPUBSIZE+LINK_MTU_SIZE {
|
||||
mtuBytes := linkRequestPkt.Data[ECPUBSIZE : ECPUBSIZE+LINK_MTU_SIZE]
|
||||
responderLink.mtu = (int(mtuBytes[0]&0x1F) << 16) | (int(mtuBytes[1]) << 8) | int(mtuBytes[2])
|
||||
responderLink.mode = (mtuBytes[0] & MODE_BYTEMASK) >> 5
|
||||
}
|
||||
|
||||
if err := responderLink.generateEphemeralKeys(); err != nil {
|
||||
t.Fatalf("Failed to generate responder keys: %v", err)
|
||||
}
|
||||
|
||||
if err := responderLink.performHandshake(); err != nil {
|
||||
t.Fatalf("Responder handshake failed: %v", err)
|
||||
}
|
||||
|
||||
responderLink.status = STATUS_ACTIVE
|
||||
responderLink.establishedAt = time.Now()
|
||||
|
||||
if string(responderLink.linkID) != string(initiatorLink.linkID) {
|
||||
t.Error("Link IDs do not match between initiator and responder")
|
||||
}
|
||||
|
||||
t.Logf("Responder handshake successful, shared_key_len=%d", len(responderLink.sharedKey))
|
||||
}
|
||||
|
||||
func TestLinkProofValidation(t *testing.T) {
|
||||
responderIdent, err := identity.NewIdentity()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create responder identity: %v", err)
|
||||
}
|
||||
|
||||
cfg := &common.ReticulumConfig{}
|
||||
transportInstance := transport.NewTransport(cfg)
|
||||
|
||||
dest, err := destination.New(responderIdent, destination.IN, destination.SINGLE, "test", transportInstance, "link")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create destination: %v", err)
|
||||
}
|
||||
|
||||
initiatorLink := &Link{
|
||||
destination: dest,
|
||||
transport: transportInstance,
|
||||
initiator: true,
|
||||
}
|
||||
|
||||
responderLink := &Link{
|
||||
transport: transportInstance,
|
||||
initiator: false,
|
||||
}
|
||||
|
||||
if err := initiatorLink.generateEphemeralKeys(); err != nil {
|
||||
t.Fatalf("Failed to generate initiator keys: %v", err)
|
||||
}
|
||||
|
||||
initiatorLink.mode = MODE_DEFAULT
|
||||
initiatorLink.mtu = 500
|
||||
|
||||
signalling := signallingBytes(initiatorLink.mtu, initiatorLink.mode)
|
||||
requestData := make([]byte, 0, ECPUBSIZE+LINK_MTU_SIZE)
|
||||
requestData = append(requestData, initiatorLink.pub...)
|
||||
requestData = append(requestData, initiatorLink.sigPub...)
|
||||
requestData = append(requestData, signalling...)
|
||||
|
||||
linkRequestPkt := &packet.Packet{
|
||||
HeaderType: packet.HeaderType1,
|
||||
PacketType: packet.PacketTypeLinkReq,
|
||||
TransportType: 0,
|
||||
Context: packet.ContextNone,
|
||||
ContextFlag: packet.FlagUnset,
|
||||
Hops: 0,
|
||||
DestinationType: dest.GetType(),
|
||||
DestinationHash: dest.GetHash(),
|
||||
Data: requestData,
|
||||
}
|
||||
|
||||
if err := linkRequestPkt.Pack(); err != nil {
|
||||
t.Fatalf("Failed to pack link request: %v", err)
|
||||
}
|
||||
|
||||
initiatorLink.linkID = linkIDFromPacket(linkRequestPkt)
|
||||
initiatorLink.requestTime = time.Now()
|
||||
initiatorLink.status = STATUS_PENDING
|
||||
|
||||
responderLink.peerPub = linkRequestPkt.Data[0:KEYSIZE]
|
||||
responderLink.peerSigPub = linkRequestPkt.Data[KEYSIZE:ECPUBSIZE]
|
||||
responderLink.linkID = linkIDFromPacket(linkRequestPkt)
|
||||
responderLink.initiator = false
|
||||
|
||||
if len(linkRequestPkt.Data) >= ECPUBSIZE+LINK_MTU_SIZE {
|
||||
mtuBytes := linkRequestPkt.Data[ECPUBSIZE : ECPUBSIZE+LINK_MTU_SIZE]
|
||||
responderLink.mtu = (int(mtuBytes[0]&0x1F) << 16) | (int(mtuBytes[1]) << 8) | int(mtuBytes[2])
|
||||
responderLink.mode = (mtuBytes[0] & MODE_BYTEMASK) >> 5
|
||||
} else {
|
||||
responderLink.mtu = 500
|
||||
responderLink.mode = MODE_DEFAULT
|
||||
}
|
||||
|
||||
if err := responderLink.generateEphemeralKeys(); err != nil {
|
||||
t.Fatalf("Failed to generate responder keys: %v", err)
|
||||
}
|
||||
|
||||
if err := responderLink.performHandshake(); err != nil {
|
||||
t.Fatalf("Responder handshake failed: %v", err)
|
||||
}
|
||||
|
||||
proofPkt, err := responderLink.GenerateLinkProof(responderIdent)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate link proof: %v", err)
|
||||
}
|
||||
|
||||
if err := initiatorLink.ValidateLinkProof(proofPkt, nil); err != nil {
|
||||
t.Fatalf("Initiator failed to validate link proof: %v", err)
|
||||
}
|
||||
|
||||
if initiatorLink.status != STATUS_ACTIVE {
|
||||
t.Errorf("Expected initiator status ACTIVE, got %d", initiatorLink.status)
|
||||
}
|
||||
|
||||
if string(initiatorLink.sharedKey) != string(responderLink.sharedKey) {
|
||||
t.Error("Shared keys do not match after full handshake")
|
||||
}
|
||||
|
||||
if string(initiatorLink.derivedKey) != string(responderLink.derivedKey) {
|
||||
t.Error("Derived keys do not match after full handshake")
|
||||
}
|
||||
|
||||
t.Logf("Full link establishment successful")
|
||||
t.Logf("Link ID: %x", initiatorLink.linkID)
|
||||
t.Logf("Shared key length: %d", len(initiatorLink.sharedKey))
|
||||
t.Logf("Derived key length: %d", len(initiatorLink.derivedKey))
|
||||
t.Logf("RTT: %.3f seconds", initiatorLink.rtt)
|
||||
}
|
||||
1699
pkg/link/link.go
1699
pkg/link/link.go
File diff suppressed because it is too large
Load Diff
218
pkg/link/link_test.go
Normal file
218
pkg/link/link_test.go
Normal file
@@ -0,0 +1,218 @@
|
||||
package link
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/common"
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/destination"
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/identity"
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/packet"
|
||||
)
|
||||
|
||||
type mockTransport struct {
|
||||
sentPackets []*packet.Packet
|
||||
}
|
||||
|
||||
func (m *mockTransport) SendPacket(pkt *packet.Packet) error {
|
||||
m.sentPackets = append(m.sentPackets, pkt)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockTransport) RegisterLink(linkID []byte, link interface{}) {
|
||||
}
|
||||
|
||||
func (m *mockTransport) GetConfig() *common.ReticulumConfig {
|
||||
return &common.ReticulumConfig{}
|
||||
}
|
||||
|
||||
func (m *mockTransport) GetInterfaces() map[string]common.NetworkInterface {
|
||||
return make(map[string]common.NetworkInterface)
|
||||
}
|
||||
|
||||
func (m *mockTransport) RegisterDestination(hash []byte, dest interface{}) {
|
||||
}
|
||||
|
||||
type mockInterface struct {
|
||||
name string
|
||||
}
|
||||
|
||||
func (m *mockInterface) GetName() string {
|
||||
return m.name
|
||||
}
|
||||
|
||||
func (m *mockInterface) Start() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockInterface) Stop() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockInterface) Send(data []byte, address string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockInterface) ProcessIncoming(data []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockInterface) SetPacketCallback(cb func([]byte, common.NetworkInterface)) {
|
||||
}
|
||||
|
||||
func (m *mockInterface) GetType() string {
|
||||
return "mock"
|
||||
}
|
||||
|
||||
func (m *mockInterface) GetMTU() int {
|
||||
return 500
|
||||
}
|
||||
|
||||
func (m *mockInterface) Detach() {
|
||||
}
|
||||
|
||||
func (m *mockInterface) Enable() {
|
||||
}
|
||||
|
||||
func (m *mockInterface) Disable() {
|
||||
}
|
||||
|
||||
func (m *mockInterface) IsEnabled() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *mockInterface) IsOnline() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *mockInterface) IsDetached() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *mockInterface) GetPacketCallback() func([]byte, common.NetworkInterface) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockInterface) GetConn() interface{} {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockInterface) ProcessOutgoing(data []byte) ([]byte, error) {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (m *mockInterface) SendPathRequest(destHash []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockInterface) SendLinkPacket(data []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockInterface) GetBandwidthAvailable() float64 {
|
||||
return 1.0
|
||||
}
|
||||
|
||||
func TestLinkRequestResponse(t *testing.T) {
|
||||
serverIdent, err := identity.New()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create server identity: %v", err)
|
||||
}
|
||||
|
||||
clientIdent, err := identity.New()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create client identity: %v", err)
|
||||
}
|
||||
|
||||
mockTrans := &mockTransport{
|
||||
sentPackets: make([]*packet.Packet, 0),
|
||||
}
|
||||
|
||||
serverDest, err := destination.New(serverIdent, destination.IN, destination.SINGLE, "testapp", mockTrans, "server")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create server destination: %v", err)
|
||||
}
|
||||
|
||||
expectedResponse := []byte("response data")
|
||||
testPath := "/test/path"
|
||||
|
||||
err = serverDest.RegisterRequestHandler(testPath, func(path string, data []byte, requestID []byte, linkID []byte, remoteIdentity *identity.Identity, requestedAt int64) []byte {
|
||||
if path != testPath {
|
||||
t.Errorf("Expected path %s, got %s", testPath, path)
|
||||
}
|
||||
return expectedResponse
|
||||
}, destination.ALLOW_ALL, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to register request handler: %v", err)
|
||||
}
|
||||
|
||||
// Test the handler is registered correctly
|
||||
pathHash := identity.TruncatedHash([]byte(testPath))
|
||||
handler := serverDest.GetRequestHandler(pathHash)
|
||||
if handler == nil {
|
||||
t.Fatal("Handler not found after registration")
|
||||
}
|
||||
|
||||
// Call the handler
|
||||
testLinkID := make([]byte, 16)
|
||||
result := handler(pathHash, []byte("test data"), []byte("request-id"), testLinkID, clientIdent, time.Now())
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("Handler returned nil")
|
||||
}
|
||||
|
||||
responseBytes, ok := result.([]byte)
|
||||
if !ok {
|
||||
t.Fatalf("Handler returned unexpected type: %T", result)
|
||||
}
|
||||
|
||||
if !bytes.Equal(responseBytes, expectedResponse) {
|
||||
t.Errorf("Expected response %q, got %q", expectedResponse, responseBytes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLinkRequestHandlerNotFound(t *testing.T) {
|
||||
serverIdent, _ := identity.New()
|
||||
mockTrans := &mockTransport{sentPackets: make([]*packet.Packet, 0)}
|
||||
|
||||
serverDest, _ := destination.New(serverIdent, destination.IN, destination.SINGLE, "testapp", mockTrans, "server")
|
||||
|
||||
nonExistentPath := "/does/not/exist"
|
||||
pathHash := identity.TruncatedHash([]byte(nonExistentPath))
|
||||
|
||||
handler := serverDest.GetRequestHandler(pathHash)
|
||||
if handler != nil {
|
||||
t.Error("Expected no handler for non-existent path, but found one")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLinkResponseHandling(t *testing.T) {
|
||||
// This test verifies the basic structure for response handling
|
||||
// Full integration testing would require a proper transport setup
|
||||
|
||||
requestID := []byte("test-request-id-")
|
||||
responseData := []byte("response payload")
|
||||
|
||||
receipt := &RequestReceipt{
|
||||
requestID: requestID,
|
||||
status: STATUS_PENDING,
|
||||
}
|
||||
|
||||
// Verify initial state
|
||||
if receipt.status != STATUS_PENDING {
|
||||
t.Errorf("Expected initial status PENDING, got %d", receipt.status)
|
||||
}
|
||||
|
||||
// Simulate setting response
|
||||
receipt.response = responseData
|
||||
receipt.status = STATUS_ACTIVE
|
||||
|
||||
if !bytes.Equal(receipt.response, responseData) {
|
||||
t.Errorf("Expected response %q, got %q", responseData, receipt.response)
|
||||
}
|
||||
|
||||
if receipt.status != STATUS_ACTIVE {
|
||||
t.Errorf("Expected status ACTIVE after response, got %d", receipt.status)
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,5 @@
|
||||
// SPDX-License-Identifier: 0BSD
|
||||
// Copyright (c) 2024-2026 Sudo-Ivan / Quad4.io
|
||||
package packet
|
||||
|
||||
const (
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
// SPDX-License-Identifier: 0BSD
|
||||
// Copyright (c) 2024-2026 Sudo-Ivan / Quad4.io
|
||||
package packet
|
||||
|
||||
import (
|
||||
@@ -6,10 +8,10 @@ import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/Sudo-Ivan/reticulum-go/pkg/identity"
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/debug"
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/identity"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -67,6 +69,7 @@ type Packet struct {
|
||||
|
||||
DestinationType byte
|
||||
DestinationHash []byte
|
||||
Destination interface{}
|
||||
TransportID []byte
|
||||
Data []byte
|
||||
|
||||
@@ -85,6 +88,21 @@ type Packet struct {
|
||||
Q *float64
|
||||
|
||||
Addresses []byte
|
||||
Link interface{}
|
||||
|
||||
receipt *PacketReceipt
|
||||
}
|
||||
|
||||
type PacketConfig struct {
|
||||
DestType byte
|
||||
Data []byte
|
||||
PacketType byte
|
||||
Context byte
|
||||
TransportType byte
|
||||
HeaderType byte
|
||||
TransportID []byte
|
||||
CreateReceipt bool
|
||||
ContextFlag byte
|
||||
}
|
||||
|
||||
func NewPacket(destType byte, data []byte, packetType byte, context byte,
|
||||
@@ -113,7 +131,7 @@ func (p *Packet) Pack() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Printf("[DEBUG-6] Packing packet: type=%d, header=%d", p.PacketType, p.HeaderType)
|
||||
debug.Log(debug.DEBUG_PACKETS, "Packing packet", "type", p.PacketType, "header", p.HeaderType)
|
||||
|
||||
// Create header byte (Corrected order)
|
||||
flags := byte(0)
|
||||
@@ -124,23 +142,23 @@ func (p *Packet) Pack() error {
|
||||
flags |= p.PacketType & 0b00000011
|
||||
|
||||
header := []byte{flags, p.Hops}
|
||||
log.Printf("[DEBUG-5] Created packet header: flags=%08b, hops=%d", flags, p.Hops)
|
||||
debug.Log(debug.DEBUG_TRACE, "Created packet header", "flags", fmt.Sprintf("%08b", flags), "hops", p.Hops)
|
||||
|
||||
header = append(header, p.DestinationHash...)
|
||||
|
||||
|
||||
if p.HeaderType == HeaderType2 {
|
||||
if p.TransportID == nil {
|
||||
return errors.New("transport ID required for header type 2")
|
||||
}
|
||||
header = append(header, p.TransportID...)
|
||||
log.Printf("[DEBUG-7] Added transport ID to header: %x", p.TransportID)
|
||||
debug.Log(debug.DEBUG_ALL, "Added transport ID to header", "transport_id", fmt.Sprintf("%x", p.TransportID))
|
||||
}
|
||||
|
||||
header = append(header, p.Context)
|
||||
log.Printf("[DEBUG-6] Final header length: %d bytes", len(header))
|
||||
debug.Log(debug.DEBUG_PACKETS, "Final header length", "bytes", len(header))
|
||||
|
||||
p.Raw = append(header, p.Data...)
|
||||
log.Printf("[DEBUG-5] Final packet size: %d bytes", len(p.Raw))
|
||||
debug.Log(debug.DEBUG_TRACE, "Final packet size", "bytes", len(p.Raw))
|
||||
|
||||
if len(p.Raw) > MTU {
|
||||
return errors.New("packet size exceeds MTU")
|
||||
@@ -148,7 +166,7 @@ func (p *Packet) Pack() error {
|
||||
|
||||
p.Packed = true
|
||||
p.updateHash()
|
||||
log.Printf("[DEBUG-7] Packet hash: %x", p.PacketHash)
|
||||
debug.Log(debug.DEBUG_ALL, "Packet hash", "hash", fmt.Sprintf("%x", p.PacketHash))
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -173,8 +191,8 @@ func (p *Packet) Unpack() error {
|
||||
if len(p.Raw) < 2*dstLen+3 {
|
||||
return errors.New("packet too short for header type 2")
|
||||
}
|
||||
p.DestinationHash = p.Raw[2 : dstLen+2] // Destination hash first
|
||||
p.TransportID = p.Raw[dstLen+2 : 2*dstLen+2] // Transport ID second
|
||||
p.DestinationHash = p.Raw[2 : dstLen+2] // Destination hash first
|
||||
p.TransportID = p.Raw[dstLen+2 : 2*dstLen+2] // Transport ID second
|
||||
p.Context = p.Raw[2*dstLen+2]
|
||||
p.Data = p.Raw[2*dstLen+3:]
|
||||
} else {
|
||||
@@ -202,14 +220,14 @@ func (p *Packet) GetHash() []byte {
|
||||
func (p *Packet) getHashablePart() []byte {
|
||||
hashable := []byte{p.Raw[0] & 0b00001111} // Lower 4 bits of flags
|
||||
if p.HeaderType == HeaderType2 {
|
||||
// Match Python: Start hash from DestHash (index 18), skipping TransportID
|
||||
// Start hash from DestHash (index 18), skipping TransportID
|
||||
dstLen := 16 // RNS.Identity.TRUNCATED_HASHLENGTH / 8
|
||||
startIndex := dstLen + 2
|
||||
if len(p.Raw) > startIndex {
|
||||
hashable = append(hashable, p.Raw[startIndex:]...)
|
||||
}
|
||||
} else {
|
||||
// Match Python: Start hash from DestHash (index 2)
|
||||
// Start hash from DestHash (index 2)
|
||||
if len(p.Raw) > 2 {
|
||||
hashable = append(hashable, p.Raw[2:]...)
|
||||
}
|
||||
@@ -221,6 +239,18 @@ func (p *Packet) updateHash() {
|
||||
p.PacketHash = p.GetHash()
|
||||
}
|
||||
|
||||
func (p *Packet) Hash() []byte {
|
||||
return p.GetHash()
|
||||
}
|
||||
|
||||
func (p *Packet) TruncatedHash() []byte {
|
||||
hash := p.GetHash()
|
||||
if len(hash) >= 16 {
|
||||
return hash[:16]
|
||||
}
|
||||
return hash
|
||||
}
|
||||
|
||||
func (p *Packet) Serialize() ([]byte, error) {
|
||||
if !p.Packed {
|
||||
if err := p.Pack(); err != nil {
|
||||
@@ -234,13 +264,13 @@ func (p *Packet) Serialize() ([]byte, error) {
|
||||
}
|
||||
|
||||
func NewAnnouncePacket(destHash []byte, identity *identity.Identity, appData []byte, transportID []byte) (*Packet, error) {
|
||||
log.Printf("[DEBUG-7] Creating new announce packet: destHash=%x, appData=%s", destHash, fmt.Sprintf("%x", appData))
|
||||
debug.Log(debug.DEBUG_ALL, "Creating new announce packet", "dest_hash", fmt.Sprintf("%x", destHash), "app_data", fmt.Sprintf("%x", appData))
|
||||
|
||||
// Get public key separated into encryption and signing keys
|
||||
pubKey := identity.GetPublicKey()
|
||||
encKey := pubKey[:32]
|
||||
signKey := pubKey[32:]
|
||||
log.Printf("[DEBUG-6] Using public keys: encKey=%x, signKey=%x", encKey, signKey)
|
||||
debug.Log(debug.DEBUG_PACKETS, "Using public keys", "enc_key", fmt.Sprintf("%x", encKey), "sign_key", fmt.Sprintf("%x", signKey))
|
||||
|
||||
// Parse app name from first msgpack element if possible
|
||||
// For nodes, we'll use "reticulum.node" as the name hash
|
||||
@@ -265,19 +295,19 @@ func NewAnnouncePacket(destHash []byte, identity *identity.Identity, appData []b
|
||||
// Create name hash (10 bytes)
|
||||
nameHash := sha256.Sum256([]byte(appName))
|
||||
nameHash10 := nameHash[:10]
|
||||
log.Printf("[DEBUG-6] Using name hash for '%s': %x", appName, nameHash10)
|
||||
debug.Log(debug.DEBUG_PACKETS, "Using name hash", "name", appName, "hash", fmt.Sprintf("%x", nameHash10))
|
||||
|
||||
// Create random hash (10 bytes) - 5 bytes random + 5 bytes time
|
||||
randomHash := make([]byte, 10)
|
||||
_, err := rand.Read(randomHash[:5]) // #nosec G104
|
||||
if err != nil {
|
||||
log.Printf("[DEBUG-6] Failed to read random bytes for hash: %v", err)
|
||||
debug.Log(debug.DEBUG_PACKETS, "Failed to read random bytes for hash", "error", err)
|
||||
return nil, err // Or handle the error appropriately
|
||||
}
|
||||
timeBytes := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(timeBytes, uint64(time.Now().Unix())) // #nosec G115
|
||||
copy(randomHash[5:], timeBytes[:5])
|
||||
log.Printf("[DEBUG-6] Generated random hash: %x", randomHash)
|
||||
debug.Log(debug.DEBUG_PACKETS, "Generated random hash", "hash", fmt.Sprintf("%x", randomHash))
|
||||
|
||||
// Prepare ratchet ID if available (not yet implemented)
|
||||
var ratchetID []byte
|
||||
@@ -291,11 +321,11 @@ func NewAnnouncePacket(destHash []byte, identity *identity.Identity, appData []b
|
||||
signedData = append(signedData, nameHash10...)
|
||||
signedData = append(signedData, randomHash...)
|
||||
signedData = append(signedData, appData...)
|
||||
log.Printf("[DEBUG-5] Created signed data (%d bytes)", len(signedData))
|
||||
debug.Log(debug.DEBUG_TRACE, "Created signed data", "bytes", len(signedData))
|
||||
|
||||
// Sign the data
|
||||
signature := identity.Sign(signedData)
|
||||
log.Printf("[DEBUG-6] Generated signature: %x", signature)
|
||||
debug.Log(debug.DEBUG_PACKETS, "Generated signature", "signature", fmt.Sprintf("%x", signature))
|
||||
|
||||
// Combine all fields according to spec
|
||||
// Data structure: Public Key (32) + Signing Key (32) + Name Hash (10) + Random Hash (10) + Ratchet (optional) + Signature (64) + App Data
|
||||
@@ -310,7 +340,7 @@ func NewAnnouncePacket(destHash []byte, identity *identity.Identity, appData []b
|
||||
data = append(data, signature...) // Signature (64 bytes)
|
||||
data = append(data, appData...) // Application data (variable)
|
||||
|
||||
log.Printf("[DEBUG-5] Combined packet data (%d bytes)", len(data))
|
||||
debug.Log(debug.DEBUG_TRACE, "Combined packet data", "bytes", len(data))
|
||||
|
||||
// Create the packet with header type 2 (two address fields)
|
||||
p := &Packet{
|
||||
@@ -321,6 +351,6 @@ func NewAnnouncePacket(destHash []byte, identity *identity.Identity, appData []b
|
||||
Data: data,
|
||||
}
|
||||
|
||||
log.Printf("[DEBUG-4] Created announce packet: type=%d, header=%d", p.PacketType, p.HeaderType)
|
||||
debug.Log(debug.DEBUG_VERBOSE, "Created announce packet", "type", p.PacketType, "header", p.HeaderType)
|
||||
return p, nil
|
||||
}
|
||||
|
||||
344
pkg/packet/receipt.go
Normal file
344
pkg/packet/receipt.go
Normal file
@@ -0,0 +1,344 @@
|
||||
// SPDX-License-Identifier: 0BSD
|
||||
// Copyright (c) 2024-2026 Sudo-Ivan / Quad4.io
|
||||
package packet
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/debug"
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/identity"
|
||||
)
|
||||
|
||||
const (
|
||||
RECEIPT_FAILED = 0x00
|
||||
RECEIPT_SENT = 0x01
|
||||
RECEIPT_DELIVERED = 0x02
|
||||
RECEIPT_CULLED = 0xFF
|
||||
|
||||
EXPL_LENGTH = (identity.HASHLENGTH + identity.SIGLENGTH) / 8
|
||||
IMPL_LENGTH = identity.SIGLENGTH / 8
|
||||
)
|
||||
|
||||
type PacketReceipt struct {
|
||||
mutex sync.RWMutex
|
||||
|
||||
hash []byte
|
||||
truncatedHash []byte
|
||||
sent bool
|
||||
sentAt time.Time
|
||||
proved bool
|
||||
status byte
|
||||
destination interface{}
|
||||
timeout time.Duration
|
||||
concludedAt time.Time
|
||||
proofPacket *Packet
|
||||
|
||||
deliveryCallback func(*PacketReceipt)
|
||||
timeoutCallback func(*PacketReceipt)
|
||||
|
||||
link interface{}
|
||||
destinationHash []byte
|
||||
destinationIdent *identity.Identity
|
||||
timeoutCheckDone chan bool
|
||||
}
|
||||
|
||||
func NewPacketReceipt(pkt *Packet) *PacketReceipt {
|
||||
hash := pkt.Hash()
|
||||
receipt := &PacketReceipt{
|
||||
hash: hash,
|
||||
truncatedHash: pkt.TruncatedHash(),
|
||||
sent: true,
|
||||
sentAt: time.Now(),
|
||||
proved: false,
|
||||
status: RECEIPT_SENT,
|
||||
destination: pkt.Destination,
|
||||
timeout: calculateTimeout(pkt),
|
||||
timeoutCheckDone: make(chan bool, 1),
|
||||
}
|
||||
|
||||
go receipt.timeoutWatchdog()
|
||||
|
||||
debug.Log(debug.DEBUG_PACKETS, "Created packet receipt", "hash", fmt.Sprintf("%x", receipt.truncatedHash))
|
||||
return receipt
|
||||
}
|
||||
|
||||
func calculateTimeout(pkt *Packet) time.Duration {
|
||||
baseTimeout := 15 * time.Second
|
||||
|
||||
if pkt.Hops > 0 {
|
||||
baseTimeout += time.Duration(pkt.Hops) * (3 * time.Second)
|
||||
}
|
||||
|
||||
return baseTimeout
|
||||
}
|
||||
|
||||
func (pr *PacketReceipt) GetStatus() byte {
|
||||
pr.mutex.RLock()
|
||||
defer pr.mutex.RUnlock()
|
||||
return pr.status
|
||||
}
|
||||
|
||||
func (pr *PacketReceipt) GetHash() []byte {
|
||||
pr.mutex.RLock()
|
||||
defer pr.mutex.RUnlock()
|
||||
return pr.hash
|
||||
}
|
||||
|
||||
func (pr *PacketReceipt) IsDelivered() bool {
|
||||
pr.mutex.RLock()
|
||||
defer pr.mutex.RUnlock()
|
||||
return pr.status == RECEIPT_DELIVERED
|
||||
}
|
||||
|
||||
func (pr *PacketReceipt) IsFailed() bool {
|
||||
pr.mutex.RLock()
|
||||
defer pr.mutex.RUnlock()
|
||||
return pr.status == RECEIPT_FAILED
|
||||
}
|
||||
|
||||
func (pr *PacketReceipt) ValidateProofPacket(proofPacket *Packet) bool {
|
||||
if proofPacket.Link != nil {
|
||||
return pr.ValidateLinkProof(proofPacket.Data, proofPacket.Link, proofPacket)
|
||||
}
|
||||
return pr.ValidateProof(proofPacket.Data, proofPacket)
|
||||
}
|
||||
|
||||
func (pr *PacketReceipt) ValidateLinkProof(proof []byte, link interface{}, proofPacket *Packet) bool {
|
||||
if len(proof) == EXPL_LENGTH {
|
||||
proofHash := proof[:identity.HASHLENGTH/8]
|
||||
signature := proof[identity.HASHLENGTH/8 : identity.HASHLENGTH/8+identity.SIGLENGTH/8]
|
||||
|
||||
pr.mutex.RLock()
|
||||
hashMatch := string(proofHash) == string(pr.hash)
|
||||
pr.mutex.RUnlock()
|
||||
|
||||
if !hashMatch {
|
||||
return false
|
||||
}
|
||||
|
||||
proofValid := pr.validateLinkSignature(signature, link)
|
||||
if proofValid {
|
||||
pr.mutex.Lock()
|
||||
pr.status = RECEIPT_DELIVERED
|
||||
pr.proved = true
|
||||
pr.concludedAt = time.Now()
|
||||
pr.proofPacket = proofPacket
|
||||
callback := pr.deliveryCallback
|
||||
pr.mutex.Unlock()
|
||||
|
||||
if callback != nil {
|
||||
go callback(pr)
|
||||
}
|
||||
|
||||
debug.Log(debug.DEBUG_PACKETS, "Link proof validated", "hash", fmt.Sprintf("%x", pr.truncatedHash))
|
||||
return true
|
||||
}
|
||||
} else if len(proof) == IMPL_LENGTH {
|
||||
debug.Log(debug.DEBUG_TRACE, "Implicit link proof not yet implemented")
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (pr *PacketReceipt) ValidateProof(proof []byte, proofPacket *Packet) bool {
|
||||
if len(proof) == EXPL_LENGTH {
|
||||
proofHash := proof[:identity.HASHLENGTH/8]
|
||||
signature := proof[identity.HASHLENGTH/8 : identity.HASHLENGTH/8+identity.SIGLENGTH/8]
|
||||
|
||||
pr.mutex.RLock()
|
||||
hashMatch := string(proofHash) == string(pr.hash)
|
||||
ident := pr.destinationIdent
|
||||
pr.mutex.RUnlock()
|
||||
|
||||
debug.Log(debug.DEBUG_PACKETS, "Explicit proof validation", "len", len(proof), "hashMatch", hashMatch, "hasIdent", ident != nil)
|
||||
|
||||
if !hashMatch {
|
||||
debug.Log(debug.DEBUG_PACKETS, "Proof hash mismatch")
|
||||
return false
|
||||
}
|
||||
|
||||
if ident == nil {
|
||||
debug.Log(debug.DEBUG_VERBOSE, "Cannot validate proof without destination identity")
|
||||
return false
|
||||
}
|
||||
|
||||
proofValid := ident.Verify(pr.hash, signature)
|
||||
debug.Log(debug.DEBUG_PACKETS, "Signature verification result", "valid", proofValid)
|
||||
if proofValid {
|
||||
pr.mutex.Lock()
|
||||
pr.status = RECEIPT_DELIVERED
|
||||
pr.proved = true
|
||||
pr.concludedAt = time.Now()
|
||||
pr.proofPacket = proofPacket
|
||||
callback := pr.deliveryCallback
|
||||
pr.mutex.Unlock()
|
||||
|
||||
if callback != nil {
|
||||
go callback(pr)
|
||||
}
|
||||
|
||||
debug.Log(debug.DEBUG_PACKETS, "Proof validated", "hash", fmt.Sprintf("%x", pr.truncatedHash))
|
||||
return true
|
||||
}
|
||||
} else if len(proof) == IMPL_LENGTH {
|
||||
signature := proof[:identity.SIGLENGTH/8]
|
||||
|
||||
pr.mutex.RLock()
|
||||
ident := pr.destinationIdent
|
||||
pr.mutex.RUnlock()
|
||||
|
||||
if ident == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
proofValid := ident.Verify(pr.hash, signature)
|
||||
if proofValid {
|
||||
pr.mutex.Lock()
|
||||
pr.status = RECEIPT_DELIVERED
|
||||
pr.proved = true
|
||||
pr.concludedAt = time.Now()
|
||||
pr.proofPacket = proofPacket
|
||||
callback := pr.deliveryCallback
|
||||
pr.mutex.Unlock()
|
||||
|
||||
if callback != nil {
|
||||
go callback(pr)
|
||||
}
|
||||
|
||||
debug.Log(debug.DEBUG_PACKETS, "Implicit proof validated", "hash", fmt.Sprintf("%x", pr.truncatedHash))
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (pr *PacketReceipt) validateLinkSignature(signature []byte, link interface{}) bool {
|
||||
type linkValidator interface {
|
||||
Validate(signature, message []byte) bool
|
||||
}
|
||||
|
||||
if validator, ok := link.(linkValidator); ok {
|
||||
return validator.Validate(signature, pr.hash)
|
||||
}
|
||||
|
||||
debug.Log(debug.DEBUG_TRACE, "Link does not implement Validate method")
|
||||
return false
|
||||
}
|
||||
|
||||
func (pr *PacketReceipt) GetRTT() time.Duration {
|
||||
pr.mutex.RLock()
|
||||
defer pr.mutex.RUnlock()
|
||||
|
||||
if pr.concludedAt.IsZero() {
|
||||
return 0
|
||||
}
|
||||
|
||||
return pr.concludedAt.Sub(pr.sentAt)
|
||||
}
|
||||
|
||||
func (pr *PacketReceipt) IsTimedOut() bool {
|
||||
pr.mutex.RLock()
|
||||
defer pr.mutex.RUnlock()
|
||||
|
||||
return time.Since(pr.sentAt) > pr.timeout
|
||||
}
|
||||
|
||||
func (pr *PacketReceipt) checkTimeout() {
|
||||
pr.mutex.Lock()
|
||||
|
||||
if pr.status != RECEIPT_SENT {
|
||||
pr.mutex.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
if time.Since(pr.sentAt) <= pr.timeout {
|
||||
pr.mutex.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
if pr.timeout < 0 {
|
||||
pr.status = RECEIPT_CULLED
|
||||
} else {
|
||||
pr.status = RECEIPT_FAILED
|
||||
}
|
||||
|
||||
pr.concludedAt = time.Now()
|
||||
callback := pr.timeoutCallback
|
||||
pr.mutex.Unlock()
|
||||
|
||||
debug.Log(debug.DEBUG_VERBOSE, "Packet receipt timed out", "hash", fmt.Sprintf("%x", pr.truncatedHash))
|
||||
|
||||
if callback != nil {
|
||||
go callback(pr)
|
||||
}
|
||||
}
|
||||
|
||||
func (pr *PacketReceipt) timeoutWatchdog() {
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
pr.checkTimeout()
|
||||
|
||||
pr.mutex.RLock()
|
||||
status := pr.status
|
||||
pr.mutex.RUnlock()
|
||||
|
||||
if status != RECEIPT_SENT {
|
||||
return
|
||||
}
|
||||
case <-pr.timeoutCheckDone:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (pr *PacketReceipt) SetTimeout(timeout time.Duration) {
|
||||
pr.mutex.Lock()
|
||||
defer pr.mutex.Unlock()
|
||||
pr.timeout = timeout
|
||||
}
|
||||
|
||||
func (pr *PacketReceipt) SetDeliveryCallback(callback func(*PacketReceipt)) {
|
||||
pr.mutex.Lock()
|
||||
defer pr.mutex.Unlock()
|
||||
pr.deliveryCallback = callback
|
||||
}
|
||||
|
||||
func (pr *PacketReceipt) SetTimeoutCallback(callback func(*PacketReceipt)) {
|
||||
pr.mutex.Lock()
|
||||
defer pr.mutex.Unlock()
|
||||
pr.timeoutCallback = callback
|
||||
}
|
||||
|
||||
func (pr *PacketReceipt) SetDestinationIdentity(ident *identity.Identity) {
|
||||
pr.mutex.Lock()
|
||||
defer pr.mutex.Unlock()
|
||||
pr.destinationIdent = ident
|
||||
}
|
||||
|
||||
func (pr *PacketReceipt) SetLink(link interface{}) {
|
||||
pr.mutex.Lock()
|
||||
defer pr.mutex.Unlock()
|
||||
pr.link = link
|
||||
}
|
||||
|
||||
func (pr *PacketReceipt) Cancel() {
|
||||
pr.mutex.Lock()
|
||||
defer pr.mutex.Unlock()
|
||||
|
||||
if pr.status == RECEIPT_SENT {
|
||||
pr.status = RECEIPT_CULLED
|
||||
pr.concludedAt = time.Now()
|
||||
}
|
||||
|
||||
select {
|
||||
case pr.timeoutCheckDone <- true:
|
||||
default:
|
||||
}
|
||||
}
|
||||
209
pkg/packet/receipt_test.go
Normal file
209
pkg/packet/receipt_test.go
Normal file
@@ -0,0 +1,209 @@
|
||||
package packet
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/identity"
|
||||
)
|
||||
|
||||
func TestPacketReceiptCreation(t *testing.T) {
|
||||
testIdent, err := identity.NewIdentity()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create identity: %v", err)
|
||||
}
|
||||
|
||||
destHash := testIdent.Hash()
|
||||
data := []byte("test packet data")
|
||||
|
||||
pkt := &Packet{
|
||||
HeaderType: HeaderType1,
|
||||
PacketType: PacketTypeData,
|
||||
TransportType: 0,
|
||||
Context: ContextNone,
|
||||
ContextFlag: FlagUnset,
|
||||
Hops: 0,
|
||||
DestinationType: 0x00,
|
||||
DestinationHash: destHash,
|
||||
Data: data,
|
||||
CreateReceipt: true,
|
||||
}
|
||||
|
||||
if err := pkt.Pack(); err != nil {
|
||||
t.Fatalf("Failed to pack packet: %v", err)
|
||||
}
|
||||
|
||||
receipt := NewPacketReceipt(pkt)
|
||||
if receipt == nil {
|
||||
t.Fatal("Receipt creation failed")
|
||||
}
|
||||
|
||||
if receipt.GetStatus() != RECEIPT_SENT {
|
||||
t.Errorf("Expected status SENT, got %d", receipt.GetStatus())
|
||||
}
|
||||
|
||||
hash := receipt.GetHash()
|
||||
if len(hash) == 0 {
|
||||
t.Error("Receipt hash is empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPacketReceiptTimeout(t *testing.T) {
|
||||
testIdent, err := identity.NewIdentity()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create identity: %v", err)
|
||||
}
|
||||
|
||||
destHash := testIdent.Hash()
|
||||
data := []byte("test data")
|
||||
|
||||
pkt := &Packet{
|
||||
HeaderType: HeaderType1,
|
||||
PacketType: PacketTypeData,
|
||||
TransportType: 0,
|
||||
Context: ContextNone,
|
||||
ContextFlag: FlagUnset,
|
||||
Hops: 0,
|
||||
DestinationType: 0x00,
|
||||
DestinationHash: destHash,
|
||||
Data: data,
|
||||
CreateReceipt: true,
|
||||
}
|
||||
|
||||
if err := pkt.Pack(); err != nil {
|
||||
t.Fatalf("Failed to pack packet: %v", err)
|
||||
}
|
||||
|
||||
receipt := NewPacketReceipt(pkt)
|
||||
receipt.SetTimeout(100 * time.Millisecond)
|
||||
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
if !receipt.IsTimedOut() {
|
||||
t.Error("Receipt should be timed out")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPacketReceiptProofValidation(t *testing.T) {
|
||||
testIdent, err := identity.NewIdentity()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create identity: %v", err)
|
||||
}
|
||||
|
||||
destHash := testIdent.Hash()
|
||||
data := []byte("test data")
|
||||
|
||||
pkt := &Packet{
|
||||
HeaderType: HeaderType1,
|
||||
PacketType: PacketTypeData,
|
||||
TransportType: 0,
|
||||
Context: ContextNone,
|
||||
ContextFlag: FlagUnset,
|
||||
Hops: 0,
|
||||
DestinationType: 0x00,
|
||||
DestinationHash: destHash,
|
||||
Data: data,
|
||||
CreateReceipt: true,
|
||||
}
|
||||
|
||||
if err := pkt.Pack(); err != nil {
|
||||
t.Fatalf("Failed to pack packet: %v", err)
|
||||
}
|
||||
|
||||
receipt := NewPacketReceipt(pkt)
|
||||
receipt.SetDestinationIdentity(testIdent)
|
||||
|
||||
packetHash := pkt.GetHash()
|
||||
t.Logf("Packet hash: %x", packetHash)
|
||||
|
||||
signature := testIdent.Sign(packetHash)
|
||||
|
||||
t.Logf("PacketHash length: %d", len(packetHash))
|
||||
t.Logf("Signature length: %d", len(signature))
|
||||
t.Logf("EXPL_LENGTH constant: %d", EXPL_LENGTH)
|
||||
|
||||
if testIdent.Verify(packetHash, signature) {
|
||||
t.Log("Direct verification succeeded")
|
||||
} else {
|
||||
t.Error("Direct verification failed")
|
||||
}
|
||||
|
||||
proof := make([]byte, 0, EXPL_LENGTH)
|
||||
proof = append(proof, packetHash...)
|
||||
proof = append(proof, signature...)
|
||||
|
||||
t.Logf("Proof length: %d", len(proof))
|
||||
|
||||
proofPacket := &Packet{
|
||||
PacketType: PacketTypeProof,
|
||||
Data: proof,
|
||||
}
|
||||
|
||||
if !receipt.ValidateProof(proof, proofPacket) {
|
||||
t.Errorf("Valid proof was rejected. Proof len=%d, expected=%d", len(proof), EXPL_LENGTH)
|
||||
}
|
||||
|
||||
if receipt.GetStatus() != RECEIPT_DELIVERED {
|
||||
t.Errorf("Expected status DELIVERED, got %d", receipt.GetStatus())
|
||||
}
|
||||
|
||||
if !receipt.IsDelivered() {
|
||||
t.Error("Receipt should be marked as delivered")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPacketReceiptCallbacks(t *testing.T) {
|
||||
testIdent, err := identity.NewIdentity()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create identity: %v", err)
|
||||
}
|
||||
|
||||
destHash := testIdent.Hash()
|
||||
data := []byte("test data")
|
||||
|
||||
pkt := &Packet{
|
||||
HeaderType: HeaderType1,
|
||||
PacketType: PacketTypeData,
|
||||
TransportType: 0,
|
||||
Context: ContextNone,
|
||||
ContextFlag: FlagUnset,
|
||||
Hops: 0,
|
||||
DestinationType: 0x00,
|
||||
DestinationHash: destHash,
|
||||
Data: data,
|
||||
CreateReceipt: true,
|
||||
}
|
||||
|
||||
if err := pkt.Pack(); err != nil {
|
||||
t.Fatalf("Failed to pack packet: %v", err)
|
||||
}
|
||||
|
||||
receipt := NewPacketReceipt(pkt)
|
||||
receipt.SetDestinationIdentity(testIdent)
|
||||
|
||||
deliveryCalled := make(chan bool, 1)
|
||||
receipt.SetDeliveryCallback(func(r *PacketReceipt) {
|
||||
deliveryCalled <- true
|
||||
})
|
||||
|
||||
packetHash := pkt.GetHash()
|
||||
signature := testIdent.Sign(packetHash)
|
||||
|
||||
proof := make([]byte, 0, EXPL_LENGTH)
|
||||
proof = append(proof, packetHash...)
|
||||
proof = append(proof, signature...)
|
||||
|
||||
proofPacket := &Packet{
|
||||
PacketType: PacketTypeProof,
|
||||
Data: proof,
|
||||
}
|
||||
|
||||
receipt.ValidateProof(proof, proofPacket)
|
||||
|
||||
select {
|
||||
case <-deliveryCalled:
|
||||
// Success
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Error("Delivery callback was not called")
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,5 @@
|
||||
// SPDX-License-Identifier: 0BSD
|
||||
// Copyright (c) 2024-2026 Sudo-Ivan / Quad4.io
|
||||
package pathfinder
|
||||
|
||||
import "time"
|
||||
|
||||
134
pkg/pathfinder/pathfinder_test.go
Normal file
134
pkg/pathfinder/pathfinder_test.go
Normal file
@@ -0,0 +1,134 @@
|
||||
package pathfinder
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewPathFinder(t *testing.T) {
|
||||
pf := NewPathFinder()
|
||||
if pf == nil {
|
||||
t.Fatal("NewPathFinder() returned nil")
|
||||
}
|
||||
if pf.paths == nil {
|
||||
t.Error("NewPathFinder() paths map is nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPathFinder_AddPath(t *testing.T) {
|
||||
pf := NewPathFinder()
|
||||
|
||||
destHash := "test-dest-hash"
|
||||
nextHop := []byte{0x01, 0x02, 0x03, 0x04}
|
||||
iface := "eth0"
|
||||
hops := byte(5)
|
||||
|
||||
pf.AddPath(destHash, nextHop, iface, hops)
|
||||
|
||||
path, exists := pf.GetPath(destHash)
|
||||
if !exists {
|
||||
t.Fatal("GetPath() returned false after AddPath()")
|
||||
}
|
||||
|
||||
if string(path.NextHop) != string(nextHop) {
|
||||
t.Errorf("NextHop = %v, want %v", path.NextHop, nextHop)
|
||||
}
|
||||
if path.Interface != iface {
|
||||
t.Errorf("Interface = %s, want %s", path.Interface, iface)
|
||||
}
|
||||
if path.HopCount != hops {
|
||||
t.Errorf("HopCount = %d, want %d", path.HopCount, hops)
|
||||
}
|
||||
if path.LastUpdated == 0 {
|
||||
t.Error("LastUpdated should be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPathFinder_GetPath(t *testing.T) {
|
||||
pf := NewPathFinder()
|
||||
|
||||
destHash := "test-dest-hash"
|
||||
_, exists := pf.GetPath(destHash)
|
||||
if exists {
|
||||
t.Error("GetPath() should return false for non-existent path")
|
||||
}
|
||||
|
||||
nextHop := []byte{0x01, 0x02}
|
||||
pf.AddPath(destHash, nextHop, "eth0", 3)
|
||||
|
||||
path, exists := pf.GetPath(destHash)
|
||||
if !exists {
|
||||
t.Fatal("GetPath() returned false for existing path")
|
||||
}
|
||||
if string(path.NextHop) != string(nextHop) {
|
||||
t.Errorf("NextHop = %v, want %v", path.NextHop, nextHop)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPathFinder_UpdatePath(t *testing.T) {
|
||||
pf := NewPathFinder()
|
||||
|
||||
destHash := "test-dest-hash"
|
||||
nextHop1 := []byte{0x01, 0x02}
|
||||
nextHop2 := []byte{0x03, 0x04}
|
||||
|
||||
pf.AddPath(destHash, nextHop1, "eth0", 3)
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
firstUpdate := time.Now().Unix()
|
||||
|
||||
pf.AddPath(destHash, nextHop2, "eth1", 5)
|
||||
|
||||
path, exists := pf.GetPath(destHash)
|
||||
if !exists {
|
||||
t.Fatal("GetPath() returned false")
|
||||
}
|
||||
|
||||
if string(path.NextHop) != string(nextHop2) {
|
||||
t.Errorf("NextHop = %v, want %v", path.NextHop, nextHop2)
|
||||
}
|
||||
if path.Interface != "eth1" {
|
||||
t.Errorf("Interface = %s, want eth1", path.Interface)
|
||||
}
|
||||
if path.HopCount != 5 {
|
||||
t.Errorf("HopCount = %d, want 5", path.HopCount)
|
||||
}
|
||||
if path.LastUpdated < firstUpdate {
|
||||
t.Error("LastUpdated should be updated")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPathFinder_MultiplePaths(t *testing.T) {
|
||||
pf := NewPathFinder()
|
||||
|
||||
paths := []struct {
|
||||
hash string
|
||||
nextHop []byte
|
||||
iface string
|
||||
hops byte
|
||||
}{
|
||||
{"hash1", []byte{0x01}, "eth0", 1},
|
||||
{"hash2", []byte{0x02}, "eth1", 2},
|
||||
{"hash3", []byte{0x03}, "eth2", 3},
|
||||
}
|
||||
|
||||
for _, p := range paths {
|
||||
pf.AddPath(p.hash, p.nextHop, p.iface, p.hops)
|
||||
}
|
||||
|
||||
for _, p := range paths {
|
||||
path, exists := pf.GetPath(p.hash)
|
||||
if !exists {
|
||||
t.Errorf("GetPath() returned false for %s", p.hash)
|
||||
continue
|
||||
}
|
||||
if string(path.NextHop) != string(p.nextHop) {
|
||||
t.Errorf("NextHop for %s = %v, want %v", p.hash, path.NextHop, p.nextHop)
|
||||
}
|
||||
if path.Interface != p.iface {
|
||||
t.Errorf("Interface for %s = %s, want %s", p.hash, path.Interface, p.iface)
|
||||
}
|
||||
if path.HopCount != p.hops {
|
||||
t.Errorf("HopCount for %s = %d, want %d", p.hash, path.HopCount, p.hops)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,8 +1,12 @@
|
||||
// SPDX-License-Identifier: 0BSD
|
||||
// Copyright (c) 2024-2026 Sudo-Ivan / Quad4.io
|
||||
package rate
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/common"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -15,22 +19,29 @@ const (
|
||||
DefaultBurstPenalty = 300 // Default seconds penalty after burst
|
||||
DefaultMaxHeldAnnounces = 256 // Default max announces in hold queue
|
||||
DefaultHeldReleaseInterval = 30 // Default seconds between releasing held announces
|
||||
|
||||
// Allowance thresholds
|
||||
AllowanceMinThreshold = 1.0
|
||||
AllowanceDecrement = 1.0
|
||||
|
||||
// History check threshold
|
||||
HistoryGraceThreshold = 1
|
||||
)
|
||||
|
||||
type Limiter struct {
|
||||
rate float64
|
||||
interval time.Duration
|
||||
capacity float64
|
||||
lastUpdate time.Time
|
||||
allowance float64
|
||||
mutex sync.Mutex
|
||||
}
|
||||
|
||||
func NewLimiter(rate float64, interval time.Duration) *Limiter {
|
||||
func NewLimiter(rate float64, capacity float64) *Limiter {
|
||||
return &Limiter{
|
||||
rate: rate,
|
||||
interval: interval,
|
||||
capacity: capacity,
|
||||
lastUpdate: time.Now(),
|
||||
allowance: rate,
|
||||
allowance: capacity,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -43,15 +54,15 @@ func (l *Limiter) Allow() bool {
|
||||
l.lastUpdate = now
|
||||
|
||||
l.allowance += elapsed.Seconds() * l.rate
|
||||
if l.allowance > l.rate {
|
||||
l.allowance = l.rate
|
||||
if l.allowance > l.capacity {
|
||||
l.allowance = l.capacity
|
||||
}
|
||||
|
||||
if l.allowance < 1.0 {
|
||||
if l.allowance < AllowanceMinThreshold {
|
||||
return false
|
||||
}
|
||||
|
||||
l.allowance -= 1.0
|
||||
l.allowance -= AllowanceDecrement
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -100,7 +111,7 @@ func (arc *AnnounceRateControl) AllowAnnounce(destHash string) bool {
|
||||
// Check rate
|
||||
lastAnnounce := history[len(history)-1]
|
||||
waitTime := arc.rateTarget
|
||||
if len(history) > arc.rateGrace {
|
||||
if len(history) > arc.rateGrace+HistoryGraceThreshold {
|
||||
waitTime += arc.ratePenalty
|
||||
}
|
||||
|
||||
@@ -155,7 +166,7 @@ func (ic *IngressControl) ProcessAnnounce(announceHash string, announceData []by
|
||||
|
||||
// Reset counter if enough time has passed
|
||||
if elapsed > ic.burstHold+ic.burstPenalty {
|
||||
ic.announceCount = 0
|
||||
ic.announceCount = common.ZERO
|
||||
ic.lastBurst = now
|
||||
}
|
||||
|
||||
@@ -166,7 +177,13 @@ func (ic *IngressControl) ProcessAnnounce(announceHash string, announceData []by
|
||||
}
|
||||
|
||||
ic.announceCount++
|
||||
burstFreq := float64(ic.announceCount) / elapsed.Seconds()
|
||||
|
||||
// Avoid division by zero and handle very small elapsed times
|
||||
seconds := elapsed.Seconds()
|
||||
if seconds < 0.01 {
|
||||
seconds = 0.01
|
||||
}
|
||||
burstFreq := float64(ic.announceCount) / seconds
|
||||
|
||||
// Hold announce if burst frequency exceeded
|
||||
if burstFreq > maxFreq {
|
||||
|
||||
150
pkg/rate/rate_test.go
Normal file
150
pkg/rate/rate_test.go
Normal file
@@ -0,0 +1,150 @@
|
||||
package rate
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewLimiter(t *testing.T) {
|
||||
limiter := NewLimiter(10.0, 1.0)
|
||||
if limiter == nil {
|
||||
t.Fatal("NewLimiter() returned nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLimiter_Allow(t *testing.T) {
|
||||
limiter := NewLimiter(10.0, 1.0)
|
||||
|
||||
if !limiter.Allow() {
|
||||
t.Error("Allow() should return true initially")
|
||||
}
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
limiter.Allow()
|
||||
}
|
||||
|
||||
if limiter.Allow() {
|
||||
t.Error("Allow() should return false after exceeding rate")
|
||||
}
|
||||
|
||||
time.Sleep(1100 * time.Millisecond)
|
||||
|
||||
if !limiter.Allow() {
|
||||
t.Error("Allow() should return true after waiting")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAnnounceRateControl(t *testing.T) {
|
||||
arc := NewAnnounceRateControl(3600.0, 3, 7200.0)
|
||||
if arc == nil {
|
||||
t.Fatal("NewAnnounceRateControl() returned nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAnnounceRateControl_AllowAnnounce(t *testing.T) {
|
||||
arc := NewAnnounceRateControl(1.0, 2, 2.0)
|
||||
|
||||
hash := "test-dest-hash"
|
||||
|
||||
if !arc.AllowAnnounce(hash) {
|
||||
t.Error("AllowAnnounce() should return true for first announce")
|
||||
}
|
||||
|
||||
if !arc.AllowAnnounce(hash) {
|
||||
t.Error("AllowAnnounce() should return true for second announce (within grace)")
|
||||
}
|
||||
|
||||
if arc.AllowAnnounce(hash) {
|
||||
t.Error("AllowAnnounce() should return false for third announce (exceeds grace)")
|
||||
}
|
||||
|
||||
time.Sleep(1100 * time.Millisecond)
|
||||
|
||||
if !arc.AllowAnnounce(hash) {
|
||||
t.Error("AllowAnnounce() should return true after waiting")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAnnounceRateControl_AllowAnnounce_DifferentHashes(t *testing.T) {
|
||||
arc := NewAnnounceRateControl(1.0, 1, 1.0)
|
||||
|
||||
hash1 := "hash1"
|
||||
hash2 := "hash2"
|
||||
|
||||
if !arc.AllowAnnounce(hash1) {
|
||||
t.Error("AllowAnnounce() should return true for hash1")
|
||||
}
|
||||
|
||||
if !arc.AllowAnnounce(hash2) {
|
||||
t.Error("AllowAnnounce() should return true for hash2 (different hash)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewIngressControl(t *testing.T) {
|
||||
ic := NewIngressControl(true)
|
||||
if ic == nil {
|
||||
t.Fatal("NewIngressControl() returned nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIngressControl_ProcessAnnounce(t *testing.T) {
|
||||
ic := NewIngressControl(true)
|
||||
|
||||
hash := "test-hash"
|
||||
data := []byte("announce data")
|
||||
|
||||
ic.mutex.Lock()
|
||||
ic.lastBurst = time.Now().Add(-time.Second)
|
||||
ic.mutex.Unlock()
|
||||
|
||||
if !ic.ProcessAnnounce(hash, data, false) {
|
||||
t.Error("ProcessAnnounce() should return true for first announce")
|
||||
}
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
for i := 0; i < 200; i++ {
|
||||
ic.ProcessAnnounce(hash, data, false)
|
||||
}
|
||||
|
||||
result := ic.ProcessAnnounce(hash, data, false)
|
||||
if result {
|
||||
t.Error("ProcessAnnounce() should return false when burst frequency exceeded")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIngressControl_ProcessAnnounce_Disabled(t *testing.T) {
|
||||
ic := NewIngressControl(false)
|
||||
|
||||
hash := "test-hash"
|
||||
data := []byte("announce data")
|
||||
|
||||
if !ic.ProcessAnnounce(hash, data, false) {
|
||||
t.Error("ProcessAnnounce() should return true when disabled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIngressControl_ReleaseHeldAnnounce(t *testing.T) {
|
||||
ic := NewIngressControl(true)
|
||||
|
||||
hash, data, found := ic.ReleaseHeldAnnounce()
|
||||
if found {
|
||||
t.Error("ReleaseHeldAnnounce() should return false when no announces held")
|
||||
}
|
||||
|
||||
ic.ProcessAnnounce("hash1", []byte("data1"), false)
|
||||
for i := 0; i < 200; i++ {
|
||||
ic.ProcessAnnounce("hash1", []byte("data1"), false)
|
||||
}
|
||||
|
||||
hash, data, found = ic.ReleaseHeldAnnounce()
|
||||
if !found {
|
||||
t.Error("ReleaseHeldAnnounce() should return true when announces are held")
|
||||
}
|
||||
if hash == "" {
|
||||
t.Error("ReleaseHeldAnnounce() should return non-empty hash")
|
||||
}
|
||||
if len(data) == 0 {
|
||||
t.Error("ReleaseHeldAnnounce() should return non-empty data")
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,5 @@
|
||||
// SPDX-License-Identifier: 0BSD
|
||||
// Copyright (c) 2024-2026 Sudo-Ivan / Quad4.io
|
||||
package resolver
|
||||
|
||||
import (
|
||||
@@ -7,7 +9,18 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/Sudo-Ivan/reticulum-go/pkg/identity"
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/identity"
|
||||
)
|
||||
|
||||
const (
|
||||
// Hash length conversion (bits to bytes)
|
||||
BitsPerByte = 8
|
||||
|
||||
// Known destination data index
|
||||
KnownDestIdentityIndex = 2
|
||||
|
||||
// Minimum name parts for hierarchical resolution
|
||||
MinNameParts = 2
|
||||
)
|
||||
|
||||
type Resolver struct {
|
||||
@@ -36,12 +49,12 @@ func (r *Resolver) ResolveIdentity(fullName string) (*identity.Identity, error)
|
||||
// Hash the full name to create a deterministic identity
|
||||
h := sha256.New()
|
||||
h.Write([]byte(fullName))
|
||||
nameHash := h.Sum(nil)[:identity.NAME_HASH_LENGTH/8]
|
||||
nameHash := h.Sum(nil)[:identity.NAME_HASH_LENGTH/BitsPerByte]
|
||||
hashStr := hex.EncodeToString(nameHash)
|
||||
|
||||
// Check if this identity is known
|
||||
if knownData, exists := identity.GetKnownDestination(hashStr); exists {
|
||||
if id, ok := knownData[2].(*identity.Identity); ok {
|
||||
if id, ok := knownData[KnownDestIdentityIndex].(*identity.Identity); ok {
|
||||
r.cacheLock.Lock()
|
||||
r.cache[fullName] = id
|
||||
r.cacheLock.Unlock()
|
||||
@@ -51,7 +64,7 @@ func (r *Resolver) ResolveIdentity(fullName string) (*identity.Identity, error)
|
||||
|
||||
// Split name into parts for hierarchical resolution
|
||||
parts := strings.Split(fullName, ".")
|
||||
if len(parts) < 2 {
|
||||
if len(parts) < MinNameParts {
|
||||
return nil, errors.New("invalid identity name format")
|
||||
}
|
||||
|
||||
|
||||
118
pkg/resolver/resolver_test.go
Normal file
118
pkg/resolver/resolver_test.go
Normal file
@@ -0,0 +1,118 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
r := New()
|
||||
if r == nil {
|
||||
t.Fatal("New() returned nil")
|
||||
}
|
||||
if r.cache == nil {
|
||||
t.Error("New() cache map is nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolver_ResolveIdentity(t *testing.T) {
|
||||
r := New()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
fullName string
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "ValidName",
|
||||
fullName: "app.aspect",
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "EmptyName",
|
||||
fullName: "",
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "InvalidFormat",
|
||||
fullName: "app",
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "MultiPartName",
|
||||
fullName: "app.aspect1.aspect2",
|
||||
wantError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
id, err := r.ResolveIdentity(tt.fullName)
|
||||
if (err != nil) != tt.wantError {
|
||||
t.Errorf("ResolveIdentity() error = %v, wantError %v", err, tt.wantError)
|
||||
return
|
||||
}
|
||||
if !tt.wantError && id == nil {
|
||||
t.Error("ResolveIdentity() returned nil identity for valid name")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolver_ResolveIdentity_Caching(t *testing.T) {
|
||||
r := New()
|
||||
|
||||
fullName := "app.aspect"
|
||||
id1, err := r.ResolveIdentity(fullName)
|
||||
if err != nil {
|
||||
t.Fatalf("ResolveIdentity() error = %v", err)
|
||||
}
|
||||
|
||||
id2, err := r.ResolveIdentity(fullName)
|
||||
if err != nil {
|
||||
t.Fatalf("ResolveIdentity() error = %v", err)
|
||||
}
|
||||
|
||||
if id1 == nil || id2 == nil {
|
||||
t.Fatal("ResolveIdentity() returned nil")
|
||||
}
|
||||
|
||||
if id1.GetPublicKey() == nil || id2.GetPublicKey() == nil {
|
||||
t.Fatal("Identity public key is nil")
|
||||
}
|
||||
|
||||
if string(id1.GetPublicKey()) != string(id2.GetPublicKey()) {
|
||||
t.Error("ResolveIdentity() should return cached identity")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveIdentity(t *testing.T) {
|
||||
id, err := ResolveIdentity("app.aspect")
|
||||
if err != nil {
|
||||
t.Fatalf("ResolveIdentity() error = %v", err)
|
||||
}
|
||||
if id == nil {
|
||||
t.Error("ResolveIdentity() returned nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolver_ResolveIdentity_Concurrent(t *testing.T) {
|
||||
r := New()
|
||||
|
||||
done := make(chan bool, 10)
|
||||
for i := 0; i < 10; i++ {
|
||||
go func() {
|
||||
id, err := r.ResolveIdentity("app.aspect")
|
||||
if err != nil {
|
||||
t.Errorf("ResolveIdentity() error = %v", err)
|
||||
}
|
||||
if id == nil {
|
||||
t.Error("ResolveIdentity() returned nil")
|
||||
}
|
||||
done <- true
|
||||
}()
|
||||
}
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
<-done
|
||||
}
|
||||
}
|
||||
255
pkg/resource/advertisement.go
Normal file
255
pkg/resource/advertisement.go
Normal file
@@ -0,0 +1,255 @@
|
||||
// SPDX-License-Identifier: 0BSD
|
||||
// Copyright (c) 2024-2026 Sudo-Ivan / Quad4.io
|
||||
package resource
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/vmihailenco/msgpack/v5"
|
||||
)
|
||||
|
||||
const (
|
||||
OVERHEAD = 134
|
||||
COLLISION_GUARD_SIZE = 2*WINDOW_MAX + 100
|
||||
)
|
||||
|
||||
type ResourceAdvertisement struct {
|
||||
TransferSize int64
|
||||
DataSize int64
|
||||
Parts int
|
||||
Hash []byte
|
||||
RandomHash []byte
|
||||
OriginalHash []byte
|
||||
Hashmap []byte
|
||||
Compressed bool
|
||||
Encrypted bool
|
||||
Split bool
|
||||
HasMetadata bool
|
||||
SegmentIndex uint16
|
||||
TotalSegments uint16
|
||||
RequestID []byte
|
||||
IsRequest bool
|
||||
IsResponse bool
|
||||
Flags byte
|
||||
}
|
||||
|
||||
func NewResourceAdvertisement(res *Resource) *ResourceAdvertisement {
|
||||
if res == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
flags := byte(0x00)
|
||||
if res.HasMetadata() {
|
||||
flags |= 0x20
|
||||
}
|
||||
if res.IsResponse() {
|
||||
flags |= 0x10
|
||||
}
|
||||
if res.IsRequest() {
|
||||
flags |= 0x08
|
||||
}
|
||||
|
||||
res.mutex.RLock()
|
||||
split := res.split
|
||||
compressed := res.compressed
|
||||
encrypted := res.encrypted
|
||||
randomHash := res.randomHash
|
||||
originalHash := res.originalHash
|
||||
segmentIndex := res.segmentIndex
|
||||
totalSegments := res.totalSegments
|
||||
res.mutex.RUnlock()
|
||||
|
||||
if split {
|
||||
flags |= 0x04
|
||||
}
|
||||
if compressed {
|
||||
flags |= 0x02
|
||||
}
|
||||
if encrypted {
|
||||
flags |= 0x01
|
||||
}
|
||||
|
||||
hashmap := res.getHashmap()
|
||||
|
||||
return &ResourceAdvertisement{
|
||||
TransferSize: res.GetTransferSize(),
|
||||
DataSize: res.GetDataSize(),
|
||||
Parts: int(res.GetSegments()),
|
||||
Hash: res.GetHash(),
|
||||
RandomHash: randomHash,
|
||||
OriginalHash: originalHash,
|
||||
Hashmap: hashmap,
|
||||
Compressed: compressed,
|
||||
Encrypted: encrypted,
|
||||
Split: split,
|
||||
HasMetadata: res.HasMetadata(),
|
||||
SegmentIndex: segmentIndex,
|
||||
TotalSegments: totalSegments,
|
||||
RequestID: res.GetRequestID(),
|
||||
IsRequest: res.IsRequest(),
|
||||
IsResponse: res.IsResponse(),
|
||||
Flags: flags,
|
||||
}
|
||||
}
|
||||
|
||||
func (ra *ResourceAdvertisement) Pack(segment int) ([]byte, error) {
|
||||
hashmapMaxLen := getHashmapMaxLen()
|
||||
hashmapStart := segment * hashmapMaxLen
|
||||
hashmapEnd := hashmapStart + hashmapMaxLen
|
||||
if hashmapEnd > len(ra.Hashmap)/MAPHASH_LEN {
|
||||
hashmapEnd = len(ra.Hashmap) / MAPHASH_LEN
|
||||
}
|
||||
|
||||
hashmap := ra.Hashmap[hashmapStart*MAPHASH_LEN : hashmapEnd*MAPHASH_LEN]
|
||||
|
||||
dict := map[string]interface{}{
|
||||
"t": ra.TransferSize,
|
||||
"d": ra.DataSize,
|
||||
"n": ra.Parts,
|
||||
"h": ra.Hash,
|
||||
"r": ra.RandomHash,
|
||||
"o": ra.OriginalHash,
|
||||
"i": ra.SegmentIndex,
|
||||
"l": ra.TotalSegments,
|
||||
"q": ra.RequestID,
|
||||
"f": ra.Flags,
|
||||
"m": hashmap,
|
||||
}
|
||||
|
||||
return msgpack.Marshal(dict)
|
||||
}
|
||||
|
||||
func UnpackResourceAdvertisement(data []byte) (*ResourceAdvertisement, error) {
|
||||
var dict map[string]interface{}
|
||||
if err := msgpack.Unmarshal(data, &dict); err != nil {
|
||||
return nil, fmt.Errorf("failed to unpack advertisement: %w", err)
|
||||
}
|
||||
|
||||
ra := &ResourceAdvertisement{}
|
||||
|
||||
if t, ok := dict["t"].(int64); ok {
|
||||
ra.TransferSize = t
|
||||
} else if t, ok := dict["t"].(uint64); ok {
|
||||
if t > uint64(math.MaxInt64) {
|
||||
return nil, fmt.Errorf("transfer size overflow")
|
||||
}
|
||||
ra.TransferSize = int64(t) // #nosec G115 - checked for overflow
|
||||
}
|
||||
|
||||
if d, ok := dict["d"].(int64); ok {
|
||||
ra.DataSize = d
|
||||
} else if d, ok := dict["d"].(uint64); ok {
|
||||
if d > uint64(math.MaxInt64) {
|
||||
return nil, fmt.Errorf("data size overflow")
|
||||
}
|
||||
ra.DataSize = int64(d) // #nosec G115 - checked for overflow
|
||||
}
|
||||
|
||||
if n, ok := dict["n"].(int); ok {
|
||||
ra.Parts = n
|
||||
} else if n, ok := dict["n"].(uint64); ok {
|
||||
if n > uint64(math.MaxInt32) {
|
||||
return nil, fmt.Errorf("parts count overflow")
|
||||
}
|
||||
ra.Parts = int(n) // #nosec G115 - checked for overflow
|
||||
}
|
||||
|
||||
if h, ok := dict["h"].([]byte); ok {
|
||||
ra.Hash = h
|
||||
}
|
||||
|
||||
if r, ok := dict["r"].([]byte); ok {
|
||||
ra.RandomHash = r
|
||||
}
|
||||
|
||||
if o, ok := dict["o"].([]byte); ok {
|
||||
ra.OriginalHash = o
|
||||
}
|
||||
|
||||
if m, ok := dict["m"].([]byte); ok {
|
||||
ra.Hashmap = m
|
||||
}
|
||||
|
||||
if f, ok := dict["f"].(byte); ok {
|
||||
ra.Flags = f
|
||||
} else if f, ok := dict["f"].(uint64); ok {
|
||||
ra.Flags = byte(f)
|
||||
}
|
||||
|
||||
ra.Encrypted = (ra.Flags & 0x01) == 0x01
|
||||
ra.Compressed = ((ra.Flags >> 1) & 0x01) == 0x01
|
||||
ra.Split = ((ra.Flags >> 2) & 0x01) == 0x01
|
||||
ra.IsRequest = ((ra.Flags >> 3) & 0x01) == 0x01
|
||||
ra.IsResponse = ((ra.Flags >> 4) & 0x01) == 0x01
|
||||
ra.HasMetadata = ((ra.Flags >> 5) & 0x01) == 0x01
|
||||
|
||||
if i, ok := dict["i"].(uint16); ok {
|
||||
ra.SegmentIndex = i
|
||||
} else if i, ok := dict["i"].(uint64); ok {
|
||||
if i > uint64(math.MaxUint16) {
|
||||
return nil, fmt.Errorf("segment index overflow")
|
||||
}
|
||||
ra.SegmentIndex = uint16(i) // #nosec G115 - checked for overflow
|
||||
}
|
||||
|
||||
if l, ok := dict["l"].(uint16); ok {
|
||||
ra.TotalSegments = l
|
||||
} else if l, ok := dict["l"].(uint64); ok {
|
||||
if l > uint64(math.MaxUint16) {
|
||||
return nil, fmt.Errorf("total segments overflow")
|
||||
}
|
||||
ra.TotalSegments = uint16(l) // #nosec G115 - checked for overflow
|
||||
}
|
||||
|
||||
if q, ok := dict["q"].([]byte); ok {
|
||||
ra.RequestID = q
|
||||
}
|
||||
|
||||
return ra, nil
|
||||
}
|
||||
|
||||
func getHashmapMaxLen() int {
|
||||
mdu := 384
|
||||
return (mdu - OVERHEAD) / MAPHASH_LEN
|
||||
}
|
||||
|
||||
func IsRequestAdvertisement(data []byte) bool {
|
||||
adv, err := UnpackResourceAdvertisement(data)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return adv.IsRequest && adv.RequestID != nil
|
||||
}
|
||||
|
||||
func IsResponseAdvertisement(data []byte) bool {
|
||||
adv, err := UnpackResourceAdvertisement(data)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return adv.IsResponse && adv.RequestID != nil
|
||||
}
|
||||
|
||||
func ReadRequestID(data []byte) []byte {
|
||||
adv, err := UnpackResourceAdvertisement(data)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return adv.RequestID
|
||||
}
|
||||
|
||||
func ReadTransferSize(data []byte) int64 {
|
||||
adv, err := UnpackResourceAdvertisement(data)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return adv.TransferSize
|
||||
}
|
||||
|
||||
func ReadSize(data []byte) int64 {
|
||||
adv, err := UnpackResourceAdvertisement(data)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return adv.DataSize
|
||||
}
|
||||
@@ -1,3 +1,5 @@
|
||||
// SPDX-License-Identifier: 0BSD
|
||||
// Copyright (c) 2024-2026 Sudo-Ivan / Quad4.io
|
||||
package resource
|
||||
|
||||
import (
|
||||
@@ -58,6 +60,7 @@ const (
|
||||
PROCESSING_GRACE = 1.0
|
||||
RETRY_GRACE_TIME = 0.25
|
||||
PER_RETRY_DELAY = 0.5
|
||||
RESPONSE_MAX_GRACE_TIME = 10.0
|
||||
)
|
||||
|
||||
type Resource struct {
|
||||
@@ -92,6 +95,10 @@ type Resource struct {
|
||||
callback func(*Resource)
|
||||
progressCallback func(*Resource)
|
||||
readOffset int64
|
||||
requestID []byte
|
||||
isResponse bool
|
||||
hashmap []byte
|
||||
parts [][]byte
|
||||
}
|
||||
|
||||
func New(data interface{}, autoCompress bool) (*Resource, error) {
|
||||
@@ -219,12 +226,6 @@ func (r *Resource) GetSegments() uint16 {
|
||||
return r.segments
|
||||
}
|
||||
|
||||
func (r *Resource) IsCompressed() bool {
|
||||
r.mutex.RLock()
|
||||
defer r.mutex.RUnlock()
|
||||
return r.compressed
|
||||
}
|
||||
|
||||
func (r *Resource) Cancel() {
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
@@ -421,3 +422,97 @@ func (r *Resource) GetSize() int64 {
|
||||
defer r.mutex.RUnlock()
|
||||
return r.dataSize
|
||||
}
|
||||
|
||||
func (r *Resource) HasMetadata() bool {
|
||||
r.mutex.RLock()
|
||||
defer r.mutex.RUnlock()
|
||||
return false
|
||||
}
|
||||
|
||||
func (r *Resource) IsRequest() bool {
|
||||
r.mutex.RLock()
|
||||
defer r.mutex.RUnlock()
|
||||
return r.requestID != nil && !r.isResponse
|
||||
}
|
||||
|
||||
func (r *Resource) IsResponse() bool {
|
||||
r.mutex.RLock()
|
||||
defer r.mutex.RUnlock()
|
||||
return r.isResponse
|
||||
}
|
||||
|
||||
func (r *Resource) GetRequestID() []byte {
|
||||
r.mutex.RLock()
|
||||
defer r.mutex.RUnlock()
|
||||
if r.requestID == nil {
|
||||
return nil
|
||||
}
|
||||
return append([]byte{}, r.requestID...)
|
||||
}
|
||||
|
||||
func (r *Resource) SetRequestID(id []byte) {
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
if id == nil {
|
||||
r.requestID = nil
|
||||
return
|
||||
}
|
||||
r.requestID = append([]byte{}, id...)
|
||||
}
|
||||
|
||||
func (r *Resource) SetIsResponse(isResponse bool) {
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
r.isResponse = isResponse
|
||||
}
|
||||
|
||||
func (r *Resource) getHashmap() []byte {
|
||||
r.mutex.RLock()
|
||||
defer r.mutex.RUnlock()
|
||||
if r.hashmap == nil {
|
||||
return nil
|
||||
}
|
||||
return append([]byte{}, r.hashmap...)
|
||||
}
|
||||
|
||||
func (r *Resource) GetRandomHash() []byte {
|
||||
r.mutex.RLock()
|
||||
defer r.mutex.RUnlock()
|
||||
if r.randomHash == nil {
|
||||
return nil
|
||||
}
|
||||
return append([]byte{}, r.randomHash...)
|
||||
}
|
||||
|
||||
func (r *Resource) GetOriginalHash() []byte {
|
||||
r.mutex.RLock()
|
||||
defer r.mutex.RUnlock()
|
||||
if r.originalHash == nil {
|
||||
return nil
|
||||
}
|
||||
return append([]byte{}, r.originalHash...)
|
||||
}
|
||||
|
||||
func (r *Resource) GetSegmentIndex() uint16 {
|
||||
r.mutex.RLock()
|
||||
defer r.mutex.RUnlock()
|
||||
return r.segmentIndex
|
||||
}
|
||||
|
||||
func (r *Resource) GetTotalSegments() uint16 {
|
||||
r.mutex.RLock()
|
||||
defer r.mutex.RUnlock()
|
||||
return r.totalSegments
|
||||
}
|
||||
|
||||
func (r *Resource) IsEncrypted() bool {
|
||||
r.mutex.RLock()
|
||||
defer r.mutex.RUnlock()
|
||||
return r.encrypted
|
||||
}
|
||||
|
||||
func (r *Resource) IsSplit() bool {
|
||||
r.mutex.RLock()
|
||||
defer r.mutex.RUnlock()
|
||||
return r.split
|
||||
}
|
||||
|
||||
153
pkg/resource/resource_test.go
Normal file
153
pkg/resource/resource_test.go
Normal file
@@ -0,0 +1,153 @@
|
||||
package resource
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewResourceFromBytes(t *testing.T) {
|
||||
data := []byte("hello world")
|
||||
r, err := New(data, false)
|
||||
if err != nil {
|
||||
t.Fatalf("New failed: %v", err)
|
||||
}
|
||||
if r.GetDataSize() != int64(len(data)) {
|
||||
t.Errorf("Expected size %d, got %d", len(data), r.GetDataSize())
|
||||
}
|
||||
if r.GetSegments() != 1 {
|
||||
t.Errorf("Expected 1 segment, got %d", r.GetSegments())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewResourceFromFile(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
tmpFile := filepath.Join(tmpDir, "test.txt")
|
||||
data := []byte("file data")
|
||||
err := os.WriteFile(tmpFile, data, 0644)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f, err := os.OpenFile(tmpFile, os.O_RDWR, 0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
r, err := New(f, false)
|
||||
if err != nil {
|
||||
t.Fatalf("New failed: %v", err)
|
||||
}
|
||||
if r.GetDataSize() != int64(len(data)) {
|
||||
t.Errorf("Expected size %d, got %d", len(data), r.GetDataSize())
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetSegmentData(t *testing.T) {
|
||||
data := make([]byte, DEFAULT_SEGMENT_SIZE+100)
|
||||
for i := range data {
|
||||
data[i] = byte(i % 256)
|
||||
}
|
||||
|
||||
r, _ := New(data, false)
|
||||
if r.GetSegments() != 2 {
|
||||
t.Fatalf("Expected 2 segments, got %d", r.GetSegments())
|
||||
}
|
||||
|
||||
seg0, err := r.GetSegmentData(0)
|
||||
if err != nil {
|
||||
t.Fatalf("GetSegmentData(0) failed: %v", err)
|
||||
}
|
||||
if !bytes.Equal(seg0, data[:DEFAULT_SEGMENT_SIZE]) {
|
||||
t.Error("Segment 0 data mismatch")
|
||||
}
|
||||
|
||||
seg1, err := r.GetSegmentData(1)
|
||||
if err != nil {
|
||||
t.Fatalf("GetSegmentData(1) failed: %v", err)
|
||||
}
|
||||
if !bytes.Equal(seg1, data[DEFAULT_SEGMENT_SIZE:]) {
|
||||
t.Error("Segment 1 data mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarkSegmentComplete(t *testing.T) {
|
||||
data := make([]byte, DEFAULT_SEGMENT_SIZE*2)
|
||||
r, _ := New(data, false)
|
||||
|
||||
callbackCalled := false
|
||||
r.SetCallback(func(res *Resource) {
|
||||
callbackCalled = true
|
||||
})
|
||||
|
||||
r.MarkSegmentComplete(0)
|
||||
if r.GetProgress() != 0.5 {
|
||||
t.Errorf("Expected progress 0.5, got %f", r.GetProgress())
|
||||
}
|
||||
if r.GetStatus() != STATUS_PENDING && r.GetStatus() != STATUS_ACTIVE {
|
||||
t.Errorf("Wrong status: %v", r.GetStatus())
|
||||
}
|
||||
|
||||
r.MarkSegmentComplete(1)
|
||||
if r.GetProgress() != 1.0 {
|
||||
t.Errorf("Expected progress 1.0, got %f", r.GetProgress())
|
||||
}
|
||||
if r.GetStatus() != STATUS_COMPLETE {
|
||||
t.Errorf("Expected status COMPLETE, got %v", r.GetStatus())
|
||||
}
|
||||
if !callbackCalled {
|
||||
t.Error("Callback was not called")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRead(t *testing.T) {
|
||||
data := []byte("hello world")
|
||||
r, _ := New(data, false)
|
||||
|
||||
buf := make([]byte, 5)
|
||||
n, err := r.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("Read failed: %v", err)
|
||||
}
|
||||
if n != 5 || !bytes.Equal(buf, []byte("hello")) {
|
||||
t.Errorf("Read wrong data: %q", buf)
|
||||
}
|
||||
|
||||
buf = make([]byte, 10)
|
||||
n, err = r.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("Read failed: %v", err)
|
||||
}
|
||||
if n != 6 || !bytes.Equal(buf[:n], []byte(" world")) {
|
||||
t.Errorf("Read wrong data: %q", buf[:n])
|
||||
}
|
||||
|
||||
n, err = r.Read(buf)
|
||||
if err != io.EOF {
|
||||
t.Errorf("Expected EOF, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCancelActivateFailed(t *testing.T) {
|
||||
data := []byte("test")
|
||||
r, _ := New(data, false)
|
||||
|
||||
r.Activate()
|
||||
if r.GetStatus() != STATUS_ACTIVE {
|
||||
t.Errorf("Expected ACTIVE, got %v", r.GetStatus())
|
||||
}
|
||||
|
||||
r.SetFailed()
|
||||
if r.GetStatus() != STATUS_FAILED {
|
||||
t.Errorf("Expected FAILED, got %v", r.GetStatus())
|
||||
}
|
||||
|
||||
r2, _ := New(data, false)
|
||||
r2.Cancel()
|
||||
if r2.GetStatus() != STATUS_CANCELLED {
|
||||
t.Errorf("Expected CANCELLED, got %v", r2.GetStatus())
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,5 @@
|
||||
// SPDX-License-Identifier: 0BSD
|
||||
// Copyright (c) 2024-2026 Sudo-Ivan / Quad4.io
|
||||
package transport
|
||||
|
||||
import (
|
||||
@@ -7,7 +9,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Sudo-Ivan/reticulum-go/pkg/rate"
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/rate"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -39,7 +41,7 @@ func NewAnnounceManager() *AnnounceManager {
|
||||
return &AnnounceManager{
|
||||
announces: make(map[string]*AnnounceEntry),
|
||||
announceQueue: make(map[string][]*AnnounceEntry),
|
||||
rateLimiter: rate.NewLimiter(DefaultPropagationRate, 1),
|
||||
rateLimiter: rate.NewLimiter(rate.DefaultBurstFreq, 10.0),
|
||||
mutex: sync.RWMutex{},
|
||||
}
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,88 +1,120 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/Sudo-Ivan/reticulum-go/pkg/common"
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/common"
|
||||
)
|
||||
|
||||
func randomBytes(n int) []byte {
|
||||
b := make([]byte, n)
|
||||
_, err := rand.Read(b)
|
||||
type mockInterface struct {
|
||||
common.BaseInterface
|
||||
sent [][]byte
|
||||
}
|
||||
|
||||
func (m *mockInterface) Send(data []byte, address string) error {
|
||||
m.sent = append(m.sent, data)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockInterface) GetName() string {
|
||||
return m.Name
|
||||
}
|
||||
|
||||
func (m *mockInterface) IsEnabled() bool {
|
||||
return m.Enabled
|
||||
}
|
||||
|
||||
func TestNewTransport(t *testing.T) {
|
||||
config := &common.ReticulumConfig{}
|
||||
tr := NewTransport(config)
|
||||
if tr == nil {
|
||||
t.Fatal("NewTransport returned nil")
|
||||
}
|
||||
defer tr.Close()
|
||||
}
|
||||
|
||||
func TestRegisterInterface(t *testing.T) {
|
||||
tr := NewTransport(&common.ReticulumConfig{})
|
||||
defer tr.Close()
|
||||
|
||||
iface := &mockInterface{}
|
||||
iface.Name = "test"
|
||||
err := tr.RegisterInterface("test", iface)
|
||||
if err != nil {
|
||||
panic("Failed to generate random bytes: " + err.Error())
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// BenchmarkTransportDestinationCreation benchmarks destination creation
|
||||
func BenchmarkTransportDestinationCreation(b *testing.B) {
|
||||
// Create a basic config for transport
|
||||
config := &common.ReticulumConfig{
|
||||
ConfigPath: "/tmp/test_config",
|
||||
t.Fatalf("RegisterInterface failed: %v", err)
|
||||
}
|
||||
|
||||
transport := NewTransport(config)
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
// Create destination (this allocates and initializes destination objects)
|
||||
dest := transport.NewDestination(nil, OUT, SINGLE, "test_app")
|
||||
_ = dest // Use the destination to avoid optimization
|
||||
retrieved, err := tr.GetInterface("test")
|
||||
if err != nil {
|
||||
t.Fatalf("GetInterface failed: %v", err)
|
||||
}
|
||||
if retrieved != iface {
|
||||
t.Error("Retrieved interface doesn't match")
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkTransportPathLookup benchmarks path lookup operations
|
||||
func BenchmarkTransportPathLookup(b *testing.B) {
|
||||
// Create a basic config for transport
|
||||
config := &common.ReticulumConfig{
|
||||
ConfigPath: "/tmp/test_config",
|
||||
func TestPathManagement(t *testing.T) {
|
||||
tr := NewTransport(&common.ReticulumConfig{})
|
||||
defer tr.Close()
|
||||
|
||||
destHash := []byte("test-destination-hash")
|
||||
nextHop := []byte("next-hop")
|
||||
iface := &mockInterface{}
|
||||
iface.Name = "iface1"
|
||||
_ = tr.RegisterInterface("iface1", iface)
|
||||
|
||||
tr.UpdatePath(destHash, nextHop, "iface1", 2)
|
||||
|
||||
if !tr.HasPath(destHash) {
|
||||
t.Error("Path not found after update")
|
||||
}
|
||||
|
||||
transport := NewTransport(config)
|
||||
if tr.HopsTo(destHash) != 2 {
|
||||
t.Errorf("Expected 2 hops, got %d", tr.HopsTo(destHash))
|
||||
}
|
||||
|
||||
// Pre-populate with some destinations
|
||||
destHash1 := randomBytes(16)
|
||||
destHash2 := randomBytes(16)
|
||||
destHash3 := randomBytes(16)
|
||||
if !bytes.Equal(tr.NextHop(destHash), nextHop) {
|
||||
t.Error("Next hop mismatch")
|
||||
}
|
||||
|
||||
// Create some destinations
|
||||
transport.NewDestination(nil, OUT, SINGLE, "test_app")
|
||||
transport.NewDestination(nil, OUT, SINGLE, "test_app")
|
||||
transport.NewDestination(nil, OUT, SINGLE, "test_app")
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
// Test path lookup operations (these involve map lookups and allocations)
|
||||
_ = transport.HasPath(destHash1)
|
||||
_ = transport.HasPath(destHash2)
|
||||
_ = transport.HasPath(destHash3)
|
||||
if tr.NextHopInterface(destHash) != "iface1" {
|
||||
t.Errorf("Expected iface1, got %s", tr.NextHopInterface(destHash))
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkTransportHopsCalculation benchmarks hops calculation
|
||||
func BenchmarkTransportHopsCalculation(b *testing.B) {
|
||||
// Create a basic config for transport
|
||||
config := &common.ReticulumConfig{
|
||||
ConfigPath: "/tmp/test_config",
|
||||
}
|
||||
func TestDestinationRegistration(t *testing.T) {
|
||||
tr := NewTransport(&common.ReticulumConfig{})
|
||||
defer tr.Close()
|
||||
|
||||
transport := NewTransport(config)
|
||||
destHash := []byte("dest")
|
||||
tr.RegisterDestination(destHash, "test-dest")
|
||||
|
||||
// Create some destinations
|
||||
destHash := randomBytes(16)
|
||||
transport.NewDestination(nil, OUT, SINGLE, "test_app")
|
||||
tr.mutex.RLock()
|
||||
dest, ok := tr.destinations[string(destHash)]
|
||||
tr.mutex.RUnlock()
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
// Test hops calculation (involves internal data structure access)
|
||||
_ = transport.HopsTo(destHash)
|
||||
if !ok || dest != "test-dest" {
|
||||
t.Error("Destination not registered correctly")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransportStatus(t *testing.T) {
|
||||
tr := NewTransport(&common.ReticulumConfig{})
|
||||
defer tr.Close()
|
||||
|
||||
destHash := []byte("dest")
|
||||
if tr.PathIsUnresponsive(destHash) {
|
||||
t.Error("Path should not be unresponsive initially")
|
||||
}
|
||||
|
||||
tr.MarkPathUnresponsive(destHash)
|
||||
if !tr.PathIsUnresponsive(destHash) {
|
||||
t.Error("Path should be unresponsive")
|
||||
}
|
||||
|
||||
tr.MarkPathResponsive(destHash)
|
||||
if tr.PathIsUnresponsive(destHash) {
|
||||
t.Error("Path should be responsive again")
|
||||
}
|
||||
}
|
||||
|
||||
476
pkg/wasm/wasm.go
Normal file
476
pkg/wasm/wasm.go
Normal file
@@ -0,0 +1,476 @@
|
||||
// SPDX-License-Identifier: 0BSD
|
||||
// Copyright (c) 2024-2026 Sudo-Ivan / Quad4.io
|
||||
//go:build js && wasm
|
||||
// +build js,wasm
|
||||
|
||||
package wasm
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"syscall/js"
|
||||
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/common"
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/debug"
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/destination"
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/identity"
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/interfaces"
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/packet"
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/transport"
|
||||
)
|
||||
|
||||
var (
|
||||
reticulumTransport *transport.Transport
|
||||
reticulumDest *destination.Destination
|
||||
reticulumIdentity *identity.Identity
|
||||
stats = struct {
|
||||
packetsSent int
|
||||
packetsReceived int
|
||||
bytesSent int
|
||||
bytesReceived int
|
||||
}{}
|
||||
packetCallback js.Value
|
||||
announceHandler js.Value
|
||||
)
|
||||
|
||||
// RegisterJSFunctions registers the Reticulum WASM API to the JavaScript global scope.
|
||||
func RegisterJSFunctions() {
|
||||
js.Global().Set("reticulum", js.ValueOf(map[string]interface{}{
|
||||
"init": js.FuncOf(InitReticulum),
|
||||
"getIdentity": js.FuncOf(GetIdentity),
|
||||
"getDestination": js.FuncOf(GetDestination),
|
||||
"connect": js.FuncOf(ConnectWebSocket),
|
||||
"disconnect": js.FuncOf(DisconnectWebSocket),
|
||||
"isConnected": js.FuncOf(IsConnected),
|
||||
"requestPath": js.FuncOf(RequestPath),
|
||||
"getStats": js.FuncOf(GetStats),
|
||||
"setPacketCallback": js.FuncOf(SetPacketCallback),
|
||||
"setAnnounceCallback": js.FuncOf(SetAnnounceCallback),
|
||||
"sendData": js.FuncOf(SendDataJS),
|
||||
"announce": js.FuncOf(SendAnnounceJS),
|
||||
}))
|
||||
}
|
||||
|
||||
func SetPacketCallback(this js.Value, args []js.Value) interface{} {
|
||||
if len(args) > 0 && args[0].Type() == js.TypeFunction {
|
||||
packetCallback = args[0]
|
||||
return js.ValueOf(true)
|
||||
}
|
||||
return js.ValueOf(false)
|
||||
}
|
||||
|
||||
func SetAnnounceCallback(this js.Value, args []js.Value) interface{} {
|
||||
if len(args) > 0 && args[0].Type() == js.TypeFunction {
|
||||
announceHandler = args[0]
|
||||
return js.ValueOf(true)
|
||||
}
|
||||
return js.ValueOf(false)
|
||||
}
|
||||
|
||||
func RequestPath(this js.Value, args []js.Value) interface{} {
|
||||
if len(args) < 1 {
|
||||
return js.ValueOf(map[string]interface{}{
|
||||
"error": "Destination hash required",
|
||||
})
|
||||
}
|
||||
|
||||
destHashHex := args[0].String()
|
||||
destHash, err := hex.DecodeString(destHashHex)
|
||||
if err != nil {
|
||||
return js.ValueOf(map[string]interface{}{
|
||||
"error": fmt.Sprintf("Invalid destination hash: %v", err),
|
||||
})
|
||||
}
|
||||
|
||||
if reticulumTransport == nil {
|
||||
return js.ValueOf(map[string]interface{}{
|
||||
"error": "Reticulum not initialized",
|
||||
})
|
||||
}
|
||||
|
||||
if err := reticulumTransport.RequestPath(destHash, "", nil, true); err != nil {
|
||||
return js.ValueOf(map[string]interface{}{
|
||||
"error": fmt.Sprintf("Failed to request path: %v", err),
|
||||
})
|
||||
}
|
||||
|
||||
return js.ValueOf(map[string]interface{}{
|
||||
"success": true,
|
||||
})
|
||||
}
|
||||
|
||||
func GetStats(this js.Value, args []js.Value) interface{} {
|
||||
return js.ValueOf(map[string]interface{}{
|
||||
"packetsSent": stats.packetsSent,
|
||||
"packetsReceived": stats.packetsReceived,
|
||||
"bytesSent": stats.bytesSent,
|
||||
"bytesReceived": stats.bytesReceived,
|
||||
})
|
||||
}
|
||||
|
||||
func InitReticulum(this js.Value, args []js.Value) interface{} {
|
||||
if len(args) < 1 {
|
||||
return js.ValueOf(map[string]interface{}{
|
||||
"error": "WebSocket URL required",
|
||||
})
|
||||
}
|
||||
|
||||
if reticulumTransport != nil {
|
||||
reticulumTransport.Close()
|
||||
reticulumTransport = nil
|
||||
}
|
||||
|
||||
wsURL := args[0].String()
|
||||
appName := "wasm_core"
|
||||
if len(args) >= 2 && args[1].Type() == js.TypeString {
|
||||
appName = args[1].String()
|
||||
}
|
||||
|
||||
var id *identity.Identity
|
||||
var err error
|
||||
|
||||
// Check for existing identity in args
|
||||
if len(args) >= 3 && !args[2].IsNull() && !args[2].IsUndefined() {
|
||||
idHex := args[2].String()
|
||||
idBytes, decodeErr := hex.DecodeString(idHex)
|
||||
if decodeErr == nil && len(idBytes) == 64 {
|
||||
id, err = identity.FromBytes(idBytes)
|
||||
if err != nil {
|
||||
debug.Log(debug.DEBUG_ERROR, "Failed to load provided identity, generating new one", "error", err)
|
||||
id, err = identity.NewIdentity()
|
||||
}
|
||||
} else {
|
||||
id, err = identity.NewIdentity()
|
||||
}
|
||||
} else {
|
||||
id, err = identity.NewIdentity()
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return js.ValueOf(map[string]interface{}{
|
||||
"error": fmt.Sprintf("Failed to handle identity: %v", err),
|
||||
})
|
||||
}
|
||||
|
||||
cfg := common.DefaultConfig()
|
||||
t := transport.NewTransport(cfg)
|
||||
|
||||
dest, err := destination.New(
|
||||
id,
|
||||
destination.IN,
|
||||
destination.SINGLE,
|
||||
appName,
|
||||
t,
|
||||
"browser",
|
||||
)
|
||||
if err != nil {
|
||||
return js.ValueOf(map[string]interface{}{
|
||||
"error": fmt.Sprintf("Failed to create destination: %v", err),
|
||||
})
|
||||
}
|
||||
|
||||
dest.SetPacketCallback(func(data []byte, ni common.NetworkInterface) {
|
||||
stats.packetsReceived++
|
||||
stats.bytesReceived += len(data)
|
||||
|
||||
if !packetCallback.IsUndefined() {
|
||||
// Convert bytes to JS Uint8Array for performance and compatibility
|
||||
uint8Array := js.Global().Get("Uint8Array").New(len(data))
|
||||
js.CopyBytesToJS(uint8Array, data)
|
||||
packetCallback.Invoke(uint8Array)
|
||||
}
|
||||
})
|
||||
|
||||
dest.SetProofStrategy(destination.PROVE_ALL)
|
||||
|
||||
t.RegisterAnnounceHandler(&genericAnnounceHandler{})
|
||||
|
||||
wsInterface, err := interfaces.NewWebSocketInterface("wasm0", wsURL, true)
|
||||
if err != nil {
|
||||
return js.ValueOf(map[string]interface{}{
|
||||
"error": fmt.Sprintf("Failed to create WebSocket interface: %v", err),
|
||||
})
|
||||
}
|
||||
|
||||
wsInterface.SetPacketCallback(func(data []byte, ni common.NetworkInterface) {
|
||||
msg := fmt.Sprintf("Received packet: %d bytes (type: 0x%02x)", len(data), data[0])
|
||||
js.Global().Call("log", msg, "success")
|
||||
debug.Log(debug.DEBUG_INFO, "WASM received packet", "bytes", len(data), "type", fmt.Sprintf("0x%02x", data[0]))
|
||||
t.HandlePacket(data, ni)
|
||||
})
|
||||
|
||||
if err := t.RegisterInterface("wasm0", wsInterface); err != nil {
|
||||
return js.ValueOf(map[string]interface{}{
|
||||
"error": fmt.Sprintf("Failed to register interface: %v", err),
|
||||
})
|
||||
}
|
||||
|
||||
if err := t.Start(); err != nil {
|
||||
return js.ValueOf(map[string]interface{}{
|
||||
"error": fmt.Sprintf("Failed to start transport: %v", err),
|
||||
})
|
||||
}
|
||||
|
||||
reticulumTransport = t
|
||||
reticulumDest = dest
|
||||
reticulumIdentity = id
|
||||
|
||||
return js.ValueOf(map[string]interface{}{
|
||||
"success": true,
|
||||
"identity": id.GetHexHash(),
|
||||
"privateKey": hex.EncodeToString(id.GetPrivateKey()),
|
||||
"destination": fmt.Sprintf("%x", dest.GetHash()),
|
||||
})
|
||||
}
|
||||
|
||||
// GetTransport returns the internal transport pointer.
|
||||
func GetTransport() *transport.Transport {
|
||||
return reticulumTransport
|
||||
}
|
||||
|
||||
// GetDestinationPointer returns the internal destination pointer.
|
||||
func GetDestinationPointer() *destination.Destination {
|
||||
return reticulumDest
|
||||
}
|
||||
|
||||
func GetIdentity(this js.Value, args []js.Value) interface{} {
|
||||
if reticulumIdentity == nil {
|
||||
return js.ValueOf(map[string]interface{}{
|
||||
"error": "Reticulum not initialized",
|
||||
})
|
||||
}
|
||||
|
||||
return js.ValueOf(map[string]interface{}{
|
||||
"hash": reticulumIdentity.GetHexHash(),
|
||||
})
|
||||
}
|
||||
|
||||
func GetDestination(this js.Value, args []js.Value) interface{} {
|
||||
if reticulumDest == nil {
|
||||
return js.ValueOf(map[string]interface{}{
|
||||
"error": "Reticulum not initialized",
|
||||
})
|
||||
}
|
||||
|
||||
return js.ValueOf(map[string]interface{}{
|
||||
"hash": fmt.Sprintf("%x", reticulumDest.GetHash()),
|
||||
})
|
||||
}
|
||||
|
||||
func IsConnected(this js.Value, args []js.Value) interface{} {
|
||||
if reticulumTransport == nil {
|
||||
return js.ValueOf(false)
|
||||
}
|
||||
|
||||
ifaces := reticulumTransport.GetInterfaces()
|
||||
for _, iface := range ifaces {
|
||||
if iface.IsOnline() {
|
||||
return js.ValueOf(true)
|
||||
}
|
||||
}
|
||||
|
||||
return js.ValueOf(false)
|
||||
}
|
||||
|
||||
func ConnectWebSocket(this js.Value, args []js.Value) interface{} {
|
||||
if reticulumTransport == nil {
|
||||
return js.ValueOf(map[string]interface{}{
|
||||
"error": "Reticulum not initialized",
|
||||
})
|
||||
}
|
||||
|
||||
ifaces := reticulumTransport.GetInterfaces()
|
||||
for name, iface := range ifaces {
|
||||
if iface.IsOnline() {
|
||||
return js.ValueOf(map[string]interface{}{
|
||||
"success": true,
|
||||
"interface": name,
|
||||
})
|
||||
}
|
||||
if err := iface.Start(); err != nil {
|
||||
return js.ValueOf(map[string]interface{}{
|
||||
"error": fmt.Sprintf("Failed to connect: %v", err),
|
||||
})
|
||||
}
|
||||
return js.ValueOf(map[string]interface{}{
|
||||
"success": true,
|
||||
"interface": name,
|
||||
})
|
||||
}
|
||||
|
||||
return js.ValueOf(map[string]interface{}{
|
||||
"error": "WebSocket interface not found",
|
||||
})
|
||||
}
|
||||
|
||||
func DisconnectWebSocket(this js.Value, args []js.Value) interface{} {
|
||||
if reticulumTransport == nil {
|
||||
return js.ValueOf(map[string]interface{}{
|
||||
"error": "Reticulum not initialized",
|
||||
})
|
||||
}
|
||||
|
||||
ifaces := reticulumTransport.GetInterfaces()
|
||||
for _, iface := range ifaces {
|
||||
if err := iface.Stop(); err != nil {
|
||||
return js.ValueOf(map[string]interface{}{
|
||||
"error": fmt.Sprintf("Failed to stop interface: %v", err),
|
||||
})
|
||||
}
|
||||
return js.ValueOf(map[string]interface{}{
|
||||
"success": true,
|
||||
})
|
||||
}
|
||||
|
||||
return js.ValueOf(map[string]interface{}{
|
||||
"error": "WebSocket interface not found",
|
||||
})
|
||||
}
|
||||
|
||||
type genericAnnounceHandler struct{}
|
||||
|
||||
func (h *genericAnnounceHandler) AspectFilter() []string {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *genericAnnounceHandler) ReceivePathResponses() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (h *genericAnnounceHandler) ReceivedAnnounce(destHash []byte, ident interface{}, appData []byte, hops uint8) error {
|
||||
if !announceHandler.IsUndefined() {
|
||||
hashStr := hex.EncodeToString(destHash)
|
||||
announceHandler.Invoke(js.ValueOf(map[string]interface{}{
|
||||
"hash": hashStr,
|
||||
"appData": string(appData),
|
||||
"hops": int(hops),
|
||||
}))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendDataJS is the JS-facing wrapper for SendData
|
||||
func SendDataJS(this js.Value, args []js.Value) interface{} {
|
||||
if len(args) < 2 {
|
||||
return js.ValueOf(map[string]interface{}{
|
||||
"error": "Destination hash and data required",
|
||||
})
|
||||
}
|
||||
|
||||
destHashHex := args[0].String()
|
||||
destHash, err := hex.DecodeString(destHashHex)
|
||||
if err != nil {
|
||||
return js.ValueOf(map[string]interface{}{
|
||||
"error": fmt.Sprintf("Invalid destination hash: %v", err),
|
||||
})
|
||||
}
|
||||
|
||||
// Support both string and Uint8Array data from JS
|
||||
var data []byte
|
||||
if args[1].Type() == js.TypeString {
|
||||
data = []byte(args[1].String())
|
||||
} else {
|
||||
data = make([]byte, args[1].Length())
|
||||
js.CopyBytesToGo(data, args[1])
|
||||
}
|
||||
|
||||
return SendData(destHash, data)
|
||||
}
|
||||
|
||||
// SendData is a generic function to send raw bytes to a destination
|
||||
func SendData(destHash []byte, data []byte) interface{} {
|
||||
if reticulumTransport == nil {
|
||||
return js.ValueOf(map[string]interface{}{
|
||||
"error": "Reticulum not initialized",
|
||||
})
|
||||
}
|
||||
|
||||
remoteIdentity, err := identity.Recall(destHash)
|
||||
if err != nil {
|
||||
return js.ValueOf(map[string]interface{}{
|
||||
"error": fmt.Sprintf("Identity not found. Wait for an announce from this peer!"),
|
||||
})
|
||||
}
|
||||
|
||||
targetDest, err := destination.FromHash(destHash, remoteIdentity, destination.SINGLE, reticulumTransport)
|
||||
if err != nil {
|
||||
return js.ValueOf(map[string]interface{}{
|
||||
"error": fmt.Sprintf("Failed to create target destination: %v", err),
|
||||
})
|
||||
}
|
||||
|
||||
encrypted, err := targetDest.Encrypt(data)
|
||||
if err != nil {
|
||||
return js.ValueOf(map[string]interface{}{
|
||||
"error": fmt.Sprintf("Encryption failed: %v", err),
|
||||
})
|
||||
}
|
||||
|
||||
pkt := packet.NewPacket(
|
||||
packet.DestinationSingle,
|
||||
encrypted,
|
||||
packet.PacketTypeData,
|
||||
packet.ContextNone,
|
||||
packet.PropagationBroadcast,
|
||||
packet.HeaderType1,
|
||||
nil,
|
||||
true,
|
||||
packet.FlagUnset,
|
||||
)
|
||||
pkt.DestinationHash = destHash
|
||||
|
||||
if err := pkt.Pack(); err != nil {
|
||||
return js.ValueOf(map[string]interface{}{
|
||||
"error": fmt.Sprintf("Packet packing failed: %v", err),
|
||||
})
|
||||
}
|
||||
|
||||
if err := reticulumTransport.SendPacket(pkt); err != nil {
|
||||
return js.ValueOf(map[string]interface{}{
|
||||
"error": fmt.Sprintf("Packet sending failed: %v", err),
|
||||
})
|
||||
}
|
||||
|
||||
stats.packetsSent++
|
||||
stats.bytesSent += len(data)
|
||||
|
||||
return js.ValueOf(map[string]interface{}{
|
||||
"success": true,
|
||||
})
|
||||
}
|
||||
|
||||
// SendAnnounceJS is the JS-facing wrapper for SendAnnounce
|
||||
func SendAnnounceJS(this js.Value, args []js.Value) interface{} {
|
||||
var appData []byte
|
||||
if len(args) >= 1 && args[0].Type() == js.TypeString {
|
||||
appData = []byte(args[0].String())
|
||||
} else if len(args) >= 1 && args[0].Type() == js.TypeObject {
|
||||
appData = make([]byte, args[0].Length())
|
||||
js.CopyBytesToGo(appData, args[0])
|
||||
}
|
||||
return SendAnnounce(appData)
|
||||
}
|
||||
|
||||
// SendAnnounce is a generic function to send an announce
|
||||
func SendAnnounce(appData []byte) interface{} {
|
||||
if reticulumDest == nil {
|
||||
return js.ValueOf(map[string]interface{}{
|
||||
"error": "Reticulum not initialized",
|
||||
})
|
||||
}
|
||||
|
||||
if len(appData) > 0 {
|
||||
reticulumDest.SetDefaultAppData(appData)
|
||||
}
|
||||
|
||||
if err := reticulumDest.Announce(false, nil, nil); err != nil {
|
||||
return js.ValueOf(map[string]interface{}{
|
||||
"error": fmt.Sprintf("Failed to send announce: %v", err),
|
||||
})
|
||||
}
|
||||
|
||||
return js.ValueOf(map[string]interface{}{
|
||||
"success": true,
|
||||
})
|
||||
}
|
||||
|
||||
147
pkg/wasm/wasm_test.go
Normal file
147
pkg/wasm/wasm_test.go
Normal file
@@ -0,0 +1,147 @@
|
||||
//go:build js && wasm
|
||||
// +build js,wasm
|
||||
|
||||
package wasm
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"syscall/js"
|
||||
"testing"
|
||||
|
||||
"git.quad4.io/Networks/Reticulum-Go/pkg/identity"
|
||||
)
|
||||
|
||||
func TestRegisterJSFunctions(t *testing.T) {
|
||||
RegisterJSFunctions()
|
||||
|
||||
reticulum := js.Global().Get("reticulum")
|
||||
if reticulum.IsUndefined() {
|
||||
t.Fatal("reticulum object not registered in global scope")
|
||||
}
|
||||
|
||||
functions := []string{
|
||||
"init", "getIdentity", "getDestination", "announce",
|
||||
"connect", "disconnect", "isConnected", "requestPath", "getStats",
|
||||
"setPacketCallback", "setAnnounceCallback", "sendData",
|
||||
}
|
||||
|
||||
for _, fn := range functions {
|
||||
if reticulum.Get(fn).Type() != js.TypeFunction {
|
||||
t.Errorf("function %s not registered or not a function", fn)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetStats(t *testing.T) {
|
||||
// Reset stats
|
||||
stats.packetsSent = 10
|
||||
stats.packetsReceived = 5
|
||||
stats.bytesSent = 100
|
||||
stats.bytesReceived = 50
|
||||
|
||||
result := GetStats(js.Undefined(), nil)
|
||||
val := result.(js.Value)
|
||||
|
||||
if val.Get("packetsSent").Int() != 10 {
|
||||
t.Errorf("expected packetsSent 10, got %d", val.Get("packetsSent").Int())
|
||||
}
|
||||
if val.Get("packetsReceived").Int() != 5 {
|
||||
t.Errorf("expected packetsReceived 5, got %d", val.Get("packetsReceived").Int())
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsConnected(t *testing.T) {
|
||||
reticulumTransport = nil
|
||||
connected := IsConnected(js.Undefined(), nil).(js.Value).Bool()
|
||||
if connected {
|
||||
t.Error("expected connected to be false when transport is nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInitReticulum(t *testing.T) {
|
||||
// Mock JS global functions
|
||||
js.Global().Set("log", js.FuncOf(func(this js.Value, args []js.Value) interface{} { return nil }))
|
||||
|
||||
// Test without arguments
|
||||
result := InitReticulum(js.Undefined(), []js.Value{})
|
||||
val := result.(js.Value)
|
||||
if val.Get("error").IsUndefined() || val.Get("error").String() != "WebSocket URL required" {
|
||||
t.Errorf("expected error 'WebSocket URL required', got %v", val.Get("error"))
|
||||
}
|
||||
|
||||
// Test with valid URL and app name
|
||||
wsURL := "ws://localhost:8080"
|
||||
appName := "test_app"
|
||||
result = InitReticulum(js.Undefined(), []js.Value{js.ValueOf(wsURL), js.ValueOf(appName)})
|
||||
val = result.(js.Value)
|
||||
|
||||
if !val.Get("success").Bool() {
|
||||
t.Errorf("InitReticulum failed: %v", val.Get("error"))
|
||||
}
|
||||
|
||||
if reticulumIdentity == nil {
|
||||
t.Fatal("reticulumIdentity should not be nil after successful init")
|
||||
}
|
||||
|
||||
// Test with provided identity
|
||||
id, _ := identity.NewIdentity()
|
||||
idHex := id.GetHexHash()
|
||||
// InitReticulum expects the FULL identity bytes in hex (64 bytes).
|
||||
idBytes := id.GetPrivateKey()
|
||||
idHexFull := hex.EncodeToString(idBytes)
|
||||
|
||||
result = InitReticulum(js.Undefined(), []js.Value{js.ValueOf(wsURL), js.ValueOf(appName), js.ValueOf(idHexFull)})
|
||||
val = result.(js.Value)
|
||||
|
||||
if !val.Get("success").Bool() {
|
||||
t.Errorf("InitReticulum with identity failed: %v", val.Get("error"))
|
||||
}
|
||||
|
||||
if reticulumIdentity.GetHexHash() != idHex {
|
||||
t.Errorf("expected identity hash %s, got %s", idHex, reticulumIdentity.GetHexHash())
|
||||
}
|
||||
}
|
||||
|
||||
func TestIdentityAndDestination(t *testing.T) {
|
||||
// Ensure initialized
|
||||
js.Global().Set("log", js.FuncOf(func(this js.Value, args []js.Value) interface{} { return nil }))
|
||||
InitReticulum(js.Undefined(), []js.Value{js.ValueOf("ws://localhost")})
|
||||
|
||||
idResult := GetIdentity(js.Undefined(), nil).(js.Value)
|
||||
if idResult.Get("hash").String() != reticulumIdentity.GetHexHash() {
|
||||
t.Error("GetIdentity returned wrong hash")
|
||||
}
|
||||
|
||||
destResult := GetDestination(js.Undefined(), nil).(js.Value)
|
||||
expectedDest := fmt.Sprintf("%x", reticulumDest.GetHash())
|
||||
if destResult.Get("hash").String() != expectedDest {
|
||||
t.Errorf("GetDestination returned wrong hash, expected %s got %s", expectedDest, destResult.Get("hash").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendDataJS(t *testing.T) {
|
||||
// Ensure initialized
|
||||
InitReticulum(js.Undefined(), []js.Value{js.ValueOf("ws://localhost")})
|
||||
|
||||
// Create a mock peer
|
||||
peerId, _ := identity.NewIdentity()
|
||||
peerHash := peerId.Hash()
|
||||
peerHashHex := hex.EncodeToString(peerHash)
|
||||
|
||||
// Manually add to known destinations so Recall works
|
||||
identity.Remember([]byte("mock_packet"), peerHash, peerId.GetPublicKey(), []byte("peer_app_data"))
|
||||
|
||||
// Test SendDataJS with string
|
||||
data := "Hello Peer!"
|
||||
result := SendDataJS(js.Undefined(), []js.Value{js.ValueOf(peerHashHex), js.ValueOf(data)}).(js.Value)
|
||||
|
||||
if !result.Get("error").IsUndefined() {
|
||||
errStr := result.Get("error").String()
|
||||
if errStr != "Packet sending failed: no path to destination" {
|
||||
t.Errorf("SendDataJS failed with unexpected error: %s", errStr)
|
||||
}
|
||||
} else if !result.Get("success").Bool() {
|
||||
t.Errorf("SendDataJS failed without error message")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user