Skip to content

Commit 4ad2574

Browse files
feat(ui): add button to reidentify model to mm
1 parent 0e3d4be commit 4ad2574

File tree

4 files changed

+101
-0
lines changed

4 files changed

+101
-0
lines changed

invokeai/frontend/web/public/locales/en.json

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -844,6 +844,11 @@
844844
"clipLEmbed": "CLIP-L Embed",
845845
"clipGEmbed": "CLIP-G Embed",
846846
"config": "Config",
847+
"reidentify": "Reidentify",
848+
"reidentifyTooltip": "If a model didn't install correctly (e.g. it has the wrong type or doesn't work), you can try reidentifying it. This will reset any custom settings you may have applied.",
849+
"reidentifySuccess": "Model reidentified successfully",
850+
"reidentifyUnknown": "Unable to identify model",
851+
"reidentifyError": "Error reidentifying model",
847852
"convert": "Convert",
848853
"convertingModelBegin": "Converting Model. Please wait.",
849854
"convertToDiffusers": "Convert To Diffusers",
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import { Button } from '@invoke-ai/ui-library';
2+
import { toast } from 'features/toast/toast';
3+
import { memo, useCallback } from 'react';
4+
import { useTranslation } from 'react-i18next';
5+
import { PiSparkleFill } from 'react-icons/pi';
6+
import { useReidentifyModelMutation } from 'services/api/endpoints/models';
7+
import type { AnyModelConfig } from 'services/api/types';
8+
9+
interface Props {
10+
modelConfig: AnyModelConfig;
11+
}
12+
13+
export const ModelReidentifyButton = memo(({ modelConfig }: Props) => {
14+
const { t } = useTranslation();
15+
const [reidentifyModel, { isLoading }] = useReidentifyModelMutation();
16+
17+
const onClick = useCallback(() => {
18+
reidentifyModel({ key: modelConfig.key })
19+
.unwrap()
20+
.then(({ type }) => {
21+
if (type === 'unknown') {
22+
toast({
23+
id: 'MODEL_REIDENTIFY_UNKNOWN',
24+
title: t('modelManager.reidentifyUnknown'),
25+
status: 'warning',
26+
});
27+
}
28+
toast({
29+
id: 'MODEL_REIDENTIFY_SUCCESS',
30+
title: t('modelManager.reidentifySuccess'),
31+
status: 'success',
32+
});
33+
})
34+
.catch((_) => {
35+
toast({
36+
id: 'MODEL_REIDENTIFY_ERROR',
37+
title: t('modelManager.reidentifyError'),
38+
status: 'error',
39+
});
40+
});
41+
}, [modelConfig.key, reidentifyModel, t]);
42+
43+
return (
44+
<Button
45+
onClick={onClick}
46+
size="sm"
47+
aria-label={t('modelManager.reidentifyTooltip')}
48+
tooltip={t('modelManager.reidentifyTooltip')}
49+
isLoading={isLoading}
50+
flexShrink={0}
51+
leftIcon={<PiSparkleFill />}
52+
>
53+
{t('modelManager.reidentify')}
54+
</Button>
55+
);
56+
});
57+
58+
ModelReidentifyButton.displayName = 'ModelReidentifyButton';

invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelView.tsx

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import type { AnyModelConfig } from 'services/api/types';
1313
import { MainModelDefaultSettings } from './MainModelDefaultSettings/MainModelDefaultSettings';
1414
import { ModelAttrView } from './ModelAttrView';
1515
import { ModelFooter } from './ModelFooter';
16+
import { ModelReidentifyButton } from './ModelReidentifyButton';
1617
import { RelatedModels } from './RelatedModels';
1718

1819
type Props = {
@@ -21,6 +22,7 @@ type Props = {
2122

2223
export const ModelView = memo(({ modelConfig }: Props) => {
2324
const { t } = useTranslation();
25+
2426
const withSettings = useMemo(() => {
2527
if (modelConfig.type === 'main' && modelConfig.base !== 'sdxl-refiner') {
2628
return true;
@@ -46,6 +48,7 @@ export const ModelView = memo(({ modelConfig }: Props) => {
4648
<ModelConvertButton modelConfig={modelConfig} />
4749
)}
4850
<ModelEditButton />
51+
<ModelReidentifyButton modelConfig={modelConfig} />
4952
</ModelHeader>
5053
<Divider />
5154
<Flex flexDir="column" gap={4}>

invokeai/frontend/web/src/services/api/endpoints/models.ts

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,40 @@ export const modelsApi = api.injectEndpoints({
299299
emptyModelCache: build.mutation<void, void>({
300300
query: () => ({ url: buildModelsUrl('empty_model_cache'), method: 'POST' }),
301301
}),
302+
reidentifyModel: build.mutation<
303+
paths['/api/v2/models/i/{key}/reidentify']['post']['responses']['200']['content']['application/json'],
304+
{ key: string }
305+
>({
306+
query: ({ key }) => {
307+
return {
308+
url: buildModelsUrl(`i/${key}/reidentify`),
309+
method: 'POST',
310+
};
311+
},
312+
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
313+
try {
314+
const { data } = await queryFulfilled;
315+
316+
// Update the individual model query caches
317+
dispatch(modelsApi.util.upsertQueryData('getModelConfig', data.key, data));
318+
319+
const { base, name, type } = data;
320+
dispatch(modelsApi.util.upsertQueryData('getModelConfigByAttrs', { base, name, type }, data));
321+
322+
// Update the list query cache
323+
dispatch(
324+
modelsApi.util.updateQueryData('getModelConfigs', undefined, (draft) => {
325+
modelConfigsAdapter.updateOne(draft, {
326+
id: data.key,
327+
changes: data,
328+
});
329+
})
330+
);
331+
} catch {
332+
// no-op
333+
}
334+
},
335+
}),
302336
}),
303337
});
304338

@@ -321,6 +355,7 @@ export const {
321355
useSetHFTokenMutation,
322356
useResetHFTokenMutation,
323357
useEmptyModelCacheMutation,
358+
useReidentifyModelMutation,
324359
} = modelsApi;
325360

326361
export const selectModelConfigsQuery = modelsApi.endpoints.getModelConfigs.select();

0 commit comments

Comments
 (0)