Files
Sudo-Ivan 2c6bee84b4
Some checks failed
renovate / renovate (push) Failing after 15s
CI / build-frontend (push) Successful in 51s
OSV-Scanner Scheduled Scan / scan-scheduled (push) Successful in 9m31s
CI / build-backend (push) Successful in 9m36s
Add caching support in main.go and related files, including SQLite integration for cache storage. Update Docker configurations to include cache settings and enable public instance optimizations.
2025-12-27 20:25:38 -06:00

310 lines
8.6 KiB
Go

package main
import (
"embed"
"encoding/json"
"flag"
"fmt"
"io/fs"
"log"
"net"
"net/http"
"os"
"strings"
"time"
"git.quad4.io/Quad4-Software/webnews/internal/api"
"git.quad4.io/Quad4-Software/webnews/internal/storage"
"golang.org/x/time/rate"
)
//go:embed build/*
var buildAssets embed.FS
func corsMiddleware(allowedOrigins []string) func(http.HandlerFunc) http.HandlerFunc {
return func(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
origin := r.Header.Get("Origin")
if origin == "" {
next.ServeHTTP(w, r)
return
}
allowed := false
if len(allowedOrigins) == 0 {
allowed = true
} else {
for _, o := range allowedOrigins {
if o == "*" || o == origin {
allowed = true
break
}
}
}
if allowed {
w.Header().Set("Access-Control-Allow-Origin", origin)
w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
}
if r.Method == "OPTIONS" {
if allowed {
w.WriteHeader(http.StatusOK)
} else {
w.WriteHeader(http.StatusForbidden)
}
return
}
if !allowed && len(allowedOrigins) > 0 {
log.Printf("Blocked CORS request from origin: %s", origin)
http.Error(w, "CORS Origin Not Allowed", http.StatusForbidden)
return
}
next.ServeHTTP(w, r)
}
}
}
func main() {
frontendPath := flag.String("frontend", "", "Path to custom frontend build directory (overrides embedded assets)")
host := flag.String("host", "0.0.0.0", "Host to bind the server to")
port := flag.String("port", "", "Port to listen on (overrides PORT env var)")
allowedOriginsStr := flag.String("allowed-origins", os.Getenv("ALLOWED_ORIGINS"), "Comma-separated list of allowed CORS origins")
// Auth flags
defaultAuthMode := os.Getenv("AUTH_MODE")
if defaultAuthMode == "" {
defaultAuthMode = "none"
}
authMode := flag.String("auth-mode", defaultAuthMode, "Authentication mode: none, token, multi")
authToken := flag.String("auth-token", os.Getenv("AUTH_TOKEN"), "Master token for 'token' auth mode")
defaultAuthFile := os.Getenv("AUTH_FILE")
if defaultAuthFile == "" {
defaultAuthFile = "accounts.json"
}
authFile := flag.String("auth-file", defaultAuthFile, "File to store accounts for 'multi' auth mode")
defaultAllowReg := true
if os.Getenv("ALLOW_REGISTRATION") == "false" {
defaultAllowReg = false
}
allowReg := flag.Bool("allow-registration", defaultAllowReg, "Allow new account generation in 'multi' mode")
defaultHashesFile := os.Getenv("HASHES_FILE")
if defaultHashesFile == "" {
defaultHashesFile = "client_hashes.json"
}
hashesFile := flag.String("hashes-file", defaultHashesFile, "File to store IP+UA hashes for rate limiting")
rateLimit := flag.Float64("rate-limit", 50.0, "Rate limit in requests per second (env: RATE_LIMIT)")
rateBurst := flag.Int("rate-burst", 100, "Rate limit burst size (env: RATE_BURST)")
disableProtection := flag.Bool("disable-protection", os.Getenv("DISABLE_PROTECTION") == "true", "Disable rate limiting and bot protection")
publicInstance := flag.Bool("public-instance", os.Getenv("PUBLIC_INSTANCE") == "true", "Enable optimizations for public instances (caching, etc.)")
cacheEnabled := flag.Bool("cache-enabled", os.Getenv("CACHE_ENABLED") == "true", "Explicitly enable/disable caching")
cacheTTL := flag.Duration("cache-ttl", 10*time.Minute, "Cache TTL (env: CACHE_TTL)")
cacheFile := flag.String("cache-file", os.Getenv("CACHE_FILE"), "SQLite file for caching (reduces memory load)")
flag.Parse()
// Handle cache config
if envTTL := os.Getenv("CACHE_TTL"); envTTL != "" {
if d, err := time.ParseDuration(envTTL); err == nil {
*cacheTTL = d
}
}
api.FeedCache.TTL = *cacheTTL
api.FullTextCache.TTL = *cacheTTL * 6 // Full text stays longer
if *cacheFile != "" {
db, err := storage.NewSQLiteDB(*cacheFile)
if err != nil {
log.Fatalf("Failed to initialize cache database: %v", err)
}
api.FeedCache.Storage = db
api.FullTextCache.Storage = db
log.Printf("Using SQLite for caching: %s\n", *cacheFile)
// Background cleanup of expired items
go func() {
for {
time.Sleep(1 * time.Hour)
if err := db.PurgeExpiredCaches(); err != nil {
log.Printf("Error purging expired caches: %v", err)
}
}
}()
}
if *publicInstance {
api.FeedCache.Enabled = true
api.FullTextCache.Enabled = true
log.Printf("Public instance optimizations enabled (caching enabled, TTL: %v)\n", *cacheTTL)
}
if os.Getenv("CACHE_ENABLED") != "" {
api.FeedCache.Enabled = *cacheEnabled
api.FullTextCache.Enabled = *cacheEnabled
log.Printf("Caching explicitly %v (TTL: %v)\n", map[bool]string{true: "enabled", false: "disabled"}[*cacheEnabled], *cacheTTL)
}
// Override rate limits from environment if set
if envRate := os.Getenv("RATE_LIMIT"); envRate != "" {
var r float64
if _, err := fmt.Sscanf(envRate, "%f", &r); err == nil {
*rateLimit = r
}
}
if envBurst := os.Getenv("RATE_BURST"); envBurst != "" {
var b int
if _, err := fmt.Sscanf(envBurst, "%d", &b); err == nil {
*rateBurst = b
}
}
api.Limiter.SetLimit(rate.Limit(*rateLimit), *rateBurst)
if *hashesFile != "" {
api.Limiter.File = *hashesFile
api.Limiter.LoadHashes()
}
am := api.NewAuthManager(*authMode, *authToken, *authFile, *allowReg)
var allowedOrigins []string
if *allowedOriginsStr != "" {
origins := strings.Split(*allowedOriginsStr, ",")
for _, o := range origins {
allowedOrigins = append(allowedOrigins, strings.TrimSpace(o))
}
}
if *port == "" {
*port = os.Getenv("PORT")
if *port == "" {
*port = "8080"
}
}
// Middleware chains
cors := corsMiddleware(allowedOrigins)
auth := func(h http.HandlerFunc) http.HandlerFunc {
return api.AuthMiddleware(am, h)
}
// Setup handlers with optional protection
bot := func(h http.HandlerFunc) http.HandlerFunc {
if *disableProtection {
return h
}
return api.BotBlockerMiddleware(h)
}
limit := func(h http.HandlerFunc) http.HandlerFunc {
if *disableProtection {
return h
}
return api.LimitMiddleware(h)
}
apiHandler := cors(auth(bot(limit(api.HandleFeedProxy))))
proxyHandler := cors(auth(bot(limit(api.HandleProxy))))
fullTextHandler := cors(auth(bot(limit(api.HandleFullText))))
pingHandler := cors(bot(limit(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
// Include auth info in ping if no specific origin check is needed
authRequired := am.Mode != "none"
canRegister := am.Mode == "multi" && am.AllowRegistration
if err := json.NewEncoder(w).Encode(map[string]any{
"status": "ok",
"auth": map[string]any{
"required": authRequired,
"mode": am.Mode,
"canReg": canRegister,
},
}); err != nil {
log.Printf("Error encoding ping response: %v", err)
}
})))
// Auth Routes
http.HandleFunc("/api/auth/register", cors(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
token, err := am.Register()
if err != nil {
http.Error(w, "Registration disabled", http.StatusForbidden)
return
}
if err := json.NewEncoder(w).Encode(map[string]string{"accountNumber": token}); err != nil {
log.Printf("Error encoding registration response: %v", err)
}
}))
http.HandleFunc("/api/auth/verify", cors(auth(func(w http.ResponseWriter, r *http.Request) {
if err := json.NewEncoder(w).Encode(map[string]bool{"valid": true}); err != nil {
log.Printf("Error encoding verification response: %v", err)
}
})))
http.HandleFunc("/api/feed", apiHandler)
http.HandleFunc("/api/proxy", proxyHandler)
http.HandleFunc("/api/fulltext", fullTextHandler)
http.HandleFunc("/api/ping", pingHandler)
// Static Assets
var staticFS fs.FS
if *frontendPath != "" {
log.Printf("Using custom frontend from: %s\n", *frontendPath)
staticFS = os.DirFS(*frontendPath)
} else {
sub, err := fs.Sub(buildAssets, "build")
if err != nil {
log.Fatal(err)
}
staticFS = sub
}
fileServer := http.FileServer(http.FS(staticFS))
// SPA Handler
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
path := strings.TrimPrefix(r.URL.Path, "/")
if path == "" {
path = "index.html"
}
_, err := staticFS.Open(path)
if err != nil {
r.URL.Path = "/"
}
fileServer.ServeHTTP(w, r)
})
addr := net.JoinHostPort(*host, *port)
log.Printf("Web News server starting on %s...\n", addr)
server := &http.Server{
Addr: addr,
Handler: nil,
ReadTimeout: 15 * time.Second,
WriteTimeout: 15 * time.Second,
IdleTimeout: 60 * time.Second,
}
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Fatal(err)
}
}