diff --git a/pkg/enqueuer/enqueuer.go b/pkg/enqueuer/enqueuer.go index 06c1664c91..bfd5f6e73f 100644 --- a/pkg/enqueuer/enqueuer.go +++ b/pkg/enqueuer/enqueuer.go @@ -56,9 +56,10 @@ type ItemList struct { } type S3Lister struct { - S3Paths []string `json:"s3_paths"` // s3:///key - Includes []string `json:"includes"` - Excludes []string `json:"excludes"` + S3Paths []string `json:"s3_paths"` // s3:///key + Includes []string `json:"includes"` + Excludes []string `json:"excludes"` + MaxResults *int64 `json:"-"` // this is not currently exposed to the user (it's used for validations) } type FilePathLister struct { @@ -246,7 +247,7 @@ func (e *Enqueuer) enqueueS3Paths(s3PathsLister *FilePathLister) (int, error) { var s3PathList []string uploader := newSQSBatchUploader(e.envConfig.APIName, e.envConfig.JobID, e.queueURL, e.aws.SQS()) - err := s3IteratorFromLister(e.aws, s3PathsLister.S3Lister, func(bucket string, s3Obj *s3.Object) (bool, error) { + _, err := s3IteratorFromLister(e.aws, s3PathsLister.S3Lister, func(bucket string, s3Obj *s3.Object) (bool, error) { s3Path := awslib.S3Path(bucket, *s3Obj.Key) s3PathList = append(s3PathList, s3Path) @@ -290,7 +291,7 @@ func (e *Enqueuer) enqueueS3FileContents(delimitedFiles *DelimitedFiles) (int, e uploader := newSQSBatchUploader(e.envConfig.APIName, e.envConfig.JobID, e.queueURL, e.aws.SQS()) bytesBuffer := bytes.NewBuffer([]byte{}) - err := s3IteratorFromLister(e.aws, delimitedFiles.S3Lister, func(bucket string, s3Obj *s3.Object) (bool, error) { + _, err := s3IteratorFromLister(e.aws, delimitedFiles.S3Lister, func(bucket string, s3Obj *s3.Object) (bool, error) { s3Path := awslib.S3Path(bucket, *s3Obj.Key) log.Info("enqueuing contents from file", zap.String("path", s3Path)) diff --git a/pkg/enqueuer/helpers.go b/pkg/enqueuer/helpers.go index d7c2246852..a8df5d43d6 100644 --- a/pkg/enqueuer/helpers.go +++ b/pkg/enqueuer/helpers.go @@ -64,13 +64,13 @@ func addJSONObjectsToQueue(uploader *sqsBatchUploader, jsonMessageList *jsonBuff return nil } -func s3IteratorFromLister(awsClient *awslib.Client, s3Lister S3Lister, fn func(string, *s3.Object) (bool, error)) error { +func s3IteratorFromLister(awsClient *awslib.Client, s3Lister S3Lister, fn func(string, *s3.Object) (bool, error)) (int64, error) { includeGlobPatterns := make([]glob.Glob, 0, len(s3Lister.Includes)) for _, includePattern := range s3Lister.Includes { globExpression, err := glob.Compile(includePattern, '/') if err != nil { - return errors.Wrap(err, "failed to interpret glob pattern", includePattern) + return 0, errors.Wrap(err, "failed to interpret glob pattern", includePattern) } includeGlobPatterns = append(includeGlobPatterns, globExpression) } @@ -79,20 +79,22 @@ func s3IteratorFromLister(awsClient *awslib.Client, s3Lister S3Lister, fn func(s for _, excludePattern := range s3Lister.Excludes { globExpression, err := glob.Compile(excludePattern, '/') if err != nil { - return errors.Wrap(err, "failed to interpret glob pattern", excludePattern) + return 0, errors.Wrap(err, "failed to interpret glob pattern", excludePattern) } excludeGlobPatterns = append(excludeGlobPatterns, globExpression) } + var numResults int64 + for _, s3Path := range s3Lister.S3Paths { bucket, key, err := awslib.SplitS3Path(s3Path) if err != nil { - return err + return 0, err } awsClientForBucket, err := awslib.NewFromClientS3Path(s3Path, awsClient) if err != nil { - return err + return 0, err } err = awsClientForBucket.S3Iterator(bucket, key, false, nil, nil, func(s3Obj *s3.Object) (bool, error) { @@ -117,15 +119,24 @@ func s3IteratorFromLister(awsClient *awslib.Client, s3Lister S3Lister, fn func(s } if !shouldSkip { - return fn(bucket, s3Obj) + shouldContinue, err := fn(bucket, s3Obj) + numResults++ + if s3Lister.MaxResults != nil && numResults >= *s3Lister.MaxResults { + shouldContinue = false + } + return shouldContinue, err } return true, nil }) if err != nil { - return err + return 0, err + } + + if s3Lister.MaxResults != nil && numResults >= *s3Lister.MaxResults { + return numResults, nil } } - return nil + return numResults, nil } diff --git a/pkg/operator/resources/job/batchapi/s3_iterator.go b/pkg/operator/resources/job/batchapi/s3_iterator.go index 80848f32c2..2860c27c5b 100644 --- a/pkg/operator/resources/job/batchapi/s3_iterator.go +++ b/pkg/operator/resources/job/batchapi/s3_iterator.go @@ -26,13 +26,13 @@ import ( ) // Takes in a function(shouldSkip, bucketName, s3.Object) -func s3IteratorFromLister(s3Lister schema.S3Lister, fn func(string, *s3.Object) (bool, error)) error { +func s3IteratorFromLister(s3Lister schema.S3Lister, fn func(string, *s3.Object) (bool, error)) (int64, error) { includeGlobPatterns := make([]glob.Glob, 0, len(s3Lister.Includes)) for _, includePattern := range s3Lister.Includes { globExpression, err := glob.Compile(includePattern, '/') if err != nil { - return errors.Wrap(err, "failed to interpret glob pattern", includePattern) + return 0, errors.Wrap(err, "failed to interpret glob pattern", includePattern) } includeGlobPatterns = append(includeGlobPatterns, globExpression) } @@ -41,20 +41,22 @@ func s3IteratorFromLister(s3Lister schema.S3Lister, fn func(string, *s3.Object) for _, excludePattern := range s3Lister.Excludes { globExpression, err := glob.Compile(excludePattern, '/') if err != nil { - return errors.Wrap(err, "failed to interpret glob pattern", excludePattern) + return 0, errors.Wrap(err, "failed to interpret glob pattern", excludePattern) } excludeGlobPatterns = append(excludeGlobPatterns, globExpression) } + var numResults int64 + for _, s3Path := range s3Lister.S3Paths { bucket, key, err := aws.SplitS3Path(s3Path) if err != nil { - return err + return 0, err } awsClientForBucket, err := aws.NewFromClientS3Path(s3Path, config.AWS) if err != nil { - return err + return 0, err } err = awsClientForBucket.S3Iterator(bucket, key, false, nil, nil, func(s3Obj *s3.Object) (bool, error) { @@ -79,15 +81,24 @@ func s3IteratorFromLister(s3Lister schema.S3Lister, fn func(string, *s3.Object) } if !shouldSkip { - return fn(bucket, s3Obj) + shouldContinue, err := fn(bucket, s3Obj) + numResults++ + if s3Lister.MaxResults != nil && numResults >= *s3Lister.MaxResults { + shouldContinue = false + } + return shouldContinue, err } return true, nil }) if err != nil { - return err + return 0, err + } + + if s3Lister.MaxResults != nil && numResults >= *s3Lister.MaxResults { + return numResults, nil } } - return nil + return numResults, nil } diff --git a/pkg/operator/resources/job/batchapi/validations.go b/pkg/operator/resources/job/batchapi/validations.go index fe51cd8d27..c28439491a 100644 --- a/pkg/operator/resources/job/batchapi/validations.go +++ b/pkg/operator/resources/job/batchapi/validations.go @@ -24,6 +24,7 @@ import ( awslib "github.com/cortexlabs/cortex/pkg/lib/aws" cr "github.com/cortexlabs/cortex/pkg/lib/configreader" "github.com/cortexlabs/cortex/pkg/lib/errors" + "github.com/cortexlabs/cortex/pkg/lib/pointer" "github.com/cortexlabs/cortex/pkg/operator/resources/job" "github.com/cortexlabs/cortex/pkg/operator/schema" "github.com/gobwas/glob" @@ -143,26 +144,30 @@ func validateS3Lister(s3Lister *schema.S3Lister) error { } } - filesFound := 0 for _, s3Path := range s3Lister.S3Paths { if !awslib.IsValidS3Path(s3Path) { return awslib.ErrorInvalidS3Path(s3Path) } + } - err := s3IteratorFromLister(*s3Lister, func(objPath string, s3Obj *s3.Object) (bool, error) { - filesFound++ - return false, nil - }) - if err != nil { - return errors.Wrap(err, s3Path) - } + shortCircuitLister := schema.S3Lister{ + S3Paths: s3Lister.S3Paths, + Includes: s3Lister.Includes, + Excludes: s3Lister.Excludes, + MaxResults: pointer.Int64(1), + } + numResults, err := s3IteratorFromLister(shortCircuitLister, func(objPath string, s3Obj *s3.Object) (bool, error) { + return false, nil + }) + if err != nil { + return err + } - if filesFound > 0 { - return nil - } + if numResults == 0 { + return ErrorNoS3FilesFound() } - return ErrorNoS3FilesFound() + return nil } func listFilesDryRun(s3Lister *schema.S3Lister) ([]string, error) { @@ -171,15 +176,14 @@ func listFilesDryRun(s3Lister *schema.S3Lister) ([]string, error) { if !awslib.IsValidS3Path(s3Path) { return nil, awslib.ErrorInvalidS3Path(s3Path) } + } - err := s3IteratorFromLister(*s3Lister, func(bucket string, s3Obj *s3.Object) (bool, error) { - s3Files = append(s3Files, awslib.S3Path(bucket, *s3Obj.Key)) - return true, nil - }) - - if err != nil { - return nil, errors.Wrap(err, s3Path) - } + _, err := s3IteratorFromLister(*s3Lister, func(bucket string, s3Obj *s3.Object) (bool, error) { + s3Files = append(s3Files, awslib.S3Path(bucket, *s3Obj.Key)) + return true, nil + }) + if err != nil { + return nil, err } if len(s3Files) == 0 { diff --git a/pkg/operator/schema/job_submission.go b/pkg/operator/schema/job_submission.go index 7fadf3b163..ad51932a20 100644 --- a/pkg/operator/schema/job_submission.go +++ b/pkg/operator/schema/job_submission.go @@ -28,9 +28,10 @@ type ItemList struct { } type S3Lister struct { - S3Paths []string `json:"s3_paths"` // s3:///key - Includes []string `json:"includes"` - Excludes []string `json:"excludes"` + S3Paths []string `json:"s3_paths"` // s3:///key + Includes []string `json:"includes"` + Excludes []string `json:"excludes"` + MaxResults *int64 `json:"-"` // this is not currently exposed to the user (it's used for validations) } type FilePathLister struct {