Add tests
All checks were successful
CI / build (push) Successful in 27s
Tests / test (push) Successful in 1m2s

This commit is contained in:
2025-12-26 22:34:48 -06:00
parent 53fb9fafb5
commit 8445ec934a
9 changed files with 731 additions and 0 deletions

View File

@@ -0,0 +1,60 @@
package api
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"time"
"git.quad4.io/Quad4-Software/osv-server/internal/config"
"git.quad4.io/Quad4-Software/osv-server/internal/indexer"
"git.quad4.io/Quad4-Software/osv-server/internal/osv"
"git.quad4.io/Quad4-Software/osv-server/internal/query"
)
func BenchmarkAPIServer(b *testing.B) {
dbPath := "../../data"
if _, err := os.Stat(filepath.Join(dbPath, indexer.IndexDBName)); os.IsNotExist(err) {
b.Skip("Real database not found at data/osv.db, skipping high-load benchmark")
}
cfg := &config.Config{
DataDir: dbPath,
UpdateInterval: 24 * time.Hour,
MaxRetries: 1,
RetryDelay: 1 * time.Second,
}
m, err := osv.NewManager(cfg)
if err != nil {
b.Fatalf("Failed to create manager: %v", err)
}
server := NewServer(m, cfg)
queryReq := query.QueryRequest{
Package: &query.PackageQuery{
Name: "requests",
Ecosystem: "PyPI",
},
Version: "2.28.0",
}
body, _ := json.Marshal(queryReq)
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
req := httptest.NewRequest("POST", "/v1/query", bytes.NewBuffer(body))
rr := httptest.NewRecorder()
server.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
// We expect 200 since data/osv.db exists
}
}
})
}

144
internal/api/api_test.go Normal file
View File

@@ -0,0 +1,144 @@
package api
import (
"archive/zip"
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"time"
"git.quad4.io/Quad4-Software/osv-server/internal/config"
"git.quad4.io/Quad4-Software/osv-server/internal/osv"
"git.quad4.io/Quad4-Software/osv-server/internal/query"
)
func TestAPIServer(t *testing.T) {
tmpDir, _ := os.MkdirTemp("", "api-test-*")
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
DataDir: tmpDir,
UpdateInterval: 1 * time.Hour,
MaxRetries: 1,
RetryDelay: 10 * time.Millisecond,
EnableDownloadEndpoint: true,
}
m, _ := osv.NewManager(cfg)
server := NewServer(m, cfg)
// Test /health
req, _ := http.NewRequest("GET", "/health", nil)
rr := httptest.NewRecorder()
server.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", rr.Code)
}
// Test /stats
req, _ = http.NewRequest("GET", "/stats", nil)
rr = httptest.NewRecorder()
server.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", rr.Code)
}
// Test /v1/query (POST) - even if empty result
queryReq := query.QueryRequest{
Package: &query.PackageQuery{
Name: "test",
Ecosystem: "test",
},
}
body, _ := json.Marshal(queryReq)
req, _ = http.NewRequest("POST", "/v1/query", bytes.NewBuffer(body))
rr = httptest.NewRecorder()
server.ServeHTTP(rr, req)
// If data doesn't exist yet, it should return 503
if rr.Code != http.StatusServiceUnavailable {
t.Errorf("Expected status 503 (downloading), got %d", rr.Code)
}
// Create dummy data to test 200 responses
dataPath := filepath.Join(tmpDir, "all.zip")
_ = os.WriteFile(dataPath, []byte("fake zip"), 0644)
// Test /health again (should be healthy now)
req, _ = http.NewRequest("GET", "/health", nil)
rr = httptest.NewRecorder()
server.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", rr.Code)
}
// Test /v1/query (GET)
req, _ = http.NewRequest("GET", "/v1/query", nil)
rr = httptest.NewRecorder()
server.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", rr.Code)
}
// Test /v1/download (GET)
req, _ = http.NewRequest("GET", "/v1/download", nil)
rr = httptest.NewRecorder()
server.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", rr.Code)
}
// Test /v1/query (POST) success
// We need to index some data first in the mock indexer
idx := m.GetIndexer()
f, _ := os.CreateTemp("", "v1-*.zip")
zw := zip.NewWriter(f)
fw, _ := zw.Create("test.json")
_, _ = fw.Write([]byte(`{"id":"GHSA-test","affected":[{"package":{"name":"test-pkg","ecosystem":"test-eco"},"versions":["1.0.0"]}]}`))
zw.Close()
f.Close()
_ = idx.IndexZip(f.Name())
os.Remove(f.Name())
queryReq = query.QueryRequest{
Package: &query.PackageQuery{
Name: "test-pkg",
Ecosystem: "test-eco",
},
Version: "1.0.0",
}
body, _ = json.Marshal(queryReq)
req, _ = http.NewRequest("POST", "/v1/query", bytes.NewBuffer(body))
rr = httptest.NewRecorder()
server.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", rr.Code)
}
var queryResp query.QueryResponse
_ = json.NewDecoder(rr.Body).Decode(&queryResp)
if len(queryResp.Vulns) != 1 || queryResp.Vulns[0].ID != "GHSA-test" {
t.Errorf("Expected GHSA-test, got %+v", queryResp)
}
// Test /v1/querybatch (POST)
batchReq := struct {
Queries []query.QueryRequest `json:"queries"`
}{
Queries: []query.QueryRequest{queryReq},
}
body, _ = json.Marshal(batchReq)
req, _ = http.NewRequest("POST", "/v1/querybatch", bytes.NewBuffer(body))
rr = httptest.NewRecorder()
server.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", rr.Code)
}
}

