diff --git a/arbalister/arrow.py b/arbalister/arrow.py index d24882d..4a4dc5b 100644 --- a/arbalister/arrow.py +++ b/arbalister/arrow.py @@ -115,7 +115,20 @@ def get_table_writer(format: ff.FileFormat) -> WriteCallable: case ff.FileFormat.Csv: import pyarrow.csv - out = pyarrow.csv.write_csv + def write_csv( + data: pa.Table, + output_file: str | pathlib.Path, + memory_pool: pa.MemoryPool | None = None, + **kwargs: dict[str, Any], + ) -> None: + pyarrow.csv.write_csv( + data=data, + output_file=str(output_file), + memory_pool=memory_pool, + write_options=pyarrow.csv.WriteOptions(**kwargs), + ) + + out = write_csv case ff.FileFormat.Parquet: import pyarrow.parquet diff --git a/arbalister/routes.py b/arbalister/routes.py index 0cb2ce1..7a75bf1 100644 --- a/arbalister/routes.py +++ b/arbalister/routes.py @@ -22,12 +22,19 @@ class SqliteReadParams: table_name: str | None = None +@dataclasses.dataclass(frozen=True, slots=True) +class CSVReadParams: + """Query parameter for the CSV reader.""" + + delimiter: str | None = "," + + @dataclasses.dataclass(frozen=True, slots=True) class NoReadParams: """Query parameter for readers with no parameters.""" -FileReadParams = SqliteReadParams | NoReadParams +FileReadParams = SqliteReadParams | CSVReadParams | NoReadParams class BaseRouteHandler(jupyter_server.base.handlers.APIHandler): @@ -63,6 +70,8 @@ def get_file_read_params(self, file_format: ff.FileFormat) -> FileReadParams: match file_format: case ff.FileFormat.Sqlite: return self.get_query_params_as(SqliteReadParams) + case ff.FileFormat.Csv: + return self.get_query_params_as(CSVReadParams) return NoReadParams() diff --git a/arbalister/tests/test_routes.py b/arbalister/tests/test_routes.py index 9b571b0..8bfc045 100644 --- a/arbalister/tests/test_routes.py +++ b/arbalister/tests/test_routes.py @@ -17,6 +17,7 @@ params=[ (ff.FileFormat.Avro, arb.routes.NoReadParams()), (ff.FileFormat.Csv, arb.routes.NoReadParams()), + (ff.FileFormat.Csv, arb.routes.CSVReadParams(delimiter=";")), (ff.FileFormat.Ipc, arb.routes.NoReadParams()), (ff.FileFormat.Orc, arb.routes.NoReadParams()), (ff.FileFormat.Parquet, arb.routes.NoReadParams()), @@ -101,6 +102,8 @@ def table_file( table_path = jp_root_dir / f"test.{str(file_format).lower()}" match file_format: + case ff.FileFormat.Csv: + write_table(dummy_table_1, table_path, delimiter=getattr(file_params, "delimiter", ",")) case ff.FileFormat.Sqlite: write_table(dummy_table_1, table_path, table_name="dummy_table_1", mode="create_append") write_table(dummy_table_2, table_path, table_name="dummy_table_2", mode="create_append") @@ -159,6 +162,7 @@ async def test_ipc_route_limit( table_file: pathlib.Path, ipc_params: arb.routes.IpcParams, file_params: arb.routes.SqliteReadParams, + file_format: ff.FileFormat, ) -> None: """Test fetching a file returns the limited rows and columns in IPC.""" response = await jp_fetch( @@ -200,6 +204,7 @@ async def test_stats_route( full_table: pa.Table, table_file: pathlib.Path, file_params: arb.routes.SqliteReadParams, + file_format: ff.FileFormat, ) -> None: """Test fetching a file returns the correct metadata in Json.""" response = await jp_fetch( @@ -212,5 +217,6 @@ async def test_stats_route( assert response.headers["Content-Type"] == "application/json; charset=UTF-8" payload = json.loads(response.body) + assert payload["num_cols"] == len(full_table.schema) assert payload["num_rows"] == full_table.num_rows