From 6cca258d61cc0b051d1996045cdb0ff3262bc172 Mon Sep 17 00:00:00 2001 From: Rex P Date: Mon, 23 Dec 2024 12:19:06 +1100 Subject: [PATCH] Add more tests --- internal/osvdev/osvdev_test.go | 67 +++++++++++++++++++++++++++++++++- 1 file changed, 66 insertions(+), 1 deletion(-) diff --git a/internal/osvdev/osvdev_test.go b/internal/osvdev/osvdev_test.go index 012a918ca2..d605e1ac3d 100644 --- a/internal/osvdev/osvdev_test.go +++ b/internal/osvdev/osvdev_test.go @@ -92,6 +92,25 @@ func TestOSVClient_QueryBatch(t *testing.T) { {}, }, }, + { + name: "multiple queries with invalid", + queries: []*osvdev.Query{ + { + Package: osvdev.Package{ + Name: "faker", + Ecosystem: string(osvschema.EcosystemNPM), + }, + Version: "6.6.6", + }, + { + Package: osvdev.Package{ + Name: "abcd-definitely-does-not-exist", + }, + }, + }, + wantIDs: [][]string{}, + wantErrContains: `client error: status="400 Bad Request" body={"code":3,"message":"Invalid query."}`, + }, } for _, tt := range tests { @@ -204,5 +223,51 @@ func TestOSVClient_Query(t *testing.T) { } func TestOSVClient_ExperimentalDetermineVersion(t *testing.T) { - // TODO + t.Parallel() + + tests := []struct { + name string + query osvdev.DetermineVersionsRequest + wantPkgs []string + wantErrContains string + }{ + { + name: "Simple non existent package query", + query: osvdev.DetermineVersionsRequest{ + Name: "test file", + FileHashes: []osvdev.DetermineVersionHash{ + { + Path: "test file/file", + Hash: []byte{}, + }, + }, + }, + wantPkgs: []string{}, + }, + // TODO: Add query for an actual package, this is not added at the moment as it requires too many hashes + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + c := osvdev.DefaultClient() + c.Config.UserAgent = "osv-scanner-api-test" + got, err := c.ExperimentalDetermineVersion(context.Background(), &tt.query) + if err != nil { + if tt.wantErrContains == "" || !strings.Contains(err.Error(), tt.wantErrContains) { + t.Errorf("OSVClient.GetVulnsByID() error = %v, wantErr %q", err, tt.wantErrContains) + } + return + } + + gotPkgInfo := make([]string, 0, len(got.Matches)) + for _, vuln := range got.Matches { + gotPkgInfo = append(gotPkgInfo, vuln.RepoInfo.Address+"@"+vuln.RepoInfo.Version) + } + + if diff := cmp.Diff(tt.wantPkgs, gotPkgInfo); diff != "" { + t.Errorf("Unexpected vuln IDs (-want +got):\n%s", diff) + } + }) + } }