diff --git a/src/datasets/packaged_modules/parquet/parquet.py b/src/datasets/packaged_modules/parquet/parquet.py index f6ec2e06cc0..858f045807a 100644 --- a/src/datasets/packaged_modules/parquet/parquet.py +++ b/src/datasets/packaged_modules/parquet/parquet.py @@ -1,8 +1,9 @@ import itertools from dataclasses import dataclass -from typing import List, Optional +from typing import List, Optional, Union import pyarrow as pa +import pyarrow.dataset as ds import pyarrow.parquet as pq import datasets @@ -19,6 +20,7 @@ class ParquetConfig(datasets.BuilderConfig): batch_size: Optional[int] = None columns: Optional[List[str]] = None features: Optional[datasets.Features] = None + filters: Optional[Union[ds.Expression, List[tuple], List[List[tuple]]]] = None def __post_init__(self): super().__post_init__() @@ -77,14 +79,25 @@ def _generate_tables(self, files): raise ValueError( f"Tried to load parquet data with columns '{self.config.columns}' with mismatching features '{self.info.features}'" ) + filter_expr = ( + pq.filters_to_expression(self.config.filters) + if isinstance(self.config.filters, list) + else self.config.filters + ) for file_idx, file in enumerate(itertools.chain.from_iterable(files)): with open(file, "rb") as f: - parquet_file = pq.ParquetFile(f) - if parquet_file.metadata.num_row_groups > 0: - batch_size = self.config.batch_size or parquet_file.metadata.row_group(0).num_rows + parquet_fragment = ds.ParquetFileFormat().make_fragment(f) + if parquet_fragment.row_groups: + batch_size = self.config.batch_size or parquet_fragment.row_groups[0].num_rows try: for batch_idx, record_batch in enumerate( - parquet_file.iter_batches(batch_size=batch_size, columns=self.config.columns) + parquet_fragment.to_batches( + batch_size=batch_size, + columns=self.config.columns, + filter=filter_expr, + batch_readahead=0, + fragment_readahead=0, + ) ): pa_table = pa.Table.from_batches([record_batch]) # Uncomment for debugging (will print the Arrow table size and elements) diff --git a/tests/io/test_parquet.py b/tests/io/test_parquet.py index 5466e633f04..cdc55c9e18e 100644 --- a/tests/io/test_parquet.py +++ b/tests/io/test_parquet.py @@ -89,6 +89,16 @@ def test_parquet_read_geoparquet(geoparquet_path, tmp_path): assert dataset.features[feature].dtype == expected_dtype +def test_parquet_read_filters(parquet_path, tmp_path): + cache_dir = tmp_path / "cache" + filters = [("col_2", "==", 1)] + dataset = ParquetDatasetReader(path_or_paths=parquet_path, cache_dir=cache_dir, filters=filters).read() + + assert isinstance(dataset, Dataset) + assert all(example["col_2"] == 1 for example in dataset) + assert dataset.num_rows == 1 + + def _check_parquet_datasetdict(dataset_dict, expected_features, splits=("train",)): assert isinstance(dataset_dict, (DatasetDict, IterableDatasetDict)) for split in splits: