diff --git a/frontend/src/components/data-table/TableActions.tsx b/frontend/src/components/data-table/TableActions.tsx index 4cbaf99b0cd..583487d4d34 100644 --- a/frontend/src/components/data-table/TableActions.tsx +++ b/frontend/src/components/data-table/TableActions.tsx @@ -2,10 +2,16 @@ import React from "react"; import { Tooltip } from "../ui/tooltip"; import { Button } from "../ui/button"; -import { SearchIcon } from "lucide-react"; +import { PaletteIcon, SearchIcon, Settings } from "lucide-react"; import { DataTablePagination } from "./pagination"; import { DownloadAs, type DownloadActionProps } from "./download-actions"; import type { Table, RowSelectionState } from "@tanstack/react-table"; +import { + DropdownMenu, + DropdownMenuContent, + DropdownMenuItem, + DropdownMenuTrigger, +} from "../ui/dropdown-menu"; interface TableActionsProps { enableSearch: boolean; @@ -44,7 +50,7 @@ export const TableActions = ({ )} - {pagination ? ( + {pagination && ( ({ } table={table} /> - ) : ( -
)} - {downloadAs && } +
+ {downloadAs && } + {table.toggleGlobalHeatmap && ( + + + + + + table.toggleGlobalHeatmap?.()}> + + {table.getGlobalHeatmap?.() ? "Disable" : "Enable"} heatmap + + + + )} +
); }; diff --git a/frontend/src/components/data-table/__test__/cell-heatmap.test.ts b/frontend/src/components/data-table/__test__/cell-heatmap.test.ts new file mode 100644 index 00000000000..eb2a7879e95 --- /dev/null +++ b/frontend/src/components/data-table/__test__/cell-heatmap.test.ts @@ -0,0 +1,99 @@ +/* Copyright 2024 Marimo. All rights reserved. */ +import { describe, it, expect, vi } from "vitest"; +import { CellHeatmapFeature } from "../cell-heatmap/feature"; +import { + type Column, + createTable, + getCoreRowModel, +} from "@tanstack/react-table"; + +describe("CellHeatmapFeature", () => { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + let state: any = { + columnHeatmap: { + value: true, + }, + cachedMaxValue: null, + cachedMinValue: null, + }; + + const mockTable = createTable({ + _features: [CellHeatmapFeature], + data: [ + { id: 1, value: 10 }, + { id: 2, value: 20 }, + { id: 3, value: 30 }, + ], + state: state, + onStateChange: (updater) => { + state = typeof updater === "function" ? updater(state) : updater; + }, + columns: [ + { id: "id", accessorKey: "id" }, + { id: "value", accessorKey: "value" }, + ], + getCoreRowModel: getCoreRowModel(), + renderFallbackValue: null, + }); + + it("should initialize with correct default state", () => { + const initialState = CellHeatmapFeature.getInitialState?.(); + expect(initialState).toEqual({ + columnHeatmap: {}, + cachedMaxValue: null, + cachedMinValue: null, + }); + }); + + it("should provide default options", () => { + const options = CellHeatmapFeature.getDefaultOptions?.(mockTable) || {}; + expect(options.enableCellHeatmap).toBe(true); + expect(typeof options.onColumnHeatmapChange).toBe("function"); + }); + + it("should add getGlobalHeatmap and toggleGlobalHeatmap methods to table", () => { + CellHeatmapFeature.createTable?.(mockTable); + expect(typeof mockTable.getGlobalHeatmap).toBe("function"); + expect(typeof mockTable.toggleGlobalHeatmap).toBe("function"); + }); + + it("should add methods to column", () => { + const mockColumn = { id: "test" } as Column; + CellHeatmapFeature.createColumn?.(mockColumn, mockTable); + expect(typeof mockColumn.getCellHeatmapColor).toBe("function"); + expect(typeof mockColumn.toggleColumnHeatmap).toBe("function"); + expect(typeof mockColumn.getIsColumnHeatmapEnabled).toBe("function"); + }); + + it("should calculate correct heatmap color", () => { + const mockColumn = { id: "value" } as Column; + CellHeatmapFeature.createColumn?.(mockColumn, mockTable); + mockTable.setState((prev) => { + prev.columnHeatmap = { value: true }; + return prev; + }); + + const color = mockColumn.getCellHeatmapColor?.(20); + expect(color).toMatch(/^hsla\(\d+(?:,\s*\d+%){2},\s*0\.6\)$/); + }); + + it("should toggle column heatmap", () => { + const mockColumn = { id: "value" } as Column; + const onColumnHeatmapChange = vi.fn(); + mockTable.options.onColumnHeatmapChange = onColumnHeatmapChange; + + CellHeatmapFeature.createColumn?.(mockColumn, mockTable); + mockColumn.toggleColumnHeatmap?.(); + + expect(onColumnHeatmapChange).toHaveBeenCalled(); + }); + + it("should handle global heatmap toggle", () => { + CellHeatmapFeature.createTable?.(mockTable); + const onColumnHeatmapChange = vi.fn(); + mockTable.options.onColumnHeatmapChange = onColumnHeatmapChange; + + mockTable.toggleGlobalHeatmap?.(); + expect(onColumnHeatmapChange).toHaveBeenCalled(); + }); +}); diff --git a/frontend/src/components/data-table/cell-heatmap/feature.ts b/frontend/src/components/data-table/cell-heatmap/feature.ts new file mode 100644 index 00000000000..ccb211a3ec7 --- /dev/null +++ b/frontend/src/components/data-table/cell-heatmap/feature.ts @@ -0,0 +1,202 @@ +/* Copyright 2024 Marimo. All rights reserved. */ +import { + type TableFeature, + type RowData, + type Table, + type Column, + type Updater, + makeStateUpdater, +} from "@tanstack/react-table"; +import type { CellHeatmapTableState, CellHeatmapOptions } from "./types"; + +export const CellHeatmapFeature: TableFeature = { + getInitialState: (state): CellHeatmapTableState => { + return { + columnHeatmap: {}, + cachedMaxValue: null, + cachedMinValue: null, + ...state, + }; + }, + + getDefaultOptions: ( + table: Table, + ): CellHeatmapOptions => { + return { + enableCellHeatmap: true, + onColumnHeatmapChange: makeStateUpdater("columnHeatmap", table), + }; + }, + + createTable: (table: Table) => { + table.getGlobalHeatmap = () => { + return Object.values(table.getState().columnHeatmap).some(Boolean); + }; + + table.toggleGlobalHeatmap = () => { + const allColumns = table.getAllColumns(); + const { columnHeatmap } = table.getState(); + const hasAnyEnabled = Object.values(columnHeatmap).some(Boolean); + + if (hasAnyEnabled) { + // Disable all columns + table.options.onColumnHeatmapChange?.( + Object.fromEntries(allColumns.map((column) => [column.id, false])), + ); + } else { + // Enable all columns + table.options.onColumnHeatmapChange?.( + Object.fromEntries(allColumns.map((column) => [column.id, true])), + ); + } + + table.setState((old) => ({ + ...old, + cachedMinValue: null, + cachedMaxValue: null, + })); + }; + }, + + createColumn: ( + column: Column, + table: Table, + ) => { + // Clear min/max cache when a column is added or removed + table.setState((old) => ({ + ...old, + cachedMinValue: null, + cachedMaxValue: null, + })); + + column.getCellHeatmapColor = (cellValue: unknown) => { + const state = table.getState(); + const isColumnHeatmapEnabled = Boolean(state.columnHeatmap[column.id]); + if (!isColumnHeatmapEnabled || typeof cellValue !== "number") { + return ""; + } + + // Get all numeric values from enabled columns + let minValue = state.cachedMinValue; + let maxValue = state.cachedMaxValue; + + if (minValue == null || maxValue == null) { + const { min, max } = getMaxMinValue(table); + minValue = min; + maxValue = max; + table.setState((old) => ({ + ...old, + cachedMinValue: minValue, + cachedMaxValue: maxValue, + })); + } + + const isDarkMode = + typeof window !== "undefined" && + "matchMedia" in window && + typeof window.matchMedia === "function" && + window.matchMedia("(prefers-color-scheme: dark)").matches; + + const colorStops = isDarkMode + ? [ + { hue: 210, saturation: 100, lightness: 30 }, // Darker Blue + { hue: 199, saturation: 95, lightness: 33 }, // Darker Cyan + { hue: 172, saturation: 66, lightness: 30 }, // Darker Teal + { hue: 158, saturation: 64, lightness: 32 }, // Darker Green + { hue: 142, saturation: 71, lightness: 25 }, // Darker Lime + { hue: 47, saturation: 96, lightness: 33 }, // Darker Yellow + { hue: 21, saturation: 90, lightness: 28 }, // Darker Orange + { hue: 0, saturation: 84, lightness: 40 }, // Darker Red + ] + : [ + { hue: 210, saturation: 100, lightness: 50 }, // Blue-500 + { hue: 199, saturation: 95, lightness: 53 }, // Cyan-500 + { hue: 172, saturation: 66, lightness: 50 }, // Teal-500 + { hue: 158, saturation: 64, lightness: 52 }, // Green-500 + { hue: 142, saturation: 71, lightness: 45 }, // Lime-600 + { hue: 47, saturation: 96, lightness: 53 }, // Yellow-400 + { hue: 21, saturation: 90, lightness: 48 }, // Orange-500 + { hue: 0, saturation: 84, lightness: 60 }, // Red-500 + ]; + + // Normalize the cellValue + const normalized = (cellValue - minValue) / (maxValue - minValue); + + const index = Math.min( + Math.floor(normalized * (colorStops.length - 1)), + colorStops.length - 2, + ); + const t = normalized * (colorStops.length - 1) - index; + + const c1 = colorStops[index]; + const c2 = colorStops[index + 1]; + + if (!c1 || !c2) { + return ""; + } + + const hue = Math.round(c1.hue + t * (c2.hue - c1.hue)); + const saturation = Math.round( + c1.saturation + t * (c2.saturation - c1.saturation), + ); + const lightness = Math.round( + c1.lightness + t * (c2.lightness - c1.lightness), + ); + + return `hsla(${hue}, ${saturation}%, ${lightness}%, 0.6)`; + }; + + column.toggleColumnHeatmap = (value?: boolean) => { + const safeUpdater: Updater = ( + old, + ) => { + const prevValue = old[column.id]; + if (value !== undefined) { + return { + ...old, + [column.id]: value, + }; + } + + return { + ...old, + [column.id]: !prevValue, + }; + }; + + table.options.onColumnHeatmapChange?.(safeUpdater); + table.setState((old) => ({ + ...old, + cachedMinValue: null, + cachedMaxValue: null, + })); + }; + + column.getIsColumnHeatmapEnabled = () => { + return table.getState().columnHeatmap[column.id] || false; + }; + }, +}; + +function getMaxMinValue(table: Table) { + const { columnHeatmap } = table.getState(); + const enabledColumnsSet = new Set( + Object.keys(columnHeatmap).filter((key) => columnHeatmap[key]), + ); + const values: number[] = []; + for (const row of table.getRowModel().rows) { + for (const column of table.getAllColumns()) { + if (enabledColumnsSet.has(column.id)) { + const cellValue = row.getValue(column.id); + if (typeof cellValue === "number" && !Number.isNaN(cellValue)) { + values.push(cellValue); + } + } + } + } + + return { + min: Math.min(...values), + max: Math.max(...values), + }; +} diff --git a/frontend/src/components/data-table/cell-heatmap/types.ts b/frontend/src/components/data-table/cell-heatmap/types.ts new file mode 100644 index 00000000000..1349c5b1d2f --- /dev/null +++ b/frontend/src/components/data-table/cell-heatmap/types.ts @@ -0,0 +1,40 @@ +/* Copyright 2024 Marimo. All rights reserved. */ +/* eslint-disable @typescript-eslint/no-empty-interface */ +import type { RowData, Updater } from "@tanstack/react-table"; + +export interface CellHeatmapTableState { + columnHeatmap: Record; + cachedMinValue?: number | null; + cachedMaxValue?: number | null; +} + +export interface CellHeatmapOptions { + enableCellHeatmap?: boolean; + onGlobalHeatmapChange?: (updater: Updater) => void; + onColumnHeatmapChange?: (updater: Updater>) => void; +} + +export interface CellHeatmapState { + global: boolean; + columns: Record; +} + +// Use declaration merging to add our new feature APIs +declare module "@tanstack/react-table" { + interface TableState extends CellHeatmapTableState {} + interface InitialTableState extends CellHeatmapTableState {} + + interface TableOptionsResolved + extends CellHeatmapOptions {} + + interface Table { + getGlobalHeatmap?: () => boolean; + toggleGlobalHeatmap?: (value?: boolean) => void; + } + + interface Column { + getCellHeatmapColor?: (cellValue: unknown) => string; + toggleColumnHeatmap?: (value?: boolean) => void; + getIsColumnHeatmapEnabled?: () => boolean; + } +} diff --git a/frontend/src/components/data-table/column-header.tsx b/frontend/src/components/data-table/column-header.tsx index daf481cb8fa..536b9ba8afc 100644 --- a/frontend/src/components/data-table/column-header.tsx +++ b/frontend/src/components/data-table/column-header.tsx @@ -12,6 +12,7 @@ import { WrapTextIcon, AlignJustifyIcon, PinOffIcon, + PaletteIcon, } from "lucide-react"; import { cn } from "@/utils/cn"; @@ -27,17 +28,17 @@ import { DropdownMenuTrigger, } from "@/components/ui/dropdown-menu"; import { Button } from "../ui/button"; -import { useRef, useState } from "react"; import { NumberField } from "../ui/number-field"; import { Input } from "../ui/input"; import { type ColumnFilterForType, Filter } from "./filters"; -import { logNever } from "@/utils/assertNever"; import type { DataType } from "@/core/kernel/messages"; import { formatOptions } from "./column-formatting/types"; import { DATA_TYPE_ICON } from "../datasets/icons"; import { formattingExample } from "./column-formatting/feature"; import { PinLeftIcon, PinRightIcon } from "@radix-ui/react-icons"; import { NAMELESS_COLUMN_PREFIX } from "./columns"; +import { logNever } from "@/utils/assertNever"; +import { useState, useRef } from "react"; interface DataTableColumnHeaderProps extends React.HTMLAttributes { @@ -205,6 +206,40 @@ export const DataTableColumnHeader = ({ ); }; + const renderColorOptions = () => { + if (!column.getIsColumnHeatmapEnabled || !column.toggleColumnHeatmap) { + return null; + } + + // Get type, should be a number + let dataType = column.columnDef.meta?.dataType; + + // HACK: If no dataType, check if column is numeric manually + if (!dataType) { + // Check if column is numeric + // Check if first or last column is numeric + const rows = column.getFacetedRowModel().flatRows; + const first = rows[0]?.getValue(column.id); + const last = rows[rows.length - 1]?.getValue(column.id); + if (typeof first === "number" || typeof last === "number") { + dataType = "number"; + } + } + + if (dataType !== "integer" && dataType !== "number") { + return null; + } + + const enabled = column.getIsColumnHeatmapEnabled(); + + return ( + column.toggleColumnHeatmap?.()}> + + {enabled ? "Disable heatmap" : "Enable heatmap"} + + ); + }; + return ( @@ -258,6 +293,7 @@ export const DataTableColumnHeader = ({ {renderColumnPinning()} {renderColumnWrapping()} {renderFormatOptions()} + {renderColorOptions()} diff --git a/frontend/src/components/data-table/data-table.tsx b/frontend/src/components/data-table/data-table.tsx index 1c0953c6c23..6284e92c41d 100644 --- a/frontend/src/components/data-table/data-table.tsx +++ b/frontend/src/components/data-table/data-table.tsx @@ -25,6 +25,7 @@ import { SearchBar } from "./SearchBar"; import { TableActions } from "./TableActions"; import { ColumnFormattingFeature } from "./column-formatting/feature"; import { ColumnWrappingFeature } from "./column-wrapping/feature"; +import { CellHeatmapFeature } from "./cell-heatmap/feature"; interface DataTableProps extends Partial { wrapperClassName?: string; @@ -56,6 +57,8 @@ interface DataTableProps extends Partial { // Columns freezeColumnsLeft?: string[]; freezeColumnsRight?: string[]; + // Heatmap + initialHeatmap?: boolean | string[]; } const DataTableInternal = ({ @@ -84,8 +87,10 @@ const DataTableInternal = ({ reloading, freezeColumnsLeft, freezeColumnsRight, + initialHeatmap = false, }: DataTableProps) => { const [isSearchEnabled, setIsSearchEnabled] = React.useState(false); + // const [cellHeatmap, setCellHeatmap] = React.useState(false); const { columnPinning, setColumnPinning } = useColumnPinning( freezeColumnsLeft, @@ -93,7 +98,12 @@ const DataTableInternal = ({ ); const table = useReactTable({ - _features: [ColumnPinning, ColumnWrappingFeature, ColumnFormattingFeature], + _features: [ + ColumnPinning, + ColumnWrappingFeature, + ColumnFormattingFeature, + CellHeatmapFeature, + ], data, columns, getCoreRowModel: getCoreRowModel(), @@ -127,11 +137,22 @@ const DataTableInternal = ({ onColumnFiltersChange: onFiltersChange, // selection onRowSelectionChange: onRowSelectionChange, + initialState: { + columnHeatmap: + initialHeatmap === true + ? // Initialize heatmap state for all columns + Object.fromEntries(columns.map((column) => [column.id, true])) + : Array.isArray(initialHeatmap) + ? Object.fromEntries( + initialHeatmap.map((columnId) => [columnId, true]), + ) + : {}, + }, state: { ...(sorting ? { sorting } : {}), columnFilters: filters, - ...// Controlled state - (paginationState + // Controlled state + ...(paginationState ? { pagination: paginationState } : // Uncontrolled state pagination && !paginationState diff --git a/frontend/src/components/data-table/download-actions.tsx b/frontend/src/components/data-table/download-actions.tsx index ac9e692da46..7433f82183f 100644 --- a/frontend/src/components/data-table/download-actions.tsx +++ b/frontend/src/components/data-table/download-actions.tsx @@ -9,15 +9,19 @@ import { } from "../ui/dropdown-menu"; import { toast } from "../ui/use-toast"; import { downloadByURL } from "@/utils/download"; -import { ChevronDownIcon } from "lucide-react"; +import { + ChevronDownIcon, + FileJsonIcon, + FileSpreadsheetIcon, +} from "lucide-react"; export interface DownloadActionProps { downloadAs: (req: { format: "csv" | "json" }) => Promise; } const options = [ - { label: "CSV", format: "csv" }, - { label: "JSON", format: "json" }, + { label: "CSV", format: "csv", Icon: FileSpreadsheetIcon }, + { label: "JSON", format: "json", Icon: FileJsonIcon }, ] as const; export const DownloadAs: React.FC = (props) => { @@ -52,6 +56,7 @@ export const DownloadAs: React.FC = (props) => { downloadByURL(downloadUrl, "download"); }} > + {option.label} ))} diff --git a/frontend/src/components/data-table/renderers.tsx b/frontend/src/components/data-table/renderers.tsx index c889b3efff4..ed3f99b0d00 100644 --- a/frontend/src/components/data-table/renderers.tsx +++ b/frontend/src/components/data-table/renderers.tsx @@ -64,6 +64,7 @@ export function renderTableBody( const renderCells = (row: Row, cells: Array>) => { return cells.map((cell) => { const { className, style } = getPinningStyles(cell.column); + const heatmapColor = cell.column.getCellHeatmapColor?.(cell.getValue()); return ( ( "whitespace-pre-wrap min-w-[200px]", className, )} - style={style} + style={{ + ...style, + backgroundColor: heatmapColor, + }} title={String(cell.getValue())} > {flexRender(cell.column.columnDef.cell, cell.getContext())} diff --git a/frontend/src/plugins/impl/DataTablePlugin.tsx b/frontend/src/plugins/impl/DataTablePlugin.tsx index 2dc844e4748..ada45032d89 100644 --- a/frontend/src/plugins/impl/DataTablePlugin.tsx +++ b/frontend/src/plugins/impl/DataTablePlugin.tsx @@ -67,6 +67,7 @@ interface Data { fieldTypes?: FieldTypesWithExternalType | null; freezeColumnsLeft?: string[]; freezeColumnsRight?: string[]; + heatmap: boolean | string[]; } // eslint-disable-next-line @typescript-eslint/consistent-type-definitions @@ -121,6 +122,7 @@ export const DataTablePlugin = createPlugin("marimo-table") ]), ) .nullish(), + heatmap: z.union([z.boolean(), z.array(z.string())]).default(false), }), ) .withFunctions({ @@ -431,6 +433,7 @@ const DataTableComponent = ({ reloading, freezeColumnsLeft, freezeColumnsRight, + heatmap, }: DataTableProps & DataTableSearchProps & { data: unknown[]; @@ -529,6 +532,7 @@ const DataTableComponent = ({ onRowSelectionChange={handleRowSelectionChange} freezeColumnsLeft={freezeColumnsLeft} freezeColumnsRight={freezeColumnsRight} + initialHeatmap={heatmap} /> diff --git a/frontend/src/plugins/impl/data-frames/DataFramePlugin.tsx b/frontend/src/plugins/impl/data-frames/DataFramePlugin.tsx index a7a5798795d..874943ef10c 100644 --- a/frontend/src/plugins/impl/data-frames/DataFramePlugin.tsx +++ b/frontend/src/plugins/impl/data-frames/DataFramePlugin.tsx @@ -288,6 +288,7 @@ export const DataFrameComponent = memo( value={Arrays.EMPTY} setValue={Functions.NOOP} selection={null} + heatmap={false} /> ); diff --git a/marimo/_plugins/ui/_impl/table.py b/marimo/_plugins/ui/_impl/table.py index 7c7c993999f..0dfc06bf584 100644 --- a/marimo/_plugins/ui/_impl/table.py +++ b/marimo/_plugins/ui/_impl/table.py @@ -201,6 +201,7 @@ def format_name(name): or functions - `freeze_columns_left`: list of column names to freeze on the left - `freeze_columns_right`: list of column names to freeze on the right + - `heatmap`: whether to enable global heatmap for numeric columns - `label`: text label for the element - `on_change`: optional callback to run when this element's value changes """ @@ -226,6 +227,7 @@ def __init__( ] = None, freeze_columns_left: Optional[Sequence[str]] = None, freeze_columns_right: Optional[Sequence[str]] = None, + heatmap: Union[bool, List[str]] = False, *, label: str = "", on_change: Optional[ @@ -360,6 +362,7 @@ def __init__( "row-headers": self._manager.get_row_headers(), "freeze-columns-left": freeze_columns_left, "freeze-columns-right": freeze_columns_right, + "heatmap": heatmap, }, on_change=on_change, functions=( diff --git a/marimo/_smoke_tests/tables/heatmap.py b/marimo/_smoke_tests/tables/heatmap.py new file mode 100644 index 00000000000..d0eea033745 --- /dev/null +++ b/marimo/_smoke_tests/tables/heatmap.py @@ -0,0 +1,54 @@ +# /// script +# requires-python = ">=3.11" +# dependencies = [ +# "marimo", +# ] +# /// +# Copyright 2024 Marimo. All rights reserved. + +import marimo + +__generated_with = "0.8.15" +app = marimo.App(width="medium") + + +@app.cell +def __(): + import os + import marimo as mo + return mo, os + + +@app.cell +def __(mo): + _size = 10 + mo.ui.table([x for x in range(_size)], page_size=_size, selection=None) + return + + +@app.cell +def __(mo): + _size = 20 + mo.ui.table( + [{"one": x, "two": x * 3} for x in range(0, _size)], + page_size=_size, + heatmap=True, + selection=None, + ) + return + + +@app.cell +def __(mo): + _size = 20 + mo.ui.table( + [{"one": x, "two": x * 3} for x in range(0, _size)], + page_size=_size, + heatmap=["two"], + selection=None, + ) + return + + +if __name__ == "__main__": + app.run()