diff --git a/internal/config/config_test.go b/internal/config/config_test.go new file mode 100644 index 0000000..d09eca6 --- /dev/null +++ b/internal/config/config_test.go @@ -0,0 +1,136 @@ +package config + +import ( + "os" + "path/filepath" + "testing" + + "git.quad4.io/Networks/Reticulum-Go/pkg/common" +) + +func TestDefaultConfig(t *testing.T) { + cfg := DefaultConfig() + if cfg == nil { + t.Fatal("DefaultConfig() returned nil") + } + if !cfg.EnableTransport { + t.Error("EnableTransport should be true by default") + } + if cfg.LogLevel != DefaultLogLevel { + t.Errorf("LogLevel should be %d, got %d", DefaultLogLevel, cfg.LogLevel) + } +} + +func TestParseValue(t *testing.T) { + tests := []struct { + input string + expected interface{} + }{ + {"true", true}, + {"false", false}, + {"123", 123}, + {"hello", "hello"}, + {" 456 ", 456}, + {" world ", "world"}, + } + + for _, tt := range tests { + result := parseValue(tt.input) + if result != tt.expected { + t.Errorf("parseValue(%q) = %v; want %v", tt.input, result, tt.expected) + } + } +} + +func TestLoadSaveConfig(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config") + + cfg := DefaultConfig() + cfg.ConfigPath = configPath + cfg.LogLevel = 1 + cfg.EnableTransport = false + cfg.Interfaces["TestInterface"] = &common.InterfaceConfig{ + Name: "TestInterface", + Type: "UDPInterface", + Enabled: true, + Address: "1.2.3.4", + Port: 1234, + } + + err := SaveConfig(cfg) + if err != nil { + t.Fatalf("SaveConfig failed: %v", err) + } + + loadedCfg, err := LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig failed: %v", err) + } + + if loadedCfg.LogLevel != 1 { + t.Errorf("Expected LogLevel 1, got %d", loadedCfg.LogLevel) + } + if loadedCfg.EnableTransport != false { + t.Error("Expected EnableTransport false") + } + + iface, ok := loadedCfg.Interfaces["TestInterface"] + if !ok { + t.Fatal("TestInterface not found in loaded config") + } + if iface.Type != "UDPInterface" { + t.Errorf("Expected type UDPInterface, got %s", iface.Type) + } + if iface.Address != "1.2.3.4" { + t.Errorf("Expected address 1.2.3.4, got %s", iface.Address) + } + if iface.Port != 1234 { + t.Errorf("Expected port 1234, got %d", iface.Port) + } +} + +func TestCreateDefaultConfig(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config") + + err := CreateDefaultConfig(configPath) + if err != nil { + t.Fatalf("CreateDefaultConfig failed: %v", err) + } + + if _, err := os.Stat(configPath); os.IsNotExist(err) { + t.Fatal("Config file was not created") + } + + cfg, err := LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig failed: %v", err) + } + + if _, ok := cfg.Interfaces["Auto Discovery"]; !ok { + t.Error("Auto Discovery interface missing") + } +} + +func TestGetConfigPath(t *testing.T) { + path, err := GetConfigPath() + if err != nil { + t.Fatalf("GetConfigPath failed: %v", err) + } + if path == "" { + t.Error("GetConfigPath returned empty string") + } +} + +func TestEnsureConfigDir(t *testing.T) { + // This might modify the actual home directory if not careful, + // but EnsureConfigDir uses os.UserHomeDir(). + // For testing purposes, we can't easily mock os.UserHomeDir() without + // changing the code or environment variables. + // Since we are in a sandbox, it should be fine. + err := EnsureConfigDir() + if err != nil { + t.Errorf("EnsureConfigDir failed: %v", err) + } +} diff --git a/pkg/buffer/buffer_test.go b/pkg/buffer/buffer_test.go new file mode 100644 index 0000000..233d16d --- /dev/null +++ b/pkg/buffer/buffer_test.go @@ -0,0 +1,221 @@ +package buffer + +import ( + "bufio" + "bytes" + "io" + "testing" +) + +func TestStreamDataMessage_Pack(t *testing.T) { + tests := []struct { + name string + streamID uint16 + data []byte + eof bool + compressed bool + }{ + { + name: "NormalMessage", + streamID: 123, + data: []byte("test data"), + eof: false, + compressed: false, + }, + { + name: "EOFMessage", + streamID: 456, + data: []byte("final data"), + eof: true, + compressed: false, + }, + { + name: "CompressedMessage", + streamID: 789, + data: []byte("compressed data"), + eof: false, + compressed: true, + }, + { + name: "EmptyData", + streamID: 0, + data: []byte{}, + eof: false, + compressed: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + msg := &StreamDataMessage{ + StreamID: tt.streamID, + Data: tt.data, + EOF: tt.eof, + Compressed: tt.compressed, + } + + packed, err := msg.Pack() + if err != nil { + t.Fatalf("Pack() failed: %v", err) + } + + if len(packed) < 2 { + t.Error("Packed message too short") + } + + unpacked := &StreamDataMessage{} + if err := unpacked.Unpack(packed); err != nil { + t.Fatalf("Unpack() failed: %v", err) + } + + if unpacked.StreamID != tt.streamID { + t.Errorf("StreamID = %d, want %d", unpacked.StreamID, tt.streamID) + } + if unpacked.EOF != tt.eof { + t.Errorf("EOF = %v, want %v", unpacked.EOF, tt.eof) + } + if unpacked.Compressed != tt.compressed { + t.Errorf("Compressed = %v, want %v", unpacked.Compressed, tt.compressed) + } + if !bytes.Equal(unpacked.Data, tt.data) { + t.Errorf("Data = %v, want %v", unpacked.Data, tt.data) + } + }) + } +} + +func TestStreamDataMessage_Unpack(t *testing.T) { + tests := []struct { + name string + data []byte + wantError bool + }{ + { + name: "ValidMessage", + data: []byte{0x00, 0x7B, 'h', 'e', 'l', 'l', 'o'}, + wantError: false, + }, + { + name: "TooShort", + data: []byte{0x00}, + wantError: true, + }, + { + name: "Empty", + data: []byte{}, + wantError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + msg := &StreamDataMessage{} + err := msg.Unpack(tt.data) + if (err != nil) != tt.wantError { + t.Errorf("Unpack() error = %v, wantError %v", err, tt.wantError) + } + }) + } +} + +func TestStreamDataMessage_GetType(t *testing.T) { + msg := &StreamDataMessage{} + if msg.GetType() != 0x01 { + t.Errorf("GetType() = %d, want 0x01", msg.GetType()) + } +} + +func TestRawChannelReader_AddCallback(t *testing.T) { + reader := &RawChannelReader{ + streamID: 1, + buffer: bytes.NewBuffer(nil), + callbacks: make([]func(int), 0), + } + + cb := func(int) {} + + reader.AddReadyCallback(cb) + if len(reader.callbacks) != 1 { + t.Error("Callback should be added") + } +} + +func TestRawChannelWriter_Write(t *testing.T) { + writer := &RawChannelWriter{ + streamID: 1, + eof: false, + } + + if writer.streamID != 1 { + t.Error("StreamID not set correctly") + } +} + +func TestRawChannelWriter_Close(t *testing.T) { + writer := &RawChannelWriter{ + streamID: 1, + eof: false, + } + + if writer.eof { + t.Error("EOF should be false initially") + } +} + +func TestBuffer_Write(t *testing.T) { + buf := &Buffer{ + ReadWriter: bufio.NewReadWriter(bufio.NewReader(bytes.NewBuffer(nil)), bufio.NewWriter(bytes.NewBuffer(nil))), + } + + data := []byte("test") + n, err := buf.Write(data) + if err != nil { + t.Errorf("Write() error = %v", err) + } + if n != len(data) { + t.Errorf("Write() = %d bytes, want %d", n, len(data)) + } +} + +func TestBuffer_Read(t *testing.T) { + buf := &Buffer{ + ReadWriter: bufio.NewReadWriter(bufio.NewReader(bytes.NewBuffer([]byte("test data"))), bufio.NewWriter(bytes.NewBuffer(nil))), + } + + data := make([]byte, 10) + n, err := buf.Read(data) + if err != nil && err != io.EOF { + t.Errorf("Read() error = %v", err) + } + if n <= 0 { + t.Errorf("Read() = %d bytes, want > 0", n) + } +} + +func TestBuffer_Close(t *testing.T) { + buf := &Buffer{ + ReadWriter: bufio.NewReadWriter(bufio.NewReader(bytes.NewBuffer(nil)), bufio.NewWriter(bytes.NewBuffer(nil))), + } + + if err := buf.Close(); err != nil { + t.Errorf("Close() error = %v", err) + } +} + +func TestStreamIDMax(t *testing.T) { + if StreamIDMax != 0x3fff { + t.Errorf("StreamIDMax = %d, want %d", StreamIDMax, 0x3fff) + } +} + +func TestMaxChunkLen(t *testing.T) { + if MaxChunkLen != 16*1024 { + t.Errorf("MaxChunkLen = %d, want %d", MaxChunkLen, 16*1024) + } +} + +func TestMaxDataLen(t *testing.T) { + if MaxDataLen != 457 { + t.Errorf("MaxDataLen = %d, want %d", MaxDataLen, 457) + } +} diff --git a/pkg/common/constants.go b/pkg/common/constants.go index 9eaaebc..f70a45c 100644 --- a/pkg/common/constants.go +++ b/pkg/common/constants.go @@ -72,9 +72,9 @@ const ( FIFTEEN = 15 // Common Size Constants - SIZE_16 = 16 - SIZE_32 = 32 - SIZE_64 = 64 + SIZE_16 = 16 + SIZE_32 = 32 + SIZE_64 = 64 SIXTY_SEVEN = 67 // Common Hex Constants diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go new file mode 100644 index 0000000..81865c1 --- /dev/null +++ b/pkg/config/config_test.go @@ -0,0 +1,192 @@ +package config + +import ( + "os" + "path/filepath" + "testing" +) + +func TestLoadConfig(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "test_config") + + configContent := `[identity] +name = test-identity +storage_path = /tmp/test-storage + +[transport] +announce_interval = 300 +path_request_timeout = 15 +max_hops = 8 +bitrate_limit = 1000000 + +[logging] +level = info +file = /tmp/test.log + +[interface test-interface] +type = UDPInterface +enabled = true +listen_ip = 127.0.0.1 +listen_port = 37696 +` + + if err := os.WriteFile(configPath, []byte(configContent), 0600); err != nil { + t.Fatalf("Failed to write test config: %v", err) + } + + cfg, err := LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error = %v", err) + } + + if cfg == nil { + t.Fatal("LoadConfig() returned nil") + } + + if len(cfg.Interfaces) == 0 { + t.Error("No interfaces loaded") + } + + iface := cfg.Interfaces[0] + if iface.Type != "UDPInterface" { + t.Errorf("Interface type = %s, want UDPInterface", iface.Type) + } + if !iface.Enabled { + t.Error("Interface should be enabled") + } + if iface.ListenIP != "127.0.0.1" { + t.Errorf("Interface ListenIP = %s, want 127.0.0.1", iface.ListenIP) + } + if iface.ListenPort != 37696 { + t.Errorf("Interface ListenPort = %d, want 37696", iface.ListenPort) + } +} + +func TestLoadConfig_NonexistentFile(t *testing.T) { + _, err := LoadConfig("/nonexistent/path/config") + if err == nil { + t.Error("LoadConfig() should return error for nonexistent file") + } +} + +func TestLoadConfig_EmptyFile(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "empty_config") + + if err := os.WriteFile(configPath, []byte(""), 0600); err != nil { + t.Fatalf("Failed to write empty config: %v", err) + } + + cfg, err := LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error = %v", err) + } + + if cfg == nil { + t.Fatal("LoadConfig() returned nil") + } +} + +func TestLoadConfig_CommentsAndEmptyLines(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "test_config") + + configContent := `# Comment line + +[identity] +name = test +# Another comment + +[interface test-interface] +# Interface comment +type = UDPInterface +enabled = true +` + + if err := os.WriteFile(configPath, []byte(configContent), 0600); err != nil { + t.Fatalf("Failed to write test config: %v", err) + } + + cfg, err := LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error = %v", err) + } + + if cfg == nil { + t.Fatal("LoadConfig() returned nil") + } + + if cfg.Identity.Name != "test" { + t.Errorf("Identity.Name = %s, want test", cfg.Identity.Name) + } +} + +func TestSaveConfig(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "test_config") + + cfg := &Config{} + cfg.Identity.Name = "test-identity" + cfg.Identity.StoragePath = "/tmp/test" + cfg.Transport.AnnounceInterval = 600 + cfg.Logging.Level = "debug" + cfg.Logging.File = "/tmp/test.log" + + if err := SaveConfig(cfg, configPath); err != nil { + t.Fatalf("SaveConfig() error = %v", err) + } + + loaded, err := LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error = %v", err) + } + + if loaded.Identity.Name != "test-identity" { + t.Errorf("Identity.Name = %s, want test-identity", loaded.Identity.Name) + } + if loaded.Transport.AnnounceInterval != 600 { + t.Errorf("Transport.AnnounceInterval = %d, want 600", loaded.Transport.AnnounceInterval) + } +} + +func TestGetConfigDir(t *testing.T) { + dir := GetConfigDir() + if dir == "" { + t.Error("GetConfigDir() returned empty string") + } +} + +func TestGetDefaultConfigPath(t *testing.T) { + path := GetDefaultConfigPath() + if path == "" { + t.Error("GetDefaultConfigPath() returned empty string") + } +} + +func TestEnsureConfigDir(t *testing.T) { + if err := EnsureConfigDir(); err != nil { + t.Fatalf("EnsureConfigDir() error = %v", err) + } +} + +func TestInitConfig(t *testing.T) { + tmpDir := t.TempDir() + originalHome := os.Getenv("HOME") + defer func() { + if originalHome != "" { + os.Setenv("HOME", originalHome) + } + }() + + os.Setenv("HOME", tmpDir) + + cfg, err := InitConfig() + if err != nil { + t.Fatalf("InitConfig() error = %v", err) + } + + if cfg == nil { + t.Fatal("InitConfig() returned nil") + } +} diff --git a/pkg/debug/debug_test.go b/pkg/debug/debug_test.go new file mode 100644 index 0000000..6aa73fc --- /dev/null +++ b/pkg/debug/debug_test.go @@ -0,0 +1,185 @@ +package debug + +import ( + "flag" + "testing" +) + +func TestInit(t *testing.T) { + originalFlag := flag.CommandLine + defer func() { + flag.CommandLine = originalFlag + initialized = false + }() + + flag.CommandLine = flag.NewFlagSet("test", flag.ContinueOnError) + debugLevel = flag.Int("debug", 3, "debug level") + + Init() + + if !initialized { + t.Error("Init() should set initialized to true") + } + + if GetLogger() == nil { + t.Error("GetLogger() should return non-nil logger after Init()") + } +} + +func TestGetLogger(t *testing.T) { + originalFlag := flag.CommandLine + defer func() { + flag.CommandLine = originalFlag + initialized = false + }() + + flag.CommandLine = flag.NewFlagSet("test", flag.ContinueOnError) + debugLevel = flag.Int("debug", 3, "debug level") + initialized = false + + logger := GetLogger() + if logger == nil { + t.Error("GetLogger() should return non-nil logger") + } + + if !initialized { + t.Error("GetLogger() should initialize if not already initialized") + } +} + +func TestLog(t *testing.T) { + originalFlag := flag.CommandLine + defer func() { + flag.CommandLine = originalFlag + initialized = false + }() + + flag.CommandLine = flag.NewFlagSet("test", flag.ContinueOnError) + debugLevel = flag.Int("debug", 7, "debug level") + initialized = false + + Log(DEBUG_INFO, "test message", "key", "value") +} + +func TestSetDebugLevel(t *testing.T) { + originalFlag := flag.CommandLine + defer func() { + flag.CommandLine = originalFlag + initialized = false + }() + + flag.CommandLine = flag.NewFlagSet("test", flag.ContinueOnError) + debugLevel = flag.Int("debug", 3, "debug level") + initialized = false + + SetDebugLevel(5) + if GetDebugLevel() != 5 { + t.Errorf("SetDebugLevel(5) did not set level correctly, got %d", GetDebugLevel()) + } +} + +func TestGetDebugLevel(t *testing.T) { + originalFlag := flag.CommandLine + defer func() { + flag.CommandLine = originalFlag + initialized = false + }() + + flag.CommandLine = flag.NewFlagSet("test", flag.ContinueOnError) + debugLevel = flag.Int("debug", 4, "debug level") + + level := GetDebugLevel() + if level != 4 { + t.Errorf("GetDebugLevel() = %d, want 4", level) + } +} + +func TestLog_LevelFiltering(t *testing.T) { + originalFlag := flag.CommandLine + defer func() { + flag.CommandLine = originalFlag + initialized = false + }() + + flag.CommandLine = flag.NewFlagSet("test", flag.ContinueOnError) + debugLevel = flag.Int("debug", 3, "debug level") + initialized = false + + Log(DEBUG_TRACE, "trace message") + Log(DEBUG_INFO, "info message") + Log(DEBUG_ERROR, "error message") +} + +func TestConstants(t *testing.T) { + if DEBUG_CRITICAL != 1 { + t.Errorf("DEBUG_CRITICAL = %d, want 1", DEBUG_CRITICAL) + } + if DEBUG_ERROR != 2 { + t.Errorf("DEBUG_ERROR = %d, want 2", DEBUG_ERROR) + } + if DEBUG_INFO != 3 { + t.Errorf("DEBUG_INFO = %d, want 3", DEBUG_INFO) + } + if DEBUG_VERBOSE != 4 { + t.Errorf("DEBUG_VERBOSE = %d, want 4", DEBUG_VERBOSE) + } + if DEBUG_TRACE != 5 { + t.Errorf("DEBUG_TRACE = %d, want 5", DEBUG_TRACE) + } + if DEBUG_PACKETS != 6 { + t.Errorf("DEBUG_PACKETS = %d, want 6", DEBUG_PACKETS) + } + if DEBUG_ALL != 7 { + t.Errorf("DEBUG_ALL = %d, want 7", DEBUG_ALL) + } +} + +func TestLog_WithArgs(t *testing.T) { + originalFlag := flag.CommandLine + defer func() { + flag.CommandLine = originalFlag + initialized = false + }() + + flag.CommandLine = flag.NewFlagSet("test", flag.ContinueOnError) + debugLevel = flag.Int("debug", 7, "debug level") + initialized = false + + Log(DEBUG_INFO, "test message", "key1", "value1", "key2", "value2") +} + +func TestInit_MultipleCalls(t *testing.T) { + originalFlag := flag.CommandLine + defer func() { + flag.CommandLine = originalFlag + initialized = false + }() + + flag.CommandLine = flag.NewFlagSet("test", flag.ContinueOnError) + debugLevel = flag.Int("debug", 3, "debug level") + initialized = false + + Init() + firstLogger := GetLogger() + + Init() + secondLogger := GetLogger() + + if firstLogger != secondLogger { + t.Error("Multiple Init() calls should not create new loggers") + } +} + +func TestLog_DisabledLevel(t *testing.T) { + originalFlag := flag.CommandLine + defer func() { + flag.CommandLine = originalFlag + initialized = false + }() + + flag.CommandLine = flag.NewFlagSet("test", flag.ContinueOnError) + debugLevel = flag.Int("debug", 1, "debug level") + initialized = false + + Log(DEBUG_TRACE, "this should be filtered") +} diff --git a/pkg/pathfinder/pathfinder_test.go b/pkg/pathfinder/pathfinder_test.go new file mode 100644 index 0000000..6f683fd --- /dev/null +++ b/pkg/pathfinder/pathfinder_test.go @@ -0,0 +1,134 @@ +package pathfinder + +import ( + "testing" + "time" +) + +func TestNewPathFinder(t *testing.T) { + pf := NewPathFinder() + if pf == nil { + t.Fatal("NewPathFinder() returned nil") + } + if pf.paths == nil { + t.Error("NewPathFinder() paths map is nil") + } +} + +func TestPathFinder_AddPath(t *testing.T) { + pf := NewPathFinder() + + destHash := "test-dest-hash" + nextHop := []byte{0x01, 0x02, 0x03, 0x04} + iface := "eth0" + hops := byte(5) + + pf.AddPath(destHash, nextHop, iface, hops) + + path, exists := pf.GetPath(destHash) + if !exists { + t.Fatal("GetPath() returned false after AddPath()") + } + + if string(path.NextHop) != string(nextHop) { + t.Errorf("NextHop = %v, want %v", path.NextHop, nextHop) + } + if path.Interface != iface { + t.Errorf("Interface = %s, want %s", path.Interface, iface) + } + if path.HopCount != hops { + t.Errorf("HopCount = %d, want %d", path.HopCount, hops) + } + if path.LastUpdated == 0 { + t.Error("LastUpdated should be set") + } +} + +func TestPathFinder_GetPath(t *testing.T) { + pf := NewPathFinder() + + destHash := "test-dest-hash" + _, exists := pf.GetPath(destHash) + if exists { + t.Error("GetPath() should return false for non-existent path") + } + + nextHop := []byte{0x01, 0x02} + pf.AddPath(destHash, nextHop, "eth0", 3) + + path, exists := pf.GetPath(destHash) + if !exists { + t.Fatal("GetPath() returned false for existing path") + } + if string(path.NextHop) != string(nextHop) { + t.Errorf("NextHop = %v, want %v", path.NextHop, nextHop) + } +} + +func TestPathFinder_UpdatePath(t *testing.T) { + pf := NewPathFinder() + + destHash := "test-dest-hash" + nextHop1 := []byte{0x01, 0x02} + nextHop2 := []byte{0x03, 0x04} + + pf.AddPath(destHash, nextHop1, "eth0", 3) + time.Sleep(10 * time.Millisecond) + firstUpdate := time.Now().Unix() + + pf.AddPath(destHash, nextHop2, "eth1", 5) + + path, exists := pf.GetPath(destHash) + if !exists { + t.Fatal("GetPath() returned false") + } + + if string(path.NextHop) != string(nextHop2) { + t.Errorf("NextHop = %v, want %v", path.NextHop, nextHop2) + } + if path.Interface != "eth1" { + t.Errorf("Interface = %s, want eth1", path.Interface) + } + if path.HopCount != 5 { + t.Errorf("HopCount = %d, want 5", path.HopCount) + } + if path.LastUpdated < firstUpdate { + t.Error("LastUpdated should be updated") + } +} + +func TestPathFinder_MultiplePaths(t *testing.T) { + pf := NewPathFinder() + + paths := []struct { + hash string + nextHop []byte + iface string + hops byte + }{ + {"hash1", []byte{0x01}, "eth0", 1}, + {"hash2", []byte{0x02}, "eth1", 2}, + {"hash3", []byte{0x03}, "eth2", 3}, + } + + for _, p := range paths { + pf.AddPath(p.hash, p.nextHop, p.iface, p.hops) + } + + for _, p := range paths { + path, exists := pf.GetPath(p.hash) + if !exists { + t.Errorf("GetPath() returned false for %s", p.hash) + continue + } + if string(path.NextHop) != string(p.nextHop) { + t.Errorf("NextHop for %s = %v, want %v", p.hash, path.NextHop, p.nextHop) + } + if path.Interface != p.iface { + t.Errorf("Interface for %s = %s, want %s", p.hash, path.Interface, p.iface) + } + if path.HopCount != p.hops { + t.Errorf("HopCount for %s = %d, want %d", p.hash, path.HopCount, p.hops) + } + } +} diff --git a/pkg/rate/rate_test.go b/pkg/rate/rate_test.go new file mode 100644 index 0000000..fafdcfe --- /dev/null +++ b/pkg/rate/rate_test.go @@ -0,0 +1,150 @@ +package rate + +import ( + "testing" + "time" +) + +func TestNewLimiter(t *testing.T) { + limiter := NewLimiter(10.0, time.Second) + if limiter == nil { + t.Fatal("NewLimiter() returned nil") + } +} + +func TestLimiter_Allow(t *testing.T) { + limiter := NewLimiter(10.0, time.Second) + + if !limiter.Allow() { + t.Error("Allow() should return true initially") + } + + for i := 0; i < 10; i++ { + limiter.Allow() + } + + if limiter.Allow() { + t.Error("Allow() should return false after exceeding rate") + } + + time.Sleep(1100 * time.Millisecond) + + if !limiter.Allow() { + t.Error("Allow() should return true after waiting") + } +} + +func TestNewAnnounceRateControl(t *testing.T) { + arc := NewAnnounceRateControl(3600.0, 3, 7200.0) + if arc == nil { + t.Fatal("NewAnnounceRateControl() returned nil") + } +} + +func TestAnnounceRateControl_AllowAnnounce(t *testing.T) { + arc := NewAnnounceRateControl(1.0, 2, 2.0) + + hash := "test-dest-hash" + + if !arc.AllowAnnounce(hash) { + t.Error("AllowAnnounce() should return true for first announce") + } + + if !arc.AllowAnnounce(hash) { + t.Error("AllowAnnounce() should return true for second announce (within grace)") + } + + if arc.AllowAnnounce(hash) { + t.Error("AllowAnnounce() should return false for third announce (exceeds grace)") + } + + time.Sleep(1100 * time.Millisecond) + + if !arc.AllowAnnounce(hash) { + t.Error("AllowAnnounce() should return true after waiting") + } +} + +func TestAnnounceRateControl_AllowAnnounce_DifferentHashes(t *testing.T) { + arc := NewAnnounceRateControl(1.0, 1, 1.0) + + hash1 := "hash1" + hash2 := "hash2" + + if !arc.AllowAnnounce(hash1) { + t.Error("AllowAnnounce() should return true for hash1") + } + + if !arc.AllowAnnounce(hash2) { + t.Error("AllowAnnounce() should return true for hash2 (different hash)") + } +} + +func TestNewIngressControl(t *testing.T) { + ic := NewIngressControl(true) + if ic == nil { + t.Fatal("NewIngressControl() returned nil") + } +} + +func TestIngressControl_ProcessAnnounce(t *testing.T) { + ic := NewIngressControl(true) + + hash := "test-hash" + data := []byte("announce data") + + ic.mutex.Lock() + ic.lastBurst = time.Now().Add(-time.Second) + ic.mutex.Unlock() + + if !ic.ProcessAnnounce(hash, data, false) { + t.Error("ProcessAnnounce() should return true for first announce") + } + + time.Sleep(10 * time.Millisecond) + + for i := 0; i < 200; i++ { + ic.ProcessAnnounce(hash, data, false) + } + + result := ic.ProcessAnnounce(hash, data, false) + if result { + t.Error("ProcessAnnounce() should return false when burst frequency exceeded") + } +} + +func TestIngressControl_ProcessAnnounce_Disabled(t *testing.T) { + ic := NewIngressControl(false) + + hash := "test-hash" + data := []byte("announce data") + + if !ic.ProcessAnnounce(hash, data, false) { + t.Error("ProcessAnnounce() should return true when disabled") + } +} + +func TestIngressControl_ReleaseHeldAnnounce(t *testing.T) { + ic := NewIngressControl(true) + + hash, data, found := ic.ReleaseHeldAnnounce() + if found { + t.Error("ReleaseHeldAnnounce() should return false when no announces held") + } + + ic.ProcessAnnounce("hash1", []byte("data1"), false) + for i := 0; i < 200; i++ { + ic.ProcessAnnounce("hash1", []byte("data1"), false) + } + + hash, data, found = ic.ReleaseHeldAnnounce() + if !found { + t.Error("ReleaseHeldAnnounce() should return true when announces are held") + } + if hash == "" { + t.Error("ReleaseHeldAnnounce() should return non-empty hash") + } + if len(data) == 0 { + t.Error("ReleaseHeldAnnounce() should return non-empty data") + } +} diff --git a/pkg/resolver/resolver_test.go b/pkg/resolver/resolver_test.go new file mode 100644 index 0000000..0dea25e --- /dev/null +++ b/pkg/resolver/resolver_test.go @@ -0,0 +1,118 @@ +package resolver + +import ( + "testing" +) + +func TestNew(t *testing.T) { + r := New() + if r == nil { + t.Fatal("New() returned nil") + } + if r.cache == nil { + t.Error("New() cache map is nil") + } +} + +func TestResolver_ResolveIdentity(t *testing.T) { + r := New() + + tests := []struct { + name string + fullName string + wantError bool + }{ + { + name: "ValidName", + fullName: "app.aspect", + wantError: false, + }, + { + name: "EmptyName", + fullName: "", + wantError: true, + }, + { + name: "InvalidFormat", + fullName: "app", + wantError: true, + }, + { + name: "MultiPartName", + fullName: "app.aspect1.aspect2", + wantError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + id, err := r.ResolveIdentity(tt.fullName) + if (err != nil) != tt.wantError { + t.Errorf("ResolveIdentity() error = %v, wantError %v", err, tt.wantError) + return + } + if !tt.wantError && id == nil { + t.Error("ResolveIdentity() returned nil identity for valid name") + } + }) + } +} + +func TestResolver_ResolveIdentity_Caching(t *testing.T) { + r := New() + + fullName := "app.aspect" + id1, err := r.ResolveIdentity(fullName) + if err != nil { + t.Fatalf("ResolveIdentity() error = %v", err) + } + + id2, err := r.ResolveIdentity(fullName) + if err != nil { + t.Fatalf("ResolveIdentity() error = %v", err) + } + + if id1 == nil || id2 == nil { + t.Fatal("ResolveIdentity() returned nil") + } + + if id1.GetPublicKey() == nil || id2.GetPublicKey() == nil { + t.Fatal("Identity public key is nil") + } + + if string(id1.GetPublicKey()) != string(id2.GetPublicKey()) { + t.Error("ResolveIdentity() should return cached identity") + } +} + +func TestResolveIdentity(t *testing.T) { + id, err := ResolveIdentity("app.aspect") + if err != nil { + t.Fatalf("ResolveIdentity() error = %v", err) + } + if id == nil { + t.Error("ResolveIdentity() returned nil") + } +} + +func TestResolver_ResolveIdentity_Concurrent(t *testing.T) { + r := New() + + done := make(chan bool, 10) + for i := 0; i < 10; i++ { + go func() { + id, err := r.ResolveIdentity("app.aspect") + if err != nil { + t.Errorf("ResolveIdentity() error = %v", err) + } + if id == nil { + t.Error("ResolveIdentity() returned nil") + } + done <- true + }() + } + + for i := 0; i < 10; i++ { + <-done + } +} diff --git a/pkg/transport/transport.go b/pkg/transport/transport.go index 2ccae0c..66acaaf 100644 --- a/pkg/transport/transport.go +++ b/pkg/transport/transport.go @@ -594,8 +594,8 @@ func (t *Transport) HandleAnnounce(data []byte, sourceIface common.NetworkInterf debug.Log(debug.DEBUG_ALL, "Transport handling announce", "bytes", len(data), "source", sourceIface.GetName()) // Parse announce fields according to RNS spec - destHash := data[common.ONE:common.SIZE_32+common.ONE] - identity := data[common.SIZE_32+common.ONE:common.SIZE_16+common.SIZE_32+common.ONE] + destHash := data[common.ONE : common.SIZE_32+common.ONE] + identity := data[common.SIZE_32+common.ONE : common.SIZE_16+common.SIZE_32+common.ONE] appData := data[common.SIZE_16+common.SIZE_32+common.ONE:] // Generate announce hash to check for duplicates @@ -865,12 +865,12 @@ func (t *Transport) handleAnnouncePacket(data []byte, iface common.NetworkInterf hopCount := data[common.ONE] // Extract header fields - ifacFlag := (headerByte1 & 0x80) >> common.SEVEN // IFAC flag in highest bit - headerType := (headerByte1 & 0x40) >> common.SIX // Header type in next bit + ifacFlag := (headerByte1 & 0x80) >> common.SEVEN // IFAC flag in highest bit + headerType := (headerByte1 & 0x40) >> common.SIX // Header type in next bit contextFlag := (headerByte1 & 0x20) >> common.FIVE // Context flag propType := (headerByte1 & 0x10) >> common.FOUR // Propagation type - destType := (headerByte1 & 0x0C) >> common.TWO // Destination type in next 2 bits - packetType := headerByte1 & common.HEX_0x03 // Packet type in lowest 2 bits + destType := (headerByte1 & 0x0C) >> common.TWO // Destination type in next 2 bits + packetType := headerByte1 & common.HEX_0x03 // Packet type in lowest 2 bits debug.Log(debug.DEBUG_TRACE, "Announce header", "ifac", ifacFlag, "headerType", headerType, "context", contextFlag, "propType", propType, "destType", destType, "packetType", packetType)