diff --git a/arbalister/arrow.py b/arbalister/arrow.py index 4a4dc5b..b4b7c3e 100644 --- a/arbalister/arrow.py +++ b/arbalister/arrow.py @@ -1,3 +1,4 @@ +import codecs import pathlib from typing import Any, Callable @@ -9,6 +10,63 @@ ReadCallable = Callable[..., dn.DataFrame] +def _read_csv( + ctx: dn.SessionContext, path: str | pathlib.Path, delimiter: str, **kwargs: dict[str, Any] +) -> dn.DataFrame: + if len(delimiter) > 1: + delimiter = codecs.decode(delimiter, "unicode_escape") + return ctx.read_csv(path, delimiter=delimiter, **kwargs) # type: ignore[arg-type] + + +def _read_ipc(ctx: dn.SessionContext, path: str | pathlib.Path, **kwargs: dict[str, Any]) -> dn.DataFrame: + import pyarrow.feather + + # table = pyarrow.feather.read_table(path, {**{"memory_map": True}, **kwargs}) + table = pyarrow.feather.read_table(path, **kwargs) + return ctx.from_arrow(table) + + +def _read_orc(ctx: dn.SessionContext, path: str | pathlib.Path, **kwargs: dict[str, Any]) -> dn.DataFrame: + # Watch for https://github.com/datafusion-contrib/datafusion-orc + # Evolution for native datafusion reader + import pyarrow.orc + + table = pyarrow.orc.read_table(path, **kwargs) + return ctx.from_arrow(table) + + +def get_table_reader(format: ff.FileFormat) -> ReadCallable: + """Get the datafusion reader factory function for the given format.""" + # TODO: datafusion >= 50.0 + # def read(ctx: dtfn.SessionContext, path: str | pathlib.Path, *args, **kwargs) -> dtfn.DataFrame: + # ds = pads.dataset(source=path, format=format.value) + # return ctx.read_table(ds, *args, **kwargs) + out: ReadCallable + match format: + case ff.FileFormat.Avro: + out = dn.SessionContext.read_avro + case ff.FileFormat.Csv: + out = _read_csv + case ff.FileFormat.Parquet: + out = dn.SessionContext.read_parquet + case ff.FileFormat.Ipc: + out = _read_ipc + case ff.FileFormat.Orc: + out = _read_orc + case ff.FileFormat.Sqlite: + from . import adbc as adbc + + # FIXME: For now we just pretend SqliteDataFrame is a datafusion DataFrame + # Either we integrate it properly into Datafusion, or we create a DataFrame as a + # typing.protocol. + out = adbc.SqliteDataFrame.read_sqlite # type: ignore[assignment] + + return out + + +WriteCallable = Callable[..., None] + + def _arrow_to_avro_type(field: pa.Field) -> str | dict[str, Any]: t = field.type if pa.types.is_integer(t): @@ -55,57 +113,6 @@ def _write_avro( writer.close() -def get_table_reader(format: ff.FileFormat) -> ReadCallable: - """Get the datafusion reader factory function for the given format.""" - # TODO: datafusion >= 50.0 - # def read(ctx: dtfn.SessionContext, path: str | pathlib.Path, *args, **kwargs) -> dtfn.DataFrame: - # ds = pads.dataset(source=path, format=format.value) - # return ctx.read_table(ds, *args, **kwargs) - out: ReadCallable - match format: - case ff.FileFormat.Avro: - out = dn.SessionContext.read_avro - case ff.FileFormat.Csv: - out = dn.SessionContext.read_csv - case ff.FileFormat.Parquet: - out = dn.SessionContext.read_parquet - case ff.FileFormat.Ipc: - import pyarrow.feather - - def read_ipc( - ctx: dn.SessionContext, path: str | pathlib.Path, **kwargs: dict[str, Any] - ) -> dn.DataFrame: - # table = pyarrow.feather.read_table(path, {**{"memory_map": True}, **kwargs}) - table = pyarrow.feather.read_table(path, **kwargs) - return ctx.from_arrow(table) - - out = read_ipc - case ff.FileFormat.Orc: - # Watch for https://github.com/datafusion-contrib/datafusion-orc - # Evolution for native datafusion reader - import pyarrow.orc - - def read_orc( - ctx: dn.SessionContext, path: str | pathlib.Path, **kwargs: dict[str, Any] - ) -> dn.DataFrame: - table = pyarrow.orc.read_table(path, **kwargs) - return ctx.from_arrow(table) - - out = read_orc - case ff.FileFormat.Sqlite: - from . import adbc as adbc - - # FIXME: For now we just pretend SqliteDataFrame is a datafision DataFrame - # Either we integrate it properly into Datafusion, or we create a DataFrame as a - # typing.protocol. - out = adbc.SqliteDataFrame.read_sqlite # type: ignore[assignment] - - return out - - -WriteCallable = Callable[..., None] - - def get_table_writer(format: ff.FileFormat) -> WriteCallable: """Get the arrow writer factory function for the given format.""" out: WriteCallable diff --git a/src/__tests__/model.spec.ts b/src/__tests__/model.spec.ts index ea3c71c..e4230b9 100644 --- a/src/__tests__/model.spec.ts +++ b/src/__tests__/model.spec.ts @@ -15,14 +15,14 @@ const MOCK_TABLE = tableFromArrays({ score: [85, 90, 78, 92, 88, 76, 95, 81, 89, 93], }); -async function fetchStatsMocked(_params: Req.StatsParams): Promise { +async function fetchStatsMocked(_params: Req.StatsOptions): Promise { return { num_rows: MOCK_TABLE.numRows, num_cols: MOCK_TABLE.numCols, }; } -async function fetchTableMocked(params: Req.TableParams): Promise { +async function fetchTableMocked(params: Req.TableOptions): Promise { let table: Arrow.Table = MOCK_TABLE; if (params.row_chunk !== undefined && params.row_chunk_size !== undefined) { @@ -51,7 +51,7 @@ describe("ArrowModel", () => { (fetchTable as jest.Mock).mockImplementation(fetchTableMocked); (fetchStats as jest.Mock).mockImplementation(fetchStatsMocked); - const model = new ArrowModel({ path: "test/path.parquet" }); + const model = new ArrowModel({ path: "test/path.parquet" }, {}); it("should initialize data", async () => { await model.ready; diff --git a/src/file_options.ts b/src/file_options.ts new file mode 100644 index 0000000..cf81573 --- /dev/null +++ b/src/file_options.ts @@ -0,0 +1,9 @@ +export interface CsvOptions { + delimiter?: string; +} + +export const DEFAULT_CSV_OPTIONS: Required = { + delimiter: ",", +}; + +export type FileOptions = CsvOptions; diff --git a/src/filetypes.ts b/src/filetypes.ts index 5358859..3716a91 100644 --- a/src/filetypes.ts +++ b/src/filetypes.ts @@ -1,86 +1,105 @@ import type { DocumentRegistry } from "@jupyterlab/docregistry"; +export enum FileType { + Avro = "apache-avro", + Csv = "csv", + Ipc = "apache-arrow-ipc-avro", + Orc = "apache-orc", + Parquet = "apache-parquet", + Sqlite = "sqlite", +} + +export function ensureCsvFileType(docRegistry: DocumentRegistry): DocumentRegistry.IFileType { + const ft = docRegistry.getFileType(FileType.Csv); + if (ft) { + return ft; + } + docRegistry.addFileType({ + name: FileType.Csv, + displayName: "CSV", + mimeTypes: ["text/csv"], + extensions: [".csv"], + contentType: "file", + }); + return docRegistry.getFileType(FileType.Csv)!; +} + export function addAvroFileType( docRegistry: DocumentRegistry, options: Partial = {}, ): DocumentRegistry.IFileType { - const name = "apache-avro"; docRegistry.addFileType({ ...options, - name, + name: FileType.Avro, displayName: "Avro", mimeTypes: ["application/avro-binary"], extensions: [".avro"], contentType: "file", fileFormat: "base64", }); - return docRegistry.getFileType(name)!; + return docRegistry.getFileType(FileType.Avro)!; } export function addParquetFileType( docRegistry: DocumentRegistry, options: Partial = {}, ): DocumentRegistry.IFileType { - const name = "apache-parquet"; docRegistry.addFileType({ ...options, - name, + name: FileType.Parquet, displayName: "Parquet", mimeTypes: ["application/vnd.apache.parquet"], extensions: [".parquet"], contentType: "file", fileFormat: "base64", }); - return docRegistry.getFileType(name)!; + return docRegistry.getFileType(FileType.Parquet)!; } export function addIpcFileType( docRegistry: DocumentRegistry, options: Partial = {}, ): DocumentRegistry.IFileType { - const name = "apache-arrow-ipc"; docRegistry.addFileType({ ...options, - name, + name: FileType.Ipc, displayName: "Arrow IPC", mimeTypes: ["application/vnd.apache.arrow.file"], extensions: [".ipc", ".feather", ".arrow"], contentType: "file", fileFormat: "base64", }); - return docRegistry.getFileType(name)!; + return docRegistry.getFileType(FileType.Ipc)!; } export function addOrcFileType( docRegistry: DocumentRegistry, options: Partial = {}, ): DocumentRegistry.IFileType { - const name = "apache-orc"; docRegistry.addFileType({ ...options, - name, + name: FileType.Orc, displayName: "Arrow ORC", mimeTypes: ["application/octet-stream"], extensions: [".orc"], contentType: "file", fileFormat: "base64", }); - return docRegistry.getFileType(name)!; + return docRegistry.getFileType(FileType.Orc)!; } export function addSqliteFileType( docRegistry: DocumentRegistry, options: Partial = {}, ): DocumentRegistry.IFileType { - const name = "sqlite"; docRegistry.addFileType({ ...options, - name, + name: FileType.Sqlite, displayName: "SQLite", mimeTypes: ["application/vnd.sqlite3"], extensions: [".sqlite", ".sqlite3", ".db", ".db3", ".s3db", ".sl3"], contentType: "file", fileFormat: "base64", }); - return docRegistry.getFileType(name)!; + return docRegistry.getFileType(FileType.Sqlite)!; } diff --git a/src/index.ts b/src/index.ts index 0128567..3f7a0c2 100644 --- a/src/index.ts +++ b/src/index.ts @@ -3,7 +3,7 @@ import { IThemeManager, showErrorMessage, WidgetTracker } from "@jupyterlab/appu import { IDefaultDrive } from "@jupyterlab/services"; import { ITranslator } from "@jupyterlab/translation"; import type { JupyterFrontEnd, JupyterFrontEndPlugin } from "@jupyterlab/application"; -import type { DocumentRegistry, IDocumentWidget } from "@jupyterlab/docregistry"; +import type { IDocumentWidget } from "@jupyterlab/docregistry"; import type * as services from "@jupyterlab/services"; import type { Contents } from "@jupyterlab/services"; import type { DataGrid } from "@lumino/datagrid"; @@ -14,6 +14,7 @@ import { addOrcFileType, addParquetFileType, addSqliteFileType, + ensureCsvFileType, } from "./filetypes"; import { getArrowIPCIcon, @@ -77,22 +78,6 @@ const arrowGrid: JupyterFrontEndPlugin = { autoStart: true, }; -function ensureCsvFileType(docRegistry: DocumentRegistry): DocumentRegistry.IFileType { - const name = "csv"; - const ft = docRegistry.getFileType(name)!; - if (ft) { - return ft; - } - docRegistry.addFileType({ - name, - displayName: "CSV", - mimeTypes: ["text/csv"], - extensions: [".csv"], - contentType: "file", - }); - return docRegistry.getFileType(name)!; -} - function activateArrowGrid( app: JupyterFrontEnd, translator: ITranslator, @@ -128,15 +113,18 @@ function activateArrowGrid( let orc_ft = addOrcFileType(app.docRegistry, { icon: getORCIcon(isLight) }); let sqlite_ft = addSqliteFileType(app.docRegistry, { icon: getSqliteIcon(isLight) }); - const factory = new ArrowGridViewerFactory({ - name: factory_arrow, - label: trans.__("Arrow Dataframe Viewer"), - fileTypes: [csv_ft.name, avo_ft.name, prq_ft.name, ipc_ft.name, orc_ft.name, sqlite_ft.name], - defaultFor: [csv_ft.name, avo_ft.name, prq_ft.name, ipc_ft.name, orc_ft.name, sqlite_ft.name], - readOnly: true, - translator, - contentProviderId: NOOP_CONTENT_PROVIDER_ID, - }); + const factory = new ArrowGridViewerFactory( + { + name: factory_arrow, + label: trans.__("Arrow Dataframe Viewer"), + fileTypes: [csv_ft.name, avo_ft.name, prq_ft.name, ipc_ft.name, orc_ft.name, sqlite_ft.name], + defaultFor: [csv_ft.name, avo_ft.name, prq_ft.name, ipc_ft.name, orc_ft.name, sqlite_ft.name], + readOnly: true, + translator, + contentProviderId: NOOP_CONTENT_PROVIDER_ID, + }, + app.docRegistry, + ); const tracker = new WidgetTracker>({ namespace: "arrowviewer", }); @@ -162,11 +150,6 @@ function activateArrowGrid( void tracker.save(widget); }); - if (csv_ft) { - widget.title.icon = csv_ft.icon; - widget.title.iconClass = csv_ft.iconClass!; - widget.title.iconLabel = csv_ft.iconLabel!; - } await widget.content.ready; widget.content.style = style; widget.content.rendererConfig = rendererConfig; diff --git a/src/model.ts b/src/model.ts index 1630c7d..1a20073 100644 --- a/src/model.ts +++ b/src/model.ts @@ -4,9 +4,10 @@ import type * as Arrow from "apache-arrow"; import { PairMap } from "./collection"; import { fetchStats, fetchTable } from "./requests"; +import type { FileOptions } from "./file_options"; export namespace ArrowModel { - export interface IOptions { + export interface LoadingOptions { path: string; rowChunkSize?: number; colChunkSize?: number; @@ -15,13 +16,16 @@ export namespace ArrowModel { } export class ArrowModel extends DataModel { - constructor(options: ArrowModel.IOptions) { + constructor(loadingOptions: ArrowModel.LoadingOptions, fileOptions: FileOptions) { super(); - this._path = options.path; - this._rowChunkSize = options.rowChunkSize ?? 512; - this._colChunkSize = options.colChunkSize ?? 24; - this._loadingRepr = options.loadingRepr ?? ""; + this._loadingParams = { + rowChunkSize: 512, + colChunkSize: 24, + loadingRepr: "", + ...loadingOptions, + }; + this._fileOptions = fileOptions; this._ready = this.initialize(); } @@ -29,7 +33,7 @@ export class ArrowModel extends DataModel { protected async initialize(): Promise { const [schema, stats, chunk00] = await Promise.all([ this.fetchSchema(), - fetchStats({ path: this._path }), + fetchStats({ path: this._loadingParams.path, ...this._fileOptions }), this.fetchChunk([0, 0]), ]); @@ -83,12 +87,12 @@ export class ArrowModel extends DataModel { const chunk = this._chunks.get(chunk_idx)!; if (chunk instanceof Promise) { // Wait for Promise to complete and mark data as modified - return this._loadingRepr; + return this._loadingParams.loadingRepr; } // We have data - const row_idx_in_chunk = row % this._rowChunkSize; - const col_idx_in_chunk = col % this._colChunkSize; + const row_idx_in_chunk = row % this._loadingParams.rowChunkSize; + const col_idx_in_chunk = col % this._loadingParams.colChunkSize; const out = chunk.getChildAt(col_idx_in_chunk)?.get(row_idx_in_chunk).toString(); // Prefetch next chunks only once we have data for the current chunk. @@ -110,17 +114,18 @@ export class ArrowModel extends DataModel { }); this._chunks.set(chunk_idx, promise); - return this._loadingRepr; + return this._loadingParams.loadingRepr; } private async fetchChunk(chunk_idx: [number, number]) { const [row_chunk, col_chunk] = chunk_idx; return await fetchTable({ - path: this._path, - row_chunk_size: this._rowChunkSize, + path: this._loadingParams.path, + row_chunk_size: this._loadingParams.rowChunkSize, row_chunk: row_chunk, - col_chunk_size: this._colChunkSize, + col_chunk_size: this._loadingParams.colChunkSize, col_chunk: col_chunk, + ...this._fileOptions, }); } @@ -129,10 +134,10 @@ export class ArrowModel extends DataModel { this.emitChanged({ type: "cells-changed", region: "body", - row: row_chunk * this._rowChunkSize, - rowSpan: this._rowChunkSize, - column: col_chunk * this._colChunkSize, - columnSpan: this._colChunkSize, + row: row_chunk * this._loadingParams.rowChunkSize, + rowSpan: this._loadingParams.rowChunkSize, + column: col_chunk * this._loadingParams.colChunkSize, + columnSpan: this._loadingParams.colChunkSize, }); } @@ -149,15 +154,19 @@ export class ArrowModel extends DataModel { private async fetchSchema() { const table = await fetchTable({ - path: this._path, + path: this._loadingParams.path, row_chunk_size: 0, row_chunk: 0, + ...this._fileOptions, }); return table.schema; } private chunkIdx(row: number, col: number): [number, number] { - return [Math.floor(row / this._rowChunkSize), Math.floor(col / this._colChunkSize)]; + return [ + Math.floor(row / this._loadingParams.rowChunkSize), + Math.floor(col / this._loadingParams.colChunkSize), + ]; } private chunkIsValid(chunk_idx: [number, number]): boolean { @@ -168,10 +177,8 @@ export class ArrowModel extends DataModel { ); } - private _path: string; - private _rowChunkSize: number; - private _colChunkSize: number; - private _loadingRepr: string; + private readonly _loadingParams: Required; + private readonly _fileOptions: FileOptions; private _numRows: number = 0; private _numCols: number = 0; diff --git a/src/requests.ts b/src/requests.ts index d3f4cc9..b15ec0d 100644 --- a/src/requests.ts +++ b/src/requests.ts @@ -1,44 +1,71 @@ import { tableFromIPC } from "apache-arrow"; import type * as Arrow from "apache-arrow"; -export interface StatsParams { - readonly path: string; +import type { FileOptions } from "./file_options"; + +export interface StatsOptions { + path: string; } export interface StatsResponse { - readonly num_rows: number; - readonly num_cols: number; + num_rows: number; + num_cols: number; } -export async function fetchStats(params: StatsParams): Promise { - const response = await fetch(`/arrow/stats/${params.path}`); +/** + * Transform a union into a union where every member is optionally present. + */ +type OptionalizeUnion = { + [K in T extends unknown ? keyof T : never]?: T extends Record ? V : never; +}; + +export async function fetchStats( + params: Readonly, +): Promise { + const queryKeys = ["path", "delimiter"] as const; + + const query = new URLSearchParams(); + + for (const key of queryKeys) { + const value = (params as Readonly & OptionalizeUnion)[key]; + if (value !== undefined && value != null) { + query.set(key, value.toString()); + } + } + + const response = await fetch(`/arrow/stats/${params.path}?${query.toString()}`); const data = await response.json(); return data; } -export interface TableParams { - readonly path: string; - readonly row_chunk_size?: number; - readonly row_chunk?: number; - readonly col_chunk_size?: number; - readonly col_chunk?: number; +export interface TableOptions { + path: string; + row_chunk_size?: number; + row_chunk?: number; + col_chunk_size?: number; + col_chunk?: number; } -export async function fetchTable(params: TableParams): Promise { - const query: string[] = []; - if (params.row_chunk_size !== undefined) { - query.push(`row_chunk_size=${encodeURIComponent(params.row_chunk_size)}`); - } - if (params.row_chunk !== undefined) { - query.push(`row_chunk=${encodeURIComponent(params.row_chunk)}`); - } - if (params.col_chunk_size !== undefined) { - query.push(`col_chunk_size=${encodeURIComponent(params.col_chunk_size)}`); - } - if (params.col_chunk !== undefined) { - query.push(`col_chunk=${encodeURIComponent(params.col_chunk)}`); +export async function fetchTable( + params: Readonly, +): Promise { + const queryKeys = [ + "row_chunk_size", + "row_chunk", + "col_chunk_size", + "col_chunk", + "delimiter", + ] as const; + + const query = new URLSearchParams(); + + for (const key of queryKeys) { + const value = (params as Readonly & OptionalizeUnion)[key]; + if (value !== undefined && value != null) { + query.set(key, value.toString()); + } } - const queryString = query.length ? `?${query.join("&")}` : ""; - const url = `/arrow/stream/${params.path}${queryString}`; + + const url = `/arrow/stream/${params.path}?${query.toString()}`; return await tableFromIPC(fetch(url)); } diff --git a/src/toolbar.ts b/src/toolbar.ts new file mode 100644 index 0000000..c04e3b8 --- /dev/null +++ b/src/toolbar.ts @@ -0,0 +1,107 @@ +// Copyright (c) Jupyter Development Team. +// Distributed under the terms of the Modified BSD License. + +import { nullTranslator } from "@jupyterlab/translation"; +import { Styling } from "@jupyterlab/ui-components"; +import { Widget } from "@lumino/widgets"; +import type { ITranslator } from "@jupyterlab/translation"; +import type { Message } from "@lumino/messaging"; + +import type { CsvOptions } from "./file_options"; +import type { ArrowGridViewer } from "./widget"; + +export namespace CsvToolbar { + export interface Options { + gridViewer: ArrowGridViewer; + translator?: ITranslator; + } +} + +export class CsvToolbar extends Widget { + constructor(options: CsvToolbar.Options, fileOptions: Required) { + super({ + node: Private.createDelimiterNode(fileOptions.delimiter, options.translator), + }); + this._gridViewer = options.gridViewer; + this.addClass("arrow-viewer-toolbar"); + } + + get fileOptions(): CsvOptions { + return { + delimiter: this.delimiterNode.value, + }; + } + + get delimiterNode(): HTMLSelectElement { + return this.node.getElementsByTagName("select")![0]; + } + + /** + * Handle the DOM events for the widget. + * + * @param event - The DOM event sent to the widget. + * + * #### Notes + * This method implements the DOM `EventListener` interface and is + * called in response to events on the dock panel's node. It should + * not be called directly by user code. + */ + handleEvent(event: Event): void { + switch (event.type) { + case "change": + this._gridViewer.updateFileOptions(this.fileOptions); + break; + default: + break; + } + } + + protected onAfterAttach(_msg: Message): void { + this.delimiterNode.addEventListener("change", this); + } + + protected onBeforeDetach(_msg: Message): void { + this.delimiterNode.removeEventListener("change", this); + } + + protected _gridViewer: ArrowGridViewer; +} + +namespace Private { + /** + * Create the node for the delimiter switcher. + */ + export function createDelimiterNode(selected: string, translator?: ITranslator): HTMLElement { + translator = translator || nullTranslator; + const trans = translator?.load("jupyterlab"); + + // The supported parsing delimiters and labels. + const delimiters = [ + [",", ","], + [";", ";"], + ["\\t", trans.__("tab")], + ["|", trans.__("pipe")], + ["#", trans.__("hash")], + ]; + + const div = document.createElement("div"); + const label = document.createElement("span"); + const select = document.createElement("select"); + label.textContent = trans.__("Delimiter: "); + label.className = "toolbar-label"; + for (const [delimiter, label] of delimiters) { + const option = document.createElement("option"); + option.value = delimiter; + option.textContent = label; + if (delimiter === selected) { + option.selected = true; + } + select.appendChild(option); + } + div.appendChild(label); + const node = Styling.wrapSelect(select); + node.classList.add("toolbar-dropdown"); + div.appendChild(node); + return div; + } +} diff --git a/src/widget.ts b/src/widget.ts index 9f76e33..e9986a4 100644 --- a/src/widget.ts +++ b/src/widget.ts @@ -6,18 +6,23 @@ import { Panel } from "@lumino/widgets"; import type { DocumentRegistry, IDocumentWidget } from "@jupyterlab/docregistry"; import type * as DataGridModule from "@lumino/datagrid"; +import { DEFAULT_CSV_OPTIONS } from "./file_options"; +import { FileType } from "./filetypes"; import { ArrowModel } from "./model"; +import { CsvToolbar } from "./toolbar"; +import type { FileOptions } from "./file_options"; export namespace ArrowGridViewer { - export interface IOptions { + export interface Options { path: string; } } export class ArrowGridViewer extends Panel { - constructor(options: ArrowGridViewer.IOptions) { + constructor(options: ArrowGridViewer.Options, fileOptions: FileOptions) { super(); this._options = options; + this._fileOptions = fileOptions; this.addClass("arrow-viewer"); @@ -49,12 +54,29 @@ export class ArrowGridViewer extends Panel { return this._options.path; } + get fileOptions(): Readonly { + return this._fileOptions; + } + + set fileOptions(fileOptions: FileOptions) { + this._fileOptions = fileOptions; + this._updateGrid(); + } + + updateFileOptions(fileOptionsUpdate: Partial) { + this.fileOptions = { + ...this.fileOptions, + ...fileOptionsUpdate, + }; + } + /** * The style used by the data grid. */ get style(): DataGridModule.DataGrid.Style { return this._grid.style; } + set style(value: DataGridModule.DataGrid.Style) { this._grid.style = { ...this._defaultStyle, ...value }; } @@ -75,7 +97,7 @@ export class ArrowGridViewer extends Panel { private async _updateGrid() { try { - const model = new ArrowModel({ path: this.path }); + const model = new ArrowModel({ path: this.path }, this.fileOptions); await model.ready; this._grid.dataModel = model; } catch (error) { @@ -115,7 +137,8 @@ export class ArrowGridViewer extends Panel { }); } - private _options: ArrowGridViewer.IOptions; + private _options: ArrowGridViewer.Options; + private _fileOptions: FileOptions; private _grid: DataGridModule.DataGrid; private _revealed = new PromiseDelegate(); private _ready: Promise; @@ -130,9 +153,9 @@ export namespace ArrowGridDocumentWidget { } export class ArrowGridDocumentWidget extends DocumentWidget { - constructor(options: ArrowGridDocumentWidget.IOptions) { + constructor(options: ArrowGridDocumentWidget.IOptions, fileOptions: FileOptions) { let { content, context, reveal, ...other } = options; - content = content || ArrowGridDocumentWidget._createContent(context); + content = content || ArrowGridDocumentWidget._createContent(context, fileOptions); reveal = Promise.all([reveal, content.revealed, context.ready]); super({ content, context, reveal, ...other }); this.addClass("arrow-viewer-base"); @@ -140,16 +163,82 @@ export class ArrowGridDocumentWidget extends DocumentWidget { private static _createContent( context: DocumentRegistry.IContext, + fileOptions: FileOptions, ): ArrowGridViewer { - return new ArrowGridViewer({ path: context.path }); + return new ArrowGridViewer({ path: context.path }, fileOptions); } } export class ArrowGridViewerFactory extends ABCWidgetFactory> { + constructor( + options: DocumentRegistry.IWidgetFactoryOptions>, + docRegistry: DocumentRegistry, + ) { + super(options); + this._docRegistry = docRegistry; + } + protected createNewWidget(context: DocumentRegistry.Context): IDocumentWidget { const translator = this.translator; - return new ArrowGridDocumentWidget({ context, translator }); + const ft = this.fileType(context.path); + + let fileOption: FileOptions = {}; + if (ft?.name === FileType.Csv) { + fileOption = DEFAULT_CSV_OPTIONS; + } + const widget = new ArrowGridDocumentWidget({ context, translator }, fileOption); + this.updateIcon(widget); + return widget; + } + + /** + * Default factory for toolbar items to be added after the widget is created. + */ + protected defaultToolbarFactory( + widget: IDocumentWidget, + ): DocumentRegistry.IToolbarItem[] { + const ft = this.fileType(widget.context.path); + if (ft?.name === FileType.Csv) { + return [ + { + name: "arbalister:csv-toolbar", + widget: new CsvToolbar( + { + gridViewer: widget.content, + translator: this.translator, + }, + DEFAULT_CSV_OPTIONS, + ), + }, + ]; + } + return []; } + + updateIcon(widget: IDocumentWidget) { + const ft = this.fileType(widget.context.path); + if (ft !== undefined) { + widget.title.icon = ft.icon; + if (ft.iconClass) { + widget.title.iconClass = ft.iconClass; + } + if (ft.iconLabel) { + widget.title.iconLabel = ft.iconLabel; + } + } + } + + private fileType(path: string): DocumentRegistry.IFileType | undefined { + const fileTypes = this._docRegistry + .getFileTypesForPath(path) + .filter((ft) => Object.values(FileType).includes(ft.name as FileType)); + if (fileTypes.length === 1) { + return fileTypes[0]; + } + return undefined; + } + + private _docRegistry: DocumentRegistry; } export interface ITextRenderConfig { diff --git a/style/base.css b/style/base.css index 7ff470c..83823bb 100644 --- a/style/base.css +++ b/style/base.css @@ -11,4 +11,39 @@ .arrow-grid-viewer { flex: 1 1 auto; } + + .arrow-viewer-toolbar { + display: flex; + flex: 0 0 auto; + flex-direction: row; + border: none; + min-height: 24px; + background: var(--jp-toolbar-background); + z-index: 1; + + .toolbar-label { + color: var(--jp-ui-font-color1); + font-size: var(--jp-ui-font-size1); + padding-left: 8px; + padding-right: 8px; + } + + .toolbar-dropdown { + flex: 0 0 auto; + vertical-align: middle; + border-radius: 0; + outline: none; + height: 20px; + margin-top: 2px; + margin-bottom: 2px; + + select.jp-mod-styled { + color: var(--jp-ui-font-color1); + background: var(--jp-layout-color1); + font-size: var(--jp-ui-font-size1); + height: 20px; + padding-right: 20px; + } + } + } }