- Implemented structured RSS feed generation using XML encoding. - Enhanced URL registration by incorporating a random salt for hash generation. - Introduced a bot blocker to the security middleware for improved bot detection. - Updated security middleware to utilize the new bot blocker and added more entropy to request fingerprinting.
178 lines
4.5 KiB
Go
178 lines
4.5 KiB
Go
package security
|
|
|
|
import (
|
|
"bytes"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"software-station/internal/models"
|
|
"software-station/internal/stats"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"golang.org/x/time/rate"
|
|
)
|
|
|
|
func TestThrottledReader(t *testing.T) {
|
|
content := []byte("hello world")
|
|
inner := io.NopCloser(bytes.NewReader(content))
|
|
limiter := rate.NewLimiter(rate.Limit(100), 100)
|
|
fp := "test-fp"
|
|
|
|
statsService := stats.NewService("test-hashes.json")
|
|
statsService.KnownHashes.Lock()
|
|
statsService.KnownHashes.Data[fp] = &models.FingerprintData{Known: true}
|
|
statsService.KnownHashes.Unlock()
|
|
|
|
tr := &ThrottledReader{
|
|
R: inner,
|
|
Limiter: limiter,
|
|
Fingerprint: fp,
|
|
Stats: statsService,
|
|
}
|
|
|
|
p := make([]byte, 5)
|
|
n, err := tr.Read(p)
|
|
if err != nil {
|
|
t.Fatalf("Read failed: %v", err)
|
|
}
|
|
if n != 5 {
|
|
t.Errorf("expected 5 bytes, got %d", n)
|
|
}
|
|
|
|
statsService.KnownHashes.RLock()
|
|
data, ok := statsService.KnownHashes.Data["test-fp"]
|
|
statsService.KnownHashes.RUnlock()
|
|
if !ok {
|
|
t.Fatal("fingerprint data not found")
|
|
}
|
|
total := data.TotalBytes
|
|
if total != 5 {
|
|
t.Errorf("expected 5 bytes in stats, got %d", total)
|
|
}
|
|
|
|
tr.Close()
|
|
}
|
|
|
|
func TestGetRequestFingerprint(t *testing.T) {
|
|
statsService := stats.NewService("test-hashes.json")
|
|
|
|
// IPv4
|
|
req1 := httptest.NewRequest("GET", "/", nil)
|
|
req1.RemoteAddr = "1.2.3.4:1234"
|
|
f1 := GetRequestFingerprint(req1, statsService)
|
|
|
|
req2 := httptest.NewRequest("GET", "/", nil)
|
|
req2.RemoteAddr = "1.2.3.4:5678"
|
|
f2 := GetRequestFingerprint(req2, statsService)
|
|
if f1 != f2 {
|
|
t.Error("fingerprints should match for same IPv4")
|
|
}
|
|
|
|
// X-Forwarded-For
|
|
req3 := httptest.NewRequest("GET", "/", nil)
|
|
req3.Header.Set("X-Forwarded-For", "5.6.7.8, 1.2.3.4")
|
|
f3 := GetRequestFingerprint(req3, statsService)
|
|
if f1 == f3 {
|
|
t.Error("fingerprints should differ for different IPs")
|
|
}
|
|
|
|
// IPv6 masking
|
|
req4 := httptest.NewRequest("GET", "/", nil)
|
|
req4.RemoteAddr = "[2001:db8::1]:1234"
|
|
f4 := GetRequestFingerprint(req4, statsService)
|
|
|
|
req5 := httptest.NewRequest("GET", "/", nil)
|
|
req5.RemoteAddr = "[2001:db8::2]:1234"
|
|
f5 := GetRequestFingerprint(req5, statsService)
|
|
if f4 != f5 {
|
|
t.Error("fingerprints should match for same IPv6 /64 prefix")
|
|
}
|
|
}
|
|
|
|
func TestSecurityMiddleware(t *testing.T) {
|
|
statsService := stats.NewService("test-hashes.json")
|
|
botBlocker := NewBotBlocker("")
|
|
handler := SecurityMiddleware(statsService, botBlocker)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
// Test bot blocking
|
|
req := httptest.NewRequest("GET", "/", nil)
|
|
req.Header.Set("User-Agent", "Googlebot")
|
|
rr := httptest.NewRecorder()
|
|
handler.ServeHTTP(rr, req)
|
|
if rr.Code != http.StatusForbidden {
|
|
t.Errorf("expected 403 for bot, got %d", rr.Code)
|
|
}
|
|
|
|
// Test forbidden pattern
|
|
req = httptest.NewRequest("GET", "/.git/config", nil)
|
|
rr = httptest.NewRecorder()
|
|
handler.ServeHTTP(rr, req)
|
|
if rr.Code != http.StatusForbidden {
|
|
t.Errorf("expected 403 for forbidden pattern, got %d", rr.Code)
|
|
}
|
|
|
|
// Test normal request
|
|
req = httptest.NewRequest("GET", "/api/software", nil)
|
|
rr = httptest.NewRecorder()
|
|
handler.ServeHTTP(rr, req)
|
|
if rr.Code != http.StatusOK {
|
|
t.Errorf("expected 200 for normal request, got %d", rr.Code)
|
|
}
|
|
}
|
|
|
|
func TestIsPrivateIP_Extended(t *testing.T) {
|
|
tests := []struct {
|
|
ip string
|
|
isPrivate bool
|
|
}{
|
|
{"127.0.0.1", true},
|
|
{"10.0.0.1", true},
|
|
{"172.16.0.1", true},
|
|
{"192.168.1.1", true},
|
|
{"0.0.0.0", true},
|
|
{"::1", true},
|
|
{"::", true},
|
|
{"fd00::1", true},
|
|
{"8.8.8.8", false},
|
|
{"1.1.1.1", false},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
ip := net.ParseIP(tt.ip)
|
|
if got := IsPrivateIP(ip); got != tt.isPrivate {
|
|
t.Errorf("IsPrivateIP(%s) = %v; want %v", tt.ip, got, tt.isPrivate)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestGetSafeHTTPClient_BlocksPrivateIPs(t *testing.T) {
|
|
client := GetSafeHTTPClient(1 * time.Second)
|
|
|
|
// Test 127.0.0.1
|
|
_, err := client.Get("http://127.0.0.1:12345")
|
|
if err == nil {
|
|
t.Error("Expected error for 127.0.0.1, got nil")
|
|
} else if !strings.Contains(err.Error(), "SSRF protection") {
|
|
t.Logf("Got error for 127.0.0.1: %v", err)
|
|
if !strings.Contains(err.Error(), "SSRF protection") {
|
|
t.Errorf("Expected 'SSRF protection' error, got: %v", err)
|
|
}
|
|
}
|
|
|
|
// Test 0.0.0.0
|
|
_, err = client.Get("http://0.0.0.0:12345")
|
|
if err == nil {
|
|
t.Error("Expected error for 0.0.0.0, got nil")
|
|
} else if !strings.Contains(err.Error(), "SSRF protection") {
|
|
t.Logf("Got error for 0.0.0.0: %v", err)
|
|
if !strings.Contains(err.Error(), "SSRF protection") {
|
|
t.Errorf("Expected 'SSRF protection' error, got: %v", err)
|
|
}
|
|
}
|
|
}
|