Skip to content

Commit 2564301

Browse files
fix(ui): fix canvas model switching (#4221)
## What type of PR is this? (check all applicable) - [ ] Refactor - [ ] Feature - [x] Bug Fix - [ ] Optimization - [ ] Documentation Update - [ ] Community Node Submission ## Have you discussed this change with the InvokeAI team? - [x] Yes - [ ] No, because: ## Description There was no check at all to see if the canvas had a valid model already selected. The first model in the list was selected every time. Now, we check if its valid. If not, we go through the logic to try and pick the first valid model. If there are no valid models, or there was a problem listing models, the model selection is cleared. ## Related Tickets & Documents <!-- For pull requests that relate or close an issue, please include them below. For example having the text: "closes #1234" would connect the current pull request to issue 1234. And when we merge the pull request, Github will automatically close the issue. --> - Closes #4125 ## QA Instructions, Screenshots, Recordings <!-- Please provide steps on how to test changes, any hardware or software specifications as well as any other pertinent information. --> - Go to Canvas tab - Select a model other than the first one in the list - Go to a different tab - Go back to Canvas tab - The model should be the same as you selected
2 parents 49cce1e + da0efea commit 2564301

File tree

1 file changed

+34
-31
lines changed
  • invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners

1 file changed

+34
-31
lines changed

invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/tabChanged.ts

Lines changed: 34 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,55 +1,58 @@
11
import { modelChanged } from 'features/parameters/store/generationSlice';
22
import { setActiveTab } from 'features/ui/store/uiSlice';
3-
import { forEach } from 'lodash-es';
43
import { NON_REFINER_BASE_MODELS } from 'services/api/constants';
5-
import {
6-
MainModelConfigEntity,
7-
modelsApi,
8-
} from 'services/api/endpoints/models';
4+
import { mainModelsAdapter, modelsApi } from 'services/api/endpoints/models';
95
import { startAppListening } from '..';
106

117
export const addTabChangedListener = () => {
128
startAppListening({
139
actionCreator: setActiveTab,
14-
effect: (action, { getState, dispatch }) => {
10+
effect: async (action, { getState, dispatch }) => {
1511
const activeTabName = action.payload;
1612
if (activeTabName === 'unifiedCanvas') {
17-
// grab the models from RTK Query cache
18-
const { data } = modelsApi.endpoints.getMainModels.select(
19-
NON_REFINER_BASE_MODELS
20-
)(getState());
13+
const currentBaseModel = getState().generation.model?.base_model;
2114

22-
if (!data) {
23-
// no models yet, so we can't do anything
24-
dispatch(modelChanged(null));
15+
if (currentBaseModel && ['sd-1', 'sd-2'].includes(currentBaseModel)) {
16+
// if we're already on a valid model, no change needed
2517
return;
2618
}
2719

28-
// need to filter out all the invalid canvas models (currently, this is just sdxl)
29-
const validCanvasModels: MainModelConfigEntity[] = [];
30-
31-
forEach(data.entities, (entity) => {
32-
if (!entity) {
20+
try {
21+
// just grab fresh models
22+
const modelsRequest = dispatch(
23+
modelsApi.endpoints.getMainModels.initiate(NON_REFINER_BASE_MODELS)
24+
);
25+
const models = await modelsRequest.unwrap();
26+
// cancel this cache subscription
27+
modelsRequest.unsubscribe();
28+
29+
if (!models.ids.length) {
30+
// no valid canvas models
31+
dispatch(modelChanged(null));
3332
return;
3433
}
35-
if (['sd-1', 'sd-2'].includes(entity.base_model)) {
36-
validCanvasModels.push(entity);
34+
35+
// need to filter out all the invalid canvas models (currently sdxl & refiner)
36+
const validCanvasModels = mainModelsAdapter
37+
.getSelectors()
38+
.selectAll(models)
39+
.filter((model) => ['sd-1', 'sd-2'].includes(model.base_model));
40+
41+
const firstValidCanvasModel = validCanvasModels[0];
42+
43+
if (!firstValidCanvasModel) {
44+
// no valid canvas models
45+
dispatch(modelChanged(null));
46+
return;
3747
}
38-
});
3948

40-
// this could still be undefined even tho TS doesn't say so
41-
const firstValidCanvasModel = validCanvasModels[0];
49+
const { base_model, model_name, model_type } = firstValidCanvasModel;
4250

43-
if (!firstValidCanvasModel) {
44-
// uh oh, we have no models that are valid for canvas
51+
dispatch(modelChanged({ base_model, model_name, model_type }));
52+
} catch {
53+
// network request failed, bail
4554
dispatch(modelChanged(null));
46-
return;
4755
}
48-
49-
// only store the model name and base model in redux
50-
const { base_model, model_name, model_type } = firstValidCanvasModel;
51-
52-
dispatch(modelChanged({ base_model, model_name, model_type }));
5356
}
5457
},
5558
});

0 commit comments

Comments
 (0)