Skip to content

Commit

Permalink
Support non-seekable stdin (- arg) in "fs upload" command
Browse files Browse the repository at this point in the history
Parity with `aws s3 cp - ...` on this.
  • Loading branch information
arielshaqed committed Mar 24, 2021
1 parent 1740409 commit e0db727
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 17 deletions.
4 changes: 3 additions & 1 deletion cmd/lakectl/cmd/fs.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,9 @@ var fsCatCmd = &cobra.Command{
func upload(ctx context.Context, client api.Client, sourcePathname string, destURI *uri.URI, direct bool) (*models.ObjectStats, error) {
fp := OpenByPath(sourcePathname)
defer func() {
_ = fp.Close()
if err := fp.Close(); err != nil {
DieErr(fmt.Errorf("close: %w", err))
}
}()

if direct {
Expand Down
45 changes: 43 additions & 2 deletions cmd/lakectl/cmd/input.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cmd

import (
"fmt"
"io"
"os"

Expand Down Expand Up @@ -55,10 +56,50 @@ func (nc *nopCloser) Close() error {
return nil
}

// OpenByPath returns a reader from the given path. If path is "-", it'll return Stdin
// deleteOnClose wraps a File to be a ReadSeekCloser that deletes itself when closed.
type deleteOnClose struct {
*os.File
}

func (d *deleteOnClose) Read(p []byte) (n int, err error) {
return d.File.Read(p)
}

func (d *deleteOnClose) Seek(offset int64, whence int) (int64, error) {
return d.File.Seek(offset, whence)
}

func (d *deleteOnClose) Close() error {
if err := os.Remove(d.Name()); err != nil {
d.File.Close() // Close failure is unimportant on read, but data definitely stays!
return fmt.Errorf("delete on close: %w", err)
}
return d.File.Close()
}

// OpenByPath returns a reader from the given path. If path is "-", it consumes Stdin and
// opens a readable copy that is either deleted (POSIX) or will delete itself on close
// (non-POSIX, notably WINs).
func OpenByPath(path string) io.ReadSeekCloser {
if path == StdinFileName {
// read from stdin
if !isSeekable(os.Stdin) {
temp, err := os.CreateTemp("", "lakectl-stdin")
if err != nil {
DieErr(fmt.Errorf("create temporary file to buffer stdin: %w", err))
}
if _, err = io.Copy(temp, os.Stdin); err != nil {
DieErr(fmt.Errorf("copy stdin to temporary file: %w", err))
}
if _, err = temp.Seek(0, io.SeekStart); err != nil {
DieErr(fmt.Errorf("rewind temporary copied file: %w", err))
}
// Try to delete the file. This will fail on Windows, we shall try to
// delete on close anyway.
if os.Remove(temp.Name()) != nil {
return &deleteOnClose{temp}
}
return temp
}
return &nopCloser{os.Stdin}
}
fp, err := os.Open(path)
Expand Down
6 changes: 6 additions & 0 deletions cmd/lakectl/cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,12 @@ func getClient() api.Client {
return client
}

// isSeekable returns true if f.Seek appears to work.
func isSeekable(f io.Seeker) bool {
_, err := f.Seek(0, io.SeekCurrent)
return err == nil // a little naive, but probably good enough for its purpose
}

// Execute adds all child commands to the root command and sets flags appropriately.
// This is called by main.main(). It only needs to happen once to the rootCmd.
func Execute() {
Expand Down
5 changes: 0 additions & 5 deletions cmd/lakectl/cmd/sst.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,6 @@ import (
"google.golang.org/protobuf/proto"
)

func isSeekable(f io.Seeker) bool {
_, err := f.Seek(0, io.SeekCurrent)
return err == nil // a little naive, but probably good enough for its purpose
}

func readStdin() (pebblesst.ReadableFile, error) {
// test if stdin is seekable
if isSeekable(os.Stdin) {
Expand Down
28 changes: 19 additions & 9 deletions pkg/api/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -742,6 +742,20 @@ func (c *client) StageObject(ctx context.Context, repoID, branchID, path string,
return resp.GetPayload(), nil
}

// readSize returns the size of r.
func readSize(r io.Seeker) (int64, error) {
cur, err := r.Seek(0, io.SeekCurrent)
if err != nil {
return 0, fmt.Errorf("tell: %w", err)
}
end, err := r.Seek(0, io.SeekEnd)
if err != nil {
return 0, fmt.Errorf("seek to end: %w", err)
}
_, err = r.Seek(cur, io.SeekStart)
return end, err
}

func (c *client) ClientUpload(ctx context.Context, repoID, branchID, path string, metadata map[string]string, contents io.ReadSeeker) (*models.ObjectStats, error) {
resp, err := c.remote.Staging.GetPhysicalAddress(&staging.GetPhysicalAddressParams{
Repository: repoID,
Expand All @@ -754,6 +768,11 @@ func (c *client) ClientUpload(ctx context.Context, repoID, branchID, path string
}
stagingLocation := resp.GetPayload()

size, err := readSize(contents)
if err != nil {
return nil, fmt.Errorf("readSize: %w", err)
}

for { // Return from inside loop
physicalAddress, err := url.Parse(stagingLocation.PhysicalAddress)

Expand Down Expand Up @@ -786,15 +805,6 @@ func (c *client) ClientUpload(ctx context.Context, repoID, branchID, path string
return nil, fmt.Errorf("upload to backing store %v: %w", physicalAddress, err)
}

size, err := contents.Seek(0, io.SeekEnd)
if err != nil {
return nil, fmt.Errorf("read size: %w", err)
}
_, err = contents.Seek(0, io.SeekStart)
if err != nil {
return nil, fmt.Errorf("rewind: %w", err)
}

_, err = c.remote.Staging.LinkPhysicalAddress(&staging.LinkPhysicalAddressParams{
Repository: repoID,
Branch: branchID,
Expand Down

0 comments on commit e0db727

Please sign in to comment.