diff --git a/examples/get_started/json-csv-reader.py b/examples/get_started/json-csv-reader.py index feedbdb28..df1c6c4b8 100644 --- a/examples/get_started/json-csv-reader.py +++ b/examples/get_started/json-csv-reader.py @@ -89,6 +89,15 @@ def main(): static_csv_ds.print_schema() static_csv_ds.show() + uri = "gs://datachain-demo/laion-aesthetics-csv/laion_aesthetics_1024_33M_1.csv" + print() + print("========================================================================") + print("dynamic CSV with header schema test parsing 3/3M objects") + print("========================================================================") + dynamic_csv_ds = DataChain.from_csv(uri, object_name="laion", nrows=3) + dynamic_csv_ds.print_schema() + dynamic_csv_ds.show() + if __name__ == "__main__": main() diff --git a/src/datachain/lib/arrow.py b/src/datachain/lib/arrow.py index 1158f0459..f56843b2e 100644 --- a/src/datachain/lib/arrow.py +++ b/src/datachain/lib/arrow.py @@ -1,5 +1,6 @@ import re from collections.abc import Sequence +from tempfile import NamedTemporaryFile from typing import TYPE_CHECKING, Optional import pyarrow as pa @@ -43,13 +44,17 @@ def __init__( self.kwargs = kwargs def process(self, file: File): - path = file.get_path() - ds = dataset( - path, filesystem=file.get_fs(), schema=self.input_schema, **self.kwargs - ) + if self.nrows: + path = _nrows_file(file, self.nrows) + ds = dataset(path, schema=self.input_schema, **self.kwargs) + else: + path = file.get_path() + ds = dataset( + path, filesystem=file.get_fs(), schema=self.input_schema, **self.kwargs + ) index = 0 with tqdm(desc="Parsed by pyarrow", unit=" rows") as pbar: - for record_batch in ds.to_batches(use_threads=False): + for record_batch in ds.to_batches(): for record in record_batch.to_pylist(): vals = list(record.values()) if self.output_schema: @@ -60,8 +65,6 @@ def process(self, file: File): else: yield vals index += 1 - if self.nrows and index >= self.nrows: - return pbar.update(len(record_batch)) @@ -125,3 +128,15 @@ def _arrow_type_mapper(col_type: pa.DataType) -> type: # noqa: PLR0911 if isinstance(col_type, pa.lib.DictionaryType): return _arrow_type_mapper(col_type.value_type) # type: ignore[return-value] raise TypeError(f"{col_type!r} datatypes not supported") + + +def _nrows_file(file: File, nrows: int) -> str: + tf = NamedTemporaryFile(delete=False) + with file.open(mode="r") as reader: + with open(tf.name, "a") as writer: + for row, line in enumerate(reader): + if row >= nrows: + break + writer.write(line) + writer.write("\n") + return tf.name diff --git a/src/datachain/lib/dc.py b/src/datachain/lib/dc.py index a91e2fa3d..91e674444 100644 --- a/src/datachain/lib/dc.py +++ b/src/datachain/lib/dc.py @@ -1261,8 +1261,21 @@ def parse_tabular( dc = dc.parse_tabular(format="json") ``` """ + from pyarrow.dataset import CsvFileFormat, JsonFileFormat + from datachain.lib.arrow import ArrowGenerator, infer_schema, schema_to_output + if nrows: + format = kwargs.get("format") + if format not in ["csv", "json"] and not isinstance( + format, (CsvFileFormat, JsonFileFormat) + ): + raise DatasetPrepareError( + self.name, + "error in `parse_tabular` - " + "`nrows` only supported for csv and json formats.", + ) + schema = None col_names = output if isinstance(output, Sequence) else None if col_names or not output: @@ -1360,6 +1373,8 @@ def from_csv( else: msg = f"error parsing csv - incompatible output type {type(output)}" raise DatasetPrepareError(chain.name, msg) + elif nrows: + nrows += 1 parse_options = ParseOptions(delimiter=delimiter) read_options = ReadOptions(column_names=column_names) @@ -1382,7 +1397,6 @@ def from_parquet( object_name: str = "", model_name: str = "", source: bool = True, - nrows=None, **kwargs, ) -> "DataChain": """Generate chain from parquet files. @@ -1395,7 +1409,6 @@ def from_parquet( object_name : Created object column name. model_name : Generated model name. source : Whether to include info about the source file. - nrows : Optional row limit. Example: Reading a single file: @@ -1414,7 +1427,6 @@ def from_parquet( object_name=object_name, model_name=model_name, source=source, - nrows=None, format="parquet", partitioning=partitioning, ) diff --git a/src/datachain/lib/file.py b/src/datachain/lib/file.py index 4c1ad6b85..855480da7 100644 --- a/src/datachain/lib/file.py +++ b/src/datachain/lib/file.py @@ -317,9 +317,9 @@ class TextFile(File): """`DataModel` for reading text files.""" @contextmanager - def open(self): - """Open the file and return a file object in text mode.""" - with super().open(mode="r") as stream: + def open(self, mode: Literal["rb", "r"] = "r"): + """Open the file and return a file object (default to text mode).""" + with super().open(mode=mode) as stream: yield stream def read_text(self): diff --git a/tests/unit/lib/test_arrow.py b/tests/unit/lib/test_arrow.py index 9bf06a311..eaaf5fa79 100644 --- a/tests/unit/lib/test_arrow.py +++ b/tests/unit/lib/test_arrow.py @@ -36,23 +36,6 @@ def test_arrow_generator(tmp_path, catalog): assert o[2] == text -def test_arrow_generator_nrows(tmp_path, catalog): - ids = [12345, 67890, 34, 0xF0123] - texts = ["28", "22", "we", "hello world"] - df = pd.DataFrame({"id": ids, "text": texts}) - - name = "111.parquet" - pq_path = tmp_path / name - df.to_parquet(pq_path) - stream = File(name=name, parent=tmp_path.as_posix(), source="file:///") - stream._set_stream(catalog, caching_enabled=False) - - func = ArrowGenerator(nrows=2) - objs = list(func.process(stream)) - - assert len(objs) == 2 - - def test_arrow_generator_no_source(tmp_path, catalog): ids = [12345, 67890, 34, 0xF0123] texts = ["28", "22", "we", "hello world"] diff --git a/tests/unit/lib/test_datachain.py b/tests/unit/lib/test_datachain.py index 2db6f1f57..8b845740e 100644 --- a/tests/unit/lib/test_datachain.py +++ b/tests/unit/lib/test_datachain.py @@ -860,6 +860,24 @@ def test_parse_tabular_output_list(tmp_dir, catalog): assert df1.equals(df) +def test_parse_tabular_nrows(tmp_dir, catalog): + df = pd.DataFrame(DF_DATA) + path = tmp_dir / "test.parquet" + df.to_json(path, orient="records", lines=True) + dc = DataChain.from_storage(path.as_uri()).parse_tabular(nrows=2, format="json") + df1 = dc.select("first_name", "age", "city").to_pandas() + + assert df1.equals(df[:2]) + + +def test_parse_tabular_nrows_invalid(tmp_dir, catalog): + df = pd.DataFrame(DF_DATA) + path = tmp_dir / "test.parquet" + df.to_parquet(path) + with pytest.raises(DataChainParamsError): + DataChain.from_storage(path.as_uri()).parse_tabular(nrows=2) + + def test_from_csv(tmp_dir, catalog): df = pd.DataFrame(DF_DATA) path = tmp_dir / "test.csv" @@ -936,6 +954,15 @@ def test_from_csv_null_collect(tmp_dir, catalog): assert row[1].height == height[i] +def test_from_csv_nrows(tmp_dir, catalog): + df = pd.DataFrame(DF_DATA) + path = tmp_dir / "test.csv" + df.to_csv(path, index=False) + dc = DataChain.from_csv(path.as_uri(), nrows=2) + df1 = dc.select("first_name", "age", "city").to_pandas() + assert df1.equals(df[:2]) + + def test_from_parquet(tmp_dir, catalog): df = pd.DataFrame(DF_DATA) path = tmp_dir / "test.parquet"