Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,55 +1,58 @@
import { modelChanged } from 'features/parameters/store/generationSlice';
import { setActiveTab } from 'features/ui/store/uiSlice';
import { forEach } from 'lodash-es';
import { NON_REFINER_BASE_MODELS } from 'services/api/constants';
import {
MainModelConfigEntity,
modelsApi,
} from 'services/api/endpoints/models';
import { mainModelsAdapter, modelsApi } from 'services/api/endpoints/models';
import { startAppListening } from '..';

export const addTabChangedListener = () => {
startAppListening({
actionCreator: setActiveTab,
effect: (action, { getState, dispatch }) => {
effect: async (action, { getState, dispatch }) => {
const activeTabName = action.payload;
if (activeTabName === 'unifiedCanvas') {
// grab the models from RTK Query cache
const { data } = modelsApi.endpoints.getMainModels.select(
NON_REFINER_BASE_MODELS
)(getState());
const currentBaseModel = getState().generation.model?.base_model;

if (!data) {
// no models yet, so we can't do anything
dispatch(modelChanged(null));
if (currentBaseModel && ['sd-1', 'sd-2'].includes(currentBaseModel)) {
// if we're already on a valid model, no change needed
return;
}

// need to filter out all the invalid canvas models (currently, this is just sdxl)
const validCanvasModels: MainModelConfigEntity[] = [];

forEach(data.entities, (entity) => {
if (!entity) {
try {
// just grab fresh models
const modelsRequest = dispatch(
modelsApi.endpoints.getMainModels.initiate(NON_REFINER_BASE_MODELS)
);
const models = await modelsRequest.unwrap();
// cancel this cache subscription
modelsRequest.unsubscribe();

if (!models.ids.length) {
// no valid canvas models
dispatch(modelChanged(null));
return;
}
if (['sd-1', 'sd-2'].includes(entity.base_model)) {
validCanvasModels.push(entity);

// need to filter out all the invalid canvas models (currently sdxl & refiner)
const validCanvasModels = mainModelsAdapter
.getSelectors()
.selectAll(models)
.filter((model) => ['sd-1', 'sd-2'].includes(model.base_model));

const firstValidCanvasModel = validCanvasModels[0];

if (!firstValidCanvasModel) {
// no valid canvas models
dispatch(modelChanged(null));
return;
}
});

// this could still be undefined even tho TS doesn't say so
const firstValidCanvasModel = validCanvasModels[0];
const { base_model, model_name, model_type } = firstValidCanvasModel;

if (!firstValidCanvasModel) {
// uh oh, we have no models that are valid for canvas
dispatch(modelChanged({ base_model, model_name, model_type }));
} catch {
// network request failed, bail
dispatch(modelChanged(null));
return;
}

// only store the model name and base model in redux
const { base_model, model_name, model_type } = firstValidCanvasModel;

dispatch(modelChanged({ base_model, model_name, model_type }));
}
},
});
Expand Down