Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 166 additions & 16 deletions python/pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import os
from abc import ABC, abstractmethod
from functools import lru_cache, singledispatch
from itertools import chain
from multiprocessing.pool import ThreadPool
from multiprocessing.sharedctypes import Synchronized
from typing import (
Expand All @@ -36,6 +37,7 @@
Callable,
Generic,
Iterable,
Iterator,
List,
Optional,
Set,
Expand All @@ -49,6 +51,7 @@
import pyarrow as pa
import pyarrow.compute as pc
import pyarrow.dataset as ds
from pyarrow import RecordBatchReader
from pyarrow.fs import (
FileInfo,
FileSystem,
Expand Down Expand Up @@ -118,6 +121,7 @@
TimeType,
UUIDType,
)
from pyiceberg.utils.pyarrow_dataset import Dataset, Fragment, Scanner
from pyiceberg.utils.singleton import Singleton

if TYPE_CHECKING:
Expand Down Expand Up @@ -741,6 +745,23 @@ def _file_to_table(
return None


def _get_fs_from_table(table: Table) -> FileSystem:
scheme, _ = PyArrowFileIO.parse_location(table.location())
if isinstance(table.io, PyArrowFileIO):
return table.io.get_fs(scheme)
else:
try:
from pyiceberg.io.fsspec import FsspecFileIO

if isinstance(table.io, FsspecFileIO):
return PyFileSystem(FSSpecHandler(table.io.get_fs(scheme)))
else:
raise ValueError(f"Expected PyArrowFileIO or FsspecFileIO, got: {table.io}")
except ModuleNotFoundError as e:
# When FsSpec is not installed
raise ValueError(f"Expected PyArrowFileIO or FsspecFileIO, got: {table.io}") from e


def project_table(
tasks: Iterable[FileScanTask],
table: Table,
Expand All @@ -761,21 +782,6 @@ def project_table(
Raises:
ResolveError: When an incompatible query is done.
"""
scheme, _ = PyArrowFileIO.parse_location(table.location())
if isinstance(table.io, PyArrowFileIO):
fs = table.io.get_fs(scheme)
else:
try:
from pyiceberg.io.fsspec import FsspecFileIO

if isinstance(table.io, FsspecFileIO):
fs = PyFileSystem(FSSpecHandler(table.io.get_fs(scheme)))
else:
raise ValueError(f"Expected PyArrowFileIO or FsspecFileIO, got: {table.io}")
except ModuleNotFoundError as e:
# When FsSpec is not installed
raise ValueError(f"Expected PyArrowFileIO or FsspecFileIO, got: {table.io}") from e

bound_row_filter = bind(table.schema(), row_filter, case_sensitive=case_sensitive)

projected_field_ids = {
Expand All @@ -793,7 +799,8 @@ def project_table(
(fs, task, bound_row_filter, projected_schema, projected_field_ids, case_sensitive, rows_counter, limit)
for task in tasks
],
chunksize=None, # we could use this to control how to materialize the generator of tasks (we should also make the expression above lazy)
chunksize=None,
# we could use this to control how to materialize the generator of tasks (we should also make the expression above lazy)
)
if table is not None
]
Expand Down Expand Up @@ -915,3 +922,146 @@ def map_key_partner(self, partner_map: Optional[pa.Array]) -> Optional[pa.Array]

def map_value_partner(self, partner_map: Optional[pa.Array]) -> Optional[pa.Array]:
return partner_map.items if isinstance(partner_map, pa.MapArray) else None


def _task_to_scanner(task: FileScanTask, fs: FileSystem, filter: BooleanExpression) -> ds.Scanner:
"""Converts a task into an actual scanner.

Does all the heavy lifting, such as fetching the physical schema, positional deletes, etc

Args:
task: The task that points to a single Parquet file
fs: The filesystem for the IO
filter: Optional user provided filters

Returns:
A scanner
"""
_, path = PyArrowFileIO.parse_location(task.file.file_path)

arrow_format = ds.ParquetFileFormat(pre_buffer=True, buffer_size=(ONE_MEGABYTE * 8))
# with fs.open_input_file(path) as fin:
fin = fs.open_input_file(path)
fragment = arrow_format.make_fragment(fin)
physical_schema = fragment.physical_schema
schema_raw = None
if metadata := physical_schema.metadata:
schema_raw = metadata.get(ICEBERG_SCHEMA)
file_schema = Schema.parse_raw(schema_raw) if schema_raw is not None else pyarrow_to_schema(physical_schema)

pyarrow_filter = None
if filter is not AlwaysTrue():
translated_row_filter = translate_column_names(filter, file_schema, case_sensitive=True)
bound_file_filter = bind(file_schema, translated_row_filter, case_sensitive=True)
pyarrow_filter = expression_to_pyarrow(bound_file_filter)

# Leave this out for now, this will only fetch the relevant columns
# file_project_schema = prune_columns(file_schema, projected_field_ids, select_full_types=False)

# Building up the fragment would also handle positional deletes
# see https://github.com/apache/iceberg/pull/6775
# filter = filter & ~isin("__row_index", pa.Array([1,2,3,4], type=pa.int64)
# As suggested by https://github.com/apache/arrow/issues/35301#issuecomment-1542536407

if file_schema is None:
raise ValueError(f"Missing Iceberg schema in Metadata for file: {path}")

return ds.Scanner.from_fragment(
fragment=fragment,
schema=physical_schema,
filter=pyarrow_filter,
# columns=[col.name for col in file_project_schema.columns],
)


class IcebergFragment(Fragment):
"""A fragment of the dataset.

This should do the heavy lifting, such as fetching files, and
applying positional deletes

"""

_task: FileScanTask
_fs: FileSystem

def __init__(self, task: FileScanTask, fs: FileSystem):
self._task = task
self._fs = fs

def scanner(
self,
columns: Optional[List[str]] = None,
filter: Optional[pc.Expression] = None,
batch_size: Optional[int] = None,
use_threads: bool = True,
**kwargs,
) -> ds.Scanner:
return _task_to_scanner(self._task, self._fs, filter=filter) # fix the conversion


class IcebergScanner(Scanner):
_schema: pa.Schema
_columns: List[str]
_scanners: List[ds.Scanner]
_filter: pc.Expression

def __init__(
self, schema: pa.Schema, columns: Optional[List[str]], filter: pc.Expression, scanners: List[ds.Scanner]
) -> None:
self._schema = schema
self._columns = columns
self._filter = filter
self._scanners = scanners

def count_rows(self) -> int:
return sum(scanner.count_rows() for scanner in self._scanners)

def head(self, num_rows: int) -> pa.Table:
# Combine all the tables, because a single fragment can have less than num_rows rows
# We can do early stopping here, but still WIP
table = pa.concat_tables([scanner.head(num_rows) for scanner in self._scanners])
return table.slice(0, num_rows)

def to_reader(self) -> RecordBatchReader:
# Batches are for BBQs
batches = chain(*[scanner.to_batches() for scanner in self._scanners])
return RecordBatchReader.from_batches(batches=batches, schema=schema_to_pyarrow(self._schema))


class IcebergDataset(Dataset):
_table: Table
_expression: BooleanExpression

def __init__(self, table: Table, expression: BooleanExpression) -> None:
self._table = table
# the filter is temporary, until we have Arrow -> Iceberg expression conversion
self._expression = expression

def get_fragments(self, filter: Optional[pa.Expression] = None, **kwargs) -> Iterator[Fragment]:
fs = _get_fs_from_table(self._table)
return (IcebergFragment(task, fs) for task in self._table.scan(row_filter=filter).plan_files())

@property
def schema(self) -> pa.Schema:
return schema_to_pyarrow(self._table.schema())

def scanner(
self,
columns: Optional[List[str]] = None,
filter: Optional[pa.Expression] = None, # This one is ignored until we have arrow -> iceberg expresion conversion
batch_size: Optional[int] = None,
use_threads: bool = True,
**kwargs,
) -> Scanner:
fs = _get_fs_from_table(self._table)
tasks = self._table.scan(row_filter=self._expression).plan_files()
with ThreadPool() as pool:
scanners = pool.starmap(
func=_task_to_scanner,
iterable=[(task, fs, self._expression) for task in tasks],
)

return IcebergScanner(
columns=columns, filter=expression_to_pyarrow(self._expression), scanners=scanners, schema=self._table.schema()
)
Loading