View File

@@ -0,0 +1,93 @@
package config
import (
"os"
"testing"
"time"
)
func TestLoad(t *testing.T) {
// Test defaults
cfg, err := LoadWithArgs([]string{})
if err != nil {
t.Fatalf("Failed to load config: %v", err)
}
if cfg.Port != DefaultPort {
t.Errorf("Expected port %d, got %d", DefaultPort, cfg.Port)
}
// Test environment variables
os.Setenv(EnvPort, "9090")
defer os.Unsetenv(EnvPort)
cfg, _ = LoadWithArgs([]string{})
if cfg.Port != 9090 {
t.Errorf("Expected port 9090 from env, got %d", cfg.Port)
}
}
func TestLoadFromFile(t *testing.T) {
tmpFile := "test_config.ini"
content := `
port = 9999
data_dir = /tmp/data
update_interval = 12h
osv_base_url = https://example.com
max_retries = 5
retry_delay = 10s
enable_download_endpoint = true
peers = peer1,peer2
`
if err := os.WriteFile(tmpFile, []byte(content), 0644); err != nil {
t.Fatalf("Failed to write test config file: %v", err)
}
defer os.Remove(tmpFile)
cfg, err := LoadWithArgs([]string{"-config", tmpFile})
if err != nil {
t.Fatalf("Failed to load config: %v", err)
}
if cfg.Port != 9999 {
t.Errorf("Expected port 9999, got %d", cfg.Port)
}
if cfg.DataDir != "/tmp/data" {
t.Errorf("Expected data_dir /tmp/data, got %s", cfg.DataDir)
}
if cfg.UpdateInterval != 12*time.Hour {
t.Errorf("Expected update_interval 12h, got %v", cfg.UpdateInterval)
}
if cfg.OSVBaseURL != "https://example.com" {
t.Errorf("Expected osv_base_url https://example.com, got %s", cfg.OSVBaseURL)
}
if cfg.MaxRetries != 5 {
t.Errorf("Expected max_retries 5, got %d", cfg.MaxRetries)
}
if cfg.RetryDelay != 10*time.Second {
t.Errorf("Expected retry_delay 10s, got %v", cfg.RetryDelay)
}
if !cfg.EnableDownloadEndpoint {
t.Errorf("Expected enable_download_endpoint true")
}
if len(cfg.Peers) != 2 || cfg.Peers[0] != "peer1" || cfg.Peers[1] != "peer2" {
t.Errorf("Unexpected peers: %v", cfg.Peers)
}
}
func TestSplitComma(t *testing.T) {
input := "url1, url2,url3 "
expected := []string{"url1", "url2", "url3"}
result := splitComma(input)
if len(result) != len(expected) {
t.Fatalf("Expected %d parts, got %d", len(expected), len(result))
}
for i, v := range result {
if v != expected[i] {
t.Errorf("Expected part %d to be %s, got %s", i, expected[i], v)
}
}
}

