Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions arbalister/arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pathlib
from typing import Any, Callable, Self

import datafusion as dtfn
import datafusion as dn
import pyarrow as pa


Expand Down Expand Up @@ -35,7 +35,7 @@ def from_filename(cls, file: pathlib.Path | str) -> Self:
raise ValueError(f"Unknown file type {file_type}")


ReadCallable = Callable[..., dtfn.DataFrame]
ReadCallable = Callable[..., dn.DataFrame]


def _arrow_to_avro_type(field: pa.Field) -> str | dict[str, Any]:
Expand Down Expand Up @@ -93,17 +93,17 @@ def get_table_reader(format: FileFormat) -> ReadCallable:
out: ReadCallable
match format:
case FileFormat.Avro:
out = dtfn.SessionContext.read_avro
out = dn.SessionContext.read_avro
case FileFormat.Csv:
out = dtfn.SessionContext.read_csv
out = dn.SessionContext.read_csv
case FileFormat.Parquet:
out = dtfn.SessionContext.read_parquet
out = dn.SessionContext.read_parquet
case FileFormat.Ipc:
import pyarrow.feather

def read_ipc(
ctx: dtfn.SessionContext, path: str | pathlib.Path, **kwargs: dict[str, Any]
) -> dtfn.DataFrame:
ctx: dn.SessionContext, path: str | pathlib.Path, **kwargs: dict[str, Any]
) -> dn.DataFrame:
# table = pyarrow.feather.read_table(path, {**{"memory_map": True}, **kwargs})
table = pyarrow.feather.read_table(path, **kwargs)
return ctx.from_arrow(table)
Expand All @@ -115,8 +115,8 @@ def read_ipc(
import pyarrow.orc

def read_orc(
ctx: dtfn.SessionContext, path: str | pathlib.Path, **kwargs: dict[str, Any]
) -> dtfn.DataFrame:
ctx: dn.SessionContext, path: str | pathlib.Path, **kwargs: dict[str, Any]
) -> dn.DataFrame:
table = pyarrow.orc.read_table(path, **kwargs)
return ctx.from_arrow(table)

Expand Down
52 changes: 29 additions & 23 deletions arbalister/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import os
import pathlib

import datafusion as dtfn
import datafusion.functions as fn
import datafusion as dn
import datafusion.functions as dnf
import jupyter_server.base.handlers
import jupyter_server.serverapp
import pyarrow as pa
Expand All @@ -17,24 +17,17 @@
class BaseRouteHandler(jupyter_server.base.handlers.APIHandler):
"""A base handler to share common methods."""

def initialize(self, context: dtfn.SessionContext) -> None:
def initialize(self, context: dn.SessionContext) -> None:
"""Process custom constructor arguments."""
super().initialize()
self.context = context

def make(self) -> dtfn.SessionConfig:
"""Return the datafusion config."""
config = dtfn.SessionConfig()
# String views do not get written properly to IPC
config.set("datafusion.execution.parquet.schema_force_view_types", "false")
return config

def data_file(self, path: str) -> pathlib.Path:
"""Return the file that is requested by the URL path."""
root_dir = pathlib.Path(os.path.expanduser(self.settings["server_root_dir"])).resolve()
return root_dir / path

