Files
software-station/internal/security/security.go
2025-12-27 02:57:25 -06:00

221 lines
5.2 KiB
Go

package security
import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"log"
"net"
"net/http"
"os"
"strings"
"sync/atomic"
"syscall"
"time"
"software-station/internal/models"
"software-station/internal/stats"
"golang.org/x/time/rate"
)
type ThrottledReader struct {
R io.ReadCloser
Limiter *rate.Limiter
Fingerprint string
Stats *stats.Service
}
func (tr *ThrottledReader) Read(p []byte) (n int, err error) {
n, err = tr.R.Read(p)
if n > 0 && tr.Limiter != nil && tr.Stats != nil {
tr.Stats.KnownHashes.RLock()
data, exists := tr.Stats.KnownHashes.Data[tr.Fingerprint]
tr.Stats.KnownHashes.RUnlock()
var total int64
if exists {
total = atomic.AddInt64(&data.TotalBytes, int64(n))
}
atomic.AddInt64(&tr.Stats.GlobalStats.TotalBytes, int64(n))
if total > HeavyDownloaderThreshold {
tr.Limiter.SetLimit(HeavyDownloaderLimit)
}
if err := tr.Limiter.WaitN(context.Background(), n); err != nil {
return n, err
}
}
return n, err
}
func (tr *ThrottledReader) Close() error {
return tr.R.Close()
}
type contextKey string
const FingerprintKey contextKey = "fingerprint"
func GetRequestFingerprint(r *http.Request, s *stats.Service) string {
if f, ok := r.Context().Value(FingerprintKey).(string); ok {
return f
}
remoteAddr := r.RemoteAddr
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
if comma := strings.IndexByte(xff, ','); comma != -1 {
remoteAddr = strings.TrimSpace(xff[:comma])
} else {
remoteAddr = strings.TrimSpace(xff)
}
} else if xri := r.Header.Get("X-Real-IP"); xri != "" {
remoteAddr = strings.TrimSpace(xri)
}
ipStr, _, err := net.SplitHostPort(remoteAddr)
if err != nil {
ipStr = remoteAddr
}
ip := net.ParseIP(ipStr)
if ip != nil {
if ip.To4() == nil {
ip = ip.Mask(net.CIDRMask(64, 128))
}
ipStr = ip.String()
}
ua := r.Header.Get("User-Agent")
hash := sha256.New()
hash.Write([]byte(ipStr + ua))
fingerprint := hex.EncodeToString(hash.Sum(nil))
s.KnownHashes.Lock()
if _, exists := s.KnownHashes.Data[fingerprint]; !exists {
s.KnownHashes.Data[fingerprint] = &models.FingerprintData{
Known: true,
}
s.SaveHashes()
}
s.KnownHashes.Unlock()
return fingerprint
}
func IsPrivateIP(ip net.IP) bool {
if os.Getenv("ALLOW_LOOPBACK") == "true" && ip.IsLoopback() {
return false
}
if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || ip.IsUnspecified() {
return true
}
// Private IP ranges
privateRanges := []struct {
start net.IP
end net.IP
}{
{net.ParseIP("10.0.0.0"), net.ParseIP("10.255.255.255")},
{net.ParseIP("172.16.0.0"), net.ParseIP("172.31.255.255")},
{net.ParseIP("192.168.0.0"), net.ParseIP("192.168.255.255")},
{net.ParseIP("fd00::"), net.ParseIP("fdff:ffff:ffff:ffff:ffff:ffff:ffff:ffff")},
}
for _, r := range privateRanges {
if bytes.Compare(ip, r.start) >= 0 && bytes.Compare(ip, r.end) <= 0 {
return true
}
}
return false
}
func GetSafeHTTPClient(timeout time.Duration) *http.Client {
dialer := &net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
Control: func(network, address string, c syscall.RawConn) error {
host, _, err := net.SplitHostPort(address)
if err != nil {
return err
}
ip := net.ParseIP(host)
if ip != nil && IsPrivateIP(ip) {
return fmt.Errorf("SSRF protection: forbidden IP %s", ip)
}
return nil
},
}
transport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: dialer.DialContext,
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
return &http.Client{
Transport: transport,
Timeout: timeout,
}
}
func SecurityMiddleware(s *stats.Service) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
path := strings.ToLower(r.URL.Path)
ua := strings.ToLower(r.UserAgent())
fingerprint := GetRequestFingerprint(r, s)
ctx := context.WithValue(r.Context(), FingerprintKey, fingerprint)
r = r.WithContext(ctx)
s.GlobalStats.Lock()
s.GlobalStats.UniqueRequests[fingerprint] = true
if !strings.HasPrefix(path, "/api") {
s.GlobalStats.WebRequests[fingerprint] = true
}
s.GlobalStats.Unlock()
defer func() {
s.GlobalStats.Lock()
s.GlobalStats.TotalResponseTime += time.Since(start)
s.GlobalStats.TotalRequests++
s.GlobalStats.Unlock()
}()
for _, bot := range BotUserAgents {
if strings.Contains(ua, bot) {
s.GlobalStats.Lock()
s.GlobalStats.BlockedRequests[fingerprint] = true
s.GlobalStats.Unlock()
http.Error(w, "Bots are not allowed", http.StatusForbidden)
return
}
}
for _, pattern := range ForbiddenPatterns {
if strings.Contains(path, pattern) {
s.GlobalStats.Lock()
s.GlobalStats.BlockedRequests[fingerprint] = true
s.GlobalStats.Unlock()
log.Printf("Blocked suspicious request: %s from %s (%s)", r.URL.String(), r.RemoteAddr, r.UserAgent())
http.Error(w, "Forbidden", http.StatusForbidden)
return
}
}
next.ServeHTTP(w, r)
})
}
}