Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for instantiation of pulled dataset #573

Merged
merged 4 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions src/datachain/catalog/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,9 +603,10 @@ def enlist_source(
)

lst = Listing(
self.metastore.clone(),
self.warehouse.clone(),
Client.get_client(list_uri, self.cache, **self.client_config),
self.get_dataset(list_ds_name),
dataset_name=list_ds_name,
object_name=object_name,
)

Expand Down Expand Up @@ -698,9 +699,13 @@ def _row_to_node(d: dict[str, Any]) -> Node:

client = self.get_client(source, **client_config)
uri = client.uri
st = self.warehouse.clone()
dataset_name, _, _, _ = DataChain.parse_uri(uri, self.session)
listing = Listing(st, client, self.get_dataset(dataset_name))
listing = Listing(
self.metastore.clone(),
self.warehouse.clone(),
client,
dataset_name=dataset_name,
)
rows = DatasetQuery(
name=dataset.name, version=ds_version, catalog=self
).to_db_records()
Expand Down Expand Up @@ -1350,6 +1355,13 @@ def _instantiate_dataset():
# we will create new one if it doesn't exist
pass

if dataset and version and dataset.has_version(version):
"""No need to communicate with Studio at all"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Q] Is there some validation that happens to confirm that dataset versions are "the same" between local and remote?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

E.g. my_dataset@v1 on remote is full of PDFs and my_dataset@v1 locally is images.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. No, there is no such thing ATM. Currently it's up to user to make sure he doesn't "shadow" remote dataset. In order to make this kind of validation we need to have equal function between datasets (I will create an issue for this) and my first feeling is that it's not that trivial to implement it (maybe I'm wrong, idk)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, it seems like a simple uuid should be enough for this, which is trivial

dataset_uri = create_dataset_uri(remote_dataset_name, version)
print(f"Local copy of dataset {dataset_uri} already present")
_instantiate_dataset()
return

remote_dataset = self.get_remote_dataset(remote_dataset_name)
# if version is not specified in uri, take the latest one
if not version:
Expand Down
31 changes: 24 additions & 7 deletions src/datachain/listing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import glob
import os
from collections.abc import Iterable, Iterator
from functools import cached_property
from itertools import zip_longest
from typing import TYPE_CHECKING, Optional

Expand All @@ -15,28 +16,34 @@
if TYPE_CHECKING:
from datachain.catalog.datasource import DataSource
from datachain.client import Client
from datachain.data_storage import AbstractWarehouse
from datachain.data_storage import AbstractMetastore, AbstractWarehouse
from datachain.dataset import DatasetRecord


class Listing:
def __init__(
self,
metastore: "AbstractMetastore",
warehouse: "AbstractWarehouse",
client: "Client",
dataset: Optional["DatasetRecord"],
dataset_name: Optional["str"] = None,
dataset_version: Optional[int] = None,
object_name: str = "file",
):
self.metastore = metastore
self.warehouse = warehouse
self.client = client
self.dataset = dataset # dataset representing bucket listing
self.dataset_name = dataset_name # dataset representing bucket listing
self.dataset_version = dataset_version # dataset representing bucket listing
self.object_name = object_name

def clone(self) -> "Listing":
return self.__class__(
self.metastore.clone(),
self.warehouse.clone(),
self.client,
self.dataset,
self.dataset_name,
self.dataset_version,
self.object_name,
)

Expand All @@ -53,12 +60,22 @@ def close(self) -> None:
def uri(self):
from datachain.lib.listing import listing_uri_from_name

return listing_uri_from_name(self.dataset.name)
assert self.dataset_name

@property
return listing_uri_from_name(self.dataset_name)

@cached_property
def dataset(self) -> "DatasetRecord":
assert self.dataset_name
return self.metastore.get_dataset(self.dataset_name)

@cached_property
def dataset_rows(self):
dataset = self.dataset
return self.warehouse.dataset_rows(
self.dataset, self.dataset.latest_version, object_name=self.object_name
dataset,
self.dataset_version or dataset.latest_version,
object_name=self.object_name,
)

def expand_path(self, path, use_glob=True) -> list[Node]:
Expand Down
39 changes: 33 additions & 6 deletions tests/func/test_pull.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
import pandas as pd
import pytest

from datachain.client.fsspec import Client
from datachain.config import Config, ConfigLevel
from datachain.dataset import DatasetStatus
from datachain.error import DataChainError
from datachain.utils import STUDIO_URL, JSONSerialize
from tests.data import ENTRIES
from tests.utils import assert_row_names, skip_if_not_sqlite
from tests.utils import assert_row_names, skip_if_not_sqlite, tree_from_path


@pytest.fixture(autouse=True)
Expand All @@ -38,10 +39,11 @@ def dog_entries():


@pytest.fixture
def dog_entries_parquet_lz4(dog_entries) -> bytes:
def dog_entries_parquet_lz4(dog_entries, cloud_test_catalog) -> bytes:
"""
Returns dogs entries in lz4 compressed parquet format
"""
src_uri = cloud_test_catalog.src_uri

def _adapt_row(row):
"""
Expand All @@ -59,7 +61,7 @@ def _adapt_row(row):
adapted["sys__id"] = 1
adapted["sys__rand"] = 1
adapted["file__location"] = ""
adapted["file__source"] = "s3://dogs"
adapted["file__source"] = src_uri
return adapted

