Skip to content
This repository has been archived by the owner on Dec 9, 2022. It is now read-only.

Commit

Permalink
Refactor of get
Browse files Browse the repository at this point in the history
  • Loading branch information
Pedro Cunha committed Sep 25, 2017
1 parent 595d01b commit 8bf1d6b
Showing 1 changed file with 66 additions and 33 deletions.
99 changes: 66 additions & 33 deletions cli/data/get.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,21 @@ import (
"strings"
)

type S3Target struct {
bucket string
prefix string
path string
}

func (s *S3Target) copy() *S3Target {
clone := *s
return &clone
}

func (t *S3Target) fullPath() string {
return fmt.Sprintf("%s/%s/%s", t.bucket, t.prefix, t.path);
}

var getBranch string
var getCommitPath string

Expand All @@ -44,7 +59,14 @@ $ paddle data get -b experimental trained-model/version1 dest/path
if !viper.IsSet("bucket") {
exitErrorf("Bucket not defined. Please define 'bucket' in your config file.")
}
fetchPath(viper.GetString("bucket"), args[0], getBranch, getCommitPath, args[1])

source := S3Target{
bucket: viper.GetString("bucket"),
prefix: fmt.Sprintf("%s/%s", args[0], getBranch),
path: getCommitPath,
}

copyPathToDestination(&source, args[1])
},
}

Expand All @@ -53,69 +75,80 @@ func init() {
getCmd.Flags().StringVarP(&getCommitPath, "path", "p", "HEAD", "Path to fetch (instead of HEAD)")
}

func fetchPath(bucket string, version string, branch string, path string, destination string) {
sess := session.Must(session.NewSessionWithOptions(session.Options{
func copyPathToDestination(source *S3Target, destination string) {
session := session.Must(session.NewSessionWithOptions(session.Options{
SharedConfigState: session.SharedConfigEnable,
}))

if path == "HEAD" {
svc := s3.New(sess)
headPath := fmt.Sprintf("%s/%s/HEAD", version, branch)
fmt.Println(headPath)
out, err := svc.GetObject(&s3.GetObjectInput{
Bucket: aws.String(bucket),
Key: aws.String(headPath),
})
if err != nil {
exitErrorf("%v", err)
}
buf := new(bytes.Buffer)
buf.ReadFrom(out.Body)
path = buf.String()
} else {
path = fmt.Sprintf("%s/%s/%s", version, branch, path)
/*
* HEAD contains the path to latest folder
*/
if source.path == "HEAD" {
source = source.copy()
source.path = readHEAD(session, source)
}

fmt.Println("Copying " + source.fullPath() + " to " + destination)
copy(session, source, destination)
}

func readHEAD(session *session.Session, source *S3Target) string {
svc := s3.New(session)
key := fmt.Sprintf("%s/HEAD", source.prefix)

out, err := svc.GetObject(&s3.GetObjectInput{
Bucket: aws.String(source.bucket),
Key: aws.String(key),
})

if err != nil {
exitErrorf("%v", err)
}
fmt.Println("Fetching " + path)
getBucketObjects(sess, bucket, path, destination)

buf := new(bytes.Buffer)
buf.ReadFrom(out.Body)
return buf.String()
}

func getBucketObjects(sess *session.Session, bucket string, prefix string, dest string) {
func copy(session *session.Session, source *S3Target, destination string) {
query := &s3.ListObjectsV2Input{
Bucket: aws.String(bucket),
Prefix: aws.String(prefix),
Bucket: aws.String(source.bucket),
Prefix: aws.String(source.prefix + "/" + source.path),
}
svc := s3.New(sess)
svc := s3.New(session)

truncatedListing := true

for truncatedListing {
resp, err := svc.ListObjectsV2(query)
response, err := svc.ListObjectsV2(query)

if err != nil {
fmt.Println(err.Error())
return
}
getObjectsAll(bucket, resp, svc, prefix, dest)
query.ContinuationToken = resp.NextContinuationToken
truncatedListing = *resp.IsTruncated
copyToLocalFiles(svc, response.Contents, source, destination)

// Check if more results
query.ContinuationToken = response.NextContinuationToken
truncatedListing = *response.IsTruncated
}
}

func getObjectsAll(bucket string, bucketObjectsList *s3.ListObjectsV2Output, s3Client *s3.S3, prefix string, dest string) {
for _, key := range bucketObjectsList.Contents {
func copyToLocalFiles(s3Client *s3.S3, objects []*s3.Object, source *S3Target, destination string) {
for _, key := range objects {
destFilename := *key.Key
if strings.HasSuffix(*key.Key, "/") {
fmt.Println("Got a directory")
continue
}
out, err := s3Client.GetObject(&s3.GetObjectInput{
Bucket: aws.String(bucket),
Bucket: aws.String(source.bucket),
Key: key.Key,
})
if err != nil {
exitErrorf("%v", err)
}
destFilePath := dest + "/" + strings.TrimPrefix(destFilename, prefix+"/")
destFilePath := destination + "/" + strings.TrimPrefix(destFilename, source.prefix + "/")
err = os.MkdirAll(filepath.Dir(destFilePath), 0777)
fmt.Print(destFilePath)
destFile, err := os.Create(destFilePath)
Expand Down

0 comments on commit 8bf1d6b

Please sign in to comment.