Skip to content

Commit 9a1cfad

Browse files
authored
fix: SDXL Metadata not being retrieved (#4057)
## What type of PR is this? (check all applicable) - [x] Bug Fix ## Have you discussed this change with the InvokeAI team? - [x] Yes ## Description - SDXL Metadata was not being retrieved. This PR fixes it.
2 parents 974175b + 6d82a10 commit 9a1cfad

File tree

8 files changed

+493
-329
lines changed

8 files changed

+493
-329
lines changed

invokeai/app/invocations/params.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66

77
from invokeai.app.invocations.prompt import PromptOutput
88

9-
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
10-
InvocationConfig, InvocationContext)
9+
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
1110
from .math import FloatOutput, IntOutput
1211

1312
# Pass-through parameter nodes - used by subgraphs
@@ -68,6 +67,7 @@ class Config(InvocationConfig):
6867
def invoke(self, context: InvocationContext) -> StringOutput:
6968
return StringOutput(text=self.text)
7069

70+
7171
class ParamPromptInvocation(BaseInvocation):
7272
"""A prompt input parameter"""
7373

@@ -80,4 +80,4 @@ class Config(InvocationConfig):
8080
}
8181

8282
def invoke(self, context: InvocationContext) -> PromptOutput:
83-
return PromptOutput(prompt=self.prompt)
83+
return PromptOutput(prompt=self.prompt)

invokeai/frontend/web/src/features/gallery/components/CurrentImage/CurrentImageButtons.tsx

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,19 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
139139
useHotkeys('s', handleUseSeed, [imageDTO]);
140140

141141
const handleUsePrompt = useCallback(() => {
142-
recallBothPrompts(metadata?.positive_prompt, metadata?.negative_prompt);
143-
}, [metadata?.negative_prompt, metadata?.positive_prompt, recallBothPrompts]);
142+
recallBothPrompts(
143+
metadata?.positive_prompt,
144+
metadata?.negative_prompt,
145+
metadata?.positive_style_prompt,
146+
metadata?.negative_style_prompt
147+
);
148+
}, [
149+
metadata?.negative_prompt,
150+
metadata?.positive_prompt,
151+
metadata?.positive_style_prompt,
152+
metadata?.negative_style_prompt,
153+
recallBothPrompts,
154+
]);
144155

145156
useHotkeys('p', handleUsePrompt, [imageDTO]);
146157

invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/SingleSelectionMenuItems.tsx

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,19 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
102102

103103
// Recall parameters handlers
104104
const handleRecallPrompt = useCallback(() => {
105-
recallBothPrompts(metadata?.positive_prompt, metadata?.negative_prompt);
106-
}, [metadata?.negative_prompt, metadata?.positive_prompt, recallBothPrompts]);
105+
recallBothPrompts(
106+
metadata?.positive_prompt,
107+
metadata?.negative_prompt,
108+
metadata?.positive_style_prompt,
109+
metadata?.negative_style_prompt
110+
);
111+
}, [
112+
metadata?.negative_prompt,
113+
metadata?.positive_prompt,
114+
metadata?.positive_style_prompt,
115+
metadata?.negative_style_prompt,
116+
recallBothPrompts,
117+
]);
107118

