diff --git a/pkg/cloud/aws/commands/run_test.go b/pkg/cloud/aws/commands/run_test.go index 70d8dc672368..a7ac58886f8b 100644 --- a/pkg/cloud/aws/commands/run_test.go +++ b/pkg/cloud/aws/commands/run_test.go @@ -1,6 +1,7 @@ package commands import ( + "bytes" "context" "os" "path/filepath" @@ -1135,8 +1136,8 @@ Summary Report for compliance: my-custom-spec }() } - output := filepath.Join(t.TempDir(), "output") - test.options.Output = output + output := bytes.NewBuffer(nil) + test.options.SetOutputWriter(output) test.options.Debug = true test.options.GlobalOptions.Timeout = time.Minute if test.options.Format == "" { @@ -1178,10 +1179,7 @@ Summary Report for compliance: my-custom-spec return } assert.NoError(t, err) - - b, err := os.ReadFile(output) - require.NoError(t, err) - assert.Equal(t, test.want, string(b)) + assert.Equal(t, test.want, output.String()) }) } } diff --git a/pkg/cloud/report/report.go b/pkg/cloud/report/report.go index 0e3eab9a94ee..6742db9d043a 100644 --- a/pkg/cloud/report/report.go +++ b/pkg/cloud/report/report.go @@ -59,11 +59,11 @@ func (r *Report) Failed() bool { // Write writes the results in the give format func Write(rep *Report, opt flag.Options, fromCache bool) error { - output, err := opt.OutputWriter() + output, cleanup, err := opt.OutputWriter() if err != nil { return xerrors.Errorf("failed to create output file: %w", err) } - defer output.Close() + defer cleanup() if opt.Compliance.Spec.ID != "" { return writeCompliance(rep, opt, output) @@ -104,7 +104,7 @@ func Write(rep *Report, opt flag.Options, fromCache bool) error { // ensure color/formatting is disabled for pipes/non-pty var useANSI bool - if opt.Output == "" { + if output == os.Stdout { if o, err := os.Stdout.Stat(); err == nil { useANSI = (o.Mode() & os.ModeCharDevice) == os.ModeCharDevice } diff --git a/pkg/cloud/report/resource_test.go b/pkg/cloud/report/resource_test.go index 07ff85a88c27..cb17c2658d57 100644 --- a/pkg/cloud/report/resource_test.go +++ b/pkg/cloud/report/resource_test.go @@ -1,8 +1,7 @@ package report import ( - "os" - "path/filepath" + "bytes" "testing" "github.com/stretchr/testify/assert" @@ -110,18 +109,15 @@ No problems detected. tt.options.AWSOptions.Services, ) - output := filepath.Join(t.TempDir(), "output") - tt.options.Output = output + output := bytes.NewBuffer(nil) + tt.options.SetOutputWriter(output) require.NoError(t, Write(report, tt.options, tt.fromCache)) assert.Equal(t, "AWS", report.Provider) assert.Equal(t, tt.options.AWSOptions.Account, report.AccountID) assert.Equal(t, tt.options.AWSOptions.Region, report.Region) assert.ElementsMatch(t, tt.options.AWSOptions.Services, report.ServicesInScope) - - b, err := os.ReadFile(output) - require.NoError(t, err) - assert.Equal(t, tt.expected, string(b)) + assert.Equal(t, tt.expected, output.String()) }) } } diff --git a/pkg/cloud/report/result_test.go b/pkg/cloud/report/result_test.go index 5b4f669b4650..f0ef85d4d564 100644 --- a/pkg/cloud/report/result_test.go +++ b/pkg/cloud/report/result_test.go @@ -1,8 +1,7 @@ package report import ( - "os" - "path/filepath" + "bytes" "strings" "testing" @@ -69,18 +68,15 @@ See https://avd.aquasec.com/misconfig/avd-aws-9999 tt.options.AWSOptions.Services, ) - output := filepath.Join(t.TempDir(), "output") - tt.options.Output = output + output := bytes.NewBuffer(nil) + tt.options.SetOutputWriter(output) require.NoError(t, Write(report, tt.options, tt.fromCache)) - b, err := os.ReadFile(output) - require.NoError(t, err) - assert.Equal(t, "AWS", report.Provider) assert.Equal(t, tt.options.AWSOptions.Account, report.AccountID) assert.Equal(t, tt.options.AWSOptions.Region, report.Region) assert.ElementsMatch(t, tt.options.AWSOptions.Services, report.ServicesInScope) - assert.Equal(t, tt.expected, strings.ReplaceAll(string(b), "\r\n", "\n")) + assert.Equal(t, tt.expected, strings.ReplaceAll(output.String(), "\r\n", "\n")) }) } } diff --git a/pkg/cloud/report/service_test.go b/pkg/cloud/report/service_test.go index 6e4ae99c2cd4..d520285cabf1 100644 --- a/pkg/cloud/report/service_test.go +++ b/pkg/cloud/report/service_test.go @@ -1,8 +1,7 @@ package report import ( - "os" - "path/filepath" + "bytes" "testing" "github.com/aws/aws-sdk-go-v2/aws/arn" @@ -317,8 +316,8 @@ Scan Overview for AWS Account tt.options.AWSOptions.Services, ) - output := filepath.Join(t.TempDir(), "output") - tt.options.Output = output + output := bytes.NewBuffer(nil) + tt.options.SetOutputWriter(output) require.NoError(t, Write(report, tt.options, tt.fromCache)) assert.Equal(t, "AWS", report.Provider) @@ -326,13 +325,11 @@ Scan Overview for AWS Account assert.Equal(t, tt.options.AWSOptions.Region, report.Region) assert.ElementsMatch(t, tt.options.AWSOptions.Services, report.ServicesInScope) - b, err := os.ReadFile(output) - require.NoError(t, err) if tt.options.Format == "json" { // json output can be formatted/ordered differently - we just care that the data matches - assert.JSONEq(t, tt.expected, string(b)) + assert.JSONEq(t, tt.expected, output.String()) } else { - assert.Equal(t, tt.expected, string(b)) + assert.Equal(t, tt.expected, output.String()) } }) } diff --git a/pkg/flag/options.go b/pkg/flag/options.go index caf723fa5bb7..a5c5b2355f82 100644 --- a/pkg/flag/options.go +++ b/pkg/flag/options.go @@ -20,7 +20,6 @@ import ( "github.com/aquasecurity/trivy/pkg/result" "github.com/aquasecurity/trivy/pkg/types" "github.com/aquasecurity/trivy/pkg/version" - xio "github.com/aquasecurity/trivy/pkg/x/io" xstrings "github.com/aquasecurity/trivy/pkg/x/strings" ) @@ -114,6 +113,10 @@ type Options struct { // We don't want to allow disabled analyzers to be passed by users, but it is necessary for internal use. DisabledAnalyzers []analyzer.Type + + // outputWriter is not initialized via the CLI. + // It is mainly used for testing purposes or by tools that use Trivy as a library. + outputWriter io.Writer } // Align takes consistency of options @@ -159,17 +162,26 @@ func (o *Options) FilterOpts() result.FilterOption { } } +// SetOutputWriter sets an output writer. +func (o *Options) SetOutputWriter(w io.Writer) { + o.outputWriter = w +} + // OutputWriter returns an output writer. // If the output file is not specified, it returns os.Stdout. -func (o *Options) OutputWriter() (io.WriteCloser, error) { +func (o *Options) OutputWriter() (io.Writer, func(), error) { + if o.outputWriter != nil { + return o.outputWriter, func() {}, nil + } + if o.Output != "" { f, err := os.Create(o.Output) if err != nil { - return nil, xerrors.Errorf("failed to create output file: %w", err) + return nil, nil, xerrors.Errorf("failed to create output file: %w", err) } - return f, nil + return f, func() { _ = f.Close() }, nil } - return xio.NopCloser(os.Stdout), nil + return os.Stdout, func() {}, nil } func addFlag(cmd *cobra.Command, flag *Flag) { diff --git a/pkg/k8s/commands/run.go b/pkg/k8s/commands/run.go index 28cda3f67b99..df6836997632 100644 --- a/pkg/k8s/commands/run.go +++ b/pkg/k8s/commands/run.go @@ -95,11 +95,11 @@ func (r *runner) run(ctx context.Context, artifacts []*k8sArtifacts.Artifact) er return xerrors.Errorf("k8s scan error: %w", err) } - output, err := r.flagOpts.OutputWriter() + output, cleanup, err := r.flagOpts.OutputWriter() if err != nil { return xerrors.Errorf("failed to create output file: %w", err) } - defer output.Close() + defer cleanup() if r.flagOpts.Compliance.Spec.ID != "" { var scanResults []types.Results diff --git a/pkg/report/table/table.go b/pkg/report/table/table.go index 94d63784bca1..8be620f36589 100644 --- a/pkg/report/table/table.go +++ b/pkg/report/table/table.go @@ -15,7 +15,6 @@ import ( "github.com/aquasecurity/tml" dbTypes "github.com/aquasecurity/trivy-db/pkg/types" "github.com/aquasecurity/trivy/pkg/types" - xio "github.com/aquasecurity/trivy/pkg/x/io" ) var ( @@ -137,7 +136,7 @@ func IsOutputToTerminal(output io.Writer) bool { return false } - if output != xio.NopCloser(os.Stdout) { + if output != os.Stdout { return false } o, err := os.Stdout.Stat() diff --git a/pkg/report/writer.go b/pkg/report/writer.go index 31c606f1e5bf..648a9372f534 100644 --- a/pkg/report/writer.go +++ b/pkg/report/writer.go @@ -25,11 +25,11 @@ const ( // Write writes the result to output, format as passed in argument func Write(report types.Report, option flag.Options) error { - output, err := option.OutputWriter() + output, cleanup, err := option.OutputWriter() if err != nil { return xerrors.Errorf("failed to create a file: %w", err) } - defer output.Close() + defer cleanup() // Compliance report if option.Compliance.Spec.ID != "" { diff --git a/pkg/x/io/io.go b/pkg/x/io/io.go index 8f935e08ec71..01055778ca12 100644 --- a/pkg/x/io/io.go +++ b/pkg/x/io/io.go @@ -9,18 +9,6 @@ import ( dio "github.com/aquasecurity/go-dep-parser/pkg/io" ) -// NopCloser returns a WriteCloser with a no-op Close method wrapping -// the provided Writer w. -func NopCloser(w io.Writer) io.WriteCloser { - return nopCloser{w} -} - -type nopCloser struct { - io.Writer -} - -func (nopCloser) Close() error { return nil } - func NewReadSeekerAt(r io.Reader) (dio.ReadSeekerAt, error) { if rr, ok := r.(dio.ReadSeekerAt); ok { return rr, nil