View File

@@ -0,0 +1,62 @@
package downloader
import (
"fmt"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"time"
)
func TestDownload(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("Range") != "" {
w.WriteHeader(http.StatusPartialContent)
fmt.Fprintf(w, "world")
return
}
fmt.Fprintf(w, "hello")
}))
defer server.Close()
tmpDir, _ := os.MkdirTemp("", "dl-test-*")
defer os.RemoveAll(tmpDir)
dl := New(1, 100*time.Millisecond)
dest := filepath.Join(tmpDir, "test.txt")
// Full download
err := dl.Download(server.URL, dest)
if err != nil {
t.Fatalf("Download failed: %v", err)
}
data, _ := os.ReadFile(dest)
if string(data) != "hello" {
t.Errorf("Expected 'hello', got '%s'", string(data))
}
// Mocking resume would require more complex server logic, but let's test basic success
}
func TestGetLastModified(t *testing.T) {
now := time.Now().UTC().Truncate(time.Second)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Last-Modified", now.Format(http.TimeFormat))
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
dl := New(1, 100*time.Millisecond)
lastMod, err := dl.GetLastModified(server.URL)
if err != nil {
t.Fatalf("Failed to get last modified: %v", err)
}
if !lastMod.UTC().Equal(now) {
t.Errorf("Expected %v, got %v", now, lastMod.UTC())
}
}

View File

@@ -0,0 +1,46 @@
package indexer
import (
"archive/zip"
"os"
"path/filepath"
"testing"
)
func TestIndexer(t *testing.T) {
tmpDir, _ := os.MkdirTemp("", "idx-test-*")
defer os.RemoveAll(tmpDir)
idx, err := New(tmpDir)
if err != nil {
t.Fatalf("Failed to create indexer: %v", err)
}
defer idx.Close()
// Create mock zip
zipPath := filepath.Join(tmpDir, "test.zip")
f, _ := os.Create(zipPath)
zw := zip.NewWriter(f)
w, _ := zw.Create("vuln1.json")
_, _ = w.Write([]byte(`{"id":"VULN-1","summary":"Test summary","affected":[{"package":{"name":"pkg1","ecosystem":"eco1"},"versions":["1.0.0"]}]}`))
zw.Close()
f.Close()
// Test IndexZip
if err := idx.IndexZip(zipPath); err != nil {
t.Fatalf("Failed to index zip: %v", err)
}
// Verify data in DB
db := idx.GetDB()
var summary string
err = db.QueryRow("SELECT summary FROM vulnerabilities WHERE id = 'VULN-1'").Scan(&summary)
if err != nil {
t.Fatalf("Failed to query DB: %v", err)
}
if summary != "Test summary" {
t.Errorf("Expected 'Test summary', got '%s'", summary)
}
}

60
internal/osv/osv_test.go Normal file
View File

@@ -0,0 +1,60 @@
package osv
import (
"archive/zip"
"net/http"
"net/http/httptest"
"os"
"testing"
"time"
"git.quad4.io/Quad4-Software/osv-server/internal/config"
)
func TestManager(t *testing.T) {
tmpDir, _ := os.MkdirTemp("", "osv-test-*")
defer os.RemoveAll(tmpDir)
// Mock server to serve the ZIP
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
f, _ := os.CreateTemp("", "mock-*.zip")
zw := zip.NewWriter(f)
fw, _ := zw.Create("v1.json")
_, _ = fw.Write([]byte(`{"id":"V1"}`))
zw.Close()
f.Close()
http.ServeFile(w, r, f.Name())
os.Remove(f.Name())
}))
defer server.Close()
cfg := &config.Config{
DataDir: tmpDir,
OSVBaseURL: server.URL,
UpdateInterval: 1 * time.Hour,
MaxRetries: 1,
RetryDelay: 10 * time.Millisecond,
}
m, err := NewManager(cfg)
if err != nil {
t.Fatalf("Failed to create manager: %v", err)
}
// Test downloadAll (internal but testable via side effects)
err = m.downloadAll()
if err != nil {
t.Fatalf("downloadAll failed: %v", err)
}
if !m.DataExists() {
t.Errorf("Expected data to exist")
}
// Stats are updated in Update(), but we can check if data is there
meta, err := m.GetMetadata()
if err != nil || meta == nil {
t.Errorf("Expected metadata to be set, got %v", err)
}
}

