Skip to content

Commit

Permalink
fix(embeddings): add color by correctness
Browse files Browse the repository at this point in the history
  • Loading branch information
mikeldking committed Feb 7, 2023
1 parent 438323b commit 1d2de97
Show file tree
Hide file tree
Showing 8 changed files with 252 additions and 150 deletions.
43 changes: 43 additions & 0 deletions app/src/components/canvas/CanvasModeRadioGroup.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import { Icon, Icons, Radio, RadioGroup } from "@arizeai/components";
import React from "react";

export enum CanvasMode {
move = "move",
select = "select",
}

/**
* TypeGuard for the canvas mode
*/
function isCanvasMode(m: unknown): m is CanvasMode {
return typeof m === "string" && m in CanvasMode;
}

type CanvasModeRadioGroupProps = {
mode: CanvasMode;
onChange: (mode: CanvasMode) => void;
};

export function CanvasModeRadioGroup(props: CanvasModeRadioGroupProps) {
return (
<RadioGroup
defaultValue={props.mode}
variant="inline-button"
size="normal"
onChange={(v) => {
if (isCanvasMode(v)) {
props.onChange(v);
} else {
throw new Error(`Unknown canvas mode: ${v}`);
}
}}
>
<Radio label="Move" value={CanvasMode.move}>
<Icon svg={<Icons.MoveFilled />} />
</Radio>
<Radio label="Select" value={CanvasMode.select}>
<Icon svg={<Icons.LassoOutline />} />
</Radio>
</RadioGroup>
);
}
33 changes: 33 additions & 0 deletions app/src/components/canvas/ColoringStrategyPicker.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import React from "react";
import { Picker, Item } from "@arizeai/components";
import { ColoringStrategy } from "./types";

function isColoringStrategy(strategy: unknown): strategy is ColoringStrategy {
return typeof strategy === "string" && strategy in ColoringStrategy;
}

const ColoringStrategies = Object.values(ColoringStrategy);

type ColoringStrategyPickerProps = {
strategy: ColoringStrategy;
onChange: (strategy: ColoringStrategy) => void;
};
export function ColoringStrategyPicker(props: ColoringStrategyPickerProps) {
const { strategy, onChange } = props;
return (
<Picker
defaultSelectedKey={strategy}
aria-label="Coloring strategy"
addonBefore="Color by"
onSelectionChange={(key) => {
if (isColoringStrategy(key)) {
onChange(key);
}
}}
>
{ColoringStrategies.map((item) => (
<Item key={item}>{item}</Item>
))}
</Picker>
);
}
196 changes: 49 additions & 147 deletions app/src/components/canvas/PointCloud.tsx
Original file line number Diff line number Diff line change
@@ -1,65 +1,46 @@
/* eslint-disable no-unused-vars */
import React, { ReactNode, useCallback, useMemo, useState } from "react";
import React, { useMemo, useState } from "react";
import {
ThreeDimensionalCanvas,
ThreeDimensionalControls,
Points,
getThreeDimensionalBounds,
ThreeDimensionalPoint,
ThreeDimensionalBounds,
LassoSelect,
PointBaseProps,
ColorSchemes,
} from "@arizeai/point-cloud";
import { ErrorBoundary } from "../ErrorBoundary";
import {
Accordion,
AccordionItem,
Form,
Icon,
Icons,
Item,
Picker,
Radio,
RadioGroup,
theme,
} from "@arizeai/components";
import { theme } from "@arizeai/components";
import { css } from "@emotion/react";
import { shade } from "polished";
import { ControlPanel } from "./ControlPanel";

const DIM_AMOUNT = 0.5;

export type ThreeDimensionalPointItem = {
position: ThreeDimensionalPoint;
metaData: unknown;
};
import { ColoringStrategyPicker } from "./ColoringStrategyPicker";
import { CanvasMode, CanvasModeRadioGroup } from "./CanvasModeRadioGroup";
import { PointCloudPoints } from "./PointCloudPoints";
import { ThreeDimensionalPointItem } from "./types";
import { ColoringStrategy } from "./types";
import { createColorFn } from "./coloring";

export type PointCloudProps = {
primaryData: ThreeDimensionalPointItem[];
referenceData?: ThreeDimensionalPointItem[];
};

enum CanvasMode {
move = "move",
select = "select",
}

/**
* TypeGuard for the canvas mode
*/
function isCanvasMode(m: unknown): m is CanvasMode {
return typeof m === "string" && m in CanvasMode;
}

const CONTROL_PANEL_WIDTH = 300;
const DEFAULT_COLOR_SCHEME = ColorSchemes.Discrete2.WhiteLightBlue;
/**
* Displays the tools available on the point cloud
* E.g. move vs select
*/
function CanvasTools(props: {
mode: CanvasMode;
onChange: (mode: CanvasMode) => void;
coloringStrategy: ColoringStrategy;
onColoringStrategyChange: (strategy: ColoringStrategy) => void;
canvasMode: CanvasMode;
onCanvasModeChange: (mode: CanvasMode) => void;
}) {
const {
coloringStrategy,
onColoringStrategyChange,
canvasMode,
onCanvasModeChange,
} = props;
return (
<div
css={css`
Expand All @@ -68,62 +49,20 @@ function CanvasTools(props: {
left: ${theme.spacing.margin8}px;
top: ${theme.spacing.margin8}px;
z-index: 1;
display: flex;
flex-direction: row;
gap: ${theme.spacing.margin8}px;
`}
>
<RadioGroup
defaultValue={props.mode}
variant="inline-button"
size="compact"
onChange={(v) => {
if (isCanvasMode(v)) {
props.onChange(v);
} else {
throw new Error(`Unknown canvas mode: ${v}`);
}
}}
>
<Radio label="Move" value={CanvasMode.move}>
<Icon svg={<Icons.MoveFilled />} />
</Radio>
<Radio label="Select" value={CanvasMode.select}>
<Icon svg={<Icons.LassoOutline />} />
</Radio>
</RadioGroup>
<ColoringStrategyPicker
strategy={coloringStrategy}
onChange={onColoringStrategyChange}
/>
<CanvasModeRadioGroup mode={canvasMode} onChange={onCanvasModeChange} />
</div>
);
}

