Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions arbalister/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import jupyterlab.labapp

from . import arrow as arrow
from . import file_format as file_format
from . import params as params
from . import routes as routes

Expand Down
88 changes: 88 additions & 0 deletions arbalister/adbc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import dataclasses
import pathlib
from typing import Any, Literal, Self

import adbc_driver_sqlite.dbapi as adbc_sqlite
import pyarrow as pa


def write_sqlite(
table: pa.Table,
path: str | pathlib.Path,
table_name: str = "Table",
catalog_name: str | None = None,
mode: Literal["append", "create", "replace", "create_append"] = "create",
) -> None:
"""Write a table as an Sqlite file."""
with adbc_sqlite.connect(str(path)) as connection:
with connection.cursor() as cursor:
cursor.adbc_ingest(table_name=table_name, data=table, mode=mode, catalog_name=catalog_name)
connection.commit()


@dataclasses.dataclass
class SqliteDataFrame:
"""A DataFrame plan on a Sqlite file.

This plan is made to resemble a Apache Datafusion DataFrame despite some performances penalty.
This is because in the Rust Datafusion implementation of as a Sqlite contrib which we would
like to eventually use so this class is only filling in temporarily.
"""

_path: str
_table_name: str
_schema: pa.Schema
_num_rows: int
_limit: int | None = None
_offset: int | None = None
_select: list[str] | None = None

@classmethod
def read_sqlite(cls, context: Any, path: pathlib.Path | str, table_name: str | None = None) -> Self:
"""Read an Sqlite file metadata and start a new DataFrame plan."""
with adbc_sqlite.connect(str(path)) as connection:
with connection.cursor() as cursor:
cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
tables = [row for (row,) in cursor.fetchall()]

if table_name is None and len(tables) > 0:
table_name = tables[0]
if table_name not in tables or table_name is None:
raise ValueError(f"Invalid table name {table_name}")

cursor.execute(f'SELECT COUNT(*) FROM "{table_name}"')
num_rows = cursor.fetchone()[0] # type: ignore[index]

schema = connection.adbc_get_table_schema(table_name)

return cls(_path=str(path), _table_name=table_name, _schema=schema, _num_rows=num_rows)

def schema(self) -> pa.Schema:
"""Return the :py:class:`pyarrow.Schema` of this DataFrame."""
return self._schema

def limit(self, count: int, offset: int = 0) -> Self:
"""Return a new DataFrame with a limited number of rows."""
return dataclasses.replace(self, _limit=count, _offset=offset)

def select(self, *columns: str) -> Self:
"""Project arbitrary expressions into a new `DataFrame`."""
return dataclasses.replace(self, _select=list(columns))

def count(self) -> int:
"""Return the total number of rows in this DataFrame."""
return self._num_rows

def to_arrow_table(self) -> pa.Table:
"""Execute the DataFrame and convert it into an Arrow Table."""
limit = (
f"LIMIT {self._limit} OFFSET {self._offset}"
if self._limit is not None and self._offset is not None
else ""
)
columns = self._select if self._select is not None else ["*"]

with adbc_sqlite.connect(self._path) as connection:
with connection.cursor() as cursor:
cursor.execute(f'SELECT {",".join(columns)} FROM "{self._table_name}" {limit}')
return cursor.fetch_arrow_table()
68 changes: 25 additions & 43 deletions arbalister/arrow.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,10 @@
import enum
import pathlib
from typing import Any, Callable, Self
from typing import Any, Callable

import datafusion as dn
import pyarrow as pa


class FileFormat(enum.StrEnum):
"""Known file format that we can read into an Arrow format.

Todo:
- ADBC (Sqlite/Postgres)

"""

Avro = "avro"
Csv = "csv"
Ipc = "ipc"
Orc = "orc"
Parquet = "parquet"

@classmethod
def from_filename(cls, file: pathlib.Path | str) -> Self:
"""Get the file format from a filename extension."""
file_type = pathlib.Path(file).suffix.removeprefix(".").strip().lower()

# Match again their default value
if ft := next((ft for ft in FileFormat if str(ft) == file_type), None):
return ft
# Match other known values
match file_type:
case "ipc" | "feather":
return cls.Ipc
raise ValueError(f"Unknown file type {file_type}")

from . import file_format as ff

ReadCallable = Callable[..., dn.DataFrame]

