Skip to content

Commit

Permalink
refactor: propagate time through context values (#5858)
Browse files Browse the repository at this point in the history
Signed-off-by: knqyf263 <knqyf263@gmail.com>
  • Loading branch information
knqyf263 authored Jan 3, 2024
1 parent 1607eee commit da597c4
Show file tree
Hide file tree
Showing 77 changed files with 384 additions and 546 deletions.
4 changes: 0 additions & 4 deletions integration/client_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ import (
"github.com/stretchr/testify/require"
testcontainers "github.com/testcontainers/testcontainers-go"

"github.com/aquasecurity/trivy/pkg/clock"
"github.com/aquasecurity/trivy/pkg/report"
"github.com/aquasecurity/trivy/pkg/uuid"
)
Expand Down Expand Up @@ -364,8 +363,6 @@ func TestClientServerWithFormat(t *testing.T) {
}

fakeTime := time.Date(2021, 8, 25, 12, 20, 30, 5, time.UTC)
clock.SetFakeTime(t, fakeTime)

report.CustomTemplateFuncMap = map[string]interface{}{
"now": func() time.Time {
return fakeTime
Expand Down Expand Up @@ -428,7 +425,6 @@ func TestClientServerWithCycloneDX(t *testing.T) {
addr, cacheDir := setup(t, setupOptions{})
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
clock.SetFakeTime(t, time.Date(2021, 8, 25, 12, 20, 30, 5, time.UTC))
uuid.SetFakeUUID(t, "3ff14136-e09f-4df9-80ea-%012d")

osArgs, outputFile := setupClient(t, tt.args, addr, cacheDir, tt.golden)
Expand Down
11 changes: 6 additions & 5 deletions integration/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"encoding/json"
"flag"
"fmt"
"github.com/aquasecurity/trivy/pkg/clock"
"io"
"net"
"os"
Expand All @@ -27,7 +28,6 @@ import (

"github.com/aquasecurity/trivy-db/pkg/db"
"github.com/aquasecurity/trivy-db/pkg/metadata"
"github.com/aquasecurity/trivy/pkg/clock"
"github.com/aquasecurity/trivy/pkg/commands"
"github.com/aquasecurity/trivy/pkg/dbtest"
"github.com/aquasecurity/trivy/pkg/types"
Expand All @@ -44,8 +44,6 @@ func initDB(t *testing.T) string {
entries, err := os.ReadDir(fixtureDir)
require.NoError(t, err)

clock.SetFakeTime(t, time.Date(2021, 8, 25, 12, 20, 30, 5, time.UTC))

var fixtures []string
for _, entry := range entries {
if entry.IsDir() {
Expand Down Expand Up @@ -193,13 +191,16 @@ func readSpdxJson(t *testing.T, filePath string) *spdx.Document {
}

func execute(osArgs []string) error {
// Set a fake time
ctx := clock.With(context.Background(), time.Date(2021, 8, 25, 12, 20, 30, 5, time.UTC))

// Setup CLI App
app := commands.NewApp()
app.SetOut(io.Discard)
app.SetArgs(osArgs)

// Run Trivy
app.SetArgs(osArgs)
return app.Execute()
return app.ExecuteContext(ctx)
}

func compareReports(t *testing.T, wantFile, gotFile string, override func(*types.Report)) {
Expand Down
17 changes: 8 additions & 9 deletions integration/repo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,13 @@ package integration

import (
"fmt"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"os"
"path/filepath"
"strings"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/aquasecurity/trivy/pkg/clock"
ftypes "github.com/aquasecurity/trivy/pkg/fanal/types"
"github.com/aquasecurity/trivy/pkg/types"
"github.com/aquasecurity/trivy/pkg/uuid"
Expand Down Expand Up @@ -416,12 +413,15 @@ func TestRepository(t *testing.T) {

osArgs := []string{
"-q",
"--cache-dir", cacheDir,
"--cache-dir",
cacheDir,
command,
"--skip-db-update",
"--skip-policy-update",
"--format", string(format),
"--parallel", fmt.Sprint(tt.args.parallel),
"--format",
string(format),
"--parallel",
fmt.Sprint(tt.args.parallel),
"--offline-scan",
}

Expand Down Expand Up @@ -499,7 +499,6 @@ func TestRepository(t *testing.T) {
osArgs = append(osArgs, "--output", outputFile)
osArgs = append(osArgs, tt.args.input)

clock.SetFakeTime(t, time.Date(2021, 8, 25, 12, 20, 30, 5, time.UTC))
uuid.SetFakeUUID(t, "3ff14136-e09f-4df9-80ea-%012d")

// Run "trivy repo"
Expand Down
3 changes: 1 addition & 2 deletions integration/testdata/debian-stretch.json.golden
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
"Metadata": {
"OS": {
"Family": "debian",
"Name": "9.9",
"EOSL": true
"Name": "9.9"
},
"ImageID": "sha256:f26939cc87ef44a6fc554eedd0a976ab30b5bc2769d65d2e986b6c5f1fd4053d",
"DiffIDs": [
Expand Down
3 changes: 1 addition & 2 deletions integration/testdata/distroless-base.json.golden
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
"Metadata": {
"OS": {
"Family": "debian",
"Name": "9.9",
"EOSL": true
"Name": "9.9"
},
"ImageID": "sha256:7f04a8d247173b1f2546d22913af637bbab4e7411e00ae6207da8d94c445750d",
"DiffIDs": [
Expand Down
3 changes: 1 addition & 2 deletions integration/testdata/distroless-python27.json.golden
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
"Metadata": {
"OS": {
"Family": "debian",
"Name": "9.9",
"EOSL": true
"Name": "9.9"
},
"ImageID": "sha256:6fcac2cc8a710f21577b5bbd534e0bfc841c0cca569b57182ba19054696cddda",
"DiffIDs": [
Expand Down
3 changes: 1 addition & 2 deletions integration/testdata/ubuntu-1804-ignore-unfixed.json.golden
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
"Metadata": {
"OS": {
"Family": "ubuntu",
"Name": "18.04",
"EOSL": true
"Name": "18.04"
},
"ImageID": "sha256:a2a15febcdf362f6115e801d37b5e60d6faaeedcb9896155e5fe9d754025be12",
"DiffIDs": [
Expand Down
3 changes: 1 addition & 2 deletions integration/testdata/ubuntu-1804.json.golden
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
"Metadata": {
"OS": {
"Family": "ubuntu",
"Name": "18.04",
"EOSL": true
"Name": "18.04"
},
"ImageID": "sha256:a2a15febcdf362f6115e801d37b5e60d6faaeedcb9896155e5fe9d754025be12",
"DiffIDs": [
Expand Down
30 changes: 20 additions & 10 deletions pkg/clock/clock.go
Original file line number Diff line number Diff line change
@@ -1,23 +1,33 @@
package clock

import (
"testing"
"context"
"time"

"k8s.io/utils/clock"
clocktesting "k8s.io/utils/clock/testing"
)

var c clock.Clock = clock.RealClock{}
// clockKey is the context key for clock. It is unexported to prevent collisions with context keys defined in
// other packages.
type clockKey struct{}

// SetFakeTime sets a fake time for testing.
func SetFakeTime(t *testing.T, fakeTime time.Time) {
c = clocktesting.NewFakeClock(fakeTime)
t.Cleanup(func() {
c = clock.RealClock{}
})
// With returns a new context with the given time.
func With(ctx context.Context, t time.Time) context.Context {
c := clocktesting.NewFakeClock(t)
return context.WithValue(ctx, clockKey{}, c)
}

func Now() time.Time {
return c.Now()
// Now returns the current time.
func Now(ctx context.Context) time.Time {
return Clock(ctx).Now()
}

// Clock returns the clock from the context.
func Clock(ctx context.Context) clock.Clock {
t, ok := ctx.Value(clockKey{}).(clock.Clock)
if !ok {
return clock.RealClock{}
}
return t
}
31 changes: 22 additions & 9 deletions pkg/cloud/aws/commands/run_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1068,8 +1068,11 @@ Summary Report for compliance: my-custom-spec
MisconfOptions: flag.MisconfOptions{IncludeNonFailures: true},
},
cacheContent: "testdata/s3andcloudtrailcache.json",
allServices: []string{"s3", "cloudtrail"},
want: expectedS3AndCloudTrailResult,
allServices: []string{
"s3",
"cloudtrail",
},
want: expectedS3AndCloudTrailResult,
},
{
name: "skip certain services and include specific services",
Expand All @@ -1087,7 +1090,10 @@ Summary Report for compliance: my-custom-spec
MisconfOptions: flag.MisconfOptions{IncludeNonFailures: true},
},
cacheContent: "testdata/s3andcloudtrailcache.json",
allServices: []string{"s3", "cloudtrail"},
allServices: []string{
"s3",
"cloudtrail",
},
// we skip cloudtrail but still expect results from it as it is cached
want: expectedS3AndCloudTrailResult,
},
Expand All @@ -1096,16 +1102,23 @@ Summary Report for compliance: my-custom-spec
options: flag.Options{
RegoOptions: flag.RegoOptions{SkipPolicyUpdate: true},
AWSOptions: flag.AWSOptions{
Region: "us-east-1",
SkipServices: []string{"cloudtrail", "iam"},
Account: "12345678",
Region: "us-east-1",
SkipServices: []string{
"cloudtrail",
"iam",
},
Account: "12345678",
},
CloudOptions: flag.CloudOptions{
MaxCacheAge: time.Hour * 24 * 365 * 100,
},
MisconfOptions: flag.MisconfOptions{IncludeNonFailures: true},
},
allServices: []string{"s3", "cloudtrail", "iam"},
allServices: []string{
"s3",
"cloudtrail",
"iam",
},
cacheContent: "testdata/s3onlycache.json",
want: expectedS3ScanResult,
},
Expand All @@ -1129,7 +1142,7 @@ Summary Report for compliance: my-custom-spec
},
}

clock.SetFakeTime(t, time.Date(2021, 8, 25, 12, 20, 30, 5, time.UTC))
ctx := clock.With(context.Background(), time.Date(2021, 8, 25, 12, 20, 30, 5, time.UTC))
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
if test.allServices != nil {
Expand Down Expand Up @@ -1179,7 +1192,7 @@ Summary Report for compliance: my-custom-spec
require.NoError(t, os.WriteFile(cacheFile, cacheData, 0600))
}

err := Run(context.Background(), test.options)
err := Run(ctx, test.options)
if test.expectErr {
assert.Error(t, err)
return
Expand Down
8 changes: 4 additions & 4 deletions pkg/cloud/report/report.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func Write(ctx context.Context, rep *Report, opt flag.Options, fromCache bool) e
defer cleanup()

if opt.Compliance.Spec.ID != "" {
return writeCompliance(rep, opt, output)
return writeCompliance(ctx, rep, opt, output)
}

var filtered []types.Result
Expand All @@ -93,7 +93,7 @@ func Write(ctx context.Context, rep *Report, opt flag.Options, fromCache bool) e
})

base := types.Report{
CreatedAt: clock.Now(),
CreatedAt: clock.Now(ctx),
ArtifactName: rep.AccountID,
ArtifactType: ftypes.ArtifactAWSAccount,
Results: filtered,
Expand Down Expand Up @@ -139,7 +139,7 @@ func Write(ctx context.Context, rep *Report, opt flag.Options, fromCache bool) e
}
}

func writeCompliance(rep *Report, opt flag.Options, output io.Writer) error {
func writeCompliance(ctx context.Context, rep *Report, opt flag.Options, output io.Writer) error {
var crr []types.Results
for _, r := range rep.Results {
crr = append(crr, r.Results)
Expand All @@ -150,7 +150,7 @@ func writeCompliance(rep *Report, opt flag.Options, output io.Writer) error {
return xerrors.Errorf("compliance report build error: %w", err)
}

return cr.Write(complianceReport, cr.Option{
return cr.Write(ctx, complianceReport, cr.Option{
Format: opt.Format,
Report: opt.ReportFormat,
Output: output,
Expand Down
15 changes: 11 additions & 4 deletions pkg/cloud/report/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,10 @@ This scan report was loaded from cached results. If you'd like to run a fresh sc
},
},
AWSOptions: flag.AWSOptions{
Services: []string{"s3", "ec2"},
Services: []string{
"s3",
"ec2",
},
},
},
fromCache: false,
Expand Down Expand Up @@ -117,7 +120,11 @@ Scan Overview for AWS Account
},
},
AWSOptions: flag.AWSOptions{
Services: []string{"ec2", "s3", "iam"},
Services: []string{
"ec2",
"s3",
"iam",
},
},
},
fromCache: false,
Expand Down Expand Up @@ -310,7 +317,7 @@ Scan Overview for AWS Account
}`,
},
}
clock.SetFakeTime(t, time.Date(2021, 8, 25, 12, 20, 30, 5, time.UTC))
ctx := clock.With(context.Background(), time.Date(2021, 8, 25, 12, 20, 30, 5, time.UTC))
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
report := New(
Expand All @@ -323,7 +330,7 @@ Scan Overview for AWS Account

output := bytes.NewBuffer(nil)
tt.options.SetOutputWriter(output)
require.NoError(t, Write(context.Background(), report, tt.options, tt.fromCache))
require.NoError(t, Write(ctx, report, tt.options, tt.fromCache))

assert.Equal(t, "AWS", report.Provider)
assert.Equal(t, tt.options.AWSOptions.Account, report.AccountID)
Expand Down
2 changes: 1 addition & 1 deletion pkg/commands/server/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,5 +59,5 @@ func Run(ctx context.Context, opts flag.Options) (err error) {

server := rpcServer.NewServer(opts.AppVersion, opts.Listen, opts.CacheDir, opts.Token, opts.TokenHeader,
opts.DBRepository, opts.RegistryOpts())
return server.ListenAndServe(cache, opts.SkipDBUpdate)
return server.ListenAndServe(ctx, cache, opts.SkipDBUpdate)
}
7 changes: 4 additions & 3 deletions pkg/compliance/report/report.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package report

import (
"context"
"io"

"golang.org/x/xerrors"
Expand Down Expand Up @@ -63,8 +64,8 @@ type Writer interface {
Write(ComplianceReport) error
}

// Write writes the results in the give format
func Write(report *ComplianceReport, option Option) error {
// Write writes the results in the given format
func Write(ctx context.Context, report *ComplianceReport, option Option) error {
switch option.Format {
case types.FormatJSON:
jwriter := JSONWriter{
Expand All @@ -79,7 +80,7 @@ func Write(report *ComplianceReport, option Option) error {
Report: option.Report,
Severities: option.Severities,
}
err := complianceWriter.Write(report)
err := complianceWriter.Write(ctx, report)
if err != nil {
return err
}
Expand Down
Loading

0 comments on commit da597c4

Please sign in to comment.