From fcdd7ab3d90b2b989515a3b14accada45a6ed395 Mon Sep 17 00:00:00 2001 From: Ahrav Dutta Date: Mon, 18 Nov 2024 14:20:16 -0800 Subject: [PATCH] remove context cancellation logic --- pkg/sources/s3/progress_tracker.go | 17 ++--- pkg/sources/s3/s3.go | 73 +++++--------------- pkg/sources/s3/s3_integration_test.go | 97 ++++----------------------- 3 files changed, 38 insertions(+), 149 deletions(-) diff --git a/pkg/sources/s3/progress_tracker.go b/pkg/sources/s3/progress_tracker.go index a6123caf83ee..485868f48aea 100644 --- a/pkg/sources/s3/progress_tracker.go +++ b/pkg/sources/s3/progress_tracker.go @@ -42,7 +42,8 @@ import ( type ProgressTracker struct { enabled bool - sync.Mutex + mu sync.Mutex // protects concurrent access to completion state. + // completedObjects tracks which indices in the current page have been processed. completedObjects []bool completionOrder []int // Track the order in which objects complete @@ -74,8 +75,8 @@ func (p *ProgressTracker) Reset() { return } - p.Lock() - defer p.Unlock() + p.mu.Lock() + defer p.mu.Unlock() // Store the current completed count before moving to next page. p.completedObjects = make([]bool, defaultMaxObjectsPerPage) p.completionOrder = make([]int, 0, defaultMaxObjectsPerPage) @@ -143,9 +144,9 @@ func (p *ProgressTracker) Complete(_ context.Context, message string) error { // - Objects completed: [0,1,2,3,4,5,7,8] // - The checkpoint will only include objects 0-5 since they are consecutive // - If scanning is interrupted and resumed: -// - Scan resumes after object 5 (the last checkpoint) -// - Objects 7-8 will be re-scanned even though they completed before -// - This ensures object 6 is not missed +// - Scan resumes after object 5 (the last checkpoint) +// - Objects 7-8 will be re-scanned even though they completed before +// - This ensures object 6 is not missed func (p *ProgressTracker) UpdateObjectProgress( ctx context.Context, completedIdx int, @@ -163,8 +164,8 @@ func (p *ProgressTracker) UpdateObjectProgress( return fmt.Errorf("completed index %d exceeds maximum page size", completedIdx) } - p.Lock() - defer p.Unlock() + p.mu.Lock() + defer p.mu.Unlock() // Only track completion if this is the first time this index is marked complete. if !p.completedObjects[completedIdx] { diff --git a/pkg/sources/s3/s3.go b/pkg/sources/s3/s3.go index d20ddd86dd89..dc4e8c2d80da 100644 --- a/pkg/sources/s3/s3.go +++ b/pkg/sources/s3/s3.go @@ -106,12 +106,11 @@ func (s *Source) Init( func (s *Source) Validate(ctx context.Context) []error { var errs []error - visitor := func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string) error { + visitor := func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string) { roleErrs := s.validateBucketAccess(c, defaultRegionClient, roleArn, buckets) if len(roleErrs) > 0 { errs = append(errs, roleErrs...) } - return nil } if err := s.visitRoles(ctx, visitor); err != nil { @@ -214,30 +213,6 @@ func (s *Source) getBucketsToScan(client *s3.S3) ([]string, error) { return bucketsToScan, nil } -// workerSignal provides thread-safe tracking of cancellation state across multiple -// goroutines processing S3 bucket pages. It ensures graceful shutdown when the context -// is cancelled during bucket scanning operations. -// -// This type serves several key purposes: -// 1. AWS ListObjectsV2PagesWithContext requires a callback that can only return bool, -// not error. workerSignal bridges this gap by providing a way to communicate -// cancellation back to the caller. -// 2. The pageChunker spawns multiple concurrent workers to process objects within -// each page. workerSignal enables these workers to detect and respond to -// cancellation signals. -// 3. Ensures proper progress tracking by allowing the main scanning loop to detect -// when workers have been cancelled and handle cleanup appropriately. -type workerSignal struct{ cancelled atomic.Bool } - -// newWorkerSignal creates a new workerSignal -func newWorkerSignal() *workerSignal { return new(workerSignal) } - -// MarkCancelled marks that a context cancellation was detected. -func (ws *workerSignal) MarkCancelled() { ws.cancelled.Store(true) } - -// WasCancelled returns true if context cancellation was detected. -func (ws *workerSignal) WasCancelled() bool { return ws.cancelled.Load() } - // pageMetadata contains metadata about a single page of S3 objects being scanned. type pageMetadata struct { bucket string // The name of the S3 bucket being scanned @@ -248,9 +223,8 @@ type pageMetadata struct { // processingState tracks the state of concurrent S3 object processing. type processingState struct { - errorCount *sync.Map // Thread-safe map tracking errors per prefix - objectCount *uint64 // Total number of objects processed - workerSignal *workerSignal // Coordinates cancellation across worker goroutines + errorCount *sync.Map // Thread-safe map tracking errors per prefix + objectCount *uint64 // Total number of objects processed } func (s *Source) scanBuckets( @@ -259,7 +233,7 @@ func (s *Source) scanBuckets( role string, bucketsToScan []string, chunksChan chan *sources.Chunk, -) error { +) { if role != "" { ctx = context.WithValue(ctx, "role", role) } @@ -268,21 +242,20 @@ func (s *Source) scanBuckets( // Determine starting point for resuming scan. resumePoint, err := s.progressTracker.GetResumePoint(ctx) if err != nil { - return fmt.Errorf("failed to get resume point :%w", err) + ctx.Logger().Error(err, "failed to get resume point") + return } startIdx, _ := slices.BinarySearch(bucketsToScan, resumePoint.CurrentBucket) - // Create worker signal to track cancellation across page processing. - workerSignal := newWorkerSignal() - bucketsToScanCount := len(bucketsToScan) for i := startIdx; i < bucketsToScanCount; i++ { bucket := bucketsToScan[i] ctx := context.WithValue(ctx, "bucket", bucket) if common.IsDone(ctx) { - return ctx.Err() + ctx.Logger().Error(ctx.Err(), "context done, while scanning bucket") + return } ctx.Logger().V(3).Info("Scanning bucket") @@ -291,7 +264,7 @@ func (s *Source) scanBuckets( i, len(bucketsToScan), fmt.Sprintf("Bucket: %s", bucket), - s.Progress.EncodedResumeInfo, // Do not set, resume handled by progressTracker + s.Progress.EncodedResumeInfo, ) regionalClient, err := s.getRegionalClientForBucket(ctx, client, role, bucket) @@ -323,25 +296,15 @@ func (s *Source) scanBuckets( page: page, } processingState := processingState{ - errorCount: &errorCount, - objectCount: &objectCount, - workerSignal: workerSignal, + errorCount: &errorCount, + objectCount: &objectCount, } s.pageChunker(ctx, pageMetadata, processingState, chunksChan) - if workerSignal.WasCancelled() { - return false // Stop pagination - } - pageNumber++ return true }) - // Check if we stopped due to cancellation. - if workerSignal.WasCancelled() { - return ctx.Err() - } - if err != nil { if role == "" { ctx.Logger().Error(err, "could not list objects in bucket") @@ -361,14 +324,12 @@ func (s *Source) scanBuckets( fmt.Sprintf("Completed scanning source %s. %d objects scanned.", s.name, objectCount), "", ) - - return nil } // Chunks emits chunks of bytes over a channel. func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk, _ ...sources.ChunkingTarget) error { - visitor := func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string) error { - return s.scanBuckets(c, defaultRegionClient, roleArn, buckets, chunksChan) + visitor := func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string) { + s.scanBuckets(c, defaultRegionClient, roleArn, buckets, chunksChan) } return s.visitRoles(ctx, visitor) @@ -418,7 +379,6 @@ func (s *Source) pageChunker( ctx = context.WithValues(ctx, "key", *obj.Key, "size", *obj.Size) if common.IsDone(ctx) { - state.workerSignal.MarkCancelled() return } @@ -461,7 +421,6 @@ func (s *Source) pageChunker( s.jobPool.Go(func() error { defer common.RecoverWithExit(ctx) if common.IsDone(ctx) { - state.workerSignal.MarkCancelled() return ctx.Err() } @@ -617,7 +576,7 @@ func (s *Source) validateBucketAccess(ctx context.Context, client *s3.S3, roleAr // If no roles are configured, it will call the function with an empty role ARN. func (s *Source) visitRoles( ctx context.Context, - f func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string) error, + f func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string), ) error { roles := s.conn.GetRoles() if len(roles) == 0 { @@ -635,9 +594,7 @@ func (s *Source) visitRoles( return fmt.Errorf("role %q could not list any s3 buckets for scanning: %w", role, err) } - if err := f(ctx, client, role, bucketsToScan); err != nil { - return err - } + f(ctx, client, role, bucketsToScan) } return nil diff --git a/pkg/sources/s3/s3_integration_test.go b/pkg/sources/s3/s3_integration_test.go index ef45956e007e..ea555b303bca 100644 --- a/pkg/sources/s3/s3_integration_test.go +++ b/pkg/sources/s3/s3_integration_test.go @@ -4,7 +4,6 @@ package s3 import ( - "encoding/json" "fmt" "sync" "testing" @@ -250,11 +249,16 @@ func TestSource_Validate(t *testing.T) { } func TestSourceChunksResumption(t *testing.T) { - // First scan - simulate interruption. ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() src := new(Source) + src.Progress = sources.Progress{ + Message: "Bucket: trufflesec-ahrav-test-2", + EncodedResumeInfo: "{\"current_bucket\":\"trufflesec-ahrav-test-2\",\"start_after\":\"test-dir/\"}", + SectionsCompleted: 0, + SectionsRemaining: 1, + } connection := &sourcespb.S3{ Credential: &sourcespb.S3_Unauthenticated{}, Buckets: []string{"trufflesec-ahrav-test-2"}, @@ -267,97 +271,24 @@ func TestSourceChunksResumption(t *testing.T) { require.NoError(t, err) chunksCh := make(chan *sources.Chunk) - var firstScanCount int64 - const cancelAfterChunks = 15_000 + var count int cancelCtx, ctxCancel := context.WithCancel(ctx) defer ctxCancel() - // Start first scan and collect chunks until chunk limit. go func() { defer close(chunksCh) err = src.Chunks(cancelCtx, chunksCh) - assert.Error(t, err, "Expected context cancellation error") + assert.NoError(t, err, "Should not error during scan") }() - // Process chunks until we hit our limit for range chunksCh { - firstScanCount++ - if firstScanCount >= cancelAfterChunks { - ctxCancel() // Cancel context after processing desired number of chunks - break - } - } - - // Verify we processed exactly the number of chunks we wanted. - assert.Equal(t, int64(cancelAfterChunks), firstScanCount, - "Should have processed exactly %d chunks in first scan", cancelAfterChunks) - - // Verify we have processed some chunks and have resumption info. - assert.Greater(t, firstScanCount, int64(0), "Should have processed some chunks in first scan") - - progress := src.GetProgress() - assert.NotEmpty(t, progress.EncodedResumeInfo, "Progress.EncodedResumeInfo should not be empty") - - firstScanCompletedIndex := progress.SectionsCompleted - - var resumeInfo ResumeInfo - err = json.Unmarshal([]byte(progress.EncodedResumeInfo), &resumeInfo) - require.NoError(t, err, "Should be able to decode resume info") - - // Verify resume info contains expected fields. - assert.Equal(t, "trufflesec-ahrav-test-2", resumeInfo.CurrentBucket, "Resume info should contain correct bucket") - assert.NotEmpty(t, resumeInfo.StartAfter, "Resume info should contain a StartAfter key") - - // Store the key where first scan stopped. - firstScanLastKey := resumeInfo.StartAfter - - // Second scan - should resume from where first scan left off. - ctx2 := context.Background() - src2 := &Source{Progress: *src.GetProgress()} - err = src2.Init(ctx2, "test name", 0, 0, false, conn, 4) - require.NoError(t, err) - - chunksCh2 := make(chan *sources.Chunk) - var secondScanCount int64 - - go func() { - defer close(chunksCh2) - err = src2.Chunks(ctx2, chunksCh2) - assert.NoError(t, err) - }() - - // Process second scan chunks and verify progress. - for range chunksCh2 { - secondScanCount++ - - // Get current progress during scan. - currentProgress := src2.GetProgress() - assert.GreaterOrEqual(t, currentProgress.SectionsCompleted, firstScanCompletedIndex, - "Progress should be greater or equal to first scan") - if currentProgress.EncodedResumeInfo != "" { - var currentResumeInfo ResumeInfo - err := json.Unmarshal([]byte(currentProgress.EncodedResumeInfo), ¤tResumeInfo) - require.NoError(t, err) - - // Verify that we're always scanning forward from where we left off. - assert.GreaterOrEqual(t, currentResumeInfo.StartAfter, firstScanLastKey, - "Second scan should never process keys before where first scan ended") - } + count++ } - // Verify total coverage. - expectedTotal := int64(19787) - actualTotal := firstScanCount + secondScanCount - - // Because of our resumption logic favoring completeness over speed, we can - // re-scan some objects. - assert.GreaterOrEqual(t, actualTotal, expectedTotal, - "Total processed chunks should meet or exceed expected count") - assert.Less(t, actualTotal, 2*expectedTotal, - "Total processed chunks should not be more than double expected count") - - finalProgress := src2.GetProgress() - assert.Equal(t, 1, int(finalProgress.SectionsCompleted), "Should have completed sections") - assert.Equal(t, 1, int(finalProgress.SectionsRemaining), "Should have remaining sections") + // Verify that we processed all remaining data on resume. + // Also verify that we processed less than the total number of chunks for the source. + sourceTotalChunkCount := 19787 + assert.Equal(t, 9638, count, "Should have processed all remaining data on resume") + assert.Less(t, count, sourceTotalChunkCount, "Should have processed less than total chunks on resume") }