Skip to content

Commit

Permalink
datachain: generalize data access functions into collect(), and colle…
Browse files Browse the repository at this point in the history
…ct_flatten
  • Loading branch information
skshetry committed Jul 22, 2024
1 parent c3ea4b3 commit f119a3a
Show file tree
Hide file tree
Showing 13 changed files with 139 additions and 111 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2990,7 +2990,7 @@
"# Preview the target image \n",
"\n",
"sample = DataChain.from_dataset(\"fashion-embeddings\").filter(C(\"file.name\") == TARGET_NAME).save()\n",
"img = next(sample.iterate_one(\"file\")).read()\n",
"img = next(sample.collect(\"file\")).read()\n",
"img"
]
},
Expand Down Expand Up @@ -3985,7 +3985,7 @@
" print(name)\n",
" try: \n",
" sample = sim_ds.filter(C(\"file.name\") == name).save()\n",
" img = next(sample.iterate_one(\"file\")).read()\n",
" img = next(sample.collect(\"file\")).read()\n",
" images.append({name: img})\n",
" display(img)\n",
" except: \n",
Expand Down Expand Up @@ -4540,7 +4540,7 @@
"for row in (\n",
" DataChain.from_dataset(\"fashion-curated\")\n",
" .select(\"file.source\", \"file.parent\", \"file.name\", \"embeddings\")\n",
" .iterate()\n",
" .collect()\n",
"):\n",
" image_paths.append(os.path.join(row[0], row[1], row[2]))\n",
" image_embeddings.append(row[3])\n"
Expand Down Expand Up @@ -4588,7 +4588,7 @@
"\n",
" # Extract image from DataChain\n",
" sample = ds.filter(C(\"file.name\") == name).save()\n",
" img = next(sample.iterate_one(\"file\")).read()\n",
" img = next(sample.collect(\"file\")).read()\n",
"\n",
" # Attach thumbnail to the point\n",
" img.thumbnail((50, 50), Image.Resampling.LANCZOS) # Updated line here\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@
},
"outputs": [],
"source": [
"sample_results = high_conf_mist.collect()\n",
"sample_results = list(high_conf_mist.collect())\n",
"sample_results[0]"
]
},
Expand Down Expand Up @@ -575,7 +575,7 @@
"# for item in low_conf_mist.limit(TOP).collect():\n",
"# display(item[0].read())\n",
" \n",
"items = low_conf_mist.limit(TOP).collect()\n",
"items = list(low_conf_mist.limit(TOP).collect())\n",
"display_image_matrix(items, TOP)"
]
},
Expand All @@ -600,7 +600,7 @@
"# for item in high_conf_mist.limit(TOP).collect():\n",
"# display(item[0].read())\n",
"\n",
"items = high_conf_mist.limit(TOP).collect()\n",
"items = list(high_conf_mist.limit(TOP).collect())\n",
"display_image_matrix(items, TOP)"
]
},
Expand Down Expand Up @@ -651,7 +651,7 @@
"# for item in high_conf_mist.limit(TOP).collect():\n",
"# display(item[0].read())\n",
"\n",
"items = correct_preds.limit(TOP).collect()\n",
"items = list(correct_preds.limit(TOP).collect())\n",
"display_image_matrix(items, TOP)"
]
},
Expand Down
2 changes: 1 addition & 1 deletion examples/json-csv-reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def main():
uri, schema_from=schema_uri, jmespath="@", model_name="OpenImage"
)
print(json_pairs_ds.to_pandas())
# print(json_pairs_ds.collect()[0])
# print(list(json_pairs_ds.collect())[0])

uri = "gs://datachain-demo/coco2017/annotations_captions/"

Expand Down
2 changes: 1 addition & 1 deletion examples/multimodal/clip_fine_tuning.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -949,7 +949,7 @@
}
],
"source": [
"sample_results = sample.collect(\"file\", \"caption_choices\", \"label\")\n",
"sample_results = list(sample.collect(\"file\", \"caption_choices\", \"label\"))\n",
"sample_results"
]
},
Expand Down
2 changes: 1 addition & 1 deletion src/datachain/lib/arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def process(self, file: File):

