diff --git a/app/src/GlobalStyles.tsx b/app/src/GlobalStyles.tsx index 28acb1fd88..3c08ee7858 100644 --- a/app/src/GlobalStyles.tsx +++ b/app/src/GlobalStyles.tsx @@ -33,6 +33,8 @@ export function GlobalStyles() { --px-primary-color: #9efcfd; --px-primary-color--transparent: rgb(158, 252, 253, 0.2); --px-reference-color: #baa1f9; + + --px-flex-gap-sm: ${theme.spacing.margin4}px; } `} /> diff --git a/app/src/components/canvas/PointCloud.tsx b/app/src/components/canvas/PointCloud.tsx index b12c1c530f..7b7f8ce763 100644 --- a/app/src/components/canvas/PointCloud.tsx +++ b/app/src/components/canvas/PointCloud.tsx @@ -4,7 +4,6 @@ import { css } from "@emotion/react"; import { theme } from "@arizeai/components"; import { Axes, - ColorSchemes, getThreeDimensionalBounds, LassoSelect, ThreeDimensionalBounds, @@ -16,6 +15,7 @@ import { usePointCloudStore } from "@phoenix/store"; import { CanvasMode, CanvasModeRadioGroup } from "./CanvasModeRadioGroup"; import { createColorFn } from "./coloring"; +import { DEFAULT_COLOR_SCHEME } from "./constants"; import { ControlPanel } from "./ControlPanel"; import { PointCloudClusters } from "./PointCloudClusters"; import { PointCloudPoints } from "./PointCloudPoints"; @@ -37,7 +37,6 @@ interface ProjectionProps extends PointCloudProps { } const CONTROL_PANEL_WIDTH = 300; -const DEFAULT_COLOR_SCHEME = ColorSchemes.Discrete2.WhiteLightBlue; /** * Displays the tools available on the point cloud * E.g. move vs select diff --git a/app/src/components/canvas/PointCloudDisplaySettings.tsx b/app/src/components/canvas/PointCloudDisplaySettings.tsx index b3db054ee4..2767822016 100644 --- a/app/src/components/canvas/PointCloudDisplaySettings.tsx +++ b/app/src/components/canvas/PointCloudDisplaySettings.tsx @@ -1,13 +1,19 @@ -import React from "react"; +import React, { ChangeEvent, useCallback, useMemo } from "react"; import { css } from "@emotion/react"; import { Form } from "@arizeai/components"; +import { useDatasets } from "@phoenix/contexts"; import { usePointCloudStore } from "@phoenix/store"; +import { ColoringStrategy } from "@phoenix/types"; +import { assertUnreachable } from "@phoenix/typeUtils"; import { ColoringStrategyPicker } from "./ColoringStrategyPicker"; +import { DEFAULT_COLOR_SCHEME, FALLBACK_COLOR } from "./constants"; +import { Shape, ShapeIcon } from "./ShapeIcon"; export function PointCloudDisplaySettings() { + const { referenceDataset } = useDatasets(); const [coloringStrategy, setColoringStrategy] = usePointCloudStore( (state) => [state.coloringStrategy, state.setColoringStrategy] ); @@ -25,6 +31,93 @@ export function PointCloudDisplaySettings() { onChange={setColoringStrategy} /> + {} + {referenceDataset != null ? : null} ); + + function DatasetVisibilitySettings() { + const { datasetVisibility, setDatasetVisibility, coloringStrategy } = + usePointCloudStore((state) => ({ + datasetVisibility: state.datasetVisibility, + setDatasetVisibility: state.setDatasetVisibility, + coloringStrategy: state.coloringStrategy, + })); + + const handleDatasetVisibilityChange = useCallback( + (event: ChangeEvent) => { + const target = event.target as HTMLInputElement; + const { name, checked } = target; + setDatasetVisibility({ + ...datasetVisibility, + [name]: checked, + }); + }, + [datasetVisibility, setDatasetVisibility] + ); + + const primaryColor = useMemo(() => { + switch (coloringStrategy) { + case ColoringStrategy.dataset: + return DEFAULT_COLOR_SCHEME[0]; + case ColoringStrategy.correctness: + return FALLBACK_COLOR; + default: + assertUnreachable(coloringStrategy); + } + }, [coloringStrategy]); + + const referenceColor = useMemo(() => { + switch (coloringStrategy) { + case ColoringStrategy.dataset: + return DEFAULT_COLOR_SCHEME[1]; + case ColoringStrategy.correctness: + return FALLBACK_COLOR; + default: + assertUnreachable(coloringStrategy); + } + }, [coloringStrategy]); + + const referenceShape = + coloringStrategy === ColoringStrategy.dataset + ? Shape.circle + : Shape.square; + + return ( +
+ + +
+ ); + } } diff --git a/app/src/components/canvas/PointCloudPoints.tsx b/app/src/components/canvas/PointCloudPoints.tsx index 403c1defaf..0c099dfcb8 100644 --- a/app/src/components/canvas/PointCloudPoints.tsx +++ b/app/src/components/canvas/PointCloudPoints.tsx @@ -4,10 +4,19 @@ import { shade } from "polished"; import { PointBaseProps, Points } from "@arizeai/point-cloud"; +import { usePointCloudStore } from "@phoenix/store"; +import { ColoringStrategy } from "@phoenix/types"; + import { PointColor, ThreeDimensionalPointItem } from "./types"; const DIM_AMOUNT = 0.5; +/** + * The amount to multiply the radius by to get the appropriate cube size + * E.g. size = radius * CUBE_RADIUS_MULTIPLIER + */ +const CUBE_RADIUS_MULTIPLIER = 1.7; + /** * Invokes the color function if it is a function, otherwise returns the color * @param point @@ -34,6 +43,11 @@ type PointCloudPointsProps = { selectedIds: Set; radius: number; }; + +/** + * Function component that renders the points in the point cloud + * Split out into it's own component to maximize performance and caching + */ export function PointCloudPoints({ primaryData, referenceData, @@ -42,6 +56,21 @@ export function PointCloudPoints({ referenceColor, radius, }: PointCloudPointsProps) { + const { datasetVisibility, coloringStrategy } = usePointCloudStore( + (state) => { + return { + datasetVisibility: state.datasetVisibility, + coloringStrategy: state.coloringStrategy, + }; + } + ); + + // Only use a cube shape if the coloring strategy is not dataset + const referenceDatasetPointShape = useMemo( + () => (coloringStrategy !== ColoringStrategy.dataset ? "cube" : "sphere"), + [coloringStrategy] + ); + /** Colors to represent a dimmed variant of the color for "un-selected" */ const dimmedPrimaryColor = useMemo(() => { if (typeof primaryColor === "function") { @@ -77,18 +106,27 @@ export function PointCloudPoints({ [referenceColor, selectedIds, dimmedReferenceColor] ); + const showReferencePoints = datasetVisibility.reference && referenceData; + return ( <> - - {referenceData && ( + {datasetVisibility.primary ? ( + + ) : null} + {showReferencePoints ? ( - )} + ) : null} ); } diff --git a/app/src/components/canvas/ShapeIcon.tsx b/app/src/components/canvas/ShapeIcon.tsx new file mode 100644 index 0000000000..5daf9288b4 --- /dev/null +++ b/app/src/components/canvas/ShapeIcon.tsx @@ -0,0 +1,75 @@ +import React, { useMemo } from "react"; +import { css } from "@emotion/react"; + +import { assertUnreachable } from "@phoenix/typeUtils"; + +export enum Shape { + square = "square", + circle = "circle", +} + +type ShapeIconProps = { + /** + * The shape of the icon / symbol + */ + shape: Shape; + /** + * The color of the icon / symbol + */ + color: string; +}; + +const SquareSVG = () => ( + + + +); + +const CircleSVG = () => { + return ( + + + + ); +}; + +export function ShapeIcon(props: ShapeIconProps) { + const { shape, color } = props; + const shapeSVG = useMemo(() => { + switch (shape) { + case Shape.square: + return ; + case Shape.circle: + return ; + default: + assertUnreachable(shape); + } + }, [shape]); + + return ( + + {shapeSVG} + + ); +} diff --git a/app/src/components/canvas/constants.ts b/app/src/components/canvas/constants.ts new file mode 100644 index 0000000000..fc54bda243 --- /dev/null +++ b/app/src/components/canvas/constants.ts @@ -0,0 +1,8 @@ +import { ColorSchemes } from "@arizeai/point-cloud"; + +export const DEFAULT_COLOR_SCHEME = ColorSchemes.Discrete2.WhiteLightBlue; + +/** + * The default color to use when coloringStrategy does not apply. + */ +export const FALLBACK_COLOR = "#555555"; diff --git a/app/src/store/pointCloudStore.ts b/app/src/store/pointCloudStore.ts index 56e486c303..e27a8c16ca 100644 --- a/app/src/store/pointCloudStore.ts +++ b/app/src/store/pointCloudStore.ts @@ -3,6 +3,14 @@ import { devtools } from "zustand/middleware"; import { ColoringStrategy } from "@phoenix/types"; +/** + * The visibility of the two datasets in the point cloud. + */ +type DatasetVisibility = { + primary: boolean; + reference: boolean; +}; + export type PointCloudState = { /** * The IDs of the points that are currently selected. @@ -28,6 +36,17 @@ export type PointCloudState = { * Sets the coloring strategy to the given value. */ setColoringStrategy: (strategy: ColoringStrategy) => void; + /** + * The visibility of the two datasets in the point cloud. + * @default { primary: true, reference: true } + */ + datasetVisibility: DatasetVisibility; + /** + * Sets the dataset visibility to the given value. + * @param {DatasetVisibility} visibility + * @returns {void} + */ + setDatasetVisibility: (visibility: DatasetVisibility) => void; }; const pointCloudStore: StateCreator = (set) => ({ @@ -37,6 +56,8 @@ const pointCloudStore: StateCreator = (set) => ({ setSelectedClusterId: (id) => set({ selectedClusterId: id }), coloringStrategy: ColoringStrategy.dataset, setColoringStrategy: (strategy) => set({ coloringStrategy: strategy }), + datasetVisibility: { primary: true, reference: true }, + setDatasetVisibility: (visibility) => set({ datasetVisibility: visibility }), }); export const usePointCloudStore = create()(