dog_entries = [_adapt_row(e) for e in dog_entries]
Expand Down Expand Up @@ -138,14 +140,18 @@ def remote_dataset(remote_dataset_version, schema):

@pytest.mark.parametrize("cloud_type, version_aware", [("s3", False)], indirect=True)
@pytest.mark.parametrize("dataset_uri", ["ds://dogs@v1", "ds://dogs"])
@pytest.mark.parametrize("instantiate", [True, False])
@skip_if_not_sqlite
def test_pull_dataset_success(
requests_mock,
cloud_test_catalog,
remote_dataset,
dog_entries_parquet_lz4,
dataset_uri,
instantiate,
):
src_uri = cloud_test_catalog.src_uri
working_dir = cloud_test_catalog.working_dir
data_url = (
"https://studio-blobvault.s3.amazonaws.com/datachain_ds_export_1_0.parquet.lz4"
)
Expand All @@ -162,9 +168,16 @@ def test_pull_dataset_success(
requests_mock.get(data_url, content=dog_entries_parquet_lz4)
catalog = cloud_test_catalog.catalog

catalog.pull_dataset(dataset_uri, no_cp=True)
# trying to pull multiple times as it should work
catalog.pull_dataset(dataset_uri, no_cp=True)
dest = None

if instantiate:
dest = working_dir / "data"
dest.mkdir()
catalog.pull_dataset(dataset_uri, output=str(dest), no_cp=False)
else:
# trying to pull multiple times since that should work as well
catalog.pull_dataset(dataset_uri, no_cp=True)
catalog.pull_dataset(dataset_uri, no_cp=True)

dataset = catalog.get_dataset("dogs")
assert dataset.versions_values == [1]
Expand Down Expand Up @@ -192,6 +205,20 @@ def test_pull_dataset_success(
},
)

client = Client.get_client(src_uri, None)

if instantiate:
assert tree_from_path(dest) == {
f"{client.name}": {
"dogs": {
"dog1": "woof",
"dog2": "arf",
"dog3": "bark",
"others": {"dog4": "ruff"},
}
}
}


@pytest.mark.parametrize("cloud_type, version_aware", [("s3", False)], indirect=True)
@skip_if_not_sqlite
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/test_listing.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,10 @@ def listing(test_session):
)

return Listing(
catalog.metastore.clone(),
catalog.warehouse.clone(),
Client.get_client("s3://whatever", catalog.cache, **catalog.client_config),
catalog.get_dataset(dataset_name),
dataset_name=dataset_name,
object_name="file",
)

Expand Down