144 lines
3.8 KiB
Go
144 lines
3.8 KiB
Go
package api
|
|
|
|
import (
|
|
"archive/zip"
|
|
"bytes"
|
|
"encoding/json"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"os"
|
|
"path/filepath"
|
|
"testing"
|
|
"time"
|
|
|
|
"git.quad4.io/quad4-software/osv-server/internal/config"
|
|
"git.quad4.io/quad4-software/osv-server/internal/osv"
|
|
"git.quad4.io/quad4-software/osv-server/internal/query"
|
|
)
|
|
|
|
func TestAPIServer(t *testing.T) {
|
|
tmpDir, _ := os.MkdirTemp("", "api-test-*")
|
|
defer os.RemoveAll(tmpDir)
|
|
|
|
cfg := &config.Config{
|
|
DataDir: tmpDir,
|
|
UpdateInterval: 1 * time.Hour,
|
|
MaxRetries: 1,
|
|
RetryDelay: 10 * time.Millisecond,
|
|
EnableDownloadEndpoint: true,
|
|
}
|
|
|
|
m, _ := osv.NewManager(cfg)
|
|
server := NewServer(m, cfg)
|
|
|
|
// Test /health
|
|
req, _ := http.NewRequest("GET", "/health", nil)
|
|
rr := httptest.NewRecorder()
|
|
server.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusOK {
|
|
t.Errorf("Expected status 200, got %d", rr.Code)
|
|
}
|
|
|
|
// Test /stats
|
|
req, _ = http.NewRequest("GET", "/stats", nil)
|
|
rr = httptest.NewRecorder()
|
|
server.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusOK {
|
|
t.Errorf("Expected status 200, got %d", rr.Code)
|
|
}
|
|
|
|
// Test /v1/query (POST) - even if empty result
|
|
queryReq := query.QueryRequest{
|
|
Package: &query.PackageQuery{
|
|
Name: "test",
|
|
Ecosystem: "test",
|
|
},
|
|
}
|
|
body, _ := json.Marshal(queryReq)
|
|
req, _ = http.NewRequest("POST", "/v1/query", bytes.NewBuffer(body))
|
|
rr = httptest.NewRecorder()
|
|
server.ServeHTTP(rr, req)
|
|
|
|
// If data doesn't exist yet, it should return 503
|
|
if rr.Code != http.StatusServiceUnavailable {
|
|
t.Errorf("Expected status 503 (downloading), got %d", rr.Code)
|
|
}
|
|
|
|
// Create dummy data to test 200 responses
|
|
dataPath := filepath.Join(tmpDir, "all.zip")
|
|
_ = os.WriteFile(dataPath, []byte("fake zip"), 0644)
|
|
|
|
// Test /health again (should be healthy now)
|
|
req, _ = http.NewRequest("GET", "/health", nil)
|
|
rr = httptest.NewRecorder()
|
|
server.ServeHTTP(rr, req)
|
|
if rr.Code != http.StatusOK {
|
|
t.Errorf("Expected status 200, got %d", rr.Code)
|
|
}
|
|
|
|
// Test /v1/query (GET)
|
|
req, _ = http.NewRequest("GET", "/v1/query", nil)
|
|
rr = httptest.NewRecorder()
|
|
server.ServeHTTP(rr, req)
|
|
if rr.Code != http.StatusOK {
|
|
t.Errorf("Expected status 200, got %d", rr.Code)
|
|
}
|
|
|
|
// Test /v1/download (GET)
|
|
req, _ = http.NewRequest("GET", "/v1/download", nil)
|
|
rr = httptest.NewRecorder()
|
|
server.ServeHTTP(rr, req)
|
|
if rr.Code != http.StatusOK {
|
|
t.Errorf("Expected status 200, got %d", rr.Code)
|
|
}
|
|
|
|
// Test /v1/query (POST) success
|
|
// We need to index some data first in the mock indexer
|
|
idx := m.GetIndexer()
|
|
f, _ := os.CreateTemp("", "v1-*.zip")
|
|
zw := zip.NewWriter(f)
|
|
fw, _ := zw.Create("test.json")
|
|
_, _ = fw.Write([]byte(`{"id":"GHSA-test","affected":[{"package":{"name":"test-pkg","ecosystem":"test-eco"},"versions":["1.0.0"]}]}`))
|
|
zw.Close()
|
|
f.Close()
|
|
_ = idx.IndexZip(f.Name())
|
|
os.Remove(f.Name())
|
|
|
|
queryReq = query.QueryRequest{
|
|
Package: &query.PackageQuery{
|
|
Name: "test-pkg",
|
|
Ecosystem: "test-eco",
|
|
},
|
|
Version: "1.0.0",
|
|
}
|
|
body, _ = json.Marshal(queryReq)
|
|
req, _ = http.NewRequest("POST", "/v1/query", bytes.NewBuffer(body))
|
|
rr = httptest.NewRecorder()
|
|
server.ServeHTTP(rr, req)
|
|
if rr.Code != http.StatusOK {
|
|
t.Errorf("Expected status 200, got %d", rr.Code)
|
|
}
|
|
|
|
var queryResp query.QueryResponse
|
|
_ = json.NewDecoder(rr.Body).Decode(&queryResp)
|
|
if len(queryResp.Vulns) != 1 || queryResp.Vulns[0].ID != "GHSA-test" {
|
|
t.Errorf("Expected GHSA-test, got %+v", queryResp)
|
|
}
|
|
|
|
// Test /v1/querybatch (POST)
|
|
batchReq := struct {
|
|
Queries []query.QueryRequest `json:"queries"`
|
|
}{
|
|
Queries: []query.QueryRequest{queryReq},
|
|
}
|
|
body, _ = json.Marshal(batchReq)
|
|
req, _ = http.NewRequest("POST", "/v1/querybatch", bytes.NewBuffer(body))
|
|
rr = httptest.NewRecorder()
|
|
server.ServeHTTP(rr, req)
|
|
if rr.Code != http.StatusOK {
|
|
t.Errorf("Expected status 200, got %d", rr.Code)
|
|
}
|
|
}
|