Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Client.parse_url() #435

Merged
merged 4 commits into from
Sep 13, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 18 additions & 29 deletions src/datachain/catalog/catalog.py
Original file line number Diff line number Diff line change
@@ -621,10 +621,6 @@
code_ast.body[-1:] = new_expressions
return code_ast

def parse_url(self, uri: str, **config: Any) -> tuple[Client, str]:
config = config or self.client_config
return Client.parse_url(uri, self.cache, **config)

def get_client(self, uri: StorageURI, **config: Any) -> Client:
"""
Return the client corresponding to the given source `uri`.
@@ -651,17 +647,16 @@
partial_path: Optional[str]

client_config = client_config or self.client_config
client, path = self.parse_url(source, **client_config)
uri, path = Client.parse_url(source)
client = Client.get_client(source, self.cache, **client_config)
stem = os.path.basename(os.path.normpath(path))
prefix = (
posixpath.dirname(path)
if glob.has_magic(stem) or client.fs.isfile(source)
else path
)
storage_dataset_name = Storage.dataset_name(
client.uri, posixpath.join(prefix, "")
)
source_metastore = self.metastore.clone(client.uri)
storage_dataset_name = Storage.dataset_name(uri, posixpath.join(prefix, ""))
source_metastore = self.metastore.clone(uri)

columns = [
Column("path", String),
@@ -675,15 +670,13 @@
]

if skip_indexing:
source_metastore.create_storage_if_not_registered(client.uri)
storage = source_metastore.get_storage(client.uri)
source_metastore.init_partial_id(client.uri)
partial_id = source_metastore.get_next_partial_id(client.uri)
source_metastore.create_storage_if_not_registered(uri)
storage = source_metastore.get_storage(uri)
source_metastore.init_partial_id(uri)
partial_id = source_metastore.get_next_partial_id(uri)

source_metastore = self.metastore.clone(
uri=client.uri, partial_id=partial_id
)
source_metastore.init(client.uri)
source_metastore = self.metastore.clone(uri=uri, partial_id=partial_id)
source_metastore.init(uri)

source_warehouse = self.warehouse.clone()
dataset = self.create_dataset(
@@ -701,20 +694,16 @@
in_progress,
partial_id,
partial_path,
) = source_metastore.register_storage_for_indexing(
client.uri, force_update, prefix
)
) = source_metastore.register_storage_for_indexing(uri, force_update, prefix)
if in_progress:
raise PendingIndexingError(f"Pending indexing operation: uri={storage.uri}")

if not need_index:
assert partial_id is not None
assert partial_path is not None
source_metastore = self.metastore.clone(
uri=client.uri, partial_id=partial_id
)
source_metastore = self.metastore.clone(uri=uri, partial_id=partial_id)
source_warehouse = self.warehouse.clone()
dataset = self.get_dataset(Storage.dataset_name(client.uri, partial_path))
dataset = self.get_dataset(Storage.dataset_name(uri, partial_path))
lst = Listing(storage, source_metastore, source_warehouse, client, dataset)
logger.debug(
"Using cached listing %s. Valid till: %s",
@@ -731,11 +720,11 @@

return lst, path

source_metastore.init_partial_id(client.uri)
partial_id = source_metastore.get_next_partial_id(client.uri)
source_metastore.init_partial_id(uri)
partial_id = source_metastore.get_next_partial_id(uri)

source_metastore.init(client.uri)
source_metastore = self.metastore.clone(uri=client.uri, partial_id=partial_id)
source_metastore.init(uri)
source_metastore = self.metastore.clone(uri=uri, partial_id=partial_id)

source_warehouse = self.warehouse.clone()

@@ -1370,7 +1359,7 @@

def signed_url(self, source: str, path: str, client_config=None) -> str:
client_config = client_config or self.client_config
client, _ = self.parse_url(source, **client_config)
client = Client.get_client(source, self.cache, **client_config)

Check warning on line 1362 in src/datachain/catalog/catalog.py

Codecov / codecov/patch

src/datachain/catalog/catalog.py#L1362

Added line #L1362 was not covered by tests
return client.url(path)

def export_dataset_table(
17 changes: 9 additions & 8 deletions src/datachain/client/fsspec.py
Original file line number Diff line number Diff line change
@@ -116,15 +116,16 @@ def is_data_source_uri(name: str) -> bool:
return DATA_SOURCE_URI_PATTERN.match(name) is not None

@staticmethod
def parse_url(
source: str,
cache: DataChainCache,
**kwargs,
) -> tuple["Client", str]:
def parse_url(source: str) -> tuple[StorageURI, str]:
cls = Client.get_implementation(source)
storage_name, rel_path = cls.split_url(source)
return cls.get_uri(storage_name), rel_path

@staticmethod
def get_client(source: str, cache: DataChainCache, **kwargs) -> "Client":
cls = Client.get_implementation(source)
storage_url, rel_path = cls.split_url(source)
client = cls.from_name(storage_url, cache, kwargs)
return client, rel_path
storage_url, _ = cls.split_url(source)
return cls.from_name(storage_url, cache, kwargs)

@classmethod
def create_fs(cls, **kwargs) -> "AbstractFileSystem":
6 changes: 5 additions & 1 deletion src/datachain/lib/dc.py
Original file line number Diff line number Diff line change
@@ -408,7 +408,11 @@ def from_storage(
in_memory=in_memory,
)
.gen(
list_bucket(list_uri, client_config=session.catalog.client_config),
list_bucket(
list_uri,
session.catalog.cache,
client_config=session.catalog.client_config,
),
output={f"{object_name}": File},
)
.save(list_dataset_name, listing=True)
12 changes: 7 additions & 5 deletions src/datachain/lib/listing.py
Original file line number Diff line number Diff line change
@@ -20,15 +20,16 @@
LISTING_PREFIX = "lst__" # listing datasets start with this name


def list_bucket(uri: str, client_config=None) -> Callable:
def list_bucket(uri: str, cache, client_config=None) -> Callable:
"""
Function that returns another generator function that yields File objects
from bucket where each File represents one bucket entry.
"""

def list_func() -> Iterator[File]:
config = client_config or {}
client, path = Client.parse_url(uri, None, **config) # type: ignore[arg-type]
client = Client.get_client(uri, cache, **config) # type: ignore[arg-type]
_, path = Client.parse_url(uri)
for entries in iter_over_async(client.scandir(path.rstrip("/")), get_loop()):
yield from entries

@@ -76,16 +77,17 @@ def parse_listing_uri(uri: str, cache, client_config) -> tuple[str, str, str]:
"""
Parsing uri and returns listing dataset name, listing uri and listing path
"""
client, path = Client.parse_url(uri, cache, **client_config)
client = Client.get_client(uri, cache, **client_config)
storage_uri, path = Client.parse_url(uri)

# clean path without globs
lst_uri_path = (
posixpath.dirname(path) if uses_glob(path) or client.fs.isfile(uri) else path
)

lst_uri = f"{client.uri}/{lst_uri_path.lstrip('/')}"
lst_uri = f"{storage_uri}/{lst_uri_path.lstrip('/')}"
ds_name = (
f"{LISTING_PREFIX}{client.uri}/{posixpath.join(lst_uri_path, '').lstrip('/')}"
f"{LISTING_PREFIX}{storage_uri}/{posixpath.join(lst_uri_path, '').lstrip('/')}"
)

return ds_name, lst_uri, path
4 changes: 2 additions & 2 deletions src/datachain/lib/listing_info.py
Original file line number Diff line number Diff line change
@@ -13,8 +13,8 @@ def uri(self) -> str:

@property
def storage_uri(self) -> str:
client, _ = Client.parse_url(self.uri, None) # type: ignore[arg-type]
return client.uri
uri, _ = Client.parse_url(self.uri)
return uri

@property
def expires(self) -> Optional[datetime]:
8 changes: 2 additions & 6 deletions src/datachain/query/dataset.py
Original file line number Diff line number Diff line change
@@ -37,6 +37,7 @@

from datachain.asyn import ASYNC_WORKERS, AsyncMapper, OrderedMapper
from datachain.catalog import QUERY_SCRIPT_CANCELED_EXIT_CODE, get_catalog
from datachain.client import Client
from datachain.data_storage.schema import (
PARTITION_COLUMN_ID,
partition_col_names,
@@ -194,7 +195,7 @@ class IndexingStep(StartingStep):

def apply(self):
self.catalog.index([self.path], **self.kwargs)
uri, path = self.parse_path()
uri, path = Client.parse_url(self.path)
_partial_id, partial_path = self.catalog.metastore.get_valid_partial_id(
uri, path
)
@@ -216,11 +217,6 @@ def q(*columns):

return step_result(q, dataset_rows.c, dependencies=[storage.uri])

def parse_path(self):
client_config = self.kwargs.get("client_config") or {}
client, path = self.catalog.parse_url(self.path, **client_config)
return client.uri, path


def generator_then_call(generator, func: Callable):
"""
3 changes: 2 additions & 1 deletion tests/func/test_listing.py
Original file line number Diff line number Diff line change
@@ -6,11 +6,12 @@

