Files

293 lines
6.8 KiB
Go

package query
import (
"encoding/json"
"fmt"
"strconv"
"strings"
"git.quad4.io/quad4-software/osv-server/internal/indexer"
)
type QueryRequest struct {
Package *PackageQuery `json:"package,omitempty"`
Version string `json:"version,omitempty"`
Commit string `json:"commit,omitempty"`
}
type PackageQuery struct {
Name string `json:"name"`
Ecosystem string `json:"ecosystem"`
}
type Vulnerability struct {
ID string `json:"id"`
Summary string `json:"summary,omitempty"`
Details string `json:"details,omitempty"`
Modified string `json:"modified,omitempty"`
Published string `json:"published,omitempty"`
Affected []Affected `json:"affected,omitempty"`
References []Reference `json:"references,omitempty"`
DatabaseSpecific map[string]interface{} `json:"database_specific,omitempty"`
}
type Affected struct {
Package PackageInfo `json:"package"`
Ranges []Range `json:"ranges,omitempty"`
Versions []string `json:"versions,omitempty"`
EcosystemSpecific map[string]interface{} `json:"ecosystem_specific,omitempty"`
}
type PackageInfo struct {
Name string `json:"name"`
Ecosystem string `json:"ecosystem"`
PURL string `json:"purl,omitempty"`
}
type Range struct {
Type string `json:"type"`
Events []Event `json:"events"`
}
type Event struct {
Introduced string `json:"introduced,omitempty"`
Fixed string `json:"fixed,omitempty"`
Limit string `json:"limit,omitempty"`
}
type Reference struct {
Type string `json:"type"`
URL string `json:"url"`
}
type QueryResponse struct {
Vulns []Vulnerability `json:"vulns"`
}
func QueryDatabase(idx *indexer.Indexer, req *QueryRequest) (*QueryResponse, error) {
if req.Package == nil && req.Commit == "" {
return &QueryResponse{Vulns: []Vulnerability{}}, nil
}
db := idx.GetDB()
var vulnIDs []string
if req.Commit != "" {
rows, err := db.Query("SELECT DISTINCT vuln_id FROM git_commits WHERE commit_hash = ?", req.Commit)
if err != nil {
return nil, fmt.Errorf("failed to query commits: %w", err)
}
defer rows.Close()
for rows.Next() {
var id string
if err := rows.Scan(&id); err == nil {
vulnIDs = append(vulnIDs, id)
}
}
} else if req.Package != nil {
if req.Version != "" {
rows, err := db.Query(`
SELECT DISTINCT a.vuln_id
FROM affected a
LEFT JOIN affected_versions av ON a.vuln_id = av.vuln_id AND a.package_name = av.package_name AND a.package_ecosystem = av.package_ecosystem
LEFT JOIN affected_ranges ar ON a.vuln_id = ar.vuln_id AND a.package_name = ar.package_name AND a.package_ecosystem = ar.package_ecosystem
WHERE a.package_name = ? AND a.package_ecosystem = ?
AND (av.version = ? OR (ar.introduced IS NOT NULL AND ar.fixed IS NOT NULL AND ? >= ar.introduced AND ? < ar.fixed))
`, req.Package.Name, req.Package.Ecosystem, req.Version, req.Version, req.Version)
if err != nil {
return nil, fmt.Errorf("failed to query versions: %w", err)
}
defer rows.Close()
for rows.Next() {
var id string
if err := rows.Scan(&id); err == nil {
vulnIDs = append(vulnIDs, id)
}
}
} else {
rows, err := db.Query("SELECT DISTINCT vuln_id FROM affected WHERE package_name = ? AND package_ecosystem = ?", req.Package.Name, req.Package.Ecosystem)
if err != nil {
return nil, fmt.Errorf("failed to query package: %w", err)
}
defer rows.Close()
for rows.Next() {
var id string
if err := rows.Scan(&id); err == nil {
vulnIDs = append(vulnIDs, id)
}
}
}
}
if len(vulnIDs) == 0 {
return &QueryResponse{Vulns: []Vulnerability{}}, nil
}
placeholders := strings.Repeat("?,", len(vulnIDs))
placeholders = placeholders[:len(placeholders)-1]
args := make([]interface{}, len(vulnIDs))
for i, id := range vulnIDs {
args[i] = id
}
// #nosec G201 - placeholders is constructed from a safe repeat of "?"
query := fmt.Sprintf("SELECT vuln_json FROM vulnerabilities WHERE id IN (%s)", placeholders)
rows, err := db.Query(query, args...)
if err != nil {
return nil, fmt.Errorf("failed to query vulnerabilities: %w", err)
}
defer rows.Close()
var vulns []Vulnerability
for rows.Next() {
var vulnJSON string
if err := rows.Scan(&vulnJSON); err != nil {
continue
}
var vuln Vulnerability
if err := json.Unmarshal([]byte(vulnJSON), &vuln); err != nil {
continue
}
vulns = append(vulns, vuln)
}
return &QueryResponse{Vulns: vulns}, nil
}
func matches(req *QueryRequest, vuln *Vulnerability) bool {
if req.Commit != "" {
return matchesCommit(req.Commit, vuln)
}
if req.Package == nil {
return false
}
for _, affected := range vuln.Affected {
if affected.Package.Name != req.Package.Name {
continue
}
if affected.Package.Ecosystem != req.Package.Ecosystem {
continue
}
if req.Version == "" {
return true
}
if matchesVersion(req.Version, &affected) {
return true
}
}
return false
}
func matchesCommit(commit string, vuln *Vulnerability) bool {
for _, affected := range vuln.Affected {
for _, r := range affected.Ranges {
if r.Type == "GIT" {
for _, event := range r.Events {
if event.Introduced == commit || event.Fixed == commit {
return true
}
}
}
}
}
return false
}
func matchesVersion(version string, affected *Affected) bool {
for _, v := range affected.Versions {
if v == version {
return true
}
}
for _, r := range affected.Ranges {
if r.Type == "ECOSYSTEM" || r.Type == "SEMVER" {
if isVersionAffected(version, r.Events) {
return true
}
}
}
return false
}
func isVersionAffected(version string, events []Event) bool {
var introduced, fixed string
for _, event := range events {
if event.Introduced != "" {
introduced = event.Introduced
}
if event.Fixed != "" {
fixed = event.Fixed
}
}
if introduced == "" {
return false
}
if fixed != "" {
return compareVersions(version, introduced) >= 0 && compareVersions(version, fixed) < 0
}
return compareVersions(version, introduced) >= 0
}
func compareVersions(v1, v2 string) int {
if v1 == v2 {
return 0
}
v1Parts := strings.Split(strings.TrimPrefix(v1, "v"), ".")
v2Parts := strings.Split(strings.TrimPrefix(v2, "v"), ".")
maxLen := len(v1Parts)
if len(v2Parts) > maxLen {
maxLen = len(v2Parts)
}
for i := 0; i < maxLen; i++ {
var v1Part, v2Part string
if i < len(v1Parts) {
v1Part = v1Parts[i]
}
if i < len(v2Parts) {
v2Part = v2Parts[i]
}
// Try numeric comparison
n1, err1 := strconv.Atoi(v1Part)
n2, err2 := strconv.Atoi(v2Part)
if err1 == nil && err2 == nil {
if n1 < n2 {
return -1
}
if n1 > n2 {
return 1
}
continue
}
if v1Part < v2Part {
return -1
}
if v1Part > v2Part {
return 1
}
}
return 0
}