This commit is contained in:
2025-12-27 02:57:25 -06:00
parent 63a6f8f7dc
commit 1d5d6aacb4
68 changed files with 6884 additions and 0 deletions

10
internal/api/constants.go Normal file
View File

@@ -0,0 +1,10 @@
package api
const (
CompressionLevel = 5
LegalDir = "legal"
PrivacyFile = "privacy.txt"
TermsFile = "terms.txt"
DisclaimerFile = "disclaimer.txt"
)

425
internal/api/handlers.go Normal file
View File

@@ -0,0 +1,425 @@
package api
import (
"crypto/sha256"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"os"
"path/filepath"
"sort"
"strings"
"sync"
"sync/atomic"
"time"
"software-station/internal/models"
"software-station/internal/security"
"software-station/internal/stats"
"golang.org/x/time/rate"
)
type SoftwareCache struct {
mu sync.RWMutex
data []models.Software
}
func (c *SoftwareCache) Get() []models.Software {
c.mu.RLock()
defer c.mu.RUnlock()
return c.data
}
func (c *SoftwareCache) Set(data []models.Software) {
c.mu.Lock()
defer c.mu.Unlock()
c.data = data
}
func (c *SoftwareCache) GetLock() *sync.RWMutex {
return &c.mu
}
func (c *SoftwareCache) GetDataPtr() *[]models.Software {
return &c.data
}
type Server struct {
GiteaToken string
SoftwareList *SoftwareCache
Stats *stats.Service
urlMap map[string]string
urlMapMu sync.RWMutex
rssCache atomic.Value // stores string
rssLastMod atomic.Value // stores time.Time
avatarCache string
}
func NewServer(token string, initialSoftware []models.Software, statsService *stats.Service) *Server {
s := &Server{
GiteaToken: token,
SoftwareList: &SoftwareCache{data: initialSoftware},
Stats: statsService,
urlMap: make(map[string]string),
avatarCache: ".cache/avatars",
}
s.rssCache.Store("")
s.rssLastMod.Store(time.Time{})
if err := os.MkdirAll(s.avatarCache, 0750); err != nil {
log.Printf("Warning: failed to create avatar cache directory: %v", err)
}
return s
}
func (s *Server) RegisterURL(targetURL string) string {
hash := fmt.Sprintf("%x", sha256.Sum256([]byte(targetURL)))
s.urlMapMu.Lock()
s.urlMap[hash] = targetURL
s.urlMapMu.Unlock()
return hash
}
func (s *Server) APISoftwareHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
softwareList := s.SoftwareList.Get()
host := r.Host
scheme := "http"
if r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https" {
scheme = "https"
}
proxiedList := make([]models.Software, len(softwareList))
for i, sw := range softwareList {
proxiedList[i] = sw
// If private, hide Gitea URL completely. If public, we can show it for the repo link.
if sw.IsPrivate {
proxiedList[i].GiteaURL = ""
}
// Proxy avatar if it exists
if sw.AvatarURL != "" {
hash := s.RegisterURL(sw.AvatarURL)
proxiedList[i].AvatarURL = fmt.Sprintf("%s://%s/api/avatar?id=%s", scheme, host, hash)
}
proxiedList[i].Releases = make([]models.Release, len(sw.Releases))
for j, rel := range sw.Releases {
proxiedList[i].Releases[j] = rel
proxiedList[i].Releases[j].Assets = make([]models.Asset, len(rel.Assets))
for k, asset := range rel.Assets {
proxiedList[i].Releases[j].Assets[k] = asset
hash := s.RegisterURL(asset.URL)
proxiedList[i].Releases[j].Assets[k].URL = fmt.Sprintf("%s://%s/api/download?id=%s", scheme, host, hash)
}
}
}
if err := json.NewEncoder(w).Encode(proxiedList); err != nil {
log.Printf("Error encoding software list: %v", err)
}
}
func (s *Server) DownloadProxyHandler(w http.ResponseWriter, r *http.Request) {
id := r.URL.Query().Get("id")
if id == "" {
http.Error(w, "Missing id parameter", http.StatusBadRequest)
return
}
s.urlMapMu.RLock()
targetURL, exists := s.urlMap[id]
s.urlMapMu.RUnlock()
if !exists {
http.Error(w, "Invalid or expired download ID", http.StatusNotFound)
return
}
fingerprint := security.GetRequestFingerprint(r, s.Stats)
ua := strings.ToLower(r.UserAgent())
isSpeedDownloader := false
for _, sd := range security.SpeedDownloaders {
if strings.Contains(ua, sd) {
isSpeedDownloader = true
break
}
}
limit := security.DefaultDownloadLimit
burst := int(security.DefaultDownloadBurst)
if isSpeedDownloader {
limit = security.SpeedDownloaderLimit
burst = int(security.SpeedDownloaderBurst)
s.Stats.GlobalStats.Lock()
s.Stats.GlobalStats.LimitedRequests[fingerprint] = true
s.Stats.GlobalStats.Unlock()
}
s.Stats.DownloadStats.Lock()
limiter, exists := s.Stats.DownloadStats.Limiters[fingerprint]
if !exists {
limiter = rate.NewLimiter(limit, burst)
s.Stats.DownloadStats.Limiters[fingerprint] = limiter
}
s.Stats.KnownHashes.RLock()
data, exists := s.Stats.KnownHashes.Data[fingerprint]
s.Stats.KnownHashes.RUnlock()
var totalDownloaded int64
if exists {
totalDownloaded = atomic.LoadInt64(&data.TotalBytes)
}
if totalDownloaded > security.HeavyDownloaderThreshold {
limiter.SetLimit(security.HeavyDownloaderLimit)
s.Stats.GlobalStats.Lock()
s.Stats.GlobalStats.LimitedRequests[fingerprint] = true
s.Stats.GlobalStats.Unlock()
} else {
limiter.SetLimit(limit)
}
s.Stats.DownloadStats.Unlock()
req, err := http.NewRequest("GET", targetURL, nil)
if err != nil {
http.Error(w, "Failed to create request", http.StatusInternalServerError)
return
}
// Forward Range headers for resumable downloads
if rangeHeader := r.Header.Get("Range"); rangeHeader != "" {
req.Header.Set("Range", rangeHeader)
}
if ifRangeHeader := r.Header.Get("If-Range"); ifRangeHeader != "" {
req.Header.Set("If-Range", ifRangeHeader)
}
if s.GiteaToken != "" {
req.Header.Set("Authorization", "token "+s.GiteaToken)
}
client := security.GetSafeHTTPClient(0)
resp, err := client.Do(req)
if err != nil {
http.Error(w, "Failed to fetch asset: "+err.Error(), http.StatusBadGateway)
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusPartialContent {
http.Error(w, "Gitea returned error: "+resp.Status, http.StatusBadGateway)
return
}
// Copy all headers from the upstream response
for k, vv := range resp.Header {
for _, v := range vv {
w.Header().Add(k, v)
}
}
w.WriteHeader(resp.StatusCode)
tr := &security.ThrottledReader{
R: resp.Body,
Limiter: limiter,
Fingerprint: fingerprint,
Stats: s.Stats,
}
n, err := io.Copy(w, tr)
if err != nil {
log.Printf("Error copying proxy response: %v", err)
}
if n > 0 {
s.Stats.GlobalStats.Lock()
s.Stats.GlobalStats.SuccessDownloads[fingerprint] = true
s.Stats.GlobalStats.Unlock()
}
s.Stats.SaveHashes()
}
func (s *Server) LegalHandler(w http.ResponseWriter, r *http.Request) {
doc := r.URL.Query().Get("doc")
var filename string
switch doc {
case "privacy":
filename = PrivacyFile
case "terms":
filename = TermsFile
case "disclaimer":
filename = DisclaimerFile
default:
http.Error(w, "Invalid document", http.StatusBadRequest)
return
}
path := filepath.Join(LegalDir, filename)
data, err := os.ReadFile(path) // #nosec G304
if err != nil {
http.Error(w, "Document not found", http.StatusNotFound)
return
}
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
// If download parameter is present, set Content-Disposition
if r.URL.Query().Get("download") == "true" {
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%s", filename))
}
if _, err := w.Write(data); err != nil {
log.Printf("Error writing legal document: %v", err)
}
}
func (s *Server) AvatarHandler(w http.ResponseWriter, r *http.Request) {
id := r.URL.Query().Get("id")
if id == "" {
http.Error(w, "Missing id parameter", http.StatusBadRequest)
return
}
cachePath := filepath.Join(s.avatarCache, id)
if _, err := os.Stat(cachePath); err == nil {
// Serve from cache
w.Header().Set("Cache-Control", "public, max-age=86400")
http.ServeFile(w, r, cachePath)
return
}
s.urlMapMu.RLock()
targetURL, exists := s.urlMap[id]
s.urlMapMu.RUnlock()
if !exists {
http.Error(w, "Invalid or expired avatar ID", http.StatusNotFound)
return
}
req, err := http.NewRequest("GET", targetURL, nil)
if err != nil {
http.Error(w, "Failed to create request", http.StatusInternalServerError)
return
}
if s.GiteaToken != "" {
req.Header.Set("Authorization", "token "+s.GiteaToken)
}
client := security.GetSafeHTTPClient(0)
resp, err := client.Do(req)
if err != nil {
http.Error(w, "Failed to fetch avatar: "+err.Error(), http.StatusBadGateway)
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
http.Error(w, "Gitea returned error: "+resp.Status, http.StatusBadGateway)
return
}
// Copy data to cache and response simultaneously
data, err := io.ReadAll(resp.Body)
if err != nil {
http.Error(w, "Failed to read avatar data", http.StatusInternalServerError)
return
}
if err := os.WriteFile(cachePath, data, 0600); err != nil {
log.Printf("Warning: failed to cache avatar: %v", err)
}
w.Header().Set("Content-Type", resp.Header.Get("Content-Type"))
w.Header().Set("Cache-Control", "public, max-age=86400")
_, _ = w.Write(data)
}
func (s *Server) RSSHandler(w http.ResponseWriter, r *http.Request) {
softwareList := s.SoftwareList.Get()
targetSoftware := r.URL.Query().Get("software")
host := r.Host
scheme := "http"
if r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https" {
scheme = "https"
}
baseURL := fmt.Sprintf("%s://%s", scheme, host)
// Collect all releases and sort by date
type item struct {
Software models.Software
Release models.Release
}
var items []item
for _, sw := range softwareList {
if targetSoftware != "" && sw.Name != targetSoftware {
continue
}
for _, rel := range sw.Releases {
items = append(items, item{Software: sw, Release: rel})
}
}
sort.Slice(items, func(i, j int) bool {
return items[i].Release.CreatedAt.After(items[j].Release.CreatedAt)
})
feedTitle := "Software Updates - Software Station"
feedDescription := "Latest software releases and updates"
selfLink := baseURL + "/api/rss"
if targetSoftware != "" {
feedTitle = fmt.Sprintf("%s Updates - Software Station", targetSoftware)
feedDescription = fmt.Sprintf("Latest releases and updates for %s", targetSoftware)
selfLink = fmt.Sprintf("%s/api/rss?software=%s", baseURL, targetSoftware)
}
var b strings.Builder
b.WriteString(`<?xml version="1.0" encoding="UTF-8" ?>
<rss version="2.0" xmlns:atom="http://www.w3.org/2005/Atom">
<channel>
<title>` + feedTitle + `</title>
<link>` + baseURL + `</link>
<description>` + feedDescription + `</description>
<language>en-us</language>
<lastBuildDate>` + time.Now().Format(time.RFC1123Z) + `</lastBuildDate>
<atom:link href="` + selfLink + `" rel="self" type="application/rss+xml" />
`)
for i, it := range items {
if i >= 50 {
break
}
title := fmt.Sprintf("%s %s", it.Software.Name, it.Release.TagName)
link := baseURL
description := it.Software.Description
if it.Release.Body != "" {
description = it.Release.Body
}
fmt.Fprintf(&b, ` <item>
<title>%s</title>
<link>%s</link>
<description><![CDATA[%s]]></description>
<guid isPermaLink="false">%s-%s</guid>
<pubDate>%s</pubDate>
</item>
`, title, link, description, it.Software.Name, it.Release.TagName, it.Release.CreatedAt.Format(time.RFC1123Z))
}
b.WriteString(`</channel>
</rss>`)
w.Header().Set("Content-Type", "application/rss+xml; charset=utf-8")
w.Header().Set("Cache-Control", "public, max-age=300")
_, _ = w.Write([]byte(b.String()))
}