def test_listing_generator(cloud_test_catalog, cloud_type):
ctc = cloud_test_catalog
catalog = ctc.catalog

uri = f"{ctc.src_uri}/cats"

dc = DataChain.from_records(DataChain.DEFAULT_FILE_RECORD).gen(
file=list_bucket(uri, client_config=ctc.catalog.client_config)
file=list_bucket(uri, catalog.cache, client_config=catalog.client_config)
)
assert dc.count() == 2

4 changes: 2 additions & 2 deletions tests/unit/lib/test_datachain.py
Original file line number Diff line number Diff line change
@@ -261,7 +261,7 @@ def test_listings(test_session, tmp_dir):
df.to_parquet(tmp_dir / "df.parquet")

uri = tmp_dir.as_uri()
client, _ = Client.parse_url(uri, test_session.catalog.cache)
client = Client.get_client(uri, test_session.catalog.cache)

DataChain.from_storage(uri, session=test_session)

@@ -292,7 +292,7 @@ def test_listings_reindex(test_session, tmp_dir):
df.to_parquet(tmp_dir / "df.parquet")

uri = tmp_dir.as_uri()
client, _ = Client.parse_url(uri, test_session.catalog.cache)
client = Client.get_client(uri, test_session.catalog.cache)

DataChain.from_storage(uri, session=test_session)
assert len(list(DataChain.listings(session=test_session).collect("listing"))) == 1
56 changes: 32 additions & 24 deletions tests/unit/test_client.py
Original file line number Diff line number Diff line change
@@ -40,15 +40,29 @@ def test_parse_url(cloud_test_catalog, rel_path, cloud_type):
assume(not rel_path.startswith("/"))
bucket_uri = cloud_test_catalog.src_uri
url = f"{bucket_uri}/{rel_path}"
uri, rel_part = Client.parse_url(url)
if cloud_type == "file":
root_uri = FileClient.root_path().as_uri()
assert uri == root_uri
assert rel_part == url[len(root_uri) :]
else:
assert uri == bucket_uri
assert rel_part == rel_path


