Files
software-station/internal/security/security_test.go
2025-12-27 02:57:25 -06:00

177 lines
4.5 KiB
Go

package security
import (
"bytes"
"io"
"net"
"net/http"
"net/http/httptest"
"software-station/internal/models"
"software-station/internal/stats"
"strings"
"testing"
"time"
"golang.org/x/time/rate"
)
func TestThrottledReader(t *testing.T) {
content := []byte("hello world")
inner := io.NopCloser(bytes.NewReader(content))
limiter := rate.NewLimiter(rate.Limit(100), 100)
fp := "test-fp"
statsService := stats.NewService("test-hashes.json")
statsService.KnownHashes.Lock()
statsService.KnownHashes.Data[fp] = &models.FingerprintData{Known: true}
statsService.KnownHashes.Unlock()
tr := &ThrottledReader{
R: inner,
Limiter: limiter,
Fingerprint: fp,
Stats: statsService,
}
p := make([]byte, 5)
n, err := tr.Read(p)
if err != nil {
t.Fatalf("Read failed: %v", err)
}
if n != 5 {
t.Errorf("expected 5 bytes, got %d", n)
}
statsService.KnownHashes.RLock()
data, ok := statsService.KnownHashes.Data["test-fp"]
statsService.KnownHashes.RUnlock()
if !ok {
t.Fatal("fingerprint data not found")
}
total := data.TotalBytes
if total != 5 {
t.Errorf("expected 5 bytes in stats, got %d", total)
}
tr.Close()
}
func TestGetRequestFingerprint(t *testing.T) {
statsService := stats.NewService("test-hashes.json")
// IPv4
req1 := httptest.NewRequest("GET", "/", nil)
req1.RemoteAddr = "1.2.3.4:1234"
f1 := GetRequestFingerprint(req1, statsService)
req2 := httptest.NewRequest("GET", "/", nil)
req2.RemoteAddr = "1.2.3.4:5678"
f2 := GetRequestFingerprint(req2, statsService)
if f1 != f2 {
t.Error("fingerprints should match for same IPv4")
}
// X-Forwarded-For
req3 := httptest.NewRequest("GET", "/", nil)
req3.Header.Set("X-Forwarded-For", "5.6.7.8, 1.2.3.4")
f3 := GetRequestFingerprint(req3, statsService)
if f1 == f3 {
t.Error("fingerprints should differ for different IPs")
}
// IPv6 masking
req4 := httptest.NewRequest("GET", "/", nil)
req4.RemoteAddr = "[2001:db8::1]:1234"
f4 := GetRequestFingerprint(req4, statsService)
req5 := httptest.NewRequest("GET", "/", nil)
req5.RemoteAddr = "[2001:db8::2]:1234"
f5 := GetRequestFingerprint(req5, statsService)
if f4 != f5 {
t.Error("fingerprints should match for same IPv6 /64 prefix")
}
}
func TestSecurityMiddleware(t *testing.T) {
statsService := stats.NewService("test-hashes.json")
handler := SecurityMiddleware(statsService)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// Test bot blocking
req := httptest.NewRequest("GET", "/", nil)
req.Header.Set("User-Agent", "Googlebot")
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusForbidden {
t.Errorf("expected 403 for bot, got %d", rr.Code)
}
// Test forbidden pattern
req = httptest.NewRequest("GET", "/.git/config", nil)
rr = httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusForbidden {
t.Errorf("expected 403 for forbidden pattern, got %d", rr.Code)
}
// Test normal request
req = httptest.NewRequest("GET", "/api/software", nil)
rr = httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("expected 200 for normal request, got %d", rr.Code)
}
}
func TestIsPrivateIP_Extended(t *testing.T) {
tests := []struct {
ip string
isPrivate bool
}{
{"127.0.0.1", true},
{"10.0.0.1", true},
{"172.16.0.1", true},
{"192.168.1.1", true},
{"0.0.0.0", true},
{"::1", true},
{"::", true},
{"fd00::1", true},
{"8.8.8.8", false},
{"1.1.1.1", false},
}
for _, tt := range tests {
ip := net.ParseIP(tt.ip)
if got := IsPrivateIP(ip); got != tt.isPrivate {
t.Errorf("IsPrivateIP(%s) = %v; want %v", tt.ip, got, tt.isPrivate)
}
}
}
func TestGetSafeHTTPClient_BlocksPrivateIPs(t *testing.T) {
client := GetSafeHTTPClient(1 * time.Second)
// Test 127.0.0.1
_, err := client.Get("http://127.0.0.1:12345")
if err == nil {
t.Error("Expected error for 127.0.0.1, got nil")
} else if !strings.Contains(err.Error(), "SSRF protection") {
t.Logf("Got error for 127.0.0.1: %v", err)
if !strings.Contains(err.Error(), "SSRF protection") {
t.Errorf("Expected 'SSRF protection' error, got: %v", err)
}
}
// Test 0.0.0.0
_, err = client.Get("http://0.0.0.0:12345")
if err == nil {
t.Error("Expected error for 0.0.0.0, got nil")
} else if !strings.Contains(err.Error(), "SSRF protection") {
t.Logf("Got error for 0.0.0.0: %v", err)
if !strings.Contains(err.Error(), "SSRF protection") {
t.Errorf("Expected 'SSRF protection' error, got: %v", err)
}
}
}