diff --git a/flytecopilot/data/download.go b/flytecopilot/data/download.go index 0fd1f10bd9..5f62ff627b 100644 --- a/flytecopilot/data/download.go +++ b/flytecopilot/data/download.go @@ -8,8 +8,10 @@ import ( "io/ioutil" "os" "path" + "path/filepath" "reflect" "strconv" + "sync" "github.com/ghodss/yaml" "github.com/golang/protobuf/jsonpb" @@ -31,40 +33,120 @@ type Downloader struct { mode core.IOStrategy_DownloadMode } -// TODO add support for multipart blobs -func (d Downloader) handleBlob(ctx context.Context, blob *core.Blob, toFilePath string) (interface{}, error) { - ref := storage.DataReference(blob.Uri) - scheme, _, _, err := ref.Split() +func (d Downloader) handleBlob(ctx context.Context, blob *core.Blob, toPath string) (interface{}, error) { + blobRef := storage.DataReference(blob.Uri) + scheme, c, _, err := blobRef.Split() if err != nil { return nil, errors.Wrapf(err, "Blob uri incorrectly formatted") } - var reader io.ReadCloser - if scheme == "http" || scheme == "https" { - reader, err = DownloadFileFromHTTP(ctx, ref) - } else { - if blob.GetMetadata().GetType().Dimensionality == core.BlobType_MULTIPART { - logger.Warnf(ctx, "Currently only single part blobs are supported, we will force multipart to be 'path/00000'") - ref, err = d.store.ConstructReference(ctx, ref, "000000") - if err != nil { + + if blob.GetMetadata().GetType().Dimensionality == core.BlobType_MULTIPART { + maxItems := 100 + cursor := storage.NewCursorAtStart() + var items []storage.DataReference + var keys []string + for { + items, cursor, err = d.store.List(ctx, blobRef, maxItems, cursor) + if err != nil || len(items) == 0 { + logger.Errorf(ctx, "failed to collect items from multipart blob [%s]", blobRef) return nil, err } + for _, item := range items { + keys = append(keys, item.String()) + } + if storage.IsCursorEnd(cursor) { + break + } } - reader, err = DownloadFileFromStorage(ctx, ref, d.store) + + success := 0 + var mu sync.Mutex + var wg sync.WaitGroup + for _, k := range keys { + absPath := fmt.Sprintf("%s://%s/%s", scheme, c, k) + + wg.Add(1) + go func() { + defer wg.Done() + defer func() { + if err := recover(); err != nil { + logger.Errorf(ctx, "recover receives error: %s", err) + } + }() + + ref := storage.DataReference(absPath) + reader, err := DownloadFileFromStorage(ctx, ref, d.store) + if err != nil { + logger.Errorf(ctx, "Failed to download from ref [%s]", ref) + return + } + defer func() { + err := reader.Close() + if err != nil { + logger.Errorf(ctx, "failed to close Blob read stream @ref [%s]. Error: %s", ref, err) + } + }() + + _, _, k, err := ref.Split() + if err != nil { + logger.Errorf(ctx, "Failed to parse ref [%s]", ref) + return + } + newPath := filepath.Join(toPath, k) + dir := filepath.Dir(newPath) + + mu.Lock() + // 0755: the directory can be read by anyone but can only be written by the owner + os.MkdirAll(dir, 0755) + mu.Unlock() + writer, err := os.Create(newPath) + if err != nil { + logger.Errorf(ctx, "failed to open file at path %s", newPath) + return + } + defer func() { + err := writer.Close() + if err != nil { + logger.Errorf(ctx, "failed to close File write stream. Error: %s", err) + } + }() + + _, err = io.Copy(writer, reader) + if err != nil { + logger.Errorf(ctx, "failed to write remote data to local filesystem") + return + } + mu.Lock() + success += 1 + mu.Unlock() + }() + } + wg.Wait() + logger.Infof(ctx, "Successfully copied [%d] remote files from [%s] to local [%s]", success, blobRef, toPath) + return toPath, nil + } + + // reader should be declared here (avoid being shared across all goroutines) + var reader io.ReadCloser + if scheme == "http" || scheme == "https" { + reader, err = DownloadFileFromHTTP(ctx, blobRef) + } else { + reader, err = DownloadFileFromStorage(ctx, blobRef, d.store) } if err != nil { - logger.Errorf(ctx, "Failed to download from ref [%s]", ref) + logger.Errorf(ctx, "Failed to download from ref [%s]", blobRef) return nil, err } defer func() { err := reader.Close() if err != nil { - logger.Errorf(ctx, "failed to close Blob read stream @ref [%s]. Error: %s", ref, err) + logger.Errorf(ctx, "failed to close Blob read stream @ref [%s]. Error: %s", blobRef, err) } }() - writer, err := os.Create(toFilePath) + writer, err := os.Create(toPath) if err != nil { - return nil, errors.Wrapf(err, "failed to open file at path %s", toFilePath) + return nil, errors.Wrapf(err, "failed to open file at path %s", toPath) } defer func() { err := writer.Close() @@ -76,12 +158,11 @@ func (d Downloader) handleBlob(ctx context.Context, blob *core.Blob, toFilePath if err != nil { return nil, errors.Wrapf(err, "failed to write remote data to local filesystem") } - logger.Infof(ctx, "Successfully copied [%d] bytes remote data from [%s] to local [%s]", v, ref, toFilePath) - return toFilePath, nil + logger.Infof(ctx, "Successfully copied [%d] bytes remote data from [%s] to local [%s]", v, blobRef, toPath) + return toPath, nil } func (d Downloader) handleSchema(ctx context.Context, schema *core.Schema, toFilePath string) (interface{}, error) { - // TODO Handle schema type return d.handleBlob(ctx, &core.Blob{Uri: schema.Uri, Metadata: &core.BlobMetadata{Type: &core.BlobType{Dimensionality: core.BlobType_MULTIPART}}}, toFilePath) } diff --git a/flytecopilot/data/download_test.go b/flytecopilot/data/download_test.go new file mode 100644 index 0000000000..1f3b3a7be6 --- /dev/null +++ b/flytecopilot/data/download_test.go @@ -0,0 +1,151 @@ +package data + +import ( + "bytes" + "context" + "os" + "path/filepath" + "testing" + + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flyte/flytestdlib/promutils" + "github.com/flyteorg/flyte/flytestdlib/storage" + + "github.com/stretchr/testify/assert" +) + +func TestHandleBlobMultipart(t *testing.T) { + t.Run("Successful Query", func(t *testing.T) { + s, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope()) + assert.NoError(t, err) + ref := storage.DataReference("s3://container/folder/file1") + s.WriteRaw(context.Background(), ref, 0, storage.Options{}, bytes.NewReader([]byte{})) + ref = storage.DataReference("s3://container/folder/file2") + s.WriteRaw(context.Background(), ref, 0, storage.Options{}, bytes.NewReader([]byte{})) + + d := Downloader{store: s} + + blob := &core.Blob{ + Uri: "s3://container/folder", + Metadata: &core.BlobMetadata{ + Type: &core.BlobType{ + Dimensionality: core.BlobType_MULTIPART, + }, + }, + } + + toPath := "./inputs" + defer func() { + err := os.RemoveAll(toPath) + if err != nil { + t.Errorf("Failed to delete directory: %v", err) + } + }() + + result, err := d.handleBlob(context.Background(), blob, toPath) + assert.NoError(t, err) + assert.Equal(t, toPath, result) + + // Check if files were created and data written + for _, file := range []string{"file1", "file2"} { + if _, err := os.Stat(filepath.Join(toPath, "folder", file)); os.IsNotExist(err) { + t.Errorf("expected file %s to exist", file) + } + } + }) + + t.Run("No Items", func(t *testing.T) { + s, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope()) + assert.NoError(t, err) + + d := Downloader{store: s} + + blob := &core.Blob{ + Uri: "s3://container/folder", + Metadata: &core.BlobMetadata{ + Type: &core.BlobType{ + Dimensionality: core.BlobType_MULTIPART, + }, + }, + } + + toPath := "./inputs" + defer func() { + err := os.RemoveAll(toPath) + if err != nil { + t.Errorf("Failed to delete directory: %v", err) + } + }() + + result, err := d.handleBlob(context.Background(), blob, toPath) + assert.Error(t, err) + assert.Nil(t, result) + }) +} + +func TestHandleBlobSinglePart(t *testing.T) { + s, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope()) + assert.NoError(t, err) + ref := storage.DataReference("s3://container/file") + s.WriteRaw(context.Background(), ref, 0, storage.Options{}, bytes.NewReader([]byte{})) + + d := Downloader{store: s} + + blob := &core.Blob{ + Uri: "s3://container/file", + Metadata: &core.BlobMetadata{ + Type: &core.BlobType{ + Dimensionality: core.BlobType_SINGLE, + }, + }, + } + + toPath := "./input" + defer func() { + err := os.RemoveAll(toPath) + if err != nil { + t.Errorf("Failed to delete file: %v", err) + } + }() + + result, err := d.handleBlob(context.Background(), blob, toPath) + assert.NoError(t, err) + assert.Equal(t, toPath, result) + + // Check if files were created and data written + if _, err := os.Stat(toPath); os.IsNotExist(err) { + t.Errorf("expected file %s to exist", toPath) + } +} + +func TestHandleBlobHTTP(t *testing.T) { + s, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope()) + assert.NoError(t, err) + d := Downloader{store: s} + + blob := &core.Blob{ + Uri: "https://raw.githubusercontent.com/flyteorg/flyte/master/README.md", + Metadata: &core.BlobMetadata{ + Type: &core.BlobType{ + Dimensionality: core.BlobType_SINGLE, + }, + }, + } + + toPath := "./input" + defer func() { + err := os.RemoveAll(toPath) + if err != nil { + t.Errorf("Failed to delete file: %v", err) + } + }() + + result, err := d.handleBlob(context.Background(), blob, toPath) + assert.NoError(t, err) + assert.Equal(t, toPath, result) + + // Check if files were created and data written + if _, err := os.Stat(toPath); os.IsNotExist(err) { + t.Errorf("expected file %s to exist", toPath) + } +} diff --git a/flytestdlib/storage/mem_store.go b/flytestdlib/storage/mem_store.go index 94083f6646..4b2a4b7077 100644 --- a/flytestdlib/storage/mem_store.go +++ b/flytestdlib/storage/mem_store.go @@ -9,6 +9,9 @@ import ( "io" "io/ioutil" "os" + "strings" + + "github.com/flyteorg/flyte/flytestdlib/logger" ) type rawFile = []byte @@ -55,7 +58,25 @@ func (s *InMemoryStore) Head(ctx context.Context, reference DataReference) (Meta } func (s *InMemoryStore) List(ctx context.Context, reference DataReference, maxItems int, cursor Cursor) ([]DataReference, Cursor, error) { - return nil, NewCursorAtEnd(), fmt.Errorf("Not implemented yet") + var items []DataReference + prefix := strings.TrimSuffix(string(reference), "/") + "/" + + for ref := range s.cache { + if strings.HasPrefix(ref.String(), prefix) { + _, _, k, err := ref.Split() + if err != nil { + logger.Errorf(ctx, "failed to split reference [%s]", ref) + continue + } + items = append(items, DataReference(k)) + } + } + + if len(items) == 0 { + return nil, NewCursorAtEnd(), os.ErrNotExist + } + + return items, NewCursorAtEnd(), nil } func (s *InMemoryStore) ReadRaw(ctx context.Context, reference DataReference) (io.ReadCloser, error) { diff --git a/flytestdlib/storage/storage.go b/flytestdlib/storage/storage.go index 52e6905513..0779beffe9 100644 --- a/flytestdlib/storage/storage.go +++ b/flytestdlib/storage/storage.go @@ -75,6 +75,10 @@ func NewCursorFromCustomPosition(customPosition string) Cursor { } } +func IsCursorEnd(cursor Cursor) bool { + return cursor.cursorState == AtEndCursorState +} + // DataStore is a simplified interface for accessing and storing data in one of the Cloud stores. // Today we rely on Stow for multi-cloud support, but this interface abstracts that part type DataStore struct { @@ -113,7 +117,7 @@ type RawStore interface { // Head gets metadata about the reference. This should generally be a light weight operation. Head(ctx context.Context, reference DataReference) (Metadata, error) - // List gets a list of items given a prefix, using a paginated API + // List gets a list of items (relative path to the reference input) given a prefix, using a paginated API List(ctx context.Context, reference DataReference, maxItems int, cursor Cursor) ([]DataReference, Cursor, error) // ReadRaw retrieves a byte array from the Blob store or an error