@settings(suppress_health_check=[HealthCheck.function_scoped_fixture], deadline=None)
@given(rel_path=non_null_text)
def test_get_client(cloud_test_catalog, rel_path, cloud_type):
catalog = cloud_test_catalog.catalog
client, rel_part = catalog.parse_url(url)
bucket_uri = cloud_test_catalog.src_uri
url = f"{bucket_uri}/{rel_path}"
client = Client.get_client(url, catalog.cache)
assert client
if cloud_type == "file":
root_uri = FileClient.root_path().as_uri()
assert client.uri == root_uri
assert rel_part == url[len(root_uri) :]
else:
assert client.uri == bucket_uri
assert rel_part == rel_path


@settings(suppress_health_check=[HealthCheck.function_scoped_fixture], deadline=None)
@@ -59,35 +73,32 @@ def test_parse_url_uppercase_scheme(cloud_test_catalog, rel_path, cloud_type):
bucket_uri = cloud_test_catalog.src_uri
bucket_uri_upper = uppercase_scheme(bucket_uri)
url = f"{bucket_uri_upper}/{rel_path}"
catalog = cloud_test_catalog.catalog
client, rel_part = catalog.parse_url(url)
uri, rel_part = Client.parse_url(url)
if cloud_type == "file":
root_uri = FileClient.root_path().as_uri()
assert client.uri == root_uri
assert uri == root_uri
assert rel_part == url[len(root_uri) :]
else:
assert client.uri == bucket_uri
assert uri == bucket_uri
assert rel_part == rel_path


