diff --git a/internal/api/handlers.go b/internal/api/handlers.go index d63e267..a70b650 100644 --- a/internal/api/handlers.go +++ b/internal/api/handlers.go @@ -1,6 +1,7 @@ package api import ( + "bytes" "crypto/rand" "crypto/sha256" "encoding/hex" @@ -56,26 +57,24 @@ type Server struct { Stats *stats.Service urlMap map[string]string urlMapMu sync.RWMutex - rssCache atomic.Value // stores string - rssLastMod atomic.Value // stores time.Time + proxiedCache []models.Software + proxiedMu sync.RWMutex + rssCache atomic.Value + rssLastMod atomic.Value avatarCache string salt []byte } func NewServer(token string, initialSoftware []models.Software, statsService *stats.Service) *Server { - salt := make([]byte, 32) - if _, err := rand.Read(salt); err != nil { - log.Fatalf("Failed to generate random salt: %v", err) - } - s := &Server{ GiteaToken: token, SoftwareList: &SoftwareCache{data: initialSoftware}, Stats: statsService, urlMap: make(map[string]string), avatarCache: ".cache/avatars", - salt: salt, } + + s.loadSalt() s.rssCache.Store("") s.rssLastMod.Store(time.Time{}) @@ -83,11 +82,93 @@ func NewServer(token string, initialSoftware []models.Software, statsService *st log.Printf("Warning: failed to create avatar cache directory: %v", err) } + s.RefreshProxiedList() + go s.startAvatarCleanup() return s } +func (s *Server) loadSalt() { + saltPath := ".salt" + data, err := os.ReadFile(saltPath) + if err == nil && len(data) == 32 { + s.salt = data + return + } + + s.salt = make([]byte, 32) + if _, err := rand.Read(s.salt); err != nil { + log.Fatalf("Failed to generate random salt: %v", err) + } + + if err := os.WriteFile(saltPath, s.salt, 0600); err != nil { + log.Printf("Warning: failed to save salt to %s: %v", saltPath, err) + } +} + +func (s *Server) RefreshProxiedList() { + softwareList := s.SoftwareList.Get() + newProxied := make([]models.Software, len(softwareList)) + newUrlMap := make(map[string]string) + + for i, sw := range softwareList { + newProxied[i] = sw + if sw.IsPrivate { + newProxied[i].GiteaURL = "" + } + + if sw.AvatarURL != "" { + hash := s.computeHash(sw.AvatarURL) + newUrlMap[hash] = sw.AvatarURL + newProxied[i].AvatarURL = hash + } + + newProxied[i].Releases = make([]models.Release, len(sw.Releases)) + for j, rel := range sw.Releases { + newProxied[i].Releases[j] = rel + newProxied[i].Releases[j].Assets = make([]models.Asset, len(rel.Assets)) + for k, asset := range rel.Assets { + newProxied[i].Releases[j].Assets[k] = asset + hash := s.computeHash(asset.URL) + newUrlMap[hash] = asset.URL + newProxied[i].Releases[j].Assets[k].URL = hash + } + } + } + + s.urlMapMu.Lock() + s.urlMap = newUrlMap + s.urlMapMu.Unlock() + + s.proxiedMu.Lock() + s.proxiedCache = newProxied + s.proxiedMu.Unlock() + + // Invalidate RSS cache as well + s.rssCache.Store("") +} + +func (s *Server) computeHash(targetURL string) string { + h := sha256.New() + h.Write(s.salt) + h.Write([]byte(targetURL)) + return hex.EncodeToString(h.Sum(nil)) +} + +func (s *Server) UpdateSoftwareList(newList []models.Software) { + s.SoftwareList.Set(newList) + s.RefreshProxiedList() +} + +func (s *Server) RegisterURL(targetURL string) string { + hash := s.computeHash(targetURL) + s.urlMapMu.Lock() + s.urlMap[hash] = targetURL + s.urlMapMu.Unlock() + return hash +} + func (s *Server) startAvatarCleanup() { ticker := time.NewTicker(AvatarCacheInterval) for range ticker.C { @@ -146,22 +227,12 @@ func (s *Server) cleanupAvatarCache() { log.Printf("Avatar cache cleaned up. Current size: %v bytes", totalSize) } -func (s *Server) RegisterURL(targetURL string) string { - h := sha256.New() - h.Write(s.salt) - h.Write([]byte(targetURL)) - hash := hex.EncodeToString(h.Sum(nil)) - - s.urlMapMu.Lock() - s.urlMap[hash] = targetURL - s.urlMapMu.Unlock() - return hash -} - func (s *Server) APISoftwareHandler(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") - softwareList := s.SoftwareList.Get() + s.proxiedMu.RLock() + proxiedList := s.proxiedCache + s.proxiedMu.RUnlock() host := r.Host scheme := "http" @@ -169,33 +240,27 @@ func (s *Server) APISoftwareHandler(w http.ResponseWriter, r *http.Request) { scheme = "https" } - proxiedList := make([]models.Software, len(softwareList)) - for i, sw := range softwareList { - proxiedList[i] = sw - // If private, hide Gitea URL completely. If public, we can show it for the repo link. - if sw.IsPrivate { - proxiedList[i].GiteaURL = "" - } - - // Proxy avatar if it exists + finalList := make([]models.Software, len(proxiedList)) + for i, sw := range proxiedList { + finalList[i] = sw if sw.AvatarURL != "" { - hash := s.RegisterURL(sw.AvatarURL) - proxiedList[i].AvatarURL = fmt.Sprintf("%s://%s/api/avatar?id=%s", scheme, host, hash) + finalList[i].AvatarURL = fmt.Sprintf("%s://%s/api/avatar?id=%s", scheme, host, sw.AvatarURL) } - proxiedList[i].Releases = make([]models.Release, len(sw.Releases)) + finalList[i].Releases = make([]models.Release, len(sw.Releases)) for j, rel := range sw.Releases { - proxiedList[i].Releases[j] = rel - proxiedList[i].Releases[j].Assets = make([]models.Asset, len(rel.Assets)) + finalList[i].Releases[j] = rel + finalList[i].Releases[j].Assets = make([]models.Asset, len(rel.Assets)) for k, asset := range rel.Assets { - proxiedList[i].Releases[j].Assets[k] = asset - hash := s.RegisterURL(asset.URL) - proxiedList[i].Releases[j].Assets[k].URL = fmt.Sprintf("%s://%s/api/download?id=%s", scheme, host, hash) + finalList[i].Releases[j].Assets[k] = asset + if asset.URL != "" { + finalList[i].Releases[j].Assets[k].URL = fmt.Sprintf("%s://%s/api/download?id=%s", scheme, host, asset.URL) + } } } } - if err := json.NewEncoder(w).Encode(proxiedList); err != nil { + if err := json.NewEncoder(w).Encode(finalList); err != nil { log.Printf("Error encoding software list: %v", err) } } @@ -364,11 +429,9 @@ func (s *Server) AvatarHandler(w http.ResponseWriter, r *http.Request) { cachePath := filepath.Join(s.avatarCache, id) if _, err := os.Stat(cachePath); err == nil { - // Update modification time for LRU cleanup now := time.Now() _ = os.Chtimes(cachePath, now, now) - // Serve from cache w.Header().Set("Cache-Control", "public, max-age=86400") http.ServeFile(w, r, cachePath) return @@ -406,7 +469,6 @@ func (s *Server) AvatarHandler(w http.ResponseWriter, r *http.Request) { return } - // Copy data to cache and response simultaneously data, err := io.ReadAll(resp.Body) if err != nil { http.Error(w, "Failed to read avatar data", http.StatusInternalServerError) @@ -463,9 +525,26 @@ type rssGUID struct { } func (s *Server) RSSHandler(w http.ResponseWriter, r *http.Request) { - softwareList := s.SoftwareList.Get() targetSoftware := r.URL.Query().Get("software") + if targetSoftware == "" { + if cached := s.rssCache.Load().(string); cached != "" { + w.Header().Set("Content-Type", "application/rss+xml; charset=utf-8") + w.Header().Set("Cache-Control", "public, max-age=300") + lastMod := s.rssLastMod.Load().(time.Time) + if !lastMod.IsZero() { + w.Header().Set("Last-Modified", lastMod.Format(http.TimeFormat)) + if r.Header.Get("If-Modified-Since") == lastMod.Format(http.TimeFormat) { + w.WriteHeader(http.StatusNotModified) + return + } + } + _, _ = w.Write([]byte(cached)) + return + } + } + + softwareList := s.SoftwareList.Get() host := r.Host scheme := "http" if r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https" { @@ -473,7 +552,6 @@ func (s *Server) RSSHandler(w http.ResponseWriter, r *http.Request) { } baseURL := fmt.Sprintf("%s://%s", scheme, host) - // Collect all releases and sort by date type item struct { Software models.Software Release models.Release @@ -543,13 +621,23 @@ func (s *Server) RSSHandler(w http.ResponseWriter, r *http.Request) { }) } - w.Header().Set("Content-Type", "application/rss+xml; charset=utf-8") - w.Header().Set("Cache-Control", "public, max-age=300") - - fmt.Fprint(w, xml.Header) - enc := xml.NewEncoder(w) + var buf bytes.Buffer + buf.WriteString(xml.Header) + enc := xml.NewEncoder(&buf) enc.Indent("", " ") if err := enc.Encode(feed); err != nil { log.Printf("Error encoding RSS feed: %v", err) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return } + + xmlData := buf.String() + if targetSoftware == "" { + s.rssCache.Store(xmlData) + s.rssLastMod.Store(time.Now()) + } + + w.Header().Set("Content-Type", "application/rss+xml; charset=utf-8") + w.Header().Set("Cache-Control", "public, max-age=300") + _, _ = w.Write([]byte(xmlData)) } diff --git a/internal/api/handlers_test.go b/internal/api/handlers_test.go new file mode 100644 index 0000000..247f025 --- /dev/null +++ b/internal/api/handlers_test.go @@ -0,0 +1,110 @@ +package api + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "software-station/internal/models" + "software-station/internal/stats" + "strings" + "testing" +) + +func TestHandlers(t *testing.T) { + os.Setenv("ALLOW_LOOPBACK", "true") + defer os.Unsetenv("ALLOW_LOOPBACK") + + tempHashes := "test_handlers_hashes.json" + defer os.Remove(tempHashes) + os.RemoveAll(".cache") + + statsService := stats.NewService(tempHashes) + initialSoftware := []models.Software{ + { + Name: "test-app", + Releases: []models.Release{ + { + TagName: "v1.0.0", + Assets: []models.Asset{ + {Name: "test.exe", URL: "http://example.com/test.exe"}, + }, + }, + }, + AvatarURL: "http://example.com/logo.png", + }, + } + server := NewServer("token", initialSoftware, statsService) + + t.Run("APISoftwareHandler", func(t *testing.T) { + req := httptest.NewRequest("GET", "/api/software", nil) + rr := httptest.NewRecorder() + server.APISoftwareHandler(rr, req) + + if rr.Code != http.StatusOK { + t.Errorf("expected 200, got %d", rr.Code) + } + + var sw []models.Software + if err := json.Unmarshal(rr.Body.Bytes(), &sw); err != nil { + t.Fatal(err) + } + if len(sw) != 1 || sw[0].Name != "test-app" { + t.Errorf("unexpected response: %v", sw) + } + if !strings.Contains(sw[0].AvatarURL, "/api/avatar?id=") { + t.Errorf("AvatarURL not proxied: %s", sw[0].AvatarURL) + } + }) + + t.Run("AvatarHandler", func(t *testing.T) { + // Mock upstream + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "image/png") + w.Write([]byte("fake-image")) + })) + defer upstream.Close() + + hash := server.RegisterURL(upstream.URL) + req := httptest.NewRequest("GET", "/api/avatar?id="+hash, nil) + rr := httptest.NewRecorder() + server.AvatarHandler(rr, req) + + if rr.Code != http.StatusOK { + t.Errorf("expected 200, got %d", rr.Code) + } + if rr.Header().Get("Content-Type") != "image/png" { + t.Errorf("expected image/png, got %s", rr.Header().Get("Content-Type")) + } + }) + + t.Run("DownloadProxyHandler", func(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("fake-binary")) + })) + defer upstream.Close() + + hash := server.RegisterURL(upstream.URL) + req := httptest.NewRequest("GET", "/api/download?id="+hash, nil) + rr := httptest.NewRecorder() + server.DownloadProxyHandler(rr, req) + + if rr.Code != http.StatusOK { + t.Errorf("expected 200, got %d", rr.Code) + } + }) + + t.Run("RSSHandler", func(t *testing.T) { + req := httptest.NewRequest("GET", "/api/rss", nil) + rr := httptest.NewRecorder() + server.RSSHandler(rr, req) + + if rr.Code != http.StatusOK { + t.Errorf("expected 200, got %d", rr.Code) + } + if !strings.Contains(rr.Body.String(), " 0 { - mu.Lock() - *softwareList = newList - mu.Unlock() + onUpdate(newList) log.Printf("Software list updated with %d items", len(newList)) } } diff --git a/internal/gitea/client.go b/internal/gitea/client.go index db4616d..7f00b23 100644 --- a/internal/gitea/client.go +++ b/internal/gitea/client.go @@ -118,10 +118,8 @@ func FetchRepoInfo(server, token, owner, repo string) (string, []string, string, func detectLicenseFromFile(server, token, owner, repo, defaultBranch string) string { branches := []string{"main", "master"} if defaultBranch != "" { - // Put default branch first branches = append([]string{defaultBranch}, "main", "master") } - // Deduplicate seen := make(map[string]bool) var finalBranches []string for _, b := range branches { @@ -149,7 +147,6 @@ func detectLicenseFromFile(server, token, owner, repo, defaultBranch string) str defer resp.Body.Close() if resp.StatusCode == http.StatusOK { - // Read first few lines to guess license scanner := bufio.NewScanner(resp.Body) for i := 0; i < 5 && scanner.Scan(); i++ { line := strings.ToUpper(scanner.Text()) @@ -166,7 +163,7 @@ func detectLicenseFromFile(server, token, owner, repo, defaultBranch string) str return "BSD" } } - return "LICENSE" // Found file but couldn't detect type + return "LICENSE" } } return "" @@ -223,7 +220,6 @@ func FetchReleases(server, token, owner, repo string) ([]models.Release, error) var assets []models.Asset var checksumsURL string - // First pass: identify assets and look for checksum file for _, ga := range gr.Assets { if ga.Name == "SHA256SUMS" { checksumsURL = ga.URL @@ -238,7 +234,6 @@ func FetchReleases(server, token, owner, repo string) ([]models.Release, error) }) } - // Second pass: if checksum file exists, fetch and parse it if checksumsURL != "" { checksums, err := fetchAndParseChecksums(checksumsURL, token) if err == nil { @@ -290,7 +285,6 @@ func fetchAndParseChecksums(url, token string) (map[string]string, error) { } parts := strings.Fields(line) if len(parts) >= 2 { - // Format is usually: hash filename checksums[parts[1]] = parts[0] } } diff --git a/internal/security/security.go b/internal/security/security.go index c82d802..cd2607b 100644 --- a/internal/security/security.go +++ b/internal/security/security.go @@ -90,10 +90,7 @@ func GetRequestFingerprint(r *http.Request, s *stats.Service) string { ipStr = ip.String() } - // Improve fingerprinting with more entropy ua := r.Header.Get("User-Agent") - lang := r.Header.Get("Accept-Language") - enc := r.Header.Get("Accept-Encoding") chUA := r.Header.Get("Sec-CH-UA") hash := sha256.New() @@ -101,20 +98,18 @@ func GetRequestFingerprint(r *http.Request, s *stats.Service) string { hash.Write([]byte("|")) hash.Write([]byte(ua)) hash.Write([]byte("|")) - hash.Write([]byte(lang)) - hash.Write([]byte("|")) - hash.Write([]byte(enc)) - hash.Write([]byte("|")) hash.Write([]byte(chUA)) fingerprint := hex.EncodeToString(hash.Sum(nil)) s.KnownHashes.Lock() if _, exists := s.KnownHashes.Data[fingerprint]; !exists { - s.KnownHashes.Data[fingerprint] = &models.FingerprintData{ - Known: true, + if len(s.KnownHashes.Data) < 10000 { + s.KnownHashes.Data[fingerprint] = &models.FingerprintData{ + Known: true, + } + s.SaveHashes() } - s.SaveHashes() } s.KnownHashes.Unlock() @@ -129,7 +124,6 @@ func IsPrivateIP(ip net.IP) bool { return true } - // Private IP ranges privateRanges := []struct { start net.IP end net.IP diff --git a/internal/stats/stats.go b/internal/stats/stats.go index 4efaa9c..6a1a756 100644 --- a/internal/stats/stats.go +++ b/internal/stats/stats.go @@ -59,13 +59,17 @@ func NewService(hashesFile string) *Service { func (s *Service) Start() { go func() { ticker := time.NewTicker(10 * time.Second) + statsTicker := time.NewTicker(24 * time.Hour) defer ticker.Stop() + defer statsTicker.Stop() for { select { case <-ticker.C: if atomic.CompareAndSwapInt32(&s.hashesDirty, 1, 0) { s.FlushHashes() } + case <-statsTicker.C: + s.ResetGlobalStats() case <-s.stopChan: s.FlushHashes() return @@ -74,6 +78,19 @@ func (s *Service) Start() { }() } +func (s *Service) ResetGlobalStats() { + s.GlobalStats.Lock() + defer s.GlobalStats.Unlock() + s.GlobalStats.UniqueRequests = make(map[string]bool) + s.GlobalStats.SuccessDownloads = make(map[string]bool) + s.GlobalStats.BlockedRequests = make(map[string]bool) + s.GlobalStats.LimitedRequests = make(map[string]bool) + s.GlobalStats.WebRequests = make(map[string]bool) + s.GlobalStats.TotalRequests = 0 + s.GlobalStats.TotalResponseTime = 0 + s.GlobalStats.StartTime = time.Now() +} + func (s *Service) Stop() { close(s.stopChan) } diff --git a/internal/stats/stats_test.go b/internal/stats/stats_test.go index ec3e2b4..7fef2f5 100644 --- a/internal/stats/stats_test.go +++ b/internal/stats/stats_test.go @@ -51,4 +51,18 @@ func TestStats(t *testing.T) { if stats["status"] != "healthy" { t.Errorf("expected healthy status, got %v", stats["status"]) } + + // Test ResetGlobalStats + service.GlobalStats.Lock() + service.GlobalStats.TotalRequests = 10 + service.GlobalStats.Unlock() + service.ResetGlobalStats() + if service.GlobalStats.TotalRequests != 0 { + t.Error("ResetGlobalStats did not reset TotalRequests") + } + + // Test Start/Stop/SaveHashes + service.Start() + service.SaveHashes() + service.Stop() } diff --git a/main.go b/main.go index 317aa7a..987712e 100644 --- a/main.go +++ b/main.go @@ -54,7 +54,7 @@ func main() { initialSoftware := config.LoadSoftware(configPath, giteaServer, giteaToken) apiServer := api.NewServer(giteaToken, initialSoftware, statsService) - config.StartBackgroundUpdater(configPath, giteaServer, giteaToken, apiServer.SoftwareList.GetLock(), apiServer.SoftwareList.GetDataPtr(), *updateInterval) + config.StartBackgroundUpdater(configPath, giteaServer, giteaToken, *updateInterval, apiServer.UpdateSoftwareList) r := chi.NewRouter() @@ -122,13 +122,11 @@ func main() { f, err := contentStatic.Open(path) if err != nil { - // If it's an API request, return a proper 404 if strings.HasPrefix(r.URL.Path, "/api") { http.Error(w, "Not Found", http.StatusNotFound) return } - // For SPA, serve index.html for unknown frontend routes indexData, err := fs.ReadFile(contentStatic, "index.html") if err != nil { http.Error(w, "Index not found", http.StatusInternalServerError)