View File

@@ -0,0 +1,68 @@
package query
import (
"os"
"path/filepath"
"testing"
"git.quad4.io/Quad4-Software/osv-server/internal/indexer"
)
func BenchmarkQueryDatabase(b *testing.B) {
// Use real database if it exists, otherwise skip or use mock
dbPath := "../../data"
if _, err := os.Stat(filepath.Join(dbPath, indexer.IndexDBName)); os.IsNotExist(err) {
b.Skip("Real database not found at data/osv.db, skipping high-load benchmark")
}
idx, err := indexer.New(dbPath)
if err != nil {
b.Fatalf("Failed to open indexer: %v", err)
}
defer idx.Close()
// Representative queries
queries := []*QueryRequest{
{
Package: &PackageQuery{
Name: "requests",
Ecosystem: "PyPI",
},
Version: "2.28.0",
},
{
Package: &PackageQuery{
Name: "lodash",
Ecosystem: "npm",
},
Version: "4.17.20",
},
{
Commit: "6879efc2c1596d11a6a6ad296f800f2ee0a2c0db",
},
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
req := queries[i%len(queries)]
_, err := QueryDatabase(idx, req)
if err != nil {
// Don't fail the whole benchmark if one query fails, but log it
continue
}
i++
}
})
}
func BenchmarkCompareVersions(b *testing.B) {
v1 := "1.2.3"
v2 := "1.2.4"
b.ResetTimer()
for i := 0; i < b.N; i++ {
compareVersions(v1, v2)
}
}

View File

