- Introduced a bot blocker initialized with a user agent blocklist. - Updated the security middleware to utilize the new bot blocker in both main and test files. - Enhanced error handling for API requests to return a proper 404 response when content is not found.
280 lines
9.0 KiB
Go
280 lines
9.0 KiB
Go
package main
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"testing"
|
|
|
|
"software-station/internal/api"
|
|
"software-station/internal/config"
|
|
"software-station/internal/gitea"
|
|
"software-station/internal/models"
|
|
"software-station/internal/security"
|
|
"software-station/internal/stats"
|
|
|
|
"github.com/go-chi/chi/v5"
|
|
)
|
|
|
|
func TestMainHandlers(t *testing.T) {
|
|
os.Setenv("ALLOW_LOOPBACK", "true")
|
|
defer os.Unsetenv("ALLOW_LOOPBACK")
|
|
|
|
// Setup mock software.txt
|
|
configPath = "test_software.txt"
|
|
os.WriteFile(configPath, []byte("Quad4-Software/software-station"), 0644)
|
|
defer os.Remove(configPath)
|
|
defer os.Remove("hashes.json")
|
|
os.RemoveAll(".cache") // Clear cache for tests
|
|
|
|
// Mock Gitea Server
|
|
var mockGitea *httptest.Server
|
|
mockGitea = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if strings.Contains(r.URL.Path, "releases") {
|
|
w.Write([]byte(fmt.Sprintf(`[{"tag_name": "v1.0.0", "body": "Release notes for v1.0.0", "created_at": "2025-12-27T10:00:00Z", "assets": [{"name": "test.exe", "size": 100, "browser_download_url": "%s/test.exe"}, {"name": "SHA256SUMS", "size": 50, "browser_download_url": "%s/SHA256SUMS"}]}]`, mockGitea.URL, mockGitea.URL)))
|
|
} else if strings.Contains(r.URL.Path, "SHA256SUMS") {
|
|
w.Write([]byte(`b380cbb6489437721e1674e3b2736c699f5ffe8827e83e749b4f72417ea7e12c test.exe`))
|
|
} else {
|
|
w.Write([]byte(`{"description": "Test Repo", "topics": ["test", "mock"], "licenses": ["MIT"], "private": false, "avatar_url": "https://example.com/logo.png"}`))
|
|
}
|
|
}))
|
|
defer mockGitea.Close()
|
|
|
|
giteaServer = mockGitea.URL
|
|
statsService := stats.NewService("test-hashes.json")
|
|
botBlocker := security.NewBotBlocker("")
|
|
initialSoftware := config.LoadSoftware(configPath, giteaServer, "")
|
|
apiServer := api.NewServer("", initialSoftware, statsService)
|
|
|
|
r := chi.NewRouter()
|
|
r.Use(security.SecurityMiddleware(statsService, botBlocker))
|
|
r.Get("/api/software", apiServer.APISoftwareHandler)
|
|
r.Get("/api/stats", statsService.APIStatsHandler)
|
|
r.Get("/api/download", apiServer.DownloadProxyHandler)
|
|
r.Get("/api/avatar", apiServer.AvatarHandler)
|
|
r.Get("/api/rss", apiServer.RSSHandler)
|
|
|
|
t.Run("API Software", func(t *testing.T) {
|
|
req := httptest.NewRequest("GET", "/api/software", nil)
|
|
rr := httptest.NewRecorder()
|
|
r.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusOK {
|
|
t.Errorf("expected 200, got %d", rr.Code)
|
|
}
|
|
|
|
var sw []models.Software
|
|
json.Unmarshal(rr.Body.Bytes(), &sw)
|
|
if len(sw) == 0 || sw[0].Name != "software-station" {
|
|
t.Errorf("unexpected software list: %v", sw)
|
|
}
|
|
|
|
if sw[0].Releases[0].Body != "Release notes for v1.0.0" {
|
|
t.Errorf("expected release body, got %s", sw[0].Releases[0].Body)
|
|
}
|
|
})
|
|
|
|
t.Run("RSS Feed", func(t *testing.T) {
|
|
req := httptest.NewRequest("GET", "/api/rss", nil)
|
|
rr := httptest.NewRecorder()
|
|
r.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusOK {
|
|
t.Errorf("expected 200, got %d", rr.Code)
|
|
}
|
|
|
|
body := rr.Body.String()
|
|
if !strings.Contains(body, "<rss") || !strings.Contains(body, "software-station v1.0.0") {
|
|
t.Errorf("RSS feed content missing: %s", body)
|
|
}
|
|
if !strings.Contains(body, "Release notes for v1.0.0") {
|
|
t.Errorf("RSS feed missing release notes: %s", body)
|
|
}
|
|
|
|
// Test per-software RSS
|
|
req = httptest.NewRequest("GET", "/api/rss?software=software-station", nil)
|
|
rr = httptest.NewRecorder()
|
|
r.ServeHTTP(rr, req)
|
|
if rr.Code != http.StatusOK {
|
|
t.Errorf("expected 200, got %d", rr.Code)
|
|
}
|
|
if !strings.Contains(rr.Body.String(), "software-station Updates") {
|
|
t.Error("Specific software RSS title missing")
|
|
}
|
|
|
|
// Test non-existent software RSS
|
|
req = httptest.NewRequest("GET", "/api/rss?software=non-existent", nil)
|
|
rr = httptest.NewRecorder()
|
|
r.ServeHTTP(rr, req)
|
|
if !strings.Contains(rr.Body.String(), "<item>") == false {
|
|
// Should be empty channel
|
|
}
|
|
})
|
|
|
|
t.Run("Avatar Proxy & Cache", func(t *testing.T) {
|
|
avatarData := []byte("fake-image-data")
|
|
avatarServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "image/png")
|
|
w.Write(avatarData)
|
|
}))
|
|
defer avatarServer.Close()
|
|
|
|
hash := apiServer.RegisterURL(avatarServer.URL)
|
|
req := httptest.NewRequest("GET", fmt.Sprintf("/api/avatar?id=%s", hash), nil)
|
|
rr := httptest.NewRecorder()
|
|
r.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusOK {
|
|
t.Errorf("expected 200, got %d", rr.Code)
|
|
}
|
|
if rr.Header().Get("Content-Type") != "image/png" {
|
|
t.Errorf("expected image/png, got %s", rr.Header().Get("Content-Type"))
|
|
}
|
|
|
|
// Verify it was cached
|
|
cachePath := filepath.Join(".cache/avatars", hash)
|
|
if _, err := os.Stat(cachePath); os.IsNotExist(err) {
|
|
t.Error("avatar was not cached to disk")
|
|
}
|
|
})
|
|
|
|
t.Run("API Stats", func(t *testing.T) {
|
|
req := httptest.NewRequest("GET", "/api/stats", nil)
|
|
rr := httptest.NewRecorder()
|
|
r.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusOK {
|
|
t.Errorf("expected 200, got %d", rr.Code)
|
|
}
|
|
})
|
|
|
|
t.Run("Download Proxy Throttling", func(t *testing.T) {
|
|
assetServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Write(make([]byte, 1024))
|
|
}))
|
|
defer assetServer.Close()
|
|
|
|
hash := apiServer.RegisterURL(assetServer.URL)
|
|
req := httptest.NewRequest("GET", fmt.Sprintf("/api/download?id=%s", hash), nil)
|
|
req.Header.Set("User-Agent", "aria2/1.35.0")
|
|
rr := httptest.NewRecorder()
|
|
r.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusOK {
|
|
t.Errorf("expected 200, got %d", rr.Code)
|
|
}
|
|
})
|
|
|
|
t.Run("Download Range Request", func(t *testing.T) {
|
|
content := "0123456789"
|
|
assetServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.Header.Get("Range") == "bytes=2-5" {
|
|
w.Header().Set("Content-Range", "bytes 2-5/10")
|
|
w.WriteHeader(http.StatusPartialContent)
|
|
w.Write([]byte(content[2:6]))
|
|
} else {
|
|
w.Write([]byte(content))
|
|
}
|
|
}))
|
|
defer assetServer.Close()
|
|
|
|
hash := apiServer.RegisterURL(assetServer.URL)
|
|
req := httptest.NewRequest("GET", fmt.Sprintf("/api/download?id=%s", hash), nil)
|
|
req.Header.Set("Range", "bytes=2-5")
|
|
rr := httptest.NewRecorder()
|
|
r.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusPartialContent {
|
|
t.Errorf("expected 206, got %d", rr.Code)
|
|
}
|
|
})
|
|
|
|
t.Run("Security - Path Traversal", func(t *testing.T) {
|
|
patterns := []string{"/.git/config", "/etc/passwd"}
|
|
for _, p := range patterns {
|
|
req := httptest.NewRequest("GET", p, nil)
|
|
rr := httptest.NewRecorder()
|
|
r.ServeHTTP(rr, req)
|
|
if rr.Code != http.StatusForbidden {
|
|
t.Errorf("expected 403 for %s, got %d", p, rr.Code)
|
|
}
|
|
}
|
|
})
|
|
|
|
t.Run("Security - XSS in API", func(t *testing.T) {
|
|
malicious := []models.Software{{Name: "<script>alert(1)</script>"}}
|
|
testStatsService := stats.NewService("test-hashes.json")
|
|
srv := api.NewServer("", malicious, testStatsService)
|
|
req := httptest.NewRequest("GET", "/api/software", nil)
|
|
rr := httptest.NewRecorder()
|
|
srv.APISoftwareHandler(rr, req)
|
|
if !strings.Contains(rr.Body.String(), "script") {
|
|
t.Error("XSS payload missing")
|
|
}
|
|
})
|
|
|
|
t.Run("Download Proxy - Speed Downloader", func(t *testing.T) {
|
|
assetServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Write(make([]byte, 100))
|
|
}))
|
|
defer assetServer.Close()
|
|
hash := apiServer.RegisterURL(assetServer.URL)
|
|
req := httptest.NewRequest("GET", fmt.Sprintf("/api/download?id=%s", hash), nil)
|
|
req.Header.Set("User-Agent", "aria2/1.35.0")
|
|
rr := httptest.NewRecorder()
|
|
r.ServeHTTP(rr, req)
|
|
if rr.Code != http.StatusOK {
|
|
t.Errorf("expected 200, got %d", rr.Code)
|
|
}
|
|
})
|
|
|
|
t.Run("Static Files", func(t *testing.T) {
|
|
// Mock the filesystem behavior for the static handler
|
|
// We'll just test if the router handles unknown paths correctly
|
|
req := httptest.NewRequest("GET", "/unknown-path", nil)
|
|
rr := httptest.NewRecorder()
|
|
r.ServeHTTP(rr, req)
|
|
// It should try to serve index.html if file not found, but since build might be empty in tests, it might 500 or 404
|
|
if rr.Code == http.StatusOK || rr.Code == http.StatusNotFound || rr.Code == http.StatusInternalServerError {
|
|
t.Logf("Static handler returned %d", rr.Code)
|
|
}
|
|
})
|
|
|
|
t.Run("API Stats - Status unhealthy", func(t *testing.T) {
|
|
statsService.GlobalStats.Lock()
|
|
statsService.GlobalStats.TotalRequests = 100
|
|
statsService.GlobalStats.BlockedRequests["bad"] = true
|
|
// Make it more than 50% blocked
|
|
for i := 0; i < 60; i++ {
|
|
statsService.GlobalStats.BlockedRequests[fmt.Sprintf("bad%d", i)] = true
|
|
}
|
|
statsService.GlobalStats.Unlock()
|
|
|
|
req := httptest.NewRequest("GET", "/api/stats", nil)
|
|
rr := httptest.NewRecorder()
|
|
r.ServeHTTP(rr, req)
|
|
|
|
var s map[string]interface{}
|
|
json.Unmarshal(rr.Body.Bytes(), &s)
|
|
if s["status"] != "unhealthy" {
|
|
t.Errorf("expected unhealthy status, got %v", s["status"])
|
|
}
|
|
|
|
// Reset
|
|
statsService.GlobalStats.Lock()
|
|
statsService.GlobalStats.BlockedRequests = make(map[string]bool)
|
|
statsService.GlobalStats.TotalRequests = 0
|
|
statsService.GlobalStats.Unlock()
|
|
})
|
|
}
|
|
|
|
func TestOSDetection(t *testing.T) {
|
|
if gitea.DetectOS("test.exe") != models.OSWindows {
|
|
t.Error("expected windows")
|
|
}
|
|
}
|