Skip to content

Commit

Permalink
fix(from_storage): no listing / cache on a single file (#734)
Browse files Browse the repository at this point in the history
* fix(from_storage): no listing / cache on a single file

* fix single file handling case for cp/ls
  • Loading branch information
shcheklein authored Dec 28, 2024
1 parent bbd44e3 commit 31d96ab
Show file tree
Hide file tree
Showing 11 changed files with 250 additions and 82 deletions.
144 changes: 92 additions & 52 deletions src/datachain/catalog/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,8 @@ def do_task(self, urls):
class NodeGroup:
"""Class for a group of nodes from the same source"""

listing: "Listing"
listing: Optional["Listing"]
client: "Client"
sources: list[DataSource]

# The source path within the bucket
Expand Down Expand Up @@ -268,9 +269,7 @@ def download(self, recursive: bool = False, pbar=None) -> None:
Download this node group to cache.
"""
if self.sources:
self.listing.client.fetch_nodes(
self.iternodes(recursive), shared_progress_bar=pbar
)
self.client.fetch_nodes(self.iternodes(recursive), shared_progress_bar=pbar)


def check_output_dataset_file(
Expand Down Expand Up @@ -375,14 +374,15 @@ def collect_nodes_for_cp(

# Collect all sources to process
for node_group in node_groups:
listing: Listing = node_group.listing
listing: Optional[Listing] = node_group.listing
valid_sources: list[DataSource] = []
for dsrc in node_group.sources:
if dsrc.is_single_object():
total_size += dsrc.node.size
total_files += 1
valid_sources.append(dsrc)
else:
assert listing
node = dsrc.node
if not recursive:
print(f"{node.full_path} is a directory (not copied).")
Expand Down Expand Up @@ -433,37 +433,51 @@ def instantiate_node_groups(
)

output_dir = output
output_file = None
if copy_to_filename:
output_dir = os.path.dirname(output)
if not output_dir:
output_dir = "."
output_file = os.path.basename(output)

# Instantiate these nodes
for node_group in node_groups:
if not node_group.sources:
continue
listing: Listing = node_group.listing
listing: Optional[Listing] = node_group.listing
source_path: str = node_group.source_path

copy_dir_contents = always_copy_dir_contents or source_path.endswith("/")
instantiated_nodes = listing.collect_nodes_to_instantiate(
node_group.sources,
copy_to_filename,
recursive,
copy_dir_contents,
source_path,
node_group.is_edatachain,
node_group.is_dataset,
)
if not virtual_only:
listing.instantiate_nodes(
instantiated_nodes,
output_dir,
total_files,
force=force,
shared_progress_bar=instantiate_progress_bar,
if not listing:
source = node_group.sources[0]
client = source.client
node = NodeWithPath(source.node, [output_file or source.node.path])
instantiated_nodes = [node]
if not virtual_only:
node.instantiate(
client, output_dir, instantiate_progress_bar, force=force
)
else:
instantiated_nodes = listing.collect_nodes_to_instantiate(
node_group.sources,
copy_to_filename,
recursive,
copy_dir_contents,
source_path,
node_group.is_edatachain,
node_group.is_dataset,
)
if not virtual_only:
listing.instantiate_nodes(
instantiated_nodes,
output_dir,
total_files,
force=force,
shared_progress_bar=instantiate_progress_bar,
)

node_group.instantiated_nodes = instantiated_nodes

if instantiate_progress_bar:
instantiate_progress_bar.close()

Expand Down Expand Up @@ -592,7 +606,7 @@ def enlist_source(
client_config=None,
object_name="file",
skip_indexing=False,
) -> tuple["Listing", str]:
) -> tuple[Optional["Listing"], "Client", str]:
from datachain.lib.dc import DataChain
from datachain.listing import Listing

Expand All @@ -603,16 +617,19 @@ def enlist_source(
list_ds_name, list_uri, list_path, _ = get_listing(
source, self.session, update=update
)
lst = None
client = Client.get_client(list_uri, self.cache, **self.client_config)

if list_ds_name:
lst = Listing(
self.metastore.clone(),
self.warehouse.clone(),
client,
dataset_name=list_ds_name,
object_name=object_name,
)

lst = Listing(
self.metastore.clone(),
self.warehouse.clone(),
Client.get_client(list_uri, self.cache, **self.client_config),
dataset_name=list_ds_name,
object_name=object_name,
)

return lst, list_path
return lst, client, list_path

def _remove_dataset_rows_and_warehouse_info(
self, dataset: DatasetRecord, version: int, **kwargs
Expand All @@ -635,24 +652,30 @@ def enlist_sources(
) -> Optional[list["DataSource"]]:
enlisted_sources = []
for src in sources: # Opt: parallel
listing, file_path = self.enlist_source(
listing, client, file_path = self.enlist_source(
src,
update,
client_config=client_config or self.client_config,
skip_indexing=skip_indexing,
)
enlisted_sources.append((listing, file_path))
enlisted_sources.append((listing, client, file_path))

if only_index:
# sometimes we don't really need listing result (e.g on indexing process)
# so this is to improve performance
return None

dsrc_all: list[DataSource] = []
for listing, file_path in enlisted_sources:
nodes = listing.expand_path(file_path)
dir_only = file_path.endswith("/")
dsrc_all.extend(DataSource(listing, node, dir_only) for node in nodes)
for listing, client, file_path in enlisted_sources:
if not listing:
nodes = [Node.from_file(client.get_file_info(file_path))]
dir_only = False
else:
nodes = listing.expand_path(file_path)
dir_only = file_path.endswith("/")
dsrc_all.extend(
DataSource(listing, client, node, dir_only) for node in nodes
)
return dsrc_all

def enlist_sources_grouped(
Expand All @@ -667,7 +690,7 @@ def enlist_sources_grouped(

def _row_to_node(d: dict[str, Any]) -> Node:
del d["file__source"]
return Node.from_dict(d)
return Node.from_row(d)

enlisted_sources: list[tuple[bool, bool, Any]] = []
client_config = client_config or self.client_config
Expand All @@ -677,7 +700,7 @@ def _row_to_node(d: dict[str, Any]) -> Node:
edatachain_data = parse_edatachain_file(src)
indexed_sources = []
for ds in edatachain_data:
listing, source_path = self.enlist_source(
listing, _, source_path = self.enlist_source(
ds["data-source"]["uri"],
update,
client_config=client_config,
Expand All @@ -701,6 +724,7 @@ def _row_to_node(d: dict[str, Any]) -> Node:
client = self.get_client(source, **client_config)
uri = client.uri
dataset_name, _, _, _ = get_listing(uri, self.session)
assert dataset_name
listing = Listing(
self.metastore.clone(),
self.warehouse.clone(),
Expand All @@ -713,6 +737,7 @@ def _row_to_node(d: dict[str, Any]) -> Node:
indexed_sources.append(
(
listing,
client,
source,
[_row_to_node(r) for r in rows],
ds_name,
Expand All @@ -722,25 +747,28 @@ def _row_to_node(d: dict[str, Any]) -> Node:

enlisted_sources.append((False, True, indexed_sources))
else:
listing, source_path = self.enlist_source(
listing, client, source_path = self.enlist_source(
src, update, client_config=client_config
)
enlisted_sources.append((False, False, (listing, source_path)))
enlisted_sources.append((False, False, (listing, client, source_path)))

node_groups = []
for is_datachain, is_dataset, payload in enlisted_sources: # Opt: parallel
if is_dataset:
for (
listing,
client,
source_path,
nodes,
dataset_name,
dataset_version,
) in payload:
dsrc = [DataSource(listing, node) for node in nodes]
assert listing
dsrc = [DataSource(listing, client, node) for node in nodes]
node_groups.append(
NodeGroup(
listing,
client,
dsrc,
source_path,
dataset_name=dataset_name,
Expand All @@ -749,18 +777,30 @@ def _row_to_node(d: dict[str, Any]) -> Node:
)
elif is_datachain:
for listing, source_path, paths in payload:
dsrc = [DataSource(listing, listing.resolve_path(p)) for p in paths]
assert listing
dsrc = [
DataSource(listing, listing.client, listing.resolve_path(p))
for p in paths
]
node_groups.append(
NodeGroup(listing, dsrc, source_path, is_edatachain=True)
NodeGroup(
listing,
listing.client,
dsrc,
source_path,
is_edatachain=True,
)
)
else:
listing, source_path = payload
as_container = source_path.endswith("/")
dsrc = [
DataSource(listing, n, as_container)
for n in listing.expand_path(source_path, use_glob=not no_glob)
]
node_groups.append(NodeGroup(listing, dsrc, source_path))
listing, client, source_path = payload
if not listing:
nodes = [Node.from_file(client.get_file_info(source_path))]
as_container = False
else:
as_container = source_path.endswith("/")
nodes = listing.expand_path(source_path, use_glob=not no_glob)
dsrc = [DataSource(listing, client, n, as_container) for n in nodes]
node_groups.append(NodeGroup(listing, client, dsrc, source_path))

return node_groups

Expand Down
10 changes: 4 additions & 6 deletions src/datachain/catalog/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,19 @@


class DataSource:
def __init__(self, listing, node, as_container=False):
def __init__(self, listing, client, node, as_container=False):
self.listing = listing
self.client = client
self.node = node
self.as_container = (
as_container # Indicates whether a .tar file is handled as a container
)

def get_full_path(self):
return self.get_node_full_path(self.node)

def get_node_full_path(self, node):
return self.listing.client.get_full_path(node.full_path)
return self.client.get_full_path(node.full_path)

def get_node_full_path_from_path(self, full_path):
return self.listing.client.get_full_path(full_path)
return self.client.get_full_path(full_path)

def is_single_object(self):
return self.node.dir_type == DirType.FILE or (
Expand Down
4 changes: 4 additions & 0 deletions src/datachain/client/fsspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,10 @@ async def get_current_etag(self, file: "File") -> str:
info = await self.fs._info(self.get_full_path(file.path))
return self.info_to_file(info, "").etag

def get_file_info(self, path: str) -> "File":
info = self.fs.info(self.get_full_path(path))
return self.info_to_file(info, path)

async def get_size(self, path: str) -> int:
return await self.fs._size(path)

Expand Down
1 change: 1 addition & 0 deletions src/datachain/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def dataset_name(self) -> str:
return self.name

list_dataset_name, _, _ = parse_listing_uri(self.name.strip("/"), None, {})
assert list_dataset_name
return list_dataset_name

@classmethod
Expand Down
18 changes: 15 additions & 3 deletions src/datachain/lib/dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from datachain.lib.dataset_info import DatasetInfo
from datachain.lib.file import ArrowRow, File, FileType, get_file_type
from datachain.lib.file import ExportPlacement as FileExportPlacement
from datachain.lib.listing import get_listing, list_bucket, ls
from datachain.lib.listing import get_file_info, get_listing, list_bucket, ls
from datachain.lib.listing_info import ListingInfo
from datachain.lib.meta_formats import read_meta
from datachain.lib.model_store import ModelStore
Expand Down Expand Up @@ -438,6 +438,18 @@ def from_storage(
uri, session, update=update
)

# ds_name is None if object is a file, we don't want to use cache
# or do listing in that case - just read that single object
if not list_ds_name:
dc = cls.from_values(
session=session,
settings=settings,
in_memory=in_memory,
file=[get_file_info(list_uri, cache, client_config=client_config)],
)
dc.signals_schema = dc.signals_schema.mutate({f"{object_name}": file_type})
return dc

if update or not list_ds_exists:
(
cls.from_records(
Expand Down Expand Up @@ -1634,7 +1646,7 @@ def from_values(
output: OutputType = None,
object_name: str = "",
**fr_map,
) -> "DataChain":
) -> "Self":
"""Generate chain from list of values.
Example:
Expand All @@ -1647,7 +1659,7 @@ def from_values(
def _func_fr() -> Iterator[tuple_type]: # type: ignore[valid-type]
yield from tuples

chain = DataChain.from_records(
chain = cls.from_records(
DataChain.DEFAULT_FILE_RECORD,
session=session,
settings=settings,
Expand Down
Loading

0 comments on commit 31d96ab

Please sign in to comment.