package security import ( "bytes" "context" "crypto/sha256" "encoding/hex" "fmt" "io" "log" "net" "net/http" "os" "strings" "sync/atomic" "syscall" "time" "software-station/internal/models" "software-station/internal/stats" "golang.org/x/time/rate" ) type ThrottledReader struct { R io.ReadCloser Limiter *rate.Limiter Fingerprint string Stats *stats.Service } func (tr *ThrottledReader) Read(p []byte) (n int, err error) { n, err = tr.R.Read(p) if n > 0 && tr.Limiter != nil && tr.Stats != nil { tr.Stats.KnownHashes.RLock() data, exists := tr.Stats.KnownHashes.Data[tr.Fingerprint] tr.Stats.KnownHashes.RUnlock() var total int64 if exists { total = atomic.AddInt64(&data.TotalBytes, int64(n)) } atomic.AddInt64(&tr.Stats.GlobalStats.TotalBytes, int64(n)) if total > HeavyDownloaderThreshold { tr.Limiter.SetLimit(HeavyDownloaderLimit) } if err := tr.Limiter.WaitN(context.Background(), n); err != nil { return n, err } } return n, err } func (tr *ThrottledReader) Close() error { return tr.R.Close() } type contextKey string const FingerprintKey contextKey = "fingerprint" func GetRequestFingerprint(r *http.Request, s *stats.Service) string { if f, ok := r.Context().Value(FingerprintKey).(string); ok { return f } remoteAddr := r.RemoteAddr if xff := r.Header.Get("X-Forwarded-For"); xff != "" { if comma := strings.IndexByte(xff, ','); comma != -1 { remoteAddr = strings.TrimSpace(xff[:comma]) } else { remoteAddr = strings.TrimSpace(xff) } } else if xri := r.Header.Get("X-Real-IP"); xri != "" { remoteAddr = strings.TrimSpace(xri) } ipStr, _, err := net.SplitHostPort(remoteAddr) if err != nil { ipStr = remoteAddr } ip := net.ParseIP(ipStr) if ip != nil { if ip.To4() == nil { ip = ip.Mask(net.CIDRMask(64, 128)) } ipStr = ip.String() } // Improve fingerprinting with more entropy ua := r.Header.Get("User-Agent") lang := r.Header.Get("Accept-Language") enc := r.Header.Get("Accept-Encoding") chUA := r.Header.Get("Sec-CH-UA") hash := sha256.New() hash.Write([]byte(ipStr)) hash.Write([]byte("|")) hash.Write([]byte(ua)) hash.Write([]byte("|")) hash.Write([]byte(lang)) hash.Write([]byte("|")) hash.Write([]byte(enc)) hash.Write([]byte("|")) hash.Write([]byte(chUA)) fingerprint := hex.EncodeToString(hash.Sum(nil)) s.KnownHashes.Lock() if _, exists := s.KnownHashes.Data[fingerprint]; !exists { s.KnownHashes.Data[fingerprint] = &models.FingerprintData{ Known: true, } s.SaveHashes() } s.KnownHashes.Unlock() return fingerprint } func IsPrivateIP(ip net.IP) bool { if os.Getenv("ALLOW_LOOPBACK") == "true" && ip.IsLoopback() { return false } if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || ip.IsUnspecified() { return true } // Private IP ranges privateRanges := []struct { start net.IP end net.IP }{ {net.ParseIP("10.0.0.0"), net.ParseIP("10.255.255.255")}, {net.ParseIP("172.16.0.0"), net.ParseIP("172.31.255.255")}, {net.ParseIP("192.168.0.0"), net.ParseIP("192.168.255.255")}, {net.ParseIP("fd00::"), net.ParseIP("fdff:ffff:ffff:ffff:ffff:ffff:ffff:ffff")}, } for _, r := range privateRanges { if bytes.Compare(ip, r.start) >= 0 && bytes.Compare(ip, r.end) <= 0 { return true } } return false } func GetSafeHTTPClient(timeout time.Duration) *http.Client { dialer := &net.Dialer{ Timeout: 30 * time.Second, KeepAlive: 30 * time.Second, Control: func(network, address string, c syscall.RawConn) error { host, _, err := net.SplitHostPort(address) if err != nil { return err } ip := net.ParseIP(host) if ip != nil && IsPrivateIP(ip) { return fmt.Errorf("SSRF protection: forbidden IP %s", ip) } return nil }, } transport := &http.Transport{ Proxy: http.ProxyFromEnvironment, DialContext: dialer.DialContext, ForceAttemptHTTP2: true, MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, } return &http.Client{ Transport: transport, Timeout: timeout, } } func SecurityMiddleware(s *stats.Service, bb *BotBlocker) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() path := strings.ToLower(r.URL.Path) ua := r.UserAgent() fingerprint := GetRequestFingerprint(r, s) ctx := context.WithValue(r.Context(), FingerprintKey, fingerprint) r = r.WithContext(ctx) s.GlobalStats.Lock() s.GlobalStats.UniqueRequests[fingerprint] = true if !strings.HasPrefix(path, "/api") { s.GlobalStats.WebRequests[fingerprint] = true } s.GlobalStats.Unlock() defer func() { s.GlobalStats.Lock() s.GlobalStats.TotalResponseTime += time.Since(start) s.GlobalStats.TotalRequests++ s.GlobalStats.Unlock() }() if bb != nil && bb.IsBot(ua) { s.GlobalStats.Lock() s.GlobalStats.BlockedRequests[fingerprint] = true s.GlobalStats.Unlock() http.Error(w, "Bots are not allowed", http.StatusForbidden) return } for _, pattern := range ForbiddenPatterns { if strings.Contains(path, pattern) { s.GlobalStats.Lock() s.GlobalStats.BlockedRequests[fingerprint] = true s.GlobalStats.Unlock() log.Printf("Blocked suspicious request: %s from %s (%s)", r.URL.String(), r.RemoteAddr, r.UserAgent()) http.Error(w, "Forbidden", http.StatusForbidden) return } } next.ServeHTTP(w, r) }) } }