Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pull dataset from studio if not available locally #901

Merged
merged 5 commits into from
Feb 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions src/datachain/catalog/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -1097,6 +1097,31 @@
def get_dataset(self, name: str) -> DatasetRecord:
return self.metastore.get_dataset(name)

def get_dataset_with_remote_fallback(
self, name: str, version: Optional[int] = None
) -> DatasetRecord:
try:
ds = self.get_dataset(name)
if version and not ds.has_version(version):
raise DatasetVersionNotFoundError(

Check warning on line 1106 in src/datachain/catalog/catalog.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/catalog/catalog.py#L1106

Added line #L1106 was not covered by tests
f"Dataset {name} does not have version {version}"
)
return ds

except (DatasetNotFoundError, DatasetVersionNotFoundError):
print("Dataset not found in local catalog, trying to get from studio")

remote_ds_uri = f"{DATASET_PREFIX}{name}"
if version:
remote_ds_uri += f"@v{version}"

self.pull_dataset(
remote_ds_uri=remote_ds_uri,
local_ds_name=name,
local_ds_version=version,
)
return self.get_dataset(name)

def get_dataset_with_version_uuid(self, uuid: str) -> DatasetRecord:
"""Returns dataset that contains version with specific uuid"""
for dataset in self.ls_datasets():
Expand Down
2 changes: 2 additions & 0 deletions src/datachain/lib/dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,7 @@ def from_dataset(
version: Optional[int] = None,
session: Optional[Session] = None,
settings: Optional[dict] = None,
fallback_to_remote: bool = True,
) -> "Self":
"""Get data from a saved Dataset. It returns the chain itself.

Expand All @@ -498,6 +499,7 @@ def from_dataset(
version=version,
session=session,
indexing_column_types=File._datachain_column_types,
fallback_to_remote=fallback_to_remote,
)
telemetry.send_event_once("class", "datachain_init", name=name, version=version)
if settings:
Expand Down
31 changes: 28 additions & 3 deletions src/datachain/query/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,17 @@
partition_col_names,
partition_columns,
)
from datachain.dataset import DatasetStatus, RowDict
from datachain.error import DatasetNotFoundError, QueryScriptCancelError
from datachain.dataset import DATASET_PREFIX, DatasetStatus, RowDict
from datachain.error import (
DatasetNotFoundError,
QueryScriptCancelError,
)
from datachain.func.base import Function
from datachain.lib.udf import UDFAdapter, _get_cache
from datachain.progress import CombinedDownloadCallback, TqdmCombinedDownloadCallback
from datachain.query.schema import C, UDFParamSpec, normalize_param
from datachain.query.session import Session
from datachain.remote.studio import is_token_set
from datachain.sql.functions.random import rand
from datachain.utils import (
batched,
Expand Down Expand Up @@ -1081,6 +1085,7 @@
session: Optional[Session] = None,
indexing_column_types: Optional[dict[str, Any]] = None,
in_memory: bool = False,
fallback_to_remote: bool = True,
) -> None:
self.session = Session.get(session, catalog=catalog, in_memory=in_memory)
self.catalog = catalog or self.session.catalog
Expand All @@ -1097,7 +1102,12 @@
self.column_types: Optional[dict[str, Any]] = None

self.name = name
ds = self.catalog.get_dataset(name)

if fallback_to_remote and is_token_set():
ds = self.catalog.get_dataset_with_remote_fallback(name, version)
else:
ds = self.catalog.get_dataset(name)

self.version = version or ds.latest_version
self.feature_schema = ds.get_version(self.version).feature_schema
self.column_types = copy(ds.schema)
Expand All @@ -1112,6 +1122,21 @@
def __or__(self, other):
return self.union(other)

def pull_dataset(self, name: str, version: Optional[int] = None) -> "DatasetRecord":
print("Dataset not found in local catalog, trying to get from studio")

Check warning on line 1126 in src/datachain/query/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/query/dataset.py#L1126

Added line #L1126 was not covered by tests
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's use logger here in debug mode? @skshetry what is your take?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried using logger with info here at beginning. But print seemed consistent with other similar messages. And also, we definitely want user to know we are trying to get it from Studio so that they can expect the delay in execution of the code.


remote_ds_uri = f"{DATASET_PREFIX}{name}"

Check warning on line 1128 in src/datachain/query/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/query/dataset.py#L1128

Added line #L1128 was not covered by tests
if version:
remote_ds_uri += f"@v{version}"

Check warning on line 1130 in src/datachain/query/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/query/dataset.py#L1130

Added line #L1130 was not covered by tests

self.catalog.pull_dataset(

Check warning on line 1132 in src/datachain/query/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/query/dataset.py#L1132

Added line #L1132 was not covered by tests
remote_ds_uri=remote_ds_uri,
local_ds_name=name,
local_ds_version=version,
)

return self.catalog.get_dataset(name)

Check warning on line 1138 in src/datachain/query/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/query/dataset.py#L1138

Added line #L1138 was not covered by tests

@staticmethod
def get_table() -> "TableClause":
table_name = "".join(
Expand Down
7 changes: 7 additions & 0 deletions src/datachain/remote/studio.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,13 @@ def _is_server_error(status_code: int) -> bool:
return str(status_code).startswith("5")


def is_token_set() -> bool:
return (
bool(os.environ.get("DVC_STUDIO_TOKEN"))
or Config().read().get("studio", {}).get("token") is not None
)


def _parse_dates(obj: dict, date_fields: list[str]):
"""
Function that converts string ISO dates to datetime.datetime instances in object
Expand Down
40 changes: 40 additions & 0 deletions tests/func/test_pull.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from datachain.config import Config, ConfigLevel
from datachain.dataset import DatasetStatus
from datachain.error import DataChainError, DatasetNotFoundError
from datachain.lib.dc import DataChain
from datachain.query.session import Session
from datachain.utils import STUDIO_URL, JSONSerialize
from tests.data import ENTRIES
from tests.utils import assert_row_names, skip_if_not_sqlite, tree_from_path
Expand Down Expand Up @@ -267,6 +269,44 @@ def test_pull_dataset_success(
}


@pytest.mark.parametrize("cloud_type, version_aware", [("s3", False)], indirect=True)
@skip_if_not_sqlite
def test_datachain_from_dataset_pull(
mocker,
cloud_test_catalog,
remote_dataset_info,
dataset_export,
dataset_export_status,
dataset_export_data_chunk,
):
# Check if the datachain pull from studio if datachain is not available.
mocker.patch(
"datachain.catalog.catalog.DatasetRowsFetcher.should_check_for_status",
return_value=True,
)

catalog = cloud_test_catalog.catalog

# Makes sure dataset is not available locally at first
with pytest.raises(DatasetNotFoundError):
catalog.get_dataset("dogs")

with Session("testSession", catalog=catalog):
ds = DataChain.from_dataset(
name="dogs",
version=1,
fallback_to_remote=True,
)

assert ds.dataset.name == "dogs"
assert ds.dataset.latest_version == 1
assert ds.dataset.status == DatasetStatus.COMPLETE

# Check that dataset is available locally after pulling
dataset = catalog.get_dataset("dogs")
assert dataset.name == "dogs"


@pytest.mark.parametrize("cloud_type, version_aware", [("s3", False)], indirect=True)
@skip_if_not_sqlite
def test_pull_dataset_wrong_dataset_uri_format(
Expand Down