Files
software-station/internal/security/security_test.go
Sudo-Ivan bd7fd93a00 Improve request fingerprinting and security middleware
- Updated GetRequestFingerprint to include additional headers (Sec-CH-UA-Platform, Sec-CH-UA-Mobile) and UID cookie for improved uniqueness.
- Modified SecurityMiddleware to set a new UID cookie if not present, enhancing user tracking and security.
- Adjusted test cases to reflect changes in fingerprinting logic and ensure accurate validation of request parameters.
2025-12-27 03:35:36 -06:00

193 lines
5.1 KiB
Go

package security
import (
"bytes"
"io"
"net"
"net/http"
"net/http/httptest"
"software-station/internal/models"
"software-station/internal/stats"
"strings"
"testing"
"time"
"golang.org/x/time/rate"
)
func TestThrottledReader(t *testing.T) {
content := []byte("hello world")
inner := io.NopCloser(bytes.NewReader(content))
limiter := rate.NewLimiter(rate.Limit(100), 100)
fp := "test-fp"
statsService := stats.NewService("test-hashes.json")
statsService.KnownHashes.Lock()
statsService.KnownHashes.Data[fp] = &models.FingerprintData{Known: true}
statsService.KnownHashes.Unlock()
tr := &ThrottledReader{
R: inner,
Limiter: limiter,
Fingerprint: fp,
Stats: statsService,
}
p := make([]byte, 5)
n, err := tr.Read(p)
if err != nil {
t.Fatalf("Read failed: %v", err)
}
if n != 5 {
t.Errorf("expected 5 bytes, got %d", n)
}
statsService.KnownHashes.RLock()
data, ok := statsService.KnownHashes.Data["test-fp"]
statsService.KnownHashes.RUnlock()
if !ok {
t.Fatal("fingerprint data not found")
}
total := data.TotalBytes
if total != 5 {
t.Errorf("expected 5 bytes in stats, got %d", total)
}
tr.Close()
}
func TestGetRequestFingerprint(t *testing.T) {
statsService := stats.NewService("test-hashes.json")
// Same IP, same headers
req1 := httptest.NewRequest("GET", "/", nil)
req1.RemoteAddr = "1.2.3.4:1234"
req1.Header.Set("User-Agent", "Mozilla/5.0")
req1.Header.Set("Sec-CH-UA", `"Google Chrome";v="123"`)
f1 := GetRequestFingerprint(req1, statsService)
req2 := httptest.NewRequest("GET", "/", nil)
req2.RemoteAddr = "1.2.3.4:5678"
req2.Header.Set("User-Agent", "Mozilla/5.0")
req2.Header.Set("Sec-CH-UA", `"Google Chrome";v="123"`)
f2 := GetRequestFingerprint(req2, statsService)
if f1 != f2 {
t.Error("fingerprints should match for same parameters")
}
// Different UID cookie
req3 := httptest.NewRequest("GET", "/", nil)
req3.RemoteAddr = "1.2.3.4:1234"
req3.Header.Set("User-Agent", "Mozilla/5.0")
req3.Header.Set("Sec-CH-UA", `"Google Chrome";v="123"`)
req3.AddCookie(&http.Cookie{Name: "_ss_uid", Value: "uid1"})
f3 := GetRequestFingerprint(req3, statsService)
if f1 == f3 {
t.Error("fingerprints should differ with different UID cookies")
}
// Different Client Hint
req4 := httptest.NewRequest("GET", "/", nil)
req4.RemoteAddr = "1.2.3.4:1234"
req4.Header.Set("User-Agent", "Mozilla/5.0")
req4.Header.Set("Sec-CH-UA", `"Brave";v="123"`)
f4 := GetRequestFingerprint(req4, statsService)
if f1 == f4 {
t.Error("fingerprints should differ with different Client Hints")
}
}
func TestSecurityMiddleware(t *testing.T) {
statsService := stats.NewService("test-hashes.json")
botBlocker := NewBotBlocker("")
handler := SecurityMiddleware(statsService, botBlocker)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// Test bot blocking
req := httptest.NewRequest("GET", "/", nil)
req.Header.Set("User-Agent", "Googlebot")
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusForbidden {
t.Errorf("expected 403 for bot, got %d", rr.Code)
}
// Test forbidden pattern
req = httptest.NewRequest("GET", "/.git/config", nil)
rr = httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusForbidden {
t.Errorf("expected 403 for forbidden pattern, got %d", rr.Code)
}
// Test normal request and cookie setting
req = httptest.NewRequest("GET", "/api/software", nil)
rr = httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("expected 200 for normal request, got %d", rr.Code)
}
cookieFound := false
for _, c := range rr.Result().Cookies() {
if c.Name == "_ss_uid" {
cookieFound = true
break
}
}
if !cookieFound {
t.Error("expected _ss_uid cookie to be set")
}
}
func TestIsPrivateIP_Extended(t *testing.T) {
tests := []struct {
ip string
isPrivate bool
}{
{"127.0.0.1", true},
{"10.0.0.1", true},
{"172.16.0.1", true},
{"192.168.1.1", true},
{"0.0.0.0", true},
{"::1", true},
{"::", true},
{"fd00::1", true},
{"8.8.8.8", false},
{"1.1.1.1", false},
}
for _, tt := range tests {
ip := net.ParseIP(tt.ip)
if got := IsPrivateIP(ip); got != tt.isPrivate {
t.Errorf("IsPrivateIP(%s) = %v; want %v", tt.ip, got, tt.isPrivate)
}
}
}
func TestGetSafeHTTPClient_BlocksPrivateIPs(t *testing.T) {
client := GetSafeHTTPClient(1 * time.Second)
// Test 127.0.0.1
_, err := client.Get("http://127.0.0.1:12345")
if err == nil {
t.Error("Expected error for 127.0.0.1, got nil")
} else if !strings.Contains(err.Error(), "SSRF protection") {
t.Logf("Got error for 127.0.0.1: %v", err)
if !strings.Contains(err.Error(), "SSRF protection") {
t.Errorf("Expected 'SSRF protection' error, got: %v", err)
}
}
// Test 0.0.0.0
_, err = client.Get("http://0.0.0.0:12345")
if err == nil {
t.Error("Expected error for 0.0.0.0, got nil")
} else if !strings.Contains(err.Error(), "SSRF protection") {
t.Logf("Got error for 0.0.0.0: %v", err)
if !strings.Contains(err.Error(), "SSRF protection") {
t.Errorf("Expected 'SSRF protection' error, got: %v", err)
}
}
}