Skip to content

Commit

Permalink
Pull dataset from studio if not available locally (#901)
Browse files Browse the repository at this point in the history
* Pull dataset from studio if not available locally

If the following case are met, this will pull dataset from Studio.
- User should be logged in to Studio.
- The dataset or version doesn't exist in local
- User has not pass studio=False to from_dataset.

In such case, this will pull the dataset from studio before continuing
further.

The test is added to check for such behavior.

Closes #874

* Move token check to util

* Move to catalog
  • Loading branch information
amritghimire authored Feb 10, 2025
1 parent 0ff6d54 commit ea9a904
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 3 deletions.
25 changes: 25 additions & 0 deletions src/datachain/catalog/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -1097,6 +1097,31 @@ def register_dataset(
def get_dataset(self, name: str) -> DatasetRecord:
return self.metastore.get_dataset(name)

def get_dataset_with_remote_fallback(
self, name: str, version: Optional[int] = None
) -> DatasetRecord:
try:
ds = self.get_dataset(name)
if version and not ds.has_version(version):
raise DatasetVersionNotFoundError(
f"Dataset {name} does not have version {version}"
)
return ds

except (DatasetNotFoundError, DatasetVersionNotFoundError):
print("Dataset not found in local catalog, trying to get from studio")

remote_ds_uri = f"{DATASET_PREFIX}{name}"
if version:
remote_ds_uri += f"@v{version}"

self.pull_dataset(
remote_ds_uri=remote_ds_uri,
local_ds_name=name,
local_ds_version=version,
)
return self.get_dataset(name)

def get_dataset_with_version_uuid(self, uuid: str) -> DatasetRecord:
"""Returns dataset that contains version with specific uuid"""
for dataset in self.ls_datasets():
Expand Down
2 changes: 2 additions & 0 deletions src/datachain/lib/dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,7 @@ def from_dataset(
version: Optional[int] = None,
session: Optional[Session] = None,
settings: Optional[dict] = None,
fallback_to_remote: bool = True,
) -> "Self":
"""Get data from a saved Dataset. It returns the chain itself.
Expand All @@ -498,6 +499,7 @@ def from_dataset(
version=version,
session=session,
indexing_column_types=File._datachain_column_types,
fallback_to_remote=fallback_to_remote,
)
telemetry.send_event_once("class", "datachain_init", name=name, version=version)
if settings:
Expand Down
31 changes: 28 additions & 3 deletions src/datachain/query/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,17 @@
partition_col_names,
partition_columns,
)
from datachain.dataset import DatasetStatus, RowDict
from datachain.error import DatasetNotFoundError, QueryScriptCancelError
from datachain.dataset import DATASET_PREFIX, DatasetStatus, RowDict
from datachain.error import (
DatasetNotFoundError,
QueryScriptCancelError,
)
from datachain.func.base import Function
from datachain.lib.udf import UDFAdapter, _get_cache
from datachain.progress import CombinedDownloadCallback, TqdmCombinedDownloadCallback
from datachain.query.schema import C, UDFParamSpec, normalize_param
from datachain.query.session import Session
from datachain.remote.studio import is_token_set
from datachain.sql.functions.random import rand
from datachain.utils import (
batched,
Expand Down Expand Up @@ -1081,6 +1085,7 @@ def __init__(
session: Optional[Session] = None,
indexing_column_types: Optional[dict[str, Any]] = None,
in_memory: bool = False,
fallback_to_remote: bool = True,
) -> None:
self.session = Session.get(session, catalog=catalog, in_memory=in_memory)
self.catalog = catalog or self.session.catalog
Expand All @@ -1097,7 +1102,12 @@ def __init__(
self.column_types: Optional[dict[str, Any]] = None

self.name = name
ds = self.catalog.get_dataset(name)

if fallback_to_remote and is_token_set():
ds = self.catalog.get_dataset_with_remote_fallback(name, version)
else:
ds = self.catalog.get_dataset(name)

self.version = version or ds.latest_version
self.feature_schema = ds.get_version(self.version).feature_schema
self.column_types = copy(ds.schema)
Expand All @@ -1112,6 +1122,21 @@ def __iter__(self):
def __or__(self, other):
return self.union(other)

def pull_dataset(self, name: str, version: Optional[int] = None) -> "DatasetRecord":
print("Dataset not found in local catalog, trying to get from studio")

remote_ds_uri = f"{DATASET_PREFIX}{name}"
if version:
remote_ds_uri += f"@v{version}"

self.catalog.pull_dataset(
remote_ds_uri=remote_ds_uri,
local_ds_name=name,
local_ds_version=version,
)

return self.catalog.get_dataset(name)

@staticmethod
def get_table() -> "TableClause":
table_name = "".join(
Expand Down
7 changes: 7 additions & 0 deletions src/datachain/remote/studio.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,13 @@ def _is_server_error(status_code: int) -> bool:
return str(status_code).startswith("5")


def is_token_set() -> bool:
return (
bool(os.environ.get("DVC_STUDIO_TOKEN"))
or Config().read().get("studio", {}).get("token") is not None
)


def _parse_dates(obj: dict, date_fields: list[str]):
"""
Function that converts string ISO dates to datetime.datetime instances in object
Expand Down
40 changes: 40 additions & 0 deletions tests/func/test_pull.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from datachain.config import Config, ConfigLevel
from datachain.dataset import DatasetStatus
from datachain.error import DataChainError, DatasetNotFoundError
from datachain.lib.dc import DataChain
from datachain.query.session import Session
from datachain.utils import STUDIO_URL, JSONSerialize
from tests.data import ENTRIES
from tests.utils import assert_row_names, skip_if_not_sqlite, tree_from_path
Expand Down Expand Up @@ -267,6 +269,44 @@ def test_pull_dataset_success(
}


@pytest.mark.parametrize("cloud_type, version_aware", [("s3", False)], indirect=True)
@skip_if_not_sqlite
def test_datachain_from_dataset_pull(
mocker,
cloud_test_catalog,
remote_dataset_info,
dataset_export,
dataset_export_status,
dataset_export_data_chunk,
):
# Check if the datachain pull from studio if datachain is not available.
mocker.patch(
"datachain.catalog.catalog.DatasetRowsFetcher.should_check_for_status",
return_value=True,
)

catalog = cloud_test_catalog.catalog

# Makes sure dataset is not available locally at first
with pytest.raises(DatasetNotFoundError):
catalog.get_dataset("dogs")

with Session("testSession", catalog=catalog):
ds = DataChain.from_dataset(
name="dogs",
version=1,
fallback_to_remote=True,
)

assert ds.dataset.name == "dogs"
assert ds.dataset.latest_version == 1
assert ds.dataset.status == DatasetStatus.COMPLETE

# Check that dataset is available locally after pulling
dataset = catalog.get_dataset("dogs")
assert dataset.name == "dogs"


@pytest.mark.parametrize("cloud_type, version_aware", [("s3", False)], indirect=True)
@skip_if_not_sqlite
def test_pull_dataset_wrong_dataset_uri_format(
Expand Down

0 comments on commit ea9a904

Please sign in to comment.