diff --git a/app/src/components/canvas/CanvasModeRadioGroup.tsx b/app/src/components/canvas/CanvasModeRadioGroup.tsx new file mode 100644 index 0000000000..cead29f7ac --- /dev/null +++ b/app/src/components/canvas/CanvasModeRadioGroup.tsx @@ -0,0 +1,43 @@ +import { Icon, Icons, Radio, RadioGroup } from "@arizeai/components"; +import React from "react"; + +export enum CanvasMode { + move = "move", + select = "select", +} + +/** + * TypeGuard for the canvas mode + */ +function isCanvasMode(m: unknown): m is CanvasMode { + return typeof m === "string" && m in CanvasMode; +} + +type CanvasModeRadioGroupProps = { + mode: CanvasMode; + onChange: (mode: CanvasMode) => void; +}; + +export function CanvasModeRadioGroup(props: CanvasModeRadioGroupProps) { + return ( + { + if (isCanvasMode(v)) { + props.onChange(v); + } else { + throw new Error(`Unknown canvas mode: ${v}`); + } + }} + > + + } /> + + + } /> + + + ); +} diff --git a/app/src/components/canvas/ColoringStrategyPicker.tsx b/app/src/components/canvas/ColoringStrategyPicker.tsx new file mode 100644 index 0000000000..1c4a52ea6b --- /dev/null +++ b/app/src/components/canvas/ColoringStrategyPicker.tsx @@ -0,0 +1,33 @@ +import React from "react"; +import { Picker, Item } from "@arizeai/components"; +import { ColoringStrategy } from "./types"; + +function isColoringStrategy(strategy: unknown): strategy is ColoringStrategy { + return typeof strategy === "string" && strategy in ColoringStrategy; +} + +const ColoringStrategies = Object.values(ColoringStrategy); + +type ColoringStrategyPickerProps = { + strategy: ColoringStrategy; + onChange: (strategy: ColoringStrategy) => void; +}; +export function ColoringStrategyPicker(props: ColoringStrategyPickerProps) { + const { strategy, onChange } = props; + return ( + { + if (isColoringStrategy(key)) { + onChange(key); + } + }} + > + {ColoringStrategies.map((item) => ( + {item} + ))} + + ); +} diff --git a/app/src/components/canvas/PointCloud.tsx b/app/src/components/canvas/PointCloud.tsx index 28d593ff79..ebd125b19f 100644 --- a/app/src/components/canvas/PointCloud.tsx +++ b/app/src/components/canvas/PointCloud.tsx @@ -1,65 +1,46 @@ -/* eslint-disable no-unused-vars */ -import React, { ReactNode, useCallback, useMemo, useState } from "react"; +import React, { useMemo, useState } from "react"; import { ThreeDimensionalCanvas, ThreeDimensionalControls, - Points, getThreeDimensionalBounds, - ThreeDimensionalPoint, ThreeDimensionalBounds, LassoSelect, - PointBaseProps, + ColorSchemes, } from "@arizeai/point-cloud"; import { ErrorBoundary } from "../ErrorBoundary"; -import { - Accordion, - AccordionItem, - Form, - Icon, - Icons, - Item, - Picker, - Radio, - RadioGroup, - theme, -} from "@arizeai/components"; +import { theme } from "@arizeai/components"; import { css } from "@emotion/react"; -import { shade } from "polished"; import { ControlPanel } from "./ControlPanel"; - -const DIM_AMOUNT = 0.5; - -export type ThreeDimensionalPointItem = { - position: ThreeDimensionalPoint; - metaData: unknown; -}; +import { ColoringStrategyPicker } from "./ColoringStrategyPicker"; +import { CanvasMode, CanvasModeRadioGroup } from "./CanvasModeRadioGroup"; +import { PointCloudPoints } from "./PointCloudPoints"; +import { ThreeDimensionalPointItem } from "./types"; +import { ColoringStrategy } from "./types"; +import { createColorFn } from "./coloring"; export type PointCloudProps = { primaryData: ThreeDimensionalPointItem[]; - referenceData?: ThreeDimensionalPointItem[]; + referenceData: ThreeDimensionalPointItem[] | null; }; -enum CanvasMode { - move = "move", - select = "select", -} - -/** - * TypeGuard for the canvas mode - */ -function isCanvasMode(m: unknown): m is CanvasMode { - return typeof m === "string" && m in CanvasMode; -} - const CONTROL_PANEL_WIDTH = 300; +const DEFAULT_COLOR_SCHEME = ColorSchemes.Discrete2.WhiteLightBlue; /** * Displays the tools available on the point cloud * E.g. move vs select */ function CanvasTools(props: { - mode: CanvasMode; - onChange: (mode: CanvasMode) => void; + coloringStrategy: ColoringStrategy; + onColoringStrategyChange: (strategy: ColoringStrategy) => void; + canvasMode: CanvasMode; + onCanvasModeChange: (mode: CanvasMode) => void; }) { + const { + coloringStrategy, + onColoringStrategyChange, + canvasMode, + onCanvasModeChange, + } = props; return (
- { - if (isCanvasMode(v)) { - props.onChange(v); - } else { - throw new Error(`Unknown canvas mode: ${v}`); - } - }} - > - - } /> - - - } /> - - + +
); } -function AccordionSection({ children }: { children: ReactNode }) { - return ( -
- {children} -
- ); -} - -// eslint-disable-next-line @typescript-eslint/no-unused-vars -function DisplayControlPanel() { - return ( - - - - -
- - Dataset - -
-
-
-
-
- ); -} - // eslint-disable-next-line @typescript-eslint/no-unused-vars function SelectionControlPanel({ selectedIds }: { selectedIds: Set }) { return ( @@ -137,64 +76,11 @@ function SelectionControlPanel({ selectedIds }: { selectedIds: Set }) { ); } -function UMAPPoints({ - primaryData, - referenceData, - selectedIds, -}: PointCloudProps & { selectedIds: Set }) { - const primaryColor = "#7BFFFF"; - const referenceColor = "#d57bff"; - /** Colors to represent a dimmed variant of the color for "un-selected" */ - const dimmedPrimaryColor = useMemo(() => { - // if (typeof primaryColor === "function") { - // return (p: PointBaseProps) => shade(DIM_AMOUNT)(primaryColor(p)); - // } - return shade(DIM_AMOUNT, primaryColor); - }, [primaryColor]); - - const dimmedReferenceColor = useMemo(() => { - // if (typeof referenceColor === "function") { - // return (p: PointBaseProps) => shade(DIM_AMOUNT)(referenceColor(p)); - // } - return shade(DIM_AMOUNT, referenceColor); - }, [referenceColor]); - - const primaryColorByFn = useCallback( - (p: PointBaseProps) => { - if (!selectedIds.has(p.metaData.id) && selectedIds.size > 0) { - return dimmedPrimaryColor; - } - return primaryColor; - }, - [selectedIds, primaryColor, dimmedPrimaryColor] - ); - - const referenceColorByFn = useCallback( - (p: PointBaseProps) => { - if (!selectedIds.has(p.metaData.id) && selectedIds.size > 0) { - return dimmedReferenceColor; - } - return referenceColor; - }, - [referenceColor, selectedIds, dimmedReferenceColor] - ); - - return ( - <> - - {referenceData && ( - - )} - - ); -} - export function PointCloud({ primaryData, referenceData }: PointCloudProps) { // AutoRotate the canvas on initial load const [autoRotate, setAutoRotate] = useState(true); + const [coloringStrategy, onColoringStrategyChange] = + useState(ColoringStrategy.dataset); const [canvasMode, setCanvasMode] = useState(CanvasMode.move); const [selectedIds, setSelectedIds] = useState>(new Set()); const allPoints = useMemo(() => { @@ -204,11 +90,25 @@ export function PointCloud({ primaryData, referenceData }: PointCloudProps) { return getThreeDimensionalBounds(allPoints.map((p) => p.position)); }, []); const isMoveMode = canvasMode === CanvasMode.move; + + // Determine the color of the points based on the strategy + const primaryColor = createColorFn({ + coloringStrategy, + defaultColor: DEFAULT_COLOR_SCHEME[0], + }); + const referenceColor = createColorFn({ + coloringStrategy, + defaultColor: DEFAULT_COLOR_SCHEME[1], + }); return ( - {/* */} {/* */} - + - diff --git a/app/src/components/canvas/PointCloudPoints.tsx b/app/src/components/canvas/PointCloudPoints.tsx new file mode 100644 index 0000000000..f5336240ab --- /dev/null +++ b/app/src/components/canvas/PointCloudPoints.tsx @@ -0,0 +1,87 @@ +import React from "react"; +import { shade } from "polished"; +import { Points, PointBaseProps } from "@arizeai/point-cloud"; +import { useCallback, useMemo } from "react"; +import { PointColor, ThreeDimensionalPointItem } from "./types"; + +const DIM_AMOUNT = 0.5; + +/** + * Invokes the color function if it is a function, otherwise returns the color + * @param point + * @param color + * @returns {string} colorString + */ +function invokeColor(point: PointBaseProps, color: PointColor) { + if (typeof color === "function") { + return color(point); + } + return color; +} +type PointCloudPointsProps = { + /** + * The primary data to display in the point cloud + */ + primaryData: ThreeDimensionalPointItem[]; + /** + * Optional second set of data to display in the point cloud + */ + referenceData: ThreeDimensionalPointItem[] | null; + primaryColor: PointColor; + referenceColor: PointColor; + selectedIds: Set; +}; +export function PointCloudPoints({ + primaryData, + referenceData, + selectedIds, + primaryColor, + referenceColor, +}: PointCloudPointsProps) { + /** Colors to represent a dimmed variant of the color for "un-selected" */ + const dimmedPrimaryColor = useMemo(() => { + if (typeof primaryColor === "function") { + return (p: PointBaseProps) => shade(DIM_AMOUNT)(primaryColor(p)); + } + return shade(DIM_AMOUNT, primaryColor); + }, [primaryColor]); + + const dimmedReferenceColor = useMemo(() => { + if (typeof referenceColor === "function") { + return (p: PointBaseProps) => shade(DIM_AMOUNT)(referenceColor(p)); + } + return shade(DIM_AMOUNT, referenceColor); + }, [referenceColor]); + + const primaryColorByFn = useCallback( + (point: PointBaseProps) => { + if (!selectedIds.has(point.metaData.id) && selectedIds.size > 0) { + return invokeColor(point, dimmedPrimaryColor); + } + return invokeColor(point, primaryColor); + }, + [selectedIds, primaryColor, dimmedPrimaryColor] + ); + + const referenceColorByFn = useCallback( + (point: PointBaseProps) => { + if (!selectedIds.has(point.metaData.id) && selectedIds.size > 0) { + return invokeColor(point, dimmedReferenceColor); + } + return invokeColor(point, referenceColor); + }, + [referenceColor, selectedIds, dimmedReferenceColor] + ); + + return ( + <> + + {referenceData && ( + + )} + + ); +} diff --git a/app/src/components/canvas/coloring.ts b/app/src/components/canvas/coloring.ts new file mode 100644 index 0000000000..eeff6d8bb8 --- /dev/null +++ b/app/src/components/canvas/coloring.ts @@ -0,0 +1,23 @@ +import { ColoringStrategy, PointColor } from "./types"; + +type ColoringConfig = { + coloringStrategy: ColoringStrategy; + defaultColor: string; +}; + +/** + * A curried function that generates a color function based on the given config. + * @param {ColoringConfig} config + * @returns {ColorFn} + */ +export const createColorFn = + (config: ColoringConfig): PointColor => + (_point) => { + const { coloringStrategy, defaultColor } = config; + switch (coloringStrategy) { + case ColoringStrategy.dataset: + return defaultColor; + case ColoringStrategy.correctness: + return "green"; + } + }; diff --git a/app/src/components/canvas/index.tsx b/app/src/components/canvas/index.tsx index 48d19d32b4..53fb5bfbd6 100644 --- a/app/src/components/canvas/index.tsx +++ b/app/src/components/canvas/index.tsx @@ -1 +1,2 @@ export * from "./PointCloud"; +export * from "./types"; diff --git a/app/src/components/canvas/types.ts b/app/src/components/canvas/types.ts new file mode 100644 index 0000000000..82717f9ca4 --- /dev/null +++ b/app/src/components/canvas/types.ts @@ -0,0 +1,13 @@ +import { PointsProps, ThreeDimensionalPoint } from "@arizeai/point-cloud"; + +export type ThreeDimensionalPointItem = { + position: ThreeDimensionalPoint; + metaData: unknown; +}; + +export enum ColoringStrategy { + dataset = "dataset", + correctness = "correctness", +} + +export type PointColor = PointsProps["pointProps"]["color"]; diff --git a/app/src/pages/Embedding.tsx b/app/src/pages/Embedding.tsx index 0759508d40..225b420d2c 100644 --- a/app/src/pages/Embedding.tsx +++ b/app/src/pages/Embedding.tsx @@ -34,6 +34,34 @@ const EmbeddingUMAPQuery = graphql` y z } + ... on Point2D { + x + y + } + } + embeddingMetadata { + linkToData + rawData + } + eventMetadata { + predictionLabel + actualLabel + predictionScore + actualScore + } + } + referenceData { + coordinates { + __typename + ... on Point3D { + x + y + z + } + ... on Point2D { + x + y + } } embeddingMetadata { linkToData @@ -55,6 +83,7 @@ const EmbeddingUMAPQuery = graphql` } } `; + export function Embedding() { const embeddingDimensionId = useEmbeddingDimensionId(); const [queryReference, loadQuery] = @@ -114,6 +143,7 @@ function umapDataEntryToThreeDimensionalPointItem( metaData: {}, }; } + /** * Fetches the umap data for the embedding dimension and passes the data to the point cloud */ @@ -127,12 +157,21 @@ const PointCloudDisplay = ({ queryReference ); - const primaryData = - data.embedding?.UMAPPoints?.data?.map( - umapDataEntryToThreeDimensionalPointItem - ) ?? []; + const sourceData = data.embedding?.UMAPPoints?.data ?? []; + const referenceSourceData = data.embedding?.UMAPPoints?.referenceData; - return ; + return ( + + ); }; export async function embeddingLoader(args: LoaderFunctionArgs) { diff --git a/app/src/pages/__generated__/EmbeddingUMAPQuery.graphql.ts b/app/src/pages/__generated__/EmbeddingUMAPQuery.graphql.ts index 966f117094..9e6d8ec05b 100644 --- a/app/src/pages/__generated__/EmbeddingUMAPQuery.graphql.ts +++ b/app/src/pages/__generated__/EmbeddingUMAPQuery.graphql.ts @@ -1,5 +1,5 @@ /** - * @generated SignedSource<<802b773fafd62d25ba7fe368c0f7b299>> + * @generated SignedSource<<872ae35d7c0e3b32b7782657f7614a42>> * @lightSyntaxTransform * @nogrep */ @@ -26,6 +26,36 @@ export type EmbeddingUMAPQuery$data = { }>; readonly data: ReadonlyArray<{ readonly coordinates: { + readonly __typename: "Point2D"; + readonly x: number; + readonly y: number; + } | { + readonly __typename: "Point3D"; + readonly x: number; + readonly y: number; + readonly z: number; + } | { + // This will never be '%other', but we need some + // value in case none of the concrete values match. + readonly __typename: "%other"; + }; + readonly embeddingMetadata: { + readonly linkToData: string | null; + readonly rawData: string | null; + }; + readonly eventMetadata: { + readonly actualLabel: string | null; + readonly actualScore: number | null; + readonly predictionLabel: string | null; + readonly predictionScore: number | null; + }; + }>; + readonly referenceData: ReadonlyArray<{ + readonly coordinates: { + readonly __typename: "Point2D"; + readonly x: number; + readonly y: number; + } | { readonly __typename: "Point3D"; readonly x: number; readonly y: number; @@ -89,6 +119,20 @@ v3 = { "storageKey": null }, v4 = { + "alias": null, + "args": null, + "kind": "ScalarField", + "name": "x", + "storageKey": null +}, +v5 = { + "alias": null, + "args": null, + "kind": "ScalarField", + "name": "y", + "storageKey": null +}, +v6 = { "alias": null, "args": null, "concreteType": null, @@ -100,20 +144,8 @@ v4 = { { "kind": "InlineFragment", "selections": [ - { - "alias": null, - "args": null, - "kind": "ScalarField", - "name": "x", - "storageKey": null - }, - { - "alias": null, - "args": null, - "kind": "ScalarField", - "name": "y", - "storageKey": null - }, + (v4/*: any*/), + (v5/*: any*/), { "alias": null, "args": null, @@ -124,11 +156,20 @@ v4 = { ], "type": "Point3D", "abstractKey": null + }, + { + "kind": "InlineFragment", + "selections": [ + (v4/*: any*/), + (v5/*: any*/) + ], + "type": "Point2D", + "abstractKey": null } ], "storageKey": null }, -v5 = { +v7 = { "alias": null, "args": null, "concreteType": "EmbeddingMetadata", @@ -153,7 +194,7 @@ v5 = { ], "storageKey": null }, -v6 = { +v8 = { "alias": null, "args": null, "concreteType": "EventMetadata", @@ -192,14 +233,19 @@ v6 = { ], "storageKey": null }, -v7 = { +v9 = [ + (v6/*: any*/), + (v7/*: any*/), + (v8/*: any*/) +], +v10 = { "alias": null, "args": null, "kind": "ScalarField", "name": "id", "storageKey": null }, -v8 = { +v11 = { "alias": null, "args": null, "concreteType": "Cluster", @@ -207,7 +253,7 @@ v8 = { "name": "clusters", "plural": true, "selections": [ - (v7/*: any*/), + (v10/*: any*/), { "alias": null, "args": null, @@ -217,7 +263,13 @@ v8 = { } ], "storageKey": null -}; +}, +v12 = [ + (v6/*: any*/), + (v7/*: any*/), + (v8/*: any*/), + (v10/*: any*/) +]; return { "fragment": { "argumentDefinitions": (v0/*: any*/), @@ -251,14 +303,20 @@ return { "kind": "LinkedField", "name": "data", "plural": true, - "selections": [ - (v4/*: any*/), - (v5/*: any*/), - (v6/*: any*/) - ], + "selections": (v9/*: any*/), + "storageKey": null + }, + { + "alias": null, + "args": null, + "concreteType": "UMAPPoint", + "kind": "LinkedField", + "name": "referenceData", + "plural": true, + "selections": (v9/*: any*/), "storageKey": null }, - (v8/*: any*/) + (v11/*: any*/) ], "storageKey": null } @@ -306,15 +364,20 @@ return { "kind": "LinkedField", "name": "data", "plural": true, - "selections": [ - (v4/*: any*/), - (v5/*: any*/), - (v6/*: any*/), - (v7/*: any*/) - ], + "selections": (v12/*: any*/), + "storageKey": null + }, + { + "alias": null, + "args": null, + "concreteType": "UMAPPoint", + "kind": "LinkedField", + "name": "referenceData", + "plural": true, + "selections": (v12/*: any*/), "storageKey": null }, - (v8/*: any*/) + (v11/*: any*/) ], "storageKey": null } @@ -326,23 +389,23 @@ return { "kind": "TypeDiscriminator", "abstractKey": "__isNode" }, - (v7/*: any*/) + (v10/*: any*/) ], "storageKey": null } ] }, "params": { - "cacheID": "b55c61f5395270230ee21b6d270f7ea5", + "cacheID": "01483fd42cd9c8074f0f1904871c19b5", "id": null, "metadata": {}, "name": "EmbeddingUMAPQuery", "operationKind": "query", - "text": "query EmbeddingUMAPQuery(\n $id: GlobalID!\n $timeRange: TimeRange!\n) {\n embedding: node(id: $id) {\n __typename\n ... on EmbeddingDimension {\n UMAPPoints(timeRange: $timeRange) {\n data {\n coordinates {\n __typename\n ... on Point3D {\n x\n y\n z\n }\n }\n embeddingMetadata {\n linkToData\n rawData\n }\n eventMetadata {\n predictionLabel\n actualLabel\n predictionScore\n actualScore\n }\n id\n }\n clusters {\n id\n pointIds\n }\n }\n }\n __isNode: __typename\n id\n }\n}\n" + "text": "query EmbeddingUMAPQuery(\n $id: GlobalID!\n $timeRange: TimeRange!\n) {\n embedding: node(id: $id) {\n __typename\n ... on EmbeddingDimension {\n UMAPPoints(timeRange: $timeRange) {\n data {\n coordinates {\n __typename\n ... on Point3D {\n x\n y\n z\n }\n ... on Point2D {\n x\n y\n }\n }\n embeddingMetadata {\n linkToData\n rawData\n }\n eventMetadata {\n predictionLabel\n actualLabel\n predictionScore\n actualScore\n }\n id\n }\n referenceData {\n coordinates {\n __typename\n ... on Point3D {\n x\n y\n z\n }\n ... on Point2D {\n x\n y\n }\n }\n embeddingMetadata {\n linkToData\n rawData\n }\n eventMetadata {\n predictionLabel\n actualLabel\n predictionScore\n actualScore\n }\n id\n }\n clusters {\n id\n pointIds\n }\n }\n }\n __isNode: __typename\n id\n }\n}\n" } }; })(); -(node as any).hash = "5570eb0dc63ef2fdb5f36a1099faf424"; +(node as any).hash = "f807ff579322f11dc73d2131be4b9091"; export default node;