Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
e70bedb
refactor(ModelManager): factor out _get_implementation method
keturn Jul 29, 2023
dca685a
refactor(ModelManager): refactor rescan-on-miss to exists() method
keturn Jul 29, 2023
b163ae6
refactor(ModelManager): factor out get_model_config
keturn Jul 29, 2023
bc9a503
refactor(ModelManager): factor out get_model_path
keturn Jul 29, 2023
86b8b69
internal(ModelManager): add instantiate method
keturn Jul 29, 2023
21617e6
Merge remote-tracking branch 'origin/main' into refactor/model_manage…
keturn Jul 29, 2023
ccceb32
lint: formatting
keturn Jul 29, 2023
dbfd1bc
Merge branch 'main' into refactor/model_manager_instantiate
keturn Jul 30, 2023
ff1c407
lint: formatting
keturn Jul 30, 2023
0e48c98
Merge remote-tracking branch 'origin/main' into refactor/model_manage…
keturn Jul 30, 2023
adfd1e5
refactor(model_manager): avoid copy/paste logic
keturn Jul 30, 2023
e351905
Merge remote-tracking branch 'origin/main' into refactor/model_manage…
keturn Jul 31, 2023
bacdf98
doc(model_manager): docstrings
keturn Jul 31, 2023
5998509
Merge branch 'main' into refactor/model_manager_instantiate
keturn Aug 1, 2023
1f9e984
Merge branch 'main' into refactor/model_manager_instantiate
keturn Aug 1, 2023
02d2cc7
Merge branch 'main' into refactor/model_manager_instantiate
keturn Aug 3, 2023
91ebf9f
Merge branch 'main' into refactor/model_manager_instantiate
keturn Aug 3, 2023
b10cf20
Merge branch 'main' into refactor/model_manager_instantiate
keturn Aug 5, 2023
65ed224
Merge branch 'main' into refactor/model_manager_instantiate
keturn Aug 5, 2023
44bf308
test(model_management): add a couple tests for _get_model_path
keturn Aug 5, 2023
7a4ff4c
Merge branch 'main' into refactor/model_manager_instantiate
keturn Aug 5, 2023
80876bb
Merge remote-tracking branch 'origin/refactor/model_manager_instantia…
keturn Aug 5, 2023
7f4c387
test(model_management): factor out name strings
keturn Aug 5, 2023
5bfd6cb
Merge remote-tracking branch 'origin/main' into refactor/model_manage…
keturn Aug 6, 2023
f272a44
Merge branch 'main' into refactor/model_manager_instantiate
keturn Aug 7, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 88 additions & 39 deletions invokeai/backend/model_management/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,19 +228,19 @@ class ConfigMeta(BaseModel):Loras, textual_inversion and controlnet models are n
"""
from __future__ import annotations

import os
import hashlib
import os
import textwrap
import yaml
import types
from dataclasses import dataclass
from pathlib import Path
from typing import Literal, Optional, List, Tuple, Union, Dict, Set, Callable, types
from shutil import rmtree, move
from typing import Optional, List, Literal, Tuple, Union, Dict, Set, Callable

import torch
import yaml
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig

from pydantic import BaseModel, Field

import invokeai.backend.util.logging as logger
Expand All @@ -259,6 +259,7 @@ class ConfigMeta(BaseModel):Loras, textual_inversion and controlnet models are n
ModelNotFoundException,
InvalidModelException,
DuplicateModelException,
ModelBase,
)

# We are only starting to number the config file with release 3.
Expand Down Expand Up @@ -361,7 +362,7 @@ def _read_models(self, config: Optional[DictConfig] = None):
if model_key.startswith("_"):
continue
model_name, base_model, model_type = self.parse_key(model_key)
model_class = MODEL_CLASSES[base_model][model_type]
model_class = self._get_implementation(base_model, model_type)
# alias for config file
model_config["model_format"] = model_config.pop("format")
self.models[model_key] = model_class.create_config(**model_config)
Expand All @@ -381,18 +382,24 @@ def sync_to_config(self):
# causing otherwise unreferenced models to be removed from memory
self._read_models()

def model_exists(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
) -> bool:
def model_exists(self, model_name: str, base_model: BaseModelType, model_type: ModelType, *, rescan=False) -> bool:
"""
Given a model name, returns True if it is a valid
identifier.
Given a model name, returns True if it is a valid identifier.

