Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 58 additions & 51 deletions arbalister/arrow.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import codecs
import pathlib
from typing import Any, Callable

Expand All @@ -9,6 +10,63 @@
ReadCallable = Callable[..., dn.DataFrame]


def _read_csv(
ctx: dn.SessionContext, path: str | pathlib.Path, delimiter: str, **kwargs: dict[str, Any]
) -> dn.DataFrame:
if len(delimiter) > 1:
delimiter = codecs.decode(delimiter, "unicode_escape")
return ctx.read_csv(path, delimiter=delimiter, **kwargs) # type: ignore[arg-type]


def _read_ipc(ctx: dn.SessionContext, path: str | pathlib.Path, **kwargs: dict[str, Any]) -> dn.DataFrame:
import pyarrow.feather

# table = pyarrow.feather.read_table(path, {**{"memory_map": True}, **kwargs})
table = pyarrow.feather.read_table(path, **kwargs)
return ctx.from_arrow(table)


def _read_orc(ctx: dn.SessionContext, path: str | pathlib.Path, **kwargs: dict[str, Any]) -> dn.DataFrame:
# Watch for https://github.com/datafusion-contrib/datafusion-orc
# Evolution for native datafusion reader
import pyarrow.orc

table = pyarrow.orc.read_table(path, **kwargs)
return ctx.from_arrow(table)


def get_table_reader(format: ff.FileFormat) -> ReadCallable:
"""Get the datafusion reader factory function for the given format."""
# TODO: datafusion >= 50.0
# def read(ctx: dtfn.SessionContext, path: str | pathlib.Path, *args, **kwargs) -> dtfn.DataFrame:
# ds = pads.dataset(source=path, format=format.value)
# return ctx.read_table(ds, *args, **kwargs)
out: ReadCallable
match format:
case ff.FileFormat.Avro:
out = dn.SessionContext.read_avro
case ff.FileFormat.Csv:
out = _read_csv
case ff.FileFormat.Parquet:
out = dn.SessionContext.read_parquet
case ff.FileFormat.Ipc:
out = _read_ipc
case ff.FileFormat.Orc:
out = _read_orc
case ff.FileFormat.Sqlite:
from . import adbc as adbc

# FIXME: For now we just pretend SqliteDataFrame is a datafusion DataFrame
# Either we integrate it properly into Datafusion, or we create a DataFrame as a
# typing.protocol.
out = adbc.SqliteDataFrame.read_sqlite # type: ignore[assignment]

return out


WriteCallable = Callable[..., None]


