@@ -228,19 +228,19 @@ class ConfigMeta(BaseModel):Loras, textual_inversion and controlnet models are n
228228"""
229229from __future__ import annotations
230230
231- import os
232231import hashlib
232+ import os
233233import textwrap
234- import yaml
234+ import types
235235from dataclasses import dataclass
236236from pathlib import Path
237- from typing import Literal , Optional , List , Tuple , Union , Dict , Set , Callable , types
238237from shutil import rmtree , move
238+ from typing import Optional , List , Literal , Tuple , Union , Dict , Set , Callable
239239
240240import torch
241+ import yaml
241242from omegaconf import OmegaConf
242243from omegaconf .dictconfig import DictConfig
243-
244244from pydantic import BaseModel , Field
245245
246246import invokeai .backend .util .logging as logger
@@ -259,6 +259,7 @@ class ConfigMeta(BaseModel):Loras, textual_inversion and controlnet models are n
259259 ModelNotFoundException ,
260260 InvalidModelException ,
261261 DuplicateModelException ,
262+ ModelBase ,
262263)
263264
264265# We are only starting to number the config file with release 3.
@@ -361,7 +362,7 @@ def _read_models(self, config: Optional[DictConfig] = None):
361362 if model_key .startswith ("_" ):
362363 continue
363364 model_name , base_model , model_type = self .parse_key (model_key )
364- model_class = MODEL_CLASSES [ base_model ][ model_type ]
365+ model_class = self . _get_implementation ( base_model , model_type )
365366 # alias for config file
366367 model_config ["model_format" ] = model_config .pop ("format" )
367368 self .models [model_key ] = model_class .create_config (** model_config )
@@ -381,18 +382,24 @@ def sync_to_config(self):
381382 # causing otherwise unreferenced models to be removed from memory
382383 self ._read_models ()
383384
384- def model_exists (
385- self ,
386- model_name : str ,
387- base_model : BaseModelType ,
388- model_type : ModelType ,
389- ) -> bool :
385+ def model_exists (self , model_name : str , base_model : BaseModelType , model_type : ModelType , * , rescan = False ) -> bool :
390386 """
391- Given a model name, returns True if it is a valid
392- identifier.
387+ Given a model name, returns True if it is a valid identifier.
388+
389+ :param model_name: symbolic name of the model in models.yaml
390+ :param model_type: ModelType enum indicating the type of model to return
391+ :param base_model: BaseModelType enum indicating the base model used by this model
392+ :param rescan: if True, scan_models_directory
393393 """
394394 model_key = self .create_key (model_name , base_model , model_type )
395- return model_key in self .models
395+ exists = model_key in self .models
396+
397+ # if model not found try to find it (maybe file just pasted)
398+ if rescan and not exists :
399+ self .scan_models_directory (base_model = base_model , model_type = model_type )
400+ exists = self .model_exists (model_name , base_model , model_type , rescan = False )
401+
402+ return exists
396403
397404 @classmethod
398405 def create_key (
@@ -443,39 +450,32 @@ def get_model(
443450 :param model_name: symbolic name of the model in models.yaml
444451 :param model_type: ModelType enum indicating the type of model to return
445452 :param base_model: BaseModelType enum indicating the base model used by this model
446- :param submode_typel : an ModelType enum indicating the portion of
453+ :param submodel_type : an ModelType enum indicating the portion of
447454 the model to retrieve (e.g. ModelType.Vae)
448455 """
449- model_class = MODEL_CLASSES [base_model ][model_type ]
450456 model_key = self .create_key (model_name , base_model , model_type )
451457
452- # if model not found try to find it (maybe file just pasted)
453- if model_key not in self .models :
454- self .scan_models_directory (base_model = base_model , model_type = model_type )
455- if model_key not in self .models :
456- raise ModelNotFoundException (f"Model not found - { model_key } " )
458+ if not self .model_exists (model_name , base_model , model_type , rescan = True ):
459+ raise ModelNotFoundException (f"Model not found - { model_key } " )
460+
461+ model_config = self ._get_model_config (base_model , model_name , model_type )
462+
463+ model_path , is_submodel_override = self ._get_model_path (model_config , submodel_type )
457464
458- model_config = self .models [model_key ]
459- model_path = self .resolve_model_path (model_config .path )
465+ if is_submodel_override :
466+ model_type = submodel_type
467+ submodel_type = None
468+
469+ model_class = self ._get_implementation (base_model , model_type )
460470
461471 if not model_path .exists ():
462472 if model_class .save_to_config :
463473 self .models [model_key ].error = ModelError .NotFound
464- raise Exception (f'Files for model "{ model_key } " not found' )
474+ raise Exception (f'Files for model "{ model_key } " not found at { model_path } ' )
465475
466476 else :
467477 self .models .pop (model_key , None )
468- raise ModelNotFoundException (f"Model not found - { model_key } " )
469-
470- # vae/movq override
471- # TODO:
472- if submodel_type is not None and hasattr (model_config , submodel_type ):
473- override_path = getattr (model_config , submodel_type )
474- if override_path :
475- model_path = self .resolve_path (override_path )
476- model_type = submodel_type
477- submodel_type = None
478- model_class = MODEL_CLASSES [base_model ][model_type ]
478+ raise ModelNotFoundException (f'Files for model "{ model_key } " not found at { model_path } ' )
479479
480480 # TODO: path
481481 # TODO: is it accurate to use path as id
@@ -513,6 +513,55 @@ def get_model(
513513 _cache = self .cache ,
514514 )
515515
516+ def _get_model_path (
517+ self , model_config : ModelConfigBase , submodel_type : Optional [SubModelType ] = None
518+ ) -> (Path , bool ):
519+ """Extract a model's filesystem path from its config.
520+
521+ :return: The fully qualified Path of the module (or submodule).
522+ """
523+ model_path = model_config .path
524+ is_submodel_override = False
525+
526+ # Does the config explicitly override the submodel?
527+ if submodel_type is not None and hasattr (model_config , submodel_type ):
528+ submodel_path = getattr (model_config , submodel_type )
529+ if submodel_path is not None :
530+ model_path = getattr (model_config , submodel_type )
531+ is_submodel_override = True
532+
533+ model_path = self .resolve_model_path (model_path )
534+ return model_path , is_submodel_override
535+
536+ def _get_model_config (self , base_model : BaseModelType , model_name : str , model_type : ModelType ) -> ModelConfigBase :
537+ """Get a model's config object."""
538+ model_key = self .create_key (model_name , base_model , model_type )
539+ try :
540+ model_config = self .models [model_key ]
541+ except KeyError :
542+ raise ModelNotFoundException (f"Model not found - { model_key } " )
543+ return model_config
544+
545+ def _get_implementation (self , base_model : BaseModelType , model_type : ModelType ) -> type [ModelBase ]:
546+ """Get the concrete implementation class for a specific model type."""
547+ model_class = MODEL_CLASSES [base_model ][model_type ]
548+ return model_class
549+
550+ def _instantiate (
551+ self ,
552+ model_name : str ,
553+ base_model : BaseModelType ,
554+ model_type : ModelType ,
555+ submodel_type : Optional [SubModelType ] = None ,
556+ ) -> ModelBase :
557+ """Make a new instance of this model, without loading it."""
558+ model_config = self ._get_model_config (base_model , model_name , model_type )
559+ model_path , is_submodel_override = self ._get_model_path (model_config , submodel_type )
560+ # FIXME: do non-overriden submodels get the right class?
561+ constructor = self ._get_implementation (base_model , model_type )
562+ instance = constructor (model_path , base_model , model_type )
563+ return instance
564+
516565 def model_info (
517566 self ,
518567 model_name : str ,
@@ -546,9 +595,10 @@ def list_model(
546595 the combined format of the list_models() method.
547596 """
548597 models = self .list_models (base_model , model_type , model_name )
549- if len (models ) > 1 :
598+ if len (models ) >= 1 :
550599 return models [0 ]
551- return None
600+ else :
601+ return None
552602
553603 def list_models (
554604 self ,
@@ -660,7 +710,7 @@ def add_model(
660710 if path := model_attributes .get ("path" ):
661711 model_attributes ["path" ] = str (self .relative_model_path (Path (path )))
662712
663- model_class = MODEL_CLASSES [ base_model ][ model_type ]
713+ model_class = self . _get_implementation ( base_model , model_type )
664714 model_config = model_class .create_config (** model_attributes )
665715 model_key = self .create_key (model_name , base_model , model_type )
666716
@@ -851,7 +901,7 @@ def commit(self, conf_file: Optional[Path] = None) -> None:
851901
852902 for model_key , model_config in self .models .items ():
853903 model_name , base_model , model_type = self .parse_key (model_key )
854- model_class = MODEL_CLASSES [ base_model ][ model_type ]
904+ model_class = self . _get_implementation ( base_model , model_type )
855905 if model_class .save_to_config :
856906 # TODO: or exclude_unset better fits here?
857907 data_to_save [model_key ] = model_config .dict (exclude_defaults = True , exclude = {"error" })
@@ -909,7 +959,7 @@ def scan_models_directory(
909959
910960 model_path = self .resolve_model_path (model_config .path ).absolute ()
911961 if not model_path .exists ():
912- model_class = MODEL_CLASSES [ cur_base_model ][ cur_model_type ]
962+ model_class = self . _get_implementation ( cur_base_model , cur_model_type )
913963 if model_class .save_to_config :
914964 model_config .error = ModelError .NotFound
915965 self .models .pop (model_key , None )
@@ -925,7 +975,7 @@ def scan_models_directory(
925975 for cur_model_type in ModelType :
926976 if model_type is not None and cur_model_type != model_type :
927977 continue
928- model_class = MODEL_CLASSES [ cur_base_model ][ cur_model_type ]
978+ model_class = self . _get_implementation ( cur_base_model , cur_model_type )
929979 models_dir = self .resolve_model_path (Path (cur_base_model .value , cur_model_type .value ))
930980
931981 if not models_dir .exists ():
0 commit comments