Files
software-station/main_test.go
Sudo-Ivan d954d7fe4b
All checks were successful
CI / build (push) Successful in 1m15s
renovate / renovate (push) Successful in 1m19s
Update security middleware and update Docker configurations
- Added a new parameter to the SecurityMiddleware function to allow custom handling of forbidden requests.
- Updated Docker configurations to enable asset caching for improved performance.
- Bumped version number in the Dockerfile to 0.3.0 and refined the image description for clarity.
- Adjusted various frontend components and error handling to support new rate limiting and forbidden access messages.
- Improved documentation in multiple languages to reflect recent changes in features and security measures.
2025-12-27 21:53:10 -06:00

315 lines
10 KiB
Go

package main
import (
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
"software-station/internal/api"
"software-station/internal/config"
"software-station/internal/gitea"
"software-station/internal/models"
"software-station/internal/security"
"software-station/internal/stats"
"github.com/go-chi/chi/v5"
)
func TestMainHandlers(t *testing.T) {
os.Setenv("ALLOW_LOOPBACK", "true")
defer os.Unsetenv("ALLOW_LOOPBACK")
// Setup mock software.txt
configPath = "test_software.txt"
os.WriteFile(configPath, []byte("Quad4-Software/software-station"), 0644)
defer os.Remove(configPath)
defer os.Remove("hashes.json")
os.RemoveAll(".cache") // Clear cache for tests
// Mock Gitea Server
var mockGitea *httptest.Server
mockGitea = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "releases") {
w.Write([]byte(fmt.Sprintf(`[{"tag_name": "v1.0.0", "body": "Release notes for v1.0.0", "created_at": "2025-12-27T10:00:00Z", "assets": [{"name": "test.exe", "size": 100, "browser_download_url": "%s/test.exe"}, {"name": "SHA256SUMS", "size": 50, "browser_download_url": "%s/SHA256SUMS"}]}]`, mockGitea.URL, mockGitea.URL)))
} else if strings.Contains(r.URL.Path, "SHA256SUMS") {
w.Write([]byte(`b380cbb6489437721e1674e3b2736c699f5ffe8827e83e749b4f72417ea7e12c test.exe`))
} else {
w.Write([]byte(`{"description": "Test Repo", "topics": ["test", "mock"], "licenses": ["MIT"], "private": false, "avatar_url": "https://example.com/logo.png"}`))
}
}))
defer mockGitea.Close()
giteaServer = mockGitea.URL
statsService := stats.NewService("test-hashes.json")
botBlocker := security.NewBotBlocker("")
initialSoftware := config.LoadSoftware(configPath, giteaServer, "")
apiServer := api.NewServer("", initialSoftware, statsService, true)
r := chi.NewRouter()
r.Use(security.SecurityMiddleware(statsService, botBlocker, nil))
r.Get("/api/software", apiServer.APISoftwareHandler)
r.Get("/api/stats", statsService.APIStatsHandler)
r.Get("/api/download", apiServer.DownloadProxyHandler)
r.Get("/api/avatar", apiServer.AvatarHandler)
r.Get("/api/rss", apiServer.RSSHandler)
t.Run("API Software", func(t *testing.T) {
req := httptest.NewRequest("GET", "/api/software", nil)
rr := httptest.NewRecorder()
r.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("expected 200, got %d", rr.Code)
}
var sw []models.Software
json.Unmarshal(rr.Body.Bytes(), &sw)
if len(sw) == 0 || sw[0].Name != "software-station" {
t.Errorf("unexpected software list: %v", sw)
}
if sw[0].Releases[0].Body != "Release notes for v1.0.0" {
t.Errorf("expected release body, got %s", sw[0].Releases[0].Body)
}
})
t.Run("RSS Feed", func(t *testing.T) {
req := httptest.NewRequest("GET", "/api/rss", nil)
rr := httptest.NewRecorder()
r.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("expected 200, got %d", rr.Code)
}
body := rr.Body.String()
if !strings.Contains(body, "<rss") || !strings.Contains(body, "software-station v1.0.0") {
t.Errorf("RSS feed content missing: %s", body)
}
if !strings.Contains(body, "Release notes for v1.0.0") {
t.Errorf("RSS feed missing release notes: %s", body)
}
// Test per-software RSS
req = httptest.NewRequest("GET", "/api/rss?software=software-station", nil)
rr = httptest.NewRecorder()
r.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("expected 200, got %d", rr.Code)
}
if !strings.Contains(rr.Body.String(), "software-station Updates") {
t.Error("Specific software RSS title missing")
}
// Test non-existent software RSS
req = httptest.NewRequest("GET", "/api/rss?software=non-existent", nil)
rr = httptest.NewRecorder()
r.ServeHTTP(rr, req)
if !strings.Contains(rr.Body.String(), "<item>") == false {
// Should be empty channel
}
})
t.Run("Avatar Proxy & Cache", func(t *testing.T) {
avatarData := []byte("fake-image-data")
avatarServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "image/png")
w.Write(avatarData)
}))
defer avatarServer.Close()
hash := apiServer.RegisterURL(avatarServer.URL)
req := httptest.NewRequest("GET", fmt.Sprintf("/api/avatar?id=%s", hash), nil)
rr := httptest.NewRecorder()
r.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("expected 200, got %d", rr.Code)
}
if rr.Header().Get("Content-Type") != "image/png" {
t.Errorf("expected image/png, got %s", rr.Header().Get("Content-Type"))
}
// Verify it was cached
cachePath := filepath.Join(".cache/avatars", hash)
if _, err := os.Stat(cachePath); os.IsNotExist(err) {
t.Error("avatar was not cached to disk")
}
})
t.Run("API Stats", func(t *testing.T) {
req := httptest.NewRequest("GET", "/api/stats", nil)
rr := httptest.NewRecorder()
r.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("expected 200, got %d", rr.Code)
}
})
t.Run("Download Proxy Throttling", func(t *testing.T) {
assetServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write(make([]byte, 1024))
}))
defer assetServer.Close()
hash := apiServer.RegisterURL(assetServer.URL)
req := httptest.NewRequest("GET", fmt.Sprintf("/api/download?id=%s", hash), nil)
req.Header.Set("User-Agent", "aria2/1.35.0")
rr := httptest.NewRecorder()
r.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("expected 200, got %d", rr.Code)
}
})
t.Run("Download Range Request", func(t *testing.T) {
content := "0123456789"
assetServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("Range") == "bytes=2-5" {
w.Header().Set("Content-Range", "bytes 2-5/10")
w.WriteHeader(http.StatusPartialContent)
w.Write([]byte(content[2:6]))
} else {
w.Write([]byte(content))
}
}))
defer assetServer.Close()
hash := apiServer.RegisterURL(assetServer.URL)
req := httptest.NewRequest("GET", fmt.Sprintf("/api/download?id=%s", hash), nil)
req.Header.Set("Range", "bytes=2-5")
rr := httptest.NewRecorder()
r.ServeHTTP(rr, req)
if rr.Code != http.StatusPartialContent {
t.Errorf("expected 206, got %d", rr.Code)
}
})
t.Run("Asset Caching Integration", func(t *testing.T) {
content := []byte("integration-cache-test")
callCount := 0
assetServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
callCount++
w.Write(content)
}))
defer assetServer.Close()
hash := apiServer.RegisterURL(assetServer.URL)
// First call
req := httptest.NewRequest("GET", fmt.Sprintf("/api/download?id=%s", hash), nil)
rr := httptest.NewRecorder()
r.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("first call: expected 200, got %d", rr.Code)
}
// Second call (should be cached)
req = httptest.NewRequest("GET", fmt.Sprintf("/api/download?id=%s", hash), nil)
rr = httptest.NewRecorder()
r.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("second call: expected 200, got %d", rr.Code)
}
if callCount != 1 {
t.Errorf("expected 1 upstream call, got %d", callCount)
}
if rr.Body.String() != string(content) {
t.Errorf("expected content %q, got %q", string(content), rr.Body.String())
}
})
t.Run("Security - Path Traversal", func(t *testing.T) {
patterns := []string{"/.git/config", "/etc/passwd"}
for _, p := range patterns {
req := httptest.NewRequest("GET", p, nil)
rr := httptest.NewRecorder()
r.ServeHTTP(rr, req)
if rr.Code != http.StatusForbidden {
t.Errorf("expected 403 for %s, got %d", p, rr.Code)
}
}
})
t.Run("Security - XSS in API", func(t *testing.T) {
malicious := []models.Software{{Name: "<script>alert(1)</script>"}}
testStatsService := stats.NewService("test-hashes.json")
srv := api.NewServer("", malicious, testStatsService, true)
req := httptest.NewRequest("GET", "/api/software", nil)
rr := httptest.NewRecorder()
srv.APISoftwareHandler(rr, req)
if !strings.Contains(rr.Body.String(), "script") {
t.Error("XSS payload missing")
}
})
t.Run("Download Proxy - Speed Downloader", func(t *testing.T) {
assetServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write(make([]byte, 100))
}))
defer assetServer.Close()
hash := apiServer.RegisterURL(assetServer.URL)
req := httptest.NewRequest("GET", fmt.Sprintf("/api/download?id=%s", hash), nil)
req.Header.Set("User-Agent", "aria2/1.35.0")
rr := httptest.NewRecorder()
r.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("expected 200, got %d", rr.Code)
}
})
t.Run("Static Files", func(t *testing.T) {
// Mock the filesystem behavior for the static handler
// We'll just test if the router handles unknown paths correctly
req := httptest.NewRequest("GET", "/unknown-path", nil)
rr := httptest.NewRecorder()
r.ServeHTTP(rr, req)
// It should try to serve index.html if file not found, but since build might be empty in tests, it might 500 or 404
if rr.Code == http.StatusOK || rr.Code == http.StatusNotFound || rr.Code == http.StatusInternalServerError {
t.Logf("Static handler returned %d", rr.Code)
}
})
t.Run("API Stats - Status unhealthy", func(t *testing.T) {
statsService.GlobalStats.Lock()
statsService.GlobalStats.TotalRequests = 100
statsService.GlobalStats.BlockedRequests["bad"] = true
// Make it more than 50% blocked
for i := 0; i < 60; i++ {
statsService.GlobalStats.BlockedRequests[fmt.Sprintf("bad%d", i)] = true
}
statsService.GlobalStats.Unlock()
req := httptest.NewRequest("GET", "/api/stats", nil)
rr := httptest.NewRecorder()
r.ServeHTTP(rr, req)
var s map[string]interface{}
json.Unmarshal(rr.Body.Bytes(), &s)
if s["status"] != "unhealthy" {
t.Errorf("expected unhealthy status, got %v", s["status"])
}
// Reset
statsService.GlobalStats.Lock()
statsService.GlobalStats.BlockedRequests = make(map[string]bool)
statsService.GlobalStats.TotalRequests = 0
statsService.GlobalStats.Unlock()
})
}
func TestOSDetection(t *testing.T) {
if gitea.DetectOS("test.exe") != models.OSWindows {
t.Error("expected windows")
}
}