Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(embeddings): display reference umap data #225

Merged
merged 2 commits into from
Feb 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions app/src/components/canvas/CanvasModeRadioGroup.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import { Icon, Icons, Radio, RadioGroup } from "@arizeai/components";
import React from "react";

export enum CanvasMode {
move = "move",
select = "select",
}

/**
* TypeGuard for the canvas mode
*/
function isCanvasMode(m: unknown): m is CanvasMode {
return typeof m === "string" && m in CanvasMode;
}

type CanvasModeRadioGroupProps = {
mode: CanvasMode;
onChange: (mode: CanvasMode) => void;
};

export function CanvasModeRadioGroup(props: CanvasModeRadioGroupProps) {
return (
<RadioGroup
defaultValue={props.mode}
variant="inline-button"
size="normal"
onChange={(v) => {
if (isCanvasMode(v)) {
props.onChange(v);
} else {
throw new Error(`Unknown canvas mode: ${v}`);
}
}}
>
<Radio label="Move" value={CanvasMode.move}>
<Icon svg={<Icons.MoveFilled />} />
</Radio>
<Radio label="Select" value={CanvasMode.select}>
<Icon svg={<Icons.LassoOutline />} />
</Radio>
</RadioGroup>
);
}
33 changes: 33 additions & 0 deletions app/src/components/canvas/ColoringStrategyPicker.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import React from "react";
import { Picker, Item } from "@arizeai/components";
import { ColoringStrategy } from "./types";

function isColoringStrategy(strategy: unknown): strategy is ColoringStrategy {
return typeof strategy === "string" && strategy in ColoringStrategy;
}

const ColoringStrategies = Object.values(ColoringStrategy);

type ColoringStrategyPickerProps = {
strategy: ColoringStrategy;
onChange: (strategy: ColoringStrategy) => void;
};
export function ColoringStrategyPicker(props: ColoringStrategyPickerProps) {
const { strategy, onChange } = props;
return (
<Picker
defaultSelectedKey={strategy}
aria-label="Coloring strategy"
addonBefore="Color by"
onSelectionChange={(key) => {
if (isColoringStrategy(key)) {
onChange(key);
}
}}
>
{ColoringStrategies.map((item) => (
<Item key={item}>{item}</Item>
))}
</Picker>
);
}
198 changes: 50 additions & 148 deletions app/src/components/canvas/PointCloud.tsx
Original file line number Diff line number Diff line change
@@ -1,65 +1,46 @@
/* eslint-disable no-unused-vars */
import React, { ReactNode, useCallback, useMemo, useState } from "react";
import React, { useMemo, useState } from "react";
import {
ThreeDimensionalCanvas,
ThreeDimensionalControls,
Points,
getThreeDimensionalBounds,
ThreeDimensionalPoint,
ThreeDimensionalBounds,
LassoSelect,
PointBaseProps,
ColorSchemes,
} from "@arizeai/point-cloud";
import { ErrorBoundary } from "../ErrorBoundary";
import {
Accordion,
AccordionItem,
Form,
Icon,
Icons,
Item,
Picker,
Radio,
RadioGroup,
theme,
} from "@arizeai/components";
import { theme } from "@arizeai/components";
import { css } from "@emotion/react";
import { shade } from "polished";
import { ControlPanel } from "./ControlPanel";

const DIM_AMOUNT = 0.5;

export type ThreeDimensionalPointItem = {
position: ThreeDimensionalPoint;
metaData: unknown;
};
import { ColoringStrategyPicker } from "./ColoringStrategyPicker";
import { CanvasMode, CanvasModeRadioGroup } from "./CanvasModeRadioGroup";
import { PointCloudPoints } from "./PointCloudPoints";
import { ThreeDimensionalPointItem } from "./types";
import { ColoringStrategy } from "./types";
import { createColorFn } from "./coloring";

export type PointCloudProps = {
primaryData: ThreeDimensionalPointItem[];
referenceData?: ThreeDimensionalPointItem[];
referenceData: ThreeDimensionalPointItem[] | null;
};

enum CanvasMode {
move = "move",
select = "select",
}

/**
* TypeGuard for the canvas mode
*/
function isCanvasMode(m: unknown): m is CanvasMode {
return typeof m === "string" && m in CanvasMode;
}

const CONTROL_PANEL_WIDTH = 300;
const DEFAULT_COLOR_SCHEME = ColorSchemes.Discrete2.WhiteLightBlue;
/**
* Displays the tools available on the point cloud
* E.g. move vs select
*/
function CanvasTools(props: {
mode: CanvasMode;
onChange: (mode: CanvasMode) => void;
coloringStrategy: ColoringStrategy;
onColoringStrategyChange: (strategy: ColoringStrategy) => void;
canvasMode: CanvasMode;
onCanvasModeChange: (mode: CanvasMode) => void;
}) {
const {
coloringStrategy,
onColoringStrategyChange,
canvasMode,
onCanvasModeChange,
} = props;
return (
<div
css={css`
Expand All @@ -68,62 +49,20 @@ function CanvasTools(props: {
left: ${theme.spacing.margin8}px;
top: ${theme.spacing.margin8}px;
z-index: 1;
display: flex;
flex-direction: row;
gap: ${theme.spacing.margin8}px;
`}
>
<RadioGroup
defaultValue={props.mode}
variant="inline-button"
size="compact"
onChange={(v) => {
if (isCanvasMode(v)) {
props.onChange(v);
} else {
throw new Error(`Unknown canvas mode: ${v}`);
}
}}
>
<Radio label="Move" value={CanvasMode.move}>
<Icon svg={<Icons.MoveFilled />} />
</Radio>
<Radio label="Select" value={CanvasMode.select}>
<Icon svg={<Icons.LassoOutline />} />
</Radio>
</RadioGroup>
<ColoringStrategyPicker
strategy={coloringStrategy}
onChange={onColoringStrategyChange}
/>
<CanvasModeRadioGroup mode={canvasMode} onChange={onCanvasModeChange} />
</div>
);
}