44
internal/cache/cache.go vendored Normal file
View File

@@ -0,0 +1,44 @@
package cache
import (
"encoding/json"
"fmt"
"log"
"os"
"path/filepath"
"software-station/internal/models"
)
const cacheDir = ".cache"
func init() {
if err := os.MkdirAll(cacheDir, 0750); err != nil {
log.Printf("Warning: failed to create cache directory: %v", err)
}
}
func GetCachePath(owner, repo string) string {
return filepath.Join(cacheDir, filepath.Clean(fmt.Sprintf("%s_%s.json", owner, repo)))
}
func SaveToCache(owner, repo string, software models.Software) error {
path := GetCachePath(owner, repo)
data, err := json.MarshalIndent(software, "", " ")
if err != nil {
return err
}
return os.WriteFile(path, data, 0600)
}
func GetFromCache(owner, repo string) (*models.Software, error) {
path := GetCachePath(owner, repo)
data, err := os.ReadFile(path) // #nosec G304
if err != nil {
return nil, err
}
var software models.Software
if err := json.Unmarshal(data, &software); err != nil {
return nil, err
}
return &software, nil
}

43
internal/cache/cache_test.go vendored Normal file
View File

@@ -0,0 +1,43 @@
package cache
import (
"os"
"software-station/internal/models"
"testing"
)
func TestCache(t *testing.T) {
owner := "test-owner"
repo := "test-repo"
software := models.Software{
Name: "test-repo",
Owner: "test-owner",
Description: "test desc",
}
// Clean up before and after
path := GetCachePath(owner, repo)
os.Remove(path)
defer os.Remove(path)
// Test GetFromCache missing
_, err := GetFromCache(owner, repo)
if err == nil {
t.Error("expected error for missing cache")
}
// Test SaveToCache
err = SaveToCache(owner, repo, software)
if err != nil {
t.Fatalf("failed to save to cache: %v", err)
}
// Test GetFromCache success
cached, err := GetFromCache(owner, repo)
if err != nil {
t.Fatalf("failed to get from cache: %v", err)
}
if cached.Name != software.Name || cached.Owner != software.Owner {
t.Errorf("cached data mismatch: %+v", cached)
}
}

