package api import ( "crypto/rand" "crypto/sha256" "encoding/hex" "encoding/json" "encoding/xml" "fmt" "io" "log" "net/http" "os" "path/filepath" "sort" "strings" "sync" "sync/atomic" "time" "software-station/internal/models" "software-station/internal/security" "software-station/internal/stats" "golang.org/x/time/rate" ) type SoftwareCache struct { mu sync.RWMutex data []models.Software } func (c *SoftwareCache) Get() []models.Software { c.mu.RLock() defer c.mu.RUnlock() return c.data } func (c *SoftwareCache) Set(data []models.Software) { c.mu.Lock() defer c.mu.Unlock() c.data = data } func (c *SoftwareCache) GetLock() *sync.RWMutex { return &c.mu } func (c *SoftwareCache) GetDataPtr() *[]models.Software { return &c.data } type Server struct { GiteaToken string SoftwareList *SoftwareCache Stats *stats.Service urlMap map[string]string urlMapMu sync.RWMutex rssCache atomic.Value // stores string rssLastMod atomic.Value // stores time.Time avatarCache string salt []byte } func NewServer(token string, initialSoftware []models.Software, statsService *stats.Service) *Server { salt := make([]byte, 32) if _, err := rand.Read(salt); err != nil { log.Fatalf("Failed to generate random salt: %v", err) } s := &Server{ GiteaToken: token, SoftwareList: &SoftwareCache{data: initialSoftware}, Stats: statsService, urlMap: make(map[string]string), avatarCache: ".cache/avatars", salt: salt, } s.rssCache.Store("") s.rssLastMod.Store(time.Time{}) if err := os.MkdirAll(s.avatarCache, 0750); err != nil { log.Printf("Warning: failed to create avatar cache directory: %v", err) } return s } func (s *Server) RegisterURL(targetURL string) string { h := sha256.New() h.Write(s.salt) h.Write([]byte(targetURL)) hash := hex.EncodeToString(h.Sum(nil)) s.urlMapMu.Lock() s.urlMap[hash] = targetURL s.urlMapMu.Unlock() return hash } func (s *Server) APISoftwareHandler(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") softwareList := s.SoftwareList.Get() host := r.Host scheme := "http" if r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https" { scheme = "https" } proxiedList := make([]models.Software, len(softwareList)) for i, sw := range softwareList { proxiedList[i] = sw // If private, hide Gitea URL completely. If public, we can show it for the repo link. if sw.IsPrivate { proxiedList[i].GiteaURL = "" } // Proxy avatar if it exists if sw.AvatarURL != "" { hash := s.RegisterURL(sw.AvatarURL) proxiedList[i].AvatarURL = fmt.Sprintf("%s://%s/api/avatar?id=%s", scheme, host, hash) } proxiedList[i].Releases = make([]models.Release, len(sw.Releases)) for j, rel := range sw.Releases { proxiedList[i].Releases[j] = rel proxiedList[i].Releases[j].Assets = make([]models.Asset, len(rel.Assets)) for k, asset := range rel.Assets { proxiedList[i].Releases[j].Assets[k] = asset hash := s.RegisterURL(asset.URL) proxiedList[i].Releases[j].Assets[k].URL = fmt.Sprintf("%s://%s/api/download?id=%s", scheme, host, hash) } } } if err := json.NewEncoder(w).Encode(proxiedList); err != nil { log.Printf("Error encoding software list: %v", err) } } func (s *Server) DownloadProxyHandler(w http.ResponseWriter, r *http.Request) { id := r.URL.Query().Get("id") if id == "" { http.Error(w, "Missing id parameter", http.StatusBadRequest) return } s.urlMapMu.RLock() targetURL, exists := s.urlMap[id] s.urlMapMu.RUnlock() if !exists { http.Error(w, "Invalid or expired download ID", http.StatusNotFound) return } fingerprint := security.GetRequestFingerprint(r, s.Stats) ua := strings.ToLower(r.UserAgent()) isSpeedDownloader := false for _, sd := range security.SpeedDownloaders { if strings.Contains(ua, sd) { isSpeedDownloader = true break } } limit := security.DefaultDownloadLimit burst := int(security.DefaultDownloadBurst) if isSpeedDownloader { limit = security.SpeedDownloaderLimit burst = int(security.SpeedDownloaderBurst) s.Stats.GlobalStats.Lock() s.Stats.GlobalStats.LimitedRequests[fingerprint] = true s.Stats.GlobalStats.Unlock() } s.Stats.DownloadStats.Lock() limiter, exists := s.Stats.DownloadStats.Limiters[fingerprint] if !exists { limiter = rate.NewLimiter(limit, burst) s.Stats.DownloadStats.Limiters[fingerprint] = limiter } s.Stats.KnownHashes.RLock() data, exists := s.Stats.KnownHashes.Data[fingerprint] s.Stats.KnownHashes.RUnlock() var totalDownloaded int64 if exists { totalDownloaded = atomic.LoadInt64(&data.TotalBytes) } if totalDownloaded > security.HeavyDownloaderThreshold { limiter.SetLimit(security.HeavyDownloaderLimit) s.Stats.GlobalStats.Lock() s.Stats.GlobalStats.LimitedRequests[fingerprint] = true s.Stats.GlobalStats.Unlock() } else { limiter.SetLimit(limit) } s.Stats.DownloadStats.Unlock() req, err := http.NewRequest("GET", targetURL, nil) if err != nil { http.Error(w, "Failed to create request", http.StatusInternalServerError) return } // Forward Range headers for resumable downloads if rangeHeader := r.Header.Get("Range"); rangeHeader != "" { req.Header.Set("Range", rangeHeader) } if ifRangeHeader := r.Header.Get("If-Range"); ifRangeHeader != "" { req.Header.Set("If-Range", ifRangeHeader) } if s.GiteaToken != "" { req.Header.Set("Authorization", "token "+s.GiteaToken) } client := security.GetSafeHTTPClient(0) resp, err := client.Do(req) if err != nil { http.Error(w, "Failed to fetch asset: "+err.Error(), http.StatusBadGateway) return } defer resp.Body.Close() if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusPartialContent { http.Error(w, "Gitea returned error: "+resp.Status, http.StatusBadGateway) return } // Copy all headers from the upstream response for k, vv := range resp.Header { for _, v := range vv { w.Header().Add(k, v) } } w.WriteHeader(resp.StatusCode) tr := &security.ThrottledReader{ R: resp.Body, Limiter: limiter, Fingerprint: fingerprint, Stats: s.Stats, } n, err := io.Copy(w, tr) if err != nil { log.Printf("Error copying proxy response: %v", err) } if n > 0 { s.Stats.GlobalStats.Lock() s.Stats.GlobalStats.SuccessDownloads[fingerprint] = true s.Stats.GlobalStats.Unlock() } s.Stats.SaveHashes() } func (s *Server) LegalHandler(w http.ResponseWriter, r *http.Request) { doc := r.URL.Query().Get("doc") var filename string switch doc { case "privacy": filename = PrivacyFile case "terms": filename = TermsFile case "disclaimer": filename = DisclaimerFile default: http.Error(w, "Invalid document", http.StatusBadRequest) return } path := filepath.Join(LegalDir, filename) data, err := os.ReadFile(path) // #nosec G304 if err != nil { http.Error(w, "Document not found", http.StatusNotFound) return } w.Header().Set("Content-Type", "text/plain; charset=utf-8") // If download parameter is present, set Content-Disposition if r.URL.Query().Get("download") == "true" { w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%s", filename)) } if _, err := w.Write(data); err != nil { log.Printf("Error writing legal document: %v", err) } } func (s *Server) AvatarHandler(w http.ResponseWriter, r *http.Request) { id := r.URL.Query().Get("id") if id == "" { http.Error(w, "Missing id parameter", http.StatusBadRequest) return } cachePath := filepath.Join(s.avatarCache, id) if _, err := os.Stat(cachePath); err == nil { // Serve from cache w.Header().Set("Cache-Control", "public, max-age=86400") http.ServeFile(w, r, cachePath) return } s.urlMapMu.RLock() targetURL, exists := s.urlMap[id] s.urlMapMu.RUnlock() if !exists { http.Error(w, "Invalid or expired avatar ID", http.StatusNotFound) return } req, err := http.NewRequest("GET", targetURL, nil) if err != nil { http.Error(w, "Failed to create request", http.StatusInternalServerError) return } if s.GiteaToken != "" { req.Header.Set("Authorization", "token "+s.GiteaToken) } client := security.GetSafeHTTPClient(0) resp, err := client.Do(req) if err != nil { http.Error(w, "Failed to fetch avatar: "+err.Error(), http.StatusBadGateway) return } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { http.Error(w, "Gitea returned error: "+resp.Status, http.StatusBadGateway) return } // Copy data to cache and response simultaneously data, err := io.ReadAll(resp.Body) if err != nil { http.Error(w, "Failed to read avatar data", http.StatusInternalServerError) return } if err := os.WriteFile(cachePath, data, 0600); err != nil { log.Printf("Warning: failed to cache avatar: %v", err) } w.Header().Set("Content-Type", resp.Header.Get("Content-Type")) w.Header().Set("Cache-Control", "public, max-age=86400") _, _ = w.Write(data) } type rssFeed struct { XMLName xml.Name `xml:"rss"` Version string `xml:"version,attr"` Atom string `xml:"xmlns:atom,attr"` Channel rssChannel `xml:"channel"` } type rssChannel struct { Title string `xml:"title"` Link string `xml:"link"` Description string `xml:"description"` Language string `xml:"language"` LastBuildDate string `xml:"lastBuildDate"` AtomLink rssLink `xml:"atom:link"` Items []rssItem `xml:"item"` } type rssLink struct { Href string `xml:"href,attr"` Rel string `xml:"rel,attr"` Type string `xml:"type,attr"` } type rssItem struct { Title string `xml:"title"` Link string `xml:"link"` Description rssDescription `xml:"description"` GUID rssGUID `xml:"guid"` PubDate string `xml:"pubDate"` } type rssDescription struct { Content string `xml:",cdata"` } type rssGUID struct { Content string `xml:",chardata"` IsPermaLink bool `xml:"isPermaLink,attr"` } func (s *Server) RSSHandler(w http.ResponseWriter, r *http.Request) { softwareList := s.SoftwareList.Get() targetSoftware := r.URL.Query().Get("software") host := r.Host scheme := "http" if r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https" { scheme = "https" } baseURL := fmt.Sprintf("%s://%s", scheme, host) // Collect all releases and sort by date type item struct { Software models.Software Release models.Release } var items []item for _, sw := range softwareList { if targetSoftware != "" && sw.Name != targetSoftware { continue } for _, rel := range sw.Releases { items = append(items, item{Software: sw, Release: rel}) } } sort.Slice(items, func(i, j int) bool { return items[i].Release.CreatedAt.After(items[j].Release.CreatedAt) }) feedTitle := "Software Updates - Software Station" feedDescription := "Latest software releases and updates" selfLink := baseURL + "/api/rss" if targetSoftware != "" { feedTitle = fmt.Sprintf("%s Updates - Software Station", targetSoftware) feedDescription = fmt.Sprintf("Latest releases and updates for %s", targetSoftware) selfLink = fmt.Sprintf("%s/api/rss?software=%s", baseURL, targetSoftware) } feed := rssFeed{ Version: "2.0", Atom: "http://www.w3.org/2005/Atom", Channel: rssChannel{ Title: feedTitle, Link: baseURL, Description: feedDescription, Language: "en-us", LastBuildDate: time.Now().Format(time.RFC1123Z), AtomLink: rssLink{ Href: selfLink, Rel: "self", Type: "application/rss+xml", }, }, } for i, it := range items { if i >= 50 { break } title := fmt.Sprintf("%s %s", it.Software.Name, it.Release.TagName) link := baseURL description := it.Software.Description if it.Release.Body != "" { description = it.Release.Body } feed.Channel.Items = append(feed.Channel.Items, rssItem{ Title: title, Link: link, Description: rssDescription{ Content: description, }, GUID: rssGUID{ Content: fmt.Sprintf("%s-%s", it.Software.Name, it.Release.TagName), IsPermaLink: false, }, PubDate: it.Release.CreatedAt.Format(time.RFC1123Z), }) } w.Header().Set("Content-Type", "application/rss+xml; charset=utf-8") w.Header().Set("Cache-Control", "public, max-age=300") fmt.Fprint(w, xml.Header) enc := xml.NewEncoder(w) enc.Indent("", " ") if err := enc.Encode(feed); err != nil { log.Printf("Error encoding RSS feed: %v", err) } }