- Updated GetRequestFingerprint to include additional headers (Sec-CH-UA-Platform, Sec-CH-UA-Mobile) and UID cookie for improved uniqueness. - Modified SecurityMiddleware to set a new UID cookie if not present, enhancing user tracking and security. - Adjusted test cases to reflect changes in fingerprinting logic and ensure accurate validation of request parameters.
193 lines
5.1 KiB
Go
193 lines
5.1 KiB
Go
package security
|
|
|
|
import (
|
|
"bytes"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"software-station/internal/models"
|
|
"software-station/internal/stats"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"golang.org/x/time/rate"
|
|
)
|
|
|
|
func TestThrottledReader(t *testing.T) {
|
|
content := []byte("hello world")
|
|
inner := io.NopCloser(bytes.NewReader(content))
|
|
limiter := rate.NewLimiter(rate.Limit(100), 100)
|
|
fp := "test-fp"
|
|
|
|
statsService := stats.NewService("test-hashes.json")
|
|
statsService.KnownHashes.Lock()
|
|
statsService.KnownHashes.Data[fp] = &models.FingerprintData{Known: true}
|
|
statsService.KnownHashes.Unlock()
|
|
|
|
tr := &ThrottledReader{
|
|
R: inner,
|
|
Limiter: limiter,
|
|
Fingerprint: fp,
|
|
Stats: statsService,
|
|
}
|
|
|
|
p := make([]byte, 5)
|
|
n, err := tr.Read(p)
|
|
if err != nil {
|
|
t.Fatalf("Read failed: %v", err)
|
|
}
|
|
if n != 5 {
|
|
t.Errorf("expected 5 bytes, got %d", n)
|
|
}
|
|
|
|
statsService.KnownHashes.RLock()
|
|
data, ok := statsService.KnownHashes.Data["test-fp"]
|
|
statsService.KnownHashes.RUnlock()
|
|
if !ok {
|
|
t.Fatal("fingerprint data not found")
|
|
}
|
|
total := data.TotalBytes
|
|
if total != 5 {
|
|
t.Errorf("expected 5 bytes in stats, got %d", total)
|
|
}
|
|
|
|
tr.Close()
|
|
}
|
|
|
|
func TestGetRequestFingerprint(t *testing.T) {
|
|
statsService := stats.NewService("test-hashes.json")
|
|
|
|
// Same IP, same headers
|
|
req1 := httptest.NewRequest("GET", "/", nil)
|
|
req1.RemoteAddr = "1.2.3.4:1234"
|
|
req1.Header.Set("User-Agent", "Mozilla/5.0")
|
|
req1.Header.Set("Sec-CH-UA", `"Google Chrome";v="123"`)
|
|
f1 := GetRequestFingerprint(req1, statsService)
|
|
|
|
req2 := httptest.NewRequest("GET", "/", nil)
|
|
req2.RemoteAddr = "1.2.3.4:5678"
|
|
req2.Header.Set("User-Agent", "Mozilla/5.0")
|
|
req2.Header.Set("Sec-CH-UA", `"Google Chrome";v="123"`)
|
|
f2 := GetRequestFingerprint(req2, statsService)
|
|
if f1 != f2 {
|
|
t.Error("fingerprints should match for same parameters")
|
|
}
|
|
|
|
// Different UID cookie
|
|
req3 := httptest.NewRequest("GET", "/", nil)
|
|
req3.RemoteAddr = "1.2.3.4:1234"
|
|
req3.Header.Set("User-Agent", "Mozilla/5.0")
|
|
req3.Header.Set("Sec-CH-UA", `"Google Chrome";v="123"`)
|
|
req3.AddCookie(&http.Cookie{Name: "_ss_uid", Value: "uid1"})
|
|
f3 := GetRequestFingerprint(req3, statsService)
|
|
if f1 == f3 {
|
|
t.Error("fingerprints should differ with different UID cookies")
|
|
}
|
|
|
|
// Different Client Hint
|
|
req4 := httptest.NewRequest("GET", "/", nil)
|
|
req4.RemoteAddr = "1.2.3.4:1234"
|
|
req4.Header.Set("User-Agent", "Mozilla/5.0")
|
|
req4.Header.Set("Sec-CH-UA", `"Brave";v="123"`)
|
|
f4 := GetRequestFingerprint(req4, statsService)
|
|
if f1 == f4 {
|
|
t.Error("fingerprints should differ with different Client Hints")
|
|
}
|
|
}
|
|
|
|
func TestSecurityMiddleware(t *testing.T) {
|
|
statsService := stats.NewService("test-hashes.json")
|
|
botBlocker := NewBotBlocker("")
|
|
handler := SecurityMiddleware(statsService, botBlocker)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
// Test bot blocking
|
|
req := httptest.NewRequest("GET", "/", nil)
|
|
req.Header.Set("User-Agent", "Googlebot")
|
|
rr := httptest.NewRecorder()
|
|
handler.ServeHTTP(rr, req)
|
|
if rr.Code != http.StatusForbidden {
|
|
t.Errorf("expected 403 for bot, got %d", rr.Code)
|
|
}
|
|
|
|
// Test forbidden pattern
|
|
req = httptest.NewRequest("GET", "/.git/config", nil)
|
|
rr = httptest.NewRecorder()
|
|
handler.ServeHTTP(rr, req)
|
|
if rr.Code != http.StatusForbidden {
|
|
t.Errorf("expected 403 for forbidden pattern, got %d", rr.Code)
|
|
}
|
|
|
|
// Test normal request and cookie setting
|
|
req = httptest.NewRequest("GET", "/api/software", nil)
|
|
rr = httptest.NewRecorder()
|
|
handler.ServeHTTP(rr, req)
|
|
if rr.Code != http.StatusOK {
|
|
t.Errorf("expected 200 for normal request, got %d", rr.Code)
|
|
}
|
|
cookieFound := false
|
|
for _, c := range rr.Result().Cookies() {
|
|
if c.Name == "_ss_uid" {
|
|
cookieFound = true
|
|
break
|
|
}
|
|
}
|
|
if !cookieFound {
|
|
t.Error("expected _ss_uid cookie to be set")
|
|
}
|
|
}
|
|
|
|
func TestIsPrivateIP_Extended(t *testing.T) {
|
|
tests := []struct {
|
|
ip string
|
|
isPrivate bool
|
|
}{
|
|
{"127.0.0.1", true},
|
|
{"10.0.0.1", true},
|
|
{"172.16.0.1", true},
|
|
{"192.168.1.1", true},
|
|
{"0.0.0.0", true},
|
|
{"::1", true},
|
|
{"::", true},
|
|
{"fd00::1", true},
|
|
{"8.8.8.8", false},
|
|
{"1.1.1.1", false},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
ip := net.ParseIP(tt.ip)
|
|
if got := IsPrivateIP(ip); got != tt.isPrivate {
|
|
t.Errorf("IsPrivateIP(%s) = %v; want %v", tt.ip, got, tt.isPrivate)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestGetSafeHTTPClient_BlocksPrivateIPs(t *testing.T) {
|
|
client := GetSafeHTTPClient(1 * time.Second)
|
|
|
|
// Test 127.0.0.1
|
|
_, err := client.Get("http://127.0.0.1:12345")
|
|
if err == nil {
|
|
t.Error("Expected error for 127.0.0.1, got nil")
|
|
} else if !strings.Contains(err.Error(), "SSRF protection") {
|
|
t.Logf("Got error for 127.0.0.1: %v", err)
|
|
if !strings.Contains(err.Error(), "SSRF protection") {
|
|
t.Errorf("Expected 'SSRF protection' error, got: %v", err)
|
|
}
|
|
}
|
|
|
|
// Test 0.0.0.0
|
|
_, err = client.Get("http://0.0.0.0:12345")
|
|
if err == nil {
|
|
t.Error("Expected error for 0.0.0.0, got nil")
|
|
} else if !strings.Contains(err.Error(), "SSRF protection") {
|
|
t.Logf("Got error for 0.0.0.0: %v", err)
|
|
if !strings.Contains(err.Error(), "SSRF protection") {
|
|
t.Errorf("Expected 'SSRF protection' error, got: %v", err)
|
|
}
|
|
}
|
|
}
|