package security import ( "bytes" "io" "net" "net/http" "net/http/httptest" "software-station/internal/models" "software-station/internal/stats" "strings" "testing" "time" "golang.org/x/time/rate" ) func TestThrottledReader(t *testing.T) { content := []byte("hello world") inner := io.NopCloser(bytes.NewReader(content)) limiter := rate.NewLimiter(rate.Limit(100), 100) fp := "test-fp" statsService := stats.NewService("test-hashes.json") statsService.KnownHashes.Lock() statsService.KnownHashes.Data[fp] = &models.FingerprintData{Known: true} statsService.KnownHashes.Unlock() tr := &ThrottledReader{ R: inner, Limiter: limiter, Fingerprint: fp, Stats: statsService, } p := make([]byte, 5) n, err := tr.Read(p) if err != nil { t.Fatalf("Read failed: %v", err) } if n != 5 { t.Errorf("expected 5 bytes, got %d", n) } statsService.KnownHashes.RLock() data, ok := statsService.KnownHashes.Data["test-fp"] statsService.KnownHashes.RUnlock() if !ok { t.Fatal("fingerprint data not found") } total := data.TotalBytes if total != 5 { t.Errorf("expected 5 bytes in stats, got %d", total) } tr.Close() } func TestGetRequestFingerprint(t *testing.T) { statsService := stats.NewService("test-hashes.json") // IPv4 req1 := httptest.NewRequest("GET", "/", nil) req1.RemoteAddr = "1.2.3.4:1234" f1 := GetRequestFingerprint(req1, statsService) req2 := httptest.NewRequest("GET", "/", nil) req2.RemoteAddr = "1.2.3.4:5678" f2 := GetRequestFingerprint(req2, statsService) if f1 != f2 { t.Error("fingerprints should match for same IPv4") } // X-Forwarded-For req3 := httptest.NewRequest("GET", "/", nil) req3.Header.Set("X-Forwarded-For", "5.6.7.8, 1.2.3.4") f3 := GetRequestFingerprint(req3, statsService) if f1 == f3 { t.Error("fingerprints should differ for different IPs") } // IPv6 masking req4 := httptest.NewRequest("GET", "/", nil) req4.RemoteAddr = "[2001:db8::1]:1234" f4 := GetRequestFingerprint(req4, statsService) req5 := httptest.NewRequest("GET", "/", nil) req5.RemoteAddr = "[2001:db8::2]:1234" f5 := GetRequestFingerprint(req5, statsService) if f4 != f5 { t.Error("fingerprints should match for same IPv6 /64 prefix") } } func TestSecurityMiddleware(t *testing.T) { statsService := stats.NewService("test-hashes.json") handler := SecurityMiddleware(statsService)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) // Test bot blocking req := httptest.NewRequest("GET", "/", nil) req.Header.Set("User-Agent", "Googlebot") rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusForbidden { t.Errorf("expected 403 for bot, got %d", rr.Code) } // Test forbidden pattern req = httptest.NewRequest("GET", "/.git/config", nil) rr = httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusForbidden { t.Errorf("expected 403 for forbidden pattern, got %d", rr.Code) } // Test normal request req = httptest.NewRequest("GET", "/api/software", nil) rr = httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Errorf("expected 200 for normal request, got %d", rr.Code) } } func TestIsPrivateIP_Extended(t *testing.T) { tests := []struct { ip string isPrivate bool }{ {"127.0.0.1", true}, {"10.0.0.1", true}, {"172.16.0.1", true}, {"192.168.1.1", true}, {"0.0.0.0", true}, {"::1", true}, {"::", true}, {"fd00::1", true}, {"8.8.8.8", false}, {"1.1.1.1", false}, } for _, tt := range tests { ip := net.ParseIP(tt.ip) if got := IsPrivateIP(ip); got != tt.isPrivate { t.Errorf("IsPrivateIP(%s) = %v; want %v", tt.ip, got, tt.isPrivate) } } } func TestGetSafeHTTPClient_BlocksPrivateIPs(t *testing.T) { client := GetSafeHTTPClient(1 * time.Second) // Test 127.0.0.1 _, err := client.Get("http://127.0.0.1:12345") if err == nil { t.Error("Expected error for 127.0.0.1, got nil") } else if !strings.Contains(err.Error(), "SSRF protection") { t.Logf("Got error for 127.0.0.1: %v", err) if !strings.Contains(err.Error(), "SSRF protection") { t.Errorf("Expected 'SSRF protection' error, got: %v", err) } } // Test 0.0.0.0 _, err = client.Get("http://0.0.0.0:12345") if err == nil { t.Error("Expected error for 0.0.0.0, got nil") } else if !strings.Contains(err.Error(), "SSRF protection") { t.Logf("Got error for 0.0.0.0: %v", err) if !strings.Contains(err.Error(), "SSRF protection") { t.Errorf("Expected 'SSRF protection' error, got: %v", err) } } }