diff --git a/.changeset/odd-bees-pay.md b/.changeset/odd-bees-pay.md new file mode 100644 index 0000000000000..44673cebb20e1 --- /dev/null +++ b/.changeset/odd-bees-pay.md @@ -0,0 +1,8 @@ +--- +"@gradio/atoms": minor +"@gradio/core": minor +"@gradio/dataframe": minor +"gradio": minor +--- + +feat:Allow sorting by multiple columns in dataframe diff --git a/js/atoms/src/IconButton.svelte b/js/atoms/src/IconButton.svelte index 7769eec5056e3..53a1581dfba0d 100644 --- a/js/atoms/src/IconButton.svelte +++ b/js/atoms/src/IconButton.svelte @@ -4,7 +4,7 @@ export let label = ""; export let show_label = false; export let pending = false; - export let size: "small" | "large" | "medium" = "small"; + export let size: "x-small" | "small" | "large" | "medium" = "small"; export let padded = true; export let highlight = false; export let disabled = false; @@ -30,6 +30,7 @@ > {#if show_label}{label}{/if}
setTimeout(resolve, 500)); }} /> + + { + const canvas = within(canvasElement); + const user = userEvent.setup(); + + const header_1 = canvas.getAllByText("A")[1]; + await userEvent.click(header_1); + + const cell_menu_button = canvas.getAllByLabelText("Open cell menu")[0]; + await userEvent.click(cell_menu_button); + + const sort_ascending_button = canvas.getByRole("button", { + name: "Sort ascending" + }); + await userEvent.click(sort_ascending_button); + + const header_2 = canvas.getAllByText("B")[1]; + await userEvent.click(header_2); + + const cell_menu_button_2 = canvas.getAllByLabelText("Open cell menu")[1]; + await userEvent.click(cell_menu_button_2); + + const sort_descending_button = canvas.getByRole("button", { + name: "Sort descending" + }); + await userEvent.click(sort_descending_button); + + const header_3 = canvas.getAllByText("C")[1]; + await userEvent.click(header_3); + + const cell_menu_button_3 = canvas.getAllByLabelText("Open cell menu")[2]; + await userEvent.click(cell_menu_button_3); + + const sort_ascending_button_3 = canvas.getByRole("button", { + name: "Sort ascending" + }); + await userEvent.click(sort_ascending_button_3); + + await userEvent.click(header_3); + await userEvent.click(cell_menu_button_3); + await userEvent.click(canvas.getByText("Clear sort")); + }} +/> diff --git a/js/dataframe/shared/CellMenu.svelte b/js/dataframe/shared/CellMenu.svelte index 0a7b3b0707c2e..e1fc3ab869810 100644 --- a/js/dataframe/shared/CellMenu.svelte +++ b/js/dataframe/shared/CellMenu.svelte @@ -2,6 +2,7 @@ import { onMount } from "svelte"; import CellMenuIcons from "./CellMenuIcons.svelte"; import type { I18nFormatter } from "js/utils/src"; + import type { SortDirection } from "./context/table_context"; export let x: number; export let y: number; @@ -16,6 +17,10 @@ export let on_delete_col: () => void; export let can_delete_rows: boolean; export let can_delete_cols: boolean; + export let on_sort: (direction: SortDirection) => void = () => {}; + export let on_clear_sort: () => void = () => {}; + export let sort_direction: SortDirection | null = null; + export let sort_priority: number | null = null; export let i18n: I18nFormatter; let menu_element: HTMLDivElement; @@ -52,6 +57,33 @@
+ {#if is_header} + + + + {/if} + {#if !is_header && can_add_rows} - + {/if} + { + event.stopPropagation(); + dispatch("sort", "asc"); + }} + > + { + event.stopPropagation(); + dispatch("sort", "desc"); + }} + >
diff --git a/js/dataframe/shared/utils/sort_utils.test.ts b/js/dataframe/shared/utils/sort_utils.test.ts index f89000f7584b2..8612646218192 100644 --- a/js/dataframe/shared/utils/sort_utils.test.ts +++ b/js/dataframe/shared/utils/sort_utils.test.ts @@ -1,24 +1,89 @@ import { describe, test, expect } from "vitest"; -import { get_sort_status, sort_data } from "./sort_utils"; +import { get_sort_status, sort_data, SortDirection } from "./sort_utils"; describe("sort_utils", () => { describe("get_sort_status", () => { const headers = ["A", "B", "C"]; test("returns none when no sort is active", () => { - expect(get_sort_status("A", undefined, undefined, headers)).toBe("none"); + expect(get_sort_status("A", [], headers)).toBe("none"); }); test("returns ascending when column is sorted ascending", () => { - expect(get_sort_status("A", 0, "asc", headers)).toBe("asc"); + expect( + get_sort_status( + "A", + [{ col: 0, direction: "asc" as SortDirection }], + headers + ) + ).toBe("asc"); }); test("returns descending when column is sorted descending", () => { - expect(get_sort_status("B", 1, "desc", headers)).toBe("desc"); + expect( + get_sort_status( + "B", + [{ col: 1, direction: "desc" as SortDirection }], + headers + ) + ).toBe("desc"); }); test("returns none for non-matching column", () => { - expect(get_sort_status("A", 1, "asc", headers)).toBe("none"); + expect( + get_sort_status( + "A", + [{ col: 1, direction: "asc" as SortDirection }], + headers + ) + ).toBe("none"); + }); + + test("handles multiple sort columns", () => { + const sort_columns = [ + { col: 0, direction: "asc" as SortDirection }, + { col: 1, direction: "desc" as SortDirection } + ]; + expect(get_sort_status("A", sort_columns, headers)).toBe("asc"); + expect(get_sort_status("B", sort_columns, headers)).toBe("desc"); + expect(get_sort_status("C", sort_columns, headers)).toBe("none"); + }); + + test("handles invalid column indices", () => { + expect( + get_sort_status( + "A", + [{ col: -1, direction: "asc" as SortDirection }], + headers + ) + ).toBe("none"); + + expect( + get_sort_status( + "A", + [{ col: 999, direction: "asc" as SortDirection }], + headers + ) + ).toBe("none"); + }); + + test("handles empty headers", () => { + expect( + get_sort_status( + "A", + [{ col: 0, direction: "asc" as SortDirection }], + [] + ) + ).toBe("none"); + }); + + test("prioritizes first matching column in sort_columns", () => { + const sort_columns = [ + { col: 0, direction: "asc" as SortDirection }, + { col: 0, direction: "desc" as SortDirection } + ]; + + expect(get_sort_status("A", sort_columns, headers)).toBe("asc"); }); }); @@ -39,33 +104,216 @@ describe("sort_utils", () => { ]; test("sorts strings ascending", () => { - const indices = sort_data(data, 0, "asc"); + const indices = sort_data(data, [ + { col: 0, direction: "asc" as SortDirection } + ]); expect(indices).toEqual([1, 0, 2]); // A, B, C }); test("sorts numbers ascending", () => { - const indices = sort_data(data, 1, "asc"); - expect(indices).toEqual([1, 0, 2]); // 1, 2, 3 + const indices = sort_data(data, [ + { col: 1, direction: "asc" as SortDirection } + ]); + expect(indices).toEqual([1, 0, 2]); }); test("sorts strings descending", () => { - const indices = sort_data(data, 0, "desc"); - expect(indices).toEqual([2, 0, 1]); // C, B, A + const indices = sort_data(data, [ + { col: 0, direction: "desc" as SortDirection } + ]); + expect(indices).toEqual([2, 0, 1]); }); - test("returns original order when sort params are invalid", () => { - const indices = sort_data(data, undefined, undefined); + test("returns original order when sort params are empty", () => { + const indices = sort_data(data, []); expect(indices).toEqual([0, 1, 2]); }); test("handles empty data", () => { - const indices = sort_data([], 0, "asc"); + const indices = sort_data( + [], + [{ col: 0, direction: "asc" as SortDirection }] + ); expect(indices).toEqual([]); }); test("handles invalid column index", () => { - const indices = sort_data(data, 999, "asc"); + const indices = sort_data(data, [ + { col: 999, direction: "asc" as SortDirection } + ]); + expect(indices).toEqual([0, 1, 2]); + }); + + test("sorts by multiple columns", () => { + const test_data = [ + [ + { id: "1", value: "A" }, + { id: "2", value: 2 } + ], + [ + { id: "3", value: "A" }, + { id: "4", value: 1 } + ], + [ + { id: "5", value: "B" }, + { id: "6", value: 3 } + ] + ]; + + const indices = sort_data(test_data, [ + { col: 0, direction: "asc" as SortDirection }, + { col: 1, direction: "asc" as SortDirection } + ]); + + expect(indices).toEqual([1, 0, 2]); + }); + + test("respects sort direction for each column", () => { + const test_data = [ + [ + { id: "1", value: "A" }, + { id: "2", value: 2 } + ], + [ + { id: "3", value: "A" }, + { id: "4", value: 1 } + ], + [ + { id: "5", value: "B" }, + { id: "6", value: 3 } + ] + ]; + + const indices = sort_data(test_data, [ + { col: 0, direction: "asc" as SortDirection }, + { col: 1, direction: "desc" as SortDirection } + ]); + + expect(indices).toEqual([0, 1, 2]); + }); + + test("handles mixed data types in sort columns", () => { + const mixed_data = [ + [ + { id: "1", value: "A" }, + { id: "2", value: 2 } + ], + [ + { id: "3", value: "A" }, + { id: "4", value: 1 } + ], + [ + { id: "5", value: "B" }, + { id: "6", value: 2 } + ] + ]; + + const indices = sort_data(mixed_data, [ + { col: 0, direction: "asc" as SortDirection }, + { col: 1, direction: "asc" as SortDirection } + ]); + + expect(indices).toEqual([1, 0, 2]); + }); + + test("handles more than two sort columns", () => { + const complex_data = [ + [ + { id: "1", value: "A" }, + { id: "2", value: 1 }, + { id: "3", value: "X" } + ], + [ + { id: "4", value: "A" }, + { id: "5", value: 1 }, + { id: "6", value: "Y" } + ], + [ + { id: "7", value: "B" }, + { id: "8", value: 2 }, + { id: "9", value: "Z" } + ] + ]; + + const indices = sort_data(complex_data, [ + { col: 0, direction: "asc" as SortDirection }, + { col: 1, direction: "asc" as SortDirection }, + { col: 2, direction: "asc" as SortDirection } + ]); + + expect(indices).toEqual([0, 1, 2]); + }); + + test("ignores invalid sort columns", () => { + const indices = sort_data(data, [ + { col: -1, direction: "asc" as SortDirection }, + { col: 0, direction: "asc" as SortDirection } + ]); + + expect(indices).toEqual([1, 0, 2]); + }); + + test("maintains original order when all values are equal", () => { + const equal_data = [ + [ + { id: "1", value: "A" }, + { id: "2", value: 1 } + ], + [ + { id: "3", value: "A" }, + { id: "4", value: 1 } + ], + [ + { id: "5", value: "A" }, + { id: "6", value: 1 } + ] + ]; + + const indices = sort_data(equal_data, [ + { col: 0, direction: "asc" as SortDirection }, + { col: 1, direction: "asc" as SortDirection } + ]); + expect(indices).toEqual([0, 1, 2]); }); + + test("handles undefined values in data rows", () => { + const data_with_undefined = [ + [ + { id: "1", value: "A" }, + { id: "2", value: "" } + ], + [ + { id: "3", value: "B" }, + { id: "4", value: 2 } + ] + ]; + + const indices = sort_data(data_with_undefined, [ + { col: 0, direction: "asc" as SortDirection }, + { col: 1, direction: "asc" as SortDirection } + ]); + + expect(indices).toEqual([0, 1]); + }); + + test("handles missing values in data", () => { + const data_with_missing = [ + [ + { id: "1", value: "" }, + { id: "2", value: 1 } + ], + [ + { id: "3", value: "A" }, + { id: "4", value: 2 } + ] + ]; + + const indices = sort_data(data_with_missing, [ + { col: 0, direction: "asc" as SortDirection } + ]); + + expect(indices).toEqual([0, 1]); + }); }); }); diff --git a/js/dataframe/shared/utils/sort_utils.ts b/js/dataframe/shared/utils/sort_utils.ts index bbbaeddfd0f47..20f48aa5e7759 100644 --- a/js/dataframe/shared/utils/sort_utils.ts +++ b/js/dataframe/shared/utils/sort_utils.ts @@ -4,50 +4,58 @@ export type SortDirection = "asc" | "desc"; export function get_sort_status( name: string, - sort_by: number | undefined, - direction: SortDirection | undefined, + sort_columns: { col: number; direction: SortDirection }[], headers: Headers ): "none" | "asc" | "desc" { - if (typeof sort_by !== "number") return "none"; - if (sort_by < 0 || sort_by >= headers.length) return "none"; - if (headers[sort_by] === name) { - if (direction === "asc") return "asc"; - if (direction === "desc") return "desc"; - } - return "none"; + if (!sort_columns.length) return "none"; + + const sort_item = sort_columns.find((item) => { + const col = item.col; + if (col < 0 || col >= headers.length) return false; + return headers[col] === name; + }); + + if (!sort_item) return "none"; + return sort_item.direction; } export function sort_data( data: { id: string; value: string | number }[][], - sort_by: number | undefined, - sort_direction: SortDirection | undefined + sort_columns: { col: number; direction: SortDirection }[] ): number[] { if (!data || !data.length || !data[0]) { return []; } - if ( - typeof sort_by === "number" && - sort_direction && - sort_by >= 0 && - sort_by < data[0].length - ) { + if (sort_columns.length > 0) { const row_indices = [...Array(data.length)].map((_, i) => i); row_indices.sort((row_a_idx, row_b_idx) => { const row_a = data[row_a_idx]; const row_b = data[row_b_idx]; - if ( - !row_a || - !row_b || - sort_by >= row_a.length || - sort_by >= row_b.length - ) - return 0; - - const val_a = row_a[sort_by].value; - const val_b = row_b[sort_by].value; - const comparison = val_a < val_b ? -1 : val_a > val_b ? 1 : 0; - return sort_direction === "asc" ? comparison : -comparison; + + for (const { col: sort_by, direction } of sort_columns) { + if ( + !row_a || + !row_b || + sort_by < 0 || + sort_by >= row_a.length || + sort_by >= row_b.length || + !row_a[sort_by] || + !row_b[sort_by] + ) { + continue; + } + + const val_a = row_a[sort_by].value; + const val_b = row_b[sort_by].value; + const comparison = val_a < val_b ? -1 : val_a > val_b ? 1 : 0; + + if (comparison !== 0) { + return direction === "asc" ? comparison : -comparison; + } + } + + return 0; }); return row_indices; } diff --git a/js/dataframe/shared/utils/table_utils.ts b/js/dataframe/shared/utils/table_utils.ts index 3e3d8f4f24c60..89625b1d48f90 100644 --- a/js/dataframe/shared/utils/table_utils.ts +++ b/js/dataframe/shared/utils/table_utils.ts @@ -28,12 +28,12 @@ export function sort_table_data( data: TableData, display_value: string[][] | null, styling: string[][] | null, - col: number | undefined, - dir: SortDirection | undefined + sort_columns: { col: number; direction: SortDirection }[] ): void { - if (col === undefined || dir === undefined) return; + if (!sort_columns.length) return; + if (!data || !data.length) return; - const indices = sort_data(data, col, dir); + const indices = sort_data(data, sort_columns); const new_data = indices.map((i: number) => data[i]); data.splice(0, data.length, ...new_data);