From 656bef32ed58b0b28a6118b1282f44398d2beca7 Mon Sep 17 00:00:00 2001 From: AntoinePrv Date: Fri, 19 Dec 2025 17:58:38 +0100 Subject: [PATCH] Add file info route --- arbalister/adbc.py | 23 +++++++++++++++-------- arbalister/routes.py | 30 ++++++++++++++++++++++++++++++ arbalister/tests/test_routes.py | 20 ++++++++++++++++++++ 3 files changed, 65 insertions(+), 8 deletions(-) diff --git a/arbalister/adbc.py b/arbalister/adbc.py index 54579d7..dd71599 100644 --- a/arbalister/adbc.py +++ b/arbalister/adbc.py @@ -37,19 +37,26 @@ class SqliteDataFrame: _offset: int | None = None _select: list[str] | None = None - @classmethod - def read_sqlite(cls, context: Any, path: pathlib.Path | str, table_name: str | None = None) -> Self: - """Read an Sqlite file metadata and start a new DataFrame plan.""" + @staticmethod + def get_table_names(path: pathlib.Path | str) -> list[str]: + """Get the list of table names in a SQLite database.""" with adbc_sqlite.connect(str(path)) as connection: with connection.cursor() as cursor: cursor.execute("SELECT name FROM sqlite_master WHERE type='table'") - tables = [row for (row,) in cursor.fetchall()] + return [row for (row,) in cursor.fetchall()] - if table_name is None and len(tables) > 0: - table_name = tables[0] - if table_name not in tables or table_name is None: - raise ValueError(f"Invalid table name {table_name}") + @classmethod + def read_sqlite(cls, context: Any, path: pathlib.Path | str, table_name: str | None = None) -> Self: + """Read an Sqlite file metadata and start a new DataFrame plan.""" + tables = cls.get_table_names(path) + + if table_name is None and len(tables) > 0: + table_name = tables[0] + if table_name not in tables or table_name is None: + raise ValueError(f"Invalid table name {table_name}") + with adbc_sqlite.connect(str(path)) as connection: + with connection.cursor() as cursor: cursor.execute(f'SELECT COUNT(*) FROM "{table_name}"') num_rows = cursor.fetchone()[0] # type: ignore[index] diff --git a/arbalister/routes.py b/arbalister/routes.py index ccd306e..3b0fc8d 100644 --- a/arbalister/routes.py +++ b/arbalister/routes.py @@ -139,6 +139,13 @@ class StatsResponse: num_cols: int = 0 +@dataclasses.dataclass(frozen=True, slots=True) +class SqliteFileInfo: + """File-specific information returned in the file info route.""" + + table_names: list[str] | None = None + + class StatsRouteHandler(BaseRouteHandler): """An handler to get file in IPC.""" @@ -180,6 +187,28 @@ async def get(self, path: str) -> None: await self.finish(dataclasses.asdict(response)) +class FileInfoRouteHandler(BaseRouteHandler): + """A handler to get file-specific information.""" + + @tornado.web.authenticated + async def get(self, path: str) -> None: + """HTTP GET return file-specific information.""" + file = self.data_file(path) + file_format = ff.FileFormat.from_filename(file) + + table_names: list[str] | None = None + + if file_format == ff.FileFormat.Sqlite: + from . import adbc + + table_names = adbc.SqliteDataFrame.get_table_names(file) + + response = SqliteFileInfo(table_names=table_names) + await self.finish(dataclasses.asdict(response)) + + await self.finish({}) + + def make_datafusion_config() -> dn.SessionConfig: """Return the datafusion config.""" config = ( @@ -203,6 +232,7 @@ def setup_route_handlers(web_app: jupyter_server.serverapp.ServerWebApplication) handlers = [ (url_path_join(base_url, r"arrow/stream/([^?]*)"), IpcRouteHandler, {"context": context}), (url_path_join(base_url, r"arrow/stats/([^?]*)"), StatsRouteHandler, {"context": context}), + (url_path_join(base_url, r"file/info/([^?]*)"), FileInfoRouteHandler, {"context": context}), ] web_app.add_handlers(host_pattern, handlers) # type: ignore[no-untyped-call] diff --git a/arbalister/tests/test_routes.py b/arbalister/tests/test_routes.py index 651b0b9..4397d7d 100644 --- a/arbalister/tests/test_routes.py +++ b/arbalister/tests/test_routes.py @@ -227,3 +227,23 @@ async def test_stats_route( table = pa.ipc.open_stream(table_64).read_all() assert table.num_rows == 0 assert table.schema.names == full_table.schema.names + + +async def test_file_info_route_sqlite( + jp_fetch: JpFetch, + table_file: pathlib.Path, + file_format: ff.FileFormat, +) -> None: + """Test fetching file info for SQLite files returns table names.""" + response = await jp_fetch("file/info/", str(table_file)) + + assert response.code == 200 + assert response.headers["Content-Type"] == "application/json; charset=UTF-8" + + payload = json.loads(response.body) + + if file_format == ff.FileFormat.Sqlite: + assert payload["table_names"] is not None + assert isinstance(payload["table_names"], list) + assert "dummy_table_1" in payload["table_names"] + assert "dummy_table_2" in payload["table_names"]