86
internal/config/config.go Normal file
View File

@@ -0,0 +1,86 @@
package config
import (
"bufio"
"io"
"log"
"net/http"
"os"
"strings"
"software-station/internal/cache"
"software-station/internal/gitea"
"software-station/internal/models"
)
func LoadSoftware(path, server, token string) []models.Software {
return LoadSoftwareExtended(path, server, token, true)
}
func LoadSoftwareExtended(path, server, token string, useCache bool) []models.Software {
var reader io.Reader
if strings.HasPrefix(path, "http://") || strings.HasPrefix(path, "https://") {
resp, err := http.Get(path) // #nosec G107
if err != nil {
log.Printf("Error fetching remote config: %v", err)
return nil
}
defer resp.Body.Close()
reader = resp.Body
} else {
file, err := os.Open(path) // #nosec G304
if err != nil {
log.Printf("Warning: config file %s not found", path)
return nil
}
defer file.Close()
reader = file
}
var softwareList []models.Software
scanner := bufio.NewScanner(reader)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.HasPrefix(line, "#") {
continue
}
parts := strings.Split(line, "/")
if len(parts) == 2 {
owner, repo := parts[0], parts[1]
// Try to get from cache first
if useCache {
if cached, err := cache.GetFromCache(owner, repo); err == nil {
softwareList = append(softwareList, *cached)
continue
}
}
desc, topics, license, isPrivate, avatarURL, err := gitea.FetchRepoInfo(server, token, owner, repo)
if err != nil {
log.Printf("Error fetching repo info for %s/%s: %v", owner, repo, err)
}
releases, err := gitea.FetchReleases(server, token, owner, repo)
if err != nil {
log.Printf("Error fetching releases for %s/%s: %v", owner, repo, err)
}
sw := models.Software{
Name: repo,
Owner: owner,
Description: desc,
Releases: releases,
GiteaURL: server,
Topics: topics,
License: license,
IsPrivate: isPrivate,
AvatarURL: avatarURL,
}
softwareList = append(softwareList, sw)
if err := cache.SaveToCache(owner, repo, sw); err != nil {
log.Printf("Error saving to cache for %s/%s: %v", owner, repo, err)
}
}
}
return softwareList
}

View File

@@ -0,0 +1,62 @@
package config
import (
"net/http"
"net/http/httptest"
"os"
"software-station/internal/models"
"sync"
"testing"
"time"
)
func TestConfig(t *testing.T) {
// Test Local Config
configPath := "test_software.txt"
os.WriteFile(configPath, []byte("Owner/Repo\n#Comment\n\nOwner2/Repo2"), 0644)
defer os.Remove(configPath)
// Mock Gitea
mockGitea := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(`{"description": "Test", "topics": [], "license": {"name": "MIT"}}`))
}))
defer mockGitea.Close()
list := LoadSoftware(configPath, mockGitea.URL, "")
if len(list) != 2 {
t.Errorf("expected 2 repos, got %d", len(list))
}
// Test Remote Config
mockRemote := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("Owner3/Repo3"))
}))
defer mockRemote.Close()
listRemote := LoadSoftware(mockRemote.URL, mockGitea.URL, "")
if len(listRemote) != 1 || listRemote[0].Name != "Repo3" {
t.Errorf("expected Repo3, got %v", listRemote)
}
}
func TestBackgroundUpdater(t *testing.T) {
configPath := "test_updater.txt"
os.WriteFile(configPath, []byte("Owner/Repo"), 0644)
defer os.Remove(configPath)
mockGitea := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(`{"description": "Updated", "topics": [], "license": {"name": "MIT"}}`))
}))
defer mockGitea.Close()
var mu sync.RWMutex
softwareList := &[]models.Software{}
StartBackgroundUpdater(configPath, mockGitea.URL, "", &mu, softwareList, 100*time.Millisecond)
// Wait for ticker
time.Sleep(200 * time.Millisecond)
if len(*softwareList) == 0 {
t.Error("softwareList was not updated by background updater")
}
}

View File

@@ -0,0 +1,6 @@
package config
const (
DefaultConfigPath = "software.txt"
DefaultGiteaServer = "https://git.quad4.io"
)

View File

