From b43cfc3b06bc832a569c8b0e9b1dd7a17a754ced Mon Sep 17 00:00:00 2001 From: imanjra Date: Fri, 15 Nov 2024 09:04:30 -0500 Subject: [PATCH] various model evaluation fixes and enhancements --- .../NativeModelEvaluationView/Evaluation.tsx | 394 ++++++++++++++---- .../EvaluationNotes.tsx | 13 +- .../EvaluationPlot.tsx | 2 + .../NativeModelEvaluationView/index.tsx | 3 +- .../NativeModelEvaluationView/utils.ts | 12 +- .../panels/model_evaluation/__init__.py | 123 ++++-- 6 files changed, 427 insertions(+), 120 deletions(-) diff --git a/app/packages/core/src/plugins/SchemaIO/components/NativeModelEvaluationView/Evaluation.tsx b/app/packages/core/src/plugins/SchemaIO/components/NativeModelEvaluationView/Evaluation.tsx index 551bb638c3..7bd6a5a11d 100644 --- a/app/packages/core/src/plugins/SchemaIO/components/NativeModelEvaluationView/Evaluation.tsx +++ b/app/packages/core/src/plugins/SchemaIO/components/NativeModelEvaluationView/Evaluation.tsx @@ -1,4 +1,5 @@ import { Dialog } from "@fiftyone/components"; +import { view } from "@fiftyone/state"; import { ArrowBack, ArrowDropDown, @@ -19,12 +20,15 @@ import { Box, Button, Card, + Checkbox, CircularProgress, + FormControlLabel, IconButton, MenuItem, Select, Stack, styled, + SxProps, Table, TableBody, TableCell, @@ -37,6 +41,7 @@ import { useTheme, } from "@mui/material"; import React, { useEffect, useMemo, useState } from "react"; +import { useRecoilState } from "recoil"; import EvaluationNotes from "./EvaluationNotes"; import EvaluationPlot from "./EvaluationPlot"; import Status from "./Status"; @@ -44,6 +49,7 @@ import { formatValue, getNumericDifference, useTriggerEvent } from "./utils"; const KEY_COLOR = "#ff6d04"; const COMPARE_KEY_COLOR = "#03a9f4"; +const DEFAULT_BAR_CONFIG = { sortBy: "az" }; export default function Evaluation(props: EvaluationProps) { const { @@ -64,20 +70,18 @@ export default function Evaluation(props: EvaluationProps) { const [expanded, setExpanded] = React.useState("summary"); const [mode, setMode] = useState("chart"); const [editNoteState, setEditNoteState] = useState({ open: false, note: "" }); - const [barConfigState, setBarConfigState] = useState({ - sortBy: "", - limit: 0, - }); - const [barConfigDialogState, setBarConfigDialogState] = useState({ - open: false, - sortBy: "best", - limit: 20, - }); + const [classPerformanceConfig, setClassPerformanceConfig] = + useState({}); + const [classPerformanceDialogConfig, setClassPerformanceDialogConfig] = + useState(DEFAULT_BAR_CONFIG); + const [confusionMatrixConfig, setConfusionMatrixConfig] = + useState({ log: true }); + const [confusionMatrixDialogConfig, setConfusionMatrixDialogConfig] = + useState(DEFAULT_BAR_CONFIG); const [metricMode, setMetricMode] = useState("chart"); const [classMode, setClassMode] = useState("chart"); const [performanceClass, setPerformanceClass] = useState("precision"); const [loadingCompare, setLoadingCompare] = useState(false); - const [viewState, setViewState] = useState({ type: "", view: {} }); const evaluation = useMemo(() => { const evaluation = data?.[`evaluation_${name}`]; return evaluation; @@ -87,11 +91,14 @@ export default function Evaluation(props: EvaluationProps) { return evaluation; }, [data]); const confusionMatrix = useMemo(() => { - return evaluation?.confusion_matrix; - }, [evaluation]); + return getMatrix(evaluation?.confusion_matrices, confusionMatrixConfig); + }, [evaluation, confusionMatrixConfig]); const compareConfusionMatrix = useMemo(() => { - return compareEvaluation?.confusion_matrix; - }, [compareEvaluation]); + return getMatrix( + compareEvaluation?.confusion_matrices, + confusionMatrixConfig + ); + }, [compareEvaluation, confusionMatrixConfig]); const compareKeys = useMemo(() => { const keys: string[] = []; const evaluations = data?.evaluations || []; @@ -125,12 +132,16 @@ export default function Evaluation(props: EvaluationProps) { }, [compareEvaluation, compareKey]); const triggerEvent = useTriggerEvent(); + const activeFilter = useActiveFilter(evaluation, compareEvaluation); const closeNoteDialog = () => { setEditNoteState((note) => ({ ...note, open: false })); }; - const closeBarConfigDialog = () => { - setBarConfigDialogState((state) => ({ ...state, open: false })); + const closeClassPerformanceConfigDialog = () => { + setClassPerformanceDialogConfig((state) => ({ ...state, open: false })); + }; + const closeConfusionMatrixConfigDialog = () => { + setConfusionMatrixDialogConfig((state) => ({ ...state, open: false })); }; if (!evaluation) { @@ -360,6 +371,12 @@ export default function Evaluation(props: EvaluationProps) { value: evaluationMetrics.tp, compareValue: compareEvaluationMetrics.tp, filterable: true, + active: + activeFilter?.value === "tp" + ? activeFilter.isCompare + ? "compare" + : "selected" + : false, }, { id: "fp", @@ -368,6 +385,12 @@ export default function Evaluation(props: EvaluationProps) { compareValue: compareEvaluationMetrics.fp, lesserIsBetter: true, filterable: true, + active: + activeFilter?.value === "fp" + ? activeFilter.isCompare + ? "compare" + : "selected" + : false, }, { id: "fn", @@ -376,6 +399,12 @@ export default function Evaluation(props: EvaluationProps) { compareValue: compareEvaluationMetrics.fn, lesserIsBetter: true, filterable: true, + active: + activeFilter?.value === "fn" + ? activeFilter.isCompare + ? "compare" + : "selected" + : false, }, ]; @@ -400,8 +429,12 @@ export default function Evaluation(props: EvaluationProps) { const performanceClasses = Object.keys(perClassPerformance); const classPerformance = formatPerClassPerformance( perClassPerformance[performanceClass], - barConfigState + classPerformanceConfig ); + const selectedPoints = + activeFilter?.type === "label" + ? [classPerformance.findIndex((c) => c.id === activeFilter.value)] + : undefined; return ( @@ -596,6 +629,7 @@ export default function Evaluation(props: EvaluationProps) { lesserIsBetter, filterable, id: rowId, + active, } = row; const difference = getNumericDifference( value, @@ -608,10 +642,20 @@ export default function Evaluation(props: EvaluationProps) { 1 ); const positiveRatio = ratio > 0; - const ratioColor = positiveRatio ? "#8BC18D" : "#FF6464"; + const zeroRatio = ratio === 0; + const negativeRatio = ratio < 0; + const ratioColor = positiveRatio + ? "#8BC18D" + : negativeRatio + ? "#FF6464" + : theme.palette.text.tertiary; const showTrophy = lesserIsBetter ? difference < 0 : difference > 0; + const activeStyle: SxProps = { + backgroundColor: theme.palette.voxel["500"], + color: "#FFFFFF", + }; return ( @@ -635,13 +679,19 @@ export default function Evaluation(props: EvaluationProps) { )} {filterable && ( { loadView("field", { field: rowId }); }} title="Load view" > - + )} @@ -661,7 +711,13 @@ export default function Evaluation(props: EvaluationProps) { {filterable && ( { loadView("field", { field: rowId, @@ -686,13 +742,22 @@ export default function Evaluation(props: EvaluationProps) { direction="row" sx={{ alignItems: "center" }} > - {positiveRatio ? ( + {positiveRatio && ( - ) : ( + )} + {negativeRatio && ( )} + {zeroRatio && ( + + — + + )} {CLASS_LABELS[performanceClass]} Per Class + {getConfigLabel({ + config: classPerformanceConfig, + type: "classPerformance", + dashed: true, + })} { - setBarConfigDialogState((state) => ({ + setClassPerformanceDialogConfig((state) => ({ ...state, open: true, })); @@ -919,7 +989,7 @@ export default function Evaluation(props: EvaluationProps) { color: KEY_COLOR, }, key: name, - selectedpoints: viewState.view.selectedClasses, + selectedpoints: selectedPoints, }, { histfunc: "sum", @@ -933,29 +1003,14 @@ export default function Evaluation(props: EvaluationProps) { color: COMPARE_KEY_COLOR, }, key: compareKey, - selectedpoints: viewState.view.selectedCompareClasses, + selectedpoints: selectedPoints, }, ]} onClick={({ points }) => { - const x = points[0]?.x; - const key = points[0]?.data.key; - const isCompare = key === compareKey; - const index = points[0]?.pointIndices[0]; - const viewStateX = viewState.view.x; - if (viewStateX === x) { - setViewState({ type: "", view: {} }); - loadView("clear", {}); - return; + if (selectedPoints?.[0] === points[0]?.pointIndices[0]) { + return loadView("clear", {}); } - setViewState({ - type: "class", - view: { - x, - selectedClasses: isCompare ? [] : [index], - selectedCompareClasses: isCompare ? [index] : [], - }, - }); - loadView("class", { x }); + loadView("class", { x: points[0]?.x }); }} /> )} @@ -1039,6 +1094,23 @@ export default function Evaluation(props: EvaluationProps) { Confusion Matrices + + + {getConfigLabel({ config: confusionMatrixConfig })} + + + { + setConfusionMatrixDialogConfig((state) => ({ + ...state, + open: true, + })); + }} + > + + + + { @@ -1082,6 +1156,9 @@ export default function Evaluation(props: EvaluationProps) { x: compareConfusionMatrix?.labels, y: compareConfusionMatrix?.labels, type: "heatmap", + colorscale: confusionMatrixConfig.log + ? compareConfusionMatrix?.colorscale || "viridis" + : "viridis", }, ]} /> @@ -1163,7 +1240,7 @@ export default function Evaluation(props: EvaluationProps) { { @@ -1194,9 +1271,9 @@ export default function Evaluation(props: EvaluationProps) { theme.palette.background.level2 }, }} @@ -1211,50 +1288,141 @@ export default function Evaluation(props: EvaluationProps) { Limit bars: { - setBarConfigDialogState((state) => ({ + const newLimit = parseInt(e.target.value); + setClassPerformanceDialogConfig((state) => { + return { + ...state, + limit: isNaN(newLimit) ? undefined : newLimit, + }; + }); + }} + /> + + + + + + + + theme.palette.background.level2 }, + }} + > + + + + Display Options: Confusion Matrix + + + Sort by: + + + + Limit classes: + { + const newLimit = parseInt(e.target.value); + setConfusionMatrixDialogConfig((state) => { + return { + ...state, + limit: isNaN(newLimit) ? undefined : newLimit, + }; + }); + }} /> + { + setConfusionMatrixDialogConfig((state) => ({ + ...state, + log: checked, + })); + }} + /> + } + />