Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix storage dependencies #421

Merged
merged 11 commits into from
Sep 16, 2024
6 changes: 5 additions & 1 deletion src/datachain/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,11 @@ def parse(

if is_listing_dataset(dataset_name):
dependency_type = DatasetDependencyType.STORAGE # type: ignore[arg-type]
dependency_name = listing_uri_from_name(dataset_name)
client, _ = Client.parse_url(
listing_uri_from_name(dataset_name),
None, # type: ignore[arg-type]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we fix this instead of ignoring?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is on my plate for long time. I will see if it makes sense to do it in separate PR or this one

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix in separate PR #435

)
dependency_name = client.uri

return cls(
id,
Expand Down
18 changes: 17 additions & 1 deletion tests/func/test_datachain.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
from PIL import Image
from sqlalchemy import Column

from datachain.client.local import FileClient
from datachain.data_storage.sqlite import SQLiteWarehouse
from datachain.dataset import DatasetStats
from datachain.dataset import DatasetDependencyType, DatasetStats
from datachain.lib.dc import DataChain, DataChainColumnError
from datachain.lib.file import File, ImageFile
from datachain.lib.listing import (
Expand Down Expand Up @@ -175,6 +176,21 @@ def _list_dataset_name(uri: str) -> str:
)


def test_from_storage_dependencies(cloud_test_catalog, cloud_type):
ctc = cloud_test_catalog
src_uri = ctc.src_uri
uri = f"{src_uri}/cats"
ds_name = "dep"
DataChain.from_storage(uri, session=ctc.session).save(ds_name)
dependencies = ctc.session.catalog.get_dataset_dependencies(ds_name, 1)
assert len(dependencies) == 1
assert dependencies[0].type == DatasetDependencyType.STORAGE
if cloud_type == "file":
assert dependencies[0].name == FileClient.root_path().as_uri()
else:
assert dependencies[0].name == src_uri


@pytest.mark.parametrize("use_cache", [True, False])
def test_map_file(cloud_test_catalog, use_cache):
ctc = cloud_test_catalog
Expand Down
7 changes: 4 additions & 3 deletions tests/func/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
import sqlalchemy as sa

from datachain.catalog.catalog import DATASET_INTERNAL_ERROR_MESSAGE
from datachain.client.local import FileClient
from datachain.data_storage.sqlite import SQLiteWarehouse
from datachain.dataset import LISTING_PREFIX, DatasetDependencyType, DatasetStatus
from datachain.dataset import DatasetDependencyType, DatasetStatus
from datachain.error import DatasetInvalidVersionError, DatasetNotFoundError
from datachain.lib.dc import DataChain
from datachain.lib.listing import parse_listing_uri
Expand Down Expand Up @@ -805,7 +806,7 @@ def test_dataset_stats_registered_ds(cloud_test_catalog, dogs_dataset):


@pytest.mark.parametrize("indirect", [True, False])
def test_dataset_storage_dependencies(cloud_test_catalog, indirect):
def test_dataset_storage_dependencies(cloud_test_catalog, cloud_type, indirect):
ctc = cloud_test_catalog
session = ctc.session
catalog = session.catalog
Expand All @@ -824,7 +825,7 @@ def test_dataset_storage_dependencies(cloud_test_catalog, indirect):
{
"id": ANY,
"type": DatasetDependencyType.STORAGE,
"name": lst_dataset.name.removeprefix(LISTING_PREFIX),
"name": uri if cloud_type != "file" else FileClient.root_path().as_uri(),
"version": "1",
"created_at": lst_dataset.get_version(1).created_at,
"dependencies": [],
Expand Down
Loading