@@ -0,0 +1,30 @@
package config
import (
"log"
"sync"
"time"
"software-station/internal/models"
)
func StartBackgroundUpdater(path, server, token string, mu *sync.RWMutex, softwareList *[]models.Software, interval time.Duration) {
ticker := time.NewTicker(interval)
go func() {
for range ticker.C {
log.Println("Checking for software updates...")
newList := LoadSoftwareFromGitea(path, server, token)
if len(newList) > 0 {
mu.Lock()
*softwareList = newList
mu.Unlock()
log.Printf("Software list updated with %d items", len(newList))
}
}
}()
}
// LoadSoftwareFromGitea always fetches from Gitea and updates cache
func LoadSoftwareFromGitea(path, server, token string) []models.Software {
return LoadSoftwareExtended(path, server, token, false)
}

298
internal/gitea/client.go Normal file
View File

@@ -0,0 +1,298 @@
package gitea
import (
"bufio"
"encoding/json"
"fmt"
"net/http"
"strings"
"time"
"software-station/internal/models"
)
func DetectOS(filename string) string {
lower := strings.ToLower(filename)
osMap := []struct {
patterns []string
suffixes []string
os string
}{
{
patterns: []string{"windows"},
suffixes: []string{".exe", ".msi"},
os: models.OSWindows,
},
{
patterns: []string{"linux"},
suffixes: []string{".deb", ".rpm", ".appimage", ".flatpak"},
os: models.OSLinux,
},
{
patterns: []string{"mac", "darwin"},
suffixes: []string{".dmg", ".pkg"},
os: models.OSMacOS,
},
{
patterns: []string{"freebsd"},
os: models.OSFreeBSD,
},
{
patterns: []string{"openbsd"},
os: models.OSOpenBSD,
},
{
patterns: []string{"android"},
suffixes: []string{".apk"},
os: models.OSAndroid,
},
{
patterns: []string{"arm", "aarch64"},
os: models.OSARM,
},
}
for _, entry := range osMap {
for _, p := range entry.patterns {
if strings.Contains(lower, p) {
return entry.os
}
}
for _, s := range entry.suffixes {
if strings.HasSuffix(lower, s) {
return entry.os
}
}
}
return models.OSUnknown
}
func FetchRepoInfo(server, token, owner, repo string) (string, []string, string, bool, string, error) {
url := fmt.Sprintf("%s%s/%s/%s", server, RepoAPIPath, owner, repo)
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return "", nil, "", false, "", err
}
if token != "" {
req.Header.Set("Authorization", "token "+token)
}
client := &http.Client{Timeout: DefaultTimeout}
resp, err := client.Do(req)
if err != nil {
return "", nil, "", false, "", err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", nil, "", false, "", fmt.Errorf("gitea api returned status %d", resp.StatusCode)
}
var info struct {
Description string `json:"description"`
Topics []string `json:"topics"`
DefaultBranch string `json:"default_branch"`
Licenses []string `json:"licenses"`
Private bool `json:"private"`
AvatarURL string `json:"avatar_url"`
}
if err := json.NewDecoder(resp.Body).Decode(&info); err != nil {
return "", nil, "", false, "", err
}
license := ""
if len(info.Licenses) > 0 {
license = info.Licenses[0]
}
if license == "" {
// Try to detect license from file if API returns nothing
license = detectLicenseFromFile(server, token, owner, repo, info.DefaultBranch)
}
return info.Description, info.Topics, license, info.Private, info.AvatarURL, nil
}
func detectLicenseFromFile(server, token, owner, repo, defaultBranch string) string {
branches := []string{"main", "master"}
if defaultBranch != "" {
// Put default branch first
branches = append([]string{defaultBranch}, "main", "master")
}
// Deduplicate
seen := make(map[string]bool)
var finalBranches []string
for _, b := range branches {
if !seen[b] {
seen[b] = true
finalBranches = append(finalBranches, b)
}
}
for _, branch := range finalBranches {
url := fmt.Sprintf("%s/%s/%s/raw/branch/%s/LICENSE", server, owner, repo, branch)
req, err := http.NewRequest("GET", url, nil)
if err != nil {
continue
}
if token != "" {
req.Header.Set("Authorization", "token "+token)
}
client := &http.Client{Timeout: DefaultTimeout}
resp, err := client.Do(req)
if err != nil {
continue
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusOK {
// Read first few lines to guess license
scanner := bufio.NewScanner(resp.Body)
for i := 0; i < 5 && scanner.Scan(); i++ {
line := strings.ToUpper(scanner.Text())
if strings.Contains(line, "MIT LICENSE") {
return "MIT"
}
if strings.Contains(line, "GNU GENERAL PUBLIC LICENSE") || strings.Contains(line, "GPL") {
return "GPL"
}
if strings.Contains(line, "APACHE LICENSE") {
return "Apache-2.0"
}
if strings.Contains(line, "BSD") {
return "BSD"
}
}
return "LICENSE" // Found file but couldn't detect type
}
}
return ""
}
func IsSBOM(filename string) bool {
lower := strings.ToLower(filename)
return strings.Contains(lower, "sbom") ||
strings.Contains(lower, "cyclonedx") ||
strings.Contains(lower, "spdx") ||
strings.HasSuffix(lower, ".cdx.json") ||
strings.HasSuffix(lower, ".spdx.json")
}
func FetchReleases(server, token, owner, repo string) ([]models.Release, error) {
url := fmt.Sprintf("%s%s/%s/%s%s", server, RepoAPIPath, owner, repo, ReleasesSuffix)
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, err
}
if token != "" {
req.Header.Set("Authorization", "token "+token)
}
client := &http.Client{Timeout: DefaultTimeout}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("gitea api returned status %d", resp.StatusCode)
}
var giteaReleases []struct {
TagName string `json:"tag_name"`
Body string `json:"body"`
CreatedAt time.Time `json:"created_at"`
Assets []struct {
Name string `json:"name"`
Size int64 `json:"size"`
URL string `json:"browser_download_url"`
} `json:"assets"`
}
if err := json.NewDecoder(resp.Body).Decode(&giteaReleases); err != nil {
return nil, err
}
var releases []models.Release
for _, gr := range giteaReleases {
var assets []models.Asset
var checksumsURL string
// First pass: identify assets and look for checksum file
for _, ga := range gr.Assets {
if ga.Name == "SHA256SUMS" {
checksumsURL = ga.URL
continue
}
assets = append(assets, models.Asset{
Name: ga.Name,
Size: ga.Size,
URL: ga.URL,
OS: DetectOS(ga.Name),
IsSBOM: IsSBOM(ga.Name),
})
}
// Second pass: if checksum file exists, fetch and parse it
if checksumsURL != "" {
checksums, err := fetchAndParseChecksums(checksumsURL, token)
if err == nil {
for i := range assets {
if sha, ok := checksums[assets[i].Name]; ok {
assets[i].SHA256 = sha
}
}
}
}
releases = append(releases, models.Release{
TagName: gr.TagName,
Body: gr.Body,
CreatedAt: gr.CreatedAt,
Assets: assets,
})
}
return releases, nil
}
func fetchAndParseChecksums(url, token string) (map[string]string, error) {
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, err
}
if token != "" {
req.Header.Set("Authorization", "token "+token)
}
client := &http.Client{Timeout: DefaultTimeout}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("failed to fetch checksums: %d", resp.StatusCode)
}
checksums := make(map[string]string)
scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" {
continue
}
parts := strings.Fields(line)
if len(parts) >= 2 {
// Format is usually: hash filename
checksums[parts[1]] = parts[0]
}
}
return checksums, nil
}