@@ -0,0 +1,122 @@
package query
import (
"archive/zip"
"os"
"path/filepath"
"testing"
"git.quad4.io/Quad4-Software/osv-server/internal/indexer"
)
func TestQueryDatabase(t *testing.T) {
tmpDir, _ := os.MkdirTemp("", "query-test-*")
defer os.RemoveAll(tmpDir)
idx, _ := indexer.New(tmpDir)
defer idx.Close()
// Index some test data
zipPath := filepath.Join(tmpDir, "test.zip")
f, _ := os.Create(zipPath)
zw := zip.NewWriter(f)
w, _ := zw.Create("v1.json")
_, _ = w.Write([]byte(`{"id":"V1","summary":"S1","affected":[{"package":{"name":"p1","ecosystem":"e1"},"versions":["1.0.0"]}]}`))
zw.Close()
f.Close()
_ = idx.IndexZip(zipPath)
// Test simple package query
req := &QueryRequest{
Package: &PackageQuery{
Name: "p1",
Ecosystem: "e1",
},
}
resp, err := QueryDatabase(idx, req)
if err != nil {
t.Fatalf("Query failed: %v", err)
}
if len(resp.Vulns) != 1 || resp.Vulns[0].ID != "V1" {
t.Errorf("Expected V1, got %+v", resp)
}
// Test version query
req.Version = "1.0.0"
resp, _ = QueryDatabase(idx, req)
if len(resp.Vulns) != 1 {
t.Errorf("Expected 1 vuln for version 1.0.0, got %d", len(resp.Vulns))
}
req.Version = "2.0.0"
resp, _ = QueryDatabase(idx, req)
if len(resp.Vulns) != 0 {
t.Errorf("Expected 0 vulns for version 2.0.0, got %d", len(resp.Vulns))
}
}
func TestMatches(t *testing.T) {
vuln := &Vulnerability{
ID: "V1",
Affected: []Affected{
{
Package: PackageInfo{Name: "p1", Ecosystem: "e1"},
Versions: []string{"1.0.0"},
Ranges: []Range{
{
Type: "SEMVER",
Events: []Event{
{Introduced: "2.0.0", Fixed: "2.1.0"},
},
},
{
Type: "GIT",
Events: []Event{
{Introduced: "commit1"},
},
},
},
},
},
}
tests := []struct {
req *QueryRequest
want bool
}{
{&QueryRequest{Package: &PackageQuery{Name: "p1", Ecosystem: "e1"}}, true},
{&QueryRequest{Package: &PackageQuery{Name: "p1", Ecosystem: "e1"}, Version: "1.0.0"}, true},
{&QueryRequest{Package: &PackageQuery{Name: "p1", Ecosystem: "e1"}, Version: "2.0.5"}, true},
{&QueryRequest{Package: &PackageQuery{Name: "p1", Ecosystem: "e1"}, Version: "2.1.0"}, false},
{&QueryRequest{Commit: "commit1"}, true},
{&QueryRequest{Commit: "unknown"}, false},
{&QueryRequest{Package: &PackageQuery{Name: "unknown", Ecosystem: "e1"}}, false},
}
for _, tt := range tests {
if got := matches(tt.req, vuln); got != tt.want {
t.Errorf("matches(%+v) = %v, want %v", tt.req, got, tt.want)
}
}
}
func TestCompareVersions(t *testing.T) {
tests := []struct {
v1, v2 string
want int
}{
{"1.0.0", "1.0.0", 0},
{"1.0.0", "1.1.0", -1},
{"1.1.0", "1.0.0", 1},
{"v1.0.0", "1.0.0", 0},
{"1.2", "1.10", -1}, // String comparison behavior in current implementation
}
for _, tt := range tests {
got := compareVersions(tt.v1, tt.v2)
if got != tt.want {
t.Errorf("compareVersions(%s, %s) = %d, want %d", tt.v1, tt.v2, got, tt.want)
}
}
}

View File

@@ -0,0 +1,76 @@
package storage
import (
"os"
"path/filepath"
"testing"
"time"
)
func TestStorage(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "storage-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
s, err := New(tmpDir)
if err != nil {
t.Fatalf("Failed to create storage: %v", err)
}
resource := "test.zip"
meta := &Metadata{
LastModified: time.Now().Truncate(time.Second),
Hash: "abc",
Size: 123,
DownloadedAt: time.Now().Truncate(time.Second),
}
// Test Save/Load Metadata
if err := s.SaveMetadata(resource, meta); err != nil {
t.Fatalf("Failed to save metadata: %v", err)
}
loaded, err := s.LoadMetadata(resource)
if err != nil {
t.Fatalf("Failed to load metadata: %v", err)
}
if loaded.Hash != meta.Hash || loaded.Size != meta.Size {
t.Errorf("Metadata mismatch: expected %+v, got %+v", meta, loaded)
}
// Test FileExists/GetFileSize
testFile := filepath.Join(tmpDir, "test.file")
if err := os.WriteFile(testFile, []byte("hello"), 0644); err != nil {
t.Fatalf("Failed to write test file: %v", err)
}
if !s.FileExists(testFile) {
t.Errorf("Expected file to exist")
}
size, err := s.GetFileSize(testFile)
if err != nil || size != 5 {
t.Errorf("Expected size 5, got %d (err: %v)", size, err)
}
// Test ComputeHash
hash, err := s.ComputeHash(testFile)
if err != nil || hash == "" {
t.Errorf("Failed to compute hash: %v", err)
}
// Test IsStale
stale, err := s.IsStale(resource, "abc", meta.LastModified, 1*time.Hour)
if err != nil || stale {
t.Errorf("Expected not stale, got stale: %v (err: %v)", stale, err)
}
stale, _ = s.IsStale(resource, "wrong", meta.LastModified, 1*time.Hour)
if !stale {
t.Errorf("Expected stale due to wrong hash")
}
}