Files
software-station/internal/api/handlers.go
Sudo-Ivan ab3c188e91 Add RSS feed generation and improve security features
- Implemented structured RSS feed generation using XML encoding.
- Enhanced URL registration by incorporating a random salt for hash generation.
- Introduced a bot blocker to the security middleware for improved bot detection.
- Updated security middleware to utilize the new bot blocker and added more entropy to request fingerprinting.
2025-12-27 03:15:42 -06:00

492 lines
12 KiB
Go

package api
import (
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"encoding/xml"
"fmt"
"io"
"log"
"net/http"
"os"
"path/filepath"
"sort"
"strings"
"sync"
"sync/atomic"
"time"
"software-station/internal/models"
"software-station/internal/security"
"software-station/internal/stats"
"golang.org/x/time/rate"
)
type SoftwareCache struct {
mu sync.RWMutex
data []models.Software
}
func (c *SoftwareCache) Get() []models.Software {
c.mu.RLock()
defer c.mu.RUnlock()
return c.data
}
func (c *SoftwareCache) Set(data []models.Software) {
c.mu.Lock()
defer c.mu.Unlock()
c.data = data
}
func (c *SoftwareCache) GetLock() *sync.RWMutex {
return &c.mu
}
func (c *SoftwareCache) GetDataPtr() *[]models.Software {
return &c.data
}
type Server struct {
GiteaToken string
SoftwareList *SoftwareCache
Stats *stats.Service
urlMap map[string]string
urlMapMu sync.RWMutex
rssCache atomic.Value // stores string
rssLastMod atomic.Value // stores time.Time
avatarCache string
salt []byte
}
func NewServer(token string, initialSoftware []models.Software, statsService *stats.Service) *Server {
salt := make([]byte, 32)
if _, err := rand.Read(salt); err != nil {
log.Fatalf("Failed to generate random salt: %v", err)
}
s := &Server{
GiteaToken: token,
SoftwareList: &SoftwareCache{data: initialSoftware},
Stats: statsService,
urlMap: make(map[string]string),
avatarCache: ".cache/avatars",
salt: salt,
}
s.rssCache.Store("")
s.rssLastMod.Store(time.Time{})
if err := os.MkdirAll(s.avatarCache, 0750); err != nil {
log.Printf("Warning: failed to create avatar cache directory: %v", err)
}
return s
}
func (s *Server) RegisterURL(targetURL string) string {
h := sha256.New()
h.Write(s.salt)
h.Write([]byte(targetURL))
hash := hex.EncodeToString(h.Sum(nil))
s.urlMapMu.Lock()
s.urlMap[hash] = targetURL
s.urlMapMu.Unlock()
return hash
}
func (s *Server) APISoftwareHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
softwareList := s.SoftwareList.Get()
host := r.Host
scheme := "http"
if r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https" {
scheme = "https"
}
proxiedList := make([]models.Software, len(softwareList))
for i, sw := range softwareList {
proxiedList[i] = sw
// If private, hide Gitea URL completely. If public, we can show it for the repo link.
if sw.IsPrivate {
proxiedList[i].GiteaURL = ""
}
// Proxy avatar if it exists
if sw.AvatarURL != "" {
hash := s.RegisterURL(sw.AvatarURL)
proxiedList[i].AvatarURL = fmt.Sprintf("%s://%s/api/avatar?id=%s", scheme, host, hash)
}
proxiedList[i].Releases = make([]models.Release, len(sw.Releases))
for j, rel := range sw.Releases {
proxiedList[i].Releases[j] = rel
proxiedList[i].Releases[j].Assets = make([]models.Asset, len(rel.Assets))
for k, asset := range rel.Assets {
proxiedList[i].Releases[j].Assets[k] = asset
hash := s.RegisterURL(asset.URL)
proxiedList[i].Releases[j].Assets[k].URL = fmt.Sprintf("%s://%s/api/download?id=%s", scheme, host, hash)
}
}
}
if err := json.NewEncoder(w).Encode(proxiedList); err != nil {
log.Printf("Error encoding software list: %v", err)
}
}
func (s *Server) DownloadProxyHandler(w http.ResponseWriter, r *http.Request) {
id := r.URL.Query().Get("id")
if id == "" {
http.Error(w, "Missing id parameter", http.StatusBadRequest)
return
}
s.urlMapMu.RLock()
targetURL, exists := s.urlMap[id]
s.urlMapMu.RUnlock()
if !exists {
http.Error(w, "Invalid or expired download ID", http.StatusNotFound)
return
}
fingerprint := security.GetRequestFingerprint(r, s.Stats)
ua := strings.ToLower(r.UserAgent())
isSpeedDownloader := false
for _, sd := range security.SpeedDownloaders {
if strings.Contains(ua, sd) {
isSpeedDownloader = true
break
}
}
limit := security.DefaultDownloadLimit
burst := int(security.DefaultDownloadBurst)
if isSpeedDownloader {
limit = security.SpeedDownloaderLimit
burst = int(security.SpeedDownloaderBurst)
s.Stats.GlobalStats.Lock()
s.Stats.GlobalStats.LimitedRequests[fingerprint] = true
s.Stats.GlobalStats.Unlock()
}
s.Stats.DownloadStats.Lock()
limiter, exists := s.Stats.DownloadStats.Limiters[fingerprint]
if !exists {
limiter = rate.NewLimiter(limit, burst)
s.Stats.DownloadStats.Limiters[fingerprint] = limiter
}
s.Stats.KnownHashes.RLock()
data, exists := s.Stats.KnownHashes.Data[fingerprint]
s.Stats.KnownHashes.RUnlock()
var totalDownloaded int64
if exists {
totalDownloaded = atomic.LoadInt64(&data.TotalBytes)
}
if totalDownloaded > security.HeavyDownloaderThreshold {
limiter.SetLimit(security.HeavyDownloaderLimit)
s.Stats.GlobalStats.Lock()
s.Stats.GlobalStats.LimitedRequests[fingerprint] = true
s.Stats.GlobalStats.Unlock()
} else {
limiter.SetLimit(limit)
}
s.Stats.DownloadStats.Unlock()
req, err := http.NewRequest("GET", targetURL, nil)
if err != nil {
http.Error(w, "Failed to create request", http.StatusInternalServerError)
return
}
// Forward Range headers for resumable downloads
if rangeHeader := r.Header.Get("Range"); rangeHeader != "" {
req.Header.Set("Range", rangeHeader)
}
if ifRangeHeader := r.Header.Get("If-Range"); ifRangeHeader != "" {
req.Header.Set("If-Range", ifRangeHeader)
}
if s.GiteaToken != "" {
req.Header.Set("Authorization", "token "+s.GiteaToken)
}
client := security.GetSafeHTTPClient(0)
resp, err := client.Do(req)
if err != nil {
http.Error(w, "Failed to fetch asset: "+err.Error(), http.StatusBadGateway)
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusPartialContent {
http.Error(w, "Gitea returned error: "+resp.Status, http.StatusBadGateway)
return
}
// Copy all headers from the upstream response
for k, vv := range resp.Header {
for _, v := range vv {
w.Header().Add(k, v)
}
}
w.WriteHeader(resp.StatusCode)
tr := &security.ThrottledReader{
R: resp.Body,
Limiter: limiter,
Fingerprint: fingerprint,
Stats: s.Stats,
}
n, err := io.Copy(w, tr)
if err != nil {
log.Printf("Error copying proxy response: %v", err)
}
if n > 0 {
s.Stats.GlobalStats.Lock()
s.Stats.GlobalStats.SuccessDownloads[fingerprint] = true
s.Stats.GlobalStats.Unlock()
}
s.Stats.SaveHashes()
}
func (s *Server) LegalHandler(w http.ResponseWriter, r *http.Request) {
doc := r.URL.Query().Get("doc")
var filename string
switch doc {
case "privacy":
filename = PrivacyFile
case "terms":
filename = TermsFile
case "disclaimer":
filename = DisclaimerFile
default:
http.Error(w, "Invalid document", http.StatusBadRequest)
return
}
path := filepath.Join(LegalDir, filename)
data, err := os.ReadFile(path) // #nosec G304
if err != nil {
http.Error(w, "Document not found", http.StatusNotFound)
return
}
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
// If download parameter is present, set Content-Disposition
if r.URL.Query().Get("download") == "true" {
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%s", filename))
}
if _, err := w.Write(data); err != nil {
log.Printf("Error writing legal document: %v", err)
}
}
func (s *Server) AvatarHandler(w http.ResponseWriter, r *http.Request) {
id := r.URL.Query().Get("id")
if id == "" {
http.Error(w, "Missing id parameter", http.StatusBadRequest)
return
}
cachePath := filepath.Join(s.avatarCache, id)
if _, err := os.Stat(cachePath); err == nil {
// Serve from cache
w.Header().Set("Cache-Control", "public, max-age=86400")
http.ServeFile(w, r, cachePath)
return
}
s.urlMapMu.RLock()
targetURL, exists := s.urlMap[id]
s.urlMapMu.RUnlock()
if !exists {
http.Error(w, "Invalid or expired avatar ID", http.StatusNotFound)
return
}
req, err := http.NewRequest("GET", targetURL, nil)
if err != nil {
http.Error(w, "Failed to create request", http.StatusInternalServerError)
return
}
if s.GiteaToken != "" {
req.Header.Set("Authorization", "token "+s.GiteaToken)
}
client := security.GetSafeHTTPClient(0)
resp, err := client.Do(req)
if err != nil {
http.Error(w, "Failed to fetch avatar: "+err.Error(), http.StatusBadGateway)
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
http.Error(w, "Gitea returned error: "+resp.Status, http.StatusBadGateway)
return
}
// Copy data to cache and response simultaneously
data, err := io.ReadAll(resp.Body)
if err != nil {
http.Error(w, "Failed to read avatar data", http.StatusInternalServerError)
return
}
if err := os.WriteFile(cachePath, data, 0600); err != nil {
log.Printf("Warning: failed to cache avatar: %v", err)
}
w.Header().Set("Content-Type", resp.Header.Get("Content-Type"))
w.Header().Set("Cache-Control", "public, max-age=86400")
_, _ = w.Write(data)
}
type rssFeed struct {
XMLName xml.Name `xml:"rss"`
Version string `xml:"version,attr"`
Atom string `xml:"xmlns:atom,attr"`
Channel rssChannel `xml:"channel"`
}
type rssChannel struct {
Title string `xml:"title"`
Link string `xml:"link"`
Description string `xml:"description"`
Language string `xml:"language"`
LastBuildDate string `xml:"lastBuildDate"`
AtomLink rssLink `xml:"atom:link"`
Items []rssItem `xml:"item"`
}
type rssLink struct {
Href string `xml:"href,attr"`
Rel string `xml:"rel,attr"`
Type string `xml:"type,attr"`
}
type rssItem struct {
Title string `xml:"title"`
Link string `xml:"link"`
Description rssDescription `xml:"description"`
GUID rssGUID `xml:"guid"`
PubDate string `xml:"pubDate"`
}
type rssDescription struct {
Content string `xml:",cdata"`
}
type rssGUID struct {
Content string `xml:",chardata"`
IsPermaLink bool `xml:"isPermaLink,attr"`
}
func (s *Server) RSSHandler(w http.ResponseWriter, r *http.Request) {
softwareList := s.SoftwareList.Get()
targetSoftware := r.URL.Query().Get("software")
host := r.Host
scheme := "http"
if r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https" {
scheme = "https"
}
baseURL := fmt.Sprintf("%s://%s", scheme, host)
// Collect all releases and sort by date
type item struct {
Software models.Software
Release models.Release
}
var items []item
for _, sw := range softwareList {
if targetSoftware != "" && sw.Name != targetSoftware {
continue
}
for _, rel := range sw.Releases {
items = append(items, item{Software: sw, Release: rel})
}
}
sort.Slice(items, func(i, j int) bool {
return items[i].Release.CreatedAt.After(items[j].Release.CreatedAt)
})
feedTitle := "Software Updates - Software Station"
feedDescription := "Latest software releases and updates"
selfLink := baseURL + "/api/rss"
if targetSoftware != "" {
feedTitle = fmt.Sprintf("%s Updates - Software Station", targetSoftware)
feedDescription = fmt.Sprintf("Latest releases and updates for %s", targetSoftware)
selfLink = fmt.Sprintf("%s/api/rss?software=%s", baseURL, targetSoftware)
}
feed := rssFeed{
Version: "2.0",
Atom: "http://www.w3.org/2005/Atom",
Channel: rssChannel{
Title: feedTitle,
Link: baseURL,
Description: feedDescription,
Language: "en-us",
LastBuildDate: time.Now().Format(time.RFC1123Z),
AtomLink: rssLink{
Href: selfLink,
Rel: "self",
Type: "application/rss+xml",
},
},
}
for i, it := range items {
if i >= 50 {
break
}
title := fmt.Sprintf("%s %s", it.Software.Name, it.Release.TagName)
link := baseURL
description := it.Software.Description
if it.Release.Body != "" {
description = it.Release.Body
}
feed.Channel.Items = append(feed.Channel.Items, rssItem{
Title: title,
Link: link,
Description: rssDescription{
Content: description,
},
GUID: rssGUID{
Content: fmt.Sprintf("%s-%s", it.Software.Name, it.Release.TagName),
IsPermaLink: false,
},
PubDate: it.Release.CreatedAt.Format(time.RFC1123Z),
})
}
w.Header().Set("Content-Type", "application/rss+xml; charset=utf-8")
w.Header().Set("Cache-Control", "public, max-age=300")
fmt.Fprint(w, xml.Header)
enc := xml.NewEncoder(w)
enc.Indent("", " ")
if err := enc.Encode(feed); err != nil {
log.Printf("Error encoding RSS feed: %v", err)
}
}