diff --git a/pkg/sources/s3/metrics.go b/pkg/sources/s3/metrics.go new file mode 100644 index 000000000000..3eb8b7c7bc00 --- /dev/null +++ b/pkg/sources/s3/metrics.go @@ -0,0 +1,99 @@ +package s3 + +import ( + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + + "github.com/trufflesecurity/trufflehog/v3/pkg/common" +) + +// metricsCollector defines the interface for recording S3 scan metrics. +type metricsCollector interface { + // Object metrics. + + RecordObjectScanned(bucket string) + RecordObjectSkipped(bucket, reason string) + RecordObjectError(bucket string) + + // Role metrics. + + RecordRoleScanned(roleArn string) + RecordBucketForRole(roleArn string) +} + +type collector struct { + objectsScanned *prometheus.CounterVec + objectsSkipped *prometheus.CounterVec + objectsErrors *prometheus.CounterVec + rolesScanned *prometheus.GaugeVec + bucketsPerRole *prometheus.GaugeVec +} + +var metricsInstance metricsCollector + +func init() { + metricsInstance = &collector{ + objectsScanned: promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: common.MetricsNamespace, + Subsystem: common.MetricsSubsystem, + Name: "objects_scanned_total", + Help: "Total number of S3 objects successfully scanned", + }, []string{"bucket"}), + + objectsSkipped: promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: common.MetricsNamespace, + Subsystem: common.MetricsSubsystem, + Name: "objects_skipped_total", + Help: "Total number of S3 objects skipped during scan", + }, []string{"bucket", "reason"}), + + objectsErrors: promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: common.MetricsNamespace, + Subsystem: common.MetricsSubsystem, + Name: "objects_errors_total", + Help: "Total number of errors encountered during S3 scan", + }, []string{"bucket"}), + + rolesScanned: promauto.NewGaugeVec(prometheus.GaugeOpts{ + Namespace: common.MetricsNamespace, + Subsystem: common.MetricsSubsystem, + Name: "roles_scanned", + Help: "Number of AWS roles being scanned", + }, []string{"role_arn"}), + + bucketsPerRole: promauto.NewGaugeVec(prometheus.GaugeOpts{ + Namespace: common.MetricsNamespace, + Subsystem: common.MetricsSubsystem, + Name: "buckets_per_role", + Help: "Number of buckets accessible per AWS role", + }, []string{"role_arn"}), + } +} + +func (c *collector) RecordObjectScanned(bucket string) { + c.objectsScanned.WithLabelValues(bucket).Inc() +} + +func (c *collector) RecordObjectSkipped(bucket, reason string) { + c.objectsSkipped.WithLabelValues(bucket, reason).Inc() +} + +func (c *collector) RecordObjectError(bucket string) { + c.objectsErrors.WithLabelValues(bucket).Inc() +} + +const defaultRoleARN = "default" + +func (c *collector) RecordRoleScanned(roleArn string) { + if roleArn == "" { + roleArn = defaultRoleARN + } + c.rolesScanned.WithLabelValues(roleArn).Set(1) +} + +func (c *collector) RecordBucketForRole(roleArn string) { + if roleArn == "" { + roleArn = defaultRoleARN + } + c.bucketsPerRole.WithLabelValues(roleArn).Inc() +} diff --git a/pkg/sources/s3/s3.go b/pkg/sources/s3/s3.go index 91970e9fd703..959a0bbfa1d4 100644 --- a/pkg/sources/s3/s3.go +++ b/pkg/sources/s3/s3.go @@ -48,6 +48,7 @@ type Source struct { checkpointer *Checkpointer sources.Progress + metricsCollector metricsCollector errorCount *sync.Map jobPool *errgroup.Group @@ -94,6 +95,7 @@ func (s *Source) Init( s.conn = &conn s.checkpointer = NewCheckpointer(ctx, conn.GetEnableResumption(), &s.Progress) + s.metricsCollector = metricsInstance s.setMaxObjectSize(conn.GetMaxObjectSize()) @@ -106,11 +108,12 @@ func (s *Source) Init( func (s *Source) Validate(ctx context.Context) []error { var errs []error - visitor := func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string) { + visitor := func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string) error { roleErrs := s.validateBucketAccess(c, defaultRegionClient, roleArn, buckets) if len(roleErrs) > 0 { errs = append(errs, roleErrs...) } + return nil } if err := s.visitRoles(ctx, visitor); err != nil { @@ -307,6 +310,7 @@ func (s *Source) scanBuckets( bucketsToScanCount := len(bucketsToScan) for bucketIdx := pos.index; bucketIdx < bucketsToScanCount; bucketIdx++ { + s.metricsCollector.RecordBucketForRole(role) bucket := bucketsToScan[bucketIdx] ctx := context.WithValue(ctx, "bucket", bucket) @@ -385,8 +389,9 @@ func (s *Source) scanBuckets( // Chunks emits chunks of bytes over a channel. func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk, _ ...sources.ChunkingTarget) error { - visitor := func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string) { + visitor := func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string) error { s.scanBuckets(c, defaultRegionClient, roleArn, buckets, chunksChan) + return nil } return s.visitRoles(ctx, visitor) @@ -427,6 +432,7 @@ func (s *Source) pageChunker( for objIdx, obj := range metadata.page.Contents { if obj == nil { + s.metricsCollector.RecordObjectSkipped(metadata.bucket, "nil_object") if err := s.checkpointer.UpdateObjectCompletion(ctx, objIdx, metadata.bucket, metadata.page.Contents); err != nil { ctx.Logger().Error(err, "could not update progress for nil object") } @@ -442,6 +448,7 @@ func (s *Source) pageChunker( // Skip GLACIER and GLACIER_IR objects. if obj.StorageClass == nil || strings.Contains(*obj.StorageClass, "GLACIER") { ctx.Logger().V(5).Info("Skipping object in storage class", "storage_class", *obj.StorageClass) + s.metricsCollector.RecordObjectSkipped(metadata.bucket, "storage_class") if err := s.checkpointer.UpdateObjectCompletion(ctx, objIdx, metadata.bucket, metadata.page.Contents); err != nil { ctx.Logger().Error(err, "could not update progress for glacier object") } @@ -451,6 +458,7 @@ func (s *Source) pageChunker( // Ignore large files. if *obj.Size > s.maxObjectSize { ctx.Logger().V(5).Info("Skipping %d byte file (over maxObjectSize limit)") + s.metricsCollector.RecordObjectSkipped(metadata.bucket, "size_limit") if err := s.checkpointer.UpdateObjectCompletion(ctx, objIdx, metadata.bucket, metadata.page.Contents); err != nil { ctx.Logger().Error(err, "could not update progress for large file") } @@ -460,6 +468,7 @@ func (s *Source) pageChunker( // File empty file. if *obj.Size == 0 { ctx.Logger().V(5).Info("Skipping empty file") + s.metricsCollector.RecordObjectSkipped(metadata.bucket, "empty_file") if err := s.checkpointer.UpdateObjectCompletion(ctx, objIdx, metadata.bucket, metadata.page.Contents); err != nil { ctx.Logger().Error(err, "could not update progress for empty file") } @@ -469,6 +478,7 @@ func (s *Source) pageChunker( // Skip incompatible extensions. if common.SkipFile(*obj.Key) { ctx.Logger().V(5).Info("Skipping file with incompatible extension") + s.metricsCollector.RecordObjectSkipped(metadata.bucket, "incompatible_extension") if err := s.checkpointer.UpdateObjectCompletion(ctx, objIdx, metadata.bucket, metadata.page.Contents); err != nil { ctx.Logger().Error(err, "could not update progress for incompatible file") } @@ -483,6 +493,7 @@ func (s *Source) pageChunker( if strings.HasSuffix(*obj.Key, "/") { ctx.Logger().V(5).Info("Skipping directory") + s.metricsCollector.RecordObjectSkipped(metadata.bucket, "directory") return nil } @@ -508,8 +519,12 @@ func (s *Source) pageChunker( Key: obj.Key, }) if err != nil { - if !strings.Contains(err.Error(), "AccessDenied") { + if strings.Contains(err.Error(), "AccessDenied") { + ctx.Logger().Error(err, "could not get S3 object; access denied") + s.metricsCollector.RecordObjectSkipped(metadata.bucket, "access_denied") + } else { ctx.Logger().Error(err, "could not get S3 object") + s.metricsCollector.RecordObjectError(metadata.bucket) } // According to the documentation for GetObjectWithContext, // the response can be non-nil even if there was an error. @@ -563,6 +578,7 @@ func (s *Source) pageChunker( if err := handlers.HandleFile(ctx, res.Body, chunkSkel, sources.ChanReporter{Ch: chunksChan}); err != nil { ctx.Logger().Error(err, "error handling file") + s.metricsCollector.RecordObjectError(metadata.bucket) return nil } @@ -580,6 +596,7 @@ func (s *Source) pageChunker( if err := s.checkpointer.UpdateObjectCompletion(ctx, objIdx, metadata.bucket, metadata.page.Contents); err != nil { ctx.Logger().Error(err, "could not update progress for scanned object") } + s.metricsCollector.RecordObjectScanned(metadata.bucket) return nil }) @@ -633,7 +650,7 @@ func (s *Source) validateBucketAccess(ctx context.Context, client *s3.S3, roleAr // If no roles are configured, it will call the function with an empty role ARN. func (s *Source) visitRoles( ctx context.Context, - f func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string), + f func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string) error, ) error { roles := s.conn.GetRoles() if len(roles) == 0 { @@ -641,6 +658,8 @@ func (s *Source) visitRoles( } for _, role := range roles { + s.metricsCollector.RecordRoleScanned(role) + client, err := s.newClient(defaultAWSRegion, role) if err != nil { return fmt.Errorf("could not create s3 client: %w", err) @@ -651,7 +670,9 @@ func (s *Source) visitRoles( return fmt.Errorf("role %q could not list any s3 buckets for scanning: %w", role, err) } - f(ctx, client, role, bucketsToScan) + if err := f(ctx, client, role, bucketsToScan); err != nil { + return err + } } return nil diff --git a/pkg/sources/s3/s3_integration_test.go b/pkg/sources/s3/s3_integration_test.go index 1801e29e20e2..3276e40fdb0e 100644 --- a/pkg/sources/s3/s3_integration_test.go +++ b/pkg/sources/s3/s3_integration_test.go @@ -82,6 +82,37 @@ func TestSource_ChunksLarge(t *testing.T) { assert.Equal(t, got, wantChunkCount) } +func TestSourceChunksNoResumption(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) + defer cancel() + + s := Source{} + connection := &sourcespb.S3{ + Credential: &sourcespb.S3_Unauthenticated{}, + Buckets: []string{"trufflesec-ahrav-test-2"}, + } + conn, err := anypb.New(connection) + if err != nil { + t.Fatal(err) + } + + err = s.Init(ctx, "test name", 0, 0, false, conn, 1) + chunksCh := make(chan *sources.Chunk) + go func() { + defer close(chunksCh) + err = s.Chunks(ctx, chunksCh) + assert.Nil(t, err) + }() + + wantChunkCount := 19787 + got := 0 + + for range chunksCh { + got++ + } + assert.Equal(t, got, wantChunkCount) +} + func TestSource_Validate(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Second*15) defer cancel()