def dataframe(self, path: str) -> dtfn.DataFrame:
def dataframe(self, path: str) -> dn.DataFrame:
"""Return the DataFusion lazy DataFrame.

Note: On some file type, the file is read eagerly when calling this method.
Expand All @@ -52,8 +45,10 @@ def get_query_params_as[T](self, dataclass_type: type[T]) -> T:
class IpcParams:
"""Query parameter for IPC data."""

per_chunk: int | None = None
chunk: int | None = None
row_chunk_size: int | None = None
row_chunk: int | None = None
col_chunk_size: int | None = None
col_chunk: int | None = None


class IpcRouteHandler(BaseRouteHandler):
Expand All @@ -66,11 +61,17 @@ async def get(self, path: str) -> None:

self.set_header("Content-Type", "application/vnd.apache.arrow.stream")

df: dtfn.DataFrame = self.dataframe(path)
df: dn.DataFrame = self.dataframe(path)

if params.row_chunk_size is not None and params.row_chunk is not None:
offset: int = params.row_chunk * params.row_chunk_size
df = df.limit(count=params.row_chunk_size, offset=offset)

if params.per_chunk is not None and params.chunk is not None:
offset: int = params.chunk * params.per_chunk
df = df.limit(count=params.per_chunk, offset=offset)
if params.col_chunk_size is not None and params.col_chunk is not None:
col_names = df.schema().names
start: int = params.col_chunk * params.col_chunk_size
end: int = start + params.col_chunk_size
df = df.select(*col_names[start:end])

table: pa.Table = df.to_arrow_table()

Expand Down Expand Up @@ -113,18 +114,23 @@ async def get(self, path: str) -> None:
# No dedicated exception type coming from DataFusion
if str(e).startswith("DataFusion"):
first_col: str = schema.names[0]
batches = df.aggregate([], [fn.count(dtfn.col(first_col))]).collect()
batches = df.aggregate([], [dnf.count(dn.col(first_col))]).collect()
num_rows = batches[0].column(0)[0].as_py()

response = StatsResponse(num_cols=len(schema), num_rows=num_rows)
await self.finish(dataclasses.asdict(response))


def make_datafusion_config() -> dtfn.SessionConfig:
def make_datafusion_config() -> dn.SessionConfig:
"""Return the datafusion config."""
config = dtfn.SessionConfig()
# String views do not get written properly to IPC
config.set("datafusion.execution.parquet.schema_force_view_types", "false")
config = (
dn.SessionConfig()
# Must use a single partition otherwise limit parallelism will return arbitrary rows
.with_target_partitions(1)
# String views do not get written properly to IPC
.set("datafusion.execution.parquet.schema_force_view_types", "false")
)

return config


Expand All @@ -133,7 +139,7 @@ def setup_route_handlers(web_app: jupyter_server.serverapp.ServerWebApplication)
host_pattern = ".*$"
base_url = web_app.settings["base_url"]

context = dtfn.SessionContext(make_datafusion_config())
context = dn.SessionContext(make_datafusion_config())

handlers = [
(url_path_join(base_url, r"arrow/stream/([^?]*)"), IpcRouteHandler, {"context": context}),
Expand Down
104 changes: 77 additions & 27 deletions arbalister/tests/test_routes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import dataclasses
import json
import pathlib
from typing import Awaitable, Callable
import random
import string
from typing import Awaitable, Callable, Final

import pyarrow as pa
import pytest
Expand All @@ -9,21 +12,29 @@
import arbalister as arb


@pytest.fixture(params=list(arb.arrow.FileFormat))
@pytest.fixture(params=list(arb.arrow.FileFormat), scope="session")
def file_format(request: pytest.FixtureRequest) -> arb.arrow.FileFormat:
"""Parametrize the file format used in the test."""
out: arb.arrow.FileFormat = request.param
return out


@pytest.fixture
DUMMY_TABLE_ROW_COUNT: Final = 10
DUMMY_TABLE_COL_COUNT: Final = 4


@pytest.fixture(scope="module")
def dummy_table() -> pa.Table:
"""Generate a table with fake data."""
data = {
"letter": list("abcdefghij"),
"number": list(range(10)),
"lower": random.choices(string.ascii_lowercase, k=DUMMY_TABLE_ROW_COUNT),
"sequence": list(range(DUMMY_TABLE_ROW_COUNT)),
"upper": random.choices(string.ascii_uppercase, k=DUMMY_TABLE_ROW_COUNT),
"number": [random.random() for _ in range(DUMMY_TABLE_ROW_COUNT)],
}
return pa.table(data)
table = pa.table(data)
assert len(table.schema) == DUMMY_TABLE_COL_COUNT
return table


@pytest.fixture
Expand All @@ -40,36 +51,75 @@ def dummy_table_file(
JpFetch = Callable[..., Awaitable[tornado.httpclient.HTTPResponse]]


async def test_ipc_route(jp_fetch: JpFetch, dummy_table: pa.Table, dummy_table_file: pathlib.Path) -> None:
"""Test fetching a file returns the correct data in IPC."""
response = await jp_fetch("arrow/stream/", str(dummy_table_file))

assert response.code == 200
assert response.headers["Content-Type"] == "application/vnd.apache.arrow.stream"

payload = pa.ipc.open_stream(response.body).read_all()
assert dummy_table.num_rows == payload.num_rows
assert dummy_table.cast(payload.schema) == payload


async def test_ipc_route_limit(
jp_fetch: JpFetch, dummy_table: pa.Table, dummy_table_file: pathlib.Path
@pytest.mark.parametrize(
"params",
[
arb.routes.IpcParams(),
# Limit only number of rows
arb.routes.IpcParams(row_chunk=0, row_chunk_size=3),
arb.routes.IpcParams(row_chunk=1, row_chunk_size=2),
arb.routes.IpcParams(row_chunk=0, row_chunk_size=DUMMY_TABLE_ROW_COUNT),
arb.routes.IpcParams(row_chunk=1, row_chunk_size=DUMMY_TABLE_ROW_COUNT // 2 + 1),
# Limit only number of cols
arb.routes.IpcParams(col_chunk=0, col_chunk_size=3),
arb.routes.IpcParams(col_chunk=1, col_chunk_size=2),
arb.routes.IpcParams(col_chunk=0, col_chunk_size=DUMMY_TABLE_COL_COUNT),
arb.routes.IpcParams(col_chunk=1, col_chunk_size=DUMMY_TABLE_COL_COUNT // 2 + 1),
# Limit both
arb.routes.IpcParams(
row_chunk=0,
row_chunk_size=3,
col_chunk=1,
col_chunk_size=DUMMY_TABLE_COL_COUNT // 2 + 1,
),
arb.routes.IpcParams(
row_chunk=0,
row_chunk_size=DUMMY_TABLE_ROW_COUNT,
col_chunk=1,
col_chunk_size=2,
),
# Schema only
arb.routes.IpcParams(
row_chunk=0,
row_chunk_size=0,
),
],
)
async def test_ipc_route_limit_row(
jp_fetch: JpFetch,
dummy_table: pa.Table,
dummy_table_file: pathlib.Path,
params: arb.routes.IpcParams,
) -> None:
"""Test fetching a file returns the limited data in IPC."""
num_rows = 2
chunk = 1
"""Test fetching a file returns the limited rows and columns in IPC."""
response = await jp_fetch(
"arrow/stream",
str(dummy_table_file),
params={"per_chunk": num_rows, "chunk": chunk},
params={k: v for k, v in dataclasses.asdict(params).items() if v is not None},
)

assert response.code == 200
assert response.headers["Content-Type"] == "application/vnd.apache.arrow.stream"

payload = pa.ipc.open_stream(response.body).read_all()
assert payload.num_rows == num_rows
assert dummy_table.slice(chunk * num_rows, num_rows).cast(payload.schema) == payload

expected = dummy_table

# Row slicing
if (size := params.row_chunk_size) is not None and (cidx := params.row_chunk) is not None:
expected_num_rows = min((size * (cidx + 1)), expected.num_rows) - (size * cidx)
assert payload.num_rows == expected_num_rows
expected = expected.slice(cidx * size, size)

# Col slicing
if (size := params.col_chunk_size) is not None and (cidx := params.col_chunk) is not None:
expected_num_cols = min((size * (cidx + 1)), len(expected.schema)) - (size * cidx)
assert len(payload.schema) == expected_num_cols
col_names = expected.schema.names
start = cidx * size
end = start + size
expected = expected.select(col_names[start:end])

assert expected.cast(payload.schema) == payload


async def test_stats_route(jp_fetch: JpFetch, dummy_table: pa.Table, dummy_table_file: pathlib.Path) -> None:
Expand Down
Loading
Loading