:param model_name: symbolic name of the model in models.yaml
:param model_type: ModelType enum indicating the type of model to return
:param base_model: BaseModelType enum indicating the base model used by this model
:param rescan: if True, scan_models_directory
"""
model_key = self.create_key(model_name, base_model, model_type)
return model_key in self.models
exists = model_key in self.models

# if model not found try to find it (maybe file just pasted)
if rescan and not exists:
self.scan_models_directory(base_model=base_model, model_type=model_type)
exists = self.model_exists(model_name, base_model, model_type, rescan=False)

return exists

@classmethod
def create_key(
Expand Down Expand Up @@ -443,39 +450,32 @@ def get_model(
:param model_name: symbolic name of the model in models.yaml
:param model_type: ModelType enum indicating the type of model to return
:param base_model: BaseModelType enum indicating the base model used by this model
:param submode_typel: an ModelType enum indicating the portion of
:param submodel_type: an ModelType enum indicating the portion of
the model to retrieve (e.g. ModelType.Vae)
"""
model_class = MODEL_CLASSES[base_model][model_type]
model_key = self.create_key(model_name, base_model, model_type)

# if model not found try to find it (maybe file just pasted)
if model_key not in self.models:
self.scan_models_directory(base_model=base_model, model_type=model_type)
if model_key not in self.models:
raise ModelNotFoundException(f"Model not found - {model_key}")
if not self.model_exists(model_name, base_model, model_type, rescan=True):
raise ModelNotFoundException(f"Model not found - {model_key}")

model_config = self._get_model_config(base_model, model_name, model_type)

model_path, is_submodel_override = self._get_model_path(model_config, submodel_type)

model_config = self.models[model_key]
model_path = self.resolve_model_path(model_config.path)
if is_submodel_override:
model_type = submodel_type
submodel_type = None

model_class = self._get_implementation(base_model, model_type)

if not model_path.exists():
if model_class.save_to_config:
self.models[model_key].error = ModelError.NotFound
raise Exception(f'Files for model "{model_key}" not found')
raise Exception(f'Files for model "{model_key}" not found at {model_path}')

else:
self.models.pop(model_key, None)
raise ModelNotFoundException(f"Model not found - {model_key}")

# vae/movq override
# TODO:
if submodel_type is not None and hasattr(model_config, submodel_type):
override_path = getattr(model_config, submodel_type)
if override_path:
model_path = self.resolve_path(override_path)
model_type = submodel_type
submodel_type = None
model_class = MODEL_CLASSES[base_model][model_type]
raise ModelNotFoundException(f'Files for model "{model_key}" not found at {model_path}')

# TODO: path
# TODO: is it accurate to use path as id
Expand Down Expand Up @@ -513,6 +513,55 @@ def get_model(
_cache=self.cache,
)

def _get_model_path(
self, model_config: ModelConfigBase, submodel_type: Optional[SubModelType] = None
) -> (Path, bool):
"""Extract a model's filesystem path from its config.

:return: The fully qualified Path of the module (or submodule).
"""
model_path = model_config.path
is_submodel_override = False

# Does the config explicitly override the submodel?
if submodel_type is not None and hasattr(model_config, submodel_type):
submodel_path = getattr(model_config, submodel_type)
if submodel_path is not None:
model_path = getattr(model_config, submodel_type)
is_submodel_override = True

model_path = self.resolve_model_path(model_path)
return model_path, is_submodel_override

def _get_model_config(self, base_model: BaseModelType, model_name: str, model_type: ModelType) -> ModelConfigBase:
"""Get a model's config object."""
model_key = self.create_key(model_name, base_model, model_type)
try:
model_config = self.models[model_key]
except KeyError:
raise ModelNotFoundException(f"Model not found - {model_key}")
return model_config

def _get_implementation(self, base_model: BaseModelType, model_type: ModelType) -> type[ModelBase]:
"""Get the concrete implementation class for a specific model type."""
model_class = MODEL_CLASSES[base_model][model_type]
return model_class

def _instantiate(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel_type: Optional[SubModelType] = None,
) -> ModelBase:
"""Make a new instance of this model, without loading it."""
model_config = self._get_model_config(base_model, model_name, model_type)
model_path, is_submodel_override = self._get_model_path(model_config, submodel_type)
# FIXME: do non-overriden submodels get the right class?
constructor = self._get_implementation(base_model, model_type)
instance = constructor(model_path, base_model, model_type)
return instance

