Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve batch job submit validation efficiency #2179

Merged
merged 4 commits into from
May 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions pkg/enqueuer/enqueuer.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,10 @@ type ItemList struct {
}

type S3Lister struct {
S3Paths []string `json:"s3_paths"` // s3://<bucket_name>/key
Includes []string `json:"includes"`
Excludes []string `json:"excludes"`
S3Paths []string `json:"s3_paths"` // s3://<bucket_name>/key
Includes []string `json:"includes"`
Excludes []string `json:"excludes"`
MaxResults *int64 `json:"-"` // this is not currently exposed to the user (it's used for validations)
}

type FilePathLister struct {
Expand Down Expand Up @@ -246,7 +247,7 @@ func (e *Enqueuer) enqueueS3Paths(s3PathsLister *FilePathLister) (int, error) {
var s3PathList []string
uploader := newSQSBatchUploader(e.envConfig.APIName, e.envConfig.JobID, e.queueURL, e.aws.SQS())

err := s3IteratorFromLister(e.aws, s3PathsLister.S3Lister, func(bucket string, s3Obj *s3.Object) (bool, error) {
_, err := s3IteratorFromLister(e.aws, s3PathsLister.S3Lister, func(bucket string, s3Obj *s3.Object) (bool, error) {
s3Path := awslib.S3Path(bucket, *s3Obj.Key)

s3PathList = append(s3PathList, s3Path)
Expand Down Expand Up @@ -290,7 +291,7 @@ func (e *Enqueuer) enqueueS3FileContents(delimitedFiles *DelimitedFiles) (int, e
uploader := newSQSBatchUploader(e.envConfig.APIName, e.envConfig.JobID, e.queueURL, e.aws.SQS())

bytesBuffer := bytes.NewBuffer([]byte{})
err := s3IteratorFromLister(e.aws, delimitedFiles.S3Lister, func(bucket string, s3Obj *s3.Object) (bool, error) {
_, err := s3IteratorFromLister(e.aws, delimitedFiles.S3Lister, func(bucket string, s3Obj *s3.Object) (bool, error) {
s3Path := awslib.S3Path(bucket, *s3Obj.Key)
log.Info("enqueuing contents from file", zap.String("path", s3Path))

Expand Down
27 changes: 19 additions & 8 deletions pkg/enqueuer/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,13 @@ func addJSONObjectsToQueue(uploader *sqsBatchUploader, jsonMessageList *jsonBuff
return nil
}

func s3IteratorFromLister(awsClient *awslib.Client, s3Lister S3Lister, fn func(string, *s3.Object) (bool, error)) error {
func s3IteratorFromLister(awsClient *awslib.Client, s3Lister S3Lister, fn func(string, *s3.Object) (bool, error)) (int64, error) {
includeGlobPatterns := make([]glob.Glob, 0, len(s3Lister.Includes))

for _, includePattern := range s3Lister.Includes {
globExpression, err := glob.Compile(includePattern, '/')
if err != nil {
return errors.Wrap(err, "failed to interpret glob pattern", includePattern)
return 0, errors.Wrap(err, "failed to interpret glob pattern", includePattern)
}
includeGlobPatterns = append(includeGlobPatterns, globExpression)
}
Expand All @@ -79,20 +79,22 @@ func s3IteratorFromLister(awsClient *awslib.Client, s3Lister S3Lister, fn func(s
for _, excludePattern := range s3Lister.Excludes {
globExpression, err := glob.Compile(excludePattern, '/')
if err != nil {
return errors.Wrap(err, "failed to interpret glob pattern", excludePattern)
return 0, errors.Wrap(err, "failed to interpret glob pattern", excludePattern)
}
excludeGlobPatterns = append(excludeGlobPatterns, globExpression)
}

var numResults int64

for _, s3Path := range s3Lister.S3Paths {
bucket, key, err := awslib.SplitS3Path(s3Path)
if err != nil {
return err
return 0, err
}

awsClientForBucket, err := awslib.NewFromClientS3Path(s3Path, awsClient)
if err != nil {
return err
return 0, err
}

err = awsClientForBucket.S3Iterator(bucket, key, false, nil, nil, func(s3Obj *s3.Object) (bool, error) {
Expand All @@ -117,15 +119,24 @@ func s3IteratorFromLister(awsClient *awslib.Client, s3Lister S3Lister, fn func(s
}

if !shouldSkip {
return fn(bucket, s3Obj)
shouldContinue, err := fn(bucket, s3Obj)
numResults++
if s3Lister.MaxResults != nil && numResults >= *s3Lister.MaxResults {
shouldContinue = false
}
return shouldContinue, err
}

return true, nil
})
if err != nil {
return err
return 0, err
}

if s3Lister.MaxResults != nil && numResults >= *s3Lister.MaxResults {
return numResults, nil
}
}

return nil
return numResults, nil
}
27 changes: 19 additions & 8 deletions pkg/operator/resources/job/batchapi/s3_iterator.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@ import (
)

// Takes in a function(shouldSkip, bucketName, s3.Object)
func s3IteratorFromLister(s3Lister schema.S3Lister, fn func(string, *s3.Object) (bool, error)) error {
func s3IteratorFromLister(s3Lister schema.S3Lister, fn func(string, *s3.Object) (bool, error)) (int64, error) {
includeGlobPatterns := make([]glob.Glob, 0, len(s3Lister.Includes))

for _, includePattern := range s3Lister.Includes {
globExpression, err := glob.Compile(includePattern, '/')
if err != nil {
return errors.Wrap(err, "failed to interpret glob pattern", includePattern)
return 0, errors.Wrap(err, "failed to interpret glob pattern", includePattern)
}
includeGlobPatterns = append(includeGlobPatterns, globExpression)
}
Expand All @@ -41,20 +41,22 @@ func s3IteratorFromLister(s3Lister schema.S3Lister, fn func(string, *s3.Object)
for _, excludePattern := range s3Lister.Excludes {
globExpression, err := glob.Compile(excludePattern, '/')
if err != nil {
return errors.Wrap(err, "failed to interpret glob pattern", excludePattern)
return 0, errors.Wrap(err, "failed to interpret glob pattern", excludePattern)
}
excludeGlobPatterns = append(excludeGlobPatterns, globExpression)
}

var numResults int64

for _, s3Path := range s3Lister.S3Paths {
bucket, key, err := aws.SplitS3Path(s3Path)
if err != nil {
return err
return 0, err
}

awsClientForBucket, err := aws.NewFromClientS3Path(s3Path, config.AWS)
if err != nil {
return err
return 0, err
}

err = awsClientForBucket.S3Iterator(bucket, key, false, nil, nil, func(s3Obj *s3.Object) (bool, error) {
Expand All @@ -79,15 +81,24 @@ func s3IteratorFromLister(s3Lister schema.S3Lister, fn func(string, *s3.Object)
}

if !shouldSkip {
return fn(bucket, s3Obj)
shouldContinue, err := fn(bucket, s3Obj)
numResults++
if s3Lister.MaxResults != nil && numResults >= *s3Lister.MaxResults {
shouldContinue = false
}
return shouldContinue, err
}

return true, nil
})
if err != nil {
return err
return 0, err
}

if s3Lister.MaxResults != nil && numResults >= *s3Lister.MaxResults {
return numResults, nil
}
}

return nil
return numResults, nil
}
44 changes: 24 additions & 20 deletions pkg/operator/resources/job/batchapi/validations.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
awslib "github.com/cortexlabs/cortex/pkg/lib/aws"
cr "github.com/cortexlabs/cortex/pkg/lib/configreader"
"github.com/cortexlabs/cortex/pkg/lib/errors"
"github.com/cortexlabs/cortex/pkg/lib/pointer"
"github.com/cortexlabs/cortex/pkg/operator/resources/job"
"github.com/cortexlabs/cortex/pkg/operator/schema"
"github.com/gobwas/glob"
Expand Down Expand Up @@ -143,26 +144,30 @@ func validateS3Lister(s3Lister *schema.S3Lister) error {
}
}

filesFound := 0
for _, s3Path := range s3Lister.S3Paths {
if !awslib.IsValidS3Path(s3Path) {
return awslib.ErrorInvalidS3Path(s3Path)
}
}

err := s3IteratorFromLister(*s3Lister, func(objPath string, s3Obj *s3.Object) (bool, error) {
filesFound++
return false, nil
})
if err != nil {
return errors.Wrap(err, s3Path)
}
shortCircuitLister := schema.S3Lister{
S3Paths: s3Lister.S3Paths,
Includes: s3Lister.Includes,
Excludes: s3Lister.Excludes,
MaxResults: pointer.Int64(1),
}
numResults, err := s3IteratorFromLister(shortCircuitLister, func(objPath string, s3Obj *s3.Object) (bool, error) {
return false, nil
})
if err != nil {
return err
}

if filesFound > 0 {
return nil
}
if numResults == 0 {
return ErrorNoS3FilesFound()
}

return ErrorNoS3FilesFound()
return nil
}

func listFilesDryRun(s3Lister *schema.S3Lister) ([]string, error) {
Expand All @@ -171,15 +176,14 @@ func listFilesDryRun(s3Lister *schema.S3Lister) ([]string, error) {
if !awslib.IsValidS3Path(s3Path) {
return nil, awslib.ErrorInvalidS3Path(s3Path)
}
}

err := s3IteratorFromLister(*s3Lister, func(bucket string, s3Obj *s3.Object) (bool, error) {
s3Files = append(s3Files, awslib.S3Path(bucket, *s3Obj.Key))
return true, nil
})

if err != nil {
return nil, errors.Wrap(err, s3Path)
}
_, err := s3IteratorFromLister(*s3Lister, func(bucket string, s3Obj *s3.Object) (bool, error) {
s3Files = append(s3Files, awslib.S3Path(bucket, *s3Obj.Key))
return true, nil
})
if err != nil {
return nil, err
}

if len(s3Files) == 0 {
Expand Down
7 changes: 4 additions & 3 deletions pkg/operator/schema/job_submission.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@ type ItemList struct {
}

type S3Lister struct {
S3Paths []string `json:"s3_paths"` // s3://<bucket_name>/key
Includes []string `json:"includes"`
Excludes []string `json:"excludes"`
S3Paths []string `json:"s3_paths"` // s3://<bucket_name>/key
Includes []string `json:"includes"`
Excludes []string `json:"excludes"`
MaxResults *int64 `json:"-"` // this is not currently exposed to the user (it's used for validations)
}

type FilePathLister struct {
Expand Down