Skip to content

Commit

Permalink
making cp non default when pulling dataset (#720)
Browse files Browse the repository at this point in the history
  • Loading branch information
ilongin authored Dec 17, 2024
1 parent 34a5796 commit 10e90c5
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 18 deletions.
6 changes: 3 additions & 3 deletions src/datachain/catalog/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -1297,15 +1297,15 @@ def pull_dataset( # noqa: PLR0915
output: Optional[str] = None,
local_ds_name: Optional[str] = None,
local_ds_version: Optional[int] = None,
no_cp: bool = False,
cp: bool = False,
force: bool = False,
edatachain: bool = False,
edatachain_file: Optional[str] = None,
*,
client_config=None,
) -> None:
def _instantiate(ds_uri: str) -> None:
if no_cp:
if not cp:
return
assert output
self.cp(
Expand All @@ -1318,7 +1318,7 @@ def _instantiate(ds_uri: str) -> None:
)
print(f"Dataset {ds_uri} instantiated locally to {output}")

if not output and not no_cp:
if cp and not output:
raise ValueError("Please provide output directory for instantiation")

studio_client = StudioClient()
Expand Down
6 changes: 3 additions & 3 deletions src/datachain/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,10 +479,10 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
help="Copy directories recursively",
)
parse_pull.add_argument(
"--no-cp",
"--cp",
default=False,
action="store_true",
help="Do not copy files, just pull a remote dataset into local DB",
help="Copy actual files after pulling remote dataset into local DB",
)
parse_pull.add_argument(
"--edatachain",
Expand Down Expand Up @@ -1322,7 +1322,7 @@ def main(argv: Optional[list[str]] = None) -> int: # noqa: C901, PLR0912, PLR09
args.output,
local_ds_name=args.local_name,
local_ds_version=args.local_version,
no_cp=args.no_cp,
cp=args.cp,
force=bool(args.force),
edatachain=args.edatachain,
edatachain_file=args.edatachain_file,
Expand Down
24 changes: 12 additions & 12 deletions tests/func/test_pull.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def test_pull_dataset_success(
output=str(dest),
local_ds_name=local_ds_name,
local_ds_version=local_ds_version,
no_cp=False,
cp=True,
)
else:
# trying to pull multiple times since that should work as well
Expand All @@ -224,7 +224,7 @@ def test_pull_dataset_success(
dataset_uri,
local_ds_name=local_ds_name,
local_ds_version=local_ds_version,
no_cp=True,
cp=False,
)

dataset = catalog.get_dataset(local_ds_name or "dogs")
Expand Down Expand Up @@ -279,7 +279,7 @@ def test_pull_dataset_wrong_dataset_uri_format(
catalog = cloud_test_catalog.catalog

with pytest.raises(DataChainError) as exc_info:
catalog.pull_dataset("wrong", no_cp=True)
catalog.pull_dataset("wrong")
assert str(exc_info.value) == "Error when parsing dataset uri"


Expand All @@ -293,7 +293,7 @@ def test_pull_dataset_wrong_version(
catalog = cloud_test_catalog.catalog

with pytest.raises(DataChainError) as exc_info:
catalog.pull_dataset("ds://dogs@v5", no_cp=True)
catalog.pull_dataset("ds://dogs@v5")
assert str(exc_info.value) == "Dataset dogs doesn't have version 5 on server"


Expand All @@ -311,7 +311,7 @@ def test_pull_dataset_not_found_in_remote(
catalog = cloud_test_catalog.catalog

with pytest.raises(DataChainError) as exc_info:
catalog.pull_dataset("ds://dogs@v1", no_cp=True)
catalog.pull_dataset("ds://dogs@v1")
assert str(exc_info.value) == "Error from server: Dataset not found"


Expand All @@ -330,7 +330,7 @@ def test_pull_dataset_error_on_fetching_stats(
catalog = cloud_test_catalog.catalog

with pytest.raises(DataChainError) as exc_info:
catalog.pull_dataset("ds://dogs@v1", no_cp=True)
catalog.pull_dataset("ds://dogs@v1")
assert str(exc_info.value) == "Error from server: Internal error"


Expand All @@ -353,7 +353,7 @@ def test_pull_dataset_exporting_dataset_failed_in_remote(
catalog = cloud_test_catalog.catalog

with pytest.raises(DataChainError) as exc_info:
catalog.pull_dataset("ds://dogs@v1", no_cp=True)
catalog.pull_dataset("ds://dogs@v1")
assert str(exc_info.value) == (
f"Error from server: Dataset export {export_status} in Studio"
)
Expand All @@ -374,7 +374,7 @@ def test_pull_dataset_empty_parquet(
catalog = cloud_test_catalog.catalog

with pytest.raises(RuntimeError):
catalog.pull_dataset("ds://dogs@v1", no_cp=True)
catalog.pull_dataset("ds://dogs@v1")


@pytest.mark.parametrize("cloud_type, version_aware", [("s3", False)], indirect=True)
Expand All @@ -389,8 +389,8 @@ def test_pull_dataset_already_exists_locally(
):
catalog = cloud_test_catalog.catalog

catalog.pull_dataset("ds://dogs@v1", local_ds_name="other", no_cp=True)
catalog.pull_dataset("ds://dogs@v1", no_cp=True)
catalog.pull_dataset("ds://dogs@v1", local_ds_name="other")
catalog.pull_dataset("ds://dogs@v1")

other = catalog.get_dataset("other")
other_version = other.get_version(1)
Expand Down Expand Up @@ -422,7 +422,7 @@ def test_pull_dataset_local_name_already_exists(
local_ds_name or "dogs", [f"{src_uri}/dogs/*"], recursive=True
)
with pytest.raises(DataChainError) as exc_info:
catalog.pull_dataset("ds://dogs@v1", local_ds_name=local_ds_name, no_cp=True)
catalog.pull_dataset("ds://dogs@v1", local_ds_name=local_ds_name)

assert str(exc_info.value) == (
f'Local dataset ds://{local_ds_name or "dogs"}@v1 already exists with different'
Expand All @@ -431,5 +431,5 @@ def test_pull_dataset_local_name_already_exists(

# able to save it as version 2 of local dataset name
catalog.pull_dataset(
"ds://dogs@v1", local_ds_name=local_ds_name, local_ds_version=2, no_cp=True
"ds://dogs@v1", local_ds_name=local_ds_name, local_ds_version=2
)

0 comments on commit 10e90c5

Please sign in to comment.