From 10e90c5c25e8f34d560b2958469d7f83b15e4812 Mon Sep 17 00:00:00 2001 From: Ivan Longin Date: Tue, 17 Dec 2024 11:46:21 +0100 Subject: [PATCH] making cp non default when pulling dataset (#720) --- src/datachain/catalog/catalog.py | 6 +++--- src/datachain/cli.py | 6 +++--- tests/func/test_pull.py | 24 ++++++++++++------------ 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/src/datachain/catalog/catalog.py b/src/datachain/catalog/catalog.py index b66a60f1c..701a4c244 100644 --- a/src/datachain/catalog/catalog.py +++ b/src/datachain/catalog/catalog.py @@ -1297,7 +1297,7 @@ def pull_dataset( # noqa: PLR0915 output: Optional[str] = None, local_ds_name: Optional[str] = None, local_ds_version: Optional[int] = None, - no_cp: bool = False, + cp: bool = False, force: bool = False, edatachain: bool = False, edatachain_file: Optional[str] = None, @@ -1305,7 +1305,7 @@ def pull_dataset( # noqa: PLR0915 client_config=None, ) -> None: def _instantiate(ds_uri: str) -> None: - if no_cp: + if not cp: return assert output self.cp( @@ -1318,7 +1318,7 @@ def _instantiate(ds_uri: str) -> None: ) print(f"Dataset {ds_uri} instantiated locally to {output}") - if not output and not no_cp: + if cp and not output: raise ValueError("Please provide output directory for instantiation") studio_client = StudioClient() diff --git a/src/datachain/cli.py b/src/datachain/cli.py index 0c8425dca..cc956eec3 100644 --- a/src/datachain/cli.py +++ b/src/datachain/cli.py @@ -479,10 +479,10 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915 help="Copy directories recursively", ) parse_pull.add_argument( - "--no-cp", + "--cp", default=False, action="store_true", - help="Do not copy files, just pull a remote dataset into local DB", + help="Copy actual files after pulling remote dataset into local DB", ) parse_pull.add_argument( "--edatachain", @@ -1322,7 +1322,7 @@ def main(argv: Optional[list[str]] = None) -> int: # noqa: C901, PLR0912, PLR09 args.output, local_ds_name=args.local_name, local_ds_version=args.local_version, - no_cp=args.no_cp, + cp=args.cp, force=bool(args.force), edatachain=args.edatachain, edatachain_file=args.edatachain_file, diff --git a/tests/func/test_pull.py b/tests/func/test_pull.py index efc6c2d2b..6e0ce98a8 100644 --- a/tests/func/test_pull.py +++ b/tests/func/test_pull.py @@ -215,7 +215,7 @@ def test_pull_dataset_success( output=str(dest), local_ds_name=local_ds_name, local_ds_version=local_ds_version, - no_cp=False, + cp=True, ) else: # trying to pull multiple times since that should work as well @@ -224,7 +224,7 @@ def test_pull_dataset_success( dataset_uri, local_ds_name=local_ds_name, local_ds_version=local_ds_version, - no_cp=True, + cp=False, ) dataset = catalog.get_dataset(local_ds_name or "dogs") @@ -279,7 +279,7 @@ def test_pull_dataset_wrong_dataset_uri_format( catalog = cloud_test_catalog.catalog with pytest.raises(DataChainError) as exc_info: - catalog.pull_dataset("wrong", no_cp=True) + catalog.pull_dataset("wrong") assert str(exc_info.value) == "Error when parsing dataset uri" @@ -293,7 +293,7 @@ def test_pull_dataset_wrong_version( catalog = cloud_test_catalog.catalog with pytest.raises(DataChainError) as exc_info: - catalog.pull_dataset("ds://dogs@v5", no_cp=True) + catalog.pull_dataset("ds://dogs@v5") assert str(exc_info.value) == "Dataset dogs doesn't have version 5 on server" @@ -311,7 +311,7 @@ def test_pull_dataset_not_found_in_remote( catalog = cloud_test_catalog.catalog with pytest.raises(DataChainError) as exc_info: - catalog.pull_dataset("ds://dogs@v1", no_cp=True) + catalog.pull_dataset("ds://dogs@v1") assert str(exc_info.value) == "Error from server: Dataset not found" @@ -330,7 +330,7 @@ def test_pull_dataset_error_on_fetching_stats( catalog = cloud_test_catalog.catalog with pytest.raises(DataChainError) as exc_info: - catalog.pull_dataset("ds://dogs@v1", no_cp=True) + catalog.pull_dataset("ds://dogs@v1") assert str(exc_info.value) == "Error from server: Internal error" @@ -353,7 +353,7 @@ def test_pull_dataset_exporting_dataset_failed_in_remote( catalog = cloud_test_catalog.catalog with pytest.raises(DataChainError) as exc_info: - catalog.pull_dataset("ds://dogs@v1", no_cp=True) + catalog.pull_dataset("ds://dogs@v1") assert str(exc_info.value) == ( f"Error from server: Dataset export {export_status} in Studio" ) @@ -374,7 +374,7 @@ def test_pull_dataset_empty_parquet( catalog = cloud_test_catalog.catalog with pytest.raises(RuntimeError): - catalog.pull_dataset("ds://dogs@v1", no_cp=True) + catalog.pull_dataset("ds://dogs@v1") @pytest.mark.parametrize("cloud_type, version_aware", [("s3", False)], indirect=True) @@ -389,8 +389,8 @@ def test_pull_dataset_already_exists_locally( ): catalog = cloud_test_catalog.catalog - catalog.pull_dataset("ds://dogs@v1", local_ds_name="other", no_cp=True) - catalog.pull_dataset("ds://dogs@v1", no_cp=True) + catalog.pull_dataset("ds://dogs@v1", local_ds_name="other") + catalog.pull_dataset("ds://dogs@v1") other = catalog.get_dataset("other") other_version = other.get_version(1) @@ -422,7 +422,7 @@ def test_pull_dataset_local_name_already_exists( local_ds_name or "dogs", [f"{src_uri}/dogs/*"], recursive=True ) with pytest.raises(DataChainError) as exc_info: - catalog.pull_dataset("ds://dogs@v1", local_ds_name=local_ds_name, no_cp=True) + catalog.pull_dataset("ds://dogs@v1", local_ds_name=local_ds_name) assert str(exc_info.value) == ( f'Local dataset ds://{local_ds_name or "dogs"}@v1 already exists with different' @@ -431,5 +431,5 @@ def test_pull_dataset_local_name_already_exists( # able to save it as version 2 of local dataset name catalog.pull_dataset( - "ds://dogs@v1", local_ds_name=local_ds_name, local_ds_version=2, no_cp=True + "ds://dogs@v1", local_ds_name=local_ds_name, local_ds_version=2 )