View File

@@ -0,0 +1,118 @@
package gitea
import (
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"software-station/internal/models"
)
func TestFetchRepoInfo(t *testing.T) {
mockSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("Authorization") != "token test-token" {
w.WriteHeader(http.StatusUnauthorized)
return
}
w.Write([]byte(`{"description": "Test Repo", "topics": ["test"], "licenses": ["MIT"], "private": false, "avatar_url": "https://example.com/avatar.png"}`))
}))
defer mockSrv.Close()
desc, topics, license, isPrivate, avatarURL, err := FetchRepoInfo(mockSrv.URL, "test-token", "owner", "repo")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if desc != "Test Repo" || len(topics) != 1 || topics[0] != "test" || license != "MIT" || isPrivate || avatarURL != "https://example.com/avatar.png" {
t.Errorf("unexpected results: %s, %v, %s, %v, %s", desc, topics, license, isPrivate, avatarURL)
}
_, _, _, _, _, err = FetchRepoInfo(mockSrv.URL, "wrong-token", "owner", "repo")
if err == nil {
t.Error("expected error for unauthorized request")
}
}
func TestFetchReleases(t *testing.T) {
var srv *httptest.Server
srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "releases") {
w.Write([]byte(fmt.Sprintf(`[{"tag_name": "v1.0.0", "assets": [{"name": "test.exe", "size": 100, "browser_download_url": "%s/test.exe"}, {"name": "SHA256SUMS", "size": 50, "browser_download_url": "%s/SHA256SUMS"}]}]`, srv.URL, srv.URL)))
} else if strings.Contains(r.URL.Path, "SHA256SUMS") {
w.Write([]byte(`hash123 test.exe`))
}
}))
defer srv.Close()
releases, err := FetchReleases(srv.URL, "", "owner", "repo")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(releases) != 1 || releases[0].TagName != "v1.0.0" {
t.Errorf("unexpected releases: %v", releases)
}
if len(releases[0].Assets) != 1 || releases[0].Assets[0].Name != "test.exe" || releases[0].Assets[0].SHA256 != "hash123" {
t.Errorf("unexpected assets: %v", releases[0].Assets)
}
}
func TestDetectOS(t *testing.T) {
tests := []struct {
filename string
expected string
}{
{"app.exe", models.OSWindows},
{"app.msi", models.OSWindows},
{"app_linux", models.OSLinux},
{"app.deb", models.OSLinux},
{"app.rpm", models.OSLinux},
{"app.dmg", models.OSMacOS},
{"app.pkg", models.OSMacOS},
{"app_freebsd", models.OSFreeBSD},
{"app_openbsd", models.OSOpenBSD},
{"app.apk", models.OSAndroid},
{"app_arm64", models.OSARM},
{"app_aarch64", models.OSARM},
{"unknown", models.OSUnknown},
}
for _, tt := range tests {
got := DetectOS(tt.filename)
if got != tt.expected {
t.Errorf("DetectOS(%s) = %s, expected %s", tt.filename, got, tt.expected)
}
}
}
func TestIsSBOM(t *testing.T) {
tests := []struct {
filename string
expected bool
}{
{"sbom.json", true},
{"cyclonedx.json", true},
{"spdx.json", true},
{"app.exe", false},
{"app.cdx.json", true},
{"app.spdx.json", true},
}
for _, tt := range tests {
if got := IsSBOM(tt.filename); got != tt.expected {
t.Errorf("IsSBOM(%s) = %v, expected %v", tt.filename, got, tt.expected)
}
}
}
func TestFetchAndParseChecksumsError(t *testing.T) {
mockSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
}))
defer mockSrv.Close()
_, err := fetchAndParseChecksums(mockSrv.URL, "")
if err == nil {
t.Error("expected error for 404")
}
}

View File

@@ -0,0 +1,9 @@
package gitea
import "time"
const (
DefaultTimeout = 10 * time.Second
RepoAPIPath = "/api/v1/repos"
ReleasesSuffix = "/releases"
)

View File

@@ -0,0 +1,12 @@
package models
const (
OSWindows = "windows"
OSLinux = "linux"
OSMacOS = "macos"
OSFreeBSD = "freebsd"
OSOpenBSD = "openbsd"
OSAndroid = "android"
OSARM = "arm"
OSUnknown = "unknown"
)

41
internal/models/models.go Normal file
View File

@@ -0,0 +1,41 @@
package models
import "time"
type Asset struct {
Name string `json:"name"`
Size int64 `json:"size"`
URL string `json:"url"`
OS string `json:"os"`
SHA256 string `json:"sha256,omitempty"`
IsSBOM bool `json:"is_sbom"`
}
type Release struct {
TagName string `json:"tag_name"`
Body string `json:"body,omitempty"`
CreatedAt time.Time `json:"created_at"`
Assets []Asset `json:"assets"`
}
type Software struct {
Name string `json:"name"`
Owner string `json:"owner"`
Description string `json:"description"`
Releases []Release `json:"releases"`
GiteaURL string `json:"gitea_url"`
Topics []string `json:"topics"`
License string `json:"license,omitempty"`
IsPrivate bool `json:"is_private"`
AvatarURL string `json:"avatar_url,omitempty"`
}
type FingerprintData struct {
Known bool `json:"known"`
TotalBytes int64 `json:"total_bytes"`
}
type SoftwareResponse struct {
GiteaURL string `json:"gitea_url"`
Software []Software `json:"software"`
}

