diff --git a/.gitattributes b/.gitattributes index df4754736a6..6cf175e7c5a 100644 --- a/.gitattributes +++ b/.gitattributes @@ -4,3 +4,4 @@ * text=auto docker/** text eol=lf tests/test_model_probe/stripped_models/** filter=lfs diff=lfs merge=lfs -text +tests/model_identification/stripped_models/** filter=lfs diff=lfs merge=lfs -text diff --git a/invokeai/app/api/routers/model_manager.py b/invokeai/app/api/routers/model_manager.py index e62a5a5b60d..0c325a4ce05 100644 --- a/invokeai/app/api/routers/model_manager.py +++ b/invokeai/app/api/routers/model_manager.py @@ -28,10 +28,12 @@ UnknownModelException, ) from invokeai.app.util.suppress_output import SuppressOutput -from invokeai.backend.model_manager import BaseModelType, ModelFormat, ModelType -from invokeai.backend.model_manager.config import ( - AnyModelConfig, - MainCheckpointConfig, +from invokeai.backend.model_manager.configs.factory import AnyModelConfig +from invokeai.backend.model_manager.configs.main import ( + Main_Checkpoint_SD1_Config, + Main_Checkpoint_SD2_Config, + Main_Checkpoint_SDXL_Config, + Main_Checkpoint_SDXLRefiner_Config, ) from invokeai.backend.model_manager.load.model_cache.cache_stats import CacheStats from invokeai.backend.model_manager.metadata.fetch.huggingface import HuggingFaceMetadataFetch @@ -44,6 +46,7 @@ StarterModelBundle, StarterModelWithoutDependencies, ) +from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType model_manager_router = APIRouter(prefix="/v2/models", tags=["model_manager"]) @@ -297,10 +300,8 @@ async def update_model_record( """Update a model's config.""" logger = ApiDependencies.invoker.services.logger record_store = ApiDependencies.invoker.services.model_manager.store - installer = ApiDependencies.invoker.services.model_manager.install try: - record_store.update_model(key, changes=changes) - config = installer.sync_model_path(key) + config = record_store.update_model(key, changes=changes, allow_class_change=True) config = add_cover_image_to_model_config(config, ApiDependencies) logger.info(f"Updated model: {key}") except UnknownModelException as e: @@ -743,9 +744,18 @@ async def convert_model( logger.error(str(e)) raise HTTPException(status_code=424, detail=str(e)) - if not isinstance(model_config, MainCheckpointConfig): - logger.error(f"The model with key {key} is not a main checkpoint model.") - raise HTTPException(400, f"The model with key {key} is not a main checkpoint model.") + if not isinstance( + model_config, + ( + Main_Checkpoint_SD1_Config, + Main_Checkpoint_SD2_Config, + Main_Checkpoint_SDXL_Config, + Main_Checkpoint_SDXLRefiner_Config, + ), + ): + msg = f"The model with key {key} is not a main SD 1/2/XL checkpoint model." + logger.error(msg) + raise HTTPException(400, msg) with TemporaryDirectory(dir=ApiDependencies.invoker.services.configuration.models_path) as tmpdir: convert_path = pathlib.Path(tmpdir) / pathlib.Path(model_config.path).stem diff --git a/invokeai/app/invocations/cogview4_denoise.py b/invokeai/app/invocations/cogview4_denoise.py index c0b962ba31d..070d8a34783 100644 --- a/invokeai/app/invocations/cogview4_denoise.py +++ b/invokeai/app/invocations/cogview4_denoise.py @@ -22,7 +22,7 @@ from invokeai.app.invocations.primitives import LatentsOutput from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.flux.sampling_utils import clip_timestep_schedule_fractional -from invokeai.backend.model_manager.config import BaseModelType +from invokeai.backend.model_manager.taxonomy import BaseModelType from invokeai.backend.rectified_flow.rectified_flow_inpaint_extension import RectifiedFlowInpaintExtension from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState from invokeai.backend.stable_diffusion.diffusion.conditioning_data import CogView4ConditioningInfo diff --git a/invokeai/app/invocations/cogview4_model_loader.py b/invokeai/app/invocations/cogview4_model_loader.py index 9db4f3c0537..fbafcd345fd 100644 --- a/invokeai/app/invocations/cogview4_model_loader.py +++ b/invokeai/app/invocations/cogview4_model_loader.py @@ -13,8 +13,7 @@ VAEField, ) from invokeai.app.services.shared.invocation_context import InvocationContext -from invokeai.backend.model_manager.config import SubModelType -from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType +from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType, SubModelType @invocation_output("cogview4_model_loader_output") diff --git a/invokeai/app/invocations/create_gradient_mask.py b/invokeai/app/invocations/create_gradient_mask.py index b232fbbc932..8a7e7c52317 100644 --- a/invokeai/app/invocations/create_gradient_mask.py +++ b/invokeai/app/invocations/create_gradient_mask.py @@ -20,9 +20,7 @@ from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation from invokeai.app.invocations.model import UNetField, VAEField from invokeai.app.services.shared.invocation_context import InvocationContext -from invokeai.backend.model_manager import LoadedModel -from invokeai.backend.model_manager.config import MainConfigBase -from invokeai.backend.model_manager.taxonomy import ModelVariantType +from invokeai.backend.model_manager.taxonomy import FluxVariantType, ModelType, ModelVariantType from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor @@ -182,10 +180,11 @@ def invoke(self, context: InvocationContext) -> GradientMaskOutput: if self.unet is not None and self.vae is not None and self.image is not None: # all three fields must be present at the same time main_model_config = context.models.get_config(self.unet.unet.key) - assert isinstance(main_model_config, MainConfigBase) - if main_model_config.variant is ModelVariantType.Inpaint: + assert main_model_config.type is ModelType.Main + variant = getattr(main_model_config, "variant", None) + if variant is ModelVariantType.Inpaint or variant is FluxVariantType.DevFill: mask = dilated_mask_tensor - vae_info: LoadedModel = context.models.load(self.vae.vae) + vae_info = context.models.load(self.vae.vae) image = context.images.get_pil(self.image.image_name) image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB")) if image_tensor.dim() == 3: diff --git a/invokeai/app/invocations/denoise_latents.py b/invokeai/app/invocations/denoise_latents.py index 37b385914cc..bb114263e23 100644 --- a/invokeai/app/invocations/denoise_latents.py +++ b/invokeai/app/invocations/denoise_latents.py @@ -39,7 +39,7 @@ from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.controlnet_utils import prepare_control_image from invokeai.backend.ip_adapter.ip_adapter import IPAdapter -from invokeai.backend.model_manager.config import AnyModelConfig +from invokeai.backend.model_manager.configs.factory import AnyModelConfig from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelVariantType from invokeai.backend.model_patcher import ModelPatcher from invokeai.backend.patches.layer_patcher import LayerPatcher diff --git a/invokeai/app/invocations/flux_denoise.py b/invokeai/app/invocations/flux_denoise.py index 35d095e2799..b6d0399108d 100644 --- a/invokeai/app/invocations/flux_denoise.py +++ b/invokeai/app/invocations/flux_denoise.py @@ -48,7 +48,7 @@ unpack, ) from invokeai.backend.flux.text_conditioning import FluxReduxConditioning, FluxTextConditioning -from invokeai.backend.model_manager.taxonomy import ModelFormat, ModelVariantType +from invokeai.backend.model_manager.taxonomy import BaseModelType, FluxVariantType, ModelFormat, ModelType from invokeai.backend.patches.layer_patcher import LayerPatcher from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX from invokeai.backend.patches.model_patch_raw import ModelPatchRaw @@ -232,7 +232,8 @@ def _run_diffusion( ) transformer_config = context.models.get_config(self.transformer.transformer) - is_schnell = "schnell" in getattr(transformer_config, "config_path", "") + assert transformer_config.base is BaseModelType.Flux and transformer_config.type is ModelType.Main + is_schnell = transformer_config.variant is FluxVariantType.Schnell # Calculate the timestep schedule. timesteps = get_schedule( @@ -277,7 +278,7 @@ def _run_diffusion( # Prepare the extra image conditioning tensor (img_cond) for either FLUX structural control or FLUX Fill. img_cond: torch.Tensor | None = None - is_flux_fill = transformer_config.variant == ModelVariantType.Inpaint # type: ignore + is_flux_fill = transformer_config.variant is FluxVariantType.DevFill if is_flux_fill: img_cond = self._prep_flux_fill_img_cond( context, device=TorchDevice.choose_torch_device(), dtype=inference_dtype diff --git a/invokeai/app/invocations/flux_ip_adapter.py b/invokeai/app/invocations/flux_ip_adapter.py index db5754ee2b0..4a1997c5122 100644 --- a/invokeai/app/invocations/flux_ip_adapter.py +++ b/invokeai/app/invocations/flux_ip_adapter.py @@ -16,10 +16,7 @@ from invokeai.app.invocations.primitives import ImageField from invokeai.app.invocations.util import validate_begin_end_step, validate_weights from invokeai.app.services.shared.invocation_context import InvocationContext -from invokeai.backend.model_manager.config import ( - IPAdapterCheckpointConfig, - IPAdapterInvokeAIConfig, -) +from invokeai.backend.model_manager.configs.ip_adapter import IPAdapter_Checkpoint_FLUX_Config from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType @@ -68,7 +65,7 @@ def validate_begin_end_step_percent(self) -> Self: def invoke(self, context: InvocationContext) -> IPAdapterOutput: # Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model. ip_adapter_info = context.models.get_config(self.ip_adapter_model.key) - assert isinstance(ip_adapter_info, (IPAdapterInvokeAIConfig, IPAdapterCheckpointConfig)) + assert isinstance(ip_adapter_info, IPAdapter_Checkpoint_FLUX_Config) # Note: There is a IPAdapterInvokeAIConfig.image_encoder_model_id field, but it isn't trustworthy. image_encoder_starter_model = CLIP_VISION_MODEL_MAP[self.clip_vision_model] diff --git a/invokeai/app/invocations/flux_model_loader.py b/invokeai/app/invocations/flux_model_loader.py index e5a1966c659..eaac82bafc8 100644 --- a/invokeai/app/invocations/flux_model_loader.py +++ b/invokeai/app/invocations/flux_model_loader.py @@ -13,10 +13,8 @@ preprocess_t5_encoder_model_identifier, preprocess_t5_tokenizer_model_identifier, ) -from invokeai.backend.flux.util import max_seq_lengths -from invokeai.backend.model_manager.config import ( - CheckpointConfigBase, -) +from invokeai.backend.flux.util import get_flux_max_seq_length +from invokeai.backend.model_manager.configs.base import Checkpoint_Config_Base from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType, SubModelType @@ -87,12 +85,12 @@ def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput: t5_encoder = preprocess_t5_encoder_model_identifier(self.t5_encoder_model) transformer_config = context.models.get_config(transformer) - assert isinstance(transformer_config, CheckpointConfigBase) + assert isinstance(transformer_config, Checkpoint_Config_Base) return FluxModelLoaderOutput( transformer=TransformerField(transformer=transformer, loras=[]), clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0), t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder, loras=[]), vae=VAEField(vae=vae), - max_seq_len=max_seq_lengths[transformer_config.config_path], + max_seq_len=get_flux_max_seq_length(transformer_config.variant), ) diff --git a/invokeai/app/invocations/flux_redux.py b/invokeai/app/invocations/flux_redux.py index 3e34497b105..403d78b0786 100644 --- a/invokeai/app/invocations/flux_redux.py +++ b/invokeai/app/invocations/flux_redux.py @@ -24,9 +24,9 @@ from invokeai.app.services.model_records.model_records_base import ModelRecordChanges from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.flux.redux.flux_redux_model import FluxReduxModel -from invokeai.backend.model_manager import BaseModelType, ModelType -from invokeai.backend.model_manager.config import AnyModelConfig +from invokeai.backend.model_manager.configs.factory import AnyModelConfig from invokeai.backend.model_manager.starter_models import siglip +from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType from invokeai.backend.sig_lip.sig_lip_pipeline import SigLipPipeline from invokeai.backend.util.devices import TorchDevice diff --git a/invokeai/app/invocations/flux_text_encoder.py b/invokeai/app/invocations/flux_text_encoder.py index 77b6187840c..c395a0bf22d 100644 --- a/invokeai/app/invocations/flux_text_encoder.py +++ b/invokeai/app/invocations/flux_text_encoder.py @@ -17,7 +17,7 @@ from invokeai.app.invocations.primitives import FluxConditioningOutput from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.flux.modules.conditioner import HFEncoder -from invokeai.backend.model_manager import ModelFormat +from invokeai.backend.model_manager.taxonomy import ModelFormat from invokeai.backend.patches.layer_patcher import LayerPatcher from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX, FLUX_LORA_T5_PREFIX from invokeai.backend.patches.model_patch_raw import ModelPatchRaw diff --git a/invokeai/app/invocations/flux_vae_encode.py b/invokeai/app/invocations/flux_vae_encode.py index 2932517edcf..4ec0365c2cb 100644 --- a/invokeai/app/invocations/flux_vae_encode.py +++ b/invokeai/app/invocations/flux_vae_encode.py @@ -12,7 +12,7 @@ from invokeai.app.invocations.primitives import LatentsOutput from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.flux.modules.autoencoder import AutoEncoder -from invokeai.backend.model_manager import LoadedModel +from invokeai.backend.model_manager.load.load_base import LoadedModel from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_flux diff --git a/invokeai/app/invocations/image_to_latents.py b/invokeai/app/invocations/image_to_latents.py index 552f5edb1b2..fde70a34fde 100644 --- a/invokeai/app/invocations/image_to_latents.py +++ b/invokeai/app/invocations/image_to_latents.py @@ -23,7 +23,7 @@ from invokeai.app.invocations.model import VAEField from invokeai.app.invocations.primitives import LatentsOutput from invokeai.app.services.shared.invocation_context import InvocationContext -from invokeai.backend.model_manager import LoadedModel +from invokeai.backend.model_manager.load.load_base import LoadedModel from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params from invokeai.backend.util.devices import TorchDevice diff --git a/invokeai/app/invocations/ip_adapter.py b/invokeai/app/invocations/ip_adapter.py index 35a98ff6ba0..2b2931e78f3 100644 --- a/invokeai/app/invocations/ip_adapter.py +++ b/invokeai/app/invocations/ip_adapter.py @@ -11,10 +11,10 @@ from invokeai.app.invocations.util import validate_begin_end_step, validate_weights from invokeai.app.services.model_records.model_records_base import ModelRecordChanges from invokeai.app.services.shared.invocation_context import InvocationContext -from invokeai.backend.model_manager.config import ( - AnyModelConfig, - IPAdapterCheckpointConfig, - IPAdapterInvokeAIConfig, +from invokeai.backend.model_manager.configs.factory import AnyModelConfig +from invokeai.backend.model_manager.configs.ip_adapter import ( + IPAdapter_Checkpoint_Config_Base, + IPAdapter_InvokeAI_Config_Base, ) from invokeai.backend.model_manager.starter_models import ( StarterModel, @@ -123,9 +123,9 @@ def validate_begin_end_step_percent(self) -> Self: def invoke(self, context: InvocationContext) -> IPAdapterOutput: # Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model. ip_adapter_info = context.models.get_config(self.ip_adapter_model.key) - assert isinstance(ip_adapter_info, (IPAdapterInvokeAIConfig, IPAdapterCheckpointConfig)) + assert isinstance(ip_adapter_info, (IPAdapter_InvokeAI_Config_Base, IPAdapter_Checkpoint_Config_Base)) - if isinstance(ip_adapter_info, IPAdapterInvokeAIConfig): + if isinstance(ip_adapter_info, IPAdapter_InvokeAI_Config_Base): image_encoder_model_id = ip_adapter_info.image_encoder_model_id image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip() else: diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index 2d338c677d2..753ae77c559 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -12,9 +12,7 @@ from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.shared.models import FreeUConfig -from invokeai.backend.model_manager.config import ( - AnyModelConfig, -) +from invokeai.backend.model_manager.configs.factory import AnyModelConfig from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType, SubModelType @@ -24,8 +22,9 @@ class ModelIdentifierField(BaseModel): name: str = Field(description="The model's name") base: BaseModelType = Field(description="The model's base model type") type: ModelType = Field(description="The model's type") - submodel_type: Optional[SubModelType] = Field( - description="The submodel to load, if this is a main model", default=None + submodel_type: SubModelType | None = Field( + description="The submodel to load, if this is a main model", + default=None, ) @classmethod diff --git a/invokeai/app/invocations/sd3_denoise.py b/invokeai/app/invocations/sd3_denoise.py index f43f26ae0ed..b9d69369b76 100644 --- a/invokeai/app/invocations/sd3_denoise.py +++ b/invokeai/app/invocations/sd3_denoise.py @@ -23,7 +23,7 @@ from invokeai.app.invocations.sd3_text_encoder import SD3_T5_MAX_SEQ_LEN from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.flux.sampling_utils import clip_timestep_schedule_fractional -from invokeai.backend.model_manager import BaseModelType +from invokeai.backend.model_manager.taxonomy import BaseModelType from invokeai.backend.rectified_flow.rectified_flow_inpaint_extension import RectifiedFlowInpaintExtension from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState from invokeai.backend.stable_diffusion.diffusion.conditioning_data import SD3ConditioningInfo diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py index 105ce0d2272..54ddaa30c32 100644 --- a/invokeai/app/services/config/config_default.py +++ b/invokeai/app/services/config/config_default.py @@ -108,6 +108,7 @@ class InvokeAIAppConfig(BaseSettings): remote_api_tokens: List of regular expression and token pairs used when downloading models from URLs. The download URL is tested against the regex, and if it matches, the token is provided in as a Bearer token. scan_models_on_startup: Scan the models directory on startup, registering orphaned models. This is typically only used in conjunction with `use_memory_db` for testing purposes. unsafe_disable_picklescan: UNSAFE. Disable the picklescan security check during model installation. Recommended only for development and testing purposes. This will allow arbitrary code execution during model installation, so should never be used in production. + allow_unknown_models: Allow installation of models that we are unable to identify. If enabled, models will be marked as `unknown` in the database, and will not have any metadata associated with them. If disabled, unknown models will be rejected during installation. """ _root: Optional[Path] = PrivateAttr(default=None) @@ -198,6 +199,7 @@ class InvokeAIAppConfig(BaseSettings): remote_api_tokens: Optional[list[URLRegexTokenPair]] = Field(default=None, description="List of regular expression and token pairs used when downloading models from URLs. The download URL is tested against the regex, and if it matches, the token is provided in as a Bearer token.") scan_models_on_startup: bool = Field(default=False, description="Scan the models directory on startup, registering orphaned models. This is typically only used in conjunction with `use_memory_db` for testing purposes.") unsafe_disable_picklescan: bool = Field(default=False, description="UNSAFE. Disable the picklescan security check during model installation. Recommended only for development and testing purposes. This will allow arbitrary code execution during model installation, so should never be used in production.") + allow_unknown_models: bool = Field(default=True, description="Allow installation of models that we are unable to identify. If enabled, models will be marked as `unknown` in the database, and will not have any metadata associated with them. If disabled, unknown models will be rejected during installation.") # fmt: on diff --git a/invokeai/app/services/events/events_base.py b/invokeai/app/services/events/events_base.py index fc0f0bb2c69..c70ef3fa16e 100644 --- a/invokeai/app/services/events/events_base.py +++ b/invokeai/app/services/events/events_base.py @@ -44,8 +44,8 @@ SessionQueueItem, SessionQueueStatus, ) - from invokeai.backend.model_manager import SubModelType - from invokeai.backend.model_manager.config import AnyModelConfig + from invokeai.backend.model_manager.configs.factory import AnyModelConfig + from invokeai.backend.model_manager.taxonomy import SubModelType class EventServiceBase: diff --git a/invokeai/app/services/events/events_common.py b/invokeai/app/services/events/events_common.py index add19d459e6..2f995293984 100644 --- a/invokeai/app/services/events/events_common.py +++ b/invokeai/app/services/events/events_common.py @@ -16,8 +16,8 @@ ) from invokeai.app.services.shared.graph import AnyInvocation, AnyInvocationOutput from invokeai.app.util.misc import get_timestamp -from invokeai.backend.model_manager import SubModelType -from invokeai.backend.model_manager.config import AnyModelConfig +from invokeai.backend.model_manager.configs.factory import AnyModelConfig +from invokeai.backend.model_manager.taxonomy import SubModelType if TYPE_CHECKING: from invokeai.app.services.download.download_base import DownloadJob @@ -546,11 +546,18 @@ class ModelInstallCompleteEvent(ModelEventBase): source: ModelSource = Field(description="Source of the model; local path, repo_id or url") key: str = Field(description="Model config record key") total_bytes: Optional[int] = Field(description="Size of the model (may be None for installation of a local path)") + config: AnyModelConfig = Field(description="The installed model's config") @classmethod def build(cls, job: "ModelInstallJob") -> "ModelInstallCompleteEvent": assert job.config_out is not None - return cls(id=job.id, source=job.source, key=(job.config_out.key), total_bytes=job.total_bytes) + return cls( + id=job.id, + source=job.source, + key=(job.config_out.key), + total_bytes=job.total_bytes, + config=job.config_out, + ) @payload_schema.register diff --git a/invokeai/app/services/model_install/model_install_base.py b/invokeai/app/services/model_install/model_install_base.py index 6ff6a42719f..39981071c18 100644 --- a/invokeai/app/services/model_install/model_install_base.py +++ b/invokeai/app/services/model_install/model_install_base.py @@ -12,7 +12,6 @@ from invokeai.app.services.invoker import Invoker from invokeai.app.services.model_install.model_install_common import ModelInstallJob, ModelSource from invokeai.app.services.model_records import ModelRecordChanges, ModelRecordServiceBase -from invokeai.backend.model_manager import AnyModelConfig if TYPE_CHECKING: from invokeai.app.services.events.events_base import EventServiceBase @@ -231,19 +230,6 @@ def wait_for_installs(self, timeout: int = 0) -> List[ModelInstallJob]: will block indefinitely until the installs complete. """ - @abstractmethod - def sync_model_path(self, key: str) -> AnyModelConfig: - """ - Move model into the location indicated by its basetype, type and name. - - Call this after updating a model's attributes in order to move - the model's path into the location indicated by its basetype, type and - name. Applies only to models whose paths are within the root `models_dir` - directory. - - May raise an UnknownModelException. - """ - @abstractmethod def download_and_cache_model(self, source: str | AnyHttpUrl) -> Path: """ diff --git a/invokeai/app/services/model_install/model_install_common.py b/invokeai/app/services/model_install/model_install_common.py index fea75d73752..f098a73d884 100644 --- a/invokeai/app/services/model_install/model_install_common.py +++ b/invokeai/app/services/model_install/model_install_common.py @@ -10,11 +10,17 @@ from invokeai.app.services.download import DownloadJob, MultiFileDownloadJob from invokeai.app.services.model_records import ModelRecordChanges -from invokeai.backend.model_manager.config import AnyModelConfig +from invokeai.backend.model_manager.configs.factory import AnyModelConfig from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata from invokeai.backend.model_manager.taxonomy import ModelRepoVariant, ModelSourceType +class InvalidModelConfigException(Exception): + """Raised when a model configuration is invalid.""" + + pass + + class InstallStatus(str, Enum): """State of an install job running in the background.""" diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index 2a6e638876e..1379184cedb 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -5,6 +5,7 @@ import re import threading import time +from copy import deepcopy from pathlib import Path from queue import Empty, Queue from shutil import move, rmtree @@ -26,6 +27,7 @@ MODEL_SOURCE_TO_TYPE_MAP, HFModelSource, InstallStatus, + InvalidModelConfigException, LocalModelSource, ModelInstallJob, ModelSource, @@ -34,13 +36,12 @@ ) from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase from invokeai.app.services.model_records.model_records_base import ModelRecordChanges -from invokeai.backend.model_manager.config import ( +from invokeai.backend.model_manager.configs.base import Checkpoint_Config_Base +from invokeai.backend.model_manager.configs.factory import ( AnyModelConfig, - CheckpointConfigBase, - InvalidModelConfigException, - ModelConfigBase, + ModelConfigFactory, ) -from invokeai.backend.model_manager.legacy_probe import ModelProbe +from invokeai.backend.model_manager.configs.unknown import Unknown_Config from invokeai.backend.model_manager.metadata import ( AnyModelRepoMetadata, HuggingFaceMetadataFetch, @@ -180,28 +181,32 @@ def install_path( self, model_path: Union[Path, str], config: Optional[ModelRecordChanges] = None, - ) -> str: # noqa D102 + ) -> str: model_path = Path(model_path) config = config or ModelRecordChanges() info: AnyModelConfig = self._probe(Path(model_path), config) # type: ignore - if preferred_name := config.name: - if Path(model_path).is_file(): - # Careful! Don't use pathlib.Path(...).with_suffix - it can will strip everything after the first dot. - preferred_name = f"{preferred_name}{model_path.suffix}" - - dest_path = ( - self.app_config.models_path / info.base.value / info.type.value / (preferred_name or model_path.name) - ) + dest_dir = self.app_config.models_path / info.key try: - new_path = self._move_model(model_path, dest_path) - except FileExistsError as excp: + if dest_dir.exists(): + raise FileExistsError( + f"Cannot install model {model_path.name} to {dest_dir}: destination already exists" + ) + dest_dir.mkdir(parents=True) + dest_path = dest_dir / model_path.name if model_path.is_file() else dest_dir + if model_path.is_file(): + move(model_path, dest_path) + elif model_path.is_dir(): + # Move the contents of the directory, not the directory itself + for item in model_path.iterdir(): + move(item, dest_dir / item.name) + except FileExistsError as e: raise DuplicateModelException( - f"A model named {model_path.name} is already installed at {dest_path.as_posix()}" - ) from excp + f"A model named {model_path.name} is already installed at {dest_dir.as_posix()}" + ) from e return self._register( - new_path, + dest_path, config, info, ) @@ -364,9 +369,18 @@ def delete(self, key: str) -> None: # noqa D102 def unconditionally_delete(self, key: str) -> None: # noqa D102 model = self.record_store.get_model(key) model_path = self.app_config.models_path / model.path + # Models are stored in a directory named by their key. To delete the model on disk, we delete the entire + # directory. However, the path we store in the model record may be either a file within the key directory, + # or the directory itself. So we have to handle both cases. if model_path.is_file() or model_path.is_symlink(): - model_path.unlink() + # Sanity check - file models should be in their own directory under the models dir. The parent of the + # file should be the model's directory, not the Invoke models dir! + assert model_path.parent != self.app_config.models_path + rmtree(model_path.parent) elif model_path.is_dir(): + # Sanity check - folder models should be in their own directory under the models dir. The path should + # not be the Invoke models dir itself! + assert model_path != self.app_config.models_path rmtree(model_path) self.unregister(key) @@ -526,7 +540,7 @@ def _set_error(self, install_job: ModelInstallJob, excp: Exception) -> None: x.content_type is not None and "text/html" in x.content_type for x in multifile_download_job.download_parts ): install_job.set_error( - InvalidModelConfigException( + ValueError( f"At least one file in {install_job.local_path} is an HTML page, not a model. This can happen when an access token is required to download." ) ) @@ -589,66 +603,25 @@ def on_model_found(model_path: Path) -> bool: found_models = search.search(self._app_config.models_path) self._logger.info(f"{len(found_models)} new models registered") - def sync_model_path(self, key: str) -> AnyModelConfig: - """ - Move model into the location indicated by its basetype, type and name. - - Call this after updating a model's attributes in order to move - the model's path into the location indicated by its basetype, type and - name. Applies only to models whose paths are within the root `models_dir` - directory. - - May raise an UnknownModelException. - """ - model = self.record_store.get_model(key) - models_dir = self.app_config.models_path - old_path = self.app_config.models_path / model.path - - if not old_path.is_relative_to(models_dir): - # The model is not in the models directory - we don't need to move it. - return model - - new_path = models_dir / model.base.value / model.type.value / old_path.name - - if old_path == new_path or new_path.exists() and old_path == new_path.resolve(): - return model - - self._logger.info(f"Moving {model.name} to {new_path}.") - new_path = self._move_model(old_path, new_path) - model.path = new_path.relative_to(models_dir).as_posix() - self.record_store.update_model(key, ModelRecordChanges(path=model.path)) - return model - - def _move_model(self, old_path: Path, new_path: Path) -> Path: - if old_path == new_path: - return old_path - - if new_path.exists(): - raise FileExistsError(f"Cannot move {old_path} to {new_path}: destination already exists") - - new_path.parent.mkdir(parents=True, exist_ok=True) - - move(old_path, new_path) - - return new_path - def _probe(self, model_path: Path, config: Optional[ModelRecordChanges] = None): config = config or ModelRecordChanges() hash_algo = self._app_config.hashing_algorithm fields = config.model_dump() - # WARNING! - # The legacy probe relies on the implicit order of tests to determine model classification. - # This can lead to regressions between the legacy and new probes. - # Do NOT change the order of `probe` and `classify` without implementing one of the following fixes: - # Short-term fix: `classify` tests `matches` in the same order as the legacy probe. - # Long-term fix: Improve `matches` to be more specific so that only one config matches - # any given model - eliminating ambiguity and removing reliance on order. - # After implementing either of these fixes, remove @pytest.mark.xfail from `test_regression_against_model_probe` - try: - return ModelProbe.probe(model_path=model_path, fields=fields, hash_algo=hash_algo) # type: ignore - except InvalidModelConfigException: - return ModelConfigBase.classify(model_path, hash_algo, **fields) + result = ModelConfigFactory.from_model_on_disk( + mod=model_path, + override_fields=deepcopy(fields), + hash_algo=hash_algo, + allow_unknown=self.app_config.allow_unknown_models, + ) + + if result.config is None: + self._logger.error(f"Could not identify model for {model_path}, detailed results: {result.details}") + raise InvalidModelConfigException(f"Could not identify model for {model_path}") + elif isinstance(result.config, Unknown_Config): + self._logger.error(f"Could not identify model for {model_path}, detailed results: {result.details}") + + return result.config def _register( self, model_path: Path, config: Optional[ModelRecordChanges] = None, info: Optional[AnyModelConfig] = None @@ -669,7 +642,7 @@ def _register( info.path = model_path.as_posix() - if isinstance(info, CheckpointConfigBase): + if isinstance(info, Checkpoint_Config_Base) and info.config_path is not None: # Checkpoints have a config file needed for conversion. Same handling as the model weights - if it's in the # invoke-managed legacy config dir, we use a relative path. legacy_config_path = self.app_config.legacy_conf_path / info.config_path diff --git a/invokeai/app/services/model_load/model_load_base.py b/invokeai/app/services/model_load/model_load_base.py index 8aae80e29da..87a405b4ea4 100644 --- a/invokeai/app/services/model_load/model_load_base.py +++ b/invokeai/app/services/model_load/model_load_base.py @@ -5,7 +5,7 @@ from pathlib import Path from typing import Callable, Optional -from invokeai.backend.model_manager.config import AnyModelConfig +from invokeai.backend.model_manager.configs.factory import AnyModelConfig from invokeai.backend.model_manager.load import LoadedModel, LoadedModelWithoutConfig from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache from invokeai.backend.model_manager.taxonomy import AnyModel, SubModelType diff --git a/invokeai/app/services/model_load/model_load_default.py b/invokeai/app/services/model_load/model_load_default.py index ad4ad97a02c..2e2d2ae219d 100644 --- a/invokeai/app/services/model_load/model_load_default.py +++ b/invokeai/app/services/model_load/model_load_default.py @@ -11,7 +11,7 @@ from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.invoker import Invoker from invokeai.app.services.model_load.model_load_base import ModelLoadServiceBase -from invokeai.backend.model_manager.config import AnyModelConfig +from invokeai.backend.model_manager.configs.factory import AnyModelConfig from invokeai.backend.model_manager.load import ( LoadedModel, LoadedModelWithoutConfig, diff --git a/invokeai/app/services/model_manager/__init__.py b/invokeai/app/services/model_manager/__init__.py index aad67ff3527..e703d4f1ffc 100644 --- a/invokeai/app/services/model_manager/__init__.py +++ b/invokeai/app/services/model_manager/__init__.py @@ -1,12 +1,10 @@ """Initialization file for model manager service.""" from invokeai.app.services.model_manager.model_manager_default import ModelManagerService, ModelManagerServiceBase -from invokeai.backend.model_manager import AnyModelConfig from invokeai.backend.model_manager.load import LoadedModel __all__ = [ "ModelManagerServiceBase", "ModelManagerService", - "AnyModelConfig", "LoadedModel", ] diff --git a/invokeai/app/services/model_records/model_records_base.py b/invokeai/app/services/model_records/model_records_base.py index 740d548a4a3..4ac227ba9f4 100644 --- a/invokeai/app/services/model_records/model_records_base.py +++ b/invokeai/app/services/model_records/model_records_base.py @@ -12,15 +12,14 @@ from invokeai.app.services.shared.pagination import PaginatedResults from invokeai.app.util.model_exclude_null import BaseModelExcludeNull -from invokeai.backend.model_manager.config import ( - AnyModelConfig, - ControlAdapterDefaultSettings, - LoraModelDefaultSettings, - MainModelDefaultSettings, -) +from invokeai.backend.model_manager.configs.controlnet import ControlAdapterDefaultSettings +from invokeai.backend.model_manager.configs.factory import AnyModelConfig +from invokeai.backend.model_manager.configs.lora import LoraModelDefaultSettings +from invokeai.backend.model_manager.configs.main import MainModelDefaultSettings from invokeai.backend.model_manager.taxonomy import ( BaseModelType, ClipVariantType, + FluxVariantType, ModelFormat, ModelSourceType, ModelType, @@ -90,7 +89,9 @@ class ModelRecordChanges(BaseModelExcludeNull): # Checkpoint-specific changes # TODO(MM2): Should we expose these? Feels footgun-y... - variant: Optional[ModelVariantType | ClipVariantType] = Field(description="The variant of the model.", default=None) + variant: Optional[ModelVariantType | ClipVariantType | FluxVariantType] = Field( + description="The variant of the model.", default=None + ) prediction_type: Optional[SchedulerPredictionType] = Field( description="The prediction type of the model.", default=None ) @@ -126,12 +127,14 @@ def del_model(self, key: str) -> None: pass @abstractmethod - def update_model(self, key: str, changes: ModelRecordChanges) -> AnyModelConfig: + def update_model(self, key: str, changes: ModelRecordChanges, allow_class_change: bool = False) -> AnyModelConfig: """ Update the model, returning the updated version. :param key: Unique key for the model to be updated. :param changes: A set of changes to apply to this model. Changes are validated before being written. + :param allow_class_change: If True, allows changes that would change the model config class. For example, + changing a LoRA into a Main model. This does not disable validation, so the changes must still be valid. """ pass diff --git a/invokeai/app/services/model_records/model_records_sql.py b/invokeai/app/services/model_records/model_records_sql.py index e3b24a6e626..943ceefdbc8 100644 --- a/invokeai/app/services/model_records/model_records_sql.py +++ b/invokeai/app/services/model_records/model_records_sql.py @@ -58,10 +58,7 @@ ) from invokeai.app.services.shared.pagination import PaginatedResults from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase -from invokeai.backend.model_manager.config import ( - AnyModelConfig, - ModelConfigFactory, -) +from invokeai.backend.model_manager.configs.factory import AnyModelConfig, ModelConfigFactory from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType @@ -137,15 +134,36 @@ def del_model(self, key: str) -> None: if cursor.rowcount == 0: raise UnknownModelException("model not found") - def update_model(self, key: str, changes: ModelRecordChanges) -> AnyModelConfig: + def update_model(self, key: str, changes: ModelRecordChanges, allow_class_change: bool = False) -> AnyModelConfig: with self._db.transaction() as cursor: record = self.get_model(key) - # Model configs use pydantic's `validate_assignment`, so each change is validated by pydantic. - for field_name in changes.model_fields_set: - setattr(record, field_name, getattr(changes, field_name)) + if allow_class_change: + # The changes may cause the model config class to change. To handle this, we need to construct the new + # class from scratch rather than trying to modify the existing instance in place. + # + # 1. Convert the existing record to a dict + # 2. Apply the changes to the dict + # 3. Attempt to create a new model config from the updated dict + + # 1. Convert the existing record to a dict + record_as_dict = record.model_dump() + + # 2. Apply the changes to the dict + for field_name in changes.model_fields_set: + record_as_dict[field_name] = getattr(changes, field_name) - json_serialized = record.model_dump_json() + # 3. Attempt to create a new model config from the updated dict + record = ModelConfigFactory.from_dict(record_as_dict) + + # If we get this far, the updated model config is valid, so we can save it to the database. + json_serialized = record.model_dump_json() + else: + # We are not allowing the model config class to change, so we can just update the existing instance in + # place. If the changes are invalid for the existing class, an exception will be raised by pydantic. + for field_name in changes.model_fields_set: + setattr(record, field_name, getattr(changes, field_name)) + json_serialized = record.model_dump_json() cursor.execute( """--sql @@ -172,7 +190,7 @@ def get_model(self, key: str) -> AnyModelConfig: with self._db.transaction() as cursor: cursor.execute( """--sql - SELECT config, strftime('%s',updated_at) FROM models + SELECT config FROM models WHERE id=?; """, (key,), @@ -180,14 +198,14 @@ def get_model(self, key: str) -> AnyModelConfig: rows = cursor.fetchone() if not rows: raise UnknownModelException("model not found") - model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1]) + model = ModelConfigFactory.from_dict(json.loads(rows[0])) return model def get_model_by_hash(self, hash: str) -> AnyModelConfig: with self._db.transaction() as cursor: cursor.execute( """--sql - SELECT config, strftime('%s',updated_at) FROM models + SELECT config FROM models WHERE hash=?; """, (hash,), @@ -195,7 +213,7 @@ def get_model_by_hash(self, hash: str) -> AnyModelConfig: rows = cursor.fetchone() if not rows: raise UnknownModelException("model not found") - model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1]) + model = ModelConfigFactory.from_dict(json.loads(rows[0])) return model def exists(self, key: str) -> bool: @@ -263,7 +281,7 @@ def search_by_attr( cursor.execute( f"""--sql - SELECT config, strftime('%s',updated_at) + SELECT config FROM models {where} ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason; @@ -276,15 +294,20 @@ def search_by_attr( results: list[AnyModelConfig] = [] for row in result: try: - model_config = ModelConfigFactory.make_config(json.loads(row[0]), timestamp=row[1]) - except pydantic.ValidationError: + model_config = ModelConfigFactory.from_dict(json.loads(row[0])) + except pydantic.ValidationError as e: # We catch this error so that the app can still run if there are invalid model configs in the database. # One reason that an invalid model config might be in the database is if someone had to rollback from a # newer version of the app that added a new model type. row_data = f"{row[0][:64]}..." if len(row[0]) > 64 else row[0] + try: + name = json.loads(row[0]).get("name", "") + except Exception: + name = "" self._logger.warning( - f"Found an invalid model config in the database. Ignoring this model. ({row_data})" + f"Skipping invalid model config in the database with name {name}. Ignoring this model. ({row_data})" ) + self._logger.warning(f"Validation error: {e}") else: results.append(model_config) @@ -295,12 +318,12 @@ def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]: with self._db.transaction() as cursor: cursor.execute( """--sql - SELECT config, strftime('%s',updated_at) FROM models + SELECT config FROM models WHERE path=?; """, (str(path),), ) - results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in cursor.fetchall()] + results = [ModelConfigFactory.from_dict(json.loads(x[0])) for x in cursor.fetchall()] return results def search_by_hash(self, hash: str) -> List[AnyModelConfig]: @@ -308,12 +331,12 @@ def search_by_hash(self, hash: str) -> List[AnyModelConfig]: with self._db.transaction() as cursor: cursor.execute( """--sql - SELECT config, strftime('%s',updated_at) FROM models + SELECT config FROM models WHERE hash=?; """, (hash,), ) - results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in cursor.fetchall()] + results = [ModelConfigFactory.from_dict(json.loads(x[0])) for x in cursor.fetchall()] return results def list_models( diff --git a/invokeai/app/services/model_relationships/model_relationships_default.py b/invokeai/app/services/model_relationships/model_relationships_default.py index 67fa6c0069d..e4da482ff27 100644 --- a/invokeai/app/services/model_relationships/model_relationships_default.py +++ b/invokeai/app/services/model_relationships/model_relationships_default.py @@ -1,6 +1,6 @@ from invokeai.app.services.invoker import Invoker from invokeai.app.services.model_relationships.model_relationships_base import ModelRelationshipsServiceABC -from invokeai.backend.model_manager.config import AnyModelConfig +from invokeai.backend.model_manager.configs.factory import AnyModelConfig class ModelRelationshipsService(ModelRelationshipsServiceABC): diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 743b6208ead..97291230e04 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -19,10 +19,8 @@ from invokeai.app.services.session_processor.session_processor_common import ProgressImage from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection from invokeai.app.util.step_callback import diffusion_step_callback -from invokeai.backend.model_manager.config import ( - AnyModelConfig, - ModelConfigBase, -) +from invokeai.backend.model_manager.configs.base import Config_Base +from invokeai.backend.model_manager.configs.factory import AnyModelConfig from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig from invokeai.backend.model_manager.taxonomy import AnyModel, BaseModelType, ModelFormat, ModelType, SubModelType from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState @@ -558,7 +556,7 @@ def get_absolute_path(self, config_or_path: AnyModelConfig | Path | str) -> Path The absolute path to the model. """ - model_path = Path(config_or_path.path) if isinstance(config_or_path, ModelConfigBase) else Path(config_or_path) + model_path = Path(config_or_path.path) if isinstance(config_or_path, Config_Base) else Path(config_or_path) if model_path.is_absolute(): return model_path.resolve() diff --git a/invokeai/app/services/shared/sqlite/sqlite_util.py b/invokeai/app/services/shared/sqlite/sqlite_util.py index 9a85d31eec1..df0e5fca049 100644 --- a/invokeai/app/services/shared/sqlite/sqlite_util.py +++ b/invokeai/app/services/shared/sqlite/sqlite_util.py @@ -24,6 +24,9 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_19 import build_migration_19 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_20 import build_migration_20 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_21 import build_migration_21 +from invokeai.app.services.shared.sqlite_migrator.migrations.migration_22 import build_migration_22 +from invokeai.app.services.shared.sqlite_migrator.migrations.migration_23 import build_migration_23 +from invokeai.app.services.shared.sqlite_migrator.migrations.migration_24 import build_migration_24 from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator @@ -65,6 +68,9 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto migrator.register_migration(build_migration_19(app_config=config)) migrator.register_migration(build_migration_20()) migrator.register_migration(build_migration_21()) + migrator.register_migration(build_migration_22(app_config=config, logger=logger)) + migrator.register_migration(build_migration_23(app_config=config, logger=logger)) + migrator.register_migration(build_migration_24(app_config=config, logger=logger)) migrator.run_migrations() return db diff --git a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_22.py b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_22.py new file mode 100644 index 00000000000..bf97cbd00ac --- /dev/null +++ b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_22.py @@ -0,0 +1,89 @@ +import sqlite3 +from logging import Logger + +from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration + + +class Migration22Callback: + def __init__(self, app_config: InvokeAIAppConfig, logger: Logger) -> None: + self._app_config = app_config + self._logger = logger + self._models_dir = app_config.models_path.resolve() + + def __call__(self, cursor: sqlite3.Cursor) -> None: + self._logger.info("Removing UNIQUE(name, base, type) constraint from models table") + + # Step 1: Rename the existing models table + cursor.execute("ALTER TABLE models RENAME TO models_old;") + + # Step 2: Create the new models table without the UNIQUE(name, base, type) constraint + cursor.execute( + """--sql + CREATE TABLE models ( + id TEXT NOT NULL PRIMARY KEY, + hash TEXT GENERATED ALWAYS as (json_extract(config, '$.hash')) VIRTUAL NOT NULL, + base TEXT GENERATED ALWAYS as (json_extract(config, '$.base')) VIRTUAL NOT NULL, + type TEXT GENERATED ALWAYS as (json_extract(config, '$.type')) VIRTUAL NOT NULL, + path TEXT GENERATED ALWAYS as (json_extract(config, '$.path')) VIRTUAL NOT NULL, + format TEXT GENERATED ALWAYS as (json_extract(config, '$.format')) VIRTUAL NOT NULL, + name TEXT GENERATED ALWAYS as (json_extract(config, '$.name')) VIRTUAL NOT NULL, + description TEXT GENERATED ALWAYS as (json_extract(config, '$.description')) VIRTUAL, + source TEXT GENERATED ALWAYS as (json_extract(config, '$.source')) VIRTUAL NOT NULL, + source_type TEXT GENERATED ALWAYS as (json_extract(config, '$.source_type')) VIRTUAL NOT NULL, + source_api_response TEXT GENERATED ALWAYS as (json_extract(config, '$.source_api_response')) VIRTUAL, + trigger_phrases TEXT GENERATED ALWAYS as (json_extract(config, '$.trigger_phrases')) VIRTUAL, + file_size INTEGER GENERATED ALWAYS as (json_extract(config, '$.file_size')) VIRTUAL NOT NULL, + -- Serialized JSON representation of the whole config object, which will contain additional fields from subclasses + config TEXT NOT NULL, + created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), + -- Updated via trigger + updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), + -- Explicit unique constraint on path + UNIQUE(path) + ); + """ + ) + + # Step 3: Copy all data from the old table to the new table + # Only copy the stored columns (id, config, created_at, updated_at), not the virtual columns + cursor.execute( + "INSERT INTO models (id, config, created_at, updated_at) " + "SELECT id, config, created_at, updated_at FROM models_old;" + ) + + # Step 4: Drop the old table + cursor.execute("DROP TABLE models_old;") + + # Step 5: Recreate indexes + cursor.execute("CREATE INDEX IF NOT EXISTS base_index ON models(base);") + cursor.execute("CREATE INDEX IF NOT EXISTS type_index ON models(type);") + cursor.execute("CREATE INDEX IF NOT EXISTS name_index ON models(name);") + + # Step 6: Recreate the updated_at trigger + cursor.execute( + """--sql + CREATE TRIGGER models_updated_at + AFTER UPDATE + ON models FOR EACH ROW + BEGIN + UPDATE models SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW') + WHERE id = old.id; + END; + """ + ) + + +def build_migration_22(app_config: InvokeAIAppConfig, logger: Logger) -> Migration: + """Builds the migration object for migrating from version 21 to version 22. + + This migration: + - Removes the UNIQUE constraint on the combination of (base, name, type) columns in the models table + - Adds an explicit UNIQUE contraint on the path column + """ + + return Migration( + from_version=21, + to_version=22, + callback=Migration22Callback(app_config=app_config, logger=logger), + ) diff --git a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_23.py b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_23.py new file mode 100644 index 00000000000..3b5dc467b38 --- /dev/null +++ b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_23.py @@ -0,0 +1,193 @@ +import json +import sqlite3 +from copy import deepcopy +from logging import Logger +from typing import Any + +from pydantic import ValidationError + +from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration +from invokeai.backend.model_manager.configs.factory import AnyModelConfig, AnyModelConfigValidator +from invokeai.backend.model_manager.configs.unknown import Unknown_Config +from invokeai.backend.model_manager.taxonomy import ( + BaseModelType, + ClipVariantType, + FluxVariantType, + ModelFormat, + ModelType, + ModelVariantType, + SchedulerPredictionType, +) + + +class Migration23Callback: + def __init__(self, app_config: InvokeAIAppConfig, logger: Logger) -> None: + self._app_config = app_config + self._logger = logger + self._models_dir = app_config.models_path.resolve() + + def __call__(self, cursor: sqlite3.Cursor) -> None: + # Grab all model records + cursor.execute("SELECT id, config FROM models;") + rows = cursor.fetchall() + + migrated_count = 0 + fallback_count = 0 + + for model_id, config_json in rows: + try: + # Migrate the config JSON to the latest schema + config_dict: dict[str, Any] = json.loads(config_json) + migrated_config = self._parse_and_migrate_config(config_dict) + + if isinstance(migrated_config, Unknown_Config): + fallback_count += 1 + else: + migrated_count += 1 + + # Write the migrated config back to the database + cursor.execute( + "UPDATE models SET config = ? WHERE id = ?;", + (migrated_config.model_dump_json(), model_id), + ) + except ValidationError as e: + self._logger.error("Invalid config schema for model %s: %s", model_id, e) + raise + except json.JSONDecodeError as e: + self._logger.error("Invalid config JSON for model %s: %s", model_id, e) + raise + + if migrated_count > 0 and fallback_count == 0: + self._logger.info(f"Migration complete: {migrated_count} model configs migrated") + elif migrated_count > 0 and fallback_count > 0: + self._logger.warning( + f"Migration complete: {migrated_count} model configs migrated, " + f"{fallback_count} model configs could not be migrated and were saved as unknown models", + ) + elif migrated_count == 0 and fallback_count > 0: + self._logger.warning( + f"Migration complete: all {fallback_count} model configs could not be migrated and were saved as unknown models", + ) + else: + self._logger.info("Migration complete: no model configs needed migration") + + def _parse_and_migrate_config(self, config_dict: dict[str, Any]) -> AnyModelConfig: + # In v6.9.0 we made some improvements to the model taxonomy and the model config schemas. There are a changes + # we need to make to old configs to bring them up to date. + + type = config_dict.get("type") + format = config_dict.get("format") + base = config_dict.get("base") + + if base == BaseModelType.Flux.value and type == ModelType.Main.value: + # Prior to v6.9.0, we used an awkward combination of `config_path` and `variant` to distinguish between FLUX + # variants. + # + # `config_path` was set to one of: + # - flux-dev + # - flux-dev-fill + # - flux-schnell + # + # `variant` was set to ModelVariantType.Inpaint for FLUX Fill models and ModelVariantType.Normal for all other FLUX + # models. + # + # We now use the `variant` field to directly represent the FLUX variant type, and `config_path` is no longer used. + + # Extract and remove `config_path` if present. + config_path = config_dict.pop("config_path", None) + + match config_path: + case "flux-dev": + config_dict["variant"] = FluxVariantType.Dev.value + case "flux-dev-fill": + config_dict["variant"] = FluxVariantType.DevFill.value + case "flux-schnell": + config_dict["variant"] = FluxVariantType.Schnell.value + case _: + # Unknown config_path - default to Dev variant + config_dict["variant"] = FluxVariantType.Dev.value + + if ( + base + in { + BaseModelType.StableDiffusion1.value, + BaseModelType.StableDiffusion2.value, + BaseModelType.StableDiffusionXL.value, + BaseModelType.StableDiffusionXLRefiner.value, + } + and type == ModelType.Main.value + ): + # Prior to v6.9.0, the prediction_type field was optional and would default to Epsilon if not present. + # We now make it explicit and always present. Use the existing value if present, otherwise default to + # Epsilon, matching the probe logic. + # + # It's only on SD1.x, SD2.x, and SDXL main models. + config_dict["prediction_type"] = config_dict.get("prediction_type", SchedulerPredictionType.Epsilon.value) + + # Prior to v6.9.0, the variant field was optional and would default to Normal if not present. + # We now make it explicit and always present. Use the existing value if present, otherwise default to + # Normal. It's only on SD main models. + config_dict["variant"] = config_dict.get("variant", ModelVariantType.Normal.value) + + if base == BaseModelType.Flux.value and type == ModelType.LoRA.value and format == ModelFormat.Diffusers.value: + # Prior to v6.9.0, we used the Diffusers format for FLUX LoRA models that used the diffusers _key_ + # structure. This was misleading, as everywhere else in the application, we used the Diffusers format + # to indicate that the model files were in the Diffusers _file_ format (i.e. a directory containing + # the weights and config files). + # + # At runtime, we check the LoRA's state dict directly to determine the key structure, so we do not need + # to rely on the format field for this purpose. As of v6.9.0, we always use the LyCORIS format for single- + # file LoRAs, regardless of the key structure. + # + # This change allows LoRA model identification to not need a special case for FLUX LoRAs in the diffusers + # key format. + config_dict["format"] = ModelFormat.LyCORIS.value + + if type == ModelType.CLIPVision.value: + # Prior to v6.9.0, some CLIP Vision models were associated with a specific base model architecture: + # - CLIP-ViT-bigG-14-laion2B-39B-b160k is the image encoder for SDXL IP Adapter and was associated with SDXL + # - CLIP-ViT-H-14-laion2B-s32B-b79K is the image encoder for SD1.5 IP Adapter and was associated with SD1.5 + # + # While this made some sense at the time, it is more correct and flexible to treat CLIP Vision models + # as independent of any specific base model architecture. + config_dict["base"] = BaseModelType.Any.value + + if type == ModelType.CLIPEmbed.value: + # Prior to v6.9.0, some CLIP Embed models did not have a variant set. The default was the L variant. + # We now make it explicit and always present. Use the existing value if present, otherwise default to + # L variant. Also, treat CLIP Embed models as independent of any specific base model architecture. + config_dict["base"] = BaseModelType.Any.value + config_dict["variant"] = config_dict.get("variant", ClipVariantType.L.value) + + try: + migrated_config = AnyModelConfigValidator.validate_python(config_dict) + # This could be a ValidationError or any other error that occurs during validation. A failure to generate a + # union discriminator could raise a ValueError, for example. Who knows what else could fail - catch all. + except Exception as e: + self._logger.error("Failed to validate migrated config, attempting to save as unknown model: %s", e) + cloned_config_dict = deepcopy(config_dict) + cloned_config_dict.pop("base", None) + cloned_config_dict.pop("type", None) + cloned_config_dict.pop("format", None) + + migrated_config = Unknown_Config( + **cloned_config_dict, + base=BaseModelType.Unknown, + type=ModelType.Unknown, + format=ModelFormat.Unknown, + ) + return migrated_config + + +def build_migration_23(app_config: InvokeAIAppConfig, logger: Logger) -> Migration: + """Builds the migration object for migrating from version 22 to version 23. + + This migration updates model configurations to the latest config schemas for v6.9.0. + """ + + return Migration( + from_version=22, + to_version=23, + callback=Migration23Callback(app_config=app_config, logger=logger), + ) diff --git a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_24.py b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_24.py new file mode 100644 index 00000000000..5ae8563b3e6 --- /dev/null +++ b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_24.py @@ -0,0 +1,240 @@ +import json +import sqlite3 +from logging import Logger +from pathlib import Path +from typing import NamedTuple + +from pydantic import ValidationError + +from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration +from invokeai.backend.model_manager.configs.factory import AnyModelConfigValidator + + +class NormalizeResult(NamedTuple): + new_relative_path: str | None + rollback_ops: list[tuple[Path, Path]] + + +class Migration24Callback: + def __init__(self, app_config: InvokeAIAppConfig, logger: Logger) -> None: + self._app_config = app_config + self._logger = logger + self._models_dir = app_config.models_path.resolve() + + def __call__(self, cursor: sqlite3.Cursor) -> None: + # Grab all model records + cursor.execute("SELECT id, config FROM models;") + rows = cursor.fetchall() + + for model_id, config_json in rows: + try: + config = AnyModelConfigValidator.validate_json(config_json) + except ValidationError: + # This could happen if the config schema changed in a way that makes old configs invalid. Unlikely + # for users, more likely for devs testing out migration paths. + self._logger.warning("Skipping model %s: invalid config schema", model_id) + continue + except json.JSONDecodeError: + # This should never happen, as we use pydantic to serialize the config to JSON. + self._logger.warning("Skipping model %s: invalid config JSON", model_id) + continue + + # We'll use a savepoint so we can roll back the database update if something goes wrong, and a simple + # rollback of file operations if needed. + cursor.execute("SAVEPOINT migrate_model") + try: + new_relative_path, rollback_ops = self._normalize_model_storage( + key=config.key, + path_value=config.path, + ) + except Exception as err: + self._logger.error("Error normalizing model %s: %s", config.key, err) + cursor.execute("ROLLBACK TO SAVEPOINT migrate_model") + cursor.execute("RELEASE SAVEPOINT migrate_model") + continue + + if new_relative_path is None: + cursor.execute("RELEASE SAVEPOINT migrate_model") + continue + + config.path = new_relative_path + try: + cursor.execute( + "UPDATE models SET config = ? WHERE id = ?;", + (config.model_dump_json(), model_id), + ) + except Exception as err: + self._logger.error("Database update failed for model %s: %s", config.key, err) + cursor.execute("ROLLBACK TO SAVEPOINT migrate_model") + cursor.execute("RELEASE SAVEPOINT migrate_model") + self._rollback_file_ops(rollback_ops) + continue + + cursor.execute("RELEASE SAVEPOINT migrate_model") + + self._prune_empty_directories() + + def _normalize_model_storage(self, key: str, path_value: str) -> NormalizeResult: + models_dir = self._models_dir + stored_path = Path(path_value) + + relative_path: Path | None + if stored_path.is_absolute(): + # If the stored path is absolute, we need to check if it's inside the models directory, which means it is + # an Invoke-managed model. If it's outside, it is user-managed we leave it alone. + try: + relative_path = stored_path.resolve().relative_to(models_dir) + except ValueError: + self._logger.info("Leaving user-managed model %s at %s", key, stored_path) + return NormalizeResult(new_relative_path=None, rollback_ops=[]) + else: + # Relative paths are always relative to the models directory and thus Invoke-managed. + relative_path = stored_path + + # If the relative path is empty, assume something is wrong. Warn and skip. + if not relative_path.parts: + self._logger.warning("Skipping model %s: empty relative path", key) + return NormalizeResult(new_relative_path=None, rollback_ops=[]) + + # Sanity check: the path is relative. It should be present in the models directory. + absolute_path = (models_dir / relative_path).resolve() + if not absolute_path.exists(): + self._logger.warning( + "Skipping model %s: expected model files at %s but nothing was found", + key, + absolute_path, + ) + return NormalizeResult(new_relative_path=None, rollback_ops=[]) + + if relative_path.parts[0] == key: + # Already normalized. Still ensure the stored path is relative. + normalized_path = relative_path.as_posix() + # If the stored path is already the normalized path, no change is needed. + new_relative_path = normalized_path if stored_path.as_posix() != normalized_path else None + return NormalizeResult(new_relative_path=new_relative_path, rollback_ops=[]) + + # We'll store the file operations we perform so we can roll them back if needed. + rollback_ops: list[tuple[Path, Path]] = [] + + # Destination directory is models_dir/ - a flat directory structure. + destination_dir = models_dir / key + + try: + if absolute_path.is_file(): + destination_dir.mkdir(parents=True, exist_ok=True) + dest_file = destination_dir / absolute_path.name + # This really shouldn't happen. + if dest_file.exists(): + self._logger.warning( + "Destination for model %s already exists at %s; skipping move", + key, + dest_file, + ) + return NormalizeResult(new_relative_path=None, rollback_ops=[]) + + self._logger.info("Moving model file %s -> %s", absolute_path, dest_file) + + # `Path.rename()` effectively moves the file or directory. + absolute_path.rename(dest_file) + rollback_ops.append((dest_file, absolute_path)) + + return NormalizeResult( + new_relative_path=(Path(key) / dest_file.name).as_posix(), + rollback_ops=rollback_ops, + ) + + if absolute_path.is_dir(): + dest_path = destination_dir + # This really shouldn't happen. + if dest_path.exists(): + self._logger.warning( + "Destination directory %s already exists for model %s; skipping", + dest_path, + key, + ) + return NormalizeResult(new_relative_path=None, rollback_ops=[]) + + self._logger.info("Moving model directory %s -> %s", absolute_path, dest_path) + + # `Path.rename()` effectively moves the file or directory. + absolute_path.rename(dest_path) + rollback_ops.append((dest_path, absolute_path)) + + return NormalizeResult( + new_relative_path=Path(key).as_posix(), + rollback_ops=rollback_ops, + ) + + # Maybe a broken symlink or something else weird? + self._logger.warning("Skipping model %s: path %s is neither a file nor directory", key, absolute_path) + return NormalizeResult(new_relative_path=None, rollback_ops=[]) + except Exception: + self._rollback_file_ops(rollback_ops) + raise + + def _rollback_file_ops(self, rollback_ops: list[tuple[Path, Path]]) -> None: + # This is a super-simple rollback that just reverses the move operations we performed. + for source, destination in reversed(rollback_ops): + try: + if source.exists(): + source.rename(destination) + except Exception as err: + self._logger.error("Failed to rollback move %s -> %s: %s", source, destination, err) + + def _prune_empty_directories(self) -> None: + # These directories are system directories we want to keep even if empty. Technically, the app should not + # have any problems if these are removed, creating them as needed, but it's cleaner to just leave them alone. + keep_names = {"model_images", ".download_cache"} + keep_dirs = {self._models_dir / name for name in keep_names} + removed_dirs: set[Path] = set() + + # Walk the models directory tree from the bottom up, removing empty directories. We sort by path length + # descending to ensure we visit children before parents. + for directory in sorted(self._models_dir.rglob("*"), key=lambda p: len(p.parts), reverse=True): + if not directory.is_dir(): + continue + if directory == self._models_dir: + continue + if any(directory == keep or keep in directory.parents for keep in keep_dirs): + continue + + try: + next(directory.iterdir()) + except StopIteration: + try: + directory.rmdir() + removed_dirs.add(directory) + self._logger.debug("Removed empty directory %s", directory) + except OSError: + # Directory not empty (or some other error) - bail out. + self._logger.warning("Failed to prune directory %s - not empty?", directory) + continue + except OSError: + continue + + self._logger.info("Pruned %d empty directories under %s", len(removed_dirs), self._models_dir) + + +def build_migration_24(app_config: InvokeAIAppConfig, logger: Logger) -> Migration: + """Builds the migration object for migrating from version 23 to version 24. + + This migration normalizes on-disk model storage so that each model lives within + a directory named by its key inside the Invoke-managed models directory, and + updates database records to reference the new relative paths. + + This migration behaves a bit differently than others. Because it involves FS operations, if we rolled the + DB back on any failure, we could leave the FS out of sync with the DB. Instead, we use savepoints + to roll back individual model updates on failure, and we roll back any FS operations we performed + for that model. + + If a model cannot be migrated for any reason (invalid config, missing files, FS errors, DB errors), we log a + warning and skip it, leaving it in its original state and location. The model will still work, but it will be in + the "wrong" location on disk. + """ + + return Migration( + from_version=23, + to_version=24, + callback=Migration24Callback(app_config=app_config, logger=logger), + ) diff --git a/invokeai/app/util/custom_openapi.py b/invokeai/app/util/custom_openapi.py index d6b8f3786f1..d400e0ff11b 100644 --- a/invokeai/app/util/custom_openapi.py +++ b/invokeai/app/util/custom_openapi.py @@ -12,6 +12,7 @@ from invokeai.app.invocations.model import ModelIdentifierField from invokeai.app.services.events.events_common import EventBase from invokeai.app.services.session_processor.session_processor_common import ProgressImage +from invokeai.backend.model_manager.configs.factory import AnyModelConfigValidator from invokeai.backend.util.logging import InvokeAILogger logger = InvokeAILogger.get_logger() @@ -115,6 +116,13 @@ def openapi() -> dict[str, Any]: # additional_schemas[1] is a dict of $defs that we need to add to the top level of the schema move_defs_to_top_level(openapi_schema, additional_schemas[1]) + any_model_config_schema = AnyModelConfigValidator.json_schema( + mode="serialization", + ref_template="#/components/schemas/{model}", + ) + move_defs_to_top_level(openapi_schema, any_model_config_schema) + openapi_schema["components"]["schemas"]["AnyModelConfig"] = any_model_config_schema + if post_transform is not None: openapi_schema = post_transform(openapi_schema) diff --git a/invokeai/backend/flux/controlnet/state_dict_utils.py b/invokeai/backend/flux/controlnet/state_dict_utils.py index aa44e6c10f0..87eae5a96bc 100644 --- a/invokeai/backend/flux/controlnet/state_dict_utils.py +++ b/invokeai/backend/flux/controlnet/state_dict_utils.py @@ -5,7 +5,7 @@ from invokeai.backend.flux.model import FluxParams -def is_state_dict_xlabs_controlnet(sd: Dict[str, Any]) -> bool: +def is_state_dict_xlabs_controlnet(sd: dict[str | int, Any]) -> bool: """Is the state dict for an XLabs ControlNet model? This is intended to be a reasonably high-precision detector, but it is not guaranteed to have perfect precision. @@ -25,7 +25,7 @@ def is_state_dict_xlabs_controlnet(sd: Dict[str, Any]) -> bool: return False -def is_state_dict_instantx_controlnet(sd: Dict[str, Any]) -> bool: +def is_state_dict_instantx_controlnet(sd: dict[str | int, Any]) -> bool: """Is the state dict for an InstantX ControlNet model? This is intended to be a reasonably high-precision detector, but it is not guaranteed to have perfect precision. diff --git a/invokeai/backend/flux/flux_state_dict_utils.py b/invokeai/backend/flux/flux_state_dict_utils.py index 8ffab54c688..c306c88f965 100644 --- a/invokeai/backend/flux/flux_state_dict_utils.py +++ b/invokeai/backend/flux/flux_state_dict_utils.py @@ -1,10 +1,7 @@ -from typing import TYPE_CHECKING +from typing import Any -if TYPE_CHECKING: - from invokeai.backend.model_manager.legacy_probe import CkptType - -def get_flux_in_channels_from_state_dict(state_dict: "CkptType") -> int | None: +def get_flux_in_channels_from_state_dict(state_dict: dict[str | int, Any]) -> int | None: """Gets the in channels from the state dict.""" # "Standard" FLUX models use "img_in.weight", but some community fine tunes use diff --git a/invokeai/backend/flux/ip_adapter/state_dict_utils.py b/invokeai/backend/flux/ip_adapter/state_dict_utils.py index 90f11ff642b..24ac53550f9 100644 --- a/invokeai/backend/flux/ip_adapter/state_dict_utils.py +++ b/invokeai/backend/flux/ip_adapter/state_dict_utils.py @@ -1,11 +1,11 @@ -from typing import Any, Dict +from typing import Any import torch from invokeai.backend.flux.ip_adapter.xlabs_ip_adapter_flux import XlabsIpAdapterParams -def is_state_dict_xlabs_ip_adapter(sd: Dict[str, Any]) -> bool: +def is_state_dict_xlabs_ip_adapter(sd: dict[str | int, Any]) -> bool: """Is the state dict for an XLabs FLUX IP-Adapter model? This is intended to be a reasonably high-precision detector, but it is not guaranteed to have perfect precision. @@ -27,7 +27,7 @@ def is_state_dict_xlabs_ip_adapter(sd: Dict[str, Any]) -> bool: return False -def infer_xlabs_ip_adapter_params_from_state_dict(state_dict: dict[str, torch.Tensor]) -> XlabsIpAdapterParams: +def infer_xlabs_ip_adapter_params_from_state_dict(state_dict: dict[str | int, torch.Tensor]) -> XlabsIpAdapterParams: num_double_blocks = 0 context_dim = 0 hidden_dim = 0 diff --git a/invokeai/backend/flux/redux/flux_redux_state_dict_utils.py b/invokeai/backend/flux/redux/flux_redux_state_dict_utils.py index a5a13b402d3..83e96d38451 100644 --- a/invokeai/backend/flux/redux/flux_redux_state_dict_utils.py +++ b/invokeai/backend/flux/redux/flux_redux_state_dict_utils.py @@ -1,7 +1,7 @@ -from typing import Any, Dict +from typing import Any -def is_state_dict_likely_flux_redux(state_dict: Dict[str, Any]) -> bool: +def is_state_dict_likely_flux_redux(state_dict: dict[str | int, Any]) -> bool: """Checks if the provided state dict is likely a FLUX Redux model.""" expected_keys = {"redux_down.bias", "redux_down.weight", "redux_up.bias", "redux_up.weight"} diff --git a/invokeai/backend/flux/util.py b/invokeai/backend/flux/util.py index 2a5261cb5c6..2cf52b6ec11 100644 --- a/invokeai/backend/flux/util.py +++ b/invokeai/backend/flux/util.py @@ -1,10 +1,11 @@ # Initially pulled from https://github.com/black-forest-labs/flux from dataclasses import dataclass -from typing import Dict, Literal +from typing import Literal from invokeai.backend.flux.model import FluxParams from invokeai.backend.flux.modules.autoencoder import AutoEncoderParams +from invokeai.backend.model_manager.taxonomy import AnyVariant, FluxVariantType @dataclass @@ -41,30 +42,39 @@ class ModelSpec: ] -max_seq_lengths: Dict[str, Literal[256, 512]] = { - "flux-dev": 512, - "flux-dev-fill": 512, - "flux-schnell": 256, +_flux_max_seq_lengths: dict[AnyVariant, Literal[256, 512]] = { + FluxVariantType.Dev: 512, + FluxVariantType.DevFill: 512, + FluxVariantType.Schnell: 256, } -ae_params = { - "flux": AutoEncoderParams( - resolution=256, - in_channels=3, - ch=128, - out_ch=3, - ch_mult=[1, 2, 4, 4], - num_res_blocks=2, - z_channels=16, - scale_factor=0.3611, - shift_factor=0.1159, - ) -} +def get_flux_max_seq_length(variant: AnyVariant): + try: + return _flux_max_seq_lengths[variant] + except KeyError: + raise ValueError(f"Unknown variant for FLUX max seq len: {variant}") + + +_flux_ae_params = AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, +) + +def get_flux_ae_params() -> AutoEncoderParams: + return _flux_ae_params -params = { - "flux-dev": FluxParams( + +_flux_transformer_params: dict[AnyVariant, FluxParams] = { + FluxVariantType.Dev: FluxParams( in_channels=64, vec_in_dim=768, context_in_dim=4096, @@ -78,7 +88,7 @@ class ModelSpec: qkv_bias=True, guidance_embed=True, ), - "flux-schnell": FluxParams( + FluxVariantType.Schnell: FluxParams( in_channels=64, vec_in_dim=768, context_in_dim=4096, @@ -92,7 +102,7 @@ class ModelSpec: qkv_bias=True, guidance_embed=False, ), - "flux-dev-fill": FluxParams( + FluxVariantType.DevFill: FluxParams( in_channels=384, out_channels=64, vec_in_dim=768, @@ -108,3 +118,10 @@ class ModelSpec: guidance_embed=True, ), } + + +def get_flux_transformers_params(variant: AnyVariant): + try: + return _flux_transformer_params[variant] + except KeyError: + raise ValueError(f"Unknown variant for FLUX transformer params: {variant}") diff --git a/invokeai/backend/model_manager/__init__.py b/invokeai/backend/model_manager/__init__.py index dca72f170e0..e69de29bb2d 100644 --- a/invokeai/backend/model_manager/__init__.py +++ b/invokeai/backend/model_manager/__init__.py @@ -1,45 +0,0 @@ -"""Re-export frequently-used symbols from the Model Manager backend.""" - -from invokeai.backend.model_manager.config import ( - AnyModelConfig, - InvalidModelConfigException, - ModelConfigBase, - ModelConfigFactory, -) -from invokeai.backend.model_manager.legacy_probe import ModelProbe -from invokeai.backend.model_manager.load import LoadedModel -from invokeai.backend.model_manager.search import ModelSearch -from invokeai.backend.model_manager.taxonomy import ( - AnyModel, - AnyVariant, - BaseModelType, - ClipVariantType, - ModelFormat, - ModelRepoVariant, - ModelSourceType, - ModelType, - ModelVariantType, - SchedulerPredictionType, - SubModelType, -) - -__all__ = [ - "AnyModelConfig", - "InvalidModelConfigException", - "LoadedModel", - "ModelConfigFactory", - "ModelProbe", - "ModelSearch", - "ModelConfigBase", - "AnyModel", - "AnyVariant", - "BaseModelType", - "ClipVariantType", - "ModelFormat", - "ModelRepoVariant", - "ModelSourceType", - "ModelType", - "ModelVariantType", - "SchedulerPredictionType", - "SubModelType", -] diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py deleted file mode 100644 index 1bfc15c046f..00000000000 --- a/invokeai/backend/model_manager/config.py +++ /dev/null @@ -1,770 +0,0 @@ -# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team -""" -Configuration definitions for image generation models. - -Typical usage: - - from invokeai.backend.model_manager import ModelConfigFactory - raw = dict(path='models/sd-1/main/foo.ckpt', - name='foo', - base='sd-1', - type='main', - config='configs/stable-diffusion/v1-inference.yaml', - variant='normal', - format='checkpoint' - ) - config = ModelConfigFactory.make_config(raw) - print(config.name) - -Validation errors will raise an InvalidModelConfigException error. - -""" - -# pyright: reportIncompatibleVariableOverride=false -import json -import logging -import time -from abc import ABC, abstractmethod -from enum import Enum -from inspect import isabstract -from pathlib import Path -from typing import ClassVar, Literal, Optional, TypeAlias, Union - -from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter -from typing_extensions import Annotated, Any, Dict - -from invokeai.app.util.misc import uuid_string -from invokeai.backend.model_hash.hash_validator import validate_hash -from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS -from invokeai.backend.model_manager.model_on_disk import ModelOnDisk -from invokeai.backend.model_manager.omi import flux_dev_1_lora, stable_diffusion_xl_1_lora -from invokeai.backend.model_manager.taxonomy import ( - AnyVariant, - BaseModelType, - ClipVariantType, - FluxLoRAFormat, - ModelFormat, - ModelRepoVariant, - ModelSourceType, - ModelType, - ModelVariantType, - SchedulerPredictionType, - SubModelType, -) -from invokeai.backend.model_manager.util.model_util import lora_token_vector_length -from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES - -logger = logging.getLogger(__name__) - - -class InvalidModelConfigException(Exception): - """Exception for when config parser doesn't recognize this combination of model type and format.""" - - pass - - -DEFAULTS_PRECISION = Literal["fp16", "fp32"] - - -class SubmodelDefinition(BaseModel): - path_or_prefix: str - model_type: ModelType - variant: AnyVariant = None - - model_config = ConfigDict(protected_namespaces=()) - - -class MainModelDefaultSettings(BaseModel): - vae: str | None = Field(default=None, description="Default VAE for this model (model key)") - vae_precision: DEFAULTS_PRECISION | None = Field(default=None, description="Default VAE precision for this model") - scheduler: SCHEDULER_NAME_VALUES | None = Field(default=None, description="Default scheduler for this model") - steps: int | None = Field(default=None, gt=0, description="Default number of steps for this model") - cfg_scale: float | None = Field(default=None, ge=1, description="Default CFG Scale for this model") - cfg_rescale_multiplier: float | None = Field( - default=None, ge=0, lt=1, description="Default CFG Rescale Multiplier for this model" - ) - width: int | None = Field(default=None, multiple_of=8, ge=64, description="Default width for this model") - height: int | None = Field(default=None, multiple_of=8, ge=64, description="Default height for this model") - guidance: float | None = Field(default=None, ge=1, description="Default Guidance for this model") - - model_config = ConfigDict(extra="forbid") - - -class LoraModelDefaultSettings(BaseModel): - weight: float | None = Field(default=None, ge=-1, le=2, description="Default weight for this model") - model_config = ConfigDict(extra="forbid") - - -class ControlAdapterDefaultSettings(BaseModel): - # This could be narrowed to controlnet processor nodes, but they change. Leaving this a string is safer. - preprocessor: str | None - model_config = ConfigDict(extra="forbid") - - -class MatchSpeed(int, Enum): - """Represents the estimated runtime speed of a config's 'matches' method.""" - - FAST = 0 - MED = 1 - SLOW = 2 - - -class ModelConfigBase(ABC, BaseModel): - """ - Abstract Base class for model configurations. - - To create a new config type, inherit from this class and implement its interface: - - (mandatory) override methods 'matches' and 'parse' - - (mandatory) define fields 'type' and 'format' as class attributes - - - (optional) override method 'get_tag' - - (optional) override field _MATCH_SPEED - - See MinimalConfigExample in test_model_probe.py for an example implementation. - """ - - @staticmethod - def json_schema_extra(schema: dict[str, Any]) -> None: - schema["required"].extend(["key", "type", "format"]) - - model_config = ConfigDict(validate_assignment=True, json_schema_extra=json_schema_extra) - - key: str = Field(description="A unique key for this model.", default_factory=uuid_string) - hash: str = Field(description="The hash of the model file(s).") - path: str = Field( - description="Path to the model on the filesystem. Relative paths are relative to the Invoke root directory." - ) - file_size: int = Field(description="The size of the model in bytes.") - name: str = Field(description="Name of the model.") - type: ModelType = Field(description="Model type") - format: ModelFormat = Field(description="Model format") - base: BaseModelType = Field(description="The base model.") - source: str = Field(description="The original source of the model (path, URL or repo_id).") - source_type: ModelSourceType = Field(description="The type of source") - - description: Optional[str] = Field(description="Model description", default=None) - source_api_response: Optional[str] = Field( - description="The original API response from the source, as stringified JSON.", default=None - ) - cover_image: Optional[str] = Field(description="Url for image to preview model", default=None) - submodels: Optional[Dict[SubModelType, SubmodelDefinition]] = Field( - description="Loadable submodels in this model", default=None - ) - usage_info: Optional[str] = Field(default=None, description="Usage information for this model") - - USING_LEGACY_PROBE: ClassVar[set] = set() - USING_CLASSIFY_API: ClassVar[set] = set() - _MATCH_SPEED: ClassVar[MatchSpeed] = MatchSpeed.MED - - def __init_subclass__(cls, **kwargs): - super().__init_subclass__(**kwargs) - if issubclass(cls, LegacyProbeMixin): - ModelConfigBase.USING_LEGACY_PROBE.add(cls) - else: - ModelConfigBase.USING_CLASSIFY_API.add(cls) - - @staticmethod - def all_config_classes(): - subclasses = ModelConfigBase.USING_LEGACY_PROBE | ModelConfigBase.USING_CLASSIFY_API - concrete = {cls for cls in subclasses if not isabstract(cls)} - return concrete - - @staticmethod - def classify(mod: str | Path | ModelOnDisk, hash_algo: HASHING_ALGORITHMS = "blake3_single", **overrides): - """ - Returns the best matching ModelConfig instance from a model's file/folder path. - Raises InvalidModelConfigException if no valid configuration is found. - Created to deprecate ModelProbe.probe - """ - if isinstance(mod, Path | str): - mod = ModelOnDisk(mod, hash_algo) - - candidates = ModelConfigBase.USING_CLASSIFY_API - sorted_by_match_speed = sorted(candidates, key=lambda cls: (cls._MATCH_SPEED, cls.__name__)) - - for config_cls in sorted_by_match_speed: - try: - if not config_cls.matches(mod): - continue - except Exception as e: - logger.warning(f"Unexpected exception while matching {mod.name} to '{config_cls.__name__}': {e}") - continue - else: - return config_cls.from_model_on_disk(mod, **overrides) - - raise InvalidModelConfigException("Unable to determine model type") - - @classmethod - def get_tag(cls) -> Tag: - type = cls.model_fields["type"].default.value - format = cls.model_fields["format"].default.value - return Tag(f"{type}.{format}") - - @classmethod - @abstractmethod - def parse(cls, mod: ModelOnDisk) -> dict[str, Any]: - """Returns a dictionary with the fields needed to construct the model. - Raises InvalidModelConfigException if the model is invalid. - """ - pass - - @classmethod - @abstractmethod - def matches(cls, mod: ModelOnDisk) -> bool: - """Performs a quick check to determine if the config matches the model. - This doesn't need to be a perfect test - the aim is to eliminate unlikely matches quickly before parsing.""" - pass - - @staticmethod - def cast_overrides(overrides: dict[str, Any]): - """Casts user overrides from str to Enum""" - if "type" in overrides: - overrides["type"] = ModelType(overrides["type"]) - - if "format" in overrides: - overrides["format"] = ModelFormat(overrides["format"]) - - if "base" in overrides: - overrides["base"] = BaseModelType(overrides["base"]) - - if "source_type" in overrides: - overrides["source_type"] = ModelSourceType(overrides["source_type"]) - - if "variant" in overrides: - overrides["variant"] = ModelVariantType(overrides["variant"]) - - @classmethod - def from_model_on_disk(cls, mod: ModelOnDisk, **overrides): - """Creates an instance of this config or raises InvalidModelConfigException.""" - fields = cls.parse(mod) - cls.cast_overrides(overrides) - fields.update(overrides) - - type = fields.get("type") or cls.model_fields["type"].default - base = fields.get("base") or cls.model_fields["base"].default - - fields["path"] = mod.path.as_posix() - fields["source"] = fields.get("source") or fields["path"] - fields["source_type"] = fields.get("source_type") or ModelSourceType.Path - fields["name"] = name = fields.get("name") or mod.name - fields["hash"] = fields.get("hash") or mod.hash() - fields["key"] = fields.get("key") or uuid_string() - fields["description"] = fields.get("description") or f"{base.value} {type.value} model {name}" - fields["repo_variant"] = fields.get("repo_variant") or mod.repo_variant() - fields["file_size"] = fields.get("file_size") or mod.size() - - return cls(**fields) - - -class LegacyProbeMixin: - """Mixin for classes using the legacy probe for model classification.""" - - @classmethod - def matches(cls, *args, **kwargs): - raise NotImplementedError(f"Method 'matches' not implemented for {cls.__name__}") - - @classmethod - def parse(cls, *args, **kwargs): - raise NotImplementedError(f"Method 'parse' not implemented for {cls.__name__}") - - -class CheckpointConfigBase(ABC, BaseModel): - """Base class for checkpoint-style models.""" - - format: Literal[ModelFormat.Checkpoint, ModelFormat.BnbQuantizednf4b, ModelFormat.GGUFQuantized] = Field( - description="Format of the provided checkpoint model", default=ModelFormat.Checkpoint - ) - config_path: str = Field(description="path to the checkpoint model config file") - converted_at: Optional[float] = Field( - description="When this model was last converted to diffusers", default_factory=time.time - ) - - -class DiffusersConfigBase(ABC, BaseModel): - """Base class for diffusers-style models.""" - - format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers - repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.Default - - -class LoRAConfigBase(ABC, BaseModel): - """Base class for LoRA models.""" - - type: Literal[ModelType.LoRA] = ModelType.LoRA - trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None) - default_settings: Optional[LoraModelDefaultSettings] = Field( - description="Default settings for this model", default=None - ) - - @classmethod - def flux_lora_format(cls, mod: ModelOnDisk): - key = "FLUX_LORA_FORMAT" - if key in mod.cache: - return mod.cache[key] - - from invokeai.backend.patches.lora_conversions.formats import flux_format_from_state_dict - - sd = mod.load_state_dict(mod.path) - value = flux_format_from_state_dict(sd, mod.metadata()) - mod.cache[key] = value - return value - - @classmethod - def base_model(cls, mod: ModelOnDisk) -> BaseModelType: - if cls.flux_lora_format(mod): - return BaseModelType.Flux - - state_dict = mod.load_state_dict() - # If we've gotten here, we assume that the model is a Stable Diffusion model - token_vector_length = lora_token_vector_length(state_dict) - if token_vector_length == 768: - return BaseModelType.StableDiffusion1 - elif token_vector_length == 1024: - return BaseModelType.StableDiffusion2 - elif token_vector_length == 1280: - return BaseModelType.StableDiffusionXL # recognizes format at https://civitai.com/models/224641 - elif token_vector_length == 2048: - return BaseModelType.StableDiffusionXL - else: - raise InvalidModelConfigException("Unknown LoRA type") - - -class T5EncoderConfigBase(ABC, BaseModel): - """Base class for diffusers-style models.""" - - type: Literal[ModelType.T5Encoder] = ModelType.T5Encoder - - -class T5EncoderConfig(T5EncoderConfigBase, LegacyProbeMixin, ModelConfigBase): - format: Literal[ModelFormat.T5Encoder] = ModelFormat.T5Encoder - - -class T5EncoderBnbQuantizedLlmInt8bConfig(T5EncoderConfigBase, LegacyProbeMixin, ModelConfigBase): - format: Literal[ModelFormat.BnbQuantizedLlmInt8b] = ModelFormat.BnbQuantizedLlmInt8b - - -class LoRAOmiConfig(LoRAConfigBase, ModelConfigBase): - format: Literal[ModelFormat.OMI] = ModelFormat.OMI - - @classmethod - def matches(cls, mod: ModelOnDisk) -> bool: - if mod.path.is_dir(): - return False - - metadata = mod.metadata() - return ( - metadata.get("modelspec.sai_model_spec") - and metadata.get("ot_branch") == "omi_format" - and metadata["modelspec.architecture"].split("/")[1].lower() == "lora" - ) - - @classmethod - def parse(cls, mod: ModelOnDisk) -> dict[str, Any]: - metadata = mod.metadata() - architecture = metadata["modelspec.architecture"] - - if architecture == stable_diffusion_xl_1_lora: - base = BaseModelType.StableDiffusionXL - elif architecture == flux_dev_1_lora: - base = BaseModelType.Flux - else: - raise InvalidModelConfigException(f"Unrecognised/unsupported architecture for OMI LoRA: {architecture}") - - return {"base": base} - - -class LoRALyCORISConfig(LoRAConfigBase, ModelConfigBase): - """Model config for LoRA/Lycoris models.""" - - format: Literal[ModelFormat.LyCORIS] = ModelFormat.LyCORIS - - @classmethod - def matches(cls, mod: ModelOnDisk) -> bool: - if mod.path.is_dir(): - return False - - # Avoid false positive match against ControlLoRA and Diffusers - if cls.flux_lora_format(mod) in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]: - return False - - state_dict = mod.load_state_dict() - for key in state_dict.keys(): - if isinstance(key, int): - continue - - if key.startswith(("lora_te_", "lora_unet_", "lora_te1_", "lora_te2_", "lora_transformer_")): - return True - # "lora_A.weight" and "lora_B.weight" are associated with models in PEFT format. We don't support all PEFT - # LoRA models, but as of the time of writing, we support Diffusers FLUX PEFT LoRA models. - if key.endswith(("to_k_lora.up.weight", "to_q_lora.down.weight", "lora_A.weight", "lora_B.weight")): - return True - - return False - - @classmethod - def parse(cls, mod: ModelOnDisk) -> dict[str, Any]: - return { - "base": cls.base_model(mod), - } - - -class ControlAdapterConfigBase(ABC, BaseModel): - default_settings: Optional[ControlAdapterDefaultSettings] = Field( - description="Default settings for this model", default=None - ) - - -class ControlLoRALyCORISConfig(ControlAdapterConfigBase, LegacyProbeMixin, ModelConfigBase): - """Model config for Control LoRA models.""" - - type: Literal[ModelType.ControlLoRa] = ModelType.ControlLoRa - trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None) - format: Literal[ModelFormat.LyCORIS] = ModelFormat.LyCORIS - - -class ControlLoRADiffusersConfig(ControlAdapterConfigBase, LegacyProbeMixin, ModelConfigBase): - """Model config for Control LoRA models.""" - - type: Literal[ModelType.ControlLoRa] = ModelType.ControlLoRa - trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None) - format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers - - -class LoRADiffusersConfig(LoRAConfigBase, ModelConfigBase): - """Model config for LoRA/Diffusers models.""" - - format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers - - @classmethod - def matches(cls, mod: ModelOnDisk) -> bool: - if mod.path.is_file(): - return cls.flux_lora_format(mod) == FluxLoRAFormat.Diffusers - - suffixes = ["bin", "safetensors"] - weight_files = [mod.path / f"pytorch_lora_weights.{sfx}" for sfx in suffixes] - return any(wf.exists() for wf in weight_files) - - @classmethod - def parse(cls, mod: ModelOnDisk) -> dict[str, Any]: - return { - "base": cls.base_model(mod), - } - - -class VAECheckpointConfig(CheckpointConfigBase, LegacyProbeMixin, ModelConfigBase): - """Model config for standalone VAE models.""" - - type: Literal[ModelType.VAE] = ModelType.VAE - - -class VAEDiffusersConfig(LegacyProbeMixin, ModelConfigBase): - """Model config for standalone VAE models (diffusers version).""" - - type: Literal[ModelType.VAE] = ModelType.VAE - format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers - - -class ControlNetDiffusersConfig(DiffusersConfigBase, ControlAdapterConfigBase, LegacyProbeMixin, ModelConfigBase): - """Model config for ControlNet models (diffusers version).""" - - type: Literal[ModelType.ControlNet] = ModelType.ControlNet - format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers - - -class ControlNetCheckpointConfig(CheckpointConfigBase, ControlAdapterConfigBase, LegacyProbeMixin, ModelConfigBase): - """Model config for ControlNet models (diffusers version).""" - - type: Literal[ModelType.ControlNet] = ModelType.ControlNet - - -class TextualInversionFileConfig(LegacyProbeMixin, ModelConfigBase): - """Model config for textual inversion embeddings.""" - - type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion - format: Literal[ModelFormat.EmbeddingFile] = ModelFormat.EmbeddingFile - - -class TextualInversionFolderConfig(LegacyProbeMixin, ModelConfigBase): - """Model config for textual inversion embeddings.""" - - type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion - format: Literal[ModelFormat.EmbeddingFolder] = ModelFormat.EmbeddingFolder - - -class MainConfigBase(ABC, BaseModel): - type: Literal[ModelType.Main] = ModelType.Main - trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None) - default_settings: Optional[MainModelDefaultSettings] = Field( - description="Default settings for this model", default=None - ) - variant: AnyVariant = ModelVariantType.Normal - - -class VideoConfigBase(ABC, BaseModel): - type: Literal[ModelType.Video] = ModelType.Video - trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None) - default_settings: Optional[MainModelDefaultSettings] = Field( - description="Default settings for this model", default=None - ) - variant: AnyVariant = ModelVariantType.Normal - - -class MainCheckpointConfig(CheckpointConfigBase, MainConfigBase, LegacyProbeMixin, ModelConfigBase): - """Model config for main checkpoint models.""" - - prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon - upcast_attention: bool = False - - -class MainBnbQuantized4bCheckpointConfig(CheckpointConfigBase, MainConfigBase, LegacyProbeMixin, ModelConfigBase): - """Model config for main checkpoint models.""" - - format: Literal[ModelFormat.BnbQuantizednf4b] = ModelFormat.BnbQuantizednf4b - prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon - upcast_attention: bool = False - - -class MainGGUFCheckpointConfig(CheckpointConfigBase, MainConfigBase, LegacyProbeMixin, ModelConfigBase): - """Model config for main checkpoint models.""" - - format: Literal[ModelFormat.GGUFQuantized] = ModelFormat.GGUFQuantized - prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon - upcast_attention: bool = False - - -class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase, LegacyProbeMixin, ModelConfigBase): - """Model config for main diffusers models.""" - - pass - - -class IPAdapterConfigBase(ABC, BaseModel): - type: Literal[ModelType.IPAdapter] = ModelType.IPAdapter - - -class IPAdapterInvokeAIConfig(IPAdapterConfigBase, LegacyProbeMixin, ModelConfigBase): - """Model config for IP Adapter diffusers format models.""" - - # TODO(ryand): Should we deprecate this field? From what I can tell, it hasn't been probed correctly for a long - # time. Need to go through the history to make sure I'm understanding this fully. - image_encoder_model_id: str - format: Literal[ModelFormat.InvokeAI] = ModelFormat.InvokeAI - - -class IPAdapterCheckpointConfig(IPAdapterConfigBase, LegacyProbeMixin, ModelConfigBase): - """Model config for IP Adapter checkpoint format models.""" - - format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint - - -class CLIPEmbedDiffusersConfig(DiffusersConfigBase): - """Model config for Clip Embeddings.""" - - variant: ClipVariantType = Field(description="Clip variant for this model") - type: Literal[ModelType.CLIPEmbed] = ModelType.CLIPEmbed - format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers - - -class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, LegacyProbeMixin, ModelConfigBase): - """Model config for CLIP-G Embeddings.""" - - variant: Literal[ClipVariantType.G] = ClipVariantType.G - - @classmethod - def get_tag(cls) -> Tag: - return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}.{ClipVariantType.G.value}") - - -class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, LegacyProbeMixin, ModelConfigBase): - """Model config for CLIP-L Embeddings.""" - - variant: Literal[ClipVariantType.L] = ClipVariantType.L - - @classmethod - def get_tag(cls) -> Tag: - return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}.{ClipVariantType.L.value}") - - -class CLIPVisionDiffusersConfig(DiffusersConfigBase, LegacyProbeMixin, ModelConfigBase): - """Model config for CLIPVision.""" - - type: Literal[ModelType.CLIPVision] = ModelType.CLIPVision - format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers - - -class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase, LegacyProbeMixin, ModelConfigBase): - """Model config for T2I.""" - - type: Literal[ModelType.T2IAdapter] = ModelType.T2IAdapter - format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers - - -class SpandrelImageToImageConfig(LegacyProbeMixin, ModelConfigBase): - """Model config for Spandrel Image to Image models.""" - - _MATCH_SPEED: ClassVar[MatchSpeed] = MatchSpeed.SLOW # requires loading the model from disk - - type: Literal[ModelType.SpandrelImageToImage] = ModelType.SpandrelImageToImage - format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint - - -class SigLIPConfig(DiffusersConfigBase, LegacyProbeMixin, ModelConfigBase): - """Model config for SigLIP.""" - - type: Literal[ModelType.SigLIP] = ModelType.SigLIP - format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers - - -class FluxReduxConfig(LegacyProbeMixin, ModelConfigBase): - """Model config for FLUX Tools Redux model.""" - - type: Literal[ModelType.FluxRedux] = ModelType.FluxRedux - format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint - - -class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase): - """Model config for Llava Onevision models.""" - - type: Literal[ModelType.LlavaOnevision] = ModelType.LlavaOnevision - format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers - - @classmethod - def matches(cls, mod: ModelOnDisk) -> bool: - if mod.path.is_file(): - return False - - config_path = mod.path / "config.json" - try: - with open(config_path, "r") as file: - config = json.load(file) - except FileNotFoundError: - return False - - architectures = config.get("architectures") - return architectures and architectures[0] == "LlavaOnevisionForConditionalGeneration" - - @classmethod - def parse(cls, mod: ModelOnDisk) -> dict[str, Any]: - return { - "base": BaseModelType.Any, - "variant": ModelVariantType.Normal, - } - - -class ApiModelConfig(MainConfigBase, ModelConfigBase): - """Model config for API-based models.""" - - format: Literal[ModelFormat.Api] = ModelFormat.Api - - @classmethod - def matches(cls, mod: ModelOnDisk) -> bool: - # API models are not stored on disk, so we can't match them. - return False - - @classmethod - def parse(cls, mod: ModelOnDisk) -> dict[str, Any]: - raise NotImplementedError("API models are not parsed from disk.") - - -class VideoApiModelConfig(VideoConfigBase, ModelConfigBase): - """Model config for API-based video models.""" - - format: Literal[ModelFormat.Api] = ModelFormat.Api - - @classmethod - def matches(cls, mod: ModelOnDisk) -> bool: - # API models are not stored on disk, so we can't match them. - return False - - @classmethod - def parse(cls, mod: ModelOnDisk) -> dict[str, Any]: - raise NotImplementedError("API models are not parsed from disk.") - - -def get_model_discriminator_value(v: Any) -> str: - """ - Computes the discriminator value for a model config. - https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions-with-callable-discriminator - """ - format_ = type_ = variant_ = None - - if isinstance(v, dict): - format_ = v.get("format") - if isinstance(format_, Enum): - format_ = format_.value - - type_ = v.get("type") - if isinstance(type_, Enum): - type_ = type_.value - - variant_ = v.get("variant") - if isinstance(variant_, Enum): - variant_ = variant_.value - else: - format_ = v.format.value - type_ = v.type.value - variant_ = getattr(v, "variant", None) - if variant_: - variant_ = variant_.value - - # Ideally, each config would be uniquely identified with a combination of fields - # i.e. (type, format, variant) without any special cases. Alas... - - # Previously, CLIPEmbed did not have any variants, meaning older database entries lack a variant field. - # To maintain compatibility, we default to ClipVariantType.L in this case. - if type_ == ModelType.CLIPEmbed.value and format_ == ModelFormat.Diffusers.value: - variant_ = variant_ or ClipVariantType.L.value - return f"{type_}.{format_}.{variant_}" - return f"{type_}.{format_}" - - -# The types are listed explicitly because IDEs/LSPs can't identify the correct types -# when AnyModelConfig is constructed dynamically using ModelConfigBase.all_config_classes -AnyModelConfig = Annotated[ - Union[ - Annotated[MainDiffusersConfig, MainDiffusersConfig.get_tag()], - Annotated[MainCheckpointConfig, MainCheckpointConfig.get_tag()], - Annotated[MainBnbQuantized4bCheckpointConfig, MainBnbQuantized4bCheckpointConfig.get_tag()], - Annotated[MainGGUFCheckpointConfig, MainGGUFCheckpointConfig.get_tag()], - Annotated[VAEDiffusersConfig, VAEDiffusersConfig.get_tag()], - Annotated[VAECheckpointConfig, VAECheckpointConfig.get_tag()], - Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()], - Annotated[ControlNetCheckpointConfig, ControlNetCheckpointConfig.get_tag()], - Annotated[LoRALyCORISConfig, LoRALyCORISConfig.get_tag()], - Annotated[LoRAOmiConfig, LoRAOmiConfig.get_tag()], - Annotated[ControlLoRALyCORISConfig, ControlLoRALyCORISConfig.get_tag()], - Annotated[ControlLoRADiffusersConfig, ControlLoRADiffusersConfig.get_tag()], - Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()], - Annotated[T5EncoderConfig, T5EncoderConfig.get_tag()], - Annotated[T5EncoderBnbQuantizedLlmInt8bConfig, T5EncoderBnbQuantizedLlmInt8bConfig.get_tag()], - Annotated[TextualInversionFileConfig, TextualInversionFileConfig.get_tag()], - Annotated[TextualInversionFolderConfig, TextualInversionFolderConfig.get_tag()], - Annotated[IPAdapterInvokeAIConfig, IPAdapterInvokeAIConfig.get_tag()], - Annotated[IPAdapterCheckpointConfig, IPAdapterCheckpointConfig.get_tag()], - Annotated[T2IAdapterConfig, T2IAdapterConfig.get_tag()], - Annotated[SpandrelImageToImageConfig, SpandrelImageToImageConfig.get_tag()], - Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()], - Annotated[CLIPLEmbedDiffusersConfig, CLIPLEmbedDiffusersConfig.get_tag()], - Annotated[CLIPGEmbedDiffusersConfig, CLIPGEmbedDiffusersConfig.get_tag()], - Annotated[SigLIPConfig, SigLIPConfig.get_tag()], - Annotated[FluxReduxConfig, FluxReduxConfig.get_tag()], - Annotated[LlavaOnevisionConfig, LlavaOnevisionConfig.get_tag()], - Annotated[ApiModelConfig, ApiModelConfig.get_tag()], - Annotated[VideoApiModelConfig, VideoApiModelConfig.get_tag()], - ], - Discriminator(get_model_discriminator_value), -] - -AnyModelConfigValidator = TypeAdapter(AnyModelConfig) -AnyDefaultSettings: TypeAlias = Union[MainModelDefaultSettings, LoraModelDefaultSettings, ControlAdapterDefaultSettings] - - -class ModelConfigFactory: - @staticmethod - def make_config(model_data: Dict[str, Any], timestamp: Optional[float] = None) -> AnyModelConfig: - """Return the appropriate config object from raw dict values.""" - model = AnyModelConfigValidator.validate_python(model_data) # type: ignore - if isinstance(model, CheckpointConfigBase) and timestamp: - model.converted_at = timestamp - validate_hash(model.hash) - return model # type: ignore diff --git a/tests/test_model_probe/sd-1/main/dreamshaper-8-inpainting/unet/diffusion_pytorch_model.fp16.safetensors b/invokeai/backend/model_manager/configs/__init__.py similarity index 100% rename from tests/test_model_probe/sd-1/main/dreamshaper-8-inpainting/unet/diffusion_pytorch_model.fp16.safetensors rename to invokeai/backend/model_manager/configs/__init__.py diff --git a/invokeai/backend/model_manager/configs/base.py b/invokeai/backend/model_manager/configs/base.py new file mode 100644 index 00000000000..8de9a2b8316 --- /dev/null +++ b/invokeai/backend/model_manager/configs/base.py @@ -0,0 +1,245 @@ +from abc import ABC, abstractmethod +from enum import Enum +from inspect import isabstract +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Literal, + Self, + Type, +) + +from pydantic import BaseModel, ConfigDict, Field, Tag +from pydantic_core import PydanticUndefined + +from invokeai.app.util.misc import uuid_string +from invokeai.backend.model_manager.model_on_disk import ModelOnDisk +from invokeai.backend.model_manager.taxonomy import ( + AnyVariant, + BaseModelType, + ModelFormat, + ModelRepoVariant, + ModelSourceType, + ModelType, +) + +if TYPE_CHECKING: + pass + + +class Config_Base(ABC, BaseModel): + """ + Abstract base class for model configurations. A model config describes a specific combination of model base, type and + format, along with other metadata about the model. For example, a Stable Diffusion 1.x main model in checkpoint format + would have base=sd-1, type=main, format=checkpoint. + + To create a new config type, inherit from this class and implement its interface: + - Define method 'from_model_on_disk' that returns an instance of the class or raises NotAMatch. This method will be + called during model installation to determine the correct config class for a model. + - Define fields 'type', 'base' and 'format' as pydantic fields. These should be Literals with a single value. A + default must be provided for each of these fields. + + If multiple combinations of base, type and format need to be supported, create a separate subclass for each. + + See MinimalConfigExample in test_model_probe.py for an example implementation. + """ + + # These fields are common to all model configs. + + key: str = Field( + default_factory=uuid_string, + description="A unique key for this model.", + ) + hash: str = Field( + description="The hash of the model file(s).", + ) + path: str = Field( + description="Path to the model on the filesystem. Relative paths are relative to the Invoke root directory.", + ) + file_size: int = Field( + description="The size of the model in bytes.", + ) + name: str = Field( + description="Name of the model.", + ) + description: str | None = Field( + default=None, + description="Model description", + ) + source: str = Field( + description="The original source of the model (path, URL or repo_id).", + ) + source_type: ModelSourceType = Field( + description="The type of source", + ) + source_api_response: str | None = Field( + default=None, + description="The original API response from the source, as stringified JSON.", + ) + cover_image: str | None = Field( + default=None, + description="Url for image to preview model", + ) + usage_info: str | None = Field( + default=None, + description="Usage information for this model", + ) + + CONFIG_CLASSES: ClassVar[set[Type["Config_Base"]]] = set() + """Set of all non-abstract subclasses of Config_Base, for use during model probing. In other words, this is the set + of all known model config types.""" + + model_config = ConfigDict( + validate_assignment=True, + json_schema_serialization_defaults_required=True, + json_schema_mode_override="serialization", + ) + + @classmethod + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + # Register non-abstract subclasses so we can iterate over them later during model probing. Note that + # isabstract() will return False if the class does not have any abstract methods, even if it inherits from ABC. + # We must check for ABC lest we unintentionally register some abstract model config classes. + if not isabstract(cls) and ABC not in cls.__bases__: + cls.CONFIG_CLASSES.add(cls) + + @classmethod + def __pydantic_init_subclass__(cls, **kwargs): + # Ensure that model configs define 'base', 'type' and 'format' fields and provide defaults for them. Each + # subclass is expected to represent a single combination of base, type and format. + # + # This pydantic dunder method is called after the pydantic model for a class is created. The normal + # __init_subclass__ is too early to do this check. + for name in ("type", "base", "format"): + if name not in cls.model_fields: + raise NotImplementedError(f"{cls.__name__} must define a '{name}' field") + if cls.model_fields[name].default is PydanticUndefined: + raise NotImplementedError(f"{cls.__name__} must define a default for the '{name}' field") + + @classmethod + def get_tag(cls) -> Tag: + """Constructs a pydantic discriminated union tag for this model config class. When a config is deserialized, + pydantic uses the tag to determine which subclass to instantiate. + + The tag is a dot-separated string of the type, format, base and variant (if applicable). + """ + tag_strings: list[str] = [] + for name in ("type", "format", "base", "variant"): + if field := cls.model_fields.get(name): + # The check in __pydantic_init_subclass__ ensures that type, format and base are always present with + # defaults. variant does not require a default, but if it has one, we need to add it to the tag. We can + # check for the presence of a default by seeing if it's not PydanticUndefined, a sentinel value used by + # pydantic to indicate that no default was provided. + if field.default is not PydanticUndefined: + # We expect each of these fields has an Enum for its default; we want the value of the enum. + tag_strings.append(field.default.value) + return Tag(".".join(tag_strings)) + + @staticmethod + def get_model_discriminator_value(v: Any) -> str: + """Computes the discriminator value for a model config discriminated union.""" + # This is called by pydantic during deserialization and serialization to determine which model the data + # represents. It can get either a dict (during deserialization) or an instance of a Config_Base subclass + # (during serialization). + # + # See: https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions-with-callable-discriminator + if isinstance(v, Config_Base): + # We have an instance of a ModelConfigBase subclass - use its tag directly. + return v.get_tag().tag + if isinstance(v, dict): + # We have a dict - attempt to compute a tag from its fields. + tag_strings: list[str] = [] + if type_ := v.get("type"): + if isinstance(type_, Enum): + type_ = str(type_.value) + elif not isinstance(type_, str): + raise ValueError("Model config dict 'type' field must be a string or Enum") + tag_strings.append(type_) + + if format_ := v.get("format"): + if isinstance(format_, Enum): + format_ = str(format_.value) + elif not isinstance(format_, str): + raise ValueError("Model config dict 'format' field must be a string or Enum") + tag_strings.append(format_) + + if base_ := v.get("base"): + if isinstance(base_, Enum): + base_ = str(base_.value) + elif not isinstance(base_, str): + raise ValueError("Model config dict 'base' field must be a string or Enum") + tag_strings.append(base_) + + # Special case: CLIP Embed models also need the variant to distinguish them. + if ( + type_ == ModelType.CLIPEmbed.value + and format_ == ModelFormat.Diffusers.value + and base_ == BaseModelType.Any.value + ): + if variant_ := v.get("variant"): + if isinstance(variant_, Enum): + variant_ = variant_.value + elif not isinstance(variant_, str): + raise ValueError("Model config dict 'variant' field must be a string or Enum") + tag_strings.append(variant_) + else: + raise ValueError("CLIP Embed model config dict must include a 'variant' field") + + return ".".join(tag_strings) + else: + raise ValueError( + "Model config discriminator value must be computed from a dict or ModelConfigBase instance" + ) + + @classmethod + @abstractmethod + def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self: + """Given the model on disk and any override fields, attempt to construct an instance of this config class. + + This method serves to identify whether the model on disk matches this config class, and if so, to extract any + additional metadata needed to instantiate the config. + + Implementations should raise a NotAMatchError if the model does not match this config class.""" + raise NotImplementedError(f"from_model_on_disk not implemented for {cls.__name__}") + + +class Checkpoint_Config_Base(ABC, BaseModel): + """Base class for checkpoint-style models.""" + + config_path: str | None = Field( + description="Path to the config for this model, if any.", + default=None, + ) + + +class Diffusers_Config_Base(ABC, BaseModel): + """Base class for diffusers-style models.""" + + format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers) + repo_variant: ModelRepoVariant = Field(ModelRepoVariant.Default) + + @classmethod + def _get_repo_variant_or_raise(cls, mod: ModelOnDisk) -> ModelRepoVariant: + # get all files ending in .bin or .safetensors + weight_files = list(mod.path.glob("**/*.safetensors")) + weight_files.extend(list(mod.path.glob("**/*.bin"))) + for x in weight_files: + if ".fp16" in x.suffixes: + return ModelRepoVariant.FP16 + if "openvino_model" in x.name: + return ModelRepoVariant.OpenVINO + if "flax_model" in x.name: + return ModelRepoVariant.Flax + if x.suffix == ".onnx": + return ModelRepoVariant.ONNX + return ModelRepoVariant.Default + + +class SubmodelDefinition(BaseModel): + path_or_prefix: str + model_type: ModelType + variant: AnyVariant | None = None + + model_config = ConfigDict(protected_namespaces=()) diff --git a/invokeai/backend/model_manager/configs/clip_embed.py b/invokeai/backend/model_manager/configs/clip_embed.py new file mode 100644 index 00000000000..4bb24a0a637 --- /dev/null +++ b/invokeai/backend/model_manager/configs/clip_embed.py @@ -0,0 +1,91 @@ +from typing import ( + Literal, + Self, +) + +from pydantic import Field +from typing_extensions import Any + +from invokeai.backend.model_manager.configs.base import Config_Base, Diffusers_Config_Base +from invokeai.backend.model_manager.configs.identification_utils import ( + NotAMatchError, + get_config_dict_or_raise, + raise_for_class_name, + raise_for_override_fields, + raise_if_not_dir, +) +from invokeai.backend.model_manager.model_on_disk import ModelOnDisk +from invokeai.backend.model_manager.taxonomy import ( + BaseModelType, + ClipVariantType, + ModelFormat, + ModelType, +) + + +def get_clip_variant_type_from_config(config: dict[str, Any]) -> ClipVariantType | None: + try: + hidden_size = config.get("hidden_size") + match hidden_size: + case 1280: + return ClipVariantType.G + case 768: + return ClipVariantType.L + case _: + return None + except Exception: + return None + + +class CLIPEmbed_Diffusers_Config_Base(Diffusers_Config_Base): + base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any) + type: Literal[ModelType.CLIPEmbed] = Field(default=ModelType.CLIPEmbed) + format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers) + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self: + raise_if_not_dir(mod) + + raise_for_override_fields(cls, override_fields) + + raise_for_class_name( + { + mod.path / "config.json", + mod.path / "text_encoder" / "config.json", + }, + { + "CLIPModel", + "CLIPTextModel", + "CLIPTextModelWithProjection", + }, + ) + + cls._validate_variant(mod) + + return cls(**override_fields) + + @classmethod + def _validate_variant(cls, mod: ModelOnDisk) -> None: + """Raise `NotAMatch` if the model variant does not match this config class.""" + expected_variant = cls.model_fields["variant"].default + config = get_config_dict_or_raise( + { + mod.path / "config.json", + mod.path / "text_encoder" / "config.json", + }, + ) + recognized_variant = get_clip_variant_type_from_config(config) + + if recognized_variant is None: + raise NotAMatchError("unable to determine CLIP variant from config") + + if expected_variant is not recognized_variant: + raise NotAMatchError(f"variant is {recognized_variant}, not {expected_variant}") + + +class CLIPEmbed_Diffusers_G_Config(CLIPEmbed_Diffusers_Config_Base, Config_Base): + variant: Literal[ClipVariantType.G] = Field(default=ClipVariantType.G) + + +class CLIPEmbed_Diffusers_L_Config(CLIPEmbed_Diffusers_Config_Base, Config_Base): + variant: Literal[ClipVariantType.L] = Field(default=ClipVariantType.L) diff --git a/invokeai/backend/model_manager/configs/clip_vision.py b/invokeai/backend/model_manager/configs/clip_vision.py new file mode 100644 index 00000000000..af5a539bc18 --- /dev/null +++ b/invokeai/backend/model_manager/configs/clip_vision.py @@ -0,0 +1,57 @@ +from typing import ( + Literal, + Self, +) + +from pydantic import Field +from typing_extensions import Any + +from invokeai.backend.model_manager.configs.base import Config_Base, Diffusers_Config_Base +from invokeai.backend.model_manager.configs.identification_utils import ( + NotAMatchError, + get_class_name_from_config_dict_or_raise, + get_config_dict_or_raise, + raise_for_override_fields, + raise_if_not_dir, +) +from invokeai.backend.model_manager.model_on_disk import ModelOnDisk +from invokeai.backend.model_manager.taxonomy import ( + BaseModelType, + ModelFormat, + ModelType, +) + + +class CLIPVision_Diffusers_Config(Diffusers_Config_Base, Config_Base): + """Model config for CLIPVision.""" + + base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any) + type: Literal[ModelType.CLIPVision] = Field(default=ModelType.CLIPVision) + format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers) + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self: + raise_if_not_dir(mod) + + raise_for_override_fields(cls, override_fields) + + cls.raise_if_config_doesnt_look_like_clip_vision(mod) + + return cls(**override_fields) + + @classmethod + def raise_if_config_doesnt_look_like_clip_vision(cls, mod: ModelOnDisk) -> None: + config_dict = get_config_dict_or_raise(mod.path / "config.json") + class_name = get_class_name_from_config_dict_or_raise(config_dict) + + if class_name == "CLIPVisionModelWithProjection": + looks_like_clip_vision = True + elif class_name == "CLIPModel" and "vision_config" in config_dict: + looks_like_clip_vision = True + else: + looks_like_clip_vision = False + + if not looks_like_clip_vision: + raise NotAMatchError( + f"config class name is {class_name}, not CLIPVisionModelWithProjection or CLIPModel with vision_config" + ) diff --git a/invokeai/backend/model_manager/configs/controlnet.py b/invokeai/backend/model_manager/configs/controlnet.py new file mode 100644 index 00000000000..630e81fd243 --- /dev/null +++ b/invokeai/backend/model_manager/configs/controlnet.py @@ -0,0 +1,230 @@ +from typing import ( + Literal, + Self, +) + +from pydantic import BaseModel, ConfigDict, Field +from typing_extensions import Any + +from invokeai.backend.flux.controlnet.state_dict_utils import ( + is_state_dict_instantx_controlnet, + is_state_dict_xlabs_controlnet, +) +from invokeai.backend.model_manager.configs.base import Checkpoint_Config_Base, Config_Base, Diffusers_Config_Base +from invokeai.backend.model_manager.configs.identification_utils import ( + NotAMatchError, + common_config_paths, + get_config_dict_or_raise, + raise_for_class_name, + raise_for_override_fields, + raise_if_not_dir, + raise_if_not_file, + state_dict_has_any_keys_starting_with, +) +from invokeai.backend.model_manager.model_on_disk import ModelOnDisk +from invokeai.backend.model_manager.taxonomy import ( + BaseModelType, + ModelFormat, + ModelType, +) + +MODEL_NAME_TO_PREPROCESSOR = { + "canny": "canny_image_processor", + "mlsd": "mlsd_image_processor", + "depth": "depth_anything_image_processor", + "bae": "normalbae_image_processor", + "normal": "normalbae_image_processor", + "sketch": "pidi_image_processor", + "scribble": "lineart_image_processor", + "lineart anime": "lineart_anime_image_processor", + "lineart_anime": "lineart_anime_image_processor", + "lineart": "lineart_image_processor", + "soft": "hed_image_processor", + "softedge": "hed_image_processor", + "hed": "hed_image_processor", + "shuffle": "content_shuffle_image_processor", + "pose": "dw_openpose_image_processor", + "mediapipe": "mediapipe_face_processor", + "pidi": "pidi_image_processor", + "zoe": "zoe_depth_image_processor", + "color": "color_map_image_processor", +} + + +class ControlAdapterDefaultSettings(BaseModel): + # This could be narrowed to controlnet processor nodes, but they change. Leaving this a string is safer. + preprocessor: str | None + model_config = ConfigDict(extra="forbid") + + @classmethod + def from_model_name(cls, model_name: str) -> Self: + for k, v in MODEL_NAME_TO_PREPROCESSOR.items(): + model_name_lower = model_name.lower() + if k in model_name_lower: + return cls(preprocessor=v) + return cls(preprocessor=None) + + +class ControlNet_Diffusers_Config_Base(Diffusers_Config_Base): + """Model config for ControlNet models (diffusers version).""" + + type: Literal[ModelType.ControlNet] = Field(default=ModelType.ControlNet) + format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers) + default_settings: ControlAdapterDefaultSettings | None = Field(None) + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self: + raise_if_not_dir(mod) + + raise_for_override_fields(cls, override_fields) + + raise_for_class_name( + common_config_paths(mod.path), + { + "ControlNetModel", + "FluxControlNetModel", + }, + ) + + cls._validate_base(mod) + + return cls(**override_fields) + + @classmethod + def _validate_base(cls, mod: ModelOnDisk) -> None: + """Raise `NotAMatch` if the model base does not match this config class.""" + expected_base = cls.model_fields["base"].default + recognized_base = cls._get_base_or_raise(mod) + if expected_base is not recognized_base: + raise NotAMatchError(f"base is {recognized_base}, not {expected_base}") + + @classmethod + def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType: + config_dict = get_config_dict_or_raise(common_config_paths(mod.path)) + + if config_dict.get("_class_name") == "FluxControlNetModel": + return BaseModelType.Flux + + dimension = config_dict.get("cross_attention_dim") + + match dimension: + case 768: + return BaseModelType.StableDiffusion1 + case 1024: + # No obvious way to distinguish between sd2-base and sd2-768, but we don't really differentiate them + # anyway. + return BaseModelType.StableDiffusion2 + case 2048: + return BaseModelType.StableDiffusionXL + case _: + raise NotAMatchError(f"unrecognized cross_attention_dim {dimension}") + + +class ControlNet_Diffusers_SD1_Config(ControlNet_Diffusers_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) + + +class ControlNet_Diffusers_SD2_Config(ControlNet_Diffusers_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2) + + +class ControlNet_Diffusers_SDXL_Config(ControlNet_Diffusers_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) + + +class ControlNet_Diffusers_FLUX_Config(ControlNet_Diffusers_Config_Base, Config_Base): + base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) + + +class ControlNet_Checkpoint_Config_Base(Checkpoint_Config_Base): + """Model config for ControlNet models (diffusers version).""" + + type: Literal[ModelType.ControlNet] = Field(default=ModelType.ControlNet) + format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint) + default_settings: ControlAdapterDefaultSettings | None = Field(None) + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self: + raise_if_not_file(mod) + + raise_for_override_fields(cls, override_fields) + + cls._validate_looks_like_controlnet(mod) + + cls._validate_base(mod) + + return cls(**override_fields) + + @classmethod + def _validate_base(cls, mod: ModelOnDisk) -> None: + """Raise `NotAMatch` if the model base does not match this config class.""" + expected_base = cls.model_fields["base"].default + recognized_base = cls._get_base_or_raise(mod) + if expected_base is not recognized_base: + raise NotAMatchError(f"base is {recognized_base}, not {expected_base}") + + @classmethod + def _validate_looks_like_controlnet(cls, mod: ModelOnDisk) -> None: + if not state_dict_has_any_keys_starting_with( + mod.load_state_dict(), + { + "controlnet", + "control_model", + "input_blocks", + # XLabs FLUX ControlNet models have keys starting with "controlnet_blocks." + # For example: https://huggingface.co/XLabs-AI/flux-controlnet-collections/blob/86ab1e915a389d5857135c00e0d350e9e38a9048/flux-canny-controlnet_v2.safetensors + # TODO(ryand): This is very fragile. XLabs FLUX ControlNet models also contain keys starting with + # "double_blocks.", which we check for above. But, I'm afraid to modify this logic because it is so + # delicate. + "controlnet_blocks", + }, + ): + raise NotAMatchError("state dict does not look like a ControlNet checkpoint") + + @classmethod + def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType: + state_dict = mod.load_state_dict() + + if is_state_dict_xlabs_controlnet(state_dict) or is_state_dict_instantx_controlnet(state_dict): + # TODO(ryand): Should I distinguish between XLabs, InstantX and other ControlNet models by implementing + # get_format()? + return BaseModelType.Flux + + for key in ( + "control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight", + "controlnet_mid_block.bias", + "input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight", + "down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_k.weight", + ): + if key not in state_dict: + continue + width = state_dict[key].shape[-1] + match width: + case 768: + return BaseModelType.StableDiffusion1 + case 1024: + return BaseModelType.StableDiffusion2 + case 2048: + return BaseModelType.StableDiffusionXL + case 1280: + return BaseModelType.StableDiffusionXL + case _: + pass + + raise NotAMatchError("unable to determine base type from state dict") + + +class ControlNet_Checkpoint_SD1_Config(ControlNet_Checkpoint_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) + + +class ControlNet_Checkpoint_SD2_Config(ControlNet_Checkpoint_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2) + + +class ControlNet_Checkpoint_SDXL_Config(ControlNet_Checkpoint_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) + + +class ControlNet_Checkpoint_FLUX_Config(ControlNet_Checkpoint_Config_Base, Config_Base): + base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) diff --git a/tests/test_model_probe/vae/taesdxl-fp16/diffusion_pytorch_model.fp16.safetensors b/invokeai/backend/model_manager/configs/external_api.py similarity index 100% rename from tests/test_model_probe/vae/taesdxl-fp16/diffusion_pytorch_model.fp16.safetensors rename to invokeai/backend/model_manager/configs/external_api.py diff --git a/invokeai/backend/model_manager/configs/factory.py b/invokeai/backend/model_manager/configs/factory.py new file mode 100644 index 00000000000..dcd7c4c0edc --- /dev/null +++ b/invokeai/backend/model_manager/configs/factory.py @@ -0,0 +1,523 @@ +import logging +from dataclasses import dataclass +from pathlib import Path +from typing import ( + Union, +) + +from pydantic import Discriminator, TypeAdapter, ValidationError +from typing_extensions import Annotated, Any + +from invokeai.app.services.config.config_default import get_config +from invokeai.app.util.misc import uuid_string +from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS +from invokeai.backend.model_manager.configs.base import Config_Base +from invokeai.backend.model_manager.configs.clip_embed import CLIPEmbed_Diffusers_G_Config, CLIPEmbed_Diffusers_L_Config +from invokeai.backend.model_manager.configs.clip_vision import CLIPVision_Diffusers_Config +from invokeai.backend.model_manager.configs.controlnet import ( + ControlAdapterDefaultSettings, + ControlNet_Checkpoint_FLUX_Config, + ControlNet_Checkpoint_SD1_Config, + ControlNet_Checkpoint_SD2_Config, + ControlNet_Checkpoint_SDXL_Config, + ControlNet_Diffusers_FLUX_Config, + ControlNet_Diffusers_SD1_Config, + ControlNet_Diffusers_SD2_Config, + ControlNet_Diffusers_SDXL_Config, +) +from invokeai.backend.model_manager.configs.flux_redux import FLUXRedux_Checkpoint_Config +from invokeai.backend.model_manager.configs.identification_utils import NotAMatchError +from invokeai.backend.model_manager.configs.ip_adapter import ( + IPAdapter_Checkpoint_FLUX_Config, + IPAdapter_Checkpoint_SD1_Config, + IPAdapter_Checkpoint_SD2_Config, + IPAdapter_Checkpoint_SDXL_Config, + IPAdapter_InvokeAI_SD1_Config, + IPAdapter_InvokeAI_SD2_Config, + IPAdapter_InvokeAI_SDXL_Config, +) +from invokeai.backend.model_manager.configs.llava_onevision import LlavaOnevision_Diffusers_Config +from invokeai.backend.model_manager.configs.lora import ( + ControlLoRA_LyCORIS_FLUX_Config, + LoRA_Diffusers_FLUX_Config, + LoRA_Diffusers_SD1_Config, + LoRA_Diffusers_SD2_Config, + LoRA_Diffusers_SDXL_Config, + LoRA_LyCORIS_FLUX_Config, + LoRA_LyCORIS_SD1_Config, + LoRA_LyCORIS_SD2_Config, + LoRA_LyCORIS_SDXL_Config, + LoRA_OMI_FLUX_Config, + LoRA_OMI_SDXL_Config, + LoraModelDefaultSettings, +) +from invokeai.backend.model_manager.configs.main import ( + Main_BnBNF4_FLUX_Config, + Main_Checkpoint_FLUX_Config, + Main_Checkpoint_SD1_Config, + Main_Checkpoint_SD2_Config, + Main_Checkpoint_SDXL_Config, + Main_Checkpoint_SDXLRefiner_Config, + Main_Diffusers_CogView4_Config, + Main_Diffusers_SD1_Config, + Main_Diffusers_SD2_Config, + Main_Diffusers_SD3_Config, + Main_Diffusers_SDXL_Config, + Main_Diffusers_SDXLRefiner_Config, + Main_ExternalAPI_ChatGPT4o_Config, + Main_ExternalAPI_FluxKontext_Config, + Main_ExternalAPI_Gemini2_5_Config, + Main_ExternalAPI_Imagen3_Config, + Main_ExternalAPI_Imagen4_Config, + Main_GGUF_FLUX_Config, + MainModelDefaultSettings, + Video_ExternalAPI_Runway_Config, + Video_ExternalAPI_Veo3_Config, +) +from invokeai.backend.model_manager.configs.siglip import SigLIP_Diffusers_Config +from invokeai.backend.model_manager.configs.spandrel import Spandrel_Checkpoint_Config +from invokeai.backend.model_manager.configs.t2i_adapter import ( + T2IAdapter_Diffusers_SD1_Config, + T2IAdapter_Diffusers_SDXL_Config, +) +from invokeai.backend.model_manager.configs.t5_encoder import T5Encoder_BnBLLMint8_Config, T5Encoder_T5Encoder_Config +from invokeai.backend.model_manager.configs.textual_inversion import ( + TI_File_SD1_Config, + TI_File_SD2_Config, + TI_File_SDXL_Config, + TI_Folder_SD1_Config, + TI_Folder_SD2_Config, + TI_Folder_SDXL_Config, +) +from invokeai.backend.model_manager.configs.unknown import Unknown_Config +from invokeai.backend.model_manager.configs.vae import ( + VAE_Checkpoint_FLUX_Config, + VAE_Checkpoint_SD1_Config, + VAE_Checkpoint_SD2_Config, + VAE_Checkpoint_SDXL_Config, + VAE_Diffusers_SD1_Config, + VAE_Diffusers_SDXL_Config, +) +from invokeai.backend.model_manager.model_on_disk import ModelOnDisk +from invokeai.backend.model_manager.taxonomy import ( + BaseModelType, + ModelFormat, + ModelSourceType, + ModelType, + variant_type_adapter, +) + +logger = logging.getLogger(__name__) +app_config = get_config() + +# Known model file extensions for sanity checking +_MODEL_EXTENSIONS = { + ".safetensors", + ".ckpt", + ".pt", + ".pth", + ".bin", + ".gguf", + ".onnx", +} + +# Known config file names for diffusers/transformers models +_CONFIG_FILES = { + "model_index.json", + "config.json", +} + +# Maximum number of files in a directory to be considered a model +_MAX_FILES_IN_MODEL_DIR = 50 + +# Maximum depth to search for model files in directories +_MAX_SEARCH_DEPTH = 2 + + +# The types are listed explicitly because IDEs/LSPs can't identify the correct types +# when AnyModelConfig is constructed dynamically using ModelConfigBase.all_config_classes +AnyModelConfig = Annotated[ + Union[ + # Main (Pipeline) - diffusers format + Annotated[Main_Diffusers_SD1_Config, Main_Diffusers_SD1_Config.get_tag()], + Annotated[Main_Diffusers_SD2_Config, Main_Diffusers_SD2_Config.get_tag()], + Annotated[Main_Diffusers_SDXL_Config, Main_Diffusers_SDXL_Config.get_tag()], + Annotated[Main_Diffusers_SDXLRefiner_Config, Main_Diffusers_SDXLRefiner_Config.get_tag()], + Annotated[Main_Diffusers_SD3_Config, Main_Diffusers_SD3_Config.get_tag()], + Annotated[Main_Diffusers_CogView4_Config, Main_Diffusers_CogView4_Config.get_tag()], + # Main (Pipeline) - checkpoint format + Annotated[Main_Checkpoint_SD1_Config, Main_Checkpoint_SD1_Config.get_tag()], + Annotated[Main_Checkpoint_SD2_Config, Main_Checkpoint_SD2_Config.get_tag()], + Annotated[Main_Checkpoint_SDXL_Config, Main_Checkpoint_SDXL_Config.get_tag()], + Annotated[Main_Checkpoint_SDXLRefiner_Config, Main_Checkpoint_SDXLRefiner_Config.get_tag()], + Annotated[Main_Checkpoint_FLUX_Config, Main_Checkpoint_FLUX_Config.get_tag()], + # Main (Pipeline) - quantized formats + Annotated[Main_BnBNF4_FLUX_Config, Main_BnBNF4_FLUX_Config.get_tag()], + Annotated[Main_GGUF_FLUX_Config, Main_GGUF_FLUX_Config.get_tag()], + # VAE - checkpoint format + Annotated[VAE_Checkpoint_SD1_Config, VAE_Checkpoint_SD1_Config.get_tag()], + Annotated[VAE_Checkpoint_SD2_Config, VAE_Checkpoint_SD2_Config.get_tag()], + Annotated[VAE_Checkpoint_SDXL_Config, VAE_Checkpoint_SDXL_Config.get_tag()], + Annotated[VAE_Checkpoint_FLUX_Config, VAE_Checkpoint_FLUX_Config.get_tag()], + # VAE - diffusers format + Annotated[VAE_Diffusers_SD1_Config, VAE_Diffusers_SD1_Config.get_tag()], + Annotated[VAE_Diffusers_SDXL_Config, VAE_Diffusers_SDXL_Config.get_tag()], + # ControlNet - checkpoint format + Annotated[ControlNet_Checkpoint_SD1_Config, ControlNet_Checkpoint_SD1_Config.get_tag()], + Annotated[ControlNet_Checkpoint_SD2_Config, ControlNet_Checkpoint_SD2_Config.get_tag()], + Annotated[ControlNet_Checkpoint_SDXL_Config, ControlNet_Checkpoint_SDXL_Config.get_tag()], + Annotated[ControlNet_Checkpoint_FLUX_Config, ControlNet_Checkpoint_FLUX_Config.get_tag()], + # ControlNet - diffusers format + Annotated[ControlNet_Diffusers_SD1_Config, ControlNet_Diffusers_SD1_Config.get_tag()], + Annotated[ControlNet_Diffusers_SD2_Config, ControlNet_Diffusers_SD2_Config.get_tag()], + Annotated[ControlNet_Diffusers_SDXL_Config, ControlNet_Diffusers_SDXL_Config.get_tag()], + Annotated[ControlNet_Diffusers_FLUX_Config, ControlNet_Diffusers_FLUX_Config.get_tag()], + # LoRA - LyCORIS format + Annotated[LoRA_LyCORIS_SD1_Config, LoRA_LyCORIS_SD1_Config.get_tag()], + Annotated[LoRA_LyCORIS_SD2_Config, LoRA_LyCORIS_SD2_Config.get_tag()], + Annotated[LoRA_LyCORIS_SDXL_Config, LoRA_LyCORIS_SDXL_Config.get_tag()], + Annotated[LoRA_LyCORIS_FLUX_Config, LoRA_LyCORIS_FLUX_Config.get_tag()], + # LoRA - OMI format + Annotated[LoRA_OMI_SDXL_Config, LoRA_OMI_SDXL_Config.get_tag()], + Annotated[LoRA_OMI_FLUX_Config, LoRA_OMI_FLUX_Config.get_tag()], + # LoRA - diffusers format + Annotated[LoRA_Diffusers_SD1_Config, LoRA_Diffusers_SD1_Config.get_tag()], + Annotated[LoRA_Diffusers_SD2_Config, LoRA_Diffusers_SD2_Config.get_tag()], + Annotated[LoRA_Diffusers_SDXL_Config, LoRA_Diffusers_SDXL_Config.get_tag()], + Annotated[LoRA_Diffusers_FLUX_Config, LoRA_Diffusers_FLUX_Config.get_tag()], + # ControlLoRA - diffusers format + Annotated[ControlLoRA_LyCORIS_FLUX_Config, ControlLoRA_LyCORIS_FLUX_Config.get_tag()], + # T5 Encoder - all formats + Annotated[T5Encoder_T5Encoder_Config, T5Encoder_T5Encoder_Config.get_tag()], + Annotated[T5Encoder_BnBLLMint8_Config, T5Encoder_BnBLLMint8_Config.get_tag()], + # TI - file format + Annotated[TI_File_SD1_Config, TI_File_SD1_Config.get_tag()], + Annotated[TI_File_SD2_Config, TI_File_SD2_Config.get_tag()], + Annotated[TI_File_SDXL_Config, TI_File_SDXL_Config.get_tag()], + # TI - folder format + Annotated[TI_Folder_SD1_Config, TI_Folder_SD1_Config.get_tag()], + Annotated[TI_Folder_SD2_Config, TI_Folder_SD2_Config.get_tag()], + Annotated[TI_Folder_SDXL_Config, TI_Folder_SDXL_Config.get_tag()], + # IP Adapter - InvokeAI format + Annotated[IPAdapter_InvokeAI_SD1_Config, IPAdapter_InvokeAI_SD1_Config.get_tag()], + Annotated[IPAdapter_InvokeAI_SD2_Config, IPAdapter_InvokeAI_SD2_Config.get_tag()], + Annotated[IPAdapter_InvokeAI_SDXL_Config, IPAdapter_InvokeAI_SDXL_Config.get_tag()], + # IP Adapter - checkpoint format + Annotated[IPAdapter_Checkpoint_SD1_Config, IPAdapter_Checkpoint_SD1_Config.get_tag()], + Annotated[IPAdapter_Checkpoint_SD2_Config, IPAdapter_Checkpoint_SD2_Config.get_tag()], + Annotated[IPAdapter_Checkpoint_SDXL_Config, IPAdapter_Checkpoint_SDXL_Config.get_tag()], + Annotated[IPAdapter_Checkpoint_FLUX_Config, IPAdapter_Checkpoint_FLUX_Config.get_tag()], + # T2I Adapter - diffusers format + Annotated[T2IAdapter_Diffusers_SD1_Config, T2IAdapter_Diffusers_SD1_Config.get_tag()], + Annotated[T2IAdapter_Diffusers_SDXL_Config, T2IAdapter_Diffusers_SDXL_Config.get_tag()], + # Misc models + Annotated[Spandrel_Checkpoint_Config, Spandrel_Checkpoint_Config.get_tag()], + Annotated[CLIPEmbed_Diffusers_G_Config, CLIPEmbed_Diffusers_G_Config.get_tag()], + Annotated[CLIPEmbed_Diffusers_L_Config, CLIPEmbed_Diffusers_L_Config.get_tag()], + Annotated[CLIPVision_Diffusers_Config, CLIPVision_Diffusers_Config.get_tag()], + Annotated[SigLIP_Diffusers_Config, SigLIP_Diffusers_Config.get_tag()], + Annotated[FLUXRedux_Checkpoint_Config, FLUXRedux_Checkpoint_Config.get_tag()], + Annotated[LlavaOnevision_Diffusers_Config, LlavaOnevision_Diffusers_Config.get_tag()], + # Main - external API + Annotated[Main_ExternalAPI_ChatGPT4o_Config, Main_ExternalAPI_ChatGPT4o_Config.get_tag()], + Annotated[Main_ExternalAPI_Gemini2_5_Config, Main_ExternalAPI_Gemini2_5_Config.get_tag()], + Annotated[Main_ExternalAPI_Imagen3_Config, Main_ExternalAPI_Imagen3_Config.get_tag()], + Annotated[Main_ExternalAPI_Imagen4_Config, Main_ExternalAPI_Imagen4_Config.get_tag()], + Annotated[Main_ExternalAPI_FluxKontext_Config, Main_ExternalAPI_FluxKontext_Config.get_tag()], + # Video - external API + Annotated[Video_ExternalAPI_Veo3_Config, Video_ExternalAPI_Veo3_Config.get_tag()], + Annotated[Video_ExternalAPI_Runway_Config, Video_ExternalAPI_Runway_Config.get_tag()], + # Unknown model (fallback) + Annotated[Unknown_Config, Unknown_Config.get_tag()], + ], + Discriminator(Config_Base.get_model_discriminator_value), +] + +AnyModelConfigValidator = TypeAdapter[AnyModelConfig](AnyModelConfig) +"""Pydantic TypeAdapter for the AnyModelConfig union, used for parsing and validation. + +If you need to parse/validate a dict or JSON into an AnyModelConfig, you should probably use +ModelConfigFactory.from_dict or ModelConfigFactory.from_json instead as they may implement +additional logic in the future. +""" + + +@dataclass +class ModelClassificationResult: + """Result of attempting to classify a model on disk into a specific model config. + + Attributes: + match: The best matching model config, or None if no match was found. + results: A mapping of model config class names to either an instance of that class (if it matched) + or an Exception (if it didn't match or an error occurred during matching). + """ + + config: AnyModelConfig | None + details: dict[str, AnyModelConfig | Exception] + + @property + def all_matches(self) -> list[AnyModelConfig]: + """Returns a list of all matching model configs found.""" + return [r for r in self.details.values() if isinstance(r, Config_Base)] + + @property + def match_count(self) -> int: + """Returns the number of matching model configs found.""" + return len(self.all_matches) + + +class ModelConfigFactory: + @staticmethod + def from_dict(fields: dict[str, Any]) -> AnyModelConfig: + """Return the appropriate config object from raw dict values.""" + model = AnyModelConfigValidator.validate_python(fields) + return model + + @staticmethod + def from_json(json: str | bytes | bytearray) -> AnyModelConfig: + """Return the appropriate config object from json.""" + model = AnyModelConfigValidator.validate_json(json) + return model + + @staticmethod + def build_common_fields( + mod: ModelOnDisk, + override_fields: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """Builds the common fields for all model configs. + + Args: + mod: The model on disk to extract fields from. + overrides: A optional dictionary of fields to override. These fields will take precedence over the values + extracted from the model on disk. + + - Casts string fields to their Enum types. + - Does not validate the fields against the model config schema. + """ + + _overrides: dict[str, Any] = override_fields or {} + fields: dict[str, Any] = {} + + if "type" in _overrides: + fields["type"] = ModelType(_overrides["type"]) + + if "format" in _overrides: + fields["format"] = ModelFormat(_overrides["format"]) + + if "base" in _overrides: + fields["base"] = BaseModelType(_overrides["base"]) + + if "source_type" in _overrides: + fields["source_type"] = ModelSourceType(_overrides["source_type"]) + + if "variant" in _overrides: + fields["variant"] = variant_type_adapter.validate_strings(_overrides["variant"]) + + fields["path"] = mod.path.as_posix() + fields["source"] = _overrides.get("source") or fields["path"] + fields["source_type"] = _overrides.get("source_type") or ModelSourceType.Path + fields["name"] = _overrides.get("name") or mod.name + fields["hash"] = _overrides.get("hash") or mod.hash() + fields["key"] = _overrides.get("key") or uuid_string() + fields["description"] = _overrides.get("description") + fields["file_size"] = _overrides.get("file_size") or mod.size() + + return fields + + @staticmethod + def _validate_path_looks_like_model(path: Path) -> None: + """Perform basic sanity checks to ensure a path looks like a model. + + This prevents wasting time trying to identify obviously non-model paths like + home directories or downloads folders. Raises RuntimeError if the path doesn't + pass basic checks. + + Args: + path: The path to validate + + Raises: + ValueError: If the path doesn't look like a model + """ + if path.is_file(): + # For files, just check the extension + if path.suffix.lower() not in _MODEL_EXTENSIONS: + raise ValueError( + f"File extension {path.suffix} is not a recognized model format. " + f"Expected one of: {', '.join(sorted(_MODEL_EXTENSIONS))}" + ) + else: + # For directories, do a quick file count check with early exit + total_files = 0 + # Ignore hidden files and directories + paths_to_check = (p for p in path.rglob("*") if not p.name.startswith(".")) + for item in paths_to_check: + if item.is_file(): + total_files += 1 + if total_files > _MAX_FILES_IN_MODEL_DIR: + raise ValueError( + f"Directory contains more than {_MAX_FILES_IN_MODEL_DIR} files. " + "This looks like a general-purpose directory rather than a model. " + "Please provide a path to a specific model file or model directory." + ) + + # Check if it has config files at root (diffusers/transformers marker) + has_root_config = any((path / config).exists() for config in _CONFIG_FILES) + + if has_root_config: + # Has a config file, looks like a valid model directory + return + + # Otherwise, search for model files within depth limit + def find_model_files(current_path: Path, depth: int) -> bool: + if depth > _MAX_SEARCH_DEPTH: + return False + try: + for item in current_path.iterdir(): + if item.is_file() and item.suffix.lower() in _MODEL_EXTENSIONS: + return True + elif item.is_dir() and find_model_files(item, depth + 1): + return True + except PermissionError: + pass + return False + + if not find_model_files(path, 0): + raise ValueError( + f"No model files or config files found in directory {path}. " + f"Expected to find model files with extensions: {', '.join(sorted(_MODEL_EXTENSIONS))} " + f"or config files: {', '.join(sorted(_CONFIG_FILES))}" + ) + + @staticmethod + def matches_sort_key(m: AnyModelConfig) -> int: + """Sort key function to prioritize model config matches in case of multiple matches.""" + + # It is possible that we have multiple matches. We need to prioritize them. + + # Known cases where multiple matches can occur: + # - SD main models can look like a LoRA when they have merged in LoRA weights. Prefer the main model. + # - SD main models in diffusers format can look like a CLIP Embed; they have a text_encoder folder with + # a config.json file. Prefer the main model. + + # Given the above cases, we can prioritize the matches by type. If we find more cases, we may need a more + # sophisticated approach. + match m.type: + case ModelType.Main: + return 0 + case ModelType.LoRA: + return 1 + case ModelType.CLIPEmbed: + return 2 + case _: + return 3 + + @staticmethod + def from_model_on_disk( + mod: str | Path | ModelOnDisk, + override_fields: dict[str, Any] | None = None, + hash_algo: HASHING_ALGORITHMS = "blake3_single", + allow_unknown: bool = True, + ) -> ModelClassificationResult: + """Classify a model on disk and return the best matching model config. + + Args: + mod: The model on disk to classify. Can be a path (str or Path) or a ModelOnDisk instance. + override_fields: Optional dictionary of fields to override. These fields will take precedence + over the values extracted from the model on disk, but this cannot force a match if the + model on disk doesn't actually match the config class. + hash_algo: The hashing algorithm to use when computing the model hash if needed. + + Returns: + A ModelClassificationResult containing the best matching model config (or None if no match) + and a mapping of all attempted model config classes to either an instance of that class (if it matched) + or an Exception (if it didn't match or an error occurred during matching). + + Raises: + ValueError: If the provided path doesn't look like a model. + """ + if isinstance(mod, Path | str): + mod = ModelOnDisk(Path(mod), hash_algo) + + # Perform basic sanity checks before attempting any config matching + # This rejects obviously non-model paths early, saving time + ModelConfigFactory._validate_path_looks_like_model(mod.path) + + # We will always need these fields to build any model config. + fields = ModelConfigFactory.build_common_fields(mod, override_fields) + + # Store results as a mapping of config class to either an instance of that class or an exception + # that was raised when trying to build it. + details: dict[str, AnyModelConfig | Exception] = {} + + # Try to build an instance of each model config class that uses the classify API. + # Each class will either return an instance of itself or raise NotAMatch if it doesn't match. + # Other exceptions may be raised if something unexpected happens during matching or building. + for candidate_class in filter(lambda x: x is not Unknown_Config, Config_Base.CONFIG_CLASSES): + candidate_name = candidate_class.__name__ + try: + # Technically, from_model_on_disk returns a Config_Base, but in practice it will always be a member of + # the AnyModelConfig union. + details[candidate_name] = candidate_class.from_model_on_disk(mod, fields) # type: ignore + except NotAMatchError as e: + # This means the model didn't match this config class. It's not an error, just no match. + details[candidate_name] = e + except ValidationError as e: + # This means the model matched, but we couldn't create the pydantic model instance for the config. + # Maybe invalid overrides were provided? + details[candidate_name] = e + except Exception as e: + # Some other unexpected error occurred. Store the exception for reporting later. + details[candidate_name] = e + + # Extract just the successful matches + matches = [r for r in details.values() if isinstance(r, Config_Base)] + + if not matches: + if not allow_unknown: + # No matches and we are not allowed to fall back to Unknown_Config + return ModelClassificationResult(config=None, details=details) + else: + # Fall back to Unknown_Config + # This should always succeed as Unknown_Config.from_model_on_disk never raises NotAMatch + config = Unknown_Config.from_model_on_disk(mod, fields) + details[Unknown_Config.__name__] = config + return ModelClassificationResult(config=config, details=details) + + matches.sort(key=ModelConfigFactory.matches_sort_key) + config = matches[0] + + # Now do any post-processing needed for specific model types/bases/etc. + match config.type: + case ModelType.Main: + config.default_settings = MainModelDefaultSettings.from_base(config.base) + case ModelType.ControlNet | ModelType.T2IAdapter | ModelType.ControlLoRa: + config.default_settings = ControlAdapterDefaultSettings.from_model_name(config.name) + case ModelType.LoRA: + config.default_settings = LoraModelDefaultSettings() + case _: + pass + + return ModelClassificationResult(config=config, details=details) + + +MODEL_NAME_TO_PREPROCESSOR = { + "canny": "canny_image_processor", + "mlsd": "mlsd_image_processor", + "depth": "depth_anything_image_processor", + "bae": "normalbae_image_processor", + "normal": "normalbae_image_processor", + "sketch": "pidi_image_processor", + "scribble": "lineart_image_processor", + "lineart anime": "lineart_anime_image_processor", + "lineart_anime": "lineart_anime_image_processor", + "lineart": "lineart_image_processor", + "soft": "hed_image_processor", + "softedge": "hed_image_processor", + "hed": "hed_image_processor", + "shuffle": "content_shuffle_image_processor", + "pose": "dw_openpose_image_processor", + "mediapipe": "mediapipe_face_processor", + "pidi": "pidi_image_processor", + "zoe": "zoe_depth_image_processor", + "color": "color_map_image_processor", +} diff --git a/invokeai/backend/model_manager/configs/flux_redux.py b/invokeai/backend/model_manager/configs/flux_redux.py new file mode 100644 index 00000000000..6eb76116fba --- /dev/null +++ b/invokeai/backend/model_manager/configs/flux_redux.py @@ -0,0 +1,40 @@ +from typing import ( + Literal, + Self, +) + +from pydantic import Field +from typing_extensions import Any + +from invokeai.backend.flux.redux.flux_redux_state_dict_utils import is_state_dict_likely_flux_redux +from invokeai.backend.model_manager.configs.base import Config_Base +from invokeai.backend.model_manager.configs.identification_utils import ( + NotAMatchError, + raise_for_override_fields, + raise_if_not_file, +) +from invokeai.backend.model_manager.model_on_disk import ModelOnDisk +from invokeai.backend.model_manager.taxonomy import ( + BaseModelType, + ModelFormat, + ModelType, +) + + +class FLUXRedux_Checkpoint_Config(Config_Base): + """Model config for FLUX Tools Redux model.""" + + type: Literal[ModelType.FluxRedux] = Field(default=ModelType.FluxRedux) + format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint) + base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self: + raise_if_not_file(mod) + + raise_for_override_fields(cls, override_fields) + + if not is_state_dict_likely_flux_redux(mod.load_state_dict()): + raise NotAMatchError("model does not match FLUX Tools Redux heuristics") + + return cls(**override_fields) diff --git a/invokeai/backend/model_manager/configs/identification_utils.py b/invokeai/backend/model_manager/configs/identification_utils.py new file mode 100644 index 00000000000..ce7d2c792de --- /dev/null +++ b/invokeai/backend/model_manager/configs/identification_utils.py @@ -0,0 +1,206 @@ +import json +from functools import cache +from pathlib import Path + +from pydantic import BaseModel, ValidationError +from pydantic_core import CoreSchema, SchemaValidator +from typing_extensions import Any + +from invokeai.backend.model_manager.model_on_disk import ModelOnDisk + + +class NotAMatchError(Exception): + """Exception for when a model does not match a config class. + + Args: + reason: The reason why the model did not match. + """ + + def __init__(self, reason: str): + super().__init__(reason) + + +def get_config_dict_or_raise(config_path: Path | set[Path]) -> dict[str, Any]: + """Load the diffusers/transformers model config file and return it as a dictionary. The config file is expected + to be in JSON format. + + Args: + config_path: The path to the config file, or a set of paths to try. + + Returns: + The config file as a dictionary. + + Raises: + NotAMatch if the config file is missing or cannot be loaded. + """ + paths_to_check = config_path if isinstance(config_path, set) else {config_path} + + problems: dict[Path, str] = {} + + for p in paths_to_check: + if not p.exists(): + problems[p] = "file does not exist" + continue + + try: + with open(p, "r") as file: + config = json.load(file) + + return config + except Exception as e: + problems[p] = str(e) + continue + + raise NotAMatchError(f"unable to load config file(s): {problems}") + + +def get_class_name_from_config_dict_or_raise(config: Path | set[Path] | dict[str, Any]) -> str: + """Load the diffusers/transformers model config file and return the class name. + + Args: + config_path: The path to the config file, or a set of paths to try. + + Returns: + The class name from the config file. + + Raises: + NotAMatch if the config file is missing or does not contain a valid class name. + """ + + if not isinstance(config, dict): + config = get_config_dict_or_raise(config) + + try: + if "_class_name" in config: + # This is a diffusers-style config + config_class_name = config["_class_name"] + elif "architectures" in config: + # This is a transformers-style config + config_class_name = config["architectures"][0] + else: + raise ValueError("missing _class_name or architectures field") + except Exception as e: + raise NotAMatchError(f"unable to determine class name from config file: {config}") from e + + if not isinstance(config_class_name, str): + raise NotAMatchError(f"_class_name or architectures field is not a string: {config_class_name}") + + return config_class_name + + +def raise_for_class_name(config: Path | set[Path] | dict[str, Any], class_name: str | set[str]) -> None: + """Get the class name from the config file and raise NotAMatch if it is not in the expected set. + + Args: + config_path: The path to the config file, or a set of paths to try. + class_name: The expected class name, or a set of expected class names. + + Raises: + NotAMatch if the class name is not in the expected set. + """ + + class_name = {class_name} if isinstance(class_name, str) else class_name + + actual_class_name = get_class_name_from_config_dict_or_raise(config) + if actual_class_name not in class_name: + raise NotAMatchError(f"invalid class name from config: {actual_class_name}") + + +def raise_for_override_fields(candidate_config_class: type[BaseModel], override_fields: dict[str, Any]) -> None: + """Check if the provided override fields are valid for the config class using pydantic. + + For example, if the candidate config class has a field "base" of type Literal[BaseModelType.StableDiffusion1], and + the override fields contain "base": BaseModelType.Flux, this function will raise NotAMatch. + + Internally, this function extracts the pydantic schema for each individual override field from the candidate config + class and validates the override value against that schema. Post-instantiation validators are not run. + + Args: + candidate_config_class: The config class that is being tested. + override_fields: The override fields provided by the user. + + Raises: + NotAMatch if any override field is invalid for the config class. + """ + for field_name, override_value in override_fields.items(): + if field_name not in candidate_config_class.model_fields: + raise NotAMatchError(f"unknown override field: {field_name}") + try: + PydanticFieldValidator.validate_field(candidate_config_class, field_name, override_value) + except ValidationError as e: + raise NotAMatchError(f"invalid override for field '{field_name}': {e}") from e + + +def raise_if_not_file(mod: ModelOnDisk) -> None: + """Raise NotAMatch if the model path is not a file.""" + if not mod.path.is_file(): + raise NotAMatchError("model path is not a file") + + +def raise_if_not_dir(mod: ModelOnDisk) -> None: + """Raise NotAMatch if the model path is not a directory.""" + if not mod.path.is_dir(): + raise NotAMatchError("model path is not a directory") + + +def state_dict_has_any_keys_exact(state_dict: dict[str | int, Any], keys: str | set[str]) -> bool: + """Returns true if the state dict has any of the specified keys.""" + _keys = {keys} if isinstance(keys, str) else keys + return any(key in state_dict for key in _keys) + + +def state_dict_has_any_keys_starting_with(state_dict: dict[str | int, Any], prefixes: str | set[str]) -> bool: + """Returns true if the state dict has any keys starting with any of the specified prefixes.""" + _prefixes = {prefixes} if isinstance(prefixes, str) else prefixes + return any(any(key.startswith(prefix) for prefix in _prefixes) for key in state_dict.keys() if isinstance(key, str)) + + +def state_dict_has_any_keys_ending_with(state_dict: dict[str | int, Any], suffixes: str | set[str]) -> bool: + """Returns true if the state dict has any keys ending with any of the specified suffixes.""" + _suffixes = {suffixes} if isinstance(suffixes, str) else suffixes + return any(any(key.endswith(suffix) for suffix in _suffixes) for key in state_dict.keys() if isinstance(key, str)) + + +def common_config_paths(path: Path) -> set[Path]: + """Returns common config file paths for models stored in directories.""" + return {path / "config.json", path / "model_index.json"} + + +class PydanticFieldValidator: + """Utility class for validating individual fields of a Pydantic model without instantiating the whole model. + + See: https://github.com/pydantic/pydantic/discussions/7367#discussioncomment-14213144 + """ + + @staticmethod + def find_field_schema(model: type[BaseModel], field_name: str) -> CoreSchema: + """Find the Pydantic core schema for a specific field in a model.""" + schema: CoreSchema = model.__pydantic_core_schema__.copy() + # we shallow copied, be careful not to mutate the original schema! + + assert schema["type"] in ["definitions", "model"] + + # find the field schema + field_schema = schema["schema"] # type: ignore + while "fields" not in field_schema: + field_schema = field_schema["schema"] # type: ignore + + field_schema = field_schema["fields"][field_name]["schema"] # type: ignore + + # if the original schema is a definition schema, replace the model schema with the field schema + if schema["type"] == "definitions": + schema["schema"] = field_schema + return schema + else: + return field_schema + + @cache + @staticmethod + def get_validator(model: type[BaseModel], field_name: str) -> SchemaValidator: + """Get a SchemaValidator for a specific field in a model.""" + return SchemaValidator(PydanticFieldValidator.find_field_schema(model, field_name)) + + @staticmethod + def validate_field(model: type[BaseModel], field_name: str, value: Any) -> Any: + """Validate a value for a specific field in a model.""" + return PydanticFieldValidator.get_validator(model, field_name).validate_python(value) diff --git a/invokeai/backend/model_manager/configs/ip_adapter.py b/invokeai/backend/model_manager/configs/ip_adapter.py new file mode 100644 index 00000000000..ba27f176201 --- /dev/null +++ b/invokeai/backend/model_manager/configs/ip_adapter.py @@ -0,0 +1,180 @@ +from abc import ABC +from typing import ( + Literal, + Self, +) + +from pydantic import BaseModel, Field +from typing_extensions import Any + +from invokeai.backend.flux.ip_adapter.state_dict_utils import is_state_dict_xlabs_ip_adapter +from invokeai.backend.model_manager.configs.base import Config_Base +from invokeai.backend.model_manager.configs.identification_utils import ( + NotAMatchError, + raise_for_override_fields, + raise_if_not_dir, + raise_if_not_file, + state_dict_has_any_keys_starting_with, +) +from invokeai.backend.model_manager.model_on_disk import ModelOnDisk +from invokeai.backend.model_manager.taxonomy import ( + BaseModelType, + ModelFormat, + ModelType, +) + + +class IPAdapter_Config_Base(ABC, BaseModel): + type: Literal[ModelType.IPAdapter] = Field(default=ModelType.IPAdapter) + + +class IPAdapter_InvokeAI_Config_Base(IPAdapter_Config_Base): + """Model config for IP Adapter diffusers format models.""" + + format: Literal[ModelFormat.InvokeAI] = Field(default=ModelFormat.InvokeAI) + + # TODO(ryand): Should we deprecate this field? From what I can tell, it hasn't been probed correctly for a long + # time. Need to go through the history to make sure I'm understanding this fully. + image_encoder_model_id: str = Field() + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self: + raise_if_not_dir(mod) + + raise_for_override_fields(cls, override_fields) + + cls._validate_has_weights_file(mod) + + cls._validate_has_image_encoder_metadata_file(mod) + + cls._validate_base(mod) + + return cls(**override_fields) + + @classmethod + def _validate_base(cls, mod: ModelOnDisk) -> None: + """Raise `NotAMatch` if the model base does not match this config class.""" + expected_base = cls.model_fields["base"].default + recognized_base = cls._get_base_or_raise(mod) + if expected_base is not recognized_base: + raise NotAMatchError(f"base is {recognized_base}, not {expected_base}") + + @classmethod + def _validate_has_weights_file(cls, mod: ModelOnDisk) -> None: + weights_file = mod.path / "ip_adapter.bin" + if not weights_file.exists(): + raise NotAMatchError("missing ip_adapter.bin weights file") + + @classmethod + def _validate_has_image_encoder_metadata_file(cls, mod: ModelOnDisk) -> None: + image_encoder_metadata_file = mod.path / "image_encoder.txt" + if not image_encoder_metadata_file.exists(): + raise NotAMatchError("missing image_encoder.txt metadata file") + + @classmethod + def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType: + state_dict = mod.load_state_dict() + + try: + cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1] + except Exception as e: + raise NotAMatchError(f"unable to determine cross attention dimension: {e}") from e + + match cross_attention_dim: + case 768: + return BaseModelType.StableDiffusion1 + case 1024: + return BaseModelType.StableDiffusion2 + case 2048: + return BaseModelType.StableDiffusionXL + case _: + raise NotAMatchError(f"unrecognized cross attention dimension {cross_attention_dim}") + + +class IPAdapter_InvokeAI_SD1_Config(IPAdapter_InvokeAI_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) + + +class IPAdapter_InvokeAI_SD2_Config(IPAdapter_InvokeAI_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2) + + +class IPAdapter_InvokeAI_SDXL_Config(IPAdapter_InvokeAI_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) + + +class IPAdapter_Checkpoint_Config_Base(IPAdapter_Config_Base): + """Model config for IP Adapter checkpoint format models.""" + + format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint) + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self: + raise_if_not_file(mod) + + raise_for_override_fields(cls, override_fields) + + cls._validate_looks_like_ip_adapter(mod) + + cls._validate_base(mod) + + return cls(**override_fields) + + @classmethod + def _validate_base(cls, mod: ModelOnDisk) -> None: + """Raise `NotAMatch` if the model base does not match this config class.""" + expected_base = cls.model_fields["base"].default + recognized_base = cls._get_base_or_raise(mod) + if expected_base is not recognized_base: + raise NotAMatchError(f"base is {recognized_base}, not {expected_base}") + + @classmethod + def _validate_looks_like_ip_adapter(cls, mod: ModelOnDisk) -> None: + if not state_dict_has_any_keys_starting_with( + mod.load_state_dict(), + { + "image_proj.", + "ip_adapter.", + # XLabs FLUX IP-Adapter models have keys startinh with "ip_adapter_proj_model.". + "ip_adapter_proj_model.", + }, + ): + raise NotAMatchError("model does not match Checkpoint IP Adapter heuristics") + + @classmethod + def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType: + state_dict = mod.load_state_dict() + + if is_state_dict_xlabs_ip_adapter(state_dict): + return BaseModelType.Flux + + try: + cross_attention_dim = state_dict["ip_adapter.1.to_k_ip.weight"].shape[-1] + except Exception as e: + raise NotAMatchError(f"unable to determine cross attention dimension: {e}") from e + + match cross_attention_dim: + case 768: + return BaseModelType.StableDiffusion1 + case 1024: + return BaseModelType.StableDiffusion2 + case 2048: + return BaseModelType.StableDiffusionXL + case _: + raise NotAMatchError(f"unrecognized cross attention dimension {cross_attention_dim}") + + +class IPAdapter_Checkpoint_SD1_Config(IPAdapter_Checkpoint_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) + + +class IPAdapter_Checkpoint_SD2_Config(IPAdapter_Checkpoint_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2) + + +class IPAdapter_Checkpoint_SDXL_Config(IPAdapter_Checkpoint_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) + + +class IPAdapter_Checkpoint_FLUX_Config(IPAdapter_Checkpoint_Config_Base, Config_Base): + base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) diff --git a/invokeai/backend/model_manager/configs/llava_onevision.py b/invokeai/backend/model_manager/configs/llava_onevision.py new file mode 100644 index 00000000000..c6ceb43ca9d --- /dev/null +++ b/invokeai/backend/model_manager/configs/llava_onevision.py @@ -0,0 +1,42 @@ +from typing import ( + Literal, + Self, +) + +from pydantic import Field +from typing_extensions import Any + +from invokeai.backend.model_manager.configs.base import Config_Base, Diffusers_Config_Base +from invokeai.backend.model_manager.configs.identification_utils import ( + common_config_paths, + raise_for_class_name, + raise_for_override_fields, + raise_if_not_dir, +) +from invokeai.backend.model_manager.model_on_disk import ModelOnDisk +from invokeai.backend.model_manager.taxonomy import ( + BaseModelType, + ModelType, +) + + +class LlavaOnevision_Diffusers_Config(Diffusers_Config_Base, Config_Base): + """Model config for Llava Onevision models.""" + + type: Literal[ModelType.LlavaOnevision] = Field(default=ModelType.LlavaOnevision) + base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any) + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self: + raise_if_not_dir(mod) + + raise_for_override_fields(cls, override_fields) + + raise_for_class_name( + common_config_paths(mod.path), + { + "LlavaOnevisionForConditionalGeneration", + }, + ) + + return cls(**override_fields) diff --git a/invokeai/backend/model_manager/configs/lora.py b/invokeai/backend/model_manager/configs/lora.py new file mode 100644 index 00000000000..26b9541c310 --- /dev/null +++ b/invokeai/backend/model_manager/configs/lora.py @@ -0,0 +1,322 @@ +from abc import ABC +from pathlib import Path +from typing import ( + Any, + Literal, + Self, +) + +from pydantic import BaseModel, ConfigDict, Field + +from invokeai.backend.model_manager.configs.base import ( + Config_Base, +) +from invokeai.backend.model_manager.configs.controlnet import ControlAdapterDefaultSettings +from invokeai.backend.model_manager.configs.identification_utils import ( + NotAMatchError, + raise_for_override_fields, + raise_if_not_dir, + raise_if_not_file, + state_dict_has_any_keys_ending_with, + state_dict_has_any_keys_starting_with, +) +from invokeai.backend.model_manager.model_on_disk import ModelOnDisk +from invokeai.backend.model_manager.omi import flux_dev_1_lora, stable_diffusion_xl_1_lora +from invokeai.backend.model_manager.taxonomy import ( + BaseModelType, + FluxLoRAFormat, + ModelFormat, + ModelType, +) +from invokeai.backend.model_manager.util.model_util import lora_token_vector_length +from invokeai.backend.patches.lora_conversions.flux_control_lora_utils import is_state_dict_likely_flux_control + + +class LoraModelDefaultSettings(BaseModel): + weight: float | None = Field(default=None, ge=-1, le=2, description="Default weight for this model") + model_config = ConfigDict(extra="forbid") + + +class LoRA_Config_Base(ABC, BaseModel): + """Base class for LoRA models.""" + + type: Literal[ModelType.LoRA] = Field(default=ModelType.LoRA) + trigger_phrases: set[str] | None = Field( + default=None, + description="Set of trigger phrases for this model", + ) + default_settings: LoraModelDefaultSettings | None = Field( + default=None, + description="Default settings for this model", + ) + + +def _get_flux_lora_format(mod: ModelOnDisk) -> FluxLoRAFormat | None: + # TODO(psyche): Moving this import to the function to avoid circular imports. Refactor later. + from invokeai.backend.patches.lora_conversions.formats import flux_format_from_state_dict + + state_dict = mod.load_state_dict() + value = flux_format_from_state_dict(state_dict, mod.metadata()) + return value + + +class LoRA_OMI_Config_Base(LoRA_Config_Base): + format: Literal[ModelFormat.OMI] = Field(default=ModelFormat.OMI) + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self: + raise_if_not_file(mod) + + raise_for_override_fields(cls, override_fields) + + cls._validate_looks_like_omi_lora(mod) + + cls._validate_base(mod) + + return cls(**override_fields) + + @classmethod + def _validate_base(cls, mod: ModelOnDisk) -> None: + """Raise `NotAMatch` if the model base does not match this config class.""" + expected_base = cls.model_fields["base"].default + recognized_base = cls._get_base_or_raise(mod) + if expected_base is not recognized_base: + raise NotAMatchError(f"base is {recognized_base}, not {expected_base}") + + @classmethod + def _validate_looks_like_omi_lora(cls, mod: ModelOnDisk) -> None: + """Raise `NotAMatch` if the model metadata does not look like an OMI LoRA.""" + flux_format = _get_flux_lora_format(mod) + if flux_format in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]: + raise NotAMatchError("model looks like ControlLoRA or Diffusers LoRA") + + metadata = mod.metadata() + + metadata_looks_like_omi_lora = ( + bool(metadata.get("modelspec.sai_model_spec")) + and metadata.get("ot_branch") == "omi_format" + and metadata.get("modelspec.architecture", "").split("/")[1].lower() == "lora" + ) + + if not metadata_looks_like_omi_lora: + raise NotAMatchError("metadata does not look like OMI LoRA") + + @classmethod + def _get_base_or_raise(cls, mod: ModelOnDisk) -> Literal[BaseModelType.Flux, BaseModelType.StableDiffusionXL]: + metadata = mod.metadata() + architecture = metadata["modelspec.architecture"] + + if architecture == stable_diffusion_xl_1_lora: + return BaseModelType.StableDiffusionXL + elif architecture == flux_dev_1_lora: + return BaseModelType.Flux + else: + raise NotAMatchError(f"unrecognised/unsupported architecture for OMI LoRA: {architecture}") + + +class LoRA_OMI_SDXL_Config(LoRA_OMI_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) + + +class LoRA_OMI_FLUX_Config(LoRA_OMI_Config_Base, Config_Base): + base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) + + +class LoRA_LyCORIS_Config_Base(LoRA_Config_Base): + """Model config for LoRA/Lycoris models.""" + + type: Literal[ModelType.LoRA] = Field(default=ModelType.LoRA) + format: Literal[ModelFormat.LyCORIS] = Field(default=ModelFormat.LyCORIS) + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self: + raise_if_not_file(mod) + + raise_for_override_fields(cls, override_fields) + + cls._validate_looks_like_lora(mod) + + cls._validate_base(mod) + + return cls(**override_fields) + + @classmethod + def _validate_base(cls, mod: ModelOnDisk) -> None: + """Raise `NotAMatch` if the model base does not match this config class.""" + expected_base = cls.model_fields["base"].default + recognized_base = cls._get_base_or_raise(mod) + if expected_base is not recognized_base: + raise NotAMatchError(f"base is {recognized_base}, not {expected_base}") + + @classmethod + def _validate_looks_like_lora(cls, mod: ModelOnDisk) -> None: + # First rule out ControlLoRA and Diffusers LoRA + flux_format = _get_flux_lora_format(mod) + if flux_format in [FluxLoRAFormat.Control]: + raise NotAMatchError("model looks like Control LoRA") + + # Note: Existence of these key prefixes/suffixes does not guarantee that this is a LoRA. + # Some main models have these keys, likely due to the creator merging in a LoRA. + has_key_with_lora_prefix = state_dict_has_any_keys_starting_with( + mod.load_state_dict(), + { + "lora_te_", + "lora_unet_", + "lora_te1_", + "lora_te2_", + "lora_transformer_", + }, + ) + + has_key_with_lora_suffix = state_dict_has_any_keys_ending_with( + mod.load_state_dict(), + { + "to_k_lora.up.weight", + "to_q_lora.down.weight", + "lora_A.weight", + "lora_B.weight", + }, + ) + + if not has_key_with_lora_prefix and not has_key_with_lora_suffix: + raise NotAMatchError("model does not match LyCORIS LoRA heuristics") + + @classmethod + def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType: + if _get_flux_lora_format(mod): + return BaseModelType.Flux + + state_dict = mod.load_state_dict() + # If we've gotten here, we assume that the model is a Stable Diffusion model + token_vector_length = lora_token_vector_length(state_dict) + if token_vector_length == 768: + return BaseModelType.StableDiffusion1 + elif token_vector_length == 1024: + return BaseModelType.StableDiffusion2 + elif token_vector_length == 1280: + return BaseModelType.StableDiffusionXL # recognizes format at https://civitai.com/models/224641 + elif token_vector_length == 2048: + return BaseModelType.StableDiffusionXL + else: + raise NotAMatchError(f"unrecognized token vector length {token_vector_length}") + + +class LoRA_LyCORIS_SD1_Config(LoRA_LyCORIS_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) + + +class LoRA_LyCORIS_SD2_Config(LoRA_LyCORIS_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2) + + +class LoRA_LyCORIS_SDXL_Config(LoRA_LyCORIS_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) + + +class LoRA_LyCORIS_FLUX_Config(LoRA_LyCORIS_Config_Base, Config_Base): + base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) + + +class ControlAdapter_Config_Base(ABC, BaseModel): + default_settings: ControlAdapterDefaultSettings | None = Field(None) + + +class ControlLoRA_LyCORIS_FLUX_Config(ControlAdapter_Config_Base, Config_Base): + """Model config for Control LoRA models.""" + + base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) + type: Literal[ModelType.ControlLoRa] = Field(default=ModelType.ControlLoRa) + format: Literal[ModelFormat.LyCORIS] = Field(default=ModelFormat.LyCORIS) + + trigger_phrases: set[str] | None = Field(None) + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self: + raise_if_not_file(mod) + + raise_for_override_fields(cls, override_fields) + + cls._validate_looks_like_control_lora(mod) + + return cls(**override_fields) + + @classmethod + def _validate_looks_like_control_lora(cls, mod: ModelOnDisk) -> None: + state_dict = mod.load_state_dict() + + if not is_state_dict_likely_flux_control(state_dict): + raise NotAMatchError("model state dict does not look like a Flux Control LoRA") + + +class LoRA_Diffusers_Config_Base(LoRA_Config_Base): + """Model config for LoRA/Diffusers models.""" + + # TODO(psyche): Needs base handling. For FLUX, the Diffusers format does not indicate a folder model; it indicates + # the weights format. FLUX Diffusers LoRAs are single files. + + format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers) + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self: + raise_if_not_dir(mod) + + raise_for_override_fields(cls, override_fields) + + cls._validate_base(mod) + + return cls(**override_fields) + + @classmethod + def _validate_base(cls, mod: ModelOnDisk) -> None: + """Raise `NotAMatch` if the model base does not match this config class.""" + expected_base = cls.model_fields["base"].default + recognized_base = cls._get_base_or_raise(mod) + if expected_base is not recognized_base: + raise NotAMatchError(f"base is {recognized_base}, not {expected_base}") + + @classmethod + def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType: + if _get_flux_lora_format(mod): + return BaseModelType.Flux + + # If we've gotten here, we assume that the LoRA is a Stable Diffusion LoRA + path_to_weight_file = cls._get_weight_file_or_raise(mod) + state_dict = mod.load_state_dict(path_to_weight_file) + token_vector_length = lora_token_vector_length(state_dict) + + match token_vector_length: + case 768: + return BaseModelType.StableDiffusion1 + case 1024: + return BaseModelType.StableDiffusion2 + case 1280: + return BaseModelType.StableDiffusionXL # recognizes format at https://civitai.com/models/224641 + case 2048: + return BaseModelType.StableDiffusionXL + case _: + raise NotAMatchError(f"unrecognized token vector length {token_vector_length}") + + @classmethod + def _get_weight_file_or_raise(cls, mod: ModelOnDisk) -> Path: + suffixes = ["bin", "safetensors"] + weight_files = [mod.path / f"pytorch_lora_weights.{sfx}" for sfx in suffixes] + for wf in weight_files: + if wf.exists(): + return wf + raise NotAMatchError("missing pytorch_lora_weights.bin or pytorch_lora_weights.safetensors") + + +class LoRA_Diffusers_SD1_Config(LoRA_Diffusers_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) + + +class LoRA_Diffusers_SD2_Config(LoRA_Diffusers_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2) + + +class LoRA_Diffusers_SDXL_Config(LoRA_Diffusers_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) + + +class LoRA_Diffusers_FLUX_Config(LoRA_Diffusers_Config_Base, Config_Base): + base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) diff --git a/invokeai/backend/model_manager/configs/main.py b/invokeai/backend/model_manager/configs/main.py new file mode 100644 index 00000000000..dcb948d99bb --- /dev/null +++ b/invokeai/backend/model_manager/configs/main.py @@ -0,0 +1,705 @@ +from abc import ABC +from typing import Any, Literal, Self + +from pydantic import BaseModel, ConfigDict, Field + +from invokeai.backend.model_manager.configs.base import ( + Checkpoint_Config_Base, + Config_Base, + Diffusers_Config_Base, + SubmodelDefinition, +) +from invokeai.backend.model_manager.configs.clip_embed import get_clip_variant_type_from_config +from invokeai.backend.model_manager.configs.identification_utils import ( + NotAMatchError, + common_config_paths, + get_config_dict_or_raise, + raise_for_class_name, + raise_for_override_fields, + raise_if_not_dir, + raise_if_not_file, + state_dict_has_any_keys_exact, +) +from invokeai.backend.model_manager.model_on_disk import ModelOnDisk +from invokeai.backend.model_manager.taxonomy import ( + BaseModelType, + FluxVariantType, + ModelFormat, + ModelType, + ModelVariantType, + SchedulerPredictionType, + SubModelType, +) +from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor +from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES + +DEFAULTS_PRECISION = Literal["fp16", "fp32"] + + +class MainModelDefaultSettings(BaseModel): + vae: str | None = Field(default=None, description="Default VAE for this model (model key)") + vae_precision: DEFAULTS_PRECISION | None = Field(default=None, description="Default VAE precision for this model") + scheduler: SCHEDULER_NAME_VALUES | None = Field(default=None, description="Default scheduler for this model") + steps: int | None = Field(default=None, gt=0, description="Default number of steps for this model") + cfg_scale: float | None = Field(default=None, ge=1, description="Default CFG Scale for this model") + cfg_rescale_multiplier: float | None = Field( + default=None, ge=0, lt=1, description="Default CFG Rescale Multiplier for this model" + ) + width: int | None = Field(default=None, multiple_of=8, ge=64, description="Default width for this model") + height: int | None = Field(default=None, multiple_of=8, ge=64, description="Default height for this model") + guidance: float | None = Field(default=None, ge=1, description="Default Guidance for this model") + + model_config = ConfigDict(extra="forbid") + + @classmethod + def from_base(cls, base: BaseModelType) -> Self | None: + match base: + case BaseModelType.StableDiffusion1: + return cls(width=512, height=512) + case BaseModelType.StableDiffusion2: + return cls(width=768, height=768) + case BaseModelType.StableDiffusionXL: + return cls(width=1024, height=1024) + case _: + # TODO(psyche): Do we want defaults for other base types? + return None + + +class Main_Config_Base(ABC, BaseModel): + type: Literal[ModelType.Main] = Field(default=ModelType.Main) + trigger_phrases: set[str] | None = Field( + default=None, + description="Set of trigger phrases for this model", + ) + default_settings: MainModelDefaultSettings | None = Field( + default=None, + description="Default settings for this model", + ) + + +def _has_bnb_nf4_keys(state_dict: dict[str | int, Any]) -> bool: + bnb_nf4_keys = { + "double_blocks.0.img_attn.proj.weight.quant_state.bitsandbytes__nf4", + "model.diffusion_model.double_blocks.0.img_attn.proj.weight.quant_state.bitsandbytes__nf4", + } + return any(key in state_dict for key in bnb_nf4_keys) + + +def _has_ggml_tensors(state_dict: dict[str | int, Any]) -> bool: + return any(isinstance(v, GGMLTensor) for v in state_dict.values()) + + +def _has_main_keys(state_dict: dict[str | int, Any]) -> bool: + for key in state_dict.keys(): + if isinstance(key, int): + continue + elif key.startswith( + ( + "cond_stage_model.", + "first_stage_model.", + "model.diffusion_model.", + # Some FLUX checkpoint files contain transformer keys prefixed with "model.diffusion_model". + # This prefix is typically used to distinguish between multiple models bundled in a single file. + "model.diffusion_model.double_blocks.", + ) + ): + return True + elif key.startswith("double_blocks.") and "ip_adapter" not in key: + # FLUX models in the official BFL format contain keys with the "double_blocks." prefix, but we must be + # careful to avoid false positives on XLabs FLUX IP-Adapter models. + return True + return False + + +class Main_SD_Checkpoint_Config_Base(Checkpoint_Config_Base, Main_Config_Base): + """Model config for main checkpoint models.""" + + format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint) + + prediction_type: SchedulerPredictionType = Field() + variant: ModelVariantType = Field() + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self: + raise_if_not_file(mod) + + raise_for_override_fields(cls, override_fields) + + cls._validate_looks_like_main_model(mod) + + cls._validate_base(mod) + + prediction_type = override_fields.get("prediction_type") or cls._get_scheduler_prediction_type_or_raise(mod) + + variant = override_fields.get("variant") or cls._get_variant_or_raise(mod) + + return cls(**override_fields, prediction_type=prediction_type, variant=variant) + + @classmethod + def _validate_base(cls, mod: ModelOnDisk) -> None: + """Raise `NotAMatch` if the model base does not match this config class.""" + expected_base = cls.model_fields["base"].default + recognized_base = cls._get_base_or_raise(mod) + if expected_base is not recognized_base: + raise NotAMatchError(f"base is {recognized_base}, not {expected_base}") + + @classmethod + def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType: + state_dict = mod.load_state_dict() + + key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" + if key_name in state_dict and state_dict[key_name].shape[-1] == 768: + return BaseModelType.StableDiffusion1 + if key_name in state_dict and state_dict[key_name].shape[-1] == 1024: + return BaseModelType.StableDiffusion2 + + key_name = "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight" + if key_name in state_dict and state_dict[key_name].shape[-1] == 2048: + return BaseModelType.StableDiffusionXL + elif key_name in state_dict and state_dict[key_name].shape[-1] == 1280: + return BaseModelType.StableDiffusionXLRefiner + + raise NotAMatchError("unable to determine base type from state dict") + + @classmethod + def _get_scheduler_prediction_type_or_raise(cls, mod: ModelOnDisk) -> SchedulerPredictionType: + base = cls.model_fields["base"].default + + if base is BaseModelType.StableDiffusion2: + state_dict = mod.load_state_dict() + key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" + if key_name in state_dict and state_dict[key_name].shape[-1] == 1024: + if "global_step" in state_dict: + if state_dict["global_step"] == 220000: + return SchedulerPredictionType.Epsilon + elif state_dict["global_step"] == 110000: + return SchedulerPredictionType.VPrediction + return SchedulerPredictionType.VPrediction + else: + return SchedulerPredictionType.Epsilon + + @classmethod + def _get_variant_or_raise(cls, mod: ModelOnDisk) -> ModelVariantType: + base = cls.model_fields["base"].default + + state_dict = mod.load_state_dict() + key_name = "model.diffusion_model.input_blocks.0.0.weight" + + if key_name not in state_dict: + raise NotAMatchError("unable to determine model variant from state dict") + + in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1] + + match in_channels: + case 4: + return ModelVariantType.Normal + case 5: + # Only SD2 has a depth variant + assert base is BaseModelType.StableDiffusion2, f"unexpected unet in_channels 5 for base '{base}'" + return ModelVariantType.Depth + case 9: + return ModelVariantType.Inpaint + case _: + raise NotAMatchError(f"unrecognized unet in_channels {in_channels} for base '{base}'") + + @classmethod + def _validate_looks_like_main_model(cls, mod: ModelOnDisk) -> None: + has_main_model_keys = _has_main_keys(mod.load_state_dict()) + if not has_main_model_keys: + raise NotAMatchError("state dict does not look like a main model") + + +class Main_Checkpoint_SD1_Config(Main_SD_Checkpoint_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) + + +class Main_Checkpoint_SD2_Config(Main_SD_Checkpoint_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2) + + +class Main_Checkpoint_SDXL_Config(Main_SD_Checkpoint_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) + + +class Main_Checkpoint_SDXLRefiner_Config(Main_SD_Checkpoint_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusionXLRefiner] = Field(default=BaseModelType.StableDiffusionXLRefiner) + + +def _get_flux_variant(state_dict: dict[str | int, Any]) -> FluxVariantType | None: + # FLUX Model variant types are distinguished by input channels and the presence of certain keys. + + # Input channels are derived from the shape of either "img_in.weight" or "model.diffusion_model.img_in.weight". + # + # Known models that use the latter key: + # - https://civitai.com/models/885098?modelVersionId=990775 + # - https://civitai.com/models/1018060?modelVersionId=1596255 + # - https://civitai.com/models/978314/ultrareal-fine-tune?modelVersionId=1413133 + # + # Input channels for known FLUX models: + # - Unquantized Dev and Schnell have in_channels=64 + # - BNB-NF4 Dev and Schnell have in_channels=1 + # - FLUX Fill has in_channels=384 + # - Unsure of quantized FLUX Fill models + # - Unsure of GGUF-quantized models + + in_channels = None + for key in {"img_in.weight", "model.diffusion_model.img_in.weight"}: + if key in state_dict: + in_channels = state_dict[key].shape[1] + break + + if in_channels is None: + # TODO(psyche): Should we have a graceful fallback here? Previously we fell back to the "normal" variant, + # but this variant is no longer used for FLUX models. If we get here, but the model is definitely a FLUX + # model, we should figure out a good fallback value. + return None + + # Because FLUX Dev and Schnell models have the same in_channels, we need to check for the presence of + # certain keys to distinguish between them. + is_flux_dev = ( + "guidance_in.out_layer.weight" in state_dict + or "model.diffusion_model.guidance_in.out_layer.weight" in state_dict + ) + + if is_flux_dev and in_channels == 384: + return FluxVariantType.DevFill + elif is_flux_dev: + return FluxVariantType.Dev + else: + # Must be a Schnell model...? + return FluxVariantType.Schnell + + +class Main_Checkpoint_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Base): + """Model config for main checkpoint models.""" + + format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint) + base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) + + variant: FluxVariantType = Field() + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self: + raise_if_not_file(mod) + + raise_for_override_fields(cls, override_fields) + + cls._validate_looks_like_main_model(mod) + + cls._validate_is_flux(mod) + + cls._validate_does_not_look_like_bnb_quantized(mod) + + cls._validate_does_not_look_like_gguf_quantized(mod) + + variant = override_fields.get("variant") or cls._get_variant_or_raise(mod) + + return cls(**override_fields, variant=variant) + + @classmethod + def _validate_is_flux(cls, mod: ModelOnDisk) -> None: + if not state_dict_has_any_keys_exact( + mod.load_state_dict(), + { + "double_blocks.0.img_attn.norm.key_norm.scale", + "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale", + }, + ): + raise NotAMatchError("state dict does not look like a FLUX checkpoint") + + @classmethod + def _get_variant_or_raise(cls, mod: ModelOnDisk) -> FluxVariantType: + # FLUX Model variant types are distinguished by input channels and the presence of certain keys. + state_dict = mod.load_state_dict() + variant = _get_flux_variant(state_dict) + + if variant is None: + # TODO(psyche): Should we have a graceful fallback here? Previously we fell back to the "normal" variant, + # but this variant is no longer used for FLUX models. If we get here, but the model is definitely a FLUX + # model, we should figure out a good fallback value. + raise NotAMatchError("unable to determine model variant from state dict") + + return variant + + @classmethod + def _validate_looks_like_main_model(cls, mod: ModelOnDisk) -> None: + has_main_model_keys = _has_main_keys(mod.load_state_dict()) + if not has_main_model_keys: + raise NotAMatchError("state dict does not look like a main model") + + @classmethod + def _validate_does_not_look_like_bnb_quantized(cls, mod: ModelOnDisk) -> None: + has_bnb_nf4_keys = _has_bnb_nf4_keys(mod.load_state_dict()) + if has_bnb_nf4_keys: + raise NotAMatchError("state dict looks like bnb quantized nf4") + + @classmethod + def _validate_does_not_look_like_gguf_quantized(cls, mod: ModelOnDisk): + has_ggml_tensors = _has_ggml_tensors(mod.load_state_dict()) + if has_ggml_tensors: + raise NotAMatchError("state dict looks like GGUF quantized") + + +class Main_BnBNF4_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Base): + """Model config for main checkpoint models.""" + + base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) + format: Literal[ModelFormat.BnbQuantizednf4b] = Field(default=ModelFormat.BnbQuantizednf4b) + + variant: FluxVariantType = Field() + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self: + raise_if_not_file(mod) + + raise_for_override_fields(cls, override_fields) + + cls._validate_looks_like_main_model(mod) + + cls._validate_model_looks_like_bnb_quantized(mod) + + variant = override_fields.get("variant") or cls._get_variant_or_raise(mod) + + return cls(**override_fields, variant=variant) + + @classmethod + def _get_variant_or_raise(cls, mod: ModelOnDisk) -> FluxVariantType: + # FLUX Model variant types are distinguished by input channels and the presence of certain keys. + state_dict = mod.load_state_dict() + variant = _get_flux_variant(state_dict) + + if variant is None: + # TODO(psyche): Should we have a graceful fallback here? Previously we fell back to the "normal" variant, + # but this variant is no longer used for FLUX models. If we get here, but the model is definitely a FLUX + # model, we should figure out a good fallback value. + raise NotAMatchError("unable to determine model variant from state dict") + + return variant + + @classmethod + def _validate_looks_like_main_model(cls, mod: ModelOnDisk) -> None: + has_main_model_keys = _has_main_keys(mod.load_state_dict()) + if not has_main_model_keys: + raise NotAMatchError("state dict does not look like a main model") + + @classmethod + def _validate_model_looks_like_bnb_quantized(cls, mod: ModelOnDisk) -> None: + has_bnb_nf4_keys = _has_bnb_nf4_keys(mod.load_state_dict()) + if not has_bnb_nf4_keys: + raise NotAMatchError("state dict does not look like bnb quantized nf4") + + +class Main_GGUF_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Base): + """Model config for main checkpoint models.""" + + base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) + format: Literal[ModelFormat.GGUFQuantized] = Field(default=ModelFormat.GGUFQuantized) + + variant: FluxVariantType = Field() + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self: + raise_if_not_file(mod) + + raise_for_override_fields(cls, override_fields) + + cls._validate_looks_like_main_model(mod) + + cls._validate_looks_like_gguf_quantized(mod) + + variant = override_fields.get("variant") or cls._get_variant_or_raise(mod) + + return cls(**override_fields, variant=variant) + + @classmethod + def _get_variant_or_raise(cls, mod: ModelOnDisk) -> FluxVariantType: + # FLUX Model variant types are distinguished by input channels and the presence of certain keys. + state_dict = mod.load_state_dict() + variant = _get_flux_variant(state_dict) + + if variant is None: + # TODO(psyche): Should we have a graceful fallback here? Previously we fell back to the "normal" variant, + # but this variant is no longer used for FLUX models. If we get here, but the model is definitely a FLUX + # model, we should figure out a good fallback value. + raise NotAMatchError("unable to determine model variant from state dict") + + return variant + + @classmethod + def _validate_looks_like_main_model(cls, mod: ModelOnDisk) -> None: + has_main_model_keys = _has_main_keys(mod.load_state_dict()) + if not has_main_model_keys: + raise NotAMatchError("state dict does not look like a main model") + + @classmethod + def _validate_looks_like_gguf_quantized(cls, mod: ModelOnDisk) -> None: + has_ggml_tensors = _has_ggml_tensors(mod.load_state_dict()) + if not has_ggml_tensors: + raise NotAMatchError("state dict does not look like GGUF quantized") + + +class Main_SD_Diffusers_Config_Base(Diffusers_Config_Base, Main_Config_Base): + prediction_type: SchedulerPredictionType = Field() + variant: ModelVariantType = Field() + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self: + raise_if_not_dir(mod) + + raise_for_override_fields(cls, override_fields) + + raise_for_class_name( + common_config_paths(mod.path), + { + # SD 1.x and 2.x + "StableDiffusionPipeline", + "StableDiffusionInpaintPipeline", + # SDXL + "StableDiffusionXLPipeline", + "StableDiffusionXLInpaintPipeline", + # SDXL Refiner + "StableDiffusionXLImg2ImgPipeline", + # TODO(psyche): Do we actually support LCM models? I don't see using this class anywhere in the codebase. + "LatentConsistencyModelPipeline", + }, + ) + + cls._validate_base(mod) + + variant = override_fields.get("variant") or cls._get_variant_or_raise(mod) + + prediction_type = override_fields.get("prediction_type") or cls._get_scheduler_prediction_type_or_raise(mod) + + repo_variant = override_fields.get("repo_variant") or cls._get_repo_variant_or_raise(mod) + + return cls( + **override_fields, + variant=variant, + prediction_type=prediction_type, + repo_variant=repo_variant, + ) + + @classmethod + def _validate_base(cls, mod: ModelOnDisk) -> None: + """Raise `NotAMatch` if the model base does not match this config class.""" + expected_base = cls.model_fields["base"].default + recognized_base = cls._get_base_or_raise(mod) + if expected_base is not recognized_base: + raise NotAMatchError(f"base is {recognized_base}, not {expected_base}") + + @classmethod + def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType: + # Handle pipelines with a UNet (i.e SD 1.x, SD2.x, SDXL). + unet_conf = get_config_dict_or_raise(mod.path / "unet" / "config.json") + cross_attention_dim = unet_conf.get("cross_attention_dim") + match cross_attention_dim: + case 768: + return BaseModelType.StableDiffusion1 + case 1024: + return BaseModelType.StableDiffusion2 + case 1280: + return BaseModelType.StableDiffusionXLRefiner + case 2048: + return BaseModelType.StableDiffusionXL + case _: + raise NotAMatchError(f"unrecognized cross_attention_dim {cross_attention_dim}") + + @classmethod + def _get_scheduler_prediction_type_or_raise(cls, mod: ModelOnDisk) -> SchedulerPredictionType: + scheduler_conf = get_config_dict_or_raise(mod.path / "scheduler" / "scheduler_config.json") + + # TODO(psyche): Is epsilon the right default or should we raise if it's not present? + prediction_type = scheduler_conf.get("prediction_type", "epsilon") + + match prediction_type: + case "v_prediction": + return SchedulerPredictionType.VPrediction + case "epsilon": + return SchedulerPredictionType.Epsilon + case _: + raise NotAMatchError(f"unrecognized scheduler prediction_type {prediction_type}") + + @classmethod + def _get_variant_or_raise(cls, mod: ModelOnDisk) -> ModelVariantType: + base = cls.model_fields["base"].default + unet_config = get_config_dict_or_raise(mod.path / "unet" / "config.json") + in_channels = unet_config.get("in_channels") + + match in_channels: + case 4: + return ModelVariantType.Normal + case 5: + # Only SD2 has a depth variant + assert base is BaseModelType.StableDiffusion2, f"unexpected unet in_channels 5 for base '{base}'" + return ModelVariantType.Depth + case 9: + return ModelVariantType.Inpaint + case _: + raise NotAMatchError(f"unrecognized unet in_channels {in_channels} for base '{base}'") + + +class Main_Diffusers_SD1_Config(Main_SD_Diffusers_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion1] = Field(BaseModelType.StableDiffusion1) + + +class Main_Diffusers_SD2_Config(Main_SD_Diffusers_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion2] = Field(BaseModelType.StableDiffusion2) + + +class Main_Diffusers_SDXL_Config(Main_SD_Diffusers_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusionXL] = Field(BaseModelType.StableDiffusionXL) + + +class Main_Diffusers_SDXLRefiner_Config(Main_SD_Diffusers_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusionXLRefiner] = Field(BaseModelType.StableDiffusionXLRefiner) + + +class Main_Diffusers_SD3_Config(Diffusers_Config_Base, Main_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion3] = Field(BaseModelType.StableDiffusion3) + submodels: dict[SubModelType, SubmodelDefinition] | None = Field( + description="Loadable submodels in this model", + default=None, + ) + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self: + raise_if_not_dir(mod) + + raise_for_override_fields(cls, override_fields) + + # This check implies the base type - no further validation needed. + raise_for_class_name( + common_config_paths(mod.path), + { + "StableDiffusion3Pipeline", + "SD3Transformer2DModel", + }, + ) + + submodels = override_fields.get("submodels") or cls._get_submodels_or_raise(mod) + + repo_variant = override_fields.get("repo_variant") or cls._get_repo_variant_or_raise(mod) + + return cls( + **override_fields, + submodels=submodels, + repo_variant=repo_variant, + ) + + @classmethod + def _get_submodels_or_raise(cls, mod: ModelOnDisk) -> dict[SubModelType, SubmodelDefinition]: + # Example: https://huggingface.co/stabilityai/stable-diffusion-3.5-medium/blob/main/model_index.json + config = get_config_dict_or_raise(common_config_paths(mod.path)) + + submodels: dict[SubModelType, SubmodelDefinition] = {} + + for key, value in config.items(): + # Anything that starts with an underscore is top-level metadata, not a submodel + if key.startswith("_") or not (isinstance(value, list) and len(value) == 2): + continue + # The key is something like "transformer" and is a submodel - it will be in a dir of the same name. + # The value value is something like ["diffusers", "SD3Transformer2DModel"] + _library_name, class_name = value + + match class_name: + case "CLIPTextModelWithProjection": + model_type = ModelType.CLIPEmbed + path_or_prefix = (mod.path / key).resolve().as_posix() + + # We need to read the config to determine the variant of the CLIP model. + clip_embed_config = get_config_dict_or_raise( + { + mod.path / key / "config.json", + mod.path / key / "model_index.json", + } + ) + variant = get_clip_variant_type_from_config(clip_embed_config) + submodels[SubModelType(key)] = SubmodelDefinition( + path_or_prefix=path_or_prefix, + model_type=model_type, + variant=variant, + ) + case "SD3Transformer2DModel": + model_type = ModelType.Main + path_or_prefix = (mod.path / key).resolve().as_posix() + variant = None + submodels[SubModelType(key)] = SubmodelDefinition( + path_or_prefix=path_or_prefix, + model_type=model_type, + variant=variant, + ) + case _: + pass + + return submodels + + +class Main_Diffusers_CogView4_Config(Diffusers_Config_Base, Main_Config_Base, Config_Base): + base: Literal[BaseModelType.CogView4] = Field(BaseModelType.CogView4) + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self: + raise_if_not_dir(mod) + + raise_for_override_fields(cls, override_fields) + + # This check implies the base type - no further validation needed. + raise_for_class_name( + common_config_paths(mod.path), + { + "CogView4Pipeline", + }, + ) + + repo_variant = override_fields.get("repo_variant") or cls._get_repo_variant_or_raise(mod) + + return cls( + **override_fields, + repo_variant=repo_variant, + ) + + +class ExternalAPI_Config_Base(ABC, BaseModel): + """Model config for API-based models.""" + + format: Literal[ModelFormat.Api] = Field(default=ModelFormat.Api) + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self: + raise NotAMatchError("External API models cannot be built from disk") + + +class Main_ExternalAPI_ChatGPT4o_Config(ExternalAPI_Config_Base, Main_Config_Base, Config_Base): + base: Literal[BaseModelType.ChatGPT4o] = Field(default=BaseModelType.ChatGPT4o) + + +class Main_ExternalAPI_Gemini2_5_Config(ExternalAPI_Config_Base, Main_Config_Base, Config_Base): + base: Literal[BaseModelType.Gemini2_5] = Field(default=BaseModelType.Gemini2_5) + + +class Main_ExternalAPI_Imagen3_Config(ExternalAPI_Config_Base, Main_Config_Base, Config_Base): + base: Literal[BaseModelType.Imagen3] = Field(default=BaseModelType.Imagen3) + + +class Main_ExternalAPI_Imagen4_Config(ExternalAPI_Config_Base, Main_Config_Base, Config_Base): + base: Literal[BaseModelType.Imagen4] = Field(default=BaseModelType.Imagen4) + + +class Main_ExternalAPI_FluxKontext_Config(ExternalAPI_Config_Base, Main_Config_Base, Config_Base): + base: Literal[BaseModelType.FluxKontext] = Field(default=BaseModelType.FluxKontext) + + +class Video_Config_Base(ABC, BaseModel): + type: Literal[ModelType.Video] = Field(default=ModelType.Video) + trigger_phrases: set[str] | None = Field(description="Set of trigger phrases for this model", default=None) + default_settings: MainModelDefaultSettings | None = Field( + description="Default settings for this model", default=None + ) + + +class Video_ExternalAPI_Veo3_Config(ExternalAPI_Config_Base, Video_Config_Base, Config_Base): + base: Literal[BaseModelType.Veo3] = Field(default=BaseModelType.Veo3) + + +class Video_ExternalAPI_Runway_Config(ExternalAPI_Config_Base, Video_Config_Base, Config_Base): + base: Literal[BaseModelType.Runway] = Field(default=BaseModelType.Runway) diff --git a/invokeai/backend/model_manager/configs/siglip.py b/invokeai/backend/model_manager/configs/siglip.py new file mode 100644 index 00000000000..62ca9494e27 --- /dev/null +++ b/invokeai/backend/model_manager/configs/siglip.py @@ -0,0 +1,44 @@ +from typing import ( + Literal, + Self, +) + +from pydantic import Field +from typing_extensions import Any + +from invokeai.backend.model_manager.configs.base import Config_Base, Diffusers_Config_Base +from invokeai.backend.model_manager.configs.identification_utils import ( + common_config_paths, + raise_for_class_name, + raise_for_override_fields, + raise_if_not_dir, +) +from invokeai.backend.model_manager.model_on_disk import ModelOnDisk +from invokeai.backend.model_manager.taxonomy import ( + BaseModelType, + ModelFormat, + ModelType, +) + + +class SigLIP_Diffusers_Config(Diffusers_Config_Base, Config_Base): + """Model config for SigLIP.""" + + type: Literal[ModelType.SigLIP] = Field(default=ModelType.SigLIP) + format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers) + base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any) + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self: + raise_if_not_dir(mod) + + raise_for_override_fields(cls, override_fields) + + raise_for_class_name( + common_config_paths(mod.path), + { + "SiglipModel", + }, + ) + + return cls(**override_fields) diff --git a/invokeai/backend/model_manager/configs/spandrel.py b/invokeai/backend/model_manager/configs/spandrel.py new file mode 100644 index 00000000000..8ca8ad5f603 --- /dev/null +++ b/invokeai/backend/model_manager/configs/spandrel.py @@ -0,0 +1,54 @@ +from typing import ( + Literal, + Self, +) + +from pydantic import Field +from typing_extensions import Any + +from invokeai.backend.model_manager.configs.base import Config_Base +from invokeai.backend.model_manager.configs.identification_utils import ( + NotAMatchError, + raise_for_override_fields, + raise_if_not_file, +) +from invokeai.backend.model_manager.model_on_disk import ModelOnDisk +from invokeai.backend.model_manager.taxonomy import ( + BaseModelType, + ModelFormat, + ModelType, +) +from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel + + +class Spandrel_Checkpoint_Config(Config_Base): + """Model config for Spandrel Image to Image models.""" + + base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any) + type: Literal[ModelType.SpandrelImageToImage] = Field(default=ModelType.SpandrelImageToImage) + format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint) + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self: + raise_if_not_file(mod) + + raise_for_override_fields(cls, override_fields) + + cls._validate_spandrel_loads_model(mod) + + return cls(**override_fields) + + @classmethod + def _validate_spandrel_loads_model(cls, mod: ModelOnDisk) -> None: + try: + # It would be nice to avoid having to load the Spandrel model from disk here. A couple of options were + # explored to avoid this: + # 1. Call `SpandrelImageToImageModel.load_from_state_dict(ckpt)`, where `ckpt` is a state_dict on the meta + # device. Unfortunately, some Spandrel models perform operations during initialization that are not + # supported on meta tensors. + # 2. Spandrel has internal logic to determine a model's type from its state_dict before loading the model. + # This logic is not exposed in spandrel's public API. We could copy the logic here, but then we have to + # maintain it, and the risk of false positive detections is higher. + SpandrelImageToImageModel.load_from_file(mod.path) + except Exception as e: + raise NotAMatchError("model does not match SpandrelImageToImage heuristics") from e diff --git a/invokeai/backend/model_manager/configs/t2i_adapter.py b/invokeai/backend/model_manager/configs/t2i_adapter.py new file mode 100644 index 00000000000..a1da40e9b4b --- /dev/null +++ b/invokeai/backend/model_manager/configs/t2i_adapter.py @@ -0,0 +1,79 @@ +from typing import ( + Literal, + Self, +) + +from pydantic import Field +from typing_extensions import Any + +from invokeai.backend.model_manager.configs.base import Config_Base, Diffusers_Config_Base +from invokeai.backend.model_manager.configs.controlnet import ControlAdapterDefaultSettings +from invokeai.backend.model_manager.configs.identification_utils import ( + NotAMatchError, + common_config_paths, + get_config_dict_or_raise, + raise_for_class_name, + raise_for_override_fields, + raise_if_not_dir, +) +from invokeai.backend.model_manager.model_on_disk import ModelOnDisk +from invokeai.backend.model_manager.taxonomy import ( + BaseModelType, + ModelFormat, + ModelType, +) + + +class T2IAdapter_Diffusers_Config_Base(Diffusers_Config_Base): + """Model config for T2I.""" + + type: Literal[ModelType.T2IAdapter] = Field(default=ModelType.T2IAdapter) + format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers) + default_settings: ControlAdapterDefaultSettings | None = Field(None) + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self: + raise_if_not_dir(mod) + + raise_for_override_fields(cls, override_fields) + + raise_for_class_name( + common_config_paths(mod.path), + { + "T2IAdapter", + }, + ) + + cls._validate_base(mod) + + return cls(**override_fields) + + @classmethod + def _validate_base(cls, mod: ModelOnDisk) -> None: + """Raise `NotAMatch` if the model base does not match this config class.""" + expected_base = cls.model_fields["base"].default + recognized_base = cls._get_base_or_raise(mod) + if expected_base is not recognized_base: + raise NotAMatchError(f"base is {recognized_base}, not {expected_base}") + + @classmethod + def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType: + config_dict = get_config_dict_or_raise(common_config_paths(mod.path)) + + adapter_type = config_dict.get("adapter_type") + + match adapter_type: + case "full_adapter_xl": + return BaseModelType.StableDiffusionXL + case "full_adapter" | "light_adapter": + return BaseModelType.StableDiffusion1 + case _: + raise NotAMatchError(f"unrecognized adapter_type '{adapter_type}'") + + +class T2IAdapter_Diffusers_SD1_Config(T2IAdapter_Diffusers_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) + + +class T2IAdapter_Diffusers_SDXL_Config(T2IAdapter_Diffusers_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) diff --git a/invokeai/backend/model_manager/configs/t5_encoder.py b/invokeai/backend/model_manager/configs/t5_encoder.py new file mode 100644 index 00000000000..ed682e14304 --- /dev/null +++ b/invokeai/backend/model_manager/configs/t5_encoder.py @@ -0,0 +1,80 @@ +from typing import Any, Literal, Self + +from pydantic import Field + +from invokeai.backend.model_manager.configs.base import Config_Base +from invokeai.backend.model_manager.configs.identification_utils import ( + NotAMatchError, + raise_for_class_name, + raise_for_override_fields, + raise_if_not_dir, + state_dict_has_any_keys_ending_with, +) +from invokeai.backend.model_manager.model_on_disk import ModelOnDisk +from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType + + +class T5Encoder_T5Encoder_Config(Config_Base): + """Configuration for T5 Encoder models in a bespoke, diffusers-like format. The model weights are expected to be in + a folder called text_encoder_2 inside the model directory, with a config file named model.safetensors.index.json.""" + + base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any) + type: Literal[ModelType.T5Encoder] = Field(default=ModelType.T5Encoder) + format: Literal[ModelFormat.T5Encoder] = Field(default=ModelFormat.T5Encoder) + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self: + raise_if_not_dir(mod) + + raise_for_override_fields(cls, override_fields) + + expected_config_path = mod.path / "text_encoder_2" / "config.json" + expected_class_name = "T5EncoderModel" + raise_for_class_name(expected_config_path, expected_class_name) + + cls.raise_if_doesnt_have_unquantized_config_file(mod) + + return cls(**override_fields) + + @classmethod + def raise_if_doesnt_have_unquantized_config_file(cls, mod: ModelOnDisk) -> None: + has_unquantized_config = (mod.path / "text_encoder_2" / "model.safetensors.index.json").exists() + + if not has_unquantized_config: + raise NotAMatchError("missing text_encoder_2/model.safetensors.index.json") + + +class T5Encoder_BnBLLMint8_Config(Config_Base): + """Configuration for T5 Encoder models quantized by bitsandbytes' LLM.int8.""" + + base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any) + type: Literal[ModelType.T5Encoder] = Field(default=ModelType.T5Encoder) + format: Literal[ModelFormat.BnbQuantizedLlmInt8b] = Field(default=ModelFormat.BnbQuantizedLlmInt8b) + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self: + raise_if_not_dir(mod) + + raise_for_override_fields(cls, override_fields) + + expected_config_path = mod.path / "text_encoder_2" / "config.json" + expected_class_name = "T5EncoderModel" + raise_for_class_name(expected_config_path, expected_class_name) + + cls.raise_if_filename_doesnt_look_like_bnb_quantized(mod) + + cls.raise_if_state_dict_doesnt_look_like_bnb_quantized(mod) + + return cls(**override_fields) + + @classmethod + def raise_if_filename_doesnt_look_like_bnb_quantized(cls, mod: ModelOnDisk) -> None: + filename_looks_like_bnb = any(x for x in mod.weight_files() if "llm_int8" in x.as_posix()) + if not filename_looks_like_bnb: + raise NotAMatchError("filename does not look like bnb quantized llm_int8") + + @classmethod + def raise_if_state_dict_doesnt_look_like_bnb_quantized(cls, mod: ModelOnDisk) -> None: + has_scb_key_suffix = state_dict_has_any_keys_ending_with(mod.load_state_dict(), "SCB") + if not has_scb_key_suffix: + raise NotAMatchError("state dict does not look like bnb quantized llm_int8") diff --git a/invokeai/backend/model_manager/configs/textual_inversion.py b/invokeai/backend/model_manager/configs/textual_inversion.py new file mode 100644 index 00000000000..c827f5234d5 --- /dev/null +++ b/invokeai/backend/model_manager/configs/textual_inversion.py @@ -0,0 +1,156 @@ +from abc import ABC +from pathlib import Path +from typing import ( + Literal, + Self, +) + +import torch +from pydantic import BaseModel, Field +from typing_extensions import Any + +from invokeai.backend.model_manager.configs.base import Config_Base +from invokeai.backend.model_manager.configs.identification_utils import ( + NotAMatchError, + raise_for_override_fields, + raise_if_not_dir, + raise_if_not_file, +) +from invokeai.backend.model_manager.model_on_disk import ModelOnDisk +from invokeai.backend.model_manager.taxonomy import ( + BaseModelType, + ModelFormat, + ModelType, +) + + +class TI_Config_Base(ABC, BaseModel): + type: Literal[ModelType.TextualInversion] = Field(default=ModelType.TextualInversion) + + @classmethod + def _validate_base(cls, mod: ModelOnDisk, path: Path | None = None) -> None: + expected_base = cls.model_fields["base"].default + recognized_base = cls._get_base_or_raise(mod, path) + if expected_base is not recognized_base: + raise NotAMatchError(f"base is {recognized_base}, not {expected_base}") + + @classmethod + def _file_looks_like_embedding(cls, mod: ModelOnDisk, path: Path | None = None) -> bool: + try: + p = path or mod.path + + if not p.exists(): + return False + + if p.is_dir(): + return False + + if p.name in [f"learned_embeds.{s}" for s in mod.weight_files()]: + return True + + state_dict = mod.load_state_dict(p) + + # Heuristic: textual inversion embeddings have these keys + if any(key in {"string_to_param", "emb_params", "clip_g"} for key in state_dict.keys()): + return True + + # Heuristic: small state dict with all tensor values + if (len(state_dict)) < 10 and all(isinstance(v, torch.Tensor) for v in state_dict.values()): + return True + + return False + except Exception: + return False + + @classmethod + def _get_base_or_raise(cls, mod: ModelOnDisk, path: Path | None = None) -> BaseModelType: + p = path or mod.path + + try: + state_dict = mod.load_state_dict(p) + except Exception as e: + raise NotAMatchError(f"unable to load state dict from {p}: {e}") from e + + try: + if "string_to_token" in state_dict: + token_dim = list(state_dict["string_to_param"].values())[0].shape[-1] + elif "emb_params" in state_dict: + token_dim = state_dict["emb_params"].shape[-1] + elif "clip_g" in state_dict: + token_dim = state_dict["clip_g"].shape[-1] + else: + token_dim = list(state_dict.values())[0].shape[0] + except Exception as e: + raise NotAMatchError(f"unable to determine token dimension from state dict in {p}: {e}") from e + + match token_dim: + case 768: + return BaseModelType.StableDiffusion1 + case 1024: + return BaseModelType.StableDiffusion2 + case 1280: + return BaseModelType.StableDiffusionXL + case _: + raise NotAMatchError(f"unrecognized token dimension {token_dim}") + + +class TI_File_Config_Base(TI_Config_Base): + """Model config for textual inversion embeddings.""" + + format: Literal[ModelFormat.EmbeddingFile] = Field(default=ModelFormat.EmbeddingFile) + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self: + raise_if_not_file(mod) + + raise_for_override_fields(cls, override_fields) + + if not cls._file_looks_like_embedding(mod): + raise NotAMatchError("model does not look like a textual inversion embedding file") + + cls._validate_base(mod) + + return cls(**override_fields) + + +class TI_File_SD1_Config(TI_File_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) + + +class TI_File_SD2_Config(TI_File_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2) + + +class TI_File_SDXL_Config(TI_File_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) + + +class TI_Folder_Config_Base(TI_Config_Base): + """Model config for textual inversion embeddings.""" + + format: Literal[ModelFormat.EmbeddingFolder] = Field(default=ModelFormat.EmbeddingFolder) + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self: + raise_if_not_dir(mod) + + raise_for_override_fields(cls, override_fields) + + for p in mod.weight_files(): + if cls._file_looks_like_embedding(mod, p): + cls._validate_base(mod, p) + return cls(**override_fields) + + raise NotAMatchError("model does not look like a textual inversion embedding folder") + + +class TI_Folder_SD1_Config(TI_Folder_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) + + +class TI_Folder_SD2_Config(TI_Folder_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2) + + +class TI_Folder_SDXL_Config(TI_Folder_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) diff --git a/invokeai/backend/model_manager/configs/unknown.py b/invokeai/backend/model_manager/configs/unknown.py new file mode 100644 index 00000000000..13fbee1c928 --- /dev/null +++ b/invokeai/backend/model_manager/configs/unknown.py @@ -0,0 +1,44 @@ +from copy import deepcopy +from typing import Any, Literal, Self + +from pydantic import Field + +from invokeai.app.services.config.config_default import get_config +from invokeai.backend.model_manager.configs.base import Config_Base +from invokeai.backend.model_manager.model_on_disk import ModelOnDisk +from invokeai.backend.model_manager.taxonomy import ( + BaseModelType, + ModelFormat, + ModelType, +) + +app_config = get_config() + + +class Unknown_Config(Config_Base): + """Model config for unknown models, used as a fallback when we cannot positively identify a model.""" + + base: Literal[BaseModelType.Unknown] = Field(default=BaseModelType.Unknown) + type: Literal[ModelType.Unknown] = Field(default=ModelType.Unknown) + format: Literal[ModelFormat.Unknown] = Field(default=ModelFormat.Unknown) + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self: + """Create an Unknown_Config for models that couldn't be positively identified. + + Note: Basic path validation (file extensions, directory structure) is already + performed by ModelConfigFactory before this method is called. + """ + + cloned_override_fields = deepcopy(override_fields) + cloned_override_fields.pop("base", None) + cloned_override_fields.pop("type", None) + cloned_override_fields.pop("format", None) + + return cls( + **cloned_override_fields, + # Override the type/format/base to ensure it's marked as unknown. + base=BaseModelType.Unknown, + type=ModelType.Unknown, + format=ModelFormat.Unknown, + ) diff --git a/invokeai/backend/model_manager/configs/vae.py b/invokeai/backend/model_manager/configs/vae.py new file mode 100644 index 00000000000..2525e0a1e44 --- /dev/null +++ b/invokeai/backend/model_manager/configs/vae.py @@ -0,0 +1,163 @@ +import re +from typing import ( + Literal, + Self, +) + +from pydantic import Field +from typing_extensions import Any + +from invokeai.backend.model_manager.configs.base import Checkpoint_Config_Base, Config_Base, Diffusers_Config_Base +from invokeai.backend.model_manager.configs.identification_utils import ( + NotAMatchError, + common_config_paths, + get_config_dict_or_raise, + raise_for_class_name, + raise_for_override_fields, + raise_if_not_dir, + raise_if_not_file, + state_dict_has_any_keys_starting_with, +) +from invokeai.backend.model_manager.model_on_disk import ModelOnDisk +from invokeai.backend.model_manager.taxonomy import ( + BaseModelType, + ModelFormat, + ModelType, +) + +REGEX_TO_BASE: dict[str, BaseModelType] = { + r"xl": BaseModelType.StableDiffusionXL, + r"sd2": BaseModelType.StableDiffusion2, + r"vae": BaseModelType.StableDiffusion1, + r"FLUX.1-schnell_ae": BaseModelType.Flux, +} + + +class VAE_Checkpoint_Config_Base(Checkpoint_Config_Base): + """Model config for standalone VAE models.""" + + type: Literal[ModelType.VAE] = Field(default=ModelType.VAE) + format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint) + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self: + raise_if_not_file(mod) + + raise_for_override_fields(cls, override_fields) + + cls._validate_looks_like_vae(mod) + + cls._validate_base(mod) + + return cls(**override_fields) + + @classmethod + def _validate_base(cls, mod: ModelOnDisk) -> None: + """Raise `NotAMatch` if the model base does not match this config class.""" + expected_base = cls.model_fields["base"].default + recognized_base = cls._get_base_or_raise(mod) + if expected_base is not recognized_base: + raise NotAMatchError(f"base is {recognized_base}, not {expected_base}") + + @classmethod + def _validate_looks_like_vae(cls, mod: ModelOnDisk) -> None: + if not state_dict_has_any_keys_starting_with( + mod.load_state_dict(), + { + "encoder.conv_in", + "decoder.conv_in", + }, + ): + raise NotAMatchError("model does not match Checkpoint VAE heuristics") + + @classmethod + def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType: + # Heuristic: VAEs of all architectures have a similar structure; the best we can do is guess based on name + for regexp, base in REGEX_TO_BASE.items(): + if re.search(regexp, mod.path.name, re.IGNORECASE): + return base + + raise NotAMatchError("cannot determine base type") + + +class VAE_Checkpoint_SD1_Config(VAE_Checkpoint_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) + + +class VAE_Checkpoint_SD2_Config(VAE_Checkpoint_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2) + + +class VAE_Checkpoint_SDXL_Config(VAE_Checkpoint_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) + + +class VAE_Checkpoint_FLUX_Config(VAE_Checkpoint_Config_Base, Config_Base): + base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) + + +class VAE_Diffusers_Config_Base(Diffusers_Config_Base): + """Model config for standalone VAE models (diffusers version).""" + + type: Literal[ModelType.VAE] = Field(default=ModelType.VAE) + format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers) + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self: + raise_if_not_dir(mod) + + raise_for_override_fields(cls, override_fields) + + raise_for_class_name( + common_config_paths(mod.path), + { + "AutoencoderKL", + "AutoencoderTiny", + }, + ) + + # Unfortunately it is difficult to distinguish SD1 and SDXL VAEs by config alone, so we may need to + # guess based on name if the config is inconclusive. + override_name = override_fields.get("name") + cls._validate_base(mod, override_name) + + return cls(**override_fields) + + @classmethod + def _validate_base(cls, mod: ModelOnDisk, override_name: str | None = None) -> None: + """Raise `NotAMatch` if the model base does not match this config class.""" + expected_base = cls.model_fields["base"].default + recognized_base = cls._get_base_or_raise(mod, override_name) + if expected_base is not recognized_base: + raise NotAMatchError(f"base is {recognized_base}, not {expected_base}") + + @classmethod + def _config_looks_like_sdxl(cls, config: dict[str, Any]) -> bool: + # Heuristic: These config values that distinguish Stability's SD 1.x VAE from their SDXL VAE. + return config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024] + + @classmethod + def _name_looks_like_sdxl(cls, mod: ModelOnDisk, override_name: str | None = None) -> bool: + # Heuristic: SD and SDXL VAE are the same shape (3-channel RGB to 4-channel float scaled down + # by a factor of 8), so we can't necessarily tell them apart by config hyperparameters. Best + # we can do is guess based on name. + return bool(re.search(r"xl\b", override_name or mod.path.name, re.IGNORECASE)) + + @classmethod + def _get_base_or_raise(cls, mod: ModelOnDisk, override_name: str | None = None) -> BaseModelType: + config_dict = get_config_dict_or_raise(common_config_paths(mod.path)) + if cls._config_looks_like_sdxl(config_dict): + return BaseModelType.StableDiffusionXL + elif cls._name_looks_like_sdxl(mod, override_name): + return BaseModelType.StableDiffusionXL + else: + # TODO(psyche): Figure out how to positively identify SD1 here, and raise if we can't. Until then, YOLO. + return BaseModelType.StableDiffusion1 + + +class VAE_Diffusers_SD1_Config(VAE_Diffusers_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) + + +class VAE_Diffusers_SDXL_Config(VAE_Diffusers_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) diff --git a/invokeai/backend/model_manager/legacy_probe.py b/invokeai/backend/model_manager/legacy_probe.py deleted file mode 100644 index 36fd82667d7..00000000000 --- a/invokeai/backend/model_manager/legacy_probe.py +++ /dev/null @@ -1,1169 +0,0 @@ -import json -import re -from pathlib import Path -from typing import Any, Callable, Dict, Literal, Optional, Union - -import picklescan.scanner as pscan -import safetensors.torch -import spandrel -import torch - -import invokeai.backend.util.logging as logger -from invokeai.app.services.config.config_default import get_config -from invokeai.app.util.misc import uuid_string -from invokeai.backend.flux.controlnet.state_dict_utils import ( - is_state_dict_instantx_controlnet, - is_state_dict_xlabs_controlnet, -) -from invokeai.backend.flux.flux_state_dict_utils import get_flux_in_channels_from_state_dict -from invokeai.backend.flux.ip_adapter.state_dict_utils import is_state_dict_xlabs_ip_adapter -from invokeai.backend.flux.redux.flux_redux_state_dict_utils import is_state_dict_likely_flux_redux -from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash -from invokeai.backend.model_manager.config import ( - AnyModelConfig, - ControlAdapterDefaultSettings, - InvalidModelConfigException, - LoraModelDefaultSettings, - MainModelDefaultSettings, - ModelConfigFactory, - SubmodelDefinition, -) -from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import ConfigLoader -from invokeai.backend.model_manager.model_on_disk import ModelOnDisk -from invokeai.backend.model_manager.taxonomy import ( - AnyVariant, - BaseModelType, - ModelFormat, - ModelRepoVariant, - ModelSourceType, - ModelType, - ModelVariantType, - SchedulerPredictionType, - SubModelType, -) -from invokeai.backend.model_manager.util.model_util import ( - get_clip_variant_type, - lora_token_vector_length, - read_checkpoint_meta, -) -from invokeai.backend.patches.lora_conversions.flux_control_lora_utils import is_state_dict_likely_flux_control -from invokeai.backend.patches.lora_conversions.flux_diffusers_lora_conversion_utils import ( - is_state_dict_likely_in_flux_diffusers_format, -) -from invokeai.backend.patches.lora_conversions.flux_kohya_lora_conversion_utils import ( - is_state_dict_likely_in_flux_kohya_format, -) -from invokeai.backend.patches.lora_conversions.flux_onetrainer_lora_conversion_utils import ( - is_state_dict_likely_in_flux_onetrainer_format, -) -from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor -from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader -from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel -from invokeai.backend.util.silence_warnings import SilenceWarnings - -CkptType = Dict[str | int, Any] - -LEGACY_CONFIGS: Dict[BaseModelType, Dict[ModelVariantType, Union[str, Dict[SchedulerPredictionType, str]]]] = { - BaseModelType.StableDiffusion1: { - ModelVariantType.Normal: { - SchedulerPredictionType.Epsilon: "v1-inference.yaml", - SchedulerPredictionType.VPrediction: "v1-inference-v.yaml", - }, - ModelVariantType.Inpaint: "v1-inpainting-inference.yaml", - }, - BaseModelType.StableDiffusion2: { - ModelVariantType.Normal: { - SchedulerPredictionType.Epsilon: "v2-inference.yaml", - SchedulerPredictionType.VPrediction: "v2-inference-v.yaml", - }, - ModelVariantType.Inpaint: { - SchedulerPredictionType.Epsilon: "v2-inpainting-inference.yaml", - SchedulerPredictionType.VPrediction: "v2-inpainting-inference-v.yaml", - }, - ModelVariantType.Depth: "v2-midas-inference.yaml", - }, - BaseModelType.StableDiffusionXL: { - ModelVariantType.Normal: "sd_xl_base.yaml", - ModelVariantType.Inpaint: "sd_xl_inpaint.yaml", - }, - BaseModelType.StableDiffusionXLRefiner: { - ModelVariantType.Normal: "sd_xl_refiner.yaml", - }, -} - - -class ProbeBase(object): - """Base class for probes.""" - - def __init__(self, model_path: Path): - self.model_path = model_path - - def get_base_type(self) -> BaseModelType: - """Get model base type.""" - raise NotImplementedError - - def get_format(self) -> ModelFormat: - """Get model file format.""" - raise NotImplementedError - - def get_variant_type(self) -> Optional[ModelVariantType]: - """Get model variant type.""" - return None - - def get_scheduler_prediction_type(self) -> Optional[SchedulerPredictionType]: - """Get model scheduler prediction type.""" - return None - - def get_image_encoder_model_id(self) -> Optional[str]: - """Get image encoder (IP adapters only).""" - return None - - -class ModelProbe(object): - PROBES: Dict[str, Dict[ModelType, type[ProbeBase]]] = { - "diffusers": {}, - "checkpoint": {}, - "onnx": {}, - } - - CLASS2TYPE = { - "FluxPipeline": ModelType.Main, - "StableDiffusionPipeline": ModelType.Main, - "StableDiffusionInpaintPipeline": ModelType.Main, - "StableDiffusionXLPipeline": ModelType.Main, - "StableDiffusionXLImg2ImgPipeline": ModelType.Main, - "StableDiffusionXLInpaintPipeline": ModelType.Main, - "StableDiffusion3Pipeline": ModelType.Main, - "LatentConsistencyModelPipeline": ModelType.Main, - "AutoencoderKL": ModelType.VAE, - "AutoencoderTiny": ModelType.VAE, - "ControlNetModel": ModelType.ControlNet, - "CLIPVisionModelWithProjection": ModelType.CLIPVision, - "T2IAdapter": ModelType.T2IAdapter, - "CLIPModel": ModelType.CLIPEmbed, - "CLIPTextModel": ModelType.CLIPEmbed, - "T5EncoderModel": ModelType.T5Encoder, - "FluxControlNetModel": ModelType.ControlNet, - "SD3Transformer2DModel": ModelType.Main, - "CLIPTextModelWithProjection": ModelType.CLIPEmbed, - "SiglipModel": ModelType.SigLIP, - "LlavaOnevisionForConditionalGeneration": ModelType.LlavaOnevision, - "CogView4Pipeline": ModelType.Main, - } - - TYPE2VARIANT: Dict[ModelType, Callable[[str], Optional[AnyVariant]]] = {ModelType.CLIPEmbed: get_clip_variant_type} - - @classmethod - def register_probe( - cls, format: Literal["diffusers", "checkpoint", "onnx"], model_type: ModelType, probe_class: type[ProbeBase] - ) -> None: - cls.PROBES[format][model_type] = probe_class - - @classmethod - def probe( - cls, model_path: Path, fields: Optional[Dict[str, Any]] = None, hash_algo: HASHING_ALGORITHMS = "blake3_single" - ) -> AnyModelConfig: - """ - Probe the model at model_path and return its configuration record. - - :param model_path: Path to the model file (checkpoint) or directory (diffusers). - :param fields: An optional dictionary that can be used to override probed - fields. Typically used for fields that don't probe well, such as prediction_type. - - Returns: The appropriate model configuration derived from ModelConfigBase. - """ - if fields is None: - fields = {} - - model_path = model_path.resolve() - - format_type = ModelFormat.Diffusers if model_path.is_dir() else ModelFormat.Checkpoint - model_info = None - model_type = ModelType(fields["type"]) if "type" in fields and fields["type"] else None - if not model_type: - if format_type is ModelFormat.Diffusers: - model_type = cls.get_model_type_from_folder(model_path) - else: - model_type = cls.get_model_type_from_checkpoint(model_path) - format_type = ModelFormat.ONNX if model_type == ModelType.ONNX else format_type - - probe_class = cls.PROBES[format_type].get(model_type) - if not probe_class: - raise InvalidModelConfigException(f"Unhandled combination of {format_type} and {model_type}") - - probe = probe_class(model_path) - - fields["source_type"] = fields.get("source_type") or ModelSourceType.Path - fields["source"] = fields.get("source") or model_path.as_posix() - fields["key"] = fields.get("key", uuid_string()) - fields["path"] = model_path.as_posix() - fields["type"] = fields.get("type") or model_type - fields["base"] = fields.get("base") or probe.get_base_type() - variant_func = cls.TYPE2VARIANT.get(fields["type"], None) - fields["variant"] = ( - fields.get("variant") or (variant_func and variant_func(model_path.as_posix())) or probe.get_variant_type() - ) - fields["prediction_type"] = fields.get("prediction_type") or probe.get_scheduler_prediction_type() - fields["image_encoder_model_id"] = fields.get("image_encoder_model_id") or probe.get_image_encoder_model_id() - fields["name"] = fields.get("name") or cls.get_model_name(model_path) - fields["description"] = ( - fields.get("description") or f"{fields['base'].value} {model_type.value} model {fields['name']}" - ) - fields["format"] = ModelFormat(fields.get("format")) if "format" in fields else probe.get_format() - fields["hash"] = fields.get("hash") or ModelHash(algorithm=hash_algo).hash(model_path) - fields["file_size"] = fields.get("file_size") or ModelOnDisk(model_path).size() - - fields["default_settings"] = fields.get("default_settings") - - if not fields["default_settings"]: - if fields["type"] in {ModelType.ControlNet, ModelType.T2IAdapter, ModelType.ControlLoRa}: - fields["default_settings"] = get_default_settings_control_adapters(fields["name"]) - if fields["type"] in {ModelType.LoRA}: - fields["default_settings"] = get_default_settings_lora() - elif fields["type"] is ModelType.Main: - fields["default_settings"] = get_default_settings_main(fields["base"]) - - if format_type == ModelFormat.Diffusers and isinstance(probe, FolderProbeBase): - fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant() - - # additional fields needed for main and controlnet models - if fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.VAE] and fields["format"] in [ - ModelFormat.Checkpoint, - ModelFormat.BnbQuantizednf4b, - ModelFormat.GGUFQuantized, - ]: - ckpt_config_path = cls._get_checkpoint_config_path( - model_path, - model_type=fields["type"], - base_type=fields["base"], - variant_type=fields["variant"], - prediction_type=fields["prediction_type"], - ) - fields["config_path"] = str(ckpt_config_path) - - # additional fields needed for main non-checkpoint models - elif fields["type"] == ModelType.Main and fields["format"] in [ - ModelFormat.ONNX, - ModelFormat.Olive, - ModelFormat.Diffusers, - ]: - fields["upcast_attention"] = fields.get("upcast_attention") or ( - fields["base"] == BaseModelType.StableDiffusion2 - and fields["prediction_type"] == SchedulerPredictionType.VPrediction - ) - - get_submodels = getattr(probe, "get_submodels", None) - if fields["base"] == BaseModelType.StableDiffusion3 and callable(get_submodels): - fields["submodels"] = get_submodels() - - model_info = ModelConfigFactory.make_config(fields) # , key=fields.get("key", None)) - return model_info - - @classmethod - def get_model_name(cls, model_path: Path) -> str: - if model_path.suffix in {".safetensors", ".bin", ".pt", ".ckpt"}: - return model_path.stem - else: - return model_path.name - - @classmethod - def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: Optional[CkptType] = None) -> ModelType: - if model_path.suffix not in (".bin", ".pt", ".ckpt", ".safetensors", ".pth", ".gguf"): - raise InvalidModelConfigException(f"{model_path}: unrecognized suffix") - - if model_path.name == "learned_embeds.bin": - return ModelType.TextualInversion - - ckpt = checkpoint if checkpoint else read_checkpoint_meta(model_path, scan=True) - ckpt = ckpt.get("state_dict", ckpt) - - if isinstance(ckpt, dict) and is_state_dict_likely_flux_control(ckpt): - return ModelType.ControlLoRa - - if isinstance(ckpt, dict) and is_state_dict_likely_flux_redux(ckpt): - return ModelType.FluxRedux - - for key in [str(k) for k in ckpt.keys()]: - if key.startswith( - ( - "cond_stage_model.", - "first_stage_model.", - "model.diffusion_model.", - # Some FLUX checkpoint files contain transformer keys prefixed with "model.diffusion_model". - # This prefix is typically used to distinguish between multiple models bundled in a single file. - "model.diffusion_model.double_blocks.", - ) - ): - # Keys starting with double_blocks are associated with Flux models - return ModelType.Main - # FLUX models in the official BFL format contain keys with the "double_blocks." prefix, but we must be - # careful to avoid false positives on XLabs FLUX IP-Adapter models. - elif key.startswith("double_blocks.") and "ip_adapter" not in key: - return ModelType.Main - elif key.startswith(("encoder.conv_in", "decoder.conv_in")): - return ModelType.VAE - elif key.startswith(("lora_te_", "lora_unet_", "lora_te1_", "lora_te2_", "lora_transformer_")): - return ModelType.LoRA - # "lora_A.weight" and "lora_B.weight" are associated with models in PEFT format. We don't support all PEFT - # LoRA models, but as of the time of writing, we support Diffusers FLUX PEFT LoRA models. - elif key.endswith(("to_k_lora.up.weight", "to_q_lora.down.weight", "lora_A.weight", "lora_B.weight")): - return ModelType.LoRA - elif key.startswith( - ( - "controlnet", - "control_model", - "input_blocks", - # XLabs FLUX ControlNet models have keys starting with "controlnet_blocks." - # For example: https://huggingface.co/XLabs-AI/flux-controlnet-collections/blob/86ab1e915a389d5857135c00e0d350e9e38a9048/flux-canny-controlnet_v2.safetensors - # TODO(ryand): This is very fragile. XLabs FLUX ControlNet models also contain keys starting with - # "double_blocks.", which we check for above. But, I'm afraid to modify this logic because it is so - # delicate. - "controlnet_blocks", - ) - ): - return ModelType.ControlNet - elif key.startswith( - ( - "image_proj.", - "ip_adapter.", - # XLabs FLUX IP-Adapter models have keys startinh with "ip_adapter_proj_model.". - "ip_adapter_proj_model.", - ) - ): - return ModelType.IPAdapter - elif key in {"emb_params", "string_to_param"}: - return ModelType.TextualInversion - - # diffusers-ti - if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()): - return ModelType.TextualInversion - - # Check if the model can be loaded as a SpandrelImageToImageModel. - # This check is intentionally performed last, as it can be expensive (it requires loading the model from disk). - try: - # It would be nice to avoid having to load the Spandrel model from disk here. A couple of options were - # explored to avoid this: - # 1. Call `SpandrelImageToImageModel.load_from_state_dict(ckpt)`, where `ckpt` is a state_dict on the meta - # device. Unfortunately, some Spandrel models perform operations during initialization that are not - # supported on meta tensors. - # 2. Spandrel has internal logic to determine a model's type from its state_dict before loading the model. - # This logic is not exposed in spandrel's public API. We could copy the logic here, but then we have to - # maintain it, and the risk of false positive detections is higher. - SpandrelImageToImageModel.load_from_file(model_path) - return ModelType.SpandrelImageToImage - except spandrel.UnsupportedModelError: - pass - except Exception as e: - logger.warning( - f"Encountered error while probing to determine if {model_path} is a Spandrel model. Ignoring. Error: {e}" - ) - - raise InvalidModelConfigException(f"Unable to determine model type for {model_path}") - - @classmethod - def get_model_type_from_folder(cls, folder_path: Path) -> ModelType: - """Get the model type of a hugging-face style folder.""" - class_name = None - error_hint = None - for suffix in ["bin", "safetensors"]: - if (folder_path / f"learned_embeds.{suffix}").exists(): - return ModelType.TextualInversion - if (folder_path / f"pytorch_lora_weights.{suffix}").exists(): - return ModelType.LoRA - if (folder_path / "unet/model.onnx").exists(): - return ModelType.ONNX - if (folder_path / "image_encoder.txt").exists(): - return ModelType.IPAdapter - - config_path = None - for p in [ - folder_path / "model_index.json", # pipeline - folder_path / "config.json", # most diffusers - folder_path / "text_encoder_2" / "config.json", # T5 text encoder - folder_path / "text_encoder" / "config.json", # T5 CLIP - ]: - if p.exists(): - config_path = p - break - - if config_path: - with open(config_path, "r") as file: - conf = json.load(file) - if "_class_name" in conf: - class_name = conf["_class_name"] - elif "architectures" in conf: - class_name = conf["architectures"][0] - else: - class_name = None - else: - error_hint = f"No model_index.json or config.json found in {folder_path}." - - if class_name and (type := cls.CLASS2TYPE.get(class_name)): - return type - else: - error_hint = f"class {class_name} is not one of the supported classes [{', '.join(cls.CLASS2TYPE.keys())}]" - - # give up - raise InvalidModelConfigException( - f"Unable to determine model type for {folder_path}" + (f"; {error_hint}" if error_hint else "") - ) - - @classmethod - def _get_checkpoint_config_path( - cls, - model_path: Path, - model_type: ModelType, - base_type: BaseModelType, - variant_type: ModelVariantType, - prediction_type: SchedulerPredictionType, - ) -> Path: - # look for a YAML file adjacent to the model file first - possible_conf = model_path.with_suffix(".yaml") - if possible_conf.exists(): - return possible_conf.absolute() - - if model_type is ModelType.Main: - if base_type == BaseModelType.Flux: - # TODO: Decide between dev/schnell - checkpoint = ModelProbe._scan_and_load_checkpoint(model_path) - state_dict = checkpoint.get("state_dict") or checkpoint - - # HACK: For FLUX, config_file is used as a key into invokeai.backend.flux.util.params during model - # loading. When FLUX support was first added, it was decided that this was the easiest way to support - # the various FLUX formats rather than adding new model types/formats. Be careful when modifying this in - # the future. - if ( - "guidance_in.out_layer.weight" in state_dict - or "model.diffusion_model.guidance_in.out_layer.weight" in state_dict - ): - if variant_type == ModelVariantType.Normal: - config_file = "flux-dev" - elif variant_type == ModelVariantType.Inpaint: - config_file = "flux-dev-fill" - else: - raise ValueError(f"Unexpected FLUX variant type: {variant_type}") - else: - config_file = "flux-schnell" - else: - config_file = LEGACY_CONFIGS[base_type][variant_type] - if isinstance(config_file, dict): # need another tier for sd-2.x models - config_file = config_file[prediction_type] - config_file = f"stable-diffusion/{config_file}" - elif model_type is ModelType.ControlNet: - config_file = ( - "controlnet/cldm_v15.yaml" - if base_type is BaseModelType.StableDiffusion1 - else "controlnet/cldm_v21.yaml" - ) - elif model_type is ModelType.VAE: - config_file = ( - # For flux, this is a key in invokeai.backend.flux.util.ae_params - # Due to model type and format being the descriminator for model configs this - # is used rather than attempting to support flux with separate model types and format - # If changed in the future, please fix me - "flux" - if base_type is BaseModelType.Flux - else "stable-diffusion/v1-inference.yaml" - if base_type is BaseModelType.StableDiffusion1 - else "stable-diffusion/sd_xl_base.yaml" - if base_type is BaseModelType.StableDiffusionXL - else "stable-diffusion/v2-inference.yaml" - ) - else: - raise InvalidModelConfigException( - f"{model_path}: Unrecognized combination of model_type={model_type}, base_type={base_type}" - ) - return Path(config_file) - - @classmethod - def _scan_and_load_checkpoint(cls, model_path: Path) -> CkptType: - with SilenceWarnings(): - if model_path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")): - cls._scan_model(model_path.name, model_path) - model = torch.load(model_path, map_location="cpu") - assert isinstance(model, dict) - return model - elif model_path.suffix.endswith(".gguf"): - return gguf_sd_loader(model_path, compute_dtype=torch.float32) - else: - return safetensors.torch.load_file(model_path) - - @classmethod - def _scan_model(cls, model_name: str, checkpoint: Path) -> None: - """ - Apply picklescanner to the indicated checkpoint and issue a warning - and option to exit if an infected file is identified. - """ - # scan model - scan_result = pscan.scan_file_path(checkpoint) - if scan_result.infected_files != 0: - if get_config().unsafe_disable_picklescan: - logger.warning( - f"The model {model_name} is potentially infected by malware, but picklescan is disabled. " - "Proceeding with caution." - ) - else: - raise RuntimeError(f"The model {model_name} is potentially infected by malware. Aborting import.") - if scan_result.scan_err: - if get_config().unsafe_disable_picklescan: - logger.warning( - f"Error scanning the model at {model_name} for malware, but picklescan is disabled. " - "Proceeding with caution." - ) - else: - raise RuntimeError(f"Error scanning the model at {model_name} for malware. Aborting import.") - - -# Probing utilities -MODEL_NAME_TO_PREPROCESSOR = { - "canny": "canny_image_processor", - "mlsd": "mlsd_image_processor", - "depth": "depth_anything_image_processor", - "bae": "normalbae_image_processor", - "normal": "normalbae_image_processor", - "sketch": "pidi_image_processor", - "scribble": "lineart_image_processor", - "lineart anime": "lineart_anime_image_processor", - "lineart_anime": "lineart_anime_image_processor", - "lineart": "lineart_image_processor", - "soft": "hed_image_processor", - "softedge": "hed_image_processor", - "hed": "hed_image_processor", - "shuffle": "content_shuffle_image_processor", - "pose": "dw_openpose_image_processor", - "mediapipe": "mediapipe_face_processor", - "pidi": "pidi_image_processor", - "zoe": "zoe_depth_image_processor", - "color": "color_map_image_processor", -} - - -def get_default_settings_control_adapters(model_name: str) -> Optional[ControlAdapterDefaultSettings]: - for k, v in MODEL_NAME_TO_PREPROCESSOR.items(): - model_name_lower = model_name.lower() - if k in model_name_lower: - return ControlAdapterDefaultSettings(preprocessor=v) - return None - - -def get_default_settings_lora() -> LoraModelDefaultSettings: - return LoraModelDefaultSettings() - - -def get_default_settings_main(model_base: BaseModelType) -> Optional[MainModelDefaultSettings]: - if model_base is BaseModelType.StableDiffusion1 or model_base is BaseModelType.StableDiffusion2: - return MainModelDefaultSettings(width=512, height=512) - elif model_base is BaseModelType.StableDiffusionXL: - return MainModelDefaultSettings(width=1024, height=1024) - # We don't provide defaults for BaseModelType.StableDiffusionXLRefiner, as they are not standalone models. - return None - - -# ##################################################3 -# Checkpoint probing -# ##################################################3 - - -class CheckpointProbeBase(ProbeBase): - def __init__(self, model_path: Path): - super().__init__(model_path) - self.checkpoint = ModelProbe._scan_and_load_checkpoint(model_path) - - def get_format(self) -> ModelFormat: - state_dict = self.checkpoint.get("state_dict") or self.checkpoint - if ( - "double_blocks.0.img_attn.proj.weight.quant_state.bitsandbytes__nf4" in state_dict - or "model.diffusion_model.double_blocks.0.img_attn.proj.weight.quant_state.bitsandbytes__nf4" in state_dict - ): - return ModelFormat.BnbQuantizednf4b - elif any(isinstance(v, GGMLTensor) for v in state_dict.values()): - return ModelFormat.GGUFQuantized - return ModelFormat("checkpoint") - - def get_variant_type(self) -> ModelVariantType: - model_type = ModelProbe.get_model_type_from_checkpoint(self.model_path, self.checkpoint) - base_type = self.get_base_type() - if model_type != ModelType.Main: - return ModelVariantType.Normal - state_dict = self.checkpoint.get("state_dict") or self.checkpoint - - if base_type == BaseModelType.Flux: - in_channels = get_flux_in_channels_from_state_dict(state_dict) - - if in_channels is None: - # If we cannot find the in_channels, we assume that this is a normal variant. Log a warning. - logger.warning( - f"{self.model_path} does not have img_in.weight or model.diffusion_model.img_in.weight key. Assuming normal variant." - ) - return ModelVariantType.Normal - - # FLUX Model variant types are distinguished by input channels: - # - Unquantized Dev and Schnell have in_channels=64 - # - BNB-NF4 Dev and Schnell have in_channels=1 - # - FLUX Fill has in_channels=384 - # - Unsure of quantized FLUX Fill models - # - Unsure of GGUF-quantized models - if in_channels == 384: - # This is a FLUX Fill model. FLUX Fill needs special handling throughout the application. The variant - # type is used to determine whether to use the fill model or the base model. - return ModelVariantType.Inpaint - else: - # Fall back on "normal" variant type for all other FLUX models. - return ModelVariantType.Normal - - in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1] - if in_channels == 9: - return ModelVariantType.Inpaint - elif in_channels == 5: - return ModelVariantType.Depth - elif in_channels == 4: - return ModelVariantType.Normal - else: - raise InvalidModelConfigException( - f"Cannot determine variant type (in_channels={in_channels}) at {self.model_path}" - ) - - -class PipelineCheckpointProbe(CheckpointProbeBase): - def get_base_type(self) -> BaseModelType: - checkpoint = self.checkpoint - state_dict = self.checkpoint.get("state_dict") or checkpoint - if ( - "double_blocks.0.img_attn.norm.key_norm.scale" in state_dict - or "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in state_dict - ): - return BaseModelType.Flux - key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" - if key_name in state_dict and state_dict[key_name].shape[-1] == 768: - return BaseModelType.StableDiffusion1 - if key_name in state_dict and state_dict[key_name].shape[-1] == 1024: - return BaseModelType.StableDiffusion2 - key_name = "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight" - if key_name in state_dict and state_dict[key_name].shape[-1] == 2048: - return BaseModelType.StableDiffusionXL - elif key_name in state_dict and state_dict[key_name].shape[-1] == 1280: - return BaseModelType.StableDiffusionXLRefiner - else: - raise InvalidModelConfigException("Cannot determine base type") - - def get_scheduler_prediction_type(self) -> SchedulerPredictionType: - """Return model prediction type.""" - type = self.get_base_type() - if type == BaseModelType.StableDiffusion2: - checkpoint = self.checkpoint - state_dict = self.checkpoint.get("state_dict") or checkpoint - key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" - if key_name in state_dict and state_dict[key_name].shape[-1] == 1024: - if "global_step" in checkpoint: - if checkpoint["global_step"] == 220000: - return SchedulerPredictionType.Epsilon - elif checkpoint["global_step"] == 110000: - return SchedulerPredictionType.VPrediction - return SchedulerPredictionType.VPrediction # a guess for sd2 ckpts - - elif type == BaseModelType.StableDiffusion1: - return SchedulerPredictionType.Epsilon # a reasonable guess for sd1 ckpts - else: - return SchedulerPredictionType.Epsilon - - -class VaeCheckpointProbe(CheckpointProbeBase): - def get_base_type(self) -> BaseModelType: - # VAEs of all base types have the same structure, so we wimp out and - # guess using the name. - for regexp, basetype in [ - (r"xl", BaseModelType.StableDiffusionXL), - (r"sd2", BaseModelType.StableDiffusion2), - (r"vae", BaseModelType.StableDiffusion1), - (r"FLUX.1-schnell_ae", BaseModelType.Flux), - ]: - if re.search(regexp, self.model_path.name, re.IGNORECASE): - return basetype - raise InvalidModelConfigException("Cannot determine base type") - - -class LoRACheckpointProbe(CheckpointProbeBase): - """Class for LoRA checkpoints.""" - - def get_format(self) -> ModelFormat: - if is_state_dict_likely_in_flux_diffusers_format(self.checkpoint): - # TODO(ryand): This is an unusual case. In other places throughout the codebase, we treat - # ModelFormat.Diffusers as meaning that the model is in a directory. In this case, the model is a single - # file, but the weight keys are in the diffusers format. - return ModelFormat.Diffusers - return ModelFormat.LyCORIS - - def get_base_type(self) -> BaseModelType: - if ( - is_state_dict_likely_in_flux_kohya_format(self.checkpoint) - or is_state_dict_likely_in_flux_onetrainer_format(self.checkpoint) - or is_state_dict_likely_in_flux_diffusers_format(self.checkpoint) - or is_state_dict_likely_flux_control(self.checkpoint) - ): - return BaseModelType.Flux - - # If we've gotten here, we assume that the model is a Stable Diffusion model. - token_vector_length = lora_token_vector_length(self.checkpoint) - if token_vector_length == 768: - return BaseModelType.StableDiffusion1 - elif token_vector_length == 1024: - return BaseModelType.StableDiffusion2 - elif token_vector_length == 1280: - return BaseModelType.StableDiffusionXL # recognizes format at https://civitai.com/models/224641 - elif token_vector_length == 2048: - return BaseModelType.StableDiffusionXL - else: - raise InvalidModelConfigException(f"Unknown LoRA type: {self.model_path}") - - -class TextualInversionCheckpointProbe(CheckpointProbeBase): - """Class for probing embeddings.""" - - def get_format(self) -> ModelFormat: - return ModelFormat.EmbeddingFile - - def get_base_type(self) -> BaseModelType: - checkpoint = self.checkpoint - if "string_to_token" in checkpoint: - token_dim = list(checkpoint["string_to_param"].values())[0].shape[-1] - elif "emb_params" in checkpoint: - token_dim = checkpoint["emb_params"].shape[-1] - elif "clip_g" in checkpoint: - token_dim = checkpoint["clip_g"].shape[-1] - else: - token_dim = list(checkpoint.values())[0].shape[0] - if token_dim == 768: - return BaseModelType.StableDiffusion1 - elif token_dim == 1024: - return BaseModelType.StableDiffusion2 - elif token_dim == 1280: - return BaseModelType.StableDiffusionXL - else: - raise InvalidModelConfigException(f"{self.model_path}: Could not determine base type") - - -class ControlNetCheckpointProbe(CheckpointProbeBase): - """Class for probing controlnets.""" - - def get_base_type(self) -> BaseModelType: - checkpoint = self.checkpoint - if is_state_dict_xlabs_controlnet(checkpoint) or is_state_dict_instantx_controlnet(checkpoint): - # TODO(ryand): Should I distinguish between XLabs, InstantX and other ControlNet models by implementing - # get_format()? - return BaseModelType.Flux - - for key_name in ( - "control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight", - "controlnet_mid_block.bias", - "input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight", - "down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_k.weight", - ): - if key_name not in checkpoint: - continue - width = checkpoint[key_name].shape[-1] - if width == 768: - return BaseModelType.StableDiffusion1 - elif width == 1024: - return BaseModelType.StableDiffusion2 - elif width == 2048: - return BaseModelType.StableDiffusionXL - elif width == 1280: - return BaseModelType.StableDiffusionXL - raise InvalidModelConfigException(f"{self.model_path}: Unable to determine base type") - - -class IPAdapterCheckpointProbe(CheckpointProbeBase): - """Class for probing IP Adapters""" - - def get_base_type(self) -> BaseModelType: - checkpoint = self.checkpoint - - if is_state_dict_xlabs_ip_adapter(checkpoint): - return BaseModelType.Flux - - for key in checkpoint.keys(): - if not key.startswith(("image_proj.", "ip_adapter.")): - continue - cross_attention_dim = checkpoint["ip_adapter.1.to_k_ip.weight"].shape[-1] - if cross_attention_dim == 768: - return BaseModelType.StableDiffusion1 - elif cross_attention_dim == 1024: - return BaseModelType.StableDiffusion2 - elif cross_attention_dim == 2048: - return BaseModelType.StableDiffusionXL - else: - raise InvalidModelConfigException( - f"IP-Adapter had unexpected cross-attention dimension: {cross_attention_dim}." - ) - raise InvalidModelConfigException(f"{self.model_path}: Unable to determine base type") - - -class CLIPVisionCheckpointProbe(CheckpointProbeBase): - def get_base_type(self) -> BaseModelType: - raise NotImplementedError() - - -class T2IAdapterCheckpointProbe(CheckpointProbeBase): - def get_base_type(self) -> BaseModelType: - raise NotImplementedError() - - -class SpandrelImageToImageCheckpointProbe(CheckpointProbeBase): - def get_base_type(self) -> BaseModelType: - return BaseModelType.Any - - -class SigLIPCheckpointProbe(CheckpointProbeBase): - def get_base_type(self) -> BaseModelType: - raise NotImplementedError() - - -class FluxReduxCheckpointProbe(CheckpointProbeBase): - def get_base_type(self) -> BaseModelType: - return BaseModelType.Flux - - -class LlavaOnevisionCheckpointProbe(CheckpointProbeBase): - def get_base_type(self) -> BaseModelType: - raise NotImplementedError() - - -######################################################## -# classes for probing folders -####################################################### -class FolderProbeBase(ProbeBase): - def get_variant_type(self) -> ModelVariantType: - return ModelVariantType.Normal - - def get_format(self) -> ModelFormat: - return ModelFormat("diffusers") - - def get_repo_variant(self) -> ModelRepoVariant: - # get all files ending in .bin or .safetensors - weight_files = list(self.model_path.glob("**/*.safetensors")) - weight_files.extend(list(self.model_path.glob("**/*.bin"))) - for x in weight_files: - if ".fp16" in x.suffixes: - return ModelRepoVariant.FP16 - if "openvino_model" in x.name: - return ModelRepoVariant.OpenVINO - if "flax_model" in x.name: - return ModelRepoVariant.Flax - if x.suffix == ".onnx": - return ModelRepoVariant.ONNX - return ModelRepoVariant.Default - - -class PipelineFolderProbe(FolderProbeBase): - def get_base_type(self) -> BaseModelType: - # Handle pipelines with a UNet (i.e SD 1.x, SD2, SDXL). - config_path = self.model_path / "unet" / "config.json" - if config_path.exists(): - with open(config_path) as file: - unet_conf = json.load(file) - if unet_conf["cross_attention_dim"] == 768: - return BaseModelType.StableDiffusion1 - elif unet_conf["cross_attention_dim"] == 1024: - return BaseModelType.StableDiffusion2 - elif unet_conf["cross_attention_dim"] == 1280: - return BaseModelType.StableDiffusionXLRefiner - elif unet_conf["cross_attention_dim"] == 2048: - return BaseModelType.StableDiffusionXL - else: - raise InvalidModelConfigException(f"Unknown base model for {self.model_path}") - - # Handle pipelines with a transformer (i.e. SD3). - config_path = self.model_path / "transformer" / "config.json" - if config_path.exists(): - with open(config_path) as file: - transformer_conf = json.load(file) - if transformer_conf["_class_name"] == "SD3Transformer2DModel": - return BaseModelType.StableDiffusion3 - elif transformer_conf["_class_name"] == "CogView4Transformer2DModel": - return BaseModelType.CogView4 - else: - raise InvalidModelConfigException(f"Unknown base model for {self.model_path}") - - raise InvalidModelConfigException(f"Unknown base model for {self.model_path}") - - def get_scheduler_prediction_type(self) -> SchedulerPredictionType: - with open(self.model_path / "scheduler" / "scheduler_config.json", "r") as file: - scheduler_conf = json.load(file) - if scheduler_conf.get("prediction_type", "epsilon") == "v_prediction": - return SchedulerPredictionType.VPrediction - elif scheduler_conf.get("prediction_type", "epsilon") == "epsilon": - return SchedulerPredictionType.Epsilon - else: - raise InvalidModelConfigException("Unknown scheduler prediction type: {scheduler_conf['prediction_type']}") - - def get_submodels(self) -> Dict[SubModelType, SubmodelDefinition]: - config = ConfigLoader.load_config(self.model_path, config_name="model_index.json") - submodels: Dict[SubModelType, SubmodelDefinition] = {} - for key, value in config.items(): - if key.startswith("_") or not (isinstance(value, list) and len(value) == 2): - continue - model_loader = str(value[1]) - if model_type := ModelProbe.CLASS2TYPE.get(model_loader): - variant_func = ModelProbe.TYPE2VARIANT.get(model_type, None) - submodels[SubModelType(key)] = SubmodelDefinition( - path_or_prefix=(self.model_path / key).resolve().as_posix(), - model_type=model_type, - variant=variant_func and variant_func((self.model_path / key).as_posix()), - ) - - return submodels - - def get_variant_type(self) -> ModelVariantType: - # This only works for pipelines! Any kind of - # exception results in our returning the - # "normal" variant type - try: - config_file = self.model_path / "unet" / "config.json" - with open(config_file, "r") as file: - conf = json.load(file) - - in_channels = conf["in_channels"] - if in_channels == 9: - return ModelVariantType.Inpaint - elif in_channels == 5: - return ModelVariantType.Depth - elif in_channels == 4: - return ModelVariantType.Normal - except Exception: - pass - return ModelVariantType.Normal - - -class VaeFolderProbe(FolderProbeBase): - def get_base_type(self) -> BaseModelType: - if self._config_looks_like_sdxl(): - return BaseModelType.StableDiffusionXL - elif self._name_looks_like_sdxl(): - # but SD and SDXL VAE are the same shape (3-channel RGB to 4-channel float scaled down - # by a factor of 8), we can't necessarily tell them apart by config hyperparameters. - return BaseModelType.StableDiffusionXL - else: - return BaseModelType.StableDiffusion1 - - def _config_looks_like_sdxl(self) -> bool: - # config values that distinguish Stability's SD 1.x VAE from their SDXL VAE. - config_file = self.model_path / "config.json" - if not config_file.exists(): - raise InvalidModelConfigException(f"Cannot determine base type for {self.model_path}") - with open(config_file, "r") as file: - config = json.load(file) - return config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024] - - def _name_looks_like_sdxl(self) -> bool: - return bool(re.search(r"xl\b", self._guess_name(), re.IGNORECASE)) - - def _guess_name(self) -> str: - name = self.model_path.name - if name == "vae": - name = self.model_path.parent.name - return name - - -class TextualInversionFolderProbe(FolderProbeBase): - def get_format(self) -> ModelFormat: - return ModelFormat.EmbeddingFolder - - def get_base_type(self) -> BaseModelType: - path = self.model_path / "learned_embeds.bin" - if not path.exists(): - raise InvalidModelConfigException( - f"{self.model_path.as_posix()} does not contain expected 'learned_embeds.bin' file" - ) - return TextualInversionCheckpointProbe(path).get_base_type() - - -class T5EncoderFolderProbe(FolderProbeBase): - def get_base_type(self) -> BaseModelType: - return BaseModelType.Any - - def get_format(self) -> ModelFormat: - path = self.model_path / "text_encoder_2" - if (path / "model.safetensors.index.json").exists(): - return ModelFormat.T5Encoder - files = list(path.glob("*.safetensors")) - if len(files) == 0: - raise InvalidModelConfigException(f"{self.model_path.as_posix()}: no .safetensors files found") - - # shortcut: look for the quantization in the name - if any(x for x in files if "llm_int8" in x.as_posix()): - return ModelFormat.BnbQuantizedLlmInt8b - - # more reliable path: probe contents for a 'SCB' key - ckpt = read_checkpoint_meta(files[0], scan=True) - if any("SCB" in x for x in ckpt.keys()): - return ModelFormat.BnbQuantizedLlmInt8b - - raise InvalidModelConfigException(f"{self.model_path.as_posix()}: unknown model format") - - -class ONNXFolderProbe(PipelineFolderProbe): - def get_base_type(self) -> BaseModelType: - # Due to the way the installer is set up, the configuration file for safetensors - # will come along for the ride if both the onnx and safetensors forms - # share the same directory. We take advantage of this here. - if (self.model_path / "unet" / "config.json").exists(): - return super().get_base_type() - else: - logger.warning('Base type probing is not implemented for ONNX models. Assuming "sd-1"') - return BaseModelType.StableDiffusion1 - - def get_format(self) -> ModelFormat: - return ModelFormat("onnx") - - def get_variant_type(self) -> ModelVariantType: - return ModelVariantType.Normal - - -class ControlNetFolderProbe(FolderProbeBase): - def get_base_type(self) -> BaseModelType: - config_file = self.model_path / "config.json" - if not config_file.exists(): - raise InvalidModelConfigException(f"Cannot determine base type for {self.model_path}") - with open(config_file, "r") as file: - config = json.load(file) - - if config.get("_class_name", None) == "FluxControlNetModel": - return BaseModelType.Flux - - # no obvious way to distinguish between sd2-base and sd2-768 - dimension = config["cross_attention_dim"] - if dimension == 768: - return BaseModelType.StableDiffusion1 - if dimension == 1024: - return BaseModelType.StableDiffusion2 - if dimension == 2048: - return BaseModelType.StableDiffusionXL - raise InvalidModelConfigException(f"Unable to determine model base for {self.model_path}") - - -class LoRAFolderProbe(FolderProbeBase): - def get_base_type(self) -> BaseModelType: - model_file = None - for suffix in ["safetensors", "bin"]: - base_file = self.model_path / f"pytorch_lora_weights.{suffix}" - if base_file.exists(): - model_file = base_file - break - if not model_file: - raise InvalidModelConfigException("Unknown LoRA format encountered") - return LoRACheckpointProbe(model_file).get_base_type() - - -class IPAdapterFolderProbe(FolderProbeBase): - def get_format(self) -> ModelFormat: - return ModelFormat.InvokeAI - - def get_base_type(self) -> BaseModelType: - model_file = self.model_path / "ip_adapter.bin" - if not model_file.exists(): - raise InvalidModelConfigException("Unknown IP-Adapter model format.") - - state_dict = torch.load(model_file, map_location="cpu") - cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1] - if cross_attention_dim == 768: - return BaseModelType.StableDiffusion1 - elif cross_attention_dim == 1024: - return BaseModelType.StableDiffusion2 - elif cross_attention_dim == 2048: - return BaseModelType.StableDiffusionXL - else: - raise InvalidModelConfigException( - f"IP-Adapter had unexpected cross-attention dimension: {cross_attention_dim}." - ) - - def get_image_encoder_model_id(self) -> Optional[str]: - encoder_id_path = self.model_path / "image_encoder.txt" - if not encoder_id_path.exists(): - return None - with open(encoder_id_path, "r") as f: - image_encoder_model = f.readline().strip() - return image_encoder_model - - -class CLIPVisionFolderProbe(FolderProbeBase): - def get_base_type(self) -> BaseModelType: - return BaseModelType.Any - - -class CLIPEmbedFolderProbe(FolderProbeBase): - def get_base_type(self) -> BaseModelType: - return BaseModelType.Any - - -class SpandrelImageToImageFolderProbe(FolderProbeBase): - def get_base_type(self) -> BaseModelType: - raise NotImplementedError() - - -class SigLIPFolderProbe(FolderProbeBase): - def get_base_type(self) -> BaseModelType: - return BaseModelType.Any - - -class FluxReduxFolderProbe(FolderProbeBase): - def get_base_type(self) -> BaseModelType: - raise NotImplementedError() - - -class LlaveOnevisionFolderProbe(FolderProbeBase): - def get_base_type(self) -> BaseModelType: - return BaseModelType.Any - - -class T2IAdapterFolderProbe(FolderProbeBase): - def get_base_type(self) -> BaseModelType: - config_file = self.model_path / "config.json" - if not config_file.exists(): - raise InvalidModelConfigException(f"Cannot determine base type for {self.model_path}") - with open(config_file, "r") as file: - config = json.load(file) - - adapter_type = config.get("adapter_type", None) - if adapter_type == "full_adapter_xl": - return BaseModelType.StableDiffusionXL - elif adapter_type == "full_adapter" or "light_adapter": - # I haven't seen any T2I adapter models for SD2, so assume that this is an SD1 adapter. - return BaseModelType.StableDiffusion1 - else: - raise InvalidModelConfigException( - f"Unable to determine base model for '{self.model_path}' (adapter_type = {adapter_type})." - ) - - -# Register probe classes -ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.VAE, VaeFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.LoRA, LoRAFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.ControlLoRa, LoRAFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.T5Encoder, T5EncoderFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.CLIPEmbed, CLIPEmbedFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.SpandrelImageToImage, SpandrelImageToImageFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.SigLIP, SigLIPFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.FluxRedux, FluxReduxFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.LlavaOnevision, LlaveOnevisionFolderProbe) - -ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.VAE, VaeCheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.LoRA, LoRACheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.ControlLoRa, LoRACheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.T2IAdapter, T2IAdapterCheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.SpandrelImageToImage, SpandrelImageToImageCheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.SigLIP, SigLIPCheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.FluxRedux, FluxReduxCheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.LlavaOnevision, LlavaOnevisionCheckpointProbe) - -ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe) diff --git a/invokeai/backend/model_manager/load/load_base.py b/invokeai/backend/model_manager/load/load_base.py index 458fc0cfc0c..a4004afba75 100644 --- a/invokeai/backend/model_manager/load/load_base.py +++ b/invokeai/backend/model_manager/load/load_base.py @@ -12,9 +12,7 @@ import torch from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend.model_manager.config import ( - AnyModelConfig, -) +from invokeai.backend.model_manager.configs.factory import AnyModelConfig from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache from invokeai.backend.model_manager.taxonomy import AnyModel, SubModelType @@ -91,14 +89,6 @@ def __init__(self, config: Optional[AnyModelConfig], cache_record: CacheRecord, self.config = config -# TODO(MM2): -# Some "intermediary" subclasses in the ModelLoaderBase class hierarchy define methods that their subclasses don't -# know about. I think the problem may be related to this class being an ABC. -# -# For example, GenericDiffusersLoader defines `get_hf_load_class()`, and StableDiffusionDiffusersModel attempts to -# call it. However, the method is not defined in the ABC, so it is not guaranteed to be implemented. - - class ModelLoaderBase(ABC): """Abstract base class for loading models into RAM/VRAM.""" diff --git a/invokeai/backend/model_manager/load/load_default.py b/invokeai/backend/model_manager/load/load_default.py index 3c26a956b76..3fb7a574f31 100644 --- a/invokeai/backend/model_manager/load/load_default.py +++ b/invokeai/backend/model_manager/load/load_default.py @@ -6,7 +6,8 @@ from typing import Optional from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend.model_manager.config import AnyModelConfig, DiffusersConfigBase, InvalidModelConfigException +from invokeai.backend.model_manager.configs.base import Diffusers_Config_Base +from invokeai.backend.model_manager.configs.factory import AnyModelConfig from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache, get_model_cache_key @@ -50,7 +51,7 @@ def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubMo model_path = self._get_model_path(model_config) if not model_path.exists(): - raise InvalidModelConfigException(f"Files for model '{model_config.name}' not found at {model_path}") + raise FileNotFoundError(f"Files for model '{model_config.name}' not found at {model_path}") with skip_torch_weight_init(): cache_record = self._load_and_cache(model_config, submodel_type) @@ -90,7 +91,7 @@ def get_size_fs( return calc_model_size_by_fs( model_path=model_path, subfolder=submodel_type.value if submodel_type else None, - variant=config.repo_variant if isinstance(config, DiffusersConfigBase) else None, + variant=config.repo_variant if isinstance(config, Diffusers_Config_Base) else None, ) # This needs to be implemented in the subclass diff --git a/invokeai/backend/model_manager/load/model_loader_registry.py b/invokeai/backend/model_manager/load/model_loader_registry.py index ecc4d1fe93b..ca9ea56edbe 100644 --- a/invokeai/backend/model_manager/load/model_loader_registry.py +++ b/invokeai/backend/model_manager/load/model_loader_registry.py @@ -18,10 +18,8 @@ from abc import ABC, abstractmethod from typing import Callable, Dict, Optional, Tuple, Type, TypeVar -from invokeai.backend.model_manager.config import ( - AnyModelConfig, - ModelConfigBase, -) +from invokeai.backend.model_manager.configs.base import Config_Base +from invokeai.backend.model_manager.configs.factory import AnyModelConfig from invokeai.backend.model_manager.load import ModelLoaderBase from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType, SubModelType @@ -40,7 +38,7 @@ def register( @abstractmethod def get_implementation( cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] - ) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]: + ) -> Tuple[Type[ModelLoaderBase], Config_Base, Optional[SubModelType]]: """ Get subclass of ModelLoaderBase registered to handle base and type. @@ -84,7 +82,7 @@ def decorator(subclass: Type[TModelLoader]) -> Type[TModelLoader]: @classmethod def get_implementation( cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] - ) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]: + ) -> Tuple[Type[ModelLoaderBase], Config_Base, Optional[SubModelType]]: """Get subclass of ModelLoaderBase registered to handle base and type.""" key1 = cls._to_registry_key(config.base, config.type, config.format) # for a specific base type diff --git a/invokeai/backend/model_manager/load/model_loaders/clip_vision.py b/invokeai/backend/model_manager/load/model_loaders/clip_vision.py index 29d7bc691cf..0150e24248f 100644 --- a/invokeai/backend/model_manager/load/model_loaders/clip_vision.py +++ b/invokeai/backend/model_manager/load/model_loaders/clip_vision.py @@ -3,10 +3,8 @@ from transformers import CLIPVisionModelWithProjection -from invokeai.backend.model_manager.config import ( - AnyModelConfig, - DiffusersConfigBase, -) +from invokeai.backend.model_manager.configs.base import Diffusers_Config_Base +from invokeai.backend.model_manager.configs.factory import AnyModelConfig from invokeai.backend.model_manager.load.load_default import ModelLoader from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry from invokeai.backend.model_manager.taxonomy import AnyModel, BaseModelType, ModelFormat, ModelType, SubModelType @@ -21,7 +19,7 @@ def _load_model( config: AnyModelConfig, submodel_type: Optional[SubModelType] = None, ) -> AnyModel: - if not isinstance(config, DiffusersConfigBase): + if not isinstance(config, Diffusers_Config_Base): raise ValueError("Only DiffusersConfigBase models are currently supported here.") if submodel_type is not None: diff --git a/invokeai/backend/model_manager/load/model_loaders/cogview4.py b/invokeai/backend/model_manager/load/model_loaders/cogview4.py index e7669a33c42..782ff38450c 100644 --- a/invokeai/backend/model_manager/load/model_loaders/cogview4.py +++ b/invokeai/backend/model_manager/load/model_loaders/cogview4.py @@ -3,11 +3,8 @@ import torch -from invokeai.backend.model_manager.config import ( - AnyModelConfig, - CheckpointConfigBase, - DiffusersConfigBase, -) +from invokeai.backend.model_manager.configs.base import Checkpoint_Config_Base, Diffusers_Config_Base +from invokeai.backend.model_manager.configs.factory import AnyModelConfig from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader from invokeai.backend.model_manager.taxonomy import ( @@ -28,7 +25,7 @@ def _load_model( config: AnyModelConfig, submodel_type: Optional[SubModelType] = None, ) -> AnyModel: - if isinstance(config, CheckpointConfigBase): + if isinstance(config, Checkpoint_Config_Base): raise NotImplementedError("CheckpointConfigBase is not implemented for CogView4 models.") if submodel_type is None: @@ -36,7 +33,7 @@ def _load_model( model_path = Path(config.path) load_class = self.get_hf_load_class(model_path, submodel_type) - repo_variant = config.repo_variant if isinstance(config, DiffusersConfigBase) else None + repo_variant = config.repo_variant if isinstance(config, Diffusers_Config_Base) else None variant = repo_variant.value if repo_variant else None model_path = model_path / submodel_type.value diff --git a/invokeai/backend/model_manager/load/model_loaders/controlnet.py b/invokeai/backend/model_manager/load/model_loaders/controlnet.py index 5bf93db3816..8fd1796b8f5 100644 --- a/invokeai/backend/model_manager/load/model_loaders/controlnet.py +++ b/invokeai/backend/model_manager/load/model_loaders/controlnet.py @@ -5,10 +5,8 @@ from diffusers import ControlNetModel -from invokeai.backend.model_manager.config import ( - AnyModelConfig, - ControlNetCheckpointConfig, -) +from invokeai.backend.model_manager.configs.controlnet import ControlNet_Checkpoint_Config_Base +from invokeai.backend.model_manager.configs.factory import AnyModelConfig from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader from invokeai.backend.model_manager.taxonomy import ( @@ -46,7 +44,7 @@ def _load_model( config: AnyModelConfig, submodel_type: Optional[SubModelType] = None, ) -> AnyModel: - if isinstance(config, ControlNetCheckpointConfig): + if isinstance(config, ControlNet_Checkpoint_Config_Base): return ControlNetModel.from_single_file( config.path, torch_dtype=self._torch_dtype, diff --git a/invokeai/backend/model_manager/load/model_loaders/flux.py b/invokeai/backend/model_manager/load/model_loaders/flux.py index 6ea7b539252..e44ddec382c 100644 --- a/invokeai/backend/model_manager/load/model_loaders/flux.py +++ b/invokeai/backend/model_manager/load/model_loaders/flux.py @@ -33,27 +33,29 @@ from invokeai.backend.flux.model import Flux from invokeai.backend.flux.modules.autoencoder import AutoEncoder from invokeai.backend.flux.redux.flux_redux_model import FluxReduxModel -from invokeai.backend.flux.util import ae_params, params -from invokeai.backend.model_manager.config import ( - AnyModelConfig, - CheckpointConfigBase, - CLIPEmbedDiffusersConfig, - ControlNetCheckpointConfig, - ControlNetDiffusersConfig, - FluxReduxConfig, - IPAdapterCheckpointConfig, - MainBnbQuantized4bCheckpointConfig, - MainCheckpointConfig, - MainGGUFCheckpointConfig, - T5EncoderBnbQuantizedLlmInt8bConfig, - T5EncoderConfig, - VAECheckpointConfig, +from invokeai.backend.flux.util import get_flux_ae_params, get_flux_transformers_params +from invokeai.backend.model_manager.configs.base import Checkpoint_Config_Base +from invokeai.backend.model_manager.configs.clip_embed import CLIPEmbed_Diffusers_Config_Base +from invokeai.backend.model_manager.configs.controlnet import ( + ControlNet_Checkpoint_Config_Base, + ControlNet_Diffusers_Config_Base, ) +from invokeai.backend.model_manager.configs.factory import AnyModelConfig +from invokeai.backend.model_manager.configs.flux_redux import FLUXRedux_Checkpoint_Config +from invokeai.backend.model_manager.configs.ip_adapter import IPAdapter_Checkpoint_Config_Base +from invokeai.backend.model_manager.configs.main import ( + Main_BnBNF4_FLUX_Config, + Main_Checkpoint_FLUX_Config, + Main_GGUF_FLUX_Config, +) +from invokeai.backend.model_manager.configs.t5_encoder import T5Encoder_BnBLLMint8_Config, T5Encoder_T5Encoder_Config +from invokeai.backend.model_manager.configs.vae import VAE_Checkpoint_Config_Base from invokeai.backend.model_manager.load.load_default import ModelLoader from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry from invokeai.backend.model_manager.taxonomy import ( AnyModel, BaseModelType, + FluxVariantType, ModelFormat, ModelType, SubModelType, @@ -85,12 +87,12 @@ def _load_model( config: AnyModelConfig, submodel_type: Optional[SubModelType] = None, ) -> AnyModel: - if not isinstance(config, VAECheckpointConfig): + if not isinstance(config, VAE_Checkpoint_Config_Base): raise ValueError("Only VAECheckpointConfig models are currently supported here.") model_path = Path(config.path) with accelerate.init_empty_weights(): - model = AutoEncoder(ae_params[config.config_path]) + model = AutoEncoder(get_flux_ae_params()) sd = load_file(model_path) model.load_state_dict(sd, assign=True) # VAE is broken in float16, which mps defaults to @@ -107,7 +109,7 @@ def _load_model( @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.CLIPEmbed, format=ModelFormat.Diffusers) -class ClipCheckpointModel(ModelLoader): +class CLIPDiffusersLoader(ModelLoader): """Class to load main models.""" def _load_model( @@ -115,7 +117,7 @@ def _load_model( config: AnyModelConfig, submodel_type: Optional[SubModelType] = None, ) -> AnyModel: - if not isinstance(config, CLIPEmbedDiffusersConfig): + if not isinstance(config, CLIPEmbed_Diffusers_Config_Base): raise ValueError("Only CLIPEmbedDiffusersConfig models are currently supported here.") match submodel_type: @@ -138,7 +140,7 @@ def _load_model( config: AnyModelConfig, submodel_type: Optional[SubModelType] = None, ) -> AnyModel: - if not isinstance(config, T5EncoderBnbQuantizedLlmInt8bConfig): + if not isinstance(config, T5Encoder_BnBLLMint8_Config): raise ValueError("Only T5EncoderBnbQuantizedLlmInt8bConfig models are currently supported here.") if not bnb_available: raise ImportError( @@ -185,7 +187,7 @@ def _load_model( config: AnyModelConfig, submodel_type: Optional[SubModelType] = None, ) -> AnyModel: - if not isinstance(config, T5EncoderConfig): + if not isinstance(config, T5Encoder_T5Encoder_Config): raise ValueError("Only T5EncoderConfig models are currently supported here.") match submodel_type: @@ -210,7 +212,7 @@ def _load_model( config: AnyModelConfig, submodel_type: Optional[SubModelType] = None, ) -> AnyModel: - if not isinstance(config, CheckpointConfigBase): + if not isinstance(config, Checkpoint_Config_Base): raise ValueError("Only CheckpointConfigBase models are currently supported here.") match submodel_type: @@ -225,11 +227,11 @@ def _load_from_singlefile( self, config: AnyModelConfig, ) -> AnyModel: - assert isinstance(config, MainCheckpointConfig) + assert isinstance(config, Main_Checkpoint_FLUX_Config) model_path = Path(config.path) with accelerate.init_empty_weights(): - model = Flux(params[config.config_path]) + model = Flux(get_flux_transformers_params(config.variant)) sd = load_file(model_path) if "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in sd: @@ -252,7 +254,7 @@ def _load_model( config: AnyModelConfig, submodel_type: Optional[SubModelType] = None, ) -> AnyModel: - if not isinstance(config, CheckpointConfigBase): + if not isinstance(config, Checkpoint_Config_Base): raise ValueError("Only CheckpointConfigBase models are currently supported here.") match submodel_type: @@ -267,11 +269,11 @@ def _load_from_singlefile( self, config: AnyModelConfig, ) -> AnyModel: - assert isinstance(config, MainGGUFCheckpointConfig) + assert isinstance(config, Main_GGUF_FLUX_Config) model_path = Path(config.path) with accelerate.init_empty_weights(): - model = Flux(params[config.config_path]) + model = Flux(get_flux_transformers_params(config.variant)) # HACK(ryand): We shouldn't be hard-coding the compute_dtype here. sd = gguf_sd_loader(model_path, compute_dtype=torch.bfloat16) @@ -298,7 +300,7 @@ def _load_model( config: AnyModelConfig, submodel_type: Optional[SubModelType] = None, ) -> AnyModel: - if not isinstance(config, CheckpointConfigBase): + if not isinstance(config, Checkpoint_Config_Base): raise ValueError("Only CheckpointConfigBase models are currently supported here.") match submodel_type: @@ -313,7 +315,7 @@ def _load_from_singlefile( self, config: AnyModelConfig, ) -> AnyModel: - assert isinstance(config, MainBnbQuantized4bCheckpointConfig) + assert isinstance(config, Main_BnBNF4_FLUX_Config) if not bnb_available: raise ImportError( "The bnb modules are not available. Please install bitsandbytes if available on your platform." @@ -322,7 +324,7 @@ def _load_from_singlefile( with SilenceWarnings(): with accelerate.init_empty_weights(): - model = Flux(params[config.config_path]) + model = Flux(get_flux_transformers_params(config.variant)) model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16) sd = load_file(model_path) if "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in sd: @@ -341,9 +343,9 @@ def _load_model( config: AnyModelConfig, submodel_type: Optional[SubModelType] = None, ) -> AnyModel: - if isinstance(config, ControlNetCheckpointConfig): + if isinstance(config, ControlNet_Checkpoint_Config_Base): model_path = Path(config.path) - elif isinstance(config, ControlNetDiffusersConfig): + elif isinstance(config, ControlNet_Diffusers_Config_Base): # If this is a diffusers directory, we simply ignore the config file and load from the weight file. model_path = Path(config.path) / "diffusion_pytorch_model.safetensors" else: @@ -362,7 +364,7 @@ def _load_model( def _load_xlabs_controlnet(self, sd: dict[str, torch.Tensor]) -> AnyModel: with accelerate.init_empty_weights(): # HACK(ryand): Is it safe to assume dev here? - model = XLabsControlNetFlux(params["flux-dev"]) + model = XLabsControlNetFlux(get_flux_transformers_params(FluxVariantType.Dev)) model.load_state_dict(sd, assign=True) return model @@ -388,7 +390,7 @@ def _load_model( config: AnyModelConfig, submodel_type: Optional[SubModelType] = None, ) -> AnyModel: - if not isinstance(config, IPAdapterCheckpointConfig): + if not isinstance(config, IPAdapter_Checkpoint_Config_Base): raise ValueError(f"Unexpected model config type: {type(config)}.") sd = load_file(Path(config.path)) @@ -411,7 +413,7 @@ def _load_model( config: AnyModelConfig, submodel_type: Optional[SubModelType] = None, ) -> AnyModel: - if not isinstance(config, FluxReduxConfig): + if not isinstance(config, FLUXRedux_Checkpoint_Config): raise ValueError(f"Unexpected model config type: {type(config)}.") sd = load_file(Path(config.path)) diff --git a/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py b/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py index 8a690583d5d..b888c69edf9 100644 --- a/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py +++ b/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py @@ -8,7 +8,8 @@ from diffusers.configuration_utils import ConfigMixin from diffusers.models.modeling_utils import ModelMixin -from invokeai.backend.model_manager.config import AnyModelConfig, DiffusersConfigBase, InvalidModelConfigException +from invokeai.backend.model_manager.configs.base import Diffusers_Config_Base +from invokeai.backend.model_manager.configs.factory import AnyModelConfig from invokeai.backend.model_manager.load.load_default import ModelLoader from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry from invokeai.backend.model_manager.taxonomy import ( @@ -33,7 +34,7 @@ def _load_model( model_class = self.get_hf_load_class(model_path) if submodel_type is not None: raise Exception(f"There are no submodels in models of type {model_class}") - repo_variant = config.repo_variant if isinstance(config, DiffusersConfigBase) else None + repo_variant = config.repo_variant if isinstance(config, Diffusers_Config_Base) else None variant = repo_variant.value if repo_variant else None try: result: AnyModel = model_class.from_pretrained(model_path, torch_dtype=self._torch_dtype, variant=variant) @@ -56,9 +57,7 @@ def get_hf_load_class(self, model_path: Path, submodel_type: Optional[SubModelTy module, class_name = config[submodel_type.value] result = self._hf_definition_to_type(module=module, class_name=class_name) except KeyError as e: - raise InvalidModelConfigException( - f'The "{submodel_type}" submodel is not available for this model.' - ) from e + raise ValueError(f'The "{submodel_type}" submodel is not available for this model.') from e else: try: config = self._load_diffusers_config(model_path, config_name="config.json") @@ -67,9 +66,9 @@ def get_hf_load_class(self, model_path: Path, submodel_type: Optional[SubModelTy elif class_name := config.get("architectures"): result = self._hf_definition_to_type(module="transformers", class_name=class_name[0]) else: - raise InvalidModelConfigException("Unable to decipher Load Class based on given config.json") + raise RuntimeError("Unable to decipher Load Class based on given config.json") except KeyError as e: - raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e + raise ValueError("An expected config.json file is missing from this model.") from e assert result is not None return result diff --git a/invokeai/backend/model_manager/load/model_loaders/ip_adapter.py b/invokeai/backend/model_manager/load/model_loaders/ip_adapter.py index d103bc5dbcb..d133a36498c 100644 --- a/invokeai/backend/model_manager/load/model_loaders/ip_adapter.py +++ b/invokeai/backend/model_manager/load/model_loaders/ip_adapter.py @@ -7,7 +7,7 @@ import torch from invokeai.backend.ip_adapter.ip_adapter import build_ip_adapter -from invokeai.backend.model_manager.config import AnyModelConfig +from invokeai.backend.model_manager.configs.factory import AnyModelConfig from invokeai.backend.model_manager.load import ModelLoader, ModelLoaderRegistry from invokeai.backend.model_manager.taxonomy import AnyModel, BaseModelType, ModelFormat, ModelType, SubModelType from invokeai.backend.raw_model import RawModel diff --git a/invokeai/backend/model_manager/load/model_loaders/llava_onevision.py b/invokeai/backend/model_manager/load/model_loaders/llava_onevision.py index b508137f814..e459bbf2bb1 100644 --- a/invokeai/backend/model_manager/load/model_loaders/llava_onevision.py +++ b/invokeai/backend/model_manager/load/model_loaders/llava_onevision.py @@ -3,9 +3,7 @@ from transformers import LlavaOnevisionForConditionalGeneration -from invokeai.backend.model_manager.config import ( - AnyModelConfig, -) +from invokeai.backend.model_manager.configs.factory import AnyModelConfig from invokeai.backend.model_manager.load.load_default import ModelLoader from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry from invokeai.backend.model_manager.taxonomy import AnyModel, BaseModelType, ModelFormat, ModelType, SubModelType diff --git a/invokeai/backend/model_manager/load/model_loaders/lora.py b/invokeai/backend/model_manager/load/model_loaders/lora.py index 98f54224fad..b97c3efeb1f 100644 --- a/invokeai/backend/model_manager/load/model_loaders/lora.py +++ b/invokeai/backend/model_manager/load/model_loaders/lora.py @@ -9,7 +9,7 @@ from safetensors.torch import load_file from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend.model_manager.config import AnyModelConfig +from invokeai.backend.model_manager.configs.factory import AnyModelConfig from invokeai.backend.model_manager.load.load_default import ModelLoader from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry @@ -30,6 +30,7 @@ lora_model_from_flux_control_state_dict, ) from invokeai.backend.patches.lora_conversions.flux_diffusers_lora_conversion_utils import ( + is_state_dict_likely_in_flux_diffusers_format, lora_model_from_flux_diffusers_state_dict, ) from invokeai.backend.patches.lora_conversions.flux_kohya_lora_conversion_utils import ( @@ -96,15 +97,19 @@ def _load_model( state_dict = convert_sdxl_keys_to_diffusers_format(state_dict) model = lora_model_from_sd_state_dict(state_dict=state_dict) elif self._model_base == BaseModelType.Flux: - if config.format in [ModelFormat.Diffusers, ModelFormat.OMI]: + if config.format is ModelFormat.OMI: # HACK(ryand): We set alpha=None for diffusers PEFT format models. These models are typically # distributed as a single file without the associated metadata containing the alpha value. We chose # alpha=None, because this is treated as alpha=rank internally in `LoRALayerBase.scale()`. alpha=rank # is a popular choice. For example, in the diffusers training scripts: # https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora_flux.py#L1194 + # + # We assume the same for LyCORIS models in diffusers key format. model = lora_model_from_flux_diffusers_state_dict(state_dict=state_dict, alpha=None) - elif config.format == ModelFormat.LyCORIS: - if is_state_dict_likely_in_flux_kohya_format(state_dict=state_dict): + elif config.format is ModelFormat.LyCORIS: + if is_state_dict_likely_in_flux_diffusers_format(state_dict=state_dict): + model = lora_model_from_flux_diffusers_state_dict(state_dict=state_dict, alpha=None) + elif is_state_dict_likely_in_flux_kohya_format(state_dict=state_dict): model = lora_model_from_flux_kohya_state_dict(state_dict=state_dict) elif is_state_dict_likely_in_flux_onetrainer_format(state_dict=state_dict): model = lora_model_from_flux_onetrainer_state_dict(state_dict=state_dict) diff --git a/invokeai/backend/model_manager/load/model_loaders/onnx.py b/invokeai/backend/model_manager/load/model_loaders/onnx.py index 3078d622b4e..a565bb11d05 100644 --- a/invokeai/backend/model_manager/load/model_loaders/onnx.py +++ b/invokeai/backend/model_manager/load/model_loaders/onnx.py @@ -5,7 +5,7 @@ from pathlib import Path from typing import Optional -from invokeai.backend.model_manager.config import AnyModelConfig +from invokeai.backend.model_manager.configs.factory import AnyModelConfig from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader from invokeai.backend.model_manager.taxonomy import ( diff --git a/invokeai/backend/model_manager/load/model_loaders/sig_lip.py b/invokeai/backend/model_manager/load/model_loaders/sig_lip.py index bdf38887a3a..16b8e6c88da 100644 --- a/invokeai/backend/model_manager/load/model_loaders/sig_lip.py +++ b/invokeai/backend/model_manager/load/model_loaders/sig_lip.py @@ -3,9 +3,7 @@ from transformers import SiglipVisionModel -from invokeai.backend.model_manager.config import ( - AnyModelConfig, -) +from invokeai.backend.model_manager.configs.factory import AnyModelConfig from invokeai.backend.model_manager.load.load_default import ModelLoader from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry from invokeai.backend.model_manager.taxonomy import AnyModel, BaseModelType, ModelFormat, ModelType, SubModelType diff --git a/invokeai/backend/model_manager/load/model_loaders/spandrel_image_to_image.py b/invokeai/backend/model_manager/load/model_loaders/spandrel_image_to_image.py index 44cb0277fc4..e6d8f429904 100644 --- a/invokeai/backend/model_manager/load/model_loaders/spandrel_image_to_image.py +++ b/invokeai/backend/model_manager/load/model_loaders/spandrel_image_to_image.py @@ -3,9 +3,7 @@ import torch -from invokeai.backend.model_manager.config import ( - AnyModelConfig, -) +from invokeai.backend.model_manager.configs.factory import AnyModelConfig from invokeai.backend.model_manager.load.load_default import ModelLoader from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry from invokeai.backend.model_manager.taxonomy import AnyModel, BaseModelType, ModelFormat, ModelType, SubModelType diff --git a/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py b/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py index aa692478cad..d0cc5893796 100644 --- a/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py +++ b/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py @@ -4,18 +4,24 @@ from pathlib import Path from typing import Optional -from diffusers import ( - StableDiffusionInpaintPipeline, - StableDiffusionPipeline, +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline +from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import StableDiffusionXLPipeline +from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint import ( StableDiffusionXLInpaintPipeline, - StableDiffusionXLPipeline, ) -from invokeai.backend.model_manager.config import ( - AnyModelConfig, - CheckpointConfigBase, - DiffusersConfigBase, - MainCheckpointConfig, +from invokeai.backend.model_manager.configs.base import Checkpoint_Config_Base, Diffusers_Config_Base +from invokeai.backend.model_manager.configs.factory import AnyModelConfig +from invokeai.backend.model_manager.configs.main import ( + Main_Checkpoint_SD1_Config, + Main_Checkpoint_SD2_Config, + Main_Checkpoint_SDXL_Config, + Main_Checkpoint_SDXLRefiner_Config, + Main_Diffusers_SD1_Config, + Main_Diffusers_SD2_Config, + Main_Diffusers_SDXL_Config, + Main_Diffusers_SDXLRefiner_Config, ) from invokeai.backend.model_manager.load.model_cache.model_cache import get_model_cache_key from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry @@ -58,7 +64,7 @@ def _load_model( config: AnyModelConfig, submodel_type: Optional[SubModelType] = None, ) -> AnyModel: - if isinstance(config, CheckpointConfigBase): + if isinstance(config, Checkpoint_Config_Base): return self._load_from_singlefile(config, submodel_type) if submodel_type is None: @@ -66,7 +72,7 @@ def _load_model( model_path = Path(config.path) load_class = self.get_hf_load_class(model_path, submodel_type) - repo_variant = config.repo_variant if isinstance(config, DiffusersConfigBase) else None + repo_variant = config.repo_variant if isinstance(config, Diffusers_Config_Base) else None variant = repo_variant.value if repo_variant else None model_path = model_path / submodel_type.value try: @@ -107,7 +113,19 @@ def _load_from_singlefile( ModelVariantType.Normal: StableDiffusionXLPipeline, }, } - assert isinstance(config, MainCheckpointConfig) + assert isinstance( + config, + ( + Main_Diffusers_SD1_Config, + Main_Diffusers_SD2_Config, + Main_Diffusers_SDXL_Config, + Main_Diffusers_SDXLRefiner_Config, + Main_Checkpoint_SD1_Config, + Main_Checkpoint_SD2_Config, + Main_Checkpoint_SDXL_Config, + Main_Checkpoint_SDXLRefiner_Config, + ), + ) try: load_class = load_classes[config.base][config.variant] except KeyError as e: diff --git a/invokeai/backend/model_manager/load/model_loaders/textual_inversion.py b/invokeai/backend/model_manager/load/model_loaders/textual_inversion.py index 60ae4ea08b7..2d0411a8df2 100644 --- a/invokeai/backend/model_manager/load/model_loaders/textual_inversion.py +++ b/invokeai/backend/model_manager/load/model_loaders/textual_inversion.py @@ -4,7 +4,7 @@ from pathlib import Path from typing import Optional -from invokeai.backend.model_manager.config import AnyModelConfig +from invokeai.backend.model_manager.configs.factory import AnyModelConfig from invokeai.backend.model_manager.load.load_default import ModelLoader from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry from invokeai.backend.model_manager.taxonomy import ( diff --git a/invokeai/backend/model_manager/load/model_loaders/vae.py b/invokeai/backend/model_manager/load/model_loaders/vae.py index 365fa0a547c..e91903ccdad 100644 --- a/invokeai/backend/model_manager/load/model_loaders/vae.py +++ b/invokeai/backend/model_manager/load/model_loaders/vae.py @@ -3,9 +3,10 @@ from typing import Optional -from diffusers import AutoencoderKL +from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL -from invokeai.backend.model_manager.config import AnyModelConfig, VAECheckpointConfig +from invokeai.backend.model_manager.configs.factory import AnyModelConfig +from invokeai.backend.model_manager.configs.vae import VAE_Checkpoint_Config_Base from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader from invokeai.backend.model_manager.taxonomy import ( @@ -27,7 +28,7 @@ def _load_model( config: AnyModelConfig, submodel_type: Optional[SubModelType] = None, ) -> AnyModel: - if isinstance(config, VAECheckpointConfig): + if isinstance(config, VAE_Checkpoint_Config_Base): return AutoencoderKL.from_single_file( config.path, torch_dtype=self._torch_dtype, diff --git a/invokeai/backend/model_manager/merge.py b/invokeai/backend/model_manager/merge.py deleted file mode 100644 index 03056b10f59..00000000000 --- a/invokeai/backend/model_manager/merge.py +++ /dev/null @@ -1,163 +0,0 @@ -""" -invokeai.backend.model_manager.merge exports: -merge_diffusion_models() -- combine multiple models by location and return a pipeline object -merge_diffusion_models_and_commit() -- combine multiple models by ModelManager ID and write to the models tables - -Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team -""" - -import warnings -from enum import Enum -from pathlib import Path -from typing import Any, List, Optional, Set - -import torch -from diffusers import AutoPipelineForText2Image -from diffusers.utils import logging as dlogging - -from invokeai.app.services.model_install import ModelInstallServiceBase -from invokeai.app.services.model_records.model_records_base import ModelRecordChanges -from invokeai.backend.model_manager import AnyModelConfig, BaseModelType, ModelType, ModelVariantType -from invokeai.backend.model_manager.config import MainDiffusersConfig -from invokeai.backend.util.devices import TorchDevice - - -class MergeInterpolationMethod(str, Enum): - WeightedSum = "weighted_sum" - Sigmoid = "sigmoid" - InvSigmoid = "inv_sigmoid" - AddDifference = "add_difference" - - -class ModelMerger(object): - """Wrapper class for model merge function.""" - - def __init__(self, installer: ModelInstallServiceBase): - """ - Initialize a ModelMerger object with the model installer. - """ - self._installer = installer - self._dtype = TorchDevice.choose_torch_dtype() - - def merge_diffusion_models( - self, - model_paths: List[Path], - alpha: float = 0.5, - interp: Optional[MergeInterpolationMethod] = None, - force: bool = False, - variant: Optional[str] = None, - **kwargs: Any, - ) -> Any: # pipe.merge is an untyped function. - """ - :param model_paths: up to three models, designated by their local paths or HuggingFace repo_ids - :param alpha: The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha - would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2 - :param interp: The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_difference" and None. - Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. - :param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False. - - **kwargs - the default DiffusionPipeline.get_config_dict kwargs: - cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map - """ - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - verbosity = dlogging.get_verbosity() - dlogging.set_verbosity_error() - dtype = torch.float16 if variant == "fp16" else self._dtype - - # Note that checkpoint_merger will not work with downloaded HuggingFace fp16 models - # until upstream https://github.com/huggingface/diffusers/pull/6670 is merged and released. - pipe = AutoPipelineForText2Image.from_pretrained( - model_paths[0], - custom_pipeline="checkpoint_merger", - torch_dtype=dtype, - variant=variant, - ) # type: ignore - merged_pipe = pipe.merge( - pretrained_model_name_or_path_list=model_paths, - alpha=alpha, - interp=interp.value if interp else None, # diffusers API treats None as "weighted sum" - force=force, - torch_dtype=dtype, - variant=variant, - **kwargs, - ) - dlogging.set_verbosity(verbosity) - return merged_pipe - - def merge_diffusion_models_and_save( - self, - model_keys: List[str], - merged_model_name: str, - alpha: float = 0.5, - force: bool = False, - interp: Optional[MergeInterpolationMethod] = None, - merge_dest_directory: Optional[Path] = None, - variant: Optional[str] = None, - **kwargs: Any, - ) -> AnyModelConfig: - """ - :param models: up to three models, designated by their registered InvokeAI model name - :param merged_model_name: name for new model - :param alpha: The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha - would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2 - :param interp: The interpolation method to use for the merging. Supports "weighted_average", "sigmoid", "inv_sigmoid", "add_difference" and None. - Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. Add_difference is A+(B-C). - :param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False. - :param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended) - **kwargs - the default DiffusionPipeline.get_config_dict kwargs: - cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map - """ - model_paths: List[Path] = [] - model_names: List[str] = [] - config = self._installer.app_config - store = self._installer.record_store - base_models: Set[BaseModelType] = set() - variant = None if self._installer.app_config.precision == "float32" else "fp16" - - assert len(model_keys) <= 2 or interp == MergeInterpolationMethod.AddDifference, ( - "When merging three models, only the 'add_difference' merge method is supported" - ) - - for key in model_keys: - info = store.get_model(key) - model_names.append(info.name) - assert isinstance(info, MainDiffusersConfig), ( - f"{info.name} ({info.key}) is not a diffusers model. It must be optimized before merging" - ) - assert info.variant == ModelVariantType("normal"), ( - f"{info.name} ({info.key}) is a {info.variant} model, which cannot currently be merged" - ) - - # tally base models used - base_models.add(info.base) - model_paths.extend([config.models_path / info.path]) - - assert len(base_models) == 1, f"All models to merge must have same base model, but found bases {base_models}" - base_model = base_models.pop() - - merge_method = None if interp == "weighted_sum" else MergeInterpolationMethod(interp) - merged_pipe = self.merge_diffusion_models(model_paths, alpha, merge_method, force, variant=variant, **kwargs) - dump_path = ( - Path(merge_dest_directory) - if merge_dest_directory - else config.models_path / base_model.value / ModelType.Main.value - ) - dump_path.mkdir(parents=True, exist_ok=True) - dump_path = dump_path / merged_model_name - - dtype = torch.float16 if variant == "fp16" else self._dtype - merged_pipe.save_pretrained(dump_path.as_posix(), safe_serialization=True, torch_dtype=dtype, variant=variant) - - # register model and get its unique key - key = self._installer.register_path(dump_path) - - # update model's config - model_config = self._installer.record_store.get_model(key) - model_config.name = merged_model_name - model_config.description = f"Merge of models {', '.join(model_names)}" - - self._installer.record_store.update_model( - key, ModelRecordChanges(name=model_config.name, description=model_config.description) - ) - return model_config diff --git a/invokeai/backend/model_manager/model_on_disk.py b/invokeai/backend/model_manager/model_on_disk.py index 502ca596a62..a77853c8f3d 100644 --- a/invokeai/backend/model_manager/model_on_disk.py +++ b/invokeai/backend/model_manager/model_on_disk.py @@ -30,7 +30,8 @@ def __init__(self, path: Path, hash_algo: HASHING_ALGORITHMS = "blake3_single"): self.hash_algo = hash_algo # Having a cache helps users of ModelOnDisk (i.e. configs) to save state # This prevents redundant computations during matching and parsing - self.cache = {"_CACHED_STATE_DICTS": {}} + self._state_dict_cache: dict[Path, Any] = {} + self._metadata_cache: dict[Path, Any] = {} def hash(self) -> str: return ModelHash(algorithm=self.hash_algo).hash(self.path) @@ -44,16 +45,21 @@ def weight_files(self) -> set[Path]: if self.path.is_file(): return {self.path} extensions = {".safetensors", ".pt", ".pth", ".ckpt", ".bin", ".gguf"} - return {f for f in self.path.rglob("*") if f.suffix in extensions} + return {f for f in self.path.rglob("*") if f.suffix in extensions and f.is_file()} def metadata(self, path: Optional[Path] = None) -> dict[str, str]: + path = path or self.path + if path in self._metadata_cache: + return self._metadata_cache[path] try: with safe_open(self.path, framework="pt", device="cpu") as f: metadata = f.metadata() assert isinstance(metadata, dict) - return metadata except Exception: - return {} + metadata = {} + + self._metadata_cache[path] = metadata + return metadata def repo_variant(self) -> Optional[ModelRepoVariant]: if self.path.is_file(): @@ -73,10 +79,8 @@ def repo_variant(self) -> Optional[ModelRepoVariant]: return ModelRepoVariant.Default def load_state_dict(self, path: Optional[Path] = None) -> StateDict: - sd_cache = self.cache["_CACHED_STATE_DICTS"] - - if path in sd_cache: - return sd_cache[path] + if path in self._state_dict_cache: + return self._state_dict_cache[path] path = self.resolve_weight_file(path) @@ -111,7 +115,7 @@ def load_state_dict(self, path: Optional[Path] = None) -> StateDict: raise ValueError(f"Unrecognized model extension: {path.suffix}") state_dict = checkpoint.get("state_dict", checkpoint) - sd_cache[path] = state_dict + self._state_dict_cache[path] = state_dict return state_dict def resolve_weight_file(self, path: Optional[Path] = None) -> Path: diff --git a/invokeai/backend/model_manager/single_file_config_files.py b/invokeai/backend/model_manager/single_file_config_files.py new file mode 100644 index 00000000000..fa4b9e934b8 --- /dev/null +++ b/invokeai/backend/model_manager/single_file_config_files.py @@ -0,0 +1,93 @@ +from dataclasses import dataclass + +from invokeai.backend.model_manager.configs.factory import AnyModelConfig +from invokeai.backend.model_manager.taxonomy import ( + BaseModelType, + ModelType, + ModelVariantType, + SchedulerPredictionType, +) + + +@dataclass(frozen=True) +class LegacyConfigKey: + type: ModelType + base: BaseModelType + variant: ModelVariantType | None = None + pred: SchedulerPredictionType | None = None + + @classmethod + def from_model_config(cls, config: AnyModelConfig) -> "LegacyConfigKey": + variant = getattr(config, "variant", None) + pred = getattr(config, "prediction_type", None) + return cls(type=config.type, base=config.base, variant=variant, pred=pred) + + +LEGACY_CONFIG_MAP: dict[LegacyConfigKey, str] = { + LegacyConfigKey( + ModelType.Main, + BaseModelType.StableDiffusion1, + ModelVariantType.Normal, + SchedulerPredictionType.Epsilon, + ): "stable-diffusion/v1-inference.yaml", + LegacyConfigKey( + ModelType.Main, + BaseModelType.StableDiffusion1, + ModelVariantType.Normal, + SchedulerPredictionType.VPrediction, + ): "stable-diffusion/v1-inference-v.yaml", + LegacyConfigKey( + ModelType.Main, + BaseModelType.StableDiffusion1, + ModelVariantType.Inpaint, + ): "stable-diffusion/v1-inpainting-inference.yaml", + LegacyConfigKey( + ModelType.Main, + BaseModelType.StableDiffusion2, + ModelVariantType.Normal, + SchedulerPredictionType.Epsilon, + ): "stable-diffusion/v2-inference.yaml", + LegacyConfigKey( + ModelType.Main, + BaseModelType.StableDiffusion2, + ModelVariantType.Normal, + SchedulerPredictionType.VPrediction, + ): "stable-diffusion/v2-inference-v.yaml", + LegacyConfigKey( + ModelType.Main, + BaseModelType.StableDiffusion2, + ModelVariantType.Inpaint, + SchedulerPredictionType.Epsilon, + ): "stable-diffusion/v2-inpainting-inference.yaml", + LegacyConfigKey( + ModelType.Main, + BaseModelType.StableDiffusion2, + ModelVariantType.Inpaint, + SchedulerPredictionType.VPrediction, + ): "stable-diffusion/v2-inpainting-inference-v.yaml", + LegacyConfigKey( + ModelType.Main, + BaseModelType.StableDiffusion2, + ModelVariantType.Depth, + ): "stable-diffusion/v2-midas-inference.yaml", + LegacyConfigKey( + ModelType.Main, + BaseModelType.StableDiffusionXL, + ModelVariantType.Normal, + ): "stable-diffusion/sd_xl_base.yaml", + LegacyConfigKey( + ModelType.Main, + BaseModelType.StableDiffusionXL, + ModelVariantType.Inpaint, + ): "stable-diffusion/sd_xl_inpaint.yaml", + LegacyConfigKey( + ModelType.Main, + BaseModelType.StableDiffusionXLRefiner, + ModelVariantType.Normal, + ): "stable-diffusion/sd_xl_refiner.yaml", + LegacyConfigKey(ModelType.ControlNet, BaseModelType.StableDiffusion1): "controlnet/cldm_v15.yaml", + LegacyConfigKey(ModelType.ControlNet, BaseModelType.StableDiffusion2): "controlnet/cldm_v21.yaml", + LegacyConfigKey(ModelType.VAE, BaseModelType.StableDiffusion1): "stable-diffusion/v1-inference.yaml", + LegacyConfigKey(ModelType.VAE, BaseModelType.StableDiffusion2): "stable-diffusion/v2-inference.yaml", + LegacyConfigKey(ModelType.VAE, BaseModelType.StableDiffusionXL): "stable-diffusion/sd_xl_base.yaml", +} diff --git a/invokeai/backend/model_manager/starter_models.py b/invokeai/backend/model_manager/starter_models.py index 84e98fed4de..8958b6fd3c2 100644 --- a/invokeai/backend/model_manager/starter_models.py +++ b/invokeai/backend/model_manager/starter_models.py @@ -37,16 +37,20 @@ class StarterModelBundle(BaseModel): ) # region CLIP Image Encoders + +# This is CLIP-ViT-H-14-laion2B-s32B-b79K ip_adapter_sd_image_encoder = StarterModel( name="IP Adapter SD1.5 Image Encoder", - base=BaseModelType.StableDiffusion1, + base=BaseModelType.Any, source="InvokeAI/ip_adapter_sd_image_encoder", description="IP Adapter SD Image Encoder", type=ModelType.CLIPVision, ) + +# This is CLIP-ViT-bigG-14-laion2B-39B-b160k ip_adapter_sdxl_image_encoder = StarterModel( name="IP Adapter SDXL Image Encoder", - base=BaseModelType.StableDiffusionXL, + base=BaseModelType.Any, source="InvokeAI/ip_adapter_sdxl_image_encoder", description="IP Adapter SDXL Image Encoder", type=ModelType.CLIPVision, diff --git a/invokeai/backend/model_manager/taxonomy.py b/invokeai/backend/model_manager/taxonomy.py index ba3c8586db4..99a31f438d1 100644 --- a/invokeai/backend/model_manager/taxonomy.py +++ b/invokeai/backend/model_manager/taxonomy.py @@ -1,38 +1,70 @@ from enum import Enum from typing import Dict, TypeAlias, Union -import diffusers import onnxruntime as ort import torch -from diffusers import ModelMixin +from diffusers.models.modeling_utils import ModelMixin +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from pydantic import TypeAdapter from invokeai.backend.raw_model import RawModel # ModelMixin is the base class for all diffusers and transformers models # RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime -AnyModel = Union[ - ModelMixin, RawModel, torch.nn.Module, Dict[str, torch.Tensor], diffusers.DiffusionPipeline, ort.InferenceSession +AnyModel: TypeAlias = Union[ + ModelMixin, + RawModel, + torch.nn.Module, + Dict[str, torch.Tensor], + DiffusionPipeline, + ort.InferenceSession, ] +"""Type alias for any kind of runtime, in-memory model representation. For example, a torch module or diffusers pipeline.""" class BaseModelType(str, Enum): - """Base model type.""" + """An enumeration of base model architectures. For example, Stable Diffusion 1.x, Stable Diffusion 2.x, FLUX, etc. + + Every model config must have a base architecture type. + + Not all models are associated with a base architecture. For example, CLIP models are their own thing, not related + to any particular model architecture. To simplify internal APIs and make it easier to work with models, we use a + fallback/null value `BaseModelType.Any` for these models, instead of making the model base optional.""" Any = "any" + """`Any` is essentially a fallback/null value for models with no base architecture association. + For example, CLIP models are not related to Stable Diffusion, FLUX, or any other model arch.""" StableDiffusion1 = "sd-1" + """Indicates the model is associated with the Stable Diffusion 1.x model architecture, including 1.4 and 1.5.""" StableDiffusion2 = "sd-2" + """Indicates the model is associated with the Stable Diffusion 2.x model architecture, including 2.0 and 2.1.""" StableDiffusion3 = "sd-3" + """Indicates the model is associated with the Stable Diffusion 3.5 model architecture.""" StableDiffusionXL = "sdxl" + """Indicates the model is associated with the Stable Diffusion XL model architecture.""" StableDiffusionXLRefiner = "sdxl-refiner" + """Indicates the model is associated with the Stable Diffusion XL Refiner model architecture.""" Flux = "flux" + """Indicates the model is associated with FLUX.1 model architecture, including FLUX Dev, Schnell and Fill.""" CogView4 = "cogview4" + """Indicates the model is associated with CogView 4 model architecture.""" Imagen3 = "imagen3" + """Indicates the model is associated with Google Imagen 3 model architecture. This is an external API model.""" Imagen4 = "imagen4" + """Indicates the model is associated with Google Imagen 4 model architecture. This is an external API model.""" Gemini2_5 = "gemini-2.5" + """Indicates the model is associated with Google Gemini 2.5 Flash Image model architecture. This is an external API model.""" ChatGPT4o = "chatgpt-4o" + """Indicates the model is associated with OpenAI ChatGPT 4o Image model architecture. This is an external API model.""" FluxKontext = "flux-kontext" + """Indicates the model is associated with FLUX Kontext model architecture. This is an external API model; local FLUX + Kontext models use the base `Flux`.""" Veo3 = "veo3" + """Indicates the model is associated with Google Veo 3 video model architecture. This is an external API model.""" Runway = "runway" + """Indicates the model is associated with Runway video model architecture. This is an external API model.""" + Unknown = "unknown" + """Indicates the model's base architecture is unknown.""" class ModelType(str, Enum): @@ -55,6 +87,7 @@ class ModelType(str, Enum): FluxRedux = "flux_redux" LlavaOnevision = "llava_onevision" Video = "video" + Unknown = "unknown" class SubModelType(str, Enum): @@ -90,6 +123,12 @@ class ModelVariantType(str, Enum): Depth = "depth" +class FluxVariantType(str, Enum): + Schnell = "schnell" + Dev = "dev" + DevFill = "dev_fill" + + class ModelFormat(str, Enum): """Storage format of model.""" @@ -107,6 +146,7 @@ class ModelFormat(str, Enum): BnbQuantizednf4b = "bnb_quantized_nf4b" GGUFQuantized = "gguf_quantized" Api = "api" + Unknown = "unknown" class SchedulerPredictionType(str, Enum): @@ -146,4 +186,7 @@ class FluxLoRAFormat(str, Enum): AIToolkit = "flux.aitoolkit" -AnyVariant: TypeAlias = Union[ModelVariantType, ClipVariantType, None] +AnyVariant: TypeAlias = Union[ModelVariantType, ClipVariantType, FluxVariantType] +variant_type_adapter = TypeAdapter[ModelVariantType | ClipVariantType | FluxVariantType]( + ModelVariantType | ClipVariantType | FluxVariantType +) diff --git a/invokeai/backend/model_manager/util/lora_metadata_extractor.py b/invokeai/backend/model_manager/util/lora_metadata_extractor.py index 842e78a7880..12b10739354 100644 --- a/invokeai/backend/model_manager/util/lora_metadata_extractor.py +++ b/invokeai/backend/model_manager/util/lora_metadata_extractor.py @@ -8,7 +8,8 @@ from PIL import Image from invokeai.app.util.thumbnails import make_thumbnail -from invokeai.backend.model_manager.config import AnyModelConfig, ModelType +from invokeai.backend.model_manager.configs.factory import AnyModelConfig +from invokeai.backend.model_manager.taxonomy import ModelType logger = logging.getLogger(__name__) diff --git a/invokeai/backend/model_manager/util/model_util.py b/invokeai/backend/model_manager/util/model_util.py index 4fa095b5999..c153129353b 100644 --- a/invokeai/backend/model_manager/util/model_util.py +++ b/invokeai/backend/model_manager/util/model_util.py @@ -83,14 +83,14 @@ def read_checkpoint_meta(path: Union[str, Path], scan: bool = True) -> Dict[str, return checkpoint -def lora_token_vector_length(checkpoint: Dict[str, torch.Tensor]) -> Optional[int]: +def lora_token_vector_length(checkpoint: dict[str | int, torch.Tensor]) -> Optional[int]: """ Given a checkpoint in memory, return the lora token vector length :param checkpoint: The checkpoint """ - def _get_shape_1(key: str, tensor: torch.Tensor, checkpoint: Dict[str, torch.Tensor]) -> Optional[int]: + def _get_shape_1(key: str, tensor: torch.Tensor, checkpoint: dict[str | int, torch.Tensor]) -> Optional[int]: lora_token_vector_length = None if "." not in key: @@ -136,6 +136,8 @@ def _get_shape_1(key: str, tensor: torch.Tensor, checkpoint: Dict[str, torch.Ten lora_te1_length = None lora_te2_length = None for key, tensor in checkpoint.items(): + if isinstance(key, int): + continue if key.startswith("lora_unet_") and ("_attn2_to_k." in key or "_attn2_to_v." in key): lora_token_vector_length = _get_shape_1(key, tensor, checkpoint) elif key.startswith("lora_unet_") and ( diff --git a/invokeai/backend/model_patcher.py b/invokeai/backend/model_patcher.py index a1d8bbed0a5..04f99495609 100644 --- a/invokeai/backend/model_patcher.py +++ b/invokeai/backend/model_patcher.py @@ -5,10 +5,10 @@ import pickle from contextlib import contextmanager -from typing import Any, Iterator, List, Optional, Tuple, Type, Union +from typing import Any, Generator, Iterator, List, Optional, Tuple, Type, Union import torch -from diffusers import UNet2DConditionModel +from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from invokeai.app.shared.models import FreeUConfig @@ -146,7 +146,7 @@ def apply_clip_skip( cls, text_encoder: Union[CLIPTextModel, CLIPTextModelWithProjection], clip_skip: int, - ) -> None: + ) -> Generator[None, Any, Any]: skipped_layers = [] try: for _i in range(clip_skip): @@ -164,7 +164,7 @@ def apply_freeu( cls, unet: UNet2DConditionModel, freeu_config: Optional[FreeUConfig] = None, - ) -> None: + ) -> Generator[None, Any, Any]: did_apply_freeu = False try: assert hasattr(unet, "enable_freeu") # mypy doesn't pick up this attribute? diff --git a/invokeai/backend/patches/lora_conversions/flux_aitoolkit_lora_conversion_utils.py b/invokeai/backend/patches/lora_conversions/flux_aitoolkit_lora_conversion_utils.py index 6ca06a0355f..f3c202268a7 100644 --- a/invokeai/backend/patches/lora_conversions/flux_aitoolkit_lora_conversion_utils.py +++ b/invokeai/backend/patches/lora_conversions/flux_aitoolkit_lora_conversion_utils.py @@ -12,7 +12,10 @@ from invokeai.backend.util import InvokeAILogger -def is_state_dict_likely_in_flux_aitoolkit_format(state_dict: dict[str, Any], metadata: dict[str, Any] = None) -> bool: +def is_state_dict_likely_in_flux_aitoolkit_format( + state_dict: dict[str | int, Any], + metadata: dict[str, Any] | None = None, +) -> bool: if metadata: try: software = json.loads(metadata.get("software", "{}")) @@ -20,7 +23,7 @@ def is_state_dict_likely_in_flux_aitoolkit_format(state_dict: dict[str, Any], me return False return software.get("name") == "ai-toolkit" # metadata got lost somewhere - return any("diffusion_model" == k.split(".", 1)[0] for k in state_dict.keys()) + return any("diffusion_model" == k.split(".", 1)[0] for k in state_dict.keys() if isinstance(k, str)) @dataclass diff --git a/invokeai/backend/patches/lora_conversions/flux_control_lora_utils.py b/invokeai/backend/patches/lora_conversions/flux_control_lora_utils.py index fa9cc764628..1762a4d5f4c 100644 --- a/invokeai/backend/patches/lora_conversions/flux_control_lora_utils.py +++ b/invokeai/backend/patches/lora_conversions/flux_control_lora_utils.py @@ -18,14 +18,16 @@ FLUX_CONTROL_TRANSFORMER_KEY_REGEX = r"(\w+\.)+(lora_A\.weight|lora_B\.weight|lora_B\.bias|scale)" -def is_state_dict_likely_flux_control(state_dict: Dict[str, Any]) -> bool: +def is_state_dict_likely_flux_control(state_dict: dict[str | int, Any]) -> bool: """Checks if the provided state dict is likely in the FLUX Control LoRA format. This is intended to be a high-precision detector, but it is not guaranteed to have perfect precision. (A perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.) """ - all_keys_match = all(re.match(FLUX_CONTROL_TRANSFORMER_KEY_REGEX, str(k)) for k in state_dict.keys()) + all_keys_match = all( + re.match(FLUX_CONTROL_TRANSFORMER_KEY_REGEX, k) for k in state_dict.keys() if isinstance(k, str) + ) # Check the shape of the img_in weight, because this layer shape is modified by FLUX control LoRAs. lora_a_weight = state_dict.get("img_in.lora_A.weight", None) diff --git a/invokeai/backend/patches/lora_conversions/flux_diffusers_lora_conversion_utils.py b/invokeai/backend/patches/lora_conversions/flux_diffusers_lora_conversion_utils.py index 188d118cc4d..f5b4bc66847 100644 --- a/invokeai/backend/patches/lora_conversions/flux_diffusers_lora_conversion_utils.py +++ b/invokeai/backend/patches/lora_conversions/flux_diffusers_lora_conversion_utils.py @@ -9,14 +9,16 @@ from invokeai.backend.patches.model_patch_raw import ModelPatchRaw -def is_state_dict_likely_in_flux_diffusers_format(state_dict: Dict[str, torch.Tensor]) -> bool: +def is_state_dict_likely_in_flux_diffusers_format(state_dict: dict[str | int, torch.Tensor]) -> bool: """Checks if the provided state dict is likely in the Diffusers FLUX LoRA format. This is intended to be a reasonably high-precision detector, but it is not guaranteed to have perfect precision. (A perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.) """ # First, check that all keys end in "lora_A.weight" or "lora_B.weight" (i.e. are in PEFT format). - all_keys_in_peft_format = all(k.endswith(("lora_A.weight", "lora_B.weight")) for k in state_dict.keys()) + all_keys_in_peft_format = all( + k.endswith(("lora_A.weight", "lora_B.weight")) for k in state_dict.keys() if isinstance(k, str) + ) # Check if keys use transformer prefix transformer_prefix_keys = [ diff --git a/invokeai/backend/patches/lora_conversions/flux_kohya_lora_conversion_utils.py b/invokeai/backend/patches/lora_conversions/flux_kohya_lora_conversion_utils.py index 7b5f3468963..f5a6830c4f1 100644 --- a/invokeai/backend/patches/lora_conversions/flux_kohya_lora_conversion_utils.py +++ b/invokeai/backend/patches/lora_conversions/flux_kohya_lora_conversion_utils.py @@ -44,7 +44,7 @@ FLUX_KOHYA_T5_KEY_REGEX = r"lora_te2_encoder_block_(\d+)_layer_(\d+)_(DenseReluDense|SelfAttention)_(\w+)_?(\w+)?\.?.*" -def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> bool: +def is_state_dict_likely_in_flux_kohya_format(state_dict: dict[str | int, Any]) -> bool: """Checks if the provided state dict is likely in the Kohya FLUX LoRA format. This is intended to be a high-precision detector, but it is not guaranteed to have perfect precision. (A @@ -56,6 +56,7 @@ def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> boo or re.match(FLUX_KOHYA_CLIP_KEY_REGEX, k) or re.match(FLUX_KOHYA_T5_KEY_REGEX, k) for k in state_dict.keys() + if isinstance(k, str) ) diff --git a/invokeai/backend/patches/lora_conversions/flux_onetrainer_lora_conversion_utils.py b/invokeai/backend/patches/lora_conversions/flux_onetrainer_lora_conversion_utils.py index 0413f0ef49f..88aeee95e49 100644 --- a/invokeai/backend/patches/lora_conversions/flux_onetrainer_lora_conversion_utils.py +++ b/invokeai/backend/patches/lora_conversions/flux_onetrainer_lora_conversion_utils.py @@ -40,7 +40,7 @@ ) -def is_state_dict_likely_in_flux_onetrainer_format(state_dict: Dict[str, Any]) -> bool: +def is_state_dict_likely_in_flux_onetrainer_format(state_dict: dict[str | int, Any]) -> bool: """Checks if the provided state dict is likely in the OneTrainer FLUX LoRA format. This is intended to be a high-precision detector, but it is not guaranteed to have perfect precision. (A @@ -53,6 +53,7 @@ def is_state_dict_likely_in_flux_onetrainer_format(state_dict: Dict[str, Any]) - or re.match(FLUX_KOHYA_CLIP_KEY_REGEX, k) or re.match(FLUX_KOHYA_T5_KEY_REGEX, k) for k in state_dict.keys() + if isinstance(k, str) ) diff --git a/invokeai/backend/patches/lora_conversions/formats.py b/invokeai/backend/patches/lora_conversions/formats.py index 94f71e05ee6..4cde7c98f67 100644 --- a/invokeai/backend/patches/lora_conversions/formats.py +++ b/invokeai/backend/patches/lora_conversions/formats.py @@ -1,3 +1,5 @@ +from typing import Any + from invokeai.backend.model_manager.taxonomy import FluxLoRAFormat from invokeai.backend.patches.lora_conversions.flux_aitoolkit_lora_conversion_utils import ( is_state_dict_likely_in_flux_aitoolkit_format, @@ -14,7 +16,10 @@ ) -def flux_format_from_state_dict(state_dict: dict, metadata: dict | None = None) -> FluxLoRAFormat | None: +def flux_format_from_state_dict( + state_dict: dict[str | int, Any], + metadata: dict[str, Any] | None = None, +) -> FluxLoRAFormat | None: if is_state_dict_likely_in_flux_kohya_format(state_dict): return FluxLoRAFormat.Kohya elif is_state_dict_likely_in_flux_onetrainer_format(state_dict): diff --git a/invokeai/backend/quantization/scripts/load_flux_model_bnb_llm_int8.py b/invokeai/backend/quantization/scripts/load_flux_model_bnb_llm_int8.py index 045ebbbf2c4..8231e313fdc 100644 --- a/invokeai/backend/quantization/scripts/load_flux_model_bnb_llm_int8.py +++ b/invokeai/backend/quantization/scripts/load_flux_model_bnb_llm_int8.py @@ -4,7 +4,8 @@ from safetensors.torch import load_file, save_file from invokeai.backend.flux.model import Flux -from invokeai.backend.flux.util import params +from invokeai.backend.flux.util import get_flux_transformers_params +from invokeai.backend.model_manager.taxonomy import ModelVariantType from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8 from invokeai.backend.quantization.scripts.load_flux_model_bnb_nf4 import log_time @@ -22,7 +23,7 @@ def main(): with log_time("Initialize FLUX transformer on meta device"): # TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config. - p = params["flux-schnell"] + p = get_flux_transformers_params(ModelVariantType.FluxSchnell) # Initialize the model on the "meta" device. with accelerate.init_empty_weights(): diff --git a/invokeai/backend/quantization/scripts/load_flux_model_bnb_nf4.py b/invokeai/backend/quantization/scripts/load_flux_model_bnb_nf4.py index c8802b9e49e..6a4ee3abf93 100644 --- a/invokeai/backend/quantization/scripts/load_flux_model_bnb_nf4.py +++ b/invokeai/backend/quantization/scripts/load_flux_model_bnb_nf4.py @@ -7,7 +7,8 @@ from safetensors.torch import load_file, save_file from invokeai.backend.flux.model import Flux -from invokeai.backend.flux.util import params +from invokeai.backend.flux.util import get_flux_transformers_params +from invokeai.backend.model_manager.taxonomy import ModelVariantType from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4 @@ -35,7 +36,7 @@ def main(): # inference_dtype = torch.bfloat16 with log_time("Initialize FLUX transformer on meta device"): # TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config. - p = params["flux-schnell"] + p = get_flux_transformers_params(ModelVariantType.FluxSchnell) # Initialize the model on the "meta" device. with accelerate.init_empty_weights(): diff --git a/invokeai/backend/util/hotfixes.py b/invokeai/backend/util/hotfixes.py index 95f2c904ad8..7e258b87795 100644 --- a/invokeai/backend/util/hotfixes.py +++ b/invokeai/backend/util/hotfixes.py @@ -23,6 +23,7 @@ from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel from torch import nn +from invokeai.backend.model_manager.taxonomy import BaseModelType, SchedulerPredictionType from invokeai.backend.util.logging import InvokeAILogger # TODO: create PR to diffusers @@ -407,7 +408,8 @@ def from_unet( use_linear_projection=unet.config.use_linear_projection, class_embed_type=unet.config.class_embed_type, num_class_embeds=unet.config.num_class_embeds, - upcast_attention=unet.config.upcast_attention, + upcast_attention=unet.config.base is BaseModelType.StableDiffusion2 + and unet.config.prediction_type is SchedulerPredictionType.VPrediction, resnet_time_scale_shift=unet.config.resnet_time_scale_shift, projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim, controlnet_conditioning_channel_order=controlnet_conditioning_channel_order, diff --git a/invokeai/backend/util/test_utils.py b/invokeai/backend/util/test_utils.py index add394e71be..e4208dc848f 100644 --- a/invokeai/backend/util/test_utils.py +++ b/invokeai/backend/util/test_utils.py @@ -7,7 +7,8 @@ from invokeai.app.services.model_manager import ModelManagerServiceBase from invokeai.app.services.model_records import UnknownModelException -from invokeai.backend.model_manager import BaseModelType, LoadedModel, ModelType, SubModelType +from invokeai.backend.model_manager.load.load_base import LoadedModel +from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType, SubModelType @pytest.fixture(scope="session") diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 64eeaf79fa5..aacd5c728f1 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -914,6 +914,9 @@ "hfTokenReset": "HF Token Reset", "urlUnauthorizedErrorMessage": "You may need to configure an API token to access this model.", "urlUnauthorizedErrorMessage2": "Learn how here.", + "unidentifiedModelTitle": "Unable to identify model", + "unidentifiedModelMessage": "We were unable to identify the type, base and/or format of the installed model. Try editing the model and selecting the appropriate settings for the model.", + "unidentifiedModelMessage2": "If you don't see the correct settings, or the model doesn't work after changing them, ask for help on or create an issue on .", "imageEncoderModelId": "Image Encoder Model ID", "installedModelsCount": "{{installed}} of {{total}} models installed.", "includesNModels": "Includes {{n}} models and their dependencies.", @@ -942,6 +945,7 @@ "modelConverted": "Model Converted", "modelDeleted": "Model Deleted", "modelDeleteFailed": "Failed to delete model", + "modelFormat": "Model Format", "modelImageDeleted": "Model Image Deleted", "modelImageDeleteFailed": "Model Image Delete Failed", "modelImageUpdated": "Model Image Updated", @@ -949,6 +953,7 @@ "modelManager": "Model Manager", "modelName": "Model Name", "modelSettings": "Model Settings", + "modelSettingsWarning": "These settings tell Invoke what kind of model this is and how to load it. If Invoke didn't detect these correctly when you installed the model, or if the model is classified as Unknown, you may need to edit them manually.", "modelType": "Model Type", "modelUpdated": "Model Updated", "modelUpdateFailed": "Model Update Failed", diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts index 41b2eb509e5..e53fc977b98 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts @@ -11,8 +11,8 @@ import { selectCanvasSlice, } from 'features/controlLayers/store/selectors'; import { getEntityIdentifier } from 'features/controlLayers/store/types'; +import { SUPPORTS_REF_IMAGES_BASE_MODELS } from 'features/modelManagerV2/models'; import { modelSelected } from 'features/parameters/store/actions'; -import { SUPPORTS_REF_IMAGES_BASE_MODELS } from 'features/parameters/types/constants'; import { zParameterModel } from 'features/parameters/types/parameterSchemas'; import { toast } from 'features/toast/toast'; import { t } from 'i18next'; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts index 62f398b5ed8..63602339a9b 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts @@ -37,7 +37,7 @@ import type { Logger } from 'roarr'; import { modelConfigsAdapterSelectors, modelsApi } from 'services/api/endpoints/models'; import type { AnyModelConfig } from 'services/api/types'; import { - isCLIPEmbedModelConfig, + isCLIPEmbedModelConfigOrSubmodel, isControlLayerModelConfig, isControlNetModelConfig, isFluxReduxModelConfig, @@ -48,7 +48,7 @@ import { isNonRefinerMainModelConfig, isRefinerMainModelModelConfig, isSpandrelImageToImageModelConfig, - isT5EncoderModelConfig, + isT5EncoderModelConfigOrSubmodel, isVideoModelConfig, } from 'services/api/types'; import type { JsonObject } from 'type-fest'; @@ -418,7 +418,7 @@ const handleTileControlNetModel: ModelHandler = (models, state, dispatch, log) = const handleT5EncoderModels: ModelHandler = (models, state, dispatch, log) => { const selectedT5EncoderModel = state.params.t5EncoderModel; - const t5EncoderModels = models.filter((m) => isT5EncoderModelConfig(m)); + const t5EncoderModels = models.filter((m) => isT5EncoderModelConfigOrSubmodel(m)); // If the currently selected model is available, we don't need to do anything if (selectedT5EncoderModel && t5EncoderModels.some((m) => m.key === selectedT5EncoderModel.key)) { @@ -446,7 +446,7 @@ const handleT5EncoderModels: ModelHandler = (models, state, dispatch, log) => { const handleCLIPEmbedModels: ModelHandler = (models, state, dispatch, log) => { const selectedCLIPEmbedModel = state.params.clipEmbedModel; - const CLIPEmbedModels = models.filter((m) => isCLIPEmbedModelConfig(m)); + const CLIPEmbedModels = models.filter((m) => isCLIPEmbedModelConfigOrSubmodel(m)); // If the currently selected model is available, we don't need to do anything if (selectedCLIPEmbedModel && CLIPEmbedModels.some((m) => m.key === selectedCLIPEmbedModel.key)) { diff --git a/invokeai/frontend/web/src/features/controlLayers/components/ParamDenoisingStrength.tsx b/invokeai/frontend/web/src/features/controlLayers/components/ParamDenoisingStrength.tsx index bf4464bd5bd..49a289b875c 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/ParamDenoisingStrength.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/ParamDenoisingStrength.tsx @@ -17,6 +17,7 @@ import { selectImg2imgStrengthConfig } from 'features/system/store/configSlice'; import { memo, useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { useSelectedModelConfig } from 'services/api/hooks/useSelectedModelConfig'; +import { isFluxFillMainModelModelConfig } from 'services/api/types'; const selectHasRasterLayersWithContent = createSelector( selectActiveRasterLayerEntities, @@ -46,11 +47,7 @@ export const ParamDenoisingStrength = memo(() => { // Denoising strength does nothing if there are no raster layers w/ content return true; } - if ( - selectedModelConfig?.type === 'main' && - selectedModelConfig?.base === 'flux' && - selectedModelConfig.variant === 'inpaint' - ) { + if (selectedModelConfig && isFluxFillMainModelModelConfig(selectedModelConfig)) { // Denoising strength is ignored by FLUX Fill, which is indicated by the variant being 'inpaint' return true; } diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityFilterer.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityFilterer.ts index e2bfbc6f6a4..779d1c21f87 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityFilterer.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityFilterer.ts @@ -17,7 +17,7 @@ import Konva from 'konva'; import { atom, computed } from 'nanostores'; import type { Logger } from 'roarr'; import { serializeError } from 'serialize-error'; -import { buildSelectModelConfig } from 'services/api/hooks/modelsByType'; +import { modelConfigsAdapterSelectors, selectModelConfigsQuery } from 'services/api/endpoints/models'; import { isControlLayerModelConfig } from 'services/api/types'; import stableHash from 'stable-hash'; import type { Equals } from 'tsafe'; @@ -202,11 +202,19 @@ export class CanvasEntityFilterer extends CanvasModuleBase { createInitialFilterConfig = (): FilterConfig => { if (this.parent.type === 'control_layer_adapter' && this.parent.state.controlAdapter.model) { // If the parent is a control layer adapter, we should check if the model has a default filter and set it if so - const selectModelConfig = buildSelectModelConfig( - this.parent.state.controlAdapter.model.key, - isControlLayerModelConfig - ); - const modelConfig = this.manager.stateApi.runSelector(selectModelConfig); + const key = this.parent.state.controlAdapter.model.key; + const modelConfig = this.manager.stateApi.runSelector((state) => { + const { data } = selectModelConfigsQuery(state); + if (!data) { + return null; + } + return ( + modelConfigsAdapterSelectors + .selectAll(data) + .filter(isControlLayerModelConfig) + .find((m) => m.key === key) ?? null + ); + }); // This always returns a filter const filter = getFilterForModel(modelConfig) ?? IMAGE_FILTERS.canny_edge_detection; return filter.buildDefaults(); diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/CanvasBboxToolModule.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/CanvasBboxToolModule.ts index cb799dd3aeb..088e7f265a4 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/CanvasBboxToolModule.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/CanvasBboxToolModule.ts @@ -13,8 +13,8 @@ import { selectBboxOverlay } from 'features/controlLayers/store/canvasSettingsSl import { selectModel } from 'features/controlLayers/store/paramsSlice'; import { selectBbox } from 'features/controlLayers/store/selectors'; import type { Coordinate, Rect, Tool } from 'features/controlLayers/store/types'; +import { API_BASE_MODELS } from 'features/modelManagerV2/models'; import type { ModelIdentifierField } from 'features/nodes/types/common'; -import { API_BASE_MODELS } from 'features/parameters/types/constants'; import Konva from 'konva'; import { atom } from 'nanostores'; import type { Logger } from 'roarr'; diff --git a/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts index ee1a9c6ba44..b8664aeb5ed 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts @@ -35,8 +35,8 @@ import { getScaledBoundingBoxDimensions, } from 'features/controlLayers/util/getScaledBoundingBoxDimensions'; import { simplifyFlatNumbersArray } from 'features/controlLayers/util/simplify'; +import { API_BASE_MODELS } from 'features/modelManagerV2/models'; import { isMainModelBase, zModelIdentifierField } from 'features/nodes/types/common'; -import { API_BASE_MODELS } from 'features/parameters/types/constants'; import { getGridSize, getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension'; import type { IRect } from 'konva/lib/types'; import type { UndoableOptions } from 'redux-undo'; diff --git a/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts index 3f148c5efbb..609478b4c0c 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts @@ -25,14 +25,14 @@ import { import { calculateNewSize } from 'features/controlLayers/util/getScaledBoundingBoxDimensions'; import { API_BASE_MODELS, - CLIP_SKIP_MAP, SUPPORTS_ASPECT_RATIO_BASE_MODELS, SUPPORTS_NEGATIVE_PROMPT_BASE_MODELS, SUPPORTS_OPTIMIZED_DENOISING_BASE_MODELS, SUPPORTS_PIXEL_DIMENSIONS_BASE_MODELS, SUPPORTS_REF_IMAGES_BASE_MODELS, SUPPORTS_SEED_BASE_MODELS, -} from 'features/parameters/types/constants'; +} from 'features/modelManagerV2/models'; +import { CLIP_SKIP_MAP } from 'features/parameters/types/constants'; import type { ParameterCanvasCoherenceMode, ParameterCFGRescaleMultiplier, diff --git a/invokeai/frontend/web/src/features/controlLayers/store/validators.ts b/invokeai/frontend/web/src/features/controlLayers/store/validators.ts index 03ef5404a6d..197a3d6e3e3 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/validators.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/validators.ts @@ -154,7 +154,7 @@ export const getControlLayerWarnings = ( warnings.push(WARNINGS.CONTROL_ADAPTER_INCOMPATIBLE_BASE_MODEL); } else if ( model.base === 'flux' && - model.variant === 'inpaint' && + model.variant === 'dev_fill' && entity.controlAdapter.model.type === 'control_lora' ) { // FLUX inpaint variants are FLUX Fill models - not compatible w/ Control LoRA diff --git a/invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx b/invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx index 259a3a41312..f0d1ffe878b 100644 --- a/invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx +++ b/invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx @@ -6,8 +6,8 @@ import { InformationalPopover } from 'common/components/InformationalPopover/Inf import type { GroupStatusMap } from 'common/components/Picker/Picker'; import { loraAdded, selectLoRAsSlice } from 'features/controlLayers/store/lorasSlice'; import { selectBase } from 'features/controlLayers/store/paramsSlice'; +import { API_BASE_MODELS } from 'features/modelManagerV2/models'; import { ModelPicker } from 'features/parameters/components/ModelPicker'; -import { API_BASE_MODELS } from 'features/parameters/types/constants'; import { memo, useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { useLoRAModels } from 'services/api/hooks/modelsByType'; diff --git a/invokeai/frontend/web/src/features/metadata/parsing.tsx b/invokeai/frontend/web/src/features/metadata/parsing.tsx index 10cd0e32f7f..a2363004ef3 100644 --- a/invokeai/frontend/web/src/features/metadata/parsing.tsx +++ b/invokeai/frontend/web/src/features/metadata/parsing.tsx @@ -49,7 +49,7 @@ import { zVideoDuration, zVideoResolution, } from 'features/controlLayers/store/types'; -import type { ModelIdentifierField } from 'features/nodes/types/common'; +import type { ModelIdentifierField, ModelType } from 'features/nodes/types/common'; import { zModelIdentifierField } from 'features/nodes/types/common'; import { zModelIdentifier } from 'features/nodes/types/v2/common'; import { modelSelected } from 'features/parameters/store/actions'; @@ -108,7 +108,7 @@ import { useCallback, useEffect, useState } from 'react'; import { useTranslation } from 'react-i18next'; import { imagesApi } from 'services/api/endpoints/images'; import { modelsApi } from 'services/api/endpoints/models'; -import type { AnyModelConfig, ModelType } from 'services/api/types'; +import type { AnyModelConfig } from 'services/api/types'; import { assert } from 'tsafe'; import z from 'zod'; diff --git a/invokeai/frontend/web/src/features/modelManagerV2/models.ts b/invokeai/frontend/web/src/features/modelManagerV2/models.ts new file mode 100644 index 00000000000..0b4096e010b --- /dev/null +++ b/invokeai/frontend/web/src/features/modelManagerV2/models.ts @@ -0,0 +1,303 @@ +import type { AnyModelVariant, BaseModelType, ModelFormat, ModelType } from 'features/nodes/types/common'; +import type { AnyModelConfig } from 'services/api/types'; +import { + isCLIPEmbedModelConfig, + isCLIPVisionModelConfig, + isControlLoRAModelConfig, + isControlNetModelConfig, + isFluxReduxModelConfig, + isIPAdapterModelConfig, + isLLaVAModelConfig, + isLoRAModelConfig, + isNonRefinerMainModelConfig, + isRefinerMainModelModelConfig, + isSigLipModelConfig, + isSpandrelImageToImageModelConfig, + isT2IAdapterModelConfig, + isT5EncoderModelConfig, + isTIModelConfig, + isUnknownModelConfig, + isVAEModelConfig, + isVideoModelConfig, +} from 'services/api/types'; +import { objectEntries } from 'tsafe'; + +import type { FilterableModelType } from './store/modelManagerV2Slice'; + +export type ModelCategoryData = { + category: FilterableModelType; + i18nKey: string; + filter: (config: AnyModelConfig) => boolean; +}; + +export const MODEL_CATEGORIES: Record = { + unknown: { + category: 'unknown', + i18nKey: 'common.unknown', + filter: isUnknownModelConfig, + }, + main: { + category: 'main', + i18nKey: 'modelManager.main', + filter: isNonRefinerMainModelConfig, + }, + refiner: { + category: 'refiner', + i18nKey: 'sdxl.refiner', + filter: isRefinerMainModelModelConfig, + }, + lora: { + category: 'lora', + i18nKey: 'modelManager.loraModels', + filter: isLoRAModelConfig, + }, + embedding: { + category: 'embedding', + i18nKey: 'modelManager.textualInversions', + filter: isTIModelConfig, + }, + controlnet: { + category: 'controlnet', + i18nKey: 'ControlNet', + filter: isControlNetModelConfig, + }, + t2i_adapter: { + category: 't2i_adapter', + i18nKey: 'common.t2iAdapter', + filter: isT2IAdapterModelConfig, + }, + t5_encoder: { + category: 't5_encoder', + i18nKey: 'modelManager.t5Encoder', + filter: isT5EncoderModelConfig, + }, + control_lora: { + category: 'control_lora', + i18nKey: 'modelManager.controlLora', + filter: isControlLoRAModelConfig, + }, + clip_embed: { + category: 'clip_embed', + i18nKey: 'modelManager.clipEmbed', + filter: isCLIPEmbedModelConfig, + }, + spandrel_image_to_image: { + category: 'spandrel_image_to_image', + i18nKey: 'modelManager.spandrelImageToImage', + filter: isSpandrelImageToImageModelConfig, + }, + ip_adapter: { + category: 'ip_adapter', + i18nKey: 'common.ipAdapter', + filter: isIPAdapterModelConfig, + }, + vae: { + category: 'vae', + i18nKey: 'VAE', + filter: isVAEModelConfig, + }, + clip_vision: { + category: 'clip_vision', + i18nKey: 'CLIP Vision', + filter: isCLIPVisionModelConfig, + }, + siglip: { + category: 'siglip', + i18nKey: 'modelManager.sigLip', + filter: isSigLipModelConfig, + }, + flux_redux: { + category: 'flux_redux', + i18nKey: 'modelManager.fluxRedux', + filter: isFluxReduxModelConfig, + }, + llava_onevision: { + category: 'llava_onevision', + i18nKey: 'modelManager.llavaOnevision', + filter: isLLaVAModelConfig, + }, + video: { + category: 'video', + i18nKey: 'Video', + filter: isVideoModelConfig, + }, +}; + +export const MODEL_CATEGORIES_AS_LIST = objectEntries(MODEL_CATEGORIES).map(([category, { i18nKey, filter }]) => ({ + category, + i18nKey, + filter, +})); + +/** + * Mapping of model base to its color + */ +export const MODEL_BASE_TO_COLOR: Record = { + any: 'base', + 'sd-1': 'green', + 'sd-2': 'teal', + 'sd-3': 'purple', + sdxl: 'invokeBlue', + 'sdxl-refiner': 'invokeBlue', + flux: 'gold', + cogview4: 'red', + imagen3: 'pink', + imagen4: 'pink', + 'chatgpt-4o': 'pink', + 'flux-kontext': 'pink', + 'gemini-2.5': 'pink', + veo3: 'purple', + runway: 'green', + unknown: 'red', +}; + +/** + * Mapping of model type to human readable name + */ +export const MODEL_TYPE_TO_LONG_NAME: Record = { + main: 'Main', + vae: 'VAE', + lora: 'LoRA', + llava_onevision: 'LLaVA OneVision', + control_lora: 'ControlLoRA', + controlnet: 'ControlNet', + t2i_adapter: 'T2I Adapter', + ip_adapter: 'IP Adapter', + embedding: 'Embedding', + onnx: 'ONNX', + clip_vision: 'CLIP Vision', + spandrel_image_to_image: 'Spandrel (Image to Image)', + t5_encoder: 'T5 Encoder', + clip_embed: 'CLIP Embed', + siglip: 'SigLIP', + flux_redux: 'FLUX Redux', + video: 'Video', + unknown: 'Unknown', +}; + +/** + * Mapping of model base to human readable name + */ +export const MODEL_BASE_TO_LONG_NAME: Record = { + any: 'Any', + 'sd-1': 'Stable Diffusion 1.x', + 'sd-2': 'Stable Diffusion 2.x', + 'sd-3': 'Stable Diffusion 3.x', + sdxl: 'Stable Diffusion XL', + 'sdxl-refiner': 'Stable Diffusion XL Refiner', + flux: 'FLUX', + cogview4: 'CogView4', + imagen3: 'Imagen3', + imagen4: 'Imagen4', + 'chatgpt-4o': 'ChatGPT 4o', + 'flux-kontext': 'Flux Kontext', + 'gemini-2.5': 'Gemini 2.5', + veo3: 'Veo3', + runway: 'Runway', + unknown: 'Unknown', +}; + +/** + * Mapping of model base to short human readable name + */ +export const MODEL_BASE_TO_SHORT_NAME: Record = { + any: 'Any', + 'sd-1': 'SD1.X', + 'sd-2': 'SD2.X', + 'sd-3': 'SD3.X', + sdxl: 'SDXL', + 'sdxl-refiner': 'SDXLR', + flux: 'FLUX', + cogview4: 'CogView4', + imagen3: 'Imagen3', + imagen4: 'Imagen4', + 'chatgpt-4o': 'ChatGPT 4o', + 'flux-kontext': 'Flux Kontext', + 'gemini-2.5': 'Gemini 2.5', + veo3: 'Veo3', + runway: 'Runway', + unknown: 'Unknown', +}; + +export const MODEL_VARIANT_TO_LONG_NAME: Record = { + normal: 'Normal', + inpaint: 'Inpaint', + depth: 'Depth', + dev: 'FLUX Dev', + dev_fill: 'FLUX Dev - Fill', + schnell: 'FLUX Schnell', + large: 'CLIP L', + gigantic: 'CLIP G', +}; + +export const MODEL_FORMAT_TO_LONG_NAME: Record = { + omi: 'OMI', + diffusers: 'Diffusers', + checkpoint: 'Checkpoint', + lycoris: 'LyCORIS', + onnx: 'ONNX', + olive: 'Olive', + embedding_file: 'Embedding (file)', + embedding_folder: 'Embedding (folder)', + invokeai: 'InvokeAI', + t5_encoder: 'T5 Encoder', + bnb_quantized_int8b: 'BNB Quantized (int8b)', + bnb_quantized_nf4b: 'BNB Quantized (nf4b)', + gguf_quantized: 'GGUF Quantized', + api: 'API', + unknown: 'Unknown', +}; + +/** + * List of base models that make API requests + */ +export const API_BASE_MODELS: BaseModelType[] = ['imagen3', 'imagen4', 'chatgpt-4o', 'flux-kontext', 'gemini-2.5']; + +export const SUPPORTS_SEED_BASE_MODELS: BaseModelType[] = ['sd-1', 'sd-2', 'sd-3', 'sdxl', 'flux', 'cogview4']; + +export const SUPPORTS_OPTIMIZED_DENOISING_BASE_MODELS: BaseModelType[] = ['flux', 'sd-3']; + +export const SUPPORTS_REF_IMAGES_BASE_MODELS: BaseModelType[] = [ + 'sd-1', + 'sdxl', + 'flux', + 'flux-kontext', + 'chatgpt-4o', + 'gemini-2.5', +]; + +export const SUPPORTS_NEGATIVE_PROMPT_BASE_MODELS: BaseModelType[] = [ + 'sd-1', + 'sd-2', + 'sdxl', + 'cogview4', + 'sd-3', + 'imagen3', + 'imagen4', +]; + +export const SUPPORTS_PIXEL_DIMENSIONS_BASE_MODELS: BaseModelType[] = [ + 'sd-1', + 'sd-2', + 'sd-3', + 'sdxl', + 'flux', + 'cogview4', +]; + +export const SUPPORTS_ASPECT_RATIO_BASE_MODELS: BaseModelType[] = [ + 'sd-1', + 'sd-2', + 'sd-3', + 'sdxl', + 'flux', + 'cogview4', + 'imagen3', + 'imagen4', + 'flux-kontext', + 'chatgpt-4o', +]; + +export const VIDEO_BASE_MODELS = ['veo3', 'runway']; + +export const REQUIRES_STARTING_FRAME_BASE_MODELS = ['runway']; diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelBaseBadge.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelBaseBadge.tsx index 2b4755597af..a1be6b208ec 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelBaseBadge.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelBaseBadge.tsx @@ -1,34 +1,16 @@ import { Badge } from '@invoke-ai/ui-library'; -import { MODEL_TYPE_SHORT_MAP } from 'features/parameters/types/constants'; +import { MODEL_BASE_TO_COLOR, MODEL_BASE_TO_SHORT_NAME } from 'features/modelManagerV2/models'; +import type { BaseModelType } from 'features/nodes/types/common'; import { memo } from 'react'; -import type { BaseModelType } from 'services/api/types'; type Props = { base: BaseModelType; }; -export const BASE_COLOR_MAP: Record = { - any: 'base', - 'sd-1': 'green', - 'sd-2': 'teal', - 'sd-3': 'purple', - sdxl: 'invokeBlue', - 'sdxl-refiner': 'invokeBlue', - flux: 'gold', - cogview4: 'red', - imagen3: 'pink', - imagen4: 'pink', - 'chatgpt-4o': 'pink', - 'flux-kontext': 'pink', - 'gemini-2.5': 'pink', - veo3: 'purple', - runway: 'green', -}; - const ModelBaseBadge = ({ base }: Props) => { return ( - - {MODEL_TYPE_SHORT_MAP[base]} + + {MODEL_BASE_TO_SHORT_NAME[base]} ); }; diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelFormatBadge.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelFormatBadge.tsx index 2dfb691008a..e139639f1f0 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelFormatBadge.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelFormatBadge.tsx @@ -1,12 +1,12 @@ import { Badge } from '@invoke-ai/ui-library'; +import type { ModelFormat } from 'features/nodes/types/common'; import { memo } from 'react'; -import type { AnyModelConfig } from 'services/api/types'; type Props = { - format: AnyModelConfig['format']; + format: ModelFormat; }; -const FORMAT_NAME_MAP: Record = { +const FORMAT_NAME_MAP: Record = { diffusers: 'diffusers', lycoris: 'lycoris', checkpoint: 'checkpoint', @@ -19,9 +19,12 @@ const FORMAT_NAME_MAP: Record = { gguf_quantized: 'gguf', api: 'api', omi: 'omi', + unknown: 'unknown', + olive: 'olive', + onnx: 'onnx', }; -const FORMAT_COLOR_MAP: Record = { +const FORMAT_COLOR_MAP: Record = { diffusers: 'base', omi: 'base', lycoris: 'base', @@ -34,6 +37,9 @@ const FORMAT_COLOR_MAP: Record = { bnb_quantized_nf4b: 'base', gguf_quantized: 'base', api: 'base', + unknown: 'red', + olive: 'base', + onnx: 'base', }; const ModelFormatBadge = ({ format }: Props) => { diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelList.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelList.tsx index 945630229d1..bde3f1d5946 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelList.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelList.tsx @@ -1,6 +1,8 @@ import { Flex, Text } from '@invoke-ai/ui-library'; +import { logger } from 'app/logging/logger'; import { useAppSelector } from 'app/store/storeHooks'; import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent'; +import { MODEL_CATEGORIES_AS_LIST } from 'features/modelManagerV2/models'; import { type FilterableModelType, selectFilteredModelType, @@ -8,274 +10,50 @@ import { } from 'features/modelManagerV2/store/modelManagerV2Slice'; import { memo, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; -import { - useCLIPEmbedModels, - useCLIPVisionModels, - useControlLoRAModel, - useControlNetModels, - useEmbeddingModels, - useFluxReduxModels, - useIPAdapterModels, - useLLaVAModels, - useLoRAModels, - useMainModels, - useRefinerModels, - useSigLipModels, - useSpandrelImageToImageModels, - useT2IAdapterModels, - useT5EncoderModels, - useVAEModels, -} from 'services/api/hooks/modelsByType'; +import { modelConfigsAdapterSelectors, useGetModelConfigsQuery } from 'services/api/endpoints/models'; import type { AnyModelConfig } from 'services/api/types'; import { FetchingModelsLoader } from './FetchingModelsLoader'; import { ModelListWrapper } from './ModelListWrapper'; +const log = logger('models'); + const ModelList = () => { const filteredModelType = useAppSelector(selectFilteredModelType); const searchTerm = useAppSelector(selectSearchTerm); const { t } = useTranslation(); - const [mainModels, { isLoading: isLoadingMainModels }] = useMainModels(); - const filteredMainModels = useMemo( - () => modelsFilter(mainModels, searchTerm, filteredModelType), - [mainModels, searchTerm, filteredModelType] - ); - - const [refinerModels, { isLoading: isLoadingRefinerModels }] = useRefinerModels(); - const filteredRefinerModels = useMemo( - () => modelsFilter(refinerModels, searchTerm, filteredModelType), - [refinerModels, searchTerm, filteredModelType] - ); - - const [loraModels, { isLoading: isLoadingLoRAModels }] = useLoRAModels(); - const filteredLoRAModels = useMemo( - () => modelsFilter(loraModels, searchTerm, filteredModelType), - [loraModels, searchTerm, filteredModelType] - ); - - const [embeddingModels, { isLoading: isLoadingEmbeddingModels }] = useEmbeddingModels(); - const filteredEmbeddingModels = useMemo( - () => modelsFilter(embeddingModels, searchTerm, filteredModelType), - [embeddingModels, searchTerm, filteredModelType] - ); - - const [controlNetModels, { isLoading: isLoadingControlNetModels }] = useControlNetModels(); - const filteredControlNetModels = useMemo( - () => modelsFilter(controlNetModels, searchTerm, filteredModelType), - [controlNetModels, searchTerm, filteredModelType] - ); - - const [t2iAdapterModels, { isLoading: isLoadingT2IAdapterModels }] = useT2IAdapterModels(); - const filteredT2IAdapterModels = useMemo( - () => modelsFilter(t2iAdapterModels, searchTerm, filteredModelType), - [t2iAdapterModels, searchTerm, filteredModelType] - ); - - const [ipAdapterModels, { isLoading: isLoadingIPAdapterModels }] = useIPAdapterModels(); - const filteredIPAdapterModels = useMemo( - () => modelsFilter(ipAdapterModels, searchTerm, filteredModelType), - [ipAdapterModels, searchTerm, filteredModelType] - ); - - const [clipVisionModels, { isLoading: isLoadingCLIPVisionModels }] = useCLIPVisionModels(); - const filteredCLIPVisionModels = useMemo( - () => modelsFilter(clipVisionModels, searchTerm, filteredModelType), - [clipVisionModels, searchTerm, filteredModelType] - ); - - const [vaeModels, { isLoading: isLoadingVAEModels }] = useVAEModels({ excludeSubmodels: true }); - const filteredVAEModels = useMemo( - () => modelsFilter(vaeModels, searchTerm, filteredModelType), - [vaeModels, searchTerm, filteredModelType] - ); - - const [t5EncoderModels, { isLoading: isLoadingT5EncoderModels }] = useT5EncoderModels({ excludeSubmodels: true }); - const filteredT5EncoderModels = useMemo( - () => modelsFilter(t5EncoderModels, searchTerm, filteredModelType), - [t5EncoderModels, searchTerm, filteredModelType] - ); - - const [controlLoRAModels, { isLoading: isLoadingControlLoRAModels }] = useControlLoRAModel(); - const filteredControlLoRAModels = useMemo( - () => modelsFilter(controlLoRAModels, searchTerm, filteredModelType), - [controlLoRAModels, searchTerm, filteredModelType] - ); - - const [clipEmbedModels, { isLoading: isLoadingClipEmbedModels }] = useCLIPEmbedModels({ excludeSubmodels: true }); - const filteredClipEmbedModels = useMemo( - () => modelsFilter(clipEmbedModels, searchTerm, filteredModelType), - [clipEmbedModels, searchTerm, filteredModelType] - ); - - const [spandrelImageToImageModels, { isLoading: isLoadingSpandrelImageToImageModels }] = - useSpandrelImageToImageModels(); - const filteredSpandrelImageToImageModels = useMemo( - () => modelsFilter(spandrelImageToImageModels, searchTerm, filteredModelType), - [spandrelImageToImageModels, searchTerm, filteredModelType] - ); - - const [sigLipModels, { isLoading: isLoadingSigLipModels }] = useSigLipModels(); - const filteredSigLipModels = useMemo( - () => modelsFilter(sigLipModels, searchTerm, filteredModelType), - [sigLipModels, searchTerm, filteredModelType] - ); - - const [fluxReduxModels, { isLoading: isLoadingFluxReduxModels }] = useFluxReduxModels(); - const filteredFluxReduxModels = useMemo( - () => modelsFilter(fluxReduxModels, searchTerm, filteredModelType), - [fluxReduxModels, searchTerm, filteredModelType] - ); - - const [llavaOneVisionModels, { isLoading: isLoadingLlavaOneVisionModels }] = useLLaVAModels(); - const filteredLlavaOneVisionModels = useMemo( - () => modelsFilter(llavaOneVisionModels, searchTerm, filteredModelType), - [llavaOneVisionModels, searchTerm, filteredModelType] - ); - - const totalFilteredModels = useMemo(() => { - return ( - filteredMainModels.length + - filteredRefinerModels.length + - filteredLoRAModels.length + - filteredEmbeddingModels.length + - filteredControlNetModels.length + - filteredT2IAdapterModels.length + - filteredIPAdapterModels.length + - filteredCLIPVisionModels.length + - filteredVAEModels.length + - filteredSpandrelImageToImageModels.length + - filteredSigLipModels.length + - filteredFluxReduxModels.length + - t5EncoderModels.length + - clipEmbedModels.length + - controlLoRAModels.length - ); - }, [ - filteredControlNetModels.length, - filteredEmbeddingModels.length, - filteredIPAdapterModels.length, - filteredCLIPVisionModels.length, - filteredLoRAModels.length, - filteredMainModels.length, - filteredRefinerModels.length, - filteredT2IAdapterModels.length, - filteredVAEModels.length, - filteredSpandrelImageToImageModels.length, - filteredSigLipModels.length, - filteredFluxReduxModels.length, - t5EncoderModels.length, - clipEmbedModels.length, - controlLoRAModels.length, - ]); + const { data, isLoading } = useGetModelConfigsQuery(); + + const models = useMemo(() => { + const modelConfigs = modelConfigsAdapterSelectors.selectAll(data ?? { ids: [], entities: {} }); + const baseFilteredModelConfigs = modelsFilter(modelConfigs, searchTerm, filteredModelType); + const byCategory: { i18nKey: string; configs: AnyModelConfig[] }[] = []; + const total = baseFilteredModelConfigs.length; + let renderedTotal = 0; + for (const { i18nKey, filter } of MODEL_CATEGORIES_AS_LIST) { + const configs = baseFilteredModelConfigs.filter(filter); + renderedTotal += configs.length; + byCategory.push({ i18nKey, configs }); + } + if (renderedTotal !== total) { + const ctx = { total, renderedTotal, difference: total - renderedTotal }; + log.warn( + ctx, + `ModelList: Not all models were categorized - ensure all possible models are covered in MODEL_CATEGORIES` + ); + } + return { total, byCategory }; + }, [data, filteredModelType, searchTerm]); return ( - {/* Main Model List */} - {isLoadingMainModels && } - {!isLoadingMainModels && filteredMainModels.length > 0 && ( - - )} - {/* Refiner Model List */} - {isLoadingRefinerModels && } - {!isLoadingRefinerModels && filteredRefinerModels.length > 0 && ( - - )} - {/* LoRAs List */} - {isLoadingLoRAModels && } - {!isLoadingLoRAModels && filteredLoRAModels.length > 0 && ( - - )} - - {/* TI List */} - {isLoadingEmbeddingModels && } - {!isLoadingEmbeddingModels && filteredEmbeddingModels.length > 0 && ( - - )} - - {/* VAE List */} - {isLoadingVAEModels && } - {!isLoadingVAEModels && filteredVAEModels.length > 0 && ( - - )} - - {/* Controlnet List */} - {isLoadingControlNetModels && } - {!isLoadingControlNetModels && filteredControlNetModels.length > 0 && ( - - )} - {/* IP Adapter List */} - {isLoadingIPAdapterModels && } - {!isLoadingIPAdapterModels && filteredIPAdapterModels.length > 0 && ( - - )} - {/* CLIP Vision List */} - {isLoadingCLIPVisionModels && } - {!isLoadingCLIPVisionModels && filteredCLIPVisionModels.length > 0 && ( - - )} - {/* T2I Adapters List */} - {isLoadingT2IAdapterModels && } - {!isLoadingT2IAdapterModels && filteredT2IAdapterModels.length > 0 && ( - - )} - {/* T5 Encoders List */} - {isLoadingT5EncoderModels && } - {!isLoadingT5EncoderModels && filteredT5EncoderModels.length > 0 && ( - - )} - {/* Control Lora List */} - {isLoadingControlLoRAModels && } - {!isLoadingControlLoRAModels && filteredControlLoRAModels.length > 0 && ( - - )} - {/* Clip Embed List */} - {isLoadingClipEmbedModels && } - {!isLoadingClipEmbedModels && filteredClipEmbedModels.length > 0 && ( - - )} - - {/* LLaVA OneVision List */} - {isLoadingLlavaOneVisionModels && } - {!isLoadingLlavaOneVisionModels && filteredLlavaOneVisionModels.length > 0 && ( - - )} - - {/* Spandrel Image to Image List */} - {isLoadingSpandrelImageToImageModels && ( - - )} - {!isLoadingSpandrelImageToImageModels && filteredSpandrelImageToImageModels.length > 0 && ( - - )} - {/* SigLIP List */} - {isLoadingSigLipModels && } - {!isLoadingSigLipModels && filteredSigLipModels.length > 0 && ( - - )} - {/* Flux Redux List */} - {isLoadingFluxReduxModels && } - {!isLoadingFluxReduxModels && filteredFluxReduxModels.length > 0 && ( - - )} - {totalFilteredModels === 0 && ( + {isLoading && } + {models.byCategory.map(({ i18nKey, configs }) => ( + + ))} + {!isLoading && models.total === 0 && ( {t('modelManager.noMatchingModels')} @@ -293,7 +71,13 @@ const modelsFilter = ( filteredModelType: FilterableModelType | null ): T[] => { return data.filter((model) => { - const matchesFilter = model.name.toLowerCase().includes(nameFilter.toLowerCase()); + const matchesFilter = + model.name.toLowerCase().includes(nameFilter.toLowerCase()) || + model.base.toLowerCase().includes(nameFilter.toLowerCase()) || + model.type.toLowerCase().includes(nameFilter.toLowerCase()) || + model.description?.toLowerCase().includes(nameFilter.toLowerCase()) || + model.format.toLowerCase().includes(nameFilter.toLowerCase()); + const matchesType = getMatchesType(model, filteredModelType); return matchesFilter && matchesType; diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelListWrapper.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelListWrapper.tsx index 08c3f4568d2..9783a88b062 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelListWrapper.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelListWrapper.tsx @@ -25,6 +25,9 @@ const contentSx = { export const ModelListWrapper = memo((props: ModelListWrapperProps) => { const { title, modelList } = props; + if (modelList.length === 0) { + return null; + } return ( {modelList.map((model) => ( diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelTypeFilter.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelTypeFilter.tsx index 320397cfa0b..0ee479e86b2 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelTypeFilter.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelTypeFilter.tsx @@ -1,46 +1,17 @@ import { Button, Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import type { FilterableModelType } from 'features/modelManagerV2/store/modelManagerV2Slice'; +import type { ModelCategoryData } from 'features/modelManagerV2/models'; +import { MODEL_CATEGORIES, MODEL_CATEGORIES_AS_LIST } from 'features/modelManagerV2/models'; import { selectFilteredModelType, setFilteredModelType } from 'features/modelManagerV2/store/modelManagerV2Slice'; -import { memo, useCallback, useMemo } from 'react'; +import { memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; import { PiFunnelBold } from 'react-icons/pi'; -import { objectKeys } from 'tsafe'; export const ModelTypeFilter = memo(() => { const { t } = useTranslation(); const dispatch = useAppDispatch(); - const MODEL_TYPE_LABELS: Record = useMemo( - () => ({ - main: t('modelManager.main'), - refiner: t('sdxl.refiner'), - lora: 'LoRA', - embedding: t('modelManager.textualInversions'), - controlnet: 'ControlNet', - vae: 'VAE', - t2i_adapter: t('common.t2iAdapter'), - t5_encoder: t('modelManager.t5Encoder'), - clip_embed: t('modelManager.clipEmbed'), - ip_adapter: t('common.ipAdapter'), - clip_vision: 'CLIP Vision', - spandrel_image_to_image: t('modelManager.spandrelImageToImage'), - control_lora: t('modelManager.controlLora'), - siglip: t('modelManager.sigLip'), - flux_redux: t('modelManager.fluxRedux'), - llava_onevision: t('modelManager.llavaOnevision'), - video: t('modelManager.video'), - }), - [t] - ); const filteredModelType = useAppSelector(selectFilteredModelType); - const selectModelType = useCallback( - (option: FilterableModelType) => { - dispatch(setFilteredModelType(option)); - }, - [dispatch] - ); - const clearModelType = useCallback(() => { dispatch(setFilteredModelType(null)); }, [dispatch]); @@ -48,18 +19,12 @@ export const ModelTypeFilter = memo(() => { return ( }> - {filteredModelType ? MODEL_TYPE_LABELS[filteredModelType] : t('modelManager.allModels')} + {filteredModelType ? t(MODEL_CATEGORIES[filteredModelType].i18nKey) : t('modelManager.allModels')} {t('modelManager.allModels')} - {objectKeys(MODEL_TYPE_LABELS).map((option) => ( - - {MODEL_TYPE_LABELS[option]} - + {MODEL_CATEGORIES_AS_LIST.map((data) => ( + ))} @@ -67,3 +32,18 @@ export const ModelTypeFilter = memo(() => { }); ModelTypeFilter.displayName = 'ModelTypeFilter'; + +const ModelMenuItem = memo(({ data }: { data: ModelCategoryData }) => { + const { t } = useTranslation(); + const dispatch = useAppDispatch(); + const filteredModelType = useAppSelector(selectFilteredModelType); + const onClick = useCallback(() => { + dispatch(setFilteredModelType(data.category)); + }, [data.category, dispatch]); + return ( + + {t(data.i18nKey)} + + ); +}); +ModelMenuItem.displayName = 'ModelMenuItem'; diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/BaseModelSelect.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/BaseModelSelect.tsx index a9159728717..8235d26efef 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/BaseModelSelect.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/BaseModelSelect.tsx @@ -1,20 +1,17 @@ import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library'; import { Combobox } from '@invoke-ai/ui-library'; import { typedMemo } from 'common/util/typedMemo'; -import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; +import { MODEL_BASE_TO_LONG_NAME } from 'features/modelManagerV2/models'; import { useCallback, useMemo } from 'react'; import type { Control } from 'react-hook-form'; import { useController } from 'react-hook-form'; import type { UpdateModelArg } from 'services/api/endpoints/models'; +import { objectEntries } from 'tsafe'; -const options: ComboboxOption[] = [ - { value: 'sd-1', label: MODEL_TYPE_MAP['sd-1'] }, - { value: 'sd-2', label: MODEL_TYPE_MAP['sd-2'] }, - { value: 'sd-3', label: MODEL_TYPE_MAP['sd-3'] }, - { value: 'flux', label: MODEL_TYPE_MAP['flux'] }, - { value: 'sdxl', label: MODEL_TYPE_MAP['sdxl'] }, - { value: 'sdxl-refiner', label: MODEL_TYPE_MAP['sdxl-refiner'] }, -]; +const options: ComboboxOption[] = objectEntries(MODEL_BASE_TO_LONG_NAME).map(([value, label]) => ({ + label, + value, +})); type Props = { control: Control; diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/ModelFormatSelect.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/ModelFormatSelect.tsx new file mode 100644 index 00000000000..1057ab7784c --- /dev/null +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/ModelFormatSelect.tsx @@ -0,0 +1,32 @@ +import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library'; +import { Combobox } from '@invoke-ai/ui-library'; +import { typedMemo } from 'common/util/typedMemo'; +import { MODEL_FORMAT_TO_LONG_NAME } from 'features/modelManagerV2/models'; +import { useCallback, useMemo } from 'react'; +import type { Control } from 'react-hook-form'; +import { useController } from 'react-hook-form'; +import type { UpdateModelArg } from 'services/api/endpoints/models'; +import { objectEntries } from 'tsafe'; + +const options: ComboboxOption[] = objectEntries(MODEL_FORMAT_TO_LONG_NAME).map(([value, label]) => ({ + label, + value, +})); + +type Props = { + control: Control; +}; + +const ModelFormatSelect = ({ control }: Props) => { + const { field } = useController({ control, name: 'format' }); + const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]); + const onChange = useCallback( + (v) => { + field.onChange(v?.value); + }, + [field] + ); + return ; +}; + +export default typedMemo(ModelFormatSelect); diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/ModelTypeSelect.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/ModelTypeSelect.tsx new file mode 100644 index 00000000000..44b41f01518 --- /dev/null +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/ModelTypeSelect.tsx @@ -0,0 +1,32 @@ +import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library'; +import { Combobox } from '@invoke-ai/ui-library'; +import { typedMemo } from 'common/util/typedMemo'; +import { MODEL_TYPE_TO_LONG_NAME } from 'features/modelManagerV2/models'; +import { useCallback, useMemo } from 'react'; +import type { Control } from 'react-hook-form'; +import { useController } from 'react-hook-form'; +import type { UpdateModelArg } from 'services/api/endpoints/models'; +import { objectEntries } from 'tsafe'; + +const options: ComboboxOption[] = objectEntries(MODEL_TYPE_TO_LONG_NAME).map(([value, label]) => ({ + label, + value, +})); + +type Props = { + control: Control; +}; + +const ModelTypeSelect = ({ control }: Props) => { + const { field } = useController({ control, name: 'type' }); + const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]); + const onChange = useCallback( + (v) => { + field.onChange(v?.value); + }, + [field] + ); + return ; +}; + +export default typedMemo(ModelTypeSelect); diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/ModelVariantSelect.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/ModelVariantSelect.tsx index 6686cc43368..52eb2a4749d 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/ModelVariantSelect.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/ModelVariantSelect.tsx @@ -1,16 +1,14 @@ import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library'; import { Combobox } from '@invoke-ai/ui-library'; import { typedMemo } from 'common/util/typedMemo'; +import { MODEL_VARIANT_TO_LONG_NAME } from 'features/modelManagerV2/models'; import { useCallback, useMemo } from 'react'; import type { Control } from 'react-hook-form'; import { useController } from 'react-hook-form'; import type { UpdateModelArg } from 'services/api/endpoints/models'; +import { objectEntries } from 'tsafe'; -const options: ComboboxOption[] = [ - { value: 'normal', label: 'Normal' }, - { value: 'inpaint', label: 'Inpaint' }, - { value: 'depth', label: 'Depth' }, -]; +const options: ComboboxOption[] = objectEntries(MODEL_VARIANT_TO_LONG_NAME).map(([value, label]) => ({ label, value })); type Props = { control: Control; diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelEdit.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelEdit.tsx index ff5c680325c..d845eca3eec 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelEdit.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelEdit.tsx @@ -8,6 +8,7 @@ import { Heading, Input, SimpleGrid, + Text, Textarea, } from '@invoke-ai/ui-library'; import { useAppDispatch } from 'app/store/storeHooks'; @@ -22,6 +23,8 @@ import { type UpdateModelArg, useUpdateModelMutation } from 'services/api/endpoi import type { AnyModelConfig } from 'services/api/types'; import BaseModelSelect from './Fields/BaseModelSelect'; +import ModelFormatSelect from './Fields/ModelFormatSelect'; +import ModelTypeSelect from './Fields/ModelTypeSelect'; import ModelVariantSelect from './Fields/ModelVariantSelect'; import PredictionTypeSelect from './Fields/PredictionTypeSelect'; import { ModelFooter } from './ModelFooter'; @@ -121,40 +124,41 @@ export const ModelEdit = memo(({ modelConfig }: Props) => {