package main import ( "embed" "encoding/json" "flag" "fmt" "io/fs" "log" "net" "net/http" "os" "strings" "time" "git.quad4.io/Quad4-Software/webnews/internal/api" "git.quad4.io/Quad4-Software/webnews/internal/storage" "golang.org/x/time/rate" ) //go:embed build/* var buildAssets embed.FS func corsMiddleware(allowedOrigins []string) func(http.HandlerFunc) http.HandlerFunc { return func(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { origin := r.Header.Get("Origin") if origin == "" { next.ServeHTTP(w, r) return } allowed := false if len(allowedOrigins) == 0 { allowed = true } else { for _, o := range allowedOrigins { if o == "*" || o == origin { allowed = true break } } } if allowed { w.Header().Set("Access-Control-Allow-Origin", origin) w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS") w.Header().Set("Access-Control-Allow-Headers", "Content-Type") } if r.Method == "OPTIONS" { if allowed { w.WriteHeader(http.StatusOK) } else { w.WriteHeader(http.StatusForbidden) } return } if !allowed && len(allowedOrigins) > 0 { log.Printf("Blocked CORS request from origin: %s", origin) http.Error(w, "CORS Origin Not Allowed", http.StatusForbidden) return } next.ServeHTTP(w, r) } } } func main() { frontendPath := flag.String("frontend", "", "Path to custom frontend build directory (overrides embedded assets)") host := flag.String("host", "0.0.0.0", "Host to bind the server to") port := flag.String("port", "", "Port to listen on (overrides PORT env var)") allowedOriginsStr := flag.String("allowed-origins", os.Getenv("ALLOWED_ORIGINS"), "Comma-separated list of allowed CORS origins") // Auth flags defaultAuthMode := os.Getenv("AUTH_MODE") if defaultAuthMode == "" { defaultAuthMode = "none" } authMode := flag.String("auth-mode", defaultAuthMode, "Authentication mode: none, token, multi") authToken := flag.String("auth-token", os.Getenv("AUTH_TOKEN"), "Master token for 'token' auth mode") defaultAuthFile := os.Getenv("AUTH_FILE") if defaultAuthFile == "" { defaultAuthFile = "accounts.json" } authFile := flag.String("auth-file", defaultAuthFile, "File to store accounts for 'multi' auth mode") defaultAllowReg := true if os.Getenv("ALLOW_REGISTRATION") == "false" { defaultAllowReg = false } allowReg := flag.Bool("allow-registration", defaultAllowReg, "Allow new account generation in 'multi' mode") defaultHashesFile := os.Getenv("HASHES_FILE") if defaultHashesFile == "" { defaultHashesFile = "client_hashes.json" } hashesFile := flag.String("hashes-file", defaultHashesFile, "File to store IP+UA hashes for rate limiting") rateLimit := flag.Float64("rate-limit", 50.0, "Rate limit in requests per second (env: RATE_LIMIT)") rateBurst := flag.Int("rate-burst", 100, "Rate limit burst size (env: RATE_BURST)") disableProtection := flag.Bool("disable-protection", os.Getenv("DISABLE_PROTECTION") == "true", "Disable rate limiting and bot protection") publicInstance := flag.Bool("public-instance", os.Getenv("PUBLIC_INSTANCE") == "true", "Enable optimizations for public instances (caching, etc.)") cacheEnabled := flag.Bool("cache-enabled", os.Getenv("CACHE_ENABLED") == "true", "Explicitly enable/disable caching") cacheTTL := flag.Duration("cache-ttl", 10*time.Minute, "Cache TTL (env: CACHE_TTL)") cacheFile := flag.String("cache-file", os.Getenv("CACHE_FILE"), "SQLite file for caching (reduces memory load)") flag.Parse() // Handle cache config if envTTL := os.Getenv("CACHE_TTL"); envTTL != "" { if d, err := time.ParseDuration(envTTL); err == nil { *cacheTTL = d } } api.FeedCache.TTL = *cacheTTL api.FullTextCache.TTL = *cacheTTL * 6 // Full text stays longer if *cacheFile != "" { db, err := storage.NewSQLiteDB(*cacheFile) if err != nil { log.Fatalf("Failed to initialize cache database: %v", err) } api.FeedCache.Storage = db api.FullTextCache.Storage = db log.Printf("Using SQLite for caching: %s\n", *cacheFile) // Background cleanup of expired items go func() { for { time.Sleep(1 * time.Hour) if err := db.PurgeExpiredCaches(); err != nil { log.Printf("Error purging expired caches: %v", err) } } }() } if *publicInstance { api.FeedCache.Enabled = true api.FullTextCache.Enabled = true log.Printf("Public instance optimizations enabled (caching enabled, TTL: %v)\n", *cacheTTL) } if os.Getenv("CACHE_ENABLED") != "" { api.FeedCache.Enabled = *cacheEnabled api.FullTextCache.Enabled = *cacheEnabled log.Printf("Caching explicitly %v (TTL: %v)\n", map[bool]string{true: "enabled", false: "disabled"}[*cacheEnabled], *cacheTTL) } // Override rate limits from environment if set if envRate := os.Getenv("RATE_LIMIT"); envRate != "" { var r float64 if _, err := fmt.Sscanf(envRate, "%f", &r); err == nil { *rateLimit = r } } if envBurst := os.Getenv("RATE_BURST"); envBurst != "" { var b int if _, err := fmt.Sscanf(envBurst, "%d", &b); err == nil { *rateBurst = b } } api.Limiter.SetLimit(rate.Limit(*rateLimit), *rateBurst) if *hashesFile != "" { api.Limiter.File = *hashesFile api.Limiter.LoadHashes() } am := api.NewAuthManager(*authMode, *authToken, *authFile, *allowReg) var allowedOrigins []string if *allowedOriginsStr != "" { origins := strings.Split(*allowedOriginsStr, ",") for _, o := range origins { allowedOrigins = append(allowedOrigins, strings.TrimSpace(o)) } } if *port == "" { *port = os.Getenv("PORT") if *port == "" { *port = "8080" } } // Middleware chains cors := corsMiddleware(allowedOrigins) auth := func(h http.HandlerFunc) http.HandlerFunc { return api.AuthMiddleware(am, h) } // Setup handlers with optional protection bot := func(h http.HandlerFunc) http.HandlerFunc { if *disableProtection { return h } return api.BotBlockerMiddleware(h) } limit := func(h http.HandlerFunc) http.HandlerFunc { if *disableProtection { return h } return api.LimitMiddleware(h) } apiHandler := cors(auth(bot(limit(api.HandleFeedProxy)))) proxyHandler := cors(auth(bot(limit(api.HandleProxy)))) fullTextHandler := cors(auth(bot(limit(api.HandleFullText)))) pingHandler := cors(bot(limit(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") // Include auth info in ping if no specific origin check is needed authRequired := am.Mode != "none" canRegister := am.Mode == "multi" && am.AllowRegistration if err := json.NewEncoder(w).Encode(map[string]any{ "status": "ok", "auth": map[string]any{ "required": authRequired, "mode": am.Mode, "canReg": canRegister, }, }); err != nil { log.Printf("Error encoding ping response: %v", err) } }))) // Auth Routes http.HandleFunc("/api/auth/register", cors(func(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } token, err := am.Register() if err != nil { http.Error(w, "Registration disabled", http.StatusForbidden) return } if err := json.NewEncoder(w).Encode(map[string]string{"accountNumber": token}); err != nil { log.Printf("Error encoding registration response: %v", err) } })) http.HandleFunc("/api/auth/verify", cors(auth(func(w http.ResponseWriter, r *http.Request) { if err := json.NewEncoder(w).Encode(map[string]bool{"valid": true}); err != nil { log.Printf("Error encoding verification response: %v", err) } }))) http.HandleFunc("/api/feed", apiHandler) http.HandleFunc("/api/proxy", proxyHandler) http.HandleFunc("/api/fulltext", fullTextHandler) http.HandleFunc("/api/ping", pingHandler) // Static Assets var staticFS fs.FS if *frontendPath != "" { log.Printf("Using custom frontend from: %s\n", *frontendPath) staticFS = os.DirFS(*frontendPath) } else { sub, err := fs.Sub(buildAssets, "build") if err != nil { log.Fatal(err) } staticFS = sub } fileServer := http.FileServer(http.FS(staticFS)) // SPA Handler http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { path := strings.TrimPrefix(r.URL.Path, "/") if path == "" { path = "index.html" } _, err := staticFS.Open(path) if err != nil { r.URL.Path = "/" } fileServer.ServeHTTP(w, r) }) addr := net.JoinHostPort(*host, *port) log.Printf("Web News server starting on %s...\n", addr) server := &http.Server{ Addr: addr, Handler: nil, ReadTimeout: 15 * time.Second, WriteTimeout: 15 * time.Second, IdleTimeout: 60 * time.Second, } if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { log.Fatal(err) } }