Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion arbalister/arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,20 @@ def get_table_writer(format: ff.FileFormat) -> WriteCallable:
case ff.FileFormat.Csv:
import pyarrow.csv

out = pyarrow.csv.write_csv
def write_csv(
data: pa.Table,
output_file: str | pathlib.Path,
memory_pool: pa.MemoryPool | None = None,
**kwargs: dict[str, Any],
) -> None:
pyarrow.csv.write_csv(
data=data,
output_file=str(output_file),
memory_pool=memory_pool,
write_options=pyarrow.csv.WriteOptions(**kwargs),
)

out = write_csv
case ff.FileFormat.Parquet:
import pyarrow.parquet

Expand Down
11 changes: 10 additions & 1 deletion arbalister/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,19 @@ class SqliteReadParams:
table_name: str | None = None


@dataclasses.dataclass(frozen=True, slots=True)
class CSVReadParams:
"""Query parameter for the CSV reader."""

delimiter: str | None = ","


@dataclasses.dataclass(frozen=True, slots=True)
class NoReadParams:
"""Query parameter for readers with no parameters."""


FileReadParams = SqliteReadParams | NoReadParams
FileReadParams = SqliteReadParams | CSVReadParams | NoReadParams


class BaseRouteHandler(jupyter_server.base.handlers.APIHandler):
Expand Down Expand Up @@ -63,6 +70,8 @@ def get_file_read_params(self, file_format: ff.FileFormat) -> FileReadParams:
match file_format:
case ff.FileFormat.Sqlite:
return self.get_query_params_as(SqliteReadParams)
case ff.FileFormat.Csv:
return self.get_query_params_as(CSVReadParams)
return NoReadParams()


Expand Down
6 changes: 6 additions & 0 deletions arbalister/tests/test_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
params=[
(ff.FileFormat.Avro, arb.routes.NoReadParams()),
(ff.FileFormat.Csv, arb.routes.NoReadParams()),
(ff.FileFormat.Csv, arb.routes.CSVReadParams(delimiter=";")),
(ff.FileFormat.Ipc, arb.routes.NoReadParams()),
(ff.FileFormat.Orc, arb.routes.NoReadParams()),
(ff.FileFormat.Parquet, arb.routes.NoReadParams()),
Expand Down Expand Up @@ -101,6 +102,8 @@ def table_file(
table_path = jp_root_dir / f"test.{str(file_format).lower()}"

match file_format:
case ff.FileFormat.Csv:
write_table(dummy_table_1, table_path, delimiter=getattr(file_params, "delimiter", ","))
case ff.FileFormat.Sqlite:
write_table(dummy_table_1, table_path, table_name="dummy_table_1", mode="create_append")
write_table(dummy_table_2, table_path, table_name="dummy_table_2", mode="create_append")
Expand Down Expand Up @@ -159,6 +162,7 @@ async def test_ipc_route_limit(
table_file: pathlib.Path,
ipc_params: arb.routes.IpcParams,
file_params: arb.routes.SqliteReadParams,
file_format: ff.FileFormat,
) -> None:
"""Test fetching a file returns the limited rows and columns in IPC."""
response = await jp_fetch(
Expand Down Expand Up @@ -200,6 +204,7 @@ async def test_stats_route(
full_table: pa.Table,
table_file: pathlib.Path,
file_params: arb.routes.SqliteReadParams,
file_format: ff.FileFormat,
) -> None:
"""Test fetching a file returns the correct metadata in Json."""
response = await jp_fetch(
Expand All @@ -212,5 +217,6 @@ async def test_stats_route(
assert response.headers["Content-Type"] == "application/json; charset=UTF-8"

payload = json.loads(response.body)

assert payload["num_cols"] == len(full_table.schema)
assert payload["num_rows"] == full_table.num_rows
Loading