From f119a3acac25a25257d3ecf7d072ad57cfac60ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Saugat=20Pachhai=20=28=E0=A4=B8=E0=A5=8C=E0=A4=97=E0=A4=BE?= =?UTF-8?q?=E0=A4=A4=29?= Date: Mon, 22 Jul 2024 14:03:58 +0545 Subject: [PATCH] datachain: generalize data access functions into collect(), and collect_flatten --- .../2-working-with-image-datachains.ipynb | 8 +- .../fashion_product_images/4-inference.ipynb | 8 +- examples/json-csv-reader.py | 2 +- examples/multimodal/clip_fine_tuning.ipynb | 2 +- src/datachain/lib/arrow.py | 2 +- src/datachain/lib/dc.py | 97 ++++++++++++------- src/datachain/lib/pytorch.py | 3 +- tests/examples/test_wds_e2e.py | 10 +- tests/func/test_datachain.py | 8 +- tests/unit/lib/test_datachain.py | 84 ++++++++-------- tests/unit/lib/test_datachain_bootstrap.py | 10 +- tests/unit/lib/test_datachain_merge.py | 12 +-- tests/unit/lib/test_feature_utils.py | 4 +- 13 files changed, 139 insertions(+), 111 deletions(-) diff --git a/examples/computer_vision/fashion_product_images/2-working-with-image-datachains.ipynb b/examples/computer_vision/fashion_product_images/2-working-with-image-datachains.ipynb index ebfa424d7..40e3b9651 100644 --- a/examples/computer_vision/fashion_product_images/2-working-with-image-datachains.ipynb +++ b/examples/computer_vision/fashion_product_images/2-working-with-image-datachains.ipynb @@ -2990,7 +2990,7 @@ "# Preview the target image \n", "\n", "sample = DataChain.from_dataset(\"fashion-embeddings\").filter(C(\"file.name\") == TARGET_NAME).save()\n", - "img = next(sample.iterate_one(\"file\")).read()\n", + "img = next(sample.collect(\"file\")).read()\n", "img" ] }, @@ -3985,7 +3985,7 @@ " print(name)\n", " try: \n", " sample = sim_ds.filter(C(\"file.name\") == name).save()\n", - " img = next(sample.iterate_one(\"file\")).read()\n", + " img = next(sample.collect(\"file\")).read()\n", " images.append({name: img})\n", " display(img)\n", " except: \n", @@ -4540,7 +4540,7 @@ "for row in (\n", " DataChain.from_dataset(\"fashion-curated\")\n", " .select(\"file.source\", \"file.parent\", \"file.name\", \"embeddings\")\n", - " .iterate()\n", + " .collect()\n", "):\n", " image_paths.append(os.path.join(row[0], row[1], row[2]))\n", " image_embeddings.append(row[3])\n" @@ -4588,7 +4588,7 @@ "\n", " # Extract image from DataChain\n", " sample = ds.filter(C(\"file.name\") == name).save()\n", - " img = next(sample.iterate_one(\"file\")).read()\n", + " img = next(sample.collect(\"file\")).read()\n", "\n", " # Attach thumbnail to the point\n", " img.thumbnail((50, 50), Image.Resampling.LANCZOS) # Updated line here\n", diff --git a/examples/computer_vision/fashion_product_images/4-inference.ipynb b/examples/computer_vision/fashion_product_images/4-inference.ipynb index 6254179b5..a64df264b 100644 --- a/examples/computer_vision/fashion_product_images/4-inference.ipynb +++ b/examples/computer_vision/fashion_product_images/4-inference.ipynb @@ -484,7 +484,7 @@ }, "outputs": [], "source": [ - "sample_results = high_conf_mist.collect()\n", + "sample_results = list(high_conf_mist.collect())\n", "sample_results[0]" ] }, @@ -575,7 +575,7 @@ "# for item in low_conf_mist.limit(TOP).collect():\n", "# display(item[0].read())\n", " \n", - "items = low_conf_mist.limit(TOP).collect()\n", + "items = list(low_conf_mist.limit(TOP).collect())\n", "display_image_matrix(items, TOP)" ] }, @@ -600,7 +600,7 @@ "# for item in high_conf_mist.limit(TOP).collect():\n", "# display(item[0].read())\n", "\n", - "items = high_conf_mist.limit(TOP).collect()\n", + "items = list(high_conf_mist.limit(TOP).collect())\n", "display_image_matrix(items, TOP)" ] }, @@ -651,7 +651,7 @@ "# for item in high_conf_mist.limit(TOP).collect():\n", "# display(item[0].read())\n", "\n", - "items = correct_preds.limit(TOP).collect()\n", + "items = list(correct_preds.limit(TOP).collect())\n", "display_image_matrix(items, TOP)" ] }, diff --git a/examples/json-csv-reader.py b/examples/json-csv-reader.py index dce4ecb53..d617507f6 100644 --- a/examples/json-csv-reader.py +++ b/examples/json-csv-reader.py @@ -69,7 +69,7 @@ def main(): uri, schema_from=schema_uri, jmespath="@", model_name="OpenImage" ) print(json_pairs_ds.to_pandas()) - # print(json_pairs_ds.collect()[0]) + # print(list(json_pairs_ds.collect())[0]) uri = "gs://datachain-demo/coco2017/annotations_captions/" diff --git a/examples/multimodal/clip_fine_tuning.ipynb b/examples/multimodal/clip_fine_tuning.ipynb index 79c822e99..4f613b13d 100644 --- a/examples/multimodal/clip_fine_tuning.ipynb +++ b/examples/multimodal/clip_fine_tuning.ipynb @@ -949,7 +949,7 @@ } ], "source": [ - "sample_results = sample.collect(\"file\", \"caption_choices\", \"label\")\n", + "sample_results = list(sample.collect(\"file\", \"caption_choices\", \"label\"))\n", "sample_results" ] }, diff --git a/src/datachain/lib/arrow.py b/src/datachain/lib/arrow.py index e9b4ed776..a199ab133 100644 --- a/src/datachain/lib/arrow.py +++ b/src/datachain/lib/arrow.py @@ -39,7 +39,7 @@ def process(self, file: File): def infer_schema(chain: "DataChain", **kwargs) -> pa.Schema: schemas = [] - for file in chain.iterate_one("file"): + for file in chain.collect("file"): ds = dataset(file.get_path(), filesystem=file.get_fs(), **kwargs) # type: ignore[union-attr] schemas.append(ds.schema) return pa.unify_schemas(schemas) diff --git a/src/datachain/lib/dc.py b/src/datachain/lib/dc.py index a6c7ea95f..0014ebbdb 100644 --- a/src/datachain/lib/dc.py +++ b/src/datachain/lib/dc.py @@ -1,7 +1,7 @@ import copy import os import re -from collections.abc import Iterable, Iterator, Sequence +from collections.abc import Iterator, Sequence from functools import wraps from typing import ( TYPE_CHECKING, @@ -410,7 +410,7 @@ def datasets( from datachain import DataChain chain = DataChain.datasets() - for ds in chain.iterate_one("dataset"): + for ds in chain.collect("dataset"): print(f"{ds.name}@v{ds.version}") ``` """ @@ -713,21 +713,28 @@ def mutate(self, **kwargs) -> "Self": @property def _effective_signals_schema(self) -> "SignalSchema": - """Effective schema used for user-facing API like iterate, to_pandas, etc.""" + """Effective schema used for user-facing API like collect, to_pandas, etc.""" signals_schema = self.signals_schema if not self._sys: return signals_schema.clone_without_sys_signals() return signals_schema @overload - def iterate_flatten(self) -> Iterable[tuple[Any, ...]]: ... + def collect_flatten(self) -> Iterator[tuple[Any, ...]]: ... @overload - def iterate_flatten( - self, row_factory: Callable[[list[str], tuple[Any, ...]], _T] - ) -> Iterable[_T]: ... + def collect_flatten( + self, *, row_factory: Callable[[list[str], tuple[Any, ...]], _T] + ) -> Iterator[_T]: ... - def iterate_flatten(self, row_factory=None): # noqa: D102 + def collect_flatten(self, *, row_factory=None): + """Yields flattened rows of values as a tuple. + + Parameters: + row_factory : A callable to convert row to a custom format. + It should accept two arguments, a list of column names and + tuple of row values. + """ db_signals = self._effective_signals_schema.db_signals() with super().select(*db_signals).as_iterable() as rows: if row_factory: @@ -739,48 +746,70 @@ def results(self) -> list[tuple[Any, ...]]: ... @overload def results( - self, row_factory: Callable[[list[str], tuple[Any, ...]], _T] + self, *, row_factory: Callable[[list[str], tuple[Any, ...]], _T] ) -> list[_T]: ... - def results(self, row_factory=None, **kwargs): # noqa: D102 - return list(self.iterate_flatten(row_factory=row_factory)) + def results(self, *, row_factory=None): # noqa: D102 + if row_factory is None: + return list(self.collect_flatten()) + return list(self.collect_flatten(row_factory=row_factory)) + + def to_records(self) -> list[dict[str, Any]]: + """Convert every row to a dictionary.""" - def to_records(self) -> list[dict[str, Any]]: # noqa: D102 def to_dict(cols: list[str], row: tuple[Any, ...]) -> dict[str, Any]: return dict(zip(cols, row)) return self.results(row_factory=to_dict) - def iterate(self, *cols: str) -> Iterator[list[DataType]]: - """Iterate over rows. + @overload + def collect(self) -> Iterator[tuple[DataType, ...]]: ... + + @overload + def collect(self, col: str) -> Iterator[DataType]: ... # type: ignore[overload-overlap] + + @overload + def collect(self, *cols: str) -> Iterator[tuple[DataType, ...]]: ... + + def collect(self, *cols: str) -> Iterator[Union[DataType, tuple[DataType, ...]]]: # type: ignore[overload-overlap] + """Yields rows of values, optionally limited to the specified columns. - If columns are specified - limit them to specified columns. + Parameters: + *cols: Limit to the specified columns. By default, all columns are selected. + + Yields: + (DataType): Yields a single item if a column is selected. + (tuple[DataType, ...]): Yields a tuple of items if multiple columns are + selected. + + Example: + Iterating over all rows: + ```py + for row in dc.collect(): + print(row) + ``` + + Iterating over all rows with selected columns: + ```py + for name, size in dc.collect("file.name", "file.size"): + print(name, size) + ``` + + Iterating over a single column: + ```py + for file in dc.collect("file.name"): + print(file) + ``` """ chain = self.select(*cols) if cols else self signals_schema = chain._effective_signals_schema db_signals = signals_schema.db_signals() with super().select(*db_signals).as_iterable() as rows: for row in rows: - yield signals_schema.row_to_features( + ret = signals_schema.row_to_features( row, catalog=chain.session.catalog, cache=chain._settings.cache ) - - def iterate_one(self, col: str) -> Iterator[DataType]: - """Iterate over one column.""" - for item in self.iterate(col): - yield item[0] - - def collect(self, *cols: str) -> list[list[DataType]]: - """Collect results from all rows. - - If columns are specified - limit them to specified - columns. - """ - return list(self.iterate(*cols)) - - def collect_one(self, col: str) -> list[DataType]: - """Collect results from one column.""" - return list(self.iterate_one(col)) + yield ret[0] if len(cols) == 1 else tuple(ret) def to_pytorch( self, transform=None, tokenizer=None, tokenizer_kwargs=None, num_samples=0 @@ -1328,7 +1357,7 @@ def export_files( if self.distinct(f"{signal}.name").count() != self.count(): raise ValueError("Files with the same name found") - for file in self.collect_one(signal): + for file in self.collect(signal): file.export(output, placement, use_cache) # type: ignore[union-attr] def shuffle(self) -> "Self": diff --git a/src/datachain/lib/pytorch.py b/src/datachain/lib/pytorch.py index 55b5949e0..58bb9e433 100644 --- a/src/datachain/lib/pytorch.py +++ b/src/datachain/lib/pytorch.py @@ -94,8 +94,7 @@ def __iter__(self) -> Iterator[Any]: if self.num_samples > 0: ds = ds.sample(self.num_samples) ds = ds.chunk(total_rank, total_workers) - stream = ds.iterate() - for row_features in stream: + for row_features in ds.collect(): row = [] for fr in row_features: if isinstance(fr, File): diff --git a/tests/examples/test_wds_e2e.py b/tests/examples/test_wds_e2e.py index a7d887335..d12dff201 100644 --- a/tests/examples/test_wds_e2e.py +++ b/tests/examples/test_wds_e2e.py @@ -75,7 +75,7 @@ def test_wds(catalog, webdataset_tars): ) num_rows = 0 - for laion_wds in res.iterate_one("laion"): + for laion_wds in res.collect("laion"): num_rows += 1 assert isinstance(laion_wds, WDSLaion) idx, data = next( @@ -105,7 +105,7 @@ def test_wds_merge_with_parquet_meta(catalog, webdataset_tars, webdataset_metada res = wds.merge(meta, on="laion.json.uid", right_on="uid") num_rows = 0 - for r in res.collect_one("laion"): + for r in res.collect("laion"): num_rows += 1 assert isinstance(r, WDSLaion) assert isinstance(r.file, File) @@ -116,7 +116,7 @@ def test_wds_merge_with_parquet_meta(catalog, webdataset_tars, webdataset_metada assert num_rows == len(WDS_TAR_SHARDS) - meta_res = res.collect(*WDS_META.keys()) + meta_res = list(res.collect(*WDS_META.keys())) for field_name_idx, rows_values in enumerate(WDS_META.values()): assert sorted(rows_values.values()) == sorted( @@ -124,7 +124,7 @@ def test_wds_merge_with_parquet_meta(catalog, webdataset_tars, webdataset_metada ) # validate correct merge - for laion_uid, uid in res.iterate("laion.json.uid", "uid"): + for laion_uid, uid in res.collect("laion.json.uid", "uid"): assert laion_uid == uid - for caption, text in res.iterate("laion.json.caption", "text"): + for caption, text in res.collect("laion.json.caption", "text"): assert caption == text diff --git a/tests/func/test_datachain.py b/tests/func/test_datachain.py index 5a74424be..1ec811bbe 100644 --- a/tests/func/test_datachain.py +++ b/tests/func/test_datachain.py @@ -57,8 +57,8 @@ def new_signal(file: File) -> str: "dog3 -> bark", "dog4 -> ruff", } - assert set(dc.iterate_one("signal")) == expected - for file in dc.iterate_one("file"): + assert set(dc.collect("signal")) == expected + for file in dc.collect("file"): assert bool(file.get_local_path()) is use_cache @@ -67,7 +67,7 @@ def test_read_file(cloud_test_catalog, use_cache): ctc = cloud_test_catalog dc = DataChain.from_storage(ctc.src_uri, catalog=ctc.catalog) - for file in dc.settings(cache=use_cache).iterate_one("file"): + for file in dc.settings(cache=use_cache).collect("file"): assert file.get_local_path() is None file.read() assert bool(file.get_local_path()) is use_cache @@ -100,7 +100,7 @@ def test_export_files(tmp_dir, cloud_test_catalog, placement, use_map, use_cache "dog4": "ruff", } - for file in df.collect_one("file"): + for file in df.collect("file"): if placement == "filename": file_path = file.name else: diff --git a/tests/unit/lib/test_datachain.py b/tests/unit/lib/test_datachain.py index 1d3f6c937..5bccb7045 100644 --- a/tests/unit/lib/test_datachain.py +++ b/tests/unit/lib/test_datachain.py @@ -113,24 +113,24 @@ def test_from_features(catalog): params="parent", output={"file": File, "t1": MyFr}, ) - for i, (_, t1) in enumerate(ds.iterate()): + for i, (_, t1) in enumerate(ds.collect()): assert t1 == features[i] def test_datasets(catalog): ds = DataChain.datasets() - datasets = [d for d in ds.iterate_one("dataset") if d.name == "fibonacci"] + datasets = [d for d in ds.collect("dataset") if d.name == "fibonacci"] assert len(datasets) == 0 DataChain.from_values(fib=[1, 1, 2, 3, 5, 8]).save("fibonacci") ds = DataChain.datasets() - datasets = [d for d in ds.iterate_one("dataset") if d.name == "fibonacci"] + datasets = [d for d in ds.collect("dataset") if d.name == "fibonacci"] assert len(datasets) == 1 assert datasets[0].num_objects == 6 ds = DataChain.datasets(object_name="foo") - datasets = [d for d in ds.iterate_one("foo") if d.name == "fibonacci"] + datasets = [d for d in ds.collect("foo") if d.name == "fibonacci"] assert len(datasets) == 1 assert datasets[0].num_objects == 6 @@ -206,7 +206,7 @@ def test_file_list(catalog): ds = DataChain.from_values(file=files) - for i, values in enumerate(ds.iterate()): + for i, values in enumerate(ds.collect()): assert values[0] == files[i] @@ -229,7 +229,7 @@ class _TestFr(BaseModel): output={"x": _TestFr}, ) - for i, (x,) in enumerate(ds.iterate()): + for i, (x,) in enumerate(ds.collect()): assert isinstance(x, _TestFr) fr = features[i] @@ -253,7 +253,7 @@ class _TestFr(BaseModel): output={"x": _TestFr}, ) - x_list = dc.collect_one("x") + x_list = list(dc.collect("x")) test_frs = [ _TestFr(sqrt=math.sqrt(fr.count), my_name=fr.nnn + "_suf") for fr in features ] @@ -284,7 +284,7 @@ class _TestFr(BaseModel): output={"x": _TestFr}, ) - assert dc.collect_one("x") == [ + assert list(dc.collect("x")) == [ _TestFr( f=File(name=""), cnt=sum(fr.count for fr in features if fr.nnn == "n1"), @@ -323,8 +323,8 @@ class _TestFr(BaseModel): output={"x": _TestFr}, ) - assert ds.collect_one("x.my_name") == ["n1-n1", "n2"] - assert ds.collect_one("x.cnt") == [12, 15] + assert list(ds.collect("x.my_name")) == ["n1-n1", "n2"] + assert list(ds.collect("x.cnt")) == [12, 15] def test_agg_simple_iterator(catalog): @@ -383,8 +383,8 @@ def func(key, val) -> Iterator[tuple[File, _ImageGroup]]: values = [1, 5, 9] ds = DataChain.from_values(key=keys, val=values).agg(x=func, partition_by=C("key")) - assert ds.collect_one("x_1.name") == ["n1-n1", "n2"] - assert ds.collect_one("x_1.size") == [10, 5] + assert list(ds.collect("x_1.name")) == ["n1-n1", "n2"] + assert list(ds.collect("x_1.size")) == [10, 5] def test_agg_tuple_result_generator(catalog): @@ -401,15 +401,15 @@ def func(key, val) -> Generator[tuple[File, _ImageGroup], None, None]: values = [1, 5, 9] ds = DataChain.from_values(key=keys, val=values).agg(x=func, partition_by=C("key")) - assert ds.collect_one("x_1.name") == ["n1-n1", "n2"] - assert ds.collect_one("x_1.size") == [10, 5] + assert list(ds.collect("x_1.name")) == ["n1-n1", "n2"] + assert list(ds.collect("x_1.size")) == [10, 5] -def test_iterate(catalog): +def test_collect(catalog): dc = DataChain.from_values(f1=features, num=range(len(features))) n = 0 - for sample in dc.iterate(): + for sample in dc.collect(): assert len(sample) == 2 fr, num = sample @@ -423,10 +423,10 @@ def test_iterate(catalog): assert n == len(features) -def test_iterate_nested_feature(catalog): +def test_collect_nested_feature(catalog): dc = DataChain.from_values(sign1=features_nested) - for n, sample in enumerate(dc.iterate()): + for n, sample in enumerate(dc.collect()): assert len(sample) == 1 nested = sample[0] @@ -437,21 +437,21 @@ def test_iterate_nested_feature(catalog): def test_select_feature(catalog): dc = DataChain.from_values(my_n=features_nested) - samples = dc.select("my_n").iterate() + samples = dc.select("my_n").collect() n = 0 for sample in samples: assert sample[0] == features_nested[n] n += 1 assert n == len(features_nested) - samples = dc.select("my_n.fr").iterate() + samples = dc.select("my_n.fr").collect() n = 0 for sample in samples: assert sample[0] == features[n] n += 1 assert n == len(features_nested) - samples = dc.select("my_n.label", "my_n.fr.count").iterate() + samples = dc.select("my_n.label", "my_n.fr.count").collect() n = 0 for sample in samples: label, count = sample @@ -464,7 +464,7 @@ def test_select_feature(catalog): def test_select_columns_intersection(catalog): dc = DataChain.from_values(my_n=features_nested) - samples = dc.select("my_n.fr", "my_n.fr.count").iterate() + samples = dc.select("my_n.fr", "my_n.fr.count").collect() n = 0 for sample in samples: fr, count = sample @@ -477,7 +477,7 @@ def test_select_columns_intersection(catalog): def test_select_except(catalog): dc = DataChain.from_values(fr1=features_nested, fr2=features) - samples = dc.select_except("fr2").iterate() + samples = dc.select_except("fr2").collect() n = 0 for sample in samples: fr = sample[0] @@ -490,20 +490,20 @@ def test_select_wrong_type(catalog): dc = DataChain.from_values(fr1=features_nested, fr2=features) with pytest.raises(SignalResolvingTypeError): - list(dc.select(4).iterate()) + list(dc.select(4).collect()) with pytest.raises(SignalResolvingTypeError): - list(dc.select_except(features[0]).iterate()) + list(dc.select_except(features[0]).collect()) def test_select_except_error(catalog): dc = DataChain.from_values(fr1=features_nested, fr2=features) with pytest.raises(SignalResolvingError): - list(dc.select_except("not_exist", "file").iterate()) + list(dc.select_except("not_exist", "file").collect()) with pytest.raises(SignalResolvingError): - list(dc.select_except("fr1.label", "file").iterate()) + list(dc.select_except("fr1.label", "file").collect()) def test_select_restore_from_saving(catalog): @@ -514,7 +514,7 @@ def test_select_restore_from_saving(catalog): restored = DataChain.from_dataset(name) n = 0 - restored_sorted = sorted(restored.iterate(), key=lambda x: x[0].count) + restored_sorted = sorted(restored.collect(), key=lambda x: x[0].count) features_sorted = sorted(features, key=lambda x: x.count) for sample in restored_sorted: assert sample[0] == features_sorted[n] @@ -593,7 +593,7 @@ def get_vector(key) -> list[np.float64]: DataChain.from_values(key=[123]).map(emd=get_vector) -def test_collect_one(catalog): +def test_collect_single_item(catalog): names = ["f1.jpg", "f1.json", "f1.txt", "f2.jpg", "f2.json"] sizes = [1, 2, 3, 4, 5] files = [File(name=name, size=size) for name, size in zip(names, sizes)] @@ -602,11 +602,11 @@ def test_collect_one(catalog): chain = DataChain.from_values(file=files, score=scores) - assert chain.collect_one("file") == files - assert chain.collect_one("file.name") == names - assert chain.collect_one("file.size") == sizes - assert chain.collect_one("file.source") == [""] * len(names) - assert np.allclose(chain.collect_one("score"), scores) + assert list(chain.collect("file")) == files + assert list(chain.collect("file.name")) == names + assert list(chain.collect("file.size")) == sizes + assert list(chain.collect("file.source")) == [""] * len(names) + assert np.allclose(list(chain.collect("score")), scores) for actual, expected in zip( chain.collect("file.size", "score"), [[x, y] for x, y in zip(sizes, scores)] @@ -622,7 +622,7 @@ def test_default_output_type(catalog): chain = DataChain.from_values(name=names).map(res1=lambda name: name + suffix) - assert chain.collect_one("res1") == [t + suffix for t in names] + assert list(chain.collect("res1")) == [t + suffix for t in names] def test_parse_tabular(tmp_dir, catalog): @@ -850,11 +850,11 @@ def test_parallel(processes, catalog): prefix = "t & " vals = ["a", "b", "c", "d", "e", "f", "g", "h", "i"] - res = ( + res = list( DataChain.from_values(key=vals) .settings(parallel=processes) .map(res=lambda key: prefix + key) - .collect_one("res") + .collect("res") ) assert res == [prefix + v for v in vals] @@ -956,16 +956,16 @@ def test_mutate(): assert chain.signals_schema.values["place"] is str expected = [fr.count * 2 * 3.14 for fr in features] - np.testing.assert_allclose(chain.collect_one("circle"), expected) + np.testing.assert_allclose(list(chain.collect("circle")), expected) def test_order_by_with_nested_columns(): names = ["a.txt", "c.txt", "d.txt", "a.txt", "b.txt"] - assert ( + assert list( DataChain.from_values(file=[File(name=name) for name in names]) .order_by("file.name") - .collect_one("file.name") + .collect("file.name") ) == ["a.txt", "a.txt", "b.txt", "c.txt", "d.txt"] @@ -974,8 +974,8 @@ def test_order_by_with_func(): from datachain.sql.functions import rand - assert ( + assert list( DataChain.from_values(file=[File(name=name) for name in names]) .order_by("file.name", rand()) - .collect_one("file.name") + .collect("file.name") ) == ["a.txt", "a.txt", "b.txt", "c.txt", "d.txt"] diff --git a/tests/unit/lib/test_datachain_bootstrap.py b/tests/unit/lib/test_datachain_bootstrap.py index f4838312f..4ecbdaa3a 100644 --- a/tests/unit/lib/test_datachain_bootstrap.py +++ b/tests/unit/lib/test_datachain_bootstrap.py @@ -29,7 +29,7 @@ def test_udf(): chain = DataChain.from_values(key=vals) udf = MyMapper() - res = chain.map(res=udf).collect_one("res") + res = list(chain.map(res=udf).collect("res")) assert res == [MyMapper.BOOTSTRAP_VALUE] * len(vals) assert udf.value == MyMapper.TEARDOWN_VALUE @@ -40,7 +40,7 @@ def test_udf_parallel(): vals = ["a", "b", "c", "d", "e", "f"] chain = DataChain.from_values(key=vals) - res = chain.settings(parallel=4).map(res=MyMapper()).collect_one("res") + res = list(chain.settings(parallel=4).map(res=MyMapper()).collect("res")) assert res == [MyMapper.BOOTSTRAP_VALUE] * len(vals) @@ -63,7 +63,7 @@ def teardown(self): udf = MyMapper() chain = DataChain.from_values(key=["a", "b", "c"]) - chain.map(res=udf).collect() + list(chain.map(res=udf).collect()) assert udf._had_bootstrap is False assert udf._had_teardown is False @@ -73,11 +73,11 @@ def test_bootstrap_in_chain(): base = 1278 prime = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29] - res = ( + res = list( DataChain.from_values(val=prime) .setup(init_val=lambda: base) .map(x=lambda val, init_val: val + init_val, output=int) - .collect_one("x") + .collect("x") ) assert res == [base + val for val in prime] diff --git a/tests/unit/lib/test_datachain_merge.py b/tests/unit/lib/test_datachain_merge.py index 2545940bc..d332029e7 100644 --- a/tests/unit/lib/test_datachain_merge.py +++ b/tests/unit/lib/test_datachain_merge.py @@ -54,7 +54,7 @@ def test_merge_objects(catalog): i = 0 j = 0 - for items in ch.iterate(): + for items in ch.collect(): assert len(items) == 2 empl, player = items @@ -94,12 +94,12 @@ def test_merge_similar_objects(catalog): assert list(ch.signals_schema.values.keys()) == ["emp", rname + "emp"] - empl = list(ch.iterate()) + empl = list(ch.collect()) assert len(empl) == 4 assert len(empl[0]) == 2 ch_inner = ch1.merge(ch2, "emp.person.name", rname=rname, inner=True) - assert len(list(ch_inner.iterate())) == 2 + assert len(list(ch_inner.collect())) == 2 def test_merge_values(catalog): @@ -120,7 +120,7 @@ def test_merge_values(catalog): i = 0 j = 0 - sorted_items_list = sorted(ch.iterate(), key=lambda x: x[0]) + sorted_items_list = sorted(ch.collect(), key=lambda x: x[0]) for items in sorted_items_list: assert len(items) == 4 id, name, _right_id, time = items @@ -154,7 +154,7 @@ def test_merge_multi_conditions(catalog): ch = ch1.merge(ch2, ("id", "name"), ("id", "d_name")) - res = list(ch.iterate()) + res = list(ch.collect()) assert len(res) == max(len(employees), len(team)) success_ids = set() @@ -189,7 +189,7 @@ def test_merge_with_itself(catalog): merged = ch.merge(ch, "emp.id") count = 0 - for left, right in merged.iterate(): + for left, right in merged.collect(): assert isinstance(left, TestEmployee) assert isinstance(right, TestEmployee) assert left == right == employees[count] diff --git a/tests/unit/lib/test_feature_utils.py b/tests/unit/lib/test_feature_utils.py index ca2619954..47066001f 100644 --- a/tests/unit/lib/test_feature_utils.py +++ b/tests/unit/lib/test_feature_utils.py @@ -31,7 +31,7 @@ def test_e2e(catalog): dc = DataChain.from_values(fib=fib, odds=values) - vals = list(dc.iterate()) + vals = list(dc.collect()) lst1 = [item[0] for item in vals] lst2 = [item[1] for item in vals] @@ -53,7 +53,7 @@ def test_single_e2e(catalog): dc = DataChain.from_values(fib=fib) - vals = list(dc.iterate()) + vals = list(dc.collect()) flattened = [item for sublist in vals for item in sublist] assert flattened == fib