Files
software-station/internal/security/security.go
Sudo-Ivan d954d7fe4b
All checks were successful
CI / build (push) Successful in 1m15s
renovate / renovate (push) Successful in 1m19s
Update security middleware and update Docker configurations
- Added a new parameter to the SecurityMiddleware function to allow custom handling of forbidden requests.
- Updated Docker configurations to enable asset caching for improved performance.
- Bumped version number in the Dockerfile to 0.3.0 and refined the image description for clarity.
- Adjusted various frontend components and error handling to support new rate limiting and forbidden access messages.
- Improved documentation in multiple languages to reflect recent changes in features and security measures.
2025-12-27 21:53:10 -06:00

270 lines
6.5 KiB
Go

package security
import (
"bytes"
"context"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"log"
"net"
"net/http"
"os"
"strings"
"sync/atomic"
"syscall"
"time"
"software-station/internal/models"
"software-station/internal/stats"
"golang.org/x/time/rate"
)
type ThrottledReader struct {
R io.ReadCloser
Limiter *rate.Limiter
Fingerprint string
Stats *stats.Service
}
func (tr *ThrottledReader) Read(p []byte) (n int, err error) {
n, err = tr.R.Read(p)
if n > 0 && tr.Limiter != nil && tr.Stats != nil {
tr.Stats.KnownHashes.RLock()
data, exists := tr.Stats.KnownHashes.Data[tr.Fingerprint]
tr.Stats.KnownHashes.RUnlock()
var total int64
if exists {
total = atomic.AddInt64(&data.TotalBytes, int64(n))
}
atomic.AddInt64(&tr.Stats.GlobalStats.TotalBytes, int64(n))
if total > HeavyDownloaderThreshold {
tr.Limiter.SetLimit(HeavyDownloaderLimit)
}
if err := tr.Limiter.WaitN(context.Background(), n); err != nil {
return n, err
}
}
return n, err
}
func (tr *ThrottledReader) Close() error {
return tr.R.Close()
}
type contextKey string
const FingerprintKey contextKey = "fingerprint"
func GetRequestFingerprint(r *http.Request, s *stats.Service) string {
if f, ok := r.Context().Value(FingerprintKey).(string); ok {
return f
}
remoteAddr := r.RemoteAddr
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
if comma := strings.IndexByte(xff, ','); comma != -1 {
remoteAddr = strings.TrimSpace(xff[:comma])
} else {
remoteAddr = strings.TrimSpace(xff)
}
} else if xri := r.Header.Get("X-Real-IP"); xri != "" {
remoteAddr = strings.TrimSpace(xri)
}
ipStr, _, err := net.SplitHostPort(remoteAddr)
if err != nil {
ipStr = remoteAddr
}
ip := net.ParseIP(ipStr)
if ip != nil {
if ip.To4() == nil {
ip = ip.Mask(net.CIDRMask(64, 128))
}
ipStr = ip.String()
}
ua := r.Header.Get("User-Agent")
chUA := r.Header.Get("Sec-CH-UA")
chPlatform := r.Header.Get("Sec-CH-UA-Platform")
chMobile := r.Header.Get("Sec-CH-UA-Mobile")
hash := sha256.New()
hash.Write([]byte("v2|"))
hash.Write([]byte(ipStr))
hash.Write([]byte("|"))
hash.Write([]byte(ua))
hash.Write([]byte("|"))
hash.Write([]byte(chUA))
hash.Write([]byte("|"))
hash.Write([]byte(chPlatform))
hash.Write([]byte("|"))
hash.Write([]byte(chMobile))
if r.TLS != nil {
hash.Write([]byte(fmt.Sprintf("|%d|%d", r.TLS.Version, r.TLS.CipherSuite)))
}
if cookie, err := r.Cookie("_ss_uid"); err == nil {
hash.Write([]byte("|"))
hash.Write([]byte(cookie.Value))
}
fingerprint := hex.EncodeToString(hash.Sum(nil))
s.KnownHashes.Lock()
if _, exists := s.KnownHashes.Data[fingerprint]; !exists {
if len(s.KnownHashes.Data) < 10000 {
s.KnownHashes.Data[fingerprint] = &models.FingerprintData{
Known: true,
}
s.SaveHashes()
}
}
s.KnownHashes.Unlock()
return fingerprint
}
func IsPrivateIP(ip net.IP) bool {
if os.Getenv("ALLOW_LOOPBACK") == "true" && ip.IsLoopback() {
return false
}
if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || ip.IsUnspecified() {
return true
}
privateRanges := []struct {
start net.IP
end net.IP
}{
{net.ParseIP("10.0.0.0"), net.ParseIP("10.255.255.255")},
{net.ParseIP("172.16.0.0"), net.ParseIP("172.31.255.255")},
{net.ParseIP("192.168.0.0"), net.ParseIP("192.168.255.255")},
{net.ParseIP("fd00::"), net.ParseIP("fdff:ffff:ffff:ffff:ffff:ffff:ffff:ffff")},
}
for _, r := range privateRanges {
if bytes.Compare(ip, r.start) >= 0 && bytes.Compare(ip, r.end) <= 0 {
return true
}
}
return false
}
func GetSafeHTTPClient(timeout time.Duration) *http.Client {
dialer := &net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
Control: func(network, address string, c syscall.RawConn) error {
host, _, err := net.SplitHostPort(address)
if err != nil {
return err
}
ip := net.ParseIP(host)
if ip != nil && IsPrivateIP(ip) {
return fmt.Errorf("SSRF protection: forbidden IP %s", ip)
}
return nil
},
}
transport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: dialer.DialContext,
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
return &http.Client{
Transport: transport,
Timeout: timeout,
}
}
func SecurityMiddleware(s *stats.Service, bb *BotBlocker, forbiddenHandler http.HandlerFunc) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
path := strings.ToLower(r.URL.Path)
ua := r.UserAgent()
if _, err := r.Cookie("_ss_uid"); err != nil {
uid := make([]byte, 16)
if _, err := rand.Read(uid); err == nil {
uidStr := hex.EncodeToString(uid)
http.SetCookie(w, &http.Cookie{
Name: "_ss_uid",
Value: uidStr,
Path: "/",
Expires: time.Now().Add(365 * 24 * time.Hour),
HttpOnly: true,
Secure: r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https",
SameSite: http.SameSiteLaxMode,
})
r.AddCookie(&http.Cookie{Name: "_ss_uid", Value: uidStr})
}
}
fingerprint := GetRequestFingerprint(r, s)
ctx := context.WithValue(r.Context(), FingerprintKey, fingerprint)
r = r.WithContext(ctx)
s.GlobalStats.Lock()
s.GlobalStats.UniqueRequests[fingerprint] = true
if !strings.HasPrefix(path, "/api") {
s.GlobalStats.WebRequests[fingerprint] = true
}
s.GlobalStats.Unlock()
defer func() {
s.GlobalStats.Lock()
s.GlobalStats.TotalResponseTime += time.Since(start)
s.GlobalStats.TotalRequests++
s.GlobalStats.Unlock()
}()
if bb != nil && bb.IsBot(ua) {
s.GlobalStats.Lock()
s.GlobalStats.BlockedRequests[fingerprint] = true
s.GlobalStats.Unlock()
if forbiddenHandler != nil {
forbiddenHandler(w, r)
} else {
http.Error(w, "Bots are not allowed", http.StatusForbidden)
}
return
}
for _, pattern := range ForbiddenPatterns {
if strings.Contains(path, pattern) {
s.GlobalStats.Lock()
s.GlobalStats.BlockedRequests[fingerprint] = true
s.GlobalStats.Unlock()
log.Printf("Blocked suspicious request: %s from %s (%s)", r.URL.String(), r.RemoteAddr, r.UserAgent())
if forbiddenHandler != nil {
forbiddenHandler(w, r)
} else {
http.Error(w, "Forbidden", http.StatusForbidden)
}
return
}
}
next.ServeHTTP(w, r)
})
}
}