From c37ec42b037478feb8337e9eddd52b6378b34a8f Mon Sep 17 00:00:00 2001 From: Andrea Papaleo Date: Mon, 22 Jul 2024 11:05:03 -0400 Subject: [PATCH] [fix] closes #456 - disposes tf model on setDefaults - clears history when model status changed to ininitialized --- .../ExampleProjectCard/ExampleProjectCard.tsx | 6 ++-- .../useClassifierModelAgain.tsx | 32 ++++++++++++++++--- src/store/classifier/classifierSlice.ts | 8 ++++- 3 files changed, 36 insertions(+), 10 deletions(-) diff --git a/src/components/cards/ExampleProjectCard/ExampleProjectCard.tsx b/src/components/cards/ExampleProjectCard/ExampleProjectCard.tsx index 23506abb..ac1aaf2e 100644 --- a/src/components/cards/ExampleProjectCard/ExampleProjectCard.tsx +++ b/src/components/cards/ExampleProjectCard/ExampleProjectCard.tsx @@ -77,7 +77,7 @@ export const ExampleProjectCard = ({ break; case ExampleProject.HumanU2OSCells: exampleProjectFilePath = - process.env.NODE_ENV !== "production" + process.env.NODE_ENV === "production" ? `${domain}/${rootPath}/HumanU2OSCellsExampleProject.${ext}` : ( await import( @@ -134,14 +134,12 @@ export const ExampleProjectCard = ({ // loadPercent will be set to 1 here dispatch(dataSlice.actions.initializeState({ data })); dispatch(projectSlice.actions.setProject({ project })); - + dispatch(classifierSlice.actions.setDefaults({})); dispatch( classifierSlice.actions.setClassifier({ classifier, }) ); - - dispatch(classifierSlice.actions.setDefaults({})); }); } catch (err) { const error: Error = err as Error; diff --git a/src/hooks/useLearningModel/useClassifierModelAgain.tsx b/src/hooks/useLearningModel/useClassifierModelAgain.tsx index d813c3a2..dcc25c66 100644 --- a/src/hooks/useLearningModel/useClassifierModelAgain.tsx +++ b/src/hooks/useLearningModel/useClassifierModelAgain.tsx @@ -1,10 +1,9 @@ -import { useEffect, useState } from "react"; +import { useEffect, useMemo, useState } from "react"; import { useDispatch, useSelector } from "react-redux"; import { selectAlertState } from "store/applicationSettings/selectors"; import { classifierSlice } from "store/classifier"; import { selectClassifierFitOptions, - selectClassifierHistory, selectClassifierModelStatus, selectClassifierSelectedModel, selectClassifierTrainingPercentage, @@ -48,9 +47,22 @@ export const useClassificationModelAgain = () => { const alertState = useSelector(selectAlertState); const fitOptions = useSelector(selectClassifierFitOptions); const trainingPercentage = useSelector(selectClassifierTrainingPercentage); - const modelHistory = useSelector((state) => { - return selectClassifierHistory(state, historyItems); - }); + + const modelHistory = useMemo(() => { + const fullHistory = selectedModel.history.history; + const selectedHistory: { [key: string]: number[] } = {}; + for (const k of historyItems) { + if (k === "epochs") { + selectedHistory[k] = selectedModel.history.epochs; + } else { + selectedHistory[k] = fullHistory.flatMap( + (cycleHistory) => cycleHistory[k] + ); + } + } + + return selectedHistory; + }, [selectedModel]); const noLabeledThingsAlert: AlertState = { alertType: AlertType.Info, name: "No labeled images", @@ -122,6 +134,7 @@ export const useClassificationModelAgain = () => { setShowWarning(true); } }, [labeledThingsCount, selectedModel]); + useEffect(() => { setTrainingAccuracy( modelHistory.categoricalAccuracy.map((y, i) => ({ x: i + 0.5, y })) @@ -174,6 +187,15 @@ export const useClassificationModelAgain = () => { } }, [fitOptions.batchSize, trainingPercentage, labeledThingsCount]); + useEffect(() => { + if (modelStatus === ModelStatus.Uninitialized) { + setTrainingAccuracy([]); + setValidationAccuracy([]); + setTrainingLoss([]); + setValidationLoss([]); + setShowPlots(false); + } + }, [modelStatus]); return { showWarning, setShowWarning, diff --git a/src/store/classifier/classifierSlice.ts b/src/store/classifier/classifierSlice.ts index 28d42d69..1d37ace1 100644 --- a/src/store/classifier/classifierSlice.ts +++ b/src/store/classifier/classifierSlice.ts @@ -59,7 +59,10 @@ export const classifierSlice = createSlice({ name: "classifier", initialState: initialState, reducers: { - resetClassifier: () => initialState, + resetClassifier: (state) => { + availableClassifierModels[state.selectedModelIdx].dispose(); + return initialState; + }, setClassifier( state, action: PayloadAction<{ classifier: ClassifierState }> @@ -70,6 +73,9 @@ export const classifierSlice = createSlice({ }, setDefaults(state, action: PayloadAction<{}>) { // TODO - segmenter: dispose() and state.selectedModel = SimpleCNN(), or whatever + + availableClassifierModels[state.selectedModelIdx].dispose(); + state.modelStatus = ModelStatus.Uninitialized; state.evaluationResult = { confusionMatrix: [],