Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions arbalister/adbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,19 +37,26 @@ class SqliteDataFrame:
_offset: int | None = None
_select: list[str] | None = None

@classmethod
def read_sqlite(cls, context: Any, path: pathlib.Path | str, table_name: str | None = None) -> Self:
"""Read an Sqlite file metadata and start a new DataFrame plan."""
@staticmethod
def get_table_names(path: pathlib.Path | str) -> list[str]:
"""Get the list of table names in a SQLite database."""
with adbc_sqlite.connect(str(path)) as connection:
with connection.cursor() as cursor:
cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
tables = [row for (row,) in cursor.fetchall()]
return [row for (row,) in cursor.fetchall()]

if table_name is None and len(tables) > 0:
table_name = tables[0]
if table_name not in tables or table_name is None:
raise ValueError(f"Invalid table name {table_name}")
@classmethod
def read_sqlite(cls, context: Any, path: pathlib.Path | str, table_name: str | None = None) -> Self:
"""Read an Sqlite file metadata and start a new DataFrame plan."""
tables = cls.get_table_names(path)

if table_name is None and len(tables) > 0:
table_name = tables[0]
if table_name not in tables or table_name is None:
raise ValueError(f"Invalid table name {table_name}")

with adbc_sqlite.connect(str(path)) as connection:
with connection.cursor() as cursor:
cursor.execute(f'SELECT COUNT(*) FROM "{table_name}"')
num_rows = cursor.fetchone()[0] # type: ignore[index]

Expand Down
30 changes: 30 additions & 0 deletions arbalister/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,13 @@ class StatsResponse:
num_cols: int = 0


@dataclasses.dataclass(frozen=True, slots=True)
class SqliteFileInfo:
"""File-specific information returned in the file info route."""

table_names: list[str] | None = None


class StatsRouteHandler(BaseRouteHandler):
"""An handler to get file in IPC."""

Expand Down Expand Up @@ -180,6 +187,28 @@ async def get(self, path: str) -> None:
await self.finish(dataclasses.asdict(response))


class FileInfoRouteHandler(BaseRouteHandler):
"""A handler to get file-specific information."""

@tornado.web.authenticated
async def get(self, path: str) -> None:
"""HTTP GET return file-specific information."""
file = self.data_file(path)
file_format = ff.FileFormat.from_filename(file)

table_names: list[str] | None = None

if file_format == ff.FileFormat.Sqlite:
from . import adbc

table_names = adbc.SqliteDataFrame.get_table_names(file)

response = SqliteFileInfo(table_names=table_names)
await self.finish(dataclasses.asdict(response))

await self.finish({})


def make_datafusion_config() -> dn.SessionConfig:
"""Return the datafusion config."""
config = (
Expand All @@ -203,6 +232,7 @@ def setup_route_handlers(web_app: jupyter_server.serverapp.ServerWebApplication)
handlers = [
(url_path_join(base_url, r"arrow/stream/([^?]*)"), IpcRouteHandler, {"context": context}),
(url_path_join(base_url, r"arrow/stats/([^?]*)"), StatsRouteHandler, {"context": context}),
(url_path_join(base_url, r"file/info/([^?]*)"), FileInfoRouteHandler, {"context": context}),
]

web_app.add_handlers(host_pattern, handlers) # type: ignore[no-untyped-call]
20 changes: 20 additions & 0 deletions arbalister/tests/test_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,3 +227,23 @@ async def test_stats_route(
table = pa.ipc.open_stream(table_64).read_all()
assert table.num_rows == 0
assert table.schema.names == full_table.schema.names


async def test_file_info_route_sqlite(
jp_fetch: JpFetch,
table_file: pathlib.Path,
file_format: ff.FileFormat,
) -> None:
"""Test fetching file info for SQLite files returns table names."""
response = await jp_fetch("file/info/", str(table_file))

assert response.code == 200
assert response.headers["Content-Type"] == "application/json; charset=UTF-8"

payload = json.loads(response.body)

if file_format == ff.FileFormat.Sqlite:
assert payload["table_names"] is not None
assert isinstance(payload["table_names"], list)
assert "dummy_table_1" in payload["table_names"]
assert "dummy_table_2" in payload["table_names"]
Loading