def model_info(
self,
model_name: str,
Expand Down Expand Up @@ -660,7 +709,7 @@ def add_model(
if path := model_attributes.get("path"):
model_attributes["path"] = str(self.relative_model_path(Path(path)))

model_class = MODEL_CLASSES[base_model][model_type]
model_class = self._get_implementation(base_model, model_type)
model_config = model_class.create_config(**model_attributes)
model_key = self.create_key(model_name, base_model, model_type)

Expand Down Expand Up @@ -851,7 +900,7 @@ def commit(self, conf_file: Optional[Path] = None) -> None:

for model_key, model_config in self.models.items():
model_name, base_model, model_type = self.parse_key(model_key)
model_class = MODEL_CLASSES[base_model][model_type]
model_class = self._get_implementation(base_model, model_type)
if model_class.save_to_config:
# TODO: or exclude_unset better fits here?
data_to_save[model_key] = model_config.dict(exclude_defaults=True, exclude={"error"})
Expand Down Expand Up @@ -909,7 +958,7 @@ def scan_models_directory(

model_path = self.resolve_model_path(model_config.path).absolute()
if not model_path.exists():
model_class = MODEL_CLASSES[cur_base_model][cur_model_type]
model_class = self._get_implementation(cur_base_model, cur_model_type)
if model_class.save_to_config:
model_config.error = ModelError.NotFound
self.models.pop(model_key, None)
Expand All @@ -925,7 +974,7 @@ def scan_models_directory(
for cur_model_type in ModelType:
if model_type is not None and cur_model_type != model_type:
continue
model_class = MODEL_CLASSES[cur_base_model][cur_model_type]
model_class = self._get_implementation(cur_base_model, cur_model_type)
models_dir = self.resolve_model_path(Path(cur_base_model.value, cur_model_type.value))

if not models_dir.exists():
Expand Down
16 changes: 9 additions & 7 deletions invokeai/backend/model_management/models/vae.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import os
import torch
import safetensors
from enum import Enum
from pathlib import Path
from typing import Optional, Union, Literal
from typing import Optional

import safetensors
import torch
from diffusers.utils import is_safetensors_available
from omegaconf import OmegaConf

from invokeai.app.services.config import InvokeAIAppConfig
from .base import (
ModelBase,
ModelConfigBase,
Expand All @@ -18,9 +23,6 @@
InvalidModelException,
ModelNotFoundException,
)
from invokeai.app.services.config import InvokeAIAppConfig
from diffusers.utils import is_safetensors_available
from omegaconf import OmegaConf


class VaeModelFormat(str, Enum):
Expand Down Expand Up @@ -80,7 +82,7 @@ def save_to_config(cls) -> bool:
@classmethod
def detect_format(cls, path: str):
if not os.path.exists(path):
raise ModelNotFoundException()
raise ModelNotFoundException(f"Does not exist as local file: {path}")

if os.path.isdir(path):
if os.path.exists(os.path.join(path, "config.json")):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ dependencies = [
"dev" = [
"pudb",
]
"test" = ["pytest>6.0.0", "pytest-cov", "black"]
"test" = ["pytest>6.0.0", "pytest-cov", "pytest-datadir", "black"]
"xformers" = [
"xformers~=0.0.19; sys_platform!='darwin'",
"triton; sys_platform=='linux'",
Expand Down
38 changes: 38 additions & 0 deletions tests/test_model_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from pathlib import Path

import pytest

from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend import ModelManager, BaseModelType, ModelType, SubModelType

BASIC_MODEL_NAME = ("SDXL base", BaseModelType.StableDiffusionXL, ModelType.Main)
VAE_OVERRIDE_MODEL_NAME = ("SDXL with VAE", BaseModelType.StableDiffusionXL, ModelType.Main)


@pytest.fixture
def model_manager(datadir) -> ModelManager:
InvokeAIAppConfig.get_config(root=datadir)
return ModelManager(datadir / "configs" / "relative_sub.models.yaml")


def test_get_model_names(model_manager: ModelManager):
names = model_manager.model_names()
assert names[:2] == [BASIC_MODEL_NAME, VAE_OVERRIDE_MODEL_NAME]


def test_get_model_path_for_diffusers(model_manager: ModelManager, datadir: Path):
model_config = model_manager._get_model_config(BASIC_MODEL_NAME[1], BASIC_MODEL_NAME[0], BASIC_MODEL_NAME[2])
top_model_path, is_override = model_manager._get_model_path(model_config)
expected_model_path = datadir / "models" / "sdxl" / "main" / "SDXL base 1_0"
assert top_model_path == expected_model_path
assert not is_override


def test_get_model_path_for_overridden_vae(model_manager: ModelManager, datadir: Path):
model_config = model_manager._get_model_config(
VAE_OVERRIDE_MODEL_NAME[1], VAE_OVERRIDE_MODEL_NAME[0], VAE_OVERRIDE_MODEL_NAME[2]
)
vae_model_path, is_override = model_manager._get_model_path(model_config, SubModelType.Vae)
expected_vae_path = datadir / "models" / "sdxl" / "vae" / "sdxl-vae-fp16-fix"
assert vae_model_path == expected_vae_path
assert is_override
15 changes: 15 additions & 0 deletions tests/test_model_manager/configs/relative_sub.models.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
__metadata__:
version: 3.0.0

sdxl/main/SDXL base:
path: sdxl/main/SDXL base 1_0
description: SDXL base v1.0
variant: normal
format: diffusers

sdxl/main/SDXL with VAE:
path: sdxl/main/SDXL base 1_0
description: SDXL with customized VAE
vae: sdxl/vae/sdxl-vae-fp16-fix/
variant: normal
format: diffusers
Empty file.
Empty file.