function AccordionSection({ children }: { children: ReactNode }) {
return (
<section
css={css`
margin: ${theme.spacing.margin8};
`}
>
{children}
</section>
);
}

// eslint-disable-next-line @typescript-eslint/no-unused-vars
function DisplayControlPanel() {
return (
<ControlPanel position="top-left" width={CONTROL_PANEL_WIDTH}>
<Accordion variant="compact">
<AccordionItem title="Display" id="display">
<AccordionSection>
<Form>
<Picker label="Color by" defaultSelectedKey={"dataset"}>
<Item key="dataset">Dataset</Item>
</Picker>
</Form>
</AccordionSection>
</AccordionItem>
</Accordion>
</ControlPanel>
);
}

// eslint-disable-next-line @typescript-eslint/no-unused-vars
function SelectionControlPanel({ selectedIds }: { selectedIds: Set<string> }) {
return (
Expand All @@ -137,64 +76,11 @@ function SelectionControlPanel({ selectedIds }: { selectedIds: Set<string> }) {
);
}

function UMAPPoints({
primaryData,
referenceData,
selectedIds,
}: PointCloudProps & { selectedIds: Set<string> }) {
const primaryColor = "#7BFFFF";
const referenceColor = "#d57bff";
/** Colors to represent a dimmed variant of the color for "un-selected" */
const dimmedPrimaryColor = useMemo(() => {
// if (typeof primaryColor === "function") {
// return (p: PointBaseProps) => shade(DIM_AMOUNT)(primaryColor(p));
// }
return shade(DIM_AMOUNT, primaryColor);
}, [primaryColor]);

const dimmedReferenceColor = useMemo(() => {
// if (typeof referenceColor === "function") {
// return (p: PointBaseProps) => shade(DIM_AMOUNT)(referenceColor(p));
// }
return shade(DIM_AMOUNT, referenceColor);
}, [referenceColor]);

const primaryColorByFn = useCallback(
(p: PointBaseProps) => {
if (!selectedIds.has(p.metaData.id) && selectedIds.size > 0) {
return dimmedPrimaryColor;
}
return primaryColor;
},
[selectedIds, primaryColor, dimmedPrimaryColor]
);

const referenceColorByFn = useCallback(
(p: PointBaseProps) => {
if (!selectedIds.has(p.metaData.id) && selectedIds.size > 0) {
return dimmedReferenceColor;
}
return referenceColor;
},
[referenceColor, selectedIds, dimmedReferenceColor]
);

return (
<>
<Points data={primaryData} pointProps={{ color: primaryColorByFn }} />
{referenceData && (
<Points
data={referenceData}
pointProps={{ color: referenceColorByFn }}
/>
)}
</>
);
}

export function PointCloud({ primaryData, referenceData }: PointCloudProps) {
// AutoRotate the canvas on initial load
const [autoRotate, setAutoRotate] = useState<boolean>(true);
const [coloringStrategy, onColoringStrategyChange] =
useState<ColoringStrategy>(ColoringStrategy.dataset);
const [canvasMode, setCanvasMode] = useState<CanvasMode>(CanvasMode.move);
const [selectedIds, setSelectedIds] = useState<Set<string>>(new Set());
const allPoints = useMemo(() => {
Expand All @@ -204,11 +90,25 @@ export function PointCloud({ primaryData, referenceData }: PointCloudProps) {
return getThreeDimensionalBounds(allPoints.map((p) => p.position));
}, []);
const isMoveMode = canvasMode === CanvasMode.move;

// Determine the color of the points based on the strategy
const primaryColor = createColorFn({
coloringStrategy,
defaultColor: DEFAULT_COLOR_SCHEME[0],
});
const referenceColor = createColorFn({
coloringStrategy,
defaultColor: DEFAULT_COLOR_SCHEME[1],
});
return (
<ErrorBoundary>
{/* <DisplayControlPanel /> */}
{/* <SelectionControlPanel selectedIds={selectedIds} /> */}
<CanvasTools mode={canvasMode} onChange={setCanvasMode} />
<CanvasTools
coloringStrategy={coloringStrategy}
onColoringStrategyChange={onColoringStrategyChange}
canvasMode={canvasMode}
onCanvasModeChange={setCanvasMode}
/>
<ThreeDimensionalCanvas camera={{ position: [0, 0, 10] }}>
<ThreeDimensionalControls
autoRotate={autoRotate}
Expand All @@ -229,10 +129,12 @@ export function PointCloud({ primaryData, referenceData }: PointCloudProps) {
enabled={canvasMode === CanvasMode.select}
/>

<UMAPPoints
<PointCloudPoints
primaryData={primaryData}
referenceData={referenceData}
selectedIds={selectedIds}
primaryColor={primaryColor}
referenceColor={referenceColor}
/>
</ThreeDimensionalBounds>
</ThreeDimensionalCanvas>
Expand Down
Loading

0 comments on commit 1d2de97

Please sign in to comment.