- Updated the StartBackgroundUpdater function to accept a callback for software list updates, improving flexibility. - Refactored the API handlers to utilize a proxied software list, enhancing data handling and response efficiency. - Introduced a new method for refreshing the proxied software list, ensuring accurate data representation. - Added unit tests for API handlers to validate functionality and response correctness.
644 lines
15 KiB
Go
644 lines
15 KiB
Go
package api
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/rand"
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"encoding/xml"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net/http"
|
|
"os"
|
|
"path/filepath"
|
|
"sort"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"software-station/internal/models"
|
|
"software-station/internal/security"
|
|
"software-station/internal/stats"
|
|
|
|
"golang.org/x/time/rate"
|
|
)
|
|
|
|
type SoftwareCache struct {
|
|
mu sync.RWMutex
|
|
data []models.Software
|
|
}
|
|
|
|
func (c *SoftwareCache) Get() []models.Software {
|
|
c.mu.RLock()
|
|
defer c.mu.RUnlock()
|
|
return c.data
|
|
}
|
|
|
|
func (c *SoftwareCache) Set(data []models.Software) {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
c.data = data
|
|
}
|
|
|
|
func (c *SoftwareCache) GetLock() *sync.RWMutex {
|
|
return &c.mu
|
|
}
|
|
|
|
func (c *SoftwareCache) GetDataPtr() *[]models.Software {
|
|
return &c.data
|
|
}
|
|
|
|
type Server struct {
|
|
GiteaToken string
|
|
SoftwareList *SoftwareCache
|
|
Stats *stats.Service
|
|
urlMap map[string]string
|
|
urlMapMu sync.RWMutex
|
|
proxiedCache []models.Software
|
|
proxiedMu sync.RWMutex
|
|
rssCache atomic.Value
|
|
rssLastMod atomic.Value
|
|
avatarCache string
|
|
salt []byte
|
|
}
|
|
|
|
func NewServer(token string, initialSoftware []models.Software, statsService *stats.Service) *Server {
|
|
s := &Server{
|
|
GiteaToken: token,
|
|
SoftwareList: &SoftwareCache{data: initialSoftware},
|
|
Stats: statsService,
|
|
urlMap: make(map[string]string),
|
|
avatarCache: ".cache/avatars",
|
|
}
|
|
|
|
s.loadSalt()
|
|
s.rssCache.Store("")
|
|
s.rssLastMod.Store(time.Time{})
|
|
|
|
if err := os.MkdirAll(s.avatarCache, 0750); err != nil {
|
|
log.Printf("Warning: failed to create avatar cache directory: %v", err)
|
|
}
|
|
|
|
s.RefreshProxiedList()
|
|
|
|
go s.startAvatarCleanup()
|
|
|
|
return s
|
|
}
|
|
|
|
func (s *Server) loadSalt() {
|
|
saltPath := ".salt"
|
|
data, err := os.ReadFile(saltPath)
|
|
if err == nil && len(data) == 32 {
|
|
s.salt = data
|
|
return
|
|
}
|
|
|
|
s.salt = make([]byte, 32)
|
|
if _, err := rand.Read(s.salt); err != nil {
|
|
log.Fatalf("Failed to generate random salt: %v", err)
|
|
}
|
|
|
|
if err := os.WriteFile(saltPath, s.salt, 0600); err != nil {
|
|
log.Printf("Warning: failed to save salt to %s: %v", saltPath, err)
|
|
}
|
|
}
|
|
|
|
func (s *Server) RefreshProxiedList() {
|
|
softwareList := s.SoftwareList.Get()
|
|
newProxied := make([]models.Software, len(softwareList))
|
|
newUrlMap := make(map[string]string)
|
|
|
|
for i, sw := range softwareList {
|
|
newProxied[i] = sw
|
|
if sw.IsPrivate {
|
|
newProxied[i].GiteaURL = ""
|
|
}
|
|
|
|
if sw.AvatarURL != "" {
|
|
hash := s.computeHash(sw.AvatarURL)
|
|
newUrlMap[hash] = sw.AvatarURL
|
|
newProxied[i].AvatarURL = hash
|
|
}
|
|
|
|
newProxied[i].Releases = make([]models.Release, len(sw.Releases))
|
|
for j, rel := range sw.Releases {
|
|
newProxied[i].Releases[j] = rel
|
|
newProxied[i].Releases[j].Assets = make([]models.Asset, len(rel.Assets))
|
|
for k, asset := range rel.Assets {
|
|
newProxied[i].Releases[j].Assets[k] = asset
|
|
hash := s.computeHash(asset.URL)
|
|
newUrlMap[hash] = asset.URL
|
|
newProxied[i].Releases[j].Assets[k].URL = hash
|
|
}
|
|
}
|
|
}
|
|
|
|
s.urlMapMu.Lock()
|
|
s.urlMap = newUrlMap
|
|
s.urlMapMu.Unlock()
|
|
|
|
s.proxiedMu.Lock()
|
|
s.proxiedCache = newProxied
|
|
s.proxiedMu.Unlock()
|
|
|
|
// Invalidate RSS cache as well
|
|
s.rssCache.Store("")
|
|
}
|
|
|
|
func (s *Server) computeHash(targetURL string) string {
|
|
h := sha256.New()
|
|
h.Write(s.salt)
|
|
h.Write([]byte(targetURL))
|
|
return hex.EncodeToString(h.Sum(nil))
|
|
}
|
|
|
|
func (s *Server) UpdateSoftwareList(newList []models.Software) {
|
|
s.SoftwareList.Set(newList)
|
|
s.RefreshProxiedList()
|
|
}
|
|
|
|
func (s *Server) RegisterURL(targetURL string) string {
|
|
hash := s.computeHash(targetURL)
|
|
s.urlMapMu.Lock()
|
|
s.urlMap[hash] = targetURL
|
|
s.urlMapMu.Unlock()
|
|
return hash
|
|
}
|
|
|
|
func (s *Server) startAvatarCleanup() {
|
|
ticker := time.NewTicker(AvatarCacheInterval)
|
|
for range ticker.C {
|
|
s.cleanupAvatarCache()
|
|
}
|
|
}
|
|
|
|
func (s *Server) cleanupAvatarCache() {
|
|
files, err := os.ReadDir(s.avatarCache)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
var totalSize int64
|
|
type fileInfo struct {
|
|
path string
|
|
size int64
|
|
mod time.Time
|
|
}
|
|
var infos []fileInfo
|
|
|
|
for _, f := range files {
|
|
if f.IsDir() {
|
|
continue
|
|
}
|
|
path := filepath.Join(s.avatarCache, f.Name())
|
|
info, err := f.Info()
|
|
if err != nil {
|
|
continue
|
|
}
|
|
totalSize += info.Size()
|
|
infos = append(infos, fileInfo{
|
|
path: path,
|
|
size: info.Size(),
|
|
mod: info.ModTime(),
|
|
})
|
|
}
|
|
|
|
if totalSize <= AvatarCacheLimit {
|
|
return
|
|
}
|
|
|
|
// Sort by modification time (oldest first)
|
|
sort.Slice(infos, func(i, j int) bool {
|
|
return infos[i].mod.Before(infos[j].mod)
|
|
})
|
|
|
|
for _, info := range infos {
|
|
if totalSize <= AvatarCacheLimit {
|
|
break
|
|
}
|
|
if err := os.Remove(info.path); err == nil {
|
|
totalSize -= info.size
|
|
}
|
|
}
|
|
log.Printf("Avatar cache cleaned up. Current size: %v bytes", totalSize)
|
|
}
|
|
|
|
func (s *Server) APISoftwareHandler(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
|
|
s.proxiedMu.RLock()
|
|
proxiedList := s.proxiedCache
|
|
s.proxiedMu.RUnlock()
|
|
|
|
host := r.Host
|
|
scheme := "http"
|
|
if r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https" {
|
|
scheme = "https"
|
|
}
|
|
|
|
finalList := make([]models.Software, len(proxiedList))
|
|
for i, sw := range proxiedList {
|
|
finalList[i] = sw
|
|
if sw.AvatarURL != "" {
|
|
finalList[i].AvatarURL = fmt.Sprintf("%s://%s/api/avatar?id=%s", scheme, host, sw.AvatarURL)
|
|
}
|
|
|
|
finalList[i].Releases = make([]models.Release, len(sw.Releases))
|
|
for j, rel := range sw.Releases {
|
|
finalList[i].Releases[j] = rel
|
|
finalList[i].Releases[j].Assets = make([]models.Asset, len(rel.Assets))
|
|
for k, asset := range rel.Assets {
|
|
finalList[i].Releases[j].Assets[k] = asset
|
|
if asset.URL != "" {
|
|
finalList[i].Releases[j].Assets[k].URL = fmt.Sprintf("%s://%s/api/download?id=%s", scheme, host, asset.URL)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if err := json.NewEncoder(w).Encode(finalList); err != nil {
|
|
log.Printf("Error encoding software list: %v", err)
|
|
}
|
|
}
|
|
|
|
func (s *Server) DownloadProxyHandler(w http.ResponseWriter, r *http.Request) {
|
|
id := r.URL.Query().Get("id")
|
|
if id == "" {
|
|
http.Error(w, "Missing id parameter", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
s.urlMapMu.RLock()
|
|
targetURL, exists := s.urlMap[id]
|
|
s.urlMapMu.RUnlock()
|
|
|
|
if !exists {
|
|
http.Error(w, "Invalid or expired download ID", http.StatusNotFound)
|
|
return
|
|
}
|
|
|
|
fingerprint := security.GetRequestFingerprint(r, s.Stats)
|
|
ua := strings.ToLower(r.UserAgent())
|
|
|
|
isSpeedDownloader := false
|
|
for _, sd := range security.SpeedDownloaders {
|
|
if strings.Contains(ua, sd) {
|
|
isSpeedDownloader = true
|
|
break
|
|
}
|
|
}
|
|
|
|
limit := security.DefaultDownloadLimit
|
|
burst := int(security.DefaultDownloadBurst)
|
|
if isSpeedDownloader {
|
|
limit = security.SpeedDownloaderLimit
|
|
burst = int(security.SpeedDownloaderBurst)
|
|
s.Stats.GlobalStats.Lock()
|
|
s.Stats.GlobalStats.LimitedRequests[fingerprint] = true
|
|
s.Stats.GlobalStats.Unlock()
|
|
}
|
|
|
|
s.Stats.DownloadStats.Lock()
|
|
limiter, exists := s.Stats.DownloadStats.Limiters[fingerprint]
|
|
if !exists {
|
|
limiter = rate.NewLimiter(limit, burst)
|
|
s.Stats.DownloadStats.Limiters[fingerprint] = limiter
|
|
}
|
|
|
|
s.Stats.KnownHashes.RLock()
|
|
data, exists := s.Stats.KnownHashes.Data[fingerprint]
|
|
s.Stats.KnownHashes.RUnlock()
|
|
|
|
var totalDownloaded int64
|
|
if exists {
|
|
totalDownloaded = atomic.LoadInt64(&data.TotalBytes)
|
|
}
|
|
|
|
if totalDownloaded > security.HeavyDownloaderThreshold {
|
|
limiter.SetLimit(security.HeavyDownloaderLimit)
|
|
s.Stats.GlobalStats.Lock()
|
|
s.Stats.GlobalStats.LimitedRequests[fingerprint] = true
|
|
s.Stats.GlobalStats.Unlock()
|
|
} else {
|
|
limiter.SetLimit(limit)
|
|
}
|
|
s.Stats.DownloadStats.Unlock()
|
|
|
|
req, err := http.NewRequest("GET", targetURL, nil)
|
|
if err != nil {
|
|
http.Error(w, "Failed to create request", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// Forward Range headers for resumable downloads
|
|
if rangeHeader := r.Header.Get("Range"); rangeHeader != "" {
|
|
req.Header.Set("Range", rangeHeader)
|
|
}
|
|
if ifRangeHeader := r.Header.Get("If-Range"); ifRangeHeader != "" {
|
|
req.Header.Set("If-Range", ifRangeHeader)
|
|
}
|
|
|
|
if s.GiteaToken != "" {
|
|
req.Header.Set("Authorization", "token "+s.GiteaToken)
|
|
}
|
|
|
|
client := security.GetSafeHTTPClient(0)
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
http.Error(w, "Failed to fetch asset: "+err.Error(), http.StatusBadGateway)
|
|
return
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusPartialContent {
|
|
http.Error(w, "Gitea returned error: "+resp.Status, http.StatusBadGateway)
|
|
return
|
|
}
|
|
|
|
// Copy all headers from the upstream response
|
|
for k, vv := range resp.Header {
|
|
for _, v := range vv {
|
|
w.Header().Add(k, v)
|
|
}
|
|
}
|
|
w.WriteHeader(resp.StatusCode)
|
|
|
|
tr := &security.ThrottledReader{
|
|
R: resp.Body,
|
|
Limiter: limiter,
|
|
Fingerprint: fingerprint,
|
|
Stats: s.Stats,
|
|
}
|
|
|
|
n, err := io.Copy(w, tr)
|
|
if err != nil {
|
|
log.Printf("Error copying proxy response: %v", err)
|
|
}
|
|
if n > 0 {
|
|
s.Stats.GlobalStats.Lock()
|
|
s.Stats.GlobalStats.SuccessDownloads[fingerprint] = true
|
|
s.Stats.GlobalStats.Unlock()
|
|
}
|
|
|
|
s.Stats.SaveHashes()
|
|
}
|
|
|
|
func (s *Server) LegalHandler(w http.ResponseWriter, r *http.Request) {
|
|
doc := r.URL.Query().Get("doc")
|
|
var filename string
|
|
|
|
switch doc {
|
|
case "privacy":
|
|
filename = PrivacyFile
|
|
case "terms":
|
|
filename = TermsFile
|
|
case "disclaimer":
|
|
filename = DisclaimerFile
|
|
default:
|
|
http.Error(w, "Invalid document", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
path := filepath.Join(LegalDir, filename)
|
|
data, err := os.ReadFile(path) // #nosec G304
|
|
if err != nil {
|
|
http.Error(w, "Document not found", http.StatusNotFound)
|
|
return
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
|
// If download parameter is present, set Content-Disposition
|
|
if r.URL.Query().Get("download") == "true" {
|
|
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%s", filename))
|
|
}
|
|
if _, err := w.Write(data); err != nil {
|
|
log.Printf("Error writing legal document: %v", err)
|
|
}
|
|
}
|
|
|
|
func (s *Server) AvatarHandler(w http.ResponseWriter, r *http.Request) {
|
|
id := r.URL.Query().Get("id")
|
|
if id == "" {
|
|
http.Error(w, "Missing id parameter", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
cachePath := filepath.Join(s.avatarCache, id)
|
|
if _, err := os.Stat(cachePath); err == nil {
|
|
now := time.Now()
|
|
_ = os.Chtimes(cachePath, now, now)
|
|
|
|
w.Header().Set("Cache-Control", "public, max-age=86400")
|
|
http.ServeFile(w, r, cachePath)
|
|
return
|
|
}
|
|
|
|
s.urlMapMu.RLock()
|
|
targetURL, exists := s.urlMap[id]
|
|
s.urlMapMu.RUnlock()
|
|
|
|
if !exists {
|
|
http.Error(w, "Invalid or expired avatar ID", http.StatusNotFound)
|
|
return
|
|
}
|
|
|
|
req, err := http.NewRequest("GET", targetURL, nil)
|
|
if err != nil {
|
|
http.Error(w, "Failed to create request", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
if s.GiteaToken != "" {
|
|
req.Header.Set("Authorization", "token "+s.GiteaToken)
|
|
}
|
|
|
|
client := security.GetSafeHTTPClient(0)
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
http.Error(w, "Failed to fetch avatar: "+err.Error(), http.StatusBadGateway)
|
|
return
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
http.Error(w, "Gitea returned error: "+resp.Status, http.StatusBadGateway)
|
|
return
|
|
}
|
|
|
|
data, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
http.Error(w, "Failed to read avatar data", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
if err := os.WriteFile(cachePath, data, 0600); err != nil {
|
|
log.Printf("Warning: failed to cache avatar: %v", err)
|
|
}
|
|
|
|
w.Header().Set("Content-Type", resp.Header.Get("Content-Type"))
|
|
w.Header().Set("Cache-Control", "public, max-age=86400")
|
|
_, _ = w.Write(data)
|
|
}
|
|
|
|
type rssFeed struct {
|
|
XMLName xml.Name `xml:"rss"`
|
|
Version string `xml:"version,attr"`
|
|
Atom string `xml:"xmlns:atom,attr"`
|
|
Channel rssChannel `xml:"channel"`
|
|
}
|
|
|
|
type rssChannel struct {
|
|
Title string `xml:"title"`
|
|
Link string `xml:"link"`
|
|
Description string `xml:"description"`
|
|
Language string `xml:"language"`
|
|
LastBuildDate string `xml:"lastBuildDate"`
|
|
AtomLink rssLink `xml:"atom:link"`
|
|
Items []rssItem `xml:"item"`
|
|
}
|
|
|
|
type rssLink struct {
|
|
Href string `xml:"href,attr"`
|
|
Rel string `xml:"rel,attr"`
|
|
Type string `xml:"type,attr"`
|
|
}
|
|
|
|
type rssItem struct {
|
|
Title string `xml:"title"`
|
|
Link string `xml:"link"`
|
|
Description rssDescription `xml:"description"`
|
|
GUID rssGUID `xml:"guid"`
|
|
PubDate string `xml:"pubDate"`
|
|
}
|
|
|
|
type rssDescription struct {
|
|
Content string `xml:",cdata"`
|
|
}
|
|
|
|
type rssGUID struct {
|
|
Content string `xml:",chardata"`
|
|
IsPermaLink bool `xml:"isPermaLink,attr"`
|
|
}
|
|
|
|
func (s *Server) RSSHandler(w http.ResponseWriter, r *http.Request) {
|
|
targetSoftware := r.URL.Query().Get("software")
|
|
|
|
if targetSoftware == "" {
|
|
if cached := s.rssCache.Load().(string); cached != "" {
|
|
w.Header().Set("Content-Type", "application/rss+xml; charset=utf-8")
|
|
w.Header().Set("Cache-Control", "public, max-age=300")
|
|
lastMod := s.rssLastMod.Load().(time.Time)
|
|
if !lastMod.IsZero() {
|
|
w.Header().Set("Last-Modified", lastMod.Format(http.TimeFormat))
|
|
if r.Header.Get("If-Modified-Since") == lastMod.Format(http.TimeFormat) {
|
|
w.WriteHeader(http.StatusNotModified)
|
|
return
|
|
}
|
|
}
|
|
_, _ = w.Write([]byte(cached))
|
|
return
|
|
}
|
|
}
|
|
|
|
softwareList := s.SoftwareList.Get()
|
|
host := r.Host
|
|
scheme := "http"
|
|
if r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https" {
|
|
scheme = "https"
|
|
}
|
|
baseURL := fmt.Sprintf("%s://%s", scheme, host)
|
|
|
|
type item struct {
|
|
Software models.Software
|
|
Release models.Release
|
|
}
|
|
var items []item
|
|
for _, sw := range softwareList {
|
|
if targetSoftware != "" && sw.Name != targetSoftware {
|
|
continue
|
|
}
|
|
for _, rel := range sw.Releases {
|
|
items = append(items, item{Software: sw, Release: rel})
|
|
}
|
|
}
|
|
|
|
sort.Slice(items, func(i, j int) bool {
|
|
return items[i].Release.CreatedAt.After(items[j].Release.CreatedAt)
|
|
})
|
|
|
|
feedTitle := "Software Updates - Software Station"
|
|
feedDescription := "Latest software releases and updates"
|
|
selfLink := baseURL + "/api/rss"
|
|
if targetSoftware != "" {
|
|
feedTitle = fmt.Sprintf("%s Updates - Software Station", targetSoftware)
|
|
feedDescription = fmt.Sprintf("Latest releases and updates for %s", targetSoftware)
|
|
selfLink = fmt.Sprintf("%s/api/rss?software=%s", baseURL, targetSoftware)
|
|
}
|
|
|
|
feed := rssFeed{
|
|
Version: "2.0",
|
|
Atom: "http://www.w3.org/2005/Atom",
|
|
Channel: rssChannel{
|
|
Title: feedTitle,
|
|
Link: baseURL,
|
|
Description: feedDescription,
|
|
Language: "en-us",
|
|
LastBuildDate: time.Now().Format(time.RFC1123Z),
|
|
AtomLink: rssLink{
|
|
Href: selfLink,
|
|
Rel: "self",
|
|
Type: "application/rss+xml",
|
|
},
|
|
},
|
|
}
|
|
|
|
for i, it := range items {
|
|
if i >= 50 {
|
|
break
|
|
}
|
|
title := fmt.Sprintf("%s %s", it.Software.Name, it.Release.TagName)
|
|
link := baseURL
|
|
description := it.Software.Description
|
|
if it.Release.Body != "" {
|
|
description = it.Release.Body
|
|
}
|
|
|
|
feed.Channel.Items = append(feed.Channel.Items, rssItem{
|
|
Title: title,
|
|
Link: link,
|
|
Description: rssDescription{
|
|
Content: description,
|
|
},
|
|
GUID: rssGUID{
|
|
Content: fmt.Sprintf("%s-%s", it.Software.Name, it.Release.TagName),
|
|
IsPermaLink: false,
|
|
},
|
|
PubDate: it.Release.CreatedAt.Format(time.RFC1123Z),
|
|
})
|
|
}
|
|
|
|
var buf bytes.Buffer
|
|
buf.WriteString(xml.Header)
|
|
enc := xml.NewEncoder(&buf)
|
|
enc.Indent("", " ")
|
|
if err := enc.Encode(feed); err != nil {
|
|
log.Printf("Error encoding RSS feed: %v", err)
|
|
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
xmlData := buf.String()
|
|
if targetSoftware == "" {
|
|
s.rssCache.Store(xmlData)
|
|
s.rssLastMod.Store(time.Now())
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/rss+xml; charset=utf-8")
|
|
w.Header().Set("Cache-Control", "public, max-age=300")
|
|
_, _ = w.Write([]byte(xmlData))
|
|
}
|