Skip to content

Commit

Permalink
Refactor Client.parse_url() (#435)
Browse files Browse the repository at this point in the history
* refactoring parse_url

* removed old method
  • Loading branch information
ilongin authored Sep 13, 2024
1 parent 944defc commit 1ef7556
Show file tree
Hide file tree
Showing 9 changed files with 79 additions and 78 deletions.
47 changes: 18 additions & 29 deletions src/datachain/catalog/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,10 +621,6 @@ def attach_query_wrapper(self, code_ast):
code_ast.body[-1:] = new_expressions
return code_ast

def parse_url(self, uri: str, **config: Any) -> tuple[Client, str]:
config = config or self.client_config
return Client.parse_url(uri, self.cache, **config)

def get_client(self, uri: StorageURI, **config: Any) -> Client:
"""
Return the client corresponding to the given source `uri`.
Expand All @@ -651,17 +647,16 @@ def enlist_source(
partial_path: Optional[str]

client_config = client_config or self.client_config
client, path = self.parse_url(source, **client_config)
uri, path = Client.parse_url(source)
client = Client.get_client(source, self.cache, **client_config)
stem = os.path.basename(os.path.normpath(path))
prefix = (
posixpath.dirname(path)
if glob.has_magic(stem) or client.fs.isfile(source)
else path
)
storage_dataset_name = Storage.dataset_name(
client.uri, posixpath.join(prefix, "")
)
source_metastore = self.metastore.clone(client.uri)
storage_dataset_name = Storage.dataset_name(uri, posixpath.join(prefix, ""))
source_metastore = self.metastore.clone(uri)

columns = [
Column("path", String),
Expand All @@ -675,15 +670,13 @@ def enlist_source(
]

if skip_indexing:
source_metastore.create_storage_if_not_registered(client.uri)
storage = source_metastore.get_storage(client.uri)
source_metastore.init_partial_id(client.uri)
partial_id = source_metastore.get_next_partial_id(client.uri)
source_metastore.create_storage_if_not_registered(uri)
storage = source_metastore.get_storage(uri)
source_metastore.init_partial_id(uri)
partial_id = source_metastore.get_next_partial_id(uri)

source_metastore = self.metastore.clone(
uri=client.uri, partial_id=partial_id
)
source_metastore.init(client.uri)
source_metastore = self.metastore.clone(uri=uri, partial_id=partial_id)
source_metastore.init(uri)

source_warehouse = self.warehouse.clone()
dataset = self.create_dataset(
Expand All @@ -701,20 +694,16 @@ def enlist_source(
in_progress,
partial_id,
partial_path,
) = source_metastore.register_storage_for_indexing(
client.uri, force_update, prefix
)
) = source_metastore.register_storage_for_indexing(uri, force_update, prefix)
if in_progress:
raise PendingIndexingError(f"Pending indexing operation: uri={storage.uri}")

if not need_index:
assert partial_id is not None
assert partial_path is not None
source_metastore = self.metastore.clone(
uri=client.uri, partial_id=partial_id
)
source_metastore = self.metastore.clone(uri=uri, partial_id=partial_id)
source_warehouse = self.warehouse.clone()
dataset = self.get_dataset(Storage.dataset_name(client.uri, partial_path))
dataset = self.get_dataset(Storage.dataset_name(uri, partial_path))
lst = Listing(storage, source_metastore, source_warehouse, client, dataset)
logger.debug(
"Using cached listing %s. Valid till: %s",
Expand All @@ -731,11 +720,11 @@ def enlist_source(

return lst, path

source_metastore.init_partial_id(client.uri)
partial_id = source_metastore.get_next_partial_id(client.uri)
source_metastore.init_partial_id(uri)
partial_id = source_metastore.get_next_partial_id(uri)

source_metastore.init(client.uri)
source_metastore = self.metastore.clone(uri=client.uri, partial_id=partial_id)
source_metastore.init(uri)
source_metastore = self.metastore.clone(uri=uri, partial_id=partial_id)

source_warehouse = self.warehouse.clone()

Expand Down Expand Up @@ -1370,7 +1359,7 @@ def ls_dataset_rows(

def signed_url(self, source: str, path: str, client_config=None) -> str:
client_config = client_config or self.client_config
client, _ = self.parse_url(source, **client_config)
client = Client.get_client(source, self.cache, **client_config)
return client.url(path)

def export_dataset_table(
Expand Down
17 changes: 9 additions & 8 deletions src/datachain/client/fsspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,15 +116,16 @@ def is_data_source_uri(name: str) -> bool:
return DATA_SOURCE_URI_PATTERN.match(name) is not None

@staticmethod
def parse_url(
source: str,
cache: DataChainCache,
**kwargs,
) -> tuple["Client", str]:
def parse_url(source: str) -> tuple[StorageURI, str]:
cls = Client.get_implementation(source)
storage_name, rel_path = cls.split_url(source)
return cls.get_uri(storage_name), rel_path

@staticmethod
def get_client(source: str, cache: DataChainCache, **kwargs) -> "Client":
cls = Client.get_implementation(source)
storage_url, rel_path = cls.split_url(source)
client = cls.from_name(storage_url, cache, kwargs)
return client, rel_path
storage_url, _ = cls.split_url(source)
return cls.from_name(storage_url, cache, kwargs)

@classmethod
def create_fs(cls, **kwargs) -> "AbstractFileSystem":
Expand Down
6 changes: 5 additions & 1 deletion src/datachain/lib/dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,11 @@ def from_storage(
in_memory=in_memory,
)
.gen(
list_bucket(list_uri, client_config=session.catalog.client_config),
list_bucket(
list_uri,
session.catalog.cache,
client_config=session.catalog.client_config,
),
output={f"{object_name}": File},
)
.save(list_dataset_name, listing=True)
Expand Down
12 changes: 7 additions & 5 deletions src/datachain/lib/listing.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,16 @@
LISTING_PREFIX = "lst__" # listing datasets start with this name


def list_bucket(uri: str, client_config=None) -> Callable:
def list_bucket(uri: str, cache, client_config=None) -> Callable:
"""
Function that returns another generator function that yields File objects
from bucket where each File represents one bucket entry.
"""

def list_func() -> Iterator[File]:
config = client_config or {}
client, path = Client.parse_url(uri, None, **config) # type: ignore[arg-type]
client = Client.get_client(uri, cache, **config) # type: ignore[arg-type]
_, path = Client.parse_url(uri)
for entries in iter_over_async(client.scandir(path.rstrip("/")), get_loop()):
yield from entries

Expand Down Expand Up @@ -76,16 +77,17 @@ def parse_listing_uri(uri: str, cache, client_config) -> tuple[str, str, str]:
"""
Parsing uri and returns listing dataset name, listing uri and listing path
"""
client, path = Client.parse_url(uri, cache, **client_config)
client = Client.get_client(uri, cache, **client_config)
storage_uri, path = Client.parse_url(uri)

# clean path without globs
lst_uri_path = (
posixpath.dirname(path) if uses_glob(path) or client.fs.isfile(uri) else path
)

lst_uri = f"{client.uri}/{lst_uri_path.lstrip('/')}"
lst_uri = f"{storage_uri}/{lst_uri_path.lstrip('/')}"
ds_name = (
f"{LISTING_PREFIX}{client.uri}/{posixpath.join(lst_uri_path, '').lstrip('/')}"
f"{LISTING_PREFIX}{storage_uri}/{posixpath.join(lst_uri_path, '').lstrip('/')}"
)

return ds_name, lst_uri, path
Expand Down
4 changes: 2 additions & 2 deletions src/datachain/lib/listing_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ def uri(self) -> str:

@property
def storage_uri(self) -> str:
client, _ = Client.parse_url(self.uri, None) # type: ignore[arg-type]
return client.uri
uri, _ = Client.parse_url(self.uri)
return uri

@property
def expires(self) -> Optional[datetime]:
Expand Down
8 changes: 2 additions & 6 deletions src/datachain/query/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

from datachain.asyn import ASYNC_WORKERS, AsyncMapper, OrderedMapper
from datachain.catalog import QUERY_SCRIPT_CANCELED_EXIT_CODE, get_catalog
from datachain.client import Client
from datachain.data_storage.schema import (
PARTITION_COLUMN_ID,
partition_col_names,
Expand Down Expand Up @@ -194,7 +195,7 @@ class IndexingStep(StartingStep):

def apply(self):
self.catalog.index([self.path], **self.kwargs)
uri, path = self.parse_path()
uri, path = Client.parse_url(self.path)
_partial_id, partial_path = self.catalog.metastore.get_valid_partial_id(
uri, path
)
Expand All @@ -216,11 +217,6 @@ def q(*columns):

return step_result(q, dataset_rows.c, dependencies=[storage.uri])

def parse_path(self):
client_config = self.kwargs.get("client_config") or {}
client, path = self.catalog.parse_url(self.path, **client_config)
return client.uri, path


def generator_then_call(generator, func: Callable):
"""
Expand Down
3 changes: 2 additions & 1 deletion tests/func/test_listing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@

def test_listing_generator(cloud_test_catalog, cloud_type):
ctc = cloud_test_catalog
catalog = ctc.catalog

uri = f"{ctc.src_uri}/cats"

dc = DataChain.from_records(DataChain.DEFAULT_FILE_RECORD).gen(
file=list_bucket(uri, client_config=ctc.catalog.client_config)
file=list_bucket(uri, catalog.cache, client_config=catalog.client_config)
)
assert dc.count() == 2

Expand Down
4 changes: 2 additions & 2 deletions tests/unit/lib/test_datachain.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def test_listings(test_session, tmp_dir):
df.to_parquet(tmp_dir / "df.parquet")

uri = tmp_dir.as_uri()
client, _ = Client.parse_url(uri, test_session.catalog.cache)
client = Client.get_client(uri, test_session.catalog.cache)

DataChain.from_storage(uri, session=test_session)

Expand Down Expand Up @@ -292,7 +292,7 @@ def test_listings_reindex(test_session, tmp_dir):
df.to_parquet(tmp_dir / "df.parquet")

uri = tmp_dir.as_uri()
client, _ = Client.parse_url(uri, test_session.catalog.cache)
client = Client.get_client(uri, test_session.catalog.cache)

DataChain.from_storage(uri, session=test_session)
assert len(list(DataChain.listings(session=test_session).collect("listing"))) == 1
Expand Down
56 changes: 32 additions & 24 deletions tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,29 @@ def test_parse_url(cloud_test_catalog, rel_path, cloud_type):
assume(not rel_path.startswith("/"))
bucket_uri = cloud_test_catalog.src_uri
url = f"{bucket_uri}/{rel_path}"
uri, rel_part = Client.parse_url(url)
if cloud_type == "file":
root_uri = FileClient.root_path().as_uri()
assert uri == root_uri
assert rel_part == url[len(root_uri) :]
else:
assert uri == bucket_uri
assert rel_part == rel_path


@settings(suppress_health_check=[HealthCheck.function_scoped_fixture], deadline=None)
@given(rel_path=non_null_text)
def test_get_client(cloud_test_catalog, rel_path, cloud_type):
catalog = cloud_test_catalog.catalog
client, rel_part = catalog.parse_url(url)
bucket_uri = cloud_test_catalog.src_uri
url = f"{bucket_uri}/{rel_path}"
client = Client.get_client(url, catalog.cache)
assert client
if cloud_type == "file":
root_uri = FileClient.root_path().as_uri()
assert client.uri == root_uri
assert rel_part == url[len(root_uri) :]
else:
assert client.uri == bucket_uri
assert rel_part == rel_path


@settings(suppress_health_check=[HealthCheck.function_scoped_fixture], deadline=None)
Expand All @@ -59,35 +73,32 @@ def test_parse_url_uppercase_scheme(cloud_test_catalog, rel_path, cloud_type):
bucket_uri = cloud_test_catalog.src_uri
bucket_uri_upper = uppercase_scheme(bucket_uri)
url = f"{bucket_uri_upper}/{rel_path}"
catalog = cloud_test_catalog.catalog
client, rel_part = catalog.parse_url(url)
uri, rel_part = Client.parse_url(url)
if cloud_type == "file":
root_uri = FileClient.root_path().as_uri()
assert client.uri == root_uri
assert uri == root_uri
assert rel_part == url[len(root_uri) :]
else:
assert client.uri == bucket_uri
assert uri == bucket_uri
assert rel_part == rel_path


@pytest.mark.parametrize("cloud_type", ["file"], indirect=True)
def test_parse_file_absolute_path_without_protocol(cloud_test_catalog):
catalog = cloud_test_catalog.catalog
working_dir = Path().absolute()
root_uri = FileClient.root_path().as_uri()
client, rel_part = catalog.parse_url(str(working_dir / Path("animals")))
uri, rel_part = Client.parse_url(str(working_dir / Path("animals")))
working_dir = Path().absolute()
root_uri = FileClient.root_path().as_uri()
assert client.uri == root_uri
assert uri == root_uri
assert rel_part == (working_dir / Path("animals")).as_uri()[len(root_uri) :]


@pytest.mark.parametrize("cloud_type", ["file"], indirect=True)
def test_parse_file_relative_path_multiple_dirs_back(cloud_test_catalog):
catalog = cloud_test_catalog.catalog
client, rel_part = catalog.parse_url("../../animals".replace("/", os.sep))
uri, rel_part = Client.parse_url("../../animals".replace("/", os.sep))
root_uri = FileClient.root_path().as_uri()
assert client.uri == root_uri
assert uri == root_uri
assert (
rel_part
== (Path().absolute().parents[1] / Path("animals")).as_uri()[len(root_uri) :]
Expand All @@ -97,10 +108,9 @@ def test_parse_file_relative_path_multiple_dirs_back(cloud_test_catalog):
@pytest.mark.parametrize("cloud_type", ["file"], indirect=True)
@pytest.mark.parametrize("url", ["./animals".replace("/", os.sep), "animals"])
def test_parse_file_relative_path_working_dir(cloud_test_catalog, url):
catalog = cloud_test_catalog.catalog
client, rel_part = catalog.parse_url(url)
uri, rel_part = Client.parse_url(url)
root_uri = FileClient.root_path().as_uri()
assert client.uri == root_uri
assert uri == root_uri
assert rel_part == (Path().absolute() / Path("animals")).as_uri()[len(root_uri) :]


Expand All @@ -109,18 +119,17 @@ def test_parse_file_relative_path_home_dir(cloud_test_catalog):
if sys.platform == "win32":
# home dir shortcut is not available on windows
pytest.skip()
catalog = cloud_test_catalog.catalog
client, rel_part = catalog.parse_url("~/animals")
uri, rel_part = Client.parse_url("~/animals")
root_uri = FileClient.root_path().as_uri()
assert client.uri == root_uri
assert uri == root_uri
assert rel_part == (Path().home() / Path("animals")).as_uri()[len(root_uri) :]


@pytest.mark.parametrize("cloud_type", ["file"], indirect=True)
def test_parse_file_path_ends_with_slash(cloud_type):
client, rel_part = Client.parse_url("./animals/".replace("/", os.sep), None)
uri, rel_part = Client.parse_url("./animals/".replace("/", os.sep))
root_uri = FileClient.root_path().as_uri()
assert client.uri == root_uri
assert uri == root_uri
assert (
rel_part
== ((Path().absolute() / Path("animals")).as_uri())[len(root_uri) :] + "/"
Expand All @@ -130,7 +139,6 @@ def test_parse_file_path_ends_with_slash(cloud_type):
@pytest.mark.parametrize("cloud_type", ["s3", "azure", "gs"], indirect=True)
def test_parse_cloud_path_ends_with_slash(cloud_test_catalog):
uri = f"{cloud_test_catalog.src_uri}/animals/"
catalog = cloud_test_catalog.catalog
client, rel_part = catalog.parse_url(uri)
assert client.uri == cloud_test_catalog.src_uri
uri, rel_part = Client.parse_url(uri)
assert uri == cloud_test_catalog.src_uri
assert rel_part == "animals/"

0 comments on commit 1ef7556

Please sign in to comment.