310 lines
8.6 KiB
Go
310 lines
8.6 KiB
Go
package main
|
|
|
|
import (
|
|
"embed"
|
|
"encoding/json"
|
|
"flag"
|
|
"fmt"
|
|
"io/fs"
|
|
"log"
|
|
"net"
|
|
"net/http"
|
|
"os"
|
|
"strings"
|
|
"time"
|
|
|
|
"git.quad4.io/Quad4-Software/webnews/internal/api"
|
|
"git.quad4.io/Quad4-Software/webnews/internal/storage"
|
|
"golang.org/x/time/rate"
|
|
)
|
|
|
|
//go:embed build/*
|
|
var buildAssets embed.FS
|
|
|
|
func corsMiddleware(allowedOrigins []string) func(http.HandlerFunc) http.HandlerFunc {
|
|
return func(next http.HandlerFunc) http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
origin := r.Header.Get("Origin")
|
|
if origin == "" {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
allowed := false
|
|
if len(allowedOrigins) == 0 {
|
|
allowed = true
|
|
} else {
|
|
for _, o := range allowedOrigins {
|
|
if o == "*" || o == origin {
|
|
allowed = true
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
if allowed {
|
|
w.Header().Set("Access-Control-Allow-Origin", origin)
|
|
w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS")
|
|
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
|
|
}
|
|
|
|
if r.Method == "OPTIONS" {
|
|
if allowed {
|
|
w.WriteHeader(http.StatusOK)
|
|
} else {
|
|
w.WriteHeader(http.StatusForbidden)
|
|
}
|
|
return
|
|
}
|
|
|
|
if !allowed && len(allowedOrigins) > 0 {
|
|
log.Printf("Blocked CORS request from origin: %s", origin)
|
|
http.Error(w, "CORS Origin Not Allowed", http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
}
|
|
}
|
|
}
|
|
|
|
func main() {
|
|
frontendPath := flag.String("frontend", "", "Path to custom frontend build directory (overrides embedded assets)")
|
|
host := flag.String("host", "0.0.0.0", "Host to bind the server to")
|
|
port := flag.String("port", "", "Port to listen on (overrides PORT env var)")
|
|
allowedOriginsStr := flag.String("allowed-origins", os.Getenv("ALLOWED_ORIGINS"), "Comma-separated list of allowed CORS origins")
|
|
|
|
// Auth flags
|
|
defaultAuthMode := os.Getenv("AUTH_MODE")
|
|
if defaultAuthMode == "" {
|
|
defaultAuthMode = "none"
|
|
}
|
|
authMode := flag.String("auth-mode", defaultAuthMode, "Authentication mode: none, token, multi")
|
|
|
|
authToken := flag.String("auth-token", os.Getenv("AUTH_TOKEN"), "Master token for 'token' auth mode")
|
|
|
|
defaultAuthFile := os.Getenv("AUTH_FILE")
|
|
if defaultAuthFile == "" {
|
|
defaultAuthFile = "accounts.json"
|
|
}
|
|
authFile := flag.String("auth-file", defaultAuthFile, "File to store accounts for 'multi' auth mode")
|
|
|
|
defaultAllowReg := true
|
|
if os.Getenv("ALLOW_REGISTRATION") == "false" {
|
|
defaultAllowReg = false
|
|
}
|
|
allowReg := flag.Bool("allow-registration", defaultAllowReg, "Allow new account generation in 'multi' mode")
|
|
|
|
defaultHashesFile := os.Getenv("HASHES_FILE")
|
|
if defaultHashesFile == "" {
|
|
defaultHashesFile = "client_hashes.json"
|
|
}
|
|
hashesFile := flag.String("hashes-file", defaultHashesFile, "File to store IP+UA hashes for rate limiting")
|
|
|
|
rateLimit := flag.Float64("rate-limit", 50.0, "Rate limit in requests per second (env: RATE_LIMIT)")
|
|
rateBurst := flag.Int("rate-burst", 100, "Rate limit burst size (env: RATE_BURST)")
|
|
|
|
disableProtection := flag.Bool("disable-protection", os.Getenv("DISABLE_PROTECTION") == "true", "Disable rate limiting and bot protection")
|
|
|
|
publicInstance := flag.Bool("public-instance", os.Getenv("PUBLIC_INSTANCE") == "true", "Enable optimizations for public instances (caching, etc.)")
|
|
cacheEnabled := flag.Bool("cache-enabled", os.Getenv("CACHE_ENABLED") == "true", "Explicitly enable/disable caching")
|
|
cacheTTL := flag.Duration("cache-ttl", 10*time.Minute, "Cache TTL (env: CACHE_TTL)")
|
|
cacheFile := flag.String("cache-file", os.Getenv("CACHE_FILE"), "SQLite file for caching (reduces memory load)")
|
|
|
|
flag.Parse()
|
|
|
|
// Handle cache config
|
|
if envTTL := os.Getenv("CACHE_TTL"); envTTL != "" {
|
|
if d, err := time.ParseDuration(envTTL); err == nil {
|
|
*cacheTTL = d
|
|
}
|
|
}
|
|
|
|
api.FeedCache.TTL = *cacheTTL
|
|
api.FullTextCache.TTL = *cacheTTL * 6 // Full text stays longer
|
|
|
|
if *cacheFile != "" {
|
|
db, err := storage.NewSQLiteDB(*cacheFile)
|
|
if err != nil {
|
|
log.Fatalf("Failed to initialize cache database: %v", err)
|
|
}
|
|
api.FeedCache.Storage = db
|
|
api.FullTextCache.Storage = db
|
|
log.Printf("Using SQLite for caching: %s\n", *cacheFile)
|
|
|
|
// Background cleanup of expired items
|
|
go func() {
|
|
for {
|
|
time.Sleep(1 * time.Hour)
|
|
if err := db.PurgeExpiredCaches(); err != nil {
|
|
log.Printf("Error purging expired caches: %v", err)
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
if *publicInstance {
|
|
api.FeedCache.Enabled = true
|
|
api.FullTextCache.Enabled = true
|
|
log.Printf("Public instance optimizations enabled (caching enabled, TTL: %v)\n", *cacheTTL)
|
|
}
|
|
if os.Getenv("CACHE_ENABLED") != "" {
|
|
api.FeedCache.Enabled = *cacheEnabled
|
|
api.FullTextCache.Enabled = *cacheEnabled
|
|
log.Printf("Caching explicitly %v (TTL: %v)\n", map[bool]string{true: "enabled", false: "disabled"}[*cacheEnabled], *cacheTTL)
|
|
}
|
|
|
|
// Override rate limits from environment if set
|
|
if envRate := os.Getenv("RATE_LIMIT"); envRate != "" {
|
|
var r float64
|
|
if _, err := fmt.Sscanf(envRate, "%f", &r); err == nil {
|
|
*rateLimit = r
|
|
}
|
|
}
|
|
if envBurst := os.Getenv("RATE_BURST"); envBurst != "" {
|
|
var b int
|
|
if _, err := fmt.Sscanf(envBurst, "%d", &b); err == nil {
|
|
*rateBurst = b
|
|
}
|
|
}
|
|
|
|
api.Limiter.SetLimit(rate.Limit(*rateLimit), *rateBurst)
|
|
|
|
if *hashesFile != "" {
|
|
api.Limiter.File = *hashesFile
|
|
api.Limiter.LoadHashes()
|
|
}
|
|
|
|
am := api.NewAuthManager(*authMode, *authToken, *authFile, *allowReg)
|
|
|
|
var allowedOrigins []string
|
|
if *allowedOriginsStr != "" {
|
|
origins := strings.Split(*allowedOriginsStr, ",")
|
|
for _, o := range origins {
|
|
allowedOrigins = append(allowedOrigins, strings.TrimSpace(o))
|
|
}
|
|
}
|
|
|
|
if *port == "" {
|
|
*port = os.Getenv("PORT")
|
|
if *port == "" {
|
|
*port = "8080"
|
|
}
|
|
}
|
|
|
|
// Middleware chains
|
|
cors := corsMiddleware(allowedOrigins)
|
|
auth := func(h http.HandlerFunc) http.HandlerFunc {
|
|
return api.AuthMiddleware(am, h)
|
|
}
|
|
|
|
// Setup handlers with optional protection
|
|
bot := func(h http.HandlerFunc) http.HandlerFunc {
|
|
if *disableProtection {
|
|
return h
|
|
}
|
|
return api.BotBlockerMiddleware(h)
|
|
}
|
|
|
|
limit := func(h http.HandlerFunc) http.HandlerFunc {
|
|
if *disableProtection {
|
|
return h
|
|
}
|
|
return api.LimitMiddleware(h)
|
|
}
|
|
|
|
apiHandler := cors(auth(bot(limit(api.HandleFeedProxy))))
|
|
proxyHandler := cors(auth(bot(limit(api.HandleProxy))))
|
|
fullTextHandler := cors(auth(bot(limit(api.HandleFullText))))
|
|
|
|
pingHandler := cors(bot(limit(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
|
|
// Include auth info in ping if no specific origin check is needed
|
|
authRequired := am.Mode != "none"
|
|
canRegister := am.Mode == "multi" && am.AllowRegistration
|
|
|
|
if err := json.NewEncoder(w).Encode(map[string]any{
|
|
"status": "ok",
|
|
"auth": map[string]any{
|
|
"required": authRequired,
|
|
"mode": am.Mode,
|
|
"canReg": canRegister,
|
|
},
|
|
}); err != nil {
|
|
log.Printf("Error encoding ping response: %v", err)
|
|
}
|
|
})))
|
|
|
|
// Auth Routes
|
|
http.HandleFunc("/api/auth/register", cors(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != "POST" {
|
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
token, err := am.Register()
|
|
if err != nil {
|
|
http.Error(w, "Registration disabled", http.StatusForbidden)
|
|
return
|
|
}
|
|
if err := json.NewEncoder(w).Encode(map[string]string{"accountNumber": token}); err != nil {
|
|
log.Printf("Error encoding registration response: %v", err)
|
|
}
|
|
}))
|
|
|
|
http.HandleFunc("/api/auth/verify", cors(auth(func(w http.ResponseWriter, r *http.Request) {
|
|
if err := json.NewEncoder(w).Encode(map[string]bool{"valid": true}); err != nil {
|
|
log.Printf("Error encoding verification response: %v", err)
|
|
}
|
|
})))
|
|
|
|
http.HandleFunc("/api/feed", apiHandler)
|
|
http.HandleFunc("/api/proxy", proxyHandler)
|
|
http.HandleFunc("/api/fulltext", fullTextHandler)
|
|
http.HandleFunc("/api/ping", pingHandler)
|
|
|
|
// Static Assets
|
|
var staticFS fs.FS
|
|
if *frontendPath != "" {
|
|
log.Printf("Using custom frontend from: %s\n", *frontendPath)
|
|
staticFS = os.DirFS(*frontendPath)
|
|
} else {
|
|
sub, err := fs.Sub(buildAssets, "build")
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
staticFS = sub
|
|
}
|
|
|
|
fileServer := http.FileServer(http.FS(staticFS))
|
|
|
|
// SPA Handler
|
|
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
|
path := strings.TrimPrefix(r.URL.Path, "/")
|
|
if path == "" {
|
|
path = "index.html"
|
|
}
|
|
|
|
_, err := staticFS.Open(path)
|
|
if err != nil {
|
|
r.URL.Path = "/"
|
|
}
|
|
fileServer.ServeHTTP(w, r)
|
|
})
|
|
|
|
addr := net.JoinHostPort(*host, *port)
|
|
log.Printf("Web News server starting on %s...\n", addr)
|
|
|
|
server := &http.Server{
|
|
Addr: addr,
|
|
Handler: nil,
|
|
ReadTimeout: 15 * time.Second,
|
|
WriteTimeout: 15 * time.Second,
|
|
IdleTimeout: 60 * time.Second,
|
|
}
|
|
|
|
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
|
log.Fatal(err)
|
|
}
|
|
}
|