Expand Down Expand Up @@ -84,21 +55,21 @@ def _write_avro(
writer.close()


def get_table_reader(format: FileFormat) -> ReadCallable:
def get_table_reader(format: ff.FileFormat) -> ReadCallable:
"""Get the datafusion reader factory function for the given format."""
# TODO: datafusion >= 50.0
# def read(ctx: dtfn.SessionContext, path: str | pathlib.Path, *args, **kwargs) -> dtfn.DataFrame:
# ds = pads.dataset(source=path, format=format.value)
# return ctx.read_table(ds, *args, **kwargs)
out: ReadCallable
match format:
case FileFormat.Avro:
case ff.FileFormat.Avro:
out = dn.SessionContext.read_avro
case FileFormat.Csv:
case ff.FileFormat.Csv:
out = dn.SessionContext.read_csv
case FileFormat.Parquet:
case ff.FileFormat.Parquet:
out = dn.SessionContext.read_parquet
case FileFormat.Ipc:
case ff.FileFormat.Ipc:
import pyarrow.feather

def read_ipc(
Expand All @@ -109,7 +80,7 @@ def read_ipc(
return ctx.from_arrow(table)

out = read_ipc
case FileFormat.Orc:
case ff.FileFormat.Orc:
# Watch for https://github.com/datafusion-contrib/datafusion-orc
# Evolution for native datafusion reader
import pyarrow.orc
Expand All @@ -121,33 +92,44 @@ def read_orc(
return ctx.from_arrow(table)

out = read_orc
case ff.FileFormat.Sqlite:
from . import adbc as adbc

# FIXME: For now we just pretend SqliteDataFrame is a datafision DataFrame
# Either we integrate it properly into Datafusion, or we create a DataFrame as a
# typing.protocol.
out = adbc.SqliteDataFrame.read_sqlite # type: ignore[assignment]

return out


WriteCallable = Callable[..., None]


def get_table_writer(format: FileFormat) -> WriteCallable:
def get_table_writer(format: ff.FileFormat) -> WriteCallable:
"""Get the arrow writer factory function for the given format."""
out: WriteCallable
match format:
case FileFormat.Avro:
case ff.FileFormat.Avro:
out = _write_avro
case FileFormat.Csv:
case ff.FileFormat.Csv:
import pyarrow.csv

out = pyarrow.csv.write_csv
case FileFormat.Parquet:
case ff.FileFormat.Parquet:
import pyarrow.parquet

out = pyarrow.parquet.write_table
case FileFormat.Ipc:
case ff.FileFormat.Ipc:
import pyarrow.feather

out = pyarrow.feather.write_feather
case FileFormat.Orc:
case ff.FileFormat.Orc:
import pyarrow.orc

out = pyarrow.orc.write_table
case ff.FileFormat.Sqlite:
from . import adbc as adbc

out = adbc.write_sqlite
return out
35 changes: 35 additions & 0 deletions arbalister/file_format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import enum
import pathlib
from typing import Self


class FileFormat(enum.StrEnum):
"""Known file format that we can read into an Arrow format.

Todo:
- ADBC (Sqlite/Postgres)

"""

Avro = "avro"
Csv = "csv"
Ipc = "ipc"
Orc = "orc"
Parquet = "parquet"
Sqlite = "sqlite"

@classmethod
def from_filename(cls, file: pathlib.Path | str) -> Self:
"""Get the file format from a filename extension."""
file_type = pathlib.Path(file).suffix.removeprefix(".").strip().lower()

# Match again their default value
if ft := next((ft for ft in FileFormat if str(ft) == file_type), None):
return ft
# Match other known values
match file_type:
case "ipc" | "feather":
return cls.Ipc
case "sqlite3" | "db" | ".db3", ".s3db", ".sl3":
return cls.Sqlite
raise ValueError(f"Unknown file type {file_type}")
29 changes: 27 additions & 2 deletions arbalister/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,25 @@
from jupyter_server.utils import url_path_join

from . import arrow as abw
from . import file_format as ff
from . import params as params


@dataclasses.dataclass(frozen=True, slots=True)
class SqliteReadParams:
"""Query parameter for the Sqlite reader."""

table_name: str | None = None


@dataclasses.dataclass(frozen=True, slots=True)
class NoReadParams:
"""Query parameter for readers with no parameters."""


FileReadParams = SqliteReadParams | NoReadParams


class BaseRouteHandler(jupyter_server.base.handlers.APIHandler):
"""A base handler to share common methods."""

Expand All @@ -33,13 +49,22 @@ def dataframe(self, path: str) -> dn.DataFrame:
Note: On some file type, the file is read eagerly when calling this method.
"""
file = self.data_file(path)
read_table = abw.get_table_reader(format=abw.FileFormat.from_filename(file))
return read_table(self.context, file)
file_format = ff.FileFormat.from_filename(file)
file_params = self.get_file_read_params(file_format)
read_table = abw.get_table_reader(format=file_format)
return read_table(self.context, file, **dataclasses.asdict(file_params))

def get_query_params_as[T](self, dataclass_type: type[T]) -> T:
"""Extract query parameters into a dataclass type."""
return params.build_dataclass(dataclass_type, self.get_query_argument)

def get_file_read_params(self, file_format: ff.FileFormat) -> FileReadParams:
"""Read the parameters associated with the relevant file format."""
match file_format:
case ff.FileFormat.Sqlite:
return self.get_query_params_as(SqliteReadParams)
return NoReadParams()


@dataclasses.dataclass(frozen=True, slots=True)
class IpcParams:
Expand Down
Loading
Loading