Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: propagate time through context values #5858

Merged
merged 2 commits into from
Jan 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions integration/client_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ import (
"github.com/stretchr/testify/require"
testcontainers "github.com/testcontainers/testcontainers-go"

"github.com/aquasecurity/trivy/pkg/clock"
"github.com/aquasecurity/trivy/pkg/report"
"github.com/aquasecurity/trivy/pkg/uuid"
)
Expand Down Expand Up @@ -364,8 +363,6 @@ func TestClientServerWithFormat(t *testing.T) {
}

fakeTime := time.Date(2021, 8, 25, 12, 20, 30, 5, time.UTC)
clock.SetFakeTime(t, fakeTime)

report.CustomTemplateFuncMap = map[string]interface{}{
"now": func() time.Time {
return fakeTime
Expand Down Expand Up @@ -428,7 +425,6 @@ func TestClientServerWithCycloneDX(t *testing.T) {
addr, cacheDir := setup(t, setupOptions{})
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
clock.SetFakeTime(t, time.Date(2021, 8, 25, 12, 20, 30, 5, time.UTC))
uuid.SetFakeUUID(t, "3ff14136-e09f-4df9-80ea-%012d")

osArgs, outputFile := setupClient(t, tt.args, addr, cacheDir, tt.golden)
Expand Down
11 changes: 6 additions & 5 deletions integration/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"encoding/json"
"flag"
"fmt"
"github.com/aquasecurity/trivy/pkg/clock"
"io"
"net"
"os"
Expand All @@ -27,7 +28,6 @@ import (

"github.com/aquasecurity/trivy-db/pkg/db"
"github.com/aquasecurity/trivy-db/pkg/metadata"
"github.com/aquasecurity/trivy/pkg/clock"
"github.com/aquasecurity/trivy/pkg/commands"
"github.com/aquasecurity/trivy/pkg/dbtest"
"github.com/aquasecurity/trivy/pkg/types"
Expand All @@ -44,8 +44,6 @@ func initDB(t *testing.T) string {
entries, err := os.ReadDir(fixtureDir)
require.NoError(t, err)

clock.SetFakeTime(t, time.Date(2021, 8, 25, 12, 20, 30, 5, time.UTC))

var fixtures []string
for _, entry := range entries {
if entry.IsDir() {
Expand Down Expand Up @@ -193,13 +191,16 @@ func readSpdxJson(t *testing.T, filePath string) *spdx.Document {
}

func execute(osArgs []string) error {
// Set a fake time
ctx := clock.With(context.Background(), time.Date(2021, 8, 25, 12, 20, 30, 5, time.UTC))

// Setup CLI App
app := commands.NewApp()
app.SetOut(io.Discard)
app.SetArgs(osArgs)

// Run Trivy
app.SetArgs(osArgs)
return app.Execute()
return app.ExecuteContext(ctx)
}

func compareReports(t *testing.T, wantFile, gotFile string, override func(*types.Report)) {
Expand Down
17 changes: 8 additions & 9 deletions integration/repo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,13 @@ package integration

import (
"fmt"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"os"
"path/filepath"
"strings"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/aquasecurity/trivy/pkg/clock"
ftypes "github.com/aquasecurity/trivy/pkg/fanal/types"
"github.com/aquasecurity/trivy/pkg/types"
"github.com/aquasecurity/trivy/pkg/uuid"
Expand Down Expand Up @@ -416,12 +413,15 @@ func TestRepository(t *testing.T) {

osArgs := []string{
"-q",
"--cache-dir", cacheDir,
"--cache-dir",
cacheDir,
command,
"--skip-db-update",
"--skip-policy-update",
"--format", string(format),
"--parallel", fmt.Sprint(tt.args.parallel),
"--format",
string(format),
"--parallel",
fmt.Sprint(tt.args.parallel),
"--offline-scan",
}

Expand Down Expand Up @@ -499,7 +499,6 @@ func TestRepository(t *testing.T) {
osArgs = append(osArgs, "--output", outputFile)
osArgs = append(osArgs, tt.args.input)

clock.SetFakeTime(t, time.Date(2021, 8, 25, 12, 20, 30, 5, time.UTC))
uuid.SetFakeUUID(t, "3ff14136-e09f-4df9-80ea-%012d")

// Run "trivy repo"
Expand Down
3 changes: 1 addition & 2 deletions integration/testdata/debian-stretch.json.golden
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
"Metadata": {
"OS": {
"Family": "debian",
"Name": "9.9",
"EOSL": true
"Name": "9.9"
},
"ImageID": "sha256:f26939cc87ef44a6fc554eedd0a976ab30b5bc2769d65d2e986b6c5f1fd4053d",
"DiffIDs": [
Expand Down
3 changes: 1 addition & 2 deletions integration/testdata/distroless-base.json.golden
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
"Metadata": {
"OS": {
"Family": "debian",
"Name": "9.9",
"EOSL": true
"Name": "9.9"
},
"ImageID": "sha256:7f04a8d247173b1f2546d22913af637bbab4e7411e00ae6207da8d94c445750d",
"DiffIDs": [
Expand Down
3 changes: 1 addition & 2 deletions integration/testdata/distroless-python27.json.golden
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
"Metadata": {
"OS": {
"Family": "debian",
"Name": "9.9",
"EOSL": true
"Name": "9.9"
},
"ImageID": "sha256:6fcac2cc8a710f21577b5bbd534e0bfc841c0cca569b57182ba19054696cddda",
"DiffIDs": [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
"Metadata": {
"OS": {
"Family": "ubuntu",
"Name": "18.04",
"EOSL": true
"Name": "18.04"
},
"ImageID": "sha256:a2a15febcdf362f6115e801d37b5e60d6faaeedcb9896155e5fe9d754025be12",
"DiffIDs": [
Expand Down
3 changes: 1 addition & 2 deletions integration/testdata/ubuntu-1804.json.golden
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
"Metadata": {
"OS": {
"Family": "ubuntu",
"Name": "18.04",
"EOSL": true
"Name": "18.04"
},
"ImageID": "sha256:a2a15febcdf362f6115e801d37b5e60d6faaeedcb9896155e5fe9d754025be12",
"DiffIDs": [
Expand Down
30 changes: 20 additions & 10 deletions pkg/clock/clock.go
Original file line number Diff line number Diff line change
@@ -1,23 +1,33 @@
package clock

import (
"testing"
"context"
"time"

"k8s.io/utils/clock"
clocktesting "k8s.io/utils/clock/testing"
)

var c clock.Clock = clock.RealClock{}
// clockKey is the context key for clock. It is unexported to prevent collisions with context keys defined in
// other packages.
type clockKey struct{}

// SetFakeTime sets a fake time for testing.
func SetFakeTime(t *testing.T, fakeTime time.Time) {
c = clocktesting.NewFakeClock(fakeTime)
t.Cleanup(func() {
c = clock.RealClock{}
})
// With returns a new context with the given time.
func With(ctx context.Context, t time.Time) context.Context {
c := clocktesting.NewFakeClock(t)
return context.WithValue(ctx, clockKey{}, c)
}

func Now() time.Time {
return c.Now()
// Now returns the current time.
func Now(ctx context.Context) time.Time {
return Clock(ctx).Now()
}

// Clock returns the clock from the context.
func Clock(ctx context.Context) clock.Clock {
t, ok := ctx.Value(clockKey{}).(clock.Clock)
if !ok {
return clock.RealClock{}
}
return t
}
31 changes: 22 additions & 9 deletions pkg/cloud/aws/commands/run_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1068,8 +1068,11 @@ Summary Report for compliance: my-custom-spec
MisconfOptions: flag.MisconfOptions{IncludeNonFailures: true},
},
cacheContent: "testdata/s3andcloudtrailcache.json",
allServices: []string{"s3", "cloudtrail"},
want: expectedS3AndCloudTrailResult,
allServices: []string{
"s3",
"cloudtrail",
},
want: expectedS3AndCloudTrailResult,
},
{
name: "skip certain services and include specific services",
Expand All @@ -1087,7 +1090,10 @@ Summary Report for compliance: my-custom-spec
MisconfOptions: flag.MisconfOptions{IncludeNonFailures: true},
},
cacheContent: "testdata/s3andcloudtrailcache.json",
allServices: []string{"s3", "cloudtrail"},
allServices: []string{
"s3",
"cloudtrail",
},
// we skip cloudtrail but still expect results from it as it is cached
want: expectedS3AndCloudTrailResult,
},
Expand All @@ -1096,16 +1102,23 @@ Summary Report for compliance: my-custom-spec
options: flag.Options{
RegoOptions: flag.RegoOptions{SkipPolicyUpdate: true},
AWSOptions: flag.AWSOptions{
Region: "us-east-1",
SkipServices: []string{"cloudtrail", "iam"},
Account: "12345678",
Region: "us-east-1",
SkipServices: []string{
"cloudtrail",
"iam",
},
Account: "12345678",
},
CloudOptions: flag.CloudOptions{
MaxCacheAge: time.Hour * 24 * 365 * 100,
},
MisconfOptions: flag.MisconfOptions{IncludeNonFailures: true},
},
allServices: []string{"s3", "cloudtrail", "iam"},
allServices: []string{
"s3",
"cloudtrail",
"iam",
},
cacheContent: "testdata/s3onlycache.json",
want: expectedS3ScanResult,
},
Expand All @@ -1129,7 +1142,7 @@ Summary Report for compliance: my-custom-spec
},
}

clock.SetFakeTime(t, time.Date(2021, 8, 25, 12, 20, 30, 5, time.UTC))
ctx := clock.With(context.Background(), time.Date(2021, 8, 25, 12, 20, 30, 5, time.UTC))
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
if test.allServices != nil {
Expand Down Expand Up @@ -1179,7 +1192,7 @@ Summary Report for compliance: my-custom-spec
require.NoError(t, os.WriteFile(cacheFile, cacheData, 0600))
}

err := Run(context.Background(), test.options)
err := Run(ctx, test.options)
if test.expectErr {
assert.Error(t, err)
return
Expand Down
8 changes: 4 additions & 4 deletions pkg/cloud/report/report.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func Write(ctx context.Context, rep *Report, opt flag.Options, fromCache bool) e
defer cleanup()

if opt.Compliance.Spec.ID != "" {
return writeCompliance(rep, opt, output)
return writeCompliance(ctx, rep, opt, output)
}

var filtered []types.Result
Expand All @@ -93,7 +93,7 @@ func Write(ctx context.Context, rep *Report, opt flag.Options, fromCache bool) e
})

base := types.Report{
CreatedAt: clock.Now(),
CreatedAt: clock.Now(ctx),
ArtifactName: rep.AccountID,
ArtifactType: ftypes.ArtifactAWSAccount,
Results: filtered,
Expand Down Expand Up @@ -139,7 +139,7 @@ func Write(ctx context.Context, rep *Report, opt flag.Options, fromCache bool) e
}
}

func writeCompliance(rep *Report, opt flag.Options, output io.Writer) error {
func writeCompliance(ctx context.Context, rep *Report, opt flag.Options, output io.Writer) error {
var crr []types.Results
for _, r := range rep.Results {
crr = append(crr, r.Results)
Expand All @@ -150,7 +150,7 @@ func writeCompliance(rep *Report, opt flag.Options, output io.Writer) error {
return xerrors.Errorf("compliance report build error: %w", err)
}

return cr.Write(complianceReport, cr.Option{
return cr.Write(ctx, complianceReport, cr.Option{
Format: opt.Format,
Report: opt.ReportFormat,
Output: output,
Expand Down
15 changes: 11 additions & 4 deletions pkg/cloud/report/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,10 @@ This scan report was loaded from cached results. If you'd like to run a fresh sc
},
},
AWSOptions: flag.AWSOptions{
Services: []string{"s3", "ec2"},
Services: []string{
"s3",
"ec2",
},
},
},
fromCache: false,
Expand Down Expand Up @@ -117,7 +120,11 @@ Scan Overview for AWS Account
},
},
AWSOptions: flag.AWSOptions{
Services: []string{"ec2", "s3", "iam"},
Services: []string{
"ec2",
"s3",
"iam",
},
},
},
fromCache: false,
Expand Down Expand Up @@ -310,7 +317,7 @@ Scan Overview for AWS Account
}`,
},
}
clock.SetFakeTime(t, time.Date(2021, 8, 25, 12, 20, 30, 5, time.UTC))
ctx := clock.With(context.Background(), time.Date(2021, 8, 25, 12, 20, 30, 5, time.UTC))
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
report := New(
Expand All @@ -323,7 +330,7 @@ Scan Overview for AWS Account

output := bytes.NewBuffer(nil)
tt.options.SetOutputWriter(output)
require.NoError(t, Write(context.Background(), report, tt.options, tt.fromCache))
require.NoError(t, Write(ctx, report, tt.options, tt.fromCache))

assert.Equal(t, "AWS", report.Provider)
assert.Equal(t, tt.options.AWSOptions.Account, report.AccountID)
Expand Down
2 changes: 1 addition & 1 deletion pkg/commands/server/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,5 +59,5 @@ func Run(ctx context.Context, opts flag.Options) (err error) {

server := rpcServer.NewServer(opts.AppVersion, opts.Listen, opts.CacheDir, opts.Token, opts.TokenHeader,
opts.DBRepository, opts.RegistryOpts())
return server.ListenAndServe(cache, opts.SkipDBUpdate)
return server.ListenAndServe(ctx, cache, opts.SkipDBUpdate)
}
7 changes: 4 additions & 3 deletions pkg/compliance/report/report.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package report

import (
"context"
"io"

"golang.org/x/xerrors"
Expand Down Expand Up @@ -63,8 +64,8 @@ type Writer interface {
Write(ComplianceReport) error
}

// Write writes the results in the give format
func Write(report *ComplianceReport, option Option) error {
// Write writes the results in the given format
func Write(ctx context.Context, report *ComplianceReport, option Option) error {
switch option.Format {
case types.FormatJSON:
jwriter := JSONWriter{
Expand All @@ -79,7 +80,7 @@ func Write(report *ComplianceReport, option Option) error {
Report: option.Report,
Severities: option.Severities,
}
err := complianceWriter.Write(report)
err := complianceWriter.Write(ctx, report)
if err != nil {
return err
}
Expand Down
Loading