def infer_schema(chain: "DataChain", **kwargs) -> pa.Schema:
schemas = []
for file in chain.iterate_one("file"):
for file in chain.collect("file"):
ds = dataset(file.get_path(), filesystem=file.get_fs(), **kwargs) # type: ignore[union-attr]
schemas.append(ds.schema)
return pa.unify_schemas(schemas)
Expand Down
97 changes: 63 additions & 34 deletions src/datachain/lib/dc.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import copy
import os
import re
from collections.abc import Iterable, Iterator, Sequence
from collections.abc import Iterator, Sequence
from functools import wraps
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -410,7 +410,7 @@ def datasets(
from datachain import DataChain
chain = DataChain.datasets()
for ds in chain.iterate_one("dataset"):
for ds in chain.collect("dataset"):
print(f"{ds.name}@v{ds.version}")
```
"""
Expand Down Expand Up @@ -713,21 +713,28 @@ def mutate(self, **kwargs) -> "Self":

@property
def _effective_signals_schema(self) -> "SignalSchema":
"""Effective schema used for user-facing API like iterate, to_pandas, etc."""
"""Effective schema used for user-facing API like collect, to_pandas, etc."""
signals_schema = self.signals_schema
if not self._sys:
return signals_schema.clone_without_sys_signals()
return signals_schema

@overload
def iterate_flatten(self) -> Iterable[tuple[Any, ...]]: ...
def collect_flatten(self) -> Iterator[tuple[Any, ...]]: ...

@overload
def iterate_flatten(
self, row_factory: Callable[[list[str], tuple[Any, ...]], _T]
) -> Iterable[_T]: ...
def collect_flatten(
self, *, row_factory: Callable[[list[str], tuple[Any, ...]], _T]
) -> Iterator[_T]: ...

def iterate_flatten(self, row_factory=None): # noqa: D102
def collect_flatten(self, *, row_factory=None):
"""Yields flattened rows of values as a tuple.
Parameters:
row_factory : A callable to convert row to a custom format.
It should accept two arguments, a list of column names and
tuple of row values.
"""
db_signals = self._effective_signals_schema.db_signals()
with super().select(*db_signals).as_iterable() as rows:
if row_factory:
Expand All @@ -739,48 +746,70 @@ def results(self) -> list[tuple[Any, ...]]: ...

@overload
def results(
self, row_factory: Callable[[list[str], tuple[Any, ...]], _T]
self, *, row_factory: Callable[[list[str], tuple[Any, ...]], _T]
) -> list[_T]: ...

def results(self, row_factory=None, **kwargs): # noqa: D102
return list(self.iterate_flatten(row_factory=row_factory))
def results(self, *, row_factory=None): # noqa: D102
if row_factory is None:
return list(self.collect_flatten())
return list(self.collect_flatten(row_factory=row_factory))

def to_records(self) -> list[dict[str, Any]]:
"""Convert every row to a dictionary."""

def to_records(self) -> list[dict[str, Any]]: # noqa: D102
def to_dict(cols: list[str], row: tuple[Any, ...]) -> dict[str, Any]:
return dict(zip(cols, row))

return self.results(row_factory=to_dict)

def iterate(self, *cols: str) -> Iterator[list[DataType]]:
"""Iterate over rows.
@overload
def collect(self) -> Iterator[tuple[DataType, ...]]: ...

@overload
def collect(self, col: str) -> Iterator[DataType]: ... # type: ignore[overload-overlap]

@overload
def collect(self, *cols: str) -> Iterator[tuple[DataType, ...]]: ...

def collect(self, *cols: str) -> Iterator[Union[DataType, tuple[DataType, ...]]]: # type: ignore[overload-overlap]
"""Yields rows of values, optionally limited to the specified columns.
If columns are specified - limit them to specified columns.
Parameters:
*cols: Limit to the specified columns. By default, all columns are selected.
Yields:
(DataType): Yields a single item if a column is selected.
(tuple[DataType, ...]): Yields a tuple of items if multiple columns are
selected.
Example:
Iterating over all rows:
```py
for row in dc.collect():
print(row)
```
Iterating over all rows with selected columns:
```py
for name, size in dc.collect("file.name", "file.size"):
print(name, size)
```
Iterating over a single column:
```py
for file in dc.collect("file.name"):
print(file)
```
"""
chain = self.select(*cols) if cols else self
signals_schema = chain._effective_signals_schema
db_signals = signals_schema.db_signals()
with super().select(*db_signals).as_iterable() as rows:
for row in rows:
yield signals_schema.row_to_features(
ret = signals_schema.row_to_features(
row, catalog=chain.session.catalog, cache=chain._settings.cache
)

def iterate_one(self, col: str) -> Iterator[DataType]:
"""Iterate over one column."""
for item in self.iterate(col):
yield item[0]

def collect(self, *cols: str) -> list[list[DataType]]:
"""Collect results from all rows.
If columns are specified - limit them to specified
columns.
"""
return list(self.iterate(*cols))

def collect_one(self, col: str) -> list[DataType]:
"""Collect results from one column."""
return list(self.iterate_one(col))
yield ret[0] if len(cols) == 1 else tuple(ret)

def to_pytorch(
self, transform=None, tokenizer=None, tokenizer_kwargs=None, num_samples=0
Expand Down Expand Up @@ -1328,7 +1357,7 @@ def export_files(
if self.distinct(f"{signal}.name").count() != self.count():
raise ValueError("Files with the same name found")

for file in self.collect_one(signal):
for file in self.collect(signal):
file.export(output, placement, use_cache) # type: ignore[union-attr]

def shuffle(self) -> "Self":
Expand Down
3 changes: 1 addition & 2 deletions src/datachain/lib/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,7 @@ def __iter__(self) -> Iterator[Any]:
if self.num_samples > 0:
ds = ds.sample(self.num_samples)
ds = ds.chunk(total_rank, total_workers)
stream = ds.iterate()
for row_features in stream:
for row_features in ds.collect():
row = []
for fr in row_features:
if isinstance(fr, File):
Expand Down
10 changes: 5 additions & 5 deletions tests/examples/test_wds_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def test_wds(catalog, webdataset_tars):
)

num_rows = 0
for laion_wds in res.iterate_one("laion"):
for laion_wds in res.collect("laion"):
num_rows += 1
assert isinstance(laion_wds, WDSLaion)
idx, data = next(
Expand Down Expand Up @@ -105,7 +105,7 @@ def test_wds_merge_with_parquet_meta(catalog, webdataset_tars, webdataset_metada
res = wds.merge(meta, on="laion.json.uid", right_on="uid")

num_rows = 0
for r in res.collect_one("laion"):
for r in res.collect("laion"):
num_rows += 1
assert isinstance(r, WDSLaion)
assert isinstance(r.file, File)
Expand All @@ -116,15 +116,15 @@ def test_wds_merge_with_parquet_meta(catalog, webdataset_tars, webdataset_metada

assert num_rows == len(WDS_TAR_SHARDS)

meta_res = res.collect(*WDS_META.keys())
meta_res = list(res.collect(*WDS_META.keys()))

for field_name_idx, rows_values in enumerate(WDS_META.values()):
assert sorted(rows_values.values()) == sorted(
[r[field_name_idx] for r in meta_res]
)

# validate correct merge
for laion_uid, uid in res.iterate("laion.json.uid", "uid"):
for laion_uid, uid in res.collect("laion.json.uid", "uid"):
assert laion_uid == uid
for caption, text in res.iterate("laion.json.caption", "text"):
for caption, text in res.collect("laion.json.caption", "text"):
assert caption == text
8 changes: 4 additions & 4 deletions tests/func/test_datachain.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ def new_signal(file: File) -> str:
"dog3 -> bark",
"dog4 -> ruff",
}
assert set(dc.iterate_one("signal")) == expected
for file in dc.iterate_one("file"):
assert set(dc.collect("signal")) == expected
for file in dc.collect("file"):
assert bool(file.get_local_path()) is use_cache


Expand All @@ -67,7 +67,7 @@ def test_read_file(cloud_test_catalog, use_cache):
ctc = cloud_test_catalog

dc = DataChain.from_storage(ctc.src_uri, catalog=ctc.catalog)
for file in dc.settings(cache=use_cache).iterate_one("file"):
for file in dc.settings(cache=use_cache).collect("file"):
assert file.get_local_path() is None
file.read()
assert bool(file.get_local_path()) is use_cache
Expand Down Expand Up @@ -100,7 +100,7 @@ def test_export_files(tmp_dir, cloud_test_catalog, placement, use_map, use_cache
"dog4": "ruff",
}

for file in df.collect_one("file"):
for file in df.collect("file"):
if placement == "filename":
file_path = file.name
else:
Expand Down
Loading

0 comments on commit f119a3a

Please sign in to comment.