Skip to content

Commit dfb41d8

Browse files
authored
Merge branch 'main' into bugfix/autodetect-sdxl-ckpt-config
2 parents af04400 + e77400a commit dfb41d8

File tree

23 files changed

+658
-126
lines changed

23 files changed

+658
-126
lines changed

invokeai/app/api/routers/images.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,20 @@
11
import io
22
from typing import Optional
33

4+
from PIL import Image
45
from fastapi import Body, HTTPException, Path, Query, Request, Response, UploadFile
56
from fastapi.responses import FileResponse
67
from fastapi.routing import APIRouter
7-
from PIL import Image
8-
from pydantic import BaseModel, Field
8+
from pydantic import BaseModel
99

1010
from invokeai.app.invocations.metadata import ImageMetadata
1111
from invokeai.app.models.image import ImageCategory, ResourceOrigin
1212
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
13-
from invokeai.app.services.item_storage import PaginatedResults
1413
from invokeai.app.services.models.image_record import (
1514
ImageDTO,
1615
ImageRecordChanges,
1716
ImageUrlsDTO,
1817
)
19-
2018
from ..dependencies import ApiDependencies
2119

2220
images_router = APIRouter(prefix="/v1/images", tags=["images"])
@@ -152,8 +150,9 @@ async def get_image_metadata(
152150
raise HTTPException(status_code=404)
153151

154152

155-
@images_router.get(
153+
@images_router.api_route(
156154
"/i/{image_name}/full",
155+
methods=["GET", "HEAD"],
157156
operation_id="get_image_full",
158157
response_class=Response,
159158
responses={

invokeai/app/services/config.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
always_use_cpu: false
2929
free_gpu_mem: false
3030
Features:
31-
restore: true
3231
esrgan: true
3332
patchmatch: true
3433
internet_available: true
@@ -165,7 +164,7 @@ class InvokeBatch(InvokeAISettings):
165164
import os
166165
import sys
167166
from argparse import ArgumentParser
168-
from omegaconf import OmegaConf, DictConfig
167+
from omegaconf import OmegaConf, DictConfig, ListConfig
169168
from pathlib import Path
170169
from pydantic import BaseSettings, Field, parse_obj_as
171170
from typing import ClassVar, Dict, List, Set, Literal, Union, get_origin, get_type_hints, get_args
@@ -189,7 +188,12 @@ def parse_args(self, argv: list = sys.argv[1:]):
189188
opt = parser.parse_args(argv)
190189
for name in self.__fields__:
191190
if name not in self._excluded():
192-
setattr(self, name, getattr(opt, name))
191+
value = getattr(opt, name)
192+
if isinstance(value, ListConfig):
193+
value = list(value)
194+
elif isinstance(value, DictConfig):
195+
value = dict(value)
196+
setattr(self, name, value)
193197

194198
def to_yaml(self) -> str:
195199
"""
@@ -282,14 +286,10 @@ def _excluded_from_yaml(self) -> List[str]:
282286
return [
283287
"type",
284288
"initconf",
285-
"gpu_mem_reserved",
286-
"max_loaded_models",
287289
"version",
288290
"from_file",
289291
"model",
290-
"restore",
291292
"root",
292-
"nsfw_checker",
293293
]
294294

295295
class Config:
@@ -388,15 +388,11 @@ class InvokeAIAppConfig(InvokeAISettings):
388388
internet_available : bool = Field(default=True, description="If true, attempt to download models on the fly; otherwise only use local models", category='Features')
389389
log_tokenization : bool = Field(default=False, description="Enable logging of parsed prompt tokens.", category='Features')
390390
patchmatch : bool = Field(default=True, description="Enable/disable patchmatch inpaint code", category='Features')
391-
restore : bool = Field(default=True, description="Enable/disable face restoration code (DEPRECATED)", category='DEPRECATED')
392391

393392
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance')
394393
free_gpu_mem : bool = Field(default=False, description="If true, purge model from GPU after each generation.", category='Memory/Performance')
395-
max_loaded_models : int = Field(default=3, gt=0, description="(DEPRECATED: use max_cache_size) Maximum number of models to keep in memory for rapid switching", category='DEPRECATED')
396394
max_cache_size : float = Field(default=6.0, gt=0, description="Maximum memory amount used by model cache for rapid switching", category='Memory/Performance')
397395
max_vram_cache_size : float = Field(default=2.75, ge=0, description="Amount of VRAM reserved for model storage", category='Memory/Performance')
398-
gpu_mem_reserved : float = Field(default=2.75, ge=0, description="DEPRECATED: use max_vram_cache_size. Amount of VRAM reserved for model storage", category='DEPRECATED')
399-
nsfw_checker : bool = Field(default=True, description="DEPRECATED: use Web settings to enable/disable", category='DEPRECATED')
400396
precision : Literal[tuple(['auto','float16','float32','autocast'])] = Field(default='auto',description='Floating point precision', category='Memory/Performance')
401397
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category='Memory/Performance')
402398
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance')
@@ -414,9 +410,7 @@ class InvokeAIAppConfig(InvokeAISettings):
414410
outdir : Path = Field(default='outputs', description='Default folder for output images', category='Paths')
415411
from_file : Path = Field(default=None, description='Take command input from the indicated file (command-line client only)', category='Paths')
416412
use_memory_db : bool = Field(default=False, description='Use in-memory database for storing image metadata', category='Paths')
417-
ignore_missing_core_models : bool = Field(default=False, description='Ignore missing models in models/core/convert')
418-
419-
model : str = Field(default='stable-diffusion-1.5', description='Initial model name', category='Models')
413+
ignore_missing_core_models : bool = Field(default=False, description='Ignore missing models in models/core/convert', category='Features')
420414

421415
log_handlers : List[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>"', category="Logging")
422416
# note - would be better to read the log_format values from logging.py, but this creates circular dependencies issues
@@ -426,6 +420,9 @@ class InvokeAIAppConfig(InvokeAISettings):
426420
version : bool = Field(default=False, description="Show InvokeAI version and exit", category="Other")
427421
# fmt: on
428422

423+
class Config:
424+
validate_assignment = True
425+
429426
def parse_args(self, argv: List[str] = None, conf: DictConfig = None, clobber=False):
430427
"""
431428
Update settings with contents of init file, environment, and

invokeai/backend/install/invokeai_configure.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@
4444
)
4545
from invokeai.backend.util.logging import InvokeAILogger
4646
from invokeai.frontend.install.model_install import addModelsForm, process_and_execute
47+
48+
# TO DO - Move all the frontend code into invokeai.frontend.install
4749
from invokeai.frontend.install.widgets import (
4850
SingleSelectColumns,
4951
CenteredButtonPress,
@@ -61,6 +63,7 @@
6163
ModelInstall,
6264
)
6365
from invokeai.backend.model_management.model_probe import ModelType, BaseModelType
66+
from pydantic.error_wrappers import ValidationError
6467

6568
warnings.filterwarnings("ignore")
6669
transformers.logging.set_verbosity_error()
@@ -654,10 +657,13 @@ def migrate_init_file(legacy_format: Path):
654657
old = legacy_parser.parse_args([f"@{str(legacy_format)}"])
655658
new = InvokeAIAppConfig.get_config()
656659

657-
fields = list(get_type_hints(InvokeAIAppConfig).keys())
660+
fields = [x for x, y in InvokeAIAppConfig.__fields__.items() if y.field_info.extra.get("category") != "DEPRECATED"]
658661
for attr in fields:
659662
if hasattr(old, attr):
660-
setattr(new, attr, getattr(old, attr))
663+
try:
664+
setattr(new, attr, getattr(old, attr))
665+
except ValidationError as e:
666+
print(f"* Ignoring incompatible value for field {attr}:\n {str(e)}")
661667

662668
# a few places where the field names have changed and we have to
663669
# manually add in the new names/values

invokeai/backend/model_management/model_manager.py

Lines changed: 91 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -228,19 +228,19 @@ class ConfigMeta(BaseModel):Loras, textual_inversion and controlnet models are n
228228
"""
229229
from __future__ import annotations
230230

231-
import os
232231
import hashlib
232+
import os
233233
import textwrap
234-
import yaml
234+
import types
235235
from dataclasses import dataclass
236236
from pathlib import Path
237-
from typing import Literal, Optional, List, Tuple, Union, Dict, Set, Callable, types
238237
from shutil import rmtree, move
238+
from typing import Optional, List, Literal, Tuple, Union, Dict, Set, Callable
239239

240240
import torch
241+
import yaml
241242
from omegaconf import OmegaConf
242243
from omegaconf.dictconfig import DictConfig
243-
244244
from pydantic import BaseModel, Field
245245

246246
import invokeai.backend.util.logging as logger
@@ -259,6 +259,7 @@ class ConfigMeta(BaseModel):Loras, textual_inversion and controlnet models are n
259259
ModelNotFoundException,
260260
InvalidModelException,
261261
DuplicateModelException,
262+
ModelBase,
262263
)
263264

264265
# We are only starting to number the config file with release 3.
@@ -361,7 +362,7 @@ def _read_models(self, config: Optional[DictConfig] = None):
361362
if model_key.startswith("_"):
362363
continue
363364
model_name, base_model, model_type = self.parse_key(model_key)
364-
model_class = MODEL_CLASSES[base_model][model_type]
365+
model_class = self._get_implementation(base_model, model_type)
365366
# alias for config file
366367
model_config["model_format"] = model_config.pop("format")
367368
self.models[model_key] = model_class.create_config(**model_config)
@@ -381,18 +382,24 @@ def sync_to_config(self):
381382
# causing otherwise unreferenced models to be removed from memory
382383
self._read_models()
383384

384-
def model_exists(
385-
self,
386-
model_name: str,
387-
base_model: BaseModelType,
388-
model_type: ModelType,
389-
) -> bool:
385+
def model_exists(self, model_name: str, base_model: BaseModelType, model_type: ModelType, *, rescan=False) -> bool:
390386
"""
391-
Given a model name, returns True if it is a valid
392-
identifier.
387+
Given a model name, returns True if it is a valid identifier.
388+
389+
:param model_name: symbolic name of the model in models.yaml
390+
:param model_type: ModelType enum indicating the type of model to return
391+
:param base_model: BaseModelType enum indicating the base model used by this model
392+
:param rescan: if True, scan_models_directory
393393
"""
394394
model_key = self.create_key(model_name, base_model, model_type)
395-
return model_key in self.models
395+
exists = model_key in self.models
396+
397+
# if model not found try to find it (maybe file just pasted)
398+
if rescan and not exists:
399+
self.scan_models_directory(base_model=base_model, model_type=model_type)
400+
exists = self.model_exists(model_name, base_model, model_type, rescan=False)
401+
402+
return exists
396403

397404
@classmethod
398405
def create_key(
@@ -443,39 +450,32 @@ def get_model(
443450
:param model_name: symbolic name of the model in models.yaml
444451
:param model_type: ModelType enum indicating the type of model to return
445452
:param base_model: BaseModelType enum indicating the base model used by this model
446-
:param submode_typel: an ModelType enum indicating the portion of
453+
:param submodel_type: an ModelType enum indicating the portion of
447454
the model to retrieve (e.g. ModelType.Vae)
448455
"""
449-
model_class = MODEL_CLASSES[base_model][model_type]
450456
model_key = self.create_key(model_name, base_model, model_type)
451457

452-
# if model not found try to find it (maybe file just pasted)
453-
if model_key not in self.models:
454-
self.scan_models_directory(base_model=base_model, model_type=model_type)
455-
if model_key not in self.models:
456-
raise ModelNotFoundException(f"Model not found - {model_key}")
458+
if not self.model_exists(model_name, base_model, model_type, rescan=True):
459+
raise ModelNotFoundException(f"Model not found - {model_key}")
460+
461+
model_config = self._get_model_config(base_model, model_name, model_type)
462+
463+
model_path, is_submodel_override = self._get_model_path(model_config, submodel_type)
457464

458-
model_config = self.models[model_key]
459-
model_path = self.resolve_model_path(model_config.path)
465+
if is_submodel_override:
466+
model_type = submodel_type
467+
submodel_type = None
468+
469+
model_class = self._get_implementation(base_model, model_type)
460470

461471
if not model_path.exists():
462472
if model_class.save_to_config:
463473
self.models[model_key].error = ModelError.NotFound
464-
raise Exception(f'Files for model "{model_key}" not found')
474+
raise Exception(f'Files for model "{model_key}" not found at {model_path}')
465475

466476
else:
467477
self.models.pop(model_key, None)
468-
raise ModelNotFoundException(f"Model not found - {model_key}")
469-
470-
# vae/movq override
471-
# TODO:
472-
if submodel_type is not None and hasattr(model_config, submodel_type):
473-
override_path = getattr(model_config, submodel_type)
474-
if override_path:
475-
model_path = self.resolve_path(override_path)
476-
model_type = submodel_type
477-
submodel_type = None
478-
model_class = MODEL_CLASSES[base_model][model_type]
478+
raise ModelNotFoundException(f'Files for model "{model_key}" not found at {model_path}')
479479

480480
# TODO: path
481481
# TODO: is it accurate to use path as id
@@ -513,6 +513,55 @@ def get_model(
513513
_cache=self.cache,
514514
)
515515

516+
def _get_model_path(
517+
self, model_config: ModelConfigBase, submodel_type: Optional[SubModelType] = None
518+
) -> (Path, bool):
519+
"""Extract a model's filesystem path from its config.
520+
521+
:return: The fully qualified Path of the module (or submodule).
522+
"""
523+
model_path = model_config.path
524+
is_submodel_override = False
525+
526+
# Does the config explicitly override the submodel?
527+
if submodel_type is not None and hasattr(model_config, submodel_type):
528+
submodel_path = getattr(model_config, submodel_type)
529+
if submodel_path is not None:
530+
model_path = getattr(model_config, submodel_type)
531+
is_submodel_override = True
532+
533+
model_path = self.resolve_model_path(model_path)
534+
return model_path, is_submodel_override
535+
536+
def _get_model_config(self, base_model: BaseModelType, model_name: str, model_type: ModelType) -> ModelConfigBase:
537+
"""Get a model's config object."""
538+
model_key = self.create_key(model_name, base_model, model_type)
539+
try:
540+
model_config = self.models[model_key]
541+
except KeyError:
542+
raise ModelNotFoundException(f"Model not found - {model_key}")
543+
return model_config
544+
545+
def _get_implementation(self, base_model: BaseModelType, model_type: ModelType) -> type[ModelBase]:
546+
"""Get the concrete implementation class for a specific model type."""
547+
model_class = MODEL_CLASSES[base_model][model_type]
548+
return model_class
549+
550+
def _instantiate(
551+
self,
552+
model_name: str,
553+
base_model: BaseModelType,
554+
model_type: ModelType,
555+
submodel_type: Optional[SubModelType] = None,
556+
) -> ModelBase:
557+
"""Make a new instance of this model, without loading it."""
558+
model_config = self._get_model_config(base_model, model_name, model_type)
559+
model_path, is_submodel_override = self._get_model_path(model_config, submodel_type)
560+
# FIXME: do non-overriden submodels get the right class?
561+
constructor = self._get_implementation(base_model, model_type)
562+
instance = constructor(model_path, base_model, model_type)
563+
return instance
564+
516565
def model_info(
517566
self,
518567
model_name: str,
@@ -546,9 +595,10 @@ def list_model(
546595
the combined format of the list_models() method.
547596
"""
548597
models = self.list_models(base_model, model_type, model_name)
549-
if len(models) > 1:
598+
if len(models) >= 1:
550599
return models[0]
551-
return None
600+
else:
601+
return None
552602

553603
def list_models(
554604
self,
@@ -660,7 +710,7 @@ def add_model(
660710
if path := model_attributes.get("path"):
661711
model_attributes["path"] = str(self.relative_model_path(Path(path)))
662712

663-
model_class = MODEL_CLASSES[base_model][model_type]
713+
model_class = self._get_implementation(base_model, model_type)
664714
model_config = model_class.create_config(**model_attributes)
665715
model_key = self.create_key(model_name, base_model, model_type)
666716

@@ -851,7 +901,7 @@ def commit(self, conf_file: Optional[Path] = None) -> None:
851901

852902
for model_key, model_config in self.models.items():
853903
model_name, base_model, model_type = self.parse_key(model_key)
854-
model_class = MODEL_CLASSES[base_model][model_type]
904+
model_class = self._get_implementation(base_model, model_type)
855905
if model_class.save_to_config:
856906
# TODO: or exclude_unset better fits here?
857907
data_to_save[model_key] = model_config.dict(exclude_defaults=True, exclude={"error"})
@@ -909,7 +959,7 @@ def scan_models_directory(
909959

910960
model_path = self.resolve_model_path(model_config.path).absolute()
911961
if not model_path.exists():
912-
model_class = MODEL_CLASSES[cur_base_model][cur_model_type]
962+
model_class = self._get_implementation(cur_base_model, cur_model_type)
913963
if model_class.save_to_config:
914964
model_config.error = ModelError.NotFound
915965
self.models.pop(model_key, None)
@@ -925,7 +975,7 @@ def scan_models_directory(
925975
for cur_model_type in ModelType:
926976
if model_type is not None and cur_model_type != model_type:
927977
continue
928-
model_class = MODEL_CLASSES[cur_base_model][cur_model_type]
978+
model_class = self._get_implementation(cur_base_model, cur_model_type)
929979
models_dir = self.resolve_model_path(Path(cur_base_model.value, cur_model_type.value))
930980

931981
if not models_dir.exists():

0 commit comments

Comments
 (0)