From 5c03581849616cd602812fce546ed263fe817f0b Mon Sep 17 00:00:00 2001 From: knqyf263 Date: Wed, 19 Jul 2023 12:35:21 +0300 Subject: [PATCH 1/5] fix(report): close the file --- pkg/cloud/aws/commands/run.go | 23 +------------- pkg/cloud/aws/commands/run_test.go | 14 +++++---- pkg/cloud/report/report.go | 45 ++++++++++++++++++++++++--- pkg/cloud/report/resource_test.go | 12 ++++--- pkg/cloud/report/result_test.go | 12 ++++--- pkg/cloud/report/service_test.go | 14 ++++++--- pkg/commands/app.go | 50 ++++++++++++------------------ pkg/commands/app_test.go | 6 ++-- pkg/flag/options.go | 5 ++- pkg/flag/report_flags.go | 16 ++-------- pkg/flag/report_flags_test.go | 14 ++------- pkg/k8s/commands/run.go | 14 +++++++-- pkg/report/github/github_test.go | 14 +++------ pkg/report/json_test.go | 12 +++---- pkg/report/sarif_test.go | 10 +++--- pkg/report/table/table_test.go | 8 ++--- pkg/report/template_test.go | 9 +++--- pkg/report/writer.go | 37 ++++++++++++++-------- 18 files changed, 162 insertions(+), 153 deletions(-) diff --git a/pkg/cloud/aws/commands/run.go b/pkg/cloud/aws/commands/run.go index 3d4317f38b68..7d9356a31bd7 100644 --- a/pkg/cloud/aws/commands/run.go +++ b/pkg/cloud/aws/commands/run.go @@ -5,10 +5,9 @@ import ( "errors" "strings" - "golang.org/x/exp/slices" - "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/service/sts" + "golang.org/x/exp/slices" "golang.org/x/xerrors" "github.com/aquasecurity/defsec/pkg/errs" @@ -17,10 +16,8 @@ import ( "github.com/aquasecurity/trivy/pkg/cloud/aws/scanner" "github.com/aquasecurity/trivy/pkg/cloud/report" "github.com/aquasecurity/trivy/pkg/commands/operation" - cr "github.com/aquasecurity/trivy/pkg/compliance/report" "github.com/aquasecurity/trivy/pkg/flag" "github.com/aquasecurity/trivy/pkg/log" - "github.com/aquasecurity/trivy/pkg/types" ) var allSupportedServicesFunc = awsScanner.AllSupportedServices @@ -166,24 +163,6 @@ func Run(ctx context.Context, opt flag.Options) error { } log.Logger.Debug("Writing report to output...") - if opt.Compliance.Spec.ID != "" { - convertedResults := report.ConvertResults(results, cloud.ProviderAWS, opt.Services) - var crr []types.Results - for _, r := range convertedResults { - crr = append(crr, r.Results) - } - - complianceReport, err := cr.BuildComplianceReport(crr, opt.Compliance) - if err != nil { - return xerrors.Errorf("compliance report build error: %w", err) - } - - return cr.Write(complianceReport, cr.Option{ - Format: opt.Format, - Report: opt.ReportFormat, - Output: opt.Output, - }) - } res := results.GetFailed() if opt.MisconfOptions.IncludeNonFailures { diff --git a/pkg/cloud/aws/commands/run_test.go b/pkg/cloud/aws/commands/run_test.go index 5c6eb1962110..ef84d3c05508 100644 --- a/pkg/cloud/aws/commands/run_test.go +++ b/pkg/cloud/aws/commands/run_test.go @@ -1,7 +1,6 @@ package commands import ( - "bytes" "context" "os" "path/filepath" @@ -1243,8 +1242,8 @@ Summary Report for compliance: my-custom-spec }() } - buffer := new(bytes.Buffer) - test.options.Output = buffer + output := filepath.Join(t.TempDir(), "output") + test.options.Output = output test.options.Debug = true test.options.GlobalOptions.Timeout = time.Minute if test.options.Format == "" { @@ -1283,10 +1282,13 @@ Summary Report for compliance: my-custom-spec err := Run(context.Background(), test.options) if test.expectErr { assert.Error(t, err) - } else { - assert.NoError(t, err) - assert.Equal(t, test.want, buffer.String()) + return } + assert.NoError(t, err) + + b, err := os.ReadFile(output) + require.NoError(t, err) + assert.Equal(t, test.want, string(b)) }) } } diff --git a/pkg/cloud/report/report.go b/pkg/cloud/report/report.go index 8441944671e0..c0517d6357d1 100644 --- a/pkg/cloud/report/report.go +++ b/pkg/cloud/report/report.go @@ -2,12 +2,16 @@ package report import ( "context" + "io" "os" "sort" "time" + "golang.org/x/xerrors" + "github.com/aquasecurity/defsec/pkg/scan" "github.com/aquasecurity/tml" + cr "github.com/aquasecurity/trivy/pkg/compliance/report" ftypes "github.com/aquasecurity/trivy/pkg/fanal/types" "github.com/aquasecurity/trivy/pkg/flag" pkgReport "github.com/aquasecurity/trivy/pkg/report" @@ -55,6 +59,19 @@ func (r *Report) Failed() bool { // Write writes the results in the give format func Write(rep *Report, opt flag.Options, fromCache bool) error { + output := os.Stdout + if opt.Output != "" { + f, err := os.Create(opt.Output) + if err != nil { + return xerrors.Errorf("failed to create output file: %w", err) + } + output = f + defer f.Close() + } + + if opt.Compliance.Spec.ID != "" { + return writeCompliance(rep, opt, output) + } var filtered []types.Result @@ -91,7 +108,7 @@ func Write(rep *Report, opt flag.Options, fromCache bool) error { // ensure color/formatting is disabled for pipes/non-pty var useANSI bool - if opt.Output == os.Stdout { + if opt.Output == "" { if o, err := os.Stdout.Stat(); err == nil { useANSI = (o.Mode() & os.ModeCharDevice) == os.ModeCharDevice } @@ -102,22 +119,22 @@ func Write(rep *Report, opt flag.Options, fromCache bool) error { switch { case len(opt.Services) == 1 && opt.ARN == "": - if err := writeResourceTable(rep, filtered, opt.Output, opt.Services[0]); err != nil { + if err := writeResourceTable(rep, filtered, output, opt.Services[0]); err != nil { return err } case len(opt.Services) == 1 && opt.ARN != "": - if err := writeResultsForARN(rep, filtered, opt.Output, opt.Services[0], opt.ARN, opt.Severities); err != nil { + if err := writeResultsForARN(rep, filtered, output, opt.Services[0], opt.ARN, opt.Severities); err != nil { return err } default: - if err := writeServiceTable(rep, filtered, opt.Output); err != nil { + if err := writeServiceTable(rep, filtered, output); err != nil { return err } } // render cache info if fromCache { - _ = tml.Fprintf(opt.Output, "\nThis scan report was loaded from cached results. If you'd like to run a fresh scan, use --update-cache.\n") + _ = tml.Fprintf(output, "\nThis scan report was loaded from cached results. If you'd like to run a fresh scan, use --update-cache.\n") } return nil @@ -132,3 +149,21 @@ func Write(rep *Report, opt flag.Options, fromCache bool) error { }) } } + +func writeCompliance(rep *Report, opt flag.Options, output io.Writer) error { + var crr []types.Results + for _, r := range rep.Results { + crr = append(crr, r.Results) + } + + complianceReport, err := cr.BuildComplianceReport(crr, opt.Compliance) + if err != nil { + return xerrors.Errorf("compliance report build error: %w", err) + } + + return cr.Write(complianceReport, cr.Option{ + Format: opt.Format, + Report: opt.ReportFormat, + Output: output, + }) +} diff --git a/pkg/cloud/report/resource_test.go b/pkg/cloud/report/resource_test.go index dbe070cff93b..07ff85a88c27 100644 --- a/pkg/cloud/report/resource_test.go +++ b/pkg/cloud/report/resource_test.go @@ -1,7 +1,8 @@ package report import ( - "bytes" + "os" + "path/filepath" "testing" "github.com/stretchr/testify/assert" @@ -109,15 +110,18 @@ No problems detected. tt.options.AWSOptions.Services, ) - buffer := bytes.NewBuffer([]byte{}) - tt.options.Output = buffer + output := filepath.Join(t.TempDir(), "output") + tt.options.Output = output require.NoError(t, Write(report, tt.options, tt.fromCache)) assert.Equal(t, "AWS", report.Provider) assert.Equal(t, tt.options.AWSOptions.Account, report.AccountID) assert.Equal(t, tt.options.AWSOptions.Region, report.Region) assert.ElementsMatch(t, tt.options.AWSOptions.Services, report.ServicesInScope) - assert.Equal(t, tt.expected, buffer.String()) + + b, err := os.ReadFile(output) + require.NoError(t, err) + assert.Equal(t, tt.expected, string(b)) }) } } diff --git a/pkg/cloud/report/result_test.go b/pkg/cloud/report/result_test.go index e12f63b19fd8..5b4f669b4650 100644 --- a/pkg/cloud/report/result_test.go +++ b/pkg/cloud/report/result_test.go @@ -1,7 +1,8 @@ package report import ( - "bytes" + "os" + "path/filepath" "strings" "testing" @@ -68,15 +69,18 @@ See https://avd.aquasec.com/misconfig/avd-aws-9999 tt.options.AWSOptions.Services, ) - buffer := bytes.NewBuffer([]byte{}) - tt.options.Output = buffer + output := filepath.Join(t.TempDir(), "output") + tt.options.Output = output require.NoError(t, Write(report, tt.options, tt.fromCache)) + b, err := os.ReadFile(output) + require.NoError(t, err) + assert.Equal(t, "AWS", report.Provider) assert.Equal(t, tt.options.AWSOptions.Account, report.AccountID) assert.Equal(t, tt.options.AWSOptions.Region, report.Region) assert.ElementsMatch(t, tt.options.AWSOptions.Services, report.ServicesInScope) - assert.Equal(t, tt.expected, strings.ReplaceAll(buffer.String(), "\r\n", "\n")) + assert.Equal(t, tt.expected, strings.ReplaceAll(string(b), "\r\n", "\n")) }) } } diff --git a/pkg/cloud/report/service_test.go b/pkg/cloud/report/service_test.go index 3b66f5d0fcea..cf14466e4d86 100644 --- a/pkg/cloud/report/service_test.go +++ b/pkg/cloud/report/service_test.go @@ -1,7 +1,8 @@ package report import ( - "bytes" + "os" + "path/filepath" "testing" "github.com/aquasecurity/trivy-db/pkg/types" @@ -320,19 +321,22 @@ Scan Overview for AWS Account tt.options.AWSOptions.Services, ) - buffer := bytes.NewBuffer([]byte{}) - tt.options.Output = buffer + output := filepath.Join(t.TempDir(), "output") + tt.options.Output = output require.NoError(t, Write(report, tt.options, tt.fromCache)) assert.Equal(t, "AWS", report.Provider) assert.Equal(t, tt.options.AWSOptions.Account, report.AccountID) assert.Equal(t, tt.options.AWSOptions.Region, report.Region) assert.ElementsMatch(t, tt.options.AWSOptions.Services, report.ServicesInScope) + + b, err := os.ReadFile(output) + require.NoError(t, err) if tt.options.Format == "json" { // json output can be formatted/ordered differently - we just care that the data matches - assert.JSONEq(t, tt.expected, buffer.String()) + assert.JSONEq(t, tt.expected, string(b)) } else { - assert.Equal(t, tt.expected, buffer.String()) + assert.Equal(t, tt.expected, string(b)) } }) } diff --git a/pkg/commands/app.go b/pkg/commands/app.go index cc9991af7d43..294a73f6f454 100644 --- a/pkg/commands/app.go +++ b/pkg/commands/app.go @@ -71,15 +71,6 @@ Use "{{.CommandPath}} [command] --help" for more information about a command.{{e groupPlugin = "plugin" ) -var ( - outputWriter io.Writer = os.Stdout -) - -// SetOut overrides the destination for messages -func SetOut(out io.Writer) { - outputWriter = out -} - // NewApp is the factory method to return Trivy CLI func NewApp(version string) *cobra.Command { globalFlags := flag.NewGlobalFlagGroup() @@ -189,8 +180,6 @@ func NewRootCommand(version string, globalFlags *flag.GlobalFlagGroup) *cobra.Co $ trivy server`, Args: cobra.NoArgs, PersistentPreRunE: func(cmd *cobra.Command, args []string) error { - cmd.SetOut(outputWriter) - // Set the Trivy version here so that we can override version printer. cmd.Version = version @@ -224,7 +213,7 @@ func NewRootCommand(version string, globalFlags *flag.GlobalFlagGroup) *cobra.Co globalOptions := globalFlags.ToOptions() if globalOptions.ShowVersion { // Customize version output - showVersion(globalOptions.CacheDir, versionFormat, version, outputWriter) + showVersion(globalOptions.CacheDir, versionFormat, version, cmd.OutOrStdout()) } else { return cmd.Help() } @@ -310,7 +299,7 @@ func NewImageCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { return validateArgs(cmd, args) }, RunE: func(cmd *cobra.Command, args []string) error { - options, err := imageFlags.ToOptions(cmd.Version, args, globalFlags, outputWriter) + options, err := imageFlags.ToOptions(cmd.Version, args, globalFlags) if err != nil { return xerrors.Errorf("flag error: %w", err) } @@ -369,7 +358,7 @@ func NewFilesystemCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { if err := fsFlags.Bind(cmd); err != nil { return xerrors.Errorf("flag bind error: %w", err) } - options, err := fsFlags.ToOptions(cmd.Version, args, globalFlags, outputWriter) + options, err := fsFlags.ToOptions(cmd.Version, args, globalFlags) if err != nil { return xerrors.Errorf("flag error: %w", err) } @@ -428,7 +417,7 @@ func NewRootfsCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { if err := rootfsFlags.Bind(cmd); err != nil { return xerrors.Errorf("flag bind error: %w", err) } - options, err := rootfsFlags.ToOptions(cmd.Version, args, globalFlags, outputWriter) + options, err := rootfsFlags.ToOptions(cmd.Version, args, globalFlags) if err != nil { return xerrors.Errorf("flag error: %w", err) } @@ -482,7 +471,7 @@ func NewRepositoryCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { if err := repoFlags.Bind(cmd); err != nil { return xerrors.Errorf("flag bind error: %w", err) } - options, err := repoFlags.ToOptions(cmd.Version, args, globalFlags, outputWriter) + options, err := repoFlags.ToOptions(cmd.Version, args, globalFlags) if err != nil { return xerrors.Errorf("flag error: %w", err) } @@ -522,7 +511,7 @@ func NewConvertCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { if err := convertFlags.Bind(cmd); err != nil { return xerrors.Errorf("flag bind error: %w", err) } - opts, err := convertFlags.ToOptions(cmd.Version, args, globalFlags, outputWriter) + opts, err := convertFlags.ToOptions(cmd.Version, args, globalFlags) if err != nil { return xerrors.Errorf("flag error: %w", err) } @@ -579,7 +568,7 @@ func NewClientCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { if err := clientFlags.Bind(cmd); err != nil { return xerrors.Errorf("flag bind error: %w", err) } - options, err := clientFlags.ToOptions(cmd.Version, args, globalFlags, outputWriter) + options, err := clientFlags.ToOptions(cmd.Version, args, globalFlags) if err != nil { return xerrors.Errorf("flag error: %w", err) } @@ -620,7 +609,7 @@ func NewServerCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { if err := serverFlags.Bind(cmd); err != nil { return xerrors.Errorf("flag bind error: %w", err) } - options, err := serverFlags.ToOptions(cmd.Version, args, globalFlags, outputWriter) + options, err := serverFlags.ToOptions(cmd.Version, args, globalFlags) if err != nil { return xerrors.Errorf("flag error: %w", err) } @@ -682,7 +671,7 @@ func NewConfigCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { if err := configFlags.Bind(cmd); err != nil { return xerrors.Errorf("flag bind error: %w", err) } - options, err := configFlags.ToOptions(cmd.Version, args, globalFlags, outputWriter) + options, err := configFlags.ToOptions(cmd.Version, args, globalFlags) if err != nil { return xerrors.Errorf("flag error: %w", err) } @@ -840,7 +829,7 @@ func NewModuleCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { } repo := args[0] - opts, err := moduleFlags.ToOptions(cmd.Version, args, globalFlags, outputWriter) + opts, err := moduleFlags.ToOptions(cmd.Version, args, globalFlags) if err != nil { return xerrors.Errorf("flag error: %w", err) } @@ -864,7 +853,7 @@ func NewModuleCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { } repo := args[0] - opts, err := moduleFlags.ToOptions(cmd.Version, args, globalFlags, outputWriter) + opts, err := moduleFlags.ToOptions(cmd.Version, args, globalFlags) if err != nil { return xerrors.Errorf("flag error: %w", err) } @@ -953,7 +942,7 @@ func NewKubernetesCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { if err := k8sFlags.Bind(cmd); err != nil { return xerrors.Errorf("flag bind error: %w", err) } - opts, err := k8sFlags.ToOptions(cmd.Version, args, globalFlags, outputWriter) + opts, err := k8sFlags.ToOptions(cmd.Version, args, globalFlags) if err != nil { return xerrors.Errorf("flag error: %w", err) } @@ -1017,7 +1006,7 @@ The following services are supported: return nil }, RunE: func(cmd *cobra.Command, args []string) error { - opts, err := awsFlags.ToOptions(cmd.Version, args, globalFlags, outputWriter) + opts, err := awsFlags.ToOptions(cmd.Version, args, globalFlags) if err != nil { return xerrors.Errorf("flag error: %w", err) } @@ -1081,7 +1070,7 @@ func NewVMCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { if err := vmFlags.Bind(cmd); err != nil { return xerrors.Errorf("flag bind error: %w", err) } - options, err := vmFlags.ToOptions(cmd.Version, args, globalFlags, outputWriter) + options, err := vmFlags.ToOptions(cmd.Version, args, globalFlags) if err != nil { return xerrors.Errorf("flag error: %w", err) } @@ -1140,7 +1129,7 @@ func NewSBOMCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { if err := sbomFlags.Bind(cmd); err != nil { return xerrors.Errorf("flag bind error: %w", err) } - options, err := sbomFlags.ToOptions(cmd.Version, args, globalFlags, outputWriter) + options, err := sbomFlags.ToOptions(cmd.Version, args, globalFlags) if err != nil { return xerrors.Errorf("flag error: %w", err) } @@ -1169,7 +1158,7 @@ func NewVersionCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { Args: cobra.NoArgs, RunE: func(cmd *cobra.Command, args []string) error { options := globalFlags.ToOptions() - showVersion(options.CacheDir, versionFormat, cmd.Version, outputWriter) + showVersion(options.CacheDir, versionFormat, cmd.Version, cmd.OutOrStdout()) return nil }, @@ -1184,7 +1173,7 @@ func NewVersionCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { return cmd } -func showVersion(cacheDir, outputFormat, version string, outputWriter io.Writer) { +func showVersion(cacheDir, outputFormat, version string, w io.Writer) { var dbMeta *metadata.Metadata var javadbMeta *metadata.Metadata @@ -1218,13 +1207,12 @@ func showVersion(cacheDir, outputFormat, version string, outputWriter io.Writer) switch outputFormat { case "json": - b, _ := json.Marshal(VersionInfo{ + _ = json.NewEncoder(w).Encode(VersionInfo{ Version: version, VulnerabilityDB: dbMeta, JavaDB: javadbMeta, PolicyBundle: pbMeta, }) - fmt.Fprintln(outputWriter, string(b)) default: output := fmt.Sprintf("Version: %s\n", version) if dbMeta != nil { @@ -1251,7 +1239,7 @@ func showVersion(cacheDir, outputFormat, version string, outputWriter io.Writer) DownloadedAt: %s `, pbMeta.Digest, pbMeta.DownloadedAt.UTC()) } - fmt.Fprintf(outputWriter, output) + fmt.Fprintf(w, output) } } diff --git a/pkg/commands/app_test.go b/pkg/commands/app_test.go index 5b4661c382e0..21e0d71f71f2 100644 --- a/pkg/commands/app_test.go +++ b/pkg/commands/app_test.go @@ -158,7 +158,7 @@ Policy Bundle: t.Run(test.name, func(t *testing.T) { got := new(bytes.Buffer) app := NewApp("test") - SetOut(got) + app.SetOut(got) app.SetArgs(test.arguments) err := app.Execute() @@ -259,7 +259,7 @@ func TestFlags(t *testing.T) { globalFlags := flag.NewGlobalFlagGroup() rootCmd := NewRootCommand("dev", globalFlags) rootCmd.SetErr(io.Discard) - SetOut(io.Discard) + rootCmd.SetOut(io.Discard) flags := &flag.Flags{ ReportFlagGroup: flag.NewReportFlagGroup(), @@ -270,7 +270,7 @@ func TestFlags(t *testing.T) { // Bind require.NoError(t, flags.Bind(cmd)) - options, err := flags.ToOptions("dev", args, globalFlags, nil) + options, err := flags.ToOptions("dev", args, globalFlags) require.NoError(t, err) assert.Equal(t, tt.want.format, options.Format) diff --git a/pkg/flag/options.go b/pkg/flag/options.go index ae6f076c3873..cfecff4c81db 100644 --- a/pkg/flag/options.go +++ b/pkg/flag/options.go @@ -2,7 +2,6 @@ package flag import ( "fmt" - "io" "os" "strings" "sync" @@ -441,7 +440,7 @@ func (f *Flags) Bind(cmd *cobra.Command) error { } // nolint: gocyclo -func (f *Flags) ToOptions(appVersion string, args []string, globalFlags *GlobalFlagGroup, output io.Writer) (Options, error) { +func (f *Flags) ToOptions(appVersion string, args []string, globalFlags *GlobalFlagGroup) (Options, error) { var err error opts := Options{ AppVersion: appVersion, @@ -522,7 +521,7 @@ func (f *Flags) ToOptions(appVersion string, args []string, globalFlags *GlobalF } if f.ReportFlagGroup != nil { - opts.ReportOptions, err = f.ReportFlagGroup.ToOptions(output) + opts.ReportOptions, err = f.ReportFlagGroup.ToOptions() if err != nil { return Options{}, xerrors.Errorf("report flag error: %w", err) } diff --git a/pkg/flag/report_flags.go b/pkg/flag/report_flags.go index d180b916ef2d..215fb3f18624 100644 --- a/pkg/flag/report_flags.go +++ b/pkg/flag/report_flags.go @@ -1,8 +1,6 @@ package flag import ( - "io" - "os" "strings" "github.com/samber/lo" @@ -131,7 +129,7 @@ type ReportOptions struct { ExitCode int ExitOnEOL int IgnorePolicy string - Output io.Writer + Output string Severities []dbTypes.Severity Compliance spec.ComplianceSpec } @@ -174,12 +172,11 @@ func (f *ReportFlagGroup) Flags() []*Flag { } } -func (f *ReportFlagGroup) ToOptions(out io.Writer) (ReportOptions, error) { +func (f *ReportFlagGroup) ToOptions() (ReportOptions, error) { format := getString(f.Format) template := getString(f.Template) dependencyTree := getBool(f.DependencyTree) listAllPkgs := getBool(f.ListAllPkgs) - output := getString(f.Output) if template != "" { if format == "" { @@ -214,13 +211,6 @@ func (f *ReportFlagGroup) ToOptions(out io.Writer) (ReportOptions, error) { listAllPkgs = true } - if output != "" { - var err error - if out, err = os.Create(output); err != nil { - return ReportOptions{}, xerrors.Errorf("failed to create an output file: %w", err) - } - } - cs, err := loadComplianceTypes(getString(f.Compliance)) if err != nil { return ReportOptions{}, xerrors.Errorf("unable to load compliance spec: %w", err) @@ -236,7 +226,7 @@ func (f *ReportFlagGroup) ToOptions(out io.Writer) (ReportOptions, error) { ExitCode: getInt(f.ExitCode), ExitOnEOL: getInt(f.ExitOnEOL), IgnorePolicy: getString(f.IgnorePolicy), - Output: out, + Output: getString(f.Output), Severities: toSeverity(getStringSlice(f.Severity)), Compliance: cs, }, nil diff --git a/pkg/flag/report_flags_test.go b/pkg/flag/report_flags_test.go index 2ff9743e5fae..d0b0004b3449 100644 --- a/pkg/flag/report_flags_test.go +++ b/pkg/flag/report_flags_test.go @@ -1,7 +1,6 @@ package flag_test import ( - "os" "testing" defsecTypes "github.com/aquasecurity/defsec/pkg/types" @@ -44,9 +43,7 @@ func TestReportFlagGroup_ToOptions(t *testing.T) { { name: "happy default (without flags)", fields: fields{}, - want: flag.ReportOptions{ - Output: os.Stdout, - }, + want: flag.ReportOptions{}, }, { name: "happy path with an cyclonedx", @@ -56,7 +53,6 @@ func TestReportFlagGroup_ToOptions(t *testing.T) { listAllPkgs: true, }, want: flag.ReportOptions{ - Output: os.Stdout, Severities: []dbTypes.Severity{dbTypes.SeverityCritical}, Format: report.FormatCycloneDX, ListAllPkgs: true, @@ -76,7 +72,6 @@ func TestReportFlagGroup_ToOptions(t *testing.T) { `Severities: ["CRITICAL"]`, }, want: flag.ReportOptions{ - Output: os.Stdout, Severities: []dbTypes.Severity{ dbTypes.SeverityCritical, }, @@ -94,7 +89,6 @@ func TestReportFlagGroup_ToOptions(t *testing.T) { "'--template' is ignored because '--format template' is not specified. Use '--template' option with '--format template' option.", }, want: flag.ReportOptions{ - Output: os.Stdout, Severities: []dbTypes.Severity{dbTypes.SeverityLow}, Template: "@contrib/gitlab.tpl", }, @@ -110,7 +104,6 @@ func TestReportFlagGroup_ToOptions(t *testing.T) { "'--template' is ignored because '--format json' is specified. Use '--template' option with '--format template' option.", }, want: flag.ReportOptions{ - Output: os.Stdout, Format: "json", Severities: []dbTypes.Severity{dbTypes.SeverityLow}, Template: "@contrib/gitlab.tpl", @@ -126,7 +119,6 @@ func TestReportFlagGroup_ToOptions(t *testing.T) { "'--format template' is ignored because '--template' is not specified. Specify '--template' option when you use '--format template'.", }, want: flag.ReportOptions{ - Output: os.Stdout, Format: "template", Severities: []dbTypes.Severity{dbTypes.SeverityLow}, }, @@ -143,7 +135,6 @@ func TestReportFlagGroup_ToOptions(t *testing.T) { }, want: flag.ReportOptions{ Format: "table", - Output: os.Stdout, Severities: []dbTypes.Severity{dbTypes.SeverityLow}, ListAllPkgs: true, }, @@ -155,7 +146,6 @@ func TestReportFlagGroup_ToOptions(t *testing.T) { severities: dbTypes.SeverityLow.String(), }, want: flag.ReportOptions{ - Output: os.Stdout, Compliance: spec.ComplianceSpec{ Spec: defsecTypes.Spec{ ID: "0001", @@ -216,7 +206,7 @@ func TestReportFlagGroup_ToOptions(t *testing.T) { Compliance: &flag.ComplianceFlag, } - got, err := f.ToOptions(os.Stdout) + got, err := f.ToOptions() assert.NoError(t, err) assert.Equalf(t, tt.want, got, "ToOptions()") diff --git a/pkg/k8s/commands/run.go b/pkg/k8s/commands/run.go index 73e26207d0b0..951742d8a398 100644 --- a/pkg/k8s/commands/run.go +++ b/pkg/k8s/commands/run.go @@ -3,6 +3,7 @@ package commands import ( "context" "errors" + "os" "github.com/spf13/viper" "golang.org/x/xerrors" @@ -95,6 +96,15 @@ func (r *runner) run(ctx context.Context, artifacts []*artifacts.Artifact) error return xerrors.Errorf("k8s scan error: %w", err) } + output := os.Stdout + if r.flagOpts.Output != "" { + output, err = os.Create(r.flagOpts.Output) + if err != nil { + return xerrors.Errorf("failed to create output file: %w", err) + } + defer output.Close() + } + if r.flagOpts.Compliance.Spec.ID != "" { var scanResults []types.Results for _, rss := range rpt.Resources { @@ -107,14 +117,14 @@ func (r *runner) run(ctx context.Context, artifacts []*artifacts.Artifact) error return cr.Write(complianceReport, cr.Option{ Format: r.flagOpts.Format, Report: r.flagOpts.ReportFormat, - Output: r.flagOpts.Output, + Output: output, }) } if err := k8sRep.Write(rpt, report.Option{ Format: r.flagOpts.Format, Report: r.flagOpts.ReportFormat, - Output: r.flagOpts.Output, + Output: output, Severities: r.flagOpts.Severities, Components: r.flagOpts.Components, Scanners: r.flagOpts.ScanOptions.Scanners, diff --git a/pkg/report/github/github_test.go b/pkg/report/github/github_test.go index e8efa03eb0af..961ae3c1649e 100644 --- a/pkg/report/github/github_test.go +++ b/pkg/report/github/github_test.go @@ -9,7 +9,6 @@ import ( dbTypes "github.com/aquasecurity/trivy-db/pkg/types" ftypes "github.com/aquasecurity/trivy/pkg/fanal/types" - "github.com/aquasecurity/trivy/pkg/report" "github.com/aquasecurity/trivy/pkg/report/github" "github.com/aquasecurity/trivy/pkg/types" ) @@ -136,22 +135,19 @@ func TestWriter_Write(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - jw := github.Writer{} - written := bytes.Buffer{} - jw.Output = &written + written := bytes.NewBuffer(nil) + w := github.Writer{ + Output: written, + } inputResults := tt.report - err := report.Write(inputResults, report.Option{ - Format: "github", - Output: &written, - }) + err := w.Write(inputResults) assert.NoError(t, err) var got github.DependencySnapshot err = json.Unmarshal(written.Bytes(), &got) assert.NoError(t, err, "invalid github written") - assert.Equal(t, tt.want, got.Manifests, tt.name) }) } diff --git a/pkg/report/json_test.go b/pkg/report/json_test.go index 4b12c361e11b..850afd003545 100644 --- a/pkg/report/json_test.go +++ b/pkg/report/json_test.go @@ -66,9 +66,10 @@ func TestReportWriter_JSON(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - jw := report.JSONWriter{} - jsonWritten := bytes.Buffer{} - jw.Output = &jsonWritten + jsonWritten := bytes.NewBuffer(nil) + jw := report.JSONWriter{ + Output: jsonWritten, + } inputResults := types.Report{ SchemaVersion: 2, @@ -81,10 +82,7 @@ func TestReportWriter_JSON(t *testing.T) { }, } - err := report.Write(inputResults, report.Option{ - Format: "json", - Output: &jsonWritten, - }) + err := jw.Write(inputResults) assert.NoError(t, err) var got types.Report diff --git a/pkg/report/sarif_test.go b/pkg/report/sarif_test.go index 6511287718cf..595d69ba01fb 100644 --- a/pkg/report/sarif_test.go +++ b/pkg/report/sarif_test.go @@ -456,11 +456,11 @@ func TestReportWriter_Sarif(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - sarifWritten := bytes.Buffer{} - err := report.Write(tt.input, report.Option{ - Format: "sarif", - Output: &sarifWritten, - }) + sarifWritten := bytes.NewBuffer(nil) + w := report.SarifWriter{ + Output: sarifWritten, + } + err := w.Write(tt.input) assert.NoError(t, err) result := &sarif.Report{} diff --git a/pkg/report/table/table_test.go b/pkg/report/table/table_test.go index 9100c3b52cf3..27cc70a392ff 100644 --- a/pkg/report/table/table_test.go +++ b/pkg/report/table/table_test.go @@ -8,7 +8,7 @@ import ( dbTypes "github.com/aquasecurity/trivy-db/pkg/types" ftypes "github.com/aquasecurity/trivy/pkg/fanal/types" - "github.com/aquasecurity/trivy/pkg/report" + "github.com/aquasecurity/trivy/pkg/report/table" "github.com/aquasecurity/trivy/pkg/types" ) @@ -339,8 +339,7 @@ package-lock.json for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { tableWritten := bytes.Buffer{} - err := report.Write(types.Report{Results: tc.results}, report.Option{ - Format: report.FormatTable, + writer := table.Writer{ Output: &tableWritten, Tree: true, IncludeNonFailures: tc.includeNonFailures, @@ -348,7 +347,8 @@ package-lock.json dbTypes.SeverityHigh, dbTypes.SeverityMedium, }, - }) + } + err := writer.Write(types.Report{Results: tc.results}) assert.NoError(t, err) assert.Equal(t, tc.expectedOutput, tableWritten.String(), tc.name) }) diff --git a/pkg/report/template_test.go b/pkg/report/template_test.go index de2ee66c9b32..ded1fa8844c9 100644 --- a/pkg/report/template_test.go +++ b/pkg/report/template_test.go @@ -7,6 +7,7 @@ import ( "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" dbTypes "github.com/aquasecurity/trivy-db/pkg/types" "github.com/aquasecurity/trivy/pkg/clock" @@ -177,11 +178,9 @@ func TestReportWriter_Template(t *testing.T) { }, } - err := report.Write(inputReport, report.Option{ - Format: "template", - Output: &got, - OutputTemplate: tc.template, - }) + w, err := report.NewTemplateWriter(&got, tc.template) + require.NoError(t, err) + err = w.Write(inputReport) assert.NoError(t, err) assert.Equal(t, tc.expected, got.String()) }) diff --git a/pkg/report/writer.go b/pkg/report/writer.go index 3628c9889c4b..840a62e78c66 100644 --- a/pkg/report/writer.go +++ b/pkg/report/writer.go @@ -2,6 +2,7 @@ package report import ( "io" + "os" "strings" "sync" @@ -61,7 +62,7 @@ type Option struct { Format string Report string - Output io.Writer + Output string Tree bool Severities []dbTypes.Severity OutputTemplate string @@ -78,16 +79,26 @@ type Option struct { // Write writes the result to output, format as passed in argument func Write(report types.Report, option Option) error { + output := os.Stdout + if option.Output != "" { + f, err := os.Create(option.Output) + if err != nil { + return xerrors.Errorf("failed to create a file: %w", err) + } + output = f + defer f.Close() + } + // Compliance report if option.Compliance.Spec.ID != "" { - return complianceWrite(report, option) + return complianceWrite(report, option, output) } var writer Writer switch option.Format { case FormatTable: writer = &table.Writer{ - Output: option.Output, + Output: output, Severities: option.Severities, Tree: option.Tree, ShowMessageOnce: &sync.Once{}, @@ -97,38 +108,38 @@ func Write(report types.Report, option Option) error { IgnoredLicenses: option.IgnoredLicenses, } case FormatJSON: - writer = &JSONWriter{Output: option.Output} + writer = &JSONWriter{Output: output} case FormatGitHub: writer = &github.Writer{ - Output: option.Output, + Output: output, Version: option.AppVersion, } case FormatCycloneDX: // TODO: support xml format option with cyclonedx writer - writer = cyclonedx.NewWriter(option.Output, option.AppVersion) + writer = cyclonedx.NewWriter(output, option.AppVersion) case FormatSPDX, FormatSPDXJSON: - writer = spdx.NewWriter(option.Output, option.AppVersion, option.Format) + writer = spdx.NewWriter(output, option.AppVersion, option.Format) case FormatTemplate: // We keep `sarif.tpl` template working for backward compatibility for a while. if strings.HasPrefix(option.OutputTemplate, "@") && strings.HasSuffix(option.OutputTemplate, "sarif.tpl") { log.Logger.Warn("Using `--template sarif.tpl` is deprecated. Please migrate to `--format sarif`. See https://github.com/aquasecurity/trivy/discussions/1571") writer = &SarifWriter{ - Output: option.Output, + Output: output, Version: option.AppVersion, } break } var err error - if writer, err = NewTemplateWriter(option.Output, option.OutputTemplate); err != nil { + if writer, err = NewTemplateWriter(output, option.OutputTemplate); err != nil { return xerrors.Errorf("failed to initialize template writer: %w", err) } case FormatSarif: writer = &SarifWriter{ - Output: option.Output, + Output: output, Version: option.AppVersion, } case FormatCosignVuln: - writer = predicate.NewVulnWriter(option.Output, option.AppVersion) + writer = predicate.NewVulnWriter(output, option.AppVersion) default: return xerrors.Errorf("unknown format: %v", option.Format) } @@ -139,7 +150,7 @@ func Write(report types.Report, option Option) error { return nil } -func complianceWrite(report types.Report, opt Option) error { +func complianceWrite(report types.Report, opt Option, output io.Writer) error { complianceReport, err := cr.BuildComplianceReport([]types.Results{report.Results}, opt.Compliance) if err != nil { return xerrors.Errorf("compliance report build error: %w", err) @@ -147,7 +158,7 @@ func complianceWrite(report types.Report, opt Option) error { return cr.Write(complianceReport, cr.Option{ Format: opt.Format, Report: opt.Report, - Output: opt.Output, + Output: output, Severities: opt.Severities, }) } From 2f8a848043556606df2d3566ad2305172785cab0 Mon Sep 17 00:00:00 2001 From: knqyf263 Date: Thu, 20 Jul 2023 15:59:10 +0300 Subject: [PATCH 2/5] refactor: add the format type --- pkg/cloud/report/report.go | 21 ++----- pkg/commands/app.go | 12 ++-- pkg/commands/app_test.go | 12 ++-- pkg/commands/artifact/run.go | 8 +-- pkg/commands/convert/run.go | 2 +- pkg/compliance/report/report.go | 9 +-- pkg/fanal/analyzer/fs.go | 6 +- pkg/fanal/types/image.go | 7 --- pkg/flag/image_flags.go | 9 +-- pkg/flag/options.go | 49 ++++++++--------- pkg/flag/report_flags.go | 26 ++++----- pkg/flag/report_flags_test.go | 13 ++--- pkg/flag/scan_flags.go | 9 +-- pkg/k8s/commands/cluster.go | 5 +- pkg/k8s/commands/run.go | 12 ++-- pkg/k8s/report/report.go | 2 +- pkg/k8s/scanner/scanner.go | 10 +--- pkg/k8s/writer.go | 11 ++-- pkg/mapfs/file.go | 6 +- pkg/mapfs/fs.go | 4 +- pkg/report/spdx/spdx.go | 4 +- pkg/report/writer.go | 97 +++++++-------------------------- pkg/rpc/client/client.go | 3 +- pkg/types/report.go | 50 +++++++++++++---- pkg/types/target.go | 7 --- pkg/x/io/io.go | 15 +++++ pkg/x/strings/strings.go | 19 +++++++ pkg/{syncx => x/sync}/sync.go | 2 +- 28 files changed, 199 insertions(+), 231 deletions(-) create mode 100644 pkg/x/io/io.go create mode 100644 pkg/x/strings/strings.go rename pkg/{syncx => x/sync}/sync.go (98%) diff --git a/pkg/cloud/report/report.go b/pkg/cloud/report/report.go index c0517d6357d1..300ca84dc783 100644 --- a/pkg/cloud/report/report.go +++ b/pkg/cloud/report/report.go @@ -59,15 +59,11 @@ func (r *Report) Failed() bool { // Write writes the results in the give format func Write(rep *Report, opt flag.Options, fromCache bool) error { - output := os.Stdout - if opt.Output != "" { - f, err := os.Create(opt.Output) - if err != nil { - return xerrors.Errorf("failed to create output file: %w", err) - } - output = f - defer f.Close() + output, err := opt.OutputWriter() + if err != nil { + return xerrors.Errorf("failed to create output file: %w", err) } + defer output.Close() if opt.Compliance.Spec.ID != "" { return writeCompliance(rep, opt, output) @@ -139,14 +135,7 @@ func Write(rep *Report, opt flag.Options, fromCache bool) error { return nil default: - return pkgReport.Write(base, pkgReport.Option{ - Format: opt.Format, - Output: opt.Output, - Severities: opt.Severities, - OutputTemplate: opt.Template, - IncludeNonFailures: opt.IncludeNonFailures, - Trace: opt.Trace, - }) + return pkgReport.Write(base, opt) } } diff --git a/pkg/commands/app.go b/pkg/commands/app.go index 294a73f6f454..4c0e5b995f7f 100644 --- a/pkg/commands/app.go +++ b/pkg/commands/app.go @@ -28,8 +28,8 @@ import ( "github.com/aquasecurity/trivy/pkg/module" "github.com/aquasecurity/trivy/pkg/plugin" "github.com/aquasecurity/trivy/pkg/policy" - r "github.com/aquasecurity/trivy/pkg/report" "github.com/aquasecurity/trivy/pkg/types" + xstrings "github.com/aquasecurity/trivy/pkg/x/strings" ) // VersionInfo holds the trivy DB version Info @@ -894,11 +894,11 @@ func NewKubernetesCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { reportFlagGroup.ExitOnEOL = nil // disable '--exit-on-eol' formatFlag := flag.FormatFlag - formatFlag.Values = []string{ - r.FormatTable, - r.FormatJSON, - r.FormatCycloneDX, - } + formatFlag.Values = xstrings.ToStringSlice([]types.Format{ + types.FormatTable, + types.FormatJSON, + types.FormatCycloneDX, + }) reportFlagGroup.Format = &formatFlag k8sFlags := &flag.Flags{ diff --git a/pkg/commands/app_test.go b/pkg/commands/app_test.go index 21e0d71f71f2..52cd1c513335 100644 --- a/pkg/commands/app_test.go +++ b/pkg/commands/app_test.go @@ -11,7 +11,7 @@ import ( dbTypes "github.com/aquasecurity/trivy-db/pkg/types" "github.com/aquasecurity/trivy/pkg/flag" - "github.com/aquasecurity/trivy/pkg/report" + "github.com/aquasecurity/trivy/pkg/types" ) func Test_showVersion(t *testing.T) { @@ -170,7 +170,7 @@ Policy Bundle: func TestFlags(t *testing.T) { type want struct { - format string + format types.Format severities []dbTypes.Severity } tests := []struct { @@ -185,7 +185,7 @@ func TestFlags(t *testing.T) { "test", }, want: want{ - format: report.FormatTable, + format: types.FormatTable, severities: []dbTypes.Severity{ dbTypes.SeverityUnknown, dbTypes.SeverityLow, @@ -203,7 +203,7 @@ func TestFlags(t *testing.T) { "LOW,MEDIUM", }, want: want{ - format: report.FormatTable, + format: types.FormatTable, severities: []dbTypes.Severity{ dbTypes.SeverityLow, dbTypes.SeverityMedium, @@ -220,7 +220,7 @@ func TestFlags(t *testing.T) { "HIGH", }, want: want{ - format: report.FormatTable, + format: types.FormatTable, severities: []dbTypes.Severity{ dbTypes.SeverityLow, dbTypes.SeverityHigh, @@ -237,7 +237,7 @@ func TestFlags(t *testing.T) { "CRITICAL", }, want: want{ - format: report.FormatJSON, + format: types.FormatJSON, severities: []dbTypes.Severity{ dbTypes.SeverityCritical, }, diff --git a/pkg/commands/artifact/run.go b/pkg/commands/artifact/run.go index 23d0b01dc541..1d3444232b17 100644 --- a/pkg/commands/artifact/run.go +++ b/pkg/commands/artifact/run.go @@ -282,7 +282,7 @@ func (r *runner) Filter(ctx context.Context, opts flag.Options, report types.Rep } func (r *runner) Report(opts flag.Options, report types.Report) error { - if err := pkgReport.Write(report, opts.ReportOpts()); err != nil { + if err := pkgReport.Write(report, opts); err != nil { return xerrors.Errorf("unable to write results: %w", err) } @@ -325,7 +325,7 @@ func (r *runner) initJavaDB(opts flag.Options) error { // If vulnerability scanning and SBOM generation are disabled, it doesn't need to download the Java database. if !opts.Scanners.Enabled(types.VulnerabilityScanner) && - !slices.Contains(pkgReport.SupportedSBOMFormats, opts.Format) { + !slices.Contains(types.SupportedSBOMFormats, opts.Format) { return nil } @@ -497,7 +497,7 @@ func disabledAnalyzers(opts flag.Options) []analyzer.Type { // But we don't create client if vulnerability analysis is disabled and SBOM format is not used // We need to disable jar analyzer to avoid errors // TODO disable all languages that don't contain license information for this case - if !opts.Scanners.Enabled(types.VulnerabilityScanner) && !slices.Contains(pkgReport.SupportedSBOMFormats, opts.Format) { + if !opts.Scanners.Enabled(types.VulnerabilityScanner) && !slices.Contains(types.SupportedSBOMFormats, opts.Format) { analyzers = append(analyzers, analyzer.TypeJar) } @@ -611,7 +611,7 @@ func initScannerConfig(opts flag.Options, cacheClient cache.Cache) (ScannerConfi // SPDX needs to calculate digests for package files var fileChecksum bool - if opts.Format == pkgReport.FormatSPDXJSON || opts.Format == pkgReport.FormatSPDX { + if opts.Format == types.FormatSPDXJSON || opts.Format == types.FormatSPDX { fileChecksum = true } diff --git a/pkg/commands/convert/run.go b/pkg/commands/convert/run.go index a9879e0d1765..490864a14bdd 100644 --- a/pkg/commands/convert/run.go +++ b/pkg/commands/convert/run.go @@ -37,7 +37,7 @@ func Run(ctx context.Context, opts flag.Options) (err error) { } log.Logger.Debug("Writing report to output...") - if err = report.Write(r, opts.ReportOpts()); err != nil { + if err = report.Write(r, opts); err != nil { return xerrors.Errorf("unable to write results: %w", err) } diff --git a/pkg/compliance/report/report.go b/pkg/compliance/report/report.go index 09c2090591b7..50ae2460635b 100644 --- a/pkg/compliance/report/report.go +++ b/pkg/compliance/report/report.go @@ -15,13 +15,10 @@ import ( const ( allReport = "all" summaryReport = "summary" - - tableFormat = "table" - jsonFormat = "json" ) type Option struct { - Format string + Format types.Format Report string Output io.Writer Severities []dbTypes.Severity @@ -70,10 +67,10 @@ type Writer interface { // Write writes the results in the give format func Write(report *ComplianceReport, option Option) error { switch option.Format { - case jsonFormat: + case types.FormatJSON: jwriter := JSONWriter{Output: option.Output, Report: option.Report} return jwriter.Write(report) - case tableFormat: + case types.FormatTable: if !report.empty() { complianceWriter := &TableWriter{ Output: option.Output, diff --git a/pkg/fanal/analyzer/fs.go b/pkg/fanal/analyzer/fs.go index 18dd8dc8b9af..d55a7cba7f10 100644 --- a/pkg/fanal/analyzer/fs.go +++ b/pkg/fanal/analyzer/fs.go @@ -10,14 +10,14 @@ import ( "golang.org/x/xerrors" "github.com/aquasecurity/trivy/pkg/mapfs" - "github.com/aquasecurity/trivy/pkg/syncx" + "github.com/aquasecurity/trivy/pkg/x/sync" ) // CompositeFS contains multiple filesystems for post-analyzers type CompositeFS struct { group AnalyzerGroup dir string - files *syncx.Map[Type, *mapfs.FS] + files *sync.Map[Type, *mapfs.FS] } func NewCompositeFS(group AnalyzerGroup) (*CompositeFS, error) { @@ -29,7 +29,7 @@ func NewCompositeFS(group AnalyzerGroup) (*CompositeFS, error) { return &CompositeFS{ group: group, dir: tmpDir, - files: new(syncx.Map[Type, *mapfs.FS]), + files: new(sync.Map[Type, *mapfs.FS]), }, nil } diff --git a/pkg/fanal/types/image.go b/pkg/fanal/types/image.go index ab76cfc086fc..8a69f8786b84 100644 --- a/pkg/fanal/types/image.go +++ b/pkg/fanal/types/image.go @@ -2,7 +2,6 @@ package types import ( v1 "github.com/google/go-containerregistry/pkg/v1" - "github.com/samber/lo" ) const ( @@ -106,9 +105,3 @@ type Credential struct { Username string Password string } - -func (runtimes ImageSources) StringSlice() []string { - return lo.Map(runtimes, func(r ImageSource, _ int) string { - return string(r) - }) -} diff --git a/pkg/flag/image_flags.go b/pkg/flag/image_flags.go index 35fb8e00137c..b94162aeb9a7 100644 --- a/pkg/flag/image_flags.go +++ b/pkg/flag/image_flags.go @@ -6,6 +6,7 @@ import ( ftypes "github.com/aquasecurity/trivy/pkg/fanal/types" "github.com/aquasecurity/trivy/pkg/types" + xstrings "github.com/aquasecurity/trivy/pkg/x/strings" ) // e.g. config yaml @@ -18,10 +19,10 @@ var ( Name: "image-config-scanners", ConfigName: "image.image-config-scanners", Default: []string{}, - Values: types.Scanners{ + Values: xstrings.ToStringSlice(types.Scanners{ types.MisconfigScanner, types.SecretScanner, - }.StringSlice(), + }), Usage: "comma-separated list of what security issues to detect on container image configurations", } ScanRemovedPkgsFlag = Flag{ @@ -51,8 +52,8 @@ var ( SourceFlag = Flag{ Name: "image-src", ConfigName: "image.source", - Default: ftypes.AllImageSources.StringSlice(), - Values: ftypes.AllImageSources.StringSlice(), + Default: xstrings.ToStringSlice(ftypes.AllImageSources), + Values: xstrings.ToStringSlice(ftypes.AllImageSources), Usage: "image source(s) to use, in priority order", } ) diff --git a/pkg/flag/options.go b/pkg/flag/options.go index cfecff4c81db..94994295d7a9 100644 --- a/pkg/flag/options.go +++ b/pkg/flag/options.go @@ -2,12 +2,12 @@ package flag import ( "fmt" + "io" "os" "strings" "sync" "time" - "github.com/samber/lo" "github.com/spf13/cast" "github.com/spf13/cobra" "github.com/spf13/pflag" @@ -17,14 +17,12 @@ import ( "github.com/aquasecurity/trivy/pkg/fanal/analyzer" ftypes "github.com/aquasecurity/trivy/pkg/fanal/types" "github.com/aquasecurity/trivy/pkg/log" - "github.com/aquasecurity/trivy/pkg/report" "github.com/aquasecurity/trivy/pkg/result" + "github.com/aquasecurity/trivy/pkg/types" + xio "github.com/aquasecurity/trivy/pkg/x/io" + xstrings "github.com/aquasecurity/trivy/pkg/x/strings" ) -type String interface { - ~string -} - type Flag struct { // Name is for CLI flag and environment variable. // If this field is empty, it will be available only in config file. @@ -119,18 +117,18 @@ type Options struct { // Align takes consistency of options func (o *Options) Align() { - if o.Format == report.FormatSPDX || o.Format == report.FormatSPDXJSON { + if o.Format == types.FormatSPDX || o.Format == types.FormatSPDXJSON { log.Logger.Info(`"--format spdx" and "--format spdx-json" disable security scanning`) o.Scanners = nil } // Vulnerability scanning is disabled by default for CycloneDX. - if o.Format == report.FormatCycloneDX && !viper.IsSet(ScannersFlag.ConfigName) && len(o.K8sOptions.Components) == 0 { // remove K8sOptions.Components validation check when vuln scan is supported for k8s report with cycloneDX + if o.Format == types.FormatCycloneDX && !viper.IsSet(ScannersFlag.ConfigName) && len(o.K8sOptions.Components) == 0 { // remove K8sOptions.Components validation check when vuln scan is supported for k8s report with cycloneDX log.Logger.Info(`"--format cyclonedx" disables security scanning. Specify "--scanners vuln" explicitly if you want to include vulnerabilities in the CycloneDX report.`) o.Scanners = nil } - if o.Format == report.FormatCycloneDX && len(o.K8sOptions.Components) > 0 { + if o.Format == types.FormatCycloneDX && len(o.K8sOptions.Components) > 0 { log.Logger.Info(`"k8s with --format cyclonedx" disable security scanning`) o.Scanners = nil } @@ -160,19 +158,17 @@ func (o *Options) FilterOpts() result.FilterOption { } } -func (o *Options) ReportOpts() report.Option { - return report.Option{ - AppVersion: o.AppVersion, - Format: o.Format, - Output: o.Output, - Tree: o.DependencyTree, - Severities: o.Severities, - OutputTemplate: o.Template, - IncludeNonFailures: o.IncludeNonFailures, - Trace: o.Trace, - Report: o.ReportFormat, - Compliance: o.Compliance, +// OutputWriter returns an output writer. +// If the output file is not specified, it returns os.Stdout. +func (o *Options) OutputWriter() (io.WriteCloser, error) { + if o.Output != "" { + f, err := os.Create(o.Output) + if err != nil { + return nil, xerrors.Errorf("failed to create output file: %w", err) + } + return f, nil } + return xio.NopCloser(os.Stdout), nil } func addFlag(cmd *cobra.Command, flag *Flag) { @@ -267,6 +263,11 @@ func getString(flag *Flag) string { return cast.ToString(getValue(flag)) } +func getUnderlyingString[T xstrings.String](flag *Flag) T { + s := getString(flag) + return T(s) +} + func getStringSlice(flag *Flag) []string { // viper always returns a string for ENV // https://github.com/spf13/viper/blob/419fd86e49ef061d0d33f4d1d56d5e2a480df5bb/viper.go#L545-L553 @@ -282,14 +283,12 @@ func getStringSlice(flag *Flag) []string { return v } -func getUnderlyingStringSlice[T String](flag *Flag) []T { +func getUnderlyingStringSlice[T xstrings.String](flag *Flag) []T { ss := getStringSlice(flag) if len(ss) == 0 { return nil } - return lo.Map(ss, func(s string, _ int) T { - return T(s) - }) + return xstrings.ToTSlice[T](ss) } func getInt(flag *Flag) int { diff --git a/pkg/flag/report_flags.go b/pkg/flag/report_flags.go index 215fb3f18624..50304f5d7883 100644 --- a/pkg/flag/report_flags.go +++ b/pkg/flag/report_flags.go @@ -10,9 +10,9 @@ import ( dbTypes "github.com/aquasecurity/trivy-db/pkg/types" "github.com/aquasecurity/trivy/pkg/compliance/spec" "github.com/aquasecurity/trivy/pkg/log" - "github.com/aquasecurity/trivy/pkg/report" "github.com/aquasecurity/trivy/pkg/result" "github.com/aquasecurity/trivy/pkg/types" + xstrings "github.com/aquasecurity/trivy/pkg/x/strings" ) // e.g. config yaml: @@ -25,8 +25,8 @@ var ( Name: "format", ConfigName: "format", Shorthand: "f", - Default: report.FormatTable, - Values: report.SupportedFormats, + Default: string(types.FormatTable), + Values: xstrings.ToStringSlice(types.SupportedFormats), Usage: "format", } ReportFormatFlag = Flag{ @@ -120,7 +120,7 @@ type ReportFlagGroup struct { } type ReportOptions struct { - Format string + Format types.Format ReportFormat string Template string DependencyTree bool @@ -173,7 +173,7 @@ func (f *ReportFlagGroup) Flags() []*Flag { } func (f *ReportFlagGroup) ToOptions() (ReportOptions, error) { - format := getString(f.Format) + format := getUnderlyingString[types.Format](f.Format) template := getString(f.Template) dependencyTree := getBool(f.DependencyTree) listAllPkgs := getBool(f.ListAllPkgs) @@ -185,14 +185,14 @@ func (f *ReportFlagGroup) ToOptions() (ReportOptions, error) { log.Logger.Warnf("'--template' is ignored because '--format %s' is specified. Use '--template' option with '--format template' option.", format) } } else { - if format == report.FormatTemplate { + if format == types.FormatTemplate { log.Logger.Warn("'--format template' is ignored because '--template' is not specified. Specify '--template' option when you use '--format template'.") } } // "--list-all-pkgs" option is unavailable with "--format table". // If user specifies "--list-all-pkgs" with "--format table", we should warn it. - if listAllPkgs && format == report.FormatTable { + if listAllPkgs && format == types.FormatTable { log.Logger.Warn(`"--list-all-pkgs" cannot be used with "--format table". Try "--format json" or other formats.`) } @@ -201,7 +201,7 @@ func (f *ReportFlagGroup) ToOptions() (ReportOptions, error) { log.Logger.Infof(`"--dependency-tree" only shows the dependents of vulnerable packages. ` + `Note that it is the reverse of the usual dependency tree, which shows the packages that depend on the vulnerable package. ` + `It supports limited package managers. Please see the document for the detail.`) - if format != report.FormatTable { + if format != types.FormatTable { log.Logger.Warn(`"--dependency-tree" can be used only with "--format table".`) } } @@ -233,7 +233,7 @@ func (f *ReportFlagGroup) ToOptions() (ReportOptions, error) { } func loadComplianceTypes(compliance string) (spec.ComplianceSpec, error) { - if len(compliance) > 0 && !slices.Contains(types.Compliances, compliance) && !strings.HasPrefix(compliance, "@") { + if len(compliance) > 0 && !slices.Contains(types.SupportedCompliances, compliance) && !strings.HasPrefix(compliance, "@") { return spec.ComplianceSpec{}, xerrors.Errorf("unknown compliance : %v", compliance) } @@ -245,13 +245,13 @@ func loadComplianceTypes(compliance string) (spec.ComplianceSpec, error) { return cs, nil } -func (f *ReportFlagGroup) forceListAllPkgs(format string, listAllPkgs, dependencyTree bool) bool { - if slices.Contains(report.SupportedSBOMFormats, format) && !listAllPkgs { - log.Logger.Debugf("%q automatically enables '--list-all-pkgs'.", report.SupportedSBOMFormats) +func (f *ReportFlagGroup) forceListAllPkgs(format types.Format, listAllPkgs, dependencyTree bool) bool { + if slices.Contains(types.SupportedSBOMFormats, format) && !listAllPkgs { + log.Logger.Debugf("%q automatically enables '--list-all-pkgs'.", types.SupportedSBOMFormats) return true } // We need this flag to insert dependency locations into Sarif('Package' struct contains 'Locations') - if format == report.FormatSarif && !listAllPkgs { + if format == types.FormatSarif && !listAllPkgs { log.Logger.Debugf("Sarif format automatically enables '--list-all-pkgs' to get locations") return true } diff --git a/pkg/flag/report_flags_test.go b/pkg/flag/report_flags_test.go index d0b0004b3449..9155addbbe64 100644 --- a/pkg/flag/report_flags_test.go +++ b/pkg/flag/report_flags_test.go @@ -3,23 +3,22 @@ package flag_test import ( "testing" - defsecTypes "github.com/aquasecurity/defsec/pkg/types" - "github.com/spf13/viper" "github.com/stretchr/testify/assert" "go.uber.org/zap" "go.uber.org/zap/zaptest/observer" + defsecTypes "github.com/aquasecurity/defsec/pkg/types" dbTypes "github.com/aquasecurity/trivy-db/pkg/types" "github.com/aquasecurity/trivy/pkg/compliance/spec" "github.com/aquasecurity/trivy/pkg/flag" "github.com/aquasecurity/trivy/pkg/log" - "github.com/aquasecurity/trivy/pkg/report" + "github.com/aquasecurity/trivy/pkg/types" ) func TestReportFlagGroup_ToOptions(t *testing.T) { type fields struct { - format string + format types.Format template string dependencyTree bool listAllPkgs bool @@ -54,7 +53,7 @@ func TestReportFlagGroup_ToOptions(t *testing.T) { }, want: flag.ReportOptions{ Severities: []dbTypes.Severity{dbTypes.SeverityCritical}, - Format: report.FormatCycloneDX, + Format: types.FormatCycloneDX, ListAllPkgs: true, }, }, @@ -75,7 +74,7 @@ func TestReportFlagGroup_ToOptions(t *testing.T) { Severities: []dbTypes.Severity{ dbTypes.SeverityCritical, }, - Format: report.FormatCycloneDX, + Format: types.FormatCycloneDX, ListAllPkgs: true, }, }, @@ -178,7 +177,7 @@ func TestReportFlagGroup_ToOptions(t *testing.T) { core, obs := observer.New(level) log.Logger = zap.New(core).Sugar() - viper.Set(flag.FormatFlag.ConfigName, tt.fields.format) + viper.Set(flag.FormatFlag.ConfigName, string(tt.fields.format)) viper.Set(flag.TemplateFlag.ConfigName, tt.fields.template) viper.Set(flag.DependencyTreeFlag.ConfigName, tt.fields.dependencyTree) viper.Set(flag.ListAllPkgsFlag.ConfigName, tt.fields.listAllPkgs) diff --git a/pkg/flag/scan_flags.go b/pkg/flag/scan_flags.go index 0dd8fb1ca34e..1951daeedcc5 100644 --- a/pkg/flag/scan_flags.go +++ b/pkg/flag/scan_flags.go @@ -2,6 +2,7 @@ package flag import ( "github.com/aquasecurity/trivy/pkg/types" + xstrings "github.com/aquasecurity/trivy/pkg/x/strings" ) var ( @@ -26,16 +27,16 @@ var ( ScannersFlag = Flag{ Name: "scanners", ConfigName: "scan.scanners", - Default: types.Scanners{ + Default: xstrings.ToStringSlice(types.Scanners{ types.VulnerabilityScanner, types.SecretScanner, - }.StringSlice(), - Values: types.Scanners{ + }), + Values: xstrings.ToStringSlice(types.Scanners{ types.VulnerabilityScanner, types.MisconfigScanner, types.SecretScanner, types.LicenseScanner, - }.StringSlice(), + }), Aliases: []Alias{ { Name: "security-checks", diff --git a/pkg/k8s/commands/cluster.go b/pkg/k8s/commands/cluster.go index 6715bbcdc8e5..632b4e430867 100644 --- a/pkg/k8s/commands/cluster.go +++ b/pkg/k8s/commands/cluster.go @@ -11,7 +11,6 @@ import ( "github.com/aquasecurity/trivy-kubernetes/pkg/trivyk8s" "github.com/aquasecurity/trivy/pkg/flag" "github.com/aquasecurity/trivy/pkg/log" - "github.com/aquasecurity/trivy/pkg/report" "github.com/aquasecurity/trivy/pkg/types" ) @@ -23,12 +22,12 @@ func clusterRun(ctx context.Context, opts flag.Options, cluster k8s.Cluster) err var artifacts []*artifacts.Artifact var err error switch opts.Format { - case report.FormatCycloneDX: + case types.FormatCycloneDX: artifacts, err = trivyk8s.New(cluster, log.Logger).ListBomInfo(ctx) if err != nil { return xerrors.Errorf("get k8s artifacts with node info error: %w", err) } - case report.FormatJSON, report.FormatTable: + case types.FormatJSON, types.FormatTable: if opts.Scanners.AnyEnabled(types.MisconfigScanner) && slices.Contains(opts.Components, "infra") { artifacts, err = trivyk8s.New(cluster, log.Logger).ListArtifactAndNodeInfo(ctx, opts.NodeCollectorNamespace, opts.ExcludeNodes, opts.Tolerations...) if err != nil { diff --git a/pkg/k8s/commands/run.go b/pkg/k8s/commands/run.go index 951742d8a398..8efdd3244cbc 100644 --- a/pkg/k8s/commands/run.go +++ b/pkg/k8s/commands/run.go @@ -3,7 +3,6 @@ package commands import ( "context" "errors" - "os" "github.com/spf13/viper" "golang.org/x/xerrors" @@ -96,14 +95,11 @@ func (r *runner) run(ctx context.Context, artifacts []*artifacts.Artifact) error return xerrors.Errorf("k8s scan error: %w", err) } - output := os.Stdout - if r.flagOpts.Output != "" { - output, err = os.Create(r.flagOpts.Output) - if err != nil { - return xerrors.Errorf("failed to create output file: %w", err) - } - defer output.Close() + output, err := r.flagOpts.OutputWriter() + if err != nil { + return xerrors.Errorf("failed to create output file: %w", err) } + defer output.Close() if r.flagOpts.Compliance.Spec.ID != "" { var scanResults []types.Results diff --git a/pkg/k8s/report/report.go b/pkg/k8s/report/report.go index 4dd36cb97a0a..0db40e0d33dc 100644 --- a/pkg/k8s/report/report.go +++ b/pkg/k8s/report/report.go @@ -25,7 +25,7 @@ const ( ) type Option struct { - Format string + Format types.Format Report string Output io.Writer Severities []dbTypes.Severity diff --git a/pkg/k8s/scanner/scanner.go b/pkg/k8s/scanner/scanner.go index 759377920ad5..2f319b992411 100644 --- a/pkg/k8s/scanner/scanner.go +++ b/pkg/k8s/scanner/scanner.go @@ -7,16 +7,13 @@ import ( "sort" "strings" - "golang.org/x/xerrors" - + cdx "github.com/CycloneDX/cyclonedx-go" ms "github.com/mitchellh/mapstructure" "github.com/package-url/packageurl-go" "github.com/samber/lo" + "golang.org/x/xerrors" "github.com/aquasecurity/go-version/pkg/version" - - cdx "github.com/CycloneDX/cyclonedx-go" - "github.com/aquasecurity/trivy-kubernetes/pkg/artifacts" "github.com/aquasecurity/trivy-kubernetes/pkg/bom" cmd "github.com/aquasecurity/trivy/pkg/commands/artifact" @@ -27,7 +24,6 @@ import ( "github.com/aquasecurity/trivy/pkg/log" "github.com/aquasecurity/trivy/pkg/parallel" "github.com/aquasecurity/trivy/pkg/purl" - rep "github.com/aquasecurity/trivy/pkg/report" cyc "github.com/aquasecurity/trivy/pkg/sbom/cyclonedx" "github.com/aquasecurity/trivy/pkg/sbom/cyclonedx/core" "github.com/aquasecurity/trivy/pkg/scanner/local" @@ -74,7 +70,7 @@ func (s *Scanner) Scan(ctx context.Context, artifactsData []*artifacts.Artifact) } }() - if s.opts.Format == rep.FormatCycloneDX { + if s.opts.Format == types.FormatCycloneDX { rootComponent, err := clusterInfoToReportResources(artifactsData, s.cluster) if err != nil { return report.Report{}, err diff --git a/pkg/k8s/writer.go b/pkg/k8s/writer.go index f374c6b92520..4decab422699 100644 --- a/pkg/k8s/writer.go +++ b/pkg/k8s/writer.go @@ -3,12 +3,11 @@ package k8s import ( "fmt" - "github.com/aquasecurity/trivy/pkg/k8s/report" - cdx "github.com/CycloneDX/cyclonedx-go" - rp "github.com/aquasecurity/trivy/pkg/report" + "github.com/aquasecurity/trivy/pkg/k8s/report" "github.com/aquasecurity/trivy/pkg/report/table" + "github.com/aquasecurity/trivy/pkg/types" ) type Writer interface { @@ -20,13 +19,13 @@ func Write(k8sreport report.Report, option report.Option) error { k8sreport.PrintErrors() switch option.Format { - case rp.FormatJSON: + case types.FormatJSON: jwriter := report.JSONWriter{ Output: option.Output, Report: option.Report, } return jwriter.Write(k8sreport) - case rp.FormatTable: + case types.FormatTable: separatedReports := report.SeparateMisconfigReports(k8sreport, option.Scanners, option.Components) if option.Report == report.SummaryReport { @@ -48,7 +47,7 @@ func Write(k8sreport report.Report, option report.Option) error { } return nil - case rp.FormatCycloneDX: + case types.FormatCycloneDX: w := report.NewCycloneDXWriter(option.Output, cdx.BOMFileFormatJSON, option.APIVersion) return w.Write(k8sreport.RootComponent) } diff --git a/pkg/mapfs/file.go b/pkg/mapfs/file.go index 7dd990a5c881..cc66ce1ed163 100644 --- a/pkg/mapfs/file.go +++ b/pkg/mapfs/file.go @@ -11,7 +11,7 @@ import ( "golang.org/x/xerrors" - "github.com/aquasecurity/trivy/pkg/syncx" + xsync "github.com/aquasecurity/trivy/pkg/x/sync" ) var separator = "/" @@ -24,7 +24,7 @@ type file struct { underlyingPath string // underlying file path data []byte // virtual file, only either of 'path' or 'data' has a value. stat fileStat - files syncx.Map[string, *file] + files xsync.Map[string, *file] } func (f *file) isVirtual() bool { @@ -187,7 +187,7 @@ func (f *file) MkdirAll(path string, perm fs.FileMode) error { modTime: time.Now(), mode: perm, }, - files: syncx.Map[string, *file]{}, + files: xsync.Map[string, *file]{}, } // Create the directory when the key is not present diff --git a/pkg/mapfs/fs.go b/pkg/mapfs/fs.go index 3bf59f47c2fc..471730cc533e 100644 --- a/pkg/mapfs/fs.go +++ b/pkg/mapfs/fs.go @@ -12,7 +12,7 @@ import ( "golang.org/x/exp/slices" "golang.org/x/xerrors" - "github.com/aquasecurity/trivy/pkg/syncx" + xsync "github.com/aquasecurity/trivy/pkg/x/sync" ) type allFS interface { @@ -56,7 +56,7 @@ func New(opts ...Option) *FS { modTime: time.Now(), mode: 0o0700 | fs.ModeDir, }, - files: syncx.Map[string, *file]{}, + files: xsync.Map[string, *file]{}, }, } for _, opt := range opts { diff --git a/pkg/report/spdx/spdx.go b/pkg/report/spdx/spdx.go index aa253b4a9e88..848e26172dbe 100644 --- a/pkg/report/spdx/spdx.go +++ b/pkg/report/spdx/spdx.go @@ -15,11 +15,11 @@ import ( type Writer struct { output io.Writer version string - format string + format types.Format marshaler *spdx.Marshaler } -func NewWriter(output io.Writer, version string, spdxFormat string) Writer { +func NewWriter(output io.Writer, version string, spdxFormat types.Format) Writer { return Writer{ output: output, version: version, diff --git a/pkg/report/writer.go b/pkg/report/writer.go index 840a62e78c66..f31d9324f277 100644 --- a/pkg/report/writer.go +++ b/pkg/report/writer.go @@ -2,15 +2,13 @@ package report import ( "io" - "os" "strings" "sync" "golang.org/x/xerrors" - dbTypes "github.com/aquasecurity/trivy-db/pkg/types" cr "github.com/aquasecurity/trivy/pkg/compliance/report" - "github.com/aquasecurity/trivy/pkg/compliance/spec" + "github.com/aquasecurity/trivy/pkg/flag" "github.com/aquasecurity/trivy/pkg/log" "github.com/aquasecurity/trivy/pkg/report/cyclonedx" "github.com/aquasecurity/trivy/pkg/report/github" @@ -22,72 +20,15 @@ import ( const ( SchemaVersion = 2 - - FormatTable = "table" - FormatJSON = "json" - FormatTemplate = "template" - FormatSarif = "sarif" - FormatCycloneDX = "cyclonedx" - FormatSPDX = "spdx" - FormatSPDXJSON = "spdx-json" - FormatGitHub = "github" - FormatCosignVuln = "cosign-vuln" -) - -var ( - SupportedFormats = []string{ - FormatTable, - FormatJSON, - FormatTemplate, - FormatSarif, - FormatCycloneDX, - FormatSPDX, - FormatSPDXJSON, - FormatGitHub, - FormatCosignVuln, - } -) - -var ( - SupportedSBOMFormats = []string{ - FormatCycloneDX, - FormatSPDX, - FormatSPDXJSON, - FormatGitHub, - } ) -type Option struct { - AppVersion string - - Format string - Report string - Output string - Tree bool - Severities []dbTypes.Severity - OutputTemplate string - Compliance spec.ComplianceSpec - - // For misconfigurations - IncludeNonFailures bool - Trace bool - - // For licenses - LicenseRiskThreshold int - IgnoredLicenses []string -} - // Write writes the result to output, format as passed in argument -func Write(report types.Report, option Option) error { - output := os.Stdout - if option.Output != "" { - f, err := os.Create(option.Output) - if err != nil { - return xerrors.Errorf("failed to create a file: %w", err) - } - output = f - defer f.Close() +func Write(report types.Report, option flag.Options) error { + output, err := option.OutputWriter() + if err != nil { + return xerrors.Errorf("failed to create a file: %w", err) } + defer output.Close() // Compliance report if option.Compliance.Spec.ID != "" { @@ -96,32 +37,32 @@ func Write(report types.Report, option Option) error { var writer Writer switch option.Format { - case FormatTable: + case types.FormatTable: writer = &table.Writer{ Output: output, Severities: option.Severities, - Tree: option.Tree, + Tree: option.DependencyTree, ShowMessageOnce: &sync.Once{}, IncludeNonFailures: option.IncludeNonFailures, Trace: option.Trace, LicenseRiskThreshold: option.LicenseRiskThreshold, IgnoredLicenses: option.IgnoredLicenses, } - case FormatJSON: + case types.FormatJSON: writer = &JSONWriter{Output: output} - case FormatGitHub: + case types.FormatGitHub: writer = &github.Writer{ Output: output, Version: option.AppVersion, } - case FormatCycloneDX: + case types.FormatCycloneDX: // TODO: support xml format option with cyclonedx writer writer = cyclonedx.NewWriter(output, option.AppVersion) - case FormatSPDX, FormatSPDXJSON: + case types.FormatSPDX, types.FormatSPDXJSON: writer = spdx.NewWriter(output, option.AppVersion, option.Format) - case FormatTemplate: + case types.FormatTemplate: // We keep `sarif.tpl` template working for backward compatibility for a while. - if strings.HasPrefix(option.OutputTemplate, "@") && strings.HasSuffix(option.OutputTemplate, "sarif.tpl") { + if strings.HasPrefix(option.Template, "@") && strings.HasSuffix(option.Template, "sarif.tpl") { log.Logger.Warn("Using `--template sarif.tpl` is deprecated. Please migrate to `--format sarif`. See https://github.com/aquasecurity/trivy/discussions/1571") writer = &SarifWriter{ Output: output, @@ -130,15 +71,15 @@ func Write(report types.Report, option Option) error { break } var err error - if writer, err = NewTemplateWriter(output, option.OutputTemplate); err != nil { + if writer, err = NewTemplateWriter(output, option.Template); err != nil { return xerrors.Errorf("failed to initialize template writer: %w", err) } - case FormatSarif: + case types.FormatSarif: writer = &SarifWriter{ Output: output, Version: option.AppVersion, } - case FormatCosignVuln: + case types.FormatCosignVuln: writer = predicate.NewVulnWriter(output, option.AppVersion) default: return xerrors.Errorf("unknown format: %v", option.Format) @@ -150,14 +91,14 @@ func Write(report types.Report, option Option) error { return nil } -func complianceWrite(report types.Report, opt Option, output io.Writer) error { +func complianceWrite(report types.Report, opt flag.Options, output io.Writer) error { complianceReport, err := cr.BuildComplianceReport([]types.Results{report.Results}, opt.Compliance) if err != nil { return xerrors.Errorf("compliance report build error: %w", err) } return cr.Write(complianceReport, cr.Option{ Format: opt.Format, - Report: opt.Report, + Report: opt.ReportFormat, Output: output, Severities: opt.Severities, }) diff --git a/pkg/rpc/client/client.go b/pkg/rpc/client/client.go index 8d9dc2314c8f..e2a967de9f26 100644 --- a/pkg/rpc/client/client.go +++ b/pkg/rpc/client/client.go @@ -10,6 +10,7 @@ import ( ftypes "github.com/aquasecurity/trivy/pkg/fanal/types" r "github.com/aquasecurity/trivy/pkg/rpc" "github.com/aquasecurity/trivy/pkg/types" + xstrings "github.com/aquasecurity/trivy/pkg/x/strings" rpc "github.com/aquasecurity/trivy/rpc/scanner" ) @@ -82,7 +83,7 @@ func (s Scanner) Scan(ctx context.Context, target, artifactKey string, blobKeys BlobIds: blobKeys, Options: &rpc.ScanOptions{ VulnType: opts.VulnType, - Scanners: opts.Scanners.StringSlice(), + Scanners: xstrings.ToStringSlice(opts.Scanners), ListAllPackages: opts.ListAllPackages, LicenseCategories: licenseCategories, IncludeDevDeps: opts.IncludeDevDeps, diff --git a/pkg/types/report.go b/pkg/types/report.go index c0bf89844d8d..ad3e32925527 100644 --- a/pkg/types/report.go +++ b/pkg/types/report.go @@ -8,16 +8,6 @@ import ( ftypes "github.com/aquasecurity/trivy/pkg/fanal/types" ) -var Compliances = []string{ - ComplianceK8sNsa, - ComplianceK8sCIS, - ComplianceK8sPSSBaseline, - ComplianceK8sPSSRestricted, - ComplianceAWSCIS12, - ComplianceAWSCIS14, - ComplianceDockerCIS, -} - // Report represents a scan result type Report struct { SchemaVersion int `json:",omitempty"` @@ -48,6 +38,7 @@ type Results []Result type ResultClass string type Compliance = string +type Format string const ( ClassOSPkg = "os-pkgs" // For detected packages and vulnerabilities in OS packages @@ -65,6 +56,45 @@ const ( ComplianceAWSCIS12 = Compliance("aws-cis-1.2") ComplianceAWSCIS14 = Compliance("aws-cis-1.4") ComplianceDockerCIS = Compliance("docker-cis") + + FormatTable Format = "table" + FormatJSON Format = "json" + FormatTemplate Format = "template" + FormatSarif Format = "sarif" + FormatCycloneDX Format = "cyclonedx" + FormatSPDX Format = "spdx" + FormatSPDXJSON Format = "spdx-json" + FormatGitHub Format = "github" + FormatCosignVuln Format = "cosign-vuln" +) + +var ( + SupportedFormats = []Format{ + FormatTable, + FormatJSON, + FormatTemplate, + FormatSarif, + FormatCycloneDX, + FormatSPDX, + FormatSPDXJSON, + FormatGitHub, + FormatCosignVuln, + } + SupportedSBOMFormats = []Format{ + FormatCycloneDX, + FormatSPDX, + FormatSPDXJSON, + FormatGitHub, + } + SupportedCompliances = []string{ + ComplianceK8sNsa, + ComplianceK8sCIS, + ComplianceK8sPSSBaseline, + ComplianceK8sPSSRestricted, + ComplianceAWSCIS12, + ComplianceAWSCIS14, + ComplianceDockerCIS, + } ) // Result holds a target and detected vulnerabilities diff --git a/pkg/types/target.go b/pkg/types/target.go index f0f91eff96e1..f302505bc1b6 100644 --- a/pkg/types/target.go +++ b/pkg/types/target.go @@ -1,7 +1,6 @@ package types import ( - "github.com/samber/lo" "golang.org/x/exp/slices" ) @@ -84,9 +83,3 @@ func (scanners Scanners) AnyEnabled(ss ...Scanner) bool { } return false } - -func (scanners Scanners) StringSlice() []string { - return lo.Map(scanners, func(s Scanner, _ int) string { - return string(s) - }) -} diff --git a/pkg/x/io/io.go b/pkg/x/io/io.go new file mode 100644 index 000000000000..bd3b9da8b4f4 --- /dev/null +++ b/pkg/x/io/io.go @@ -0,0 +1,15 @@ +package io + +import "io" + +// NopCloser returns a WriteCloser with a no-op Close method wrapping +// the provided Writer w. +func NopCloser(w io.Writer) io.WriteCloser { + return nopCloser{w} +} + +type nopCloser struct { + io.Writer +} + +func (nopCloser) Close() error { return nil } diff --git a/pkg/x/strings/strings.go b/pkg/x/strings/strings.go new file mode 100644 index 000000000000..ce534b18d68a --- /dev/null +++ b/pkg/x/strings/strings.go @@ -0,0 +1,19 @@ +package strings + +import "github.com/samber/lo" + +type String interface { + ~string +} + +func ToStringSlice[T String](ss []T) []string { + return lo.Map(ss, func(s T, _ int) string { + return string(s) + }) +} + +func ToTSlice[T String](ss []string) []T { + return lo.Map(ss, func(s string, _ int) T { + return T(s) + }) +} diff --git a/pkg/syncx/sync.go b/pkg/x/sync/sync.go similarity index 98% rename from pkg/syncx/sync.go rename to pkg/x/sync/sync.go index 10f80117a94d..9841779efacf 100644 --- a/pkg/syncx/sync.go +++ b/pkg/x/sync/sync.go @@ -1,4 +1,4 @@ -package syncx +package sync import "sync" From 1417f18582cb3a16edeccc42ccaef4abcfa09a58 Mon Sep 17 00:00:00 2001 From: knqyf263 Date: Sun, 23 Jul 2023 15:13:44 +0300 Subject: [PATCH 3/5] fix: return errors in version printing --- pkg/commands/app.go | 30 +++++++++++++++++++++++------- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/pkg/commands/app.go b/pkg/commands/app.go index 4c0e5b995f7f..6aefa8beb960 100644 --- a/pkg/commands/app.go +++ b/pkg/commands/app.go @@ -213,7 +213,7 @@ func NewRootCommand(version string, globalFlags *flag.GlobalFlagGroup) *cobra.Co globalOptions := globalFlags.ToOptions() if globalOptions.ShowVersion { // Customize version output - showVersion(globalOptions.CacheDir, versionFormat, version, cmd.OutOrStdout()) + return showVersion(globalOptions.CacheDir, versionFormat, version, cmd.OutOrStdout()) } else { return cmd.Help() } @@ -962,7 +962,10 @@ func NewKubernetesCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { func NewAWSCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { reportFlagGroup := flag.NewReportFlagGroup() compliance := flag.ComplianceFlag - compliance.Values = []string{types.ComplianceAWSCIS12, types.ComplianceAWSCIS14} + compliance.Values = []string{ + types.ComplianceAWSCIS12, + types.ComplianceAWSCIS14, + } reportFlagGroup.Compliance = &compliance // override usage as the accepted values differ for each subcommand. reportFlagGroup.ExitOnEOL = nil // disable '--exit-on-eol' @@ -1173,12 +1176,15 @@ func NewVersionCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { return cmd } -func showVersion(cacheDir, outputFormat, version string, w io.Writer) { +func showVersion(cacheDir, outputFormat, version string, w io.Writer) error { var dbMeta *metadata.Metadata var javadbMeta *metadata.Metadata mc := metadata.NewClient(cacheDir) - meta, _ := mc.Get() // nolint: errcheck + meta, err := mc.Get() + if err != nil { + return xerrors.Errorf("failed to get db metadata: %w", err) + } if !meta.UpdatedAt.IsZero() && !meta.NextUpdate.IsZero() && meta.Version != 0 { dbMeta = &metadata.Metadata{ Version: meta.Version, @@ -1189,7 +1195,10 @@ func showVersion(cacheDir, outputFormat, version string, w io.Writer) { } mcJava := javadb.NewMetadata(filepath.Join(cacheDir, "java-db")) - metaJava, _ := mcJava.Get() // nolint: errcheck + metaJava, err := mcJava.Get() + if err != nil { + return xerrors.Errorf("failed to get java db metadata: %w", err) + } if !metaJava.UpdatedAt.IsZero() && !metaJava.NextUpdate.IsZero() && metaJava.Version != 0 { javadbMeta = &metadata.Metadata{ Version: metaJava.Version, @@ -1202,17 +1211,23 @@ func showVersion(cacheDir, outputFormat, version string, w io.Writer) { var pbMeta *policy.Metadata pc, err := policy.NewClient(cacheDir, false) if pc != nil && err == nil { - pbMeta, _ = pc.GetMetadata() + pbMeta, err = pc.GetMetadata() + if err != nil { + return xerrors.Errorf("failed to get policy metadata: %w", err) + } } switch outputFormat { case "json": - _ = json.NewEncoder(w).Encode(VersionInfo{ + err = json.NewEncoder(w).Encode(VersionInfo{ Version: version, VulnerabilityDB: dbMeta, JavaDB: javadbMeta, PolicyBundle: pbMeta, }) + if err != nil { + return xerrors.Errorf("json encode error: %w", err) + } default: output := fmt.Sprintf("Version: %s\n", version) if dbMeta != nil { @@ -1241,6 +1256,7 @@ func showVersion(cacheDir, outputFormat, version string, w io.Writer) { } fmt.Fprintf(w, output) } + return nil } func validateArgs(cmd *cobra.Command, args []string) error { From 803870740456d887474d1166fd527a7909d8c191 Mon Sep 17 00:00:00 2001 From: knqyf263 Date: Sun, 23 Jul 2023 15:46:37 +0300 Subject: [PATCH 4/5] fix: lint issues --- pkg/commands/app.go | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/pkg/commands/app.go b/pkg/commands/app.go index 6aefa8beb960..f8c98ba0dfa6 100644 --- a/pkg/commands/app.go +++ b/pkg/commands/app.go @@ -217,7 +217,6 @@ func NewRootCommand(version string, globalFlags *flag.GlobalFlagGroup) *cobra.Co } else { return cmd.Help() } - return nil }, } @@ -1161,9 +1160,7 @@ func NewVersionCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command { Args: cobra.NoArgs, RunE: func(cmd *cobra.Command, args []string) error { options := globalFlags.ToOptions() - showVersion(options.CacheDir, versionFormat, cmd.Version, cmd.OutOrStdout()) - - return nil + return showVersion(options.CacheDir, versionFormat, cmd.Version, cmd.OutOrStdout()) }, SilenceErrors: true, SilenceUsage: true, From defa3c17787341f207014b6d9f96ecf238e62b38 Mon Sep 17 00:00:00 2001 From: knqyf263 Date: Sun, 23 Jul 2023 16:06:02 +0300 Subject: [PATCH 5/5] fix: do not fail on bogus cache dir --- pkg/commands/app.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pkg/commands/app.go b/pkg/commands/app.go index f8c98ba0dfa6..3d4ff093d797 100644 --- a/pkg/commands/app.go +++ b/pkg/commands/app.go @@ -1180,7 +1180,7 @@ func showVersion(cacheDir, outputFormat, version string, w io.Writer) error { mc := metadata.NewClient(cacheDir) meta, err := mc.Get() if err != nil { - return xerrors.Errorf("failed to get db metadata: %w", err) + log.Logger.Debugw("Failed to get DB metadata", "error", err) } if !meta.UpdatedAt.IsZero() && !meta.NextUpdate.IsZero() && meta.Version != 0 { dbMeta = &metadata.Metadata{ @@ -1194,7 +1194,7 @@ func showVersion(cacheDir, outputFormat, version string, w io.Writer) error { mcJava := javadb.NewMetadata(filepath.Join(cacheDir, "java-db")) metaJava, err := mcJava.Get() if err != nil { - return xerrors.Errorf("failed to get java db metadata: %w", err) + log.Logger.Debugw("Failed to get Java DB metadata", "error", err) } if !metaJava.UpdatedAt.IsZero() && !metaJava.NextUpdate.IsZero() && metaJava.Version != 0 { javadbMeta = &metadata.Metadata{ @@ -1210,7 +1210,7 @@ func showVersion(cacheDir, outputFormat, version string, w io.Writer) error { if pc != nil && err == nil { pbMeta, err = pc.GetMetadata() if err != nil { - return xerrors.Errorf("failed to get policy metadata: %w", err) + log.Logger.Debugw("Failed to get policy metadata", "error", err) } }