View File

@@ -0,0 +1,50 @@
package security
import (
"time"
"golang.org/x/time/rate"
)
const (
_ = iota
KB = 1 << (10 * iota)
MB
GB
)
const (
// Download Throttling
DefaultDownloadLimit = rate.Limit(5 * MB) // 5MB/s
DefaultDownloadBurst = 2 * MB // 2MB
SpeedDownloaderLimit = rate.Limit(1 * MB) // 1MB/s
SpeedDownloaderBurst = 512 * KB // 512KB
HeavyDownloaderThreshold = 1 * GB // 1GB
HeavyDownloaderLimit = rate.Limit(256 * KB) // 256KB/s
// Rate Limiting
GlobalRateLimit = 100
GlobalRateWindow = 1 * time.Minute
APIRateLimit = 30
APIRateWindow = 1 * time.Minute
)
var ForbiddenPatterns = []string{
".git", ".env", ".aws", ".config", ".ssh",
"wp-admin", "wp-login", "phpinfo", ".php",
"etc/passwd", "cgi-bin", "shell", "cmd",
".sql", ".bak", ".old", ".zip", ".rar",
}
var BotUserAgents = []string{
"bot", "crawl", "spider", "slurp", "googlebot", "bingbot", "yandexbot",
"ahrefsbot", "baiduspider", "duckduckbot", "facebookexternalhit",
"twitterbot", "rogerbot", "linkedinbot", "embedly", "quora link preview",
"showyoubot", "outbrain", "pinterest", "slackbot", "vkShare", "W3C_Validator",
}
var SpeedDownloaders = []string{
"aria2", "wget", "curl", "axel", "transmission", "libcurl",
}

View File

@@ -0,0 +1,220 @@
package security
import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"log"
"net"
"net/http"
"os"
"strings"
"sync/atomic"
"syscall"
"time"
"software-station/internal/models"
"software-station/internal/stats"
"golang.org/x/time/rate"
)
type ThrottledReader struct {
R io.ReadCloser
Limiter *rate.Limiter
Fingerprint string
Stats *stats.Service
}
func (tr *ThrottledReader) Read(p []byte) (n int, err error) {
n, err = tr.R.Read(p)
if n > 0 && tr.Limiter != nil && tr.Stats != nil {
tr.Stats.KnownHashes.RLock()
data, exists := tr.Stats.KnownHashes.Data[tr.Fingerprint]
tr.Stats.KnownHashes.RUnlock()
var total int64
if exists {
total = atomic.AddInt64(&data.TotalBytes, int64(n))
}
atomic.AddInt64(&tr.Stats.GlobalStats.TotalBytes, int64(n))
if total > HeavyDownloaderThreshold {
tr.Limiter.SetLimit(HeavyDownloaderLimit)
}
if err := tr.Limiter.WaitN(context.Background(), n); err != nil {
return n, err
}
}
return n, err
}
func (tr *ThrottledReader) Close() error {
return tr.R.Close()
}
type contextKey string
const FingerprintKey contextKey = "fingerprint"
func GetRequestFingerprint(r *http.Request, s *stats.Service) string {
if f, ok := r.Context().Value(FingerprintKey).(string); ok {
return f
}
remoteAddr := r.RemoteAddr
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
if comma := strings.IndexByte(xff, ','); comma != -1 {
remoteAddr = strings.TrimSpace(xff[:comma])
} else {
remoteAddr = strings.TrimSpace(xff)
}
} else if xri := r.Header.Get("X-Real-IP"); xri != "" {
remoteAddr = strings.TrimSpace(xri)
}
ipStr, _, err := net.SplitHostPort(remoteAddr)
if err != nil {
ipStr = remoteAddr
}
ip := net.ParseIP(ipStr)
if ip != nil {
if ip.To4() == nil {
ip = ip.Mask(net.CIDRMask(64, 128))
}
ipStr = ip.String()
}
ua := r.Header.Get("User-Agent")
hash := sha256.New()
hash.Write([]byte(ipStr + ua))
fingerprint := hex.EncodeToString(hash.Sum(nil))
s.KnownHashes.Lock()
if _, exists := s.KnownHashes.Data[fingerprint]; !exists {
s.KnownHashes.Data[fingerprint] = &models.FingerprintData{
Known: true,
}
s.SaveHashes()
}
s.KnownHashes.Unlock()
return fingerprint
}
func IsPrivateIP(ip net.IP) bool {
if os.Getenv("ALLOW_LOOPBACK") == "true" && ip.IsLoopback() {
return false
}
if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || ip.IsUnspecified() {
return true
}
// Private IP ranges
privateRanges := []struct {
start net.IP
end net.IP
}{
{net.ParseIP("10.0.0.0"), net.ParseIP("10.255.255.255")},
{net.ParseIP("172.16.0.0"), net.ParseIP("172.31.255.255")},
{net.ParseIP("192.168.0.0"), net.ParseIP("192.168.255.255")},
{net.ParseIP("fd00::"), net.ParseIP("fdff:ffff:ffff:ffff:ffff:ffff:ffff:ffff")},
}
for _, r := range privateRanges {
if bytes.Compare(ip, r.start) >= 0 && bytes.Compare(ip, r.end) <= 0 {
return true
}
}
return false
}
func GetSafeHTTPClient(timeout time.Duration) *http.Client {
dialer := &net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
Control: func(network, address string, c syscall.RawConn) error {
host, _, err := net.SplitHostPort(address)
if err != nil {
return err
}
ip := net.ParseIP(host)
if ip != nil && IsPrivateIP(ip) {
return fmt.Errorf("SSRF protection: forbidden IP %s", ip)
}
return nil
},
}
transport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: dialer.DialContext,
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
return &http.Client{
Transport: transport,
Timeout: timeout,
}
}
func SecurityMiddleware(s *stats.Service) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
path := strings.ToLower(r.URL.Path)
ua := strings.ToLower(r.UserAgent())
fingerprint := GetRequestFingerprint(r, s)
ctx := context.WithValue(r.Context(), FingerprintKey, fingerprint)
r = r.WithContext(ctx)
s.GlobalStats.Lock()
s.GlobalStats.UniqueRequests[fingerprint] = true
if !strings.HasPrefix(path, "/api") {
s.GlobalStats.WebRequests[fingerprint] = true
}
s.GlobalStats.Unlock()
defer func() {
s.GlobalStats.Lock()
s.GlobalStats.TotalResponseTime += time.Since(start)
s.GlobalStats.TotalRequests++
s.GlobalStats.Unlock()
}()
for _, bot := range BotUserAgents {
if strings.Contains(ua, bot) {
s.GlobalStats.Lock()
s.GlobalStats.BlockedRequests[fingerprint] = true
s.GlobalStats.Unlock()
http.Error(w, "Bots are not allowed", http.StatusForbidden)
return
}
}
for _, pattern := range ForbiddenPatterns {
if strings.Contains(path, pattern) {
s.GlobalStats.Lock()
s.GlobalStats.BlockedRequests[fingerprint] = true
s.GlobalStats.Unlock()
log.Printf("Blocked suspicious request: %s from %s (%s)", r.URL.String(), r.RemoteAddr, r.UserAgent())
http.Error(w, "Forbidden", http.StatusForbidden)
return
}
}
next.ServeHTTP(w, r)
})
}
}

