diff --git a/app/src/GlobalStyles.tsx b/app/src/GlobalStyles.tsx
index 28acb1fd88..3c08ee7858 100644
--- a/app/src/GlobalStyles.tsx
+++ b/app/src/GlobalStyles.tsx
@@ -33,6 +33,8 @@ export function GlobalStyles() {
--px-primary-color: #9efcfd;
--px-primary-color--transparent: rgb(158, 252, 253, 0.2);
--px-reference-color: #baa1f9;
+
+ --px-flex-gap-sm: ${theme.spacing.margin4}px;
}
`}
/>
diff --git a/app/src/components/canvas/PointCloud.tsx b/app/src/components/canvas/PointCloud.tsx
index b12c1c530f..7b7f8ce763 100644
--- a/app/src/components/canvas/PointCloud.tsx
+++ b/app/src/components/canvas/PointCloud.tsx
@@ -4,7 +4,6 @@ import { css } from "@emotion/react";
import { theme } from "@arizeai/components";
import {
Axes,
- ColorSchemes,
getThreeDimensionalBounds,
LassoSelect,
ThreeDimensionalBounds,
@@ -16,6 +15,7 @@ import { usePointCloudStore } from "@phoenix/store";
import { CanvasMode, CanvasModeRadioGroup } from "./CanvasModeRadioGroup";
import { createColorFn } from "./coloring";
+import { DEFAULT_COLOR_SCHEME } from "./constants";
import { ControlPanel } from "./ControlPanel";
import { PointCloudClusters } from "./PointCloudClusters";
import { PointCloudPoints } from "./PointCloudPoints";
@@ -37,7 +37,6 @@ interface ProjectionProps extends PointCloudProps {
}
const CONTROL_PANEL_WIDTH = 300;
-const DEFAULT_COLOR_SCHEME = ColorSchemes.Discrete2.WhiteLightBlue;
/**
* Displays the tools available on the point cloud
* E.g. move vs select
diff --git a/app/src/components/canvas/PointCloudDisplaySettings.tsx b/app/src/components/canvas/PointCloudDisplaySettings.tsx
index b3db054ee4..2767822016 100644
--- a/app/src/components/canvas/PointCloudDisplaySettings.tsx
+++ b/app/src/components/canvas/PointCloudDisplaySettings.tsx
@@ -1,13 +1,19 @@
-import React from "react";
+import React, { ChangeEvent, useCallback, useMemo } from "react";
import { css } from "@emotion/react";
import { Form } from "@arizeai/components";
+import { useDatasets } from "@phoenix/contexts";
import { usePointCloudStore } from "@phoenix/store";
+import { ColoringStrategy } from "@phoenix/types";
+import { assertUnreachable } from "@phoenix/typeUtils";
import { ColoringStrategyPicker } from "./ColoringStrategyPicker";
+import { DEFAULT_COLOR_SCHEME, FALLBACK_COLOR } from "./constants";
+import { Shape, ShapeIcon } from "./ShapeIcon";
export function PointCloudDisplaySettings() {
+ const { referenceDataset } = useDatasets();
const [coloringStrategy, setColoringStrategy] = usePointCloudStore(
(state) => [state.coloringStrategy, state.setColoringStrategy]
);
@@ -25,6 +31,93 @@ export function PointCloudDisplaySettings() {
onChange={setColoringStrategy}
/>
+ {}
+ {referenceDataset != null ? : null}
);
+
+ function DatasetVisibilitySettings() {
+ const { datasetVisibility, setDatasetVisibility, coloringStrategy } =
+ usePointCloudStore((state) => ({
+ datasetVisibility: state.datasetVisibility,
+ setDatasetVisibility: state.setDatasetVisibility,
+ coloringStrategy: state.coloringStrategy,
+ }));
+
+ const handleDatasetVisibilityChange = useCallback(
+ (event: ChangeEvent) => {
+ const target = event.target as HTMLInputElement;
+ const { name, checked } = target;
+ setDatasetVisibility({
+ ...datasetVisibility,
+ [name]: checked,
+ });
+ },
+ [datasetVisibility, setDatasetVisibility]
+ );
+
+ const primaryColor = useMemo(() => {
+ switch (coloringStrategy) {
+ case ColoringStrategy.dataset:
+ return DEFAULT_COLOR_SCHEME[0];
+ case ColoringStrategy.correctness:
+ return FALLBACK_COLOR;
+ default:
+ assertUnreachable(coloringStrategy);
+ }
+ }, [coloringStrategy]);
+
+ const referenceColor = useMemo(() => {
+ switch (coloringStrategy) {
+ case ColoringStrategy.dataset:
+ return DEFAULT_COLOR_SCHEME[1];
+ case ColoringStrategy.correctness:
+ return FALLBACK_COLOR;
+ default:
+ assertUnreachable(coloringStrategy);
+ }
+ }, [coloringStrategy]);
+
+ const referenceShape =
+ coloringStrategy === ColoringStrategy.dataset
+ ? Shape.circle
+ : Shape.square;
+
+ return (
+
+ );
+ }
}
diff --git a/app/src/components/canvas/PointCloudPoints.tsx b/app/src/components/canvas/PointCloudPoints.tsx
index 403c1defaf..0c099dfcb8 100644
--- a/app/src/components/canvas/PointCloudPoints.tsx
+++ b/app/src/components/canvas/PointCloudPoints.tsx
@@ -4,10 +4,19 @@ import { shade } from "polished";
import { PointBaseProps, Points } from "@arizeai/point-cloud";
+import { usePointCloudStore } from "@phoenix/store";
+import { ColoringStrategy } from "@phoenix/types";
+
import { PointColor, ThreeDimensionalPointItem } from "./types";
const DIM_AMOUNT = 0.5;
+/**
+ * The amount to multiply the radius by to get the appropriate cube size
+ * E.g. size = radius * CUBE_RADIUS_MULTIPLIER
+ */
+const CUBE_RADIUS_MULTIPLIER = 1.7;
+
/**
* Invokes the color function if it is a function, otherwise returns the color
* @param point
@@ -34,6 +43,11 @@ type PointCloudPointsProps = {
selectedIds: Set;
radius: number;
};
+
+/**
+ * Function component that renders the points in the point cloud
+ * Split out into it's own component to maximize performance and caching
+ */
export function PointCloudPoints({
primaryData,
referenceData,
@@ -42,6 +56,21 @@ export function PointCloudPoints({
referenceColor,
radius,
}: PointCloudPointsProps) {
+ const { datasetVisibility, coloringStrategy } = usePointCloudStore(
+ (state) => {
+ return {
+ datasetVisibility: state.datasetVisibility,
+ coloringStrategy: state.coloringStrategy,
+ };
+ }
+ );
+
+ // Only use a cube shape if the coloring strategy is not dataset
+ const referenceDatasetPointShape = useMemo(
+ () => (coloringStrategy !== ColoringStrategy.dataset ? "cube" : "sphere"),
+ [coloringStrategy]
+ );
+
/** Colors to represent a dimmed variant of the color for "un-selected" */
const dimmedPrimaryColor = useMemo(() => {
if (typeof primaryColor === "function") {
@@ -77,18 +106,27 @@ export function PointCloudPoints({
[referenceColor, selectedIds, dimmedReferenceColor]
);
+ const showReferencePoints = datasetVisibility.reference && referenceData;
+
return (
<>
-
- {referenceData && (
+ {datasetVisibility.primary ? (
+
+ ) : null}
+ {showReferencePoints ? (
- )}
+ ) : null}
>
);
}
diff --git a/app/src/components/canvas/ShapeIcon.tsx b/app/src/components/canvas/ShapeIcon.tsx
new file mode 100644
index 0000000000..5daf9288b4
--- /dev/null
+++ b/app/src/components/canvas/ShapeIcon.tsx
@@ -0,0 +1,75 @@
+import React, { useMemo } from "react";
+import { css } from "@emotion/react";
+
+import { assertUnreachable } from "@phoenix/typeUtils";
+
+export enum Shape {
+ square = "square",
+ circle = "circle",
+}
+
+type ShapeIconProps = {
+ /**
+ * The shape of the icon / symbol
+ */
+ shape: Shape;
+ /**
+ * The color of the icon / symbol
+ */
+ color: string;
+};
+
+const SquareSVG = () => (
+
+);
+
+const CircleSVG = () => {
+ return (
+
+ );
+};
+
+export function ShapeIcon(props: ShapeIconProps) {
+ const { shape, color } = props;
+ const shapeSVG = useMemo(() => {
+ switch (shape) {
+ case Shape.square:
+ return ;
+ case Shape.circle:
+ return ;
+ default:
+ assertUnreachable(shape);
+ }
+ }, [shape]);
+
+ return (
+
+ {shapeSVG}
+
+ );
+}
diff --git a/app/src/components/canvas/constants.ts b/app/src/components/canvas/constants.ts
new file mode 100644
index 0000000000..fc54bda243
--- /dev/null
+++ b/app/src/components/canvas/constants.ts
@@ -0,0 +1,8 @@
+import { ColorSchemes } from "@arizeai/point-cloud";
+
+export const DEFAULT_COLOR_SCHEME = ColorSchemes.Discrete2.WhiteLightBlue;
+
+/**
+ * The default color to use when coloringStrategy does not apply.
+ */
+export const FALLBACK_COLOR = "#555555";
diff --git a/app/src/store/pointCloudStore.ts b/app/src/store/pointCloudStore.ts
index 56e486c303..e27a8c16ca 100644
--- a/app/src/store/pointCloudStore.ts
+++ b/app/src/store/pointCloudStore.ts
@@ -3,6 +3,14 @@ import { devtools } from "zustand/middleware";
import { ColoringStrategy } from "@phoenix/types";
+/**
+ * The visibility of the two datasets in the point cloud.
+ */
+type DatasetVisibility = {
+ primary: boolean;
+ reference: boolean;
+};
+
export type PointCloudState = {
/**
* The IDs of the points that are currently selected.
@@ -28,6 +36,17 @@ export type PointCloudState = {
* Sets the coloring strategy to the given value.
*/
setColoringStrategy: (strategy: ColoringStrategy) => void;
+ /**
+ * The visibility of the two datasets in the point cloud.
+ * @default { primary: true, reference: true }
+ */
+ datasetVisibility: DatasetVisibility;
+ /**
+ * Sets the dataset visibility to the given value.
+ * @param {DatasetVisibility} visibility
+ * @returns {void}
+ */
+ setDatasetVisibility: (visibility: DatasetVisibility) => void;
};
const pointCloudStore: StateCreator = (set) => ({
@@ -37,6 +56,8 @@ const pointCloudStore: StateCreator = (set) => ({
setSelectedClusterId: (id) => set({ selectedClusterId: id }),
coloringStrategy: ColoringStrategy.dataset,
setColoringStrategy: (strategy) => set({ coloringStrategy: strategy }),
+ datasetVisibility: { primary: true, reference: true },
+ setDatasetVisibility: (visibility) => set({ datasetVisibility: visibility }),
});
export const usePointCloudStore = create()(