@pytest.mark.parametrize("cloud_type", ["file"], indirect=True)
def test_parse_file_absolute_path_without_protocol(cloud_test_catalog):
catalog = cloud_test_catalog.catalog
working_dir = Path().absolute()
root_uri = FileClient.root_path().as_uri()
client, rel_part = catalog.parse_url(str(working_dir / Path("animals")))
uri, rel_part = Client.parse_url(str(working_dir / Path("animals")))
working_dir = Path().absolute()
root_uri = FileClient.root_path().as_uri()
assert client.uri == root_uri
assert uri == root_uri
assert rel_part == (working_dir / Path("animals")).as_uri()[len(root_uri) :]


@pytest.mark.parametrize("cloud_type", ["file"], indirect=True)
def test_parse_file_relative_path_multiple_dirs_back(cloud_test_catalog):
catalog = cloud_test_catalog.catalog
client, rel_part = catalog.parse_url("../../animals".replace("/", os.sep))
uri, rel_part = Client.parse_url("../../animals".replace("/", os.sep))
root_uri = FileClient.root_path().as_uri()
assert client.uri == root_uri
assert uri == root_uri
assert (
rel_part
== (Path().absolute().parents[1] / Path("animals")).as_uri()[len(root_uri) :]
@@ -97,10 +108,9 @@ def test_parse_file_relative_path_multiple_dirs_back(cloud_test_catalog):
@pytest.mark.parametrize("cloud_type", ["file"], indirect=True)
@pytest.mark.parametrize("url", ["./animals".replace("/", os.sep), "animals"])
def test_parse_file_relative_path_working_dir(cloud_test_catalog, url):
catalog = cloud_test_catalog.catalog
client, rel_part = catalog.parse_url(url)
uri, rel_part = Client.parse_url(url)
root_uri = FileClient.root_path().as_uri()
assert client.uri == root_uri
assert uri == root_uri
assert rel_part == (Path().absolute() / Path("animals")).as_uri()[len(root_uri) :]


@@ -109,18 +119,17 @@ def test_parse_file_relative_path_home_dir(cloud_test_catalog):
if sys.platform == "win32":
# home dir shortcut is not available on windows
pytest.skip()
catalog = cloud_test_catalog.catalog
client, rel_part = catalog.parse_url("~/animals")
uri, rel_part = Client.parse_url("~/animals")
root_uri = FileClient.root_path().as_uri()
assert client.uri == root_uri
assert uri == root_uri
assert rel_part == (Path().home() / Path("animals")).as_uri()[len(root_uri) :]


@pytest.mark.parametrize("cloud_type", ["file"], indirect=True)
def test_parse_file_path_ends_with_slash(cloud_type):
client, rel_part = Client.parse_url("./animals/".replace("/", os.sep), None)
uri, rel_part = Client.parse_url("./animals/".replace("/", os.sep))
root_uri = FileClient.root_path().as_uri()
assert client.uri == root_uri
assert uri == root_uri
assert (
rel_part
== ((Path().absolute() / Path("animals")).as_uri())[len(root_uri) :] + "/"
@@ -130,7 +139,6 @@ def test_parse_file_path_ends_with_slash(cloud_type):
@pytest.mark.parametrize("cloud_type", ["s3", "azure", "gs"], indirect=True)
def test_parse_cloud_path_ends_with_slash(cloud_test_catalog):
uri = f"{cloud_test_catalog.src_uri}/animals/"
catalog = cloud_test_catalog.catalog
client, rel_part = catalog.parse_url(uri)
assert client.uri == cloud_test_catalog.src_uri
uri, rel_part = Client.parse_url(uri)
assert uri == cloud_test_catalog.src_uri
assert rel_part == "animals/"