diff --git a/app/src/components/canvas/CanvasModeRadioGroup.tsx b/app/src/components/canvas/CanvasModeRadioGroup.tsx
new file mode 100644
index 0000000000..cead29f7ac
--- /dev/null
+++ b/app/src/components/canvas/CanvasModeRadioGroup.tsx
@@ -0,0 +1,43 @@
+import { Icon, Icons, Radio, RadioGroup } from "@arizeai/components";
+import React from "react";
+
+export enum CanvasMode {
+ move = "move",
+ select = "select",
+}
+
+/**
+ * TypeGuard for the canvas mode
+ */
+function isCanvasMode(m: unknown): m is CanvasMode {
+ return typeof m === "string" && m in CanvasMode;
+}
+
+type CanvasModeRadioGroupProps = {
+ mode: CanvasMode;
+ onChange: (mode: CanvasMode) => void;
+};
+
+export function CanvasModeRadioGroup(props: CanvasModeRadioGroupProps) {
+ return (
+ {
+ if (isCanvasMode(v)) {
+ props.onChange(v);
+ } else {
+ throw new Error(`Unknown canvas mode: ${v}`);
+ }
+ }}
+ >
+
+ } />
+
+
+ } />
+
+
+ );
+}
diff --git a/app/src/components/canvas/ColoringStrategyPicker.tsx b/app/src/components/canvas/ColoringStrategyPicker.tsx
new file mode 100644
index 0000000000..1c4a52ea6b
--- /dev/null
+++ b/app/src/components/canvas/ColoringStrategyPicker.tsx
@@ -0,0 +1,33 @@
+import React from "react";
+import { Picker, Item } from "@arizeai/components";
+import { ColoringStrategy } from "./types";
+
+function isColoringStrategy(strategy: unknown): strategy is ColoringStrategy {
+ return typeof strategy === "string" && strategy in ColoringStrategy;
+}
+
+const ColoringStrategies = Object.values(ColoringStrategy);
+
+type ColoringStrategyPickerProps = {
+ strategy: ColoringStrategy;
+ onChange: (strategy: ColoringStrategy) => void;
+};
+export function ColoringStrategyPicker(props: ColoringStrategyPickerProps) {
+ const { strategy, onChange } = props;
+ return (
+ {
+ if (isColoringStrategy(key)) {
+ onChange(key);
+ }
+ }}
+ >
+ {ColoringStrategies.map((item) => (
+ - {item}
+ ))}
+
+ );
+}
diff --git a/app/src/components/canvas/PointCloud.tsx b/app/src/components/canvas/PointCloud.tsx
index 28d593ff79..ebd125b19f 100644
--- a/app/src/components/canvas/PointCloud.tsx
+++ b/app/src/components/canvas/PointCloud.tsx
@@ -1,65 +1,46 @@
-/* eslint-disable no-unused-vars */
-import React, { ReactNode, useCallback, useMemo, useState } from "react";
+import React, { useMemo, useState } from "react";
import {
ThreeDimensionalCanvas,
ThreeDimensionalControls,
- Points,
getThreeDimensionalBounds,
- ThreeDimensionalPoint,
ThreeDimensionalBounds,
LassoSelect,
- PointBaseProps,
+ ColorSchemes,
} from "@arizeai/point-cloud";
import { ErrorBoundary } from "../ErrorBoundary";
-import {
- Accordion,
- AccordionItem,
- Form,
- Icon,
- Icons,
- Item,
- Picker,
- Radio,
- RadioGroup,
- theme,
-} from "@arizeai/components";
+import { theme } from "@arizeai/components";
import { css } from "@emotion/react";
-import { shade } from "polished";
import { ControlPanel } from "./ControlPanel";
-
-const DIM_AMOUNT = 0.5;
-
-export type ThreeDimensionalPointItem = {
- position: ThreeDimensionalPoint;
- metaData: unknown;
-};
+import { ColoringStrategyPicker } from "./ColoringStrategyPicker";
+import { CanvasMode, CanvasModeRadioGroup } from "./CanvasModeRadioGroup";
+import { PointCloudPoints } from "./PointCloudPoints";
+import { ThreeDimensionalPointItem } from "./types";
+import { ColoringStrategy } from "./types";
+import { createColorFn } from "./coloring";
export type PointCloudProps = {
primaryData: ThreeDimensionalPointItem[];
- referenceData?: ThreeDimensionalPointItem[];
+ referenceData: ThreeDimensionalPointItem[] | null;
};
-enum CanvasMode {
- move = "move",
- select = "select",
-}
-
-/**
- * TypeGuard for the canvas mode
- */
-function isCanvasMode(m: unknown): m is CanvasMode {
- return typeof m === "string" && m in CanvasMode;
-}
-
const CONTROL_PANEL_WIDTH = 300;
+const DEFAULT_COLOR_SCHEME = ColorSchemes.Discrete2.WhiteLightBlue;
/**
* Displays the tools available on the point cloud
* E.g. move vs select
*/
function CanvasTools(props: {
- mode: CanvasMode;
- onChange: (mode: CanvasMode) => void;
+ coloringStrategy: ColoringStrategy;
+ onColoringStrategyChange: (strategy: ColoringStrategy) => void;
+ canvasMode: CanvasMode;
+ onCanvasModeChange: (mode: CanvasMode) => void;
}) {
+ const {
+ coloringStrategy,
+ onColoringStrategyChange,
+ canvasMode,
+ onCanvasModeChange,
+ } = props;
return (
- {
- if (isCanvasMode(v)) {
- props.onChange(v);
- } else {
- throw new Error(`Unknown canvas mode: ${v}`);
- }
- }}
- >
-
- } />
-
-
- } />
-
-
+
+
);
}
-function AccordionSection({ children }: { children: ReactNode }) {
- return (
-
- );
-}
-
-// eslint-disable-next-line @typescript-eslint/no-unused-vars
-function DisplayControlPanel() {
- return (
-
-
-
-
-
-
-
-
-
- );
-}
-
// eslint-disable-next-line @typescript-eslint/no-unused-vars
function SelectionControlPanel({ selectedIds }: { selectedIds: Set }) {
return (
@@ -137,64 +76,11 @@ function SelectionControlPanel({ selectedIds }: { selectedIds: Set }) {
);
}
-function UMAPPoints({
- primaryData,
- referenceData,
- selectedIds,
-}: PointCloudProps & { selectedIds: Set }) {
- const primaryColor = "#7BFFFF";
- const referenceColor = "#d57bff";
- /** Colors to represent a dimmed variant of the color for "un-selected" */
- const dimmedPrimaryColor = useMemo(() => {
- // if (typeof primaryColor === "function") {
- // return (p: PointBaseProps) => shade(DIM_AMOUNT)(primaryColor(p));
- // }
- return shade(DIM_AMOUNT, primaryColor);
- }, [primaryColor]);
-
- const dimmedReferenceColor = useMemo(() => {
- // if (typeof referenceColor === "function") {
- // return (p: PointBaseProps) => shade(DIM_AMOUNT)(referenceColor(p));
- // }
- return shade(DIM_AMOUNT, referenceColor);
- }, [referenceColor]);
-
- const primaryColorByFn = useCallback(
- (p: PointBaseProps) => {
- if (!selectedIds.has(p.metaData.id) && selectedIds.size > 0) {
- return dimmedPrimaryColor;
- }
- return primaryColor;
- },
- [selectedIds, primaryColor, dimmedPrimaryColor]
- );
-
- const referenceColorByFn = useCallback(
- (p: PointBaseProps) => {
- if (!selectedIds.has(p.metaData.id) && selectedIds.size > 0) {
- return dimmedReferenceColor;
- }
- return referenceColor;
- },
- [referenceColor, selectedIds, dimmedReferenceColor]
- );
-
- return (
- <>
-
- {referenceData && (
-
- )}
- >
- );
-}
-
export function PointCloud({ primaryData, referenceData }: PointCloudProps) {
// AutoRotate the canvas on initial load
const [autoRotate, setAutoRotate] = useState(true);
+ const [coloringStrategy, onColoringStrategyChange] =
+ useState(ColoringStrategy.dataset);
const [canvasMode, setCanvasMode] = useState(CanvasMode.move);
const [selectedIds, setSelectedIds] = useState>(new Set());
const allPoints = useMemo(() => {
@@ -204,11 +90,25 @@ export function PointCloud({ primaryData, referenceData }: PointCloudProps) {
return getThreeDimensionalBounds(allPoints.map((p) => p.position));
}, []);
const isMoveMode = canvasMode === CanvasMode.move;
+
+ // Determine the color of the points based on the strategy
+ const primaryColor = createColorFn({
+ coloringStrategy,
+ defaultColor: DEFAULT_COLOR_SCHEME[0],
+ });
+ const referenceColor = createColorFn({
+ coloringStrategy,
+ defaultColor: DEFAULT_COLOR_SCHEME[1],
+ });
return (
- {/* */}
{/* */}
-
+
-
diff --git a/app/src/components/canvas/PointCloudPoints.tsx b/app/src/components/canvas/PointCloudPoints.tsx
new file mode 100644
index 0000000000..f5336240ab
--- /dev/null
+++ b/app/src/components/canvas/PointCloudPoints.tsx
@@ -0,0 +1,87 @@
+import React from "react";
+import { shade } from "polished";
+import { Points, PointBaseProps } from "@arizeai/point-cloud";
+import { useCallback, useMemo } from "react";
+import { PointColor, ThreeDimensionalPointItem } from "./types";
+
+const DIM_AMOUNT = 0.5;
+
+/**
+ * Invokes the color function if it is a function, otherwise returns the color
+ * @param point
+ * @param color
+ * @returns {string} colorString
+ */
+function invokeColor(point: PointBaseProps, color: PointColor) {
+ if (typeof color === "function") {
+ return color(point);
+ }
+ return color;
+}
+type PointCloudPointsProps = {
+ /**
+ * The primary data to display in the point cloud
+ */
+ primaryData: ThreeDimensionalPointItem[];
+ /**
+ * Optional second set of data to display in the point cloud
+ */
+ referenceData: ThreeDimensionalPointItem[] | null;
+ primaryColor: PointColor;
+ referenceColor: PointColor;
+ selectedIds: Set;
+};
+export function PointCloudPoints({
+ primaryData,
+ referenceData,
+ selectedIds,
+ primaryColor,
+ referenceColor,
+}: PointCloudPointsProps) {
+ /** Colors to represent a dimmed variant of the color for "un-selected" */
+ const dimmedPrimaryColor = useMemo(() => {
+ if (typeof primaryColor === "function") {
+ return (p: PointBaseProps) => shade(DIM_AMOUNT)(primaryColor(p));
+ }
+ return shade(DIM_AMOUNT, primaryColor);
+ }, [primaryColor]);
+
+ const dimmedReferenceColor = useMemo(() => {
+ if (typeof referenceColor === "function") {
+ return (p: PointBaseProps) => shade(DIM_AMOUNT)(referenceColor(p));
+ }
+ return shade(DIM_AMOUNT, referenceColor);
+ }, [referenceColor]);
+
+ const primaryColorByFn = useCallback(
+ (point: PointBaseProps) => {
+ if (!selectedIds.has(point.metaData.id) && selectedIds.size > 0) {
+ return invokeColor(point, dimmedPrimaryColor);
+ }
+ return invokeColor(point, primaryColor);
+ },
+ [selectedIds, primaryColor, dimmedPrimaryColor]
+ );
+
+ const referenceColorByFn = useCallback(
+ (point: PointBaseProps) => {
+ if (!selectedIds.has(point.metaData.id) && selectedIds.size > 0) {
+ return invokeColor(point, dimmedReferenceColor);
+ }
+ return invokeColor(point, referenceColor);
+ },
+ [referenceColor, selectedIds, dimmedReferenceColor]
+ );
+
+ return (
+ <>
+
+ {referenceData && (
+
+ )}
+ >
+ );
+}
diff --git a/app/src/components/canvas/coloring.ts b/app/src/components/canvas/coloring.ts
new file mode 100644
index 0000000000..eeff6d8bb8
--- /dev/null
+++ b/app/src/components/canvas/coloring.ts
@@ -0,0 +1,23 @@
+import { ColoringStrategy, PointColor } from "./types";
+
+type ColoringConfig = {
+ coloringStrategy: ColoringStrategy;
+ defaultColor: string;
+};
+
+/**
+ * A curried function that generates a color function based on the given config.
+ * @param {ColoringConfig} config
+ * @returns {ColorFn}
+ */
+export const createColorFn =
+ (config: ColoringConfig): PointColor =>
+ (_point) => {
+ const { coloringStrategy, defaultColor } = config;
+ switch (coloringStrategy) {
+ case ColoringStrategy.dataset:
+ return defaultColor;
+ case ColoringStrategy.correctness:
+ return "green";
+ }
+ };
diff --git a/app/src/components/canvas/index.tsx b/app/src/components/canvas/index.tsx
index 48d19d32b4..53fb5bfbd6 100644
--- a/app/src/components/canvas/index.tsx
+++ b/app/src/components/canvas/index.tsx
@@ -1 +1,2 @@
export * from "./PointCloud";
+export * from "./types";
diff --git a/app/src/components/canvas/types.ts b/app/src/components/canvas/types.ts
new file mode 100644
index 0000000000..82717f9ca4
--- /dev/null
+++ b/app/src/components/canvas/types.ts
@@ -0,0 +1,13 @@
+import { PointsProps, ThreeDimensionalPoint } from "@arizeai/point-cloud";
+
+export type ThreeDimensionalPointItem = {
+ position: ThreeDimensionalPoint;
+ metaData: unknown;
+};
+
+export enum ColoringStrategy {
+ dataset = "dataset",
+ correctness = "correctness",
+}
+
+export type PointColor = PointsProps["pointProps"]["color"];
diff --git a/app/src/pages/Embedding.tsx b/app/src/pages/Embedding.tsx
index 0759508d40..225b420d2c 100644
--- a/app/src/pages/Embedding.tsx
+++ b/app/src/pages/Embedding.tsx
@@ -34,6 +34,34 @@ const EmbeddingUMAPQuery = graphql`
y
z
}
+ ... on Point2D {
+ x
+ y
+ }
+ }
+ embeddingMetadata {
+ linkToData
+ rawData
+ }
+ eventMetadata {
+ predictionLabel
+ actualLabel
+ predictionScore
+ actualScore
+ }
+ }
+ referenceData {
+ coordinates {
+ __typename
+ ... on Point3D {
+ x
+ y
+ z
+ }
+ ... on Point2D {
+ x
+ y
+ }
}
embeddingMetadata {
linkToData
@@ -55,6 +83,7 @@ const EmbeddingUMAPQuery = graphql`
}
}
`;
+
export function Embedding() {
const embeddingDimensionId = useEmbeddingDimensionId();
const [queryReference, loadQuery] =
@@ -114,6 +143,7 @@ function umapDataEntryToThreeDimensionalPointItem(
metaData: {},
};
}
+
/**
* Fetches the umap data for the embedding dimension and passes the data to the point cloud
*/
@@ -127,12 +157,21 @@ const PointCloudDisplay = ({
queryReference
);
- const primaryData =
- data.embedding?.UMAPPoints?.data?.map(
- umapDataEntryToThreeDimensionalPointItem
- ) ?? [];
+ const sourceData = data.embedding?.UMAPPoints?.data ?? [];
+ const referenceSourceData = data.embedding?.UMAPPoints?.referenceData;
- return ;
+ return (
+
+ );
};
export async function embeddingLoader(args: LoaderFunctionArgs) {
diff --git a/app/src/pages/__generated__/EmbeddingUMAPQuery.graphql.ts b/app/src/pages/__generated__/EmbeddingUMAPQuery.graphql.ts
index 966f117094..9e6d8ec05b 100644
--- a/app/src/pages/__generated__/EmbeddingUMAPQuery.graphql.ts
+++ b/app/src/pages/__generated__/EmbeddingUMAPQuery.graphql.ts
@@ -1,5 +1,5 @@
/**
- * @generated SignedSource<<802b773fafd62d25ba7fe368c0f7b299>>
+ * @generated SignedSource<<872ae35d7c0e3b32b7782657f7614a42>>
* @lightSyntaxTransform
* @nogrep
*/
@@ -26,6 +26,36 @@ export type EmbeddingUMAPQuery$data = {
}>;
readonly data: ReadonlyArray<{
readonly coordinates: {
+ readonly __typename: "Point2D";
+ readonly x: number;
+ readonly y: number;
+ } | {
+ readonly __typename: "Point3D";
+ readonly x: number;
+ readonly y: number;
+ readonly z: number;
+ } | {
+ // This will never be '%other', but we need some
+ // value in case none of the concrete values match.
+ readonly __typename: "%other";
+ };
+ readonly embeddingMetadata: {
+ readonly linkToData: string | null;
+ readonly rawData: string | null;
+ };
+ readonly eventMetadata: {
+ readonly actualLabel: string | null;
+ readonly actualScore: number | null;
+ readonly predictionLabel: string | null;
+ readonly predictionScore: number | null;
+ };
+ }>;
+ readonly referenceData: ReadonlyArray<{
+ readonly coordinates: {
+ readonly __typename: "Point2D";
+ readonly x: number;
+ readonly y: number;
+ } | {
readonly __typename: "Point3D";
readonly x: number;
readonly y: number;
@@ -89,6 +119,20 @@ v3 = {
"storageKey": null
},
v4 = {
+ "alias": null,
+ "args": null,
+ "kind": "ScalarField",
+ "name": "x",
+ "storageKey": null
+},
+v5 = {
+ "alias": null,
+ "args": null,
+ "kind": "ScalarField",
+ "name": "y",
+ "storageKey": null
+},
+v6 = {
"alias": null,
"args": null,
"concreteType": null,
@@ -100,20 +144,8 @@ v4 = {
{
"kind": "InlineFragment",
"selections": [
- {
- "alias": null,
- "args": null,
- "kind": "ScalarField",
- "name": "x",
- "storageKey": null
- },
- {
- "alias": null,
- "args": null,
- "kind": "ScalarField",
- "name": "y",
- "storageKey": null
- },
+ (v4/*: any*/),
+ (v5/*: any*/),
{
"alias": null,
"args": null,
@@ -124,11 +156,20 @@ v4 = {
],
"type": "Point3D",
"abstractKey": null
+ },
+ {
+ "kind": "InlineFragment",
+ "selections": [
+ (v4/*: any*/),
+ (v5/*: any*/)
+ ],
+ "type": "Point2D",
+ "abstractKey": null
}
],
"storageKey": null
},
-v5 = {
+v7 = {
"alias": null,
"args": null,
"concreteType": "EmbeddingMetadata",
@@ -153,7 +194,7 @@ v5 = {
],
"storageKey": null
},
-v6 = {
+v8 = {
"alias": null,
"args": null,
"concreteType": "EventMetadata",
@@ -192,14 +233,19 @@ v6 = {
],
"storageKey": null
},
-v7 = {
+v9 = [
+ (v6/*: any*/),
+ (v7/*: any*/),
+ (v8/*: any*/)
+],
+v10 = {
"alias": null,
"args": null,
"kind": "ScalarField",
"name": "id",
"storageKey": null
},
-v8 = {
+v11 = {
"alias": null,
"args": null,
"concreteType": "Cluster",
@@ -207,7 +253,7 @@ v8 = {
"name": "clusters",
"plural": true,
"selections": [
- (v7/*: any*/),
+ (v10/*: any*/),
{
"alias": null,
"args": null,
@@ -217,7 +263,13 @@ v8 = {
}
],
"storageKey": null
-};
+},
+v12 = [
+ (v6/*: any*/),
+ (v7/*: any*/),
+ (v8/*: any*/),
+ (v10/*: any*/)
+];
return {
"fragment": {
"argumentDefinitions": (v0/*: any*/),
@@ -251,14 +303,20 @@ return {
"kind": "LinkedField",
"name": "data",
"plural": true,
- "selections": [
- (v4/*: any*/),
- (v5/*: any*/),
- (v6/*: any*/)
- ],
+ "selections": (v9/*: any*/),
+ "storageKey": null
+ },
+ {
+ "alias": null,
+ "args": null,
+ "concreteType": "UMAPPoint",
+ "kind": "LinkedField",
+ "name": "referenceData",
+ "plural": true,
+ "selections": (v9/*: any*/),
"storageKey": null
},
- (v8/*: any*/)
+ (v11/*: any*/)
],
"storageKey": null
}
@@ -306,15 +364,20 @@ return {
"kind": "LinkedField",
"name": "data",
"plural": true,
- "selections": [
- (v4/*: any*/),
- (v5/*: any*/),
- (v6/*: any*/),
- (v7/*: any*/)
- ],
+ "selections": (v12/*: any*/),
+ "storageKey": null
+ },
+ {
+ "alias": null,
+ "args": null,
+ "concreteType": "UMAPPoint",
+ "kind": "LinkedField",
+ "name": "referenceData",
+ "plural": true,
+ "selections": (v12/*: any*/),
"storageKey": null
},
- (v8/*: any*/)
+ (v11/*: any*/)
],
"storageKey": null
}
@@ -326,23 +389,23 @@ return {
"kind": "TypeDiscriminator",
"abstractKey": "__isNode"
},
- (v7/*: any*/)
+ (v10/*: any*/)
],
"storageKey": null
}
]
},
"params": {
- "cacheID": "b55c61f5395270230ee21b6d270f7ea5",
+ "cacheID": "01483fd42cd9c8074f0f1904871c19b5",
"id": null,
"metadata": {},
"name": "EmbeddingUMAPQuery",
"operationKind": "query",
- "text": "query EmbeddingUMAPQuery(\n $id: GlobalID!\n $timeRange: TimeRange!\n) {\n embedding: node(id: $id) {\n __typename\n ... on EmbeddingDimension {\n UMAPPoints(timeRange: $timeRange) {\n data {\n coordinates {\n __typename\n ... on Point3D {\n x\n y\n z\n }\n }\n embeddingMetadata {\n linkToData\n rawData\n }\n eventMetadata {\n predictionLabel\n actualLabel\n predictionScore\n actualScore\n }\n id\n }\n clusters {\n id\n pointIds\n }\n }\n }\n __isNode: __typename\n id\n }\n}\n"
+ "text": "query EmbeddingUMAPQuery(\n $id: GlobalID!\n $timeRange: TimeRange!\n) {\n embedding: node(id: $id) {\n __typename\n ... on EmbeddingDimension {\n UMAPPoints(timeRange: $timeRange) {\n data {\n coordinates {\n __typename\n ... on Point3D {\n x\n y\n z\n }\n ... on Point2D {\n x\n y\n }\n }\n embeddingMetadata {\n linkToData\n rawData\n }\n eventMetadata {\n predictionLabel\n actualLabel\n predictionScore\n actualScore\n }\n id\n }\n referenceData {\n coordinates {\n __typename\n ... on Point3D {\n x\n y\n z\n }\n ... on Point2D {\n x\n y\n }\n }\n embeddingMetadata {\n linkToData\n rawData\n }\n eventMetadata {\n predictionLabel\n actualLabel\n predictionScore\n actualScore\n }\n id\n }\n clusters {\n id\n pointIds\n }\n }\n }\n __isNode: __typename\n id\n }\n}\n"
}
};
})();
-(node as any).hash = "5570eb0dc63ef2fdb5f36a1099faf424";
+(node as any).hash = "f807ff579322f11dc73d2131be4b9091";
export default node;