From 0eea7431b27a7ec6835da0fa166b3708e1d932a2 Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Thu, 15 Jun 2023 11:07:01 +0200 Subject: [PATCH] Python: Give the Dataset protocol a try --- python/pyiceberg/io/pyarrow.py | 182 +++++++++++++++++++-- python/pyiceberg/utils/pyarrow_dataset.py | 184 ++++++++++++++++++++++ 2 files changed, 350 insertions(+), 16 deletions(-) create mode 100644 python/pyiceberg/utils/pyarrow_dataset.py diff --git a/python/pyiceberg/io/pyarrow.py b/python/pyiceberg/io/pyarrow.py index 7454c6f744bd..649ed2993ecd 100644 --- a/python/pyiceberg/io/pyarrow.py +++ b/python/pyiceberg/io/pyarrow.py @@ -28,6 +28,7 @@ import os from abc import ABC, abstractmethod from functools import lru_cache, singledispatch +from itertools import chain from multiprocessing.pool import ThreadPool from multiprocessing.sharedctypes import Synchronized from typing import ( @@ -36,6 +37,7 @@ Callable, Generic, Iterable, + Iterator, List, Optional, Set, @@ -49,6 +51,7 @@ import pyarrow as pa import pyarrow.compute as pc import pyarrow.dataset as ds +from pyarrow import RecordBatchReader from pyarrow.fs import ( FileInfo, FileSystem, @@ -118,6 +121,7 @@ TimeType, UUIDType, ) +from pyiceberg.utils.pyarrow_dataset import Dataset, Fragment, Scanner from pyiceberg.utils.singleton import Singleton if TYPE_CHECKING: @@ -741,6 +745,23 @@ def _file_to_table( return None +def _get_fs_from_table(table: Table) -> FileSystem: + scheme, _ = PyArrowFileIO.parse_location(table.location()) + if isinstance(table.io, PyArrowFileIO): + return table.io.get_fs(scheme) + else: + try: + from pyiceberg.io.fsspec import FsspecFileIO + + if isinstance(table.io, FsspecFileIO): + return PyFileSystem(FSSpecHandler(table.io.get_fs(scheme))) + else: + raise ValueError(f"Expected PyArrowFileIO or FsspecFileIO, got: {table.io}") + except ModuleNotFoundError as e: + # When FsSpec is not installed + raise ValueError(f"Expected PyArrowFileIO or FsspecFileIO, got: {table.io}") from e + + def project_table( tasks: Iterable[FileScanTask], table: Table, @@ -761,21 +782,6 @@ def project_table( Raises: ResolveError: When an incompatible query is done. """ - scheme, _ = PyArrowFileIO.parse_location(table.location()) - if isinstance(table.io, PyArrowFileIO): - fs = table.io.get_fs(scheme) - else: - try: - from pyiceberg.io.fsspec import FsspecFileIO - - if isinstance(table.io, FsspecFileIO): - fs = PyFileSystem(FSSpecHandler(table.io.get_fs(scheme))) - else: - raise ValueError(f"Expected PyArrowFileIO or FsspecFileIO, got: {table.io}") - except ModuleNotFoundError as e: - # When FsSpec is not installed - raise ValueError(f"Expected PyArrowFileIO or FsspecFileIO, got: {table.io}") from e - bound_row_filter = bind(table.schema(), row_filter, case_sensitive=case_sensitive) projected_field_ids = { @@ -793,7 +799,8 @@ def project_table( (fs, task, bound_row_filter, projected_schema, projected_field_ids, case_sensitive, rows_counter, limit) for task in tasks ], - chunksize=None, # we could use this to control how to materialize the generator of tasks (we should also make the expression above lazy) + chunksize=None, + # we could use this to control how to materialize the generator of tasks (we should also make the expression above lazy) ) if table is not None ] @@ -915,3 +922,146 @@ def map_key_partner(self, partner_map: Optional[pa.Array]) -> Optional[pa.Array] def map_value_partner(self, partner_map: Optional[pa.Array]) -> Optional[pa.Array]: return partner_map.items if isinstance(partner_map, pa.MapArray) else None + + +def _task_to_scanner(task: FileScanTask, fs: FileSystem, filter: BooleanExpression) -> ds.Scanner: + """Converts a task into an actual scanner. + + Does all the heavy lifting, such as fetching the physical schema, positional deletes, etc + + Args: + task: The task that points to a single Parquet file + fs: The filesystem for the IO + filter: Optional user provided filters + + Returns: + A scanner + """ + _, path = PyArrowFileIO.parse_location(task.file.file_path) + + arrow_format = ds.ParquetFileFormat(pre_buffer=True, buffer_size=(ONE_MEGABYTE * 8)) + # with fs.open_input_file(path) as fin: + fin = fs.open_input_file(path) + fragment = arrow_format.make_fragment(fin) + physical_schema = fragment.physical_schema + schema_raw = None + if metadata := physical_schema.metadata: + schema_raw = metadata.get(ICEBERG_SCHEMA) + file_schema = Schema.parse_raw(schema_raw) if schema_raw is not None else pyarrow_to_schema(physical_schema) + + pyarrow_filter = None + if filter is not AlwaysTrue(): + translated_row_filter = translate_column_names(filter, file_schema, case_sensitive=True) + bound_file_filter = bind(file_schema, translated_row_filter, case_sensitive=True) + pyarrow_filter = expression_to_pyarrow(bound_file_filter) + + # Leave this out for now, this will only fetch the relevant columns + # file_project_schema = prune_columns(file_schema, projected_field_ids, select_full_types=False) + + # Building up the fragment would also handle positional deletes + # see https://github.com/apache/iceberg/pull/6775 + # filter = filter & ~isin("__row_index", pa.Array([1,2,3,4], type=pa.int64) + # As suggested by https://github.com/apache/arrow/issues/35301#issuecomment-1542536407 + + if file_schema is None: + raise ValueError(f"Missing Iceberg schema in Metadata for file: {path}") + + return ds.Scanner.from_fragment( + fragment=fragment, + schema=physical_schema, + filter=pyarrow_filter, + # columns=[col.name for col in file_project_schema.columns], + ) + + +class IcebergFragment(Fragment): + """A fragment of the dataset. + + This should do the heavy lifting, such as fetching files, and + applying positional deletes + + """ + + _task: FileScanTask + _fs: FileSystem + + def __init__(self, task: FileScanTask, fs: FileSystem): + self._task = task + self._fs = fs + + def scanner( + self, + columns: Optional[List[str]] = None, + filter: Optional[pc.Expression] = None, + batch_size: Optional[int] = None, + use_threads: bool = True, + **kwargs, + ) -> ds.Scanner: + return _task_to_scanner(self._task, self._fs, filter=filter) # fix the conversion + + +class IcebergScanner(Scanner): + _schema: pa.Schema + _columns: List[str] + _scanners: List[ds.Scanner] + _filter: pc.Expression + + def __init__( + self, schema: pa.Schema, columns: Optional[List[str]], filter: pc.Expression, scanners: List[ds.Scanner] + ) -> None: + self._schema = schema + self._columns = columns + self._filter = filter + self._scanners = scanners + + def count_rows(self) -> int: + return sum(scanner.count_rows() for scanner in self._scanners) + + def head(self, num_rows: int) -> pa.Table: + # Combine all the tables, because a single fragment can have less than num_rows rows + # We can do early stopping here, but still WIP + table = pa.concat_tables([scanner.head(num_rows) for scanner in self._scanners]) + return table.slice(0, num_rows) + + def to_reader(self) -> RecordBatchReader: + # Batches are for BBQs + batches = chain(*[scanner.to_batches() for scanner in self._scanners]) + return RecordBatchReader.from_batches(batches=batches, schema=schema_to_pyarrow(self._schema)) + + +class IcebergDataset(Dataset): + _table: Table + _expression: BooleanExpression + + def __init__(self, table: Table, expression: BooleanExpression) -> None: + self._table = table + # the filter is temporary, until we have Arrow -> Iceberg expression conversion + self._expression = expression + + def get_fragments(self, filter: Optional[pa.Expression] = None, **kwargs) -> Iterator[Fragment]: + fs = _get_fs_from_table(self._table) + return (IcebergFragment(task, fs) for task in self._table.scan(row_filter=filter).plan_files()) + + @property + def schema(self) -> pa.Schema: + return schema_to_pyarrow(self._table.schema()) + + def scanner( + self, + columns: Optional[List[str]] = None, + filter: Optional[pa.Expression] = None, # This one is ignored until we have arrow -> iceberg expresion conversion + batch_size: Optional[int] = None, + use_threads: bool = True, + **kwargs, + ) -> Scanner: + fs = _get_fs_from_table(self._table) + tasks = self._table.scan(row_filter=self._expression).plan_files() + with ThreadPool() as pool: + scanners = pool.starmap( + func=_task_to_scanner, + iterable=[(task, fs, self._expression) for task in tasks], + ) + + return IcebergScanner( + columns=columns, filter=expression_to_pyarrow(self._expression), scanners=scanners, schema=self._table.schema() + ) diff --git a/python/pyiceberg/utils/pyarrow_dataset.py b/python/pyiceberg/utils/pyarrow_dataset.py new file mode 100644 index 000000000000..4548741304cd --- /dev/null +++ b/python/pyiceberg/utils/pyarrow_dataset.py @@ -0,0 +1,184 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Protocol definitions for pyarrow.dataset + +These provide the abstract interface for a dataset. Other libraries may implement +this interface to expose their data, without having to extend PyArrow's classes. + +Applications and libraries that want to consume datasets should accept datasets +that implement these protocols, rather than requiring the specific +PyArrow classes. + +See Extending PyArrow Datasets for more information: + +https://arrow.apache.org/docs/python/integration/dataset.html +""" +from abc import abstractmethod +from typing import ( + Iterator, + List, + Optional, + Protocol, + runtime_checkable, +) + +from pyarrow import RecordBatchReader, Schema, Table +from pyarrow.dataset import Expression + + +@runtime_checkable +class Scanner(Protocol): + """ + A scanner implementation for a dataset. + + This may be a scan of a whole dataset, or a scan of a single fragment. + """ + + @abstractmethod + def count_rows(self) -> int: + """ + Count the number of rows in this dataset. + + Implementors may provide optimized code paths that compute this from metadata. + + Returns + ------- + int + The number of rows in the dataset. + """ + ... + + @abstractmethod + def head(self, num_rows: int) -> Table: + """ + Get the first ``num_rows`` rows of the dataset. + + Parameters + ---------- + num_rows : int + The number of rows to return. + + Returns + ------- + Table + A table containing the first ``num_rows`` rows of the dataset. + """ + ... + + @abstractmethod + def to_reader(self) -> RecordBatchReader: + """ + Create a Record Batch Reader for this scan. + + This is used to read the data in chunks. + + Returns + ------- + RecordBatchReader + """ + ... + + +@runtime_checkable +class Scannable(Protocol): + @abstractmethod + def scanner( + self, + columns: Optional[List[str]] = None, + filter: Optional[Expression] = None, + batch_size: Optional[int] = None, + use_threads: bool = True, + **kwargs, + ) -> Scanner: + """Create a scanner for this dataset. + + Parameters + ---------- + columns : List[str], optional + Names of columns to include in the scan. If None, all columns are + included. + filter : Expression, optional + Filter expression to apply to the scan. If None, no filter is applied. + batch_size : int, optional + The number of rows to include in each batch. If None, the default + value is used. The default value is implementation specific. + use_threads : bool, default True + Whether to use multiple threads to read the rows. It is expected + that consumers reading a whole dataset in one scanner will keep this + as True, while consumers reading a single fragment per worker will + typically set this to False. + + Notes + ----- + The filters must be fully satisfied. If the dataset cannot satisfy the + filter, it should raise an error. + + Only the following expressions are allowed in the filter: + - Equality / inequalities (==, !=, <, >, <=, >=) + - Conjunctions (and, or) + - Field references (e.g. "a" or "a.b.c") + - Literals (e.g. 1, 1.0, "a", True) + - cast + - is_null / not_null + - isin + - between + - negation (not) + + """ + ... + + +@runtime_checkable +class Fragment(Scannable, Protocol): + """A fragment of a dataset. + + This might be a partition, a file, a file chunk, etc. + + This class should be pickleable so that it can be used in a distributed scan.""" + + ... + + +@runtime_checkable +class Dataset(Scannable, Protocol): + @abstractmethod + def get_fragments(self, filter: Optional[Expression] = None, **kwargs) -> Iterator[Fragment]: + """Get the fragments of this dataset. + + Parameters + ---------- + filter : Expression, optional + Filter expression to use to prune which fragments are selected. + See Scannable.scanner for details on allowed filters. The filter is + just used to prune which fragments are selected. It does not need to + save the filter to apply to the scan. That is handled by the scanner. + **kwargs : dict + Additional arguments to pass to underlying implementation. + """ + ... + + @property + @abstractmethod + def schema(self) -> Schema: + """ + Get the schema of this dataset. + + Returns + ------- + Schema + """ + ...