View File

@@ -0,0 +1,176 @@
package security
import (
"bytes"
"io"
"net"
"net/http"
"net/http/httptest"
"software-station/internal/models"
"software-station/internal/stats"
"strings"
"testing"
"time"
"golang.org/x/time/rate"
)
func TestThrottledReader(t *testing.T) {
content := []byte("hello world")
inner := io.NopCloser(bytes.NewReader(content))
limiter := rate.NewLimiter(rate.Limit(100), 100)
fp := "test-fp"
statsService := stats.NewService("test-hashes.json")
statsService.KnownHashes.Lock()
statsService.KnownHashes.Data[fp] = &models.FingerprintData{Known: true}
statsService.KnownHashes.Unlock()
tr := &ThrottledReader{
R: inner,
Limiter: limiter,
Fingerprint: fp,
Stats: statsService,
}
p := make([]byte, 5)
n, err := tr.Read(p)
if err != nil {
t.Fatalf("Read failed: %v", err)
}
if n != 5 {
t.Errorf("expected 5 bytes, got %d", n)
}
statsService.KnownHashes.RLock()
data, ok := statsService.KnownHashes.Data["test-fp"]
statsService.KnownHashes.RUnlock()
if !ok {
t.Fatal("fingerprint data not found")
}
total := data.TotalBytes
if total != 5 {
t.Errorf("expected 5 bytes in stats, got %d", total)
}
tr.Close()
}
func TestGetRequestFingerprint(t *testing.T) {
statsService := stats.NewService("test-hashes.json")
// IPv4
req1 := httptest.NewRequest("GET", "/", nil)
req1.RemoteAddr = "1.2.3.4:1234"
f1 := GetRequestFingerprint(req1, statsService)
req2 := httptest.NewRequest("GET", "/", nil)
req2.RemoteAddr = "1.2.3.4:5678"
f2 := GetRequestFingerprint(req2, statsService)
if f1 != f2 {
t.Error("fingerprints should match for same IPv4")
}
// X-Forwarded-For
req3 := httptest.NewRequest("GET", "/", nil)
req3.Header.Set("X-Forwarded-For", "5.6.7.8, 1.2.3.4")
f3 := GetRequestFingerprint(req3, statsService)
if f1 == f3 {
t.Error("fingerprints should differ for different IPs")
}
// IPv6 masking
req4 := httptest.NewRequest("GET", "/", nil)
req4.RemoteAddr = "[2001:db8::1]:1234"
f4 := GetRequestFingerprint(req4, statsService)
req5 := httptest.NewRequest("GET", "/", nil)
req5.RemoteAddr = "[2001:db8::2]:1234"
f5 := GetRequestFingerprint(req5, statsService)
if f4 != f5 {
t.Error("fingerprints should match for same IPv6 /64 prefix")
}
}
func TestSecurityMiddleware(t *testing.T) {
statsService := stats.NewService("test-hashes.json")
handler := SecurityMiddleware(statsService)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// Test bot blocking
req := httptest.NewRequest("GET", "/", nil)
req.Header.Set("User-Agent", "Googlebot")
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusForbidden {
t.Errorf("expected 403 for bot, got %d", rr.Code)
}
// Test forbidden pattern
req = httptest.NewRequest("GET", "/.git/config", nil)
rr = httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusForbidden {
t.Errorf("expected 403 for forbidden pattern, got %d", rr.Code)
}
// Test normal request
req = httptest.NewRequest("GET", "/api/software", nil)
rr = httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("expected 200 for normal request, got %d", rr.Code)
}
}
func TestIsPrivateIP_Extended(t *testing.T) {
tests := []struct {
ip string
isPrivate bool
}{
{"127.0.0.1", true},
{"10.0.0.1", true},
{"172.16.0.1", true},
{"192.168.1.1", true},
{"0.0.0.0", true},
{"::1", true},
{"::", true},
{"fd00::1", true},
{"8.8.8.8", false},
{"1.1.1.1", false},
}
for _, tt := range tests {
ip := net.ParseIP(tt.ip)
if got := IsPrivateIP(ip); got != tt.isPrivate {
t.Errorf("IsPrivateIP(%s) = %v; want %v", tt.ip, got, tt.isPrivate)
}
}
}
func TestGetSafeHTTPClient_BlocksPrivateIPs(t *testing.T) {
client := GetSafeHTTPClient(1 * time.Second)
// Test 127.0.0.1
_, err := client.Get("http://127.0.0.1:12345")
if err == nil {
t.Error("Expected error for 127.0.0.1, got nil")
} else if !strings.Contains(err.Error(), "SSRF protection") {
t.Logf("Got error for 127.0.0.1: %v", err)
if !strings.Contains(err.Error(), "SSRF protection") {
t.Errorf("Expected 'SSRF protection' error, got: %v", err)
}
}
// Test 0.0.0.0
_, err = client.Get("http://0.0.0.0:12345")
if err == nil {
t.Error("Expected error for 0.0.0.0, got nil")
} else if !strings.Contains(err.Error(), "SSRF protection") {
t.Logf("Got error for 0.0.0.0: %v", err)
if !strings.Contains(err.Error(), "SSRF protection") {
t.Errorf("Expected 'SSRF protection' error, got: %v", err)
}
}
}

View File

@@ -0,0 +1,5 @@
package stats
const (
DefaultHashesFile = "hashes.json"
)

154
internal/stats/stats.go Normal file
View File

@@ -0,0 +1,154 @@
package stats
import (
"encoding/json"
"log"
"net/http"
"os"
"sync"
"sync/atomic"
"time"
"software-station/internal/models"
"golang.org/x/time/rate"
)
type Service struct {
HashesFile string
KnownHashes struct {
sync.RWMutex
Data map[string]*models.FingerprintData
}
GlobalStats struct {
sync.RWMutex
UniqueRequests map[string]bool
SuccessDownloads map[string]bool
BlockedRequests map[string]bool
LimitedRequests map[string]bool
WebRequests map[string]bool
TotalResponseTime time.Duration
TotalRequests int64
TotalBytes int64
StartTime time.Time
}
DownloadStats struct {
sync.RWMutex
Limiters map[string]*rate.Limiter
}
hashesDirty int32
stopChan chan struct{}
}
func NewService(hashesFile string) *Service {
s := &Service{
HashesFile: hashesFile,
stopChan: make(chan struct{}),
}
s.KnownHashes.Data = make(map[string]*models.FingerprintData)
s.GlobalStats.UniqueRequests = make(map[string]bool)
s.GlobalStats.SuccessDownloads = make(map[string]bool)
s.GlobalStats.BlockedRequests = make(map[string]bool)
s.GlobalStats.LimitedRequests = make(map[string]bool)
s.GlobalStats.WebRequests = make(map[string]bool)
s.GlobalStats.StartTime = time.Now()
s.DownloadStats.Limiters = make(map[string]*rate.Limiter)
return s
}
func (s *Service) Start() {
go func() {
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if atomic.CompareAndSwapInt32(&s.hashesDirty, 1, 0) {
s.FlushHashes()
}
case <-s.stopChan:
s.FlushHashes()
return
}
}
}()
}
func (s *Service) Stop() {
close(s.stopChan)
}
func (s *Service) LoadHashes() {
data, err := os.ReadFile(s.HashesFile)
if err != nil {
if !os.IsNotExist(err) {
log.Printf("Error reading hashes file: %v", err)
}
return
}
s.KnownHashes.Lock()
defer s.KnownHashes.Unlock()
if err := json.Unmarshal(data, &s.KnownHashes.Data); err != nil {
log.Printf("Error unmarshaling hashes: %v", err)
return
}
var total int64
for _, d := range s.KnownHashes.Data {
total += atomic.LoadInt64(&d.TotalBytes)
}
atomic.StoreInt64(&s.GlobalStats.TotalBytes, total)
}
func (s *Service) SaveHashes() {
atomic.StoreInt32(&s.hashesDirty, 1)
}
func (s *Service) FlushHashes() {
s.KnownHashes.RLock()
data, err := json.MarshalIndent(s.KnownHashes.Data, "", " ")
s.KnownHashes.RUnlock()
if err != nil {
log.Printf("Error marshaling hashes: %v", err)
return
}
if err := os.WriteFile(s.HashesFile, data, 0600); err != nil {
log.Printf("Error writing hashes file: %v", err)
}
}
func (s *Service) APIStatsHandler(w http.ResponseWriter, r *http.Request) {
s.GlobalStats.RLock()
defer s.GlobalStats.RUnlock()
avgResponse := time.Duration(0)
if s.GlobalStats.TotalRequests > 0 {
avgResponse = s.GlobalStats.TotalResponseTime / time.Duration(s.GlobalStats.TotalRequests)
}
totalBytes := atomic.LoadInt64(&s.GlobalStats.TotalBytes)
uptime := time.Since(s.GlobalStats.StartTime)
avgSpeed := float64(totalBytes) / uptime.Seconds()
status := "healthy"
if s.GlobalStats.TotalRequests > 0 && float64(len(s.GlobalStats.BlockedRequests))/float64(s.GlobalStats.TotalRequests) > 0.5 {
status = "unhealthy"
}
data := map[string]interface{}{
"total_unique_download_requests": len(s.GlobalStats.UniqueRequests),
"total_unique_success_downloads": len(s.GlobalStats.SuccessDownloads),
"total_unique_blocked": len(s.GlobalStats.BlockedRequests),
"total_unique_limited": len(s.GlobalStats.LimitedRequests),
"total_unique_web_requests": len(s.GlobalStats.WebRequests),
"avg_speed_bps": avgSpeed,
"avg_response_time": avgResponse.String(),
"uptime": uptime.String(),
"status": status,
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(data); err != nil {
log.Printf("Error encoding stats: %v", err)
}
}

View File

@@ -0,0 +1,54 @@
package stats
import (
"encoding/json"
"net/http"
"net/http/httptest"
"os"
"software-station/internal/models"
"testing"
)
func TestStats(t *testing.T) {
tempFile := "test_hashes.json"
defer os.Remove(tempFile)
service := NewService(tempFile)
// Test SaveHashes
service.KnownHashes.Lock()
service.KnownHashes.Data["test"] = &models.FingerprintData{Known: true, TotalBytes: 100}
service.KnownHashes.Unlock()
service.FlushHashes()
if _, err := os.Stat(tempFile); os.IsNotExist(err) {
t.Error("SaveHashes did not create file")
}
// Test LoadHashes
service.KnownHashes.Lock()
delete(service.KnownHashes.Data, "test")
service.KnownHashes.Unlock()
service.LoadHashes()
service.KnownHashes.RLock()
if data, ok := service.KnownHashes.Data["test"]; !ok || data.TotalBytes != 100 {
t.Errorf("LoadHashes did not restore data correctly: %+v", data)
}
service.KnownHashes.RUnlock()
// Test APIStatsHandler
req := httptest.NewRequest("GET", "/api/stats", nil)
rr := httptest.NewRecorder()
service.APIStatsHandler(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("expected 200, got %d", rr.Code)
}
var stats map[string]interface{}
json.Unmarshal(rr.Body.Bytes(), &stats)
if stats["status"] != "healthy" {
t.Errorf("expected healthy status, got %v", stats["status"])
}
}