Skip to content

Commit 6648dbd

Browse files
fix(ui): post-onnx fixes
1 parent e86925d commit 6648dbd

File tree

18 files changed

+228
-105
lines changed

18 files changed

+228
-105
lines changed
Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
#!/usr/bin/env sh
22
. "$(dirname -- "$0")/_/husky.sh"
33

4-
python -m black . --check
5-
64
cd invokeai/frontend/web/ && npm run lint-staged

invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import {
66
modelChanged,
77
vaeSelected,
88
} from 'features/parameters/store/generationSlice';
9-
import { zMainModel } from 'features/parameters/types/parameterSchemas';
9+
import { zMainOrOnnxModel } from 'features/parameters/types/parameterSchemas';
1010
import { addToast } from 'features/system/store/systemSlice';
1111
import { makeToast } from 'features/system/util/makeToast';
1212
import { forEach } from 'lodash-es';
@@ -19,7 +19,7 @@ export const addModelSelectedListener = () => {
1919
const log = logger('models');
2020

2121
const state = getState();
22-
const result = zMainModel.safeParse(action.payload);
22+
const result = zMainOrOnnxModel.safeParse(action.payload);
2323

2424
if (!result.success) {
2525
log.error(

invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ import {
66
vaeSelected,
77
} from 'features/parameters/store/generationSlice';
88
import {
9-
zMainModel,
9+
zMainOrOnnxModel,
10+
zSDXLRefinerModel,
1011
zVaeModel,
1112
} from 'features/parameters/types/parameterSchemas';
1213
import {
@@ -53,7 +54,7 @@ export const addModelsLoadedListener = () => {
5354
return;
5455
}
5556

56-
const result = zMainModel.safeParse(firstModel);
57+
const result = zMainOrOnnxModel.safeParse(firstModel);
5758

5859
if (!result.success) {
5960
log.error(
@@ -102,7 +103,7 @@ export const addModelsLoadedListener = () => {
102103
return;
103104
}
104105

105-
const result = zMainModel.safeParse(firstModel);
106+
const result = zSDXLRefinerModel.safeParse(firstModel);
106107

107108
if (!result.success) {
108109
log.error(

invokeai/frontend/web/src/features/nodes/util/graphBuilders/addLoRAsToGraph.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ export const addLoRAsToGraph = (
1919
state: RootState,
2020
graph: NonNullableGraph,
2121
baseNodeId: string,
22-
modelLoader: string = MAIN_MODEL_LOADER
22+
modelLoaderNodeId: string = MAIN_MODEL_LOADER
2323
): void => {
2424
/**
2525
* LoRA nodes get the UNet and CLIP models from the main model loader and apply the LoRA to them.
@@ -85,7 +85,7 @@ export const addLoRAsToGraph = (
8585
// first lora = start the lora chain, attach directly to model loader
8686
graph.edges.push({
8787
source: {
88-
node_id: modelLoader,
88+
node_id: modelLoaderNodeId,
8989
field: 'unet',
9090
},
9191
destination: {

invokeai/frontend/web/src/features/nodes/util/graphBuilders/addVAEToGraph.ts

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ import {
1717
export const addVAEToGraph = (
1818
state: RootState,
1919
graph: NonNullableGraph,
20-
modelLoader: string = MAIN_MODEL_LOADER
20+
modelLoaderNodeId: string = MAIN_MODEL_LOADER
2121
): void => {
2222
const { vae } = state.generation;
2323

@@ -34,11 +34,11 @@ export const addVAEToGraph = (
3434
vae_model: vae,
3535
};
3636
}
37-
const isOnnxModel = modelLoader == ONNX_MODEL_LOADER;
37+
const isOnnxModel = modelLoaderNodeId == ONNX_MODEL_LOADER;
3838
if (graph.id === TEXT_TO_IMAGE_GRAPH || graph.id === IMAGE_TO_IMAGE_GRAPH) {
3939
graph.edges.push({
4040
source: {
41-
node_id: isAutoVae ? modelLoader : VAE_LOADER,
41+
node_id: isAutoVae ? modelLoaderNodeId : VAE_LOADER,
4242
field: isAutoVae && isOnnxModel ? 'vae_decoder' : 'vae',
4343
},
4444
destination: {
@@ -51,7 +51,7 @@ export const addVAEToGraph = (
5151
if (graph.id === IMAGE_TO_IMAGE_GRAPH) {
5252
graph.edges.push({
5353
source: {
54-
node_id: isAutoVae ? modelLoader : VAE_LOADER,
54+
node_id: isAutoVae ? modelLoaderNodeId : VAE_LOADER,
5555
field: isAutoVae && isOnnxModel ? 'vae_decoder' : 'vae',
5656
},
5757
destination: {
@@ -64,7 +64,7 @@ export const addVAEToGraph = (
6464
if (graph.id === INPAINT_GRAPH) {
6565
graph.edges.push({
6666
source: {
67-
node_id: isAutoVae ? modelLoader : VAE_LOADER,
67+
node_id: isAutoVae ? modelLoaderNodeId : VAE_LOADER,
6868
field: isAutoVae && isOnnxModel ? 'vae_decoder' : 'vae',
6969
},
7070
destination: {

invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasTextToImageGraph.ts

Lines changed: 40 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@ import {
2020
TEXT_TO_IMAGE_GRAPH,
2121
TEXT_TO_LATENTS,
2222
} from './constants';
23+
import {
24+
ONNXTextToLatentsInvocation,
25+
TextToLatentsInvocation,
26+
} from 'services/api/types';
2327

2428
/**
2529
* Builds the Canvas tab's Text to Image graph.
@@ -53,8 +57,31 @@ export const buildCanvasTextToImageGraph = (
5357
const use_cpu = shouldUseNoiseSettings
5458
? shouldUseCpuNoise
5559
: initialGenerationState.shouldUseCpuNoise;
56-
const onnx_model_type = model.model_type.includes('onnx');
57-
const model_loader = onnx_model_type ? ONNX_MODEL_LOADER : MAIN_MODEL_LOADER;
60+
const isUsingOnnxModel = model.model_type === 'onnx';
61+
const modelLoaderNodeId = isUsingOnnxModel
62+
? ONNX_MODEL_LOADER
63+
: MAIN_MODEL_LOADER;
64+
const modelLoaderNodeType = isUsingOnnxModel
65+
? 'onnx_model_loader'
66+
: 'main_model_loader';
67+
const t2lNode: TextToLatentsInvocation | ONNXTextToLatentsInvocation =
68+
isUsingOnnxModel
69+
? {
70+
type: 't2l_onnx',
71+
id: TEXT_TO_LATENTS,
72+
is_intermediate: true,
73+
cfg_scale,
74+
scheduler,
75+
steps,
76+
}
77+
: {
78+
type: 't2l',
79+
id: TEXT_TO_LATENTS,
80+
is_intermediate: true,
81+
cfg_scale,
82+
scheduler,
83+
steps,
84+
};
5885
/**
5986
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
6087
* full graph here as a template. Then use the parameters from app state and set friendlier node
@@ -70,13 +97,13 @@ export const buildCanvasTextToImageGraph = (
7097
id: TEXT_TO_IMAGE_GRAPH,
7198
nodes: {
7299
[POSITIVE_CONDITIONING]: {
73-
type: onnx_model_type ? 'prompt_onnx' : 'compel',
100+
type: isUsingOnnxModel ? 'prompt_onnx' : 'compel',
74101
id: POSITIVE_CONDITIONING,
75102
is_intermediate: true,
76103
prompt: positivePrompt,
77104
},
78105
[NEGATIVE_CONDITIONING]: {
79-
type: onnx_model_type ? 'prompt_onnx' : 'compel',
106+
type: isUsingOnnxModel ? 'prompt_onnx' : 'compel',
80107
id: NEGATIVE_CONDITIONING,
81108
is_intermediate: true,
82109
prompt: negativePrompt,
@@ -89,17 +116,10 @@ export const buildCanvasTextToImageGraph = (
89116
height,
90117
use_cpu,
91118
},
92-
[TEXT_TO_LATENTS]: {
93-
type: onnx_model_type ? 't2l_onnx' : 't2l',
94-
id: TEXT_TO_LATENTS,
95-
is_intermediate: true,
96-
cfg_scale,
97-
scheduler,
98-
steps,
99-
},
100-
[model_loader]: {
101-
type: model_loader,
102-
id: model_loader,
119+
[t2lNode.id]: t2lNode,
120+
[modelLoaderNodeId]: {
121+
type: modelLoaderNodeType,
122+
id: modelLoaderNodeId,
103123
is_intermediate: true,
104124
model,
105125
},
@@ -110,7 +130,7 @@ export const buildCanvasTextToImageGraph = (
110130
skipped_layers: clipSkip,
111131
},
112132
[LATENTS_TO_IMAGE]: {
113-
type: onnx_model_type ? 'l2i_onnx' : 'l2i',
133+
type: isUsingOnnxModel ? 'l2i_onnx' : 'l2i',
114134
id: LATENTS_TO_IMAGE,
115135
is_intermediate: !shouldAutoSave,
116136
},
@@ -138,7 +158,7 @@ export const buildCanvasTextToImageGraph = (
138158
},
139159
{
140160
source: {
141-
node_id: model_loader,
161+
node_id: modelLoaderNodeId,
142162
field: 'clip',
143163
},
144164
destination: {
@@ -168,7 +188,7 @@ export const buildCanvasTextToImageGraph = (
168188
},
169189
{
170190
source: {
171-
node_id: model_loader,
191+
node_id: modelLoaderNodeId,
172192
field: 'unet',
173193
},
174194
destination: {
@@ -232,10 +252,10 @@ export const buildCanvasTextToImageGraph = (
232252
});
233253

234254
// add LoRA support
235-
addLoRAsToGraph(state, graph, TEXT_TO_LATENTS, model_loader);
255+
addLoRAsToGraph(state, graph, TEXT_TO_LATENTS, modelLoaderNodeId);
236256

237257
// optionally add custom VAE
238-
addVAEToGraph(state, graph, model_loader);
258+
addVAEToGraph(state, graph, modelLoaderNodeId);
239259

240260
// add dynamic prompts - also sets up core iteration and seed
241261
addDynamicPromptsToGraph(state, graph);

invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearTextToImageGraph.ts

Lines changed: 51 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@ import {
2020
TEXT_TO_IMAGE_GRAPH,
2121
TEXT_TO_LATENTS,
2222
} from './constants';
23+
import {
24+
ONNXTextToLatentsInvocation,
25+
TextToLatentsInvocation,
26+
} from 'services/api/types';
2327

2428
export const buildLinearTextToImageGraph = (
2529
state: RootState
@@ -49,8 +53,31 @@ export const buildLinearTextToImageGraph = (
4953
throw new Error('No model found in state');
5054
}
5155

52-
const onnx_model_type = model.model_type.includes('onnx');
53-
const model_loader = onnx_model_type ? ONNX_MODEL_LOADER : MAIN_MODEL_LOADER;
56+
const isUsingOnnxModel = model.model_type === 'onnx';
57+
const modelLoaderNodeId = isUsingOnnxModel
58+
? ONNX_MODEL_LOADER
59+
: MAIN_MODEL_LOADER;
60+
const modelLoaderNodeType = isUsingOnnxModel
61+
? 'onnx_model_loader'
62+
: 'main_model_loader';
63+
const t2lNode: TextToLatentsInvocation | ONNXTextToLatentsInvocation =
64+
isUsingOnnxModel
65+
? {
66+
type: 't2l_onnx',
67+
id: TEXT_TO_LATENTS,
68+
is_intermediate: true,
69+
cfg_scale,
70+
scheduler,
71+
steps,
72+
}
73+
: {
74+
type: 't2l',
75+
id: TEXT_TO_LATENTS,
76+
is_intermediate: true,
77+
cfg_scale,
78+
scheduler,
79+
steps,
80+
};
5481
/**
5582
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
5683
* full graph here as a template. Then use the parameters from app state and set friendlier node
@@ -66,50 +93,49 @@ export const buildLinearTextToImageGraph = (
6693
const graph: NonNullableGraph = {
6794
id: TEXT_TO_IMAGE_GRAPH,
6895
nodes: {
69-
[model_loader]: {
70-
type: model_loader,
71-
id: model_loader,
72-
model,
73-
},
74-
[CLIP_SKIP]: {
75-
type: 'clip_skip',
76-
id: CLIP_SKIP,
77-
skipped_layers: clipSkip,
78-
},
7996
[POSITIVE_CONDITIONING]: {
80-
type: onnx_model_type ? 'prompt_onnx' : 'compel',
97+
type: isUsingOnnxModel ? 'prompt_onnx' : 'compel',
8198
id: POSITIVE_CONDITIONING,
8299
prompt: positivePrompt,
100+
is_intermediate: true,
83101
},
84102
[NEGATIVE_CONDITIONING]: {
85-
type: onnx_model_type ? 'prompt_onnx' : 'compel',
103+
type: isUsingOnnxModel ? 'prompt_onnx' : 'compel',
86104
id: NEGATIVE_CONDITIONING,
87105
prompt: negativePrompt,
106+
is_intermediate: true,
88107
},
89108
[NOISE]: {
90109
type: 'noise',
91110
id: NOISE,
92111
width,
93112
height,
94113
use_cpu,
114+
is_intermediate: true,
115+
},
116+
[t2lNode.id]: t2lNode,
117+
[modelLoaderNodeId]: {
118+
type: modelLoaderNodeType,
119+
id: modelLoaderNodeId,
120+
is_intermediate: true,
121+
model,
95122
},
96-
[TEXT_TO_LATENTS]: {
97-
type: onnx_model_type ? 't2l_onnx' : 't2l',
98-
id: TEXT_TO_LATENTS,
99-
cfg_scale,
100-
scheduler,
101-
steps,
123+
[CLIP_SKIP]: {
124+
type: 'clip_skip',
125+
id: CLIP_SKIP,
126+
skipped_layers: clipSkip,
127+
is_intermediate: true,
102128
},
103129
[LATENTS_TO_IMAGE]: {
104-
type: onnx_model_type ? 'l2i_onnx' : 'l2i',
130+
type: isUsingOnnxModel ? 'l2i_onnx' : 'l2i',
105131
id: LATENTS_TO_IMAGE,
106132
fp32: vaePrecision === 'fp32' ? true : false,
107133
},
108134
},
109135
edges: [
110136
{
111137
source: {
112-
node_id: model_loader,
138+
node_id: modelLoaderNodeId,
113139
field: 'clip',
114140
},
115141
destination: {
@@ -119,7 +145,7 @@ export const buildLinearTextToImageGraph = (
119145
},
120146
{
121147
source: {
122-
node_id: model_loader,
148+
node_id: modelLoaderNodeId,
123149
field: 'unet',
124150
},
125151
destination: {
@@ -223,10 +249,10 @@ export const buildLinearTextToImageGraph = (
223249
});
224250

225251
// add LoRA support
226-
addLoRAsToGraph(state, graph, TEXT_TO_LATENTS, model_loader);
252+
addLoRAsToGraph(state, graph, TEXT_TO_LATENTS, modelLoaderNodeId);
227253

228254
// optionally add custom VAE
229-
addVAEToGraph(state, graph, model_loader);
255+
addVAEToGraph(state, graph, modelLoaderNodeId);
230256

231257
// add dynamic prompts - also sets up core iteration and seed
232258
addDynamicPromptsToGraph(state, graph);

invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ import {
3535
isValidSDXLNegativeStylePrompt,
3636
isValidSDXLPositiveStylePrompt,
3737
isValidSDXLRefinerAestheticScore,
38+
isValidSDXLRefinerModel,
3839
isValidSDXLRefinerStart,
3940
isValidScheduler,
4041
isValidSeed,
@@ -381,7 +382,7 @@ export const useRecallParameters = () => {
381382
dispatch(setNegativeStylePromptSDXL(negative_style_prompt));
382383
}
383384

384-
if (isValidMainModel(refiner_model)) {
385+
if (isValidSDXLRefinerModel(refiner_model)) {
385386
dispatch(refinerModelChanged(refiner_model));
386387
}
387388

0 commit comments

Comments
 (0)