- Updated GetRequestFingerprint to include additional headers (Sec-CH-UA-Platform, Sec-CH-UA-Mobile) and UID cookie for improved uniqueness. - Modified SecurityMiddleware to set a new UID cookie if not present, enhancing user tracking and security. - Adjusted test cases to reflect changes in fingerprinting logic and ensure accurate validation of request parameters.
262 lines
6.3 KiB
Go
262 lines
6.3 KiB
Go
package security
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/rand"
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net"
|
|
"net/http"
|
|
"os"
|
|
"strings"
|
|
"sync/atomic"
|
|
"syscall"
|
|
"time"
|
|
|
|
"software-station/internal/models"
|
|
"software-station/internal/stats"
|
|
|
|
"golang.org/x/time/rate"
|
|
)
|
|
|
|
type ThrottledReader struct {
|
|
R io.ReadCloser
|
|
Limiter *rate.Limiter
|
|
Fingerprint string
|
|
Stats *stats.Service
|
|
}
|
|
|
|
func (tr *ThrottledReader) Read(p []byte) (n int, err error) {
|
|
n, err = tr.R.Read(p)
|
|
if n > 0 && tr.Limiter != nil && tr.Stats != nil {
|
|
tr.Stats.KnownHashes.RLock()
|
|
data, exists := tr.Stats.KnownHashes.Data[tr.Fingerprint]
|
|
tr.Stats.KnownHashes.RUnlock()
|
|
|
|
var total int64
|
|
if exists {
|
|
total = atomic.AddInt64(&data.TotalBytes, int64(n))
|
|
}
|
|
atomic.AddInt64(&tr.Stats.GlobalStats.TotalBytes, int64(n))
|
|
|
|
if total > HeavyDownloaderThreshold {
|
|
tr.Limiter.SetLimit(HeavyDownloaderLimit)
|
|
}
|
|
|
|
if err := tr.Limiter.WaitN(context.Background(), n); err != nil {
|
|
return n, err
|
|
}
|
|
}
|
|
return n, err
|
|
}
|
|
|
|
func (tr *ThrottledReader) Close() error {
|
|
return tr.R.Close()
|
|
}
|
|
|
|
type contextKey string
|
|
|
|
const FingerprintKey contextKey = "fingerprint"
|
|
|
|
func GetRequestFingerprint(r *http.Request, s *stats.Service) string {
|
|
if f, ok := r.Context().Value(FingerprintKey).(string); ok {
|
|
return f
|
|
}
|
|
|
|
remoteAddr := r.RemoteAddr
|
|
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
|
if comma := strings.IndexByte(xff, ','); comma != -1 {
|
|
remoteAddr = strings.TrimSpace(xff[:comma])
|
|
} else {
|
|
remoteAddr = strings.TrimSpace(xff)
|
|
}
|
|
} else if xri := r.Header.Get("X-Real-IP"); xri != "" {
|
|
remoteAddr = strings.TrimSpace(xri)
|
|
}
|
|
|
|
ipStr, _, err := net.SplitHostPort(remoteAddr)
|
|
if err != nil {
|
|
ipStr = remoteAddr
|
|
}
|
|
|
|
ip := net.ParseIP(ipStr)
|
|
if ip != nil {
|
|
if ip.To4() == nil {
|
|
ip = ip.Mask(net.CIDRMask(64, 128))
|
|
}
|
|
ipStr = ip.String()
|
|
}
|
|
|
|
ua := r.Header.Get("User-Agent")
|
|
chUA := r.Header.Get("Sec-CH-UA")
|
|
chPlatform := r.Header.Get("Sec-CH-UA-Platform")
|
|
chMobile := r.Header.Get("Sec-CH-UA-Mobile")
|
|
|
|
hash := sha256.New()
|
|
hash.Write([]byte("v2|"))
|
|
hash.Write([]byte(ipStr))
|
|
hash.Write([]byte("|"))
|
|
hash.Write([]byte(ua))
|
|
hash.Write([]byte("|"))
|
|
hash.Write([]byte(chUA))
|
|
hash.Write([]byte("|"))
|
|
hash.Write([]byte(chPlatform))
|
|
hash.Write([]byte("|"))
|
|
hash.Write([]byte(chMobile))
|
|
|
|
if r.TLS != nil {
|
|
hash.Write([]byte(fmt.Sprintf("|%d|%d", r.TLS.Version, r.TLS.CipherSuite)))
|
|
}
|
|
|
|
if cookie, err := r.Cookie("_ss_uid"); err == nil {
|
|
hash.Write([]byte("|"))
|
|
hash.Write([]byte(cookie.Value))
|
|
}
|
|
|
|
fingerprint := hex.EncodeToString(hash.Sum(nil))
|
|
|
|
s.KnownHashes.Lock()
|
|
if _, exists := s.KnownHashes.Data[fingerprint]; !exists {
|
|
if len(s.KnownHashes.Data) < 10000 {
|
|
s.KnownHashes.Data[fingerprint] = &models.FingerprintData{
|
|
Known: true,
|
|
}
|
|
s.SaveHashes()
|
|
}
|
|
}
|
|
s.KnownHashes.Unlock()
|
|
|
|
return fingerprint
|
|
}
|
|
|
|
func IsPrivateIP(ip net.IP) bool {
|
|
if os.Getenv("ALLOW_LOOPBACK") == "true" && ip.IsLoopback() {
|
|
return false
|
|
}
|
|
if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || ip.IsUnspecified() {
|
|
return true
|
|
}
|
|
|
|
privateRanges := []struct {
|
|
start net.IP
|
|
end net.IP
|
|
}{
|
|
{net.ParseIP("10.0.0.0"), net.ParseIP("10.255.255.255")},
|
|
{net.ParseIP("172.16.0.0"), net.ParseIP("172.31.255.255")},
|
|
{net.ParseIP("192.168.0.0"), net.ParseIP("192.168.255.255")},
|
|
{net.ParseIP("fd00::"), net.ParseIP("fdff:ffff:ffff:ffff:ffff:ffff:ffff:ffff")},
|
|
}
|
|
|
|
for _, r := range privateRanges {
|
|
if bytes.Compare(ip, r.start) >= 0 && bytes.Compare(ip, r.end) <= 0 {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
func GetSafeHTTPClient(timeout time.Duration) *http.Client {
|
|
dialer := &net.Dialer{
|
|
Timeout: 30 * time.Second,
|
|
KeepAlive: 30 * time.Second,
|
|
Control: func(network, address string, c syscall.RawConn) error {
|
|
host, _, err := net.SplitHostPort(address)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
ip := net.ParseIP(host)
|
|
if ip != nil && IsPrivateIP(ip) {
|
|
return fmt.Errorf("SSRF protection: forbidden IP %s", ip)
|
|
}
|
|
return nil
|
|
},
|
|
}
|
|
|
|
transport := &http.Transport{
|
|
Proxy: http.ProxyFromEnvironment,
|
|
DialContext: dialer.DialContext,
|
|
ForceAttemptHTTP2: true,
|
|
MaxIdleConns: 100,
|
|
IdleConnTimeout: 90 * time.Second,
|
|
TLSHandshakeTimeout: 10 * time.Second,
|
|
ExpectContinueTimeout: 1 * time.Second,
|
|
}
|
|
|
|
return &http.Client{
|
|
Transport: transport,
|
|
Timeout: timeout,
|
|
}
|
|
}
|
|
|
|
func SecurityMiddleware(s *stats.Service, bb *BotBlocker) func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
start := time.Now()
|
|
path := strings.ToLower(r.URL.Path)
|
|
ua := r.UserAgent()
|
|
|
|
if _, err := r.Cookie("_ss_uid"); err != nil {
|
|
uid := make([]byte, 16)
|
|
if _, err := rand.Read(uid); err == nil {
|
|
uidStr := hex.EncodeToString(uid)
|
|
http.SetCookie(w, &http.Cookie{
|
|
Name: "_ss_uid",
|
|
Value: uidStr,
|
|
Path: "/",
|
|
Expires: time.Now().Add(365 * 24 * time.Hour),
|
|
HttpOnly: true,
|
|
Secure: r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https",
|
|
SameSite: http.SameSiteLaxMode,
|
|
})
|
|
r.AddCookie(&http.Cookie{Name: "_ss_uid", Value: uidStr})
|
|
}
|
|
}
|
|
|
|
fingerprint := GetRequestFingerprint(r, s)
|
|
|
|
ctx := context.WithValue(r.Context(), FingerprintKey, fingerprint)
|
|
r = r.WithContext(ctx)
|
|
|
|
s.GlobalStats.Lock()
|
|
s.GlobalStats.UniqueRequests[fingerprint] = true
|
|
if !strings.HasPrefix(path, "/api") {
|
|
s.GlobalStats.WebRequests[fingerprint] = true
|
|
}
|
|
s.GlobalStats.Unlock()
|
|
|
|
defer func() {
|
|
s.GlobalStats.Lock()
|
|
s.GlobalStats.TotalResponseTime += time.Since(start)
|
|
s.GlobalStats.TotalRequests++
|
|
s.GlobalStats.Unlock()
|
|
}()
|
|
|
|
if bb != nil && bb.IsBot(ua) {
|
|
s.GlobalStats.Lock()
|
|
s.GlobalStats.BlockedRequests[fingerprint] = true
|
|
s.GlobalStats.Unlock()
|
|
http.Error(w, "Bots are not allowed", http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
for _, pattern := range ForbiddenPatterns {
|
|
if strings.Contains(path, pattern) {
|
|
s.GlobalStats.Lock()
|
|
s.GlobalStats.BlockedRequests[fingerprint] = true
|
|
s.GlobalStats.Unlock()
|
|
log.Printf("Blocked suspicious request: %s from %s (%s)", r.URL.String(), r.RemoteAddr, r.UserAgent())
|
|
http.Error(w, "Forbidden", http.StatusForbidden)
|
|
return
|
|
}
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
}
|