108119
const handleRecallSeed = useCallback(() => {
109120
recallSeed(metadata?.seed);

invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts

Lines changed: 105 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,15 @@
11
import { useAppToaster } from 'app/components/Toaster';
22
import { useAppDispatch } from 'app/store/storeHooks';
3+
import {
4+
refinerModelChanged,
5+
setNegativeStylePromptSDXL,
6+
setPositiveStylePromptSDXL,
7+
setRefinerAestheticScore,
8+
setRefinerCFGScale,
9+
setRefinerScheduler,
10+
setRefinerStart,
11+
setRefinerSteps,
12+
} from 'features/sdxl/store/sdxlSlice';
313
import { useCallback } from 'react';
414
import { useTranslation } from 'react-i18next';
515
import { UnsafeImageMetadata } from 'services/api/endpoints/images';
@@ -22,6 +32,10 @@ import {
2232
isValidMainModel,
2333
isValidNegativePrompt,
2434
isValidPositivePrompt,
35+
isValidSDXLNegativeStylePrompt,
36+
isValidSDXLPositiveStylePrompt,
37+
isValidSDXLRefinerAestheticScore,
38+
isValidSDXLRefinerStart,
2539
isValidScheduler,
2640
isValidSeed,
2741
isValidSteps,
@@ -74,17 +88,34 @@ export const useRecallParameters = () => {
7488
* Recall both prompts with toast
7589
*/
7690
const recallBothPrompts = useCallback(
77-
(positivePrompt: unknown, negativePrompt: unknown) => {
91+
(
92+
positivePrompt: unknown,
93+
negativePrompt: unknown,
94+
positiveStylePrompt: unknown,
95+
negativeStylePrompt: unknown
96+
) => {
7897
if (
7998
isValidPositivePrompt(positivePrompt) ||
80-
isValidNegativePrompt(negativePrompt)
99+
isValidNegativePrompt(negativePrompt) ||
100+
isValidSDXLPositiveStylePrompt(positiveStylePrompt) ||
101+
isValidSDXLNegativeStylePrompt(negativeStylePrompt)
81102
) {
82103
if (isValidPositivePrompt(positivePrompt)) {
83104
dispatch(setPositivePrompt(positivePrompt));
84105
}
106+
85107
if (isValidNegativePrompt(negativePrompt)) {
86108
dispatch(setNegativePrompt(negativePrompt));
87109
}
110+
111+
if (isValidSDXLPositiveStylePrompt(positiveStylePrompt)) {
112+
dispatch(setPositiveStylePromptSDXL(positiveStylePrompt));
113+
}
114+
115+
if (isValidSDXLPositiveStylePrompt(negativeStylePrompt)) {
116+
dispatch(setNegativeStylePromptSDXL(negativeStylePrompt));
117+
}
118+
88119
parameterSetToast();
89120
return;
90121
}
@@ -123,6 +154,36 @@ export const useRecallParameters = () => {
123154
[dispatch, parameterSetToast, parameterNotSetToast]
124155
);
125156

157+
/**
158+
* Recall SDXL Positive Style Prompt with toast
159+
*/
160+
const recallSDXLPositiveStylePrompt = useCallback(
161+
(positiveStylePrompt: unknown) => {
162+
if (!isValidSDXLPositiveStylePrompt(positiveStylePrompt)) {
163+
parameterNotSetToast();
164+
return;
165+
}
166+
dispatch(setPositiveStylePromptSDXL(positiveStylePrompt));
167+
parameterSetToast();
168+
},
169+
[dispatch, parameterSetToast, parameterNotSetToast]
170+
);
171+
172+
/**
173+
* Recall SDXL Negative Style Prompt with toast
174+
*/
175+
const recallSDXLNegativeStylePrompt = useCallback(
176+
(negativeStylePrompt: unknown) => {
177+
if (!isValidSDXLNegativeStylePrompt(negativeStylePrompt)) {
178+
parameterNotSetToast();
179+
return;
180+
}
181+
dispatch(setNegativeStylePromptSDXL(negativeStylePrompt));
182+
parameterSetToast();
183+
},
184+
[dispatch, parameterSetToast, parameterNotSetToast]
185+
);
186+
126187
/**
127188
* Recall seed with toast
128189
*/
@@ -271,6 +332,14 @@ export const useRecallParameters = () => {
271332
steps,
272333
width,
273334
strength,
335+
positive_style_prompt,
336+
negative_style_prompt,
337+
refiner_model,
338+
refiner_cfg_scale,
339+
refiner_steps,
340+
refiner_scheduler,
341+
refiner_aesthetic_store,
342+
refiner_start,
274343
} = metadata;
275344

276345
if (isValidCfgScale(cfg_scale)) {
@@ -304,6 +373,38 @@ export const useRecallParameters = () => {
304373
dispatch(setImg2imgStrength(strength));
305374
}
306375

376+
if (isValidSDXLPositiveStylePrompt(positive_style_prompt)) {
377+
dispatch(setPositiveStylePromptSDXL(positive_style_prompt));
378+
}
379+
380+
if (isValidSDXLNegativeStylePrompt(negative_style_prompt)) {
381+
dispatch(setNegativeStylePromptSDXL(negative_style_prompt));
382+
}
383+
384+
if (isValidMainModel(refiner_model)) {
385+
dispatch(refinerModelChanged(refiner_model));
386+
}
387+
388+
if (isValidSteps(refiner_steps)) {
389+
dispatch(setRefinerSteps(refiner_steps));
390+
}
391+
392+
if (isValidCfgScale(refiner_cfg_scale)) {
393+
dispatch(setRefinerCFGScale(refiner_cfg_scale));
394+
}
395+
396+
if (isValidScheduler(refiner_scheduler)) {
397+
dispatch(setRefinerScheduler(refiner_scheduler));
398+
}
399+
400+
if (isValidSDXLRefinerAestheticScore(refiner_aesthetic_store)) {
401+
dispatch(setRefinerAestheticScore(refiner_aesthetic_store));
402+
}
403+
404+
if (isValidSDXLRefinerStart(refiner_start)) {
405+
dispatch(setRefinerStart(refiner_start));
406+
}
407+
307408
allParameterSetToast();
308409
},
309410
[allParameterNotSetToast, allParameterSetToast, dispatch]
@@ -313,6 +414,8 @@ export const useRecallParameters = () => {
313414
recallBothPrompts,
314415
recallPositivePrompt,
315416
recallNegativePrompt,
417+
recallSDXLPositiveStylePrompt,
418+
recallSDXLNegativeStylePrompt,
316419
recallSeed,
317420
recallCfgScale,
318421
recallModel,

invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,39 @@ export type PrecisionParam = z.infer<typeof zPrecision>;
310310
export const isValidPrecision = (val: unknown): val is PrecisionParam =>
311311
zPrecision.safeParse(val).success;
312312

313+
/**
314+
* Zod schema for SDXL refiner aesthetic score parameter
315+
*/
316+
export const zSDXLRefinerAestheticScore = z.number().min(1).max(10);
317+
/**
318+
* Type alias for SDXL refiner aesthetic score parameter, inferred from its zod schema
319+
*/
320+
export type SDXLRefinerAestheticScoreParam = z.infer<
321+
typeof zSDXLRefinerAestheticScore
322+
>;
323+
/**
324+
* Validates/type-guards a value as a SDXL refiner aesthetic score parameter
325+
*/
326+
export const isValidSDXLRefinerAestheticScore = (
327+
val: unknown
328+
): val is SDXLRefinerAestheticScoreParam =>
329+
zSDXLRefinerAestheticScore.safeParse(val).success;
330+
331+
/**
332+
* Zod schema for SDXL start parameter
333+
*/
334+
export const zSDXLRefinerstart = z.number().min(0).max(1);
335+
/**
336+
* Type alias for SDXL start, inferred from its zod schema
337+
*/
338+
export type SDXLRefinerStartParam = z.infer<typeof zSDXLRefinerstart>;
339+
/**
340+
* Validates/type-guards a value as a SDXL refiner aesthetic score parameter
341+
*/
342+
export const isValidSDXLRefinerStart = (
343+
val: unknown
344+
): val is SDXLRefinerStartParam => zSDXLRefinerstart.safeParse(val).success;
345+
313346
// /**
314347
// * Zod schema for BaseModelType
315348
// */

invokeai/frontend/web/src/features/sdxl/components/ParamSDXLConcatButton.tsx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ export default function ParamSDXLConcatButton() {
2121

2222
return (
2323
<IAIIconButton
24-
aria-label="Concat"
25-
tooltip="Concatenates Basic Prompt with Style (Recommended)"
24+
aria-label="Concatenate Prompt & Style"
25+
tooltip="Concatenate Prompt & Style"
2626
variant="outline"
2727
isChecked={shouldConcatSDXLStylePrompt}
2828
onClick={handleShouldConcatPromptChange}

0 commit comments

Comments
 (0)