function AccordionSection({ children }: { children: ReactNode }) {
return (
<section
css={css`
margin: ${theme.spacing.margin8};
`}
>
{children}
</section>
);
}

// eslint-disable-next-line @typescript-eslint/no-unused-vars
function DisplayControlPanel() {
return (
<ControlPanel position="top-left" width={CONTROL_PANEL_WIDTH}>
<Accordion variant="compact">
<AccordionItem title="Display" id="display">
<AccordionSection>
<Form>
<Picker label="Color by" defaultSelectedKey={"dataset"}>
<Item key="dataset">Dataset</Item>
</Picker>
</Form>
</AccordionSection>
</AccordionItem>
</Accordion>
</ControlPanel>
);
}

// eslint-disable-next-line @typescript-eslint/no-unused-vars
function SelectionControlPanel({ selectedIds }: { selectedIds: Set<string> }) {
return (
Expand All @@ -137,64 +76,11 @@ function SelectionControlPanel({ selectedIds }: { selectedIds: Set<string> }) {
);
}

function UMAPPoints({
primaryData,
referenceData,
selectedIds,
}: PointCloudProps & { selectedIds: Set<string> }) {
const primaryColor = "#7BFFFF";
const referenceColor = "#d57bff";
/** Colors to represent a dimmed variant of the color for "un-selected" */
const dimmedPrimaryColor = useMemo(() => {
// if (typeof primaryColor === "function") {
// return (p: PointBaseProps) => shade(DIM_AMOUNT)(primaryColor(p));
// }
return shade(DIM_AMOUNT, primaryColor);
}, [primaryColor]);

const dimmedReferenceColor = useMemo(() => {
// if (typeof referenceColor === "function") {
// return (p: PointBaseProps) => shade(DIM_AMOUNT)(referenceColor(p));
// }
return shade(DIM_AMOUNT, referenceColor);
}, [referenceColor]);

const primaryColorByFn = useCallback(
(p: PointBaseProps) => {
if (!selectedIds.has(p.metaData.id) && selectedIds.size > 0) {
return dimmedPrimaryColor;
}
return primaryColor;
},
[selectedIds, primaryColor, dimmedPrimaryColor]
);

const referenceColorByFn = useCallback(
(p: PointBaseProps) => {
if (!selectedIds.has(p.metaData.id) && selectedIds.size > 0) {
return dimmedReferenceColor;
}
return referenceColor;
},
[referenceColor, selectedIds, dimmedReferenceColor]
);

return (
<>
<Points data={primaryData} pointProps={{ color: primaryColorByFn }} />
{referenceData && (
<Points
data={referenceData}
pointProps={{ color: referenceColorByFn }}
/>
)}
</>
);
}

export function PointCloud({ primaryData, referenceData }: PointCloudProps) {
// AutoRotate the canvas on initial load
const [autoRotate, setAutoRotate] = useState<boolean>(true);
const [coloringStrategy, onColoringStrategyChange] =
useState<ColoringStrategy>(ColoringStrategy.dataset);
const [canvasMode, setCanvasMode] = useState<CanvasMode>(CanvasMode.move);
const [selectedIds, setSelectedIds] = useState<Set<string>>(new Set());
const allPoints = useMemo(() => {
Expand All @@ -204,11 +90,25 @@ export function PointCloud({ primaryData, referenceData }: PointCloudProps) {
return getThreeDimensionalBounds(allPoints.map((p) => p.position));
}, []);
const isMoveMode = canvasMode === CanvasMode.move;

// Determine the color of the points based on the strategy
const primaryColor = createColorFn({
coloringStrategy,
defaultColor: DEFAULT_COLOR_SCHEME[0],
});
const referenceColor = createColorFn({
coloringStrategy,
defaultColor: DEFAULT_COLOR_SCHEME[1],
});
return (
<ErrorBoundary>
{/* <DisplayControlPanel /> */}
{/* <SelectionControlPanel selectedIds={selectedIds} /> */}
<CanvasTools mode={canvasMode} onChange={setCanvasMode} />
<CanvasTools
coloringStrategy={coloringStrategy}
onColoringStrategyChange={onColoringStrategyChange}
canvasMode={canvasMode}
onCanvasModeChange={setCanvasMode}
/>
<ThreeDimensionalCanvas camera={{ position: [0, 0, 10] }}>
<ThreeDimensionalControls
autoRotate={autoRotate}
Expand All @@ -229,10 +129,12 @@ export function PointCloud({ primaryData, referenceData }: PointCloudProps) {
enabled={canvasMode === CanvasMode.select}
/>

<UMAPPoints
<PointCloudPoints
primaryData={primaryData}
referenceData={referenceData}
selectedIds={selectedIds}
primaryColor={primaryColor}
referenceColor={referenceColor}
/>
</ThreeDimensionalBounds>
</ThreeDimensionalCanvas>
Expand Down
Loading