def _arrow_to_avro_type(field: pa.Field) -> str | dict[str, Any]:
t = field.type
if pa.types.is_integer(t):
Expand Down Expand Up @@ -55,57 +113,6 @@ def _write_avro(
writer.close()


def get_table_reader(format: ff.FileFormat) -> ReadCallable:
"""Get the datafusion reader factory function for the given format."""
# TODO: datafusion >= 50.0
# def read(ctx: dtfn.SessionContext, path: str | pathlib.Path, *args, **kwargs) -> dtfn.DataFrame:
# ds = pads.dataset(source=path, format=format.value)
# return ctx.read_table(ds, *args, **kwargs)
out: ReadCallable
match format:
case ff.FileFormat.Avro:
out = dn.SessionContext.read_avro
case ff.FileFormat.Csv:
out = dn.SessionContext.read_csv
case ff.FileFormat.Parquet:
out = dn.SessionContext.read_parquet
case ff.FileFormat.Ipc:
import pyarrow.feather

def read_ipc(
ctx: dn.SessionContext, path: str | pathlib.Path, **kwargs: dict[str, Any]
) -> dn.DataFrame:
# table = pyarrow.feather.read_table(path, {**{"memory_map": True}, **kwargs})
table = pyarrow.feather.read_table(path, **kwargs)
return ctx.from_arrow(table)

out = read_ipc
case ff.FileFormat.Orc:
# Watch for https://github.com/datafusion-contrib/datafusion-orc
# Evolution for native datafusion reader
import pyarrow.orc

def read_orc(
ctx: dn.SessionContext, path: str | pathlib.Path, **kwargs: dict[str, Any]
) -> dn.DataFrame:
table = pyarrow.orc.read_table(path, **kwargs)
return ctx.from_arrow(table)

out = read_orc
case ff.FileFormat.Sqlite:
from . import adbc as adbc

# FIXME: For now we just pretend SqliteDataFrame is a datafision DataFrame
# Either we integrate it properly into Datafusion, or we create a DataFrame as a
# typing.protocol.
out = adbc.SqliteDataFrame.read_sqlite # type: ignore[assignment]

return out


WriteCallable = Callable[..., None]


def get_table_writer(format: ff.FileFormat) -> WriteCallable:
"""Get the arrow writer factory function for the given format."""
out: WriteCallable
Expand Down
6 changes: 3 additions & 3 deletions src/__tests__/model.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@ const MOCK_TABLE = tableFromArrays({
score: [85, 90, 78, 92, 88, 76, 95, 81, 89, 93],
});

async function fetchStatsMocked(_params: Req.StatsParams): Promise<Req.StatsResponse> {
async function fetchStatsMocked(_params: Req.StatsOptions): Promise<Req.StatsResponse> {
return {
num_rows: MOCK_TABLE.numRows,
num_cols: MOCK_TABLE.numCols,
};
}

async function fetchTableMocked(params: Req.TableParams): Promise<Arrow.Table> {
async function fetchTableMocked(params: Req.TableOptions): Promise<Arrow.Table> {
let table: Arrow.Table = MOCK_TABLE;

if (params.row_chunk !== undefined && params.row_chunk_size !== undefined) {
Expand Down Expand Up @@ -51,7 +51,7 @@ describe("ArrowModel", () => {
(fetchTable as jest.Mock).mockImplementation(fetchTableMocked);
(fetchStats as jest.Mock).mockImplementation(fetchStatsMocked);

const model = new ArrowModel({ path: "test/path.parquet" });
const model = new ArrowModel({ path: "test/path.parquet" }, {});

it("should initialize data", async () => {
await model.ready;
Expand Down
9 changes: 9 additions & 0 deletions src/file_options.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
export interface CsvOptions {
delimiter?: string;
}

export const DEFAULT_CSV_OPTIONS: Required<CsvOptions> = {
delimiter: ",",
};

export type FileOptions = CsvOptions;
49 changes: 34 additions & 15 deletions src/filetypes.ts
Original file line number Diff line number Diff line change
@@ -1,86 +1,105 @@
import type { DocumentRegistry } from "@jupyterlab/docregistry";

export enum FileType {
Avro = "apache-avro",
Csv = "csv",
Ipc = "apache-arrow-ipc-avro",
Orc = "apache-orc",
Parquet = "apache-parquet",
Sqlite = "sqlite",
}

export function ensureCsvFileType(docRegistry: DocumentRegistry): DocumentRegistry.IFileType {
const ft = docRegistry.getFileType(FileType.Csv);
if (ft) {
return ft;
}
docRegistry.addFileType({
name: FileType.Csv,
displayName: "CSV",
mimeTypes: ["text/csv"],
extensions: [".csv"],
contentType: "file",
});
return docRegistry.getFileType(FileType.Csv)!;
}

export function addAvroFileType(
docRegistry: DocumentRegistry,
options: Partial<DocumentRegistry.IFileType> = {},
): DocumentRegistry.IFileType {
const name = "apache-avro";
docRegistry.addFileType({
...options,
name,
name: FileType.Avro,
displayName: "Avro",
mimeTypes: ["application/avro-binary"],
extensions: [".avro"],
contentType: "file",
fileFormat: "base64",
});
return docRegistry.getFileType(name)!;
return docRegistry.getFileType(FileType.Avro)!;
}

export function addParquetFileType(
docRegistry: DocumentRegistry,
options: Partial<DocumentRegistry.IFileType> = {},
): DocumentRegistry.IFileType {
const name = "apache-parquet";
docRegistry.addFileType({
...options,
name,
name: FileType.Parquet,
displayName: "Parquet",
mimeTypes: ["application/vnd.apache.parquet"],
extensions: [".parquet"],
contentType: "file",
fileFormat: "base64",
});
return docRegistry.getFileType(name)!;
return docRegistry.getFileType(FileType.Parquet)!;
}

export function addIpcFileType(
docRegistry: DocumentRegistry,
options: Partial<DocumentRegistry.IFileType> = {},
): DocumentRegistry.IFileType {
const name = "apache-arrow-ipc";
docRegistry.addFileType({
...options,
name,
name: FileType.Ipc,
displayName: "Arrow IPC",
mimeTypes: ["application/vnd.apache.arrow.file"],
extensions: [".ipc", ".feather", ".arrow"],
contentType: "file",
fileFormat: "base64",
});
return docRegistry.getFileType(name)!;
return docRegistry.getFileType(FileType.Ipc)!;
}

export function addOrcFileType(
docRegistry: DocumentRegistry,
options: Partial<DocumentRegistry.IFileType> = {},
): DocumentRegistry.IFileType {
const name = "apache-orc";
docRegistry.addFileType({
...options,
name,
name: FileType.Orc,
displayName: "Arrow ORC",
mimeTypes: ["application/octet-stream"],
extensions: [".orc"],
contentType: "file",
fileFormat: "base64",
});
return docRegistry.getFileType(name)!;
return docRegistry.getFileType(FileType.Orc)!;
}

export function addSqliteFileType(
docRegistry: DocumentRegistry,
options: Partial<DocumentRegistry.IFileType> = {},
): DocumentRegistry.IFileType {
const name = "sqlite";
docRegistry.addFileType({
...options,
name,
name: FileType.Sqlite,
displayName: "SQLite",
mimeTypes: ["application/vnd.sqlite3"],
extensions: [".sqlite", ".sqlite3", ".db", ".db3", ".s3db", ".sl3"],
contentType: "file",
fileFormat: "base64",
});
return docRegistry.getFileType(name)!;
return docRegistry.getFileType(FileType.Sqlite)!;
}
45 changes: 14 additions & 31 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { IThemeManager, showErrorMessage, WidgetTracker } from "@jupyterlab/appu
import { IDefaultDrive } from "@jupyterlab/services";
import { ITranslator } from "@jupyterlab/translation";
import type { JupyterFrontEnd, JupyterFrontEndPlugin } from "@jupyterlab/application";
import type { DocumentRegistry, IDocumentWidget } from "@jupyterlab/docregistry";
import type { IDocumentWidget } from "@jupyterlab/docregistry";
import type * as services from "@jupyterlab/services";
import type { Contents } from "@jupyterlab/services";
import type { DataGrid } from "@lumino/datagrid";
Expand All @@ -14,6 +14,7 @@ import {
addOrcFileType,
addParquetFileType,
addSqliteFileType,
ensureCsvFileType,
} from "./filetypes";
import {
getArrowIPCIcon,
Expand Down Expand Up @@ -77,22 +78,6 @@ const arrowGrid: JupyterFrontEndPlugin<void> = {
autoStart: true,
};

function ensureCsvFileType(docRegistry: DocumentRegistry): DocumentRegistry.IFileType {
const name = "csv";
const ft = docRegistry.getFileType(name)!;
if (ft) {
return ft;
}
docRegistry.addFileType({
name,
displayName: "CSV",
mimeTypes: ["text/csv"],
extensions: [".csv"],
contentType: "file",
});
return docRegistry.getFileType(name)!;
}

function activateArrowGrid(
app: JupyterFrontEnd,
translator: ITranslator,
Expand Down Expand Up @@ -128,15 +113,18 @@ function activateArrowGrid(
let orc_ft = addOrcFileType(app.docRegistry, { icon: getORCIcon(isLight) });
let sqlite_ft = addSqliteFileType(app.docRegistry, { icon: getSqliteIcon(isLight) });

const factory = new ArrowGridViewerFactory({
name: factory_arrow,
label: trans.__("Arrow Dataframe Viewer"),
fileTypes: [csv_ft.name, avo_ft.name, prq_ft.name, ipc_ft.name, orc_ft.name, sqlite_ft.name],
defaultFor: [csv_ft.name, avo_ft.name, prq_ft.name, ipc_ft.name, orc_ft.name, sqlite_ft.name],
readOnly: true,
translator,
contentProviderId: NOOP_CONTENT_PROVIDER_ID,
});
const factory = new ArrowGridViewerFactory(
{
name: factory_arrow,
label: trans.__("Arrow Dataframe Viewer"),
fileTypes: [csv_ft.name, avo_ft.name, prq_ft.name, ipc_ft.name, orc_ft.name, sqlite_ft.name],
defaultFor: [csv_ft.name, avo_ft.name, prq_ft.name, ipc_ft.name, orc_ft.name, sqlite_ft.name],
readOnly: true,
translator,
contentProviderId: NOOP_CONTENT_PROVIDER_ID,
},
app.docRegistry,
);
const tracker = new WidgetTracker<IDocumentWidget<ArrowGridViewer>>({
namespace: "arrowviewer",
});
Expand All @@ -162,11 +150,6 @@ function activateArrowGrid(
void tracker.save(widget);
});

if (csv_ft) {
widget.title.icon = csv_ft.icon;
widget.title.iconClass = csv_ft.iconClass!;
widget.title.iconLabel = csv_ft.iconLabel!;
}
await widget.content.ready;
widget.content.style = style;
widget.content.rendererConfig = rendererConfig;
Expand Down
Loading
Loading