diff --git a/src/filetypes.ts b/src/filetypes.ts index 3716a91..b5c4ad5 100644 --- a/src/filetypes.ts +++ b/src/filetypes.ts @@ -1,5 +1,16 @@ +import { LabIcon } from "@jupyterlab/ui-components"; import type { DocumentRegistry } from "@jupyterlab/docregistry"; +import arrowIpcSvg from "../style/icons/arrow.svg"; +import arrowIpcDarkSvg from "../style/icons/arrow_dark.svg"; +import avroSvg from "../style/icons/avro.svg"; +import orcLightSvg from "../style/icons/orc.svg"; +import orcDarkSvg from "../style/icons/orc_dark.svg"; +import parquetSvgLight from "../style/icons/parquet.svg"; +import parquetSvgDark from "../style/icons/parquet_dark.svg"; +import sqliteSvgLight from "../style/icons/sqlite.svg"; +import sqliteSvgDark from "../style/icons/sqlite_dark.svg"; + export enum FileType { Avro = "apache-avro", Csv = "csv", @@ -9,17 +20,54 @@ export enum FileType { Sqlite = "sqlite", } -export function ensureCsvFileType(docRegistry: DocumentRegistry): DocumentRegistry.IFileType { - const ft = docRegistry.getFileType(FileType.Csv); - if (ft) { - return ft; +export namespace FileType { + export function all(): FileType[] { + return Object.values(FileType).filter((v): v is FileType => typeof v === "string"); + } +} + +function _getIconSvg(fileType: FileType, isLight: boolean): string { + switch (fileType) { + case FileType.Parquet: + return isLight ? parquetSvgLight : parquetSvgDark; + case FileType.Ipc: + return isLight ? arrowIpcSvg : arrowIpcDarkSvg; + case FileType.Orc: + return isLight ? orcLightSvg : orcDarkSvg; + case FileType.Avro: + return avroSvg; + case FileType.Sqlite: + return isLight ? sqliteSvgLight : sqliteSvgDark; + case FileType.Csv: + throw new Error(`CSV file type does not have an icon`); + default: + throw new Error(`Unknown file type: ${fileType}`); } +} + +function _makeIcon(fileType: FileType, isLight: boolean): LabIcon { + return new LabIcon({ + name: `arbalister:${fileType}`, + svgstr: _getIconSvg(fileType, isLight), + }); +} + +function _updateIcon(icon: LabIcon, fileType: FileType, isLight: boolean) { + icon.svgstr = _getIconSvg(fileType, isLight); +} + +export function addCsvFileType( + docRegistry: DocumentRegistry, + options: Partial = {}, +): DocumentRegistry.IFileType { docRegistry.addFileType({ + ...options, name: FileType.Csv, displayName: "CSV", mimeTypes: ["text/csv"], extensions: [".csv"], contentType: "file", + fileFormat: "text", }); return docRegistry.getFileType(FileType.Csv)!; } @@ -103,3 +151,45 @@ export function addSqliteFileType( }); return docRegistry.getFileType(FileType.Sqlite)!; } + +export function ensureFileType( + docRegistry: DocumentRegistry, + fileType: FileType, + isLight: boolean, +): DocumentRegistry.IFileType { + const ft = docRegistry.getFileType(fileType); + if (ft) { + return ft; + } + switch (fileType) { + case FileType.Avro: + return addAvroFileType(docRegistry, { icon: _makeIcon(FileType.Avro, isLight) }); + case FileType.Parquet: + return addParquetFileType(docRegistry, { icon: _makeIcon(FileType.Parquet, isLight) }); + case FileType.Ipc: + return addIpcFileType(docRegistry, { icon: _makeIcon(FileType.Ipc, isLight) }); + case FileType.Orc: + return addOrcFileType(docRegistry, { icon: _makeIcon(FileType.Orc, isLight) }); + case FileType.Sqlite: + return addSqliteFileType(docRegistry, { icon: _makeIcon(FileType.Sqlite, isLight) }); + case FileType.Csv: + return addCsvFileType(docRegistry); + default: + throw new Error(`Unknown file type: ${fileType}`); + } +} + +export function updateIcon( + docRegistry: DocumentRegistry, + fileType: FileType, + isLight: boolean, +): void { + const ft = docRegistry.getFileType(fileType); + // We most likely we did not set the Csv file type + if (ft?.name === FileType.Csv) { + return; + } + if (ft?.icon) { + _updateIcon(ft?.icon, fileType, isLight); + } +} diff --git a/src/index.ts b/src/index.ts index 3f7a0c2..9eac098 100644 --- a/src/index.ts +++ b/src/index.ts @@ -8,21 +8,7 @@ import type * as services from "@jupyterlab/services"; import type { Contents } from "@jupyterlab/services"; import type { DataGrid } from "@lumino/datagrid"; -import { - addAvroFileType, - addIpcFileType, - addOrcFileType, - addParquetFileType, - addSqliteFileType, - ensureCsvFileType, -} from "./filetypes"; -import { - getArrowIPCIcon, - getAvroIcon, - getORCIcon, - getParquetIcon, - getSqliteIcon, -} from "./labicons"; +import { ensureFileType, FileType, updateIcon } from "./filetypes"; import { ArrowGridViewerFactory } from "./widget"; import type { ArrowGridViewer, ITextRenderConfig } from "./widget"; @@ -106,19 +92,15 @@ function activateArrowGrid( isLight = currentTheme ? themeManager?.isLight(currentTheme as string) : true; } - const csv_ft = ensureCsvFileType(app.docRegistry); - let prq_ft = addParquetFileType(app.docRegistry, { icon: getParquetIcon(isLight) }); - let avo_ft = addAvroFileType(app.docRegistry, { icon: getAvroIcon(isLight) }); - let ipc_ft = addIpcFileType(app.docRegistry, { icon: getArrowIPCIcon(isLight) }); - let orc_ft = addOrcFileType(app.docRegistry, { icon: getORCIcon(isLight) }); - let sqlite_ft = addSqliteFileType(app.docRegistry, { icon: getSqliteIcon(isLight) }); + const fileTypes = FileType.all().map((ft) => ensureFileType(app.docRegistry, ft, isLight)); + const fileTypesNames = fileTypes.map((ft) => ft.name); const factory = new ArrowGridViewerFactory( { name: factory_arrow, label: trans.__("Arrow Dataframe Viewer"), - fileTypes: [csv_ft.name, avo_ft.name, prq_ft.name, ipc_ft.name, orc_ft.name, sqlite_ft.name], - defaultFor: [csv_ft.name, avo_ft.name, prq_ft.name, ipc_ft.name, orc_ft.name, sqlite_ft.name], + fileTypes: fileTypesNames, + defaultFor: fileTypesNames, readOnly: true, translator, contentProviderId: NOOP_CONTENT_PROVIDER_ID, @@ -173,12 +155,13 @@ function activateArrowGrid( widget.content.style = style; widget.content.rendererConfig = rendererConfig; }); - prq_ft = addParquetFileType(app.docRegistry, { icon: getParquetIcon(isLightNew) }); - avo_ft = addAvroFileType(app.docRegistry, { icon: getAvroIcon(isLightNew) }); - ipc_ft = addIpcFileType(app.docRegistry, { icon: getArrowIPCIcon(isLightNew) }); - orc_ft = addOrcFileType(app.docRegistry, { icon: getORCIcon(isLightNew) }); - sqlite_ft = addSqliteFileType(app.docRegistry, { icon: getSqliteIcon(isLightNew) }); + + // Update the file icons to match theme + FileType.all().forEach((ft) => { + updateIcon(app.docRegistry, ft, isLightNew); + }); }; + if (themeManager) { themeManager.themeChanged.connect((_, args) => { try { diff --git a/src/labicons.ts b/src/labicons.ts deleted file mode 100644 index 91a3b09..0000000 --- a/src/labicons.ts +++ /dev/null @@ -1,35 +0,0 @@ -import { LabIcon } from "@jupyterlab/ui-components"; - -import arrowIPCSvg from "../style/icons/arrow.svg"; -import arrowIPCDarkSvg from "../style/icons/arrow_dark.svg"; -import avroSvg from "../style/icons/avro.svg"; -import orcLightSvg from "../style/icons/orc.svg"; -import orcDarkSvg from "../style/icons/orc_dark.svg"; -import parquetSvgLight from "../style/icons/parquet.svg"; -import parquetSvgDark from "../style/icons/parquet_dark.svg"; -import sqliteSvgLight from "../style/icons/sqlite.svg"; -import sqliteSvgDark from "../style/icons/sqlite_dark.svg"; - -export const getLabIcon = (labIconName: string, iconSvg: string) => { - return new LabIcon({ - name: `arbalister:${labIconName}`, - svgstr: iconSvg, - }); -}; - -export const getParquetIcon = (isLight: boolean) => { - return getLabIcon("parquet", isLight ? parquetSvgLight : parquetSvgDark); -}; - -export const getArrowIPCIcon = (isLight: boolean) => { - return getLabIcon("arrowipc", isLight ? arrowIPCSvg : arrowIPCDarkSvg); -}; -export const getORCIcon = (isLight: boolean) => { - return getLabIcon("orc", isLight ? orcLightSvg : orcDarkSvg); -}; -export const getAvroIcon = (isLight: boolean) => { - return getLabIcon("avro", isLight ? avroSvg : avroSvg); -}; -export const getSqliteIcon = (isLight: boolean) => { - return getLabIcon("sqlite", isLight ? sqliteSvgLight : sqliteSvgDark); -}; diff --git a/src/widget.ts b/src/widget.ts index 8646b7c..fcbf6b7 100644 --- a/src/widget.ts +++ b/src/widget.ts @@ -244,10 +244,11 @@ export class ArrowGridViewerFactory extends ABCWidgetFactory Object.values(FileType).includes(ft.name as FileType)); - if (fileTypes.length === 1) { + .filter((ft) => knowFileTypes.includes(ft.name as FileType)); + if (fileTypes.length >= 1) { return fileTypes[0]; } return undefined;