diff --git a/README.md b/README.md index 413c199..5438767 100644 --- a/README.md +++ b/README.md @@ -9,4 +9,3 @@ Work in progress. ``` $ paddle help ``` - diff --git a/cli/data/cmd.go b/cli/data/cmd.go index 7ba51e0..7c2907e 100644 --- a/cli/data/cmd.go +++ b/cli/data/cmd.go @@ -27,13 +27,5 @@ var DataCmd = &cobra.Command{ func init() { DataCmd.AddCommand(commitCmd) - // Here you will define your flags and configuration settings. - - // Cobra supports Persistent Flags which will work for this command - // and all subcommands, e.g.: - // dataCmd.PersistentFlags().String("foo", "", "A help for foo") - - // Cobra supports local flags which will only run when this command - // is called directly, e.g.: - // dataCmd.Flags().BoolP("toggle", "t", false, "Help message for toggle") + DataCmd.AddCommand(getCmd) } diff --git a/cli/data/commit.go b/cli/data/commit.go index 1675471..b61e5a6 100644 --- a/cli/data/commit.go +++ b/cli/data/commit.go @@ -29,7 +29,7 @@ import ( "time" ) -var branch string +var commitBranch string var commitCmd = &cobra.Command{ Use: "commit [source path] [version]", @@ -39,23 +39,18 @@ var commitCmd = &cobra.Command{ Example: -$ paddle data commit -b experimantal source/path version1 +$ paddle data commit -b experimental source/path trained-model/version1 `, Run: func(cmd *cobra.Command, args []string) { if !viper.IsSet("bucket") { exitErrorf("Bucket not defined. Please define 'bucket' in your config file.") } - commitPath(args[0], viper.GetString("bucket"), args[1], branch) + commitPath(args[0], viper.GetString("bucket"), args[1], commitBranch) }, } func init() { - commitCmd.Flags().StringVarP(&branch, "branch", "b", "master", "Branch to work on") -} - -func exitErrorf(msg string, args ...interface{}) { - fmt.Fprintf(os.Stderr, msg+"\n", args...) - os.Exit(1) + commitCmd.Flags().StringVarP(&commitBranch, "branch", "b", "master", "Branch to work on") } func commitPath(path string, bucket string, version string, branch string) { diff --git a/cli/data/common.go b/cli/data/common.go new file mode 100644 index 0000000..06dd8ec --- /dev/null +++ b/cli/data/common.go @@ -0,0 +1,11 @@ +package data + +import ( + "fmt" + "os" +) + +func exitErrorf(msg string, args ...interface{}) { + fmt.Fprintf(os.Stderr, msg+"\n", args...) + os.Exit(1) +} diff --git a/cli/data/get.go b/cli/data/get.go new file mode 100644 index 0000000..24e0b8b --- /dev/null +++ b/cli/data/get.go @@ -0,0 +1,133 @@ +// Copyright © 2017 RooFoods LTD +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package data + +import ( + "bytes" + "fmt" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/s3" + "github.com/spf13/cobra" + "github.com/spf13/viper" + "io" + "os" + "path/filepath" + "strings" +) + +var getBranch string +var getCommitPath string + +var getCmd = &cobra.Command{ + Use: "get [version] [destination path]", + Short: "Fetch data from S3", + Args: cobra.ExactArgs(2), + Long: `Fetch data from a S3 versioned path. + +Example: + +$ paddle data get -b experimental trained-model/version1 dest/path +`, + Run: func(cmd *cobra.Command, args []string) { + if !viper.IsSet("bucket") { + exitErrorf("Bucket not defined. Please define 'bucket' in your config file.") + } + fetchPath(viper.GetString("bucket"), args[0], getBranch, getCommitPath, args[1]) + }, +} + +func init() { + getCmd.Flags().StringVarP(&getBranch, "branch", "b", "master", "Branch to work on") + getCmd.Flags().StringVarP(&getCommitPath, "path", "p", "HEAD", "Path to fetch (instead of HEAD)") +} + +func fetchPath(bucket string, version string, branch string, path string, destination string) { + sess := session.Must(session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + })) + + if path == "HEAD" { + svc := s3.New(sess) + headPath := fmt.Sprintf("%s/%s/HEAD", version, branch) + fmt.Println(headPath) + out, err := svc.GetObject(&s3.GetObjectInput{ + Bucket: aws.String(bucket), + Key: aws.String(headPath), + }) + if err != nil { + exitErrorf("%v", err) + } + buf := new(bytes.Buffer) + buf.ReadFrom(out.Body) + path = buf.String() + } else { + path = fmt.Sprintf("%s/%s/%s", version, branch, path) + } + fmt.Println("Fetching " + path) + getBucketObjects(sess, bucket, path, destination) +} + +func getBucketObjects(sess *session.Session, bucket string, prefix string, dest string) { + query := &s3.ListObjectsV2Input{ + Bucket: aws.String(bucket), + Prefix: aws.String(prefix), + } + svc := s3.New(sess) + + truncatedListing := true + + for truncatedListing { + resp, err := svc.ListObjectsV2(query) + + if err != nil { + fmt.Println(err.Error()) + return + } + getObjectsAll(bucket, resp, svc, prefix, dest) + query.ContinuationToken = resp.NextContinuationToken + truncatedListing = *resp.IsTruncated + } +} + +func getObjectsAll(bucket string, bucketObjectsList *s3.ListObjectsV2Output, s3Client *s3.S3, prefix string, dest string) { + for _, key := range bucketObjectsList.Contents { + destFilename := *key.Key + if strings.HasSuffix(*key.Key, "/") { + fmt.Println("Got a directory") + continue + } + out, err := s3Client.GetObject(&s3.GetObjectInput{ + Bucket: aws.String(bucket), + Key: key.Key, + }) + if err != nil { + exitErrorf("%v", err) + } + destFilePath := dest + "/" + strings.TrimPrefix(destFilename, prefix+"/") + err = os.MkdirAll(filepath.Dir(destFilePath), 0777) + fmt.Print(destFilePath) + destFile, err := os.Create(destFilePath) + if err != nil { + exitErrorf("%v", err) + } + bytes, err := io.Copy(destFile, out.Body) + if err != nil { + exitErrorf("%v", err) + } + fmt.Printf(" -> %d bytes\n", bytes) + out.Body.Close() + destFile.Close() + } +}