From a72a423a5ee9dd8e70a2e622d725bc5259fce073 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 13 Jan 2024 15:23:06 +1100 Subject: [PATCH 001/340] tidy(nodes): move all field things to fields.py Unfortunately, this is necessary to prevent circular imports at runtime. --- invokeai/app/api/routers/images.py | 2 +- invokeai/app/api_app.py | 3 +- invokeai/app/invocations/baseinvocation.py | 453 +--------------- invokeai/app/invocations/collections.py | 3 +- invokeai/app/invocations/compel.py | 6 +- .../controlnet_image_processors.py | 6 +- invokeai/app/invocations/cv.py | 3 +- invokeai/app/invocations/facetools.py | 4 +- invokeai/app/invocations/fields.py | 501 ++++++++++++++++++ invokeai/app/invocations/image.py | 5 +- invokeai/app/invocations/infill.py | 3 +- invokeai/app/invocations/ip_adapter.py | 5 +- invokeai/app/invocations/latent.py | 7 +- invokeai/app/invocations/math.py | 4 +- invokeai/app/invocations/metadata.py | 6 +- invokeai/app/invocations/model.py | 5 +- invokeai/app/invocations/noise.py | 4 +- invokeai/app/invocations/onnx.py | 16 +- invokeai/app/invocations/param_easing.py | 3 +- invokeai/app/invocations/primitives.py | 6 +- invokeai/app/invocations/prompt.py | 3 +- invokeai/app/invocations/sdxl.py | 6 +- invokeai/app/invocations/strings.py | 4 +- invokeai/app/invocations/t2i_adapter.py | 5 +- invokeai/app/invocations/tiles.py | 5 +- invokeai/app/invocations/upscale.py | 3 +- .../services/image_files/image_files_base.py | 2 +- .../services/image_files/image_files_disk.py | 2 +- .../image_records/image_records_base.py | 2 +- .../image_records/image_records_sqlite.py | 2 +- invokeai/app/services/images/images_base.py | 2 +- .../app/services/images/images_default.py | 2 +- invokeai/app/services/shared/graph.py | 5 +- invokeai/app/shared/fields.py | 67 --- invokeai/app/shared/models.py | 2 +- tests/aa_nodes/test_nodes.py | 3 +- 36 files changed, 552 insertions(+), 608 deletions(-) create mode 100644 invokeai/app/invocations/fields.py delete mode 100644 invokeai/app/shared/fields.py diff --git a/invokeai/app/api/routers/images.py b/invokeai/app/api/routers/images.py index 125896b8d3a..cc60ad1be83 100644 --- a/invokeai/app/api/routers/images.py +++ b/invokeai/app/api/routers/images.py @@ -8,7 +8,7 @@ from PIL import Image from pydantic import BaseModel, Field, ValidationError -from invokeai.app.invocations.baseinvocation import MetadataField, MetadataFieldValidator +from invokeai.app.invocations.fields import MetadataField, MetadataFieldValidator from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO from invokeai.app.services.shared.pagination import OffsetPaginatedResults diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index 6294083d0e1..f48074de7c7 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -6,6 +6,7 @@ from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles from invokeai.version.invokeai_version import __version__ +from .invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra from .services.config import InvokeAIAppConfig app_config = InvokeAIAppConfig.get_config() @@ -57,8 +58,6 @@ from .api.sockets import SocketIO from .invocations.baseinvocation import ( BaseInvocation, - InputFieldJSONSchemaExtra, - OutputFieldJSONSchemaExtra, UIConfigBase, ) diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index d9e0c7ba0d2..395d5e98707 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -12,10 +12,11 @@ from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Literal, Optional, Type, TypeVar, Union, cast import semver -from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter, create_model -from pydantic.fields import FieldInfo, _Unset +from pydantic import BaseModel, ConfigDict, Field, create_model +from pydantic.fields import FieldInfo from pydantic_core import PydanticUndefined +from invokeai.app.invocations.fields import FieldKind, Input from invokeai.app.services.config.config_default import InvokeAIAppConfig from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID from invokeai.app.shared.fields import FieldDescriptions @@ -52,393 +53,6 @@ class Classification(str, Enum, metaclass=MetaEnum): Prototype = "prototype" -class Input(str, Enum, metaclass=MetaEnum): - """ - The type of input a field accepts. - - `Input.Direct`: The field must have its value provided directly, when the invocation and field \ - are instantiated. - - `Input.Connection`: The field must have its value provided by a connection. - - `Input.Any`: The field may have its value provided either directly or by a connection. - """ - - Connection = "connection" - Direct = "direct" - Any = "any" - - -class FieldKind(str, Enum, metaclass=MetaEnum): - """ - The kind of field. - - `Input`: An input field on a node. - - `Output`: An output field on a node. - - `Internal`: A field which is treated as an input, but cannot be used in node definitions. Metadata is - one example. It is provided to nodes via the WithMetadata class, and we want to reserve the field name - "metadata" for this on all nodes. `FieldKind` is used to short-circuit the field name validation logic, - allowing "metadata" for that field. - - `NodeAttribute`: The field is a node attribute. These are fields which are not inputs or outputs, - but which are used to store information about the node. For example, the `id` and `type` fields are node - attributes. - - The presence of this in `json_schema_extra["field_kind"]` is used when initializing node schemas on app - startup, and when generating the OpenAPI schema for the workflow editor. - """ - - Input = "input" - Output = "output" - Internal = "internal" - NodeAttribute = "node_attribute" - - -class UIType(str, Enum, metaclass=MetaEnum): - """ - Type hints for the UI for situations in which the field type is not enough to infer the correct UI type. - - - Model Fields - The most common node-author-facing use will be for model fields. Internally, there is no difference - between SD-1, SD-2 and SDXL model fields - they all use the class `MainModelField`. To ensure the - base-model-specific UI is rendered, use e.g. `ui_type=UIType.SDXLMainModelField` to indicate that - the field is an SDXL main model field. - - - Any Field - We cannot infer the usage of `typing.Any` via schema parsing, so you *must* use `ui_type=UIType.Any` to - indicate that the field accepts any type. Use with caution. This cannot be used on outputs. - - - Scheduler Field - Special handling in the UI is needed for this field, which otherwise would be parsed as a plain enum field. - - - Internal Fields - Similar to the Any Field, the `collect` and `iterate` nodes make use of `typing.Any`. To facilitate - handling these types in the client, we use `UIType._Collection` and `UIType._CollectionItem`. These - should not be used by node authors. - - - DEPRECATED Fields - These types are deprecated and should not be used by node authors. A warning will be logged if one is - used, and the type will be ignored. They are included here for backwards compatibility. - """ - - # region Model Field Types - SDXLMainModel = "SDXLMainModelField" - SDXLRefinerModel = "SDXLRefinerModelField" - ONNXModel = "ONNXModelField" - VaeModel = "VAEModelField" - LoRAModel = "LoRAModelField" - ControlNetModel = "ControlNetModelField" - IPAdapterModel = "IPAdapterModelField" - # endregion - - # region Misc Field Types - Scheduler = "SchedulerField" - Any = "AnyField" - # endregion - - # region Internal Field Types - _Collection = "CollectionField" - _CollectionItem = "CollectionItemField" - # endregion - - # region DEPRECATED - Boolean = "DEPRECATED_Boolean" - Color = "DEPRECATED_Color" - Conditioning = "DEPRECATED_Conditioning" - Control = "DEPRECATED_Control" - Float = "DEPRECATED_Float" - Image = "DEPRECATED_Image" - Integer = "DEPRECATED_Integer" - Latents = "DEPRECATED_Latents" - String = "DEPRECATED_String" - BooleanCollection = "DEPRECATED_BooleanCollection" - ColorCollection = "DEPRECATED_ColorCollection" - ConditioningCollection = "DEPRECATED_ConditioningCollection" - ControlCollection = "DEPRECATED_ControlCollection" - FloatCollection = "DEPRECATED_FloatCollection" - ImageCollection = "DEPRECATED_ImageCollection" - IntegerCollection = "DEPRECATED_IntegerCollection" - LatentsCollection = "DEPRECATED_LatentsCollection" - StringCollection = "DEPRECATED_StringCollection" - BooleanPolymorphic = "DEPRECATED_BooleanPolymorphic" - ColorPolymorphic = "DEPRECATED_ColorPolymorphic" - ConditioningPolymorphic = "DEPRECATED_ConditioningPolymorphic" - ControlPolymorphic = "DEPRECATED_ControlPolymorphic" - FloatPolymorphic = "DEPRECATED_FloatPolymorphic" - ImagePolymorphic = "DEPRECATED_ImagePolymorphic" - IntegerPolymorphic = "DEPRECATED_IntegerPolymorphic" - LatentsPolymorphic = "DEPRECATED_LatentsPolymorphic" - StringPolymorphic = "DEPRECATED_StringPolymorphic" - MainModel = "DEPRECATED_MainModel" - UNet = "DEPRECATED_UNet" - Vae = "DEPRECATED_Vae" - CLIP = "DEPRECATED_CLIP" - Collection = "DEPRECATED_Collection" - CollectionItem = "DEPRECATED_CollectionItem" - Enum = "DEPRECATED_Enum" - WorkflowField = "DEPRECATED_WorkflowField" - IsIntermediate = "DEPRECATED_IsIntermediate" - BoardField = "DEPRECATED_BoardField" - MetadataItem = "DEPRECATED_MetadataItem" - MetadataItemCollection = "DEPRECATED_MetadataItemCollection" - MetadataItemPolymorphic = "DEPRECATED_MetadataItemPolymorphic" - MetadataDict = "DEPRECATED_MetadataDict" - # endregion - - -class UIComponent(str, Enum, metaclass=MetaEnum): - """ - The type of UI component to use for a field, used to override the default components, which are - inferred from the field type. - """ - - None_ = "none" - Textarea = "textarea" - Slider = "slider" - - -class InputFieldJSONSchemaExtra(BaseModel): - """ - Extra attributes to be added to input fields and their OpenAPI schema. Used during graph execution, - and by the workflow editor during schema parsing and UI rendering. - """ - - input: Input - orig_required: bool - field_kind: FieldKind - default: Optional[Any] = None - orig_default: Optional[Any] = None - ui_hidden: bool = False - ui_type: Optional[UIType] = None - ui_component: Optional[UIComponent] = None - ui_order: Optional[int] = None - ui_choice_labels: Optional[dict[str, str]] = None - - model_config = ConfigDict( - validate_assignment=True, - json_schema_serialization_defaults_required=True, - ) - - -class OutputFieldJSONSchemaExtra(BaseModel): - """ - Extra attributes to be added to input fields and their OpenAPI schema. Used by the workflow editor - during schema parsing and UI rendering. - """ - - field_kind: FieldKind - ui_hidden: bool - ui_type: Optional[UIType] - ui_order: Optional[int] - - model_config = ConfigDict( - validate_assignment=True, - json_schema_serialization_defaults_required=True, - ) - - -def InputField( - # copied from pydantic's Field - # TODO: Can we support default_factory? - default: Any = _Unset, - default_factory: Callable[[], Any] | None = _Unset, - title: str | None = _Unset, - description: str | None = _Unset, - pattern: str | None = _Unset, - strict: bool | None = _Unset, - gt: float | None = _Unset, - ge: float | None = _Unset, - lt: float | None = _Unset, - le: float | None = _Unset, - multiple_of: float | None = _Unset, - allow_inf_nan: bool | None = _Unset, - max_digits: int | None = _Unset, - decimal_places: int | None = _Unset, - min_length: int | None = _Unset, - max_length: int | None = _Unset, - # custom - input: Input = Input.Any, - ui_type: Optional[UIType] = None, - ui_component: Optional[UIComponent] = None, - ui_hidden: bool = False, - ui_order: Optional[int] = None, - ui_choice_labels: Optional[dict[str, str]] = None, -) -> Any: - """ - Creates an input field for an invocation. - - This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/latest/api/fields/#pydantic.fields.Field) \ - that adds a few extra parameters to support graph execution and the node editor UI. - - :param Input input: [Input.Any] The kind of input this field requires. \ - `Input.Direct` means a value must be provided on instantiation. \ - `Input.Connection` means the value must be provided by a connection. \ - `Input.Any` means either will do. - - :param UIType ui_type: [None] Optionally provides an extra type hint for the UI. \ - In some situations, the field's type is not enough to infer the correct UI type. \ - For example, model selection fields should render a dropdown UI component to select a model. \ - Internally, there is no difference between SD-1, SD-2 and SDXL model fields, they all use \ - `MainModelField`. So to ensure the base-model-specific UI is rendered, you can use \ - `UIType.SDXLMainModelField` to indicate that the field is an SDXL main model field. - - :param UIComponent ui_component: [None] Optionally specifies a specific component to use in the UI. \ - The UI will always render a suitable component, but sometimes you want something different than the default. \ - For example, a `string` field will default to a single-line input, but you may want a multi-line textarea instead. \ - For this case, you could provide `UIComponent.Textarea`. - - :param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI. - - :param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. - - :param dict[str, str] ui_choice_labels: [None] Specifies the labels to use for the choices in an enum field. - """ - - json_schema_extra_ = InputFieldJSONSchemaExtra( - input=input, - ui_type=ui_type, - ui_component=ui_component, - ui_hidden=ui_hidden, - ui_order=ui_order, - ui_choice_labels=ui_choice_labels, - field_kind=FieldKind.Input, - orig_required=True, - ) - - """ - There is a conflict between the typing of invocation definitions and the typing of an invocation's - `invoke()` function. - - On instantiation of a node, the invocation definition is used to create the python class. At this time, - any number of fields may be optional, because they may be provided by connections. - - On calling of `invoke()`, however, those fields may be required. - - For example, consider an ResizeImageInvocation with an `image: ImageField` field. - - `image` is required during the call to `invoke()`, but when the python class is instantiated, - the field may not be present. This is fine, because that image field will be provided by a - connection from an ancestor node, which outputs an image. - - This means we want to type the `image` field as optional for the node class definition, but required - for the `invoke()` function. - - If we use `typing.Optional` in the node class definition, the field will be typed as optional in the - `invoke()` method, and we'll have to do a lot of runtime checks to ensure the field is present - or - any static type analysis tools will complain. - - To get around this, in node class definitions, we type all fields correctly for the `invoke()` function, - but secretly make them optional in `InputField()`. We also store the original required bool and/or default - value. When we call `invoke()`, we use this stored information to do an additional check on the class. - """ - - if default_factory is not _Unset and default_factory is not None: - default = default_factory() - logger.warn('"default_factory" is not supported, calling it now to set "default"') - - # These are the args we may wish pass to the pydantic `Field()` function - field_args = { - "default": default, - "title": title, - "description": description, - "pattern": pattern, - "strict": strict, - "gt": gt, - "ge": ge, - "lt": lt, - "le": le, - "multiple_of": multiple_of, - "allow_inf_nan": allow_inf_nan, - "max_digits": max_digits, - "decimal_places": decimal_places, - "min_length": min_length, - "max_length": max_length, - } - - # We only want to pass the args that were provided, otherwise the `Field()`` function won't work as expected - provided_args = {k: v for (k, v) in field_args.items() if v is not PydanticUndefined} - - # Because we are manually making fields optional, we need to store the original required bool for reference later - json_schema_extra_.orig_required = default is PydanticUndefined - - # Make Input.Any and Input.Connection fields optional, providing None as a default if the field doesn't already have one - if input is Input.Any or input is Input.Connection: - default_ = None if default is PydanticUndefined else default - provided_args.update({"default": default_}) - if default is not PydanticUndefined: - # Before invoking, we'll check for the original default value and set it on the field if the field has no value - json_schema_extra_.default = default - json_schema_extra_.orig_default = default - elif default is not PydanticUndefined: - default_ = default - provided_args.update({"default": default_}) - json_schema_extra_.orig_default = default_ - - return Field( - **provided_args, - json_schema_extra=json_schema_extra_.model_dump(exclude_none=True), - ) - - -def OutputField( - # copied from pydantic's Field - default: Any = _Unset, - title: str | None = _Unset, - description: str | None = _Unset, - pattern: str | None = _Unset, - strict: bool | None = _Unset, - gt: float | None = _Unset, - ge: float | None = _Unset, - lt: float | None = _Unset, - le: float | None = _Unset, - multiple_of: float | None = _Unset, - allow_inf_nan: bool | None = _Unset, - max_digits: int | None = _Unset, - decimal_places: int | None = _Unset, - min_length: int | None = _Unset, - max_length: int | None = _Unset, - # custom - ui_type: Optional[UIType] = None, - ui_hidden: bool = False, - ui_order: Optional[int] = None, -) -> Any: - """ - Creates an output field for an invocation output. - - This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/1.10/usage/schema/#field-customization) \ - that adds a few extra parameters to support graph execution and the node editor UI. - - :param UIType ui_type: [None] Optionally provides an extra type hint for the UI. \ - In some situations, the field's type is not enough to infer the correct UI type. \ - For example, model selection fields should render a dropdown UI component to select a model. \ - Internally, there is no difference between SD-1, SD-2 and SDXL model fields, they all use \ - `MainModelField`. So to ensure the base-model-specific UI is rendered, you can use \ - `UIType.SDXLMainModelField` to indicate that the field is an SDXL main model field. - - :param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI. \ - - :param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. \ - """ - return Field( - default=default, - title=title, - description=description, - pattern=pattern, - strict=strict, - gt=gt, - ge=ge, - lt=lt, - le=le, - multiple_of=multiple_of, - allow_inf_nan=allow_inf_nan, - max_digits=max_digits, - decimal_places=decimal_places, - min_length=min_length, - max_length=max_length, - json_schema_extra=OutputFieldJSONSchemaExtra( - ui_type=ui_type, - ui_hidden=ui_hidden, - ui_order=ui_order, - field_kind=FieldKind.Output, - ).model_dump(exclude_none=True), - ) - - class UIConfigBase(BaseModel): """ Provides additional node configuration to the UI. @@ -460,33 +74,6 @@ class UIConfigBase(BaseModel): ) -class InvocationContext: - """Initialized and provided to on execution of invocations.""" - - services: InvocationServices - graph_execution_state_id: str - queue_id: str - queue_item_id: int - queue_batch_id: str - workflow: Optional[WorkflowWithoutID] - - def __init__( - self, - services: InvocationServices, - queue_id: str, - queue_item_id: int, - queue_batch_id: str, - graph_execution_state_id: str, - workflow: Optional[WorkflowWithoutID], - ): - self.services = services - self.graph_execution_state_id = graph_execution_state_id - self.queue_id = queue_id - self.queue_item_id = queue_item_id - self.queue_batch_id = queue_batch_id - self.workflow = workflow - - class BaseInvocationOutput(BaseModel): """ Base class for all invocation outputs. @@ -926,37 +513,3 @@ def wrapper(cls: Type[TBaseInvocationOutput]) -> Type[TBaseInvocationOutput]: return cls return wrapper - - -class MetadataField(RootModel): - """ - Pydantic model for metadata with custom root of type dict[str, Any]. - Metadata is stored without a strict schema. - """ - - root: dict[str, Any] = Field(description="The metadata") - - -MetadataFieldValidator = TypeAdapter(MetadataField) - - -class WithMetadata(BaseModel): - metadata: Optional[MetadataField] = Field( - default=None, - description=FieldDescriptions.metadata, - json_schema_extra=InputFieldJSONSchemaExtra( - field_kind=FieldKind.Internal, - input=Input.Connection, - orig_required=False, - ).model_dump(exclude_none=True), - ) - - -class WithWorkflow: - workflow = None - - def __init_subclass__(cls) -> None: - logger.warn( - f"{cls.__module__.split('.')[0]}.{cls.__name__}: WithWorkflow is deprecated. Use `context.workflow` to access the workflow." - ) - super().__init_subclass__() diff --git a/invokeai/app/invocations/collections.py b/invokeai/app/invocations/collections.py index 4c7b6f94cd4..d35a9d79c74 100644 --- a/invokeai/app/invocations/collections.py +++ b/invokeai/app/invocations/collections.py @@ -7,7 +7,8 @@ from invokeai.app.invocations.primitives import IntegerCollectionOutput from invokeai.app.util.misc import SEED_MAX -from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation +from .baseinvocation import BaseInvocation, InvocationContext, invocation +from .fields import InputField @invocation( diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 49c62cff564..b386aef2cbe 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -5,8 +5,8 @@ from compel import Compel, ReturnedEmbeddingsType from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment +from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIComponent from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput -from invokeai.app.shared.fields import FieldDescriptions from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( BasicConditioningInfo, ExtraConditioningInfo, @@ -20,11 +20,7 @@ from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, - Input, - InputField, InvocationContext, - OutputField, - UIComponent, invocation, invocation_output, ) diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index 1f9342985a0..9b652b8eee9 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -25,10 +25,10 @@ from PIL import Image from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator +from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, WithMetadata from invokeai.app.invocations.primitives import ImageField, ImageOutput from invokeai.app.invocations.util import validate_begin_end_step, validate_weights from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin -from invokeai.app.shared.fields import FieldDescriptions from invokeai.backend.image_util.depth_anything import DepthAnythingDetector from invokeai.backend.image_util.dw_openpose import DWOpenposeDetector @@ -36,11 +36,7 @@ from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, - Input, - InputField, InvocationContext, - OutputField, - WithMetadata, invocation, invocation_output, ) diff --git a/invokeai/app/invocations/cv.py b/invokeai/app/invocations/cv.py index cb6828d21ac..5865338e192 100644 --- a/invokeai/app/invocations/cv.py +++ b/invokeai/app/invocations/cv.py @@ -8,7 +8,8 @@ from invokeai.app.invocations.primitives import ImageField, ImageOutput from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin -from .baseinvocation import BaseInvocation, InputField, InvocationContext, WithMetadata, invocation +from .baseinvocation import BaseInvocation, InvocationContext, invocation +from .fields import InputField, WithMetadata @invocation("cv_inpaint", title="OpenCV Inpaint", tags=["opencv", "inpaint"], category="inpaint", version="1.2.0") diff --git a/invokeai/app/invocations/facetools.py b/invokeai/app/invocations/facetools.py index e0c89b4de5a..13f1066ec3e 100644 --- a/invokeai/app/invocations/facetools.py +++ b/invokeai/app/invocations/facetools.py @@ -13,13 +13,11 @@ import invokeai.assets.fonts as font_assets from invokeai.app.invocations.baseinvocation import ( BaseInvocation, - InputField, InvocationContext, - OutputField, - WithMetadata, invocation, invocation_output, ) +from invokeai.app.invocations.fields import InputField, OutputField, WithMetadata from invokeai.app.invocations.primitives import ImageField, ImageOutput from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin diff --git a/invokeai/app/invocations/fields.py b/invokeai/app/invocations/fields.py new file mode 100644 index 00000000000..0cce8e3c6b5 --- /dev/null +++ b/invokeai/app/invocations/fields.py @@ -0,0 +1,501 @@ +from enum import Enum +from typing import Any, Callable, Optional + +from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter +from pydantic.fields import _Unset +from pydantic_core import PydanticUndefined + +from invokeai.app.util.metaenum import MetaEnum +from invokeai.backend.util.logging import InvokeAILogger + +logger = InvokeAILogger.get_logger() + + +class UIType(str, Enum, metaclass=MetaEnum): + """ + Type hints for the UI for situations in which the field type is not enough to infer the correct UI type. + + - Model Fields + The most common node-author-facing use will be for model fields. Internally, there is no difference + between SD-1, SD-2 and SDXL model fields - they all use the class `MainModelField`. To ensure the + base-model-specific UI is rendered, use e.g. `ui_type=UIType.SDXLMainModelField` to indicate that + the field is an SDXL main model field. + + - Any Field + We cannot infer the usage of `typing.Any` via schema parsing, so you *must* use `ui_type=UIType.Any` to + indicate that the field accepts any type. Use with caution. This cannot be used on outputs. + + - Scheduler Field + Special handling in the UI is needed for this field, which otherwise would be parsed as a plain enum field. + + - Internal Fields + Similar to the Any Field, the `collect` and `iterate` nodes make use of `typing.Any`. To facilitate + handling these types in the client, we use `UIType._Collection` and `UIType._CollectionItem`. These + should not be used by node authors. + + - DEPRECATED Fields + These types are deprecated and should not be used by node authors. A warning will be logged if one is + used, and the type will be ignored. They are included here for backwards compatibility. + """ + + # region Model Field Types + SDXLMainModel = "SDXLMainModelField" + SDXLRefinerModel = "SDXLRefinerModelField" + ONNXModel = "ONNXModelField" + VaeModel = "VAEModelField" + LoRAModel = "LoRAModelField" + ControlNetModel = "ControlNetModelField" + IPAdapterModel = "IPAdapterModelField" + # endregion + + # region Misc Field Types + Scheduler = "SchedulerField" + Any = "AnyField" + # endregion + + # region Internal Field Types + _Collection = "CollectionField" + _CollectionItem = "CollectionItemField" + # endregion + + # region DEPRECATED + Boolean = "DEPRECATED_Boolean" + Color = "DEPRECATED_Color" + Conditioning = "DEPRECATED_Conditioning" + Control = "DEPRECATED_Control" + Float = "DEPRECATED_Float" + Image = "DEPRECATED_Image" + Integer = "DEPRECATED_Integer" + Latents = "DEPRECATED_Latents" + String = "DEPRECATED_String" + BooleanCollection = "DEPRECATED_BooleanCollection" + ColorCollection = "DEPRECATED_ColorCollection" + ConditioningCollection = "DEPRECATED_ConditioningCollection" + ControlCollection = "DEPRECATED_ControlCollection" + FloatCollection = "DEPRECATED_FloatCollection" + ImageCollection = "DEPRECATED_ImageCollection" + IntegerCollection = "DEPRECATED_IntegerCollection" + LatentsCollection = "DEPRECATED_LatentsCollection" + StringCollection = "DEPRECATED_StringCollection" + BooleanPolymorphic = "DEPRECATED_BooleanPolymorphic" + ColorPolymorphic = "DEPRECATED_ColorPolymorphic" + ConditioningPolymorphic = "DEPRECATED_ConditioningPolymorphic" + ControlPolymorphic = "DEPRECATED_ControlPolymorphic" + FloatPolymorphic = "DEPRECATED_FloatPolymorphic" + ImagePolymorphic = "DEPRECATED_ImagePolymorphic" + IntegerPolymorphic = "DEPRECATED_IntegerPolymorphic" + LatentsPolymorphic = "DEPRECATED_LatentsPolymorphic" + StringPolymorphic = "DEPRECATED_StringPolymorphic" + MainModel = "DEPRECATED_MainModel" + UNet = "DEPRECATED_UNet" + Vae = "DEPRECATED_Vae" + CLIP = "DEPRECATED_CLIP" + Collection = "DEPRECATED_Collection" + CollectionItem = "DEPRECATED_CollectionItem" + Enum = "DEPRECATED_Enum" + WorkflowField = "DEPRECATED_WorkflowField" + IsIntermediate = "DEPRECATED_IsIntermediate" + BoardField = "DEPRECATED_BoardField" + MetadataItem = "DEPRECATED_MetadataItem" + MetadataItemCollection = "DEPRECATED_MetadataItemCollection" + MetadataItemPolymorphic = "DEPRECATED_MetadataItemPolymorphic" + MetadataDict = "DEPRECATED_MetadataDict" + + +class UIComponent(str, Enum, metaclass=MetaEnum): + """ + The type of UI component to use for a field, used to override the default components, which are + inferred from the field type. + """ + + None_ = "none" + Textarea = "textarea" + Slider = "slider" + + +class FieldDescriptions: + denoising_start = "When to start denoising, expressed a percentage of total steps" + denoising_end = "When to stop denoising, expressed a percentage of total steps" + cfg_scale = "Classifier-Free Guidance scale" + cfg_rescale_multiplier = "Rescale multiplier for CFG guidance, used for models trained with zero-terminal SNR" + scheduler = "Scheduler to use during inference" + positive_cond = "Positive conditioning tensor" + negative_cond = "Negative conditioning tensor" + noise = "Noise tensor" + clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count" + unet = "UNet (scheduler, LoRAs)" + vae = "VAE" + cond = "Conditioning tensor" + controlnet_model = "ControlNet model to load" + vae_model = "VAE model to load" + lora_model = "LoRA model to load" + main_model = "Main model (UNet, VAE, CLIP) to load" + sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load" + sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load" + onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load" + lora_weight = "The weight at which the LoRA is applied to each model" + compel_prompt = "Prompt to be parsed by Compel to create a conditioning tensor" + raw_prompt = "Raw prompt text (no parsing)" + sdxl_aesthetic = "The aesthetic score to apply to the conditioning tensor" + skipped_layers = "Number of layers to skip in text encoder" + seed = "Seed for random number generation" + steps = "Number of steps to run" + width = "Width of output (px)" + height = "Height of output (px)" + control = "ControlNet(s) to apply" + ip_adapter = "IP-Adapter to apply" + t2i_adapter = "T2I-Adapter(s) to apply" + denoised_latents = "Denoised latents tensor" + latents = "Latents tensor" + strength = "Strength of denoising (proportional to steps)" + metadata = "Optional metadata to be saved with the image" + metadata_collection = "Collection of Metadata" + metadata_item_polymorphic = "A single metadata item or collection of metadata items" + metadata_item_label = "Label for this metadata item" + metadata_item_value = "The value for this metadata item (may be any type)" + workflow = "Optional workflow to be saved with the image" + interp_mode = "Interpolation mode" + torch_antialias = "Whether or not to apply antialiasing (bilinear or bicubic only)" + fp32 = "Whether or not to use full float32 precision" + precision = "Precision to use" + tiled = "Processing using overlapping tiles (reduce memory consumption)" + detect_res = "Pixel resolution for detection" + image_res = "Pixel resolution for output image" + safe_mode = "Whether or not to use safe mode" + scribble_mode = "Whether or not to use scribble mode" + scale_factor = "The factor by which to scale" + blend_alpha = ( + "Blending factor. 0.0 = use input A only, 1.0 = use input B only, 0.5 = 50% mix of input A and input B." + ) + num_1 = "The first number" + num_2 = "The second number" + mask = "The mask to use for the operation" + board = "The board to save the image to" + image = "The image to process" + tile_size = "Tile size" + inclusive_low = "The inclusive low value" + exclusive_high = "The exclusive high value" + decimal_places = "The number of decimal places to round to" + freeu_s1 = 'Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process.' + freeu_s2 = 'Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process.' + freeu_b1 = "Scaling factor for stage 1 to amplify the contributions of backbone features." + freeu_b2 = "Scaling factor for stage 2 to amplify the contributions of backbone features." + + +class MetadataField(RootModel): + """ + Pydantic model for metadata with custom root of type dict[str, Any]. + Metadata is stored without a strict schema. + """ + + root: dict[str, Any] = Field(description="The metadata") + + +MetadataFieldValidator = TypeAdapter(MetadataField) + + +class Input(str, Enum, metaclass=MetaEnum): + """ + The type of input a field accepts. + - `Input.Direct`: The field must have its value provided directly, when the invocation and field \ + are instantiated. + - `Input.Connection`: The field must have its value provided by a connection. + - `Input.Any`: The field may have its value provided either directly or by a connection. + """ + + Connection = "connection" + Direct = "direct" + Any = "any" + + +class FieldKind(str, Enum, metaclass=MetaEnum): + """ + The kind of field. + - `Input`: An input field on a node. + - `Output`: An output field on a node. + - `Internal`: A field which is treated as an input, but cannot be used in node definitions. Metadata is + one example. It is provided to nodes via the WithMetadata class, and we want to reserve the field name + "metadata" for this on all nodes. `FieldKind` is used to short-circuit the field name validation logic, + allowing "metadata" for that field. + - `NodeAttribute`: The field is a node attribute. These are fields which are not inputs or outputs, + but which are used to store information about the node. For example, the `id` and `type` fields are node + attributes. + + The presence of this in `json_schema_extra["field_kind"]` is used when initializing node schemas on app + startup, and when generating the OpenAPI schema for the workflow editor. + """ + + Input = "input" + Output = "output" + Internal = "internal" + NodeAttribute = "node_attribute" + + +class InputFieldJSONSchemaExtra(BaseModel): + """ + Extra attributes to be added to input fields and their OpenAPI schema. Used during graph execution, + and by the workflow editor during schema parsing and UI rendering. + """ + + input: Input + orig_required: bool + field_kind: FieldKind + default: Optional[Any] = None + orig_default: Optional[Any] = None + ui_hidden: bool = False + ui_type: Optional[UIType] = None + ui_component: Optional[UIComponent] = None + ui_order: Optional[int] = None + ui_choice_labels: Optional[dict[str, str]] = None + + model_config = ConfigDict( + validate_assignment=True, + json_schema_serialization_defaults_required=True, + ) + + +class WithMetadata(BaseModel): + metadata: Optional[MetadataField] = Field( + default=None, + description=FieldDescriptions.metadata, + json_schema_extra=InputFieldJSONSchemaExtra( + field_kind=FieldKind.Internal, + input=Input.Connection, + orig_required=False, + ).model_dump(exclude_none=True), + ) + + +class WithWorkflow: + workflow = None + + def __init_subclass__(cls) -> None: + logger.warn( + f"{cls.__module__.split('.')[0]}.{cls.__name__}: WithWorkflow is deprecated. Use `context.workflow` to access the workflow." + ) + super().__init_subclass__() + + +class OutputFieldJSONSchemaExtra(BaseModel): + """ + Extra attributes to be added to input fields and their OpenAPI schema. Used by the workflow editor + during schema parsing and UI rendering. + """ + + field_kind: FieldKind + ui_hidden: bool + ui_type: Optional[UIType] + ui_order: Optional[int] + + model_config = ConfigDict( + validate_assignment=True, + json_schema_serialization_defaults_required=True, + ) + + +def InputField( + # copied from pydantic's Field + # TODO: Can we support default_factory? + default: Any = _Unset, + default_factory: Callable[[], Any] | None = _Unset, + title: str | None = _Unset, + description: str | None = _Unset, + pattern: str | None = _Unset, + strict: bool | None = _Unset, + gt: float | None = _Unset, + ge: float | None = _Unset, + lt: float | None = _Unset, + le: float | None = _Unset, + multiple_of: float | None = _Unset, + allow_inf_nan: bool | None = _Unset, + max_digits: int | None = _Unset, + decimal_places: int | None = _Unset, + min_length: int | None = _Unset, + max_length: int | None = _Unset, + # custom + input: Input = Input.Any, + ui_type: Optional[UIType] = None, + ui_component: Optional[UIComponent] = None, + ui_hidden: bool = False, + ui_order: Optional[int] = None, + ui_choice_labels: Optional[dict[str, str]] = None, +) -> Any: + """ + Creates an input field for an invocation. + + This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/latest/api/fields/#pydantic.fields.Field) \ + that adds a few extra parameters to support graph execution and the node editor UI. + + :param Input input: [Input.Any] The kind of input this field requires. \ + `Input.Direct` means a value must be provided on instantiation. \ + `Input.Connection` means the value must be provided by a connection. \ + `Input.Any` means either will do. + + :param UIType ui_type: [None] Optionally provides an extra type hint for the UI. \ + In some situations, the field's type is not enough to infer the correct UI type. \ + For example, model selection fields should render a dropdown UI component to select a model. \ + Internally, there is no difference between SD-1, SD-2 and SDXL model fields, they all use \ + `MainModelField`. So to ensure the base-model-specific UI is rendered, you can use \ + `UIType.SDXLMainModelField` to indicate that the field is an SDXL main model field. + + :param UIComponent ui_component: [None] Optionally specifies a specific component to use in the UI. \ + The UI will always render a suitable component, but sometimes you want something different than the default. \ + For example, a `string` field will default to a single-line input, but you may want a multi-line textarea instead. \ + For this case, you could provide `UIComponent.Textarea`. + + :param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI. + + :param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. + + :param dict[str, str] ui_choice_labels: [None] Specifies the labels to use for the choices in an enum field. + """ + + json_schema_extra_ = InputFieldJSONSchemaExtra( + input=input, + ui_type=ui_type, + ui_component=ui_component, + ui_hidden=ui_hidden, + ui_order=ui_order, + ui_choice_labels=ui_choice_labels, + field_kind=FieldKind.Input, + orig_required=True, + ) + + """ + There is a conflict between the typing of invocation definitions and the typing of an invocation's + `invoke()` function. + + On instantiation of a node, the invocation definition is used to create the python class. At this time, + any number of fields may be optional, because they may be provided by connections. + + On calling of `invoke()`, however, those fields may be required. + + For example, consider an ResizeImageInvocation with an `image: ImageField` field. + + `image` is required during the call to `invoke()`, but when the python class is instantiated, + the field may not be present. This is fine, because that image field will be provided by a + connection from an ancestor node, which outputs an image. + + This means we want to type the `image` field as optional for the node class definition, but required + for the `invoke()` function. + + If we use `typing.Optional` in the node class definition, the field will be typed as optional in the + `invoke()` method, and we'll have to do a lot of runtime checks to ensure the field is present - or + any static type analysis tools will complain. + + To get around this, in node class definitions, we type all fields correctly for the `invoke()` function, + but secretly make them optional in `InputField()`. We also store the original required bool and/or default + value. When we call `invoke()`, we use this stored information to do an additional check on the class. + """ + + if default_factory is not _Unset and default_factory is not None: + default = default_factory() + logger.warn('"default_factory" is not supported, calling it now to set "default"') + + # These are the args we may wish pass to the pydantic `Field()` function + field_args = { + "default": default, + "title": title, + "description": description, + "pattern": pattern, + "strict": strict, + "gt": gt, + "ge": ge, + "lt": lt, + "le": le, + "multiple_of": multiple_of, + "allow_inf_nan": allow_inf_nan, + "max_digits": max_digits, + "decimal_places": decimal_places, + "min_length": min_length, + "max_length": max_length, + } + + # We only want to pass the args that were provided, otherwise the `Field()`` function won't work as expected + provided_args = {k: v for (k, v) in field_args.items() if v is not PydanticUndefined} + + # Because we are manually making fields optional, we need to store the original required bool for reference later + json_schema_extra_.orig_required = default is PydanticUndefined + + # Make Input.Any and Input.Connection fields optional, providing None as a default if the field doesn't already have one + if input is Input.Any or input is Input.Connection: + default_ = None if default is PydanticUndefined else default + provided_args.update({"default": default_}) + if default is not PydanticUndefined: + # Before invoking, we'll check for the original default value and set it on the field if the field has no value + json_schema_extra_.default = default + json_schema_extra_.orig_default = default + elif default is not PydanticUndefined: + default_ = default + provided_args.update({"default": default_}) + json_schema_extra_.orig_default = default_ + + return Field( + **provided_args, + json_schema_extra=json_schema_extra_.model_dump(exclude_none=True), + ) + + +def OutputField( + # copied from pydantic's Field + default: Any = _Unset, + title: str | None = _Unset, + description: str | None = _Unset, + pattern: str | None = _Unset, + strict: bool | None = _Unset, + gt: float | None = _Unset, + ge: float | None = _Unset, + lt: float | None = _Unset, + le: float | None = _Unset, + multiple_of: float | None = _Unset, + allow_inf_nan: bool | None = _Unset, + max_digits: int | None = _Unset, + decimal_places: int | None = _Unset, + min_length: int | None = _Unset, + max_length: int | None = _Unset, + # custom + ui_type: Optional[UIType] = None, + ui_hidden: bool = False, + ui_order: Optional[int] = None, +) -> Any: + """ + Creates an output field for an invocation output. + + This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/1.10/usage/schema/#field-customization) \ + that adds a few extra parameters to support graph execution and the node editor UI. + + :param UIType ui_type: [None] Optionally provides an extra type hint for the UI. \ + In some situations, the field's type is not enough to infer the correct UI type. \ + For example, model selection fields should render a dropdown UI component to select a model. \ + Internally, there is no difference between SD-1, SD-2 and SDXL model fields, they all use \ + `MainModelField`. So to ensure the base-model-specific UI is rendered, you can use \ + `UIType.SDXLMainModelField` to indicate that the field is an SDXL main model field. + + :param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI. \ + + :param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. \ + """ + return Field( + default=default, + title=title, + description=description, + pattern=pattern, + strict=strict, + gt=gt, + ge=ge, + lt=lt, + le=le, + multiple_of=multiple_of, + allow_inf_nan=allow_inf_nan, + max_digits=max_digits, + decimal_places=decimal_places, + min_length=min_length, + max_length=max_length, + json_schema_extra=OutputFieldJSONSchemaExtra( + ui_type=ui_type, + ui_hidden=ui_hidden, + ui_order=ui_order, + field_kind=FieldKind.Output, + ).model_dump(exclude_none=True), + ) + # endregion diff --git a/invokeai/app/invocations/image.py b/invokeai/app/invocations/image.py index f729d60cdd5..16d0f33dda3 100644 --- a/invokeai/app/invocations/image.py +++ b/invokeai/app/invocations/image.py @@ -7,19 +7,16 @@ import numpy from PIL import Image, ImageChops, ImageFilter, ImageOps +from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, WithMetadata from invokeai.app.invocations.primitives import BoardField, ColorField, ImageField, ImageOutput from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin -from invokeai.app.shared.fields import FieldDescriptions from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark from invokeai.backend.image_util.safety_checker import SafetyChecker from .baseinvocation import ( BaseInvocation, Classification, - Input, - InputField, InvocationContext, - WithMetadata, invocation, ) diff --git a/invokeai/app/invocations/infill.py b/invokeai/app/invocations/infill.py index c3d00bb1330..d4d3d5bea44 100644 --- a/invokeai/app/invocations/infill.py +++ b/invokeai/app/invocations/infill.py @@ -13,7 +13,8 @@ from invokeai.backend.image_util.lama import LaMA from invokeai.backend.image_util.patchmatch import PatchMatch -from .baseinvocation import BaseInvocation, InputField, InvocationContext, WithMetadata, invocation +from .baseinvocation import BaseInvocation, InvocationContext, invocation +from .fields import InputField, WithMetadata from .image import PIL_RESAMPLING_MAP, PIL_RESAMPLING_MODES diff --git a/invokeai/app/invocations/ip_adapter.py b/invokeai/app/invocations/ip_adapter.py index 6bd28896244..c01e0ed0fb2 100644 --- a/invokeai/app/invocations/ip_adapter.py +++ b/invokeai/app/invocations/ip_adapter.py @@ -7,16 +7,13 @@ from invokeai.app.invocations.baseinvocation import ( BaseInvocation, BaseInvocationOutput, - Input, - InputField, InvocationContext, - OutputField, invocation, invocation_output, ) +from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField from invokeai.app.invocations.primitives import ImageField from invokeai.app.invocations.util import validate_begin_end_step, validate_weights -from invokeai.app.shared.fields import FieldDescriptions from invokeai.backend.model_management.models.base import BaseModelType, ModelType from invokeai.backend.model_management.models.ip_adapter import get_ip_adapter_image_encoder_model_id diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index b77363ceb86..909c307481e 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -23,6 +23,7 @@ from pydantic import field_validator from torchvision.transforms.functional import resize as tv_resize +from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType, WithMetadata from invokeai.app.invocations.ip_adapter import IPAdapterField from invokeai.app.invocations.primitives import ( DenoiseMaskField, @@ -35,7 +36,6 @@ ) from invokeai.app.invocations.t2i_adapter import T2IAdapterField from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin -from invokeai.app.shared.fields import FieldDescriptions from invokeai.app.util.controlnet_utils import prepare_control_image from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus @@ -59,12 +59,7 @@ from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, - Input, - InputField, InvocationContext, - OutputField, - UIType, - WithMetadata, invocation, invocation_output, ) diff --git a/invokeai/app/invocations/math.py b/invokeai/app/invocations/math.py index defc61275fe..6ca53011f0b 100644 --- a/invokeai/app/invocations/math.py +++ b/invokeai/app/invocations/math.py @@ -5,10 +5,10 @@ import numpy as np from pydantic import ValidationInfo, field_validator +from invokeai.app.invocations.fields import FieldDescriptions, InputField from invokeai.app.invocations.primitives import FloatOutput, IntegerOutput -from invokeai.app.shared.fields import FieldDescriptions -from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation +from .baseinvocation import BaseInvocation, InvocationContext, invocation @invocation("add", title="Add Integers", tags=["math", "add"], category="math", version="1.0.0") diff --git a/invokeai/app/invocations/metadata.py b/invokeai/app/invocations/metadata.py index 14d66f8ef68..399e217dc17 100644 --- a/invokeai/app/invocations/metadata.py +++ b/invokeai/app/invocations/metadata.py @@ -5,20 +5,16 @@ from invokeai.app.invocations.baseinvocation import ( BaseInvocation, BaseInvocationOutput, - InputField, InvocationContext, - MetadataField, - OutputField, - UIType, invocation, invocation_output, ) from invokeai.app.invocations.controlnet_image_processors import ControlField +from invokeai.app.invocations.fields import FieldDescriptions, InputField, MetadataField, OutputField, UIType from invokeai.app.invocations.ip_adapter import IPAdapterModelField from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField from invokeai.app.invocations.primitives import ImageField from invokeai.app.invocations.t2i_adapter import T2IAdapterField -from invokeai.app.shared.fields import FieldDescriptions from ...version import __version__ diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index 99dcc72999b..c710c9761b0 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -3,17 +3,14 @@ from pydantic import BaseModel, ConfigDict, Field -from invokeai.app.shared.fields import FieldDescriptions +from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField from invokeai.app.shared.models import FreeUConfig from ...backend.model_management import BaseModelType, ModelType, SubModelType from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, - Input, - InputField, InvocationContext, - OutputField, invocation, invocation_output, ) diff --git a/invokeai/app/invocations/noise.py b/invokeai/app/invocations/noise.py index b1ee91e1cdf..2e717ac561b 100644 --- a/invokeai/app/invocations/noise.py +++ b/invokeai/app/invocations/noise.py @@ -4,17 +4,15 @@ import torch from pydantic import field_validator +from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField from invokeai.app.invocations.latent import LatentsField -from invokeai.app.shared.fields import FieldDescriptions from invokeai.app.util.misc import SEED_MAX from ...backend.util.devices import choose_torch_device, torch_dtype from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, - InputField, InvocationContext, - OutputField, invocation, invocation_output, ) diff --git a/invokeai/app/invocations/onnx.py b/invokeai/app/invocations/onnx.py index 759cfde700f..b43d7eaef2c 100644 --- a/invokeai/app/invocations/onnx.py +++ b/invokeai/app/invocations/onnx.py @@ -11,9 +11,17 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator from tqdm import tqdm +from invokeai.app.invocations.fields import ( + FieldDescriptions, + Input, + InputField, + OutputField, + UIComponent, + UIType, + WithMetadata, +) from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput, ImageField, ImageOutput from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin -from invokeai.app.shared.fields import FieldDescriptions from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.backend import BaseModelType, ModelType, SubModelType @@ -24,13 +32,7 @@ from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, - Input, - InputField, InvocationContext, - OutputField, - UIComponent, - UIType, - WithMetadata, invocation, invocation_output, ) diff --git a/invokeai/app/invocations/param_easing.py b/invokeai/app/invocations/param_easing.py index dccd18f754b..dab9c3dc0f4 100644 --- a/invokeai/app/invocations/param_easing.py +++ b/invokeai/app/invocations/param_easing.py @@ -41,7 +41,8 @@ from invokeai.app.invocations.primitives import FloatCollectionOutput -from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation +from .baseinvocation import BaseInvocation, InvocationContext, invocation +from .fields import InputField @invocation( diff --git a/invokeai/app/invocations/primitives.py b/invokeai/app/invocations/primitives.py index afe8ff06d9d..22f03454a55 100644 --- a/invokeai/app/invocations/primitives.py +++ b/invokeai/app/invocations/primitives.py @@ -5,16 +5,12 @@ import torch from pydantic import BaseModel, Field -from invokeai.app.shared.fields import FieldDescriptions +from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIComponent from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, - Input, - InputField, InvocationContext, - OutputField, - UIComponent, invocation, invocation_output, ) diff --git a/invokeai/app/invocations/prompt.py b/invokeai/app/invocations/prompt.py index 4778d980771..94b4a217ae7 100644 --- a/invokeai/app/invocations/prompt.py +++ b/invokeai/app/invocations/prompt.py @@ -7,7 +7,8 @@ from invokeai.app.invocations.primitives import StringCollectionOutput -from .baseinvocation import BaseInvocation, InputField, InvocationContext, UIComponent, invocation +from .baseinvocation import BaseInvocation, InvocationContext, invocation +from .fields import InputField, UIComponent @invocation( diff --git a/invokeai/app/invocations/sdxl.py b/invokeai/app/invocations/sdxl.py index 68076fdfeb1..62df5bc8047 100644 --- a/invokeai/app/invocations/sdxl.py +++ b/invokeai/app/invocations/sdxl.py @@ -1,14 +1,10 @@ -from invokeai.app.shared.fields import FieldDescriptions +from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType from ...backend.model_management import ModelType, SubModelType from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, - Input, - InputField, InvocationContext, - OutputField, - UIType, invocation, invocation_output, ) diff --git a/invokeai/app/invocations/strings.py b/invokeai/app/invocations/strings.py index 3466206b377..ccbc2f6d924 100644 --- a/invokeai/app/invocations/strings.py +++ b/invokeai/app/invocations/strings.py @@ -5,13 +5,11 @@ from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, - InputField, InvocationContext, - OutputField, - UIComponent, invocation, invocation_output, ) +from .fields import InputField, OutputField, UIComponent from .primitives import StringOutput diff --git a/invokeai/app/invocations/t2i_adapter.py b/invokeai/app/invocations/t2i_adapter.py index e055d23903f..66ac87c37b8 100644 --- a/invokeai/app/invocations/t2i_adapter.py +++ b/invokeai/app/invocations/t2i_adapter.py @@ -5,17 +5,14 @@ from invokeai.app.invocations.baseinvocation import ( BaseInvocation, BaseInvocationOutput, - Input, - InputField, InvocationContext, - OutputField, invocation, invocation_output, ) from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_RESIZE_VALUES +from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField from invokeai.app.invocations.primitives import ImageField from invokeai.app.invocations.util import validate_begin_end_step, validate_weights -from invokeai.app.shared.fields import FieldDescriptions from invokeai.backend.model_management.models.base import BaseModelType diff --git a/invokeai/app/invocations/tiles.py b/invokeai/app/invocations/tiles.py index e51f891a8db..bdc23ef6edd 100644 --- a/invokeai/app/invocations/tiles.py +++ b/invokeai/app/invocations/tiles.py @@ -8,14 +8,11 @@ BaseInvocation, BaseInvocationOutput, Classification, - Input, - InputField, InvocationContext, - OutputField, - WithMetadata, invocation, invocation_output, ) +from invokeai.app.invocations.fields import Input, InputField, OutputField, WithMetadata from invokeai.app.invocations.primitives import ImageField, ImageOutput from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin from invokeai.backend.tiles.tiles import ( diff --git a/invokeai/app/invocations/upscale.py b/invokeai/app/invocations/upscale.py index 5f715c1a7ed..2cab279a9fc 100644 --- a/invokeai/app/invocations/upscale.py +++ b/invokeai/app/invocations/upscale.py @@ -14,7 +14,8 @@ from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN from invokeai.backend.util.devices import choose_torch_device -from .baseinvocation import BaseInvocation, InputField, InvocationContext, WithMetadata, invocation +from .baseinvocation import BaseInvocation, InvocationContext, invocation +from .fields import InputField, WithMetadata # TODO: Populate this from disk? # TODO: Use model manager to load? diff --git a/invokeai/app/services/image_files/image_files_base.py b/invokeai/app/services/image_files/image_files_base.py index 27dd67531f4..f4036277b72 100644 --- a/invokeai/app/services/image_files/image_files_base.py +++ b/invokeai/app/services/image_files/image_files_base.py @@ -4,7 +4,7 @@ from PIL.Image import Image as PILImageType -from invokeai.app.invocations.baseinvocation import MetadataField +from invokeai.app.invocations.fields import MetadataField from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID diff --git a/invokeai/app/services/image_files/image_files_disk.py b/invokeai/app/services/image_files/image_files_disk.py index 08448216723..fb687973bad 100644 --- a/invokeai/app/services/image_files/image_files_disk.py +++ b/invokeai/app/services/image_files/image_files_disk.py @@ -7,7 +7,7 @@ from PIL.Image import Image as PILImageType from send2trash import send2trash -from invokeai.app.invocations.baseinvocation import MetadataField +from invokeai.app.invocations.fields import MetadataField from invokeai.app.services.invoker import Invoker from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail diff --git a/invokeai/app/services/image_records/image_records_base.py b/invokeai/app/services/image_records/image_records_base.py index 727f4977fba..7b7b261ecab 100644 --- a/invokeai/app/services/image_records/image_records_base.py +++ b/invokeai/app/services/image_records/image_records_base.py @@ -2,7 +2,7 @@ from datetime import datetime from typing import Optional -from invokeai.app.invocations.metadata import MetadataField +from invokeai.app.invocations.fields import MetadataField from invokeai.app.services.shared.pagination import OffsetPaginatedResults from .image_records_common import ImageCategory, ImageRecord, ImageRecordChanges, ResourceOrigin diff --git a/invokeai/app/services/image_records/image_records_sqlite.py b/invokeai/app/services/image_records/image_records_sqlite.py index 74f82e7d84c..5b37913c8fd 100644 --- a/invokeai/app/services/image_records/image_records_sqlite.py +++ b/invokeai/app/services/image_records/image_records_sqlite.py @@ -3,7 +3,7 @@ from datetime import datetime from typing import Optional, Union, cast -from invokeai.app.invocations.baseinvocation import MetadataField, MetadataFieldValidator +from invokeai.app.invocations.fields import MetadataField, MetadataFieldValidator from invokeai.app.services.shared.pagination import OffsetPaginatedResults from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase diff --git a/invokeai/app/services/images/images_base.py b/invokeai/app/services/images/images_base.py index df71dadb5b0..42c42667744 100644 --- a/invokeai/app/services/images/images_base.py +++ b/invokeai/app/services/images/images_base.py @@ -3,7 +3,7 @@ from PIL.Image import Image as PILImageType -from invokeai.app.invocations.baseinvocation import MetadataField +from invokeai.app.invocations.fields import MetadataField from invokeai.app.services.image_records.image_records_common import ( ImageCategory, ImageRecord, diff --git a/invokeai/app/services/images/images_default.py b/invokeai/app/services/images/images_default.py index ff21731a506..adeed738119 100644 --- a/invokeai/app/services/images/images_default.py +++ b/invokeai/app/services/images/images_default.py @@ -2,7 +2,7 @@ from PIL.Image import Image as PILImageType -from invokeai.app.invocations.baseinvocation import MetadataField +from invokeai.app.invocations.fields import MetadataField from invokeai.app.services.invoker import Invoker from invokeai.app.services.shared.pagination import OffsetPaginatedResults from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID diff --git a/invokeai/app/services/shared/graph.py b/invokeai/app/services/shared/graph.py index 1acf165abac..ba05b050c5b 100644 --- a/invokeai/app/services/shared/graph.py +++ b/invokeai/app/services/shared/graph.py @@ -13,14 +13,11 @@ from invokeai.app.invocations.baseinvocation import ( BaseInvocation, BaseInvocationOutput, - Input, - InputField, InvocationContext, - OutputField, - UIType, invocation, invocation_output, ) +from invokeai.app.invocations.fields import Input, InputField, OutputField, UIType from invokeai.app.util.misc import uuid_string # in 3.10 this would be "from types import NoneType" diff --git a/invokeai/app/shared/fields.py b/invokeai/app/shared/fields.py deleted file mode 100644 index 3e841ffbf22..00000000000 --- a/invokeai/app/shared/fields.py +++ /dev/null @@ -1,67 +0,0 @@ -class FieldDescriptions: - denoising_start = "When to start denoising, expressed a percentage of total steps" - denoising_end = "When to stop denoising, expressed a percentage of total steps" - cfg_scale = "Classifier-Free Guidance scale" - cfg_rescale_multiplier = "Rescale multiplier for CFG guidance, used for models trained with zero-terminal SNR" - scheduler = "Scheduler to use during inference" - positive_cond = "Positive conditioning tensor" - negative_cond = "Negative conditioning tensor" - noise = "Noise tensor" - clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count" - unet = "UNet (scheduler, LoRAs)" - vae = "VAE" - cond = "Conditioning tensor" - controlnet_model = "ControlNet model to load" - vae_model = "VAE model to load" - lora_model = "LoRA model to load" - main_model = "Main model (UNet, VAE, CLIP) to load" - sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load" - sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load" - onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load" - lora_weight = "The weight at which the LoRA is applied to each model" - compel_prompt = "Prompt to be parsed by Compel to create a conditioning tensor" - raw_prompt = "Raw prompt text (no parsing)" - sdxl_aesthetic = "The aesthetic score to apply to the conditioning tensor" - skipped_layers = "Number of layers to skip in text encoder" - seed = "Seed for random number generation" - steps = "Number of steps to run" - width = "Width of output (px)" - height = "Height of output (px)" - control = "ControlNet(s) to apply" - ip_adapter = "IP-Adapter to apply" - t2i_adapter = "T2I-Adapter(s) to apply" - denoised_latents = "Denoised latents tensor" - latents = "Latents tensor" - strength = "Strength of denoising (proportional to steps)" - metadata = "Optional metadata to be saved with the image" - metadata_collection = "Collection of Metadata" - metadata_item_polymorphic = "A single metadata item or collection of metadata items" - metadata_item_label = "Label for this metadata item" - metadata_item_value = "The value for this metadata item (may be any type)" - workflow = "Optional workflow to be saved with the image" - interp_mode = "Interpolation mode" - torch_antialias = "Whether or not to apply antialiasing (bilinear or bicubic only)" - fp32 = "Whether or not to use full float32 precision" - precision = "Precision to use" - tiled = "Processing using overlapping tiles (reduce memory consumption)" - detect_res = "Pixel resolution for detection" - image_res = "Pixel resolution for output image" - safe_mode = "Whether or not to use safe mode" - scribble_mode = "Whether or not to use scribble mode" - scale_factor = "The factor by which to scale" - blend_alpha = ( - "Blending factor. 0.0 = use input A only, 1.0 = use input B only, 0.5 = 50% mix of input A and input B." - ) - num_1 = "The first number" - num_2 = "The second number" - mask = "The mask to use for the operation" - board = "The board to save the image to" - image = "The image to process" - tile_size = "Tile size" - inclusive_low = "The inclusive low value" - exclusive_high = "The exclusive high value" - decimal_places = "The number of decimal places to round to" - freeu_s1 = 'Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process.' - freeu_s2 = 'Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process.' - freeu_b1 = "Scaling factor for stage 1 to amplify the contributions of backbone features." - freeu_b2 = "Scaling factor for stage 2 to amplify the contributions of backbone features." diff --git a/invokeai/app/shared/models.py b/invokeai/app/shared/models.py index ed68cb287e3..1a11b480cc5 100644 --- a/invokeai/app/shared/models.py +++ b/invokeai/app/shared/models.py @@ -1,6 +1,6 @@ from pydantic import BaseModel, Field -from invokeai.app.shared.fields import FieldDescriptions +from invokeai.app.invocations.fields import FieldDescriptions class FreeUConfig(BaseModel): diff --git a/tests/aa_nodes/test_nodes.py b/tests/aa_nodes/test_nodes.py index bca4e1011f8..e71daad3f3a 100644 --- a/tests/aa_nodes/test_nodes.py +++ b/tests/aa_nodes/test_nodes.py @@ -3,12 +3,11 @@ from invokeai.app.invocations.baseinvocation import ( BaseInvocation, BaseInvocationOutput, - InputField, InvocationContext, - OutputField, invocation, invocation_output, ) +from invokeai.app.invocations.fields import InputField, OutputField from invokeai.app.invocations.image import ImageField From c61340c874c9920578dbddc4501a7842d2f1f8cc Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 13 Jan 2024 18:02:58 +1100 Subject: [PATCH 002/340] feat(nodes): restricts invocation context power Creates a low-power `InvocationContext` with simplified methods and data. See `invocation_context.py` for detailed comments. --- .../app/services/shared/invocation_context.py | 408 ++++++++++++++++++ invokeai/app/util/step_callback.py | 39 +- 2 files changed, 434 insertions(+), 13 deletions(-) create mode 100644 invokeai/app/services/shared/invocation_context.py diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py new file mode 100644 index 00000000000..c0aaac54f87 --- /dev/null +++ b/invokeai/app/services/shared/invocation_context.py @@ -0,0 +1,408 @@ +from dataclasses import dataclass +from enum import Enum +from typing import TYPE_CHECKING, Optional + +from PIL.Image import Image +from pydantic import ConfigDict +from torch import Tensor + +from invokeai.app.invocations.compel import ConditioningFieldData +from invokeai.app.invocations.fields import MetadataField, WithMetadata +from invokeai.app.services.config.config_default import InvokeAIAppConfig +from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin +from invokeai.app.services.images.images_common import ImageDTO +from invokeai.app.services.invocation_services import InvocationServices +from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID +from invokeai.app.util.misc import uuid_string +from invokeai.app.util.step_callback import stable_diffusion_step_callback +from invokeai.backend.model_management.model_manager import ModelInfo +from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType +from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState + +if TYPE_CHECKING: + from invokeai.app.invocations.baseinvocation import BaseInvocation + +""" +The InvocationContext provides access to various services and data about the current invocation. + +We do not provide the invocation services directly, as their methods are both dangerous and +inconvenient to use. + +For example: +- The `images` service allows nodes to delete or unsafely modify existing images. +- The `configuration` service allows nodes to change the app's config at runtime. +- The `events` service allows nodes to emit arbitrary events. + +Wrapping these services provides a simpler and safer interface for nodes to use. + +When a node executes, a fresh `InvocationContext` is built for it, ensuring nodes cannot interfere +with each other. + +Note: The docstrings are in weird places, but that's where they must be to get IDEs to see them. +""" + + +@dataclass(frozen=True) +class InvocationContextData: + invocation: "BaseInvocation" + session_id: str + queue_id: str + source_node_id: str + queue_item_id: int + batch_id: str + workflow: Optional[WorkflowWithoutID] = None + + +class LoggerInterface: + def __init__(self, services: InvocationServices) -> None: + def debug(message: str) -> None: + """ + Logs a debug message. + + :param message: The message to log. + """ + services.logger.debug(message) + + def info(message: str) -> None: + """ + Logs an info message. + + :param message: The message to log. + """ + services.logger.info(message) + + def warning(message: str) -> None: + """ + Logs a warning message. + + :param message: The message to log. + """ + services.logger.warning(message) + + def error(message: str) -> None: + """ + Logs an error message. + + :param message: The message to log. + """ + services.logger.error(message) + + self.debug = debug + self.info = info + self.warning = warning + self.error = error + + +class ImagesInterface: + def __init__(self, services: InvocationServices, context_data: InvocationContextData) -> None: + def save( + image: Image, + board_id: Optional[str] = None, + image_category: ImageCategory = ImageCategory.GENERAL, + metadata: Optional[MetadataField] = None, + ) -> ImageDTO: + """ + Saves an image, returning its DTO. + + If the current queue item has a workflow, it is automatically saved with the image. + + :param image: The image to save, as a PIL image. + :param board_id: The board ID to add the image to, if it should be added. + :param image_category: The category of the image. Only the GENERAL category is added to the gallery. + :param metadata: The metadata to save with the image, if it should have any. If the invocation inherits \ + from `WithMetadata`, that metadata will be used automatically. Provide this only if you want to \ + override or provide metadata manually. + """ + + # If the invocation inherits metadata, use that. Else, use the metadata passed in. + metadata_ = ( + context_data.invocation.metadata if isinstance(context_data.invocation, WithMetadata) else metadata + ) + + return services.images.create( + image=image, + is_intermediate=context_data.invocation.is_intermediate, + image_category=image_category, + board_id=board_id, + metadata=metadata_, + image_origin=ResourceOrigin.INTERNAL, + workflow=context_data.workflow, + session_id=context_data.session_id, + node_id=context_data.invocation.id, + ) + + def get_pil(image_name: str) -> Image: + """ + Gets an image as a PIL Image object. + + :param image_name: The name of the image to get. + """ + return services.images.get_pil_image(image_name) + + def get_metadata(image_name: str) -> Optional[MetadataField]: + """ + Gets an image's metadata, if it has any. + + :param image_name: The name of the image to get the metadata for. + """ + return services.images.get_metadata(image_name) + + def get_dto(image_name: str) -> ImageDTO: + """ + Gets an image as an ImageDTO object. + + :param image_name: The name of the image to get. + """ + return services.images.get_dto(image_name) + + def update( + image_name: str, + board_id: Optional[str] = None, + is_intermediate: Optional[bool] = False, + ) -> ImageDTO: + """ + Updates an image, returning its updated DTO. + + It is not suggested to update images saved by earlier nodes, as this can cause confusion for users. + + If you use this method, you *must* return the image as an :class:`ImageOutput` for the gallery to + get the updated image. + + :param image_name: The name of the image to update. + :param board_id: The board ID to add the image to, if it should be added. + :param is_intermediate: Whether the image is an intermediate. Intermediate images aren't added to the gallery. + """ + if is_intermediate is not None: + services.images.update(image_name, ImageRecordChanges(is_intermediate=is_intermediate)) + if board_id is None: + services.board_images.remove_image_from_board(image_name) + else: + services.board_images.add_image_to_board(image_name, board_id) + return services.images.get_dto(image_name) + + self.save = save + self.get_pil = get_pil + self.get_metadata = get_metadata + self.get_dto = get_dto + self.update = update + + +class LatentsKind(str, Enum): + IMAGE = "image" + NOISE = "noise" + MASK = "mask" + MASKED_IMAGE = "masked_image" + OTHER = "other" + + +class LatentsInterface: + def __init__( + self, + services: InvocationServices, + context_data: InvocationContextData, + ) -> None: + def save(tensor: Tensor) -> str: + """ + Saves a latents tensor, returning its name. + + :param tensor: The latents tensor to save. + """ + name = f"{context_data.session_id}__{context_data.invocation.id}__{uuid_string()[:7]}" + services.latents.save( + name=name, + data=tensor, + ) + return name + + def get(latents_name: str) -> Tensor: + """ + Gets a latents tensor by name. + + :param latents_name: The name of the latents tensor to get. + """ + return services.latents.get(latents_name) + + self.save = save + self.get = get + + +class ConditioningInterface: + def __init__( + self, + services: InvocationServices, + context_data: InvocationContextData, + ) -> None: + def save(conditioning_data: ConditioningFieldData) -> str: + """ + Saves a conditioning data object, returning its name. + + :param conditioning_data: The conditioning data to save. + """ + name = f"{context_data.session_id}__{context_data.invocation.id}__{uuid_string()[:7]}__conditioning" + services.latents.save( + name=name, + data=conditioning_data, # type: ignore [arg-type] + ) + return name + + def get(conditioning_name: str) -> Tensor: + """ + Gets conditioning data by name. + + :param conditioning_name: The name of the conditioning data to get. + """ + return services.latents.get(conditioning_name) + + self.save = save + self.get = get + + +class ModelsInterface: + def __init__(self, services: InvocationServices, context_data: InvocationContextData) -> None: + def exists(model_name: str, base_model: BaseModelType, model_type: ModelType) -> bool: + """ + Checks if a model exists. + + :param model_name: The name of the model to check. + :param base_model: The base model of the model to check. + :param model_type: The type of the model to check. + """ + return services.model_manager.model_exists(model_name, base_model, model_type) + + def load( + model_name: str, base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None + ) -> ModelInfo: + """ + Loads a model, returning its `ModelInfo` object. + + :param model_name: The name of the model to get. + :param base_model: The base model of the model to get. + :param model_type: The type of the model to get. + :param submodel: The submodel of the model to get. + """ + return services.model_manager.get_model( + model_name, base_model, model_type, submodel, context_data=context_data + ) + + def get_info(model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict: + """ + Gets a model's info, an dict-like object. + + :param model_name: The name of the model to get. + :param base_model: The base model of the model to get. + :param model_type: The type of the model to get. + """ + return services.model_manager.model_info(model_name, base_model, model_type) + + self.exists = exists + self.load = load + self.get_info = get_info + + +class ConfigInterface: + def __init__(self, services: InvocationServices) -> None: + def get() -> InvokeAIAppConfig: + """ + Gets the app's config. + """ + # The config can be changed at runtime. We don't want nodes doing this, so we make a + # frozen copy.. + config = services.configuration.get_config() + frozen_config = config.model_copy(update={"model_config": ConfigDict(frozen=True)}) + return frozen_config + + self.get = get + + +class UtilInterface: + def __init__(self, services: InvocationServices, context_data: InvocationContextData) -> None: + def sd_step_callback( + intermediate_state: PipelineIntermediateState, + base_model: BaseModelType, + ) -> None: + """ + The step callback emits a progress event with the current step, the total number of + steps, a preview image, and some other internal metadata. + + This should be called after each step of the diffusion process. + + :param intermediate_state: The intermediate state of the diffusion pipeline. + :param base_model: The base model for the current denoising step. + """ + stable_diffusion_step_callback( + context_data=context_data, + intermediate_state=intermediate_state, + base_model=base_model, + invocation_queue=services.queue, + events=services.events, + ) + + self.sd_step_callback = sd_step_callback + + +class InvocationContext: + """ + The invocation context provides access to various services and data about the current invocation. + """ + + def __init__( + self, + images: ImagesInterface, + latents: LatentsInterface, + models: ModelsInterface, + config: ConfigInterface, + logger: LoggerInterface, + data: InvocationContextData, + util: UtilInterface, + conditioning: ConditioningInterface, + ) -> None: + self.images = images + "Provides methods to save, get and update images and their metadata." + self.logger = logger + "Provides access to the app logger." + self.latents = latents + "Provides methods to save and get latents tensors, including image, noise, masks, and masked images." + self.conditioning = conditioning + "Provides methods to save and get conditioning data." + self.models = models + "Provides methods to check if a model exists, get a model, and get a model's info." + self.config = config + "Provides access to the app's config." + self.data = data + "Provides data about the current queue item and invocation." + self.util = util + "Provides utility methods." + + +def build_invocation_context( + services: InvocationServices, + context_data: InvocationContextData, +) -> InvocationContext: + """ + Builds the invocation context. This is a wrapper around the invocation services that provides + a more convenient (and less dangerous) interface for nodes to use. + + :param invocation_services: The invocation services to wrap. + :param invocation_context_data: The invocation context data. + """ + + logger = LoggerInterface(services=services) + images = ImagesInterface(services=services, context_data=context_data) + latents = LatentsInterface(services=services, context_data=context_data) + models = ModelsInterface(services=services, context_data=context_data) + config = ConfigInterface(services=services) + util = UtilInterface(services=services, context_data=context_data) + conditioning = ConditioningInterface(services=services, context_data=context_data) + + ctx = InvocationContext( + images=images, + logger=logger, + config=config, + latents=latents, + models=models, + data=context_data, + util=util, + conditioning=conditioning, + ) + + return ctx diff --git a/invokeai/app/util/step_callback.py b/invokeai/app/util/step_callback.py index f166206d528..5cc3caa9ba5 100644 --- a/invokeai/app/util/step_callback.py +++ b/invokeai/app/util/step_callback.py @@ -1,12 +1,25 @@ +from typing import Protocol + import torch from PIL import Image +from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException, ProgressImage +from invokeai.app.services.invocation_queue.invocation_queue_base import InvocationQueueABC +from invokeai.app.services.shared.invocation_context import InvocationContextData from ...backend.model_management.models import BaseModelType from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.util.util import image_to_dataURL -from ..invocations.baseinvocation import InvocationContext + + +class StepCallback(Protocol): + def __call__( + self, + intermediate_state: PipelineIntermediateState, + base_model: BaseModelType, + ) -> None: + ... def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix=None): @@ -25,13 +38,13 @@ def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix= def stable_diffusion_step_callback( - context: InvocationContext, + context_data: InvocationContextData, intermediate_state: PipelineIntermediateState, - node: dict, - source_node_id: str, base_model: BaseModelType, -): - if context.services.queue.is_canceled(context.graph_execution_state_id): + invocation_queue: InvocationQueueABC, + events: EventServiceBase, +) -> None: + if invocation_queue.is_canceled(context_data.session_id): raise CanceledException # Some schedulers report not only the noisy latents at the current timestep, @@ -108,13 +121,13 @@ def stable_diffusion_step_callback( dataURL = image_to_dataURL(image, image_format="JPEG") - context.services.events.emit_generator_progress( - queue_id=context.queue_id, - queue_item_id=context.queue_item_id, - queue_batch_id=context.queue_batch_id, - graph_execution_state_id=context.graph_execution_state_id, - node=node, - source_node_id=source_node_id, + events.emit_generator_progress( + queue_id=context_data.queue_id, + queue_item_id=context_data.queue_item_id, + queue_batch_id=context_data.batch_id, + graph_execution_state_id=context_data.session_id, + node_id=context_data.invocation.id, + source_node_id=context_data.source_node_id, progress_image=ProgressImage(width=width, height=height, dataURL=dataURL), step=intermediate_state.step, order=intermediate_state.order, From 3a03f41ee1d756fdf5fcbcf2437b7e871f0b56fc Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 13 Jan 2024 23:23:01 +1100 Subject: [PATCH 003/340] feat: add pyright config I was having issues with mypy bother over- and under-reporting certain problems. I've added a pyright config. --- pyproject.toml | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 7f4b0d77f25..d063f1ad0ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -280,3 +280,19 @@ module = [ "invokeai.frontend.install.model_install", ] #=== End: MyPy + +[tool.pyright] +include = [ + "invokeai/app/invocations/" +] +exclude = [ + "**/node_modules", + "**/__pycache__", + "invokeai/app/invocations/onnx.py", + "invokeai/app/api/routers/models.py", + "invokeai/app/services/invocation_stats/invocation_stats_default.py", + "invokeai/app/services/model_manager/model_manager_base.py", + "invokeai/app/services/model_manager/model_manager_default.py", + "invokeai/app/services/model_records/model_records_sql.py", + "invokeai/app/util/controlnet_utils.py", +] From 64c03d158e4ca97572739ae36abeed41c29bc293 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 13 Jan 2024 23:23:16 +1100 Subject: [PATCH 004/340] feat(nodes): update all invocations to use new invocation context Update all invocations to use the new context. The changes are all fairly simple, but there are a lot of them. Supporting minor changes: - Patch bump for all nodes that use the context - Update invocation processor to provide new context - Minor change to `EventServiceBase` to accept a node's ID instead of the dict version of a node - Minor change to `ModelManagerService` to support the new wrapped context - Fanagling of imports to avoid circular dependencies --- invokeai/app/invocations/baseinvocation.py | 54 +- invokeai/app/invocations/collections.py | 8 +- invokeai/app/invocations/compel.py | 105 ++-- .../controlnet_image_processors.py | 56 +- invokeai/app/invocations/cv.py | 32 +- invokeai/app/invocations/facetools.py | 129 ++-- invokeai/app/invocations/fields.py | 57 +- invokeai/app/invocations/image.py | 586 +++++------------- invokeai/app/invocations/infill.py | 123 +--- invokeai/app/invocations/ip_adapter.py | 9 +- invokeai/app/invocations/latent.py | 238 +++---- invokeai/app/invocations/math.py | 22 +- invokeai/app/invocations/metadata.py | 19 +- invokeai/app/invocations/model.py | 29 +- invokeai/app/invocations/noise.py | 25 +- invokeai/app/invocations/onnx.py | 10 +- invokeai/app/invocations/param_easing.py | 44 +- invokeai/app/invocations/primitives.py | 136 ++-- invokeai/app/invocations/prompt.py | 6 +- invokeai/app/invocations/sdxl.py | 13 +- invokeai/app/invocations/strings.py | 11 +- invokeai/app/invocations/t2i_adapter.py | 6 +- invokeai/app/invocations/tiles.py | 38 +- invokeai/app/invocations/upscale.py | 33 +- invokeai/app/services/events/events_base.py | 4 +- .../invocation_processor_default.py | 24 +- .../model_manager/model_manager_base.py | 9 +- .../model_manager/model_manager_common.py | 0 .../model_manager/model_manager_default.py | 44 +- invokeai/app/services/shared/graph.py | 7 +- .../app/services/shared/invocation_context.py | 9 +- invokeai/app/util/step_callback.py | 23 +- 32 files changed, 717 insertions(+), 1192 deletions(-) create mode 100644 invokeai/app/services/model_manager/model_manager_common.py diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index 395d5e98707..c4aed1fac5a 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -16,10 +16,16 @@ from pydantic.fields import FieldInfo from pydantic_core import PydanticUndefined -from invokeai.app.invocations.fields import FieldKind, Input +from invokeai.app.invocations.fields import ( + FieldDescriptions, + FieldKind, + Input, + InputFieldJSONSchemaExtra, + MetadataField, + logger, +) from invokeai.app.services.config.config_default import InvokeAIAppConfig -from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID -from invokeai.app.shared.fields import FieldDescriptions +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.metaenum import MetaEnum from invokeai.app.util.misc import uuid_string from invokeai.backend.util.logging import InvokeAILogger @@ -219,7 +225,7 @@ def invoke(self, context: InvocationContext) -> BaseInvocationOutput: """Invoke with provided context and return outputs.""" pass - def invoke_internal(self, context: InvocationContext) -> BaseInvocationOutput: + def invoke_internal(self, context: InvocationContext, services: "InvocationServices") -> BaseInvocationOutput: """ Internal invoke method, calls `invoke()` after some prep. Handles optional fields that are required to call `invoke()` and invocation cache. @@ -244,23 +250,23 @@ def invoke_internal(self, context: InvocationContext) -> BaseInvocationOutput: raise MissingInputException(self.model_fields["type"].default, field_name) # skip node cache codepath if it's disabled - if context.services.configuration.node_cache_size == 0: + if services.configuration.node_cache_size == 0: return self.invoke(context) output: BaseInvocationOutput if self.use_cache: - key = context.services.invocation_cache.create_key(self) - cached_value = context.services.invocation_cache.get(key) + key = services.invocation_cache.create_key(self) + cached_value = services.invocation_cache.get(key) if cached_value is None: - context.services.logger.debug(f'Invocation cache miss for type "{self.get_type()}": {self.id}') + services.logger.debug(f'Invocation cache miss for type "{self.get_type()}": {self.id}') output = self.invoke(context) - context.services.invocation_cache.save(key, output) + services.invocation_cache.save(key, output) return output else: - context.services.logger.debug(f'Invocation cache hit for type "{self.get_type()}": {self.id}') + services.logger.debug(f'Invocation cache hit for type "{self.get_type()}": {self.id}') return cached_value else: - context.services.logger.debug(f'Skipping invocation cache for "{self.get_type()}": {self.id}') + services.logger.debug(f'Skipping invocation cache for "{self.get_type()}": {self.id}') return self.invoke(context) id: str = Field( @@ -513,3 +519,29 @@ def wrapper(cls: Type[TBaseInvocationOutput]) -> Type[TBaseInvocationOutput]: return cls return wrapper + + +class WithMetadata(BaseModel): + """ + Inherit from this class if your node needs a metadata input field. + """ + + metadata: Optional[MetadataField] = Field( + default=None, + description=FieldDescriptions.metadata, + json_schema_extra=InputFieldJSONSchemaExtra( + field_kind=FieldKind.Internal, + input=Input.Connection, + orig_required=False, + ).model_dump(exclude_none=True), + ) + + +class WithWorkflow: + workflow = None + + def __init_subclass__(cls) -> None: + logger.warn( + f"{cls.__module__.split('.')[0]}.{cls.__name__}: WithWorkflow is deprecated. Use `context.workflow` to access the workflow." + ) + super().__init_subclass__() diff --git a/invokeai/app/invocations/collections.py b/invokeai/app/invocations/collections.py index d35a9d79c74..f5709b4ba36 100644 --- a/invokeai/app/invocations/collections.py +++ b/invokeai/app/invocations/collections.py @@ -7,7 +7,7 @@ from invokeai.app.invocations.primitives import IntegerCollectionOutput from invokeai.app.util.misc import SEED_MAX -from .baseinvocation import BaseInvocation, InvocationContext, invocation +from .baseinvocation import BaseInvocation, invocation from .fields import InputField @@ -27,7 +27,7 @@ def stop_gt_start(cls, v: int, info: ValidationInfo): raise ValueError("stop must be greater than start") return v - def invoke(self, context: InvocationContext) -> IntegerCollectionOutput: + def invoke(self, context) -> IntegerCollectionOutput: return IntegerCollectionOutput(collection=list(range(self.start, self.stop, self.step))) @@ -45,7 +45,7 @@ class RangeOfSizeInvocation(BaseInvocation): size: int = InputField(default=1, gt=0, description="The number of values") step: int = InputField(default=1, description="The step of the range") - def invoke(self, context: InvocationContext) -> IntegerCollectionOutput: + def invoke(self, context) -> IntegerCollectionOutput: return IntegerCollectionOutput( collection=list(range(self.start, self.start + (self.step * self.size), self.step)) ) @@ -72,6 +72,6 @@ class RandomRangeInvocation(BaseInvocation): description="The seed for the RNG (omit for random)", ) - def invoke(self, context: InvocationContext) -> IntegerCollectionOutput: + def invoke(self, context) -> IntegerCollectionOutput: rng = np.random.default_rng(self.seed) return IntegerCollectionOutput(collection=list(rng.integers(low=self.low, high=self.high, size=self.size))) diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index b386aef2cbe..b4496031bc4 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -1,12 +1,18 @@ -from dataclasses import dataclass -from typing import List, Optional, Union +from typing import TYPE_CHECKING, List, Optional, Union import torch from compel import Compel, ReturnedEmbeddingsType from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment -from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIComponent -from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput +from invokeai.app.invocations.fields import ( + ConditioningFieldData, + FieldDescriptions, + Input, + InputField, + OutputField, + UIComponent, +) +from invokeai.app.invocations.primitives import ConditioningOutput from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( BasicConditioningInfo, ExtraConditioningInfo, @@ -20,16 +26,14 @@ from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, - InvocationContext, invocation, invocation_output, ) from .model import ClipField +if TYPE_CHECKING: + from invokeai.app.services.shared.invocation_context import InvocationContext -@dataclass -class ConditioningFieldData: - conditionings: List[BasicConditioningInfo] # unconditioned: Optional[torch.Tensor] @@ -44,7 +48,7 @@ class ConditioningFieldData: title="Prompt", tags=["prompt", "compel"], category="conditioning", - version="1.0.0", + version="1.0.1", ) class CompelInvocation(BaseInvocation): """Parse prompt using compel package to conditioning.""" @@ -61,26 +65,18 @@ class CompelInvocation(BaseInvocation): ) @torch.no_grad() - def invoke(self, context: InvocationContext) -> ConditioningOutput: - tokenizer_info = context.services.model_manager.get_model( - **self.clip.tokenizer.model_dump(), - context=context, - ) - text_encoder_info = context.services.model_manager.get_model( - **self.clip.text_encoder.model_dump(), - context=context, - ) + def invoke(self, context) -> ConditioningOutput: + tokenizer_info = context.models.load(**self.clip.tokenizer.model_dump()) + text_encoder_info = context.models.load(**self.clip.text_encoder.model_dump()) def _lora_loader(): for lora in self.clip.loras: - lora_info = context.services.model_manager.get_model( - **lora.model_dump(exclude={"weight"}), context=context - ) + lora_info = context.models.load(**lora.model_dump(exclude={"weight"})) yield (lora_info.context.model, lora.weight) del lora_info return - # loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] + # loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] ti_list = [] for trigger in extract_ti_triggers_from_prompt(self.prompt): @@ -89,11 +85,10 @@ def _lora_loader(): ti_list.append( ( name, - context.services.model_manager.get_model( + context.models.load( model_name=name, base_model=self.clip.text_encoder.base_model, model_type=ModelType.TextualInversion, - context=context, ).context.model, ) ) @@ -124,7 +119,7 @@ def _lora_loader(): conjunction = Compel.parse_prompt_string(self.prompt) - if context.services.configuration.log_tokenization: + if context.config.get().log_tokenization: log_tokenization_for_conjunction(conjunction, tokenizer) c, options = compel.build_conditioning_tensor_for_conjunction(conjunction) @@ -145,34 +140,23 @@ def _lora_loader(): ] ) - conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning" - context.services.latents.save(conditioning_name, conditioning_data) + conditioning_name = context.conditioning.save(conditioning_data) - return ConditioningOutput( - conditioning=ConditioningField( - conditioning_name=conditioning_name, - ), - ) + return ConditioningOutput.build(conditioning_name) class SDXLPromptInvocationBase: def run_clip_compel( self, - context: InvocationContext, + context: "InvocationContext", clip_field: ClipField, prompt: str, get_pooled: bool, lora_prefix: str, zero_on_empty: bool, ): - tokenizer_info = context.services.model_manager.get_model( - **clip_field.tokenizer.model_dump(), - context=context, - ) - text_encoder_info = context.services.model_manager.get_model( - **clip_field.text_encoder.model_dump(), - context=context, - ) + tokenizer_info = context.models.load(**clip_field.tokenizer.model_dump()) + text_encoder_info = context.models.load(**clip_field.text_encoder.model_dump()) # return zero on empty if prompt == "" and zero_on_empty: @@ -196,14 +180,12 @@ def run_clip_compel( def _lora_loader(): for lora in clip_field.loras: - lora_info = context.services.model_manager.get_model( - **lora.model_dump(exclude={"weight"}), context=context - ) + lora_info = context.models.load(**lora.model_dump(exclude={"weight"})) yield (lora_info.context.model, lora.weight) del lora_info return - # loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] + # loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] ti_list = [] for trigger in extract_ti_triggers_from_prompt(prompt): @@ -212,11 +194,10 @@ def _lora_loader(): ti_list.append( ( name, - context.services.model_manager.get_model( + context.models.load( model_name=name, base_model=clip_field.text_encoder.base_model, model_type=ModelType.TextualInversion, - context=context, ).context.model, ) ) @@ -249,7 +230,7 @@ def _lora_loader(): conjunction = Compel.parse_prompt_string(prompt) - if context.services.configuration.log_tokenization: + if context.config.get().log_tokenization: # TODO: better logging for and syntax log_tokenization_for_conjunction(conjunction, tokenizer) @@ -282,7 +263,7 @@ def _lora_loader(): title="SDXL Prompt", tags=["sdxl", "compel", "prompt"], category="conditioning", - version="1.0.0", + version="1.0.1", ) class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): """Parse prompt using compel package to conditioning.""" @@ -307,7 +288,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2") @torch.no_grad() - def invoke(self, context: InvocationContext) -> ConditioningOutput: + def invoke(self, context) -> ConditioningOutput: c1, c1_pooled, ec1 = self.run_clip_compel( context, self.clip, self.prompt, False, "lora_te1_", zero_on_empty=True ) @@ -364,14 +345,9 @@ def invoke(self, context: InvocationContext) -> ConditioningOutput: ] ) - conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning" - context.services.latents.save(conditioning_name, conditioning_data) + conditioning_name = context.conditioning.save(conditioning_data) - return ConditioningOutput( - conditioning=ConditioningField( - conditioning_name=conditioning_name, - ), - ) + return ConditioningOutput.build(conditioning_name) @invocation( @@ -379,7 +355,7 @@ def invoke(self, context: InvocationContext) -> ConditioningOutput: title="SDXL Refiner Prompt", tags=["sdxl", "compel", "prompt"], category="conditioning", - version="1.0.0", + version="1.0.1", ) class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): """Parse prompt using compel package to conditioning.""" @@ -397,7 +373,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection) @torch.no_grad() - def invoke(self, context: InvocationContext) -> ConditioningOutput: + def invoke(self, context) -> ConditioningOutput: # TODO: if there will appear lora for refiner - write proper prefix c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "", zero_on_empty=False) @@ -417,14 +393,9 @@ def invoke(self, context: InvocationContext) -> ConditioningOutput: ] ) - conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning" - context.services.latents.save(conditioning_name, conditioning_data) + conditioning_name = context.conditioning.save(conditioning_data) - return ConditioningOutput( - conditioning=ConditioningField( - conditioning_name=conditioning_name, - ), - ) + return ConditioningOutput.build(conditioning_name) @invocation_output("clip_skip_output") @@ -447,7 +418,7 @@ class ClipSkipInvocation(BaseInvocation): clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP") skipped_layers: int = InputField(default=0, description=FieldDescriptions.skipped_layers) - def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput: + def invoke(self, context) -> ClipSkipInvocationOutput: self.clip.skipped_layers += self.skipped_layers return ClipSkipInvocationOutput( clip=self.clip, diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index 9b652b8eee9..3797722c93e 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -25,18 +25,17 @@ from PIL import Image from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator -from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, WithMetadata -from invokeai.app.invocations.primitives import ImageField, ImageOutput +from invokeai.app.invocations.baseinvocation import WithMetadata +from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField +from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.invocations.util import validate_begin_end_step, validate_weights -from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin from invokeai.backend.image_util.depth_anything import DepthAnythingDetector from invokeai.backend.image_util.dw_openpose import DWOpenposeDetector +from invokeai.backend.model_management.models.base import BaseModelType -from ...backend.model_management import BaseModelType from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, - InvocationContext, invocation, invocation_output, ) @@ -121,7 +120,7 @@ def validate_begin_end_step_percent(self) -> "ControlNetInvocation": validate_begin_end_step(self.begin_step_percent, self.end_step_percent) return self - def invoke(self, context: InvocationContext) -> ControlOutput: + def invoke(self, context) -> ControlOutput: return ControlOutput( control=ControlField( image=self.image, @@ -145,23 +144,14 @@ def run_processor(self, image: Image.Image) -> Image.Image: # superclass just passes through image without processing return image - def invoke(self, context: InvocationContext) -> ImageOutput: - raw_image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + raw_image = context.images.get_pil(self.image.image_name) # image type should be PIL.PngImagePlugin.PngImageFile ? processed_image = self.run_processor(raw_image) # currently can't see processed image in node UI without a showImage node, # so for now setting image_type to RESULT instead of INTERMEDIATE so will get saved in gallery - image_dto = context.services.images.create( - image=processed_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.CONTROL, - session_id=context.graph_execution_state_id, - node_id=self.id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=processed_image) """Builds an ImageOutput and its ImageField""" processed_image_field = ImageField(image_name=image_dto.image_name) @@ -180,7 +170,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Canny Processor", tags=["controlnet", "canny"], category="controlnet", - version="1.2.0", + version="1.2.1", ) class CannyImageProcessorInvocation(ImageProcessorInvocation): """Canny edge detection for ControlNet""" @@ -203,7 +193,7 @@ def run_processor(self, image): title="HED (softedge) Processor", tags=["controlnet", "hed", "softedge"], category="controlnet", - version="1.2.0", + version="1.2.1", ) class HedImageProcessorInvocation(ImageProcessorInvocation): """Applies HED edge detection to image""" @@ -232,7 +222,7 @@ def run_processor(self, image): title="Lineart Processor", tags=["controlnet", "lineart"], category="controlnet", - version="1.2.0", + version="1.2.1", ) class LineartImageProcessorInvocation(ImageProcessorInvocation): """Applies line art processing to image""" @@ -254,7 +244,7 @@ def run_processor(self, image): title="Lineart Anime Processor", tags=["controlnet", "lineart", "anime"], category="controlnet", - version="1.2.0", + version="1.2.1", ) class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation): """Applies line art anime processing to image""" @@ -277,7 +267,7 @@ def run_processor(self, image): title="Midas Depth Processor", tags=["controlnet", "midas"], category="controlnet", - version="1.2.0", + version="1.2.1", ) class MidasDepthImageProcessorInvocation(ImageProcessorInvocation): """Applies Midas depth processing to image""" @@ -304,7 +294,7 @@ def run_processor(self, image): title="Normal BAE Processor", tags=["controlnet"], category="controlnet", - version="1.2.0", + version="1.2.1", ) class NormalbaeImageProcessorInvocation(ImageProcessorInvocation): """Applies NormalBae processing to image""" @@ -321,7 +311,7 @@ def run_processor(self, image): @invocation( - "mlsd_image_processor", title="MLSD Processor", tags=["controlnet", "mlsd"], category="controlnet", version="1.2.0" + "mlsd_image_processor", title="MLSD Processor", tags=["controlnet", "mlsd"], category="controlnet", version="1.2.1" ) class MlsdImageProcessorInvocation(ImageProcessorInvocation): """Applies MLSD processing to image""" @@ -344,7 +334,7 @@ def run_processor(self, image): @invocation( - "pidi_image_processor", title="PIDI Processor", tags=["controlnet", "pidi"], category="controlnet", version="1.2.0" + "pidi_image_processor", title="PIDI Processor", tags=["controlnet", "pidi"], category="controlnet", version="1.2.1" ) class PidiImageProcessorInvocation(ImageProcessorInvocation): """Applies PIDI processing to image""" @@ -371,7 +361,7 @@ def run_processor(self, image): title="Content Shuffle Processor", tags=["controlnet", "contentshuffle"], category="controlnet", - version="1.2.0", + version="1.2.1", ) class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation): """Applies content shuffle processing to image""" @@ -401,7 +391,7 @@ def run_processor(self, image): title="Zoe (Depth) Processor", tags=["controlnet", "zoe", "depth"], category="controlnet", - version="1.2.0", + version="1.2.1", ) class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation): """Applies Zoe depth processing to image""" @@ -417,7 +407,7 @@ def run_processor(self, image): title="Mediapipe Face Processor", tags=["controlnet", "mediapipe", "face"], category="controlnet", - version="1.2.0", + version="1.2.1", ) class MediapipeFaceProcessorInvocation(ImageProcessorInvocation): """Applies mediapipe face processing to image""" @@ -440,7 +430,7 @@ def run_processor(self, image): title="Leres (Depth) Processor", tags=["controlnet", "leres", "depth"], category="controlnet", - version="1.2.0", + version="1.2.1", ) class LeresImageProcessorInvocation(ImageProcessorInvocation): """Applies leres processing to image""" @@ -469,7 +459,7 @@ def run_processor(self, image): title="Tile Resample Processor", tags=["controlnet", "tile"], category="controlnet", - version="1.2.0", + version="1.2.1", ) class TileResamplerProcessorInvocation(ImageProcessorInvocation): """Tile resampler processor""" @@ -509,7 +499,7 @@ def run_processor(self, img): title="Segment Anything Processor", tags=["controlnet", "segmentanything"], category="controlnet", - version="1.2.0", + version="1.2.1", ) class SegmentAnythingProcessorInvocation(ImageProcessorInvocation): """Applies segment anything processing to image""" @@ -551,7 +541,7 @@ def show_anns(self, anns: List[Dict]): title="Color Map Processor", tags=["controlnet"], category="controlnet", - version="1.2.0", + version="1.2.1", ) class ColorMapImageProcessorInvocation(ImageProcessorInvocation): """Generates a color map from the provided image""" diff --git a/invokeai/app/invocations/cv.py b/invokeai/app/invocations/cv.py index 5865338e192..375b18f9c58 100644 --- a/invokeai/app/invocations/cv.py +++ b/invokeai/app/invocations/cv.py @@ -5,23 +5,23 @@ import numpy from PIL import Image, ImageOps -from invokeai.app.invocations.primitives import ImageField, ImageOutput -from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin +from invokeai.app.invocations.fields import ImageField +from invokeai.app.invocations.primitives import ImageOutput -from .baseinvocation import BaseInvocation, InvocationContext, invocation +from .baseinvocation import BaseInvocation, invocation from .fields import InputField, WithMetadata -@invocation("cv_inpaint", title="OpenCV Inpaint", tags=["opencv", "inpaint"], category="inpaint", version="1.2.0") +@invocation("cv_inpaint", title="OpenCV Inpaint", tags=["opencv", "inpaint"], category="inpaint", version="1.2.1") class CvInpaintInvocation(BaseInvocation, WithMetadata): """Simple inpaint using opencv.""" image: ImageField = InputField(description="The image to inpaint") mask: ImageField = InputField(description="The mask to use when inpainting") - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) - mask = context.services.images.get_pil_image(self.mask.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) + mask = context.images.get_pil(self.mask.image_name) # Convert to cv image/mask # TODO: consider making these utility functions @@ -35,18 +35,6 @@ def invoke(self, context: InvocationContext) -> ImageOutput: # TODO: consider making a utility function image_inpainted = Image.fromarray(cv.cvtColor(cv_inpainted, cv.COLOR_BGR2RGB)) - image_dto = context.services.images.create( - image=image_inpainted, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - workflow=context.workflow, - ) - - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + image_dto = context.images.save(image=image_inpainted) + + return ImageOutput.build(image_dto) diff --git a/invokeai/app/invocations/facetools.py b/invokeai/app/invocations/facetools.py index 13f1066ec3e..2c92e28cfe0 100644 --- a/invokeai/app/invocations/facetools.py +++ b/invokeai/app/invocations/facetools.py @@ -1,7 +1,7 @@ import math import re from pathlib import Path -from typing import Optional, TypedDict +from typing import TYPE_CHECKING, Optional, TypedDict import cv2 import numpy as np @@ -13,13 +13,16 @@ import invokeai.assets.fonts as font_assets from invokeai.app.invocations.baseinvocation import ( BaseInvocation, - InvocationContext, + WithMetadata, invocation, invocation_output, ) -from invokeai.app.invocations.fields import InputField, OutputField, WithMetadata -from invokeai.app.invocations.primitives import ImageField, ImageOutput -from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin +from invokeai.app.invocations.fields import ImageField, InputField, OutputField +from invokeai.app.invocations.primitives import ImageOutput +from invokeai.app.services.image_records.image_records_common import ImageCategory + +if TYPE_CHECKING: + from invokeai.app.services.shared.invocation_context import InvocationContext @invocation_output("face_mask_output") @@ -174,7 +177,7 @@ def prepare_faces_list( def generate_face_box_mask( - context: InvocationContext, + context: "InvocationContext", minimum_confidence: float, x_offset: float, y_offset: float, @@ -273,7 +276,7 @@ def generate_face_box_mask( def extract_face( - context: InvocationContext, + context: "InvocationContext", image: ImageType, face: FaceResultData, padding: int, @@ -304,37 +307,37 @@ def extract_face( # Adjust the crop boundaries to stay within the original image's dimensions if x_min < 0: - context.services.logger.warning("FaceTools --> -X-axis padding reached image edge.") + context.logger.warning("FaceTools --> -X-axis padding reached image edge.") x_max -= x_min x_min = 0 elif x_max > mask.width: - context.services.logger.warning("FaceTools --> +X-axis padding reached image edge.") + context.logger.warning("FaceTools --> +X-axis padding reached image edge.") x_min -= x_max - mask.width x_max = mask.width if y_min < 0: - context.services.logger.warning("FaceTools --> +Y-axis padding reached image edge.") + context.logger.warning("FaceTools --> +Y-axis padding reached image edge.") y_max -= y_min y_min = 0 elif y_max > mask.height: - context.services.logger.warning("FaceTools --> -Y-axis padding reached image edge.") + context.logger.warning("FaceTools --> -Y-axis padding reached image edge.") y_min -= y_max - mask.height y_max = mask.height # Ensure the crop is square and adjust the boundaries if needed if x_max - x_min != crop_size: - context.services.logger.warning("FaceTools --> Limiting x-axis padding to constrain bounding box to a square.") + context.logger.warning("FaceTools --> Limiting x-axis padding to constrain bounding box to a square.") diff = crop_size - (x_max - x_min) x_min -= diff // 2 x_max += diff - diff // 2 if y_max - y_min != crop_size: - context.services.logger.warning("FaceTools --> Limiting y-axis padding to constrain bounding box to a square.") + context.logger.warning("FaceTools --> Limiting y-axis padding to constrain bounding box to a square.") diff = crop_size - (y_max - y_min) y_min -= diff // 2 y_max += diff - diff // 2 - context.services.logger.info(f"FaceTools --> Calculated bounding box (8 multiple): {crop_size}") + context.logger.info(f"FaceTools --> Calculated bounding box (8 multiple): {crop_size}") # Crop the output image to the specified size with the center of the face mesh as the center. mask = mask.crop((x_min, y_min, x_max, y_max)) @@ -354,7 +357,7 @@ def extract_face( def get_faces_list( - context: InvocationContext, + context: "InvocationContext", image: ImageType, should_chunk: bool, minimum_confidence: float, @@ -366,7 +369,7 @@ def get_faces_list( # Generate the face box mask and get the center of the face. if not should_chunk: - context.services.logger.info("FaceTools --> Attempting full image face detection.") + context.logger.info("FaceTools --> Attempting full image face detection.") result = generate_face_box_mask( context=context, minimum_confidence=minimum_confidence, @@ -378,7 +381,7 @@ def get_faces_list( draw_mesh=draw_mesh, ) if should_chunk or len(result) == 0: - context.services.logger.info("FaceTools --> Chunking image (chunk toggled on, or no face found in full image).") + context.logger.info("FaceTools --> Chunking image (chunk toggled on, or no face found in full image).") width, height = image.size image_chunks = [] x_offsets = [] @@ -397,7 +400,7 @@ def get_faces_list( x_offsets.append(x) y_offsets.append(0) fx += increment - context.services.logger.info(f"FaceTools --> Chunk starting at x = {x}") + context.logger.info(f"FaceTools --> Chunk starting at x = {x}") elif height > width: # Portrait - slice the image vertically fy = 0.0 @@ -409,10 +412,10 @@ def get_faces_list( x_offsets.append(0) y_offsets.append(y) fy += increment - context.services.logger.info(f"FaceTools --> Chunk starting at y = {y}") + context.logger.info(f"FaceTools --> Chunk starting at y = {y}") for idx in range(len(image_chunks)): - context.services.logger.info(f"FaceTools --> Evaluating faces in chunk {idx}") + context.logger.info(f"FaceTools --> Evaluating faces in chunk {idx}") result = result + generate_face_box_mask( context=context, minimum_confidence=minimum_confidence, @@ -426,7 +429,7 @@ def get_faces_list( if len(result) == 0: # Give up - context.services.logger.warning( + context.logger.warning( "FaceTools --> No face detected in chunked input image. Passing through original image." ) @@ -435,7 +438,7 @@ def get_faces_list( return all_faces -@invocation("face_off", title="FaceOff", tags=["image", "faceoff", "face", "mask"], category="image", version="1.2.0") +@invocation("face_off", title="FaceOff", tags=["image", "faceoff", "face", "mask"], category="image", version="1.2.1") class FaceOffInvocation(BaseInvocation, WithMetadata): """Bound, extract, and mask a face from an image using MediaPipe detection""" @@ -456,7 +459,7 @@ class FaceOffInvocation(BaseInvocation, WithMetadata): description="Whether to bypass full image face detection and default to image chunking. Chunking will occur if no faces are found in the full image.", ) - def faceoff(self, context: InvocationContext, image: ImageType) -> Optional[ExtractFaceData]: + def faceoff(self, context: "InvocationContext", image: ImageType) -> Optional[ExtractFaceData]: all_faces = get_faces_list( context=context, image=image, @@ -468,11 +471,11 @@ def faceoff(self, context: InvocationContext, image: ImageType) -> Optional[Extr ) if len(all_faces) == 0: - context.services.logger.warning("FaceOff --> No faces detected. Passing through original image.") + context.logger.warning("FaceOff --> No faces detected. Passing through original image.") return None if self.face_id > len(all_faces) - 1: - context.services.logger.warning( + context.logger.warning( f"FaceOff --> Face ID {self.face_id} is outside of the number of faces detected ({len(all_faces)}). Passing through original image." ) return None @@ -483,8 +486,8 @@ def faceoff(self, context: InvocationContext, image: ImageType) -> Optional[Extr return face_data - def invoke(self, context: InvocationContext) -> FaceOffOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> FaceOffOutput: + image = context.images.get_pil(self.image.image_name) result = self.faceoff(context=context, image=image) if result is None: @@ -498,24 +501,9 @@ def invoke(self, context: InvocationContext) -> FaceOffOutput: x = result["x_min"] y = result["y_min"] - image_dto = context.services.images.create( - image=result_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - workflow=context.workflow, - ) + image_dto = context.images.save(image=result_image) - mask_dto = context.services.images.create( - image=result_mask, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.MASK, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - ) + mask_dto = context.images.save(image=result_mask, image_category=ImageCategory.MASK) output = FaceOffOutput( image=ImageField(image_name=image_dto.image_name), @@ -529,7 +517,7 @@ def invoke(self, context: InvocationContext) -> FaceOffOutput: return output -@invocation("face_mask_detection", title="FaceMask", tags=["image", "face", "mask"], category="image", version="1.2.0") +@invocation("face_mask_detection", title="FaceMask", tags=["image", "face", "mask"], category="image", version="1.2.1") class FaceMaskInvocation(BaseInvocation, WithMetadata): """Face mask creation using mediapipe face detection""" @@ -556,7 +544,7 @@ def validate_comma_separated_ints(cls, v) -> str: raise ValueError('Face IDs must be a comma-separated list of integers (e.g. "1,2,3")') return v - def facemask(self, context: InvocationContext, image: ImageType) -> FaceMaskResult: + def facemask(self, context: "InvocationContext", image: ImageType) -> FaceMaskResult: all_faces = get_faces_list( context=context, image=image, @@ -578,7 +566,7 @@ def facemask(self, context: InvocationContext, image: ImageType) -> FaceMaskResu if len(intersected_face_ids) == 0: id_range_str = ",".join([str(id) for id in id_range]) - context.services.logger.warning( + context.logger.warning( f"Face IDs must be in range of detected faces - requested {self.face_ids}, detected {id_range_str}. Passing through original image." ) return FaceMaskResult( @@ -613,28 +601,13 @@ def facemask(self, context: InvocationContext, image: ImageType) -> FaceMaskResu mask=mask_pil, ) - def invoke(self, context: InvocationContext) -> FaceMaskOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> FaceMaskOutput: + image = context.images.get_pil(self.image.image_name) result = self.facemask(context=context, image=image) - image_dto = context.services.images.create( - image=result["image"], - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - workflow=context.workflow, - ) + image_dto = context.images.save(image=result["image"]) - mask_dto = context.services.images.create( - image=result["mask"], - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.MASK, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - ) + mask_dto = context.images.save(image=result["mask"], image_category=ImageCategory.MASK) output = FaceMaskOutput( image=ImageField(image_name=image_dto.image_name), @@ -647,7 +620,7 @@ def invoke(self, context: InvocationContext) -> FaceMaskOutput: @invocation( - "face_identifier", title="FaceIdentifier", tags=["image", "face", "identifier"], category="image", version="1.2.0" + "face_identifier", title="FaceIdentifier", tags=["image", "face", "identifier"], category="image", version="1.2.1" ) class FaceIdentifierInvocation(BaseInvocation, WithMetadata): """Outputs an image with detected face IDs printed on each face. For use with other FaceTools.""" @@ -661,7 +634,7 @@ class FaceIdentifierInvocation(BaseInvocation, WithMetadata): description="Whether to bypass full image face detection and default to image chunking. Chunking will occur if no faces are found in the full image.", ) - def faceidentifier(self, context: InvocationContext, image: ImageType) -> ImageType: + def faceidentifier(self, context: "InvocationContext", image: ImageType) -> ImageType: image = image.copy() all_faces = get_faces_list( @@ -702,22 +675,10 @@ def faceidentifier(self, context: InvocationContext, image: ImageType) -> ImageT return image - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) result_image = self.faceidentifier(context=context, image=image) - image_dto = context.services.images.create( - image=result_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - workflow=context.workflow, - ) + image_dto = context.images.save(image=result_image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) diff --git a/invokeai/app/invocations/fields.py b/invokeai/app/invocations/fields.py index 0cce8e3c6b5..566babbb6b7 100644 --- a/invokeai/app/invocations/fields.py +++ b/invokeai/app/invocations/fields.py @@ -1,11 +1,13 @@ +from dataclasses import dataclass from enum import Enum -from typing import Any, Callable, Optional +from typing import Any, Callable, List, Optional, Tuple from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter from pydantic.fields import _Unset from pydantic_core import PydanticUndefined from invokeai.app.util.metaenum import MetaEnum +from invokeai.backend.stable_diffusion.diffusion.conditioning_data import BasicConditioningInfo from invokeai.backend.util.logging import InvokeAILogger logger = InvokeAILogger.get_logger() @@ -255,6 +257,10 @@ class InputFieldJSONSchemaExtra(BaseModel): class WithMetadata(BaseModel): + """ + Inherit from this class if your node needs a metadata input field. + """ + metadata: Optional[MetadataField] = Field( default=None, description=FieldDescriptions.metadata, @@ -498,4 +504,53 @@ def OutputField( field_kind=FieldKind.Output, ).model_dump(exclude_none=True), ) + + +class ImageField(BaseModel): + """An image primitive field""" + + image_name: str = Field(description="The name of the image") + + +class BoardField(BaseModel): + """A board primitive field""" + + board_id: str = Field(description="The id of the board") + + +class DenoiseMaskField(BaseModel): + """An inpaint mask field""" + + mask_name: str = Field(description="The name of the mask image") + masked_latents_name: Optional[str] = Field(default=None, description="The name of the masked image latents") + + +class LatentsField(BaseModel): + """A latents tensor primitive field""" + + latents_name: str = Field(description="The name of the latents") + seed: Optional[int] = Field(default=None, description="Seed used to generate this latents") + + +class ColorField(BaseModel): + """A color primitive field""" + + r: int = Field(ge=0, le=255, description="The red component") + g: int = Field(ge=0, le=255, description="The green component") + b: int = Field(ge=0, le=255, description="The blue component") + a: int = Field(ge=0, le=255, description="The alpha component") + + def tuple(self) -> Tuple[int, int, int, int]: + return (self.r, self.g, self.b, self.a) + + +@dataclass +class ConditioningFieldData: + conditionings: List[BasicConditioningInfo] + + +class ConditioningField(BaseModel): + """A conditioning tensor primitive value""" + + conditioning_name: str = Field(description="The name of conditioning tensor") # endregion diff --git a/invokeai/app/invocations/image.py b/invokeai/app/invocations/image.py index 16d0f33dda3..10ebd97ace3 100644 --- a/invokeai/app/invocations/image.py +++ b/invokeai/app/invocations/image.py @@ -7,30 +7,36 @@ import numpy from PIL import Image, ImageChops, ImageFilter, ImageOps -from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, WithMetadata -from invokeai.app.invocations.primitives import BoardField, ColorField, ImageField, ImageOutput -from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin +from invokeai.app.invocations.baseinvocation import WithMetadata +from invokeai.app.invocations.fields import ( + BoardField, + ColorField, + FieldDescriptions, + ImageField, + Input, + InputField, +) +from invokeai.app.invocations.primitives import ImageOutput +from invokeai.app.services.image_records.image_records_common import ImageCategory from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark from invokeai.backend.image_util.safety_checker import SafetyChecker from .baseinvocation import ( BaseInvocation, Classification, - InvocationContext, invocation, ) -@invocation("show_image", title="Show Image", tags=["image"], category="image", version="1.0.0") +@invocation("show_image", title="Show Image", tags=["image"], category="image", version="1.0.1") class ShowImageInvocation(BaseInvocation): """Displays a provided image using the OS image viewer, and passes it forward in the pipeline.""" image: ImageField = InputField(description="The image to show") - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) - if image: - image.show() + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) + image.show() # TODO: how to handle failure? @@ -46,7 +52,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Blank Image", tags=["image"], category="image", - version="1.2.0", + version="1.2.1", ) class BlankImageInvocation(BaseInvocation, WithMetadata): """Creates a blank image and forwards it to the pipeline""" @@ -56,25 +62,12 @@ class BlankImageInvocation(BaseInvocation, WithMetadata): mode: Literal["RGB", "RGBA"] = InputField(default="RGB", description="The mode of the image") color: ColorField = InputField(default=ColorField(r=0, g=0, b=0, a=255), description="The color of the image") - def invoke(self, context: InvocationContext) -> ImageOutput: + def invoke(self, context) -> ImageOutput: image = Image.new(mode=self.mode, size=(self.width, self.height), color=self.color.tuple()) - image_dto = context.services.images.create( - image=image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -82,7 +75,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Crop Image", tags=["image", "crop"], category="image", - version="1.2.0", + version="1.2.1", ) class ImageCropInvocation(BaseInvocation, WithMetadata): """Crops an image to a specified box. The box can be outside of the image.""" @@ -93,28 +86,15 @@ class ImageCropInvocation(BaseInvocation, WithMetadata): width: int = InputField(default=512, gt=0, description="The width of the crop rectangle") height: int = InputField(default=512, gt=0, description="The height of the crop rectangle") - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) image_crop = Image.new(mode="RGBA", size=(self.width, self.height), color=(0, 0, 0, 0)) image_crop.paste(image, (-self.x, -self.y)) - image_dto = context.services.images.create( - image=image_crop, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=image_crop) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -145,8 +125,8 @@ class CenterPadCropInvocation(BaseInvocation): description="Number of pixels to pad/crop from the bottom (negative values crop inwards, positive values pad outwards)", ) - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) # Calculate and create new image dimensions new_width = image.width + self.right + self.left @@ -156,20 +136,9 @@ def invoke(self, context: InvocationContext) -> ImageOutput: # Paste new image onto input image_crop.paste(image, (self.left, self.top)) - image_dto = context.services.images.create( - image=image_crop, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - ) + image_dto = context.images.save(image=image_crop) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -177,7 +146,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Paste Image", tags=["image", "paste"], category="image", - version="1.2.0", + version="1.2.1", ) class ImagePasteInvocation(BaseInvocation, WithMetadata): """Pastes an image into another image.""" @@ -192,12 +161,12 @@ class ImagePasteInvocation(BaseInvocation, WithMetadata): y: int = InputField(default=0, description="The top y coordinate at which to paste the image") crop: bool = InputField(default=False, description="Crop to base image dimensions") - def invoke(self, context: InvocationContext) -> ImageOutput: - base_image = context.services.images.get_pil_image(self.base_image.image_name) - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + base_image = context.images.get_pil(self.base_image.image_name) + image = context.images.get_pil(self.image.image_name) mask = None if self.mask is not None: - mask = context.services.images.get_pil_image(self.mask.image_name) + mask = context.images.get_pil(self.mask.image_name) mask = ImageOps.invert(mask.convert("L")) # TODO: probably shouldn't invert mask here... should user be required to do it? @@ -214,22 +183,9 @@ def invoke(self, context: InvocationContext) -> ImageOutput: base_w, base_h = base_image.size new_image = new_image.crop((abs(min_x), abs(min_y), abs(min_x) + base_w, abs(min_y) + base_h)) - image_dto = context.services.images.create( - image=new_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=new_image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -237,7 +193,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Mask from Alpha", tags=["image", "mask"], category="image", - version="1.2.0", + version="1.2.1", ) class MaskFromAlphaInvocation(BaseInvocation, WithMetadata): """Extracts the alpha channel of an image as a mask.""" @@ -245,29 +201,16 @@ class MaskFromAlphaInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description="The image to create the mask from") invert: bool = InputField(default=False, description="Whether or not to invert the mask") - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) image_mask = image.split()[-1] if self.invert: image_mask = ImageOps.invert(image_mask) - image_dto = context.services.images.create( - image=image_mask, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.MASK, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=image_mask, image_category=ImageCategory.MASK) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -275,7 +218,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Multiply Images", tags=["image", "multiply"], category="image", - version="1.2.0", + version="1.2.1", ) class ImageMultiplyInvocation(BaseInvocation, WithMetadata): """Multiplies two images together using `PIL.ImageChops.multiply()`.""" @@ -283,28 +226,15 @@ class ImageMultiplyInvocation(BaseInvocation, WithMetadata): image1: ImageField = InputField(description="The first image to multiply") image2: ImageField = InputField(description="The second image to multiply") - def invoke(self, context: InvocationContext) -> ImageOutput: - image1 = context.services.images.get_pil_image(self.image1.image_name) - image2 = context.services.images.get_pil_image(self.image2.image_name) + def invoke(self, context) -> ImageOutput: + image1 = context.images.get_pil(self.image1.image_name) + image2 = context.images.get_pil(self.image2.image_name) multiply_image = ImageChops.multiply(image1, image2) - image_dto = context.services.images.create( - image=multiply_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=multiply_image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) IMAGE_CHANNELS = Literal["A", "R", "G", "B"] @@ -315,7 +245,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Extract Image Channel", tags=["image", "channel"], category="image", - version="1.2.0", + version="1.2.1", ) class ImageChannelInvocation(BaseInvocation, WithMetadata): """Gets a channel from an image.""" @@ -323,27 +253,14 @@ class ImageChannelInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description="The image to get the channel from") channel: IMAGE_CHANNELS = InputField(default="A", description="The channel to get") - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) channel_image = image.getchannel(self.channel) - image_dto = context.services.images.create( - image=channel_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=channel_image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"] @@ -354,7 +271,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Convert Image Mode", tags=["image", "convert"], category="image", - version="1.2.0", + version="1.2.1", ) class ImageConvertInvocation(BaseInvocation, WithMetadata): """Converts an image to a different mode.""" @@ -362,27 +279,14 @@ class ImageConvertInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description="The image to convert") mode: IMAGE_MODES = InputField(default="L", description="The mode to convert to") - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) converted_image = image.convert(self.mode) - image_dto = context.services.images.create( - image=converted_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=converted_image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -390,7 +294,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Blur Image", tags=["image", "blur"], category="image", - version="1.2.0", + version="1.2.1", ) class ImageBlurInvocation(BaseInvocation, WithMetadata): """Blurs an image""" @@ -400,30 +304,17 @@ class ImageBlurInvocation(BaseInvocation, WithMetadata): # Metadata blur_type: Literal["gaussian", "box"] = InputField(default="gaussian", description="The type of blur") - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) blur = ( ImageFilter.GaussianBlur(self.radius) if self.blur_type == "gaussian" else ImageFilter.BoxBlur(self.radius) ) blur_image = image.filter(blur) - image_dto = context.services.images.create( - image=blur_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=blur_image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -431,7 +322,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Unsharp Mask", tags=["image", "unsharp_mask"], category="image", - version="1.2.0", + version="1.2.1", classification=Classification.Beta, ) class UnsharpMaskInvocation(BaseInvocation, WithMetadata): @@ -447,8 +338,8 @@ def pil_from_array(self, arr): def array_from_pil(self, img): return numpy.array(img) / 255 - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) mode = image.mode alpha_channel = image.getchannel("A") if mode == "RGBA" else None @@ -466,16 +357,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: if alpha_channel is not None: image.putalpha(alpha_channel) - image_dto = context.services.images.create( - image=image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=image) return ImageOutput( image=ImageField(image_name=image_dto.image_name), @@ -509,7 +391,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Resize Image", tags=["image", "resize"], category="image", - version="1.2.0", + version="1.2.1", ) class ImageResizeInvocation(BaseInvocation, WithMetadata): """Resizes an image to specific dimensions""" @@ -519,8 +401,8 @@ class ImageResizeInvocation(BaseInvocation, WithMetadata): height: int = InputField(default=512, gt=0, description="The height to resize to (px)") resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode") - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) resample_mode = PIL_RESAMPLING_MAP[self.resample_mode] @@ -529,22 +411,9 @@ def invoke(self, context: InvocationContext) -> ImageOutput: resample=resample_mode, ) - image_dto = context.services.images.create( - image=resize_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=resize_image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -552,7 +421,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Scale Image", tags=["image", "scale"], category="image", - version="1.2.0", + version="1.2.1", ) class ImageScaleInvocation(BaseInvocation, WithMetadata): """Scales an image by a factor""" @@ -565,8 +434,8 @@ class ImageScaleInvocation(BaseInvocation, WithMetadata): ) resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode") - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) resample_mode = PIL_RESAMPLING_MAP[self.resample_mode] width = int(image.width * self.scale_factor) @@ -577,22 +446,9 @@ def invoke(self, context: InvocationContext) -> ImageOutput: resample=resample_mode, ) - image_dto = context.services.images.create( - image=resize_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=resize_image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -600,7 +456,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Lerp Image", tags=["image", "lerp"], category="image", - version="1.2.0", + version="1.2.1", ) class ImageLerpInvocation(BaseInvocation, WithMetadata): """Linear interpolation of all pixels of an image""" @@ -609,30 +465,17 @@ class ImageLerpInvocation(BaseInvocation, WithMetadata): min: int = InputField(default=0, ge=0, le=255, description="The minimum output value") max: int = InputField(default=255, ge=0, le=255, description="The maximum output value") - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) image_arr = numpy.asarray(image, dtype=numpy.float32) / 255 image_arr = image_arr * (self.max - self.min) + self.min lerp_image = Image.fromarray(numpy.uint8(image_arr)) - image_dto = context.services.images.create( - image=lerp_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=lerp_image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -640,7 +483,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Inverse Lerp Image", tags=["image", "ilerp"], category="image", - version="1.2.0", + version="1.2.1", ) class ImageInverseLerpInvocation(BaseInvocation, WithMetadata): """Inverse linear interpolation of all pixels of an image""" @@ -649,30 +492,17 @@ class ImageInverseLerpInvocation(BaseInvocation, WithMetadata): min: int = InputField(default=0, ge=0, le=255, description="The minimum input value") max: int = InputField(default=255, ge=0, le=255, description="The maximum input value") - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) image_arr = numpy.asarray(image, dtype=numpy.float32) image_arr = numpy.minimum(numpy.maximum(image_arr - self.min, 0) / float(self.max - self.min), 1) * 255 # type: ignore [assignment] ilerp_image = Image.fromarray(numpy.uint8(image_arr)) - image_dto = context.services.images.create( - image=ilerp_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=ilerp_image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -680,17 +510,17 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Blur NSFW Image", tags=["image", "nsfw"], category="image", - version="1.2.0", + version="1.2.1", ) class ImageNSFWBlurInvocation(BaseInvocation, WithMetadata): """Add blur to NSFW-flagged images""" image: ImageField = InputField(description="The image to check") - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) - logger = context.services.logger + logger = context.logger logger.debug("Running NSFW checker") if SafetyChecker.has_nsfw_concept(image): logger.info("A potentially NSFW image has been detected. Image will be blurred.") @@ -699,22 +529,9 @@ def invoke(self, context: InvocationContext) -> ImageOutput: blurry_image.paste(caution, (0, 0), caution) image = blurry_image - image_dto = context.services.images.create( - image=image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) def _get_caution_img(self) -> Image.Image: import invokeai.app.assets.images as image_assets @@ -728,7 +545,7 @@ def _get_caution_img(self) -> Image.Image: title="Add Invisible Watermark", tags=["image", "watermark"], category="image", - version="1.2.0", + version="1.2.1", ) class ImageWatermarkInvocation(BaseInvocation, WithMetadata): """Add an invisible watermark to an image""" @@ -736,25 +553,12 @@ class ImageWatermarkInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description="The image to check") text: str = InputField(default="InvokeAI", description="Watermark text") - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) new_image = InvisibleWatermark.add_watermark(image, self.text) - image_dto = context.services.images.create( - image=new_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=new_image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -762,7 +566,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Mask Edge", tags=["image", "mask", "inpaint"], category="image", - version="1.2.0", + version="1.2.1", ) class MaskEdgeInvocation(BaseInvocation, WithMetadata): """Applies an edge mask to an image""" @@ -775,8 +579,8 @@ class MaskEdgeInvocation(BaseInvocation, WithMetadata): description="Second threshold for the hysteresis procedure in Canny edge detection" ) - def invoke(self, context: InvocationContext) -> ImageOutput: - mask = context.services.images.get_pil_image(self.image.image_name).convert("L") + def invoke(self, context) -> ImageOutput: + mask = context.images.get_pil(self.image.image_name).convert("L") npimg = numpy.asarray(mask, dtype=numpy.uint8) npgradient = numpy.uint8(255 * (1.0 - numpy.floor(numpy.abs(0.5 - numpy.float32(npimg) / 255.0) * 2.0))) @@ -791,22 +595,9 @@ def invoke(self, context: InvocationContext) -> ImageOutput: new_mask = ImageOps.invert(new_mask) - image_dto = context.services.images.create( - image=new_mask, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.MASK, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=new_mask, image_category=ImageCategory.MASK) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -814,7 +605,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Combine Masks", tags=["image", "mask", "multiply"], category="image", - version="1.2.0", + version="1.2.1", ) class MaskCombineInvocation(BaseInvocation, WithMetadata): """Combine two masks together by multiplying them using `PIL.ImageChops.multiply()`.""" @@ -822,28 +613,15 @@ class MaskCombineInvocation(BaseInvocation, WithMetadata): mask1: ImageField = InputField(description="The first mask to combine") mask2: ImageField = InputField(description="The second image to combine") - def invoke(self, context: InvocationContext) -> ImageOutput: - mask1 = context.services.images.get_pil_image(self.mask1.image_name).convert("L") - mask2 = context.services.images.get_pil_image(self.mask2.image_name).convert("L") + def invoke(self, context) -> ImageOutput: + mask1 = context.images.get_pil(self.mask1.image_name).convert("L") + mask2 = context.images.get_pil(self.mask2.image_name).convert("L") combined_mask = ImageChops.multiply(mask1, mask2) - image_dto = context.services.images.create( - image=combined_mask, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=combined_mask, image_category=ImageCategory.MASK) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -851,7 +629,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Color Correct", tags=["image", "color"], category="image", - version="1.2.0", + version="1.2.1", ) class ColorCorrectInvocation(BaseInvocation, WithMetadata): """ @@ -864,14 +642,14 @@ class ColorCorrectInvocation(BaseInvocation, WithMetadata): mask: Optional[ImageField] = InputField(default=None, description="Mask to use when applying color-correction") mask_blur_radius: float = InputField(default=8, description="Mask blur radius") - def invoke(self, context: InvocationContext) -> ImageOutput: + def invoke(self, context) -> ImageOutput: pil_init_mask = None if self.mask is not None: - pil_init_mask = context.services.images.get_pil_image(self.mask.image_name).convert("L") + pil_init_mask = context.images.get_pil(self.mask.image_name).convert("L") - init_image = context.services.images.get_pil_image(self.reference.image_name) + init_image = context.images.get_pil(self.reference.image_name) - result = context.services.images.get_pil_image(self.image.image_name).convert("RGBA") + result = context.images.get_pil(self.image.image_name).convert("RGBA") # if init_image is None or init_mask is None: # return result @@ -945,22 +723,9 @@ def invoke(self, context: InvocationContext) -> ImageOutput: # Paste original on color-corrected generation (using blurred mask) matched_result.paste(init_image, (0, 0), mask=multiplied_blurred_init_mask) - image_dto = context.services.images.create( - image=matched_result, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=matched_result) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -968,7 +733,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Adjust Image Hue", tags=["image", "hue"], category="image", - version="1.2.0", + version="1.2.1", ) class ImageHueAdjustmentInvocation(BaseInvocation, WithMetadata): """Adjusts the Hue of an image.""" @@ -976,8 +741,8 @@ class ImageHueAdjustmentInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description="The image to adjust") hue: int = InputField(default=0, description="The degrees by which to rotate the hue, 0-360") - def invoke(self, context: InvocationContext) -> ImageOutput: - pil_image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + pil_image = context.images.get_pil(self.image.image_name) # Convert image to HSV color space hsv_image = numpy.array(pil_image.convert("HSV")) @@ -991,24 +756,9 @@ def invoke(self, context: InvocationContext) -> ImageOutput: # Convert back to PIL format and to original color mode pil_image = Image.fromarray(hsv_image, mode="HSV").convert("RGBA") - image_dto = context.services.images.create( - image=pil_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - is_intermediate=self.is_intermediate, - session_id=context.graph_execution_state_id, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=pil_image) - return ImageOutput( - image=ImageField( - image_name=image_dto.image_name, - ), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) COLOR_CHANNELS = Literal[ @@ -1072,7 +822,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: "value", ], category="image", - version="1.2.0", + version="1.2.1", ) class ImageChannelOffsetInvocation(BaseInvocation, WithMetadata): """Add or subtract a value from a specific color channel of an image.""" @@ -1081,8 +831,8 @@ class ImageChannelOffsetInvocation(BaseInvocation, WithMetadata): channel: COLOR_CHANNELS = InputField(description="Which channel to adjust") offset: int = InputField(default=0, ge=-255, le=255, description="The amount to adjust the channel by") - def invoke(self, context: InvocationContext) -> ImageOutput: - pil_image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + pil_image = context.images.get_pil(self.image.image_name) # extract the channel and mode from the input and reference tuple mode = CHANNEL_FORMATS[self.channel][0] @@ -1101,24 +851,9 @@ def invoke(self, context: InvocationContext) -> ImageOutput: # Convert back to RGBA format and output pil_image = Image.fromarray(converted_image.astype(numpy.uint8), mode=mode).convert("RGBA") - image_dto = context.services.images.create( - image=pil_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - is_intermediate=self.is_intermediate, - session_id=context.graph_execution_state_id, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=pil_image) - return ImageOutput( - image=ImageField( - image_name=image_dto.image_name, - ), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -1143,7 +878,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: "value", ], category="image", - version="1.2.0", + version="1.2.1", ) class ImageChannelMultiplyInvocation(BaseInvocation, WithMetadata): """Scale a specific color channel of an image.""" @@ -1153,8 +888,8 @@ class ImageChannelMultiplyInvocation(BaseInvocation, WithMetadata): scale: float = InputField(default=1.0, ge=0.0, description="The amount to scale the channel by.") invert_channel: bool = InputField(default=False, description="Invert the channel after scaling") - def invoke(self, context: InvocationContext) -> ImageOutput: - pil_image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + pil_image = context.images.get_pil(self.image.image_name) # extract the channel and mode from the input and reference tuple mode = CHANNEL_FORMATS[self.channel][0] @@ -1177,24 +912,9 @@ def invoke(self, context: InvocationContext) -> ImageOutput: # Convert back to RGBA format and output pil_image = Image.fromarray(converted_image.astype(numpy.uint8), mode=mode).convert("RGBA") - image_dto = context.services.images.create( - image=pil_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - is_intermediate=self.is_intermediate, - session_id=context.graph_execution_state_id, - workflow=context.workflow, - metadata=self.metadata, - ) + image_dto = context.images.save(image=pil_image) - return ImageOutput( - image=ImageField( - image_name=image_dto.image_name, - ), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -1202,7 +922,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Save Image", tags=["primitives", "image"], category="primitives", - version="1.2.0", + version="1.2.1", use_cache=False, ) class SaveImageInvocation(BaseInvocation, WithMetadata): @@ -1211,26 +931,12 @@ class SaveImageInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description=FieldDescriptions.image) board: BoardField = InputField(default=None, description=FieldDescriptions.board, input=Input.Direct) - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) - image_dto = context.services.images.create( - image=image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - board_id=self.board.board_id if self.board else None, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=image, board_id=self.board.board_id if self.board else None) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -1238,7 +944,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Linear UI Image Output", tags=["primitives", "image"], category="primitives", - version="1.0.1", + version="1.0.2", use_cache=False, ) class LinearUIOutputInvocation(BaseInvocation, WithMetadata): @@ -1247,19 +953,13 @@ class LinearUIOutputInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description=FieldDescriptions.image) board: Optional[BoardField] = InputField(default=None, description=FieldDescriptions.board, input=Input.Direct) - def invoke(self, context: InvocationContext) -> ImageOutput: - image_dto = context.services.images.get_dto(self.image.image_name) - - if self.board: - context.services.board_images.add_image_to_board(self.board.board_id, self.image.image_name) + def invoke(self, context) -> ImageOutput: + image_dto = context.images.get_dto(self.image.image_name) - if image_dto.is_intermediate != self.is_intermediate: - context.services.images.update( - self.image.image_name, changes=ImageRecordChanges(is_intermediate=self.is_intermediate) - ) - - return ImageOutput( - image=ImageField(image_name=self.image.image_name), - width=image_dto.width, - height=image_dto.height, + image_dto = context.images.update( + image_name=self.image.image_name, + board_id=self.board.board_id if self.board else None, + is_intermediate=self.is_intermediate, ) + + return ImageOutput.build(image_dto) diff --git a/invokeai/app/invocations/infill.py b/invokeai/app/invocations/infill.py index d4d3d5bea44..be51c8312f9 100644 --- a/invokeai/app/invocations/infill.py +++ b/invokeai/app/invocations/infill.py @@ -6,15 +6,15 @@ import numpy as np from PIL import Image, ImageOps -from invokeai.app.invocations.primitives import ColorField, ImageField, ImageOutput -from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin +from invokeai.app.invocations.fields import ColorField, ImageField +from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.util.misc import SEED_MAX from invokeai.backend.image_util.cv2_inpaint import cv2_inpaint from invokeai.backend.image_util.lama import LaMA from invokeai.backend.image_util.patchmatch import PatchMatch -from .baseinvocation import BaseInvocation, InvocationContext, invocation -from .fields import InputField, WithMetadata +from .baseinvocation import BaseInvocation, WithMetadata, invocation +from .fields import InputField from .image import PIL_RESAMPLING_MAP, PIL_RESAMPLING_MODES @@ -119,7 +119,7 @@ def tile_fill_missing(im: Image.Image, tile_size: int = 16, seed: Optional[int] return si -@invocation("infill_rgba", title="Solid Color Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.0") +@invocation("infill_rgba", title="Solid Color Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1") class InfillColorInvocation(BaseInvocation, WithMetadata): """Infills transparent areas of an image with a solid color""" @@ -129,33 +129,20 @@ class InfillColorInvocation(BaseInvocation, WithMetadata): description="The color to use to infill", ) - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) solid_bg = Image.new("RGBA", image.size, self.color.tuple()) infilled = Image.alpha_composite(solid_bg, image.convert("RGBA")) infilled.paste(image, (0, 0), image.split()[-1]) - image_dto = context.services.images.create( - image=infilled, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=infilled) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) -@invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1") +@invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2") class InfillTileInvocation(BaseInvocation, WithMetadata): """Infills transparent areas of an image with tiles of the image""" @@ -168,32 +155,19 @@ class InfillTileInvocation(BaseInvocation, WithMetadata): description="The seed to use for tile generation (omit for random)", ) - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) infilled = tile_fill_missing(image.copy(), seed=self.seed, tile_size=self.tile_size) infilled.paste(image, (0, 0), image.split()[-1]) - image_dto = context.services.images.create( - image=infilled, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=infilled) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( - "infill_patchmatch", title="PatchMatch Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.0" + "infill_patchmatch", title="PatchMatch Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1" ) class InfillPatchMatchInvocation(BaseInvocation, WithMetadata): """Infills transparent areas of an image using the PatchMatch algorithm""" @@ -202,8 +176,8 @@ class InfillPatchMatchInvocation(BaseInvocation, WithMetadata): downscale: float = InputField(default=2.0, gt=0, description="Run patchmatch on downscaled image to speedup infill") resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode") - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name).convert("RGBA") + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name).convert("RGBA") resample_mode = PIL_RESAMPLING_MAP[self.resample_mode] @@ -228,77 +202,38 @@ def invoke(self, context: InvocationContext) -> ImageOutput: infilled.paste(image, (0, 0), mask=image.split()[-1]) # image.paste(infilled, (0, 0), mask=image.split()[-1]) - image_dto = context.services.images.create( - image=infilled, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=infilled) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) -@invocation("infill_lama", title="LaMa Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.0") +@invocation("infill_lama", title="LaMa Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1") class LaMaInfillInvocation(BaseInvocation, WithMetadata): """Infills transparent areas of an image using the LaMa model""" image: ImageField = InputField(description="The image to infill") - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) infilled = infill_lama(image.copy()) - image_dto = context.services.images.create( - image=infilled, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=infilled) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) -@invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.0") +@invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1") class CV2InfillInvocation(BaseInvocation, WithMetadata): """Infills transparent areas of an image using OpenCV Inpainting""" image: ImageField = InputField(description="The image to infill") - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) infilled = infill_cv2(image.copy()) - image_dto = context.services.images.create( - image=infilled, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=infilled) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) diff --git a/invokeai/app/invocations/ip_adapter.py b/invokeai/app/invocations/ip_adapter.py index c01e0ed0fb2..b836be04b58 100644 --- a/invokeai/app/invocations/ip_adapter.py +++ b/invokeai/app/invocations/ip_adapter.py @@ -7,7 +7,6 @@ from invokeai.app.invocations.baseinvocation import ( BaseInvocation, BaseInvocationOutput, - InvocationContext, invocation, invocation_output, ) @@ -62,7 +61,7 @@ class IPAdapterOutput(BaseInvocationOutput): ip_adapter: IPAdapterField = OutputField(description=FieldDescriptions.ip_adapter, title="IP-Adapter") -@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.1.1") +@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.1.2") class IPAdapterInvocation(BaseInvocation): """Collects IP-Adapter info to pass to other nodes.""" @@ -93,9 +92,9 @@ def validate_begin_end_step_percent(self): validate_begin_end_step(self.begin_step_percent, self.end_step_percent) return self - def invoke(self, context: InvocationContext) -> IPAdapterOutput: + def invoke(self, context) -> IPAdapterOutput: # Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model. - ip_adapter_info = context.services.model_manager.model_info( + ip_adapter_info = context.models.get_info( self.ip_adapter_model.model_name, self.ip_adapter_model.base_model, ModelType.IPAdapter ) # HACK(ryand): This is bad for a couple of reasons: 1) we are bypassing the model manager to read the model @@ -104,7 +103,7 @@ def invoke(self, context: InvocationContext) -> IPAdapterOutput: # is currently messy due to differences between how the model info is generated when installing a model from # disk vs. downloading the model. image_encoder_model_id = get_ip_adapter_image_encoder_model_id( - os.path.join(context.services.configuration.get_config().models_path, ip_adapter_info["path"]) + os.path.join(context.config.get().models_path, ip_adapter_info["path"]) ) image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip() image_encoder_model = CLIPVisionModelField( diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 909c307481e..0127a6521e1 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -3,7 +3,7 @@ import math from contextlib import ExitStack from functools import singledispatchmethod -from typing import List, Literal, Optional, Union +from typing import TYPE_CHECKING, List, Literal, Optional, Union import einops import numpy as np @@ -23,21 +23,26 @@ from pydantic import field_validator from torchvision.transforms.functional import resize as tv_resize -from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType, WithMetadata +from invokeai.app.invocations.fields import ( + ConditioningField, + DenoiseMaskField, + FieldDescriptions, + ImageField, + Input, + InputField, + LatentsField, + OutputField, + UIType, + WithMetadata, +) from invokeai.app.invocations.ip_adapter import IPAdapterField from invokeai.app.invocations.primitives import ( - DenoiseMaskField, DenoiseMaskOutput, - ImageField, ImageOutput, - LatentsField, LatentsOutput, - build_latents_output, ) from invokeai.app.invocations.t2i_adapter import T2IAdapterField -from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin from invokeai.app.util.controlnet_utils import prepare_control_image -from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus from invokeai.backend.model_management.models import ModelType, SilenceWarnings from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData, IPAdapterConditioningInfo @@ -59,14 +64,15 @@ from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, - InvocationContext, invocation, invocation_output, ) -from .compel import ConditioningField from .controlnet_image_processors import ControlField from .model import ModelInfo, UNetField, VaeField +if TYPE_CHECKING: + from invokeai.app.services.shared.invocation_context import InvocationContext + if choose_torch_device() == torch.device("mps"): from torch import mps @@ -102,7 +108,7 @@ class SchedulerInvocation(BaseInvocation): ui_type=UIType.Scheduler, ) - def invoke(self, context: InvocationContext) -> SchedulerOutput: + def invoke(self, context) -> SchedulerOutput: return SchedulerOutput(scheduler=self.scheduler) @@ -111,7 +117,7 @@ def invoke(self, context: InvocationContext) -> SchedulerOutput: title="Create Denoise Mask", tags=["mask", "denoise"], category="latents", - version="1.0.0", + version="1.0.1", ) class CreateDenoiseMaskInvocation(BaseInvocation): """Creates mask for denoising model run.""" @@ -137,9 +143,9 @@ def prep_mask_tensor(self, mask_image): return mask_tensor @torch.no_grad() - def invoke(self, context: InvocationContext) -> DenoiseMaskOutput: + def invoke(self, context) -> DenoiseMaskOutput: if self.image is not None: - image = context.services.images.get_pil_image(self.image.image_name) + image = context.images.get_pil(self.image.image_name) image = image_resized_to_grid_as_tensor(image.convert("RGB")) if image.dim() == 3: image = image.unsqueeze(0) @@ -147,47 +153,37 @@ def invoke(self, context: InvocationContext) -> DenoiseMaskOutput: image = None mask = self.prep_mask_tensor( - context.services.images.get_pil_image(self.mask.image_name), + context.images.get_pil(self.mask.image_name), ) if image is not None: - vae_info = context.services.model_manager.get_model( - **self.vae.vae.model_dump(), - context=context, - ) + vae_info = context.models.load(**self.vae.vae.model_dump()) img_mask = tv_resize(mask, image.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False) masked_image = image * torch.where(img_mask < 0.5, 0.0, 1.0) # TODO: masked_latents = ImageToLatentsInvocation.vae_encode(vae_info, self.fp32, self.tiled, masked_image.clone()) - masked_latents_name = f"{context.graph_execution_state_id}__{self.id}_masked_latents" - context.services.latents.save(masked_latents_name, masked_latents) + masked_latents_name = context.latents.save(tensor=masked_latents) else: masked_latents_name = None - mask_name = f"{context.graph_execution_state_id}__{self.id}_mask" - context.services.latents.save(mask_name, mask) + mask_name = context.latents.save(tensor=mask) - return DenoiseMaskOutput( - denoise_mask=DenoiseMaskField( - mask_name=mask_name, - masked_latents_name=masked_latents_name, - ), + return DenoiseMaskOutput.build( + mask_name=mask_name, + masked_latents_name=masked_latents_name, ) def get_scheduler( - context: InvocationContext, + context: "InvocationContext", scheduler_info: ModelInfo, scheduler_name: str, seed: int, ) -> Scheduler: scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"]) - orig_scheduler_info = context.services.model_manager.get_model( - **scheduler_info.model_dump(), - context=context, - ) + orig_scheduler_info = context.models.load(**scheduler_info.model_dump()) with orig_scheduler_info as orig_scheduler: scheduler_config = orig_scheduler.config @@ -216,7 +212,7 @@ def get_scheduler( title="Denoise Latents", tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"], category="latents", - version="1.5.1", + version="1.5.2", ) class DenoiseLatentsInvocation(BaseInvocation): """Denoises noisy latents to decodable images""" @@ -302,34 +298,18 @@ def ge_one(cls, v): raise ValueError("cfg_scale must be greater than 1") return v - # TODO: pass this an emitter method or something? or a session for dispatching? - def dispatch_progress( - self, - context: InvocationContext, - source_node_id: str, - intermediate_state: PipelineIntermediateState, - base_model: BaseModelType, - ) -> None: - stable_diffusion_step_callback( - context=context, - intermediate_state=intermediate_state, - node=self.model_dump(), - source_node_id=source_node_id, - base_model=base_model, - ) - def get_conditioning_data( self, - context: InvocationContext, + context: "InvocationContext", scheduler, unet, seed, ) -> ConditioningData: - positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name) + positive_cond_data = context.conditioning.get(self.positive_conditioning.conditioning_name) c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype) extra_conditioning_info = c.extra_conditioning - negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name) + negative_cond_data = context.conditioning.get(self.negative_conditioning.conditioning_name) uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype) conditioning_data = ConditioningData( @@ -389,7 +369,7 @@ def __init__(self): def prep_control_data( self, - context: InvocationContext, + context: "InvocationContext", control_input: Union[ControlField, List[ControlField]], latents_shape: List[int], exit_stack: ExitStack, @@ -417,17 +397,16 @@ def prep_control_data( controlnet_data = [] for control_info in control_list: control_model = exit_stack.enter_context( - context.services.model_manager.get_model( + context.models.load( model_name=control_info.control_model.model_name, model_type=ModelType.ControlNet, base_model=control_info.control_model.base_model, - context=context, ) ) # control_models.append(control_model) control_image_field = control_info.image - input_image = context.services.images.get_pil_image(control_image_field.image_name) + input_image = context.images.get_pil(control_image_field.image_name) # self.image.image_type, self.image.image_name # FIXME: still need to test with different widths, heights, devices, dtypes # and add in batch_size, num_images_per_prompt? @@ -463,7 +442,7 @@ def prep_control_data( def prep_ip_adapter_data( self, - context: InvocationContext, + context: "InvocationContext", ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]], conditioning_data: ConditioningData, exit_stack: ExitStack, @@ -485,19 +464,17 @@ def prep_ip_adapter_data( conditioning_data.ip_adapter_conditioning = [] for single_ip_adapter in ip_adapter: ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context( - context.services.model_manager.get_model( + context.models.load( model_name=single_ip_adapter.ip_adapter_model.model_name, model_type=ModelType.IPAdapter, base_model=single_ip_adapter.ip_adapter_model.base_model, - context=context, ) ) - image_encoder_model_info = context.services.model_manager.get_model( + image_encoder_model_info = context.models.load( model_name=single_ip_adapter.image_encoder_model.model_name, model_type=ModelType.CLIPVision, base_model=single_ip_adapter.image_encoder_model.base_model, - context=context, ) # `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here. @@ -505,7 +482,7 @@ def prep_ip_adapter_data( if not isinstance(single_ipa_images, list): single_ipa_images = [single_ipa_images] - single_ipa_images = [context.services.images.get_pil_image(image.image_name) for image in single_ipa_images] + single_ipa_images = [context.images.get_pil(image.image_name) for image in single_ipa_images] # TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other # models are needed in memory. This would help to reduce peak memory utilization in low-memory environments. @@ -532,7 +509,7 @@ def prep_ip_adapter_data( def run_t2i_adapters( self, - context: InvocationContext, + context: "InvocationContext", t2i_adapter: Optional[Union[T2IAdapterField, list[T2IAdapterField]]], latents_shape: list[int], do_classifier_free_guidance: bool, @@ -549,13 +526,12 @@ def run_t2i_adapters( t2i_adapter_data = [] for t2i_adapter_field in t2i_adapter: - t2i_adapter_model_info = context.services.model_manager.get_model( + t2i_adapter_model_info = context.models.load( model_name=t2i_adapter_field.t2i_adapter_model.model_name, model_type=ModelType.T2IAdapter, base_model=t2i_adapter_field.t2i_adapter_model.base_model, - context=context, ) - image = context.services.images.get_pil_image(t2i_adapter_field.image.image_name) + image = context.images.get_pil(t2i_adapter_field.image.image_name) # The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally. if t2i_adapter_field.t2i_adapter_model.base_model == BaseModelType.StableDiffusion1: @@ -642,30 +618,30 @@ def init_scheduler(self, scheduler, device, steps, denoising_start, denoising_en return num_inference_steps, timesteps, init_timestep - def prep_inpaint_mask(self, context, latents): + def prep_inpaint_mask(self, context: "InvocationContext", latents): if self.denoise_mask is None: return None, None - mask = context.services.latents.get(self.denoise_mask.mask_name) + mask = context.latents.get(self.denoise_mask.mask_name) mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False) if self.denoise_mask.masked_latents_name is not None: - masked_latents = context.services.latents.get(self.denoise_mask.masked_latents_name) + masked_latents = context.latents.get(self.denoise_mask.masked_latents_name) else: masked_latents = None return 1 - mask, masked_latents @torch.no_grad() - def invoke(self, context: InvocationContext) -> LatentsOutput: + def invoke(self, context) -> LatentsOutput: with SilenceWarnings(): # this quenches NSFW nag from diffusers seed = None noise = None if self.noise is not None: - noise = context.services.latents.get(self.noise.latents_name) + noise = context.latents.get(self.noise.latents_name) seed = self.noise.seed if self.latents is not None: - latents = context.services.latents.get(self.latents.latents_name) + latents = context.latents.get(self.latents.latents_name) if seed is None: seed = self.latents.seed @@ -691,27 +667,17 @@ def invoke(self, context: InvocationContext) -> LatentsOutput: do_classifier_free_guidance=True, ) - # Get the source node id (we are invoking the prepared node) - graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) - source_node_id = graph_execution_state.prepared_source_mapping[self.id] - def step_callback(state: PipelineIntermediateState): - self.dispatch_progress(context, source_node_id, state, self.unet.unet.base_model) + context.util.sd_step_callback(state, self.unet.unet.base_model) def _lora_loader(): for lora in self.unet.loras: - lora_info = context.services.model_manager.get_model( - **lora.model_dump(exclude={"weight"}), - context=context, - ) + lora_info = context.models.load(**lora.model_dump(exclude={"weight"})) yield (lora_info.context.model, lora.weight) del lora_info return - unet_info = context.services.model_manager.get_model( - **self.unet.unet.model_dump(), - context=context, - ) + unet_info = context.models.load(**self.unet.unet.model_dump()) with ( ExitStack() as exit_stack, ModelPatcher.apply_freeu(unet_info.context.model, self.unet.freeu_config), @@ -787,9 +753,8 @@ def _lora_loader(): if choose_torch_device() == torch.device("mps"): mps.empty_cache() - name = f"{context.graph_execution_state_id}__{self.id}" - context.services.latents.save(name, result_latents) - return build_latents_output(latents_name=name, latents=result_latents, seed=seed) + name = context.latents.save(tensor=result_latents) + return LatentsOutput.build(latents_name=name, latents=result_latents, seed=seed) @invocation( @@ -797,7 +762,7 @@ def _lora_loader(): title="Latents to Image", tags=["latents", "image", "vae", "l2i"], category="latents", - version="1.2.0", + version="1.2.1", ) class LatentsToImageInvocation(BaseInvocation, WithMetadata): """Generates an image from latents.""" @@ -814,13 +779,10 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata): fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32) @torch.no_grad() - def invoke(self, context: InvocationContext) -> ImageOutput: - latents = context.services.latents.get(self.latents.latents_name) + def invoke(self, context) -> ImageOutput: + latents = context.latents.get(self.latents.latents_name) - vae_info = context.services.model_manager.get_model( - **self.vae.vae.model_dump(), - context=context, - ) + vae_info = context.models.load(**self.vae.vae.model_dump()) with set_seamless(vae_info.context.model, self.vae.seamless_axes), vae_info as vae: latents = latents.to(vae.device) @@ -849,7 +811,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: vae.to(dtype=torch.float16) latents = latents.half() - if self.tiled or context.services.configuration.tiled_decode: + if self.tiled or context.config.get().tiled_decode: vae.enable_tiling() else: vae.disable_tiling() @@ -873,22 +835,9 @@ def invoke(self, context: InvocationContext) -> ImageOutput: if choose_torch_device() == torch.device("mps"): mps.empty_cache() - image_dto = context.services.images.create( - image=image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"] @@ -899,7 +848,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Resize Latents", tags=["latents", "resize"], category="latents", - version="1.0.0", + version="1.0.1", ) class ResizeLatentsInvocation(BaseInvocation): """Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8.""" @@ -921,8 +870,8 @@ class ResizeLatentsInvocation(BaseInvocation): mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode) antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias) - def invoke(self, context: InvocationContext) -> LatentsOutput: - latents = context.services.latents.get(self.latents.latents_name) + def invoke(self, context) -> LatentsOutput: + latents = context.latents.get(self.latents.latents_name) # TODO: device = choose_torch_device() @@ -940,10 +889,8 @@ def invoke(self, context: InvocationContext) -> LatentsOutput: if device == torch.device("mps"): mps.empty_cache() - name = f"{context.graph_execution_state_id}__{self.id}" - # context.services.latents.set(name, resized_latents) - context.services.latents.save(name, resized_latents) - return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed) + name = context.latents.save(tensor=resized_latents) + return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed) @invocation( @@ -951,7 +898,7 @@ def invoke(self, context: InvocationContext) -> LatentsOutput: title="Scale Latents", tags=["latents", "resize"], category="latents", - version="1.0.0", + version="1.0.1", ) class ScaleLatentsInvocation(BaseInvocation): """Scales latents by a given factor.""" @@ -964,8 +911,8 @@ class ScaleLatentsInvocation(BaseInvocation): mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode) antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias) - def invoke(self, context: InvocationContext) -> LatentsOutput: - latents = context.services.latents.get(self.latents.latents_name) + def invoke(self, context) -> LatentsOutput: + latents = context.latents.get(self.latents.latents_name) # TODO: device = choose_torch_device() @@ -984,10 +931,8 @@ def invoke(self, context: InvocationContext) -> LatentsOutput: if device == torch.device("mps"): mps.empty_cache() - name = f"{context.graph_execution_state_id}__{self.id}" - # context.services.latents.set(name, resized_latents) - context.services.latents.save(name, resized_latents) - return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed) + name = context.latents.save(tensor=resized_latents) + return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed) @invocation( @@ -995,7 +940,7 @@ def invoke(self, context: InvocationContext) -> LatentsOutput: title="Image to Latents", tags=["latents", "image", "vae", "i2l"], category="latents", - version="1.0.0", + version="1.0.1", ) class ImageToLatentsInvocation(BaseInvocation): """Encodes an image into latents.""" @@ -1055,13 +1000,10 @@ def vae_encode(vae_info, upcast, tiled, image_tensor): return latents @torch.no_grad() - def invoke(self, context: InvocationContext) -> LatentsOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> LatentsOutput: + image = context.images.get_pil(self.image.image_name) - vae_info = context.services.model_manager.get_model( - **self.vae.vae.model_dump(), - context=context, - ) + vae_info = context.models.load(**self.vae.vae.model_dump()) image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB")) if image_tensor.dim() == 3: @@ -1069,10 +1011,9 @@ def invoke(self, context: InvocationContext) -> LatentsOutput: latents = self.vae_encode(vae_info, self.fp32, self.tiled, image_tensor) - name = f"{context.graph_execution_state_id}__{self.id}" latents = latents.to("cpu") - context.services.latents.save(name, latents) - return build_latents_output(latents_name=name, latents=latents, seed=None) + name = context.latents.save(tensor=latents) + return LatentsOutput.build(latents_name=name, latents=latents, seed=None) @singledispatchmethod @staticmethod @@ -1092,7 +1033,7 @@ def _(vae: AutoencoderTiny, image_tensor: torch.FloatTensor) -> torch.FloatTenso title="Blend Latents", tags=["latents", "blend"], category="latents", - version="1.0.0", + version="1.0.1", ) class BlendLatentsInvocation(BaseInvocation): """Blend two latents using a given alpha. Latents must have same size.""" @@ -1107,9 +1048,9 @@ class BlendLatentsInvocation(BaseInvocation): ) alpha: float = InputField(default=0.5, description=FieldDescriptions.blend_alpha) - def invoke(self, context: InvocationContext) -> LatentsOutput: - latents_a = context.services.latents.get(self.latents_a.latents_name) - latents_b = context.services.latents.get(self.latents_b.latents_name) + def invoke(self, context) -> LatentsOutput: + latents_a = context.latents.get(self.latents_a.latents_name) + latents_b = context.latents.get(self.latents_b.latents_name) if latents_a.shape != latents_b.shape: raise Exception("Latents to blend must be the same size.") @@ -1163,10 +1104,8 @@ def slerp(t, v0, v1, DOT_THRESHOLD=0.9995): if device == torch.device("mps"): mps.empty_cache() - name = f"{context.graph_execution_state_id}__{self.id}" - # context.services.latents.set(name, resized_latents) - context.services.latents.save(name, blended_latents) - return build_latents_output(latents_name=name, latents=blended_latents) + name = context.latents.save(tensor=blended_latents) + return LatentsOutput.build(latents_name=name, latents=blended_latents) # The Crop Latents node was copied from @skunkworxdark's implementation here: @@ -1176,7 +1115,7 @@ def slerp(t, v0, v1, DOT_THRESHOLD=0.9995): title="Crop Latents", tags=["latents", "crop"], category="latents", - version="1.0.0", + version="1.0.1", ) # TODO(ryand): Named `CropLatentsCoreInvocation` to prevent a conflict with custom node `CropLatentsInvocation`. # Currently, if the class names conflict then 'GET /openapi.json' fails. @@ -1210,8 +1149,8 @@ class CropLatentsCoreInvocation(BaseInvocation): description="The height (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.", ) - def invoke(self, context: InvocationContext) -> LatentsOutput: - latents = context.services.latents.get(self.latents.latents_name) + def invoke(self, context) -> LatentsOutput: + latents = context.latents.get(self.latents.latents_name) x1 = self.x // LATENT_SCALE_FACTOR y1 = self.y // LATENT_SCALE_FACTOR @@ -1220,10 +1159,9 @@ def invoke(self, context: InvocationContext) -> LatentsOutput: cropped_latents = latents[..., y1:y2, x1:x2] - name = f"{context.graph_execution_state_id}__{self.id}" - context.services.latents.save(name, cropped_latents) + name = context.latents.save(tensor=cropped_latents) - return build_latents_output(latents_name=name, latents=cropped_latents) + return LatentsOutput.build(latents_name=name, latents=cropped_latents) @invocation_output("ideal_size_output") diff --git a/invokeai/app/invocations/math.py b/invokeai/app/invocations/math.py index 6ca53011f0b..d2dbf049816 100644 --- a/invokeai/app/invocations/math.py +++ b/invokeai/app/invocations/math.py @@ -8,7 +8,7 @@ from invokeai.app.invocations.fields import FieldDescriptions, InputField from invokeai.app.invocations.primitives import FloatOutput, IntegerOutput -from .baseinvocation import BaseInvocation, InvocationContext, invocation +from .baseinvocation import BaseInvocation, invocation @invocation("add", title="Add Integers", tags=["math", "add"], category="math", version="1.0.0") @@ -18,7 +18,7 @@ class AddInvocation(BaseInvocation): a: int = InputField(default=0, description=FieldDescriptions.num_1) b: int = InputField(default=0, description=FieldDescriptions.num_2) - def invoke(self, context: InvocationContext) -> IntegerOutput: + def invoke(self, context) -> IntegerOutput: return IntegerOutput(value=self.a + self.b) @@ -29,7 +29,7 @@ class SubtractInvocation(BaseInvocation): a: int = InputField(default=0, description=FieldDescriptions.num_1) b: int = InputField(default=0, description=FieldDescriptions.num_2) - def invoke(self, context: InvocationContext) -> IntegerOutput: + def invoke(self, context) -> IntegerOutput: return IntegerOutput(value=self.a - self.b) @@ -40,7 +40,7 @@ class MultiplyInvocation(BaseInvocation): a: int = InputField(default=0, description=FieldDescriptions.num_1) b: int = InputField(default=0, description=FieldDescriptions.num_2) - def invoke(self, context: InvocationContext) -> IntegerOutput: + def invoke(self, context) -> IntegerOutput: return IntegerOutput(value=self.a * self.b) @@ -51,7 +51,7 @@ class DivideInvocation(BaseInvocation): a: int = InputField(default=0, description=FieldDescriptions.num_1) b: int = InputField(default=0, description=FieldDescriptions.num_2) - def invoke(self, context: InvocationContext) -> IntegerOutput: + def invoke(self, context) -> IntegerOutput: return IntegerOutput(value=int(self.a / self.b)) @@ -69,7 +69,7 @@ class RandomIntInvocation(BaseInvocation): low: int = InputField(default=0, description=FieldDescriptions.inclusive_low) high: int = InputField(default=np.iinfo(np.int32).max, description=FieldDescriptions.exclusive_high) - def invoke(self, context: InvocationContext) -> IntegerOutput: + def invoke(self, context) -> IntegerOutput: return IntegerOutput(value=np.random.randint(self.low, self.high)) @@ -88,7 +88,7 @@ class RandomFloatInvocation(BaseInvocation): high: float = InputField(default=1.0, description=FieldDescriptions.exclusive_high) decimals: int = InputField(default=2, description=FieldDescriptions.decimal_places) - def invoke(self, context: InvocationContext) -> FloatOutput: + def invoke(self, context) -> FloatOutput: random_float = np.random.uniform(self.low, self.high) rounded_float = round(random_float, self.decimals) return FloatOutput(value=rounded_float) @@ -110,7 +110,7 @@ class FloatToIntegerInvocation(BaseInvocation): default="Nearest", description="The method to use for rounding" ) - def invoke(self, context: InvocationContext) -> IntegerOutput: + def invoke(self, context) -> IntegerOutput: if self.method == "Nearest": return IntegerOutput(value=round(self.value / self.multiple) * self.multiple) elif self.method == "Floor": @@ -128,7 +128,7 @@ class RoundInvocation(BaseInvocation): value: float = InputField(default=0, description="The float value") decimals: int = InputField(default=0, description="The number of decimal places") - def invoke(self, context: InvocationContext) -> FloatOutput: + def invoke(self, context) -> FloatOutput: return FloatOutput(value=round(self.value, self.decimals)) @@ -196,7 +196,7 @@ def no_unrepresentable_results(cls, v: int, info: ValidationInfo): raise ValueError("Result of exponentiation is not an integer") return v - def invoke(self, context: InvocationContext) -> IntegerOutput: + def invoke(self, context) -> IntegerOutput: # Python doesn't support switch statements until 3.10, but InvokeAI supports back to 3.9 if self.operation == "ADD": return IntegerOutput(value=self.a + self.b) @@ -270,7 +270,7 @@ def no_unrepresentable_results(cls, v: float, info: ValidationInfo): raise ValueError("Root operation resulted in a complex number") return v - def invoke(self, context: InvocationContext) -> FloatOutput: + def invoke(self, context) -> FloatOutput: # Python doesn't support switch statements until 3.10, but InvokeAI supports back to 3.9 if self.operation == "ADD": return FloatOutput(value=self.a + self.b) diff --git a/invokeai/app/invocations/metadata.py b/invokeai/app/invocations/metadata.py index 399e217dc17..9d74abd8c12 100644 --- a/invokeai/app/invocations/metadata.py +++ b/invokeai/app/invocations/metadata.py @@ -5,15 +5,20 @@ from invokeai.app.invocations.baseinvocation import ( BaseInvocation, BaseInvocationOutput, - InvocationContext, invocation, invocation_output, ) from invokeai.app.invocations.controlnet_image_processors import ControlField -from invokeai.app.invocations.fields import FieldDescriptions, InputField, MetadataField, OutputField, UIType +from invokeai.app.invocations.fields import ( + FieldDescriptions, + ImageField, + InputField, + MetadataField, + OutputField, + UIType, +) from invokeai.app.invocations.ip_adapter import IPAdapterModelField from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField -from invokeai.app.invocations.primitives import ImageField from invokeai.app.invocations.t2i_adapter import T2IAdapterField from ...version import __version__ @@ -59,7 +64,7 @@ class MetadataItemInvocation(BaseInvocation): label: str = InputField(description=FieldDescriptions.metadata_item_label) value: Any = InputField(description=FieldDescriptions.metadata_item_value, ui_type=UIType.Any) - def invoke(self, context: InvocationContext) -> MetadataItemOutput: + def invoke(self, context) -> MetadataItemOutput: return MetadataItemOutput(item=MetadataItemField(label=self.label, value=self.value)) @@ -76,7 +81,7 @@ class MetadataInvocation(BaseInvocation): description=FieldDescriptions.metadata_item_polymorphic ) - def invoke(self, context: InvocationContext) -> MetadataOutput: + def invoke(self, context) -> MetadataOutput: if isinstance(self.items, MetadataItemField): # single metadata item data = {self.items.label: self.items.value} @@ -95,7 +100,7 @@ class MergeMetadataInvocation(BaseInvocation): collection: list[MetadataField] = InputField(description=FieldDescriptions.metadata_collection) - def invoke(self, context: InvocationContext) -> MetadataOutput: + def invoke(self, context) -> MetadataOutput: data = {} for item in self.collection: data.update(item.model_dump()) @@ -213,7 +218,7 @@ class CoreMetadataInvocation(BaseInvocation): description="The start value used for refiner denoising", ) - def invoke(self, context: InvocationContext) -> MetadataOutput: + def invoke(self, context) -> MetadataOutput: """Collects and outputs a CoreMetadata object""" return MetadataOutput( diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index c710c9761b0..f81e559e446 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -10,7 +10,6 @@ from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, - InvocationContext, invocation, invocation_output, ) @@ -102,7 +101,7 @@ class LoRAModelField(BaseModel): title="Main Model", tags=["model"], category="model", - version="1.0.0", + version="1.0.1", ) class MainModelLoaderInvocation(BaseInvocation): """Loads a main model, outputting its submodels.""" @@ -110,13 +109,13 @@ class MainModelLoaderInvocation(BaseInvocation): model: MainModelField = InputField(description=FieldDescriptions.main_model, input=Input.Direct) # TODO: precision? - def invoke(self, context: InvocationContext) -> ModelLoaderOutput: + def invoke(self, context) -> ModelLoaderOutput: base_model = self.model.base_model model_name = self.model.model_name model_type = ModelType.Main # TODO: not found exceptions - if not context.services.model_manager.model_exists( + if not context.models.exists( model_name=model_name, base_model=base_model, model_type=model_type, @@ -203,7 +202,7 @@ class LoraLoaderOutput(BaseInvocationOutput): clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP") -@invocation("lora_loader", title="LoRA", tags=["model"], category="model", version="1.0.0") +@invocation("lora_loader", title="LoRA", tags=["model"], category="model", version="1.0.1") class LoraLoaderInvocation(BaseInvocation): """Apply selected lora to unet and text_encoder.""" @@ -222,14 +221,14 @@ class LoraLoaderInvocation(BaseInvocation): title="CLIP", ) - def invoke(self, context: InvocationContext) -> LoraLoaderOutput: + def invoke(self, context) -> LoraLoaderOutput: if self.lora is None: raise Exception("No LoRA provided") base_model = self.lora.base_model lora_name = self.lora.model_name - if not context.services.model_manager.model_exists( + if not context.models.exists( base_model=base_model, model_name=lora_name, model_type=ModelType.Lora, @@ -285,7 +284,7 @@ class SDXLLoraLoaderOutput(BaseInvocationOutput): title="SDXL LoRA", tags=["lora", "model"], category="model", - version="1.0.0", + version="1.0.1", ) class SDXLLoraLoaderInvocation(BaseInvocation): """Apply selected lora to unet and text_encoder.""" @@ -311,14 +310,14 @@ class SDXLLoraLoaderInvocation(BaseInvocation): title="CLIP 2", ) - def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput: + def invoke(self, context) -> SDXLLoraLoaderOutput: if self.lora is None: raise Exception("No LoRA provided") base_model = self.lora.base_model lora_name = self.lora.model_name - if not context.services.model_manager.model_exists( + if not context.models.exists( base_model=base_model, model_name=lora_name, model_type=ModelType.Lora, @@ -384,7 +383,7 @@ class VAEModelField(BaseModel): model_config = ConfigDict(protected_namespaces=()) -@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.0") +@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.1") class VaeLoaderInvocation(BaseInvocation): """Loads a VAE model, outputting a VaeLoaderOutput""" @@ -394,12 +393,12 @@ class VaeLoaderInvocation(BaseInvocation): title="VAE", ) - def invoke(self, context: InvocationContext) -> VAEOutput: + def invoke(self, context) -> VAEOutput: base_model = self.vae_model.base_model model_name = self.vae_model.model_name model_type = ModelType.Vae - if not context.services.model_manager.model_exists( + if not context.models.exists( base_model=base_model, model_name=model_name, model_type=model_type, @@ -449,7 +448,7 @@ class SeamlessModeInvocation(BaseInvocation): seamless_y: bool = InputField(default=True, input=Input.Any, description="Specify whether Y axis is seamless") seamless_x: bool = InputField(default=True, input=Input.Any, description="Specify whether X axis is seamless") - def invoke(self, context: InvocationContext) -> SeamlessModeOutput: + def invoke(self, context) -> SeamlessModeOutput: # Conditionally append 'x' and 'y' based on seamless_x and seamless_y unet = copy.deepcopy(self.unet) vae = copy.deepcopy(self.vae) @@ -485,6 +484,6 @@ class FreeUInvocation(BaseInvocation): s1: float = InputField(default=0.9, ge=-1, le=3, description=FieldDescriptions.freeu_s1) s2: float = InputField(default=0.2, ge=-1, le=3, description=FieldDescriptions.freeu_s2) - def invoke(self, context: InvocationContext) -> UNetOutput: + def invoke(self, context) -> UNetOutput: self.unet.freeu_config = FreeUConfig(s1=self.s1, s2=self.s2, b1=self.b1, b2=self.b2) return UNetOutput(unet=self.unet) diff --git a/invokeai/app/invocations/noise.py b/invokeai/app/invocations/noise.py index 2e717ac561b..41641152f04 100644 --- a/invokeai/app/invocations/noise.py +++ b/invokeai/app/invocations/noise.py @@ -4,15 +4,13 @@ import torch from pydantic import field_validator -from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField -from invokeai.app.invocations.latent import LatentsField +from invokeai.app.invocations.fields import FieldDescriptions, InputField, LatentsField, OutputField from invokeai.app.util.misc import SEED_MAX from ...backend.util.devices import choose_torch_device, torch_dtype from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, - InvocationContext, invocation, invocation_output, ) @@ -67,13 +65,13 @@ class NoiseOutput(BaseInvocationOutput): width: int = OutputField(description=FieldDescriptions.width) height: int = OutputField(description=FieldDescriptions.height) - -def build_noise_output(latents_name: str, latents: torch.Tensor, seed: int): - return NoiseOutput( - noise=LatentsField(latents_name=latents_name, seed=seed), - width=latents.size()[3] * 8, - height=latents.size()[2] * 8, - ) + @classmethod + def build(cls, latents_name: str, latents: torch.Tensor, seed: int) -> "NoiseOutput": + return cls( + noise=LatentsField(latents_name=latents_name, seed=seed), + width=latents.size()[3] * 8, + height=latents.size()[2] * 8, + ) @invocation( @@ -114,7 +112,7 @@ def modulo_seed(cls, v): """Returns the seed modulo (SEED_MAX + 1) to ensure it is within the valid range.""" return v % (SEED_MAX + 1) - def invoke(self, context: InvocationContext) -> NoiseOutput: + def invoke(self, context) -> NoiseOutput: noise = get_noise( width=self.width, height=self.height, @@ -122,6 +120,5 @@ def invoke(self, context: InvocationContext) -> NoiseOutput: seed=self.seed, use_cpu=self.use_cpu, ) - name = f"{context.graph_execution_state_id}__{self.id}" - context.services.latents.save(name, noise) - return build_noise_output(latents_name=name, latents=noise, seed=self.seed) + name = context.latents.save(tensor=noise) + return NoiseOutput.build(latents_name=name, latents=noise, seed=self.seed) diff --git a/invokeai/app/invocations/onnx.py b/invokeai/app/invocations/onnx.py index b43d7eaef2c..3f8e6669ab8 100644 --- a/invokeai/app/invocations/onnx.py +++ b/invokeai/app/invocations/onnx.py @@ -37,7 +37,7 @@ invocation_output, ) from .controlnet_image_processors import ControlField -from .latent import SAMPLER_NAME_VALUES, LatentsField, LatentsOutput, build_latents_output, get_scheduler +from .latent import SAMPLER_NAME_VALUES, LatentsField, LatentsOutput, get_scheduler from .model import ClipField, ModelInfo, UNetField, VaeField ORT_TO_NP_TYPE = { @@ -63,7 +63,7 @@ class ONNXPromptInvocation(BaseInvocation): prompt: str = InputField(default="", description=FieldDescriptions.raw_prompt, ui_component=UIComponent.Textarea) clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection) - def invoke(self, context: InvocationContext) -> ConditioningOutput: + def invoke(self, context) -> ConditioningOutput: tokenizer_info = context.services.model_manager.get_model( **self.clip.tokenizer.model_dump(), ) @@ -201,7 +201,7 @@ def ge_one(cls, v): # based on # https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375 - def invoke(self, context: InvocationContext) -> LatentsOutput: + def invoke(self, context) -> LatentsOutput: c, _ = context.services.latents.get(self.positive_conditioning.conditioning_name) uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name) graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) @@ -342,7 +342,7 @@ class ONNXLatentsToImageInvocation(BaseInvocation, WithMetadata): ) # tiled: bool = InputField(default=False, description="Decode latents by overlaping tiles(less memory consumption)") - def invoke(self, context: InvocationContext) -> ImageOutput: + def invoke(self, context) -> ImageOutput: latents = context.services.latents.get(self.latents.latents_name) if self.vae.vae.submodel != SubModelType.VaeDecoder: @@ -417,7 +417,7 @@ class OnnxModelLoaderInvocation(BaseInvocation): description=FieldDescriptions.onnx_main_model, input=Input.Direct, ui_type=UIType.ONNXModel ) - def invoke(self, context: InvocationContext) -> ONNXModelLoaderOutput: + def invoke(self, context) -> ONNXModelLoaderOutput: base_model = self.model.base_model model_name = self.model.model_name model_type = ModelType.ONNX diff --git a/invokeai/app/invocations/param_easing.py b/invokeai/app/invocations/param_easing.py index dab9c3dc0f4..bf59e87d270 100644 --- a/invokeai/app/invocations/param_easing.py +++ b/invokeai/app/invocations/param_easing.py @@ -41,7 +41,7 @@ from invokeai.app.invocations.primitives import FloatCollectionOutput -from .baseinvocation import BaseInvocation, InvocationContext, invocation +from .baseinvocation import BaseInvocation, invocation from .fields import InputField @@ -62,7 +62,7 @@ class FloatLinearRangeInvocation(BaseInvocation): description="number of values to interpolate over (including start and stop)", ) - def invoke(self, context: InvocationContext) -> FloatCollectionOutput: + def invoke(self, context) -> FloatCollectionOutput: param_list = list(np.linspace(self.start, self.stop, self.steps)) return FloatCollectionOutput(collection=param_list) @@ -110,7 +110,7 @@ def invoke(self, context: InvocationContext) -> FloatCollectionOutput: title="Step Param Easing", tags=["step", "easing"], category="step", - version="1.0.0", + version="1.0.1", ) class StepParamEasingInvocation(BaseInvocation): """Experimental per-step parameter easing for denoising steps""" @@ -130,7 +130,7 @@ class StepParamEasingInvocation(BaseInvocation): # alt_mirror: bool = InputField(default=False, description="alternative mirroring by dual easing") show_easing_plot: bool = InputField(default=False, description="show easing plot") - def invoke(self, context: InvocationContext) -> FloatCollectionOutput: + def invoke(self, context) -> FloatCollectionOutput: log_diagnostics = False # convert from start_step_percent to nearest step <= (steps * start_step_percent) # start_step = int(np.floor(self.num_steps * self.start_step_percent)) @@ -149,19 +149,19 @@ def invoke(self, context: InvocationContext) -> FloatCollectionOutput: postlist = list(num_poststeps * [self.post_end_value]) if log_diagnostics: - context.services.logger.debug("start_step: " + str(start_step)) - context.services.logger.debug("end_step: " + str(end_step)) - context.services.logger.debug("num_easing_steps: " + str(num_easing_steps)) - context.services.logger.debug("num_presteps: " + str(num_presteps)) - context.services.logger.debug("num_poststeps: " + str(num_poststeps)) - context.services.logger.debug("prelist size: " + str(len(prelist))) - context.services.logger.debug("postlist size: " + str(len(postlist))) - context.services.logger.debug("prelist: " + str(prelist)) - context.services.logger.debug("postlist: " + str(postlist)) + context.logger.debug("start_step: " + str(start_step)) + context.logger.debug("end_step: " + str(end_step)) + context.logger.debug("num_easing_steps: " + str(num_easing_steps)) + context.logger.debug("num_presteps: " + str(num_presteps)) + context.logger.debug("num_poststeps: " + str(num_poststeps)) + context.logger.debug("prelist size: " + str(len(prelist))) + context.logger.debug("postlist size: " + str(len(postlist))) + context.logger.debug("prelist: " + str(prelist)) + context.logger.debug("postlist: " + str(postlist)) easing_class = EASING_FUNCTIONS_MAP[self.easing] if log_diagnostics: - context.services.logger.debug("easing class: " + str(easing_class)) + context.logger.debug("easing class: " + str(easing_class)) easing_list = [] if self.mirror: # "expected" mirroring # if number of steps is even, squeeze duration down to (number_of_steps)/2 @@ -172,7 +172,7 @@ def invoke(self, context: InvocationContext) -> FloatCollectionOutput: base_easing_duration = int(np.ceil(num_easing_steps / 2.0)) if log_diagnostics: - context.services.logger.debug("base easing duration: " + str(base_easing_duration)) + context.logger.debug("base easing duration: " + str(base_easing_duration)) even_num_steps = num_easing_steps % 2 == 0 # even number of steps easing_function = easing_class( start=self.start_value, @@ -184,14 +184,14 @@ def invoke(self, context: InvocationContext) -> FloatCollectionOutput: easing_val = easing_function.ease(step_index) base_easing_vals.append(easing_val) if log_diagnostics: - context.services.logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(easing_val)) + context.logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(easing_val)) if even_num_steps: mirror_easing_vals = list(reversed(base_easing_vals)) else: mirror_easing_vals = list(reversed(base_easing_vals[0:-1])) if log_diagnostics: - context.services.logger.debug("base easing vals: " + str(base_easing_vals)) - context.services.logger.debug("mirror easing vals: " + str(mirror_easing_vals)) + context.logger.debug("base easing vals: " + str(base_easing_vals)) + context.logger.debug("mirror easing vals: " + str(mirror_easing_vals)) easing_list = base_easing_vals + mirror_easing_vals # FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely @@ -226,12 +226,12 @@ def invoke(self, context: InvocationContext) -> FloatCollectionOutput: step_val = easing_function.ease(step_index) easing_list.append(step_val) if log_diagnostics: - context.services.logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(step_val)) + context.logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(step_val)) if log_diagnostics: - context.services.logger.debug("prelist size: " + str(len(prelist))) - context.services.logger.debug("easing_list size: " + str(len(easing_list))) - context.services.logger.debug("postlist size: " + str(len(postlist))) + context.logger.debug("prelist size: " + str(len(prelist))) + context.logger.debug("easing_list size: " + str(len(easing_list))) + context.logger.debug("postlist size: " + str(len(postlist))) param_list = prelist + easing_list + postlist diff --git a/invokeai/app/invocations/primitives.py b/invokeai/app/invocations/primitives.py index 22f03454a55..ee04345eed8 100644 --- a/invokeai/app/invocations/primitives.py +++ b/invokeai/app/invocations/primitives.py @@ -1,16 +1,26 @@ # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) -from typing import Optional, Tuple +from typing import Optional import torch -from pydantic import BaseModel, Field -from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIComponent +from invokeai.app.invocations.fields import ( + ColorField, + ConditioningField, + DenoiseMaskField, + FieldDescriptions, + ImageField, + Input, + InputField, + LatentsField, + OutputField, + UIComponent, +) +from invokeai.app.services.images.images_common import ImageDTO from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, - InvocationContext, invocation, invocation_output, ) @@ -49,7 +59,7 @@ class BooleanInvocation(BaseInvocation): value: bool = InputField(default=False, description="The boolean value") - def invoke(self, context: InvocationContext) -> BooleanOutput: + def invoke(self, context) -> BooleanOutput: return BooleanOutput(value=self.value) @@ -65,7 +75,7 @@ class BooleanCollectionInvocation(BaseInvocation): collection: list[bool] = InputField(default=[], description="The collection of boolean values") - def invoke(self, context: InvocationContext) -> BooleanCollectionOutput: + def invoke(self, context) -> BooleanCollectionOutput: return BooleanCollectionOutput(collection=self.collection) @@ -98,7 +108,7 @@ class IntegerInvocation(BaseInvocation): value: int = InputField(default=0, description="The integer value") - def invoke(self, context: InvocationContext) -> IntegerOutput: + def invoke(self, context) -> IntegerOutput: return IntegerOutput(value=self.value) @@ -114,7 +124,7 @@ class IntegerCollectionInvocation(BaseInvocation): collection: list[int] = InputField(default=[], description="The collection of integer values") - def invoke(self, context: InvocationContext) -> IntegerCollectionOutput: + def invoke(self, context) -> IntegerCollectionOutput: return IntegerCollectionOutput(collection=self.collection) @@ -145,7 +155,7 @@ class FloatInvocation(BaseInvocation): value: float = InputField(default=0.0, description="The float value") - def invoke(self, context: InvocationContext) -> FloatOutput: + def invoke(self, context) -> FloatOutput: return FloatOutput(value=self.value) @@ -161,7 +171,7 @@ class FloatCollectionInvocation(BaseInvocation): collection: list[float] = InputField(default=[], description="The collection of float values") - def invoke(self, context: InvocationContext) -> FloatCollectionOutput: + def invoke(self, context) -> FloatCollectionOutput: return FloatCollectionOutput(collection=self.collection) @@ -192,7 +202,7 @@ class StringInvocation(BaseInvocation): value: str = InputField(default="", description="The string value", ui_component=UIComponent.Textarea) - def invoke(self, context: InvocationContext) -> StringOutput: + def invoke(self, context) -> StringOutput: return StringOutput(value=self.value) @@ -208,7 +218,7 @@ class StringCollectionInvocation(BaseInvocation): collection: list[str] = InputField(default=[], description="The collection of string values") - def invoke(self, context: InvocationContext) -> StringCollectionOutput: + def invoke(self, context) -> StringCollectionOutput: return StringCollectionOutput(collection=self.collection) @@ -217,18 +227,6 @@ def invoke(self, context: InvocationContext) -> StringCollectionOutput: # region Image -class ImageField(BaseModel): - """An image primitive field""" - - image_name: str = Field(description="The name of the image") - - -class BoardField(BaseModel): - """A board primitive field""" - - board_id: str = Field(description="The id of the board") - - @invocation_output("image_output") class ImageOutput(BaseInvocationOutput): """Base class for nodes that output a single image""" @@ -237,6 +235,14 @@ class ImageOutput(BaseInvocationOutput): width: int = OutputField(description="The width of the image in pixels") height: int = OutputField(description="The height of the image in pixels") + @classmethod + def build(cls, image_dto: ImageDTO) -> "ImageOutput": + return cls( + image=ImageField(image_name=image_dto.image_name), + width=image_dto.width, + height=image_dto.height, + ) + @invocation_output("image_collection_output") class ImageCollectionOutput(BaseInvocationOutput): @@ -247,7 +253,7 @@ class ImageCollectionOutput(BaseInvocationOutput): ) -@invocation("image", title="Image Primitive", tags=["primitives", "image"], category="primitives", version="1.0.0") +@invocation("image", title="Image Primitive", tags=["primitives", "image"], category="primitives", version="1.0.1") class ImageInvocation( BaseInvocation, ): @@ -255,8 +261,8 @@ class ImageInvocation( image: ImageField = InputField(description="The image to load") - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) return ImageOutput( image=ImageField(image_name=self.image.image_name), @@ -277,7 +283,7 @@ class ImageCollectionInvocation(BaseInvocation): collection: list[ImageField] = InputField(description="The collection of image values") - def invoke(self, context: InvocationContext) -> ImageCollectionOutput: + def invoke(self, context) -> ImageCollectionOutput: return ImageCollectionOutput(collection=self.collection) @@ -286,32 +292,24 @@ def invoke(self, context: InvocationContext) -> ImageCollectionOutput: # region DenoiseMask -class DenoiseMaskField(BaseModel): - """An inpaint mask field""" - - mask_name: str = Field(description="The name of the mask image") - masked_latents_name: Optional[str] = Field(default=None, description="The name of the masked image latents") - - @invocation_output("denoise_mask_output") class DenoiseMaskOutput(BaseInvocationOutput): """Base class for nodes that output a single image""" denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run") + @classmethod + def build(cls, mask_name: str, masked_latents_name: Optional[str] = None) -> "DenoiseMaskOutput": + return cls( + denoise_mask=DenoiseMaskField(mask_name=mask_name, masked_latents_name=masked_latents_name), + ) + # endregion # region Latents -class LatentsField(BaseModel): - """A latents tensor primitive field""" - - latents_name: str = Field(description="The name of the latents") - seed: Optional[int] = Field(default=None, description="Seed used to generate this latents") - - @invocation_output("latents_output") class LatentsOutput(BaseInvocationOutput): """Base class for nodes that output a single latents tensor""" @@ -322,6 +320,14 @@ class LatentsOutput(BaseInvocationOutput): width: int = OutputField(description=FieldDescriptions.width) height: int = OutputField(description=FieldDescriptions.height) + @classmethod + def build(cls, latents_name: str, latents: torch.Tensor, seed: Optional[int] = None) -> "LatentsOutput": + return cls( + latents=LatentsField(latents_name=latents_name, seed=seed), + width=latents.size()[3] * 8, + height=latents.size()[2] * 8, + ) + @invocation_output("latents_collection_output") class LatentsCollectionOutput(BaseInvocationOutput): @@ -333,17 +339,17 @@ class LatentsCollectionOutput(BaseInvocationOutput): @invocation( - "latents", title="Latents Primitive", tags=["primitives", "latents"], category="primitives", version="1.0.0" + "latents", title="Latents Primitive", tags=["primitives", "latents"], category="primitives", version="1.0.1" ) class LatentsInvocation(BaseInvocation): """A latents tensor primitive value""" latents: LatentsField = InputField(description="The latents tensor", input=Input.Connection) - def invoke(self, context: InvocationContext) -> LatentsOutput: - latents = context.services.latents.get(self.latents.latents_name) + def invoke(self, context) -> LatentsOutput: + latents = context.latents.get(self.latents.latents_name) - return build_latents_output(self.latents.latents_name, latents) + return LatentsOutput.build(self.latents.latents_name, latents) @invocation( @@ -360,35 +366,15 @@ class LatentsCollectionInvocation(BaseInvocation): description="The collection of latents tensors", ) - def invoke(self, context: InvocationContext) -> LatentsCollectionOutput: + def invoke(self, context) -> LatentsCollectionOutput: return LatentsCollectionOutput(collection=self.collection) -def build_latents_output(latents_name: str, latents: torch.Tensor, seed: Optional[int] = None): - return LatentsOutput( - latents=LatentsField(latents_name=latents_name, seed=seed), - width=latents.size()[3] * 8, - height=latents.size()[2] * 8, - ) - - # endregion # region Color -class ColorField(BaseModel): - """A color primitive field""" - - r: int = Field(ge=0, le=255, description="The red component") - g: int = Field(ge=0, le=255, description="The green component") - b: int = Field(ge=0, le=255, description="The blue component") - a: int = Field(ge=0, le=255, description="The alpha component") - - def tuple(self) -> Tuple[int, int, int, int]: - return (self.r, self.g, self.b, self.a) - - @invocation_output("color_output") class ColorOutput(BaseInvocationOutput): """Base class for nodes that output a single color""" @@ -411,7 +397,7 @@ class ColorInvocation(BaseInvocation): color: ColorField = InputField(default=ColorField(r=0, g=0, b=0, a=255), description="The color value") - def invoke(self, context: InvocationContext) -> ColorOutput: + def invoke(self, context) -> ColorOutput: return ColorOutput(color=self.color) @@ -420,18 +406,16 @@ def invoke(self, context: InvocationContext) -> ColorOutput: # region Conditioning -class ConditioningField(BaseModel): - """A conditioning tensor primitive value""" - - conditioning_name: str = Field(description="The name of conditioning tensor") - - @invocation_output("conditioning_output") class ConditioningOutput(BaseInvocationOutput): """Base class for nodes that output a single conditioning tensor""" conditioning: ConditioningField = OutputField(description=FieldDescriptions.cond) + @classmethod + def build(cls, conditioning_name: str) -> "ConditioningOutput": + return cls(conditioning=ConditioningField(conditioning_name=conditioning_name)) + @invocation_output("conditioning_collection_output") class ConditioningCollectionOutput(BaseInvocationOutput): @@ -454,7 +438,7 @@ class ConditioningInvocation(BaseInvocation): conditioning: ConditioningField = InputField(description=FieldDescriptions.cond, input=Input.Connection) - def invoke(self, context: InvocationContext) -> ConditioningOutput: + def invoke(self, context) -> ConditioningOutput: return ConditioningOutput(conditioning=self.conditioning) @@ -473,7 +457,7 @@ class ConditioningCollectionInvocation(BaseInvocation): description="The collection of conditioning tensors", ) - def invoke(self, context: InvocationContext) -> ConditioningCollectionOutput: + def invoke(self, context) -> ConditioningCollectionOutput: return ConditioningCollectionOutput(collection=self.collection) diff --git a/invokeai/app/invocations/prompt.py b/invokeai/app/invocations/prompt.py index 94b4a217ae7..4f5ef43a568 100644 --- a/invokeai/app/invocations/prompt.py +++ b/invokeai/app/invocations/prompt.py @@ -7,7 +7,7 @@ from invokeai.app.invocations.primitives import StringCollectionOutput -from .baseinvocation import BaseInvocation, InvocationContext, invocation +from .baseinvocation import BaseInvocation, invocation from .fields import InputField, UIComponent @@ -29,7 +29,7 @@ class DynamicPromptInvocation(BaseInvocation): max_prompts: int = InputField(default=1, description="The number of prompts to generate") combinatorial: bool = InputField(default=False, description="Whether to use the combinatorial generator") - def invoke(self, context: InvocationContext) -> StringCollectionOutput: + def invoke(self, context) -> StringCollectionOutput: if self.combinatorial: generator = CombinatorialPromptGenerator() prompts = generator.generate(self.prompt, max_prompts=self.max_prompts) @@ -91,7 +91,7 @@ def promptsFromFile( break return prompts - def invoke(self, context: InvocationContext) -> StringCollectionOutput: + def invoke(self, context) -> StringCollectionOutput: prompts = self.promptsFromFile( self.file_path, self.pre_prompt, diff --git a/invokeai/app/invocations/sdxl.py b/invokeai/app/invocations/sdxl.py index 62df5bc8047..75a526cfff6 100644 --- a/invokeai/app/invocations/sdxl.py +++ b/invokeai/app/invocations/sdxl.py @@ -4,7 +4,6 @@ from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, - InvocationContext, invocation, invocation_output, ) @@ -30,7 +29,7 @@ class SDXLRefinerModelLoaderOutput(BaseInvocationOutput): vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE") -@invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model", version="1.0.0") +@invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model", version="1.0.1") class SDXLModelLoaderInvocation(BaseInvocation): """Loads an sdxl base model, outputting its submodels.""" @@ -39,13 +38,13 @@ class SDXLModelLoaderInvocation(BaseInvocation): ) # TODO: precision? - def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput: + def invoke(self, context) -> SDXLModelLoaderOutput: base_model = self.model.base_model model_name = self.model.model_name model_type = ModelType.Main # TODO: not found exceptions - if not context.services.model_manager.model_exists( + if not context.models.exists( model_name=model_name, base_model=base_model, model_type=model_type, @@ -116,7 +115,7 @@ def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput: title="SDXL Refiner Model", tags=["model", "sdxl", "refiner"], category="model", - version="1.0.0", + version="1.0.1", ) class SDXLRefinerModelLoaderInvocation(BaseInvocation): """Loads an sdxl refiner model, outputting its submodels.""" @@ -128,13 +127,13 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation): ) # TODO: precision? - def invoke(self, context: InvocationContext) -> SDXLRefinerModelLoaderOutput: + def invoke(self, context) -> SDXLRefinerModelLoaderOutput: base_model = self.model.base_model model_name = self.model.model_name model_type = ModelType.Main # TODO: not found exceptions - if not context.services.model_manager.model_exists( + if not context.models.exists( model_name=model_name, base_model=base_model, model_type=model_type, diff --git a/invokeai/app/invocations/strings.py b/invokeai/app/invocations/strings.py index ccbc2f6d924..a4c92d9de56 100644 --- a/invokeai/app/invocations/strings.py +++ b/invokeai/app/invocations/strings.py @@ -5,7 +5,6 @@ from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, - InvocationContext, invocation, invocation_output, ) @@ -33,7 +32,7 @@ class StringSplitNegInvocation(BaseInvocation): string: str = InputField(default="", description="String to split", ui_component=UIComponent.Textarea) - def invoke(self, context: InvocationContext) -> StringPosNegOutput: + def invoke(self, context) -> StringPosNegOutput: p_string = "" n_string = "" brackets_depth = 0 @@ -77,7 +76,7 @@ class StringSplitInvocation(BaseInvocation): default="", description="Delimiter to spilt with. blank will split on the first whitespace" ) - def invoke(self, context: InvocationContext) -> String2Output: + def invoke(self, context) -> String2Output: result = self.string.split(self.delimiter, 1) if len(result) == 2: part1, part2 = result @@ -95,7 +94,7 @@ class StringJoinInvocation(BaseInvocation): string_left: str = InputField(default="", description="String Left", ui_component=UIComponent.Textarea) string_right: str = InputField(default="", description="String Right", ui_component=UIComponent.Textarea) - def invoke(self, context: InvocationContext) -> StringOutput: + def invoke(self, context) -> StringOutput: return StringOutput(value=((self.string_left or "") + (self.string_right or ""))) @@ -107,7 +106,7 @@ class StringJoinThreeInvocation(BaseInvocation): string_middle: str = InputField(default="", description="String Middle", ui_component=UIComponent.Textarea) string_right: str = InputField(default="", description="String Right", ui_component=UIComponent.Textarea) - def invoke(self, context: InvocationContext) -> StringOutput: + def invoke(self, context) -> StringOutput: return StringOutput(value=((self.string_left or "") + (self.string_middle or "") + (self.string_right or ""))) @@ -126,7 +125,7 @@ class StringReplaceInvocation(BaseInvocation): default=False, description="Use search string as a regex expression (non regex is case insensitive)" ) - def invoke(self, context: InvocationContext) -> StringOutput: + def invoke(self, context) -> StringOutput: pattern = self.search_string or "" new_string = self.string or "" if len(pattern) > 0: diff --git a/invokeai/app/invocations/t2i_adapter.py b/invokeai/app/invocations/t2i_adapter.py index 66ac87c37b8..74a098a501c 100644 --- a/invokeai/app/invocations/t2i_adapter.py +++ b/invokeai/app/invocations/t2i_adapter.py @@ -5,13 +5,11 @@ from invokeai.app.invocations.baseinvocation import ( BaseInvocation, BaseInvocationOutput, - InvocationContext, invocation, invocation_output, ) from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_RESIZE_VALUES -from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField -from invokeai.app.invocations.primitives import ImageField +from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField from invokeai.app.invocations.util import validate_begin_end_step, validate_weights from invokeai.backend.model_management.models.base import BaseModelType @@ -91,7 +89,7 @@ def validate_begin_end_step_percent(self): validate_begin_end_step(self.begin_step_percent, self.end_step_percent) return self - def invoke(self, context: InvocationContext) -> T2IAdapterOutput: + def invoke(self, context) -> T2IAdapterOutput: return T2IAdapterOutput( t2i_adapter=T2IAdapterField( image=self.image, diff --git a/invokeai/app/invocations/tiles.py b/invokeai/app/invocations/tiles.py index bdc23ef6edd..dd34c3dc093 100644 --- a/invokeai/app/invocations/tiles.py +++ b/invokeai/app/invocations/tiles.py @@ -8,13 +8,12 @@ BaseInvocation, BaseInvocationOutput, Classification, - InvocationContext, + WithMetadata, invocation, invocation_output, ) -from invokeai.app.invocations.fields import Input, InputField, OutputField, WithMetadata -from invokeai.app.invocations.primitives import ImageField, ImageOutput -from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin +from invokeai.app.invocations.fields import ImageField, Input, InputField, OutputField +from invokeai.app.invocations.primitives import ImageOutput from invokeai.backend.tiles.tiles import ( calc_tiles_even_split, calc_tiles_min_overlap, @@ -58,7 +57,7 @@ class CalculateImageTilesInvocation(BaseInvocation): description="The target overlap, in pixels, between adjacent tiles. Adjacent tiles will overlap by at least this amount", ) - def invoke(self, context: InvocationContext) -> CalculateImageTilesOutput: + def invoke(self, context) -> CalculateImageTilesOutput: tiles = calc_tiles_with_overlap( image_height=self.image_height, image_width=self.image_width, @@ -101,7 +100,7 @@ class CalculateImageTilesEvenSplitInvocation(BaseInvocation): description="The overlap, in pixels, between adjacent tiles.", ) - def invoke(self, context: InvocationContext) -> CalculateImageTilesOutput: + def invoke(self, context) -> CalculateImageTilesOutput: tiles = calc_tiles_even_split( image_height=self.image_height, image_width=self.image_width, @@ -131,7 +130,7 @@ class CalculateImageTilesMinimumOverlapInvocation(BaseInvocation): tile_height: int = InputField(ge=1, default=576, description="The tile height, in pixels.") min_overlap: int = InputField(default=128, ge=0, description="Minimum overlap between adjacent tiles, in pixels.") - def invoke(self, context: InvocationContext) -> CalculateImageTilesOutput: + def invoke(self, context) -> CalculateImageTilesOutput: tiles = calc_tiles_min_overlap( image_height=self.image_height, image_width=self.image_width, @@ -176,7 +175,7 @@ class TileToPropertiesInvocation(BaseInvocation): tile: Tile = InputField(description="The tile to split into properties.") - def invoke(self, context: InvocationContext) -> TileToPropertiesOutput: + def invoke(self, context) -> TileToPropertiesOutput: return TileToPropertiesOutput( coords_left=self.tile.coords.left, coords_right=self.tile.coords.right, @@ -213,7 +212,7 @@ class PairTileImageInvocation(BaseInvocation): image: ImageField = InputField(description="The tile image.") tile: Tile = InputField(description="The tile properties.") - def invoke(self, context: InvocationContext) -> PairTileImageOutput: + def invoke(self, context) -> PairTileImageOutput: return PairTileImageOutput( tile_with_image=TileWithImage( tile=self.tile, @@ -249,7 +248,7 @@ class MergeTilesToImageInvocation(BaseInvocation, WithMetadata): description="The amount to blend adjacent tiles in pixels. Must be <= the amount of overlap between adjacent tiles.", ) - def invoke(self, context: InvocationContext) -> ImageOutput: + def invoke(self, context) -> ImageOutput: images = [twi.image for twi in self.tiles_with_images] tiles = [twi.tile for twi in self.tiles_with_images] @@ -265,7 +264,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: # existed in memory at an earlier point in the graph. tile_np_images: list[np.ndarray] = [] for image in images: - pil_image = context.services.images.get_pil_image(image.image_name) + pil_image = context.images.get_pil(image.image_name) pil_image = pil_image.convert("RGB") tile_np_images.append(np.array(pil_image)) @@ -288,18 +287,5 @@ def invoke(self, context: InvocationContext) -> ImageOutput: # Convert into a PIL image and save pil_image = Image.fromarray(np_image) - image_dto = context.services.images.create( - image=pil_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + image_dto = context.images.save(image=pil_image) + return ImageOutput.build(image_dto) diff --git a/invokeai/app/invocations/upscale.py b/invokeai/app/invocations/upscale.py index 2cab279a9fc..ef174809860 100644 --- a/invokeai/app/invocations/upscale.py +++ b/invokeai/app/invocations/upscale.py @@ -8,13 +8,13 @@ from PIL import Image from pydantic import ConfigDict -from invokeai.app.invocations.primitives import ImageField, ImageOutput -from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin +from invokeai.app.invocations.fields import ImageField +from invokeai.app.invocations.primitives import ImageOutput from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN from invokeai.backend.util.devices import choose_torch_device -from .baseinvocation import BaseInvocation, InvocationContext, invocation +from .baseinvocation import BaseInvocation, invocation from .fields import InputField, WithMetadata # TODO: Populate this from disk? @@ -30,7 +30,7 @@ from torch import mps -@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.3.0") +@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.3.1") class ESRGANInvocation(BaseInvocation, WithMetadata): """Upscales an image using RealESRGAN.""" @@ -42,9 +42,9 @@ class ESRGANInvocation(BaseInvocation, WithMetadata): model_config = ConfigDict(protected_namespaces=()) - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) - models_path = context.services.configuration.models_path + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) + models_path = context.config.get().models_path rrdbnet_model = None netscale = None @@ -88,7 +88,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: netscale = 2 else: msg = f"Invalid RealESRGAN model: {self.model_name}" - context.services.logger.error(msg) + context.logger.error(msg) raise ValueError(msg) esrgan_model_path = Path(f"core/upscaling/realesrgan/{self.model_name}") @@ -111,19 +111,6 @@ def invoke(self, context: InvocationContext) -> ImageOutput: if choose_torch_device() == torch.device("mps"): mps.empty_cache() - image_dto = context.services.images.create( - image=pil_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=pil_image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) diff --git a/invokeai/app/services/events/events_base.py b/invokeai/app/services/events/events_base.py index e9365f33495..ad08ae03956 100644 --- a/invokeai/app/services/events/events_base.py +++ b/invokeai/app/services/events/events_base.py @@ -55,7 +55,7 @@ def emit_generator_progress( queue_item_id: int, queue_batch_id: str, graph_execution_state_id: str, - node: dict, + node_id: str, source_node_id: str, progress_image: Optional[ProgressImage], step: int, @@ -70,7 +70,7 @@ def emit_generator_progress( "queue_item_id": queue_item_id, "queue_batch_id": queue_batch_id, "graph_execution_state_id": graph_execution_state_id, - "node_id": node.get("id"), + "node_id": node_id, "source_node_id": source_node_id, "progress_image": progress_image.model_dump() if progress_image is not None else None, "step": step, diff --git a/invokeai/app/services/invocation_processor/invocation_processor_default.py b/invokeai/app/services/invocation_processor/invocation_processor_default.py index 54342c0da13..d2ebe235e63 100644 --- a/invokeai/app/services/invocation_processor/invocation_processor_default.py +++ b/invokeai/app/services/invocation_processor/invocation_processor_default.py @@ -5,11 +5,11 @@ from typing import Optional import invokeai.backend.util.logging as logger -from invokeai.app.invocations.baseinvocation import InvocationContext from invokeai.app.services.invocation_queue.invocation_queue_common import InvocationQueueItem from invokeai.app.services.invocation_stats.invocation_stats_common import ( GESStatsNotFoundError, ) +from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context from invokeai.app.util.profiler import Profiler from ..invoker import Invoker @@ -131,16 +131,20 @@ def stats_cleanup(graph_execution_state_id: str) -> None: # which handles a few things: # - nodes that require a value, but get it only from a connection # - referencing the invocation cache instead of executing the node - outputs = invocation.invoke_internal( - InvocationContext( - services=self.__invoker.services, - graph_execution_state_id=graph_execution_state.id, - queue_item_id=queue_item.session_queue_item_id, - queue_id=queue_item.session_queue_id, - queue_batch_id=queue_item.session_queue_batch_id, - workflow=queue_item.workflow, - ) + context_data = InvocationContextData( + invocation=invocation, + session_id=graph_id, + workflow=queue_item.workflow, + source_node_id=source_node_id, + queue_id=queue_item.session_queue_id, + queue_item_id=queue_item.session_queue_item_id, + batch_id=queue_item.session_queue_batch_id, + ) + context = build_invocation_context( + services=self.__invoker.services, + context_data=context_data, ) + outputs = invocation.invoke_internal(context=context, services=self.__invoker.services) # Check queue to see if this is canceled, and skip if so if self.__invoker.services.queue.is_canceled(graph_execution_state.id): diff --git a/invokeai/app/services/model_manager/model_manager_base.py b/invokeai/app/services/model_manager/model_manager_base.py index 4c2fc4c085c..a9b53ae2242 100644 --- a/invokeai/app/services/model_manager/model_manager_base.py +++ b/invokeai/app/services/model_manager/model_manager_base.py @@ -5,11 +5,12 @@ from abc import ABC, abstractmethod from logging import Logger from pathlib import Path -from typing import TYPE_CHECKING, Callable, List, Literal, Optional, Tuple, Union +from typing import Callable, List, Literal, Optional, Tuple, Union from pydantic import Field from invokeai.app.services.config.config_default import InvokeAIAppConfig +from invokeai.app.services.shared.invocation_context import InvocationContextData from invokeai.backend.model_management import ( AddModelResult, BaseModelType, @@ -21,9 +22,6 @@ ) from invokeai.backend.model_management.model_cache import CacheStats -if TYPE_CHECKING: - from invokeai.app.invocations.baseinvocation import BaseInvocation, InvocationContext - class ModelManagerServiceBase(ABC): """Responsible for managing models on disk and in memory""" @@ -49,8 +47,7 @@ def get_model( base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None, - node: Optional[BaseInvocation] = None, - context: Optional[InvocationContext] = None, + context_data: Optional[InvocationContextData] = None, ) -> ModelInfo: """Retrieve the indicated model with name and type. submodel can be used to get a part (such as the vae) diff --git a/invokeai/app/services/model_manager/model_manager_common.py b/invokeai/app/services/model_manager/model_manager_common.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/invokeai/app/services/model_manager/model_manager_default.py b/invokeai/app/services/model_manager/model_manager_default.py index cdb3e59a91c..b641dd3f1ed 100644 --- a/invokeai/app/services/model_manager/model_manager_default.py +++ b/invokeai/app/services/model_manager/model_manager_default.py @@ -11,6 +11,8 @@ from invokeai.app.services.config.config_default import InvokeAIAppConfig from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException +from invokeai.app.services.invoker import Invoker +from invokeai.app.services.shared.invocation_context import InvocationContextData from invokeai.backend.model_management import ( AddModelResult, BaseModelType, @@ -30,7 +32,7 @@ from .model_manager_base import ModelManagerServiceBase if TYPE_CHECKING: - from invokeai.app.invocations.baseinvocation import InvocationContext + pass # simple implementation @@ -86,13 +88,16 @@ def __init__( ) logger.info("Model manager service initialized") + def start(self, invoker: Invoker) -> None: + self._invoker: Optional[Invoker] = invoker + def get_model( self, model_name: str, base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None, - context: Optional[InvocationContext] = None, + context_data: Optional[InvocationContextData] = None, ) -> ModelInfo: """ Retrieve the indicated model. submodel can be used to get a @@ -100,9 +105,9 @@ def get_model( """ # we can emit model loading events if we are executing with access to the invocation context - if context: + if context_data is not None: self._emit_load_event( - context=context, + context_data=context_data, model_name=model_name, base_model=base_model, model_type=model_type, @@ -116,9 +121,9 @@ def get_model( submodel, ) - if context: + if context_data is not None: self._emit_load_event( - context=context, + context_data=context_data, model_name=model_name, base_model=base_model, model_type=model_type, @@ -263,22 +268,25 @@ def commit(self, conf_file: Optional[Path] = None): def _emit_load_event( self, - context: InvocationContext, + context_data: InvocationContextData, model_name: str, base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None, model_info: Optional[ModelInfo] = None, ): - if context.services.queue.is_canceled(context.graph_execution_state_id): + if self._invoker is None: + return + + if self._invoker.services.queue.is_canceled(context_data.session_id): raise CanceledException() if model_info: - context.services.events.emit_model_load_completed( - queue_id=context.queue_id, - queue_item_id=context.queue_item_id, - queue_batch_id=context.queue_batch_id, - graph_execution_state_id=context.graph_execution_state_id, + self._invoker.services.events.emit_model_load_completed( + queue_id=context_data.queue_id, + queue_item_id=context_data.queue_item_id, + queue_batch_id=context_data.batch_id, + graph_execution_state_id=context_data.session_id, model_name=model_name, base_model=base_model, model_type=model_type, @@ -286,11 +294,11 @@ def _emit_load_event( model_info=model_info, ) else: - context.services.events.emit_model_load_started( - queue_id=context.queue_id, - queue_item_id=context.queue_item_id, - queue_batch_id=context.queue_batch_id, - graph_execution_state_id=context.graph_execution_state_id, + self._invoker.services.events.emit_model_load_started( + queue_id=context_data.queue_id, + queue_item_id=context_data.queue_item_id, + queue_batch_id=context_data.batch_id, + graph_execution_state_id=context_data.session_id, model_name=model_name, base_model=base_model, model_type=model_type, diff --git a/invokeai/app/services/shared/graph.py b/invokeai/app/services/shared/graph.py index ba05b050c5b..c0699eb96bb 100644 --- a/invokeai/app/services/shared/graph.py +++ b/invokeai/app/services/shared/graph.py @@ -13,7 +13,6 @@ from invokeai.app.invocations.baseinvocation import ( BaseInvocation, BaseInvocationOutput, - InvocationContext, invocation, invocation_output, ) @@ -202,7 +201,7 @@ class GraphInvocation(BaseInvocation): # TODO: figure out how to create a default here graph: "Graph" = InputField(description="The graph to run", default=None) - def invoke(self, context: InvocationContext) -> GraphInvocationOutput: + def invoke(self, context) -> GraphInvocationOutput: """Invoke with provided services and return outputs.""" return GraphInvocationOutput() @@ -228,7 +227,7 @@ class IterateInvocation(BaseInvocation): ) index: int = InputField(description="The index, will be provided on executed iterators", default=0, ui_hidden=True) - def invoke(self, context: InvocationContext) -> IterateInvocationOutput: + def invoke(self, context) -> IterateInvocationOutput: """Produces the outputs as values""" return IterateInvocationOutput(item=self.collection[self.index], index=self.index, total=len(self.collection)) @@ -255,7 +254,7 @@ class CollectInvocation(BaseInvocation): description="The collection, will be provided on execution", default=[], ui_hidden=True ) - def invoke(self, context: InvocationContext) -> CollectInvocationOutput: + def invoke(self, context) -> CollectInvocationOutput: """Invoke with provided services and return outputs.""" return CollectInvocationOutput(collection=copy.copy(self.collection)) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index c0aaac54f87..b68e521c73f 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -6,8 +6,7 @@ from pydantic import ConfigDict from torch import Tensor -from invokeai.app.invocations.compel import ConditioningFieldData -from invokeai.app.invocations.fields import MetadataField, WithMetadata +from invokeai.app.invocations.fields import ConditioningFieldData, MetadataField, WithMetadata from invokeai.app.services.config.config_default import InvokeAIAppConfig from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin from invokeai.app.services.images.images_common import ImageDTO @@ -245,13 +244,15 @@ def save(conditioning_data: ConditioningFieldData) -> str: ) return name - def get(conditioning_name: str) -> Tensor: + def get(conditioning_name: str) -> ConditioningFieldData: """ Gets conditioning data by name. :param conditioning_name: The name of the conditioning data to get. """ - return services.latents.get(conditioning_name) + # TODO(sm): We are (ab)using the latents storage service as a general pickle storage + # service, but it is typed as returning tensors, so we need to ignore the type here. + return services.latents.get(conditioning_name) # type: ignore [return-value] self.save = save self.get = get diff --git a/invokeai/app/util/step_callback.py b/invokeai/app/util/step_callback.py index 5cc3caa9ba5..d83b380d95d 100644 --- a/invokeai/app/util/step_callback.py +++ b/invokeai/app/util/step_callback.py @@ -1,25 +1,18 @@ -from typing import Protocol +from typing import TYPE_CHECKING import torch from PIL import Image -from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException, ProgressImage -from invokeai.app.services.invocation_queue.invocation_queue_base import InvocationQueueABC -from invokeai.app.services.shared.invocation_context import InvocationContextData from ...backend.model_management.models import BaseModelType from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.util.util import image_to_dataURL - -class StepCallback(Protocol): - def __call__( - self, - intermediate_state: PipelineIntermediateState, - base_model: BaseModelType, - ) -> None: - ... +if TYPE_CHECKING: + from invokeai.app.services.events.events_base import EventServiceBase + from invokeai.app.services.invocation_queue.invocation_queue_base import InvocationQueueABC + from invokeai.app.services.shared.invocation_context import InvocationContextData def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix=None): @@ -38,11 +31,11 @@ def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix= def stable_diffusion_step_callback( - context_data: InvocationContextData, + context_data: "InvocationContextData", intermediate_state: PipelineIntermediateState, base_model: BaseModelType, - invocation_queue: InvocationQueueABC, - events: EventServiceBase, + invocation_queue: "InvocationQueueABC", + events: "EventServiceBase", ) -> None: if invocation_queue.is_canceled(context_data.session_id): raise CanceledException From 5f1712fccbc01d25fecd5f152e5a228f683ca59d Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 13 Jan 2024 23:23:27 +1100 Subject: [PATCH 005/340] docs: update INVOCATIONS.md --- docs/contributing/INVOCATIONS.md | 97 ++++++++++++-------------------- 1 file changed, 36 insertions(+), 61 deletions(-) diff --git a/docs/contributing/INVOCATIONS.md b/docs/contributing/INVOCATIONS.md index 124589f44ce..5d9a3690bad 100644 --- a/docs/contributing/INVOCATIONS.md +++ b/docs/contributing/INVOCATIONS.md @@ -9,11 +9,15 @@ complex functionality. ## Invocations Directory -InvokeAI Nodes can be found in the `invokeai/app/invocations` directory. These can be used as examples to create your own nodes. +InvokeAI Nodes can be found in the `invokeai/app/invocations` directory. These +can be used as examples to create your own nodes. -New nodes should be added to a subfolder in `nodes` direction found at the root level of the InvokeAI installation location. Nodes added to this folder will be able to be used upon application startup. +New nodes should be added to a subfolder in `nodes` direction found at the root +level of the InvokeAI installation location. Nodes added to this folder will be +able to be used upon application startup. + +Example `nodes` subfolder structure: -Example `nodes` subfolder structure: ```py ├── __init__.py # Invoke-managed custom node loader │ @@ -30,14 +34,14 @@ Example `nodes` subfolder structure: └── fancy_node.py ``` -Each node folder must have an `__init__.py` file that imports its nodes. Only nodes imported in the `__init__.py` file are loaded. - See the README in the nodes folder for more examples: +Each node folder must have an `__init__.py` file that imports its nodes. Only +nodes imported in the `__init__.py` file are loaded. See the README in the nodes +folder for more examples: ```py from .cool_node import CoolInvocation ``` - ## Creating A New Invocation In order to understand the process of creating a new Invocation, let us actually @@ -131,7 +135,6 @@ from invokeai.app.invocations.primitives import ImageField class ResizeInvocation(BaseInvocation): '''Resizes an image''' - # Inputs image: ImageField = InputField(description="The input image") width: int = InputField(default=512, ge=64, le=2048, description="Width of the new image") height: int = InputField(default=512, ge=64, le=2048, description="Height of the new image") @@ -167,12 +170,11 @@ from invokeai.app.invocations.primitives import ImageField class ResizeInvocation(BaseInvocation): '''Resizes an image''' - # Inputs image: ImageField = InputField(description="The input image") width: int = InputField(default=512, ge=64, le=2048, description="Width of the new image") height: int = InputField(default=512, ge=64, le=2048, description="Height of the new image") - def invoke(self, context: InvocationContext): + def invoke(self, context): pass ``` @@ -197,12 +199,11 @@ from invokeai.app.invocations.image import ImageOutput class ResizeInvocation(BaseInvocation): '''Resizes an image''' - # Inputs image: ImageField = InputField(description="The input image") width: int = InputField(default=512, ge=64, le=2048, description="Width of the new image") height: int = InputField(default=512, ge=64, le=2048, description="Height of the new image") - def invoke(self, context: InvocationContext) -> ImageOutput: + def invoke(self, context) -> ImageOutput: pass ``` @@ -228,31 +229,18 @@ class ResizeInvocation(BaseInvocation): width: int = InputField(default=512, ge=64, le=2048, description="Width of the new image") height: int = InputField(default=512, ge=64, le=2048, description="Height of the new image") - def invoke(self, context: InvocationContext) -> ImageOutput: - # Load the image using InvokeAI's predefined Image Service. Returns the PIL image. - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + # Load the input image as a PIL image + image = context.images.get_pil(self.image.image_name) - # Resizing the image + # Resize the image resized_image = image.resize((self.width, self.height)) - # Save the image using InvokeAI's predefined Image Service. Returns the prepared PIL image. - output_image = context.services.images.create( - image=resized_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - ) - - # Returning the Image - return ImageOutput( - image=ImageField( - image_name=output_image.image_name, - ), - width=output_image.width, - height=output_image.height, - ) + # Save the image + image_dto = context.images.save(image=resized_image) + + # Return an ImageOutput + return ImageOutput.build(image_dto) ``` **Note:** Do not be overwhelmed by the `ImageOutput` process. InvokeAI has a @@ -343,27 +331,25 @@ class ImageColorStringOutput(BaseInvocationOutput): That's all there is to it. - +Custom fields only support connection inputs in the Workflow Editor. From b4a4c877250bf71d76ee0346742084b11dddd91b Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 13 Jan 2024 23:23:38 +1100 Subject: [PATCH 006/340] tests: fix tests for new invocation context --- tests/aa_nodes/test_graph_execution_state.py | 8 +------- tests/aa_nodes/test_nodes.py | 17 ++++++++--------- 2 files changed, 9 insertions(+), 16 deletions(-) diff --git a/tests/aa_nodes/test_graph_execution_state.py b/tests/aa_nodes/test_graph_execution_state.py index fab1fa4598f..9cc30e43e11 100644 --- a/tests/aa_nodes/test_graph_execution_state.py +++ b/tests/aa_nodes/test_graph_execution_state.py @@ -21,7 +21,6 @@ from invokeai.app.services.invocation_queue.invocation_queue_memory import MemoryInvocationQueue from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService -from invokeai.app.services.session_queue.session_queue_common import DEFAULT_QUEUE_ID from invokeai.app.services.shared.graph import ( CollectInvocation, Graph, @@ -86,12 +85,7 @@ def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[B print(f"invoking {n.id}: {type(n)}") o = n.invoke( InvocationContext( - queue_batch_id="1", - queue_item_id=1, - queue_id=DEFAULT_QUEUE_ID, - services=services, - graph_execution_state_id="1", - workflow=None, + conditioning=None, config=None, data=None, images=None, latents=None, logger=None, models=None, util=None ) ) g.complete(n.id, o) diff --git a/tests/aa_nodes/test_nodes.py b/tests/aa_nodes/test_nodes.py index e71daad3f3a..559457c0e11 100644 --- a/tests/aa_nodes/test_nodes.py +++ b/tests/aa_nodes/test_nodes.py @@ -3,7 +3,6 @@ from invokeai.app.invocations.baseinvocation import ( BaseInvocation, BaseInvocationOutput, - InvocationContext, invocation, invocation_output, ) @@ -21,7 +20,7 @@ class ListPassThroughInvocationOutput(BaseInvocationOutput): class ListPassThroughInvocation(BaseInvocation): collection: list[ImageField] = InputField(default=[]) - def invoke(self, context: InvocationContext) -> ListPassThroughInvocationOutput: + def invoke(self, context) -> ListPassThroughInvocationOutput: return ListPassThroughInvocationOutput(collection=self.collection) @@ -34,13 +33,13 @@ class PromptTestInvocationOutput(BaseInvocationOutput): class PromptTestInvocation(BaseInvocation): prompt: str = InputField(default="") - def invoke(self, context: InvocationContext) -> PromptTestInvocationOutput: + def invoke(self, context) -> PromptTestInvocationOutput: return PromptTestInvocationOutput(prompt=self.prompt) @invocation("test_error", version="1.0.0") class ErrorInvocation(BaseInvocation): - def invoke(self, context: InvocationContext) -> PromptTestInvocationOutput: + def invoke(self, context) -> PromptTestInvocationOutput: raise Exception("This invocation is supposed to fail") @@ -54,7 +53,7 @@ class TextToImageTestInvocation(BaseInvocation): prompt: str = InputField(default="") prompt2: str = InputField(default="") - def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput: + def invoke(self, context) -> ImageTestInvocationOutput: return ImageTestInvocationOutput(image=ImageField(image_name=self.id)) @@ -63,7 +62,7 @@ class ImageToImageTestInvocation(BaseInvocation): prompt: str = InputField(default="") image: Union[ImageField, None] = InputField(default=None) - def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput: + def invoke(self, context) -> ImageTestInvocationOutput: return ImageTestInvocationOutput(image=ImageField(image_name=self.id)) @@ -76,7 +75,7 @@ class PromptCollectionTestInvocationOutput(BaseInvocationOutput): class PromptCollectionTestInvocation(BaseInvocation): collection: list[str] = InputField() - def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput: + def invoke(self, context) -> PromptCollectionTestInvocationOutput: return PromptCollectionTestInvocationOutput(collection=self.collection.copy()) @@ -89,7 +88,7 @@ class AnyTypeTestInvocationOutput(BaseInvocationOutput): class AnyTypeTestInvocation(BaseInvocation): value: Any = InputField(default=None) - def invoke(self, context: InvocationContext) -> AnyTypeTestInvocationOutput: + def invoke(self, context) -> AnyTypeTestInvocationOutput: return AnyTypeTestInvocationOutput(value=self.value) @@ -97,7 +96,7 @@ def invoke(self, context: InvocationContext) -> AnyTypeTestInvocationOutput: class PolymorphicStringTestInvocation(BaseInvocation): value: Union[str, list[str]] = InputField(default="") - def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput: + def invoke(self, context) -> PromptCollectionTestInvocationOutput: if isinstance(self.value, str): return PromptCollectionTestInvocationOutput(collection=[self.value]) return PromptCollectionTestInvocationOutput(collection=self.value) From 02db1d04938e11fac6bd60efa92ef82578336a7b Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 14 Jan 2024 00:05:15 +1100 Subject: [PATCH 007/340] feat(nodes): tidy `invocation_context.py`, improve comments --- .../app/services/shared/invocation_context.py | 115 ++++++++++++------ 1 file changed, 80 insertions(+), 35 deletions(-) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index b68e521c73f..7961c011aff 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from enum import Enum from typing import TYPE_CHECKING, Optional from PIL.Image import Image @@ -37,6 +36,9 @@ When a node executes, a fresh `InvocationContext` is built for it, ensuring nodes cannot interfere with each other. +Many of the wrappers have the same signature as the methods they wrap. This allows us to write +user-facing docstrings and not need to go and update the internal services to match. + Note: The docstrings are in weird places, but that's where they must be to get IDEs to see them. """ @@ -44,12 +46,19 @@ @dataclass(frozen=True) class InvocationContextData: invocation: "BaseInvocation" + """The invocation that is being executed.""" session_id: str + """The session that is being executed.""" queue_id: str + """The queue in which the session is being executed.""" source_node_id: str + """The ID of the node from which the currently executing invocation was prepared.""" queue_item_id: int + """The ID of the queue item that is being executed.""" batch_id: str + """The ID of the batch that is being executed.""" workflow: Optional[WorkflowWithoutID] = None + """The workflow associated with this queue item, if any.""" class LoggerInterface: @@ -103,14 +112,15 @@ def save( """ Saves an image, returning its DTO. - If the current queue item has a workflow, it is automatically saved with the image. + If the current queue item has a workflow or metadata, it is automatically saved with the image. :param image: The image to save, as a PIL image. :param board_id: The board ID to add the image to, if it should be added. - :param image_category: The category of the image. Only the GENERAL category is added to the gallery. - :param metadata: The metadata to save with the image, if it should have any. If the invocation inherits \ - from `WithMetadata`, that metadata will be used automatically. Provide this only if you want to \ - override or provide metadata manually. + :param image_category: The category of the image. Only the GENERAL category is added \ + to the gallery. + :param metadata: The metadata to save with the image, if it should have any. If the \ + invocation inherits from `WithMetadata`, that metadata will be used automatically. \ + **Use this only if you want to override or provide metadata manually!** """ # If the invocation inherits metadata, use that. Else, use the metadata passed in. @@ -186,14 +196,6 @@ def update( self.update = update -class LatentsKind(str, Enum): - IMAGE = "image" - NOISE = "noise" - MASK = "mask" - MASKED_IMAGE = "masked_image" - OTHER = "other" - - class LatentsInterface: def __init__( self, @@ -206,6 +208,22 @@ def save(tensor: Tensor) -> str: :param tensor: The latents tensor to save. """ + + # Previously, we added a suffix indicating the type of Tensor we were saving, e.g. + # "mask", "noise", "masked_latents", etc. + # + # Retaining that capability in this wrapper would require either many different methods + # to save latents, or extra args for this method. Instead of complicating the API, we + # will use the same naming scheme for all latents. + # + # This has a very minor impact as we don't use them after a session completes. + + # Previously, invocations chose the name for their latents. This is a bit risky, so we + # will generate a name for them instead. We use a uuid to ensure the name is unique. + # + # Because the name of the latents file will includes the session and invocation IDs, + # we don't need to worry about collisions. A truncated UUIDv4 is fine. + name = f"{context_data.session_id}__{context_data.invocation.id}__{uuid_string()[:7]}" services.latents.save( name=name, @@ -231,12 +249,21 @@ def __init__( services: InvocationServices, context_data: InvocationContextData, ) -> None: + # TODO(psyche): We are (ab)using the latents storage service as a general pickle storage + # service, but it is typed to work with Tensors only. We have to fudge the types here. + def save(conditioning_data: ConditioningFieldData) -> str: """ Saves a conditioning data object, returning its name. :param conditioning_data: The conditioning data to save. """ + + # Conditioning data is *not* a Tensor, so we will suffix it to indicate this. + # + # See comment for `LatentsInterface.save` for more info about this method (it's very + # similar). + name = f"{context_data.session_id}__{context_data.invocation.id}__{uuid_string()[:7]}__conditioning" services.latents.save( name=name, @@ -250,9 +277,8 @@ def get(conditioning_name: str) -> ConditioningFieldData: :param conditioning_name: The name of the conditioning data to get. """ - # TODO(sm): We are (ab)using the latents storage service as a general pickle storage - # service, but it is typed as returning tensors, so we need to ignore the type here. - return services.latents.get(conditioning_name) # type: ignore [return-value] + + return services.latents.get(conditioning_name) # type: ignore [return-value] self.save = save self.get = get @@ -281,6 +307,17 @@ def load( :param model_type: The type of the model to get. :param submodel: The submodel of the model to get. """ + + # During this call, the model manager emits events with model loading status. The model + # manager itself has access to the events services, but does not have access to the + # required metadata for the events. + # + # For example, it needs access to the node's ID so that the events can be associated + # with the execution of a specific node. + # + # While this is available within the node, it's tedious to need to pass it in on every + # call. We can avoid that by wrapping the method here. + return services.model_manager.get_model( model_name, base_model, model_type, submodel, context_data=context_data ) @@ -306,8 +343,11 @@ def get() -> InvokeAIAppConfig: """ Gets the app's config. """ - # The config can be changed at runtime. We don't want nodes doing this, so we make a - # frozen copy.. + + # The config can be changed at runtime. + # + # We don't want nodes doing this, so we make a frozen copy. + config = services.configuration.get_config() frozen_config = config.model_copy(update={"model_config": ConfigDict(frozen=True)}) return frozen_config @@ -330,6 +370,12 @@ def sd_step_callback( :param intermediate_state: The intermediate state of the diffusion pipeline. :param base_model: The base model for the current denoising step. """ + + # The step callback needs access to the events and the invocation queue services, but this + # represents a dangerous level of access. + # + # We wrap the step callback so that nodes do not have direct access to these services. + stable_diffusion_step_callback( context_data=context_data, intermediate_state=intermediate_state, @@ -343,36 +389,36 @@ def sd_step_callback( class InvocationContext: """ - The invocation context provides access to various services and data about the current invocation. + The `InvocationContext` provides access to various services and data for the current invocation. """ def __init__( self, images: ImagesInterface, latents: LatentsInterface, + conditioning: ConditioningInterface, models: ModelsInterface, - config: ConfigInterface, logger: LoggerInterface, - data: InvocationContextData, + config: ConfigInterface, util: UtilInterface, - conditioning: ConditioningInterface, + data: InvocationContextData, ) -> None: self.images = images - "Provides methods to save, get and update images and their metadata." - self.logger = logger - "Provides access to the app logger." + """Provides methods to save, get and update images and their metadata.""" self.latents = latents - "Provides methods to save and get latents tensors, including image, noise, masks, and masked images." + """Provides methods to save and get latents tensors, including image, noise, masks, and masked images.""" self.conditioning = conditioning - "Provides methods to save and get conditioning data." + """Provides methods to save and get conditioning data.""" self.models = models - "Provides methods to check if a model exists, get a model, and get a model's info." + """Provides methods to check if a model exists, get a model, and get a model's info.""" + self.logger = logger + """Provides access to the app logger.""" self.config = config - "Provides access to the app's config." - self.data = data - "Provides data about the current queue item and invocation." + """Provides access to the app's config.""" self.util = util - "Provides utility methods." + """Provides utility methods.""" + self.data = data + """Provides data about the current queue item and invocation.""" def build_invocation_context( @@ -380,8 +426,7 @@ def build_invocation_context( context_data: InvocationContextData, ) -> InvocationContext: """ - Builds the invocation context. This is a wrapper around the invocation services that provides - a more convenient (and less dangerous) interface for nodes to use. + Builds the invocation context for a specific invocation execution. :param invocation_services: The invocation services to wrap. :param invocation_context_data: The invocation context data. From a08221656b7ba0e6c51ba85fcc20d27c2b8ad46d Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 14 Jan 2024 00:34:56 +1100 Subject: [PATCH 008/340] chore: ruff --- invokeai/app/invocations/baseinvocation.py | 1 - invokeai/app/invocations/onnx.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index c4aed1fac5a..df0596c9a15 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -22,7 +22,6 @@ Input, InputFieldJSONSchemaExtra, MetadataField, - logger, ) from invokeai.app.services.config.config_default import InvokeAIAppConfig from invokeai.app.services.shared.invocation_context import InvocationContext diff --git a/invokeai/app/invocations/onnx.py b/invokeai/app/invocations/onnx.py index 3f8e6669ab8..a1e318a3802 100644 --- a/invokeai/app/invocations/onnx.py +++ b/invokeai/app/invocations/onnx.py @@ -318,7 +318,7 @@ def dispatch_progress( name = f"{context.graph_execution_state_id}__{self.id}" context.services.latents.save(name, latents) - return build_latents_output(latents_name=name, latents=torch.from_numpy(latents)) + # return build_latents_output(latents_name=name, latents=torch.from_numpy(latents)) # Latent to image From 195f76c3f3e17d6ba623b1a58305c761d8f99eef Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 14 Jan 2024 20:16:51 +1100 Subject: [PATCH 009/340] feat(nodes): restore previous invocation context methods with deprecation warnings --- .../app/services/shared/invocation_context.py | 117 +++++++++++++++++- pyproject.toml | 1 + 2 files changed, 117 insertions(+), 1 deletion(-) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 7961c011aff..023274d49fa 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -1,6 +1,7 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Optional +from deprecated import deprecated from PIL.Image import Image from pydantic import ConfigDict from torch import Tensor @@ -365,7 +366,7 @@ def sd_step_callback( The step callback emits a progress event with the current step, the total number of steps, a preview image, and some other internal metadata. - This should be called after each step of the diffusion process. + This should be called after each denoising step. :param intermediate_state: The intermediate state of the diffusion pipeline. :param base_model: The base model for the current denoising step. @@ -387,6 +388,30 @@ def sd_step_callback( self.sd_step_callback = sd_step_callback +deprecation_version = "3.7.0" +removed_version = "3.8.0" + + +def get_deprecation_reason(property_name: str, alternative: Optional[str] = None) -> str: + msg = f"{property_name} is deprecated as of v{deprecation_version}. It will be removed in v{removed_version}." + if alternative is not None: + msg += f" Use {alternative} instead." + msg += " See PLACEHOLDER_URL for details." + return msg + + +# Deprecation docstrings template. I don't think we can implement these programmatically with +# __doc__ because the IDE won't see them. + +""" +**DEPRECATED as of v3.7.0** + +PROPERTY_NAME will be removed in v3.8.0. Use ALTERNATIVE instead. See PLACEHOLDER_URL for details. + +OG_DOCSTRING +""" + + class InvocationContext: """ The `InvocationContext` provides access to various services and data for the current invocation. @@ -402,6 +427,7 @@ def __init__( config: ConfigInterface, util: UtilInterface, data: InvocationContextData, + services: InvocationServices, ) -> None: self.images = images """Provides methods to save, get and update images and their metadata.""" @@ -419,6 +445,94 @@ def __init__( """Provides utility methods.""" self.data = data """Provides data about the current queue item and invocation.""" + self.__services = services + + @property + @deprecated(version=deprecation_version, reason=get_deprecation_reason("`context.services`")) + def services(self) -> InvocationServices: + """ + **DEPRECATED as of v3.7.0** + + `context.services` will be removed in v3.8.0. See PLACEHOLDER_URL for details. + + The invocation services. + """ + return self.__services + + @property + @deprecated( + version=deprecation_version, + reason=get_deprecation_reason("`context.graph_execution_state_api`", "`context.data.session_id`"), + ) + def graph_execution_state_id(self) -> str: + """ + **DEPRECATED as of v3.7.0** + + `context.graph_execution_state_api` will be removed in v3.8.0. Use `context.data.session_id` instead. See PLACEHOLDER_URL for details. + + The ID of the session (aka graph execution state). + """ + return self.data.session_id + + @property + @deprecated( + version=deprecation_version, + reason=get_deprecation_reason("`context.queue_id`", "`context.data.queue_id`"), + ) + def queue_id(self) -> str: + """ + **DEPRECATED as of v3.7.0** + + `context.queue_id` will be removed in v3.8.0. Use `context.data.queue_id` instead. See PLACEHOLDER_URL for details. + + The ID of the queue. + """ + return self.data.queue_id + + @property + @deprecated( + version=deprecation_version, + reason=get_deprecation_reason("`context.queue_item_id`", "`context.data.queue_item_id`"), + ) + def queue_item_id(self) -> int: + """ + **DEPRECATED as of v3.7.0** + + `context.queue_item_id` will be removed in v3.8.0. Use `context.data.queue_item_id` instead. See PLACEHOLDER_URL for details. + + The ID of the queue item. + """ + return self.data.queue_item_id + + @property + @deprecated( + version=deprecation_version, + reason=get_deprecation_reason("`context.queue_batch_id`", "`context.data.batch_id`"), + ) + def queue_batch_id(self) -> str: + """ + **DEPRECATED as of v3.7.0** + + `context.queue_batch_id` will be removed in v3.8.0. Use `context.data.batch_id` instead. See PLACEHOLDER_URL for details. + + The ID of the batch. + """ + return self.data.batch_id + + @property + @deprecated( + version=deprecation_version, + reason=get_deprecation_reason("`context.workflow`", "`context.data.workflow`"), + ) + def workflow(self) -> Optional[WorkflowWithoutID]: + """ + **DEPRECATED as of v3.7.0** + + `context.workflow` will be removed in v3.8.0. Use `context.data.workflow` instead. See PLACEHOLDER_URL for details. + + The workflow associated with this queue item, if any. + """ + return self.data.workflow def build_invocation_context( @@ -449,6 +563,7 @@ def build_invocation_context( data=context_data, util=util, conditioning=conditioning, + services=services, ) return ctx diff --git a/pyproject.toml b/pyproject.toml index d063f1ad0ee..8d25ed20910 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,6 +66,7 @@ dependencies = [ "albumentations", "click", "datasets", + "Deprecated", "dnspython~=2.4.0", "dynamicprompts", "easing-functions", From 1a5f84bbd2a2af1f8e39249120a3f6f29ac31aa1 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 15 Jan 2024 09:37:05 +1100 Subject: [PATCH 010/340] tests: fix missing arg for InvocationContext --- tests/aa_nodes/test_graph_execution_state.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/aa_nodes/test_graph_execution_state.py b/tests/aa_nodes/test_graph_execution_state.py index 9cc30e43e11..3577a78ae2c 100644 --- a/tests/aa_nodes/test_graph_execution_state.py +++ b/tests/aa_nodes/test_graph_execution_state.py @@ -85,7 +85,15 @@ def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[B print(f"invoking {n.id}: {type(n)}") o = n.invoke( InvocationContext( - conditioning=None, config=None, data=None, images=None, latents=None, logger=None, models=None, util=None + conditioning=None, + config=None, + data=None, + images=None, + latents=None, + logger=None, + models=None, + util=None, + services=None, ) ) g.complete(n.id, o) From d354dbdd09691667e42df4b506f7591e403a3aed Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 15 Jan 2024 10:41:25 +1100 Subject: [PATCH 011/340] feat(nodes): move `ConditioningFieldData` to `conditioning_data.py` --- invokeai/app/invocations/compel.py | 2 +- invokeai/app/invocations/fields.py | 9 +-------- invokeai/app/services/shared/invocation_context.py | 3 ++- .../stable_diffusion/diffusion/conditioning_data.py | 5 +++++ 4 files changed, 9 insertions(+), 10 deletions(-) diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index b4496031bc4..94caf4128d2 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -5,7 +5,6 @@ from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment from invokeai.app.invocations.fields import ( - ConditioningFieldData, FieldDescriptions, Input, InputField, @@ -15,6 +14,7 @@ from invokeai.app.invocations.primitives import ConditioningOutput from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( BasicConditioningInfo, + ConditioningFieldData, ExtraConditioningInfo, SDXLConditioningInfo, ) diff --git a/invokeai/app/invocations/fields.py b/invokeai/app/invocations/fields.py index 566babbb6b7..8879f760770 100644 --- a/invokeai/app/invocations/fields.py +++ b/invokeai/app/invocations/fields.py @@ -1,13 +1,11 @@ -from dataclasses import dataclass from enum import Enum -from typing import Any, Callable, List, Optional, Tuple +from typing import Any, Callable, Optional, Tuple from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter from pydantic.fields import _Unset from pydantic_core import PydanticUndefined from invokeai.app.util.metaenum import MetaEnum -from invokeai.backend.stable_diffusion.diffusion.conditioning_data import BasicConditioningInfo from invokeai.backend.util.logging import InvokeAILogger logger = InvokeAILogger.get_logger() @@ -544,11 +542,6 @@ def tuple(self) -> Tuple[int, int, int, int]: return (self.r, self.g, self.b, self.a) -@dataclass -class ConditioningFieldData: - conditionings: List[BasicConditioningInfo] - - class ConditioningField(BaseModel): """A conditioning tensor primitive value""" diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 023274d49fa..3cf3952de0e 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -6,7 +6,7 @@ from pydantic import ConfigDict from torch import Tensor -from invokeai.app.invocations.fields import ConditioningFieldData, MetadataField, WithMetadata +from invokeai.app.invocations.fields import MetadataField, WithMetadata from invokeai.app.services.config.config_default import InvokeAIAppConfig from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin from invokeai.app.services.images.images_common import ImageDTO @@ -17,6 +17,7 @@ from invokeai.backend.model_management.model_manager import ModelInfo from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState +from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData if TYPE_CHECKING: from invokeai.app.invocations.baseinvocation import BaseInvocation diff --git a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py index 3e38f9f78d5..0676555f7a9 100644 --- a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py @@ -32,6 +32,11 @@ def to(self, device, dtype=None): return self +@dataclass +class ConditioningFieldData: + conditionings: List[BasicConditioningInfo] + + @dataclass class SDXLConditioningInfo(BasicConditioningInfo): pooled_embeds: torch.Tensor From 2ab1b4b07673c59c767a9290b4519f4188595aa3 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 15 Jan 2024 10:48:33 +1100 Subject: [PATCH 012/340] feat(nodes): create invocation_api.py This is the public API for invocations. Everything a custom node might need should be re-exported from this file. --- .../controlnet_image_processors.py | 3 +- invokeai/app/invocations/facetools.py | 3 +- invokeai/app/invocations/image.py | 2 +- invokeai/app/invocations/infill.py | 4 +- invokeai/app/invocations/tiles.py | 3 +- invokeai/invocation_api/__init__.py | 109 ++++++++++++++++++ pyproject.toml | 1 + 7 files changed, 116 insertions(+), 9 deletions(-) create mode 100644 invokeai/invocation_api/__init__.py diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index 3797722c93e..e993ceffde5 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -25,8 +25,7 @@ from PIL import Image from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator -from invokeai.app.invocations.baseinvocation import WithMetadata -from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField +from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField, WithMetadata from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.invocations.util import validate_begin_end_step, validate_weights from invokeai.backend.image_util.depth_anything import DepthAnythingDetector diff --git a/invokeai/app/invocations/facetools.py b/invokeai/app/invocations/facetools.py index 2c92e28cfe0..dad63089816 100644 --- a/invokeai/app/invocations/facetools.py +++ b/invokeai/app/invocations/facetools.py @@ -13,11 +13,10 @@ import invokeai.assets.fonts as font_assets from invokeai.app.invocations.baseinvocation import ( BaseInvocation, - WithMetadata, invocation, invocation_output, ) -from invokeai.app.invocations.fields import ImageField, InputField, OutputField +from invokeai.app.invocations.fields import ImageField, InputField, OutputField, WithMetadata from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.services.image_records.image_records_common import ImageCategory diff --git a/invokeai/app/invocations/image.py b/invokeai/app/invocations/image.py index 10ebd97ace3..3b8b0b4b80b 100644 --- a/invokeai/app/invocations/image.py +++ b/invokeai/app/invocations/image.py @@ -7,7 +7,6 @@ import numpy from PIL import Image, ImageChops, ImageFilter, ImageOps -from invokeai.app.invocations.baseinvocation import WithMetadata from invokeai.app.invocations.fields import ( BoardField, ColorField, @@ -15,6 +14,7 @@ ImageField, Input, InputField, + WithMetadata, ) from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.services.image_records.image_records_common import ImageCategory diff --git a/invokeai/app/invocations/infill.py b/invokeai/app/invocations/infill.py index be51c8312f9..159bdb5f7ad 100644 --- a/invokeai/app/invocations/infill.py +++ b/invokeai/app/invocations/infill.py @@ -13,8 +13,8 @@ from invokeai.backend.image_util.lama import LaMA from invokeai.backend.image_util.patchmatch import PatchMatch -from .baseinvocation import BaseInvocation, WithMetadata, invocation -from .fields import InputField +from .baseinvocation import BaseInvocation, invocation +from .fields import InputField, WithMetadata from .image import PIL_RESAMPLING_MAP, PIL_RESAMPLING_MODES diff --git a/invokeai/app/invocations/tiles.py b/invokeai/app/invocations/tiles.py index dd34c3dc093..0b4c472696b 100644 --- a/invokeai/app/invocations/tiles.py +++ b/invokeai/app/invocations/tiles.py @@ -8,11 +8,10 @@ BaseInvocation, BaseInvocationOutput, Classification, - WithMetadata, invocation, invocation_output, ) -from invokeai.app.invocations.fields import ImageField, Input, InputField, OutputField +from invokeai.app.invocations.fields import ImageField, Input, InputField, OutputField, WithMetadata from invokeai.app.invocations.primitives import ImageOutput from invokeai.backend.tiles.tiles import ( calc_tiles_even_split, diff --git a/invokeai/invocation_api/__init__.py b/invokeai/invocation_api/__init__.py new file mode 100644 index 00000000000..e867ec3cc4e --- /dev/null +++ b/invokeai/invocation_api/__init__.py @@ -0,0 +1,109 @@ +""" +This file re-exports all the public API for invocations. This is the only file that should be imported by custom nodes. + +TODO(psyche): Do we want to dogfood this? +""" + +from invokeai.app.invocations.baseinvocation import ( + BaseInvocation, + BaseInvocationOutput, + invocation, + invocation_output, +) +from invokeai.app.invocations.fields import ( + BoardField, + ColorField, + ConditioningField, + DenoiseMaskField, + FieldDescriptions, + FieldKind, + ImageField, + Input, + InputField, + LatentsField, + MetadataField, + OutputField, + UIComponent, + UIType, + WithMetadata, + WithWorkflow, +) +from invokeai.app.invocations.primitives import ( + BooleanCollectionOutput, + BooleanOutput, + ColorCollectionOutput, + ColorOutput, + ConditioningCollectionOutput, + ConditioningOutput, + DenoiseMaskOutput, + FloatCollectionOutput, + FloatOutput, + ImageCollectionOutput, + ImageOutput, + IntegerCollectionOutput, + IntegerOutput, + LatentsCollectionOutput, + LatentsOutput, + StringCollectionOutput, + StringOutput, +) +from invokeai.app.services.image_records.image_records_common import ImageCategory +from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( + BasicConditioningInfo, + ConditioningFieldData, + ExtraConditioningInfo, + SDXLConditioningInfo, +) + +__all__ = [ + # invokeai.app.invocations.baseinvocation + "BaseInvocation", + "BaseInvocationOutput", + "invocation", + "invocation_output", + # invokeai.app.services.shared.invocation_context + "InvocationContext", + # invokeai.app.invocations.fields + "BoardField", + "ColorField", + "ConditioningField", + "DenoiseMaskField", + "FieldDescriptions", + "FieldKind", + "ImageField", + "Input", + "InputField", + "LatentsField", + "MetadataField", + "OutputField", + "UIComponent", + "UIType", + "WithMetadata", + "WithWorkflow", + # invokeai.app.invocations.primitives + "BooleanCollectionOutput", + "BooleanOutput", + "ColorCollectionOutput", + "ColorOutput", + "ConditioningCollectionOutput", + "ConditioningOutput", + "DenoiseMaskOutput", + "FloatCollectionOutput", + "FloatOutput", + "ImageCollectionOutput", + "ImageOutput", + "IntegerCollectionOutput", + "IntegerOutput", + "LatentsCollectionOutput", + "LatentsOutput", + "StringCollectionOutput", + "StringOutput", + # invokeai.app.services.image_records.image_records_common + "ImageCategory", + # invokeai.backend.stable_diffusion.diffusion.conditioning_data + "BasicConditioningInfo", + "ConditioningFieldData", + "ExtraConditioningInfo", + "SDXLConditioningInfo", +] diff --git a/pyproject.toml b/pyproject.toml index 8d25ed20910..69958064c6d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -170,6 +170,7 @@ version = { attr = "invokeai.version.__version__" } "invokeai.frontend.web.static*", "invokeai.configs*", "invokeai.app*", + "invokeai.invocation_api*", ] [tool.setuptools.package-data] From 3176c84ab662c7b6e4de2403767b7e61b0d5ac4d Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 16 Jan 2024 19:19:49 +1100 Subject: [PATCH 013/340] feat: tweak pyright config --- pyproject.toml | 47 +++++++++++++++++++++++++++++++++-------------- 1 file changed, 33 insertions(+), 14 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 69958064c6d..8b28375e291 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -284,17 +284,36 @@ module = [ #=== End: MyPy [tool.pyright] -include = [ - "invokeai/app/invocations/" -] -exclude = [ - "**/node_modules", - "**/__pycache__", - "invokeai/app/invocations/onnx.py", - "invokeai/app/api/routers/models.py", - "invokeai/app/services/invocation_stats/invocation_stats_default.py", - "invokeai/app/services/model_manager/model_manager_base.py", - "invokeai/app/services/model_manager/model_manager_default.py", - "invokeai/app/services/model_records/model_records_sql.py", - "invokeai/app/util/controlnet_utils.py", -] +# Start from strict mode +typeCheckingMode = "strict" +# This errors whenever an import is missing a type stub file - way too noisy +reportMissingTypeStubs = "none" +# These are the rest of the rules enabled by strict mode - enable them @ warning +reportConstantRedefinition = "warning" +reportDeprecated = "warning" +reportDuplicateImport = "warning" +reportIncompleteStub = "warning" +reportInconsistentConstructor = "warning" +reportInvalidStubStatement = "warning" +reportMatchNotExhaustive = "warning" +reportMissingParameterType = "warning" +reportMissingTypeArgument = "warning" +reportPrivateUsage = "warning" +reportTypeCommentUsage = "warning" +reportUnknownArgumentType = "warning" +reportUnknownLambdaType = "warning" +reportUnknownMemberType = "warning" +reportUnknownParameterType = "warning" +reportUnknownVariableType = "warning" +reportUnnecessaryCast = "warning" +reportUnnecessaryComparison = "warning" +reportUnnecessaryContains = "warning" +reportUnnecessaryIsInstance = "warning" +reportUnusedClass = "warning" +reportUnusedImport = "warning" +reportUnusedFunction = "warning" +reportUnusedVariable = "warning" +reportUntypedBaseClass = "warning" +reportUntypedClassDecorator = "warning" +reportUntypedFunctionDecorator = "warning" +reportUntypedNamedTuple = "warning" From 700dcae79dbb38298cad3fc8b74aba6fbde3c93b Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 16 Jan 2024 20:02:38 +1100 Subject: [PATCH 014/340] feat(nodes): do not freeze InvocationContextData, prevents it from being subclassesd --- invokeai/app/services/shared/invocation_context.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 3cf3952de0e..a849d6b17a2 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -45,7 +45,7 @@ """ -@dataclass(frozen=True) +@dataclass class InvocationContextData: invocation: "BaseInvocation" """The invocation that is being executed.""" From b9702d88dea170c98e92c3231f4cab4942852c8d Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 5 Feb 2024 17:16:35 +1100 Subject: [PATCH 015/340] fix(nodes): restore type annotations for `InvocationContext` --- docs/contributing/INVOCATIONS.md | 6 +-- invokeai/app/invocations/collections.py | 7 +-- invokeai/app/invocations/compel.py | 18 +++---- .../controlnet_image_processors.py | 5 +- invokeai/app/invocations/cv.py | 3 +- invokeai/app/invocations/facetools.py | 24 ++++----- invokeai/app/invocations/image.py | 51 ++++++++++--------- invokeai/app/invocations/infill.py | 11 ++-- invokeai/app/invocations/ip_adapter.py | 3 +- invokeai/app/invocations/latent.py | 18 +++---- invokeai/app/invocations/math.py | 21 ++++---- invokeai/app/invocations/metadata.py | 9 ++-- invokeai/app/invocations/model.py | 13 ++--- invokeai/app/invocations/noise.py | 3 +- invokeai/app/invocations/onnx.py | 8 +-- invokeai/app/invocations/param_easing.py | 5 +- invokeai/app/invocations/primitives.py | 31 +++++------ invokeai/app/invocations/prompt.py | 5 +- invokeai/app/invocations/sdxl.py | 5 +- invokeai/app/invocations/strings.py | 12 +++-- invokeai/app/invocations/t2i_adapter.py | 3 +- invokeai/app/invocations/tiles.py | 13 ++--- invokeai/app/invocations/upscale.py | 3 +- invokeai/app/services/shared/graph.py | 7 +-- tests/aa_nodes/test_nodes.py | 17 ++++--- 25 files changed, 158 insertions(+), 143 deletions(-) diff --git a/docs/contributing/INVOCATIONS.md b/docs/contributing/INVOCATIONS.md index 5d9a3690bad..ce1ee9e808a 100644 --- a/docs/contributing/INVOCATIONS.md +++ b/docs/contributing/INVOCATIONS.md @@ -174,7 +174,7 @@ class ResizeInvocation(BaseInvocation): width: int = InputField(default=512, ge=64, le=2048, description="Width of the new image") height: int = InputField(default=512, ge=64, le=2048, description="Height of the new image") - def invoke(self, context): + def invoke(self, context: InvocationContext): pass ``` @@ -203,7 +203,7 @@ class ResizeInvocation(BaseInvocation): width: int = InputField(default=512, ge=64, le=2048, description="Width of the new image") height: int = InputField(default=512, ge=64, le=2048, description="Height of the new image") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: pass ``` @@ -229,7 +229,7 @@ class ResizeInvocation(BaseInvocation): width: int = InputField(default=512, ge=64, le=2048, description="Width of the new image") height: int = InputField(default=512, ge=64, le=2048, description="Height of the new image") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: # Load the input image as a PIL image image = context.images.get_pil(self.image.image_name) diff --git a/invokeai/app/invocations/collections.py b/invokeai/app/invocations/collections.py index f5709b4ba36..e02291980f9 100644 --- a/invokeai/app/invocations/collections.py +++ b/invokeai/app/invocations/collections.py @@ -5,6 +5,7 @@ from pydantic import ValidationInfo, field_validator from invokeai.app.invocations.primitives import IntegerCollectionOutput +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.misc import SEED_MAX from .baseinvocation import BaseInvocation, invocation @@ -27,7 +28,7 @@ def stop_gt_start(cls, v: int, info: ValidationInfo): raise ValueError("stop must be greater than start") return v - def invoke(self, context) -> IntegerCollectionOutput: + def invoke(self, context: InvocationContext) -> IntegerCollectionOutput: return IntegerCollectionOutput(collection=list(range(self.start, self.stop, self.step))) @@ -45,7 +46,7 @@ class RangeOfSizeInvocation(BaseInvocation): size: int = InputField(default=1, gt=0, description="The number of values") step: int = InputField(default=1, description="The step of the range") - def invoke(self, context) -> IntegerCollectionOutput: + def invoke(self, context: InvocationContext) -> IntegerCollectionOutput: return IntegerCollectionOutput( collection=list(range(self.start, self.start + (self.step * self.size), self.step)) ) @@ -72,6 +73,6 @@ class RandomRangeInvocation(BaseInvocation): description="The seed for the RNG (omit for random)", ) - def invoke(self, context) -> IntegerCollectionOutput: + def invoke(self, context: InvocationContext) -> IntegerCollectionOutput: rng = np.random.default_rng(self.seed) return IntegerCollectionOutput(collection=list(rng.integers(low=self.low, high=self.high, size=self.size))) diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 94caf4128d2..978c6dcb17f 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, List, Optional, Union +from typing import List, Optional, Union import torch from compel import Compel, ReturnedEmbeddingsType @@ -12,6 +12,7 @@ UIComponent, ) from invokeai.app.invocations.primitives import ConditioningOutput +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( BasicConditioningInfo, ConditioningFieldData, @@ -31,10 +32,7 @@ ) from .model import ClipField -if TYPE_CHECKING: - from invokeai.app.services.shared.invocation_context import InvocationContext - - # unconditioned: Optional[torch.Tensor] +# unconditioned: Optional[torch.Tensor] # class ConditioningAlgo(str, Enum): @@ -65,7 +63,7 @@ class CompelInvocation(BaseInvocation): ) @torch.no_grad() - def invoke(self, context) -> ConditioningOutput: + def invoke(self, context: InvocationContext) -> ConditioningOutput: tokenizer_info = context.models.load(**self.clip.tokenizer.model_dump()) text_encoder_info = context.models.load(**self.clip.text_encoder.model_dump()) @@ -148,7 +146,7 @@ def _lora_loader(): class SDXLPromptInvocationBase: def run_clip_compel( self, - context: "InvocationContext", + context: InvocationContext, clip_field: ClipField, prompt: str, get_pooled: bool, @@ -288,7 +286,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2") @torch.no_grad() - def invoke(self, context) -> ConditioningOutput: + def invoke(self, context: InvocationContext) -> ConditioningOutput: c1, c1_pooled, ec1 = self.run_clip_compel( context, self.clip, self.prompt, False, "lora_te1_", zero_on_empty=True ) @@ -373,7 +371,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection) @torch.no_grad() - def invoke(self, context) -> ConditioningOutput: + def invoke(self, context: InvocationContext) -> ConditioningOutput: # TODO: if there will appear lora for refiner - write proper prefix c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "", zero_on_empty=False) @@ -418,7 +416,7 @@ class ClipSkipInvocation(BaseInvocation): clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP") skipped_layers: int = InputField(default=0, description=FieldDescriptions.skipped_layers) - def invoke(self, context) -> ClipSkipInvocationOutput: + def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput: self.clip.skipped_layers += self.skipped_layers return ClipSkipInvocationOutput( clip=self.clip, diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index e993ceffde5..f8bdf14117c 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -28,6 +28,7 @@ from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField, WithMetadata from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.invocations.util import validate_begin_end_step, validate_weights +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.image_util.depth_anything import DepthAnythingDetector from invokeai.backend.image_util.dw_openpose import DWOpenposeDetector from invokeai.backend.model_management.models.base import BaseModelType @@ -119,7 +120,7 @@ def validate_begin_end_step_percent(self) -> "ControlNetInvocation": validate_begin_end_step(self.begin_step_percent, self.end_step_percent) return self - def invoke(self, context) -> ControlOutput: + def invoke(self, context: InvocationContext) -> ControlOutput: return ControlOutput( control=ControlField( image=self.image, @@ -143,7 +144,7 @@ def run_processor(self, image: Image.Image) -> Image.Image: # superclass just passes through image without processing return image - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: raw_image = context.images.get_pil(self.image.image_name) # image type should be PIL.PngImagePlugin.PngImageFile ? processed_image = self.run_processor(raw_image) diff --git a/invokeai/app/invocations/cv.py b/invokeai/app/invocations/cv.py index 375b18f9c58..1ebabf5e064 100644 --- a/invokeai/app/invocations/cv.py +++ b/invokeai/app/invocations/cv.py @@ -7,6 +7,7 @@ from invokeai.app.invocations.fields import ImageField from invokeai.app.invocations.primitives import ImageOutput +from invokeai.app.services.shared.invocation_context import InvocationContext from .baseinvocation import BaseInvocation, invocation from .fields import InputField, WithMetadata @@ -19,7 +20,7 @@ class CvInpaintInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description="The image to inpaint") mask: ImageField = InputField(description="The mask to use when inpainting") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) mask = context.images.get_pil(self.mask.image_name) diff --git a/invokeai/app/invocations/facetools.py b/invokeai/app/invocations/facetools.py index dad63089816..a1702d6517c 100644 --- a/invokeai/app/invocations/facetools.py +++ b/invokeai/app/invocations/facetools.py @@ -1,7 +1,7 @@ import math import re from pathlib import Path -from typing import TYPE_CHECKING, Optional, TypedDict +from typing import Optional, TypedDict import cv2 import numpy as np @@ -19,9 +19,7 @@ from invokeai.app.invocations.fields import ImageField, InputField, OutputField, WithMetadata from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.services.image_records.image_records_common import ImageCategory - -if TYPE_CHECKING: - from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.app.services.shared.invocation_context import InvocationContext @invocation_output("face_mask_output") @@ -176,7 +174,7 @@ def prepare_faces_list( def generate_face_box_mask( - context: "InvocationContext", + context: InvocationContext, minimum_confidence: float, x_offset: float, y_offset: float, @@ -275,7 +273,7 @@ def generate_face_box_mask( def extract_face( - context: "InvocationContext", + context: InvocationContext, image: ImageType, face: FaceResultData, padding: int, @@ -356,7 +354,7 @@ def extract_face( def get_faces_list( - context: "InvocationContext", + context: InvocationContext, image: ImageType, should_chunk: bool, minimum_confidence: float, @@ -458,7 +456,7 @@ class FaceOffInvocation(BaseInvocation, WithMetadata): description="Whether to bypass full image face detection and default to image chunking. Chunking will occur if no faces are found in the full image.", ) - def faceoff(self, context: "InvocationContext", image: ImageType) -> Optional[ExtractFaceData]: + def faceoff(self, context: InvocationContext, image: ImageType) -> Optional[ExtractFaceData]: all_faces = get_faces_list( context=context, image=image, @@ -485,7 +483,7 @@ def faceoff(self, context: "InvocationContext", image: ImageType) -> Optional[Ex return face_data - def invoke(self, context) -> FaceOffOutput: + def invoke(self, context: InvocationContext) -> FaceOffOutput: image = context.images.get_pil(self.image.image_name) result = self.faceoff(context=context, image=image) @@ -543,7 +541,7 @@ def validate_comma_separated_ints(cls, v) -> str: raise ValueError('Face IDs must be a comma-separated list of integers (e.g. "1,2,3")') return v - def facemask(self, context: "InvocationContext", image: ImageType) -> FaceMaskResult: + def facemask(self, context: InvocationContext, image: ImageType) -> FaceMaskResult: all_faces = get_faces_list( context=context, image=image, @@ -600,7 +598,7 @@ def facemask(self, context: "InvocationContext", image: ImageType) -> FaceMaskRe mask=mask_pil, ) - def invoke(self, context) -> FaceMaskOutput: + def invoke(self, context: InvocationContext) -> FaceMaskOutput: image = context.images.get_pil(self.image.image_name) result = self.facemask(context=context, image=image) @@ -633,7 +631,7 @@ class FaceIdentifierInvocation(BaseInvocation, WithMetadata): description="Whether to bypass full image face detection and default to image chunking. Chunking will occur if no faces are found in the full image.", ) - def faceidentifier(self, context: "InvocationContext", image: ImageType) -> ImageType: + def faceidentifier(self, context: InvocationContext, image: ImageType) -> ImageType: image = image.copy() all_faces = get_faces_list( @@ -674,7 +672,7 @@ def faceidentifier(self, context: "InvocationContext", image: ImageType) -> Imag return image - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) result_image = self.faceidentifier(context=context, image=image) diff --git a/invokeai/app/invocations/image.py b/invokeai/app/invocations/image.py index 3b8b0b4b80b..7b74e4d96d4 100644 --- a/invokeai/app/invocations/image.py +++ b/invokeai/app/invocations/image.py @@ -18,6 +18,7 @@ ) from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.services.image_records.image_records_common import ImageCategory +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark from invokeai.backend.image_util.safety_checker import SafetyChecker @@ -34,7 +35,7 @@ class ShowImageInvocation(BaseInvocation): image: ImageField = InputField(description="The image to show") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) image.show() @@ -62,7 +63,7 @@ class BlankImageInvocation(BaseInvocation, WithMetadata): mode: Literal["RGB", "RGBA"] = InputField(default="RGB", description="The mode of the image") color: ColorField = InputField(default=ColorField(r=0, g=0, b=0, a=255), description="The color of the image") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = Image.new(mode=self.mode, size=(self.width, self.height), color=self.color.tuple()) image_dto = context.images.save(image=image) @@ -86,7 +87,7 @@ class ImageCropInvocation(BaseInvocation, WithMetadata): width: int = InputField(default=512, gt=0, description="The width of the crop rectangle") height: int = InputField(default=512, gt=0, description="The height of the crop rectangle") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) image_crop = Image.new(mode="RGBA", size=(self.width, self.height), color=(0, 0, 0, 0)) @@ -125,7 +126,7 @@ class CenterPadCropInvocation(BaseInvocation): description="Number of pixels to pad/crop from the bottom (negative values crop inwards, positive values pad outwards)", ) - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) # Calculate and create new image dimensions @@ -161,7 +162,7 @@ class ImagePasteInvocation(BaseInvocation, WithMetadata): y: int = InputField(default=0, description="The top y coordinate at which to paste the image") crop: bool = InputField(default=False, description="Crop to base image dimensions") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: base_image = context.images.get_pil(self.base_image.image_name) image = context.images.get_pil(self.image.image_name) mask = None @@ -201,7 +202,7 @@ class MaskFromAlphaInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description="The image to create the mask from") invert: bool = InputField(default=False, description="Whether or not to invert the mask") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) image_mask = image.split()[-1] @@ -226,7 +227,7 @@ class ImageMultiplyInvocation(BaseInvocation, WithMetadata): image1: ImageField = InputField(description="The first image to multiply") image2: ImageField = InputField(description="The second image to multiply") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image1 = context.images.get_pil(self.image1.image_name) image2 = context.images.get_pil(self.image2.image_name) @@ -253,7 +254,7 @@ class ImageChannelInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description="The image to get the channel from") channel: IMAGE_CHANNELS = InputField(default="A", description="The channel to get") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) channel_image = image.getchannel(self.channel) @@ -279,7 +280,7 @@ class ImageConvertInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description="The image to convert") mode: IMAGE_MODES = InputField(default="L", description="The mode to convert to") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) converted_image = image.convert(self.mode) @@ -304,7 +305,7 @@ class ImageBlurInvocation(BaseInvocation, WithMetadata): # Metadata blur_type: Literal["gaussian", "box"] = InputField(default="gaussian", description="The type of blur") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) blur = ( @@ -338,7 +339,7 @@ def pil_from_array(self, arr): def array_from_pil(self, img): return numpy.array(img) / 255 - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) mode = image.mode @@ -401,7 +402,7 @@ class ImageResizeInvocation(BaseInvocation, WithMetadata): height: int = InputField(default=512, gt=0, description="The height to resize to (px)") resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) resample_mode = PIL_RESAMPLING_MAP[self.resample_mode] @@ -434,7 +435,7 @@ class ImageScaleInvocation(BaseInvocation, WithMetadata): ) resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) resample_mode = PIL_RESAMPLING_MAP[self.resample_mode] @@ -465,7 +466,7 @@ class ImageLerpInvocation(BaseInvocation, WithMetadata): min: int = InputField(default=0, ge=0, le=255, description="The minimum output value") max: int = InputField(default=255, ge=0, le=255, description="The maximum output value") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) image_arr = numpy.asarray(image, dtype=numpy.float32) / 255 @@ -492,7 +493,7 @@ class ImageInverseLerpInvocation(BaseInvocation, WithMetadata): min: int = InputField(default=0, ge=0, le=255, description="The minimum input value") max: int = InputField(default=255, ge=0, le=255, description="The maximum input value") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) image_arr = numpy.asarray(image, dtype=numpy.float32) @@ -517,7 +518,7 @@ class ImageNSFWBlurInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description="The image to check") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) logger = context.logger @@ -553,7 +554,7 @@ class ImageWatermarkInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description="The image to check") text: str = InputField(default="InvokeAI", description="Watermark text") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) new_image = InvisibleWatermark.add_watermark(image, self.text) image_dto = context.images.save(image=new_image) @@ -579,7 +580,7 @@ class MaskEdgeInvocation(BaseInvocation, WithMetadata): description="Second threshold for the hysteresis procedure in Canny edge detection" ) - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: mask = context.images.get_pil(self.image.image_name).convert("L") npimg = numpy.asarray(mask, dtype=numpy.uint8) @@ -613,7 +614,7 @@ class MaskCombineInvocation(BaseInvocation, WithMetadata): mask1: ImageField = InputField(description="The first mask to combine") mask2: ImageField = InputField(description="The second image to combine") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: mask1 = context.images.get_pil(self.mask1.image_name).convert("L") mask2 = context.images.get_pil(self.mask2.image_name).convert("L") @@ -642,7 +643,7 @@ class ColorCorrectInvocation(BaseInvocation, WithMetadata): mask: Optional[ImageField] = InputField(default=None, description="Mask to use when applying color-correction") mask_blur_radius: float = InputField(default=8, description="Mask blur radius") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: pil_init_mask = None if self.mask is not None: pil_init_mask = context.images.get_pil(self.mask.image_name).convert("L") @@ -741,7 +742,7 @@ class ImageHueAdjustmentInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description="The image to adjust") hue: int = InputField(default=0, description="The degrees by which to rotate the hue, 0-360") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: pil_image = context.images.get_pil(self.image.image_name) # Convert image to HSV color space @@ -831,7 +832,7 @@ class ImageChannelOffsetInvocation(BaseInvocation, WithMetadata): channel: COLOR_CHANNELS = InputField(description="Which channel to adjust") offset: int = InputField(default=0, ge=-255, le=255, description="The amount to adjust the channel by") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: pil_image = context.images.get_pil(self.image.image_name) # extract the channel and mode from the input and reference tuple @@ -888,7 +889,7 @@ class ImageChannelMultiplyInvocation(BaseInvocation, WithMetadata): scale: float = InputField(default=1.0, ge=0.0, description="The amount to scale the channel by.") invert_channel: bool = InputField(default=False, description="Invert the channel after scaling") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: pil_image = context.images.get_pil(self.image.image_name) # extract the channel and mode from the input and reference tuple @@ -931,7 +932,7 @@ class SaveImageInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description=FieldDescriptions.image) board: BoardField = InputField(default=None, description=FieldDescriptions.board, input=Input.Direct) - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) image_dto = context.images.save(image=image, board_id=self.board.board_id if self.board else None) @@ -953,7 +954,7 @@ class LinearUIOutputInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description=FieldDescriptions.image) board: Optional[BoardField] = InputField(default=None, description=FieldDescriptions.board, input=Input.Direct) - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image_dto = context.images.get_dto(self.image.image_name) image_dto = context.images.update( diff --git a/invokeai/app/invocations/infill.py b/invokeai/app/invocations/infill.py index 159bdb5f7ad..b007edd9e42 100644 --- a/invokeai/app/invocations/infill.py +++ b/invokeai/app/invocations/infill.py @@ -8,6 +8,7 @@ from invokeai.app.invocations.fields import ColorField, ImageField from invokeai.app.invocations.primitives import ImageOutput +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.misc import SEED_MAX from invokeai.backend.image_util.cv2_inpaint import cv2_inpaint from invokeai.backend.image_util.lama import LaMA @@ -129,7 +130,7 @@ class InfillColorInvocation(BaseInvocation, WithMetadata): description="The color to use to infill", ) - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) solid_bg = Image.new("RGBA", image.size, self.color.tuple()) @@ -155,7 +156,7 @@ class InfillTileInvocation(BaseInvocation, WithMetadata): description="The seed to use for tile generation (omit for random)", ) - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) infilled = tile_fill_missing(image.copy(), seed=self.seed, tile_size=self.tile_size) @@ -176,7 +177,7 @@ class InfillPatchMatchInvocation(BaseInvocation, WithMetadata): downscale: float = InputField(default=2.0, gt=0, description="Run patchmatch on downscaled image to speedup infill") resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name).convert("RGBA") resample_mode = PIL_RESAMPLING_MAP[self.resample_mode] @@ -213,7 +214,7 @@ class LaMaInfillInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description="The image to infill") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) infilled = infill_lama(image.copy()) @@ -229,7 +230,7 @@ class CV2InfillInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description="The image to infill") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) infilled = infill_cv2(image.copy()) diff --git a/invokeai/app/invocations/ip_adapter.py b/invokeai/app/invocations/ip_adapter.py index b836be04b58..845fcfa2848 100644 --- a/invokeai/app/invocations/ip_adapter.py +++ b/invokeai/app/invocations/ip_adapter.py @@ -13,6 +13,7 @@ from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField from invokeai.app.invocations.primitives import ImageField from invokeai.app.invocations.util import validate_begin_end_step, validate_weights +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.model_management.models.base import BaseModelType, ModelType from invokeai.backend.model_management.models.ip_adapter import get_ip_adapter_image_encoder_model_id @@ -92,7 +93,7 @@ def validate_begin_end_step_percent(self): validate_begin_end_step(self.begin_step_percent, self.end_step_percent) return self - def invoke(self, context) -> IPAdapterOutput: + def invoke(self, context: InvocationContext) -> IPAdapterOutput: # Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model. ip_adapter_info = context.models.get_info( self.ip_adapter_model.model_name, self.ip_adapter_model.base_model, ModelType.IPAdapter diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 0127a6521e1..2cc84f80a73 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -3,7 +3,7 @@ import math from contextlib import ExitStack from functools import singledispatchmethod -from typing import TYPE_CHECKING, List, Literal, Optional, Union +from typing import List, Literal, Optional, Union import einops import numpy as np @@ -42,6 +42,7 @@ LatentsOutput, ) from invokeai.app.invocations.t2i_adapter import T2IAdapterField +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.controlnet_utils import prepare_control_image from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus from invokeai.backend.model_management.models import ModelType, SilenceWarnings @@ -70,9 +71,6 @@ from .controlnet_image_processors import ControlField from .model import ModelInfo, UNetField, VaeField -if TYPE_CHECKING: - from invokeai.app.services.shared.invocation_context import InvocationContext - if choose_torch_device() == torch.device("mps"): from torch import mps @@ -177,7 +175,7 @@ def invoke(self, context) -> DenoiseMaskOutput: def get_scheduler( - context: "InvocationContext", + context: InvocationContext, scheduler_info: ModelInfo, scheduler_name: str, seed: int, @@ -300,7 +298,7 @@ def ge_one(cls, v): def get_conditioning_data( self, - context: "InvocationContext", + context: InvocationContext, scheduler, unet, seed, @@ -369,7 +367,7 @@ def __init__(self): def prep_control_data( self, - context: "InvocationContext", + context: InvocationContext, control_input: Union[ControlField, List[ControlField]], latents_shape: List[int], exit_stack: ExitStack, @@ -442,7 +440,7 @@ def prep_control_data( def prep_ip_adapter_data( self, - context: "InvocationContext", + context: InvocationContext, ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]], conditioning_data: ConditioningData, exit_stack: ExitStack, @@ -509,7 +507,7 @@ def prep_ip_adapter_data( def run_t2i_adapters( self, - context: "InvocationContext", + context: InvocationContext, t2i_adapter: Optional[Union[T2IAdapterField, list[T2IAdapterField]]], latents_shape: list[int], do_classifier_free_guidance: bool, @@ -618,7 +616,7 @@ def init_scheduler(self, scheduler, device, steps, denoising_start, denoising_en return num_inference_steps, timesteps, init_timestep - def prep_inpaint_mask(self, context: "InvocationContext", latents): + def prep_inpaint_mask(self, context: InvocationContext, latents): if self.denoise_mask is None: return None, None diff --git a/invokeai/app/invocations/math.py b/invokeai/app/invocations/math.py index d2dbf049816..83a092be69e 100644 --- a/invokeai/app/invocations/math.py +++ b/invokeai/app/invocations/math.py @@ -7,6 +7,7 @@ from invokeai.app.invocations.fields import FieldDescriptions, InputField from invokeai.app.invocations.primitives import FloatOutput, IntegerOutput +from invokeai.app.services.shared.invocation_context import InvocationContext from .baseinvocation import BaseInvocation, invocation @@ -18,7 +19,7 @@ class AddInvocation(BaseInvocation): a: int = InputField(default=0, description=FieldDescriptions.num_1) b: int = InputField(default=0, description=FieldDescriptions.num_2) - def invoke(self, context) -> IntegerOutput: + def invoke(self, context: InvocationContext) -> IntegerOutput: return IntegerOutput(value=self.a + self.b) @@ -29,7 +30,7 @@ class SubtractInvocation(BaseInvocation): a: int = InputField(default=0, description=FieldDescriptions.num_1) b: int = InputField(default=0, description=FieldDescriptions.num_2) - def invoke(self, context) -> IntegerOutput: + def invoke(self, context: InvocationContext) -> IntegerOutput: return IntegerOutput(value=self.a - self.b) @@ -40,7 +41,7 @@ class MultiplyInvocation(BaseInvocation): a: int = InputField(default=0, description=FieldDescriptions.num_1) b: int = InputField(default=0, description=FieldDescriptions.num_2) - def invoke(self, context) -> IntegerOutput: + def invoke(self, context: InvocationContext) -> IntegerOutput: return IntegerOutput(value=self.a * self.b) @@ -51,7 +52,7 @@ class DivideInvocation(BaseInvocation): a: int = InputField(default=0, description=FieldDescriptions.num_1) b: int = InputField(default=0, description=FieldDescriptions.num_2) - def invoke(self, context) -> IntegerOutput: + def invoke(self, context: InvocationContext) -> IntegerOutput: return IntegerOutput(value=int(self.a / self.b)) @@ -69,7 +70,7 @@ class RandomIntInvocation(BaseInvocation): low: int = InputField(default=0, description=FieldDescriptions.inclusive_low) high: int = InputField(default=np.iinfo(np.int32).max, description=FieldDescriptions.exclusive_high) - def invoke(self, context) -> IntegerOutput: + def invoke(self, context: InvocationContext) -> IntegerOutput: return IntegerOutput(value=np.random.randint(self.low, self.high)) @@ -88,7 +89,7 @@ class RandomFloatInvocation(BaseInvocation): high: float = InputField(default=1.0, description=FieldDescriptions.exclusive_high) decimals: int = InputField(default=2, description=FieldDescriptions.decimal_places) - def invoke(self, context) -> FloatOutput: + def invoke(self, context: InvocationContext) -> FloatOutput: random_float = np.random.uniform(self.low, self.high) rounded_float = round(random_float, self.decimals) return FloatOutput(value=rounded_float) @@ -110,7 +111,7 @@ class FloatToIntegerInvocation(BaseInvocation): default="Nearest", description="The method to use for rounding" ) - def invoke(self, context) -> IntegerOutput: + def invoke(self, context: InvocationContext) -> IntegerOutput: if self.method == "Nearest": return IntegerOutput(value=round(self.value / self.multiple) * self.multiple) elif self.method == "Floor": @@ -128,7 +129,7 @@ class RoundInvocation(BaseInvocation): value: float = InputField(default=0, description="The float value") decimals: int = InputField(default=0, description="The number of decimal places") - def invoke(self, context) -> FloatOutput: + def invoke(self, context: InvocationContext) -> FloatOutput: return FloatOutput(value=round(self.value, self.decimals)) @@ -196,7 +197,7 @@ def no_unrepresentable_results(cls, v: int, info: ValidationInfo): raise ValueError("Result of exponentiation is not an integer") return v - def invoke(self, context) -> IntegerOutput: + def invoke(self, context: InvocationContext) -> IntegerOutput: # Python doesn't support switch statements until 3.10, but InvokeAI supports back to 3.9 if self.operation == "ADD": return IntegerOutput(value=self.a + self.b) @@ -270,7 +271,7 @@ def no_unrepresentable_results(cls, v: float, info: ValidationInfo): raise ValueError("Root operation resulted in a complex number") return v - def invoke(self, context) -> FloatOutput: + def invoke(self, context: InvocationContext) -> FloatOutput: # Python doesn't support switch statements until 3.10, but InvokeAI supports back to 3.9 if self.operation == "ADD": return FloatOutput(value=self.a + self.b) diff --git a/invokeai/app/invocations/metadata.py b/invokeai/app/invocations/metadata.py index 9d74abd8c12..58edfab711a 100644 --- a/invokeai/app/invocations/metadata.py +++ b/invokeai/app/invocations/metadata.py @@ -20,6 +20,7 @@ from invokeai.app.invocations.ip_adapter import IPAdapterModelField from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField from invokeai.app.invocations.t2i_adapter import T2IAdapterField +from invokeai.app.services.shared.invocation_context import InvocationContext from ...version import __version__ @@ -64,7 +65,7 @@ class MetadataItemInvocation(BaseInvocation): label: str = InputField(description=FieldDescriptions.metadata_item_label) value: Any = InputField(description=FieldDescriptions.metadata_item_value, ui_type=UIType.Any) - def invoke(self, context) -> MetadataItemOutput: + def invoke(self, context: InvocationContext) -> MetadataItemOutput: return MetadataItemOutput(item=MetadataItemField(label=self.label, value=self.value)) @@ -81,7 +82,7 @@ class MetadataInvocation(BaseInvocation): description=FieldDescriptions.metadata_item_polymorphic ) - def invoke(self, context) -> MetadataOutput: + def invoke(self, context: InvocationContext) -> MetadataOutput: if isinstance(self.items, MetadataItemField): # single metadata item data = {self.items.label: self.items.value} @@ -100,7 +101,7 @@ class MergeMetadataInvocation(BaseInvocation): collection: list[MetadataField] = InputField(description=FieldDescriptions.metadata_collection) - def invoke(self, context) -> MetadataOutput: + def invoke(self, context: InvocationContext) -> MetadataOutput: data = {} for item in self.collection: data.update(item.model_dump()) @@ -218,7 +219,7 @@ class CoreMetadataInvocation(BaseInvocation): description="The start value used for refiner denoising", ) - def invoke(self, context) -> MetadataOutput: + def invoke(self, context: InvocationContext) -> MetadataOutput: """Collects and outputs a CoreMetadata object""" return MetadataOutput( diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index f81e559e446..6a1fd6d36bc 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -4,6 +4,7 @@ from pydantic import BaseModel, ConfigDict, Field from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.shared.models import FreeUConfig from ...backend.model_management import BaseModelType, ModelType, SubModelType @@ -109,7 +110,7 @@ class MainModelLoaderInvocation(BaseInvocation): model: MainModelField = InputField(description=FieldDescriptions.main_model, input=Input.Direct) # TODO: precision? - def invoke(self, context) -> ModelLoaderOutput: + def invoke(self, context: InvocationContext) -> ModelLoaderOutput: base_model = self.model.base_model model_name = self.model.model_name model_type = ModelType.Main @@ -221,7 +222,7 @@ class LoraLoaderInvocation(BaseInvocation): title="CLIP", ) - def invoke(self, context) -> LoraLoaderOutput: + def invoke(self, context: InvocationContext) -> LoraLoaderOutput: if self.lora is None: raise Exception("No LoRA provided") @@ -310,7 +311,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation): title="CLIP 2", ) - def invoke(self, context) -> SDXLLoraLoaderOutput: + def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput: if self.lora is None: raise Exception("No LoRA provided") @@ -393,7 +394,7 @@ class VaeLoaderInvocation(BaseInvocation): title="VAE", ) - def invoke(self, context) -> VAEOutput: + def invoke(self, context: InvocationContext) -> VAEOutput: base_model = self.vae_model.base_model model_name = self.vae_model.model_name model_type = ModelType.Vae @@ -448,7 +449,7 @@ class SeamlessModeInvocation(BaseInvocation): seamless_y: bool = InputField(default=True, input=Input.Any, description="Specify whether Y axis is seamless") seamless_x: bool = InputField(default=True, input=Input.Any, description="Specify whether X axis is seamless") - def invoke(self, context) -> SeamlessModeOutput: + def invoke(self, context: InvocationContext) -> SeamlessModeOutput: # Conditionally append 'x' and 'y' based on seamless_x and seamless_y unet = copy.deepcopy(self.unet) vae = copy.deepcopy(self.vae) @@ -484,6 +485,6 @@ class FreeUInvocation(BaseInvocation): s1: float = InputField(default=0.9, ge=-1, le=3, description=FieldDescriptions.freeu_s1) s2: float = InputField(default=0.2, ge=-1, le=3, description=FieldDescriptions.freeu_s2) - def invoke(self, context) -> UNetOutput: + def invoke(self, context: InvocationContext) -> UNetOutput: self.unet.freeu_config = FreeUConfig(s1=self.s1, s2=self.s2, b1=self.b1, b2=self.b2) return UNetOutput(unet=self.unet) diff --git a/invokeai/app/invocations/noise.py b/invokeai/app/invocations/noise.py index 41641152f04..78f13cc52d1 100644 --- a/invokeai/app/invocations/noise.py +++ b/invokeai/app/invocations/noise.py @@ -5,6 +5,7 @@ from pydantic import field_validator from invokeai.app.invocations.fields import FieldDescriptions, InputField, LatentsField, OutputField +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.misc import SEED_MAX from ...backend.util.devices import choose_torch_device, torch_dtype @@ -112,7 +113,7 @@ def modulo_seed(cls, v): """Returns the seed modulo (SEED_MAX + 1) to ensure it is within the valid range.""" return v % (SEED_MAX + 1) - def invoke(self, context) -> NoiseOutput: + def invoke(self, context: InvocationContext) -> NoiseOutput: noise = get_noise( width=self.width, height=self.height, diff --git a/invokeai/app/invocations/onnx.py b/invokeai/app/invocations/onnx.py index a1e318a3802..e7b4d3d9fc5 100644 --- a/invokeai/app/invocations/onnx.py +++ b/invokeai/app/invocations/onnx.py @@ -63,7 +63,7 @@ class ONNXPromptInvocation(BaseInvocation): prompt: str = InputField(default="", description=FieldDescriptions.raw_prompt, ui_component=UIComponent.Textarea) clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection) - def invoke(self, context) -> ConditioningOutput: + def invoke(self, context: InvocationContext) -> ConditioningOutput: tokenizer_info = context.services.model_manager.get_model( **self.clip.tokenizer.model_dump(), ) @@ -201,7 +201,7 @@ def ge_one(cls, v): # based on # https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375 - def invoke(self, context) -> LatentsOutput: + def invoke(self, context: InvocationContext) -> LatentsOutput: c, _ = context.services.latents.get(self.positive_conditioning.conditioning_name) uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name) graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) @@ -342,7 +342,7 @@ class ONNXLatentsToImageInvocation(BaseInvocation, WithMetadata): ) # tiled: bool = InputField(default=False, description="Decode latents by overlaping tiles(less memory consumption)") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: latents = context.services.latents.get(self.latents.latents_name) if self.vae.vae.submodel != SubModelType.VaeDecoder: @@ -417,7 +417,7 @@ class OnnxModelLoaderInvocation(BaseInvocation): description=FieldDescriptions.onnx_main_model, input=Input.Direct, ui_type=UIType.ONNXModel ) - def invoke(self, context) -> ONNXModelLoaderOutput: + def invoke(self, context: InvocationContext) -> ONNXModelLoaderOutput: base_model = self.model.base_model model_name = self.model.model_name model_type = ModelType.ONNX diff --git a/invokeai/app/invocations/param_easing.py b/invokeai/app/invocations/param_easing.py index bf59e87d270..6845637de92 100644 --- a/invokeai/app/invocations/param_easing.py +++ b/invokeai/app/invocations/param_easing.py @@ -40,6 +40,7 @@ from matplotlib.ticker import MaxNLocator from invokeai.app.invocations.primitives import FloatCollectionOutput +from invokeai.app.services.shared.invocation_context import InvocationContext from .baseinvocation import BaseInvocation, invocation from .fields import InputField @@ -62,7 +63,7 @@ class FloatLinearRangeInvocation(BaseInvocation): description="number of values to interpolate over (including start and stop)", ) - def invoke(self, context) -> FloatCollectionOutput: + def invoke(self, context: InvocationContext) -> FloatCollectionOutput: param_list = list(np.linspace(self.start, self.stop, self.steps)) return FloatCollectionOutput(collection=param_list) @@ -130,7 +131,7 @@ class StepParamEasingInvocation(BaseInvocation): # alt_mirror: bool = InputField(default=False, description="alternative mirroring by dual easing") show_easing_plot: bool = InputField(default=False, description="show easing plot") - def invoke(self, context) -> FloatCollectionOutput: + def invoke(self, context: InvocationContext) -> FloatCollectionOutput: log_diagnostics = False # convert from start_step_percent to nearest step <= (steps * start_step_percent) # start_step = int(np.floor(self.num_steps * self.start_step_percent)) diff --git a/invokeai/app/invocations/primitives.py b/invokeai/app/invocations/primitives.py index ee04345eed8..c90d3230b2b 100644 --- a/invokeai/app/invocations/primitives.py +++ b/invokeai/app/invocations/primitives.py @@ -17,6 +17,7 @@ UIComponent, ) from invokeai.app.services.images.images_common import ImageDTO +from invokeai.app.services.shared.invocation_context import InvocationContext from .baseinvocation import ( BaseInvocation, @@ -59,7 +60,7 @@ class BooleanInvocation(BaseInvocation): value: bool = InputField(default=False, description="The boolean value") - def invoke(self, context) -> BooleanOutput: + def invoke(self, context: InvocationContext) -> BooleanOutput: return BooleanOutput(value=self.value) @@ -75,7 +76,7 @@ class BooleanCollectionInvocation(BaseInvocation): collection: list[bool] = InputField(default=[], description="The collection of boolean values") - def invoke(self, context) -> BooleanCollectionOutput: + def invoke(self, context: InvocationContext) -> BooleanCollectionOutput: return BooleanCollectionOutput(collection=self.collection) @@ -108,7 +109,7 @@ class IntegerInvocation(BaseInvocation): value: int = InputField(default=0, description="The integer value") - def invoke(self, context) -> IntegerOutput: + def invoke(self, context: InvocationContext) -> IntegerOutput: return IntegerOutput(value=self.value) @@ -124,7 +125,7 @@ class IntegerCollectionInvocation(BaseInvocation): collection: list[int] = InputField(default=[], description="The collection of integer values") - def invoke(self, context) -> IntegerCollectionOutput: + def invoke(self, context: InvocationContext) -> IntegerCollectionOutput: return IntegerCollectionOutput(collection=self.collection) @@ -155,7 +156,7 @@ class FloatInvocation(BaseInvocation): value: float = InputField(default=0.0, description="The float value") - def invoke(self, context) -> FloatOutput: + def invoke(self, context: InvocationContext) -> FloatOutput: return FloatOutput(value=self.value) @@ -171,7 +172,7 @@ class FloatCollectionInvocation(BaseInvocation): collection: list[float] = InputField(default=[], description="The collection of float values") - def invoke(self, context) -> FloatCollectionOutput: + def invoke(self, context: InvocationContext) -> FloatCollectionOutput: return FloatCollectionOutput(collection=self.collection) @@ -202,7 +203,7 @@ class StringInvocation(BaseInvocation): value: str = InputField(default="", description="The string value", ui_component=UIComponent.Textarea) - def invoke(self, context) -> StringOutput: + def invoke(self, context: InvocationContext) -> StringOutput: return StringOutput(value=self.value) @@ -218,7 +219,7 @@ class StringCollectionInvocation(BaseInvocation): collection: list[str] = InputField(default=[], description="The collection of string values") - def invoke(self, context) -> StringCollectionOutput: + def invoke(self, context: InvocationContext) -> StringCollectionOutput: return StringCollectionOutput(collection=self.collection) @@ -261,7 +262,7 @@ class ImageInvocation( image: ImageField = InputField(description="The image to load") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) return ImageOutput( @@ -283,7 +284,7 @@ class ImageCollectionInvocation(BaseInvocation): collection: list[ImageField] = InputField(description="The collection of image values") - def invoke(self, context) -> ImageCollectionOutput: + def invoke(self, context: InvocationContext) -> ImageCollectionOutput: return ImageCollectionOutput(collection=self.collection) @@ -346,7 +347,7 @@ class LatentsInvocation(BaseInvocation): latents: LatentsField = InputField(description="The latents tensor", input=Input.Connection) - def invoke(self, context) -> LatentsOutput: + def invoke(self, context: InvocationContext) -> LatentsOutput: latents = context.latents.get(self.latents.latents_name) return LatentsOutput.build(self.latents.latents_name, latents) @@ -366,7 +367,7 @@ class LatentsCollectionInvocation(BaseInvocation): description="The collection of latents tensors", ) - def invoke(self, context) -> LatentsCollectionOutput: + def invoke(self, context: InvocationContext) -> LatentsCollectionOutput: return LatentsCollectionOutput(collection=self.collection) @@ -397,7 +398,7 @@ class ColorInvocation(BaseInvocation): color: ColorField = InputField(default=ColorField(r=0, g=0, b=0, a=255), description="The color value") - def invoke(self, context) -> ColorOutput: + def invoke(self, context: InvocationContext) -> ColorOutput: return ColorOutput(color=self.color) @@ -438,7 +439,7 @@ class ConditioningInvocation(BaseInvocation): conditioning: ConditioningField = InputField(description=FieldDescriptions.cond, input=Input.Connection) - def invoke(self, context) -> ConditioningOutput: + def invoke(self, context: InvocationContext) -> ConditioningOutput: return ConditioningOutput(conditioning=self.conditioning) @@ -457,7 +458,7 @@ class ConditioningCollectionInvocation(BaseInvocation): description="The collection of conditioning tensors", ) - def invoke(self, context) -> ConditioningCollectionOutput: + def invoke(self, context: InvocationContext) -> ConditioningCollectionOutput: return ConditioningCollectionOutput(collection=self.collection) diff --git a/invokeai/app/invocations/prompt.py b/invokeai/app/invocations/prompt.py index 4f5ef43a568..234743a0035 100644 --- a/invokeai/app/invocations/prompt.py +++ b/invokeai/app/invocations/prompt.py @@ -6,6 +6,7 @@ from pydantic import field_validator from invokeai.app.invocations.primitives import StringCollectionOutput +from invokeai.app.services.shared.invocation_context import InvocationContext from .baseinvocation import BaseInvocation, invocation from .fields import InputField, UIComponent @@ -29,7 +30,7 @@ class DynamicPromptInvocation(BaseInvocation): max_prompts: int = InputField(default=1, description="The number of prompts to generate") combinatorial: bool = InputField(default=False, description="Whether to use the combinatorial generator") - def invoke(self, context) -> StringCollectionOutput: + def invoke(self, context: InvocationContext) -> StringCollectionOutput: if self.combinatorial: generator = CombinatorialPromptGenerator() prompts = generator.generate(self.prompt, max_prompts=self.max_prompts) @@ -91,7 +92,7 @@ def promptsFromFile( break return prompts - def invoke(self, context) -> StringCollectionOutput: + def invoke(self, context: InvocationContext) -> StringCollectionOutput: prompts = self.promptsFromFile( self.file_path, self.pre_prompt, diff --git a/invokeai/app/invocations/sdxl.py b/invokeai/app/invocations/sdxl.py index 75a526cfff6..8d51674a046 100644 --- a/invokeai/app/invocations/sdxl.py +++ b/invokeai/app/invocations/sdxl.py @@ -1,4 +1,5 @@ from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType +from invokeai.app.services.shared.invocation_context import InvocationContext from ...backend.model_management import ModelType, SubModelType from .baseinvocation import ( @@ -38,7 +39,7 @@ class SDXLModelLoaderInvocation(BaseInvocation): ) # TODO: precision? - def invoke(self, context) -> SDXLModelLoaderOutput: + def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput: base_model = self.model.base_model model_name = self.model.model_name model_type = ModelType.Main @@ -127,7 +128,7 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation): ) # TODO: precision? - def invoke(self, context) -> SDXLRefinerModelLoaderOutput: + def invoke(self, context: InvocationContext) -> SDXLRefinerModelLoaderOutput: base_model = self.model.base_model model_name = self.model.model_name model_type = ModelType.Main diff --git a/invokeai/app/invocations/strings.py b/invokeai/app/invocations/strings.py index a4c92d9de56..182c976cd77 100644 --- a/invokeai/app/invocations/strings.py +++ b/invokeai/app/invocations/strings.py @@ -2,6 +2,8 @@ import re +from invokeai.app.services.shared.invocation_context import InvocationContext + from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, @@ -32,7 +34,7 @@ class StringSplitNegInvocation(BaseInvocation): string: str = InputField(default="", description="String to split", ui_component=UIComponent.Textarea) - def invoke(self, context) -> StringPosNegOutput: + def invoke(self, context: InvocationContext) -> StringPosNegOutput: p_string = "" n_string = "" brackets_depth = 0 @@ -76,7 +78,7 @@ class StringSplitInvocation(BaseInvocation): default="", description="Delimiter to spilt with. blank will split on the first whitespace" ) - def invoke(self, context) -> String2Output: + def invoke(self, context: InvocationContext) -> String2Output: result = self.string.split(self.delimiter, 1) if len(result) == 2: part1, part2 = result @@ -94,7 +96,7 @@ class StringJoinInvocation(BaseInvocation): string_left: str = InputField(default="", description="String Left", ui_component=UIComponent.Textarea) string_right: str = InputField(default="", description="String Right", ui_component=UIComponent.Textarea) - def invoke(self, context) -> StringOutput: + def invoke(self, context: InvocationContext) -> StringOutput: return StringOutput(value=((self.string_left or "") + (self.string_right or ""))) @@ -106,7 +108,7 @@ class StringJoinThreeInvocation(BaseInvocation): string_middle: str = InputField(default="", description="String Middle", ui_component=UIComponent.Textarea) string_right: str = InputField(default="", description="String Right", ui_component=UIComponent.Textarea) - def invoke(self, context) -> StringOutput: + def invoke(self, context: InvocationContext) -> StringOutput: return StringOutput(value=((self.string_left or "") + (self.string_middle or "") + (self.string_right or ""))) @@ -125,7 +127,7 @@ class StringReplaceInvocation(BaseInvocation): default=False, description="Use search string as a regex expression (non regex is case insensitive)" ) - def invoke(self, context) -> StringOutput: + def invoke(self, context: InvocationContext) -> StringOutput: pattern = self.search_string or "" new_string = self.string or "" if len(pattern) > 0: diff --git a/invokeai/app/invocations/t2i_adapter.py b/invokeai/app/invocations/t2i_adapter.py index 74a098a501c..0f4fe66ada1 100644 --- a/invokeai/app/invocations/t2i_adapter.py +++ b/invokeai/app/invocations/t2i_adapter.py @@ -11,6 +11,7 @@ from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_RESIZE_VALUES from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField from invokeai.app.invocations.util import validate_begin_end_step, validate_weights +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.model_management.models.base import BaseModelType @@ -89,7 +90,7 @@ def validate_begin_end_step_percent(self): validate_begin_end_step(self.begin_step_percent, self.end_step_percent) return self - def invoke(self, context) -> T2IAdapterOutput: + def invoke(self, context: InvocationContext) -> T2IAdapterOutput: return T2IAdapterOutput( t2i_adapter=T2IAdapterField( image=self.image, diff --git a/invokeai/app/invocations/tiles.py b/invokeai/app/invocations/tiles.py index 0b4c472696b..19ece423761 100644 --- a/invokeai/app/invocations/tiles.py +++ b/invokeai/app/invocations/tiles.py @@ -13,6 +13,7 @@ ) from invokeai.app.invocations.fields import ImageField, Input, InputField, OutputField, WithMetadata from invokeai.app.invocations.primitives import ImageOutput +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.tiles.tiles import ( calc_tiles_even_split, calc_tiles_min_overlap, @@ -56,7 +57,7 @@ class CalculateImageTilesInvocation(BaseInvocation): description="The target overlap, in pixels, between adjacent tiles. Adjacent tiles will overlap by at least this amount", ) - def invoke(self, context) -> CalculateImageTilesOutput: + def invoke(self, context: InvocationContext) -> CalculateImageTilesOutput: tiles = calc_tiles_with_overlap( image_height=self.image_height, image_width=self.image_width, @@ -99,7 +100,7 @@ class CalculateImageTilesEvenSplitInvocation(BaseInvocation): description="The overlap, in pixels, between adjacent tiles.", ) - def invoke(self, context) -> CalculateImageTilesOutput: + def invoke(self, context: InvocationContext) -> CalculateImageTilesOutput: tiles = calc_tiles_even_split( image_height=self.image_height, image_width=self.image_width, @@ -129,7 +130,7 @@ class CalculateImageTilesMinimumOverlapInvocation(BaseInvocation): tile_height: int = InputField(ge=1, default=576, description="The tile height, in pixels.") min_overlap: int = InputField(default=128, ge=0, description="Minimum overlap between adjacent tiles, in pixels.") - def invoke(self, context) -> CalculateImageTilesOutput: + def invoke(self, context: InvocationContext) -> CalculateImageTilesOutput: tiles = calc_tiles_min_overlap( image_height=self.image_height, image_width=self.image_width, @@ -174,7 +175,7 @@ class TileToPropertiesInvocation(BaseInvocation): tile: Tile = InputField(description="The tile to split into properties.") - def invoke(self, context) -> TileToPropertiesOutput: + def invoke(self, context: InvocationContext) -> TileToPropertiesOutput: return TileToPropertiesOutput( coords_left=self.tile.coords.left, coords_right=self.tile.coords.right, @@ -211,7 +212,7 @@ class PairTileImageInvocation(BaseInvocation): image: ImageField = InputField(description="The tile image.") tile: Tile = InputField(description="The tile properties.") - def invoke(self, context) -> PairTileImageOutput: + def invoke(self, context: InvocationContext) -> PairTileImageOutput: return PairTileImageOutput( tile_with_image=TileWithImage( tile=self.tile, @@ -247,7 +248,7 @@ class MergeTilesToImageInvocation(BaseInvocation, WithMetadata): description="The amount to blend adjacent tiles in pixels. Must be <= the amount of overlap between adjacent tiles.", ) - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: images = [twi.image for twi in self.tiles_with_images] tiles = [twi.tile for twi in self.tiles_with_images] diff --git a/invokeai/app/invocations/upscale.py b/invokeai/app/invocations/upscale.py index ef174809860..71ef7ca3aa0 100644 --- a/invokeai/app/invocations/upscale.py +++ b/invokeai/app/invocations/upscale.py @@ -10,6 +10,7 @@ from invokeai.app.invocations.fields import ImageField from invokeai.app.invocations.primitives import ImageOutput +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN from invokeai.backend.util.devices import choose_torch_device @@ -42,7 +43,7 @@ class ESRGANInvocation(BaseInvocation, WithMetadata): model_config = ConfigDict(protected_namespaces=()) - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) models_path = context.config.get().models_path diff --git a/invokeai/app/services/shared/graph.py b/invokeai/app/services/shared/graph.py index c0699eb96bb..3df230f5ee7 100644 --- a/invokeai/app/services/shared/graph.py +++ b/invokeai/app/services/shared/graph.py @@ -17,6 +17,7 @@ invocation_output, ) from invokeai.app.invocations.fields import Input, InputField, OutputField, UIType +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.misc import uuid_string # in 3.10 this would be "from types import NoneType" @@ -201,7 +202,7 @@ class GraphInvocation(BaseInvocation): # TODO: figure out how to create a default here graph: "Graph" = InputField(description="The graph to run", default=None) - def invoke(self, context) -> GraphInvocationOutput: + def invoke(self, context: InvocationContext) -> GraphInvocationOutput: """Invoke with provided services and return outputs.""" return GraphInvocationOutput() @@ -227,7 +228,7 @@ class IterateInvocation(BaseInvocation): ) index: int = InputField(description="The index, will be provided on executed iterators", default=0, ui_hidden=True) - def invoke(self, context) -> IterateInvocationOutput: + def invoke(self, context: InvocationContext) -> IterateInvocationOutput: """Produces the outputs as values""" return IterateInvocationOutput(item=self.collection[self.index], index=self.index, total=len(self.collection)) @@ -254,7 +255,7 @@ class CollectInvocation(BaseInvocation): description="The collection, will be provided on execution", default=[], ui_hidden=True ) - def invoke(self, context) -> CollectInvocationOutput: + def invoke(self, context: InvocationContext) -> CollectInvocationOutput: """Invoke with provided services and return outputs.""" return CollectInvocationOutput(collection=copy.copy(self.collection)) diff --git a/tests/aa_nodes/test_nodes.py b/tests/aa_nodes/test_nodes.py index 559457c0e11..aab3d9c7b4b 100644 --- a/tests/aa_nodes/test_nodes.py +++ b/tests/aa_nodes/test_nodes.py @@ -8,6 +8,7 @@ ) from invokeai.app.invocations.fields import InputField, OutputField from invokeai.app.invocations.image import ImageField +from invokeai.app.services.shared.invocation_context import InvocationContext # Define test invocations before importing anything that uses invocations @@ -20,7 +21,7 @@ class ListPassThroughInvocationOutput(BaseInvocationOutput): class ListPassThroughInvocation(BaseInvocation): collection: list[ImageField] = InputField(default=[]) - def invoke(self, context) -> ListPassThroughInvocationOutput: + def invoke(self, context: InvocationContext) -> ListPassThroughInvocationOutput: return ListPassThroughInvocationOutput(collection=self.collection) @@ -33,13 +34,13 @@ class PromptTestInvocationOutput(BaseInvocationOutput): class PromptTestInvocation(BaseInvocation): prompt: str = InputField(default="") - def invoke(self, context) -> PromptTestInvocationOutput: + def invoke(self, context: InvocationContext) -> PromptTestInvocationOutput: return PromptTestInvocationOutput(prompt=self.prompt) @invocation("test_error", version="1.0.0") class ErrorInvocation(BaseInvocation): - def invoke(self, context) -> PromptTestInvocationOutput: + def invoke(self, context: InvocationContext) -> PromptTestInvocationOutput: raise Exception("This invocation is supposed to fail") @@ -53,7 +54,7 @@ class TextToImageTestInvocation(BaseInvocation): prompt: str = InputField(default="") prompt2: str = InputField(default="") - def invoke(self, context) -> ImageTestInvocationOutput: + def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput: return ImageTestInvocationOutput(image=ImageField(image_name=self.id)) @@ -62,7 +63,7 @@ class ImageToImageTestInvocation(BaseInvocation): prompt: str = InputField(default="") image: Union[ImageField, None] = InputField(default=None) - def invoke(self, context) -> ImageTestInvocationOutput: + def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput: return ImageTestInvocationOutput(image=ImageField(image_name=self.id)) @@ -75,7 +76,7 @@ class PromptCollectionTestInvocationOutput(BaseInvocationOutput): class PromptCollectionTestInvocation(BaseInvocation): collection: list[str] = InputField() - def invoke(self, context) -> PromptCollectionTestInvocationOutput: + def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput: return PromptCollectionTestInvocationOutput(collection=self.collection.copy()) @@ -88,7 +89,7 @@ class AnyTypeTestInvocationOutput(BaseInvocationOutput): class AnyTypeTestInvocation(BaseInvocation): value: Any = InputField(default=None) - def invoke(self, context) -> AnyTypeTestInvocationOutput: + def invoke(self, context: InvocationContext) -> AnyTypeTestInvocationOutput: return AnyTypeTestInvocationOutput(value=self.value) @@ -96,7 +97,7 @@ def invoke(self, context) -> AnyTypeTestInvocationOutput: class PolymorphicStringTestInvocation(BaseInvocation): value: Union[str, list[str]] = InputField(default="") - def invoke(self, context) -> PromptCollectionTestInvocationOutput: + def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput: if isinstance(self.value, str): return PromptCollectionTestInvocationOutput(collection=[self.value]) return PromptCollectionTestInvocationOutput(collection=self.value) From 251d4af0ae32b4c7ebcf4bb0c3de25bdc742cf9b Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 5 Feb 2024 17:40:49 +1100 Subject: [PATCH 016/340] feat(nodes): add boards interface to invocation context --- .../app/services/shared/invocation_context.py | 54 +++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index a849d6b17a2..cbcaa6a5489 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -7,6 +7,7 @@ from torch import Tensor from invokeai.app.invocations.fields import MetadataField, WithMetadata +from invokeai.app.services.boards.boards_common import BoardDTO from invokeai.app.services.config.config_default import InvokeAIAppConfig from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin from invokeai.app.services.images.images_common import ImageDTO @@ -63,6 +64,54 @@ class InvocationContextData: """The workflow associated with this queue item, if any.""" +class BoardsInterface: + def __init__(self, services: InvocationServices) -> None: + def create(board_name: str) -> BoardDTO: + """ + Creates a board. + + :param board_name: The name of the board to create. + """ + return services.boards.create(board_name) + + def get_dto(board_id: str) -> BoardDTO: + """ + Gets a board DTO. + + :param board_id: The ID of the board to get. + """ + return services.boards.get_dto(board_id) + + def get_all() -> list[BoardDTO]: + """ + Gets all boards. + """ + return services.boards.get_all() + + def add_image_to_board(board_id: str, image_name: str) -> None: + """ + Adds an image to a board. + + :param board_id: The ID of the board to add the image to. + :param image_name: The name of the image to add to the board. + """ + services.board_images.add_image_to_board(board_id, image_name) + + def get_all_image_names_for_board(board_id: str) -> list[str]: + """ + Gets all image names for a board. + + :param board_id: The ID of the board to get the image names for. + """ + return services.board_images.get_all_board_image_names_for_board(board_id) + + self.create = create + self.get_dto = get_dto + self.get_all = get_all + self.add_image_to_board = add_image_to_board + self.get_all_image_names_for_board = get_all_image_names_for_board + + class LoggerInterface: def __init__(self, services: InvocationServices) -> None: def debug(message: str) -> None: @@ -427,6 +476,7 @@ def __init__( logger: LoggerInterface, config: ConfigInterface, util: UtilInterface, + boards: BoardsInterface, data: InvocationContextData, services: InvocationServices, ) -> None: @@ -444,6 +494,8 @@ def __init__( """Provides access to the app's config.""" self.util = util """Provides utility methods.""" + self.boards = boards + """Provides methods to interact with boards.""" self.data = data """Provides data about the current queue item and invocation.""" self.__services = services @@ -554,6 +606,7 @@ def build_invocation_context( config = ConfigInterface(services=services) util = UtilInterface(services=services, context_data=context_data) conditioning = ConditioningInterface(services=services, context_data=context_data) + boards = BoardsInterface(services=services) ctx = InvocationContext( images=images, @@ -565,6 +618,7 @@ def build_invocation_context( util=util, conditioning=conditioning, services=services, + boards=boards, ) return ctx From 8327f2dff58f30b748cd863451aa251bed132fb6 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 5 Feb 2024 17:48:32 +1100 Subject: [PATCH 017/340] feat(nodes): export more things from `invocation_api" --- invokeai/invocation_api/__init__.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/invokeai/invocation_api/__init__.py b/invokeai/invocation_api/__init__.py index e867ec3cc4e..e80bc26a003 100644 --- a/invokeai/invocation_api/__init__.py +++ b/invokeai/invocation_api/__init__.py @@ -47,8 +47,14 @@ StringCollectionOutput, StringOutput, ) +from invokeai.app.services.boards.boards_common import BoardDTO +from invokeai.app.services.config.config_default import InvokeAIAppConfig from invokeai.app.services.image_records.image_records_common import ImageCategory from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID +from invokeai.backend.model_management.model_manager import ModelInfo +from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType +from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( BasicConditioningInfo, ConditioningFieldData, @@ -101,9 +107,23 @@ "StringOutput", # invokeai.app.services.image_records.image_records_common "ImageCategory", + # invokeai.app.services.boards.boards_common + "BoardDTO", # invokeai.backend.stable_diffusion.diffusion.conditioning_data "BasicConditioningInfo", "ConditioningFieldData", "ExtraConditioningInfo", "SDXLConditioningInfo", + # invokeai.backend.stable_diffusion.diffusers_pipeline + "PipelineIntermediateState", + # invokeai.app.services.workflow_records.workflow_records_common + "WorkflowWithoutID", + # invokeai.app.services.config.config_default + "InvokeAIAppConfig", + # invokeai.backend.model_management.model_manager + "ModelInfo", + # invokeai.backend.model_management.models.base + "BaseModelType", + "ModelType", + "SubModelType", ] From eaabad1fa13761b4c429db65fe86045e6bfd2eab Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 5 Feb 2024 18:06:01 +1100 Subject: [PATCH 018/340] chore(nodes): add comments for ConfigInterface --- invokeai/app/services/shared/invocation_context.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index cbcaa6a5489..cb989cb15e0 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -392,7 +392,7 @@ class ConfigInterface: def __init__(self, services: InvocationServices) -> None: def get() -> InvokeAIAppConfig: """ - Gets the app's config. + Gets the app's config. The config is read-only; attempts to mutate it will raise an error. """ # The config can be changed at runtime. @@ -400,6 +400,7 @@ def get() -> InvokeAIAppConfig: # We don't want nodes doing this, so we make a frozen copy. config = services.configuration.get_config() + # TODO(psyche): If config cannot be changed at runtime, should we cache this? frozen_config = config.model_copy(update={"model_config": ConfigDict(frozen=True)}) return frozen_config From b3e2bdf72689ef47a5d6b96cd3aebb88c7b3d1b5 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 6 Feb 2024 00:37:18 +1100 Subject: [PATCH 019/340] tests(nodes): fix mock InvocationContext --- tests/aa_nodes/test_graph_execution_state.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/aa_nodes/test_graph_execution_state.py b/tests/aa_nodes/test_graph_execution_state.py index 3577a78ae2c..1612cbe7198 100644 --- a/tests/aa_nodes/test_graph_execution_state.py +++ b/tests/aa_nodes/test_graph_execution_state.py @@ -93,6 +93,7 @@ def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[B logger=None, models=None, util=None, + boards=None, services=None, ) ) From 3b1c20d9f491d8a23e687384654a25d0ccf28361 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 6 Feb 2024 10:22:58 +1100 Subject: [PATCH 020/340] fix(nodes): restore missing context type annotations --- invokeai/app/invocations/latent.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 2cc84f80a73..5e36e73ec8f 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -106,7 +106,7 @@ class SchedulerInvocation(BaseInvocation): ui_type=UIType.Scheduler, ) - def invoke(self, context) -> SchedulerOutput: + def invoke(self, context: InvocationContext) -> SchedulerOutput: return SchedulerOutput(scheduler=self.scheduler) @@ -141,7 +141,7 @@ def prep_mask_tensor(self, mask_image): return mask_tensor @torch.no_grad() - def invoke(self, context) -> DenoiseMaskOutput: + def invoke(self, context: InvocationContext) -> DenoiseMaskOutput: if self.image is not None: image = context.images.get_pil(self.image.image_name) image = image_resized_to_grid_as_tensor(image.convert("RGB")) @@ -630,7 +630,7 @@ def prep_inpaint_mask(self, context: InvocationContext, latents): return 1 - mask, masked_latents @torch.no_grad() - def invoke(self, context) -> LatentsOutput: + def invoke(self, context: InvocationContext) -> LatentsOutput: with SilenceWarnings(): # this quenches NSFW nag from diffusers seed = None noise = None @@ -777,7 +777,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata): fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32) @torch.no_grad() - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: latents = context.latents.get(self.latents.latents_name) vae_info = context.models.load(**self.vae.vae.model_dump()) @@ -868,7 +868,7 @@ class ResizeLatentsInvocation(BaseInvocation): mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode) antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias) - def invoke(self, context) -> LatentsOutput: + def invoke(self, context: InvocationContext) -> LatentsOutput: latents = context.latents.get(self.latents.latents_name) # TODO: @@ -909,7 +909,7 @@ class ScaleLatentsInvocation(BaseInvocation): mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode) antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias) - def invoke(self, context) -> LatentsOutput: + def invoke(self, context: InvocationContext) -> LatentsOutput: latents = context.latents.get(self.latents.latents_name) # TODO: @@ -998,7 +998,7 @@ def vae_encode(vae_info, upcast, tiled, image_tensor): return latents @torch.no_grad() - def invoke(self, context) -> LatentsOutput: + def invoke(self, context: InvocationContext) -> LatentsOutput: image = context.images.get_pil(self.image.image_name) vae_info = context.models.load(**self.vae.vae.model_dump()) @@ -1046,7 +1046,7 @@ class BlendLatentsInvocation(BaseInvocation): ) alpha: float = InputField(default=0.5, description=FieldDescriptions.blend_alpha) - def invoke(self, context) -> LatentsOutput: + def invoke(self, context: InvocationContext) -> LatentsOutput: latents_a = context.latents.get(self.latents_a.latents_name) latents_b = context.latents.get(self.latents_b.latents_name) @@ -1147,7 +1147,7 @@ class CropLatentsCoreInvocation(BaseInvocation): description="The height (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.", ) - def invoke(self, context) -> LatentsOutput: + def invoke(self, context: InvocationContext) -> LatentsOutput: latents = context.latents.get(self.latents.latents_name) x1 = self.x // LATENT_SCALE_FACTOR From a780734044bff5eac3875e2e0a28c944ca6e8a21 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 14:24:05 +1100 Subject: [PATCH 021/340] feat(nodes): do not hide `services` in invocation context interfaces --- .../app/services/shared/invocation_context.py | 675 ++++++++---------- 1 file changed, 317 insertions(+), 358 deletions(-) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index cb989cb15e0..54c50bcf76b 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -64,379 +64,338 @@ class InvocationContextData: """The workflow associated with this queue item, if any.""" -class BoardsInterface: - def __init__(self, services: InvocationServices) -> None: - def create(board_name: str) -> BoardDTO: - """ - Creates a board. - - :param board_name: The name of the board to create. - """ - return services.boards.create(board_name) - - def get_dto(board_id: str) -> BoardDTO: - """ - Gets a board DTO. - - :param board_id: The ID of the board to get. - """ - return services.boards.get_dto(board_id) - - def get_all() -> list[BoardDTO]: - """ - Gets all boards. - """ - return services.boards.get_all() - - def add_image_to_board(board_id: str, image_name: str) -> None: - """ - Adds an image to a board. - - :param board_id: The ID of the board to add the image to. - :param image_name: The name of the image to add to the board. - """ - services.board_images.add_image_to_board(board_id, image_name) - - def get_all_image_names_for_board(board_id: str) -> list[str]: - """ - Gets all image names for a board. - - :param board_id: The ID of the board to get the image names for. - """ - return services.board_images.get_all_board_image_names_for_board(board_id) - - self.create = create - self.get_dto = get_dto - self.get_all = get_all - self.add_image_to_board = add_image_to_board - self.get_all_image_names_for_board = get_all_image_names_for_board - - -class LoggerInterface: - def __init__(self, services: InvocationServices) -> None: - def debug(message: str) -> None: - """ - Logs a debug message. - - :param message: The message to log. - """ - services.logger.debug(message) - - def info(message: str) -> None: - """ - Logs an info message. - - :param message: The message to log. - """ - services.logger.info(message) - - def warning(message: str) -> None: - """ - Logs a warning message. - - :param message: The message to log. - """ - services.logger.warning(message) - - def error(message: str) -> None: - """ - Logs an error message. - - :param message: The message to log. - """ - services.logger.error(message) - - self.debug = debug - self.info = info - self.warning = warning - self.error = error - - -class ImagesInterface: +class InvocationContextInterface: def __init__(self, services: InvocationServices, context_data: InvocationContextData) -> None: - def save( - image: Image, - board_id: Optional[str] = None, - image_category: ImageCategory = ImageCategory.GENERAL, - metadata: Optional[MetadataField] = None, - ) -> ImageDTO: - """ - Saves an image, returning its DTO. - - If the current queue item has a workflow or metadata, it is automatically saved with the image. - - :param image: The image to save, as a PIL image. - :param board_id: The board ID to add the image to, if it should be added. - :param image_category: The category of the image. Only the GENERAL category is added \ - to the gallery. - :param metadata: The metadata to save with the image, if it should have any. If the \ - invocation inherits from `WithMetadata`, that metadata will be used automatically. \ - **Use this only if you want to override or provide metadata manually!** - """ - - # If the invocation inherits metadata, use that. Else, use the metadata passed in. - metadata_ = ( - context_data.invocation.metadata if isinstance(context_data.invocation, WithMetadata) else metadata - ) - - return services.images.create( - image=image, - is_intermediate=context_data.invocation.is_intermediate, - image_category=image_category, - board_id=board_id, - metadata=metadata_, - image_origin=ResourceOrigin.INTERNAL, - workflow=context_data.workflow, - session_id=context_data.session_id, - node_id=context_data.invocation.id, - ) - - def get_pil(image_name: str) -> Image: - """ - Gets an image as a PIL Image object. - - :param image_name: The name of the image to get. - """ - return services.images.get_pil_image(image_name) - - def get_metadata(image_name: str) -> Optional[MetadataField]: - """ - Gets an image's metadata, if it has any. - - :param image_name: The name of the image to get the metadata for. - """ - return services.images.get_metadata(image_name) - - def get_dto(image_name: str) -> ImageDTO: - """ - Gets an image as an ImageDTO object. - - :param image_name: The name of the image to get. - """ - return services.images.get_dto(image_name) - - def update( - image_name: str, - board_id: Optional[str] = None, - is_intermediate: Optional[bool] = False, - ) -> ImageDTO: - """ - Updates an image, returning its updated DTO. - - It is not suggested to update images saved by earlier nodes, as this can cause confusion for users. - - If you use this method, you *must* return the image as an :class:`ImageOutput` for the gallery to - get the updated image. - - :param image_name: The name of the image to update. - :param board_id: The board ID to add the image to, if it should be added. - :param is_intermediate: Whether the image is an intermediate. Intermediate images aren't added to the gallery. - """ - if is_intermediate is not None: - services.images.update(image_name, ImageRecordChanges(is_intermediate=is_intermediate)) - if board_id is None: - services.board_images.remove_image_from_board(image_name) - else: - services.board_images.add_image_to_board(image_name, board_id) - return services.images.get_dto(image_name) - - self.save = save - self.get_pil = get_pil - self.get_metadata = get_metadata - self.get_dto = get_dto - self.update = update - - -class LatentsInterface: - def __init__( + self._services = services + self._context_data = context_data + + +class BoardsInterface(InvocationContextInterface): + def create(self, board_name: str) -> BoardDTO: + """ + Creates a board. + + :param board_name: The name of the board to create. + """ + return self._services.boards.create(board_name) + + def get_dto(self, board_id: str) -> BoardDTO: + """ + Gets a board DTO. + + :param board_id: The ID of the board to get. + """ + return self._services.boards.get_dto(board_id) + + def get_all(self) -> list[BoardDTO]: + """ + Gets all boards. + """ + return self._services.boards.get_all() + + def add_image_to_board(self, board_id: str, image_name: str) -> None: + """ + Adds an image to a board. + + :param board_id: The ID of the board to add the image to. + :param image_name: The name of the image to add to the board. + """ + return self._services.board_images.add_image_to_board(board_id, image_name) + + def get_all_image_names_for_board(self, board_id: str) -> list[str]: + """ + Gets all image names for a board. + + :param board_id: The ID of the board to get the image names for. + """ + return self._services.board_images.get_all_board_image_names_for_board(board_id) + + +class LoggerInterface(InvocationContextInterface): + def debug(self, message: str) -> None: + """ + Logs a debug message. + + :param message: The message to log. + """ + self._services.logger.debug(message) + + def info(self, message: str) -> None: + """ + Logs an info message. + + :param message: The message to log. + """ + self._services.logger.info(message) + + def warning(self, message: str) -> None: + """ + Logs a warning message. + + :param message: The message to log. + """ + self._services.logger.warning(message) + + def error(self, message: str) -> None: + """ + Logs an error message. + + :param message: The message to log. + """ + self._services.logger.error(message) + + +class ImagesInterface(InvocationContextInterface): + def save( self, - services: InvocationServices, - context_data: InvocationContextData, - ) -> None: - def save(tensor: Tensor) -> str: - """ - Saves a latents tensor, returning its name. - - :param tensor: The latents tensor to save. - """ - - # Previously, we added a suffix indicating the type of Tensor we were saving, e.g. - # "mask", "noise", "masked_latents", etc. - # - # Retaining that capability in this wrapper would require either many different methods - # to save latents, or extra args for this method. Instead of complicating the API, we - # will use the same naming scheme for all latents. - # - # This has a very minor impact as we don't use them after a session completes. - - # Previously, invocations chose the name for their latents. This is a bit risky, so we - # will generate a name for them instead. We use a uuid to ensure the name is unique. - # - # Because the name of the latents file will includes the session and invocation IDs, - # we don't need to worry about collisions. A truncated UUIDv4 is fine. - - name = f"{context_data.session_id}__{context_data.invocation.id}__{uuid_string()[:7]}" - services.latents.save( - name=name, - data=tensor, - ) - return name - - def get(latents_name: str) -> Tensor: - """ - Gets a latents tensor by name. - - :param latents_name: The name of the latents tensor to get. - """ - return services.latents.get(latents_name) - - self.save = save - self.get = get - - -class ConditioningInterface: - def __init__( + image: Image, + board_id: Optional[str] = None, + image_category: ImageCategory = ImageCategory.GENERAL, + metadata: Optional[MetadataField] = None, + ) -> ImageDTO: + """ + Saves an image, returning its DTO. + + If the current queue item has a workflow or metadata, it is automatically saved with the image. + + :param image: The image to save, as a PIL image. + :param board_id: The board ID to add the image to, if it should be added. + :param image_category: The category of the image. Only the GENERAL category is added \ + to the gallery. + :param metadata: The metadata to save with the image, if it should have any. If the \ + invocation inherits from `WithMetadata`, that metadata will be used automatically. \ + **Use this only if you want to override or provide metadata manually!** + """ + + # If the invocation inherits metadata, use that. Else, use the metadata passed in. + metadata_ = ( + self._context_data.invocation.metadata + if isinstance(self._context_data.invocation, WithMetadata) + else metadata + ) + + return self._services.images.create( + image=image, + is_intermediate=self._context_data.invocation.is_intermediate, + image_category=image_category, + board_id=board_id, + metadata=metadata_, + image_origin=ResourceOrigin.INTERNAL, + workflow=self._context_data.workflow, + session_id=self._context_data.session_id, + node_id=self._context_data.invocation.id, + ) + + def get_pil(self, image_name: str) -> Image: + """ + Gets an image as a PIL Image object. + + :param image_name: The name of the image to get. + """ + return self._services.images.get_pil_image(image_name) + + def get_metadata(self, image_name: str) -> Optional[MetadataField]: + """ + Gets an image's metadata, if it has any. + + :param image_name: The name of the image to get the metadata for. + """ + return self._services.images.get_metadata(image_name) + + def get_dto(self, image_name: str) -> ImageDTO: + """ + Gets an image as an ImageDTO object. + + :param image_name: The name of the image to get. + """ + return self._services.images.get_dto(image_name) + + def update( self, - services: InvocationServices, - context_data: InvocationContextData, - ) -> None: - # TODO(psyche): We are (ab)using the latents storage service as a general pickle storage - # service, but it is typed to work with Tensors only. We have to fudge the types here. + image_name: str, + board_id: Optional[str] = None, + is_intermediate: Optional[bool] = False, + ) -> ImageDTO: + """ + Updates an image, returning its updated DTO. + + It is not suggested to update images saved by earlier nodes, as this can cause confusion for users. - def save(conditioning_data: ConditioningFieldData) -> str: - """ - Saves a conditioning data object, returning its name. + If you use this method, you *must* return the image as an :class:`ImageOutput` for the gallery to + get the updated image. - :param conditioning_data: The conditioning data to save. - """ + :param image_name: The name of the image to update. + :param board_id: The board ID to add the image to, if it should be added. + :param is_intermediate: Whether the image is an intermediate. Intermediate images aren't added to the gallery. + """ + if is_intermediate is not None: + self._services.images.update(image_name, ImageRecordChanges(is_intermediate=is_intermediate)) + if board_id is None: + self._services.board_images.remove_image_from_board(image_name) + else: + self._services.board_images.add_image_to_board(image_name, board_id) + return self._services.images.get_dto(image_name) - # Conditioning data is *not* a Tensor, so we will suffix it to indicate this. - # - # See comment for `LatentsInterface.save` for more info about this method (it's very - # similar). - name = f"{context_data.session_id}__{context_data.invocation.id}__{uuid_string()[:7]}__conditioning" - services.latents.save( - name=name, - data=conditioning_data, # type: ignore [arg-type] - ) - return name +class LatentsInterface(InvocationContextInterface): + def save(self, tensor: Tensor) -> str: + """ + Saves a latents tensor, returning its name. - def get(conditioning_name: str) -> ConditioningFieldData: - """ - Gets conditioning data by name. + :param tensor: The latents tensor to save. + """ - :param conditioning_name: The name of the conditioning data to get. - """ + # Previously, we added a suffix indicating the type of Tensor we were saving, e.g. + # "mask", "noise", "masked_latents", etc. + # + # Retaining that capability in this wrapper would require either many different methods + # to save latents, or extra args for this method. Instead of complicating the API, we + # will use the same naming scheme for all latents. + # + # This has a very minor impact as we don't use them after a session completes. + + # Previously, invocations chose the name for their latents. This is a bit risky, so we + # will generate a name for them instead. We use a uuid to ensure the name is unique. + # + # Because the name of the latents file will includes the session and invocation IDs, + # we don't need to worry about collisions. A truncated UUIDv4 is fine. + + name = f"{self._context_data.session_id}__{self._context_data.invocation.id}__{uuid_string()[:7]}" + self._services.latents.save( + name=name, + data=tensor, + ) + return name + + def get(self, latents_name: str) -> Tensor: + """ + Gets a latents tensor by name. - return services.latents.get(conditioning_name) # type: ignore [return-value] + :param latents_name: The name of the latents tensor to get. + """ + return self._services.latents.get(latents_name) - self.save = save - self.get = get +class ConditioningInterface(InvocationContextInterface): + # TODO(psyche): We are (ab)using the latents storage service as a general pickle storage + # service, but it is typed to work with Tensors only. We have to fudge the types here. + def save(self, conditioning_data: ConditioningFieldData) -> str: + """ + Saves a conditioning data object, returning its name. -class ModelsInterface: - def __init__(self, services: InvocationServices, context_data: InvocationContextData) -> None: - def exists(model_name: str, base_model: BaseModelType, model_type: ModelType) -> bool: - """ - Checks if a model exists. - - :param model_name: The name of the model to check. - :param base_model: The base model of the model to check. - :param model_type: The type of the model to check. - """ - return services.model_manager.model_exists(model_name, base_model, model_type) - - def load( - model_name: str, base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None - ) -> ModelInfo: - """ - Loads a model, returning its `ModelInfo` object. - - :param model_name: The name of the model to get. - :param base_model: The base model of the model to get. - :param model_type: The type of the model to get. - :param submodel: The submodel of the model to get. - """ - - # During this call, the model manager emits events with model loading status. The model - # manager itself has access to the events services, but does not have access to the - # required metadata for the events. - # - # For example, it needs access to the node's ID so that the events can be associated - # with the execution of a specific node. - # - # While this is available within the node, it's tedious to need to pass it in on every - # call. We can avoid that by wrapping the method here. - - return services.model_manager.get_model( - model_name, base_model, model_type, submodel, context_data=context_data - ) - - def get_info(model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict: - """ - Gets a model's info, an dict-like object. - - :param model_name: The name of the model to get. - :param base_model: The base model of the model to get. - :param model_type: The type of the model to get. - """ - return services.model_manager.model_info(model_name, base_model, model_type) - - self.exists = exists - self.load = load - self.get_info = get_info - - -class ConfigInterface: - def __init__(self, services: InvocationServices) -> None: - def get() -> InvokeAIAppConfig: - """ - Gets the app's config. The config is read-only; attempts to mutate it will raise an error. - """ - - # The config can be changed at runtime. - # - # We don't want nodes doing this, so we make a frozen copy. - - config = services.configuration.get_config() - # TODO(psyche): If config cannot be changed at runtime, should we cache this? - frozen_config = config.model_copy(update={"model_config": ConfigDict(frozen=True)}) - return frozen_config - - self.get = get - - -class UtilInterface: - def __init__(self, services: InvocationServices, context_data: InvocationContextData) -> None: - def sd_step_callback( - intermediate_state: PipelineIntermediateState, - base_model: BaseModelType, - ) -> None: - """ - The step callback emits a progress event with the current step, the total number of - steps, a preview image, and some other internal metadata. + :param conditioning_context_data: The conditioning data to save. + """ + + # Conditioning data is *not* a Tensor, so we will suffix it to indicate this. + # + # See comment for `LatentsInterface.save` for more info about this method (it's very + # similar). - This should be called after each denoising step. + name = f"{self._context_data.session_id}__{self._context_data.invocation.id}__{uuid_string()[:7]}__conditioning" + self._services.latents.save( + name=name, + data=conditioning_data, # type: ignore [arg-type] + ) + return name - :param intermediate_state: The intermediate state of the diffusion pipeline. - :param base_model: The base model for the current denoising step. - """ + def get(self, conditioning_name: str) -> ConditioningFieldData: + """ + Gets conditioning data by name. + + :param conditioning_name: The name of the conditioning data to get. + """ - # The step callback needs access to the events and the invocation queue services, but this - # represents a dangerous level of access. - # - # We wrap the step callback so that nodes do not have direct access to these services. + return self._services.latents.get(conditioning_name) # type: ignore [return-value] + + +class ModelsInterface(InvocationContextInterface): + def exists(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> bool: + """ + Checks if a model exists. + + :param model_name: The name of the model to check. + :param base_model: The base model of the model to check. + :param model_type: The type of the model to check. + """ + return self._services.model_manager.model_exists(model_name, base_model, model_type) + + def load( + self, model_name: str, base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None + ) -> ModelInfo: + """ + Loads a model, returning its `ModelInfo` object. + + :param model_name: The name of the model to get. + :param base_model: The base model of the model to get. + :param model_type: The type of the model to get. + :param submodel: The submodel of the model to get. + """ + + # During this call, the model manager emits events with model loading status. The model + # manager itself has access to the events services, but does not have access to the + # required metadata for the events. + # + # For example, it needs access to the node's ID so that the events can be associated + # with the execution of a specific node. + # + # While this is available within the node, it's tedious to need to pass it in on every + # call. We can avoid that by wrapping the method here. + + return self._services.model_manager.get_model( + model_name, base_model, model_type, submodel, context_data=self._context_data + ) + + def get_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict: + """ + Gets a model's info, an dict-like object. + + :param model_name: The name of the model to get. + :param base_model: The base model of the model to get. + :param model_type: The type of the model to get. + """ + return self._services.model_manager.model_info(model_name, base_model, model_type) + + +class ConfigInterface(InvocationContextInterface): + def get(self) -> InvokeAIAppConfig: + """ + Gets the app's config. The config is read-only; attempts to mutate it will raise an error. + """ + + # The config can be changed at runtime. + # + # We don't want nodes doing this, so we make a frozen copy. + + config = self._services.configuration.get_config() + # TODO(psyche): If config cannot be changed at runtime, should we cache this? + frozen_config = config.model_copy(update={"model_config": ConfigDict(frozen=True)}) + return frozen_config + + +class UtilInterface(InvocationContextInterface): + def sd_step_callback(self, intermediate_state: PipelineIntermediateState, base_model: BaseModelType) -> None: + """ + The step callback emits a progress event with the current step, the total number of + steps, a preview image, and some other internal metadata. + + This should be called after each denoising step. + + :param intermediate_state: The intermediate state of the diffusion pipeline. + :param base_model: The base model for the current denoising step. + """ - stable_diffusion_step_callback( - context_data=context_data, - intermediate_state=intermediate_state, - base_model=base_model, - invocation_queue=services.queue, - events=services.events, - ) + # The step callback needs access to the events and the invocation queue services, but this + # represents a dangerous level of access. + # + # We wrap the step callback so that nodes do not have direct access to these services. - self.sd_step_callback = sd_step_callback + stable_diffusion_step_callback( + context_data=self._context_data, + intermediate_state=intermediate_state, + base_model=base_model, + invocation_queue=self._services.queue, + events=self._services.events, + ) deprecation_version = "3.7.0" @@ -600,14 +559,14 @@ def build_invocation_context( :param invocation_context_data: The invocation context data. """ - logger = LoggerInterface(services=services) + logger = LoggerInterface(services=services, context_data=context_data) images = ImagesInterface(services=services, context_data=context_data) latents = LatentsInterface(services=services, context_data=context_data) models = ModelsInterface(services=services, context_data=context_data) - config = ConfigInterface(services=services) + config = ConfigInterface(services=services, context_data=context_data) util = UtilInterface(services=services, context_data=context_data) conditioning = ConditioningInterface(services=services, context_data=context_data) - boards = BoardsInterface(services=services) + boards = BoardsInterface(services=services, context_data=context_data) ctx = InvocationContext( images=images, From 337c2dcb1e979b52cca10d58c2e46cbe718baba0 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 14:28:29 +1100 Subject: [PATCH 022/340] feat(nodes): cache invocation interface config --- .../app/services/shared/invocation_context.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 54c50bcf76b..99e439ad96d 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -357,19 +357,24 @@ def get_info(self, model_name: str, base_model: BaseModelType, model_type: Model class ConfigInterface(InvocationContextInterface): + def __init__(self, services: InvocationServices, context_data: InvocationContextData) -> None: + super().__init__(services, context_data) + # Config cache, only populated at runtime if requested + self._frozen_config: Optional[InvokeAIAppConfig] = None + def get(self) -> InvokeAIAppConfig: """ Gets the app's config. The config is read-only; attempts to mutate it will raise an error. """ - # The config can be changed at runtime. - # - # We don't want nodes doing this, so we make a frozen copy. + if self._frozen_config is None: + # The config is a live pydantic model and can be changed at runtime. + # We don't want nodes doing this, so we make a frozen copy. + self._frozen_config = self._services.configuration.get_config().model_copy( + update={"model_config": ConfigDict(frozen=True)} + ) - config = self._services.configuration.get_config() - # TODO(psyche): If config cannot be changed at runtime, should we cache this? - frozen_config = config.model_copy(update={"model_config": ConfigDict(frozen=True)}) - return frozen_config + return self._frozen_config class UtilInterface(InvocationContextInterface): From cf9db751d78469de42af9642510f106cb9f38e94 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 14:36:42 +1100 Subject: [PATCH 023/340] feat(nodes): context.__services -> context._services --- invokeai/app/services/shared/invocation_context.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 99e439ad96d..5da85931672 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -463,7 +463,8 @@ def __init__( """Provides methods to interact with boards.""" self.data = data """Provides data about the current queue item and invocation.""" - self.__services = services + self._services = services + """Provides access to the full application services. This is an internal API and may change without warning.""" @property @deprecated(version=deprecation_version, reason=get_deprecation_reason("`context.services`")) @@ -475,7 +476,7 @@ def services(self) -> InvocationServices: The invocation services. """ - return self.__services + return self._services @property @deprecated( From 2146dfb1684515df1003c84f1fe3090a06b67514 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 14:39:26 +1100 Subject: [PATCH 024/340] feat(nodes): context.data -> context._data --- .../app/services/shared/invocation_context.py | 38 +++++++++---------- tests/aa_nodes/test_graph_execution_state.py | 2 +- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 5da85931672..b48a6acc545 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -442,7 +442,7 @@ def __init__( config: ConfigInterface, util: UtilInterface, boards: BoardsInterface, - data: InvocationContextData, + context_data: InvocationContextData, services: InvocationServices, ) -> None: self.images = images @@ -461,8 +461,8 @@ def __init__( """Provides utility methods.""" self.boards = boards """Provides methods to interact with boards.""" - self.data = data - """Provides data about the current queue item and invocation.""" + self._data = context_data + """Provides data about the current queue item and invocation. This is an internal API and may change without warning.""" self._services = services """Provides access to the full application services. This is an internal API and may change without warning.""" @@ -481,77 +481,77 @@ def services(self) -> InvocationServices: @property @deprecated( version=deprecation_version, - reason=get_deprecation_reason("`context.graph_execution_state_api`", "`context.data.session_id`"), + reason=get_deprecation_reason("`context.graph_execution_state_id", "`context._data.session_id`"), ) def graph_execution_state_id(self) -> str: """ **DEPRECATED as of v3.7.0** - `context.graph_execution_state_api` will be removed in v3.8.0. Use `context.data.session_id` instead. See PLACEHOLDER_URL for details. + `context.graph_execution_state_api` will be removed in v3.8.0. Use `context._data.session_id` instead. See PLACEHOLDER_URL for details. The ID of the session (aka graph execution state). """ - return self.data.session_id + return self._data.session_id @property @deprecated( version=deprecation_version, - reason=get_deprecation_reason("`context.queue_id`", "`context.data.queue_id`"), + reason=get_deprecation_reason("`context.queue_id`", "`context._data.queue_id`"), ) def queue_id(self) -> str: """ **DEPRECATED as of v3.7.0** - `context.queue_id` will be removed in v3.8.0. Use `context.data.queue_id` instead. See PLACEHOLDER_URL for details. + `context.queue_id` will be removed in v3.8.0. Use `context._data.queue_id` instead. See PLACEHOLDER_URL for details. The ID of the queue. """ - return self.data.queue_id + return self._data.queue_id @property @deprecated( version=deprecation_version, - reason=get_deprecation_reason("`context.queue_item_id`", "`context.data.queue_item_id`"), + reason=get_deprecation_reason("`context.queue_item_id`", "`context._data.queue_item_id`"), ) def queue_item_id(self) -> int: """ **DEPRECATED as of v3.7.0** - `context.queue_item_id` will be removed in v3.8.0. Use `context.data.queue_item_id` instead. See PLACEHOLDER_URL for details. + `context.queue_item_id` will be removed in v3.8.0. Use `context._data.queue_item_id` instead. See PLACEHOLDER_URL for details. The ID of the queue item. """ - return self.data.queue_item_id + return self._data.queue_item_id @property @deprecated( version=deprecation_version, - reason=get_deprecation_reason("`context.queue_batch_id`", "`context.data.batch_id`"), + reason=get_deprecation_reason("`context.queue_batch_id`", "`context._data.batch_id`"), ) def queue_batch_id(self) -> str: """ **DEPRECATED as of v3.7.0** - `context.queue_batch_id` will be removed in v3.8.0. Use `context.data.batch_id` instead. See PLACEHOLDER_URL for details. + `context.queue_batch_id` will be removed in v3.8.0. Use `context._data.batch_id` instead. See PLACEHOLDER_URL for details. The ID of the batch. """ - return self.data.batch_id + return self._data.batch_id @property @deprecated( version=deprecation_version, - reason=get_deprecation_reason("`context.workflow`", "`context.data.workflow`"), + reason=get_deprecation_reason("`context.workflow`", "`context._data.workflow`"), ) def workflow(self) -> Optional[WorkflowWithoutID]: """ **DEPRECATED as of v3.7.0** - `context.workflow` will be removed in v3.8.0. Use `context.data.workflow` instead. See PLACEHOLDER_URL for details. + `context.workflow` will be removed in v3.8.0. Use `context._data.workflow` instead. See PLACEHOLDER_URL for details. The workflow associated with this queue item, if any. """ - return self.data.workflow + return self._data.workflow def build_invocation_context( @@ -580,7 +580,7 @@ def build_invocation_context( config=config, latents=latents, models=models, - data=context_data, + context_data=context_data, util=util, conditioning=conditioning, services=services, diff --git a/tests/aa_nodes/test_graph_execution_state.py b/tests/aa_nodes/test_graph_execution_state.py index 1612cbe7198..aba7c5694f3 100644 --- a/tests/aa_nodes/test_graph_execution_state.py +++ b/tests/aa_nodes/test_graph_execution_state.py @@ -87,7 +87,7 @@ def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[B InvocationContext( conditioning=None, config=None, - data=None, + context_data=None, images=None, latents=None, logger=None, From 6f94af2cacb6d45b5b683a56c950bfae0a7454dc Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 15:58:46 +1100 Subject: [PATCH 025/340] fix(nodes): do not freeze or cache config in context wrapper - The config is already cached by the config class's `get_config()` method. - The config mutates itself in its `root_path` property getter. Freezing the class makes any attempt to grab a path from the config error. Unfortunately this means we cannot easily freeze the class without fiddling with the inner workings of `InvokeAIAppConfig`, which is outside the scope here. --- .../app/services/shared/invocation_context.py | 18 ++---------------- 1 file changed, 2 insertions(+), 16 deletions(-) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index b48a6acc545..cd88ec876dd 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -357,24 +357,10 @@ def get_info(self, model_name: str, base_model: BaseModelType, model_type: Model class ConfigInterface(InvocationContextInterface): - def __init__(self, services: InvocationServices, context_data: InvocationContextData) -> None: - super().__init__(services, context_data) - # Config cache, only populated at runtime if requested - self._frozen_config: Optional[InvokeAIAppConfig] = None - def get(self) -> InvokeAIAppConfig: - """ - Gets the app's config. The config is read-only; attempts to mutate it will raise an error. - """ - - if self._frozen_config is None: - # The config is a live pydantic model and can be changed at runtime. - # We don't want nodes doing this, so we make a frozen copy. - self._frozen_config = self._services.configuration.get_config().model_copy( - update={"model_config": ConfigDict(frozen=True)} - ) + """Gets the app's config.""" - return self._frozen_config + return self._services.configuration.get_config() class UtilInterface(InvocationContextInterface): From f486ebf8c166e4ee10305807b03616f8956bef0b Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 16:14:35 +1100 Subject: [PATCH 026/340] fix(ui): remove original l2i node in HRF graph --- .../web/src/features/nodes/util/graph/addHrfToGraph.ts | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addHrfToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addHrfToGraph.ts index 7413302fa57..8a4448833cf 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/addHrfToGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/addHrfToGraph.ts @@ -314,6 +314,10 @@ export const addHrfToGraph = (state: RootState, graph: NonNullableGraph): void = ); copyConnectionsToDenoiseLatentsHrf(graph); + // The original l2i node is unnecessary now, remove it + graph.edges = graph.edges.filter((edge) => edge.destination.node_id !== LATENTS_TO_IMAGE); + delete graph.nodes[LATENTS_TO_IMAGE]; + graph.nodes[LATENTS_TO_IMAGE_HRF_HR] = { type: 'l2i', id: LATENTS_TO_IMAGE_HRF_HR, From 33a8c4c7d37a6a4e8eb299574fad62fe1c1b1e5e Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 16:23:57 +1100 Subject: [PATCH 027/340] remove unused configdict import --- invokeai/app/services/shared/invocation_context.py | 1 - 1 file changed, 1 deletion(-) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index cd88ec876dd..8aaa5233afd 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -3,7 +3,6 @@ from deprecated import deprecated from PIL.Image import Image -from pydantic import ConfigDict from torch import Tensor from invokeai.app.invocations.fields import MetadataField, WithMetadata From 9e5d11081acdcc81ffae58ab8fc924ddc12ba113 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 16:33:55 +1100 Subject: [PATCH 028/340] feat(nodes): add `WithBoard` field helper class This class works the same way as `WithMetadata` - it simply adds a `board` field to the node. The context wrapper function is able to pull the board id from this. This allows image-outputting nodes to get a board field "for free", and have their outputs automatically saved to it. This is a breaking change for node authors who may have a field called `board`, because it makes `board` a reserved field name. I'll look into how to avoid this - maybe by naming this invoke-managed field `_board` to avoid collisions? Supporting changes: - `WithBoard` is added to all image-outputting nodes, giving them the ability to save to board. - Unused, duplicate `WithMetadata` and `WithWorkflow` classes are deleted from `baseinvocation.py`. The "real" versions are in `fields.py`. - Remove `LinearUIOutputInvocation`. Now that all nodes that output images also have a `board` field by default, this node is no longer necessary. See comment here for context: https://github.com/invoke-ai/InvokeAI/pull/5491#discussion_r1480760629 - Without `LinearUIOutputInvocation`, the `ImagesInferface.update` method is no longer needed, and removed. Note: This commit does not bump all node versions. I will ensure that is done correctly before merging the PR of which this commit is a part. Note: A followup commit will implement the frontend changes to support this change. --- invokeai/app/invocations/baseinvocation.py | 33 +------- .../controlnet_image_processors.py | 12 ++- invokeai/app/invocations/cv.py | 4 +- invokeai/app/invocations/facetools.py | 4 +- invokeai/app/invocations/fields.py | 16 ++++ invokeai/app/invocations/image.py | 76 ++++++------------- invokeai/app/invocations/infill.py | 12 +-- invokeai/app/invocations/latent.py | 3 +- invokeai/app/invocations/primitives.py | 4 +- invokeai/app/invocations/tiles.py | 4 +- invokeai/app/invocations/upscale.py | 4 +- .../app/services/shared/invocation_context.py | 40 +++------- 12 files changed, 78 insertions(+), 134 deletions(-) diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index df0596c9a15..3243714937f 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -17,11 +17,8 @@ from pydantic_core import PydanticUndefined from invokeai.app.invocations.fields import ( - FieldDescriptions, FieldKind, Input, - InputFieldJSONSchemaExtra, - MetadataField, ) from invokeai.app.services.config.config_default import InvokeAIAppConfig from invokeai.app.services.shared.invocation_context import InvocationContext @@ -306,9 +303,7 @@ def invoke_internal(self, context: InvocationContext, services: "InvocationServi "workflow", } -RESERVED_INPUT_FIELD_NAMES = { - "metadata", -} +RESERVED_INPUT_FIELD_NAMES = {"metadata", "board"} RESERVED_OUTPUT_FIELD_NAMES = {"type"} @@ -518,29 +513,3 @@ def wrapper(cls: Type[TBaseInvocationOutput]) -> Type[TBaseInvocationOutput]: return cls return wrapper - - -class WithMetadata(BaseModel): - """ - Inherit from this class if your node needs a metadata input field. - """ - - metadata: Optional[MetadataField] = Field( - default=None, - description=FieldDescriptions.metadata, - json_schema_extra=InputFieldJSONSchemaExtra( - field_kind=FieldKind.Internal, - input=Input.Connection, - orig_required=False, - ).model_dump(exclude_none=True), - ) - - -class WithWorkflow: - workflow = None - - def __init_subclass__(cls) -> None: - logger.warn( - f"{cls.__module__.split('.')[0]}.{cls.__name__}: WithWorkflow is deprecated. Use `context.workflow` to access the workflow." - ) - super().__init_subclass__() diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index f8bdf14117c..37954c1097e 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -25,7 +25,15 @@ from PIL import Image from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator -from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField, WithMetadata +from invokeai.app.invocations.fields import ( + FieldDescriptions, + ImageField, + Input, + InputField, + OutputField, + WithBoard, + WithMetadata, +) from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.invocations.util import validate_begin_end_step, validate_weights from invokeai.app.services.shared.invocation_context import InvocationContext @@ -135,7 +143,7 @@ def invoke(self, context: InvocationContext) -> ControlOutput: # This invocation exists for other invocations to subclass it - do not register with @invocation! -class ImageProcessorInvocation(BaseInvocation, WithMetadata): +class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard): """Base class for invocations that preprocess images for ControlNet""" image: ImageField = InputField(description="The image to process") diff --git a/invokeai/app/invocations/cv.py b/invokeai/app/invocations/cv.py index 1ebabf5e064..8174f19b64c 100644 --- a/invokeai/app/invocations/cv.py +++ b/invokeai/app/invocations/cv.py @@ -10,11 +10,11 @@ from invokeai.app.services.shared.invocation_context import InvocationContext from .baseinvocation import BaseInvocation, invocation -from .fields import InputField, WithMetadata +from .fields import InputField, WithBoard, WithMetadata @invocation("cv_inpaint", title="OpenCV Inpaint", tags=["opencv", "inpaint"], category="inpaint", version="1.2.1") -class CvInpaintInvocation(BaseInvocation, WithMetadata): +class CvInpaintInvocation(BaseInvocation, WithMetadata, WithBoard): """Simple inpaint using opencv.""" image: ImageField = InputField(description="The image to inpaint") diff --git a/invokeai/app/invocations/facetools.py b/invokeai/app/invocations/facetools.py index a1702d6517c..fed2ed5e4f2 100644 --- a/invokeai/app/invocations/facetools.py +++ b/invokeai/app/invocations/facetools.py @@ -16,7 +16,7 @@ invocation, invocation_output, ) -from invokeai.app.invocations.fields import ImageField, InputField, OutputField, WithMetadata +from invokeai.app.invocations.fields import ImageField, InputField, OutputField, WithBoard, WithMetadata from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.services.image_records.image_records_common import ImageCategory from invokeai.app.services.shared.invocation_context import InvocationContext @@ -619,7 +619,7 @@ def invoke(self, context: InvocationContext) -> FaceMaskOutput: @invocation( "face_identifier", title="FaceIdentifier", tags=["image", "face", "identifier"], category="image", version="1.2.1" ) -class FaceIdentifierInvocation(BaseInvocation, WithMetadata): +class FaceIdentifierInvocation(BaseInvocation, WithMetadata, WithBoard): """Outputs an image with detected face IDs printed on each face. For use with other FaceTools.""" image: ImageField = InputField(description="Image to face detect") diff --git a/invokeai/app/invocations/fields.py b/invokeai/app/invocations/fields.py index 8879f760770..c42d2f83120 100644 --- a/invokeai/app/invocations/fields.py +++ b/invokeai/app/invocations/fields.py @@ -280,6 +280,22 @@ def __init_subclass__(cls) -> None: super().__init_subclass__() +class WithBoard(BaseModel): + """ + Inherit from this class if your node needs a board input field. + """ + + board: Optional["BoardField"] = Field( + default=None, + description=FieldDescriptions.board, + json_schema_extra=InputFieldJSONSchemaExtra( + field_kind=FieldKind.Internal, + input=Input.Direct, + orig_required=False, + ).model_dump(exclude_none=True), + ) + + class OutputFieldJSONSchemaExtra(BaseModel): """ Extra attributes to be added to input fields and their OpenAPI schema. Used by the workflow editor diff --git a/invokeai/app/invocations/image.py b/invokeai/app/invocations/image.py index 7b74e4d96d4..f5ad5515a68 100644 --- a/invokeai/app/invocations/image.py +++ b/invokeai/app/invocations/image.py @@ -8,12 +8,11 @@ from PIL import Image, ImageChops, ImageFilter, ImageOps from invokeai.app.invocations.fields import ( - BoardField, ColorField, FieldDescriptions, ImageField, - Input, InputField, + WithBoard, WithMetadata, ) from invokeai.app.invocations.primitives import ImageOutput @@ -55,7 +54,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: category="image", version="1.2.1", ) -class BlankImageInvocation(BaseInvocation, WithMetadata): +class BlankImageInvocation(BaseInvocation, WithMetadata, WithBoard): """Creates a blank image and forwards it to the pipeline""" width: int = InputField(default=512, description="The width of the image") @@ -78,7 +77,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: category="image", version="1.2.1", ) -class ImageCropInvocation(BaseInvocation, WithMetadata): +class ImageCropInvocation(BaseInvocation, WithMetadata, WithBoard): """Crops an image to a specified box. The box can be outside of the image.""" image: ImageField = InputField(description="The image to crop") @@ -149,7 +148,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: category="image", version="1.2.1", ) -class ImagePasteInvocation(BaseInvocation, WithMetadata): +class ImagePasteInvocation(BaseInvocation, WithMetadata, WithBoard): """Pastes an image into another image.""" base_image: ImageField = InputField(description="The base image") @@ -196,7 +195,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: category="image", version="1.2.1", ) -class MaskFromAlphaInvocation(BaseInvocation, WithMetadata): +class MaskFromAlphaInvocation(BaseInvocation, WithMetadata, WithBoard): """Extracts the alpha channel of an image as a mask.""" image: ImageField = InputField(description="The image to create the mask from") @@ -221,7 +220,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: category="image", version="1.2.1", ) -class ImageMultiplyInvocation(BaseInvocation, WithMetadata): +class ImageMultiplyInvocation(BaseInvocation, WithMetadata, WithBoard): """Multiplies two images together using `PIL.ImageChops.multiply()`.""" image1: ImageField = InputField(description="The first image to multiply") @@ -248,7 +247,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: category="image", version="1.2.1", ) -class ImageChannelInvocation(BaseInvocation, WithMetadata): +class ImageChannelInvocation(BaseInvocation, WithMetadata, WithBoard): """Gets a channel from an image.""" image: ImageField = InputField(description="The image to get the channel from") @@ -274,7 +273,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: category="image", version="1.2.1", ) -class ImageConvertInvocation(BaseInvocation, WithMetadata): +class ImageConvertInvocation(BaseInvocation, WithMetadata, WithBoard): """Converts an image to a different mode.""" image: ImageField = InputField(description="The image to convert") @@ -297,7 +296,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: category="image", version="1.2.1", ) -class ImageBlurInvocation(BaseInvocation, WithMetadata): +class ImageBlurInvocation(BaseInvocation, WithMetadata, WithBoard): """Blurs an image""" image: ImageField = InputField(description="The image to blur") @@ -326,7 +325,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: version="1.2.1", classification=Classification.Beta, ) -class UnsharpMaskInvocation(BaseInvocation, WithMetadata): +class UnsharpMaskInvocation(BaseInvocation, WithMetadata, WithBoard): """Applies an unsharp mask filter to an image""" image: ImageField = InputField(description="The image to use") @@ -394,7 +393,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: category="image", version="1.2.1", ) -class ImageResizeInvocation(BaseInvocation, WithMetadata): +class ImageResizeInvocation(BaseInvocation, WithMetadata, WithBoard): """Resizes an image to specific dimensions""" image: ImageField = InputField(description="The image to resize") @@ -424,7 +423,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: category="image", version="1.2.1", ) -class ImageScaleInvocation(BaseInvocation, WithMetadata): +class ImageScaleInvocation(BaseInvocation, WithMetadata, WithBoard): """Scales an image by a factor""" image: ImageField = InputField(description="The image to scale") @@ -459,7 +458,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: category="image", version="1.2.1", ) -class ImageLerpInvocation(BaseInvocation, WithMetadata): +class ImageLerpInvocation(BaseInvocation, WithMetadata, WithBoard): """Linear interpolation of all pixels of an image""" image: ImageField = InputField(description="The image to lerp") @@ -486,7 +485,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: category="image", version="1.2.1", ) -class ImageInverseLerpInvocation(BaseInvocation, WithMetadata): +class ImageInverseLerpInvocation(BaseInvocation, WithMetadata, WithBoard): """Inverse linear interpolation of all pixels of an image""" image: ImageField = InputField(description="The image to lerp") @@ -513,7 +512,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: category="image", version="1.2.1", ) -class ImageNSFWBlurInvocation(BaseInvocation, WithMetadata): +class ImageNSFWBlurInvocation(BaseInvocation, WithMetadata, WithBoard): """Add blur to NSFW-flagged images""" image: ImageField = InputField(description="The image to check") @@ -548,7 +547,7 @@ def _get_caution_img(self) -> Image.Image: category="image", version="1.2.1", ) -class ImageWatermarkInvocation(BaseInvocation, WithMetadata): +class ImageWatermarkInvocation(BaseInvocation, WithMetadata, WithBoard): """Add an invisible watermark to an image""" image: ImageField = InputField(description="The image to check") @@ -569,7 +568,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: category="image", version="1.2.1", ) -class MaskEdgeInvocation(BaseInvocation, WithMetadata): +class MaskEdgeInvocation(BaseInvocation, WithMetadata, WithBoard): """Applies an edge mask to an image""" image: ImageField = InputField(description="The image to apply the mask to") @@ -608,7 +607,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: category="image", version="1.2.1", ) -class MaskCombineInvocation(BaseInvocation, WithMetadata): +class MaskCombineInvocation(BaseInvocation, WithMetadata, WithBoard): """Combine two masks together by multiplying them using `PIL.ImageChops.multiply()`.""" mask1: ImageField = InputField(description="The first mask to combine") @@ -632,7 +631,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: category="image", version="1.2.1", ) -class ColorCorrectInvocation(BaseInvocation, WithMetadata): +class ColorCorrectInvocation(BaseInvocation, WithMetadata, WithBoard): """ Shifts the colors of a target image to match the reference image, optionally using a mask to only color-correct certain regions of the target image. @@ -736,7 +735,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: category="image", version="1.2.1", ) -class ImageHueAdjustmentInvocation(BaseInvocation, WithMetadata): +class ImageHueAdjustmentInvocation(BaseInvocation, WithMetadata, WithBoard): """Adjusts the Hue of an image.""" image: ImageField = InputField(description="The image to adjust") @@ -825,7 +824,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: category="image", version="1.2.1", ) -class ImageChannelOffsetInvocation(BaseInvocation, WithMetadata): +class ImageChannelOffsetInvocation(BaseInvocation, WithMetadata, WithBoard): """Add or subtract a value from a specific color channel of an image.""" image: ImageField = InputField(description="The image to adjust") @@ -881,7 +880,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: category="image", version="1.2.1", ) -class ImageChannelMultiplyInvocation(BaseInvocation, WithMetadata): +class ImageChannelMultiplyInvocation(BaseInvocation, WithMetadata, WithBoard): """Scale a specific color channel of an image.""" image: ImageField = InputField(description="The image to adjust") @@ -926,41 +925,14 @@ def invoke(self, context: InvocationContext) -> ImageOutput: version="1.2.1", use_cache=False, ) -class SaveImageInvocation(BaseInvocation, WithMetadata): +class SaveImageInvocation(BaseInvocation, WithMetadata, WithBoard): """Saves an image. Unlike an image primitive, this invocation stores a copy of the image.""" image: ImageField = InputField(description=FieldDescriptions.image) - board: BoardField = InputField(default=None, description=FieldDescriptions.board, input=Input.Direct) def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) - image_dto = context.images.save(image=image, board_id=self.board.board_id if self.board else None) - - return ImageOutput.build(image_dto) - - -@invocation( - "linear_ui_output", - title="Linear UI Image Output", - tags=["primitives", "image"], - category="primitives", - version="1.0.2", - use_cache=False, -) -class LinearUIOutputInvocation(BaseInvocation, WithMetadata): - """Handles Linear UI Image Outputting tasks.""" - - image: ImageField = InputField(description=FieldDescriptions.image) - board: Optional[BoardField] = InputField(default=None, description=FieldDescriptions.board, input=Input.Direct) - - def invoke(self, context: InvocationContext) -> ImageOutput: - image_dto = context.images.get_dto(self.image.image_name) - - image_dto = context.images.update( - image_name=self.image.image_name, - board_id=self.board.board_id if self.board else None, - is_intermediate=self.is_intermediate, - ) + image_dto = context.images.save(image=image) return ImageOutput.build(image_dto) diff --git a/invokeai/app/invocations/infill.py b/invokeai/app/invocations/infill.py index b007edd9e42..53f6f4732fe 100644 --- a/invokeai/app/invocations/infill.py +++ b/invokeai/app/invocations/infill.py @@ -15,7 +15,7 @@ from invokeai.backend.image_util.patchmatch import PatchMatch from .baseinvocation import BaseInvocation, invocation -from .fields import InputField, WithMetadata +from .fields import InputField, WithBoard, WithMetadata from .image import PIL_RESAMPLING_MAP, PIL_RESAMPLING_MODES @@ -121,7 +121,7 @@ def tile_fill_missing(im: Image.Image, tile_size: int = 16, seed: Optional[int] @invocation("infill_rgba", title="Solid Color Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1") -class InfillColorInvocation(BaseInvocation, WithMetadata): +class InfillColorInvocation(BaseInvocation, WithMetadata, WithBoard): """Infills transparent areas of an image with a solid color""" image: ImageField = InputField(description="The image to infill") @@ -144,7 +144,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: @invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2") -class InfillTileInvocation(BaseInvocation, WithMetadata): +class InfillTileInvocation(BaseInvocation, WithMetadata, WithBoard): """Infills transparent areas of an image with tiles of the image""" image: ImageField = InputField(description="The image to infill") @@ -170,7 +170,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: @invocation( "infill_patchmatch", title="PatchMatch Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1" ) -class InfillPatchMatchInvocation(BaseInvocation, WithMetadata): +class InfillPatchMatchInvocation(BaseInvocation, WithMetadata, WithBoard): """Infills transparent areas of an image using the PatchMatch algorithm""" image: ImageField = InputField(description="The image to infill") @@ -209,7 +209,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: @invocation("infill_lama", title="LaMa Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1") -class LaMaInfillInvocation(BaseInvocation, WithMetadata): +class LaMaInfillInvocation(BaseInvocation, WithMetadata, WithBoard): """Infills transparent areas of an image using the LaMa model""" image: ImageField = InputField(description="The image to infill") @@ -225,7 +225,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: @invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1") -class CV2InfillInvocation(BaseInvocation, WithMetadata): +class CV2InfillInvocation(BaseInvocation, WithMetadata, WithBoard): """Infills transparent areas of an image using OpenCV Inpainting""" image: ImageField = InputField(description="The image to infill") diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 5e36e73ec8f..5449ec9af7a 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -33,6 +33,7 @@ LatentsField, OutputField, UIType, + WithBoard, WithMetadata, ) from invokeai.app.invocations.ip_adapter import IPAdapterField @@ -762,7 +763,7 @@ def _lora_loader(): category="latents", version="1.2.1", ) -class LatentsToImageInvocation(BaseInvocation, WithMetadata): +class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard): """Generates an image from latents.""" latents: LatentsField = InputField( diff --git a/invokeai/app/invocations/primitives.py b/invokeai/app/invocations/primitives.py index c90d3230b2b..a77939943ae 100644 --- a/invokeai/app/invocations/primitives.py +++ b/invokeai/app/invocations/primitives.py @@ -255,9 +255,7 @@ class ImageCollectionOutput(BaseInvocationOutput): @invocation("image", title="Image Primitive", tags=["primitives", "image"], category="primitives", version="1.0.1") -class ImageInvocation( - BaseInvocation, -): +class ImageInvocation(BaseInvocation): """An image primitive value""" image: ImageField = InputField(description="The image to load") diff --git a/invokeai/app/invocations/tiles.py b/invokeai/app/invocations/tiles.py index 19ece423761..cb5373bbf75 100644 --- a/invokeai/app/invocations/tiles.py +++ b/invokeai/app/invocations/tiles.py @@ -11,7 +11,7 @@ invocation, invocation_output, ) -from invokeai.app.invocations.fields import ImageField, Input, InputField, OutputField, WithMetadata +from invokeai.app.invocations.fields import ImageField, Input, InputField, OutputField, WithBoard, WithMetadata from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.tiles.tiles import ( @@ -232,7 +232,7 @@ def invoke(self, context: InvocationContext) -> PairTileImageOutput: version="1.1.0", classification=Classification.Beta, ) -class MergeTilesToImageInvocation(BaseInvocation, WithMetadata): +class MergeTilesToImageInvocation(BaseInvocation, WithMetadata, WithBoard): """Merge multiple tile images into a single image.""" # Inputs diff --git a/invokeai/app/invocations/upscale.py b/invokeai/app/invocations/upscale.py index 71ef7ca3aa0..2e2a6ce8813 100644 --- a/invokeai/app/invocations/upscale.py +++ b/invokeai/app/invocations/upscale.py @@ -16,7 +16,7 @@ from invokeai.backend.util.devices import choose_torch_device from .baseinvocation import BaseInvocation, invocation -from .fields import InputField, WithMetadata +from .fields import InputField, WithBoard, WithMetadata # TODO: Populate this from disk? # TODO: Use model manager to load? @@ -32,7 +32,7 @@ @invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.3.1") -class ESRGANInvocation(BaseInvocation, WithMetadata): +class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard): """Upscales an image using RealESRGAN.""" image: ImageField = InputField(description="The input image") diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 8aaa5233afd..97a62246fbc 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -5,10 +5,10 @@ from PIL.Image import Image from torch import Tensor -from invokeai.app.invocations.fields import MetadataField, WithMetadata +from invokeai.app.invocations.fields import MetadataField, WithBoard, WithMetadata from invokeai.app.services.boards.boards_common import BoardDTO from invokeai.app.services.config.config_default import InvokeAIAppConfig -from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin +from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin from invokeai.app.services.images.images_common import ImageDTO from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID @@ -158,7 +158,9 @@ def save( If the current queue item has a workflow or metadata, it is automatically saved with the image. :param image: The image to save, as a PIL image. - :param board_id: The board ID to add the image to, if it should be added. + :param board_id: The board ID to add the image to, if it should be added. It the invocation \ + inherits from `WithBoard`, that board will be used automatically. **Use this only if \ + you want to override or provide a board manually!** :param image_category: The category of the image. Only the GENERAL category is added \ to the gallery. :param metadata: The metadata to save with the image, if it should have any. If the \ @@ -173,11 +175,15 @@ def save( else metadata ) + # If the invocation inherits WithBoard, use that. Else, use the board_id passed in. + board_ = self._context_data.invocation.board if isinstance(self._context_data.invocation, WithBoard) else None + board_id_ = board_.board_id if board_ is not None else board_id + return self._services.images.create( image=image, is_intermediate=self._context_data.invocation.is_intermediate, image_category=image_category, - board_id=board_id, + board_id=board_id_, metadata=metadata_, image_origin=ResourceOrigin.INTERNAL, workflow=self._context_data.workflow, @@ -209,32 +215,6 @@ def get_dto(self, image_name: str) -> ImageDTO: """ return self._services.images.get_dto(image_name) - def update( - self, - image_name: str, - board_id: Optional[str] = None, - is_intermediate: Optional[bool] = False, - ) -> ImageDTO: - """ - Updates an image, returning its updated DTO. - - It is not suggested to update images saved by earlier nodes, as this can cause confusion for users. - - If you use this method, you *must* return the image as an :class:`ImageOutput` for the gallery to - get the updated image. - - :param image_name: The name of the image to update. - :param board_id: The board ID to add the image to, if it should be added. - :param is_intermediate: Whether the image is an intermediate. Intermediate images aren't added to the gallery. - """ - if is_intermediate is not None: - self._services.images.update(image_name, ImageRecordChanges(is_intermediate=is_intermediate)) - if board_id is None: - self._services.board_images.remove_image_from_board(image_name) - else: - self._services.board_images.add_image_to_board(image_name, board_id) - return self._services.images.get_dto(image_name) - class LatentsInterface(InvocationContextInterface): def save(self, tensor: Tensor) -> str: From 365b0b36a549f5499d0c0e83bcfd4c8a429fee15 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 16:34:40 +1100 Subject: [PATCH 029/340] chore(ui): regen types --- .../frontend/web/src/services/api/schema.ts | 223 ++++++++---------- 1 file changed, 96 insertions(+), 127 deletions(-) diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index da036b6d40a..45358ed97d5 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -968,6 +968,8 @@ export type components = { * @description Creates a blank image and forwards it to the pipeline */ BlankImageInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -1860,6 +1862,8 @@ export type components = { * @description Infills transparent areas of an image using OpenCV Inpainting */ CV2InfillInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -2095,6 +2099,8 @@ export type components = { * @description Canny edge detection for ControlNet */ CannyImageProcessorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -2482,6 +2488,8 @@ export type components = { * using a mask to only color-correct certain regions of the target image. */ ColorCorrectInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -2590,6 +2598,8 @@ export type components = { * @description Generates a color map from the provided image */ ColorMapImageProcessorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -2797,6 +2807,8 @@ export type components = { * @description Applies content shuffle processing to image */ ContentShuffleImageProcessorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -3442,6 +3454,8 @@ export type components = { * @description Simple inpaint using opencv. */ CvInpaintInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -3677,6 +3691,8 @@ export type components = { * @description Generates a depth map based on the Depth Anything algorithm */ DepthAnythingImageProcessorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -3910,6 +3926,8 @@ export type components = { * @description Upscales an image using RealESRGAN. */ ESRGANInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -4041,6 +4059,8 @@ export type components = { * @description Outputs an image with detected face IDs printed on each face. For use with other FaceTools. */ FaceIdentifierInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -4873,6 +4893,8 @@ export type components = { * @description Applies HED edge detection to image */ HedImageProcessorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -5324,6 +5346,8 @@ export type components = { * @description Blurs an image */ ImageBlurInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -5382,6 +5406,8 @@ export type components = { * @description Gets a channel from an image. */ ImageChannelInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -5422,6 +5448,8 @@ export type components = { * @description Scale a specific color channel of an image. */ ImageChannelMultiplyInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -5473,6 +5501,8 @@ export type components = { * @description Add or subtract a value from a specific color channel of an image. */ ImageChannelOffsetInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -5640,6 +5670,8 @@ export type components = { * @description Converts an image to a different mode. */ ImageConvertInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -5680,6 +5712,8 @@ export type components = { * @description Crops an image to a specified box. The box can be outside of the image. */ ImageCropInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -5949,6 +5983,8 @@ export type components = { * @description Adjusts the Hue of an image. */ ImageHueAdjustmentInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -5988,6 +6024,8 @@ export type components = { * @description Inverse linear interpolation of all pixels of an image */ ImageInverseLerpInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -6064,6 +6102,8 @@ export type components = { * @description Linear interpolation of all pixels of an image */ ImageLerpInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -6109,6 +6149,8 @@ export type components = { * @description Multiplies two images together using `PIL.ImageChops.multiply()`. */ ImageMultiplyInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -6144,6 +6186,8 @@ export type components = { * @description Add blur to NSFW-flagged images */ ImageNSFWBlurInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -6252,6 +6296,8 @@ export type components = { * @description Pastes an image into another image. */ ImagePasteInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -6337,6 +6383,8 @@ export type components = { * @description Resizes an image to specific dimensions */ ImageResizeInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -6446,6 +6494,8 @@ export type components = { * @description Scales an image by a factor */ ImageScaleInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -6621,6 +6671,8 @@ export type components = { * @description Add an invisible watermark to an image */ ImageWatermarkInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -6676,6 +6728,8 @@ export type components = { * @description Infills transparent areas of an image with a solid color */ InfillColorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -6719,6 +6773,8 @@ export type components = { * @description Infills transparent areas of an image using the PatchMatch algorithm */ InfillPatchMatchInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -6765,6 +6821,8 @@ export type components = { * @description Infills transparent areas of an image with tiles of the image */ InfillTileInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -7065,6 +7123,8 @@ export type components = { * @description Infills transparent areas of an image using the LaMa model */ LaMaInfillInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -7093,96 +7153,6 @@ export type components = { */ type: "infill_lama"; }; - /** - * Latent Consistency MonoNode - * @description Wrapper node around diffusers LatentConsistencyTxt2ImgPipeline - */ - LatentConsistencyInvocation: { - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** - * Prompt - * @description The prompt to use - */ - prompt?: string; - /** - * Num Inference Steps - * @description The number of inference steps to use, 4-8 recommended - * @default 8 - */ - num_inference_steps?: number; - /** - * Guidance Scale - * @description The guidance scale to use - * @default 8 - */ - guidance_scale?: number; - /** - * Batches - * @description The number of batches to use - * @default 1 - */ - batches?: number; - /** - * Images Per Batch - * @description The number of images per batch to use - * @default 1 - */ - images_per_batch?: number; - /** - * Seeds - * @description List of noise seeds to use - */ - seeds?: number[]; - /** - * Lcm Origin Steps - * @description The lcm origin steps to use - * @default 50 - */ - lcm_origin_steps?: number; - /** - * Width - * @description The width to use - * @default 512 - */ - width?: number; - /** - * Height - * @description The height to use - * @default 512 - */ - height?: number; - /** - * Precision - * @description floating point precision - * @default fp16 - * @enum {string} - */ - precision?: "fp16" | "fp32"; - /** @description The board to save the image to */ - board?: components["schemas"]["BoardField"]; - /** - * type - * @default latent_consistency_mononode - * @constant - */ - type: "latent_consistency_mononode"; - }; /** * Latents Collection Primitive * @description A collection of latents tensor primitive values @@ -7310,6 +7280,8 @@ export type components = { * @description Generates an image from latents. */ LatentsToImageInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -7357,6 +7329,8 @@ export type components = { * @description Applies leres processing to image */ LeresImageProcessorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -7441,46 +7415,13 @@ export type components = { /** @description Type of commercial use allowed or 'No' if no commercial use is allowed. */ AllowCommercialUse?: components["schemas"]["CommercialUsage"]; }; - /** - * Linear UI Image Output - * @description Handles Linear UI Image Outputting tasks. - */ - LinearUIOutputInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default false - */ - use_cache?: boolean; - /** @description The image to process */ - image?: components["schemas"]["ImageField"]; - /** @description The board to save the image to */ - board?: components["schemas"]["BoardField"] | null; - /** - * type - * @default linear_ui_output - * @constant - */ - type: "linear_ui_output"; - }; /** * Lineart Anime Processor * @description Applies line art anime processing to image */ LineartAnimeImageProcessorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -7526,6 +7467,8 @@ export type components = { * @description Applies line art processing to image */ LineartImageProcessorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -7954,6 +7897,8 @@ export type components = { * @description Combine two masks together by multiplying them using `PIL.ImageChops.multiply()`. */ MaskCombineInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -7989,6 +7934,8 @@ export type components = { * @description Applies an edge mask to an image */ MaskEdgeInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -8042,6 +7989,8 @@ export type components = { * @description Extracts the alpha channel of an image as a mask. */ MaskFromAlphaInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -8122,6 +8071,8 @@ export type components = { * @description Applies mediapipe face processing to image */ MediapipeFaceProcessorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -8238,6 +8189,8 @@ export type components = { * @description Merge multiple tile images into a single image. */ MergeTilesToImageInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -8404,6 +8357,8 @@ export type components = { * @description Applies Midas depth processing to image */ MidasDepthImageProcessorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -8449,6 +8404,8 @@ export type components = { * @description Applies MLSD processing to image */ MlsdImageProcessorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -8961,6 +8918,8 @@ export type components = { * @description Applies NormalBae processing to image */ NormalbaeImageProcessorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -9591,6 +9550,8 @@ export type components = { * @description Applies PIDI processing to image */ PidiImageProcessorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -10443,6 +10404,8 @@ export type components = { * @description Saves an image. Unlike an image primitive, this invocation stores a copy of the image. */ SaveImageInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -10464,8 +10427,6 @@ export type components = { use_cache?: boolean; /** @description The image to process */ image?: components["schemas"]["ImageField"]; - /** @description The board to save the image to */ - board?: components["schemas"]["BoardField"]; /** * type * @default save_image @@ -10651,6 +10612,8 @@ export type components = { * @description Applies segment anything processing to image */ SegmentAnythingProcessorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -12189,6 +12152,8 @@ export type components = { * @description Tile resampler processor */ TileResamplerProcessorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -12378,6 +12343,8 @@ export type components = { * @description Applies an unsharp mask filter to an image */ UnsharpMaskInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -12846,6 +12813,8 @@ export type components = { * @description Applies Zoe depth processing to image */ ZoeDepthImageProcessorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** From c76f963acca8e735fb61a738276f3b74d8459119 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 16:41:24 +1100 Subject: [PATCH 030/340] feat(ui): revise graphs to not use `LinearUIOutputInvocation` See this comment for context: https://github.com/invoke-ai/InvokeAI/pull/5491#discussion_r1480760629 - Remove this now-unnecessary node from all graphs - Update graphs' terminal image-outputting nodes' `is_intermediate` and `board` fields appropriately - Add util function to prepare the `board` field, tidy the utils - Update `socketInvocationComplete` listener to work correctly with this change I've manually tested all graph permutations that were changed (I think this is all...) to ensure images go to the gallery as expected: - ad-hoc upscaling - t2i w/ sd1.5 - t2i w/ sd1.5 & hrf - t2i w/ sdxl - t2i w/ sdxl + refiner - i2i w/ sd1.5 - i2i w/ sdxl - i2i w/ sdxl + refiner - canvas t2i w/ sd1.5 - canvas t2i w/ sdxl - canvas t2i w/ sdxl + refiner - canvas i2i w/ sd1.5 - canvas i2i w/ sdxl - canvas i2i w/ sdxl + refiner - canvas inpaint w/ sd1.5 - canvas inpaint w/ sdxl - canvas inpaint w/ sdxl + refiner - canvas outpaint w/ sd1.5 - canvas outpaint w/ sdxl - canvas outpaint w/ sdxl + refiner --- .../socketio/socketInvocationComplete.ts | 9 +-- .../listeners/upscaleRequested.ts | 6 +- .../nodes/util/graph/addHrfToGraph.ts | 4 +- .../nodes/util/graph/addLinearUIOutputNode.ts | 78 ------------------- .../nodes/util/graph/addNSFWCheckerToGraph.ts | 4 +- .../nodes/util/graph/addSDXLRefinerToGraph.ts | 2 +- .../nodes/util/graph/addWatermarkerToGraph.ts | 9 +-- .../util/graph/buildAdHocUpscaleGraph.ts | 40 +++------- .../graph/buildCanvasImageToImageGraph.ts | 13 ++-- .../util/graph/buildCanvasInpaintGraph.ts | 7 +- .../util/graph/buildCanvasOutpaintGraph.ts | 7 +- .../graph/buildCanvasSDXLImageToImageGraph.ts | 8 +- .../util/graph/buildCanvasSDXLInpaintGraph.ts | 8 +- .../graph/buildCanvasSDXLOutpaintGraph.ts | 8 +- .../graph/buildCanvasSDXLTextToImageGraph.ts | 11 ++- .../util/graph/buildCanvasTextToImageGraph.ts | 10 +-- .../graph/buildLinearImageToImageGraph.ts | 7 +- .../graph/buildLinearSDXLImageToImageGraph.ts | 8 +- .../graph/buildLinearSDXLTextToImageGraph.ts | 8 +- .../util/graph/buildLinearTextToImageGraph.ts | 7 +- .../features/nodes/util/graph/constants.ts | 1 - .../nodes/util/graph/getSDXLStylePrompt.ts | 11 --- .../nodes/util/graph/graphBuilderUtils.ts | 38 +++++++++ .../frontend/web/src/services/api/types.ts | 1 - 24 files changed, 108 insertions(+), 197 deletions(-) delete mode 100644 invokeai/frontend/web/src/features/nodes/util/graph/addLinearUIOutputNode.ts delete mode 100644 invokeai/frontend/web/src/features/nodes/util/graph/getSDXLStylePrompt.ts create mode 100644 invokeai/frontend/web/src/features/nodes/util/graph/graphBuilderUtils.ts diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts index d49f35cd2ab..75fa9e10949 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts @@ -4,7 +4,7 @@ import { addImageToStagingArea } from 'features/canvas/store/canvasSlice'; import { boardIdSelected, galleryViewChanged, imageSelected } from 'features/gallery/store/gallerySlice'; import { IMAGE_CATEGORIES } from 'features/gallery/store/types'; import { isImageOutput } from 'features/nodes/types/common'; -import { LINEAR_UI_OUTPUT, nodeIDDenyList } from 'features/nodes/util/graph/constants'; +import { CANVAS_OUTPUT } from 'features/nodes/util/graph/constants'; import { boardsApi } from 'services/api/endpoints/boards'; import { imagesApi } from 'services/api/endpoints/images'; import { imagesAdapter } from 'services/api/util'; @@ -24,10 +24,9 @@ export const addInvocationCompleteEventListener = () => { const { data } = action.payload; log.debug({ data: parseify(data) }, `Invocation complete (${action.payload.data.node.type})`); - const { result, node, queue_batch_id, source_node_id } = data; - + const { result, node, queue_batch_id } = data; // This complete event has an associated image output - if (isImageOutput(result) && !nodeTypeDenylist.includes(node.type) && !nodeIDDenyList.includes(source_node_id)) { + if (isImageOutput(result) && !nodeTypeDenylist.includes(node.type)) { const { image_name } = result.image; const { canvas, gallery } = getState(); @@ -42,7 +41,7 @@ export const addInvocationCompleteEventListener = () => { imageDTORequest.unsubscribe(); // Add canvas images to the staging area - if (canvas.batchIds.includes(queue_batch_id) && [LINEAR_UI_OUTPUT].includes(data.source_node_id)) { + if (canvas.batchIds.includes(queue_batch_id) && data.source_node_id === CANVAS_OUTPUT) { dispatch(addImageToStagingArea(imageDTO)); } diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/upscaleRequested.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/upscaleRequested.ts index 46f55ef21ff..ab989301796 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/upscaleRequested.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/upscaleRequested.ts @@ -39,16 +39,12 @@ export const addUpscaleRequestedListener = () => { return; } - const { esrganModelName } = state.postprocessing; - const { autoAddBoardId } = state.gallery; - const enqueueBatchArg: BatchConfig = { prepend: true, batch: { graph: buildAdHocUpscaleGraph({ image_name, - esrganModelName, - autoAddBoardId, + state, }), runs: 1, }, diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addHrfToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addHrfToGraph.ts index 8a4448833cf..5632cfd1122 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/addHrfToGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/addHrfToGraph.ts @@ -1,6 +1,7 @@ import { logger } from 'app/logging/logger'; import type { RootState } from 'app/store/store'; import { roundToMultiple } from 'common/util/roundDownToMultiple'; +import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils'; import { selectOptimalDimension } from 'features/parameters/store/generationSlice'; import type { DenoiseLatentsInvocation, @@ -322,7 +323,8 @@ export const addHrfToGraph = (state: RootState, graph: NonNullableGraph): void = type: 'l2i', id: LATENTS_TO_IMAGE_HRF_HR, fp32: originalLatentsToImageNode?.fp32, - is_intermediate: true, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), }; graph.edges.push( { diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addLinearUIOutputNode.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addLinearUIOutputNode.ts deleted file mode 100644 index 5c78ad804ed..00000000000 --- a/invokeai/frontend/web/src/features/nodes/util/graph/addLinearUIOutputNode.ts +++ /dev/null @@ -1,78 +0,0 @@ -import type { RootState } from 'app/store/store'; -import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; -import type { LinearUIOutputInvocation, NonNullableGraph } from 'services/api/types'; - -import { - CANVAS_OUTPUT, - LATENTS_TO_IMAGE, - LATENTS_TO_IMAGE_HRF_HR, - LINEAR_UI_OUTPUT, - NSFW_CHECKER, - WATERMARKER, -} from './constants'; - -/** - * Set the `use_cache` field on the linear/canvas graph's final image output node to False. - */ -export const addLinearUIOutputNode = (state: RootState, graph: NonNullableGraph): void => { - const activeTabName = activeTabNameSelector(state); - const is_intermediate = activeTabName === 'unifiedCanvas' ? !state.canvas.shouldAutoSave : false; - const { autoAddBoardId } = state.gallery; - - const linearUIOutputNode: LinearUIOutputInvocation = { - id: LINEAR_UI_OUTPUT, - type: 'linear_ui_output', - is_intermediate, - use_cache: false, - board: autoAddBoardId === 'none' ? undefined : { board_id: autoAddBoardId }, - }; - - graph.nodes[LINEAR_UI_OUTPUT] = linearUIOutputNode; - - const destination = { - node_id: LINEAR_UI_OUTPUT, - field: 'image', - }; - - if (WATERMARKER in graph.nodes) { - graph.edges.push({ - source: { - node_id: WATERMARKER, - field: 'image', - }, - destination, - }); - } else if (NSFW_CHECKER in graph.nodes) { - graph.edges.push({ - source: { - node_id: NSFW_CHECKER, - field: 'image', - }, - destination, - }); - } else if (CANVAS_OUTPUT in graph.nodes) { - graph.edges.push({ - source: { - node_id: CANVAS_OUTPUT, - field: 'image', - }, - destination, - }); - } else if (LATENTS_TO_IMAGE_HRF_HR in graph.nodes) { - graph.edges.push({ - source: { - node_id: LATENTS_TO_IMAGE_HRF_HR, - field: 'image', - }, - destination, - }); - } else if (LATENTS_TO_IMAGE in graph.nodes) { - graph.edges.push({ - source: { - node_id: LATENTS_TO_IMAGE, - field: 'image', - }, - destination, - }); - } -}; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addNSFWCheckerToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addNSFWCheckerToGraph.ts index 4a8e77abfa0..35fc3246890 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/addNSFWCheckerToGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/addNSFWCheckerToGraph.ts @@ -2,6 +2,7 @@ import type { RootState } from 'app/store/store'; import type { ImageNSFWBlurInvocation, LatentsToImageInvocation, NonNullableGraph } from 'services/api/types'; import { LATENTS_TO_IMAGE, NSFW_CHECKER } from './constants'; +import { getBoardField, getIsIntermediate } from './graphBuilderUtils'; export const addNSFWCheckerToGraph = ( state: RootState, @@ -21,7 +22,8 @@ export const addNSFWCheckerToGraph = ( const nsfwCheckerNode: ImageNSFWBlurInvocation = { id: NSFW_CHECKER, type: 'img_nsfw', - is_intermediate: true, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), }; graph.nodes[NSFW_CHECKER] = nsfwCheckerNode as ImageNSFWBlurInvocation; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addSDXLRefinerToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addSDXLRefinerToGraph.ts index 708353e4d6a..fc4d998969d 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/addSDXLRefinerToGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/addSDXLRefinerToGraph.ts @@ -24,7 +24,7 @@ import { SDXL_REFINER_POSITIVE_CONDITIONING, SDXL_REFINER_SEAMLESS, } from './constants'; -import { getSDXLStylePrompts } from './getSDXLStylePrompt'; +import { getSDXLStylePrompts } from './graphBuilderUtils'; import { upsertMetadata } from './metadata'; export const addSDXLRefinerToGraph = ( diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addWatermarkerToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addWatermarkerToGraph.ts index 99c5c07be47..61beb11df49 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/addWatermarkerToGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/addWatermarkerToGraph.ts @@ -1,5 +1,4 @@ import type { RootState } from 'app/store/store'; -import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; import type { ImageNSFWBlurInvocation, ImageWatermarkInvocation, @@ -8,16 +7,13 @@ import type { } from 'services/api/types'; import { LATENTS_TO_IMAGE, NSFW_CHECKER, WATERMARKER } from './constants'; +import { getBoardField, getIsIntermediate } from './graphBuilderUtils'; export const addWatermarkerToGraph = ( state: RootState, graph: NonNullableGraph, nodeIdToAddTo = LATENTS_TO_IMAGE ): void => { - const activeTabName = activeTabNameSelector(state); - - const is_intermediate = activeTabName === 'unifiedCanvas' ? !state.canvas.shouldAutoSave : false; - const nodeToAddTo = graph.nodes[nodeIdToAddTo] as LatentsToImageInvocation | undefined; const nsfwCheckerNode = graph.nodes[NSFW_CHECKER] as ImageNSFWBlurInvocation | undefined; @@ -30,7 +26,8 @@ export const addWatermarkerToGraph = ( const watermarkerNode: ImageWatermarkInvocation = { id: WATERMARKER, type: 'img_watermark', - is_intermediate, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), }; graph.nodes[WATERMARKER] = watermarkerNode; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildAdHocUpscaleGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildAdHocUpscaleGraph.ts index fa20206d91f..52c09b1db06 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildAdHocUpscaleGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildAdHocUpscaleGraph.ts @@ -1,51 +1,33 @@ -import type { BoardId } from 'features/gallery/store/types'; -import type { ParamESRGANModelName } from 'features/parameters/store/postprocessingSlice'; -import type { ESRGANInvocation, Graph, LinearUIOutputInvocation, NonNullableGraph } from 'services/api/types'; +import type { RootState } from 'app/store/store'; +import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils'; +import type { ESRGANInvocation, Graph, NonNullableGraph } from 'services/api/types'; -import { ESRGAN, LINEAR_UI_OUTPUT } from './constants'; +import { ESRGAN } from './constants'; import { addCoreMetadataNode, upsertMetadata } from './metadata'; type Arg = { image_name: string; - esrganModelName: ParamESRGANModelName; - autoAddBoardId: BoardId; + state: RootState; }; -export const buildAdHocUpscaleGraph = ({ image_name, esrganModelName, autoAddBoardId }: Arg): Graph => { +export const buildAdHocUpscaleGraph = ({ image_name, state }: Arg): Graph => { + const { esrganModelName } = state.postprocessing; + const realesrganNode: ESRGANInvocation = { id: ESRGAN, type: 'esrgan', image: { image_name }, model_name: esrganModelName, - is_intermediate: true, - }; - - const linearUIOutputNode: LinearUIOutputInvocation = { - id: LINEAR_UI_OUTPUT, - type: 'linear_ui_output', - use_cache: false, - is_intermediate: false, - board: autoAddBoardId === 'none' ? undefined : { board_id: autoAddBoardId }, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), }; const graph: NonNullableGraph = { id: `adhoc-esrgan-graph`, nodes: { [ESRGAN]: realesrganNode, - [LINEAR_UI_OUTPUT]: linearUIOutputNode, }, - edges: [ - { - source: { - node_id: ESRGAN, - field: 'image', - }, - destination: { - node_id: LINEAR_UI_OUTPUT, - field: 'image', - }, - }, - ], + edges: [], }; addCoreMetadataNode(graph, {}, ESRGAN); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasImageToImageGraph.ts index 3002e05441b..bc6a83f4fa6 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasImageToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasImageToImageGraph.ts @@ -1,10 +1,10 @@ import { logger } from 'app/logging/logger'; import type { RootState } from 'app/store/store'; +import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils'; import type { ImageDTO, ImageToLatentsInvocation, NonNullableGraph } from 'services/api/types'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph'; @@ -132,7 +132,8 @@ export const buildCanvasImageToImageGraph = (state: RootState, initialImage: Ima [CANVAS_OUTPUT]: { type: 'l2i', id: CANVAS_OUTPUT, - is_intermediate, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), use_cache: false, }, }, @@ -242,7 +243,8 @@ export const buildCanvasImageToImageGraph = (state: RootState, initialImage: Ima graph.nodes[CANVAS_OUTPUT] = { id: CANVAS_OUTPUT, type: 'img_resize', - is_intermediate, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), width: width, height: height, use_cache: false, @@ -284,7 +286,8 @@ export const buildCanvasImageToImageGraph = (state: RootState, initialImage: Ima graph.nodes[CANVAS_OUTPUT] = { type: 'l2i', id: CANVAS_OUTPUT, - is_intermediate, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), fp32, use_cache: false, }; @@ -355,7 +358,5 @@ export const buildCanvasImageToImageGraph = (state: RootState, initialImage: Ima addWatermarkerToGraph(state, graph, CANVAS_OUTPUT); } - addLinearUIOutputNode(state, graph); - return graph; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasInpaintGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasInpaintGraph.ts index bb52a44a8e4..d983b9cf4f5 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasInpaintGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasInpaintGraph.ts @@ -1,5 +1,6 @@ import { logger } from 'app/logging/logger'; import type { RootState } from 'app/store/store'; +import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils'; import type { CreateDenoiseMaskInvocation, ImageBlurInvocation, @@ -12,7 +13,6 @@ import type { import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph'; @@ -191,7 +191,8 @@ export const buildCanvasInpaintGraph = ( [CANVAS_OUTPUT]: { type: 'color_correct', id: CANVAS_OUTPUT, - is_intermediate, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), reference: canvasInitImage, use_cache: false, }, @@ -663,7 +664,5 @@ export const buildCanvasInpaintGraph = ( addWatermarkerToGraph(state, graph, CANVAS_OUTPUT); } - addLinearUIOutputNode(state, graph); - return graph; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasOutpaintGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasOutpaintGraph.ts index b82b55cfee7..1d028943818 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasOutpaintGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasOutpaintGraph.ts @@ -1,5 +1,6 @@ import { logger } from 'app/logging/logger'; import type { RootState } from 'app/store/store'; +import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils'; import type { ImageDTO, ImageToLatentsInvocation, @@ -11,7 +12,6 @@ import type { import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph'; @@ -200,7 +200,8 @@ export const buildCanvasOutpaintGraph = ( [CANVAS_OUTPUT]: { type: 'color_correct', id: CANVAS_OUTPUT, - is_intermediate, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), use_cache: false, }, }, @@ -769,7 +770,5 @@ export const buildCanvasOutpaintGraph = ( addWatermarkerToGraph(state, graph, CANVAS_OUTPUT); } - addLinearUIOutputNode(state, graph); - return graph; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLImageToImageGraph.ts index 1b586371a02..58269afce3f 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLImageToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLImageToImageGraph.ts @@ -4,7 +4,6 @@ import type { ImageDTO, ImageToLatentsInvocation, NonNullableGraph } from 'servi import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph'; import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph'; @@ -26,7 +25,7 @@ import { SDXL_REFINER_SEAMLESS, SEAMLESS, } from './constants'; -import { getSDXLStylePrompts } from './getSDXLStylePrompt'; +import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils'; import { addCoreMetadataNode } from './metadata'; /** @@ -246,7 +245,8 @@ export const buildCanvasSDXLImageToImageGraph = (state: RootState, initialImage: graph.nodes[CANVAS_OUTPUT] = { id: CANVAS_OUTPUT, type: 'img_resize', - is_intermediate, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), width: width, height: height, use_cache: false, @@ -368,7 +368,5 @@ export const buildCanvasSDXLImageToImageGraph = (state: RootState, initialImage: addWatermarkerToGraph(state, graph, CANVAS_OUTPUT); } - addLinearUIOutputNode(state, graph); - return graph; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLInpaintGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLInpaintGraph.ts index 00fea9a37e6..5902dee2fc4 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLInpaintGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLInpaintGraph.ts @@ -12,7 +12,6 @@ import type { import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph'; import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph'; @@ -44,7 +43,7 @@ import { SDXL_REFINER_SEAMLESS, SEAMLESS, } from './constants'; -import { getSDXLStylePrompts } from './getSDXLStylePrompt'; +import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils'; /** * Builds the Canvas tab's Inpaint graph. @@ -190,7 +189,8 @@ export const buildCanvasSDXLInpaintGraph = ( [CANVAS_OUTPUT]: { type: 'color_correct', id: CANVAS_OUTPUT, - is_intermediate, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), reference: canvasInitImage, use_cache: false, }, @@ -687,7 +687,5 @@ export const buildCanvasSDXLInpaintGraph = ( addWatermarkerToGraph(state, graph, CANVAS_OUTPUT); } - addLinearUIOutputNode(state, graph); - return graph; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLOutpaintGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLOutpaintGraph.ts index f85760d8f2e..7a78750e8d2 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLOutpaintGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLOutpaintGraph.ts @@ -11,7 +11,6 @@ import type { import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph'; import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph'; @@ -46,7 +45,7 @@ import { SDXL_REFINER_SEAMLESS, SEAMLESS, } from './constants'; -import { getSDXLStylePrompts } from './getSDXLStylePrompt'; +import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils'; /** * Builds the Canvas tab's Outpaint graph. @@ -199,7 +198,8 @@ export const buildCanvasSDXLOutpaintGraph = ( [CANVAS_OUTPUT]: { type: 'color_correct', id: CANVAS_OUTPUT, - is_intermediate, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), use_cache: false, }, }, @@ -786,7 +786,5 @@ export const buildCanvasSDXLOutpaintGraph = ( addWatermarkerToGraph(state, graph, CANVAS_OUTPUT); } - addLinearUIOutputNode(state, graph); - return graph; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLTextToImageGraph.ts index 91d9da4cb5d..22da39c67da 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLTextToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLTextToImageGraph.ts @@ -4,7 +4,6 @@ import type { NonNullableGraph } from 'services/api/types'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph'; import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph'; @@ -24,7 +23,7 @@ import { SDXL_REFINER_SEAMLESS, SEAMLESS, } from './constants'; -import { getSDXLStylePrompts } from './getSDXLStylePrompt'; +import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils'; import { addCoreMetadataNode } from './metadata'; /** @@ -222,7 +221,8 @@ export const buildCanvasSDXLTextToImageGraph = (state: RootState): NonNullableGr graph.nodes[CANVAS_OUTPUT] = { id: CANVAS_OUTPUT, type: 'img_resize', - is_intermediate, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), width: width, height: height, use_cache: false, @@ -254,7 +254,8 @@ export const buildCanvasSDXLTextToImageGraph = (state: RootState): NonNullableGr graph.nodes[CANVAS_OUTPUT] = { type: 'l2i', id: CANVAS_OUTPUT, - is_intermediate, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), fp32, use_cache: false, }; @@ -330,7 +331,5 @@ export const buildCanvasSDXLTextToImageGraph = (state: RootState): NonNullableGr addWatermarkerToGraph(state, graph, CANVAS_OUTPUT); } - addLinearUIOutputNode(state, graph); - return graph; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasTextToImageGraph.ts index 967dd3ff4a5..93f0470c7ad 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasTextToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasTextToImageGraph.ts @@ -1,10 +1,10 @@ import { logger } from 'app/logging/logger'; import type { RootState } from 'app/store/store'; +import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils'; import type { NonNullableGraph } from 'services/api/types'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph'; @@ -211,7 +211,8 @@ export const buildCanvasTextToImageGraph = (state: RootState): NonNullableGraph graph.nodes[CANVAS_OUTPUT] = { id: CANVAS_OUTPUT, type: 'img_resize', - is_intermediate, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), width: width, height: height, use_cache: false, @@ -243,7 +244,8 @@ export const buildCanvasTextToImageGraph = (state: RootState): NonNullableGraph graph.nodes[CANVAS_OUTPUT] = { type: 'l2i', id: CANVAS_OUTPUT, - is_intermediate, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), fp32, use_cache: false, }; @@ -310,7 +312,5 @@ export const buildCanvasTextToImageGraph = (state: RootState): NonNullableGraph addWatermarkerToGraph(state, graph, CANVAS_OUTPUT); } - addLinearUIOutputNode(state, graph); - return graph; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearImageToImageGraph.ts index c76776d94d3..d1f1546b23b 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearImageToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearImageToImageGraph.ts @@ -1,10 +1,10 @@ import { logger } from 'app/logging/logger'; import type { RootState } from 'app/store/store'; +import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils'; import type { ImageResizeInvocation, ImageToLatentsInvocation, NonNullableGraph } from 'services/api/types'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph'; @@ -117,7 +117,8 @@ export const buildLinearImageToImageGraph = (state: RootState): NonNullableGraph type: 'l2i', id: LATENTS_TO_IMAGE, fp32, - is_intermediate, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), }, [DENOISE_LATENTS]: { type: 'denoise_latents', @@ -358,7 +359,5 @@ export const buildLinearImageToImageGraph = (state: RootState): NonNullableGraph addWatermarkerToGraph(state, graph); } - addLinearUIOutputNode(state, graph); - return graph; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearSDXLImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearSDXLImageToImageGraph.ts index 9ae602bcacb..de4ad7ceceb 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearSDXLImageToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearSDXLImageToImageGraph.ts @@ -4,7 +4,6 @@ import type { ImageResizeInvocation, ImageToLatentsInvocation, NonNullableGraph import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph'; import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph'; @@ -25,7 +24,7 @@ import { SDXL_REFINER_SEAMLESS, SEAMLESS, } from './constants'; -import { getSDXLStylePrompts } from './getSDXLStylePrompt'; +import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils'; import { addCoreMetadataNode } from './metadata'; /** @@ -120,7 +119,8 @@ export const buildLinearSDXLImageToImageGraph = (state: RootState): NonNullableG type: 'l2i', id: LATENTS_TO_IMAGE, fp32, - is_intermediate, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), }, [SDXL_DENOISE_LATENTS]: { type: 'denoise_latents', @@ -380,7 +380,5 @@ export const buildLinearSDXLImageToImageGraph = (state: RootState): NonNullableG addWatermarkerToGraph(state, graph); } - addLinearUIOutputNode(state, graph); - return graph; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearSDXLTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearSDXLTextToImageGraph.ts index 222dc1a3595..58b97b07c75 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearSDXLTextToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearSDXLTextToImageGraph.ts @@ -4,7 +4,6 @@ import type { NonNullableGraph } from 'services/api/types'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph'; import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph'; @@ -23,7 +22,7 @@ import { SDXL_TEXT_TO_IMAGE_GRAPH, SEAMLESS, } from './constants'; -import { getSDXLStylePrompts } from './getSDXLStylePrompt'; +import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils'; import { addCoreMetadataNode } from './metadata'; export const buildLinearSDXLTextToImageGraph = (state: RootState): NonNullableGraph => { @@ -120,7 +119,8 @@ export const buildLinearSDXLTextToImageGraph = (state: RootState): NonNullableGr type: 'l2i', id: LATENTS_TO_IMAGE, fp32, - is_intermediate, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), use_cache: false, }, }, @@ -281,7 +281,5 @@ export const buildLinearSDXLTextToImageGraph = (state: RootState): NonNullableGr addWatermarkerToGraph(state, graph); } - addLinearUIOutputNode(state, graph); - return graph; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearTextToImageGraph.ts index 0a45d91debc..b2b84cfdad7 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearTextToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearTextToImageGraph.ts @@ -1,11 +1,11 @@ import { logger } from 'app/logging/logger'; import type { RootState } from 'app/store/store'; +import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils'; import type { NonNullableGraph } from 'services/api/types'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addHrfToGraph } from './addHrfToGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph'; @@ -119,7 +119,8 @@ export const buildLinearTextToImageGraph = (state: RootState): NonNullableGraph type: 'l2i', id: LATENTS_TO_IMAGE, fp32, - is_intermediate, + is_intermediate: getIsIntermediate(state), + board: getBoardField(state), use_cache: false, }, }, @@ -267,7 +268,5 @@ export const buildLinearTextToImageGraph = (state: RootState): NonNullableGraph addWatermarkerToGraph(state, graph); } - addLinearUIOutputNode(state, graph); - return graph; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/constants.ts b/invokeai/frontend/web/src/features/nodes/util/graph/constants.ts index 363d3191210..767bf25df0a 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/constants.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/constants.ts @@ -9,7 +9,6 @@ export const LATENTS_TO_IMAGE_HRF_LR = 'latents_to_image_hrf_lr'; export const IMAGE_TO_LATENTS_HRF = 'image_to_latents_hrf'; export const RESIZE_HRF = 'resize_hrf'; export const ESRGAN_HRF = 'esrgan_hrf'; -export const LINEAR_UI_OUTPUT = 'linear_ui_output'; export const NSFW_CHECKER = 'nsfw_checker'; export const WATERMARKER = 'invisible_watermark'; export const NOISE = 'noise'; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/getSDXLStylePrompt.ts b/invokeai/frontend/web/src/features/nodes/util/graph/getSDXLStylePrompt.ts deleted file mode 100644 index e1cd8518fdd..00000000000 --- a/invokeai/frontend/web/src/features/nodes/util/graph/getSDXLStylePrompt.ts +++ /dev/null @@ -1,11 +0,0 @@ -import type { RootState } from 'app/store/store'; - -export const getSDXLStylePrompts = (state: RootState): { positiveStylePrompt: string; negativeStylePrompt: string } => { - const { positivePrompt, negativePrompt } = state.generation; - const { positiveStylePrompt, negativeStylePrompt, shouldConcatSDXLStylePrompt } = state.sdxl; - - return { - positiveStylePrompt: shouldConcatSDXLStylePrompt ? positivePrompt : positiveStylePrompt, - negativeStylePrompt: shouldConcatSDXLStylePrompt ? negativePrompt : negativeStylePrompt, - }; -}; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/graphBuilderUtils.ts b/invokeai/frontend/web/src/features/nodes/util/graph/graphBuilderUtils.ts new file mode 100644 index 00000000000..cb6fc9acf1e --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/graph/graphBuilderUtils.ts @@ -0,0 +1,38 @@ +import type { RootState } from 'app/store/store'; +import type { BoardField } from 'features/nodes/types/common'; +import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; + +/** + * Gets the board field, based on the autoAddBoardId setting. + */ +export const getBoardField = (state: RootState): BoardField | undefined => { + const { autoAddBoardId } = state.gallery; + if (autoAddBoardId === 'none') { + return undefined; + } + return { board_id: autoAddBoardId }; +}; + +/** + * Gets the SDXL style prompts, based on the concat setting. + */ +export const getSDXLStylePrompts = (state: RootState): { positiveStylePrompt: string; negativeStylePrompt: string } => { + const { positivePrompt, negativePrompt } = state.generation; + const { positiveStylePrompt, negativeStylePrompt, shouldConcatSDXLStylePrompt } = state.sdxl; + + return { + positiveStylePrompt: shouldConcatSDXLStylePrompt ? positivePrompt : positiveStylePrompt, + negativeStylePrompt: shouldConcatSDXLStylePrompt ? negativePrompt : negativeStylePrompt, + }; +}; + +/** + * Gets the is_intermediate field, based on the active tab and shouldAutoSave setting. + */ +export const getIsIntermediate = (state: RootState) => { + const activeTabName = activeTabNameSelector(state); + if (activeTabName === 'unifiedCanvas') { + return !state.canvas.shouldAutoSave; + } + return false; +}; diff --git a/invokeai/frontend/web/src/services/api/types.ts b/invokeai/frontend/web/src/services/api/types.ts index 1382fbe275a..55ff808b404 100644 --- a/invokeai/frontend/web/src/services/api/types.ts +++ b/invokeai/frontend/web/src/services/api/types.ts @@ -132,7 +132,6 @@ export type DivideInvocation = s['DivideInvocation']; export type ImageNSFWBlurInvocation = s['ImageNSFWBlurInvocation']; export type ImageWatermarkInvocation = s['ImageWatermarkInvocation']; export type SeamlessModeInvocation = s['SeamlessModeInvocation']; -export type LinearUIOutputInvocation = s['LinearUIOutputInvocation']; export type MetadataInvocation = s['MetadataInvocation']; export type CoreMetadataInvocation = s['CoreMetadataInvocation']; export type MetadataItemInvocation = s['MetadataItemInvocation']; From 8c9981e89d13889673996de57e1e90a774d35808 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 17:01:39 +1100 Subject: [PATCH 031/340] tidy(nodes): remove unnecessary, shadowing class attr declarations --- invokeai/app/services/invocation_services.py | 27 -------------------- 1 file changed, 27 deletions(-) diff --git a/invokeai/app/services/invocation_services.py b/invokeai/app/services/invocation_services.py index 11a4de99d6e..51bfd5d77a1 100644 --- a/invokeai/app/services/invocation_services.py +++ b/invokeai/app/services/invocation_services.py @@ -36,33 +36,6 @@ class InvocationServices: """Services that can be used by invocations""" - # TODO: Just forward-declared everything due to circular dependencies. Fix structure. - board_images: "BoardImagesServiceABC" - board_image_record_storage: "BoardImageRecordStorageBase" - boards: "BoardServiceABC" - board_records: "BoardRecordStorageBase" - configuration: "InvokeAIAppConfig" - events: "EventServiceBase" - graph_execution_manager: "ItemStorageABC[GraphExecutionState]" - images: "ImageServiceABC" - image_records: "ImageRecordStorageBase" - image_files: "ImageFileStorageBase" - latents: "LatentsStorageBase" - logger: "Logger" - model_manager: "ModelManagerServiceBase" - model_records: "ModelRecordServiceBase" - download_queue: "DownloadQueueServiceBase" - model_install: "ModelInstallServiceBase" - processor: "InvocationProcessorABC" - performance_statistics: "InvocationStatsServiceBase" - queue: "InvocationQueueABC" - session_queue: "SessionQueueBase" - session_processor: "SessionProcessorBase" - invocation_cache: "InvocationCacheBase" - names: "NameServiceBase" - urls: "UrlServiceBase" - workflow_records: "WorkflowRecordsStorageBase" - def __init__( self, board_images: "BoardImagesServiceABC", From c96d363d845fd7baade37ba80d8fcf6e147a5c4e Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 17:10:25 +1100 Subject: [PATCH 032/340] fix(nodes): rearrange fields.py to avoid needing forward refs --- invokeai/app/invocations/fields.py | 92 +++++++++++++++--------------- 1 file changed, 46 insertions(+), 46 deletions(-) diff --git a/invokeai/app/invocations/fields.py b/invokeai/app/invocations/fields.py index c42d2f83120..40d403c03d9 100644 --- a/invokeai/app/invocations/fields.py +++ b/invokeai/app/invocations/fields.py @@ -182,6 +182,51 @@ class FieldDescriptions: freeu_b2 = "Scaling factor for stage 2 to amplify the contributions of backbone features." +class ImageField(BaseModel): + """An image primitive field""" + + image_name: str = Field(description="The name of the image") + + +class BoardField(BaseModel): + """A board primitive field""" + + board_id: str = Field(description="The id of the board") + + +class DenoiseMaskField(BaseModel): + """An inpaint mask field""" + + mask_name: str = Field(description="The name of the mask image") + masked_latents_name: Optional[str] = Field(default=None, description="The name of the masked image latents") + + +class LatentsField(BaseModel): + """A latents tensor primitive field""" + + latents_name: str = Field(description="The name of the latents") + seed: Optional[int] = Field(default=None, description="Seed used to generate this latents") + + +class ColorField(BaseModel): + """A color primitive field""" + + r: int = Field(ge=0, le=255, description="The red component") + g: int = Field(ge=0, le=255, description="The green component") + b: int = Field(ge=0, le=255, description="The blue component") + a: int = Field(ge=0, le=255, description="The alpha component") + + def tuple(self) -> Tuple[int, int, int, int]: + return (self.r, self.g, self.b, self.a) + + +class ConditioningField(BaseModel): + """A conditioning tensor primitive value""" + + conditioning_name: str = Field(description="The name of conditioning tensor") + # endregion + + class MetadataField(RootModel): """ Pydantic model for metadata with custom root of type dict[str, Any]. @@ -285,7 +330,7 @@ class WithBoard(BaseModel): Inherit from this class if your node needs a board input field. """ - board: Optional["BoardField"] = Field( + board: Optional[BoardField] = Field( default=None, description=FieldDescriptions.board, json_schema_extra=InputFieldJSONSchemaExtra( @@ -518,48 +563,3 @@ def OutputField( field_kind=FieldKind.Output, ).model_dump(exclude_none=True), ) - - -class ImageField(BaseModel): - """An image primitive field""" - - image_name: str = Field(description="The name of the image") - - -class BoardField(BaseModel): - """A board primitive field""" - - board_id: str = Field(description="The id of the board") - - -class DenoiseMaskField(BaseModel): - """An inpaint mask field""" - - mask_name: str = Field(description="The name of the mask image") - masked_latents_name: Optional[str] = Field(default=None, description="The name of the masked image latents") - - -class LatentsField(BaseModel): - """A latents tensor primitive field""" - - latents_name: str = Field(description="The name of the latents") - seed: Optional[int] = Field(default=None, description="Seed used to generate this latents") - - -class ColorField(BaseModel): - """A color primitive field""" - - r: int = Field(ge=0, le=255, description="The red component") - g: int = Field(ge=0, le=255, description="The green component") - b: int = Field(ge=0, le=255, description="The blue component") - a: int = Field(ge=0, le=255, description="The alpha component") - - def tuple(self) -> Tuple[int, int, int, int]: - return (self.r, self.g, self.b, self.a) - - -class ConditioningField(BaseModel): - """A conditioning tensor primitive value""" - - conditioning_name: str = Field(description="The name of conditioning tensor") - # endregion From e003542f393bf0f180d49f7d6a4de1148e45be3c Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 17:11:22 +1100 Subject: [PATCH 033/340] tidy(nodes): delete onnx.py It doesn't work and keeping it updated to prevent the app from starting was getting tedious. Deleted. --- invokeai/app/invocations/onnx.py | 510 ------------------------------- 1 file changed, 510 deletions(-) delete mode 100644 invokeai/app/invocations/onnx.py diff --git a/invokeai/app/invocations/onnx.py b/invokeai/app/invocations/onnx.py deleted file mode 100644 index e7b4d3d9fc5..00000000000 --- a/invokeai/app/invocations/onnx.py +++ /dev/null @@ -1,510 +0,0 @@ -# Copyright (c) 2023 Borisov Sergey (https://github.com/StAlKeR7779) - -import inspect - -# from contextlib import ExitStack -from typing import List, Literal, Union - -import numpy as np -import torch -from diffusers.image_processor import VaeImageProcessor -from pydantic import BaseModel, ConfigDict, Field, field_validator -from tqdm import tqdm - -from invokeai.app.invocations.fields import ( - FieldDescriptions, - Input, - InputField, - OutputField, - UIComponent, - UIType, - WithMetadata, -) -from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput, ImageField, ImageOutput -from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin -from invokeai.app.util.step_callback import stable_diffusion_step_callback -from invokeai.backend import BaseModelType, ModelType, SubModelType - -from ...backend.model_management import ONNXModelPatcher -from ...backend.stable_diffusion import PipelineIntermediateState -from ...backend.util import choose_torch_device -from ..util.ti_utils import extract_ti_triggers_from_prompt -from .baseinvocation import ( - BaseInvocation, - BaseInvocationOutput, - InvocationContext, - invocation, - invocation_output, -) -from .controlnet_image_processors import ControlField -from .latent import SAMPLER_NAME_VALUES, LatentsField, LatentsOutput, get_scheduler -from .model import ClipField, ModelInfo, UNetField, VaeField - -ORT_TO_NP_TYPE = { - "tensor(bool)": np.bool_, - "tensor(int8)": np.int8, - "tensor(uint8)": np.uint8, - "tensor(int16)": np.int16, - "tensor(uint16)": np.uint16, - "tensor(int32)": np.int32, - "tensor(uint32)": np.uint32, - "tensor(int64)": np.int64, - "tensor(uint64)": np.uint64, - "tensor(float16)": np.float16, - "tensor(float)": np.float32, - "tensor(double)": np.float64, -} - -PRECISION_VALUES = Literal[tuple(ORT_TO_NP_TYPE.keys())] - - -@invocation("prompt_onnx", title="ONNX Prompt (Raw)", tags=["prompt", "onnx"], category="conditioning", version="1.0.0") -class ONNXPromptInvocation(BaseInvocation): - prompt: str = InputField(default="", description=FieldDescriptions.raw_prompt, ui_component=UIComponent.Textarea) - clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection) - - def invoke(self, context: InvocationContext) -> ConditioningOutput: - tokenizer_info = context.services.model_manager.get_model( - **self.clip.tokenizer.model_dump(), - ) - text_encoder_info = context.services.model_manager.get_model( - **self.clip.text_encoder.model_dump(), - ) - with tokenizer_info as orig_tokenizer, text_encoder_info as text_encoder: # , ExitStack() as stack: - loras = [ - ( - context.services.model_manager.get_model(**lora.model_dump(exclude={"weight"})).context.model, - lora.weight, - ) - for lora in self.clip.loras - ] - - ti_list = [] - for trigger in extract_ti_triggers_from_prompt(self.prompt): - name = trigger[1:-1] - try: - ti_list.append( - ( - name, - context.services.model_manager.get_model( - model_name=name, - base_model=self.clip.text_encoder.base_model, - model_type=ModelType.TextualInversion, - ).context.model, - ) - ) - except Exception: - # print(e) - # import traceback - # print(traceback.format_exc()) - print(f'Warn: trigger: "{trigger}" not found') - if loras or ti_list: - text_encoder.release_session() - with ( - ONNXModelPatcher.apply_lora_text_encoder(text_encoder, loras), - ONNXModelPatcher.apply_ti(orig_tokenizer, text_encoder, ti_list) as (tokenizer, ti_manager), - ): - text_encoder.create_session() - - # copy from - # https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L153 - text_inputs = tokenizer( - self.prompt, - padding="max_length", - max_length=tokenizer.model_max_length, - truncation=True, - return_tensors="np", - ) - text_input_ids = text_inputs.input_ids - """ - untruncated_ids = tokenizer(prompt, padding="max_length", return_tensors="np").input_ids - - if not np.array_equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode( - untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] - ) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - """ - - prompt_embeds = text_encoder(input_ids=text_input_ids.astype(np.int32))[0] - - conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning" - - # TODO: hacky but works ;D maybe rename latents somehow? - context.services.latents.save(conditioning_name, (prompt_embeds, None)) - - return ConditioningOutput( - conditioning=ConditioningField( - conditioning_name=conditioning_name, - ), - ) - - -# Text to image -@invocation( - "t2l_onnx", - title="ONNX Text to Latents", - tags=["latents", "inference", "txt2img", "onnx"], - category="latents", - version="1.0.0", -) -class ONNXTextToLatentsInvocation(BaseInvocation): - """Generates latents from conditionings.""" - - positive_conditioning: ConditioningField = InputField( - description=FieldDescriptions.positive_cond, - input=Input.Connection, - ) - negative_conditioning: ConditioningField = InputField( - description=FieldDescriptions.negative_cond, - input=Input.Connection, - ) - noise: LatentsField = InputField( - description=FieldDescriptions.noise, - input=Input.Connection, - ) - steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps) - cfg_scale: Union[float, List[float]] = InputField( - default=7.5, - ge=1, - description=FieldDescriptions.cfg_scale, - ) - scheduler: SAMPLER_NAME_VALUES = InputField( - default="euler", description=FieldDescriptions.scheduler, input=Input.Direct, ui_type=UIType.Scheduler - ) - precision: PRECISION_VALUES = InputField(default="tensor(float16)", description=FieldDescriptions.precision) - unet: UNetField = InputField( - description=FieldDescriptions.unet, - input=Input.Connection, - ) - control: Union[ControlField, list[ControlField]] = InputField( - default=None, - description=FieldDescriptions.control, - ) - # seamless: bool = InputField(default=False, description="Whether or not to generate an image that can tile without seams", ) - # seamless_axes: str = InputField(default="", description="The axes to tile the image on, 'x' and/or 'y'") - - @field_validator("cfg_scale") - def ge_one(cls, v): - """validate that all cfg_scale values are >= 1""" - if isinstance(v, list): - for i in v: - if i < 1: - raise ValueError("cfg_scale must be greater than 1") - else: - if v < 1: - raise ValueError("cfg_scale must be greater than 1") - return v - - # based on - # https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375 - def invoke(self, context: InvocationContext) -> LatentsOutput: - c, _ = context.services.latents.get(self.positive_conditioning.conditioning_name) - uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name) - graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) - source_node_id = graph_execution_state.prepared_source_mapping[self.id] - if isinstance(c, torch.Tensor): - c = c.cpu().numpy() - if isinstance(uc, torch.Tensor): - uc = uc.cpu().numpy() - device = torch.device(choose_torch_device()) - prompt_embeds = np.concatenate([uc, c]) - - latents = context.services.latents.get(self.noise.latents_name) - if isinstance(latents, torch.Tensor): - latents = latents.cpu().numpy() - - # TODO: better execution device handling - latents = latents.astype(ORT_TO_NP_TYPE[self.precision]) - - # get the initial random noise unless the user supplied it - do_classifier_free_guidance = True - # latents_dtype = prompt_embeds.dtype - # latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8) - # if latents.shape != latents_shape: - # raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") - - scheduler = get_scheduler( - context=context, - scheduler_info=self.unet.scheduler, - scheduler_name=self.scheduler, - seed=0, # TODO: refactor this node - ) - - def torch2numpy(latent: torch.Tensor): - return latent.cpu().numpy() - - def numpy2torch(latent, device): - return torch.from_numpy(latent).to(device) - - def dispatch_progress( - self, context: InvocationContext, source_node_id: str, intermediate_state: PipelineIntermediateState - ) -> None: - stable_diffusion_step_callback( - context=context, - intermediate_state=intermediate_state, - node=self.model_dump(), - source_node_id=source_node_id, - ) - - scheduler.set_timesteps(self.steps) - latents = latents * np.float64(scheduler.init_noise_sigma) - - extra_step_kwargs = {} - if "eta" in set(inspect.signature(scheduler.step).parameters.keys()): - extra_step_kwargs.update( - eta=0.0, - ) - - unet_info = context.services.model_manager.get_model(**self.unet.unet.model_dump()) - - with unet_info as unet: # , ExitStack() as stack: - # loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras] - loras = [ - ( - context.services.model_manager.get_model(**lora.model_dump(exclude={"weight"})).context.model, - lora.weight, - ) - for lora in self.unet.loras - ] - - if loras: - unet.release_session() - with ONNXModelPatcher.apply_lora_unet(unet, loras): - # TODO: - _, _, h, w = latents.shape - unet.create_session(h, w) - - timestep_dtype = next( - (input.type for input in unet.session.get_inputs() if input.name == "timestep"), "tensor(float16)" - ) - timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype] - for i in tqdm(range(len(scheduler.timesteps))): - t = scheduler.timesteps[i] - # expand the latents if we are doing classifier free guidance - latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = scheduler.scale_model_input(numpy2torch(latent_model_input, device), t) - latent_model_input = latent_model_input.cpu().numpy() - - # predict the noise residual - timestep = np.array([t], dtype=timestep_dtype) - noise_pred = unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds) - noise_pred = noise_pred[0] - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) - noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - scheduler_output = scheduler.step( - numpy2torch(noise_pred, device), t, numpy2torch(latents, device), **extra_step_kwargs - ) - latents = torch2numpy(scheduler_output.prev_sample) - - state = PipelineIntermediateState( - run_id="test", step=i, timestep=timestep, latents=scheduler_output.prev_sample - ) - dispatch_progress(self, context=context, source_node_id=source_node_id, intermediate_state=state) - - # call the callback, if provided - # if callback is not None and i % callback_steps == 0: - # callback(i, t, latents) - - torch.cuda.empty_cache() - - name = f"{context.graph_execution_state_id}__{self.id}" - context.services.latents.save(name, latents) - # return build_latents_output(latents_name=name, latents=torch.from_numpy(latents)) - - -# Latent to image -@invocation( - "l2i_onnx", - title="ONNX Latents to Image", - tags=["latents", "image", "vae", "onnx"], - category="image", - version="1.2.0", -) -class ONNXLatentsToImageInvocation(BaseInvocation, WithMetadata): - """Generates an image from latents.""" - - latents: LatentsField = InputField( - description=FieldDescriptions.denoised_latents, - input=Input.Connection, - ) - vae: VaeField = InputField( - description=FieldDescriptions.vae, - input=Input.Connection, - ) - # tiled: bool = InputField(default=False, description="Decode latents by overlaping tiles(less memory consumption)") - - def invoke(self, context: InvocationContext) -> ImageOutput: - latents = context.services.latents.get(self.latents.latents_name) - - if self.vae.vae.submodel != SubModelType.VaeDecoder: - raise Exception(f"Expected vae_decoder, found: {self.vae.vae.model_type}") - - vae_info = context.services.model_manager.get_model( - **self.vae.vae.model_dump(), - ) - - # clear memory as vae decode can request a lot - torch.cuda.empty_cache() - - with vae_info as vae: - vae.create_session() - - # copied from - # https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L427 - latents = 1 / 0.18215 * latents - # image = self.vae_decoder(latent_sample=latents)[0] - # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 - image = np.concatenate([vae(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])]) - - image = np.clip(image / 2 + 0.5, 0, 1) - image = image.transpose((0, 2, 3, 1)) - image = VaeImageProcessor.numpy_to_pil(image)[0] - - torch.cuda.empty_cache() - - image_dto = context.services.images.create( - image=image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) - - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) - - -@invocation_output("model_loader_output_onnx") -class ONNXModelLoaderOutput(BaseInvocationOutput): - """Model loader output""" - - unet: UNetField = OutputField(default=None, description=FieldDescriptions.unet, title="UNet") - clip: ClipField = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP") - vae_decoder: VaeField = OutputField(default=None, description=FieldDescriptions.vae, title="VAE Decoder") - vae_encoder: VaeField = OutputField(default=None, description=FieldDescriptions.vae, title="VAE Encoder") - - -class OnnxModelField(BaseModel): - """Onnx model field""" - - model_name: str = Field(description="Name of the model") - base_model: BaseModelType = Field(description="Base model") - model_type: ModelType = Field(description="Model Type") - - model_config = ConfigDict(protected_namespaces=()) - - -@invocation("onnx_model_loader", title="ONNX Main Model", tags=["onnx", "model"], category="model", version="1.0.0") -class OnnxModelLoaderInvocation(BaseInvocation): - """Loads a main model, outputting its submodels.""" - - model: OnnxModelField = InputField( - description=FieldDescriptions.onnx_main_model, input=Input.Direct, ui_type=UIType.ONNXModel - ) - - def invoke(self, context: InvocationContext) -> ONNXModelLoaderOutput: - base_model = self.model.base_model - model_name = self.model.model_name - model_type = ModelType.ONNX - - # TODO: not found exceptions - if not context.services.model_manager.model_exists( - model_name=model_name, - base_model=base_model, - model_type=model_type, - ): - raise Exception(f"Unknown {base_model} {model_type} model: {model_name}") - - """ - if not context.services.model_manager.model_exists( - model_name=self.model_name, - model_type=SDModelType.Diffusers, - submodel=SDModelType.Tokenizer, - ): - raise Exception( - f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted" - ) - - if not context.services.model_manager.model_exists( - model_name=self.model_name, - model_type=SDModelType.Diffusers, - submodel=SDModelType.TextEncoder, - ): - raise Exception( - f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted" - ) - - if not context.services.model_manager.model_exists( - model_name=self.model_name, - model_type=SDModelType.Diffusers, - submodel=SDModelType.UNet, - ): - raise Exception( - f"Failed to find unet submodel from {self.model_name}! Check if model corrupted" - ) - """ - - return ONNXModelLoaderOutput( - unet=UNetField( - unet=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.UNet, - ), - scheduler=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.Scheduler, - ), - loras=[], - ), - clip=ClipField( - tokenizer=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.Tokenizer, - ), - text_encoder=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.TextEncoder, - ), - loras=[], - skipped_layers=0, - ), - vae_decoder=VaeField( - vae=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.VaeDecoder, - ), - ), - vae_encoder=VaeField( - vae=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.VaeEncoder, - ), - ), - ) From 9c5bada4fc416cf4d3aa1b581955621cc129e9fb Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 17:41:23 +1100 Subject: [PATCH 034/340] feat(nodes): replace latents service with tensors and conditioning services - New generic class `PickleStorageBase`, implements the same API as `LatentsStorageBase`, use for storing non-serializable data via pickling - Implementation `PickleStorageTorch` uses `torch.save` and `torch.load`, same as `LatentsStorageDisk` - Add `tensors: PickleStorageBase[torch.Tensor]` to `InvocationServices` - Add `conditioning: PickleStorageBase[ConditioningFieldData]` to `InvocationServices` - Remove `latents` service and all `LatentsStorage` classes - Update `InvocationContext` and all usage of old `latents` service to use the new services/context wrapper methods --- invokeai/app/api/dependencies.py | 18 +++-- invokeai/app/invocations/latent.py | 36 +++++----- invokeai/app/invocations/noise.py | 2 +- invokeai/app/invocations/primitives.py | 6 +- .../invocation_cache_memory.py | 3 +- invokeai/app/services/invocation_services.py | 12 +++- .../app/services/latents_storage/__init__.py | 0 .../latents_storage/latents_storage_disk.py | 58 ---------------- .../latents_storage_forward_cache.py | 68 ------------------- .../pickle_storage_base.py} | 18 ++--- .../pickle_storage_forward_cache.py | 58 ++++++++++++++++ .../pickle_storage/pickle_storage_torch.py | 62 +++++++++++++++++ .../app/services/shared/invocation_context.py | 49 ++++++------- 13 files changed, 197 insertions(+), 193 deletions(-) delete mode 100644 invokeai/app/services/latents_storage/__init__.py delete mode 100644 invokeai/app/services/latents_storage/latents_storage_disk.py delete mode 100644 invokeai/app/services/latents_storage/latents_storage_forward_cache.py rename invokeai/app/services/{latents_storage/latents_storage_base.py => pickle_storage/pickle_storage_base.py} (68%) create mode 100644 invokeai/app/services/pickle_storage/pickle_storage_forward_cache.py create mode 100644 invokeai/app/services/pickle_storage/pickle_storage_torch.py diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index c8309e1729e..6bb0915cb6e 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -2,9 +2,14 @@ from logging import Logger +import torch + from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory +from invokeai.app.services.pickle_storage.pickle_storage_forward_cache import PickleStorageForwardCache +from invokeai.app.services.pickle_storage.pickle_storage_torch import PickleStorageTorch from invokeai.app.services.shared.sqlite.sqlite_util import init_db from invokeai.backend.model_manager.metadata import ModelMetadataStore +from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData from invokeai.backend.util.logging import InvokeAILogger from invokeai.version.invokeai_version import __version__ @@ -23,8 +28,6 @@ from ..services.invocation_services import InvocationServices from ..services.invocation_stats.invocation_stats_default import InvocationStatsService from ..services.invoker import Invoker -from ..services.latents_storage.latents_storage_disk import DiskLatentsStorage -from ..services.latents_storage.latents_storage_forward_cache import ForwardCacheLatentsStorage from ..services.model_install import ModelInstallService from ..services.model_manager.model_manager_default import ModelManagerService from ..services.model_records import ModelRecordServiceSQL @@ -68,6 +71,9 @@ def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger logger.debug(f"Internet connectivity is {config.internet_available}") output_folder = config.output_path + if output_folder is None: + raise ValueError("Output folder is not set") + image_files = DiskImageFileStorage(f"{output_folder}/images") db = init_db(config=config, logger=logger, image_files=image_files) @@ -84,7 +90,10 @@ def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger image_records = SqliteImageRecordStorage(db=db) images = ImageService() invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size) - latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents")) + tensors = PickleStorageForwardCache(PickleStorageTorch[torch.Tensor](output_folder / "tensors", "tensor")) + conditioning = PickleStorageForwardCache( + PickleStorageTorch[ConditioningFieldData](output_folder / "conditioning", "conditioning") + ) model_manager = ModelManagerService(config, logger) model_record_service = ModelRecordServiceSQL(db=db) download_queue_service = DownloadQueueService(event_bus=events) @@ -117,7 +126,6 @@ def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger image_records=image_records, images=images, invocation_cache=invocation_cache, - latents=latents, logger=logger, model_manager=model_manager, model_records=model_record_service, @@ -131,6 +139,8 @@ def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger session_queue=session_queue, urls=urls, workflow_records=workflow_records, + tensors=tensors, + conditioning=conditioning, ) ApiDependencies.invoker = Invoker(services) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 5449ec9af7a..94440d3e2aa 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -163,11 +163,11 @@ def invoke(self, context: InvocationContext) -> DenoiseMaskOutput: # TODO: masked_latents = ImageToLatentsInvocation.vae_encode(vae_info, self.fp32, self.tiled, masked_image.clone()) - masked_latents_name = context.latents.save(tensor=masked_latents) + masked_latents_name = context.tensors.save(tensor=masked_latents) else: masked_latents_name = None - mask_name = context.latents.save(tensor=mask) + mask_name = context.tensors.save(tensor=mask) return DenoiseMaskOutput.build( mask_name=mask_name, @@ -621,10 +621,10 @@ def prep_inpaint_mask(self, context: InvocationContext, latents): if self.denoise_mask is None: return None, None - mask = context.latents.get(self.denoise_mask.mask_name) + mask = context.tensors.get(self.denoise_mask.mask_name) mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False) if self.denoise_mask.masked_latents_name is not None: - masked_latents = context.latents.get(self.denoise_mask.masked_latents_name) + masked_latents = context.tensors.get(self.denoise_mask.masked_latents_name) else: masked_latents = None @@ -636,11 +636,11 @@ def invoke(self, context: InvocationContext) -> LatentsOutput: seed = None noise = None if self.noise is not None: - noise = context.latents.get(self.noise.latents_name) + noise = context.tensors.get(self.noise.latents_name) seed = self.noise.seed if self.latents is not None: - latents = context.latents.get(self.latents.latents_name) + latents = context.tensors.get(self.latents.latents_name) if seed is None: seed = self.latents.seed @@ -752,7 +752,7 @@ def _lora_loader(): if choose_torch_device() == torch.device("mps"): mps.empty_cache() - name = context.latents.save(tensor=result_latents) + name = context.tensors.save(tensor=result_latents) return LatentsOutput.build(latents_name=name, latents=result_latents, seed=seed) @@ -779,7 +779,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard): @torch.no_grad() def invoke(self, context: InvocationContext) -> ImageOutput: - latents = context.latents.get(self.latents.latents_name) + latents = context.tensors.get(self.latents.latents_name) vae_info = context.models.load(**self.vae.vae.model_dump()) @@ -870,7 +870,7 @@ class ResizeLatentsInvocation(BaseInvocation): antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias) def invoke(self, context: InvocationContext) -> LatentsOutput: - latents = context.latents.get(self.latents.latents_name) + latents = context.tensors.get(self.latents.latents_name) # TODO: device = choose_torch_device() @@ -888,7 +888,7 @@ def invoke(self, context: InvocationContext) -> LatentsOutput: if device == torch.device("mps"): mps.empty_cache() - name = context.latents.save(tensor=resized_latents) + name = context.tensors.save(tensor=resized_latents) return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed) @@ -911,7 +911,7 @@ class ScaleLatentsInvocation(BaseInvocation): antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias) def invoke(self, context: InvocationContext) -> LatentsOutput: - latents = context.latents.get(self.latents.latents_name) + latents = context.tensors.get(self.latents.latents_name) # TODO: device = choose_torch_device() @@ -930,7 +930,7 @@ def invoke(self, context: InvocationContext) -> LatentsOutput: if device == torch.device("mps"): mps.empty_cache() - name = context.latents.save(tensor=resized_latents) + name = context.tensors.save(tensor=resized_latents) return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed) @@ -1011,7 +1011,7 @@ def invoke(self, context: InvocationContext) -> LatentsOutput: latents = self.vae_encode(vae_info, self.fp32, self.tiled, image_tensor) latents = latents.to("cpu") - name = context.latents.save(tensor=latents) + name = context.tensors.save(tensor=latents) return LatentsOutput.build(latents_name=name, latents=latents, seed=None) @singledispatchmethod @@ -1048,8 +1048,8 @@ class BlendLatentsInvocation(BaseInvocation): alpha: float = InputField(default=0.5, description=FieldDescriptions.blend_alpha) def invoke(self, context: InvocationContext) -> LatentsOutput: - latents_a = context.latents.get(self.latents_a.latents_name) - latents_b = context.latents.get(self.latents_b.latents_name) + latents_a = context.tensors.get(self.latents_a.latents_name) + latents_b = context.tensors.get(self.latents_b.latents_name) if latents_a.shape != latents_b.shape: raise Exception("Latents to blend must be the same size.") @@ -1103,7 +1103,7 @@ def slerp(t, v0, v1, DOT_THRESHOLD=0.9995): if device == torch.device("mps"): mps.empty_cache() - name = context.latents.save(tensor=blended_latents) + name = context.tensors.save(tensor=blended_latents) return LatentsOutput.build(latents_name=name, latents=blended_latents) @@ -1149,7 +1149,7 @@ class CropLatentsCoreInvocation(BaseInvocation): ) def invoke(self, context: InvocationContext) -> LatentsOutput: - latents = context.latents.get(self.latents.latents_name) + latents = context.tensors.get(self.latents.latents_name) x1 = self.x // LATENT_SCALE_FACTOR y1 = self.y // LATENT_SCALE_FACTOR @@ -1158,7 +1158,7 @@ def invoke(self, context: InvocationContext) -> LatentsOutput: cropped_latents = latents[..., y1:y2, x1:x2] - name = context.latents.save(tensor=cropped_latents) + name = context.tensors.save(tensor=cropped_latents) return LatentsOutput.build(latents_name=name, latents=cropped_latents) diff --git a/invokeai/app/invocations/noise.py b/invokeai/app/invocations/noise.py index 78f13cc52d1..74b3d6e4cb1 100644 --- a/invokeai/app/invocations/noise.py +++ b/invokeai/app/invocations/noise.py @@ -121,5 +121,5 @@ def invoke(self, context: InvocationContext) -> NoiseOutput: seed=self.seed, use_cpu=self.use_cpu, ) - name = context.latents.save(tensor=noise) + name = context.tensors.save(tensor=noise) return NoiseOutput.build(latents_name=name, latents=noise, seed=self.seed) diff --git a/invokeai/app/invocations/primitives.py b/invokeai/app/invocations/primitives.py index a77939943ae..082d5432ccf 100644 --- a/invokeai/app/invocations/primitives.py +++ b/invokeai/app/invocations/primitives.py @@ -313,9 +313,7 @@ def build(cls, mask_name: str, masked_latents_name: Optional[str] = None) -> "De class LatentsOutput(BaseInvocationOutput): """Base class for nodes that output a single latents tensor""" - latents: LatentsField = OutputField( - description=FieldDescriptions.latents, - ) + latents: LatentsField = OutputField(description=FieldDescriptions.latents) width: int = OutputField(description=FieldDescriptions.width) height: int = OutputField(description=FieldDescriptions.height) @@ -346,7 +344,7 @@ class LatentsInvocation(BaseInvocation): latents: LatentsField = InputField(description="The latents tensor", input=Input.Connection) def invoke(self, context: InvocationContext) -> LatentsOutput: - latents = context.latents.get(self.latents.latents_name) + latents = context.tensors.get(self.latents.latents_name) return LatentsOutput.build(self.latents.latents_name, latents) diff --git a/invokeai/app/services/invocation_cache/invocation_cache_memory.py b/invokeai/app/services/invocation_cache/invocation_cache_memory.py index 4a503b3c6b1..c700f81186f 100644 --- a/invokeai/app/services/invocation_cache/invocation_cache_memory.py +++ b/invokeai/app/services/invocation_cache/invocation_cache_memory.py @@ -37,7 +37,8 @@ def start(self, invoker: Invoker) -> None: if self._max_cache_size == 0: return self._invoker.services.images.on_deleted(self._delete_by_match) - self._invoker.services.latents.on_deleted(self._delete_by_match) + self._invoker.services.tensors.on_deleted(self._delete_by_match) + self._invoker.services.conditioning.on_deleted(self._delete_by_match) def get(self, key: Union[int, str]) -> Optional[BaseInvocationOutput]: with self._lock: diff --git a/invokeai/app/services/invocation_services.py b/invokeai/app/services/invocation_services.py index 51bfd5d77a1..81885781acb 100644 --- a/invokeai/app/services/invocation_services.py +++ b/invokeai/app/services/invocation_services.py @@ -6,6 +6,10 @@ if TYPE_CHECKING: from logging import Logger + import torch + + from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData + from .board_image_records.board_image_records_base import BoardImageRecordStorageBase from .board_images.board_images_base import BoardImagesServiceABC from .board_records.board_records_base import BoardRecordStorageBase @@ -21,11 +25,11 @@ from .invocation_queue.invocation_queue_base import InvocationQueueABC from .invocation_stats.invocation_stats_base import InvocationStatsServiceBase from .item_storage.item_storage_base import ItemStorageABC - from .latents_storage.latents_storage_base import LatentsStorageBase from .model_install import ModelInstallServiceBase from .model_manager.model_manager_base import ModelManagerServiceBase from .model_records import ModelRecordServiceBase from .names.names_base import NameServiceBase + from .pickle_storage.pickle_storage_base import PickleStorageBase from .session_processor.session_processor_base import SessionProcessorBase from .session_queue.session_queue_base import SessionQueueBase from .shared.graph import GraphExecutionState @@ -48,7 +52,6 @@ def __init__( images: "ImageServiceABC", image_files: "ImageFileStorageBase", image_records: "ImageRecordStorageBase", - latents: "LatentsStorageBase", logger: "Logger", model_manager: "ModelManagerServiceBase", model_records: "ModelRecordServiceBase", @@ -63,6 +66,8 @@ def __init__( names: "NameServiceBase", urls: "UrlServiceBase", workflow_records: "WorkflowRecordsStorageBase", + tensors: "PickleStorageBase[torch.Tensor]", + conditioning: "PickleStorageBase[ConditioningFieldData]", ): self.board_images = board_images self.board_image_records = board_image_records @@ -74,7 +79,6 @@ def __init__( self.images = images self.image_files = image_files self.image_records = image_records - self.latents = latents self.logger = logger self.model_manager = model_manager self.model_records = model_records @@ -89,3 +93,5 @@ def __init__( self.names = names self.urls = urls self.workflow_records = workflow_records + self.tensors = tensors + self.conditioning = conditioning diff --git a/invokeai/app/services/latents_storage/__init__.py b/invokeai/app/services/latents_storage/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/invokeai/app/services/latents_storage/latents_storage_disk.py b/invokeai/app/services/latents_storage/latents_storage_disk.py deleted file mode 100644 index 9192b9147f7..00000000000 --- a/invokeai/app/services/latents_storage/latents_storage_disk.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) - -from pathlib import Path -from typing import Union - -import torch - -from invokeai.app.services.invoker import Invoker - -from .latents_storage_base import LatentsStorageBase - - -class DiskLatentsStorage(LatentsStorageBase): - """Stores latents in a folder on disk without caching""" - - __output_folder: Path - - def __init__(self, output_folder: Union[str, Path]): - self.__output_folder = output_folder if isinstance(output_folder, Path) else Path(output_folder) - self.__output_folder.mkdir(parents=True, exist_ok=True) - - def start(self, invoker: Invoker) -> None: - self._invoker = invoker - self._delete_all_latents() - - def get(self, name: str) -> torch.Tensor: - latent_path = self.get_path(name) - return torch.load(latent_path) - - def save(self, name: str, data: torch.Tensor) -> None: - self.__output_folder.mkdir(parents=True, exist_ok=True) - latent_path = self.get_path(name) - torch.save(data, latent_path) - - def delete(self, name: str) -> None: - latent_path = self.get_path(name) - latent_path.unlink() - - def get_path(self, name: str) -> Path: - return self.__output_folder / name - - def _delete_all_latents(self) -> None: - """ - Deletes all latents from disk. - Must be called after we have access to `self._invoker` (e.g. in `start()`). - """ - deleted_latents_count = 0 - freed_space = 0 - for latents_file in Path(self.__output_folder).glob("*"): - if latents_file.is_file(): - freed_space += latents_file.stat().st_size - deleted_latents_count += 1 - latents_file.unlink() - if deleted_latents_count > 0: - freed_space_in_mb = round(freed_space / 1024 / 1024, 2) - self._invoker.services.logger.info( - f"Deleted {deleted_latents_count} latents files (freed {freed_space_in_mb}MB)" - ) diff --git a/invokeai/app/services/latents_storage/latents_storage_forward_cache.py b/invokeai/app/services/latents_storage/latents_storage_forward_cache.py deleted file mode 100644 index 6232b76a27d..00000000000 --- a/invokeai/app/services/latents_storage/latents_storage_forward_cache.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) - -from queue import Queue -from typing import Dict, Optional - -import torch - -from invokeai.app.services.invoker import Invoker - -from .latents_storage_base import LatentsStorageBase - - -class ForwardCacheLatentsStorage(LatentsStorageBase): - """Caches the latest N latents in memory, writing-thorugh to and reading from underlying storage""" - - __cache: Dict[str, torch.Tensor] - __cache_ids: Queue - __max_cache_size: int - __underlying_storage: LatentsStorageBase - - def __init__(self, underlying_storage: LatentsStorageBase, max_cache_size: int = 20): - super().__init__() - self.__underlying_storage = underlying_storage - self.__cache = {} - self.__cache_ids = Queue() - self.__max_cache_size = max_cache_size - - def start(self, invoker: Invoker) -> None: - self._invoker = invoker - start_op = getattr(self.__underlying_storage, "start", None) - if callable(start_op): - start_op(invoker) - - def stop(self, invoker: Invoker) -> None: - self._invoker = invoker - stop_op = getattr(self.__underlying_storage, "stop", None) - if callable(stop_op): - stop_op(invoker) - - def get(self, name: str) -> torch.Tensor: - cache_item = self.__get_cache(name) - if cache_item is not None: - return cache_item - - latent = self.__underlying_storage.get(name) - self.__set_cache(name, latent) - return latent - - def save(self, name: str, data: torch.Tensor) -> None: - self.__underlying_storage.save(name, data) - self.__set_cache(name, data) - self._on_changed(data) - - def delete(self, name: str) -> None: - self.__underlying_storage.delete(name) - if name in self.__cache: - del self.__cache[name] - self._on_deleted(name) - - def __get_cache(self, name: str) -> Optional[torch.Tensor]: - return None if name not in self.__cache else self.__cache[name] - - def __set_cache(self, name: str, data: torch.Tensor): - if name not in self.__cache: - self.__cache[name] = data - self.__cache_ids.put(name) - if self.__cache_ids.qsize() > self.__max_cache_size: - self.__cache.pop(self.__cache_ids.get()) diff --git a/invokeai/app/services/latents_storage/latents_storage_base.py b/invokeai/app/services/pickle_storage/pickle_storage_base.py similarity index 68% rename from invokeai/app/services/latents_storage/latents_storage_base.py rename to invokeai/app/services/pickle_storage/pickle_storage_base.py index 9fa42b0ae61..558b97c0f1b 100644 --- a/invokeai/app/services/latents_storage/latents_storage_base.py +++ b/invokeai/app/services/pickle_storage/pickle_storage_base.py @@ -1,15 +1,15 @@ # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) from abc import ABC, abstractmethod -from typing import Callable +from typing import Callable, Generic, TypeVar -import torch +T = TypeVar("T") -class LatentsStorageBase(ABC): - """Responsible for storing and retrieving latents.""" +class PickleStorageBase(ABC, Generic[T]): + """Responsible for storing and retrieving non-serializable data using a pickler.""" - _on_changed_callbacks: list[Callable[[torch.Tensor], None]] + _on_changed_callbacks: list[Callable[[T], None]] _on_deleted_callbacks: list[Callable[[str], None]] def __init__(self) -> None: @@ -17,18 +17,18 @@ def __init__(self) -> None: self._on_deleted_callbacks = [] @abstractmethod - def get(self, name: str) -> torch.Tensor: + def get(self, name: str) -> T: pass @abstractmethod - def save(self, name: str, data: torch.Tensor) -> None: + def save(self, name: str, data: T) -> None: pass @abstractmethod def delete(self, name: str) -> None: pass - def on_changed(self, on_changed: Callable[[torch.Tensor], None]) -> None: + def on_changed(self, on_changed: Callable[[T], None]) -> None: """Register a callback for when an item is changed""" self._on_changed_callbacks.append(on_changed) @@ -36,7 +36,7 @@ def on_deleted(self, on_deleted: Callable[[str], None]) -> None: """Register a callback for when an item is deleted""" self._on_deleted_callbacks.append(on_deleted) - def _on_changed(self, item: torch.Tensor) -> None: + def _on_changed(self, item: T) -> None: for callback in self._on_changed_callbacks: callback(item) diff --git a/invokeai/app/services/pickle_storage/pickle_storage_forward_cache.py b/invokeai/app/services/pickle_storage/pickle_storage_forward_cache.py new file mode 100644 index 00000000000..3002d9e045d --- /dev/null +++ b/invokeai/app/services/pickle_storage/pickle_storage_forward_cache.py @@ -0,0 +1,58 @@ +from queue import Queue +from typing import Optional, TypeVar + +from invokeai.app.services.invoker import Invoker +from invokeai.app.services.pickle_storage.pickle_storage_base import PickleStorageBase + +T = TypeVar("T") + + +class PickleStorageForwardCache(PickleStorageBase[T]): + def __init__(self, underlying_storage: PickleStorageBase[T], max_cache_size: int = 20): + super().__init__() + self._underlying_storage = underlying_storage + self._cache: dict[str, T] = {} + self._cache_ids = Queue[str]() + self._max_cache_size = max_cache_size + + def start(self, invoker: Invoker) -> None: + self._invoker = invoker + start_op = getattr(self._underlying_storage, "start", None) + if callable(start_op): + start_op(invoker) + + def stop(self, invoker: Invoker) -> None: + self._invoker = invoker + stop_op = getattr(self._underlying_storage, "stop", None) + if callable(stop_op): + stop_op(invoker) + + def get(self, name: str) -> T: + cache_item = self._get_cache(name) + if cache_item is not None: + return cache_item + + latent = self._underlying_storage.get(name) + self._set_cache(name, latent) + return latent + + def save(self, name: str, data: T) -> None: + self._underlying_storage.save(name, data) + self._set_cache(name, data) + self._on_changed(data) + + def delete(self, name: str) -> None: + self._underlying_storage.delete(name) + if name in self._cache: + del self._cache[name] + self._on_deleted(name) + + def _get_cache(self, name: str) -> Optional[T]: + return None if name not in self._cache else self._cache[name] + + def _set_cache(self, name: str, data: T): + if name not in self._cache: + self._cache[name] = data + self._cache_ids.put(name) + if self._cache_ids.qsize() > self._max_cache_size: + self._cache.pop(self._cache_ids.get()) diff --git a/invokeai/app/services/pickle_storage/pickle_storage_torch.py b/invokeai/app/services/pickle_storage/pickle_storage_torch.py new file mode 100644 index 00000000000..0b3c9af7a33 --- /dev/null +++ b/invokeai/app/services/pickle_storage/pickle_storage_torch.py @@ -0,0 +1,62 @@ +# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) + +from pathlib import Path +from typing import TypeVar + +import torch + +from invokeai.app.services.invoker import Invoker +from invokeai.app.services.pickle_storage.pickle_storage_base import PickleStorageBase + +T = TypeVar("T") + + +class PickleStorageTorch(PickleStorageBase[T]): + """Responsible for storing and retrieving non-serializable data using `torch.save` and `torch.load`.""" + + def __init__(self, output_folder: Path, item_type_name: "str"): + self._output_folder = output_folder + self._output_folder.mkdir(parents=True, exist_ok=True) + self._item_type_name = item_type_name + + def start(self, invoker: Invoker) -> None: + self._invoker = invoker + self._delete_all_items() + + def get(self, name: str) -> T: + latent_path = self._get_path(name) + return torch.load(latent_path) + + def save(self, name: str, data: T) -> None: + self._output_folder.mkdir(parents=True, exist_ok=True) + latent_path = self._get_path(name) + torch.save(data, latent_path) + + def delete(self, name: str) -> None: + latent_path = self._get_path(name) + latent_path.unlink() + + def _get_path(self, name: str) -> Path: + return self._output_folder / name + + def _delete_all_items(self) -> None: + """ + Deletes all pickled items from disk. + Must be called after we have access to `self._invoker` (e.g. in `start()`). + """ + + if not self._invoker: + raise ValueError("Invoker is not set. Must call `start()` first.") + + deleted_latents_count = 0 + freed_space = 0 + for latents_file in Path(self._output_folder).glob("*"): + if latents_file.is_file(): + freed_space += latents_file.stat().st_size + deleted_latents_count += 1 + latents_file.unlink() + if deleted_latents_count > 0: + freed_space_in_mb = round(freed_space / 1024 / 1024, 2) + self._invoker.services.logger.info( + f"Deleted {deleted_latents_count} {self._item_type_name} files (freed {freed_space_in_mb}MB)" + ) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 97a62246fbc..6756b1f5c6c 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -216,48 +216,46 @@ def get_dto(self, image_name: str) -> ImageDTO: return self._services.images.get_dto(image_name) -class LatentsInterface(InvocationContextInterface): +class TensorsInterface(InvocationContextInterface): def save(self, tensor: Tensor) -> str: """ - Saves a latents tensor, returning its name. + Saves a tensor, returning its name. - :param tensor: The latents tensor to save. + :param tensor: The tensor to save. """ # Previously, we added a suffix indicating the type of Tensor we were saving, e.g. # "mask", "noise", "masked_latents", etc. # # Retaining that capability in this wrapper would require either many different methods - # to save latents, or extra args for this method. Instead of complicating the API, we - # will use the same naming scheme for all latents. + # to save tensors, or extra args for this method. Instead of complicating the API, we + # will use the same naming scheme for all tensors. # # This has a very minor impact as we don't use them after a session completes. - # Previously, invocations chose the name for their latents. This is a bit risky, so we + # Previously, invocations chose the name for their tensors. This is a bit risky, so we # will generate a name for them instead. We use a uuid to ensure the name is unique. # - # Because the name of the latents file will includes the session and invocation IDs, + # Because the name of the tensors file will includes the session and invocation IDs, # we don't need to worry about collisions. A truncated UUIDv4 is fine. name = f"{self._context_data.session_id}__{self._context_data.invocation.id}__{uuid_string()[:7]}" - self._services.latents.save( + self._services.tensors.save( name=name, data=tensor, ) return name - def get(self, latents_name: str) -> Tensor: + def get(self, tensor_name: str) -> Tensor: """ - Gets a latents tensor by name. + Gets a tensor by name. - :param latents_name: The name of the latents tensor to get. + :param tensor_name: The name of the tensor to get. """ - return self._services.latents.get(latents_name) + return self._services.tensors.get(tensor_name) class ConditioningInterface(InvocationContextInterface): - # TODO(psyche): We are (ab)using the latents storage service as a general pickle storage - # service, but it is typed to work with Tensors only. We have to fudge the types here. def save(self, conditioning_data: ConditioningFieldData) -> str: """ Saves a conditioning data object, returning its name. @@ -265,15 +263,12 @@ def save(self, conditioning_data: ConditioningFieldData) -> str: :param conditioning_context_data: The conditioning data to save. """ - # Conditioning data is *not* a Tensor, so we will suffix it to indicate this. - # - # See comment for `LatentsInterface.save` for more info about this method (it's very - # similar). + # See comment in TensorsInterface.save for why we generate the name here. - name = f"{self._context_data.session_id}__{self._context_data.invocation.id}__{uuid_string()[:7]}__conditioning" - self._services.latents.save( + name = f"{self._context_data.session_id}__{self._context_data.invocation.id}__{uuid_string()[:7]}" + self._services.conditioning.save( name=name, - data=conditioning_data, # type: ignore [arg-type] + data=conditioning_data, ) return name @@ -284,7 +279,7 @@ def get(self, conditioning_name: str) -> ConditioningFieldData: :param conditioning_name: The name of the conditioning data to get. """ - return self._services.latents.get(conditioning_name) # type: ignore [return-value] + return self._services.conditioning.get(conditioning_name) class ModelsInterface(InvocationContextInterface): @@ -400,7 +395,7 @@ class InvocationContext: def __init__( self, images: ImagesInterface, - latents: LatentsInterface, + tensors: TensorsInterface, conditioning: ConditioningInterface, models: ModelsInterface, logger: LoggerInterface, @@ -412,8 +407,8 @@ def __init__( ) -> None: self.images = images """Provides methods to save, get and update images and their metadata.""" - self.latents = latents - """Provides methods to save and get latents tensors, including image, noise, masks, and masked images.""" + self.tensors = tensors + """Provides methods to save and get tensors, including image, noise, masks, and masked images.""" self.conditioning = conditioning """Provides methods to save and get conditioning data.""" self.models = models @@ -532,7 +527,7 @@ def build_invocation_context( logger = LoggerInterface(services=services, context_data=context_data) images = ImagesInterface(services=services, context_data=context_data) - latents = LatentsInterface(services=services, context_data=context_data) + tensors = TensorsInterface(services=services, context_data=context_data) models = ModelsInterface(services=services, context_data=context_data) config = ConfigInterface(services=services, context_data=context_data) util = UtilInterface(services=services, context_data=context_data) @@ -543,7 +538,7 @@ def build_invocation_context( images=images, logger=logger, config=config, - latents=latents, + tensors=tensors, models=models, context_data=context_data, util=util, From 35c73e84d6093a221fca29f9872e4c0612f56574 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 18:17:23 +1100 Subject: [PATCH 035/340] tidy(nodes): do not refer to files as latents in `PickleStorageTorch` --- .../pickle_storage/pickle_storage_torch.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/invokeai/app/services/pickle_storage/pickle_storage_torch.py b/invokeai/app/services/pickle_storage/pickle_storage_torch.py index 0b3c9af7a33..7b18dc0625e 100644 --- a/invokeai/app/services/pickle_storage/pickle_storage_torch.py +++ b/invokeai/app/services/pickle_storage/pickle_storage_torch.py @@ -48,15 +48,15 @@ def _delete_all_items(self) -> None: if not self._invoker: raise ValueError("Invoker is not set. Must call `start()` first.") - deleted_latents_count = 0 + deleted_count = 0 freed_space = 0 - for latents_file in Path(self._output_folder).glob("*"): - if latents_file.is_file(): - freed_space += latents_file.stat().st_size - deleted_latents_count += 1 - latents_file.unlink() - if deleted_latents_count > 0: + for file in Path(self._output_folder).glob("*"): + if file.is_file(): + freed_space += file.stat().st_size + deleted_count += 1 + file.unlink() + if deleted_count > 0: freed_space_in_mb = round(freed_space / 1024 / 1024, 2) self._invoker.services.logger.info( - f"Deleted {deleted_latents_count} {self._item_type_name} files (freed {freed_space_in_mb}MB)" + f"Deleted {deleted_count} {self._item_type_name} files (freed {freed_space_in_mb}MB)" ) From 89ffcba7698e7c0d649303cf9f7dcff579c13ea1 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 18:35:58 +1100 Subject: [PATCH 036/340] fix(nodes): add super init to `PickleStorageTorch` --- invokeai/app/services/pickle_storage/pickle_storage_torch.py | 1 + 1 file changed, 1 insertion(+) diff --git a/invokeai/app/services/pickle_storage/pickle_storage_torch.py b/invokeai/app/services/pickle_storage/pickle_storage_torch.py index 7b18dc0625e..de411bbf47d 100644 --- a/invokeai/app/services/pickle_storage/pickle_storage_torch.py +++ b/invokeai/app/services/pickle_storage/pickle_storage_torch.py @@ -15,6 +15,7 @@ class PickleStorageTorch(PickleStorageBase[T]): """Responsible for storing and retrieving non-serializable data using `torch.save` and `torch.load`.""" def __init__(self, output_folder: Path, item_type_name: "str"): + super().__init__() self._output_folder = output_folder self._output_folder.mkdir(parents=True, exist_ok=True) self._item_type_name = item_type_name From a4d0d8bd4ec32d5ae6aa6807a7f27bf966d8dcfc Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 18:43:33 +1100 Subject: [PATCH 037/340] feat(nodes): ItemStorageABC typevar no longer bound to pydantic.BaseModel This bound is totally unnecessary. There's no requirement for any implementation of `ItemStorageABC` to work only on pydantic models. --- invokeai/app/services/item_storage/item_storage_base.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/invokeai/app/services/item_storage/item_storage_base.py b/invokeai/app/services/item_storage/item_storage_base.py index c93edf5188d..d7366791594 100644 --- a/invokeai/app/services/item_storage/item_storage_base.py +++ b/invokeai/app/services/item_storage/item_storage_base.py @@ -1,9 +1,7 @@ from abc import ABC, abstractmethod from typing import Callable, Generic, TypeVar -from pydantic import BaseModel - -T = TypeVar("T", bound=BaseModel) +T = TypeVar("T") class ItemStorageABC(ABC, Generic[T]): From 971851d62eaf4b877ce0a1fd873ccba8e6159f18 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 18:44:30 +1100 Subject: [PATCH 038/340] tidy(nodes): do not refer to files as latents in `PickleStorageTorch` (again) --- .../services/pickle_storage/pickle_storage_torch.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/invokeai/app/services/pickle_storage/pickle_storage_torch.py b/invokeai/app/services/pickle_storage/pickle_storage_torch.py index de411bbf47d..16f0d7bb7ad 100644 --- a/invokeai/app/services/pickle_storage/pickle_storage_torch.py +++ b/invokeai/app/services/pickle_storage/pickle_storage_torch.py @@ -25,17 +25,17 @@ def start(self, invoker: Invoker) -> None: self._delete_all_items() def get(self, name: str) -> T: - latent_path = self._get_path(name) - return torch.load(latent_path) + file_path = self._get_path(name) + return torch.load(file_path) def save(self, name: str, data: T) -> None: self._output_folder.mkdir(parents=True, exist_ok=True) - latent_path = self._get_path(name) - torch.save(data, latent_path) + file_path = self._get_path(name) + torch.save(data, file_path) def delete(self, name: str) -> None: - latent_path = self._get_path(name) - latent_path.unlink() + file_path = self._get_path(name) + file_path.unlink() def _get_path(self, name: str) -> Path: return self._output_folder / name From 023ed835924e3782e66140ade7e0ba0bf92b9ab3 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 19:39:03 +1100 Subject: [PATCH 039/340] feat(nodes): use `ItemStorageABC` for tensors and conditioning Turns out `ItemStorageABC` was almost identical to `PickleStorageBase`. Instead of maintaining separate classes, we can use `ItemStorageABC` for both. There's only one change needed - the `ItemStorageABC.set` method must return the newly stored item's ID. This allows us to let the service handle the responsibility of naming the item, but still create the requisite output objects during node execution. The naming implementation is improved here. It extracts the name of the generic and appends a UUID to that string when saving items. --- invokeai/app/api/dependencies.py | 10 +-- invokeai/app/services/invocation_services.py | 5 +- .../item_storage/item_storage_base.py | 2 +- .../item_storage_ephemeral_disk.py | 72 +++++++++++++++++++ .../item_storage_forward_cache.py | 61 ++++++++++++++++ .../item_storage/item_storage_memory.py | 3 +- .../pickle_storage/pickle_storage_base.py | 45 ------------ .../pickle_storage_forward_cache.py | 58 --------------- .../pickle_storage/pickle_storage_torch.py | 63 ---------------- .../app/services/shared/invocation_context.py | 30 +------- 10 files changed, 145 insertions(+), 204 deletions(-) create mode 100644 invokeai/app/services/item_storage/item_storage_ephemeral_disk.py create mode 100644 invokeai/app/services/item_storage/item_storage_forward_cache.py delete mode 100644 invokeai/app/services/pickle_storage/pickle_storage_base.py delete mode 100644 invokeai/app/services/pickle_storage/pickle_storage_forward_cache.py delete mode 100644 invokeai/app/services/pickle_storage/pickle_storage_torch.py diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 6bb0915cb6e..d6fd970a22d 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -4,9 +4,9 @@ import torch +from invokeai.app.services.item_storage.item_storage_ephemeral_disk import ItemStorageEphemeralDisk +from invokeai.app.services.item_storage.item_storage_forward_cache import ItemStorageForwardCache from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory -from invokeai.app.services.pickle_storage.pickle_storage_forward_cache import PickleStorageForwardCache -from invokeai.app.services.pickle_storage.pickle_storage_torch import PickleStorageTorch from invokeai.app.services.shared.sqlite.sqlite_util import init_db from invokeai.backend.model_manager.metadata import ModelMetadataStore from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData @@ -90,9 +90,9 @@ def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger image_records = SqliteImageRecordStorage(db=db) images = ImageService() invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size) - tensors = PickleStorageForwardCache(PickleStorageTorch[torch.Tensor](output_folder / "tensors", "tensor")) - conditioning = PickleStorageForwardCache( - PickleStorageTorch[ConditioningFieldData](output_folder / "conditioning", "conditioning") + tensors = ItemStorageForwardCache(ItemStorageEphemeralDisk[torch.Tensor](output_folder / "tensors")) + conditioning = ItemStorageForwardCache( + ItemStorageEphemeralDisk[ConditioningFieldData](output_folder / "conditioning") ) model_manager = ModelManagerService(config, logger) model_record_service = ModelRecordServiceSQL(db=db) diff --git a/invokeai/app/services/invocation_services.py b/invokeai/app/services/invocation_services.py index 81885781acb..69599d83a4b 100644 --- a/invokeai/app/services/invocation_services.py +++ b/invokeai/app/services/invocation_services.py @@ -29,7 +29,6 @@ from .model_manager.model_manager_base import ModelManagerServiceBase from .model_records import ModelRecordServiceBase from .names.names_base import NameServiceBase - from .pickle_storage.pickle_storage_base import PickleStorageBase from .session_processor.session_processor_base import SessionProcessorBase from .session_queue.session_queue_base import SessionQueueBase from .shared.graph import GraphExecutionState @@ -66,8 +65,8 @@ def __init__( names: "NameServiceBase", urls: "UrlServiceBase", workflow_records: "WorkflowRecordsStorageBase", - tensors: "PickleStorageBase[torch.Tensor]", - conditioning: "PickleStorageBase[ConditioningFieldData]", + tensors: "ItemStorageABC[torch.Tensor]", + conditioning: "ItemStorageABC[ConditioningFieldData]", ): self.board_images = board_images self.board_image_records = board_image_records diff --git a/invokeai/app/services/item_storage/item_storage_base.py b/invokeai/app/services/item_storage/item_storage_base.py index d7366791594..f2d62ea45fb 100644 --- a/invokeai/app/services/item_storage/item_storage_base.py +++ b/invokeai/app/services/item_storage/item_storage_base.py @@ -26,7 +26,7 @@ def get(self, item_id: str) -> T: pass @abstractmethod - def set(self, item: T) -> None: + def set(self, item: T) -> str: """ Sets the item. The id will be extracted based on id_field. :param item: the item to set diff --git a/invokeai/app/services/item_storage/item_storage_ephemeral_disk.py b/invokeai/app/services/item_storage/item_storage_ephemeral_disk.py new file mode 100644 index 00000000000..9843d1e54bc --- /dev/null +++ b/invokeai/app/services/item_storage/item_storage_ephemeral_disk.py @@ -0,0 +1,72 @@ +import typing +from pathlib import Path +from typing import Optional, TypeVar + +import torch + +from invokeai.app.services.invoker import Invoker +from invokeai.app.services.item_storage.item_storage_base import ItemStorageABC +from invokeai.app.util.misc import uuid_string + +T = TypeVar("T") + + +class ItemStorageEphemeralDisk(ItemStorageABC[T]): + """Provides arbitrary item storage with a disk-backed ephemeral storage. The storage is cleared at startup.""" + + def __init__(self, output_folder: Path): + super().__init__() + self._output_folder = output_folder + self._output_folder.mkdir(parents=True, exist_ok=True) + self.__item_class_name: Optional[str] = None + + @property + def _item_class_name(self) -> str: + if not self.__item_class_name: + # `__orig_class__` is not available in the constructor for some technical, undoubtedly very pythonic reason + self.__item_class_name = typing.get_args(self.__orig_class__)[0].__name__ # pyright: ignore [reportUnknownMemberType, reportGeneralTypeIssues] + return self.__item_class_name + + def start(self, invoker: Invoker) -> None: + self._invoker = invoker + self._delete_all_items() + + def get(self, item_id: str) -> T: + file_path = self._get_path(item_id) + return torch.load(file_path) # pyright: ignore [reportUnknownMemberType] + + def set(self, item: T) -> str: + self._output_folder.mkdir(parents=True, exist_ok=True) + item_id = f"{self._item_class_name}_{uuid_string()}" + file_path = self._get_path(item_id) + torch.save(item, file_path) # pyright: ignore [reportUnknownMemberType] + return item_id + + def delete(self, item_id: str) -> None: + file_path = self._get_path(item_id) + file_path.unlink() + + def _get_path(self, item_id: str) -> Path: + return self._output_folder / item_id + + def _delete_all_items(self) -> None: + """ + Deletes all pickled items from disk. + Must be called after we have access to `self._invoker` (e.g. in `start()`). + """ + + if not self._invoker: + raise ValueError("Invoker is not set. Must call `start()` first.") + + deleted_count = 0 + freed_space = 0 + for file in Path(self._output_folder).glob("*"): + if file.is_file(): + freed_space += file.stat().st_size + deleted_count += 1 + file.unlink() + if deleted_count > 0: + freed_space_in_mb = round(freed_space / 1024 / 1024, 2) + self._invoker.services.logger.info( + f"Deleted {deleted_count} {self._item_class_name} files (freed {freed_space_in_mb}MB)" + ) diff --git a/invokeai/app/services/item_storage/item_storage_forward_cache.py b/invokeai/app/services/item_storage/item_storage_forward_cache.py new file mode 100644 index 00000000000..d1fe8e13fa9 --- /dev/null +++ b/invokeai/app/services/item_storage/item_storage_forward_cache.py @@ -0,0 +1,61 @@ +from queue import Queue +from typing import Optional, TypeVar + +from invokeai.app.services.invoker import Invoker +from invokeai.app.services.item_storage.item_storage_base import ItemStorageABC + +T = TypeVar("T") + + +class ItemStorageForwardCache(ItemStorageABC[T]): + """Provides a simple forward cache for an underlying storage. The cache is LRU and has a maximum size.""" + + def __init__(self, underlying_storage: ItemStorageABC[T], max_cache_size: int = 20): + super().__init__() + self._underlying_storage = underlying_storage + self._cache: dict[str, T] = {} + self._cache_ids = Queue[str]() + self._max_cache_size = max_cache_size + + def start(self, invoker: Invoker) -> None: + self._invoker = invoker + start_op = getattr(self._underlying_storage, "start", None) + if callable(start_op): + start_op(invoker) + + def stop(self, invoker: Invoker) -> None: + self._invoker = invoker + stop_op = getattr(self._underlying_storage, "stop", None) + if callable(stop_op): + stop_op(invoker) + + def get(self, item_id: str) -> T: + cache_item = self._get_cache(item_id) + if cache_item is not None: + return cache_item + + latent = self._underlying_storage.get(item_id) + self._set_cache(item_id, latent) + return latent + + def set(self, item: T) -> str: + item_id = self._underlying_storage.set(item) + self._set_cache(item_id, item) + self._on_changed(item) + return item_id + + def delete(self, item_id: str) -> None: + self._underlying_storage.delete(item_id) + if item_id in self._cache: + del self._cache[item_id] + self._on_deleted(item_id) + + def _get_cache(self, item_id: str) -> Optional[T]: + return None if item_id not in self._cache else self._cache[item_id] + + def _set_cache(self, item_id: str, data: T): + if item_id not in self._cache: + self._cache[item_id] = data + self._cache_ids.put(item_id) + if self._cache_ids.qsize() > self._max_cache_size: + self._cache.pop(self._cache_ids.get()) diff --git a/invokeai/app/services/item_storage/item_storage_memory.py b/invokeai/app/services/item_storage/item_storage_memory.py index d8dd0e06645..6d028745164 100644 --- a/invokeai/app/services/item_storage/item_storage_memory.py +++ b/invokeai/app/services/item_storage/item_storage_memory.py @@ -34,7 +34,7 @@ def get(self, item_id: str) -> T: self._items[item_id] = item return item - def set(self, item: T) -> None: + def set(self, item: T) -> str: item_id = getattr(item, self._id_field) if item_id in self._items: # If item already exists, remove it and add it to the end @@ -44,6 +44,7 @@ def set(self, item: T) -> None: self._items.popitem(last=False) self._items[item_id] = item self._on_changed(item) + return item_id def delete(self, item_id: str) -> None: # This is a no-op if the item doesn't exist. diff --git a/invokeai/app/services/pickle_storage/pickle_storage_base.py b/invokeai/app/services/pickle_storage/pickle_storage_base.py deleted file mode 100644 index 558b97c0f1b..00000000000 --- a/invokeai/app/services/pickle_storage/pickle_storage_base.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) - -from abc import ABC, abstractmethod -from typing import Callable, Generic, TypeVar - -T = TypeVar("T") - - -class PickleStorageBase(ABC, Generic[T]): - """Responsible for storing and retrieving non-serializable data using a pickler.""" - - _on_changed_callbacks: list[Callable[[T], None]] - _on_deleted_callbacks: list[Callable[[str], None]] - - def __init__(self) -> None: - self._on_changed_callbacks = [] - self._on_deleted_callbacks = [] - - @abstractmethod - def get(self, name: str) -> T: - pass - - @abstractmethod - def save(self, name: str, data: T) -> None: - pass - - @abstractmethod - def delete(self, name: str) -> None: - pass - - def on_changed(self, on_changed: Callable[[T], None]) -> None: - """Register a callback for when an item is changed""" - self._on_changed_callbacks.append(on_changed) - - def on_deleted(self, on_deleted: Callable[[str], None]) -> None: - """Register a callback for when an item is deleted""" - self._on_deleted_callbacks.append(on_deleted) - - def _on_changed(self, item: T) -> None: - for callback in self._on_changed_callbacks: - callback(item) - - def _on_deleted(self, item_id: str) -> None: - for callback in self._on_deleted_callbacks: - callback(item_id) diff --git a/invokeai/app/services/pickle_storage/pickle_storage_forward_cache.py b/invokeai/app/services/pickle_storage/pickle_storage_forward_cache.py deleted file mode 100644 index 3002d9e045d..00000000000 --- a/invokeai/app/services/pickle_storage/pickle_storage_forward_cache.py +++ /dev/null @@ -1,58 +0,0 @@ -from queue import Queue -from typing import Optional, TypeVar - -from invokeai.app.services.invoker import Invoker -from invokeai.app.services.pickle_storage.pickle_storage_base import PickleStorageBase - -T = TypeVar("T") - - -class PickleStorageForwardCache(PickleStorageBase[T]): - def __init__(self, underlying_storage: PickleStorageBase[T], max_cache_size: int = 20): - super().__init__() - self._underlying_storage = underlying_storage - self._cache: dict[str, T] = {} - self._cache_ids = Queue[str]() - self._max_cache_size = max_cache_size - - def start(self, invoker: Invoker) -> None: - self._invoker = invoker - start_op = getattr(self._underlying_storage, "start", None) - if callable(start_op): - start_op(invoker) - - def stop(self, invoker: Invoker) -> None: - self._invoker = invoker - stop_op = getattr(self._underlying_storage, "stop", None) - if callable(stop_op): - stop_op(invoker) - - def get(self, name: str) -> T: - cache_item = self._get_cache(name) - if cache_item is not None: - return cache_item - - latent = self._underlying_storage.get(name) - self._set_cache(name, latent) - return latent - - def save(self, name: str, data: T) -> None: - self._underlying_storage.save(name, data) - self._set_cache(name, data) - self._on_changed(data) - - def delete(self, name: str) -> None: - self._underlying_storage.delete(name) - if name in self._cache: - del self._cache[name] - self._on_deleted(name) - - def _get_cache(self, name: str) -> Optional[T]: - return None if name not in self._cache else self._cache[name] - - def _set_cache(self, name: str, data: T): - if name not in self._cache: - self._cache[name] = data - self._cache_ids.put(name) - if self._cache_ids.qsize() > self._max_cache_size: - self._cache.pop(self._cache_ids.get()) diff --git a/invokeai/app/services/pickle_storage/pickle_storage_torch.py b/invokeai/app/services/pickle_storage/pickle_storage_torch.py deleted file mode 100644 index 16f0d7bb7ad..00000000000 --- a/invokeai/app/services/pickle_storage/pickle_storage_torch.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) - -from pathlib import Path -from typing import TypeVar - -import torch - -from invokeai.app.services.invoker import Invoker -from invokeai.app.services.pickle_storage.pickle_storage_base import PickleStorageBase - -T = TypeVar("T") - - -class PickleStorageTorch(PickleStorageBase[T]): - """Responsible for storing and retrieving non-serializable data using `torch.save` and `torch.load`.""" - - def __init__(self, output_folder: Path, item_type_name: "str"): - super().__init__() - self._output_folder = output_folder - self._output_folder.mkdir(parents=True, exist_ok=True) - self._item_type_name = item_type_name - - def start(self, invoker: Invoker) -> None: - self._invoker = invoker - self._delete_all_items() - - def get(self, name: str) -> T: - file_path = self._get_path(name) - return torch.load(file_path) - - def save(self, name: str, data: T) -> None: - self._output_folder.mkdir(parents=True, exist_ok=True) - file_path = self._get_path(name) - torch.save(data, file_path) - - def delete(self, name: str) -> None: - file_path = self._get_path(name) - file_path.unlink() - - def _get_path(self, name: str) -> Path: - return self._output_folder / name - - def _delete_all_items(self) -> None: - """ - Deletes all pickled items from disk. - Must be called after we have access to `self._invoker` (e.g. in `start()`). - """ - - if not self._invoker: - raise ValueError("Invoker is not set. Must call `start()` first.") - - deleted_count = 0 - freed_space = 0 - for file in Path(self._output_folder).glob("*"): - if file.is_file(): - freed_space += file.stat().st_size - deleted_count += 1 - file.unlink() - if deleted_count > 0: - freed_space_in_mb = round(freed_space / 1024 / 1024, 2) - self._invoker.services.logger.info( - f"Deleted {deleted_count} {self._item_type_name} files (freed {freed_space_in_mb}MB)" - ) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 6756b1f5c6c..baff47a3df4 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -12,7 +12,6 @@ from invokeai.app.services.images.images_common import ImageDTO from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID -from invokeai.app.util.misc import uuid_string from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.backend.model_management.model_manager import ModelInfo from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType @@ -224,26 +223,7 @@ def save(self, tensor: Tensor) -> str: :param tensor: The tensor to save. """ - # Previously, we added a suffix indicating the type of Tensor we were saving, e.g. - # "mask", "noise", "masked_latents", etc. - # - # Retaining that capability in this wrapper would require either many different methods - # to save tensors, or extra args for this method. Instead of complicating the API, we - # will use the same naming scheme for all tensors. - # - # This has a very minor impact as we don't use them after a session completes. - - # Previously, invocations chose the name for their tensors. This is a bit risky, so we - # will generate a name for them instead. We use a uuid to ensure the name is unique. - # - # Because the name of the tensors file will includes the session and invocation IDs, - # we don't need to worry about collisions. A truncated UUIDv4 is fine. - - name = f"{self._context_data.session_id}__{self._context_data.invocation.id}__{uuid_string()[:7]}" - self._services.tensors.save( - name=name, - data=tensor, - ) + name = self._services.tensors.set(item=tensor) return name def get(self, tensor_name: str) -> Tensor: @@ -263,13 +243,7 @@ def save(self, conditioning_data: ConditioningFieldData) -> str: :param conditioning_context_data: The conditioning data to save. """ - # See comment in TensorsInterface.save for why we generate the name here. - - name = f"{self._context_data.session_id}__{self._context_data.invocation.id}__{uuid_string()[:7]}" - self._services.conditioning.save( - name=name, - data=conditioning_data, - ) + name = self._services.conditioning.set(item=conditioning_data) return name def get(self, conditioning_name: str) -> ConditioningFieldData: From 7a5ea1c45b515f2741965a693be64e05027494d3 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 19:50:30 +1100 Subject: [PATCH 040/340] feat(nodes): create helper function to generate the item ID --- .../app/services/item_storage/item_storage_ephemeral_disk.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/invokeai/app/services/item_storage/item_storage_ephemeral_disk.py b/invokeai/app/services/item_storage/item_storage_ephemeral_disk.py index 9843d1e54bc..377c9c39b3e 100644 --- a/invokeai/app/services/item_storage/item_storage_ephemeral_disk.py +++ b/invokeai/app/services/item_storage/item_storage_ephemeral_disk.py @@ -37,7 +37,7 @@ def get(self, item_id: str) -> T: def set(self, item: T) -> str: self._output_folder.mkdir(parents=True, exist_ok=True) - item_id = f"{self._item_class_name}_{uuid_string()}" + item_id = self._new_item_id() file_path = self._get_path(item_id) torch.save(item, file_path) # pyright: ignore [reportUnknownMemberType] return item_id @@ -49,6 +49,9 @@ def delete(self, item_id: str) -> None: def _get_path(self, item_id: str) -> Path: return self._output_folder / item_id + def _new_item_id(self) -> str: + return f"{self._item_class_name}_{uuid_string()}" + def _delete_all_items(self) -> None: """ Deletes all pickled items from disk. From 6edd85ebda9964130e17c39a8610767786fd45a2 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 19:51:04 +1100 Subject: [PATCH 041/340] feat(nodes): support custom save and load functions in `ItemStorageEphemeralDisk` --- .../item_storage/item_storage_common.py | 10 +++++++ .../item_storage_ephemeral_disk.py | 26 +++++++++++++++---- 2 files changed, 31 insertions(+), 5 deletions(-) diff --git a/invokeai/app/services/item_storage/item_storage_common.py b/invokeai/app/services/item_storage/item_storage_common.py index 8fd677c71b7..7f9bd7bd4ef 100644 --- a/invokeai/app/services/item_storage/item_storage_common.py +++ b/invokeai/app/services/item_storage/item_storage_common.py @@ -1,5 +1,15 @@ +from pathlib import Path +from typing import Callable, TypeAlias, TypeVar + + class ItemNotFoundError(KeyError): """Raised when an item is not found in storage""" def __init__(self, item_id: str) -> None: super().__init__(f"Item with id {item_id} not found") + + +T = TypeVar("T") + +SaveFunc: TypeAlias = Callable[[T, Path], None] +LoadFunc: TypeAlias = Callable[[Path], T] diff --git a/invokeai/app/services/item_storage/item_storage_ephemeral_disk.py b/invokeai/app/services/item_storage/item_storage_ephemeral_disk.py index 377c9c39b3e..4dc67129dac 100644 --- a/invokeai/app/services/item_storage/item_storage_ephemeral_disk.py +++ b/invokeai/app/services/item_storage/item_storage_ephemeral_disk.py @@ -6,18 +6,31 @@ from invokeai.app.services.invoker import Invoker from invokeai.app.services.item_storage.item_storage_base import ItemStorageABC +from invokeai.app.services.item_storage.item_storage_common import LoadFunc, SaveFunc from invokeai.app.util.misc import uuid_string T = TypeVar("T") class ItemStorageEphemeralDisk(ItemStorageABC[T]): - """Provides arbitrary item storage with a disk-backed ephemeral storage. The storage is cleared at startup.""" - - def __init__(self, output_folder: Path): + """Provides a disk-backed ephemeral storage. The storage is cleared at startup. + + :param output_folder: The folder where the items will be stored + :param save: The function to use to save the items to disk [torch.save] + :param load: The function to use to load the items from disk [torch.load] + """ + + def __init__( + self, + output_folder: Path, + save: SaveFunc[T] = torch.save, # pyright: ignore [reportUnknownMemberType] + load: LoadFunc[T] = torch.load, # pyright: ignore [reportUnknownMemberType] + ): super().__init__() self._output_folder = output_folder self._output_folder.mkdir(parents=True, exist_ok=True) + self._save = save + self._load = load self.__item_class_name: Optional[str] = None @property @@ -33,13 +46,13 @@ def start(self, invoker: Invoker) -> None: def get(self, item_id: str) -> T: file_path = self._get_path(item_id) - return torch.load(file_path) # pyright: ignore [reportUnknownMemberType] + return self._load(file_path) def set(self, item: T) -> str: self._output_folder.mkdir(parents=True, exist_ok=True) item_id = self._new_item_id() file_path = self._get_path(item_id) - torch.save(item, file_path) # pyright: ignore [reportUnknownMemberType] + self._save(item, file_path) return item_id def delete(self, item_id: str) -> None: @@ -58,6 +71,9 @@ def _delete_all_items(self) -> None: Must be called after we have access to `self._invoker` (e.g. in `start()`). """ + # We could try using a temporary directory here, but they aren't cleared in the event of a crash, so we'd have + # to manually clear them on startup anyways. This is a bit simpler and more reliable. + if not self._invoker: raise ValueError("Invoker is not set. Must call `start()` first.") From 77c3a48abd250f4229b5bb3624bef0a2882b3c32 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 22:54:52 +1100 Subject: [PATCH 042/340] feat(nodes): support custom exception in ephemeral disk storage --- .../item_storage/item_storage_ephemeral_disk.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/invokeai/app/services/item_storage/item_storage_ephemeral_disk.py b/invokeai/app/services/item_storage/item_storage_ephemeral_disk.py index 4dc67129dac..97c767c87d7 100644 --- a/invokeai/app/services/item_storage/item_storage_ephemeral_disk.py +++ b/invokeai/app/services/item_storage/item_storage_ephemeral_disk.py @@ -1,12 +1,12 @@ import typing from pathlib import Path -from typing import Optional, TypeVar +from typing import Optional, Type, TypeVar import torch from invokeai.app.services.invoker import Invoker from invokeai.app.services.item_storage.item_storage_base import ItemStorageABC -from invokeai.app.services.item_storage.item_storage_common import LoadFunc, SaveFunc +from invokeai.app.services.item_storage.item_storage_common import ItemNotFoundError, LoadFunc, SaveFunc from invokeai.app.util.misc import uuid_string T = TypeVar("T") @@ -18,6 +18,7 @@ class ItemStorageEphemeralDisk(ItemStorageABC[T]): :param output_folder: The folder where the items will be stored :param save: The function to use to save the items to disk [torch.save] :param load: The function to use to load the items from disk [torch.load] + :param load_exc: The exception that is raised when an item is not found [FileNotFoundError] """ def __init__( @@ -25,12 +26,14 @@ def __init__( output_folder: Path, save: SaveFunc[T] = torch.save, # pyright: ignore [reportUnknownMemberType] load: LoadFunc[T] = torch.load, # pyright: ignore [reportUnknownMemberType] + load_exc: Type[Exception] = FileNotFoundError, ): super().__init__() self._output_folder = output_folder self._output_folder.mkdir(parents=True, exist_ok=True) self._save = save self._load = load + self._load_exc = load_exc self.__item_class_name: Optional[str] = None @property @@ -46,7 +49,10 @@ def start(self, invoker: Invoker) -> None: def get(self, item_id: str) -> T: file_path = self._get_path(item_id) - return self._load(file_path) + try: + return self._load(file_path) + except self._load_exc as e: + raise ItemNotFoundError(item_id) from e def set(self, item: T) -> str: self._output_folder.mkdir(parents=True, exist_ok=True) From 4813b7d35c7839c9c7250fba4f049db4c02f5727 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 23:30:46 +1100 Subject: [PATCH 043/340] revert(nodes): revert making tensors/conditioning use item storage Turns out they are just different enough in purpose that the implementations would be rather unintuitive. I've made a separate ObjectSerializer service to handle tensors and conditioning. Refined the class a bit too. --- invokeai/app/api/dependencies.py | 10 +- invokeai/app/invocations/latent.py | 24 ++--- invokeai/app/invocations/primitives.py | 2 +- invokeai/app/services/invocation_services.py | 6 +- .../item_storage/item_storage_base.py | 8 +- .../item_storage/item_storage_common.py | 10 -- .../item_storage_ephemeral_disk.py | 97 ------------------- .../item_storage_forward_cache.py | 61 ------------ .../item_storage/item_storage_memory.py | 3 +- .../object_serializer_base.py | 53 ++++++++++ .../object_serializer_common.py | 5 + .../object_serializer_ephemeral_disk.py | 84 ++++++++++++++++ .../object_serializer_forward_cache.py | 61 ++++++++++++ .../app/services/shared/invocation_context.py | 24 ++--- 14 files changed, 243 insertions(+), 205 deletions(-) delete mode 100644 invokeai/app/services/item_storage/item_storage_ephemeral_disk.py delete mode 100644 invokeai/app/services/item_storage/item_storage_forward_cache.py create mode 100644 invokeai/app/services/object_serializer/object_serializer_base.py create mode 100644 invokeai/app/services/object_serializer/object_serializer_common.py create mode 100644 invokeai/app/services/object_serializer/object_serializer_ephemeral_disk.py create mode 100644 invokeai/app/services/object_serializer/object_serializer_forward_cache.py diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index d6fd970a22d..0c80494616f 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -4,9 +4,9 @@ import torch -from invokeai.app.services.item_storage.item_storage_ephemeral_disk import ItemStorageEphemeralDisk -from invokeai.app.services.item_storage.item_storage_forward_cache import ItemStorageForwardCache from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory +from invokeai.app.services.object_serializer.object_serializer_ephemeral_disk import ObjectSerializerEphemeralDisk +from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache from invokeai.app.services.shared.sqlite.sqlite_util import init_db from invokeai.backend.model_manager.metadata import ModelMetadataStore from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData @@ -90,9 +90,9 @@ def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger image_records = SqliteImageRecordStorage(db=db) images = ImageService() invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size) - tensors = ItemStorageForwardCache(ItemStorageEphemeralDisk[torch.Tensor](output_folder / "tensors")) - conditioning = ItemStorageForwardCache( - ItemStorageEphemeralDisk[ConditioningFieldData](output_folder / "conditioning") + tensors = ObjectSerializerForwardCache(ObjectSerializerEphemeralDisk[torch.Tensor](output_folder / "tensors")) + conditioning = ObjectSerializerForwardCache( + ObjectSerializerEphemeralDisk[ConditioningFieldData](output_folder / "conditioning") ) model_manager = ModelManagerService(config, logger) model_record_service = ModelRecordServiceSQL(db=db) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 94440d3e2aa..4137ab6e2f6 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -304,11 +304,11 @@ def get_conditioning_data( unet, seed, ) -> ConditioningData: - positive_cond_data = context.conditioning.get(self.positive_conditioning.conditioning_name) + positive_cond_data = context.conditioning.load(self.positive_conditioning.conditioning_name) c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype) extra_conditioning_info = c.extra_conditioning - negative_cond_data = context.conditioning.get(self.negative_conditioning.conditioning_name) + negative_cond_data = context.conditioning.load(self.negative_conditioning.conditioning_name) uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype) conditioning_data = ConditioningData( @@ -621,10 +621,10 @@ def prep_inpaint_mask(self, context: InvocationContext, latents): if self.denoise_mask is None: return None, None - mask = context.tensors.get(self.denoise_mask.mask_name) + mask = context.tensors.load(self.denoise_mask.mask_name) mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False) if self.denoise_mask.masked_latents_name is not None: - masked_latents = context.tensors.get(self.denoise_mask.masked_latents_name) + masked_latents = context.tensors.load(self.denoise_mask.masked_latents_name) else: masked_latents = None @@ -636,11 +636,11 @@ def invoke(self, context: InvocationContext) -> LatentsOutput: seed = None noise = None if self.noise is not None: - noise = context.tensors.get(self.noise.latents_name) + noise = context.tensors.load(self.noise.latents_name) seed = self.noise.seed if self.latents is not None: - latents = context.tensors.get(self.latents.latents_name) + latents = context.tensors.load(self.latents.latents_name) if seed is None: seed = self.latents.seed @@ -779,7 +779,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard): @torch.no_grad() def invoke(self, context: InvocationContext) -> ImageOutput: - latents = context.tensors.get(self.latents.latents_name) + latents = context.tensors.load(self.latents.latents_name) vae_info = context.models.load(**self.vae.vae.model_dump()) @@ -870,7 +870,7 @@ class ResizeLatentsInvocation(BaseInvocation): antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias) def invoke(self, context: InvocationContext) -> LatentsOutput: - latents = context.tensors.get(self.latents.latents_name) + latents = context.tensors.load(self.latents.latents_name) # TODO: device = choose_torch_device() @@ -911,7 +911,7 @@ class ScaleLatentsInvocation(BaseInvocation): antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias) def invoke(self, context: InvocationContext) -> LatentsOutput: - latents = context.tensors.get(self.latents.latents_name) + latents = context.tensors.load(self.latents.latents_name) # TODO: device = choose_torch_device() @@ -1048,8 +1048,8 @@ class BlendLatentsInvocation(BaseInvocation): alpha: float = InputField(default=0.5, description=FieldDescriptions.blend_alpha) def invoke(self, context: InvocationContext) -> LatentsOutput: - latents_a = context.tensors.get(self.latents_a.latents_name) - latents_b = context.tensors.get(self.latents_b.latents_name) + latents_a = context.tensors.load(self.latents_a.latents_name) + latents_b = context.tensors.load(self.latents_b.latents_name) if latents_a.shape != latents_b.shape: raise Exception("Latents to blend must be the same size.") @@ -1149,7 +1149,7 @@ class CropLatentsCoreInvocation(BaseInvocation): ) def invoke(self, context: InvocationContext) -> LatentsOutput: - latents = context.tensors.get(self.latents.latents_name) + latents = context.tensors.load(self.latents.latents_name) x1 = self.x // LATENT_SCALE_FACTOR y1 = self.y // LATENT_SCALE_FACTOR diff --git a/invokeai/app/invocations/primitives.py b/invokeai/app/invocations/primitives.py index 082d5432ccf..d0f95c92d02 100644 --- a/invokeai/app/invocations/primitives.py +++ b/invokeai/app/invocations/primitives.py @@ -344,7 +344,7 @@ class LatentsInvocation(BaseInvocation): latents: LatentsField = InputField(description="The latents tensor", input=Input.Connection) def invoke(self, context: InvocationContext) -> LatentsOutput: - latents = context.tensors.get(self.latents.latents_name) + latents = context.tensors.load(self.latents.latents_name) return LatentsOutput.build(self.latents.latents_name, latents) diff --git a/invokeai/app/services/invocation_services.py b/invokeai/app/services/invocation_services.py index 69599d83a4b..e893be87636 100644 --- a/invokeai/app/services/invocation_services.py +++ b/invokeai/app/services/invocation_services.py @@ -3,6 +3,8 @@ from typing import TYPE_CHECKING +from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase + if TYPE_CHECKING: from logging import Logger @@ -65,8 +67,8 @@ def __init__( names: "NameServiceBase", urls: "UrlServiceBase", workflow_records: "WorkflowRecordsStorageBase", - tensors: "ItemStorageABC[torch.Tensor]", - conditioning: "ItemStorageABC[ConditioningFieldData]", + tensors: "ObjectSerializerBase[torch.Tensor]", + conditioning: "ObjectSerializerBase[ConditioningFieldData]", ): self.board_images = board_images self.board_image_records = board_image_records diff --git a/invokeai/app/services/item_storage/item_storage_base.py b/invokeai/app/services/item_storage/item_storage_base.py index f2d62ea45fb..ef227ba241c 100644 --- a/invokeai/app/services/item_storage/item_storage_base.py +++ b/invokeai/app/services/item_storage/item_storage_base.py @@ -1,7 +1,9 @@ from abc import ABC, abstractmethod from typing import Callable, Generic, TypeVar -T = TypeVar("T") +from pydantic import BaseModel + +T = TypeVar("T", bound=BaseModel) class ItemStorageABC(ABC, Generic[T]): @@ -26,9 +28,9 @@ def get(self, item_id: str) -> T: pass @abstractmethod - def set(self, item: T) -> str: + def set(self, item: T) -> None: """ - Sets the item. The id will be extracted based on id_field. + Sets the item. :param item: the item to set """ pass diff --git a/invokeai/app/services/item_storage/item_storage_common.py b/invokeai/app/services/item_storage/item_storage_common.py index 7f9bd7bd4ef..8fd677c71b7 100644 --- a/invokeai/app/services/item_storage/item_storage_common.py +++ b/invokeai/app/services/item_storage/item_storage_common.py @@ -1,15 +1,5 @@ -from pathlib import Path -from typing import Callable, TypeAlias, TypeVar - - class ItemNotFoundError(KeyError): """Raised when an item is not found in storage""" def __init__(self, item_id: str) -> None: super().__init__(f"Item with id {item_id} not found") - - -T = TypeVar("T") - -SaveFunc: TypeAlias = Callable[[T, Path], None] -LoadFunc: TypeAlias = Callable[[Path], T] diff --git a/invokeai/app/services/item_storage/item_storage_ephemeral_disk.py b/invokeai/app/services/item_storage/item_storage_ephemeral_disk.py deleted file mode 100644 index 97c767c87d7..00000000000 --- a/invokeai/app/services/item_storage/item_storage_ephemeral_disk.py +++ /dev/null @@ -1,97 +0,0 @@ -import typing -from pathlib import Path -from typing import Optional, Type, TypeVar - -import torch - -from invokeai.app.services.invoker import Invoker -from invokeai.app.services.item_storage.item_storage_base import ItemStorageABC -from invokeai.app.services.item_storage.item_storage_common import ItemNotFoundError, LoadFunc, SaveFunc -from invokeai.app.util.misc import uuid_string - -T = TypeVar("T") - - -class ItemStorageEphemeralDisk(ItemStorageABC[T]): - """Provides a disk-backed ephemeral storage. The storage is cleared at startup. - - :param output_folder: The folder where the items will be stored - :param save: The function to use to save the items to disk [torch.save] - :param load: The function to use to load the items from disk [torch.load] - :param load_exc: The exception that is raised when an item is not found [FileNotFoundError] - """ - - def __init__( - self, - output_folder: Path, - save: SaveFunc[T] = torch.save, # pyright: ignore [reportUnknownMemberType] - load: LoadFunc[T] = torch.load, # pyright: ignore [reportUnknownMemberType] - load_exc: Type[Exception] = FileNotFoundError, - ): - super().__init__() - self._output_folder = output_folder - self._output_folder.mkdir(parents=True, exist_ok=True) - self._save = save - self._load = load - self._load_exc = load_exc - self.__item_class_name: Optional[str] = None - - @property - def _item_class_name(self) -> str: - if not self.__item_class_name: - # `__orig_class__` is not available in the constructor for some technical, undoubtedly very pythonic reason - self.__item_class_name = typing.get_args(self.__orig_class__)[0].__name__ # pyright: ignore [reportUnknownMemberType, reportGeneralTypeIssues] - return self.__item_class_name - - def start(self, invoker: Invoker) -> None: - self._invoker = invoker - self._delete_all_items() - - def get(self, item_id: str) -> T: - file_path = self._get_path(item_id) - try: - return self._load(file_path) - except self._load_exc as e: - raise ItemNotFoundError(item_id) from e - - def set(self, item: T) -> str: - self._output_folder.mkdir(parents=True, exist_ok=True) - item_id = self._new_item_id() - file_path = self._get_path(item_id) - self._save(item, file_path) - return item_id - - def delete(self, item_id: str) -> None: - file_path = self._get_path(item_id) - file_path.unlink() - - def _get_path(self, item_id: str) -> Path: - return self._output_folder / item_id - - def _new_item_id(self) -> str: - return f"{self._item_class_name}_{uuid_string()}" - - def _delete_all_items(self) -> None: - """ - Deletes all pickled items from disk. - Must be called after we have access to `self._invoker` (e.g. in `start()`). - """ - - # We could try using a temporary directory here, but they aren't cleared in the event of a crash, so we'd have - # to manually clear them on startup anyways. This is a bit simpler and more reliable. - - if not self._invoker: - raise ValueError("Invoker is not set. Must call `start()` first.") - - deleted_count = 0 - freed_space = 0 - for file in Path(self._output_folder).glob("*"): - if file.is_file(): - freed_space += file.stat().st_size - deleted_count += 1 - file.unlink() - if deleted_count > 0: - freed_space_in_mb = round(freed_space / 1024 / 1024, 2) - self._invoker.services.logger.info( - f"Deleted {deleted_count} {self._item_class_name} files (freed {freed_space_in_mb}MB)" - ) diff --git a/invokeai/app/services/item_storage/item_storage_forward_cache.py b/invokeai/app/services/item_storage/item_storage_forward_cache.py deleted file mode 100644 index d1fe8e13fa9..00000000000 --- a/invokeai/app/services/item_storage/item_storage_forward_cache.py +++ /dev/null @@ -1,61 +0,0 @@ -from queue import Queue -from typing import Optional, TypeVar - -from invokeai.app.services.invoker import Invoker -from invokeai.app.services.item_storage.item_storage_base import ItemStorageABC - -T = TypeVar("T") - - -class ItemStorageForwardCache(ItemStorageABC[T]): - """Provides a simple forward cache for an underlying storage. The cache is LRU and has a maximum size.""" - - def __init__(self, underlying_storage: ItemStorageABC[T], max_cache_size: int = 20): - super().__init__() - self._underlying_storage = underlying_storage - self._cache: dict[str, T] = {} - self._cache_ids = Queue[str]() - self._max_cache_size = max_cache_size - - def start(self, invoker: Invoker) -> None: - self._invoker = invoker - start_op = getattr(self._underlying_storage, "start", None) - if callable(start_op): - start_op(invoker) - - def stop(self, invoker: Invoker) -> None: - self._invoker = invoker - stop_op = getattr(self._underlying_storage, "stop", None) - if callable(stop_op): - stop_op(invoker) - - def get(self, item_id: str) -> T: - cache_item = self._get_cache(item_id) - if cache_item is not None: - return cache_item - - latent = self._underlying_storage.get(item_id) - self._set_cache(item_id, latent) - return latent - - def set(self, item: T) -> str: - item_id = self._underlying_storage.set(item) - self._set_cache(item_id, item) - self._on_changed(item) - return item_id - - def delete(self, item_id: str) -> None: - self._underlying_storage.delete(item_id) - if item_id in self._cache: - del self._cache[item_id] - self._on_deleted(item_id) - - def _get_cache(self, item_id: str) -> Optional[T]: - return None if item_id not in self._cache else self._cache[item_id] - - def _set_cache(self, item_id: str, data: T): - if item_id not in self._cache: - self._cache[item_id] = data - self._cache_ids.put(item_id) - if self._cache_ids.qsize() > self._max_cache_size: - self._cache.pop(self._cache_ids.get()) diff --git a/invokeai/app/services/item_storage/item_storage_memory.py b/invokeai/app/services/item_storage/item_storage_memory.py index 6d028745164..d8dd0e06645 100644 --- a/invokeai/app/services/item_storage/item_storage_memory.py +++ b/invokeai/app/services/item_storage/item_storage_memory.py @@ -34,7 +34,7 @@ def get(self, item_id: str) -> T: self._items[item_id] = item return item - def set(self, item: T) -> str: + def set(self, item: T) -> None: item_id = getattr(item, self._id_field) if item_id in self._items: # If item already exists, remove it and add it to the end @@ -44,7 +44,6 @@ def set(self, item: T) -> str: self._items.popitem(last=False) self._items[item_id] = item self._on_changed(item) - return item_id def delete(self, item_id: str) -> None: # This is a no-op if the item doesn't exist. diff --git a/invokeai/app/services/object_serializer/object_serializer_base.py b/invokeai/app/services/object_serializer/object_serializer_base.py new file mode 100644 index 00000000000..b01a641d8fb --- /dev/null +++ b/invokeai/app/services/object_serializer/object_serializer_base.py @@ -0,0 +1,53 @@ +from abc import ABC, abstractmethod +from typing import Callable, Generic, TypeVar + +T = TypeVar("T") + + +class ObjectSerializerBase(ABC, Generic[T]): + """Saves and loads arbitrary python objects.""" + + def __init__(self) -> None: + self._on_saved_callbacks: list[Callable[[str, T], None]] = [] + self._on_deleted_callbacks: list[Callable[[str], None]] = [] + + @abstractmethod + def load(self, name: str) -> T: + """ + Loads the object. + :param name: The name of the object to load. + :raises ObjectNotFoundError: if the object is not found + """ + pass + + @abstractmethod + def save(self, obj: T) -> str: + """ + Saves the object, returning its name. + :param obj: The object to save. + """ + pass + + @abstractmethod + def delete(self, name: str) -> None: + """ + Deletes the object, if it exists. + :param name: The name of the object to delete. + """ + pass + + def on_saved(self, on_saved: Callable[[str, T], None]) -> None: + """Register a callback for when an object is saved""" + self._on_saved_callbacks.append(on_saved) + + def on_deleted(self, on_deleted: Callable[[str], None]) -> None: + """Register a callback for when an object is deleted""" + self._on_deleted_callbacks.append(on_deleted) + + def _on_saved(self, name: str, obj: T) -> None: + for callback in self._on_saved_callbacks: + callback(name, obj) + + def _on_deleted(self, name: str) -> None: + for callback in self._on_deleted_callbacks: + callback(name) diff --git a/invokeai/app/services/object_serializer/object_serializer_common.py b/invokeai/app/services/object_serializer/object_serializer_common.py new file mode 100644 index 00000000000..7057386541f --- /dev/null +++ b/invokeai/app/services/object_serializer/object_serializer_common.py @@ -0,0 +1,5 @@ +class ObjectNotFoundError(KeyError): + """Raised when an object is not found while loading""" + + def __init__(self, name: str) -> None: + super().__init__(f"Object with name {name} not found") diff --git a/invokeai/app/services/object_serializer/object_serializer_ephemeral_disk.py b/invokeai/app/services/object_serializer/object_serializer_ephemeral_disk.py new file mode 100644 index 00000000000..afa868b157f --- /dev/null +++ b/invokeai/app/services/object_serializer/object_serializer_ephemeral_disk.py @@ -0,0 +1,84 @@ +import typing +from pathlib import Path +from typing import Optional, TypeVar + +import torch + +from invokeai.app.services.invoker import Invoker +from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase +from invokeai.app.services.object_serializer.object_serializer_common import ObjectNotFoundError +from invokeai.app.util.misc import uuid_string + +T = TypeVar("T") + + +class ObjectSerializerEphemeralDisk(ObjectSerializerBase[T]): + """Provides a disk-backed ephemeral storage for arbitrary python objects. The storage is cleared at startup. + + :param output_folder: The folder where the objects will be stored + """ + + def __init__(self, output_dir: Path): + super().__init__() + self._output_dir = output_dir + self._output_dir.mkdir(parents=True, exist_ok=True) + self.__obj_class_name: Optional[str] = None + + def start(self, invoker: Invoker) -> None: + self._invoker = invoker + self._delete_all() + + def load(self, name: str) -> T: + file_path = self._get_path(name) + try: + return torch.load(file_path) # pyright: ignore [reportUnknownMemberType] + except FileNotFoundError as e: + raise ObjectNotFoundError(name) from e + + def save(self, obj: T) -> str: + name = self._new_name() + file_path = self._get_path(name) + torch.save(obj, file_path) # pyright: ignore [reportUnknownMemberType] + return name + + def delete(self, name: str) -> None: + file_path = self._get_path(name) + file_path.unlink() + + @property + def _obj_class_name(self) -> str: + if not self.__obj_class_name: + # `__orig_class__` is not available in the constructor for some technical, undoubtedly very pythonic reason + self.__obj_class_name = typing.get_args(self.__orig_class__)[0].__name__ # pyright: ignore [reportUnknownMemberType, reportGeneralTypeIssues] + return self.__obj_class_name + + def _get_path(self, name: str) -> Path: + return self._output_dir / name + + def _new_name(self) -> str: + return f"{self._obj_class_name}_{uuid_string()}" + + def _delete_all(self) -> None: + """ + Deletes all objects from disk. + Must be called after we have access to `self._invoker` (e.g. in `start()`). + """ + + # We could try using a temporary directory here, but they aren't cleared in the event of a crash, so we'd have + # to manually clear them on startup anyways. This is a bit simpler and more reliable. + + if not self._invoker: + raise ValueError("Invoker is not set. Must call `start()` first.") + + deleted_count = 0 + freed_space = 0 + for file in Path(self._output_dir).glob("*"): + if file.is_file(): + freed_space += file.stat().st_size + deleted_count += 1 + file.unlink() + if deleted_count > 0: + freed_space_in_mb = round(freed_space / 1024 / 1024, 2) + self._invoker.services.logger.info( + f"Deleted {deleted_count} {self._obj_class_name} files (freed {freed_space_in_mb}MB)" + ) diff --git a/invokeai/app/services/object_serializer/object_serializer_forward_cache.py b/invokeai/app/services/object_serializer/object_serializer_forward_cache.py new file mode 100644 index 00000000000..40e34e65406 --- /dev/null +++ b/invokeai/app/services/object_serializer/object_serializer_forward_cache.py @@ -0,0 +1,61 @@ +from queue import Queue +from typing import Optional, TypeVar + +from invokeai.app.services.invoker import Invoker +from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase + +T = TypeVar("T") + + +class ObjectSerializerForwardCache(ObjectSerializerBase[T]): + """Provides a simple forward cache for an underlying storage. The cache is LRU and has a maximum size.""" + + def __init__(self, underlying_storage: ObjectSerializerBase[T], max_cache_size: int = 20): + super().__init__() + self._underlying_storage = underlying_storage + self._cache: dict[str, T] = {} + self._cache_ids = Queue[str]() + self._max_cache_size = max_cache_size + + def start(self, invoker: Invoker) -> None: + self._invoker = invoker + start_op = getattr(self._underlying_storage, "start", None) + if callable(start_op): + start_op(invoker) + + def stop(self, invoker: Invoker) -> None: + self._invoker = invoker + stop_op = getattr(self._underlying_storage, "stop", None) + if callable(stop_op): + stop_op(invoker) + + def load(self, name: str) -> T: + cache_item = self._get_cache(name) + if cache_item is not None: + return cache_item + + latent = self._underlying_storage.load(name) + self._set_cache(name, latent) + return latent + + def save(self, obj: T) -> str: + name = self._underlying_storage.save(obj) + self._set_cache(name, obj) + self._on_saved(name, obj) + return name + + def delete(self, name: str) -> None: + self._underlying_storage.delete(name) + if name in self._cache: + del self._cache[name] + self._on_deleted(name) + + def _get_cache(self, name: str) -> Optional[T]: + return None if name not in self._cache else self._cache[name] + + def _set_cache(self, name: str, data: T): + if name not in self._cache: + self._cache[name] = data + self._cache_ids.put(name) + if self._cache_ids.qsize() > self._max_cache_size: + self._cache.pop(self._cache_ids.get()) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index baff47a3df4..8c5a821fd0f 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -223,16 +223,16 @@ def save(self, tensor: Tensor) -> str: :param tensor: The tensor to save. """ - name = self._services.tensors.set(item=tensor) - return name + tensor_id = self._services.tensors.save(obj=tensor) + return tensor_id - def get(self, tensor_name: str) -> Tensor: + def load(self, name: str) -> Tensor: """ - Gets a tensor by name. + Loads a tensor by name. - :param tensor_name: The name of the tensor to get. + :param name: The name of the tensor to load. """ - return self._services.tensors.get(tensor_name) + return self._services.tensors.load(name) class ConditioningInterface(InvocationContextInterface): @@ -243,17 +243,17 @@ def save(self, conditioning_data: ConditioningFieldData) -> str: :param conditioning_context_data: The conditioning data to save. """ - name = self._services.conditioning.set(item=conditioning_data) - return name + conditioning_id = self._services.conditioning.save(obj=conditioning_data) + return conditioning_id - def get(self, conditioning_name: str) -> ConditioningFieldData: + def load(self, name: str) -> ConditioningFieldData: """ - Gets conditioning data by name. + Loads conditioning data by name. - :param conditioning_name: The name of the conditioning data to get. + :param name: The name of the conditioning data to load. """ - return self._services.conditioning.get(conditioning_name) + return self._services.conditioning.load(name) class ModelsInterface(InvocationContextInterface): From 42ee9043804dd6ab2f74927e6c86dbd879a9f50a Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 7 Feb 2024 23:43:22 +1100 Subject: [PATCH 044/340] tidy(nodes): remove object serializer on_saved It's unused. --- .../services/object_serializer/object_serializer_base.py | 9 --------- .../object_serializer/object_serializer_forward_cache.py | 1 - 2 files changed, 10 deletions(-) diff --git a/invokeai/app/services/object_serializer/object_serializer_base.py b/invokeai/app/services/object_serializer/object_serializer_base.py index b01a641d8fb..ff19b4a039d 100644 --- a/invokeai/app/services/object_serializer/object_serializer_base.py +++ b/invokeai/app/services/object_serializer/object_serializer_base.py @@ -8,7 +8,6 @@ class ObjectSerializerBase(ABC, Generic[T]): """Saves and loads arbitrary python objects.""" def __init__(self) -> None: - self._on_saved_callbacks: list[Callable[[str, T], None]] = [] self._on_deleted_callbacks: list[Callable[[str], None]] = [] @abstractmethod @@ -36,18 +35,10 @@ def delete(self, name: str) -> None: """ pass - def on_saved(self, on_saved: Callable[[str, T], None]) -> None: - """Register a callback for when an object is saved""" - self._on_saved_callbacks.append(on_saved) - def on_deleted(self, on_deleted: Callable[[str], None]) -> None: """Register a callback for when an object is deleted""" self._on_deleted_callbacks.append(on_deleted) - def _on_saved(self, name: str, obj: T) -> None: - for callback in self._on_saved_callbacks: - callback(name, obj) - def _on_deleted(self, name: str) -> None: for callback in self._on_deleted_callbacks: callback(name) diff --git a/invokeai/app/services/object_serializer/object_serializer_forward_cache.py b/invokeai/app/services/object_serializer/object_serializer_forward_cache.py index 40e34e65406..2a4ecdd844b 100644 --- a/invokeai/app/services/object_serializer/object_serializer_forward_cache.py +++ b/invokeai/app/services/object_serializer/object_serializer_forward_cache.py @@ -41,7 +41,6 @@ def load(self, name: str) -> T: def save(self, obj: T) -> str: name = self._underlying_storage.save(obj) self._set_cache(name, obj) - self._on_saved(name, obj) return name def delete(self, name: str) -> None: From 0cd7cb98a86f7d561d179b4b3cdbfa9bb935d69c Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 8 Feb 2024 00:18:58 +1100 Subject: [PATCH 045/340] feat(nodes): allow `_delete_all` in obj serializer to be called at any time `_delete_all` logged how many items it deleted, and had to be called _after_ service start bc it needed access to logger. Move the logger call to the startup method and return the the deleted stats from `_delete_all`. This lets `_delete_all` be called at any time. --- .../object_serializer_ephemeral_disk.py | 26 +++++++++++-------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/invokeai/app/services/object_serializer/object_serializer_ephemeral_disk.py b/invokeai/app/services/object_serializer/object_serializer_ephemeral_disk.py index afa868b157f..9545d1714d7 100644 --- a/invokeai/app/services/object_serializer/object_serializer_ephemeral_disk.py +++ b/invokeai/app/services/object_serializer/object_serializer_ephemeral_disk.py @@ -1,4 +1,5 @@ import typing +from dataclasses import dataclass from pathlib import Path from typing import Optional, TypeVar @@ -12,6 +13,12 @@ T = TypeVar("T") +@dataclass +class DeleteAllResult: + deleted_count: int + freed_space_bytes: float + + class ObjectSerializerEphemeralDisk(ObjectSerializerBase[T]): """Provides a disk-backed ephemeral storage for arbitrary python objects. The storage is cleared at startup. @@ -26,7 +33,12 @@ def __init__(self, output_dir: Path): def start(self, invoker: Invoker) -> None: self._invoker = invoker - self._delete_all() + delete_all_result = self._delete_all() + if delete_all_result.deleted_count > 0: + freed_space_in_mb = round(delete_all_result.freed_space_bytes / 1024 / 1024, 2) + self._invoker.services.logger.info( + f"Deleted {delete_all_result.deleted_count} {self._obj_class_name} files (freed {freed_space_in_mb}MB)" + ) def load(self, name: str) -> T: file_path = self._get_path(name) @@ -58,18 +70,14 @@ def _get_path(self, name: str) -> Path: def _new_name(self) -> str: return f"{self._obj_class_name}_{uuid_string()}" - def _delete_all(self) -> None: + def _delete_all(self) -> DeleteAllResult: """ Deletes all objects from disk. - Must be called after we have access to `self._invoker` (e.g. in `start()`). """ # We could try using a temporary directory here, but they aren't cleared in the event of a crash, so we'd have # to manually clear them on startup anyways. This is a bit simpler and more reliable. - if not self._invoker: - raise ValueError("Invoker is not set. Must call `start()` first.") - deleted_count = 0 freed_space = 0 for file in Path(self._output_dir).glob("*"): @@ -77,8 +85,4 @@ def _delete_all(self) -> None: freed_space += file.stat().st_size deleted_count += 1 file.unlink() - if deleted_count > 0: - freed_space_in_mb = round(freed_space / 1024 / 1024, 2) - self._invoker.services.logger.info( - f"Deleted {deleted_count} {self._obj_class_name} files (freed {freed_space_in_mb}MB)" - ) + return DeleteAllResult(deleted_count, freed_space) From 7599693a3ce1396da8cf99490723609b1ca49a9b Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 8 Feb 2024 00:20:10 +1100 Subject: [PATCH 046/340] tests: add object serializer tests These test both object serializer and its forward cache implementation. --- .../test_object_serializer_ephemeral_disk.py | 148 ++++++++++++++++++ 1 file changed, 148 insertions(+) create mode 100644 tests/test_object_serializer_ephemeral_disk.py diff --git a/tests/test_object_serializer_ephemeral_disk.py b/tests/test_object_serializer_ephemeral_disk.py new file mode 100644 index 00000000000..fffa65304f6 --- /dev/null +++ b/tests/test_object_serializer_ephemeral_disk.py @@ -0,0 +1,148 @@ +from dataclasses import dataclass +from pathlib import Path + +import pytest +import torch + +from invokeai.app.services.object_serializer.object_serializer_common import ObjectNotFoundError +from invokeai.app.services.object_serializer.object_serializer_ephemeral_disk import ObjectSerializerEphemeralDisk +from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache + + +@dataclass +class MockDataclass: + foo: str + + +@pytest.fixture +def obj_serializer(tmp_path: Path): + return ObjectSerializerEphemeralDisk[MockDataclass](tmp_path) + + +@pytest.fixture +def fwd_cache(tmp_path: Path): + return ObjectSerializerForwardCache(ObjectSerializerEphemeralDisk[MockDataclass](tmp_path), max_cache_size=2) + + +def test_obj_serializer_ephemeral_disk_initializes(tmp_path: Path): + obj_serializer = ObjectSerializerEphemeralDisk[MockDataclass](tmp_path) + assert obj_serializer._output_dir == tmp_path + + +def test_obj_serializer_ephemeral_disk_saves(obj_serializer: ObjectSerializerEphemeralDisk[MockDataclass]): + obj_1 = MockDataclass(foo="bar") + obj_1_name = obj_serializer.save(obj_1) + assert Path(obj_serializer._output_dir, obj_1_name).exists() + + obj_2 = MockDataclass(foo="baz") + obj_2_name = obj_serializer.save(obj_2) + assert Path(obj_serializer._output_dir, obj_2_name).exists() + + +def test_obj_serializer_ephemeral_disk_loads(obj_serializer: ObjectSerializerEphemeralDisk[MockDataclass]): + obj_1 = MockDataclass(foo="bar") + obj_1_name = obj_serializer.save(obj_1) + assert obj_serializer.load(obj_1_name).foo == "bar" + + obj_2 = MockDataclass(foo="baz") + obj_2_name = obj_serializer.save(obj_2) + assert obj_serializer.load(obj_2_name).foo == "baz" + + with pytest.raises(ObjectNotFoundError): + obj_serializer.load("nonexistent_object_name") + + +def test_obj_serializer_ephemeral_disk_deletes(obj_serializer: ObjectSerializerEphemeralDisk[MockDataclass]): + obj_1 = MockDataclass(foo="bar") + obj_1_name = obj_serializer.save(obj_1) + + obj_2 = MockDataclass(foo="bar") + obj_2_name = obj_serializer.save(obj_2) + + obj_serializer.delete(obj_1_name) + assert not Path(obj_serializer._output_dir, obj_1_name).exists() + assert Path(obj_serializer._output_dir, obj_2_name).exists() + + +def test_obj_serializer_ephemeral_disk_deletes_all(obj_serializer: ObjectSerializerEphemeralDisk[MockDataclass]): + obj_1 = MockDataclass(foo="bar") + obj_1_name = obj_serializer.save(obj_1) + + obj_2 = MockDataclass(foo="bar") + obj_2_name = obj_serializer.save(obj_2) + + delete_all_result = obj_serializer._delete_all() + + assert not Path(obj_serializer._output_dir, obj_1_name).exists() + assert not Path(obj_serializer._output_dir, obj_2_name).exists() + assert delete_all_result.deleted_count == 2 + + +def test_obj_serializer_ephemeral_disk_different_types(tmp_path: Path): + obj_serializer = ObjectSerializerEphemeralDisk[MockDataclass](tmp_path) + + obj_1 = MockDataclass(foo="bar") + obj_1_name = obj_serializer.save(obj_1) + obj_1_loaded = obj_serializer.load(obj_1_name) + assert isinstance(obj_1_loaded, MockDataclass) + assert obj_1_loaded.foo == "bar" + assert obj_1_name.startswith("MockDataclass_") + + obj_serializer = ObjectSerializerEphemeralDisk[int](tmp_path) + obj_2_name = obj_serializer.save(9001) + assert obj_serializer.load(obj_2_name) == 9001 + assert obj_2_name.startswith("int_") + + obj_serializer = ObjectSerializerEphemeralDisk[str](tmp_path) + obj_3_name = obj_serializer.save("foo") + assert obj_serializer.load(obj_3_name) == "foo" + assert obj_3_name.startswith("str_") + + obj_serializer = ObjectSerializerEphemeralDisk[torch.Tensor](tmp_path) + obj_4_name = obj_serializer.save(torch.tensor([1, 2, 3])) + obj_4_loaded = obj_serializer.load(obj_4_name) + assert isinstance(obj_4_loaded, torch.Tensor) + assert torch.equal(obj_4_loaded, torch.tensor([1, 2, 3])) + assert obj_4_name.startswith("Tensor_") + + +def test_obj_serializer_fwd_cache_initializes(obj_serializer: ObjectSerializerEphemeralDisk[MockDataclass]): + fwd_cache = ObjectSerializerForwardCache(obj_serializer) + assert fwd_cache._underlying_storage == obj_serializer + + +def test_obj_serializer_fwd_cache_saves_and_loads(fwd_cache: ObjectSerializerForwardCache[MockDataclass]): + obj = MockDataclass(foo="bar") + obj_name = fwd_cache.save(obj) + obj_loaded = fwd_cache.load(obj_name) + obj_underlying = fwd_cache._underlying_storage.load(obj_name) + assert obj_loaded == obj_underlying + assert obj_loaded.foo == "bar" + + +def test_obj_serializer_fwd_cache_respects_cache_size(fwd_cache: ObjectSerializerForwardCache[MockDataclass]): + obj_1 = MockDataclass(foo="bar") + obj_1_name = fwd_cache.save(obj_1) + obj_2 = MockDataclass(foo="baz") + obj_2_name = fwd_cache.save(obj_2) + obj_3 = MockDataclass(foo="qux") + obj_3_name = fwd_cache.save(obj_3) + assert obj_1_name not in fwd_cache._cache + assert obj_2_name in fwd_cache._cache + assert obj_3_name in fwd_cache._cache + # apparently qsize is "not reliable"? + assert fwd_cache._cache_ids.qsize() == 2 + + +def test_obj_serializer_fwd_cache_calls_delete_callback(fwd_cache: ObjectSerializerForwardCache[MockDataclass]): + called_name = None + obj_1 = MockDataclass(foo="bar") + + def on_deleted(name: str): + nonlocal called_name + called_name = name + + fwd_cache.on_deleted(on_deleted) + obj_1_name = fwd_cache.save(obj_1) + fwd_cache.delete(obj_1_name) + assert called_name == obj_1_name From 9b0dc8a36ba0c8b202b38971b84904c4fb99b734 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 8 Feb 2024 00:23:47 +1100 Subject: [PATCH 047/340] tidy(nodes): minor spelling correction --- invokeai/app/services/shared/invocation_context.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 8c5a821fd0f..828d3d84904 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -223,8 +223,8 @@ def save(self, tensor: Tensor) -> str: :param tensor: The tensor to save. """ - tensor_id = self._services.tensors.save(obj=tensor) - return tensor_id + name = self._services.tensors.save(obj=tensor) + return name def load(self, name: str) -> Tensor: """ @@ -243,8 +243,8 @@ def save(self, conditioning_data: ConditioningFieldData) -> str: :param conditioning_context_data: The conditioning data to save. """ - conditioning_id = self._services.conditioning.save(obj=conditioning_data) - return conditioning_id + name = self._services.conditioning.save(obj=conditioning_data) + return name def load(self, name: str) -> ConditioningFieldData: """ From 640b7f0ef494d8d8c263ea5fc37e1ddea91253ab Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 8 Feb 2024 00:36:53 +1100 Subject: [PATCH 048/340] tests: fix broken tests --- .../object_serializer_ephemeral_disk.py | 9 ++++++--- .../object_serializer_forward_cache.py | 10 ++++++---- tests/aa_nodes/test_graph_execution_state.py | 5 +++-- tests/aa_nodes/test_invoker.py | 3 ++- 4 files changed, 17 insertions(+), 10 deletions(-) diff --git a/invokeai/app/services/object_serializer/object_serializer_ephemeral_disk.py b/invokeai/app/services/object_serializer/object_serializer_ephemeral_disk.py index 9545d1714d7..880848a1425 100644 --- a/invokeai/app/services/object_serializer/object_serializer_ephemeral_disk.py +++ b/invokeai/app/services/object_serializer/object_serializer_ephemeral_disk.py @@ -1,15 +1,18 @@ import typing from dataclasses import dataclass from pathlib import Path -from typing import Optional, TypeVar +from typing import TYPE_CHECKING, Optional, TypeVar import torch -from invokeai.app.services.invoker import Invoker from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase from invokeai.app.services.object_serializer.object_serializer_common import ObjectNotFoundError from invokeai.app.util.misc import uuid_string +if TYPE_CHECKING: + from invokeai.app.services.invoker import Invoker + + T = TypeVar("T") @@ -31,7 +34,7 @@ def __init__(self, output_dir: Path): self._output_dir.mkdir(parents=True, exist_ok=True) self.__obj_class_name: Optional[str] = None - def start(self, invoker: Invoker) -> None: + def start(self, invoker: "Invoker") -> None: self._invoker = invoker delete_all_result = self._delete_all() if delete_all_result.deleted_count > 0: diff --git a/invokeai/app/services/object_serializer/object_serializer_forward_cache.py b/invokeai/app/services/object_serializer/object_serializer_forward_cache.py index 2a4ecdd844b..c8ca13982c1 100644 --- a/invokeai/app/services/object_serializer/object_serializer_forward_cache.py +++ b/invokeai/app/services/object_serializer/object_serializer_forward_cache.py @@ -1,11 +1,13 @@ from queue import Queue -from typing import Optional, TypeVar +from typing import TYPE_CHECKING, Optional, TypeVar -from invokeai.app.services.invoker import Invoker from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase T = TypeVar("T") +if TYPE_CHECKING: + from invokeai.app.services.invoker import Invoker + class ObjectSerializerForwardCache(ObjectSerializerBase[T]): """Provides a simple forward cache for an underlying storage. The cache is LRU and has a maximum size.""" @@ -17,13 +19,13 @@ def __init__(self, underlying_storage: ObjectSerializerBase[T], max_cache_size: self._cache_ids = Queue[str]() self._max_cache_size = max_cache_size - def start(self, invoker: Invoker) -> None: + def start(self, invoker: "Invoker") -> None: self._invoker = invoker start_op = getattr(self._underlying_storage, "start", None) if callable(start_op): start_op(invoker) - def stop(self, invoker: Invoker) -> None: + def stop(self, invoker: "Invoker") -> None: self._invoker = invoker stop_op = getattr(self._underlying_storage, "stop", None) if callable(stop_op): diff --git a/tests/aa_nodes/test_graph_execution_state.py b/tests/aa_nodes/test_graph_execution_state.py index aba7c5694f3..27d2d2230a3 100644 --- a/tests/aa_nodes/test_graph_execution_state.py +++ b/tests/aa_nodes/test_graph_execution_state.py @@ -60,7 +60,6 @@ def mock_services() -> InvocationServices: image_records=None, # type: ignore images=None, # type: ignore invocation_cache=MemoryInvocationCache(max_cache_size=0), - latents=None, # type: ignore logger=logging, # type: ignore model_manager=None, # type: ignore model_records=None, # type: ignore @@ -74,6 +73,8 @@ def mock_services() -> InvocationServices: session_queue=None, # type: ignore urls=None, # type: ignore workflow_records=None, # type: ignore + tensors=None, + conditioning=None, ) @@ -89,7 +90,7 @@ def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[B config=None, context_data=None, images=None, - latents=None, + tensors=None, logger=None, models=None, util=None, diff --git a/tests/aa_nodes/test_invoker.py b/tests/aa_nodes/test_invoker.py index 2ae4eab58a0..437ea0f00d3 100644 --- a/tests/aa_nodes/test_invoker.py +++ b/tests/aa_nodes/test_invoker.py @@ -63,7 +63,6 @@ def mock_services() -> InvocationServices: image_records=None, # type: ignore images=None, # type: ignore invocation_cache=MemoryInvocationCache(max_cache_size=0), - latents=None, # type: ignore logger=logging, # type: ignore model_manager=None, # type: ignore model_records=None, # type: ignore @@ -77,6 +76,8 @@ def mock_services() -> InvocationServices: session_queue=None, # type: ignore urls=None, # type: ignore workflow_records=None, # type: ignore + tensors=None, + conditioning=None, ) From fadc88417d972a50b2c308e4a692bebc07474bc4 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 8 Feb 2024 07:55:36 +1100 Subject: [PATCH 049/340] feat(nodes): use LATENT_SCALE_FACTOR const in tensor output builders --- invokeai/app/invocations/noise.py | 5 +++-- invokeai/app/invocations/primitives.py | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/invokeai/app/invocations/noise.py b/invokeai/app/invocations/noise.py index 74b3d6e4cb1..4093030388b 100644 --- a/invokeai/app/invocations/noise.py +++ b/invokeai/app/invocations/noise.py @@ -5,6 +5,7 @@ from pydantic import field_validator from invokeai.app.invocations.fields import FieldDescriptions, InputField, LatentsField, OutputField +from invokeai.app.invocations.latent import LATENT_SCALE_FACTOR from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.misc import SEED_MAX @@ -70,8 +71,8 @@ class NoiseOutput(BaseInvocationOutput): def build(cls, latents_name: str, latents: torch.Tensor, seed: int) -> "NoiseOutput": return cls( noise=LatentsField(latents_name=latents_name, seed=seed), - width=latents.size()[3] * 8, - height=latents.size()[2] * 8, + width=latents.size()[3] * LATENT_SCALE_FACTOR, + height=latents.size()[2] * LATENT_SCALE_FACTOR, ) diff --git a/invokeai/app/invocations/primitives.py b/invokeai/app/invocations/primitives.py index d0f95c92d02..2a9cb8cf9bf 100644 --- a/invokeai/app/invocations/primitives.py +++ b/invokeai/app/invocations/primitives.py @@ -16,6 +16,7 @@ OutputField, UIComponent, ) +from invokeai.app.invocations.latent import LATENT_SCALE_FACTOR from invokeai.app.services.images.images_common import ImageDTO from invokeai.app.services.shared.invocation_context import InvocationContext @@ -321,8 +322,8 @@ class LatentsOutput(BaseInvocationOutput): def build(cls, latents_name: str, latents: torch.Tensor, seed: Optional[int] = None) -> "LatentsOutput": return cls( latents=LatentsField(latents_name=latents_name, seed=seed), - width=latents.size()[3] * 8, - height=latents.size()[2] * 8, + width=latents.size()[3] * LATENT_SCALE_FACTOR, + height=latents.size()[2] * LATENT_SCALE_FACTOR, ) From b7720224173d41e1a363c9b4bfb2d51fc6d7d937 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 8 Feb 2024 08:14:04 +1100 Subject: [PATCH 050/340] Revert "feat(nodes): use LATENT_SCALE_FACTOR const in tensor output builders" This reverts commit ef18fc546560277302f3886e456da9a47e8edce0. --- invokeai/app/invocations/noise.py | 5 ++--- invokeai/app/invocations/primitives.py | 5 ++--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/invokeai/app/invocations/noise.py b/invokeai/app/invocations/noise.py index 4093030388b..74b3d6e4cb1 100644 --- a/invokeai/app/invocations/noise.py +++ b/invokeai/app/invocations/noise.py @@ -5,7 +5,6 @@ from pydantic import field_validator from invokeai.app.invocations.fields import FieldDescriptions, InputField, LatentsField, OutputField -from invokeai.app.invocations.latent import LATENT_SCALE_FACTOR from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.misc import SEED_MAX @@ -71,8 +70,8 @@ class NoiseOutput(BaseInvocationOutput): def build(cls, latents_name: str, latents: torch.Tensor, seed: int) -> "NoiseOutput": return cls( noise=LatentsField(latents_name=latents_name, seed=seed), - width=latents.size()[3] * LATENT_SCALE_FACTOR, - height=latents.size()[2] * LATENT_SCALE_FACTOR, + width=latents.size()[3] * 8, + height=latents.size()[2] * 8, ) diff --git a/invokeai/app/invocations/primitives.py b/invokeai/app/invocations/primitives.py index 2a9cb8cf9bf..d0f95c92d02 100644 --- a/invokeai/app/invocations/primitives.py +++ b/invokeai/app/invocations/primitives.py @@ -16,7 +16,6 @@ OutputField, UIComponent, ) -from invokeai.app.invocations.latent import LATENT_SCALE_FACTOR from invokeai.app.services.images.images_common import ImageDTO from invokeai.app.services.shared.invocation_context import InvocationContext @@ -322,8 +321,8 @@ class LatentsOutput(BaseInvocationOutput): def build(cls, latents_name: str, latents: torch.Tensor, seed: Optional[int] = None) -> "LatentsOutput": return cls( latents=LatentsField(latents_name=latents_name, seed=seed), - width=latents.size()[3] * LATENT_SCALE_FACTOR, - height=latents.size()[2] * LATENT_SCALE_FACTOR, + width=latents.size()[3] * 8, + height=latents.size()[2] * 8, ) From df8948da966c4e09141460af6df43f4cd4fdad58 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 8 Feb 2024 10:57:01 +1100 Subject: [PATCH 051/340] tidy(nodes): clarify comment --- invokeai/app/services/shared/invocation_context.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 828d3d84904..3d06cf92725 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -279,15 +279,8 @@ def load( :param submodel: The submodel of the model to get. """ - # During this call, the model manager emits events with model loading status. The model - # manager itself has access to the events services, but does not have access to the - # required metadata for the events. - # - # For example, it needs access to the node's ID so that the events can be associated - # with the execution of a specific node. - # - # While this is available within the node, it's tedious to need to pass it in on every - # call. We can avoid that by wrapping the method here. + # The model manager emits events as it loads the model. It needs the context data to build + # the event payloads. return self._services.model_manager.get_model( model_name, base_model, model_type, submodel, context_data=self._context_data From 49211f10d570de615700d33165319e024e7d8e01 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 8 Feb 2024 11:05:33 +1100 Subject: [PATCH 052/340] fix(nodes): use `metadata`/`board_id` if provided by user, overriding `WithMetadata`/`WithBoard`-provided values --- .../app/services/shared/invocation_context.py | 23 +++++++++++-------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 3d06cf92725..1ca44b78625 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -167,16 +167,19 @@ def save( **Use this only if you want to override or provide metadata manually!** """ - # If the invocation inherits metadata, use that. Else, use the metadata passed in. - metadata_ = ( - self._context_data.invocation.metadata - if isinstance(self._context_data.invocation, WithMetadata) - else metadata - ) - - # If the invocation inherits WithBoard, use that. Else, use the board_id passed in. - board_ = self._context_data.invocation.board if isinstance(self._context_data.invocation, WithBoard) else None - board_id_ = board_.board_id if board_ is not None else board_id + # If `metadata` is provided directly, use that. Else, use the metadata provided by `WithMetadata`, falling back to None. + metadata_ = None + if metadata: + metadata_ = metadata + elif isinstance(self._context_data.invocation, WithMetadata): + metadata_ = self._context_data.invocation.metadata + + # If `board_id` is provided directly, use that. Else, use the board provided by `WithBoard`, falling back to None. + board_id_ = None + if board_id: + board_id_ = board_id + elif isinstance(self._context_data.invocation, WithBoard) and self._context_data.invocation.board: + board_id_ = self._context_data.invocation.board.board_id return self._services.images.create( image=image, From abc192503a6bf1b9b6bc775ae21aa14ce10ffc27 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 8 Feb 2024 16:09:59 +1100 Subject: [PATCH 053/340] feat(nodes): make delete on startup configurable for obj serializer - The default is to not delete on startup - feels safer. - The two services using this class _do_ delete on startup. - The class has "ephemeral" removed from its name. - Tests & app updated for this change. --- invokeai/app/api/dependencies.py | 8 ++- ...eral_disk.py => object_serializer_disk.py} | 22 ++++--- ...disk.py => test_object_serializer_disk.py} | 64 ++++++++++++++----- 3 files changed, 67 insertions(+), 27 deletions(-) rename invokeai/app/services/object_serializer/{object_serializer_ephemeral_disk.py => object_serializer_disk.py} (77%) rename tests/{test_object_serializer_ephemeral_disk.py => test_object_serializer_disk.py} (65%) diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 0c80494616f..2acb961aa7a 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -5,7 +5,7 @@ import torch from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory -from invokeai.app.services.object_serializer.object_serializer_ephemeral_disk import ObjectSerializerEphemeralDisk +from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache from invokeai.app.services.shared.sqlite.sqlite_util import init_db from invokeai.backend.model_manager.metadata import ModelMetadataStore @@ -90,9 +90,11 @@ def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger image_records = SqliteImageRecordStorage(db=db) images = ImageService() invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size) - tensors = ObjectSerializerForwardCache(ObjectSerializerEphemeralDisk[torch.Tensor](output_folder / "tensors")) + tensors = ObjectSerializerForwardCache( + ObjectSerializerDisk[torch.Tensor](output_folder / "tensors", delete_on_startup=True) + ) conditioning = ObjectSerializerForwardCache( - ObjectSerializerEphemeralDisk[ConditioningFieldData](output_folder / "conditioning") + ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", delete_on_startup=True) ) model_manager = ModelManagerService(config, logger) model_record_service = ModelRecordServiceSQL(db=db) diff --git a/invokeai/app/services/object_serializer/object_serializer_ephemeral_disk.py b/invokeai/app/services/object_serializer/object_serializer_disk.py similarity index 77% rename from invokeai/app/services/object_serializer/object_serializer_ephemeral_disk.py rename to invokeai/app/services/object_serializer/object_serializer_disk.py index 880848a1425..174ff15192d 100644 --- a/invokeai/app/services/object_serializer/object_serializer_ephemeral_disk.py +++ b/invokeai/app/services/object_serializer/object_serializer_disk.py @@ -22,26 +22,30 @@ class DeleteAllResult: freed_space_bytes: float -class ObjectSerializerEphemeralDisk(ObjectSerializerBase[T]): - """Provides a disk-backed ephemeral storage for arbitrary python objects. The storage is cleared at startup. +class ObjectSerializerDisk(ObjectSerializerBase[T]): + """Provides a disk-backed storage for arbitrary python objects. :param output_folder: The folder where the objects will be stored + :param delete_on_startup: If True, all objects in the output folder will be deleted on startup """ - def __init__(self, output_dir: Path): + def __init__(self, output_dir: Path, delete_on_startup: bool = False): super().__init__() self._output_dir = output_dir self._output_dir.mkdir(parents=True, exist_ok=True) + self._delete_on_startup = delete_on_startup self.__obj_class_name: Optional[str] = None def start(self, invoker: "Invoker") -> None: self._invoker = invoker - delete_all_result = self._delete_all() - if delete_all_result.deleted_count > 0: - freed_space_in_mb = round(delete_all_result.freed_space_bytes / 1024 / 1024, 2) - self._invoker.services.logger.info( - f"Deleted {delete_all_result.deleted_count} {self._obj_class_name} files (freed {freed_space_in_mb}MB)" - ) + + if self._delete_on_startup: + delete_all_result = self._delete_all() + if delete_all_result.deleted_count > 0: + freed_space_in_mb = round(delete_all_result.freed_space_bytes / 1024 / 1024, 2) + self._invoker.services.logger.info( + f"Deleted {delete_all_result.deleted_count} {self._obj_class_name} files (freed {freed_space_in_mb}MB)" + ) def load(self, name: str) -> T: file_path = self._get_path(name) diff --git a/tests/test_object_serializer_ephemeral_disk.py b/tests/test_object_serializer_disk.py similarity index 65% rename from tests/test_object_serializer_ephemeral_disk.py rename to tests/test_object_serializer_disk.py index fffa65304f6..5ce1e579013 100644 --- a/tests/test_object_serializer_ephemeral_disk.py +++ b/tests/test_object_serializer_disk.py @@ -1,11 +1,14 @@ from dataclasses import dataclass +from logging import Logger from pathlib import Path +from unittest.mock import Mock import pytest import torch +from invokeai.app.services.invoker import Invoker from invokeai.app.services.object_serializer.object_serializer_common import ObjectNotFoundError -from invokeai.app.services.object_serializer.object_serializer_ephemeral_disk import ObjectSerializerEphemeralDisk +from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache @@ -14,22 +17,31 @@ class MockDataclass: foo: str +def count_files(path: Path): + return len(list(path.iterdir())) + + @pytest.fixture def obj_serializer(tmp_path: Path): - return ObjectSerializerEphemeralDisk[MockDataclass](tmp_path) + return ObjectSerializerDisk[MockDataclass](tmp_path) @pytest.fixture def fwd_cache(tmp_path: Path): - return ObjectSerializerForwardCache(ObjectSerializerEphemeralDisk[MockDataclass](tmp_path), max_cache_size=2) + return ObjectSerializerForwardCache(ObjectSerializerDisk[MockDataclass](tmp_path), max_cache_size=2) + + +@pytest.fixture +def mock_invoker_with_logger(): + return Mock(Invoker, services=Mock(logger=Mock(Logger))) -def test_obj_serializer_ephemeral_disk_initializes(tmp_path: Path): - obj_serializer = ObjectSerializerEphemeralDisk[MockDataclass](tmp_path) +def test_obj_serializer_disk_initializes(tmp_path: Path): + obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path) assert obj_serializer._output_dir == tmp_path -def test_obj_serializer_ephemeral_disk_saves(obj_serializer: ObjectSerializerEphemeralDisk[MockDataclass]): +def test_obj_serializer_disk_saves(obj_serializer: ObjectSerializerDisk[MockDataclass]): obj_1 = MockDataclass(foo="bar") obj_1_name = obj_serializer.save(obj_1) assert Path(obj_serializer._output_dir, obj_1_name).exists() @@ -39,7 +51,7 @@ def test_obj_serializer_ephemeral_disk_saves(obj_serializer: ObjectSerializerEph assert Path(obj_serializer._output_dir, obj_2_name).exists() -def test_obj_serializer_ephemeral_disk_loads(obj_serializer: ObjectSerializerEphemeralDisk[MockDataclass]): +def test_obj_serializer_disk_loads(obj_serializer: ObjectSerializerDisk[MockDataclass]): obj_1 = MockDataclass(foo="bar") obj_1_name = obj_serializer.save(obj_1) assert obj_serializer.load(obj_1_name).foo == "bar" @@ -52,7 +64,7 @@ def test_obj_serializer_ephemeral_disk_loads(obj_serializer: ObjectSerializerEph obj_serializer.load("nonexistent_object_name") -def test_obj_serializer_ephemeral_disk_deletes(obj_serializer: ObjectSerializerEphemeralDisk[MockDataclass]): +def test_obj_serializer_disk_deletes(obj_serializer: ObjectSerializerDisk[MockDataclass]): obj_1 = MockDataclass(foo="bar") obj_1_name = obj_serializer.save(obj_1) @@ -64,7 +76,7 @@ def test_obj_serializer_ephemeral_disk_deletes(obj_serializer: ObjectSerializerE assert Path(obj_serializer._output_dir, obj_2_name).exists() -def test_obj_serializer_ephemeral_disk_deletes_all(obj_serializer: ObjectSerializerEphemeralDisk[MockDataclass]): +def test_obj_serializer_disk_deletes_all(obj_serializer: ObjectSerializerDisk[MockDataclass]): obj_1 = MockDataclass(foo="bar") obj_1_name = obj_serializer.save(obj_1) @@ -78,8 +90,30 @@ def test_obj_serializer_ephemeral_disk_deletes_all(obj_serializer: ObjectSeriali assert delete_all_result.deleted_count == 2 -def test_obj_serializer_ephemeral_disk_different_types(tmp_path: Path): - obj_serializer = ObjectSerializerEphemeralDisk[MockDataclass](tmp_path) +def test_obj_serializer_disk_default_no_delete_on_startup(tmp_path: Path, mock_invoker_with_logger: Invoker): + obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path) + assert obj_serializer._delete_on_startup is False + + obj_1 = MockDataclass(foo="bar") + obj_1_name = obj_serializer.save(obj_1) + + obj_serializer.start(mock_invoker_with_logger) + assert Path(tmp_path, obj_1_name).exists() + + +def test_obj_serializer_disk_delete_on_startup(tmp_path: Path, mock_invoker_with_logger: Invoker): + obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, delete_on_startup=True) + assert obj_serializer._delete_on_startup is True + + obj_1 = MockDataclass(foo="bar") + obj_1_name = obj_serializer.save(obj_1) + + obj_serializer.start(mock_invoker_with_logger) + assert not Path(tmp_path, obj_1_name).exists() + + +def test_obj_serializer_disk_different_types(tmp_path: Path): + obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path) obj_1 = MockDataclass(foo="bar") obj_1_name = obj_serializer.save(obj_1) @@ -88,17 +122,17 @@ def test_obj_serializer_ephemeral_disk_different_types(tmp_path: Path): assert obj_1_loaded.foo == "bar" assert obj_1_name.startswith("MockDataclass_") - obj_serializer = ObjectSerializerEphemeralDisk[int](tmp_path) + obj_serializer = ObjectSerializerDisk[int](tmp_path) obj_2_name = obj_serializer.save(9001) assert obj_serializer.load(obj_2_name) == 9001 assert obj_2_name.startswith("int_") - obj_serializer = ObjectSerializerEphemeralDisk[str](tmp_path) + obj_serializer = ObjectSerializerDisk[str](tmp_path) obj_3_name = obj_serializer.save("foo") assert obj_serializer.load(obj_3_name) == "foo" assert obj_3_name.startswith("str_") - obj_serializer = ObjectSerializerEphemeralDisk[torch.Tensor](tmp_path) + obj_serializer = ObjectSerializerDisk[torch.Tensor](tmp_path) obj_4_name = obj_serializer.save(torch.tensor([1, 2, 3])) obj_4_loaded = obj_serializer.load(obj_4_name) assert isinstance(obj_4_loaded, torch.Tensor) @@ -106,7 +140,7 @@ def test_obj_serializer_ephemeral_disk_different_types(tmp_path: Path): assert obj_4_name.startswith("Tensor_") -def test_obj_serializer_fwd_cache_initializes(obj_serializer: ObjectSerializerEphemeralDisk[MockDataclass]): +def test_obj_serializer_fwd_cache_initializes(obj_serializer: ObjectSerializerDisk[MockDataclass]): fwd_cache = ObjectSerializerForwardCache(obj_serializer) assert fwd_cache._underlying_storage == obj_serializer From e1a8998de27b6beff0ad68231ca0a208386cff74 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 10 Feb 2024 09:41:23 +1100 Subject: [PATCH 054/340] tidy(nodes): do not store unnecessarily store invoker --- .../app/services/object_serializer/object_serializer_disk.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/invokeai/app/services/object_serializer/object_serializer_disk.py b/invokeai/app/services/object_serializer/object_serializer_disk.py index 174ff15192d..b3827e16a92 100644 --- a/invokeai/app/services/object_serializer/object_serializer_disk.py +++ b/invokeai/app/services/object_serializer/object_serializer_disk.py @@ -37,13 +37,11 @@ def __init__(self, output_dir: Path, delete_on_startup: bool = False): self.__obj_class_name: Optional[str] = None def start(self, invoker: "Invoker") -> None: - self._invoker = invoker - if self._delete_on_startup: delete_all_result = self._delete_all() if delete_all_result.deleted_count > 0: freed_space_in_mb = round(delete_all_result.freed_space_bytes / 1024 / 1024, 2) - self._invoker.services.logger.info( + invoker.services.logger.info( f"Deleted {delete_all_result.deleted_count} {self._obj_class_name} files (freed {freed_space_in_mb}MB)" ) From 7c6a7297ebc206cc4675b3a301b9211d9a485a3b Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 10 Feb 2024 10:06:50 +1100 Subject: [PATCH 055/340] tidy(nodes): "latents" -> "obj" --- .../object_serializer/object_serializer_forward_cache.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/invokeai/app/services/object_serializer/object_serializer_forward_cache.py b/invokeai/app/services/object_serializer/object_serializer_forward_cache.py index c8ca13982c1..812731f456a 100644 --- a/invokeai/app/services/object_serializer/object_serializer_forward_cache.py +++ b/invokeai/app/services/object_serializer/object_serializer_forward_cache.py @@ -36,9 +36,9 @@ def load(self, name: str) -> T: if cache_item is not None: return cache_item - latent = self._underlying_storage.load(name) - self._set_cache(name, latent) - return latent + obj = self._underlying_storage.load(name) + self._set_cache(name, obj) + return obj def save(self, obj: T) -> str: name = self._underlying_storage.save(obj) From 46cf3c03646c5c51cc5a853df4556dde7a5a5a49 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 10 Feb 2024 10:10:48 +1100 Subject: [PATCH 056/340] chore(nodes): fix pyright ignore --- .../app/services/object_serializer/object_serializer_disk.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invokeai/app/services/object_serializer/object_serializer_disk.py b/invokeai/app/services/object_serializer/object_serializer_disk.py index b3827e16a92..06f86aa460c 100644 --- a/invokeai/app/services/object_serializer/object_serializer_disk.py +++ b/invokeai/app/services/object_serializer/object_serializer_disk.py @@ -66,7 +66,7 @@ def delete(self, name: str) -> None: def _obj_class_name(self) -> str: if not self.__obj_class_name: # `__orig_class__` is not available in the constructor for some technical, undoubtedly very pythonic reason - self.__obj_class_name = typing.get_args(self.__orig_class__)[0].__name__ # pyright: ignore [reportUnknownMemberType, reportGeneralTypeIssues] + self.__obj_class_name = typing.get_args(self.__orig_class__)[0].__name__ # pyright: ignore [reportUnknownMemberType, reportAttributeAccessIssue] return self.__obj_class_name def _get_path(self, name: str) -> Path: From d3fc317b2c2ec08829db54b86e8b710e1ffa6702 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 10 Feb 2024 10:11:31 +1100 Subject: [PATCH 057/340] chore(nodes): update ObjectSerializerForwardCache docstring --- .../object_serializer/object_serializer_forward_cache.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/invokeai/app/services/object_serializer/object_serializer_forward_cache.py b/invokeai/app/services/object_serializer/object_serializer_forward_cache.py index 812731f456a..b361259a4b1 100644 --- a/invokeai/app/services/object_serializer/object_serializer_forward_cache.py +++ b/invokeai/app/services/object_serializer/object_serializer_forward_cache.py @@ -10,7 +10,10 @@ class ObjectSerializerForwardCache(ObjectSerializerBase[T]): - """Provides a simple forward cache for an underlying storage. The cache is LRU and has a maximum size.""" + """ + Provides a LRU cache for an instance of `ObjectSerializerBase`. + Saving an object to the cache always writes through to the underlying storage. + """ def __init__(self, underlying_storage: ObjectSerializerBase[T], max_cache_size: int = 20): super().__init__() From 72d6383bf4200b17324defebf943443ce821a190 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 10 Feb 2024 18:46:51 +1100 Subject: [PATCH 058/340] tests: test ObjectSerializerDisk class name extraction --- tests/test_object_serializer_disk.py | 29 +++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/tests/test_object_serializer_disk.py b/tests/test_object_serializer_disk.py index 5ce1e579013..2bc7e16937e 100644 --- a/tests/test_object_serializer_disk.py +++ b/tests/test_object_serializer_disk.py @@ -113,28 +113,31 @@ def test_obj_serializer_disk_delete_on_startup(tmp_path: Path, mock_invoker_with def test_obj_serializer_disk_different_types(tmp_path: Path): - obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path) - + obj_serializer_1 = ObjectSerializerDisk[MockDataclass](tmp_path) obj_1 = MockDataclass(foo="bar") - obj_1_name = obj_serializer.save(obj_1) - obj_1_loaded = obj_serializer.load(obj_1_name) + obj_1_name = obj_serializer_1.save(obj_1) + obj_1_loaded = obj_serializer_1.load(obj_1_name) + assert obj_serializer_1._obj_class_name == "MockDataclass" assert isinstance(obj_1_loaded, MockDataclass) assert obj_1_loaded.foo == "bar" assert obj_1_name.startswith("MockDataclass_") - obj_serializer = ObjectSerializerDisk[int](tmp_path) - obj_2_name = obj_serializer.save(9001) - assert obj_serializer.load(obj_2_name) == 9001 + obj_serializer_2 = ObjectSerializerDisk[int](tmp_path) + obj_2_name = obj_serializer_2.save(9001) + assert obj_serializer_2._obj_class_name == "int" + assert obj_serializer_2.load(obj_2_name) == 9001 assert obj_2_name.startswith("int_") - obj_serializer = ObjectSerializerDisk[str](tmp_path) - obj_3_name = obj_serializer.save("foo") - assert obj_serializer.load(obj_3_name) == "foo" + obj_serializer_3 = ObjectSerializerDisk[str](tmp_path) + obj_3_name = obj_serializer_3.save("foo") + assert obj_serializer_3._obj_class_name == "str" + assert obj_serializer_3.load(obj_3_name) == "foo" assert obj_3_name.startswith("str_") - obj_serializer = ObjectSerializerDisk[torch.Tensor](tmp_path) - obj_4_name = obj_serializer.save(torch.tensor([1, 2, 3])) - obj_4_loaded = obj_serializer.load(obj_4_name) + obj_serializer_4 = ObjectSerializerDisk[torch.Tensor](tmp_path) + obj_4_name = obj_serializer_4.save(torch.tensor([1, 2, 3])) + obj_4_loaded = obj_serializer_4.load(obj_4_name) + assert obj_serializer_4._obj_class_name == "Tensor" assert isinstance(obj_4_loaded, torch.Tensor) assert torch.equal(obj_4_loaded, torch.tensor([1, 2, 3])) assert obj_4_name.startswith("Tensor_") From ec4ee989689ceb1c0274862f59923904334a4236 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 10 Feb 2024 19:11:28 +1100 Subject: [PATCH 059/340] feat(nodes): use TemporaryDirectory to handle ephemeral storage in ObjectSerializerDisk Replace `delete_on_startup: bool` & associated logic with `ephemeral: bool` and `TemporaryDirectory`. The temp dir is created inside of `output_dir`. For example, if `output_dir` is `invokeai/outputs/tensors/`, then the temp dir might be `invokeai/outputs/tensors/tmpvj35ht7b/`. The temp dir is cleaned up when the service is stopped, or when it is GC'd if not properly stopped. In the event of a catastrophic crash where the temp files are not cleaned up, the user can delete the tempdir themselves. This situation may not occur in normal use, but if you kill the process, python cannot clean up the temp dir itself. This includes running the app in a debugger and killing the debugger process - something I do relatively often. Tests updated. --- invokeai/app/api/dependencies.py | 4 +- .../object_serializer_disk.py | 56 ++++++++----------- tests/test_object_serializer_disk.py | 53 +++++++----------- 3 files changed, 46 insertions(+), 67 deletions(-) diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 2acb961aa7a..0f2a92b5c8e 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -91,10 +91,10 @@ def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger images = ImageService() invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size) tensors = ObjectSerializerForwardCache( - ObjectSerializerDisk[torch.Tensor](output_folder / "tensors", delete_on_startup=True) + ObjectSerializerDisk[torch.Tensor](output_folder / "tensors", ephemeral=True) ) conditioning = ObjectSerializerForwardCache( - ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", delete_on_startup=True) + ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True) ) model_manager = ModelManagerService(config, logger) model_record_service = ModelRecordServiceSQL(db=db) diff --git a/invokeai/app/services/object_serializer/object_serializer_disk.py b/invokeai/app/services/object_serializer/object_serializer_disk.py index 06f86aa460c..935fec30605 100644 --- a/invokeai/app/services/object_serializer/object_serializer_disk.py +++ b/invokeai/app/services/object_serializer/object_serializer_disk.py @@ -1,3 +1,4 @@ +import tempfile import typing from dataclasses import dataclass from pathlib import Path @@ -23,28 +24,24 @@ class DeleteAllResult: class ObjectSerializerDisk(ObjectSerializerBase[T]): - """Provides a disk-backed storage for arbitrary python objects. + """Disk-backed storage for arbitrary python objects. Serialization is handled by `torch.save` and `torch.load`. - :param output_folder: The folder where the objects will be stored - :param delete_on_startup: If True, all objects in the output folder will be deleted on startup + :param output_dir: The folder where the serialized objects will be stored + :param ephemeral: If True, objects will be stored in a temporary directory inside the given output_dir and cleaned up on exit """ - def __init__(self, output_dir: Path, delete_on_startup: bool = False): + def __init__(self, output_dir: Path, ephemeral: bool = False): super().__init__() - self._output_dir = output_dir - self._output_dir.mkdir(parents=True, exist_ok=True) - self._delete_on_startup = delete_on_startup + self._ephemeral = ephemeral + self._base_output_dir = output_dir + self._base_output_dir.mkdir(parents=True, exist_ok=True) + # Must specify `ignore_cleanup_errors` to avoid fatal errors during cleanup on Windows + self._tempdir = ( + tempfile.TemporaryDirectory(dir=self._base_output_dir, ignore_cleanup_errors=True) if ephemeral else None + ) + self._output_dir = Path(self._tempdir.name) if self._tempdir else self._base_output_dir self.__obj_class_name: Optional[str] = None - def start(self, invoker: "Invoker") -> None: - if self._delete_on_startup: - delete_all_result = self._delete_all() - if delete_all_result.deleted_count > 0: - freed_space_in_mb = round(delete_all_result.freed_space_bytes / 1024 / 1024, 2) - invoker.services.logger.info( - f"Deleted {delete_all_result.deleted_count} {self._obj_class_name} files (freed {freed_space_in_mb}MB)" - ) - def load(self, name: str) -> T: file_path = self._get_path(name) try: @@ -75,19 +72,14 @@ def _get_path(self, name: str) -> Path: def _new_name(self) -> str: return f"{self._obj_class_name}_{uuid_string()}" - def _delete_all(self) -> DeleteAllResult: - """ - Deletes all objects from disk. - """ - - # We could try using a temporary directory here, but they aren't cleared in the event of a crash, so we'd have - # to manually clear them on startup anyways. This is a bit simpler and more reliable. - - deleted_count = 0 - freed_space = 0 - for file in Path(self._output_dir).glob("*"): - if file.is_file(): - freed_space += file.stat().st_size - deleted_count += 1 - file.unlink() - return DeleteAllResult(deleted_count, freed_space) + def _tempdir_cleanup(self) -> None: + """Calls `cleanup` on the temporary directory, if it exists.""" + if self._tempdir: + self._tempdir.cleanup() + + def __del__(self) -> None: + # In case the service is not properly stopped, clean up the temporary directory when the class instance is GC'd. + self._tempdir_cleanup() + + def stop(self, invoker: "Invoker") -> None: + self._tempdir_cleanup() diff --git a/tests/test_object_serializer_disk.py b/tests/test_object_serializer_disk.py index 2bc7e16937e..125534c5002 100644 --- a/tests/test_object_serializer_disk.py +++ b/tests/test_object_serializer_disk.py @@ -1,12 +1,10 @@ +import tempfile from dataclasses import dataclass -from logging import Logger from pathlib import Path -from unittest.mock import Mock import pytest import torch -from invokeai.app.services.invoker import Invoker from invokeai.app.services.object_serializer.object_serializer_common import ObjectNotFoundError from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache @@ -31,11 +29,6 @@ def fwd_cache(tmp_path: Path): return ObjectSerializerForwardCache(ObjectSerializerDisk[MockDataclass](tmp_path), max_cache_size=2) -@pytest.fixture -def mock_invoker_with_logger(): - return Mock(Invoker, services=Mock(logger=Mock(Logger))) - - def test_obj_serializer_disk_initializes(tmp_path: Path): obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path) assert obj_serializer._output_dir == tmp_path @@ -76,39 +69,33 @@ def test_obj_serializer_disk_deletes(obj_serializer: ObjectSerializerDisk[MockDa assert Path(obj_serializer._output_dir, obj_2_name).exists() -def test_obj_serializer_disk_deletes_all(obj_serializer: ObjectSerializerDisk[MockDataclass]): - obj_1 = MockDataclass(foo="bar") - obj_1_name = obj_serializer.save(obj_1) +def test_obj_serializer_ephemeral_creates_tempdir(tmp_path: Path): + obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=True) + assert isinstance(obj_serializer._tempdir, tempfile.TemporaryDirectory) + assert obj_serializer._base_output_dir == tmp_path + assert obj_serializer._output_dir != tmp_path + assert obj_serializer._output_dir == Path(obj_serializer._tempdir.name) - obj_2 = MockDataclass(foo="bar") - obj_2_name = obj_serializer.save(obj_2) - delete_all_result = obj_serializer._delete_all() +def test_obj_serializer_ephemeral_deletes_tempdir(tmp_path: Path): + obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=True) + tempdir_path = obj_serializer._output_dir + del obj_serializer + assert not tempdir_path.exists() - assert not Path(obj_serializer._output_dir, obj_1_name).exists() - assert not Path(obj_serializer._output_dir, obj_2_name).exists() - assert delete_all_result.deleted_count == 2 +def test_obj_serializer_ephemeral_deletes_tempdir_on_stop(tmp_path: Path): + obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=True) + tempdir_path = obj_serializer._output_dir + obj_serializer.stop(None) # pyright: ignore [reportArgumentType] + assert not tempdir_path.exists() -def test_obj_serializer_disk_default_no_delete_on_startup(tmp_path: Path, mock_invoker_with_logger: Invoker): - obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path) - assert obj_serializer._delete_on_startup is False +def test_obj_serializer_ephemeral_writes_to_tempdir(tmp_path: Path): + obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=True) obj_1 = MockDataclass(foo="bar") obj_1_name = obj_serializer.save(obj_1) - - obj_serializer.start(mock_invoker_with_logger) - assert Path(tmp_path, obj_1_name).exists() - - -def test_obj_serializer_disk_delete_on_startup(tmp_path: Path, mock_invoker_with_logger: Invoker): - obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, delete_on_startup=True) - assert obj_serializer._delete_on_startup is True - - obj_1 = MockDataclass(foo="bar") - obj_1_name = obj_serializer.save(obj_1) - - obj_serializer.start(mock_invoker_with_logger) + assert Path(obj_serializer._output_dir, obj_1_name).exists() assert not Path(tmp_path, obj_1_name).exists() From 640d7b832c5092d6508cf09b14c8cb2db83a31c8 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 11 Feb 2024 08:52:07 +1100 Subject: [PATCH 060/340] feat(nodes): extract LATENT_SCALE_FACTOR to constants.py --- invokeai/app/invocations/constants.py | 7 +++++++ invokeai/app/invocations/latent.py | 7 +------ invokeai/backend/tiles/tiles.py | 2 +- 3 files changed, 9 insertions(+), 7 deletions(-) create mode 100644 invokeai/app/invocations/constants.py diff --git a/invokeai/app/invocations/constants.py b/invokeai/app/invocations/constants.py new file mode 100644 index 00000000000..95b16f0d057 --- /dev/null +++ b/invokeai/app/invocations/constants.py @@ -0,0 +1,7 @@ +LATENT_SCALE_FACTOR = 8 +""" +HACK: Many nodes are currently hard-coded to use a fixed latent scale factor of 8. This is fragile, and will need to +be addressed if future models use a different latent scale factor. Also, note that there may be places where the scale +factor is hard-coded to a literal '8' rather than using this constant. +The ratio of image:latent dimensions is LATENT_SCALE_FACTOR:1, or 8:1. +""" diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 4137ab6e2f6..fedfc38402d 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -23,6 +23,7 @@ from pydantic import field_validator from torchvision.transforms.functional import resize as tv_resize +from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR from invokeai.app.invocations.fields import ( ConditioningField, DenoiseMaskField, @@ -79,12 +80,6 @@ SAMPLER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())] -# HACK: Many nodes are currently hard-coded to use a fixed latent scale factor of 8. This is fragile, and will need to -# be addressed if future models use a different latent scale factor. Also, note that there may be places where the scale -# factor is hard-coded to a literal '8' rather than using this constant. -# The ratio of image:latent dimensions is LATENT_SCALE_FACTOR:1, or 8:1. -LATENT_SCALE_FACTOR = 8 - @invocation_output("scheduler_output") class SchedulerOutput(BaseInvocationOutput): diff --git a/invokeai/backend/tiles/tiles.py b/invokeai/backend/tiles/tiles.py index 3c400fc87ce..2757dadba20 100644 --- a/invokeai/backend/tiles/tiles.py +++ b/invokeai/backend/tiles/tiles.py @@ -3,7 +3,7 @@ import numpy as np -from invokeai.app.invocations.latent import LATENT_SCALE_FACTOR +from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR from invokeai.backend.tiles.utils import TBLR, Tile, paste, seam_blend From e7e22acdcd2f9fa97d044cc43520f2a2f1f271d9 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 11 Feb 2024 08:54:01 +1100 Subject: [PATCH 061/340] feat(nodes): use LATENT_SCALE_FACTOR in primitives.py, noise.py - LatentsOutput.build - NoiseOutput.build - Noise.width, Noise.height multiple_of --- invokeai/app/invocations/noise.py | 9 +++++---- invokeai/app/invocations/primitives.py | 5 +++-- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/invokeai/app/invocations/noise.py b/invokeai/app/invocations/noise.py index 74b3d6e4cb1..335d3df292e 100644 --- a/invokeai/app/invocations/noise.py +++ b/invokeai/app/invocations/noise.py @@ -4,6 +4,7 @@ import torch from pydantic import field_validator +from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR from invokeai.app.invocations.fields import FieldDescriptions, InputField, LatentsField, OutputField from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.misc import SEED_MAX @@ -70,8 +71,8 @@ class NoiseOutput(BaseInvocationOutput): def build(cls, latents_name: str, latents: torch.Tensor, seed: int) -> "NoiseOutput": return cls( noise=LatentsField(latents_name=latents_name, seed=seed), - width=latents.size()[3] * 8, - height=latents.size()[2] * 8, + width=latents.size()[3] * LATENT_SCALE_FACTOR, + height=latents.size()[2] * LATENT_SCALE_FACTOR, ) @@ -93,13 +94,13 @@ class NoiseInvocation(BaseInvocation): ) width: int = InputField( default=512, - multiple_of=8, + multiple_of=LATENT_SCALE_FACTOR, gt=0, description=FieldDescriptions.width, ) height: int = InputField( default=512, - multiple_of=8, + multiple_of=LATENT_SCALE_FACTOR, gt=0, description=FieldDescriptions.height, ) diff --git a/invokeai/app/invocations/primitives.py b/invokeai/app/invocations/primitives.py index d0f95c92d02..43422134829 100644 --- a/invokeai/app/invocations/primitives.py +++ b/invokeai/app/invocations/primitives.py @@ -4,6 +4,7 @@ import torch +from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR from invokeai.app.invocations.fields import ( ColorField, ConditioningField, @@ -321,8 +322,8 @@ class LatentsOutput(BaseInvocationOutput): def build(cls, latents_name: str, latents: torch.Tensor, seed: Optional[int] = None) -> "LatentsOutput": return cls( latents=LatentsField(latents_name=latents_name, seed=seed), - width=latents.size()[3] * 8, - height=latents.size()[2] * 8, + width=latents.size()[3] * LATENT_SCALE_FACTOR, + height=latents.size()[2] * LATENT_SCALE_FACTOR, ) From 5007c7786b328b3601b0847fadbee1ff08362005 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 11 Feb 2024 09:27:57 +1100 Subject: [PATCH 062/340] chore(backend): rename `ModelInfo` -> `LoadedModelInfo` We have two different classes named `ModelInfo` which might need to be used by API consumers. We need to export both but have to deal with this naming collision. The `ModelInfo` I've renamed here is the one that is returned when a model is loaded. It's the object least likely to be used by API consumers. --- invokeai/app/services/events/events_base.py | 10 +++++----- .../services/model_manager/model_manager_base.py | 4 ++-- .../model_manager/model_manager_default.py | 16 ++++++++-------- .../app/services/shared/invocation_context.py | 7 ++++--- invokeai/backend/__init__.py | 9 ++++++++- invokeai/backend/model_management/__init__.py | 2 +- .../backend/model_management/model_manager.py | 6 +++--- invokeai/backend/util/test_utils.py | 10 +++++----- invokeai/invocation_api/__init__.py | 4 ++-- 9 files changed, 38 insertions(+), 30 deletions(-) diff --git a/invokeai/app/services/events/events_base.py b/invokeai/app/services/events/events_base.py index ad08ae03956..6b441efc2bf 100644 --- a/invokeai/app/services/events/events_base.py +++ b/invokeai/app/services/events/events_base.py @@ -11,7 +11,7 @@ SessionQueueStatus, ) from invokeai.app.util.misc import get_timestamp -from invokeai.backend.model_management.model_manager import ModelInfo +from invokeai.backend.model_management.model_manager import LoadedModelInfo from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType @@ -201,7 +201,7 @@ def emit_model_load_completed( base_model: BaseModelType, model_type: ModelType, submodel: SubModelType, - model_info: ModelInfo, + loaded_model_info: LoadedModelInfo, ) -> None: """Emitted when a model is correctly loaded (returns model info)""" self.__emit_queue_event( @@ -215,9 +215,9 @@ def emit_model_load_completed( "base_model": base_model, "model_type": model_type, "submodel": submodel, - "hash": model_info.hash, - "location": str(model_info.location), - "precision": str(model_info.precision), + "hash": loaded_model_info.hash, + "location": str(loaded_model_info.location), + "precision": str(loaded_model_info.precision), }, ) diff --git a/invokeai/app/services/model_manager/model_manager_base.py b/invokeai/app/services/model_manager/model_manager_base.py index a9b53ae2242..f888c0ec973 100644 --- a/invokeai/app/services/model_manager/model_manager_base.py +++ b/invokeai/app/services/model_manager/model_manager_base.py @@ -14,8 +14,8 @@ from invokeai.backend.model_management import ( AddModelResult, BaseModelType, + LoadedModelInfo, MergeInterpolationMethod, - ModelInfo, ModelType, SchedulerPredictionType, SubModelType, @@ -48,7 +48,7 @@ def get_model( model_type: ModelType, submodel: Optional[SubModelType] = None, context_data: Optional[InvocationContextData] = None, - ) -> ModelInfo: + ) -> LoadedModelInfo: """Retrieve the indicated model with name and type. submodel can be used to get a part (such as the vae) of a diffusers pipeline.""" diff --git a/invokeai/app/services/model_manager/model_manager_default.py b/invokeai/app/services/model_manager/model_manager_default.py index b641dd3f1ed..c3712abf8e6 100644 --- a/invokeai/app/services/model_manager/model_manager_default.py +++ b/invokeai/app/services/model_manager/model_manager_default.py @@ -16,8 +16,8 @@ from invokeai.backend.model_management import ( AddModelResult, BaseModelType, + LoadedModelInfo, MergeInterpolationMethod, - ModelInfo, ModelManager, ModelMerger, ModelNotFoundException, @@ -98,7 +98,7 @@ def get_model( model_type: ModelType, submodel: Optional[SubModelType] = None, context_data: Optional[InvocationContextData] = None, - ) -> ModelInfo: + ) -> LoadedModelInfo: """ Retrieve the indicated model. submodel can be used to get a part (such as the vae) of a diffusers mode. @@ -114,7 +114,7 @@ def get_model( submodel=submodel, ) - model_info = self.mgr.get_model( + loaded_model_info = self.mgr.get_model( model_name, base_model, model_type, @@ -128,10 +128,10 @@ def get_model( base_model=base_model, model_type=model_type, submodel=submodel, - model_info=model_info, + loaded_model_info=loaded_model_info, ) - return model_info + return loaded_model_info def model_exists( self, @@ -273,7 +273,7 @@ def _emit_load_event( base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None, - model_info: Optional[ModelInfo] = None, + loaded_model_info: Optional[LoadedModelInfo] = None, ): if self._invoker is None: return @@ -281,7 +281,7 @@ def _emit_load_event( if self._invoker.services.queue.is_canceled(context_data.session_id): raise CanceledException() - if model_info: + if loaded_model_info: self._invoker.services.events.emit_model_load_completed( queue_id=context_data.queue_id, queue_item_id=context_data.queue_item_id, @@ -291,7 +291,7 @@ def _emit_load_event( base_model=base_model, model_type=model_type, submodel=submodel, - model_info=model_info, + loaded_model_info=loaded_model_info, ) else: self._invoker.services.events.emit_model_load_started( diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 1ca44b78625..68fb78c1430 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -13,7 +13,7 @@ from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID from invokeai.app.util.step_callback import stable_diffusion_step_callback -from invokeai.backend.model_management.model_manager import ModelInfo +from invokeai.backend.model_management.model_manager import LoadedModelInfo from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData @@ -272,14 +272,15 @@ def exists(self, model_name: str, base_model: BaseModelType, model_type: ModelTy def load( self, model_name: str, base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None - ) -> ModelInfo: + ) -> LoadedModelInfo: """ - Loads a model, returning its `ModelInfo` object. + Loads a model. :param model_name: The name of the model to get. :param base_model: The base model of the model to get. :param model_type: The type of the model to get. :param submodel: The submodel of the model to get. + :returns: An object representing the loaded model. """ # The model manager emits events as it loads the model. It needs the context data to build diff --git a/invokeai/backend/__init__.py b/invokeai/backend/__init__.py index ae9a12edbe2..54a1843d463 100644 --- a/invokeai/backend/__init__.py +++ b/invokeai/backend/__init__.py @@ -1,5 +1,12 @@ """ Initialization file for invokeai.backend """ -from .model_management import BaseModelType, ModelCache, ModelInfo, ModelManager, ModelType, SubModelType # noqa: F401 +from .model_management import ( # noqa: F401 + BaseModelType, + LoadedModelInfo, + ModelCache, + ModelManager, + ModelType, + SubModelType, +) from .model_management.models import SilenceWarnings # noqa: F401 diff --git a/invokeai/backend/model_management/__init__.py b/invokeai/backend/model_management/__init__.py index 03abf58eb46..d523a7a0c8d 100644 --- a/invokeai/backend/model_management/__init__.py +++ b/invokeai/backend/model_management/__init__.py @@ -3,7 +3,7 @@ Initialization file for invokeai.backend.model_management """ # This import must be first -from .model_manager import AddModelResult, ModelInfo, ModelManager, SchedulerPredictionType +from .model_manager import AddModelResult, LoadedModelInfo, ModelManager, SchedulerPredictionType from .lora import ModelPatcher, ONNXModelPatcher from .model_cache import ModelCache diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index 362d8d3ff55..da74ca3fb58 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -271,7 +271,7 @@ class ConfigMeta(BaseModel):Loras, textual_inversion and controlnet models are n @dataclass -class ModelInfo: +class LoadedModelInfo: context: ModelLocker name: str base_model: BaseModelType @@ -450,7 +450,7 @@ def get_model( base_model: BaseModelType, model_type: ModelType, submodel_type: Optional[SubModelType] = None, - ) -> ModelInfo: + ) -> LoadedModelInfo: """Given a model named identified in models.yaml, return an ModelInfo object describing it. :param model_name: symbolic name of the model in models.yaml @@ -508,7 +508,7 @@ def get_model( model_hash = "" # TODO: - return ModelInfo( + return LoadedModelInfo( context=model_context, name=model_name, base_model=base_model, diff --git a/invokeai/backend/util/test_utils.py b/invokeai/backend/util/test_utils.py index 09b9de9e984..685603cedc6 100644 --- a/invokeai/backend/util/test_utils.py +++ b/invokeai/backend/util/test_utils.py @@ -7,7 +7,7 @@ from invokeai.app.services.config.config_default import InvokeAIAppConfig from invokeai.backend.install.model_install_backend import ModelInstall -from invokeai.backend.model_management.model_manager import ModelInfo +from invokeai.backend.model_management.model_manager import LoadedModelInfo from invokeai.backend.model_management.models.base import BaseModelType, ModelNotFoundException, ModelType, SubModelType @@ -34,8 +34,8 @@ def install_and_load_model( base_model: BaseModelType, model_type: ModelType, submodel_type: Optional[SubModelType] = None, -) -> ModelInfo: - """Install a model if it is not already installed, then get the ModelInfo for that model. +) -> LoadedModelInfo: + """Install a model if it is not already installed, then get the LoadedModelInfo for that model. This is intended as a utility function for tests. @@ -49,9 +49,9 @@ def install_and_load_model( submodel_type (Optional[SubModelType]): The submodel type, forwarded to ModelManager.get_model(...). Returns: - ModelInfo + LoadedModelInfo """ - # If the requested model is already installed, return its ModelInfo. + # If the requested model is already installed, return its LoadedModelInfo. with contextlib.suppress(ModelNotFoundException): return model_installer.mgr.get_model(model_name, base_model, model_type, submodel_type) diff --git a/invokeai/invocation_api/__init__.py b/invokeai/invocation_api/__init__.py index e80bc26a003..2d3ceca11e2 100644 --- a/invokeai/invocation_api/__init__.py +++ b/invokeai/invocation_api/__init__.py @@ -52,7 +52,7 @@ from invokeai.app.services.image_records.image_records_common import ImageCategory from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID -from invokeai.backend.model_management.model_manager import ModelInfo +from invokeai.backend.model_management.model_manager import LoadedModelInfo from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( @@ -121,7 +121,7 @@ # invokeai.app.services.config.config_default "InvokeAIAppConfig", # invokeai.backend.model_management.model_manager - "ModelInfo", + "LoadedModelInfo", # invokeai.backend.model_management.models.base "BaseModelType", "ModelType", From ea9351c70274cde270c2198b6577968f2e76759a Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 11 Feb 2024 09:39:36 +1100 Subject: [PATCH 063/340] chore(nodes): export model-related objects from invocation_api --- invokeai/invocation_api/__init__.py | 31 +++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/invokeai/invocation_api/__init__.py b/invokeai/invocation_api/__init__.py index 2d3ceca11e2..055dd12757d 100644 --- a/invokeai/invocation_api/__init__.py +++ b/invokeai/invocation_api/__init__.py @@ -28,6 +28,22 @@ WithMetadata, WithWorkflow, ) +from invokeai.app.invocations.model import ( + ClipField, + CLIPOutput, + LoraInfo, + LoraLoaderOutput, + LoRAModelField, + MainModelField, + ModelInfo, + ModelLoaderOutput, + SDXLLoraLoaderOutput, + UNetField, + UNetOutput, + VaeField, + VAEModelField, + VAEOutput, +) from invokeai.app.invocations.primitives import ( BooleanCollectionOutput, BooleanOutput, @@ -87,6 +103,21 @@ "UIType", "WithMetadata", "WithWorkflow", + # invokeai.app.invocations.model + "ModelInfo", + "LoraInfo", + "UNetField", + "ClipField", + "VaeField", + "MainModelField", + "LoRAModelField", + "VAEModelField", + "UNetOutput", + "VAEOutput", + "CLIPOutput", + "ModelLoaderOutput", + "LoraLoaderOutput", + "SDXLLoraLoaderOutput", # invokeai.app.invocations.primitives "BooleanCollectionOutput", "BooleanOutput", From 9cd207920b914f56a5fd997c16deceafd2352c5e Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 11 Feb 2024 09:43:36 +1100 Subject: [PATCH 064/340] chore(nodes): remove deprecation logic for nodes API --- .../app/services/shared/invocation_context.py | 112 ------------------ 1 file changed, 112 deletions(-) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 68fb78c1430..c68dc1140b2 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -1,7 +1,6 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Optional -from deprecated import deprecated from PIL.Image import Image from torch import Tensor @@ -334,30 +333,6 @@ def sd_step_callback(self, intermediate_state: PipelineIntermediateState, base_m ) -deprecation_version = "3.7.0" -removed_version = "3.8.0" - - -def get_deprecation_reason(property_name: str, alternative: Optional[str] = None) -> str: - msg = f"{property_name} is deprecated as of v{deprecation_version}. It will be removed in v{removed_version}." - if alternative is not None: - msg += f" Use {alternative} instead." - msg += " See PLACEHOLDER_URL for details." - return msg - - -# Deprecation docstrings template. I don't think we can implement these programmatically with -# __doc__ because the IDE won't see them. - -""" -**DEPRECATED as of v3.7.0** - -PROPERTY_NAME will be removed in v3.8.0. Use ALTERNATIVE instead. See PLACEHOLDER_URL for details. - -OG_DOCSTRING -""" - - class InvocationContext: """ The `InvocationContext` provides access to various services and data for the current invocation. @@ -397,93 +372,6 @@ def __init__( self._services = services """Provides access to the full application services. This is an internal API and may change without warning.""" - @property - @deprecated(version=deprecation_version, reason=get_deprecation_reason("`context.services`")) - def services(self) -> InvocationServices: - """ - **DEPRECATED as of v3.7.0** - - `context.services` will be removed in v3.8.0. See PLACEHOLDER_URL for details. - - The invocation services. - """ - return self._services - - @property - @deprecated( - version=deprecation_version, - reason=get_deprecation_reason("`context.graph_execution_state_id", "`context._data.session_id`"), - ) - def graph_execution_state_id(self) -> str: - """ - **DEPRECATED as of v3.7.0** - - `context.graph_execution_state_api` will be removed in v3.8.0. Use `context._data.session_id` instead. See PLACEHOLDER_URL for details. - - The ID of the session (aka graph execution state). - """ - return self._data.session_id - - @property - @deprecated( - version=deprecation_version, - reason=get_deprecation_reason("`context.queue_id`", "`context._data.queue_id`"), - ) - def queue_id(self) -> str: - """ - **DEPRECATED as of v3.7.0** - - `context.queue_id` will be removed in v3.8.0. Use `context._data.queue_id` instead. See PLACEHOLDER_URL for details. - - The ID of the queue. - """ - return self._data.queue_id - - @property - @deprecated( - version=deprecation_version, - reason=get_deprecation_reason("`context.queue_item_id`", "`context._data.queue_item_id`"), - ) - def queue_item_id(self) -> int: - """ - **DEPRECATED as of v3.7.0** - - `context.queue_item_id` will be removed in v3.8.0. Use `context._data.queue_item_id` instead. See PLACEHOLDER_URL for details. - - The ID of the queue item. - """ - return self._data.queue_item_id - - @property - @deprecated( - version=deprecation_version, - reason=get_deprecation_reason("`context.queue_batch_id`", "`context._data.batch_id`"), - ) - def queue_batch_id(self) -> str: - """ - **DEPRECATED as of v3.7.0** - - `context.queue_batch_id` will be removed in v3.8.0. Use `context._data.batch_id` instead. See PLACEHOLDER_URL for details. - - The ID of the batch. - """ - return self._data.batch_id - - @property - @deprecated( - version=deprecation_version, - reason=get_deprecation_reason("`context.workflow`", "`context._data.workflow`"), - ) - def workflow(self) -> Optional[WorkflowWithoutID]: - """ - **DEPRECATED as of v3.7.0** - - `context.workflow` will be removed in v3.8.0. Use `context._data.workflow` instead. See PLACEHOLDER_URL for details. - - The workflow associated with this queue item, if any. - """ - return self._data.workflow - def build_invocation_context( services: InvocationServices, From 170ceb6663d5f272374433258ff96c7b58cdd0c5 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 11 Feb 2024 09:51:25 +1100 Subject: [PATCH 065/340] chore(nodes): "SAMPLER_NAME_VALUES" -> "SCHEDULER_NAME_VALUES" This was named inaccurately. --- invokeai/app/invocations/constants.py | 7 +++++++ invokeai/app/invocations/latent.py | 10 ++++------ 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/invokeai/app/invocations/constants.py b/invokeai/app/invocations/constants.py index 95b16f0d057..795e7a3b604 100644 --- a/invokeai/app/invocations/constants.py +++ b/invokeai/app/invocations/constants.py @@ -1,3 +1,7 @@ +from typing import Literal + +from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP + LATENT_SCALE_FACTOR = 8 """ HACK: Many nodes are currently hard-coded to use a fixed latent scale factor of 8. This is fragile, and will need to @@ -5,3 +9,6 @@ factor is hard-coded to a literal '8' rather than using this constant. The ratio of image:latent dimensions is LATENT_SCALE_FACTOR:1, or 8:1. """ + +SCHEDULER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())] +"""A literal type representing the valid scheduler names.""" diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index fedfc38402d..69e3f055ca8 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -23,7 +23,7 @@ from pydantic import field_validator from torchvision.transforms.functional import resize as tv_resize -from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR +from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES from invokeai.app.invocations.fields import ( ConditioningField, DenoiseMaskField, @@ -78,12 +78,10 @@ DEFAULT_PRECISION = choose_precision(choose_torch_device()) -SAMPLER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())] - @invocation_output("scheduler_output") class SchedulerOutput(BaseInvocationOutput): - scheduler: SAMPLER_NAME_VALUES = OutputField(description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler) + scheduler: SCHEDULER_NAME_VALUES = OutputField(description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler) @invocation( @@ -96,7 +94,7 @@ class SchedulerOutput(BaseInvocationOutput): class SchedulerInvocation(BaseInvocation): """Selects a scheduler.""" - scheduler: SAMPLER_NAME_VALUES = InputField( + scheduler: SCHEDULER_NAME_VALUES = InputField( default="euler", description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler, @@ -234,7 +232,7 @@ class DenoiseLatentsInvocation(BaseInvocation): description=FieldDescriptions.denoising_start, ) denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end) - scheduler: SAMPLER_NAME_VALUES = InputField( + scheduler: SCHEDULER_NAME_VALUES = InputField( default="euler", description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler, From 39c82c33e8aa63008aab819e67fc9e7ebbe6ec34 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 11 Feb 2024 10:06:53 +1100 Subject: [PATCH 066/340] feat(nodes): add more missing exports to invocation_api Crawled through a few custom nodes to figure out what I had missed. --- invokeai/invocation_api/__init__.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/invokeai/invocation_api/__init__.py b/invokeai/invocation_api/__init__.py index 055dd12757d..e110b5a2db3 100644 --- a/invokeai/invocation_api/__init__.py +++ b/invokeai/invocation_api/__init__.py @@ -7,9 +7,11 @@ from invokeai.app.invocations.baseinvocation import ( BaseInvocation, BaseInvocationOutput, + Classification, invocation, invocation_output, ) +from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES from invokeai.app.invocations.fields import ( BoardField, ColorField, @@ -28,6 +30,8 @@ WithMetadata, WithWorkflow, ) +from invokeai.app.invocations.latent import SchedulerOutput +from invokeai.app.invocations.metadata import MetadataItemField, MetadataItemOutput, MetadataOutput from invokeai.app.invocations.model import ( ClipField, CLIPOutput, @@ -68,6 +72,7 @@ from invokeai.app.services.image_records.image_records_common import ImageCategory from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID +from invokeai.app.util.misc import SEED_MAX, get_random_seed from invokeai.backend.model_management.model_manager import LoadedModelInfo from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState @@ -77,11 +82,14 @@ ExtraConditioningInfo, SDXLConditioningInfo, ) +from invokeai.backend.util.devices import CPU_DEVICE, CUDA_DEVICE, MPS_DEVICE, choose_precision, choose_torch_device +from invokeai.version import __version__ __all__ = [ # invokeai.app.invocations.baseinvocation "BaseInvocation", "BaseInvocationOutput", + "Classification", "invocation", "invocation_output", # invokeai.app.services.shared.invocation_context @@ -103,6 +111,12 @@ "UIType", "WithMetadata", "WithWorkflow", + # invokeai.app.invocations.latent + "SchedulerOutput", + # invokeai.app.invocations.metadata + "MetadataItemField", + "MetadataItemOutput", + "MetadataOutput", # invokeai.app.invocations.model "ModelInfo", "LoraInfo", @@ -157,4 +171,17 @@ "BaseModelType", "ModelType", "SubModelType", + # invokeai.app.invocations.constants + "SCHEDULER_NAME_VALUES", + # invokeai.version + "__version__", + # invokeai.backend.util.devices + "choose_precision", + "choose_torch_device", + "CPU_DEVICE", + "CUDA_DEVICE", + "MPS_DEVICE", + # invokeai.app.util.misc + "SEED_MAX", + "get_random_seed", ] From c227adfa30c7a1a81f993fe396a72c0c1088c297 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 13 Feb 2024 15:02:30 +1100 Subject: [PATCH 067/340] chore(ui): regen types --- .../frontend/web/src/services/api/schema.ts | 1892 +---------------- 1 file changed, 67 insertions(+), 1825 deletions(-) diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index 45358ed97d5..1599b310c9a 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -680,70 +680,6 @@ export type components = { */ type: "add"; }; - /** - * Adjust Image Hue Plus - * @description Adjusts the Hue of an image by rotating it in the selected color space - */ - AdjustImageHuePlusInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description The image to adjust */ - image?: components["schemas"]["ImageField"]; - /** - * Space - * @description Color space in which to rotate hue by polar coords (*: non-invertible) - * @default Okhsl - * @enum {string} - */ - space?: "HSV / HSL / RGB" | "Okhsl" | "Okhsv" | "*Oklch / Oklab" | "*LCh / CIELab" | "*UPLab (w/CIELab_to_UPLab.icc)"; - /** - * Degrees - * @description Degrees by which to rotate image hue - * @default 0 - */ - degrees?: number; - /** - * Preserve Lightness - * @description Whether to preserve CIELAB lightness values - * @default false - */ - preserve_lightness?: boolean; - /** - * Ok Adaptive Gamut - * @description Higher preserves chroma at the expense of lightness (Oklab) - * @default 0.05 - */ - ok_adaptive_gamut?: number; - /** - * Ok High Precision - * @description Use more steps in computing gamut (Oklab/Okhsv/Okhsl) - * @default true - */ - ok_high_precision?: boolean; - /** - * type - * @default img_hue_adjust_plus - * @constant - */ - type: "img_hue_adjust_plus"; - }; /** * AppConfig * @description App Config Response @@ -1454,39 +1390,6 @@ export type components = { */ type: "boolean_output"; }; - /** - * BRIA AI Background Removal - * @description Uses the new Bria 1.4 model to remove backgrounds from images. - */ - BriaRemoveBackgroundInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description The image to crop */ - image?: components["schemas"]["ImageField"]; - /** - * type - * @default bria_bg_remove - * @constant - */ - type: "bria_bg_remove"; - }; /** * CLIPOutput * @description Base class for invocations that output a CLIP field @@ -1581,282 +1484,6 @@ export type components = { /** @description Base model (usually 'Any') */ base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; }; - /** - * CMYK Color Separation - * @description Get color images from a base color and two others that subtractively mix to obtain it - */ - CMYKColorSeparationInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** - * Width - * @description Desired image width - * @default 512 - */ - width?: number; - /** - * Height - * @description Desired image height - * @default 512 - */ - height?: number; - /** - * C Value - * @description Desired final cyan value - * @default 0 - */ - c_value?: number; - /** - * M Value - * @description Desired final magenta value - * @default 25 - */ - m_value?: number; - /** - * Y Value - * @description Desired final yellow value - * @default 28 - */ - y_value?: number; - /** - * K Value - * @description Desired final black value - * @default 76 - */ - k_value?: number; - /** - * C Split - * @description Desired cyan split point % [0..1.0] - * @default 0.5 - */ - c_split?: number; - /** - * M Split - * @description Desired magenta split point % [0..1.0] - * @default 1 - */ - m_split?: number; - /** - * Y Split - * @description Desired yellow split point % [0..1.0] - * @default 0 - */ - y_split?: number; - /** - * K Split - * @description Desired black split point % [0..1.0] - * @default 0.5 - */ - k_split?: number; - /** - * Profile - * @description CMYK Color Profile - * @default Default - * @enum {string} - */ - profile?: "Default" | "PIL"; - /** - * type - * @default cmyk_separation - * @constant - */ - type: "cmyk_separation"; - }; - /** - * CMYK Merge - * @description Merge subtractive color channels (CMYK+alpha) - */ - CMYKMergeInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description The c channel */ - c_channel?: components["schemas"]["ImageField"] | null; - /** @description The m channel */ - m_channel?: components["schemas"]["ImageField"] | null; - /** @description The y channel */ - y_channel?: components["schemas"]["ImageField"] | null; - /** @description The k channel */ - k_channel?: components["schemas"]["ImageField"] | null; - /** @description The alpha channel */ - alpha_channel?: components["schemas"]["ImageField"] | null; - /** - * Profile - * @description CMYK Color Profile - * @default Default - * @enum {string} - */ - profile?: "Default" | "PIL"; - /** - * type - * @default cmyk_merge - * @constant - */ - type: "cmyk_merge"; - }; - /** - * CMYKSeparationOutput - * @description Base class for invocations that output four L-mode images (C, M, Y, K) - */ - CMYKSeparationOutput: { - /** @description Blank image of the specified color */ - color_image: components["schemas"]["ImageField"]; - /** - * Width - * @description The width of the image in pixels - */ - width: number; - /** - * Height - * @description The height of the image in pixels - */ - height: number; - /** @description Blank image of the first separated color */ - part_a: components["schemas"]["ImageField"]; - /** - * Rgb Red A - * @description R value of color part A - */ - rgb_red_a: number; - /** - * Rgb Green A - * @description G value of color part A - */ - rgb_green_a: number; - /** - * Rgb Blue A - * @description B value of color part A - */ - rgb_blue_a: number; - /** @description Blank image of the second separated color */ - part_b: components["schemas"]["ImageField"]; - /** - * Rgb Red B - * @description R value of color part B - */ - rgb_red_b: number; - /** - * Rgb Green B - * @description G value of color part B - */ - rgb_green_b: number; - /** - * Rgb Blue B - * @description B value of color part B - */ - rgb_blue_b: number; - /** - * type - * @default cmyk_separation_output - * @constant - */ - type: "cmyk_separation_output"; - }; - /** - * CMYK Split - * @description Split an image into subtractive color channels (CMYK+alpha) - */ - CMYKSplitInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description The image to halftone */ - image?: components["schemas"]["ImageField"]; - /** - * Profile - * @description CMYK Color Profile - * @default Default - * @enum {string} - */ - profile?: "Default" | "PIL"; - /** - * type - * @default cmyk_split - * @constant - */ - type: "cmyk_split"; - }; - /** - * CMYKSplitOutput - * @description Base class for invocations that output four L-mode images (C, M, Y, K) - */ - CMYKSplitOutput: { - /** @description Grayscale image of the cyan channel */ - c_channel: components["schemas"]["ImageField"]; - /** @description Grayscale image of the magenta channel */ - m_channel: components["schemas"]["ImageField"]; - /** @description Grayscale image of the yellow channel */ - y_channel: components["schemas"]["ImageField"]; - /** @description Grayscale image of the k channel */ - k_channel: components["schemas"]["ImageField"]; - /** @description Grayscale image of the alpha channel */ - alpha_channel: components["schemas"]["ImageField"]; - /** - * Width - * @description The width of the image in pixels - */ - width: number; - /** - * Height - * @description The height of the image in pixels - */ - height: number; - /** - * type - * @default cmyk_split_output - * @constant - */ - type: "cmyk_split_output"; - }; /** * CV2 Infill * @description Infills transparent areas of an image using OpenCV Inpainting @@ -3491,6 +3118,8 @@ export type components = { * @description Generates an openpose pose from an image using DWPose */ DWOpenposeImageProcessorInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -4014,11 +3643,20 @@ export type components = { */ priority: number; }; + /** ExposedField */ + ExposedField: { + /** Nodeid */ + nodeId: string; + /** Fieldname */ + fieldName: string; + }; /** - * Equivalent Achromatic Lightness - * @description Calculate Equivalent Achromatic Lightness from image + * FaceIdentifier + * @description Outputs an image with detected face IDs printed on each face. For use with other FaceTools. */ - EquivalentAchromaticLightnessInvocation: { + FaceIdentifierInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -4038,54 +3676,12 @@ export type components = { * @default true */ use_cache?: boolean; - /** @description Image from which to get channel */ + /** @description Image to face detect */ image?: components["schemas"]["ImageField"]; /** - * type - * @default ealightness - * @constant - */ - type: "ealightness"; - }; - /** ExposedField */ - ExposedField: { - /** Nodeid */ - nodeId: string; - /** Fieldname */ - fieldName: string; - }; - /** - * FaceIdentifier - * @description Outputs an image with detected face IDs printed on each face. For use with other FaceTools. - */ - FaceIdentifierInvocation: { - /** @description The board to save the image to */ - board?: components["schemas"]["BoardField"] | null; - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description Image to face detect */ - image?: components["schemas"]["ImageField"]; - /** - * Minimum Confidence - * @description Minimum confidence for face detection (lower if detection is failing) - * @default 0.5 + * Minimum Confidence + * @description Minimum confidence for face detection (lower if detection is failing) + * @default 0.5 */ minimum_confidence?: number; /** @@ -4301,39 +3897,6 @@ export type components = { */ y: number; }; - /** - * Flatten Histogram (Grayscale) - * @description Scales the values of an L-mode image by scaling them to the full range 0..255 in equal proportions - */ - FlattenHistogramMono: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description Single-channel image for which to flatten the histogram */ - image?: components["schemas"]["ImageField"]; - /** - * type - * @default flatten_histogram_mono - * @constant - */ - type: "flatten_histogram_mono"; - }; /** * Float Collection Primitive * @description A collection of float primitive values @@ -4683,7 +4246,7 @@ export type components = { * @description The nodes in this graph */ nodes?: { - [key: string]: components["schemas"]["ImageDilateOrErodeInvocation"] | components["schemas"]["ImageChannelMultiplyInvocation"] | components["schemas"]["StringInvocation"] | components["schemas"]["SchedulerInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ConditioningInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["HandDepthMeshGraphormerProcessor"] | components["schemas"]["MergeMetadataInvocation"] | components["schemas"]["TextToMaskClipsegInvocation"] | components["schemas"]["MetadataItemInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LinearUIOutputInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["OffsetLatentsInvocation"] | components["schemas"]["FaceIdentifierInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StringJoinThreeInvocation"] | components["schemas"]["FreeUInvocation"] | components["schemas"]["CMYKColorSeparationInvocation"] | components["schemas"]["NoiseImage2DInvocation"] | components["schemas"]["BooleanCollectionInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["EquivalentAchromaticLightnessInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["DepthAnythingImageProcessorInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageBlendInvocation"] | components["schemas"]["FloatMathInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["ImageRotateInvocation"] | components["schemas"]["ShadowsHighlightsMidtonesMaskInvocation"] | components["schemas"]["ONNXLatentsToImageInvocation"] | components["schemas"]["CV2InfillInvocation"] | components["schemas"]["BriaRemoveBackgroundInvocation"] | components["schemas"]["IntegerInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["ImageValueThresholdsInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["StringJoinInvocation"] | components["schemas"]["ONNXPromptInvocation"] | components["schemas"]["LatentsInvocation"] | components["schemas"]["IPAdapterInvocation"] | components["schemas"]["NoiseSpectralInvocation"] | components["schemas"]["TextMaskInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["RandomFloatInvocation"] | components["schemas"]["MaskedBlendLatentsInvocation"] | components["schemas"]["CreateDenoiseMaskInvocation"] | components["schemas"]["CMYKMergeInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["ColorMapImageProcessorInvocation"] | components["schemas"]["TextToMaskClipsegAdvancedInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["BlendLatentsInvocation"] | components["schemas"]["UnsharpMaskInvocation"] | components["schemas"]["MetadataInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["FloatToIntegerInvocation"] | components["schemas"]["CalculateImageTilesInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["PairTileImageInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["CoreMetadataInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SeamlessModeInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageCompositorInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["ImageInvocation"] | components["schemas"]["FlattenHistogramMono"] | components["schemas"]["AdjustImageHuePlusInvocation"] | components["schemas"]["FloatCollectionInvocation"] | components["schemas"]["T2IAdapterInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["RoundInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["CMYKSplitInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["FloatInvocation"] | components["schemas"]["StringReplaceInvocation"] | components["schemas"]["ColorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["StringSplitNegInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["CropLatentsCoreInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["BooleanInvocation"] | components["schemas"]["CalculateImageTilesEvenSplitInvocation"] | components["schemas"]["FaceOffInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["IdealSizeInvocation"] | components["schemas"]["BlankImageInvocation"] | components["schemas"]["DWOpenposeImageProcessorInvocation"] | components["schemas"]["ConditioningCollectionInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["IntegerMathInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["ImageChannelOffsetInvocation"] | components["schemas"]["LatentsCollectionInvocation"] | components["schemas"]["TileToPropertiesInvocation"] | components["schemas"]["ImageEnhanceInvocation"] | components["schemas"]["IntegerCollectionInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["SDXLLoraLoaderInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["StringCollectionInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["LatentConsistencyInvocation"] | components["schemas"]["MergeTilesToImageInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["FaceMaskInvocation"] | components["schemas"]["CenterPadCropInvocation"] | components["schemas"]["LaMaInfillInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["OnnxModelLoaderInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["SaveImageInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["ImageOffsetInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["StringSplitInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["CalculateImageTilesMinimumOverlapInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ONNXTextToLatentsInvocation"] | components["schemas"]["InfillColorInvocation"]; + [key: string]: components["schemas"]["ImageCropInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["RoundInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CalculateImageTilesMinimumOverlapInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ConditioningCollectionInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["CropLatentsCoreInvocation"] | components["schemas"]["ConditioningInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["DepthAnythingImageProcessorInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["IntegerInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["LaMaInfillInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["StringJoinInvocation"] | components["schemas"]["CalculateImageTilesInvocation"] | components["schemas"]["BlankImageInvocation"] | components["schemas"]["StringSplitNegInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["StringJoinThreeInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["BlendLatentsInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["DWOpenposeImageProcessorInvocation"] | components["schemas"]["MetadataItemInvocation"] | components["schemas"]["RandomFloatInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["ImageInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["TileToPropertiesInvocation"] | components["schemas"]["ImageChannelMultiplyInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImageChannelOffsetInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["StringCollectionInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["FloatInvocation"] | components["schemas"]["MetadataInvocation"] | components["schemas"]["MergeMetadataInvocation"] | components["schemas"]["CoreMetadataInvocation"] | components["schemas"]["StringReplaceInvocation"] | components["schemas"]["MergeTilesToImageInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["StringSplitInvocation"] | components["schemas"]["BooleanCollectionInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["BooleanInvocation"] | components["schemas"]["LatentsInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["FaceMaskInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["SDXLLoraLoaderInvocation"] | components["schemas"]["SeamlessModeInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["LatentsCollectionInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IntegerCollectionInvocation"] | components["schemas"]["IPAdapterInvocation"] | components["schemas"]["FreeUInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["T2IAdapterInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["ColorMapImageProcessorInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["SaveImageInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["CenterPadCropInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["CalculateImageTilesEvenSplitInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["SchedulerInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["CV2InfillInvocation"] | components["schemas"]["StringInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["PairTileImageInvocation"] | components["schemas"]["FaceOffInvocation"] | components["schemas"]["FloatCollectionInvocation"] | components["schemas"]["FaceIdentifierInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ColorInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["IntegerMathInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["CreateDenoiseMaskInvocation"] | components["schemas"]["FloatMathInvocation"] | components["schemas"]["UnsharpMaskInvocation"] | components["schemas"]["FloatToIntegerInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["IdealSizeInvocation"]; }; /** * Edges @@ -4720,7 +4283,7 @@ export type components = { * @description The results of node executions */ results: { - [key: string]: components["schemas"]["LoraLoaderOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["SeamlessModeOutput"] | components["schemas"]["IPAdapterOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["UNetOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["CMYKSplitOutput"] | components["schemas"]["LatentsCollectionOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["TileToPropertiesOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["CMYKSeparationOutput"] | components["schemas"]["String2Output"] | components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["MetadataOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["VAEOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["T2IAdapterOutput"] | components["schemas"]["IdealSizeOutput"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["PairTileImageOutput"] | components["schemas"]["CLIPOutput"] | components["schemas"]["ClipSkipInvocationOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["CalculateImageTilesOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["GraphInvocationOutput"] | components["schemas"]["IntegerOutput"] | components["schemas"]["SDXLLoraLoaderOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["DenoiseMaskOutput"] | components["schemas"]["ONNXModelLoaderOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["HandDepthOutput"] | components["schemas"]["ShadowsHighlightsMidtonesMasksOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["SchedulerOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["ColorCollectionOutput"]; + [key: string]: components["schemas"]["MetadataOutput"] | components["schemas"]["IntegerOutput"] | components["schemas"]["CalculateImageTilesOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["CLIPOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["TileToPropertiesOutput"] | components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["SDXLLoraLoaderOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["IdealSizeOutput"] | components["schemas"]["PairTileImageOutput"] | components["schemas"]["T2IAdapterOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["UNetOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["VAEOutput"] | components["schemas"]["LoraLoaderOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["GraphInvocationOutput"] | components["schemas"]["String2Output"] | components["schemas"]["ColorCollectionOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["SeamlessModeOutput"] | components["schemas"]["DenoiseMaskOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["IPAdapterOutput"] | components["schemas"]["LatentsCollectionOutput"] | components["schemas"]["SchedulerOutput"] | components["schemas"]["ClipSkipInvocationOutput"]; }; /** * Errors @@ -4811,83 +4374,6 @@ export type components = { /** Detail */ detail?: components["schemas"]["ValidationError"][]; }; - /** - * Hand Depth w/ MeshGraphormer - * @description Generate hand depth maps to inpaint with using ControlNet - */ - HandDepthMeshGraphormerProcessor: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description The image to process */ - image?: components["schemas"]["ImageField"]; - /** - * Resolution - * @description Pixel resolution for output image - * @default 512 - */ - resolution?: number; - /** - * Mask Padding - * @description Amount to pad the hand mask by - * @default 30 - */ - mask_padding?: number; - /** - * Offload - * @description Offload model after usage - * @default false - */ - offload?: boolean; - /** - * type - * @default hand_depth_mesh_graphormer_image_processor - * @constant - */ - type: "hand_depth_mesh_graphormer_image_processor"; - }; - /** - * HandDepthOutput - * @description Base class for to output Meshgraphormer results - */ - HandDepthOutput: { - /** @description Improved hands depth map */ - image: components["schemas"]["ImageField"]; - /** @description Hands area mask */ - mask: components["schemas"]["ImageField"]; - /** - * Width - * @description The width of the depth map in pixels - */ - width: number; - /** - * Height - * @description The height of the depth map in pixels - */ - height: number; - /** - * type - * @default meshgraphormer_output - * @constant - */ - type: "meshgraphormer_output"; - }; /** * HED (softedge) Processor * @description Applies HED edge detection to image @@ -5260,87 +4746,6 @@ export type components = { */ type: "ideal_size_output"; }; - /** - * Image Layer Blend - * @description Blend two images together, with optional opacity, mask, and blend modes - */ - ImageBlendInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description The top image to blend */ - layer_upper?: components["schemas"]["ImageField"]; - /** - * Blend Mode - * @description Available blend modes - * @default Normal - * @enum {string} - */ - blend_mode?: "Normal" | "Lighten Only" | "Darken Only" | "Lighten Only (EAL)" | "Darken Only (EAL)" | "Hue" | "Saturation" | "Color" | "Luminosity" | "Linear Dodge (Add)" | "Subtract" | "Multiply" | "Divide" | "Screen" | "Overlay" | "Linear Burn" | "Difference" | "Hard Light" | "Soft Light" | "Vivid Light" | "Linear Light" | "Color Burn" | "Color Dodge"; - /** - * Opacity - * @description Desired opacity of the upper layer - * @default 1 - */ - opacity?: number; - /** @description Optional mask, used to restrict areas from blending */ - mask?: components["schemas"]["ImageField"] | null; - /** - * Fit To Width - * @description Scale upper layer to fit base width - * @default false - */ - fit_to_width?: boolean; - /** - * Fit To Height - * @description Scale upper layer to fit base height - * @default true - */ - fit_to_height?: boolean; - /** @description The bottom image to blend */ - layer_base?: components["schemas"]["ImageField"]; - /** - * Color Space - * @description Available color spaces for blend computations - * @default Linear RGB - * @enum {string} - */ - color_space?: "RGB" | "Linear RGB" | "HSL (RGB)" | "HSV (RGB)" | "Okhsl" | "Okhsv" | "Oklch (Oklab)" | "LCh (CIELab)"; - /** - * Adaptive Gamut - * @description Adaptive gamut clipping (0=off). Higher prioritizes chroma over lightness - * @default 0 - */ - adaptive_gamut?: number; - /** - * High Precision - * @description Use more steps in computing gamut when possible - * @default true - */ - high_precision?: boolean; - /** - * type - * @default img_blend - * @constant - */ - type: "img_blend"; - }; /** * Blur Image * @description Blurs an image @@ -5594,77 +4999,6 @@ export type components = { */ type: "image_collection_output"; }; - /** - * Image Compositor - * @description Removes backdrop from subject image then overlays subject on background image - */ - ImageCompositorInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description Image of the subject on a plain monochrome background */ - image_subject?: components["schemas"]["ImageField"]; - /** @description Image of a background scene */ - image_background?: components["schemas"]["ImageField"]; - /** - * Chroma Key - * @description Can be empty for corner flood select, or CSS-3 color or tuple - * @default - */ - chroma_key?: string; - /** - * Threshold - * @description Subject isolation flood-fill threshold - * @default 50 - */ - threshold?: number; - /** - * Fill X - * @description Scale base subject image to fit background width - * @default false - */ - fill_x?: boolean; - /** - * Fill Y - * @description Scale base subject image to fit background height - * @default true - */ - fill_y?: boolean; - /** - * X Offset - * @description x-offset for the subject - * @default 0 - */ - x_offset?: number; - /** - * Y Offset - * @description y-offset for the subject - * @default 0 - */ - y_offset?: number; - /** - * type - * @default img_composite - * @constant - */ - type: "img_composite"; - }; /** * Convert Image Mode * @description Converts an image to a different mode. @@ -5847,10 +5181,23 @@ export type components = { board_id?: string | null; }; /** - * Image Dilate or Erode - * @description Dilate (expand) or erode (contract) an image + * ImageField + * @description An image primitive field + */ + ImageField: { + /** + * Image Name + * @description The name of the image + */ + image_name: string; + }; + /** + * Adjust Image Hue + * @description Adjusts the Hue of an image. */ - ImageDilateOrErodeInvocation: { + ImageHueAdjustmentInvocation: { + /** @description The board to save the image to */ + board?: components["schemas"]["BoardField"] | null; /** @description Optional metadata to be saved with the image */ metadata?: components["schemas"]["MetadataField"] | null; /** @@ -5870,158 +5217,24 @@ export type components = { * @default true */ use_cache?: boolean; - /** @description The image from which to create a mask */ + /** @description The image to adjust */ image?: components["schemas"]["ImageField"]; /** - * Lightness Only - * @description If true, only applies to image lightness (CIELa*b*) - * @default false - */ - lightness_only?: boolean; - /** - * Radius W - * @description Width (in pixels) by which to dilate(expand) or erode (contract) the image - * @default 4 - */ - radius_w?: number; - /** - * Radius H - * @description Height (in pixels) by which to dilate(expand) or erode (contract) the image - * @default 4 - */ - radius_h?: number; - /** - * Mode - * @description How to operate on the image - * @default Dilate - * @enum {string} + * Hue + * @description The degrees by which to rotate the hue, 0-360 + * @default 0 */ - mode?: "Dilate" | "Erode"; + hue?: number; /** * type - * @default img_dilate_erode + * @default img_hue_adjust * @constant */ - type: "img_dilate_erode"; + type: "img_hue_adjust"; }; /** - * Enhance Image - * @description Applies processing from PIL's ImageEnhance module. - */ - ImageEnhanceInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description The image for which to apply processing */ - image?: components["schemas"]["ImageField"]; - /** - * Invert - * @description Whether to invert the image colors - * @default false - */ - invert?: boolean; - /** - * Color - * @description Color enhancement factor - * @default 1 - */ - color?: number; - /** - * Contrast - * @description Contrast enhancement factor - * @default 1 - */ - contrast?: number; - /** - * Brightness - * @description Brightness enhancement factor - * @default 1 - */ - brightness?: number; - /** - * Sharpness - * @description Sharpness enhancement factor - * @default 1 - */ - sharpness?: number; - /** - * type - * @default img_enhance - * @constant - */ - type: "img_enhance"; - }; - /** - * ImageField - * @description An image primitive field - */ - ImageField: { - /** - * Image Name - * @description The name of the image - */ - image_name: string; - }; - /** - * Adjust Image Hue - * @description Adjusts the Hue of an image. - */ - ImageHueAdjustmentInvocation: { - /** @description The board to save the image to */ - board?: components["schemas"]["BoardField"] | null; - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description The image to adjust */ - image?: components["schemas"]["ImageField"]; - /** - * Hue - * @description The degrees by which to rotate the hue, 0-360 - * @default 0 - */ - hue?: number; - /** - * type - * @default img_hue_adjust - * @constant - */ - type: "img_hue_adjust"; - }; - /** - * Inverse Lerp Image - * @description Inverse linear interpolation of all pixels of an image + * Inverse Lerp Image + * @description Inverse linear interpolation of all pixels of an image */ ImageInverseLerpInvocation: { /** @description The board to save the image to */ @@ -6216,57 +5429,6 @@ export type components = { */ type: "img_nsfw"; }; - /** - * Offset Image - * @description Offsets an image by a given percentage (or pixel amount). - */ - ImageOffsetInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** - * As Pixels - * @description Interpret offsets as pixels rather than percentages - * @default false - */ - as_pixels?: boolean; - /** @description Image to be offset */ - image?: components["schemas"]["ImageField"]; - /** - * X Offset - * @description x-offset for the subject - * @default 0.5 - */ - x_offset?: number; - /** - * Y Offset - * @description y-offset for the subject - * @default 0.5 - */ - y_offset?: number; - /** - * type - * @default offset_image - * @constant - */ - type: "offset_image"; - }; /** * ImageOutput * @description Base class for nodes that output a single image @@ -6432,63 +5594,6 @@ export type components = { */ type: "img_resize"; }; - /** - * Rotate/Flip Image - * @description Rotates an image by a given angle (in degrees clockwise). - */ - ImageRotateInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description Image to be rotated clockwise */ - image?: components["schemas"]["ImageField"]; - /** - * Degrees - * @description Angle (in degrees clockwise) by which to rotate - * @default 90 - */ - degrees?: number; - /** - * Expand To Fit - * @description If true, extends the image boundary to fit the rotated content - * @default true - */ - expand_to_fit?: boolean; - /** - * Flip Horizontal - * @description If true, flips the image horizontally - * @default false - */ - flip_horizontal?: boolean; - /** - * Flip Vertical - * @description If true, flips the image vertically - * @default false - */ - flip_vertical?: boolean; - /** - * type - * @default rotate_image - * @constant - */ - type: "rotate_image"; - }; /** * Scale Image * @description Scales an image by a factor @@ -6603,69 +5708,6 @@ export type components = { */ thumbnail_url: string; }; - /** - * Image Value Thresholds - * @description Clip image to pure black/white past specified thresholds - */ - ImageValueThresholdsInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description The image from which to create a mask */ - image?: components["schemas"]["ImageField"]; - /** - * Invert Output - * @description Make light areas dark and vice versa - * @default false - */ - invert_output?: boolean; - /** - * Renormalize Values - * @description Rescale remaining values from minimum to maximum - * @default false - */ - renormalize_values?: boolean; - /** - * Lightness Only - * @description If true, only applies to image lightness (CIELa*b*) - * @default false - */ - lightness_only?: boolean; - /** - * Threshold Upper - * @description Threshold above which will be set to full value - * @default 0.5 - */ - threshold_upper?: number; - /** - * Threshold Lower - * @description Threshold below which will be set to minimum value - * @default 0.5 - */ - threshold_lower?: number; - /** - * type - * @default img_val_thresholds - * @constant - */ - type: "img_val_thresholds"; - }; /** * Add Invisible Watermark * @description Add an invisible watermark to an image @@ -8025,47 +7067,6 @@ export type components = { */ type: "tomask"; }; - /** - * Blend Latents/Noise (Masked) - * @description Blend two latents using a given alpha and mask. Latents must have same size. - */ - MaskedBlendLatentsInvocation: { - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description Latents tensor */ - latents_a?: components["schemas"]["LatentsField"]; - /** @description Latents tensor */ - latents_b?: components["schemas"]["LatentsField"]; - /** @description Mask for blending in latents B */ - mask?: components["schemas"]["ImageField"]; - /** - * Alpha - * @description Blending factor. 0.0 = use input A only, 1.0 = use input B only, 0.5 = 50% mix of input A and input B. - * @default 0.5 - */ - alpha?: number; - /** - * type - * @default lmblend - * @constant - */ - type: "lmblend"; - }; /** * Mediapipe Face Processor * @description Applies mediapipe face processing to image @@ -8679,88 +7680,8 @@ export type components = { value: string | number; }; /** - * 2D Noise Image - * @description Creates an image of 2D Noise approximating the desired characteristics - */ - NoiseImage2DInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** - * Noise Type - * @description Desired noise spectral characteristics - * @default White - * @enum {string} - */ - noise_type?: "White" | "Red" | "Blue" | "Green"; - /** - * Width - * @description Desired image width - * @default 512 - */ - width?: number; - /** - * Height - * @description Desired image height - * @default 512 - */ - height?: number; - /** - * Seed - * @description Seed for noise generation - * @default 0 - */ - seed?: number; - /** - * Iterations - * @description Noise approx. iterations - * @default 15 - */ - iterations?: number; - /** - * Blur Threshold - * @description Threshold used in computing noise (lower is better/slower) - * @default 0.2 - */ - blur_threshold?: number; - /** - * Sigma Red - * @description Sigma for strong gaussian blur LPF for red/green - * @default 3 - */ - sigma_red?: number; - /** - * Sigma Blue - * @description Sigma for weak gaussian blur HPF for blue/green - * @default 1 - */ - sigma_blue?: number; - /** - * type - * @default noiseimg_2d - * @constant - */ - type: "noiseimg_2d"; - }; - /** - * Noise - * @description Generates latent noise. + * Noise + * @description Generates latent noise. */ NoiseInvocation: { /** @@ -8835,84 +7756,6 @@ export type components = { */ type: "noise_output"; }; - /** - * Noise (Spectral characteristics) - * @description Creates an image of 2D Noise approximating the desired characteristics - */ - NoiseSpectralInvocation: { - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** - * Noise Type - * @description Desired noise spectral characteristics - * @default White - * @enum {string} - */ - noise_type?: "White" | "Red" | "Blue" | "Green"; - /** - * Width - * @description Desired image width - * @default 512 - */ - width?: number; - /** - * Height - * @description Desired image height - * @default 512 - */ - height?: number; - /** - * Seed - * @description Seed for noise generation - * @default 0 - */ - seed?: number; - /** - * Iterations - * @description Noise approx. iterations - * @default 15 - */ - iterations?: number; - /** - * Blur Threshold - * @description Threshold used in computing noise (lower is better/slower) - * @default 0.2 - */ - blur_threshold?: number; - /** - * Sigma Red - * @description Sigma for strong gaussian blur LPF for red/green - * @default 3 - */ - sigma_red?: number; - /** - * Sigma Blue - * @description Sigma for weak gaussian blur HPF for blue/green - * @default 1 - */ - sigma_blue?: number; - /** - * type - * @default noise_spectral - * @constant - */ - type: "noise_spectral"; - }; /** * Normal BAE Processor * @description Applies NormalBae processing to image @@ -8960,107 +7803,6 @@ export type components = { */ type: "normalbae_image_processor"; }; - /** - * ONNX Latents to Image - * @description Generates an image from latents. - */ - ONNXLatentsToImageInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description Denoised latents tensor */ - latents?: components["schemas"]["LatentsField"]; - /** @description VAE */ - vae?: components["schemas"]["VaeField"]; - /** - * type - * @default l2i_onnx - * @constant - */ - type: "l2i_onnx"; - }; - /** - * ONNXModelLoaderOutput - * @description Model loader output - */ - ONNXModelLoaderOutput: { - /** - * UNet - * @description UNet (scheduler, LoRAs) - */ - unet?: components["schemas"]["UNetField"]; - /** - * CLIP - * @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count - */ - clip?: components["schemas"]["ClipField"]; - /** - * VAE Decoder - * @description VAE - */ - vae_decoder?: components["schemas"]["VaeField"]; - /** - * VAE Encoder - * @description VAE - */ - vae_encoder?: components["schemas"]["VaeField"]; - /** - * type - * @default model_loader_output_onnx - * @constant - */ - type: "model_loader_output_onnx"; - }; - /** ONNX Prompt (Raw) */ - ONNXPromptInvocation: { - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** - * Prompt - * @description Raw prompt text (no parsing) - * @default - */ - prompt?: string; - /** @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count */ - clip?: components["schemas"]["ClipField"]; - /** - * type - * @default prompt_onnx - * @constant - */ - type: "prompt_onnx"; - }; /** * ONNXSD1Config * @description Model config for ONNX format models based on sd-1. @@ -9242,117 +7984,6 @@ export type components = { /** Upcast Attention */ upcast_attention: boolean; }; - /** - * ONNX Text to Latents - * @description Generates latents from conditionings. - */ - ONNXTextToLatentsInvocation: { - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description Positive conditioning tensor */ - positive_conditioning?: components["schemas"]["ConditioningField"]; - /** @description Negative conditioning tensor */ - negative_conditioning?: components["schemas"]["ConditioningField"]; - /** @description Noise tensor */ - noise?: components["schemas"]["LatentsField"]; - /** - * Steps - * @description Number of steps to run - * @default 10 - */ - steps?: number; - /** - * Cfg Scale - * @description Classifier-Free Guidance scale - * @default 7.5 - */ - cfg_scale?: number | number[]; - /** - * Scheduler - * @description Scheduler to use during inference - * @default euler - * @enum {string} - */ - scheduler?: "ddim" | "ddpm" | "deis" | "lms" | "lms_k" | "pndm" | "heun" | "heun_k" | "euler" | "euler_k" | "euler_a" | "kdpm_2" | "kdpm_2_a" | "dpmpp_2s" | "dpmpp_2s_k" | "dpmpp_2m" | "dpmpp_2m_k" | "dpmpp_2m_sde" | "dpmpp_2m_sde_k" | "dpmpp_sde" | "dpmpp_sde_k" | "unipc" | "lcm"; - /** - * Precision - * @description Precision to use - * @default tensor(float16) - * @enum {string} - */ - precision?: "tensor(bool)" | "tensor(int8)" | "tensor(uint8)" | "tensor(int16)" | "tensor(uint16)" | "tensor(int32)" | "tensor(uint32)" | "tensor(int64)" | "tensor(uint64)" | "tensor(float16)" | "tensor(float)" | "tensor(double)"; - /** @description UNet (scheduler, LoRAs) */ - unet?: components["schemas"]["UNetField"]; - /** - * Control - * @description ControlNet(s) to apply - */ - control?: components["schemas"]["ControlField"] | components["schemas"]["ControlField"][]; - /** - * type - * @default t2l_onnx - * @constant - */ - type: "t2l_onnx"; - }; - /** - * Offset Latents - * @description Offsets a latents tensor by a given percentage of height/width. - */ - OffsetLatentsInvocation: { - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description Latents tensor */ - latents?: components["schemas"]["LatentsField"]; - /** - * X Offset - * @description Approx percentage to offset (H) - * @default 0.5 - */ - x_offset?: number; - /** - * Y Offset - * @description Approx percentage to offset (V) - * @default 0.5 - */ - y_offset?: number; - /** - * type - * @default offset_latents - * @constant - */ - type: "offset_latents"; - }; /** OffsetPaginatedResults[BoardDTO] */ OffsetPaginatedResults_BoardDTO_: { /** @@ -9392,58 +8023,12 @@ export type components = { * Total * @description Total number of items in result */ - total: number; - /** - * Items - * @description Items - */ - items: components["schemas"]["ImageDTO"][]; - }; - /** - * OnnxModelField - * @description Onnx model field - */ - OnnxModelField: { - /** - * Model Name - * @description Name of the model - */ - model_name: string; - /** @description Base model */ - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; - /** @description Model Type */ - model_type: components["schemas"]["invokeai__backend__model_management__models__base__ModelType"]; - }; - /** - * ONNX Main Model - * @description Loads a main model, outputting its submodels. - */ - OnnxModelLoaderInvocation: { - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description ONNX Main model (UNet, VAE, CLIP) to load */ - model: components["schemas"]["OnnxModelField"]; + total: number; /** - * type - * @default onnx_model_loader - * @constant + * Items + * @description Items */ - type: "onnx_model_loader"; + items: components["schemas"]["ImageDTO"][]; }; /** PaginatedResults[ModelSummary] */ PaginatedResults_ModelSummary_: { @@ -10852,106 +9437,6 @@ export type components = { */ total: number; }; - /** - * Shadows/Highlights/Midtones - * @description Extract a Shadows/Highlights/Midtones mask from an image - */ - ShadowsHighlightsMidtonesMaskInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description Image from which to extract mask */ - image?: components["schemas"]["ImageField"]; - /** - * Invert Output - * @description Off: white on black / On: black on white - * @default true - */ - invert_output?: boolean; - /** - * Highlight Threshold - * @description Threshold beyond which mask values will be at extremum - * @default 0.75 - */ - highlight_threshold?: number; - /** - * Upper Mid Threshold - * @description Threshold to which to extend mask border by 0..1 gradient - * @default 0.7 - */ - upper_mid_threshold?: number; - /** - * Lower Mid Threshold - * @description Threshold to which to extend mask border by 0..1 gradient - * @default 0.3 - */ - lower_mid_threshold?: number; - /** - * Shadow Threshold - * @description Threshold beyond which mask values will be at extremum - * @default 0.25 - */ - shadow_threshold?: number; - /** - * Mask Expand Or Contract - * @description Pixels to grow (or shrink) the mask areas - * @default 0 - */ - mask_expand_or_contract?: number; - /** - * Mask Blur - * @description Gaussian blur radius to apply to the masks - * @default 0 - */ - mask_blur?: number; - /** - * type - * @default shmmask - * @constant - */ - type: "shmmask"; - }; - /** ShadowsHighlightsMidtonesMasksOutput */ - ShadowsHighlightsMidtonesMasksOutput: { - /** @description Soft-edged highlights mask */ - highlights_mask?: components["schemas"]["ImageField"]; - /** @description Soft-edged midtones mask */ - midtones_mask?: components["schemas"]["ImageField"]; - /** @description Soft-edged shadows mask */ - shadows_mask?: components["schemas"]["ImageField"]; - /** - * Width - * @description Width of the input/outputs - */ - width: number; - /** - * Height - * @description Height of the input/outputs - */ - height: number; - /** - * type - * @default shmmask_output - * @constant - */ - type: "shmmask_output"; - }; /** * Show Image * @description Displays a provided image using the OS image viewer, and passes it forward in the pipeline. @@ -11833,249 +10318,6 @@ export type components = { /** Right */ right: number; }; - /** - * Text Mask - * @description Creates a 2D rendering of a text mask from a given font - */ - TextMaskInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** - * Width - * @description The width of the desired mask - * @default 512 - */ - width?: number; - /** - * Height - * @description The height of the desired mask - * @default 512 - */ - height?: number; - /** - * Text - * @description The text to render - * @default - */ - text?: string; - /** - * Font - * @description Path to a FreeType-supported TTF/OTF font file - * @default - */ - font?: string; - /** - * Size - * @description Desired point size of text to use - * @default 64 - */ - size?: number; - /** - * Angle - * @description Angle of rotation to apply to the text - * @default 0 - */ - angle?: number; - /** - * X Offset - * @description x-offset for text rendering - * @default 24 - */ - x_offset?: number; - /** - * Y Offset - * @description y-offset for text rendering - * @default 36 - */ - y_offset?: number; - /** - * Invert - * @description Whether to invert color of the output - * @default false - */ - invert?: boolean; - /** - * type - * @default text_mask - * @constant - */ - type: "text_mask"; - }; - /** - * Text to Mask Advanced (Clipseg) - * @description Uses the Clipseg model to generate an image mask from a text prompt - */ - TextToMaskClipsegAdvancedInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description The image from which to create a mask */ - image?: components["schemas"]["ImageField"]; - /** - * Invert Output - * @description Off: white on black / On: black on white - * @default true - */ - invert_output?: boolean; - /** - * Prompt 1 - * @description First prompt with which to create a mask - */ - prompt_1?: string; - /** - * Prompt 2 - * @description Second prompt with which to create a mask (optional) - */ - prompt_2?: string; - /** - * Prompt 3 - * @description Third prompt with which to create a mask (optional) - */ - prompt_3?: string; - /** - * Prompt 4 - * @description Fourth prompt with which to create a mask (optional) - */ - prompt_4?: string; - /** - * Combine - * @description How to combine the results - * @default or - * @enum {string} - */ - combine?: "or" | "and" | "none (rgba multiplex)"; - /** - * Smoothing - * @description Radius of blur to apply before thresholding - * @default 4 - */ - smoothing?: number; - /** - * Subject Threshold - * @description Threshold above which is considered the subject - * @default 1 - */ - subject_threshold?: number; - /** - * Background Threshold - * @description Threshold below which is considered the background - * @default 0 - */ - background_threshold?: number; - /** - * type - * @default txt2mask_clipseg_adv - * @constant - */ - type: "txt2mask_clipseg_adv"; - }; - /** - * Text to Mask (Clipseg) - * @description Uses the Clipseg model to generate an image mask from a text prompt - */ - TextToMaskClipsegInvocation: { - /** @description Optional metadata to be saved with the image */ - metadata?: components["schemas"]["MetadataField"] | null; - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description The image from which to create a mask */ - image?: components["schemas"]["ImageField"]; - /** - * Invert Output - * @description Off: white on black / On: black on white - * @default true - */ - invert_output?: boolean; - /** - * Prompt - * @description The prompt with which to create a mask - */ - prompt?: string; - /** - * Smoothing - * @description Radius of blur to apply before thresholding - * @default 4 - */ - smoothing?: number; - /** - * Subject Threshold - * @description Threshold above which is considered the subject - * @default 0.4 - */ - subject_threshold?: number; - /** - * Background Threshold - * @description Threshold below which is considered the background - * @default 0.4 - */ - background_threshold?: number; - /** - * Mask Expand Or Contract - * @description Pixels by which to grow (or shrink) mask after thresholding - * @default 0 - */ - mask_expand_or_contract?: number; - /** - * Mask Blur - * @description Radius of blur to apply after thresholding - * @default 0 - */ - mask_blur?: number; - /** - * type - * @default txt2mask_clipseg - * @constant - */ - type: "txt2mask_clipseg"; - }; /** * TextualInversionConfig * @description Model config for textual inversion embeddings. @@ -13067,53 +11309,53 @@ export type components = { */ UIType: "SDXLMainModelField" | "SDXLRefinerModelField" | "ONNXModelField" | "VAEModelField" | "LoRAModelField" | "ControlNetModelField" | "IPAdapterModelField" | "SchedulerField" | "AnyField" | "CollectionField" | "CollectionItemField" | "DEPRECATED_Boolean" | "DEPRECATED_Color" | "DEPRECATED_Conditioning" | "DEPRECATED_Control" | "DEPRECATED_Float" | "DEPRECATED_Image" | "DEPRECATED_Integer" | "DEPRECATED_Latents" | "DEPRECATED_String" | "DEPRECATED_BooleanCollection" | "DEPRECATED_ColorCollection" | "DEPRECATED_ConditioningCollection" | "DEPRECATED_ControlCollection" | "DEPRECATED_FloatCollection" | "DEPRECATED_ImageCollection" | "DEPRECATED_IntegerCollection" | "DEPRECATED_LatentsCollection" | "DEPRECATED_StringCollection" | "DEPRECATED_BooleanPolymorphic" | "DEPRECATED_ColorPolymorphic" | "DEPRECATED_ConditioningPolymorphic" | "DEPRECATED_ControlPolymorphic" | "DEPRECATED_FloatPolymorphic" | "DEPRECATED_ImagePolymorphic" | "DEPRECATED_IntegerPolymorphic" | "DEPRECATED_LatentsPolymorphic" | "DEPRECATED_StringPolymorphic" | "DEPRECATED_MainModel" | "DEPRECATED_UNet" | "DEPRECATED_Vae" | "DEPRECATED_CLIP" | "DEPRECATED_Collection" | "DEPRECATED_CollectionItem" | "DEPRECATED_Enum" | "DEPRECATED_WorkflowField" | "DEPRECATED_IsIntermediate" | "DEPRECATED_BoardField" | "DEPRECATED_MetadataItem" | "DEPRECATED_MetadataItemCollection" | "DEPRECATED_MetadataItemPolymorphic" | "DEPRECATED_MetadataDict"; /** - * T2IAdapterModelFormat + * StableDiffusionXLModelFormat * @description An enumeration. * @enum {string} */ - T2IAdapterModelFormat: "diffusers"; + StableDiffusionXLModelFormat: "checkpoint" | "diffusers"; /** - * IPAdapterModelFormat + * T2IAdapterModelFormat * @description An enumeration. * @enum {string} */ - IPAdapterModelFormat: "invokeai"; + T2IAdapterModelFormat: "diffusers"; /** - * StableDiffusionXLModelFormat + * CLIPVisionModelFormat * @description An enumeration. * @enum {string} */ - StableDiffusionXLModelFormat: "checkpoint" | "diffusers"; + CLIPVisionModelFormat: "diffusers"; /** - * StableDiffusionOnnxModelFormat + * ControlNetModelFormat * @description An enumeration. * @enum {string} */ - StableDiffusionOnnxModelFormat: "olive" | "onnx"; + ControlNetModelFormat: "checkpoint" | "diffusers"; /** - * CLIPVisionModelFormat + * StableDiffusion1ModelFormat * @description An enumeration. * @enum {string} */ - CLIPVisionModelFormat: "diffusers"; + StableDiffusion1ModelFormat: "checkpoint" | "diffusers"; /** - * StableDiffusion2ModelFormat + * IPAdapterModelFormat * @description An enumeration. * @enum {string} */ - StableDiffusion2ModelFormat: "checkpoint" | "diffusers"; + IPAdapterModelFormat: "invokeai"; /** - * StableDiffusion1ModelFormat + * StableDiffusion2ModelFormat * @description An enumeration. * @enum {string} */ - StableDiffusion1ModelFormat: "checkpoint" | "diffusers"; + StableDiffusion2ModelFormat: "checkpoint" | "diffusers"; /** - * ControlNetModelFormat + * StableDiffusionOnnxModelFormat * @description An enumeration. * @enum {string} */ - ControlNetModelFormat: "checkpoint" | "diffusers"; + StableDiffusionOnnxModelFormat: "olive" | "onnx"; }; responses: never; parameters: never; From 3c47e6f025d5da8b88050e89c55cf49dcf34b002 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 13 Feb 2024 16:30:00 +1100 Subject: [PATCH 068/340] feat(ui): workflow schema v3 (WIP) The changes aim to deduplicate data between workflows and node templates, decoupling workflows from internal implementation details. A good amount of data that was needlessly duplicated from the node template to the workflow is removed. These changes substantially reduce the file size of workflows (and therefore the images with embedded workflows): - Default T2I SD1.5 workflow JSON is reduced from 23.7kb (798 lines) to 10.9kb (407 lines). - Default tiled upscale workflow JSON is reduced from 102.7kb (3341 lines) to 51.9kb (1774 lines). The trade-off is that we need to reference node templates to get things like the field type and other things. In practice, this is a non-issue, because we need a node template to do anything with a node anyways. - Field types are not included in the workflow. They are always pulled from the node templates. The field type is now properly an internal implementation detail and we can change it as needed. Previously this would require a migration for the workflow itself. With the v3 schema, the structure of a field type is an internal implementation detail that we are free to change as we see fit. - Workflow nodes no long have an `outputs` property and there is no longer such a thing as a `FieldOutputInstance`. These are only on the templates. These were never referenced at a time when we didn't also have the templates available, and there'd be no reason to do so. - Node width and height are no longer stored in the node. These weren't used. Also, per https://reactflow.dev/api-reference/types/node, we shouldn't be programmatically changing these properties. A future enhancement can properly add node resizing. - `nodeTemplates` slice is merged back into `nodesSlice` as `nodes.templates`. Turns out it's just a hassle having these separate in separate slices. - Workflow migration logic updated to support the new schema. V1 workflows migrate all the way to v3 now. - Changes throughout the nodes code to accommodate the above changes. --- .../middleware/devtools/actionSanitizer.ts | 2 +- .../listeners/getOpenAPISchema.ts | 2 +- .../listeners/updateAllNodesRequested.ts | 3 +- .../listeners/workflowLoadRequested.ts | 2 +- invokeai/frontend/web/src/app/store/store.ts | 2 - .../frontend/web/src/app/store/storeHooks.ts | 3 +- invokeai/frontend/web/src/app/store/util.ts | 2 + .../src/common/hooks/useIsReadyToEnqueue.ts | 6 +- .../flow/AddNodePopover/AddNodePopover.tsx | 14 +- .../flow/edges/util/makeEdgeSelector.ts | 18 +- .../InvocationNodeCollapsedHandles.tsx | 19 +- .../Invocation/InvocationNodeWrapper.tsx | 4 +- .../Invocation/fields/EditableFieldTitle.tsx | 4 +- .../nodes/Invocation/fields/FieldTitle.tsx | 2 +- .../Invocation/fields/FieldTooltipContent.tsx | 6 +- .../nodes/Invocation/fields/InputField.tsx | 6 +- .../Invocation/fields/InputFieldRenderer.tsx | 29 +- .../Invocation/fields/LinearViewField.tsx | 4 +- .../nodes/Invocation/fields/OutputField.tsx | 10 +- .../inspector/InspectorDetailsTab.tsx | 5 +- .../inspector/InspectorOutputsTab.tsx | 5 +- .../inspector/InspectorTemplateTab.tsx | 5 +- .../hooks/useAnyOrDirectInputFieldNames.ts | 20 +- .../src/features/nodes/hooks/useBuildNode.ts | 2 +- .../hooks/useConnectionInputFieldNames.ts | 20 +- .../nodes/hooks/useConnectionState.ts | 10 +- .../nodes/hooks/useDoNodeVersionsMatch.ts | 18 +- .../nodes/hooks/useDoesInputHaveValue.ts | 12 +- .../src/features/nodes/hooks/useFieldData.ts | 23 - .../nodes/hooks/useFieldInputInstance.ts | 15 +- .../features/nodes/hooks/useFieldInputKind.ts | 15 +- .../nodes/hooks/useFieldInputTemplate.ts | 15 +- .../src/features/nodes/hooks/useFieldLabel.ts | 10 +- .../nodes/hooks/useFieldOutputInstance.ts | 23 - .../nodes/hooks/useFieldOutputTemplate.ts | 15 +- .../features/nodes/hooks/useFieldTemplate.ts | 21 +- .../nodes/hooks/useFieldTemplateTitle.ts | 16 +- .../features/nodes/hooks/useFieldType.ts.ts | 14 +- .../nodes/hooks/useGetNodesNeedUpdate.ts | 5 +- .../features/nodes/hooks/useHasImageOutput.ts | 13 +- .../features/nodes/hooks/useIsIntermediate.ts | 10 +- .../nodes/hooks/useIsValidConnection.ts | 38 +- .../nodes/hooks/useNodeClassification.ts | 17 +- .../src/features/nodes/hooks/useNodeData.ts | 7 +- .../src/features/nodes/hooks/useNodeLabel.ts | 9 +- .../nodes/hooks/useNodeNeedsUpdate.ts | 15 +- .../src/features/nodes/hooks/useNodePack.ts | 10 +- .../features/nodes/hooks/useNodeTemplate.ts | 13 +- .../nodes/hooks/useNodeTemplateByType.ts | 10 +- .../nodes/hooks/useNodeTemplateTitle.ts | 15 +- .../nodes/hooks/useOutputFieldNames.ts | 20 +- .../src/features/nodes/hooks/useUseCache.ts | 8 +- .../nodes/hooks/useWorkflowWatcher.ts | 4 +- .../web/src/features/nodes/store/actions.ts | 4 +- .../nodes/store/nodeTemplatesSlice.ts | 24 - .../src/features/nodes/store/nodesSlice.ts | 15 +- .../web/src/features/nodes/store/selectors.ts | 51 + .../web/src/features/nodes/store/types.ts | 5 +- .../store/util/findConnectionToValidHandle.ts | 30 +- .../util/makeIsConnectionValidSelector.ts | 2 +- .../src/features/nodes/store/workflowSlice.ts | 6 +- .../web/src/features/nodes/types/field.ts | 130 +-- .../src/features/nodes/types/invocation.ts | 22 +- .../web/src/features/nodes/types/v2/common.ts | 188 ++++ .../src/features/nodes/types/v2/constants.ts | 80 ++ .../web/src/features/nodes/types/v2/error.ts | 58 ++ .../web/src/features/nodes/types/v2/field.ts | 875 ++++++++++++++++++ .../src/features/nodes/types/v2/invocation.ts | 93 ++ .../src/features/nodes/types/v2/metadata.ts | 77 ++ .../src/features/nodes/types/v2/openapi.ts | 86 ++ .../web/src/features/nodes/types/v2/semver.ts | 21 + .../src/features/nodes/types/v2/workflow.ts | 89 ++ .../web/src/features/nodes/types/workflow.ts | 10 +- .../nodes/util/node/buildInvocationNode.ts | 22 +- .../features/nodes/util/node/nodeUpdate.ts | 1 - .../util/schema/buildFieldInputInstance.ts | 3 - .../nodes/util/workflow/buildWorkflow.ts | 20 +- .../nodes/util/workflow/migrations.ts | 32 +- .../nodes/util/workflow/validateWorkflow.ts | 4 +- .../workflowLibrary/hooks/useSaveWorkflow.ts | 4 +- 80 files changed, 1936 insertions(+), 612 deletions(-) create mode 100644 invokeai/frontend/web/src/app/store/util.ts delete mode 100644 invokeai/frontend/web/src/features/nodes/hooks/useFieldData.ts delete mode 100644 invokeai/frontend/web/src/features/nodes/hooks/useFieldOutputInstance.ts delete mode 100644 invokeai/frontend/web/src/features/nodes/store/nodeTemplatesSlice.ts create mode 100644 invokeai/frontend/web/src/features/nodes/store/selectors.ts create mode 100644 invokeai/frontend/web/src/features/nodes/types/v2/common.ts create mode 100644 invokeai/frontend/web/src/features/nodes/types/v2/constants.ts create mode 100644 invokeai/frontend/web/src/features/nodes/types/v2/error.ts create mode 100644 invokeai/frontend/web/src/features/nodes/types/v2/field.ts create mode 100644 invokeai/frontend/web/src/features/nodes/types/v2/invocation.ts create mode 100644 invokeai/frontend/web/src/features/nodes/types/v2/metadata.ts create mode 100644 invokeai/frontend/web/src/features/nodes/types/v2/openapi.ts create mode 100644 invokeai/frontend/web/src/features/nodes/types/v2/semver.ts create mode 100644 invokeai/frontend/web/src/features/nodes/types/v2/workflow.ts diff --git a/invokeai/frontend/web/src/app/store/middleware/devtools/actionSanitizer.ts b/invokeai/frontend/web/src/app/store/middleware/devtools/actionSanitizer.ts index 2e2d2014b23..ed8c82d91ca 100644 --- a/invokeai/frontend/web/src/app/store/middleware/devtools/actionSanitizer.ts +++ b/invokeai/frontend/web/src/app/store/middleware/devtools/actionSanitizer.ts @@ -1,6 +1,6 @@ import type { UnknownAction } from '@reduxjs/toolkit'; import { isAnyGraphBuilt } from 'features/nodes/store/actions'; -import { nodeTemplatesBuilt } from 'features/nodes/store/nodeTemplatesSlice'; +import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice'; import { cloneDeep } from 'lodash-es'; import { appInfoApi } from 'services/api/endpoints/appInfo'; import type { Graph } from 'services/api/types'; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/getOpenAPISchema.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/getOpenAPISchema.ts index b2d36159098..88518e2c0bb 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/getOpenAPISchema.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/getOpenAPISchema.ts @@ -1,6 +1,6 @@ import { logger } from 'app/logging/logger'; import { parseify } from 'common/util/serialize'; -import { nodeTemplatesBuilt } from 'features/nodes/store/nodeTemplatesSlice'; +import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice'; import { parseSchema } from 'features/nodes/util/schema/parseSchema'; import { size } from 'lodash-es'; import { appInfoApi } from 'services/api/endpoints/appInfo'; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested.ts index 752c3b09df2..ac1298da5ba 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested.ts @@ -15,8 +15,7 @@ export const addUpdateAllNodesRequestedListener = () => { actionCreator: updateAllNodesRequested, effect: (action, { dispatch, getState }) => { const log = logger('nodes'); - const nodes = getState().nodes.nodes; - const templates = getState().nodeTemplates.templates; + const { nodes, templates } = getState().nodes; let unableToUpdateCount = 0; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested.ts index 9307031e6d0..ad41dc2654f 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested.ts @@ -18,7 +18,7 @@ export const addWorkflowLoadRequestedListener = () => { effect: (action, { dispatch, getState }) => { const log = logger('nodes'); const { workflow, asCopy } = action.payload; - const nodeTemplates = getState().nodeTemplates.templates; + const nodeTemplates = getState().nodes.templates; try { const { workflow: validatedWorkflow, warnings } = validateWorkflow(workflow, nodeTemplates); diff --git a/invokeai/frontend/web/src/app/store/store.ts b/invokeai/frontend/web/src/app/store/store.ts index e25e1351eb9..270662c3d21 100644 --- a/invokeai/frontend/web/src/app/store/store.ts +++ b/invokeai/frontend/web/src/app/store/store.ts @@ -16,7 +16,6 @@ import { hrfPersistConfig, hrfSlice } from 'features/hrf/store/hrfSlice'; import { loraPersistConfig, loraSlice } from 'features/lora/store/loraSlice'; import { modelManagerPersistConfig, modelManagerSlice } from 'features/modelManager/store/modelManagerSlice'; import { nodesPersistConfig, nodesSlice } from 'features/nodes/store/nodesSlice'; -import { nodesTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; import { workflowPersistConfig, workflowSlice } from 'features/nodes/store/workflowSlice'; import { generationPersistConfig, generationSlice } from 'features/parameters/store/generationSlice'; import { postprocessingPersistConfig, postprocessingSlice } from 'features/parameters/store/postprocessingSlice'; @@ -46,7 +45,6 @@ const allReducers = { [gallerySlice.name]: gallerySlice.reducer, [generationSlice.name]: generationSlice.reducer, [nodesSlice.name]: nodesSlice.reducer, - [nodesTemplatesSlice.name]: nodesTemplatesSlice.reducer, [postprocessingSlice.name]: postprocessingSlice.reducer, [systemSlice.name]: systemSlice.reducer, [configSlice.name]: configSlice.reducer, diff --git a/invokeai/frontend/web/src/app/store/storeHooks.ts b/invokeai/frontend/web/src/app/store/storeHooks.ts index f1a9aa979c0..6bc904acb31 100644 --- a/invokeai/frontend/web/src/app/store/storeHooks.ts +++ b/invokeai/frontend/web/src/app/store/storeHooks.ts @@ -1,7 +1,8 @@ import type { AppThunkDispatch, RootState } from 'app/store/store'; import type { TypedUseSelectorHook } from 'react-redux'; -import { useDispatch, useSelector } from 'react-redux'; +import { useDispatch, useSelector, useStore } from 'react-redux'; // Use throughout your app instead of plain `useDispatch` and `useSelector` export const useAppDispatch = () => useDispatch(); export const useAppSelector: TypedUseSelectorHook = useSelector; +export const useAppStore = () => useStore(); diff --git a/invokeai/frontend/web/src/app/store/util.ts b/invokeai/frontend/web/src/app/store/util.ts new file mode 100644 index 00000000000..381f7f85d26 --- /dev/null +++ b/invokeai/frontend/web/src/app/store/util.ts @@ -0,0 +1,2 @@ +export const EMPTY_ARRAY = []; +export const EMPTY_OBJECT = {}; diff --git a/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts b/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts index 4952fa1c47b..baa704e75ca 100644 --- a/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts +++ b/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts @@ -8,7 +8,6 @@ import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types'; import { selectDynamicPromptsSlice } from 'features/dynamicPrompts/store/dynamicPromptsSlice'; import { getShouldProcessPrompt } from 'features/dynamicPrompts/util/getShouldProcessPrompt'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; import { isInvocationNode } from 'features/nodes/types/invocation'; import { selectGenerationSlice } from 'features/parameters/store/generationSlice'; import { selectSystemSlice } from 'features/system/store/systemSlice'; @@ -23,11 +22,10 @@ const selector = createMemoizedSelector( selectGenerationSlice, selectSystemSlice, selectNodesSlice, - selectNodeTemplatesSlice, selectDynamicPromptsSlice, activeTabNameSelector, ], - (controlAdapters, generation, system, nodes, nodeTemplates, dynamicPrompts, activeTabName) => { + (controlAdapters, generation, system, nodes, dynamicPrompts, activeTabName) => { const { initialImage, model, positivePrompt } = generation; const { isConnected } = system; @@ -54,7 +52,7 @@ const selector = createMemoizedSelector( return; } - const nodeTemplate = nodeTemplates.templates[node.data.type]; + const nodeTemplate = nodes.templates[node.data.type]; if (!nodeTemplate) { // Node type not found diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx index b24b52c6abf..061209cafc0 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx @@ -7,8 +7,12 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import type { SelectInstance } from 'chakra-react-select'; import { useBuildNode } from 'features/nodes/hooks/useBuildNode'; -import { addNodePopoverClosed, addNodePopoverOpened, nodeAdded } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; +import { + addNodePopoverClosed, + addNodePopoverOpened, + nodeAdded, + selectNodesSlice, +} from 'features/nodes/store/nodesSlice'; import { validateSourceAndTargetTypes } from 'features/nodes/store/util/validateSourceAndTargetTypes'; import { filter, map, memoize, some } from 'lodash-es'; import type { KeyboardEventHandler } from 'react'; @@ -54,10 +58,10 @@ const AddNodePopover = () => { const fieldFilter = useAppSelector((s) => s.nodes.connectionStartFieldType); const handleFilter = useAppSelector((s) => s.nodes.connectionStartParams?.handleType); - const selector = createMemoizedSelector(selectNodeTemplatesSlice, (nodeTemplates) => { + const selector = createMemoizedSelector(selectNodesSlice, (nodes) => { // If we have a connection in progress, we need to filter the node choices const filteredNodeTemplates = fieldFilter - ? filter(nodeTemplates.templates, (template) => { + ? filter(nodes.templates, (template) => { const handles = handleFilter === 'source' ? template.inputs : template.outputs; return some(handles, (handle) => { @@ -67,7 +71,7 @@ const AddNodePopover = () => { return validateSourceAndTargetTypes(sourceType, targetType); }); }) - : map(nodeTemplates.templates); + : map(nodes.templates); const options: ComboboxOption[] = map(filteredNodeTemplates, (template) => { return { diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/makeEdgeSelector.ts b/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/makeEdgeSelector.ts index 4bfc588e675..ba40b4984cd 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/makeEdgeSelector.ts +++ b/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/makeEdgeSelector.ts @@ -1,10 +1,17 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; +import { selectFieldOutputTemplate } from 'features/nodes/store/selectors'; import { isInvocationNode } from 'features/nodes/types/invocation'; import { getFieldColor } from './getEdgeColor'; +const defaultReturnValue = { + isSelected: false, + shouldAnimate: false, + stroke: colorTokenToCssVar('base.500'), +}; + export const makeEdgeSelector = ( source: string, sourceHandleId: string | null | undefined, @@ -12,14 +19,19 @@ export const makeEdgeSelector = ( targetHandleId: string | null | undefined, selected?: boolean ) => - createMemoizedSelector(selectNodesSlice, (nodes) => { + createMemoizedSelector(selectNodesSlice, (nodes): { isSelected: boolean; shouldAnimate: boolean; stroke: string } => { const sourceNode = nodes.nodes.find((node) => node.id === source); const targetNode = nodes.nodes.find((node) => node.id === target); const isInvocationToInvocationEdge = isInvocationNode(sourceNode) && isInvocationNode(targetNode); - const isSelected = sourceNode?.selected || targetNode?.selected || selected; - const sourceType = isInvocationToInvocationEdge ? sourceNode?.data?.outputs[sourceHandleId || '']?.type : undefined; + const isSelected = Boolean(sourceNode?.selected || targetNode?.selected || selected); + if (!sourceNode || !sourceHandleId) { + return defaultReturnValue; + } + + const outputFieldTemplate = selectFieldOutputTemplate(nodes, sourceNode.id, sourceHandleId); + const sourceType = isInvocationToInvocationEdge ? outputFieldTemplate?.type : undefined; const stroke = sourceType && nodes.shouldColorEdges ? getFieldColor(sourceType) : colorTokenToCssVar('base.500'); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeCollapsedHandles.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeCollapsedHandles.tsx index c287842f6ed..b888e8a5162 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeCollapsedHandles.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeCollapsedHandles.tsx @@ -1,6 +1,5 @@ import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens'; -import { useNodeData } from 'features/nodes/hooks/useNodeData'; -import { isInvocationNodeData } from 'features/nodes/types/invocation'; +import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate'; import { map } from 'lodash-es'; import type { CSSProperties } from 'react'; import { memo, useMemo } from 'react'; @@ -13,7 +12,7 @@ interface Props { const hiddenHandleStyles: CSSProperties = { visibility: 'hidden' }; const InvocationNodeCollapsedHandles = ({ nodeId }: Props) => { - const data = useNodeData(nodeId); + const template = useNodeTemplate(nodeId); const { base600 } = useChakraThemeTokens(); const dummyHandleStyles: CSSProperties = useMemo( @@ -37,7 +36,7 @@ const InvocationNodeCollapsedHandles = ({ nodeId }: Props) => { [dummyHandleStyles] ); - if (!isInvocationNodeData(data)) { + if (!template) { return null; } @@ -45,14 +44,14 @@ const InvocationNodeCollapsedHandles = ({ nodeId }: Props) => { <> - {map(data.inputs, (input) => ( + {map(template.inputs, (input) => ( { ))} - {map(data.outputs, (output) => ( + {map(template.outputs, (output) => ( ) => { const { id: nodeId, type, isOpen, label } = data; const hasTemplateSelector = useMemo( - () => createSelector(selectNodeTemplatesSlice, (nodeTemplates) => Boolean(nodeTemplates.templates[type])), + () => createSelector(selectNodesSlice, (nodes) => Boolean(nodes.templates[type])), [type] ); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/EditableFieldTitle.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/EditableFieldTitle.tsx index c2231f703ab..e02b1a1474e 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/EditableFieldTitle.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/EditableFieldTitle.tsx @@ -22,7 +22,7 @@ import FieldTooltipContent from './FieldTooltipContent'; interface Props { nodeId: string; fieldName: string; - kind: 'input' | 'output'; + kind: 'inputs' | 'outputs'; isMissingInput?: boolean; withTooltip?: boolean; } @@ -58,7 +58,7 @@ const EditableFieldTitle = forwardRef((props: Props, ref) => { return ( : undefined} + label={withTooltip ? : undefined} openDelay={HANDLE_TOOLTIP_OPEN_DELAY} > { - const field = useFieldInstance(nodeId, fieldName); + const field = useFieldInputInstance(nodeId, fieldName); const fieldTemplate = useFieldTemplate(nodeId, fieldName, kind); const isInputTemplate = isFieldInputTemplate(fieldTemplate); const fieldTypeName = useFieldTypeName(fieldTemplate?.type); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputField.tsx index 2b9f7960e4b..66b0d3f7556 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputField.tsx @@ -25,7 +25,7 @@ const InputField = ({ nodeId, fieldName }: Props) => { const [isHovered, setIsHovered] = useState(false); const { isConnected, isConnectionInProgress, isConnectionStartField, connectionError, shouldDim } = - useConnectionState({ nodeId, fieldName, kind: 'input' }); + useConnectionState({ nodeId, fieldName, kind: 'inputs' }); const isMissingInput = useMemo(() => { if (!fieldTemplate) { @@ -76,7 +76,7 @@ const InputField = ({ nodeId, fieldName }: Props) => { @@ -101,7 +101,7 @@ const InputField = ({ nodeId, fieldName }: Props) => { diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx index c1d52c1d4fb..b6e331c1149 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx @@ -1,6 +1,5 @@ -import { Box, Text } from '@invoke-ai/ui-library'; -import { useFieldInstance } from 'features/nodes/hooks/useFieldData'; -import { useFieldTemplate } from 'features/nodes/hooks/useFieldTemplate'; +import { useFieldInputInstance } from 'features/nodes/hooks/useFieldInputInstance'; +import { useFieldInputTemplate } from 'features/nodes/hooks/useFieldInputTemplate'; import { isBoardFieldInputInstance, isBoardFieldInputTemplate, @@ -38,7 +37,6 @@ import { isVAEModelFieldInputTemplate, } from 'features/nodes/types/field'; import { memo } from 'react'; -import { useTranslation } from 'react-i18next'; import BoardFieldInputComponent from './inputs/BoardFieldInputComponent'; import BooleanFieldInputComponent from './inputs/BooleanFieldInputComponent'; @@ -63,17 +61,8 @@ type InputFieldProps = { }; const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => { - const { t } = useTranslation(); - const fieldInstance = useFieldInstance(nodeId, fieldName); - const fieldTemplate = useFieldTemplate(nodeId, fieldName, 'input'); - - if (fieldTemplate?.fieldKind === 'output') { - return ( - - {t('nodes.outputFieldInInput')}: {fieldInstance?.type.name} - - ); - } + const fieldInstance = useFieldInputInstance(nodeId, fieldName); + const fieldTemplate = useFieldInputTemplate(nodeId, fieldName); if (isStringFieldInputInstance(fieldInstance) && isStringFieldInputTemplate(fieldTemplate)) { return ; @@ -141,18 +130,10 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => { return ; } - if (fieldInstance && fieldTemplate) { + if (fieldTemplate) { // Fallback for when there is no component for the type return null; } - - return ( - - - {t('nodes.unknownFieldType', { type: fieldInstance?.type.name })} - - - ); }; export default memo(InputFieldRenderer); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/LinearViewField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/LinearViewField.tsx index d0a30ecc3c7..0cd199f7a47 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/LinearViewField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/LinearViewField.tsx @@ -62,7 +62,7 @@ const LinearViewField = ({ nodeId, fieldName }: Props) => { /> - + {isValueChanged && ( { /> )} } + label={} openDelay={HANDLE_TOOLTIP_OPEN_DELAY} placement="top" > diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/OutputField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/OutputField.tsx index 48c4c0d7404..f2d776a2da1 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/OutputField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/OutputField.tsx @@ -1,6 +1,5 @@ import { Flex, FormControl, FormLabel, Tooltip } from '@invoke-ai/ui-library'; import { useConnectionState } from 'features/nodes/hooks/useConnectionState'; -import { useFieldOutputInstance } from 'features/nodes/hooks/useFieldOutputInstance'; import { useFieldOutputTemplate } from 'features/nodes/hooks/useFieldOutputTemplate'; import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants'; import type { PropsWithChildren } from 'react'; @@ -18,18 +17,17 @@ interface Props { const OutputField = ({ nodeId, fieldName }: Props) => { const { t } = useTranslation(); const fieldTemplate = useFieldOutputTemplate(nodeId, fieldName); - const fieldInstance = useFieldOutputInstance(nodeId, fieldName); const { isConnected, isConnectionInProgress, isConnectionStartField, connectionError, shouldDim } = - useConnectionState({ nodeId, fieldName, kind: 'output' }); + useConnectionState({ nodeId, fieldName, kind: 'outputs' }); - if (!fieldTemplate || !fieldInstance) { + if (!fieldTemplate) { return ( {t('nodes.unknownOutput', { - name: fieldTemplate?.title ?? fieldName, + name: fieldName, })} @@ -40,7 +38,7 @@ const OutputField = ({ nodeId, fieldName }: Props) => { return ( } + label={} openDelay={HANDLE_TOOLTIP_OPEN_DELAY} placement="top" shouldWrapChildren diff --git a/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorDetailsTab.tsx b/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorDetailsTab.tsx index b7c9033d6b2..d72d2f5aa8d 100644 --- a/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorDetailsTab.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorDetailsTab.tsx @@ -6,19 +6,18 @@ import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableCon import NotesTextarea from 'features/nodes/components/flow/nodes/Invocation/NotesTextarea'; import { useNodeNeedsUpdate } from 'features/nodes/hooks/useNodeNeedsUpdate'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; import { isInvocationNode } from 'features/nodes/types/invocation'; import { memo } from 'react'; import { useTranslation } from 'react-i18next'; import EditableNodeTitle from './details/EditableNodeTitle'; -const selector = createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { +const selector = createMemoizedSelector(selectNodesSlice, (nodes) => { const lastSelectedNodeId = nodes.selectedNodes[nodes.selectedNodes.length - 1]; const lastSelectedNode = nodes.nodes.find((node) => node.id === lastSelectedNodeId); - const lastSelectedNodeTemplate = lastSelectedNode ? nodeTemplates.templates[lastSelectedNode.data.type] : undefined; + const lastSelectedNodeTemplate = lastSelectedNode ? nodes.templates[lastSelectedNode.data.type] : undefined; if (!isInvocationNode(lastSelectedNode) || !lastSelectedNodeTemplate) { return; diff --git a/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorOutputsTab.tsx b/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorOutputsTab.tsx index ee7dfaa6932..978eeddd24a 100644 --- a/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorOutputsTab.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorOutputsTab.tsx @@ -5,7 +5,6 @@ import { IAINoContentFallback } from 'common/components/IAIImageFallback'; import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent'; import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; import { isInvocationNode } from 'features/nodes/types/invocation'; import { memo } from 'react'; import { useTranslation } from 'react-i18next'; @@ -14,12 +13,12 @@ import type { AnyResult } from 'services/events/types'; import ImageOutputPreview from './outputs/ImageOutputPreview'; -const selector = createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { +const selector = createMemoizedSelector(selectNodesSlice, (nodes) => { const lastSelectedNodeId = nodes.selectedNodes[nodes.selectedNodes.length - 1]; const lastSelectedNode = nodes.nodes.find((node) => node.id === lastSelectedNodeId); - const lastSelectedNodeTemplate = lastSelectedNode ? nodeTemplates.templates[lastSelectedNode.data.type] : undefined; + const lastSelectedNodeTemplate = lastSelectedNode ? nodes.templates[lastSelectedNode.data.type] : undefined; const nes = nodes.nodeExecutionStates[lastSelectedNodeId ?? '__UNKNOWN_NODE__']; diff --git a/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorTemplateTab.tsx b/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorTemplateTab.tsx index 28f0e82d68c..ea6e8ed704d 100644 --- a/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorTemplateTab.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorTemplateTab.tsx @@ -3,16 +3,15 @@ import { useAppSelector } from 'app/store/storeHooks'; import { IAINoContentFallback } from 'common/components/IAIImageFallback'; import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; import { memo } from 'react'; import { useTranslation } from 'react-i18next'; -const selector = createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { +const selector = createMemoizedSelector(selectNodesSlice, (nodes) => { const lastSelectedNodeId = nodes.selectedNodes[nodes.selectedNodes.length - 1]; const lastSelectedNode = nodes.nodes.find((node) => node.id === lastSelectedNodeId); - const lastSelectedNodeTemplate = lastSelectedNode ? nodeTemplates.templates[lastSelectedNode.data.type] : undefined; + const lastSelectedNodeTemplate = lastSelectedNode ? nodes.templates[lastSelectedNode.data.type] : undefined; return { template: lastSelectedNodeTemplate, diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useAnyOrDirectInputFieldNames.ts b/invokeai/frontend/web/src/features/nodes/hooks/useAnyOrDirectInputFieldNames.ts index d0263a8bdaf..c882924e241 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useAnyOrDirectInputFieldNames.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useAnyOrDirectInputFieldNames.ts @@ -1,26 +1,22 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; +import { EMPTY_ARRAY } from 'app/store/util'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectNodeTemplate } from 'features/nodes/store/selectors'; import { getSortedFilteredFieldNames } from 'features/nodes/util/node/getSortedFilteredFieldNames'; import { TEMPLATE_BUILDER_MAP } from 'features/nodes/util/schema/buildFieldInputTemplate'; import { keys, map } from 'lodash-es'; import { useMemo } from 'react'; -export const useAnyOrDirectInputFieldNames = (nodeId: string) => { +export const useAnyOrDirectInputFieldNames = (nodeId: string): string[] => { const selector = useMemo( () => - createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return []; + createMemoizedSelector(selectNodesSlice, (nodes) => { + const template = selectNodeTemplate(nodes, nodeId); + if (!template) { + return EMPTY_ARRAY; } - const nodeTemplate = nodeTemplates.templates[node.data.type]; - if (!nodeTemplate) { - return []; - } - const fields = map(nodeTemplate.inputs).filter( + const fields = map(template.inputs).filter( (field) => (['any', 'direct'].includes(field.input) || field.type.isCollectionOrScalar) && keys(TEMPLATE_BUILDER_MAP).includes(field.type.name) diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useBuildNode.ts b/invokeai/frontend/web/src/features/nodes/hooks/useBuildNode.ts index aecc9318938..b19edf3c85a 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useBuildNode.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useBuildNode.ts @@ -13,7 +13,7 @@ export const SHARED_NODE_PROPERTIES: Partial = { }; export const useBuildNode = () => { - const nodeTemplates = useAppSelector((s) => s.nodeTemplates.templates); + const nodeTemplates = useAppSelector((s) => s.nodes.templates); const flow = useReactFlow(); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useConnectionInputFieldNames.ts b/invokeai/frontend/web/src/features/nodes/hooks/useConnectionInputFieldNames.ts index 23f318517b5..dc8a05b88c2 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useConnectionInputFieldNames.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useConnectionInputFieldNames.ts @@ -1,28 +1,24 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; +import { EMPTY_ARRAY } from 'app/store/util'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectNodeTemplate } from 'features/nodes/store/selectors'; import { getSortedFilteredFieldNames } from 'features/nodes/util/node/getSortedFilteredFieldNames'; import { TEMPLATE_BUILDER_MAP } from 'features/nodes/util/schema/buildFieldInputTemplate'; import { keys, map } from 'lodash-es'; import { useMemo } from 'react'; -export const useConnectionInputFieldNames = (nodeId: string) => { +export const useConnectionInputFieldNames = (nodeId: string): string[] => { const selector = useMemo( () => - createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return []; - } - const nodeTemplate = nodeTemplates.templates[node.data.type]; - if (!nodeTemplate) { - return []; + createMemoizedSelector(selectNodesSlice, (nodes) => { + const template = selectNodeTemplate(nodes, nodeId); + if (!template) { + return EMPTY_ARRAY; } // get the visible fields - const fields = map(nodeTemplate.inputs).filter( + const fields = map(template.inputs).filter( (field) => (field.input === 'connection' && !field.type.isCollectionOrScalar) || !keys(TEMPLATE_BUILDER_MAP).includes(field.type.name) diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts b/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts index a6f8b663f69..97b96f323ad 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts @@ -14,7 +14,7 @@ const selectIsConnectionInProgress = createSelector( export type UseConnectionStateProps = { nodeId: string; fieldName: string; - kind: 'input' | 'output'; + kind: 'inputs' | 'outputs'; }; export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionStateProps) => { @@ -26,8 +26,8 @@ export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionSta Boolean( nodes.edges.filter((edge) => { return ( - (kind === 'input' ? edge.target : edge.source) === nodeId && - (kind === 'input' ? edge.targetHandle : edge.sourceHandle) === fieldName + (kind === 'inputs' ? edge.target : edge.source) === nodeId && + (kind === 'inputs' ? edge.targetHandle : edge.sourceHandle) === fieldName ); }).length ) @@ -36,7 +36,7 @@ export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionSta ); const selectConnectionError = useMemo( - () => makeConnectionErrorSelector(nodeId, fieldName, kind === 'input' ? 'target' : 'source', fieldType), + () => makeConnectionErrorSelector(nodeId, fieldName, kind === 'inputs' ? 'target' : 'source', fieldType), [nodeId, fieldName, kind, fieldType] ); @@ -46,7 +46,7 @@ export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionSta Boolean( nodes.connectionStartParams?.nodeId === nodeId && nodes.connectionStartParams?.handleId === fieldName && - nodes.connectionStartParams?.handleType === { input: 'target', output: 'source' }[kind] + nodes.connectionStartParams?.handleType === { inputs: 'target', outputs: 'source' }[kind] ) ), [fieldName, kind, nodeId] diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useDoNodeVersionsMatch.ts b/invokeai/frontend/web/src/features/nodes/hooks/useDoNodeVersionsMatch.ts index bfbf0a3b2d3..91994cf7525 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useDoNodeVersionsMatch.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useDoNodeVersionsMatch.ts @@ -2,23 +2,19 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { compareVersions } from 'compare-versions'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectNodeData, selectNodeTemplate } from 'features/nodes/store/selectors'; import { useMemo } from 'react'; -export const useDoNodeVersionsMatch = (nodeId: string) => { +export const useDoNodeVersionsMatch = (nodeId: string): boolean => { const selector = useMemo( () => - createSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { + createSelector(selectNodesSlice, (nodes) => { + const data = selectNodeData(nodes, nodeId); + const template = selectNodeTemplate(nodes, nodeId); + if (!template?.version || !data?.version) { return false; } - const nodeTemplate = nodeTemplates.templates[node?.data.type ?? '']; - if (!nodeTemplate?.version || !node.data?.version) { - return false; - } - return compareVersions(nodeTemplate.version, node.data.version) === 0; + return compareVersions(template.version, data.version) === 0; }), [nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useDoesInputHaveValue.ts b/invokeai/frontend/web/src/features/nodes/hooks/useDoesInputHaveValue.ts index cfe5c90d9cc..5051eaa55b3 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useDoesInputHaveValue.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useDoesInputHaveValue.ts @@ -1,18 +1,18 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectNodeData } from 'features/nodes/store/selectors'; import { useMemo } from 'react'; -export const useDoesInputHaveValue = (nodeId: string, fieldName: string) => { +export const useDoesInputHaveValue = (nodeId: string, fieldName: string): boolean => { const selector = useMemo( () => createMemoizedSelector(selectNodesSlice, (nodes) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return; + const data = selectNodeData(nodes, nodeId); + if (!data) { + return false; } - return node?.data.inputs[fieldName]?.value !== undefined; + return data.inputs[fieldName]?.value !== undefined; }), [fieldName, nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldData.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldData.ts deleted file mode 100644 index 8b35a2d44be..00000000000 --- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldData.ts +++ /dev/null @@ -1,23 +0,0 @@ -import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; -import { useAppSelector } from 'app/store/storeHooks'; -import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; -import { useMemo } from 'react'; - -export const useFieldInstance = (nodeId: string, fieldName: string) => { - const selector = useMemo( - () => - createMemoizedSelector(selectNodesSlice, (nodes) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return; - } - return node?.data.inputs[fieldName]; - }), - [fieldName, nodeId] - ); - - const fieldData = useAppSelector(selector); - - return fieldData; -}; diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputInstance.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputInstance.ts index 0793f1f9529..25065e7aba5 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputInstance.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputInstance.ts @@ -1,23 +1,20 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectFieldInputInstance } from 'features/nodes/store/selectors'; +import type { FieldInputInstance } from 'features/nodes/types/field'; import { useMemo } from 'react'; -export const useFieldInputInstance = (nodeId: string, fieldName: string) => { +export const useFieldInputInstance = (nodeId: string, fieldName: string): FieldInputInstance | null => { const selector = useMemo( () => createMemoizedSelector(selectNodesSlice, (nodes) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return; - } - return node.data.inputs[fieldName]; + return selectFieldInputInstance(nodes, nodeId, fieldName); }), [fieldName, nodeId] ); - const fieldTemplate = useAppSelector(selector); + const fieldData = useAppSelector(selector); - return fieldTemplate; + return fieldData; }; diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputKind.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputKind.ts index 11d44dbde2e..08de3d9b205 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputKind.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputKind.ts @@ -1,21 +1,16 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectFieldInputTemplate } from 'features/nodes/store/selectors'; +import type { FieldInput } from 'features/nodes/types/field'; import { useMemo } from 'react'; export const useFieldInputKind = (nodeId: string, fieldName: string) => { const selector = useMemo( () => - createSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return; - } - const nodeTemplate = nodeTemplates.templates[node?.data.type ?? '']; - const fieldTemplate = nodeTemplate?.inputs[fieldName]; - return fieldTemplate?.input; + createSelector(selectNodesSlice, (nodes): FieldInput | null => { + const template = selectFieldInputTemplate(nodes, nodeId, fieldName); + return template?.input ?? null; }), [fieldName, nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputTemplate.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputTemplate.ts index 8533d2be8df..e8289d7e07d 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputTemplate.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputTemplate.ts @@ -1,20 +1,15 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectFieldInputTemplate } from 'features/nodes/store/selectors'; +import type { FieldInputTemplate } from 'features/nodes/types/field'; import { useMemo } from 'react'; -export const useFieldInputTemplate = (nodeId: string, fieldName: string) => { +export const useFieldInputTemplate = (nodeId: string, fieldName: string): FieldInputTemplate | null => { const selector = useMemo( () => - createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return; - } - const nodeTemplate = nodeTemplates.templates[node?.data.type ?? '']; - return nodeTemplate?.inputs[fieldName]; + createMemoizedSelector(selectNodesSlice, (nodes) => { + return selectFieldInputTemplate(nodes, nodeId, fieldName); }), [fieldName, nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldLabel.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldLabel.ts index ef57956047e..92eab8d1b15 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldLabel.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useFieldLabel.ts @@ -1,18 +1,14 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectFieldInputInstance } from 'features/nodes/store/selectors'; import { useMemo } from 'react'; -export const useFieldLabel = (nodeId: string, fieldName: string) => { +export const useFieldLabel = (nodeId: string, fieldName: string): string | null => { const selector = useMemo( () => createSelector(selectNodesSlice, (nodes) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return; - } - return node?.data.inputs[fieldName]?.label; + return selectFieldInputInstance(nodes, nodeId, fieldName)?.label ?? null; }), [fieldName, nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldOutputInstance.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldOutputInstance.ts deleted file mode 100644 index 8b71f1ea014..00000000000 --- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldOutputInstance.ts +++ /dev/null @@ -1,23 +0,0 @@ -import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; -import { useAppSelector } from 'app/store/storeHooks'; -import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; -import { useMemo } from 'react'; - -export const useFieldOutputInstance = (nodeId: string, fieldName: string) => { - const selector = useMemo( - () => - createMemoizedSelector(selectNodesSlice, (nodes) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return; - } - return node.data.outputs[fieldName]; - }), - [fieldName, nodeId] - ); - - const fieldTemplate = useAppSelector(selector); - - return fieldTemplate; -}; diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldOutputTemplate.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldOutputTemplate.ts index 11f592b399e..cb154071e97 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldOutputTemplate.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useFieldOutputTemplate.ts @@ -1,20 +1,15 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectFieldOutputTemplate } from 'features/nodes/store/selectors'; +import type { FieldOutputTemplate } from 'features/nodes/types/field'; import { useMemo } from 'react'; -export const useFieldOutputTemplate = (nodeId: string, fieldName: string) => { +export const useFieldOutputTemplate = (nodeId: string, fieldName: string): FieldOutputTemplate | null => { const selector = useMemo( () => - createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return; - } - const nodeTemplate = nodeTemplates.templates[node?.data.type ?? '']; - return nodeTemplate?.outputs[fieldName]; + createMemoizedSelector(selectNodesSlice, (nodes) => { + return selectFieldOutputTemplate(nodes, nodeId, fieldName); }), [fieldName, nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldTemplate.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldTemplate.ts index 663821da81e..7be4ecfd4df 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldTemplate.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useFieldTemplate.ts @@ -1,21 +1,22 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; -import { KIND_MAP } from 'features/nodes/types/constants'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectFieldInputTemplate, selectFieldOutputTemplate } from 'features/nodes/store/selectors'; +import type { FieldInputTemplate, FieldOutputTemplate } from 'features/nodes/types/field'; import { useMemo } from 'react'; -export const useFieldTemplate = (nodeId: string, fieldName: string, kind: 'input' | 'output') => { +export const useFieldTemplate = ( + nodeId: string, + fieldName: string, + kind: 'inputs' | 'outputs' +): FieldInputTemplate | FieldOutputTemplate | null => { const selector = useMemo( () => - createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return; + createMemoizedSelector(selectNodesSlice, (nodes) => { + if (kind === 'inputs') { + return selectFieldInputTemplate(nodes, nodeId, fieldName); } - const nodeTemplate = nodeTemplates.templates[node?.data.type ?? '']; - return nodeTemplate?.[KIND_MAP[kind]][fieldName]; + return selectFieldOutputTemplate(nodes, nodeId, fieldName); }), [fieldName, kind, nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldTemplateTitle.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldTemplateTitle.ts index cfdcda6efab..e41e0195724 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldTemplateTitle.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useFieldTemplateTitle.ts @@ -1,21 +1,17 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; -import { KIND_MAP } from 'features/nodes/types/constants'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectFieldInputTemplate, selectFieldOutputTemplate } from 'features/nodes/store/selectors'; import { useMemo } from 'react'; -export const useFieldTemplateTitle = (nodeId: string, fieldName: string, kind: 'input' | 'output') => { +export const useFieldTemplateTitle = (nodeId: string, fieldName: string, kind: 'inputs' | 'outputs'): string | null => { const selector = useMemo( () => - createSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return; + createSelector(selectNodesSlice, (nodes) => { + if (kind === 'inputs') { + return selectFieldInputTemplate(nodes, nodeId, fieldName)?.title ?? null; } - const nodeTemplate = nodeTemplates.templates[node?.data.type ?? '']; - return nodeTemplate?.[KIND_MAP[kind]][fieldName]?.title; + return selectFieldOutputTemplate(nodes, nodeId, fieldName)?.title ?? null; }), [fieldName, kind, nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldType.ts.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldType.ts.ts index a834726a136..a71a4d044ee 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldType.ts.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useFieldType.ts.ts @@ -1,20 +1,18 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { KIND_MAP } from 'features/nodes/types/constants'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectFieldInputTemplate, selectFieldOutputTemplate } from 'features/nodes/store/selectors'; +import type { FieldType } from 'features/nodes/types/field'; import { useMemo } from 'react'; -export const useFieldType = (nodeId: string, fieldName: string, kind: 'input' | 'output') => { +export const useFieldType = (nodeId: string, fieldName: string, kind: 'inputs' | 'outputs'): FieldType | null => { const selector = useMemo( () => createMemoizedSelector(selectNodesSlice, (nodes) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return; + if (kind === 'inputs') { + return selectFieldInputTemplate(nodes, nodeId, fieldName)?.type ?? null; } - const field = node.data[KIND_MAP[kind]][fieldName]; - return field?.type; + return selectFieldOutputTemplate(nodes, nodeId, fieldName)?.type ?? null; }), [fieldName, kind, nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useGetNodesNeedUpdate.ts b/invokeai/frontend/web/src/features/nodes/hooks/useGetNodesNeedUpdate.ts index a8019c92d6d..71344197d54 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useGetNodesNeedUpdate.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useGetNodesNeedUpdate.ts @@ -1,13 +1,12 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; import { isInvocationNode } from 'features/nodes/types/invocation'; import { getNeedsUpdate } from 'features/nodes/util/node/nodeUpdate'; -const selector = createSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => +const selector = createSelector(selectNodesSlice, (nodes) => nodes.nodes.filter(isInvocationNode).some((node) => { - const template = nodeTemplates.templates[node.data.type]; + const template = nodes.templates[node.data.type]; if (!template) { return false; } diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useHasImageOutput.ts b/invokeai/frontend/web/src/features/nodes/hooks/useHasImageOutput.ts index 617e713c7cc..3ac3cabb220 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useHasImageOutput.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useHasImageOutput.ts @@ -1,24 +1,21 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectNodeTemplate } from 'features/nodes/store/selectors'; import { some } from 'lodash-es'; import { useMemo } from 'react'; -export const useHasImageOutput = (nodeId: string) => { +export const useHasImageOutput = (nodeId: string): boolean => { const selector = useMemo( () => createMemoizedSelector(selectNodesSlice, (nodes) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return false; - } + const template = selectNodeTemplate(nodes, nodeId); return some( - node.data.outputs, + template?.outputs, (output) => output.type.name === 'ImageField' && // the image primitive node (node type "image") does not actually save the image, do not show the image-saving checkboxes - node.data.type !== 'image' + template?.type !== 'image' ); }), [nodeId] diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useIsIntermediate.ts b/invokeai/frontend/web/src/features/nodes/hooks/useIsIntermediate.ts index 729bfa0cea0..3fad0a2a861 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useIsIntermediate.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useIsIntermediate.ts @@ -1,18 +1,14 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectNodeData } from 'features/nodes/store/selectors'; import { useMemo } from 'react'; -export const useIsIntermediate = (nodeId: string) => { +export const useIsIntermediate = (nodeId: string): boolean => { const selector = useMemo( () => createSelector(selectNodesSlice, (nodes) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return false; - } - return node.data.isIntermediate; + return selectNodeData(nodes, nodeId)?.isIntermediate ?? false; }), [nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts b/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts index 39a8abbe7a2..ded05c7b9bf 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts @@ -1,11 +1,10 @@ // TODO: enable this at some point -import { useAppSelector } from 'app/store/storeHooks'; +import { useAppSelector, useAppStore } from 'app/store/storeHooks'; import { getIsGraphAcyclic } from 'features/nodes/store/util/getIsGraphAcyclic'; import { validateSourceAndTargetTypes } from 'features/nodes/store/util/validateSourceAndTargetTypes'; import type { InvocationNodeData } from 'features/nodes/types/invocation'; import { useCallback } from 'react'; import type { Connection, Node } from 'reactflow'; -import { useReactFlow } from 'reactflow'; /** * NOTE: The logic here must be duplicated in `invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts` @@ -13,36 +12,31 @@ import { useReactFlow } from 'reactflow'; */ export const useIsValidConnection = () => { - const flow = useReactFlow(); + const store = useAppStore(); const shouldValidateGraph = useAppSelector((s) => s.nodes.shouldValidateGraph); const isValidConnection = useCallback( ({ source, sourceHandle, target, targetHandle }: Connection): boolean => { - const edges = flow.getEdges(); - const nodes = flow.getNodes(); // Connection must have valid targets if (!(source && sourceHandle && target && targetHandle)) { return false; } - // Find the source and target nodes - const sourceNode = flow.getNode(source) as Node; - const targetNode = flow.getNode(target) as Node; - - // Conditional guards against undefined nodes/handles - if (!(sourceNode && targetNode && sourceNode.data && targetNode.data)) { + if (source === target) { + // Don't allow nodes to connect to themselves, even if validation is disabled return false; } - const sourceField = sourceNode.data.outputs[sourceHandle]; - const targetField = targetNode.data.inputs[targetHandle]; + const state = store.getState(); + const { nodes, edges, templates } = state.nodes; - if (!sourceField || !targetField) { - // something has gone terribly awry - return false; - } + // Find the source and target nodes + const sourceNode = nodes.find((node) => node.id === source) as Node; + const targetNode = nodes.find((node) => node.id === target) as Node; + const sourceFieldTemplate = templates[sourceNode.data.type]?.outputs[sourceHandle]; + const targetFieldTemplate = templates[targetNode.data.type]?.inputs[targetHandle]; - if (source === target) { - // Don't allow nodes to connect to themselves, even if validation is disabled + // Conditional guards against undefined nodes/handles + if (!(sourceFieldTemplate && targetFieldTemplate)) { return false; } @@ -69,20 +63,20 @@ export const useIsValidConnection = () => { return edge.target === target && edge.targetHandle === targetHandle; }) && // except CollectionItem inputs can have multiples - targetField.type.name !== 'CollectionItemField' + targetFieldTemplate.type.name !== 'CollectionItemField' ) { return false; } // Must use the originalType here if it exists - if (!validateSourceAndTargetTypes(sourceField.type, targetField.type)) { + if (!validateSourceAndTargetTypes(sourceFieldTemplate.type, targetFieldTemplate.type)) { return false; } // Graphs much be acyclic (no loops!) return getIsGraphAcyclic(source, target, nodes, edges); }, - [flow, shouldValidateGraph] + [shouldValidateGraph, store] ); return isValidConnection; diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useNodeClassification.ts b/invokeai/frontend/web/src/features/nodes/hooks/useNodeClassification.ts index c61721030eb..bab8ff3f194 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useNodeClassification.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useNodeClassification.ts @@ -1,20 +1,15 @@ -import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; +import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectNodeTemplate } from 'features/nodes/store/selectors'; +import type { Classification } from 'features/nodes/types/common'; import { useMemo } from 'react'; -export const useNodeClassification = (nodeId: string) => { +export const useNodeClassification = (nodeId: string): Classification | null => { const selector = useMemo( () => - createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return false; - } - const nodeTemplate = nodeTemplates.templates[node?.data.type ?? '']; - return nodeTemplate?.classification; + createSelector(selectNodesSlice, (nodes) => { + return selectNodeTemplate(nodes, nodeId)?.classification ?? null; }), [nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useNodeData.ts b/invokeai/frontend/web/src/features/nodes/hooks/useNodeData.ts index c507def5ee3..fa21008ff8b 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useNodeData.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useNodeData.ts @@ -1,14 +1,15 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; +import { selectNodeData } from 'features/nodes/store/selectors'; +import type { InvocationNodeData } from 'features/nodes/types/invocation'; import { useMemo } from 'react'; -export const useNodeData = (nodeId: string) => { +export const useNodeData = (nodeId: string): InvocationNodeData | null => { const selector = useMemo( () => createMemoizedSelector(selectNodesSlice, (nodes) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - return node?.data; + return selectNodeData(nodes, nodeId); }), [nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useNodeLabel.ts b/invokeai/frontend/web/src/features/nodes/hooks/useNodeLabel.ts index c5fc43742a1..31dcb9c466e 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useNodeLabel.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useNodeLabel.ts @@ -1,19 +1,14 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectNodeData } from 'features/nodes/store/selectors'; import { useMemo } from 'react'; export const useNodeLabel = (nodeId: string) => { const selector = useMemo( () => createSelector(selectNodesSlice, (nodes) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return false; - } - - return node.data.label; + return selectNodeData(nodes, nodeId)?.label ?? null; }), [nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useNodeNeedsUpdate.ts b/invokeai/frontend/web/src/features/nodes/hooks/useNodeNeedsUpdate.ts index e6efa667f12..aa0294f70f0 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useNodeNeedsUpdate.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useNodeNeedsUpdate.ts @@ -1,21 +1,20 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectInvocationNode, selectNodeTemplate } from 'features/nodes/store/selectors'; import { getNeedsUpdate } from 'features/nodes/util/node/nodeUpdate'; import { useMemo } from 'react'; export const useNodeNeedsUpdate = (nodeId: string) => { const selector = useMemo( () => - createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - const template = nodeTemplates.templates[node?.data.type ?? '']; - if (isInvocationNode(node) && template) { - return getNeedsUpdate(node, template); + createMemoizedSelector(selectNodesSlice, (nodes) => { + const node = selectInvocationNode(nodes, nodeId); + const template = selectNodeTemplate(nodes, nodeId); + if (!node || !template) { + return false; } - return false; + return getNeedsUpdate(node, template); }), [nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useNodePack.ts b/invokeai/frontend/web/src/features/nodes/hooks/useNodePack.ts index ca3dd5cfdf6..5c920866e9d 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useNodePack.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useNodePack.ts @@ -1,18 +1,14 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectNodeData } from 'features/nodes/store/selectors'; import { useMemo } from 'react'; -export const useNodePack = (nodeId: string) => { +export const useNodePack = (nodeId: string): string | null => { const selector = useMemo( () => createSelector(selectNodesSlice, (nodes) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return false; - } - return node.data.nodePack; + return selectNodeData(nodes, nodeId)?.nodePack ?? null; }), [nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplate.ts b/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplate.ts index 7544cbff461..866c9275fb3 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplate.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplate.ts @@ -1,16 +1,15 @@ -import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; +import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; +import { selectNodeTemplate } from 'features/nodes/store/selectors'; +import type { InvocationTemplate } from 'features/nodes/types/invocation'; import { useMemo } from 'react'; -export const useNodeTemplate = (nodeId: string) => { +export const useNodeTemplate = (nodeId: string): InvocationTemplate | null => { const selector = useMemo( () => - createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - const nodeTemplate = nodeTemplates.templates[node?.data.type ?? '']; - return nodeTemplate; + createSelector(selectNodesSlice, (nodes) => { + return selectNodeTemplate(nodes, nodeId); }), [nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateByType.ts b/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateByType.ts index 8fd1345f6f5..a0c870f6941 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateByType.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateByType.ts @@ -1,14 +1,14 @@ -import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; +import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; +import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; import type { InvocationTemplate } from 'features/nodes/types/invocation'; import { useMemo } from 'react'; -export const useNodeTemplateByType = (type: string) => { +export const useNodeTemplateByType = (type: string): InvocationTemplate | null => { const selector = useMemo( () => - createMemoizedSelector(selectNodeTemplatesSlice, (nodeTemplates): InvocationTemplate | undefined => { - return nodeTemplates.templates[type]; + createSelector(selectNodesSlice, (nodes) => { + return nodes.templates[type] ?? null; }), [type] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateTitle.ts b/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateTitle.ts index 15d2ec38c32..120b8c758be 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateTitle.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateTitle.ts @@ -1,21 +1,14 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectNodeTemplate } from 'features/nodes/store/selectors'; import { useMemo } from 'react'; -export const useNodeTemplateTitle = (nodeId: string) => { +export const useNodeTemplateTitle = (nodeId: string): string | null => { const selector = useMemo( () => - createSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return false; - } - const nodeTemplate = node ? nodeTemplates.templates[node.data.type] : undefined; - - return nodeTemplate?.title; + createSelector(selectNodesSlice, (nodes) => { + return selectNodeTemplate(nodes, nodeId)?.title ?? null; }), [nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useOutputFieldNames.ts b/invokeai/frontend/web/src/features/nodes/hooks/useOutputFieldNames.ts index e352bd8b90f..24863080a74 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useOutputFieldNames.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useOutputFieldNames.ts @@ -1,8 +1,8 @@ -import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; +import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; +import { EMPTY_ARRAY } from 'app/store/util'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectNodeTemplate } from 'features/nodes/store/selectors'; import { getSortedFilteredFieldNames } from 'features/nodes/util/node/getSortedFilteredFieldNames'; import { map } from 'lodash-es'; import { useMemo } from 'react'; @@ -10,17 +10,13 @@ import { useMemo } from 'react'; export const useOutputFieldNames = (nodeId: string) => { const selector = useMemo( () => - createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return []; - } - const nodeTemplate = nodeTemplates.templates[node.data.type]; - if (!nodeTemplate) { - return []; + createSelector(selectNodesSlice, (nodes) => { + const template = selectNodeTemplate(nodes, nodeId); + if (!template) { + return EMPTY_ARRAY; } - return getSortedFilteredFieldNames(map(nodeTemplate.outputs)); + return getSortedFilteredFieldNames(map(template.outputs)); }), [nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useUseCache.ts b/invokeai/frontend/web/src/features/nodes/hooks/useUseCache.ts index edfc990882b..aaca80039b0 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useUseCache.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useUseCache.ts @@ -1,18 +1,14 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectNodeData } from 'features/nodes/store/selectors'; import { useMemo } from 'react'; export const useUseCache = (nodeId: string) => { const selector = useMemo( () => createSelector(selectNodesSlice, (nodes) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return false; - } - return node.data.useCache; + return selectNodeData(nodes, nodeId)?.useCache ?? false; }), [nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useWorkflowWatcher.ts b/invokeai/frontend/web/src/features/nodes/hooks/useWorkflowWatcher.ts index 0e4806d81b7..5d79c154428 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useWorkflowWatcher.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useWorkflowWatcher.ts @@ -2,14 +2,14 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; import { selectWorkflowSlice } from 'features/nodes/store/workflowSlice'; -import type { WorkflowV2 } from 'features/nodes/types/workflow'; +import type { WorkflowV3 } from 'features/nodes/types/workflow'; import type { BuildWorkflowArg } from 'features/nodes/util/workflow/buildWorkflow'; import { buildWorkflowFast } from 'features/nodes/util/workflow/buildWorkflow'; import { debounce } from 'lodash-es'; import { atom } from 'nanostores'; import { useEffect } from 'react'; -export const $builtWorkflow = atom(null); +export const $builtWorkflow = atom(null); const debouncedBuildWorkflow = debounce((arg: BuildWorkflowArg) => { $builtWorkflow.set(buildWorkflowFast(arg)); diff --git a/invokeai/frontend/web/src/features/nodes/store/actions.ts b/invokeai/frontend/web/src/features/nodes/store/actions.ts index 00457494bfb..b32a3ba9979 100644 --- a/invokeai/frontend/web/src/features/nodes/store/actions.ts +++ b/invokeai/frontend/web/src/features/nodes/store/actions.ts @@ -1,5 +1,5 @@ import { createAction, isAnyOf } from '@reduxjs/toolkit'; -import type { WorkflowV2 } from 'features/nodes/types/workflow'; +import type { WorkflowV3 } from 'features/nodes/types/workflow'; import type { Graph } from 'services/api/types'; export const textToImageGraphBuilt = createAction('nodes/textToImageGraphBuilt'); @@ -21,4 +21,4 @@ export const workflowLoadRequested = createAction<{ export const updateAllNodesRequested = createAction('nodes/updateAllNodesRequested'); -export const workflowLoaded = createAction('workflow/workflowLoaded'); +export const workflowLoaded = createAction('workflow/workflowLoaded'); diff --git a/invokeai/frontend/web/src/features/nodes/store/nodeTemplatesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodeTemplatesSlice.ts deleted file mode 100644 index c211131aab7..00000000000 --- a/invokeai/frontend/web/src/features/nodes/store/nodeTemplatesSlice.ts +++ /dev/null @@ -1,24 +0,0 @@ -import type { PayloadAction } from '@reduxjs/toolkit'; -import { createSlice } from '@reduxjs/toolkit'; -import type { RootState } from 'app/store/store'; -import type { InvocationTemplate } from 'features/nodes/types/invocation'; - -import type { NodeTemplatesState } from './types'; - -export const initialNodeTemplatesState: NodeTemplatesState = { - templates: {}, -}; - -export const nodesTemplatesSlice = createSlice({ - name: 'nodeTemplates', - initialState: initialNodeTemplatesState, - reducers: { - nodeTemplatesBuilt: (state, action: PayloadAction>) => { - state.templates = action.payload; - }, - }, -}); - -export const { nodeTemplatesBuilt } = nodesTemplatesSlice.actions; - -export const selectNodeTemplatesSlice = (state: RootState) => state.nodeTemplates; diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts index aee01b381ba..6b596da0633 100644 --- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts @@ -42,7 +42,7 @@ import { zT2IAdapterModelFieldValue, zVAEModelFieldValue, } from 'features/nodes/types/field'; -import type { AnyNode, NodeExecutionState } from 'features/nodes/types/invocation'; +import type { AnyNode, InvocationTemplate, NodeExecutionState } from 'features/nodes/types/invocation'; import { isInvocationNode, isNotesNode, zNodeStatus } from 'features/nodes/types/invocation'; import { cloneDeep, forEach } from 'lodash-es'; import type { @@ -92,6 +92,7 @@ export const initialNodesState: NodesState = { _version: 1, nodes: [], edges: [], + templates: {}, connectionStartParams: null, connectionStartFieldType: null, connectionMade: false, @@ -190,6 +191,7 @@ export const nodesSlice = createSlice({ node, state.nodes, state.edges, + state.templates, nodeId, handleId, handleType, @@ -224,12 +226,12 @@ export const nodesSlice = createSlice({ if (!nodeId || !handleId) { return; } - const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId); - const node = state.nodes?.[nodeIndex]; + const node = state.nodes.find((n) => n.id === nodeId); if (!isInvocationNode(node)) { return; } - const field = handleType === 'source' ? node.data.outputs[handleId] : node.data.inputs[handleId]; + const template = state.templates[node.data.type]; + const field = handleType === 'source' ? template?.outputs[handleId] : template?.inputs[handleId]; state.connectionStartFieldType = field?.type ?? null; }, connectionMade: (state, action: PayloadAction) => { @@ -260,6 +262,7 @@ export const nodesSlice = createSlice({ mouseOverNode, state.nodes, state.edges, + state.templates, nodeId, handleId, handleType, @@ -677,6 +680,9 @@ export const nodesSlice = createSlice({ selectionModeChanged: (state, action: PayloadAction) => { state.selectionMode = action.payload ? SelectionMode.Full : SelectionMode.Partial; }, + nodeTemplatesBuilt: (state, action: PayloadAction>) => { + state.templates = action.payload; + }, }, extraReducers: (builder) => { builder.addCase(workflowLoaded, (state, action) => { @@ -808,6 +814,7 @@ export const { shouldValidateGraphChanged, viewportChanged, edgeAdded, + nodeTemplatesBuilt, } = nodesSlice.actions; // This is used for tracking `state.workflow.isTouched` diff --git a/invokeai/frontend/web/src/features/nodes/store/selectors.ts b/invokeai/frontend/web/src/features/nodes/store/selectors.ts new file mode 100644 index 00000000000..90675d62707 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/selectors.ts @@ -0,0 +1,51 @@ +import type { NodesState } from 'features/nodes/store/types'; +import type { FieldInputInstance, FieldInputTemplate, FieldOutputTemplate } from 'features/nodes/types/field'; +import type { InvocationNode, InvocationNodeData, InvocationTemplate } from 'features/nodes/types/invocation'; +import { isInvocationNode } from 'features/nodes/types/invocation'; + +export const selectInvocationNode = (nodesSlice: NodesState, nodeId: string): InvocationNode | null => { + const node = nodesSlice.nodes.find((node) => node.id === nodeId); + if (!isInvocationNode(node)) { + return null; + } + return node; +}; + +export const selectNodeData = (nodesSlice: NodesState, nodeId: string): InvocationNodeData | null => { + return selectInvocationNode(nodesSlice, nodeId)?.data ?? null; +}; + +export const selectNodeTemplate = (nodesSlice: NodesState, nodeId: string): InvocationTemplate | null => { + const node = selectInvocationNode(nodesSlice, nodeId); + if (!node) { + return null; + } + return nodesSlice.templates[node.data.type] ?? null; +}; + +export const selectFieldInputInstance = ( + nodesSlice: NodesState, + nodeId: string, + fieldName: string +): FieldInputInstance | null => { + const data = selectNodeData(nodesSlice, nodeId); + return data?.inputs[fieldName] ?? null; +}; + +export const selectFieldInputTemplate = ( + nodesSlice: NodesState, + nodeId: string, + fieldName: string +): FieldInputTemplate | null => { + const template = selectNodeTemplate(nodesSlice, nodeId); + return template?.inputs[fieldName] ?? null; +}; + +export const selectFieldOutputTemplate = ( + nodesSlice: NodesState, + nodeId: string, + fieldName: string +): FieldOutputTemplate | null => { + const template = selectNodeTemplate(nodesSlice, nodeId); + return template?.outputs[fieldName] ?? null; +}; diff --git a/invokeai/frontend/web/src/features/nodes/store/types.ts b/invokeai/frontend/web/src/features/nodes/store/types.ts index 8b0de447e43..1a040d2c705 100644 --- a/invokeai/frontend/web/src/features/nodes/store/types.ts +++ b/invokeai/frontend/web/src/features/nodes/store/types.ts @@ -5,13 +5,14 @@ import type { InvocationTemplate, NodeExecutionState, } from 'features/nodes/types/invocation'; -import type { WorkflowV2 } from 'features/nodes/types/workflow'; +import type { WorkflowV3 } from 'features/nodes/types/workflow'; import type { OnConnectStartParams, SelectionMode, Viewport, XYPosition } from 'reactflow'; export type NodesState = { _version: 1; nodes: AnyNode[]; edges: InvocationNodeEdge[]; + templates: Record; connectionStartParams: OnConnectStartParams | null; connectionStartFieldType: FieldType | null; connectionMade: boolean; @@ -38,7 +39,7 @@ export type FieldIdentifierWithValue = FieldIdentifier & { value: StatefulFieldValue; }; -export type WorkflowsState = Omit & { +export type WorkflowsState = Omit & { _version: 1; isTouched: boolean; mode: WorkflowMode; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/findConnectionToValidHandle.ts b/invokeai/frontend/web/src/features/nodes/store/util/findConnectionToValidHandle.ts index 9f2c37a2ad7..ef899c5f414 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/findConnectionToValidHandle.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/findConnectionToValidHandle.ts @@ -1,4 +1,6 @@ -import type { FieldInputInstance, FieldOutputInstance, FieldType } from 'features/nodes/types/field'; +import type { FieldInputTemplate, FieldOutputTemplate, FieldType } from 'features/nodes/types/field'; +import type { AnyNode, InvocationNodeEdge, InvocationTemplate } from 'features/nodes/types/invocation'; +import { isInvocationNode } from 'features/nodes/types/invocation'; import type { Connection, Edge, HandleType, Node } from 'reactflow'; import { getIsGraphAcyclic } from './getIsGraphAcyclic'; @@ -9,7 +11,7 @@ const isValidConnection = ( handleCurrentType: HandleType, handleCurrentFieldType: FieldType, node: Node, - handle: FieldInputInstance | FieldOutputInstance + handle: FieldInputTemplate | FieldOutputTemplate ) => { let isValidConnection = true; if (handleCurrentType === 'source') { @@ -38,24 +40,31 @@ const isValidConnection = ( }; export const findConnectionToValidHandle = ( - node: Node, - nodes: Node[], - edges: Edge[], + node: AnyNode, + nodes: AnyNode[], + edges: InvocationNodeEdge[], + templates: Record, handleCurrentNodeId: string, handleCurrentName: string, handleCurrentType: HandleType, handleCurrentFieldType: FieldType ): Connection | null => { - if (node.id === handleCurrentNodeId) { + if (node.id === handleCurrentNodeId || !isInvocationNode(node)) { return null; } - const handles = handleCurrentType === 'source' ? node.data.inputs : node.data.outputs; + const template = templates[node.data.type]; + + if (!template) { + return null; + } + + const handles = handleCurrentType === 'source' ? template.inputs : template.outputs; //Prioritize handles whos name matches the node we're coming from - if (handles[handleCurrentName]) { - const handle = handles[handleCurrentName]; + const handle = handles[handleCurrentName]; + if (handle) { const sourceID = handleCurrentType === 'source' ? handleCurrentNodeId : node.id; const targetID = handleCurrentType === 'source' ? node.id : handleCurrentNodeId; const sourceHandle = handleCurrentType === 'source' ? handleCurrentName : handle.name; @@ -77,6 +86,9 @@ export const findConnectionToValidHandle = ( for (const handleName in handles) { const handle = handles[handleName]; + if (!handle) { + continue; + } const sourceID = handleCurrentType === 'source' ? handleCurrentNodeId : node.id; const targetID = handleCurrentType === 'source' ? node.id : handleCurrentNodeId; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts b/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts index 8575932cbdd..d6ea0d9c86e 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts @@ -16,7 +16,7 @@ export const makeConnectionErrorSelector = ( nodeId: string, fieldName: string, handleType: HandleType, - fieldType?: FieldType + fieldType?: FieldType | null ) => { return createSelector(selectNodesSlice, (nodesSlice) => { if (!fieldType) { diff --git a/invokeai/frontend/web/src/features/nodes/store/workflowSlice.ts b/invokeai/frontend/web/src/features/nodes/store/workflowSlice.ts index 2978f25138d..4f40a68e1f0 100644 --- a/invokeai/frontend/web/src/features/nodes/store/workflowSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/workflowSlice.ts @@ -10,10 +10,10 @@ import type { } from 'features/nodes/store/types'; import type { FieldIdentifier } from 'features/nodes/types/field'; import { isInvocationNode } from 'features/nodes/types/invocation'; -import type { WorkflowCategory, WorkflowV2 } from 'features/nodes/types/workflow'; +import type { WorkflowCategory, WorkflowV3 } from 'features/nodes/types/workflow'; import { cloneDeep, isEqual, omit, uniqBy } from 'lodash-es'; -export const blankWorkflow: Omit = { +export const blankWorkflow: Omit = { name: '', author: '', description: '', @@ -22,7 +22,7 @@ export const blankWorkflow: Omit = { tags: '', notes: '', exposedFields: [], - meta: { version: '2.0.0', category: 'user' }, + meta: { version: '3.0.0', category: 'user' }, id: undefined, }; diff --git a/invokeai/frontend/web/src/features/nodes/types/field.ts b/invokeai/frontend/web/src/features/nodes/types/field.ts index 38f1af55dd8..aa6164d6e53 100644 --- a/invokeai/frontend/web/src/features/nodes/types/field.ts +++ b/invokeai/frontend/web/src/features/nodes/types/field.ts @@ -46,20 +46,11 @@ export type FieldInput = z.infer; export const zFieldUIComponent = z.enum(['none', 'textarea', 'slider']); export type FieldUIComponent = z.infer; -export const zFieldInstanceBase = z.object({ - id: z.string().trim().min(1), +export const zFieldInputInstanceBase = z.object({ name: z.string().trim().min(1), -}); -export const zFieldInputInstanceBase = zFieldInstanceBase.extend({ - fieldKind: z.literal('input'), label: z.string().nullish(), }); -export const zFieldOutputInstanceBase = zFieldInstanceBase.extend({ - fieldKind: z.literal('output'), -}); -export type FieldInstanceBase = z.infer; export type FieldInputInstanceBase = z.infer; -export type FieldOutputInstanceBase = z.infer; export const zFieldTemplateBase = z.object({ name: z.string().min(1), @@ -102,12 +93,8 @@ export const zIntegerFieldType = zFieldTypeBase.extend({ }); export const zIntegerFieldValue = z.number().int(); export const zIntegerFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zIntegerFieldType, value: zIntegerFieldValue, }); -export const zIntegerFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zIntegerFieldType, -}); export const zIntegerFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zIntegerFieldType, default: zIntegerFieldValue, @@ -136,12 +123,8 @@ export const zFloatFieldType = zFieldTypeBase.extend({ }); export const zFloatFieldValue = z.number(); export const zFloatFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zFloatFieldType, value: zFloatFieldValue, }); -export const zFloatFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zFloatFieldType, -}); export const zFloatFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zFloatFieldType, default: zFloatFieldValue, @@ -157,7 +140,6 @@ export const zFloatFieldOutputTemplate = zFieldOutputTemplateBase.extend({ export type FloatFieldType = z.infer; export type FloatFieldValue = z.infer; export type FloatFieldInputInstance = z.infer; -export type FloatFieldOutputInstance = z.infer; export type FloatFieldInputTemplate = z.infer; export type FloatFieldOutputTemplate = z.infer; export const isFloatFieldInputInstance = (val: unknown): val is FloatFieldInputInstance => @@ -172,12 +154,8 @@ export const zStringFieldType = zFieldTypeBase.extend({ }); export const zStringFieldValue = z.string(); export const zStringFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zStringFieldType, value: zStringFieldValue, }); -export const zStringFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zStringFieldType, -}); export const zStringFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zStringFieldType, default: zStringFieldValue, @@ -191,7 +169,6 @@ export const zStringFieldOutputTemplate = zFieldOutputTemplateBase.extend({ export type StringFieldType = z.infer; export type StringFieldValue = z.infer; export type StringFieldInputInstance = z.infer; -export type StringFieldOutputInstance = z.infer; export type StringFieldInputTemplate = z.infer; export type StringFieldOutputTemplate = z.infer; export const isStringFieldInputInstance = (val: unknown): val is StringFieldInputInstance => @@ -206,12 +183,8 @@ export const zBooleanFieldType = zFieldTypeBase.extend({ }); export const zBooleanFieldValue = z.boolean(); export const zBooleanFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zBooleanFieldType, value: zBooleanFieldValue, }); -export const zBooleanFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zBooleanFieldType, -}); export const zBooleanFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zBooleanFieldType, default: zBooleanFieldValue, @@ -222,7 +195,6 @@ export const zBooleanFieldOutputTemplate = zFieldOutputTemplateBase.extend({ export type BooleanFieldType = z.infer; export type BooleanFieldValue = z.infer; export type BooleanFieldInputInstance = z.infer; -export type BooleanFieldOutputInstance = z.infer; export type BooleanFieldInputTemplate = z.infer; export type BooleanFieldOutputTemplate = z.infer; export const isBooleanFieldInputInstance = (val: unknown): val is BooleanFieldInputInstance => @@ -237,12 +209,8 @@ export const zEnumFieldType = zFieldTypeBase.extend({ }); export const zEnumFieldValue = z.string(); export const zEnumFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zEnumFieldType, value: zEnumFieldValue, }); -export const zEnumFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zEnumFieldType, -}); export const zEnumFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zEnumFieldType, default: zEnumFieldValue, @@ -255,7 +223,6 @@ export const zEnumFieldOutputTemplate = zFieldOutputTemplateBase.extend({ export type EnumFieldType = z.infer; export type EnumFieldValue = z.infer; export type EnumFieldInputInstance = z.infer; -export type EnumFieldOutputInstance = z.infer; export type EnumFieldInputTemplate = z.infer; export type EnumFieldOutputTemplate = z.infer; export const isEnumFieldInputInstance = (val: unknown): val is EnumFieldInputInstance => @@ -270,12 +237,8 @@ export const zImageFieldType = zFieldTypeBase.extend({ }); export const zImageFieldValue = zImageField.optional(); export const zImageFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zImageFieldType, value: zImageFieldValue, }); -export const zImageFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zImageFieldType, -}); export const zImageFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zImageFieldType, default: zImageFieldValue, @@ -286,7 +249,6 @@ export const zImageFieldOutputTemplate = zFieldOutputTemplateBase.extend({ export type ImageFieldType = z.infer; export type ImageFieldValue = z.infer; export type ImageFieldInputInstance = z.infer; -export type ImageFieldOutputInstance = z.infer; export type ImageFieldInputTemplate = z.infer; export type ImageFieldOutputTemplate = z.infer; export const isImageFieldInputInstance = (val: unknown): val is ImageFieldInputInstance => @@ -301,12 +263,8 @@ export const zBoardFieldType = zFieldTypeBase.extend({ }); export const zBoardFieldValue = zBoardField.optional(); export const zBoardFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zBoardFieldType, value: zBoardFieldValue, }); -export const zBoardFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zBoardFieldType, -}); export const zBoardFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zBoardFieldType, default: zBoardFieldValue, @@ -317,7 +275,6 @@ export const zBoardFieldOutputTemplate = zFieldOutputTemplateBase.extend({ export type BoardFieldType = z.infer; export type BoardFieldValue = z.infer; export type BoardFieldInputInstance = z.infer; -export type BoardFieldOutputInstance = z.infer; export type BoardFieldInputTemplate = z.infer; export type BoardFieldOutputTemplate = z.infer; export const isBoardFieldInputInstance = (val: unknown): val is BoardFieldInputInstance => @@ -332,12 +289,8 @@ export const zColorFieldType = zFieldTypeBase.extend({ }); export const zColorFieldValue = zColorField.optional(); export const zColorFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zColorFieldType, value: zColorFieldValue, }); -export const zColorFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zColorFieldType, -}); export const zColorFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zColorFieldType, default: zColorFieldValue, @@ -348,7 +301,6 @@ export const zColorFieldOutputTemplate = zFieldOutputTemplateBase.extend({ export type ColorFieldType = z.infer; export type ColorFieldValue = z.infer; export type ColorFieldInputInstance = z.infer; -export type ColorFieldOutputInstance = z.infer; export type ColorFieldInputTemplate = z.infer; export type ColorFieldOutputTemplate = z.infer; export const isColorFieldInputInstance = (val: unknown): val is ColorFieldInputInstance => @@ -363,12 +315,8 @@ export const zMainModelFieldType = zFieldTypeBase.extend({ }); export const zMainModelFieldValue = zMainModelField.optional(); export const zMainModelFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zMainModelFieldType, value: zMainModelFieldValue, }); -export const zMainModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zMainModelFieldType, -}); export const zMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zMainModelFieldType, default: zMainModelFieldValue, @@ -379,7 +327,6 @@ export const zMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ export type MainModelFieldType = z.infer; export type MainModelFieldValue = z.infer; export type MainModelFieldInputInstance = z.infer; -export type MainModelFieldOutputInstance = z.infer; export type MainModelFieldInputTemplate = z.infer; export type MainModelFieldOutputTemplate = z.infer; export const isMainModelFieldInputInstance = (val: unknown): val is MainModelFieldInputInstance => @@ -394,12 +341,8 @@ export const zSDXLMainModelFieldType = zFieldTypeBase.extend({ }); export const zSDXLMainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL models only. export const zSDXLMainModelFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zSDXLMainModelFieldType, value: zSDXLMainModelFieldValue, }); -export const zSDXLMainModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zSDXLMainModelFieldType, -}); export const zSDXLMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zSDXLMainModelFieldType, default: zSDXLMainModelFieldValue, @@ -410,7 +353,6 @@ export const zSDXLMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend export type SDXLMainModelFieldType = z.infer; export type SDXLMainModelFieldValue = z.infer; export type SDXLMainModelFieldInputInstance = z.infer; -export type SDXLMainModelFieldOutputInstance = z.infer; export type SDXLMainModelFieldInputTemplate = z.infer; export type SDXLMainModelFieldOutputTemplate = z.infer; export const isSDXLMainModelFieldInputInstance = (val: unknown): val is SDXLMainModelFieldInputInstance => @@ -425,12 +367,8 @@ export const zSDXLRefinerModelFieldType = zFieldTypeBase.extend({ }); export const zSDXLRefinerModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL Refiner models only. export const zSDXLRefinerModelFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zSDXLRefinerModelFieldType, value: zSDXLRefinerModelFieldValue, }); -export const zSDXLRefinerModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zSDXLRefinerModelFieldType, -}); export const zSDXLRefinerModelFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zSDXLRefinerModelFieldType, default: zSDXLRefinerModelFieldValue, @@ -441,7 +379,6 @@ export const zSDXLRefinerModelFieldOutputTemplate = zFieldOutputTemplateBase.ext export type SDXLRefinerModelFieldType = z.infer; export type SDXLRefinerModelFieldValue = z.infer; export type SDXLRefinerModelFieldInputInstance = z.infer; -export type SDXLRefinerModelFieldOutputInstance = z.infer; export type SDXLRefinerModelFieldInputTemplate = z.infer; export type SDXLRefinerModelFieldOutputTemplate = z.infer; export const isSDXLRefinerModelFieldInputInstance = (val: unknown): val is SDXLRefinerModelFieldInputInstance => @@ -456,12 +393,8 @@ export const zVAEModelFieldType = zFieldTypeBase.extend({ }); export const zVAEModelFieldValue = zVAEModelField.optional(); export const zVAEModelFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zVAEModelFieldType, value: zVAEModelFieldValue, }); -export const zVAEModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zVAEModelFieldType, -}); export const zVAEModelFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zVAEModelFieldType, default: zVAEModelFieldValue, @@ -472,7 +405,6 @@ export const zVAEModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ export type VAEModelFieldType = z.infer; export type VAEModelFieldValue = z.infer; export type VAEModelFieldInputInstance = z.infer; -export type VAEModelFieldOutputInstance = z.infer; export type VAEModelFieldInputTemplate = z.infer; export type VAEModelFieldOutputTemplate = z.infer; export const isVAEModelFieldInputInstance = (val: unknown): val is VAEModelFieldInputInstance => @@ -487,12 +419,8 @@ export const zLoRAModelFieldType = zFieldTypeBase.extend({ }); export const zLoRAModelFieldValue = zLoRAModelField.optional(); export const zLoRAModelFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zLoRAModelFieldType, value: zLoRAModelFieldValue, }); -export const zLoRAModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zLoRAModelFieldType, -}); export const zLoRAModelFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zLoRAModelFieldType, default: zLoRAModelFieldValue, @@ -503,7 +431,6 @@ export const zLoRAModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ export type LoRAModelFieldType = z.infer; export type LoRAModelFieldValue = z.infer; export type LoRAModelFieldInputInstance = z.infer; -export type LoRAModelFieldOutputInstance = z.infer; export type LoRAModelFieldInputTemplate = z.infer; export type LoRAModelFieldOutputTemplate = z.infer; export const isLoRAModelFieldInputInstance = (val: unknown): val is LoRAModelFieldInputInstance => @@ -518,12 +445,8 @@ export const zControlNetModelFieldType = zFieldTypeBase.extend({ }); export const zControlNetModelFieldValue = zControlNetModelField.optional(); export const zControlNetModelFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zControlNetModelFieldType, value: zControlNetModelFieldValue, }); -export const zControlNetModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zControlNetModelFieldType, -}); export const zControlNetModelFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zControlNetModelFieldType, default: zControlNetModelFieldValue, @@ -534,7 +457,6 @@ export const zControlNetModelFieldOutputTemplate = zFieldOutputTemplateBase.exte export type ControlNetModelFieldType = z.infer; export type ControlNetModelFieldValue = z.infer; export type ControlNetModelFieldInputInstance = z.infer; -export type ControlNetModelFieldOutputInstance = z.infer; export type ControlNetModelFieldInputTemplate = z.infer; export type ControlNetModelFieldOutputTemplate = z.infer; export const isControlNetModelFieldInputInstance = (val: unknown): val is ControlNetModelFieldInputInstance => @@ -551,12 +473,8 @@ export const zIPAdapterModelFieldType = zFieldTypeBase.extend({ }); export const zIPAdapterModelFieldValue = zIPAdapterModelField.optional(); export const zIPAdapterModelFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zIPAdapterModelFieldType, value: zIPAdapterModelFieldValue, }); -export const zIPAdapterModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zIPAdapterModelFieldType, -}); export const zIPAdapterModelFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zIPAdapterModelFieldType, default: zIPAdapterModelFieldValue, @@ -567,7 +485,6 @@ export const zIPAdapterModelFieldOutputTemplate = zFieldOutputTemplateBase.exten export type IPAdapterModelFieldType = z.infer; export type IPAdapterModelFieldValue = z.infer; export type IPAdapterModelFieldInputInstance = z.infer; -export type IPAdapterModelFieldOutputInstance = z.infer; export type IPAdapterModelFieldInputTemplate = z.infer; export type IPAdapterModelFieldOutputTemplate = z.infer; export const isIPAdapterModelFieldInputInstance = (val: unknown): val is IPAdapterModelFieldInputInstance => @@ -584,12 +501,8 @@ export const zT2IAdapterModelFieldType = zFieldTypeBase.extend({ }); export const zT2IAdapterModelFieldValue = zT2IAdapterModelField.optional(); export const zT2IAdapterModelFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zT2IAdapterModelFieldType, value: zT2IAdapterModelFieldValue, }); -export const zT2IAdapterModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zT2IAdapterModelFieldType, -}); export const zT2IAdapterModelFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zT2IAdapterModelFieldType, default: zT2IAdapterModelFieldValue, @@ -600,7 +513,6 @@ export const zT2IAdapterModelFieldOutputTemplate = zFieldOutputTemplateBase.exte export type T2IAdapterModelFieldType = z.infer; export type T2IAdapterModelFieldValue = z.infer; export type T2IAdapterModelFieldInputInstance = z.infer; -export type T2IAdapterModelFieldOutputInstance = z.infer; export type T2IAdapterModelFieldInputTemplate = z.infer; export type T2IAdapterModelFieldOutputTemplate = z.infer; export const isT2IAdapterModelFieldInputInstance = (val: unknown): val is T2IAdapterModelFieldInputInstance => @@ -615,12 +527,8 @@ export const zSchedulerFieldType = zFieldTypeBase.extend({ }); export const zSchedulerFieldValue = zSchedulerField.optional(); export const zSchedulerFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zSchedulerFieldType, value: zSchedulerFieldValue, }); -export const zSchedulerFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zSchedulerFieldType, -}); export const zSchedulerFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zSchedulerFieldType, default: zSchedulerFieldValue, @@ -631,7 +539,6 @@ export const zSchedulerFieldOutputTemplate = zFieldOutputTemplateBase.extend({ export type SchedulerFieldType = z.infer; export type SchedulerFieldValue = z.infer; export type SchedulerFieldInputInstance = z.infer; -export type SchedulerFieldOutputInstance = z.infer; export type SchedulerFieldInputTemplate = z.infer; export type SchedulerFieldOutputTemplate = z.infer; export const isSchedulerFieldInputInstance = (val: unknown): val is SchedulerFieldInputInstance => @@ -657,12 +564,8 @@ export const zStatelessFieldType = zFieldTypeBase.extend({ }); export const zStatelessFieldValue = z.undefined().catch(undefined); // stateless --> no value, but making this z.never() introduces a lot of extra TS fanagling export const zStatelessFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zStatelessFieldType, value: zStatelessFieldValue, }); -export const zStatelessFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zStatelessFieldType, -}); export const zStatelessFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zStatelessFieldType, default: zStatelessFieldValue, @@ -675,7 +578,6 @@ export const zStatelessFieldOutputTemplate = zFieldOutputTemplateBase.extend({ export type StatelessFieldType = z.infer; export type StatelessFieldValue = z.infer; export type StatelessFieldInputInstance = z.infer; -export type StatelessFieldOutputInstance = z.infer; export type StatelessFieldInputTemplate = z.infer; export type StatelessFieldOutputTemplate = z.infer; // #endregion @@ -783,36 +685,6 @@ export const isFieldInputInstance = (val: unknown): val is FieldInputInstance => zFieldInputInstance.safeParse(val).success; // #endregion -// #region StatefulFieldOutputInstance & FieldOutputInstance -export const zStatefulFieldOutputInstance = z.union([ - zIntegerFieldOutputInstance, - zFloatFieldOutputInstance, - zStringFieldOutputInstance, - zBooleanFieldOutputInstance, - zEnumFieldOutputInstance, - zImageFieldOutputInstance, - zBoardFieldOutputInstance, - zMainModelFieldOutputInstance, - zSDXLMainModelFieldOutputInstance, - zSDXLRefinerModelFieldOutputInstance, - zVAEModelFieldOutputInstance, - zLoRAModelFieldOutputInstance, - zControlNetModelFieldOutputInstance, - zIPAdapterModelFieldOutputInstance, - zT2IAdapterModelFieldOutputInstance, - zColorFieldOutputInstance, - zSchedulerFieldOutputInstance, -]); -export type StatefulFieldOutputInstance = z.infer; -export const isStatefulFieldOutputInstance = (val: unknown): val is StatefulFieldOutputInstance => - zStatefulFieldOutputInstance.safeParse(val).success; - -export const zFieldOutputInstance = z.union([zStatefulFieldOutputInstance, zStatelessFieldOutputInstance]); -export type FieldOutputInstance = z.infer; -export const isFieldOutputInstance = (val: unknown): val is FieldOutputInstance => - zFieldOutputInstance.safeParse(val).success; -// #endregion - // #region StatefulFieldInputTemplate & FieldInputTemplate export const zStatefulFieldInputTemplate = z.union([ zIntegerFieldInputTemplate, diff --git a/invokeai/frontend/web/src/features/nodes/types/invocation.ts b/invokeai/frontend/web/src/features/nodes/types/invocation.ts index 86ec70fd9bd..5ccb19430da 100644 --- a/invokeai/frontend/web/src/features/nodes/types/invocation.ts +++ b/invokeai/frontend/web/src/features/nodes/types/invocation.ts @@ -2,7 +2,7 @@ import type { Edge, Node } from 'reactflow'; import { z } from 'zod'; import { zClassification, zProgressImage } from './common'; -import { zFieldInputInstance, zFieldInputTemplate, zFieldOutputInstance, zFieldOutputTemplate } from './field'; +import { zFieldInputInstance, zFieldInputTemplate, zFieldOutputTemplate } from './field'; import { zSemVer } from './semver'; // #region InvocationTemplate @@ -25,16 +25,15 @@ export type InvocationTemplate = z.infer; // #region NodeData export const zInvocationNodeData = z.object({ id: z.string().trim().min(1), - type: z.string().trim().min(1), + version: zSemVer, + nodePack: z.string().min(1).nullish(), label: z.string(), - isOpen: z.boolean(), notes: z.string(), + type: z.string().trim().min(1), + inputs: z.record(zFieldInputInstance), + isOpen: z.boolean(), isIntermediate: z.boolean(), useCache: z.boolean(), - version: zSemVer, - nodePack: z.string().min(1).nullish(), - inputs: z.record(zFieldInputInstance), - outputs: z.record(zFieldOutputInstance), }); export const zNotesNodeData = z.object({ @@ -62,11 +61,12 @@ export type NotesNode = Node; export type CurrentImageNode = Node; export type AnyNode = Node; -export const isInvocationNode = (node?: AnyNode): node is InvocationNode => Boolean(node && node.type === 'invocation'); -export const isNotesNode = (node?: AnyNode): node is NotesNode => Boolean(node && node.type === 'notes'); -export const isCurrentImageNode = (node?: AnyNode): node is CurrentImageNode => +export const isInvocationNode = (node?: AnyNode | null): node is InvocationNode => + Boolean(node && node.type === 'invocation'); +export const isNotesNode = (node?: AnyNode | null): node is NotesNode => Boolean(node && node.type === 'notes'); +export const isCurrentImageNode = (node?: AnyNode | null): node is CurrentImageNode => Boolean(node && node.type === 'current_image'); -export const isInvocationNodeData = (node?: AnyNodeData): node is InvocationNodeData => +export const isInvocationNodeData = (node?: AnyNodeData | null): node is InvocationNodeData => Boolean(node && !['notes', 'current_image'].includes(node.type)); // node.type may be 'notes', 'current_image', or any invocation type // #endregion diff --git a/invokeai/frontend/web/src/features/nodes/types/v2/common.ts b/invokeai/frontend/web/src/features/nodes/types/v2/common.ts new file mode 100644 index 00000000000..b5244743799 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/v2/common.ts @@ -0,0 +1,188 @@ +import { z } from 'zod'; + +// #region Field data schemas +export const zImageField = z.object({ + image_name: z.string().trim().min(1), +}); +export type ImageField = z.infer; + +export const zBoardField = z.object({ + board_id: z.string().trim().min(1), +}); +export type BoardField = z.infer; + +export const zColorField = z.object({ + r: z.number().int().min(0).max(255), + g: z.number().int().min(0).max(255), + b: z.number().int().min(0).max(255), + a: z.number().int().min(0).max(255), +}); +export type ColorField = z.infer; + +export const zClassification = z.enum(['stable', 'beta', 'prototype']); +export type Classification = z.infer; + +export const zSchedulerField = z.enum([ + 'euler', + 'deis', + 'ddim', + 'ddpm', + 'dpmpp_2s', + 'dpmpp_2m', + 'dpmpp_2m_sde', + 'dpmpp_sde', + 'heun', + 'kdpm_2', + 'lms', + 'pndm', + 'unipc', + 'euler_k', + 'dpmpp_2s_k', + 'dpmpp_2m_k', + 'dpmpp_2m_sde_k', + 'dpmpp_sde_k', + 'heun_k', + 'lms_k', + 'euler_a', + 'kdpm_2_a', + 'lcm', +]); +export type SchedulerField = z.infer; +// #endregion + +// #region Model-related schemas +export const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']); +export const zModelType = z.enum(['main', 'vae', 'lora', 'controlnet', 'embedding']); +export const zModelName = z.string().min(3); +export const zModelIdentifier = z.object({ + model_name: zModelName, + base_model: zBaseModel, +}); +export type BaseModel = z.infer; +export type ModelType = z.infer; +export type ModelIdentifier = z.infer; + +export const zMainModelField = z.object({ + model_name: zModelName, + base_model: zBaseModel, + model_type: z.literal('main'), +}); +export const zSDXLRefinerModelField = z.object({ + model_name: z.string().min(1), + base_model: z.literal('sdxl-refiner'), + model_type: z.literal('main'), +}); +export type MainModelField = z.infer; +export type SDXLRefinerModelField = z.infer; + +export const zSubModelType = z.enum([ + 'unet', + 'text_encoder', + 'text_encoder_2', + 'tokenizer', + 'tokenizer_2', + 'vae', + 'vae_decoder', + 'vae_encoder', + 'scheduler', + 'safety_checker', +]); +export type SubModelType = z.infer; + +export const zVAEModelField = zModelIdentifier; + +export const zModelInfo = zModelIdentifier.extend({ + model_type: zModelType, + submodel: zSubModelType.optional(), +}); +export type ModelInfo = z.infer; + +export const zLoRAModelField = zModelIdentifier; +export type LoRAModelField = z.infer; + +export const zControlNetModelField = zModelIdentifier; +export type ControlNetModelField = z.infer; + +export const zIPAdapterModelField = zModelIdentifier; +export type IPAdapterModelField = z.infer; + +export const zT2IAdapterModelField = zModelIdentifier; +export type T2IAdapterModelField = z.infer; + +export const zLoraInfo = zModelInfo.extend({ + weight: z.number().optional(), +}); +export type LoraInfo = z.infer; + +export const zUNetField = z.object({ + unet: zModelInfo, + scheduler: zModelInfo, + loras: z.array(zLoraInfo), +}); +export type UNetField = z.infer; + +export const zCLIPField = z.object({ + tokenizer: zModelInfo, + text_encoder: zModelInfo, + skipped_layers: z.number(), + loras: z.array(zLoraInfo), +}); +export type CLIPField = z.infer; + +export const zVAEField = z.object({ + vae: zModelInfo, +}); +export type VAEField = z.infer; +// #endregion + +// #region Control Adapters +export const zControlField = z.object({ + image: zImageField, + control_model: zControlNetModelField, + control_weight: z.union([z.number(), z.array(z.number())]).optional(), + begin_step_percent: z.number().optional(), + end_step_percent: z.number().optional(), + control_mode: z.enum(['balanced', 'more_prompt', 'more_control', 'unbalanced']).optional(), + resize_mode: z.enum(['just_resize', 'crop_resize', 'fill_resize', 'just_resize_simple']).optional(), +}); +export type ControlField = z.infer; + +export const zIPAdapterField = z.object({ + image: zImageField, + ip_adapter_model: zIPAdapterModelField, + weight: z.number(), + begin_step_percent: z.number().optional(), + end_step_percent: z.number().optional(), +}); +export type IPAdapterField = z.infer; + +export const zT2IAdapterField = z.object({ + image: zImageField, + t2i_adapter_model: zT2IAdapterModelField, + weight: z.union([z.number(), z.array(z.number())]).optional(), + begin_step_percent: z.number().optional(), + end_step_percent: z.number().optional(), + resize_mode: z.enum(['just_resize', 'crop_resize', 'fill_resize', 'just_resize_simple']).optional(), +}); +export type T2IAdapterField = z.infer; +// #endregion + +// #region ProgressImage +export const zProgressImage = z.object({ + dataURL: z.string(), + width: z.number().int(), + height: z.number().int(), +}); +export type ProgressImage = z.infer; +// #endregion + +// #region ImageOutput +export const zImageOutput = z.object({ + image: zImageField, + width: z.number().int().gt(0), + height: z.number().int().gt(0), + type: z.literal('image_output'), +}); +export type ImageOutput = z.infer; +export const isImageOutput = (output: unknown): output is ImageOutput => zImageOutput.safeParse(output).success; +// #endregion diff --git a/invokeai/frontend/web/src/features/nodes/types/v2/constants.ts b/invokeai/frontend/web/src/features/nodes/types/v2/constants.ts new file mode 100644 index 00000000000..35ef9e9fd2c --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/v2/constants.ts @@ -0,0 +1,80 @@ +import type { Node } from 'reactflow'; + +/** + * How long to wait before showing a tooltip when hovering a field handle. + */ +export const HANDLE_TOOLTIP_OPEN_DELAY = 500; + +/** + * The width of a node in the UI in pixels. + */ +export const NODE_WIDTH = 320; + +/** + * This class name is special - reactflow uses it to identify the drag handle of a node, + * applying the appropriate listeners to it. + */ +export const DRAG_HANDLE_CLASSNAME = 'node-drag-handle'; + +/** + * reactflow-specifc properties shared between all node types. + */ +export const SHARED_NODE_PROPERTIES: Partial = { + dragHandle: `.${DRAG_HANDLE_CLASSNAME}`, +}; + +/** + * Helper for getting the kind of a field. + */ +export const KIND_MAP = { + input: 'inputs' as const, + output: 'outputs' as const, +}; + +/** + * Model types' handles are rendered as squares in the UI. + */ +export const MODEL_TYPES = [ + 'IPAdapterModelField', + 'ControlNetModelField', + 'LoRAModelField', + 'MainModelField', + 'SDXLMainModelField', + 'SDXLRefinerModelField', + 'VaeModelField', + 'UNetField', + 'VaeField', + 'ClipField', + 'T2IAdapterModelField', + 'IPAdapterModelField', +]; + +/** + * Colors for each field type - applies to their handles and edges. + */ +export const FIELD_COLORS: { [key: string]: string } = { + BoardField: 'purple.500', + BooleanField: 'green.500', + ClipField: 'green.500', + ColorField: 'pink.300', + ConditioningField: 'cyan.500', + ControlField: 'teal.500', + ControlNetModelField: 'teal.500', + EnumField: 'blue.500', + FloatField: 'orange.500', + ImageField: 'purple.500', + IntegerField: 'red.500', + IPAdapterField: 'teal.500', + IPAdapterModelField: 'teal.500', + LatentsField: 'pink.500', + LoRAModelField: 'teal.500', + MainModelField: 'teal.500', + SDXLMainModelField: 'teal.500', + SDXLRefinerModelField: 'teal.500', + StringField: 'yellow.500', + T2IAdapterField: 'teal.500', + T2IAdapterModelField: 'teal.500', + UNetField: 'red.500', + VaeField: 'blue.500', + VaeModelField: 'teal.500', +}; diff --git a/invokeai/frontend/web/src/features/nodes/types/v2/error.ts b/invokeai/frontend/web/src/features/nodes/types/v2/error.ts new file mode 100644 index 00000000000..905b487fb04 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/v2/error.ts @@ -0,0 +1,58 @@ +/** + * Invalid Workflow Version Error + * Raised when a workflow version is not recognized. + */ +export class WorkflowVersionError extends Error { + /** + * Create WorkflowVersionError + * @param {String} message + */ + constructor(message: string) { + super(message); + this.name = this.constructor.name; + } +} +/** + * Workflow Migration Error + * Raised when a workflow migration fails. + */ +export class WorkflowMigrationError extends Error { + /** + * Create WorkflowMigrationError + * @param {String} message + */ + constructor(message: string) { + super(message); + this.name = this.constructor.name; + } +} + +/** + * Unable to Update Node Error + * Raised when a node cannot be updated. + */ +export class NodeUpdateError extends Error { + /** + * Create NodeUpdateError + * @param {String} message + */ + constructor(message: string) { + super(message); + this.name = this.constructor.name; + } +} + +/** + * FieldParseError + * Raised when a field cannot be parsed from a field schema. + */ +export class FieldParseError extends Error { + /** + * Create FieldTypeParseError + * @param {String} message + */ + constructor(message: string) { + super(message); + this.name = this.constructor.name; + } +} diff --git a/invokeai/frontend/web/src/features/nodes/types/v2/field.ts b/invokeai/frontend/web/src/features/nodes/types/v2/field.ts new file mode 100644 index 00000000000..38f1af55dd8 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/v2/field.ts @@ -0,0 +1,875 @@ +import { z } from 'zod'; + +import { + zBoardField, + zColorField, + zControlNetModelField, + zImageField, + zIPAdapterModelField, + zLoRAModelField, + zMainModelField, + zSchedulerField, + zT2IAdapterModelField, + zVAEModelField, +} from './common'; + +/** + * zod schemas & inferred types for fields. + * + * These schemas and types are only required for stateful field - fields that have UI components + * and allow the user to directly provide values. + * + * This includes primitive values (numbers, strings, booleans), models, scheduler, etc. + * + * If a field type does not have a UI component, then it does not need to be included here, because + * we never store its value. Such field types will be handled via the "StatelessField" logic. + * + * Fields require: + * - zFieldType - zod schema for the field type + * - zFieldValue - zod schema for the field value + * - zFieldInputInstance - zod schema for the field's input instance + * - zFieldOutputInstance - zod schema for the field's output instance + * - zFieldInputTemplate - zod schema for the field's input template + * - zFieldOutputTemplate - zod schema for the field's output template + * - inferred types for each schema + * - type guards for InputInstance and InputTemplate + * + * These then must be added to the unions at the bottom of this file. + */ + +/** */ + +// #region Base schemas & misc +export const zFieldInput = z.enum(['connection', 'direct', 'any']); +export type FieldInput = z.infer; + +export const zFieldUIComponent = z.enum(['none', 'textarea', 'slider']); +export type FieldUIComponent = z.infer; + +export const zFieldInstanceBase = z.object({ + id: z.string().trim().min(1), + name: z.string().trim().min(1), +}); +export const zFieldInputInstanceBase = zFieldInstanceBase.extend({ + fieldKind: z.literal('input'), + label: z.string().nullish(), +}); +export const zFieldOutputInstanceBase = zFieldInstanceBase.extend({ + fieldKind: z.literal('output'), +}); +export type FieldInstanceBase = z.infer; +export type FieldInputInstanceBase = z.infer; +export type FieldOutputInstanceBase = z.infer; + +export const zFieldTemplateBase = z.object({ + name: z.string().min(1), + title: z.string().min(1), + description: z.string().nullish(), + ui_hidden: z.boolean(), + ui_type: z.string().nullish(), + ui_order: z.number().int().nullish(), +}); +export const zFieldInputTemplateBase = zFieldTemplateBase.extend({ + fieldKind: z.literal('input'), + input: zFieldInput, + required: z.boolean(), + ui_component: zFieldUIComponent.nullish(), + ui_choice_labels: z.record(z.string()).nullish(), +}); +export const zFieldOutputTemplateBase = zFieldTemplateBase.extend({ + fieldKind: z.literal('output'), +}); +export type FieldTemplateBase = z.infer; +export type FieldInputTemplateBase = z.infer; +export type FieldOutputTemplateBase = z.infer; + +export const zFieldTypeBase = z.object({ + isCollection: z.boolean(), + isCollectionOrScalar: z.boolean(), +}); + +export const zFieldIdentifier = z.object({ + nodeId: z.string().trim().min(1), + fieldName: z.string().trim().min(1), +}); +export type FieldIdentifier = z.infer; +export const isFieldIdentifier = (val: unknown): val is FieldIdentifier => zFieldIdentifier.safeParse(val).success; +// #endregion + +// #region IntegerField +export const zIntegerFieldType = zFieldTypeBase.extend({ + name: z.literal('IntegerField'), +}); +export const zIntegerFieldValue = z.number().int(); +export const zIntegerFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zIntegerFieldType, + value: zIntegerFieldValue, +}); +export const zIntegerFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zIntegerFieldType, +}); +export const zIntegerFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zIntegerFieldType, + default: zIntegerFieldValue, + multipleOf: z.number().int().optional(), + maximum: z.number().int().optional(), + exclusiveMaximum: z.number().int().optional(), + minimum: z.number().int().optional(), + exclusiveMinimum: z.number().int().optional(), +}); +export const zIntegerFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zIntegerFieldType, +}); +export type IntegerFieldType = z.infer; +export type IntegerFieldValue = z.infer; +export type IntegerFieldInputInstance = z.infer; +export type IntegerFieldInputTemplate = z.infer; +export const isIntegerFieldInputInstance = (val: unknown): val is IntegerFieldInputInstance => + zIntegerFieldInputInstance.safeParse(val).success; +export const isIntegerFieldInputTemplate = (val: unknown): val is IntegerFieldInputTemplate => + zIntegerFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region FloatField +export const zFloatFieldType = zFieldTypeBase.extend({ + name: z.literal('FloatField'), +}); +export const zFloatFieldValue = z.number(); +export const zFloatFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zFloatFieldType, + value: zFloatFieldValue, +}); +export const zFloatFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zFloatFieldType, +}); +export const zFloatFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zFloatFieldType, + default: zFloatFieldValue, + multipleOf: z.number().optional(), + maximum: z.number().optional(), + exclusiveMaximum: z.number().optional(), + minimum: z.number().optional(), + exclusiveMinimum: z.number().optional(), +}); +export const zFloatFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zFloatFieldType, +}); +export type FloatFieldType = z.infer; +export type FloatFieldValue = z.infer; +export type FloatFieldInputInstance = z.infer; +export type FloatFieldOutputInstance = z.infer; +export type FloatFieldInputTemplate = z.infer; +export type FloatFieldOutputTemplate = z.infer; +export const isFloatFieldInputInstance = (val: unknown): val is FloatFieldInputInstance => + zFloatFieldInputInstance.safeParse(val).success; +export const isFloatFieldInputTemplate = (val: unknown): val is FloatFieldInputTemplate => + zFloatFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region StringField +export const zStringFieldType = zFieldTypeBase.extend({ + name: z.literal('StringField'), +}); +export const zStringFieldValue = z.string(); +export const zStringFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zStringFieldType, + value: zStringFieldValue, +}); +export const zStringFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zStringFieldType, +}); +export const zStringFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zStringFieldType, + default: zStringFieldValue, + maxLength: z.number().int().optional(), + minLength: z.number().int().optional(), +}); +export const zStringFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zStringFieldType, +}); + +export type StringFieldType = z.infer; +export type StringFieldValue = z.infer; +export type StringFieldInputInstance = z.infer; +export type StringFieldOutputInstance = z.infer; +export type StringFieldInputTemplate = z.infer; +export type StringFieldOutputTemplate = z.infer; +export const isStringFieldInputInstance = (val: unknown): val is StringFieldInputInstance => + zStringFieldInputInstance.safeParse(val).success; +export const isStringFieldInputTemplate = (val: unknown): val is StringFieldInputTemplate => + zStringFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region BooleanField +export const zBooleanFieldType = zFieldTypeBase.extend({ + name: z.literal('BooleanField'), +}); +export const zBooleanFieldValue = z.boolean(); +export const zBooleanFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zBooleanFieldType, + value: zBooleanFieldValue, +}); +export const zBooleanFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zBooleanFieldType, +}); +export const zBooleanFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zBooleanFieldType, + default: zBooleanFieldValue, +}); +export const zBooleanFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zBooleanFieldType, +}); +export type BooleanFieldType = z.infer; +export type BooleanFieldValue = z.infer; +export type BooleanFieldInputInstance = z.infer; +export type BooleanFieldOutputInstance = z.infer; +export type BooleanFieldInputTemplate = z.infer; +export type BooleanFieldOutputTemplate = z.infer; +export const isBooleanFieldInputInstance = (val: unknown): val is BooleanFieldInputInstance => + zBooleanFieldInputInstance.safeParse(val).success; +export const isBooleanFieldInputTemplate = (val: unknown): val is BooleanFieldInputTemplate => + zBooleanFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region EnumField +export const zEnumFieldType = zFieldTypeBase.extend({ + name: z.literal('EnumField'), +}); +export const zEnumFieldValue = z.string(); +export const zEnumFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zEnumFieldType, + value: zEnumFieldValue, +}); +export const zEnumFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zEnumFieldType, +}); +export const zEnumFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zEnumFieldType, + default: zEnumFieldValue, + options: z.array(z.string()), + labels: z.record(z.string()).optional(), +}); +export const zEnumFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zEnumFieldType, +}); +export type EnumFieldType = z.infer; +export type EnumFieldValue = z.infer; +export type EnumFieldInputInstance = z.infer; +export type EnumFieldOutputInstance = z.infer; +export type EnumFieldInputTemplate = z.infer; +export type EnumFieldOutputTemplate = z.infer; +export const isEnumFieldInputInstance = (val: unknown): val is EnumFieldInputInstance => + zEnumFieldInputInstance.safeParse(val).success; +export const isEnumFieldInputTemplate = (val: unknown): val is EnumFieldInputTemplate => + zEnumFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region ImageField +export const zImageFieldType = zFieldTypeBase.extend({ + name: z.literal('ImageField'), +}); +export const zImageFieldValue = zImageField.optional(); +export const zImageFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zImageFieldType, + value: zImageFieldValue, +}); +export const zImageFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zImageFieldType, +}); +export const zImageFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zImageFieldType, + default: zImageFieldValue, +}); +export const zImageFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zImageFieldType, +}); +export type ImageFieldType = z.infer; +export type ImageFieldValue = z.infer; +export type ImageFieldInputInstance = z.infer; +export type ImageFieldOutputInstance = z.infer; +export type ImageFieldInputTemplate = z.infer; +export type ImageFieldOutputTemplate = z.infer; +export const isImageFieldInputInstance = (val: unknown): val is ImageFieldInputInstance => + zImageFieldInputInstance.safeParse(val).success; +export const isImageFieldInputTemplate = (val: unknown): val is ImageFieldInputTemplate => + zImageFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region BoardField +export const zBoardFieldType = zFieldTypeBase.extend({ + name: z.literal('BoardField'), +}); +export const zBoardFieldValue = zBoardField.optional(); +export const zBoardFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zBoardFieldType, + value: zBoardFieldValue, +}); +export const zBoardFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zBoardFieldType, +}); +export const zBoardFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zBoardFieldType, + default: zBoardFieldValue, +}); +export const zBoardFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zBoardFieldType, +}); +export type BoardFieldType = z.infer; +export type BoardFieldValue = z.infer; +export type BoardFieldInputInstance = z.infer; +export type BoardFieldOutputInstance = z.infer; +export type BoardFieldInputTemplate = z.infer; +export type BoardFieldOutputTemplate = z.infer; +export const isBoardFieldInputInstance = (val: unknown): val is BoardFieldInputInstance => + zBoardFieldInputInstance.safeParse(val).success; +export const isBoardFieldInputTemplate = (val: unknown): val is BoardFieldInputTemplate => + zBoardFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region ColorField +export const zColorFieldType = zFieldTypeBase.extend({ + name: z.literal('ColorField'), +}); +export const zColorFieldValue = zColorField.optional(); +export const zColorFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zColorFieldType, + value: zColorFieldValue, +}); +export const zColorFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zColorFieldType, +}); +export const zColorFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zColorFieldType, + default: zColorFieldValue, +}); +export const zColorFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zColorFieldType, +}); +export type ColorFieldType = z.infer; +export type ColorFieldValue = z.infer; +export type ColorFieldInputInstance = z.infer; +export type ColorFieldOutputInstance = z.infer; +export type ColorFieldInputTemplate = z.infer; +export type ColorFieldOutputTemplate = z.infer; +export const isColorFieldInputInstance = (val: unknown): val is ColorFieldInputInstance => + zColorFieldInputInstance.safeParse(val).success; +export const isColorFieldInputTemplate = (val: unknown): val is ColorFieldInputTemplate => + zColorFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region MainModelField +export const zMainModelFieldType = zFieldTypeBase.extend({ + name: z.literal('MainModelField'), +}); +export const zMainModelFieldValue = zMainModelField.optional(); +export const zMainModelFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zMainModelFieldType, + value: zMainModelFieldValue, +}); +export const zMainModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zMainModelFieldType, +}); +export const zMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zMainModelFieldType, + default: zMainModelFieldValue, +}); +export const zMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zMainModelFieldType, +}); +export type MainModelFieldType = z.infer; +export type MainModelFieldValue = z.infer; +export type MainModelFieldInputInstance = z.infer; +export type MainModelFieldOutputInstance = z.infer; +export type MainModelFieldInputTemplate = z.infer; +export type MainModelFieldOutputTemplate = z.infer; +export const isMainModelFieldInputInstance = (val: unknown): val is MainModelFieldInputInstance => + zMainModelFieldInputInstance.safeParse(val).success; +export const isMainModelFieldInputTemplate = (val: unknown): val is MainModelFieldInputTemplate => + zMainModelFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region SDXLMainModelField +export const zSDXLMainModelFieldType = zFieldTypeBase.extend({ + name: z.literal('SDXLMainModelField'), +}); +export const zSDXLMainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL models only. +export const zSDXLMainModelFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zSDXLMainModelFieldType, + value: zSDXLMainModelFieldValue, +}); +export const zSDXLMainModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zSDXLMainModelFieldType, +}); +export const zSDXLMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zSDXLMainModelFieldType, + default: zSDXLMainModelFieldValue, +}); +export const zSDXLMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zSDXLMainModelFieldType, +}); +export type SDXLMainModelFieldType = z.infer; +export type SDXLMainModelFieldValue = z.infer; +export type SDXLMainModelFieldInputInstance = z.infer; +export type SDXLMainModelFieldOutputInstance = z.infer; +export type SDXLMainModelFieldInputTemplate = z.infer; +export type SDXLMainModelFieldOutputTemplate = z.infer; +export const isSDXLMainModelFieldInputInstance = (val: unknown): val is SDXLMainModelFieldInputInstance => + zSDXLMainModelFieldInputInstance.safeParse(val).success; +export const isSDXLMainModelFieldInputTemplate = (val: unknown): val is SDXLMainModelFieldInputTemplate => + zSDXLMainModelFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region SDXLRefinerModelField +export const zSDXLRefinerModelFieldType = zFieldTypeBase.extend({ + name: z.literal('SDXLRefinerModelField'), +}); +export const zSDXLRefinerModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL Refiner models only. +export const zSDXLRefinerModelFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zSDXLRefinerModelFieldType, + value: zSDXLRefinerModelFieldValue, +}); +export const zSDXLRefinerModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zSDXLRefinerModelFieldType, +}); +export const zSDXLRefinerModelFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zSDXLRefinerModelFieldType, + default: zSDXLRefinerModelFieldValue, +}); +export const zSDXLRefinerModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zSDXLRefinerModelFieldType, +}); +export type SDXLRefinerModelFieldType = z.infer; +export type SDXLRefinerModelFieldValue = z.infer; +export type SDXLRefinerModelFieldInputInstance = z.infer; +export type SDXLRefinerModelFieldOutputInstance = z.infer; +export type SDXLRefinerModelFieldInputTemplate = z.infer; +export type SDXLRefinerModelFieldOutputTemplate = z.infer; +export const isSDXLRefinerModelFieldInputInstance = (val: unknown): val is SDXLRefinerModelFieldInputInstance => + zSDXLRefinerModelFieldInputInstance.safeParse(val).success; +export const isSDXLRefinerModelFieldInputTemplate = (val: unknown): val is SDXLRefinerModelFieldInputTemplate => + zSDXLRefinerModelFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region VAEModelField +export const zVAEModelFieldType = zFieldTypeBase.extend({ + name: z.literal('VAEModelField'), +}); +export const zVAEModelFieldValue = zVAEModelField.optional(); +export const zVAEModelFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zVAEModelFieldType, + value: zVAEModelFieldValue, +}); +export const zVAEModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zVAEModelFieldType, +}); +export const zVAEModelFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zVAEModelFieldType, + default: zVAEModelFieldValue, +}); +export const zVAEModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zVAEModelFieldType, +}); +export type VAEModelFieldType = z.infer; +export type VAEModelFieldValue = z.infer; +export type VAEModelFieldInputInstance = z.infer; +export type VAEModelFieldOutputInstance = z.infer; +export type VAEModelFieldInputTemplate = z.infer; +export type VAEModelFieldOutputTemplate = z.infer; +export const isVAEModelFieldInputInstance = (val: unknown): val is VAEModelFieldInputInstance => + zVAEModelFieldInputInstance.safeParse(val).success; +export const isVAEModelFieldInputTemplate = (val: unknown): val is VAEModelFieldInputTemplate => + zVAEModelFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region LoRAModelField +export const zLoRAModelFieldType = zFieldTypeBase.extend({ + name: z.literal('LoRAModelField'), +}); +export const zLoRAModelFieldValue = zLoRAModelField.optional(); +export const zLoRAModelFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zLoRAModelFieldType, + value: zLoRAModelFieldValue, +}); +export const zLoRAModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zLoRAModelFieldType, +}); +export const zLoRAModelFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zLoRAModelFieldType, + default: zLoRAModelFieldValue, +}); +export const zLoRAModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zLoRAModelFieldType, +}); +export type LoRAModelFieldType = z.infer; +export type LoRAModelFieldValue = z.infer; +export type LoRAModelFieldInputInstance = z.infer; +export type LoRAModelFieldOutputInstance = z.infer; +export type LoRAModelFieldInputTemplate = z.infer; +export type LoRAModelFieldOutputTemplate = z.infer; +export const isLoRAModelFieldInputInstance = (val: unknown): val is LoRAModelFieldInputInstance => + zLoRAModelFieldInputInstance.safeParse(val).success; +export const isLoRAModelFieldInputTemplate = (val: unknown): val is LoRAModelFieldInputTemplate => + zLoRAModelFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region ControlNetModelField +export const zControlNetModelFieldType = zFieldTypeBase.extend({ + name: z.literal('ControlNetModelField'), +}); +export const zControlNetModelFieldValue = zControlNetModelField.optional(); +export const zControlNetModelFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zControlNetModelFieldType, + value: zControlNetModelFieldValue, +}); +export const zControlNetModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zControlNetModelFieldType, +}); +export const zControlNetModelFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zControlNetModelFieldType, + default: zControlNetModelFieldValue, +}); +export const zControlNetModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zControlNetModelFieldType, +}); +export type ControlNetModelFieldType = z.infer; +export type ControlNetModelFieldValue = z.infer; +export type ControlNetModelFieldInputInstance = z.infer; +export type ControlNetModelFieldOutputInstance = z.infer; +export type ControlNetModelFieldInputTemplate = z.infer; +export type ControlNetModelFieldOutputTemplate = z.infer; +export const isControlNetModelFieldInputInstance = (val: unknown): val is ControlNetModelFieldInputInstance => + zControlNetModelFieldInputInstance.safeParse(val).success; +export const isControlNetModelFieldInputTemplate = (val: unknown): val is ControlNetModelFieldInputTemplate => + zControlNetModelFieldInputTemplate.safeParse(val).success; +export const isControlNetModelFieldValue = (v: unknown): v is ControlNetModelFieldValue => + zControlNetModelFieldValue.safeParse(v).success; +// #endregion + +// #region IPAdapterModelField +export const zIPAdapterModelFieldType = zFieldTypeBase.extend({ + name: z.literal('IPAdapterModelField'), +}); +export const zIPAdapterModelFieldValue = zIPAdapterModelField.optional(); +export const zIPAdapterModelFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zIPAdapterModelFieldType, + value: zIPAdapterModelFieldValue, +}); +export const zIPAdapterModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zIPAdapterModelFieldType, +}); +export const zIPAdapterModelFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zIPAdapterModelFieldType, + default: zIPAdapterModelFieldValue, +}); +export const zIPAdapterModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zIPAdapterModelFieldType, +}); +export type IPAdapterModelFieldType = z.infer; +export type IPAdapterModelFieldValue = z.infer; +export type IPAdapterModelFieldInputInstance = z.infer; +export type IPAdapterModelFieldOutputInstance = z.infer; +export type IPAdapterModelFieldInputTemplate = z.infer; +export type IPAdapterModelFieldOutputTemplate = z.infer; +export const isIPAdapterModelFieldInputInstance = (val: unknown): val is IPAdapterModelFieldInputInstance => + zIPAdapterModelFieldInputInstance.safeParse(val).success; +export const isIPAdapterModelFieldInputTemplate = (val: unknown): val is IPAdapterModelFieldInputTemplate => + zIPAdapterModelFieldInputTemplate.safeParse(val).success; +export const isIPAdapterModelFieldValue = (val: unknown): val is IPAdapterModelFieldValue => + zIPAdapterModelFieldValue.safeParse(val).success; +// #endregion + +// #region T2IAdapterField +export const zT2IAdapterModelFieldType = zFieldTypeBase.extend({ + name: z.literal('T2IAdapterModelField'), +}); +export const zT2IAdapterModelFieldValue = zT2IAdapterModelField.optional(); +export const zT2IAdapterModelFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zT2IAdapterModelFieldType, + value: zT2IAdapterModelFieldValue, +}); +export const zT2IAdapterModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zT2IAdapterModelFieldType, +}); +export const zT2IAdapterModelFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zT2IAdapterModelFieldType, + default: zT2IAdapterModelFieldValue, +}); +export const zT2IAdapterModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zT2IAdapterModelFieldType, +}); +export type T2IAdapterModelFieldType = z.infer; +export type T2IAdapterModelFieldValue = z.infer; +export type T2IAdapterModelFieldInputInstance = z.infer; +export type T2IAdapterModelFieldOutputInstance = z.infer; +export type T2IAdapterModelFieldInputTemplate = z.infer; +export type T2IAdapterModelFieldOutputTemplate = z.infer; +export const isT2IAdapterModelFieldInputInstance = (val: unknown): val is T2IAdapterModelFieldInputInstance => + zT2IAdapterModelFieldInputInstance.safeParse(val).success; +export const isT2IAdapterModelFieldInputTemplate = (val: unknown): val is T2IAdapterModelFieldInputTemplate => + zT2IAdapterModelFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region SchedulerField +export const zSchedulerFieldType = zFieldTypeBase.extend({ + name: z.literal('SchedulerField'), +}); +export const zSchedulerFieldValue = zSchedulerField.optional(); +export const zSchedulerFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zSchedulerFieldType, + value: zSchedulerFieldValue, +}); +export const zSchedulerFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zSchedulerFieldType, +}); +export const zSchedulerFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zSchedulerFieldType, + default: zSchedulerFieldValue, +}); +export const zSchedulerFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zSchedulerFieldType, +}); +export type SchedulerFieldType = z.infer; +export type SchedulerFieldValue = z.infer; +export type SchedulerFieldInputInstance = z.infer; +export type SchedulerFieldOutputInstance = z.infer; +export type SchedulerFieldInputTemplate = z.infer; +export type SchedulerFieldOutputTemplate = z.infer; +export const isSchedulerFieldInputInstance = (val: unknown): val is SchedulerFieldInputInstance => + zSchedulerFieldInputInstance.safeParse(val).success; +export const isSchedulerFieldInputTemplate = (val: unknown): val is SchedulerFieldInputTemplate => + zSchedulerFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region StatelessField +/** + * StatelessField is a catchall for stateless fields with no UI input components. They do not + * do not support "direct" input, instead only accepting connections from other fields. + * + * This field type serves as a "generic" field type. + * + * Examples include: + * - Fields like UNetField or LatentsField where we do not allow direct UI input + * - Reserved fields like IsIntermediate + * - Any other field we don't have full-on schemas for + */ +export const zStatelessFieldType = zFieldTypeBase.extend({ + name: z.string().min(1), // stateless --> we accept the field's name as the type +}); +export const zStatelessFieldValue = z.undefined().catch(undefined); // stateless --> no value, but making this z.never() introduces a lot of extra TS fanagling +export const zStatelessFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zStatelessFieldType, + value: zStatelessFieldValue, +}); +export const zStatelessFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zStatelessFieldType, +}); +export const zStatelessFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zStatelessFieldType, + default: zStatelessFieldValue, + input: z.literal('connection'), // stateless --> only accepts connection inputs +}); +export const zStatelessFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zStatelessFieldType, +}); + +export type StatelessFieldType = z.infer; +export type StatelessFieldValue = z.infer; +export type StatelessFieldInputInstance = z.infer; +export type StatelessFieldOutputInstance = z.infer; +export type StatelessFieldInputTemplate = z.infer; +export type StatelessFieldOutputTemplate = z.infer; +// #endregion + +/** + * Here we define the main field unions: + * - FieldType + * - FieldValue + * - FieldInputInstance + * - FieldOutputInstance + * - FieldInputTemplate + * - FieldOutputTemplate + * + * All stateful fields are unioned together, and then that union is unioned with StatelessField. + * + * This allows us to interact with stateful fields without needing to worry about "generic" handling + * for all other StatelessFields. + */ + +// #region StatefulFieldType & FieldType +export const zStatefulFieldType = z.union([ + zIntegerFieldType, + zFloatFieldType, + zStringFieldType, + zBooleanFieldType, + zEnumFieldType, + zImageFieldType, + zBoardFieldType, + zMainModelFieldType, + zSDXLMainModelFieldType, + zSDXLRefinerModelFieldType, + zVAEModelFieldType, + zLoRAModelFieldType, + zControlNetModelFieldType, + zIPAdapterModelFieldType, + zT2IAdapterModelFieldType, + zColorFieldType, + zSchedulerFieldType, +]); +export type StatefulFieldType = z.infer; +export const isStatefulFieldType = (val: unknown): val is StatefulFieldType => + zStatefulFieldType.safeParse(val).success; + +export const zFieldType = z.union([zStatefulFieldType, zStatelessFieldType]); +export type FieldType = z.infer; +export const isFieldType = (val: unknown): val is FieldType => zFieldType.safeParse(val).success; +// #endregion + +// #region StatefulFieldValue & FieldValue +export const zStatefulFieldValue = z.union([ + zIntegerFieldValue, + zFloatFieldValue, + zStringFieldValue, + zBooleanFieldValue, + zEnumFieldValue, + zImageFieldValue, + zBoardFieldValue, + zMainModelFieldValue, + zSDXLMainModelFieldValue, + zSDXLRefinerModelFieldValue, + zVAEModelFieldValue, + zLoRAModelFieldValue, + zControlNetModelFieldValue, + zIPAdapterModelFieldValue, + zT2IAdapterModelFieldValue, + zColorFieldValue, + zSchedulerFieldValue, +]); +export type StatefulFieldValue = z.infer; +export const isStatefulFieldValue = (val: unknown): val is StatefulFieldValue => + zStatefulFieldValue.safeParse(val).success; + +export const zFieldValue = z.union([zStatefulFieldValue, zStatelessFieldValue]); +export type FieldValue = z.infer; +export const isFieldValue = (val: unknown): val is FieldValue => zFieldValue.safeParse(val).success; +// #endregion + +// #region StatefulFieldInputInstance & FieldInputInstance +export const zStatefulFieldInputInstance = z.union([ + zIntegerFieldInputInstance, + zFloatFieldInputInstance, + zStringFieldInputInstance, + zBooleanFieldInputInstance, + zEnumFieldInputInstance, + zImageFieldInputInstance, + zBoardFieldInputInstance, + zMainModelFieldInputInstance, + zSDXLMainModelFieldInputInstance, + zSDXLRefinerModelFieldInputInstance, + zVAEModelFieldInputInstance, + zLoRAModelFieldInputInstance, + zControlNetModelFieldInputInstance, + zIPAdapterModelFieldInputInstance, + zT2IAdapterModelFieldInputInstance, + zColorFieldInputInstance, + zSchedulerFieldInputInstance, +]); +export type StatefulFieldInputInstance = z.infer; +export const isStatefulFieldInputInstance = (val: unknown): val is StatefulFieldInputInstance => + zStatefulFieldInputInstance.safeParse(val).success; + +export const zFieldInputInstance = z.union([zStatefulFieldInputInstance, zStatelessFieldInputInstance]); +export type FieldInputInstance = z.infer; +export const isFieldInputInstance = (val: unknown): val is FieldInputInstance => + zFieldInputInstance.safeParse(val).success; +// #endregion + +// #region StatefulFieldOutputInstance & FieldOutputInstance +export const zStatefulFieldOutputInstance = z.union([ + zIntegerFieldOutputInstance, + zFloatFieldOutputInstance, + zStringFieldOutputInstance, + zBooleanFieldOutputInstance, + zEnumFieldOutputInstance, + zImageFieldOutputInstance, + zBoardFieldOutputInstance, + zMainModelFieldOutputInstance, + zSDXLMainModelFieldOutputInstance, + zSDXLRefinerModelFieldOutputInstance, + zVAEModelFieldOutputInstance, + zLoRAModelFieldOutputInstance, + zControlNetModelFieldOutputInstance, + zIPAdapterModelFieldOutputInstance, + zT2IAdapterModelFieldOutputInstance, + zColorFieldOutputInstance, + zSchedulerFieldOutputInstance, +]); +export type StatefulFieldOutputInstance = z.infer; +export const isStatefulFieldOutputInstance = (val: unknown): val is StatefulFieldOutputInstance => + zStatefulFieldOutputInstance.safeParse(val).success; + +export const zFieldOutputInstance = z.union([zStatefulFieldOutputInstance, zStatelessFieldOutputInstance]); +export type FieldOutputInstance = z.infer; +export const isFieldOutputInstance = (val: unknown): val is FieldOutputInstance => + zFieldOutputInstance.safeParse(val).success; +// #endregion + +// #region StatefulFieldInputTemplate & FieldInputTemplate +export const zStatefulFieldInputTemplate = z.union([ + zIntegerFieldInputTemplate, + zFloatFieldInputTemplate, + zStringFieldInputTemplate, + zBooleanFieldInputTemplate, + zEnumFieldInputTemplate, + zImageFieldInputTemplate, + zBoardFieldInputTemplate, + zMainModelFieldInputTemplate, + zSDXLMainModelFieldInputTemplate, + zSDXLRefinerModelFieldInputTemplate, + zVAEModelFieldInputTemplate, + zLoRAModelFieldInputTemplate, + zControlNetModelFieldInputTemplate, + zIPAdapterModelFieldInputTemplate, + zT2IAdapterModelFieldInputTemplate, + zColorFieldInputTemplate, + zSchedulerFieldInputTemplate, + zStatelessFieldInputTemplate, +]); +export type StatefulFieldInputTemplate = z.infer; +export const isStatefulFieldInputTemplate = (val: unknown): val is StatefulFieldInputTemplate => + zStatefulFieldInputTemplate.safeParse(val).success; + +export const zFieldInputTemplate = z.union([zStatefulFieldInputTemplate, zStatelessFieldInputTemplate]); +export type FieldInputTemplate = z.infer; +export const isFieldInputTemplate = (val: unknown): val is FieldInputTemplate => + zFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region StatefulFieldOutputTemplate & FieldOutputTemplate +export const zStatefulFieldOutputTemplate = z.union([ + zIntegerFieldOutputTemplate, + zFloatFieldOutputTemplate, + zStringFieldOutputTemplate, + zBooleanFieldOutputTemplate, + zEnumFieldOutputTemplate, + zImageFieldOutputTemplate, + zBoardFieldOutputTemplate, + zMainModelFieldOutputTemplate, + zSDXLMainModelFieldOutputTemplate, + zSDXLRefinerModelFieldOutputTemplate, + zVAEModelFieldOutputTemplate, + zLoRAModelFieldOutputTemplate, + zControlNetModelFieldOutputTemplate, + zIPAdapterModelFieldOutputTemplate, + zT2IAdapterModelFieldOutputTemplate, + zColorFieldOutputTemplate, + zSchedulerFieldOutputTemplate, +]); +export type StatefulFieldOutputTemplate = z.infer; +export const isStatefulFieldOutputTemplate = (val: unknown): val is StatefulFieldOutputTemplate => + zStatefulFieldOutputTemplate.safeParse(val).success; + +export const zFieldOutputTemplate = z.union([zStatefulFieldOutputTemplate, zStatelessFieldOutputTemplate]); +export type FieldOutputTemplate = z.infer; +export const isFieldOutputTemplate = (val: unknown): val is FieldOutputTemplate => + zFieldOutputTemplate.safeParse(val).success; +// #endregion diff --git a/invokeai/frontend/web/src/features/nodes/types/v2/invocation.ts b/invokeai/frontend/web/src/features/nodes/types/v2/invocation.ts new file mode 100644 index 00000000000..86ec70fd9bd --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/v2/invocation.ts @@ -0,0 +1,93 @@ +import type { Edge, Node } from 'reactflow'; +import { z } from 'zod'; + +import { zClassification, zProgressImage } from './common'; +import { zFieldInputInstance, zFieldInputTemplate, zFieldOutputInstance, zFieldOutputTemplate } from './field'; +import { zSemVer } from './semver'; + +// #region InvocationTemplate +export const zInvocationTemplate = z.object({ + type: z.string(), + title: z.string(), + description: z.string(), + tags: z.array(z.string().min(1)), + inputs: z.record(zFieldInputTemplate), + outputs: z.record(zFieldOutputTemplate), + outputType: z.string().min(1), + version: zSemVer, + useCache: z.boolean(), + nodePack: z.string().min(1).nullish(), + classification: zClassification, +}); +export type InvocationTemplate = z.infer; +// #endregion + +// #region NodeData +export const zInvocationNodeData = z.object({ + id: z.string().trim().min(1), + type: z.string().trim().min(1), + label: z.string(), + isOpen: z.boolean(), + notes: z.string(), + isIntermediate: z.boolean(), + useCache: z.boolean(), + version: zSemVer, + nodePack: z.string().min(1).nullish(), + inputs: z.record(zFieldInputInstance), + outputs: z.record(zFieldOutputInstance), +}); + +export const zNotesNodeData = z.object({ + id: z.string().trim().min(1), + type: z.literal('notes'), + label: z.string(), + isOpen: z.boolean(), + notes: z.string(), +}); +export const zCurrentImageNodeData = z.object({ + id: z.string().trim().min(1), + type: z.literal('current_image'), + label: z.string(), + isOpen: z.boolean(), +}); +export const zAnyNodeData = z.union([zInvocationNodeData, zNotesNodeData, zCurrentImageNodeData]); + +export type NotesNodeData = z.infer; +export type InvocationNodeData = z.infer; +export type CurrentImageNodeData = z.infer; +export type AnyNodeData = z.infer; + +export type InvocationNode = Node; +export type NotesNode = Node; +export type CurrentImageNode = Node; +export type AnyNode = Node; + +export const isInvocationNode = (node?: AnyNode): node is InvocationNode => Boolean(node && node.type === 'invocation'); +export const isNotesNode = (node?: AnyNode): node is NotesNode => Boolean(node && node.type === 'notes'); +export const isCurrentImageNode = (node?: AnyNode): node is CurrentImageNode => + Boolean(node && node.type === 'current_image'); +export const isInvocationNodeData = (node?: AnyNodeData): node is InvocationNodeData => + Boolean(node && !['notes', 'current_image'].includes(node.type)); // node.type may be 'notes', 'current_image', or any invocation type +// #endregion + +// #region NodeExecutionState +export const zNodeStatus = z.enum(['PENDING', 'IN_PROGRESS', 'COMPLETED', 'FAILED']); +export const zNodeExecutionState = z.object({ + nodeId: z.string().trim().min(1), + status: zNodeStatus, + progress: z.number().nullable(), + progressImage: zProgressImage.nullable(), + error: z.string().nullable(), + outputs: z.array(z.any()), +}); +export type NodeExecutionState = z.infer; +export type NodeStatus = z.infer; +// #endregion + +// #region Edges +export const zInvocationNodeEdgeExtra = z.object({ + type: z.union([z.literal('default'), z.literal('collapsed')]), +}); +export type InvocationNodeEdgeExtra = z.infer; +export type InvocationNodeEdge = Edge; +// #endregion diff --git a/invokeai/frontend/web/src/features/nodes/types/v2/metadata.ts b/invokeai/frontend/web/src/features/nodes/types/v2/metadata.ts new file mode 100644 index 00000000000..0cc30499e38 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/v2/metadata.ts @@ -0,0 +1,77 @@ +import { z } from 'zod'; + +import { + zControlField, + zIPAdapterField, + zLoRAModelField, + zMainModelField, + zSDXLRefinerModelField, + zT2IAdapterField, + zVAEModelField, +} from './common'; + +// #region Metadata-optimized versions of schemas +// TODO: It's possible that `deepPartial` will be deprecated: +// - https://github.com/colinhacks/zod/issues/2106 +// - https://github.com/colinhacks/zod/issues/2854 +export const zLoRAMetadataItem = z.object({ + lora: zLoRAModelField.deepPartial(), + weight: z.number(), +}); +const zControlNetMetadataItem = zControlField.deepPartial(); +const zIPAdapterMetadataItem = zIPAdapterField.deepPartial(); +const zT2IAdapterMetadataItem = zT2IAdapterField.deepPartial(); +const zSDXLRefinerModelMetadataItem = zSDXLRefinerModelField.deepPartial(); +const zModelMetadataItem = zMainModelField.deepPartial(); +const zVAEModelMetadataItem = zVAEModelField.deepPartial(); +export type LoRAMetadataItem = z.infer; +export type ControlNetMetadataItem = z.infer; +export type IPAdapterMetadataItem = z.infer; +export type T2IAdapterMetadataItem = z.infer; +export type SDXLRefinerModelMetadataItem = z.infer; +export type ModelMetadataItem = z.infer; +export type VAEModelMetadataItem = z.infer; +// #endregion + +// #region CoreMetadata +export const zCoreMetadata = z + .object({ + app_version: z.string().nullish().catch(null), + generation_mode: z.string().nullish().catch(null), + created_by: z.string().nullish().catch(null), + positive_prompt: z.string().nullish().catch(null), + negative_prompt: z.string().nullish().catch(null), + width: z.number().int().nullish().catch(null), + height: z.number().int().nullish().catch(null), + seed: z.number().int().nullish().catch(null), + rand_device: z.string().nullish().catch(null), + cfg_scale: z.number().nullish().catch(null), + cfg_rescale_multiplier: z.number().nullish().catch(null), + steps: z.number().int().nullish().catch(null), + scheduler: z.string().nullish().catch(null), + clip_skip: z.number().int().nullish().catch(null), + model: zModelMetadataItem.nullish().catch(null), + controlnets: z.array(zControlNetMetadataItem).nullish().catch(null), + ipAdapters: z.array(zIPAdapterMetadataItem).nullish().catch(null), + t2iAdapters: z.array(zT2IAdapterMetadataItem).nullish().catch(null), + loras: z.array(zLoRAMetadataItem).nullish().catch(null), + vae: zVAEModelMetadataItem.nullish().catch(null), + strength: z.number().nullish().catch(null), + hrf_enabled: z.boolean().nullish().catch(null), + hrf_strength: z.number().nullish().catch(null), + hrf_method: z.string().nullish().catch(null), + init_image: z.string().nullish().catch(null), + positive_style_prompt: z.string().nullish().catch(null), + negative_style_prompt: z.string().nullish().catch(null), + refiner_model: zSDXLRefinerModelMetadataItem.nullish().catch(null), + refiner_cfg_scale: z.number().nullish().catch(null), + refiner_steps: z.number().int().nullish().catch(null), + refiner_scheduler: z.string().nullish().catch(null), + refiner_positive_aesthetic_score: z.number().nullish().catch(null), + refiner_negative_aesthetic_score: z.number().nullish().catch(null), + refiner_start: z.number().nullish().catch(null), + }) + .passthrough(); +export type CoreMetadata = z.infer; + +// #endregion diff --git a/invokeai/frontend/web/src/features/nodes/types/v2/openapi.ts b/invokeai/frontend/web/src/features/nodes/types/v2/openapi.ts new file mode 100644 index 00000000000..83d774439a3 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/v2/openapi.ts @@ -0,0 +1,86 @@ +import type { OpenAPIV3_1 } from 'openapi-types'; +import type { + InputFieldJSONSchemaExtra, + InvocationJSONSchemaExtra, + OutputFieldJSONSchemaExtra, +} from 'services/api/types'; + +// Janky customization of OpenAPI Schema :/ + +export type InvocationSchemaExtra = InvocationJSONSchemaExtra & { + output: OpenAPIV3_1.ReferenceObject; // the output of the invocation + title: string; + category?: string; + tags?: string[]; + version: string; + properties: Omit< + NonNullable & (InputFieldJSONSchemaExtra | OutputFieldJSONSchemaExtra), + 'type' + > & { + type: Omit & { + default: string; + }; + use_cache: Omit & { + default: boolean; + }; + }; +}; + +export type InvocationSchemaType = { + default: string; // the type of the invocation +}; + +export type InvocationBaseSchemaObject = Omit & + InvocationSchemaExtra; + +export type InvocationOutputSchemaObject = Omit & { + properties: OpenAPIV3_1.SchemaObject['properties'] & { + type: Omit & { + default: string; + }; + } & { + class: 'output'; + }; +}; + +export type InvocationFieldSchema = OpenAPIV3_1.SchemaObject & InputFieldJSONSchemaExtra; + +export type OpenAPIV3_1SchemaOrRef = OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject; + +export interface ArraySchemaObject extends InvocationBaseSchemaObject { + type: OpenAPIV3_1.ArraySchemaObjectType; + items: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject; +} +export interface NonArraySchemaObject extends InvocationBaseSchemaObject { + type?: OpenAPIV3_1.NonArraySchemaObjectType; +} + +export type InvocationSchemaObject = (ArraySchemaObject | NonArraySchemaObject) & { class: 'invocation' }; + +export const isSchemaObject = ( + obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject | undefined +): obj is OpenAPIV3_1.SchemaObject => Boolean(obj && !('$ref' in obj)); + +export const isArraySchemaObject = ( + obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject | undefined +): obj is OpenAPIV3_1.ArraySchemaObject => Boolean(obj && !('$ref' in obj) && obj.type === 'array'); + +export const isNonArraySchemaObject = ( + obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject | undefined +): obj is OpenAPIV3_1.NonArraySchemaObject => Boolean(obj && !('$ref' in obj) && obj.type !== 'array'); + +export const isRefObject = ( + obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject | undefined +): obj is OpenAPIV3_1.ReferenceObject => Boolean(obj && '$ref' in obj); + +export const isInvocationSchemaObject = ( + obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject | InvocationSchemaObject +): obj is InvocationSchemaObject => 'class' in obj && obj.class === 'invocation'; + +export const isInvocationOutputSchemaObject = ( + obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject | InvocationOutputSchemaObject +): obj is InvocationOutputSchemaObject => 'class' in obj && obj.class === 'output'; + +export const isInvocationFieldSchema = ( + obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject +): obj is InvocationFieldSchema => !('$ref' in obj); diff --git a/invokeai/frontend/web/src/features/nodes/types/v2/semver.ts b/invokeai/frontend/web/src/features/nodes/types/v2/semver.ts new file mode 100644 index 00000000000..3ba330eac47 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/v2/semver.ts @@ -0,0 +1,21 @@ +import { z } from 'zod'; + +// Schemas and types for working with semver + +const zVersionInt = z.coerce.number().int().min(0); + +export const zSemVer = z.string().refine((val) => { + const [major, minor, patch] = val.split('.'); + return ( + zVersionInt.safeParse(major).success && zVersionInt.safeParse(minor).success && zVersionInt.safeParse(patch).success + ); +}); + +export const zParsedSemver = zSemVer.transform((val) => { + const [major, minor, patch] = val.split('.'); + return { + major: Number(major), + minor: Number(minor), + patch: Number(patch), + }; +}); diff --git a/invokeai/frontend/web/src/features/nodes/types/v2/workflow.ts b/invokeai/frontend/web/src/features/nodes/types/v2/workflow.ts new file mode 100644 index 00000000000..723a354013b --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/v2/workflow.ts @@ -0,0 +1,89 @@ +import { z } from 'zod'; + +import { zFieldIdentifier } from './field'; +import { zInvocationNodeData, zNotesNodeData } from './invocation'; + +// #region Workflow misc +export const zXYPosition = z + .object({ + x: z.number(), + y: z.number(), + }) + .default({ x: 0, y: 0 }); +export type XYPosition = z.infer; + +export const zDimension = z.number().gt(0).nullish(); +export type Dimension = z.infer; + +export const zWorkflowCategory = z.enum(['user', 'default', 'project']); +export type WorkflowCategory = z.infer; +// #endregion + +// #region Workflow Nodes +export const zWorkflowInvocationNode = z.object({ + id: z.string().trim().min(1), + type: z.literal('invocation'), + data: zInvocationNodeData, + width: zDimension, + height: zDimension, + position: zXYPosition, +}); +export const zWorkflowNotesNode = z.object({ + id: z.string().trim().min(1), + type: z.literal('notes'), + data: zNotesNodeData, + width: zDimension, + height: zDimension, + position: zXYPosition, +}); +export const zWorkflowNode = z.union([zWorkflowInvocationNode, zWorkflowNotesNode]); + +export type WorkflowInvocationNode = z.infer; +export type WorkflowNotesNode = z.infer; +export type WorkflowNode = z.infer; + +export const isWorkflowInvocationNode = (val: unknown): val is WorkflowInvocationNode => + zWorkflowInvocationNode.safeParse(val).success; +// #endregion + +// #region Workflow Edges +export const zWorkflowEdgeBase = z.object({ + id: z.string().trim().min(1), + source: z.string().trim().min(1), + target: z.string().trim().min(1), +}); +export const zWorkflowEdgeDefault = zWorkflowEdgeBase.extend({ + type: z.literal('default'), + sourceHandle: z.string().trim().min(1), + targetHandle: z.string().trim().min(1), +}); +export const zWorkflowEdgeCollapsed = zWorkflowEdgeBase.extend({ + type: z.literal('collapsed'), +}); +export const zWorkflowEdge = z.union([zWorkflowEdgeDefault, zWorkflowEdgeCollapsed]); + +export type WorkflowEdgeDefault = z.infer; +export type WorkflowEdgeCollapsed = z.infer; +export type WorkflowEdge = z.infer; +// #endregion + +// #region Workflow +export const zWorkflowV2 = z.object({ + id: z.string().min(1).optional(), + name: z.string(), + author: z.string(), + description: z.string(), + version: z.string(), + contact: z.string(), + tags: z.string(), + notes: z.string(), + nodes: z.array(zWorkflowNode), + edges: z.array(zWorkflowEdge), + exposedFields: z.array(zFieldIdentifier), + meta: z.object({ + category: zWorkflowCategory.default('user'), + version: z.literal('2.0.0'), + }), +}); +export type WorkflowV2 = z.infer; +// #endregion diff --git a/invokeai/frontend/web/src/features/nodes/types/workflow.ts b/invokeai/frontend/web/src/features/nodes/types/workflow.ts index 723a354013b..adad7c0f219 100644 --- a/invokeai/frontend/web/src/features/nodes/types/workflow.ts +++ b/invokeai/frontend/web/src/features/nodes/types/workflow.ts @@ -24,16 +24,12 @@ export const zWorkflowInvocationNode = z.object({ id: z.string().trim().min(1), type: z.literal('invocation'), data: zInvocationNodeData, - width: zDimension, - height: zDimension, position: zXYPosition, }); export const zWorkflowNotesNode = z.object({ id: z.string().trim().min(1), type: z.literal('notes'), data: zNotesNodeData, - width: zDimension, - height: zDimension, position: zXYPosition, }); export const zWorkflowNode = z.union([zWorkflowInvocationNode, zWorkflowNotesNode]); @@ -68,7 +64,7 @@ export type WorkflowEdge = z.infer; // #endregion // #region Workflow -export const zWorkflowV2 = z.object({ +export const zWorkflowV3 = z.object({ id: z.string().min(1).optional(), name: z.string(), author: z.string(), @@ -82,8 +78,8 @@ export const zWorkflowV2 = z.object({ exposedFields: z.array(zFieldIdentifier), meta: z.object({ category: zWorkflowCategory.default('user'), - version: z.literal('2.0.0'), + version: z.literal('3.0.0'), }), }); -export type WorkflowV2 = z.infer; +export type WorkflowV3 = z.infer; // #endregion diff --git a/invokeai/frontend/web/src/features/nodes/util/node/buildInvocationNode.ts b/invokeai/frontend/web/src/features/nodes/util/node/buildInvocationNode.ts index ea40bd4660f..af19aa86eaf 100644 --- a/invokeai/frontend/web/src/features/nodes/util/node/buildInvocationNode.ts +++ b/invokeai/frontend/web/src/features/nodes/util/node/buildInvocationNode.ts @@ -1,5 +1,5 @@ import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants'; -import type { FieldInputInstance, FieldOutputInstance } from 'features/nodes/types/field'; +import type { FieldInputInstance } from 'features/nodes/types/field'; import type { InvocationNode, InvocationTemplate } from 'features/nodes/types/invocation'; import { buildFieldInputInstance } from 'features/nodes/util/schema/buildFieldInputInstance'; import { reduce } from 'lodash-es'; @@ -24,25 +24,6 @@ export const buildInvocationNode = (position: XYPosition, template: InvocationTe {} as Record ); - const outputs = reduce( - template.outputs, - (outputsAccumulator, outputTemplate, outputName) => { - const fieldId = uuidv4(); - - const outputFieldValue: FieldOutputInstance = { - id: fieldId, - name: outputName, - type: outputTemplate.type, - fieldKind: 'output', - }; - - outputsAccumulator[outputName] = outputFieldValue; - - return outputsAccumulator; - }, - {} as Record - ); - const node: InvocationNode = { ...SHARED_NODE_PROPERTIES, id: nodeId, @@ -58,7 +39,6 @@ export const buildInvocationNode = (position: XYPosition, template: InvocationTe isIntermediate: type === 'save_image' ? false : true, useCache: template.useCache, inputs, - outputs, }, }; diff --git a/invokeai/frontend/web/src/features/nodes/util/node/nodeUpdate.ts b/invokeai/frontend/web/src/features/nodes/util/node/nodeUpdate.ts index f195c49d30a..5ece51d0f30 100644 --- a/invokeai/frontend/web/src/features/nodes/util/node/nodeUpdate.ts +++ b/invokeai/frontend/web/src/features/nodes/util/node/nodeUpdate.ts @@ -54,6 +54,5 @@ export const updateNode = (node: InvocationNode, template: InvocationTemplate): // Remove any fields that are not in the template clone.data.inputs = pick(clone.data.inputs, keys(defaults.data.inputs)); - clone.data.outputs = pick(clone.data.outputs, keys(defaults.data.outputs)); return clone; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts index dd3cf0ad7b4..f8097566c95 100644 --- a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts +++ b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts @@ -23,11 +23,8 @@ const FIELD_VALUE_FALLBACK_MAP: Record = export const buildFieldInputInstance = (id: string, template: FieldInputTemplate): FieldInputInstance => { const fieldInstance: FieldInputInstance = { - id, name: template.name, - type: template.type, label: '', - fieldKind: 'input' as const, value: template.default ?? get(FIELD_VALUE_FALLBACK_MAP, template.type.name), }; diff --git a/invokeai/frontend/web/src/features/nodes/util/workflow/buildWorkflow.ts b/invokeai/frontend/web/src/features/nodes/util/workflow/buildWorkflow.ts index 70775a98823..720da164648 100644 --- a/invokeai/frontend/web/src/features/nodes/util/workflow/buildWorkflow.ts +++ b/invokeai/frontend/web/src/features/nodes/util/workflow/buildWorkflow.ts @@ -2,8 +2,8 @@ import { logger } from 'app/logging/logger'; import { parseify } from 'common/util/serialize'; import type { NodesState, WorkflowsState } from 'features/nodes/store/types'; import { isInvocationNode, isNotesNode } from 'features/nodes/types/invocation'; -import type { WorkflowV2 } from 'features/nodes/types/workflow'; -import { zWorkflowV2 } from 'features/nodes/types/workflow'; +import type { WorkflowV3 } from 'features/nodes/types/workflow'; +import { zWorkflowV3 } from 'features/nodes/types/workflow'; import i18n from 'i18n'; import { cloneDeep, pick } from 'lodash-es'; import { fromZodError } from 'zod-validation-error'; @@ -25,14 +25,14 @@ const workflowKeys = [ 'exposedFields', 'meta', 'id', -] satisfies (keyof WorkflowV2)[]; +] satisfies (keyof WorkflowV3)[]; -export type BuildWorkflowFunction = (arg: BuildWorkflowArg) => WorkflowV2; +export type BuildWorkflowFunction = (arg: BuildWorkflowArg) => WorkflowV3; -export const buildWorkflowFast: BuildWorkflowFunction = ({ nodes, edges, workflow }: BuildWorkflowArg): WorkflowV2 => { +export const buildWorkflowFast: BuildWorkflowFunction = ({ nodes, edges, workflow }: BuildWorkflowArg): WorkflowV3 => { const clonedWorkflow = pick(cloneDeep(workflow), workflowKeys); - const newWorkflow: WorkflowV2 = { + const newWorkflow: WorkflowV3 = { ...clonedWorkflow, nodes: [], edges: [], @@ -45,8 +45,6 @@ export const buildWorkflowFast: BuildWorkflowFunction = ({ nodes, edges, workflo type: node.type, data: cloneDeep(node.data), position: { ...node.position }, - width: node.width, - height: node.height, }); } else if (isNotesNode(node) && node.type) { newWorkflow.nodes.push({ @@ -54,8 +52,6 @@ export const buildWorkflowFast: BuildWorkflowFunction = ({ nodes, edges, workflo type: node.type, data: cloneDeep(node.data), position: { ...node.position }, - width: node.width, - height: node.height, }); } }); @@ -83,12 +79,12 @@ export const buildWorkflowFast: BuildWorkflowFunction = ({ nodes, edges, workflo return newWorkflow; }; -export const buildWorkflowWithValidation = ({ nodes, edges, workflow }: BuildWorkflowArg): WorkflowV2 | null => { +export const buildWorkflowWithValidation = ({ nodes, edges, workflow }: BuildWorkflowArg): WorkflowV3 | null => { // builds what really, really should be a valid workflow const workflowToValidate = buildWorkflowFast({ nodes, edges, workflow }); // but bc we are storing this in the DB, let's be extra sure - const result = zWorkflowV2.safeParse(workflowToValidate); + const result = zWorkflowV3.safeParse(workflowToValidate); if (!result.success) { const { message } = fromZodError(result.error, { diff --git a/invokeai/frontend/web/src/features/nodes/util/workflow/migrations.ts b/invokeai/frontend/web/src/features/nodes/util/workflow/migrations.ts index a2677f3d174..a023c96ba92 100644 --- a/invokeai/frontend/web/src/features/nodes/util/workflow/migrations.ts +++ b/invokeai/frontend/web/src/features/nodes/util/workflow/migrations.ts @@ -6,8 +6,10 @@ import { zSemVer } from 'features/nodes/types/semver'; import { FIELD_TYPE_V1_TO_FIELD_TYPE_V2_MAPPING } from 'features/nodes/types/v1/fieldTypeMap'; import type { WorkflowV1 } from 'features/nodes/types/v1/workflowV1'; import { zWorkflowV1 } from 'features/nodes/types/v1/workflowV1'; -import type { WorkflowV2 } from 'features/nodes/types/workflow'; -import { zWorkflowV2 } from 'features/nodes/types/workflow'; +import type { WorkflowV2 } from 'features/nodes/types/v2/workflow'; +import { zWorkflowV2 } from 'features/nodes/types/v2/workflow'; +import type { WorkflowV3 } from 'features/nodes/types/workflow'; +import { zWorkflowV3 } from 'features/nodes/types/workflow'; import { t } from 'i18next'; import { forEach } from 'lodash-es'; import { z } from 'zod'; @@ -30,7 +32,7 @@ const zWorkflowMetaVersion = z.object({ * - Workflow schema version bumped to 2.0.0 */ const migrateV1toV2 = (workflowToMigrate: WorkflowV1): WorkflowV2 => { - const invocationTemplates = $store.get()?.getState().nodeTemplates.templates; + const invocationTemplates = $store.get()?.getState().nodes.templates; if (!invocationTemplates) { throw new Error(t('app.storeNotInitialized')); @@ -70,26 +72,34 @@ const migrateV1toV2 = (workflowToMigrate: WorkflowV1): WorkflowV2 => { return zWorkflowV2.parse(workflowToMigrate); }; +const migrateV2toV3 = (workflowToMigrate: WorkflowV2): WorkflowV3 => { + // Bump version + (workflowToMigrate as unknown as WorkflowV3).meta.version = '3.0.0'; + // Parsing strips out any extra properties not in the latest version + return zWorkflowV3.parse(workflowToMigrate); +}; + /** * Parses a workflow and migrates it to the latest version if necessary. */ -export const parseAndMigrateWorkflow = (data: unknown): WorkflowV2 => { +export const parseAndMigrateWorkflow = (data: unknown): WorkflowV3 => { const workflowVersionResult = zWorkflowMetaVersion.safeParse(data); if (!workflowVersionResult.success) { throw new WorkflowVersionError(t('nodes.unableToGetWorkflowVersion')); } - const { version } = workflowVersionResult.data.meta; + let workflow = data as WorkflowV1 | WorkflowV2 | WorkflowV3; - if (version === '1.0.0') { - const v1 = zWorkflowV1.parse(data); - return migrateV1toV2(v1); + if (workflow.meta.version === '1.0.0') { + const v1 = zWorkflowV1.parse(workflow); + workflow = migrateV1toV2(v1); } - if (version === '2.0.0') { - return zWorkflowV2.parse(data); + if (workflow.meta.version === '2.0.0') { + const v2 = zWorkflowV2.parse(workflow); + workflow = migrateV2toV3(v2); } - throw new WorkflowVersionError(t('nodes.unrecognizedWorkflowVersion', { version })); + return workflow as WorkflowV3; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.ts b/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.ts index 848d2aee77a..5096e588b06 100644 --- a/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.ts +++ b/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.ts @@ -1,6 +1,6 @@ import { parseify } from 'common/util/serialize'; import type { InvocationTemplate } from 'features/nodes/types/invocation'; -import type { WorkflowV2 } from 'features/nodes/types/workflow'; +import type { WorkflowV3 } from 'features/nodes/types/workflow'; import { isWorkflowInvocationNode } from 'features/nodes/types/workflow'; import { getNeedsUpdate } from 'features/nodes/util/node/nodeUpdate'; import { t } from 'i18next'; @@ -16,7 +16,7 @@ type WorkflowWarning = { }; type ValidateWorkflowResult = { - workflow: WorkflowV2; + workflow: WorkflowV3; warnings: WorkflowWarning[]; }; diff --git a/invokeai/frontend/web/src/features/workflowLibrary/hooks/useSaveWorkflow.ts b/invokeai/frontend/web/src/features/workflowLibrary/hooks/useSaveWorkflow.ts index 5d484b68971..7b49d70213f 100644 --- a/invokeai/frontend/web/src/features/workflowLibrary/hooks/useSaveWorkflow.ts +++ b/invokeai/frontend/web/src/features/workflowLibrary/hooks/useSaveWorkflow.ts @@ -3,7 +3,7 @@ import { useToast } from '@invoke-ai/ui-library'; import { useAppDispatch } from 'app/store/storeHooks'; import { $builtWorkflow } from 'features/nodes/hooks/useWorkflowWatcher'; import { workflowIDChanged, workflowSaved } from 'features/nodes/store/workflowSlice'; -import type { WorkflowV2 } from 'features/nodes/types/workflow'; +import type { WorkflowV3 } from 'features/nodes/types/workflow'; import { workflowUpdated } from 'features/workflowLibrary/store/actions'; import { useCallback, useRef } from 'react'; import { useTranslation } from 'react-i18next'; @@ -18,7 +18,7 @@ type UseSaveLibraryWorkflowReturn = { type UseSaveLibraryWorkflow = () => UseSaveLibraryWorkflowReturn; -export const isWorkflowWithID = (workflow: WorkflowV2): workflow is O.Required => +export const isWorkflowWithID = (workflow: WorkflowV3): workflow is O.Required => Boolean(workflow.id); export const useSaveLibraryWorkflow: UseSaveLibraryWorkflow = () => { From 88b2fbefeb65c7302bc9b7bd1a2723f639afd328 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 13 Feb 2024 22:51:44 +1100 Subject: [PATCH 069/340] feat(ui): add vitest - Add vitest. - Consolidate vite configs into single file (easier to config everything based on env for testing) --- invokeai/frontend/web/config/common.mts | 12 - .../frontend/web/config/vite.app.config.mts | 33 --- .../web/config/vite.package.config.mts | 46 ---- invokeai/frontend/web/package.json | 7 +- invokeai/frontend/web/pnpm-lock.yaml | 222 +++++++++++++++++- invokeai/frontend/web/vite.config.mts | 88 ++++++- 6 files changed, 306 insertions(+), 102 deletions(-) delete mode 100644 invokeai/frontend/web/config/common.mts delete mode 100644 invokeai/frontend/web/config/vite.app.config.mts delete mode 100644 invokeai/frontend/web/config/vite.package.config.mts diff --git a/invokeai/frontend/web/config/common.mts b/invokeai/frontend/web/config/common.mts deleted file mode 100644 index fd559cabd1e..00000000000 --- a/invokeai/frontend/web/config/common.mts +++ /dev/null @@ -1,12 +0,0 @@ -import react from '@vitejs/plugin-react-swc'; -import { visualizer } from 'rollup-plugin-visualizer'; -import type { PluginOption, UserConfig } from 'vite'; -import eslint from 'vite-plugin-eslint'; -import tsconfigPaths from 'vite-tsconfig-paths'; - -export const commonPlugins: UserConfig['plugins'] = [ - react(), - eslint(), - tsconfigPaths(), - visualizer() as unknown as PluginOption, -]; diff --git a/invokeai/frontend/web/config/vite.app.config.mts b/invokeai/frontend/web/config/vite.app.config.mts deleted file mode 100644 index 9683ed26a48..00000000000 --- a/invokeai/frontend/web/config/vite.app.config.mts +++ /dev/null @@ -1,33 +0,0 @@ -import type { UserConfig } from 'vite'; - -import { commonPlugins } from './common.mjs'; - -export const appConfig: UserConfig = { - base: './', - plugins: [...commonPlugins], - build: { - chunkSizeWarningLimit: 1500, - }, - server: { - // Proxy HTTP requests to the flask server - proxy: { - // Proxy socket.io to the nodes socketio server - '/ws/socket.io': { - target: 'ws://127.0.0.1:9090', - ws: true, - }, - // Proxy openapi schema definiton - '/openapi.json': { - target: 'http://127.0.0.1:9090/openapi.json', - rewrite: (path) => path.replace(/^\/openapi.json/, ''), - changeOrigin: true, - }, - // proxy nodes api - '/api/v1': { - target: 'http://127.0.0.1:9090/api/v1', - rewrite: (path) => path.replace(/^\/api\/v1/, ''), - changeOrigin: true, - }, - }, - }, -}; diff --git a/invokeai/frontend/web/config/vite.package.config.mts b/invokeai/frontend/web/config/vite.package.config.mts deleted file mode 100644 index 3c05d52e005..00000000000 --- a/invokeai/frontend/web/config/vite.package.config.mts +++ /dev/null @@ -1,46 +0,0 @@ -import path from 'path'; -import type { UserConfig } from 'vite'; -import cssInjectedByJsPlugin from 'vite-plugin-css-injected-by-js'; -import dts from 'vite-plugin-dts'; - -import { commonPlugins } from './common.mjs'; - -export const packageConfig: UserConfig = { - base: './', - plugins: [ - ...commonPlugins, - dts({ - insertTypesEntry: true, - }), - cssInjectedByJsPlugin(), - ], - build: { - cssCodeSplit: true, - lib: { - entry: path.resolve(__dirname, '../src/index.ts'), - name: 'InvokeAIUI', - fileName: (format) => `invoke-ai-ui.${format}.js`, - }, - rollupOptions: { - external: ['react', 'react-dom', '@emotion/react', '@chakra-ui/react', '@invoke-ai/ui-library'], - output: { - globals: { - react: 'React', - 'react-dom': 'ReactDOM', - '@emotion/react': 'EmotionReact', - '@invoke-ai/ui-library': 'UiLibrary', - }, - }, - }, - }, - resolve: { - alias: { - app: path.resolve(__dirname, '../src/app'), - assets: path.resolve(__dirname, '../src/assets'), - common: path.resolve(__dirname, '../src/common'), - features: path.resolve(__dirname, '../src/features'), - services: path.resolve(__dirname, '../src/services'), - theme: path.resolve(__dirname, '../src/theme'), - }, - }, -}; diff --git a/invokeai/frontend/web/package.json b/invokeai/frontend/web/package.json index cd95183c7a4..b2838e538ce 100644 --- a/invokeai/frontend/web/package.json +++ b/invokeai/frontend/web/package.json @@ -33,7 +33,9 @@ "preinstall": "npx only-allow pnpm", "storybook": "storybook dev -p 6006", "build-storybook": "storybook build", - "unimported": "npx unimported" + "unimported": "npx unimported", + "test": "vitest", + "test:no-watch": "vitest --no-watch" }, "madge": { "excludeRegExp": [ @@ -157,7 +159,8 @@ "vite-plugin-css-injected-by-js": "^3.3.1", "vite-plugin-dts": "^3.7.1", "vite-plugin-eslint": "^1.8.1", - "vite-tsconfig-paths": "^4.3.1" + "vite-tsconfig-paths": "^4.3.1", + "vitest": "^1.2.2" }, "pnpm": { "patchedDependencies": { diff --git a/invokeai/frontend/web/pnpm-lock.yaml b/invokeai/frontend/web/pnpm-lock.yaml index 1d9083d1b44..f3bf68cf1da 100644 --- a/invokeai/frontend/web/pnpm-lock.yaml +++ b/invokeai/frontend/web/pnpm-lock.yaml @@ -215,7 +215,7 @@ devDependencies: version: 7.6.10(react-dom@18.2.0)(react@18.2.0)(typescript@5.3.3)(vite@5.0.12) '@storybook/test': specifier: ^7.6.10 - version: 7.6.10 + version: 7.6.10(vitest@1.2.2) '@storybook/theming': specifier: ^7.6.10 version: 7.6.10(react-dom@18.2.0)(react@18.2.0) @@ -318,6 +318,9 @@ devDependencies: vite-tsconfig-paths: specifier: ^4.3.1 version: 4.3.1(typescript@5.3.3)(vite@5.0.12) + vitest: + specifier: ^1.2.2 + version: 1.2.2(@types/node@20.11.5) packages: @@ -5464,7 +5467,7 @@ packages: - supports-color dev: true - /@storybook/test@7.6.10: + /@storybook/test@7.6.10(vitest@1.2.2): resolution: {integrity: sha512-dn/T+HcWOBlVh3c74BHurp++BaqBoQgNbSIaXlYDpJoZ+DzNIoEQVsWFYm5gCbtKK27iFd4n52RiQI3f6Vblqw==} dependencies: '@storybook/client-logger': 7.6.10 @@ -5472,7 +5475,7 @@ packages: '@storybook/instrumenter': 7.6.10 '@storybook/preview-api': 7.6.10 '@testing-library/dom': 9.3.4 - '@testing-library/jest-dom': 6.2.0 + '@testing-library/jest-dom': 6.2.0(vitest@1.2.2) '@testing-library/user-event': 14.3.0(@testing-library/dom@9.3.4) '@types/chai': 4.3.11 '@vitest/expect': 0.34.7 @@ -5652,7 +5655,7 @@ packages: pretty-format: 27.5.1 dev: true - /@testing-library/jest-dom@6.2.0: + /@testing-library/jest-dom@6.2.0(vitest@1.2.2): resolution: {integrity: sha512-+BVQlJ9cmEn5RDMUS8c2+TU6giLvzaHZ8sU/x0Jj7fk+6/46wPdwlgOPcpxS17CjcanBi/3VmGMqVr2rmbUmNw==} engines: {node: '>=14', npm: '>=6', yarn: '>=1'} peerDependencies: @@ -5678,6 +5681,7 @@ packages: dom-accessibility-api: 0.6.3 lodash: 4.17.21 redent: 3.0.0 + vitest: 1.2.2(@types/node@20.11.5) dev: true /@testing-library/user-event@14.3.0(@testing-library/dom@9.3.4): @@ -6490,12 +6494,42 @@ packages: chai: 4.4.1 dev: true + /@vitest/expect@1.2.2: + resolution: {integrity: sha512-3jpcdPAD7LwHUUiT2pZTj2U82I2Tcgg2oVPvKxhn6mDI2On6tfvPQTjAI4628GUGDZrCm4Zna9iQHm5cEexOAg==} + dependencies: + '@vitest/spy': 1.2.2 + '@vitest/utils': 1.2.2 + chai: 4.4.1 + dev: true + + /@vitest/runner@1.2.2: + resolution: {integrity: sha512-JctG7QZ4LSDXr5CsUweFgcpEvrcxOV1Gft7uHrvkQ+fsAVylmWQvnaAr/HDp3LAH1fztGMQZugIheTWjaGzYIg==} + dependencies: + '@vitest/utils': 1.2.2 + p-limit: 5.0.0 + pathe: 1.1.2 + dev: true + + /@vitest/snapshot@1.2.2: + resolution: {integrity: sha512-SmGY4saEw1+bwE1th6S/cZmPxz/Q4JWsl7LvbQIky2tKE35US4gd0Mjzqfr84/4OD0tikGWaWdMja/nWL5NIPA==} + dependencies: + magic-string: 0.30.5 + pathe: 1.1.2 + pretty-format: 29.7.0 + dev: true + /@vitest/spy@0.34.7: resolution: {integrity: sha512-NMMSzOY2d8L0mcOt4XcliDOS1ISyGlAXuQtERWVOoVHnKwmG+kKhinAiGw3dTtMQWybfa89FG8Ucg9tiC/FhTQ==} dependencies: tinyspy: 2.2.0 dev: true + /@vitest/spy@1.2.2: + resolution: {integrity: sha512-k9Gcahssw8d7X3pSLq3e3XEu/0L78mUkCjivUqCQeXJm9clfXR/Td8+AP+VC1O6fKPIDLcHDTAmBOINVuv6+7g==} + dependencies: + tinyspy: 2.2.0 + dev: true + /@vitest/utils@0.34.7: resolution: {integrity: sha512-ziAavQLpCYS9sLOorGrFFKmy2gnfiNU0ZJ15TsMz/K92NAPS/rp9K4z6AJQQk5Y8adCy4Iwpxy7pQumQ/psnRg==} dependencies: @@ -6504,6 +6538,15 @@ packages: pretty-format: 29.7.0 dev: true + /@vitest/utils@1.2.2: + resolution: {integrity: sha512-WKITBHLsBHlpjnDQahr+XK6RE7MiAsgrIkr0pGhQ9ygoxBfUeG0lUG5iLlzqjmKSlBv3+j5EGsriBzh+C3Tq9g==} + dependencies: + diff-sequences: 29.6.3 + estree-walker: 3.0.3 + loupe: 2.3.7 + pretty-format: 29.7.0 + dev: true + /@volar/language-core@1.11.1: resolution: {integrity: sha512-dOcNn3i9GgZAcJt43wuaEykSluAuOkQgzni1cuxLxTV0nJKanQztp7FxyswdRILaKH+P2XZMPRp2S4MV/pElCw==} dependencies: @@ -7184,6 +7227,11 @@ packages: engines: {node: '>=0.4.0'} dev: true + /acorn-walk@8.3.2: + resolution: {integrity: sha512-cjkyv4OtNCIeqhHrfS81QWXoCBPExR/J62oyEqepVw8WaQeSqpW2uhuLPh1m9eWhDuOo/jUXVTlifvesOWp/4A==} + engines: {node: '>=0.4.0'} + dev: true + /acorn@7.4.1: resolution: {integrity: sha512-nQyp0o1/mNdbTO1PO6kHkwSrmgZ0MT/jCCpNiwbUjGoRN4dlBhqJtoQuCnEOKzgTVwg0ZWiCoQy6SxMebQVh8A==} engines: {node: '>=0.4.0'} @@ -7661,6 +7709,11 @@ packages: engines: {node: '>= 0.8'} dev: true + /cac@6.7.14: + resolution: {integrity: sha512-b6Ilus+c3RrdDk+JhLKUAQfzzgLEPy6wcXqS7f/xe1EETvsDP6GORG7SFuOs6cID5YkqchW/LXZbX5bc8j7ZcQ==} + engines: {node: '>=8'} + dev: true + /call-bind@1.0.5: resolution: {integrity: sha512-C3nQxfFZxFRVoJoGKKI8y3MOEo129NQ+FgQ08iye+Mk4zNZZGdjfs06bVTr+DBSlA66Q2VEcMki/cUCP4SercQ==} dependencies: @@ -9173,6 +9226,12 @@ packages: resolution: {integrity: sha512-Rfkk/Mp/DL7JVje3u18FxFujQlTNR2q6QfMSMB7AvCBx91NGj/ba3kCfza0f6dVDbw7YlRf/nDrn7pQrCCyQ/w==} dev: true + /estree-walker@3.0.3: + resolution: {integrity: sha512-7RUKfXgSMMkzt6ZuXmqapOurLGPPfgj6l9uRZ7lRGolvk0y2yocc35LdcxKC5PQZdn2DMqioAQ2NoWcrTKmm6g==} + dependencies: + '@types/estree': 1.0.5 + dev: true + /esutils@2.0.3: resolution: {integrity: sha512-kVscqXk4OCp68SZ0dkgEKVi6/8ij300KBWTJq32P/dYeWTSwK41WyTxalN1eRmA5Z9UU/LX9D7FWSmV9SAYx6g==} engines: {node: '>=0.10.0'} @@ -10547,6 +10606,10 @@ packages: hasBin: true dev: true + /jsonc-parser@3.2.1: + resolution: {integrity: sha512-AilxAyFOAcK5wA1+LeaySVBrHsGQvUFCDWXKpZjzaL0PqW+xfBOttn8GNtWKFWqneyMZj41MWF9Kl6iPWLwgOA==} + dev: true + /jsondiffpatch@0.6.0: resolution: {integrity: sha512-3QItJOXp2AP1uv7waBkao5nCvhEv+QmJAd38Ybq7wNI74Q+BBmnLn4EDKz6yI9xGAIQoUF87qHt+kc1IVxB4zQ==} engines: {node: ^18.0.0 || >=20.0.0} @@ -10648,6 +10711,14 @@ packages: engines: {node: ^12.20.0 || ^14.13.1 || >=16.0.0} dev: true + /local-pkg@0.5.0: + resolution: {integrity: sha512-ok6z3qlYyCDS4ZEU27HaU6x/xZa9Whf8jD4ptH5UZTQYZVYeb9bnZ3ojVhiJNLiXK1Hfc0GNbLXcmZ5plLDDBg==} + engines: {node: '>=14'} + dependencies: + mlly: 1.5.0 + pkg-types: 1.0.3 + dev: true + /locate-path@3.0.0: resolution: {integrity: sha512-7AO748wWnIhNqAuaty2ZWHkQHRSNfPVIsPIfwEOWO22AmaoVrWavlOcMR5nzTLNYvp36X220/maaRsrec1G65A==} engines: {node: '>=6'} @@ -10986,6 +11057,15 @@ packages: hasBin: true dev: true + /mlly@1.5.0: + resolution: {integrity: sha512-NPVQvAY1xr1QoVeG0cy8yUYC7FQcOx6evl/RjT1wL5FvzPnzOysoqB/jmx/DhssT2dYa8nxECLAaFI/+gVLhDQ==} + dependencies: + acorn: 8.11.3 + pathe: 1.1.2 + pkg-types: 1.0.3 + ufo: 1.3.2 + dev: true + /module-definition@3.4.0: resolution: {integrity: sha512-XxJ88R1v458pifaSkPNLUTdSPNVGMP2SXVncVmApGO+gAfrLANiYe6JofymCzVceGOMwQE2xogxBSc8uB7XegA==} engines: {node: '>=6.0'} @@ -11380,6 +11460,13 @@ packages: yocto-queue: 0.1.0 dev: true + /p-limit@5.0.0: + resolution: {integrity: sha512-/Eaoq+QyLSiXQ4lyYV23f14mZRQcXnxfHrN0vCai+ak9G0pp9iEQukIIZq5NccEvwRB8PUnZT0KsOoDCINS1qQ==} + engines: {node: '>=18'} + dependencies: + yocto-queue: 1.0.0 + dev: true + /p-locate@3.0.0: resolution: {integrity: sha512-x+12w/To+4GFfgJhBEpiDcLozRJGegY+Ei7/z0tSLkMmxGZNybVMSfWj9aJn8Z5Fc7dBUNJOOVgPv2H7IwulSQ==} engines: {node: '>=6'} @@ -11550,6 +11637,14 @@ packages: find-up: 5.0.0 dev: true + /pkg-types@1.0.3: + resolution: {integrity: sha512-nN7pYi0AQqJnoLPC9eHFQ8AcyaixBUOwvqc5TDnIKCMEE6I0y8P7OKA7fPexsXGCGxQDl/cmrLAp26LhcwxZ4A==} + dependencies: + jsonc-parser: 3.2.1 + mlly: 1.5.0 + pathe: 1.1.2 + dev: true + /pluralize@8.0.0: resolution: {integrity: sha512-Nc3IT5yHzflTfbjgqWcCPpo7DaKy4FnpB0l/zCAW0Tc7jxAiuqSxHasntB3D7887LSrA93kDJ9IXovxJYxyLCA==} engines: {node: '>=4'} @@ -12850,6 +12945,10 @@ packages: object-inspect: 1.13.1 dev: true + /siginfo@2.0.0: + resolution: {integrity: sha512-ybx0WO1/8bSBLEWXZvEd7gMW3Sn3JFlW3TvX1nREbDLRNQNaeNN8WK0meBwPdAaOI7TtRRRJn/Es1zhrrCHu7g==} + dev: true + /signal-exit@3.0.7: resolution: {integrity: sha512-wnD2ZE+l+SPC/uoS0vXeE9L1+0wuaMqKlfz9AMUo38JsyLSBWSFcHR1Rri62LZc12vLr1gb3jl7iwQhgwpAbGQ==} dev: true @@ -12968,6 +13067,10 @@ packages: stackframe: 1.3.4 dev: false + /stackback@0.0.2: + resolution: {integrity: sha512-1XMJE5fQo1jGH6Y/7ebnwPOBEkIEnT4QF32d5R1+VXdXveM0IBMJt8zfaxX1P3QhVwrYe+576+jkANtSS2mBbw==} + dev: true + /stackframe@1.3.4: resolution: {integrity: sha512-oeVtt7eWQS+Na6F//S4kJ2K2VbRlS9D43mAlMyVpVWovy9o+jfgH8O9agzANzaiLjclA0oYzUXEM4PurhSUChw==} dev: false @@ -12992,6 +13095,10 @@ packages: engines: {node: '>= 0.8'} dev: true + /std-env@3.7.0: + resolution: {integrity: sha512-JPbdCEQLj1w5GilpiHAx3qJvFndqybBysA3qUOnznweH4QbNYUsW/ea8QzSrnh0vNsezMMw5bcVool8lM0gwzg==} + dev: true + /stop-iteration-iterator@1.0.0: resolution: {integrity: sha512-iCGQj+0l0HOdZ2AEeBADlsRC+vsnDsZsbdSiH1yNSjcfKM7fdpCMfqAL/dwF5BLiw/XhRft/Wax6zQbhq2BcjQ==} engines: {node: '>= 0.4'} @@ -13161,6 +13268,12 @@ packages: engines: {node: '>=8'} dev: true + /strip-literal@1.3.0: + resolution: {integrity: sha512-PugKzOsyXpArk0yWmUwqOZecSO0GH0bPoctLcqNDH9J04pVW3lflYE0ujElBGTloevcxF5MofAOZ7C5l2b+wLg==} + dependencies: + acorn: 8.11.3 + dev: true + /stylis@4.2.0: resolution: {integrity: sha512-Orov6g6BB1sDfYgzWfTHDOxamtX1bE/zo104Dh9e6fqJ3PooipYyfJ0pUmrZO2wAvO8YbEyeFrkV91XTsGMSrw==} dev: false @@ -13311,6 +13424,15 @@ packages: /tiny-invariant@1.3.1: resolution: {integrity: sha512-AD5ih2NlSssTCwsMznbvwMZpJ1cbhkGd2uueNxzv2jDlEeZdU04JQfRnggJQ8DrcVBGjAsCKwFBbDlVNtEMlzw==} + /tinybench@2.6.0: + resolution: {integrity: sha512-N8hW3PG/3aOoZAN5V/NSAEDz0ZixDSSt5b/a05iqtpgfLWMSVuCo7w0k2vVvEjdrIoeGqZzweX2WlyioNIHchA==} + dev: true + + /tinypool@0.8.2: + resolution: {integrity: sha512-SUszKYe5wgsxnNOVlBYO6IC+8VGWdVGZWAqUxp3UErNBtptZvWbwyUOyzNL59zigz2rCA92QiL3wvG+JDSdJdQ==} + engines: {node: '>=14.0.0'} + dev: true + /tinyspy@2.2.0: resolution: {integrity: sha512-d2eda04AN/cPOR89F7Xv5bK/jrQEhmcLFe6HFldoeO9AJtps+fqEnh486vnT/8y4bw38pSyxDcTCAq+Ks2aJTg==} engines: {node: '>=14.0.0'} @@ -13828,6 +13950,27 @@ packages: engines: {node: '>= 0.8'} dev: true + /vite-node@1.2.2(@types/node@20.11.5): + resolution: {integrity: sha512-1as4rDTgVWJO3n1uHmUYqq7nsFgINQ9u+mRcXpjeOMJUmviqNKjcZB7UfRZrlM7MjYXMKpuWp5oGkjaFLnjawg==} + engines: {node: ^18.0.0 || >=20.0.0} + hasBin: true + dependencies: + cac: 6.7.14 + debug: 4.3.4 + pathe: 1.1.2 + picocolors: 1.0.0 + vite: 5.0.12(@types/node@20.11.5) + transitivePeerDependencies: + - '@types/node' + - less + - lightningcss + - sass + - stylus + - sugarss + - supports-color + - terser + dev: true + /vite-plugin-css-injected-by-js@3.3.1(vite@5.0.12): resolution: {integrity: sha512-PjM/X45DR3/V1K1fTRs8HtZHEQ55kIfdrn+dzaqNBFrOYO073SeSNCxp4j7gSYhV9NffVHaEnOL4myoko0ePAg==} peerDependencies: @@ -13926,6 +14069,63 @@ packages: fsevents: 2.3.3 dev: true + /vitest@1.2.2(@types/node@20.11.5): + resolution: {integrity: sha512-d5Ouvrnms3GD9USIK36KG8OZ5bEvKEkITFtnGv56HFaSlbItJuYr7hv2Lkn903+AvRAgSixiamozUVfORUekjw==} + engines: {node: ^18.0.0 || >=20.0.0} + hasBin: true + peerDependencies: + '@edge-runtime/vm': '*' + '@types/node': ^18.0.0 || >=20.0.0 + '@vitest/browser': ^1.0.0 + '@vitest/ui': ^1.0.0 + happy-dom: '*' + jsdom: '*' + peerDependenciesMeta: + '@edge-runtime/vm': + optional: true + '@types/node': + optional: true + '@vitest/browser': + optional: true + '@vitest/ui': + optional: true + happy-dom: + optional: true + jsdom: + optional: true + dependencies: + '@types/node': 20.11.5 + '@vitest/expect': 1.2.2 + '@vitest/runner': 1.2.2 + '@vitest/snapshot': 1.2.2 + '@vitest/spy': 1.2.2 + '@vitest/utils': 1.2.2 + acorn-walk: 8.3.2 + cac: 6.7.14 + chai: 4.4.1 + debug: 4.3.4 + execa: 8.0.1 + local-pkg: 0.5.0 + magic-string: 0.30.5 + pathe: 1.1.2 + picocolors: 1.0.0 + std-env: 3.7.0 + strip-literal: 1.3.0 + tinybench: 2.6.0 + tinypool: 0.8.2 + vite: 5.0.12(@types/node@20.11.5) + vite-node: 1.2.2(@types/node@20.11.5) + why-is-node-running: 2.2.2 + transitivePeerDependencies: + - less + - lightningcss + - sass + - stylus + - sugarss + - supports-color + - terser + dev: true + /void-elements@3.1.0: resolution: {integrity: sha512-Dhxzh5HZuiHQhbvTW9AMetFfBHDMYpo23Uo9btPXgdYP+3T5S+p+jgNy7spra+veYhBP2dCSgxR/i2Y02h5/6w==} engines: {node: '>=0.10.0'} @@ -14049,6 +14249,15 @@ packages: isexe: 2.0.0 dev: true + /why-is-node-running@2.2.2: + resolution: {integrity: sha512-6tSwToZxTOcotxHeA+qGCq1mVzKR3CwcJGmVcY+QE8SHy6TnpFnh8PAvPNHYr7EcuVeG0QSMxtYCuO1ta/G/oA==} + engines: {node: '>=8'} + hasBin: true + dependencies: + siginfo: 2.0.0 + stackback: 0.0.2 + dev: true + /wordwrap@1.0.0: resolution: {integrity: sha512-gvVzJFlPycKc5dZN4yPkP8w7Dc37BtP1yczEneOb4uq34pXZcvrtRTmWV8W+Ume+XCxKgbjM+nevkyFPMybd4Q==} dev: true @@ -14189,6 +14398,11 @@ packages: engines: {node: '>=10'} dev: true + /yocto-queue@1.0.0: + resolution: {integrity: sha512-9bnSc/HEW2uRy67wc+T8UwauLuPJVn28jb+GtJY16iiKWyvmYJRXVT4UamsAEGQfPohgr2q4Tq0sQbQlxTfi1g==} + engines: {node: '>=12.20'} + dev: true + /z-schema@5.0.5: resolution: {integrity: sha512-D7eujBWkLa3p2sIpJA0d1pr7es+a7m0vFAnZLlCEKq/Ij2k0MLi9Br2UPxoxdYystm5K1yeBGzub0FlYUEWj2Q==} engines: {node: '>=8.0.0'} diff --git a/invokeai/frontend/web/vite.config.mts b/invokeai/frontend/web/vite.config.mts index b76dd24b628..325c6467dee 100644 --- a/invokeai/frontend/web/vite.config.mts +++ b/invokeai/frontend/web/vite.config.mts @@ -1,12 +1,90 @@ +/// +import react from '@vitejs/plugin-react-swc'; +import path from 'path'; +import { visualizer } from 'rollup-plugin-visualizer'; +import type { PluginOption } from 'vite'; import { defineConfig } from 'vite'; - -import { appConfig } from './config/vite.app.config.mjs'; -import { packageConfig } from './config/vite.package.config.mjs'; +import cssInjectedByJsPlugin from 'vite-plugin-css-injected-by-js'; +import dts from 'vite-plugin-dts'; +import eslint from 'vite-plugin-eslint'; +import tsconfigPaths from 'vite-tsconfig-paths'; export default defineConfig(({ mode }) => { if (mode === 'package') { - return packageConfig; + return { + base: './', + plugins: [ + react(), + eslint(), + tsconfigPaths(), + visualizer() as unknown as PluginOption, + dts({ + insertTypesEntry: true, + }), + cssInjectedByJsPlugin(), + ], + build: { + cssCodeSplit: true, + lib: { + entry: path.resolve(__dirname, '../src/index.ts'), + name: 'InvokeAIUI', + fileName: (format) => `invoke-ai-ui.${format}.js`, + }, + rollupOptions: { + external: ['react', 'react-dom', '@emotion/react', '@chakra-ui/react', '@invoke-ai/ui-library'], + output: { + globals: { + react: 'React', + 'react-dom': 'ReactDOM', + '@emotion/react': 'EmotionReact', + '@invoke-ai/ui-library': 'UiLibrary', + }, + }, + }, + }, + resolve: { + alias: { + app: path.resolve(__dirname, '../src/app'), + assets: path.resolve(__dirname, '../src/assets'), + common: path.resolve(__dirname, '../src/common'), + features: path.resolve(__dirname, '../src/features'), + services: path.resolve(__dirname, '../src/services'), + theme: path.resolve(__dirname, '../src/theme'), + }, + }, + }; } - return appConfig; + return { + base: './', + plugins: [react(), mode !== 'test' && eslint(), tsconfigPaths(), visualizer() as unknown as PluginOption], + build: { + chunkSizeWarningLimit: 1500, + }, + server: { + // Proxy HTTP requests to the flask server + proxy: { + // Proxy socket.io to the nodes socketio server + '/ws/socket.io': { + target: 'ws://127.0.0.1:9090', + ws: true, + }, + // Proxy openapi schema definiton + '/openapi.json': { + target: 'http://127.0.0.1:9090/openapi.json', + rewrite: (path) => path.replace(/^\/openapi.json/, ''), + changeOrigin: true, + }, + // proxy nodes api + '/api/v1': { + target: 'http://127.0.0.1:9090/api/v1', + rewrite: (path) => path.replace(/^\/api\/v1/, ''), + changeOrigin: true, + }, + }, + }, + test: { + // + }, + }; }); From 6b9f21414a8b37228097ffb414300460241b34bb Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 13 Feb 2024 22:53:30 +1100 Subject: [PATCH 070/340] feat(ui): add more types of FieldParseError Unfortunately you cannot test for both a specific type of error and match its message. Splitting the error classes makes it easier to test expected error conditions. --- .../web/src/features/nodes/types/error.ts | 5 ++++ .../nodes/util/schema/parseFieldType.ts | 30 +++++++++++-------- 2 files changed, 23 insertions(+), 12 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/types/error.ts b/invokeai/frontend/web/src/features/nodes/types/error.ts index 905b487fb04..c3da136c7a8 100644 --- a/invokeai/frontend/web/src/features/nodes/types/error.ts +++ b/invokeai/frontend/web/src/features/nodes/types/error.ts @@ -56,3 +56,8 @@ export class FieldParseError extends Error { this.name = this.constructor.name; } } + +export class UnableToExtractSchemaNameFromRefError extends FieldParseError {} +export class UnsupportedArrayItemType extends FieldParseError {} +export class UnsupportedUnionError extends FieldParseError {} +export class UnsupportedPrimitiveTypeError extends FieldParseError {} \ No newline at end of file diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/parseFieldType.ts b/invokeai/frontend/web/src/features/nodes/util/schema/parseFieldType.ts index 14b1aefd6d3..13da6b38312 100644 --- a/invokeai/frontend/web/src/features/nodes/util/schema/parseFieldType.ts +++ b/invokeai/frontend/web/src/features/nodes/util/schema/parseFieldType.ts @@ -1,6 +1,12 @@ -import { FieldParseError } from 'features/nodes/types/error'; +import { + FieldParseError, + UnableToExtractSchemaNameFromRefError, + UnsupportedArrayItemType, + UnsupportedPrimitiveTypeError, + UnsupportedUnionError, +} from 'features/nodes/types/error'; import type { FieldType } from 'features/nodes/types/field'; -import type { OpenAPIV3_1SchemaOrRef } from 'features/nodes/types/openapi'; +import type { InvocationFieldSchema, OpenAPIV3_1SchemaOrRef } from 'features/nodes/types/openapi'; import { isArraySchemaObject, isInvocationFieldSchema, @@ -42,7 +48,7 @@ const isCollectionFieldType = (fieldType: string) => { return false; }; -export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType => { +export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef | InvocationFieldSchema): FieldType => { if (isInvocationFieldSchema(schemaObject)) { // Check if this field has an explicit type provided by the node schema const { ui_type } = schemaObject; @@ -72,7 +78,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType // This is a single ref type const name = refObjectToSchemaName(allOf[0]); if (!name) { - throw new FieldParseError(t('nodes.unableToExtractSchemaNameFromRef')); + throw new UnableToExtractSchemaNameFromRefError(t('nodes.unableToExtractSchemaNameFromRef')); } return { name, @@ -95,7 +101,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType if (isRefObject(filteredAnyOf[0])) { const name = refObjectToSchemaName(filteredAnyOf[0]); if (!name) { - throw new FieldParseError(t('nodes.unableToExtractSchemaNameFromRef')); + throw new UnableToExtractSchemaNameFromRefError(t('nodes.unableToExtractSchemaNameFromRef')); } return { @@ -118,7 +124,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType if (filteredAnyOf.length !== 2) { // This is a union of more than 2 types, which we don't support - throw new FieldParseError( + throw new UnsupportedUnionError( t('nodes.unsupportedAnyOfLength', { count: filteredAnyOf.length, }) @@ -159,7 +165,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType }; } - throw new FieldParseError( + throw new UnsupportedUnionError( t('nodes.unsupportedMismatchedUnion', { firstType, secondType, @@ -178,7 +184,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType if (isSchemaObject(schemaObject.items)) { const itemType = schemaObject.items.type; if (!itemType || isArray(itemType)) { - throw new FieldParseError( + throw new UnsupportedArrayItemType( t('nodes.unsupportedArrayItemType', { type: itemType, }) @@ -188,7 +194,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType const name = OPENAPI_TO_FIELD_TYPE_MAP[itemType]; if (!name) { // it's 'null', 'object', or 'array' - skip - throw new FieldParseError( + throw new UnsupportedArrayItemType( t('nodes.unsupportedArrayItemType', { type: itemType, }) @@ -204,7 +210,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType // This is a ref object, extract the type name const name = refObjectToSchemaName(schemaObject.items); if (!name) { - throw new FieldParseError(t('nodes.unableToExtractSchemaNameFromRef')); + throw new UnableToExtractSchemaNameFromRefError(t('nodes.unableToExtractSchemaNameFromRef')); } return { name, @@ -216,7 +222,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType const name = OPENAPI_TO_FIELD_TYPE_MAP[schemaObject.type]; if (!name) { // it's 'null', 'object', or 'array' - skip - throw new FieldParseError( + throw new UnsupportedPrimitiveTypeError( t('nodes.unsupportedArrayItemType', { type: schemaObject.type, }) @@ -232,7 +238,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType } else if (isRefObject(schemaObject)) { const name = refObjectToSchemaName(schemaObject); if (!name) { - throw new FieldParseError(t('nodes.unableToExtractSchemaNameFromRef')); + throw new UnableToExtractSchemaNameFromRefError(t('nodes.unableToExtractSchemaNameFromRef')); } return { name, From 76031bcf43157b0a3070d03dd561b34d57a9a9b0 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 13 Feb 2024 22:53:52 +1100 Subject: [PATCH 071/340] tests(ui): add `parseFieldType.test.ts` --- .../nodes/util/schema/parseFieldType.test.ts | 379 ++++++++++++++++++ 1 file changed, 379 insertions(+) create mode 100644 invokeai/frontend/web/src/features/nodes/util/schema/parseFieldType.test.ts diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/parseFieldType.test.ts b/invokeai/frontend/web/src/features/nodes/util/schema/parseFieldType.test.ts new file mode 100644 index 00000000000..2f4ce48a326 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/schema/parseFieldType.test.ts @@ -0,0 +1,379 @@ +import { + UnableToExtractSchemaNameFromRefError, + UnsupportedArrayItemType, + UnsupportedPrimitiveTypeError, + UnsupportedUnionError, +} from 'features/nodes/types/error'; +import type { InvocationFieldSchema, OpenAPIV3_1SchemaOrRef } from 'features/nodes/types/openapi'; +import { parseFieldType, refObjectToSchemaName } from 'features/nodes/util/schema/parseFieldType'; +import { describe, expect, it } from 'vitest'; + +type ParseFieldTypeTestCase = { + name: string; + schema: OpenAPIV3_1SchemaOrRef | InvocationFieldSchema; + expected: { name: string; isCollection: boolean; isCollectionOrScalar: boolean }; +}; + +const primitiveTypes: ParseFieldTypeTestCase[] = [ + { + name: 'Scalar IntegerField', + schema: { type: 'integer' }, + expected: { name: 'IntegerField', isCollection: false, isCollectionOrScalar: false }, + }, + { + name: 'Scalar FloatField', + schema: { type: 'number' }, + expected: { name: 'FloatField', isCollection: false, isCollectionOrScalar: false }, + }, + { + name: 'Scalar StringField', + schema: { type: 'string' }, + expected: { name: 'StringField', isCollection: false, isCollectionOrScalar: false }, + }, + { + name: 'Scalar BooleanField', + schema: { type: 'boolean' }, + expected: { name: 'BooleanField', isCollection: false, isCollectionOrScalar: false }, + }, + { + name: 'Collection IntegerField', + schema: { items: { type: 'integer' }, type: 'array' }, + expected: { name: 'IntegerField', isCollection: true, isCollectionOrScalar: false }, + }, + { + name: 'Collection FloatField', + schema: { items: { type: 'number' }, type: 'array' }, + expected: { name: 'FloatField', isCollection: true, isCollectionOrScalar: false }, + }, + { + name: 'Collection StringField', + schema: { items: { type: 'string' }, type: 'array' }, + expected: { name: 'StringField', isCollection: true, isCollectionOrScalar: false }, + }, + { + name: 'Collection BooleanField', + schema: { items: { type: 'boolean' }, type: 'array' }, + expected: { name: 'BooleanField', isCollection: true, isCollectionOrScalar: false }, + }, + { + name: 'CollectionOrScalar IntegerField', + schema: { + anyOf: [ + { + type: 'integer', + }, + { + items: { + type: 'integer', + }, + type: 'array', + }, + ], + }, + expected: { name: 'IntegerField', isCollection: false, isCollectionOrScalar: true }, + }, + { + name: 'CollectionOrScalar FloatField', + schema: { + anyOf: [ + { + type: 'number', + }, + { + items: { + type: 'number', + }, + type: 'array', + }, + ], + }, + expected: { name: 'FloatField', isCollection: false, isCollectionOrScalar: true }, + }, + { + name: 'CollectionOrScalar StringField', + schema: { + anyOf: [ + { + type: 'string', + }, + { + items: { + type: 'string', + }, + type: 'array', + }, + ], + }, + expected: { name: 'StringField', isCollection: false, isCollectionOrScalar: true }, + }, + { + name: 'CollectionOrScalar BooleanField', + schema: { + anyOf: [ + { + type: 'boolean', + }, + { + items: { + type: 'boolean', + }, + type: 'array', + }, + ], + }, + expected: { name: 'BooleanField', isCollection: false, isCollectionOrScalar: true }, + }, +]; + +const complexTypes: ParseFieldTypeTestCase[] = [ + { + name: 'Scalar ConditioningField', + schema: { + allOf: [ + { + $ref: '#/components/schemas/ConditioningField', + }, + ], + }, + expected: { name: 'ConditioningField', isCollection: false, isCollectionOrScalar: false }, + }, + { + name: 'Nullable Scalar ConditioningField', + schema: { + anyOf: [ + { + $ref: '#/components/schemas/ConditioningField', + }, + { + type: 'null', + }, + ], + }, + expected: { name: 'ConditioningField', isCollection: false, isCollectionOrScalar: false }, + }, + { + name: 'Collection ConditioningField', + schema: { + anyOf: [ + { + items: { + $ref: '#/components/schemas/ConditioningField', + }, + type: 'array', + }, + ], + }, + expected: { name: 'ConditioningField', isCollection: true, isCollectionOrScalar: false }, + }, + { + name: 'Nullable Collection ConditioningField', + schema: { + anyOf: [ + { + items: { + $ref: '#/components/schemas/ConditioningField', + }, + type: 'array', + }, + { + type: 'null', + }, + ], + }, + expected: { name: 'ConditioningField', isCollection: true, isCollectionOrScalar: false }, + }, + { + name: 'CollectionOrScalar ConditioningField', + schema: { + anyOf: [ + { + items: { + $ref: '#/components/schemas/ConditioningField', + }, + type: 'array', + }, + { + $ref: '#/components/schemas/ConditioningField', + }, + ], + }, + expected: { name: 'ConditioningField', isCollection: false, isCollectionOrScalar: true }, + }, + { + name: 'Nullable CollectionOrScalar ConditioningField', + schema: { + anyOf: [ + { + items: { + $ref: '#/components/schemas/ConditioningField', + }, + type: 'array', + }, + { + $ref: '#/components/schemas/ConditioningField', + }, + { + type: 'null', + }, + ], + }, + expected: { name: 'ConditioningField', isCollection: false, isCollectionOrScalar: true }, + }, +]; + +const specialCases: ParseFieldTypeTestCase[] = [ + { + name: 'String EnumField', + schema: { + type: 'string', + enum: ['large', 'base', 'small'], + }, + expected: { name: 'EnumField', isCollection: false, isCollectionOrScalar: false }, + }, + { + name: 'String EnumField with one value', + schema: { + const: 'Some Value', + }, + expected: { name: 'EnumField', isCollection: false, isCollectionOrScalar: false }, + }, + { + name: 'Explicit ui_type (SchedulerField)', + schema: { + type: 'string', + enum: ['ddim', 'ddpm', 'deis'], + ui_type: 'SchedulerField', + }, + expected: { name: 'SchedulerField', isCollection: false, isCollectionOrScalar: false }, + }, + { + name: 'Explicit ui_type (AnyField)', + schema: { + type: 'string', + enum: ['ddim', 'ddpm', 'deis'], + ui_type: 'AnyField', + }, + expected: { name: 'AnyField', isCollection: false, isCollectionOrScalar: false }, + }, + { + name: 'Explicit ui_type (CollectionField)', + schema: { + type: 'string', + enum: ['ddim', 'ddpm', 'deis'], + ui_type: 'CollectionField', + }, + expected: { name: 'CollectionField', isCollection: true, isCollectionOrScalar: false }, + }, +]; + +describe('refObjectToSchemaName', async () => { + it('parses ref object 1', () => { + expect( + refObjectToSchemaName({ + $ref: '#/components/schemas/ImageField', + }) + ).toEqual('ImageField'); + }); + it('parses ref object 2', () => { + expect( + refObjectToSchemaName({ + $ref: '#/components/schemas/T2IAdapterModelField', + }) + ).toEqual('T2IAdapterModelField'); + }); +}); + +describe.concurrent('parseFieldType', async () => { + it.each(primitiveTypes)('parses primitive types ($name)', ({ schema, expected }) => { + expect(parseFieldType(schema)).toEqual(expected); + }); + it.each(complexTypes)('parses complex types ($name)', ({ schema, expected }) => { + expect(parseFieldType(schema)).toEqual(expected); + }); + it.each(specialCases)('parses special case types ($name)', ({ schema, expected }) => { + expect(parseFieldType(schema)).toEqual(expected); + }); + + it('raises if it cannot extract a schema name from a ref', () => { + expect(() => + parseFieldType({ + allOf: [ + { + $ref: '#/components/schemas/', + }, + ], + }) + ).toThrowError(UnableToExtractSchemaNameFromRefError); + }); + + it('raises if it receives a union of mismatched types', () => { + expect(() => + parseFieldType({ + anyOf: [ + { + type: 'string', + }, + { + type: 'integer', + }, + ], + }) + ).toThrowError(UnsupportedUnionError); + }); + + it('raises if it receives a union of mismatched types (excluding null)', () => { + expect(() => + parseFieldType({ + anyOf: [ + { + type: 'string', + }, + { + type: 'integer', + }, + { + type: 'null', + }, + ], + }) + ).toThrowError(UnsupportedUnionError); + }); + + it('raises if it received an unsupported primitive type (object)', () => { + expect(() => + parseFieldType({ + type: 'object', + }) + ).toThrowError(UnsupportedPrimitiveTypeError); + }); + + it('raises if it received an unsupported primitive type (null)', () => { + expect(() => + parseFieldType({ + type: 'null', + }) + ).toThrowError(UnsupportedPrimitiveTypeError); + }); + + it('raises if it received an unsupported array item type (object)', () => { + expect(() => + parseFieldType({ + items: { + type: 'object', + }, + type: 'array', + }) + ).toThrowError(UnsupportedArrayItemType); + }); + + it('raises if it received an unsupported array item type (null)', () => { + expect(() => + parseFieldType({ + items: { + type: 'null', + }, + type: 'array', + }) + ).toThrowError(UnsupportedArrayItemType); + }); +}); From 37cac353a3285189800104453c8347a892bdab09 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Mon, 22 Jan 2024 14:37:23 -0500 Subject: [PATCH 072/340] add concept of repo variant --- invokeai/backend/model_manager/config.py | 4 +- invokeai/backend/model_manager/probe.py | 19 ++++++++++ tests/test_model_probe.py | 9 ++++- .../vae/taesdxl-fp16/config.json | 37 +++++++++++++++++++ .../diffusion_pytorch_model.fp16.safetensors | 0 5 files changed, 65 insertions(+), 4 deletions(-) create mode 100644 tests/test_model_probe/vae/taesdxl-fp16/config.json create mode 100644 tests/test_model_probe/vae/taesdxl-fp16/diffusion_pytorch_model.fp16.safetensors diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 964cc19f196..b4685caf10d 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -150,7 +150,7 @@ class _DiffusersConfig(ModelConfigBase): """Model config for diffusers-style models.""" format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers - + repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.DEFAULT class LoRAConfig(ModelConfigBase): """Model config for LoRA/Lycoris models.""" @@ -179,7 +179,6 @@ class ControlNetDiffusersConfig(_DiffusersConfig): type: Literal[ModelType.ControlNet] = ModelType.ControlNet format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers - class ControlNetCheckpointConfig(_CheckpointConfig): """Model config for ControlNet models (diffusers version).""" @@ -215,7 +214,6 @@ class MainDiffusersConfig(_DiffusersConfig, _MainConfig): prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon upcast_attention: bool = False - class ONNXSD1Config(_MainConfig): """Model config for ONNX format models based on sd-1.""" diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index cd048d2fe78..ba3ac3dd0cc 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -20,6 +20,7 @@ ModelFormat, ModelType, ModelVariantType, + ModelRepoVariant, SchedulerPredictionType, ) from .hash import FastModelHash @@ -155,6 +156,9 @@ def probe( fields["original_hash"] = fields.get("original_hash") or hash fields["current_hash"] = fields.get("current_hash") or hash + if format_type == ModelFormat.Diffusers: + fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant() + # additional fields needed for main and controlnet models if fields["type"] in [ModelType.Main, ModelType.ControlNet] and fields["format"] == ModelFormat.Checkpoint: fields["config"] = cls._get_checkpoint_config_path( @@ -477,6 +481,20 @@ def get_variant_type(self) -> ModelVariantType: def get_format(self) -> ModelFormat: return ModelFormat("diffusers") + def get_repo_variant(self) -> ModelRepoVariant: + # get all files ending in .bin or .safetensors + weight_files = list(self.model_path.glob('**/*.safetensors')) + weight_files.extend(list(self.model_path.glob('**/*.bin'))) + for x in weight_files: + if ".fp16" in x.suffixes: + return ModelRepoVariant.FP16 + if "openvino_model" in x.name: + return ModelRepoVariant.OPENVINO + if "flax_model" in x.name: + return ModelRepoVariant.FLAX + if x.suffix == ".onnx": + return ModelRepoVariant.ONNX + return ModelRepoVariant.DEFAULT class PipelineFolderProbe(FolderProbeBase): def get_base_type(self) -> BaseModelType: @@ -522,6 +540,7 @@ def get_variant_type(self) -> ModelVariantType: except Exception: pass return ModelVariantType.Normal + class VaeFolderProbe(FolderProbeBase): diff --git a/tests/test_model_probe.py b/tests/test_model_probe.py index 248b7d602fd..415559a64cd 100644 --- a/tests/test_model_probe.py +++ b/tests/test_model_probe.py @@ -3,7 +3,7 @@ import pytest from invokeai.backend import BaseModelType -from invokeai.backend.model_management.model_probe import VaeFolderProbe +from invokeai.backend.model_manager.probe import VaeFolderProbe @pytest.mark.parametrize( @@ -20,3 +20,10 @@ def test_get_base_type(vae_path: str, expected_type: BaseModelType, datadir: Pat probe = VaeFolderProbe(sd1_vae_path) base_type = probe.get_base_type() assert base_type == expected_type + repo_variant = probe.get_repo_variant() + assert repo_variant == 'default' + +def test_repo_variant(datadir: Path): + probe = VaeFolderProbe(datadir / "vae" / "taesdxl-fp16") + repo_variant = probe.get_repo_variant() + assert repo_variant == 'fp16' diff --git a/tests/test_model_probe/vae/taesdxl-fp16/config.json b/tests/test_model_probe/vae/taesdxl-fp16/config.json new file mode 100644 index 00000000000..62f01c3eb44 --- /dev/null +++ b/tests/test_model_probe/vae/taesdxl-fp16/config.json @@ -0,0 +1,37 @@ +{ + "_class_name": "AutoencoderTiny", + "_diffusers_version": "0.20.0.dev0", + "act_fn": "relu", + "decoder_block_out_channels": [ + 64, + 64, + 64, + 64 + ], + "encoder_block_out_channels": [ + 64, + 64, + 64, + 64 + ], + "force_upcast": false, + "in_channels": 3, + "latent_channels": 4, + "latent_magnitude": 3, + "latent_shift": 0.5, + "num_decoder_blocks": [ + 3, + 3, + 3, + 1 + ], + "num_encoder_blocks": [ + 1, + 3, + 3, + 3 + ], + "out_channels": 3, + "scaling_factor": 1.0, + "upsampling_scaling_factor": 2 +} diff --git a/tests/test_model_probe/vae/taesdxl-fp16/diffusion_pytorch_model.fp16.safetensors b/tests/test_model_probe/vae/taesdxl-fp16/diffusion_pytorch_model.fp16.safetensors new file mode 100644 index 00000000000..e69de29bb2d From 88da301b6a333961a731dbc68cf5c09830935727 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Wed, 31 Jan 2024 23:37:59 -0500 Subject: [PATCH 073/340] add ram cache module and support files --- invokeai/backend/model_manager/config.py | 3 + .../backend/model_manager/load/__init__.py | 0 .../backend/model_manager/load/load_base.py | 193 ++++++++++ .../model_manager/load/load_default.py | 168 +++++++++ .../model_manager/load/memory_snapshot.py | 100 ++++++ .../backend/model_manager/load/model_util.py | 109 ++++++ .../model_manager/load/optimizations.py | 30 ++ .../model_manager/load/ram_cache/__init__.py | 0 .../load/ram_cache/ram_cache_base.py | 145 ++++++++ .../load/ram_cache/ram_cache_default.py | 332 ++++++++++++++++++ invokeai/backend/model_manager/load/vae.py | 31 ++ .../backend/model_manager/onnx_runtime.py | 216 ++++++++++++ invokeai/backend/model_manager/probe.py | 8 +- tests/test_model_probe.py | 5 +- 14 files changed, 1334 insertions(+), 6 deletions(-) create mode 100644 invokeai/backend/model_manager/load/__init__.py create mode 100644 invokeai/backend/model_manager/load/load_base.py create mode 100644 invokeai/backend/model_manager/load/load_default.py create mode 100644 invokeai/backend/model_manager/load/memory_snapshot.py create mode 100644 invokeai/backend/model_manager/load/model_util.py create mode 100644 invokeai/backend/model_manager/load/optimizations.py create mode 100644 invokeai/backend/model_manager/load/ram_cache/__init__.py create mode 100644 invokeai/backend/model_manager/load/ram_cache/ram_cache_base.py create mode 100644 invokeai/backend/model_manager/load/ram_cache/ram_cache_default.py create mode 100644 invokeai/backend/model_manager/load/vae.py create mode 100644 invokeai/backend/model_manager/onnx_runtime.py diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index b4685caf10d..338669c873a 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -152,6 +152,7 @@ class _DiffusersConfig(ModelConfigBase): format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.DEFAULT + class LoRAConfig(ModelConfigBase): """Model config for LoRA/Lycoris models.""" @@ -179,6 +180,7 @@ class ControlNetDiffusersConfig(_DiffusersConfig): type: Literal[ModelType.ControlNet] = ModelType.ControlNet format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers + class ControlNetCheckpointConfig(_CheckpointConfig): """Model config for ControlNet models (diffusers version).""" @@ -214,6 +216,7 @@ class MainDiffusersConfig(_DiffusersConfig, _MainConfig): prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon upcast_attention: bool = False + class ONNXSD1Config(_MainConfig): """Model config for ONNX format models based on sd-1.""" diff --git a/invokeai/backend/model_manager/load/__init__.py b/invokeai/backend/model_manager/load/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/invokeai/backend/model_manager/load/load_base.py b/invokeai/backend/model_manager/load/load_base.py new file mode 100644 index 00000000000..7cb7222b717 --- /dev/null +++ b/invokeai/backend/model_manager/load/load_base.py @@ -0,0 +1,193 @@ +# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team +""" +Base class for model loading in InvokeAI. + +Use like this: + + loader = AnyModelLoader(...) + loaded_model = loader.get_model('019ab39adfa1840455') + with loaded_model as model: # context manager moves model into VRAM + # do something with loaded_model +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from logging import Logger +from pathlib import Path +from typing import Any, Callable, Dict, Optional, Type, Union + +import torch +from diffusers import DiffusionPipeline +from injector import inject + +from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.app.services.model_records import ModelRecordServiceBase +from invokeai.backend.model_manager import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType +from invokeai.backend.model_manager.convert_cache import ModelConvertCacheBase +from invokeai.backend.model_manager.onnx_runtime import IAIOnnxRuntimeModel +from invokeai.backend.model_manager.ram_cache import ModelCacheBase + +AnyModel = Union[DiffusionPipeline, torch.nn.Module, IAIOnnxRuntimeModel] + + +class ModelLockerBase(ABC): + """Base class for the model locker used by the loader.""" + + @abstractmethod + def lock(self) -> None: + """Lock the contained model and move it into VRAM.""" + pass + + @abstractmethod + def unlock(self) -> None: + """Unlock the contained model, and remove it from VRAM.""" + pass + + @property + @abstractmethod + def model(self) -> AnyModel: + """Return the model.""" + pass + + +@dataclass +class LoadedModel: + """Context manager object that mediates transfer from RAM<->VRAM.""" + + config: AnyModelConfig + locker: ModelLockerBase + + def __enter__(self) -> AnyModel: # I think load_file() always returns a dict + """Context entry.""" + self.locker.lock() + return self.model + + def __exit__(self, *args: Any, **kwargs: Any) -> None: + """Context exit.""" + self.locker.unlock() + + @property + def model(self) -> AnyModel: + """Return the model without locking it.""" + return self.locker.model() + + +class ModelLoaderBase(ABC): + """Abstract base class for loading models into RAM/VRAM.""" + + @abstractmethod + def __init__( + self, + app_config: InvokeAIAppConfig, + logger: Logger, + ram_cache: ModelCacheBase, + convert_cache: ModelConvertCacheBase, + ): + """Initialize the loader.""" + pass + + @abstractmethod + def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel: + """ + Return a model given its key. + + Given a model key identified in the model configuration backend, + return a ModelInfo object that can be used to retrieve the model. + + :param model_config: Model configuration, as returned by ModelConfigRecordStore + :param submodel_type: an ModelType enum indicating the portion of + the model to retrieve (e.g. ModelType.Vae) + """ + pass + + @abstractmethod + def get_size_fs( + self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None + ) -> int: + """Return size in bytes of the model, calculated before loading.""" + pass + + +# TO DO: Better name? +class AnyModelLoader: + """This class manages the model loaders and invokes the correct one to load a model of given base and type.""" + + # this tracks the loader subclasses + _registry: Dict[str, Type[ModelLoaderBase]] = {} + + @inject + def __init__( + self, + store: ModelRecordServiceBase, + app_config: InvokeAIAppConfig, + logger: Logger, + ram_cache: ModelCacheBase, + convert_cache: ModelConvertCacheBase, + ): + """Store the provided ModelRecordServiceBase and empty the registry.""" + self._store = store + self._app_config = app_config + self._logger = logger + self._ram_cache = ram_cache + self._convert_cache = convert_cache + + def get_model(self, key: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel: + """ + Return a model given its key. + + Given a model key identified in the model configuration backend, + return a ModelInfo object that can be used to retrieve the model. + + :param key: model key, as known to the config backend + :param submodel_type: an ModelType enum indicating the portion of + the model to retrieve (e.g. ModelType.Vae) + """ + model_config = self._store.get_model(key) + implementation = self.__class__.get_implementation( + base=model_config.base, type=model_config.type, format=model_config.format + ) + return implementation( + app_config=self._app_config, + logger=self._logger, + ram_cache=self._ram_cache, + convert_cache=self._convert_cache, + ).load_model(model_config, submodel_type) + + @staticmethod + def _to_registry_key(base: BaseModelType, type: ModelType, format: ModelFormat) -> str: + return "-".join([base.value, type.value, format.value]) + + @classmethod + def get_implementation(cls, base: BaseModelType, type: ModelType, format: ModelFormat) -> Type[ModelLoaderBase]: + """Get subclass of ModelLoaderBase registered to handle base and type.""" + key1 = cls._to_registry_key(base, type, format) # for a specific base type + key2 = cls._to_registry_key(BaseModelType.Any, type, format) # with wildcard Any + implementation = cls._registry.get(key1) or cls._registry.get(key2) + if not implementation: + raise NotImplementedError( + "No subclass of LoadedModel is registered for base={base}, type={type}, format={format}" + ) + return implementation + + @classmethod + def register( + cls, type: ModelType, format: ModelFormat, base: BaseModelType = BaseModelType.Any + ) -> Callable[[Type[ModelLoaderBase]], Type[ModelLoaderBase]]: + """Define a decorator which registers the subclass of loader.""" + + def decorator(subclass: Type[ModelLoaderBase]) -> Type[ModelLoaderBase]: + print("Registering class", subclass.__name__) + key = cls._to_registry_key(base, type, format) + cls._registry[key] = subclass + return subclass + + return decorator + + +# in _init__.py will call something like +# def configure_loader_dependencies(binder): +# binder.bind(ModelRecordServiceBase, ApiDependencies.invoker.services.model_records, scope=singleton) +# binder.bind(InvokeAIAppConfig, ApiDependencies.invoker.services.configuration, scope=singleton) +# etc +# injector = Injector(configure_loader_dependencies) +# loader = injector.get(ModelFactory) diff --git a/invokeai/backend/model_manager/load/load_default.py b/invokeai/backend/model_manager/load/load_default.py new file mode 100644 index 00000000000..eb2d432aaae --- /dev/null +++ b/invokeai/backend/model_manager/load/load_default.py @@ -0,0 +1,168 @@ +# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team +"""Default implementation of model loading in InvokeAI.""" + +import sys +from logging import Logger +from pathlib import Path +from typing import Any, Dict, Optional, Tuple + +from diffusers import ModelMixin +from diffusers.configuration_utils import ConfigMixin +from injector import inject + +from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.backend.model_manager import AnyModelConfig, InvalidModelConfigException, ModelRepoVariant, SubModelType +from invokeai.backend.model_manager.convert_cache import ModelConvertCacheBase +from invokeai.backend.model_manager.load.load_base import AnyModel, LoadedModel, ModelLoaderBase +from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs +from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init +from invokeai.backend.model_manager.ram_cache import ModelCacheBase, ModelLockerBase +from invokeai.backend.util.devices import choose_torch_device, torch_dtype + + +class ConfigLoader(ConfigMixin): + """Subclass of ConfigMixin for loading diffusers configuration files.""" + + @classmethod + def load_config(cls, *args: Any, **kwargs: Any) -> Dict[str, Any]: + """Load a diffusrs ConfigMixin configuration.""" + cls.config_name = kwargs.pop("config_name") + # Diffusers doesn't provide typing info + return super().load_config(*args, **kwargs) # type: ignore + + +# TO DO: The loader is not thread safe! +class ModelLoader(ModelLoaderBase): + """Default implementation of ModelLoaderBase.""" + + @inject # can inject instances of each of the classes in the call signature + def __init__( + self, + app_config: InvokeAIAppConfig, + logger: Logger, + ram_cache: ModelCacheBase, + convert_cache: ModelConvertCacheBase, + ): + """Initialize the loader.""" + self._app_config = app_config + self._logger = logger + self._ram_cache = ram_cache + self._convert_cache = convert_cache + self._torch_dtype = torch_dtype(choose_torch_device()) + self._size: Optional[int] = None # model size + + def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel: + """ + Return a model given its configuration. + + Given a model's configuration as returned by the ModelRecordConfigStore service, + return a LoadedModel object that can be used for inference. + + :param model config: Configuration record for this model + :param submodel_type: an ModelType enum indicating the portion of + the model to retrieve (e.g. ModelType.Vae) + """ + if model_config.type == "main" and not submodel_type: + raise InvalidModelConfigException("submodel_type is required when loading a main model") + + model_path, is_submodel_override = self._get_model_path(model_config, submodel_type) + if is_submodel_override: + submodel_type = None + + if not model_path.exists(): + raise InvalidModelConfigException(f"Files for model 'model_config.name' not found at {model_path}") + + model_path = self._convert_if_needed(model_config, model_path, submodel_type) + locker = self._load_if_needed(model_config, model_path, submodel_type) + return LoadedModel(config=model_config, locker=locker) + + # IMPORTANT: This needs to be overridden in the StableDiffusion subclass so as to handle vae overrides + # and submodels!!!! + def _get_model_path( + self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None + ) -> Tuple[Path, bool]: + model_base = self._app_config.models_path + return ((model_base / config.path).resolve(), False) + + def _convert_if_needed( + self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None + ) -> Path: + if not self._needs_conversion(config): + return model_path + + self._convert_cache.make_room(self._size or self.get_size_fs(config, model_path, submodel_type)) + cache_path: Path = self._convert_cache.cache_path(config.key) + if cache_path.exists(): + return cache_path + + self._convert_model(model_path, cache_path) + return cache_path + + def _needs_conversion(self, config: AnyModelConfig) -> bool: + return False + + def _load_if_needed( + self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None + ) -> ModelLockerBase: + # TO DO: This is not thread safe! + if self._ram_cache.exists(config.key, submodel_type): + return self._ram_cache.get(config.key, submodel_type) + + model_variant = getattr(config, "repo_variant", None) + self._ram_cache.make_room(self.get_size_fs(config, model_path, submodel_type)) + + # This is where the model is actually loaded! + with skip_torch_weight_init(): + loaded_model = self._load_model(model_path, model_variant=model_variant, submodel_type=submodel_type) + + self._ram_cache.put( + config.key, + submodel_type=submodel_type, + model=loaded_model, + ) + + return self._ram_cache.get(config.key, submodel_type) + + def get_size_fs( + self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None + ) -> int: + """Get the size of the model on disk.""" + return calc_model_size_by_fs( + model_path=model_path, + subfolder=submodel_type.value if submodel_type else None, + variant=config.repo_variant if hasattr(config, "repo_variant") else None, + ) + + def _convert_model(self, model_path: Path, cache_path: Path) -> None: + raise NotImplementedError + + def _load_model( + self, + model_path: Path, + model_variant: Optional[ModelRepoVariant] = None, + submodel_type: Optional[SubModelType] = None, + ) -> AnyModel: + raise NotImplementedError + + def _load_diffusers_config(self, model_path: Path, config_name: str = "config.json") -> Dict[str, Any]: + return ConfigLoader.load_config(model_path, config_name=config_name) + + # TO DO: Add exception handling + def _hf_definition_to_type(self, module: str, class_name: str) -> ModelMixin: # fix with correct type + if module in ["diffusers", "transformers"]: + res_type = sys.modules[module] + else: + res_type = sys.modules["diffusers"].pipelines + result: ModelMixin = getattr(res_type, class_name) + return result + + # TO DO: Add exception handling + def _get_hf_load_class(self, model_path: Path, submodel_type: Optional[SubModelType] = None) -> ModelMixin: + if submodel_type: + config = self._load_diffusers_config(model_path, config_name="model_index.json") + module, class_name = config[submodel_type.value] + return self._hf_definition_to_type(module=module, class_name=class_name) + else: + config = self._load_diffusers_config(model_path, config_name="config.json") + class_name = config["_class_name"] + return self._hf_definition_to_type(module="diffusers", class_name=class_name) diff --git a/invokeai/backend/model_manager/load/memory_snapshot.py b/invokeai/backend/model_manager/load/memory_snapshot.py new file mode 100644 index 00000000000..504829a4271 --- /dev/null +++ b/invokeai/backend/model_manager/load/memory_snapshot.py @@ -0,0 +1,100 @@ +import gc +from typing import Optional + +import psutil +import torch +from typing_extensions import Self + +from invokeai.backend.model_management.libc_util import LibcUtil, Struct_mallinfo2 + +GB = 2**30 # 1 GB + + +class MemorySnapshot: + """A snapshot of RAM and VRAM usage. All values are in bytes.""" + + def __init__(self, process_ram: int, vram: Optional[int], malloc_info: Optional[Struct_mallinfo2]): + """Initialize a MemorySnapshot. + + Most of the time, `MemorySnapshot` will be constructed with `MemorySnapshot.capture()`. + + Args: + process_ram (int): CPU RAM used by the current process. + vram (Optional[int]): VRAM used by torch. + malloc_info (Optional[Struct_mallinfo2]): Malloc info obtained from LibcUtil. + """ + self.process_ram = process_ram + self.vram = vram + self.malloc_info = malloc_info + + @classmethod + def capture(cls, run_garbage_collector: bool = True) -> Self: + """Capture and return a MemorySnapshot. + + Note: This function has significant overhead, particularly if `run_garbage_collector == True`. + + Args: + run_garbage_collector (bool, optional): If true, gc.collect() will be run before checking the process RAM + usage. Defaults to True. + + Returns: + MemorySnapshot + """ + if run_garbage_collector: + gc.collect() + + # According to the psutil docs (https://psutil.readthedocs.io/en/latest/#psutil.Process.memory_info), rss is + # supported on all platforms. + process_ram = psutil.Process().memory_info().rss + + if torch.cuda.is_available(): + vram = torch.cuda.memory_allocated() + else: + # TODO: We could add support for mps.current_allocated_memory() as well. Leaving out for now until we have + # time to test it properly. + vram = None + + try: + malloc_info = LibcUtil().mallinfo2() # type: ignore + except (OSError, AttributeError): + # OSError: This is expected in environments that do not have the 'libc.so.6' shared library. + # AttributeError: This is expected in environments that have `libc.so.6` but do not have the `mallinfo2` (e.g. glibc < 2.33) + # TODO: Does `mallinfo` work? + malloc_info = None + + return cls(process_ram, vram, malloc_info) + + +def get_pretty_snapshot_diff(snapshot_1: Optional[MemorySnapshot], snapshot_2: Optional[MemorySnapshot]) -> str: + """Get a pretty string describing the difference between two `MemorySnapshot`s.""" + + def get_msg_line(prefix: str, val1: int, val2: int) -> str: + diff = val2 - val1 + return f"{prefix: <30} ({(diff/GB):+5.3f}): {(val1/GB):5.3f}GB -> {(val2/GB):5.3f}GB\n" + + msg = "" + + if snapshot_1 is None or snapshot_2 is None: + return msg + + msg += get_msg_line("Process RAM", snapshot_1.process_ram, snapshot_2.process_ram) + + if snapshot_1.malloc_info is not None and snapshot_2.malloc_info is not None: + msg += get_msg_line("libc mmap allocated", snapshot_1.malloc_info.hblkhd, snapshot_2.malloc_info.hblkhd) + + msg += get_msg_line("libc arena used", snapshot_1.malloc_info.uordblks, snapshot_2.malloc_info.uordblks) + + msg += get_msg_line("libc arena free", snapshot_1.malloc_info.fordblks, snapshot_2.malloc_info.fordblks) + + libc_total_allocated_1 = snapshot_1.malloc_info.arena + snapshot_1.malloc_info.hblkhd + libc_total_allocated_2 = snapshot_2.malloc_info.arena + snapshot_2.malloc_info.hblkhd + msg += get_msg_line("libc total allocated", libc_total_allocated_1, libc_total_allocated_2) + + libc_total_used_1 = snapshot_1.malloc_info.uordblks + snapshot_1.malloc_info.hblkhd + libc_total_used_2 = snapshot_2.malloc_info.uordblks + snapshot_2.malloc_info.hblkhd + msg += get_msg_line("libc total used", libc_total_used_1, libc_total_used_2) + + if snapshot_1.vram is not None and snapshot_2.vram is not None: + msg += get_msg_line("VRAM", snapshot_1.vram, snapshot_2.vram) + + return msg diff --git a/invokeai/backend/model_manager/load/model_util.py b/invokeai/backend/model_manager/load/model_util.py new file mode 100644 index 00000000000..18407cbca2e --- /dev/null +++ b/invokeai/backend/model_manager/load/model_util.py @@ -0,0 +1,109 @@ +# Copyright (c) 2024 The InvokeAI Development Team +"""Various utility functions needed by the loader and caching system.""" + +import json +from pathlib import Path +from typing import Optional, Union + +import torch +from diffusers import DiffusionPipeline + +from invokeai.backend.model_manager.onnx_runtime import IAIOnnxRuntimeModel + + +def calc_model_size_by_data(model: Union[DiffusionPipeline, torch.nn.Module, IAIOnnxRuntimeModel]) -> int: + """Get size of a model in memory in bytes.""" + if isinstance(model, DiffusionPipeline): + return _calc_pipeline_by_data(model) + elif isinstance(model, torch.nn.Module): + return _calc_model_by_data(model) + elif isinstance(model, IAIOnnxRuntimeModel): + return _calc_onnx_model_by_data(model) + else: + return 0 + + +def _calc_pipeline_by_data(pipeline: DiffusionPipeline) -> int: + res = 0 + assert hasattr(pipeline, "components") + for submodel_key in pipeline.components.keys(): + submodel = getattr(pipeline, submodel_key) + if submodel is not None and isinstance(submodel, torch.nn.Module): + res += _calc_model_by_data(submodel) + return res + + +def _calc_model_by_data(model: torch.nn.Module) -> int: + mem_params = sum([param.nelement() * param.element_size() for param in model.parameters()]) + mem_bufs = sum([buf.nelement() * buf.element_size() for buf in model.buffers()]) + mem: int = mem_params + mem_bufs # in bytes + return mem + + +def _calc_onnx_model_by_data(model: IAIOnnxRuntimeModel) -> int: + tensor_size = model.tensors.size() * 2 # The session doubles this + mem = tensor_size # in bytes + return mem + + +def calc_model_size_by_fs(model_path: Path, subfolder: Optional[str] = None, variant: Optional[str] = None) -> int: + """Estimate the size of a model on disk in bytes.""" + if subfolder is not None: + model_path = model_path / subfolder + + # this can happen when, for example, the safety checker is not downloaded. + if not model_path.exists(): + return 0 + + all_files = [f for f in model_path.iterdir() if (model_path / f).is_file()] + + fp16_files = {f for f in all_files if ".fp16." in f.name or ".fp16-" in f.name} + bit8_files = {f for f in all_files if ".8bit." in f.name or ".8bit-" in f.name} + other_files = set(all_files) - fp16_files - bit8_files + + if variant is None: + files = other_files + elif variant == "fp16": + files = fp16_files + elif variant == "8bit": + files = bit8_files + else: + raise NotImplementedError(f"Unknown variant: {variant}") + + # try read from index if exists + index_postfix = ".index.json" + if variant is not None: + index_postfix = f".index.{variant}.json" + + for file in files: + if not file.name.endswith(index_postfix): + continue + try: + with open(model_path / file, "r") as f: + index_data = json.loads(f.read()) + return int(index_data["metadata"]["total_size"]) + except Exception: + pass + + # calculate files size if there is no index file + formats = [ + (".safetensors",), # safetensors + (".bin",), # torch + (".onnx", ".pb"), # onnx + (".msgpack",), # flax + (".ckpt",), # tf + (".h5",), # tf2 + ] + + for file_format in formats: + model_files = [f for f in files if f.suffix in file_format] + if len(model_files) == 0: + continue + + model_size = 0 + for model_file in model_files: + file_stats = (model_path / model_file).stat() + model_size += file_stats.st_size + return model_size + + return 0 # scheduler/feature_extractor/tokenizer - models without loading to gpu diff --git a/invokeai/backend/model_manager/load/optimizations.py b/invokeai/backend/model_manager/load/optimizations.py new file mode 100644 index 00000000000..a46d262175f --- /dev/null +++ b/invokeai/backend/model_manager/load/optimizations.py @@ -0,0 +1,30 @@ +from contextlib import contextmanager + +import torch + + +def _no_op(*args, **kwargs): + pass + + +@contextmanager +def skip_torch_weight_init(): + """A context manager that monkey-patches several of the common torch layers (torch.nn.Linear, torch.nn.Conv1d, etc.) + to skip weight initialization. + + By default, `torch.nn.Linear` and `torch.nn.ConvNd` layers initialize their weights (according to a particular + distribution) when __init__ is called. This weight initialization step can take a significant amount of time, and is + completely unnecessary if the intent is to load checkpoint weights from disk for the layer. This context manager + monkey-patches common torch layers to skip the weight initialization step. + """ + torch_modules = [torch.nn.Linear, torch.nn.modules.conv._ConvNd, torch.nn.Embedding] + saved_functions = [m.reset_parameters for m in torch_modules] + + try: + for torch_module in torch_modules: + torch_module.reset_parameters = _no_op + + yield None + finally: + for torch_module, saved_function in zip(torch_modules, saved_functions, strict=True): + torch_module.reset_parameters = saved_function diff --git a/invokeai/backend/model_manager/load/ram_cache/__init__.py b/invokeai/backend/model_manager/load/ram_cache/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/invokeai/backend/model_manager/load/ram_cache/ram_cache_base.py b/invokeai/backend/model_manager/load/ram_cache/ram_cache_base.py new file mode 100644 index 00000000000..cd80d1e78b2 --- /dev/null +++ b/invokeai/backend/model_manager/load/ram_cache/ram_cache_base.py @@ -0,0 +1,145 @@ +# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team +# TODO: Add Stalker's proper name to copyright +""" +Manage a RAM cache of diffusion/transformer models for fast switching. +They are moved between GPU VRAM and CPU RAM as necessary. If the cache +grows larger than a preset maximum, then the least recently used +model will be cleared and (re)loaded from disk when next needed. +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from logging import Logger +from typing import Dict, Optional + +import torch + +from invokeai.backend.model_manager import SubModelType +from invokeai.backend.model_manager.load.load_base import AnyModel, ModelLockerBase + + +@dataclass +class CacheStats(object): + """Data object to record statistics on cache hits/misses.""" + + hits: int = 0 # cache hits + misses: int = 0 # cache misses + high_watermark: int = 0 # amount of cache used + in_cache: int = 0 # number of models in cache + cleared: int = 0 # number of models cleared to make space + cache_size: int = 0 # total size of cache + loaded_model_sizes: Dict[str, int] = field(default_factory=dict) + + +@dataclass +class CacheRecord: + """Elements of the cache.""" + + key: str + model: AnyModel + size: int + _locks: int = 0 + + def lock(self) -> None: + """Lock this record.""" + self._locks += 1 + + def unlock(self) -> None: + """Unlock this record.""" + self._locks -= 1 + assert self._locks >= 0 + + @property + def locked(self) -> bool: + """Return true if record is locked.""" + return self._locks > 0 + + +class ModelCacheBase(ABC): + """Virtual base class for RAM model cache.""" + + @property + @abstractmethod + def storage_device(self) -> torch.device: + """Return the storage device (e.g. "CPU" for RAM).""" + pass + + @property + @abstractmethod + def execution_device(self) -> torch.device: + """Return the exection device (e.g. "cuda" for VRAM).""" + pass + + @property + @abstractmethod + def lazy_offloading(self) -> bool: + """Return true if the cache is configured to lazily offload models in VRAM.""" + pass + + @abstractmethod + def offload_unlocked_models(self) -> None: + """Offload from VRAM any models not actively in use.""" + pass + + @abstractmethod + def move_model_to_device(self, cache_entry: CacheRecord, device: torch.device) -> None: + """Move model into the indicated device.""" + pass + + @property + @abstractmethod + def logger(self) -> Logger: + """Return the logger used by the cache.""" + pass + + @abstractmethod + def make_room(self, size: int) -> None: + """Make enough room in the cache to accommodate a new model of indicated size.""" + pass + + @abstractmethod + def put( + self, + key: str, + model: AnyModel, + submodel_type: Optional[SubModelType] = None, + ) -> None: + """Store model under key and optional submodel_type.""" + pass + + @abstractmethod + def get( + self, + key: str, + submodel_type: Optional[SubModelType] = None, + ) -> ModelLockerBase: + """ + Retrieve model locker object using key and optional submodel_type. + + This may return an UnknownModelException if the model is not in the cache. + """ + pass + + @abstractmethod + def exists( + self, + key: str, + submodel_type: Optional[SubModelType] = None, + ) -> bool: + """Return true if the model identified by key and submodel_type is in the cache.""" + pass + + @abstractmethod + def cache_size(self) -> int: + """Get the total size of the models currently cached.""" + pass + + @abstractmethod + def get_stats(self) -> CacheStats: + """Return cache hit/miss/size statistics.""" + pass + + @abstractmethod + def print_cuda_stats(self) -> None: + """Log debugging information on CUDA usage.""" + pass diff --git a/invokeai/backend/model_manager/load/ram_cache/ram_cache_default.py b/invokeai/backend/model_manager/load/ram_cache/ram_cache_default.py new file mode 100644 index 00000000000..bd43e978c83 --- /dev/null +++ b/invokeai/backend/model_manager/load/ram_cache/ram_cache_default.py @@ -0,0 +1,332 @@ +# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team +# TODO: Add Stalker's proper name to copyright +""" +Manage a RAM cache of diffusion/transformer models for fast switching. +They are moved between GPU VRAM and CPU RAM as necessary. If the cache +grows larger than a preset maximum, then the least recently used +model will be cleared and (re)loaded from disk when next needed. + +The cache returns context manager generators designed to load the +model into the GPU within the context, and unload outside the +context. Use like this: + + cache = ModelCache(max_cache_size=7.5) + with cache.get_model('runwayml/stable-diffusion-1-5') as SD1, + cache.get_model('stabilityai/stable-diffusion-2') as SD2: + do_something_in_GPU(SD1,SD2) + + +""" + +import math +import time +from contextlib import suppress +from logging import Logger +from typing import Any, Dict, List, Optional + +import torch + +from invokeai.app.services.model_records import UnknownModelException +from invokeai.backend.model_manager import SubModelType +from invokeai.backend.model_manager.load.load_base import AnyModel, ModelLockerBase +from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff +from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data +from invokeai.backend.model_manager.load.ram_cache.ram_cache_base import CacheRecord, CacheStats, ModelCacheBase +from invokeai.backend.util.devices import choose_torch_device +from invokeai.backend.util.logging import InvokeAILogger + +if choose_torch_device() == torch.device("mps"): + from torch import mps + +# Maximum size of the cache, in gigs +# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously +DEFAULT_MAX_CACHE_SIZE = 6.0 + +# amount of GPU memory to hold in reserve for use by generations (GB) +DEFAULT_MAX_VRAM_CACHE_SIZE = 2.75 + +# actual size of a gig +GIG = 1073741824 + +# Size of a MB in bytes. +MB = 2**20 + + +class ModelCache(ModelCacheBase): + """Implementation of ModelCacheBase.""" + + def __init__( + self, + max_cache_size: float = DEFAULT_MAX_CACHE_SIZE, + max_vram_cache_size: float = DEFAULT_MAX_VRAM_CACHE_SIZE, + execution_device: torch.device = torch.device("cuda"), + storage_device: torch.device = torch.device("cpu"), + precision: torch.dtype = torch.float16, + sequential_offload: bool = False, + lazy_offloading: bool = True, + sha_chunksize: int = 16777216, + log_memory_usage: bool = False, + logger: Optional[Logger] = None, + ): + """ + Initialize the model RAM cache. + + :param max_cache_size: Maximum size of the RAM cache [6.0 GB] + :param execution_device: Torch device to load active model into [torch.device('cuda')] + :param storage_device: Torch device to save inactive model in [torch.device('cpu')] + :param precision: Precision for loaded models [torch.float16] + :param lazy_offloading: Keep model in VRAM until another model needs to be loaded + :param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially + :param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache + operation, and the result will be logged (at debug level). There is a time cost to capturing the memory + snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's + behaviour. + """ + # allow lazy offloading only when vram cache enabled + self._lazy_offloading = lazy_offloading and max_vram_cache_size > 0 + self._precision: torch.dtype = precision + self._max_cache_size: float = max_cache_size + self._max_vram_cache_size: float = max_vram_cache_size + self._execution_device: torch.device = execution_device + self._storage_device: torch.device = storage_device + self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__) + self._log_memory_usage = log_memory_usage + + # used for stats collection + self.stats = None + + self._cached_models: Dict[str, CacheRecord] = {} + self._cache_stack: List[str] = [] + + class ModelLocker(ModelLockerBase): + """Internal class that mediates movement in and out of GPU.""" + + def __init__(self, cache: ModelCacheBase, cache_entry: CacheRecord): + """ + Initialize the model locker. + + :param cache: The ModelCache object + :param cache_entry: The entry in the model cache + """ + self._cache = cache + self._cache_entry = cache_entry + + @property + def model(self) -> AnyModel: + """Return the model without moving it around.""" + return self._cache_entry.model + + def lock(self) -> Any: + """Move the model into the execution device (GPU) and lock it.""" + if not hasattr(self.model, "to"): + return self.model + + # NOTE that the model has to have the to() method in order for this code to move it into GPU! + self._cache_entry.lock() + + try: + if self._cache.lazy_offloading: + self._cache.offload_unlocked_models() + + self._cache.move_model_to_device(self._cache_entry, self._cache.execution_device) + + self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._cache.execution_device}") + self._cache.print_cuda_stats() + + except Exception: + self._cache_entry.unlock() + raise + return self.model + + def unlock(self) -> None: + """Call upon exit from context.""" + if not hasattr(self.model, "to"): + return + + self._cache_entry.unlock() + if not self._cache.lazy_offloading: + self._cache.offload_unlocked_models() + self._cache.print_cuda_stats() + + @property + def logger(self) -> Logger: + """Return the logger used by the cache.""" + return self._logger + + @property + def lazy_offloading(self) -> bool: + """Return true if the cache is configured to lazily offload models in VRAM.""" + return self._lazy_offloading + + @property + def storage_device(self) -> torch.device: + """Return the storage device (e.g. "CPU" for RAM).""" + return self._storage_device + + @property + def execution_device(self) -> torch.device: + """Return the exection device (e.g. "cuda" for VRAM).""" + return self._execution_device + + def cache_size(self) -> int: + """Get the total size of the models currently cached.""" + total = 0 + for cache_record in self._cached_models.values(): + total += cache_record.size + return total + + def exists( + self, + key: str, + submodel_type: Optional[SubModelType] = None, + ) -> bool: + """Return true if the model identified by key and submodel_type is in the cache.""" + key = self._make_cache_key(key, submodel_type) + return key in self._cached_models + + def put( + self, + key: str, + model: AnyModel, + submodel_type: Optional[SubModelType] = None, + ) -> None: + """Store model under key and optional submodel_type.""" + key = self._make_cache_key(key, submodel_type) + assert key not in self._cached_models + + loaded_model_size = calc_model_size_by_data(model) + cache_record = CacheRecord(key, model, loaded_model_size) + self._cached_models[key] = cache_record + self._cache_stack.append(key) + + def get( + self, + key: str, + submodel_type: Optional[SubModelType] = None, + ) -> ModelLockerBase: + """ + Retrieve model using key and optional submodel_type. + + This may return an UnknownModelException if the model is not in the cache. + """ + key = self._make_cache_key(key, submodel_type) + if key not in self._cached_models: + raise UnknownModelException + + # this moves the entry to the top (right end) of the stack + with suppress(Exception): + self._cache_stack.remove(key) + self._cache_stack.append(key) + cache_entry = self._cached_models[key] + return self.ModelLocker( + cache=self, + cache_entry=cache_entry, + ) + + def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]: + if self._log_memory_usage: + return MemorySnapshot.capture() + return None + + def _make_cache_key(self, model_key: str, submodel_type: Optional[SubModelType] = None) -> str: + if submodel_type: + return f"{model_key}:{submodel_type.value}" + else: + return model_key + + def offload_unlocked_models(self) -> None: + """Move any unused models from VRAM.""" + reserved = self._max_vram_cache_size * GIG + vram_in_use = torch.cuda.memory_allocated() + self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB") + for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size): + if vram_in_use <= reserved: + break + if not cache_entry.locked: + self.move_model_to_device(cache_entry, self.storage_device) + + vram_in_use = torch.cuda.memory_allocated() + self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB") + + torch.cuda.empty_cache() + if choose_torch_device() == torch.device("mps"): + mps.empty_cache() + + # TO DO: Only reason to pass the CacheRecord rather than the model is to get the key and size + # for printing debugging messages. Revisit whether this is necessary + def move_model_to_device(self, cache_entry: CacheRecord, target_device: torch.device) -> None: + """Move model into the indicated device.""" + # These attributes are not in the base class but in derived classes + assert hasattr(cache_entry.model, "device") + assert hasattr(cache_entry.model, "to") + + source_device = cache_entry.model.device + + # Note: We compare device types only so that 'cuda' == 'cuda:0'. This would need to be revised to support + # multi-GPU. + if torch.device(source_device).type == torch.device(target_device).type: + return + + start_model_to_time = time.time() + snapshot_before = self._capture_memory_snapshot() + cache_entry.model.to(target_device) + snapshot_after = self._capture_memory_snapshot() + end_model_to_time = time.time() + self.logger.debug( + f"Moved model '{cache_entry.key}' from {source_device} to" + f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s.\n" + f"Estimated model size: {(cache_entry.size/GIG):.3f} GB.\n" + f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}" + ) + + if ( + snapshot_before is not None + and snapshot_after is not None + and snapshot_before.vram is not None + and snapshot_after.vram is not None + ): + vram_change = abs(snapshot_before.vram - snapshot_after.vram) + + # If the estimated model size does not match the change in VRAM, log a warning. + if not math.isclose( + vram_change, + cache_entry.size, + rel_tol=0.1, + abs_tol=10 * MB, + ): + self.logger.debug( + f"Moving model '{cache_entry.key}' from {source_device} to" + f" {target_device} caused an unexpected change in VRAM usage. The model's" + " estimated size may be incorrect. Estimated model size:" + f" {(cache_entry.size/GIG):.3f} GB.\n" + f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}" + ) + + def print_cuda_stats(self) -> None: + """Log CUDA diagnostics.""" + vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG) + ram = "%4.2fG" % self.cache_size() + + cached_models = 0 + loaded_models = 0 + locked_models = 0 + for cache_record in self._cached_models.values(): + cached_models += 1 + assert hasattr(cache_record.model, "device") + if cache_record.model.device is self.storage_device: + loaded_models += 1 + if cache_record.locked: + locked_models += 1 + + self.logger.debug( + f"Current VRAM/RAM usage: {vram}/{ram}; cached_models/loaded_models/locked_models/ =" + f" {cached_models}/{loaded_models}/{locked_models}" + ) + + def get_stats(self) -> CacheStats: + """Return cache hit/miss/size statistics.""" + raise NotImplementedError + + def make_room(self, size: int) -> None: + """Make enough room in the cache to accommodate a new model of indicated size.""" + raise NotImplementedError diff --git a/invokeai/backend/model_manager/load/vae.py b/invokeai/backend/model_manager/load/vae.py new file mode 100644 index 00000000000..a6cbe241e1e --- /dev/null +++ b/invokeai/backend/model_manager/load/vae.py @@ -0,0 +1,31 @@ +# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team +"""Class for VAE model loading in InvokeAI.""" + +from pathlib import Path +from typing import Dict, Optional + +import torch + +from invokeai.backend.model_manager import BaseModelType, ModelFormat, ModelRepoVariant, ModelType, SubModelType +from invokeai.backend.model_manager.load.load_base import AnyModelLoader +from invokeai.backend.model_manager.load.load_default import ModelLoader + + +@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Vae, format=ModelFormat.Diffusers) +class VaeDiffusersModel(ModelLoader): + """Class to load VAE models.""" + + def _load_model( + self, + model_path: Path, + model_variant: Optional[ModelRepoVariant] = None, + submodel_type: Optional[SubModelType] = None, + ) -> Dict[str, torch.Tensor]: + if submodel_type is not None: + raise Exception("There are no submodels in VAEs") + vae_class = self._get_hf_load_class(model_path) + variant = model_variant.value if model_variant else "" + result: Dict[str, torch.Tensor] = vae_class.from_pretrained( + model_path, torch_dtype=self._torch_dtype, variant=variant + ) # type: ignore + return result diff --git a/invokeai/backend/model_manager/onnx_runtime.py b/invokeai/backend/model_manager/onnx_runtime.py new file mode 100644 index 00000000000..f79fa015692 --- /dev/null +++ b/invokeai/backend/model_manager/onnx_runtime.py @@ -0,0 +1,216 @@ +# Copyright (c) 2024 The InvokeAI Development Team +import os +import sys +from pathlib import Path +from typing import Any, List, Optional, Tuple, Union + +import numpy as np +import onnx +from onnx import numpy_helper +from onnxruntime import InferenceSession, SessionOptions, get_available_providers + +ONNX_WEIGHTS_NAME = "model.onnx" + + +# NOTE FROM LS: This was copied from Stalker's original implementation. +# I have not yet gone through and fixed all the type hints +class IAIOnnxRuntimeModel: + class _tensor_access: + def __init__(self, model): # type: ignore + self.model = model + self.indexes = {} + for idx, obj in enumerate(self.model.proto.graph.initializer): + self.indexes[obj.name] = idx + + def __getitem__(self, key: str): # type: ignore + value = self.model.proto.graph.initializer[self.indexes[key]] + return numpy_helper.to_array(value) + + def __setitem__(self, key: str, value: np.ndarray): # type: ignore + new_node = numpy_helper.from_array(value) + # set_external_data(new_node, location="in-memory-location") + new_node.name = key + # new_node.ClearField("raw_data") + del self.model.proto.graph.initializer[self.indexes[key]] + self.model.proto.graph.initializer.insert(self.indexes[key], new_node) + # self.model.data[key] = OrtValue.ortvalue_from_numpy(value) + + # __delitem__ + + def __contains__(self, key: str) -> bool: + return self.indexes[key] in self.model.proto.graph.initializer + + def items(self) -> List[Tuple[str, Any]]: # fixme + raise NotImplementedError("tensor.items") + # return [(obj.name, obj) for obj in self.raw_proto] + + def keys(self) -> List[str]: + return list(self.indexes.keys()) + + def values(self) -> List[Any]: # fixme + raise NotImplementedError("tensor.values") + # return [obj for obj in self.raw_proto] + + def size(self) -> int: + bytesSum = 0 + for node in self.model.proto.graph.initializer: + bytesSum += sys.getsizeof(node.raw_data) + return bytesSum + + class _access_helper: + def __init__(self, raw_proto): # type: ignore + self.indexes = {} + self.raw_proto = raw_proto + for idx, obj in enumerate(raw_proto): + self.indexes[obj.name] = idx + + def __getitem__(self, key: str): # type: ignore + return self.raw_proto[self.indexes[key]] + + def __setitem__(self, key: str, value): # type: ignore + index = self.indexes[key] + del self.raw_proto[index] + self.raw_proto.insert(index, value) + + # __delitem__ + + def __contains__(self, key: str) -> bool: + return key in self.indexes + + def items(self) -> List[Tuple[str, Any]]: + return [(obj.name, obj) for obj in self.raw_proto] + + def keys(self) -> List[str]: + return list(self.indexes.keys()) + + def values(self) -> List[Any]: # fixme + return list(self.raw_proto) + + def __init__(self, model_path: str, provider: Optional[str]): + self.path = model_path + self.session = None + self.provider = provider + """ + self.data_path = self.path + "_data" + if not os.path.exists(self.data_path): + print(f"Moving model tensors to separate file: {self.data_path}") + tmp_proto = onnx.load(model_path, load_external_data=True) + onnx.save_model(tmp_proto, self.path, save_as_external_data=True, all_tensors_to_one_file=True, location=os.path.basename(self.data_path), size_threshold=1024, convert_attribute=False) + del tmp_proto + gc.collect() + + self.proto = onnx.load(model_path, load_external_data=False) + """ + + self.proto = onnx.load(model_path, load_external_data=True) + # self.data = dict() + # for tensor in self.proto.graph.initializer: + # name = tensor.name + + # if tensor.HasField("raw_data"): + # npt = numpy_helper.to_array(tensor) + # orv = OrtValue.ortvalue_from_numpy(npt) + # # self.data[name] = orv + # # set_external_data(tensor, location="in-memory-location") + # tensor.name = name + # # tensor.ClearField("raw_data") + + self.nodes = self._access_helper(self.proto.graph.node) # type: ignore + # self.initializers = self._access_helper(self.proto.graph.initializer) + # print(self.proto.graph.input) + # print(self.proto.graph.initializer) + + self.tensors = self._tensor_access(self) # type: ignore + + # TODO: integrate with model manager/cache + def create_session(self, height=None, width=None): + if self.session is None or self.session_width != width or self.session_height != height: + # onnx.save(self.proto, "tmp.onnx") + # onnx.save_model(self.proto, "tmp.onnx", save_as_external_data=True, all_tensors_to_one_file=True, location="tmp.onnx_data", size_threshold=1024, convert_attribute=False) + # TODO: something to be able to get weight when they already moved outside of model proto + # (trimmed_model, external_data) = buffer_external_data_tensors(self.proto) + sess = SessionOptions() + # self._external_data.update(**external_data) + # sess.add_external_initializers(list(self.data.keys()), list(self.data.values())) + # sess.enable_profiling = True + + # sess.intra_op_num_threads = 1 + # sess.inter_op_num_threads = 1 + # sess.execution_mode = ExecutionMode.ORT_SEQUENTIAL + # sess.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL + # sess.enable_cpu_mem_arena = True + # sess.enable_mem_pattern = True + # sess.add_session_config_entry("session.intra_op.use_xnnpack_threadpool", "1") ########### It's the key code + self.session_height = height + self.session_width = width + if height and width: + sess.add_free_dimension_override_by_name("unet_sample_batch", 2) + sess.add_free_dimension_override_by_name("unet_sample_channels", 4) + sess.add_free_dimension_override_by_name("unet_hidden_batch", 2) + sess.add_free_dimension_override_by_name("unet_hidden_sequence", 77) + sess.add_free_dimension_override_by_name("unet_sample_height", self.session_height) + sess.add_free_dimension_override_by_name("unet_sample_width", self.session_width) + sess.add_free_dimension_override_by_name("unet_time_batch", 1) + providers = [] + if self.provider: + providers.append(self.provider) + else: + providers = get_available_providers() + if "TensorrtExecutionProvider" in providers: + providers.remove("TensorrtExecutionProvider") + try: + self.session = InferenceSession(self.proto.SerializeToString(), providers=providers, sess_options=sess) + except Exception as e: + raise e + # self.session = InferenceSession("tmp.onnx", providers=[self.provider], sess_options=self.sess_options) + # self.io_binding = self.session.io_binding() + + def release_session(self): + self.session = None + import gc + + gc.collect() + return + + def __call__(self, **kwargs): + if self.session is None: + raise Exception("You should call create_session before running model") + + inputs = {k: np.array(v) for k, v in kwargs.items()} + # output_names = self.session.get_outputs() + # for k in inputs: + # self.io_binding.bind_cpu_input(k, inputs[k]) + # for name in output_names: + # self.io_binding.bind_output(name.name) + # self.session.run_with_iobinding(self.io_binding, None) + # return self.io_binding.copy_outputs_to_cpu() + return self.session.run(None, inputs) + + # compatability with diffusers load code + @classmethod + def from_pretrained( + cls, + model_id: Union[str, Path], + subfolder: Optional[Union[str, Path]] = None, + file_name: Optional[str] = None, + provider: Optional[str] = None, + sess_options: Optional["SessionOptions"] = None, + **kwargs: Any, + ) -> Any: # fixme + file_name = file_name or ONNX_WEIGHTS_NAME + + if os.path.isdir(model_id): + model_path = model_id + if subfolder is not None: + model_path = os.path.join(model_path, subfolder) + model_path = os.path.join(model_path, file_name) + + else: + model_path = model_id + + # load model from local directory + if not os.path.isfile(model_path): + raise Exception(f"Model not found: {model_path}") + + # TODO: session options + return cls(str(model_path), provider=provider) diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index ba3ac3dd0cc..9fd118b7822 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -18,9 +18,9 @@ InvalidModelConfigException, ModelConfigFactory, ModelFormat, + ModelRepoVariant, ModelType, ModelVariantType, - ModelRepoVariant, SchedulerPredictionType, ) from .hash import FastModelHash @@ -483,8 +483,8 @@ def get_format(self) -> ModelFormat: def get_repo_variant(self) -> ModelRepoVariant: # get all files ending in .bin or .safetensors - weight_files = list(self.model_path.glob('**/*.safetensors')) - weight_files.extend(list(self.model_path.glob('**/*.bin'))) + weight_files = list(self.model_path.glob("**/*.safetensors")) + weight_files.extend(list(self.model_path.glob("**/*.bin"))) for x in weight_files: if ".fp16" in x.suffixes: return ModelRepoVariant.FP16 @@ -496,6 +496,7 @@ def get_repo_variant(self) -> ModelRepoVariant: return ModelRepoVariant.ONNX return ModelRepoVariant.DEFAULT + class PipelineFolderProbe(FolderProbeBase): def get_base_type(self) -> BaseModelType: with open(self.model_path / "unet" / "config.json", "r") as file: @@ -540,7 +541,6 @@ def get_variant_type(self) -> ModelVariantType: except Exception: pass return ModelVariantType.Normal - class VaeFolderProbe(FolderProbeBase): diff --git a/tests/test_model_probe.py b/tests/test_model_probe.py index 415559a64cd..aacae06a8bb 100644 --- a/tests/test_model_probe.py +++ b/tests/test_model_probe.py @@ -21,9 +21,10 @@ def test_get_base_type(vae_path: str, expected_type: BaseModelType, datadir: Pat base_type = probe.get_base_type() assert base_type == expected_type repo_variant = probe.get_repo_variant() - assert repo_variant == 'default' + assert repo_variant == "default" + def test_repo_variant(datadir: Path): probe = VaeFolderProbe(datadir / "vae" / "taesdxl-fp16") repo_variant = probe.get_repo_variant() - assert repo_variant == 'fp16' + assert repo_variant == "fp16" From 4d58f4ba4463c77de4bc092a876a501867a8e1a6 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sat, 3 Feb 2024 22:55:09 -0500 Subject: [PATCH 074/340] model loading and conversion implemented for vaes --- invokeai/app/api/dependencies.py | 17 +- .../app/services/config/config_default.py | 21 +- .../model_install/model_install_default.py | 5 +- .../model_records/model_records_base.py | 15 +- .../model_records/model_records_sql.py | 43 +- .../app/services/shared/sqlite/sqlite_util.py | 2 + .../sqlite_migrator/migrations/migration_6.py | 44 + invokeai/backend/install/install_helper.py | 11 +- invokeai/backend/model_manager/__init__.py | 4 + invokeai/backend/model_manager/config.py | 10 +- .../convert_ckpt_to_diffusers.py | 1744 +++++++++++++++++ .../backend/model_manager/load/__init__.py | 35 + .../load/convert_cache/__init__.py | 4 + .../load/convert_cache/convert_cache_base.py | 28 + .../convert_cache/convert_cache_default.py | 64 + .../backend/model_manager/load/load_base.py | 72 +- .../model_manager/load/load_default.py | 23 +- .../load/model_cache/__init__.py | 5 + .../model_cache_base.py} | 54 +- .../model_cache_default.py} | 202 +- .../load/model_cache/model_locker.py | 59 + .../load/model_loaders/__init__.py | 3 + .../model_manager/load/model_loaders/vae.py | 83 + .../backend/model_manager/load/model_util.py | 3 + .../model_manager/load/ram_cache/__init__.py | 0 invokeai/backend/model_manager/load/vae.py | 31 - invokeai/backend/util/__init__.py | 12 +- invokeai/backend/util/devices.py | 5 +- invokeai/backend/util/util.py | 14 + 29 files changed, 2379 insertions(+), 234 deletions(-) create mode 100644 invokeai/app/services/shared/sqlite_migrator/migrations/migration_6.py create mode 100644 invokeai/backend/model_manager/convert_ckpt_to_diffusers.py create mode 100644 invokeai/backend/model_manager/load/convert_cache/__init__.py create mode 100644 invokeai/backend/model_manager/load/convert_cache/convert_cache_base.py create mode 100644 invokeai/backend/model_manager/load/convert_cache/convert_cache_default.py create mode 100644 invokeai/backend/model_manager/load/model_cache/__init__.py rename invokeai/backend/model_manager/load/{ram_cache/ram_cache_base.py => model_cache/model_cache_base.py} (77%) rename invokeai/backend/model_manager/load/{ram_cache/ram_cache_default.py => model_cache/model_cache_default.py} (63%) create mode 100644 invokeai/backend/model_manager/load/model_cache/model_locker.py create mode 100644 invokeai/backend/model_manager/load/model_loaders/__init__.py create mode 100644 invokeai/backend/model_manager/load/model_loaders/vae.py delete mode 100644 invokeai/backend/model_manager/load/ram_cache/__init__.py delete mode 100644 invokeai/backend/model_manager/load/vae.py diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 0f2a92b5c8e..dcb8d219971 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -8,6 +8,8 @@ from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache from invokeai.app.services.shared.sqlite.sqlite_util import init_db +from invokeai.backend.model_manager.load import AnyModelLoader, ModelConvertCache +from invokeai.backend.model_manager.load.model_cache import ModelCache from invokeai.backend.model_manager.metadata import ModelMetadataStore from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData from invokeai.backend.util.logging import InvokeAILogger @@ -98,15 +100,26 @@ def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger ) model_manager = ModelManagerService(config, logger) model_record_service = ModelRecordServiceSQL(db=db) + model_loader = AnyModelLoader( + app_config=config, + logger=logger, + ram_cache=ModelCache( + max_cache_size=config.ram_cache_size, max_vram_cache_size=config.vram_cache_size, logger=logger + ), + convert_cache=ModelConvertCache( + cache_path=config.models_convert_cache_path, max_size=config.convert_cache_size + ), + ) + model_record_service = ModelRecordServiceSQL(db=db, loader=model_loader) download_queue_service = DownloadQueueService(event_bus=events) - metadata_store = ModelMetadataStore(db=db) model_install_service = ModelInstallService( app_config=config, record_store=model_record_service, download_queue=download_queue_service, - metadata_store=metadata_store, + metadata_store=ModelMetadataStore(db=db), event_bus=events, ) + model_manager = ModelManagerService(config, logger) # TO DO: legacy model manager v1. Remove names = SimpleNameService() performance_statistics = InvocationStatsService() processor = DefaultInvocationProcessor() diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py index 132afc22722..b161ea18d61 100644 --- a/invokeai/app/services/config/config_default.py +++ b/invokeai/app/services/config/config_default.py @@ -237,6 +237,7 @@ class InvokeAIAppConfig(InvokeAISettings): autoimport_dir : Path = Field(default=Path('autoimport'), description='Path to a directory of models files to be imported on startup.', json_schema_extra=Categories.Paths) conf_path : Path = Field(default=Path('configs/models.yaml'), description='Path to models definition file', json_schema_extra=Categories.Paths) models_dir : Path = Field(default=Path('models'), description='Path to the models directory', json_schema_extra=Categories.Paths) + convert_cache_dir : Path = Field(default=Path('models/.cache'), description='Path to the converted models cache directory', json_schema_extra=Categories.Paths) legacy_conf_dir : Path = Field(default=Path('configs/stable-diffusion'), description='Path to directory of legacy checkpoint config files', json_schema_extra=Categories.Paths) db_dir : Path = Field(default=Path('databases'), description='Path to InvokeAI databases directory', json_schema_extra=Categories.Paths) outdir : Path = Field(default=Path('outputs'), description='Default folder for output images', json_schema_extra=Categories.Paths) @@ -262,6 +263,8 @@ class InvokeAIAppConfig(InvokeAISettings): # CACHE ram : float = Field(default=7.5, gt=0, description="Maximum memory amount used by model cache for rapid switching (floating point number, GB)", json_schema_extra=Categories.ModelCache, ) vram : float = Field(default=0.25, ge=0, description="Amount of VRAM reserved for model storage (floating point number, GB)", json_schema_extra=Categories.ModelCache, ) + convert_cache : float = Field(default=10.0, ge=0, description="Maximum size of on-disk converted models cache (GB)", json_schema_extra=Categories.ModelCache) + lazy_offload : bool = Field(default=True, description="Keep models in VRAM until their space is needed", json_schema_extra=Categories.ModelCache, ) log_memory_usage : bool = Field(default=False, description="If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.", json_schema_extra=Categories.ModelCache) @@ -404,6 +407,11 @@ def models_path(self) -> Path: """Path to the models directory.""" return self._resolve(self.models_dir) + @property + def models_convert_cache_path(self) -> Path: + """Path to the converted cache models directory.""" + return self._resolve(self.convert_cache_dir) + @property def custom_nodes_path(self) -> Path: """Path to the custom nodes directory.""" @@ -433,15 +441,20 @@ def invisible_watermark(self) -> bool: return True @property - def ram_cache_size(self) -> Union[Literal["auto"], float]: - """Return the ram cache size using the legacy or modern setting.""" + def ram_cache_size(self) -> float: + """Return the ram cache size using the legacy or modern setting (GB).""" return self.max_cache_size or self.ram @property - def vram_cache_size(self) -> Union[Literal["auto"], float]: - """Return the vram cache size using the legacy or modern setting.""" + def vram_cache_size(self) -> float: + """Return the vram cache size using the legacy or modern setting (GB).""" return self.max_vram_cache_size or self.vram + @property + def convert_cache_size(self) -> float: + """Return the convert cache size on disk (GB).""" + return self.convert_cache + @property def use_cpu(self) -> bool: """Return true if the device is set to CPU or the always_use_cpu flag is set.""" diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index 82c667f584f..2b2294bfce4 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -145,7 +145,7 @@ def register_path( ) -> str: # noqa D102 model_path = Path(model_path) config = config or {} - if config.get("source") is None: + if not config.get("source"): config["source"] = model_path.resolve().as_posix() return self._register(model_path, config) @@ -156,7 +156,7 @@ def install_path( ) -> str: # noqa D102 model_path = Path(model_path) config = config or {} - if config.get("source") is None: + if not config.get("source"): config["source"] = model_path.resolve().as_posix() info: AnyModelConfig = self._probe_model(Path(model_path), config) @@ -300,6 +300,7 @@ def _install_next_item(self) -> None: job.total_bytes = self._stat_size(job.local_path) job.bytes = job.total_bytes self._signal_job_running(job) + job.config_in["source"] = str(job.source) if job.inplace: key = self.register_path(job.local_path, job.config_in) else: diff --git a/invokeai/app/services/model_records/model_records_base.py b/invokeai/app/services/model_records/model_records_base.py index 57597570cde..31cfecb4ec8 100644 --- a/invokeai/app/services/model_records/model_records_base.py +++ b/invokeai/app/services/model_records/model_records_base.py @@ -11,7 +11,7 @@ from pydantic import BaseModel, Field from invokeai.app.services.shared.pagination import PaginatedResults -from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType +from invokeai.backend.model_manager import LoadedModel, AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore @@ -102,6 +102,19 @@ def get_model(self, key: str) -> AnyModelConfig: """ pass + @abstractmethod + def load_model(self, key: str, submodel_type: Optional[SubModelType]) -> LoadedModel: + """ + Load the indicated model into memory and return a LoadedModel object. + + :param key: Key of model config to be fetched. + :param submodel_type: For main (pipeline models), the submodel to fetch + + Exceptions: UnknownModelException -- model with this key not known + NotImplementedException -- a model loader was not provided at initialization time + """ + pass + @property @abstractmethod def metadata_store(self) -> ModelMetadataStore: diff --git a/invokeai/app/services/model_records/model_records_sql.py b/invokeai/app/services/model_records/model_records_sql.py index 4512da5d413..eee867ccb46 100644 --- a/invokeai/app/services/model_records/model_records_sql.py +++ b/invokeai/app/services/model_records/model_records_sql.py @@ -42,6 +42,7 @@ import json import sqlite3 +import time from math import ceil from pathlib import Path from typing import Any, Dict, List, Optional, Set, Tuple, Union @@ -53,8 +54,10 @@ ModelConfigFactory, ModelFormat, ModelType, + SubModelType, ) from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore, UnknownMetadataException +from invokeai.backend.model_manager.load import AnyModelLoader, LoadedModel from ..shared.sqlite.sqlite_database import SqliteDatabase from .model_records_base import ( @@ -69,16 +72,17 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): """Implementation of the ModelConfigStore ABC using a SQL database.""" - def __init__(self, db: SqliteDatabase): + def __init__(self, db: SqliteDatabase, loader: Optional[AnyModelLoader]=None): """ Initialize a new object from preexisting sqlite3 connection and threading lock objects. - :param conn: sqlite3 connection object - :param lock: threading Lock object + :param db: Sqlite connection object + :param loader: Initialized model loader object (optional) """ super().__init__() self._db = db - self._cursor = self._db.conn.cursor() + self._cursor = db.conn.cursor() + self._loader = loader @property def db(self) -> SqliteDatabase: @@ -199,7 +203,7 @@ def get_model(self, key: str) -> AnyModelConfig: with self._db.lock: self._cursor.execute( """--sql - SELECT config FROM model_config + SELECT config, strftime('%s',updated_at) FROM model_config WHERE id=?; """, (key,), @@ -207,9 +211,24 @@ def get_model(self, key: str) -> AnyModelConfig: rows = self._cursor.fetchone() if not rows: raise UnknownModelException("model not found") - model = ModelConfigFactory.make_config(json.loads(rows[0])) + model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1]) return model + def load_model(self, key: str, submodel_type: Optional[SubModelType]) -> LoadedModel: + """ + Load the indicated model into memory and return a LoadedModel object. + + :param key: Key of model config to be fetched. + :param submodel_type: For main (pipeline models), the submodel to fetch. + + Exceptions: UnknownModelException -- model with this key not known + NotImplementedException -- a model loader was not provided at initialization time + """ + if not self._loader: + raise NotImplementedError(f"Class {self.__class__} was not initialized with a model loader") + model_config = self.get_model(key) + return self._loader.load_model(model_config, submodel_type) + def exists(self, key: str) -> bool: """ Return True if a model with the indicated key exists in the databse. @@ -265,12 +284,12 @@ def search_by_attr( with self._db.lock: self._cursor.execute( f"""--sql - select config FROM model_config + select config, strftime('%s',updated_at) FROM model_config {where}; """, tuple(bindings), ) - results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()] + results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()] return results def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]: @@ -279,12 +298,12 @@ def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]: with self._db.lock: self._cursor.execute( """--sql - SELECT config FROM model_config + SELECT config, strftime('%s',updated_at) FROM model_config WHERE path=?; """, (str(path),), ) - results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()] + results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()] return results def search_by_hash(self, hash: str) -> List[AnyModelConfig]: @@ -293,12 +312,12 @@ def search_by_hash(self, hash: str) -> List[AnyModelConfig]: with self._db.lock: self._cursor.execute( """--sql - SELECT config FROM model_config + SELECT config, strftime('%s',updated_at) FROM model_config WHERE original_hash=?; """, (hash,), ) - results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()] + results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()] return results @property diff --git a/invokeai/app/services/shared/sqlite/sqlite_util.py b/invokeai/app/services/shared/sqlite/sqlite_util.py index 6079b3f08d7..681886eacd3 100644 --- a/invokeai/app/services/shared/sqlite/sqlite_util.py +++ b/invokeai/app/services/shared/sqlite/sqlite_util.py @@ -8,6 +8,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_3 import build_migration_3 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_4 import build_migration_4 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_5 import build_migration_5 +from invokeai.app.services.shared.sqlite_migrator.migrations.migration_6 import build_migration_6 from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator @@ -33,6 +34,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto migrator.register_migration(build_migration_3(app_config=config, logger=logger)) migrator.register_migration(build_migration_4()) migrator.register_migration(build_migration_5()) + migrator.register_migration(build_migration_6()) migrator.run_migrations() return db diff --git a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_6.py b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_6.py new file mode 100644 index 00000000000..e72878f726f --- /dev/null +++ b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_6.py @@ -0,0 +1,44 @@ +import sqlite3 +from logging import Logger + +from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration + +class Migration6Callback: + + def __call__(self, cursor: sqlite3.Cursor) -> None: + self._recreate_model_triggers(cursor) + + def _recreate_model_triggers(self, cursor: sqlite3.Cursor) -> None: + """ + Adds the timestamp trigger to the model_config table. + + This trigger was inadvertently dropped in earlier migration scripts. + """ + + cursor.execute( + """--sql + CREATE TRIGGER IF NOT EXISTS model_config_updated_at + AFTER UPDATE + ON model_config FOR EACH ROW + BEGIN + UPDATE model_config SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW') + WHERE id = old.id; + END; + """ + ) + +def build_migration_6() -> Migration: + """ + Build the migration from database version 5 to 6. + + This migration does the following: + - Adds the model_config_updated_at trigger if it does not exist + """ + migration_6 = Migration( + from_version=5, + to_version=6, + callback=Migration6Callback(), + ) + + return migration_6 diff --git a/invokeai/backend/install/install_helper.py b/invokeai/backend/install/install_helper.py index e54be527d95..8c03d2ccf84 100644 --- a/invokeai/backend/install/install_helper.py +++ b/invokeai/backend/install/install_helper.py @@ -98,11 +98,13 @@ def __init__(self) -> None: super().__init__() self._bars: Dict[str, tqdm] = {} self._last: Dict[str, int] = {} + self._logger = InvokeAILogger.get_logger(__name__) def dispatch(self, event_name: str, payload: Any) -> None: """Dispatch an event by appending it to self.events.""" + data = payload["data"] + source = data["source"] if payload["event"] == "model_install_downloading": - data = payload["data"] dest = data["local_path"] total_bytes = data["total_bytes"] bytes = data["bytes"] @@ -111,7 +113,12 @@ def dispatch(self, event_name: str, payload: Any) -> None: self._last[dest] = 0 self._bars[dest].update(bytes - self._last[dest]) self._last[dest] = bytes - + elif payload["event"] == "model_install_completed": + self._logger.info(f"{source}: installed successfully.") + elif payload["event"] == "model_install_error": + self._logger.warning(f"{source}: installation failed with error {data['error']}") + elif payload["event"] == "model_install_cancelled": + self._logger.warning(f"{source}: installation cancelled") class InstallHelper(object): """Capture information stored jointly in INITIAL_MODELS.yaml and the installed models db.""" diff --git a/invokeai/backend/model_manager/__init__.py b/invokeai/backend/model_manager/__init__.py index 0f16852c934..f3c84cd01f7 100644 --- a/invokeai/backend/model_manager/__init__.py +++ b/invokeai/backend/model_manager/__init__.py @@ -1,6 +1,7 @@ """Re-export frequently-used symbols from the Model Manager backend.""" from .config import ( + AnyModel, AnyModelConfig, BaseModelType, InvalidModelConfigException, @@ -14,12 +15,15 @@ ) from .probe import ModelProbe from .search import ModelSearch +from .load import LoadedModel __all__ = [ + "AnyModel", "AnyModelConfig", "BaseModelType", "ModelRepoVariant", "InvalidModelConfigException", + "LoadedModel", "ModelConfigFactory", "ModelFormat", "ModelProbe", diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 338669c873a..796ccbacde0 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -19,12 +19,15 @@ Validation errors will raise an InvalidModelConfigException error. """ +import time +import torch from enum import Enum from typing import Literal, Optional, Type, Union from pydantic import BaseModel, ConfigDict, Field, TypeAdapter +from diffusers import ModelMixin from typing_extensions import Annotated, Any, Dict - +from .onnx_runtime import IAIOnnxRuntimeModel class InvalidModelConfigException(Exception): """Exception for when config parser doesn't recognized this combination of model type and format.""" @@ -127,6 +130,7 @@ class ModelConfigBase(BaseModel): ) # if model is converted or otherwise modified, this will hold updated hash description: Optional[str] = Field(default=None) source: Optional[str] = Field(description="Model download source (URL or repo_id)", default=None) + last_modified: Optional[float] = Field(description="Timestamp for modification time", default_factory=time.time) model_config = ConfigDict( use_enum_values=False, @@ -280,6 +284,7 @@ class T2IConfig(ModelConfigBase): ] AnyModelConfigValidator = TypeAdapter(AnyModelConfig) +AnyModel = Union[ModelMixin, torch.nn.Module, IAIOnnxRuntimeModel] # IMPLEMENTATION NOTE: # The preferred alternative to the above is a discriminated Union as shown @@ -312,6 +317,7 @@ def make_config( model_data: Union[dict, AnyModelConfig], key: Optional[str] = None, dest_class: Optional[Type] = None, + timestamp: Optional[float] = None ) -> AnyModelConfig: """ Return the appropriate config object from raw dict values. @@ -330,4 +336,6 @@ def make_config( model = AnyModelConfigValidator.validate_python(model_data) if key: model.key = key + if timestamp: + model.last_modified = timestamp return model diff --git a/invokeai/backend/model_manager/convert_ckpt_to_diffusers.py b/invokeai/backend/model_manager/convert_ckpt_to_diffusers.py new file mode 100644 index 00000000000..9d6fc4841f2 --- /dev/null +++ b/invokeai/backend/model_manager/convert_ckpt_to_diffusers.py @@ -0,0 +1,1744 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Adapted for use in InvokeAI by Lincoln Stein, July 2023 +# +""" Conversion script for the Stable Diffusion checkpoints.""" + +import re +from contextlib import nullcontext +from io import BytesIO +from pathlib import Path +from typing import Optional, Union + +import requests +import torch +from diffusers.models import AutoencoderKL, ControlNetModel, PriorTransformer, UNet2DConditionModel +from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel +from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer +from diffusers.schedulers import ( + DDIMScheduler, + DDPMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + HeunDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, + UnCLIPScheduler, +) +from diffusers.utils import is_accelerate_available +from diffusers.utils.import_utils import BACKENDS_MAPPING +from picklescan.scanner import scan_file_path +from transformers import ( + AutoFeatureExtractor, + BertTokenizerFast, + CLIPImageProcessor, + CLIPTextConfig, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionConfig, + CLIPVisionModelWithProjection, +) + +from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.backend.util.logging import InvokeAILogger +from invokeai.backend.model_manager import BaseModelType, ModelVariantType + +try: + from omegaconf import OmegaConf + from omegaconf.dictconfig import DictConfig +except ImportError: + raise ImportError( + "OmegaConf is required to convert the LDM checkpoints. Please install it with `pip install OmegaConf`." + ) + +if is_accelerate_available(): + from accelerate import init_empty_weights + from accelerate.utils import set_module_tensor_to_device + +logger = InvokeAILogger.get_logger(__name__) +CONVERT_MODEL_ROOT = InvokeAIAppConfig.get_config().models_path / "core/convert" + + +def shave_segments(path, n_shave_prefix_segments=1): + """ + Removes segments. Positive values shave the first segments, negative shave the last segments. + """ + if n_shave_prefix_segments >= 0: + return ".".join(path.split(".")[n_shave_prefix_segments:]) + else: + return ".".join(path.split(".")[:n_shave_prefix_segments]) + + +def renew_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item.replace("in_layers.0", "norm1") + new_item = new_item.replace("in_layers.2", "conv1") + + new_item = new_item.replace("out_layers.0", "norm2") + new_item = new_item.replace("out_layers.3", "conv2") + + new_item = new_item.replace("emb_layers.1", "time_emb_proj") + new_item = new_item.replace("skip_connection", "conv_shortcut") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("nin_shortcut", "conv_shortcut") + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + # new_item = new_item.replace('norm.weight', 'group_norm.weight') + # new_item = new_item.replace('norm.bias', 'group_norm.bias') + + # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') + # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') + + # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("norm.weight", "group_norm.weight") + new_item = new_item.replace("norm.bias", "group_norm.bias") + + new_item = new_item.replace("q.weight", "to_q.weight") + new_item = new_item.replace("q.bias", "to_q.bias") + + new_item = new_item.replace("k.weight", "to_k.weight") + new_item = new_item.replace("k.bias", "to_k.bias") + + new_item = new_item.replace("v.weight", "to_v.weight") + new_item = new_item.replace("v.bias", "to_v.bias") + + new_item = new_item.replace("proj_out.weight", "to_out.0.weight") + new_item = new_item.replace("proj_out.bias", "to_out.0.bias") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def assign_to_checkpoint( + paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None +): + """ + This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits + attention layers, and takes into account additional replacements that may arise. + + Assigns the weights to the new checkpoint. + """ + assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." + + # Splits the attention layers into three variables. + if attention_paths_to_split is not None: + for path, path_map in attention_paths_to_split.items(): + old_tensor = old_checkpoint[path] + channels = old_tensor.shape[0] // 3 + + target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) + + num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 + + old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) + query, key, value = old_tensor.split(channels // num_heads, dim=1) + + checkpoint[path_map["query"]] = query.reshape(target_shape) + checkpoint[path_map["key"]] = key.reshape(target_shape) + checkpoint[path_map["value"]] = value.reshape(target_shape) + + for path in paths: + new_path = path["new"] + + # These have already been assigned + if attention_paths_to_split is not None and new_path in attention_paths_to_split: + continue + + # Global renaming happens here + new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") + new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") + new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") + + if additional_replacements is not None: + for replacement in additional_replacements: + new_path = new_path.replace(replacement["old"], replacement["new"]) + + # proj_attn.weight has to be converted from conv 1D to linear + is_attn_weight = "proj_attn.weight" in new_path or ("attentions" in new_path and "to_" in new_path) + shape = old_checkpoint[path["old"]].shape + if is_attn_weight and len(shape) == 3: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] + elif is_attn_weight and len(shape) == 4: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0] + else: + checkpoint[new_path] = old_checkpoint[path["old"]] + + +def conv_attn_to_linear(checkpoint): + keys = list(checkpoint.keys()) + attn_keys = ["query.weight", "key.weight", "value.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in attn_keys: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0, 0] + elif "proj_attn.weight" in key: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0] + + +def create_unet_diffusers_config(original_config, image_size: int, controlnet=False): + """ + Creates a config for the diffusers based on the config of the LDM model. + """ + if controlnet: + unet_params = original_config.model.params.control_stage_config.params + else: + if "unet_config" in original_config.model.params and original_config.model.params.unet_config is not None: + unet_params = original_config.model.params.unet_config.params + else: + unet_params = original_config.model.params.network_config.params + + vae_params = original_config.model.params.first_stage_config.params.ddconfig + + block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult] + + down_block_types = [] + resolution = 1 + for i in range(len(block_out_channels)): + block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D" + down_block_types.append(block_type) + if i != len(block_out_channels) - 1: + resolution *= 2 + + up_block_types = [] + for _i in range(len(block_out_channels)): + block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D" + up_block_types.append(block_type) + resolution //= 2 + + if unet_params.transformer_depth is not None: + transformer_layers_per_block = ( + unet_params.transformer_depth + if isinstance(unet_params.transformer_depth, int) + else list(unet_params.transformer_depth) + ) + else: + transformer_layers_per_block = 1 + + vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1) + + head_dim = unet_params.num_heads if "num_heads" in unet_params else None + use_linear_projection = ( + unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False + ) + if use_linear_projection: + # stable diffusion 2-base-512 and 2-768 + if head_dim is None: + head_dim_mult = unet_params.model_channels // unet_params.num_head_channels + head_dim = [head_dim_mult * c for c in list(unet_params.channel_mult)] + + class_embed_type = None + addition_embed_type = None + addition_time_embed_dim = None + projection_class_embeddings_input_dim = None + context_dim = None + + if unet_params.context_dim is not None: + context_dim = ( + unet_params.context_dim if isinstance(unet_params.context_dim, int) else unet_params.context_dim[0] + ) + + if "num_classes" in unet_params: + if unet_params.num_classes == "sequential": + if context_dim in [2048, 1280]: + # SDXL + addition_embed_type = "text_time" + addition_time_embed_dim = 256 + else: + class_embed_type = "projection" + assert "adm_in_channels" in unet_params + projection_class_embeddings_input_dim = unet_params.adm_in_channels + else: + raise NotImplementedError(f"Unknown conditional unet num_classes config: {unet_params.num_classes}") + + config = { + "sample_size": image_size // vae_scale_factor, + "in_channels": unet_params.in_channels, + "down_block_types": tuple(down_block_types), + "block_out_channels": tuple(block_out_channels), + "layers_per_block": unet_params.num_res_blocks, + "cross_attention_dim": context_dim, + "attention_head_dim": head_dim, + "use_linear_projection": use_linear_projection, + "class_embed_type": class_embed_type, + "addition_embed_type": addition_embed_type, + "addition_time_embed_dim": addition_time_embed_dim, + "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim, + "transformer_layers_per_block": transformer_layers_per_block, + } + + if controlnet: + config["conditioning_channels"] = unet_params.hint_channels + else: + config["out_channels"] = unet_params.out_channels + config["up_block_types"] = tuple(up_block_types) + + return config + + +def create_vae_diffusers_config(original_config, image_size: int): + """ + Creates a config for the diffusers based on the config of the LDM model. + """ + vae_params = original_config.model.params.first_stage_config.params.ddconfig + _ = original_config.model.params.first_stage_config.params.embed_dim + + block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult] + down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) + up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) + + config = { + "sample_size": image_size, + "in_channels": vae_params.in_channels, + "out_channels": vae_params.out_ch, + "down_block_types": tuple(down_block_types), + "up_block_types": tuple(up_block_types), + "block_out_channels": tuple(block_out_channels), + "latent_channels": vae_params.z_channels, + "layers_per_block": vae_params.num_res_blocks, + } + return config + + +def create_diffusers_schedular(original_config): + schedular = DDIMScheduler( + num_train_timesteps=original_config.model.params.timesteps, + beta_start=original_config.model.params.linear_start, + beta_end=original_config.model.params.linear_end, + beta_schedule="scaled_linear", + ) + return schedular + + +def create_ldm_bert_config(original_config): + bert_params = original_config.model.parms.cond_stage_config.params + config = LDMBertConfig( + d_model=bert_params.n_embed, + encoder_layers=bert_params.n_layer, + encoder_ffn_dim=bert_params.n_embed * 4, + ) + return config + + +def convert_ldm_unet_checkpoint( + checkpoint, config, path=None, extract_ema=False, controlnet=False, skip_extract_state_dict=False +): + """ + Takes a state dict and a config, and returns a converted checkpoint. + """ + + if skip_extract_state_dict: + unet_state_dict = checkpoint + else: + # extract state_dict for UNet + unet_state_dict = {} + keys = list(checkpoint.keys()) + + if controlnet: + unet_key = "control_model." + else: + unet_key = "model.diffusion_model." + + # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA + if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema: + logger.warning(f"Checkpoint {path} has both EMA and non-EMA weights.") + logger.warning( + "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" + " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." + ) + for key in keys: + if key.startswith("model.diffusion_model"): + flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) + else: + if sum(k.startswith("model_ema") for k in keys) > 100: + logger.warning( + "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" + " weights (usually better for inference), please make sure to add the `--extract_ema` flag." + ) + + for key in keys: + if key.startswith(unet_key): + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) + + new_checkpoint = {} + + new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] + new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] + new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] + new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] + + if config["class_embed_type"] is None: + # No parameters to port + ... + elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection": + new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] + new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] + new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] + new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] + else: + raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}") + + if config["addition_embed_type"] == "text_time": + new_checkpoint["add_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] + new_checkpoint["add_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] + new_checkpoint["add_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] + new_checkpoint["add_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] + + new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] + new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] + + if not controlnet: + new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] + new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] + new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] + new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] + + # Retrieves the keys for the input blocks only + num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) + input_blocks = { + layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] + for layer_id in range(num_input_blocks) + } + + # Retrieves the keys for the middle blocks only + num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) + middle_blocks = { + layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] + for layer_id in range(num_middle_blocks) + } + + # Retrieves the keys for the output blocks only + num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) + output_blocks = { + layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] + for layer_id in range(num_output_blocks) + } + + for i in range(1, num_input_blocks): + block_id = (i - 1) // (config["layers_per_block"] + 1) + layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) + + resnets = [ + key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key + ] + attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] + + if f"input_blocks.{i}.0.op.weight" in unet_state_dict: + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.weight" + ) + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.bias" + ) + + paths = renew_resnet_paths(resnets) + meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config) + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + resnet_0 = middle_blocks[0] + attentions = middle_blocks[1] + resnet_1 = middle_blocks[2] + + resnet_0_paths = renew_resnet_paths(resnet_0) + assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) + + resnet_1_paths = renew_resnet_paths(resnet_1) + assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) + + attentions_paths = renew_attention_paths(attentions) + meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} + assign_to_checkpoint( + attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + for i in range(num_output_blocks): + block_id = i // (config["layers_per_block"] + 1) + layer_in_block_id = i % (config["layers_per_block"] + 1) + output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] + output_block_list = {} + + for layer in output_block_layers: + layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) + if layer_id in output_block_list: + output_block_list[layer_id].append(layer_name) + else: + output_block_list[layer_id] = [layer_name] + + if len(output_block_list) > 1: + resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] + attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] + + resnet_0_paths = renew_resnet_paths(resnets) + paths = renew_resnet_paths(resnets) + + meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + output_block_list = {k: sorted(v) for k, v in output_block_list.items()} + if ["conv.bias", "conv.weight"] in output_block_list.values(): + index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.weight" + ] + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.bias" + ] + + # Clear attentions as they have been attributed above. + if len(attentions) == 2: + attentions = [] + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = { + "old": f"output_blocks.{i}.1", + "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", + } + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + else: + resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) + for path in resnet_0_paths: + old_path = ".".join(["output_blocks", str(i), path["old"]]) + new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) + + new_checkpoint[new_path] = unet_state_dict[old_path] + + if controlnet: + # conditioning embedding + + orig_index = 0 + + new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.weight" + ) + new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.bias" + ) + + orig_index += 2 + + diffusers_index = 0 + + while diffusers_index < 6: + new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.weight" + ) + new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.bias" + ) + diffusers_index += 1 + orig_index += 2 + + new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.weight" + ) + new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.bias" + ) + + # down blocks + for i in range(num_input_blocks): + new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(f"zero_convs.{i}.0.weight") + new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(f"zero_convs.{i}.0.bias") + + # mid block + new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight") + new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias") + + return new_checkpoint + + +def convert_ldm_vae_checkpoint(checkpoint, config): + # extract state dict for VAE + vae_state_dict = {} + keys = list(checkpoint.keys()) + vae_key = "first_stage_model." if any(k.startswith("first_stage_model.") for k in keys) else "" + for key in keys: + if key.startswith(vae_key): + vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) + + new_checkpoint = {} + + new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] + new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] + new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] + new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] + new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] + new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] + + new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] + new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] + new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] + new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] + new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] + new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] + + new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] + new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] + new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] + new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] + + # Retrieves the keys for the encoder down blocks only + num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) + down_blocks = { + layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) + } + + # Retrieves the keys for the decoder up blocks only + num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) + up_blocks = { + layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) + } + + for i in range(num_down_blocks): + resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] + + if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.weight" + ) + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.bias" + ) + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + + for i in range(num_up_blocks): + block_id = num_up_blocks - 1 - i + resnets = [ + key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key + ] + + if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.weight" + ] + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.bias" + ] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + return new_checkpoint + + +def convert_ldm_bert_checkpoint(checkpoint, config): + def _copy_attn_layer(hf_attn_layer, pt_attn_layer): + hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight + hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight + hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight + + hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight + hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias + + def _copy_linear(hf_linear, pt_linear): + hf_linear.weight = pt_linear.weight + hf_linear.bias = pt_linear.bias + + def _copy_layer(hf_layer, pt_layer): + # copy layer norms + _copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0]) + _copy_linear(hf_layer.final_layer_norm, pt_layer[1][0]) + + # copy attn + _copy_attn_layer(hf_layer.self_attn, pt_layer[0][1]) + + # copy MLP + pt_mlp = pt_layer[1][1] + _copy_linear(hf_layer.fc1, pt_mlp.net[0][0]) + _copy_linear(hf_layer.fc2, pt_mlp.net[2]) + + def _copy_layers(hf_layers, pt_layers): + for i, hf_layer in enumerate(hf_layers): + if i != 0: + i += i + pt_layer = pt_layers[i : i + 2] + _copy_layer(hf_layer, pt_layer) + + hf_model = LDMBertModel(config).eval() + + # copy embeds + hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight + hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight + + # copy layer norm + _copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm) + + # copy hidden layers + _copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers) + + _copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits) + + return hf_model + + +def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder=None): + if text_encoder is None: + config = CLIPTextConfig.from_pretrained(CONVERT_MODEL_ROOT / "clip-vit-large-patch14") + + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + text_model = CLIPTextModel(config) + + keys = list(checkpoint.keys()) + + text_model_dict = {} + + remove_prefixes = ["cond_stage_model.transformer", "conditioner.embedders.0.transformer"] + + for key in keys: + for prefix in remove_prefixes: + if key.startswith(prefix): + text_model_dict[key[len(prefix + ".") :]] = checkpoint[key] + + if is_accelerate_available(): + for param_name, param in text_model_dict.items(): + set_module_tensor_to_device(text_model, param_name, "cpu", value=param) + else: + text_model.load_state_dict(text_model_dict) + + return text_model + + +textenc_conversion_lst = [ + ("positional_embedding", "text_model.embeddings.position_embedding.weight"), + ("token_embedding.weight", "text_model.embeddings.token_embedding.weight"), + ("ln_final.weight", "text_model.final_layer_norm.weight"), + ("ln_final.bias", "text_model.final_layer_norm.bias"), + ("text_projection", "text_projection.weight"), +] +textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst} + +textenc_transformer_conversion_lst = [ + # (stable-diffusion, HF Diffusers) + ("resblocks.", "text_model.encoder.layers."), + ("ln_1", "layer_norm1"), + ("ln_2", "layer_norm2"), + (".c_fc.", ".fc1."), + (".c_proj.", ".fc2."), + (".attn", ".self_attn"), + ("ln_final.", "transformer.text_model.final_layer_norm."), + ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"), + ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"), +] +protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst} +textenc_pattern = re.compile("|".join(protected.keys())) + + +def convert_paint_by_example_checkpoint(checkpoint): + config = CLIPVisionConfig.from_pretrained(CONVERT_MODEL_ROOT / "clip-vit-large-patch14") + model = PaintByExampleImageEncoder(config) + + keys = list(checkpoint.keys()) + + text_model_dict = {} + + for key in keys: + if key.startswith("cond_stage_model.transformer"): + text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] + + # load clip vision + model.model.load_state_dict(text_model_dict) + + # load mapper + keys_mapper = { + k[len("cond_stage_model.mapper.res") :]: v + for k, v in checkpoint.items() + if k.startswith("cond_stage_model.mapper") + } + + MAPPING = { + "attn.c_qkv": ["attn1.to_q", "attn1.to_k", "attn1.to_v"], + "attn.c_proj": ["attn1.to_out.0"], + "ln_1": ["norm1"], + "ln_2": ["norm3"], + "mlp.c_fc": ["ff.net.0.proj"], + "mlp.c_proj": ["ff.net.2"], + } + + mapped_weights = {} + for key, value in keys_mapper.items(): + prefix = key[: len("blocks.i")] + suffix = key.split(prefix)[-1].split(".")[-1] + name = key.split(prefix)[-1].split(suffix)[0][1:-1] + mapped_names = MAPPING[name] + + num_splits = len(mapped_names) + for i, mapped_name in enumerate(mapped_names): + new_name = ".".join([prefix, mapped_name, suffix]) + shape = value.shape[0] // num_splits + mapped_weights[new_name] = value[i * shape : (i + 1) * shape] + + model.mapper.load_state_dict(mapped_weights) + + # load final layer norm + model.final_layer_norm.load_state_dict( + { + "bias": checkpoint["cond_stage_model.final_ln.bias"], + "weight": checkpoint["cond_stage_model.final_ln.weight"], + } + ) + + # load final proj + model.proj_out.load_state_dict( + { + "bias": checkpoint["proj_out.bias"], + "weight": checkpoint["proj_out.weight"], + } + ) + + # load uncond vector + model.uncond_vector.data = torch.nn.Parameter(checkpoint["learnable_vector"]) + return model + + +def convert_open_clip_checkpoint( + checkpoint, config_name, prefix="cond_stage_model.model.", has_projection=False, **config_kwargs +): + # text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder") + # text_model = CLIPTextModelWithProjection.from_pretrained( + # "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", projection_dim=1280 + # ) + config = CLIPTextConfig.from_pretrained(config_name, **config_kwargs) + + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + text_model = CLIPTextModelWithProjection(config) if has_projection else CLIPTextModel(config) + + keys = list(checkpoint.keys()) + + keys_to_ignore = [] + if config_name == "stabilityai/stable-diffusion-2" and config.num_hidden_layers == 23: + # make sure to remove all keys > 22 + keys_to_ignore += [k for k in keys if k.startswith("cond_stage_model.model.transformer.resblocks.23")] + keys_to_ignore += ["cond_stage_model.model.text_projection"] + + text_model_dict = {} + + if prefix + "text_projection" in checkpoint: + d_model = int(checkpoint[prefix + "text_projection"].shape[0]) + else: + d_model = 1024 + + text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids") + + for key in keys: + if key in keys_to_ignore: + continue + if key[len(prefix) :] in textenc_conversion_map: + if key.endswith("text_projection"): + value = checkpoint[key].T.contiguous() + else: + value = checkpoint[key] + + text_model_dict[textenc_conversion_map[key[len(prefix) :]]] = value + + if key.startswith(prefix + "transformer."): + new_key = key[len(prefix + "transformer.") :] + if new_key.endswith(".in_proj_weight"): + new_key = new_key[: -len(".in_proj_weight")] + new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) + text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :] + text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model : d_model * 2, :] + text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2 :, :] + elif new_key.endswith(".in_proj_bias"): + new_key = new_key[: -len(".in_proj_bias")] + new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) + text_model_dict[new_key + ".q_proj.bias"] = checkpoint[key][:d_model] + text_model_dict[new_key + ".k_proj.bias"] = checkpoint[key][d_model : d_model * 2] + text_model_dict[new_key + ".v_proj.bias"] = checkpoint[key][d_model * 2 :] + else: + new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) + + text_model_dict[new_key] = checkpoint[key] + + if is_accelerate_available(): + for param_name, param in text_model_dict.items(): + set_module_tensor_to_device(text_model, param_name, "cpu", value=param) + else: + text_model.load_state_dict(text_model_dict) + + return text_model + + +def stable_unclip_image_encoder(original_config): + """ + Returns the image processor and clip image encoder for the img2img unclip pipeline. + + We currently know of two types of stable unclip models which separately use the clip and the openclip image + encoders. + """ + + image_embedder_config = original_config.model.params.embedder_config + + sd_clip_image_embedder_class = image_embedder_config.target + sd_clip_image_embedder_class = sd_clip_image_embedder_class.split(".")[-1] + + if sd_clip_image_embedder_class == "ClipImageEmbedder": + clip_model_name = image_embedder_config.params.model + + if clip_model_name == "ViT-L/14": + feature_extractor = CLIPImageProcessor() + image_encoder = CLIPVisionModelWithProjection.from_pretrained(CONVERT_MODEL_ROOT / "clip-vit-large-patch14") + else: + raise NotImplementedError(f"Unknown CLIP checkpoint name in stable diffusion checkpoint {clip_model_name}") + + elif sd_clip_image_embedder_class == "FrozenOpenCLIPImageEmbedder": + feature_extractor = CLIPImageProcessor() + # InvokeAI doesn't use CLIPVisionModelWithProjection so it isn't in the core - if this code is hit a download will occur + image_encoder = CLIPVisionModelWithProjection.from_pretrained( + CONVERT_MODEL_ROOT / "CLIP-ViT-H-14-laion2B-s32B-b79K" + ) + else: + raise NotImplementedError( + f"Unknown CLIP image embedder class in stable diffusion checkpoint {sd_clip_image_embedder_class}" + ) + + return feature_extractor, image_encoder + + +def stable_unclip_image_noising_components( + original_config, clip_stats_path: Optional[str] = None, device: Optional[str] = None +): + """ + Returns the noising components for the img2img and txt2img unclip pipelines. + + Converts the stability noise augmentor into + 1. a `StableUnCLIPImageNormalizer` for holding the CLIP stats + 2. a `DDPMScheduler` for holding the noise schedule + + If the noise augmentor config specifies a clip stats path, the `clip_stats_path` must be provided. + """ + noise_aug_config = original_config.model.params.noise_aug_config + noise_aug_class = noise_aug_config.target + noise_aug_class = noise_aug_class.split(".")[-1] + + if noise_aug_class == "CLIPEmbeddingNoiseAugmentation": + noise_aug_config = noise_aug_config.params + embedding_dim = noise_aug_config.timestep_dim + max_noise_level = noise_aug_config.noise_schedule_config.timesteps + beta_schedule = noise_aug_config.noise_schedule_config.beta_schedule + + image_normalizer = StableUnCLIPImageNormalizer(embedding_dim=embedding_dim) + image_noising_scheduler = DDPMScheduler(num_train_timesteps=max_noise_level, beta_schedule=beta_schedule) + + if "clip_stats_path" in noise_aug_config: + if clip_stats_path is None: + raise ValueError("This stable unclip config requires a `clip_stats_path`") + + clip_mean, clip_std = torch.load(clip_stats_path, map_location=device) + clip_mean = clip_mean[None, :] + clip_std = clip_std[None, :] + + clip_stats_state_dict = { + "mean": clip_mean, + "std": clip_std, + } + + image_normalizer.load_state_dict(clip_stats_state_dict) + else: + raise NotImplementedError(f"Unknown noise augmentor class: {noise_aug_class}") + + return image_normalizer, image_noising_scheduler + + +def convert_controlnet_checkpoint( + checkpoint, + original_config, + checkpoint_path, + image_size, + upcast_attention, + extract_ema, + use_linear_projection=None, + cross_attention_dim=None, + precision: Optional[torch.dtype] = None, +): + ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True) + ctrlnet_config["upcast_attention"] = upcast_attention + + ctrlnet_config.pop("sample_size") + original_config = ctrlnet_config.copy() + + ctrlnet_config.pop("addition_embed_type") + ctrlnet_config.pop("addition_time_embed_dim") + ctrlnet_config.pop("transformer_layers_per_block") + + if use_linear_projection is not None: + ctrlnet_config["use_linear_projection"] = use_linear_projection + + if cross_attention_dim is not None: + ctrlnet_config["cross_attention_dim"] = cross_attention_dim + + controlnet = ControlNetModel(**ctrlnet_config) + + # Some controlnet ckpt files are distributed independently from the rest of the + # model components i.e. https://huggingface.co/thibaud/controlnet-sd21/ + if "time_embed.0.weight" in checkpoint: + skip_extract_state_dict = True + else: + skip_extract_state_dict = False + + converted_ctrl_checkpoint = convert_ldm_unet_checkpoint( + checkpoint, + original_config, + path=checkpoint_path, + extract_ema=extract_ema, + controlnet=True, + skip_extract_state_dict=skip_extract_state_dict, + ) + + controlnet.load_state_dict(converted_ctrl_checkpoint) + + return controlnet.to(precision) + + +def download_from_original_stable_diffusion_ckpt( + checkpoint_path: str, + model_version: BaseModelType, + model_variant: ModelVariantType, + original_config_file: str = None, + image_size: Optional[int] = None, + prediction_type: str = None, + model_type: str = None, + extract_ema: bool = False, + precision: Optional[torch.dtype] = None, + scheduler_type: str = "pndm", + num_in_channels: Optional[int] = None, + upcast_attention: Optional[bool] = None, + device: str = None, + from_safetensors: bool = False, + stable_unclip: Optional[str] = None, + stable_unclip_prior: Optional[str] = None, + clip_stats_path: Optional[str] = None, + controlnet: Optional[bool] = None, + load_safety_checker: bool = True, + pipeline_class: DiffusionPipeline = None, + local_files_only=False, + vae_path=None, + text_encoder=None, + tokenizer=None, + scan_needed: bool = True, +) -> DiffusionPipeline: + """ + Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml` + config file. + + Although many of the arguments can be automatically inferred, some of these rely on brittle checks against the + global step count, which will likely fail for models that have undergone further fine-tuning. Therefore, it is + recommended that you override the default values and/or supply an `original_config_file` wherever possible. + + Args: + checkpoint_path (`str`): Path to `.ckpt` file. + original_config_file (`str`): + Path to `.yaml` config file corresponding to the original architecture. If `None`, will be automatically + inferred by looking for a key that only exists in SD2.0 models. + image_size (`int`, *optional*, defaults to 512): + The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Diffusion v2 + Base. Use 768 for Stable Diffusion v2. + prediction_type (`str`, *optional*): + The prediction type that the model was trained on. Use `'epsilon'` for Stable Diffusion v1.X and Stable + Diffusion v2 Base. Use `'v_prediction'` for Stable Diffusion v2. + num_in_channels (`int`, *optional*, defaults to None): + The number of input channels. If `None`, it will be automatically inferred. + scheduler_type (`str`, *optional*, defaults to 'pndm'): + Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm", + "ddim"]`. + model_type (`str`, *optional*, defaults to `None`): + The pipeline type. `None` to automatically infer, or one of `["FrozenOpenCLIPEmbedder", + "FrozenCLIPEmbedder", "PaintByExample"]`. + is_img2img (`bool`, *optional*, defaults to `False`): + Whether the model should be loaded as an img2img pipeline. + extract_ema (`bool`, *optional*, defaults to `False`): Only relevant for + checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults to + `False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher quality images for + inference. Non-EMA weights are usually better to continue fine-tuning. + upcast_attention (`bool`, *optional*, defaults to `None`): + Whether the attention computation should always be upcasted. This is necessary when running stable + diffusion 2.1. + device (`str`, *optional*, defaults to `None`): + The device to use. Pass `None` to determine automatically. + from_safetensors (`str`, *optional*, defaults to `False`): + If `checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch. + load_safety_checker (`bool`, *optional*, defaults to `True`): + Whether to load the safety checker or not. Defaults to `True`. + pipeline_class (`str`, *optional*, defaults to `None`): + The pipeline class to use. Pass `None` to determine automatically. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + text_encoder (`CLIPTextModel`, *optional*, defaults to `None`): + An instance of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel) + to use, specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) + variant. If this parameter is `None`, the function will load a new instance of [CLIP] by itself, if needed. + tokenizer (`CLIPTokenizer`, *optional*, defaults to `None`): + An instance of + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer) + to use. If this parameter is `None`, the function will load a new instance of [CLIPTokenizer] by itself, if + needed. + precision (`torch.dtype`, *optional*, defauts to `None`): + If not provided the precision will be set to the precision of the original file. + return: A StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file. + """ + + # import pipelines here to avoid circular import error when using from_single_file method + from diffusers import ( + LDMTextToImagePipeline, + PaintByExamplePipeline, + StableDiffusionControlNetPipeline, + StableDiffusionInpaintPipeline, + StableDiffusionPipeline, + StableDiffusionXLImg2ImgPipeline, + StableDiffusionXLPipeline, + StableUnCLIPImg2ImgPipeline, + StableUnCLIPPipeline, + ) + + if pipeline_class is None: + pipeline_class = StableDiffusionPipeline if not controlnet else StableDiffusionControlNetPipeline + + if prediction_type == "v-prediction": + prediction_type = "v_prediction" + + if from_safetensors: + from safetensors.torch import load_file as safe_load + + checkpoint = safe_load(checkpoint_path, device="cpu") + else: + if scan_needed: + # scan model + scan_result = scan_file_path(checkpoint_path) + if scan_result.infected_files != 0: + raise Exception("The model {checkpoint_path} is potentially infected by malware. Aborting import.") + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + checkpoint = torch.load(checkpoint_path, map_location=device) + else: + checkpoint = torch.load(checkpoint_path, map_location=device) + + # Sometimes models don't have the global_step item + if "global_step" in checkpoint: + global_step = checkpoint["global_step"] + else: + logger.debug("global_step key not found in model") + global_step = None + + # NOTE: this while loop isn't great but this controlnet checkpoint has one additional + # "state_dict" key https://huggingface.co/thibaud/controlnet-canny-sd21 + while "state_dict" in checkpoint: + checkpoint = checkpoint["state_dict"] + + logger.debug(f"model_type = {model_type}; original_config_file = {original_config_file}") + + precision_probing_key = "model.diffusion_model.input_blocks.0.0.bias" + logger.debug(f"original checkpoint precision == {checkpoint[precision_probing_key].dtype}") + precision = precision or checkpoint[precision_probing_key].dtype + + if original_config_file is None: + key_name_v2_1 = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" + key_name_sd_xl_base = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias" + key_name_sd_xl_refiner = "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias" + + # model_type = "v1" + config_url = ( + "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" + ) + + if key_name_v2_1 in checkpoint and checkpoint[key_name_v2_1].shape[-1] == 1024: + # model_type = "v2" + config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml" + + if global_step == 110000: + # v2.1 needs to upcast attention + upcast_attention = True + elif key_name_sd_xl_base in checkpoint: + # only base xl has two text embedders + config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml" + elif key_name_sd_xl_refiner in checkpoint: + # only refiner xl has embedder and one text embedders + config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml" + + original_config_file = BytesIO(requests.get(config_url).content) + + original_config = OmegaConf.load(original_config_file) + if original_config["model"]["params"].get("use_ema") is not None: + extract_ema = original_config["model"]["params"]["use_ema"] + + if ( + model_version in [BaseModelType.StableDiffusion2, BaseModelType.StableDiffusion1] + and original_config["model"]["params"].get("parameterization") == "v" + ): + prediction_type = "v_prediction" + upcast_attention = True + image_size = 768 if model_version == BaseModelType.StableDiffusion2 else 512 + else: + prediction_type = "epsilon" + upcast_attention = False + image_size = 512 + + # Convert the text model. + if ( + model_type is None + and "cond_stage_config" in original_config.model.params + and original_config.model.params.cond_stage_config is not None + ): + model_type = original_config.model.params.cond_stage_config.target.split(".")[-1] + logger.debug(f"no `model_type` given, `model_type` inferred as: {model_type}") + elif model_type is None and original_config.model.params.network_config is not None: + if original_config.model.params.network_config.params.context_dim == 2048: + model_type = "SDXL" + else: + model_type = "SDXL-Refiner" + if image_size is None: + image_size = 1024 + + if num_in_channels is None and pipeline_class == StableDiffusionInpaintPipeline: + num_in_channels = 9 + elif num_in_channels is None: + num_in_channels = 4 + + if "unet_config" in original_config.model.params: + original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels + + if ( + "parameterization" in original_config["model"]["params"] + and original_config["model"]["params"]["parameterization"] == "v" + ): + if prediction_type is None: + # NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"` + # as it relies on a brittle global step parameter here + prediction_type = "epsilon" if global_step == 875000 else "v_prediction" + if image_size is None: + # NOTE: For stable diffusion 2 base one has to pass `image_size==512` + # as it relies on a brittle global step parameter here + image_size = 512 if global_step == 875000 else 768 + else: + if prediction_type is None: + prediction_type = "epsilon" + if image_size is None: + image_size = 512 + + if controlnet is None and "control_stage_config" in original_config.model.params: + controlnet = convert_controlnet_checkpoint( + checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema + ) + + num_train_timesteps = getattr(original_config.model.params, "timesteps", None) or 1000 + + if model_type in ["SDXL", "SDXL-Refiner"]: + scheduler_dict = { + "beta_schedule": "scaled_linear", + "beta_start": 0.00085, + "beta_end": 0.012, + "interpolation_type": "linear", + "num_train_timesteps": num_train_timesteps, + "prediction_type": "epsilon", + "sample_max_value": 1.0, + "set_alpha_to_one": False, + "skip_prk_steps": True, + "steps_offset": 1, + "timestep_spacing": "leading", + } + scheduler = EulerDiscreteScheduler.from_config(scheduler_dict) + scheduler_type = "euler" + else: + beta_start = getattr(original_config.model.params, "linear_start", None) or 0.02 + beta_end = getattr(original_config.model.params, "linear_end", None) or 0.085 + scheduler = DDIMScheduler( + beta_end=beta_end, + beta_schedule="scaled_linear", + beta_start=beta_start, + num_train_timesteps=num_train_timesteps, + steps_offset=1, + clip_sample=False, + set_alpha_to_one=False, + prediction_type=prediction_type, + ) + # make sure scheduler works correctly with DDIM + scheduler.register_to_config(clip_sample=False) + + if scheduler_type == "pndm": + config = dict(scheduler.config) + config["skip_prk_steps"] = True + scheduler = PNDMScheduler.from_config(config) + elif scheduler_type == "lms": + scheduler = LMSDiscreteScheduler.from_config(scheduler.config) + elif scheduler_type == "heun": + scheduler = HeunDiscreteScheduler.from_config(scheduler.config) + elif scheduler_type == "euler": + scheduler = EulerDiscreteScheduler.from_config(scheduler.config) + elif scheduler_type == "euler-ancestral": + scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config) + elif scheduler_type == "dpm": + scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config) + elif scheduler_type == "ddim": + scheduler = scheduler + else: + raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!") + + # Convert the UNet2DConditionModel model. + unet_config = create_unet_diffusers_config(original_config, image_size=image_size) + unet_config["upcast_attention"] = upcast_attention + converted_unet_checkpoint = convert_ldm_unet_checkpoint( + checkpoint, unet_config, path=checkpoint_path, extract_ema=extract_ema + ) + + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + unet = UNet2DConditionModel(**unet_config) + + if is_accelerate_available(): + for param_name, param in converted_unet_checkpoint.items(): + set_module_tensor_to_device(unet, param_name, "cpu", value=param) + else: + unet.load_state_dict(converted_unet_checkpoint) + + # Convert the VAE model. + if vae_path is None: + vae_config = create_vae_diffusers_config(original_config, image_size=image_size) + converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) + + if ( + "model" in original_config + and "params" in original_config.model + and "scale_factor" in original_config.model.params + ): + vae_scaling_factor = original_config.model.params.scale_factor + else: + vae_scaling_factor = 0.18215 # default SD scaling factor + + vae_config["scaling_factor"] = vae_scaling_factor + + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + vae = AutoencoderKL(**vae_config) + + if is_accelerate_available(): + for param_name, param in converted_vae_checkpoint.items(): + set_module_tensor_to_device(vae, param_name, "cpu", value=param) + else: + vae.load_state_dict(converted_vae_checkpoint) + else: + vae = AutoencoderKL.from_pretrained(vae_path) + + if model_type == "FrozenOpenCLIPEmbedder": + config_name = "stabilityai/stable-diffusion-2" + config_kwargs = {"subfolder": "text_encoder"} + + text_model = convert_open_clip_checkpoint(checkpoint, config_name, **config_kwargs) + tokenizer = CLIPTokenizer.from_pretrained(CONVERT_MODEL_ROOT / "stable-diffusion-2-clip", subfolder="tokenizer") + + if stable_unclip is None: + if controlnet: + pipe = pipeline_class( + vae=vae.to(precision), + text_encoder=text_model.to(precision), + tokenizer=tokenizer, + unet=unet.to(precision), + scheduler=scheduler, + controlnet=controlnet, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + ) + else: + pipe = pipeline_class( + vae=vae.to(precision), + text_encoder=text_model.to(precision), + tokenizer=tokenizer, + unet=unet.to(precision), + scheduler=scheduler, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + ) + else: + image_normalizer, image_noising_scheduler = stable_unclip_image_noising_components( + original_config, clip_stats_path=clip_stats_path, device=device + ) + + if stable_unclip == "img2img": + feature_extractor, image_encoder = stable_unclip_image_encoder(original_config) + + pipe = StableUnCLIPImg2ImgPipeline( + # image encoding components + feature_extractor=feature_extractor, + image_encoder=image_encoder, + # image noising components + image_normalizer=image_normalizer, + image_noising_scheduler=image_noising_scheduler, + # regular denoising components + tokenizer=tokenizer, + text_encoder=text_model.to(precision), + unet=unet.to(precision), + scheduler=scheduler, + # vae + vae=vae, + ) + elif stable_unclip == "txt2img": + if stable_unclip_prior is None or stable_unclip_prior == "karlo": + karlo_model = "kakaobrain/karlo-v1-alpha" + prior = PriorTransformer.from_pretrained(karlo_model, subfolder="prior") + + prior_tokenizer = CLIPTokenizer.from_pretrained(CONVERT_MODEL_ROOT / "clip-vit-large-patch14") + prior_text_model = CLIPTextModelWithProjection.from_pretrained( + CONVERT_MODEL_ROOT / "clip-vit-large-patch14" + ) + + prior_scheduler = UnCLIPScheduler.from_pretrained(karlo_model, subfolder="prior_scheduler") + prior_scheduler = DDPMScheduler.from_config(prior_scheduler.config) + else: + raise NotImplementedError(f"unknown prior for stable unclip model: {stable_unclip_prior}") + + pipe = StableUnCLIPPipeline( + # prior components + prior_tokenizer=prior_tokenizer, + prior_text_encoder=prior_text_model, + prior=prior, + prior_scheduler=prior_scheduler, + # image noising components + image_normalizer=image_normalizer, + image_noising_scheduler=image_noising_scheduler, + # regular denoising components + tokenizer=tokenizer, + text_encoder=text_model, + unet=unet, + scheduler=scheduler, + # vae + vae=vae, + ) + else: + raise NotImplementedError(f"unknown `stable_unclip` type: {stable_unclip}") + elif model_type == "PaintByExample": + vision_model = convert_paint_by_example_checkpoint(checkpoint) + tokenizer = CLIPTokenizer.from_pretrained(CONVERT_MODEL_ROOT / "clip-vit-large-patch14") + feature_extractor = AutoFeatureExtractor.from_pretrained(CONVERT_MODEL_ROOT / "stable-diffusion-safety-checker") + pipe = PaintByExamplePipeline( + vae=vae, + image_encoder=vision_model, + unet=unet, + scheduler=scheduler, + safety_checker=None, + feature_extractor=feature_extractor, + ) + elif model_type == "FrozenCLIPEmbedder": + text_model = convert_ldm_clip_checkpoint( + checkpoint, local_files_only=local_files_only, text_encoder=text_encoder + ) + tokenizer = ( + CLIPTokenizer.from_pretrained(CONVERT_MODEL_ROOT / "clip-vit-large-patch14") + if tokenizer is None + else tokenizer + ) + + if load_safety_checker: + safety_checker = StableDiffusionSafetyChecker.from_pretrained( + CONVERT_MODEL_ROOT / "stable-diffusion-safety-checker" + ) + feature_extractor = AutoFeatureExtractor.from_pretrained( + CONVERT_MODEL_ROOT / "stable-diffusion-safety-checker" + ) + else: + safety_checker = None + feature_extractor = None + + if controlnet: + pipe = pipeline_class( + vae=vae.to(precision), + text_encoder=text_model.to(precision), + tokenizer=tokenizer, + unet=unet.to(precision), + controlnet=controlnet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + else: + pipe = pipeline_class( + vae=vae.to(precision), + text_encoder=text_model.to(precision), + tokenizer=tokenizer, + unet=unet.to(precision), + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + elif model_type in ["SDXL", "SDXL-Refiner"]: + if model_type == "SDXL": + tokenizer = CLIPTokenizer.from_pretrained(CONVERT_MODEL_ROOT / "clip-vit-large-patch14") + text_encoder = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only) + + tokenizer_name = CONVERT_MODEL_ROOT / "CLIP-ViT-bigG-14-laion2B-39B-b160k" + tokenizer_2 = CLIPTokenizer.from_pretrained(tokenizer_name, pad_token="!") + + config_name = tokenizer_name + config_kwargs = {"projection_dim": 1280} + text_encoder_2 = convert_open_clip_checkpoint( + checkpoint, config_name, prefix="conditioner.embedders.1.model.", has_projection=True, **config_kwargs + ) + + pipe = StableDiffusionXLPipeline( + vae=vae.to(precision), + text_encoder=text_encoder.to(precision), + tokenizer=tokenizer, + text_encoder_2=text_encoder_2.to(precision), + tokenizer_2=tokenizer_2, + unet=unet.to(precision), + scheduler=scheduler, + force_zeros_for_empty_prompt=True, + ) + else: + tokenizer = None + text_encoder = None + tokenizer_name = CONVERT_MODEL_ROOT / "CLIP-ViT-bigG-14-laion2B-39B-b160k" + tokenizer_2 = CLIPTokenizer.from_pretrained(tokenizer_name, pad_token="!") + + config_name = tokenizer_name + config_kwargs = {"projection_dim": 1280} + text_encoder_2 = convert_open_clip_checkpoint( + checkpoint, config_name, prefix="conditioner.embedders.0.model.", has_projection=True, **config_kwargs + ) + + pipe = StableDiffusionXLImg2ImgPipeline( + vae=vae.to(precision), + text_encoder=text_encoder, + tokenizer=tokenizer, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + unet=unet.to(precision), + scheduler=scheduler, + requires_aesthetics_score=True, + force_zeros_for_empty_prompt=False, + ) + else: + text_config = create_ldm_bert_config(original_config) + text_model = convert_ldm_bert_checkpoint(checkpoint, text_config) + tokenizer = BertTokenizerFast.from_pretrained(CONVERT_MODEL_ROOT / "bert-base-uncased") + pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler) + + return pipe + + +def download_controlnet_from_original_ckpt( + checkpoint_path: str, + original_config_file: str, + image_size: int = 512, + extract_ema: bool = False, + precision: Optional[torch.dtype] = None, + num_in_channels: Optional[int] = None, + upcast_attention: Optional[bool] = None, + device: str = None, + from_safetensors: bool = False, + use_linear_projection: Optional[bool] = None, + cross_attention_dim: Optional[bool] = None, + scan_needed: bool = False, +) -> DiffusionPipeline: + + from omegaconf import OmegaConf + + if from_safetensors: + from safetensors import safe_open + + checkpoint = {} + with safe_open(checkpoint_path, framework="pt", device="cpu") as f: + for key in f.keys(): + checkpoint[key] = f.get_tensor(key) + else: + if scan_needed: + # scan model + scan_result = scan_file_path(checkpoint_path) + if scan_result.infected_files != 0: + raise Exception("The model {checkpoint_path} is potentially infected by malware. Aborting import.") + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + checkpoint = torch.load(checkpoint_path, map_location=device) + else: + checkpoint = torch.load(checkpoint_path, map_location=device) + + # NOTE: this while loop isn't great but this controlnet checkpoint has one additional + # "state_dict" key https://huggingface.co/thibaud/controlnet-canny-sd21 + while "state_dict" in checkpoint: + checkpoint = checkpoint["state_dict"] + + # use original precision + precision_probing_key = "input_blocks.0.0.bias" + ckpt_precision = checkpoint[precision_probing_key].dtype + logger.debug(f"original controlnet precision = {ckpt_precision}") + precision = precision or ckpt_precision + + original_config = OmegaConf.load(original_config_file) + + if num_in_channels is not None: + original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels + + if "control_stage_config" not in original_config.model.params: + raise ValueError("`control_stage_config` not present in original config") + + controlnet = convert_controlnet_checkpoint( + checkpoint, + original_config, + checkpoint_path, + image_size, + upcast_attention, + extract_ema, + use_linear_projection=use_linear_projection, + cross_attention_dim=cross_attention_dim, + ) + + return controlnet.to(precision) + + +def convert_ldm_vae_to_diffusers(checkpoint, vae_config: DictConfig, image_size: int) -> AutoencoderKL: + vae_config = create_vae_diffusers_config(vae_config, image_size=image_size) + + converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) + + vae = AutoencoderKL(**vae_config) + vae.load_state_dict(converted_vae_checkpoint) + return vae + + +def convert_ckpt_to_diffusers( + checkpoint_path: Union[str, Path], + dump_path: Union[str, Path], + use_safetensors: bool = True, + **kwargs, +): + """ + Takes all the arguments of download_from_original_stable_diffusion_ckpt(), + and in addition a path-like object indicating the location of the desired diffusers + model to be written. + """ + pipe = download_from_original_stable_diffusion_ckpt(checkpoint_path, **kwargs) + + # TO DO: save correct repo variant + pipe.save_pretrained( + dump_path, + safe_serialization=use_safetensors, + ) + + +def convert_controlnet_to_diffusers( + checkpoint_path: Union[str, Path], + dump_path: Union[str, Path], + **kwargs, +): + """ + Takes all the arguments of download_controlnet_from_original_ckpt(), + and in addition a path-like object indicating the location of the desired diffusers + model to be written. + """ + pipe = download_controlnet_from_original_ckpt(checkpoint_path, **kwargs) + + # TO DO: save correct repo variant + pipe.save_pretrained(dump_path, safe_serialization=True) diff --git a/invokeai/backend/model_manager/load/__init__.py b/invokeai/backend/model_manager/load/__init__.py index e69de29bb2d..357677bb7f7 100644 --- a/invokeai/backend/model_manager/load/__init__.py +++ b/invokeai/backend/model_manager/load/__init__.py @@ -0,0 +1,35 @@ +# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development Team +""" +Init file for the model loader. +""" +from importlib import import_module +from pathlib import Path +from typing import Optional + +from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.backend.util.logging import InvokeAILogger +from .load_base import AnyModelLoader, LoadedModel +from .model_cache.model_cache_default import ModelCache +from .convert_cache.convert_cache_default import ModelConvertCache + +# This registers the subclasses that implement loaders of specific model types +loaders = [x.stem for x in Path(Path(__file__).parent,'model_loaders').glob('*.py') if x.stem != '__init__'] +for module in loaders: + print(f'module={module}') + import_module(f"{__package__}.model_loaders.{module}") + +__all__ = ["AnyModelLoader", "LoadedModel"] + + +def get_standalone_loader(app_config: Optional[InvokeAIAppConfig]) -> AnyModelLoader: + app_config = app_config or InvokeAIAppConfig.get_config() + logger = InvokeAILogger.get_logger(config=app_config) + return AnyModelLoader(app_config=app_config, + logger=logger, + ram_cache=ModelCache(logger=logger, + max_cache_size=app_config.ram_cache_size, + max_vram_cache_size=app_config.vram_cache_size + ), + convert_cache=ModelConvertCache(app_config.models_convert_cache_path) + ) + diff --git a/invokeai/backend/model_manager/load/convert_cache/__init__.py b/invokeai/backend/model_manager/load/convert_cache/__init__.py new file mode 100644 index 00000000000..eb3149be329 --- /dev/null +++ b/invokeai/backend/model_manager/load/convert_cache/__init__.py @@ -0,0 +1,4 @@ +from .convert_cache_base import ModelConvertCacheBase +from .convert_cache_default import ModelConvertCache + +__all__ = ['ModelConvertCacheBase', 'ModelConvertCache'] diff --git a/invokeai/backend/model_manager/load/convert_cache/convert_cache_base.py b/invokeai/backend/model_manager/load/convert_cache/convert_cache_base.py new file mode 100644 index 00000000000..25263f96aaa --- /dev/null +++ b/invokeai/backend/model_manager/load/convert_cache/convert_cache_base.py @@ -0,0 +1,28 @@ +""" +Disk-based converted model cache. +""" +from abc import ABC, abstractmethod +from pathlib import Path + +class ModelConvertCacheBase(ABC): + + @property + @abstractmethod + def max_size(self) -> float: + """Return the maximum size of this cache directory.""" + pass + + @abstractmethod + def make_room(self, size: float) -> None: + """ + Make sufficient room in the cache directory for a model of max_size. + + :param size: Size required (GB) + """ + pass + + @abstractmethod + def cache_path(self, key: str) -> Path: + """Return the path for a model with the indicated key.""" + pass + diff --git a/invokeai/backend/model_manager/load/convert_cache/convert_cache_default.py b/invokeai/backend/model_manager/load/convert_cache/convert_cache_default.py new file mode 100644 index 00000000000..f799510ec5b --- /dev/null +++ b/invokeai/backend/model_manager/load/convert_cache/convert_cache_default.py @@ -0,0 +1,64 @@ +""" +Placeholder for convert cache implementation. +""" + +from pathlib import Path +import shutil +from invokeai.backend.util.logging import InvokeAILogger +from invokeai.backend.util import GIG, directory_size +from .convert_cache_base import ModelConvertCacheBase + +class ModelConvertCache(ModelConvertCacheBase): + + def __init__(self, cache_path: Path, max_size: float=10.0): + """Initialize the convert cache with the base directory and a limit on its maximum size (in GBs).""" + if not cache_path.exists(): + cache_path.mkdir(parents=True) + self._cache_path = cache_path + self._max_size = max_size + + @property + def max_size(self) -> float: + """Return the maximum size of this cache directory (GB).""" + return self._max_size + + def cache_path(self, key: str) -> Path: + """Return the path for a model with the indicated key.""" + return self._cache_path / key + + def make_room(self, size: float) -> None: + """ + Make sufficient room in the cache directory for a model of max_size. + + :param size: Size required (GB) + """ + size_needed = directory_size(self._cache_path) + size + max_size = int(self.max_size) * GIG + logger = InvokeAILogger.get_logger() + + if size_needed <= max_size: + return + + logger.debug( + f"Convert cache has gotten too large {(size_needed / GIG):4.2f} > {(max_size / GIG):4.2f}G.. Trimming." + ) + + # For this to work, we make the assumption that the directory contains + # a 'model_index.json', 'unet/config.json' file, or a 'config.json' file at top level. + # This should be true for any diffusers model. + def by_atime(path: Path) -> float: + for config in ["model_index.json", "unet/config.json", "config.json"]: + sentinel = path / config + if sentinel.exists(): + return sentinel.stat().st_atime + return 0.0 + + # sort by last access time - least accessed files will be at the end + lru_models = sorted(self._cache_path.iterdir(), key=by_atime, reverse=True) + logger.debug(f"cached models in descending atime order: {lru_models}") + while size_needed > max_size and len(lru_models) > 0: + next_victim = lru_models.pop() + victim_size = directory_size(next_victim) + logger.debug(f"Removing cached converted model {next_victim} to free {victim_size / GIG} GB") + shutil.rmtree(next_victim) + size_needed -= victim_size diff --git a/invokeai/backend/model_manager/load/load_base.py b/invokeai/backend/model_manager/load/load_base.py index 7cb7222b717..3ade83160a2 100644 --- a/invokeai/backend/model_manager/load/load_base.py +++ b/invokeai/backend/model_manager/load/load_base.py @@ -16,39 +16,11 @@ from pathlib import Path from typing import Any, Callable, Dict, Optional, Type, Union -import torch -from diffusers import DiffusionPipeline -from injector import inject - from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.app.services.model_records import ModelRecordServiceBase -from invokeai.backend.model_manager import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType -from invokeai.backend.model_manager.convert_cache import ModelConvertCacheBase -from invokeai.backend.model_manager.onnx_runtime import IAIOnnxRuntimeModel -from invokeai.backend.model_manager.ram_cache import ModelCacheBase - -AnyModel = Union[DiffusionPipeline, torch.nn.Module, IAIOnnxRuntimeModel] - - -class ModelLockerBase(ABC): - """Base class for the model locker used by the loader.""" - - @abstractmethod - def lock(self) -> None: - """Lock the contained model and move it into VRAM.""" - pass - - @abstractmethod - def unlock(self) -> None: - """Unlock the contained model, and remove it from VRAM.""" - pass - - @property - @abstractmethod - def model(self) -> AnyModel: - """Return the model.""" - pass - +from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType +from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase +from invokeai.backend.model_manager.load.model_cache.model_locker import ModelLockerBase +from invokeai.backend.model_manager.load.convert_cache.convert_cache_base import ModelConvertCacheBase @dataclass class LoadedModel: @@ -69,7 +41,7 @@ def __exit__(self, *args: Any, **kwargs: Any) -> None: @property def model(self) -> AnyModel: """Return the model without locking it.""" - return self.locker.model() + return self.locker.model class ModelLoaderBase(ABC): @@ -89,9 +61,9 @@ def __init__( @abstractmethod def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel: """ - Return a model given its key. + Return a model given its confguration. - Given a model key identified in the model configuration backend, + Given a model identified in the model configuration backend, return a ModelInfo object that can be used to retrieve the model. :param model_config: Model configuration, as returned by ModelConfigRecordStore @@ -115,34 +87,32 @@ class AnyModelLoader: # this tracks the loader subclasses _registry: Dict[str, Type[ModelLoaderBase]] = {} - @inject def __init__( self, - store: ModelRecordServiceBase, app_config: InvokeAIAppConfig, logger: Logger, ram_cache: ModelCacheBase, convert_cache: ModelConvertCacheBase, ): - """Store the provided ModelRecordServiceBase and empty the registry.""" - self._store = store + """Initialize AnyModelLoader with its dependencies.""" self._app_config = app_config self._logger = logger self._ram_cache = ram_cache self._convert_cache = convert_cache - def get_model(self, key: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel: - """ - Return a model given its key. + @property + def ram_cache(self) -> ModelCacheBase: + """Return the RAM cache associated used by the loaders.""" + return self._ram_cache - Given a model key identified in the model configuration backend, - return a ModelInfo object that can be used to retrieve the model. + def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType]=None) -> LoadedModel: + """ + Return a model given its configuration. :param key: model key, as known to the config backend :param submodel_type: an ModelType enum indicating the portion of the model to retrieve (e.g. ModelType.Vae) """ - model_config = self._store.get_model(key) implementation = self.__class__.get_implementation( base=model_config.base, type=model_config.type, format=model_config.format ) @@ -165,7 +135,7 @@ def get_implementation(cls, base: BaseModelType, type: ModelType, format: ModelF implementation = cls._registry.get(key1) or cls._registry.get(key2) if not implementation: raise NotImplementedError( - "No subclass of LoadedModel is registered for base={base}, type={type}, format={format}" + f"No subclass of LoadedModel is registered for base={base}, type={type}, format={format}" ) return implementation @@ -176,18 +146,10 @@ def register( """Define a decorator which registers the subclass of loader.""" def decorator(subclass: Type[ModelLoaderBase]) -> Type[ModelLoaderBase]: - print("Registering class", subclass.__name__) + print("DEBUG: Registering class", subclass.__name__) key = cls._to_registry_key(base, type, format) cls._registry[key] = subclass return subclass return decorator - -# in _init__.py will call something like -# def configure_loader_dependencies(binder): -# binder.bind(ModelRecordServiceBase, ApiDependencies.invoker.services.model_records, scope=singleton) -# binder.bind(InvokeAIAppConfig, ApiDependencies.invoker.services.configuration, scope=singleton) -# etc -# injector = Injector(configure_loader_dependencies) -# loader = injector.get(ModelFactory) diff --git a/invokeai/backend/model_manager/load/load_default.py b/invokeai/backend/model_manager/load/load_default.py index eb2d432aaae..0b028235fdd 100644 --- a/invokeai/backend/model_manager/load/load_default.py +++ b/invokeai/backend/model_manager/load/load_default.py @@ -8,15 +8,14 @@ from diffusers import ModelMixin from diffusers.configuration_utils import ConfigMixin -from injector import inject from invokeai.app.services.config import InvokeAIAppConfig from invokeai.backend.model_manager import AnyModelConfig, InvalidModelConfigException, ModelRepoVariant, SubModelType -from invokeai.backend.model_manager.convert_cache import ModelConvertCacheBase +from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase from invokeai.backend.model_manager.load.load_base import AnyModel, LoadedModel, ModelLoaderBase from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init -from invokeai.backend.model_manager.ram_cache import ModelCacheBase, ModelLockerBase +from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase from invokeai.backend.util.devices import choose_torch_device, torch_dtype @@ -35,7 +34,6 @@ def load_config(cls, *args: Any, **kwargs: Any) -> Dict[str, Any]: class ModelLoader(ModelLoaderBase): """Default implementation of ModelLoaderBase.""" - @inject # can inject instances of each of the classes in the call signature def __init__( self, app_config: InvokeAIAppConfig, @@ -87,18 +85,15 @@ def _get_model_path( def _convert_if_needed( self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None ) -> Path: - if not self._needs_conversion(config): - return model_path - - self._convert_cache.make_room(self._size or self.get_size_fs(config, model_path, submodel_type)) cache_path: Path = self._convert_cache.cache_path(config.key) - if cache_path.exists(): - return cache_path - self._convert_model(model_path, cache_path) - return cache_path + if not self._needs_conversion(config, model_path, cache_path): + return cache_path if cache_path.exists() else model_path + + self._convert_cache.make_room(self._size or self.get_size_fs(config, model_path, submodel_type)) + return self._convert_model(config, model_path, cache_path) - def _needs_conversion(self, config: AnyModelConfig) -> bool: + def _needs_conversion(self, config: AnyModelConfig, model_path: Path, cache_path: Path) -> bool: return False def _load_if_needed( @@ -133,7 +128,7 @@ def get_size_fs( variant=config.repo_variant if hasattr(config, "repo_variant") else None, ) - def _convert_model(self, model_path: Path, cache_path: Path) -> None: + def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path: raise NotImplementedError def _load_model( diff --git a/invokeai/backend/model_manager/load/model_cache/__init__.py b/invokeai/backend/model_manager/load/model_cache/__init__.py new file mode 100644 index 00000000000..776b9d8936d --- /dev/null +++ b/invokeai/backend/model_manager/load/model_cache/__init__.py @@ -0,0 +1,5 @@ +"""Init file for RamCache.""" + +from .model_cache_base import ModelCacheBase +from .model_cache_default import ModelCache +_all__ = ['ModelCacheBase', 'ModelCache'] diff --git a/invokeai/backend/model_manager/load/ram_cache/ram_cache_base.py b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py similarity index 77% rename from invokeai/backend/model_manager/load/ram_cache/ram_cache_base.py rename to invokeai/backend/model_manager/load/model_cache/model_cache_base.py index cd80d1e78b2..50b69d961c6 100644 --- a/invokeai/backend/model_manager/load/ram_cache/ram_cache_base.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py @@ -10,34 +10,41 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field from logging import Logger -from typing import Dict, Optional +from typing import Dict, Optional, TypeVar, Generic import torch -from invokeai.backend.model_manager import SubModelType -from invokeai.backend.model_manager.load.load_base import AnyModel, ModelLockerBase +from invokeai.backend.model_manager import AnyModel, SubModelType +class ModelLockerBase(ABC): + """Base class for the model locker used by the loader.""" -@dataclass -class CacheStats(object): - """Data object to record statistics on cache hits/misses.""" + @abstractmethod + def lock(self) -> AnyModel: + """Lock the contained model and move it into VRAM.""" + pass - hits: int = 0 # cache hits - misses: int = 0 # cache misses - high_watermark: int = 0 # amount of cache used - in_cache: int = 0 # number of models in cache - cleared: int = 0 # number of models cleared to make space - cache_size: int = 0 # total size of cache - loaded_model_sizes: Dict[str, int] = field(default_factory=dict) + @abstractmethod + def unlock(self) -> None: + """Unlock the contained model, and remove it from VRAM.""" + pass + + @property + @abstractmethod + def model(self) -> AnyModel: + """Return the model.""" + pass +T = TypeVar("T") @dataclass -class CacheRecord: +class CacheRecord(Generic[T]): """Elements of the cache.""" key: str - model: AnyModel + model: T size: int + loaded: bool = False _locks: int = 0 def lock(self) -> None: @@ -55,7 +62,7 @@ def locked(self) -> bool: return self._locks > 0 -class ModelCacheBase(ABC): +class ModelCacheBase(ABC, Generic[T]): """Virtual base class for RAM model cache.""" @property @@ -76,8 +83,14 @@ def lazy_offloading(self) -> bool: """Return true if the cache is configured to lazily offload models in VRAM.""" pass + @property @abstractmethod - def offload_unlocked_models(self) -> None: + def max_cache_size(self) -> float: + """Return true if the cache is configured to lazily offload models in VRAM.""" + pass + + @abstractmethod + def offload_unlocked_models(self, size_required: int) -> None: """Offload from VRAM any models not actively in use.""" pass @@ -101,7 +114,7 @@ def make_room(self, size: int) -> None: def put( self, key: str, - model: AnyModel, + model: T, submodel_type: Optional[SubModelType] = None, ) -> None: """Store model under key and optional submodel_type.""" @@ -134,11 +147,6 @@ def cache_size(self) -> int: """Get the total size of the models currently cached.""" pass - @abstractmethod - def get_stats(self) -> CacheStats: - """Return cache hit/miss/size statistics.""" - pass - @abstractmethod def print_cuda_stats(self) -> None: """Log debugging information on CUDA usage.""" diff --git a/invokeai/backend/model_manager/load/ram_cache/ram_cache_default.py b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py similarity index 63% rename from invokeai/backend/model_manager/load/ram_cache/ram_cache_default.py rename to invokeai/backend/model_manager/load/model_cache/model_cache_default.py index bd43e978c83..961f68a4bea 100644 --- a/invokeai/backend/model_manager/load/ram_cache/ram_cache_default.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py @@ -18,6 +18,7 @@ """ +import gc import math import time from contextlib import suppress @@ -26,14 +27,14 @@ import torch -from invokeai.app.services.model_records import UnknownModelException from invokeai.backend.model_manager import SubModelType -from invokeai.backend.model_manager.load.load_base import AnyModel, ModelLockerBase +from invokeai.backend.model_manager.load.load_base import AnyModel from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data -from invokeai.backend.model_manager.load.ram_cache.ram_cache_base import CacheRecord, CacheStats, ModelCacheBase from invokeai.backend.util.devices import choose_torch_device from invokeai.backend.util.logging import InvokeAILogger +from .model_cache_base import CacheRecord, ModelCacheBase +from .model_locker import ModelLockerBase, ModelLocker if choose_torch_device() == torch.device("mps"): from torch import mps @@ -52,7 +53,7 @@ MB = 2**20 -class ModelCache(ModelCacheBase): +class ModelCache(ModelCacheBase[AnyModel]): """Implementation of ModelCacheBase.""" def __init__( @@ -92,62 +93,9 @@ def __init__( self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__) self._log_memory_usage = log_memory_usage - # used for stats collection - self.stats = None - - self._cached_models: Dict[str, CacheRecord] = {} + self._cached_models: Dict[str, CacheRecord[AnyModel]] = {} self._cache_stack: List[str] = [] - class ModelLocker(ModelLockerBase): - """Internal class that mediates movement in and out of GPU.""" - - def __init__(self, cache: ModelCacheBase, cache_entry: CacheRecord): - """ - Initialize the model locker. - - :param cache: The ModelCache object - :param cache_entry: The entry in the model cache - """ - self._cache = cache - self._cache_entry = cache_entry - - @property - def model(self) -> AnyModel: - """Return the model without moving it around.""" - return self._cache_entry.model - - def lock(self) -> Any: - """Move the model into the execution device (GPU) and lock it.""" - if not hasattr(self.model, "to"): - return self.model - - # NOTE that the model has to have the to() method in order for this code to move it into GPU! - self._cache_entry.lock() - - try: - if self._cache.lazy_offloading: - self._cache.offload_unlocked_models() - - self._cache.move_model_to_device(self._cache_entry, self._cache.execution_device) - - self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._cache.execution_device}") - self._cache.print_cuda_stats() - - except Exception: - self._cache_entry.unlock() - raise - return self.model - - def unlock(self) -> None: - """Call upon exit from context.""" - if not hasattr(self.model, "to"): - return - - self._cache_entry.unlock() - if not self._cache.lazy_offloading: - self._cache.offload_unlocked_models() - self._cache.print_cuda_stats() - @property def logger(self) -> Logger: """Return the logger used by the cache.""" @@ -168,6 +116,11 @@ def execution_device(self) -> torch.device: """Return the exection device (e.g. "cuda" for VRAM).""" return self._execution_device + @property + def max_cache_size(self) -> float: + """Return the cap on cache size.""" + return self._max_cache_size + def cache_size(self) -> int: """Get the total size of the models currently cached.""" total = 0 @@ -207,18 +160,18 @@ def get( """ Retrieve model using key and optional submodel_type. - This may return an UnknownModelException if the model is not in the cache. + This may return an IndexError if the model is not in the cache. """ key = self._make_cache_key(key, submodel_type) if key not in self._cached_models: - raise UnknownModelException + raise IndexError(f"The model with key {key} is not in the cache.") # this moves the entry to the top (right end) of the stack with suppress(Exception): self._cache_stack.remove(key) self._cache_stack.append(key) cache_entry = self._cached_models[key] - return self.ModelLocker( + return ModelLocker( cache=self, cache_entry=cache_entry, ) @@ -234,19 +187,19 @@ def _make_cache_key(self, model_key: str, submodel_type: Optional[SubModelType] else: return model_key - def offload_unlocked_models(self) -> None: + def offload_unlocked_models(self, size_required: int) -> None: """Move any unused models from VRAM.""" reserved = self._max_vram_cache_size * GIG - vram_in_use = torch.cuda.memory_allocated() - self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB") + vram_in_use = torch.cuda.memory_allocated() + size_required + self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM needed for models; max allowed={(reserved/GIG):.2f}GB") for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size): if vram_in_use <= reserved: break if not cache_entry.locked: self.move_model_to_device(cache_entry, self.storage_device) - - vram_in_use = torch.cuda.memory_allocated() - self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB") + cache_entry.loaded = False + vram_in_use = torch.cuda.memory_allocated() + size_required + self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM now available for models; max allowed={(reserved/GIG):.2f}GB") torch.cuda.empty_cache() if choose_torch_device() == torch.device("mps"): @@ -305,28 +258,111 @@ def move_model_to_device(self, cache_entry: CacheRecord, target_device: torch.de def print_cuda_stats(self) -> None: """Log CUDA diagnostics.""" vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG) - ram = "%4.2fG" % self.cache_size() + ram = "%4.2fG" % (self.cache_size() / GIG) - cached_models = 0 - loaded_models = 0 - locked_models = 0 + in_ram_models = 0 + in_vram_models = 0 + locked_in_vram_models = 0 for cache_record in self._cached_models.values(): - cached_models += 1 assert hasattr(cache_record.model, "device") - if cache_record.model.device is self.storage_device: - loaded_models += 1 + if cache_record.model.device == self.storage_device: + in_ram_models += 1 + else: + in_vram_models += 1 if cache_record.locked: - locked_models += 1 + locked_in_vram_models += 1 self.logger.debug( - f"Current VRAM/RAM usage: {vram}/{ram}; cached_models/loaded_models/locked_models/ =" - f" {cached_models}/{loaded_models}/{locked_models}" + f"Current VRAM/RAM usage: {vram}/{ram}; models_in_ram/models_in_vram(locked) =" + f" {in_ram_models}/{in_vram_models}({locked_in_vram_models})" ) - def get_stats(self) -> CacheStats: - """Return cache hit/miss/size statistics.""" - raise NotImplementedError - - def make_room(self, size: int) -> None: + def make_room(self, model_size: int) -> None: """Make enough room in the cache to accommodate a new model of indicated size.""" - raise NotImplementedError + # calculate how much memory this model will require + # multiplier = 2 if self.precision==torch.float32 else 1 + bytes_needed = model_size + maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes + current_size = self.cache_size() + + if current_size + bytes_needed > maximum_size: + self.logger.debug( + f"Max cache size exceeded: {(current_size/GIG):.2f}/{self.max_cache_size:.2f} GB, need an additional" + f" {(bytes_needed/GIG):.2f} GB" + ) + + self.logger.debug(f"Before unloading: cached_models={len(self._cached_models)}") + + pos = 0 + models_cleared = 0 + while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack): + model_key = self._cache_stack[pos] + cache_entry = self._cached_models[model_key] + + refs = sys.getrefcount(cache_entry.model) + + # HACK: This is a workaround for a memory-management issue that we haven't tracked down yet. We are directly + # going against the advice in the Python docs by using `gc.get_referrers(...)` in this way: + # https://docs.python.org/3/library/gc.html#gc.get_referrers + + # manualy clear local variable references of just finished function calls + # for some reason python don't want to collect it even by gc.collect() immidiately + if refs > 2: + while True: + cleared = False + for referrer in gc.get_referrers(cache_entry.model): + if type(referrer).__name__ == "frame": + # RuntimeError: cannot clear an executing frame + with suppress(RuntimeError): + referrer.clear() + cleared = True + # break + + # repeat if referrers changes(due to frame clear), else exit loop + if cleared: + gc.collect() + else: + break + + device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None + self.logger.debug( + f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}," + f" refs: {refs}" + ) + + # Expected refs: + # 1 from cache_entry + # 1 from getrefcount function + # 1 from onnx runtime object + if not cache_entry.locked and refs <= (3 if "onnx" in model_key else 2): + self.logger.debug( + f"Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)" + ) + current_size -= cache_entry.size + models_cleared += 1 + del self._cache_stack[pos] + del self._cached_models[model_key] + del cache_entry + + else: + pos += 1 + + if models_cleared > 0: + # There would likely be some 'garbage' to be collected regardless of whether a model was cleared or not, but + # there is a significant time cost to calling `gc.collect()`, so we want to use it sparingly. (The time cost + # is high even if no garbage gets collected.) + # + # Calling gc.collect(...) when a model is cleared seems like a good middle-ground: + # - If models had to be cleared, it's a signal that we are close to our memory limit. + # - If models were cleared, there's a good chance that there's a significant amount of garbage to be + # collected. + # + # Keep in mind that gc is only responsible for handling reference cycles. Most objects should be cleaned up + # immediately when their reference count hits 0. + gc.collect() + + torch.cuda.empty_cache() + if choose_torch_device() == torch.device("mps"): + mps.empty_cache() + + self.logger.debug(f"After unloading: cached_models={len(self._cached_models)}") diff --git a/invokeai/backend/model_manager/load/model_cache/model_locker.py b/invokeai/backend/model_manager/load/model_cache/model_locker.py new file mode 100644 index 00000000000..506d0129491 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_cache/model_locker.py @@ -0,0 +1,59 @@ +""" +Base class and implementation of a class that moves models in and out of VRAM. +""" + +from abc import ABC, abstractmethod +from invokeai.backend.model_manager import AnyModel +from .model_cache_base import ModelLockerBase, ModelCacheBase, CacheRecord + +class ModelLocker(ModelLockerBase): + """Internal class that mediates movement in and out of GPU.""" + + def __init__(self, cache: ModelCacheBase[AnyModel], cache_entry: CacheRecord[AnyModel]): + """ + Initialize the model locker. + + :param cache: The ModelCache object + :param cache_entry: The entry in the model cache + """ + self._cache = cache + self._cache_entry = cache_entry + + @property + def model(self) -> AnyModel: + """Return the model without moving it around.""" + return self._cache_entry.model + + def lock(self) -> AnyModel: + """Move the model into the execution device (GPU) and lock it.""" + if not hasattr(self.model, "to"): + return self.model + + # NOTE that the model has to have the to() method in order for this code to move it into GPU! + self._cache_entry.lock() + + try: + if self._cache.lazy_offloading: + self._cache.offload_unlocked_models(self._cache_entry.size) + + self._cache.move_model_to_device(self._cache_entry, self._cache.execution_device) + self._cache_entry.loaded = True + + self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._cache.execution_device}") + self._cache.print_cuda_stats() + + except Exception: + self._cache_entry.unlock() + raise + return self.model + + def unlock(self) -> None: + """Call upon exit from context.""" + if not hasattr(self.model, "to"): + return + + self._cache_entry.unlock() + if not self._cache.lazy_offloading: + self._cache.offload_unlocked_models(self._cache_entry.size) + self._cache.print_cuda_stats() + diff --git a/invokeai/backend/model_manager/load/model_loaders/__init__.py b/invokeai/backend/model_manager/load/model_loaders/__init__.py new file mode 100644 index 00000000000..962cba54811 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_loaders/__init__.py @@ -0,0 +1,3 @@ +""" +Init file for model_loaders. +""" diff --git a/invokeai/backend/model_manager/load/model_loaders/vae.py b/invokeai/backend/model_manager/load/model_loaders/vae.py new file mode 100644 index 00000000000..6f21c3d0903 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_loaders/vae.py @@ -0,0 +1,83 @@ +# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team +"""Class for VAE model loading in InvokeAI.""" + +from pathlib import Path +from typing import Optional + +import torch +import safetensors +from omegaconf import OmegaConf, DictConfig +from invokeai.backend.util.devices import torch_dtype +from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelFormat, ModelRepoVariant, ModelType, SubModelType +from invokeai.backend.model_manager.load.load_base import AnyModelLoader +from invokeai.backend.model_manager.load.load_default import ModelLoader +from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers + +@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Vae, format=ModelFormat.Diffusers) +@AnyModelLoader.register(base=BaseModelType.StableDiffusion1, type=ModelType.Vae, format=ModelFormat.Checkpoint) +@AnyModelLoader.register(base=BaseModelType.StableDiffusion2, type=ModelType.Vae, format=ModelFormat.Checkpoint) +class VaeDiffusersModel(ModelLoader): + """Class to load VAE models.""" + + def _load_model( + self, + model_path: Path, + model_variant: Optional[ModelRepoVariant] = None, + submodel_type: Optional[SubModelType] = None, + ) -> AnyModel: + if submodel_type is not None: + raise Exception("There are no submodels in VAEs") + vae_class = self._get_hf_load_class(model_path) + variant = model_variant.value if model_variant else None + result: AnyModel = vae_class.from_pretrained( + model_path, torch_dtype=self._torch_dtype, variant=variant + ) # type: ignore + return result + + def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool: + print(f'DEBUG: last_modified={config.last_modified}') + print(f'DEBUG: cache_path={(dest_path / "config.json").stat().st_mtime}') + print(f'DEBUG: model_path={model_path.stat().st_mtime}') + if config.format != ModelFormat.Checkpoint: + return False + elif dest_path.exists() \ + and (dest_path / "config.json").stat().st_mtime >= config.last_modified \ + and (dest_path / "config.json").stat().st_mtime >= model_path.stat().st_mtime: + return False + else: + return True + + def _convert_model(self, + config: AnyModelConfig, + weights_path: Path, + output_path: Path + ) -> Path: + if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}: + raise Exception(f"Vae conversion not supported for model type: {config.base}") + else: + config_file = 'v1-inference.yaml' if config.base == BaseModelType.StableDiffusion1 else "v2-inference-v.yaml" + + if weights_path.suffix == ".safetensors": + checkpoint = safetensors.torch.load_file(weights_path, device="cpu") + else: + checkpoint = torch.load(weights_path, map_location="cpu") + + dtype = torch_dtype() + + # sometimes weights are hidden under "state_dict", and sometimes not + if "state_dict" in checkpoint: + checkpoint = checkpoint["state_dict"] + + ckpt_config = OmegaConf.load(self._app_config.legacy_conf_path / config_file) + assert isinstance(ckpt_config, DictConfig) + + print(f'DEBUG: CONVERTIGN') + vae_model = convert_ldm_vae_to_diffusers( + checkpoint=checkpoint, + vae_config=ckpt_config, + image_size=512, + ) + vae_model.to(dtype) # set precision appropriately + vae_model.save_pretrained(output_path, safe_serialization=True, torch_dtype=dtype) + return output_path + diff --git a/invokeai/backend/model_manager/load/model_util.py b/invokeai/backend/model_manager/load/model_util.py index 18407cbca2e..7c27e66472f 100644 --- a/invokeai/backend/model_manager/load/model_util.py +++ b/invokeai/backend/model_manager/load/model_util.py @@ -48,6 +48,9 @@ def _calc_onnx_model_by_data(model: IAIOnnxRuntimeModel) -> int: def calc_model_size_by_fs(model_path: Path, subfolder: Optional[str] = None, variant: Optional[str] = None) -> int: """Estimate the size of a model on disk in bytes.""" + if model_path.is_file(): + return model_path.stat().st_size + if subfolder is not None: model_path = model_path / subfolder diff --git a/invokeai/backend/model_manager/load/ram_cache/__init__.py b/invokeai/backend/model_manager/load/ram_cache/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/invokeai/backend/model_manager/load/vae.py b/invokeai/backend/model_manager/load/vae.py deleted file mode 100644 index a6cbe241e1e..00000000000 --- a/invokeai/backend/model_manager/load/vae.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team -"""Class for VAE model loading in InvokeAI.""" - -from pathlib import Path -from typing import Dict, Optional - -import torch - -from invokeai.backend.model_manager import BaseModelType, ModelFormat, ModelRepoVariant, ModelType, SubModelType -from invokeai.backend.model_manager.load.load_base import AnyModelLoader -from invokeai.backend.model_manager.load.load_default import ModelLoader - - -@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Vae, format=ModelFormat.Diffusers) -class VaeDiffusersModel(ModelLoader): - """Class to load VAE models.""" - - def _load_model( - self, - model_path: Path, - model_variant: Optional[ModelRepoVariant] = None, - submodel_type: Optional[SubModelType] = None, - ) -> Dict[str, torch.Tensor]: - if submodel_type is not None: - raise Exception("There are no submodels in VAEs") - vae_class = self._get_hf_load_class(model_path) - variant = model_variant.value if model_variant else "" - result: Dict[str, torch.Tensor] = vae_class.from_pretrained( - model_path, torch_dtype=self._torch_dtype, variant=variant - ) # type: ignore - return result diff --git a/invokeai/backend/util/__init__.py b/invokeai/backend/util/__init__.py index 87ae1480f54..0164dffe303 100644 --- a/invokeai/backend/util/__init__.py +++ b/invokeai/backend/util/__init__.py @@ -12,6 +12,14 @@ torch_dtype, ) from .logging import InvokeAILogger -from .util import Chdir, ask_user, download_with_resume, instantiate_from_config, url_attachment_name # noqa: F401 +from .util import ( # TO DO: Clean this up; remove the unused symbols + GIG, + Chdir, + ask_user, # noqa + directory_size, + download_with_resume, + instantiate_from_config, # noqa + url_attachment_name, # noqa + ) -__all__ = ["Chdir", "InvokeAILogger", "choose_precision", "choose_torch_device"] +__all__ = ["GIG", "directory_size","Chdir", "download_with_resume", "InvokeAILogger", "choose_precision", "choose_torch_device"] diff --git a/invokeai/backend/util/devices.py b/invokeai/backend/util/devices.py index d6d3ad727f7..ad3f4e139a7 100644 --- a/invokeai/backend/util/devices.py +++ b/invokeai/backend/util/devices.py @@ -1,7 +1,7 @@ from __future__ import annotations from contextlib import nullcontext -from typing import Union +from typing import Union, Optional import torch from torch import autocast @@ -43,7 +43,8 @@ def choose_precision(device: torch.device) -> str: return "float32" -def torch_dtype(device: torch.device) -> torch.dtype: +def torch_dtype(device: Optional[torch.device] = None) -> torch.dtype: + device = device or choose_torch_device() precision = choose_precision(device) if precision == "float16": return torch.float16 diff --git a/invokeai/backend/util/util.py b/invokeai/backend/util/util.py index 13751e27702..6589aa72784 100644 --- a/invokeai/backend/util/util.py +++ b/invokeai/backend/util/util.py @@ -24,6 +24,20 @@ from .devices import torch_dtype +# actual size of a gig +GIG = 1073741824 + +def directory_size(directory: Path) -> int: + """ + Return the aggregate size of all files in a directory (bytes). + """ + sum = 0 + for root, dirs, files in os.walk(directory): + for f in files: + sum += Path(root, f).stat().st_size + for d in dirs: + sum += Path(root, d).stat().st_size + return sum def log_txt_as_img(wh, xc, size=10): # wh a tuple of (width, height) From 3959957bc3d8a96b96c8da64210cc11f3f11d343 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sun, 4 Feb 2024 17:23:10 -0500 Subject: [PATCH 075/340] loaders for main, controlnet, ip-adapter, clipvision and t2i --- .../app/services/config/config_default.py | 2 +- .../model_records/model_records_base.py | 11 +- .../model_records/model_records_sql.py | 17 +- .../sqlite_migrator/migrations/migration_6.py | 5 +- invokeai/backend/install/install_helper.py | 1 + .../model_management/models/controlnet.py | 1 - invokeai/backend/model_manager/__init__.py | 2 +- invokeai/backend/model_manager/config.py | 15 +- .../convert_ckpt_to_diffusers.py | 4 +- .../backend/model_manager/load/__init__.py | 24 +- .../load/convert_cache/__init__.py | 2 +- .../load/convert_cache/convert_cache_base.py | 3 +- .../convert_cache/convert_cache_default.py | 10 +- .../backend/model_manager/load/load_base.py | 52 +- .../model_manager/load/load_default.py | 50 +- .../model_manager/load/memory_snapshot.py | 2 +- .../load/model_cache/__init__.py | 2 +- .../load/model_cache/model_cache_base.py | 8 +- .../load/model_cache/model_cache_default.py | 46 +- .../load/model_cache/model_locker.py | 6 +- .../load/model_loaders/controlnet.py | 60 ++ .../load/model_loaders/generic_diffusers.py | 34 + .../load/model_loaders/ip_adapter.py | 39 ++ .../model_manager/load/model_loaders/lora.py | 76 +++ .../load/model_loaders/stable_diffusion.py | 93 +++ .../model_manager/load/model_loaders/vae.py | 66 +- .../backend/model_manager/load/model_util.py | 5 +- invokeai/backend/model_manager/lora.py | 620 ++++++++++++++++++ invokeai/backend/model_manager/probe.py | 6 +- invokeai/backend/util/__init__.py | 16 +- invokeai/backend/util/devices.py | 2 +- invokeai/backend/util/util.py | 2 + 32 files changed, 1123 insertions(+), 159 deletions(-) create mode 100644 invokeai/backend/model_manager/load/model_loaders/controlnet.py create mode 100644 invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py create mode 100644 invokeai/backend/model_manager/load/model_loaders/ip_adapter.py create mode 100644 invokeai/backend/model_manager/load/model_loaders/lora.py create mode 100644 invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py create mode 100644 invokeai/backend/model_manager/lora.py diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py index b161ea18d61..b39e916da34 100644 --- a/invokeai/app/services/config/config_default.py +++ b/invokeai/app/services/config/config_default.py @@ -173,7 +173,7 @@ class InvokeBatch(InvokeAISettings): import os from pathlib import Path -from typing import Any, ClassVar, Dict, List, Literal, Optional, Union +from typing import Any, ClassVar, Dict, List, Literal, Optional from omegaconf import DictConfig, OmegaConf from pydantic import Field diff --git a/invokeai/app/services/model_records/model_records_base.py b/invokeai/app/services/model_records/model_records_base.py index 31cfecb4ec8..42e3c8f83a7 100644 --- a/invokeai/app/services/model_records/model_records_base.py +++ b/invokeai/app/services/model_records/model_records_base.py @@ -11,7 +11,14 @@ from pydantic import BaseModel, Field from invokeai.app.services.shared.pagination import PaginatedResults -from invokeai.backend.model_manager import LoadedModel, AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType +from invokeai.backend.model_manager import ( + AnyModelConfig, + BaseModelType, + LoadedModel, + ModelFormat, + ModelType, + SubModelType, +) from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore @@ -108,7 +115,7 @@ def load_model(self, key: str, submodel_type: Optional[SubModelType]) -> LoadedM Load the indicated model into memory and return a LoadedModel object. :param key: Key of model config to be fetched. - :param submodel_type: For main (pipeline models), the submodel to fetch + :param submodel_type: For main (pipeline models), the submodel to fetch Exceptions: UnknownModelException -- model with this key not known NotImplementedException -- a model loader was not provided at initialization time diff --git a/invokeai/app/services/model_records/model_records_sql.py b/invokeai/app/services/model_records/model_records_sql.py index eee867ccb46..b50cd17a75d 100644 --- a/invokeai/app/services/model_records/model_records_sql.py +++ b/invokeai/app/services/model_records/model_records_sql.py @@ -42,7 +42,6 @@ import json import sqlite3 -import time from math import ceil from pathlib import Path from typing import Any, Dict, List, Optional, Set, Tuple, Union @@ -56,8 +55,8 @@ ModelType, SubModelType, ) -from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore, UnknownMetadataException from invokeai.backend.model_manager.load import AnyModelLoader, LoadedModel +from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore, UnknownMetadataException from ..shared.sqlite.sqlite_database import SqliteDatabase from .model_records_base import ( @@ -72,7 +71,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): """Implementation of the ModelConfigStore ABC using a SQL database.""" - def __init__(self, db: SqliteDatabase, loader: Optional[AnyModelLoader]=None): + def __init__(self, db: SqliteDatabase, loader: Optional[AnyModelLoader] = None): """ Initialize a new object from preexisting sqlite3 connection and threading lock objects. @@ -289,7 +288,9 @@ def search_by_attr( """, tuple(bindings), ) - results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()] + results = [ + ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall() + ] return results def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]: @@ -303,7 +304,9 @@ def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]: """, (str(path),), ) - results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()] + results = [ + ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall() + ] return results def search_by_hash(self, hash: str) -> List[AnyModelConfig]: @@ -317,7 +320,9 @@ def search_by_hash(self, hash: str) -> List[AnyModelConfig]: """, (hash,), ) - results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()] + results = [ + ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall() + ] return results @property diff --git a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_6.py b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_6.py index e72878f726f..b4734445110 100644 --- a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_6.py +++ b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_6.py @@ -1,11 +1,9 @@ import sqlite3 -from logging import Logger -from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration -class Migration6Callback: +class Migration6Callback: def __call__(self, cursor: sqlite3.Cursor) -> None: self._recreate_model_triggers(cursor) @@ -28,6 +26,7 @@ def _recreate_model_triggers(self, cursor: sqlite3.Cursor) -> None: """ ) + def build_migration_6() -> Migration: """ Build the migration from database version 5 to 6. diff --git a/invokeai/backend/install/install_helper.py b/invokeai/backend/install/install_helper.py index 8c03d2ccf84..9f219132d4d 100644 --- a/invokeai/backend/install/install_helper.py +++ b/invokeai/backend/install/install_helper.py @@ -120,6 +120,7 @@ def dispatch(self, event_name: str, payload: Any) -> None: elif payload["event"] == "model_install_cancelled": self._logger.warning(f"{source}: installation cancelled") + class InstallHelper(object): """Capture information stored jointly in INITIAL_MODELS.yaml and the installed models db.""" diff --git a/invokeai/backend/model_management/models/controlnet.py b/invokeai/backend/model_management/models/controlnet.py index da269eba4b7..3b534cb9d14 100644 --- a/invokeai/backend/model_management/models/controlnet.py +++ b/invokeai/backend/model_management/models/controlnet.py @@ -139,7 +139,6 @@ def _convert_controlnet_ckpt_and_cache( cache it to disk, and return Path to converted file. If already on disk then just returns Path. """ - print(f"DEBUG: controlnet config = {model_config}") app_config = InvokeAIAppConfig.get_config() weights = app_config.root_path / model_path output_path = Path(output_path) diff --git a/invokeai/backend/model_manager/__init__.py b/invokeai/backend/model_manager/__init__.py index f3c84cd01f7..98cc5054c73 100644 --- a/invokeai/backend/model_manager/__init__.py +++ b/invokeai/backend/model_manager/__init__.py @@ -13,9 +13,9 @@ SchedulerPredictionType, SubModelType, ) +from .load import LoadedModel from .probe import ModelProbe from .search import ModelSearch -from .load import LoadedModel __all__ = [ "AnyModel", diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 796ccbacde0..e59a84d7291 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -20,14 +20,16 @@ """ import time -import torch from enum import Enum from typing import Literal, Optional, Type, Union -from pydantic import BaseModel, ConfigDict, Field, TypeAdapter +import torch from diffusers import ModelMixin +from pydantic import BaseModel, ConfigDict, Field, TypeAdapter from typing_extensions import Annotated, Any, Dict + from .onnx_runtime import IAIOnnxRuntimeModel +from ..ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus class InvalidModelConfigException(Exception): """Exception for when config parser doesn't recognized this combination of model type and format.""" @@ -204,6 +206,8 @@ class _MainConfig(ModelConfigBase): vae: Optional[str] = Field(default=None) variant: ModelVariantType = ModelVariantType.Normal + prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon + upcast_attention: bool = False ztsnr_training: bool = False @@ -217,8 +221,6 @@ class MainDiffusersConfig(_DiffusersConfig, _MainConfig): """Model config for main diffusers models.""" type: Literal[ModelType.Main] = ModelType.Main - prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon - upcast_attention: bool = False class ONNXSD1Config(_MainConfig): @@ -276,6 +278,7 @@ class T2IConfig(ModelConfigBase): _ONNXConfig, _VaeConfig, _ControlNetConfig, + # ModelConfigBase, LoRAConfig, TextualInversionConfig, IPAdapterConfig, @@ -284,7 +287,7 @@ class T2IConfig(ModelConfigBase): ] AnyModelConfigValidator = TypeAdapter(AnyModelConfig) -AnyModel = Union[ModelMixin, torch.nn.Module, IAIOnnxRuntimeModel] +AnyModel = Union[ModelMixin, torch.nn.Module, IAIOnnxRuntimeModel, IPAdapter, IPAdapterPlus] # IMPLEMENTATION NOTE: # The preferred alternative to the above is a discriminated Union as shown @@ -317,7 +320,7 @@ def make_config( model_data: Union[dict, AnyModelConfig], key: Optional[str] = None, dest_class: Optional[Type] = None, - timestamp: Optional[float] = None + timestamp: Optional[float] = None, ) -> AnyModelConfig: """ Return the appropriate config object from raw dict values. diff --git a/invokeai/backend/model_manager/convert_ckpt_to_diffusers.py b/invokeai/backend/model_manager/convert_ckpt_to_diffusers.py index 9d6fc4841f2..6f5acd58329 100644 --- a/invokeai/backend/model_manager/convert_ckpt_to_diffusers.py +++ b/invokeai/backend/model_manager/convert_ckpt_to_diffusers.py @@ -43,7 +43,6 @@ UnCLIPScheduler, ) from diffusers.utils import is_accelerate_available -from diffusers.utils.import_utils import BACKENDS_MAPPING from picklescan.scanner import scan_file_path from transformers import ( AutoFeatureExtractor, @@ -58,8 +57,8 @@ ) from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.model_manager import BaseModelType, ModelVariantType +from invokeai.backend.util.logging import InvokeAILogger try: from omegaconf import OmegaConf @@ -1643,7 +1642,6 @@ def download_controlnet_from_original_ckpt( cross_attention_dim: Optional[bool] = None, scan_needed: bool = False, ) -> DiffusionPipeline: - from omegaconf import OmegaConf if from_safetensors: diff --git a/invokeai/backend/model_manager/load/__init__.py b/invokeai/backend/model_manager/load/__init__.py index 357677bb7f7..19b0116ba3b 100644 --- a/invokeai/backend/model_manager/load/__init__.py +++ b/invokeai/backend/model_manager/load/__init__.py @@ -8,14 +8,15 @@ from invokeai.app.services.config import InvokeAIAppConfig from invokeai.backend.util.logging import InvokeAILogger + +from .convert_cache.convert_cache_default import ModelConvertCache from .load_base import AnyModelLoader, LoadedModel from .model_cache.model_cache_default import ModelCache -from .convert_cache.convert_cache_default import ModelConvertCache # This registers the subclasses that implement loaders of specific model types -loaders = [x.stem for x in Path(Path(__file__).parent,'model_loaders').glob('*.py') if x.stem != '__init__'] +loaders = [x.stem for x in Path(Path(__file__).parent, "model_loaders").glob("*.py") if x.stem != "__init__"] for module in loaders: - print(f'module={module}') + print(f"module={module}") import_module(f"{__package__}.model_loaders.{module}") __all__ = ["AnyModelLoader", "LoadedModel"] @@ -24,12 +25,11 @@ def get_standalone_loader(app_config: Optional[InvokeAIAppConfig]) -> AnyModelLoader: app_config = app_config or InvokeAIAppConfig.get_config() logger = InvokeAILogger.get_logger(config=app_config) - return AnyModelLoader(app_config=app_config, - logger=logger, - ram_cache=ModelCache(logger=logger, - max_cache_size=app_config.ram_cache_size, - max_vram_cache_size=app_config.vram_cache_size - ), - convert_cache=ModelConvertCache(app_config.models_convert_cache_path) - ) - + return AnyModelLoader( + app_config=app_config, + logger=logger, + ram_cache=ModelCache( + logger=logger, max_cache_size=app_config.ram_cache_size, max_vram_cache_size=app_config.vram_cache_size + ), + convert_cache=ModelConvertCache(app_config.models_convert_cache_path), + ) diff --git a/invokeai/backend/model_manager/load/convert_cache/__init__.py b/invokeai/backend/model_manager/load/convert_cache/__init__.py index eb3149be329..5be56d2d584 100644 --- a/invokeai/backend/model_manager/load/convert_cache/__init__.py +++ b/invokeai/backend/model_manager/load/convert_cache/__init__.py @@ -1,4 +1,4 @@ from .convert_cache_base import ModelConvertCacheBase from .convert_cache_default import ModelConvertCache -__all__ = ['ModelConvertCacheBase', 'ModelConvertCache'] +__all__ = ["ModelConvertCacheBase", "ModelConvertCache"] diff --git a/invokeai/backend/model_manager/load/convert_cache/convert_cache_base.py b/invokeai/backend/model_manager/load/convert_cache/convert_cache_base.py index 25263f96aaa..6268c099a5f 100644 --- a/invokeai/backend/model_manager/load/convert_cache/convert_cache_base.py +++ b/invokeai/backend/model_manager/load/convert_cache/convert_cache_base.py @@ -4,8 +4,8 @@ from abc import ABC, abstractmethod from pathlib import Path -class ModelConvertCacheBase(ABC): +class ModelConvertCacheBase(ABC): @property @abstractmethod def max_size(self) -> float: @@ -25,4 +25,3 @@ def make_room(self, size: float) -> None: def cache_path(self, key: str) -> Path: """Return the path for a model with the indicated key.""" pass - diff --git a/invokeai/backend/model_manager/load/convert_cache/convert_cache_default.py b/invokeai/backend/model_manager/load/convert_cache/convert_cache_default.py index f799510ec5b..4c361258d90 100644 --- a/invokeai/backend/model_manager/load/convert_cache/convert_cache_default.py +++ b/invokeai/backend/model_manager/load/convert_cache/convert_cache_default.py @@ -2,15 +2,17 @@ Placeholder for convert cache implementation. """ -from pathlib import Path import shutil -from invokeai.backend.util.logging import InvokeAILogger +from pathlib import Path + from invokeai.backend.util import GIG, directory_size +from invokeai.backend.util.logging import InvokeAILogger + from .convert_cache_base import ModelConvertCacheBase -class ModelConvertCache(ModelConvertCacheBase): - def __init__(self, cache_path: Path, max_size: float=10.0): +class ModelConvertCache(ModelConvertCacheBase): + def __init__(self, cache_path: Path, max_size: float = 10.0): """Initialize the convert cache with the base directory and a limit on its maximum size (in GBs).""" if not cache_path.exists(): cache_path.mkdir(parents=True) diff --git a/invokeai/backend/model_manager/load/load_base.py b/invokeai/backend/model_manager/load/load_base.py index 3ade83160a2..7d4e8337c3c 100644 --- a/invokeai/backend/model_manager/load/load_base.py +++ b/invokeai/backend/model_manager/load/load_base.py @@ -10,17 +10,19 @@ # do something with loaded_model """ +import hashlib from abc import ABC, abstractmethod from dataclasses import dataclass from logging import Logger from pathlib import Path -from typing import Any, Callable, Dict, Optional, Type, Union +from typing import Any, Callable, Dict, Optional, Tuple, Type from invokeai.app.services.config import InvokeAIAppConfig from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType -from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase -from invokeai.backend.model_manager.load.model_cache.model_locker import ModelLockerBase +from invokeai.backend.model_manager.config import VaeCheckpointConfig, VaeDiffusersConfig from invokeai.backend.model_manager.load.convert_cache.convert_cache_base import ModelConvertCacheBase +from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase + @dataclass class LoadedModel: @@ -52,7 +54,7 @@ def __init__( self, app_config: InvokeAIAppConfig, logger: Logger, - ram_cache: ModelCacheBase, + ram_cache: ModelCacheBase[AnyModel], convert_cache: ModelConvertCacheBase, ): """Initialize the loader.""" @@ -91,7 +93,7 @@ def __init__( self, app_config: InvokeAIAppConfig, logger: Logger, - ram_cache: ModelCacheBase, + ram_cache: ModelCacheBase[AnyModel], convert_cache: ModelConvertCacheBase, ): """Initialize AnyModelLoader with its dependencies.""" @@ -101,11 +103,11 @@ def __init__( self._convert_cache = convert_cache @property - def ram_cache(self) -> ModelCacheBase: + def ram_cache(self) -> ModelCacheBase[AnyModel]: """Return the RAM cache associated used by the loaders.""" return self._ram_cache - def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType]=None) -> LoadedModel: + def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel: """ Return a model given its configuration. @@ -113,9 +115,7 @@ def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubMo :param submodel_type: an ModelType enum indicating the portion of the model to retrieve (e.g. ModelType.Vae) """ - implementation = self.__class__.get_implementation( - base=model_config.base, type=model_config.type, format=model_config.format - ) + implementation, model_config, submodel_type = self.__class__.get_implementation(model_config, submodel_type) return implementation( app_config=self._app_config, logger=self._logger, @@ -128,16 +128,37 @@ def _to_registry_key(base: BaseModelType, type: ModelType, format: ModelFormat) return "-".join([base.value, type.value, format.value]) @classmethod - def get_implementation(cls, base: BaseModelType, type: ModelType, format: ModelFormat) -> Type[ModelLoaderBase]: + def get_implementation( + cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] + ) -> Tuple[Type[ModelLoaderBase], AnyModelConfig, Optional[SubModelType]]: """Get subclass of ModelLoaderBase registered to handle base and type.""" - key1 = cls._to_registry_key(base, type, format) # for a specific base type - key2 = cls._to_registry_key(BaseModelType.Any, type, format) # with wildcard Any + # We have to handle VAE overrides here because this will change the model type and the corresponding implementation returned + conf2, submodel_type = cls._handle_subtype_overrides(config, submodel_type) + + key1 = cls._to_registry_key(conf2.base, conf2.type, conf2.format) # for a specific base type + key2 = cls._to_registry_key(BaseModelType.Any, conf2.type, conf2.format) # with wildcard Any implementation = cls._registry.get(key1) or cls._registry.get(key2) if not implementation: raise NotImplementedError( - f"No subclass of LoadedModel is registered for base={base}, type={type}, format={format}" + f"No subclass of LoadedModel is registered for base={config.base}, type={config.type}, format={config.format}" ) - return implementation + return implementation, conf2, submodel_type + + @classmethod + def _handle_subtype_overrides( + cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] + ) -> Tuple[AnyModelConfig, Optional[SubModelType]]: + if submodel_type == SubModelType.Vae and hasattr(config, "vae") and config.vae is not None: + model_path = Path(config.vae) + config_class = ( + VaeCheckpointConfig if model_path.suffix in [".pt", ".safetensors", ".ckpt"] else VaeDiffusersConfig + ) + hash = hashlib.md5(model_path.as_posix().encode("utf-8")).hexdigest() + new_conf = config_class(path=model_path.as_posix(), name=model_path.stem, base=config.base, key=hash) + submodel_type = None + else: + new_conf = config + return new_conf, submodel_type @classmethod def register( @@ -152,4 +173,3 @@ def decorator(subclass: Type[ModelLoaderBase]) -> Type[ModelLoaderBase]: return subclass return decorator - diff --git a/invokeai/backend/model_manager/load/load_default.py b/invokeai/backend/model_manager/load/load_default.py index 0b028235fdd..453283e9b4a 100644 --- a/invokeai/backend/model_manager/load/load_default.py +++ b/invokeai/backend/model_manager/load/load_default.py @@ -10,12 +10,12 @@ from diffusers.configuration_utils import ConfigMixin from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend.model_manager import AnyModelConfig, InvalidModelConfigException, ModelRepoVariant, SubModelType +from invokeai.backend.model_manager import AnyModel, AnyModelConfig, InvalidModelConfigException, ModelRepoVariant, SubModelType from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase -from invokeai.backend.model_manager.load.load_base import AnyModel, LoadedModel, ModelLoaderBase -from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs -from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init +from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase +from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs, calc_model_size_by_data +from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init from invokeai.backend.util.devices import choose_torch_device, torch_dtype @@ -38,7 +38,7 @@ def __init__( self, app_config: InvokeAIAppConfig, logger: Logger, - ram_cache: ModelCacheBase, + ram_cache: ModelCacheBase[AnyModel], convert_cache: ModelConvertCacheBase, ): """Initialize the loader.""" @@ -47,7 +47,6 @@ def __init__( self._ram_cache = ram_cache self._convert_cache = convert_cache self._torch_dtype = torch_dtype(choose_torch_device()) - self._size: Optional[int] = None # model size def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel: """ @@ -63,9 +62,7 @@ def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubMo if model_config.type == "main" and not submodel_type: raise InvalidModelConfigException("submodel_type is required when loading a main model") - model_path, is_submodel_override = self._get_model_path(model_config, submodel_type) - if is_submodel_override: - submodel_type = None + model_path, model_config, submodel_type = self._get_model_path(model_config, submodel_type) if not model_path.exists(): raise InvalidModelConfigException(f"Files for model 'model_config.name' not found at {model_path}") @@ -74,13 +71,12 @@ def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubMo locker = self._load_if_needed(model_config, model_path, submodel_type) return LoadedModel(config=model_config, locker=locker) - # IMPORTANT: This needs to be overridden in the StableDiffusion subclass so as to handle vae overrides - # and submodels!!!! def _get_model_path( self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None - ) -> Tuple[Path, bool]: + ) -> Tuple[Path, AnyModelConfig, Optional[SubModelType]]: model_base = self._app_config.models_path - return ((model_base / config.path).resolve(), False) + result = (model_base / config.path).resolve(), config, submodel_type + return result def _convert_if_needed( self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None @@ -90,7 +86,7 @@ def _convert_if_needed( if not self._needs_conversion(config, model_path, cache_path): return cache_path if cache_path.exists() else model_path - self._convert_cache.make_room(self._size or self.get_size_fs(config, model_path, submodel_type)) + self._convert_cache.make_room(self.get_size_fs(config, model_path, submodel_type)) return self._convert_model(config, model_path, cache_path) def _needs_conversion(self, config: AnyModelConfig, model_path: Path, cache_path: Path) -> bool: @@ -114,6 +110,7 @@ def _load_if_needed( config.key, submodel_type=submodel_type, model=loaded_model, + size=calc_model_size_by_data(loaded_model), ) return self._ram_cache.get(config.key, submodel_type) @@ -128,17 +125,6 @@ def get_size_fs( variant=config.repo_variant if hasattr(config, "repo_variant") else None, ) - def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path: - raise NotImplementedError - - def _load_model( - self, - model_path: Path, - model_variant: Optional[ModelRepoVariant] = None, - submodel_type: Optional[SubModelType] = None, - ) -> AnyModel: - raise NotImplementedError - def _load_diffusers_config(self, model_path: Path, config_name: str = "config.json") -> Dict[str, Any]: return ConfigLoader.load_config(model_path, config_name=config_name) @@ -161,3 +147,17 @@ def _get_hf_load_class(self, model_path: Path, submodel_type: Optional[SubModelT config = self._load_diffusers_config(model_path, config_name="config.json") class_name = config["_class_name"] return self._hf_definition_to_type(module="diffusers", class_name=class_name) + + # This needs to be implemented in subclasses that handle checkpoints + def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path: + raise NotImplementedError + + # This needs to be implemented in the subclass + def _load_model( + self, + model_path: Path, + model_variant: Optional[ModelRepoVariant] = None, + submodel_type: Optional[SubModelType] = None, + ) -> AnyModel: + raise NotImplementedError + diff --git a/invokeai/backend/model_manager/load/memory_snapshot.py b/invokeai/backend/model_manager/load/memory_snapshot.py index 504829a4271..295be0c5514 100644 --- a/invokeai/backend/model_manager/load/memory_snapshot.py +++ b/invokeai/backend/model_manager/load/memory_snapshot.py @@ -97,4 +97,4 @@ def get_msg_line(prefix: str, val1: int, val2: int) -> str: if snapshot_1.vram is not None and snapshot_2.vram is not None: msg += get_msg_line("VRAM", snapshot_1.vram, snapshot_2.vram) - return msg + return "\n"+msg if len(msg)>0 else msg diff --git a/invokeai/backend/model_manager/load/model_cache/__init__.py b/invokeai/backend/model_manager/load/model_cache/__init__.py index 776b9d8936d..50cafa37696 100644 --- a/invokeai/backend/model_manager/load/model_cache/__init__.py +++ b/invokeai/backend/model_manager/load/model_cache/__init__.py @@ -2,4 +2,4 @@ from .model_cache_base import ModelCacheBase from .model_cache_default import ModelCache -_all__ = ['ModelCacheBase', 'ModelCache'] +_all__ = ["ModelCacheBase", "ModelCache"] diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_base.py b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py index 50b69d961c6..14a7dfb4a1f 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_base.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py @@ -8,14 +8,15 @@ """ from abc import ABC, abstractmethod -from dataclasses import dataclass, field +from dataclasses import dataclass from logging import Logger -from typing import Dict, Optional, TypeVar, Generic +from typing import Generic, Optional, TypeVar import torch from invokeai.backend.model_manager import AnyModel, SubModelType + class ModelLockerBase(ABC): """Base class for the model locker used by the loader.""" @@ -35,8 +36,10 @@ def model(self) -> AnyModel: """Return the model.""" pass + T = TypeVar("T") + @dataclass class CacheRecord(Generic[T]): """Elements of the cache.""" @@ -115,6 +118,7 @@ def put( self, key: str, model: T, + size: int, submodel_type: Optional[SubModelType] = None, ) -> None: """Store model under key and optional submodel_type.""" diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py index 961f68a4bea..688be8ceb48 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py @@ -19,22 +19,24 @@ """ import gc +import logging import math +import sys import time from contextlib import suppress from logging import Logger -from typing import Any, Dict, List, Optional +from typing import Dict, List, Optional import torch from invokeai.backend.model_manager import SubModelType from invokeai.backend.model_manager.load.load_base import AnyModel from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff -from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data from invokeai.backend.util.devices import choose_torch_device from invokeai.backend.util.logging import InvokeAILogger + from .model_cache_base import CacheRecord, ModelCacheBase -from .model_locker import ModelLockerBase, ModelLocker +from .model_locker import ModelLocker, ModelLockerBase if choose_torch_device() == torch.device("mps"): from torch import mps @@ -91,7 +93,7 @@ def __init__( self._execution_device: torch.device = execution_device self._storage_device: torch.device = storage_device self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__) - self._log_memory_usage = log_memory_usage + self._log_memory_usage = log_memory_usage or self._logger.level == logging.DEBUG self._cached_models: Dict[str, CacheRecord[AnyModel]] = {} self._cache_stack: List[str] = [] @@ -141,14 +143,14 @@ def put( self, key: str, model: AnyModel, + size: int, submodel_type: Optional[SubModelType] = None, ) -> None: """Store model under key and optional submodel_type.""" key = self._make_cache_key(key, submodel_type) assert key not in self._cached_models - loaded_model_size = calc_model_size_by_data(model) - cache_record = CacheRecord(key, model, loaded_model_size) + cache_record = CacheRecord(key, model, size) self._cached_models[key] = cache_record self._cache_stack.append(key) @@ -195,28 +197,32 @@ def offload_unlocked_models(self, size_required: int) -> None: for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size): if vram_in_use <= reserved: break + if not cache_entry.loaded: + continue if not cache_entry.locked: self.move_model_to_device(cache_entry, self.storage_device) cache_entry.loaded = False vram_in_use = torch.cuda.memory_allocated() + size_required - self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM now available for models; max allowed={(reserved/GIG):.2f}GB") + self.logger.debug( + f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB" + ) torch.cuda.empty_cache() if choose_torch_device() == torch.device("mps"): mps.empty_cache() - # TO DO: Only reason to pass the CacheRecord rather than the model is to get the key and size - # for printing debugging messages. Revisit whether this is necessary - def move_model_to_device(self, cache_entry: CacheRecord, target_device: torch.device) -> None: + def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None: """Move model into the indicated device.""" - # These attributes are not in the base class but in derived classes - assert hasattr(cache_entry.model, "device") - assert hasattr(cache_entry.model, "to") + # These attributes are not in the base ModelMixin class but in derived classes. + # Some models don't have these attributes, in which case they run in RAM/CPU. + self.logger.debug(f"Called to move {cache_entry.key} to {target_device}") + if not (hasattr(cache_entry.model, "device") and hasattr(cache_entry.model, "to")): + return source_device = cache_entry.model.device - # Note: We compare device types only so that 'cuda' == 'cuda:0'. This would need to be revised to support - # multi-GPU. + # Note: We compare device types only so that 'cuda' == 'cuda:0'. + # This would need to be revised to support multi-GPU. if torch.device(source_device).type == torch.device(target_device).type: return @@ -227,8 +233,8 @@ def move_model_to_device(self, cache_entry: CacheRecord, target_device: torch.de end_model_to_time = time.time() self.logger.debug( f"Moved model '{cache_entry.key}' from {source_device} to" - f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s.\n" - f"Estimated model size: {(cache_entry.size/GIG):.3f} GB.\n" + f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s." + f"Estimated model size: {(cache_entry.size/GIG):.3f} GB." f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}" ) @@ -291,7 +297,7 @@ def make_room(self, model_size: int) -> None: f" {(bytes_needed/GIG):.2f} GB" ) - self.logger.debug(f"Before unloading: cached_models={len(self._cached_models)}") + self.logger.debug(f"Before making_room: cached_models={len(self._cached_models)}") pos = 0 models_cleared = 0 @@ -336,7 +342,7 @@ def make_room(self, model_size: int) -> None: # 1 from onnx runtime object if not cache_entry.locked and refs <= (3 if "onnx" in model_key else 2): self.logger.debug( - f"Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)" + f"Removing {model_key} from RAM cache to free at least {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)" ) current_size -= cache_entry.size models_cleared += 1 @@ -365,4 +371,4 @@ def make_room(self, model_size: int) -> None: if choose_torch_device() == torch.device("mps"): mps.empty_cache() - self.logger.debug(f"After unloading: cached_models={len(self._cached_models)}") + self.logger.debug(f"After making room: cached_models={len(self._cached_models)}") diff --git a/invokeai/backend/model_manager/load/model_cache/model_locker.py b/invokeai/backend/model_manager/load/model_cache/model_locker.py index 506d0129491..7a5fdd4284b 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_locker.py +++ b/invokeai/backend/model_manager/load/model_cache/model_locker.py @@ -2,9 +2,10 @@ Base class and implementation of a class that moves models in and out of VRAM. """ -from abc import ABC, abstractmethod from invokeai.backend.model_manager import AnyModel -from .model_cache_base import ModelLockerBase, ModelCacheBase, CacheRecord + +from .model_cache_base import CacheRecord, ModelCacheBase, ModelLockerBase + class ModelLocker(ModelLockerBase): """Internal class that mediates movement in and out of GPU.""" @@ -56,4 +57,3 @@ def unlock(self) -> None: if not self._cache.lazy_offloading: self._cache.offload_unlocked_models(self._cache_entry.size) self._cache.print_cuda_stats() - diff --git a/invokeai/backend/model_manager/load/model_loaders/controlnet.py b/invokeai/backend/model_manager/load/model_loaders/controlnet.py new file mode 100644 index 00000000000..8e6a80ceb20 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_loaders/controlnet.py @@ -0,0 +1,60 @@ +# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team +"""Class for ControlNet model loading in InvokeAI.""" + +from pathlib import Path + +import safetensors +import torch + +from invokeai.backend.model_manager import ( + AnyModelConfig, + BaseModelType, + ModelFormat, + ModelType, +) +from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_controlnet_to_diffusers +from invokeai.backend.model_manager.load.load_base import AnyModelLoader +from .generic_diffusers import GenericDiffusersLoader + +@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Diffusers) +@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Checkpoint) +class ControlnetLoader(GenericDiffusersLoader): + """Class to load ControlNet models.""" + + def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool: + if config.format != ModelFormat.Checkpoint: + return False + elif ( + dest_path.exists() + and (dest_path / "config.json").stat().st_mtime >= (config.last_modified or 0.0) + and (dest_path / "config.json").stat().st_mtime >= model_path.stat().st_mtime + ): + return False + else: + return True + + def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path: + if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}: + raise Exception(f"Vae conversion not supported for model type: {config.base}") + else: + assert hasattr(config, 'config') + config_file = config.config + + if weights_path.suffix == ".safetensors": + checkpoint = safetensors.torch.load_file(weights_path, device="cpu") + else: + checkpoint = torch.load(weights_path, map_location="cpu") + + # sometimes weights are hidden under "state_dict", and sometimes not + if "state_dict" in checkpoint: + checkpoint = checkpoint["state_dict"] + + convert_controlnet_to_diffusers( + weights_path, + output_path, + original_config_file=self._app_config.root_path / config_file, + image_size=512, + scan_needed=True, + from_safetensors=weights_path.suffix == ".safetensors", + ) + return output_path diff --git a/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py b/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py new file mode 100644 index 00000000000..f92a9048c50 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py @@ -0,0 +1,34 @@ +# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team +"""Class for simple diffusers model loading in InvokeAI.""" + +from pathlib import Path +from typing import Optional + +from invokeai.backend.model_manager import ( + AnyModel, + BaseModelType, + ModelFormat, + ModelRepoVariant, + ModelType, + SubModelType, +) +from invokeai.backend.model_manager.load.load_base import AnyModelLoader +from invokeai.backend.model_manager.load.load_default import ModelLoader + +@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.CLIPVision, format=ModelFormat.Diffusers) +@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.T2IAdapter, format=ModelFormat.Diffusers) +class GenericDiffusersLoader(ModelLoader): + """Class to load simple diffusers models.""" + + def _load_model( + self, + model_path: Path, + model_variant: Optional[ModelRepoVariant] = None, + submodel_type: Optional[SubModelType] = None, + ) -> AnyModel: + model_class = self._get_hf_load_class(model_path) + if submodel_type is not None: + raise Exception(f"There are no submodels in models of type {model_class}") + variant = model_variant.value if model_variant else None + result: AnyModel = model_class.from_pretrained(model_path, torch_dtype=self._torch_dtype, variant=variant) # type: ignore + return result diff --git a/invokeai/backend/model_manager/load/model_loaders/ip_adapter.py b/invokeai/backend/model_manager/load/model_loaders/ip_adapter.py new file mode 100644 index 00000000000..63dc3790f16 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_loaders/ip_adapter.py @@ -0,0 +1,39 @@ +# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team +"""Class for IP Adapter model loading in InvokeAI.""" + +import torch + +from pathlib import Path +from typing import Optional + +from invokeai.backend.ip_adapter.ip_adapter import build_ip_adapter +from invokeai.backend.model_manager import ( + AnyModel, + BaseModelType, + ModelFormat, + ModelRepoVariant, + ModelType, + SubModelType, +) +from invokeai.backend.model_manager.load.load_base import AnyModelLoader +from invokeai.backend.model_manager.load.load_default import ModelLoader + +@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.IPAdapter, format=ModelFormat.InvokeAI) +class IPAdapterInvokeAILoader(ModelLoader): + """Class to load IP Adapter diffusers models.""" + + def _load_model( + self, + model_path: Path, + model_variant: Optional[ModelRepoVariant] = None, + submodel_type: Optional[SubModelType] = None, + ) -> AnyModel: + if submodel_type is not None: + raise ValueError("There are no submodels in an IP-Adapter model.") + model = build_ip_adapter( + ip_adapter_ckpt_path=model_path / "ip_adapter.bin", + device=torch.device("cpu"), + dtype=self._torch_dtype, + ) + return model + diff --git a/invokeai/backend/model_manager/load/model_loaders/lora.py b/invokeai/backend/model_manager/load/model_loaders/lora.py new file mode 100644 index 00000000000..4d19aadb7d2 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_loaders/lora.py @@ -0,0 +1,76 @@ +# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team +"""Class for LoRA model loading in InvokeAI.""" + + +from pathlib import Path +from typing import Optional, Tuple +from logging import Logger + +from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase +from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase +from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.backend.model_manager import ( + AnyModel, + AnyModelConfig, + BaseModelType, + ModelFormat, + ModelRepoVariant, + ModelType, + SubModelType, +) +from invokeai.backend.model_manager.lora import LoRAModelRaw +from invokeai.backend.model_manager.load.load_base import AnyModelLoader +from invokeai.backend.model_manager.load.load_default import ModelLoader + +@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Diffusers) +@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Lycoris) +class LoraLoader(ModelLoader): + """Class to load LoRA models.""" + + # We cheat a little bit to get access to the model base + def __init__( + self, + app_config: InvokeAIAppConfig, + logger: Logger, + ram_cache: ModelCacheBase[AnyModel], + convert_cache: ModelConvertCacheBase, + ): + """Initialize the loader.""" + super().__init__(app_config, logger, ram_cache, convert_cache) + self._model_base: Optional[BaseModelType] = None + + def _load_model( + self, + model_path: Path, + model_variant: Optional[ModelRepoVariant] = None, + submodel_type: Optional[SubModelType] = None, + ) -> AnyModel: + if submodel_type is not None: + raise ValueError("There are no submodels in a LoRA model.") + model = LoRAModelRaw.from_checkpoint( + file_path=model_path, + dtype=self._torch_dtype, + base_model=self._model_base, + ) + return model + + # override + def _get_model_path( + self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None + ) -> Tuple[Path, AnyModelConfig, Optional[SubModelType]]: + self._model_base = config.base # cheating a little - setting this variable for later call to _load_model() + + model_base_path = self._app_config.models_path + model_path = model_base_path / config.path + + if config.format == ModelFormat.Diffusers: + for ext in ["safetensors", "bin"]: # return path to the safetensors file inside the folder + path = model_base_path / config.path / f"pytorch_lora_weights.{ext}" + if path.exists(): + model_path = path + break + + result = model_path.resolve(), config, submodel_type + return result + + diff --git a/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py b/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py new file mode 100644 index 00000000000..a963e8403b9 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py @@ -0,0 +1,93 @@ +# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team +"""Class for StableDiffusion model loading in InvokeAI.""" + + +from pathlib import Path +from typing import Optional + +from diffusers import StableDiffusionInpaintPipeline, StableDiffusionPipeline + +from invokeai.backend.model_manager import ( + AnyModel, + AnyModelConfig, + BaseModelType, + ModelFormat, + ModelRepoVariant, + ModelType, + ModelVariantType, + SubModelType, +) +from invokeai.backend.model_manager.config import MainCheckpointConfig +from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ckpt_to_diffusers +from invokeai.backend.model_manager.load.load_base import AnyModelLoader +from invokeai.backend.model_manager.load.load_default import ModelLoader + + +@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Main, format=ModelFormat.Diffusers) +@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Main, format=ModelFormat.Checkpoint) +class StableDiffusionDiffusersModel(ModelLoader): + """Class to load main models.""" + + model_base_to_model_type = { + BaseModelType.StableDiffusion1: "FrozenCLIPEmbedder", + BaseModelType.StableDiffusion2: "FrozenOpenCLIPEmbedder", + BaseModelType.StableDiffusionXL: "SDXL", + BaseModelType.StableDiffusionXLRefiner: "SDXL-Refiner", + } + + def _load_model( + self, + model_path: Path, + model_variant: Optional[ModelRepoVariant] = None, + submodel_type: Optional[SubModelType] = None, + ) -> AnyModel: + if not submodel_type is not None: + raise Exception("A submodel type must be provided when loading main pipelines.") + load_class = self._get_hf_load_class(model_path, submodel_type) + variant = model_variant.value if model_variant else None + model_path = model_path / submodel_type.value + result: AnyModel = load_class.from_pretrained( + model_path, + torch_dtype=self._torch_dtype, + variant=variant, + ) # type: ignore + return result + + def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool: + if config.format != ModelFormat.Checkpoint: + return False + elif ( + dest_path.exists() + and (dest_path / "model_index.json").stat().st_mtime >= (config.last_modified or 0.0) + and (dest_path / "model_index.json").stat().st_mtime >= model_path.stat().st_mtime + ): + return False + else: + return True + + def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path: + assert isinstance(config, MainCheckpointConfig) + variant = config.variant + base = config.base + pipeline_class = ( + StableDiffusionInpaintPipeline if variant == ModelVariantType.Inpaint else StableDiffusionPipeline + ) + + config_file = config.config + + self._logger.info(f"Converting {weights_path} to diffusers format") + convert_ckpt_to_diffusers( + weights_path, + output_path, + model_type=self.model_base_to_model_type[base], + model_version=base, + model_variant=variant, + original_config_file=self._app_config.root_path / config_file, + extract_ema=True, + scan_needed=True, + pipeline_class=pipeline_class, + from_safetensors=weights_path.suffix == ".safetensors", + precision=self._torch_dtype, + load_safety_checker=False, + ) + return output_path diff --git a/invokeai/backend/model_manager/load/model_loaders/vae.py b/invokeai/backend/model_manager/load/model_loaders/vae.py index 6f21c3d0903..7a35e53459a 100644 --- a/invokeai/backend/model_manager/load/model_loaders/vae.py +++ b/invokeai/backend/model_manager/load/model_loaders/vae.py @@ -2,68 +2,54 @@ """Class for VAE model loading in InvokeAI.""" from pathlib import Path -from typing import Optional -import torch import safetensors -from omegaconf import OmegaConf, DictConfig -from invokeai.backend.util.devices import torch_dtype -from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelFormat, ModelRepoVariant, ModelType, SubModelType -from invokeai.backend.model_manager.load.load_base import AnyModelLoader -from invokeai.backend.model_manager.load.load_default import ModelLoader +import torch +from omegaconf import DictConfig, OmegaConf + +from invokeai.backend.model_manager import ( + AnyModelConfig, + BaseModelType, + ModelFormat, + ModelType, +) from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers +from invokeai.backend.model_manager.load.load_base import AnyModelLoader +from .generic_diffusers import GenericDiffusersLoader + @AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Vae, format=ModelFormat.Diffusers) @AnyModelLoader.register(base=BaseModelType.StableDiffusion1, type=ModelType.Vae, format=ModelFormat.Checkpoint) @AnyModelLoader.register(base=BaseModelType.StableDiffusion2, type=ModelType.Vae, format=ModelFormat.Checkpoint) -class VaeDiffusersModel(ModelLoader): +class VaeLoader(GenericDiffusersLoader): """Class to load VAE models.""" - def _load_model( - self, - model_path: Path, - model_variant: Optional[ModelRepoVariant] = None, - submodel_type: Optional[SubModelType] = None, - ) -> AnyModel: - if submodel_type is not None: - raise Exception("There are no submodels in VAEs") - vae_class = self._get_hf_load_class(model_path) - variant = model_variant.value if model_variant else None - result: AnyModel = vae_class.from_pretrained( - model_path, torch_dtype=self._torch_dtype, variant=variant - ) # type: ignore - return result - def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool: - print(f'DEBUG: last_modified={config.last_modified}') - print(f'DEBUG: cache_path={(dest_path / "config.json").stat().st_mtime}') - print(f'DEBUG: model_path={model_path.stat().st_mtime}') if config.format != ModelFormat.Checkpoint: return False - elif dest_path.exists() \ - and (dest_path / "config.json").stat().st_mtime >= config.last_modified \ - and (dest_path / "config.json").stat().st_mtime >= model_path.stat().st_mtime: + elif ( + dest_path.exists() + and (dest_path / "config.json").stat().st_mtime >= (config.last_modified or 0.0) + and (dest_path / "config.json").stat().st_mtime >= model_path.stat().st_mtime + ): return False else: return True - def _convert_model(self, - config: AnyModelConfig, - weights_path: Path, - output_path: Path - ) -> Path: + def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path: + # TO DO: check whether sdxl VAE models convert. if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}: raise Exception(f"Vae conversion not supported for model type: {config.base}") else: - config_file = 'v1-inference.yaml' if config.base == BaseModelType.StableDiffusion1 else "v2-inference-v.yaml" + config_file = ( + "v1-inference.yaml" if config.base == BaseModelType.StableDiffusion1 else "v2-inference-v.yaml" + ) if weights_path.suffix == ".safetensors": checkpoint = safetensors.torch.load_file(weights_path, device="cpu") else: checkpoint = torch.load(weights_path, map_location="cpu") - dtype = torch_dtype() - # sometimes weights are hidden under "state_dict", and sometimes not if "state_dict" in checkpoint: checkpoint = checkpoint["state_dict"] @@ -71,13 +57,11 @@ def _convert_model(self, ckpt_config = OmegaConf.load(self._app_config.legacy_conf_path / config_file) assert isinstance(ckpt_config, DictConfig) - print(f'DEBUG: CONVERTIGN') vae_model = convert_ldm_vae_to_diffusers( checkpoint=checkpoint, vae_config=ckpt_config, image_size=512, ) - vae_model.to(dtype) # set precision appropriately - vae_model.save_pretrained(output_path, safe_serialization=True, torch_dtype=dtype) + vae_model.to(self._torch_dtype) # set precision appropriately + vae_model.save_pretrained(output_path, safe_serialization=True) return output_path - diff --git a/invokeai/backend/model_manager/load/model_util.py b/invokeai/backend/model_manager/load/model_util.py index 7c27e66472f..404c88bbbcd 100644 --- a/invokeai/backend/model_manager/load/model_util.py +++ b/invokeai/backend/model_manager/load/model_util.py @@ -8,10 +8,11 @@ import torch from diffusers import DiffusionPipeline +from invokeai.backend.model_manager.config import AnyModel from invokeai.backend.model_manager.onnx_runtime import IAIOnnxRuntimeModel -def calc_model_size_by_data(model: Union[DiffusionPipeline, torch.nn.Module, IAIOnnxRuntimeModel]) -> int: +def calc_model_size_by_data(model: AnyModel) -> int: """Get size of a model in memory in bytes.""" if isinstance(model, DiffusionPipeline): return _calc_pipeline_by_data(model) @@ -50,7 +51,7 @@ def calc_model_size_by_fs(model_path: Path, subfolder: Optional[str] = None, var """Estimate the size of a model on disk in bytes.""" if model_path.is_file(): return model_path.stat().st_size - + if subfolder is not None: model_path = model_path / subfolder diff --git a/invokeai/backend/model_manager/lora.py b/invokeai/backend/model_manager/lora.py new file mode 100644 index 00000000000..4c48de48ec7 --- /dev/null +++ b/invokeai/backend/model_manager/lora.py @@ -0,0 +1,620 @@ +# Copyright (c) 2024 The InvokeAI Development team +"""LoRA model support.""" + +import torch +from safetensors.torch import load_file +from pathlib import Path +from typing import Dict, Optional, Union, List, Tuple +from typing_extensions import Self +from invokeai.backend.model_manager import BaseModelType + +class LoRALayerBase: + # rank: Optional[int] + # alpha: Optional[float] + # bias: Optional[torch.Tensor] + # layer_key: str + + # @property + # def scale(self): + # return self.alpha / self.rank if (self.alpha and self.rank) else 1.0 + + def __init__( + self, + layer_key: str, + values: Dict[str, torch.Tensor], + ): + if "alpha" in values: + self.alpha = values["alpha"].item() + else: + self.alpha = None + + if "bias_indices" in values and "bias_values" in values and "bias_size" in values: + self.bias: Optional[torch.Tensor] = torch.sparse_coo_tensor( + values["bias_indices"], + values["bias_values"], + tuple(values["bias_size"]), + ) + + else: + self.bias = None + + self.rank = None # set in layer implementation + self.layer_key = layer_key + + def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor: + raise NotImplementedError() + + def calc_size(self) -> int: + model_size = 0 + for val in [self.bias]: + if val is not None: + model_size += val.nelement() * val.element_size() + return model_size + + def to( + self, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> None: + if self.bias is not None: + self.bias = self.bias.to(device=device, dtype=dtype) + + +# TODO: find and debug lora/locon with bias +class LoRALayer(LoRALayerBase): + # up: torch.Tensor + # mid: Optional[torch.Tensor] + # down: torch.Tensor + + def __init__( + self, + layer_key: str, + values: Dict[str, torch.Tensor], + ): + super().__init__(layer_key, values) + + self.up = values["lora_up.weight"] + self.down = values["lora_down.weight"] + if "lora_mid.weight" in values: + self.mid: Optional[torch.Tensor] = values["lora_mid.weight"] + else: + self.mid = None + + self.rank = self.down.shape[0] + + def get_weight(self, orig_weight: torch.Tensor): + if self.mid is not None: + up = self.up.reshape(self.up.shape[0], self.up.shape[1]) + down = self.down.reshape(self.down.shape[0], self.down.shape[1]) + weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down) + else: + weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1) + + return weight + + def calc_size(self) -> int: + model_size = super().calc_size() + for val in [self.up, self.mid, self.down]: + if val is not None: + model_size += val.nelement() * val.element_size() + return model_size + + def to( + self, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> None: + super().to(device=device, dtype=dtype) + + self.up = self.up.to(device=device, dtype=dtype) + self.down = self.down.to(device=device, dtype=dtype) + + if self.mid is not None: + self.mid = self.mid.to(device=device, dtype=dtype) + + +class LoHALayer(LoRALayerBase): + # w1_a: torch.Tensor + # w1_b: torch.Tensor + # w2_a: torch.Tensor + # w2_b: torch.Tensor + # t1: Optional[torch.Tensor] = None + # t2: Optional[torch.Tensor] = None + + def __init__( + self, + layer_key: str, + values: Dict[str, torch.Tensor] + ): + super().__init__(layer_key, values) + + self.w1_a = values["hada_w1_a"] + self.w1_b = values["hada_w1_b"] + self.w2_a = values["hada_w2_a"] + self.w2_b = values["hada_w2_b"] + + if "hada_t1" in values: + self.t1: Optional[torch.Tensor] = values["hada_t1"] + else: + self.t1 = None + + if "hada_t2" in values: + self.t2: Optional[torch.Tensor] = values["hada_t2"] + else: + self.t2 = None + + self.rank = self.w1_b.shape[0] + + def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor: + if self.t1 is None: + weight: torch.Tensor = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b) + + else: + rebuild1 = torch.einsum("i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a) + rebuild2 = torch.einsum("i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a) + weight = rebuild1 * rebuild2 + + return weight + + def calc_size(self) -> int: + model_size = super().calc_size() + for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]: + if val is not None: + model_size += val.nelement() * val.element_size() + return model_size + + def to( + self, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> None: + super().to(device=device, dtype=dtype) + + self.w1_a = self.w1_a.to(device=device, dtype=dtype) + self.w1_b = self.w1_b.to(device=device, dtype=dtype) + if self.t1 is not None: + self.t1 = self.t1.to(device=device, dtype=dtype) + + self.w2_a = self.w2_a.to(device=device, dtype=dtype) + self.w2_b = self.w2_b.to(device=device, dtype=dtype) + if self.t2 is not None: + self.t2 = self.t2.to(device=device, dtype=dtype) + + +class LoKRLayer(LoRALayerBase): + # w1: Optional[torch.Tensor] = None + # w1_a: Optional[torch.Tensor] = None + # w1_b: Optional[torch.Tensor] = None + # w2: Optional[torch.Tensor] = None + # w2_a: Optional[torch.Tensor] = None + # w2_b: Optional[torch.Tensor] = None + # t2: Optional[torch.Tensor] = None + + def __init__( + self, + layer_key: str, + values: Dict[str, torch.Tensor], + ): + super().__init__(layer_key, values) + + if "lokr_w1" in values: + self.w1: Optional[torch.Tensor] = values["lokr_w1"] + self.w1_a = None + self.w1_b = None + else: + self.w1 = None + self.w1_a = values["lokr_w1_a"] + self.w1_b = values["lokr_w1_b"] + + if "lokr_w2" in values: + self.w2: Optional[torch.Tensor] = values["lokr_w2"] + self.w2_a = None + self.w2_b = None + else: + self.w2 = None + self.w2_a = values["lokr_w2_a"] + self.w2_b = values["lokr_w2_b"] + + if "lokr_t2" in values: + self.t2: Optional[torch.Tensor] = values["lokr_t2"] + else: + self.t2 = None + + if "lokr_w1_b" in values: + self.rank = values["lokr_w1_b"].shape[0] + elif "lokr_w2_b" in values: + self.rank = values["lokr_w2_b"].shape[0] + else: + self.rank = None # unscaled + + def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor: + w1: Optional[torch.Tensor] = self.w1 + if w1 is None: + assert self.w1_a is not None + assert self.w1_b is not None + w1 = self.w1_a @ self.w1_b + + w2 = self.w2 + if w2 is None: + if self.t2 is None: + assert self.w2_a is not None + assert self.w2_b is not None + w2 = self.w2_a @ self.w2_b + else: + w2 = torch.einsum("i j k l, i p, j r -> p r k l", self.t2, self.w2_a, self.w2_b) + + if len(w2.shape) == 4: + w1 = w1.unsqueeze(2).unsqueeze(2) + w2 = w2.contiguous() + assert w1 is not None + assert w2 is not None + weight = torch.kron(w1, w2) + + return weight + + def calc_size(self) -> int: + model_size = super().calc_size() + for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]: + if val is not None: + model_size += val.nelement() * val.element_size() + return model_size + + def to( + self, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> None: + super().to(device=device, dtype=dtype) + + if self.w1 is not None: + self.w1 = self.w1.to(device=device, dtype=dtype) + else: + assert self.w1_a is not None + assert self.w1_b is not None + self.w1_a = self.w1_a.to(device=device, dtype=dtype) + self.w1_b = self.w1_b.to(device=device, dtype=dtype) + + if self.w2 is not None: + self.w2 = self.w2.to(device=device, dtype=dtype) + else: + assert self.w2_a is not None + assert self.w2_b is not None + self.w2_a = self.w2_a.to(device=device, dtype=dtype) + self.w2_b = self.w2_b.to(device=device, dtype=dtype) + + if self.t2 is not None: + self.t2 = self.t2.to(device=device, dtype=dtype) + + +class FullLayer(LoRALayerBase): + # weight: torch.Tensor + + def __init__( + self, + layer_key: str, + values: Dict[str, torch.Tensor], + ): + super().__init__(layer_key, values) + + self.weight = values["diff"] + + if len(values.keys()) > 1: + _keys = list(values.keys()) + _keys.remove("diff") + raise NotImplementedError(f"Unexpected keys in lora diff layer: {_keys}") + + self.rank = None # unscaled + + def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor: + return self.weight + + def calc_size(self) -> int: + model_size = super().calc_size() + model_size += self.weight.nelement() * self.weight.element_size() + return model_size + + def to( + self, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + super().to(device=device, dtype=dtype) + + self.weight = self.weight.to(device=device, dtype=dtype) + + +class IA3Layer(LoRALayerBase): + # weight: torch.Tensor + # on_input: torch.Tensor + + def __init__( + self, + layer_key: str, + values: Dict[str, torch.Tensor], + ): + super().__init__(layer_key, values) + + self.weight = values["weight"] + self.on_input = values["on_input"] + + self.rank = None # unscaled + + def get_weight(self, orig_weight: torch.Tensor): + weight = self.weight + if not self.on_input: + weight = weight.reshape(-1, 1) + return orig_weight * weight + + def calc_size(self) -> int: + model_size = super().calc_size() + model_size += self.weight.nelement() * self.weight.element_size() + model_size += self.on_input.nelement() * self.on_input.element_size() + return model_size + + def to( + self, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + super().to(device=device, dtype=dtype) + + self.weight = self.weight.to(device=device, dtype=dtype) + self.on_input = self.on_input.to(device=device, dtype=dtype) + +AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer] + +# TODO: rename all methods used in model logic with Info postfix and remove here Raw postfix +class LoRAModelRaw: # (torch.nn.Module): + _name: str + layers: Dict[str, AnyLoRALayer] + + def __init__( + self, + name: str, + layers: Dict[str, AnyLoRALayer], + ): + self._name = name + self.layers = layers + + @property + def name(self) -> str: + return self._name + + def to( + self, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> None: + # TODO: try revert if exception? + for _key, layer in self.layers.items(): + layer.to(device=device, dtype=dtype) + + def calc_size(self) -> int: + model_size = 0 + for _, layer in self.layers.items(): + model_size += layer.calc_size() + return model_size + + @classmethod + def _convert_sdxl_keys_to_diffusers_format(cls, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """Convert the keys of an SDXL LoRA state_dict to diffusers format. + + The input state_dict can be in either Stability AI format or diffusers format. If the state_dict is already in + diffusers format, then this function will have no effect. + + This function is adapted from: + https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L385-L409 + + Args: + state_dict (Dict[str, Tensor]): The SDXL LoRA state_dict. + + Raises: + ValueError: If state_dict contains an unrecognized key, or not all keys could be converted. + + Returns: + Dict[str, Tensor]: The diffusers-format state_dict. + """ + converted_count = 0 # The number of Stability AI keys converted to diffusers format. + not_converted_count = 0 # The number of keys that were not converted. + + # Get a sorted list of Stability AI UNet keys so that we can efficiently search for keys with matching prefixes. + # For example, we want to efficiently find `input_blocks_4_1` in the list when searching for + # `input_blocks_4_1_proj_in`. + stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP) + stability_unet_keys.sort() + + new_state_dict = {} + for full_key, value in state_dict.items(): + if full_key.startswith("lora_unet_"): + search_key = full_key.replace("lora_unet_", "") + # Use bisect to find the key in stability_unet_keys that *may* match the search_key's prefix. + position = bisect.bisect_right(stability_unet_keys, search_key) + map_key = stability_unet_keys[position - 1] + # Now, check if the map_key *actually* matches the search_key. + if search_key.startswith(map_key): + new_key = full_key.replace(map_key, SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP[map_key]) + new_state_dict[new_key] = value + converted_count += 1 + else: + new_state_dict[full_key] = value + not_converted_count += 1 + elif full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"): + # The CLIP text encoders have the same keys in both Stability AI and diffusers formats. + new_state_dict[full_key] = value + continue + else: + raise ValueError(f"Unrecognized SDXL LoRA key prefix: '{full_key}'.") + + if converted_count > 0 and not_converted_count > 0: + raise ValueError( + f"The SDXL LoRA could only be partially converted to diffusers format. converted={converted_count}," + f" not_converted={not_converted_count}" + ) + + return new_state_dict + + @classmethod + def from_checkpoint( + cls, + file_path: Union[str, Path], + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + base_model: Optional[BaseModelType] = None, + ) -> Self: + device = device or torch.device("cpu") + dtype = dtype or torch.float32 + + if isinstance(file_path, str): + file_path = Path(file_path) + + model = cls( + name=file_path.stem, # TODO: + layers={}, + ) + + if file_path.suffix == ".safetensors": + state_dict = load_file(file_path.absolute().as_posix(), device="cpu") + else: + state_dict = torch.load(file_path, map_location="cpu") + + state_dict = cls._group_state(state_dict) + + if base_model == BaseModelType.StableDiffusionXL: + state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict) + + for layer_key, values in state_dict.items(): + # lora and locon + if "lora_down.weight" in values: + layer: AnyLoRALayer = LoRALayer(layer_key, values) + + # loha + elif "hada_w1_b" in values: + layer = LoHALayer(layer_key, values) + + # lokr + elif "lokr_w1_b" in values or "lokr_w1" in values: + layer = LoKRLayer(layer_key, values) + + # diff + elif "diff" in values: + layer = FullLayer(layer_key, values) + + # ia3 + elif "weight" in values and "on_input" in values: + layer = IA3Layer(layer_key, values) + + else: + print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}") + raise Exception("Unknown lora format!") + + # lower memory consumption by removing already parsed layer values + state_dict[layer_key].clear() + + layer.to(device=device, dtype=dtype) + model.layers[layer_key] = layer + + return model + + @staticmethod + def _group_state(state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]: + state_dict_groupped: Dict[str, Dict[str, torch.Tensor]] = {} + + for key, value in state_dict.items(): + stem, leaf = key.split(".", 1) + if stem not in state_dict_groupped: + state_dict_groupped[stem] = {} + state_dict_groupped[stem][leaf] = value + + return state_dict_groupped + + +# code from +# https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32 +def make_sdxl_unet_conversion_map() -> List[Tuple[str,str]]: + """Create a dict mapping state_dict keys from Stability AI SDXL format to diffusers SDXL format.""" + unet_conversion_map_layer = [] + + for i in range(3): # num_blocks is 3 in sdxl + # loop over downblocks/upblocks + for j in range(2): + # loop over resnets/attentions for downblocks + hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." + sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." + unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) + + if i < 3: + # no attention layers in down_blocks.3 + hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." + sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." + unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) + + for j in range(3): + # loop over resnets/attentions for upblocks + hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." + sd_up_res_prefix = f"output_blocks.{3*i + j}.0." + unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) + + # if i > 0: commentout for sdxl + # no attention layers in up_blocks.0 + hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." + sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." + unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) + + if i < 3: + # no downsample in down_blocks.3 + hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." + sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." + unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) + + # no upsample in up_blocks.3 + hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." + sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl + unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) + + hf_mid_atn_prefix = "mid_block.attentions.0." + sd_mid_atn_prefix = "middle_block.1." + unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) + + for j in range(2): + hf_mid_res_prefix = f"mid_block.resnets.{j}." + sd_mid_res_prefix = f"middle_block.{2*j}." + unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) + + unet_conversion_map_resnet = [ + # (stable-diffusion, HF Diffusers) + ("in_layers.0.", "norm1."), + ("in_layers.2.", "conv1."), + ("out_layers.0.", "norm2."), + ("out_layers.3.", "conv2."), + ("emb_layers.1.", "time_emb_proj."), + ("skip_connection.", "conv_shortcut."), + ] + + unet_conversion_map = [] + for sd, hf in unet_conversion_map_layer: + if "resnets" in hf: + for sd_res, hf_res in unet_conversion_map_resnet: + unet_conversion_map.append((sd + sd_res, hf + hf_res)) + else: + unet_conversion_map.append((sd, hf)) + + for j in range(2): + hf_time_embed_prefix = f"time_embedding.linear_{j+1}." + sd_time_embed_prefix = f"time_embed.{j*2}." + unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix)) + + for j in range(2): + hf_label_embed_prefix = f"add_embedding.linear_{j+1}." + sd_label_embed_prefix = f"label_emb.0.{j*2}." + unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix)) + + unet_conversion_map.append(("input_blocks.0.0.", "conv_in.")) + unet_conversion_map.append(("out.0.", "conv_norm_out.")) + unet_conversion_map.append(("out.2.", "conv_out.")) + + return unet_conversion_map + + +SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP = { + sd.rstrip(".").replace(".", "_"): hf.rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map() +} diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index 9fd118b7822..64a20a20923 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -29,8 +29,12 @@ LEGACY_CONFIGS: Dict[BaseModelType, Dict[ModelVariantType, Union[str, Dict[SchedulerPredictionType, str]]]] = { BaseModelType.StableDiffusion1: { - ModelVariantType.Normal: "v1-inference.yaml", + ModelVariantType.Normal: { + SchedulerPredictionType.Epsilon: "v1-inference.yaml", + SchedulerPredictionType.VPrediction: "v1-inference-v.yaml", + }, ModelVariantType.Inpaint: "v1-inpainting-inference.yaml", + ModelVariantType.Depth: "v2-midas-inference.yaml", }, BaseModelType.StableDiffusion2: { ModelVariantType.Normal: { diff --git a/invokeai/backend/util/__init__.py b/invokeai/backend/util/__init__.py index 0164dffe303..7b48f0364ea 100644 --- a/invokeai/backend/util/__init__.py +++ b/invokeai/backend/util/__init__.py @@ -12,14 +12,22 @@ torch_dtype, ) from .logging import InvokeAILogger -from .util import ( # TO DO: Clean this up; remove the unused symbols +from .util import ( # TO DO: Clean this up; remove the unused symbols GIG, Chdir, ask_user, # noqa directory_size, download_with_resume, - instantiate_from_config, # noqa + instantiate_from_config, # noqa url_attachment_name, # noqa - ) +) -__all__ = ["GIG", "directory_size","Chdir", "download_with_resume", "InvokeAILogger", "choose_precision", "choose_torch_device"] +__all__ = [ + "GIG", + "directory_size", + "Chdir", + "download_with_resume", + "InvokeAILogger", + "choose_precision", + "choose_torch_device", +] diff --git a/invokeai/backend/util/devices.py b/invokeai/backend/util/devices.py index ad3f4e139a7..a787f9b6f42 100644 --- a/invokeai/backend/util/devices.py +++ b/invokeai/backend/util/devices.py @@ -1,7 +1,7 @@ from __future__ import annotations from contextlib import nullcontext -from typing import Union, Optional +from typing import Optional, Union import torch from torch import autocast diff --git a/invokeai/backend/util/util.py b/invokeai/backend/util/util.py index 6589aa72784..ae376b41b25 100644 --- a/invokeai/backend/util/util.py +++ b/invokeai/backend/util/util.py @@ -27,6 +27,7 @@ # actual size of a gig GIG = 1073741824 + def directory_size(directory: Path) -> int: """ Return the aggregate size of all files in a directory (bytes). @@ -39,6 +40,7 @@ def directory_size(directory: Path) -> int: sum += Path(root, d).stat().st_size return sum + def log_txt_as_img(wh, xc, size=10): # wh a tuple of (width, height) # xc a list of captions to plot From da01d20348fb6526bb2f112d5d0fc358ebe74b26 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sun, 4 Feb 2024 23:18:00 -0500 Subject: [PATCH 076/340] added textual inversion and lora loaders --- .../model_install/model_install_default.py | 5 + .../{model_manager => embeddings}/lora.py | 35 +- invokeai/backend/embeddings/model_patcher.py | 586 ++++++++++++++++++ invokeai/backend/model_management/lora.py | 5 +- invokeai/backend/model_manager/config.py | 4 +- .../model_manager/load/load_default.py | 11 +- .../model_manager/load/memory_snapshot.py | 2 +- .../load/model_cache/__init__.py | 2 - .../load/model_loaders/controlnet.py | 4 +- .../load/model_loaders/generic_diffusers.py | 1 + .../load/model_loaders/ip_adapter.py | 6 +- .../model_manager/load/model_loaders/lora.py | 18 +- .../load/model_loaders/textual_inversion.py | 55 ++ .../model_manager/load/model_loaders/vae.py | 1 + .../backend/model_manager/load/model_util.py | 4 +- .../{model_manager => onnx}/onnx_runtime.py | 0 16 files changed, 701 insertions(+), 38 deletions(-) rename invokeai/backend/{model_manager => embeddings}/lora.py (96%) create mode 100644 invokeai/backend/embeddings/model_patcher.py create mode 100644 invokeai/backend/model_manager/load/model_loaders/textual_inversion.py rename invokeai/backend/{model_manager => onnx}/onnx_runtime.py (100%) diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index 2b2294bfce4..1c188b300df 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -178,6 +178,11 @@ def install_path( ) def import_model(self, source: ModelSource, config: Optional[Dict[str, Any]] = None) -> ModelInstallJob: # noqa D102 + similar_jobs = [x for x in self.list_jobs() if x.source == source and not x.in_terminal_state] + if similar_jobs: + self._logger.warning(f"There is already an active install job for {source}. Not enqueuing.") + return similar_jobs[0] + if isinstance(source, LocalModelSource): install_job = self._import_local_model(source, config) self._install_queue.put(install_job) # synchronously install diff --git a/invokeai/backend/model_manager/lora.py b/invokeai/backend/embeddings/lora.py similarity index 96% rename from invokeai/backend/model_manager/lora.py rename to invokeai/backend/embeddings/lora.py index 4c48de48ec7..9a59a977087 100644 --- a/invokeai/backend/model_manager/lora.py +++ b/invokeai/backend/embeddings/lora.py @@ -1,13 +1,17 @@ # Copyright (c) 2024 The InvokeAI Development team """LoRA model support.""" +import bisect +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Union + import torch from safetensors.torch import load_file -from pathlib import Path -from typing import Dict, Optional, Union, List, Tuple from typing_extensions import Self + from invokeai.backend.model_manager import BaseModelType + class LoRALayerBase: # rank: Optional[int] # alpha: Optional[float] @@ -41,7 +45,7 @@ def __init__( self.rank = None # set in layer implementation self.layer_key = layer_key - def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor: + def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor: raise NotImplementedError() def calc_size(self) -> int: @@ -82,7 +86,7 @@ def __init__( self.rank = self.down.shape[0] - def get_weight(self, orig_weight: torch.Tensor): + def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor: if self.mid is not None: up = self.up.reshape(self.up.shape[0], self.up.shape[1]) down = self.down.reshape(self.down.shape[0], self.down.shape[1]) @@ -121,11 +125,7 @@ class LoHALayer(LoRALayerBase): # t1: Optional[torch.Tensor] = None # t2: Optional[torch.Tensor] = None - def __init__( - self, - layer_key: str, - values: Dict[str, torch.Tensor] - ): + def __init__(self, layer_key: str, values: Dict[str, torch.Tensor]): super().__init__(layer_key, values) self.w1_a = values["hada_w1_a"] @@ -145,7 +145,7 @@ def __init__( self.rank = self.w1_b.shape[0] - def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor: + def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor: if self.t1 is None: weight: torch.Tensor = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b) @@ -227,7 +227,7 @@ def __init__( else: self.rank = None # unscaled - def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor: + def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor: w1: Optional[torch.Tensor] = self.w1 if w1 is None: assert self.w1_a is not None @@ -305,7 +305,7 @@ def __init__( self.rank = None # unscaled - def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor: + def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor: return self.weight def calc_size(self) -> int: @@ -330,7 +330,7 @@ class IA3Layer(LoRALayerBase): def __init__( self, layer_key: str, - values: Dict[str, torch.Tensor], + values: Dict[str, torch.Tensor], ): super().__init__(layer_key, values) @@ -339,10 +339,11 @@ def __init__( self.rank = None # unscaled - def get_weight(self, orig_weight: torch.Tensor): + def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor: weight = self.weight if not self.on_input: weight = weight.reshape(-1, 1) + assert orig_weight is not None return orig_weight * weight def calc_size(self) -> int: @@ -361,8 +362,10 @@ def to( self.weight = self.weight.to(device=device, dtype=dtype) self.on_input = self.on_input.to(device=device, dtype=dtype) + AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer] - + + # TODO: rename all methods used in model logic with Info postfix and remove here Raw postfix class LoRAModelRaw: # (torch.nn.Module): _name: str @@ -530,7 +533,7 @@ def _group_state(state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, tor # code from # https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32 -def make_sdxl_unet_conversion_map() -> List[Tuple[str,str]]: +def make_sdxl_unet_conversion_map() -> List[Tuple[str, str]]: """Create a dict mapping state_dict keys from Stability AI SDXL format to diffusers SDXL format.""" unet_conversion_map_layer = [] diff --git a/invokeai/backend/embeddings/model_patcher.py b/invokeai/backend/embeddings/model_patcher.py new file mode 100644 index 00000000000..6d73235197d --- /dev/null +++ b/invokeai/backend/embeddings/model_patcher.py @@ -0,0 +1,586 @@ +# Copyright (c) 2024 Ryan Dick, Lincoln D. Stein, and the InvokeAI Development Team +"""These classes implement model patching with LoRAs and Textual Inversions.""" +from __future__ import annotations + +import pickle +from contextlib import contextmanager +from pathlib import Path +from typing import Any, Dict, Generator, List, Optional, Tuple, Union + +import numpy as np +import torch +from compel.embeddings_provider import BaseTextualInversionManager +from diffusers import ModelMixin, OnnxRuntimeModel, UNet2DConditionModel +from safetensors.torch import load_file +from transformers import CLIPTextModel, CLIPTokenizer +from typing_extensions import Self + +from invokeai.app.shared.models import FreeUConfig +from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init +from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel + +from .lora import LoRAModelRaw + +""" +loras = [ + (lora_model1, 0.7), + (lora_model2, 0.4), +] +with LoRAHelper.apply_lora_unet(unet, loras): + # unet with applied loras +# unmodified unet + +""" + + +# TODO: rename smth like ModelPatcher and add TI method? +class ModelPatcher: + @staticmethod + def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]: + assert "." not in lora_key + + if not lora_key.startswith(prefix): + raise Exception(f"lora_key with invalid prefix: {lora_key}, {prefix}") + + module = model + module_key = "" + key_parts = lora_key[len(prefix) :].split("_") + + submodule_name = key_parts.pop(0) + + while len(key_parts) > 0: + try: + module = module.get_submodule(submodule_name) + module_key += "." + submodule_name + submodule_name = key_parts.pop(0) + except Exception: + submodule_name += "_" + key_parts.pop(0) + + module = module.get_submodule(submodule_name) + module_key = (module_key + "." + submodule_name).lstrip(".") + + return (module_key, module) + + @classmethod + @contextmanager + def apply_lora_unet( + cls, + unet: UNet2DConditionModel, + loras: List[Tuple[LoRAModelRaw, float]], + ) -> Generator[None, None, None]: + with cls.apply_lora(unet, loras, "lora_unet_"): + yield + + @classmethod + @contextmanager + def apply_lora_text_encoder( + cls, + text_encoder: CLIPTextModel, + loras: List[Tuple[LoRAModelRaw, float]], + ): + with cls.apply_lora(text_encoder, loras, "lora_te_"): + yield + + @classmethod + @contextmanager + def apply_sdxl_lora_text_encoder( + cls, + text_encoder: CLIPTextModel, + loras: List[Tuple[LoRAModelRaw, float]], + ): + with cls.apply_lora(text_encoder, loras, "lora_te1_"): + yield + + @classmethod + @contextmanager + def apply_sdxl_lora_text_encoder2( + cls, + text_encoder: CLIPTextModel, + loras: List[Tuple[LoRAModelRaw, float]], + ): + with cls.apply_lora(text_encoder, loras, "lora_te2_"): + yield + + @classmethod + @contextmanager + def apply_lora( + cls, + model: Union[torch.nn.Module, ModelMixin, UNet2DConditionModel], + loras: List[Tuple[LoRAModelRaw, float]], + prefix: str, + ) -> Generator[None, None, None]: + original_weights = {} + try: + with torch.no_grad(): + for lora, lora_weight in loras: + # assert lora.device.type == "cpu" + for layer_key, layer in lora.layers.items(): + if not layer_key.startswith(prefix): + continue + + # TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This + # should be improved in the following ways: + # 1. The key mapping could be more-efficiently pre-computed. This would save time every time a + # LoRA model is applied. + # 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the + # intricacies of Stable Diffusion key resolution. It should just expect the input LoRA + # weights to have valid keys. + module_key, module = cls._resolve_lora_key(model, layer_key, prefix) + + # All of the LoRA weight calculations will be done on the same device as the module weight. + # (Performance will be best if this is a CUDA device.) + device = module.weight.device + dtype = module.weight.dtype + + if module_key not in original_weights: + original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True) + + layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0 + + # We intentionally move to the target device first, then cast. Experimentally, this was found to + # be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the + # same thing in a single call to '.to(...)'. + layer.to(device=device) + layer.to(dtype=torch.float32) + # TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA + # devices here. Experimentally, it was found to be very slow on CPU. More investigation needed. + layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale) + layer.to(device=torch.device("cpu")) + + assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??! + if module.weight.shape != layer_weight.shape: + # TODO: debug on lycoris + assert hasattr(layer_weight, "reshape") + layer_weight = layer_weight.reshape(module.weight.shape) + + assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??! + module.weight += layer_weight.to(dtype=dtype) + + yield # wait for context manager exit + + finally: + assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule() + with torch.no_grad(): + for module_key, weight in original_weights.items(): + model.get_submodule(module_key).weight.copy_(weight) + + @classmethod + @contextmanager + def apply_ti( + cls, + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModel, + ti_list: List[Tuple[str, TextualInversionModel]], + ) -> Generator[Tuple[CLIPTokenizer, TextualInversionManager], None, None]: + init_tokens_count = None + new_tokens_added = None + + # TODO: This is required since Transformers 4.32 see + # https://github.com/huggingface/transformers/pull/25088 + # More information by NVIDIA: + # https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc + # This value might need to be changed in the future and take the GPUs model into account as there seem + # to be ideal values for different GPUS. This value is temporary! + # For references to the current discussion please see https://github.com/invoke-ai/InvokeAI/pull/4817 + pad_to_multiple_of = 8 + + try: + # HACK: The CLIPTokenizer API does not include a way to remove tokens after calling add_tokens(...). As a + # workaround, we create a full copy of `tokenizer` so that its original behavior can be restored after + # exiting this `apply_ti(...)` context manager. + # + # In a previous implementation, the deep copy was obtained with `ti_tokenizer = copy.deepcopy(tokenizer)`, + # but a pickle roundtrip was found to be much faster (1 sec vs. 0.05 secs). + ti_tokenizer = pickle.loads(pickle.dumps(tokenizer)) + ti_manager = TextualInversionManager(ti_tokenizer) + init_tokens_count = text_encoder.resize_token_embeddings(None, pad_to_multiple_of).num_embeddings + + def _get_trigger(ti_name: str, index: int) -> str: + trigger = ti_name + if index > 0: + trigger += f"-!pad-{i}" + return f"<{trigger}>" + + def _get_ti_embedding(model_embeddings: torch.nn.Module, ti: TextualInversionModel) -> torch.Tensor: + # for SDXL models, select the embedding that matches the text encoder's dimensions + if ti.embedding_2 is not None: + return ( + ti.embedding_2 + if ti.embedding_2.shape[1] == model_embeddings.weight.data[0].shape[0] + else ti.embedding + ) + else: + return ti.embedding + + # modify tokenizer + new_tokens_added = 0 + for ti_name, ti in ti_list: + ti_embedding = _get_ti_embedding(text_encoder.get_input_embeddings(), ti) + + for i in range(ti_embedding.shape[0]): + new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i)) + + # Modify text_encoder. + # resize_token_embeddings(...) constructs a new torch.nn.Embedding internally. Initializing the weights of + # this embedding is slow and unnecessary, so we wrap this step in skip_torch_weight_init() to save some + # time. + with skip_torch_weight_init(): + text_encoder.resize_token_embeddings(init_tokens_count + new_tokens_added, pad_to_multiple_of) + model_embeddings = text_encoder.get_input_embeddings() + + for ti_name, ti in ti_list: + ti_embedding = _get_ti_embedding(text_encoder.get_input_embeddings(), ti) + + ti_tokens = [] + for i in range(ti_embedding.shape[0]): + embedding = ti_embedding[i] + trigger = _get_trigger(ti_name, i) + + token_id = ti_tokenizer.convert_tokens_to_ids(trigger) + if token_id == ti_tokenizer.unk_token_id: + raise RuntimeError(f"Unable to find token id for token '{trigger}'") + + if model_embeddings.weight.data[token_id].shape != embedding.shape: + raise ValueError( + f"Cannot load embedding for {trigger}. It was trained on a model with token dimension" + f" {embedding.shape[0]}, but the current model has token dimension" + f" {model_embeddings.weight.data[token_id].shape[0]}." + ) + + model_embeddings.weight.data[token_id] = embedding.to( + device=text_encoder.device, dtype=text_encoder.dtype + ) + ti_tokens.append(token_id) + + if len(ti_tokens) > 1: + ti_manager.pad_tokens[ti_tokens[0]] = ti_tokens[1:] + + yield ti_tokenizer, ti_manager + + finally: + if init_tokens_count and new_tokens_added: + text_encoder.resize_token_embeddings(init_tokens_count, pad_to_multiple_of) + + @classmethod + @contextmanager + def apply_clip_skip( + cls, + text_encoder: CLIPTextModel, + clip_skip: int, + ) -> Generator[None, None, None]: + skipped_layers = [] + try: + for _i in range(clip_skip): + skipped_layers.append(text_encoder.text_model.encoder.layers.pop(-1)) + + yield + + finally: + while len(skipped_layers) > 0: + text_encoder.text_model.encoder.layers.append(skipped_layers.pop()) + + @classmethod + @contextmanager + def apply_freeu( + cls, + unet: UNet2DConditionModel, + freeu_config: Optional[FreeUConfig] = None, + ) -> Generator[None, None, None]: + did_apply_freeu = False + try: + assert hasattr(unet, "enable_freeu") # mypy doesn't pick up this attribute? + if freeu_config is not None: + unet.enable_freeu(b1=freeu_config.b1, b2=freeu_config.b2, s1=freeu_config.s1, s2=freeu_config.s2) + did_apply_freeu = True + + yield + + finally: + assert hasattr(unet, "disable_freeu") # mypy doesn't pick up this attribute? + if did_apply_freeu: + unet.disable_freeu() + + +class TextualInversionModel: + embedding: torch.Tensor # [n, 768]|[n, 1280] + embedding_2: Optional[torch.Tensor] = None # [n, 768]|[n, 1280] - for SDXL models + + @classmethod + def from_checkpoint( + cls, + file_path: Union[str, Path], + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> Self: + if not isinstance(file_path, Path): + file_path = Path(file_path) + + result = cls() # TODO: + + if file_path.suffix == ".safetensors": + state_dict = load_file(file_path.absolute().as_posix(), device="cpu") + else: + state_dict = torch.load(file_path, map_location="cpu") + + # both v1 and v2 format embeddings + # difference mostly in metadata + if "string_to_param" in state_dict: + if len(state_dict["string_to_param"]) > 1: + print( + f'Warn: Embedding "{file_path.name}" contains multiple tokens, which is not supported. The first', + " token will be used.", + ) + + result.embedding = next(iter(state_dict["string_to_param"].values())) + + # v3 (easynegative) + elif "emb_params" in state_dict: + result.embedding = state_dict["emb_params"] + + # v5(sdxl safetensors file) + elif "clip_g" in state_dict and "clip_l" in state_dict: + result.embedding = state_dict["clip_g"] + result.embedding_2 = state_dict["clip_l"] + + # v4(diffusers bin files) + else: + result.embedding = next(iter(state_dict.values())) + + if len(result.embedding.shape) == 1: + result.embedding = result.embedding.unsqueeze(0) + + if not isinstance(result.embedding, torch.Tensor): + raise ValueError(f"Invalid embeddings file: {file_path.name}") + + return result + + +# no type hints for BaseTextualInversionManager? +class TextualInversionManager(BaseTextualInversionManager): # type: ignore + pad_tokens: Dict[int, List[int]] + tokenizer: CLIPTokenizer + + def __init__(self, tokenizer: CLIPTokenizer): + self.pad_tokens = {} + self.tokenizer = tokenizer + + def expand_textual_inversion_token_ids_if_necessary(self, token_ids: list[int]) -> list[int]: + if len(self.pad_tokens) == 0: + return token_ids + + if token_ids[0] == self.tokenizer.bos_token_id: + raise ValueError("token_ids must not start with bos_token_id") + if token_ids[-1] == self.tokenizer.eos_token_id: + raise ValueError("token_ids must not end with eos_token_id") + + new_token_ids = [] + for token_id in token_ids: + new_token_ids.append(token_id) + if token_id in self.pad_tokens: + new_token_ids.extend(self.pad_tokens[token_id]) + + # Do not exceed the max model input size + # The -2 here is compensating for compensate compel.embeddings_provider.get_token_ids(), + # which first removes and then adds back the start and end tokens. + max_length = list(self.tokenizer.max_model_input_sizes.values())[0] - 2 + if len(new_token_ids) > max_length: + new_token_ids = new_token_ids[0:max_length] + + return new_token_ids + + +class ONNXModelPatcher: + @classmethod + @contextmanager + def apply_lora_unet( + cls, + unet: OnnxRuntimeModel, + loras: List[Tuple[LoRAModelRaw, float]], + ) -> Generator[None, None, None]: + with cls.apply_lora(unet, loras, "lora_unet_"): + yield + + @classmethod + @contextmanager + def apply_lora_text_encoder( + cls, + text_encoder: OnnxRuntimeModel, + loras: List[Tuple[LoRAModelRaw, float]], + ) -> Generator[None, None, None]: + with cls.apply_lora(text_encoder, loras, "lora_te_"): + yield + + # based on + # https://github.com/ssube/onnx-web/blob/ca2e436f0623e18b4cfe8a0363fcfcf10508acf7/api/onnx_web/convert/diffusion/lora.py#L323 + @classmethod + @contextmanager + def apply_lora( + cls, + model: IAIOnnxRuntimeModel, + loras: List[Tuple[LoRAModelRaw, float]], + prefix: str, + ) -> Generator[None, None, None]: + from .models.base import IAIOnnxRuntimeModel + + if not isinstance(model, IAIOnnxRuntimeModel): + raise Exception("Only IAIOnnxRuntimeModel models supported") + + orig_weights = {} + + try: + blended_loras: Dict[str, torch.Tensor] = {} + + for lora, lora_weight in loras: + for layer_key, layer in lora.layers.items(): + if not layer_key.startswith(prefix): + continue + + layer.to(dtype=torch.float32) + layer_key = layer_key.replace(prefix, "") + # TODO: rewrite to pass original tensor weight(required by ia3) + layer_weight = layer.get_weight(None).detach().cpu().numpy() * lora_weight + if layer_key in blended_loras: + blended_loras[layer_key] += layer_weight + else: + blended_loras[layer_key] = layer_weight + + node_names = {} + for node in model.nodes.values(): + node_names[node.name.replace("/", "_").replace(".", "_").lstrip("_")] = node.name + + for layer_key, lora_weight in blended_loras.items(): + conv_key = layer_key + "_Conv" + gemm_key = layer_key + "_Gemm" + matmul_key = layer_key + "_MatMul" + + if conv_key in node_names or gemm_key in node_names: + if conv_key in node_names: + conv_node = model.nodes[node_names[conv_key]] + else: + conv_node = model.nodes[node_names[gemm_key]] + + weight_name = [n for n in conv_node.input if ".weight" in n][0] + orig_weight = model.tensors[weight_name] + + if orig_weight.shape[-2:] == (1, 1): + if lora_weight.shape[-2:] == (1, 1): + new_weight = orig_weight.squeeze((3, 2)) + lora_weight.squeeze((3, 2)) + else: + new_weight = orig_weight.squeeze((3, 2)) + lora_weight + + new_weight = np.expand_dims(new_weight, (2, 3)) + else: + if orig_weight.shape != lora_weight.shape: + new_weight = orig_weight + lora_weight.reshape(orig_weight.shape) + else: + new_weight = orig_weight + lora_weight + + orig_weights[weight_name] = orig_weight + model.tensors[weight_name] = new_weight.astype(orig_weight.dtype) + + elif matmul_key in node_names: + weight_node = model.nodes[node_names[matmul_key]] + matmul_name = [n for n in weight_node.input if "MatMul" in n][0] + + orig_weight = model.tensors[matmul_name] + new_weight = orig_weight + lora_weight.transpose() + + orig_weights[matmul_name] = orig_weight + model.tensors[matmul_name] = new_weight.astype(orig_weight.dtype) + + else: + # warn? err? + pass + + yield + + finally: + # restore original weights + for name, orig_weight in orig_weights.items(): + model.tensors[name] = orig_weight + + @classmethod + @contextmanager + def apply_ti( + cls, + tokenizer: CLIPTokenizer, + text_encoder: IAIOnnxRuntimeModel, + ti_list: List[Tuple[str, Any]], + ) -> Generator[Tuple[CLIPTokenizer, TextualInversionManager], None, None]: + from .models.base import IAIOnnxRuntimeModel + + if not isinstance(text_encoder, IAIOnnxRuntimeModel): + raise Exception("Only IAIOnnxRuntimeModel models supported") + + orig_embeddings = None + + try: + # HACK: The CLIPTokenizer API does not include a way to remove tokens after calling add_tokens(...). As a + # workaround, we create a full copy of `tokenizer` so that its original behavior can be restored after + # exiting this `apply_ti(...)` context manager. + # + # In a previous implementation, the deep copy was obtained with `ti_tokenizer = copy.deepcopy(tokenizer)`, + # but a pickle roundtrip was found to be much faster (1 sec vs. 0.05 secs). + ti_tokenizer = pickle.loads(pickle.dumps(tokenizer)) + ti_manager = TextualInversionManager(ti_tokenizer) + + def _get_trigger(ti_name: str, index: int) -> str: + trigger = ti_name + if index > 0: + trigger += f"-!pad-{i}" + return f"<{trigger}>" + + # modify text_encoder + orig_embeddings = text_encoder.tensors["text_model.embeddings.token_embedding.weight"] + + # modify tokenizer + new_tokens_added = 0 + for ti_name, ti in ti_list: + if ti.embedding_2 is not None: + ti_embedding = ( + ti.embedding_2 if ti.embedding_2.shape[1] == orig_embeddings.shape[0] else ti.embedding + ) + else: + ti_embedding = ti.embedding + + for i in range(ti_embedding.shape[0]): + new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i)) + + embeddings = np.concatenate( + (np.copy(orig_embeddings), np.zeros((new_tokens_added, orig_embeddings.shape[1]))), + axis=0, + ) + + for ti_name, _ in ti_list: + ti_tokens = [] + for i in range(ti_embedding.shape[0]): + embedding = ti_embedding[i].detach().numpy() + trigger = _get_trigger(ti_name, i) + + token_id = ti_tokenizer.convert_tokens_to_ids(trigger) + if token_id == ti_tokenizer.unk_token_id: + raise RuntimeError(f"Unable to find token id for token '{trigger}'") + + if embeddings[token_id].shape != embedding.shape: + raise ValueError( + f"Cannot load embedding for {trigger}. It was trained on a model with token dimension" + f" {embedding.shape[0]}, but the current model has token dimension" + f" {embeddings[token_id].shape[0]}." + ) + + embeddings[token_id] = embedding + ti_tokens.append(token_id) + + if len(ti_tokens) > 1: + ti_manager.pad_tokens[ti_tokens[0]] = ti_tokens[1:] + + text_encoder.tensors["text_model.embeddings.token_embedding.weight"] = embeddings.astype( + orig_embeddings.dtype + ) + + yield ti_tokenizer, ti_manager + + finally: + # restore + if orig_embeddings is not None: + text_encoder.tensors["text_model.embeddings.token_embedding.weight"] = orig_embeddings diff --git a/invokeai/backend/model_management/lora.py b/invokeai/backend/model_management/lora.py index d72f55794d3..aed5eb60d57 100644 --- a/invokeai/backend/model_management/lora.py +++ b/invokeai/backend/model_management/lora.py @@ -102,7 +102,7 @@ def apply_sdxl_lora_text_encoder2( def apply_lora( cls, model: torch.nn.Module, - loras: List[Tuple[LoRAModel, float]], + loras: List[Tuple[LoRAModel, float]], # THIS IS INCORRECT. IT IS ACTUALLY A LoRAModelRaw prefix: str, ): original_weights = {} @@ -194,6 +194,8 @@ def _get_trigger(ti_name, index): return f"<{trigger}>" def _get_ti_embedding(model_embeddings, ti): + print(f"DEBUG: model_embeddings={type(model_embeddings)}, ti={type(ti)}") + print(f"DEBUG: is it an nn.Module? {isinstance(model_embeddings, torch.nn.Module)}") # for SDXL models, select the embedding that matches the text encoder's dimensions if ti.embedding_2 is not None: return ( @@ -202,6 +204,7 @@ def _get_ti_embedding(model_embeddings, ti): else ti.embedding ) else: + print(f"DEBUG: ti.embedding={type(ti.embedding)}") return ti.embedding # modify tokenizer diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index e59a84d7291..4488f8eafc5 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -28,9 +28,11 @@ from pydantic import BaseModel, ConfigDict, Field, TypeAdapter from typing_extensions import Annotated, Any, Dict -from .onnx_runtime import IAIOnnxRuntimeModel +from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel + from ..ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus + class InvalidModelConfigException(Exception): """Exception for when config parser doesn't recognized this combination of model type and format.""" diff --git a/invokeai/backend/model_manager/load/load_default.py b/invokeai/backend/model_manager/load/load_default.py index 453283e9b4a..adc84d20516 100644 --- a/invokeai/backend/model_manager/load/load_default.py +++ b/invokeai/backend/model_manager/load/load_default.py @@ -10,11 +10,17 @@ from diffusers.configuration_utils import ConfigMixin from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend.model_manager import AnyModel, AnyModelConfig, InvalidModelConfigException, ModelRepoVariant, SubModelType +from invokeai.backend.model_manager import ( + AnyModel, + AnyModelConfig, + InvalidModelConfigException, + ModelRepoVariant, + SubModelType, +) from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase -from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs, calc_model_size_by_data +from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data, calc_model_size_by_fs from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init from invokeai.backend.util.devices import choose_torch_device, torch_dtype @@ -160,4 +166,3 @@ def _load_model( submodel_type: Optional[SubModelType] = None, ) -> AnyModel: raise NotImplementedError - diff --git a/invokeai/backend/model_manager/load/memory_snapshot.py b/invokeai/backend/model_manager/load/memory_snapshot.py index 295be0c5514..346f5dc4247 100644 --- a/invokeai/backend/model_manager/load/memory_snapshot.py +++ b/invokeai/backend/model_manager/load/memory_snapshot.py @@ -97,4 +97,4 @@ def get_msg_line(prefix: str, val1: int, val2: int) -> str: if snapshot_1.vram is not None and snapshot_2.vram is not None: msg += get_msg_line("VRAM", snapshot_1.vram, snapshot_2.vram) - return "\n"+msg if len(msg)>0 else msg + return "\n" + msg if len(msg) > 0 else msg diff --git a/invokeai/backend/model_manager/load/model_cache/__init__.py b/invokeai/backend/model_manager/load/model_cache/__init__.py index 50cafa37696..6c87e2519e5 100644 --- a/invokeai/backend/model_manager/load/model_cache/__init__.py +++ b/invokeai/backend/model_manager/load/model_cache/__init__.py @@ -1,5 +1,3 @@ """Init file for RamCache.""" -from .model_cache_base import ModelCacheBase -from .model_cache_default import ModelCache _all__ = ["ModelCacheBase", "ModelCache"] diff --git a/invokeai/backend/model_manager/load/model_loaders/controlnet.py b/invokeai/backend/model_manager/load/model_loaders/controlnet.py index 8e6a80ceb20..e61e2b46a63 100644 --- a/invokeai/backend/model_manager/load/model_loaders/controlnet.py +++ b/invokeai/backend/model_manager/load/model_loaders/controlnet.py @@ -14,8 +14,10 @@ ) from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_controlnet_to_diffusers from invokeai.backend.model_manager.load.load_base import AnyModelLoader + from .generic_diffusers import GenericDiffusersLoader + @AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Diffusers) @AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Checkpoint) class ControlnetLoader(GenericDiffusersLoader): @@ -37,7 +39,7 @@ def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}: raise Exception(f"Vae conversion not supported for model type: {config.base}") else: - assert hasattr(config, 'config') + assert hasattr(config, "config") config_file = config.config if weights_path.suffix == ".safetensors": diff --git a/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py b/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py index f92a9048c50..03c26f3a0c0 100644 --- a/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py +++ b/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py @@ -15,6 +15,7 @@ from invokeai.backend.model_manager.load.load_base import AnyModelLoader from invokeai.backend.model_manager.load.load_default import ModelLoader + @AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.CLIPVision, format=ModelFormat.Diffusers) @AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.T2IAdapter, format=ModelFormat.Diffusers) class GenericDiffusersLoader(ModelLoader): diff --git a/invokeai/backend/model_manager/load/model_loaders/ip_adapter.py b/invokeai/backend/model_manager/load/model_loaders/ip_adapter.py index 63dc3790f16..27ced41c1e9 100644 --- a/invokeai/backend/model_manager/load/model_loaders/ip_adapter.py +++ b/invokeai/backend/model_manager/load/model_loaders/ip_adapter.py @@ -1,11 +1,11 @@ # Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team """Class for IP Adapter model loading in InvokeAI.""" -import torch - from pathlib import Path from typing import Optional +import torch + from invokeai.backend.ip_adapter.ip_adapter import build_ip_adapter from invokeai.backend.model_manager import ( AnyModel, @@ -18,6 +18,7 @@ from invokeai.backend.model_manager.load.load_base import AnyModelLoader from invokeai.backend.model_manager.load.load_default import ModelLoader + @AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.IPAdapter, format=ModelFormat.InvokeAI) class IPAdapterInvokeAILoader(ModelLoader): """Class to load IP Adapter diffusers models.""" @@ -36,4 +37,3 @@ def _load_model( dtype=self._torch_dtype, ) return model - diff --git a/invokeai/backend/model_manager/load/model_loaders/lora.py b/invokeai/backend/model_manager/load/model_loaders/lora.py index 4d19aadb7d2..d8e5f920e24 100644 --- a/invokeai/backend/model_manager/load/model_loaders/lora.py +++ b/invokeai/backend/model_manager/load/model_loaders/lora.py @@ -2,13 +2,12 @@ """Class for LoRA model loading in InvokeAI.""" +from logging import Logger from pathlib import Path from typing import Optional, Tuple -from logging import Logger -from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase -from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.backend.embeddings.lora import LoRAModelRaw from invokeai.backend.model_manager import ( AnyModel, AnyModelConfig, @@ -18,9 +17,11 @@ ModelType, SubModelType, ) -from invokeai.backend.model_manager.lora import LoRAModelRaw +from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase from invokeai.backend.model_manager.load.load_base import AnyModelLoader from invokeai.backend.model_manager.load.load_default import ModelLoader +from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase + @AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Diffusers) @AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Lycoris) @@ -47,6 +48,7 @@ def _load_model( ) -> AnyModel: if submodel_type is not None: raise ValueError("There are no submodels in a LoRA model.") + assert self._model_base is not None model = LoRAModelRaw.from_checkpoint( file_path=model_path, dtype=self._torch_dtype, @@ -56,9 +58,11 @@ def _load_model( # override def _get_model_path( - self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None + self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None ) -> Tuple[Path, AnyModelConfig, Optional[SubModelType]]: - self._model_base = config.base # cheating a little - setting this variable for later call to _load_model() + self._model_base = ( + config.base + ) # cheating a little - we remember this variable for using in the subsequent call to _load_model() model_base_path = self._app_config.models_path model_path = model_base_path / config.path @@ -72,5 +76,3 @@ def _get_model_path( result = model_path.resolve(), config, submodel_type return result - - diff --git a/invokeai/backend/model_manager/load/model_loaders/textual_inversion.py b/invokeai/backend/model_manager/load/model_loaders/textual_inversion.py new file mode 100644 index 00000000000..394fddc75d0 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_loaders/textual_inversion.py @@ -0,0 +1,55 @@ +# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team +"""Class for TI model loading in InvokeAI.""" + + +from pathlib import Path +from typing import Optional, Tuple + +from invokeai.backend.embeddings.model_patcher import TextualInversionModel as TextualInversionModelRaw +from invokeai.backend.model_manager import ( + AnyModel, + AnyModelConfig, + BaseModelType, + ModelFormat, + ModelRepoVariant, + ModelType, + SubModelType, +) +from invokeai.backend.model_manager.load.load_base import AnyModelLoader +from invokeai.backend.model_manager.load.load_default import ModelLoader + + +@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.TextualInversion, format=ModelFormat.EmbeddingFile) +@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.TextualInversion, format=ModelFormat.EmbeddingFolder) +class TextualInversionLoader(ModelLoader): + """Class to load TI models.""" + + def _load_model( + self, + model_path: Path, + model_variant: Optional[ModelRepoVariant] = None, + submodel_type: Optional[SubModelType] = None, + ) -> AnyModel: + if submodel_type is not None: + raise ValueError("There are no submodels in a TI model.") + model = TextualInversionModelRaw.from_checkpoint( + file_path=model_path, + dtype=self._torch_dtype, + ) + return model + + # override + def _get_model_path( + self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None + ) -> Tuple[Path, AnyModelConfig, Optional[SubModelType]]: + model_path = self._app_config.models_path / config.path + + if config.format == ModelFormat.EmbeddingFolder: + path = model_path / "learned_embeds.bin" + else: + path = model_path + + if not path.exists(): + raise OSError(f"The embedding file at {path} was not found") + + return path, config, submodel_type diff --git a/invokeai/backend/model_manager/load/model_loaders/vae.py b/invokeai/backend/model_manager/load/model_loaders/vae.py index 7a35e53459a..882ae055771 100644 --- a/invokeai/backend/model_manager/load/model_loaders/vae.py +++ b/invokeai/backend/model_manager/load/model_loaders/vae.py @@ -15,6 +15,7 @@ ) from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers from invokeai.backend.model_manager.load.load_base import AnyModelLoader + from .generic_diffusers import GenericDiffusersLoader diff --git a/invokeai/backend/model_manager/load/model_util.py b/invokeai/backend/model_manager/load/model_util.py index 404c88bbbcd..3f2d22595e2 100644 --- a/invokeai/backend/model_manager/load/model_util.py +++ b/invokeai/backend/model_manager/load/model_util.py @@ -3,13 +3,13 @@ import json from pathlib import Path -from typing import Optional, Union +from typing import Optional import torch from diffusers import DiffusionPipeline from invokeai.backend.model_manager.config import AnyModel -from invokeai.backend.model_manager.onnx_runtime import IAIOnnxRuntimeModel +from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel def calc_model_size_by_data(model: AnyModel) -> int: diff --git a/invokeai/backend/model_manager/onnx_runtime.py b/invokeai/backend/onnx/onnx_runtime.py similarity index 100% rename from invokeai/backend/model_manager/onnx_runtime.py rename to invokeai/backend/onnx/onnx_runtime.py From 0e113beb5b3a36b3460f21728aed8ac1fcc172a8 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Mon, 5 Feb 2024 21:55:11 -0500 Subject: [PATCH 077/340] Multiple refinements on loaders: - Cache stat collection enabled. - Implemented ONNX loading. - Add ability to specify the repo version variant in installer CLI. - If caller asks for a repo version that doesn't exist, will fall back to empty version rather than raising an error. --- .../model_install/model_install_default.py | 6 +-- invokeai/backend/install/install_helper.py | 18 ++++++-- invokeai/backend/model_manager/config.py | 14 ++++++- .../backend/model_manager/load/__init__.py | 1 - .../backend/model_manager/load/load_base.py | 4 +- .../model_manager/load/load_default.py | 32 ++++++++++---- .../load/model_cache/__init__.py | 3 +- .../load/model_cache/model_cache_base.py | 10 ++++- .../load/model_cache/model_cache_default.py | 42 +++++++++++++++++-- .../model_manager/load/model_loaders/onnx.py | 41 ++++++++++++++++++ .../model_manager/metadata/fetch/civitai.py | 7 +++- .../metadata/fetch/fetch_base.py | 7 +++- .../metadata/fetch/huggingface.py | 26 ++++++++---- .../model_manager/metadata/metadata_base.py | 1 - invokeai/backend/model_manager/probe.py | 16 +++++-- .../model_manager/util/select_hf_files.py | 14 +++++-- invokeai/backend/util/devices.py | 20 ++++++--- invokeai/frontend/install/model_install2.py | 2 +- 18 files changed, 215 insertions(+), 49 deletions(-) create mode 100644 invokeai/backend/model_manager/load/model_loaders/onnx.py diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index 1c188b300df..d32af4a513d 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -495,10 +495,10 @@ def _next_id(self) -> int: return id @staticmethod - def _guess_variant() -> ModelRepoVariant: + def _guess_variant() -> Optional[ModelRepoVariant]: """Guess the best HuggingFace variant type to download.""" precision = choose_precision(choose_torch_device()) - return ModelRepoVariant.FP16 if precision == "float16" else ModelRepoVariant.DEFAULT + return ModelRepoVariant.FP16 if precision == "float16" else None def _import_local_model(self, source: LocalModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob: return ModelInstallJob( @@ -523,7 +523,7 @@ def _import_from_hf(self, source: HFModelSource, config: Optional[Dict[str, Any] if not source.access_token: self._logger.info("No HuggingFace access token present; some models may not be downloadable.") - metadata = HuggingFaceMetadataFetch(self._session).from_id(source.repo_id) + metadata = HuggingFaceMetadataFetch(self._session).from_id(source.repo_id, source.variant) assert isinstance(metadata, ModelMetadataWithFiles) remote_files = metadata.download_urls( variant=source.variant or self._guess_variant(), diff --git a/invokeai/backend/install/install_helper.py b/invokeai/backend/install/install_helper.py index 9f219132d4d..57dfadcaeaa 100644 --- a/invokeai/backend/install/install_helper.py +++ b/invokeai/backend/install/install_helper.py @@ -30,6 +30,7 @@ from invokeai.backend.model_manager import ( BaseModelType, InvalidModelConfigException, + ModelRepoVariant, ModelType, ) from invokeai.backend.model_manager.metadata import UnknownMetadataException @@ -233,11 +234,18 @@ def _make_install_source(self, model_info: UnifiedModelInfo) -> ModelSource: if model_path.exists(): # local file on disk return LocalModelSource(path=model_path.absolute(), inplace=True) - if re.match(r"^[^/]+/[^/]+$", model_path_id_or_url): # hugging face repo_id + + # parsing huggingface repo ids + # we're going to do a little trick that allows for extended repo_ids of form "foo/bar:fp16" + variants = "|".join([x.lower() for x in ModelRepoVariant.__members__]) + if match := re.match(f"^([^/]+/[^/]+?)(?::({variants}))?$", model_path_id_or_url): + repo_id = match.group(1) + repo_variant = ModelRepoVariant(match.group(2)) if match.group(2) else None return HFModelSource( - repo_id=model_path_id_or_url, + repo_id=repo_id, access_token=HfFolder.get_token(), subfolder=model_info.subfolder, + variant=repo_variant, ) if re.match(r"^(http|https):", model_path_id_or_url): return URLModelSource(url=AnyHttpUrl(model_path_id_or_url)) @@ -278,9 +286,11 @@ def add_or_delete(self, selections: InstallSelections) -> None: model_name=model_name, ) if len(matches) > 1: - print(f"{model} is ambiguous. Please use model_type:model_name (e.g. main:my_model) to disambiguate.") + print( + f"{model_to_remove} is ambiguous. Please use model_base/model_type/model_name (e.g. sd-1/main/my_model) to disambiguate." + ) elif not matches: - print(f"{model}: unknown model") + print(f"{model_to_remove}: unknown model") else: for m in matches: print(f"Deleting {m.type}:{m.name}") diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 4488f8eafc5..49ce6af2b81 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -109,7 +109,7 @@ class SchedulerPredictionType(str, Enum): class ModelRepoVariant(str, Enum): """Various hugging face variants on the diffusers format.""" - DEFAULT = "default" # model files without "fp16" or other qualifier + DEFAULT = "" # model files without "fp16" or other qualifier - empty str FP16 = "fp16" FP32 = "fp32" ONNX = "onnx" @@ -246,6 +246,16 @@ class ONNXSD2Config(_MainConfig): upcast_attention: bool = True +class ONNXSDXLConfig(_MainConfig): + """Model config for ONNX format models based on sdxl.""" + + type: Literal[ModelType.ONNX] = ModelType.ONNX + format: Literal[ModelFormat.Onnx, ModelFormat.Olive] + # No yaml config file for ONNX, so these are part of config + base: Literal[BaseModelType.StableDiffusionXL] = BaseModelType.StableDiffusionXL + prediction_type: SchedulerPredictionType = SchedulerPredictionType.VPrediction + + class IPAdapterConfig(ModelConfigBase): """Model config for IP Adaptor format models.""" @@ -267,7 +277,7 @@ class T2IConfig(ModelConfigBase): format: Literal[ModelFormat.Diffusers] -_ONNXConfig = Annotated[Union[ONNXSD1Config, ONNXSD2Config], Field(discriminator="base")] +_ONNXConfig = Annotated[Union[ONNXSD1Config, ONNXSD2Config, ONNXSDXLConfig], Field(discriminator="base")] _ControlNetConfig = Annotated[ Union[ControlNetDiffusersConfig, ControlNetCheckpointConfig], Field(discriminator="format"), diff --git a/invokeai/backend/model_manager/load/__init__.py b/invokeai/backend/model_manager/load/__init__.py index 19b0116ba3b..e4c7077f783 100644 --- a/invokeai/backend/model_manager/load/__init__.py +++ b/invokeai/backend/model_manager/load/__init__.py @@ -16,7 +16,6 @@ # This registers the subclasses that implement loaders of specific model types loaders = [x.stem for x in Path(Path(__file__).parent, "model_loaders").glob("*.py") if x.stem != "__init__"] for module in loaders: - print(f"module={module}") import_module(f"{__package__}.model_loaders.{module}") __all__ = ["AnyModelLoader", "LoadedModel"] diff --git a/invokeai/backend/model_manager/load/load_base.py b/invokeai/backend/model_manager/load/load_base.py index 7d4e8337c3c..ee9d6d53e3d 100644 --- a/invokeai/backend/model_manager/load/load_base.py +++ b/invokeai/backend/model_manager/load/load_base.py @@ -22,6 +22,7 @@ from invokeai.backend.model_manager.config import VaeCheckpointConfig, VaeDiffusersConfig from invokeai.backend.model_manager.load.convert_cache.convert_cache_base import ModelConvertCacheBase from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase +from invokeai.backend.util.logging import InvokeAILogger @dataclass @@ -88,6 +89,7 @@ class AnyModelLoader: # this tracks the loader subclasses _registry: Dict[str, Type[ModelLoaderBase]] = {} + _logger: Logger = InvokeAILogger.get_logger() def __init__( self, @@ -167,7 +169,7 @@ def register( """Define a decorator which registers the subclass of loader.""" def decorator(subclass: Type[ModelLoaderBase]) -> Type[ModelLoaderBase]: - print("DEBUG: Registering class", subclass.__name__) + cls._logger.debug(f"Registering class {subclass.__name__} to load models of type {base}/{type}/{format}") key = cls._to_registry_key(base, type, format) cls._registry[key] = subclass return subclass diff --git a/invokeai/backend/model_manager/load/load_default.py b/invokeai/backend/model_manager/load/load_default.py index adc84d20516..757745072d1 100644 --- a/invokeai/backend/model_manager/load/load_default.py +++ b/invokeai/backend/model_manager/load/load_default.py @@ -52,7 +52,7 @@ def __init__( self._logger = logger self._ram_cache = ram_cache self._convert_cache = convert_cache - self._torch_dtype = torch_dtype(choose_torch_device()) + self._torch_dtype = torch_dtype(choose_torch_device(), app_config) def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel: """ @@ -102,8 +102,10 @@ def _load_if_needed( self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None ) -> ModelLockerBase: # TO DO: This is not thread safe! - if self._ram_cache.exists(config.key, submodel_type): + try: return self._ram_cache.get(config.key, submodel_type) + except IndexError: + pass model_variant = getattr(config, "repo_variant", None) self._ram_cache.make_room(self.get_size_fs(config, model_path, submodel_type)) @@ -119,7 +121,11 @@ def _load_if_needed( size=calc_model_size_by_data(loaded_model), ) - return self._ram_cache.get(config.key, submodel_type) + return self._ram_cache.get( + key=config.key, + submodel_type=submodel_type, + stats_name=":".join([config.base, config.type, config.name, (submodel_type or "")]), + ) def get_size_fs( self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None @@ -146,13 +152,21 @@ def _hf_definition_to_type(self, module: str, class_name: str) -> ModelMixin: # # TO DO: Add exception handling def _get_hf_load_class(self, model_path: Path, submodel_type: Optional[SubModelType] = None) -> ModelMixin: if submodel_type: - config = self._load_diffusers_config(model_path, config_name="model_index.json") - module, class_name = config[submodel_type.value] - return self._hf_definition_to_type(module=module, class_name=class_name) + try: + config = self._load_diffusers_config(model_path, config_name="model_index.json") + module, class_name = config[submodel_type.value] + return self._hf_definition_to_type(module=module, class_name=class_name) + except KeyError as e: + raise InvalidModelConfigException( + f'The "{submodel_type}" submodel is not available for this model.' + ) from e else: - config = self._load_diffusers_config(model_path, config_name="config.json") - class_name = config["_class_name"] - return self._hf_definition_to_type(module="diffusers", class_name=class_name) + try: + config = self._load_diffusers_config(model_path, config_name="config.json") + class_name = config["_class_name"] + return self._hf_definition_to_type(module="diffusers", class_name=class_name) + except KeyError as e: + raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e # This needs to be implemented in subclasses that handle checkpoints def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path: diff --git a/invokeai/backend/model_manager/load/model_cache/__init__.py b/invokeai/backend/model_manager/load/model_cache/__init__.py index 6c87e2519e5..0cb5184f3a4 100644 --- a/invokeai/backend/model_manager/load/model_cache/__init__.py +++ b/invokeai/backend/model_manager/load/model_cache/__init__.py @@ -1,3 +1,4 @@ -"""Init file for RamCache.""" +"""Init file for ModelCache.""" + _all__ = ["ModelCacheBase", "ModelCache"] diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_base.py b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py index 14a7dfb4a1f..b1a6768ee8f 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_base.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py @@ -129,11 +129,17 @@ def get( self, key: str, submodel_type: Optional[SubModelType] = None, + stats_name: Optional[str] = None, ) -> ModelLockerBase: """ - Retrieve model locker object using key and optional submodel_type. + Retrieve model using key and optional submodel_type. - This may return an UnknownModelException if the model is not in the cache. + :param key: Opaque model key + :param submodel_type: Type of the submodel to fetch + :param stats_name: A human-readable id for the model for the purposes of + stats reporting. + + This may raise an IndexError if the model is not in the cache. """ pass diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py index 688be8ceb48..7e30512a588 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py @@ -24,6 +24,7 @@ import sys import time from contextlib import suppress +from dataclasses import dataclass, field from logging import Logger from typing import Dict, List, Optional @@ -55,6 +56,20 @@ MB = 2**20 +@dataclass +class CacheStats(object): + """Collect statistics on cache performance.""" + + hits: int = 0 # cache hits + misses: int = 0 # cache misses + high_watermark: int = 0 # amount of cache used + in_cache: int = 0 # number of models in cache + cleared: int = 0 # number of models cleared to make space + cache_size: int = 0 # total size of cache + # {submodel_key => size} + loaded_model_sizes: Dict[str, int] = field(default_factory=dict) + + class ModelCache(ModelCacheBase[AnyModel]): """Implementation of ModelCacheBase.""" @@ -94,6 +109,8 @@ def __init__( self._storage_device: torch.device = storage_device self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__) self._log_memory_usage = log_memory_usage or self._logger.level == logging.DEBUG + # used for stats collection + self.stats = CacheStats() self._cached_models: Dict[str, CacheRecord[AnyModel]] = {} self._cache_stack: List[str] = [] @@ -158,21 +175,40 @@ def get( self, key: str, submodel_type: Optional[SubModelType] = None, + stats_name: Optional[str] = None, ) -> ModelLockerBase: """ Retrieve model using key and optional submodel_type. - This may return an IndexError if the model is not in the cache. + :param key: Opaque model key + :param submodel_type: Type of the submodel to fetch + :param stats_name: A human-readable id for the model for the purposes of + stats reporting. + + This may raise an IndexError if the model is not in the cache. """ key = self._make_cache_key(key, submodel_type) - if key not in self._cached_models: + if key in self._cached_models: + self.stats.hits += 1 + else: + self.stats.misses += 1 raise IndexError(f"The model with key {key} is not in the cache.") + cache_entry = self._cached_models[key] + + # more stats + stats_name = stats_name or key + self.stats.cache_size = int(self._max_cache_size * GIG) + self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size()) + self.stats.in_cache = len(self._cached_models) + self.stats.loaded_model_sizes[stats_name] = max( + self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.size + ) + # this moves the entry to the top (right end) of the stack with suppress(Exception): self._cache_stack.remove(key) self._cache_stack.append(key) - cache_entry = self._cached_models[key] return ModelLocker( cache=self, cache_entry=cache_entry, diff --git a/invokeai/backend/model_manager/load/model_loaders/onnx.py b/invokeai/backend/model_manager/load/model_loaders/onnx.py new file mode 100644 index 00000000000..935a6b7c953 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_loaders/onnx.py @@ -0,0 +1,41 @@ +# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team +"""Class for Onnx model loading in InvokeAI.""" + +# This should work the same as Stable Diffusion pipelines +from pathlib import Path +from typing import Optional + +from invokeai.backend.model_manager import ( + AnyModel, + BaseModelType, + ModelFormat, + ModelRepoVariant, + ModelType, + SubModelType, +) +from invokeai.backend.model_manager.load.load_base import AnyModelLoader +from invokeai.backend.model_manager.load.load_default import ModelLoader + + +@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Onnx) +@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Olive) +class OnnyxDiffusersModel(ModelLoader): + """Class to load onnx models.""" + + def _load_model( + self, + model_path: Path, + model_variant: Optional[ModelRepoVariant] = None, + submodel_type: Optional[SubModelType] = None, + ) -> AnyModel: + if not submodel_type is not None: + raise Exception("A submodel type must be provided when loading onnx pipelines.") + load_class = self._get_hf_load_class(model_path, submodel_type) + variant = model_variant.value if model_variant else None + model_path = model_path / submodel_type.value + result: AnyModel = load_class.from_pretrained( + model_path, + torch_dtype=self._torch_dtype, + variant=variant, + ) # type: ignore + return result diff --git a/invokeai/backend/model_manager/metadata/fetch/civitai.py b/invokeai/backend/model_manager/metadata/fetch/civitai.py index 6e41d6f11b2..7991f6a7489 100644 --- a/invokeai/backend/model_manager/metadata/fetch/civitai.py +++ b/invokeai/backend/model_manager/metadata/fetch/civitai.py @@ -32,6 +32,8 @@ from pydantic.networks import AnyHttpUrl from requests.sessions import Session +from invokeai.backend.model_manager import ModelRepoVariant + from ..metadata_base import ( AnyModelRepoMetadata, CivitaiMetadata, @@ -82,10 +84,13 @@ def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata: return self.from_civitai_versionid(int(version_id)) raise UnknownMetadataException("The url '{url}' does not match any known Civitai URL patterns") - def from_id(self, id: str) -> AnyModelRepoMetadata: + def from_id(self, id: str, variant: Optional[ModelRepoVariant] = None) -> AnyModelRepoMetadata: """ Given a Civitai model version ID, return a ModelRepoMetadata object. + :param id: An ID. + :param variant: A model variant from the ModelRepoVariant enum (currently ignored) + May raise an `UnknownMetadataException`. """ return self.from_civitai_versionid(int(id)) diff --git a/invokeai/backend/model_manager/metadata/fetch/fetch_base.py b/invokeai/backend/model_manager/metadata/fetch/fetch_base.py index 58b65b69477..d628ab5c178 100644 --- a/invokeai/backend/model_manager/metadata/fetch/fetch_base.py +++ b/invokeai/backend/model_manager/metadata/fetch/fetch_base.py @@ -18,6 +18,8 @@ from pydantic.networks import AnyHttpUrl from requests.sessions import Session +from invokeai.backend.model_manager import ModelRepoVariant + from ..metadata_base import AnyModelRepoMetadata, AnyModelRepoMetadataValidator @@ -45,10 +47,13 @@ def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata: pass @abstractmethod - def from_id(self, id: str) -> AnyModelRepoMetadata: + def from_id(self, id: str, variant: Optional[ModelRepoVariant] = None) -> AnyModelRepoMetadata: """ Given an ID for a model, return a ModelMetadata object. + :param id: An ID. + :param variant: A model variant from the ModelRepoVariant enum. + This method will raise a `UnknownMetadataException` in the event that the requested model's metadata is not found at the provided id. """ diff --git a/invokeai/backend/model_manager/metadata/fetch/huggingface.py b/invokeai/backend/model_manager/metadata/fetch/huggingface.py index 5d1eb0cc9e4..6f04e8713b2 100644 --- a/invokeai/backend/model_manager/metadata/fetch/huggingface.py +++ b/invokeai/backend/model_manager/metadata/fetch/huggingface.py @@ -19,10 +19,12 @@ import requests from huggingface_hub import HfApi, configure_http_backend, hf_hub_url -from huggingface_hub.utils._errors import RepositoryNotFoundError +from huggingface_hub.utils._errors import RepositoryNotFoundError, RevisionNotFoundError from pydantic.networks import AnyHttpUrl from requests.sessions import Session +from invokeai.backend.model_manager import ModelRepoVariant + from ..metadata_base import ( AnyModelRepoMetadata, HuggingFaceMetadata, @@ -53,12 +55,22 @@ def from_json(cls, json: str) -> HuggingFaceMetadata: metadata = HuggingFaceMetadata.model_validate_json(json) return metadata - def from_id(self, id: str) -> AnyModelRepoMetadata: + def from_id(self, id: str, variant: Optional[ModelRepoVariant] = None) -> AnyModelRepoMetadata: """Return a HuggingFaceMetadata object given the model's repo_id.""" - try: - model_info = HfApi().model_info(repo_id=id, files_metadata=True) - except RepositoryNotFoundError as excp: - raise UnknownMetadataException(f"'{id}' not found. See trace for details.") from excp + # Little loop which tries fetching a revision corresponding to the selected variant. + # If not available, then set variant to None and get the default. + # If this too fails, raise exception. + model_info = None + while not model_info: + try: + model_info = HfApi().model_info(repo_id=id, files_metadata=True, revision=variant) + except RepositoryNotFoundError as excp: + raise UnknownMetadataException(f"'{id}' not found. See trace for details.") from excp + except RevisionNotFoundError: + if variant is None: + raise + else: + variant = None _, name = id.split("/") return HuggingFaceMetadata( @@ -70,7 +82,7 @@ def from_id(self, id: str) -> AnyModelRepoMetadata: tags=model_info.tags, files=[ RemoteModelFile( - url=hf_hub_url(id, x.rfilename), + url=hf_hub_url(id, x.rfilename, revision=variant), path=Path(name, x.rfilename), size=x.size, sha256=x.lfs.get("sha256") if x.lfs else None, diff --git a/invokeai/backend/model_manager/metadata/metadata_base.py b/invokeai/backend/model_manager/metadata/metadata_base.py index 5aa883d26d0..5c3afcdc960 100644 --- a/invokeai/backend/model_manager/metadata/metadata_base.py +++ b/invokeai/backend/model_manager/metadata/metadata_base.py @@ -184,7 +184,6 @@ def download_urls( [x.path for x in self.files], variant, subfolder ) # all files in the model prefix = f"{subfolder}/" if subfolder else "" - # the next step reads model_index.json to determine which subdirectories belong # to the model if Path(f"{prefix}model_index.json") in paths: diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index 64a20a20923..55a9c0464a5 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -7,6 +7,7 @@ import torch from picklescan.scanner import scan_file_path +import invokeai.backend.util.logging as logger from invokeai.backend.model_management.models.base import read_checkpoint_meta from invokeai.backend.model_management.models.ip_adapter import IPAdapterModelFormat from invokeai.backend.model_management.util import lora_token_vector_length @@ -590,13 +591,20 @@ def get_base_type(self) -> BaseModelType: return TextualInversionCheckpointProbe(path).get_base_type() -class ONNXFolderProbe(FolderProbeBase): +class ONNXFolderProbe(PipelineFolderProbe): + def get_base_type(self) -> BaseModelType: + # Due to the way the installer is set up, the configuration file for safetensors + # will come along for the ride if both the onnx and safetensors forms + # share the same directory. We take advantage of this here. + if (self.model_path / "unet" / "config.json").exists(): + return super().get_base_type() + else: + logger.warning('Base type probing is not implemented for ONNX models. Assuming "sd-1"') + return BaseModelType.StableDiffusion1 + def get_format(self) -> ModelFormat: return ModelFormat("onnx") - def get_base_type(self) -> BaseModelType: - return BaseModelType.StableDiffusion1 - def get_variant_type(self) -> ModelVariantType: return ModelVariantType.Normal diff --git a/invokeai/backend/model_manager/util/select_hf_files.py b/invokeai/backend/model_manager/util/select_hf_files.py index 69760590440..a894d915de6 100644 --- a/invokeai/backend/model_manager/util/select_hf_files.py +++ b/invokeai/backend/model_manager/util/select_hf_files.py @@ -41,13 +41,21 @@ def filter_files( for file in files: if file.name.endswith((".json", ".txt")): paths.append(file) - elif file.name.endswith(("learned_embeds.bin", "ip_adapter.bin", "lora_weights.safetensors")): + elif file.name.endswith( + ( + "learned_embeds.bin", + "ip_adapter.bin", + "lora_weights.safetensors", + "weights.pb", + "onnx_data", + ) + ): paths.append(file) # BRITTLENESS WARNING!! # Diffusers models always seem to have "model" in their name, and the regex filter below is applied to avoid # downloading random checkpoints that might also be in the repo. However there is no guarantee # that a checkpoint doesn't contain "model" in its name, and no guarantee that future diffusers models - # will adhere to this naming convention, so this is an area of brittleness. + # will adhere to this naming convention, so this is an area to be careful of. elif re.search(r"model(\.[^.]+)?\.(safetensors|bin|onnx|xml|pth|pt|ckpt|msgpack)$", file.name): paths.append(file) @@ -64,7 +72,7 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path result = set() basenames: Dict[Path, Path] = {} for path in files: - if path.suffix == ".onnx": + if path.suffix in [".onnx", ".pb", ".onnx_data"]: if variant == ModelRepoVariant.ONNX: result.add(path) diff --git a/invokeai/backend/util/devices.py b/invokeai/backend/util/devices.py index a787f9b6f42..b4f24d8483b 100644 --- a/invokeai/backend/util/devices.py +++ b/invokeai/backend/util/devices.py @@ -29,12 +29,17 @@ def choose_torch_device() -> torch.device: return torch.device(config.device) -def choose_precision(device: torch.device) -> str: - """Returns an appropriate precision for the given torch device""" +# We are in transition here from using a single global AppConfig to allowing multiple +# configurations. It is strongly recommended to pass the app_config to this function. +def choose_precision(device: torch.device, app_config: Optional[InvokeAIAppConfig] = None) -> str: + """Return an appropriate precision for the given torch device.""" + app_config = app_config or config if device.type == "cuda": device_name = torch.cuda.get_device_name(device) if not ("GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name): - if config.precision == "bfloat16": + if app_config.precision == "float32": + return "float32" + elif app_config.precision == "bfloat16": return "bfloat16" else: return "float16" @@ -43,9 +48,14 @@ def choose_precision(device: torch.device) -> str: return "float32" -def torch_dtype(device: Optional[torch.device] = None) -> torch.dtype: +# We are in transition here from using a single global AppConfig to allowing multiple +# configurations. It is strongly recommended to pass the app_config to this function. +def torch_dtype( + device: Optional[torch.device] = None, + app_config: Optional[InvokeAIAppConfig] = None, +) -> torch.dtype: device = device or choose_torch_device() - precision = choose_precision(device) + precision = choose_precision(device, app_config) if precision == "float16": return torch.float16 if precision == "bfloat16": diff --git a/invokeai/frontend/install/model_install2.py b/invokeai/frontend/install/model_install2.py index 6eb480c8d9d..51a633a5654 100644 --- a/invokeai/frontend/install/model_install2.py +++ b/invokeai/frontend/install/model_install2.py @@ -505,7 +505,7 @@ def list_models(installer: ModelInstallService, model_type: ModelType): print(f"Installed models of type `{model_type}`:") for model in models: path = (config.models_path / model.path).resolve() - print(f"{model.name:40}{model.base.value:14}{path}") + print(f"{model.name:40}{model.base.value:5}{model.type.value:8}{model.format.value:12}{path}") # -------------------------------------------------------- From cbb1090461ccd3a5e2f2775fbcce1cfb9deef414 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Mon, 5 Feb 2024 22:56:32 -0500 Subject: [PATCH 078/340] BREAKING CHANGES: invocations now require model key, not base/type/name - Implement new model loader and modify invocations and embeddings - Finish implementation loaders for all models currently supported by InvokeAI. - Move lora, textual_inversion, and model patching support into backend/embeddings. - Restore support for model cache statistics collection (a little ugly, needs work). - Fixed up invocations that load and patch models. - Move seamless and silencewarnings utils into better location --- invokeai/app/api/routers/download_queue.py | 2 +- invokeai/app/invocations/compel.py | 114 +++++++----- .../controlnet_image_processors.py | 7 +- invokeai/app/invocations/ip_adapter.py | 49 +++-- invokeai/app/invocations/latent.py | 86 +++++---- invokeai/app/invocations/model.py | 176 +++++------------- invokeai/app/invocations/sdxl.py | 74 ++------ invokeai/app/invocations/t2i_adapter.py | 8 +- invokeai/app/services/events/events_base.py | 27 +-- .../invocation_stats_default.py | 16 +- .../model_records/model_records_base.py | 47 ++++- .../model_records/model_records_sql.py | 92 ++++++++- invokeai/backend/embeddings/__init__.py | 4 + invokeai/backend/embeddings/embedding_base.py | 12 ++ invokeai/backend/embeddings/lora.py | 14 +- invokeai/backend/embeddings/model_patcher.py | 134 +++---------- .../backend/embeddings/textual_inversion.py | 100 ++++++++++ invokeai/backend/install/install_helper.py | 3 +- invokeai/backend/model_manager/config.py | 5 +- .../backend/model_manager/load/load_base.py | 4 +- .../model_manager/load/load_default.py | 4 +- .../load/model_cache/__init__.py | 4 +- .../load/model_cache/model_cache_base.py | 33 +++- .../load/model_cache/model_cache_default.py | 53 +++--- .../load/model_loaders/textual_inversion.py | 2 +- invokeai/backend/stable_diffusion/__init__.py | 9 + invokeai/backend/stable_diffusion/seamless.py | 102 ++++++++++ invokeai/backend/util/silence_warnings.py | 28 +++ invokeai/frontend/install/model_install2.py | 8 +- .../util/test_hf_model_select.py | 2 + tests/test_model_probe.py | 6 +- 31 files changed, 728 insertions(+), 497 deletions(-) create mode 100644 invokeai/backend/embeddings/__init__.py create mode 100644 invokeai/backend/embeddings/embedding_base.py create mode 100644 invokeai/backend/embeddings/textual_inversion.py create mode 100644 invokeai/backend/stable_diffusion/seamless.py create mode 100644 invokeai/backend/util/silence_warnings.py diff --git a/invokeai/app/api/routers/download_queue.py b/invokeai/app/api/routers/download_queue.py index 92b658c3708..2dba376c181 100644 --- a/invokeai/app/api/routers/download_queue.py +++ b/invokeai/app/api/routers/download_queue.py @@ -55,7 +55,7 @@ async def download( ) -> DownloadJob: """Download the source URL to the file or directory indicted in dest.""" queue = ApiDependencies.invoker.services.download_queue - return queue.download(source, dest, priority, access_token) + return queue.download(source, Path(dest), priority, access_token) @download_queue_router.get( diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 978c6dcb17f..0e1a6bdc6fb 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -1,9 +1,10 @@ -from typing import List, Optional, Union +from typing import Iterator, List, Optional, Tuple, Union import torch from compel import Compel, ReturnedEmbeddingsType from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment +import invokeai.backend.util.logging as logger from invokeai.app.invocations.fields import ( FieldDescriptions, Input, @@ -12,18 +13,21 @@ UIComponent, ) from invokeai.app.invocations.primitives import ConditioningOutput +from invokeai.app.services.model_records import UnknownModelException from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.app.util.ti_utils import extract_ti_triggers_from_prompt +from invokeai.backend.embeddings.lora import LoRAModelRaw +from invokeai.backend.embeddings.model_patcher import ModelPatcher +from invokeai.backend.embeddings.textual_inversion import TextualInversionModelRaw +from invokeai.backend.model_manager import ModelType from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( BasicConditioningInfo, ConditioningFieldData, ExtraConditioningInfo, SDXLConditioningInfo, ) +from invokeai.backend.util.devices import torch_dtype -from ...backend.model_management.lora import ModelPatcher -from ...backend.model_management.models import ModelNotFoundException, ModelType -from ...backend.util.devices import torch_dtype -from ..util.ti_utils import extract_ti_triggers_from_prompt from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, @@ -64,13 +68,22 @@ class CompelInvocation(BaseInvocation): @torch.no_grad() def invoke(self, context: InvocationContext) -> ConditioningOutput: - tokenizer_info = context.models.load(**self.clip.tokenizer.model_dump()) - text_encoder_info = context.models.load(**self.clip.text_encoder.model_dump()) + tokenizer_info = context.services.model_records.load_model( + **self.clip.tokenizer.model_dump(), + context=context, + ) + text_encoder_info = context.services.model_records.load_model( + **self.clip.text_encoder.model_dump(), + context=context, + ) - def _lora_loader(): + def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: for lora in self.clip.loras: - lora_info = context.models.load(**lora.model_dump(exclude={"weight"})) - yield (lora_info.context.model, lora.weight) + lora_info = context.services.model_records.load_model( + **lora.model_dump(exclude={"weight"}), context=context + ) + assert isinstance(lora_info.model, LoRAModelRaw) + yield (lora_info.model, lora.weight) del lora_info return @@ -80,24 +93,20 @@ def _lora_loader(): for trigger in extract_ti_triggers_from_prompt(self.prompt): name = trigger[1:-1] try: - ti_list.append( - ( - name, - context.models.load( - model_name=name, - base_model=self.clip.text_encoder.base_model, - model_type=ModelType.TextualInversion, - ).context.model, - ) - ) - except ModelNotFoundException: + loaded_model = context.services.model_records.load_model( + **self.clip.text_encoder.model_dump(), + context=context, + ).model + assert isinstance(loaded_model, TextualInversionModelRaw) + ti_list.append((name, loaded_model)) + except UnknownModelException: # print(e) # import traceback # print(traceback.format_exc()) print(f'Warn: trigger: "{trigger}" not found') with ( - ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as ( + ModelPatcher.apply_ti(tokenizer_info.model, text_encoder_info.model, ti_list) as ( tokenizer, ti_manager, ), @@ -105,7 +114,7 @@ def _lora_loader(): # Apply the LoRA after text_encoder has been moved to its target device for faster patching. ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()), # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. - ModelPatcher.apply_clip_skip(text_encoder_info.context.model, self.clip.skipped_layers), + ModelPatcher.apply_clip_skip(text_encoder_info.model, self.clip.skipped_layers), ): compel = Compel( tokenizer=tokenizer, @@ -144,6 +153,8 @@ def _lora_loader(): class SDXLPromptInvocationBase: + """Prompt processor for SDXL models.""" + def run_clip_compel( self, context: InvocationContext, @@ -152,20 +163,27 @@ def run_clip_compel( get_pooled: bool, lora_prefix: str, zero_on_empty: bool, - ): - tokenizer_info = context.models.load(**clip_field.tokenizer.model_dump()) - text_encoder_info = context.models.load(**clip_field.text_encoder.model_dump()) + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[ExtraConditioningInfo]]: + tokenizer_info = context.services.model_records.load_model( + **clip_field.tokenizer.model_dump(), + context=context, + ) + text_encoder_info = context.services.model_records.load_model( + **clip_field.text_encoder.model_dump(), + context=context, + ) # return zero on empty if prompt == "" and zero_on_empty: - cpu_text_encoder = text_encoder_info.context.model + cpu_text_encoder = text_encoder_info.model + assert isinstance(cpu_text_encoder, torch.nn.Module) c = torch.zeros( ( 1, cpu_text_encoder.config.max_position_embeddings, cpu_text_encoder.config.hidden_size, ), - dtype=text_encoder_info.context.cache.precision, + dtype=cpu_text_encoder.dtype, ) if get_pooled: c_pooled = torch.zeros( @@ -176,10 +194,14 @@ def run_clip_compel( c_pooled = None return c, c_pooled, None - def _lora_loader(): + def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: for lora in clip_field.loras: - lora_info = context.models.load(**lora.model_dump(exclude={"weight"})) - yield (lora_info.context.model, lora.weight) + lora_info = context.services.model_records.load_model( + **lora.model_dump(exclude={"weight"}), context=context + ) + lora_model = lora_info.model + assert isinstance(lora_model, LoRAModelRaw) + yield (lora_model, lora.weight) del lora_info return @@ -189,24 +211,24 @@ def _lora_loader(): for trigger in extract_ti_triggers_from_prompt(prompt): name = trigger[1:-1] try: - ti_list.append( - ( - name, - context.models.load( - model_name=name, - base_model=clip_field.text_encoder.base_model, - model_type=ModelType.TextualInversion, - ).context.model, - ) - ) - except ModelNotFoundException: + ti_model = context.services.model_records.load_model_by_attr( + model_name=name, + base_model=text_encoder_info.config.base, + model_type=ModelType.TextualInversion, + context=context, + ).model + assert isinstance(ti_model, TextualInversionModelRaw) + ti_list.append((name, ti_model)) + except UnknownModelException: # print(e) # import traceback # print(traceback.format_exc()) - print(f'Warn: trigger: "{trigger}" not found') + logger.warning(f'trigger: "{trigger}" not found') + except ValueError: + logger.warning(f'trigger: "{trigger}" more than one similarly-named textual inversion models') with ( - ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as ( + ModelPatcher.apply_ti(tokenizer_info.model, text_encoder_info.model, ti_list) as ( tokenizer, ti_manager, ), @@ -214,7 +236,7 @@ def _lora_loader(): # Apply the LoRA after text_encoder has been moved to its target device for faster patching. ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix), # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. - ModelPatcher.apply_clip_skip(text_encoder_info.context.model, clip_field.skipped_layers), + ModelPatcher.apply_clip_skip(text_encoder_info.model, clip_field.skipped_layers), ): compel = Compel( tokenizer=tokenizer, @@ -332,6 +354,7 @@ def invoke(self, context: InvocationContext) -> ConditioningOutput: dim=1, ) + assert c2_pooled is not None conditioning_data = ConditioningFieldData( conditionings=[ SDXLConditioningInfo( @@ -380,6 +403,7 @@ def invoke(self, context: InvocationContext) -> ConditioningOutput: add_time_ids = torch.tensor([original_size + crop_coords + (self.aesthetic_score,)]) + assert c2_pooled is not None conditioning_data = ConditioningFieldData( conditionings=[ SDXLConditioningInfo( diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index 37954c1097e..580ee085627 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -23,7 +23,7 @@ ) from controlnet_aux.util import HWC3, ade_palette from PIL import Image -from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator +from pydantic import BaseModel, Field, field_validator, model_validator from invokeai.app.invocations.fields import ( FieldDescriptions, @@ -60,10 +60,7 @@ class ControlNetModelField(BaseModel): """ControlNet model field""" - model_name: str = Field(description="Name of the ControlNet model") - base_model: BaseModelType = Field(description="Base model") - - model_config = ConfigDict(protected_namespaces=()) + key: str = Field(description="Model config record key for the ControlNet model") class ControlField(BaseModel): diff --git a/invokeai/app/invocations/ip_adapter.py b/invokeai/app/invocations/ip_adapter.py index 845fcfa2848..700b285a45f 100644 --- a/invokeai/app/invocations/ip_adapter.py +++ b/invokeai/app/invocations/ip_adapter.py @@ -2,7 +2,8 @@ from builtins import float from typing import List, Union -from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator +from pydantic import BaseModel, Field, field_validator, model_validator +from typing_extensions import Self from invokeai.app.invocations.baseinvocation import ( BaseInvocation, @@ -18,18 +19,13 @@ from invokeai.backend.model_management.models.ip_adapter import get_ip_adapter_image_encoder_model_id +# LS: Consider moving these two classes into model.py class IPAdapterModelField(BaseModel): - model_name: str = Field(description="Name of the IP-Adapter model") - base_model: BaseModelType = Field(description="Base model") - - model_config = ConfigDict(protected_namespaces=()) + key: str = Field(description="Key to the IP-Adapter model") class CLIPVisionModelField(BaseModel): - model_name: str = Field(description="Name of the CLIP Vision image encoder model") - base_model: BaseModelType = Field(description="Base model (usually 'Any')") - - model_config = ConfigDict(protected_namespaces=()) + key: str = Field(description="Key to the CLIP Vision image encoder model") class IPAdapterField(BaseModel): @@ -46,16 +42,26 @@ class IPAdapterField(BaseModel): @field_validator("weight") @classmethod - def validate_ip_adapter_weight(cls, v): + def validate_ip_adapter_weight(cls, v: float) -> float: validate_weights(v) return v @model_validator(mode="after") - def validate_begin_end_step_percent(self): + def validate_begin_end_step_percent(self) -> Self: validate_begin_end_step(self.begin_step_percent, self.end_step_percent) return self +def get_ip_adapter_image_encoder_model_id(model_path: str): + """Read the ID of the image encoder associated with the IP-Adapter at `model_path`.""" + image_encoder_config_file = os.path.join(model_path, "image_encoder.txt") + + with open(image_encoder_config_file, "r") as f: + image_encoder_model = f.readline().strip() + + return image_encoder_model + + @invocation_output("ip_adapter_output") class IPAdapterOutput(BaseInvocationOutput): # Outputs @@ -84,33 +90,36 @@ class IPAdapterInvocation(BaseInvocation): @field_validator("weight") @classmethod - def validate_ip_adapter_weight(cls, v): + def validate_ip_adapter_weight(cls, v: float) -> float: validate_weights(v) return v @model_validator(mode="after") - def validate_begin_end_step_percent(self): + def validate_begin_end_step_percent(self) -> Self: validate_begin_end_step(self.begin_step_percent, self.end_step_percent) return self def invoke(self, context: InvocationContext) -> IPAdapterOutput: # Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model. - ip_adapter_info = context.models.get_info( - self.ip_adapter_model.model_name, self.ip_adapter_model.base_model, ModelType.IPAdapter - ) + ip_adapter_info = context.services.model_records.get_model(self.ip_adapter_model.key) # HACK(ryand): This is bad for a couple of reasons: 1) we are bypassing the model manager to read the model # directly, and 2) we are reading from disk every time this invocation is called without caching the result. # A better solution would be to store the image encoder model reference in the IP-Adapter model info, but this # is currently messy due to differences between how the model info is generated when installing a model from # disk vs. downloading the model. + # TODO (LS): Fix the issue above by: + # 1. Change IPAdapterConfig definition to include a field for the repo_id of the image encoder model. + # 2. Update probe.py to read `image_encoder.txt` and store it in the config. + # 3. Change below to get the image encoder from the configuration record. image_encoder_model_id = get_ip_adapter_image_encoder_model_id( - os.path.join(context.config.get().models_path, ip_adapter_info["path"]) + os.path.join(context.services.configuration.get_config().models_path, ip_adapter_info.path) ) image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip() - image_encoder_model = CLIPVisionModelField( - model_name=image_encoder_model_name, - base_model=BaseModelType.Any, + image_encoder_models = context.services.model_records.search_by_attr( + model_name=image_encoder_model_name, base_model=BaseModelType.Any, model_type=ModelType.CLIPVision ) + assert len(image_encoder_models) == 1 + image_encoder_model = CLIPVisionModelField(key=image_encoder_models[0].key) return IPAdapterOutput( ip_adapter=IPAdapterField( image=self.image, diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 69e3f055ca8..063b23fa589 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -3,13 +3,13 @@ import math from contextlib import ExitStack from functools import singledispatchmethod -from typing import List, Literal, Optional, Union +from typing import Iterator, List, Literal, Optional, Tuple, Union import einops import numpy as np import torch import torchvision.transforms as T -from diffusers import AutoencoderKL, AutoencoderTiny +from diffusers import AutoencoderKL, AutoencoderTiny, UNet2DConditionModel from diffusers.image_processor import VaeImageProcessor from diffusers.models.adapter import T2IAdapter from diffusers.models.attention_processor import ( @@ -46,14 +46,13 @@ from invokeai.app.invocations.t2i_adapter import T2IAdapterField from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.controlnet_utils import prepare_control_image +from invokeai.backend.embeddings.model_patcher import ModelPatcher from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus -from invokeai.backend.model_management.models import ModelType, SilenceWarnings +from invokeai.backend.model_manager import AnyModel, BaseModelType +from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData, IPAdapterConditioningInfo +from invokeai.backend.util.silence_warnings import SilenceWarnings -from ...backend.model_management.lora import ModelPatcher -from ...backend.model_management.models import BaseModelType -from ...backend.model_management.seamless import set_seamless -from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion.diffusers_pipeline import ( ControlNetData, IPAdapterData, @@ -149,7 +148,10 @@ def invoke(self, context: InvocationContext) -> DenoiseMaskOutput: ) if image is not None: - vae_info = context.models.load(**self.vae.vae.model_dump()) + vae_info = context.services.model_records.load_model( + **self.vae.vae.model_dump(), + context=context, + ) img_mask = tv_resize(mask, image.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False) masked_image = image * torch.where(img_mask < 0.5, 0.0, 1.0) @@ -175,7 +177,10 @@ def get_scheduler( seed: int, ) -> Scheduler: scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"]) - orig_scheduler_info = context.models.load(**scheduler_info.model_dump()) + orig_scheduler_info = context.services.model_records.load_model( + **scheduler_info.model_dump(), + context=context, + ) with orig_scheduler_info as orig_scheduler: scheduler_config = orig_scheduler.config @@ -389,10 +394,9 @@ def prep_control_data( controlnet_data = [] for control_info in control_list: control_model = exit_stack.enter_context( - context.models.load( - model_name=control_info.control_model.model_name, - model_type=ModelType.ControlNet, - base_model=control_info.control_model.base_model, + context.services.model_records.load_model( + key=control_info.control_model.key, + context=context, ) ) @@ -456,17 +460,15 @@ def prep_ip_adapter_data( conditioning_data.ip_adapter_conditioning = [] for single_ip_adapter in ip_adapter: ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context( - context.models.load( - model_name=single_ip_adapter.ip_adapter_model.model_name, - model_type=ModelType.IPAdapter, - base_model=single_ip_adapter.ip_adapter_model.base_model, + context.services.model_records.load_model( + key=single_ip_adapter.ip_adapter_model.key, + context=context, ) ) - image_encoder_model_info = context.models.load( - model_name=single_ip_adapter.image_encoder_model.model_name, - model_type=ModelType.CLIPVision, - base_model=single_ip_adapter.image_encoder_model.base_model, + image_encoder_model_info = context.services.model_records.load_model( + key=single_ip_adapter.image_encoder_model.key, + context=context, ) # `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here. @@ -518,10 +520,9 @@ def run_t2i_adapters( t2i_adapter_data = [] for t2i_adapter_field in t2i_adapter: - t2i_adapter_model_info = context.models.load( - model_name=t2i_adapter_field.t2i_adapter_model.model_name, - model_type=ModelType.T2IAdapter, - base_model=t2i_adapter_field.t2i_adapter_model.base_model, + t2i_adapter_model_info = context.services.model_records.load_model( + key=t2i_adapter_field.t2i_adapter_model.key, + context=context, ) image = context.images.get_pil(t2i_adapter_field.image.image_name) @@ -556,7 +557,7 @@ def run_t2i_adapters( do_classifier_free_guidance=False, width=t2i_input_width, height=t2i_input_height, - num_channels=t2i_adapter_model.config.in_channels, + num_channels=t2i_adapter_model.config["in_channels"], # mypy treats this as a FrozenDict device=t2i_adapter_model.device, dtype=t2i_adapter_model.dtype, resize_mode=t2i_adapter_field.resize_mode, @@ -662,22 +663,30 @@ def invoke(self, context: InvocationContext) -> LatentsOutput: def step_callback(state: PipelineIntermediateState): context.util.sd_step_callback(state, self.unet.unet.base_model) - def _lora_loader(): + def _lora_loader() -> Iterator[Tuple[AnyModel, float]]: for lora in self.unet.loras: - lora_info = context.models.load(**lora.model_dump(exclude={"weight"})) - yield (lora_info.context.model, lora.weight) + lora_info = context.services.model_records.load_model( + **lora.model_dump(exclude={"weight"}), + context=context, + ) + yield (lora_info.model, lora.weight) del lora_info return - unet_info = context.models.load(**self.unet.unet.model_dump()) + unet_info = context.services.model_records.load_model( + **self.unet.unet.model_dump(), + context=context, + ) + assert isinstance(unet_info.model, UNet2DConditionModel) with ( ExitStack() as exit_stack, - ModelPatcher.apply_freeu(unet_info.context.model, self.unet.freeu_config), - set_seamless(unet_info.context.model, self.unet.seamless_axes), + ModelPatcher.apply_freeu(unet_info.model, self.unet.freeu_config), + set_seamless(unet_info.model, self.unet.seamless_axes), # FIXME unet_info as unet, # Apply the LoRA after unet has been moved to its target device for faster patching. ModelPatcher.apply_lora_unet(unet, _lora_loader()), ): + assert isinstance(unet, torch.Tensor) latents = latents.to(device=unet.device, dtype=unet.dtype) if noise is not None: noise = noise.to(device=unet.device, dtype=unet.dtype) @@ -774,9 +783,13 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard): def invoke(self, context: InvocationContext) -> ImageOutput: latents = context.tensors.load(self.latents.latents_name) - vae_info = context.models.load(**self.vae.vae.model_dump()) + vae_info = context.services.model_records.load_model( + **self.vae.vae.model_dump(), + context=context, + ) - with set_seamless(vae_info.context.model, self.vae.seamless_axes), vae_info as vae: + with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae: + assert isinstance(vae, torch.Tensor) latents = latents.to(vae.device) if self.fp32: vae.to(dtype=torch.float32) @@ -995,7 +1008,10 @@ def vae_encode(vae_info, upcast, tiled, image_tensor): def invoke(self, context: InvocationContext) -> LatentsOutput: image = context.images.get_pil(self.image.image_name) - vae_info = context.models.load(**self.vae.vae.model_dump()) + vae_info = context.services.model_records.load_model( + **self.vae.vae.model_dump(), + context=context, + ) image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB")) if image_tensor.dim() == 3: diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index 6a1fd6d36bc..e2ea7442839 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -1,13 +1,13 @@ import copy from typing import List, Optional -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, Field from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.shared.models import FreeUConfig -from ...backend.model_management import BaseModelType, ModelType, SubModelType +from ...backend.model_manager import SubModelType from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, @@ -17,13 +17,9 @@ class ModelInfo(BaseModel): - model_name: str = Field(description="Info to load submodel") - base_model: BaseModelType = Field(description="Base model") - model_type: ModelType = Field(description="Info to load submodel") + key: str = Field(description="Info to load submodel") submodel: Optional[SubModelType] = Field(default=None, description="Info to load submodel") - model_config = ConfigDict(protected_namespaces=()) - class LoraInfo(ModelInfo): weight: float = Field(description="Lora's weight which to use when apply to model") @@ -52,7 +48,7 @@ class VaeField(BaseModel): @invocation_output("unet_output") class UNetOutput(BaseInvocationOutput): - """Base class for invocations that output a UNet field""" + """Base class for invocations that output a UNet field.""" unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet") @@ -81,20 +77,13 @@ class ModelLoaderOutput(UNetOutput, CLIPOutput, VAEOutput): class MainModelField(BaseModel): """Main model field""" - model_name: str = Field(description="Name of the model") - base_model: BaseModelType = Field(description="Base model") - model_type: ModelType = Field(description="Model Type") - - model_config = ConfigDict(protected_namespaces=()) + key: str = Field(description="Model key") class LoRAModelField(BaseModel): """LoRA model field""" - model_name: str = Field(description="Name of the LoRA model") - base_model: BaseModelType = Field(description="Base model") - - model_config = ConfigDict(protected_namespaces=()) + key: str = Field(description="LoRA model key") @invocation( @@ -111,74 +100,31 @@ class MainModelLoaderInvocation(BaseInvocation): # TODO: precision? def invoke(self, context: InvocationContext) -> ModelLoaderOutput: - base_model = self.model.base_model - model_name = self.model.model_name - model_type = ModelType.Main + key = self.model.key # TODO: not found exceptions - if not context.models.exists( - model_name=model_name, - base_model=base_model, - model_type=model_type, - ): - raise Exception(f"Unknown {base_model} {model_type} model: {model_name}") - - """ - if not context.services.model_manager.model_exists( - model_name=self.model_name, - model_type=SDModelType.Diffusers, - submodel=SDModelType.Tokenizer, - ): - raise Exception( - f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted" - ) - - if not context.services.model_manager.model_exists( - model_name=self.model_name, - model_type=SDModelType.Diffusers, - submodel=SDModelType.TextEncoder, - ): - raise Exception( - f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted" - ) - - if not context.services.model_manager.model_exists( - model_name=self.model_name, - model_type=SDModelType.Diffusers, - submodel=SDModelType.UNet, - ): - raise Exception( - f"Failed to find unet submodel from {self.model_name}! Check if model corrupted" - ) - """ + if not context.services.model_records.exists(key): + raise Exception(f"Unknown model {key}") return ModelLoaderOutput( unet=UNetField( unet=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, + key=key, submodel=SubModelType.UNet, ), scheduler=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, + key=key, submodel=SubModelType.Scheduler, ), loras=[], ), clip=ClipField( tokenizer=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, + key=key, submodel=SubModelType.Tokenizer, ), text_encoder=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, + key=key, submodel=SubModelType.TextEncoder, ), loras=[], @@ -186,9 +132,7 @@ def invoke(self, context: InvocationContext) -> ModelLoaderOutput: ), vae=VaeField( vae=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, + key=key, submodel=SubModelType.Vae, ), ), @@ -226,21 +170,16 @@ def invoke(self, context: InvocationContext) -> LoraLoaderOutput: if self.lora is None: raise Exception("No LoRA provided") - base_model = self.lora.base_model - lora_name = self.lora.model_name + lora_key = self.lora.key - if not context.models.exists( - base_model=base_model, - model_name=lora_name, - model_type=ModelType.Lora, - ): - raise Exception(f"Unkown lora name: {lora_name}!") + if not context.services.model_records.exists(lora_key): + raise Exception(f"Unkown lora: {lora_key}!") - if self.unet is not None and any(lora.model_name == lora_name for lora in self.unet.loras): - raise Exception(f'Lora "{lora_name}" already applied to unet') + if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras): + raise Exception(f'Lora "{lora_key}" already applied to unet') - if self.clip is not None and any(lora.model_name == lora_name for lora in self.clip.loras): - raise Exception(f'Lora "{lora_name}" already applied to clip') + if self.clip is not None and any(lora.key == lora_key for lora in self.clip.loras): + raise Exception(f'Lora "{lora_key}" already applied to clip') output = LoraLoaderOutput() @@ -248,9 +187,7 @@ def invoke(self, context: InvocationContext) -> LoraLoaderOutput: output.unet = copy.deepcopy(self.unet) output.unet.loras.append( LoraInfo( - base_model=base_model, - model_name=lora_name, - model_type=ModelType.Lora, + key=lora_key, submodel=None, weight=self.weight, ) @@ -260,9 +197,7 @@ def invoke(self, context: InvocationContext) -> LoraLoaderOutput: output.clip = copy.deepcopy(self.clip) output.clip.loras.append( LoraInfo( - base_model=base_model, - model_name=lora_name, - model_type=ModelType.Lora, + key=lora_key, submodel=None, weight=self.weight, ) @@ -315,24 +250,19 @@ def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput: if self.lora is None: raise Exception("No LoRA provided") - base_model = self.lora.base_model - lora_name = self.lora.model_name + lora_key = self.lora.key - if not context.models.exists( - base_model=base_model, - model_name=lora_name, - model_type=ModelType.Lora, - ): - raise Exception(f"Unknown lora name: {lora_name}!") + if not context.services.model_records.exists(lora_key): + raise Exception(f"Unknown lora: {lora_key}!") - if self.unet is not None and any(lora.model_name == lora_name for lora in self.unet.loras): - raise Exception(f'Lora "{lora_name}" already applied to unet') + if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras): + raise Exception(f'Lora "{lora_key}" already applied to unet') - if self.clip is not None and any(lora.model_name == lora_name for lora in self.clip.loras): - raise Exception(f'Lora "{lora_name}" already applied to clip') + if self.clip is not None and any(lora.key == lora_key for lora in self.clip.loras): + raise Exception(f'Lora "{lora_key}" already applied to clip') - if self.clip2 is not None and any(lora.model_name == lora_name for lora in self.clip2.loras): - raise Exception(f'Lora "{lora_name}" already applied to clip2') + if self.clip2 is not None and any(lora.key == lora_key for lora in self.clip2.loras): + raise Exception(f'Lora "{lora_key}" already applied to clip2') output = SDXLLoraLoaderOutput() @@ -340,9 +270,7 @@ def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput: output.unet = copy.deepcopy(self.unet) output.unet.loras.append( LoraInfo( - base_model=base_model, - model_name=lora_name, - model_type=ModelType.Lora, + key=lora_key, submodel=None, weight=self.weight, ) @@ -352,9 +280,7 @@ def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput: output.clip = copy.deepcopy(self.clip) output.clip.loras.append( LoraInfo( - base_model=base_model, - model_name=lora_name, - model_type=ModelType.Lora, + key=lora_key, submodel=None, weight=self.weight, ) @@ -364,9 +290,7 @@ def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput: output.clip2 = copy.deepcopy(self.clip2) output.clip2.loras.append( LoraInfo( - base_model=base_model, - model_name=lora_name, - model_type=ModelType.Lora, + key=lora_key, submodel=None, weight=self.weight, ) @@ -378,10 +302,7 @@ def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput: class VAEModelField(BaseModel): """Vae model field""" - model_name: str = Field(description="Name of the model") - base_model: BaseModelType = Field(description="Base model") - - model_config = ConfigDict(protected_namespaces=()) + key: str = Field(description="Model's key") @invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.1") @@ -395,25 +316,12 @@ class VaeLoaderInvocation(BaseInvocation): ) def invoke(self, context: InvocationContext) -> VAEOutput: - base_model = self.vae_model.base_model - model_name = self.vae_model.model_name - model_type = ModelType.Vae - - if not context.models.exists( - base_model=base_model, - model_name=model_name, - model_type=model_type, - ): - raise Exception(f"Unkown vae name: {model_name}!") - return VAEOutput( - vae=VaeField( - vae=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - ) - ) - ) + key = self.vae_model.key + + if not context.services.model_records.exists(key): + raise Exception(f"Unkown vae: {key}!") + + return VAEOutput(vae=VaeField(vae=ModelInfo(key=key))) @invocation_output("seamless_output") diff --git a/invokeai/app/invocations/sdxl.py b/invokeai/app/invocations/sdxl.py index 8d51674a046..633a6477fdb 100644 --- a/invokeai/app/invocations/sdxl.py +++ b/invokeai/app/invocations/sdxl.py @@ -1,7 +1,7 @@ from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.model_manager import SubModelType -from ...backend.model_management import ModelType, SubModelType from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, @@ -40,45 +40,31 @@ class SDXLModelLoaderInvocation(BaseInvocation): # TODO: precision? def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput: - base_model = self.model.base_model - model_name = self.model.model_name - model_type = ModelType.Main + model_key = self.model.key # TODO: not found exceptions - if not context.models.exists( - model_name=model_name, - base_model=base_model, - model_type=model_type, - ): - raise Exception(f"Unknown {base_model} {model_type} model: {model_name}") + if not context.services.model_records.exists(model_key): + raise Exception(f"Unknown model: {model_key}") return SDXLModelLoaderOutput( unet=UNetField( unet=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, + key=model_key, submodel=SubModelType.UNet, ), scheduler=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, + key=model_key, submodel=SubModelType.Scheduler, ), loras=[], ), clip=ClipField( tokenizer=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, + key=model_key, submodel=SubModelType.Tokenizer, ), text_encoder=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, + key=model_key, submodel=SubModelType.TextEncoder, ), loras=[], @@ -86,15 +72,11 @@ def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput: ), clip2=ClipField( tokenizer=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, + key=model_key, submodel=SubModelType.Tokenizer2, ), text_encoder=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, + key=model_key, submodel=SubModelType.TextEncoder2, ), loras=[], @@ -102,9 +84,7 @@ def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput: ), vae=VaeField( vae=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, + key=model_key, submodel=SubModelType.Vae, ), ), @@ -129,45 +109,31 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation): # TODO: precision? def invoke(self, context: InvocationContext) -> SDXLRefinerModelLoaderOutput: - base_model = self.model.base_model - model_name = self.model.model_name - model_type = ModelType.Main + model_key = self.model.key # TODO: not found exceptions - if not context.models.exists( - model_name=model_name, - base_model=base_model, - model_type=model_type, - ): - raise Exception(f"Unknown {base_model} {model_type} model: {model_name}") + if not context.services.model_records.exists(model_key): + raise Exception(f"Unknown model: {model_key}") return SDXLRefinerModelLoaderOutput( unet=UNetField( unet=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, + key=model_key, submodel=SubModelType.UNet, ), scheduler=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, + key=model_key, submodel=SubModelType.Scheduler, ), loras=[], ), clip2=ClipField( tokenizer=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, + key=model_key, submodel=SubModelType.Tokenizer2, ), text_encoder=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, + key=model_key, submodel=SubModelType.TextEncoder2, ), loras=[], @@ -175,9 +141,7 @@ def invoke(self, context: InvocationContext) -> SDXLRefinerModelLoaderOutput: ), vae=VaeField( vae=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, + key=model_key, submodel=SubModelType.Vae, ), ), diff --git a/invokeai/app/invocations/t2i_adapter.py b/invokeai/app/invocations/t2i_adapter.py index 0f4fe66ada1..0f1e251bb36 100644 --- a/invokeai/app/invocations/t2i_adapter.py +++ b/invokeai/app/invocations/t2i_adapter.py @@ -1,6 +1,6 @@ from typing import Union -from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator +from pydantic import BaseModel, Field, field_validator, model_validator from invokeai.app.invocations.baseinvocation import ( BaseInvocation, @@ -12,14 +12,10 @@ from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField from invokeai.app.invocations.util import validate_begin_end_step, validate_weights from invokeai.app.services.shared.invocation_context import InvocationContext -from invokeai.backend.model_management.models.base import BaseModelType class T2IAdapterModelField(BaseModel): - model_name: str = Field(description="Name of the T2I-Adapter model") - base_model: BaseModelType = Field(description="Base model") - - model_config = ConfigDict(protected_namespaces=()) + key: str = Field(description="Model record key for the T2I-Adapter model") class T2IAdapterField(BaseModel): diff --git a/invokeai/app/services/events/events_base.py b/invokeai/app/services/events/events_base.py index 6b441efc2bf..90d9068b88c 100644 --- a/invokeai/app/services/events/events_base.py +++ b/invokeai/app/services/events/events_base.py @@ -11,8 +11,7 @@ SessionQueueStatus, ) from invokeai.app.util.misc import get_timestamp -from invokeai.backend.model_management.model_manager import LoadedModelInfo -from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType +from invokeai.backend.model_manager import AnyModelConfig class EventServiceBase: @@ -171,10 +170,7 @@ def emit_model_load_started( queue_item_id: int, queue_batch_id: str, graph_execution_state_id: str, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - submodel: SubModelType, + model_config: AnyModelConfig, ) -> None: """Emitted when a model is requested""" self.__emit_queue_event( @@ -184,10 +180,7 @@ def emit_model_load_started( "queue_item_id": queue_item_id, "queue_batch_id": queue_batch_id, "graph_execution_state_id": graph_execution_state_id, - "model_name": model_name, - "base_model": base_model, - "model_type": model_type, - "submodel": submodel, + "model_config": model_config.model_dump(), }, ) @@ -197,11 +190,7 @@ def emit_model_load_completed( queue_item_id: int, queue_batch_id: str, graph_execution_state_id: str, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - submodel: SubModelType, - loaded_model_info: LoadedModelInfo, + model_config: AnyModelConfig, ) -> None: """Emitted when a model is correctly loaded (returns model info)""" self.__emit_queue_event( @@ -211,13 +200,7 @@ def emit_model_load_completed( "queue_item_id": queue_item_id, "queue_batch_id": queue_batch_id, "graph_execution_state_id": graph_execution_state_id, - "model_name": model_name, - "base_model": base_model, - "model_type": model_type, - "submodel": submodel, - "hash": loaded_model_info.hash, - "location": str(loaded_model_info.location), - "precision": str(loaded_model_info.precision), + "model_config": model_config.model_dump(), }, ) diff --git a/invokeai/app/services/invocation_stats/invocation_stats_default.py b/invokeai/app/services/invocation_stats/invocation_stats_default.py index be58aaad2dd..0c63b545ff2 100644 --- a/invokeai/app/services/invocation_stats/invocation_stats_default.py +++ b/invokeai/app/services/invocation_stats/invocation_stats_default.py @@ -2,6 +2,7 @@ import time from contextlib import contextmanager from pathlib import Path +from typing import Iterator import psutil import torch @@ -10,7 +11,7 @@ from invokeai.app.invocations.baseinvocation import BaseInvocation from invokeai.app.services.invoker import Invoker from invokeai.app.services.item_storage.item_storage_common import ItemNotFoundError -from invokeai.backend.model_management.model_cache import CacheStats +from invokeai.backend.model_manager.load.model_cache import CacheStats from .invocation_stats_base import InvocationStatsServiceBase from .invocation_stats_common import ( @@ -41,7 +42,10 @@ def start(self, invoker: Invoker) -> None: self._invoker = invoker @contextmanager - def collect_stats(self, invocation: BaseInvocation, graph_execution_state_id: str): + def collect_stats(self, invocation: BaseInvocation, graph_execution_state_id: str) -> Iterator[None]: + services = self._invoker.services + if services.model_records is None or services.model_records.loader is None: + yield None if not self._stats.get(graph_execution_state_id): # First time we're seeing this graph_execution_state_id. self._stats[graph_execution_state_id] = GraphExecutionStats() @@ -55,8 +59,10 @@ def collect_stats(self, invocation: BaseInvocation, graph_execution_state_id: st start_ram = psutil.Process().memory_info().rss if torch.cuda.is_available(): torch.cuda.reset_peak_memory_stats() - if self._invoker.services.model_manager: - self._invoker.services.model_manager.collect_cache_stats(self._cache_stats[graph_execution_state_id]) + + # TO DO [LS]: clean up loader service - shouldn't be an attribute of model records + assert services.model_records.loader is not None + services.model_records.loader.ram_cache.stats = self._cache_stats[graph_execution_state_id] try: # Let the invocation run. @@ -73,7 +79,7 @@ def collect_stats(self, invocation: BaseInvocation, graph_execution_state_id: st ) self._stats[graph_execution_state_id].add_node_execution_stats(node_stats) - def _prune_stale_stats(self): + def _prune_stale_stats(self) -> None: """Check all graphs being tracked and prune any that have completed/errored. This shouldn't be necessary, but we don't have totally robust upstream handling of graph completions/errors, so diff --git a/invokeai/app/services/model_records/model_records_base.py b/invokeai/app/services/model_records/model_records_base.py index 42e3c8f83a7..e00dd4169d5 100644 --- a/invokeai/app/services/model_records/model_records_base.py +++ b/invokeai/app/services/model_records/model_records_base.py @@ -10,6 +10,7 @@ from pydantic import BaseModel, Field +from invokeai.app.invocations.baseinvocation import InvocationContext from invokeai.app.services.shared.pagination import PaginatedResults from invokeai.backend.model_manager import ( AnyModelConfig, @@ -19,6 +20,7 @@ ModelType, SubModelType, ) +from invokeai.backend.model_manager.load import AnyModelLoader from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore @@ -110,12 +112,45 @@ def get_model(self, key: str) -> AnyModelConfig: pass @abstractmethod - def load_model(self, key: str, submodel_type: Optional[SubModelType]) -> LoadedModel: + def load_model( + self, + key: str, + submodel: Optional[SubModelType] = None, + context: Optional[InvocationContext] = None, + ) -> LoadedModel: """ Load the indicated model into memory and return a LoadedModel object. :param key: Key of model config to be fetched. - :param submodel_type: For main (pipeline models), the submodel to fetch + :param submodel: For main (pipeline models), the submodel to fetch + :param context: Invocation context, used for event issuing. + + Exceptions: UnknownModelException -- model with this key not known + NotImplementedException -- a model loader was not provided at initialization time + """ + pass + + @abstractmethod + def load_model_by_attr( + self, + model_name: str, + base_model: BaseModelType, + model_type: ModelType, + submodel: Optional[SubModelType] = None, + context: Optional[InvocationContext] = None, + ) -> LoadedModel: + """ + Load the indicated model into memory and return a LoadedModel object. + + This is provided for API compatability with the get_model() method + in the original model manager. However, note that LoadedModel is + not the same as the original ModelInfo that ws returned. + + :param model_name: Key of model config to be fetched. + :param base_model: Base model + :param model_type: Type of the model + :param submodel: For main (pipeline models), the submodel to fetch + :param context: The invocation context. Exceptions: UnknownModelException -- model with this key not known NotImplementedException -- a model loader was not provided at initialization time @@ -166,7 +201,7 @@ def list_models( @abstractmethod def exists(self, key: str) -> bool: """ - Return True if a model with the indicated key exists in the databse. + Return True if a model with the indicated key exists in the database. :param key: Unique key for the model to be deleted """ @@ -209,6 +244,12 @@ def search_by_attr( """ pass + @property + @abstractmethod + def loader(self) -> Optional[AnyModelLoader]: + """Return the model loader used by this instance.""" + pass + def all_models(self) -> List[AnyModelConfig]: """Return all the model configs in the database.""" return self.search_by_attr() diff --git a/invokeai/app/services/model_records/model_records_sql.py b/invokeai/app/services/model_records/model_records_sql.py index b50cd17a75d..28a77b1b1ab 100644 --- a/invokeai/app/services/model_records/model_records_sql.py +++ b/invokeai/app/services/model_records/model_records_sql.py @@ -46,6 +46,8 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Set, Tuple, Union +from invokeai.app.invocations.baseinvocation import InvocationContext +from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException from invokeai.app.services.shared.pagination import PaginatedResults from invokeai.backend.model_manager.config import ( AnyModelConfig, @@ -88,6 +90,11 @@ def db(self) -> SqliteDatabase: """Return the underlying database.""" return self._db + @property + def loader(self) -> Optional[AnyModelLoader]: + """Return the model loader used by this instance.""" + return self._loader + def add_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig: """ Add a model to the database. @@ -213,20 +220,73 @@ def get_model(self, key: str) -> AnyModelConfig: model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1]) return model - def load_model(self, key: str, submodel_type: Optional[SubModelType]) -> LoadedModel: + def load_model( + self, + key: str, + submodel: Optional[SubModelType], + context: Optional[InvocationContext] = None, + ) -> LoadedModel: """ Load the indicated model into memory and return a LoadedModel object. :param key: Key of model config to be fetched. - :param submodel_type: For main (pipeline models), the submodel to fetch. + :param submodel: For main (pipeline models), the submodel to fetch. + :param context: Invocation context used for event reporting Exceptions: UnknownModelException -- model with this key not known NotImplementedException -- a model loader was not provided at initialization time """ if not self._loader: raise NotImplementedError(f"Class {self.__class__} was not initialized with a model loader") + # we can emit model loading events if we are executing with access to the invocation context + model_config = self.get_model(key) - return self._loader.load_model(model_config, submodel_type) + if context: + self._emit_load_event( + context=context, + model_config=model_config, + ) + loaded_model = self._loader.load_model(model_config, submodel) + if context: + self._emit_load_event( + context=context, + model_config=model_config, + loaded=True, + ) + return loaded_model + + def load_model_by_attr( + self, + model_name: str, + base_model: BaseModelType, + model_type: ModelType, + submodel: Optional[SubModelType] = None, + context: Optional[InvocationContext] = None, + ) -> LoadedModel: + """ + Load the indicated model into memory and return a LoadedModel object. + + This is provided for API compatability with the get_model() method + in the original model manager. However, note that LoadedModel is + not the same as the original ModelInfo that ws returned. + + :param model_name: Key of model config to be fetched. + :param base_model: Base model + :param model_type: Type of the model + :param submodel: For main (pipeline models), the submodel to fetch + :param context: The invocation context. + + Exceptions: UnknownModelException -- model with this key not known + NotImplementedException -- a model loader was not provided at initialization time + ValueError -- more than one model matches this combination + """ + configs = self.search_by_attr(model_name, base_model, model_type) + if len(configs) == 0: + raise UnknownModelException(f"{base_model}/{model_type}/{model_name}: Unknown model") + elif len(configs) > 1: + raise ValueError(f"{base_model}/{model_type}/{model_name}: More than one model matches.") + else: + return self.load_model(configs[0].key, submodel) def exists(self, key: str) -> bool: """ @@ -416,3 +476,29 @@ def _fixup(summary: Dict[str, str]) -> Dict[str, Union[str, int, Set[str]]]: return PaginatedResults( page=page, pages=ceil(total / per_page), per_page=per_page, total=total, items=items ) + + def _emit_load_event( + self, + context: InvocationContext, + model_config: AnyModelConfig, + loaded: Optional[bool] = False, + ) -> None: + if context.services.queue.is_canceled(context.graph_execution_state_id): + raise CanceledException() + + if not loaded: + context.services.events.emit_model_load_started( + queue_id=context.queue_id, + queue_item_id=context.queue_item_id, + queue_batch_id=context.queue_batch_id, + graph_execution_state_id=context.graph_execution_state_id, + model_config=model_config, + ) + else: + context.services.events.emit_model_load_completed( + queue_id=context.queue_id, + queue_item_id=context.queue_item_id, + queue_batch_id=context.queue_batch_id, + graph_execution_state_id=context.graph_execution_state_id, + model_config=model_config, + ) diff --git a/invokeai/backend/embeddings/__init__.py b/invokeai/backend/embeddings/__init__.py new file mode 100644 index 00000000000..46ead533c4d --- /dev/null +++ b/invokeai/backend/embeddings/__init__.py @@ -0,0 +1,4 @@ +"""Initialization file for invokeai.backend.embeddings modules.""" + +# from .model_patcher import ModelPatcher +# __all__ = ["ModelPatcher"] diff --git a/invokeai/backend/embeddings/embedding_base.py b/invokeai/backend/embeddings/embedding_base.py new file mode 100644 index 00000000000..5e752a29e14 --- /dev/null +++ b/invokeai/backend/embeddings/embedding_base.py @@ -0,0 +1,12 @@ +"""Base class for LoRA and Textual Inversion models. + +The EmbeddingRaw class is the base class of LoRAModelRaw and TextualInversionModelRaw, +and is used for type checking of calls to the model patcher. + +The use of "Raw" here is a historical artifact, and carried forward in +order to avoid confusion. +""" + + +class EmbeddingModelRaw: + """Base class for LoRA and Textual Inversion models.""" diff --git a/invokeai/backend/embeddings/lora.py b/invokeai/backend/embeddings/lora.py index 9a59a977087..3c7ef074efe 100644 --- a/invokeai/backend/embeddings/lora.py +++ b/invokeai/backend/embeddings/lora.py @@ -11,6 +11,8 @@ from invokeai.backend.model_manager import BaseModelType +from .embedding_base import EmbeddingModelRaw + class LoRALayerBase: # rank: Optional[int] @@ -317,7 +319,7 @@ def to( self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, - ): + ) -> None: super().to(device=device, dtype=dtype) self.weight = self.weight.to(device=device, dtype=dtype) @@ -367,7 +369,7 @@ def to( # TODO: rename all methods used in model logic with Info postfix and remove here Raw postfix -class LoRAModelRaw: # (torch.nn.Module): +class LoRAModelRaw(EmbeddingModelRaw): # (torch.nn.Module): _name: str layers: Dict[str, AnyLoRALayer] @@ -471,16 +473,16 @@ def from_checkpoint( file_path = Path(file_path) model = cls( - name=file_path.stem, # TODO: + name=file_path.stem, layers={}, ) if file_path.suffix == ".safetensors": - state_dict = load_file(file_path.absolute().as_posix(), device="cpu") + sd = load_file(file_path.absolute().as_posix(), device="cpu") else: - state_dict = torch.load(file_path, map_location="cpu") + sd = torch.load(file_path, map_location="cpu") - state_dict = cls._group_state(state_dict) + state_dict = cls._group_state(sd) if base_model == BaseModelType.StableDiffusionXL: state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict) diff --git a/invokeai/backend/embeddings/model_patcher.py b/invokeai/backend/embeddings/model_patcher.py index 6d73235197d..4725181b8ed 100644 --- a/invokeai/backend/embeddings/model_patcher.py +++ b/invokeai/backend/embeddings/model_patcher.py @@ -4,22 +4,20 @@ import pickle from contextlib import contextmanager -from pathlib import Path -from typing import Any, Dict, Generator, List, Optional, Tuple, Union +from typing import Any, Dict, Iterator, List, Optional, Tuple import numpy as np import torch -from compel.embeddings_provider import BaseTextualInversionManager -from diffusers import ModelMixin, OnnxRuntimeModel, UNet2DConditionModel -from safetensors.torch import load_file +from diffusers import OnnxRuntimeModel, UNet2DConditionModel from transformers import CLIPTextModel, CLIPTokenizer -from typing_extensions import Self from invokeai.app.shared.models import FreeUConfig +from invokeai.backend.model_manager import AnyModel from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel from .lora import LoRAModelRaw +from .textual_inversion import TextualInversionManager, TextualInversionModelRaw """ loras = [ @@ -67,7 +65,7 @@ def apply_lora_unet( cls, unet: UNet2DConditionModel, loras: List[Tuple[LoRAModelRaw, float]], - ) -> Generator[None, None, None]: + ) -> None: with cls.apply_lora(unet, loras, "lora_unet_"): yield @@ -76,8 +74,8 @@ def apply_lora_unet( def apply_lora_text_encoder( cls, text_encoder: CLIPTextModel, - loras: List[Tuple[LoRAModelRaw, float]], - ): + loras: Iterator[Tuple[LoRAModelRaw, float]], + ) -> None: with cls.apply_lora(text_encoder, loras, "lora_te_"): yield @@ -87,7 +85,7 @@ def apply_sdxl_lora_text_encoder( cls, text_encoder: CLIPTextModel, loras: List[Tuple[LoRAModelRaw, float]], - ): + ) -> None: with cls.apply_lora(text_encoder, loras, "lora_te1_"): yield @@ -97,7 +95,7 @@ def apply_sdxl_lora_text_encoder2( cls, text_encoder: CLIPTextModel, loras: List[Tuple[LoRAModelRaw, float]], - ): + ) -> None: with cls.apply_lora(text_encoder, loras, "lora_te2_"): yield @@ -105,10 +103,10 @@ def apply_sdxl_lora_text_encoder2( @contextmanager def apply_lora( cls, - model: Union[torch.nn.Module, ModelMixin, UNet2DConditionModel], - loras: List[Tuple[LoRAModelRaw, float]], + model: AnyModel, + loras: Iterator[Tuple[LoRAModelRaw, float]], prefix: str, - ) -> Generator[None, None, None]: + ) -> None: original_weights = {} try: with torch.no_grad(): @@ -125,6 +123,7 @@ def apply_lora( # 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the # intricacies of Stable Diffusion key resolution. It should just expect the input LoRA # weights to have valid keys. + assert isinstance(model, torch.nn.Module) module_key, module = cls._resolve_lora_key(model, layer_key, prefix) # All of the LoRA weight calculations will be done on the same device as the module weight. @@ -170,8 +169,8 @@ def apply_ti( cls, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, - ti_list: List[Tuple[str, TextualInversionModel]], - ) -> Generator[Tuple[CLIPTokenizer, TextualInversionManager], None, None]: + ti_list: List[Tuple[str, TextualInversionModelRaw]], + ) -> Iterator[Tuple[CLIPTokenizer, TextualInversionManager]]: init_tokens_count = None new_tokens_added = None @@ -201,7 +200,7 @@ def _get_trigger(ti_name: str, index: int) -> str: trigger += f"-!pad-{i}" return f"<{trigger}>" - def _get_ti_embedding(model_embeddings: torch.nn.Module, ti: TextualInversionModel) -> torch.Tensor: + def _get_ti_embedding(model_embeddings: torch.nn.Module, ti: TextualInversionModelRaw) -> torch.Tensor: # for SDXL models, select the embedding that matches the text encoder's dimensions if ti.embedding_2 is not None: return ( @@ -229,6 +228,7 @@ def _get_ti_embedding(model_embeddings: torch.nn.Module, ti: TextualInversionMod model_embeddings = text_encoder.get_input_embeddings() for ti_name, ti in ti_list: + assert isinstance(ti, TextualInversionModelRaw) ti_embedding = _get_ti_embedding(text_encoder.get_input_embeddings(), ti) ti_tokens = [] @@ -267,7 +267,7 @@ def apply_clip_skip( cls, text_encoder: CLIPTextModel, clip_skip: int, - ) -> Generator[None, None, None]: + ) -> None: skipped_layers = [] try: for _i in range(clip_skip): @@ -285,7 +285,7 @@ def apply_freeu( cls, unet: UNet2DConditionModel, freeu_config: Optional[FreeUConfig] = None, - ) -> Generator[None, None, None]: + ) -> None: did_apply_freeu = False try: assert hasattr(unet, "enable_freeu") # mypy doesn't pick up this attribute? @@ -301,94 +301,6 @@ def apply_freeu( unet.disable_freeu() -class TextualInversionModel: - embedding: torch.Tensor # [n, 768]|[n, 1280] - embedding_2: Optional[torch.Tensor] = None # [n, 768]|[n, 1280] - for SDXL models - - @classmethod - def from_checkpoint( - cls, - file_path: Union[str, Path], - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ) -> Self: - if not isinstance(file_path, Path): - file_path = Path(file_path) - - result = cls() # TODO: - - if file_path.suffix == ".safetensors": - state_dict = load_file(file_path.absolute().as_posix(), device="cpu") - else: - state_dict = torch.load(file_path, map_location="cpu") - - # both v1 and v2 format embeddings - # difference mostly in metadata - if "string_to_param" in state_dict: - if len(state_dict["string_to_param"]) > 1: - print( - f'Warn: Embedding "{file_path.name}" contains multiple tokens, which is not supported. The first', - " token will be used.", - ) - - result.embedding = next(iter(state_dict["string_to_param"].values())) - - # v3 (easynegative) - elif "emb_params" in state_dict: - result.embedding = state_dict["emb_params"] - - # v5(sdxl safetensors file) - elif "clip_g" in state_dict and "clip_l" in state_dict: - result.embedding = state_dict["clip_g"] - result.embedding_2 = state_dict["clip_l"] - - # v4(diffusers bin files) - else: - result.embedding = next(iter(state_dict.values())) - - if len(result.embedding.shape) == 1: - result.embedding = result.embedding.unsqueeze(0) - - if not isinstance(result.embedding, torch.Tensor): - raise ValueError(f"Invalid embeddings file: {file_path.name}") - - return result - - -# no type hints for BaseTextualInversionManager? -class TextualInversionManager(BaseTextualInversionManager): # type: ignore - pad_tokens: Dict[int, List[int]] - tokenizer: CLIPTokenizer - - def __init__(self, tokenizer: CLIPTokenizer): - self.pad_tokens = {} - self.tokenizer = tokenizer - - def expand_textual_inversion_token_ids_if_necessary(self, token_ids: list[int]) -> list[int]: - if len(self.pad_tokens) == 0: - return token_ids - - if token_ids[0] == self.tokenizer.bos_token_id: - raise ValueError("token_ids must not start with bos_token_id") - if token_ids[-1] == self.tokenizer.eos_token_id: - raise ValueError("token_ids must not end with eos_token_id") - - new_token_ids = [] - for token_id in token_ids: - new_token_ids.append(token_id) - if token_id in self.pad_tokens: - new_token_ids.extend(self.pad_tokens[token_id]) - - # Do not exceed the max model input size - # The -2 here is compensating for compensate compel.embeddings_provider.get_token_ids(), - # which first removes and then adds back the start and end tokens. - max_length = list(self.tokenizer.max_model_input_sizes.values())[0] - 2 - if len(new_token_ids) > max_length: - new_token_ids = new_token_ids[0:max_length] - - return new_token_ids - - class ONNXModelPatcher: @classmethod @contextmanager @@ -396,7 +308,7 @@ def apply_lora_unet( cls, unet: OnnxRuntimeModel, loras: List[Tuple[LoRAModelRaw, float]], - ) -> Generator[None, None, None]: + ) -> None: with cls.apply_lora(unet, loras, "lora_unet_"): yield @@ -406,7 +318,7 @@ def apply_lora_text_encoder( cls, text_encoder: OnnxRuntimeModel, loras: List[Tuple[LoRAModelRaw, float]], - ) -> Generator[None, None, None]: + ) -> None: with cls.apply_lora(text_encoder, loras, "lora_te_"): yield @@ -419,7 +331,7 @@ def apply_lora( model: IAIOnnxRuntimeModel, loras: List[Tuple[LoRAModelRaw, float]], prefix: str, - ) -> Generator[None, None, None]: + ) -> None: from .models.base import IAIOnnxRuntimeModel if not isinstance(model, IAIOnnxRuntimeModel): @@ -506,7 +418,7 @@ def apply_ti( tokenizer: CLIPTokenizer, text_encoder: IAIOnnxRuntimeModel, ti_list: List[Tuple[str, Any]], - ) -> Generator[Tuple[CLIPTokenizer, TextualInversionManager], None, None]: + ) -> Iterator[Tuple[CLIPTokenizer, TextualInversionManager]]: from .models.base import IAIOnnxRuntimeModel if not isinstance(text_encoder, IAIOnnxRuntimeModel): diff --git a/invokeai/backend/embeddings/textual_inversion.py b/invokeai/backend/embeddings/textual_inversion.py new file mode 100644 index 00000000000..389edff039d --- /dev/null +++ b/invokeai/backend/embeddings/textual_inversion.py @@ -0,0 +1,100 @@ +"""Textual Inversion wrapper class.""" + +from pathlib import Path +from typing import Dict, List, Optional, Union + +import torch +from compel.embeddings_provider import BaseTextualInversionManager +from safetensors.torch import load_file +from transformers import CLIPTokenizer +from typing_extensions import Self + +from .embedding_base import EmbeddingModelRaw + + +class TextualInversionModelRaw(EmbeddingModelRaw): + embedding: torch.Tensor # [n, 768]|[n, 1280] + embedding_2: Optional[torch.Tensor] = None # [n, 768]|[n, 1280] - for SDXL models + + @classmethod + def from_checkpoint( + cls, + file_path: Union[str, Path], + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> Self: + if not isinstance(file_path, Path): + file_path = Path(file_path) + + result = cls() # TODO: + + if file_path.suffix == ".safetensors": + state_dict = load_file(file_path.absolute().as_posix(), device="cpu") + else: + state_dict = torch.load(file_path, map_location="cpu") + + # both v1 and v2 format embeddings + # difference mostly in metadata + if "string_to_param" in state_dict: + if len(state_dict["string_to_param"]) > 1: + print( + f'Warn: Embedding "{file_path.name}" contains multiple tokens, which is not supported. The first', + " token will be used.", + ) + + result.embedding = next(iter(state_dict["string_to_param"].values())) + + # v3 (easynegative) + elif "emb_params" in state_dict: + result.embedding = state_dict["emb_params"] + + # v5(sdxl safetensors file) + elif "clip_g" in state_dict and "clip_l" in state_dict: + result.embedding = state_dict["clip_g"] + result.embedding_2 = state_dict["clip_l"] + + # v4(diffusers bin files) + else: + result.embedding = next(iter(state_dict.values())) + + if len(result.embedding.shape) == 1: + result.embedding = result.embedding.unsqueeze(0) + + if not isinstance(result.embedding, torch.Tensor): + raise ValueError(f"Invalid embeddings file: {file_path.name}") + + return result + + +# no type hints for BaseTextualInversionManager? +class TextualInversionManager(BaseTextualInversionManager): # type: ignore + pad_tokens: Dict[int, List[int]] + tokenizer: CLIPTokenizer + + def __init__(self, tokenizer: CLIPTokenizer): + self.pad_tokens = {} + self.tokenizer = tokenizer + + def expand_textual_inversion_token_ids_if_necessary(self, token_ids: list[int]) -> list[int]: + if len(self.pad_tokens) == 0: + return token_ids + + if token_ids[0] == self.tokenizer.bos_token_id: + raise ValueError("token_ids must not start with bos_token_id") + if token_ids[-1] == self.tokenizer.eos_token_id: + raise ValueError("token_ids must not end with eos_token_id") + + new_token_ids = [] + for token_id in token_ids: + new_token_ids.append(token_id) + if token_id in self.pad_tokens: + new_token_ids.extend(self.pad_tokens[token_id]) + + # Do not exceed the max model input size + # The -2 here is compensating for compensate compel.embeddings_provider.get_token_ids(), + # which first removes and then adds back the start and end tokens. + max_length = list(self.tokenizer.max_model_input_sizes.values())[0] - 2 + if len(new_token_ids) > max_length: + new_token_ids = new_token_ids[0:max_length] + + return new_token_ids diff --git a/invokeai/backend/install/install_helper.py b/invokeai/backend/install/install_helper.py index 57dfadcaeaa..8877e33092c 100644 --- a/invokeai/backend/install/install_helper.py +++ b/invokeai/backend/install/install_helper.py @@ -241,10 +241,11 @@ def _make_install_source(self, model_info: UnifiedModelInfo) -> ModelSource: if match := re.match(f"^([^/]+/[^/]+?)(?::({variants}))?$", model_path_id_or_url): repo_id = match.group(1) repo_variant = ModelRepoVariant(match.group(2)) if match.group(2) else None + subfolder = Path(model_info.subfolder) if model_info.subfolder else None return HFModelSource( repo_id=repo_id, access_token=HfFolder.get_token(), - subfolder=model_info.subfolder, + subfolder=subfolder, variant=repo_variant, ) if re.match(r"^(http|https):", model_path_id_or_url): diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 49ce6af2b81..0dcd925c84b 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -30,8 +30,11 @@ from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel +from ..embeddings.embedding_base import EmbeddingModelRaw from ..ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus +AnyModel = Union[ModelMixin, torch.nn.Module, IAIOnnxRuntimeModel, IPAdapter, IPAdapterPlus, EmbeddingModelRaw] + class InvalidModelConfigException(Exception): """Exception for when config parser doesn't recognized this combination of model type and format.""" @@ -299,7 +302,7 @@ class T2IConfig(ModelConfigBase): ] AnyModelConfigValidator = TypeAdapter(AnyModelConfig) -AnyModel = Union[ModelMixin, torch.nn.Module, IAIOnnxRuntimeModel, IPAdapter, IPAdapterPlus] + # IMPLEMENTATION NOTE: # The preferred alternative to the above is a discriminated Union as shown diff --git a/invokeai/backend/model_manager/load/load_base.py b/invokeai/backend/model_manager/load/load_base.py index ee9d6d53e3d..9d98ee30531 100644 --- a/invokeai/backend/model_manager/load/load_base.py +++ b/invokeai/backend/model_manager/load/load_base.py @@ -18,8 +18,8 @@ from typing import Any, Callable, Dict, Optional, Tuple, Type from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType -from invokeai.backend.model_manager.config import VaeCheckpointConfig, VaeDiffusersConfig +from invokeai.backend.model_manager import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType +from invokeai.backend.model_manager.config import AnyModel, VaeCheckpointConfig, VaeDiffusersConfig from invokeai.backend.model_manager.load.convert_cache.convert_cache_base import ModelConvertCacheBase from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase from invokeai.backend.util.logging import InvokeAILogger diff --git a/invokeai/backend/model_manager/load/load_default.py b/invokeai/backend/model_manager/load/load_default.py index 757745072d1..2192c88ac2f 100644 --- a/invokeai/backend/model_manager/load/load_default.py +++ b/invokeai/backend/model_manager/load/load_default.py @@ -19,7 +19,7 @@ ) from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase -from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase +from invokeai.backend.model_manager.load.model_cache.model_cache_base import CacheStats, ModelCacheBase, ModelLockerBase from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data, calc_model_size_by_fs from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init from invokeai.backend.util.devices import choose_torch_device, torch_dtype @@ -71,7 +71,7 @@ def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubMo model_path, model_config, submodel_type = self._get_model_path(model_config, submodel_type) if not model_path.exists(): - raise InvalidModelConfigException(f"Files for model 'model_config.name' not found at {model_path}") + raise InvalidModelConfigException(f"Files for model '{model_config.name}' not found at {model_path}") model_path = self._convert_if_needed(model_config, model_path, submodel_type) locker = self._load_if_needed(model_config, model_path, submodel_type) diff --git a/invokeai/backend/model_manager/load/model_cache/__init__.py b/invokeai/backend/model_manager/load/model_cache/__init__.py index 0cb5184f3a4..32c682d0424 100644 --- a/invokeai/backend/model_manager/load/model_cache/__init__.py +++ b/invokeai/backend/model_manager/load/model_cache/__init__.py @@ -1,4 +1,6 @@ """Init file for ModelCache.""" +from .model_cache_base import ModelCacheBase, CacheStats # noqa F401 +from .model_cache_default import ModelCache # noqa F401 -_all__ = ["ModelCacheBase", "ModelCache"] +_all__ = ["ModelCacheBase", "ModelCache", "CacheStats"] diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_base.py b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py index b1a6768ee8f..4a4a3c7d299 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_base.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py @@ -8,13 +8,13 @@ """ from abc import ABC, abstractmethod -from dataclasses import dataclass +from dataclasses import dataclass, field from logging import Logger -from typing import Generic, Optional, TypeVar +from typing import Dict, Generic, Optional, TypeVar import torch -from invokeai.backend.model_manager import AnyModel, SubModelType +from invokeai.backend.model_manager.config import AnyModel, SubModelType class ModelLockerBase(ABC): @@ -65,6 +65,19 @@ def locked(self) -> bool: return self._locks > 0 +@dataclass +class CacheStats(object): + """Collect statistics on cache performance.""" + + hits: int = 0 # cache hits + misses: int = 0 # cache misses + high_watermark: int = 0 # amount of cache used + in_cache: int = 0 # number of models in cache + cleared: int = 0 # number of models cleared to make space + cache_size: int = 0 # total size of cache + loaded_model_sizes: Dict[str, int] = field(default_factory=dict) + + class ModelCacheBase(ABC, Generic[T]): """Virtual base class for RAM model cache.""" @@ -98,10 +111,22 @@ def offload_unlocked_models(self, size_required: int) -> None: pass @abstractmethod - def move_model_to_device(self, cache_entry: CacheRecord, device: torch.device) -> None: + def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], device: torch.device) -> None: """Move model into the indicated device.""" pass + @property + @abstractmethod + def stats(self) -> CacheStats: + """Return collected CacheStats object.""" + pass + + @stats.setter + @abstractmethod + def stats(self, stats: CacheStats) -> None: + """Set the CacheStats object for collectin cache statistics.""" + pass + @property @abstractmethod def logger(self) -> Logger: diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py index 7e30512a588..b1deb215b2b 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py @@ -24,19 +24,17 @@ import sys import time from contextlib import suppress -from dataclasses import dataclass, field from logging import Logger from typing import Dict, List, Optional import torch -from invokeai.backend.model_manager import SubModelType -from invokeai.backend.model_manager.load.load_base import AnyModel +from invokeai.backend.model_manager import AnyModel, SubModelType from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff from invokeai.backend.util.devices import choose_torch_device from invokeai.backend.util.logging import InvokeAILogger -from .model_cache_base import CacheRecord, ModelCacheBase +from .model_cache_base import CacheRecord, CacheStats, ModelCacheBase from .model_locker import ModelLocker, ModelLockerBase if choose_torch_device() == torch.device("mps"): @@ -56,20 +54,6 @@ MB = 2**20 -@dataclass -class CacheStats(object): - """Collect statistics on cache performance.""" - - hits: int = 0 # cache hits - misses: int = 0 # cache misses - high_watermark: int = 0 # amount of cache used - in_cache: int = 0 # number of models in cache - cleared: int = 0 # number of models cleared to make space - cache_size: int = 0 # total size of cache - # {submodel_key => size} - loaded_model_sizes: Dict[str, int] = field(default_factory=dict) - - class ModelCache(ModelCacheBase[AnyModel]): """Implementation of ModelCacheBase.""" @@ -110,7 +94,7 @@ def __init__( self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__) self._log_memory_usage = log_memory_usage or self._logger.level == logging.DEBUG # used for stats collection - self.stats = CacheStats() + self._stats: Optional[CacheStats] = None self._cached_models: Dict[str, CacheRecord[AnyModel]] = {} self._cache_stack: List[str] = [] @@ -140,6 +124,16 @@ def max_cache_size(self) -> float: """Return the cap on cache size.""" return self._max_cache_size + @property + def stats(self) -> Optional[CacheStats]: + """Return collected CacheStats object.""" + return self._stats + + @stats.setter + def stats(self, stats: CacheStats) -> None: + """Set the CacheStats object for collectin cache statistics.""" + self._stats = stats + def cache_size(self) -> int: """Get the total size of the models currently cached.""" total = 0 @@ -189,21 +183,24 @@ def get( """ key = self._make_cache_key(key, submodel_type) if key in self._cached_models: - self.stats.hits += 1 + if self.stats: + self.stats.hits += 1 else: - self.stats.misses += 1 + if self.stats: + self.stats.misses += 1 raise IndexError(f"The model with key {key} is not in the cache.") cache_entry = self._cached_models[key] # more stats - stats_name = stats_name or key - self.stats.cache_size = int(self._max_cache_size * GIG) - self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size()) - self.stats.in_cache = len(self._cached_models) - self.stats.loaded_model_sizes[stats_name] = max( - self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.size - ) + if self.stats: + stats_name = stats_name or key + self.stats.cache_size = int(self._max_cache_size * GIG) + self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size()) + self.stats.in_cache = len(self._cached_models) + self.stats.loaded_model_sizes[stats_name] = max( + self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.size + ) # this moves the entry to the top (right end) of the stack with suppress(Exception): diff --git a/invokeai/backend/model_manager/load/model_loaders/textual_inversion.py b/invokeai/backend/model_manager/load/model_loaders/textual_inversion.py index 394fddc75d0..6635f6b43fe 100644 --- a/invokeai/backend/model_manager/load/model_loaders/textual_inversion.py +++ b/invokeai/backend/model_manager/load/model_loaders/textual_inversion.py @@ -5,7 +5,7 @@ from pathlib import Path from typing import Optional, Tuple -from invokeai.backend.embeddings.model_patcher import TextualInversionModel as TextualInversionModelRaw +from invokeai.backend.embeddings.textual_inversion import TextualInversionModelRaw from invokeai.backend.model_manager import ( AnyModel, AnyModelConfig, diff --git a/invokeai/backend/stable_diffusion/__init__.py b/invokeai/backend/stable_diffusion/__init__.py index 212045f81b8..75e6aa0a5de 100644 --- a/invokeai/backend/stable_diffusion/__init__.py +++ b/invokeai/backend/stable_diffusion/__init__.py @@ -4,3 +4,12 @@ from .diffusers_pipeline import PipelineIntermediateState, StableDiffusionGeneratorPipeline # noqa: F401 from .diffusion import InvokeAIDiffuserComponent # noqa: F401 from .diffusion.cross_attention_map_saving import AttentionMapSaver # noqa: F401 +from .seamless import set_seamless # noqa: F401 + +__all__ = [ + "PipelineIntermediateState", + "StableDiffusionGeneratorPipeline", + "InvokeAIDiffuserComponent", + "AttentionMapSaver", + "set_seamless", +] diff --git a/invokeai/backend/stable_diffusion/seamless.py b/invokeai/backend/stable_diffusion/seamless.py new file mode 100644 index 00000000000..bfdf9e0c536 --- /dev/null +++ b/invokeai/backend/stable_diffusion/seamless.py @@ -0,0 +1,102 @@ +from __future__ import annotations + +from contextlib import contextmanager +from typing import List, Union + +import torch.nn as nn +from diffusers.models import AutoencoderKL, UNet2DConditionModel + + +def _conv_forward_asymmetric(self, input, weight, bias): + """ + Patch for Conv2d._conv_forward that supports asymmetric padding + """ + working = nn.functional.pad(input, self.asymmetric_padding["x"], mode=self.asymmetric_padding_mode["x"]) + working = nn.functional.pad(working, self.asymmetric_padding["y"], mode=self.asymmetric_padding_mode["y"]) + return nn.functional.conv2d( + working, + weight, + bias, + self.stride, + nn.modules.utils._pair(0), + self.dilation, + self.groups, + ) + + +@contextmanager +def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL], seamless_axes: List[str]): + try: + to_restore = [] + + for m_name, m in model.named_modules(): + if isinstance(model, UNet2DConditionModel): + if ".attentions." in m_name: + continue + + if ".resnets." in m_name: + if ".conv2" in m_name: + continue + if ".conv_shortcut" in m_name: + continue + + """ + if isinstance(model, UNet2DConditionModel): + if False and ".upsamplers." in m_name: + continue + + if False and ".downsamplers." in m_name: + continue + + if True and ".resnets." in m_name: + if True and ".conv1" in m_name: + if False and "down_blocks" in m_name: + continue + if False and "mid_block" in m_name: + continue + if False and "up_blocks" in m_name: + continue + + if True and ".conv2" in m_name: + continue + + if True and ".conv_shortcut" in m_name: + continue + + if True and ".attentions." in m_name: + continue + + if False and m_name in ["conv_in", "conv_out"]: + continue + """ + + if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): + m.asymmetric_padding_mode = {} + m.asymmetric_padding = {} + m.asymmetric_padding_mode["x"] = "circular" if ("x" in seamless_axes) else "constant" + m.asymmetric_padding["x"] = ( + m._reversed_padding_repeated_twice[0], + m._reversed_padding_repeated_twice[1], + 0, + 0, + ) + m.asymmetric_padding_mode["y"] = "circular" if ("y" in seamless_axes) else "constant" + m.asymmetric_padding["y"] = ( + 0, + 0, + m._reversed_padding_repeated_twice[2], + m._reversed_padding_repeated_twice[3], + ) + + to_restore.append((m, m._conv_forward)) + m._conv_forward = _conv_forward_asymmetric.__get__(m, nn.Conv2d) + + yield + + finally: + for module, orig_conv_forward in to_restore: + module._conv_forward = orig_conv_forward + if hasattr(module, "asymmetric_padding_mode"): + del module.asymmetric_padding_mode + if hasattr(module, "asymmetric_padding"): + del module.asymmetric_padding diff --git a/invokeai/backend/util/silence_warnings.py b/invokeai/backend/util/silence_warnings.py new file mode 100644 index 00000000000..068b605da97 --- /dev/null +++ b/invokeai/backend/util/silence_warnings.py @@ -0,0 +1,28 @@ +"""Context class to silence transformers and diffusers warnings.""" +import warnings +from typing import Any + +from diffusers import logging as diffusers_logging +from transformers import logging as transformers_logging + + +class SilenceWarnings(object): + """Use in context to temporarily turn off warnings from transformers & diffusers modules. + + with SilenceWarnings(): + # do something + """ + + def __init__(self) -> None: + self.transformers_verbosity = transformers_logging.get_verbosity() + self.diffusers_verbosity = diffusers_logging.get_verbosity() + + def __enter__(self) -> None: + transformers_logging.set_verbosity_error() + diffusers_logging.set_verbosity_error() + warnings.simplefilter("ignore") + + def __exit__(self, *args: Any) -> None: + transformers_logging.set_verbosity(self.transformers_verbosity) + diffusers_logging.set_verbosity(self.diffusers_verbosity) + warnings.simplefilter("default") diff --git a/invokeai/frontend/install/model_install2.py b/invokeai/frontend/install/model_install2.py index 51a633a5654..22b132370e6 100644 --- a/invokeai/frontend/install/model_install2.py +++ b/invokeai/frontend/install/model_install2.py @@ -23,7 +23,7 @@ from npyscreen import widget from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.app.services.model_install import ModelInstallService +from invokeai.app.services.model_install import ModelInstallServiceBase from invokeai.backend.install.install_helper import InstallHelper, InstallSelections, UnifiedModelInfo from invokeai.backend.model_manager import ModelType from invokeai.backend.util import choose_precision, choose_torch_device @@ -499,7 +499,7 @@ def onStart(self) -> None: ) -def list_models(installer: ModelInstallService, model_type: ModelType): +def list_models(installer: ModelInstallServiceBase, model_type: ModelType): """Print out all models of type model_type.""" models = installer.record_store.search_by_attr(model_type=model_type) print(f"Installed models of type `{model_type}`:") @@ -527,7 +527,9 @@ def select_and_download_models(opt: Namespace) -> None: install_helper.add_or_delete(selections) elif opt.default_only: - selections = InstallSelections(install_models=[install_helper.default_model()]) + default_model = install_helper.default_model() + assert default_model is not None + selections = InstallSelections(install_models=[default_model]) install_helper.add_or_delete(selections) elif opt.yes_to_all: diff --git a/tests/backend/model_manager_2/util/test_hf_model_select.py b/tests/backend/model_manager_2/util/test_hf_model_select.py index f14d9a6823a..5bef9cb2e19 100644 --- a/tests/backend/model_manager_2/util/test_hf_model_select.py +++ b/tests/backend/model_manager_2/util/test_hf_model_select.py @@ -192,6 +192,7 @@ def sdxl_base_files() -> List[Path]: "text_encoder/model.onnx", "text_encoder_2/config.json", "text_encoder_2/model.onnx", + "text_encoder_2/model.onnx_data", "tokenizer/merges.txt", "tokenizer/special_tokens_map.json", "tokenizer/tokenizer_config.json", @@ -202,6 +203,7 @@ def sdxl_base_files() -> List[Path]: "tokenizer_2/vocab.json", "unet/config.json", "unet/model.onnx", + "unet/model.onnx_data", "vae_decoder/config.json", "vae_decoder/model.onnx", "vae_encoder/config.json", diff --git a/tests/test_model_probe.py b/tests/test_model_probe.py index aacae06a8bb..be823e2be9f 100644 --- a/tests/test_model_probe.py +++ b/tests/test_model_probe.py @@ -2,7 +2,7 @@ import pytest -from invokeai.backend import BaseModelType +from invokeai.backend.model_manager import BaseModelType, ModelRepoVariant from invokeai.backend.model_manager.probe import VaeFolderProbe @@ -21,10 +21,10 @@ def test_get_base_type(vae_path: str, expected_type: BaseModelType, datadir: Pat base_type = probe.get_base_type() assert base_type == expected_type repo_variant = probe.get_repo_variant() - assert repo_variant == "default" + assert repo_variant == ModelRepoVariant.DEFAULT def test_repo_variant(datadir: Path): probe = VaeFolderProbe(datadir / "vae" / "taesdxl-fp16") repo_variant = probe.get_repo_variant() - assert repo_variant == "fp16" + assert repo_variant == ModelRepoVariant.FP16 From 0f0e8cea1c0f4483403e38cd7d0d43689dde8156 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Fri, 9 Feb 2024 16:42:33 -0500 Subject: [PATCH 079/340] fix invokeai_configure script to work with new mm; rename CLIs --- .../app/services/config/config_default.py | 10 +- invokeai/backend/install/install_helper.py | 2 +- .../backend/install/invokeai_configure.py | 197 ++++---- .../model_manager/load/load_default.py | 2 +- invokeai/backend/util/devices.py | 6 +- invokeai/configs/INITIAL_MODELS.yaml | 106 +++-- ...L_MODELS2.yaml => INITIAL_MODELS.yaml.OLD} | 106 ++--- invokeai/frontend/install/model_install.py | 448 +++++------------- ...model_install2.py => model_install.py.OLD} | 448 +++++++++++++----- invokeai/frontend/install/widgets.py | 11 + ...e_diffusers2.py => merge_diffusers.py.OLD} | 0 pyproject.toml | 3 +- tests/test_model_manager.py | 47 -- 13 files changed, 690 insertions(+), 696 deletions(-) rename invokeai/configs/{INITIAL_MODELS2.yaml => INITIAL_MODELS.yaml.OLD} (59%) rename invokeai/frontend/install/{model_install2.py => model_install.py.OLD} (57%) rename invokeai/frontend/merge/{merge_diffusers2.py => merge_diffusers.py.OLD} (100%) delete mode 100644 tests/test_model_manager.py diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py index b39e916da34..2af775372dd 100644 --- a/invokeai/app/services/config/config_default.py +++ b/invokeai/app/services/config/config_default.py @@ -185,7 +185,9 @@ class InvokeBatch(InvokeAISettings): INIT_FILE = Path("invokeai.yaml") DB_FILE = Path("invokeai.db") LEGACY_INIT_FILE = Path("invokeai.init") -DEFAULT_MAX_VRAM = 0.5 +DEFAULT_RAM_CACHE = 10.0 +DEFAULT_VRAM_CACHE = 0.25 +DEFAULT_CONVERT_CACHE = 20.0 class Categories(object): @@ -261,9 +263,9 @@ class InvokeAIAppConfig(InvokeAISettings): version : bool = Field(default=False, description="Show InvokeAI version and exit", json_schema_extra=Categories.Other) # CACHE - ram : float = Field(default=7.5, gt=0, description="Maximum memory amount used by model cache for rapid switching (floating point number, GB)", json_schema_extra=Categories.ModelCache, ) - vram : float = Field(default=0.25, ge=0, description="Amount of VRAM reserved for model storage (floating point number, GB)", json_schema_extra=Categories.ModelCache, ) - convert_cache : float = Field(default=10.0, ge=0, description="Maximum size of on-disk converted models cache (GB)", json_schema_extra=Categories.ModelCache) + ram : float = Field(default=DEFAULT_RAM_CACHE, gt=0, description="Maximum memory amount used by model cache for rapid switching (floating point number, GB)", json_schema_extra=Categories.ModelCache, ) + vram : float = Field(default=DEFAULT_VRAM_CACHE, ge=0, description="Amount of VRAM reserved for model storage (floating point number, GB)", json_schema_extra=Categories.ModelCache, ) + convert_cache : float = Field(default=DEFAULT_CONVERT_CACHE, ge=0, description="Maximum size of on-disk converted models cache (GB)", json_schema_extra=Categories.ModelCache) lazy_offload : bool = Field(default=True, description="Keep models in VRAM until their space is needed", json_schema_extra=Categories.ModelCache, ) log_memory_usage : bool = Field(default=False, description="If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.", json_schema_extra=Categories.ModelCache) diff --git a/invokeai/backend/install/install_helper.py b/invokeai/backend/install/install_helper.py index 8877e33092c..9c386c209ce 100644 --- a/invokeai/backend/install/install_helper.py +++ b/invokeai/backend/install/install_helper.py @@ -37,7 +37,7 @@ from invokeai.backend.util.logging import InvokeAILogger # name of the starter models file -INITIAL_MODELS = "INITIAL_MODELS2.yaml" +INITIAL_MODELS = "INITIAL_MODELS.yaml" def initialize_record_store(app_config: InvokeAIAppConfig) -> ModelRecordServiceBase: diff --git a/invokeai/backend/install/invokeai_configure.py b/invokeai/backend/install/invokeai_configure.py index 3cb7db6c82c..4dfa2b070c0 100755 --- a/invokeai/backend/install/invokeai_configure.py +++ b/invokeai/backend/install/invokeai_configure.py @@ -18,31 +18,30 @@ from enum import Enum from pathlib import Path from shutil import get_terminal_size -from typing import Any, get_args, get_type_hints +from typing import Any, Optional, Set, Tuple, Type, get_args, get_type_hints from urllib import request import npyscreen -import omegaconf import psutil import torch import transformers -import yaml -from diffusers import AutoencoderKL +from diffusers import AutoencoderKL, ModelMixin from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from huggingface_hub import HfFolder from huggingface_hub import login as hf_hub_login -from omegaconf import OmegaConf -from pydantic import ValidationError +from omegaconf import DictConfig, OmegaConf +from pydantic.error_wrappers import ValidationError from tqdm import tqdm from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextConfig, CLIPTextModel, CLIPTokenizer import invokeai.configs as configs from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.backend.install.install_helper import InstallHelper, InstallSelections from invokeai.backend.install.legacy_arg_parsing import legacy_parser -from invokeai.backend.install.model_install_backend import InstallSelections, ModelInstall, hf_download_from_pretrained -from invokeai.backend.model_management.model_probe import BaseModelType, ModelType +from invokeai.backend.model_manager import BaseModelType, ModelType +from invokeai.backend.util import choose_precision, choose_torch_device from invokeai.backend.util.logging import InvokeAILogger -from invokeai.frontend.install.model_install import addModelsForm, process_and_execute +from invokeai.frontend.install.model_install import addModelsForm # TO DO - Move all the frontend code into invokeai.frontend.install from invokeai.frontend.install.widgets import ( @@ -61,7 +60,7 @@ transformers.logging.set_verbosity_error() -def get_literal_fields(field) -> list[Any]: +def get_literal_fields(field: str) -> Tuple[Any]: return get_args(get_type_hints(InvokeAIAppConfig).get(field)) @@ -80,8 +79,7 @@ def get_literal_fields(field) -> list[Any]: GENERATION_OPT_CHOICES = ["sequential_guidance", "force_tiled_decode", "lazy_offload"] GB = 1073741824 # GB in bytes HAS_CUDA = torch.cuda.is_available() -_, MAX_VRAM = torch.cuda.mem_get_info() if HAS_CUDA else (0, 0) - +_, MAX_VRAM = torch.cuda.mem_get_info() if HAS_CUDA else (0.0, 0.0) MAX_VRAM /= GB MAX_RAM = psutil.virtual_memory().total / GB @@ -96,13 +94,15 @@ def get_literal_fields(field) -> list[Any]: class DummyWidgetValue(Enum): + """Dummy widget values.""" + zero = 0 true = True false = False # -------------------------------------------- -def postscript(errors: None): +def postscript(errors: Set[str]) -> None: if not any(errors): message = f""" ** INVOKEAI INSTALLATION SUCCESSFUL ** @@ -143,7 +143,7 @@ def yes_or_no(prompt: str, default_yes=True): # --------------------------------------------- -def HfLogin(access_token) -> str: +def HfLogin(access_token) -> None: """ Helper for logging in to Huggingface The stdout capture is needed to hide the irrelevant "git credential helper" warning @@ -162,7 +162,7 @@ def HfLogin(access_token) -> str: # ------------------------------------- class ProgressBar: - def __init__(self, model_name="file"): + def __init__(self, model_name: str = "file"): self.pbar = None self.name = model_name @@ -179,6 +179,22 @@ def __call__(self, block_num, block_size, total_size): self.pbar.update(block_size) +# --------------------------------------------- +def hf_download_from_pretrained(model_class: Type[ModelMixin], model_name: str, destination: Path, **kwargs: Any): + filter = lambda x: "fp16 is not a valid" not in x.getMessage() # noqa E731 + logger.addFilter(filter) + try: + model = model_class.from_pretrained( + model_name, + resume_download=True, + **kwargs, + ) + model.save_pretrained(destination, safe_serialization=True) + finally: + logger.removeFilter(filter) + return destination + + # --------------------------------------------- def download_with_progress_bar(model_url: str, model_dest: str, label: str = "the"): try: @@ -249,6 +265,7 @@ def download_conversion_models(): # --------------------------------------------- +# TO DO: use the download queue here. def download_realesrgan(): logger.info("Installing ESRGAN Upscaling models...") URLs = [ @@ -288,18 +305,19 @@ def download_lama(): # --------------------------------------------- -def download_support_models(): +def download_support_models() -> None: download_realesrgan() download_lama() download_conversion_models() # ------------------------------------- -def get_root(root: str = None) -> str: +def get_root(root: Optional[str] = None) -> str: if root: return root - elif os.environ.get("INVOKEAI_ROOT"): - return os.environ.get("INVOKEAI_ROOT") + elif root := os.environ.get("INVOKEAI_ROOT"): + assert root is not None + return root else: return str(config.root_path) @@ -455,6 +473,25 @@ def create(self): max_width=110, scroll_exit=True, ) + self.add_widget_intelligent( + npyscreen.TitleFixedText, + name="Model disk conversion cache size (GB). This is used to cache safetensors files that need to be converted to diffusers..", + begin_entry_at=0, + editable=False, + color="CONTROL", + scroll_exit=True, + ) + self.nextrely -= 1 + self.disk = self.add_widget_intelligent( + npyscreen.Slider, + value=clip(old_opts.convert_cache, range=(0, 100), step=0.5), + out_of=100, + lowest=0.0, + step=0.5, + relx=8, + scroll_exit=True, + ) + self.nextrely += 1 self.add_widget_intelligent( npyscreen.TitleFixedText, name="Model RAM cache size (GB). Make this at least large enough to hold a single full model (2GB for SD-1, 6GB for SDXL).", @@ -495,6 +532,14 @@ def create(self): ) else: self.vram = DummyWidgetValue.zero + + self.nextrely += 1 + self.add_widget_intelligent( + npyscreen.FixedText, + value="Location of the database used to store model path and configuration information:", + editable=False, + color="CONTROL", + ) self.nextrely += 1 self.outdir = self.add_widget_intelligent( FileBox, @@ -506,19 +551,21 @@ def create(self): labelColor="GOOD", begin_entry_at=40, max_height=3, + max_width=127, scroll_exit=True, ) self.autoimport_dirs = {} self.autoimport_dirs["autoimport_dir"] = self.add_widget_intelligent( FileBox, - name="Folder to recursively scan for new checkpoints, ControlNets, LoRAs and TI models", - value=str(config.root_path / config.autoimport_dir), + name="Optional folder to scan for new checkpoints, ControlNets, LoRAs and TI models", + value=str(config.root_path / config.autoimport_dir) if config.autoimport_dir else "", select_dir=True, must_exist=False, use_two_lines=False, labelColor="GOOD", begin_entry_at=32, max_height=3, + max_width=127, scroll_exit=True, ) self.nextrely += 1 @@ -555,6 +602,10 @@ def show_hide_slice_sizes(self, value): self.attention_slice_label.hidden = not show self.attention_slice_size.hidden = not show + def show_hide_model_conf_override(self, value): + self.model_conf_override.hidden = value + self.model_conf_override.display() + def on_ok(self): options = self.marshall_arguments() if self.validate_field_values(options): @@ -584,18 +635,21 @@ def validate_field_values(self, opt: Namespace) -> bool: else: return True - def marshall_arguments(self): + def marshall_arguments(self) -> Namespace: new_opts = Namespace() for attr in [ "ram", "vram", + "convert_cache", "outdir", ]: if hasattr(self, attr): setattr(new_opts, attr, getattr(self, attr).value) for attr in self.autoimport_dirs: + if not self.autoimport_dirs[attr].value: + continue directory = Path(self.autoimport_dirs[attr].value) if directory.is_relative_to(config.root_path): directory = directory.relative_to(config.root_path) @@ -615,13 +669,14 @@ def marshall_arguments(self): class EditOptApplication(npyscreen.NPSAppManaged): - def __init__(self, program_opts: Namespace, invokeai_opts: Namespace): + def __init__(self, program_opts: Namespace, invokeai_opts: InvokeAIAppConfig, install_helper: InstallHelper): super().__init__() self.program_opts = program_opts self.invokeai_opts = invokeai_opts self.user_cancelled = False self.autoload_pending = True - self.install_selections = default_user_selections(program_opts) + self.install_helper = install_helper + self.install_selections = default_user_selections(program_opts, install_helper) def onStart(self): npyscreen.setTheme(npyscreen.Themes.DefaultTheme) @@ -640,16 +695,10 @@ def onStart(self): cycle_widgets=False, ) - def new_opts(self): + def new_opts(self) -> Namespace: return self.options.marshall_arguments() -def edit_opts(program_opts: Namespace, invokeai_opts: Namespace) -> argparse.Namespace: - editApp = EditOptApplication(program_opts, invokeai_opts) - editApp.run() - return editApp.new_opts() - - def default_ramcache() -> float: """Run a heuristic for the default RAM cache based on installed RAM.""" @@ -660,27 +709,18 @@ def default_ramcache() -> float: ) # 2.1 is just large enough for sd 1.5 ;-) -def default_startup_options(init_file: Path) -> Namespace: +def default_startup_options(init_file: Path) -> InvokeAIAppConfig: opts = InvokeAIAppConfig.get_config() - opts.ram = opts.ram or default_ramcache() + opts.ram = default_ramcache() return opts -def default_user_selections(program_opts: Namespace) -> InstallSelections: - try: - installer = ModelInstall(config) - except omegaconf.errors.ConfigKeyError: - logger.warning("Your models.yaml file is corrupt or out of date. Reinitializing") - initialize_rootdir(config.root_path, True) - installer = ModelInstall(config) - - models = installer.all_models() +def default_user_selections(program_opts: Namespace, install_helper: InstallHelper) -> InstallSelections: + default_model = install_helper.default_model() + assert default_model is not None + default_models = [default_model] if program_opts.default_only else install_helper.recommended_models() return InstallSelections( - install_models=[models[installer.default_model()].path or models[installer.default_model()].repo_id] - if program_opts.default_only - else [models[x].path or models[x].repo_id for x in installer.recommended_models()] - if program_opts.yes_to_all - else [], + install_models=default_models if program_opts.yes_to_all else [], ) @@ -716,21 +756,10 @@ def initialize_rootdir(root: Path, yes_to_all: bool = False): path.mkdir(parents=True, exist_ok=True) -def maybe_create_models_yaml(root: Path): - models_yaml = root / "configs" / "models.yaml" - if models_yaml.exists(): - if OmegaConf.load(models_yaml).get("__metadata__"): # up to date - return - else: - logger.info("Creating new models.yaml, original saved as models.yaml.orig") - models_yaml.rename(models_yaml.parent / "models.yaml.orig") - - with open(models_yaml, "w") as yaml_file: - yaml_file.write(yaml.dump({"__metadata__": {"version": "3.0.0"}})) - - # ------------------------------------- -def run_console_ui(program_opts: Namespace, initfile: Path = None) -> (Namespace, Namespace): +def run_console_ui( + program_opts: Namespace, initfile: Path, install_helper: InstallHelper +) -> Tuple[Optional[Namespace], Optional[InstallSelections]]: invokeai_opts = default_startup_options(initfile) invokeai_opts.root = program_opts.root @@ -739,22 +768,16 @@ def run_console_ui(program_opts: Namespace, initfile: Path = None) -> (Namespace "Could not increase terminal size. Try running again with a larger window or smaller font size." ) - # the install-models application spawns a subprocess to install - # models, and will crash unless this is set before running. - import torch - - torch.multiprocessing.set_start_method("spawn") - - editApp = EditOptApplication(program_opts, invokeai_opts) + editApp = EditOptApplication(program_opts, invokeai_opts, install_helper) editApp.run() if editApp.user_cancelled: return (None, None) else: - return (editApp.new_opts, editApp.install_selections) + return (editApp.new_opts(), editApp.install_selections) # ------------------------------------- -def write_opts(opts: Namespace, init_file: Path): +def write_opts(opts: InvokeAIAppConfig, init_file: Path) -> None: """ Update the invokeai.yaml file with values from current settings. """ @@ -762,7 +785,7 @@ def write_opts(opts: Namespace, init_file: Path): new_config = InvokeAIAppConfig.get_config() new_config.root = config.root - for key, value in opts.__dict__.items(): + for key, value in opts.model_dump().items(): if hasattr(new_config, key): setattr(new_config, key, value) @@ -779,7 +802,7 @@ def default_output_dir() -> Path: # ------------------------------------- -def write_default_options(program_opts: Namespace, initfile: Path): +def write_default_options(program_opts: Namespace, initfile: Path) -> None: opt = default_startup_options(initfile) write_opts(opt, initfile) @@ -789,16 +812,11 @@ def write_default_options(program_opts: Namespace, initfile: Path): # the legacy Args object in order to parse # the old init file and write out the new # yaml format. -def migrate_init_file(legacy_format: Path): +def migrate_init_file(legacy_format: Path) -> None: old = legacy_parser.parse_args([f"@{str(legacy_format)}"]) new = InvokeAIAppConfig.get_config() - fields = [ - x - for x, y in InvokeAIAppConfig.model_fields.items() - if (y.json_schema_extra.get("category", None) if y.json_schema_extra else None) != "DEPRECATED" - ] - for attr in fields: + for attr in InvokeAIAppConfig.model_fields.keys(): if hasattr(old, attr): try: setattr(new, attr, getattr(old, attr)) @@ -819,7 +837,7 @@ def migrate_init_file(legacy_format: Path): # ------------------------------------- -def migrate_models(root: Path): +def migrate_models(root: Path) -> None: from invokeai.backend.install.migrate_to_3 import do_migrate do_migrate(root, root) @@ -838,7 +856,9 @@ def migrate_if_needed(opt: Namespace, root: Path) -> bool: ): logger.info("** Migrating invokeai.init to invokeai.yaml") migrate_init_file(old_init_file) - config.parse_args(argv=[], conf=OmegaConf.load(new_init_file)) + omegaconf = OmegaConf.load(new_init_file) + assert isinstance(omegaconf, DictConfig) + config.parse_args(argv=[], conf=omegaconf) if old_hub.exists(): migrate_models(config.root_path) @@ -849,7 +869,7 @@ def migrate_if_needed(opt: Namespace, root: Path) -> bool: # ------------------------------------- -def main() -> None: +def main(): parser = argparse.ArgumentParser(description="InvokeAI model downloader") parser.add_argument( "--skip-sd-weights", @@ -908,6 +928,7 @@ def main() -> None: if opt.full_precision: invoke_args.extend(["--precision", "float32"]) config.parse_args(invoke_args) + config.precision = "float32" if opt.full_precision else choose_precision(torch.device(choose_torch_device())) logger = InvokeAILogger().get_logger(config=config) errors = set() @@ -921,14 +942,18 @@ def main() -> None: # run this unconditionally in case new directories need to be added initialize_rootdir(config.root_path, opt.yes_to_all) - models_to_download = default_user_selections(opt) + # this will initialize the models.yaml file if not present + install_helper = InstallHelper(config, logger) + + models_to_download = default_user_selections(opt, install_helper) new_init_file = config.root_path / "invokeai.yaml" if opt.yes_to_all: write_default_options(opt, new_init_file) init_options = Namespace(precision="float32" if opt.full_precision else "float16") + else: - init_options, models_to_download = run_console_ui(opt, new_init_file) + init_options, models_to_download = run_console_ui(opt, new_init_file, install_helper) if init_options: write_opts(init_options, new_init_file) else: @@ -943,10 +968,12 @@ def main() -> None: if opt.skip_sd_weights: logger.warning("Skipping diffusion weights download per user request") + elif models_to_download: - process_and_execute(opt, models_to_download) + install_helper.add_or_delete(models_to_download) postscript(errors=errors) + if not opt.yes_to_all: input("Press any key to continue...") except WindowTooSmallException as e: diff --git a/invokeai/backend/model_manager/load/load_default.py b/invokeai/backend/model_manager/load/load_default.py index 2192c88ac2f..c1dfe729af7 100644 --- a/invokeai/backend/model_manager/load/load_default.py +++ b/invokeai/backend/model_manager/load/load_default.py @@ -19,7 +19,7 @@ ) from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase -from invokeai.backend.model_manager.load.model_cache.model_cache_base import CacheStats, ModelCacheBase, ModelLockerBase +from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data, calc_model_size_by_fs from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init from invokeai.backend.util.devices import choose_torch_device, torch_dtype diff --git a/invokeai/backend/util/devices.py b/invokeai/backend/util/devices.py index b4f24d8483b..a83d1045f70 100644 --- a/invokeai/backend/util/devices.py +++ b/invokeai/backend/util/devices.py @@ -1,7 +1,7 @@ from __future__ import annotations from contextlib import nullcontext -from typing import Optional, Union +from typing import Literal, Optional, Union import torch from torch import autocast @@ -31,7 +31,9 @@ def choose_torch_device() -> torch.device: # We are in transition here from using a single global AppConfig to allowing multiple # configurations. It is strongly recommended to pass the app_config to this function. -def choose_precision(device: torch.device, app_config: Optional[InvokeAIAppConfig] = None) -> str: +def choose_precision( + device: torch.device, app_config: Optional[InvokeAIAppConfig] = None +) -> Literal["float32", "float16", "bfloat16"]: """Return an appropriate precision for the given torch device.""" app_config = app_config or config if device.type == "cuda": diff --git a/invokeai/configs/INITIAL_MODELS.yaml b/invokeai/configs/INITIAL_MODELS.yaml index c230665e3a6..ca2283ab811 100644 --- a/invokeai/configs/INITIAL_MODELS.yaml +++ b/invokeai/configs/INITIAL_MODELS.yaml @@ -1,153 +1,157 @@ # This file predefines a few models that the user may want to install. sd-1/main/stable-diffusion-v1-5: description: Stable Diffusion version 1.5 diffusers model (4.27 GB) - repo_id: runwayml/stable-diffusion-v1-5 + source: runwayml/stable-diffusion-v1-5 recommended: True default: True sd-1/main/stable-diffusion-v1-5-inpainting: description: RunwayML SD 1.5 model optimized for inpainting, diffusers version (4.27 GB) - repo_id: runwayml/stable-diffusion-inpainting + source: runwayml/stable-diffusion-inpainting recommended: True sd-2/main/stable-diffusion-2-1: description: Stable Diffusion version 2.1 diffusers model, trained on 768 pixel images (5.21 GB) - repo_id: stabilityai/stable-diffusion-2-1 + source: stabilityai/stable-diffusion-2-1 recommended: False sd-2/main/stable-diffusion-2-inpainting: description: Stable Diffusion version 2.0 inpainting model (5.21 GB) - repo_id: stabilityai/stable-diffusion-2-inpainting + source: stabilityai/stable-diffusion-2-inpainting recommended: False sdxl/main/stable-diffusion-xl-base-1-0: description: Stable Diffusion XL base model (12 GB) - repo_id: stabilityai/stable-diffusion-xl-base-1.0 + source: stabilityai/stable-diffusion-xl-base-1.0 recommended: True sdxl-refiner/main/stable-diffusion-xl-refiner-1-0: description: Stable Diffusion XL refiner model (12 GB) - repo_id: stabilityai/stable-diffusion-xl-refiner-1.0 + source: stabilityai/stable-diffusion-xl-refiner-1.0 recommended: False -sdxl/vae/sdxl-1-0-vae-fix: - description: Fine tuned version of the SDXL-1.0 VAE - repo_id: madebyollin/sdxl-vae-fp16-fix +sdxl/vae/sdxl-vae-fp16-fix: + description: Version of the SDXL-1.0 VAE that works in half precision mode + source: madebyollin/sdxl-vae-fp16-fix recommended: True sd-1/main/Analog-Diffusion: description: An SD-1.5 model trained on diverse analog photographs (2.13 GB) - repo_id: wavymulder/Analog-Diffusion + source: wavymulder/Analog-Diffusion recommended: False -sd-1/main/Deliberate_v5: +sd-1/main/Deliberate: description: Versatile model that produces detailed images up to 768px (4.27 GB) - path: https://huggingface.co/XpucT/Deliberate/resolve/main/Deliberate_v5.safetensors + source: XpucT/Deliberate recommended: False sd-1/main/Dungeons-and-Diffusion: description: Dungeons & Dragons characters (2.13 GB) - repo_id: 0xJustin/Dungeons-and-Diffusion + source: 0xJustin/Dungeons-and-Diffusion recommended: False sd-1/main/dreamlike-photoreal-2: description: A photorealistic model trained on 768 pixel images based on SD 1.5 (2.13 GB) - repo_id: dreamlike-art/dreamlike-photoreal-2.0 + source: dreamlike-art/dreamlike-photoreal-2.0 recommended: False sd-1/main/Inkpunk-Diffusion: description: Stylized illustrations inspired by Gorillaz, FLCL and Shinkawa; prompt with "nvinkpunk" (4.27 GB) - repo_id: Envvi/Inkpunk-Diffusion + source: Envvi/Inkpunk-Diffusion recommended: False sd-1/main/openjourney: description: An SD 1.5 model fine tuned on Midjourney; prompt with "mdjrny-v4 style" (2.13 GB) - repo_id: prompthero/openjourney + source: prompthero/openjourney recommended: False sd-1/main/seek.art_MEGA: - repo_id: coreco/seek.art_MEGA + source: coreco/seek.art_MEGA description: A general use SD-1.5 "anything" model that supports multiple styles (2.1 GB) recommended: False sd-1/main/trinart_stable_diffusion_v2: description: An SD-1.5 model finetuned with ~40K assorted high resolution manga/anime-style images (2.13 GB) - repo_id: naclbit/trinart_stable_diffusion_v2 + source: naclbit/trinart_stable_diffusion_v2 recommended: False sd-1/controlnet/qrcode_monster: - repo_id: monster-labs/control_v1p_sd15_qrcode_monster + source: monster-labs/control_v1p_sd15_qrcode_monster subfolder: v2 sd-1/controlnet/canny: - repo_id: lllyasviel/control_v11p_sd15_canny + source: lllyasviel/control_v11p_sd15_canny recommended: True sd-1/controlnet/inpaint: - repo_id: lllyasviel/control_v11p_sd15_inpaint + source: lllyasviel/control_v11p_sd15_inpaint sd-1/controlnet/mlsd: - repo_id: lllyasviel/control_v11p_sd15_mlsd + source: lllyasviel/control_v11p_sd15_mlsd sd-1/controlnet/depth: - repo_id: lllyasviel/control_v11f1p_sd15_depth + source: lllyasviel/control_v11f1p_sd15_depth recommended: True sd-1/controlnet/normal_bae: - repo_id: lllyasviel/control_v11p_sd15_normalbae + source: lllyasviel/control_v11p_sd15_normalbae sd-1/controlnet/seg: - repo_id: lllyasviel/control_v11p_sd15_seg + source: lllyasviel/control_v11p_sd15_seg sd-1/controlnet/lineart: - repo_id: lllyasviel/control_v11p_sd15_lineart + source: lllyasviel/control_v11p_sd15_lineart recommended: True sd-1/controlnet/lineart_anime: - repo_id: lllyasviel/control_v11p_sd15s2_lineart_anime + source: lllyasviel/control_v11p_sd15s2_lineart_anime sd-1/controlnet/openpose: - repo_id: lllyasviel/control_v11p_sd15_openpose + source: lllyasviel/control_v11p_sd15_openpose recommended: True sd-1/controlnet/scribble: - repo_id: lllyasviel/control_v11p_sd15_scribble + source: lllyasviel/control_v11p_sd15_scribble recommended: False sd-1/controlnet/softedge: - repo_id: lllyasviel/control_v11p_sd15_softedge + source: lllyasviel/control_v11p_sd15_softedge sd-1/controlnet/shuffle: - repo_id: lllyasviel/control_v11e_sd15_shuffle + source: lllyasviel/control_v11e_sd15_shuffle sd-1/controlnet/tile: - repo_id: lllyasviel/control_v11f1e_sd15_tile + source: lllyasviel/control_v11f1e_sd15_tile sd-1/controlnet/ip2p: - repo_id: lllyasviel/control_v11e_sd15_ip2p + source: lllyasviel/control_v11e_sd15_ip2p sd-1/t2i_adapter/canny-sd15: - repo_id: TencentARC/t2iadapter_canny_sd15v2 + source: TencentARC/t2iadapter_canny_sd15v2 sd-1/t2i_adapter/sketch-sd15: - repo_id: TencentARC/t2iadapter_sketch_sd15v2 + source: TencentARC/t2iadapter_sketch_sd15v2 sd-1/t2i_adapter/depth-sd15: - repo_id: TencentARC/t2iadapter_depth_sd15v2 + source: TencentARC/t2iadapter_depth_sd15v2 sd-1/t2i_adapter/zoedepth-sd15: - repo_id: TencentARC/t2iadapter_zoedepth_sd15v1 + source: TencentARC/t2iadapter_zoedepth_sd15v1 sdxl/t2i_adapter/canny-sdxl: - repo_id: TencentARC/t2i-adapter-canny-sdxl-1.0 + source: TencentARC/t2i-adapter-canny-sdxl-1.0 sdxl/t2i_adapter/zoedepth-sdxl: - repo_id: TencentARC/t2i-adapter-depth-zoe-sdxl-1.0 + source: TencentARC/t2i-adapter-depth-zoe-sdxl-1.0 sdxl/t2i_adapter/lineart-sdxl: - repo_id: TencentARC/t2i-adapter-lineart-sdxl-1.0 + source: TencentARC/t2i-adapter-lineart-sdxl-1.0 sdxl/t2i_adapter/sketch-sdxl: - repo_id: TencentARC/t2i-adapter-sketch-sdxl-1.0 + source: TencentARC/t2i-adapter-sketch-sdxl-1.0 sd-1/embedding/EasyNegative: - path: https://huggingface.co/embed/EasyNegative/resolve/main/EasyNegative.safetensors + source: https://huggingface.co/embed/EasyNegative/resolve/main/EasyNegative.safetensors recommended: True -sd-1/embedding/ahx-beta-453407d: - repo_id: sd-concepts-library/ahx-beta-453407d + description: A textual inversion to use in the negative prompt to reduce bad anatomy +sd-1/lora/FlatColor: + source: https://civitai.com/models/6433/loraflatcolor + recommended: True + description: A LoRA that generates scenery using solid blocks of color sd-1/lora/Ink scenery: - path: https://civitai.com/api/download/models/83390 + source: https://civitai.com/api/download/models/83390 + description: Generate india ink-like landscapes sd-1/ip_adapter/ip_adapter_sd15: - repo_id: InvokeAI/ip_adapter_sd15 + source: InvokeAI/ip_adapter_sd15 recommended: True requires: - InvokeAI/ip_adapter_sd_image_encoder description: IP-Adapter for SD 1.5 models sd-1/ip_adapter/ip_adapter_plus_sd15: - repo_id: InvokeAI/ip_adapter_plus_sd15 + source: InvokeAI/ip_adapter_plus_sd15 recommended: False requires: - InvokeAI/ip_adapter_sd_image_encoder description: Refined IP-Adapter for SD 1.5 models sd-1/ip_adapter/ip_adapter_plus_face_sd15: - repo_id: InvokeAI/ip_adapter_plus_face_sd15 + source: InvokeAI/ip_adapter_plus_face_sd15 recommended: False requires: - InvokeAI/ip_adapter_sd_image_encoder description: Refined IP-Adapter for SD 1.5 models, adapted for faces sdxl/ip_adapter/ip_adapter_sdxl: - repo_id: InvokeAI/ip_adapter_sdxl + source: InvokeAI/ip_adapter_sdxl recommended: False requires: - InvokeAI/ip_adapter_sdxl_image_encoder description: IP-Adapter for SDXL models any/clip_vision/ip_adapter_sd_image_encoder: - repo_id: InvokeAI/ip_adapter_sd_image_encoder + source: InvokeAI/ip_adapter_sd_image_encoder recommended: False description: Required model for using IP-Adapters with SD-1/2 models any/clip_vision/ip_adapter_sdxl_image_encoder: - repo_id: InvokeAI/ip_adapter_sdxl_image_encoder + source: InvokeAI/ip_adapter_sdxl_image_encoder recommended: False description: Required model for using IP-Adapters with SDXL models diff --git a/invokeai/configs/INITIAL_MODELS2.yaml b/invokeai/configs/INITIAL_MODELS.yaml.OLD similarity index 59% rename from invokeai/configs/INITIAL_MODELS2.yaml rename to invokeai/configs/INITIAL_MODELS.yaml.OLD index ca2283ab811..c230665e3a6 100644 --- a/invokeai/configs/INITIAL_MODELS2.yaml +++ b/invokeai/configs/INITIAL_MODELS.yaml.OLD @@ -1,157 +1,153 @@ # This file predefines a few models that the user may want to install. sd-1/main/stable-diffusion-v1-5: description: Stable Diffusion version 1.5 diffusers model (4.27 GB) - source: runwayml/stable-diffusion-v1-5 + repo_id: runwayml/stable-diffusion-v1-5 recommended: True default: True sd-1/main/stable-diffusion-v1-5-inpainting: description: RunwayML SD 1.5 model optimized for inpainting, diffusers version (4.27 GB) - source: runwayml/stable-diffusion-inpainting + repo_id: runwayml/stable-diffusion-inpainting recommended: True sd-2/main/stable-diffusion-2-1: description: Stable Diffusion version 2.1 diffusers model, trained on 768 pixel images (5.21 GB) - source: stabilityai/stable-diffusion-2-1 + repo_id: stabilityai/stable-diffusion-2-1 recommended: False sd-2/main/stable-diffusion-2-inpainting: description: Stable Diffusion version 2.0 inpainting model (5.21 GB) - source: stabilityai/stable-diffusion-2-inpainting + repo_id: stabilityai/stable-diffusion-2-inpainting recommended: False sdxl/main/stable-diffusion-xl-base-1-0: description: Stable Diffusion XL base model (12 GB) - source: stabilityai/stable-diffusion-xl-base-1.0 + repo_id: stabilityai/stable-diffusion-xl-base-1.0 recommended: True sdxl-refiner/main/stable-diffusion-xl-refiner-1-0: description: Stable Diffusion XL refiner model (12 GB) - source: stabilityai/stable-diffusion-xl-refiner-1.0 + repo_id: stabilityai/stable-diffusion-xl-refiner-1.0 recommended: False -sdxl/vae/sdxl-vae-fp16-fix: - description: Version of the SDXL-1.0 VAE that works in half precision mode - source: madebyollin/sdxl-vae-fp16-fix +sdxl/vae/sdxl-1-0-vae-fix: + description: Fine tuned version of the SDXL-1.0 VAE + repo_id: madebyollin/sdxl-vae-fp16-fix recommended: True sd-1/main/Analog-Diffusion: description: An SD-1.5 model trained on diverse analog photographs (2.13 GB) - source: wavymulder/Analog-Diffusion + repo_id: wavymulder/Analog-Diffusion recommended: False -sd-1/main/Deliberate: +sd-1/main/Deliberate_v5: description: Versatile model that produces detailed images up to 768px (4.27 GB) - source: XpucT/Deliberate + path: https://huggingface.co/XpucT/Deliberate/resolve/main/Deliberate_v5.safetensors recommended: False sd-1/main/Dungeons-and-Diffusion: description: Dungeons & Dragons characters (2.13 GB) - source: 0xJustin/Dungeons-and-Diffusion + repo_id: 0xJustin/Dungeons-and-Diffusion recommended: False sd-1/main/dreamlike-photoreal-2: description: A photorealistic model trained on 768 pixel images based on SD 1.5 (2.13 GB) - source: dreamlike-art/dreamlike-photoreal-2.0 + repo_id: dreamlike-art/dreamlike-photoreal-2.0 recommended: False sd-1/main/Inkpunk-Diffusion: description: Stylized illustrations inspired by Gorillaz, FLCL and Shinkawa; prompt with "nvinkpunk" (4.27 GB) - source: Envvi/Inkpunk-Diffusion + repo_id: Envvi/Inkpunk-Diffusion recommended: False sd-1/main/openjourney: description: An SD 1.5 model fine tuned on Midjourney; prompt with "mdjrny-v4 style" (2.13 GB) - source: prompthero/openjourney + repo_id: prompthero/openjourney recommended: False sd-1/main/seek.art_MEGA: - source: coreco/seek.art_MEGA + repo_id: coreco/seek.art_MEGA description: A general use SD-1.5 "anything" model that supports multiple styles (2.1 GB) recommended: False sd-1/main/trinart_stable_diffusion_v2: description: An SD-1.5 model finetuned with ~40K assorted high resolution manga/anime-style images (2.13 GB) - source: naclbit/trinart_stable_diffusion_v2 + repo_id: naclbit/trinart_stable_diffusion_v2 recommended: False sd-1/controlnet/qrcode_monster: - source: monster-labs/control_v1p_sd15_qrcode_monster + repo_id: monster-labs/control_v1p_sd15_qrcode_monster subfolder: v2 sd-1/controlnet/canny: - source: lllyasviel/control_v11p_sd15_canny + repo_id: lllyasviel/control_v11p_sd15_canny recommended: True sd-1/controlnet/inpaint: - source: lllyasviel/control_v11p_sd15_inpaint + repo_id: lllyasviel/control_v11p_sd15_inpaint sd-1/controlnet/mlsd: - source: lllyasviel/control_v11p_sd15_mlsd + repo_id: lllyasviel/control_v11p_sd15_mlsd sd-1/controlnet/depth: - source: lllyasviel/control_v11f1p_sd15_depth + repo_id: lllyasviel/control_v11f1p_sd15_depth recommended: True sd-1/controlnet/normal_bae: - source: lllyasviel/control_v11p_sd15_normalbae + repo_id: lllyasviel/control_v11p_sd15_normalbae sd-1/controlnet/seg: - source: lllyasviel/control_v11p_sd15_seg + repo_id: lllyasviel/control_v11p_sd15_seg sd-1/controlnet/lineart: - source: lllyasviel/control_v11p_sd15_lineart + repo_id: lllyasviel/control_v11p_sd15_lineart recommended: True sd-1/controlnet/lineart_anime: - source: lllyasviel/control_v11p_sd15s2_lineart_anime + repo_id: lllyasviel/control_v11p_sd15s2_lineart_anime sd-1/controlnet/openpose: - source: lllyasviel/control_v11p_sd15_openpose + repo_id: lllyasviel/control_v11p_sd15_openpose recommended: True sd-1/controlnet/scribble: - source: lllyasviel/control_v11p_sd15_scribble + repo_id: lllyasviel/control_v11p_sd15_scribble recommended: False sd-1/controlnet/softedge: - source: lllyasviel/control_v11p_sd15_softedge + repo_id: lllyasviel/control_v11p_sd15_softedge sd-1/controlnet/shuffle: - source: lllyasviel/control_v11e_sd15_shuffle + repo_id: lllyasviel/control_v11e_sd15_shuffle sd-1/controlnet/tile: - source: lllyasviel/control_v11f1e_sd15_tile + repo_id: lllyasviel/control_v11f1e_sd15_tile sd-1/controlnet/ip2p: - source: lllyasviel/control_v11e_sd15_ip2p + repo_id: lllyasviel/control_v11e_sd15_ip2p sd-1/t2i_adapter/canny-sd15: - source: TencentARC/t2iadapter_canny_sd15v2 + repo_id: TencentARC/t2iadapter_canny_sd15v2 sd-1/t2i_adapter/sketch-sd15: - source: TencentARC/t2iadapter_sketch_sd15v2 + repo_id: TencentARC/t2iadapter_sketch_sd15v2 sd-1/t2i_adapter/depth-sd15: - source: TencentARC/t2iadapter_depth_sd15v2 + repo_id: TencentARC/t2iadapter_depth_sd15v2 sd-1/t2i_adapter/zoedepth-sd15: - source: TencentARC/t2iadapter_zoedepth_sd15v1 + repo_id: TencentARC/t2iadapter_zoedepth_sd15v1 sdxl/t2i_adapter/canny-sdxl: - source: TencentARC/t2i-adapter-canny-sdxl-1.0 + repo_id: TencentARC/t2i-adapter-canny-sdxl-1.0 sdxl/t2i_adapter/zoedepth-sdxl: - source: TencentARC/t2i-adapter-depth-zoe-sdxl-1.0 + repo_id: TencentARC/t2i-adapter-depth-zoe-sdxl-1.0 sdxl/t2i_adapter/lineart-sdxl: - source: TencentARC/t2i-adapter-lineart-sdxl-1.0 + repo_id: TencentARC/t2i-adapter-lineart-sdxl-1.0 sdxl/t2i_adapter/sketch-sdxl: - source: TencentARC/t2i-adapter-sketch-sdxl-1.0 + repo_id: TencentARC/t2i-adapter-sketch-sdxl-1.0 sd-1/embedding/EasyNegative: - source: https://huggingface.co/embed/EasyNegative/resolve/main/EasyNegative.safetensors + path: https://huggingface.co/embed/EasyNegative/resolve/main/EasyNegative.safetensors recommended: True - description: A textual inversion to use in the negative prompt to reduce bad anatomy -sd-1/lora/FlatColor: - source: https://civitai.com/models/6433/loraflatcolor - recommended: True - description: A LoRA that generates scenery using solid blocks of color +sd-1/embedding/ahx-beta-453407d: + repo_id: sd-concepts-library/ahx-beta-453407d sd-1/lora/Ink scenery: - source: https://civitai.com/api/download/models/83390 - description: Generate india ink-like landscapes + path: https://civitai.com/api/download/models/83390 sd-1/ip_adapter/ip_adapter_sd15: - source: InvokeAI/ip_adapter_sd15 + repo_id: InvokeAI/ip_adapter_sd15 recommended: True requires: - InvokeAI/ip_adapter_sd_image_encoder description: IP-Adapter for SD 1.5 models sd-1/ip_adapter/ip_adapter_plus_sd15: - source: InvokeAI/ip_adapter_plus_sd15 + repo_id: InvokeAI/ip_adapter_plus_sd15 recommended: False requires: - InvokeAI/ip_adapter_sd_image_encoder description: Refined IP-Adapter for SD 1.5 models sd-1/ip_adapter/ip_adapter_plus_face_sd15: - source: InvokeAI/ip_adapter_plus_face_sd15 + repo_id: InvokeAI/ip_adapter_plus_face_sd15 recommended: False requires: - InvokeAI/ip_adapter_sd_image_encoder description: Refined IP-Adapter for SD 1.5 models, adapted for faces sdxl/ip_adapter/ip_adapter_sdxl: - source: InvokeAI/ip_adapter_sdxl + repo_id: InvokeAI/ip_adapter_sdxl recommended: False requires: - InvokeAI/ip_adapter_sdxl_image_encoder description: IP-Adapter for SDXL models any/clip_vision/ip_adapter_sd_image_encoder: - source: InvokeAI/ip_adapter_sd_image_encoder + repo_id: InvokeAI/ip_adapter_sd_image_encoder recommended: False description: Required model for using IP-Adapters with SD-1/2 models any/clip_vision/ip_adapter_sdxl_image_encoder: - source: InvokeAI/ip_adapter_sdxl_image_encoder + repo_id: InvokeAI/ip_adapter_sdxl_image_encoder recommended: False description: Required model for using IP-Adapters with SDXL models diff --git a/invokeai/frontend/install/model_install.py b/invokeai/frontend/install/model_install.py index e23538ffd66..22b132370e6 100644 --- a/invokeai/frontend/install/model_install.py +++ b/invokeai/frontend/install/model_install.py @@ -6,47 +6,45 @@ """ This is the npyscreen frontend to the model installation application. -The work is actually done in backend code in model_install_backend.py. +It is currently named model_install2.py, but will ultimately replace model_install.py. """ import argparse import curses -import logging import sys -import textwrap import traceback +import warnings from argparse import Namespace -from multiprocessing import Process -from multiprocessing.connection import Connection, Pipe -from pathlib import Path from shutil import get_terminal_size -from typing import Optional +from typing import Any, Dict, List, Optional, Set import npyscreen import torch from npyscreen import widget from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend.install.model_install_backend import InstallSelections, ModelInstall, SchedulerPredictionType -from invokeai.backend.model_management import ModelManager, ModelType +from invokeai.app.services.model_install import ModelInstallServiceBase +from invokeai.backend.install.install_helper import InstallHelper, InstallSelections, UnifiedModelInfo +from invokeai.backend.model_manager import ModelType from invokeai.backend.util import choose_precision, choose_torch_device from invokeai.backend.util.logging import InvokeAILogger from invokeai.frontend.install.widgets import ( MIN_COLS, MIN_LINES, - BufferBox, CenteredTitleText, CyclingForm, MultiSelectColumns, SingleSelectColumns, TextBox, WindowTooSmallException, - select_stable_diffusion_config_file, set_min_terminal_size, ) +warnings.filterwarnings("ignore", category=UserWarning) # noqa: E402 config = InvokeAIAppConfig.get_config() -logger = InvokeAILogger.get_logger() +logger = InvokeAILogger.get_logger("ModelInstallService") +logger.setLevel("WARNING") +# logger.setLevel('DEBUG') # build a table mapping all non-printable characters to None # for stripping control characters @@ -58,44 +56,42 @@ def make_printable(s: str) -> str: - """Replace non-printable characters in a string""" + """Replace non-printable characters in a string.""" return s.translate(NOPRINT_TRANS_TABLE) class addModelsForm(CyclingForm, npyscreen.FormMultiPage): + """Main form for interactive TUI.""" + # for responsive resizing set to False, but this seems to cause a crash! FIX_MINIMUM_SIZE_WHEN_CREATED = True # for persistence current_tab = 0 - def __init__(self, parentApp, name, multipage=False, *args, **keywords): + def __init__(self, parentApp: npyscreen.NPSAppManaged, name: str, multipage: bool = False, **keywords: Any): self.multipage = multipage self.subprocess = None - super().__init__(parentApp=parentApp, name=name, *args, **keywords) # noqa: B026 # TODO: maybe this is bad? + super().__init__(parentApp=parentApp, name=name, **keywords) - def create(self): + def create(self) -> None: + self.installer = self.parentApp.install_helper.installer + self.model_labels = self._get_model_labels() self.keypress_timeout = 10 self.counter = 0 self.subprocess_connection = None - if not config.model_conf_path.exists(): - with open(config.model_conf_path, "w") as file: - print("# InvokeAI model configuration file", file=file) - self.installer = ModelInstall(config) - self.all_models = self.installer.all_models() - self.starter_models = self.installer.starter_models() - self.model_labels = self._get_model_labels() window_width, window_height = get_terminal_size() - self.nextrely -= 1 + # npyscreen has no typing hints + self.nextrely -= 1 # type: ignore self.add_widget_intelligent( npyscreen.FixedText, value="Use ctrl-N and ctrl-P to move to the ext and

revious fields. Cursor keys navigate, and selects.", editable=False, color="CAUTION", ) - self.nextrely += 1 + self.nextrely += 1 # type: ignore self.tabs = self.add_widget_intelligent( SingleSelectColumns, values=[ @@ -115,9 +111,9 @@ def create(self): ) self.tabs.on_changed = self._toggle_tables - top_of_table = self.nextrely + top_of_table = self.nextrely # type: ignore self.starter_pipelines = self.add_starter_pipelines() - bottom_of_table = self.nextrely + bottom_of_table = self.nextrely # type: ignore self.nextrely = top_of_table self.pipeline_models = self.add_pipeline_widgets( @@ -162,15 +158,7 @@ def create(self): self.nextrely = bottom_of_table + 1 - self.monitor = self.add_widget_intelligent( - BufferBox, - name="Log Messages", - editable=False, - max_height=6, - ) - self.nextrely += 1 - done_label = "APPLY CHANGES" back_label = "BACK" cancel_label = "CANCEL" current_position = self.nextrely @@ -186,14 +174,8 @@ def create(self): npyscreen.ButtonPress, name=cancel_label, when_pressed_function=self.on_cancel ) self.nextrely = current_position - self.ok_button = self.add_widget_intelligent( - npyscreen.ButtonPress, - name=done_label, - relx=(window_width - len(done_label)) // 2, - when_pressed_function=self.on_execute, - ) - label = "APPLY CHANGES & EXIT" + label = "APPLY CHANGES" self.nextrely = current_position self.done = self.add_widget_intelligent( npyscreen.ButtonPress, @@ -210,17 +192,16 @@ def create(self): ############# diffusers tab ########## def add_starter_pipelines(self) -> dict[str, npyscreen.widget]: """Add widgets responsible for selecting diffusers models""" - widgets = {} - models = self.all_models - starters = self.starter_models - starter_model_labels = self.model_labels + widgets: Dict[str, npyscreen.widget] = {} - self.installed_models = sorted([x for x in starters if models[x].installed]) + all_models = self.all_models # master dict of all models, indexed by key + model_list = [x for x in self.starter_models if all_models[x].type in ["main", "vae"]] + model_labels = [self.model_labels[x] for x in model_list] widgets.update( label1=self.add_widget_intelligent( CenteredTitleText, - name="Select from a starter set of Stable Diffusion models from HuggingFace.", + name="Select from a starter set of Stable Diffusion models from HuggingFace and Civitae.", editable=False, labelColor="CAUTION", ) @@ -230,23 +211,24 @@ def add_starter_pipelines(self) -> dict[str, npyscreen.widget]: # if user has already installed some initial models, then don't patronize them # by showing more recommendations show_recommended = len(self.installed_models) == 0 - keys = [x for x in models.keys() if x in starters] + + checked = [ + model_list.index(x) + for x in model_list + if (show_recommended and all_models[x].recommended) or all_models[x].installed + ] widgets.update( models_selected=self.add_widget_intelligent( MultiSelectColumns, columns=1, name="Install Starter Models", - values=[starter_model_labels[x] for x in keys], - value=[ - keys.index(x) - for x in keys - if (show_recommended and models[x].recommended) or (x in self.installed_models) - ], - max_height=len(starters) + 1, + values=model_labels, + value=checked, + max_height=len(model_list) + 1, relx=4, scroll_exit=True, ), - models=keys, + models=model_list, ) self.nextrely += 1 @@ -257,14 +239,18 @@ def add_model_widgets( self, model_type: ModelType, window_width: int = 120, - install_prompt: str = None, - exclude: set = None, + install_prompt: Optional[str] = None, + exclude: Optional[Set[str]] = None, ) -> dict[str, npyscreen.widget]: """Generic code to create model selection widgets""" if exclude is None: exclude = set() - widgets = {} - model_list = [x for x in self.all_models if self.all_models[x].model_type == model_type and x not in exclude] + widgets: Dict[str, npyscreen.widget] = {} + all_models = self.all_models + model_list = sorted( + [x for x in all_models if all_models[x].type == model_type and x not in exclude], + key=lambda x: all_models[x].name or "", + ) model_labels = [self.model_labels[x] for x in model_list] show_recommended = len(self.installed_models) == 0 @@ -300,7 +286,7 @@ def add_model_widgets( value=[ model_list.index(x) for x in model_list - if (show_recommended and self.all_models[x].recommended) or self.all_models[x].installed + if (show_recommended and all_models[x].recommended) or all_models[x].installed ], max_height=len(model_list) // columns + 1, relx=4, @@ -324,7 +310,7 @@ def add_model_widgets( download_ids=self.add_widget_intelligent( TextBox, name="Additional URLs, or HuggingFace repo_ids to install (Space separated. Use shift-control-V to paste):", - max_height=4, + max_height=6, scroll_exit=True, editable=True, ) @@ -349,13 +335,13 @@ def add_pipeline_widgets( return widgets - def resize(self): + def resize(self) -> None: super().resize() if s := self.starter_pipelines.get("models_selected"): - keys = [x for x in self.all_models.keys() if x in self.starter_models] - s.values = [self.model_labels[x] for x in keys] + if model_list := self.starter_pipelines.get("models"): + s.values = [self.model_labels[x] for x in model_list] - def _toggle_tables(self, value=None): + def _toggle_tables(self, value: List[int]) -> None: selected_tab = value[0] widgets = [ self.starter_pipelines, @@ -385,17 +371,18 @@ def _toggle_tables(self, value=None): self.display() def _get_model_labels(self) -> dict[str, str]: + """Return a list of trimmed labels for all models.""" window_width, window_height = get_terminal_size() checkbox_width = 4 spacing_width = 2 + result = {} models = self.all_models - label_width = max([len(models[x].name) for x in models]) + label_width = max([len(models[x].name or "") for x in self.starter_models]) description_width = window_width - label_width - checkbox_width - spacing_width - result = {} - for x in models.keys(): - description = models[x].description + for key in self.all_models: + description = models[key].description description = ( description[0 : description_width - 3] + "..." if description and len(description) > description_width @@ -403,7 +390,8 @@ def _get_model_labels(self) -> dict[str, str]: if description else "" ) - result[x] = f"%-{label_width}s %s" % (models[x].name, description) + result[key] = f"%-{label_width}s %s" % (models[key].name, description) + return result def _get_columns(self) -> int: @@ -413,50 +401,40 @@ def _get_columns(self) -> int: def confirm_deletions(self, selections: InstallSelections) -> bool: remove_models = selections.remove_models - if len(remove_models) > 0: - mods = "\n".join([ModelManager.parse_key(x)[0] for x in remove_models]) - return npyscreen.notify_ok_cancel( + if remove_models: + model_names = [self.all_models[x].name or "" for x in remove_models] + mods = "\n".join(model_names) + is_ok = npyscreen.notify_ok_cancel( f"These unchecked models will be deleted from disk. Continue?\n---------\n{mods}" ) + assert isinstance(is_ok, bool) # npyscreen doesn't have return type annotations + return is_ok else: return True - def on_execute(self): - self.marshall_arguments() - app = self.parentApp - if not self.confirm_deletions(app.install_selections): - return + @property + def all_models(self) -> Dict[str, UnifiedModelInfo]: + # npyscreen doesn't having typing hints + return self.parentApp.install_helper.all_models # type: ignore - self.monitor.entry_widget.buffer(["Processing..."], scroll_end=True) - self.ok_button.hidden = True - self.display() + @property + def starter_models(self) -> List[str]: + return self.parentApp.install_helper._starter_models # type: ignore - # TO DO: Spawn a worker thread, not a subprocess - parent_conn, child_conn = Pipe() - p = Process( - target=process_and_execute, - kwargs={ - "opt": app.program_opts, - "selections": app.install_selections, - "conn_out": child_conn, - }, - ) - p.start() - child_conn.close() - self.subprocess_connection = parent_conn - self.subprocess = p - app.install_selections = InstallSelections() + @property + def installed_models(self) -> List[str]: + return self.parentApp.install_helper._installed_models # type: ignore - def on_back(self): + def on_back(self) -> None: self.parentApp.switchFormPrevious() self.editing = False - def on_cancel(self): + def on_cancel(self) -> None: self.parentApp.setNextForm(None) self.parentApp.user_cancelled = True self.editing = False - def on_done(self): + def on_done(self) -> None: self.marshall_arguments() if not self.confirm_deletions(self.parentApp.install_selections): return @@ -464,77 +442,7 @@ def on_done(self): self.parentApp.user_cancelled = False self.editing = False - ########## This routine monitors the child process that is performing model installation and removal ##### - def while_waiting(self): - """Called during idle periods. Main task is to update the Log Messages box with messages - from the child process that does the actual installation/removal""" - c = self.subprocess_connection - if not c: - return - - monitor_widget = self.monitor.entry_widget - while c.poll(): - try: - data = c.recv_bytes().decode("utf-8") - data.strip("\n") - - # processing child is requesting user input to select the - # right configuration file - if data.startswith("*need v2 config"): - _, model_path, *_ = data.split(":", 2) - self._return_v2_config(model_path) - - # processing child is done - elif data == "*done*": - self._close_subprocess_and_regenerate_form() - break - - # update the log message box - else: - data = make_printable(data) - data = data.replace("[A", "") - monitor_widget.buffer( - textwrap.wrap( - data, - width=monitor_widget.width, - subsequent_indent=" ", - ), - scroll_end=True, - ) - self.display() - except (EOFError, OSError): - self.subprocess_connection = None - - def _return_v2_config(self, model_path: str): - c = self.subprocess_connection - model_name = Path(model_path).name - message = select_stable_diffusion_config_file(model_name=model_name) - c.send_bytes(message.encode("utf-8")) - - def _close_subprocess_and_regenerate_form(self): - app = self.parentApp - self.subprocess_connection.close() - self.subprocess_connection = None - self.monitor.entry_widget.buffer(["** Action Complete **"]) - self.display() - - # rebuild the form, saving and restoring some of the fields that need to be preserved. - saved_messages = self.monitor.entry_widget.values - - app.main_form = app.addForm( - "MAIN", - addModelsForm, - name="Install Stable Diffusion Models", - multipage=self.multipage, - ) - app.switchForm("MAIN") - - app.main_form.monitor.entry_widget.values = saved_messages - app.main_form.monitor.entry_widget.buffer([""], scroll_end=True) - # app.main_form.pipeline_models['autoload_directory'].value = autoload_dir - # app.main_form.pipeline_models['autoscan_on_startup'].value = autoscan - - def marshall_arguments(self): + def marshall_arguments(self) -> None: """ Assemble arguments and store as attributes of the application: .starter_models: dict of model names to install from INITIAL_CONFIGURE.yaml @@ -564,46 +472,24 @@ def marshall_arguments(self): models_to_install = [x for x in selected if not self.all_models[x].installed] models_to_remove = [x for x in section["models"] if x not in selected and self.all_models[x].installed] selections.remove_models.extend(models_to_remove) - selections.install_models.extend( - all_models[x].path or all_models[x].repo_id - for x in models_to_install - if all_models[x].path or all_models[x].repo_id - ) + selections.install_models.extend([all_models[x] for x in models_to_install]) # models located in the 'download_ids" section for section in ui_sections: if downloads := section.get("download_ids"): - selections.install_models.extend(downloads.value.split()) - - # NOT NEEDED - DONE IN BACKEND NOW - # # special case for the ipadapter_models. If any of the adapters are - # # chosen, then we add the corresponding encoder(s) to the install list. - # section = self.ipadapter_models - # if section.get("models_selected"): - # selected_adapters = [ - # self.all_models[section["models"][x]].name for x in section.get("models_selected").value - # ] - # encoders = [] - # if any(["sdxl" in x for x in selected_adapters]): - # encoders.append("ip_adapter_sdxl_image_encoder") - # if any(["sd15" in x for x in selected_adapters]): - # encoders.append("ip_adapter_sd_image_encoder") - # for encoder in encoders: - # key = f"any/clip_vision/{encoder}" - # repo_id = f"InvokeAI/{encoder}" - # if key not in self.all_models: - # selections.install_models.append(repo_id) - - -class AddModelApplication(npyscreen.NPSAppManaged): - def __init__(self, opt): + models = [UnifiedModelInfo(source=x) for x in downloads.value.split()] + selections.install_models.extend(models) + + +class AddModelApplication(npyscreen.NPSAppManaged): # type: ignore + def __init__(self, opt: Namespace, install_helper: InstallHelper): super().__init__() self.program_opts = opt self.user_cancelled = False - # self.autoload_pending = True self.install_selections = InstallSelections() + self.install_helper = install_helper - def onStart(self): + def onStart(self) -> None: npyscreen.setTheme(npyscreen.Themes.DefaultTheme) self.main_form = self.addForm( "MAIN", @@ -613,138 +499,62 @@ def onStart(self): ) -class StderrToMessage: - def __init__(self, connection: Connection): - self.connection = connection - - def write(self, data: str): - self.connection.send_bytes(data.encode("utf-8")) - - def flush(self): - pass +def list_models(installer: ModelInstallServiceBase, model_type: ModelType): + """Print out all models of type model_type.""" + models = installer.record_store.search_by_attr(model_type=model_type) + print(f"Installed models of type `{model_type}`:") + for model in models: + path = (config.models_path / model.path).resolve() + print(f"{model.name:40}{model.base.value:5}{model.type.value:8}{model.format.value:12}{path}") # -------------------------------------------------------- -def ask_user_for_prediction_type(model_path: Path, tui_conn: Connection = None) -> SchedulerPredictionType: - if tui_conn: - logger.debug("Waiting for user response...") - return _ask_user_for_pt_tui(model_path, tui_conn) - else: - return _ask_user_for_pt_cmdline(model_path) - - -def _ask_user_for_pt_cmdline(model_path: Path) -> Optional[SchedulerPredictionType]: - choices = [SchedulerPredictionType.Epsilon, SchedulerPredictionType.VPrediction, None] - print( - f""" -Please select the scheduler prediction type of the checkpoint named {model_path.name}: -[1] "epsilon" - most v1.5 models and v2 models trained on 512 pixel images -[2] "vprediction" - v2 models trained on 768 pixel images and a few v1.5 models -[3] Accept the best guess; you can fix it in the Web UI later -""" - ) - choice = None - ok = False - while not ok: - try: - choice = input("select [3]> ").strip() - if not choice: - return None - choice = choices[int(choice) - 1] - ok = True - except (ValueError, IndexError): - print(f"{choice} is not a valid choice") - except EOFError: - return - return choice - - -def _ask_user_for_pt_tui(model_path: Path, tui_conn: Connection) -> SchedulerPredictionType: - tui_conn.send_bytes(f"*need v2 config for:{model_path}".encode("utf-8")) - # note that we don't do any status checking here - response = tui_conn.recv_bytes().decode("utf-8") - if response is None: - return None - elif response == "epsilon": - return SchedulerPredictionType.epsilon - elif response == "v": - return SchedulerPredictionType.VPrediction - elif response == "guess": - return None - else: - return None - - -# -------------------------------------------------------- -def process_and_execute( - opt: Namespace, - selections: InstallSelections, - conn_out: Connection = None, -): - # need to reinitialize config in subprocess - config = InvokeAIAppConfig.get_config() - args = ["--root", opt.root] if opt.root else [] - config.parse_args(args) - - # set up so that stderr is sent to conn_out - if conn_out: - translator = StderrToMessage(conn_out) - sys.stderr = translator - sys.stdout = translator - logger = InvokeAILogger.get_logger() - logger.handlers.clear() - logger.addHandler(logging.StreamHandler(translator)) - - installer = ModelInstall(config, prediction_type_helper=lambda x: ask_user_for_prediction_type(x, conn_out)) - installer.install(selections) - - if conn_out: - conn_out.send_bytes("*done*".encode("utf-8")) - conn_out.close() - - -# -------------------------------------------------------- -def select_and_download_models(opt: Namespace): +def select_and_download_models(opt: Namespace) -> None: + """Prompt user for install/delete selections and execute.""" precision = "float32" if opt.full_precision else choose_precision(torch.device(choose_torch_device())) - config.precision = precision - installer = ModelInstall(config, prediction_type_helper=ask_user_for_prediction_type) + # unsure how to avoid a typing complaint in the next line: config.precision is an enumerated Literal + config.precision = precision # type: ignore + install_helper = InstallHelper(config, logger) + installer = install_helper.installer + if opt.list_models: - installer.list_models(opt.list_models) + list_models(installer, opt.list_models) + elif opt.add or opt.delete: - selections = InstallSelections(install_models=opt.add or [], remove_models=opt.delete or []) - installer.install(selections) + selections = InstallSelections( + install_models=[UnifiedModelInfo(source=x) for x in (opt.add or [])], remove_models=opt.delete or [] + ) + install_helper.add_or_delete(selections) + elif opt.default_only: - selections = InstallSelections(install_models=installer.default_model()) - installer.install(selections) + default_model = install_helper.default_model() + assert default_model is not None + selections = InstallSelections(install_models=[default_model]) + install_helper.add_or_delete(selections) + elif opt.yes_to_all: - selections = InstallSelections(install_models=installer.recommended_models()) - installer.install(selections) + selections = InstallSelections(install_models=install_helper.recommended_models()) + install_helper.add_or_delete(selections) # this is where the TUI is called else: - # needed to support the probe() method running under a subprocess - torch.multiprocessing.set_start_method("spawn") - if not set_min_terminal_size(MIN_COLS, MIN_LINES): raise WindowTooSmallException( "Could not increase terminal size. Try running again with a larger window or smaller font size." ) - installApp = AddModelApplication(opt) + installApp = AddModelApplication(opt, install_helper) try: installApp.run() - except KeyboardInterrupt as e: - if hasattr(installApp, "main_form"): - if installApp.main_form.subprocess and installApp.main_form.subprocess.is_alive(): - logger.info("Terminating subprocesses") - installApp.main_form.subprocess.terminate() - installApp.main_form.subprocess = None - raise e - process_and_execute(opt, installApp.install_selections) + except KeyboardInterrupt: + print("Aborted...") + sys.exit(-1) + + install_helper.add_or_delete(installApp.install_selections) # ------------------------------------- -def main(): +def main() -> None: parser = argparse.ArgumentParser(description="InvokeAI model downloader") parser.add_argument( "--add", @@ -754,7 +564,7 @@ def main(): parser.add_argument( "--delete", nargs="*", - help="List of names of models to idelete", + help="List of names of models to delete. Use type:name to disambiguate, as in `controlnet:my_model`", ) parser.add_argument( "--full-precision", @@ -781,14 +591,6 @@ def main(): choices=[x.value for x in ModelType], help="list installed models", ) - parser.add_argument( - "--config_file", - "-c", - dest="config_file", - type=str, - default=None, - help="path to configuration file to create", - ) parser.add_argument( "--root_dir", dest="root", diff --git a/invokeai/frontend/install/model_install2.py b/invokeai/frontend/install/model_install.py.OLD similarity index 57% rename from invokeai/frontend/install/model_install2.py rename to invokeai/frontend/install/model_install.py.OLD index 22b132370e6..e23538ffd66 100644 --- a/invokeai/frontend/install/model_install2.py +++ b/invokeai/frontend/install/model_install.py.OLD @@ -6,45 +6,47 @@ """ This is the npyscreen frontend to the model installation application. -It is currently named model_install2.py, but will ultimately replace model_install.py. +The work is actually done in backend code in model_install_backend.py. """ import argparse import curses +import logging import sys +import textwrap import traceback -import warnings from argparse import Namespace +from multiprocessing import Process +from multiprocessing.connection import Connection, Pipe +from pathlib import Path from shutil import get_terminal_size -from typing import Any, Dict, List, Optional, Set +from typing import Optional import npyscreen import torch from npyscreen import widget from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.app.services.model_install import ModelInstallServiceBase -from invokeai.backend.install.install_helper import InstallHelper, InstallSelections, UnifiedModelInfo -from invokeai.backend.model_manager import ModelType +from invokeai.backend.install.model_install_backend import InstallSelections, ModelInstall, SchedulerPredictionType +from invokeai.backend.model_management import ModelManager, ModelType from invokeai.backend.util import choose_precision, choose_torch_device from invokeai.backend.util.logging import InvokeAILogger from invokeai.frontend.install.widgets import ( MIN_COLS, MIN_LINES, + BufferBox, CenteredTitleText, CyclingForm, MultiSelectColumns, SingleSelectColumns, TextBox, WindowTooSmallException, + select_stable_diffusion_config_file, set_min_terminal_size, ) -warnings.filterwarnings("ignore", category=UserWarning) # noqa: E402 config = InvokeAIAppConfig.get_config() -logger = InvokeAILogger.get_logger("ModelInstallService") -logger.setLevel("WARNING") -# logger.setLevel('DEBUG') +logger = InvokeAILogger.get_logger() # build a table mapping all non-printable characters to None # for stripping control characters @@ -56,42 +58,44 @@ def make_printable(s: str) -> str: - """Replace non-printable characters in a string.""" + """Replace non-printable characters in a string""" return s.translate(NOPRINT_TRANS_TABLE) class addModelsForm(CyclingForm, npyscreen.FormMultiPage): - """Main form for interactive TUI.""" - # for responsive resizing set to False, but this seems to cause a crash! FIX_MINIMUM_SIZE_WHEN_CREATED = True # for persistence current_tab = 0 - def __init__(self, parentApp: npyscreen.NPSAppManaged, name: str, multipage: bool = False, **keywords: Any): + def __init__(self, parentApp, name, multipage=False, *args, **keywords): self.multipage = multipage self.subprocess = None - super().__init__(parentApp=parentApp, name=name, **keywords) + super().__init__(parentApp=parentApp, name=name, *args, **keywords) # noqa: B026 # TODO: maybe this is bad? - def create(self) -> None: - self.installer = self.parentApp.install_helper.installer - self.model_labels = self._get_model_labels() + def create(self): self.keypress_timeout = 10 self.counter = 0 self.subprocess_connection = None + if not config.model_conf_path.exists(): + with open(config.model_conf_path, "w") as file: + print("# InvokeAI model configuration file", file=file) + self.installer = ModelInstall(config) + self.all_models = self.installer.all_models() + self.starter_models = self.installer.starter_models() + self.model_labels = self._get_model_labels() window_width, window_height = get_terminal_size() - # npyscreen has no typing hints - self.nextrely -= 1 # type: ignore + self.nextrely -= 1 self.add_widget_intelligent( npyscreen.FixedText, value="Use ctrl-N and ctrl-P to move to the ext and

revious fields. Cursor keys navigate, and selects.", editable=False, color="CAUTION", ) - self.nextrely += 1 # type: ignore + self.nextrely += 1 self.tabs = self.add_widget_intelligent( SingleSelectColumns, values=[ @@ -111,9 +115,9 @@ def create(self) -> None: ) self.tabs.on_changed = self._toggle_tables - top_of_table = self.nextrely # type: ignore + top_of_table = self.nextrely self.starter_pipelines = self.add_starter_pipelines() - bottom_of_table = self.nextrely # type: ignore + bottom_of_table = self.nextrely self.nextrely = top_of_table self.pipeline_models = self.add_pipeline_widgets( @@ -158,7 +162,15 @@ def create(self) -> None: self.nextrely = bottom_of_table + 1 + self.monitor = self.add_widget_intelligent( + BufferBox, + name="Log Messages", + editable=False, + max_height=6, + ) + self.nextrely += 1 + done_label = "APPLY CHANGES" back_label = "BACK" cancel_label = "CANCEL" current_position = self.nextrely @@ -174,8 +186,14 @@ def create(self) -> None: npyscreen.ButtonPress, name=cancel_label, when_pressed_function=self.on_cancel ) self.nextrely = current_position + self.ok_button = self.add_widget_intelligent( + npyscreen.ButtonPress, + name=done_label, + relx=(window_width - len(done_label)) // 2, + when_pressed_function=self.on_execute, + ) - label = "APPLY CHANGES" + label = "APPLY CHANGES & EXIT" self.nextrely = current_position self.done = self.add_widget_intelligent( npyscreen.ButtonPress, @@ -192,16 +210,17 @@ def create(self) -> None: ############# diffusers tab ########## def add_starter_pipelines(self) -> dict[str, npyscreen.widget]: """Add widgets responsible for selecting diffusers models""" - widgets: Dict[str, npyscreen.widget] = {} + widgets = {} + models = self.all_models + starters = self.starter_models + starter_model_labels = self.model_labels - all_models = self.all_models # master dict of all models, indexed by key - model_list = [x for x in self.starter_models if all_models[x].type in ["main", "vae"]] - model_labels = [self.model_labels[x] for x in model_list] + self.installed_models = sorted([x for x in starters if models[x].installed]) widgets.update( label1=self.add_widget_intelligent( CenteredTitleText, - name="Select from a starter set of Stable Diffusion models from HuggingFace and Civitae.", + name="Select from a starter set of Stable Diffusion models from HuggingFace.", editable=False, labelColor="CAUTION", ) @@ -211,24 +230,23 @@ def add_starter_pipelines(self) -> dict[str, npyscreen.widget]: # if user has already installed some initial models, then don't patronize them # by showing more recommendations show_recommended = len(self.installed_models) == 0 - - checked = [ - model_list.index(x) - for x in model_list - if (show_recommended and all_models[x].recommended) or all_models[x].installed - ] + keys = [x for x in models.keys() if x in starters] widgets.update( models_selected=self.add_widget_intelligent( MultiSelectColumns, columns=1, name="Install Starter Models", - values=model_labels, - value=checked, - max_height=len(model_list) + 1, + values=[starter_model_labels[x] for x in keys], + value=[ + keys.index(x) + for x in keys + if (show_recommended and models[x].recommended) or (x in self.installed_models) + ], + max_height=len(starters) + 1, relx=4, scroll_exit=True, ), - models=model_list, + models=keys, ) self.nextrely += 1 @@ -239,18 +257,14 @@ def add_model_widgets( self, model_type: ModelType, window_width: int = 120, - install_prompt: Optional[str] = None, - exclude: Optional[Set[str]] = None, + install_prompt: str = None, + exclude: set = None, ) -> dict[str, npyscreen.widget]: """Generic code to create model selection widgets""" if exclude is None: exclude = set() - widgets: Dict[str, npyscreen.widget] = {} - all_models = self.all_models - model_list = sorted( - [x for x in all_models if all_models[x].type == model_type and x not in exclude], - key=lambda x: all_models[x].name or "", - ) + widgets = {} + model_list = [x for x in self.all_models if self.all_models[x].model_type == model_type and x not in exclude] model_labels = [self.model_labels[x] for x in model_list] show_recommended = len(self.installed_models) == 0 @@ -286,7 +300,7 @@ def add_model_widgets( value=[ model_list.index(x) for x in model_list - if (show_recommended and all_models[x].recommended) or all_models[x].installed + if (show_recommended and self.all_models[x].recommended) or self.all_models[x].installed ], max_height=len(model_list) // columns + 1, relx=4, @@ -310,7 +324,7 @@ def add_model_widgets( download_ids=self.add_widget_intelligent( TextBox, name="Additional URLs, or HuggingFace repo_ids to install (Space separated. Use shift-control-V to paste):", - max_height=6, + max_height=4, scroll_exit=True, editable=True, ) @@ -335,13 +349,13 @@ def add_pipeline_widgets( return widgets - def resize(self) -> None: + def resize(self): super().resize() if s := self.starter_pipelines.get("models_selected"): - if model_list := self.starter_pipelines.get("models"): - s.values = [self.model_labels[x] for x in model_list] + keys = [x for x in self.all_models.keys() if x in self.starter_models] + s.values = [self.model_labels[x] for x in keys] - def _toggle_tables(self, value: List[int]) -> None: + def _toggle_tables(self, value=None): selected_tab = value[0] widgets = [ self.starter_pipelines, @@ -371,18 +385,17 @@ def _toggle_tables(self, value: List[int]) -> None: self.display() def _get_model_labels(self) -> dict[str, str]: - """Return a list of trimmed labels for all models.""" window_width, window_height = get_terminal_size() checkbox_width = 4 spacing_width = 2 - result = {} models = self.all_models - label_width = max([len(models[x].name or "") for x in self.starter_models]) + label_width = max([len(models[x].name) for x in models]) description_width = window_width - label_width - checkbox_width - spacing_width - for key in self.all_models: - description = models[key].description + result = {} + for x in models.keys(): + description = models[x].description description = ( description[0 : description_width - 3] + "..." if description and len(description) > description_width @@ -390,8 +403,7 @@ def _get_model_labels(self) -> dict[str, str]: if description else "" ) - result[key] = f"%-{label_width}s %s" % (models[key].name, description) - + result[x] = f"%-{label_width}s %s" % (models[x].name, description) return result def _get_columns(self) -> int: @@ -401,40 +413,50 @@ def _get_columns(self) -> int: def confirm_deletions(self, selections: InstallSelections) -> bool: remove_models = selections.remove_models - if remove_models: - model_names = [self.all_models[x].name or "" for x in remove_models] - mods = "\n".join(model_names) - is_ok = npyscreen.notify_ok_cancel( + if len(remove_models) > 0: + mods = "\n".join([ModelManager.parse_key(x)[0] for x in remove_models]) + return npyscreen.notify_ok_cancel( f"These unchecked models will be deleted from disk. Continue?\n---------\n{mods}" ) - assert isinstance(is_ok, bool) # npyscreen doesn't have return type annotations - return is_ok else: return True - @property - def all_models(self) -> Dict[str, UnifiedModelInfo]: - # npyscreen doesn't having typing hints - return self.parentApp.install_helper.all_models # type: ignore + def on_execute(self): + self.marshall_arguments() + app = self.parentApp + if not self.confirm_deletions(app.install_selections): + return - @property - def starter_models(self) -> List[str]: - return self.parentApp.install_helper._starter_models # type: ignore + self.monitor.entry_widget.buffer(["Processing..."], scroll_end=True) + self.ok_button.hidden = True + self.display() - @property - def installed_models(self) -> List[str]: - return self.parentApp.install_helper._installed_models # type: ignore + # TO DO: Spawn a worker thread, not a subprocess + parent_conn, child_conn = Pipe() + p = Process( + target=process_and_execute, + kwargs={ + "opt": app.program_opts, + "selections": app.install_selections, + "conn_out": child_conn, + }, + ) + p.start() + child_conn.close() + self.subprocess_connection = parent_conn + self.subprocess = p + app.install_selections = InstallSelections() - def on_back(self) -> None: + def on_back(self): self.parentApp.switchFormPrevious() self.editing = False - def on_cancel(self) -> None: + def on_cancel(self): self.parentApp.setNextForm(None) self.parentApp.user_cancelled = True self.editing = False - def on_done(self) -> None: + def on_done(self): self.marshall_arguments() if not self.confirm_deletions(self.parentApp.install_selections): return @@ -442,7 +464,77 @@ def on_done(self) -> None: self.parentApp.user_cancelled = False self.editing = False - def marshall_arguments(self) -> None: + ########## This routine monitors the child process that is performing model installation and removal ##### + def while_waiting(self): + """Called during idle periods. Main task is to update the Log Messages box with messages + from the child process that does the actual installation/removal""" + c = self.subprocess_connection + if not c: + return + + monitor_widget = self.monitor.entry_widget + while c.poll(): + try: + data = c.recv_bytes().decode("utf-8") + data.strip("\n") + + # processing child is requesting user input to select the + # right configuration file + if data.startswith("*need v2 config"): + _, model_path, *_ = data.split(":", 2) + self._return_v2_config(model_path) + + # processing child is done + elif data == "*done*": + self._close_subprocess_and_regenerate_form() + break + + # update the log message box + else: + data = make_printable(data) + data = data.replace("[A", "") + monitor_widget.buffer( + textwrap.wrap( + data, + width=monitor_widget.width, + subsequent_indent=" ", + ), + scroll_end=True, + ) + self.display() + except (EOFError, OSError): + self.subprocess_connection = None + + def _return_v2_config(self, model_path: str): + c = self.subprocess_connection + model_name = Path(model_path).name + message = select_stable_diffusion_config_file(model_name=model_name) + c.send_bytes(message.encode("utf-8")) + + def _close_subprocess_and_regenerate_form(self): + app = self.parentApp + self.subprocess_connection.close() + self.subprocess_connection = None + self.monitor.entry_widget.buffer(["** Action Complete **"]) + self.display() + + # rebuild the form, saving and restoring some of the fields that need to be preserved. + saved_messages = self.monitor.entry_widget.values + + app.main_form = app.addForm( + "MAIN", + addModelsForm, + name="Install Stable Diffusion Models", + multipage=self.multipage, + ) + app.switchForm("MAIN") + + app.main_form.monitor.entry_widget.values = saved_messages + app.main_form.monitor.entry_widget.buffer([""], scroll_end=True) + # app.main_form.pipeline_models['autoload_directory'].value = autoload_dir + # app.main_form.pipeline_models['autoscan_on_startup'].value = autoscan + + def marshall_arguments(self): """ Assemble arguments and store as attributes of the application: .starter_models: dict of model names to install from INITIAL_CONFIGURE.yaml @@ -472,24 +564,46 @@ def marshall_arguments(self) -> None: models_to_install = [x for x in selected if not self.all_models[x].installed] models_to_remove = [x for x in section["models"] if x not in selected and self.all_models[x].installed] selections.remove_models.extend(models_to_remove) - selections.install_models.extend([all_models[x] for x in models_to_install]) + selections.install_models.extend( + all_models[x].path or all_models[x].repo_id + for x in models_to_install + if all_models[x].path or all_models[x].repo_id + ) # models located in the 'download_ids" section for section in ui_sections: if downloads := section.get("download_ids"): - models = [UnifiedModelInfo(source=x) for x in downloads.value.split()] - selections.install_models.extend(models) - - -class AddModelApplication(npyscreen.NPSAppManaged): # type: ignore - def __init__(self, opt: Namespace, install_helper: InstallHelper): + selections.install_models.extend(downloads.value.split()) + + # NOT NEEDED - DONE IN BACKEND NOW + # # special case for the ipadapter_models. If any of the adapters are + # # chosen, then we add the corresponding encoder(s) to the install list. + # section = self.ipadapter_models + # if section.get("models_selected"): + # selected_adapters = [ + # self.all_models[section["models"][x]].name for x in section.get("models_selected").value + # ] + # encoders = [] + # if any(["sdxl" in x for x in selected_adapters]): + # encoders.append("ip_adapter_sdxl_image_encoder") + # if any(["sd15" in x for x in selected_adapters]): + # encoders.append("ip_adapter_sd_image_encoder") + # for encoder in encoders: + # key = f"any/clip_vision/{encoder}" + # repo_id = f"InvokeAI/{encoder}" + # if key not in self.all_models: + # selections.install_models.append(repo_id) + + +class AddModelApplication(npyscreen.NPSAppManaged): + def __init__(self, opt): super().__init__() self.program_opts = opt self.user_cancelled = False + # self.autoload_pending = True self.install_selections = InstallSelections() - self.install_helper = install_helper - def onStart(self) -> None: + def onStart(self): npyscreen.setTheme(npyscreen.Themes.DefaultTheme) self.main_form = self.addForm( "MAIN", @@ -499,62 +613,138 @@ def onStart(self) -> None: ) -def list_models(installer: ModelInstallServiceBase, model_type: ModelType): - """Print out all models of type model_type.""" - models = installer.record_store.search_by_attr(model_type=model_type) - print(f"Installed models of type `{model_type}`:") - for model in models: - path = (config.models_path / model.path).resolve() - print(f"{model.name:40}{model.base.value:5}{model.type.value:8}{model.format.value:12}{path}") +class StderrToMessage: + def __init__(self, connection: Connection): + self.connection = connection + + def write(self, data: str): + self.connection.send_bytes(data.encode("utf-8")) + + def flush(self): + pass # -------------------------------------------------------- -def select_and_download_models(opt: Namespace) -> None: - """Prompt user for install/delete selections and execute.""" - precision = "float32" if opt.full_precision else choose_precision(torch.device(choose_torch_device())) - # unsure how to avoid a typing complaint in the next line: config.precision is an enumerated Literal - config.precision = precision # type: ignore - install_helper = InstallHelper(config, logger) - installer = install_helper.installer +def ask_user_for_prediction_type(model_path: Path, tui_conn: Connection = None) -> SchedulerPredictionType: + if tui_conn: + logger.debug("Waiting for user response...") + return _ask_user_for_pt_tui(model_path, tui_conn) + else: + return _ask_user_for_pt_cmdline(model_path) - if opt.list_models: - list_models(installer, opt.list_models) - elif opt.add or opt.delete: - selections = InstallSelections( - install_models=[UnifiedModelInfo(source=x) for x in (opt.add or [])], remove_models=opt.delete or [] - ) - install_helper.add_or_delete(selections) +def _ask_user_for_pt_cmdline(model_path: Path) -> Optional[SchedulerPredictionType]: + choices = [SchedulerPredictionType.Epsilon, SchedulerPredictionType.VPrediction, None] + print( + f""" +Please select the scheduler prediction type of the checkpoint named {model_path.name}: +[1] "epsilon" - most v1.5 models and v2 models trained on 512 pixel images +[2] "vprediction" - v2 models trained on 768 pixel images and a few v1.5 models +[3] Accept the best guess; you can fix it in the Web UI later +""" + ) + choice = None + ok = False + while not ok: + try: + choice = input("select [3]> ").strip() + if not choice: + return None + choice = choices[int(choice) - 1] + ok = True + except (ValueError, IndexError): + print(f"{choice} is not a valid choice") + except EOFError: + return + return choice + + +def _ask_user_for_pt_tui(model_path: Path, tui_conn: Connection) -> SchedulerPredictionType: + tui_conn.send_bytes(f"*need v2 config for:{model_path}".encode("utf-8")) + # note that we don't do any status checking here + response = tui_conn.recv_bytes().decode("utf-8") + if response is None: + return None + elif response == "epsilon": + return SchedulerPredictionType.epsilon + elif response == "v": + return SchedulerPredictionType.VPrediction + elif response == "guess": + return None + else: + return None - elif opt.default_only: - default_model = install_helper.default_model() - assert default_model is not None - selections = InstallSelections(install_models=[default_model]) - install_helper.add_or_delete(selections) +# -------------------------------------------------------- +def process_and_execute( + opt: Namespace, + selections: InstallSelections, + conn_out: Connection = None, +): + # need to reinitialize config in subprocess + config = InvokeAIAppConfig.get_config() + args = ["--root", opt.root] if opt.root else [] + config.parse_args(args) + + # set up so that stderr is sent to conn_out + if conn_out: + translator = StderrToMessage(conn_out) + sys.stderr = translator + sys.stdout = translator + logger = InvokeAILogger.get_logger() + logger.handlers.clear() + logger.addHandler(logging.StreamHandler(translator)) + + installer = ModelInstall(config, prediction_type_helper=lambda x: ask_user_for_prediction_type(x, conn_out)) + installer.install(selections) + + if conn_out: + conn_out.send_bytes("*done*".encode("utf-8")) + conn_out.close() + + +# -------------------------------------------------------- +def select_and_download_models(opt: Namespace): + precision = "float32" if opt.full_precision else choose_precision(torch.device(choose_torch_device())) + config.precision = precision + installer = ModelInstall(config, prediction_type_helper=ask_user_for_prediction_type) + if opt.list_models: + installer.list_models(opt.list_models) + elif opt.add or opt.delete: + selections = InstallSelections(install_models=opt.add or [], remove_models=opt.delete or []) + installer.install(selections) + elif opt.default_only: + selections = InstallSelections(install_models=installer.default_model()) + installer.install(selections) elif opt.yes_to_all: - selections = InstallSelections(install_models=install_helper.recommended_models()) - install_helper.add_or_delete(selections) + selections = InstallSelections(install_models=installer.recommended_models()) + installer.install(selections) # this is where the TUI is called else: + # needed to support the probe() method running under a subprocess + torch.multiprocessing.set_start_method("spawn") + if not set_min_terminal_size(MIN_COLS, MIN_LINES): raise WindowTooSmallException( "Could not increase terminal size. Try running again with a larger window or smaller font size." ) - installApp = AddModelApplication(opt, install_helper) + installApp = AddModelApplication(opt) try: installApp.run() - except KeyboardInterrupt: - print("Aborted...") - sys.exit(-1) - - install_helper.add_or_delete(installApp.install_selections) + except KeyboardInterrupt as e: + if hasattr(installApp, "main_form"): + if installApp.main_form.subprocess and installApp.main_form.subprocess.is_alive(): + logger.info("Terminating subprocesses") + installApp.main_form.subprocess.terminate() + installApp.main_form.subprocess = None + raise e + process_and_execute(opt, installApp.install_selections) # ------------------------------------- -def main() -> None: +def main(): parser = argparse.ArgumentParser(description="InvokeAI model downloader") parser.add_argument( "--add", @@ -564,7 +754,7 @@ def main() -> None: parser.add_argument( "--delete", nargs="*", - help="List of names of models to delete. Use type:name to disambiguate, as in `controlnet:my_model`", + help="List of names of models to idelete", ) parser.add_argument( "--full-precision", @@ -591,6 +781,14 @@ def main() -> None: choices=[x.value for x in ModelType], help="list installed models", ) + parser.add_argument( + "--config_file", + "-c", + dest="config_file", + type=str, + default=None, + help="path to configuration file to create", + ) parser.add_argument( "--root_dir", dest="root", diff --git a/invokeai/frontend/install/widgets.py b/invokeai/frontend/install/widgets.py index 5905ae29dab..4dbc6349a0b 100644 --- a/invokeai/frontend/install/widgets.py +++ b/invokeai/frontend/install/widgets.py @@ -267,6 +267,17 @@ def h_select(self, ch): self.on_changed(self.value) +class CheckboxWithChanged(npyscreen.Checkbox): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.on_changed = None + + def whenToggled(self): + super().whenToggled() + if self.on_changed: + self.on_changed(self.value) + + class SingleSelectColumnsSimple(SelectColumnBase, SingleSelectWithChanged): """Row of radio buttons. Spacebar to select.""" diff --git a/invokeai/frontend/merge/merge_diffusers2.py b/invokeai/frontend/merge/merge_diffusers.py.OLD similarity index 100% rename from invokeai/frontend/merge/merge_diffusers2.py rename to invokeai/frontend/merge/merge_diffusers.py.OLD diff --git a/pyproject.toml b/pyproject.toml index 8b28375e291..2958e3629a8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -136,8 +136,7 @@ dependencies = [ # full commands "invokeai-configure" = "invokeai.frontend.install.invokeai_configure:invokeai_configure" -"invokeai-merge" = "invokeai.frontend.merge:invokeai_merge_diffusers" -"invokeai-merge2" = "invokeai.frontend.merge.merge_diffusers2:main" +"invokeai-merge" = "invokeai.frontend.merge.merge_diffusers:main" "invokeai-ti" = "invokeai.frontend.training:invokeai_textual_inversion" "invokeai-model-install" = "invokeai.frontend.install.model_install:main" "invokeai-model-install2" = "invokeai.frontend.install.model_install2:main" # will eventually be renamed to invokeai-model-install diff --git a/tests/test_model_manager.py b/tests/test_model_manager.py deleted file mode 100644 index 3e48c7ed6fc..00000000000 --- a/tests/test_model_manager.py +++ /dev/null @@ -1,47 +0,0 @@ -from pathlib import Path - -import pytest - -from invokeai.app.services.config.config_default import InvokeAIAppConfig -from invokeai.backend import BaseModelType, ModelManager, ModelType, SubModelType - -BASIC_MODEL_NAME = ("SDXL base", BaseModelType.StableDiffusionXL, ModelType.Main) -VAE_OVERRIDE_MODEL_NAME = ("SDXL with VAE", BaseModelType.StableDiffusionXL, ModelType.Main) -VAE_NULL_OVERRIDE_MODEL_NAME = ("SDXL with empty VAE", BaseModelType.StableDiffusionXL, ModelType.Main) - - -@pytest.fixture -def model_manager(datadir) -> ModelManager: - InvokeAIAppConfig.get_config(root=datadir) - return ModelManager(datadir / "configs" / "relative_sub.models.yaml") - - -def test_get_model_names(model_manager: ModelManager): - names = model_manager.model_names() - assert names[:2] == [BASIC_MODEL_NAME, VAE_OVERRIDE_MODEL_NAME] - - -def test_get_model_path_for_diffusers(model_manager: ModelManager, datadir: Path): - model_config = model_manager._get_model_config(BASIC_MODEL_NAME[1], BASIC_MODEL_NAME[0], BASIC_MODEL_NAME[2]) - top_model_path, is_override = model_manager._get_model_path(model_config) - expected_model_path = datadir / "models" / "sdxl" / "main" / "SDXL base 1_0" - assert top_model_path == expected_model_path - assert not is_override - - -def test_get_model_path_for_overridden_vae(model_manager: ModelManager, datadir: Path): - model_config = model_manager._get_model_config( - VAE_OVERRIDE_MODEL_NAME[1], VAE_OVERRIDE_MODEL_NAME[0], VAE_OVERRIDE_MODEL_NAME[2] - ) - vae_model_path, is_override = model_manager._get_model_path(model_config, SubModelType.Vae) - expected_vae_path = datadir / "models" / "sdxl" / "vae" / "sdxl-vae-fp16-fix" - assert vae_model_path == expected_vae_path - assert is_override - - -def test_get_model_path_for_null_overridden_vae(model_manager: ModelManager, datadir: Path): - model_config = model_manager._get_model_config( - VAE_NULL_OVERRIDE_MODEL_NAME[1], VAE_NULL_OVERRIDE_MODEL_NAME[0], VAE_NULL_OVERRIDE_MODEL_NAME[2] - ) - vae_model_path, is_override = model_manager._get_model_path(model_config, SubModelType.Vae) - assert not is_override From d80575e99a7a205d9bdd09752b1614bf48edeab0 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Fri, 9 Feb 2024 20:46:47 -0500 Subject: [PATCH 080/340] probe for required encoder for IPAdapters and add to config --- invokeai/app/invocations/ip_adapter.py | 24 +----------------------- invokeai/backend/model_manager/config.py | 1 + invokeai/backend/model_manager/probe.py | 13 +++++++++++++ 3 files changed, 15 insertions(+), 23 deletions(-) diff --git a/invokeai/app/invocations/ip_adapter.py b/invokeai/app/invocations/ip_adapter.py index 700b285a45f..f64b3266bbb 100644 --- a/invokeai/app/invocations/ip_adapter.py +++ b/invokeai/app/invocations/ip_adapter.py @@ -1,4 +1,3 @@ -import os from builtins import float from typing import List, Union @@ -52,16 +51,6 @@ def validate_begin_end_step_percent(self) -> Self: return self -def get_ip_adapter_image_encoder_model_id(model_path: str): - """Read the ID of the image encoder associated with the IP-Adapter at `model_path`.""" - image_encoder_config_file = os.path.join(model_path, "image_encoder.txt") - - with open(image_encoder_config_file, "r") as f: - image_encoder_model = f.readline().strip() - - return image_encoder_model - - @invocation_output("ip_adapter_output") class IPAdapterOutput(BaseInvocationOutput): # Outputs @@ -102,18 +91,7 @@ def validate_begin_end_step_percent(self) -> Self: def invoke(self, context: InvocationContext) -> IPAdapterOutput: # Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model. ip_adapter_info = context.services.model_records.get_model(self.ip_adapter_model.key) - # HACK(ryand): This is bad for a couple of reasons: 1) we are bypassing the model manager to read the model - # directly, and 2) we are reading from disk every time this invocation is called without caching the result. - # A better solution would be to store the image encoder model reference in the IP-Adapter model info, but this - # is currently messy due to differences between how the model info is generated when installing a model from - # disk vs. downloading the model. - # TODO (LS): Fix the issue above by: - # 1. Change IPAdapterConfig definition to include a field for the repo_id of the image encoder model. - # 2. Update probe.py to read `image_encoder.txt` and store it in the config. - # 3. Change below to get the image encoder from the configuration record. - image_encoder_model_id = get_ip_adapter_image_encoder_model_id( - os.path.join(context.services.configuration.get_config().models_path, ip_adapter_info.path) - ) + image_encoder_model_id = ip_adapter_info.image_encoder_model_id image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip() image_encoder_models = context.services.model_records.search_by_attr( model_name=image_encoder_model_name, base_model=BaseModelType.Any, model_type=ModelType.CLIPVision diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 0dcd925c84b..d2e7a0923a4 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -263,6 +263,7 @@ class IPAdapterConfig(ModelConfigBase): """Model config for IP Adaptor format models.""" type: Literal[ModelType.IPAdapter] = ModelType.IPAdapter + image_encoder_model_id: str format: Literal[ModelFormat.InvokeAI] diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index 55a9c0464a5..e7d21c578fd 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -78,6 +78,10 @@ def get_scheduler_prediction_type(self) -> Optional[SchedulerPredictionType]: """Get model scheduler prediction type.""" return None + def get_image_encoder_model_id(self) -> Optional[str]: + """Get image encoder (IP adapters only).""" + return None + class ModelProbe(object): PROBES: Dict[str, Dict[ModelType, type[ProbeBase]]] = { @@ -153,6 +157,7 @@ def probe( fields["base"] = fields.get("base") or probe.get_base_type() fields["variant"] = fields.get("variant") or probe.get_variant_type() fields["prediction_type"] = fields.get("prediction_type") or probe.get_scheduler_prediction_type() + fields["image_encoder_model_id"] = fields.get("image_encoder_model_id") or probe.get_image_encoder_model_id() fields["name"] = fields.get("name") or cls.get_model_name(model_path) fields["description"] = ( fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}" @@ -669,6 +674,14 @@ def get_base_type(self) -> BaseModelType: f"IP-Adapter had unexpected cross-attention dimension: {cross_attention_dim}." ) + def get_image_encoder_model_id(self) -> Optional[str]: + encoder_id_path = self.model_path / "image_encoder.txt" + if not encoder_id_path.exists(): + return None + with open(encoder_id_path, "r") as f: + image_encoder_model = f.readline().strip() + return image_encoder_model + class CLIPVisionFolderProbe(FolderProbeBase): def get_base_type(self) -> BaseModelType: From ec15e3a1668b8e0b744ae1d612eca2bb7f60fcce Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Fri, 9 Feb 2024 23:08:38 -0500 Subject: [PATCH 081/340] consolidate model manager parts into a single class --- invokeai/app/services/model_load/__init__.py | 6 + .../services/model_load/model_load_base.py | 22 + .../services/model_load/model_load_default.py | 54 +++ .../app/services/model_manager/__init__.py | 17 +- .../model_manager/model_manager_base.py | 294 +---------- .../model_manager/model_manager_default.py | 456 ++---------------- invokeai/backend/__init__.py | 9 - invokeai/backend/model_manager/config.py | 6 +- .../backend/model_manager/load/__init__.py | 2 +- invokeai/backend/model_manager/search.py | 12 +- 10 files changed, 184 insertions(+), 694 deletions(-) create mode 100644 invokeai/app/services/model_load/__init__.py create mode 100644 invokeai/app/services/model_load/model_load_base.py create mode 100644 invokeai/app/services/model_load/model_load_default.py diff --git a/invokeai/app/services/model_load/__init__.py b/invokeai/app/services/model_load/__init__.py new file mode 100644 index 00000000000..b4a86e9348d --- /dev/null +++ b/invokeai/app/services/model_load/__init__.py @@ -0,0 +1,6 @@ +"""Initialization file for model load service module.""" + +from .model_load_base import ModelLoadServiceBase +from .model_load_default import ModelLoadService + +__all__ = ["ModelLoadServiceBase", "ModelLoadService"] diff --git a/invokeai/app/services/model_load/model_load_base.py b/invokeai/app/services/model_load/model_load_base.py new file mode 100644 index 00000000000..7228806e809 --- /dev/null +++ b/invokeai/app/services/model_load/model_load_base.py @@ -0,0 +1,22 @@ +# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Team +"""Base class for model loader.""" + +from abc import ABC, abstractmethod +from typing import Optional + +from invokeai.backend.model_manager import AnyModelConfig, SubModelType +from invokeai.backend.model_manager.load import LoadedModel + + +class ModelLoadServiceBase(ABC): + """Wrapper around AnyModelLoader.""" + + @abstractmethod + def load_model_by_key(self, key: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel: + """Given a model's key, load it and return the LoadedModel object.""" + pass + + @abstractmethod + def load_model_by_config(self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel: + """Given a model's configuration, load it and return the LoadedModel object.""" + pass diff --git a/invokeai/app/services/model_load/model_load_default.py b/invokeai/app/services/model_load/model_load_default.py new file mode 100644 index 00000000000..80e2fe161d0 --- /dev/null +++ b/invokeai/app/services/model_load/model_load_default.py @@ -0,0 +1,54 @@ +# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Team +"""Implementation of model loader service.""" + +from typing import Optional + +from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.app.services.model_records import ModelRecordServiceBase +from invokeai.backend.model_manager import AnyModelConfig, SubModelType +from invokeai.backend.model_manager.load import AnyModelLoader, LoadedModel, ModelCache, ModelConvertCache +from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase +from invokeai.backend.model_manager.load.ram_cache import ModelCacheBase +from invokeai.backend.util.logging import InvokeAILogger + +from .model_load_base import ModelLoadServiceBase + + +class ModelLoadService(ModelLoadServiceBase): + """Wrapper around AnyModelLoader.""" + + def __init__( + self, + app_config: InvokeAIAppConfig, + record_store: ModelRecordServiceBase, + ram_cache: Optional[ModelCacheBase] = None, + convert_cache: Optional[ModelConvertCacheBase] = None, + ): + """Initialize the model load service.""" + logger = InvokeAILogger.get_logger(self.__class__.__name__) + logger.setLevel(app_config.log_level.upper()) + self._store = record_store + self._any_loader = AnyModelLoader( + app_config=app_config, + logger=logger, + ram_cache=ram_cache + or ModelCache( + max_cache_size=app_config.ram_cache_size, + max_vram_cache_size=app_config.vram_cache_size, + logger=logger, + ), + convert_cache=convert_cache + or ModelConvertCache( + cache_path=app_config.models_convert_cache_path, + max_size=app_config.convert_cache_size, + ), + ) + + def load_model_by_key(self, key: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel: + """Given a model's key, load it and return the LoadedModel object.""" + config = self._store.get_model(key) + return self.load_model_by_config(config, submodel_type) + + def load_model_by_config(self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel: + """Given a model's configuration, load it and return the LoadedModel object.""" + return self._any_loader.load_model(config, submodel_type) diff --git a/invokeai/app/services/model_manager/__init__.py b/invokeai/app/services/model_manager/__init__.py index 3d6a9c248c6..5e281922a8b 100644 --- a/invokeai/app/services/model_manager/__init__.py +++ b/invokeai/app/services/model_manager/__init__.py @@ -1 +1,16 @@ -from .model_manager_default import ModelManagerService # noqa F401 +"""Initialization file for model manager service.""" + +from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType +from invokeai.backend.model_manager.load import LoadedModel + +from .model_manager_default import ModelManagerService + +__all__ = [ + "ModelManagerService", + "AnyModel", + "AnyModelConfig", + "BaseModelType", + "ModelType", + "SubModelType", + "LoadedModel", +] diff --git a/invokeai/app/services/model_manager/model_manager_base.py b/invokeai/app/services/model_manager/model_manager_base.py index f888c0ec973..c6e77fa163d 100644 --- a/invokeai/app/services/model_manager/model_manager_base.py +++ b/invokeai/app/services/model_manager/model_manager_base.py @@ -1,283 +1,39 @@ # Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team -from __future__ import annotations - from abc import ABC, abstractmethod -from logging import Logger -from pathlib import Path -from typing import Callable, List, Literal, Optional, Tuple, Union - -from pydantic import Field - -from invokeai.app.services.config.config_default import InvokeAIAppConfig -from invokeai.app.services.shared.invocation_context import InvocationContextData -from invokeai.backend.model_management import ( - AddModelResult, - BaseModelType, - LoadedModelInfo, - MergeInterpolationMethod, - ModelType, - SchedulerPredictionType, - SubModelType, -) -from invokeai.backend.model_management.model_cache import CacheStats - - -class ModelManagerServiceBase(ABC): - """Responsible for managing models on disk and in memory""" - - @abstractmethod - def __init__( - self, - config: InvokeAIAppConfig, - logger: Logger, - ): - """ - Initialize with the path to the models.yaml config file. - Optional parameters are the torch device type, precision, max_models, - and sequential_offload boolean. Note that the default device - type and precision are set up for a CUDA system running at half precision. - """ - pass - @abstractmethod - def get_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - submodel: Optional[SubModelType] = None, - context_data: Optional[InvocationContextData] = None, - ) -> LoadedModelInfo: - """Retrieve the indicated model with name and type. - submodel can be used to get a part (such as the vae) - of a diffusers pipeline.""" - pass - - @property - @abstractmethod - def logger(self): - pass - - @abstractmethod - def model_exists( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - ) -> bool: - pass - - @abstractmethod - def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict: - """ - Given a model name returns a dict-like (OmegaConf) object describing it. - Uses the exact format as the omegaconf stanza. - """ - pass +from pydantic import BaseModel, Field +from typing_extensions import Self - @abstractmethod - def list_models(self, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None) -> dict: - """ - Return a dict of models in the format: - { model_type1: - { model_name1: {'status': 'active'|'cached'|'not loaded', - 'model_name' : name, - 'model_type' : SDModelType, - 'description': description, - 'format': 'folder'|'safetensors'|'ckpt' - }, - model_name2: { etc } - }, - model_type2: - { model_name_n: etc - } - """ - pass +from ..config import InvokeAIAppConfig +from ..download import DownloadQueueServiceBase +from ..events.events_base import EventServiceBase +from ..model_install import ModelInstallServiceBase +from ..model_load import ModelLoadServiceBase +from ..model_records import ModelRecordServiceBase +from ..shared.sqlite.sqlite_database import SqliteDatabase - @abstractmethod - def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict: - """ - Return information about the model using the same format as list_models() - """ - pass - @abstractmethod - def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]: - """ - Returns a list of all the model names known. - """ - pass +class ModelManagerServiceBase(BaseModel, ABC): + """Abstract base class for the model manager service.""" - @abstractmethod - def add_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - model_attributes: dict, - clobber: bool = False, - ) -> AddModelResult: - """ - Update the named model with a dictionary of attributes. Will fail with an - assertion error if the name already exists. Pass clobber=True to overwrite. - On a successful update, the config will be changed in memory. Will fail - with an assertion error if provided attributes are incorrect or - the model name is missing. Call commit() to write changes to disk. - """ - pass + store: ModelRecordServiceBase = Field(description="An instance of the model record configuration service.") + install: ModelInstallServiceBase = Field(description="An instance of the model install service.") + load: ModelLoadServiceBase = Field(description="An instance of the model load service.") + @classmethod @abstractmethod - def update_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - model_attributes: dict, - ) -> AddModelResult: + def build_model_manager( + cls, + app_config: InvokeAIAppConfig, + db: SqliteDatabase, + download_queue: DownloadQueueServiceBase, + events: EventServiceBase, + ) -> Self: """ - Update the named model with a dictionary of attributes. Will fail with a - ModelNotFoundException if the name does not already exist. + Construct the model manager service instance. - On a successful update, the config will be changed in memory. Will fail - with an assertion error if provided attributes are incorrect or - the model name is missing. Call commit() to write changes to disk. - """ - pass - - @abstractmethod - def del_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - ): - """ - Delete the named model from configuration. If delete_files is true, - then the underlying weight file or diffusers directory will be deleted - as well. Call commit() to write to disk. - """ - pass - - @abstractmethod - def rename_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - new_name: str, - ): - """ - Rename the indicated model. - """ - pass - - @abstractmethod - def list_checkpoint_configs(self) -> List[Path]: - """ - List the checkpoint config paths from ROOT/configs/stable-diffusion. - """ - pass - - @abstractmethod - def convert_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: Literal[ModelType.Main, ModelType.Vae], - ) -> AddModelResult: - """ - Convert a checkpoint file into a diffusers folder, deleting the cached - version and deleting the original checkpoint file if it is in the models - directory. - :param model_name: Name of the model to convert - :param base_model: Base model type - :param model_type: Type of model ['vae' or 'main'] - - This will raise a ValueError unless the model is not a checkpoint. It will - also raise a ValueError in the event that there is a similarly-named diffusers - directory already in place. - """ - pass - - @abstractmethod - def heuristic_import( - self, - items_to_import: set[str], - prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None, - ) -> dict[str, AddModelResult]: - """Import a list of paths, repo_ids or URLs. Returns the set of - successfully imported items. - :param items_to_import: Set of strings corresponding to models to be imported. - :param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType. - - The prediction type helper is necessary to distinguish between - models based on Stable Diffusion 2 Base (requiring - SchedulerPredictionType.Epsilson) and Stable Diffusion 768 - (requiring SchedulerPredictionType.VPrediction). It is - generally impossible to do this programmatically, so the - prediction_type_helper usually asks the user to choose. - - The result is a set of successfully installed models. Each element - of the set is a dict corresponding to the newly-created OmegaConf stanza for - that model. - """ - pass - - @abstractmethod - def merge_models( - self, - model_names: List[str] = Field( - default=None, min_length=2, max_length=3, description="List of model names to merge" - ), - base_model: Union[BaseModelType, str] = Field( - default=None, description="Base model shared by all models to be merged" - ), - merged_model_name: str = Field(default=None, description="Name of destination model after merging"), - alpha: Optional[float] = 0.5, - interp: Optional[MergeInterpolationMethod] = None, - force: Optional[bool] = False, - merge_dest_directory: Optional[Path] = None, - ) -> AddModelResult: - """ - Merge two to three diffusrs pipeline models and save as a new model. - :param model_names: List of 2-3 models to merge - :param base_model: Base model to use for all models - :param merged_model_name: Name of destination merged model - :param alpha: Alpha strength to apply to 2d and 3d model - :param interp: Interpolation method. None (default) - :param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended) - """ - pass - - @abstractmethod - def search_for_models(self, directory: Path) -> List[Path]: - """ - Return list of all models found in the designated directory. - """ - pass - - @abstractmethod - def sync_to_config(self): - """ - Re-read models.yaml, rescan the models directory, and reimport models - in the autoimport directories. Call after making changes outside the - model manager API. - """ - pass - - @abstractmethod - def collect_cache_stats(self, cache_stats: CacheStats): - """ - Reset model cache statistics for graph with graph_id. - """ - pass - - @abstractmethod - def commit(self, conf_file: Optional[Path] = None) -> None: - """ - Write current configuration out to the indicated file. - If no conf_file is provided, then replaces the - original file/database used to initialize the object. + Use it rather than the __init__ constructor. This class + method simplifies the construction considerably. """ pass diff --git a/invokeai/app/services/model_manager/model_manager_default.py b/invokeai/app/services/model_manager/model_manager_default.py index c3712abf8e6..ad0fd66dbbd 100644 --- a/invokeai/app/services/model_manager/model_manager_default.py +++ b/invokeai/app/services/model_manager/model_manager_default.py @@ -1,421 +1,67 @@ # Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team +"""Implementation of ModelManagerServiceBase.""" -from __future__ import annotations +from typing_extensions import Self -from logging import Logger -from pathlib import Path -from typing import TYPE_CHECKING, Callable, List, Literal, Optional, Tuple, Union - -import torch -from pydantic import Field - -from invokeai.app.services.config.config_default import InvokeAIAppConfig -from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException -from invokeai.app.services.invoker import Invoker -from invokeai.app.services.shared.invocation_context import InvocationContextData -from invokeai.backend.model_management import ( - AddModelResult, - BaseModelType, - LoadedModelInfo, - MergeInterpolationMethod, - ModelManager, - ModelMerger, - ModelNotFoundException, - ModelType, - SchedulerPredictionType, - SubModelType, -) -from invokeai.backend.model_management.model_cache import CacheStats -from invokeai.backend.model_management.model_search import FindModels -from invokeai.backend.util import choose_precision, choose_torch_device +from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache +from invokeai.backend.model_manager.metadata import ModelMetadataStore +from invokeai.backend.util.logging import InvokeAILogger +from ..config import InvokeAIAppConfig +from ..download import DownloadQueueServiceBase +from ..events.events_base import EventServiceBase +from ..model_install import ModelInstallService +from ..model_load import ModelLoadService +from ..model_records import ModelRecordServiceSQL +from ..shared.sqlite.sqlite_database import SqliteDatabase from .model_manager_base import ModelManagerServiceBase -if TYPE_CHECKING: - pass - -# simple implementation class ModelManagerService(ModelManagerServiceBase): - """Responsible for managing models on disk and in memory""" - - def __init__( - self, - config: InvokeAIAppConfig, - logger: Logger, - ): - """ - Initialize with the path to the models.yaml config file. - Optional parameters are the torch device type, precision, max_models, - and sequential_offload boolean. Note that the default device - type and precision are set up for a CUDA system running at half precision. - """ - if config.model_conf_path and config.model_conf_path.exists(): - config_file = config.model_conf_path - else: - config_file = config.root_dir / "configs/models.yaml" - - logger.debug(f"Config file={config_file}") - - device = torch.device(choose_torch_device()) - device_name = torch.cuda.get_device_name() if device == torch.device("cuda") else "" - logger.info(f"GPU device = {device} {device_name}") - - precision = config.precision - if precision == "auto": - precision = choose_precision(device) - dtype = torch.float32 if precision == "float32" else torch.float16 - - # this is transitional backward compatibility - # support for the deprecated `max_loaded_models` - # configuration value. If present, then the - # cache size is set to 2.5 GB times - # the number of max_loaded_models. Otherwise - # use new `ram_cache_size` config setting - max_cache_size = config.ram_cache_size + """ + The ModelManagerService handles various aspects of model installation, maintenance and loading. - logger.debug(f"Maximum RAM cache size: {max_cache_size} GiB") + It bundles three distinct services: + model_manager.store -- Routines to manage the database of model configuration records. + model_manager.install -- Routines to install, move and delete models. + model_manager.load -- Routines to load models into memory. + """ - sequential_offload = config.sequential_guidance - - self.mgr = ModelManager( - config=config_file, - device_type=device, - precision=dtype, - max_cache_size=max_cache_size, - sequential_offload=sequential_offload, - logger=logger, - ) - logger.info("Model manager service initialized") - - def start(self, invoker: Invoker) -> None: - self._invoker: Optional[Invoker] = invoker - - def get_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - submodel: Optional[SubModelType] = None, - context_data: Optional[InvocationContextData] = None, - ) -> LoadedModelInfo: - """ - Retrieve the indicated model. submodel can be used to get a - part (such as the vae) of a diffusers mode. + @classmethod + def build_model_manager( + cls, + app_config: InvokeAIAppConfig, + db: SqliteDatabase, + download_queue: DownloadQueueServiceBase, + events: EventServiceBase, + ) -> Self: """ + Construct the model manager service instance. - # we can emit model loading events if we are executing with access to the invocation context - if context_data is not None: - self._emit_load_event( - context_data=context_data, - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=submodel, - ) + For simplicity, use this class method rather than the __init__ constructor. + """ + logger = InvokeAILogger.get_logger(cls.__name__) + logger.setLevel(app_config.log_level.upper()) - loaded_model_info = self.mgr.get_model( - model_name, - base_model, - model_type, - submodel, + ram_cache = ModelCache( + max_cache_size=app_config.ram_cache_size, max_vram_cache_size=app_config.vram_cache_size, logger=logger ) - - if context_data is not None: - self._emit_load_event( - context_data=context_data, - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=submodel, - loaded_model_info=loaded_model_info, - ) - - return loaded_model_info - - def model_exists( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - ) -> bool: - """ - Given a model name, returns True if it is a valid - identifier. - """ - return self.mgr.model_exists( - model_name, - base_model, - model_type, + convert_cache = ModelConvertCache( + cache_path=app_config.models_convert_cache_path, max_size=app_config.convert_cache_size ) - - def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> Union[dict, None]: - """ - Given a model name returns a dict-like (OmegaConf) object describing it. - """ - return self.mgr.model_info(model_name, base_model, model_type) - - def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]: - """ - Returns a list of all the model names known. - """ - return self.mgr.model_names() - - def list_models( - self, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None - ) -> list[dict]: - """ - Return a list of models. - """ - return self.mgr.list_models(base_model, model_type) - - def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> Union[dict, None]: - """ - Return information about the model using the same format as list_models() - """ - return self.mgr.list_model(model_name=model_name, base_model=base_model, model_type=model_type) - - def add_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - model_attributes: dict, - clobber: bool = False, - ) -> AddModelResult: - """ - Update the named model with a dictionary of attributes. Will fail with an - assertion error if the name already exists. Pass clobber=True to overwrite. - On a successful update, the config will be changed in memory. Will fail - with an assertion error if provided attributes are incorrect or - the model name is missing. Call commit() to write changes to disk. - """ - self.logger.debug(f"add/update model {model_name}") - return self.mgr.add_model(model_name, base_model, model_type, model_attributes, clobber) - - def update_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - model_attributes: dict, - ) -> AddModelResult: - """ - Update the named model with a dictionary of attributes. Will fail with a - ModelNotFoundException exception if the name does not already exist. - On a successful update, the config will be changed in memory. Will fail - with an assertion error if provided attributes are incorrect or - the model name is missing. Call commit() to write changes to disk. - """ - self.logger.debug(f"update model {model_name}") - if not self.model_exists(model_name, base_model, model_type): - raise ModelNotFoundException(f"Unknown model {model_name}") - return self.add_model(model_name, base_model, model_type, model_attributes, clobber=True) - - def del_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - ): - """ - Delete the named model from configuration. If delete_files is true, - then the underlying weight file or diffusers directory will be deleted - as well. - """ - self.logger.debug(f"delete model {model_name}") - self.mgr.del_model(model_name, base_model, model_type) - self.mgr.commit() - - def convert_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: Literal[ModelType.Main, ModelType.Vae], - convert_dest_directory: Optional[Path] = Field( - default=None, description="Optional directory location for merged model" - ), - ) -> AddModelResult: - """ - Convert a checkpoint file into a diffusers folder, deleting the cached - version and deleting the original checkpoint file if it is in the models - directory. - :param model_name: Name of the model to convert - :param base_model: Base model type - :param model_type: Type of model ['vae' or 'main'] - :param convert_dest_directory: Save the converted model to the designated directory (`models/etc/etc` by default) - - This will raise a ValueError unless the model is not a checkpoint. It will - also raise a ValueError in the event that there is a similarly-named diffusers - directory already in place. - """ - self.logger.debug(f"convert model {model_name}") - return self.mgr.convert_model(model_name, base_model, model_type, convert_dest_directory) - - def collect_cache_stats(self, cache_stats: CacheStats): - """ - Reset model cache statistics for graph with graph_id. - """ - self.mgr.cache.stats = cache_stats - - def commit(self, conf_file: Optional[Path] = None): - """ - Write current configuration out to the indicated file. - If no conf_file is provided, then replaces the - original file/database used to initialize the object. - """ - return self.mgr.commit(conf_file) - - def _emit_load_event( - self, - context_data: InvocationContextData, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - submodel: Optional[SubModelType] = None, - loaded_model_info: Optional[LoadedModelInfo] = None, - ): - if self._invoker is None: - return - - if self._invoker.services.queue.is_canceled(context_data.session_id): - raise CanceledException() - - if loaded_model_info: - self._invoker.services.events.emit_model_load_completed( - queue_id=context_data.queue_id, - queue_item_id=context_data.queue_item_id, - queue_batch_id=context_data.batch_id, - graph_execution_state_id=context_data.session_id, - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=submodel, - loaded_model_info=loaded_model_info, - ) - else: - self._invoker.services.events.emit_model_load_started( - queue_id=context_data.queue_id, - queue_item_id=context_data.queue_item_id, - queue_batch_id=context_data.batch_id, - graph_execution_state_id=context_data.session_id, - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=submodel, - ) - - @property - def logger(self): - return self.mgr.logger - - def heuristic_import( - self, - items_to_import: set[str], - prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None, - ) -> dict[str, AddModelResult]: - """Import a list of paths, repo_ids or URLs. Returns the set of - successfully imported items. - :param items_to_import: Set of strings corresponding to models to be imported. - :param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType. - - The prediction type helper is necessary to distinguish between - models based on Stable Diffusion 2 Base (requiring - SchedulerPredictionType.Epsilson) and Stable Diffusion 768 - (requiring SchedulerPredictionType.VPrediction). It is - generally impossible to do this programmatically, so the - prediction_type_helper usually asks the user to choose. - - The result is a set of successfully installed models. Each element - of the set is a dict corresponding to the newly-created OmegaConf stanza for - that model. - """ - return self.mgr.heuristic_import(items_to_import, prediction_type_helper) - - def merge_models( - self, - model_names: List[str] = Field( - default=None, min_length=2, max_length=3, description="List of model names to merge" - ), - base_model: Union[BaseModelType, str] = Field( - default=None, description="Base model shared by all models to be merged" - ), - merged_model_name: str = Field(default=None, description="Name of destination model after merging"), - alpha: float = 0.5, - interp: Optional[MergeInterpolationMethod] = None, - force: bool = False, - merge_dest_directory: Optional[Path] = Field( - default=None, description="Optional directory location for merged model" - ), - ) -> AddModelResult: - """ - Merge two to three diffusrs pipeline models and save as a new model. - :param model_names: List of 2-3 models to merge - :param base_model: Base model to use for all models - :param merged_model_name: Name of destination merged model - :param alpha: Alpha strength to apply to 2d and 3d model - :param interp: Interpolation method. None (default) - :param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended) - """ - merger = ModelMerger(self.mgr) - try: - result = merger.merge_diffusion_models_and_save( - model_names=model_names, - base_model=base_model, - merged_model_name=merged_model_name, - alpha=alpha, - interp=interp, - force=force, - merge_dest_directory=merge_dest_directory, - ) - except AssertionError as e: - raise ValueError(e) - return result - - def search_for_models(self, directory: Path) -> List[Path]: - """ - Return list of all models found in the designated directory. - """ - search = FindModels([directory], self.logger) - return search.list_models() - - def sync_to_config(self): - """ - Re-read models.yaml, rescan the models directory, and reimport models - in the autoimport directories. Call after making changes outside the - model manager API. - """ - return self.mgr.sync_to_config() - - def list_checkpoint_configs(self) -> List[Path]: - """ - List the checkpoint config paths from ROOT/configs/stable-diffusion. - """ - config = self.mgr.app_config - conf_path = config.legacy_conf_path - root_path = config.root_path - return [(conf_path / x).relative_to(root_path) for x in conf_path.glob("**/*.yaml")] - - def rename_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - new_name: Optional[str] = None, - new_base: Optional[BaseModelType] = None, - ): - """ - Rename the indicated model. Can provide a new name and/or a new base. - :param model_name: Current name of the model - :param base_model: Current base of the model - :param model_type: Model type (can't be changed) - :param new_name: New name for the model - :param new_base: New base for the model - """ - self.mgr.rename_model( - base_model=base_model, - model_type=model_type, - model_name=model_name, - new_name=new_name, - new_base=new_base, + record_store = ModelRecordServiceSQL(db=db) + loader = ModelLoadService( + app_config=app_config, + record_store=record_store, + ram_cache=ram_cache, + convert_cache=convert_cache, + ) + record_store._loader = loader # yeah, there is a circular reference here + installer = ModelInstallService( + app_config=app_config, + record_store=record_store, + download_queue=download_queue, + metadata_store=ModelMetadataStore(db=db), + event_bus=events, ) + return cls(store=record_store, install=installer, load=loader) diff --git a/invokeai/backend/__init__.py b/invokeai/backend/__init__.py index 54a1843d463..9fe97ee525e 100644 --- a/invokeai/backend/__init__.py +++ b/invokeai/backend/__init__.py @@ -1,12 +1,3 @@ """ Initialization file for invokeai.backend """ -from .model_management import ( # noqa: F401 - BaseModelType, - LoadedModelInfo, - ModelCache, - ModelManager, - ModelType, - SubModelType, -) -from .model_management.models import SilenceWarnings # noqa: F401 diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index d2e7a0923a4..4534a4892fb 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -21,7 +21,7 @@ """ import time from enum import Enum -from typing import Literal, Optional, Type, Union +from typing import Literal, Optional, Type, Union, Class import torch from diffusers import ModelMixin @@ -333,9 +333,9 @@ class ModelConfigFactory(object): @classmethod def make_config( cls, - model_data: Union[dict, AnyModelConfig], + model_data: Union[Dict[str, Any], AnyModelConfig], key: Optional[str] = None, - dest_class: Optional[Type] = None, + dest_class: Optional[Type[Class]] = None, timestamp: Optional[float] = None, ) -> AnyModelConfig: """ diff --git a/invokeai/backend/model_manager/load/__init__.py b/invokeai/backend/model_manager/load/__init__.py index e4c7077f783..966a739237a 100644 --- a/invokeai/backend/model_manager/load/__init__.py +++ b/invokeai/backend/model_manager/load/__init__.py @@ -18,7 +18,7 @@ for module in loaders: import_module(f"{__package__}.model_loaders.{module}") -__all__ = ["AnyModelLoader", "LoadedModel"] +__all__ = ["AnyModelLoader", "LoadedModel", "ModelCache", "ModelConvertCache"] def get_standalone_loader(app_config: Optional[InvokeAIAppConfig]) -> AnyModelLoader: diff --git a/invokeai/backend/model_manager/search.py b/invokeai/backend/model_manager/search.py index 4cc3caebe47..a54938fdd5c 100644 --- a/invokeai/backend/model_manager/search.py +++ b/invokeai/backend/model_manager/search.py @@ -26,10 +26,10 @@ def find_main_models(model: Path) -> bool: from typing import Callable, Optional, Set, Union from pydantic import BaseModel, Field - +from logging import Logger from invokeai.backend.util.logging import InvokeAILogger -default_logger = InvokeAILogger.get_logger() +default_logger: Logger = InvokeAILogger.get_logger() class SearchStats(BaseModel): @@ -56,7 +56,7 @@ class ModelSearchBase(ABC, BaseModel): on_model_found : Optional[Callable[[Path], bool]] = Field(default=None, description="Called when a model is found.") # noqa E221 on_search_completed : Optional[Callable[[Set[Path]], None]] = Field(default=None, description="Called when search is complete.") # noqa E221 stats : SearchStats = Field(default_factory=SearchStats, description="Summary statistics after search") # noqa E221 - logger : InvokeAILogger = Field(default=default_logger, description="Logger instance.") # noqa E221 + logger : Logger = Field(default=default_logger, description="Logger instance.") # noqa E221 # fmt: on class Config: @@ -128,13 +128,13 @@ def search_started(self) -> None: def model_found(self, model: Path) -> None: self.stats.models_found += 1 - if not self.on_model_found or self.on_model_found(model): + if self.on_model_found is None or self.on_model_found(model): self.stats.models_filtered += 1 self.models_found.add(model) def search_completed(self) -> None: - if self.on_search_completed: - self.on_search_completed(self._models_found) + if self.on_search_completed is not None: + self.on_search_completed(self.models_found) def search(self, directory: Union[Path, str]) -> Set[Path]: self._directory = Path(directory) From 387c8d901cdda12c05f088354ced25814494e70c Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sat, 10 Feb 2024 18:09:45 -0500 Subject: [PATCH 082/340] make model manager v2 ready for PR review - Replace legacy model manager service with the v2 manager. - Update invocations to use new load interface. - Fixed many but not all type checking errors in the invocations. Most were unrelated to model manager - Updated routes. All the new routes live under the route tag `model_manager_v2`. To avoid confusion with the old routes, they have the URL prefix `/api/v2/models`. The old routes have been de-registered. - Added a pytest for the loader. - Updated documentation in contributing/MODEL_MANAGER.md --- docs/contributing/MODEL_MANAGER.md | 223 ++++++++++++------ invokeai/app/api/dependencies.py | 29 +-- .../{model_records.py => model_manager_v2.py} | 82 ++++--- invokeai/app/api/routers/models.py | 3 +- invokeai/app/api_app.py | 5 +- invokeai/app/invocations/compel.py | 43 ++-- invokeai/app/invocations/latent.py | 168 ++++++++----- invokeai/app/invocations/model.py | 2 +- invokeai/app/services/invocation_services.py | 6 - .../invocation_stats_default.py | 9 +- .../services/model_load/model_load_base.py | 60 ++++- .../services/model_load/model_load_default.py | 115 ++++++++- .../model_manager/model_manager_base.py | 38 ++- .../model_manager/model_manager_default.py | 39 ++- .../model_records/model_records_base.py | 49 ---- .../model_records/model_records_sql.py | 99 +------- .../sqlite_migrator/migrations/migration_6.py | 19 ++ invokeai/backend/embeddings/model_patcher.py | 4 +- invokeai/backend/image_util/safety_checker.py | 2 +- invokeai/backend/ip_adapter/ip_adapter.py | 4 +- invokeai/backend/model_manager/config.py | 13 +- .../backend/model_manager/load/load_base.py | 18 +- .../model_manager/load/load_default.py | 2 +- .../load/model_cache/model_cache_default.py | 2 +- .../load/model_loaders/controlnet.py | 12 +- .../load/model_loaders/generic_diffusers.py | 5 +- .../load/model_loaders/stable_diffusion.py | 8 +- .../model_manager/load/model_loaders/vae.py | 8 +- .../backend/model_manager/load/model_util.py | 2 +- invokeai/backend/model_manager/search.py | 3 +- .../stable_diffusion/schedulers/__init__.py | 2 + invokeai/frontend/install/model_install.py | 2 +- tests/aa_nodes/test_graph_execution_state.py | 2 - tests/aa_nodes/test_invoker.py | 2 - .../model_loading/test_model_load.py | 22 ++ .../model_manager_2_fixtures.py | 11 + 36 files changed, 679 insertions(+), 434 deletions(-) rename invokeai/app/api/routers/{model_records.py => model_manager_v2.py} (86%) create mode 100644 tests/backend/model_manager_2/model_loading/test_model_load.py diff --git a/docs/contributing/MODEL_MANAGER.md b/docs/contributing/MODEL_MANAGER.md index 880c8b24801..39220f4ba89 100644 --- a/docs/contributing/MODEL_MANAGER.md +++ b/docs/contributing/MODEL_MANAGER.md @@ -28,7 +28,7 @@ model. These are the: Hugging Face, as well as discriminating among model versions in Civitai, but can be used for arbitrary content. - * _ModelLoadServiceBase_ (**CURRENTLY UNDER DEVELOPMENT - NOT IMPLEMENTED**) + * _ModelLoadServiceBase_ Responsible for loading a model from disk into RAM and VRAM and getting it ready for inference. @@ -41,10 +41,10 @@ The four main services can be found in * `invokeai/app/services/model_records/` * `invokeai/app/services/model_install/` * `invokeai/app/services/downloads/` -* `invokeai/app/services/model_loader/` (**under development**) +* `invokeai/app/services/model_load/` Code related to the FastAPI web API can be found in -`invokeai/app/api/routers/model_records.py`. +`invokeai/app/api/routers/model_manager_v2.py`. *** @@ -84,10 +84,10 @@ diffusers model. When this happens, `original_hash` is unchanged, but `ModelType`, `ModelFormat` and `BaseModelType` are string enums that are defined in `invokeai.backend.model_manager.config`. They are also imported by, and can be reexported from, -`invokeai.app.services.model_record_service`: +`invokeai.app.services.model_manager.model_records`: ``` -from invokeai.app.services.model_record_service import ModelType, ModelFormat, BaseModelType +from invokeai.app.services.model_records import ModelType, ModelFormat, BaseModelType ``` The `path` field can be absolute or relative. If relative, it is taken @@ -123,7 +123,7 @@ taken to be the `models_dir` directory. `variant` is an enumerated string class with values `normal`, `inpaint` and `depth`. If needed, it can be imported if needed from -either `invokeai.app.services.model_record_service` or +either `invokeai.app.services.model_records` or `invokeai.backend.model_manager.config`. ### ONNXSD2Config @@ -134,7 +134,7 @@ either `invokeai.app.services.model_record_service` or | `upcast_attention` | bool | Model requires its attention module to be upcast | The `SchedulerPredictionType` enum can be imported from either -`invokeai.app.services.model_record_service` or +`invokeai.app.services.model_records` or `invokeai.backend.model_manager.config`. ### Other config classes @@ -157,15 +157,6 @@ indicates that the model is compatible with any of the base models. This works OK for some models, such as the IP Adapter image encoders, but is an all-or-nothing proposition. -Another issue is that the config class hierarchy is paralleled to some -extent by a `ModelBase` class hierarchy defined in -`invokeai.backend.model_manager.models.base` and its subclasses. These -are classes representing the models after they are loaded into RAM and -include runtime information such as load status and bytes used. Some -of the fields, including `name`, `model_type` and `base_model`, are -shared between `ModelConfigBase` and `ModelBase`, and this is a -potential source of confusion. - ## Reading and Writing Model Configuration Records The `ModelRecordService` provides the ability to retrieve model @@ -177,11 +168,11 @@ initialization and can be retrieved within an invocation from the `InvocationContext` object: ``` -store = context.services.model_record_store +store = context.services.model_manager.store ``` or from elsewhere in the code by accessing -`ApiDependencies.invoker.services.model_record_store`. +`ApiDependencies.invoker.services.model_manager.store`. ### Creating a `ModelRecordService` @@ -190,7 +181,7 @@ you can directly create either a `ModelRecordServiceSQL` or a `ModelRecordServiceFile` object: ``` -from invokeai.app.services.model_record_service import ModelRecordServiceSQL, ModelRecordServiceFile +from invokeai.app.services.model_records import ModelRecordServiceSQL, ModelRecordServiceFile store = ModelRecordServiceSQL.from_connection(connection, lock) store = ModelRecordServiceSQL.from_db_file('/path/to/sqlite_database.db') @@ -252,7 +243,7 @@ So a typical startup pattern would be: ``` import sqlite3 from invokeai.app.services.thread import lock -from invokeai.app.services.model_record_service import ModelRecordServiceBase +from invokeai.app.services.model_records import ModelRecordServiceBase from invokeai.app.services.config import InvokeAIAppConfig config = InvokeAIAppConfig.get_config() @@ -260,19 +251,6 @@ db_conn = sqlite3.connect(config.db_path.as_posix(), check_same_thread=False) store = ModelRecordServiceBase.open(config, db_conn, lock) ``` -_A note on simultaneous access to `invokeai.db`_: The current InvokeAI -service architecture for the image and graph databases is careful to -use a shared sqlite3 connection and a thread lock to ensure that two -threads don't attempt to access the database simultaneously. However, -the default `sqlite3` library used by Python reports using -**Serialized** mode, which allows multiple threads to access the -database simultaneously using multiple database connections (see -https://www.sqlite.org/threadsafe.html and -https://ricardoanderegg.com/posts/python-sqlite-thread-safety/). Therefore -it should be safe to allow the record service to open its own SQLite -database connection. Opening a model record service should then be as -simple as `ModelRecordServiceBase.open(config)`. - ### Fetching a Model's Configuration from `ModelRecordServiceBase` Configurations can be retrieved in several ways. @@ -1465,7 +1443,7 @@ create alternative instances if you wish. ### Creating a ModelLoadService object The class is defined in -`invokeai.app.services.model_loader_service`. It is initialized with +`invokeai.app.services.model_load`. It is initialized with an InvokeAIAppConfig object, from which it gets configuration information such as the user's desired GPU and precision, and with a previously-created `ModelRecordServiceBase` object, from which it @@ -1475,8 +1453,8 @@ Here is a typical initialization pattern: ``` from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.app.services.model_record_service import ModelRecordServiceBase -from invokeai.app.services.model_loader_service import ModelLoadService +from invokeai.app.services.model_records import ModelRecordServiceBase +from invokeai.app.services.model_load import ModelLoadService config = InvokeAIAppConfig.get_config() store = ModelRecordServiceBase.open(config) @@ -1487,14 +1465,11 @@ Note that we are relying on the contents of the application configuration to choose the implementation of `ModelRecordServiceBase`. -### get_model(key, [submodel_type], [context]) -> ModelInfo: +### load_model_by_key(key, [submodel_type], [context]) -> LoadedModel -*** TO DO: change to get_model(key, context=None, **kwargs) - -The `get_model()` method, like its similarly-named cousin in -`ModelRecordService`, receives the unique key that identifies the -model. It loads the model into memory, gets the model ready for use, -and returns a `ModelInfo` object. +The `load_model_by_key()` method receives the unique key that +identifies the model. It loads the model into memory, gets the model +ready for use, and returns a `LoadedModel` object. The optional second argument, `subtype` is a `SubModelType` string enum, such as "vae". It is mandatory when used with a main model, and @@ -1504,46 +1479,64 @@ The optional third argument, `context` can be provided by an invocation to trigger model load event reporting. See below for details. -The returned `ModelInfo` object shares some fields in common with -`ModelConfigBase`, but is otherwise a completely different beast: +The returned `LoadedModel` object contains a copy of the configuration +record returned by the model record `get_model()` method, as well as +the in-memory loaded model: -| **Field Name** | **Type** | **Description** | + +| **Attribute Name** | **Type** | **Description** | |----------------|-----------------|------------------| -| `key` | str | The model key derived from the ModelRecordService database | -| `name` | str | Name of this model | -| `base_model` | BaseModelType | Base model for this model | -| `type` | ModelType or SubModelType | Either the model type (non-main) or the submodel type (main models)| -| `location` | Path or str | Location of the model on the filesystem | -| `precision` | torch.dtype | The torch.precision to use for inference | -| `context` | ModelCache.ModelLocker | A context class used to lock the model in VRAM while in use | +| `config` | AnyModelConfig | A copy of the model's configuration record for retrieving base type, etc. | +| `model` | AnyModel | The instantiated model (details below) | +| `locker` | ModelLockerBase | A context manager that mediates the movement of the model into VRAM | + +Because the loader can return multiple model types, it is typed to +return `AnyModel`, a Union `ModelMixin`, `torch.nn.Module`, +`IAIOnnxRuntimeModel`, `IPAdapter`, `IPAdapterPlus`, and +`EmbeddingModelRaw`. `ModelMixin` is the base class of all diffusers +models, `EmbeddingModelRaw` is used for LoRA and TextualInversion +models. The others are obvious. -The types for `ModelInfo` and `SubModelType` can be imported from -`invokeai.app.services.model_loader_service`. -To use the model, you use the `ModelInfo` as a context manager using -the following pattern: +`LoadedModel` acts as a context manager. The context loads the model +into the execution device (e.g. VRAM on CUDA systems), locks the model +in the execution device for the duration of the context, and returns +the model. Use it like this: ``` -model_info = loader.get_model('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae')) +model_info = loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae')) with model_info as vae: image = vae.decode(latents)[0] ``` -The `vae` model will stay locked in the GPU during the period of time -it is in the context manager's scope. +`get_model_by_key()` may raise any of the following exceptions: -`get_model()` may raise any of the following exceptions: - -- `UnknownModelException` -- key not in database -- `ModelNotFoundException` -- key in database but model not found at path -- `InvalidModelException` -- the model is guilty of a variety of sins +- `UnknownModelException` -- key not in database +- `ModelNotFoundException` -- key in database but model not found at path +- `NotImplementedException` -- the loader doesn't know how to load this type of model -** TO DO: ** Resolve discrepancy between ModelInfo.location and -ModelConfig.path. +### load_model_by_attr(model_name, base_model, model_type, [submodel], [context]) -> LoadedModel + +This is similar to `load_model_by_key`, but instead it accepts the +combination of the model's name, type and base, which it passes to the +model record config store for retrieval. If successful, this method +returns a `LoadedModel`. It can raise the following exceptions: + +``` +UnknownModelException -- model with these attributes not known +NotImplementedException -- the loader doesn't know how to load this type of model +ValueError -- more than one model matches this combination of base/type/name +``` + +### load_model_by_config(config, [submodel], [context]) -> LoadedModel + +This method takes an `AnyModelConfig` returned by +ModelRecordService.get_model() and returns the corresponding loaded +model. It may raise a `NotImplementedException`. ### Emitting model loading events -When the `context` argument is passed to `get_model()`, it will +When the `context` argument is passed to `load_model_*()`, it will retrieve the invocation event bus from the passed `InvocationContext` object to emit events on the invocation bus. The two events are "model_load_started" and "model_load_completed". Both carry the @@ -1563,3 +1556,97 @@ payload=dict( ) ``` +### Adding Model Loaders + +Model loaders are small classes that inherit from the `ModelLoader` +base class. They typically implement one method `_load_model()` whose +signature is: + +``` +def _load_model( + self, + model_path: Path, + model_variant: Optional[ModelRepoVariant] = None, + submodel_type: Optional[SubModelType] = None, +) -> AnyModel: +``` + +`_load_model()` will be passed the path to the model on disk, an +optional repository variant (used by the diffusers loaders to select, +e.g. the `fp16` variant, and an optional submodel_type for main and +onnx models. + +To install a new loader, place it in +`invokeai/backend/model_manager/load/model_loaders`. Inherit from +`ModelLoader` and use the `@AnyModelLoader.register()` decorator to +indicate what type of models the loader can handle. + +Here is a complete example from `generic_diffusers.py`, which is able +to load several different diffusers types: + +``` +from pathlib import Path +from typing import Optional + +from invokeai.backend.model_manager import ( + AnyModel, + BaseModelType, + ModelFormat, + ModelRepoVariant, + ModelType, + SubModelType, +) +from ..load_base import AnyModelLoader +from ..load_default import ModelLoader + + +@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.CLIPVision, format=ModelFormat.Diffusers) +@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.T2IAdapter, format=ModelFormat.Diffusers) +class GenericDiffusersLoader(ModelLoader): + """Class to load simple diffusers models.""" + + def _load_model( + self, + model_path: Path, + model_variant: Optional[ModelRepoVariant] = None, + submodel_type: Optional[SubModelType] = None, + ) -> AnyModel: + model_class = self._get_hf_load_class(model_path) + if submodel_type is not None: + raise Exception(f"There are no submodels in models of type {model_class}") + variant = model_variant.value if model_variant else None + result: AnyModel = model_class.from_pretrained(model_path, torch_dtype=self._torch_dtype, variant=variant) # type: ignore + return result +``` + +Note that a loader can register itself to handle several different +model types. An exception will be raised if more than one loader tries +to register the same model type. + +#### Conversion + +Some models require conversion to diffusers format before they can be +loaded. These loaders should override two additional methods: + +``` +_needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool +_convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path: +``` + +The first method accepts the model configuration, the path to where +the unmodified model is currently installed, and a proposed +destination for the converted model. This method returns True if the +model needs to be converted. It typically does this by comparing the +last modification time of the original model file to the modification +time of the converted model. In some cases you will also want to check +the modification date of the configuration record, in the event that +the user has changed something like the scheduler prediction type that +will require the model to be re-converted. See `controlnet.py` for an +example of this logic. + +The second method accepts the model configuration, the path to the +original model on disk, and the desired output path for the converted +model. It does whatever it needs to do to get the model into diffusers +format, and returns the Path of the resulting model. (The path should +ordinarily be the same as `output_path`.) + diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index dcb8d219971..378961a0557 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -8,9 +8,6 @@ from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache from invokeai.app.services.shared.sqlite.sqlite_util import init_db -from invokeai.backend.model_manager.load import AnyModelLoader, ModelConvertCache -from invokeai.backend.model_manager.load.model_cache import ModelCache -from invokeai.backend.model_manager.metadata import ModelMetadataStore from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData from invokeai.backend.util.logging import InvokeAILogger from invokeai.version.invokeai_version import __version__ @@ -30,9 +27,7 @@ from ..services.invocation_services import InvocationServices from ..services.invocation_stats.invocation_stats_default import InvocationStatsService from ..services.invoker import Invoker -from ..services.model_install import ModelInstallService from ..services.model_manager.model_manager_default import ModelManagerService -from ..services.model_records import ModelRecordServiceSQL from ..services.names.names_default import SimpleNameService from ..services.session_processor.session_processor_default import DefaultSessionProcessor from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue @@ -98,28 +93,10 @@ def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger conditioning = ObjectSerializerForwardCache( ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True) ) - model_manager = ModelManagerService(config, logger) - model_record_service = ModelRecordServiceSQL(db=db) - model_loader = AnyModelLoader( - app_config=config, - logger=logger, - ram_cache=ModelCache( - max_cache_size=config.ram_cache_size, max_vram_cache_size=config.vram_cache_size, logger=logger - ), - convert_cache=ModelConvertCache( - cache_path=config.models_convert_cache_path, max_size=config.convert_cache_size - ), - ) - model_record_service = ModelRecordServiceSQL(db=db, loader=model_loader) download_queue_service = DownloadQueueService(event_bus=events) - model_install_service = ModelInstallService( - app_config=config, - record_store=model_record_service, - download_queue=download_queue_service, - metadata_store=ModelMetadataStore(db=db), - event_bus=events, + model_manager = ModelManagerService.build_model_manager( + app_config=configuration, db=db, download_queue=download_queue_service, events=events ) - model_manager = ModelManagerService(config, logger) # TO DO: legacy model manager v1. Remove names = SimpleNameService() performance_statistics = InvocationStatsService() processor = DefaultInvocationProcessor() @@ -143,9 +120,7 @@ def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger invocation_cache=invocation_cache, logger=logger, model_manager=model_manager, - model_records=model_record_service, download_queue=download_queue_service, - model_install=model_install_service, names=names, performance_statistics=performance_statistics, processor=processor, diff --git a/invokeai/app/api/routers/model_records.py b/invokeai/app/api/routers/model_manager_v2.py similarity index 86% rename from invokeai/app/api/routers/model_records.py rename to invokeai/app/api/routers/model_manager_v2.py index f9a3e408985..4fc785e4f7a 100644 --- a/invokeai/app/api/routers/model_records.py +++ b/invokeai/app/api/routers/model_manager_v2.py @@ -32,7 +32,7 @@ from ..dependencies import ApiDependencies -model_records_router = APIRouter(prefix="/v1/model/record", tags=["model_manager_v2_unstable"]) +model_manager_v2_router = APIRouter(prefix="/v2/models", tags=["model_manager_v2"]) class ModelsList(BaseModel): @@ -52,7 +52,7 @@ class ModelTagSet(BaseModel): tags: Set[str] -@model_records_router.get( +@model_manager_v2_router.get( "/", operation_id="list_model_records", ) @@ -65,7 +65,7 @@ async def list_model_records( ), ) -> ModelsList: """Get a list of models.""" - record_store = ApiDependencies.invoker.services.model_records + record_store = ApiDependencies.invoker.services.model_manager.store found_models: list[AnyModelConfig] = [] if base_models: for base_model in base_models: @@ -81,7 +81,7 @@ async def list_model_records( return ModelsList(models=found_models) -@model_records_router.get( +@model_manager_v2_router.get( "/i/{key}", operation_id="get_model_record", responses={ @@ -94,24 +94,27 @@ async def get_model_record( key: str = Path(description="Key of the model record to fetch."), ) -> AnyModelConfig: """Get a model record""" - record_store = ApiDependencies.invoker.services.model_records + record_store = ApiDependencies.invoker.services.model_manager.store try: - return record_store.get_model(key) + config: AnyModelConfig = record_store.get_model(key) + return config except UnknownModelException as e: raise HTTPException(status_code=404, detail=str(e)) -@model_records_router.get("/meta", operation_id="list_model_summary") +@model_manager_v2_router.get("/meta", operation_id="list_model_summary") async def list_model_summary( page: int = Query(default=0, description="The page to get"), per_page: int = Query(default=10, description="The number of models per page"), order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Default, description="The attribute to order by"), ) -> PaginatedResults[ModelSummary]: """Gets a page of model summary data.""" - return ApiDependencies.invoker.services.model_records.list_models(page=page, per_page=per_page, order_by=order_by) + record_store = ApiDependencies.invoker.services.model_manager.store + results: PaginatedResults[ModelSummary] = record_store.list_models(page=page, per_page=per_page, order_by=order_by) + return results -@model_records_router.get( +@model_manager_v2_router.get( "/meta/i/{key}", operation_id="get_model_metadata", responses={ @@ -124,24 +127,25 @@ async def get_model_metadata( key: str = Path(description="Key of the model repo metadata to fetch."), ) -> Optional[AnyModelRepoMetadata]: """Get a model metadata object.""" - record_store = ApiDependencies.invoker.services.model_records - result = record_store.get_metadata(key) + record_store = ApiDependencies.invoker.services.model_manager.store + result: Optional[AnyModelRepoMetadata] = record_store.get_metadata(key) if not result: raise HTTPException(status_code=404, detail="No metadata for a model with this key") return result -@model_records_router.get( +@model_manager_v2_router.get( "/tags", operation_id="list_tags", ) async def list_tags() -> Set[str]: """Get a unique set of all the model tags.""" - record_store = ApiDependencies.invoker.services.model_records - return record_store.list_tags() + record_store = ApiDependencies.invoker.services.model_manager.store + result: Set[str] = record_store.list_tags() + return result -@model_records_router.get( +@model_manager_v2_router.get( "/tags/search", operation_id="search_by_metadata_tags", ) @@ -149,12 +153,12 @@ async def search_by_metadata_tags( tags: Set[str] = Query(default=None, description="Tags to search for"), ) -> ModelsList: """Get a list of models.""" - record_store = ApiDependencies.invoker.services.model_records + record_store = ApiDependencies.invoker.services.model_manager.store results = record_store.search_by_metadata_tag(tags) return ModelsList(models=results) -@model_records_router.patch( +@model_manager_v2_router.patch( "/i/{key}", operation_id="update_model_record", responses={ @@ -172,9 +176,9 @@ async def update_model_record( ) -> AnyModelConfig: """Update model contents with a new config. If the model name or base fields are changed, then the model is renamed.""" logger = ApiDependencies.invoker.services.logger - record_store = ApiDependencies.invoker.services.model_records + record_store = ApiDependencies.invoker.services.model_manager.store try: - model_response = record_store.update_model(key, config=info) + model_response: AnyModelConfig = record_store.update_model(key, config=info) logger.info(f"Updated model: {key}") except UnknownModelException as e: raise HTTPException(status_code=404, detail=str(e)) @@ -184,7 +188,7 @@ async def update_model_record( return model_response -@model_records_router.delete( +@model_manager_v2_router.delete( "/i/{key}", operation_id="del_model_record", responses={ @@ -205,7 +209,7 @@ async def del_model_record( logger = ApiDependencies.invoker.services.logger try: - installer = ApiDependencies.invoker.services.model_install + installer = ApiDependencies.invoker.services.model_manager.install installer.delete(key) logger.info(f"Deleted model: {key}") return Response(status_code=204) @@ -214,7 +218,7 @@ async def del_model_record( raise HTTPException(status_code=404, detail=str(e)) -@model_records_router.post( +@model_manager_v2_router.post( "/i/", operation_id="add_model_record", responses={ @@ -229,7 +233,7 @@ async def add_model_record( ) -> AnyModelConfig: """Add a model using the configuration information appropriate for its type.""" logger = ApiDependencies.invoker.services.logger - record_store = ApiDependencies.invoker.services.model_records + record_store = ApiDependencies.invoker.services.model_manager.store if config.key == "": config.key = sha1(randbytes(100)).hexdigest() logger.info(f"Created model {config.key} for {config.name}") @@ -243,10 +247,11 @@ async def add_model_record( raise HTTPException(status_code=415) # now fetch it out - return record_store.get_model(config.key) + result: AnyModelConfig = record_store.get_model(config.key) + return result -@model_records_router.post( +@model_manager_v2_router.post( "/import", operation_id="import_model_record", responses={ @@ -322,7 +327,7 @@ async def import_model( logger = ApiDependencies.invoker.services.logger try: - installer = ApiDependencies.invoker.services.model_install + installer = ApiDependencies.invoker.services.model_manager.install result: ModelInstallJob = installer.import_model( source=source, config=config, @@ -340,17 +345,17 @@ async def import_model( return result -@model_records_router.get( +@model_manager_v2_router.get( "/import", operation_id="list_model_install_jobs", ) async def list_model_install_jobs() -> List[ModelInstallJob]: """Return list of model install jobs.""" - jobs: List[ModelInstallJob] = ApiDependencies.invoker.services.model_install.list_jobs() + jobs: List[ModelInstallJob] = ApiDependencies.invoker.services.model_manager.install.list_jobs() return jobs -@model_records_router.get( +@model_manager_v2_router.get( "/import/{id}", operation_id="get_model_install_job", responses={ @@ -361,12 +366,13 @@ async def list_model_install_jobs() -> List[ModelInstallJob]: async def get_model_install_job(id: int = Path(description="Model install id")) -> ModelInstallJob: """Return model install job corresponding to the given source.""" try: - return ApiDependencies.invoker.services.model_install.get_job_by_id(id) + result: ModelInstallJob = ApiDependencies.invoker.services.model_manager.install.get_job_by_id(id) + return result except ValueError as e: raise HTTPException(status_code=404, detail=str(e)) -@model_records_router.delete( +@model_manager_v2_router.delete( "/import/{id}", operation_id="cancel_model_install_job", responses={ @@ -377,7 +383,7 @@ async def get_model_install_job(id: int = Path(description="Model install id")) ) async def cancel_model_install_job(id: int = Path(description="Model install job ID")) -> None: """Cancel the model install job(s) corresponding to the given job ID.""" - installer = ApiDependencies.invoker.services.model_install + installer = ApiDependencies.invoker.services.model_manager.install try: job = installer.get_job_by_id(id) except ValueError as e: @@ -385,7 +391,7 @@ async def cancel_model_install_job(id: int = Path(description="Model install job installer.cancel_job(job) -@model_records_router.patch( +@model_manager_v2_router.patch( "/import", operation_id="prune_model_install_jobs", responses={ @@ -395,11 +401,11 @@ async def cancel_model_install_job(id: int = Path(description="Model install job ) async def prune_model_install_jobs() -> Response: """Prune all completed and errored jobs from the install job list.""" - ApiDependencies.invoker.services.model_install.prune_jobs() + ApiDependencies.invoker.services.model_manager.install.prune_jobs() return Response(status_code=204) -@model_records_router.patch( +@model_manager_v2_router.patch( "/sync", operation_id="sync_models_to_config", responses={ @@ -414,11 +420,11 @@ async def sync_models_to_config() -> Response: Model files without a corresponding record in the database are added. Orphan records without a models file are deleted. """ - ApiDependencies.invoker.services.model_install.sync_to_config() + ApiDependencies.invoker.services.model_manager.install.sync_to_config() return Response(status_code=204) -@model_records_router.put( +@model_manager_v2_router.put( "/merge", operation_id="merge", ) @@ -451,7 +457,7 @@ async def merge( try: logger.info(f"Merging models: {keys} into {merge_dest_directory or ''}/{merged_model_name}") dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None - installer = ApiDependencies.invoker.services.model_install + installer = ApiDependencies.invoker.services.model_manager.install merger = ModelMerger(installer) model_names = [installer.record_store.get_model(x).name for x in keys] response = merger.merge_diffusion_models_and_save( diff --git a/invokeai/app/api/routers/models.py b/invokeai/app/api/routers/models.py index 8f83820cf89..0aa7aa0ecba 100644 --- a/invokeai/app/api/routers/models.py +++ b/invokeai/app/api/routers/models.py @@ -8,8 +8,7 @@ from pydantic import BaseModel, ConfigDict, Field, TypeAdapter from starlette.exceptions import HTTPException -from invokeai.backend import BaseModelType, ModelType -from invokeai.backend.model_management import MergeInterpolationMethod +from invokeai.backend.model_management import BaseModelType, MergeInterpolationMethod, ModelType from invokeai.backend.model_management.models import ( OPENAPI_MODEL_CONFIGS, InvalidModelException, diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index f48074de7c7..851cbc8160e 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -48,7 +48,7 @@ boards, download_queue, images, - model_records, + model_manager_v2, models, session_queue, sessions, @@ -114,8 +114,7 @@ async def shutdown_event() -> None: app.include_router(sessions.session_router, prefix="/api") app.include_router(utilities.utilities_router, prefix="/api") -app.include_router(models.models_router, prefix="/api") -app.include_router(model_records.model_records_router, prefix="/api") +app.include_router(model_manager_v2.model_manager_v2_router, prefix="/api") app.include_router(download_queue.download_queue_router, prefix="/api") app.include_router(images.images_router, prefix="/api") app.include_router(boards.boards_router, prefix="/api") diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 0e1a6bdc6fb..3850fb6cc3d 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -3,6 +3,7 @@ import torch from compel import Compel, ReturnedEmbeddingsType from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment +from transformers import CLIPTokenizer import invokeai.backend.util.logging as logger from invokeai.app.invocations.fields import ( @@ -68,18 +69,18 @@ class CompelInvocation(BaseInvocation): @torch.no_grad() def invoke(self, context: InvocationContext) -> ConditioningOutput: - tokenizer_info = context.services.model_records.load_model( + tokenizer_info = context.services.model_manager.load.load_model_by_key( **self.clip.tokenizer.model_dump(), context=context, ) - text_encoder_info = context.services.model_records.load_model( + text_encoder_info = context.services.model_manager.load.load_model_by_key( **self.clip.text_encoder.model_dump(), context=context, ) def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: for lora in self.clip.loras: - lora_info = context.services.model_records.load_model( + lora_info = context.services.model_manager.load.load_model_by_key( **lora.model_dump(exclude={"weight"}), context=context ) assert isinstance(lora_info.model, LoRAModelRaw) @@ -93,7 +94,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: for trigger in extract_ti_triggers_from_prompt(self.prompt): name = trigger[1:-1] try: - loaded_model = context.services.model_records.load_model( + loaded_model = context.services.model_manager.load.load_model_by_key( **self.clip.text_encoder.model_dump(), context=context, ).model @@ -164,11 +165,11 @@ def run_clip_compel( lora_prefix: str, zero_on_empty: bool, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[ExtraConditioningInfo]]: - tokenizer_info = context.services.model_records.load_model( + tokenizer_info = context.services.model_manager.load.load_model_by_key( **clip_field.tokenizer.model_dump(), context=context, ) - text_encoder_info = context.services.model_records.load_model( + text_encoder_info = context.services.model_manager.load.load_model_by_key( **clip_field.text_encoder.model_dump(), context=context, ) @@ -196,7 +197,7 @@ def run_clip_compel( def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: for lora in clip_field.loras: - lora_info = context.services.model_records.load_model( + lora_info = context.services.model_manager.load.load_model_by_key( **lora.model_dump(exclude={"weight"}), context=context ) lora_model = lora_info.model @@ -211,7 +212,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: for trigger in extract_ti_triggers_from_prompt(prompt): name = trigger[1:-1] try: - ti_model = context.services.model_records.load_model_by_attr( + ti_model = context.services.model_manager.load.load_model_by_attr( model_name=name, base_model=text_encoder_info.config.base, model_type=ModelType.TextualInversion, @@ -448,9 +449,9 @@ def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput: def get_max_token_count( - tokenizer, + tokenizer: CLIPTokenizer, prompt: Union[FlattenedPrompt, Blend, Conjunction], - truncate_if_too_long=False, + truncate_if_too_long: bool = False, ) -> int: if type(prompt) is Blend: blend: Blend = prompt @@ -462,7 +463,9 @@ def get_max_token_count( return len(get_tokens_for_prompt_object(tokenizer, prompt, truncate_if_too_long)) -def get_tokens_for_prompt_object(tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True) -> List[str]: +def get_tokens_for_prompt_object( + tokenizer: CLIPTokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long: bool = True +) -> List[str]: if type(parsed_prompt) is Blend: raise ValueError("Blend is not supported here - you need to get tokens for each of its .children") @@ -475,24 +478,29 @@ def get_tokens_for_prompt_object(tokenizer, parsed_prompt: FlattenedPrompt, trun for x in parsed_prompt.children ] text = " ".join(text_fragments) - tokens = tokenizer.tokenize(text) + tokens: List[str] = tokenizer.tokenize(text) if truncate_if_too_long: max_tokens_length = tokenizer.model_max_length - 2 # typically 75 tokens = tokens[0:max_tokens_length] return tokens -def log_tokenization_for_conjunction(c: Conjunction, tokenizer, display_label_prefix=None): +def log_tokenization_for_conjunction( + c: Conjunction, tokenizer: CLIPTokenizer, display_label_prefix: Optional[str] = None +) -> None: display_label_prefix = display_label_prefix or "" for i, p in enumerate(c.prompts): if len(c.prompts) > 1: this_display_label_prefix = f"{display_label_prefix}(conjunction part {i + 1}, weight={c.weights[i]})" else: + assert display_label_prefix is not None this_display_label_prefix = display_label_prefix log_tokenization_for_prompt_object(p, tokenizer, display_label_prefix=this_display_label_prefix) -def log_tokenization_for_prompt_object(p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None): +def log_tokenization_for_prompt_object( + p: Union[Blend, FlattenedPrompt], tokenizer: CLIPTokenizer, display_label_prefix: Optional[str] = None +) -> None: display_label_prefix = display_label_prefix or "" if type(p) is Blend: blend: Blend = p @@ -532,7 +540,12 @@ def log_tokenization_for_prompt_object(p: Union[Blend, FlattenedPrompt], tokeniz log_tokenization_for_text(text, tokenizer, display_label=display_label_prefix) -def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_too_long=False): +def log_tokenization_for_text( + text: str, + tokenizer: CLIPTokenizer, + display_label: Optional[str] = None, + truncate_if_too_long: Optional[bool] = False, +) -> None: """shows how the prompt is tokenized # usually tokens have '' to indicate end-of-word, # but for readability it has been replaced with ' ' diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 063b23fa589..289da2dd73d 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -3,13 +3,15 @@ import math from contextlib import ExitStack from functools import singledispatchmethod -from typing import Iterator, List, Literal, Optional, Tuple, Union +from typing import Any, Iterator, List, Literal, Optional, Tuple, Union import einops import numpy as np +import numpy.typing as npt import torch import torchvision.transforms as T -from diffusers import AutoencoderKL, AutoencoderTiny, UNet2DConditionModel +from diffusers import AutoencoderKL, AutoencoderTiny +from diffusers.configuration_utils import ConfigMixin from diffusers.image_processor import VaeImageProcessor from diffusers.models.adapter import T2IAdapter from diffusers.models.attention_processor import ( @@ -18,8 +20,10 @@ LoRAXFormersAttnProcessor, XFormersAttnProcessor, ) +from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel from diffusers.schedulers import DPMSolverSDEScheduler from diffusers.schedulers import SchedulerMixin as Scheduler +from PIL import Image from pydantic import field_validator from torchvision.transforms.functional import resize as tv_resize @@ -46,9 +50,10 @@ from invokeai.app.invocations.t2i_adapter import T2IAdapterField from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.controlnet_utils import prepare_control_image +from invokeai.backend.embeddings.lora import LoRAModelRaw from invokeai.backend.embeddings.model_patcher import ModelPatcher from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus -from invokeai.backend.model_manager import AnyModel, BaseModelType +from invokeai.backend.model_manager import BaseModelType, LoadedModel from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData, IPAdapterConditioningInfo from invokeai.backend.util.silence_warnings import SilenceWarnings @@ -123,10 +128,10 @@ class CreateDenoiseMaskInvocation(BaseInvocation): ui_order=4, ) - def prep_mask_tensor(self, mask_image): + def prep_mask_tensor(self, mask_image: Image) -> torch.Tensor: if mask_image.mode != "L": mask_image = mask_image.convert("L") - mask_tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False) + mask_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False) if mask_tensor.dim() == 3: mask_tensor = mask_tensor.unsqueeze(0) # if shape is not None: @@ -136,25 +141,25 @@ def prep_mask_tensor(self, mask_image): @torch.no_grad() def invoke(self, context: InvocationContext) -> DenoiseMaskOutput: if self.image is not None: - image = context.images.get_pil(self.image.image_name) - image = image_resized_to_grid_as_tensor(image.convert("RGB")) - if image.dim() == 3: - image = image.unsqueeze(0) + image = context.services.images.get_pil_image(self.image.image_name) + image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB")) + if image_tensor.dim() == 3: + image_tensor = image_tensor.unsqueeze(0) else: - image = None + image_tensor = None mask = self.prep_mask_tensor( context.images.get_pil(self.mask.image_name), ) - if image is not None: - vae_info = context.services.model_records.load_model( + if image_tensor is not None: + vae_info = context.services.model_manager.load.load_model_by_key( **self.vae.vae.model_dump(), context=context, ) - img_mask = tv_resize(mask, image.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False) - masked_image = image * torch.where(img_mask < 0.5, 0.0, 1.0) + img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False) + masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0) # TODO: masked_latents = ImageToLatentsInvocation.vae_encode(vae_info, self.fp32, self.tiled, masked_image.clone()) @@ -177,7 +182,7 @@ def get_scheduler( seed: int, ) -> Scheduler: scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"]) - orig_scheduler_info = context.services.model_records.load_model( + orig_scheduler_info = context.services.model_manager.load.load_model_by_key( **scheduler_info.model_dump(), context=context, ) @@ -188,7 +193,7 @@ def get_scheduler( scheduler_config = scheduler_config["_backup"] scheduler_config = { **scheduler_config, - **scheduler_extra_config, + **scheduler_extra_config, # FIXME "_backup": scheduler_config, } @@ -201,6 +206,7 @@ def get_scheduler( # hack copied over from generate.py if not hasattr(scheduler, "uses_inpainting_model"): scheduler.uses_inpainting_model = lambda: False + assert isinstance(scheduler, Scheduler) return scheduler @@ -284,7 +290,7 @@ class DenoiseLatentsInvocation(BaseInvocation): ) @field_validator("cfg_scale") - def ge_one(cls, v): + def ge_one(cls, v: Union[List[float], float]) -> Union[List[float], float]: """validate that all cfg_scale values are >= 1""" if isinstance(v, list): for i in v: @@ -298,9 +304,9 @@ def ge_one(cls, v): def get_conditioning_data( self, context: InvocationContext, - scheduler, - unet, - seed, + scheduler: Scheduler, + unet: UNet2DConditionModel, + seed: int, ) -> ConditioningData: positive_cond_data = context.conditioning.load(self.positive_conditioning.conditioning_name) c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype) @@ -323,7 +329,7 @@ def get_conditioning_data( ), ) - conditioning_data = conditioning_data.add_scheduler_args_if_applicable( + conditioning_data = conditioning_data.add_scheduler_args_if_applicable( # FIXME scheduler, # for ddim scheduler eta=0.0, # ddim_eta @@ -335,8 +341,8 @@ def get_conditioning_data( def create_pipeline( self, - unet, - scheduler, + unet: UNet2DConditionModel, + scheduler: Scheduler, ) -> StableDiffusionGeneratorPipeline: # TODO: # configure_model_padding( @@ -347,10 +353,10 @@ def create_pipeline( class FakeVae: class FakeVaeConfig: - def __init__(self): + def __init__(self) -> None: self.block_out_channels = [0] - def __init__(self): + def __init__(self) -> None: self.config = FakeVae.FakeVaeConfig() return StableDiffusionGeneratorPipeline( @@ -367,11 +373,11 @@ def __init__(self): def prep_control_data( self, context: InvocationContext, - control_input: Union[ControlField, List[ControlField]], + control_input: Optional[Union[ControlField, List[ControlField]]], latents_shape: List[int], exit_stack: ExitStack, do_classifier_free_guidance: bool = True, - ) -> List[ControlNetData]: + ) -> Optional[List[ControlNetData]]: # Assuming fixed dimensional scaling of LATENT_SCALE_FACTOR. control_height_resize = latents_shape[2] * LATENT_SCALE_FACTOR control_width_resize = latents_shape[3] * LATENT_SCALE_FACTOR @@ -394,7 +400,7 @@ def prep_control_data( controlnet_data = [] for control_info in control_list: control_model = exit_stack.enter_context( - context.services.model_records.load_model( + context.services.model_manager.load.load_model_by_key( key=control_info.control_model.key, context=context, ) @@ -460,23 +466,25 @@ def prep_ip_adapter_data( conditioning_data.ip_adapter_conditioning = [] for single_ip_adapter in ip_adapter: ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context( - context.services.model_records.load_model( + context.services.model_manager.load.load_model_by_key( key=single_ip_adapter.ip_adapter_model.key, context=context, ) ) - image_encoder_model_info = context.services.model_records.load_model( + image_encoder_model_info = context.services.model_manager.load.load_model_by_key( key=single_ip_adapter.image_encoder_model.key, context=context, ) # `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here. - single_ipa_images = single_ip_adapter.image - if not isinstance(single_ipa_images, list): - single_ipa_images = [single_ipa_images] + single_ipa_image_fields = single_ip_adapter.image + if not isinstance(single_ipa_image_fields, list): + single_ipa_image_fields = [single_ipa_image_fields] - single_ipa_images = [context.images.get_pil(image.image_name) for image in single_ipa_images] + single_ipa_images = [ + context.services.images.get_pil_image(image.image_name) for image in single_ipa_image_fields + ] # TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other # models are needed in memory. This would help to reduce peak memory utilization in low-memory environments. @@ -520,21 +528,19 @@ def run_t2i_adapters( t2i_adapter_data = [] for t2i_adapter_field in t2i_adapter: - t2i_adapter_model_info = context.services.model_records.load_model( + t2i_adapter_model_info = context.services.model_manager.load.load_model_by_key( key=t2i_adapter_field.t2i_adapter_model.key, context=context, ) image = context.images.get_pil(t2i_adapter_field.image.image_name) # The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally. - if t2i_adapter_field.t2i_adapter_model.base_model == BaseModelType.StableDiffusion1: + if t2i_adapter_model_info.base == BaseModelType.StableDiffusion1: max_unet_downscale = 8 - elif t2i_adapter_field.t2i_adapter_model.base_model == BaseModelType.StableDiffusionXL: + elif t2i_adapter_model_info.base == BaseModelType.StableDiffusionXL: max_unet_downscale = 4 else: - raise ValueError( - f"Unexpected T2I-Adapter base model type: '{t2i_adapter_field.t2i_adapter_model.base_model}'." - ) + raise ValueError(f"Unexpected T2I-Adapter base model type: '{t2i_adapter_model_info.base}'.") t2i_adapter_model: T2IAdapter with t2i_adapter_model_info as t2i_adapter_model: @@ -582,7 +588,15 @@ def run_t2i_adapters( # original idea by https://github.com/AmericanPresidentJimmyCarter # TODO: research more for second order schedulers timesteps - def init_scheduler(self, scheduler, device, steps, denoising_start, denoising_end): + def init_scheduler( + self, + scheduler: Union[Scheduler, ConfigMixin], + device: torch.device, + steps: int, + denoising_start: float, + denoising_end: float, + ) -> Tuple[int, List[int], int]: + assert isinstance(scheduler, ConfigMixin) if scheduler.config.get("cpu_only", False): scheduler.set_timesteps(steps, device="cpu") timesteps = scheduler.timesteps.to(device=device) @@ -594,11 +608,11 @@ def init_scheduler(self, scheduler, device, steps, denoising_start, denoising_en _timesteps = timesteps[:: scheduler.order] # get start timestep index - t_start_val = int(round(scheduler.config.num_train_timesteps * (1 - denoising_start))) + t_start_val = int(round(scheduler.config["num_train_timesteps"] * (1 - denoising_start))) t_start_idx = len(list(filter(lambda ts: ts >= t_start_val, _timesteps))) # get end timestep index - t_end_val = int(round(scheduler.config.num_train_timesteps * (1 - denoising_end))) + t_end_val = int(round(scheduler.config["num_train_timesteps"] * (1 - denoising_end))) t_end_idx = len(list(filter(lambda ts: ts >= t_end_val, _timesteps[t_start_idx:]))) # apply order to indexes @@ -611,7 +625,9 @@ def init_scheduler(self, scheduler, device, steps, denoising_start, denoising_en return num_inference_steps, timesteps, init_timestep - def prep_inpaint_mask(self, context: InvocationContext, latents): + def prep_inpaint_mask( + self, context: InvocationContext, latents: torch.Tensor + ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: if self.denoise_mask is None: return None, None @@ -660,12 +676,19 @@ def invoke(self, context: InvocationContext) -> LatentsOutput: do_classifier_free_guidance=True, ) - def step_callback(state: PipelineIntermediateState): - context.util.sd_step_callback(state, self.unet.unet.base_model) + # Get the source node id (we are invoking the prepared node) + graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) + source_node_id = graph_execution_state.prepared_source_mapping[self.id] + + # get the unet's config so that we can pass the base to dispatch_progress() + unet_config = context.services.model_manager.store.get_model(**self.unet.unet.model_dump()) - def _lora_loader() -> Iterator[Tuple[AnyModel, float]]: + def step_callback(state: PipelineIntermediateState) -> None: + self.dispatch_progress(context, source_node_id, state, unet_config.base) + + def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: for lora in self.unet.loras: - lora_info = context.services.model_records.load_model( + lora_info = context.services.model_manager.load.load_model_by_key( **lora.model_dump(exclude={"weight"}), context=context, ) @@ -673,7 +696,7 @@ def _lora_loader() -> Iterator[Tuple[AnyModel, float]]: del lora_info return - unet_info = context.services.model_records.load_model( + unet_info = context.services.model_manager.load.load_model_by_key( **self.unet.unet.model_dump(), context=context, ) @@ -783,7 +806,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard): def invoke(self, context: InvocationContext) -> ImageOutput: latents = context.tensors.load(self.latents.latents_name) - vae_info = context.services.model_records.load_model( + vae_info = context.services.model_manager.load.load_model_by_key( **self.vae.vae.model_dump(), context=context, ) @@ -961,8 +984,9 @@ class ImageToLatentsInvocation(BaseInvocation): fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32) @staticmethod - def vae_encode(vae_info, upcast, tiled, image_tensor): + def vae_encode(vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor) -> torch.Tensor: with vae_info as vae: + assert isinstance(vae, torch.nn.Module) orig_dtype = vae.dtype if upcast: vae.to(dtype=torch.float32) @@ -1008,7 +1032,7 @@ def vae_encode(vae_info, upcast, tiled, image_tensor): def invoke(self, context: InvocationContext) -> LatentsOutput: image = context.images.get_pil(self.image.image_name) - vae_info = context.services.model_records.load_model( + vae_info = context.services.model_manager.load.load_model_by_key( **self.vae.vae.model_dump(), context=context, ) @@ -1026,14 +1050,19 @@ def invoke(self, context: InvocationContext) -> LatentsOutput: @singledispatchmethod @staticmethod def _encode_to_tensor(vae: AutoencoderKL, image_tensor: torch.FloatTensor) -> torch.FloatTensor: + assert isinstance(vae, torch.nn.Module) image_tensor_dist = vae.encode(image_tensor).latent_dist - latents = image_tensor_dist.sample().to(dtype=vae.dtype) # FIXME: uses torch.randn. make reproducible! + latents: torch.Tensor = image_tensor_dist.sample().to( + dtype=vae.dtype + ) # FIXME: uses torch.randn. make reproducible! return latents @_encode_to_tensor.register @staticmethod def _(vae: AutoencoderTiny, image_tensor: torch.FloatTensor) -> torch.FloatTensor: - return vae.encode(image_tensor).latents + assert isinstance(vae, torch.nn.Module) + latents: torch.FloatTensor = vae.encode(image_tensor).latents + return latents @invocation( @@ -1066,7 +1095,12 @@ def invoke(self, context: InvocationContext) -> LatentsOutput: # TODO: device = choose_torch_device() - def slerp(t, v0, v1, DOT_THRESHOLD=0.9995): + def slerp( + t: Union[float, npt.NDArray[Any]], # FIXME: maybe use np.float32 here? + v0: Union[torch.Tensor, npt.NDArray[Any]], + v1: Union[torch.Tensor, npt.NDArray[Any]], + DOT_THRESHOLD: float = 0.9995, + ) -> Union[torch.Tensor, npt.NDArray[Any]]: """ Spherical linear interpolation Args: @@ -1099,12 +1133,16 @@ def slerp(t, v0, v1, DOT_THRESHOLD=0.9995): v2 = s0 * v0 + s1 * v1 if inputs_are_torch: - v2 = torch.from_numpy(v2).to(device) - - return v2 + v2_torch: torch.Tensor = torch.from_numpy(v2).to(device) + return v2_torch + else: + assert isinstance(v2, np.ndarray) + return v2 # blend - blended_latents = slerp(self.alpha, latents_a, latents_b) + bl = slerp(self.alpha, latents_a, latents_b) + assert isinstance(bl, torch.Tensor) + blended_latents: torch.Tensor = bl # for type checking convenience # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 blended_latents = blended_latents.to("cpu") @@ -1197,15 +1235,19 @@ class IdealSizeInvocation(BaseInvocation): description="Amount to multiply the model's dimensions by when calculating the ideal size (may result in initial generation artifacts if too large)", ) - def trim_to_multiple_of(self, *args, multiple_of=LATENT_SCALE_FACTOR): + def trim_to_multiple_of(self, *args: int, multiple_of: int = LATENT_SCALE_FACTOR) -> Tuple[int, ...]: return tuple((x - x % multiple_of) for x in args) def invoke(self, context: InvocationContext) -> IdealSizeOutput: + unet_config = context.services.model_manager.load.load_model_by_key( + **self.unet.unet.model_dump(), + context=context, + ) aspect = self.width / self.height - dimension = 512 - if self.unet.unet.base_model == BaseModelType.StableDiffusion2: + dimension: float = 512 + if unet_config.base == BaseModelType.StableDiffusion2: dimension = 768 - elif self.unet.unet.base_model == BaseModelType.StableDiffusionXL: + elif unet_config.base == BaseModelType.StableDiffusionXL: dimension = 1024 dimension = dimension * self.multiplier min_dimension = math.floor(dimension * 0.5) diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index e2ea7442839..fa6e8b98da0 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -17,7 +17,7 @@ class ModelInfo(BaseModel): - key: str = Field(description="Info to load submodel") + key: str = Field(description="Key of model as returned by ModelRecordServiceBase.get_model()") submodel: Optional[SubModelType] = Field(default=None, description="Info to load submodel") diff --git a/invokeai/app/services/invocation_services.py b/invokeai/app/services/invocation_services.py index e893be87636..0a1fa1e9222 100644 --- a/invokeai/app/services/invocation_services.py +++ b/invokeai/app/services/invocation_services.py @@ -27,9 +27,7 @@ from .invocation_queue.invocation_queue_base import InvocationQueueABC from .invocation_stats.invocation_stats_base import InvocationStatsServiceBase from .item_storage.item_storage_base import ItemStorageABC - from .model_install import ModelInstallServiceBase from .model_manager.model_manager_base import ModelManagerServiceBase - from .model_records import ModelRecordServiceBase from .names.names_base import NameServiceBase from .session_processor.session_processor_base import SessionProcessorBase from .session_queue.session_queue_base import SessionQueueBase @@ -55,9 +53,7 @@ def __init__( image_records: "ImageRecordStorageBase", logger: "Logger", model_manager: "ModelManagerServiceBase", - model_records: "ModelRecordServiceBase", download_queue: "DownloadQueueServiceBase", - model_install: "ModelInstallServiceBase", processor: "InvocationProcessorABC", performance_statistics: "InvocationStatsServiceBase", queue: "InvocationQueueABC", @@ -82,9 +78,7 @@ def __init__( self.image_records = image_records self.logger = logger self.model_manager = model_manager - self.model_records = model_records self.download_queue = download_queue - self.model_install = model_install self.processor = processor self.performance_statistics = performance_statistics self.queue = queue diff --git a/invokeai/app/services/invocation_stats/invocation_stats_default.py b/invokeai/app/services/invocation_stats/invocation_stats_default.py index 0c63b545ff2..6c893021de4 100644 --- a/invokeai/app/services/invocation_stats/invocation_stats_default.py +++ b/invokeai/app/services/invocation_stats/invocation_stats_default.py @@ -43,8 +43,10 @@ def start(self, invoker: Invoker) -> None: @contextmanager def collect_stats(self, invocation: BaseInvocation, graph_execution_state_id: str) -> Iterator[None]: + # This is to handle case of the model manager not being initialized, which happens + # during some tests. services = self._invoker.services - if services.model_records is None or services.model_records.loader is None: + if services.model_manager is None or services.model_manager.load is None: yield None if not self._stats.get(graph_execution_state_id): # First time we're seeing this graph_execution_state_id. @@ -60,9 +62,8 @@ def collect_stats(self, invocation: BaseInvocation, graph_execution_state_id: st if torch.cuda.is_available(): torch.cuda.reset_peak_memory_stats() - # TO DO [LS]: clean up loader service - shouldn't be an attribute of model records - assert services.model_records.loader is not None - services.model_records.loader.ram_cache.stats = self._cache_stats[graph_execution_state_id] + assert services.model_manager.load is not None + services.model_manager.load.ram_cache.stats = self._cache_stats[graph_execution_state_id] try: # Let the invocation run. diff --git a/invokeai/app/services/model_load/model_load_base.py b/invokeai/app/services/model_load/model_load_base.py index 7228806e809..f298d98ce6d 100644 --- a/invokeai/app/services/model_load/model_load_base.py +++ b/invokeai/app/services/model_load/model_load_base.py @@ -4,7 +4,8 @@ from abc import ABC, abstractmethod from typing import Optional -from invokeai.backend.model_manager import AnyModelConfig, SubModelType +from invokeai.app.invocations.baseinvocation import InvocationContext +from invokeai.backend.model_manager import AnyModelConfig, BaseModelType, ModelType, SubModelType from invokeai.backend.model_manager.load import LoadedModel @@ -12,11 +13,60 @@ class ModelLoadServiceBase(ABC): """Wrapper around AnyModelLoader.""" @abstractmethod - def load_model_by_key(self, key: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel: - """Given a model's key, load it and return the LoadedModel object.""" + def load_model_by_key( + self, + key: str, + submodel_type: Optional[SubModelType] = None, + context: Optional[InvocationContext] = None, + ) -> LoadedModel: + """ + Given a model's key, load it and return the LoadedModel object. + + :param key: Key of model config to be fetched. + :param submodel: For main (pipeline models), the submodel to fetch. + :param context: Invocation context used for event reporting + """ pass @abstractmethod - def load_model_by_config(self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel: - """Given a model's configuration, load it and return the LoadedModel object.""" + def load_model_by_config( + self, + model_config: AnyModelConfig, + submodel_type: Optional[SubModelType] = None, + context: Optional[InvocationContext] = None, + ) -> LoadedModel: + """ + Given a model's configuration, load it and return the LoadedModel object. + + :param model_config: Model configuration record (as returned by ModelRecordBase.get_model()) + :param submodel: For main (pipeline models), the submodel to fetch. + :param context: Invocation context used for event reporting + """ pass + + @abstractmethod + def load_model_by_attr( + self, + model_name: str, + base_model: BaseModelType, + model_type: ModelType, + submodel: Optional[SubModelType] = None, + context: Optional[InvocationContext] = None, + ) -> LoadedModel: + """ + Given a model's attributes, search the database for it, and if found, load and return the LoadedModel object. + + This is provided for API compatability with the get_model() method + in the original model manager. However, note that LoadedModel is + not the same as the original ModelInfo that ws returned. + + :param model_name: Name of to be fetched. + :param base_model: Base model + :param model_type: Type of the model + :param submodel: For main (pipeline models), the submodel to fetch + :param context: The invocation context. + + Exceptions: UnknownModelException -- model with these attributes not known + NotImplementedException -- a model loader was not provided at initialization time + ValueError -- more than one model matches this combination + """ diff --git a/invokeai/app/services/model_load/model_load_default.py b/invokeai/app/services/model_load/model_load_default.py index 80e2fe161d0..67107cada6e 100644 --- a/invokeai/app/services/model_load/model_load_default.py +++ b/invokeai/app/services/model_load/model_load_default.py @@ -3,12 +3,14 @@ from typing import Optional +from invokeai.app.invocations.baseinvocation import InvocationContext from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.app.services.model_records import ModelRecordServiceBase -from invokeai.backend.model_manager import AnyModelConfig, SubModelType +from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException +from invokeai.app.services.model_records import ModelRecordServiceBase, UnknownModelException +from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType from invokeai.backend.model_manager.load import AnyModelLoader, LoadedModel, ModelCache, ModelConvertCache from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase -from invokeai.backend.model_manager.load.ram_cache import ModelCacheBase +from invokeai.backend.model_manager.load.model_cache import ModelCacheBase from invokeai.backend.util.logging import InvokeAILogger from .model_load_base import ModelLoadServiceBase @@ -21,7 +23,7 @@ def __init__( self, app_config: InvokeAIAppConfig, record_store: ModelRecordServiceBase, - ram_cache: Optional[ModelCacheBase] = None, + ram_cache: Optional[ModelCacheBase[AnyModel]] = None, convert_cache: Optional[ModelConvertCacheBase] = None, ): """Initialize the model load service.""" @@ -44,11 +46,104 @@ def __init__( ), ) - def load_model_by_key(self, key: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel: - """Given a model's key, load it and return the LoadedModel object.""" + def load_model_by_key( + self, + key: str, + submodel_type: Optional[SubModelType] = None, + context: Optional[InvocationContext] = None, + ) -> LoadedModel: + """ + Given a model's key, load it and return the LoadedModel object. + + :param key: Key of model config to be fetched. + :param submodel: For main (pipeline models), the submodel to fetch. + :param context: Invocation context used for event reporting + """ config = self._store.get_model(key) - return self.load_model_by_config(config, submodel_type) + return self.load_model_by_config(config, submodel_type, context) + + def load_model_by_attr( + self, + model_name: str, + base_model: BaseModelType, + model_type: ModelType, + submodel: Optional[SubModelType] = None, + context: Optional[InvocationContext] = None, + ) -> LoadedModel: + """ + Given a model's attributes, search the database for it, and if found, load and return the LoadedModel object. + + This is provided for API compatability with the get_model() method + in the original model manager. However, note that LoadedModel is + not the same as the original ModelInfo that ws returned. + + :param model_name: Name of to be fetched. + :param base_model: Base model + :param model_type: Type of the model + :param submodel: For main (pipeline models), the submodel to fetch + :param context: The invocation context. + + Exceptions: UnknownModelException -- model with this key not known + NotImplementedException -- a model loader was not provided at initialization time + ValueError -- more than one model matches this combination + """ + configs = self._store.search_by_attr(model_name, base_model, model_type) + if len(configs) == 0: + raise UnknownModelException(f"{base_model}/{model_type}/{model_name}: Unknown model") + elif len(configs) > 1: + raise ValueError(f"{base_model}/{model_type}/{model_name}: More than one model matches.") + else: + return self.load_model_by_key(configs[0].key, submodel) + + def load_model_by_config( + self, + model_config: AnyModelConfig, + submodel_type: Optional[SubModelType] = None, + context: Optional[InvocationContext] = None, + ) -> LoadedModel: + """ + Given a model's configuration, load it and return the LoadedModel object. + + :param model_config: Model configuration record (as returned by ModelRecordBase.get_model()) + :param submodel: For main (pipeline models), the submodel to fetch. + :param context: Invocation context used for event reporting + """ + if context: + self._emit_load_event( + context=context, + model_config=model_config, + ) + loaded_model = self._any_loader.load_model(model_config, submodel_type) + if context: + self._emit_load_event( + context=context, + model_config=model_config, + loaded=True, + ) + return loaded_model + + def _emit_load_event( + self, + context: InvocationContext, + model_config: AnyModelConfig, + loaded: Optional[bool] = False, + ) -> None: + if context.services.queue.is_canceled(context.graph_execution_state_id): + raise CanceledException() - def load_model_by_config(self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel: - """Given a model's configuration, load it and return the LoadedModel object.""" - return self._any_loader.load_model(config, submodel_type) + if not loaded: + context.services.events.emit_model_load_started( + queue_id=context.queue_id, + queue_item_id=context.queue_item_id, + queue_batch_id=context.queue_batch_id, + graph_execution_state_id=context.graph_execution_state_id, + model_config=model_config, + ) + else: + context.services.events.emit_model_load_completed( + queue_id=context.queue_id, + queue_item_id=context.queue_item_id, + queue_batch_id=context.queue_batch_id, + graph_execution_state_id=context.graph_execution_state_id, + model_config=model_config, + ) diff --git a/invokeai/app/services/model_manager/model_manager_base.py b/invokeai/app/services/model_manager/model_manager_base.py index c6e77fa163d..1116c82ff1f 100644 --- a/invokeai/app/services/model_manager/model_manager_base.py +++ b/invokeai/app/services/model_manager/model_manager_base.py @@ -2,9 +2,10 @@ from abc import ABC, abstractmethod -from pydantic import BaseModel, Field from typing_extensions import Self +from invokeai.app.services.invoker import Invoker + from ..config import InvokeAIAppConfig from ..download import DownloadQueueServiceBase from ..events.events_base import EventServiceBase @@ -14,12 +15,13 @@ from ..shared.sqlite.sqlite_database import SqliteDatabase -class ModelManagerServiceBase(BaseModel, ABC): +class ModelManagerServiceBase(ABC): """Abstract base class for the model manager service.""" - store: ModelRecordServiceBase = Field(description="An instance of the model record configuration service.") - install: ModelInstallServiceBase = Field(description="An instance of the model install service.") - load: ModelLoadServiceBase = Field(description="An instance of the model load service.") + # attributes: + # store: ModelRecordServiceBase = Field(description="An instance of the model record configuration service.") + # install: ModelInstallServiceBase = Field(description="An instance of the model install service.") + # load: ModelLoadServiceBase = Field(description="An instance of the model load service.") @classmethod @abstractmethod @@ -37,3 +39,29 @@ def build_model_manager( method simplifies the construction considerably. """ pass + + @property + @abstractmethod + def store(self) -> ModelRecordServiceBase: + """Return the ModelRecordServiceBase used to store and retrieve configuration records.""" + pass + + @property + @abstractmethod + def load(self) -> ModelLoadServiceBase: + """Return the ModelLoadServiceBase used to load models from their configuration records.""" + pass + + @property + @abstractmethod + def install(self) -> ModelInstallServiceBase: + """Return the ModelInstallServiceBase used to download and manipulate model files.""" + pass + + @abstractmethod + def start(self, invoker: Invoker) -> None: + pass + + @abstractmethod + def stop(self, invoker: Invoker) -> None: + pass diff --git a/invokeai/app/services/model_manager/model_manager_default.py b/invokeai/app/services/model_manager/model_manager_default.py index ad0fd66dbbd..028d4af6159 100644 --- a/invokeai/app/services/model_manager/model_manager_default.py +++ b/invokeai/app/services/model_manager/model_manager_default.py @@ -3,6 +3,7 @@ from typing_extensions import Self +from invokeai.app.services.invoker import Invoker from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache from invokeai.backend.model_manager.metadata import ModelMetadataStore from invokeai.backend.util.logging import InvokeAILogger @@ -10,9 +11,9 @@ from ..config import InvokeAIAppConfig from ..download import DownloadQueueServiceBase from ..events.events_base import EventServiceBase -from ..model_install import ModelInstallService -from ..model_load import ModelLoadService -from ..model_records import ModelRecordServiceSQL +from ..model_install import ModelInstallService, ModelInstallServiceBase +from ..model_load import ModelLoadService, ModelLoadServiceBase +from ..model_records import ModelRecordServiceBase, ModelRecordServiceSQL from ..shared.sqlite.sqlite_database import SqliteDatabase from .model_manager_base import ModelManagerServiceBase @@ -27,6 +28,38 @@ class ModelManagerService(ModelManagerServiceBase): model_manager.load -- Routines to load models into memory. """ + def __init__( + self, + store: ModelRecordServiceBase, + install: ModelInstallServiceBase, + load: ModelLoadServiceBase, + ): + self._store = store + self._install = install + self._load = load + + @property + def store(self) -> ModelRecordServiceBase: + return self._store + + @property + def install(self) -> ModelInstallServiceBase: + return self._install + + @property + def load(self) -> ModelLoadServiceBase: + return self._load + + def start(self, invoker: Invoker) -> None: + for service in [self._store, self._install, self._load]: + if hasattr(service, "start"): + service.start(invoker) + + def stop(self, invoker: Invoker) -> None: + for service in [self._store, self._install, self._load]: + if hasattr(service, "stop"): + service.stop(invoker) + @classmethod def build_model_manager( cls, diff --git a/invokeai/app/services/model_records/model_records_base.py b/invokeai/app/services/model_records/model_records_base.py index e00dd4169d5..e2e98c7e896 100644 --- a/invokeai/app/services/model_records/model_records_base.py +++ b/invokeai/app/services/model_records/model_records_base.py @@ -10,15 +10,12 @@ from pydantic import BaseModel, Field -from invokeai.app.invocations.baseinvocation import InvocationContext from invokeai.app.services.shared.pagination import PaginatedResults from invokeai.backend.model_manager import ( AnyModelConfig, BaseModelType, - LoadedModel, ModelFormat, ModelType, - SubModelType, ) from invokeai.backend.model_manager.load import AnyModelLoader from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore @@ -111,52 +108,6 @@ def get_model(self, key: str) -> AnyModelConfig: """ pass - @abstractmethod - def load_model( - self, - key: str, - submodel: Optional[SubModelType] = None, - context: Optional[InvocationContext] = None, - ) -> LoadedModel: - """ - Load the indicated model into memory and return a LoadedModel object. - - :param key: Key of model config to be fetched. - :param submodel: For main (pipeline models), the submodel to fetch - :param context: Invocation context, used for event issuing. - - Exceptions: UnknownModelException -- model with this key not known - NotImplementedException -- a model loader was not provided at initialization time - """ - pass - - @abstractmethod - def load_model_by_attr( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - submodel: Optional[SubModelType] = None, - context: Optional[InvocationContext] = None, - ) -> LoadedModel: - """ - Load the indicated model into memory and return a LoadedModel object. - - This is provided for API compatability with the get_model() method - in the original model manager. However, note that LoadedModel is - not the same as the original ModelInfo that ws returned. - - :param model_name: Key of model config to be fetched. - :param base_model: Base model - :param model_type: Type of the model - :param submodel: For main (pipeline models), the submodel to fetch - :param context: The invocation context. - - Exceptions: UnknownModelException -- model with this key not known - NotImplementedException -- a model loader was not provided at initialization time - """ - pass - @property @abstractmethod def metadata_store(self) -> ModelMetadataStore: diff --git a/invokeai/app/services/model_records/model_records_sql.py b/invokeai/app/services/model_records/model_records_sql.py index 28a77b1b1ab..f48175351de 100644 --- a/invokeai/app/services/model_records/model_records_sql.py +++ b/invokeai/app/services/model_records/model_records_sql.py @@ -46,8 +46,6 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Set, Tuple, Union -from invokeai.app.invocations.baseinvocation import InvocationContext -from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException from invokeai.app.services.shared.pagination import PaginatedResults from invokeai.backend.model_manager.config import ( AnyModelConfig, @@ -55,9 +53,8 @@ ModelConfigFactory, ModelFormat, ModelType, - SubModelType, ) -from invokeai.backend.model_manager.load import AnyModelLoader, LoadedModel +from invokeai.backend.model_manager.load import AnyModelLoader from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore, UnknownMetadataException from ..shared.sqlite.sqlite_database import SqliteDatabase @@ -220,74 +217,6 @@ def get_model(self, key: str) -> AnyModelConfig: model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1]) return model - def load_model( - self, - key: str, - submodel: Optional[SubModelType], - context: Optional[InvocationContext] = None, - ) -> LoadedModel: - """ - Load the indicated model into memory and return a LoadedModel object. - - :param key: Key of model config to be fetched. - :param submodel: For main (pipeline models), the submodel to fetch. - :param context: Invocation context used for event reporting - - Exceptions: UnknownModelException -- model with this key not known - NotImplementedException -- a model loader was not provided at initialization time - """ - if not self._loader: - raise NotImplementedError(f"Class {self.__class__} was not initialized with a model loader") - # we can emit model loading events if we are executing with access to the invocation context - - model_config = self.get_model(key) - if context: - self._emit_load_event( - context=context, - model_config=model_config, - ) - loaded_model = self._loader.load_model(model_config, submodel) - if context: - self._emit_load_event( - context=context, - model_config=model_config, - loaded=True, - ) - return loaded_model - - def load_model_by_attr( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - submodel: Optional[SubModelType] = None, - context: Optional[InvocationContext] = None, - ) -> LoadedModel: - """ - Load the indicated model into memory and return a LoadedModel object. - - This is provided for API compatability with the get_model() method - in the original model manager. However, note that LoadedModel is - not the same as the original ModelInfo that ws returned. - - :param model_name: Key of model config to be fetched. - :param base_model: Base model - :param model_type: Type of the model - :param submodel: For main (pipeline models), the submodel to fetch - :param context: The invocation context. - - Exceptions: UnknownModelException -- model with this key not known - NotImplementedException -- a model loader was not provided at initialization time - ValueError -- more than one model matches this combination - """ - configs = self.search_by_attr(model_name, base_model, model_type) - if len(configs) == 0: - raise UnknownModelException(f"{base_model}/{model_type}/{model_name}: Unknown model") - elif len(configs) > 1: - raise ValueError(f"{base_model}/{model_type}/{model_name}: More than one model matches.") - else: - return self.load_model(configs[0].key, submodel) - def exists(self, key: str) -> bool: """ Return True if a model with the indicated key exists in the databse. @@ -476,29 +405,3 @@ def _fixup(summary: Dict[str, str]) -> Dict[str, Union[str, int, Set[str]]]: return PaginatedResults( page=page, pages=ceil(total / per_page), per_page=per_page, total=total, items=items ) - - def _emit_load_event( - self, - context: InvocationContext, - model_config: AnyModelConfig, - loaded: Optional[bool] = False, - ) -> None: - if context.services.queue.is_canceled(context.graph_execution_state_id): - raise CanceledException() - - if not loaded: - context.services.events.emit_model_load_started( - queue_id=context.queue_id, - queue_item_id=context.queue_item_id, - queue_batch_id=context.queue_batch_id, - graph_execution_state_id=context.graph_execution_state_id, - model_config=model_config, - ) - else: - context.services.events.emit_model_load_completed( - queue_id=context.queue_id, - queue_item_id=context.queue_item_id, - queue_batch_id=context.queue_batch_id, - graph_execution_state_id=context.graph_execution_state_id, - model_config=model_config, - ) diff --git a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_6.py b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_6.py index b4734445110..1f9ac56518c 100644 --- a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_6.py +++ b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_6.py @@ -6,6 +6,7 @@ class Migration6Callback: def __call__(self, cursor: sqlite3.Cursor) -> None: self._recreate_model_triggers(cursor) + self._delete_ip_adapters(cursor) def _recreate_model_triggers(self, cursor: sqlite3.Cursor) -> None: """ @@ -26,6 +27,22 @@ def _recreate_model_triggers(self, cursor: sqlite3.Cursor) -> None: """ ) + def _delete_ip_adapters(self, cursor: sqlite3.Cursor) -> None: + """ + Delete all the IP adapters. + + The model manager will automatically find and re-add them after the migration + is done. This allows the manager to add the correct image encoder to their + configuration records. + """ + + cursor.execute( + """--sql + DELETE FROM model_config + WHERE type='ip_adapter'; + """ + ) + def build_migration_6() -> Migration: """ @@ -33,6 +50,8 @@ def build_migration_6() -> Migration: This migration does the following: - Adds the model_config_updated_at trigger if it does not exist + - Delete all ip_adapter models so that the model prober can find and + update with the correct image processor model. """ migration_6 = Migration( from_version=5, diff --git a/invokeai/backend/embeddings/model_patcher.py b/invokeai/backend/embeddings/model_patcher.py index 4725181b8ed..bee8909c311 100644 --- a/invokeai/backend/embeddings/model_patcher.py +++ b/invokeai/backend/embeddings/model_patcher.py @@ -64,7 +64,7 @@ def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tup def apply_lora_unet( cls, unet: UNet2DConditionModel, - loras: List[Tuple[LoRAModelRaw, float]], + loras: Iterator[Tuple[LoRAModelRaw, float]], ) -> None: with cls.apply_lora(unet, loras, "lora_unet_"): yield @@ -307,7 +307,7 @@ class ONNXModelPatcher: def apply_lora_unet( cls, unet: OnnxRuntimeModel, - loras: List[Tuple[LoRAModelRaw, float]], + loras: Iterator[Tuple[LoRAModelRaw, float]], ) -> None: with cls.apply_lora(unet, loras, "lora_unet_"): yield diff --git a/invokeai/backend/image_util/safety_checker.py b/invokeai/backend/image_util/safety_checker.py index b9649925e14..92ddef5ecc3 100644 --- a/invokeai/backend/image_util/safety_checker.py +++ b/invokeai/backend/image_util/safety_checker.py @@ -8,8 +8,8 @@ import invokeai.backend.util.logging as logger from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend import SilenceWarnings from invokeai.backend.util.devices import choose_torch_device +from invokeai.backend.util.silence_warnings import SilenceWarnings config = InvokeAIAppConfig.get_config() diff --git a/invokeai/backend/ip_adapter/ip_adapter.py b/invokeai/backend/ip_adapter/ip_adapter.py index 9176bf1f49f..b4706ea99c0 100644 --- a/invokeai/backend/ip_adapter/ip_adapter.py +++ b/invokeai/backend/ip_adapter/ip_adapter.py @@ -8,7 +8,6 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionWeights -from invokeai.backend.model_management.models.base import calc_model_size_by_data from .resampler import Resampler @@ -124,6 +123,9 @@ def to(self, device: torch.device, dtype: Optional[torch.dtype] = None): self.attn_weights.to(device=self.device, dtype=self.dtype) def calc_size(self): + # workaround for circular import + from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data + return calc_model_size_by_data(self._image_proj_model) + calc_model_size_by_data(self.attn_weights) def _init_image_proj_model(self, state_dict): diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 4534a4892fb..9f0f774b499 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -21,7 +21,7 @@ """ import time from enum import Enum -from typing import Literal, Optional, Type, Union, Class +from typing import Literal, Optional, Type, Union import torch from diffusers import ModelMixin @@ -335,7 +335,7 @@ def make_config( cls, model_data: Union[Dict[str, Any], AnyModelConfig], key: Optional[str] = None, - dest_class: Optional[Type[Class]] = None, + dest_class: Optional[Type[ModelConfigBase]] = None, timestamp: Optional[float] = None, ) -> AnyModelConfig: """ @@ -347,14 +347,17 @@ def make_config( :param dest_class: The config class to be returned. If not provided, will be selected automatically. """ + model: Optional[ModelConfigBase] = None if isinstance(model_data, ModelConfigBase): model = model_data elif dest_class: - model = dest_class.validate_python(model_data) + model = dest_class.model_validate(model_data) else: - model = AnyModelConfigValidator.validate_python(model_data) + # mypy doesn't typecheck TypeAdapters well? + model = AnyModelConfigValidator.validate_python(model_data) # type: ignore + assert model is not None if key: model.key = key if timestamp: model.last_modified = timestamp - return model + return model # type: ignore diff --git a/invokeai/backend/model_manager/load/load_base.py b/invokeai/backend/model_manager/load/load_base.py index 9d98ee30531..3d026af2269 100644 --- a/invokeai/backend/model_manager/load/load_base.py +++ b/invokeai/backend/model_manager/load/load_base.py @@ -18,8 +18,16 @@ from typing import Any, Callable, Dict, Optional, Tuple, Type from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend.model_manager import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType -from invokeai.backend.model_manager.config import AnyModel, VaeCheckpointConfig, VaeDiffusersConfig +from invokeai.backend.model_manager.config import ( + AnyModel, + AnyModelConfig, + BaseModelType, + ModelFormat, + ModelType, + SubModelType, + VaeCheckpointConfig, + VaeDiffusersConfig, +) from invokeai.backend.model_manager.load.convert_cache.convert_cache_base import ModelConvertCacheBase from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase from invokeai.backend.util.logging import InvokeAILogger @@ -32,7 +40,7 @@ class LoadedModel: config: AnyModelConfig locker: ModelLockerBase - def __enter__(self) -> AnyModel: # I think load_file() always returns a dict + def __enter__(self) -> AnyModel: """Context entry.""" self.locker.lock() return self.model @@ -171,6 +179,10 @@ def register( def decorator(subclass: Type[ModelLoaderBase]) -> Type[ModelLoaderBase]: cls._logger.debug(f"Registering class {subclass.__name__} to load models of type {base}/{type}/{format}") key = cls._to_registry_key(base, type, format) + if key in cls._registry: + raise Exception( + f"{subclass.__name__} is trying to register as a loader for {base}/{type}/{format}, but this type of model has already been registered by {cls._registry[key].__name__}" + ) cls._registry[key] = subclass return subclass diff --git a/invokeai/backend/model_manager/load/load_default.py b/invokeai/backend/model_manager/load/load_default.py index c1dfe729af7..df83c8320d9 100644 --- a/invokeai/backend/model_manager/load/load_default.py +++ b/invokeai/backend/model_manager/load/load_default.py @@ -169,7 +169,7 @@ def _get_hf_load_class(self, model_path: Path, submodel_type: Optional[SubModelT raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e # This needs to be implemented in subclasses that handle checkpoints - def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path: + def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path: raise NotImplementedError # This needs to be implemented in the subclass diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py index b1deb215b2b..98d6f34cead 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py @@ -246,7 +246,7 @@ def offload_unlocked_models(self, size_required: int) -> None: def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None: """Move model into the indicated device.""" - # These attributes are not in the base ModelMixin class but in derived classes. + # These attributes are not in the base ModelMixin class but in various derived classes. # Some models don't have these attributes, in which case they run in RAM/CPU. self.logger.debug(f"Called to move {cache_entry.key} to {target_device}") if not (hasattr(cache_entry.model, "device") and hasattr(cache_entry.model, "to")): diff --git a/invokeai/backend/model_manager/load/model_loaders/controlnet.py b/invokeai/backend/model_manager/load/model_loaders/controlnet.py index e61e2b46a63..d446d079336 100644 --- a/invokeai/backend/model_manager/load/model_loaders/controlnet.py +++ b/invokeai/backend/model_manager/load/model_loaders/controlnet.py @@ -35,28 +35,28 @@ def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: else: return True - def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path: + def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path: if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}: raise Exception(f"Vae conversion not supported for model type: {config.base}") else: assert hasattr(config, "config") config_file = config.config - if weights_path.suffix == ".safetensors": - checkpoint = safetensors.torch.load_file(weights_path, device="cpu") + if model_path.suffix == ".safetensors": + checkpoint = safetensors.torch.load_file(model_path, device="cpu") else: - checkpoint = torch.load(weights_path, map_location="cpu") + checkpoint = torch.load(model_path, map_location="cpu") # sometimes weights are hidden under "state_dict", and sometimes not if "state_dict" in checkpoint: checkpoint = checkpoint["state_dict"] convert_controlnet_to_diffusers( - weights_path, + model_path, output_path, original_config_file=self._app_config.root_path / config_file, image_size=512, scan_needed=True, - from_safetensors=weights_path.suffix == ".safetensors", + from_safetensors=model_path.suffix == ".safetensors", ) return output_path diff --git a/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py b/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py index 03c26f3a0c0..114e317f3c6 100644 --- a/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py +++ b/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py @@ -12,8 +12,9 @@ ModelType, SubModelType, ) -from invokeai.backend.model_manager.load.load_base import AnyModelLoader -from invokeai.backend.model_manager.load.load_default import ModelLoader + +from ..load_base import AnyModelLoader +from ..load_default import ModelLoader @AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.CLIPVision, format=ModelFormat.Diffusers) diff --git a/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py b/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py index a963e8403b9..23b4e1fccd6 100644 --- a/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py +++ b/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py @@ -65,7 +65,7 @@ def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: else: return True - def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path: + def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path: assert isinstance(config, MainCheckpointConfig) variant = config.variant base = config.base @@ -75,9 +75,9 @@ def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path config_file = config.config - self._logger.info(f"Converting {weights_path} to diffusers format") + self._logger.info(f"Converting {model_path} to diffusers format") convert_ckpt_to_diffusers( - weights_path, + model_path, output_path, model_type=self.model_base_to_model_type[base], model_version=base, @@ -86,7 +86,7 @@ def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path extract_ema=True, scan_needed=True, pipeline_class=pipeline_class, - from_safetensors=weights_path.suffix == ".safetensors", + from_safetensors=model_path.suffix == ".safetensors", precision=self._torch_dtype, load_safety_checker=False, ) diff --git a/invokeai/backend/model_manager/load/model_loaders/vae.py b/invokeai/backend/model_manager/load/model_loaders/vae.py index 882ae055771..3983ea75950 100644 --- a/invokeai/backend/model_manager/load/model_loaders/vae.py +++ b/invokeai/backend/model_manager/load/model_loaders/vae.py @@ -37,7 +37,7 @@ def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: else: return True - def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path: + def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path: # TO DO: check whether sdxl VAE models convert. if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}: raise Exception(f"Vae conversion not supported for model type: {config.base}") @@ -46,10 +46,10 @@ def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path "v1-inference.yaml" if config.base == BaseModelType.StableDiffusion1 else "v2-inference-v.yaml" ) - if weights_path.suffix == ".safetensors": - checkpoint = safetensors.torch.load_file(weights_path, device="cpu") + if model_path.suffix == ".safetensors": + checkpoint = safetensors.torch.load_file(model_path, device="cpu") else: - checkpoint = torch.load(weights_path, map_location="cpu") + checkpoint = torch.load(model_path, map_location="cpu") # sometimes weights are hidden under "state_dict", and sometimes not if "state_dict" in checkpoint: diff --git a/invokeai/backend/model_manager/load/model_util.py b/invokeai/backend/model_manager/load/model_util.py index 3f2d22595e2..c55eee48fa5 100644 --- a/invokeai/backend/model_manager/load/model_util.py +++ b/invokeai/backend/model_manager/load/model_util.py @@ -65,7 +65,7 @@ def calc_model_size_by_fs(model_path: Path, subfolder: Optional[str] = None, var bit8_files = {f for f in all_files if ".8bit." in f.name or ".8bit-" in f.name} other_files = set(all_files) - fp16_files - bit8_files - if variant is None: + if not variant: # ModelRepoVariant.DEFAULT evaluates to empty string for compatability with HF files = other_files elif variant == "fp16": files = fp16_files diff --git a/invokeai/backend/model_manager/search.py b/invokeai/backend/model_manager/search.py index a54938fdd5c..f7e1e1bed76 100644 --- a/invokeai/backend/model_manager/search.py +++ b/invokeai/backend/model_manager/search.py @@ -22,11 +22,12 @@ def find_main_models(model: Path) -> bool: import os from abc import ABC, abstractmethod +from logging import Logger from pathlib import Path from typing import Callable, Optional, Set, Union from pydantic import BaseModel, Field -from logging import Logger + from invokeai.backend.util.logging import InvokeAILogger default_logger: Logger = InvokeAILogger.get_logger() diff --git a/invokeai/backend/stable_diffusion/schedulers/__init__.py b/invokeai/backend/stable_diffusion/schedulers/__init__.py index a4e9dbf9dad..0b780d3ee27 100644 --- a/invokeai/backend/stable_diffusion/schedulers/__init__.py +++ b/invokeai/backend/stable_diffusion/schedulers/__init__.py @@ -1 +1,3 @@ from .schedulers import SCHEDULER_MAP # noqa: F401 + +__all__ = ["SCHEDULER_MAP"] diff --git a/invokeai/frontend/install/model_install.py b/invokeai/frontend/install/model_install.py index 22b132370e6..20b630dfc62 100644 --- a/invokeai/frontend/install/model_install.py +++ b/invokeai/frontend/install/model_install.py @@ -513,7 +513,7 @@ def select_and_download_models(opt: Namespace) -> None: """Prompt user for install/delete selections and execute.""" precision = "float32" if opt.full_precision else choose_precision(torch.device(choose_torch_device())) # unsure how to avoid a typing complaint in the next line: config.precision is an enumerated Literal - config.precision = precision # type: ignore + config.precision = precision install_helper = InstallHelper(config, logger) installer = install_helper.installer diff --git a/tests/aa_nodes/test_graph_execution_state.py b/tests/aa_nodes/test_graph_execution_state.py index 27d2d2230a3..f839a4a8785 100644 --- a/tests/aa_nodes/test_graph_execution_state.py +++ b/tests/aa_nodes/test_graph_execution_state.py @@ -62,9 +62,7 @@ def mock_services() -> InvocationServices: invocation_cache=MemoryInvocationCache(max_cache_size=0), logger=logging, # type: ignore model_manager=None, # type: ignore - model_records=None, # type: ignore download_queue=None, # type: ignore - model_install=None, # type: ignore names=None, # type: ignore performance_statistics=InvocationStatsService(), processor=DefaultInvocationProcessor(), diff --git a/tests/aa_nodes/test_invoker.py b/tests/aa_nodes/test_invoker.py index 437ea0f00d3..774f7501dc2 100644 --- a/tests/aa_nodes/test_invoker.py +++ b/tests/aa_nodes/test_invoker.py @@ -65,9 +65,7 @@ def mock_services() -> InvocationServices: invocation_cache=MemoryInvocationCache(max_cache_size=0), logger=logging, # type: ignore model_manager=None, # type: ignore - model_records=None, # type: ignore download_queue=None, # type: ignore - model_install=None, # type: ignore names=None, # type: ignore performance_statistics=InvocationStatsService(), processor=DefaultInvocationProcessor(), diff --git a/tests/backend/model_manager_2/model_loading/test_model_load.py b/tests/backend/model_manager_2/model_loading/test_model_load.py new file mode 100644 index 00000000000..a7a64e91ac0 --- /dev/null +++ b/tests/backend/model_manager_2/model_loading/test_model_load.py @@ -0,0 +1,22 @@ +""" +Test model loading +""" + +from pathlib import Path + +from invokeai.app.services.model_install import ModelInstallServiceBase +from invokeai.backend.embeddings.textual_inversion import TextualInversionModelRaw +from invokeai.backend.model_manager.load import AnyModelLoader +from tests.backend.model_manager_2.model_manager_2_fixtures import * # noqa F403 + + +def test_loading(mm2_installer: ModelInstallServiceBase, mm2_loader: AnyModelLoader, embedding_file: Path): + store = mm2_installer.record_store + matches = store.search_by_attr(model_name="test_embedding") + assert len(matches) == 0 + key = mm2_installer.register_path(embedding_file) + loaded_model = mm2_loader.load_model(store.get_model(key)) + assert loaded_model is not None + assert loaded_model.config.key == key + with loaded_model as model: + assert isinstance(model, TextualInversionModelRaw) diff --git a/tests/backend/model_manager_2/model_manager_2_fixtures.py b/tests/backend/model_manager_2/model_manager_2_fixtures.py index d6d091befea..d85eab67dd3 100644 --- a/tests/backend/model_manager_2/model_manager_2_fixtures.py +++ b/tests/backend/model_manager_2/model_manager_2_fixtures.py @@ -20,6 +20,7 @@ ModelFormat, ModelType, ) +from invokeai.backend.model_manager.load import AnyModelLoader, ModelCache, ModelConvertCache from invokeai.backend.model_manager.metadata import ModelMetadataStore from invokeai.backend.util.logging import InvokeAILogger from tests.backend.model_manager_2.model_metadata.metadata_examples import ( @@ -89,6 +90,16 @@ def mm2_app_config(mm2_root_dir: Path) -> InvokeAIAppConfig: return app_config +@pytest.fixture +def mm2_loader(mm2_app_config: InvokeAIAppConfig, mm2_record_store: ModelRecordServiceSQL) -> AnyModelLoader: + logger = InvokeAILogger.get_logger(config=mm2_app_config) + ram_cache = ModelCache( + logger=logger, max_cache_size=mm2_app_config.ram_cache_size, max_vram_cache_size=mm2_app_config.vram_cache_size + ) + convert_cache = ModelConvertCache(mm2_app_config.models_convert_cache_path) + return AnyModelLoader(app_config=mm2_app_config, logger=logger, ram_cache=ram_cache, convert_cache=convert_cache) + + @pytest.fixture def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceSQL: logger = InvokeAILogger.get_logger(config=mm2_app_config) From 7418865426312f99d062f74cf8c12471f1d6a1d7 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sun, 11 Feb 2024 23:37:49 -0500 Subject: [PATCH 083/340] add back the `heuristic_import()` method and extend repo_ids to arbitrary file paths --- docs/contributing/MODEL_MANAGER.md | 52 ++++++++++++-- invokeai/app/api/routers/model_manager_v2.py | 70 ++++++++++++++++++- invokeai/app/api_app.py | 1 - .../model_install/model_install_base.py | 39 ++++++++++- .../model_install/model_install_default.py | 43 +++++++++++- .../model_manager/util/select_hf_files.py | 6 ++ 6 files changed, 199 insertions(+), 12 deletions(-) diff --git a/docs/contributing/MODEL_MANAGER.md b/docs/contributing/MODEL_MANAGER.md index 39220f4ba89..959b7f9733c 100644 --- a/docs/contributing/MODEL_MANAGER.md +++ b/docs/contributing/MODEL_MANAGER.md @@ -446,6 +446,44 @@ required parameters: Once initialized, the installer will provide the following methods: +#### install_job = installer.heuristic_import(source, [config], [access_token]) + +This is a simplified interface to the installer which takes a source +string, an optional model configuration dictionary and an optional +access token. + +The `source` is a string that can be any of these forms + +1. A path on the local filesystem (`C:\\users\\fred\\model.safetensors`) +2. A Url pointing to a single downloadable model file (`https://civitai.com/models/58390/detail-tweaker-lora-lora`) +3. A HuggingFace repo_id with any of the following formats: + - `model/name` -- entire model + - `model/name:fp32` -- entire model, using the fp32 variant + - `model/name:fp16:vae` -- vae submodel, using the fp16 variant + - `model/name::vae` -- vae submodel, using default precision + - `model/name:fp16:path/to/model.safetensors` -- an individual model file, fp16 variant + - `model/name::path/to/model.safetensors` -- an individual model file, default variant + +Note that by specifying a relative path to the top of the HuggingFace +repo, you can download and install arbitrary models files. + +The variant, if not provided, will be automatically filled in with +`fp32` if the user has requested full precision, and `fp16` +otherwise. If a variant that does not exist is requested, then the +method will install whatever HuggingFace returns as its default +revision. + +`config` is an optional dict of values that will override the +autoprobed values for model type, base, scheduler prediction type, and +so forth. See [Model configuration and +probing](#Model-configuration-and-probing) for details. + +`access_token` is an optional access token for accessing resources +that need authentication. + +The method will return a `ModelInstallJob`. This object is discussed +at length in the following section. + #### install_job = installer.import_model() The `import_model()` method is the core of the installer. The @@ -464,9 +502,10 @@ source2 = LocalModelSource(path='/opt/models/sushi_diffusers') # a local dif source3 = HFModelSource(repo_id='runwayml/stable-diffusion-v1-5') # a repo_id source4 = HFModelSource(repo_id='runwayml/stable-diffusion-v1-5', subfolder='vae') # a subfolder within a repo_id source5 = HFModelSource(repo_id='runwayml/stable-diffusion-v1-5', variant='fp16') # a named variant of a HF model +source6 = HFModelSource(repo_id='runwayml/stable-diffusion-v1-5', subfolder='OrangeMix/OrangeMix1.ckpt') # path to an individual model file -source6 = URLModelSource(url='https://civitai.com/api/download/models/63006') # model located at a URL -source7 = URLModelSource(url='https://civitai.com/api/download/models/63006', access_token='letmein') # with an access token +source7 = URLModelSource(url='https://civitai.com/api/download/models/63006') # model located at a URL +source8 = URLModelSource(url='https://civitai.com/api/download/models/63006', access_token='letmein') # with an access token for source in [source1, source2, source3, source4, source5, source6, source7]: install_job = installer.install_model(source) @@ -522,7 +561,6 @@ can be passed to `import_model()`. attributes returned by the model prober. See the section below for details. - #### LocalModelSource This is used for a model that is located on a locally-accessible Posix @@ -715,7 +753,7 @@ and `cancelled`, as well as `in_terminal_state`. The last will return True if the job is in the complete, errored or cancelled states. -#### Model confguration and probing +#### Model configuration and probing The install service uses the `invokeai.backend.model_manager.probe` module during import to determine the model's type, base type, and @@ -1106,7 +1144,7 @@ job = queue.create_download_job( event_handlers=[my_handler1, my_handler2], # if desired start=True, ) - ``` +``` The `filename` argument forces the downloader to use the specified name for the file rather than the name provided by the remote source, @@ -1427,9 +1465,9 @@ set of keys to the corresponding model config objects. Find all model metadata records that have the given author and return a set of keys to the corresponding model config objects. -# The remainder of this documentation is provisional, pending implementation of the Load service +*** -## Let's get loaded, the lowdown on ModelLoadService +## The Lowdown on the ModelLoadService The `ModelLoadService` is responsible for loading a named model into memory so that it can be used for inference. Despite the fact that it diff --git a/invokeai/app/api/routers/model_manager_v2.py b/invokeai/app/api/routers/model_manager_v2.py index 4fc785e4f7a..4482edfa0f6 100644 --- a/invokeai/app/api/routers/model_manager_v2.py +++ b/invokeai/app/api/routers/model_manager_v2.py @@ -251,9 +251,75 @@ async def add_model_record( return result +@model_manager_v2_router.post( + "/heuristic_import", + operation_id="heuristic_import_model", + responses={ + 201: {"description": "The model imported successfully"}, + 415: {"description": "Unrecognized file/folder format"}, + 424: {"description": "The model appeared to import successfully, but could not be found in the model manager"}, + 409: {"description": "There is already a model corresponding to this path or repo_id"}, + }, + status_code=201, +) +async def heuristic_import( + source: str, + config: Optional[Dict[str, Any]] = Body( + description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ", + default=None, + ), + access_token: Optional[str] = None, +) -> ModelInstallJob: + """Install a model using a string identifier. + + `source` can be any of the following. + + 1. A path on the local filesystem ('C:\\users\\fred\\model.safetensors') + 2. A Url pointing to a single downloadable model file + 3. A HuggingFace repo_id with any of the following formats: + - model/name + - model/name:fp16:vae + - model/name::vae -- use default precision + - model/name:fp16:path/to/model.safetensors + - model/name::path/to/model.safetensors + + `config` is an optional dict containing model configuration values that will override + the ones that are probed automatically. + + `access_token` is an optional access token for use with Urls that require + authentication. + + Models will be downloaded, probed, configured and installed in a + series of background threads. The return object has `status` attribute + that can be used to monitor progress. + + See the documentation for `import_model_record` for more information on + interpreting the job information returned by this route. + """ + logger = ApiDependencies.invoker.services.logger + + try: + installer = ApiDependencies.invoker.services.model_manager.install + result: ModelInstallJob = installer.heuristic_import( + source=source, + config=config, + ) + logger.info(f"Started installation of {source}") + except UnknownModelException as e: + logger.error(str(e)) + raise HTTPException(status_code=424, detail=str(e)) + except InvalidModelException as e: + logger.error(str(e)) + raise HTTPException(status_code=415) + except ValueError as e: + logger.error(str(e)) + raise HTTPException(status_code=409, detail=str(e)) + return result + + @model_manager_v2_router.post( "/import", - operation_id="import_model_record", + operation_id="import_model", responses={ 201: {"description": "The model imported successfully"}, 415: {"description": "Unrecognized file/folder format"}, @@ -269,7 +335,7 @@ async def import_model( default=None, ), ) -> ModelInstallJob: - """Add a model using its local path, repo_id, or remote URL. + """Install a model using its local path, repo_id, or remote URL. Models will be downloaded, probed, configured and installed in a series of background threads. The return object has `status` attribute diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index 851cbc8160e..1831b54c13c 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -49,7 +49,6 @@ download_queue, images, model_manager_v2, - models, session_queue, sessions, utilities, diff --git a/invokeai/app/services/model_install/model_install_base.py b/invokeai/app/services/model_install/model_install_base.py index 635cb154d64..943cdf1157f 100644 --- a/invokeai/app/services/model_install/model_install_base.py +++ b/invokeai/app/services/model_install/model_install_base.py @@ -127,8 +127,8 @@ def proper_repo_id(cls, v: str) -> str: # noqa D102 def __str__(self) -> str: """Return string version of repoid when string rep needed.""" base: str = self.repo_id + base += f":{self.variant or ''}" base += f":{self.subfolder}" if self.subfolder else "" - base += f" ({self.variant})" if self.variant else "" return base @@ -324,6 +324,43 @@ def install_path( :returns id: The string ID of the registered model. """ + @abstractmethod + def heuristic_import( + self, + source: str, + config: Optional[Dict[str, Any]] = None, + access_token: Optional[str] = None, + ) -> ModelInstallJob: + r"""Install the indicated model using heuristics to interpret user intentions. + + :param source: String source + :param config: Optional dict. Any fields in this dict + will override corresponding autoassigned probe fields in the + model's config record as described in `import_model()`. + :param access_token: Optional access token for remote sources. + + The source can be: + 1. A local file path in posix() format (`/foo/bar` or `C:\foo\bar`) + 2. An http or https URL (`https://foo.bar/foo`) + 3. A HuggingFace repo_id (`foo/bar`, `foo/bar:fp16`, `foo/bar:fp16:vae`) + + We extend the HuggingFace repo_id syntax to include the variant and the + subfolder or path. The following are acceptable alternatives: + stabilityai/stable-diffusion-v4 + stabilityai/stable-diffusion-v4:fp16 + stabilityai/stable-diffusion-v4:fp16:vae + stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors + stabilityai/stable-diffusion-v4:onnx:vae + + Because a local file path can look like a huggingface repo_id, the logic + first checks whether the path exists on disk, and if not, it is treated as + a parseable huggingface repo. + + The previous support for recursing into a local folder and loading all model-like files + has been removed. + """ + pass + @abstractmethod def import_model( self, diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index d32af4a513d..df73fcb8cbe 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -50,6 +50,7 @@ ModelInstallJob, ModelInstallServiceBase, ModelSource, + StringLikeSource, URLModelSource, ) @@ -177,6 +178,34 @@ def install_path( info, ) + def heuristic_import( + self, + source: str, + config: Optional[Dict[str, Any]] = None, + access_token: Optional[str] = None, + ) -> ModelInstallJob: + variants = "|".join(ModelRepoVariant.__members__.values()) + hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$" + source_obj: Optional[StringLikeSource] = None + + if Path(source).exists(): # A local file or directory + source_obj = LocalModelSource(path=Path(source)) + elif match := re.match(hf_repoid_re, source): + source_obj = HFModelSource( + repo_id=match.group(1), + variant=match.group(2) if match.group(2) else None, # pass None rather than '' + subfolder=Path(match.group(3)) if match.group(3) else None, + access_token=access_token, + ) + elif re.match(r"^https?://[^/]+", source): + source_obj = URLModelSource( + url=AnyHttpUrl(source), + access_token=access_token, + ) + else: + raise ValueError(f"Unsupported model source: '{source}'") + return self.import_model(source_obj, config) + def import_model(self, source: ModelSource, config: Optional[Dict[str, Any]] = None) -> ModelInstallJob: # noqa D102 similar_jobs = [x for x in self.list_jobs() if x.source == source and not x.in_terminal_state] if similar_jobs: @@ -571,6 +600,8 @@ def _import_remote_model( # TODO: Replace with tempfile.tmpdir() when multithreading is cleaned up. # Currently the tmpdir isn't automatically removed at exit because it is # being held in a daemon thread. + if len(remote_files) == 0: + raise ValueError(f"{source}: No downloadable files found") tmpdir = Path( mkdtemp( dir=self._app_config.models_path, @@ -586,6 +617,16 @@ def _import_remote_model( bytes=0, total_bytes=0, ) + # In the event that there is a subfolder specified in the source, + # we need to remove it from the destination path in order to avoid + # creating unwanted subfolders + if hasattr(source, "subfolder") and source.subfolder: + root = Path(remote_files[0].path.parts[0]) + subfolder = root / source.subfolder + else: + root = Path(".") + subfolder = Path(".") + # we remember the path up to the top of the tmpdir so that it may be # removed safely at the end of the install process. install_job._install_tmpdir = tmpdir @@ -595,7 +636,7 @@ def _import_remote_model( self._logger.debug(f"remote_files={remote_files}") for model_file in remote_files: url = model_file.url - path = model_file.path + path = root / model_file.path.relative_to(subfolder) self._logger.info(f"Downloading {url} => {path}") install_job.total_bytes += model_file.size assert hasattr(source, "access_token") diff --git a/invokeai/backend/model_manager/util/select_hf_files.py b/invokeai/backend/model_manager/util/select_hf_files.py index a894d915de6..2fd7a3721ab 100644 --- a/invokeai/backend/model_manager/util/select_hf_files.py +++ b/invokeai/backend/model_manager/util/select_hf_files.py @@ -36,6 +36,11 @@ def filter_files( """ variant = variant or ModelRepoVariant.DEFAULT paths: List[Path] = [] + root = files[0].parts[0] + + # if the subfolder is a single file, then bypass the selection and just return it + if subfolder and subfolder.suffix in [".safetensors", ".bin", ".onnx", ".xml", ".pth", ".pt", ".ckpt", ".msgpack"]: + return [root / subfolder] # Start by filtering on model file extensions, discarding images, docs, etc for file in files: @@ -61,6 +66,7 @@ def filter_files( # limit search to subfolder if requested if subfolder: + subfolder = root / subfolder paths = [x for x in paths if x.parent == Path(subfolder)] # _filter_by_variant uniquifies the paths and returns a set From 6fadb9d55c06e157ec190447f06373ad040831ce Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Mon, 12 Feb 2024 14:27:17 -0500 Subject: [PATCH 084/340] add a JIT download_and_cache() call to the model installer --- docs/contributing/MODEL_MANAGER.md | 40 +++++++++++++++ .../app/services/download/download_base.py | 13 +++++ .../app/services/download/download_default.py | 15 +++++- .../model_install/model_install_base.py | 34 ++++++++++++- .../model_install/model_install_default.py | 49 ++++++++++++++++++- .../convert_cache/convert_cache_default.py | 8 ++- 6 files changed, 154 insertions(+), 5 deletions(-) diff --git a/docs/contributing/MODEL_MANAGER.md b/docs/contributing/MODEL_MANAGER.md index 959b7f9733c..b711c654de8 100644 --- a/docs/contributing/MODEL_MANAGER.md +++ b/docs/contributing/MODEL_MANAGER.md @@ -792,6 +792,14 @@ returns a list of completed jobs. The optional `timeout` argument will return from the call if jobs aren't completed in the specified time. An argument of 0 (the default) will block indefinitely. +#### jobs = installer.wait_for_job(job, [timeout]) + +Like `wait_for_installs()`, but block until a specific job has +completed or errored, and then return the job. The optional `timeout` +argument will return from the call if the job doesn't complete in the +specified time. An argument of 0 (the default) will block +indefinitely. + #### jobs = installer.list_jobs() Return a list of all active and complete `ModelInstallJobs`. @@ -854,6 +862,31 @@ This method is similar to `unregister()`, but also unconditionally deletes the corresponding model weights file(s), regardless of whether they are inside or outside the InvokeAI models hierarchy. + +#### path = installer.download_and_cache(remote_source, [access_token], [timeout]) + +This utility routine will download the model file located at source, +cache it, and return the path to the cached file. It does not attempt +to determine the model type, probe its configuration values, or +register it with the models database. + +You may provide an access token if the remote source requires +authorization. The call will block indefinitely until the file is +completely downloaded, cancelled or raises an error of some sort. If +you provide a timeout (in seconds), the call will raise a +`TimeoutError` exception if the download hasn't completed in the +specified period. + +You may use this mechanism to request any type of file, not just a +model. The file will be stored in a subdirectory of +`INVOKEAI_ROOT/models/.cache`. If the requested file is found in the +cache, its path will be returned without redownloading it. + +Be aware that the models cache is cleared of infrequently-used files +and directories at regular intervals when the size of the cache +exceeds the value specified in Invoke's `convert_cache` configuration +variable. + #### List[str]=installer.scan_directory(scan_dir: Path, install: bool) This method will recursively scan the directory indicated in @@ -1187,6 +1220,13 @@ queue or was not created by this queue. This method will block until all the active jobs in the queue have reached a terminal state (completed, errored or cancelled). +#### queue.wait_for_job(job, [timeout]) + +This method will block until the indicated job has reached a terminal +state (completed, errored or cancelled). If the optional timeout is +provided, the call will block for at most timeout seconds, and raise a +TimeoutError otherwise. + #### jobs = queue.list_jobs() This will return a list of all jobs, including ones that have not yet diff --git a/invokeai/app/services/download/download_base.py b/invokeai/app/services/download/download_base.py index f854f64f585..2ac13b825fe 100644 --- a/invokeai/app/services/download/download_base.py +++ b/invokeai/app/services/download/download_base.py @@ -260,3 +260,16 @@ def cancel_job(self, job: DownloadJob) -> None: def join(self) -> None: """Wait until all jobs are off the queue.""" pass + + @abstractmethod + def wait_for_job(self, job: DownloadJob, timeout: int = 0) -> DownloadJob: + """Wait until the indicated download job has reached a terminal state. + + This will block until the indicated install job has completed, + been cancelled, or errored out. + + :param job: The job to wait on. + :param timeout: Wait up to indicated number of seconds. Raise a TimeoutError if + the job hasn't completed within the indicated time. + """ + pass diff --git a/invokeai/app/services/download/download_default.py b/invokeai/app/services/download/download_default.py index 7613c0893fc..f740c500873 100644 --- a/invokeai/app/services/download/download_default.py +++ b/invokeai/app/services/download/download_default.py @@ -4,6 +4,7 @@ import os import re import threading +import time import traceback from pathlib import Path from queue import Empty, PriorityQueue @@ -52,6 +53,7 @@ def __init__( self._next_job_id = 0 self._queue = PriorityQueue() self._stop_event = threading.Event() + self._job_completed_event = threading.Event() self._worker_pool = set() self._lock = threading.Lock() self._logger = InvokeAILogger.get_logger("DownloadQueueService") @@ -188,6 +190,16 @@ def cancel_all_jobs(self) -> None: if not job.in_terminal_state: self.cancel_job(job) + def wait_for_job(self, job: DownloadJob, timeout: int = 0) -> DownloadJob: + """Block until the indicated job has reached terminal state, or when timeout limit reached.""" + start = time.time() + while not job.in_terminal_state: + if self._job_completed_event.wait(timeout=5): # in case we miss an event + self._job_completed_event.clear() + if timeout > 0 and time.time() - start > timeout: + raise TimeoutError("Timeout exceeded") + return job + def _start_workers(self, max_workers: int) -> None: """Start the requested number of worker threads.""" self._stop_event.clear() @@ -223,6 +235,7 @@ def _download_next_item(self) -> None: finally: job.job_ended = get_iso_timestamp() + self._job_completed_event.set() # signal a change to terminal state self._queue.task_done() self._logger.debug(f"Download queue worker thread {threading.current_thread().name} exiting.") @@ -407,7 +420,7 @@ def _cleanup_cancelled_job(self, job: DownloadJob) -> None: # Example on_progress event handler to display a TQDM status bar # Activate with: -# download_service.download('http://foo.bar/baz', '/tmp', on_progress=TqdmProgress().job_update +# download_service.download(DownloadJob('http://foo.bar/baz', '/tmp', on_progress=TqdmProgress().update)) class TqdmProgress(object): """TQDM-based progress bar object to use in on_progress handlers.""" diff --git a/invokeai/app/services/model_install/model_install_base.py b/invokeai/app/services/model_install/model_install_base.py index 943cdf1157f..39ea8c4a0d1 100644 --- a/invokeai/app/services/model_install/model_install_base.py +++ b/invokeai/app/services/model_install/model_install_base.py @@ -422,6 +422,18 @@ def prune_jobs(self) -> None: def cancel_job(self, job: ModelInstallJob) -> None: """Cancel the indicated job.""" + @abstractmethod + def wait_for_job(self, job: ModelInstallJob, timeout: int = 0) -> ModelInstallJob: + """Wait for the indicated job to reach a terminal state. + + This will block until the indicated install job has completed, + been cancelled, or errored out. + + :param job: The job to wait on. + :param timeout: Wait up to indicated number of seconds. Raise a TimeoutError if + the job hasn't completed within the indicated time. + """ + @abstractmethod def wait_for_installs(self, timeout: int = 0) -> List[ModelInstallJob]: """ @@ -431,7 +443,8 @@ def wait_for_installs(self, timeout: int = 0) -> List[ModelInstallJob]: completed, been cancelled, or errored out. :param timeout: Wait up to indicated number of seconds. Raise an Exception('timeout') if - installs do not complete within the indicated time. + installs do not complete within the indicated time. A timeout of zero (the default) + will block indefinitely until the installs complete. """ @abstractmethod @@ -447,3 +460,22 @@ def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: @abstractmethod def sync_to_config(self) -> None: """Synchronize models on disk to those in the model record database.""" + + @abstractmethod + def download_and_cache(self, source: Union[str, AnyHttpUrl], access_token: Optional[str] = None) -> Path: + """ + Download the model file located at source to the models cache and return its Path. + + :param source: A Url or a string that can be converted into one. + :param access_token: Optional access token to access restricted resources. + + The model file will be downloaded into the system-wide model cache + (`models/.cache`) if it isn't already there. Note that the model cache + is periodically cleared of infrequently-used entries when the model + converter runs. + + Note that this doesn't automaticallly install or register the model, but is + intended for use by nodes that need access to models that aren't directly + supported by InvokeAI. The downloading process takes advantage of the download queue + to avoid interrupting other operations. + """ diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index df73fcb8cbe..414e3007157 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -17,7 +17,7 @@ from requests import Session from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase +from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase, TqdmProgress from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.invoker import Invoker from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase, ModelRecordServiceSQL @@ -87,6 +87,7 @@ def __init__( self._lock = threading.Lock() self._stop_event = threading.Event() self._downloads_changed_event = threading.Event() + self._install_completed_event = threading.Event() self._download_queue = download_queue self._download_cache: Dict[AnyHttpUrl, ModelInstallJob] = {} self._running = False @@ -241,6 +242,17 @@ def get_job_by_id(self, id: int) -> ModelInstallJob: # noqa D102 assert isinstance(jobs[0], ModelInstallJob) return jobs[0] + def wait_for_job(self, job: ModelInstallJob, timeout: int = 0) -> ModelInstallJob: + """Block until the indicated job has reached terminal state, or when timeout limit reached.""" + start = time.time() + while not job.in_terminal_state: + if self._install_completed_event.wait(timeout=5): # in case we miss an event + self._install_completed_event.clear() + if timeout > 0 and time.time() - start > timeout: + raise TimeoutError("Timeout exceeded") + return job + + # TODO: Better name? Maybe wait_for_jobs()? Maybe too easily confused with above def wait_for_installs(self, timeout: int = 0) -> List[ModelInstallJob]: # noqa D102 """Block until all installation jobs are done.""" start = time.time() @@ -248,7 +260,7 @@ def wait_for_installs(self, timeout: int = 0) -> List[ModelInstallJob]: # noqa if self._downloads_changed_event.wait(timeout=5): # in case we miss an event self._downloads_changed_event.clear() if timeout > 0 and time.time() - start > timeout: - raise Exception("Timeout exceeded") + raise TimeoutError("Timeout exceeded") self._install_queue.join() return self._install_jobs @@ -302,6 +314,38 @@ def unconditionally_delete(self, key: str) -> None: # noqa D102 path.unlink() self.unregister(key) + def download_and_cache( + self, + source: Union[str, AnyHttpUrl], + access_token: Optional[str] = None, + timeout: int = 0, + ) -> Path: + """Download the model file located at source to the models cache and return its Path.""" + model_hash = sha256(str(source).encode("utf-8")).hexdigest()[0:32] + model_path = self._app_config.models_convert_cache_path / model_hash + + # We expect the cache directory to contain one and only one downloaded file. + # We don't know the file's name in advance, as it is set by the download + # content-disposition header. + if model_path.exists(): + contents = [x for x in model_path.iterdir() if x.is_file()] + if len(contents) > 0: + return contents[0] + + model_path.mkdir(parents=True, exist_ok=True) + job = self._download_queue.download( + source=AnyHttpUrl(str(source)), + dest=model_path, + access_token=access_token, + on_progress=TqdmProgress().update, + ) + self._download_queue.wait_for_job(job, timeout) + if job.complete: + assert job.download_path is not None + return job.download_path + else: + raise Exception(job.error) + # -------------------------------------------------------------------------------------------- # Internal functions that manage the installer threads # -------------------------------------------------------------------------------------------- @@ -365,6 +409,7 @@ def _install_next_item(self) -> None: # if this is an install of a remote file, then clean up the temporary directory if job._install_tmpdir is not None: rmtree(job._install_tmpdir) + self._install_completed_event.set() self._install_queue.task_done() self._logger.info("Install thread exiting") diff --git a/invokeai/backend/model_manager/load/convert_cache/convert_cache_default.py b/invokeai/backend/model_manager/load/convert_cache/convert_cache_default.py index 4c361258d90..84f4f76299a 100644 --- a/invokeai/backend/model_manager/load/convert_cache/convert_cache_default.py +++ b/invokeai/backend/model_manager/load/convert_cache/convert_cache_default.py @@ -53,7 +53,13 @@ def by_atime(path: Path) -> float: sentinel = path / config if sentinel.exists(): return sentinel.stat().st_atime - return 0.0 + + # no sentinel file found! - pick the most recent file in the directory + try: + atimes = sorted([x.stat().st_atime for x in path.iterdir() if x.is_file()], reverse=True) + return atimes[0] + except IndexError: + return 0.0 # sort by last access time - least accessed files will be at the end lru_models = sorted(self._cache_path.iterdir(), key=by_atime, reverse=True) From b06ebbd98d7fa2ca7bdfc6ee2738300d0924f9cc Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Mon, 12 Feb 2024 21:25:42 -0500 Subject: [PATCH 085/340] add route for model conversion from safetensors to diffusers - Begin to add SwaggerUI documentation for AnyModelConfig and other discriminated Unions. --- invokeai/app/api/routers/model_manager_v2.py | 80 ++++++++++++++++++- .../model_install/model_install_default.py | 6 +- .../services/model_load/model_load_base.py | 14 +++- .../services/model_load/model_load_default.py | 12 ++- .../model_records/model_records_base.py | 7 -- .../model_records/model_records_sql.py | 10 +-- .../backend/model_manager/load/load_base.py | 5 ++ 7 files changed, 113 insertions(+), 21 deletions(-) diff --git a/invokeai/app/api/routers/model_manager_v2.py b/invokeai/app/api/routers/model_manager_v2.py index 4482edfa0f6..8d31c6f286b 100644 --- a/invokeai/app/api/routers/model_manager_v2.py +++ b/invokeai/app/api/routers/model_manager_v2.py @@ -2,6 +2,7 @@ """FastAPI route for model configuration records.""" import pathlib +import shutil from hashlib import sha1 from random import randbytes from typing import Any, Dict, List, Optional, Set @@ -24,8 +25,10 @@ from invokeai.backend.model_manager.config import ( AnyModelConfig, BaseModelType, + MainCheckpointConfig, ModelFormat, ModelType, + SubModelType, ) from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata @@ -318,7 +321,7 @@ async def heuristic_import( @model_manager_v2_router.post( - "/import", + "/install", operation_id="import_model", responses={ 201: {"description": "The model imported successfully"}, @@ -490,6 +493,81 @@ async def sync_models_to_config() -> Response: return Response(status_code=204) +@model_manager_v2_router.put( + "/convert/{key}", + operation_id="convert_model", + responses={ + 200: {"description": "Model converted successfully"}, + 400: {"description": "Bad request"}, + 404: {"description": "Model not found"}, + 409: {"description": "There is already a model registered at this location"}, + }, +) +async def convert_model( + key: str = Path(description="Unique key of the safetensors main model to convert to diffusers format."), +) -> AnyModelConfig: + """ + Permanently convert a model into diffusers format, replacing the safetensors version. + Note that the key and model hash will change. Use the model configuration record returned + by this call to get the new values. + """ + logger = ApiDependencies.invoker.services.logger + loader = ApiDependencies.invoker.services.model_manager.load + store = ApiDependencies.invoker.services.model_manager.store + installer = ApiDependencies.invoker.services.model_manager.install + + try: + model_config = store.get_model(key) + except UnknownModelException as e: + logger.error(str(e)) + raise HTTPException(status_code=424, detail=str(e)) + + if not isinstance(model_config, MainCheckpointConfig): + logger.error(f"The model with key {key} is not a main checkpoint model.") + raise HTTPException(400, f"The model with key {key} is not a main checkpoint model.") + + # loading the model will convert it into a cached diffusers file + loader.load_model_by_config(model_config, submodel_type=SubModelType.Scheduler) + + # Get the path of the converted model from the loader + cache_path = loader.convert_cache.cache_path(key) + assert cache_path.exists() + + # temporarily rename the original safetensors file so that there is no naming conflict + original_name = model_config.name + model_config.name = f"{original_name}.DELETE" + store.update_model(key, config=model_config) + + # install the diffusers + try: + new_key = installer.install_path( + cache_path, + config={ + "name": original_name, + "description": model_config.description, + "original_hash": model_config.original_hash, + "source": model_config.source, + }, + ) + except DuplicateModelException as e: + logger.error(str(e)) + raise HTTPException(status_code=409, detail=str(e)) + + # get the original metadata + if orig_metadata := store.get_metadata(key): + store.metadata_store.add_metadata(new_key, orig_metadata) + + # delete the original safetensors file + installer.delete(key) + + # delete the cached version + shutil.rmtree(cache_path) + + # return the config record for the new diffusers directory + new_config: AnyModelConfig = store.get_model(new_key) + return new_config + + @model_manager_v2_router.put( "/merge", operation_id="merge", diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index 414e3007157..20a85a82a14 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -162,8 +162,10 @@ def install_path( config["source"] = model_path.resolve().as_posix() info: AnyModelConfig = self._probe_model(Path(model_path), config) - old_hash = info.original_hash - dest_path = self.app_config.models_path / info.base.value / info.type.value / model_path.name + old_hash = info.current_hash + dest_path = ( + self.app_config.models_path / info.base.value / info.type.value / (config.get("name") or model_path.name) + ) try: new_path = self._copy_model(model_path, dest_path) except FileExistsError as excp: diff --git a/invokeai/app/services/model_load/model_load_base.py b/invokeai/app/services/model_load/model_load_base.py index f298d98ce6d..45eaf4652fb 100644 --- a/invokeai/app/services/model_load/model_load_base.py +++ b/invokeai/app/services/model_load/model_load_base.py @@ -5,8 +5,10 @@ from typing import Optional from invokeai.app.invocations.baseinvocation import InvocationContext -from invokeai.backend.model_manager import AnyModelConfig, BaseModelType, ModelType, SubModelType +from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType from invokeai.backend.model_manager.load import LoadedModel +from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase +from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase class ModelLoadServiceBase(ABC): @@ -70,3 +72,13 @@ def load_model_by_attr( NotImplementedException -- a model loader was not provided at initialization time ValueError -- more than one model matches this combination """ + + @property + @abstractmethod + def ram_cache(self) -> ModelCacheBase[AnyModel]: + """Return the RAM cache used by this loader.""" + + @property + @abstractmethod + def convert_cache(self) -> ModelConvertCacheBase: + """Return the checkpoint convert cache used by this loader.""" diff --git a/invokeai/app/services/model_load/model_load_default.py b/invokeai/app/services/model_load/model_load_default.py index 67107cada6e..a6ccd5afbc3 100644 --- a/invokeai/app/services/model_load/model_load_default.py +++ b/invokeai/app/services/model_load/model_load_default.py @@ -10,7 +10,7 @@ from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType from invokeai.backend.model_manager.load import AnyModelLoader, LoadedModel, ModelCache, ModelConvertCache from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase -from invokeai.backend.model_manager.load.model_cache import ModelCacheBase +from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase from invokeai.backend.util.logging import InvokeAILogger from .model_load_base import ModelLoadServiceBase @@ -46,6 +46,16 @@ def __init__( ), ) + @property + def ram_cache(self) -> ModelCacheBase[AnyModel]: + """Return the RAM cache used by this loader.""" + return self._any_loader.ram_cache + + @property + def convert_cache(self) -> ModelConvertCacheBase: + """Return the checkpoint convert cache used by this loader.""" + return self._any_loader.convert_cache + def load_model_by_key( self, key: str, diff --git a/invokeai/app/services/model_records/model_records_base.py b/invokeai/app/services/model_records/model_records_base.py index e2e98c7e896..b2eacc524b7 100644 --- a/invokeai/app/services/model_records/model_records_base.py +++ b/invokeai/app/services/model_records/model_records_base.py @@ -17,7 +17,6 @@ ModelFormat, ModelType, ) -from invokeai.backend.model_manager.load import AnyModelLoader from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore @@ -195,12 +194,6 @@ def search_by_attr( """ pass - @property - @abstractmethod - def loader(self) -> Optional[AnyModelLoader]: - """Return the model loader used by this instance.""" - pass - def all_models(self) -> List[AnyModelConfig]: """Return all the model configs in the database.""" return self.search_by_attr() diff --git a/invokeai/app/services/model_records/model_records_sql.py b/invokeai/app/services/model_records/model_records_sql.py index f48175351de..84a14123838 100644 --- a/invokeai/app/services/model_records/model_records_sql.py +++ b/invokeai/app/services/model_records/model_records_sql.py @@ -54,7 +54,6 @@ ModelFormat, ModelType, ) -from invokeai.backend.model_manager.load import AnyModelLoader from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore, UnknownMetadataException from ..shared.sqlite.sqlite_database import SqliteDatabase @@ -70,28 +69,21 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): """Implementation of the ModelConfigStore ABC using a SQL database.""" - def __init__(self, db: SqliteDatabase, loader: Optional[AnyModelLoader] = None): + def __init__(self, db: SqliteDatabase): """ Initialize a new object from preexisting sqlite3 connection and threading lock objects. :param db: Sqlite connection object - :param loader: Initialized model loader object (optional) """ super().__init__() self._db = db self._cursor = db.conn.cursor() - self._loader = loader @property def db(self) -> SqliteDatabase: """Return the underlying database.""" return self._db - @property - def loader(self) -> Optional[AnyModelLoader]: - """Return the model loader used by this instance.""" - return self._loader - def add_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig: """ Add a model to the database. diff --git a/invokeai/backend/model_manager/load/load_base.py b/invokeai/backend/model_manager/load/load_base.py index 3d026af2269..5f392ada75e 100644 --- a/invokeai/backend/model_manager/load/load_base.py +++ b/invokeai/backend/model_manager/load/load_base.py @@ -117,6 +117,11 @@ def ram_cache(self) -> ModelCacheBase[AnyModel]: """Return the RAM cache associated used by the loaders.""" return self._ram_cache + @property + def convert_cache(self) -> ModelConvertCacheBase: + """Return the convert cache associated used by the loaders.""" + return self._convert_cache + def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel: """ Return a model given its configuration. From 9cdca07dc6f3d0c6ec6656a022bd0cf78e25813b Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Tue, 13 Feb 2024 00:26:49 -0500 Subject: [PATCH 086/340] fix a number of typechecking errors --- invokeai/app/api/routers/download_queue.py | 6 +- invokeai/app/api/routers/model_manager_v2.py | 69 ++++++++++++++++--- invokeai/app/invocations/ip_adapter.py | 4 +- invokeai/app/invocations/model.py | 8 +-- invokeai/app/services/config/config_base.py | 11 +-- invokeai/app/services/config/config_common.py | 2 +- .../app/services/download/download_default.py | 12 ++-- invokeai/app/util/misc.py | 8 +-- .../backend/model_manager/load/load_base.py | 13 ++-- .../load/model_cache/model_cache_default.py | 4 +- .../metadata/fetch/fetch_base.py | 4 +- invokeai/backend/model_manager/probe.py | 2 +- invokeai/backend/model_manager/search.py | 6 +- 13 files changed, 101 insertions(+), 48 deletions(-) diff --git a/invokeai/app/api/routers/download_queue.py b/invokeai/app/api/routers/download_queue.py index 2dba376c181..a6e53c7a5c4 100644 --- a/invokeai/app/api/routers/download_queue.py +++ b/invokeai/app/api/routers/download_queue.py @@ -36,7 +36,7 @@ async def list_downloads() -> List[DownloadJob]: 400: {"description": "Bad request"}, }, ) -async def prune_downloads(): +async def prune_downloads() -> Response: """Prune completed and errored jobs.""" queue = ApiDependencies.invoker.services.download_queue queue.prune_jobs() @@ -87,7 +87,7 @@ async def get_download_job( ) async def cancel_download_job( id: int = Path(description="ID of the download job to cancel."), -): +) -> Response: """Cancel a download job using its ID.""" try: queue = ApiDependencies.invoker.services.download_queue @@ -105,7 +105,7 @@ async def cancel_download_job( 204: {"description": "Download jobs have been cancelled"}, }, ) -async def cancel_all_download_jobs(): +async def cancel_all_download_jobs() -> Response: """Cancel all download jobs.""" ApiDependencies.invoker.services.download_queue.cancel_all_jobs() return Response(status_code=204) diff --git a/invokeai/app/api/routers/model_manager_v2.py b/invokeai/app/api/routers/model_manager_v2.py index 8d31c6f286b..029c6207072 100644 --- a/invokeai/app/api/routers/model_manager_v2.py +++ b/invokeai/app/api/routers/model_manager_v2.py @@ -9,7 +9,7 @@ from fastapi import Body, Path, Query, Response from fastapi.routing import APIRouter -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, Field from starlette.exceptions import HTTPException from typing_extensions import Annotated @@ -37,6 +37,35 @@ model_manager_v2_router = APIRouter(prefix="/v2/models", tags=["model_manager_v2"]) +example_model_output = { + "path": "sd-1/main/openjourney", + "name": "openjourney", + "base": "sd-1", + "type": "main", + "format": "diffusers", + "key": "3a0e45ff858926fd4a63da630688b1e1", + "original_hash": "1c12f18fb6e403baef26fb9d720fbd2f", + "current_hash": "1c12f18fb6e403baef26fb9d720fbd2f", + "description": "sd-1 main model openjourney", + "source": "/opt/invokeai/models/sd-1/main/openjourney", + "last_modified": 1707794711, + "vae": "/opt/invokeai/models/sd-1/vae/vae-ft-mse-840000-ema-pruned_fp16.safetensors", + "variant": "normal", + "prediction_type": "epsilon", + "repo_variant": "fp16", +} + +example_model_input = { + "path": "base/type/name", + "name": "model_name", + "base": "sd-1", + "type": "main", + "format": "diffusers", + "description": "Model description", + "vae": None, + "variant": "normal", +} + class ModelsList(BaseModel): """Return list of configs.""" @@ -88,7 +117,10 @@ async def list_model_records( "/i/{key}", operation_id="get_model_record", responses={ - 200: {"description": "Success"}, + 200: { + "description": "The model configuration was retrieved successfully", + "content": {"application/json": {"example": example_model_output}}, + }, 400: {"description": "Bad request"}, 404: {"description": "The model could not be found"}, }, @@ -165,18 +197,22 @@ async def search_by_metadata_tags( "/i/{key}", operation_id="update_model_record", responses={ - 200: {"description": "The model was updated successfully"}, + 200: { + "description": "The model was updated successfully", + "content": {"application/json": {"example": example_model_output}}, + }, 400: {"description": "Bad request"}, 404: {"description": "The model could not be found"}, 409: {"description": "There is already a model corresponding to the new name"}, }, status_code=200, - response_model=AnyModelConfig, ) async def update_model_record( key: Annotated[str, Path(description="Unique key of model")], - info: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")], -) -> AnyModelConfig: + info: Annotated[ + AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input) + ], +) -> Annotated[AnyModelConfig, Field(example="this is neat")]: """Update model contents with a new config. If the model name or base fields are changed, then the model is renamed.""" logger = ApiDependencies.invoker.services.logger record_store = ApiDependencies.invoker.services.model_manager.store @@ -225,7 +261,10 @@ async def del_model_record( "/i/", operation_id="add_model_record", responses={ - 201: {"description": "The model added successfully"}, + 201: { + "description": "The model added successfully", + "content": {"application/json": {"example": example_model_output}}, + }, 409: {"description": "There is already a model corresponding to this path or repo_id"}, 415: {"description": "Unrecognized file/folder format"}, }, @@ -270,6 +309,7 @@ async def heuristic_import( config: Optional[Dict[str, Any]] = Body( description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ", default=None, + example={"name": "modelT", "description": "antique cars"}, ), access_token: Optional[str] = None, ) -> ModelInstallJob: @@ -497,7 +537,10 @@ async def sync_models_to_config() -> Response: "/convert/{key}", operation_id="convert_model", responses={ - 200: {"description": "Model converted successfully"}, + 200: { + "description": "Model converted successfully", + "content": {"application/json": {"example": example_model_output}}, + }, 400: {"description": "Bad request"}, 404: {"description": "Model not found"}, 409: {"description": "There is already a model registered at this location"}, @@ -571,6 +614,15 @@ async def convert_model( @model_manager_v2_router.put( "/merge", operation_id="merge", + responses={ + 200: { + "description": "Model converted successfully", + "content": {"application/json": {"example": example_model_output}}, + }, + 400: {"description": "Bad request"}, + 404: {"description": "Model not found"}, + 409: {"description": "There is already a model registered at this location"}, + }, ) async def merge( keys: List[str] = Body(description="Keys for two to three models to merge", min_length=2, max_length=3), @@ -596,7 +648,6 @@ async def merge( interp: Interpolation method. One of "weighted_sum", "sigmoid", "inv_sigmoid" or "add_difference" [weighted_sum] merge_dest_directory: Specify a directory to store the merged model in [models directory] """ - print(f"here i am, keys={keys}") logger = ApiDependencies.invoker.services.logger try: logger.info(f"Merging models: {keys} into {merge_dest_directory or ''}/{merged_model_name}") diff --git a/invokeai/app/invocations/ip_adapter.py b/invokeai/app/invocations/ip_adapter.py index f64b3266bbb..01124f62f3c 100644 --- a/invokeai/app/invocations/ip_adapter.py +++ b/invokeai/app/invocations/ip_adapter.py @@ -90,10 +90,10 @@ def validate_begin_end_step_percent(self) -> Self: def invoke(self, context: InvocationContext) -> IPAdapterOutput: # Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model. - ip_adapter_info = context.services.model_records.get_model(self.ip_adapter_model.key) + ip_adapter_info = context.services.model_manager.store.get_model(self.ip_adapter_model.key) image_encoder_model_id = ip_adapter_info.image_encoder_model_id image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip() - image_encoder_models = context.services.model_records.search_by_attr( + image_encoder_models = context.services.model_manager.store.search_by_attr( model_name=image_encoder_model_name, base_model=BaseModelType.Any, model_type=ModelType.CLIPVision ) assert len(image_encoder_models) == 1 diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index fa6e8b98da0..f78425c6ee3 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -103,7 +103,7 @@ def invoke(self, context: InvocationContext) -> ModelLoaderOutput: key = self.model.key # TODO: not found exceptions - if not context.services.model_records.exists(key): + if not context.services.model_manager.store.exists(key): raise Exception(f"Unknown model {key}") return ModelLoaderOutput( @@ -172,7 +172,7 @@ def invoke(self, context: InvocationContext) -> LoraLoaderOutput: lora_key = self.lora.key - if not context.services.model_records.exists(lora_key): + if not context.services.model_manager.store.exists(lora_key): raise Exception(f"Unkown lora: {lora_key}!") if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras): @@ -252,7 +252,7 @@ def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput: lora_key = self.lora.key - if not context.services.model_records.exists(lora_key): + if not context.services.model_manager.store.exists(lora_key): raise Exception(f"Unknown lora: {lora_key}!") if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras): @@ -318,7 +318,7 @@ class VaeLoaderInvocation(BaseInvocation): def invoke(self, context: InvocationContext) -> VAEOutput: key = self.vae_model.key - if not context.services.model_records.exists(key): + if not context.services.model_manager.store.exists(key): raise Exception(f"Unkown vae: {key}!") return VAEOutput(vae=VaeField(vae=ModelInfo(key=key))) diff --git a/invokeai/app/services/config/config_base.py b/invokeai/app/services/config/config_base.py index a304b38a955..983df6b4684 100644 --- a/invokeai/app/services/config/config_base.py +++ b/invokeai/app/services/config/config_base.py @@ -27,11 +27,11 @@ class InvokeAISettings(BaseSettings): """Runtime configuration settings in which default values are read from an omegaconf .yaml file.""" initconf: ClassVar[Optional[DictConfig]] = None - argparse_groups: ClassVar[Dict] = {} + argparse_groups: ClassVar[Dict[str, Any]] = {} model_config = SettingsConfigDict(env_file_encoding="utf-8", arbitrary_types_allowed=True, case_sensitive=True) - def parse_args(self, argv: Optional[list] = sys.argv[1:]): + def parse_args(self, argv: Optional[List[str]] = sys.argv[1:]) -> None: """Call to parse command-line arguments.""" parser = self.get_parser() opt, unknown_opts = parser.parse_known_args(argv) @@ -68,7 +68,7 @@ def to_yaml(self) -> str: return OmegaConf.to_yaml(conf) @classmethod - def add_parser_arguments(cls, parser): + def add_parser_arguments(cls, parser) -> None: """Dynamically create arguments for a settings parser.""" if "type" in get_type_hints(cls): settings_stanza = get_args(get_type_hints(cls)["type"])[0] @@ -117,7 +117,8 @@ def cmd_name(cls, command_field: str = "type") -> str: """Return the category of a setting.""" hints = get_type_hints(cls) if command_field in hints: - return get_args(hints[command_field])[0] + result: str = get_args(hints[command_field])[0] + return result else: return "Uncategorized" @@ -158,7 +159,7 @@ def _excluded_from_yaml(cls) -> List[str]: ] @classmethod - def add_field_argument(cls, command_parser, name: str, field, default_override=None): + def add_field_argument(cls, command_parser, name: str, field, default_override=None) -> None: """Add the argparse arguments for a setting parser.""" field_type = get_type_hints(cls).get(name) default = ( diff --git a/invokeai/app/services/config/config_common.py b/invokeai/app/services/config/config_common.py index d11bcabcf9c..27a0f859c23 100644 --- a/invokeai/app/services/config/config_common.py +++ b/invokeai/app/services/config/config_common.py @@ -21,7 +21,7 @@ class PagingArgumentParser(argparse.ArgumentParser): It also supports reading defaults from an init file. """ - def print_help(self, file=None): + def print_help(self, file=None) -> None: text = self.format_help() pydoc.pager(text) diff --git a/invokeai/app/services/download/download_default.py b/invokeai/app/services/download/download_default.py index f740c500873..7008f8ed741 100644 --- a/invokeai/app/services/download/download_default.py +++ b/invokeai/app/services/download/download_default.py @@ -8,12 +8,12 @@ import traceback from pathlib import Path from queue import Empty, PriorityQueue -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Set import requests from pydantic.networks import AnyHttpUrl from requests import HTTPError -from tqdm import tqdm +from tqdm import tqdm, std from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.util.misc import get_iso_timestamp @@ -49,12 +49,12 @@ def __init__( :param max_parallel_dl: Number of simultaneous downloads allowed [5]. :param requests_session: Optional requests.sessions.Session object, for unit tests. """ - self._jobs = {} + self._jobs: Dict[int, DownloadJob] = {} self._next_job_id = 0 - self._queue = PriorityQueue() + self._queue: PriorityQueue[DownloadJob] = PriorityQueue() self._stop_event = threading.Event() self._job_completed_event = threading.Event() - self._worker_pool = set() + self._worker_pool: Set[threading.Thread] = set() self._lock = threading.Lock() self._logger = InvokeAILogger.get_logger("DownloadQueueService") self._event_bus = event_bus @@ -424,7 +424,7 @@ def _cleanup_cancelled_job(self, job: DownloadJob) -> None: class TqdmProgress(object): """TQDM-based progress bar object to use in on_progress handlers.""" - _bars: Dict[int, tqdm] # the tqdm object + _bars: Dict[int, tqdm] # type: ignore _last: Dict[int, int] # last bytes downloaded def __init__(self) -> None: # noqa D107 diff --git a/invokeai/app/util/misc.py b/invokeai/app/util/misc.py index 910b05d8dde..da431929dbe 100644 --- a/invokeai/app/util/misc.py +++ b/invokeai/app/util/misc.py @@ -5,7 +5,7 @@ import numpy as np -def get_timestamp(): +def get_timestamp() -> int: return int(datetime.datetime.now(datetime.timezone.utc).timestamp()) @@ -20,16 +20,16 @@ def get_datetime_from_iso_timestamp(iso_timestamp: str) -> datetime.datetime: SEED_MAX = np.iinfo(np.uint32).max -def get_random_seed(): +def get_random_seed() -> int: rng = np.random.default_rng(seed=None) return int(rng.integers(0, SEED_MAX)) -def uuid_string(): +def uuid_string() -> str: res = uuid.uuid4() return str(res) -def is_optional(value: typing.Any): +def is_optional(value: typing.Any) -> bool: """Checks if a value is typed as Optional. Note that Optional is sugar for Union[x, None].""" return typing.get_origin(value) is typing.Union and type(None) in typing.get_args(value) diff --git a/invokeai/backend/model_manager/load/load_base.py b/invokeai/backend/model_manager/load/load_base.py index 5f392ada75e..7649dee762b 100644 --- a/invokeai/backend/model_manager/load/load_base.py +++ b/invokeai/backend/model_manager/load/load_base.py @@ -22,6 +22,7 @@ AnyModel, AnyModelConfig, BaseModelType, + ModelConfigBase, ModelFormat, ModelType, SubModelType, @@ -70,7 +71,7 @@ def __init__( pass @abstractmethod - def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel: + def load_model(self, model_config: ModelConfigBase, submodel_type: Optional[SubModelType] = None) -> LoadedModel: """ Return a model given its confguration. @@ -122,7 +123,7 @@ def convert_cache(self) -> ModelConvertCacheBase: """Return the convert cache associated used by the loaders.""" return self._convert_cache - def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel: + def load_model(self, model_config: ModelConfigBase, submodel_type: Optional[SubModelType] = None) -> LoadedModel: """ Return a model given its configuration. @@ -144,8 +145,8 @@ def _to_registry_key(base: BaseModelType, type: ModelType, format: ModelFormat) @classmethod def get_implementation( - cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] - ) -> Tuple[Type[ModelLoaderBase], AnyModelConfig, Optional[SubModelType]]: + cls, config: ModelConfigBase, submodel_type: Optional[SubModelType] + ) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]: """Get subclass of ModelLoaderBase registered to handle base and type.""" # We have to handle VAE overrides here because this will change the model type and the corresponding implementation returned conf2, submodel_type = cls._handle_subtype_overrides(config, submodel_type) @@ -161,8 +162,8 @@ def get_implementation( @classmethod def _handle_subtype_overrides( - cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] - ) -> Tuple[AnyModelConfig, Optional[SubModelType]]: + cls, config: ModelConfigBase, submodel_type: Optional[SubModelType] + ) -> Tuple[ModelConfigBase, Optional[SubModelType]]: if submodel_type == SubModelType.Vae and hasattr(config, "vae") and config.vae is not None: model_path = Path(config.vae) config_class = ( diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py index 98d6f34cead..786396062cf 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py @@ -34,8 +34,8 @@ from invokeai.backend.util.devices import choose_torch_device from invokeai.backend.util.logging import InvokeAILogger -from .model_cache_base import CacheRecord, CacheStats, ModelCacheBase -from .model_locker import ModelLocker, ModelLockerBase +from .model_cache_base import CacheRecord, CacheStats, ModelCacheBase, ModelLockerBase +from .model_locker import ModelLocker if choose_torch_device() == torch.device("mps"): from torch import mps diff --git a/invokeai/backend/model_manager/metadata/fetch/fetch_base.py b/invokeai/backend/model_manager/metadata/fetch/fetch_base.py index d628ab5c178..5d75493b92f 100644 --- a/invokeai/backend/model_manager/metadata/fetch/fetch_base.py +++ b/invokeai/backend/model_manager/metadata/fetch/fetch_base.py @@ -20,7 +20,7 @@ from invokeai.backend.model_manager import ModelRepoVariant -from ..metadata_base import AnyModelRepoMetadata, AnyModelRepoMetadataValidator +from ..metadata_base import AnyModelRepoMetadata, AnyModelRepoMetadataValidator, BaseMetadata class ModelMetadataFetchBase(ABC): @@ -62,5 +62,5 @@ def from_id(self, id: str, variant: Optional[ModelRepoVariant] = None) -> AnyMod @classmethod def from_json(cls, json: str) -> AnyModelRepoMetadata: """Given the JSON representation of the metadata, return the corresponding Pydantic object.""" - metadata = AnyModelRepoMetadataValidator.validate_json(json) + metadata: BaseMetadata = AnyModelRepoMetadataValidator.validate_json(json) # type: ignore return metadata diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index e7d21c578fd..2c2066d7c52 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -166,7 +166,7 @@ def probe( fields["original_hash"] = fields.get("original_hash") or hash fields["current_hash"] = fields.get("current_hash") or hash - if format_type == ModelFormat.Diffusers: + if format_type == ModelFormat.Diffusers and hasattr(probe, "get_repo_variant"): fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant() # additional fields needed for main and controlnet models diff --git a/invokeai/backend/model_manager/search.py b/invokeai/backend/model_manager/search.py index f7e1e1bed76..0ead22b743f 100644 --- a/invokeai/backend/model_manager/search.py +++ b/invokeai/backend/model_manager/search.py @@ -116,9 +116,9 @@ class ModelSearch(ModelSearchBase): # returns all models that have 'anime' in the path """ - models_found: Set[Path] = Field(default=None) - scanned_dirs: Set[Path] = Field(default=None) - pruned_paths: Set[Path] = Field(default=None) + models_found: Optional[Set[Path]] = Field(default=None) + scanned_dirs: Optional[Set[Path]] = Field(default=None) + pruned_paths: Optional[Set[Path]] = Field(default=None) def search_started(self) -> None: self.models_found = set() From a8eb9e26bb0b29ab3f5fbc13108334f051e1c423 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Wed, 14 Feb 2024 11:10:50 -0500 Subject: [PATCH 087/340] improve swagger documentation --- invokeai/app/api/routers/model_manager_v2.py | 214 ++++++++++++------ .../app/services/download/download_default.py | 2 +- invokeai/backend/model_manager/config.py | 16 +- 3 files changed, 159 insertions(+), 73 deletions(-) diff --git a/invokeai/app/api/routers/model_manager_v2.py b/invokeai/app/api/routers/model_manager_v2.py index 029c6207072..2471e0d8c9b 100644 --- a/invokeai/app/api/routers/model_manager_v2.py +++ b/invokeai/app/api/routers/model_manager_v2.py @@ -9,7 +9,7 @@ from fastapi import Body, Path, Query, Response from fastapi.routing import APIRouter -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict from starlette.exceptions import HTTPException from typing_extensions import Annotated @@ -37,51 +37,102 @@ model_manager_v2_router = APIRouter(prefix="/v2/models", tags=["model_manager_v2"]) -example_model_output = { - "path": "sd-1/main/openjourney", - "name": "openjourney", + +class ModelsList(BaseModel): + """Return list of configs.""" + + models: List[AnyModelConfig] + + model_config = ConfigDict(use_enum_values=True) + + +class ModelTagSet(BaseModel): + """Return tags for a set of models.""" + + key: str + name: str + author: str + tags: Set[str] + + +############################################################################## +# These are example inputs and outputs that are used in places where Swagger +# is unable to generate a correct example. +############################################################################## +example_model_config = { + "path": "string", + "name": "string", "base": "sd-1", "type": "main", - "format": "diffusers", - "key": "3a0e45ff858926fd4a63da630688b1e1", - "original_hash": "1c12f18fb6e403baef26fb9d720fbd2f", - "current_hash": "1c12f18fb6e403baef26fb9d720fbd2f", - "description": "sd-1 main model openjourney", - "source": "/opt/invokeai/models/sd-1/main/openjourney", - "last_modified": 1707794711, - "vae": "/opt/invokeai/models/sd-1/vae/vae-ft-mse-840000-ema-pruned_fp16.safetensors", + "format": "checkpoint", + "config": "string", + "key": "string", + "original_hash": "string", + "current_hash": "string", + "description": "string", + "source": "string", + "last_modified": 0, + "vae": "string", "variant": "normal", "prediction_type": "epsilon", "repo_variant": "fp16", + "upcast_attention": False, + "ztsnr_training": False, } example_model_input = { - "path": "base/type/name", + "path": "/path/to/model", "name": "model_name", "base": "sd-1", "type": "main", - "format": "diffusers", + "format": "checkpoint", + "config": "configs/stable-diffusion/v1-inference.yaml", "description": "Model description", "vae": None, "variant": "normal", } +example_model_metadata = { + "name": "ip_adapter_sd_image_encoder", + "author": "InvokeAI", + "tags": [ + "transformers", + "safetensors", + "clip_vision_model", + "endpoints_compatible", + "region:us", + "has_space", + "license:apache-2.0", + ], + "files": [ + { + "url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/README.md", + "path": "ip_adapter_sd_image_encoder/README.md", + "size": 628, + "sha256": None, + }, + { + "url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/config.json", + "path": "ip_adapter_sd_image_encoder/config.json", + "size": 560, + "sha256": None, + }, + { + "url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/model.safetensors", + "path": "ip_adapter_sd_image_encoder/model.safetensors", + "size": 2528373448, + "sha256": "6ca9667da1ca9e0b0f75e46bb030f7e011f44f86cbfb8d5a36590fcd7507b030", + }, + ], + "type": "huggingface", + "id": "InvokeAI/ip_adapter_sd_image_encoder", + "tag_dict": {"license": "apache-2.0"}, + "last_modified": "2023-09-23T17:33:25Z", +} -class ModelsList(BaseModel): - """Return list of configs.""" - - models: List[AnyModelConfig] - - model_config = ConfigDict(use_enum_values=True) - - -class ModelTagSet(BaseModel): - """Return tags for a set of models.""" - - key: str - name: str - author: str - tags: Set[str] +############################################################################## +# ROUTES +############################################################################## @model_manager_v2_router.get( @@ -119,7 +170,7 @@ async def list_model_records( responses={ 200: { "description": "The model configuration was retrieved successfully", - "content": {"application/json": {"example": example_model_output}}, + "content": {"application/json": {"example": example_model_config}}, }, 400: {"description": "Bad request"}, 404: {"description": "The model could not be found"}, @@ -137,7 +188,7 @@ async def get_model_record( raise HTTPException(status_code=404, detail=str(e)) -@model_manager_v2_router.get("/meta", operation_id="list_model_summary") +@model_manager_v2_router.get("/summary", operation_id="list_model_summary") async def list_model_summary( page: int = Query(default=0, description="The page to get"), per_page: int = Query(default=10, description="The number of models per page"), @@ -153,7 +204,10 @@ async def list_model_summary( "/meta/i/{key}", operation_id="get_model_metadata", responses={ - 200: {"description": "Success"}, + 200: { + "description": "The model metadata was retrieved successfully", + "content": {"application/json": {"example": example_model_metadata}}, + }, 400: {"description": "Bad request"}, 404: {"description": "No metadata available"}, }, @@ -199,7 +253,7 @@ async def search_by_metadata_tags( responses={ 200: { "description": "The model was updated successfully", - "content": {"application/json": {"example": example_model_output}}, + "content": {"application/json": {"example": example_model_config}}, }, 400: {"description": "Bad request"}, 404: {"description": "The model could not be found"}, @@ -212,7 +266,7 @@ async def update_model_record( info: Annotated[ AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input) ], -) -> Annotated[AnyModelConfig, Field(example="this is neat")]: +) -> AnyModelConfig: """Update model contents with a new config. If the model name or base fields are changed, then the model is renamed.""" logger = ApiDependencies.invoker.services.logger record_store = ApiDependencies.invoker.services.model_manager.store @@ -263,7 +317,7 @@ async def del_model_record( responses={ 201: { "description": "The model added successfully", - "content": {"application/json": {"example": example_model_output}}, + "content": {"application/json": {"example": example_model_config}}, }, 409: {"description": "There is already a model corresponding to this path or repo_id"}, 415: {"description": "Unrecognized file/folder format"}, @@ -271,7 +325,9 @@ async def del_model_record( status_code=201, ) async def add_model_record( - config: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")], + config: Annotated[ + AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input) + ], ) -> AnyModelConfig: """Add a model using the configuration information appropriate for its type.""" logger = ApiDependencies.invoker.services.logger @@ -389,32 +445,38 @@ async def import_model( appropriate value: * To install a local path using LocalModelSource, pass a source of form: - `{ + ``` + { "type": "local", "path": "/path/to/model", "inplace": false - }` - The "inplace" flag, if true, will register the model in place in its - current filesystem location. Otherwise, the model will be copied - into the InvokeAI models directory. + } + ``` + The "inplace" flag, if true, will register the model in place in its + current filesystem location. Otherwise, the model will be copied + into the InvokeAI models directory. * To install a HuggingFace repo_id using HFModelSource, pass a source of form: - `{ + ``` + { "type": "hf", "repo_id": "stabilityai/stable-diffusion-2.0", "variant": "fp16", "subfolder": "vae", "access_token": "f5820a918aaf01" - }` - The `variant`, `subfolder` and `access_token` fields are optional. + } + ``` + The `variant`, `subfolder` and `access_token` fields are optional. * To install a remote model using an arbitrary URL, pass: - `{ + ``` + { "type": "url", "url": "http://www.civitai.com/models/123456", "access_token": "f5820a918aaf01" - }` - The `access_token` field is optonal + } + ``` + The `access_token` field is optonal The model's configuration record will be probed and filled in automatically. To override the default guesses, pass "metadata" @@ -423,9 +485,9 @@ async def import_model( Installation occurs in the background. Either use list_model_install_jobs() to poll for completion, or listen on the event bus for the following events: - "model_install_running" - "model_install_completed" - "model_install_error" + * "model_install_running" + * "model_install_completed" + * "model_install_error" On successful completion, the event's payload will contain the field "key" containing the installed ID of the model. On an error, the event's payload @@ -459,7 +521,25 @@ async def import_model( operation_id="list_model_install_jobs", ) async def list_model_install_jobs() -> List[ModelInstallJob]: - """Return list of model install jobs.""" + """Return the list of model install jobs. + + Install jobs have a numeric `id`, a `status`, and other fields that provide information on + the nature of the job and its progress. The `status` is one of: + + * "waiting" -- Job is waiting in the queue to run + * "downloading" -- Model file(s) are downloading + * "running" -- Model has downloaded and the model probing and registration process is running + * "completed" -- Installation completed successfully + * "error" -- An error occurred. Details will be in the "error_type" and "error" fields. + * "cancelled" -- Job was cancelled before completion. + + Once completed, information about the model such as its size, base + model, type, and metadata can be retrieved from the `config_out` + field. For multi-file models such as diffusers, information on individual files + can be retrieved from `download_parts`. + + See the example and schema below for more information. + """ jobs: List[ModelInstallJob] = ApiDependencies.invoker.services.model_manager.install.list_jobs() return jobs @@ -473,7 +553,10 @@ async def list_model_install_jobs() -> List[ModelInstallJob]: }, ) async def get_model_install_job(id: int = Path(description="Model install id")) -> ModelInstallJob: - """Return model install job corresponding to the given source.""" + """ + Return model install job corresponding to the given source. See the documentation for 'List Model Install Jobs' + for information on the format of the return value. + """ try: result: ModelInstallJob = ApiDependencies.invoker.services.model_manager.install.get_job_by_id(id) return result @@ -539,7 +622,7 @@ async def sync_models_to_config() -> Response: responses={ 200: { "description": "Model converted successfully", - "content": {"application/json": {"example": example_model_output}}, + "content": {"application/json": {"example": example_model_config}}, }, 400: {"description": "Bad request"}, 404: {"description": "Model not found"}, @@ -551,8 +634,8 @@ async def convert_model( ) -> AnyModelConfig: """ Permanently convert a model into diffusers format, replacing the safetensors version. - Note that the key and model hash will change. Use the model configuration record returned - by this call to get the new values. + Note that during the conversion process the key and model hash will change. + The return value is the model configuration for the converted model. """ logger = ApiDependencies.invoker.services.logger loader = ApiDependencies.invoker.services.model_manager.load @@ -617,7 +700,7 @@ async def convert_model( responses={ 200: { "description": "Model converted successfully", - "content": {"application/json": {"example": example_model_output}}, + "content": {"application/json": {"example": example_model_config}}, }, 400: {"description": "Bad request"}, 404: {"description": "Model not found"}, @@ -639,14 +722,17 @@ async def merge( ), ) -> AnyModelConfig: """ - Merge diffusers models. - - keys: List of 2-3 model keys to merge together. All models must use the same base type. - merged_model_name: Name for the merged model [Concat model names] - alpha: Alpha value (0.0-1.0). Higher values give more weight to the second model [0.5] - force: If true, force the merge even if the models were generated by different versions of the diffusers library [False] - interp: Interpolation method. One of "weighted_sum", "sigmoid", "inv_sigmoid" or "add_difference" [weighted_sum] - merge_dest_directory: Specify a directory to store the merged model in [models directory] + Merge diffusers models. The process is controlled by a set parameters provided in the body of the request. + ``` + Argument Description [default] + -------- ---------------------- + keys List of 2-3 model keys to merge together. All models must use the same base type. + merged_model_name Name for the merged model [Concat model names] + alpha Alpha value (0.0-1.0). Higher values give more weight to the second model [0.5] + force If true, force the merge even if the models were generated by different versions of the diffusers library [False] + interp Interpolation method. One of "weighted_sum", "sigmoid", "inv_sigmoid" or "add_difference" [weighted_sum] + merge_dest_directory Specify a directory to store the merged model in [models directory] + ``` """ logger = ApiDependencies.invoker.services.logger try: diff --git a/invokeai/app/services/download/download_default.py b/invokeai/app/services/download/download_default.py index 7008f8ed741..6d5cedbcad8 100644 --- a/invokeai/app/services/download/download_default.py +++ b/invokeai/app/services/download/download_default.py @@ -13,7 +13,7 @@ import requests from pydantic.networks import AnyHttpUrl from requests import HTTPError -from tqdm import tqdm, std +from tqdm import tqdm from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.util.misc import get_iso_timestamp diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 9f0f774b499..42921f0b32c 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -123,11 +123,11 @@ class ModelRepoVariant(str, Enum): class ModelConfigBase(BaseModel): """Base class for model configuration information.""" - path: str - name: str - base: BaseModelType - type: ModelType - format: ModelFormat + path: str = Field(description="filesystem path to the model file or directory") + name: str = Field(description="model name") + base: BaseModelType = Field(description="base model") + type: ModelType = Field(description="type of the model") + format: ModelFormat = Field(description="model format") key: str = Field(description="unique key for model", default="") original_hash: Optional[str] = Field( description="original fasthash of model contents", default=None @@ -135,9 +135,9 @@ class ModelConfigBase(BaseModel): current_hash: Optional[str] = Field( description="current fasthash of model contents", default=None ) # if model is converted or otherwise modified, this will hold updated hash - description: Optional[str] = Field(default=None) - source: Optional[str] = Field(description="Model download source (URL or repo_id)", default=None) - last_modified: Optional[float] = Field(description="Timestamp for modification time", default_factory=time.time) + description: Optional[str] = Field(description="human readable description of the model", default=None) + source: Optional[str] = Field(description="model original source (path, URL or repo_id)", default=None) + last_modified: Optional[float] = Field(description="timestamp for modification time", default_factory=time.time) model_config = ConfigDict( use_enum_values=False, From 0cdeea0fa35a1c1e76df1f371838531de5ad6d94 Mon Sep 17 00:00:00 2001 From: Brandon Rising Date: Wed, 14 Feb 2024 09:36:30 -0500 Subject: [PATCH 088/340] Remove references to model_records service, change submodel property on ModelInfo to submodel_type to support new params in model manager --- docs/contributing/MODEL_MANAGER.md | 2 +- invokeai/app/invocations/latent.py | 2 +- invokeai/app/invocations/model.py | 22 +++++++-------- invokeai/app/invocations/sdxl.py | 28 +++++++++---------- .../backend/model_management/model_manager.py | 2 +- pyproject.toml | 2 +- 6 files changed, 29 insertions(+), 29 deletions(-) diff --git a/docs/contributing/MODEL_MANAGER.md b/docs/contributing/MODEL_MANAGER.md index b711c654de8..b19699de73d 100644 --- a/docs/contributing/MODEL_MANAGER.md +++ b/docs/contributing/MODEL_MANAGER.md @@ -1627,7 +1627,7 @@ payload=dict( queue_batch_id=queue_batch_id, graph_execution_state_id=graph_execution_state_id, model_key=model_key, - submodel=submodel, + submodel_type=submodel, hash=model_info.hash, location=str(model_info.location), precision=str(model_info.precision), diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 289da2dd73d..c3de5219406 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -812,7 +812,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: ) with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae: - assert isinstance(vae, torch.Tensor) + assert isinstance(vae, torch.nn.Module) latents = latents.to(vae.device) if self.fp32: vae.to(dtype=torch.float32) diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index f78425c6ee3..71a71a63c83 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -18,7 +18,7 @@ class ModelInfo(BaseModel): key: str = Field(description="Key of model as returned by ModelRecordServiceBase.get_model()") - submodel: Optional[SubModelType] = Field(default=None, description="Info to load submodel") + submodel_type: Optional[SubModelType] = Field(default=None, description="Info to load submodel") class LoraInfo(ModelInfo): @@ -110,22 +110,22 @@ def invoke(self, context: InvocationContext) -> ModelLoaderOutput: unet=UNetField( unet=ModelInfo( key=key, - submodel=SubModelType.UNet, + submodel_type=SubModelType.UNet, ), scheduler=ModelInfo( key=key, - submodel=SubModelType.Scheduler, + submodel_type=SubModelType.Scheduler, ), loras=[], ), clip=ClipField( tokenizer=ModelInfo( key=key, - submodel=SubModelType.Tokenizer, + submodel_type=SubModelType.Tokenizer, ), text_encoder=ModelInfo( key=key, - submodel=SubModelType.TextEncoder, + submodel_type=SubModelType.TextEncoder, ), loras=[], skipped_layers=0, @@ -133,7 +133,7 @@ def invoke(self, context: InvocationContext) -> ModelLoaderOutput: vae=VaeField( vae=ModelInfo( key=key, - submodel=SubModelType.Vae, + submodel_type=SubModelType.Vae, ), ), ) @@ -188,7 +188,7 @@ def invoke(self, context: InvocationContext) -> LoraLoaderOutput: output.unet.loras.append( LoraInfo( key=lora_key, - submodel=None, + submodel_type=None, weight=self.weight, ) ) @@ -198,7 +198,7 @@ def invoke(self, context: InvocationContext) -> LoraLoaderOutput: output.clip.loras.append( LoraInfo( key=lora_key, - submodel=None, + submodel_type=None, weight=self.weight, ) ) @@ -271,7 +271,7 @@ def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput: output.unet.loras.append( LoraInfo( key=lora_key, - submodel=None, + submodel_type=None, weight=self.weight, ) ) @@ -281,7 +281,7 @@ def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput: output.clip.loras.append( LoraInfo( key=lora_key, - submodel=None, + submodel_type=None, weight=self.weight, ) ) @@ -291,7 +291,7 @@ def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput: output.clip2.loras.append( LoraInfo( key=lora_key, - submodel=None, + submodel_type=None, weight=self.weight, ) ) diff --git a/invokeai/app/invocations/sdxl.py b/invokeai/app/invocations/sdxl.py index 633a6477fdb..85e6fb787fa 100644 --- a/invokeai/app/invocations/sdxl.py +++ b/invokeai/app/invocations/sdxl.py @@ -43,29 +43,29 @@ def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput: model_key = self.model.key # TODO: not found exceptions - if not context.services.model_records.exists(model_key): + if not context.services.model_manager.store.exists(model_key): raise Exception(f"Unknown model: {model_key}") return SDXLModelLoaderOutput( unet=UNetField( unet=ModelInfo( key=model_key, - submodel=SubModelType.UNet, + submodel_type=SubModelType.UNet, ), scheduler=ModelInfo( key=model_key, - submodel=SubModelType.Scheduler, + submodel_type=SubModelType.Scheduler, ), loras=[], ), clip=ClipField( tokenizer=ModelInfo( key=model_key, - submodel=SubModelType.Tokenizer, + submodel_type=SubModelType.Tokenizer, ), text_encoder=ModelInfo( key=model_key, - submodel=SubModelType.TextEncoder, + submodel_type=SubModelType.TextEncoder, ), loras=[], skipped_layers=0, @@ -73,11 +73,11 @@ def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput: clip2=ClipField( tokenizer=ModelInfo( key=model_key, - submodel=SubModelType.Tokenizer2, + submodel_type=SubModelType.Tokenizer2, ), text_encoder=ModelInfo( key=model_key, - submodel=SubModelType.TextEncoder2, + submodel_type=SubModelType.TextEncoder2, ), loras=[], skipped_layers=0, @@ -85,7 +85,7 @@ def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput: vae=VaeField( vae=ModelInfo( key=model_key, - submodel=SubModelType.Vae, + submodel_type=SubModelType.Vae, ), ), ) @@ -112,29 +112,29 @@ def invoke(self, context: InvocationContext) -> SDXLRefinerModelLoaderOutput: model_key = self.model.key # TODO: not found exceptions - if not context.services.model_records.exists(model_key): + if not context.services.model_manager.store.exists(model_key): raise Exception(f"Unknown model: {model_key}") return SDXLRefinerModelLoaderOutput( unet=UNetField( unet=ModelInfo( key=model_key, - submodel=SubModelType.UNet, + submodel_type=SubModelType.UNet, ), scheduler=ModelInfo( key=model_key, - submodel=SubModelType.Scheduler, + submodel_type=SubModelType.Scheduler, ), loras=[], ), clip2=ClipField( tokenizer=ModelInfo( key=model_key, - submodel=SubModelType.Tokenizer2, + submodel_type=SubModelType.Tokenizer2, ), text_encoder=ModelInfo( key=model_key, - submodel=SubModelType.TextEncoder2, + submodel_type=SubModelType.TextEncoder2, ), loras=[], skipped_layers=0, @@ -142,7 +142,7 @@ def invoke(self, context: InvocationContext) -> SDXLRefinerModelLoaderOutput: vae=VaeField( vae=ModelInfo( key=model_key, - submodel=SubModelType.Vae, + submodel_type=SubModelType.Vae, ), ), ) diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index da74ca3fb58..84d93f15fa8 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -499,7 +499,7 @@ def get_model( model_class=model_class, base_model=base_model, model_type=model_type, - submodel=submodel_type, + submodel_type=submodel_type, ) if model_key not in self.cache_keys: diff --git a/pyproject.toml b/pyproject.toml index 2958e3629a8..f57607bc0af 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -245,7 +245,7 @@ module = [ "invokeai.app.services.invocation_stats.invocation_stats_default", "invokeai.app.services.model_manager.model_manager_base", "invokeai.app.services.model_manager.model_manager_default", - "invokeai.app.services.model_records.model_records_sql", + "invokeai.app.services.model_manager.store.model_records_sql", "invokeai.app.util.controlnet_utils", "invokeai.backend.image_util.txt2mask", "invokeai.backend.image_util.safety_checker", From 162de87cd096154be4b47c3686aa69103c68a5a5 Mon Sep 17 00:00:00 2001 From: Brandon Rising Date: Wed, 14 Feb 2024 09:51:11 -0500 Subject: [PATCH 089/340] References to context.services.model_manager.store.get_model can only accept keys, remove invalid assertion --- invokeai/app/invocations/latent.py | 4 ++-- .../load/model_cache/model_cache_default.py | 22 +++++++++---------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index c3de5219406..05293fdfee3 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -681,7 +681,7 @@ def invoke(self, context: InvocationContext) -> LatentsOutput: source_node_id = graph_execution_state.prepared_source_mapping[self.id] # get the unet's config so that we can pass the base to dispatch_progress() - unet_config = context.services.model_manager.store.get_model(**self.unet.unet.model_dump()) + unet_config = context.services.model_manager.store.get_model(self.unet.unet.key) def step_callback(state: PipelineIntermediateState) -> None: self.dispatch_progress(context, source_node_id, state, unet_config.base) @@ -709,7 +709,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: # Apply the LoRA after unet has been moved to its target device for faster patching. ModelPatcher.apply_lora_unet(unet, _lora_loader()), ): - assert isinstance(unet, torch.Tensor) + assert isinstance(unet, UNet2DConditionModel) latents = latents.to(device=unet.device, dtype=unet.dtype) if noise is not None: noise = noise.to(device=unet.device, dtype=unet.dtype) diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py index 786396062cf..02ce1266c75 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py @@ -303,18 +303,18 @@ def print_cuda_stats(self) -> None: in_vram_models = 0 locked_in_vram_models = 0 for cache_record in self._cached_models.values(): - assert hasattr(cache_record.model, "device") - if cache_record.model.device == self.storage_device: - in_ram_models += 1 - else: - in_vram_models += 1 - if cache_record.locked: - locked_in_vram_models += 1 + if hasattr(cache_record.model, "device"): + if cache_record.model.device == self.storage_device: + in_ram_models += 1 + else: + in_vram_models += 1 + if cache_record.locked: + locked_in_vram_models += 1 - self.logger.debug( - f"Current VRAM/RAM usage: {vram}/{ram}; models_in_ram/models_in_vram(locked) =" - f" {in_ram_models}/{in_vram_models}({locked_in_vram_models})" - ) + self.logger.debug( + f"Current VRAM/RAM usage: {vram}/{ram}; models_in_ram/models_in_vram(locked) =" + f" {in_ram_models}/{in_vram_models}({locked_in_vram_models})" + ) def make_room(self, model_size: int) -> None: """Make enough room in the cache to accommodate a new model of indicated size.""" From 88b9038c3ec0c4e10cb0bb5e551bc748a4ba5cf2 Mon Sep 17 00:00:00 2001 From: Brandon Rising Date: Wed, 14 Feb 2024 13:07:11 -0500 Subject: [PATCH 090/340] Update _get_hf_load_class to support clipvision models --- invokeai/backend/model_manager/load/load_default.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/invokeai/backend/model_manager/load/load_default.py b/invokeai/backend/model_manager/load/load_default.py index df83c8320d9..9ed0ccb2d3c 100644 --- a/invokeai/backend/model_manager/load/load_default.py +++ b/invokeai/backend/model_manager/load/load_default.py @@ -163,8 +163,12 @@ def _get_hf_load_class(self, model_path: Path, submodel_type: Optional[SubModelT else: try: config = self._load_diffusers_config(model_path, config_name="config.json") - class_name = config["_class_name"] - return self._hf_definition_to_type(module="diffusers", class_name=class_name) + class_name = config.get("_class_name", None) + if class_name: + return self._hf_definition_to_type(module="diffusers", class_name=class_name) + if config.get("model_type", None) == "clip_vision_model": + class_name = config.get("architectures")[0] + return self._hf_definition_to_type(module="transformers", class_name=class_name) except KeyError as e: raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e From 0f93698a09be4e1d74feda6cd79b1d8f45200dbf Mon Sep 17 00:00:00 2001 From: Brandon Rising Date: Wed, 14 Feb 2024 13:16:15 -0500 Subject: [PATCH 091/340] Raise InvalidModelConfigException when unable to detect load class in ModelLoader --- invokeai/backend/model_manager/load/load_default.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/invokeai/backend/model_manager/load/load_default.py b/invokeai/backend/model_manager/load/load_default.py index 9ed0ccb2d3c..1dac121a300 100644 --- a/invokeai/backend/model_manager/load/load_default.py +++ b/invokeai/backend/model_manager/load/load_default.py @@ -169,6 +169,8 @@ def _get_hf_load_class(self, model_path: Path, submodel_type: Optional[SubModelT if config.get("model_type", None) == "clip_vision_model": class_name = config.get("architectures")[0] return self._hf_definition_to_type(module="transformers", class_name=class_name) + if not class_name: + raise InvalidModelConfigException("Unable to decifer Load Class based on given config.json") except KeyError as e: raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e From 7c454b2d7f53c97464f1ff1120e42114f08d73ca Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 15 Feb 2024 20:43:41 +1100 Subject: [PATCH 092/340] feat(nodes): update invocation context for mm2, update nodes model usage --- invokeai/app/invocations/compel.py | 40 ++------ invokeai/app/invocations/ip_adapter.py | 7 +- invokeai/app/invocations/latent.py | 71 +++----------- invokeai/app/invocations/model.py | 8 +- invokeai/app/invocations/sdxl.py | 4 +- .../services/model_load/model_load_base.py | 14 +-- .../services/model_load/model_load_default.py | 48 +++++----- .../app/services/shared/invocation_context.py | 94 ++++++++++++++----- invokeai/app/util/step_callback.py | 2 +- 9 files changed, 141 insertions(+), 147 deletions(-) diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 3850fb6cc3d..5159d5b89c5 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -69,20 +69,12 @@ class CompelInvocation(BaseInvocation): @torch.no_grad() def invoke(self, context: InvocationContext) -> ConditioningOutput: - tokenizer_info = context.services.model_manager.load.load_model_by_key( - **self.clip.tokenizer.model_dump(), - context=context, - ) - text_encoder_info = context.services.model_manager.load.load_model_by_key( - **self.clip.text_encoder.model_dump(), - context=context, - ) + tokenizer_info = context.models.load(**self.clip.tokenizer.model_dump()) + text_encoder_info = context.models.load(**self.clip.text_encoder.model_dump()) def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: for lora in self.clip.loras: - lora_info = context.services.model_manager.load.load_model_by_key( - **lora.model_dump(exclude={"weight"}), context=context - ) + lora_info = context.models.load(**lora.model_dump(exclude={"weight"})) assert isinstance(lora_info.model, LoRAModelRaw) yield (lora_info.model, lora.weight) del lora_info @@ -94,10 +86,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: for trigger in extract_ti_triggers_from_prompt(self.prompt): name = trigger[1:-1] try: - loaded_model = context.services.model_manager.load.load_model_by_key( - **self.clip.text_encoder.model_dump(), - context=context, - ).model + loaded_model = context.models.load(**self.clip.text_encoder.model_dump()).model assert isinstance(loaded_model, TextualInversionModelRaw) ti_list.append((name, loaded_model)) except UnknownModelException: @@ -165,14 +154,8 @@ def run_clip_compel( lora_prefix: str, zero_on_empty: bool, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[ExtraConditioningInfo]]: - tokenizer_info = context.services.model_manager.load.load_model_by_key( - **clip_field.tokenizer.model_dump(), - context=context, - ) - text_encoder_info = context.services.model_manager.load.load_model_by_key( - **clip_field.text_encoder.model_dump(), - context=context, - ) + tokenizer_info = context.models.load(**clip_field.tokenizer.model_dump()) + text_encoder_info = context.models.load(**clip_field.text_encoder.model_dump()) # return zero on empty if prompt == "" and zero_on_empty: @@ -197,9 +180,7 @@ def run_clip_compel( def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: for lora in clip_field.loras: - lora_info = context.services.model_manager.load.load_model_by_key( - **lora.model_dump(exclude={"weight"}), context=context - ) + lora_info = context.models.load(**lora.model_dump(exclude={"weight"})) lora_model = lora_info.model assert isinstance(lora_model, LoRAModelRaw) yield (lora_model, lora.weight) @@ -212,11 +193,8 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: for trigger in extract_ti_triggers_from_prompt(prompt): name = trigger[1:-1] try: - ti_model = context.services.model_manager.load.load_model_by_attr( - model_name=name, - base_model=text_encoder_info.config.base, - model_type=ModelType.TextualInversion, - context=context, + ti_model = context.models.load_by_attrs( + model_name=name, base_model=text_encoder_info.config.base, model_type=ModelType.TextualInversion ).model assert isinstance(ti_model, TextualInversionModelRaw) ti_list.append((name, ti_model)) diff --git a/invokeai/app/invocations/ip_adapter.py b/invokeai/app/invocations/ip_adapter.py index 01124f62f3c..15e254010b5 100644 --- a/invokeai/app/invocations/ip_adapter.py +++ b/invokeai/app/invocations/ip_adapter.py @@ -14,8 +14,7 @@ from invokeai.app.invocations.primitives import ImageField from invokeai.app.invocations.util import validate_begin_end_step, validate_weights from invokeai.app.services.shared.invocation_context import InvocationContext -from invokeai.backend.model_management.models.base import BaseModelType, ModelType -from invokeai.backend.model_management.models.ip_adapter import get_ip_adapter_image_encoder_model_id +from invokeai.backend.model_manager.config import BaseModelType, ModelType # LS: Consider moving these two classes into model.py @@ -90,10 +89,10 @@ def validate_begin_end_step_percent(self) -> Self: def invoke(self, context: InvocationContext) -> IPAdapterOutput: # Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model. - ip_adapter_info = context.services.model_manager.store.get_model(self.ip_adapter_model.key) + ip_adapter_info = context.models.get_config(self.ip_adapter_model.key) image_encoder_model_id = ip_adapter_info.image_encoder_model_id image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip() - image_encoder_models = context.services.model_manager.store.search_by_attr( + image_encoder_models = context.models.search_by_attrs( model_name=image_encoder_model_name, base_model=BaseModelType.Any, model_type=ModelType.CLIPVision ) assert len(image_encoder_models) == 1 diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 05293fdfee3..5dd0eb074d5 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -141,7 +141,7 @@ def prep_mask_tensor(self, mask_image: Image) -> torch.Tensor: @torch.no_grad() def invoke(self, context: InvocationContext) -> DenoiseMaskOutput: if self.image is not None: - image = context.services.images.get_pil_image(self.image.image_name) + image = context.images.get_pil(self.image.image_name) image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB")) if image_tensor.dim() == 3: image_tensor = image_tensor.unsqueeze(0) @@ -153,10 +153,7 @@ def invoke(self, context: InvocationContext) -> DenoiseMaskOutput: ) if image_tensor is not None: - vae_info = context.services.model_manager.load.load_model_by_key( - **self.vae.vae.model_dump(), - context=context, - ) + vae_info = context.models.load(**self.vae.vae.model_dump()) img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False) masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0) @@ -182,10 +179,7 @@ def get_scheduler( seed: int, ) -> Scheduler: scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"]) - orig_scheduler_info = context.services.model_manager.load.load_model_by_key( - **scheduler_info.model_dump(), - context=context, - ) + orig_scheduler_info = context.models.load(**scheduler_info.model_dump()) with orig_scheduler_info as orig_scheduler: scheduler_config = orig_scheduler.config @@ -399,12 +393,7 @@ def prep_control_data( # and if weight is None, populate with default 1.0? controlnet_data = [] for control_info in control_list: - control_model = exit_stack.enter_context( - context.services.model_manager.load.load_model_by_key( - key=control_info.control_model.key, - context=context, - ) - ) + control_model = exit_stack.enter_context(context.models.load(key=control_info.control_model.key)) # control_models.append(control_model) control_image_field = control_info.image @@ -466,25 +455,17 @@ def prep_ip_adapter_data( conditioning_data.ip_adapter_conditioning = [] for single_ip_adapter in ip_adapter: ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context( - context.services.model_manager.load.load_model_by_key( - key=single_ip_adapter.ip_adapter_model.key, - context=context, - ) + context.models.load(key=single_ip_adapter.ip_adapter_model.key) ) - image_encoder_model_info = context.services.model_manager.load.load_model_by_key( - key=single_ip_adapter.image_encoder_model.key, - context=context, - ) + image_encoder_model_info = context.models.load(key=single_ip_adapter.image_encoder_model.key) # `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here. single_ipa_image_fields = single_ip_adapter.image if not isinstance(single_ipa_image_fields, list): single_ipa_image_fields = [single_ipa_image_fields] - single_ipa_images = [ - context.services.images.get_pil_image(image.image_name) for image in single_ipa_image_fields - ] + single_ipa_images = [context.images.get_pil(image.image_name) for image in single_ipa_image_fields] # TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other # models are needed in memory. This would help to reduce peak memory utilization in low-memory environments. @@ -528,10 +509,7 @@ def run_t2i_adapters( t2i_adapter_data = [] for t2i_adapter_field in t2i_adapter: - t2i_adapter_model_info = context.services.model_manager.load.load_model_by_key( - key=t2i_adapter_field.t2i_adapter_model.key, - context=context, - ) + t2i_adapter_model_info = context.models.load(key=t2i_adapter_field.t2i_adapter_model.key) image = context.images.get_pil(t2i_adapter_field.image.image_name) # The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally. @@ -676,30 +654,20 @@ def invoke(self, context: InvocationContext) -> LatentsOutput: do_classifier_free_guidance=True, ) - # Get the source node id (we are invoking the prepared node) - graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) - source_node_id = graph_execution_state.prepared_source_mapping[self.id] - # get the unet's config so that we can pass the base to dispatch_progress() - unet_config = context.services.model_manager.store.get_model(self.unet.unet.key) + unet_config = context.models.get_config(self.unet.unet.key) def step_callback(state: PipelineIntermediateState) -> None: - self.dispatch_progress(context, source_node_id, state, unet_config.base) + context.util.sd_step_callback(state, unet_config.base) def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: for lora in self.unet.loras: - lora_info = context.services.model_manager.load.load_model_by_key( - **lora.model_dump(exclude={"weight"}), - context=context, - ) + lora_info = context.models.load(**lora.model_dump(exclude={"weight"})) yield (lora_info.model, lora.weight) del lora_info return - unet_info = context.services.model_manager.load.load_model_by_key( - **self.unet.unet.model_dump(), - context=context, - ) + unet_info = context.models.load(**self.unet.unet.model_dump()) assert isinstance(unet_info.model, UNet2DConditionModel) with ( ExitStack() as exit_stack, @@ -806,10 +774,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard): def invoke(self, context: InvocationContext) -> ImageOutput: latents = context.tensors.load(self.latents.latents_name) - vae_info = context.services.model_manager.load.load_model_by_key( - **self.vae.vae.model_dump(), - context=context, - ) + vae_info = context.models.load(**self.vae.vae.model_dump()) with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae: assert isinstance(vae, torch.nn.Module) @@ -1032,10 +997,7 @@ def vae_encode(vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: t def invoke(self, context: InvocationContext) -> LatentsOutput: image = context.images.get_pil(self.image.image_name) - vae_info = context.services.model_manager.load.load_model_by_key( - **self.vae.vae.model_dump(), - context=context, - ) + vae_info = context.models.load(**self.vae.vae.model_dump()) image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB")) if image_tensor.dim() == 3: @@ -1239,10 +1201,7 @@ def trim_to_multiple_of(self, *args: int, multiple_of: int = LATENT_SCALE_FACTOR return tuple((x - x % multiple_of) for x in args) def invoke(self, context: InvocationContext) -> IdealSizeOutput: - unet_config = context.services.model_manager.load.load_model_by_key( - **self.unet.unet.model_dump(), - context=context, - ) + unet_config = context.models.get_config(**self.unet.unet.model_dump()) aspect = self.width / self.height dimension: float = 512 if unet_config.base == BaseModelType.StableDiffusion2: diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index 71a71a63c83..6087bc82db1 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -103,7 +103,7 @@ def invoke(self, context: InvocationContext) -> ModelLoaderOutput: key = self.model.key # TODO: not found exceptions - if not context.services.model_manager.store.exists(key): + if not context.models.exists(key): raise Exception(f"Unknown model {key}") return ModelLoaderOutput( @@ -172,7 +172,7 @@ def invoke(self, context: InvocationContext) -> LoraLoaderOutput: lora_key = self.lora.key - if not context.services.model_manager.store.exists(lora_key): + if not context.models.exists(lora_key): raise Exception(f"Unkown lora: {lora_key}!") if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras): @@ -252,7 +252,7 @@ def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput: lora_key = self.lora.key - if not context.services.model_manager.store.exists(lora_key): + if not context.models.exists(lora_key): raise Exception(f"Unknown lora: {lora_key}!") if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras): @@ -318,7 +318,7 @@ class VaeLoaderInvocation(BaseInvocation): def invoke(self, context: InvocationContext) -> VAEOutput: key = self.vae_model.key - if not context.services.model_manager.store.exists(key): + if not context.models.exists(key): raise Exception(f"Unkown vae: {key}!") return VAEOutput(vae=VaeField(vae=ModelInfo(key=key))) diff --git a/invokeai/app/invocations/sdxl.py b/invokeai/app/invocations/sdxl.py index 85e6fb787fa..0df27c00110 100644 --- a/invokeai/app/invocations/sdxl.py +++ b/invokeai/app/invocations/sdxl.py @@ -43,7 +43,7 @@ def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput: model_key = self.model.key # TODO: not found exceptions - if not context.services.model_manager.store.exists(model_key): + if not context.models.exists(model_key): raise Exception(f"Unknown model: {model_key}") return SDXLModelLoaderOutput( @@ -112,7 +112,7 @@ def invoke(self, context: InvocationContext) -> SDXLRefinerModelLoaderOutput: model_key = self.model.key # TODO: not found exceptions - if not context.services.model_manager.store.exists(model_key): + if not context.models.exists(model_key): raise Exception(f"Unknown model: {model_key}") return SDXLRefinerModelLoaderOutput( diff --git a/invokeai/app/services/model_load/model_load_base.py b/invokeai/app/services/model_load/model_load_base.py index 45eaf4652fb..f4dd905135a 100644 --- a/invokeai/app/services/model_load/model_load_base.py +++ b/invokeai/app/services/model_load/model_load_base.py @@ -4,7 +4,7 @@ from abc import ABC, abstractmethod from typing import Optional -from invokeai.app.invocations.baseinvocation import InvocationContext +from invokeai.app.services.shared.invocation_context import InvocationContextData from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType from invokeai.backend.model_manager.load import LoadedModel from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase @@ -19,14 +19,14 @@ def load_model_by_key( self, key: str, submodel_type: Optional[SubModelType] = None, - context: Optional[InvocationContext] = None, + context_data: Optional[InvocationContextData] = None, ) -> LoadedModel: """ Given a model's key, load it and return the LoadedModel object. :param key: Key of model config to be fetched. :param submodel: For main (pipeline models), the submodel to fetch. - :param context: Invocation context used for event reporting + :param context_data: Invocation context data used for event reporting """ pass @@ -35,14 +35,14 @@ def load_model_by_config( self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None, - context: Optional[InvocationContext] = None, + context_data: Optional[InvocationContextData] = None, ) -> LoadedModel: """ Given a model's configuration, load it and return the LoadedModel object. :param model_config: Model configuration record (as returned by ModelRecordBase.get_model()) :param submodel: For main (pipeline models), the submodel to fetch. - :param context: Invocation context used for event reporting + :param context_data: Invocation context data used for event reporting """ pass @@ -53,7 +53,7 @@ def load_model_by_attr( base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None, - context: Optional[InvocationContext] = None, + context_data: Optional[InvocationContextData] = None, ) -> LoadedModel: """ Given a model's attributes, search the database for it, and if found, load and return the LoadedModel object. @@ -66,7 +66,7 @@ def load_model_by_attr( :param base_model: Base model :param model_type: Type of the model :param submodel: For main (pipeline models), the submodel to fetch - :param context: The invocation context. + :param context_data: The invocation context data. Exceptions: UnknownModelException -- model with these attributes not known NotImplementedException -- a model loader was not provided at initialization time diff --git a/invokeai/app/services/model_load/model_load_default.py b/invokeai/app/services/model_load/model_load_default.py index a6ccd5afbc3..29b297c8145 100644 --- a/invokeai/app/services/model_load/model_load_default.py +++ b/invokeai/app/services/model_load/model_load_default.py @@ -3,10 +3,11 @@ from typing import Optional -from invokeai.app.invocations.baseinvocation import InvocationContext from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException +from invokeai.app.services.invoker import Invoker from invokeai.app.services.model_records import ModelRecordServiceBase, UnknownModelException +from invokeai.app.services.shared.invocation_context import InvocationContextData from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType from invokeai.backend.model_manager.load import AnyModelLoader, LoadedModel, ModelCache, ModelConvertCache from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase @@ -46,6 +47,9 @@ def __init__( ), ) + def start(self, invoker: Invoker) -> None: + self._invoker = invoker + @property def ram_cache(self) -> ModelCacheBase[AnyModel]: """Return the RAM cache used by this loader.""" @@ -60,7 +64,7 @@ def load_model_by_key( self, key: str, submodel_type: Optional[SubModelType] = None, - context: Optional[InvocationContext] = None, + context_data: Optional[InvocationContextData] = None, ) -> LoadedModel: """ Given a model's key, load it and return the LoadedModel object. @@ -70,7 +74,7 @@ def load_model_by_key( :param context: Invocation context used for event reporting """ config = self._store.get_model(key) - return self.load_model_by_config(config, submodel_type, context) + return self.load_model_by_config(config, submodel_type, context_data) def load_model_by_attr( self, @@ -78,7 +82,7 @@ def load_model_by_attr( base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None, - context: Optional[InvocationContext] = None, + context_data: Optional[InvocationContextData] = None, ) -> LoadedModel: """ Given a model's attributes, search the database for it, and if found, load and return the LoadedModel object. @@ -109,7 +113,7 @@ def load_model_by_config( self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None, - context: Optional[InvocationContext] = None, + context_data: Optional[InvocationContextData] = None, ) -> LoadedModel: """ Given a model's configuration, load it and return the LoadedModel object. @@ -118,15 +122,15 @@ def load_model_by_config( :param submodel: For main (pipeline models), the submodel to fetch. :param context: Invocation context used for event reporting """ - if context: + if context_data: self._emit_load_event( - context=context, + context_data=context_data, model_config=model_config, ) loaded_model = self._any_loader.load_model(model_config, submodel_type) - if context: + if context_data: self._emit_load_event( - context=context, + context_data=context_data, model_config=model_config, loaded=True, ) @@ -134,26 +138,28 @@ def load_model_by_config( def _emit_load_event( self, - context: InvocationContext, + context_data: InvocationContextData, model_config: AnyModelConfig, loaded: Optional[bool] = False, ) -> None: - if context.services.queue.is_canceled(context.graph_execution_state_id): + if not self._invoker: + return + if self._invoker.services.queue.is_canceled(context_data.session_id): raise CanceledException() if not loaded: - context.services.events.emit_model_load_started( - queue_id=context.queue_id, - queue_item_id=context.queue_item_id, - queue_batch_id=context.queue_batch_id, - graph_execution_state_id=context.graph_execution_state_id, + self._invoker.services.events.emit_model_load_started( + queue_id=context_data.queue_id, + queue_item_id=context_data.queue_item_id, + queue_batch_id=context_data.batch_id, + graph_execution_state_id=context_data.session_id, model_config=model_config, ) else: - context.services.events.emit_model_load_completed( - queue_id=context.queue_id, - queue_item_id=context.queue_item_id, - queue_batch_id=context.queue_batch_id, - graph_execution_state_id=context.graph_execution_state_id, + self._invoker.services.events.emit_model_load_completed( + queue_id=context_data.queue_id, + queue_item_id=context_data.queue_item_id, + queue_batch_id=context_data.batch_id, + graph_execution_state_id=context_data.session_id, model_config=model_config, ) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index c68dc1140b2..089d09f825c 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from pathlib import Path from typing import TYPE_CHECKING, Optional from PIL.Image import Image @@ -12,8 +13,9 @@ from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID from invokeai.app.util.step_callback import stable_diffusion_step_callback -from invokeai.backend.model_management.model_manager import LoadedModelInfo -from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType +from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType +from invokeai.backend.model_manager.load.load_base import LoadedModel +from invokeai.backend.model_manager.metadata.metadata_base import AnyModelRepoMetadata from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData @@ -259,45 +261,95 @@ def load(self, name: str) -> ConditioningFieldData: class ModelsInterface(InvocationContextInterface): - def exists(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> bool: + def exists(self, key: str) -> bool: """ Checks if a model exists. - :param model_name: The name of the model to check. - :param base_model: The base model of the model to check. - :param model_type: The type of the model to check. + :param key: The key of the model. """ - return self._services.model_manager.model_exists(model_name, base_model, model_type) + return self._services.model_manager.store.exists(key) - def load( - self, model_name: str, base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None - ) -> LoadedModelInfo: + def load(self, key: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel: """ Loads a model. - :param model_name: The name of the model to get. - :param base_model: The base model of the model to get. - :param model_type: The type of the model to get. - :param submodel: The submodel of the model to get. + :param key: The key of the model. + :param submodel_type: The submodel of the model to get. :returns: An object representing the loaded model. """ # The model manager emits events as it loads the model. It needs the context data to build # the event payloads. - return self._services.model_manager.get_model( - model_name, base_model, model_type, submodel, context_data=self._context_data + return self._services.model_manager.load.load_model_by_key( + key=key, submodel_type=submodel_type, context_data=self._context_data ) - def get_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict: + def load_by_attrs( + self, model_name: str, base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None + ) -> LoadedModel: + """ + Loads a model by its attributes. + + :param model_name: Name of to be fetched. + :param base_model: Base model + :param model_type: Type of the model + :param submodel: For main (pipeline models), the submodel to fetch + """ + return self._services.model_manager.load.load_model_by_attr( + model_name=model_name, + base_model=base_model, + model_type=model_type, + submodel=submodel, + context_data=self._context_data, + ) + + def get_config(self, key: str) -> AnyModelConfig: """ Gets a model's info, an dict-like object. - :param model_name: The name of the model to get. - :param base_model: The base model of the model to get. - :param model_type: The type of the model to get. + :param key: The key of the model. + """ + return self._services.model_manager.store.get_model(key=key) + + def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]: + """ + Gets a model's metadata, if it has any. + + :param key: The key of the model. """ - return self._services.model_manager.model_info(model_name, base_model, model_type) + return self._services.model_manager.store.get_metadata(key=key) + + def search_by_path(self, path: Path) -> list[AnyModelConfig]: + """ + Searches for models by path. + + :param path: The path to search for. + """ + return self._services.model_manager.store.search_by_path(path) + + def search_by_attrs( + self, + model_name: Optional[str] = None, + base_model: Optional[BaseModelType] = None, + model_type: Optional[ModelType] = None, + model_format: Optional[ModelFormat] = None, + ) -> list[AnyModelConfig]: + """ + Searches for models by attributes. + + :param model_name: Name of to be fetched. + :param base_model: Base model + :param model_type: Type of the model + :param submodel: For main (pipeline models), the submodel to fetch + """ + + return self._services.model_manager.store.search_by_attr( + model_name=model_name, + base_model=base_model, + model_type=model_type, + model_format=model_format, + ) class ConfigInterface(InvocationContextInterface): diff --git a/invokeai/app/util/step_callback.py b/invokeai/app/util/step_callback.py index d83b380d95d..33d00ca3660 100644 --- a/invokeai/app/util/step_callback.py +++ b/invokeai/app/util/step_callback.py @@ -4,8 +4,8 @@ from PIL import Image from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException, ProgressImage +from invokeai.backend.model_manager.config import BaseModelType -from ...backend.model_management.models import BaseModelType from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.util.util import image_to_dataURL From 758e0c1565e1c7407e10666da0d93bfa4b441d62 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 15 Feb 2024 20:50:47 +1100 Subject: [PATCH 093/340] chore: ruff --- invokeai/app/invocations/controlnet_image_processors.py | 1 - 1 file changed, 1 deletion(-) diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index 580ee085627..8542134fff0 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -39,7 +39,6 @@ from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.image_util.depth_anything import DepthAnythingDetector from invokeai.backend.image_util.dw_openpose import DWOpenposeDetector -from invokeai.backend.model_management.models.base import BaseModelType from .baseinvocation import ( BaseInvocation, From 8651b7bd38353294b32a7fa9344cf491a2c12d86 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 15 Feb 2024 20:52:44 +1100 Subject: [PATCH 094/340] chore: lint --- invokeai/frontend/web/src/features/nodes/types/error.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invokeai/frontend/web/src/features/nodes/types/error.ts b/invokeai/frontend/web/src/features/nodes/types/error.ts index c3da136c7a8..82bc0f86e09 100644 --- a/invokeai/frontend/web/src/features/nodes/types/error.ts +++ b/invokeai/frontend/web/src/features/nodes/types/error.ts @@ -60,4 +60,4 @@ export class FieldParseError extends Error { export class UnableToExtractSchemaNameFromRefError extends FieldParseError {} export class UnsupportedArrayItemType extends FieldParseError {} export class UnsupportedUnionError extends FieldParseError {} -export class UnsupportedPrimitiveTypeError extends FieldParseError {} \ No newline at end of file +export class UnsupportedPrimitiveTypeError extends FieldParseError {} From 8e91c3a08bc9573f7b88172e635dbb89fd6b616c Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 15 Feb 2024 20:53:41 +1100 Subject: [PATCH 095/340] fix(ui): fix type issues --- .../nodes/components/sidePanel/viewMode/WorkflowField.tsx | 4 ++-- .../src/features/nodes/util/schema/parseFieldType.test.ts | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/components/sidePanel/viewMode/WorkflowField.tsx b/invokeai/frontend/web/src/features/nodes/components/sidePanel/viewMode/WorkflowField.tsx index 0e5857933a7..e707dd4f54d 100644 --- a/invokeai/frontend/web/src/features/nodes/components/sidePanel/viewMode/WorkflowField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/sidePanel/viewMode/WorkflowField.tsx @@ -16,7 +16,7 @@ type Props = { const WorkflowField = ({ nodeId, fieldName }: Props) => { const label = useFieldLabel(nodeId, fieldName); - const fieldTemplateTitle = useFieldTemplateTitle(nodeId, fieldName, 'input'); + const fieldTemplateTitle = useFieldTemplateTitle(nodeId, fieldName, 'inputs'); const { isValueChanged, onReset } = useFieldOriginalValue(nodeId, fieldName); return ( @@ -36,7 +36,7 @@ const WorkflowField = ({ nodeId, fieldName }: Props) => { /> )} } + label={} openDelay={HANDLE_TOOLTIP_OPEN_DELAY} placement="top" > diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/parseFieldType.test.ts b/invokeai/frontend/web/src/features/nodes/util/schema/parseFieldType.test.ts index 2f4ce48a326..d7011ad6f84 100644 --- a/invokeai/frontend/web/src/features/nodes/util/schema/parseFieldType.test.ts +++ b/invokeai/frontend/web/src/features/nodes/util/schema/parseFieldType.test.ts @@ -284,13 +284,13 @@ describe('refObjectToSchemaName', async () => { }); describe.concurrent('parseFieldType', async () => { - it.each(primitiveTypes)('parses primitive types ($name)', ({ schema, expected }) => { + it.each(primitiveTypes)('parses primitive types ($name)', ({ schema, expected }: ParseFieldTypeTestCase) => { expect(parseFieldType(schema)).toEqual(expected); }); - it.each(complexTypes)('parses complex types ($name)', ({ schema, expected }) => { + it.each(complexTypes)('parses complex types ($name)', ({ schema, expected }: ParseFieldTypeTestCase) => { expect(parseFieldType(schema)).toEqual(expected); }); - it.each(specialCases)('parses special case types ($name)', ({ schema, expected }) => { + it.each(specialCases)('parses special case types ($name)', ({ schema, expected }: ParseFieldTypeTestCase) => { expect(parseFieldType(schema)).toEqual(expected); }); From af7b5d69a4b769fc81f1c3675ad575798ada1540 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 15 Feb 2024 21:16:25 +1100 Subject: [PATCH 096/340] feat(ui): export components type --- .../frontend/web/src/services/api/types.ts | 228 +++++++++--------- 1 file changed, 114 insertions(+), 114 deletions(-) diff --git a/invokeai/frontend/web/src/services/api/types.ts b/invokeai/frontend/web/src/services/api/types.ts index 55ff808b404..f9a1decf655 100644 --- a/invokeai/frontend/web/src/services/api/types.ts +++ b/invokeai/frontend/web/src/services/api/types.ts @@ -3,7 +3,7 @@ import type { EntityState } from '@reduxjs/toolkit'; import type { components, paths } from 'services/api/schema'; import type { O } from 'ts-toolbelt'; -type s = components['schemas']; +export type S = components['schemas']; export type ImageCache = EntityState; @@ -23,60 +23,60 @@ export type BatchConfig = export type EnqueueBatchResult = components['schemas']['EnqueueBatchResult']; -export type InputFieldJSONSchemaExtra = s['InputFieldJSONSchemaExtra']; -export type OutputFieldJSONSchemaExtra = s['OutputFieldJSONSchemaExtra']; -export type InvocationJSONSchemaExtra = s['UIConfigBase']; +export type InputFieldJSONSchemaExtra = S['InputFieldJSONSchemaExtra']; +export type OutputFieldJSONSchemaExtra = S['OutputFieldJSONSchemaExtra']; +export type InvocationJSONSchemaExtra = S['UIConfigBase']; // App Info -export type AppVersion = s['AppVersion']; -export type AppConfig = s['AppConfig']; -export type AppDependencyVersions = s['AppDependencyVersions']; +export type AppVersion = S['AppVersion']; +export type AppConfig = S['AppConfig']; +export type AppDependencyVersions = S['AppDependencyVersions']; // Images -export type ImageDTO = s['ImageDTO']; -export type BoardDTO = s['BoardDTO']; -export type BoardChanges = s['BoardChanges']; -export type ImageChanges = s['ImageRecordChanges']; -export type ImageCategory = s['ImageCategory']; -export type ResourceOrigin = s['ResourceOrigin']; -export type ImageField = s['ImageField']; -export type OffsetPaginatedResults_BoardDTO_ = s['OffsetPaginatedResults_BoardDTO_']; -export type OffsetPaginatedResults_ImageDTO_ = s['OffsetPaginatedResults_ImageDTO_']; +export type ImageDTO = S['ImageDTO']; +export type BoardDTO = S['BoardDTO']; +export type BoardChanges = S['BoardChanges']; +export type ImageChanges = S['ImageRecordChanges']; +export type ImageCategory = S['ImageCategory']; +export type ResourceOrigin = S['ResourceOrigin']; +export type ImageField = S['ImageField']; +export type OffsetPaginatedResults_BoardDTO_ = S['OffsetPaginatedResults_BoardDTO_']; +export type OffsetPaginatedResults_ImageDTO_ = S['OffsetPaginatedResults_ImageDTO_']; // Models -export type ModelType = s['invokeai__backend__model_management__models__base__ModelType']; -export type SubModelType = s['SubModelType']; -export type BaseModelType = s['invokeai__backend__model_management__models__base__BaseModelType']; -export type MainModelField = s['MainModelField']; -export type VAEModelField = s['VAEModelField']; -export type LoRAModelField = s['LoRAModelField']; -export type LoRAModelFormat = s['LoRAModelFormat']; -export type ControlNetModelField = s['ControlNetModelField']; -export type IPAdapterModelField = s['IPAdapterModelField']; -export type T2IAdapterModelField = s['T2IAdapterModelField']; -export type ModelsList = s['invokeai__app__api__routers__models__ModelsList']; -export type ControlField = s['ControlField']; -export type IPAdapterField = s['IPAdapterField']; +export type ModelType = S['ModelType']; +export type SubModelType = S['SubModelType']; +export type BaseModelType = S['BaseModelType']; +export type MainModelField = S['MainModelField']; +export type VAEModelField = S['VAEModelField']; +export type LoRAModelField = S['LoRAModelField']; +export type LoRAModelFormat = S['LoRAModelFormat']; +export type ControlNetModelField = S['ControlNetModelField']; +export type IPAdapterModelField = S['IPAdapterModelField']; +export type T2IAdapterModelField = S['T2IAdapterModelField']; +export type ModelsList = S['invokeai__app__api__routers__models__ModelsList']; +export type ControlField = S['ControlField']; +export type IPAdapterField = S['IPAdapterField']; // Model Configs -export type LoRAModelConfig = s['LoRAModelConfig']; -export type VaeModelConfig = s['VaeModelConfig']; -export type ControlNetModelCheckpointConfig = s['ControlNetModelCheckpointConfig']; -export type ControlNetModelDiffusersConfig = s['ControlNetModelDiffusersConfig']; +export type LoRAModelConfig = S['LoRAModelConfig']; +export type VaeModelConfig = S['VaeModelConfig']; +export type ControlNetModelCheckpointConfig = S['ControlNetModelCheckpointConfig']; +export type ControlNetModelDiffusersConfig = S['ControlNetModelDiffusersConfig']; export type ControlNetModelConfig = ControlNetModelCheckpointConfig | ControlNetModelDiffusersConfig; -export type IPAdapterModelInvokeAIConfig = s['IPAdapterModelInvokeAIConfig']; +export type IPAdapterModelInvokeAIConfig = S['IPAdapterModelInvokeAIConfig']; export type IPAdapterModelConfig = IPAdapterModelInvokeAIConfig; -export type T2IAdapterModelDiffusersConfig = s['T2IAdapterModelDiffusersConfig']; +export type T2IAdapterModelDiffusersConfig = S['T2IAdapterModelDiffusersConfig']; export type T2IAdapterModelConfig = T2IAdapterModelDiffusersConfig; -export type TextualInversionModelConfig = s['TextualInversionModelConfig']; +export type TextualInversionModelConfig = S['TextualInversionModelConfig']; export type DiffusersModelConfig = - | s['StableDiffusion1ModelDiffusersConfig'] - | s['StableDiffusion2ModelDiffusersConfig'] - | s['StableDiffusionXLModelDiffusersConfig']; + | S['StableDiffusion1ModelDiffusersConfig'] + | S['StableDiffusion2ModelDiffusersConfig'] + | S['StableDiffusionXLModelDiffusersConfig']; export type CheckpointModelConfig = - | s['StableDiffusion1ModelCheckpointConfig'] - | s['StableDiffusion2ModelCheckpointConfig'] - | s['StableDiffusionXLModelCheckpointConfig']; + | S['StableDiffusion1ModelCheckpointConfig'] + | S['StableDiffusion2ModelCheckpointConfig'] + | S['StableDiffusionXLModelCheckpointConfig']; export type MainModelConfig = DiffusersModelConfig | CheckpointModelConfig; export type AnyModelConfig = | LoRAModelConfig @@ -87,87 +87,87 @@ export type AnyModelConfig = | TextualInversionModelConfig | MainModelConfig; -export type MergeModelConfig = s['Body_merge_models']; -export type ImportModelConfig = s['Body_import_model']; +export type MergeModelConfig = S['Body_merge_models']; +export type ImportModelConfig = S['Body_import_model']; // Graphs -export type Graph = s['Graph']; +export type Graph = S['Graph']; export type NonNullableGraph = O.Required; -export type Edge = s['Edge']; -export type GraphExecutionState = s['GraphExecutionState']; -export type Batch = s['Batch']; -export type SessionQueueItemDTO = s['SessionQueueItemDTO']; -export type SessionQueueItem = s['SessionQueueItem']; -export type WorkflowRecordOrderBy = s['WorkflowRecordOrderBy']; -export type SQLiteDirection = s['SQLiteDirection']; -export type WorkflowDTO = s['WorkflowRecordDTO']; -export type WorkflowRecordListItemDTO = s['WorkflowRecordListItemDTO']; +export type Edge = S['Edge']; +export type GraphExecutionState = S['GraphExecutionState']; +export type Batch = S['Batch']; +export type SessionQueueItemDTO = S['SessionQueueItemDTO']; +export type SessionQueueItem = S['SessionQueueItem']; +export type WorkflowRecordOrderBy = S['WorkflowRecordOrderBy']; +export type SQLiteDirection = S['SQLiteDirection']; +export type WorkflowDTO = S['WorkflowRecordDTO']; +export type WorkflowRecordListItemDTO = S['WorkflowRecordListItemDTO']; // General nodes -export type CollectInvocation = s['CollectInvocation']; -export type IterateInvocation = s['IterateInvocation']; -export type RangeInvocation = s['RangeInvocation']; -export type RandomRangeInvocation = s['RandomRangeInvocation']; -export type RangeOfSizeInvocation = s['RangeOfSizeInvocation']; -export type ImageResizeInvocation = s['ImageResizeInvocation']; -export type ImageBlurInvocation = s['ImageBlurInvocation']; -export type ImageScaleInvocation = s['ImageScaleInvocation']; -export type InfillPatchMatchInvocation = s['InfillPatchMatchInvocation']; -export type InfillTileInvocation = s['InfillTileInvocation']; -export type CreateDenoiseMaskInvocation = s['CreateDenoiseMaskInvocation']; -export type MaskEdgeInvocation = s['MaskEdgeInvocation']; -export type RandomIntInvocation = s['RandomIntInvocation']; -export type CompelInvocation = s['CompelInvocation']; -export type DynamicPromptInvocation = s['DynamicPromptInvocation']; -export type NoiseInvocation = s['NoiseInvocation']; -export type DenoiseLatentsInvocation = s['DenoiseLatentsInvocation']; -export type SDXLLoraLoaderInvocation = s['SDXLLoraLoaderInvocation']; -export type ImageToLatentsInvocation = s['ImageToLatentsInvocation']; -export type LatentsToImageInvocation = s['LatentsToImageInvocation']; -export type ImageCollectionInvocation = s['ImageCollectionInvocation']; -export type MainModelLoaderInvocation = s['MainModelLoaderInvocation']; -export type LoraLoaderInvocation = s['LoraLoaderInvocation']; -export type ESRGANInvocation = s['ESRGANInvocation']; -export type DivideInvocation = s['DivideInvocation']; -export type ImageNSFWBlurInvocation = s['ImageNSFWBlurInvocation']; -export type ImageWatermarkInvocation = s['ImageWatermarkInvocation']; -export type SeamlessModeInvocation = s['SeamlessModeInvocation']; -export type MetadataInvocation = s['MetadataInvocation']; -export type CoreMetadataInvocation = s['CoreMetadataInvocation']; -export type MetadataItemInvocation = s['MetadataItemInvocation']; -export type MergeMetadataInvocation = s['MergeMetadataInvocation']; -export type IPAdapterMetadataField = s['IPAdapterMetadataField']; -export type T2IAdapterField = s['T2IAdapterField']; -export type LoRAMetadataField = s['LoRAMetadataField']; +export type CollectInvocation = S['CollectInvocation']; +export type IterateInvocation = S['IterateInvocation']; +export type RangeInvocation = S['RangeInvocation']; +export type RandomRangeInvocation = S['RandomRangeInvocation']; +export type RangeOfSizeInvocation = S['RangeOfSizeInvocation']; +export type ImageResizeInvocation = S['ImageResizeInvocation']; +export type ImageBlurInvocation = S['ImageBlurInvocation']; +export type ImageScaleInvocation = S['ImageScaleInvocation']; +export type InfillPatchMatchInvocation = S['InfillPatchMatchInvocation']; +export type InfillTileInvocation = S['InfillTileInvocation']; +export type CreateDenoiseMaskInvocation = S['CreateDenoiseMaskInvocation']; +export type MaskEdgeInvocation = S['MaskEdgeInvocation']; +export type RandomIntInvocation = S['RandomIntInvocation']; +export type CompelInvocation = S['CompelInvocation']; +export type DynamicPromptInvocation = S['DynamicPromptInvocation']; +export type NoiseInvocation = S['NoiseInvocation']; +export type DenoiseLatentsInvocation = S['DenoiseLatentsInvocation']; +export type SDXLLoraLoaderInvocation = S['SDXLLoraLoaderInvocation']; +export type ImageToLatentsInvocation = S['ImageToLatentsInvocation']; +export type LatentsToImageInvocation = S['LatentsToImageInvocation']; +export type ImageCollectionInvocation = S['ImageCollectionInvocation']; +export type MainModelLoaderInvocation = S['MainModelLoaderInvocation']; +export type LoraLoaderInvocation = S['LoraLoaderInvocation']; +export type ESRGANInvocation = S['ESRGANInvocation']; +export type DivideInvocation = S['DivideInvocation']; +export type ImageNSFWBlurInvocation = S['ImageNSFWBlurInvocation']; +export type ImageWatermarkInvocation = S['ImageWatermarkInvocation']; +export type SeamlessModeInvocation = S['SeamlessModeInvocation']; +export type MetadataInvocation = S['MetadataInvocation']; +export type CoreMetadataInvocation = S['CoreMetadataInvocation']; +export type MetadataItemInvocation = S['MetadataItemInvocation']; +export type MergeMetadataInvocation = S['MergeMetadataInvocation']; +export type IPAdapterMetadataField = S['IPAdapterMetadataField']; +export type T2IAdapterField = S['T2IAdapterField']; +export type LoRAMetadataField = S['LoRAMetadataField']; // ControlNet Nodes -export type ControlNetInvocation = s['ControlNetInvocation']; -export type T2IAdapterInvocation = s['T2IAdapterInvocation']; -export type IPAdapterInvocation = s['IPAdapterInvocation']; -export type CannyImageProcessorInvocation = s['CannyImageProcessorInvocation']; -export type ColorMapImageProcessorInvocation = s['ColorMapImageProcessorInvocation']; -export type ContentShuffleImageProcessorInvocation = s['ContentShuffleImageProcessorInvocation']; -export type DepthAnythingImageProcessorInvocation = s['DepthAnythingImageProcessorInvocation']; -export type HedImageProcessorInvocation = s['HedImageProcessorInvocation']; -export type LineartAnimeImageProcessorInvocation = s['LineartAnimeImageProcessorInvocation']; -export type LineartImageProcessorInvocation = s['LineartImageProcessorInvocation']; -export type MediapipeFaceProcessorInvocation = s['MediapipeFaceProcessorInvocation']; -export type MidasDepthImageProcessorInvocation = s['MidasDepthImageProcessorInvocation']; -export type MlsdImageProcessorInvocation = s['MlsdImageProcessorInvocation']; -export type NormalbaeImageProcessorInvocation = s['NormalbaeImageProcessorInvocation']; -export type DWOpenposeImageProcessorInvocation = s['DWOpenposeImageProcessorInvocation']; -export type PidiImageProcessorInvocation = s['PidiImageProcessorInvocation']; -export type ZoeDepthImageProcessorInvocation = s['ZoeDepthImageProcessorInvocation']; +export type ControlNetInvocation = S['ControlNetInvocation']; +export type T2IAdapterInvocation = S['T2IAdapterInvocation']; +export type IPAdapterInvocation = S['IPAdapterInvocation']; +export type CannyImageProcessorInvocation = S['CannyImageProcessorInvocation']; +export type ColorMapImageProcessorInvocation = S['ColorMapImageProcessorInvocation']; +export type ContentShuffleImageProcessorInvocation = S['ContentShuffleImageProcessorInvocation']; +export type DepthAnythingImageProcessorInvocation = S['DepthAnythingImageProcessorInvocation']; +export type HedImageProcessorInvocation = S['HedImageProcessorInvocation']; +export type LineartAnimeImageProcessorInvocation = S['LineartAnimeImageProcessorInvocation']; +export type LineartImageProcessorInvocation = S['LineartImageProcessorInvocation']; +export type MediapipeFaceProcessorInvocation = S['MediapipeFaceProcessorInvocation']; +export type MidasDepthImageProcessorInvocation = S['MidasDepthImageProcessorInvocation']; +export type MlsdImageProcessorInvocation = S['MlsdImageProcessorInvocation']; +export type NormalbaeImageProcessorInvocation = S['NormalbaeImageProcessorInvocation']; +export type DWOpenposeImageProcessorInvocation = S['DWOpenposeImageProcessorInvocation']; +export type PidiImageProcessorInvocation = S['PidiImageProcessorInvocation']; +export type ZoeDepthImageProcessorInvocation = S['ZoeDepthImageProcessorInvocation']; // Node Outputs -export type ImageOutput = s['ImageOutput']; -export type StringOutput = s['StringOutput']; -export type FloatOutput = s['FloatOutput']; -export type IntegerOutput = s['IntegerOutput']; -export type IterateInvocationOutput = s['IterateInvocationOutput']; -export type CollectInvocationOutput = s['CollectInvocationOutput']; -export type LatentsOutput = s['LatentsOutput']; -export type GraphInvocationOutput = s['GraphInvocationOutput']; +export type ImageOutput = S['ImageOutput']; +export type StringOutput = S['StringOutput']; +export type FloatOutput = S['FloatOutput']; +export type IntegerOutput = S['IntegerOutput']; +export type IterateInvocationOutput = S['IterateInvocationOutput']; +export type CollectInvocationOutput = S['CollectInvocationOutput']; +export type LatentsOutput = S['LatentsOutput']; +export type GraphInvocationOutput = S['GraphInvocationOutput']; // Post-image upload actions, controls workflows when images are uploaded From 2dbde99f01c8dcfb45bc115be8e6c52a72e7d968 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 15 Feb 2024 22:15:21 +1100 Subject: [PATCH 097/340] chore(ui): typegen --- .../frontend/web/src/services/api/schema.ts | 2024 +++++++---------- 1 file changed, 847 insertions(+), 1177 deletions(-) diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index 1599b310c9a..3393e74d486 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -19,80 +19,14 @@ export type paths = { */ post: operations["parse_dynamicprompts"]; }; - "/api/v1/models/": { - /** - * List Models - * @description Gets a list of models - */ - get: operations["list_models"]; - }; - "/api/v1/models/{base_model}/{model_type}/{model_name}": { - /** - * Delete Model - * @description Delete Model - */ - delete: operations["del_model"]; - /** - * Update Model - * @description Update model contents with a new config. If the model name or base fields are changed, then the model is renamed. - */ - patch: operations["update_model"]; - }; - "/api/v1/models/import": { - /** - * Import Model - * @description Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically - */ - post: operations["import_model"]; - }; - "/api/v1/models/add": { - /** - * Add Model - * @description Add a model using the configuration information appropriate for its type. Only local models can be added by path - */ - post: operations["add_model"]; - }; - "/api/v1/models/convert/{base_model}/{model_type}/{model_name}": { - /** - * Convert Model - * @description Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none. - */ - put: operations["convert_model"]; - }; - "/api/v1/models/search": { - /** Search For Models */ - get: operations["search_for_models"]; - }; - "/api/v1/models/ckpt_confs": { - /** - * List Ckpt Configs - * @description Return a list of the legacy checkpoint configuration files stored in `ROOT/configs/stable-diffusion`, relative to ROOT. - */ - get: operations["list_ckpt_configs"]; - }; - "/api/v1/models/sync": { - /** - * Sync To Config - * @description Call after making changes to models.yaml, autoimport directories or models directory to synchronize - * in-memory data structures with disk data structures. - */ - post: operations["sync_to_config"]; - }; - "/api/v1/models/merge/{base_model}": { - /** - * Merge Models - * @description Convert a checkpoint model into a diffusers model - */ - put: operations["merge_models"]; - }; - "/api/v1/model/record/": { + "/api/v2/models/": { /** * List Model Records * @description Get a list of models. */ get: operations["list_model_records"]; }; - "/api/v1/model/record/i/{key}": { + "/api/v2/models/i/{key}": { /** * Get Model Record * @description Get a model record @@ -112,50 +46,76 @@ export type paths = { */ patch: operations["update_model_record"]; }; - "/api/v1/model/record/meta": { + "/api/v2/models/summary": { /** * List Model Summary * @description Gets a page of model summary data. */ get: operations["list_model_summary"]; }; - "/api/v1/model/record/meta/i/{key}": { + "/api/v2/models/meta/i/{key}": { /** * Get Model Metadata * @description Get a model metadata object. */ get: operations["get_model_metadata"]; }; - "/api/v1/model/record/tags": { + "/api/v2/models/tags": { /** * List Tags * @description Get a unique set of all the model tags. */ get: operations["list_tags"]; }; - "/api/v1/model/record/tags/search": { + "/api/v2/models/tags/search": { /** * Search By Metadata Tags * @description Get a list of models. */ get: operations["search_by_metadata_tags"]; }; - "/api/v1/model/record/i/": { + "/api/v2/models/i/": { /** * Add Model Record * @description Add a model using the configuration information appropriate for its type. */ post: operations["add_model_record"]; }; - "/api/v1/model/record/import": { + "/api/v2/models/heuristic_import": { /** - * List Model Install Jobs - * @description Return list of model install jobs. + * Heuristic Import + * @description Install a model using a string identifier. + * + * `source` can be any of the following. + * + * 1. A path on the local filesystem ('C:\users\fred\model.safetensors') + * 2. A Url pointing to a single downloadable model file + * 3. A HuggingFace repo_id with any of the following formats: + * - model/name + * - model/name:fp16:vae + * - model/name::vae -- use default precision + * - model/name:fp16:path/to/model.safetensors + * - model/name::path/to/model.safetensors + * + * `config` is an optional dict containing model configuration values that will override + * the ones that are probed automatically. + * + * `access_token` is an optional access token for use with Urls that require + * authentication. + * + * Models will be downloaded, probed, configured and installed in a + * series of background threads. The return object has `status` attribute + * that can be used to monitor progress. + * + * See the documentation for `import_model_record` for more information on + * interpreting the job information returned by this route. */ - get: operations["list_model_install_jobs"]; + post: operations["heuristic_import_model"]; + }; + "/api/v2/models/install": { /** * Import Model - * @description Add a model using its local path, repo_id, or remote URL. + * @description Install a model using its local path, repo_id, or remote URL. * * Models will be downloaded, probed, configured and installed in a * series of background threads. The return object has `status` attribute @@ -166,32 +126,38 @@ export type paths = { * appropriate value: * * * To install a local path using LocalModelSource, pass a source of form: - * `{ + * ``` + * { * "type": "local", * "path": "/path/to/model", * "inplace": false - * }` - * The "inplace" flag, if true, will register the model in place in its - * current filesystem location. Otherwise, the model will be copied - * into the InvokeAI models directory. + * } + * ``` + * The "inplace" flag, if true, will register the model in place in its + * current filesystem location. Otherwise, the model will be copied + * into the InvokeAI models directory. * * * To install a HuggingFace repo_id using HFModelSource, pass a source of form: - * `{ + * ``` + * { * "type": "hf", * "repo_id": "stabilityai/stable-diffusion-2.0", * "variant": "fp16", * "subfolder": "vae", * "access_token": "f5820a918aaf01" - * }` - * The `variant`, `subfolder` and `access_token` fields are optional. + * } + * ``` + * The `variant`, `subfolder` and `access_token` fields are optional. * * * To install a remote model using an arbitrary URL, pass: - * `{ + * ``` + * { * "type": "url", * "url": "http://www.civitai.com/models/123456", * "access_token": "f5820a918aaf01" - * }` - * The `access_token` field is optonal + * } + * ``` + * The `access_token` field is optonal * * The model's configuration record will be probed and filled in * automatically. To override the default guesses, pass "metadata" @@ -200,26 +166,51 @@ export type paths = { * Installation occurs in the background. Either use list_model_install_jobs() * to poll for completion, or listen on the event bus for the following events: * - * "model_install_running" - * "model_install_completed" - * "model_install_error" + * * "model_install_running" + * * "model_install_completed" + * * "model_install_error" * * On successful completion, the event's payload will contain the field "key" * containing the installed ID of the model. On an error, the event's payload * will contain the fields "error_type" and "error" describing the nature of the * error and its traceback, respectively. */ - post: operations["import_model_record"]; + post: operations["import_model"]; + }; + "/api/v2/models/import": { + /** + * List Model Install Jobs + * @description Return the list of model install jobs. + * + * Install jobs have a numeric `id`, a `status`, and other fields that provide information on + * the nature of the job and its progress. The `status` is one of: + * + * * "waiting" -- Job is waiting in the queue to run + * * "downloading" -- Model file(s) are downloading + * * "running" -- Model has downloaded and the model probing and registration process is running + * * "completed" -- Installation completed successfully + * * "error" -- An error occurred. Details will be in the "error_type" and "error" fields. + * * "cancelled" -- Job was cancelled before completion. + * + * Once completed, information about the model such as its size, base + * model, type, and metadata can be retrieved from the `config_out` + * field. For multi-file models such as diffusers, information on individual files + * can be retrieved from `download_parts`. + * + * See the example and schema below for more information. + */ + get: operations["list_model_install_jobs"]; /** * Prune Model Install Jobs * @description Prune all completed and errored jobs from the install job list. */ patch: operations["prune_model_install_jobs"]; }; - "/api/v1/model/record/import/{id}": { + "/api/v2/models/import/{id}": { /** * Get Model Install Job - * @description Return model install job corresponding to the given source. + * @description Return model install job corresponding to the given source. See the documentation for 'List Model Install Jobs' + * for information on the format of the return value. */ get: operations["get_model_install_job"]; /** @@ -228,7 +219,7 @@ export type paths = { */ delete: operations["cancel_model_install_job"]; }; - "/api/v1/model/record/sync": { + "/api/v2/models/sync": { /** * Sync Models To Config * @description Traverse the models and autoimport directories. @@ -238,17 +229,29 @@ export type paths = { */ patch: operations["sync_models_to_config"]; }; - "/api/v1/model/record/merge": { + "/api/v2/models/convert/{key}": { + /** + * Convert Model + * @description Permanently convert a model into diffusers format, replacing the safetensors version. + * Note that during the conversion process the key and model hash will change. + * The return value is the model configuration for the converted model. + */ + put: operations["convert_model"]; + }; + "/api/v2/models/merge": { /** * Merge - * @description Merge diffusers models. - * - * keys: List of 2-3 model keys to merge together. All models must use the same base type. - * merged_model_name: Name for the merged model [Concat model names] - * alpha: Alpha value (0.0-1.0). Higher values give more weight to the second model [0.5] - * force: If true, force the merge even if the models were generated by different versions of the diffusers library [False] - * interp: Interpolation method. One of "weighted_sum", "sigmoid", "inv_sigmoid" or "add_difference" [weighted_sum] - * merge_dest_directory: Specify a directory to store the merged model in [models directory] + * @description Merge diffusers models. The process is controlled by a set parameters provided in the body of the request. + * ``` + * Argument Description [default] + * -------- ---------------------- + * keys List of 2-3 model keys to merge together. All models must use the same base type. + * merged_model_name Name for the merged model [Concat model names] + * alpha Alpha value (0.0-1.0). Higher values give more weight to the second model [0.5] + * force If true, force the merge even if the models were generated by different versions of the diffusers library [False] + * interp Interpolation method. One of "weighted_sum", "sigmoid", "inv_sigmoid" or "add_difference" [weighted_sum] + * merge_dest_directory Specify a directory to store the merged model in [models directory] + * ``` */ put: operations["merge"]; }; @@ -815,6 +818,12 @@ export type components = { */ type?: "basemetadata"; }; + /** + * BaseModelType + * @description Base model type. + * @enum {string} + */ + BaseModelType: "any" | "sd-1" | "sd-2" | "sdxl" | "sdxl-refiner"; /** Batch */ Batch: { /** @@ -1163,19 +1172,6 @@ export type components = { }; /** Body_import_model */ Body_import_model: { - /** - * Location - * @description A model path, repo_id or URL to import - */ - location: string; - /** - * Prediction Type - * @description Prediction type for SDv2 checkpoints and rare SDv1 checkpoints - */ - prediction_type?: ("v_prediction" | "epsilon" | "sample") | null; - }; - /** Body_import_model_record */ - Body_import_model_record: { /** Source */ source: components["schemas"]["LocalModelSource"] | components["schemas"]["HFModelSource"] | components["schemas"]["CivitaiModelSource"] | components["schemas"]["URLModelSource"]; /** @@ -1216,11 +1212,6 @@ export type components = { */ merge_dest_directory?: string | null; }; - /** Body_merge_models */ - Body_merge_models: { - /** @description Model configuration */ - body: components["schemas"]["MergeModelsBody"]; - }; /** Body_parse_dynamicprompts */ Body_parse_dynamicprompts: { /** @@ -1412,11 +1403,18 @@ export type components = { * @description Model config for ClipVision. */ CLIPVisionDiffusersConfig: { - /** Path */ + /** + * Path + * @description filesystem path to the model file or directory + */ path: string; - /** Name */ + /** + * Name + * @description model name + */ name: string; - base: components["schemas"]["invokeai__backend__model_manager__config__BaseModelType"]; + /** @description base model */ + base: components["schemas"]["BaseModelType"]; /** * Type * @default clip_vision @@ -1444,45 +1442,29 @@ export type components = { * @description current fasthash of model contents */ current_hash?: string | null; - /** Description */ + /** + * Description + * @description human readable description of the model + */ description?: string | null; /** * Source - * @description Model download source (URL or repo_id) + * @description model original source (path, URL or repo_id) */ source?: string | null; - }; - /** CLIPVisionModelDiffusersConfig */ - CLIPVisionModelDiffusersConfig: { - /** Model Name */ - model_name: string; - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; - /** - * Model Type - * @default clip_vision - * @constant - */ - model_type: "clip_vision"; - /** Path */ - path: string; - /** Description */ - description?: string | null; /** - * Model Format - * @constant + * Last Modified + * @description timestamp for modification time */ - model_format: "diffusers"; - error?: components["schemas"]["ModelError"] | null; + last_modified?: number | null; }; /** CLIPVisionModelField */ CLIPVisionModelField: { /** - * Model Name - * @description Name of the CLIP Vision image encoder model + * Key + * @description Key to the CLIP Vision image encoder model */ - model_name: string; - /** @description Base model (usually 'Any') */ - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; + key: string; }; /** * CV2 Infill @@ -2538,11 +2520,18 @@ export type components = { * @description Model config for ControlNet models (diffusers version). */ ControlNetCheckpointConfig: { - /** Path */ + /** + * Path + * @description filesystem path to the model file or directory + */ path: string; - /** Name */ + /** + * Name + * @description model name + */ name: string; - base: components["schemas"]["invokeai__backend__model_manager__config__BaseModelType"]; + /** @description base model */ + base: components["schemas"]["BaseModelType"]; /** * Type * @default controlnet @@ -2571,13 +2560,21 @@ export type components = { * @description current fasthash of model contents */ current_hash?: string | null; - /** Description */ + /** + * Description + * @description human readable description of the model + */ description?: string | null; /** * Source - * @description Model download source (URL or repo_id) + * @description model original source (path, URL or repo_id) */ source?: string | null; + /** + * Last Modified + * @description timestamp for modification time + */ + last_modified?: number | null; /** * Config * @description path to the checkpoint model config file @@ -2589,11 +2586,18 @@ export type components = { * @description Model config for ControlNet models (diffusers version). */ ControlNetDiffusersConfig: { - /** Path */ + /** + * Path + * @description filesystem path to the model file or directory + */ path: string; - /** Name */ + /** + * Name + * @description model name + */ name: string; - base: components["schemas"]["invokeai__backend__model_manager__config__BaseModelType"]; + /** @description base model */ + base: components["schemas"]["BaseModelType"]; /** * Type * @default controlnet @@ -2622,13 +2626,23 @@ export type components = { * @description current fasthash of model contents */ current_hash?: string | null; - /** Description */ + /** + * Description + * @description human readable description of the model + */ description?: string | null; /** * Source - * @description Model download source (URL or repo_id) + * @description model original source (path, URL or repo_id) */ source?: string | null; + /** + * Last Modified + * @description timestamp for modification time + */ + last_modified?: number | null; + /** @default */ + repo_variant?: components["schemas"]["ModelRepoVariant"] | null; }; /** * ControlNet @@ -2695,64 +2709,16 @@ export type components = { */ type: "controlnet"; }; - /** ControlNetModelCheckpointConfig */ - ControlNetModelCheckpointConfig: { - /** Model Name */ - model_name: string; - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; - /** - * Model Type - * @default controlnet - * @constant - */ - model_type: "controlnet"; - /** Path */ - path: string; - /** Description */ - description?: string | null; - /** - * Model Format - * @constant - */ - model_format: "checkpoint"; - error?: components["schemas"]["ModelError"] | null; - /** Config */ - config: string; - }; - /** ControlNetModelDiffusersConfig */ - ControlNetModelDiffusersConfig: { - /** Model Name */ - model_name: string; - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; - /** - * Model Type - * @default controlnet - * @constant - */ - model_type: "controlnet"; - /** Path */ - path: string; - /** Description */ - description?: string | null; - /** - * Model Format - * @constant - */ - model_format: "diffusers"; - error?: components["schemas"]["ModelError"] | null; - }; /** * ControlNetModelField * @description ControlNet model field */ ControlNetModelField: { /** - * Model Name - * @description Name of the ControlNet model + * Key + * @description Model config record key for the ControlNet model */ - model_name: string; - /** @description Base model */ - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; + key: string; }; /** * ControlOutput @@ -4246,7 +4212,7 @@ export type components = { * @description The nodes in this graph */ nodes?: { - [key: string]: components["schemas"]["ImageCropInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["RoundInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CalculateImageTilesMinimumOverlapInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ConditioningCollectionInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["CropLatentsCoreInvocation"] | components["schemas"]["ConditioningInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["DepthAnythingImageProcessorInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["IntegerInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["LaMaInfillInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["StringJoinInvocation"] | components["schemas"]["CalculateImageTilesInvocation"] | components["schemas"]["BlankImageInvocation"] | components["schemas"]["StringSplitNegInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["StringJoinThreeInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["BlendLatentsInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["DWOpenposeImageProcessorInvocation"] | components["schemas"]["MetadataItemInvocation"] | components["schemas"]["RandomFloatInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["ImageInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["TileToPropertiesInvocation"] | components["schemas"]["ImageChannelMultiplyInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImageChannelOffsetInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["StringCollectionInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["FloatInvocation"] | components["schemas"]["MetadataInvocation"] | components["schemas"]["MergeMetadataInvocation"] | components["schemas"]["CoreMetadataInvocation"] | components["schemas"]["StringReplaceInvocation"] | components["schemas"]["MergeTilesToImageInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["StringSplitInvocation"] | components["schemas"]["BooleanCollectionInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["BooleanInvocation"] | components["schemas"]["LatentsInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["FaceMaskInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["SDXLLoraLoaderInvocation"] | components["schemas"]["SeamlessModeInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["LatentsCollectionInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IntegerCollectionInvocation"] | components["schemas"]["IPAdapterInvocation"] | components["schemas"]["FreeUInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["T2IAdapterInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["ColorMapImageProcessorInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["SaveImageInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["CenterPadCropInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["CalculateImageTilesEvenSplitInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["SchedulerInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["CV2InfillInvocation"] | components["schemas"]["StringInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["PairTileImageInvocation"] | components["schemas"]["FaceOffInvocation"] | components["schemas"]["FloatCollectionInvocation"] | components["schemas"]["FaceIdentifierInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ColorInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["IntegerMathInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["CreateDenoiseMaskInvocation"] | components["schemas"]["FloatMathInvocation"] | components["schemas"]["UnsharpMaskInvocation"] | components["schemas"]["FloatToIntegerInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["IdealSizeInvocation"]; + [key: string]: components["schemas"]["CenterPadCropInvocation"] | components["schemas"]["ConditioningInvocation"] | components["schemas"]["CoreMetadataInvocation"] | components["schemas"]["SeamlessModeInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["FreeUInvocation"] | components["schemas"]["TileToPropertiesInvocation"] | components["schemas"]["SaveImageInvocation"] | components["schemas"]["FaceMaskInvocation"] | components["schemas"]["CalculateImageTilesEvenSplitInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["ImageChannelOffsetInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["LatentsInvocation"] | components["schemas"]["CreateDenoiseMaskInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ConditioningCollectionInvocation"] | components["schemas"]["IntegerCollectionInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["StringSplitInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["RandomFloatInvocation"] | components["schemas"]["FaceOffInvocation"] | components["schemas"]["T2IAdapterInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["MergeTilesToImageInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["StringInvocation"] | components["schemas"]["FloatMathInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["StringJoinThreeInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["FloatToIntegerInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ColorInvocation"] | components["schemas"]["SchedulerInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["BlankImageInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["FloatInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["CropLatentsCoreInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["IntegerMathInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ColorMapImageProcessorInvocation"] | components["schemas"]["ImageChannelMultiplyInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["CalculateImageTilesInvocation"] | components["schemas"]["LaMaInfillInvocation"] | components["schemas"]["StringReplaceInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["SDXLLoraLoaderInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["DWOpenposeImageProcessorInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["CV2InfillInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["ImageInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["IntegerInvocation"] | components["schemas"]["PairTileImageInvocation"] | components["schemas"]["UnsharpMaskInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["BlendLatentsInvocation"] | components["schemas"]["StringCollectionInvocation"] | components["schemas"]["StringJoinInvocation"] | components["schemas"]["MergeMetadataInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["IdealSizeInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["MetadataItemInvocation"] | components["schemas"]["MetadataInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["BooleanCollectionInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["IPAdapterInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["FloatCollectionInvocation"] | components["schemas"]["LatentsCollectionInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["FaceIdentifierInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["BooleanInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["StringSplitNegInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["RoundInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["CalculateImageTilesMinimumOverlapInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["DepthAnythingImageProcessorInvocation"]; }; /** * Edges @@ -4283,7 +4249,7 @@ export type components = { * @description The results of node executions */ results: { - [key: string]: components["schemas"]["MetadataOutput"] | components["schemas"]["IntegerOutput"] | components["schemas"]["CalculateImageTilesOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["CLIPOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["TileToPropertiesOutput"] | components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["SDXLLoraLoaderOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["IdealSizeOutput"] | components["schemas"]["PairTileImageOutput"] | components["schemas"]["T2IAdapterOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["UNetOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["VAEOutput"] | components["schemas"]["LoraLoaderOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["GraphInvocationOutput"] | components["schemas"]["String2Output"] | components["schemas"]["ColorCollectionOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["SeamlessModeOutput"] | components["schemas"]["DenoiseMaskOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["IPAdapterOutput"] | components["schemas"]["LatentsCollectionOutput"] | components["schemas"]["SchedulerOutput"] | components["schemas"]["ClipSkipInvocationOutput"]; + [key: string]: components["schemas"]["ImageCollectionOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["T2IAdapterOutput"] | components["schemas"]["MetadataOutput"] | components["schemas"]["ColorCollectionOutput"] | components["schemas"]["IntegerOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["LatentsCollectionOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["CalculateImageTilesOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["SeamlessModeOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["SDXLLoraLoaderOutput"] | components["schemas"]["String2Output"] | components["schemas"]["SchedulerOutput"] | components["schemas"]["VAEOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["IdealSizeOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["GraphInvocationOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["IPAdapterOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["UNetOutput"] | components["schemas"]["LoraLoaderOutput"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["DenoiseMaskOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["TileToPropertiesOutput"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["PairTileImageOutput"] | components["schemas"]["ClipSkipInvocationOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["CLIPOutput"]; }; /** * Errors @@ -4477,11 +4443,18 @@ export type components = { * @description Model config for IP Adaptor format models. */ IPAdapterConfig: { - /** Path */ + /** + * Path + * @description filesystem path to the model file or directory + */ path: string; - /** Name */ + /** + * Name + * @description model name + */ name: string; - base: components["schemas"]["invokeai__backend__model_manager__config__BaseModelType"]; + /** @description base model */ + base: components["schemas"]["BaseModelType"]; /** * Type * @default ip_adapter @@ -4509,13 +4482,23 @@ export type components = { * @description current fasthash of model contents */ current_hash?: string | null; - /** Description */ + /** + * Description + * @description human readable description of the model + */ description?: string | null; /** * Source - * @description Model download source (URL or repo_id) + * @description model original source (path, URL or repo_id) */ source?: string | null; + /** + * Last Modified + * @description timestamp for modification time + */ + last_modified?: number | null; + /** Image Encoder Model Id */ + image_encoder_model_id: string; }; /** IPAdapterField */ IPAdapterField: { @@ -4632,34 +4615,10 @@ export type components = { /** IPAdapterModelField */ IPAdapterModelField: { /** - * Model Name - * @description Name of the IP-Adapter model - */ - model_name: string; - /** @description Base model */ - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; - }; - /** IPAdapterModelInvokeAIConfig */ - IPAdapterModelInvokeAIConfig: { - /** Model Name */ - model_name: string; - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; - /** - * Model Type - * @default ip_adapter - * @constant - */ - model_type: "ip_adapter"; - /** Path */ - path: string; - /** Description */ - description?: string | null; - /** - * Model Format - * @constant + * Key + * @description Key to the IP-Adapter model */ - model_format: "invokeai"; - error?: components["schemas"]["ModelError"] | null; + key: string; }; /** IPAdapterOutput */ IPAdapterOutput: { @@ -6562,11 +6521,18 @@ export type components = { * @description Model config for LoRA/Lycoris models. */ LoRAConfig: { - /** Path */ + /** + * Path + * @description filesystem path to the model file or directory + */ path: string; - /** Name */ + /** + * Name + * @description model name + */ name: string; - base: components["schemas"]["invokeai__backend__model_manager__config__BaseModelType"]; + /** @description base model */ + base: components["schemas"]["BaseModelType"]; /** * Type * @default lora @@ -6594,13 +6560,21 @@ export type components = { * @description current fasthash of model contents */ current_hash?: string | null; - /** Description */ + /** + * Description + * @description human readable description of the model + */ description?: string | null; /** * Source - * @description Model download source (URL or repo_id) + * @description model original source (path, URL or repo_id) */ source?: string | null; + /** + * Last Modified + * @description timestamp for modification time + */ + last_modified?: number | null; }; /** * LoRAMetadataField @@ -6615,42 +6589,17 @@ export type components = { */ weight: number; }; - /** LoRAModelConfig */ - LoRAModelConfig: { - /** Model Name */ - model_name: string; - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; - /** - * Model Type - * @default lora - * @constant - */ - model_type: "lora"; - /** Path */ - path: string; - /** Description */ - description?: string | null; - model_format: components["schemas"]["LoRAModelFormat"]; - error?: components["schemas"]["ModelError"] | null; - }; /** * LoRAModelField * @description LoRA model field */ LoRAModelField: { /** - * Model Name - * @description Name of the LoRA model + * Key + * @description LoRA model key */ - model_name: string; - /** @description Base model */ - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; + key: string; }; - /** - * LoRAModelFormat - * @enum {string} - */ - LoRAModelFormat: "lycoris" | "diffusers"; /** * LocalModelSource * @description A local file or directory path. @@ -6678,16 +6627,12 @@ export type components = { /** LoraInfo */ LoraInfo: { /** - * Model Name - * @description Info to load submodel + * Key + * @description Key of model as returned by ModelRecordServiceBase.get_model() */ - model_name: string; - /** @description Base model */ - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; - /** @description Info to load submodel */ - model_type: components["schemas"]["invokeai__backend__model_management__models__base__ModelType"]; + key: string; /** @description Info to load submodel */ - submodel?: components["schemas"]["SubModelType"] | null; + submodel_type?: components["schemas"]["SubModelType"] | null; /** * Weight * @description Lora's weight which to use when apply to model @@ -6771,11 +6716,18 @@ export type components = { * @description Model config for main checkpoint models. */ MainCheckpointConfig: { - /** Path */ + /** + * Path + * @description filesystem path to the model file or directory + */ path: string; - /** Name */ + /** + * Name + * @description model name + */ name: string; - base: components["schemas"]["invokeai__backend__model_manager__config__BaseModelType"]; + /** @description base model */ + base: components["schemas"]["BaseModelType"]; /** * Type * @default main @@ -6804,17 +6756,32 @@ export type components = { * @description current fasthash of model contents */ current_hash?: string | null; - /** Description */ + /** + * Description + * @description human readable description of the model + */ description?: string | null; /** * Source - * @description Model download source (URL or repo_id) + * @description model original source (path, URL or repo_id) */ source?: string | null; + /** + * Last Modified + * @description timestamp for modification time + */ + last_modified?: number | null; /** Vae */ vae?: string | null; /** @default normal */ - variant?: components["schemas"]["invokeai__backend__model_manager__config__ModelVariantType"]; + variant?: components["schemas"]["ModelVariantType"]; + /** @default epsilon */ + prediction_type?: components["schemas"]["SchedulerPredictionType"]; + /** + * Upcast Attention + * @default false + */ + upcast_attention?: boolean; /** * Ztsnr Training * @default false @@ -6831,11 +6798,18 @@ export type components = { * @description Model config for main diffusers models. */ MainDiffusersConfig: { - /** Path */ + /** + * Path + * @description filesystem path to the model file or directory + */ path: string; - /** Name */ + /** + * Name + * @description model name + */ name: string; - base: components["schemas"]["invokeai__backend__model_manager__config__BaseModelType"]; + /** @description base model */ + base: components["schemas"]["BaseModelType"]; /** * Type * @default main @@ -6864,29 +6838,39 @@ export type components = { * @description current fasthash of model contents */ current_hash?: string | null; - /** Description */ + /** + * Description + * @description human readable description of the model + */ description?: string | null; /** * Source - * @description Model download source (URL or repo_id) + * @description model original source (path, URL or repo_id) */ source?: string | null; + /** + * Last Modified + * @description timestamp for modification time + */ + last_modified?: number | null; /** Vae */ vae?: string | null; /** @default normal */ - variant?: components["schemas"]["invokeai__backend__model_manager__config__ModelVariantType"]; + variant?: components["schemas"]["ModelVariantType"]; + /** @default epsilon */ + prediction_type?: components["schemas"]["SchedulerPredictionType"]; /** - * Ztsnr Training + * Upcast Attention * @default false */ - ztsnr_training?: boolean; - /** @default epsilon */ - prediction_type?: components["schemas"]["invokeai__backend__model_manager__config__SchedulerPredictionType"]; + upcast_attention?: boolean; /** - * Upcast Attention + * Ztsnr Training * @default false */ - upcast_attention?: boolean; + ztsnr_training?: boolean; + /** @default */ + repo_variant?: components["schemas"]["ModelRepoVariant"] | null; }; /** * MainModelField @@ -6894,14 +6878,10 @@ export type components = { */ MainModelField: { /** - * Model Name - * @description Name of the model + * Key + * @description Model key */ - model_name: string; - /** @description Base model */ - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; - /** @description Model Type */ - model_type: components["schemas"]["invokeai__backend__model_management__models__base__ModelType"]; + key: string; }; /** * Main Model @@ -7153,38 +7133,6 @@ export type components = { */ type: "merge_metadata"; }; - /** MergeModelsBody */ - MergeModelsBody: { - /** - * Model Names - * @description model name - */ - model_names: string[]; - /** - * Merged Model Name - * @description Name of destination model - */ - merged_model_name: string | null; - /** - * Alpha - * @description Alpha weighting strength to apply to 2d and 3d models - * @default 0.5 - */ - alpha?: number | null; - /** @description Interpolation method */ - interp: components["schemas"]["MergeInterpolationMethod"] | null; - /** - * Force - * @description Force merging of models created with different versions of diffusers - * @default false - */ - force?: boolean | null; - /** - * Merge Dest Directory - * @description Save the merged model to the designated directory (with 'merged_model_name' appended) - */ - merge_dest_directory?: string | null; - }; /** * Merge Tiles to Image * @description Merge multiple tile images into a single image. @@ -7459,11 +7407,6 @@ export type components = { */ type: "mlsd_image_processor"; }; - /** - * ModelError - * @constant - */ - ModelError: "not_found"; /** * ModelFormat * @description Storage format of model. @@ -7473,16 +7416,12 @@ export type components = { /** ModelInfo */ ModelInfo: { /** - * Model Name - * @description Info to load submodel + * Key + * @description Key of model as returned by ModelRecordServiceBase.get_model() */ - model_name: string; - /** @description Base model */ - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; - /** @description Info to load submodel */ - model_type: components["schemas"]["invokeai__backend__model_management__models__base__ModelType"]; + key: string; /** @description Info to load submodel */ - submodel?: components["schemas"]["SubModelType"] | null; + submodel_type?: components["schemas"]["SubModelType"] | null; }; /** * ModelInstallJob @@ -7508,7 +7447,7 @@ export type components = { * Config Out * @description After successful installation, this will hold the configuration object. */ - config_out?: (components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"]) | (components["schemas"]["ONNXSD1Config"] | components["schemas"]["ONNXSD2Config"]) | (components["schemas"]["VaeDiffusersConfig"] | components["schemas"]["VaeCheckpointConfig"]) | (components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"]) | components["schemas"]["LoRAConfig"] | components["schemas"]["TextualInversionConfig"] | components["schemas"]["IPAdapterConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"] | components["schemas"]["T2IConfig"] | null; + config_out?: (components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"]) | (components["schemas"]["ONNXSD1Config"] | components["schemas"]["ONNXSD2Config"] | components["schemas"]["ONNXSDXLConfig"]) | (components["schemas"]["VaeDiffusersConfig"] | components["schemas"]["VaeCheckpointConfig"]) | (components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"]) | components["schemas"]["LoRAConfig"] | components["schemas"]["TextualInversionConfig"] | components["schemas"]["IPAdapterConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"] | components["schemas"]["T2IConfig"] | null; /** * Inplace * @description Leave model in its current location; otherwise install under models directory @@ -7587,7 +7526,7 @@ export type components = { * @description Various hugging face variants on the diffusers format. * @enum {string} */ - ModelRepoVariant: "default" | "fp16" | "fp32" | "onnx" | "openvino" | "flax"; + ModelRepoVariant: "" | "fp16" | "fp32" | "onnx" | "openvino" | "flax"; /** * ModelSummary * @description A short summary of models for UI listing purposes. @@ -7599,9 +7538,9 @@ export type components = { */ key: string; /** @description model type */ - type: components["schemas"]["invokeai__backend__model_manager__config__ModelType"]; + type: components["schemas"]["ModelType"]; /** @description base model */ - base: components["schemas"]["invokeai__backend__model_manager__config__BaseModelType"]; + base: components["schemas"]["BaseModelType"]; /** @description model format */ format: components["schemas"]["ModelFormat"]; /** @@ -7620,6 +7559,26 @@ export type components = { */ tags: string[]; }; + /** + * ModelType + * @description Model type. + * @enum {string} + */ + ModelType: "onnx" | "main" | "vae" | "lora" | "controlnet" | "embedding" | "ip_adapter" | "clip_vision" | "t2i_adapter"; + /** + * ModelVariantType + * @description Variant type. + * @enum {string} + */ + ModelVariantType: "normal" | "inpaint" | "depth"; + /** + * ModelsList + * @description Return list of configs. + */ + ModelsList: { + /** Models */ + models: ((components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"]) | (components["schemas"]["ONNXSD1Config"] | components["schemas"]["ONNXSD2Config"] | components["schemas"]["ONNXSDXLConfig"]) | (components["schemas"]["VaeDiffusersConfig"] | components["schemas"]["VaeCheckpointConfig"]) | (components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"]) | components["schemas"]["LoRAConfig"] | components["schemas"]["TextualInversionConfig"] | components["schemas"]["IPAdapterConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"] | components["schemas"]["T2IConfig"])[]; + }; /** * Multiply Integers * @description Multiplies two numbers @@ -7808,9 +7767,15 @@ export type components = { * @description Model config for ONNX format models based on sd-1. */ ONNXSD1Config: { - /** Path */ + /** + * Path + * @description filesystem path to the model file or directory + */ path: string; - /** Name */ + /** + * Name + * @description model name + */ name: string; /** * Base @@ -7845,38 +7810,52 @@ export type components = { * @description current fasthash of model contents */ current_hash?: string | null; - /** Description */ + /** + * Description + * @description human readable description of the model + */ description?: string | null; /** * Source - * @description Model download source (URL or repo_id) + * @description model original source (path, URL or repo_id) */ source?: string | null; + /** + * Last Modified + * @description timestamp for modification time + */ + last_modified?: number | null; /** Vae */ vae?: string | null; /** @default normal */ - variant?: components["schemas"]["invokeai__backend__model_manager__config__ModelVariantType"]; - /** - * Ztsnr Training - * @default false - */ - ztsnr_training?: boolean; + variant?: components["schemas"]["ModelVariantType"]; /** @default epsilon */ - prediction_type?: components["schemas"]["invokeai__backend__model_manager__config__SchedulerPredictionType"]; + prediction_type?: components["schemas"]["SchedulerPredictionType"]; /** * Upcast Attention * @default false */ upcast_attention?: boolean; + /** + * Ztsnr Training + * @default false + */ + ztsnr_training?: boolean; }; /** * ONNXSD2Config * @description Model config for ONNX format models based on sd-2. */ ONNXSD2Config: { - /** Path */ + /** + * Path + * @description filesystem path to the model file or directory + */ path: string; - /** Name */ + /** + * Name + * @description model name + */ name: string; /** * Base @@ -7911,78 +7890,117 @@ export type components = { * @description current fasthash of model contents */ current_hash?: string | null; - /** Description */ + /** + * Description + * @description human readable description of the model + */ description?: string | null; /** * Source - * @description Model download source (URL or repo_id) + * @description model original source (path, URL or repo_id) */ source?: string | null; + /** + * Last Modified + * @description timestamp for modification time + */ + last_modified?: number | null; /** Vae */ vae?: string | null; /** @default normal */ - variant?: components["schemas"]["invokeai__backend__model_manager__config__ModelVariantType"]; - /** - * Ztsnr Training - * @default false - */ - ztsnr_training?: boolean; + variant?: components["schemas"]["ModelVariantType"]; /** @default v_prediction */ - prediction_type?: components["schemas"]["invokeai__backend__model_manager__config__SchedulerPredictionType"]; + prediction_type?: components["schemas"]["SchedulerPredictionType"]; /** * Upcast Attention * @default true */ upcast_attention?: boolean; + /** + * Ztsnr Training + * @default false + */ + ztsnr_training?: boolean; }; - /** ONNXStableDiffusion1ModelConfig */ - ONNXStableDiffusion1ModelConfig: { - /** Model Name */ - model_name: string; - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; + /** + * ONNXSDXLConfig + * @description Model config for ONNX format models based on sdxl. + */ + ONNXSDXLConfig: { /** - * Model Type - * @default onnx - * @constant + * Path + * @description filesystem path to the model file or directory */ - model_type: "onnx"; - /** Path */ path: string; - /** Description */ - description?: string | null; /** - * Model Format + * Name + * @description model name + */ + name: string; + /** + * Base + * @default sdxl * @constant */ - model_format: "onnx"; - error?: components["schemas"]["ModelError"] | null; - variant: components["schemas"]["invokeai__backend__model_management__models__base__ModelVariantType"]; - }; - /** ONNXStableDiffusion2ModelConfig */ - ONNXStableDiffusion2ModelConfig: { - /** Model Name */ - model_name: string; - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; + base?: "sdxl"; /** - * Model Type + * Type * @default onnx * @constant */ - model_type: "onnx"; - /** Path */ - path: string; - /** Description */ + type?: "onnx"; + /** + * Format + * @enum {string} + */ + format: "onnx" | "olive"; + /** + * Key + * @description unique key for model + * @default + */ + key?: string; + /** + * Original Hash + * @description original fasthash of model contents + */ + original_hash?: string | null; + /** + * Current Hash + * @description current fasthash of model contents + */ + current_hash?: string | null; + /** + * Description + * @description human readable description of the model + */ description?: string | null; /** - * Model Format - * @constant + * Source + * @description model original source (path, URL or repo_id) + */ + source?: string | null; + /** + * Last Modified + * @description timestamp for modification time + */ + last_modified?: number | null; + /** Vae */ + vae?: string | null; + /** @default normal */ + variant?: components["schemas"]["ModelVariantType"]; + /** @default v_prediction */ + prediction_type?: components["schemas"]["SchedulerPredictionType"]; + /** + * Upcast Attention + * @default false + */ + upcast_attention?: boolean; + /** + * Ztsnr Training + * @default false */ - model_format: "onnx"; - error?: components["schemas"]["ModelError"] | null; - variant: components["schemas"]["invokeai__backend__model_management__models__base__ModelVariantType"]; - prediction_type: components["schemas"]["invokeai__backend__model_management__models__base__SchedulerPredictionType"]; - /** Upcast Attention */ - upcast_attention: boolean; + ztsnr_training?: boolean; }; /** OffsetPaginatedResults[BoardDTO] */ OffsetPaginatedResults_BoardDTO_: { @@ -9119,6 +9137,12 @@ export type components = { */ type: "scheduler_output"; }; + /** + * SchedulerPredictionType + * @description Scheduler prediction type. + * @enum {string} + */ + SchedulerPredictionType: "epsilon" | "v_prediction" | "sample"; /** * Seamless * @description Applies the seamless transformation to the Model UNet and VAE. @@ -9468,162 +9492,6 @@ export type components = { */ type: "show_image"; }; - /** StableDiffusion1ModelCheckpointConfig */ - StableDiffusion1ModelCheckpointConfig: { - /** Model Name */ - model_name: string; - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; - /** - * Model Type - * @default main - * @constant - */ - model_type: "main"; - /** Path */ - path: string; - /** Description */ - description?: string | null; - /** - * Model Format - * @constant - */ - model_format: "checkpoint"; - error?: components["schemas"]["ModelError"] | null; - /** Vae */ - vae?: string | null; - /** Config */ - config: string; - variant: components["schemas"]["invokeai__backend__model_management__models__base__ModelVariantType"]; - }; - /** StableDiffusion1ModelDiffusersConfig */ - StableDiffusion1ModelDiffusersConfig: { - /** Model Name */ - model_name: string; - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; - /** - * Model Type - * @default main - * @constant - */ - model_type: "main"; - /** Path */ - path: string; - /** Description */ - description?: string | null; - /** - * Model Format - * @constant - */ - model_format: "diffusers"; - error?: components["schemas"]["ModelError"] | null; - /** Vae */ - vae?: string | null; - variant: components["schemas"]["invokeai__backend__model_management__models__base__ModelVariantType"]; - }; - /** StableDiffusion2ModelCheckpointConfig */ - StableDiffusion2ModelCheckpointConfig: { - /** Model Name */ - model_name: string; - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; - /** - * Model Type - * @default main - * @constant - */ - model_type: "main"; - /** Path */ - path: string; - /** Description */ - description?: string | null; - /** - * Model Format - * @constant - */ - model_format: "checkpoint"; - error?: components["schemas"]["ModelError"] | null; - /** Vae */ - vae?: string | null; - /** Config */ - config: string; - variant: components["schemas"]["invokeai__backend__model_management__models__base__ModelVariantType"]; - }; - /** StableDiffusion2ModelDiffusersConfig */ - StableDiffusion2ModelDiffusersConfig: { - /** Model Name */ - model_name: string; - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; - /** - * Model Type - * @default main - * @constant - */ - model_type: "main"; - /** Path */ - path: string; - /** Description */ - description?: string | null; - /** - * Model Format - * @constant - */ - model_format: "diffusers"; - error?: components["schemas"]["ModelError"] | null; - /** Vae */ - vae?: string | null; - variant: components["schemas"]["invokeai__backend__model_management__models__base__ModelVariantType"]; - }; - /** StableDiffusionXLModelCheckpointConfig */ - StableDiffusionXLModelCheckpointConfig: { - /** Model Name */ - model_name: string; - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; - /** - * Model Type - * @default main - * @constant - */ - model_type: "main"; - /** Path */ - path: string; - /** Description */ - description?: string | null; - /** - * Model Format - * @constant - */ - model_format: "checkpoint"; - error?: components["schemas"]["ModelError"] | null; - /** Vae */ - vae?: string | null; - /** Config */ - config: string; - variant: components["schemas"]["invokeai__backend__model_management__models__base__ModelVariantType"]; - }; - /** StableDiffusionXLModelDiffusersConfig */ - StableDiffusionXLModelDiffusersConfig: { - /** Model Name */ - model_name: string; - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; - /** - * Model Type - * @default main - * @constant - */ - model_type: "main"; - /** Path */ - path: string; - /** Description */ - description?: string | null; - /** - * Model Format - * @constant - */ - model_format: "diffusers"; - error?: components["schemas"]["ModelError"] | null; - /** Vae */ - vae?: string | null; - variant: components["schemas"]["invokeai__backend__model_management__models__base__ModelVariantType"]; - }; /** * Step Param Easing * @description Experimental per-step parameter easing for denoising steps @@ -10079,6 +9947,7 @@ export type components = { }; /** * SubModelType + * @description Submodel type. * @enum {string} */ SubModelType: "unet" | "text_encoder" | "text_encoder_2" | "tokenizer" | "tokenizer_2" | "vae" | "vae_decoder" | "vae_encoder" | "scheduler" | "safety_checker"; @@ -10216,37 +10085,13 @@ export type components = { */ type: "t2i_adapter"; }; - /** T2IAdapterModelDiffusersConfig */ - T2IAdapterModelDiffusersConfig: { - /** Model Name */ - model_name: string; - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; + /** T2IAdapterModelField */ + T2IAdapterModelField: { /** - * Model Type - * @default t2i_adapter - * @constant + * Key + * @description Model record key for the T2I-Adapter model */ - model_type: "t2i_adapter"; - /** Path */ - path: string; - /** Description */ - description?: string | null; - /** - * Model Format - * @constant - */ - model_format: "diffusers"; - error?: components["schemas"]["ModelError"] | null; - }; - /** T2IAdapterModelField */ - T2IAdapterModelField: { - /** - * Model Name - * @description Name of the T2I-Adapter model - */ - model_name: string; - /** @description Base model */ - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; + key: string; }; /** T2IAdapterOutput */ T2IAdapterOutput: { @@ -10267,11 +10112,18 @@ export type components = { * @description Model config for T2I. */ T2IConfig: { - /** Path */ + /** + * Path + * @description filesystem path to the model file or directory + */ path: string; - /** Name */ + /** + * Name + * @description model name + */ name: string; - base: components["schemas"]["invokeai__backend__model_manager__config__BaseModelType"]; + /** @description base model */ + base: components["schemas"]["BaseModelType"]; /** * Type * @default t2i_adapter @@ -10299,13 +10151,21 @@ export type components = { * @description current fasthash of model contents */ current_hash?: string | null; - /** Description */ + /** + * Description + * @description human readable description of the model + */ description?: string | null; /** * Source - * @description Model download source (URL or repo_id) + * @description model original source (path, URL or repo_id) */ source?: string | null; + /** + * Last Modified + * @description timestamp for modification time + */ + last_modified?: number | null; }; /** TBLR */ TBLR: { @@ -10323,11 +10183,18 @@ export type components = { * @description Model config for textual inversion embeddings. */ TextualInversionConfig: { - /** Path */ + /** + * Path + * @description filesystem path to the model file or directory + */ path: string; - /** Name */ + /** + * Name + * @description model name + */ name: string; - base: components["schemas"]["invokeai__backend__model_manager__config__BaseModelType"]; + /** @description base model */ + base: components["schemas"]["BaseModelType"]; /** * Type * @default embedding @@ -10355,32 +10222,21 @@ export type components = { * @description current fasthash of model contents */ current_hash?: string | null; - /** Description */ + /** + * Description + * @description human readable description of the model + */ description?: string | null; /** * Source - * @description Model download source (URL or repo_id) + * @description model original source (path, URL or repo_id) */ source?: string | null; - }; - /** TextualInversionModelConfig */ - TextualInversionModelConfig: { - /** Model Name */ - model_name: string; - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; /** - * Model Type - * @default embedding - * @constant + * Last Modified + * @description timestamp for modification time */ - model_type: "embedding"; - /** Path */ - path: string; - /** Description */ - description?: string | null; - /** Model Format */ - model_format: null; - error?: components["schemas"]["ModelError"] | null; + last_modified?: number | null; }; /** Tile */ Tile: { @@ -10546,7 +10402,7 @@ export type components = { }; /** * UNetOutput - * @description Base class for invocations that output a UNet field + * @description Base class for invocations that output a UNet field. */ UNetOutput: { /** @@ -10646,12 +10502,10 @@ export type components = { */ VAEModelField: { /** - * Model Name - * @description Name of the model + * Key + * @description Model's key */ - model_name: string; - /** @description Base model */ - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; + key: string; }; /** * VAEOutput @@ -10675,11 +10529,18 @@ export type components = { * @description Model config for standalone VAE models. */ VaeCheckpointConfig: { - /** Path */ + /** + * Path + * @description filesystem path to the model file or directory + */ path: string; - /** Name */ + /** + * Name + * @description model name + */ name: string; - base: components["schemas"]["invokeai__backend__model_manager__config__BaseModelType"]; + /** @description base model */ + base: components["schemas"]["BaseModelType"]; /** * Type * @default vae @@ -10708,24 +10569,39 @@ export type components = { * @description current fasthash of model contents */ current_hash?: string | null; - /** Description */ + /** + * Description + * @description human readable description of the model + */ description?: string | null; /** * Source - * @description Model download source (URL or repo_id) + * @description model original source (path, URL or repo_id) */ source?: string | null; + /** + * Last Modified + * @description timestamp for modification time + */ + last_modified?: number | null; }; /** * VaeDiffusersConfig * @description Model config for standalone VAE models (diffusers version). */ VaeDiffusersConfig: { - /** Path */ + /** + * Path + * @description filesystem path to the model file or directory + */ path: string; - /** Name */ + /** + * Name + * @description model name + */ name: string; - base: components["schemas"]["invokeai__backend__model_manager__config__BaseModelType"]; + /** @description base model */ + base: components["schemas"]["BaseModelType"]; /** * Type * @default vae @@ -10754,13 +10630,21 @@ export type components = { * @description current fasthash of model contents */ current_hash?: string | null; - /** Description */ + /** + * Description + * @description human readable description of the model + */ description?: string | null; /** * Source - * @description Model download source (URL or repo_id) + * @description model original source (path, URL or repo_id) */ source?: string | null; + /** + * Last Modified + * @description timestamp for modification time + */ + last_modified?: number | null; }; /** VaeField */ VaeField: { @@ -10806,29 +10690,6 @@ export type components = { */ type: "vae_loader"; }; - /** VaeModelConfig */ - VaeModelConfig: { - /** Model Name */ - model_name: string; - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; - /** - * Model Type - * @default vae - * @constant - */ - model_type: "vae"; - /** Path */ - path: string; - /** Description */ - description?: string | null; - model_format: components["schemas"]["VaeModelFormat"]; - error?: components["schemas"]["ModelError"] | null; - }; - /** - * VaeModelFormat - * @enum {string} - */ - VaeModelFormat: "checkpoint" | "diffusers"; /** ValidationError */ ValidationError: { /** Location */ @@ -11085,63 +10946,6 @@ export type components = { */ type: "zoe_depth_image_processor"; }; - /** - * ModelsList - * @description Return list of configs. - */ - invokeai__app__api__routers__model_records__ModelsList: { - /** Models */ - models: ((components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"]) | (components["schemas"]["ONNXSD1Config"] | components["schemas"]["ONNXSD2Config"]) | (components["schemas"]["VaeDiffusersConfig"] | components["schemas"]["VaeCheckpointConfig"]) | (components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"]) | components["schemas"]["LoRAConfig"] | components["schemas"]["TextualInversionConfig"] | components["schemas"]["IPAdapterConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"] | components["schemas"]["T2IConfig"])[]; - }; - /** ModelsList */ - invokeai__app__api__routers__models__ModelsList: { - /** Models */ - models: (components["schemas"]["ONNXStableDiffusion1ModelConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelCheckpointConfig"] | components["schemas"]["ControlNetModelDiffusersConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["IPAdapterModelInvokeAIConfig"] | components["schemas"]["CLIPVisionModelDiffusersConfig"] | components["schemas"]["T2IAdapterModelDiffusersConfig"] | components["schemas"]["ONNXStableDiffusion2ModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"])[]; - }; - /** - * BaseModelType - * @enum {string} - */ - invokeai__backend__model_management__models__base__BaseModelType: "any" | "sd-1" | "sd-2" | "sdxl" | "sdxl-refiner"; - /** - * ModelType - * @enum {string} - */ - invokeai__backend__model_management__models__base__ModelType: "onnx" | "main" | "vae" | "lora" | "controlnet" | "embedding" | "ip_adapter" | "clip_vision" | "t2i_adapter"; - /** - * ModelVariantType - * @enum {string} - */ - invokeai__backend__model_management__models__base__ModelVariantType: "normal" | "inpaint" | "depth"; - /** - * SchedulerPredictionType - * @enum {string} - */ - invokeai__backend__model_management__models__base__SchedulerPredictionType: "epsilon" | "v_prediction" | "sample"; - /** - * BaseModelType - * @description Base model type. - * @enum {string} - */ - invokeai__backend__model_manager__config__BaseModelType: "any" | "sd-1" | "sd-2" | "sdxl" | "sdxl-refiner"; - /** - * ModelType - * @description Model type. - * @enum {string} - */ - invokeai__backend__model_manager__config__ModelType: "onnx" | "main" | "vae" | "lora" | "controlnet" | "embedding" | "ip_adapter" | "clip_vision" | "t2i_adapter"; - /** - * ModelVariantType - * @description Variant type. - * @enum {string} - */ - invokeai__backend__model_manager__config__ModelVariantType: "normal" | "inpaint" | "depth"; - /** - * SchedulerPredictionType - * @description Scheduler prediction type. - * @enum {string} - */ - invokeai__backend__model_manager__config__SchedulerPredictionType: "epsilon" | "v_prediction" | "sample"; /** * Classification * @description The classification of an Invocation. @@ -11309,17 +11113,17 @@ export type components = { */ UIType: "SDXLMainModelField" | "SDXLRefinerModelField" | "ONNXModelField" | "VAEModelField" | "LoRAModelField" | "ControlNetModelField" | "IPAdapterModelField" | "SchedulerField" | "AnyField" | "CollectionField" | "CollectionItemField" | "DEPRECATED_Boolean" | "DEPRECATED_Color" | "DEPRECATED_Conditioning" | "DEPRECATED_Control" | "DEPRECATED_Float" | "DEPRECATED_Image" | "DEPRECATED_Integer" | "DEPRECATED_Latents" | "DEPRECATED_String" | "DEPRECATED_BooleanCollection" | "DEPRECATED_ColorCollection" | "DEPRECATED_ConditioningCollection" | "DEPRECATED_ControlCollection" | "DEPRECATED_FloatCollection" | "DEPRECATED_ImageCollection" | "DEPRECATED_IntegerCollection" | "DEPRECATED_LatentsCollection" | "DEPRECATED_StringCollection" | "DEPRECATED_BooleanPolymorphic" | "DEPRECATED_ColorPolymorphic" | "DEPRECATED_ConditioningPolymorphic" | "DEPRECATED_ControlPolymorphic" | "DEPRECATED_FloatPolymorphic" | "DEPRECATED_ImagePolymorphic" | "DEPRECATED_IntegerPolymorphic" | "DEPRECATED_LatentsPolymorphic" | "DEPRECATED_StringPolymorphic" | "DEPRECATED_MainModel" | "DEPRECATED_UNet" | "DEPRECATED_Vae" | "DEPRECATED_CLIP" | "DEPRECATED_Collection" | "DEPRECATED_CollectionItem" | "DEPRECATED_Enum" | "DEPRECATED_WorkflowField" | "DEPRECATED_IsIntermediate" | "DEPRECATED_BoardField" | "DEPRECATED_MetadataItem" | "DEPRECATED_MetadataItemCollection" | "DEPRECATED_MetadataItemPolymorphic" | "DEPRECATED_MetadataDict"; /** - * StableDiffusionXLModelFormat + * VaeModelFormat * @description An enumeration. * @enum {string} */ - StableDiffusionXLModelFormat: "checkpoint" | "diffusers"; + VaeModelFormat: "checkpoint" | "diffusers"; /** - * T2IAdapterModelFormat + * StableDiffusion2ModelFormat * @description An enumeration. * @enum {string} */ - T2IAdapterModelFormat: "diffusers"; + StableDiffusion2ModelFormat: "checkpoint" | "diffusers"; /** * CLIPVisionModelFormat * @description An enumeration. @@ -11327,11 +11131,11 @@ export type components = { */ CLIPVisionModelFormat: "diffusers"; /** - * ControlNetModelFormat + * StableDiffusionXLModelFormat * @description An enumeration. * @enum {string} */ - ControlNetModelFormat: "checkpoint" | "diffusers"; + StableDiffusionXLModelFormat: "checkpoint" | "diffusers"; /** * StableDiffusion1ModelFormat * @description An enumeration. @@ -11339,23 +11143,35 @@ export type components = { */ StableDiffusion1ModelFormat: "checkpoint" | "diffusers"; /** - * IPAdapterModelFormat + * ControlNetModelFormat * @description An enumeration. * @enum {string} */ - IPAdapterModelFormat: "invokeai"; + ControlNetModelFormat: "checkpoint" | "diffusers"; /** - * StableDiffusion2ModelFormat + * StableDiffusionOnnxModelFormat * @description An enumeration. * @enum {string} */ - StableDiffusion2ModelFormat: "checkpoint" | "diffusers"; + StableDiffusionOnnxModelFormat: "olive" | "onnx"; /** - * StableDiffusionOnnxModelFormat + * T2IAdapterModelFormat * @description An enumeration. * @enum {string} */ - StableDiffusionOnnxModelFormat: "olive" | "onnx"; + T2IAdapterModelFormat: "diffusers"; + /** + * LoRAModelFormat + * @description An enumeration. + * @enum {string} + */ + LoRAModelFormat: "lycoris" | "diffusers"; + /** + * IPAdapterModelFormat + * @description An enumeration. + * @enum {string} + */ + IPAdapterModelFormat: "invokeai"; }; responses: never; parameters: never; @@ -11426,25 +11242,63 @@ export type operations = { }; }; /** - * List Models - * @description Gets a list of models + * List Model Records + * @description Get a list of models. */ - list_models: { + list_model_records: { parameters: { query?: { /** @description Base models to include */ - base_models?: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"][] | null; + base_models?: components["schemas"]["BaseModelType"][] | null; /** @description The type of model to get */ - model_type?: components["schemas"]["invokeai__backend__model_management__models__base__ModelType"] | null; + model_type?: components["schemas"]["ModelType"] | null; + /** @description Exact match on the name of the model */ + model_name?: string | null; + /** @description Exact match on the format of the model (e.g. 'diffusers') */ + model_format?: components["schemas"]["ModelFormat"] | null; }; }; responses: { /** @description Successful Response */ 200: { content: { - "application/json": components["schemas"]["invokeai__app__api__routers__models__ModelsList"]; + "application/json": components["schemas"]["ModelsList"]; + }; + }; + /** @description Validation Error */ + 422: { + content: { + "application/json": components["schemas"]["HTTPValidationError"]; + }; + }; + }; + }; + /** + * Get Model Record + * @description Get a model record + */ + get_model_record: { + parameters: { + path: { + /** @description Key of the model record to fetch. */ + key: string; + }; + }; + responses: { + /** @description The model configuration was retrieved successfully */ + 200: { + content: { + "application/json": (components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"]) | (components["schemas"]["ONNXSD1Config"] | components["schemas"]["ONNXSD2Config"] | components["schemas"]["ONNXSDXLConfig"]) | (components["schemas"]["VaeDiffusersConfig"] | components["schemas"]["VaeCheckpointConfig"]) | (components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"]) | components["schemas"]["LoRAConfig"] | components["schemas"]["TextualInversionConfig"] | components["schemas"]["IPAdapterConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"] | components["schemas"]["T2IConfig"]; }; }; + /** @description Bad request */ + 400: { + content: never; + }; + /** @description The model could not be found */ + 404: { + content: never; + }; /** @description Validation Error */ 422: { content: { @@ -11454,18 +11308,17 @@ export type operations = { }; }; /** - * Delete Model - * @description Delete Model + * Del Model Record + * @description Delete model record from database. + * + * The configuration record will be removed. The corresponding weights files will be + * deleted as well if they reside within the InvokeAI "models" directory. */ - del_model: { + del_model_record: { parameters: { path: { - /** @description Base model */ - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; - /** @description The type of model */ - model_type: components["schemas"]["invokeai__backend__model_management__models__base__ModelType"]; - /** @description model name */ - model_name: string; + /** @description Unique key of model to remove from model registry. */ + key: string; }; }; responses: { @@ -11486,30 +11339,38 @@ export type operations = { }; }; /** - * Update Model + * Update Model Record * @description Update model contents with a new config. If the model name or base fields are changed, then the model is renamed. */ - update_model: { + update_model_record: { parameters: { path: { - /** @description Base model */ - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; - /** @description The type of model */ - model_type: components["schemas"]["invokeai__backend__model_management__models__base__ModelType"]; - /** @description model name */ - model_name: string; + /** @description Unique key of model */ + key: string; }; }; requestBody: { content: { - "application/json": components["schemas"]["ONNXStableDiffusion1ModelConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelCheckpointConfig"] | components["schemas"]["ControlNetModelDiffusersConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["IPAdapterModelInvokeAIConfig"] | components["schemas"]["CLIPVisionModelDiffusersConfig"] | components["schemas"]["T2IAdapterModelDiffusersConfig"] | components["schemas"]["ONNXStableDiffusion2ModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"]; + /** + * @example { + * "path": "/path/to/model", + * "name": "model_name", + * "base": "sd-1", + * "type": "main", + * "format": "checkpoint", + * "config": "configs/stable-diffusion/v1-inference.yaml", + * "description": "Model description", + * "variant": "normal" + * } + */ + "application/json": (components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"]) | (components["schemas"]["ONNXSD1Config"] | components["schemas"]["ONNXSD2Config"] | components["schemas"]["ONNXSDXLConfig"]) | (components["schemas"]["VaeDiffusersConfig"] | components["schemas"]["VaeCheckpointConfig"]) | (components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"]) | components["schemas"]["LoRAConfig"] | components["schemas"]["TextualInversionConfig"] | components["schemas"]["IPAdapterConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"] | components["schemas"]["T2IConfig"]; }; }; responses: { /** @description The model was updated successfully */ 200: { content: { - "application/json": components["schemas"]["ONNXStableDiffusion1ModelConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelCheckpointConfig"] | components["schemas"]["ControlNetModelDiffusersConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["IPAdapterModelInvokeAIConfig"] | components["schemas"]["CLIPVisionModelDiffusersConfig"] | components["schemas"]["T2IAdapterModelDiffusersConfig"] | components["schemas"]["ONNXStableDiffusion2ModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"]; + "application/json": (components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"]) | (components["schemas"]["ONNXSD1Config"] | components["schemas"]["ONNXSD2Config"] | components["schemas"]["ONNXSDXLConfig"]) | (components["schemas"]["VaeDiffusersConfig"] | components["schemas"]["VaeCheckpointConfig"]) | (components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"]) | components["schemas"]["LoRAConfig"] | components["schemas"]["TextualInversionConfig"] | components["schemas"]["IPAdapterConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"] | components["schemas"]["T2IConfig"]; }; }; /** @description Bad request */ @@ -11533,387 +11394,32 @@ export type operations = { }; }; /** - * Import Model - * @description Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically + * List Model Summary + * @description Gets a page of model summary data. */ - import_model: { - requestBody: { - content: { - "application/json": components["schemas"]["Body_import_model"]; + list_model_summary: { + parameters: { + query?: { + /** @description The page to get */ + page?: number; + /** @description The number of models per page */ + per_page?: number; + /** @description The attribute to order by */ + order_by?: components["schemas"]["ModelRecordOrderBy"]; }; }; responses: { - /** @description The model imported successfully */ - 201: { + /** @description Successful Response */ + 200: { content: { - "application/json": components["schemas"]["ONNXStableDiffusion1ModelConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelCheckpointConfig"] | components["schemas"]["ControlNetModelDiffusersConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["IPAdapterModelInvokeAIConfig"] | components["schemas"]["CLIPVisionModelDiffusersConfig"] | components["schemas"]["T2IAdapterModelDiffusersConfig"] | components["schemas"]["ONNXStableDiffusion2ModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"]; + "application/json": components["schemas"]["PaginatedResults_ModelSummary_"]; }; }; - /** @description The model could not be found */ - 404: { - content: never; - }; - /** @description There is already a model corresponding to this path or repo_id */ - 409: { - content: never; - }; - /** @description Unrecognized file/folder format */ - 415: { - content: never; - }; - /** @description Validation Error */ - 422: { - content: { - "application/json": components["schemas"]["HTTPValidationError"]; - }; - }; - /** @description The model appeared to import successfully, but could not be found in the model manager */ - 424: { - content: never; - }; - }; - }; - /** - * Add Model - * @description Add a model using the configuration information appropriate for its type. Only local models can be added by path - */ - add_model: { - requestBody: { - content: { - "application/json": components["schemas"]["ONNXStableDiffusion1ModelConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelCheckpointConfig"] | components["schemas"]["ControlNetModelDiffusersConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["IPAdapterModelInvokeAIConfig"] | components["schemas"]["CLIPVisionModelDiffusersConfig"] | components["schemas"]["T2IAdapterModelDiffusersConfig"] | components["schemas"]["ONNXStableDiffusion2ModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"]; - }; - }; - responses: { - /** @description The model added successfully */ - 201: { - content: { - "application/json": components["schemas"]["ONNXStableDiffusion1ModelConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelCheckpointConfig"] | components["schemas"]["ControlNetModelDiffusersConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["IPAdapterModelInvokeAIConfig"] | components["schemas"]["CLIPVisionModelDiffusersConfig"] | components["schemas"]["T2IAdapterModelDiffusersConfig"] | components["schemas"]["ONNXStableDiffusion2ModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"]; - }; - }; - /** @description The model could not be found */ - 404: { - content: never; - }; - /** @description There is already a model corresponding to this path or repo_id */ - 409: { - content: never; - }; - /** @description Validation Error */ - 422: { - content: { - "application/json": components["schemas"]["HTTPValidationError"]; - }; - }; - /** @description The model appeared to add successfully, but could not be found in the model manager */ - 424: { - content: never; - }; - }; - }; - /** - * Convert Model - * @description Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none. - */ - convert_model: { - parameters: { - query?: { - /** @description Save the converted model to the designated directory */ - convert_dest_directory?: string | null; - }; - path: { - /** @description Base model */ - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; - /** @description The type of model */ - model_type: components["schemas"]["invokeai__backend__model_management__models__base__ModelType"]; - /** @description model name */ - model_name: string; - }; - }; - responses: { - /** @description Model converted successfully */ - 200: { - content: { - "application/json": components["schemas"]["ONNXStableDiffusion1ModelConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelCheckpointConfig"] | components["schemas"]["ControlNetModelDiffusersConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["IPAdapterModelInvokeAIConfig"] | components["schemas"]["CLIPVisionModelDiffusersConfig"] | components["schemas"]["T2IAdapterModelDiffusersConfig"] | components["schemas"]["ONNXStableDiffusion2ModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"]; - }; - }; - /** @description Bad request */ - 400: { - content: never; - }; - /** @description Model not found */ - 404: { - content: never; - }; - /** @description Validation Error */ - 422: { - content: { - "application/json": components["schemas"]["HTTPValidationError"]; - }; - }; - }; - }; - /** Search For Models */ - search_for_models: { - parameters: { - query: { - /** @description Directory path to search for models */ - search_path: string; - }; - }; - responses: { - /** @description Directory searched successfully */ - 200: { - content: { - "application/json": string[]; - }; - }; - /** @description Invalid directory path */ - 404: { - content: never; - }; - /** @description Validation Error */ - 422: { - content: { - "application/json": components["schemas"]["HTTPValidationError"]; - }; - }; - }; - }; - /** - * List Ckpt Configs - * @description Return a list of the legacy checkpoint configuration files stored in `ROOT/configs/stable-diffusion`, relative to ROOT. - */ - list_ckpt_configs: { - responses: { - /** @description paths retrieved successfully */ - 200: { - content: { - "application/json": string[]; - }; - }; - }; - }; - /** - * Sync To Config - * @description Call after making changes to models.yaml, autoimport directories or models directory to synchronize - * in-memory data structures with disk data structures. - */ - sync_to_config: { - responses: { - /** @description synchronization successful */ - 201: { - content: { - "application/json": boolean; - }; - }; - }; - }; - /** - * Merge Models - * @description Convert a checkpoint model into a diffusers model - */ - merge_models: { - parameters: { - path: { - /** @description Base model */ - base_model: components["schemas"]["invokeai__backend__model_management__models__base__BaseModelType"]; - }; - }; - requestBody: { - content: { - "application/json": components["schemas"]["Body_merge_models"]; - }; - }; - responses: { - /** @description Model converted successfully */ - 200: { - content: { - "application/json": components["schemas"]["ONNXStableDiffusion1ModelConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelCheckpointConfig"] | components["schemas"]["ControlNetModelDiffusersConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["IPAdapterModelInvokeAIConfig"] | components["schemas"]["CLIPVisionModelDiffusersConfig"] | components["schemas"]["T2IAdapterModelDiffusersConfig"] | components["schemas"]["ONNXStableDiffusion2ModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusionXLModelCheckpointConfig"] | components["schemas"]["StableDiffusionXLModelDiffusersConfig"]; - }; - }; - /** @description Incompatible models */ - 400: { - content: never; - }; - /** @description One or more models not found */ - 404: { - content: never; - }; - /** @description Validation Error */ - 422: { - content: { - "application/json": components["schemas"]["HTTPValidationError"]; - }; - }; - }; - }; - /** - * List Model Records - * @description Get a list of models. - */ - list_model_records: { - parameters: { - query?: { - /** @description Base models to include */ - base_models?: components["schemas"]["invokeai__backend__model_manager__config__BaseModelType"][] | null; - /** @description The type of model to get */ - model_type?: components["schemas"]["invokeai__backend__model_manager__config__ModelType"] | null; - /** @description Exact match on the name of the model */ - model_name?: string | null; - /** @description Exact match on the format of the model (e.g. 'diffusers') */ - model_format?: components["schemas"]["ModelFormat"] | null; - }; - }; - responses: { - /** @description Successful Response */ - 200: { - content: { - "application/json": components["schemas"]["invokeai__app__api__routers__model_records__ModelsList"]; - }; - }; - /** @description Validation Error */ - 422: { - content: { - "application/json": components["schemas"]["HTTPValidationError"]; - }; - }; - }; - }; - /** - * Get Model Record - * @description Get a model record - */ - get_model_record: { - parameters: { - path: { - /** @description Key of the model record to fetch. */ - key: string; - }; - }; - responses: { - /** @description Success */ - 200: { - content: { - "application/json": (components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"]) | (components["schemas"]["ONNXSD1Config"] | components["schemas"]["ONNXSD2Config"]) | (components["schemas"]["VaeDiffusersConfig"] | components["schemas"]["VaeCheckpointConfig"]) | (components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"]) | components["schemas"]["LoRAConfig"] | components["schemas"]["TextualInversionConfig"] | components["schemas"]["IPAdapterConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"] | components["schemas"]["T2IConfig"]; - }; - }; - /** @description Bad request */ - 400: { - content: never; - }; - /** @description The model could not be found */ - 404: { - content: never; - }; - /** @description Validation Error */ - 422: { - content: { - "application/json": components["schemas"]["HTTPValidationError"]; - }; - }; - }; - }; - /** - * Del Model Record - * @description Delete model record from database. - * - * The configuration record will be removed. The corresponding weights files will be - * deleted as well if they reside within the InvokeAI "models" directory. - */ - del_model_record: { - parameters: { - path: { - /** @description Unique key of model to remove from model registry. */ - key: string; - }; - }; - responses: { - /** @description Model deleted successfully */ - 204: { - content: never; - }; - /** @description Model not found */ - 404: { - content: never; - }; - /** @description Validation Error */ - 422: { - content: { - "application/json": components["schemas"]["HTTPValidationError"]; - }; - }; - }; - }; - /** - * Update Model Record - * @description Update model contents with a new config. If the model name or base fields are changed, then the model is renamed. - */ - update_model_record: { - parameters: { - path: { - /** @description Unique key of model */ - key: string; - }; - }; - requestBody: { - content: { - "application/json": (components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"]) | (components["schemas"]["ONNXSD1Config"] | components["schemas"]["ONNXSD2Config"]) | (components["schemas"]["VaeDiffusersConfig"] | components["schemas"]["VaeCheckpointConfig"]) | (components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"]) | components["schemas"]["LoRAConfig"] | components["schemas"]["TextualInversionConfig"] | components["schemas"]["IPAdapterConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"] | components["schemas"]["T2IConfig"]; - }; - }; - responses: { - /** @description The model was updated successfully */ - 200: { - content: { - "application/json": (components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"]) | (components["schemas"]["ONNXSD1Config"] | components["schemas"]["ONNXSD2Config"]) | (components["schemas"]["VaeDiffusersConfig"] | components["schemas"]["VaeCheckpointConfig"]) | (components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"]) | components["schemas"]["LoRAConfig"] | components["schemas"]["TextualInversionConfig"] | components["schemas"]["IPAdapterConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"] | components["schemas"]["T2IConfig"]; - }; - }; - /** @description Bad request */ - 400: { - content: never; - }; - /** @description The model could not be found */ - 404: { - content: never; - }; - /** @description There is already a model corresponding to the new name */ - 409: { - content: never; - }; - /** @description Validation Error */ - 422: { - content: { - "application/json": components["schemas"]["HTTPValidationError"]; - }; - }; - }; - }; - /** - * List Model Summary - * @description Gets a page of model summary data. - */ - list_model_summary: { - parameters: { - query?: { - /** @description The page to get */ - page?: number; - /** @description The number of models per page */ - per_page?: number; - /** @description The attribute to order by */ - order_by?: components["schemas"]["ModelRecordOrderBy"]; - }; - }; - responses: { - /** @description Successful Response */ - 200: { - content: { - "application/json": components["schemas"]["PaginatedResults_ModelSummary_"]; - }; - }; - /** @description Validation Error */ - 422: { - content: { - "application/json": components["schemas"]["HTTPValidationError"]; - }; + /** @description Validation Error */ + 422: { + content: { + "application/json": components["schemas"]["HTTPValidationError"]; + }; }; }; }; @@ -11929,7 +11435,7 @@ export type operations = { }; }; responses: { - /** @description Success */ + /** @description The model metadata was retrieved successfully */ 200: { content: { "application/json": (components["schemas"]["BaseMetadata"] | components["schemas"]["HuggingFaceMetadata"] | components["schemas"]["CivitaiMetadata"]) | null; @@ -11980,7 +11486,7 @@ export type operations = { /** @description Successful Response */ 200: { content: { - "application/json": components["schemas"]["invokeai__app__api__routers__model_records__ModelsList"]; + "application/json": components["schemas"]["ModelsList"]; }; }; /** @description Validation Error */ @@ -11998,14 +11504,26 @@ export type operations = { add_model_record: { requestBody: { content: { - "application/json": (components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"]) | (components["schemas"]["ONNXSD1Config"] | components["schemas"]["ONNXSD2Config"]) | (components["schemas"]["VaeDiffusersConfig"] | components["schemas"]["VaeCheckpointConfig"]) | (components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"]) | components["schemas"]["LoRAConfig"] | components["schemas"]["TextualInversionConfig"] | components["schemas"]["IPAdapterConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"] | components["schemas"]["T2IConfig"]; + /** + * @example { + * "path": "/path/to/model", + * "name": "model_name", + * "base": "sd-1", + * "type": "main", + * "format": "checkpoint", + * "config": "configs/stable-diffusion/v1-inference.yaml", + * "description": "Model description", + * "variant": "normal" + * } + */ + "application/json": (components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"]) | (components["schemas"]["ONNXSD1Config"] | components["schemas"]["ONNXSD2Config"] | components["schemas"]["ONNXSDXLConfig"]) | (components["schemas"]["VaeDiffusersConfig"] | components["schemas"]["VaeCheckpointConfig"]) | (components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"]) | components["schemas"]["LoRAConfig"] | components["schemas"]["TextualInversionConfig"] | components["schemas"]["IPAdapterConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"] | components["schemas"]["T2IConfig"]; }; }; responses: { /** @description The model added successfully */ 201: { content: { - "application/json": (components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"]) | (components["schemas"]["ONNXSD1Config"] | components["schemas"]["ONNXSD2Config"]) | (components["schemas"]["VaeDiffusersConfig"] | components["schemas"]["VaeCheckpointConfig"]) | (components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"]) | components["schemas"]["LoRAConfig"] | components["schemas"]["TextualInversionConfig"] | components["schemas"]["IPAdapterConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"] | components["schemas"]["T2IConfig"]; + "application/json": (components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"]) | (components["schemas"]["ONNXSD1Config"] | components["schemas"]["ONNXSD2Config"] | components["schemas"]["ONNXSDXLConfig"]) | (components["schemas"]["VaeDiffusersConfig"] | components["schemas"]["VaeCheckpointConfig"]) | (components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"]) | components["schemas"]["LoRAConfig"] | components["schemas"]["TextualInversionConfig"] | components["schemas"]["IPAdapterConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"] | components["schemas"]["T2IConfig"]; }; }; /** @description There is already a model corresponding to this path or repo_id */ @@ -12025,22 +11543,81 @@ export type operations = { }; }; /** - * List Model Install Jobs - * @description Return list of model install jobs. + * Heuristic Import + * @description Install a model using a string identifier. + * + * `source` can be any of the following. + * + * 1. A path on the local filesystem ('C:\users\fred\model.safetensors') + * 2. A Url pointing to a single downloadable model file + * 3. A HuggingFace repo_id with any of the following formats: + * - model/name + * - model/name:fp16:vae + * - model/name::vae -- use default precision + * - model/name:fp16:path/to/model.safetensors + * - model/name::path/to/model.safetensors + * + * `config` is an optional dict containing model configuration values that will override + * the ones that are probed automatically. + * + * `access_token` is an optional access token for use with Urls that require + * authentication. + * + * Models will be downloaded, probed, configured and installed in a + * series of background threads. The return object has `status` attribute + * that can be used to monitor progress. + * + * See the documentation for `import_model_record` for more information on + * interpreting the job information returned by this route. */ - list_model_install_jobs: { + heuristic_import_model: { + parameters: { + query: { + source: string; + access_token?: string | null; + }; + }; + requestBody?: { + content: { + /** + * @example { + * "name": "modelT", + * "description": "antique cars" + * } + */ + "application/json": Record | null; + }; + }; responses: { - /** @description Successful Response */ - 200: { + /** @description The model imported successfully */ + 201: { content: { - "application/json": components["schemas"]["ModelInstallJob"][]; + "application/json": components["schemas"]["ModelInstallJob"]; }; }; + /** @description There is already a model corresponding to this path or repo_id */ + 409: { + content: never; + }; + /** @description Unrecognized file/folder format */ + 415: { + content: never; + }; + /** @description Validation Error */ + 422: { + content: { + "application/json": components["schemas"]["HTTPValidationError"]; + }; + }; + /** @description The model appeared to import successfully, but could not be found in the model manager */ + 424: { + content: never; + }; }; }; /** * Import Model - * @description Add a model using its local path, repo_id, or remote URL. + * @description Install a model using its local path, repo_id, or remote URL. * * Models will be downloaded, probed, configured and installed in a * series of background threads. The return object has `status` attribute @@ -12051,32 +11628,38 @@ export type operations = { * appropriate value: * * * To install a local path using LocalModelSource, pass a source of form: - * `{ + * ``` + * { * "type": "local", * "path": "/path/to/model", * "inplace": false - * }` - * The "inplace" flag, if true, will register the model in place in its - * current filesystem location. Otherwise, the model will be copied - * into the InvokeAI models directory. + * } + * ``` + * The "inplace" flag, if true, will register the model in place in its + * current filesystem location. Otherwise, the model will be copied + * into the InvokeAI models directory. * * * To install a HuggingFace repo_id using HFModelSource, pass a source of form: - * `{ + * ``` + * { * "type": "hf", * "repo_id": "stabilityai/stable-diffusion-2.0", * "variant": "fp16", * "subfolder": "vae", * "access_token": "f5820a918aaf01" - * }` - * The `variant`, `subfolder` and `access_token` fields are optional. + * } + * ``` + * The `variant`, `subfolder` and `access_token` fields are optional. * * * To install a remote model using an arbitrary URL, pass: - * `{ + * ``` + * { * "type": "url", * "url": "http://www.civitai.com/models/123456", * "access_token": "f5820a918aaf01" - * }` - * The `access_token` field is optonal + * } + * ``` + * The `access_token` field is optonal * * The model's configuration record will be probed and filled in * automatically. To override the default guesses, pass "metadata" @@ -12085,19 +11668,19 @@ export type operations = { * Installation occurs in the background. Either use list_model_install_jobs() * to poll for completion, or listen on the event bus for the following events: * - * "model_install_running" - * "model_install_completed" - * "model_install_error" + * * "model_install_running" + * * "model_install_completed" + * * "model_install_error" * * On successful completion, the event's payload will contain the field "key" * containing the installed ID of the model. On an error, the event's payload * will contain the fields "error_type" and "error" describing the nature of the * error and its traceback, respectively. */ - import_model_record: { + import_model: { requestBody: { content: { - "application/json": components["schemas"]["Body_import_model_record"]; + "application/json": components["schemas"]["Body_import_model"]; }; }; responses: { @@ -12127,6 +11710,37 @@ export type operations = { }; }; }; + /** + * List Model Install Jobs + * @description Return the list of model install jobs. + * + * Install jobs have a numeric `id`, a `status`, and other fields that provide information on + * the nature of the job and its progress. The `status` is one of: + * + * * "waiting" -- Job is waiting in the queue to run + * * "downloading" -- Model file(s) are downloading + * * "running" -- Model has downloaded and the model probing and registration process is running + * * "completed" -- Installation completed successfully + * * "error" -- An error occurred. Details will be in the "error_type" and "error" fields. + * * "cancelled" -- Job was cancelled before completion. + * + * Once completed, information about the model such as its size, base + * model, type, and metadata can be retrieved from the `config_out` + * field. For multi-file models such as diffusers, information on individual files + * can be retrieved from `download_parts`. + * + * See the example and schema below for more information. + */ + list_model_install_jobs: { + responses: { + /** @description Successful Response */ + 200: { + content: { + "application/json": components["schemas"]["ModelInstallJob"][]; + }; + }; + }; + }; /** * Prune Model Install Jobs * @description Prune all completed and errored jobs from the install job list. @@ -12151,7 +11765,8 @@ export type operations = { }; /** * Get Model Install Job - * @description Return model install job corresponding to the given source. + * @description Return model install job corresponding to the given source. See the documentation for 'List Model Install Jobs' + * for information on the format of the return value. */ get_model_install_job: { parameters: { @@ -12234,16 +11849,59 @@ export type operations = { }; }; }; + /** + * Convert Model + * @description Permanently convert a model into diffusers format, replacing the safetensors version. + * Note that during the conversion process the key and model hash will change. + * The return value is the model configuration for the converted model. + */ + convert_model: { + parameters: { + path: { + /** @description Unique key of the safetensors main model to convert to diffusers format. */ + key: string; + }; + }; + responses: { + /** @description Model converted successfully */ + 200: { + content: { + "application/json": (components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"]) | (components["schemas"]["ONNXSD1Config"] | components["schemas"]["ONNXSD2Config"] | components["schemas"]["ONNXSDXLConfig"]) | (components["schemas"]["VaeDiffusersConfig"] | components["schemas"]["VaeCheckpointConfig"]) | (components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"]) | components["schemas"]["LoRAConfig"] | components["schemas"]["TextualInversionConfig"] | components["schemas"]["IPAdapterConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"] | components["schemas"]["T2IConfig"]; + }; + }; + /** @description Bad request */ + 400: { + content: never; + }; + /** @description Model not found */ + 404: { + content: never; + }; + /** @description There is already a model registered at this location */ + 409: { + content: never; + }; + /** @description Validation Error */ + 422: { + content: { + "application/json": components["schemas"]["HTTPValidationError"]; + }; + }; + }; + }; /** * Merge - * @description Merge diffusers models. - * - * keys: List of 2-3 model keys to merge together. All models must use the same base type. - * merged_model_name: Name for the merged model [Concat model names] - * alpha: Alpha value (0.0-1.0). Higher values give more weight to the second model [0.5] - * force: If true, force the merge even if the models were generated by different versions of the diffusers library [False] - * interp: Interpolation method. One of "weighted_sum", "sigmoid", "inv_sigmoid" or "add_difference" [weighted_sum] - * merge_dest_directory: Specify a directory to store the merged model in [models directory] + * @description Merge diffusers models. The process is controlled by a set parameters provided in the body of the request. + * ``` + * Argument Description [default] + * -------- ---------------------- + * keys List of 2-3 model keys to merge together. All models must use the same base type. + * merged_model_name Name for the merged model [Concat model names] + * alpha Alpha value (0.0-1.0). Higher values give more weight to the second model [0.5] + * force If true, force the merge even if the models were generated by different versions of the diffusers library [False] + * interp Interpolation method. One of "weighted_sum", "sigmoid", "inv_sigmoid" or "add_difference" [weighted_sum] + * merge_dest_directory Specify a directory to store the merged model in [models directory] + * ``` */ merge: { requestBody: { @@ -12252,12 +11910,24 @@ export type operations = { }; }; responses: { - /** @description Successful Response */ + /** @description Model converted successfully */ 200: { content: { - "application/json": (components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"]) | (components["schemas"]["ONNXSD1Config"] | components["schemas"]["ONNXSD2Config"]) | (components["schemas"]["VaeDiffusersConfig"] | components["schemas"]["VaeCheckpointConfig"]) | (components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"]) | components["schemas"]["LoRAConfig"] | components["schemas"]["TextualInversionConfig"] | components["schemas"]["IPAdapterConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"] | components["schemas"]["T2IConfig"]; + "application/json": (components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"]) | (components["schemas"]["ONNXSD1Config"] | components["schemas"]["ONNXSD2Config"] | components["schemas"]["ONNXSDXLConfig"]) | (components["schemas"]["VaeDiffusersConfig"] | components["schemas"]["VaeCheckpointConfig"]) | (components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"]) | components["schemas"]["LoRAConfig"] | components["schemas"]["TextualInversionConfig"] | components["schemas"]["IPAdapterConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"] | components["schemas"]["T2IConfig"]; }; }; + /** @description Bad request */ + 400: { + content: never; + }; + /** @description Model not found */ + 404: { + content: never; + }; + /** @description There is already a model registered at this location */ + 409: { + content: never; + }; /** @description Validation Error */ 422: { content: { From f08571384f99294b92092625e83a64c1e0ae6308 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 15 Feb 2024 22:16:11 +1100 Subject: [PATCH 098/340] tests(ui): enable vitest type testing This is useful for the zod schemas and types we have created to match the backend. --- invokeai/frontend/web/.gitignore | 3 +++ invokeai/frontend/web/package.json | 1 + invokeai/frontend/web/pnpm-lock.yaml | 7 +++++++ invokeai/frontend/web/vite.config.mts | 5 ++++- 4 files changed, 15 insertions(+), 1 deletion(-) diff --git a/invokeai/frontend/web/.gitignore b/invokeai/frontend/web/.gitignore index 8e7ebc76a1f..3e8a372bc77 100644 --- a/invokeai/frontend/web/.gitignore +++ b/invokeai/frontend/web/.gitignore @@ -41,3 +41,6 @@ stats.html # Yalc .yalc yalc.lock + +# vitest +tsconfig.vitest-temp.json \ No newline at end of file diff --git a/invokeai/frontend/web/package.json b/invokeai/frontend/web/package.json index b2838e538ce..cea13350d26 100644 --- a/invokeai/frontend/web/package.json +++ b/invokeai/frontend/web/package.json @@ -154,6 +154,7 @@ "rollup-plugin-visualizer": "^5.12.0", "storybook": "^7.6.10", "ts-toolbelt": "^9.6.0", + "tsafe": "^1.6.6", "typescript": "^5.3.3", "vite": "^5.0.12", "vite-plugin-css-injected-by-js": "^3.3.1", diff --git a/invokeai/frontend/web/pnpm-lock.yaml b/invokeai/frontend/web/pnpm-lock.yaml index f3bf68cf1da..0ec2e47a0cd 100644 --- a/invokeai/frontend/web/pnpm-lock.yaml +++ b/invokeai/frontend/web/pnpm-lock.yaml @@ -300,6 +300,9 @@ devDependencies: ts-toolbelt: specifier: ^9.6.0 version: 9.6.0 + tsafe: + specifier: ^1.6.6 + version: 1.6.6 typescript: specifier: ^5.3.3 version: 5.3.3 @@ -13505,6 +13508,10 @@ packages: resolution: {integrity: sha512-nsZd8ZeNUzukXPlJmTBwUAuABDe/9qtVDelJeT/qW0ow3ZS3BsQJtNkan1802aM9Uf68/Y8ljw86Hu0h5IUW3w==} dev: true + /tsafe@1.6.6: + resolution: {integrity: sha512-gzkapsdbMNwBnTIjgO758GujLCj031IgHK/PKr2mrmkCSJMhSOR5FeOuSxKLMUoYc0vAA4RGEYYbjt/v6afD3g==} + dev: true + /tsconfck@3.0.1(typescript@5.3.3): resolution: {integrity: sha512-7ppiBlF3UEddCLeI1JRx5m2Ryq+xk4JrZuq4EuYXykipebaq1dV0Fhgr1hb7CkmHt32QSgOZlcqVLEtHBG4/mg==} engines: {node: ^18 || >=20} diff --git a/invokeai/frontend/web/vite.config.mts b/invokeai/frontend/web/vite.config.mts index 325c6467dee..f4dbae71232 100644 --- a/invokeai/frontend/web/vite.config.mts +++ b/invokeai/frontend/web/vite.config.mts @@ -84,7 +84,10 @@ export default defineConfig(({ mode }) => { }, }, test: { - // + typecheck: { + enabled: true, + ignoreSourceErrors: true, + }, }, }; }); From 8b09eab9ede4c12958b672aff8229b2fa1f60945 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 15 Feb 2024 22:16:55 +1100 Subject: [PATCH 099/340] tests(ui): add type tests --- .../src/features/nodes/types/common.test-d.ts | 69 +++++++++++++++++++ .../features/nodes/types/workflow.test-d.ts | 18 +++++ 2 files changed, 87 insertions(+) create mode 100644 invokeai/frontend/web/src/features/nodes/types/common.test-d.ts create mode 100644 invokeai/frontend/web/src/features/nodes/types/workflow.test-d.ts diff --git a/invokeai/frontend/web/src/features/nodes/types/common.test-d.ts b/invokeai/frontend/web/src/features/nodes/types/common.test-d.ts new file mode 100644 index 00000000000..7f28e864a13 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/common.test-d.ts @@ -0,0 +1,69 @@ +import type { + BaseModel, + BoardField, + Classification, + CLIPField, + ColorField, + ControlField, + ControlNetModelField, + ImageField, + ImageOutput, + IPAdapterField, + IPAdapterModelField, + LoraInfo, + LoRAModelField, + MainModelField, + ModelInfo, + ModelType, + ProgressImage, + SchedulerField, + SDXLRefinerModelField, + SubModelType, + T2IAdapterField, + T2IAdapterModelField, + UNetField, + VAEField, +} from 'features/nodes/types/common'; +import type { S } from 'services/api/types'; +import type { Equals, Extends } from 'tsafe'; +import { assert } from 'tsafe'; +import { describe, test } from 'vitest'; + +/** + * These types originate from the server and are recreated as zod schemas manually, for use at runtime. + * The tests ensure that the types are correctly recreated. + */ + +describe('Common types', () => { + // Complex field types + test('ImageField', () => assert>()); + test('BoardField', () => assert>()); + test('ColorField', () => assert>()); + test('SchedulerField', () => assert>>()); + test('UNetField', () => assert>()); + test('CLIPField', () => assert>()); + test('MainModelField', () => assert>()); + test('SDXLRefinerModelField', () => assert>()); + test('VAEField', () => assert>()); + test('ControlField', () => assert>()); + // @ts-expect-error TODO(psyche): fix types + test('IPAdapterField', () => assert>()); + test('T2IAdapterField', () => assert>()); + test('LoRAModelField', () => assert>()); + test('ControlNetModelField', () => assert>()); + test('IPAdapterModelField', () => assert>()); + test('T2IAdapterModelField', () => assert>()); + + // Model component types + test('BaseModel', () => assert>()); + test('ModelType', () => assert>()); + test('SubModelType', () => assert>()); + test('ModelInfo', () => assert>()); + + // Misc types + test('LoraInfo', () => assert>()); + // @ts-expect-error TODO(psyche): There is no `ProgressImage` in the server types yet + test('ProgressImage', () => assert>()); + test('ImageOutput', () => assert>()); + test('Classification', () => assert>()); +}); diff --git a/invokeai/frontend/web/src/features/nodes/types/workflow.test-d.ts b/invokeai/frontend/web/src/features/nodes/types/workflow.test-d.ts new file mode 100644 index 00000000000..7cb1ea230ce --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/workflow.test-d.ts @@ -0,0 +1,18 @@ +import type { WorkflowCategory, WorkflowV3, XYPosition } from 'features/nodes/types/workflow'; +import type * as ReactFlow from 'reactflow'; +import type { S } from 'services/api/types'; +import type { Equals, Extends } from 'tsafe'; +import { assert } from 'tsafe'; +import { describe, test } from 'vitest'; + +/** + * These types originate from the server and are recreated as zod schemas manually, for use at runtime. + * The tests ensure that the types are correctly recreated. + */ + +describe('Workflow types', () => { + test('XYPosition', () => assert>()); + test('WorkflowCategory', () => assert>()); + // @ts-expect-error TODO(psyche): Need to revise server types! + test('WorkflowV3', () => assert>()); +}); From aee44d6b3f372fdec16e9381992ec5efbecce89d Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 15 Feb 2024 22:17:16 +1100 Subject: [PATCH 100/340] fix(ui): update model types --- .../web/src/features/nodes/types/common.ts | 31 ++++++++++--------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/types/common.ts b/invokeai/frontend/web/src/features/nodes/types/common.ts index b5244743799..ef579fce8cf 100644 --- a/invokeai/frontend/web/src/features/nodes/types/common.ts +++ b/invokeai/frontend/web/src/features/nodes/types/common.ts @@ -52,27 +52,29 @@ export type SchedulerField = z.infer; // #region Model-related schemas export const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']); -export const zModelType = z.enum(['main', 'vae', 'lora', 'controlnet', 'embedding']); +export const zModelType = z.enum([ + 'main', + 'vae', + 'lora', + 'controlnet', + 'embedding', + 'ip_adapter', + 'clip_vision', + 't2i_adapter', + 'onnx', // TODO(psyche): Remove this when removed from backend +]); export const zModelName = z.string().min(3); export const zModelIdentifier = z.object({ - model_name: zModelName, - base_model: zBaseModel, + key: z.string().min(1), }); export type BaseModel = z.infer; export type ModelType = z.infer; export type ModelIdentifier = z.infer; -export const zMainModelField = z.object({ - model_name: zModelName, - base_model: zBaseModel, - model_type: z.literal('main'), -}); -export const zSDXLRefinerModelField = z.object({ - model_name: z.string().min(1), - base_model: z.literal('sdxl-refiner'), - model_type: z.literal('main'), -}); +export const zMainModelField = zModelIdentifier; export type MainModelField = z.infer; + +export const zSDXLRefinerModelField = zModelIdentifier; export type SDXLRefinerModelField = z.infer; export const zSubModelType = z.enum([ @@ -92,8 +94,7 @@ export type SubModelType = z.infer; export const zVAEModelField = zModelIdentifier; export const zModelInfo = zModelIdentifier.extend({ - model_type: zModelType, - submodel: zSubModelType.optional(), + submodel_type: zSubModelType.nullish(), }); export type ModelInfo = z.infer; From 13d3a01933a69658a1fa08c024a23e3eb87c3250 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 16 Feb 2024 22:51:47 +1100 Subject: [PATCH 101/340] fix(nodes): fix t2i adapter model loading --- invokeai/app/invocations/latent.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 5dd0eb074d5..1f21b539dc9 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -509,19 +509,20 @@ def run_t2i_adapters( t2i_adapter_data = [] for t2i_adapter_field in t2i_adapter: - t2i_adapter_model_info = context.models.load(key=t2i_adapter_field.t2i_adapter_model.key) + t2i_adapter_model_config = context.models.get_config(key=t2i_adapter_field.t2i_adapter_model.key) + t2i_adapter_loaded_model = context.models.load(key=t2i_adapter_field.t2i_adapter_model.key) image = context.images.get_pil(t2i_adapter_field.image.image_name) # The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally. - if t2i_adapter_model_info.base == BaseModelType.StableDiffusion1: + if t2i_adapter_model_config.base == BaseModelType.StableDiffusion1: max_unet_downscale = 8 - elif t2i_adapter_model_info.base == BaseModelType.StableDiffusionXL: + elif t2i_adapter_model_config.base == BaseModelType.StableDiffusionXL: max_unet_downscale = 4 else: - raise ValueError(f"Unexpected T2I-Adapter base model type: '{t2i_adapter_model_info.base}'.") + raise ValueError(f"Unexpected T2I-Adapter base model type: '{t2i_adapter_model_config.base}'.") t2i_adapter_model: T2IAdapter - with t2i_adapter_model_info as t2i_adapter_model: + with t2i_adapter_loaded_model as t2i_adapter_model: total_downscale_factor = t2i_adapter_model.total_downscale_factor # Resize the T2I-Adapter input image. From b79fac79701ab802b790ebc35c7975f27d3ebca5 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 16 Feb 2024 18:56:02 +1100 Subject: [PATCH 102/340] feat(ui): update model identifier to be key (wip) - Update most model identifiers to be `{key: string}` instead of name/base/type. Doesn't change the model select components yet. - Update model _parameters_, stored in redux, to be `{key: string, base: BaseModel}` - we need to store the base model to be able to check model compatibility. May want to store the whole config? Not sure... --- .../frontend/web/.storybook/ReduxInit.tsx | 8 +- .../listeners/enqueueRequestedLinear.ts | 2 +- .../listeners/modelSelected.ts | 10 +- .../listeners/modelsLoaded.ts | 43 +-- .../common/hooks/useGroupedModelCombobox.ts | 6 +- .../src/common/hooks/useIsReadyToEnqueue.ts | 2 +- .../web/src/common/hooks/useModelCombobox.ts | 6 +- .../src/features/canvas/store/canvasSlice.ts | 2 +- .../parameters/ParamControlAdapterModel.tsx | 17 +- .../hooks/useAddControlAdapter.ts | 4 +- .../store/controlAdaptersSlice.ts | 6 +- .../features/embedding/EmbeddingSelect.tsx | 8 +- .../ImageMetadataActions.tsx | 14 +- .../src/features/lora/components/LoRACard.tsx | 2 +- .../src/features/lora/components/LoRAList.tsx | 2 +- .../features/lora/components/LoRASelect.tsx | 6 +- .../web/src/features/lora/store/loraSlice.ts | 35 ++- .../subpanels/ModelManagerPanel.tsx | 10 +- .../ModelManagerPanel/CheckpointModelEdit.tsx | 4 +- .../ModelManagerPanel/DiffusersModelEdit.tsx | 4 +- .../ModelManagerPanel/LoRAModelEdit.tsx | 14 +- .../subpanels/ModelManagerPanel/ModelList.tsx | 6 +- .../ModelManagerPanel/ModelListItem.tsx | 4 +- .../ControlNetModelFieldInputComponent.tsx | 4 +- .../IPAdapterModelFieldInputComponent.tsx | 4 +- .../inputs/LoRAModelFieldInputComponent.tsx | 4 +- .../inputs/MainModelFieldInputComponent.tsx | 4 +- .../RefinerModelFieldInputComponent.tsx | 4 +- .../SDXLMainModelFieldInputComponent.tsx | 4 +- .../T2IAdapterModelFieldInputComponent.tsx | 4 +- .../inputs/VAEModelFieldInputComponent.tsx | 4 +- .../web/src/features/nodes/types/common.ts | 16 +- .../util/graph/addControlNetToLinearGraph.ts | 2 +- .../util/graph/addIPAdapterToLinearGraph.ts | 2 +- .../nodes/util/graph/addLoRAsToGraph.ts | 9 +- .../nodes/util/graph/addSDXLLoRAstoGraph.ts | 9 +- .../util/graph/addT2IAdapterToLinearGraph.ts | 2 +- .../nodes/util/graph/buildCanvasGraph.ts | 8 +- .../util/graph/buildLinearBatchConfig.ts | 2 +- .../components/Advanced/ParamClipSkip.tsx | 6 +- .../components/Core/ParamPositivePrompt.tsx | 2 +- .../MainModel/ParamMainModelSelect.tsx | 4 +- .../VAEModel/ParamVAEModelSelect.tsx | 6 +- .../parameters/hooks/useRecallParameters.ts | 29 +- .../parameters/store/generationSlice.ts | 6 +- .../parameters/types/parameterSchemas.ts | 15 +- .../parameters/util/optimalDimension.ts | 6 +- .../ParamSDXLRefinerModelSelect.tsx | 6 +- .../AdvancedSettingsAccordion.tsx | 3 +- .../GenerationSettingsAccordion.tsx | 5 +- .../ImageSettingsAccordion.tsx | 2 +- .../ui/components/ParametersPanel.tsx | 2 +- .../web/src/services/api/endpoints/models.ts | 286 +++++------------- .../frontend/web/src/services/api/types.ts | 47 ++- 54 files changed, 268 insertions(+), 454 deletions(-) diff --git a/invokeai/frontend/web/.storybook/ReduxInit.tsx b/invokeai/frontend/web/.storybook/ReduxInit.tsx index 55d01322427..7d3f8e0d2bf 100644 --- a/invokeai/frontend/web/.storybook/ReduxInit.tsx +++ b/invokeai/frontend/web/.storybook/ReduxInit.tsx @@ -10,13 +10,7 @@ export const ReduxInit = memo((props: PropsWithChildren) => { const dispatch = useAppDispatch(); useGlobalModifiersInit(); useEffect(() => { - dispatch( - modelChanged({ - model_name: 'test_model', - base_model: 'sd-1', - model_type: 'main', - }) - ); + dispatch(modelChanged({ key: 'test_model', base: 'sd-1' })); }, []); return props.children; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts index e1e13fadbed..d1cb692c982 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts @@ -19,7 +19,7 @@ export const addEnqueueRequestedLinear = () => { let graph; - if (model && model.base_model === 'sdxl') { + if (model && model.base === 'sdxl') { if (action.payload.tabName === 'txt2img') { graph = buildLinearSDXLTextToImageGraph(state); } else { diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts index 7638c5522af..35e2ad5f9bc 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts @@ -30,8 +30,8 @@ export const addModelSelectedListener = () => { const newModel = result.data; - const newBaseModel = newModel.base_model; - const didBaseModelChange = state.generation.model?.base_model !== newBaseModel; + const newBaseModel = newModel.base; + const didBaseModelChange = state.generation.model?.base !== newBaseModel; if (didBaseModelChange) { // we may need to reset some incompatible submodels @@ -39,7 +39,7 @@ export const addModelSelectedListener = () => { // handle incompatible loras forEach(state.lora.loras, (lora, id) => { - if (lora.base_model !== newBaseModel) { + if (lora.base !== newBaseModel) { dispatch(loraRemoved(id)); modelsCleared += 1; } @@ -47,14 +47,14 @@ export const addModelSelectedListener = () => { // handle incompatible vae const { vae } = state.generation; - if (vae && vae.base_model !== newBaseModel) { + if (vae && vae.base !== newBaseModel) { dispatch(vaeSelected(null)); modelsCleared += 1; } // handle incompatible controlnets selectControlAdapterAll(state.controlAdapters).forEach((ca) => { - if (ca.model?.base_model !== newBaseModel) { + if (ca.model?.base !== newBaseModel) { dispatch(controlAdapterIsEnabledChanged({ id: ca.id, isEnabled: false })); modelsCleared += 1; } diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts index 0ffe88cd078..366644fa685 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts @@ -34,14 +34,7 @@ export const addModelsLoadedListener = () => { return; } - const isCurrentModelAvailable = currentModel - ? models.some( - (m) => - m.model_name === currentModel.model_name && - m.base_model === currentModel.base_model && - m.model_type === currentModel.model_type - ) - : false; + const isCurrentModelAvailable = currentModel ? models.some((m) => m.key === currentModel.key) : false; if (isCurrentModelAvailable) { return; @@ -74,14 +67,7 @@ export const addModelsLoadedListener = () => { return; } - const isCurrentModelAvailable = currentModel - ? models.some( - (m) => - m.model_name === currentModel.model_name && - m.base_model === currentModel.base_model && - m.model_type === currentModel.model_type - ) - : false; + const isCurrentModelAvailable = currentModel ? models.some((m) => m.key === currentModel.key) : false; if (!isCurrentModelAvailable) { dispatch(refinerModelChanged(null)); @@ -103,10 +89,7 @@ export const addModelsLoadedListener = () => { return; } - const isCurrentVAEAvailable = some( - action.payload.entities, - (m) => m?.model_name === currentVae?.model_name && m?.base_model === currentVae?.base_model - ); + const isCurrentVAEAvailable = some(action.payload.entities, (m) => m?.key === currentVae?.key); if (isCurrentVAEAvailable) { return; @@ -140,10 +123,7 @@ export const addModelsLoadedListener = () => { const loras = getState().lora.loras; forEach(loras, (lora, id) => { - const isLoRAAvailable = some( - action.payload.entities, - (m) => m?.model_name === lora?.model_name && m?.base_model === lora?.base_model - ); + const isLoRAAvailable = some(action.payload.entities, (m) => m?.key === lora?.key); if (isLoRAAvailable) { return; @@ -161,10 +141,7 @@ export const addModelsLoadedListener = () => { log.info({ models: action.payload.entities }, `ControlNet models loaded (${action.payload.ids.length})`); selectAllControlNets(getState().controlAdapters).forEach((ca) => { - const isModelAvailable = some( - action.payload.entities, - (m) => m?.model_name === ca?.model?.model_name && m?.base_model === ca?.model?.base_model - ); + const isModelAvailable = some(action.payload.entities, (m) => m?.key === ca?.model?.key); if (isModelAvailable) { return; @@ -182,10 +159,7 @@ export const addModelsLoadedListener = () => { log.info({ models: action.payload.entities }, `T2I Adapter models loaded (${action.payload.ids.length})`); selectAllT2IAdapters(getState().controlAdapters).forEach((ca) => { - const isModelAvailable = some( - action.payload.entities, - (m) => m?.model_name === ca?.model?.model_name && m?.base_model === ca?.model?.base_model - ); + const isModelAvailable = some(action.payload.entities, (m) => m?.key === ca?.model?.key); if (isModelAvailable) { return; @@ -203,10 +177,7 @@ export const addModelsLoadedListener = () => { log.info({ models: action.payload.entities }, `IP Adapter models loaded (${action.payload.ids.length})`); selectAllIPAdapters(getState().controlAdapters).forEach((ca) => { - const isModelAvailable = some( - action.payload.entities, - (m) => m?.model_name === ca?.model?.model_name && m?.base_model === ca?.model?.base_model - ); + const isModelAvailable = some(action.payload.entities, (m) => m?.key === ca?.model?.key); if (isModelAvailable) { return; diff --git a/invokeai/frontend/web/src/common/hooks/useGroupedModelCombobox.ts b/invokeai/frontend/web/src/common/hooks/useGroupedModelCombobox.ts index eb55db79ca8..875ce1f1c4c 100644 --- a/invokeai/frontend/web/src/common/hooks/useGroupedModelCombobox.ts +++ b/invokeai/frontend/web/src/common/hooks/useGroupedModelCombobox.ts @@ -5,10 +5,10 @@ import type { GroupBase } from 'chakra-react-select'; import { groupBy, map, reduce } from 'lodash-es'; import { useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; -import type { AnyModelConfigEntity } from 'services/api/endpoints/models'; +import type { AnyModelConfig } from 'services/api/endpoints/models'; import { getModelId } from 'services/api/endpoints/models'; -type UseGroupedModelComboboxArg = { +type UseGroupedModelComboboxArg = { modelEntities: EntityState | undefined; selectedModel?: Pick | null; onChange: (value: T | null) => void; @@ -24,7 +24,7 @@ type UseGroupedModelComboboxReturn = { noOptionsMessage: () => string; }; -export const useGroupedModelCombobox = ( +export const useGroupedModelCombobox = ( arg: UseGroupedModelComboboxArg ): UseGroupedModelComboboxReturn => { const { t } = useTranslation(); diff --git a/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts b/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts index baa704e75ca..b31efed970d 100644 --- a/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts +++ b/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts @@ -105,7 +105,7 @@ const selector = createMemoizedSelector( number: i + 1, }) ); - } else if (ca.model.base_model !== model?.base_model) { + } else if (ca.model.base !== model?.base) { // This should never happen, just a sanity check reasons.push( i18n.t('parameters.invoke.incompatibleBaseModelForControlAdapter', { diff --git a/invokeai/frontend/web/src/common/hooks/useModelCombobox.ts b/invokeai/frontend/web/src/common/hooks/useModelCombobox.ts index 880b3163791..341fed1e47e 100644 --- a/invokeai/frontend/web/src/common/hooks/useModelCombobox.ts +++ b/invokeai/frontend/web/src/common/hooks/useModelCombobox.ts @@ -3,10 +3,10 @@ import type { EntityState } from '@reduxjs/toolkit'; import { map } from 'lodash-es'; import { useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; -import type { AnyModelConfigEntity } from 'services/api/endpoints/models'; +import type { AnyModelConfig } from 'services/api/endpoints/models'; import { getModelId } from 'services/api/endpoints/models'; -type UseModelComboboxArg = { +type UseModelComboboxArg = { modelEntities: EntityState | undefined; selectedModel?: Pick | null; onChange: (value: T | null) => void; @@ -23,7 +23,7 @@ type UseModelComboboxReturn = { noOptionsMessage: () => string; }; -export const useModelCombobox = ( +export const useModelCombobox = ( arg: UseModelComboboxArg ): UseModelComboboxReturn => { const { t } = useTranslation(); diff --git a/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts b/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts index cd734d3f00e..f50d52c1bff 100644 --- a/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts +++ b/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts @@ -626,7 +626,7 @@ export const canvasSlice = createSlice({ }, extraReducers: (builder) => { builder.addCase(modelChanged, (state, action) => { - if (action.meta.previousModel?.base_model === action.payload?.base_model) { + if (action.meta.previousModel?.base === action.payload?.base) { // The base model hasn't changed, we don't need to optimize the size return; } diff --git a/invokeai/frontend/web/src/features/controlAdapters/components/parameters/ParamControlAdapterModel.tsx b/invokeai/frontend/web/src/features/controlAdapters/components/parameters/ParamControlAdapterModel.tsx index 13851b143c2..a3202384458 100644 --- a/invokeai/frontend/web/src/features/controlAdapters/components/parameters/ParamControlAdapterModel.tsx +++ b/invokeai/frontend/web/src/features/controlAdapters/components/parameters/ParamControlAdapterModel.tsx @@ -11,12 +11,7 @@ import { selectGenerationSlice } from 'features/parameters/store/generationSlice import { pick } from 'lodash-es'; import { memo, useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; -import type { - ControlNetModelConfigEntity, - IPAdapterModelConfigEntity, - T2IAdapterModelConfigEntity, -} from 'services/api/endpoints/models'; -import type { AnyModelConfig } from 'services/api/types'; +import type { AnyModelConfig, ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'services/api/types'; type ParamControlAdapterModelProps = { id: string; @@ -29,21 +24,21 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => { const controlAdapterType = useControlAdapterType(id); const model = useControlAdapterModel(id); const dispatch = useAppDispatch(); - const currentBaseModel = useAppSelector((s) => s.generation.model?.base_model); + const currentBaseModel = useAppSelector((s) => s.generation.model?.base); const mainModel = useAppSelector(selectMainModel); const { t } = useTranslation(); const models = useControlAdapterModelEntities(controlAdapterType); const _onChange = useCallback( - (model: ControlNetModelConfigEntity | IPAdapterModelConfigEntity | T2IAdapterModelConfigEntity | null) => { + (model: ControlNetConfig | IPAdapterConfig | T2IAdapterConfig | null) => { if (!model) { return; } dispatch( controlAdapterModelChanged({ id, - model: pick(model, 'base_model', 'model_name'), + model: pick(model, 'base', 'key'), }) ); }, @@ -57,7 +52,7 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => { const getIsDisabled = useCallback( (model: AnyModelConfig): boolean => { - const isCompatible = currentBaseModel === model.base_model; + const isCompatible = currentBaseModel === model.base; const hasMainModel = Boolean(currentBaseModel); return !hasMainModel || !isCompatible; }, @@ -73,7 +68,7 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => { return ( - + { - const baseModel = useAppSelector((s) => s.generation.model?.base_model); + const baseModel = useAppSelector((s) => s.generation.model?.base); const dispatch = useAppDispatch(); const models = useControlAdapterModels(type); const firstModel = useMemo(() => { // prefer to use a model that matches the base model - const firstCompatibleModel = models.filter((m) => (baseModel ? m.base_model === baseModel : true))[0]; + const firstCompatibleModel = models.filter((m) => (baseModel ? m.base === baseModel : true))[0]; if (firstCompatibleModel) { return firstCompatibleModel; diff --git a/invokeai/frontend/web/src/features/controlAdapters/store/controlAdaptersSlice.ts b/invokeai/frontend/web/src/features/controlAdapters/store/controlAdaptersSlice.ts index 49b07f16a12..fce94ad0196 100644 --- a/invokeai/frontend/web/src/features/controlAdapters/store/controlAdaptersSlice.ts +++ b/invokeai/frontend/web/src/features/controlAdapters/store/controlAdaptersSlice.ts @@ -236,7 +236,8 @@ export const controlAdaptersSlice = createSlice({ let processorType: ControlAdapterProcessorType | undefined = undefined; for (const modelSubstring in CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS) { - if (model.model_name.includes(modelSubstring)) { + // TODO(MM2): matching modelSubstring to the model key is no longer a valid way to figure out the default processorType + if (model.key.includes(modelSubstring)) { processorType = CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS[modelSubstring]; break; } @@ -359,7 +360,8 @@ export const controlAdaptersSlice = createSlice({ let processorType: ControlAdapterProcessorType | undefined = undefined; for (const modelSubstring in CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS) { - if (cn.model?.model_name.includes(modelSubstring)) { + // TODO(MM2): matching modelSubstring to the model key is no longer a valid way to figure out the default processorType + if (cn.model?.key.includes(modelSubstring)) { processorType = CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS[modelSubstring]; break; } diff --git a/invokeai/frontend/web/src/features/embedding/EmbeddingSelect.tsx b/invokeai/frontend/web/src/features/embedding/EmbeddingSelect.tsx index ffe9d63360a..426ddd21e23 100644 --- a/invokeai/frontend/web/src/features/embedding/EmbeddingSelect.tsx +++ b/invokeai/frontend/web/src/features/embedding/EmbeddingSelect.tsx @@ -6,18 +6,18 @@ import type { EmbeddingSelectProps } from 'features/embedding/types'; import { t } from 'i18next'; import { memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; -import type { TextualInversionModelConfigEntity } from 'services/api/endpoints/models'; import { useGetTextualInversionModelsQuery } from 'services/api/endpoints/models'; +import type { TextualInversionConfig } from 'services/api/types'; const noOptionsMessage = () => t('embedding.noMatchingEmbedding'); export const EmbeddingSelect = memo(({ onSelect, onClose }: EmbeddingSelectProps) => { const { t } = useTranslation(); - const currentBaseModel = useAppSelector((s) => s.generation.model?.base_model); + const currentBaseModel = useAppSelector((s) => s.generation.model?.base); const getIsDisabled = useCallback( - (embedding: TextualInversionModelConfigEntity): boolean => { + (embedding: TextualInversionConfig): boolean => { const isCompatible = currentBaseModel === embedding.base_model; const hasMainModel = Boolean(currentBaseModel); return !hasMainModel || !isCompatible; @@ -27,7 +27,7 @@ export const EmbeddingSelect = memo(({ onSelect, onClose }: EmbeddingSelectProps const { data, isLoading } = useGetTextualInversionModelsQuery(); const _onChange = useCallback( - (embedding: TextualInversionModelConfigEntity | null) => { + (embedding: TextualInversionConfig | null) => { if (!embedding) { return; } diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx index e9a1461186d..5907ba07000 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx @@ -208,8 +208,8 @@ const ImageMetadataActions = (props: Props) => { {metadata.seed !== undefined && metadata.seed !== null && ( )} - {metadata.model !== undefined && metadata.model !== null && metadata.model.model_name && ( - + {metadata.model !== undefined && metadata.model !== null && metadata.model.key && ( + )} {metadata.width && ( @@ -222,7 +222,7 @@ const ImageMetadataActions = (props: Props) => { )} {metadata.steps && ( @@ -269,7 +269,7 @@ const ImageMetadataActions = (props: Props) => { ); @@ -279,7 +279,7 @@ const ImageMetadataActions = (props: Props) => { ))} @@ -287,7 +287,7 @@ const ImageMetadataActions = (props: Props) => { ))} @@ -295,7 +295,7 @@ const ImageMetadataActions = (props: Props) => { ))} diff --git a/invokeai/frontend/web/src/features/lora/components/LoRACard.tsx b/invokeai/frontend/web/src/features/lora/components/LoRACard.tsx index 28bd8afe95d..81e0027b2d1 100644 --- a/invokeai/frontend/web/src/features/lora/components/LoRACard.tsx +++ b/invokeai/frontend/web/src/features/lora/components/LoRACard.tsx @@ -44,7 +44,7 @@ export const LoRACard = memo((props: LoRACardProps) => { - {lora.model_name} + {lora.key} diff --git a/invokeai/frontend/web/src/features/lora/components/LoRAList.tsx b/invokeai/frontend/web/src/features/lora/components/LoRAList.tsx index 9f37454d163..7bcd5378055 100644 --- a/invokeai/frontend/web/src/features/lora/components/LoRAList.tsx +++ b/invokeai/frontend/web/src/features/lora/components/LoRAList.tsx @@ -18,7 +18,7 @@ export const LoRAList = memo(() => { return ( {lorasArray.map((lora) => ( - + ))} ); diff --git a/invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx b/invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx index ed70a4d44a1..069c557aefa 100644 --- a/invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx +++ b/invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx @@ -7,7 +7,7 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox'; import { loraAdded, selectLoraSlice } from 'features/lora/store/loraSlice'; import { memo, useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; -import type { LoRAModelConfigEntity } from 'services/api/endpoints/models'; +import type { LoRAConfig } from 'services/api/endpoints/models'; import { useGetLoRAModelsQuery } from 'services/api/endpoints/models'; const selectAddedLoRAs = createMemoizedSelector(selectLoraSlice, (lora) => lora.loras); @@ -19,7 +19,7 @@ const LoRASelect = () => { const addedLoRAs = useAppSelector(selectAddedLoRAs); const currentBaseModel = useAppSelector((s) => s.generation.model?.base_model); - const getIsDisabled = (lora: LoRAModelConfigEntity): boolean => { + const getIsDisabled = (lora: LoRAConfig): boolean => { const isCompatible = currentBaseModel === lora.base_model; const isAdded = Boolean(addedLoRAs[lora.id]); const hasMainModel = Boolean(currentBaseModel); @@ -27,7 +27,7 @@ const LoRASelect = () => { }; const _onChange = useCallback( - (lora: LoRAModelConfigEntity | null) => { + (lora: LoRAConfig | null) => { if (!lora) { return; } diff --git a/invokeai/frontend/web/src/features/lora/store/loraSlice.ts b/invokeai/frontend/web/src/features/lora/store/loraSlice.ts index ab1b140a7cc..dd455e12c39 100644 --- a/invokeai/frontend/web/src/features/lora/store/loraSlice.ts +++ b/invokeai/frontend/web/src/features/lora/store/loraSlice.ts @@ -2,10 +2,9 @@ import type { PayloadAction } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit'; import type { PersistConfig, RootState } from 'app/store/store'; import type { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas'; -import type { LoRAModelConfigEntity } from 'services/api/endpoints/models'; +import type { LoRAConfig } from 'services/api/types'; export type LoRA = ParameterLoRAModel & { - id: string; weight: number; isEnabled?: boolean; }; @@ -29,40 +28,40 @@ export const loraSlice = createSlice({ name: 'lora', initialState: initialLoraState, reducers: { - loraAdded: (state, action: PayloadAction) => { - const { model_name, id, base_model } = action.payload; - state.loras[id] = { id, model_name, base_model, ...defaultLoRAConfig }; + loraAdded: (state, action: PayloadAction) => { + const { key, base } = action.payload; + state.loras[key] = { key, base, ...defaultLoRAConfig }; }, - loraRecalled: (state, action: PayloadAction) => { - const { model_name, id, base_model, weight } = action.payload; - state.loras[id] = { id, model_name, base_model, weight, isEnabled: true }; + loraRecalled: (state, action: PayloadAction) => { + const { key, base, weight } = action.payload; + state.loras[key] = { key, base, weight, isEnabled: true }; }, loraRemoved: (state, action: PayloadAction) => { - const id = action.payload; - delete state.loras[id]; + const key = action.payload; + delete state.loras[key]; }, lorasCleared: (state) => { state.loras = {}; }, - loraWeightChanged: (state, action: PayloadAction<{ id: string; weight: number }>) => { - const { id, weight } = action.payload; - const lora = state.loras[id]; + loraWeightChanged: (state, action: PayloadAction<{ key: string; weight: number }>) => { + const { key, weight } = action.payload; + const lora = state.loras[key]; if (!lora) { return; } lora.weight = weight; }, loraWeightReset: (state, action: PayloadAction) => { - const id = action.payload; - const lora = state.loras[id]; + const key = action.payload; + const lora = state.loras[key]; if (!lora) { return; } lora.weight = defaultLoRAConfig.weight; }, - loraIsEnabledChanged: (state, action: PayloadAction>) => { - const { id, isEnabled } = action.payload; - const lora = state.loras[id]; + loraIsEnabledChanged: (state, action: PayloadAction>) => { + const { key, isEnabled } = action.payload; + const lora = state.loras[key]; if (!lora) { return; } diff --git a/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel.tsx b/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel.tsx index 6b9abdbfec7..7501151ba49 100644 --- a/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel.tsx +++ b/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel.tsx @@ -3,9 +3,9 @@ import { memo, useState } from 'react'; import { useTranslation } from 'react-i18next'; import { ALL_BASE_MODELS } from 'services/api/constants'; import type { - DiffusersModelConfigEntity, - LoRAModelConfigEntity, - MainModelConfigEntity, + DiffusersModelConfig, + LoRAConfig, + MainModelConfig, } from 'services/api/endpoints/models'; import { useGetLoRAModelsQuery, useGetMainModelsQuery } from 'services/api/endpoints/models'; @@ -38,7 +38,7 @@ const ModelManagerPanel = () => { }; type ModelEditProps = { - model: MainModelConfigEntity | LoRAModelConfigEntity | undefined; + model: MainModelConfig | LoRAConfig | undefined; }; const ModelEdit = (props: ModelEditProps) => { @@ -50,7 +50,7 @@ const ModelEdit = (props: ModelEditProps) => { } if (model?.model_format === 'diffusers') { - return ; + return ; } if (model?.model_type === 'lora') { diff --git a/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx b/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx index f4d271187db..43707308e0e 100644 --- a/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx +++ b/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx @@ -21,14 +21,14 @@ import { memo, useCallback, useEffect, useState } from 'react'; import type { SubmitHandler } from 'react-hook-form'; import { useForm } from 'react-hook-form'; import { useTranslation } from 'react-i18next'; -import type { CheckpointModelConfigEntity } from 'services/api/endpoints/models'; +import type { CheckpointModelConfig } from 'services/api/endpoints/models'; import { useGetCheckpointConfigsQuery, useUpdateMainModelsMutation } from 'services/api/endpoints/models'; import type { CheckpointModelConfig } from 'services/api/types'; import ModelConvert from './ModelConvert'; type CheckpointModelEditProps = { - model: CheckpointModelConfigEntity; + model: CheckpointModelConfig; }; const CheckpointModelEdit = (props: CheckpointModelEditProps) => { diff --git a/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/DiffusersModelEdit.tsx b/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/DiffusersModelEdit.tsx index 4670f321570..bf6349234f5 100644 --- a/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/DiffusersModelEdit.tsx +++ b/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/DiffusersModelEdit.tsx @@ -9,12 +9,12 @@ import { memo, useCallback } from 'react'; import type { SubmitHandler } from 'react-hook-form'; import { useForm } from 'react-hook-form'; import { useTranslation } from 'react-i18next'; -import type { DiffusersModelConfigEntity } from 'services/api/endpoints/models'; +import type { DiffusersModelConfig } from 'services/api/endpoints/models'; import { useUpdateMainModelsMutation } from 'services/api/endpoints/models'; import type { DiffusersModelConfig } from 'services/api/types'; type DiffusersModelEditProps = { - model: DiffusersModelConfigEntity; + model: DiffusersModelConfig; }; const DiffusersModelEdit = (props: DiffusersModelEditProps) => { diff --git a/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/LoRAModelEdit.tsx b/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/LoRAModelEdit.tsx index 2baf735beeb..75151cd0012 100644 --- a/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/LoRAModelEdit.tsx +++ b/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/LoRAModelEdit.tsx @@ -8,12 +8,12 @@ import { memo, useCallback } from 'react'; import type { SubmitHandler } from 'react-hook-form'; import { useForm } from 'react-hook-form'; import { useTranslation } from 'react-i18next'; -import type { LoRAModelConfigEntity } from 'services/api/endpoints/models'; +import type { LoRAConfig } from 'services/api/endpoints/models'; import { useUpdateLoRAModelsMutation } from 'services/api/endpoints/models'; -import type { LoRAModelConfig } from 'services/api/types'; +import type { LoRAConfig } from 'services/api/types'; type LoRAModelEditProps = { - model: LoRAModelConfigEntity; + model: LoRAConfig; }; const LoRAModelEdit = (props: LoRAModelEditProps) => { @@ -30,7 +30,7 @@ const LoRAModelEdit = (props: LoRAModelEditProps) => { control, formState: { errors }, reset, - } = useForm({ + } = useForm({ defaultValues: { model_name: model.model_name ? model.model_name : '', base_model: model.base_model, @@ -42,7 +42,7 @@ const LoRAModelEdit = (props: LoRAModelEditProps) => { mode: 'onChange', }); - const onSubmit = useCallback>( + const onSubmit = useCallback>( (values) => { const responseBody = { base_model: model.base_model, @@ -53,7 +53,7 @@ const LoRAModelEdit = (props: LoRAModelEditProps) => { updateLoRAModel(responseBody) .unwrap() .then((payload) => { - reset(payload as LoRAModelConfig, { keepDefaultValues: true }); + reset(payload as LoRAConfig, { keepDefaultValues: true }); dispatch( addToast( makeToast({ @@ -106,7 +106,7 @@ const LoRAModelEdit = (props: LoRAModelEditProps) => { {t('modelManager.description')} - control={control} name="base_model" /> + control={control} name="base_model" /> {t('modelManager.modelLocation')} diff --git a/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/ModelList.tsx b/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/ModelList.tsx index 94db3d20c33..dd74bb0c23f 100644 --- a/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/ModelList.tsx +++ b/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/ModelList.tsx @@ -5,7 +5,7 @@ import type { ChangeEvent, PropsWithChildren } from 'react'; import { memo, useCallback, useState } from 'react'; import { useTranslation } from 'react-i18next'; import { ALL_BASE_MODELS } from 'services/api/constants'; -import type { LoRAModelConfigEntity, MainModelConfigEntity } from 'services/api/endpoints/models'; +import type { LoRAConfig, MainModelConfig } from 'services/api/endpoints/models'; import { useGetLoRAModelsQuery, useGetMainModelsQuery } from 'services/api/endpoints/models'; import ModelListItem from './ModelListItem'; @@ -127,7 +127,7 @@ const ModelList = (props: ModelListProps) => { export default memo(ModelList); -const modelsFilter = ( +const modelsFilter = ( data: EntityState | undefined, model_type: ModelType, model_format: ModelFormat | undefined, @@ -163,7 +163,7 @@ StyledModelContainer.displayName = 'StyledModelContainer'; type ModelListWrapperProps = { title: string; - modelList: MainModelConfigEntity[] | LoRAModelConfigEntity[]; + modelList: MainModelConfig[] | LoRAConfig[]; selected: ModelListProps; }; diff --git a/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/ModelListItem.tsx b/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/ModelListItem.tsx index fdd13e09f55..835499d25ad 100644 --- a/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/ModelListItem.tsx +++ b/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/ModelListItem.tsx @@ -15,11 +15,11 @@ import { makeToast } from 'features/system/util/makeToast'; import { memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; import { PiTrashSimpleBold } from 'react-icons/pi'; -import type { LoRAModelConfigEntity, MainModelConfigEntity } from 'services/api/endpoints/models'; +import type { LoRAConfig, MainModelConfig } from 'services/api/endpoints/models'; import { useDeleteLoRAModelsMutation, useDeleteMainModelsMutation } from 'services/api/endpoints/models'; type ModelListItemProps = { - model: MainModelConfigEntity | LoRAModelConfigEntity; + model: MainModelConfig | LoRAConfig; isSelected: boolean; setSelectedModelId: (v: string | undefined) => void; }; diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ControlNetModelFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ControlNetModelFieldInputComponent.tsx index 22024f3d3ce..53d800e7b62 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ControlNetModelFieldInputComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ControlNetModelFieldInputComponent.tsx @@ -4,7 +4,7 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox'; import { fieldControlNetModelValueChanged } from 'features/nodes/store/nodesSlice'; import type { ControlNetModelFieldInputInstance, ControlNetModelFieldInputTemplate } from 'features/nodes/types/field'; import { memo, useCallback } from 'react'; -import type { ControlNetModelConfigEntity } from 'services/api/endpoints/models'; +import type { ControlNetConfig } from 'services/api/endpoints/models'; import { useGetControlNetModelsQuery } from 'services/api/endpoints/models'; import type { FieldComponentProps } from './types'; @@ -17,7 +17,7 @@ const ControlNetModelFieldInputComponent = (props: Props) => { const { data, isLoading } = useGetControlNetModelsQuery(); const _onChange = useCallback( - (value: ControlNetModelConfigEntity | null) => { + (value: ControlNetConfig | null) => { if (!value) { return; } diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/IPAdapterModelFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/IPAdapterModelFieldInputComponent.tsx index 2cde3472475..3f195ceb32c 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/IPAdapterModelFieldInputComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/IPAdapterModelFieldInputComponent.tsx @@ -4,7 +4,7 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox'; import { fieldIPAdapterModelValueChanged } from 'features/nodes/store/nodesSlice'; import type { IPAdapterModelFieldInputInstance, IPAdapterModelFieldInputTemplate } from 'features/nodes/types/field'; import { memo, useCallback } from 'react'; -import type { IPAdapterModelConfigEntity } from 'services/api/endpoints/models'; +import type { IPAdapterConfig } from 'services/api/endpoints/models'; import { useGetIPAdapterModelsQuery } from 'services/api/endpoints/models'; import type { FieldComponentProps } from './types'; @@ -17,7 +17,7 @@ const IPAdapterModelFieldInputComponent = ( const { data: ipAdapterModels } = useGetIPAdapterModelsQuery(); const _onChange = useCallback( - (value: IPAdapterModelConfigEntity | null) => { + (value: IPAdapterConfig | null) => { if (!value) { return; } diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/LoRAModelFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/LoRAModelFieldInputComponent.tsx index 96208d68d44..eeb07fa08e2 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/LoRAModelFieldInputComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/LoRAModelFieldInputComponent.tsx @@ -4,7 +4,7 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox'; import { fieldLoRAModelValueChanged } from 'features/nodes/store/nodesSlice'; import type { LoRAModelFieldInputInstance, LoRAModelFieldInputTemplate } from 'features/nodes/types/field'; import { memo, useCallback } from 'react'; -import type { LoRAModelConfigEntity } from 'services/api/endpoints/models'; +import type { LoRAConfig } from 'services/api/endpoints/models'; import { useGetLoRAModelsQuery } from 'services/api/endpoints/models'; import type { FieldComponentProps } from './types'; @@ -16,7 +16,7 @@ const LoRAModelFieldInputComponent = (props: Props) => { const dispatch = useAppDispatch(); const { data, isLoading } = useGetLoRAModelsQuery(); const _onChange = useCallback( - (value: LoRAModelConfigEntity | null) => { + (value: LoRAConfig | null) => { if (!value) { return; } diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/MainModelFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/MainModelFieldInputComponent.tsx index 64c6970cae3..7ddde08816c 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/MainModelFieldInputComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/MainModelFieldInputComponent.tsx @@ -6,7 +6,7 @@ import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice'; import type { MainModelFieldInputInstance, MainModelFieldInputTemplate } from 'features/nodes/types/field'; import { memo, useCallback } from 'react'; import { NON_SDXL_MAIN_MODELS } from 'services/api/constants'; -import type { MainModelConfigEntity } from 'services/api/endpoints/models'; +import type { MainModelConfig } from 'services/api/endpoints/models'; import { useGetMainModelsQuery } from 'services/api/endpoints/models'; import type { FieldComponentProps } from './types'; @@ -18,7 +18,7 @@ const MainModelFieldInputComponent = (props: Props) => { const dispatch = useAppDispatch(); const { data, isLoading } = useGetMainModelsQuery(NON_SDXL_MAIN_MODELS); const _onChange = useCallback( - (value: MainModelConfigEntity | null) => { + (value: MainModelConfig | null) => { if (!value) { return; } diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/RefinerModelFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/RefinerModelFieldInputComponent.tsx index 98901af38b2..9b5a1138d4a 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/RefinerModelFieldInputComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/RefinerModelFieldInputComponent.tsx @@ -9,7 +9,7 @@ import type { } from 'features/nodes/types/field'; import { memo, useCallback } from 'react'; import { REFINER_BASE_MODELS } from 'services/api/constants'; -import type { MainModelConfigEntity } from 'services/api/endpoints/models'; +import type { MainModelConfig } from 'services/api/endpoints/models'; import { useGetMainModelsQuery } from 'services/api/endpoints/models'; import type { FieldComponentProps } from './types'; @@ -21,7 +21,7 @@ const RefinerModelFieldInputComponent = (props: Props) => { const dispatch = useAppDispatch(); const { data, isLoading } = useGetMainModelsQuery(REFINER_BASE_MODELS); const _onChange = useCallback( - (value: MainModelConfigEntity | null) => { + (value: MainModelConfig | null) => { if (!value) { return; } diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/SDXLMainModelFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/SDXLMainModelFieldInputComponent.tsx index f5bc7ac3e40..cf353619e8f 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/SDXLMainModelFieldInputComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/SDXLMainModelFieldInputComponent.tsx @@ -6,7 +6,7 @@ import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice'; import type { SDXLMainModelFieldInputInstance, SDXLMainModelFieldInputTemplate } from 'features/nodes/types/field'; import { memo, useCallback } from 'react'; import { SDXL_MAIN_MODELS } from 'services/api/constants'; -import type { MainModelConfigEntity } from 'services/api/endpoints/models'; +import type { MainModelConfig } from 'services/api/endpoints/models'; import { useGetMainModelsQuery } from 'services/api/endpoints/models'; import type { FieldComponentProps } from './types'; @@ -18,7 +18,7 @@ const SDXLMainModelFieldInputComponent = (props: Props) => { const dispatch = useAppDispatch(); const { data, isLoading } = useGetMainModelsQuery(SDXL_MAIN_MODELS); const _onChange = useCallback( - (value: MainModelConfigEntity | null) => { + (value: MainModelConfig | null) => { if (!value) { return; } diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/T2IAdapterModelFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/T2IAdapterModelFieldInputComponent.tsx index 9baf0d2d61e..8402c56343a 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/T2IAdapterModelFieldInputComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/T2IAdapterModelFieldInputComponent.tsx @@ -4,7 +4,7 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox'; import { fieldT2IAdapterModelValueChanged } from 'features/nodes/store/nodesSlice'; import type { T2IAdapterModelFieldInputInstance, T2IAdapterModelFieldInputTemplate } from 'features/nodes/types/field'; import { memo, useCallback } from 'react'; -import type { T2IAdapterModelConfigEntity } from 'services/api/endpoints/models'; +import type { T2IAdapterConfig } from 'services/api/endpoints/models'; import { useGetT2IAdapterModelsQuery } from 'services/api/endpoints/models'; import type { FieldComponentProps } from './types'; @@ -18,7 +18,7 @@ const T2IAdapterModelFieldInputComponent = ( const { data: t2iAdapterModels } = useGetT2IAdapterModelsQuery(); const _onChange = useCallback( - (value: T2IAdapterModelConfigEntity | null) => { + (value: T2IAdapterConfig | null) => { if (!value) { return; } diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/VAEModelFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/VAEModelFieldInputComponent.tsx index 070178f32a7..af09f2d8f20 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/VAEModelFieldInputComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/VAEModelFieldInputComponent.tsx @@ -5,7 +5,7 @@ import { SyncModelsIconButton } from 'features/modelManager/components/SyncModel import { fieldVaeModelValueChanged } from 'features/nodes/store/nodesSlice'; import type { VAEModelFieldInputInstance, VAEModelFieldInputTemplate } from 'features/nodes/types/field'; import { memo, useCallback } from 'react'; -import type { VaeModelConfigEntity } from 'services/api/endpoints/models'; +import type { VAEConfig } from 'services/api/endpoints/models'; import { useGetVaeModelsQuery } from 'services/api/endpoints/models'; import type { FieldComponentProps } from './types'; @@ -17,7 +17,7 @@ const VAEModelFieldInputComponent = (props: Props) => { const dispatch = useAppDispatch(); const { data, isLoading } = useGetVaeModelsQuery(); const _onChange = useCallback( - (value: VaeModelConfigEntity | null) => { + (value: VAEConfig | null) => { if (!value) { return; } diff --git a/invokeai/frontend/web/src/features/nodes/types/common.ts b/invokeai/frontend/web/src/features/nodes/types/common.ts index ef579fce8cf..891bd29bc86 100644 --- a/invokeai/frontend/web/src/features/nodes/types/common.ts +++ b/invokeai/frontend/web/src/features/nodes/types/common.ts @@ -67,11 +67,13 @@ export const zModelName = z.string().min(3); export const zModelIdentifier = z.object({ key: z.string().min(1), }); +export const zModelFieldBase = zModelIdentifier; +export const zModelIdentifierWithBase = zModelIdentifier.extend({ base: zBaseModel }); export type BaseModel = z.infer; export type ModelType = z.infer; export type ModelIdentifier = z.infer; - -export const zMainModelField = zModelIdentifier; +export type ModelIdentifierWithBase = z.infer; +export const zMainModelField = zModelFieldBase; export type MainModelField = z.infer; export const zSDXLRefinerModelField = zModelIdentifier; @@ -91,23 +93,23 @@ export const zSubModelType = z.enum([ ]); export type SubModelType = z.infer; -export const zVAEModelField = zModelIdentifier; +export const zVAEModelField = zModelFieldBase; export const zModelInfo = zModelIdentifier.extend({ submodel_type: zSubModelType.nullish(), }); export type ModelInfo = z.infer; -export const zLoRAModelField = zModelIdentifier; +export const zLoRAModelField = zModelFieldBase; export type LoRAModelField = z.infer; -export const zControlNetModelField = zModelIdentifier; +export const zControlNetModelField = zModelFieldBase; export type ControlNetModelField = z.infer; -export const zIPAdapterModelField = zModelIdentifier; +export const zIPAdapterModelField = zModelFieldBase; export type IPAdapterModelField = z.infer; -export const zT2IAdapterModelField = zModelIdentifier; +export const zT2IAdapterModelField = zModelFieldBase; export type T2IAdapterModelField = z.infer; export const zLoraInfo = zModelInfo.extend({ diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addControlNetToLinearGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addControlNetToLinearGraph.ts index d862b0986ea..1853d3722cd 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/addControlNetToLinearGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/addControlNetToLinearGraph.ts @@ -14,7 +14,7 @@ import { upsertMetadata } from './metadata'; export const addControlNetToLinearGraph = (state: RootState, graph: NonNullableGraph, baseNodeId: string): void => { const validControlNets = selectValidControlNets(state.controlAdapters).filter( - (ca) => ca.model?.base_model === state.generation.model?.base_model + (ca) => ca.model?.base === state.generation.model?.base ); // const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addIPAdapterToLinearGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addIPAdapterToLinearGraph.ts index 3a79b78c6ed..b51ac1bd52a 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/addIPAdapterToLinearGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/addIPAdapterToLinearGraph.ts @@ -14,7 +14,7 @@ import { upsertMetadata } from './metadata'; export const addIPAdapterToLinearGraph = (state: RootState, graph: NonNullableGraph, baseNodeId: string): void => { const validIPAdapters = selectValidIPAdapters(state.controlAdapters).filter( - (ca) => ca.model?.base_model === state.generation.model?.base_model + (ca) => ca.model?.base === state.generation.model?.base ); if (validIPAdapters.length) { diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addLoRAsToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addLoRAsToGraph.ts index 3ed71b7529a..95bba9b4410 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/addLoRAsToGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/addLoRAsToGraph.ts @@ -28,6 +28,7 @@ export const addLoRAsToGraph = ( * So we need to inject a LoRA chain into the graph. */ + // TODO(MM2): check base model const enabledLoRAs = filter(state.lora.loras, (l) => l.isEnabled ?? false); const loraCount = size(enabledLoRAs); @@ -48,19 +49,19 @@ export const addLoRAsToGraph = ( const loraMetadata: CoreMetadataInvocation['loras'] = []; enabledLoRAs.forEach((lora) => { - const { model_name, base_model, weight } = lora; - const currentLoraNodeId = `${LORA_LOADER}_${model_name.replace('.', '_')}`; + const { key, weight } = lora; + const currentLoraNodeId = `${LORA_LOADER}_${key}`; const loraLoaderNode: LoraLoaderInvocation = { type: 'lora_loader', id: currentLoraNodeId, is_intermediate: true, - lora: { model_name, base_model }, + lora: { key }, weight, }; loraMetadata.push({ - lora: { model_name, base_model }, + lora: { key }, weight, }); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addSDXLLoRAstoGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addSDXLLoRAstoGraph.ts index 95535689222..7874b059c91 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/addSDXLLoRAstoGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/addSDXLLoRAstoGraph.ts @@ -31,6 +31,7 @@ export const addSDXLLoRAsToGraph = ( * So we need to inject a LoRA chain into the graph. */ + // TODO(MM2): check base model const enabledLoRAs = filter(state.lora.loras, (l) => l.isEnabled ?? false); const loraCount = size(enabledLoRAs); @@ -60,20 +61,20 @@ export const addSDXLLoRAsToGraph = ( let currentLoraIndex = 0; enabledLoRAs.forEach((lora) => { - const { model_name, base_model, weight } = lora; - const currentLoraNodeId = `${LORA_LOADER}_${model_name.replace('.', '_')}`; + const { key, weight } = lora; + const currentLoraNodeId = `${LORA_LOADER}_${key}`; const loraLoaderNode: SDXLLoraLoaderInvocation = { type: 'sdxl_lora_loader', id: currentLoraNodeId, is_intermediate: true, - lora: { model_name, base_model }, + lora: { key }, weight, }; loraMetadata.push( zLoRAMetadataItem.parse({ - lora: { model_name, base_model }, + lora: { key }, weight, }) ); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addT2IAdapterToLinearGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addT2IAdapterToLinearGraph.ts index d35f72a2b46..84002337d78 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/addT2IAdapterToLinearGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/addT2IAdapterToLinearGraph.ts @@ -14,7 +14,7 @@ import { upsertMetadata } from './metadata'; export const addT2IAdaptersToLinearGraph = (state: RootState, graph: NonNullableGraph, baseNodeId: string): void => { const validT2IAdapters = selectValidT2IAdapters(state.controlAdapters).filter( - (ca) => ca.model?.base_model === state.generation.model?.base_model + (ca) => ca.model?.base === state.generation.model?.base ); if (validT2IAdapters.length) { diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasGraph.ts index 2b64f4898b1..4ce2e4d6733 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasGraph.ts @@ -19,7 +19,7 @@ export const buildCanvasGraph = ( let graph: NonNullableGraph; if (generationMode === 'txt2img') { - if (state.generation.model && state.generation.model.base_model === 'sdxl') { + if (state.generation.model && state.generation.model.base === 'sdxl') { graph = buildCanvasSDXLTextToImageGraph(state); } else { graph = buildCanvasTextToImageGraph(state); @@ -28,7 +28,7 @@ export const buildCanvasGraph = ( if (!canvasInitImage) { throw new Error('Missing canvas init image'); } - if (state.generation.model && state.generation.model.base_model === 'sdxl') { + if (state.generation.model && state.generation.model.base === 'sdxl') { graph = buildCanvasSDXLImageToImageGraph(state, canvasInitImage); } else { graph = buildCanvasImageToImageGraph(state, canvasInitImage); @@ -37,7 +37,7 @@ export const buildCanvasGraph = ( if (!canvasInitImage || !canvasMaskImage) { throw new Error('Missing canvas init and mask images'); } - if (state.generation.model && state.generation.model.base_model === 'sdxl') { + if (state.generation.model && state.generation.model.base === 'sdxl') { graph = buildCanvasSDXLInpaintGraph(state, canvasInitImage, canvasMaskImage); } else { graph = buildCanvasInpaintGraph(state, canvasInitImage, canvasMaskImage); @@ -46,7 +46,7 @@ export const buildCanvasGraph = ( if (!canvasInitImage) { throw new Error('Missing canvas init image'); } - if (state.generation.model && state.generation.model.base_model === 'sdxl') { + if (state.generation.model && state.generation.model.base === 'sdxl') { graph = buildCanvasSDXLOutpaintGraph(state, canvasInitImage, canvasMaskImage); } else { graph = buildCanvasOutpaintGraph(state, canvasInitImage, canvasMaskImage); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearBatchConfig.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearBatchConfig.ts index d0e331fb469..9fcc6afaa07 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearBatchConfig.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearBatchConfig.ts @@ -105,7 +105,7 @@ export const prepareLinearUIBatch = (state: RootState, graph: NonNullableGraph, }); } - if (shouldConcatSDXLStylePrompt && model?.base_model === 'sdxl') { + if (shouldConcatSDXLStylePrompt && model?.base === 'sdxl') { if (graph.nodes[POSITIVE_CONDITIONING]) { firstBatchDatumList.push({ node_path: POSITIVE_CONDITIONING, diff --git a/invokeai/frontend/web/src/features/parameters/components/Advanced/ParamClipSkip.tsx b/invokeai/frontend/web/src/features/parameters/components/Advanced/ParamClipSkip.tsx index 621ed56ef6f..c23d5416137 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Advanced/ParamClipSkip.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Advanced/ParamClipSkip.tsx @@ -29,17 +29,17 @@ const ParamClipSkip = () => { if (!model) { return CLIP_SKIP_MAP['sd-1'].maxClip; } - return CLIP_SKIP_MAP[model.base_model].maxClip; + return CLIP_SKIP_MAP[model.base].maxClip; }, [model]); const sliderMarks = useMemo(() => { if (!model) { return CLIP_SKIP_MAP['sd-1'].markers; } - return CLIP_SKIP_MAP[model.base_model].markers; + return CLIP_SKIP_MAP[model.base].markers; }, [model]); - if (model?.base_model === 'sdxl') { + if (model?.base === 'sdxl') { return null; } diff --git a/invokeai/frontend/web/src/features/parameters/components/Core/ParamPositivePrompt.tsx b/invokeai/frontend/web/src/features/parameters/components/Core/ParamPositivePrompt.tsx index ae81f78fd10..a1852bfafee 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Core/ParamPositivePrompt.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Core/ParamPositivePrompt.tsx @@ -15,7 +15,7 @@ import { useTranslation } from 'react-i18next'; export const ParamPositivePrompt = memo(() => { const dispatch = useAppDispatch(); const prompt = useAppSelector((s) => s.generation.positivePrompt); - const baseModel = useAppSelector((s) => s.generation.model)?.base_model; + const baseModel = useAppSelector((s) => s.generation.model)?.base; const textareaRef = useRef(null); const { t } = useTranslation(); diff --git a/invokeai/frontend/web/src/features/parameters/components/MainModel/ParamMainModelSelect.tsx b/invokeai/frontend/web/src/features/parameters/components/MainModel/ParamMainModelSelect.tsx index c6c77b5fe95..18f780bdeeb 100644 --- a/invokeai/frontend/web/src/features/parameters/components/MainModel/ParamMainModelSelect.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/MainModel/ParamMainModelSelect.tsx @@ -9,7 +9,7 @@ import { pick } from 'lodash-es'; import { memo, useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { NON_REFINER_BASE_MODELS } from 'services/api/constants'; -import type { MainModelConfigEntity } from 'services/api/endpoints/models'; +import type { MainModelConfig } from 'services/api/endpoints/models'; import { getModelId, mainModelsAdapterSelectors, useGetMainModelsQuery } from 'services/api/endpoints/models'; const selectModel = createMemoizedSelector(selectGenerationSlice, (generation) => generation.model); @@ -26,7 +26,7 @@ const ParamMainModelSelect = () => { return mainModelsAdapterSelectors.selectById(data, getModelId(model))?.description; }, [data, model]); const _onChange = useCallback( - (model: MainModelConfigEntity | null) => { + (model: MainModelConfig | null) => { if (!model) { return; } diff --git a/invokeai/frontend/web/src/features/parameters/components/VAEModel/ParamVAEModelSelect.tsx b/invokeai/frontend/web/src/features/parameters/components/VAEModel/ParamVAEModelSelect.tsx index f290378aa82..cc0164153d8 100644 --- a/invokeai/frontend/web/src/features/parameters/components/VAEModel/ParamVAEModelSelect.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/VAEModel/ParamVAEModelSelect.tsx @@ -7,7 +7,7 @@ import { selectGenerationSlice, vaeSelected } from 'features/parameters/store/ge import { pick } from 'lodash-es'; import { memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; -import type { VaeModelConfigEntity } from 'services/api/endpoints/models'; +import type { VAEConfig } from 'services/api/endpoints/models'; import { useGetVaeModelsQuery } from 'services/api/endpoints/models'; const selector = createMemoizedSelector(selectGenerationSlice, (generation) => { @@ -21,7 +21,7 @@ const ParamVAEModelSelect = () => { const { model, vae } = useAppSelector(selector); const { data, isLoading } = useGetVaeModelsQuery(); const getIsDisabled = useCallback( - (vae: VaeModelConfigEntity): boolean => { + (vae: VAEConfig): boolean => { const isCompatible = model?.base_model === vae.base_model; const hasMainModel = Boolean(model?.base_model); return !hasMainModel || !isCompatible; @@ -29,7 +29,7 @@ const ParamVAEModelSelect = () => { [model?.base_model] ); const _onChange = useCallback( - (vae: VaeModelConfigEntity | null) => { + (vae: VAEConfig | null) => { dispatch(vaeSelected(vae ? pick(vae, 'base_model', 'model_name') : null)); }, [dispatch] diff --git a/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts b/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts index 5a9fa6c66d9..c8b17816bb5 100644 --- a/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts +++ b/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts @@ -464,17 +464,15 @@ export const useRecallParameters = () => { return { lora: null, error: 'Invalid LoRA model' }; } - const { base_model, model_name } = loraMetadataItem.lora; + const { lora } = loraMetadataItem; - const matchingLoRA = loraModels - ? loraModelsAdapterSelectors.selectById(loraModels, `${base_model}/lora/${model_name}`) - : undefined; + const matchingLoRA = loraModels ? loraModelsAdapterSelectors.selectById(loraModels, lora.key) : undefined; if (!matchingLoRA) { return { lora: null, error: 'LoRA model is not installed' }; } - const isCompatibleBaseModel = matchingLoRA?.base_model === (newModel ?? model)?.base_model; + const isCompatibleBaseModel = matchingLoRA?.base === (newModel ?? model)?.base; if (!isCompatibleBaseModel) { return { @@ -520,17 +518,14 @@ export const useRecallParameters = () => { controlnetMetadataItem; const matchingControlNetModel = controlNetModels - ? controlNetModelsAdapterSelectors.selectById( - controlNetModels, - `${control_model.base_model}/controlnet/${control_model.model_name}` - ) + ? controlNetModelsAdapterSelectors.selectById(controlNetModels, control_model.key) : undefined; if (!matchingControlNetModel) { return { controlnet: null, error: 'ControlNet model is not installed' }; } - const isCompatibleBaseModel = matchingControlNetModel?.base_model === (newModel ?? model)?.base_model; + const isCompatibleBaseModel = matchingControlNetModel?.base === (newModel ?? model)?.base; if (!isCompatibleBaseModel) { return { @@ -597,17 +592,14 @@ export const useRecallParameters = () => { t2iAdapterMetadataItem; const matchingT2IAdapterModel = t2iAdapterModels - ? t2iAdapterModelsAdapterSelectors.selectById( - t2iAdapterModels, - `${t2i_adapter_model.base_model}/t2i_adapter/${t2i_adapter_model.model_name}` - ) + ? t2iAdapterModelsAdapterSelectors.selectById(t2iAdapterModels, t2i_adapter_model.key) : undefined; if (!matchingT2IAdapterModel) { return { controlnet: null, error: 'ControlNet model is not installed' }; } - const isCompatibleBaseModel = matchingT2IAdapterModel?.base_model === (newModel ?? model)?.base_model; + const isCompatibleBaseModel = matchingT2IAdapterModel?.base === (newModel ?? model)?.base; if (!isCompatibleBaseModel) { return { @@ -672,17 +664,14 @@ export const useRecallParameters = () => { const { image, ip_adapter_model, weight, begin_step_percent, end_step_percent } = ipAdapterMetadataItem; const matchingIPAdapterModel = ipAdapterModels - ? ipAdapterModelsAdapterSelectors.selectById( - ipAdapterModels, - `${ip_adapter_model.base_model}/ip_adapter/${ip_adapter_model.model_name}` - ) + ? ipAdapterModelsAdapterSelectors.selectById(ipAdapterModels, ip_adapter_model.key) : undefined; if (!matchingIPAdapterModel) { return { ipAdapter: null, error: 'IP Adapter model is not installed' }; } - const isCompatibleBaseModel = matchingIPAdapterModel?.base_model === (newModel ?? model)?.base_model; + const isCompatibleBaseModel = matchingIPAdapterModel?.base === (newModel ?? model)?.base; if (!isCompatibleBaseModel) { return { diff --git a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts index df98943cd3d..1666a34d6ab 100644 --- a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts +++ b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts @@ -158,15 +158,15 @@ export const generationSlice = createSlice({ // Clamp ClipSkip Based On Selected Model // TODO(psyche): remove this special handling when https://github.com/invoke-ai/InvokeAI/issues/4583 is resolved // WIP PR here: https://github.com/invoke-ai/InvokeAI/pull/4624 - if (newModel.base_model === 'sdxl') { + if (newModel.base === 'sdxl') { // We don't support clip skip for SDXL yet - it's not in the graphs state.clipSkip = 0; } else { - const { maxClip } = CLIP_SKIP_MAP[newModel.base_model]; + const { maxClip } = CLIP_SKIP_MAP[newModel.base]; state.clipSkip = clamp(state.clipSkip, 0, maxClip); } - if (action.meta.previousModel?.base_model === newModel.base_model) { + if (action.meta.previousModel?.base === newModel.base) { // The base model hasn't changed, we don't need to optimize the size return; } diff --git a/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts b/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts index 7a5efe9dcfc..abd8ee28103 100644 --- a/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts +++ b/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts @@ -1,5 +1,6 @@ import { NUMPY_RAND_MAX } from 'app/constants'; import { + zBaseModel, zControlNetModelField, zIPAdapterModelField, zLoRAModelField, @@ -104,48 +105,48 @@ export const isParameterAspectRatio = (val: unknown): val is ParameterAspectRati // #endregion // #region Model -export const zParameterModel = zMainModelField; +export const zParameterModel = zMainModelField.extend({ base: zBaseModel }); export type ParameterModel = z.infer; export const isParameterModel = (val: unknown): val is ParameterModel => zParameterModel.safeParse(val).success; // #endregion // #region SDXL Refiner Model -export const zParameterSDXLRefinerModel = zSDXLRefinerModelField; +export const zParameterSDXLRefinerModel = zSDXLRefinerModelField.extend({ base: zBaseModel }); export type ParameterSDXLRefinerModel = z.infer; export const isParameterSDXLRefinerModel = (val: unknown): val is ParameterSDXLRefinerModel => zParameterSDXLRefinerModel.safeParse(val).success; // #endregion // #region VAE Model -export const zParameterVAEModel = zVAEModelField; +export const zParameterVAEModel = zVAEModelField.extend({ base: zBaseModel }); export type ParameterVAEModel = z.infer; export const isParameterVAEModel = (val: unknown): val is ParameterVAEModel => zParameterVAEModel.safeParse(val).success; // #endregion // #region LoRA Model -export const zParameterLoRAModel = zLoRAModelField; +export const zParameterLoRAModel = zLoRAModelField.extend({ base: zBaseModel }); export type ParameterLoRAModel = z.infer; export const isParameterLoRAModel = (val: unknown): val is ParameterLoRAModel => zParameterLoRAModel.safeParse(val).success; // #endregion // #region ControlNet Model -export const zParameterControlNetModel = zControlNetModelField; +export const zParameterControlNetModel = zControlNetModelField.extend({ base: zBaseModel }); export type ParameterControlNetModel = z.infer; export const isParameterControlNetModel = (val: unknown): val is ParameterControlNetModel => zParameterControlNetModel.safeParse(val).success; // #endregion // #region IP Adapter Model -export const zParameterIPAdapterModel = zIPAdapterModelField; +export const zParameterIPAdapterModel = zIPAdapterModelField.extend({ base: zBaseModel }); export type ParameterIPAdapterModel = z.infer; export const isParameterIPAdapterModel = (val: unknown): val is ParameterIPAdapterModel => zParameterIPAdapterModel.safeParse(val).success; // #endregion // #region T2I Adapter Model -export const zParameterT2IAdapterModel = zT2IAdapterModelField; +export const zParameterT2IAdapterModel = zT2IAdapterModelField.extend({ base: zBaseModel }); export type ParameterT2IAdapterModel = z.infer; export const isParameterT2IAdapterModel = (val: unknown): val is ParameterT2IAdapterModel => zParameterT2IAdapterModel.safeParse(val).success; diff --git a/invokeai/frontend/web/src/features/parameters/util/optimalDimension.ts b/invokeai/frontend/web/src/features/parameters/util/optimalDimension.ts index 1c550eb8a40..92b4f182727 100644 --- a/invokeai/frontend/web/src/features/parameters/util/optimalDimension.ts +++ b/invokeai/frontend/web/src/features/parameters/util/optimalDimension.ts @@ -1,12 +1,12 @@ -import type { ModelIdentifier } from 'features/nodes/types/common'; +import type { ModelIdentifierWithBase } from 'features/nodes/types/common'; /** * Gets the optimal dimension for a givel model, based on the model's base_model * @param model The model identifier * @returns The optimal dimension for the model */ -export const getOptimalDimension = (model?: ModelIdentifier | null): number => - model?.base_model === 'sdxl' ? 1024 : 512; +export const getOptimalDimension = (model?: ModelIdentifierWithBase | null): number => + model?.base === 'sdxl' ? 1024 : 512; const MIN_AREA_FACTOR = 0.8; const MAX_AREA_FACTOR = 1.2; diff --git a/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerModelSelect.tsx b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerModelSelect.tsx index 5559ec76b7e..4c542515573 100644 --- a/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerModelSelect.tsx +++ b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerModelSelect.tsx @@ -7,12 +7,12 @@ import { refinerModelChanged, selectSdxlSlice } from 'features/sdxl/store/sdxlSl import { memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; import { REFINER_BASE_MODELS } from 'services/api/constants'; -import type { MainModelConfigEntity } from 'services/api/endpoints/models'; +import type { MainModelConfig } from 'services/api/endpoints/models'; import { useGetMainModelsQuery } from 'services/api/endpoints/models'; const selectModel = createMemoizedSelector(selectSdxlSlice, (sdxl) => sdxl.refinerModel); -const optionsFilter = (model: MainModelConfigEntity) => model.base_model === 'sdxl-refiner'; +const optionsFilter = (model: MainModelConfig) => model.base_model === 'sdxl-refiner'; const ParamSDXLRefinerModelSelect = () => { const dispatch = useAppDispatch(); @@ -20,7 +20,7 @@ const ParamSDXLRefinerModelSelect = () => { const { t } = useTranslation(); const { data, isLoading } = useGetMainModelsQuery(REFINER_BASE_MODELS); const _onChange = useCallback( - (model: MainModelConfigEntity | null) => { + (model: MainModelConfig | null) => { if (!model) { dispatch(refinerModelChanged(null)); return; diff --git a/invokeai/frontend/web/src/features/settingsAccordions/components/AdvancedSettingsAccordion/AdvancedSettingsAccordion.tsx b/invokeai/frontend/web/src/features/settingsAccordions/components/AdvancedSettingsAccordion/AdvancedSettingsAccordion.tsx index bceee915cde..fc8c54576c5 100644 --- a/invokeai/frontend/web/src/features/settingsAccordions/components/AdvancedSettingsAccordion/AdvancedSettingsAccordion.tsx +++ b/invokeai/frontend/web/src/features/settingsAccordions/components/AdvancedSettingsAccordion/AdvancedSettingsAccordion.tsx @@ -24,7 +24,8 @@ const formLabelProps2: FormLabelProps = { const selectBadges = createMemoizedSelector(selectGenerationSlice, (generation) => { const badges: (string | number)[] = []; if (generation.vae) { - let vaeBadge = generation.vae.model_name; + // TODO(MM2): Fetch the vae name + let vaeBadge = generation.vae.key; if (generation.vaePrecision === 'fp16') { vaeBadge += ` ${generation.vaePrecision}`; } diff --git a/invokeai/frontend/web/src/features/settingsAccordions/components/GenerationSettingsAccordion/GenerationSettingsAccordion.tsx b/invokeai/frontend/web/src/features/settingsAccordions/components/GenerationSettingsAccordion/GenerationSettingsAccordion.tsx index 077875a8a76..cda7dcf6e92 100644 --- a/invokeai/frontend/web/src/features/settingsAccordions/components/GenerationSettingsAccordion/GenerationSettingsAccordion.tsx +++ b/invokeai/frontend/web/src/features/settingsAccordions/components/GenerationSettingsAccordion/GenerationSettingsAccordion.tsx @@ -35,9 +35,10 @@ const badgesSelector = createMemoizedSelector(selectLoraSlice, selectGenerationS const enabledLoRAsCount = filter(lora.loras, (l) => !!l.isEnabled).length; const loraTabBadges = enabledLoRAsCount ? [enabledLoRAsCount] : []; const accordionBadges: (string | number)[] = []; + // TODO(MM2): fetch model name if (generation.model) { - accordionBadges.push(generation.model.model_name); - accordionBadges.push(generation.model.base_model); + accordionBadges.push(generation.model.key); + accordionBadges.push(generation.model.base); } return { loraTabBadges, accordionBadges }; diff --git a/invokeai/frontend/web/src/features/settingsAccordions/components/ImageSettingsAccordion/ImageSettingsAccordion.tsx b/invokeai/frontend/web/src/features/settingsAccordions/components/ImageSettingsAccordion/ImageSettingsAccordion.tsx index 2f778fe717c..8f876850e86 100644 --- a/invokeai/frontend/web/src/features/settingsAccordions/components/ImageSettingsAccordion/ImageSettingsAccordion.tsx +++ b/invokeai/frontend/web/src/features/settingsAccordions/components/ImageSettingsAccordion/ImageSettingsAccordion.tsx @@ -56,7 +56,7 @@ const selector = createMemoizedSelector( if (hrfEnabled) { badges.push('HiRes Fix'); } - return { badges, activeTabName, isSDXL: model?.base_model === 'sdxl' }; + return { badges, activeTabName, isSDXL: model?.base === 'sdxl' }; } ); diff --git a/invokeai/frontend/web/src/features/ui/components/ParametersPanel.tsx b/invokeai/frontend/web/src/features/ui/components/ParametersPanel.tsx index d52b2d9000e..a74d132bd60 100644 --- a/invokeai/frontend/web/src/features/ui/components/ParametersPanel.tsx +++ b/invokeai/frontend/web/src/features/ui/components/ParametersPanel.tsx @@ -22,7 +22,7 @@ const overlayScrollbarsStyles: CSSProperties = { const ParametersPanel = () => { const activeTabName = useAppSelector(activeTabNameSelector); - const isSDXL = useAppSelector((s) => s.generation.model?.base_model === 'sdxl'); + const isSDXL = useAppSelector((s) => s.generation.model?.base === 'sdxl'); return ( diff --git a/invokeai/frontend/web/src/services/api/endpoints/models.ts b/invokeai/frontend/web/src/services/api/endpoints/models.ts index c11c8b45e59..97e221454d5 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/models.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/models.ts @@ -1,64 +1,26 @@ -import type { EntityState } from '@reduxjs/toolkit'; +import type { EntityAdapter, EntityState } from '@reduxjs/toolkit'; import { createEntityAdapter } from '@reduxjs/toolkit'; import { getSelectorsOptions } from 'app/store/createMemoizedSelector'; -import { cloneDeep } from 'lodash-es'; import queryString from 'query-string'; import type { operations, paths } from 'services/api/schema'; import type { AnyModelConfig, BaseModelType, - CheckpointModelConfig, - ControlNetModelConfig, - DiffusersModelConfig, + ControlNetConfig, ImportModelConfig, - IPAdapterModelConfig, - LoRAModelConfig, + IPAdapterConfig, + LoRAConfig, MainModelConfig, MergeModelConfig, ModelType, - T2IAdapterModelConfig, - TextualInversionModelConfig, - VaeModelConfig, + T2IAdapterConfig, + TextualInversionConfig, + VAEConfig, } from 'services/api/types'; -import type { ApiTagDescription } from '..'; +import type { ApiTagDescription, tagTypes } from '..'; import { api, LIST_TAG } from '..'; -export type DiffusersModelConfigEntity = DiffusersModelConfig & { id: string }; -export type CheckpointModelConfigEntity = CheckpointModelConfig & { - id: string; -}; -export type MainModelConfigEntity = DiffusersModelConfigEntity | CheckpointModelConfigEntity; - -export type LoRAModelConfigEntity = LoRAModelConfig & { id: string }; - -export type ControlNetModelConfigEntity = ControlNetModelConfig & { - id: string; -}; - -export type IPAdapterModelConfigEntity = IPAdapterModelConfig & { - id: string; -}; - -export type T2IAdapterModelConfigEntity = T2IAdapterModelConfig & { - id: string; -}; - -export type TextualInversionModelConfigEntity = TextualInversionModelConfig & { - id: string; -}; - -export type VaeModelConfigEntity = VaeModelConfig & { id: string }; - -export type AnyModelConfigEntity = - | MainModelConfigEntity - | LoRAModelConfigEntity - | ControlNetModelConfigEntity - | IPAdapterModelConfigEntity - | T2IAdapterModelConfigEntity - | TextualInversionModelConfigEntity - | VaeModelConfigEntity; - type UpdateMainModelArg = { base_model: BaseModelType; model_name: string; @@ -68,11 +30,11 @@ type UpdateMainModelArg = { type UpdateLoRAModelArg = { base_model: BaseModelType; model_name: string; - body: LoRAModelConfig; + body: LoRAConfig; }; type UpdateMainModelResponse = - paths['/api/v1/models/{base_model}/{model_type}/{model_name}']['patch']['responses']['200']['content']['application/json']; + paths['/api/v2/models/i/{key}']['patch']['responses']['200']['content']['application/json']; type UpdateLoRAModelResponse = UpdateMainModelResponse; @@ -128,59 +90,71 @@ type CheckpointConfigsResponse = type SearchFolderArg = operations['search_for_models']['parameters']['query']; -export const mainModelsAdapter = createEntityAdapter({ - sortComparer: (a, b) => a.model_name.localeCompare(b.model_name), +export const mainModelsAdapter = createEntityAdapter({ + selectId: (entity) => entity.key, + sortComparer: (a, b) => a.name.localeCompare(b.name), }); export const mainModelsAdapterSelectors = mainModelsAdapter.getSelectors(undefined, getSelectorsOptions); -export const loraModelsAdapter = createEntityAdapter({ - sortComparer: (a, b) => a.model_name.localeCompare(b.model_name), +export const loraModelsAdapter = createEntityAdapter({ + selectId: (entity) => entity.key, + sortComparer: (a, b) => a.name.localeCompare(b.name), }); export const loraModelsAdapterSelectors = loraModelsAdapter.getSelectors(undefined, getSelectorsOptions); -export const controlNetModelsAdapter = createEntityAdapter({ - sortComparer: (a, b) => a.model_name.localeCompare(b.model_name), +export const controlNetModelsAdapter = createEntityAdapter({ + selectId: (entity) => entity.key, + sortComparer: (a, b) => a.name.localeCompare(b.name), }); export const controlNetModelsAdapterSelectors = controlNetModelsAdapter.getSelectors(undefined, getSelectorsOptions); -export const ipAdapterModelsAdapter = createEntityAdapter({ - sortComparer: (a, b) => a.model_name.localeCompare(b.model_name), +export const ipAdapterModelsAdapter = createEntityAdapter({ + selectId: (entity) => entity.key, + sortComparer: (a, b) => a.name.localeCompare(b.name), }); export const ipAdapterModelsAdapterSelectors = ipAdapterModelsAdapter.getSelectors(undefined, getSelectorsOptions); -export const t2iAdapterModelsAdapter = createEntityAdapter({ - sortComparer: (a, b) => a.model_name.localeCompare(b.model_name), +export const t2iAdapterModelsAdapter = createEntityAdapter({ + selectId: (entity) => entity.key, + sortComparer: (a, b) => a.name.localeCompare(b.name), }); export const t2iAdapterModelsAdapterSelectors = t2iAdapterModelsAdapter.getSelectors(undefined, getSelectorsOptions); -export const textualInversionModelsAdapter = createEntityAdapter({ - sortComparer: (a, b) => a.model_name.localeCompare(b.model_name), +export const textualInversionModelsAdapter = createEntityAdapter({ + selectId: (entity) => entity.key, + sortComparer: (a, b) => a.name.localeCompare(b.name), }); export const textualInversionModelsAdapterSelectors = textualInversionModelsAdapter.getSelectors( undefined, getSelectorsOptions ); -export const vaeModelsAdapter = createEntityAdapter({ - sortComparer: (a, b) => a.model_name.localeCompare(b.model_name), +export const vaeModelsAdapter = createEntityAdapter({ + selectId: (entity) => entity.key, + sortComparer: (a, b) => a.name.localeCompare(b.name), }); export const vaeModelsAdapterSelectors = vaeModelsAdapter.getSelectors(undefined, getSelectorsOptions); -export const getModelId = ({ - base_model, - model_type, - model_name, -}: Pick) => `${base_model}/${model_type}/${model_name}`; - -const createModelEntities = (models: AnyModelConfig[]): T[] => { - const entityArray: T[] = []; - models.forEach((model) => { - const entity = { - ...cloneDeep(model), - id: getModelId(model), - } as T; - entityArray.push(entity); - }); - return entityArray; -}; +const buildProvidesTags = + (tagType: (typeof tagTypes)[number]) => + (result: EntityState | undefined) => { + const tags: ApiTagDescription[] = [{ type: tagType, id: LIST_TAG }, 'Model']; + + if (result) { + tags.push( + ...result.ids.map((id) => ({ + type: tagType, + id, + })) + ); + } + + return tags; + }; + +const buildTransformResponse = + (adapter: EntityAdapter) => + (response: { models: T[] }) => { + return adapter.setAll(adapter.getInitialState(), response.models); + }; export const modelsApi = api.injectEndpoints({ endpoints: (build) => ({ - getMainModels: build.query, BaseModelType[]>({ + getMainModels: build.query, BaseModelType[]>({ query: (base_models) => { const params = { model_type: 'main', @@ -190,24 +164,8 @@ export const modelsApi = api.injectEndpoints({ const query = queryString.stringify(params, { arrayFormat: 'none' }); return `models/?${query}`; }, - providesTags: (result) => { - const tags: ApiTagDescription[] = [{ type: 'MainModel', id: LIST_TAG }, 'Model']; - - if (result) { - tags.push( - ...result.ids.map((id) => ({ - type: 'MainModel' as const, - id, - })) - ); - } - - return tags; - }, - transformResponse: (response: { models: MainModelConfig[] }) => { - const entities = createModelEntities(response.models); - return mainModelsAdapter.setAll(mainModelsAdapter.getInitialState(), entities); - }, + providesTags: buildProvidesTags('MainModel'), + transformResponse: buildTransformResponse(mainModelsAdapter), }), updateMainModels: build.mutation({ query: ({ base_model, model_name, body }) => { @@ -277,26 +235,10 @@ export const modelsApi = api.injectEndpoints({ }, invalidatesTags: ['Model'], }), - getLoRAModels: build.query, void>({ + getLoRAModels: build.query, void>({ query: () => ({ url: 'models/', params: { model_type: 'lora' } }), - providesTags: (result) => { - const tags: ApiTagDescription[] = [{ type: 'LoRAModel', id: LIST_TAG }, 'Model']; - - if (result) { - tags.push( - ...result.ids.map((id) => ({ - type: 'LoRAModel' as const, - id, - })) - ); - } - - return tags; - }, - transformResponse: (response: { models: LoRAModelConfig[] }) => { - const entities = createModelEntities(response.models); - return loraModelsAdapter.setAll(loraModelsAdapter.getInitialState(), entities); - }, + providesTags: buildProvidesTags('LoRAModel'), + transformResponse: buildTransformResponse(loraModelsAdapter), }), updateLoRAModels: build.mutation({ query: ({ base_model, model_name, body }) => { @@ -317,110 +259,30 @@ export const modelsApi = api.injectEndpoints({ }, invalidatesTags: [{ type: 'LoRAModel', id: LIST_TAG }], }), - getControlNetModels: build.query, void>({ + getControlNetModels: build.query, void>({ query: () => ({ url: 'models/', params: { model_type: 'controlnet' } }), - providesTags: (result) => { - const tags: ApiTagDescription[] = [{ type: 'ControlNetModel', id: LIST_TAG }, 'Model']; - - if (result) { - tags.push( - ...result.ids.map((id) => ({ - type: 'ControlNetModel' as const, - id, - })) - ); - } - - return tags; - }, - transformResponse: (response: { models: ControlNetModelConfig[] }) => { - const entities = createModelEntities(response.models); - return controlNetModelsAdapter.setAll(controlNetModelsAdapter.getInitialState(), entities); - }, + providesTags: buildProvidesTags('ControlNetModel'), + transformResponse: buildTransformResponse(controlNetModelsAdapter), }), - getIPAdapterModels: build.query, void>({ + getIPAdapterModels: build.query, void>({ query: () => ({ url: 'models/', params: { model_type: 'ip_adapter' } }), - providesTags: (result) => { - const tags: ApiTagDescription[] = [{ type: 'IPAdapterModel', id: LIST_TAG }, 'Model']; - - if (result) { - tags.push( - ...result.ids.map((id) => ({ - type: 'IPAdapterModel' as const, - id, - })) - ); - } - - return tags; - }, - transformResponse: (response: { models: IPAdapterModelConfig[] }) => { - const entities = createModelEntities(response.models); - return ipAdapterModelsAdapter.setAll(ipAdapterModelsAdapter.getInitialState(), entities); - }, + providesTags: buildProvidesTags('IPAdapterModel'), + transformResponse: buildTransformResponse(ipAdapterModelsAdapter), }), - getT2IAdapterModels: build.query, void>({ + getT2IAdapterModels: build.query, void>({ query: () => ({ url: 'models/', params: { model_type: 't2i_adapter' } }), - providesTags: (result) => { - const tags: ApiTagDescription[] = [{ type: 'T2IAdapterModel', id: LIST_TAG }, 'Model']; - - if (result) { - tags.push( - ...result.ids.map((id) => ({ - type: 'T2IAdapterModel' as const, - id, - })) - ); - } - - return tags; - }, - transformResponse: (response: { models: T2IAdapterModelConfig[] }) => { - const entities = createModelEntities(response.models); - return t2iAdapterModelsAdapter.setAll(t2iAdapterModelsAdapter.getInitialState(), entities); - }, + providesTags: buildProvidesTags('T2IAdapterModel'), + transformResponse: buildTransformResponse(t2iAdapterModelsAdapter), }), - getVaeModels: build.query, void>({ + getVaeModels: build.query, void>({ query: () => ({ url: 'models/', params: { model_type: 'vae' } }), - providesTags: (result) => { - const tags: ApiTagDescription[] = [{ type: 'VaeModel', id: LIST_TAG }, 'Model']; - - if (result) { - tags.push( - ...result.ids.map((id) => ({ - type: 'VaeModel' as const, - id, - })) - ); - } - - return tags; - }, - transformResponse: (response: { models: VaeModelConfig[] }) => { - const entities = createModelEntities(response.models); - return vaeModelsAdapter.setAll(vaeModelsAdapter.getInitialState(), entities); - }, + providesTags: buildProvidesTags('VaeModel'), + transformResponse: buildTransformResponse(vaeModelsAdapter), }), - getTextualInversionModels: build.query, void>({ + getTextualInversionModels: build.query, void>({ query: () => ({ url: 'models/', params: { model_type: 'embedding' } }), - providesTags: (result) => { - const tags: ApiTagDescription[] = [{ type: 'TextualInversionModel', id: LIST_TAG }, 'Model']; - - if (result) { - tags.push( - ...result.ids.map((id) => ({ - type: 'TextualInversionModel' as const, - id, - })) - ); - } - - return tags; - }, - transformResponse: (response: { models: TextualInversionModelConfig[] }) => { - const entities = createModelEntities(response.models); - return textualInversionModelsAdapter.setAll(textualInversionModelsAdapter.getInitialState(), entities); - }, + providesTags: buildProvidesTags('TextualInversionModel'), + transformResponse: buildTransformResponse(textualInversionModelsAdapter), }), getModelsInFolder: build.query({ query: (arg) => { diff --git a/invokeai/frontend/web/src/services/api/types.ts b/invokeai/frontend/web/src/services/api/types.ts index f9a1decf655..7a02cc55681 100644 --- a/invokeai/frontend/web/src/services/api/types.ts +++ b/invokeai/frontend/web/src/services/api/types.ts @@ -2,6 +2,7 @@ import type { UseToastOptions } from '@invoke-ai/ui-library'; import type { EntityState } from '@reduxjs/toolkit'; import type { components, paths } from 'services/api/schema'; import type { O } from 'ts-toolbelt'; +import type { SetRequired } from 'type-fest'; export type S = components['schemas']; @@ -54,40 +55,34 @@ export type LoRAModelFormat = S['LoRAModelFormat']; export type ControlNetModelField = S['ControlNetModelField']; export type IPAdapterModelField = S['IPAdapterModelField']; export type T2IAdapterModelField = S['T2IAdapterModelField']; -export type ModelsList = S['invokeai__app__api__routers__models__ModelsList']; export type ControlField = S['ControlField']; export type IPAdapterField = S['IPAdapterField']; // Model Configs -export type LoRAModelConfig = S['LoRAModelConfig']; -export type VaeModelConfig = S['VaeModelConfig']; -export type ControlNetModelCheckpointConfig = S['ControlNetModelCheckpointConfig']; -export type ControlNetModelDiffusersConfig = S['ControlNetModelDiffusersConfig']; -export type ControlNetModelConfig = ControlNetModelCheckpointConfig | ControlNetModelDiffusersConfig; -export type IPAdapterModelInvokeAIConfig = S['IPAdapterModelInvokeAIConfig']; -export type IPAdapterModelConfig = IPAdapterModelInvokeAIConfig; -export type T2IAdapterModelDiffusersConfig = S['T2IAdapterModelDiffusersConfig']; -export type T2IAdapterModelConfig = T2IAdapterModelDiffusersConfig; -export type TextualInversionModelConfig = S['TextualInversionModelConfig']; -export type DiffusersModelConfig = - | S['StableDiffusion1ModelDiffusersConfig'] - | S['StableDiffusion2ModelDiffusersConfig'] - | S['StableDiffusionXLModelDiffusersConfig']; -export type CheckpointModelConfig = - | S['StableDiffusion1ModelCheckpointConfig'] - | S['StableDiffusion2ModelCheckpointConfig'] - | S['StableDiffusionXLModelCheckpointConfig']; + +// TODO(MM2): Can we make key required in the pydantic model? +type KeyRequired = SetRequired; +export type LoRAConfig = KeyRequired; +// TODO(MM2): Can we rename this from Vae -> VAE +export type VAEConfig = KeyRequired | KeyRequired; +export type ControlNetConfig = KeyRequired | KeyRequired; +export type IPAdapterConfig = KeyRequired; +// TODO(MM2): Can we rename this to T2IAdapterConfig +export type T2IAdapterConfig = KeyRequired; +export type TextualInversionConfig = KeyRequired; +export type DiffusersModelConfig = KeyRequired; +export type CheckpointModelConfig = KeyRequired; export type MainModelConfig = DiffusersModelConfig | CheckpointModelConfig; export type AnyModelConfig = - | LoRAModelConfig - | VaeModelConfig - | ControlNetModelConfig - | IPAdapterModelConfig - | T2IAdapterModelConfig - | TextualInversionModelConfig + | LoRAConfig + | VAEConfig + | ControlNetConfig + | IPAdapterConfig + | T2IAdapterConfig + | TextualInversionConfig | MainModelConfig; -export type MergeModelConfig = S['Body_merge_models']; +export type MergeModelConfig = S['Body_merge']; export type ImportModelConfig = S['Body_import_model']; // Graphs From 389082023a5590f98babbec1b589c4886dfdc49d Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 16 Feb 2024 21:57:30 +1100 Subject: [PATCH 103/340] refactor(ui): url builders for each router The MM2 router is at `api/v2/models`. URL builder utils make this a bit easier to manage. --- .../web/src/services/api/endpoints/appInfo.ts | 24 +++++--- .../web/src/services/api/endpoints/boards.ts | 20 +++++-- .../web/src/services/api/endpoints/images.ts | 56 ++++++++++++------- .../web/src/services/api/endpoints/models.ts | 53 +++++++++++------- .../web/src/services/api/endpoints/queue.ts | 40 +++++++------ .../src/services/api/endpoints/utilities.ts | 12 +++- .../src/services/api/endpoints/workflows.ts | 20 +++++-- .../frontend/web/src/services/api/index.ts | 5 +- .../frontend/web/src/services/api/schema.ts | 32 +++++------ .../frontend/web/src/services/api/types.ts | 6 +- .../frontend/web/src/services/api/util.ts | 3 +- invokeai/frontend/web/vite.config.mts | 6 +- 12 files changed, 176 insertions(+), 101 deletions(-) diff --git a/invokeai/frontend/web/src/services/api/endpoints/appInfo.ts b/invokeai/frontend/web/src/services/api/endpoints/appInfo.ts index c0916f568e8..a7efaafcc82 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/appInfo.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/appInfo.ts @@ -3,27 +3,35 @@ import type { OpenAPIV3_1 } from 'openapi-types'; import type { paths } from 'services/api/schema'; import type { AppConfig, AppDependencyVersions, AppVersion } from 'services/api/types'; -import { api } from '..'; +import { api, buildV1Url } from '..'; + +/** + * Builds an endpoint URL for the app router + * @example + * buildAppInfoUrl('some-path') + * // '/api/v1/app/some-path' + */ +const buildAppInfoUrl = (path: string = '') => buildV1Url(`app/${path}`); export const appInfoApi = api.injectEndpoints({ endpoints: (build) => ({ getAppVersion: build.query({ query: () => ({ - url: `app/version`, + url: buildAppInfoUrl('version'), method: 'GET', }), providesTags: ['FetchOnReconnect'], }), getAppDeps: build.query({ query: () => ({ - url: `app/app_deps`, + url: buildAppInfoUrl('app_deps'), method: 'GET', }), providesTags: ['FetchOnReconnect'], }), getAppConfig: build.query({ query: () => ({ - url: `app/config`, + url: buildAppInfoUrl('config'), method: 'GET', }), providesTags: ['FetchOnReconnect'], @@ -33,28 +41,28 @@ export const appInfoApi = api.injectEndpoints({ void >({ query: () => ({ - url: `app/invocation_cache/status`, + url: buildAppInfoUrl('invocation_cache/status'), method: 'GET', }), providesTags: ['InvocationCacheStatus', 'FetchOnReconnect'], }), clearInvocationCache: build.mutation({ query: () => ({ - url: `app/invocation_cache`, + url: buildAppInfoUrl('invocation_cache'), method: 'DELETE', }), invalidatesTags: ['InvocationCacheStatus'], }), enableInvocationCache: build.mutation({ query: () => ({ - url: `app/invocation_cache/enable`, + url: buildAppInfoUrl('invocation_cache/enable'), method: 'PUT', }), invalidatesTags: ['InvocationCacheStatus'], }), disableInvocationCache: build.mutation({ query: () => ({ - url: `app/invocation_cache/disable`, + url: buildAppInfoUrl('invocation_cache/disable'), method: 'PUT', }), invalidatesTags: ['InvocationCacheStatus'], diff --git a/invokeai/frontend/web/src/services/api/endpoints/boards.ts b/invokeai/frontend/web/src/services/api/endpoints/boards.ts index 6977a2bd53a..8efda867373 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/boards.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/boards.ts @@ -9,7 +9,15 @@ import type { import { getListImagesUrl } from 'services/api/util'; import type { ApiTagDescription } from '..'; -import { api, LIST_TAG } from '..'; +import { api, buildV1Url, LIST_TAG } from '..'; + +/** + * Builds an endpoint URL for the boards router + * @example + * buildBoardsUrl('some-path') + * // '/api/v1/boards/some-path' + */ +export const buildBoardsUrl = (path: string = '') => buildV1Url(`boards/${path}`); export const boardsApi = api.injectEndpoints({ endpoints: (build) => ({ @@ -17,7 +25,7 @@ export const boardsApi = api.injectEndpoints({ * Boards Queries */ listBoards: build.query({ - query: (arg) => ({ url: 'boards/', params: arg }), + query: (arg) => ({ url: buildBoardsUrl(), params: arg }), providesTags: (result) => { // any list of boards const tags: ApiTagDescription[] = [{ type: 'Board', id: LIST_TAG }, 'FetchOnReconnect']; @@ -38,7 +46,7 @@ export const boardsApi = api.injectEndpoints({ listAllBoards: build.query, void>({ query: () => ({ - url: 'boards/', + url: buildBoardsUrl(), params: { all: true }, }), providesTags: (result) => { @@ -61,7 +69,7 @@ export const boardsApi = api.injectEndpoints({ listAllImageNamesForBoard: build.query, string>({ query: (board_id) => ({ - url: `boards/${board_id}/image_names`, + url: buildBoardsUrl(`${board_id}/image_names`), }), providesTags: (result, error, arg) => [{ type: 'ImageNameList', id: arg }, 'FetchOnReconnect'], keepUnusedDataFor: 0, @@ -107,7 +115,7 @@ export const boardsApi = api.injectEndpoints({ createBoard: build.mutation({ query: (board_name) => ({ - url: `boards/`, + url: buildBoardsUrl(), method: 'POST', params: { board_name }, }), @@ -116,7 +124,7 @@ export const boardsApi = api.injectEndpoints({ updateBoard: build.mutation({ query: ({ board_id, changes }) => ({ - url: `boards/${board_id}`, + url: buildBoardsUrl(board_id), method: 'PATCH', body: changes, }), diff --git a/invokeai/frontend/web/src/services/api/endpoints/images.ts b/invokeai/frontend/web/src/services/api/endpoints/images.ts index 181c8d23fc8..49eb28390ff 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/images.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/images.ts @@ -26,8 +26,24 @@ import { } from 'services/api/util'; import type { ApiTagDescription } from '..'; -import { api, LIST_TAG } from '..'; -import { boardsApi } from './boards'; +import { api, buildV1Url, LIST_TAG } from '..'; +import { boardsApi, buildBoardsUrl } from './boards'; + +/** + * Builds an endpoint URL for the images router + * @example + * buildImagesUrl('some-path') + * // '/api/v1/images/some-path' + */ +const buildImagesUrl = (path: string = '') => buildV1Url(`images/${path}`); + +/** + * Builds an endpoint URL for the board_images router + * @example + * buildBoardImagesUrl('some-path') + * // '/api/v1/board_images/some-path' + */ +const buildBoardImagesUrl = (path: string = '') => buildV1Url(`board_images/${path}`); export const imagesApi = api.injectEndpoints({ endpoints: (build) => ({ @@ -90,20 +106,20 @@ export const imagesApi = api.injectEndpoints({ keepUnusedDataFor: 86400, }), getIntermediatesCount: build.query({ - query: () => ({ url: 'images/intermediates' }), + query: () => ({ url: buildImagesUrl('intermediates') }), providesTags: ['IntermediatesCount', 'FetchOnReconnect'], }), clearIntermediates: build.mutation({ - query: () => ({ url: `images/intermediates`, method: 'DELETE' }), + query: () => ({ url: buildImagesUrl('intermediates'), method: 'DELETE' }), invalidatesTags: ['IntermediatesCount'], }), getImageDTO: build.query({ - query: (image_name) => ({ url: `images/i/${image_name}` }), + query: (image_name) => ({ url: buildImagesUrl(`i/${image_name}`) }), providesTags: (result, error, image_name) => [{ type: 'Image', id: image_name }], keepUnusedDataFor: 86400, // 24 hours }), getImageMetadata: build.query({ - query: (image_name) => ({ url: `images/i/${image_name}/metadata` }), + query: (image_name) => ({ url: buildImagesUrl(`i/${image_name}/metadata`) }), providesTags: (result, error, image_name) => [{ type: 'ImageMetadata', id: image_name }], transformResponse: ( response: paths['/api/v1/images/i/{image_name}/metadata']['get']['responses']['200']['content']['application/json'] @@ -130,7 +146,7 @@ export const imagesApi = api.injectEndpoints({ }), deleteImage: build.mutation({ query: ({ image_name }) => ({ - url: `images/i/${image_name}`, + url: buildImagesUrl(`i/${image_name}`), method: 'DELETE', }), async onQueryStarted(imageDTO, { dispatch, queryFulfilled }) { @@ -185,7 +201,7 @@ export const imagesApi = api.injectEndpoints({ query: ({ imageDTOs }) => { const image_names = imageDTOs.map((imageDTO) => imageDTO.image_name); return { - url: `images/delete`, + url: buildImagesUrl('delete'), method: 'POST', body: { image_names, @@ -258,7 +274,7 @@ export const imagesApi = api.injectEndpoints({ */ changeImageIsIntermediate: build.mutation({ query: ({ imageDTO, is_intermediate }) => ({ - url: `images/i/${imageDTO.image_name}`, + url: buildImagesUrl(`i/${imageDTO.image_name}`), method: 'PATCH', body: { is_intermediate }, }), @@ -380,7 +396,7 @@ export const imagesApi = api.injectEndpoints({ */ changeImageSessionId: build.mutation({ query: ({ imageDTO, session_id }) => ({ - url: `images/i/${imageDTO.image_name}`, + url: buildImagesUrl(`i/${imageDTO.image_name}`), method: 'PATCH', body: { session_id }, }), @@ -417,7 +433,7 @@ export const imagesApi = api.injectEndpoints({ { imageDTOs: ImageDTO[] } >({ query: ({ imageDTOs: images }) => ({ - url: `images/star`, + url: buildImagesUrl('star'), method: 'POST', body: { image_names: images.map((img) => img.image_name) }, }), @@ -511,7 +527,7 @@ export const imagesApi = api.injectEndpoints({ { imageDTOs: ImageDTO[] } >({ query: ({ imageDTOs: images }) => ({ - url: `images/unstar`, + url: buildImagesUrl('unstar'), method: 'POST', body: { image_names: images.map((img) => img.image_name) }, }), @@ -611,7 +627,7 @@ export const imagesApi = api.injectEndpoints({ const formData = new FormData(); formData.append('file', file); return { - url: `images/upload`, + url: buildImagesUrl('upload'), method: 'POST', body: formData, params: { @@ -674,7 +690,7 @@ export const imagesApi = api.injectEndpoints({ }), deleteBoard: build.mutation({ - query: (board_id) => ({ url: `boards/${board_id}`, method: 'DELETE' }), + query: (board_id) => ({ url: buildBoardsUrl(board_id), method: 'DELETE' }), invalidatesTags: () => [ { type: 'Board', id: LIST_TAG }, // invalidate the 'No Board' cache @@ -764,7 +780,7 @@ export const imagesApi = api.injectEndpoints({ deleteBoardAndImages: build.mutation({ query: (board_id) => ({ - url: `boards/${board_id}`, + url: buildBoardsUrl(board_id), method: 'DELETE', params: { include_images: true }, }), @@ -840,7 +856,7 @@ export const imagesApi = api.injectEndpoints({ query: ({ board_id, imageDTO }) => { const { image_name } = imageDTO; return { - url: `board_images/`, + url: buildBoardImagesUrl(), method: 'POST', body: { board_id, image_name }, }; @@ -961,7 +977,7 @@ export const imagesApi = api.injectEndpoints({ query: ({ imageDTO }) => { const { image_name } = imageDTO; return { - url: `board_images/`, + url: buildBoardImagesUrl(), method: 'DELETE', body: { image_name }, }; @@ -1080,7 +1096,7 @@ export const imagesApi = api.injectEndpoints({ } >({ query: ({ board_id, imageDTOs }) => ({ - url: `board_images/batch`, + url: buildBoardImagesUrl('batch'), method: 'POST', body: { image_names: imageDTOs.map((i) => i.image_name), @@ -1197,7 +1213,7 @@ export const imagesApi = api.injectEndpoints({ } >({ query: ({ imageDTOs }) => ({ - url: `board_images/batch/delete`, + url: buildBoardImagesUrl('batch/delete'), method: 'POST', body: { image_names: imageDTOs.map((i) => i.image_name), @@ -1321,7 +1337,7 @@ export const imagesApi = api.injectEndpoints({ components['schemas']['Body_download_images_from_list'] >({ query: ({ image_names, board_id }) => ({ - url: `images/download`, + url: buildImagesUrl('download'), method: 'POST', body: { image_names, diff --git a/invokeai/frontend/web/src/services/api/endpoints/models.ts b/invokeai/frontend/web/src/services/api/endpoints/models.ts index 97e221454d5..9a7f1080564 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/models.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/models.ts @@ -19,7 +19,10 @@ import type { } from 'services/api/types'; import type { ApiTagDescription, tagTypes } from '..'; -import { api, LIST_TAG } from '..'; +import { api, buildV2Url, LIST_TAG } from '..'; + +/* eslint-disable @typescript-eslint/no-explicit-any */ +export const getModelId = (input: any): any => input; type UpdateMainModelArg = { base_model: BaseModelType; @@ -36,6 +39,8 @@ type UpdateLoRAModelArg = { type UpdateMainModelResponse = paths['/api/v2/models/i/{key}']['patch']['responses']['200']['content']['application/json']; +type ListModelsArg = NonNullable; + type UpdateLoRAModelResponse = UpdateMainModelResponse; type DeleteMainModelArg = { @@ -152,17 +157,25 @@ const buildTransformResponse = return adapter.setAll(adapter.getInitialState(), response.models); }; +/** + * Builds an endpoint URL for the models router + * @example + * buildModelsUrl('some-path') + * // '/api/v1/models/some-path' + */ +const buildModelsUrl = (path: string = '') => buildV2Url(`models/${path}`); + export const modelsApi = api.injectEndpoints({ endpoints: (build) => ({ getMainModels: build.query, BaseModelType[]>({ query: (base_models) => { - const params = { + const params: ListModelsArg = { model_type: 'main', base_models, }; const query = queryString.stringify(params, { arrayFormat: 'none' }); - return `models/?${query}`; + return buildModelsUrl(`?${query}`); }, providesTags: buildProvidesTags('MainModel'), transformResponse: buildTransformResponse(mainModelsAdapter), @@ -170,7 +183,7 @@ export const modelsApi = api.injectEndpoints({ updateMainModels: build.mutation({ query: ({ base_model, model_name, body }) => { return { - url: `models/${base_model}/main/${model_name}`, + url: buildModelsUrl(`${base_model}/main/${model_name}`), method: 'PATCH', body: body, }; @@ -180,7 +193,7 @@ export const modelsApi = api.injectEndpoints({ importMainModels: build.mutation({ query: ({ body }) => { return { - url: `models/import`, + url: buildModelsUrl('import'), method: 'POST', body: body, }; @@ -190,7 +203,7 @@ export const modelsApi = api.injectEndpoints({ addMainModels: build.mutation({ query: ({ body }) => { return { - url: `models/add`, + url: buildModelsUrl('add'), method: 'POST', body: body, }; @@ -200,7 +213,7 @@ export const modelsApi = api.injectEndpoints({ deleteMainModels: build.mutation({ query: ({ base_model, model_name, model_type }) => { return { - url: `models/${base_model}/${model_type}/${model_name}`, + url: buildModelsUrl(`${base_model}/${model_type}/${model_name}`), method: 'DELETE', }; }, @@ -209,7 +222,7 @@ export const modelsApi = api.injectEndpoints({ convertMainModels: build.mutation({ query: ({ base_model, model_name, convert_dest_directory }) => { return { - url: `models/convert/${base_model}/main/${model_name}`, + url: buildModelsUrl(`convert/${base_model}/main/${model_name}`), method: 'PUT', params: { convert_dest_directory }, }; @@ -219,7 +232,7 @@ export const modelsApi = api.injectEndpoints({ mergeMainModels: build.mutation({ query: ({ base_model, body }) => { return { - url: `models/merge/${base_model}`, + url: buildModelsUrl(`merge/${base_model}`), method: 'PUT', body: body, }; @@ -229,21 +242,21 @@ export const modelsApi = api.injectEndpoints({ syncModels: build.mutation({ query: () => { return { - url: `models/sync`, + url: buildModelsUrl('sync'), method: 'POST', }; }, invalidatesTags: ['Model'], }), getLoRAModels: build.query, void>({ - query: () => ({ url: 'models/', params: { model_type: 'lora' } }), + query: () => ({ url: buildModelsUrl(), params: { model_type: 'lora' } }), providesTags: buildProvidesTags('LoRAModel'), transformResponse: buildTransformResponse(loraModelsAdapter), }), updateLoRAModels: build.mutation({ query: ({ base_model, model_name, body }) => { return { - url: `models/${base_model}/lora/${model_name}`, + url: buildModelsUrl(`${base_model}/lora/${model_name}`), method: 'PATCH', body: body, }; @@ -253,34 +266,34 @@ export const modelsApi = api.injectEndpoints({ deleteLoRAModels: build.mutation({ query: ({ base_model, model_name }) => { return { - url: `models/${base_model}/lora/${model_name}`, + url: buildModelsUrl(`${base_model}/lora/${model_name}`), method: 'DELETE', }; }, invalidatesTags: [{ type: 'LoRAModel', id: LIST_TAG }], }), getControlNetModels: build.query, void>({ - query: () => ({ url: 'models/', params: { model_type: 'controlnet' } }), + query: () => ({ url: buildModelsUrl(), params: { model_type: 'controlnet' } }), providesTags: buildProvidesTags('ControlNetModel'), transformResponse: buildTransformResponse(controlNetModelsAdapter), }), getIPAdapterModels: build.query, void>({ - query: () => ({ url: 'models/', params: { model_type: 'ip_adapter' } }), + query: () => ({ url: buildModelsUrl(), params: { model_type: 'ip_adapter' } }), providesTags: buildProvidesTags('IPAdapterModel'), transformResponse: buildTransformResponse(ipAdapterModelsAdapter), }), getT2IAdapterModels: build.query, void>({ - query: () => ({ url: 'models/', params: { model_type: 't2i_adapter' } }), + query: () => ({ url: buildModelsUrl(), params: { model_type: 't2i_adapter' } }), providesTags: buildProvidesTags('T2IAdapterModel'), transformResponse: buildTransformResponse(t2iAdapterModelsAdapter), }), getVaeModels: build.query, void>({ - query: () => ({ url: 'models/', params: { model_type: 'vae' } }), + query: () => ({ url: buildModelsUrl(), params: { model_type: 'vae' } }), providesTags: buildProvidesTags('VaeModel'), transformResponse: buildTransformResponse(vaeModelsAdapter), }), getTextualInversionModels: build.query, void>({ - query: () => ({ url: 'models/', params: { model_type: 'embedding' } }), + query: () => ({ url: buildModelsUrl(), params: { model_type: 'embedding' } }), providesTags: buildProvidesTags('TextualInversionModel'), transformResponse: buildTransformResponse(textualInversionModelsAdapter), }), @@ -288,14 +301,14 @@ export const modelsApi = api.injectEndpoints({ query: (arg) => { const folderQueryStr = queryString.stringify(arg, {}); return { - url: `/models/search?${folderQueryStr}`, + url: buildModelsUrl(`search?${folderQueryStr}`), }; }, }), getCheckpointConfigs: build.query({ query: () => { return { - url: `/models/ckpt_confs`, + url: buildModelsUrl(`ckpt_confs`), }; }, }), diff --git a/invokeai/frontend/web/src/services/api/endpoints/queue.ts b/invokeai/frontend/web/src/services/api/endpoints/queue.ts index 6c0798a936b..385aa8ad12d 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/queue.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/queue.ts @@ -7,7 +7,15 @@ import queryString from 'query-string'; import type { components, paths } from 'services/api/schema'; import type { ApiTagDescription } from '..'; -import { api } from '..'; +import { api, buildV1Url } from '..'; + +/** + * Builds an endpoint URL for the queue router + * @example + * buildQueueUrl('some-path') + * // '/api/v1/queue/queue_id/some-path' + */ +const buildQueueUrl = (path: string = '') => buildV1Url(`queue/${$queueId.get()}/${path}`); const getListQueueItemsUrl = (queryArgs?: paths['/api/v1/queue/{queue_id}/list']['get']['parameters']['query']) => { const query = queryArgs @@ -17,10 +25,10 @@ const getListQueueItemsUrl = (queryArgs?: paths['/api/v1/queue/{queue_id}/list'] : undefined; if (query) { - return `queue/${$queueId.get()}/list?${query}`; + return buildQueueUrl(`list?${query}`); } - return `queue/${$queueId.get()}/list`; + return buildQueueUrl('list'); }; export type SessionQueueItemStatus = NonNullable< @@ -58,7 +66,7 @@ export const queueApi = api.injectEndpoints({ paths['/api/v1/queue/{queue_id}/enqueue_batch']['post']['requestBody']['content']['application/json'] >({ query: (arg) => ({ - url: `queue/${$queueId.get()}/enqueue_batch`, + url: buildQueueUrl('enqueue_batch'), body: arg, method: 'POST', }), @@ -78,7 +86,7 @@ export const queueApi = api.injectEndpoints({ void >({ query: () => ({ - url: `queue/${$queueId.get()}/processor/resume`, + url: buildQueueUrl('processor/resume'), method: 'PUT', }), invalidatesTags: ['CurrentSessionQueueItem', 'SessionQueueStatus'], @@ -88,7 +96,7 @@ export const queueApi = api.injectEndpoints({ void >({ query: () => ({ - url: `queue/${$queueId.get()}/processor/pause`, + url: buildQueueUrl('processor/pause'), method: 'PUT', }), invalidatesTags: ['CurrentSessionQueueItem', 'SessionQueueStatus'], @@ -98,7 +106,7 @@ export const queueApi = api.injectEndpoints({ void >({ query: () => ({ - url: `queue/${$queueId.get()}/prune`, + url: buildQueueUrl('prune'), method: 'PUT', }), invalidatesTags: ['SessionQueueStatus', 'BatchStatus'], @@ -117,7 +125,7 @@ export const queueApi = api.injectEndpoints({ void >({ query: () => ({ - url: `queue/${$queueId.get()}/clear`, + url: buildQueueUrl('clear'), method: 'PUT', }), invalidatesTags: [ @@ -142,7 +150,7 @@ export const queueApi = api.injectEndpoints({ void >({ query: () => ({ - url: `queue/${$queueId.get()}/current`, + url: buildQueueUrl('current'), method: 'GET', }), providesTags: (result) => { @@ -158,7 +166,7 @@ export const queueApi = api.injectEndpoints({ void >({ query: () => ({ - url: `queue/${$queueId.get()}/next`, + url: buildQueueUrl('next'), method: 'GET', }), providesTags: (result) => { @@ -174,7 +182,7 @@ export const queueApi = api.injectEndpoints({ void >({ query: () => ({ - url: `queue/${$queueId.get()}/status`, + url: buildQueueUrl('status'), method: 'GET', }), providesTags: ['SessionQueueStatus', 'FetchOnReconnect'], @@ -184,7 +192,7 @@ export const queueApi = api.injectEndpoints({ { batch_id: string } >({ query: ({ batch_id }) => ({ - url: `queue/${$queueId.get()}/b/${batch_id}/status`, + url: buildQueueUrl(`/b/${batch_id}/status`), method: 'GET', }), providesTags: (result) => { @@ -200,7 +208,7 @@ export const queueApi = api.injectEndpoints({ number >({ query: (item_id) => ({ - url: `queue/${$queueId.get()}/i/${item_id}`, + url: buildQueueUrl(`i/${item_id}`), method: 'GET', }), providesTags: (result) => { @@ -216,7 +224,7 @@ export const queueApi = api.injectEndpoints({ number >({ query: (item_id) => ({ - url: `queue/${$queueId.get()}/i/${item_id}/cancel`, + url: buildQueueUrl(`i/${item_id}/cancel`), method: 'PUT', }), onQueryStarted: async (item_id, { dispatch, queryFulfilled }) => { @@ -253,7 +261,7 @@ export const queueApi = api.injectEndpoints({ paths['/api/v1/queue/{queue_id}/cancel_by_batch_ids']['put']['requestBody']['content']['application/json'] >({ query: (body) => ({ - url: `queue/${$queueId.get()}/cancel_by_batch_ids`, + url: buildQueueUrl('cancel_by_batch_ids'), method: 'PUT', body, }), @@ -279,7 +287,7 @@ export const queueApi = api.injectEndpoints({ method: 'GET', }), serializeQueryArgs: () => { - return `queue/${$queueId.get()}/list`; + return buildQueueUrl('list'); }, transformResponse: (response: components['schemas']['CursorPaginatedResults_SessionQueueItemDTO_']) => queueItemsAdapter.addMany( diff --git a/invokeai/frontend/web/src/services/api/endpoints/utilities.ts b/invokeai/frontend/web/src/services/api/endpoints/utilities.ts index c08ee62dc9f..309dd2dc79f 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/utilities.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/utilities.ts @@ -1,6 +1,14 @@ import type { components } from 'services/api/schema'; -import { api } from '..'; +import { api, buildV1Url } from '..'; + +/** + * Builds an endpoint URL for the utilities router + * @example + * buildUtilitiesUrl('some-path') + * // '/api/v1/utilities/some-path' + */ +const buildUtilitiesUrl = (path: string = '') => buildV1Url(`utilities/${path}`); export const utilitiesApi = api.injectEndpoints({ endpoints: (build) => ({ @@ -9,7 +17,7 @@ export const utilitiesApi = api.injectEndpoints({ { prompt: string; max_prompts: number } >({ query: (arg) => ({ - url: 'utilities/dynamicprompts', + url: buildUtilitiesUrl('dynamicprompts'), body: arg, method: 'POST', }), diff --git a/invokeai/frontend/web/src/services/api/endpoints/workflows.ts b/invokeai/frontend/web/src/services/api/endpoints/workflows.ts index c382f7e1111..1e64809e5a7 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/workflows.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/workflows.ts @@ -1,6 +1,14 @@ import type { paths } from 'services/api/schema'; -import { api, LIST_TAG } from '..'; +import { api, buildV1Url, LIST_TAG } from '..'; + +/** + * Builds an endpoint URL for the workflows router + * @example + * buildWorkflowsUrl('some-path') + * // '/api/v1/workflows/some-path' + */ +const buildWorkflowsUrl = (path: string = '') => buildV1Url(`workflows/${path}`); export const workflowsApi = api.injectEndpoints({ endpoints: (build) => ({ @@ -8,7 +16,7 @@ export const workflowsApi = api.injectEndpoints({ paths['/api/v1/workflows/i/{workflow_id}']['get']['responses']['200']['content']['application/json'], string >({ - query: (workflow_id) => `workflows/i/${workflow_id}`, + query: (workflow_id) => buildWorkflowsUrl(`i/${workflow_id}`), providesTags: (result, error, workflow_id) => [{ type: 'Workflow', id: workflow_id }, 'FetchOnReconnect'], onQueryStarted: async (arg, api) => { const { dispatch, queryFulfilled } = api; @@ -22,7 +30,7 @@ export const workflowsApi = api.injectEndpoints({ }), deleteWorkflow: build.mutation({ query: (workflow_id) => ({ - url: `workflows/i/${workflow_id}`, + url: buildWorkflowsUrl(`i/${workflow_id}`), method: 'DELETE', }), invalidatesTags: (result, error, workflow_id) => [ @@ -36,7 +44,7 @@ export const workflowsApi = api.injectEndpoints({ paths['/api/v1/workflows/']['post']['requestBody']['content']['application/json']['workflow'] >({ query: (workflow) => ({ - url: 'workflows/', + url: buildWorkflowsUrl(), method: 'POST', body: { workflow }, }), @@ -50,7 +58,7 @@ export const workflowsApi = api.injectEndpoints({ paths['/api/v1/workflows/i/{workflow_id}']['patch']['requestBody']['content']['application/json']['workflow'] >({ query: (workflow) => ({ - url: `workflows/i/${workflow.id}`, + url: buildWorkflowsUrl(`i/${workflow.id}`), method: 'PATCH', body: { workflow }, }), @@ -65,7 +73,7 @@ export const workflowsApi = api.injectEndpoints({ NonNullable >({ query: (params) => ({ - url: 'workflows/', + url: buildWorkflowsUrl(), params, }), providesTags: ['FetchOnReconnect', { type: 'Workflow', id: LIST_TAG }], diff --git a/invokeai/frontend/web/src/services/api/index.ts b/invokeai/frontend/web/src/services/api/index.ts index 6a342bb72d7..3584bdec453 100644 --- a/invokeai/frontend/web/src/services/api/index.ts +++ b/invokeai/frontend/web/src/services/api/index.ts @@ -57,7 +57,7 @@ const dynamicBaseQuery: BaseQueryFn `api/v1/${path}`; +export const buildV2Url = (path: string): string => `api/v2/${path}`; diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index 3393e74d486..40fc262be26 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -4212,7 +4212,7 @@ export type components = { * @description The nodes in this graph */ nodes?: { - [key: string]: components["schemas"]["CenterPadCropInvocation"] | components["schemas"]["ConditioningInvocation"] | components["schemas"]["CoreMetadataInvocation"] | components["schemas"]["SeamlessModeInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["FreeUInvocation"] | components["schemas"]["TileToPropertiesInvocation"] | components["schemas"]["SaveImageInvocation"] | components["schemas"]["FaceMaskInvocation"] | components["schemas"]["CalculateImageTilesEvenSplitInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["ImageChannelOffsetInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["LatentsInvocation"] | components["schemas"]["CreateDenoiseMaskInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ConditioningCollectionInvocation"] | components["schemas"]["IntegerCollectionInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["StringSplitInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["RandomFloatInvocation"] | components["schemas"]["FaceOffInvocation"] | components["schemas"]["T2IAdapterInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["MergeTilesToImageInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["StringInvocation"] | components["schemas"]["FloatMathInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["StringJoinThreeInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["FloatToIntegerInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ColorInvocation"] | components["schemas"]["SchedulerInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["BlankImageInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["FloatInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["CropLatentsCoreInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["IntegerMathInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ColorMapImageProcessorInvocation"] | components["schemas"]["ImageChannelMultiplyInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["CalculateImageTilesInvocation"] | components["schemas"]["LaMaInfillInvocation"] | components["schemas"]["StringReplaceInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["SDXLLoraLoaderInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["DWOpenposeImageProcessorInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["CV2InfillInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["ImageInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["IntegerInvocation"] | components["schemas"]["PairTileImageInvocation"] | components["schemas"]["UnsharpMaskInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["BlendLatentsInvocation"] | components["schemas"]["StringCollectionInvocation"] | components["schemas"]["StringJoinInvocation"] | components["schemas"]["MergeMetadataInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["IdealSizeInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["MetadataItemInvocation"] | components["schemas"]["MetadataInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["BooleanCollectionInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["IPAdapterInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["FloatCollectionInvocation"] | components["schemas"]["LatentsCollectionInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["FaceIdentifierInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["BooleanInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["StringSplitNegInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["RoundInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["CalculateImageTilesMinimumOverlapInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["DepthAnythingImageProcessorInvocation"]; + [key: string]: components["schemas"]["ControlNetInvocation"] | components["schemas"]["TileToPropertiesInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["SchedulerInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["StringInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["ImageInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["SeamlessModeInvocation"] | components["schemas"]["FloatInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["SDXLLoraLoaderInvocation"] | components["schemas"]["SaveImageInvocation"] | components["schemas"]["ColorMapImageProcessorInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["T2IAdapterInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["CalculateImageTilesMinimumOverlapInvocation"] | components["schemas"]["StringSplitNegInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["DWOpenposeImageProcessorInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["ConditioningCollectionInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["StringSplitInvocation"] | components["schemas"]["FaceIdentifierInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["IntegerInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["LaMaInfillInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["CenterPadCropInvocation"] | components["schemas"]["CoreMetadataInvocation"] | components["schemas"]["CalculateImageTilesInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["CreateDenoiseMaskInvocation"] | components["schemas"]["StringJoinThreeInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["PairTileImageInvocation"] | components["schemas"]["IdealSizeInvocation"] | components["schemas"]["IntegerMathInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["LatentsInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["FloatMathInvocation"] | components["schemas"]["StringCollectionInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageChannelOffsetInvocation"] | components["schemas"]["FloatCollectionInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MergeTilesToImageInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["MergeMetadataInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["BlendLatentsInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["MetadataItemInvocation"] | components["schemas"]["FloatToIntegerInvocation"] | components["schemas"]["DepthAnythingImageProcessorInvocation"] | components["schemas"]["StringReplaceInvocation"] | components["schemas"]["RoundInvocation"] | components["schemas"]["LatentsCollectionInvocation"] | components["schemas"]["CropLatentsCoreInvocation"] | components["schemas"]["BooleanCollectionInvocation"] | components["schemas"]["BlankImageInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["FreeUInvocation"] | components["schemas"]["IPAdapterInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["ImageChannelMultiplyInvocation"] | components["schemas"]["MetadataInvocation"] | components["schemas"]["StringJoinInvocation"] | components["schemas"]["FaceMaskInvocation"] | components["schemas"]["IntegerCollectionInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["CalculateImageTilesEvenSplitInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["CV2InfillInvocation"] | components["schemas"]["RandomFloatInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["UnsharpMaskInvocation"] | components["schemas"]["ColorInvocation"] | components["schemas"]["BooleanInvocation"] | components["schemas"]["FaceOffInvocation"] | components["schemas"]["ConditioningInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["CvInpaintInvocation"]; }; /** * Edges @@ -4249,7 +4249,7 @@ export type components = { * @description The results of node executions */ results: { - [key: string]: components["schemas"]["ImageCollectionOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["T2IAdapterOutput"] | components["schemas"]["MetadataOutput"] | components["schemas"]["ColorCollectionOutput"] | components["schemas"]["IntegerOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["LatentsCollectionOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["CalculateImageTilesOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["SeamlessModeOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["SDXLLoraLoaderOutput"] | components["schemas"]["String2Output"] | components["schemas"]["SchedulerOutput"] | components["schemas"]["VAEOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["IdealSizeOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["GraphInvocationOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["IPAdapterOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["UNetOutput"] | components["schemas"]["LoraLoaderOutput"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["DenoiseMaskOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["TileToPropertiesOutput"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["PairTileImageOutput"] | components["schemas"]["ClipSkipInvocationOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["CLIPOutput"]; + [key: string]: components["schemas"]["SchedulerOutput"] | components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["LoraLoaderOutput"] | components["schemas"]["IdealSizeOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["IPAdapterOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["CLIPOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["PairTileImageOutput"] | components["schemas"]["GraphInvocationOutput"] | components["schemas"]["CalculateImageTilesOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["SeamlessModeOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["UNetOutput"] | components["schemas"]["SDXLLoraLoaderOutput"] | components["schemas"]["TileToPropertiesOutput"] | components["schemas"]["T2IAdapterOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["ColorCollectionOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["DenoiseMaskOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["VAEOutput"] | components["schemas"]["ClipSkipInvocationOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["String2Output"] | components["schemas"]["IntegerOutput"] | components["schemas"]["LatentsCollectionOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["MetadataOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["IterateInvocationOutput"]; }; /** * Errors @@ -11119,17 +11119,11 @@ export type components = { */ VaeModelFormat: "checkpoint" | "diffusers"; /** - * StableDiffusion2ModelFormat - * @description An enumeration. - * @enum {string} - */ - StableDiffusion2ModelFormat: "checkpoint" | "diffusers"; - /** - * CLIPVisionModelFormat + * T2IAdapterModelFormat * @description An enumeration. * @enum {string} */ - CLIPVisionModelFormat: "diffusers"; + T2IAdapterModelFormat: "diffusers"; /** * StableDiffusionXLModelFormat * @description An enumeration. @@ -11143,29 +11137,35 @@ export type components = { */ StableDiffusion1ModelFormat: "checkpoint" | "diffusers"; /** - * ControlNetModelFormat + * StableDiffusionOnnxModelFormat * @description An enumeration. * @enum {string} */ - ControlNetModelFormat: "checkpoint" | "diffusers"; + StableDiffusionOnnxModelFormat: "olive" | "onnx"; /** - * StableDiffusionOnnxModelFormat + * ControlNetModelFormat * @description An enumeration. * @enum {string} */ - StableDiffusionOnnxModelFormat: "olive" | "onnx"; + ControlNetModelFormat: "checkpoint" | "diffusers"; /** - * T2IAdapterModelFormat + * StableDiffusion2ModelFormat * @description An enumeration. * @enum {string} */ - T2IAdapterModelFormat: "diffusers"; + StableDiffusion2ModelFormat: "checkpoint" | "diffusers"; /** * LoRAModelFormat * @description An enumeration. * @enum {string} */ LoRAModelFormat: "lycoris" | "diffusers"; + /** + * CLIPVisionModelFormat + * @description An enumeration. + * @enum {string} + */ + CLIPVisionModelFormat: "diffusers"; /** * IPAdapterModelFormat * @description An enumeration. diff --git a/invokeai/frontend/web/src/services/api/types.ts b/invokeai/frontend/web/src/services/api/types.ts index 7a02cc55681..4ae2f9b594e 100644 --- a/invokeai/frontend/web/src/services/api/types.ts +++ b/invokeai/frontend/web/src/services/api/types.ts @@ -61,11 +61,13 @@ export type IPAdapterField = S['IPAdapterField']; // Model Configs // TODO(MM2): Can we make key required in the pydantic model? -type KeyRequired = SetRequired; +type KeyRequired = SetRequired; export type LoRAConfig = KeyRequired; // TODO(MM2): Can we rename this from Vae -> VAE export type VAEConfig = KeyRequired | KeyRequired; -export type ControlNetConfig = KeyRequired | KeyRequired; +export type ControlNetConfig = + | KeyRequired + | KeyRequired; export type IPAdapterConfig = KeyRequired; // TODO(MM2): Can we rename this to T2IAdapterConfig export type T2IAdapterConfig = KeyRequired; diff --git a/invokeai/frontend/web/src/services/api/util.ts b/invokeai/frontend/web/src/services/api/util.ts index f7f36f46308..a7a5d6451e7 100644 --- a/invokeai/frontend/web/src/services/api/util.ts +++ b/invokeai/frontend/web/src/services/api/util.ts @@ -3,6 +3,7 @@ import { getSelectorsOptions } from 'app/store/createMemoizedSelector'; import { dateComparator } from 'common/util/dateComparator'; import { ASSETS_CATEGORIES, IMAGE_CATEGORIES } from 'features/gallery/store/types'; import queryString from 'query-string'; +import { buildV1Url } from 'services/api'; import type { ImageCache, ImageDTO, ListImagesArgs } from './types'; @@ -79,4 +80,4 @@ export const imagesSelectors = imagesAdapter.getSelectors(undefined, getSelector // Helper to create the url for the listImages endpoint. Also we use it to create the cache key. export const getListImagesUrl = (queryArgs: ListImagesArgs) => - `images/?${queryString.stringify(queryArgs, { arrayFormat: 'none' })}`; + buildV1Url(`images/?${queryString.stringify(queryArgs, { arrayFormat: 'none' })}`); diff --git a/invokeai/frontend/web/vite.config.mts b/invokeai/frontend/web/vite.config.mts index f4dbae71232..32e3e1f64fe 100644 --- a/invokeai/frontend/web/vite.config.mts +++ b/invokeai/frontend/web/vite.config.mts @@ -76,9 +76,9 @@ export default defineConfig(({ mode }) => { changeOrigin: true, }, // proxy nodes api - '/api/v1': { - target: 'http://127.0.0.1:9090/api/v1', - rewrite: (path) => path.replace(/^\/api\/v1/, ''), + '/api/': { + target: 'http://127.0.0.1:9090/api/', + rewrite: (path) => path.replace(/^\/api/, ''), changeOrigin: true, }, }, From c8c0b1fdc53fba3ee83c72a52bdc924091b0536e Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 16 Feb 2024 22:41:09 +1100 Subject: [PATCH 104/340] feat(ui): fix main model & control adapter model selects --- .../src/common/hooks/useModelCustomSelect.ts | 88 +++++++++++++++++++ .../parameters/ParamControlAdapterModel.tsx | 49 +++-------- .../hooks/useControlAdapterModelEntities.ts | 23 ----- .../hooks/useControlAdapterModelQuery.ts | 26 ++++++ .../hooks/useControlAdapterType.ts | 10 ++- .../subpanels/ModelManagerPanel.tsx | 6 +- .../MainModel/ParamMainModelSelect.tsx | 45 ++++------ .../src/features/parameters/store/actions.ts | 5 +- .../features/parameters/types/constants.ts | 4 +- 9 files changed, 154 insertions(+), 102 deletions(-) create mode 100644 invokeai/frontend/web/src/common/hooks/useModelCustomSelect.ts delete mode 100644 invokeai/frontend/web/src/features/controlAdapters/hooks/useControlAdapterModelEntities.ts create mode 100644 invokeai/frontend/web/src/features/controlAdapters/hooks/useControlAdapterModelQuery.ts diff --git a/invokeai/frontend/web/src/common/hooks/useModelCustomSelect.ts b/invokeai/frontend/web/src/common/hooks/useModelCustomSelect.ts new file mode 100644 index 00000000000..07ea98a2747 --- /dev/null +++ b/invokeai/frontend/web/src/common/hooks/useModelCustomSelect.ts @@ -0,0 +1,88 @@ +import type { Item } from '@invoke-ai/ui-library'; +import type { EntityState } from '@reduxjs/toolkit'; +import { EMPTY_ARRAY } from 'app/store/util'; +import type { ModelIdentifierWithBase } from 'features/nodes/types/common'; +import { MODEL_TYPE_SHORT_MAP } from 'features/parameters/types/constants'; +import { filter } from 'lodash-es'; +import { useCallback, useMemo } from 'react'; +import { useTranslation } from 'react-i18next'; +import type { AnyModelConfig } from 'services/api/types'; + +type UseModelCustomSelectArg = { + data: EntityState | undefined; + isLoading: boolean; + selectedModel?: ModelIdentifierWithBase | null; + onChange: (value: T | null) => void; + modelFilter?: (model: T) => boolean; + isModelDisabled?: (model: T) => boolean; +}; + +type UseModelCustomSelectReturn = { + selectedItem: Item | null; + items: Item[]; + onChange: (item: Item | null) => void; + placeholder: string; +}; + +const modelFilterDefault = () => true; +const isModelDisabledDefault = () => false; + +export const useModelCustomSelect = ({ + data, + isLoading, + selectedModel, + onChange, + modelFilter = modelFilterDefault, + isModelDisabled = isModelDisabledDefault, +}: UseModelCustomSelectArg): UseModelCustomSelectReturn => { + const { t } = useTranslation(); + + const items: Item[] = useMemo( + () => + data + ? filter(data.entities, modelFilter).map((m) => ({ + label: m.name, + value: m.key, + description: m.description, + group: MODEL_TYPE_SHORT_MAP[m.base], + isDisabled: isModelDisabled(m), + })) + : EMPTY_ARRAY, + [data, isModelDisabled, modelFilter] + ); + + const _onChange = useCallback( + (item: Item | null) => { + if (!item || !data) { + return; + } + const model = data.entities[item.value]; + if (!model) { + return; + } + onChange(model); + }, + [data, onChange] + ); + + const selectedItem = useMemo(() => items.find((o) => o.value === selectedModel?.key) ?? null, [selectedModel, items]); + + const placeholder = useMemo(() => { + if (isLoading) { + return t('common.loading'); + } + + if (items.length === 0) { + return t('models.noModelsAvailable'); + } + + return t('models.selectModel'); + }, [isLoading, items, t]); + + return { + items, + onChange: _onChange, + selectedItem, + placeholder, + }; +}; diff --git a/invokeai/frontend/web/src/features/controlAdapters/components/parameters/ParamControlAdapterModel.tsx b/invokeai/frontend/web/src/features/controlAdapters/components/parameters/ParamControlAdapterModel.tsx index a3202384458..696bf47b2a6 100644 --- a/invokeai/frontend/web/src/features/controlAdapters/components/parameters/ParamControlAdapterModel.tsx +++ b/invokeai/frontend/web/src/features/controlAdapters/components/parameters/ParamControlAdapterModel.tsx @@ -1,34 +1,27 @@ -import { Combobox, FormControl, Tooltip } from '@invoke-ai/ui-library'; -import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; +import { CustomSelect, FormControl } from '@invoke-ai/ui-library'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox'; +import { useModelCustomSelect } from 'common/hooks/useModelCustomSelect'; import { useControlAdapterIsEnabled } from 'features/controlAdapters/hooks/useControlAdapterIsEnabled'; import { useControlAdapterModel } from 'features/controlAdapters/hooks/useControlAdapterModel'; -import { useControlAdapterModelEntities } from 'features/controlAdapters/hooks/useControlAdapterModelEntities'; +import { useControlAdapterModelQuery } from 'features/controlAdapters/hooks/useControlAdapterModelQuery'; import { useControlAdapterType } from 'features/controlAdapters/hooks/useControlAdapterType'; import { controlAdapterModelChanged } from 'features/controlAdapters/store/controlAdaptersSlice'; -import { selectGenerationSlice } from 'features/parameters/store/generationSlice'; import { pick } from 'lodash-es'; import { memo, useCallback, useMemo } from 'react'; -import { useTranslation } from 'react-i18next'; -import type { AnyModelConfig, ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'services/api/types'; +import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'services/api/types'; type ParamControlAdapterModelProps = { id: string; }; -const selectMainModel = createMemoizedSelector(selectGenerationSlice, (generation) => generation.model); - const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => { const isEnabled = useControlAdapterIsEnabled(id); const controlAdapterType = useControlAdapterType(id); const model = useControlAdapterModel(id); const dispatch = useAppDispatch(); const currentBaseModel = useAppSelector((s) => s.generation.model?.base); - const mainModel = useAppSelector(selectMainModel); - const { t } = useTranslation(); - const models = useControlAdapterModelEntities(controlAdapterType); + const { data, isLoading } = useControlAdapterModelQuery(controlAdapterType); const _onChange = useCallback( (model: ControlNetConfig | IPAdapterConfig | T2IAdapterConfig | null) => { @@ -50,34 +43,18 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => { [controlAdapterType, model] ); - const getIsDisabled = useCallback( - (model: AnyModelConfig): boolean => { - const isCompatible = currentBaseModel === model.base; - const hasMainModel = Boolean(currentBaseModel); - return !hasMainModel || !isCompatible; - }, - [currentBaseModel] - ); - - const { options, value, onChange, noOptionsMessage } = useGroupedModelCombobox({ - modelEntities: models, - onChange: _onChange, + const { items, selectedItem, onChange, placeholder } = useModelCustomSelect({ + data, + isLoading, selectedModel, - getIsDisabled, + onChange: _onChange, + modelFilter: (model) => model.base === currentBaseModel, }); return ( - - - - - + + + ); }; diff --git a/invokeai/frontend/web/src/features/controlAdapters/hooks/useControlAdapterModelEntities.ts b/invokeai/frontend/web/src/features/controlAdapters/hooks/useControlAdapterModelEntities.ts deleted file mode 100644 index 0c8baaacc2d..00000000000 --- a/invokeai/frontend/web/src/features/controlAdapters/hooks/useControlAdapterModelEntities.ts +++ /dev/null @@ -1,23 +0,0 @@ -import type { ControlAdapterType } from 'features/controlAdapters/store/types'; -import { - useGetControlNetModelsQuery, - useGetIPAdapterModelsQuery, - useGetT2IAdapterModelsQuery, -} from 'services/api/endpoints/models'; - -export const useControlAdapterModelEntities = (type?: ControlAdapterType) => { - const { data: controlNetModelsData } = useGetControlNetModelsQuery(); - const { data: t2iAdapterModelsData } = useGetT2IAdapterModelsQuery(); - const { data: ipAdapterModelsData } = useGetIPAdapterModelsQuery(); - - if (type === 'controlnet') { - return controlNetModelsData; - } - if (type === 't2i_adapter') { - return t2iAdapterModelsData; - } - if (type === 'ip_adapter') { - return ipAdapterModelsData; - } - return; -}; diff --git a/invokeai/frontend/web/src/features/controlAdapters/hooks/useControlAdapterModelQuery.ts b/invokeai/frontend/web/src/features/controlAdapters/hooks/useControlAdapterModelQuery.ts new file mode 100644 index 00000000000..1d092497af6 --- /dev/null +++ b/invokeai/frontend/web/src/features/controlAdapters/hooks/useControlAdapterModelQuery.ts @@ -0,0 +1,26 @@ +import type { ControlAdapterType } from 'features/controlAdapters/store/types'; +import { + useGetControlNetModelsQuery, + useGetIPAdapterModelsQuery, + useGetT2IAdapterModelsQuery, +} from 'services/api/endpoints/models'; + +export const useControlAdapterModelQuery = (type: ControlAdapterType) => { + const controlNetModelsQuery = useGetControlNetModelsQuery(); + const t2iAdapterModelsQuery = useGetT2IAdapterModelsQuery(); + const ipAdapterModelsQuery = useGetIPAdapterModelsQuery(); + + if (type === 'controlnet') { + return controlNetModelsQuery; + } + if (type === 't2i_adapter') { + return t2iAdapterModelsQuery; + } + if (type === 'ip_adapter') { + return ipAdapterModelsQuery; + } + + // Assert that the end of the function is not reachable. + const exhaustiveCheck: never = type; + return exhaustiveCheck; +}; diff --git a/invokeai/frontend/web/src/features/controlAdapters/hooks/useControlAdapterType.ts b/invokeai/frontend/web/src/features/controlAdapters/hooks/useControlAdapterType.ts index 4e15dc9e64e..fe818f32875 100644 --- a/invokeai/frontend/web/src/features/controlAdapters/hooks/useControlAdapterType.ts +++ b/invokeai/frontend/web/src/features/controlAdapters/hooks/useControlAdapterType.ts @@ -5,14 +5,16 @@ import { selectControlAdaptersSlice, } from 'features/controlAdapters/store/controlAdaptersSlice'; import { useMemo } from 'react'; +import { assert } from 'tsafe'; export const useControlAdapterType = (id: string) => { const selector = useMemo( () => - createMemoizedSelector( - selectControlAdaptersSlice, - (controlAdapters) => selectControlAdapterById(controlAdapters, id)?.type - ), + createMemoizedSelector(selectControlAdaptersSlice, (controlAdapters) => { + const type = selectControlAdapterById(controlAdapters, id)?.type; + assert(type !== undefined, `Control adapter with id ${id} not found`); + return type; + }), [id] ); diff --git a/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel.tsx b/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel.tsx index 7501151ba49..15149b339b9 100644 --- a/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel.tsx +++ b/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel.tsx @@ -2,11 +2,7 @@ import { Flex, Text } from '@invoke-ai/ui-library'; import { memo, useState } from 'react'; import { useTranslation } from 'react-i18next'; import { ALL_BASE_MODELS } from 'services/api/constants'; -import type { - DiffusersModelConfig, - LoRAConfig, - MainModelConfig, -} from 'services/api/endpoints/models'; +import type { DiffusersModelConfig, LoRAConfig, MainModelConfig } from 'services/api/endpoints/models'; import { useGetLoRAModelsQuery, useGetMainModelsQuery } from 'services/api/endpoints/models'; import CheckpointModelEdit from './ModelManagerPanel/CheckpointModelEdit'; diff --git a/invokeai/frontend/web/src/features/parameters/components/MainModel/ParamMainModelSelect.tsx b/invokeai/frontend/web/src/features/parameters/components/MainModel/ParamMainModelSelect.tsx index 18f780bdeeb..0759020cc82 100644 --- a/invokeai/frontend/web/src/features/parameters/components/MainModel/ParamMainModelSelect.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/MainModel/ParamMainModelSelect.tsx @@ -1,62 +1,47 @@ -import { Box, Combobox, FormControl, FormLabel, Tooltip } from '@invoke-ai/ui-library'; +import { CustomSelect, FormControl, FormLabel } from '@invoke-ai/ui-library'; import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover'; -import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox'; +import { useModelCustomSelect } from 'common/hooks/useModelCustomSelect'; import { modelSelected } from 'features/parameters/store/actions'; import { selectGenerationSlice } from 'features/parameters/store/generationSlice'; -import { pick } from 'lodash-es'; -import { memo, useCallback, useMemo } from 'react'; +import { memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; import { NON_REFINER_BASE_MODELS } from 'services/api/constants'; -import type { MainModelConfig } from 'services/api/endpoints/models'; -import { getModelId, mainModelsAdapterSelectors, useGetMainModelsQuery } from 'services/api/endpoints/models'; +import { useGetMainModelsQuery } from 'services/api/endpoints/models'; +import type { MainModelConfig } from 'services/api/types'; const selectModel = createMemoizedSelector(selectGenerationSlice, (generation) => generation.model); const ParamMainModelSelect = () => { const dispatch = useAppDispatch(); const { t } = useTranslation(); - const model = useAppSelector(selectModel); + const selectedModel = useAppSelector(selectModel); const { data, isLoading } = useGetMainModelsQuery(NON_REFINER_BASE_MODELS); - const tooltipLabel = useMemo(() => { - if (!data || !model) { - return; - } - return mainModelsAdapterSelectors.selectById(data, getModelId(model))?.description; - }, [data, model]); + const _onChange = useCallback( (model: MainModelConfig | null) => { if (!model) { return; } - dispatch(modelSelected(pick(model, ['base_model', 'model_name', 'model_type']))); + dispatch(modelSelected({ key: model.key, base: model.base })); }, [dispatch] ); - const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({ - modelEntities: data, - onChange: _onChange, - selectedModel: model, + + const { items, selectedItem, onChange, placeholder } = useModelCustomSelect({ + data, isLoading, + selectedModel, + onChange: _onChange, }); return ( - + {t('modelManager.model')} - - - - - + ); }; diff --git a/invokeai/frontend/web/src/features/parameters/store/actions.ts b/invokeai/frontend/web/src/features/parameters/store/actions.ts index 0d4dda0e87c..f7bf127c057 100644 --- a/invokeai/frontend/web/src/features/parameters/store/actions.ts +++ b/invokeai/frontend/web/src/features/parameters/store/actions.ts @@ -1,6 +1,7 @@ import { createAction } from '@reduxjs/toolkit'; -import type { ImageDTO, MainModelField } from 'services/api/types'; +import type { ParameterModel } from 'features/parameters/types/parameterSchemas'; +import type { ImageDTO } from 'services/api/types'; export const initialImageSelected = createAction('generation/initialImageSelected'); -export const modelSelected = createAction('generation/modelSelected'); +export const modelSelected = createAction('generation/modelSelected'); diff --git a/invokeai/frontend/web/src/features/parameters/types/constants.ts b/invokeai/frontend/web/src/features/parameters/types/constants.ts index a74807a959c..bd78759b390 100644 --- a/invokeai/frontend/web/src/features/parameters/types/constants.ts +++ b/invokeai/frontend/web/src/features/parameters/types/constants.ts @@ -17,8 +17,8 @@ export const MODEL_TYPE_MAP = { */ export const MODEL_TYPE_SHORT_MAP = { any: 'Any', - 'sd-1': 'SD1', - 'sd-2': 'SD2', + 'sd-1': 'SD1.X', + 'sd-2': 'SD2.X', sdxl: 'SDXL', 'sdxl-refiner': 'SDXLR', }; From 8773260da9e749926aa9cb61c4e957b9116a48fc Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 16 Feb 2024 22:42:15 +1100 Subject: [PATCH 105/340] chore(ui): lint --- invokeai/frontend/web/src/common/hooks/useModelCombobox.ts | 4 +--- invokeai/frontend/web/src/services/api/endpoints/workflows.ts | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/invokeai/frontend/web/src/common/hooks/useModelCombobox.ts b/invokeai/frontend/web/src/common/hooks/useModelCombobox.ts index 341fed1e47e..07e6aeb34c4 100644 --- a/invokeai/frontend/web/src/common/hooks/useModelCombobox.ts +++ b/invokeai/frontend/web/src/common/hooks/useModelCombobox.ts @@ -23,9 +23,7 @@ type UseModelComboboxReturn = { noOptionsMessage: () => string; }; -export const useModelCombobox = ( - arg: UseModelComboboxArg -): UseModelComboboxReturn => { +export const useModelCombobox = (arg: UseModelComboboxArg): UseModelComboboxReturn => { const { t } = useTranslation(); const { modelEntities, selectedModel, getIsDisabled, onChange, isLoading, optionsFilter = () => true } = arg; const options = useMemo(() => { diff --git a/invokeai/frontend/web/src/services/api/endpoints/workflows.ts b/invokeai/frontend/web/src/services/api/endpoints/workflows.ts index 1e64809e5a7..0280e2ebc46 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/workflows.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/workflows.ts @@ -5,7 +5,7 @@ import { api, buildV1Url, LIST_TAG } from '..'; /** * Builds an endpoint URL for the workflows router * @example - * buildWorkflowsUrl('some-path') + * buildWorkflowsUrl('some-path') * // '/api/v1/workflows/some-path' */ const buildWorkflowsUrl = (path: string = '') => buildV1Url(`workflows/${path}`); From a8dc0f107eb2d6e1ea6baa7076b015fa7c7d2231 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Thu, 15 Feb 2024 22:41:29 -0500 Subject: [PATCH 106/340] Fix issues identified during PR review by RyanjDick and brandonrising - ModelMetadataStoreService is now injected into ModelRecordStoreService (these two services are really joined at the hip, and should someday be merged) - ModelRecordStoreService is now injected into ModelManagerService - Reduced timeout value for the various installer and download wait*() methods - Introduced a Mock modelmanager for testing - Removed bare print() statement with _logger in the install helper backend. - Removed unused code from model loader init file - Made `locker` a private variable in the `LoadedModel` object. - Fixed up model merge frontend (will be deprecated anyway!) --- invokeai/app/api/dependencies.py | 8 +- .../app/services/download/download_default.py | 2 +- .../invocation_stats_default.py | 2 - .../model_install/model_install_base.py | 6 +- .../model_install/model_install_default.py | 15 +- .../model_manager/model_manager_default.py | 15 +- .../app/services/model_metadata/__init__.py | 9 + .../model_metadata/metadata_store_base.py | 65 +++++ .../model_metadata/metadata_store_sql.py | 222 ++++++++++++++++++ .../model_records/model_records_base.py | 6 +- .../model_records/model_records_sql.py | 18 +- invokeai/backend/install/install_helper.py | 13 +- .../backend/model_manager/load/__init__.py | 17 -- .../backend/model_manager/load/load_base.py | 8 +- .../model_manager/load/load_default.py | 2 +- invokeai/backend/model_manager/merge.py | 5 +- .../model_manager/metadata/__init__.py | 5 +- invokeai/frontend/merge/merge_diffusers.py | 133 +++++++---- tests/aa_nodes/test_invoker.py | 3 +- .../model_records/test_model_records_sql.py | 3 +- .../model_manager_2_fixtures.py | 15 +- .../model_metadata/test_model_metadata.py | 8 +- 22 files changed, 449 insertions(+), 131 deletions(-) create mode 100644 invokeai/app/services/model_metadata/__init__.py create mode 100644 invokeai/app/services/model_metadata/metadata_store_base.py create mode 100644 invokeai/app/services/model_metadata/metadata_store_sql.py diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 378961a0557..8e79b26e2d9 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -28,6 +28,8 @@ from ..services.invocation_stats.invocation_stats_default import InvocationStatsService from ..services.invoker import Invoker from ..services.model_manager.model_manager_default import ModelManagerService +from ..services.model_metadata import ModelMetadataStoreSQL +from ..services.model_records import ModelRecordServiceSQL from ..services.names.names_default import SimpleNameService from ..services.session_processor.session_processor_default import DefaultSessionProcessor from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue @@ -94,8 +96,12 @@ def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True) ) download_queue_service = DownloadQueueService(event_bus=events) + model_metadata_service = ModelMetadataStoreSQL(db=db) model_manager = ModelManagerService.build_model_manager( - app_config=configuration, db=db, download_queue=download_queue_service, events=events + app_config=configuration, + model_record_service=ModelRecordServiceSQL(db=db, metadata_store=model_metadata_service), + download_queue=download_queue_service, + events=events, ) names = SimpleNameService() performance_statistics = InvocationStatsService() diff --git a/invokeai/app/services/download/download_default.py b/invokeai/app/services/download/download_default.py index 6d5cedbcad8..50cac80d094 100644 --- a/invokeai/app/services/download/download_default.py +++ b/invokeai/app/services/download/download_default.py @@ -194,7 +194,7 @@ def wait_for_job(self, job: DownloadJob, timeout: int = 0) -> DownloadJob: """Block until the indicated job has reached terminal state, or when timeout limit reached.""" start = time.time() while not job.in_terminal_state: - if self._job_completed_event.wait(timeout=5): # in case we miss an event + if self._job_completed_event.wait(timeout=0.25): # in case we miss an event self._job_completed_event.clear() if timeout > 0 and time.time() - start > timeout: raise TimeoutError("Timeout exceeded") diff --git a/invokeai/app/services/invocation_stats/invocation_stats_default.py b/invokeai/app/services/invocation_stats/invocation_stats_default.py index 6c893021de4..486a1ca5b3e 100644 --- a/invokeai/app/services/invocation_stats/invocation_stats_default.py +++ b/invokeai/app/services/invocation_stats/invocation_stats_default.py @@ -46,8 +46,6 @@ def collect_stats(self, invocation: BaseInvocation, graph_execution_state_id: st # This is to handle case of the model manager not being initialized, which happens # during some tests. services = self._invoker.services - if services.model_manager is None or services.model_manager.load is None: - yield None if not self._stats.get(graph_execution_state_id): # First time we're seeing this graph_execution_state_id. self._stats[graph_execution_state_id] = GraphExecutionStats() diff --git a/invokeai/app/services/model_install/model_install_base.py b/invokeai/app/services/model_install/model_install_base.py index 39ea8c4a0d1..2f03db0af72 100644 --- a/invokeai/app/services/model_install/model_install_base.py +++ b/invokeai/app/services/model_install/model_install_base.py @@ -18,7 +18,9 @@ from invokeai.app.services.invoker import Invoker from invokeai.app.services.model_records import ModelRecordServiceBase from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant -from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore +from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata + +from ..model_metadata import ModelMetadataStoreBase class InstallStatus(str, Enum): @@ -243,7 +245,7 @@ def __init__( app_config: InvokeAIAppConfig, record_store: ModelRecordServiceBase, download_queue: DownloadQueueServiceBase, - metadata_store: ModelMetadataStore, + metadata_store: ModelMetadataStoreBase, event_bus: Optional["EventServiceBase"] = None, ): """ diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index 20a85a82a14..7dee8bfd8cb 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -20,7 +20,7 @@ from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase, TqdmProgress from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.invoker import Invoker -from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase, ModelRecordServiceSQL +from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase from invokeai.backend.model_manager.config import ( AnyModelConfig, BaseModelType, @@ -33,7 +33,6 @@ AnyModelRepoMetadata, CivitaiMetadataFetch, HuggingFaceMetadataFetch, - ModelMetadataStore, ModelMetadataWithFiles, RemoteModelFile, ) @@ -65,7 +64,6 @@ def __init__( app_config: InvokeAIAppConfig, record_store: ModelRecordServiceBase, download_queue: DownloadQueueServiceBase, - metadata_store: Optional[ModelMetadataStore] = None, event_bus: Optional[EventServiceBase] = None, session: Optional[Session] = None, ): @@ -93,14 +91,7 @@ def __init__( self._running = False self._session = session self._next_job_id = 0 - # There may not necessarily be a metadata store initialized - # so we create one and initialize it with the same sql database - # used by the record store service. - if metadata_store: - self._metadata_store = metadata_store - else: - assert isinstance(record_store, ModelRecordServiceSQL) - self._metadata_store = ModelMetadataStore(record_store.db) + self._metadata_store = record_store.metadata_store # for convenience @property def app_config(self) -> InvokeAIAppConfig: # noqa D102 @@ -259,7 +250,7 @@ def wait_for_installs(self, timeout: int = 0) -> List[ModelInstallJob]: # noqa """Block until all installation jobs are done.""" start = time.time() while len(self._download_cache) > 0: - if self._downloads_changed_event.wait(timeout=5): # in case we miss an event + if self._downloads_changed_event.wait(timeout=0.25): # in case we miss an event self._downloads_changed_event.clear() if timeout > 0 and time.time() - start > timeout: raise TimeoutError("Timeout exceeded") diff --git a/invokeai/app/services/model_manager/model_manager_default.py b/invokeai/app/services/model_manager/model_manager_default.py index 028d4af6159..b96341be69e 100644 --- a/invokeai/app/services/model_manager/model_manager_default.py +++ b/invokeai/app/services/model_manager/model_manager_default.py @@ -5,7 +5,6 @@ from invokeai.app.services.invoker import Invoker from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache -from invokeai.backend.model_manager.metadata import ModelMetadataStore from invokeai.backend.util.logging import InvokeAILogger from ..config import InvokeAIAppConfig @@ -13,8 +12,7 @@ from ..events.events_base import EventServiceBase from ..model_install import ModelInstallService, ModelInstallServiceBase from ..model_load import ModelLoadService, ModelLoadServiceBase -from ..model_records import ModelRecordServiceBase, ModelRecordServiceSQL -from ..shared.sqlite.sqlite_database import SqliteDatabase +from ..model_records import ModelRecordServiceBase from .model_manager_base import ModelManagerServiceBase @@ -64,7 +62,7 @@ def stop(self, invoker: Invoker) -> None: def build_model_manager( cls, app_config: InvokeAIAppConfig, - db: SqliteDatabase, + model_record_service: ModelRecordServiceBase, download_queue: DownloadQueueServiceBase, events: EventServiceBase, ) -> Self: @@ -82,19 +80,16 @@ def build_model_manager( convert_cache = ModelConvertCache( cache_path=app_config.models_convert_cache_path, max_size=app_config.convert_cache_size ) - record_store = ModelRecordServiceSQL(db=db) loader = ModelLoadService( app_config=app_config, - record_store=record_store, + record_store=model_record_service, ram_cache=ram_cache, convert_cache=convert_cache, ) - record_store._loader = loader # yeah, there is a circular reference here installer = ModelInstallService( app_config=app_config, - record_store=record_store, + record_store=model_record_service, download_queue=download_queue, - metadata_store=ModelMetadataStore(db=db), event_bus=events, ) - return cls(store=record_store, install=installer, load=loader) + return cls(store=model_record_service, install=installer, load=loader) diff --git a/invokeai/app/services/model_metadata/__init__.py b/invokeai/app/services/model_metadata/__init__.py new file mode 100644 index 00000000000..981c96b709b --- /dev/null +++ b/invokeai/app/services/model_metadata/__init__.py @@ -0,0 +1,9 @@ +"""Init file for ModelMetadataStoreService module.""" + +from .metadata_store_base import ModelMetadataStoreBase +from .metadata_store_sql import ModelMetadataStoreSQL + +__all__ = [ + "ModelMetadataStoreBase", + "ModelMetadataStoreSQL", +] diff --git a/invokeai/app/services/model_metadata/metadata_store_base.py b/invokeai/app/services/model_metadata/metadata_store_base.py new file mode 100644 index 00000000000..e0e4381b099 --- /dev/null +++ b/invokeai/app/services/model_metadata/metadata_store_base.py @@ -0,0 +1,65 @@ +# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team +""" +Storage for Model Metadata +""" + +from abc import ABC, abstractmethod +from typing import List, Set, Tuple + +from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata + + +class ModelMetadataStoreBase(ABC): + """Store, search and fetch model metadata retrieved from remote repositories.""" + + @abstractmethod + def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None: + """ + Add a block of repo metadata to a model record. + + The model record config must already exist in the database with the + same key. Otherwise a FOREIGN KEY constraint exception will be raised. + + :param model_key: Existing model key in the `model_config` table + :param metadata: ModelRepoMetadata object to store + """ + + @abstractmethod + def get_metadata(self, model_key: str) -> AnyModelRepoMetadata: + """Retrieve the ModelRepoMetadata corresponding to model key.""" + + @abstractmethod + def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata + """Dump out all the metadata.""" + + @abstractmethod + def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata: + """ + Update metadata corresponding to the model with the indicated key. + + :param model_key: Existing model key in the `model_config` table + :param metadata: ModelRepoMetadata object to update + """ + + @abstractmethod + def list_tags(self) -> Set[str]: + """Return all tags in the tags table.""" + + @abstractmethod + def search_by_tag(self, tags: Set[str]) -> Set[str]: + """Return the keys of models containing all of the listed tags.""" + + @abstractmethod + def search_by_author(self, author: str) -> Set[str]: + """Return the keys of models authored by the indicated author.""" + + @abstractmethod + def search_by_name(self, name: str) -> Set[str]: + """ + Return the keys of models with the indicated name. + + Note that this is the name of the model given to it by + the remote source. The user may have changed the local + name. The local name will be located in the model config + record object. + """ diff --git a/invokeai/app/services/model_metadata/metadata_store_sql.py b/invokeai/app/services/model_metadata/metadata_store_sql.py new file mode 100644 index 00000000000..afe9d2c8c69 --- /dev/null +++ b/invokeai/app/services/model_metadata/metadata_store_sql.py @@ -0,0 +1,222 @@ +# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team +""" +SQL Storage for Model Metadata +""" + +import sqlite3 +from typing import List, Optional, Set, Tuple + +from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase +from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, UnknownMetadataException +from invokeai.backend.model_manager.metadata.fetch import ModelMetadataFetchBase + +from .metadata_store_base import ModelMetadataStoreBase + + +class ModelMetadataStoreSQL(ModelMetadataStoreBase): + """Store, search and fetch model metadata retrieved from remote repositories.""" + + def __init__(self, db: SqliteDatabase): + """ + Initialize a new object from preexisting sqlite3 connection and threading lock objects. + + :param conn: sqlite3 connection object + :param lock: threading Lock object + """ + super().__init__() + self._db = db + self._cursor = self._db.conn.cursor() + + def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None: + """ + Add a block of repo metadata to a model record. + + The model record config must already exist in the database with the + same key. Otherwise a FOREIGN KEY constraint exception will be raised. + + :param model_key: Existing model key in the `model_config` table + :param metadata: ModelRepoMetadata object to store + """ + json_serialized = metadata.model_dump_json() + with self._db.lock: + try: + self._cursor.execute( + """--sql + INSERT INTO model_metadata( + id, + metadata + ) + VALUES (?,?); + """, + ( + model_key, + json_serialized, + ), + ) + self._update_tags(model_key, metadata.tags) + self._db.conn.commit() + except sqlite3.IntegrityError as excp: # FOREIGN KEY error: the key was not in model_config table + self._db.conn.rollback() + raise UnknownMetadataException from excp + except sqlite3.Error as excp: + self._db.conn.rollback() + raise excp + + def get_metadata(self, model_key: str) -> AnyModelRepoMetadata: + """Retrieve the ModelRepoMetadata corresponding to model key.""" + with self._db.lock: + self._cursor.execute( + """--sql + SELECT metadata FROM model_metadata + WHERE id=?; + """, + (model_key,), + ) + rows = self._cursor.fetchone() + if not rows: + raise UnknownMetadataException("model metadata not found") + return ModelMetadataFetchBase.from_json(rows[0]) + + def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata + """Dump out all the metadata.""" + with self._db.lock: + self._cursor.execute( + """--sql + SELECT id,metadata FROM model_metadata; + """, + (), + ) + rows = self._cursor.fetchall() + return [(x[0], ModelMetadataFetchBase.from_json(x[1])) for x in rows] + + def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata: + """ + Update metadata corresponding to the model with the indicated key. + + :param model_key: Existing model key in the `model_config` table + :param metadata: ModelRepoMetadata object to update + """ + json_serialized = metadata.model_dump_json() # turn it into a json string. + with self._db.lock: + try: + self._cursor.execute( + """--sql + UPDATE model_metadata + SET + metadata=? + WHERE id=?; + """, + (json_serialized, model_key), + ) + if self._cursor.rowcount == 0: + raise UnknownMetadataException("model metadata not found") + self._update_tags(model_key, metadata.tags) + self._db.conn.commit() + except sqlite3.Error as e: + self._db.conn.rollback() + raise e + + return self.get_metadata(model_key) + + def list_tags(self) -> Set[str]: + """Return all tags in the tags table.""" + self._cursor.execute( + """--sql + select tag_text from tags; + """ + ) + return {x[0] for x in self._cursor.fetchall()} + + def search_by_tag(self, tags: Set[str]) -> Set[str]: + """Return the keys of models containing all of the listed tags.""" + with self._db.lock: + try: + matches: Optional[Set[str]] = None + for tag in tags: + self._cursor.execute( + """--sql + SELECT a.model_id FROM model_tags AS a, + tags AS b + WHERE a.tag_id=b.tag_id + AND b.tag_text=?; + """, + (tag,), + ) + model_keys = {x[0] for x in self._cursor.fetchall()} + if matches is None: + matches = model_keys + matches = matches.intersection(model_keys) + except sqlite3.Error as e: + raise e + return matches if matches else set() + + def search_by_author(self, author: str) -> Set[str]: + """Return the keys of models authored by the indicated author.""" + self._cursor.execute( + """--sql + SELECT id FROM model_metadata + WHERE author=?; + """, + (author,), + ) + return {x[0] for x in self._cursor.fetchall()} + + def search_by_name(self, name: str) -> Set[str]: + """ + Return the keys of models with the indicated name. + + Note that this is the name of the model given to it by + the remote source. The user may have changed the local + name. The local name will be located in the model config + record object. + """ + self._cursor.execute( + """--sql + SELECT id FROM model_metadata + WHERE name=?; + """, + (name,), + ) + return {x[0] for x in self._cursor.fetchall()} + + def _update_tags(self, model_key: str, tags: Set[str]) -> None: + """Update tags for the model referenced by model_key.""" + # remove previous tags from this model + self._cursor.execute( + """--sql + DELETE FROM model_tags + WHERE model_id=?; + """, + (model_key,), + ) + + for tag in tags: + self._cursor.execute( + """--sql + INSERT OR IGNORE INTO tags ( + tag_text + ) + VALUES (?); + """, + (tag,), + ) + self._cursor.execute( + """--sql + SELECT tag_id + FROM tags + WHERE tag_text = ? + LIMIT 1; + """, + (tag,), + ) + tag_id = self._cursor.fetchone()[0] + self._cursor.execute( + """--sql + INSERT OR IGNORE INTO model_tags ( + model_id, + tag_id + ) + VALUES (?,?); + """, + (model_key, tag_id), + ) diff --git a/invokeai/app/services/model_records/model_records_base.py b/invokeai/app/services/model_records/model_records_base.py index b2eacc524b7..d6014db448a 100644 --- a/invokeai/app/services/model_records/model_records_base.py +++ b/invokeai/app/services/model_records/model_records_base.py @@ -17,7 +17,9 @@ ModelFormat, ModelType, ) -from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore +from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata + +from ..model_metadata import ModelMetadataStoreBase class DuplicateModelException(Exception): @@ -109,7 +111,7 @@ def get_model(self, key: str) -> AnyModelConfig: @property @abstractmethod - def metadata_store(self) -> ModelMetadataStore: + def metadata_store(self) -> ModelMetadataStoreBase: """Return a ModelMetadataStore initialized on the same database.""" pass diff --git a/invokeai/app/services/model_records/model_records_sql.py b/invokeai/app/services/model_records/model_records_sql.py index 84a14123838..dcd1114655b 100644 --- a/invokeai/app/services/model_records/model_records_sql.py +++ b/invokeai/app/services/model_records/model_records_sql.py @@ -54,8 +54,9 @@ ModelFormat, ModelType, ) -from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore, UnknownMetadataException +from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, UnknownMetadataException +from ..model_metadata import ModelMetadataStoreBase, ModelMetadataStoreSQL from ..shared.sqlite.sqlite_database import SqliteDatabase from .model_records_base import ( DuplicateModelException, @@ -69,7 +70,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): """Implementation of the ModelConfigStore ABC using a SQL database.""" - def __init__(self, db: SqliteDatabase): + def __init__(self, db: SqliteDatabase, metadata_store: ModelMetadataStoreBase): """ Initialize a new object from preexisting sqlite3 connection and threading lock objects. @@ -78,6 +79,7 @@ def __init__(self, db: SqliteDatabase): super().__init__() self._db = db self._cursor = db.conn.cursor() + self._metadata_store = metadata_store @property def db(self) -> SqliteDatabase: @@ -157,7 +159,7 @@ def del_model(self, key: str) -> None: self._db.conn.rollback() raise e - def update_model(self, key: str, config: Union[dict, AnyModelConfig]) -> AnyModelConfig: + def update_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig: """ Update the model, returning the updated version. @@ -307,9 +309,9 @@ def search_by_hash(self, hash: str) -> List[AnyModelConfig]: return results @property - def metadata_store(self) -> ModelMetadataStore: + def metadata_store(self) -> ModelMetadataStoreBase: """Return a ModelMetadataStore initialized on the same database.""" - return ModelMetadataStore(self._db) + return self._metadata_store def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]: """ @@ -330,18 +332,18 @@ def search_by_metadata_tag(self, tags: Set[str]) -> List[AnyModelConfig]: :param tags: Set of tags to search for. All tags must be present. """ - store = ModelMetadataStore(self._db) + store = ModelMetadataStoreSQL(self._db) keys = store.search_by_tag(tags) return [self.get_model(x) for x in keys] def list_tags(self) -> Set[str]: """Return a unique set of all the model tags in the metadata database.""" - store = ModelMetadataStore(self._db) + store = ModelMetadataStoreSQL(self._db) return store.list_tags() def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: """List metadata for all models that have it.""" - store = ModelMetadataStore(self._db) + store = ModelMetadataStoreSQL(self._db) return store.list_all_metadata() def list_models( diff --git a/invokeai/backend/install/install_helper.py b/invokeai/backend/install/install_helper.py index 9c386c209ce..3623b623a94 100644 --- a/invokeai/backend/install/install_helper.py +++ b/invokeai/backend/install/install_helper.py @@ -25,6 +25,7 @@ ModelSource, URLModelSource, ) +from invokeai.app.services.model_metadata import ModelMetadataStoreSQL from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL from invokeai.app.services.shared.sqlite.sqlite_util import init_db from invokeai.backend.model_manager import ( @@ -45,7 +46,7 @@ def initialize_record_store(app_config: InvokeAIAppConfig) -> ModelRecordService logger = InvokeAILogger.get_logger(config=app_config) image_files = DiskImageFileStorage(f"{app_config.output_path}/images") db = init_db(config=app_config, logger=logger, image_files=image_files) - obj: ModelRecordServiceBase = ModelRecordServiceSQL(db) + obj: ModelRecordServiceBase = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db)) return obj @@ -54,12 +55,10 @@ def initialize_installer( ) -> ModelInstallServiceBase: """Return an initialized ModelInstallService object.""" record_store = initialize_record_store(app_config) - metadata_store = record_store.metadata_store download_queue = DownloadQueueService() installer = ModelInstallService( app_config=app_config, record_store=record_store, - metadata_store=metadata_store, download_queue=download_queue, event_bus=event_bus, ) @@ -287,14 +286,14 @@ def add_or_delete(self, selections: InstallSelections) -> None: model_name=model_name, ) if len(matches) > 1: - print( - f"{model_to_remove} is ambiguous. Please use model_base/model_type/model_name (e.g. sd-1/main/my_model) to disambiguate." + self._logger.error( + "{model_to_remove} is ambiguous. Please use model_base/model_type/model_name (e.g. sd-1/main/my_model) to disambiguate" ) elif not matches: - print(f"{model_to_remove}: unknown model") + self._logger.error(f"{model_to_remove}: unknown model") else: for m in matches: - print(f"Deleting {m.type}:{m.name}") + self._logger.info(f"Deleting {m.type}:{m.name}") installer.delete(m.key) installer.wait_for_installs() diff --git a/invokeai/backend/model_manager/load/__init__.py b/invokeai/backend/model_manager/load/__init__.py index 966a739237a..a3a840b6259 100644 --- a/invokeai/backend/model_manager/load/__init__.py +++ b/invokeai/backend/model_manager/load/__init__.py @@ -4,10 +4,6 @@ """ from importlib import import_module from pathlib import Path -from typing import Optional - -from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend.util.logging import InvokeAILogger from .convert_cache.convert_cache_default import ModelConvertCache from .load_base import AnyModelLoader, LoadedModel @@ -19,16 +15,3 @@ import_module(f"{__package__}.model_loaders.{module}") __all__ = ["AnyModelLoader", "LoadedModel", "ModelCache", "ModelConvertCache"] - - -def get_standalone_loader(app_config: Optional[InvokeAIAppConfig]) -> AnyModelLoader: - app_config = app_config or InvokeAIAppConfig.get_config() - logger = InvokeAILogger.get_logger(config=app_config) - return AnyModelLoader( - app_config=app_config, - logger=logger, - ram_cache=ModelCache( - logger=logger, max_cache_size=app_config.ram_cache_size, max_vram_cache_size=app_config.vram_cache_size - ), - convert_cache=ModelConvertCache(app_config.models_convert_cache_path), - ) diff --git a/invokeai/backend/model_manager/load/load_base.py b/invokeai/backend/model_manager/load/load_base.py index 7649dee762b..4c5e899aa3b 100644 --- a/invokeai/backend/model_manager/load/load_base.py +++ b/invokeai/backend/model_manager/load/load_base.py @@ -39,21 +39,21 @@ class LoadedModel: """Context manager object that mediates transfer from RAM<->VRAM.""" config: AnyModelConfig - locker: ModelLockerBase + _locker: ModelLockerBase def __enter__(self) -> AnyModel: """Context entry.""" - self.locker.lock() + self._locker.lock() return self.model def __exit__(self, *args: Any, **kwargs: Any) -> None: """Context exit.""" - self.locker.unlock() + self._locker.unlock() @property def model(self) -> AnyModel: """Return the model without locking it.""" - return self.locker.model + return self._locker.model class ModelLoaderBase(ABC): diff --git a/invokeai/backend/model_manager/load/load_default.py b/invokeai/backend/model_manager/load/load_default.py index 1dac121a300..79c9311de1d 100644 --- a/invokeai/backend/model_manager/load/load_default.py +++ b/invokeai/backend/model_manager/load/load_default.py @@ -75,7 +75,7 @@ def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubMo model_path = self._convert_if_needed(model_config, model_path, submodel_type) locker = self._load_if_needed(model_config, model_path, submodel_type) - return LoadedModel(config=model_config, locker=locker) + return LoadedModel(config=model_config, _locker=locker) def _get_model_path( self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None diff --git a/invokeai/backend/model_manager/merge.py b/invokeai/backend/model_manager/merge.py index 2c94af4af3b..108f1f0e6f7 100644 --- a/invokeai/backend/model_manager/merge.py +++ b/invokeai/backend/model_manager/merge.py @@ -39,10 +39,7 @@ class ModelMerger(object): def __init__(self, installer: ModelInstallServiceBase): """ - Initialize a ModelMerger object. - - :param store: Underlying storage manager for the running process. - :param config: InvokeAIAppConfig object (if not provided, default will be selected). + Initialize a ModelMerger object with the model installer. """ self._installer = installer diff --git a/invokeai/backend/model_manager/metadata/__init__.py b/invokeai/backend/model_manager/metadata/__init__.py index 672e378c7fe..a35e55f3d24 100644 --- a/invokeai/backend/model_manager/metadata/__init__.py +++ b/invokeai/backend/model_manager/metadata/__init__.py @@ -18,7 +18,7 @@ if data.allow_commercial_use: print("Commercial use of this model is allowed") """ -from .fetch import CivitaiMetadataFetch, HuggingFaceMetadataFetch +from .fetch import CivitaiMetadataFetch, HuggingFaceMetadataFetch, ModelMetadataFetchBase from .metadata_base import ( AnyModelRepoMetadata, AnyModelRepoMetadataValidator, @@ -31,7 +31,6 @@ RemoteModelFile, UnknownMetadataException, ) -from .metadata_store import ModelMetadataStore __all__ = [ "AnyModelRepoMetadata", @@ -42,7 +41,7 @@ "HuggingFaceMetadata", "HuggingFaceMetadataFetch", "LicenseRestrictions", - "ModelMetadataStore", + "ModelMetadataFetchBase", "BaseMetadata", "ModelMetadataWithFiles", "RemoteModelFile", diff --git a/invokeai/frontend/merge/merge_diffusers.py b/invokeai/frontend/merge/merge_diffusers.py index 92b98b52f96..5484040674d 100644 --- a/invokeai/frontend/merge/merge_diffusers.py +++ b/invokeai/frontend/merge/merge_diffusers.py @@ -6,20 +6,40 @@ """ import argparse import curses +import re import sys from argparse import Namespace from pathlib import Path -from typing import List +from typing import List, Optional, Tuple import npyscreen from npyscreen import widget -import invokeai.backend.util.logging as logger from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend.model_management import BaseModelType, ModelManager, ModelMerger, ModelType +from invokeai.app.services.download import DownloadQueueService +from invokeai.app.services.image_files.image_files_disk import DiskImageFileStorage +from invokeai.app.services.model_install import ModelInstallService +from invokeai.app.services.model_metadata import ModelMetadataStoreSQL +from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL +from invokeai.app.services.shared.sqlite.sqlite_util import init_db +from invokeai.backend.model_manager import ( + BaseModelType, + ModelFormat, + ModelType, + ModelVariantType, +) +from invokeai.backend.model_manager.merge import ModelMerger +from invokeai.backend.util.logging import InvokeAILogger from invokeai.frontend.install.widgets import FloatTitleSlider, SingleSelectColumns, TextBox config = InvokeAIAppConfig.get_config() +logger = InvokeAILogger.get_logger() + +BASE_TYPES = [ + (BaseModelType.StableDiffusion1, "Models Built on SD-1.x"), + (BaseModelType.StableDiffusion2, "Models Built on SD-2.x"), + (BaseModelType.StableDiffusionXL, "Models Built on SDXL"), +] def _parse_args() -> Namespace: @@ -48,7 +68,7 @@ def _parse_args() -> Namespace: parser.add_argument( "--base_model", type=str, - choices=[x.value for x in BaseModelType], + choices=[x[0].value for x in BASE_TYPES], help="The base model shared by the models to be merged", ) parser.add_argument( @@ -98,17 +118,17 @@ def __init__(self, parentApp, name): super().__init__(parentApp, name) @property - def model_manager(self): - return self.parentApp.model_manager + def record_store(self): + return self.parentApp.record_store def afterEditing(self): self.parentApp.setNextForm(None) def create(self): window_height, window_width = curses.initscr().getmaxyx() - - self.model_names = self.get_model_names() self.current_base = 0 + self.models = self.get_models(BASE_TYPES[self.current_base][0]) + self.model_names = [x[1] for x in self.models] max_width = max([len(x) for x in self.model_names]) max_width += 6 horizontal_layout = max_width * 3 < window_width @@ -128,11 +148,7 @@ def create(self): self.nextrely += 1 self.base_select = self.add_widget_intelligent( SingleSelectColumns, - values=[ - "Models Built on SD-1.x", - "Models Built on SD-2.x", - "Models Built on SDXL", - ], + values=[x[1] for x in BASE_TYPES], value=[self.current_base], columns=4, max_height=2, @@ -263,21 +279,20 @@ def on_cancel(self): sys.exit(0) def marshall_arguments(self) -> dict: - model_names = self.model_names + model_keys = [x[0] for x in self.models] models = [ - model_names[self.model1.value[0]], - model_names[self.model2.value[0]], + model_keys[self.model1.value[0]], + model_keys[self.model2.value[0]], ] if self.model3.value[0] > 0: - models.append(model_names[self.model3.value[0] - 1]) + models.append(model_keys[self.model3.value[0] - 1]) interp = "add_difference" else: interp = self.interpolations[self.merge_method.value[0]] - bases = ["sd-1", "sd-2", "sdxl"] args = { - "model_names": models, - "base_model": BaseModelType(bases[self.base_select.value[0]]), + "model_keys": models, + "base_model": tuple(BaseModelType)[self.base_select.value[0]], "alpha": self.alpha.value, "interp": interp, "force": self.force.value, @@ -311,18 +326,18 @@ def validate_field_values(self) -> bool: else: return True - def get_model_names(self, base_model: BaseModelType = BaseModelType.StableDiffusion1) -> List[str]: - model_names = [ - info["model_name"] - for info in self.model_manager.list_models(model_type=ModelType.Main, base_model=base_model) - if info["model_format"] == "diffusers" + def get_models(self, base_model: Optional[BaseModelType] = None) -> List[Tuple[str, str]]: # key to name + models = [ + (x.key, x.name) + for x in self.record_store.search_by_attr(model_type=ModelType.Main, base_model=base_model) + if x.format == ModelFormat("diffusers") and x.variant == ModelVariantType("normal") ] - return sorted(model_names) + return sorted(models, key=lambda x: x[1]) - def _populate_models(self, value=None): - bases = ["sd-1", "sd-2", "sdxl"] - base_model = BaseModelType(bases[value[0]]) - self.model_names = self.get_model_names(base_model) + def _populate_models(self, value: List[int]): + base_model = BASE_TYPES[value[0]][0] + self.models = self.get_models(base_model) + self.model_names = [x[1] for x in self.models] models_plus_none = self.model_names.copy() models_plus_none.insert(0, "None") @@ -334,24 +349,24 @@ def _populate_models(self, value=None): class Mergeapp(npyscreen.NPSAppManaged): - def __init__(self, model_manager: ModelManager): + def __init__(self, record_store: ModelRecordServiceBase): super().__init__() - self.model_manager = model_manager + self.record_store = record_store def onStart(self): npyscreen.setTheme(npyscreen.Themes.ElegantTheme) self.main = self.addForm("MAIN", mergeModelsForm, name="Merge Models Settings") -def run_gui(args: Namespace): - model_manager = ModelManager(config.model_conf_path) - mergeapp = Mergeapp(model_manager) +def run_gui(args: Namespace) -> None: + record_store: ModelRecordServiceBase = get_config_store() + mergeapp = Mergeapp(record_store) mergeapp.run() - args = mergeapp.merge_arguments - merger = ModelMerger(model_manager) + merger = get_model_merger(record_store) merger.merge_diffusion_models_and_save(**args) - logger.info(f'Models merged into new model: "{args["merged_model_name"]}".') + merged_model_name = args["merged_model_name"] + logger.info(f'Models merged into new model: "{merged_model_name}".') def run_cli(args: Namespace): @@ -364,20 +379,54 @@ def run_cli(args: Namespace): args.merged_model_name = "+".join(args.model_names) logger.info(f'No --merged_model_name provided. Defaulting to "{args.merged_model_name}"') - model_manager = ModelManager(config.model_conf_path) + record_store: ModelRecordServiceBase = get_config_store() assert ( - not model_manager.model_exists(args.merged_model_name, args.base_model, ModelType.Main) or args.clobber + len(record_store.search_by_attr(args.merged_model_name, args.base_model, ModelType.Main)) == 0 or args.clobber ), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.' - merger = ModelMerger(model_manager) - merger.merge_diffusion_models_and_save(**vars(args)) + merger = get_model_merger(record_store) + model_keys = [] + for name in args.model_names: + if len(name) == 32 and re.match(r"^[0-9a-f]$", name): + model_keys.append(name) + else: + models = record_store.search_by_attr( + model_name=name, model_type=ModelType.Main, base_model=BaseModelType(args.base_model) + ) + assert len(models) > 0, f"{name}: Unknown model" + assert len(models) < 2, f"{name}: More than one model by this name. Please specify the model key instead." + model_keys.append(models[0].key) + + merger.merge_diffusion_models_and_save( + alpha=args.alpha, + model_keys=model_keys, + merged_model_name=args.merged_model_name, + interp=args.interp, + force=args.force, + ) logger.info(f'Models merged into new model: "{args.merged_model_name}".') +def get_config_store() -> ModelRecordServiceSQL: + output_path = config.output_path + assert output_path is not None + image_files = DiskImageFileStorage(output_path / "images") + db = init_db(config=config, logger=InvokeAILogger.get_logger(), image_files=image_files) + return ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db)) + + +def get_model_merger(record_store: ModelRecordServiceBase) -> ModelMerger: + installer = ModelInstallService(app_config=config, record_store=record_store, download_queue=DownloadQueueService()) + installer.start() + return ModelMerger(installer) + + def main(): args = _parse_args() if args.root_dir: config.parse_args(["--root", str(args.root_dir)]) + else: + config.parse_args([]) try: if args.front_end: diff --git a/tests/aa_nodes/test_invoker.py b/tests/aa_nodes/test_invoker.py index 774f7501dc2..f67b5a2ac55 100644 --- a/tests/aa_nodes/test_invoker.py +++ b/tests/aa_nodes/test_invoker.py @@ -1,4 +1,5 @@ import logging +from unittest.mock import Mock import pytest @@ -64,7 +65,7 @@ def mock_services() -> InvocationServices: images=None, # type: ignore invocation_cache=MemoryInvocationCache(max_cache_size=0), logger=logging, # type: ignore - model_manager=None, # type: ignore + model_manager=Mock(), # type: ignore download_queue=None, # type: ignore names=None, # type: ignore performance_statistics=InvocationStatsService(), diff --git a/tests/app/services/model_records/test_model_records_sql.py b/tests/app/services/model_records/test_model_records_sql.py index 46afe0105b5..852e1da979c 100644 --- a/tests/app/services/model_records/test_model_records_sql.py +++ b/tests/app/services/model_records/test_model_records_sql.py @@ -8,6 +8,7 @@ import pytest from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.app.services.model_metadata import ModelMetadataStoreSQL from invokeai.app.services.model_records import ( DuplicateModelException, ModelRecordOrderBy, @@ -36,7 +37,7 @@ def store( config = InvokeAIAppConfig(root=datadir) logger = InvokeAILogger.get_logger(config=config) db = create_mock_sqlite_database(config, logger) - return ModelRecordServiceSQL(db) + return ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db)) def example_config() -> TextualInversionConfig: diff --git a/tests/backend/model_manager_2/model_manager_2_fixtures.py b/tests/backend/model_manager_2/model_manager_2_fixtures.py index d85eab67dd3..ebdc9cb5cd6 100644 --- a/tests/backend/model_manager_2/model_manager_2_fixtures.py +++ b/tests/backend/model_manager_2/model_manager_2_fixtures.py @@ -14,6 +14,7 @@ from invokeai.app.services.download import DownloadQueueService from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.model_install import ModelInstallService, ModelInstallServiceBase +from invokeai.app.services.model_metadata import ModelMetadataStoreBase, ModelMetadataStoreSQL from invokeai.app.services.model_records import ModelRecordServiceSQL from invokeai.backend.model_manager.config import ( BaseModelType, @@ -21,7 +22,6 @@ ModelType, ) from invokeai.backend.model_manager.load import AnyModelLoader, ModelCache, ModelConvertCache -from invokeai.backend.model_manager.metadata import ModelMetadataStore from invokeai.backend.util.logging import InvokeAILogger from tests.backend.model_manager_2.model_metadata.metadata_examples import ( RepoCivitaiModelMetadata1, @@ -104,7 +104,7 @@ def mm2_loader(mm2_app_config: InvokeAIAppConfig, mm2_record_store: ModelRecordS def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceSQL: logger = InvokeAILogger.get_logger(config=mm2_app_config) db = create_mock_sqlite_database(mm2_app_config, logger) - store = ModelRecordServiceSQL(db) + store = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db)) # add five simple config records to the database raw1 = { "path": "/tmp/foo1", @@ -163,15 +163,14 @@ def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceSQL @pytest.fixture -def mm2_metadata_store(mm2_record_store: ModelRecordServiceSQL) -> ModelMetadataStore: - db = mm2_record_store._db # to ensure we are sharing the same database - return ModelMetadataStore(db) +def mm2_metadata_store(mm2_record_store: ModelRecordServiceSQL) -> ModelMetadataStoreBase: + return mm2_record_store.metadata_store @pytest.fixture def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session: """This fixtures defines a series of mock URLs for testing download and installation.""" - sess = TestSession() + sess: Session = TestSession() sess.mount( "https://test.com/missing_model.safetensors", TestAdapter( @@ -258,8 +257,7 @@ def mm2_installer(mm2_app_config: InvokeAIAppConfig, mm2_session: Session) -> Mo logger = InvokeAILogger.get_logger() db = create_mock_sqlite_database(mm2_app_config, logger) events = DummyEventService() - store = ModelRecordServiceSQL(db) - metadata_store = ModelMetadataStore(db) + store = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db)) download_queue = DownloadQueueService(requests_session=mm2_session) download_queue.start() @@ -268,7 +266,6 @@ def mm2_installer(mm2_app_config: InvokeAIAppConfig, mm2_session: Session) -> Mo app_config=mm2_app_config, record_store=store, download_queue=download_queue, - metadata_store=metadata_store, event_bus=events, session=mm2_session, ) diff --git a/tests/backend/model_manager_2/model_metadata/test_model_metadata.py b/tests/backend/model_manager_2/model_metadata/test_model_metadata.py index 5a2ec937673..f61eab1b5d0 100644 --- a/tests/backend/model_manager_2/model_metadata/test_model_metadata.py +++ b/tests/backend/model_manager_2/model_metadata/test_model_metadata.py @@ -8,6 +8,7 @@ from pydantic.networks import HttpUrl from requests.sessions import Session +from invokeai.app.services.model_metadata import ModelMetadataStoreBase from invokeai.backend.model_manager.config import ModelRepoVariant from invokeai.backend.model_manager.metadata import ( CivitaiMetadata, @@ -15,14 +16,13 @@ CommercialUsage, HuggingFaceMetadata, HuggingFaceMetadataFetch, - ModelMetadataStore, UnknownMetadataException, ) from invokeai.backend.model_manager.util import select_hf_files from tests.backend.model_manager_2.model_manager_2_fixtures import * # noqa F403 -def test_metadata_store_put_get(mm2_metadata_store: ModelMetadataStore) -> None: +def test_metadata_store_put_get(mm2_metadata_store: ModelMetadataStoreBase) -> None: tags = {"text-to-image", "diffusers"} input_metadata = HuggingFaceMetadata( name="sdxl-vae", @@ -40,7 +40,7 @@ def test_metadata_store_put_get(mm2_metadata_store: ModelMetadataStore) -> None: assert mm2_metadata_store.list_tags() == tags -def test_metadata_store_update(mm2_metadata_store: ModelMetadataStore) -> None: +def test_metadata_store_update(mm2_metadata_store: ModelMetadataStoreBase) -> None: input_metadata = HuggingFaceMetadata( name="sdxl-vae", author="stabilityai", @@ -57,7 +57,7 @@ def test_metadata_store_update(mm2_metadata_store: ModelMetadataStore) -> None: assert input_metadata == output_metadata -def test_metadata_search(mm2_metadata_store: ModelMetadataStore) -> None: +def test_metadata_search(mm2_metadata_store: ModelMetadataStoreBase) -> None: metadata1 = HuggingFaceMetadata( name="sdxl-vae", author="stabilityai", From a5d4d1ee47d61062b055f714b9b8b2be94b08040 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sat, 17 Feb 2024 11:45:32 -0500 Subject: [PATCH 107/340] Tidy names and locations of modules - Rename old "model_management" directory to "model_management_OLD" in order to catch dangling references to original model manager. - Caught and fixed most dangling references (still checking) - Rename lora, textual_inversion and model_patcher modules - Introduce a RawModel base class to simplfy the Union returned by the model loaders. - Tidy up the model manager 2-related tests. Add useful fixtures, and a finalizer to the queue and installer fixtures that will stop the services and release threads. --- invokeai/app/invocations/compel.py | 6 +- invokeai/app/invocations/latent.py | 4 +- .../services/model_load/model_load_default.py | 23 +- .../app/services/model_manager/__init__.py | 3 +- invokeai/backend/ip_adapter/ip_adapter.py | 3 +- invokeai/backend/{embeddings => }/lora.py | 7 +- .../README.md | 0 .../__init__.py | 0 .../convert_ckpt_to_diffusers.py | 0 .../detect_baked_in_vae.py | 0 .../lora.py | 0 .../memory_snapshot.py | 0 .../model_cache.py | 0 .../model_load_optimizations.py | 0 .../model_manager.py | 0 .../model_merge.py | 0 .../model_probe.py | 0 .../model_search.py | 0 .../models/__init__.py | 0 .../models/base.py | 0 .../models/clip_vision.py | 0 .../models/controlnet.py | 0 .../models/ip_adapter.py | 0 .../models/lora.py | 0 .../models/sdxl.py | 0 .../models/stable_diffusion.py | 0 .../models/stable_diffusion_onnx.py | 0 .../models/t2i_adapter.py | 0 .../models/textual_inversion.py | 0 .../models/vae.py | 0 .../seamless.py | 0 .../util.py | 0 invokeai/backend/model_manager/config.py | 9 +- .../libc_util.py | 0 .../model_manager/load/memory_snapshot.py | 4 +- .../model_manager/load/model_loaders/lora.py | 2 +- .../load/model_loaders/textual_inversion.py | 2 +- invokeai/backend/model_manager/probe.py | 9 +- .../backend/model_manager/util/libc_util.py | 75 ++ .../backend/model_manager/util/model_util.py | 129 +++ .../backend/{embeddings => }/model_patcher.py | 0 invokeai/backend/onnx/onnx_runtime.py | 1 + invokeai/backend/raw_model.py | 14 + .../{embeddings => }/textual_inversion.py | 6 +- invokeai/backend/util/test_utils.py | 45 +- invokeai/configs/INITIAL_MODELS.yaml.OLD | 153 ---- invokeai/configs/models.yaml.example | 47 - .../frontend/install/model_install.py.OLD | 845 ------------------ .../frontend/merge/merge_diffusers.py.OLD | 438 --------- .../model_install/test_model_install.py | 2 +- .../model_records/test_model_records_sql.py | 2 +- tests/backend/ip_adapter/test_ip_adapter.py | 2 +- .../data/invokeai_root/README | 0 .../stable-diffusion/v1-inference.yaml | 0 .../data/invokeai_root/databases/README | 0 .../data/invokeai_root/models/README | 0 .../test-diffusers-main/model_index.json | 0 .../scheduler/scheduler_config.json | 0 .../text_encoder/config.json | 0 .../text_encoder/model.fp16.safetensors | 0 .../text_encoder/model.safetensors | 0 .../text_encoder_2/config.json | 0 .../text_encoder_2/model.fp16.safetensors | 0 .../text_encoder_2/model.safetensors | 0 .../test-diffusers-main/tokenizer/merges.txt | 0 .../tokenizer/special_tokens_map.json | 0 .../tokenizer/tokenizer_config.json | 0 .../test-diffusers-main/tokenizer/vocab.json | 0 .../tokenizer_2/merges.txt | 0 .../tokenizer_2/special_tokens_map.json | 0 .../tokenizer_2/tokenizer_config.json | 0 .../tokenizer_2/vocab.json | 0 .../test-diffusers-main/unet/config.json | 0 .../diffusion_pytorch_model.fp16.safetensors | 0 .../unet/diffusion_pytorch_model.safetensors | 0 .../test-diffusers-main/vae/config.json | 0 .../diffusion_pytorch_model.fp16.safetensors | 0 .../vae/diffusion_pytorch_model.safetensors | 0 .../test_files/test_embedding.safetensors | Bin .../model_loading/test_model_load.py | 11 +- .../model_manager_fixtures.py} | 101 ++- .../model_metadata/metadata_examples.py | 0 .../model_metadata/test_model_metadata.py | 2 +- .../test_libc_util.py | 2 +- .../test_lora.py | 4 +- .../test_memory_snapshot.py | 6 +- .../test_model_load_optimization.py | 2 +- .../util/test_hf_model_select.py | 0 tests/conftest.py | 5 - 89 files changed, 355 insertions(+), 1609 deletions(-) rename invokeai/backend/{embeddings => }/lora.py (99%) rename invokeai/backend/{model_management => model_management_OLD}/README.md (100%) rename invokeai/backend/{model_management => model_management_OLD}/__init__.py (100%) rename invokeai/backend/{model_management => model_management_OLD}/convert_ckpt_to_diffusers.py (100%) rename invokeai/backend/{model_management => model_management_OLD}/detect_baked_in_vae.py (100%) rename invokeai/backend/{model_management => model_management_OLD}/lora.py (100%) rename invokeai/backend/{model_management => model_management_OLD}/memory_snapshot.py (100%) rename invokeai/backend/{model_management => model_management_OLD}/model_cache.py (100%) rename invokeai/backend/{model_management => model_management_OLD}/model_load_optimizations.py (100%) rename invokeai/backend/{model_management => model_management_OLD}/model_manager.py (100%) rename invokeai/backend/{model_management => model_management_OLD}/model_merge.py (100%) rename invokeai/backend/{model_management => model_management_OLD}/model_probe.py (100%) rename invokeai/backend/{model_management => model_management_OLD}/model_search.py (100%) rename invokeai/backend/{model_management => model_management_OLD}/models/__init__.py (100%) rename invokeai/backend/{model_management => model_management_OLD}/models/base.py (100%) rename invokeai/backend/{model_management => model_management_OLD}/models/clip_vision.py (100%) rename invokeai/backend/{model_management => model_management_OLD}/models/controlnet.py (100%) rename invokeai/backend/{model_management => model_management_OLD}/models/ip_adapter.py (100%) rename invokeai/backend/{model_management => model_management_OLD}/models/lora.py (100%) rename invokeai/backend/{model_management => model_management_OLD}/models/sdxl.py (100%) rename invokeai/backend/{model_management => model_management_OLD}/models/stable_diffusion.py (100%) rename invokeai/backend/{model_management => model_management_OLD}/models/stable_diffusion_onnx.py (100%) rename invokeai/backend/{model_management => model_management_OLD}/models/t2i_adapter.py (100%) rename invokeai/backend/{model_management => model_management_OLD}/models/textual_inversion.py (100%) rename invokeai/backend/{model_management => model_management_OLD}/models/vae.py (100%) rename invokeai/backend/{model_management => model_management_OLD}/seamless.py (100%) rename invokeai/backend/{model_management => model_management_OLD}/util.py (100%) rename invokeai/backend/{model_management => model_manager}/libc_util.py (100%) create mode 100644 invokeai/backend/model_manager/util/libc_util.py create mode 100644 invokeai/backend/model_manager/util/model_util.py rename invokeai/backend/{embeddings => }/model_patcher.py (100%) create mode 100644 invokeai/backend/raw_model.py rename invokeai/backend/{embeddings => }/textual_inversion.py (97%) delete mode 100644 invokeai/configs/INITIAL_MODELS.yaml.OLD delete mode 100644 invokeai/configs/models.yaml.example delete mode 100644 invokeai/frontend/install/model_install.py.OLD delete mode 100644 invokeai/frontend/merge/merge_diffusers.py.OLD rename tests/backend/{model_manager_2 => model_manager}/data/invokeai_root/README (100%) rename tests/backend/{model_manager_2 => model_manager}/data/invokeai_root/configs/stable-diffusion/v1-inference.yaml (100%) rename tests/backend/{model_manager_2 => model_manager}/data/invokeai_root/databases/README (100%) rename tests/backend/{model_manager_2 => model_manager}/data/invokeai_root/models/README (100%) rename tests/backend/{model_manager_2 => model_manager}/data/test_files/test-diffusers-main/model_index.json (100%) rename tests/backend/{model_manager_2 => model_manager}/data/test_files/test-diffusers-main/scheduler/scheduler_config.json (100%) rename tests/backend/{model_manager_2 => model_manager}/data/test_files/test-diffusers-main/text_encoder/config.json (100%) rename tests/backend/{model_manager_2 => model_manager}/data/test_files/test-diffusers-main/text_encoder/model.fp16.safetensors (100%) rename tests/backend/{model_manager_2 => model_manager}/data/test_files/test-diffusers-main/text_encoder/model.safetensors (100%) rename tests/backend/{model_manager_2 => model_manager}/data/test_files/test-diffusers-main/text_encoder_2/config.json (100%) rename tests/backend/{model_manager_2 => model_manager}/data/test_files/test-diffusers-main/text_encoder_2/model.fp16.safetensors (100%) rename tests/backend/{model_manager_2 => model_manager}/data/test_files/test-diffusers-main/text_encoder_2/model.safetensors (100%) rename tests/backend/{model_manager_2 => model_manager}/data/test_files/test-diffusers-main/tokenizer/merges.txt (100%) rename tests/backend/{model_manager_2 => model_manager}/data/test_files/test-diffusers-main/tokenizer/special_tokens_map.json (100%) rename tests/backend/{model_manager_2 => model_manager}/data/test_files/test-diffusers-main/tokenizer/tokenizer_config.json (100%) rename tests/backend/{model_manager_2 => model_manager}/data/test_files/test-diffusers-main/tokenizer/vocab.json (100%) rename tests/backend/{model_manager_2 => model_manager}/data/test_files/test-diffusers-main/tokenizer_2/merges.txt (100%) rename tests/backend/{model_manager_2 => model_manager}/data/test_files/test-diffusers-main/tokenizer_2/special_tokens_map.json (100%) rename tests/backend/{model_manager_2 => model_manager}/data/test_files/test-diffusers-main/tokenizer_2/tokenizer_config.json (100%) rename tests/backend/{model_manager_2 => model_manager}/data/test_files/test-diffusers-main/tokenizer_2/vocab.json (100%) rename tests/backend/{model_manager_2 => model_manager}/data/test_files/test-diffusers-main/unet/config.json (100%) rename tests/backend/{model_manager_2 => model_manager}/data/test_files/test-diffusers-main/unet/diffusion_pytorch_model.fp16.safetensors (100%) rename tests/backend/{model_manager_2 => model_manager}/data/test_files/test-diffusers-main/unet/diffusion_pytorch_model.safetensors (100%) rename tests/backend/{model_manager_2 => model_manager}/data/test_files/test-diffusers-main/vae/config.json (100%) rename tests/backend/{model_manager_2 => model_manager}/data/test_files/test-diffusers-main/vae/diffusion_pytorch_model.fp16.safetensors (100%) rename tests/backend/{model_manager_2 => model_manager}/data/test_files/test-diffusers-main/vae/diffusion_pytorch_model.safetensors (100%) rename tests/backend/{model_manager_2 => model_manager}/data/test_files/test_embedding.safetensors (100%) rename tests/backend/{model_manager_2 => model_manager}/model_loading/test_model_load.py (61%) rename tests/backend/{model_manager_2/model_manager_2_fixtures.py => model_manager/model_manager_fixtures.py} (80%) rename tests/backend/{model_manager_2 => model_manager}/model_metadata/metadata_examples.py (100%) rename tests/backend/{model_manager_2 => model_manager}/model_metadata/test_model_metadata.py (99%) rename tests/backend/{model_management => model_manager}/test_libc_util.py (88%) rename tests/backend/{model_management => model_manager}/test_lora.py (96%) rename tests/backend/{model_management => model_manager}/test_memory_snapshot.py (87%) rename tests/backend/{model_management => model_manager}/test_model_load_optimization.py (96%) rename tests/backend/{model_manager_2 => model_manager}/util/test_hf_model_select.py (100%) diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 5159d5b89c5..593121ba60b 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -17,9 +17,9 @@ from invokeai.app.services.model_records import UnknownModelException from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.ti_utils import extract_ti_triggers_from_prompt -from invokeai.backend.embeddings.lora import LoRAModelRaw -from invokeai.backend.embeddings.model_patcher import ModelPatcher -from invokeai.backend.embeddings.textual_inversion import TextualInversionModelRaw +from invokeai.backend.lora import LoRAModelRaw +from invokeai.backend.model_patcher import ModelPatcher +from invokeai.backend.textual_inversion import TextualInversionModelRaw from invokeai.backend.model_manager import ModelType from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( BasicConditioningInfo, diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 1f21b539dc9..bfe7255b628 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -50,10 +50,10 @@ from invokeai.app.invocations.t2i_adapter import T2IAdapterField from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.controlnet_utils import prepare_control_image -from invokeai.backend.embeddings.lora import LoRAModelRaw -from invokeai.backend.embeddings.model_patcher import ModelPatcher from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus +from invokeai.backend.lora import LoRAModelRaw from invokeai.backend.model_manager import BaseModelType, LoadedModel +from invokeai.backend.model_patcher import ModelPatcher from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData, IPAdapterConditioningInfo from invokeai.backend.util.silence_warnings import SilenceWarnings diff --git a/invokeai/app/services/model_load/model_load_default.py b/invokeai/app/services/model_load/model_load_default.py index 29b297c8145..fa96a4672d1 100644 --- a/invokeai/app/services/model_load/model_load_default.py +++ b/invokeai/app/services/model_load/model_load_default.py @@ -21,11 +21,11 @@ class ModelLoadService(ModelLoadServiceBase): """Wrapper around AnyModelLoader.""" def __init__( - self, - app_config: InvokeAIAppConfig, - record_store: ModelRecordServiceBase, - ram_cache: Optional[ModelCacheBase[AnyModel]] = None, - convert_cache: Optional[ModelConvertCacheBase] = None, + self, + app_config: InvokeAIAppConfig, + record_store: ModelRecordServiceBase, + ram_cache: ModelCacheBase[AnyModel], + convert_cache: ModelConvertCacheBase, ): """Initialize the model load service.""" logger = InvokeAILogger.get_logger(self.__class__.__name__) @@ -34,17 +34,8 @@ def __init__( self._any_loader = AnyModelLoader( app_config=app_config, logger=logger, - ram_cache=ram_cache - or ModelCache( - max_cache_size=app_config.ram_cache_size, - max_vram_cache_size=app_config.vram_cache_size, - logger=logger, - ), - convert_cache=convert_cache - or ModelConvertCache( - cache_path=app_config.models_convert_cache_path, - max_size=app_config.convert_cache_size, - ), + ram_cache=ram_cache, + convert_cache=convert_cache, ) def start(self, invoker: Invoker) -> None: diff --git a/invokeai/app/services/model_manager/__init__.py b/invokeai/app/services/model_manager/__init__.py index 5e281922a8b..66707493f71 100644 --- a/invokeai/app/services/model_manager/__init__.py +++ b/invokeai/app/services/model_manager/__init__.py @@ -3,9 +3,10 @@ from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType from invokeai.backend.model_manager.load import LoadedModel -from .model_manager_default import ModelManagerService +from .model_manager_default import ModelManagerServiceBase, ModelManagerService __all__ = [ + "ModelManagerServiceBase", "ModelManagerService", "AnyModel", "AnyModelConfig", diff --git a/invokeai/backend/ip_adapter/ip_adapter.py b/invokeai/backend/ip_adapter/ip_adapter.py index b4706ea99c0..3ba6fc5a23c 100644 --- a/invokeai/backend/ip_adapter/ip_adapter.py +++ b/invokeai/backend/ip_adapter/ip_adapter.py @@ -10,6 +10,7 @@ from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionWeights from .resampler import Resampler +from ..raw_model import RawModel class ImageProjModel(torch.nn.Module): @@ -91,7 +92,7 @@ def forward(self, image_embeds): return clip_extra_context_tokens -class IPAdapter: +class IPAdapter(RawModel): """IP-Adapter: https://arxiv.org/pdf/2308.06721.pdf""" def __init__( diff --git a/invokeai/backend/embeddings/lora.py b/invokeai/backend/lora.py similarity index 99% rename from invokeai/backend/embeddings/lora.py rename to invokeai/backend/lora.py index 3c7ef074efe..fb0c23067fb 100644 --- a/invokeai/backend/embeddings/lora.py +++ b/invokeai/backend/lora.py @@ -10,8 +10,7 @@ from typing_extensions import Self from invokeai.backend.model_manager import BaseModelType - -from .embedding_base import EmbeddingModelRaw +from .raw_model import RawModel class LoRALayerBase: @@ -367,9 +366,7 @@ def to( AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer] - -# TODO: rename all methods used in model logic with Info postfix and remove here Raw postfix -class LoRAModelRaw(EmbeddingModelRaw): # (torch.nn.Module): +class LoRAModelRaw(RawModel): # (torch.nn.Module): _name: str layers: Dict[str, AnyLoRALayer] diff --git a/invokeai/backend/model_management/README.md b/invokeai/backend/model_management_OLD/README.md similarity index 100% rename from invokeai/backend/model_management/README.md rename to invokeai/backend/model_management_OLD/README.md diff --git a/invokeai/backend/model_management/__init__.py b/invokeai/backend/model_management_OLD/__init__.py similarity index 100% rename from invokeai/backend/model_management/__init__.py rename to invokeai/backend/model_management_OLD/__init__.py diff --git a/invokeai/backend/model_management/convert_ckpt_to_diffusers.py b/invokeai/backend/model_management_OLD/convert_ckpt_to_diffusers.py similarity index 100% rename from invokeai/backend/model_management/convert_ckpt_to_diffusers.py rename to invokeai/backend/model_management_OLD/convert_ckpt_to_diffusers.py diff --git a/invokeai/backend/model_management/detect_baked_in_vae.py b/invokeai/backend/model_management_OLD/detect_baked_in_vae.py similarity index 100% rename from invokeai/backend/model_management/detect_baked_in_vae.py rename to invokeai/backend/model_management_OLD/detect_baked_in_vae.py diff --git a/invokeai/backend/model_management/lora.py b/invokeai/backend/model_management_OLD/lora.py similarity index 100% rename from invokeai/backend/model_management/lora.py rename to invokeai/backend/model_management_OLD/lora.py diff --git a/invokeai/backend/model_management/memory_snapshot.py b/invokeai/backend/model_management_OLD/memory_snapshot.py similarity index 100% rename from invokeai/backend/model_management/memory_snapshot.py rename to invokeai/backend/model_management_OLD/memory_snapshot.py diff --git a/invokeai/backend/model_management/model_cache.py b/invokeai/backend/model_management_OLD/model_cache.py similarity index 100% rename from invokeai/backend/model_management/model_cache.py rename to invokeai/backend/model_management_OLD/model_cache.py diff --git a/invokeai/backend/model_management/model_load_optimizations.py b/invokeai/backend/model_management_OLD/model_load_optimizations.py similarity index 100% rename from invokeai/backend/model_management/model_load_optimizations.py rename to invokeai/backend/model_management_OLD/model_load_optimizations.py diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management_OLD/model_manager.py similarity index 100% rename from invokeai/backend/model_management/model_manager.py rename to invokeai/backend/model_management_OLD/model_manager.py diff --git a/invokeai/backend/model_management/model_merge.py b/invokeai/backend/model_management_OLD/model_merge.py similarity index 100% rename from invokeai/backend/model_management/model_merge.py rename to invokeai/backend/model_management_OLD/model_merge.py diff --git a/invokeai/backend/model_management/model_probe.py b/invokeai/backend/model_management_OLD/model_probe.py similarity index 100% rename from invokeai/backend/model_management/model_probe.py rename to invokeai/backend/model_management_OLD/model_probe.py diff --git a/invokeai/backend/model_management/model_search.py b/invokeai/backend/model_management_OLD/model_search.py similarity index 100% rename from invokeai/backend/model_management/model_search.py rename to invokeai/backend/model_management_OLD/model_search.py diff --git a/invokeai/backend/model_management/models/__init__.py b/invokeai/backend/model_management_OLD/models/__init__.py similarity index 100% rename from invokeai/backend/model_management/models/__init__.py rename to invokeai/backend/model_management_OLD/models/__init__.py diff --git a/invokeai/backend/model_management/models/base.py b/invokeai/backend/model_management_OLD/models/base.py similarity index 100% rename from invokeai/backend/model_management/models/base.py rename to invokeai/backend/model_management_OLD/models/base.py diff --git a/invokeai/backend/model_management/models/clip_vision.py b/invokeai/backend/model_management_OLD/models/clip_vision.py similarity index 100% rename from invokeai/backend/model_management/models/clip_vision.py rename to invokeai/backend/model_management_OLD/models/clip_vision.py diff --git a/invokeai/backend/model_management/models/controlnet.py b/invokeai/backend/model_management_OLD/models/controlnet.py similarity index 100% rename from invokeai/backend/model_management/models/controlnet.py rename to invokeai/backend/model_management_OLD/models/controlnet.py diff --git a/invokeai/backend/model_management/models/ip_adapter.py b/invokeai/backend/model_management_OLD/models/ip_adapter.py similarity index 100% rename from invokeai/backend/model_management/models/ip_adapter.py rename to invokeai/backend/model_management_OLD/models/ip_adapter.py diff --git a/invokeai/backend/model_management/models/lora.py b/invokeai/backend/model_management_OLD/models/lora.py similarity index 100% rename from invokeai/backend/model_management/models/lora.py rename to invokeai/backend/model_management_OLD/models/lora.py diff --git a/invokeai/backend/model_management/models/sdxl.py b/invokeai/backend/model_management_OLD/models/sdxl.py similarity index 100% rename from invokeai/backend/model_management/models/sdxl.py rename to invokeai/backend/model_management_OLD/models/sdxl.py diff --git a/invokeai/backend/model_management/models/stable_diffusion.py b/invokeai/backend/model_management_OLD/models/stable_diffusion.py similarity index 100% rename from invokeai/backend/model_management/models/stable_diffusion.py rename to invokeai/backend/model_management_OLD/models/stable_diffusion.py diff --git a/invokeai/backend/model_management/models/stable_diffusion_onnx.py b/invokeai/backend/model_management_OLD/models/stable_diffusion_onnx.py similarity index 100% rename from invokeai/backend/model_management/models/stable_diffusion_onnx.py rename to invokeai/backend/model_management_OLD/models/stable_diffusion_onnx.py diff --git a/invokeai/backend/model_management/models/t2i_adapter.py b/invokeai/backend/model_management_OLD/models/t2i_adapter.py similarity index 100% rename from invokeai/backend/model_management/models/t2i_adapter.py rename to invokeai/backend/model_management_OLD/models/t2i_adapter.py diff --git a/invokeai/backend/model_management/models/textual_inversion.py b/invokeai/backend/model_management_OLD/models/textual_inversion.py similarity index 100% rename from invokeai/backend/model_management/models/textual_inversion.py rename to invokeai/backend/model_management_OLD/models/textual_inversion.py diff --git a/invokeai/backend/model_management/models/vae.py b/invokeai/backend/model_management_OLD/models/vae.py similarity index 100% rename from invokeai/backend/model_management/models/vae.py rename to invokeai/backend/model_management_OLD/models/vae.py diff --git a/invokeai/backend/model_management/seamless.py b/invokeai/backend/model_management_OLD/seamless.py similarity index 100% rename from invokeai/backend/model_management/seamless.py rename to invokeai/backend/model_management_OLD/seamless.py diff --git a/invokeai/backend/model_management/util.py b/invokeai/backend/model_management_OLD/util.py similarity index 100% rename from invokeai/backend/model_management/util.py rename to invokeai/backend/model_management_OLD/util.py diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 42921f0b32c..bc4848b0a50 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -28,12 +28,11 @@ from pydantic import BaseModel, ConfigDict, Field, TypeAdapter from typing_extensions import Annotated, Any, Dict -from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel +from ..raw_model import RawModel -from ..embeddings.embedding_base import EmbeddingModelRaw -from ..ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus - -AnyModel = Union[ModelMixin, torch.nn.Module, IAIOnnxRuntimeModel, IPAdapter, IPAdapterPlus, EmbeddingModelRaw] +# ModelMixin is the base class for all diffusers and transformers models +# RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime +AnyModel = Union[ModelMixin, RawModel, torch.nn.Module] class InvalidModelConfigException(Exception): diff --git a/invokeai/backend/model_management/libc_util.py b/invokeai/backend/model_manager/libc_util.py similarity index 100% rename from invokeai/backend/model_management/libc_util.py rename to invokeai/backend/model_manager/libc_util.py diff --git a/invokeai/backend/model_manager/load/memory_snapshot.py b/invokeai/backend/model_manager/load/memory_snapshot.py index 346f5dc4247..209d7166f36 100644 --- a/invokeai/backend/model_manager/load/memory_snapshot.py +++ b/invokeai/backend/model_manager/load/memory_snapshot.py @@ -5,7 +5,7 @@ import torch from typing_extensions import Self -from invokeai.backend.model_management.libc_util import LibcUtil, Struct_mallinfo2 +from ..util.libc_util import LibcUtil, Struct_mallinfo2 GB = 2**30 # 1 GB @@ -97,4 +97,4 @@ def get_msg_line(prefix: str, val1: int, val2: int) -> str: if snapshot_1.vram is not None and snapshot_2.vram is not None: msg += get_msg_line("VRAM", snapshot_1.vram, snapshot_2.vram) - return "\n" + msg if len(msg) > 0 else msg + return msg diff --git a/invokeai/backend/model_manager/load/model_loaders/lora.py b/invokeai/backend/model_manager/load/model_loaders/lora.py index d8e5f920e24..6ff2dcc9182 100644 --- a/invokeai/backend/model_manager/load/model_loaders/lora.py +++ b/invokeai/backend/model_manager/load/model_loaders/lora.py @@ -7,7 +7,7 @@ from typing import Optional, Tuple from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend.embeddings.lora import LoRAModelRaw +from invokeai.backend.lora import LoRAModelRaw from invokeai.backend.model_manager import ( AnyModel, AnyModelConfig, diff --git a/invokeai/backend/model_manager/load/model_loaders/textual_inversion.py b/invokeai/backend/model_manager/load/model_loaders/textual_inversion.py index 6635f6b43fe..94767479609 100644 --- a/invokeai/backend/model_manager/load/model_loaders/textual_inversion.py +++ b/invokeai/backend/model_manager/load/model_loaders/textual_inversion.py @@ -5,7 +5,7 @@ from pathlib import Path from typing import Optional, Tuple -from invokeai.backend.embeddings.textual_inversion import TextualInversionModelRaw +from invokeai.backend.textual_inversion import TextualInversionModelRaw from invokeai.backend.model_manager import ( AnyModel, AnyModelConfig, diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index 2c2066d7c52..d511ffa875f 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -8,9 +8,7 @@ from picklescan.scanner import scan_file_path import invokeai.backend.util.logging as logger -from invokeai.backend.model_management.models.base import read_checkpoint_meta -from invokeai.backend.model_management.models.ip_adapter import IPAdapterModelFormat -from invokeai.backend.model_management.util import lora_token_vector_length +from .util.model_util import lora_token_vector_length, read_checkpoint_meta from invokeai.backend.util.util import SilenceWarnings from .config import ( @@ -55,7 +53,6 @@ }, } - class ProbeBase(object): """Base class for probes.""" @@ -653,8 +650,8 @@ def get_base_type(self) -> BaseModelType: class IPAdapterFolderProbe(FolderProbeBase): - def get_format(self) -> IPAdapterModelFormat: - return IPAdapterModelFormat.InvokeAI.value + def get_format(self) -> ModelFormat: + return ModelFormat.InvokeAI def get_base_type(self) -> BaseModelType: model_file = self.model_path / "ip_adapter.bin" diff --git a/invokeai/backend/model_manager/util/libc_util.py b/invokeai/backend/model_manager/util/libc_util.py new file mode 100644 index 00000000000..1fbcae0a93c --- /dev/null +++ b/invokeai/backend/model_manager/util/libc_util.py @@ -0,0 +1,75 @@ +import ctypes + + +class Struct_mallinfo2(ctypes.Structure): + """A ctypes Structure that matches the libc mallinfo2 struct. + + Docs: + - https://man7.org/linux/man-pages/man3/mallinfo.3.html + - https://www.gnu.org/software/libc/manual/html_node/Statistics-of-Malloc.html + + struct mallinfo2 { + size_t arena; /* Non-mmapped space allocated (bytes) */ + size_t ordblks; /* Number of free chunks */ + size_t smblks; /* Number of free fastbin blocks */ + size_t hblks; /* Number of mmapped regions */ + size_t hblkhd; /* Space allocated in mmapped regions (bytes) */ + size_t usmblks; /* See below */ + size_t fsmblks; /* Space in freed fastbin blocks (bytes) */ + size_t uordblks; /* Total allocated space (bytes) */ + size_t fordblks; /* Total free space (bytes) */ + size_t keepcost; /* Top-most, releasable space (bytes) */ + }; + """ + + _fields_ = [ + ("arena", ctypes.c_size_t), + ("ordblks", ctypes.c_size_t), + ("smblks", ctypes.c_size_t), + ("hblks", ctypes.c_size_t), + ("hblkhd", ctypes.c_size_t), + ("usmblks", ctypes.c_size_t), + ("fsmblks", ctypes.c_size_t), + ("uordblks", ctypes.c_size_t), + ("fordblks", ctypes.c_size_t), + ("keepcost", ctypes.c_size_t), + ] + + def __str__(self): + s = "" + s += f"{'arena': <10}= {(self.arena/2**30):15.5f} # Non-mmapped space allocated (GB) (uordblks + fordblks)\n" + s += f"{'ordblks': <10}= {(self.ordblks): >15} # Number of free chunks\n" + s += f"{'smblks': <10}= {(self.smblks): >15} # Number of free fastbin blocks \n" + s += f"{'hblks': <10}= {(self.hblks): >15} # Number of mmapped regions \n" + s += f"{'hblkhd': <10}= {(self.hblkhd/2**30):15.5f} # Space allocated in mmapped regions (GB)\n" + s += f"{'usmblks': <10}= {(self.usmblks): >15} # Unused\n" + s += f"{'fsmblks': <10}= {(self.fsmblks/2**30):15.5f} # Space in freed fastbin blocks (GB)\n" + s += ( + f"{'uordblks': <10}= {(self.uordblks/2**30):15.5f} # Space used by in-use allocations (non-mmapped)" + " (GB)\n" + ) + s += f"{'fordblks': <10}= {(self.fordblks/2**30):15.5f} # Space in free blocks (non-mmapped) (GB)\n" + s += f"{'keepcost': <10}= {(self.keepcost/2**30):15.5f} # Top-most, releasable space (GB)\n" + return s + + +class LibcUtil: + """A utility class for interacting with the C Standard Library (`libc`) via ctypes. + + Note that this class will raise on __init__() if 'libc.so.6' can't be found. Take care to handle environments where + this shared library is not available. + + TODO: Improve cross-OS compatibility of this class. + """ + + def __init__(self): + self._libc = ctypes.cdll.LoadLibrary("libc.so.6") + + def mallinfo2(self) -> Struct_mallinfo2: + """Calls `libc` `mallinfo2`. + + Docs: https://man7.org/linux/man-pages/man3/mallinfo.3.html + """ + mallinfo2 = self._libc.mallinfo2 + mallinfo2.restype = Struct_mallinfo2 + return mallinfo2() diff --git a/invokeai/backend/model_manager/util/model_util.py b/invokeai/backend/model_manager/util/model_util.py new file mode 100644 index 00000000000..6847a40878c --- /dev/null +++ b/invokeai/backend/model_manager/util/model_util.py @@ -0,0 +1,129 @@ +"""Utilities for parsing model files, used mostly by probe.py""" + +import json +import torch +from typing import Union +from pathlib import Path +from picklescan.scanner import scan_file_path + +def _fast_safetensors_reader(path: str): + checkpoint = {} + device = torch.device("meta") + with open(path, "rb") as f: + definition_len = int.from_bytes(f.read(8), "little") + definition_json = f.read(definition_len) + definition = json.loads(definition_json) + + if "__metadata__" in definition and definition["__metadata__"].get("format", "pt") not in { + "pt", + "torch", + "pytorch", + }: + raise Exception("Supported only pytorch safetensors files") + definition.pop("__metadata__", None) + + for key, info in definition.items(): + dtype = { + "I8": torch.int8, + "I16": torch.int16, + "I32": torch.int32, + "I64": torch.int64, + "F16": torch.float16, + "F32": torch.float32, + "F64": torch.float64, + }[info["dtype"]] + + checkpoint[key] = torch.empty(info["shape"], dtype=dtype, device=device) + + return checkpoint + +def read_checkpoint_meta(path: Union[str, Path], scan: bool = False): + if str(path).endswith(".safetensors"): + try: + checkpoint = _fast_safetensors_reader(path) + except Exception: + # TODO: create issue for support "meta"? + checkpoint = safetensors.torch.load_file(path, device="cpu") + else: + if scan: + scan_result = scan_file_path(path) + if scan_result.infected_files != 0: + raise Exception(f'The model file "{path}" is potentially infected by malware. Aborting import.') + checkpoint = torch.load(path, map_location=torch.device("meta")) + return checkpoint + +def lora_token_vector_length(checkpoint: dict) -> int: + """ + Given a checkpoint in memory, return the lora token vector length + + :param checkpoint: The checkpoint + """ + + def _get_shape_1(key: str, tensor, checkpoint) -> int: + lora_token_vector_length = None + + if "." not in key: + return lora_token_vector_length # wrong key format + model_key, lora_key = key.split(".", 1) + + # check lora/locon + if lora_key == "lora_down.weight": + lora_token_vector_length = tensor.shape[1] + + # check loha (don't worry about hada_t1/hada_t2 as it used only in 4d shapes) + elif lora_key in ["hada_w1_b", "hada_w2_b"]: + lora_token_vector_length = tensor.shape[1] + + # check lokr (don't worry about lokr_t2 as it used only in 4d shapes) + elif "lokr_" in lora_key: + if model_key + ".lokr_w1" in checkpoint: + _lokr_w1 = checkpoint[model_key + ".lokr_w1"] + elif model_key + "lokr_w1_b" in checkpoint: + _lokr_w1 = checkpoint[model_key + ".lokr_w1_b"] + else: + return lora_token_vector_length # unknown format + + if model_key + ".lokr_w2" in checkpoint: + _lokr_w2 = checkpoint[model_key + ".lokr_w2"] + elif model_key + "lokr_w2_b" in checkpoint: + _lokr_w2 = checkpoint[model_key + ".lokr_w2_b"] + else: + return lora_token_vector_length # unknown format + + lora_token_vector_length = _lokr_w1.shape[1] * _lokr_w2.shape[1] + + elif lora_key == "diff": + lora_token_vector_length = tensor.shape[1] + + # ia3 can be detected only by shape[0] in text encoder + elif lora_key == "weight" and "lora_unet_" not in model_key: + lora_token_vector_length = tensor.shape[0] + + return lora_token_vector_length + + lora_token_vector_length = None + lora_te1_length = None + lora_te2_length = None + for key, tensor in checkpoint.items(): + if key.startswith("lora_unet_") and ("_attn2_to_k." in key or "_attn2_to_v." in key): + lora_token_vector_length = _get_shape_1(key, tensor, checkpoint) + elif key.startswith("lora_unet_") and ( + "time_emb_proj.lora_down" in key + ): # recognizes format at https://civitai.com/models/224641 + lora_token_vector_length = _get_shape_1(key, tensor, checkpoint) + elif key.startswith("lora_te") and "_self_attn_" in key: + tmp_length = _get_shape_1(key, tensor, checkpoint) + if key.startswith("lora_te_"): + lora_token_vector_length = tmp_length + elif key.startswith("lora_te1_"): + lora_te1_length = tmp_length + elif key.startswith("lora_te2_"): + lora_te2_length = tmp_length + + if lora_te1_length is not None and lora_te2_length is not None: + lora_token_vector_length = lora_te1_length + lora_te2_length + + if lora_token_vector_length is not None: + break + + return lora_token_vector_length diff --git a/invokeai/backend/embeddings/model_patcher.py b/invokeai/backend/model_patcher.py similarity index 100% rename from invokeai/backend/embeddings/model_patcher.py rename to invokeai/backend/model_patcher.py diff --git a/invokeai/backend/onnx/onnx_runtime.py b/invokeai/backend/onnx/onnx_runtime.py index f79fa015692..9b2096abdf0 100644 --- a/invokeai/backend/onnx/onnx_runtime.py +++ b/invokeai/backend/onnx/onnx_runtime.py @@ -8,6 +8,7 @@ import onnx from onnx import numpy_helper from onnxruntime import InferenceSession, SessionOptions, get_available_providers +from ..raw_model import RawModel ONNX_WEIGHTS_NAME = "model.onnx" diff --git a/invokeai/backend/raw_model.py b/invokeai/backend/raw_model.py new file mode 100644 index 00000000000..2e224d538b3 --- /dev/null +++ b/invokeai/backend/raw_model.py @@ -0,0 +1,14 @@ +"""Base class for 'Raw' models. + +The RawModel class is the base class of LoRAModelRaw and TextualInversionModelRaw, +and is used for type checking of calls to the model patcher. Its main purpose +is to avoid a circular import issues when lora.py tries to import BaseModelType +from invokeai.backend.model_manager.config, and the latter tries to import LoRAModelRaw +from lora.py. + +The term 'raw' was introduced to describe a wrapper around a torch.nn.Module +that adds additional methods and attributes. +""" + +class RawModel: + """Base class for 'Raw' model wrappers.""" diff --git a/invokeai/backend/embeddings/textual_inversion.py b/invokeai/backend/textual_inversion.py similarity index 97% rename from invokeai/backend/embeddings/textual_inversion.py rename to invokeai/backend/textual_inversion.py index 389edff039d..9a4fa0b5402 100644 --- a/invokeai/backend/embeddings/textual_inversion.py +++ b/invokeai/backend/textual_inversion.py @@ -8,11 +8,9 @@ from safetensors.torch import load_file from transformers import CLIPTokenizer from typing_extensions import Self +from .raw_model import RawModel -from .embedding_base import EmbeddingModelRaw - - -class TextualInversionModelRaw(EmbeddingModelRaw): +class TextualInversionModelRaw(RawModel): embedding: torch.Tensor # [n, 768]|[n, 1280] embedding_2: Optional[torch.Tensor] = None # [n, 768]|[n, 1280] - for SDXL models diff --git a/invokeai/backend/util/test_utils.py b/invokeai/backend/util/test_utils.py index 685603cedc6..a3def182c8c 100644 --- a/invokeai/backend/util/test_utils.py +++ b/invokeai/backend/util/test_utils.py @@ -5,10 +5,9 @@ import pytest import torch -from invokeai.app.services.config.config_default import InvokeAIAppConfig -from invokeai.backend.install.model_install_backend import ModelInstall -from invokeai.backend.model_management.model_manager import LoadedModelInfo -from invokeai.backend.model_management.models.base import BaseModelType, ModelNotFoundException, ModelType, SubModelType +from invokeai.app.services.model_manager import ModelManagerServiceBase +from invokeai.app.services.model_records import UnknownModelException +from invokeai.backend.model_manager import BaseModelType, LoadedModel, ModelType, SubModelType @pytest.fixture(scope="session") @@ -16,31 +15,20 @@ def torch_device(): return "cuda" if torch.cuda.is_available() else "cpu" -@pytest.fixture(scope="module") -def model_installer(): - """A global ModelInstall pytest fixture to be used by many tests.""" - # HACK(ryand): InvokeAIAppConfig.get_config() returns a singleton config object. This can lead to weird interactions - # between tests that need to alter the config. For example, some tests change the 'root' directory in the config, - # which can cause `install_and_load_model(...)` to re-download the model unnecessarily. As a temporary workaround, - # we pass a kwarg to get_config, which causes the config to be re-loaded. To fix this properly, we should stop using - # a singleton. - return ModelInstall(InvokeAIAppConfig.get_config(log_level="info")) - - def install_and_load_model( - model_installer: ModelInstall, + model_manager: ModelManagerServiceBase, model_path_id_or_url: Union[str, Path], model_name: str, base_model: BaseModelType, model_type: ModelType, submodel_type: Optional[SubModelType] = None, -) -> LoadedModelInfo: - """Install a model if it is not already installed, then get the LoadedModelInfo for that model. +) -> LoadedModel: + """Install a model if it is not already installed, then get the LoadedModel for that model. This is intended as a utility function for tests. Args: - model_installer (ModelInstall): The model installer. + mm2_model_manager (ModelManagerServiceBase): The model manager model_path_id_or_url (Union[str, Path]): The path, HF ID, URL, etc. where the model can be installed from if it is not already installed. model_name (str): The model name, forwarded to ModelManager.get_model(...). @@ -51,16 +39,23 @@ def install_and_load_model( Returns: LoadedModelInfo """ - # If the requested model is already installed, return its LoadedModelInfo. - with contextlib.suppress(ModelNotFoundException): - return model_installer.mgr.get_model(model_name, base_model, model_type, submodel_type) + # If the requested model is already installed, return its LoadedModel + with contextlib.suppress(UnknownModelException): + # TODO: Replace with wrapper call + loaded_model: LoadedModel = model_manager.load.load_model_by_attr( + model_name=model_name, base_model=base_model, model_type=model_type + ) + return loaded_model # Install the requested model. - model_installer.heuristic_import(model_path_id_or_url) + job = model_manager.install.heuristic_import(model_path_id_or_url) + model_manager.install.wait_for_job(job, timeout=10) + assert job.complete try: - return model_installer.mgr.get_model(model_name, base_model, model_type, submodel_type) - except ModelNotFoundException as e: + loaded_model = model_manager.load.load_model_by_config(job.config_out) + return loaded_model + except UnknownModelException as e: raise Exception( "Failed to get model info after installing it. There could be a mismatch between the requested model and" f" the installation id ('{model_path_id_or_url}'). Error: {e}" diff --git a/invokeai/configs/INITIAL_MODELS.yaml.OLD b/invokeai/configs/INITIAL_MODELS.yaml.OLD deleted file mode 100644 index c230665e3a6..00000000000 --- a/invokeai/configs/INITIAL_MODELS.yaml.OLD +++ /dev/null @@ -1,153 +0,0 @@ -# This file predefines a few models that the user may want to install. -sd-1/main/stable-diffusion-v1-5: - description: Stable Diffusion version 1.5 diffusers model (4.27 GB) - repo_id: runwayml/stable-diffusion-v1-5 - recommended: True - default: True -sd-1/main/stable-diffusion-v1-5-inpainting: - description: RunwayML SD 1.5 model optimized for inpainting, diffusers version (4.27 GB) - repo_id: runwayml/stable-diffusion-inpainting - recommended: True -sd-2/main/stable-diffusion-2-1: - description: Stable Diffusion version 2.1 diffusers model, trained on 768 pixel images (5.21 GB) - repo_id: stabilityai/stable-diffusion-2-1 - recommended: False -sd-2/main/stable-diffusion-2-inpainting: - description: Stable Diffusion version 2.0 inpainting model (5.21 GB) - repo_id: stabilityai/stable-diffusion-2-inpainting - recommended: False -sdxl/main/stable-diffusion-xl-base-1-0: - description: Stable Diffusion XL base model (12 GB) - repo_id: stabilityai/stable-diffusion-xl-base-1.0 - recommended: True -sdxl-refiner/main/stable-diffusion-xl-refiner-1-0: - description: Stable Diffusion XL refiner model (12 GB) - repo_id: stabilityai/stable-diffusion-xl-refiner-1.0 - recommended: False -sdxl/vae/sdxl-1-0-vae-fix: - description: Fine tuned version of the SDXL-1.0 VAE - repo_id: madebyollin/sdxl-vae-fp16-fix - recommended: True -sd-1/main/Analog-Diffusion: - description: An SD-1.5 model trained on diverse analog photographs (2.13 GB) - repo_id: wavymulder/Analog-Diffusion - recommended: False -sd-1/main/Deliberate_v5: - description: Versatile model that produces detailed images up to 768px (4.27 GB) - path: https://huggingface.co/XpucT/Deliberate/resolve/main/Deliberate_v5.safetensors - recommended: False -sd-1/main/Dungeons-and-Diffusion: - description: Dungeons & Dragons characters (2.13 GB) - repo_id: 0xJustin/Dungeons-and-Diffusion - recommended: False -sd-1/main/dreamlike-photoreal-2: - description: A photorealistic model trained on 768 pixel images based on SD 1.5 (2.13 GB) - repo_id: dreamlike-art/dreamlike-photoreal-2.0 - recommended: False -sd-1/main/Inkpunk-Diffusion: - description: Stylized illustrations inspired by Gorillaz, FLCL and Shinkawa; prompt with "nvinkpunk" (4.27 GB) - repo_id: Envvi/Inkpunk-Diffusion - recommended: False -sd-1/main/openjourney: - description: An SD 1.5 model fine tuned on Midjourney; prompt with "mdjrny-v4 style" (2.13 GB) - repo_id: prompthero/openjourney - recommended: False -sd-1/main/seek.art_MEGA: - repo_id: coreco/seek.art_MEGA - description: A general use SD-1.5 "anything" model that supports multiple styles (2.1 GB) - recommended: False -sd-1/main/trinart_stable_diffusion_v2: - description: An SD-1.5 model finetuned with ~40K assorted high resolution manga/anime-style images (2.13 GB) - repo_id: naclbit/trinart_stable_diffusion_v2 - recommended: False -sd-1/controlnet/qrcode_monster: - repo_id: monster-labs/control_v1p_sd15_qrcode_monster - subfolder: v2 -sd-1/controlnet/canny: - repo_id: lllyasviel/control_v11p_sd15_canny - recommended: True -sd-1/controlnet/inpaint: - repo_id: lllyasviel/control_v11p_sd15_inpaint -sd-1/controlnet/mlsd: - repo_id: lllyasviel/control_v11p_sd15_mlsd -sd-1/controlnet/depth: - repo_id: lllyasviel/control_v11f1p_sd15_depth - recommended: True -sd-1/controlnet/normal_bae: - repo_id: lllyasviel/control_v11p_sd15_normalbae -sd-1/controlnet/seg: - repo_id: lllyasviel/control_v11p_sd15_seg -sd-1/controlnet/lineart: - repo_id: lllyasviel/control_v11p_sd15_lineart - recommended: True -sd-1/controlnet/lineart_anime: - repo_id: lllyasviel/control_v11p_sd15s2_lineart_anime -sd-1/controlnet/openpose: - repo_id: lllyasviel/control_v11p_sd15_openpose - recommended: True -sd-1/controlnet/scribble: - repo_id: lllyasviel/control_v11p_sd15_scribble - recommended: False -sd-1/controlnet/softedge: - repo_id: lllyasviel/control_v11p_sd15_softedge -sd-1/controlnet/shuffle: - repo_id: lllyasviel/control_v11e_sd15_shuffle -sd-1/controlnet/tile: - repo_id: lllyasviel/control_v11f1e_sd15_tile -sd-1/controlnet/ip2p: - repo_id: lllyasviel/control_v11e_sd15_ip2p -sd-1/t2i_adapter/canny-sd15: - repo_id: TencentARC/t2iadapter_canny_sd15v2 -sd-1/t2i_adapter/sketch-sd15: - repo_id: TencentARC/t2iadapter_sketch_sd15v2 -sd-1/t2i_adapter/depth-sd15: - repo_id: TencentARC/t2iadapter_depth_sd15v2 -sd-1/t2i_adapter/zoedepth-sd15: - repo_id: TencentARC/t2iadapter_zoedepth_sd15v1 -sdxl/t2i_adapter/canny-sdxl: - repo_id: TencentARC/t2i-adapter-canny-sdxl-1.0 -sdxl/t2i_adapter/zoedepth-sdxl: - repo_id: TencentARC/t2i-adapter-depth-zoe-sdxl-1.0 -sdxl/t2i_adapter/lineart-sdxl: - repo_id: TencentARC/t2i-adapter-lineart-sdxl-1.0 -sdxl/t2i_adapter/sketch-sdxl: - repo_id: TencentARC/t2i-adapter-sketch-sdxl-1.0 -sd-1/embedding/EasyNegative: - path: https://huggingface.co/embed/EasyNegative/resolve/main/EasyNegative.safetensors - recommended: True -sd-1/embedding/ahx-beta-453407d: - repo_id: sd-concepts-library/ahx-beta-453407d -sd-1/lora/Ink scenery: - path: https://civitai.com/api/download/models/83390 -sd-1/ip_adapter/ip_adapter_sd15: - repo_id: InvokeAI/ip_adapter_sd15 - recommended: True - requires: - - InvokeAI/ip_adapter_sd_image_encoder - description: IP-Adapter for SD 1.5 models -sd-1/ip_adapter/ip_adapter_plus_sd15: - repo_id: InvokeAI/ip_adapter_plus_sd15 - recommended: False - requires: - - InvokeAI/ip_adapter_sd_image_encoder - description: Refined IP-Adapter for SD 1.5 models -sd-1/ip_adapter/ip_adapter_plus_face_sd15: - repo_id: InvokeAI/ip_adapter_plus_face_sd15 - recommended: False - requires: - - InvokeAI/ip_adapter_sd_image_encoder - description: Refined IP-Adapter for SD 1.5 models, adapted for faces -sdxl/ip_adapter/ip_adapter_sdxl: - repo_id: InvokeAI/ip_adapter_sdxl - recommended: False - requires: - - InvokeAI/ip_adapter_sdxl_image_encoder - description: IP-Adapter for SDXL models -any/clip_vision/ip_adapter_sd_image_encoder: - repo_id: InvokeAI/ip_adapter_sd_image_encoder - recommended: False - description: Required model for using IP-Adapters with SD-1/2 models -any/clip_vision/ip_adapter_sdxl_image_encoder: - repo_id: InvokeAI/ip_adapter_sdxl_image_encoder - recommended: False - description: Required model for using IP-Adapters with SDXL models diff --git a/invokeai/configs/models.yaml.example b/invokeai/configs/models.yaml.example deleted file mode 100644 index 98f8f77e62c..00000000000 --- a/invokeai/configs/models.yaml.example +++ /dev/null @@ -1,47 +0,0 @@ -# This file describes the alternative machine learning models -# available to InvokeAI script. -# -# To add a new model, follow the examples below. Each -# model requires a model config file, a weights file, -# and the width and height of the images it -# was trained on. -diffusers-1.4: - description: 🤗🧨 Stable Diffusion v1.4 - format: diffusers - repo_id: CompVis/stable-diffusion-v1-4 -diffusers-1.5: - description: 🤗🧨 Stable Diffusion v1.5 - format: diffusers - repo_id: runwayml/stable-diffusion-v1-5 - default: true -diffusers-1.5+mse: - description: 🤗🧨 Stable Diffusion v1.5 + MSE-finetuned VAE - format: diffusers - repo_id: runwayml/stable-diffusion-v1-5 - vae: - repo_id: stabilityai/sd-vae-ft-mse -diffusers-inpainting-1.5: - description: 🤗🧨 inpainting for Stable Diffusion v1.5 - format: diffusers - repo_id: runwayml/stable-diffusion-inpainting -stable-diffusion-1.5: - description: The newest Stable Diffusion version 1.5 weight file (4.27 GB) - weights: models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt - config: configs/stable-diffusion/v1-inference.yaml - width: 512 - height: 512 - vae: ./models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt -stable-diffusion-1.4: - description: Stable Diffusion inference model version 1.4 - config: configs/stable-diffusion/v1-inference.yaml - weights: models/ldm/stable-diffusion-v1/sd-v1-4.ckpt - vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt - width: 512 - height: 512 -inpainting-1.5: - weights: models/ldm/stable-diffusion-v1/sd-v1-5-inpainting.ckpt - config: configs/stable-diffusion/v1-inpainting-inference.yaml - vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt - description: RunwayML SD 1.5 model optimized for inpainting - width: 512 - height: 512 diff --git a/invokeai/frontend/install/model_install.py.OLD b/invokeai/frontend/install/model_install.py.OLD deleted file mode 100644 index e23538ffd66..00000000000 --- a/invokeai/frontend/install/model_install.py.OLD +++ /dev/null @@ -1,845 +0,0 @@ -#!/usr/bin/env python -# Copyright (c) 2022 Lincoln D. Stein (https://github.com/lstein) -# Before running stable-diffusion on an internet-isolated machine, -# run this script from one with internet connectivity. The -# two machines must share a common .cache directory. - -""" -This is the npyscreen frontend to the model installation application. -The work is actually done in backend code in model_install_backend.py. -""" - -import argparse -import curses -import logging -import sys -import textwrap -import traceback -from argparse import Namespace -from multiprocessing import Process -from multiprocessing.connection import Connection, Pipe -from pathlib import Path -from shutil import get_terminal_size -from typing import Optional - -import npyscreen -import torch -from npyscreen import widget - -from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend.install.model_install_backend import InstallSelections, ModelInstall, SchedulerPredictionType -from invokeai.backend.model_management import ModelManager, ModelType -from invokeai.backend.util import choose_precision, choose_torch_device -from invokeai.backend.util.logging import InvokeAILogger -from invokeai.frontend.install.widgets import ( - MIN_COLS, - MIN_LINES, - BufferBox, - CenteredTitleText, - CyclingForm, - MultiSelectColumns, - SingleSelectColumns, - TextBox, - WindowTooSmallException, - select_stable_diffusion_config_file, - set_min_terminal_size, -) - -config = InvokeAIAppConfig.get_config() -logger = InvokeAILogger.get_logger() - -# build a table mapping all non-printable characters to None -# for stripping control characters -# from https://stackoverflow.com/questions/92438/stripping-non-printable-characters-from-a-string-in-python -NOPRINT_TRANS_TABLE = {i: None for i in range(0, sys.maxunicode + 1) if not chr(i).isprintable()} - -# maximum number of installed models we can display before overflowing vertically -MAX_OTHER_MODELS = 72 - - -def make_printable(s: str) -> str: - """Replace non-printable characters in a string""" - return s.translate(NOPRINT_TRANS_TABLE) - - -class addModelsForm(CyclingForm, npyscreen.FormMultiPage): - # for responsive resizing set to False, but this seems to cause a crash! - FIX_MINIMUM_SIZE_WHEN_CREATED = True - - # for persistence - current_tab = 0 - - def __init__(self, parentApp, name, multipage=False, *args, **keywords): - self.multipage = multipage - self.subprocess = None - super().__init__(parentApp=parentApp, name=name, *args, **keywords) # noqa: B026 # TODO: maybe this is bad? - - def create(self): - self.keypress_timeout = 10 - self.counter = 0 - self.subprocess_connection = None - - if not config.model_conf_path.exists(): - with open(config.model_conf_path, "w") as file: - print("# InvokeAI model configuration file", file=file) - self.installer = ModelInstall(config) - self.all_models = self.installer.all_models() - self.starter_models = self.installer.starter_models() - self.model_labels = self._get_model_labels() - window_width, window_height = get_terminal_size() - - self.nextrely -= 1 - self.add_widget_intelligent( - npyscreen.FixedText, - value="Use ctrl-N and ctrl-P to move to the ext and

revious fields. Cursor keys navigate, and selects.", - editable=False, - color="CAUTION", - ) - self.nextrely += 1 - self.tabs = self.add_widget_intelligent( - SingleSelectColumns, - values=[ - "STARTERS", - "MAINS", - "CONTROLNETS", - "T2I-ADAPTERS", - "IP-ADAPTERS", - "LORAS", - "TI EMBEDDINGS", - ], - value=[self.current_tab], - columns=7, - max_height=2, - relx=8, - scroll_exit=True, - ) - self.tabs.on_changed = self._toggle_tables - - top_of_table = self.nextrely - self.starter_pipelines = self.add_starter_pipelines() - bottom_of_table = self.nextrely - - self.nextrely = top_of_table - self.pipeline_models = self.add_pipeline_widgets( - model_type=ModelType.Main, window_width=window_width, exclude=self.starter_models - ) - # self.pipeline_models['autoload_pending'] = True - bottom_of_table = max(bottom_of_table, self.nextrely) - - self.nextrely = top_of_table - self.controlnet_models = self.add_model_widgets( - model_type=ModelType.ControlNet, - window_width=window_width, - ) - bottom_of_table = max(bottom_of_table, self.nextrely) - - self.nextrely = top_of_table - self.t2i_models = self.add_model_widgets( - model_type=ModelType.T2IAdapter, - window_width=window_width, - ) - bottom_of_table = max(bottom_of_table, self.nextrely) - self.nextrely = top_of_table - self.ipadapter_models = self.add_model_widgets( - model_type=ModelType.IPAdapter, - window_width=window_width, - ) - bottom_of_table = max(bottom_of_table, self.nextrely) - - self.nextrely = top_of_table - self.lora_models = self.add_model_widgets( - model_type=ModelType.Lora, - window_width=window_width, - ) - bottom_of_table = max(bottom_of_table, self.nextrely) - - self.nextrely = top_of_table - self.ti_models = self.add_model_widgets( - model_type=ModelType.TextualInversion, - window_width=window_width, - ) - bottom_of_table = max(bottom_of_table, self.nextrely) - - self.nextrely = bottom_of_table + 1 - - self.monitor = self.add_widget_intelligent( - BufferBox, - name="Log Messages", - editable=False, - max_height=6, - ) - - self.nextrely += 1 - done_label = "APPLY CHANGES" - back_label = "BACK" - cancel_label = "CANCEL" - current_position = self.nextrely - if self.multipage: - self.back_button = self.add_widget_intelligent( - npyscreen.ButtonPress, - name=back_label, - when_pressed_function=self.on_back, - ) - else: - self.nextrely = current_position - self.cancel_button = self.add_widget_intelligent( - npyscreen.ButtonPress, name=cancel_label, when_pressed_function=self.on_cancel - ) - self.nextrely = current_position - self.ok_button = self.add_widget_intelligent( - npyscreen.ButtonPress, - name=done_label, - relx=(window_width - len(done_label)) // 2, - when_pressed_function=self.on_execute, - ) - - label = "APPLY CHANGES & EXIT" - self.nextrely = current_position - self.done = self.add_widget_intelligent( - npyscreen.ButtonPress, - name=label, - relx=window_width - len(label) - 15, - when_pressed_function=self.on_done, - ) - - # This restores the selected page on return from an installation - for _i in range(1, self.current_tab + 1): - self.tabs.h_cursor_line_down(1) - self._toggle_tables([self.current_tab]) - - ############# diffusers tab ########## - def add_starter_pipelines(self) -> dict[str, npyscreen.widget]: - """Add widgets responsible for selecting diffusers models""" - widgets = {} - models = self.all_models - starters = self.starter_models - starter_model_labels = self.model_labels - - self.installed_models = sorted([x for x in starters if models[x].installed]) - - widgets.update( - label1=self.add_widget_intelligent( - CenteredTitleText, - name="Select from a starter set of Stable Diffusion models from HuggingFace.", - editable=False, - labelColor="CAUTION", - ) - ) - - self.nextrely -= 1 - # if user has already installed some initial models, then don't patronize them - # by showing more recommendations - show_recommended = len(self.installed_models) == 0 - keys = [x for x in models.keys() if x in starters] - widgets.update( - models_selected=self.add_widget_intelligent( - MultiSelectColumns, - columns=1, - name="Install Starter Models", - values=[starter_model_labels[x] for x in keys], - value=[ - keys.index(x) - for x in keys - if (show_recommended and models[x].recommended) or (x in self.installed_models) - ], - max_height=len(starters) + 1, - relx=4, - scroll_exit=True, - ), - models=keys, - ) - - self.nextrely += 1 - return widgets - - ############# Add a set of model install widgets ######## - def add_model_widgets( - self, - model_type: ModelType, - window_width: int = 120, - install_prompt: str = None, - exclude: set = None, - ) -> dict[str, npyscreen.widget]: - """Generic code to create model selection widgets""" - if exclude is None: - exclude = set() - widgets = {} - model_list = [x for x in self.all_models if self.all_models[x].model_type == model_type and x not in exclude] - model_labels = [self.model_labels[x] for x in model_list] - - show_recommended = len(self.installed_models) == 0 - truncated = False - if len(model_list) > 0: - max_width = max([len(x) for x in model_labels]) - columns = window_width // (max_width + 8) # 8 characters for "[x] " and padding - columns = min(len(model_list), columns) or 1 - prompt = ( - install_prompt - or f"Select the desired {model_type.value.title()} models to install. Unchecked models will be purged from disk." - ) - - widgets.update( - label1=self.add_widget_intelligent( - CenteredTitleText, - name=prompt, - editable=False, - labelColor="CAUTION", - ) - ) - - if len(model_labels) > MAX_OTHER_MODELS: - model_labels = model_labels[0:MAX_OTHER_MODELS] - truncated = True - - widgets.update( - models_selected=self.add_widget_intelligent( - MultiSelectColumns, - columns=columns, - name=f"Install {model_type} Models", - values=model_labels, - value=[ - model_list.index(x) - for x in model_list - if (show_recommended and self.all_models[x].recommended) or self.all_models[x].installed - ], - max_height=len(model_list) // columns + 1, - relx=4, - scroll_exit=True, - ), - models=model_list, - ) - - if truncated: - widgets.update( - warning_message=self.add_widget_intelligent( - npyscreen.FixedText, - value=f"Too many models to display (max={MAX_OTHER_MODELS}). Some are not displayed.", - editable=False, - color="CAUTION", - ) - ) - - self.nextrely += 1 - widgets.update( - download_ids=self.add_widget_intelligent( - TextBox, - name="Additional URLs, or HuggingFace repo_ids to install (Space separated. Use shift-control-V to paste):", - max_height=4, - scroll_exit=True, - editable=True, - ) - ) - return widgets - - ### Tab for arbitrary diffusers widgets ### - def add_pipeline_widgets( - self, - model_type: ModelType = ModelType.Main, - window_width: int = 120, - **kwargs, - ) -> dict[str, npyscreen.widget]: - """Similar to add_model_widgets() but adds some additional widgets at the bottom - to support the autoload directory""" - widgets = self.add_model_widgets( - model_type=model_type, - window_width=window_width, - install_prompt=f"Installed {model_type.value.title()} models. Unchecked models in the InvokeAI root directory will be deleted. Enter URLs, paths or repo_ids to import.", - **kwargs, - ) - - return widgets - - def resize(self): - super().resize() - if s := self.starter_pipelines.get("models_selected"): - keys = [x for x in self.all_models.keys() if x in self.starter_models] - s.values = [self.model_labels[x] for x in keys] - - def _toggle_tables(self, value=None): - selected_tab = value[0] - widgets = [ - self.starter_pipelines, - self.pipeline_models, - self.controlnet_models, - self.t2i_models, - self.ipadapter_models, - self.lora_models, - self.ti_models, - ] - - for group in widgets: - for _k, v in group.items(): - try: - v.hidden = True - v.editable = False - except Exception: - pass - for _k, v in widgets[selected_tab].items(): - try: - v.hidden = False - if not isinstance(v, (npyscreen.FixedText, npyscreen.TitleFixedText, CenteredTitleText)): - v.editable = True - except Exception: - pass - self.__class__.current_tab = selected_tab # for persistence - self.display() - - def _get_model_labels(self) -> dict[str, str]: - window_width, window_height = get_terminal_size() - checkbox_width = 4 - spacing_width = 2 - - models = self.all_models - label_width = max([len(models[x].name) for x in models]) - description_width = window_width - label_width - checkbox_width - spacing_width - - result = {} - for x in models.keys(): - description = models[x].description - description = ( - description[0 : description_width - 3] + "..." - if description and len(description) > description_width - else description - if description - else "" - ) - result[x] = f"%-{label_width}s %s" % (models[x].name, description) - return result - - def _get_columns(self) -> int: - window_width, window_height = get_terminal_size() - cols = 4 if window_width > 240 else 3 if window_width > 160 else 2 if window_width > 80 else 1 - return min(cols, len(self.installed_models)) - - def confirm_deletions(self, selections: InstallSelections) -> bool: - remove_models = selections.remove_models - if len(remove_models) > 0: - mods = "\n".join([ModelManager.parse_key(x)[0] for x in remove_models]) - return npyscreen.notify_ok_cancel( - f"These unchecked models will be deleted from disk. Continue?\n---------\n{mods}" - ) - else: - return True - - def on_execute(self): - self.marshall_arguments() - app = self.parentApp - if not self.confirm_deletions(app.install_selections): - return - - self.monitor.entry_widget.buffer(["Processing..."], scroll_end=True) - self.ok_button.hidden = True - self.display() - - # TO DO: Spawn a worker thread, not a subprocess - parent_conn, child_conn = Pipe() - p = Process( - target=process_and_execute, - kwargs={ - "opt": app.program_opts, - "selections": app.install_selections, - "conn_out": child_conn, - }, - ) - p.start() - child_conn.close() - self.subprocess_connection = parent_conn - self.subprocess = p - app.install_selections = InstallSelections() - - def on_back(self): - self.parentApp.switchFormPrevious() - self.editing = False - - def on_cancel(self): - self.parentApp.setNextForm(None) - self.parentApp.user_cancelled = True - self.editing = False - - def on_done(self): - self.marshall_arguments() - if not self.confirm_deletions(self.parentApp.install_selections): - return - self.parentApp.setNextForm(None) - self.parentApp.user_cancelled = False - self.editing = False - - ########## This routine monitors the child process that is performing model installation and removal ##### - def while_waiting(self): - """Called during idle periods. Main task is to update the Log Messages box with messages - from the child process that does the actual installation/removal""" - c = self.subprocess_connection - if not c: - return - - monitor_widget = self.monitor.entry_widget - while c.poll(): - try: - data = c.recv_bytes().decode("utf-8") - data.strip("\n") - - # processing child is requesting user input to select the - # right configuration file - if data.startswith("*need v2 config"): - _, model_path, *_ = data.split(":", 2) - self._return_v2_config(model_path) - - # processing child is done - elif data == "*done*": - self._close_subprocess_and_regenerate_form() - break - - # update the log message box - else: - data = make_printable(data) - data = data.replace("[A", "") - monitor_widget.buffer( - textwrap.wrap( - data, - width=monitor_widget.width, - subsequent_indent=" ", - ), - scroll_end=True, - ) - self.display() - except (EOFError, OSError): - self.subprocess_connection = None - - def _return_v2_config(self, model_path: str): - c = self.subprocess_connection - model_name = Path(model_path).name - message = select_stable_diffusion_config_file(model_name=model_name) - c.send_bytes(message.encode("utf-8")) - - def _close_subprocess_and_regenerate_form(self): - app = self.parentApp - self.subprocess_connection.close() - self.subprocess_connection = None - self.monitor.entry_widget.buffer(["** Action Complete **"]) - self.display() - - # rebuild the form, saving and restoring some of the fields that need to be preserved. - saved_messages = self.monitor.entry_widget.values - - app.main_form = app.addForm( - "MAIN", - addModelsForm, - name="Install Stable Diffusion Models", - multipage=self.multipage, - ) - app.switchForm("MAIN") - - app.main_form.monitor.entry_widget.values = saved_messages - app.main_form.monitor.entry_widget.buffer([""], scroll_end=True) - # app.main_form.pipeline_models['autoload_directory'].value = autoload_dir - # app.main_form.pipeline_models['autoscan_on_startup'].value = autoscan - - def marshall_arguments(self): - """ - Assemble arguments and store as attributes of the application: - .starter_models: dict of model names to install from INITIAL_CONFIGURE.yaml - True => Install - False => Remove - .scan_directory: Path to a directory of models to scan and import - .autoscan_on_startup: True if invokeai should scan and import at startup time - .import_model_paths: list of URLs, repo_ids and file paths to import - """ - selections = self.parentApp.install_selections - all_models = self.all_models - - # Defined models (in INITIAL_CONFIG.yaml or models.yaml) to add/remove - ui_sections = [ - self.starter_pipelines, - self.pipeline_models, - self.controlnet_models, - self.t2i_models, - self.ipadapter_models, - self.lora_models, - self.ti_models, - ] - for section in ui_sections: - if "models_selected" not in section: - continue - selected = {section["models"][x] for x in section["models_selected"].value} - models_to_install = [x for x in selected if not self.all_models[x].installed] - models_to_remove = [x for x in section["models"] if x not in selected and self.all_models[x].installed] - selections.remove_models.extend(models_to_remove) - selections.install_models.extend( - all_models[x].path or all_models[x].repo_id - for x in models_to_install - if all_models[x].path or all_models[x].repo_id - ) - - # models located in the 'download_ids" section - for section in ui_sections: - if downloads := section.get("download_ids"): - selections.install_models.extend(downloads.value.split()) - - # NOT NEEDED - DONE IN BACKEND NOW - # # special case for the ipadapter_models. If any of the adapters are - # # chosen, then we add the corresponding encoder(s) to the install list. - # section = self.ipadapter_models - # if section.get("models_selected"): - # selected_adapters = [ - # self.all_models[section["models"][x]].name for x in section.get("models_selected").value - # ] - # encoders = [] - # if any(["sdxl" in x for x in selected_adapters]): - # encoders.append("ip_adapter_sdxl_image_encoder") - # if any(["sd15" in x for x in selected_adapters]): - # encoders.append("ip_adapter_sd_image_encoder") - # for encoder in encoders: - # key = f"any/clip_vision/{encoder}" - # repo_id = f"InvokeAI/{encoder}" - # if key not in self.all_models: - # selections.install_models.append(repo_id) - - -class AddModelApplication(npyscreen.NPSAppManaged): - def __init__(self, opt): - super().__init__() - self.program_opts = opt - self.user_cancelled = False - # self.autoload_pending = True - self.install_selections = InstallSelections() - - def onStart(self): - npyscreen.setTheme(npyscreen.Themes.DefaultTheme) - self.main_form = self.addForm( - "MAIN", - addModelsForm, - name="Install Stable Diffusion Models", - cycle_widgets=False, - ) - - -class StderrToMessage: - def __init__(self, connection: Connection): - self.connection = connection - - def write(self, data: str): - self.connection.send_bytes(data.encode("utf-8")) - - def flush(self): - pass - - -# -------------------------------------------------------- -def ask_user_for_prediction_type(model_path: Path, tui_conn: Connection = None) -> SchedulerPredictionType: - if tui_conn: - logger.debug("Waiting for user response...") - return _ask_user_for_pt_tui(model_path, tui_conn) - else: - return _ask_user_for_pt_cmdline(model_path) - - -def _ask_user_for_pt_cmdline(model_path: Path) -> Optional[SchedulerPredictionType]: - choices = [SchedulerPredictionType.Epsilon, SchedulerPredictionType.VPrediction, None] - print( - f""" -Please select the scheduler prediction type of the checkpoint named {model_path.name}: -[1] "epsilon" - most v1.5 models and v2 models trained on 512 pixel images -[2] "vprediction" - v2 models trained on 768 pixel images and a few v1.5 models -[3] Accept the best guess; you can fix it in the Web UI later -""" - ) - choice = None - ok = False - while not ok: - try: - choice = input("select [3]> ").strip() - if not choice: - return None - choice = choices[int(choice) - 1] - ok = True - except (ValueError, IndexError): - print(f"{choice} is not a valid choice") - except EOFError: - return - return choice - - -def _ask_user_for_pt_tui(model_path: Path, tui_conn: Connection) -> SchedulerPredictionType: - tui_conn.send_bytes(f"*need v2 config for:{model_path}".encode("utf-8")) - # note that we don't do any status checking here - response = tui_conn.recv_bytes().decode("utf-8") - if response is None: - return None - elif response == "epsilon": - return SchedulerPredictionType.epsilon - elif response == "v": - return SchedulerPredictionType.VPrediction - elif response == "guess": - return None - else: - return None - - -# -------------------------------------------------------- -def process_and_execute( - opt: Namespace, - selections: InstallSelections, - conn_out: Connection = None, -): - # need to reinitialize config in subprocess - config = InvokeAIAppConfig.get_config() - args = ["--root", opt.root] if opt.root else [] - config.parse_args(args) - - # set up so that stderr is sent to conn_out - if conn_out: - translator = StderrToMessage(conn_out) - sys.stderr = translator - sys.stdout = translator - logger = InvokeAILogger.get_logger() - logger.handlers.clear() - logger.addHandler(logging.StreamHandler(translator)) - - installer = ModelInstall(config, prediction_type_helper=lambda x: ask_user_for_prediction_type(x, conn_out)) - installer.install(selections) - - if conn_out: - conn_out.send_bytes("*done*".encode("utf-8")) - conn_out.close() - - -# -------------------------------------------------------- -def select_and_download_models(opt: Namespace): - precision = "float32" if opt.full_precision else choose_precision(torch.device(choose_torch_device())) - config.precision = precision - installer = ModelInstall(config, prediction_type_helper=ask_user_for_prediction_type) - if opt.list_models: - installer.list_models(opt.list_models) - elif opt.add or opt.delete: - selections = InstallSelections(install_models=opt.add or [], remove_models=opt.delete or []) - installer.install(selections) - elif opt.default_only: - selections = InstallSelections(install_models=installer.default_model()) - installer.install(selections) - elif opt.yes_to_all: - selections = InstallSelections(install_models=installer.recommended_models()) - installer.install(selections) - - # this is where the TUI is called - else: - # needed to support the probe() method running under a subprocess - torch.multiprocessing.set_start_method("spawn") - - if not set_min_terminal_size(MIN_COLS, MIN_LINES): - raise WindowTooSmallException( - "Could not increase terminal size. Try running again with a larger window or smaller font size." - ) - - installApp = AddModelApplication(opt) - try: - installApp.run() - except KeyboardInterrupt as e: - if hasattr(installApp, "main_form"): - if installApp.main_form.subprocess and installApp.main_form.subprocess.is_alive(): - logger.info("Terminating subprocesses") - installApp.main_form.subprocess.terminate() - installApp.main_form.subprocess = None - raise e - process_and_execute(opt, installApp.install_selections) - - -# ------------------------------------- -def main(): - parser = argparse.ArgumentParser(description="InvokeAI model downloader") - parser.add_argument( - "--add", - nargs="*", - help="List of URLs, local paths or repo_ids of models to install", - ) - parser.add_argument( - "--delete", - nargs="*", - help="List of names of models to idelete", - ) - parser.add_argument( - "--full-precision", - dest="full_precision", - action=argparse.BooleanOptionalAction, - type=bool, - default=False, - help="use 32-bit weights instead of faster 16-bit weights", - ) - parser.add_argument( - "--yes", - "-y", - dest="yes_to_all", - action="store_true", - help='answer "yes" to all prompts', - ) - parser.add_argument( - "--default_only", - action="store_true", - help="Only install the default model", - ) - parser.add_argument( - "--list-models", - choices=[x.value for x in ModelType], - help="list installed models", - ) - parser.add_argument( - "--config_file", - "-c", - dest="config_file", - type=str, - default=None, - help="path to configuration file to create", - ) - parser.add_argument( - "--root_dir", - dest="root", - type=str, - default=None, - help="path to root of install directory", - ) - opt = parser.parse_args() - - invoke_args = [] - if opt.root: - invoke_args.extend(["--root", opt.root]) - if opt.full_precision: - invoke_args.extend(["--precision", "float32"]) - config.parse_args(invoke_args) - logger = InvokeAILogger().get_logger(config=config) - - if not config.model_conf_path.exists(): - logger.info("Your InvokeAI root directory is not set up. Calling invokeai-configure.") - from invokeai.frontend.install.invokeai_configure import invokeai_configure - - invokeai_configure() - sys.exit(0) - - try: - select_and_download_models(opt) - except AssertionError as e: - logger.error(e) - sys.exit(-1) - except KeyboardInterrupt: - curses.nocbreak() - curses.echo() - curses.endwin() - logger.info("Goodbye! Come back soon.") - except WindowTooSmallException as e: - logger.error(str(e)) - except widget.NotEnoughSpaceForWidget as e: - if str(e).startswith("Height of 1 allocated"): - logger.error("Insufficient vertical space for the interface. Please make your window taller and try again") - input("Press any key to continue...") - except Exception as e: - if str(e).startswith("addwstr"): - logger.error( - "Insufficient horizontal space for the interface. Please make your window wider and try again." - ) - else: - print(f"An exception has occurred: {str(e)} Details:") - print(traceback.format_exc(), file=sys.stderr) - input("Press any key to continue...") - - -# ------------------------------------- -if __name__ == "__main__": - main() diff --git a/invokeai/frontend/merge/merge_diffusers.py.OLD b/invokeai/frontend/merge/merge_diffusers.py.OLD deleted file mode 100644 index b365198f879..00000000000 --- a/invokeai/frontend/merge/merge_diffusers.py.OLD +++ /dev/null @@ -1,438 +0,0 @@ -""" -invokeai.frontend.merge exports a single function called merge_diffusion_models(). - -It merges 2-3 models together and create a new InvokeAI-registered diffusion model. - -Copyright (c) 2023-24 Lincoln Stein and the InvokeAI Development Team -""" -import argparse -import curses -import re -import sys -from argparse import Namespace -from pathlib import Path -from typing import List, Optional, Tuple - -import npyscreen -from npyscreen import widget - -import invokeai.backend.util.logging as logger -from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.app.services.model_install import ModelInstallServiceBase -from invokeai.app.services.model_records import ModelRecordServiceBase -from invokeai.backend.install.install_helper import initialize_installer -from invokeai.backend.model_manager import ( - BaseModelType, - ModelFormat, - ModelType, - ModelVariantType, -) -from invokeai.backend.model_manager.merge import ModelMerger -from invokeai.frontend.install.widgets import FloatTitleSlider, SingleSelectColumns, TextBox - -config = InvokeAIAppConfig.get_config() - -BASE_TYPES = [ - (BaseModelType.StableDiffusion1, "Models Built on SD-1.x"), - (BaseModelType.StableDiffusion2, "Models Built on SD-2.x"), - (BaseModelType.StableDiffusionXL, "Models Built on SDXL"), -] - - -def _parse_args() -> Namespace: - parser = argparse.ArgumentParser(description="InvokeAI model merging") - parser.add_argument( - "--root_dir", - type=Path, - default=config.root, - help="Path to the invokeai runtime directory", - ) - parser.add_argument( - "--front_end", - "--gui", - dest="front_end", - action="store_true", - default=False, - help="Activate the text-based graphical front end for collecting parameters. Aside from --root_dir, other parameters will be ignored.", - ) - parser.add_argument( - "--models", - dest="model_names", - type=str, - nargs="+", - help="Two to three model names to be merged", - ) - parser.add_argument( - "--base_model", - type=str, - choices=[x[0].value for x in BASE_TYPES], - help="The base model shared by the models to be merged", - ) - parser.add_argument( - "--merged_model_name", - "--destination", - dest="merged_model_name", - type=str, - help="Name of the output model. If not specified, will be the concatenation of the input model names.", - ) - parser.add_argument( - "--alpha", - type=float, - default=0.5, - help="The interpolation parameter, ranging from 0 to 1. It affects the ratio in which the checkpoints are merged. Higher values give more weight to the 2d and 3d models", - ) - parser.add_argument( - "--interpolation", - dest="interp", - type=str, - choices=["weighted_sum", "sigmoid", "inv_sigmoid", "add_difference"], - default="weighted_sum", - help='Interpolation method to use. If three models are present, only "add_difference" will work.', - ) - parser.add_argument( - "--force", - action="store_true", - help="Try to merge models even if they are incompatible with each other", - ) - parser.add_argument( - "--clobber", - "--overwrite", - dest="clobber", - action="store_true", - help="Overwrite the merged model if --merged_model_name already exists", - ) - return parser.parse_args() - - -# ------------------------- GUI HERE ------------------------- -class mergeModelsForm(npyscreen.FormMultiPageAction): - interpolations = ["weighted_sum", "sigmoid", "inv_sigmoid"] - - def __init__(self, parentApp, name): - self.parentApp = parentApp - self.ALLOW_RESIZE = True - self.FIX_MINIMUM_SIZE_WHEN_CREATED = False - super().__init__(parentApp, name) - - @property - def model_record_store(self) -> ModelRecordServiceBase: - installer: ModelInstallServiceBase = self.parentApp.installer - return installer.record_store - - def afterEditing(self) -> None: - self.parentApp.setNextForm(None) - - def create(self) -> None: - window_height, window_width = curses.initscr().getmaxyx() - self.current_base = 0 - self.models = self.get_models(BASE_TYPES[self.current_base][0]) - self.model_names = [x[1] for x in self.models] - max_width = max([len(x) for x in self.model_names]) - max_width += 6 - horizontal_layout = max_width * 3 < window_width - - self.add_widget_intelligent( - npyscreen.FixedText, - color="CONTROL", - value="Select two models to merge and optionally a third.", - editable=False, - ) - self.add_widget_intelligent( - npyscreen.FixedText, - color="CONTROL", - value="Use up and down arrows to move, to select an item, and to move from one field to the next.", - editable=False, - ) - self.nextrely += 1 - self.base_select = self.add_widget_intelligent( - SingleSelectColumns, - values=[x[1] for x in BASE_TYPES], - value=[self.current_base], - columns=4, - max_height=2, - relx=8, - scroll_exit=True, - ) - self.base_select.on_changed = self._populate_models - self.add_widget_intelligent( - npyscreen.FixedText, - value="MODEL 1", - color="GOOD", - editable=False, - rely=6 if horizontal_layout else None, - ) - self.model1 = self.add_widget_intelligent( - npyscreen.SelectOne, - values=self.model_names, - value=0, - max_height=len(self.model_names), - max_width=max_width, - scroll_exit=True, - rely=7, - ) - self.add_widget_intelligent( - npyscreen.FixedText, - value="MODEL 2", - color="GOOD", - editable=False, - relx=max_width + 3 if horizontal_layout else None, - rely=6 if horizontal_layout else None, - ) - self.model2 = self.add_widget_intelligent( - npyscreen.SelectOne, - name="(2)", - values=self.model_names, - value=1, - max_height=len(self.model_names), - max_width=max_width, - relx=max_width + 3 if horizontal_layout else None, - rely=7 if horizontal_layout else None, - scroll_exit=True, - ) - self.add_widget_intelligent( - npyscreen.FixedText, - value="MODEL 3", - color="GOOD", - editable=False, - relx=max_width * 2 + 3 if horizontal_layout else None, - rely=6 if horizontal_layout else None, - ) - models_plus_none = self.model_names.copy() - models_plus_none.insert(0, "None") - self.model3 = self.add_widget_intelligent( - npyscreen.SelectOne, - name="(3)", - values=models_plus_none, - value=0, - max_height=len(self.model_names) + 1, - max_width=max_width, - scroll_exit=True, - relx=max_width * 2 + 3 if horizontal_layout else None, - rely=7 if horizontal_layout else None, - ) - for m in [self.model1, self.model2, self.model3]: - m.when_value_edited = self.models_changed - self.merged_model_name = self.add_widget_intelligent( - TextBox, - name="Name for merged model:", - labelColor="CONTROL", - max_height=3, - value="", - scroll_exit=True, - ) - self.force = self.add_widget_intelligent( - npyscreen.Checkbox, - name="Force merge of models created by different diffusers library versions", - labelColor="CONTROL", - value=True, - scroll_exit=True, - ) - self.nextrely += 1 - self.merge_method = self.add_widget_intelligent( - npyscreen.TitleSelectOne, - name="Merge Method:", - values=self.interpolations, - value=0, - labelColor="CONTROL", - max_height=len(self.interpolations) + 1, - scroll_exit=True, - ) - self.alpha = self.add_widget_intelligent( - FloatTitleSlider, - name="Weight (alpha) to assign to second and third models:", - out_of=1.0, - step=0.01, - lowest=0, - value=0.5, - labelColor="CONTROL", - scroll_exit=True, - ) - self.model1.editing = True - - def models_changed(self) -> None: - models = self.model1.values - selected_model1 = self.model1.value[0] - selected_model2 = self.model2.value[0] - selected_model3 = self.model3.value[0] - merged_model_name = f"{models[selected_model1]}+{models[selected_model2]}" - self.merged_model_name.value = merged_model_name - - if selected_model3 > 0: - self.merge_method.values = ["add_difference ( A+(B-C) )"] - self.merged_model_name.value += f"+{models[selected_model3 -1]}" # In model3 there is one more element in the list (None). So we have to subtract one. - else: - self.merge_method.values = self.interpolations - self.merge_method.value = 0 - - def on_ok(self) -> None: - if self.validate_field_values() and self.check_for_overwrite(): - self.parentApp.setNextForm(None) - self.editing = False - self.parentApp.merge_arguments = self.marshall_arguments() - npyscreen.notify("Starting the merge...") - else: - self.editing = True - - def on_cancel(self) -> None: - sys.exit(0) - - def marshall_arguments(self) -> dict: - model_keys = [x[0] for x in self.models] - models = [ - model_keys[self.model1.value[0]], - model_keys[self.model2.value[0]], - ] - if self.model3.value[0] > 0: - models.append(model_keys[self.model3.value[0] - 1]) - interp = "add_difference" - else: - interp = self.interpolations[self.merge_method.value[0]] - - args = { - "model_keys": models, - "alpha": self.alpha.value, - "interp": interp, - "force": self.force.value, - "merged_model_name": self.merged_model_name.value, - } - return args - - def check_for_overwrite(self) -> bool: - model_out = self.merged_model_name.value - if model_out not in self.model_names: - return True - else: - result: bool = npyscreen.notify_yes_no( - f"The chosen merged model destination, {model_out}, is already in use. Overwrite?" - ) - return result - - def validate_field_values(self) -> bool: - bad_fields = [] - model_names = self.model_names - selected_models = {model_names[self.model1.value[0]], model_names[self.model2.value[0]]} - if self.model3.value[0] > 0: - selected_models.add(model_names[self.model3.value[0] - 1]) - if len(selected_models) < 2: - bad_fields.append(f"Please select two or three DIFFERENT models to compare. You selected {selected_models}") - if len(bad_fields) > 0: - message = "The following problems were detected and must be corrected:" - for problem in bad_fields: - message += f"\n* {problem}" - npyscreen.notify_confirm(message) - return False - else: - return True - - def get_models(self, base_model: Optional[BaseModelType] = None) -> List[Tuple[str, str]]: # key to name - models = [ - (x.key, x.name) - for x in self.model_record_store.search_by_attr(model_type=ModelType.Main, base_model=base_model) - if x.format == ModelFormat("diffusers") - and hasattr(x, "variant") - and x.variant == ModelVariantType("normal") - ] - return sorted(models, key=lambda x: x[1]) - - def _populate_models(self, value: List[int]) -> None: - base_model = BASE_TYPES[value[0]][0] - self.models = self.get_models(base_model) - self.model_names = [x[1] for x in self.models] - - models_plus_none = self.model_names.copy() - models_plus_none.insert(0, "None") - self.model1.values = self.model_names - self.model2.values = self.model_names - self.model3.values = models_plus_none - - self.display() - - -# npyscreen is untyped and causes mypy to get naggy -class Mergeapp(npyscreen.NPSAppManaged): # type: ignore - def __init__(self, installer: ModelInstallServiceBase): - """Initialize the npyscreen application.""" - super().__init__() - self.installer = installer - - def onStart(self) -> None: - npyscreen.setTheme(npyscreen.Themes.ElegantTheme) - self.main = self.addForm("MAIN", mergeModelsForm, name="Merge Models Settings") - - -def run_gui(args: Namespace) -> None: - installer = initialize_installer(config) - mergeapp = Mergeapp(installer) - mergeapp.run() - merge_args = mergeapp.merge_arguments - merger = ModelMerger(installer) - merger.merge_diffusion_models_and_save(**merge_args) - logger.info(f'Models merged into new model: "{merge_args.merged_model_name}".') - - -def run_cli(args: Namespace) -> None: - assert args.alpha >= 0 and args.alpha <= 1.0, "alpha must be between 0 and 1" - assert ( - args.model_names and len(args.model_names) >= 1 and len(args.model_names) <= 3 - ), "Please provide the --models argument to list 2 to 3 models to merge. Use --help for full usage." - - if not args.merged_model_name: - args.merged_model_name = "+".join(args.model_names) - logger.info(f'No --merged_model_name provided. Defaulting to "{args.merged_model_name}"') - - installer = initialize_installer(config) - store = installer.record_store - assert ( - len(store.search_by_attr(args.merged_model_name, args.base_model, ModelType.Main)) == 0 or args.clobber - ), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.' - - merger = ModelMerger(installer) - model_keys = [] - for name in args.model_names: - if len(name) == 32 and re.match(r"^[0-9a-f]$", name): - model_keys.append(name) - else: - models = store.search_by_attr( - model_name=name, model_type=ModelType.Main, base_model=BaseModelType(args.base_model) - ) - assert len(models) > 0, f"{name}: Unknown model" - assert len(models) < 2, f"{name}: More than one model by this name. Please specify the model key instead." - model_keys.append(models[0].key) - - merger.merge_diffusion_models_and_save( - alpha=args.alpha, - model_keys=model_keys, - merged_model_name=args.merged_model_name, - interp=args.interp, - force=args.force, - ) - logger.info(f'Models merged into new model: "{args.merged_model_name}".') - - -def main() -> None: - args = _parse_args() - if args.root_dir: - config.parse_args(["--root", str(args.root_dir)]) - else: - config.parse_args([]) - - try: - if args.front_end: - run_gui(args) - else: - run_cli(args) - except widget.NotEnoughSpaceForWidget as e: - if str(e).startswith("Height of 1 allocated"): - logger.error("You need to have at least two diffusers models defined in models.yaml in order to merge") - else: - logger.error("Not enough room for the user interface. Try making this window larger.") - sys.exit(-1) - except Exception as e: - logger.error(str(e)) - sys.exit(-1) - except KeyboardInterrupt: - sys.exit(-1) - - -if __name__ == "__main__": - main() diff --git a/tests/app/services/model_install/test_model_install.py b/tests/app/services/model_install/test_model_install.py index 5694432ebdc..55f7e865410 100644 --- a/tests/app/services/model_install/test_model_install.py +++ b/tests/app/services/model_install/test_model_install.py @@ -20,7 +20,7 @@ ) from invokeai.app.services.model_records import UnknownModelException from invokeai.backend.model_manager.config import BaseModelType, ModelFormat, ModelType -from tests.backend.model_manager_2.model_manager_2_fixtures import * # noqa F403 +from tests.backend.model_manager.model_manager_fixtures import * # noqa F403 OS = platform.uname().system diff --git a/tests/app/services/model_records/test_model_records_sql.py b/tests/app/services/model_records/test_model_records_sql.py index 852e1da979c..57515ac81b1 100644 --- a/tests/app/services/model_records/test_model_records_sql.py +++ b/tests/app/services/model_records/test_model_records_sql.py @@ -26,7 +26,7 @@ ) from invokeai.backend.model_manager.metadata import BaseMetadata from invokeai.backend.util.logging import InvokeAILogger -from tests.backend.model_manager_2.model_manager_2_fixtures import * # noqa F403 +from tests.backend.model_manager.model_manager_fixtures import * # noqa F403 from tests.fixtures.sqlite_database import create_mock_sqlite_database diff --git a/tests/backend/ip_adapter/test_ip_adapter.py b/tests/backend/ip_adapter/test_ip_adapter.py index 6a3ec510a2c..9ed3c9bc507 100644 --- a/tests/backend/ip_adapter/test_ip_adapter.py +++ b/tests/backend/ip_adapter/test_ip_adapter.py @@ -2,7 +2,7 @@ import torch from invokeai.backend.ip_adapter.unet_patcher import UNetPatcher -from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType +from invokeai.backend.model_manager import BaseModelType, ModelType, SubModelType from invokeai.backend.util.test_utils import install_and_load_model diff --git a/tests/backend/model_manager_2/data/invokeai_root/README b/tests/backend/model_manager/data/invokeai_root/README similarity index 100% rename from tests/backend/model_manager_2/data/invokeai_root/README rename to tests/backend/model_manager/data/invokeai_root/README diff --git a/tests/backend/model_manager_2/data/invokeai_root/configs/stable-diffusion/v1-inference.yaml b/tests/backend/model_manager/data/invokeai_root/configs/stable-diffusion/v1-inference.yaml similarity index 100% rename from tests/backend/model_manager_2/data/invokeai_root/configs/stable-diffusion/v1-inference.yaml rename to tests/backend/model_manager/data/invokeai_root/configs/stable-diffusion/v1-inference.yaml diff --git a/tests/backend/model_manager_2/data/invokeai_root/databases/README b/tests/backend/model_manager/data/invokeai_root/databases/README similarity index 100% rename from tests/backend/model_manager_2/data/invokeai_root/databases/README rename to tests/backend/model_manager/data/invokeai_root/databases/README diff --git a/tests/backend/model_manager_2/data/invokeai_root/models/README b/tests/backend/model_manager/data/invokeai_root/models/README similarity index 100% rename from tests/backend/model_manager_2/data/invokeai_root/models/README rename to tests/backend/model_manager/data/invokeai_root/models/README diff --git a/tests/backend/model_manager_2/data/test_files/test-diffusers-main/model_index.json b/tests/backend/model_manager/data/test_files/test-diffusers-main/model_index.json similarity index 100% rename from tests/backend/model_manager_2/data/test_files/test-diffusers-main/model_index.json rename to tests/backend/model_manager/data/test_files/test-diffusers-main/model_index.json diff --git a/tests/backend/model_manager_2/data/test_files/test-diffusers-main/scheduler/scheduler_config.json b/tests/backend/model_manager/data/test_files/test-diffusers-main/scheduler/scheduler_config.json similarity index 100% rename from tests/backend/model_manager_2/data/test_files/test-diffusers-main/scheduler/scheduler_config.json rename to tests/backend/model_manager/data/test_files/test-diffusers-main/scheduler/scheduler_config.json diff --git a/tests/backend/model_manager_2/data/test_files/test-diffusers-main/text_encoder/config.json b/tests/backend/model_manager/data/test_files/test-diffusers-main/text_encoder/config.json similarity index 100% rename from tests/backend/model_manager_2/data/test_files/test-diffusers-main/text_encoder/config.json rename to tests/backend/model_manager/data/test_files/test-diffusers-main/text_encoder/config.json diff --git a/tests/backend/model_manager_2/data/test_files/test-diffusers-main/text_encoder/model.fp16.safetensors b/tests/backend/model_manager/data/test_files/test-diffusers-main/text_encoder/model.fp16.safetensors similarity index 100% rename from tests/backend/model_manager_2/data/test_files/test-diffusers-main/text_encoder/model.fp16.safetensors rename to tests/backend/model_manager/data/test_files/test-diffusers-main/text_encoder/model.fp16.safetensors diff --git a/tests/backend/model_manager_2/data/test_files/test-diffusers-main/text_encoder/model.safetensors b/tests/backend/model_manager/data/test_files/test-diffusers-main/text_encoder/model.safetensors similarity index 100% rename from tests/backend/model_manager_2/data/test_files/test-diffusers-main/text_encoder/model.safetensors rename to tests/backend/model_manager/data/test_files/test-diffusers-main/text_encoder/model.safetensors diff --git a/tests/backend/model_manager_2/data/test_files/test-diffusers-main/text_encoder_2/config.json b/tests/backend/model_manager/data/test_files/test-diffusers-main/text_encoder_2/config.json similarity index 100% rename from tests/backend/model_manager_2/data/test_files/test-diffusers-main/text_encoder_2/config.json rename to tests/backend/model_manager/data/test_files/test-diffusers-main/text_encoder_2/config.json diff --git a/tests/backend/model_manager_2/data/test_files/test-diffusers-main/text_encoder_2/model.fp16.safetensors b/tests/backend/model_manager/data/test_files/test-diffusers-main/text_encoder_2/model.fp16.safetensors similarity index 100% rename from tests/backend/model_manager_2/data/test_files/test-diffusers-main/text_encoder_2/model.fp16.safetensors rename to tests/backend/model_manager/data/test_files/test-diffusers-main/text_encoder_2/model.fp16.safetensors diff --git a/tests/backend/model_manager_2/data/test_files/test-diffusers-main/text_encoder_2/model.safetensors b/tests/backend/model_manager/data/test_files/test-diffusers-main/text_encoder_2/model.safetensors similarity index 100% rename from tests/backend/model_manager_2/data/test_files/test-diffusers-main/text_encoder_2/model.safetensors rename to tests/backend/model_manager/data/test_files/test-diffusers-main/text_encoder_2/model.safetensors diff --git a/tests/backend/model_manager_2/data/test_files/test-diffusers-main/tokenizer/merges.txt b/tests/backend/model_manager/data/test_files/test-diffusers-main/tokenizer/merges.txt similarity index 100% rename from tests/backend/model_manager_2/data/test_files/test-diffusers-main/tokenizer/merges.txt rename to tests/backend/model_manager/data/test_files/test-diffusers-main/tokenizer/merges.txt diff --git a/tests/backend/model_manager_2/data/test_files/test-diffusers-main/tokenizer/special_tokens_map.json b/tests/backend/model_manager/data/test_files/test-diffusers-main/tokenizer/special_tokens_map.json similarity index 100% rename from tests/backend/model_manager_2/data/test_files/test-diffusers-main/tokenizer/special_tokens_map.json rename to tests/backend/model_manager/data/test_files/test-diffusers-main/tokenizer/special_tokens_map.json diff --git a/tests/backend/model_manager_2/data/test_files/test-diffusers-main/tokenizer/tokenizer_config.json b/tests/backend/model_manager/data/test_files/test-diffusers-main/tokenizer/tokenizer_config.json similarity index 100% rename from tests/backend/model_manager_2/data/test_files/test-diffusers-main/tokenizer/tokenizer_config.json rename to tests/backend/model_manager/data/test_files/test-diffusers-main/tokenizer/tokenizer_config.json diff --git a/tests/backend/model_manager_2/data/test_files/test-diffusers-main/tokenizer/vocab.json b/tests/backend/model_manager/data/test_files/test-diffusers-main/tokenizer/vocab.json similarity index 100% rename from tests/backend/model_manager_2/data/test_files/test-diffusers-main/tokenizer/vocab.json rename to tests/backend/model_manager/data/test_files/test-diffusers-main/tokenizer/vocab.json diff --git a/tests/backend/model_manager_2/data/test_files/test-diffusers-main/tokenizer_2/merges.txt b/tests/backend/model_manager/data/test_files/test-diffusers-main/tokenizer_2/merges.txt similarity index 100% rename from tests/backend/model_manager_2/data/test_files/test-diffusers-main/tokenizer_2/merges.txt rename to tests/backend/model_manager/data/test_files/test-diffusers-main/tokenizer_2/merges.txt diff --git a/tests/backend/model_manager_2/data/test_files/test-diffusers-main/tokenizer_2/special_tokens_map.json b/tests/backend/model_manager/data/test_files/test-diffusers-main/tokenizer_2/special_tokens_map.json similarity index 100% rename from tests/backend/model_manager_2/data/test_files/test-diffusers-main/tokenizer_2/special_tokens_map.json rename to tests/backend/model_manager/data/test_files/test-diffusers-main/tokenizer_2/special_tokens_map.json diff --git a/tests/backend/model_manager_2/data/test_files/test-diffusers-main/tokenizer_2/tokenizer_config.json b/tests/backend/model_manager/data/test_files/test-diffusers-main/tokenizer_2/tokenizer_config.json similarity index 100% rename from tests/backend/model_manager_2/data/test_files/test-diffusers-main/tokenizer_2/tokenizer_config.json rename to tests/backend/model_manager/data/test_files/test-diffusers-main/tokenizer_2/tokenizer_config.json diff --git a/tests/backend/model_manager_2/data/test_files/test-diffusers-main/tokenizer_2/vocab.json b/tests/backend/model_manager/data/test_files/test-diffusers-main/tokenizer_2/vocab.json similarity index 100% rename from tests/backend/model_manager_2/data/test_files/test-diffusers-main/tokenizer_2/vocab.json rename to tests/backend/model_manager/data/test_files/test-diffusers-main/tokenizer_2/vocab.json diff --git a/tests/backend/model_manager_2/data/test_files/test-diffusers-main/unet/config.json b/tests/backend/model_manager/data/test_files/test-diffusers-main/unet/config.json similarity index 100% rename from tests/backend/model_manager_2/data/test_files/test-diffusers-main/unet/config.json rename to tests/backend/model_manager/data/test_files/test-diffusers-main/unet/config.json diff --git a/tests/backend/model_manager_2/data/test_files/test-diffusers-main/unet/diffusion_pytorch_model.fp16.safetensors b/tests/backend/model_manager/data/test_files/test-diffusers-main/unet/diffusion_pytorch_model.fp16.safetensors similarity index 100% rename from tests/backend/model_manager_2/data/test_files/test-diffusers-main/unet/diffusion_pytorch_model.fp16.safetensors rename to tests/backend/model_manager/data/test_files/test-diffusers-main/unet/diffusion_pytorch_model.fp16.safetensors diff --git a/tests/backend/model_manager_2/data/test_files/test-diffusers-main/unet/diffusion_pytorch_model.safetensors b/tests/backend/model_manager/data/test_files/test-diffusers-main/unet/diffusion_pytorch_model.safetensors similarity index 100% rename from tests/backend/model_manager_2/data/test_files/test-diffusers-main/unet/diffusion_pytorch_model.safetensors rename to tests/backend/model_manager/data/test_files/test-diffusers-main/unet/diffusion_pytorch_model.safetensors diff --git a/tests/backend/model_manager_2/data/test_files/test-diffusers-main/vae/config.json b/tests/backend/model_manager/data/test_files/test-diffusers-main/vae/config.json similarity index 100% rename from tests/backend/model_manager_2/data/test_files/test-diffusers-main/vae/config.json rename to tests/backend/model_manager/data/test_files/test-diffusers-main/vae/config.json diff --git a/tests/backend/model_manager_2/data/test_files/test-diffusers-main/vae/diffusion_pytorch_model.fp16.safetensors b/tests/backend/model_manager/data/test_files/test-diffusers-main/vae/diffusion_pytorch_model.fp16.safetensors similarity index 100% rename from tests/backend/model_manager_2/data/test_files/test-diffusers-main/vae/diffusion_pytorch_model.fp16.safetensors rename to tests/backend/model_manager/data/test_files/test-diffusers-main/vae/diffusion_pytorch_model.fp16.safetensors diff --git a/tests/backend/model_manager_2/data/test_files/test-diffusers-main/vae/diffusion_pytorch_model.safetensors b/tests/backend/model_manager/data/test_files/test-diffusers-main/vae/diffusion_pytorch_model.safetensors similarity index 100% rename from tests/backend/model_manager_2/data/test_files/test-diffusers-main/vae/diffusion_pytorch_model.safetensors rename to tests/backend/model_manager/data/test_files/test-diffusers-main/vae/diffusion_pytorch_model.safetensors diff --git a/tests/backend/model_manager_2/data/test_files/test_embedding.safetensors b/tests/backend/model_manager/data/test_files/test_embedding.safetensors similarity index 100% rename from tests/backend/model_manager_2/data/test_files/test_embedding.safetensors rename to tests/backend/model_manager/data/test_files/test_embedding.safetensors diff --git a/tests/backend/model_manager_2/model_loading/test_model_load.py b/tests/backend/model_manager/model_loading/test_model_load.py similarity index 61% rename from tests/backend/model_manager_2/model_loading/test_model_load.py rename to tests/backend/model_manager/model_loading/test_model_load.py index a7a64e91ac0..38d9b8afb8c 100644 --- a/tests/backend/model_manager_2/model_loading/test_model_load.py +++ b/tests/backend/model_manager/model_loading/test_model_load.py @@ -5,17 +5,16 @@ from pathlib import Path from invokeai.app.services.model_install import ModelInstallServiceBase -from invokeai.backend.embeddings.textual_inversion import TextualInversionModelRaw -from invokeai.backend.model_manager.load import AnyModelLoader -from tests.backend.model_manager_2.model_manager_2_fixtures import * # noqa F403 +from invokeai.app.services.model_load import ModelLoadServiceBase +from invokeai.backend.textual_inversion import TextualInversionModelRaw +from tests.backend.model_manager.model_manager_fixtures import * # noqa F403 - -def test_loading(mm2_installer: ModelInstallServiceBase, mm2_loader: AnyModelLoader, embedding_file: Path): +def test_loading(mm2_installer: ModelInstallServiceBase, mm2_loader: ModelLoadServiceBase, embedding_file: Path): store = mm2_installer.record_store matches = store.search_by_attr(model_name="test_embedding") assert len(matches) == 0 key = mm2_installer.register_path(embedding_file) - loaded_model = mm2_loader.load_model(store.get_model(key)) + loaded_model = mm2_loader.load_model_by_config(store.get_model(key)) assert loaded_model is not None assert loaded_model.config.key == key with loaded_model as model: diff --git a/tests/backend/model_manager_2/model_manager_2_fixtures.py b/tests/backend/model_manager/model_manager_fixtures.py similarity index 80% rename from tests/backend/model_manager_2/model_manager_2_fixtures.py rename to tests/backend/model_manager/model_manager_fixtures.py index ebdc9cb5cd6..5f7f44c0188 100644 --- a/tests/backend/model_manager_2/model_manager_2_fixtures.py +++ b/tests/backend/model_manager/model_manager_fixtures.py @@ -6,24 +6,27 @@ from typing import Any, Dict, List import pytest +from pytest import FixtureRequest from pydantic import BaseModel from requests.sessions import Session from requests_testadapter import TestAdapter, TestSession from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.app.services.download import DownloadQueueService +from invokeai.app.services.download import DownloadQueueServiceBase, DownloadQueueService from invokeai.app.services.events.events_base import EventServiceBase +from invokeai.app.services.model_manager import ModelManagerServiceBase, ModelManagerService +from invokeai.app.services.model_load import ModelLoadServiceBase, ModelLoadService from invokeai.app.services.model_install import ModelInstallService, ModelInstallServiceBase from invokeai.app.services.model_metadata import ModelMetadataStoreBase, ModelMetadataStoreSQL -from invokeai.app.services.model_records import ModelRecordServiceSQL +from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL from invokeai.backend.model_manager.config import ( BaseModelType, ModelFormat, ModelType, ) -from invokeai.backend.model_manager.load import AnyModelLoader, ModelCache, ModelConvertCache +from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache from invokeai.backend.util.logging import InvokeAILogger -from tests.backend.model_manager_2.model_metadata.metadata_examples import ( +from tests.backend.model_manager.model_metadata.metadata_examples import ( RepoCivitaiModelMetadata1, RepoCivitaiVersionMetadata1, RepoHFMetadata1, @@ -86,22 +89,71 @@ def mm2_app_config(mm2_root_dir: Path) -> InvokeAIAppConfig: app_config = InvokeAIAppConfig( root=mm2_root_dir, models_dir=mm2_root_dir / "models", + log_level="info", ) return app_config @pytest.fixture -def mm2_loader(mm2_app_config: InvokeAIAppConfig, mm2_record_store: ModelRecordServiceSQL) -> AnyModelLoader: - logger = InvokeAILogger.get_logger(config=mm2_app_config) +def mm2_download_queue(mm2_session: Session, + request: FixtureRequest + ) -> DownloadQueueServiceBase: + download_queue = DownloadQueueService(requests_session=mm2_session) + download_queue.start() + + def stop_queue() -> None: + download_queue.stop() + + request.addfinalizer(stop_queue) + return download_queue + +@pytest.fixture +def mm2_metadata_store(mm2_record_store: ModelRecordServiceSQL) -> ModelMetadataStoreBase: + return mm2_record_store.metadata_store + +@pytest.fixture +def mm2_loader(mm2_app_config: InvokeAIAppConfig, mm2_record_store: ModelRecordServiceBase) -> ModelLoadServiceBase: ram_cache = ModelCache( - logger=logger, max_cache_size=mm2_app_config.ram_cache_size, max_vram_cache_size=mm2_app_config.vram_cache_size + logger=InvokeAILogger.get_logger(), + max_cache_size=mm2_app_config.ram_cache_size, + max_vram_cache_size=mm2_app_config.vram_cache_size ) convert_cache = ModelConvertCache(mm2_app_config.models_convert_cache_path) - return AnyModelLoader(app_config=mm2_app_config, logger=logger, ram_cache=ram_cache, convert_cache=convert_cache) + return ModelLoadService(app_config=mm2_app_config, + record_store=mm2_record_store, + ram_cache=ram_cache, + convert_cache=convert_cache, + ) + +@pytest.fixture +def mm2_installer(mm2_app_config: InvokeAIAppConfig, + mm2_download_queue: DownloadQueueServiceBase, + mm2_session: Session, + request: FixtureRequest, + ) -> ModelInstallServiceBase: + logger = InvokeAILogger.get_logger() + db = create_mock_sqlite_database(mm2_app_config, logger) + events = DummyEventService() + store = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db)) + + installer = ModelInstallService( + app_config=mm2_app_config, + record_store=store, + download_queue=mm2_download_queue, + event_bus=events, + session=mm2_session, + ) + installer.start() + + def stop_installer() -> None: + installer.stop() + + request.addfinalizer(stop_installer) + return installer @pytest.fixture -def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceSQL: +def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceBase: logger = InvokeAILogger.get_logger(config=mm2_app_config) db = create_mock_sqlite_database(mm2_app_config, logger) store = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db)) @@ -161,11 +213,15 @@ def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceSQL store.add_model("test_config_5", raw5) return store - @pytest.fixture -def mm2_metadata_store(mm2_record_store: ModelRecordServiceSQL) -> ModelMetadataStoreBase: - return mm2_record_store.metadata_store - +def mm2_model_manager(mm2_record_store: ModelRecordServiceBase, + mm2_installer: ModelInstallServiceBase, + mm2_loader: ModelLoadServiceBase) -> ModelManagerServiceBase: + return ModelManagerService( + store=mm2_record_store, + install=mm2_installer, + load=mm2_loader + ) @pytest.fixture def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session: @@ -252,22 +308,3 @@ def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session: return sess -@pytest.fixture -def mm2_installer(mm2_app_config: InvokeAIAppConfig, mm2_session: Session) -> ModelInstallServiceBase: - logger = InvokeAILogger.get_logger() - db = create_mock_sqlite_database(mm2_app_config, logger) - events = DummyEventService() - store = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db)) - - download_queue = DownloadQueueService(requests_session=mm2_session) - download_queue.start() - - installer = ModelInstallService( - app_config=mm2_app_config, - record_store=store, - download_queue=download_queue, - event_bus=events, - session=mm2_session, - ) - installer.start() - return installer diff --git a/tests/backend/model_manager_2/model_metadata/metadata_examples.py b/tests/backend/model_manager/model_metadata/metadata_examples.py similarity index 100% rename from tests/backend/model_manager_2/model_metadata/metadata_examples.py rename to tests/backend/model_manager/model_metadata/metadata_examples.py diff --git a/tests/backend/model_manager_2/model_metadata/test_model_metadata.py b/tests/backend/model_manager/model_metadata/test_model_metadata.py similarity index 99% rename from tests/backend/model_manager_2/model_metadata/test_model_metadata.py rename to tests/backend/model_manager/model_metadata/test_model_metadata.py index f61eab1b5d0..09b18916d38 100644 --- a/tests/backend/model_manager_2/model_metadata/test_model_metadata.py +++ b/tests/backend/model_manager/model_metadata/test_model_metadata.py @@ -19,7 +19,7 @@ UnknownMetadataException, ) from invokeai.backend.model_manager.util import select_hf_files -from tests.backend.model_manager_2.model_manager_2_fixtures import * # noqa F403 +from tests.backend.model_manager.model_manager_fixtures import * # noqa F403 def test_metadata_store_put_get(mm2_metadata_store: ModelMetadataStoreBase) -> None: diff --git a/tests/backend/model_management/test_libc_util.py b/tests/backend/model_manager/test_libc_util.py similarity index 88% rename from tests/backend/model_management/test_libc_util.py rename to tests/backend/model_manager/test_libc_util.py index e13a2fd3a2f..4309dc7c34c 100644 --- a/tests/backend/model_management/test_libc_util.py +++ b/tests/backend/model_manager/test_libc_util.py @@ -1,6 +1,6 @@ import pytest -from invokeai.backend.model_management.libc_util import LibcUtil, Struct_mallinfo2 +from invokeai.backend.model_manager.util.libc_util import LibcUtil, Struct_mallinfo2 def test_libc_util_mallinfo2(): diff --git a/tests/backend/model_management/test_lora.py b/tests/backend/model_manager/test_lora.py similarity index 96% rename from tests/backend/model_management/test_lora.py rename to tests/backend/model_manager/test_lora.py index 14bcc87c892..e124bb68efc 100644 --- a/tests/backend/model_management/test_lora.py +++ b/tests/backend/model_manager/test_lora.py @@ -5,8 +5,8 @@ import pytest import torch -from invokeai.backend.model_management.lora import ModelPatcher -from invokeai.backend.model_management.models.lora import LoRALayer, LoRAModelRaw +from invokeai.backend.model_patcher import ModelPatcher +from invokeai.backend.lora import LoRALayer, LoRAModelRaw @pytest.mark.parametrize( diff --git a/tests/backend/model_management/test_memory_snapshot.py b/tests/backend/model_manager/test_memory_snapshot.py similarity index 87% rename from tests/backend/model_management/test_memory_snapshot.py rename to tests/backend/model_manager/test_memory_snapshot.py index 216cd62171d..87ec8c34ee0 100644 --- a/tests/backend/model_management/test_memory_snapshot.py +++ b/tests/backend/model_manager/test_memory_snapshot.py @@ -1,8 +1,7 @@ import pytest -from invokeai.backend.model_management.libc_util import Struct_mallinfo2 -from invokeai.backend.model_management.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff - +from invokeai.backend.model_manager.util.libc_util import Struct_mallinfo2 +from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff def test_memory_snapshot_capture(): """Smoke test of MemorySnapshot.capture().""" @@ -26,6 +25,7 @@ def test_memory_snapshot_capture(): def test_get_pretty_snapshot_diff(snapshot_1, snapshot_2): """Test that get_pretty_snapshot_diff() works with various combinations of missing MemorySnapshot fields.""" msg = get_pretty_snapshot_diff(snapshot_1, snapshot_2) + print(msg) expected_lines = 0 if snapshot_1 is not None and snapshot_2 is not None: diff --git a/tests/backend/model_management/test_model_load_optimization.py b/tests/backend/model_manager/test_model_load_optimization.py similarity index 96% rename from tests/backend/model_management/test_model_load_optimization.py rename to tests/backend/model_manager/test_model_load_optimization.py index a4fe1dd5974..f627f3a2982 100644 --- a/tests/backend/model_management/test_model_load_optimization.py +++ b/tests/backend/model_manager/test_model_load_optimization.py @@ -1,7 +1,7 @@ import pytest import torch -from invokeai.backend.model_management.model_load_optimizations import _no_op, skip_torch_weight_init +from invokeai.backend.model_manager.load.optimizations import _no_op, skip_torch_weight_init @pytest.mark.parametrize( diff --git a/tests/backend/model_manager_2/util/test_hf_model_select.py b/tests/backend/model_manager/util/test_hf_model_select.py similarity index 100% rename from tests/backend/model_manager_2/util/test_hf_model_select.py rename to tests/backend/model_manager/util/test_hf_model_select.py diff --git a/tests/conftest.py b/tests/conftest.py index 6e7d559be44..1c816002296 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,2 @@ # conftest.py is a special pytest file. Fixtures defined in this file will be accessible to all tests in this directory # without needing to explicitly import them. (https://docs.pytest.org/en/6.2.x/fixture.html) - - -# We import the model_installer and torch_device fixtures here so that they can be used by all tests. Flake8 does not -# play well with fixtures (F401 and F811), so this is cleaner than importing in all files that use these fixtures. -from invokeai.backend.util.test_utils import model_installer, torch_device # noqa: F401 From 50b989ee5b008a127f09aa9182b04089a0870f8a Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 18 Feb 2024 17:27:42 +1100 Subject: [PATCH 108/340] final tidying before marking PR as ready for review - Replace AnyModelLoader with ModelLoaderRegistry - Fix type check errors in multiple files - Remove apparently unneeded `get_model_config_enum()` method from model manager - Remove last vestiges of old model manager - Updated tests and documentation resolve conflict with seamless.py --- docs/contributing/MODEL_MANAGER.md | 129 +- .../{model_manager_v2.py => model_manager.py} | 38 +- invokeai/app/api/routers/models.py | 426 ---- invokeai/app/api_app.py | 36 +- invokeai/app/invocations/compel.py | 4 +- invokeai/app/services/config/config_base.py | 2 +- .../invocation_stats/invocation_stats_base.py | 11 +- .../model_install/model_install_base.py | 2 +- .../services/model_load/model_load_base.py | 48 +- .../services/model_load/model_load_default.py | 100 +- .../app/services/model_manager/__init__.py | 2 +- .../model_manager/model_manager_base.py | 33 + .../model_manager/model_manager_default.py | 60 +- .../app/services/shared/invocation_context.py | 4 +- invokeai/backend/install/migrate_to_3.py | 591 ------ .../backend/install/model_install_backend.py | 637 ------ invokeai/backend/ip_adapter/ip_adapter.py | 2 +- invokeai/backend/lora.py | 2 + .../backend/model_management_OLD/README.md | 27 - .../backend/model_management_OLD/__init__.py | 20 - .../convert_ckpt_to_diffusers.py | 1739 ----------------- .../detect_baked_in_vae.py | 31 - invokeai/backend/model_management_OLD/lora.py | 582 ------ .../model_management_OLD/memory_snapshot.py | 99 - .../model_management_OLD/model_cache.py | 553 ------ .../model_load_optimizations.py | 30 - .../model_management_OLD/model_manager.py | 1121 ----------- .../model_management_OLD/model_merge.py | 140 -- .../model_management_OLD/model_probe.py | 664 ------- .../model_management_OLD/model_search.py | 112 -- .../model_management_OLD/models/__init__.py | 167 -- .../model_management_OLD/models/base.py | 681 ------- .../models/clip_vision.py | 82 - .../model_management_OLD/models/controlnet.py | 162 -- .../model_management_OLD/models/ip_adapter.py | 98 - .../model_management_OLD/models/lora.py | 696 ------- .../model_management_OLD/models/sdxl.py | 148 -- .../models/stable_diffusion.py | 337 ---- .../models/stable_diffusion_onnx.py | 150 -- .../models/t2i_adapter.py | 102 - .../models/textual_inversion.py | 87 - .../model_management_OLD/models/vae.py | 179 -- .../backend/model_management_OLD/seamless.py | 84 - invokeai/backend/model_management_OLD/util.py | 79 - invokeai/backend/model_manager/__init__.py | 40 +- .../backend/model_manager/load/__init__.py | 14 +- .../backend/model_manager/load/load_base.py | 130 +- .../model_manager/load/load_default.py | 54 +- .../model_manager/load/memory_snapshot.py | 2 +- .../load/model_loader_registry.py | 122 ++ .../load/model_loaders/controlnet.py | 6 +- .../load/model_loaders/generic_diffusers.py | 66 +- .../load/model_loaders/ip_adapter.py | 5 +- .../model_manager/load/model_loaders/lora.py | 8 +- .../model_manager/load/model_loaders/onnx.py | 13 +- .../load/model_loaders/stable_diffusion.py | 13 +- .../load/model_loaders/textual_inversion.py | 12 +- .../model_manager/load/model_loaders/vae.py | 8 +- .../model_manager/load/optimizations.py | 13 +- invokeai/backend/model_manager/merge.py | 4 +- .../model_manager/metadata/metadata_base.py | 9 +- invokeai/backend/model_manager/probe.py | 3 +- invokeai/backend/model_manager/search.py | 6 +- .../backend/model_manager/util/libc_util.py | 7 +- .../backend/model_manager/util/model_util.py | 20 +- invokeai/backend/onnx/onnx_runtime.py | 3 +- invokeai/backend/raw_model.py | 1 + invokeai/backend/stable_diffusion/seamless.py | 94 +- invokeai/backend/textual_inversion.py | 2 + invokeai/backend/util/test_utils.py | 4 +- .../model_loading/test_model_load.py | 21 +- .../model_manager/model_manager_fixtures.py | 54 +- tests/backend/model_manager/test_lora.py | 2 +- .../model_manager/test_memory_snapshot.py | 3 +- 74 files changed, 673 insertions(+), 10363 deletions(-) rename invokeai/app/api/routers/{model_manager_v2.py => model_manager.py} (97%) delete mode 100644 invokeai/app/api/routers/models.py delete mode 100644 invokeai/backend/install/migrate_to_3.py delete mode 100644 invokeai/backend/install/model_install_backend.py delete mode 100644 invokeai/backend/model_management_OLD/README.md delete mode 100644 invokeai/backend/model_management_OLD/__init__.py delete mode 100644 invokeai/backend/model_management_OLD/convert_ckpt_to_diffusers.py delete mode 100644 invokeai/backend/model_management_OLD/detect_baked_in_vae.py delete mode 100644 invokeai/backend/model_management_OLD/lora.py delete mode 100644 invokeai/backend/model_management_OLD/memory_snapshot.py delete mode 100644 invokeai/backend/model_management_OLD/model_cache.py delete mode 100644 invokeai/backend/model_management_OLD/model_load_optimizations.py delete mode 100644 invokeai/backend/model_management_OLD/model_manager.py delete mode 100644 invokeai/backend/model_management_OLD/model_merge.py delete mode 100644 invokeai/backend/model_management_OLD/model_probe.py delete mode 100644 invokeai/backend/model_management_OLD/model_search.py delete mode 100644 invokeai/backend/model_management_OLD/models/__init__.py delete mode 100644 invokeai/backend/model_management_OLD/models/base.py delete mode 100644 invokeai/backend/model_management_OLD/models/clip_vision.py delete mode 100644 invokeai/backend/model_management_OLD/models/controlnet.py delete mode 100644 invokeai/backend/model_management_OLD/models/ip_adapter.py delete mode 100644 invokeai/backend/model_management_OLD/models/lora.py delete mode 100644 invokeai/backend/model_management_OLD/models/sdxl.py delete mode 100644 invokeai/backend/model_management_OLD/models/stable_diffusion.py delete mode 100644 invokeai/backend/model_management_OLD/models/stable_diffusion_onnx.py delete mode 100644 invokeai/backend/model_management_OLD/models/t2i_adapter.py delete mode 100644 invokeai/backend/model_management_OLD/models/textual_inversion.py delete mode 100644 invokeai/backend/model_management_OLD/models/vae.py delete mode 100644 invokeai/backend/model_management_OLD/seamless.py delete mode 100644 invokeai/backend/model_management_OLD/util.py create mode 100644 invokeai/backend/model_manager/load/model_loader_registry.py diff --git a/docs/contributing/MODEL_MANAGER.md b/docs/contributing/MODEL_MANAGER.md index b19699de73d..8351904b619 100644 --- a/docs/contributing/MODEL_MANAGER.md +++ b/docs/contributing/MODEL_MANAGER.md @@ -1531,23 +1531,29 @@ Here is a typical initialization pattern: ``` from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.app.services.model_records import ModelRecordServiceBase -from invokeai.app.services.model_load import ModelLoadService +from invokeai.app.services.model_load import ModelLoadService, ModelLoaderRegistry config = InvokeAIAppConfig.get_config() -store = ModelRecordServiceBase.open(config) -loader = ModelLoadService(config, store) +ram_cache = ModelCache( + max_cache_size=config.ram_cache_size, max_vram_cache_size=config.vram_cache_size, logger=logger +) +convert_cache = ModelConvertCache( + cache_path=config.models_convert_cache_path, max_size=config.convert_cache_size +) +loader = ModelLoadService( + app_config=config, + ram_cache=ram_cache, + convert_cache=convert_cache, + registry=ModelLoaderRegistry +) ``` -Note that we are relying on the contents of the application -configuration to choose the implementation of -`ModelRecordServiceBase`. +### load_model(model_config, [submodel_type], [context]) -> LoadedModel -### load_model_by_key(key, [submodel_type], [context]) -> LoadedModel - -The `load_model_by_key()` method receives the unique key that -identifies the model. It loads the model into memory, gets the model -ready for use, and returns a `LoadedModel` object. +The `load_model()` method takes an `AnyModelConfig` returned by +`ModelRecordService.get_model()` and returns the corresponding loaded +model. It loads the model into memory, gets the model ready for use, +and returns a `LoadedModel` object. The optional second argument, `subtype` is a `SubModelType` string enum, such as "vae". It is mandatory when used with a main model, and @@ -1593,25 +1599,6 @@ with model_info as vae: - `ModelNotFoundException` -- key in database but model not found at path - `NotImplementedException` -- the loader doesn't know how to load this type of model -### load_model_by_attr(model_name, base_model, model_type, [submodel], [context]) -> LoadedModel - -This is similar to `load_model_by_key`, but instead it accepts the -combination of the model's name, type and base, which it passes to the -model record config store for retrieval. If successful, this method -returns a `LoadedModel`. It can raise the following exceptions: - -``` -UnknownModelException -- model with these attributes not known -NotImplementedException -- the loader doesn't know how to load this type of model -ValueError -- more than one model matches this combination of base/type/name -``` - -### load_model_by_config(config, [submodel], [context]) -> LoadedModel - -This method takes an `AnyModelConfig` returned by -ModelRecordService.get_model() and returns the corresponding loaded -model. It may raise a `NotImplementedException`. - ### Emitting model loading events When the `context` argument is passed to `load_model_*()`, it will @@ -1656,7 +1643,7 @@ onnx models. To install a new loader, place it in `invokeai/backend/model_manager/load/model_loaders`. Inherit from -`ModelLoader` and use the `@AnyModelLoader.register()` decorator to +`ModelLoader` and use the `@ModelLoaderRegistry.register()` decorator to indicate what type of models the loader can handle. Here is a complete example from `generic_diffusers.py`, which is able @@ -1674,12 +1661,11 @@ from invokeai.backend.model_manager import ( ModelType, SubModelType, ) -from ..load_base import AnyModelLoader -from ..load_default import ModelLoader +from .. import ModelLoader, ModelLoaderRegistry -@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.CLIPVision, format=ModelFormat.Diffusers) -@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.T2IAdapter, format=ModelFormat.Diffusers) +@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.CLIPVision, format=ModelFormat.Diffusers) +@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T2IAdapter, format=ModelFormat.Diffusers) class GenericDiffusersLoader(ModelLoader): """Class to load simple diffusers models.""" @@ -1728,3 +1714,74 @@ model. It does whatever it needs to do to get the model into diffusers format, and returns the Path of the resulting model. (The path should ordinarily be the same as `output_path`.) +## The ModelManagerService object + +For convenience, the API provides a `ModelManagerService` object which +gives a single point of access to the major model manager +services. This object is created at initialization time and can be +found in the global `ApiDependencies.invoker.services.model_manager` +object, or in `context.services.model_manager` from within an +invocation. + +In the examples below, we have retrieved the manager using: +``` +mm = ApiDependencies.invoker.services.model_manager +``` + +The following properties and methods will be available: + +### mm.store + +This retrieves the `ModelRecordService` associated with the +manager. Example: + +``` +configs = mm.store.get_model_by_attr(name='stable-diffusion-v1-5') +``` + +### mm.install + +This retrieves the `ModelInstallService` associated with the manager. +Example: + +``` +job = mm.install.heuristic_import(`https://civitai.com/models/58390/detail-tweaker-lora-lora`) +``` + +### mm.load + +This retrieves the `ModelLoaderService` associated with the manager. Example: + +``` +configs = mm.store.get_model_by_attr(name='stable-diffusion-v1-5') +assert len(configs) > 0 + +loaded_model = mm.load.load_model(configs[0]) +``` + +The model manager also offers a few convenience shortcuts for loading +models: + +### mm.load_model_by_config(model_config, [submodel], [context]) -> LoadedModel + +Same as `mm.load.load_model()`. + +### mm.load_model_by_attr(model_name, base_model, model_type, [submodel], [context]) -> LoadedModel + +This accepts the combination of the model's name, type and base, which +it passes to the model record config store for retrieval. If a unique +model config is found, this method returns a `LoadedModel`. It can +raise the following exceptions: + +``` +UnknownModelException -- model with these attributes not known +NotImplementedException -- the loader doesn't know how to load this type of model +ValueError -- more than one model matches this combination of base/type/name +``` + +### mm.load_model_by_key(key, [submodel], [context]) -> LoadedModel + +This method takes a model key, looks it up using the +`ModelRecordServiceBase` object in `mm.store`, and passes the returned +model configuration to `load_model_by_config()`. It may raise a +`NotImplementedException`. diff --git a/invokeai/app/api/routers/model_manager_v2.py b/invokeai/app/api/routers/model_manager.py similarity index 97% rename from invokeai/app/api/routers/model_manager_v2.py rename to invokeai/app/api/routers/model_manager.py index 2471e0d8c9b..6b7111dd2ce 100644 --- a/invokeai/app/api/routers/model_manager_v2.py +++ b/invokeai/app/api/routers/model_manager.py @@ -35,7 +35,7 @@ from ..dependencies import ApiDependencies -model_manager_v2_router = APIRouter(prefix="/v2/models", tags=["model_manager_v2"]) +model_manager_router = APIRouter(prefix="/v2/models", tags=["model_manager"]) class ModelsList(BaseModel): @@ -135,7 +135,7 @@ class ModelTagSet(BaseModel): ############################################################################## -@model_manager_v2_router.get( +@model_manager_router.get( "/", operation_id="list_model_records", ) @@ -164,7 +164,7 @@ async def list_model_records( return ModelsList(models=found_models) -@model_manager_v2_router.get( +@model_manager_router.get( "/i/{key}", operation_id="get_model_record", responses={ @@ -188,7 +188,7 @@ async def get_model_record( raise HTTPException(status_code=404, detail=str(e)) -@model_manager_v2_router.get("/summary", operation_id="list_model_summary") +@model_manager_router.get("/summary", operation_id="list_model_summary") async def list_model_summary( page: int = Query(default=0, description="The page to get"), per_page: int = Query(default=10, description="The number of models per page"), @@ -200,7 +200,7 @@ async def list_model_summary( return results -@model_manager_v2_router.get( +@model_manager_router.get( "/meta/i/{key}", operation_id="get_model_metadata", responses={ @@ -223,7 +223,7 @@ async def get_model_metadata( return result -@model_manager_v2_router.get( +@model_manager_router.get( "/tags", operation_id="list_tags", ) @@ -234,7 +234,7 @@ async def list_tags() -> Set[str]: return result -@model_manager_v2_router.get( +@model_manager_router.get( "/tags/search", operation_id="search_by_metadata_tags", ) @@ -247,7 +247,7 @@ async def search_by_metadata_tags( return ModelsList(models=results) -@model_manager_v2_router.patch( +@model_manager_router.patch( "/i/{key}", operation_id="update_model_record", responses={ @@ -281,7 +281,7 @@ async def update_model_record( return model_response -@model_manager_v2_router.delete( +@model_manager_router.delete( "/i/{key}", operation_id="del_model_record", responses={ @@ -311,7 +311,7 @@ async def del_model_record( raise HTTPException(status_code=404, detail=str(e)) -@model_manager_v2_router.post( +@model_manager_router.post( "/i/", operation_id="add_model_record", responses={ @@ -349,7 +349,7 @@ async def add_model_record( return result -@model_manager_v2_router.post( +@model_manager_router.post( "/heuristic_import", operation_id="heuristic_import_model", responses={ @@ -416,7 +416,7 @@ async def heuristic_import( return result -@model_manager_v2_router.post( +@model_manager_router.post( "/install", operation_id="import_model", responses={ @@ -516,7 +516,7 @@ async def import_model( return result -@model_manager_v2_router.get( +@model_manager_router.get( "/import", operation_id="list_model_install_jobs", ) @@ -544,7 +544,7 @@ async def list_model_install_jobs() -> List[ModelInstallJob]: return jobs -@model_manager_v2_router.get( +@model_manager_router.get( "/import/{id}", operation_id="get_model_install_job", responses={ @@ -564,7 +564,7 @@ async def get_model_install_job(id: int = Path(description="Model install id")) raise HTTPException(status_code=404, detail=str(e)) -@model_manager_v2_router.delete( +@model_manager_router.delete( "/import/{id}", operation_id="cancel_model_install_job", responses={ @@ -583,7 +583,7 @@ async def cancel_model_install_job(id: int = Path(description="Model install job installer.cancel_job(job) -@model_manager_v2_router.patch( +@model_manager_router.patch( "/import", operation_id="prune_model_install_jobs", responses={ @@ -597,7 +597,7 @@ async def prune_model_install_jobs() -> Response: return Response(status_code=204) -@model_manager_v2_router.patch( +@model_manager_router.patch( "/sync", operation_id="sync_models_to_config", responses={ @@ -616,7 +616,7 @@ async def sync_models_to_config() -> Response: return Response(status_code=204) -@model_manager_v2_router.put( +@model_manager_router.put( "/convert/{key}", operation_id="convert_model", responses={ @@ -694,7 +694,7 @@ async def convert_model( return new_config -@model_manager_v2_router.put( +@model_manager_router.put( "/merge", operation_id="merge", responses={ diff --git a/invokeai/app/api/routers/models.py b/invokeai/app/api/routers/models.py deleted file mode 100644 index 0aa7aa0ecba..00000000000 --- a/invokeai/app/api/routers/models.py +++ /dev/null @@ -1,426 +0,0 @@ -# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654), 2023 Kent Keirsey (https://github.com/hipsterusername), 2023 Lincoln D. Stein - -import pathlib -from typing import Annotated, List, Literal, Optional, Union - -from fastapi import Body, Path, Query, Response -from fastapi.routing import APIRouter -from pydantic import BaseModel, ConfigDict, Field, TypeAdapter -from starlette.exceptions import HTTPException - -from invokeai.backend.model_management import BaseModelType, MergeInterpolationMethod, ModelType -from invokeai.backend.model_management.models import ( - OPENAPI_MODEL_CONFIGS, - InvalidModelException, - ModelNotFoundException, - SchedulerPredictionType, -) - -from ..dependencies import ApiDependencies - -models_router = APIRouter(prefix="/v1/models", tags=["models"]) - -UpdateModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)] -UpdateModelResponseValidator = TypeAdapter(UpdateModelResponse) - -ImportModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)] -ImportModelResponseValidator = TypeAdapter(ImportModelResponse) - -ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)] -ConvertModelResponseValidator = TypeAdapter(ConvertModelResponse) - -MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)] -ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)] - - -class ModelsList(BaseModel): - models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]] - - model_config = ConfigDict(use_enum_values=True) - - -ModelsListValidator = TypeAdapter(ModelsList) - - -@models_router.get( - "/", - operation_id="list_models", - responses={200: {"model": ModelsList}}, -) -async def list_models( - base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"), - model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"), -) -> ModelsList: - """Gets a list of models""" - if base_models and len(base_models) > 0: - models_raw = [] - for base_model in base_models: - models_raw.extend(ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type)) - else: - models_raw = ApiDependencies.invoker.services.model_manager.list_models(None, model_type) - models = ModelsListValidator.validate_python({"models": models_raw}) - return models - - -@models_router.patch( - "/{base_model}/{model_type}/{model_name}", - operation_id="update_model", - responses={ - 200: {"description": "The model was updated successfully"}, - 400: {"description": "Bad request"}, - 404: {"description": "The model could not be found"}, - 409: {"description": "There is already a model corresponding to the new name"}, - }, - status_code=200, - response_model=UpdateModelResponse, -) -async def update_model( - base_model: BaseModelType = Path(description="Base model"), - model_type: ModelType = Path(description="The type of model"), - model_name: str = Path(description="model name"), - info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"), -) -> UpdateModelResponse: - """Update model contents with a new config. If the model name or base fields are changed, then the model is renamed.""" - logger = ApiDependencies.invoker.services.logger - - try: - previous_info = ApiDependencies.invoker.services.model_manager.list_model( - model_name=model_name, - base_model=base_model, - model_type=model_type, - ) - - # rename operation requested - if info.model_name != model_name or info.base_model != base_model: - ApiDependencies.invoker.services.model_manager.rename_model( - base_model=base_model, - model_type=model_type, - model_name=model_name, - new_name=info.model_name, - new_base=info.base_model, - ) - logger.info(f"Successfully renamed {base_model.value}/{model_name}=>{info.base_model}/{info.model_name}") - # update information to support an update of attributes - model_name = info.model_name - base_model = info.base_model - new_info = ApiDependencies.invoker.services.model_manager.list_model( - model_name=model_name, - base_model=base_model, - model_type=model_type, - ) - if new_info.get("path") != previous_info.get( - "path" - ): # model manager moved model path during rename - don't overwrite it - info.path = new_info.get("path") - - # replace empty string values with None/null to avoid phenomenon of vae: '' - info_dict = info.model_dump() - info_dict = {x: info_dict[x] if info_dict[x] else None for x in info_dict.keys()} - - ApiDependencies.invoker.services.model_manager.update_model( - model_name=model_name, - base_model=base_model, - model_type=model_type, - model_attributes=info_dict, - ) - - model_raw = ApiDependencies.invoker.services.model_manager.list_model( - model_name=model_name, - base_model=base_model, - model_type=model_type, - ) - model_response = UpdateModelResponseValidator.validate_python(model_raw) - except ModelNotFoundException as e: - raise HTTPException(status_code=404, detail=str(e)) - except ValueError as e: - logger.error(str(e)) - raise HTTPException(status_code=409, detail=str(e)) - except Exception as e: - logger.error(str(e)) - raise HTTPException(status_code=400, detail=str(e)) - - return model_response - - -@models_router.post( - "/import", - operation_id="import_model", - responses={ - 201: {"description": "The model imported successfully"}, - 404: {"description": "The model could not be found"}, - 415: {"description": "Unrecognized file/folder format"}, - 424: {"description": "The model appeared to import successfully, but could not be found in the model manager"}, - 409: {"description": "There is already a model corresponding to this path or repo_id"}, - }, - status_code=201, - response_model=ImportModelResponse, -) -async def import_model( - location: str = Body(description="A model path, repo_id or URL to import"), - prediction_type: Optional[Literal["v_prediction", "epsilon", "sample"]] = Body( - description="Prediction type for SDv2 checkpoints and rare SDv1 checkpoints", - default=None, - ), -) -> ImportModelResponse: - """Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically""" - - location = location.strip("\"' ") - items_to_import = {location} - prediction_types = {x.value: x for x in SchedulerPredictionType} - logger = ApiDependencies.invoker.services.logger - - try: - installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import( - items_to_import=items_to_import, - prediction_type_helper=lambda x: prediction_types.get(prediction_type), - ) - info = installed_models.get(location) - - if not info: - logger.error("Import failed") - raise HTTPException(status_code=415) - - logger.info(f"Successfully imported {location}, got {info}") - model_raw = ApiDependencies.invoker.services.model_manager.list_model( - model_name=info.name, base_model=info.base_model, model_type=info.model_type - ) - return ImportModelResponseValidator.validate_python(model_raw) - - except ModelNotFoundException as e: - logger.error(str(e)) - raise HTTPException(status_code=404, detail=str(e)) - except InvalidModelException as e: - logger.error(str(e)) - raise HTTPException(status_code=415) - except ValueError as e: - logger.error(str(e)) - raise HTTPException(status_code=409, detail=str(e)) - - -@models_router.post( - "/add", - operation_id="add_model", - responses={ - 201: {"description": "The model added successfully"}, - 404: {"description": "The model could not be found"}, - 424: {"description": "The model appeared to add successfully, but could not be found in the model manager"}, - 409: {"description": "There is already a model corresponding to this path or repo_id"}, - }, - status_code=201, - response_model=ImportModelResponse, -) -async def add_model( - info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"), -) -> ImportModelResponse: - """Add a model using the configuration information appropriate for its type. Only local models can be added by path""" - - logger = ApiDependencies.invoker.services.logger - - try: - ApiDependencies.invoker.services.model_manager.add_model( - info.model_name, - info.base_model, - info.model_type, - model_attributes=info.model_dump(), - ) - logger.info(f"Successfully added {info.model_name}") - model_raw = ApiDependencies.invoker.services.model_manager.list_model( - model_name=info.model_name, - base_model=info.base_model, - model_type=info.model_type, - ) - return ImportModelResponseValidator.validate_python(model_raw) - except ModelNotFoundException as e: - logger.error(str(e)) - raise HTTPException(status_code=404, detail=str(e)) - except ValueError as e: - logger.error(str(e)) - raise HTTPException(status_code=409, detail=str(e)) - - -@models_router.delete( - "/{base_model}/{model_type}/{model_name}", - operation_id="del_model", - responses={ - 204: {"description": "Model deleted successfully"}, - 404: {"description": "Model not found"}, - }, - status_code=204, - response_model=None, -) -async def delete_model( - base_model: BaseModelType = Path(description="Base model"), - model_type: ModelType = Path(description="The type of model"), - model_name: str = Path(description="model name"), -) -> Response: - """Delete Model""" - logger = ApiDependencies.invoker.services.logger - - try: - ApiDependencies.invoker.services.model_manager.del_model( - model_name, base_model=base_model, model_type=model_type - ) - logger.info(f"Deleted model: {model_name}") - return Response(status_code=204) - except ModelNotFoundException as e: - logger.error(str(e)) - raise HTTPException(status_code=404, detail=str(e)) - - -@models_router.put( - "/convert/{base_model}/{model_type}/{model_name}", - operation_id="convert_model", - responses={ - 200: {"description": "Model converted successfully"}, - 400: {"description": "Bad request"}, - 404: {"description": "Model not found"}, - }, - status_code=200, - response_model=ConvertModelResponse, -) -async def convert_model( - base_model: BaseModelType = Path(description="Base model"), - model_type: ModelType = Path(description="The type of model"), - model_name: str = Path(description="model name"), - convert_dest_directory: Optional[str] = Query( - default=None, description="Save the converted model to the designated directory" - ), -) -> ConvertModelResponse: - """Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none.""" - logger = ApiDependencies.invoker.services.logger - try: - logger.info(f"Converting model: {model_name}") - dest = pathlib.Path(convert_dest_directory) if convert_dest_directory else None - ApiDependencies.invoker.services.model_manager.convert_model( - model_name, - base_model=base_model, - model_type=model_type, - convert_dest_directory=dest, - ) - model_raw = ApiDependencies.invoker.services.model_manager.list_model( - model_name, base_model=base_model, model_type=model_type - ) - response = ConvertModelResponseValidator.validate_python(model_raw) - except ModelNotFoundException as e: - raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found: {str(e)}") - except ValueError as e: - raise HTTPException(status_code=400, detail=str(e)) - return response - - -@models_router.get( - "/search", - operation_id="search_for_models", - responses={ - 200: {"description": "Directory searched successfully"}, - 404: {"description": "Invalid directory path"}, - }, - status_code=200, - response_model=List[pathlib.Path], -) -async def search_for_models( - search_path: pathlib.Path = Query(description="Directory path to search for models"), -) -> List[pathlib.Path]: - if not search_path.is_dir(): - raise HTTPException( - status_code=404, - detail=f"The search path '{search_path}' does not exist or is not directory", - ) - return ApiDependencies.invoker.services.model_manager.search_for_models(search_path) - - -@models_router.get( - "/ckpt_confs", - operation_id="list_ckpt_configs", - responses={ - 200: {"description": "paths retrieved successfully"}, - }, - status_code=200, - response_model=List[pathlib.Path], -) -async def list_ckpt_configs() -> List[pathlib.Path]: - """Return a list of the legacy checkpoint configuration files stored in `ROOT/configs/stable-diffusion`, relative to ROOT.""" - return ApiDependencies.invoker.services.model_manager.list_checkpoint_configs() - - -@models_router.post( - "/sync", - operation_id="sync_to_config", - responses={ - 201: {"description": "synchronization successful"}, - }, - status_code=201, - response_model=bool, -) -async def sync_to_config() -> bool: - """Call after making changes to models.yaml, autoimport directories or models directory to synchronize - in-memory data structures with disk data structures.""" - ApiDependencies.invoker.services.model_manager.sync_to_config() - return True - - -# There's some weird pydantic-fastapi behaviour that requires this to be a separate class -# TODO: After a few updates, see if it works inside the route operation handler? -class MergeModelsBody(BaseModel): - model_names: List[str] = Field(description="model name", min_length=2, max_length=3) - merged_model_name: Optional[str] = Field(description="Name of destination model") - alpha: Optional[float] = Field(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5) - interp: Optional[MergeInterpolationMethod] = Field(description="Interpolation method") - force: Optional[bool] = Field( - description="Force merging of models created with different versions of diffusers", - default=False, - ) - - merge_dest_directory: Optional[str] = Field( - description="Save the merged model to the designated directory (with 'merged_model_name' appended)", - default=None, - ) - - model_config = ConfigDict(protected_namespaces=()) - - -@models_router.put( - "/merge/{base_model}", - operation_id="merge_models", - responses={ - 200: {"description": "Model converted successfully"}, - 400: {"description": "Incompatible models"}, - 404: {"description": "One or more models not found"}, - }, - status_code=200, - response_model=MergeModelResponse, -) -async def merge_models( - body: Annotated[MergeModelsBody, Body(description="Model configuration", embed=True)], - base_model: BaseModelType = Path(description="Base model"), -) -> MergeModelResponse: - """Convert a checkpoint model into a diffusers model""" - logger = ApiDependencies.invoker.services.logger - try: - logger.info( - f"Merging models: {body.model_names} into {body.merge_dest_directory or ''}/{body.merged_model_name}" - ) - dest = pathlib.Path(body.merge_dest_directory) if body.merge_dest_directory else None - result = ApiDependencies.invoker.services.model_manager.merge_models( - model_names=body.model_names, - base_model=base_model, - merged_model_name=body.merged_model_name or "+".join(body.model_names), - alpha=body.alpha, - interp=body.interp, - force=body.force, - merge_dest_directory=dest, - ) - model_raw = ApiDependencies.invoker.services.model_manager.list_model( - result.name, - base_model=base_model, - model_type=ModelType.Main, - ) - response = ConvertModelResponseValidator.validate_python(model_raw) - except ModelNotFoundException: - raise HTTPException( - status_code=404, - detail=f"One or more of the models '{body.model_names}' not found", - ) - except ValueError as e: - raise HTTPException(status_code=400, detail=str(e)) - return response diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index 1831b54c13c..149d47fb962 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -48,7 +48,7 @@ boards, download_queue, images, - model_manager_v2, + model_manager, session_queue, sessions, utilities, @@ -113,7 +113,7 @@ async def shutdown_event() -> None: app.include_router(sessions.session_router, prefix="/api") app.include_router(utilities.utilities_router, prefix="/api") -app.include_router(model_manager_v2.model_manager_v2_router, prefix="/api") +app.include_router(model_manager.model_manager_router, prefix="/api") app.include_router(download_queue.download_queue_router, prefix="/api") app.include_router(images.images_router, prefix="/api") app.include_router(boards.boards_router, prefix="/api") @@ -175,21 +175,23 @@ def custom_openapi() -> dict[str, Any]: invoker_schema["class"] = "invocation" openapi_schema["components"]["schemas"][f"{output_type_title}"]["class"] = "output" - from invokeai.backend.model_management.models import get_model_config_enums - - for model_config_format_enum in set(get_model_config_enums()): - name = model_config_format_enum.__qualname__ - - if name in openapi_schema["components"]["schemas"]: - # print(f"Config with name {name} already defined") - continue - - openapi_schema["components"]["schemas"][name] = { - "title": name, - "description": "An enumeration.", - "type": "string", - "enum": [v.value for v in model_config_format_enum], - } + # This code no longer seems to be necessary? + # Leave it here just in case + # + # from invokeai.backend.model_manager import get_model_config_formats + # formats = get_model_config_formats() + # for model_config_name, enum_set in formats.items(): + + # if model_config_name in openapi_schema["components"]["schemas"]: + # # print(f"Config with name {name} already defined") + # continue + + # openapi_schema["components"]["schemas"][model_config_name] = { + # "title": model_config_name, + # "description": "An enumeration.", + # "type": "string", + # "enum": [v.value for v in enum_set], + # } app.openapi_schema = openapi_schema return app.openapi_schema diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 593121ba60b..517da4375e1 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -18,15 +18,15 @@ from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.ti_utils import extract_ti_triggers_from_prompt from invokeai.backend.lora import LoRAModelRaw -from invokeai.backend.model_patcher import ModelPatcher -from invokeai.backend.textual_inversion import TextualInversionModelRaw from invokeai.backend.model_manager import ModelType +from invokeai.backend.model_patcher import ModelPatcher from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( BasicConditioningInfo, ConditioningFieldData, ExtraConditioningInfo, SDXLConditioningInfo, ) +from invokeai.backend.textual_inversion import TextualInversionModelRaw from invokeai.backend.util.devices import torch_dtype from .baseinvocation import ( diff --git a/invokeai/app/services/config/config_base.py b/invokeai/app/services/config/config_base.py index 983df6b4684..c73aa438096 100644 --- a/invokeai/app/services/config/config_base.py +++ b/invokeai/app/services/config/config_base.py @@ -68,7 +68,7 @@ def to_yaml(self) -> str: return OmegaConf.to_yaml(conf) @classmethod - def add_parser_arguments(cls, parser) -> None: + def add_parser_arguments(cls, parser: ArgumentParser) -> None: """Dynamically create arguments for a settings parser.""" if "type" in get_type_hints(cls): settings_stanza = get_args(get_type_hints(cls)["type"])[0] diff --git a/invokeai/app/services/invocation_stats/invocation_stats_base.py b/invokeai/app/services/invocation_stats/invocation_stats_base.py index 22624a6579a..ec8a453323d 100644 --- a/invokeai/app/services/invocation_stats/invocation_stats_base.py +++ b/invokeai/app/services/invocation_stats/invocation_stats_base.py @@ -29,8 +29,8 @@ """ from abc import ABC, abstractmethod -from contextlib import AbstractContextManager from pathlib import Path +from typing import Iterator from invokeai.app.invocations.baseinvocation import BaseInvocation from invokeai.app.services.invocation_stats.invocation_stats_common import InvocationStatsSummary @@ -40,18 +40,17 @@ class InvocationStatsServiceBase(ABC): "Abstract base class for recording node memory/time performance statistics" @abstractmethod - def __init__(self): + def __init__(self) -> None: """ Initialize the InvocationStatsService and reset counters to zero """ - pass @abstractmethod def collect_stats( self, invocation: BaseInvocation, graph_execution_state_id: str, - ) -> AbstractContextManager: + ) -> Iterator[None]: """ Return a context object that will capture the statistics on the execution of invocaation. Use with: to place around the part of the code that executes the invocation. @@ -61,7 +60,7 @@ def collect_stats( pass @abstractmethod - def reset_stats(self, graph_execution_state_id: str): + def reset_stats(self, graph_execution_state_id: str) -> None: """ Reset all statistics for the indicated graph. :param graph_execution_state_id: The id of the session whose stats to reset. @@ -70,7 +69,7 @@ def reset_stats(self, graph_execution_state_id: str): pass @abstractmethod - def log_stats(self, graph_execution_state_id: str): + def log_stats(self, graph_execution_state_id: str) -> None: """ Write out the accumulated statistics to the log or somewhere else. :param graph_execution_state_id: The id of the session whose stats to log. diff --git a/invokeai/app/services/model_install/model_install_base.py b/invokeai/app/services/model_install/model_install_base.py index 2f03db0af72..080219af75e 100644 --- a/invokeai/app/services/model_install/model_install_base.py +++ b/invokeai/app/services/model_install/model_install_base.py @@ -14,7 +14,7 @@ from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase -from invokeai.app.services.events import EventServiceBase +from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.invoker import Invoker from invokeai.app.services.model_records import ModelRecordServiceBase from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant diff --git a/invokeai/app/services/model_load/model_load_base.py b/invokeai/app/services/model_load/model_load_base.py index f4dd905135a..cc80333e932 100644 --- a/invokeai/app/services/model_load/model_load_base.py +++ b/invokeai/app/services/model_load/model_load_base.py @@ -5,7 +5,7 @@ from typing import Optional from invokeai.app.services.shared.invocation_context import InvocationContextData -from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType +from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType from invokeai.backend.model_manager.load import LoadedModel from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase @@ -15,23 +15,7 @@ class ModelLoadServiceBase(ABC): """Wrapper around AnyModelLoader.""" @abstractmethod - def load_model_by_key( - self, - key: str, - submodel_type: Optional[SubModelType] = None, - context_data: Optional[InvocationContextData] = None, - ) -> LoadedModel: - """ - Given a model's key, load it and return the LoadedModel object. - - :param key: Key of model config to be fetched. - :param submodel: For main (pipeline models), the submodel to fetch. - :param context_data: Invocation context data used for event reporting - """ - pass - - @abstractmethod - def load_model_by_config( + def load_model( self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None, @@ -44,34 +28,6 @@ def load_model_by_config( :param submodel: For main (pipeline models), the submodel to fetch. :param context_data: Invocation context data used for event reporting """ - pass - - @abstractmethod - def load_model_by_attr( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - submodel: Optional[SubModelType] = None, - context_data: Optional[InvocationContextData] = None, - ) -> LoadedModel: - """ - Given a model's attributes, search the database for it, and if found, load and return the LoadedModel object. - - This is provided for API compatability with the get_model() method - in the original model manager. However, note that LoadedModel is - not the same as the original ModelInfo that ws returned. - - :param model_name: Name of to be fetched. - :param base_model: Base model - :param model_type: Type of the model - :param submodel: For main (pipeline models), the submodel to fetch - :param context_data: The invocation context data. - - Exceptions: UnknownModelException -- model with these attributes not known - NotImplementedException -- a model loader was not provided at initialization time - ValueError -- more than one model matches this combination - """ @property @abstractmethod diff --git a/invokeai/app/services/model_load/model_load_default.py b/invokeai/app/services/model_load/model_load_default.py index fa96a4672d1..15c6283d8af 100644 --- a/invokeai/app/services/model_load/model_load_default.py +++ b/invokeai/app/services/model_load/model_load_default.py @@ -1,15 +1,18 @@ # Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Team """Implementation of model loader service.""" -from typing import Optional +from typing import Optional, Type from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException from invokeai.app.services.invoker import Invoker -from invokeai.app.services.model_records import ModelRecordServiceBase, UnknownModelException from invokeai.app.services.shared.invocation_context import InvocationContextData -from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType -from invokeai.backend.model_manager.load import AnyModelLoader, LoadedModel, ModelCache, ModelConvertCache +from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType +from invokeai.backend.model_manager.load import ( + LoadedModel, + ModelLoaderRegistry, + ModelLoaderRegistryBase, +) from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase from invokeai.backend.util.logging import InvokeAILogger @@ -18,25 +21,23 @@ class ModelLoadService(ModelLoadServiceBase): - """Wrapper around AnyModelLoader.""" + """Wrapper around ModelLoaderRegistry.""" def __init__( - self, - app_config: InvokeAIAppConfig, - record_store: ModelRecordServiceBase, - ram_cache: ModelCacheBase[AnyModel], - convert_cache: ModelConvertCacheBase, + self, + app_config: InvokeAIAppConfig, + ram_cache: ModelCacheBase[AnyModel], + convert_cache: ModelConvertCacheBase, + registry: Optional[Type[ModelLoaderRegistryBase]] = ModelLoaderRegistry, ): """Initialize the model load service.""" logger = InvokeAILogger.get_logger(self.__class__.__name__) logger.setLevel(app_config.log_level.upper()) - self._store = record_store - self._any_loader = AnyModelLoader( - app_config=app_config, - logger=logger, - ram_cache=ram_cache, - convert_cache=convert_cache, - ) + self._logger = logger + self._app_config = app_config + self._ram_cache = ram_cache + self._convert_cache = convert_cache + self._registry = registry def start(self, invoker: Invoker) -> None: self._invoker = invoker @@ -44,63 +45,14 @@ def start(self, invoker: Invoker) -> None: @property def ram_cache(self) -> ModelCacheBase[AnyModel]: """Return the RAM cache used by this loader.""" - return self._any_loader.ram_cache + return self._ram_cache @property def convert_cache(self) -> ModelConvertCacheBase: """Return the checkpoint convert cache used by this loader.""" - return self._any_loader.convert_cache - - def load_model_by_key( - self, - key: str, - submodel_type: Optional[SubModelType] = None, - context_data: Optional[InvocationContextData] = None, - ) -> LoadedModel: - """ - Given a model's key, load it and return the LoadedModel object. - - :param key: Key of model config to be fetched. - :param submodel: For main (pipeline models), the submodel to fetch. - :param context: Invocation context used for event reporting - """ - config = self._store.get_model(key) - return self.load_model_by_config(config, submodel_type, context_data) + return self._convert_cache - def load_model_by_attr( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - submodel: Optional[SubModelType] = None, - context_data: Optional[InvocationContextData] = None, - ) -> LoadedModel: - """ - Given a model's attributes, search the database for it, and if found, load and return the LoadedModel object. - - This is provided for API compatability with the get_model() method - in the original model manager. However, note that LoadedModel is - not the same as the original ModelInfo that ws returned. - - :param model_name: Name of to be fetched. - :param base_model: Base model - :param model_type: Type of the model - :param submodel: For main (pipeline models), the submodel to fetch - :param context: The invocation context. - - Exceptions: UnknownModelException -- model with this key not known - NotImplementedException -- a model loader was not provided at initialization time - ValueError -- more than one model matches this combination - """ - configs = self._store.search_by_attr(model_name, base_model, model_type) - if len(configs) == 0: - raise UnknownModelException(f"{base_model}/{model_type}/{model_name}: Unknown model") - elif len(configs) > 1: - raise ValueError(f"{base_model}/{model_type}/{model_name}: More than one model matches.") - else: - return self.load_model_by_key(configs[0].key, submodel) - - def load_model_by_config( + def load_model( self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None, @@ -118,7 +70,15 @@ def load_model_by_config( context_data=context_data, model_config=model_config, ) - loaded_model = self._any_loader.load_model(model_config, submodel_type) + + implementation, model_config, submodel_type = self._registry.get_implementation(model_config, submodel_type) # type: ignore + loaded_model: LoadedModel = implementation( + app_config=self._app_config, + logger=self._logger, + ram_cache=self._ram_cache, + convert_cache=self._convert_cache, + ).load_model(model_config, submodel_type) + if context_data: self._emit_load_event( context_data=context_data, diff --git a/invokeai/app/services/model_manager/__init__.py b/invokeai/app/services/model_manager/__init__.py index 66707493f71..5455577266a 100644 --- a/invokeai/app/services/model_manager/__init__.py +++ b/invokeai/app/services/model_manager/__init__.py @@ -3,7 +3,7 @@ from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType from invokeai.backend.model_manager.load import LoadedModel -from .model_manager_default import ModelManagerServiceBase, ModelManagerService +from .model_manager_default import ModelManagerService, ModelManagerServiceBase __all__ = [ "ModelManagerServiceBase", diff --git a/invokeai/app/services/model_manager/model_manager_base.py b/invokeai/app/services/model_manager/model_manager_base.py index 1116c82ff1f..c25aa6fb47c 100644 --- a/invokeai/app/services/model_manager/model_manager_base.py +++ b/invokeai/app/services/model_manager/model_manager_base.py @@ -1,10 +1,14 @@ # Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team from abc import ABC, abstractmethod +from typing import Optional from typing_extensions import Self from invokeai.app.services.invoker import Invoker +from invokeai.app.services.shared.invocation_context import InvocationContextData +from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType, SubModelType +from invokeai.backend.model_manager.load.load_base import LoadedModel from ..config import InvokeAIAppConfig from ..download import DownloadQueueServiceBase @@ -65,3 +69,32 @@ def start(self, invoker: Invoker) -> None: @abstractmethod def stop(self, invoker: Invoker) -> None: pass + + @abstractmethod + def load_model_by_config( + self, + model_config: AnyModelConfig, + submodel_type: Optional[SubModelType] = None, + context_data: Optional[InvocationContextData] = None, + ) -> LoadedModel: + pass + + @abstractmethod + def load_model_by_key( + self, + key: str, + submodel_type: Optional[SubModelType] = None, + context_data: Optional[InvocationContextData] = None, + ) -> LoadedModel: + pass + + @abstractmethod + def load_model_by_attr( + self, + model_name: str, + base_model: BaseModelType, + model_type: ModelType, + submodel: Optional[SubModelType] = None, + context_data: Optional[InvocationContextData] = None, + ) -> LoadedModel: + pass diff --git a/invokeai/app/services/model_manager/model_manager_default.py b/invokeai/app/services/model_manager/model_manager_default.py index b96341be69e..d029f9e0339 100644 --- a/invokeai/app/services/model_manager/model_manager_default.py +++ b/invokeai/app/services/model_manager/model_manager_default.py @@ -1,10 +1,14 @@ # Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team """Implementation of ModelManagerServiceBase.""" +from typing import Optional + from typing_extensions import Self from invokeai.app.services.invoker import Invoker -from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache +from invokeai.app.services.shared.invocation_context import InvocationContextData +from invokeai.backend.model_manager import AnyModelConfig, BaseModelType, LoadedModel, ModelType, SubModelType +from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry from invokeai.backend.util.logging import InvokeAILogger from ..config import InvokeAIAppConfig @@ -12,7 +16,7 @@ from ..events.events_base import EventServiceBase from ..model_install import ModelInstallService, ModelInstallServiceBase from ..model_load import ModelLoadService, ModelLoadServiceBase -from ..model_records import ModelRecordServiceBase +from ..model_records import ModelRecordServiceBase, UnknownModelException from .model_manager_base import ModelManagerServiceBase @@ -58,6 +62,56 @@ def stop(self, invoker: Invoker) -> None: if hasattr(service, "stop"): service.stop(invoker) + def load_model_by_config( + self, + model_config: AnyModelConfig, + submodel_type: Optional[SubModelType] = None, + context_data: Optional[InvocationContextData] = None, + ) -> LoadedModel: + return self.load.load_model(model_config, submodel_type, context_data) + + def load_model_by_key( + self, + key: str, + submodel_type: Optional[SubModelType] = None, + context_data: Optional[InvocationContextData] = None, + ) -> LoadedModel: + config = self.store.get_model(key) + return self.load.load_model(config, submodel_type, context_data) + + def load_model_by_attr( + self, + model_name: str, + base_model: BaseModelType, + model_type: ModelType, + submodel: Optional[SubModelType] = None, + context_data: Optional[InvocationContextData] = None, + ) -> LoadedModel: + """ + Given a model's attributes, search the database for it, and if found, load and return the LoadedModel object. + + This is provided for API compatability with the get_model() method + in the original model manager. However, note that LoadedModel is + not the same as the original ModelInfo that ws returned. + + :param model_name: Name of to be fetched. + :param base_model: Base model + :param model_type: Type of the model + :param submodel: For main (pipeline models), the submodel to fetch + :param context: The invocation context. + + Exceptions: UnknownModelException -- model with this key not known + NotImplementedException -- a model loader was not provided at initialization time + ValueError -- more than one model matches this combination + """ + configs = self.store.search_by_attr(model_name, base_model, model_type) + if len(configs) == 0: + raise UnknownModelException(f"{base_model}/{model_type}/{model_name}: Unknown model") + elif len(configs) > 1: + raise ValueError(f"{base_model}/{model_type}/{model_name}: More than one model matches.") + else: + return self.load.load_model(configs[0], submodel, context_data) + @classmethod def build_model_manager( cls, @@ -82,9 +136,9 @@ def build_model_manager( ) loader = ModelLoadService( app_config=app_config, - record_store=model_record_service, ram_cache=ram_cache, convert_cache=convert_cache, + registry=ModelLoaderRegistry, ) installer = ModelInstallService( app_config=app_config, diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 089d09f825c..1395427a97e 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -281,7 +281,7 @@ def load(self, key: str, submodel_type: Optional[SubModelType] = None) -> Loaded # The model manager emits events as it loads the model. It needs the context data to build # the event payloads. - return self._services.model_manager.load.load_model_by_key( + return self._services.model_manager.load_model_by_key( key=key, submodel_type=submodel_type, context_data=self._context_data ) @@ -296,7 +296,7 @@ def load_by_attrs( :param model_type: Type of the model :param submodel: For main (pipeline models), the submodel to fetch """ - return self._services.model_manager.load.load_model_by_attr( + return self._services.model_manager.load_model_by_attr( model_name=model_name, base_model=base_model, model_type=model_type, diff --git a/invokeai/backend/install/migrate_to_3.py b/invokeai/backend/install/migrate_to_3.py deleted file mode 100644 index e15eb23f5b2..00000000000 --- a/invokeai/backend/install/migrate_to_3.py +++ /dev/null @@ -1,591 +0,0 @@ -""" -Migrate the models directory and models.yaml file from an existing -InvokeAI 2.3 installation to 3.0.0. -""" - -import argparse -import os -import shutil -import warnings -from dataclasses import dataclass -from pathlib import Path -from typing import Union - -import diffusers -import transformers -import yaml -from diffusers import AutoencoderKL, StableDiffusionPipeline -from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker -from omegaconf import DictConfig, OmegaConf -from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer - -import invokeai.backend.util.logging as logger -from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend.model_management import ModelManager -from invokeai.backend.model_management.model_probe import BaseModelType, ModelProbe, ModelProbeInfo, ModelType - -warnings.filterwarnings("ignore") -transformers.logging.set_verbosity_error() -diffusers.logging.set_verbosity_error() - - -# holder for paths that we will migrate -@dataclass -class ModelPaths: - models: Path - embeddings: Path - loras: Path - controlnets: Path - - -class MigrateTo3(object): - def __init__( - self, - from_root: Path, - to_models: Path, - model_manager: ModelManager, - src_paths: ModelPaths, - ): - self.root_directory = from_root - self.dest_models = to_models - self.mgr = model_manager - self.src_paths = src_paths - - @classmethod - def initialize_yaml(cls, yaml_file: Path): - with open(yaml_file, "w") as file: - file.write(yaml.dump({"__metadata__": {"version": "3.0.0"}})) - - def create_directory_structure(self): - """ - Create the basic directory structure for the models folder. - """ - for model_base in [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]: - for model_type in [ - ModelType.Main, - ModelType.Vae, - ModelType.Lora, - ModelType.ControlNet, - ModelType.TextualInversion, - ]: - path = self.dest_models / model_base.value / model_type.value - path.mkdir(parents=True, exist_ok=True) - path = self.dest_models / "core" - path.mkdir(parents=True, exist_ok=True) - - @staticmethod - def copy_file(src: Path, dest: Path): - """ - copy a single file with logging - """ - if dest.exists(): - logger.info(f"Skipping existing {str(dest)}") - return - logger.info(f"Copying {str(src)} to {str(dest)}") - try: - shutil.copy(src, dest) - except Exception as e: - logger.error(f"COPY FAILED: {str(e)}") - - @staticmethod - def copy_dir(src: Path, dest: Path): - """ - Recursively copy a directory with logging - """ - if dest.exists(): - logger.info(f"Skipping existing {str(dest)}") - return - - logger.info(f"Copying {str(src)} to {str(dest)}") - try: - shutil.copytree(src, dest) - except Exception as e: - logger.error(f"COPY FAILED: {str(e)}") - - def migrate_models(self, src_dir: Path): - """ - Recursively walk through src directory, probe anything - that looks like a model, and copy the model into the - appropriate location within the destination models directory. - """ - directories_scanned = set() - for root, dirs, files in os.walk(src_dir, followlinks=True): - for d in dirs: - try: - model = Path(root, d) - info = ModelProbe().heuristic_probe(model) - if not info: - continue - dest = self._model_probe_to_path(info) / model.name - self.copy_dir(model, dest) - directories_scanned.add(model) - except Exception as e: - logger.error(str(e)) - except KeyboardInterrupt: - raise - for f in files: - # don't copy raw learned_embeds.bin or pytorch_lora_weights.bin - # let them be copied as part of a tree copy operation - try: - if f in {"learned_embeds.bin", "pytorch_lora_weights.bin"}: - continue - model = Path(root, f) - if model.parent in directories_scanned: - continue - info = ModelProbe().heuristic_probe(model) - if not info: - continue - dest = self._model_probe_to_path(info) / f - self.copy_file(model, dest) - except Exception as e: - logger.error(str(e)) - except KeyboardInterrupt: - raise - - def migrate_support_models(self): - """ - Copy the clipseg, upscaler, and restoration models to their new - locations. - """ - dest_directory = self.dest_models - if (self.root_directory / "models/clipseg").exists(): - self.copy_dir(self.root_directory / "models/clipseg", dest_directory / "core/misc/clipseg") - if (self.root_directory / "models/realesrgan").exists(): - self.copy_dir(self.root_directory / "models/realesrgan", dest_directory / "core/upscaling/realesrgan") - for d in ["codeformer", "gfpgan"]: - path = self.root_directory / "models" / d - if path.exists(): - self.copy_dir(path, dest_directory / f"core/face_restoration/{d}") - - def migrate_tuning_models(self): - """ - Migrate the embeddings, loras and controlnets directories to their new homes. - """ - for src in [self.src_paths.embeddings, self.src_paths.loras, self.src_paths.controlnets]: - if not src: - continue - if src.is_dir(): - logger.info(f"Scanning {src}") - self.migrate_models(src) - else: - logger.info(f"{src} directory not found; skipping") - continue - - def migrate_conversion_models(self): - """ - Migrate all the models that are needed by the ckpt_to_diffusers conversion - script. - """ - - dest_directory = self.dest_models - kwargs = { - "cache_dir": self.root_directory / "models/hub", - # local_files_only = True - } - try: - logger.info("Migrating core tokenizers and text encoders") - target_dir = dest_directory / "core" / "convert" - - self._migrate_pretrained( - BertTokenizerFast, repo_id="bert-base-uncased", dest=target_dir / "bert-base-uncased", **kwargs - ) - - # sd-1 - repo_id = "openai/clip-vit-large-patch14" - self._migrate_pretrained( - CLIPTokenizer, repo_id=repo_id, dest=target_dir / "clip-vit-large-patch14", **kwargs - ) - self._migrate_pretrained( - CLIPTextModel, repo_id=repo_id, dest=target_dir / "clip-vit-large-patch14", force=True, **kwargs - ) - - # sd-2 - repo_id = "stabilityai/stable-diffusion-2" - self._migrate_pretrained( - CLIPTokenizer, - repo_id=repo_id, - dest=target_dir / "stable-diffusion-2-clip" / "tokenizer", - **{"subfolder": "tokenizer", **kwargs}, - ) - self._migrate_pretrained( - CLIPTextModel, - repo_id=repo_id, - dest=target_dir / "stable-diffusion-2-clip" / "text_encoder", - **{"subfolder": "text_encoder", **kwargs}, - ) - - # VAE - logger.info("Migrating stable diffusion VAE") - self._migrate_pretrained( - AutoencoderKL, repo_id="stabilityai/sd-vae-ft-mse", dest=target_dir / "sd-vae-ft-mse", **kwargs - ) - - # safety checking - logger.info("Migrating safety checker") - repo_id = "CompVis/stable-diffusion-safety-checker" - self._migrate_pretrained( - AutoFeatureExtractor, repo_id=repo_id, dest=target_dir / "stable-diffusion-safety-checker", **kwargs - ) - self._migrate_pretrained( - StableDiffusionSafetyChecker, - repo_id=repo_id, - dest=target_dir / "stable-diffusion-safety-checker", - **kwargs, - ) - except KeyboardInterrupt: - raise - except Exception as e: - logger.error(str(e)) - - def _model_probe_to_path(self, info: ModelProbeInfo) -> Path: - return Path(self.dest_models, info.base_type.value, info.model_type.value) - - def _migrate_pretrained(self, model_class, repo_id: str, dest: Path, force: bool = False, **kwargs): - if dest.exists() and not force: - logger.info(f"Skipping existing {dest}") - return - model = model_class.from_pretrained(repo_id, **kwargs) - self._save_pretrained(model, dest, overwrite=force) - - def _save_pretrained(self, model, dest: Path, overwrite: bool = False): - model_name = dest.name - if overwrite: - model.save_pretrained(dest, safe_serialization=True) - else: - download_path = dest.with_name(f"{model_name}.downloading") - model.save_pretrained(download_path, safe_serialization=True) - download_path.replace(dest) - - def _download_vae(self, repo_id: str, subfolder: str = None) -> Path: - vae = AutoencoderKL.from_pretrained(repo_id, cache_dir=self.root_directory / "models/hub", subfolder=subfolder) - info = ModelProbe().heuristic_probe(vae) - _, model_name = repo_id.split("/") - dest = self._model_probe_to_path(info) / self.unique_name(model_name, info) - vae.save_pretrained(dest, safe_serialization=True) - return dest - - def _vae_path(self, vae: Union[str, dict]) -> Path: - """ - Convert 2.3 VAE stanza to a straight path. - """ - vae_path = None - - # First get a path - if isinstance(vae, str): - vae_path = vae - - elif isinstance(vae, DictConfig): - if p := vae.get("path"): - vae_path = p - elif repo_id := vae.get("repo_id"): - if repo_id == "stabilityai/sd-vae-ft-mse": # this guy is already downloaded - vae_path = "models/core/convert/sd-vae-ft-mse" - return vae_path - else: - vae_path = self._download_vae(repo_id, vae.get("subfolder")) - - assert vae_path is not None, "Couldn't find VAE for this model" - - # if the VAE is in the old models directory, then we must move it into the new - # one. VAEs outside of this directory can stay where they are. - vae_path = Path(vae_path) - if vae_path.is_relative_to(self.src_paths.models): - info = ModelProbe().heuristic_probe(vae_path) - dest = self._model_probe_to_path(info) / vae_path.name - if not dest.exists(): - if vae_path.is_dir(): - self.copy_dir(vae_path, dest) - else: - self.copy_file(vae_path, dest) - vae_path = dest - - if vae_path.is_relative_to(self.dest_models): - rel_path = vae_path.relative_to(self.dest_models) - return Path("models", rel_path) - else: - return vae_path - - def migrate_repo_id(self, repo_id: str, model_name: str = None, **extra_config): - """ - Migrate a locally-cached diffusers pipeline identified with a repo_id - """ - dest_dir = self.dest_models - - cache = self.root_directory / "models/hub" - kwargs = { - "cache_dir": cache, - "safety_checker": None, - # local_files_only = True, - } - - owner, repo_name = repo_id.split("/") - model_name = model_name or repo_name - model = cache / "--".join(["models", owner, repo_name]) - - if len(list(model.glob("snapshots/**/model_index.json"))) == 0: - return - revisions = [x.name for x in model.glob("refs/*")] - - # if an fp16 is available we use that - revision = "fp16" if len(revisions) > 1 and "fp16" in revisions else revisions[0] - pipeline = StableDiffusionPipeline.from_pretrained(repo_id, revision=revision, **kwargs) - - info = ModelProbe().heuristic_probe(pipeline) - if not info: - return - - if self.mgr.model_exists(model_name, info.base_type, info.model_type): - logger.warning(f"A model named {model_name} already exists at the destination. Skipping migration.") - return - - dest = self._model_probe_to_path(info) / model_name - self._save_pretrained(pipeline, dest) - - rel_path = Path("models", dest.relative_to(dest_dir)) - self._add_model(model_name, info, rel_path, **extra_config) - - def migrate_path(self, location: Path, model_name: str = None, **extra_config): - """ - Migrate a model referred to using 'weights' or 'path' - """ - - # handle relative paths - dest_dir = self.dest_models - location = self.root_directory / location - model_name = model_name or location.stem - - info = ModelProbe().heuristic_probe(location) - if not info: - return - - if self.mgr.model_exists(model_name, info.base_type, info.model_type): - logger.warning(f"A model named {model_name} already exists at the destination. Skipping migration.") - return - - # uh oh, weights is in the old models directory - move it into the new one - if Path(location).is_relative_to(self.src_paths.models): - dest = Path(dest_dir, info.base_type.value, info.model_type.value, location.name) - if location.is_dir(): - self.copy_dir(location, dest) - else: - self.copy_file(location, dest) - location = Path("models", info.base_type.value, info.model_type.value, location.name) - - self._add_model(model_name, info, location, **extra_config) - - def _add_model(self, model_name: str, info: ModelProbeInfo, location: Path, **extra_config): - if info.model_type != ModelType.Main: - return - - self.mgr.add_model( - model_name=model_name, - base_model=info.base_type, - model_type=info.model_type, - clobber=True, - model_attributes={ - "path": str(location), - "description": f"A {info.base_type.value} {info.model_type.value} model", - "model_format": info.format, - "variant": info.variant_type.value, - **extra_config, - }, - ) - - def migrate_defined_models(self): - """ - Migrate models defined in models.yaml - """ - # find any models referred to in old models.yaml - conf = OmegaConf.load(self.root_directory / "configs/models.yaml") - - for model_name, stanza in conf.items(): - try: - passthru_args = {} - - if vae := stanza.get("vae"): - try: - passthru_args["vae"] = str(self._vae_path(vae)) - except Exception as e: - logger.warning(f'Could not find a VAE matching "{vae}" for model "{model_name}"') - logger.warning(str(e)) - - if config := stanza.get("config"): - passthru_args["config"] = config - - if description := stanza.get("description"): - passthru_args["description"] = description - - if repo_id := stanza.get("repo_id"): - logger.info(f"Migrating diffusers model {model_name}") - self.migrate_repo_id(repo_id, model_name, **passthru_args) - - elif location := stanza.get("weights"): - logger.info(f"Migrating checkpoint model {model_name}") - self.migrate_path(Path(location), model_name, **passthru_args) - - elif location := stanza.get("path"): - logger.info(f"Migrating diffusers model {model_name}") - self.migrate_path(Path(location), model_name, **passthru_args) - - except KeyboardInterrupt: - raise - except Exception as e: - logger.error(str(e)) - - def migrate(self): - self.create_directory_structure() - # the configure script is doing this - self.migrate_support_models() - self.migrate_conversion_models() - self.migrate_tuning_models() - self.migrate_defined_models() - - -def _parse_legacy_initfile(root: Path, initfile: Path) -> ModelPaths: - """ - Returns tuple of (embedding_path, lora_path, controlnet_path) - """ - parser = argparse.ArgumentParser(fromfile_prefix_chars="@") - parser.add_argument( - "--embedding_directory", - "--embedding_path", - type=Path, - dest="embedding_path", - default=Path("embeddings"), - ) - parser.add_argument( - "--lora_directory", - dest="lora_path", - type=Path, - default=Path("loras"), - ) - opt, _ = parser.parse_known_args([f"@{str(initfile)}"]) - return ModelPaths( - models=root / "models", - embeddings=root / str(opt.embedding_path).strip('"'), - loras=root / str(opt.lora_path).strip('"'), - controlnets=root / "controlnets", - ) - - -def _parse_legacy_yamlfile(root: Path, initfile: Path) -> ModelPaths: - """ - Returns tuple of (embedding_path, lora_path, controlnet_path) - """ - # Don't use the config object because it is unforgiving of version updates - # Just use omegaconf directly - opt = OmegaConf.load(initfile) - paths = opt.InvokeAI.Paths - models = paths.get("models_dir", "models") - embeddings = paths.get("embedding_dir", "embeddings") - loras = paths.get("lora_dir", "loras") - controlnets = paths.get("controlnet_dir", "controlnets") - return ModelPaths( - models=root / models if models else None, - embeddings=root / embeddings if embeddings else None, - loras=root / loras if loras else None, - controlnets=root / controlnets if controlnets else None, - ) - - -def get_legacy_embeddings(root: Path) -> ModelPaths: - path = root / "invokeai.init" - if path.exists(): - return _parse_legacy_initfile(root, path) - path = root / "invokeai.yaml" - if path.exists(): - return _parse_legacy_yamlfile(root, path) - - -def do_migrate(src_directory: Path, dest_directory: Path): - """ - Migrate models from src to dest InvokeAI root directories - """ - config_file = dest_directory / "configs" / "models.yaml.3" - dest_models = dest_directory / "models.3" - - version_3 = (dest_directory / "models" / "core").exists() - - # Here we create the destination models.yaml file. - # If we are writing into a version 3 directory and the - # file already exists, then we write into a copy of it to - # avoid deleting its previous customizations. Otherwise we - # create a new empty one. - if version_3: # write into the dest directory - try: - shutil.copy(dest_directory / "configs" / "models.yaml", config_file) - except Exception: - MigrateTo3.initialize_yaml(config_file) - mgr = ModelManager(config_file) # important to initialize BEFORE moving the models directory - (dest_directory / "models").replace(dest_models) - else: - MigrateTo3.initialize_yaml(config_file) - mgr = ModelManager(config_file) - - paths = get_legacy_embeddings(src_directory) - migrator = MigrateTo3(from_root=src_directory, to_models=dest_models, model_manager=mgr, src_paths=paths) - migrator.migrate() - print("Migration successful.") - - if not version_3: - (dest_directory / "models").replace(src_directory / "models.orig") - print(f"Original models directory moved to {dest_directory}/models.orig") - - (dest_directory / "configs" / "models.yaml").replace(src_directory / "configs" / "models.yaml.orig") - print(f"Original models.yaml file moved to {dest_directory}/configs/models.yaml.orig") - - config_file.replace(config_file.with_suffix("")) - dest_models.replace(dest_models.with_suffix("")) - - -def main(): - parser = argparse.ArgumentParser( - prog="invokeai-migrate3", - description=""" -This will copy and convert the models directory and the configs/models.yaml from the InvokeAI 2.3 format -'--from-directory' root to the InvokeAI 3.0 '--to-directory' root. These may be abbreviated '--from' and '--to'.a - -The old models directory and config file will be renamed 'models.orig' and 'models.yaml.orig' respectively. -It is safe to provide the same directory for both arguments, but it is better to use the invokeai_configure -script, which will perform a full upgrade in place.""", - ) - parser.add_argument( - "--from-directory", - dest="src_root", - type=Path, - required=True, - help='Source InvokeAI 2.3 root directory (containing "invokeai.init" or "invokeai.yaml")', - ) - parser.add_argument( - "--to-directory", - dest="dest_root", - type=Path, - required=True, - help='Destination InvokeAI 3.0 directory (containing "invokeai.yaml")', - ) - args = parser.parse_args() - src_root = args.src_root - assert src_root.is_dir(), f"{src_root} is not a valid directory" - assert (src_root / "models").is_dir(), f"{src_root} does not contain a 'models' subdirectory" - assert (src_root / "models" / "hub").exists(), f"{src_root} does not contain a version 2.3 models directory" - assert (src_root / "invokeai.init").exists() or ( - src_root / "invokeai.yaml" - ).exists(), f"{src_root} does not contain an InvokeAI init file." - - dest_root = args.dest_root - assert dest_root.is_dir(), f"{dest_root} is not a valid directory" - config = InvokeAIAppConfig.get_config() - config.parse_args(["--root", str(dest_root)]) - - # TODO: revisit - don't rely on invokeai.yaml to exist yet! - dest_is_setup = (dest_root / "models/core").exists() and (dest_root / "databases").exists() - if not dest_is_setup: - from invokeai.backend.install.invokeai_configure import initialize_rootdir - - initialize_rootdir(dest_root, True) - - do_migrate(src_root, dest_root) - - -if __name__ == "__main__": - main() diff --git a/invokeai/backend/install/model_install_backend.py b/invokeai/backend/install/model_install_backend.py deleted file mode 100644 index fdbe714f62c..00000000000 --- a/invokeai/backend/install/model_install_backend.py +++ /dev/null @@ -1,637 +0,0 @@ -""" -Utility (backend) functions used by model_install.py -""" -import os -import re -import shutil -import warnings -from dataclasses import dataclass, field -from pathlib import Path -from tempfile import TemporaryDirectory -from typing import Callable, Dict, List, Optional, Set, Union - -import requests -import torch -from diffusers import DiffusionPipeline -from diffusers import logging as dlogging -from huggingface_hub import HfApi, HfFolder, hf_hub_url -from omegaconf import OmegaConf -from tqdm import tqdm - -import invokeai.configs as configs -from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend.model_management import AddModelResult, BaseModelType, ModelManager, ModelType, ModelVariantType -from invokeai.backend.model_management.model_probe import ModelProbe, ModelProbeInfo, SchedulerPredictionType -from invokeai.backend.util import download_with_resume -from invokeai.backend.util.devices import choose_torch_device, torch_dtype - -from ..util.logging import InvokeAILogger - -warnings.filterwarnings("ignore") - -# --------------------------globals----------------------- -config = InvokeAIAppConfig.get_config() -logger = InvokeAILogger.get_logger(name="InvokeAI") - -# the initial "configs" dir is now bundled in the `invokeai.configs` package -Dataset_path = Path(configs.__path__[0]) / "INITIAL_MODELS.yaml" - -Config_preamble = """ -# This file describes the alternative machine learning models -# available to InvokeAI script. -# -# To add a new model, follow the examples below. Each -# model requires a model config file, a weights file, -# and the width and height of the images it -# was trained on. -""" - -LEGACY_CONFIGS = { - BaseModelType.StableDiffusion1: { - ModelVariantType.Normal: { - SchedulerPredictionType.Epsilon: "v1-inference.yaml", - SchedulerPredictionType.VPrediction: "v1-inference-v.yaml", - }, - ModelVariantType.Inpaint: { - SchedulerPredictionType.Epsilon: "v1-inpainting-inference.yaml", - SchedulerPredictionType.VPrediction: "v1-inpainting-inference-v.yaml", - }, - }, - BaseModelType.StableDiffusion2: { - ModelVariantType.Normal: { - SchedulerPredictionType.Epsilon: "v2-inference.yaml", - SchedulerPredictionType.VPrediction: "v2-inference-v.yaml", - }, - ModelVariantType.Inpaint: { - SchedulerPredictionType.Epsilon: "v2-inpainting-inference.yaml", - SchedulerPredictionType.VPrediction: "v2-inpainting-inference-v.yaml", - }, - }, - BaseModelType.StableDiffusionXL: { - ModelVariantType.Normal: "sd_xl_base.yaml", - }, - BaseModelType.StableDiffusionXLRefiner: { - ModelVariantType.Normal: "sd_xl_refiner.yaml", - }, -} - - -@dataclass -class InstallSelections: - install_models: List[str] = field(default_factory=list) - remove_models: List[str] = field(default_factory=list) - - -@dataclass -class ModelLoadInfo: - name: str - model_type: ModelType - base_type: BaseModelType - path: Optional[Path] = None - repo_id: Optional[str] = None - subfolder: Optional[str] = None - description: str = "" - installed: bool = False - recommended: bool = False - default: bool = False - requires: Optional[List[str]] = field(default_factory=list) - - -class ModelInstall(object): - def __init__( - self, - config: InvokeAIAppConfig, - prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None, - model_manager: Optional[ModelManager] = None, - access_token: Optional[str] = None, - civitai_api_key: Optional[str] = None, - ): - self.config = config - self.mgr = model_manager or ModelManager(config.model_conf_path) - self.datasets = OmegaConf.load(Dataset_path) - self.prediction_helper = prediction_type_helper - self.access_token = access_token or HfFolder.get_token() - self.civitai_api_key = civitai_api_key or config.civitai_api_key - self.reverse_paths = self._reverse_paths(self.datasets) - - def all_models(self) -> Dict[str, ModelLoadInfo]: - """ - Return dict of model_key=>ModelLoadInfo objects. - This method consolidates and simplifies the entries in both - models.yaml and INITIAL_MODELS.yaml so that they can - be treated uniformly. It also sorts the models alphabetically - by their name, to improve the display somewhat. - """ - model_dict = {} - - # first populate with the entries in INITIAL_MODELS.yaml - for key, value in self.datasets.items(): - name, base, model_type = ModelManager.parse_key(key) - value["name"] = name - value["base_type"] = base - value["model_type"] = model_type - model_info = ModelLoadInfo(**value) - if model_info.subfolder and model_info.repo_id: - model_info.repo_id += f":{model_info.subfolder}" - model_dict[key] = model_info - - # supplement with entries in models.yaml - installed_models = list(self.mgr.list_models()) - - for md in installed_models: - base = md["base_model"] - model_type = md["model_type"] - name = md["model_name"] - key = ModelManager.create_key(name, base, model_type) - if key in model_dict: - model_dict[key].installed = True - else: - model_dict[key] = ModelLoadInfo( - name=name, - base_type=base, - model_type=model_type, - path=value.get("path"), - installed=True, - ) - return {x: model_dict[x] for x in sorted(model_dict.keys(), key=lambda y: model_dict[y].name.lower())} - - def _is_autoloaded(self, model_info: dict) -> bool: - path = model_info.get("path") - if not path: - return False - for autodir in ["autoimport_dir", "lora_dir", "embedding_dir", "controlnet_dir"]: - if autodir_path := getattr(self.config, autodir): - autodir_path = self.config.root_path / autodir_path - if Path(path).is_relative_to(autodir_path): - return True - return False - - def list_models(self, model_type): - installed = self.mgr.list_models(model_type=model_type) - print() - print(f"Installed models of type `{model_type}`:") - print(f"{'Model Key':50} Model Path") - for i in installed: - print(f"{'/'.join([i['base_model'],i['model_type'],i['model_name']]):50} {i['path']}") - print() - - # logic here a little reversed to maintain backward compatibility - def starter_models(self, all_models: bool = False) -> Set[str]: - models = set() - for key, _value in self.datasets.items(): - name, base, model_type = ModelManager.parse_key(key) - if all_models or model_type in [ModelType.Main, ModelType.Vae]: - models.add(key) - return models - - def recommended_models(self) -> Set[str]: - starters = self.starter_models(all_models=True) - return {x for x in starters if self.datasets[x].get("recommended", False)} - - def default_model(self) -> str: - starters = self.starter_models() - defaults = [x for x in starters if self.datasets[x].get("default", False)] - return defaults[0] - - def install(self, selections: InstallSelections): - verbosity = dlogging.get_verbosity() # quench NSFW nags - dlogging.set_verbosity_error() - - job = 1 - jobs = len(selections.remove_models) + len(selections.install_models) - - # remove requested models - for key in selections.remove_models: - name, base, mtype = self.mgr.parse_key(key) - logger.info(f"Deleting {mtype} model {name} [{job}/{jobs}]") - try: - self.mgr.del_model(name, base, mtype) - except FileNotFoundError as e: - logger.warning(e) - job += 1 - - # add requested models - self._remove_installed(selections.install_models) - self._add_required_models(selections.install_models) - for path in selections.install_models: - logger.info(f"Installing {path} [{job}/{jobs}]") - try: - self.heuristic_import(path) - except (ValueError, KeyError) as e: - logger.error(str(e)) - job += 1 - - dlogging.set_verbosity(verbosity) - self.mgr.commit() - - def heuristic_import( - self, - model_path_id_or_url: Union[str, Path], - models_installed: Set[Path] = None, - ) -> Dict[str, AddModelResult]: - """ - :param model_path_id_or_url: A Path to a local model to import, or a string representing its repo_id or URL - :param models_installed: Set of installed models, used for recursive invocation - Returns a set of dict objects corresponding to newly-created stanzas in models.yaml. - """ - - if not models_installed: - models_installed = {} - - model_path_id_or_url = str(model_path_id_or_url).strip("\"' ") - - # A little hack to allow nested routines to retrieve info on the requested ID - self.current_id = model_path_id_or_url - path = Path(model_path_id_or_url) - - # fix relative paths - if path.exists() and not path.is_absolute(): - path = path.absolute() # make relative to current WD - - # checkpoint file, or similar - if path.is_file(): - models_installed.update({str(path): self._install_path(path)}) - - # folders style or similar - elif path.is_dir() and any( - (path / x).exists() - for x in { - "config.json", - "model_index.json", - "learned_embeds.bin", - "pytorch_lora_weights.bin", - "pytorch_lora_weights.safetensors", - } - ): - models_installed.update({str(model_path_id_or_url): self._install_path(path)}) - - # recursive scan - elif path.is_dir(): - for child in path.iterdir(): - self.heuristic_import(child, models_installed=models_installed) - - # huggingface repo - elif len(str(model_path_id_or_url).split("/")) == 2: - models_installed.update({str(model_path_id_or_url): self._install_repo(str(model_path_id_or_url))}) - - # a URL - elif str(model_path_id_or_url).startswith(("http:", "https:", "ftp:")): - models_installed.update({str(model_path_id_or_url): self._install_url(model_path_id_or_url)}) - - else: - raise KeyError(f"{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping") - - return models_installed - - def _remove_installed(self, model_list: List[str]): - all_models = self.all_models() - models_to_remove = [] - - for path in model_list: - key = self.reverse_paths.get(path) - if key and all_models[key].installed: - models_to_remove.append(path) - - for path in models_to_remove: - logger.warning(f"{path} already installed. Skipping") - model_list.remove(path) - - def _add_required_models(self, model_list: List[str]): - additional_models = [] - all_models = self.all_models() - for path in model_list: - if not (key := self.reverse_paths.get(path)): - continue - for requirement in all_models[key].requires: - requirement_key = self.reverse_paths.get(requirement) - if not all_models[requirement_key].installed: - additional_models.append(requirement) - model_list.extend(additional_models) - - # install a model from a local path. The optional info parameter is there to prevent - # the model from being probed twice in the event that it has already been probed. - def _install_path(self, path: Path, info: ModelProbeInfo = None) -> AddModelResult: - info = info or ModelProbe().heuristic_probe(path, self.prediction_helper) - if not info: - logger.warning(f"Unable to parse format of {path}") - return None - model_name = path.stem if path.is_file() else path.name - if self.mgr.model_exists(model_name, info.base_type, info.model_type): - raise ValueError(f'A model named "{model_name}" is already installed.') - attributes = self._make_attributes(path, info) - return self.mgr.add_model( - model_name=model_name, - base_model=info.base_type, - model_type=info.model_type, - model_attributes=attributes, - ) - - def _install_url(self, url: str) -> AddModelResult: - with TemporaryDirectory(dir=self.config.models_path) as staging: - CIVITAI_RE = r".*civitai.com.*" - civit_url = re.match(CIVITAI_RE, url, re.IGNORECASE) - location = download_with_resume( - url, Path(staging), access_token=self.civitai_api_key if civit_url else None - ) - if not location: - logger.error(f"Unable to download {url}. Skipping.") - info = ModelProbe().heuristic_probe(location, self.prediction_helper) - dest = self.config.models_path / info.base_type.value / info.model_type.value / location.name - dest.parent.mkdir(parents=True, exist_ok=True) - models_path = shutil.move(location, dest) - - # staged version will be garbage-collected at this time - return self._install_path(Path(models_path), info) - - def _install_repo(self, repo_id: str) -> AddModelResult: - # hack to recover models stored in subfolders -- - # Required to get the "v2" model of monster-labs/control_v1p_sd15_qrcode_monster - subfolder = None - if match := re.match(r"^([^/]+/[^/]+):(\w+)$", repo_id): - repo_id = match.group(1) - subfolder = match.group(2) - - hinfo = HfApi().model_info(repo_id) - - # we try to figure out how to download this most economically - # list all the files in the repo - files = [x.rfilename for x in hinfo.siblings] - if subfolder: - files = [x for x in files if x.startswith(f"{subfolder}/")] - prefix = f"{subfolder}/" if subfolder else "" - - location = None - - with TemporaryDirectory(dir=self.config.models_path) as staging: - staging = Path(staging) - if f"{prefix}model_index.json" in files: - location = self._download_hf_pipeline(repo_id, staging, subfolder=subfolder) # pipeline - elif f"{prefix}unet/model.onnx" in files: - location = self._download_hf_model(repo_id, files, staging) - else: - for suffix in ["safetensors", "bin"]: - if f"{prefix}pytorch_lora_weights.{suffix}" in files: - location = self._download_hf_model( - repo_id, [f"pytorch_lora_weights.{suffix}"], staging, subfolder=subfolder - ) # LoRA - break - elif ( - self.config.precision == "float16" and f"{prefix}diffusion_pytorch_model.fp16.{suffix}" in files - ): # vae, controlnet or some other standalone - files = ["config.json", f"diffusion_pytorch_model.fp16.{suffix}"] - location = self._download_hf_model(repo_id, files, staging, subfolder=subfolder) - break - elif f"{prefix}diffusion_pytorch_model.{suffix}" in files: - files = ["config.json", f"diffusion_pytorch_model.{suffix}"] - location = self._download_hf_model(repo_id, files, staging, subfolder=subfolder) - break - elif f"{prefix}learned_embeds.{suffix}" in files: - location = self._download_hf_model( - repo_id, [f"learned_embeds.{suffix}"], staging, subfolder=subfolder - ) - break - elif ( - f"{prefix}image_encoder.txt" in files and f"{prefix}ip_adapter.{suffix}" in files - ): # IP-Adapter - files = ["image_encoder.txt", f"ip_adapter.{suffix}"] - location = self._download_hf_model(repo_id, files, staging, subfolder=subfolder) - break - elif f"{prefix}model.{suffix}" in files and f"{prefix}config.json" in files: - # This elif-condition is pretty fragile, but it is intended to handle CLIP Vision models hosted - # by InvokeAI for use with IP-Adapters. - files = ["config.json", f"model.{suffix}"] - location = self._download_hf_model(repo_id, files, staging, subfolder=subfolder) - break - if not location: - logger.warning(f"Could not determine type of repo {repo_id}. Skipping install.") - return {} - - info = ModelProbe().heuristic_probe(location, self.prediction_helper) - if not info: - logger.warning(f"Could not probe {location}. Skipping install.") - return {} - dest = ( - self.config.models_path - / info.base_type.value - / info.model_type.value - / self._get_model_name(repo_id, location) - ) - if dest.exists(): - shutil.rmtree(dest) - shutil.copytree(location, dest) - return self._install_path(dest, info) - - def _get_model_name(self, path_name: str, location: Path) -> str: - """ - Calculate a name for the model - primitive implementation. - """ - if key := self.reverse_paths.get(path_name): - (name, base, mtype) = ModelManager.parse_key(key) - return name - elif location.is_dir(): - return location.name - else: - return location.stem - - def _make_attributes(self, path: Path, info: ModelProbeInfo) -> dict: - model_name = path.name if path.is_dir() else path.stem - description = f"{info.base_type.value} {info.model_type.value} model {model_name}" - if key := self.reverse_paths.get(self.current_id): - if key in self.datasets: - description = self.datasets[key].get("description") or description - - rel_path = self.relative_to_root(path, self.config.models_path) - - attributes = { - "path": str(rel_path), - "description": str(description), - "model_format": info.format, - } - legacy_conf = None - if info.model_type == ModelType.Main or info.model_type == ModelType.ONNX: - attributes.update( - { - "variant": info.variant_type, - } - ) - if info.format == "checkpoint": - try: - possible_conf = path.with_suffix(".yaml") - if possible_conf.exists(): - legacy_conf = str(self.relative_to_root(possible_conf)) - elif info.base_type in [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]: - legacy_conf = Path( - self.config.legacy_conf_dir, - LEGACY_CONFIGS[info.base_type][info.variant_type][info.prediction_type], - ) - else: - legacy_conf = Path( - self.config.legacy_conf_dir, LEGACY_CONFIGS[info.base_type][info.variant_type] - ) - except KeyError: - legacy_conf = Path(self.config.legacy_conf_dir, "v1-inference.yaml") # best guess - - if info.model_type == ModelType.ControlNet and info.format == "checkpoint": - possible_conf = path.with_suffix(".yaml") - if possible_conf.exists(): - legacy_conf = str(self.relative_to_root(possible_conf)) - else: - legacy_conf = Path( - self.config.root_path, - "configs/controlnet", - ("cldm_v15.yaml" if info.base_type == BaseModelType("sd-1") else "cldm_v21.yaml"), - ) - - if legacy_conf: - attributes.update({"config": str(legacy_conf)}) - return attributes - - def relative_to_root(self, path: Path, root: Optional[Path] = None) -> Path: - root = root or self.config.root_path - if path.is_relative_to(root): - return path.relative_to(root) - else: - return path - - def _download_hf_pipeline(self, repo_id: str, staging: Path, subfolder: str = None) -> Path: - """ - Retrieve a StableDiffusion model from cache or remote and then - does a save_pretrained() to the indicated staging area. - """ - _, name = repo_id.split("/") - precision = torch_dtype(choose_torch_device()) - variants = ["fp16", None] if precision == torch.float16 else [None, "fp16"] - - model = None - for variant in variants: - try: - model = DiffusionPipeline.from_pretrained( - repo_id, - variant=variant, - torch_dtype=precision, - safety_checker=None, - subfolder=subfolder, - ) - except Exception as e: # most errors are due to fp16 not being present. Fix this to catch other errors - if "fp16" not in str(e): - print(e) - - if model: - break - - if not model: - logger.error(f"Diffusers model {repo_id} could not be downloaded. Skipping.") - return None - model.save_pretrained(staging / name, safe_serialization=True) - return staging / name - - def _download_hf_model(self, repo_id: str, files: List[str], staging: Path, subfolder: None) -> Path: - _, name = repo_id.split("/") - location = staging / name - paths = [] - for filename in files: - filePath = Path(filename) - p = hf_download_with_resume( - repo_id, - model_dir=location / filePath.parent, - model_name=filePath.name, - access_token=self.access_token, - subfolder=filePath.parent / subfolder if subfolder else filePath.parent, - ) - if p: - paths.append(p) - else: - logger.warning(f"Could not download {filename} from {repo_id}.") - - return location if len(paths) > 0 else None - - @classmethod - def _reverse_paths(cls, datasets) -> dict: - """ - Reverse mapping from repo_id/path to destination name. - """ - return {v.get("path") or v.get("repo_id"): k for k, v in datasets.items()} - - -# ------------------------------------- -def yes_or_no(prompt: str, default_yes=True): - default = "y" if default_yes else "n" - response = input(f"{prompt} [{default}] ") or default - if default_yes: - return response[0] not in ("n", "N") - else: - return response[0] in ("y", "Y") - - -# --------------------------------------------- -def hf_download_from_pretrained(model_class: object, model_name: str, destination: Path, **kwargs): - logger = InvokeAILogger.get_logger("InvokeAI") - logger.addFilter(lambda x: "fp16 is not a valid" not in x.getMessage()) - - model = model_class.from_pretrained( - model_name, - resume_download=True, - **kwargs, - ) - model.save_pretrained(destination, safe_serialization=True) - return destination - - -# --------------------------------------------- -def hf_download_with_resume( - repo_id: str, - model_dir: str, - model_name: str, - model_dest: Path = None, - access_token: str = None, - subfolder: str = None, -) -> Path: - model_dest = model_dest or Path(os.path.join(model_dir, model_name)) - os.makedirs(model_dir, exist_ok=True) - - url = hf_hub_url(repo_id, model_name, subfolder=subfolder) - - header = {"Authorization": f"Bearer {access_token}"} if access_token else {} - open_mode = "wb" - exist_size = 0 - - if os.path.exists(model_dest): - exist_size = os.path.getsize(model_dest) - header["Range"] = f"bytes={exist_size}-" - open_mode = "ab" - - resp = requests.get(url, headers=header, stream=True) - total = int(resp.headers.get("content-length", 0)) - - if resp.status_code == 416: # "range not satisfiable", which means nothing to return - logger.info(f"{model_name}: complete file found. Skipping.") - return model_dest - elif resp.status_code == 404: - logger.warning("File not found") - return None - elif resp.status_code != 200: - logger.warning(f"{model_name}: {resp.reason}") - elif exist_size > 0: - logger.info(f"{model_name}: partial file found. Resuming...") - else: - logger.info(f"{model_name}: Downloading...") - - try: - with ( - open(model_dest, open_mode) as file, - tqdm( - desc=model_name, - initial=exist_size, - total=total + exist_size, - unit="iB", - unit_scale=True, - unit_divisor=1000, - ) as bar, - ): - for data in resp.iter_content(chunk_size=1024): - size = file.write(data) - bar.update(size) - except Exception as e: - logger.error(f"An error occurred while downloading {model_name}: {str(e)}") - return None - return model_dest diff --git a/invokeai/backend/ip_adapter/ip_adapter.py b/invokeai/backend/ip_adapter/ip_adapter.py index 3ba6fc5a23c..e51966c779c 100644 --- a/invokeai/backend/ip_adapter/ip_adapter.py +++ b/invokeai/backend/ip_adapter/ip_adapter.py @@ -9,8 +9,8 @@ from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionWeights -from .resampler import Resampler from ..raw_model import RawModel +from .resampler import Resampler class ImageProjModel(torch.nn.Module): diff --git a/invokeai/backend/lora.py b/invokeai/backend/lora.py index fb0c23067fb..0b7128034a2 100644 --- a/invokeai/backend/lora.py +++ b/invokeai/backend/lora.py @@ -10,6 +10,7 @@ from typing_extensions import Self from invokeai.backend.model_manager import BaseModelType + from .raw_model import RawModel @@ -366,6 +367,7 @@ def to( AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer] + class LoRAModelRaw(RawModel): # (torch.nn.Module): _name: str layers: Dict[str, AnyLoRALayer] diff --git a/invokeai/backend/model_management_OLD/README.md b/invokeai/backend/model_management_OLD/README.md deleted file mode 100644 index 0d94f39642e..00000000000 --- a/invokeai/backend/model_management_OLD/README.md +++ /dev/null @@ -1,27 +0,0 @@ -# Model Cache - -## `glibc` Memory Allocator Fragmentation - -Python (and PyTorch) relies on the memory allocator from the C Standard Library (`libc`). On linux, with the GNU C Standard Library implementation (`glibc`), our memory access patterns have been observed to cause severe memory fragmentation. This fragmentation results in large amounts of memory that has been freed but can't be released back to the OS. Loading models from disk and moving them between CPU/CUDA seem to be the operations that contribute most to the fragmentation. This memory fragmentation issue can result in OOM crashes during frequent model switching, even if `max_cache_size` is set to a reasonable value (e.g. a OOM crash with `max_cache_size=16` on a system with 32GB of RAM). - -This problem may also exist on other OSes, and other `libc` implementations. But, at the time of writing, it has only been investigated on linux with `glibc`. - -To better understand how the `glibc` memory allocator works, see these references: -- Basics: https://www.gnu.org/software/libc/manual/html_node/The-GNU-Allocator.html -- Details: https://sourceware.org/glibc/wiki/MallocInternals - -Note the differences between memory allocated as chunks in an arena vs. memory allocated with `mmap`. Under `glibc`'s default configuration, most model tensors get allocated as chunks in an arena making them vulnerable to the problem of fragmentation. - -We can work around this memory fragmentation issue by setting the following env var: - -```bash -# Force blocks >1MB to be allocated with `mmap` so that they are released to the system immediately when they are freed. -MALLOC_MMAP_THRESHOLD_=1048576 -``` - -See the following references for more information about the `malloc` tunable parameters: -- https://www.gnu.org/software/libc/manual/html_node/Malloc-Tunable-Parameters.html -- https://www.gnu.org/software/libc/manual/html_node/Memory-Allocation-Tunables.html -- https://man7.org/linux/man-pages/man3/mallopt.3.html - -The model cache emits debug logs that provide visibility into the state of the `libc` memory allocator. See the `LibcUtil` class for more info on how these `libc` malloc stats are collected. diff --git a/invokeai/backend/model_management_OLD/__init__.py b/invokeai/backend/model_management_OLD/__init__.py deleted file mode 100644 index d523a7a0c8d..00000000000 --- a/invokeai/backend/model_management_OLD/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -# ruff: noqa: I001, F401 -""" -Initialization file for invokeai.backend.model_management -""" -# This import must be first -from .model_manager import AddModelResult, LoadedModelInfo, ModelManager, SchedulerPredictionType -from .lora import ModelPatcher, ONNXModelPatcher -from .model_cache import ModelCache - -from .models import ( - BaseModelType, - DuplicateModelException, - ModelNotFoundException, - ModelType, - ModelVariantType, - SubModelType, -) - -# This import must be last -from .model_merge import MergeInterpolationMethod, ModelMerger diff --git a/invokeai/backend/model_management_OLD/convert_ckpt_to_diffusers.py b/invokeai/backend/model_management_OLD/convert_ckpt_to_diffusers.py deleted file mode 100644 index 6878218f679..00000000000 --- a/invokeai/backend/model_management_OLD/convert_ckpt_to_diffusers.py +++ /dev/null @@ -1,1739 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# Adapted for use in InvokeAI by Lincoln Stein, July 2023 -# -""" Conversion script for the Stable Diffusion checkpoints.""" - -import re -from contextlib import nullcontext -from io import BytesIO -from pathlib import Path -from typing import Optional, Union - -import requests -import torch -from diffusers.models import AutoencoderKL, ControlNetModel, PriorTransformer, UNet2DConditionModel -from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel -from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder -from diffusers.pipelines.pipeline_utils import DiffusionPipeline -from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker -from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer -from diffusers.schedulers import ( - DDIMScheduler, - DDPMScheduler, - DPMSolverMultistepScheduler, - EulerAncestralDiscreteScheduler, - EulerDiscreteScheduler, - HeunDiscreteScheduler, - LMSDiscreteScheduler, - PNDMScheduler, - UnCLIPScheduler, -) -from diffusers.utils import is_accelerate_available -from picklescan.scanner import scan_file_path -from transformers import ( - AutoFeatureExtractor, - BertTokenizerFast, - CLIPImageProcessor, - CLIPTextConfig, - CLIPTextModel, - CLIPTextModelWithProjection, - CLIPTokenizer, - CLIPVisionConfig, - CLIPVisionModelWithProjection, -) - -from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend.util.logging import InvokeAILogger - -from .models import BaseModelType, ModelVariantType - -try: - from omegaconf import OmegaConf - from omegaconf.dictconfig import DictConfig -except ImportError: - raise ImportError( - "OmegaConf is required to convert the LDM checkpoints. Please install it with `pip install OmegaConf`." - ) - -if is_accelerate_available(): - from accelerate import init_empty_weights - from accelerate.utils import set_module_tensor_to_device - -logger = InvokeAILogger.get_logger(__name__) -CONVERT_MODEL_ROOT = InvokeAIAppConfig.get_config().models_path / "core/convert" - - -def shave_segments(path, n_shave_prefix_segments=1): - """ - Removes segments. Positive values shave the first segments, negative shave the last segments. - """ - if n_shave_prefix_segments >= 0: - return ".".join(path.split(".")[n_shave_prefix_segments:]) - else: - return ".".join(path.split(".")[:n_shave_prefix_segments]) - - -def renew_resnet_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside resnets to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item.replace("in_layers.0", "norm1") - new_item = new_item.replace("in_layers.2", "conv1") - - new_item = new_item.replace("out_layers.0", "norm2") - new_item = new_item.replace("out_layers.3", "conv2") - - new_item = new_item.replace("emb_layers.1", "time_emb_proj") - new_item = new_item.replace("skip_connection", "conv_shortcut") - - new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) - - mapping.append({"old": old_item, "new": new_item}) - - return mapping - - -def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside resnets to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item - - new_item = new_item.replace("nin_shortcut", "conv_shortcut") - new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) - - mapping.append({"old": old_item, "new": new_item}) - - return mapping - - -def renew_attention_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside attentions to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item - - # new_item = new_item.replace('norm.weight', 'group_norm.weight') - # new_item = new_item.replace('norm.bias', 'group_norm.bias') - - # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') - # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') - - # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) - - mapping.append({"old": old_item, "new": new_item}) - - return mapping - - -def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside attentions to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item - - new_item = new_item.replace("norm.weight", "group_norm.weight") - new_item = new_item.replace("norm.bias", "group_norm.bias") - - new_item = new_item.replace("q.weight", "to_q.weight") - new_item = new_item.replace("q.bias", "to_q.bias") - - new_item = new_item.replace("k.weight", "to_k.weight") - new_item = new_item.replace("k.bias", "to_k.bias") - - new_item = new_item.replace("v.weight", "to_v.weight") - new_item = new_item.replace("v.bias", "to_v.bias") - - new_item = new_item.replace("proj_out.weight", "to_out.0.weight") - new_item = new_item.replace("proj_out.bias", "to_out.0.bias") - - new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) - - mapping.append({"old": old_item, "new": new_item}) - - return mapping - - -def assign_to_checkpoint( - paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None -): - """ - This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits - attention layers, and takes into account additional replacements that may arise. - - Assigns the weights to the new checkpoint. - """ - assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." - - # Splits the attention layers into three variables. - if attention_paths_to_split is not None: - for path, path_map in attention_paths_to_split.items(): - old_tensor = old_checkpoint[path] - channels = old_tensor.shape[0] // 3 - - target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) - - num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 - - old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) - query, key, value = old_tensor.split(channels // num_heads, dim=1) - - checkpoint[path_map["query"]] = query.reshape(target_shape) - checkpoint[path_map["key"]] = key.reshape(target_shape) - checkpoint[path_map["value"]] = value.reshape(target_shape) - - for path in paths: - new_path = path["new"] - - # These have already been assigned - if attention_paths_to_split is not None and new_path in attention_paths_to_split: - continue - - # Global renaming happens here - new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") - new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") - new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") - - if additional_replacements is not None: - for replacement in additional_replacements: - new_path = new_path.replace(replacement["old"], replacement["new"]) - - # proj_attn.weight has to be converted from conv 1D to linear - is_attn_weight = "proj_attn.weight" in new_path or ("attentions" in new_path and "to_" in new_path) - shape = old_checkpoint[path["old"]].shape - if is_attn_weight and len(shape) == 3: - checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] - elif is_attn_weight and len(shape) == 4: - checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0] - else: - checkpoint[new_path] = old_checkpoint[path["old"]] - - -def conv_attn_to_linear(checkpoint): - keys = list(checkpoint.keys()) - attn_keys = ["query.weight", "key.weight", "value.weight"] - for key in keys: - if ".".join(key.split(".")[-2:]) in attn_keys: - if checkpoint[key].ndim > 2: - checkpoint[key] = checkpoint[key][:, :, 0, 0] - elif "proj_attn.weight" in key: - if checkpoint[key].ndim > 2: - checkpoint[key] = checkpoint[key][:, :, 0] - - -def create_unet_diffusers_config(original_config, image_size: int, controlnet=False): - """ - Creates a config for the diffusers based on the config of the LDM model. - """ - if controlnet: - unet_params = original_config.model.params.control_stage_config.params - else: - if "unet_config" in original_config.model.params and original_config.model.params.unet_config is not None: - unet_params = original_config.model.params.unet_config.params - else: - unet_params = original_config.model.params.network_config.params - - vae_params = original_config.model.params.first_stage_config.params.ddconfig - - block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult] - - down_block_types = [] - resolution = 1 - for i in range(len(block_out_channels)): - block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D" - down_block_types.append(block_type) - if i != len(block_out_channels) - 1: - resolution *= 2 - - up_block_types = [] - for _i in range(len(block_out_channels)): - block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D" - up_block_types.append(block_type) - resolution //= 2 - - if unet_params.transformer_depth is not None: - transformer_layers_per_block = ( - unet_params.transformer_depth - if isinstance(unet_params.transformer_depth, int) - else list(unet_params.transformer_depth) - ) - else: - transformer_layers_per_block = 1 - - vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1) - - head_dim = unet_params.num_heads if "num_heads" in unet_params else None - use_linear_projection = ( - unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False - ) - if use_linear_projection: - # stable diffusion 2-base-512 and 2-768 - if head_dim is None: - head_dim_mult = unet_params.model_channels // unet_params.num_head_channels - head_dim = [head_dim_mult * c for c in list(unet_params.channel_mult)] - - class_embed_type = None - addition_embed_type = None - addition_time_embed_dim = None - projection_class_embeddings_input_dim = None - context_dim = None - - if unet_params.context_dim is not None: - context_dim = ( - unet_params.context_dim if isinstance(unet_params.context_dim, int) else unet_params.context_dim[0] - ) - - if "num_classes" in unet_params: - if unet_params.num_classes == "sequential": - if context_dim in [2048, 1280]: - # SDXL - addition_embed_type = "text_time" - addition_time_embed_dim = 256 - else: - class_embed_type = "projection" - assert "adm_in_channels" in unet_params - projection_class_embeddings_input_dim = unet_params.adm_in_channels - else: - raise NotImplementedError(f"Unknown conditional unet num_classes config: {unet_params.num_classes}") - - config = { - "sample_size": image_size // vae_scale_factor, - "in_channels": unet_params.in_channels, - "down_block_types": tuple(down_block_types), - "block_out_channels": tuple(block_out_channels), - "layers_per_block": unet_params.num_res_blocks, - "cross_attention_dim": context_dim, - "attention_head_dim": head_dim, - "use_linear_projection": use_linear_projection, - "class_embed_type": class_embed_type, - "addition_embed_type": addition_embed_type, - "addition_time_embed_dim": addition_time_embed_dim, - "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim, - "transformer_layers_per_block": transformer_layers_per_block, - } - - if controlnet: - config["conditioning_channels"] = unet_params.hint_channels - else: - config["out_channels"] = unet_params.out_channels - config["up_block_types"] = tuple(up_block_types) - - return config - - -def create_vae_diffusers_config(original_config, image_size: int): - """ - Creates a config for the diffusers based on the config of the LDM model. - """ - vae_params = original_config.model.params.first_stage_config.params.ddconfig - _ = original_config.model.params.first_stage_config.params.embed_dim - - block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult] - down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) - up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) - - config = { - "sample_size": image_size, - "in_channels": vae_params.in_channels, - "out_channels": vae_params.out_ch, - "down_block_types": tuple(down_block_types), - "up_block_types": tuple(up_block_types), - "block_out_channels": tuple(block_out_channels), - "latent_channels": vae_params.z_channels, - "layers_per_block": vae_params.num_res_blocks, - } - return config - - -def create_diffusers_schedular(original_config): - schedular = DDIMScheduler( - num_train_timesteps=original_config.model.params.timesteps, - beta_start=original_config.model.params.linear_start, - beta_end=original_config.model.params.linear_end, - beta_schedule="scaled_linear", - ) - return schedular - - -def create_ldm_bert_config(original_config): - bert_params = original_config.model.parms.cond_stage_config.params - config = LDMBertConfig( - d_model=bert_params.n_embed, - encoder_layers=bert_params.n_layer, - encoder_ffn_dim=bert_params.n_embed * 4, - ) - return config - - -def convert_ldm_unet_checkpoint( - checkpoint, config, path=None, extract_ema=False, controlnet=False, skip_extract_state_dict=False -): - """ - Takes a state dict and a config, and returns a converted checkpoint. - """ - - if skip_extract_state_dict: - unet_state_dict = checkpoint - else: - # extract state_dict for UNet - unet_state_dict = {} - keys = list(checkpoint.keys()) - - if controlnet: - unet_key = "control_model." - else: - unet_key = "model.diffusion_model." - - # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA - if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema: - logger.warning(f"Checkpoint {path} has both EMA and non-EMA weights.") - logger.warning( - "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" - " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." - ) - for key in keys: - if key.startswith("model.diffusion_model"): - flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) - unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) - else: - if sum(k.startswith("model_ema") for k in keys) > 100: - logger.warning( - "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" - " weights (usually better for inference), please make sure to add the `--extract_ema` flag." - ) - - for key in keys: - if key.startswith(unet_key): - unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) - - new_checkpoint = {} - - new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] - new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] - new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] - new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] - - if config["class_embed_type"] is None: - # No parameters to port - ... - elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection": - new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] - new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] - new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] - new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] - else: - raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}") - - if config["addition_embed_type"] == "text_time": - new_checkpoint["add_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] - new_checkpoint["add_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] - new_checkpoint["add_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] - new_checkpoint["add_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] - - new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] - new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] - - if not controlnet: - new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] - new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] - new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] - new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] - - # Retrieves the keys for the input blocks only - num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) - input_blocks = { - layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] - for layer_id in range(num_input_blocks) - } - - # Retrieves the keys for the middle blocks only - num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) - middle_blocks = { - layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] - for layer_id in range(num_middle_blocks) - } - - # Retrieves the keys for the output blocks only - num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) - output_blocks = { - layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] - for layer_id in range(num_output_blocks) - } - - for i in range(1, num_input_blocks): - block_id = (i - 1) // (config["layers_per_block"] + 1) - layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) - - resnets = [ - key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key - ] - attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] - - if f"input_blocks.{i}.0.op.weight" in unet_state_dict: - new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( - f"input_blocks.{i}.0.op.weight" - ) - new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( - f"input_blocks.{i}.0.op.bias" - ) - - paths = renew_resnet_paths(resnets) - meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} - assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config) - - if len(attentions): - paths = renew_attention_paths(attentions) - meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} - assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config - ) - - resnet_0 = middle_blocks[0] - attentions = middle_blocks[1] - resnet_1 = middle_blocks[2] - - resnet_0_paths = renew_resnet_paths(resnet_0) - assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) - - resnet_1_paths = renew_resnet_paths(resnet_1) - assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) - - attentions_paths = renew_attention_paths(attentions) - meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} - assign_to_checkpoint( - attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config - ) - - for i in range(num_output_blocks): - block_id = i // (config["layers_per_block"] + 1) - layer_in_block_id = i % (config["layers_per_block"] + 1) - output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] - output_block_list = {} - - for layer in output_block_layers: - layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) - if layer_id in output_block_list: - output_block_list[layer_id].append(layer_name) - else: - output_block_list[layer_id] = [layer_name] - - if len(output_block_list) > 1: - resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] - attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] - - resnet_0_paths = renew_resnet_paths(resnets) - paths = renew_resnet_paths(resnets) - - meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} - assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config - ) - - output_block_list = {k: sorted(v) for k, v in output_block_list.items()} - if ["conv.bias", "conv.weight"] in output_block_list.values(): - index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) - new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ - f"output_blocks.{i}.{index}.conv.weight" - ] - new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ - f"output_blocks.{i}.{index}.conv.bias" - ] - - # Clear attentions as they have been attributed above. - if len(attentions) == 2: - attentions = [] - - if len(attentions): - paths = renew_attention_paths(attentions) - meta_path = { - "old": f"output_blocks.{i}.1", - "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", - } - assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config - ) - else: - resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) - for path in resnet_0_paths: - old_path = ".".join(["output_blocks", str(i), path["old"]]) - new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) - - new_checkpoint[new_path] = unet_state_dict[old_path] - - if controlnet: - # conditioning embedding - - orig_index = 0 - - new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = unet_state_dict.pop( - f"input_hint_block.{orig_index}.weight" - ) - new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop( - f"input_hint_block.{orig_index}.bias" - ) - - orig_index += 2 - - diffusers_index = 0 - - while diffusers_index < 6: - new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop( - f"input_hint_block.{orig_index}.weight" - ) - new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop( - f"input_hint_block.{orig_index}.bias" - ) - diffusers_index += 1 - orig_index += 2 - - new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop( - f"input_hint_block.{orig_index}.weight" - ) - new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop( - f"input_hint_block.{orig_index}.bias" - ) - - # down blocks - for i in range(num_input_blocks): - new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(f"zero_convs.{i}.0.weight") - new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(f"zero_convs.{i}.0.bias") - - # mid block - new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight") - new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias") - - return new_checkpoint - - -def convert_ldm_vae_checkpoint(checkpoint, config): - # extract state dict for VAE - vae_state_dict = {} - keys = list(checkpoint.keys()) - vae_key = "first_stage_model." if any(k.startswith("first_stage_model.") for k in keys) else "" - for key in keys: - if key.startswith(vae_key): - vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) - - new_checkpoint = {} - - new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] - new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] - new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] - new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] - new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] - new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] - - new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] - new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] - new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] - new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] - new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] - new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] - - new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] - new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] - new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] - new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] - - # Retrieves the keys for the encoder down blocks only - num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) - down_blocks = { - layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) - } - - # Retrieves the keys for the decoder up blocks only - num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) - up_blocks = { - layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) - } - - for i in range(num_down_blocks): - resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] - - if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: - new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( - f"encoder.down.{i}.downsample.conv.weight" - ) - new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( - f"encoder.down.{i}.downsample.conv.bias" - ) - - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - - mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] - num_mid_res_blocks = 2 - for i in range(1, num_mid_res_blocks + 1): - resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] - - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - - mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] - paths = renew_vae_attention_paths(mid_attentions) - meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - conv_attn_to_linear(new_checkpoint) - - for i in range(num_up_blocks): - block_id = num_up_blocks - 1 - i - resnets = [ - key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key - ] - - if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: - new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ - f"decoder.up.{block_id}.upsample.conv.weight" - ] - new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ - f"decoder.up.{block_id}.upsample.conv.bias" - ] - - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - - mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] - num_mid_res_blocks = 2 - for i in range(1, num_mid_res_blocks + 1): - resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] - - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - - mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] - paths = renew_vae_attention_paths(mid_attentions) - meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - conv_attn_to_linear(new_checkpoint) - return new_checkpoint - - -def convert_ldm_bert_checkpoint(checkpoint, config): - def _copy_attn_layer(hf_attn_layer, pt_attn_layer): - hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight - hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight - hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight - - hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight - hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias - - def _copy_linear(hf_linear, pt_linear): - hf_linear.weight = pt_linear.weight - hf_linear.bias = pt_linear.bias - - def _copy_layer(hf_layer, pt_layer): - # copy layer norms - _copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0]) - _copy_linear(hf_layer.final_layer_norm, pt_layer[1][0]) - - # copy attn - _copy_attn_layer(hf_layer.self_attn, pt_layer[0][1]) - - # copy MLP - pt_mlp = pt_layer[1][1] - _copy_linear(hf_layer.fc1, pt_mlp.net[0][0]) - _copy_linear(hf_layer.fc2, pt_mlp.net[2]) - - def _copy_layers(hf_layers, pt_layers): - for i, hf_layer in enumerate(hf_layers): - if i != 0: - i += i - pt_layer = pt_layers[i : i + 2] - _copy_layer(hf_layer, pt_layer) - - hf_model = LDMBertModel(config).eval() - - # copy embeds - hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight - hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight - - # copy layer norm - _copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm) - - # copy hidden layers - _copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers) - - _copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits) - - return hf_model - - -def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder=None): - if text_encoder is None: - config = CLIPTextConfig.from_pretrained(CONVERT_MODEL_ROOT / "clip-vit-large-patch14") - - ctx = init_empty_weights if is_accelerate_available() else nullcontext - with ctx(): - text_model = CLIPTextModel(config) - - keys = list(checkpoint.keys()) - - text_model_dict = {} - - remove_prefixes = ["cond_stage_model.transformer", "conditioner.embedders.0.transformer"] - - for key in keys: - for prefix in remove_prefixes: - if key.startswith(prefix): - text_model_dict[key[len(prefix + ".") :]] = checkpoint[key] - - if is_accelerate_available(): - for param_name, param in text_model_dict.items(): - set_module_tensor_to_device(text_model, param_name, "cpu", value=param) - else: - text_model.load_state_dict(text_model_dict) - - return text_model - - -textenc_conversion_lst = [ - ("positional_embedding", "text_model.embeddings.position_embedding.weight"), - ("token_embedding.weight", "text_model.embeddings.token_embedding.weight"), - ("ln_final.weight", "text_model.final_layer_norm.weight"), - ("ln_final.bias", "text_model.final_layer_norm.bias"), - ("text_projection", "text_projection.weight"), -] -textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst} - -textenc_transformer_conversion_lst = [ - # (stable-diffusion, HF Diffusers) - ("resblocks.", "text_model.encoder.layers."), - ("ln_1", "layer_norm1"), - ("ln_2", "layer_norm2"), - (".c_fc.", ".fc1."), - (".c_proj.", ".fc2."), - (".attn", ".self_attn"), - ("ln_final.", "transformer.text_model.final_layer_norm."), - ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"), - ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"), -] -protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst} -textenc_pattern = re.compile("|".join(protected.keys())) - - -def convert_paint_by_example_checkpoint(checkpoint): - config = CLIPVisionConfig.from_pretrained(CONVERT_MODEL_ROOT / "clip-vit-large-patch14") - model = PaintByExampleImageEncoder(config) - - keys = list(checkpoint.keys()) - - text_model_dict = {} - - for key in keys: - if key.startswith("cond_stage_model.transformer"): - text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] - - # load clip vision - model.model.load_state_dict(text_model_dict) - - # load mapper - keys_mapper = { - k[len("cond_stage_model.mapper.res") :]: v - for k, v in checkpoint.items() - if k.startswith("cond_stage_model.mapper") - } - - MAPPING = { - "attn.c_qkv": ["attn1.to_q", "attn1.to_k", "attn1.to_v"], - "attn.c_proj": ["attn1.to_out.0"], - "ln_1": ["norm1"], - "ln_2": ["norm3"], - "mlp.c_fc": ["ff.net.0.proj"], - "mlp.c_proj": ["ff.net.2"], - } - - mapped_weights = {} - for key, value in keys_mapper.items(): - prefix = key[: len("blocks.i")] - suffix = key.split(prefix)[-1].split(".")[-1] - name = key.split(prefix)[-1].split(suffix)[0][1:-1] - mapped_names = MAPPING[name] - - num_splits = len(mapped_names) - for i, mapped_name in enumerate(mapped_names): - new_name = ".".join([prefix, mapped_name, suffix]) - shape = value.shape[0] // num_splits - mapped_weights[new_name] = value[i * shape : (i + 1) * shape] - - model.mapper.load_state_dict(mapped_weights) - - # load final layer norm - model.final_layer_norm.load_state_dict( - { - "bias": checkpoint["cond_stage_model.final_ln.bias"], - "weight": checkpoint["cond_stage_model.final_ln.weight"], - } - ) - - # load final proj - model.proj_out.load_state_dict( - { - "bias": checkpoint["proj_out.bias"], - "weight": checkpoint["proj_out.weight"], - } - ) - - # load uncond vector - model.uncond_vector.data = torch.nn.Parameter(checkpoint["learnable_vector"]) - return model - - -def convert_open_clip_checkpoint( - checkpoint, config_name, prefix="cond_stage_model.model.", has_projection=False, **config_kwargs -): - # text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder") - # text_model = CLIPTextModelWithProjection.from_pretrained( - # "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", projection_dim=1280 - # ) - config = CLIPTextConfig.from_pretrained(config_name, **config_kwargs) - - ctx = init_empty_weights if is_accelerate_available() else nullcontext - with ctx(): - text_model = CLIPTextModelWithProjection(config) if has_projection else CLIPTextModel(config) - - keys = list(checkpoint.keys()) - - keys_to_ignore = [] - if config_name == "stabilityai/stable-diffusion-2" and config.num_hidden_layers == 23: - # make sure to remove all keys > 22 - keys_to_ignore += [k for k in keys if k.startswith("cond_stage_model.model.transformer.resblocks.23")] - keys_to_ignore += ["cond_stage_model.model.text_projection"] - - text_model_dict = {} - - if prefix + "text_projection" in checkpoint: - d_model = int(checkpoint[prefix + "text_projection"].shape[0]) - else: - d_model = 1024 - - text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids") - - for key in keys: - if key in keys_to_ignore: - continue - if key[len(prefix) :] in textenc_conversion_map: - if key.endswith("text_projection"): - value = checkpoint[key].T.contiguous() - else: - value = checkpoint[key] - - text_model_dict[textenc_conversion_map[key[len(prefix) :]]] = value - - if key.startswith(prefix + "transformer."): - new_key = key[len(prefix + "transformer.") :] - if new_key.endswith(".in_proj_weight"): - new_key = new_key[: -len(".in_proj_weight")] - new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) - text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :] - text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model : d_model * 2, :] - text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2 :, :] - elif new_key.endswith(".in_proj_bias"): - new_key = new_key[: -len(".in_proj_bias")] - new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) - text_model_dict[new_key + ".q_proj.bias"] = checkpoint[key][:d_model] - text_model_dict[new_key + ".k_proj.bias"] = checkpoint[key][d_model : d_model * 2] - text_model_dict[new_key + ".v_proj.bias"] = checkpoint[key][d_model * 2 :] - else: - new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) - - text_model_dict[new_key] = checkpoint[key] - - if is_accelerate_available(): - for param_name, param in text_model_dict.items(): - set_module_tensor_to_device(text_model, param_name, "cpu", value=param) - else: - text_model.load_state_dict(text_model_dict) - - return text_model - - -def stable_unclip_image_encoder(original_config): - """ - Returns the image processor and clip image encoder for the img2img unclip pipeline. - - We currently know of two types of stable unclip models which separately use the clip and the openclip image - encoders. - """ - - image_embedder_config = original_config.model.params.embedder_config - - sd_clip_image_embedder_class = image_embedder_config.target - sd_clip_image_embedder_class = sd_clip_image_embedder_class.split(".")[-1] - - if sd_clip_image_embedder_class == "ClipImageEmbedder": - clip_model_name = image_embedder_config.params.model - - if clip_model_name == "ViT-L/14": - feature_extractor = CLIPImageProcessor() - image_encoder = CLIPVisionModelWithProjection.from_pretrained(CONVERT_MODEL_ROOT / "clip-vit-large-patch14") - else: - raise NotImplementedError(f"Unknown CLIP checkpoint name in stable diffusion checkpoint {clip_model_name}") - - elif sd_clip_image_embedder_class == "FrozenOpenCLIPImageEmbedder": - feature_extractor = CLIPImageProcessor() - # InvokeAI doesn't use CLIPVisionModelWithProjection so it isn't in the core - if this code is hit a download will occur - image_encoder = CLIPVisionModelWithProjection.from_pretrained( - CONVERT_MODEL_ROOT / "CLIP-ViT-H-14-laion2B-s32B-b79K" - ) - else: - raise NotImplementedError( - f"Unknown CLIP image embedder class in stable diffusion checkpoint {sd_clip_image_embedder_class}" - ) - - return feature_extractor, image_encoder - - -def stable_unclip_image_noising_components( - original_config, clip_stats_path: Optional[str] = None, device: Optional[str] = None -): - """ - Returns the noising components for the img2img and txt2img unclip pipelines. - - Converts the stability noise augmentor into - 1. a `StableUnCLIPImageNormalizer` for holding the CLIP stats - 2. a `DDPMScheduler` for holding the noise schedule - - If the noise augmentor config specifies a clip stats path, the `clip_stats_path` must be provided. - """ - noise_aug_config = original_config.model.params.noise_aug_config - noise_aug_class = noise_aug_config.target - noise_aug_class = noise_aug_class.split(".")[-1] - - if noise_aug_class == "CLIPEmbeddingNoiseAugmentation": - noise_aug_config = noise_aug_config.params - embedding_dim = noise_aug_config.timestep_dim - max_noise_level = noise_aug_config.noise_schedule_config.timesteps - beta_schedule = noise_aug_config.noise_schedule_config.beta_schedule - - image_normalizer = StableUnCLIPImageNormalizer(embedding_dim=embedding_dim) - image_noising_scheduler = DDPMScheduler(num_train_timesteps=max_noise_level, beta_schedule=beta_schedule) - - if "clip_stats_path" in noise_aug_config: - if clip_stats_path is None: - raise ValueError("This stable unclip config requires a `clip_stats_path`") - - clip_mean, clip_std = torch.load(clip_stats_path, map_location=device) - clip_mean = clip_mean[None, :] - clip_std = clip_std[None, :] - - clip_stats_state_dict = { - "mean": clip_mean, - "std": clip_std, - } - - image_normalizer.load_state_dict(clip_stats_state_dict) - else: - raise NotImplementedError(f"Unknown noise augmentor class: {noise_aug_class}") - - return image_normalizer, image_noising_scheduler - - -def convert_controlnet_checkpoint( - checkpoint, - original_config, - checkpoint_path, - image_size, - upcast_attention, - extract_ema, - use_linear_projection=None, - cross_attention_dim=None, - precision: Optional[torch.dtype] = None, -): - ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True) - ctrlnet_config["upcast_attention"] = upcast_attention - - ctrlnet_config.pop("sample_size") - original_config = ctrlnet_config.copy() - - ctrlnet_config.pop("addition_embed_type") - ctrlnet_config.pop("addition_time_embed_dim") - ctrlnet_config.pop("transformer_layers_per_block") - - if use_linear_projection is not None: - ctrlnet_config["use_linear_projection"] = use_linear_projection - - if cross_attention_dim is not None: - ctrlnet_config["cross_attention_dim"] = cross_attention_dim - - controlnet = ControlNetModel(**ctrlnet_config) - - # Some controlnet ckpt files are distributed independently from the rest of the - # model components i.e. https://huggingface.co/thibaud/controlnet-sd21/ - if "time_embed.0.weight" in checkpoint: - skip_extract_state_dict = True - else: - skip_extract_state_dict = False - - converted_ctrl_checkpoint = convert_ldm_unet_checkpoint( - checkpoint, - original_config, - path=checkpoint_path, - extract_ema=extract_ema, - controlnet=True, - skip_extract_state_dict=skip_extract_state_dict, - ) - - controlnet.load_state_dict(converted_ctrl_checkpoint) - - return controlnet.to(precision) - - -def download_from_original_stable_diffusion_ckpt( - checkpoint_path: str, - model_version: BaseModelType, - model_variant: ModelVariantType, - original_config_file: str = None, - image_size: Optional[int] = None, - prediction_type: str = None, - model_type: str = None, - extract_ema: bool = False, - precision: Optional[torch.dtype] = None, - scheduler_type: str = "pndm", - num_in_channels: Optional[int] = None, - upcast_attention: Optional[bool] = None, - device: str = None, - from_safetensors: bool = False, - stable_unclip: Optional[str] = None, - stable_unclip_prior: Optional[str] = None, - clip_stats_path: Optional[str] = None, - controlnet: Optional[bool] = None, - load_safety_checker: bool = True, - pipeline_class: DiffusionPipeline = None, - local_files_only=False, - vae_path=None, - text_encoder=None, - tokenizer=None, - scan_needed: bool = True, -) -> DiffusionPipeline: - """ - Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml` - config file. - - Although many of the arguments can be automatically inferred, some of these rely on brittle checks against the - global step count, which will likely fail for models that have undergone further fine-tuning. Therefore, it is - recommended that you override the default values and/or supply an `original_config_file` wherever possible. - - Args: - checkpoint_path (`str`): Path to `.ckpt` file. - original_config_file (`str`): - Path to `.yaml` config file corresponding to the original architecture. If `None`, will be automatically - inferred by looking for a key that only exists in SD2.0 models. - image_size (`int`, *optional*, defaults to 512): - The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Diffusion v2 - Base. Use 768 for Stable Diffusion v2. - prediction_type (`str`, *optional*): - The prediction type that the model was trained on. Use `'epsilon'` for Stable Diffusion v1.X and Stable - Diffusion v2 Base. Use `'v_prediction'` for Stable Diffusion v2. - num_in_channels (`int`, *optional*, defaults to None): - The number of input channels. If `None`, it will be automatically inferred. - scheduler_type (`str`, *optional*, defaults to 'pndm'): - Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm", - "ddim"]`. - model_type (`str`, *optional*, defaults to `None`): - The pipeline type. `None` to automatically infer, or one of `["FrozenOpenCLIPEmbedder", - "FrozenCLIPEmbedder", "PaintByExample"]`. - is_img2img (`bool`, *optional*, defaults to `False`): - Whether the model should be loaded as an img2img pipeline. - extract_ema (`bool`, *optional*, defaults to `False`): Only relevant for - checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults to - `False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher quality images for - inference. Non-EMA weights are usually better to continue fine-tuning. - upcast_attention (`bool`, *optional*, defaults to `None`): - Whether the attention computation should always be upcasted. This is necessary when running stable - diffusion 2.1. - device (`str`, *optional*, defaults to `None`): - The device to use. Pass `None` to determine automatically. - from_safetensors (`str`, *optional*, defaults to `False`): - If `checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch. - load_safety_checker (`bool`, *optional*, defaults to `True`): - Whether to load the safety checker or not. Defaults to `True`. - pipeline_class (`str`, *optional*, defaults to `None`): - The pipeline class to use. Pass `None` to determine automatically. - local_files_only (`bool`, *optional*, defaults to `False`): - Whether or not to only look at local files (i.e., do not try to download the model). - text_encoder (`CLIPTextModel`, *optional*, defaults to `None`): - An instance of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel) - to use, specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) - variant. If this parameter is `None`, the function will load a new instance of [CLIP] by itself, if needed. - tokenizer (`CLIPTokenizer`, *optional*, defaults to `None`): - An instance of - [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer) - to use. If this parameter is `None`, the function will load a new instance of [CLIPTokenizer] by itself, if - needed. - precision (`torch.dtype`, *optional*, defauts to `None`): - If not provided the precision will be set to the precision of the original file. - return: A StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file. - """ - - # import pipelines here to avoid circular import error when using from_single_file method - from diffusers import ( - LDMTextToImagePipeline, - PaintByExamplePipeline, - StableDiffusionControlNetPipeline, - StableDiffusionInpaintPipeline, - StableDiffusionPipeline, - StableDiffusionXLImg2ImgPipeline, - StableDiffusionXLPipeline, - StableUnCLIPImg2ImgPipeline, - StableUnCLIPPipeline, - ) - - if pipeline_class is None: - pipeline_class = StableDiffusionPipeline if not controlnet else StableDiffusionControlNetPipeline - - if prediction_type == "v-prediction": - prediction_type = "v_prediction" - - if from_safetensors: - from safetensors.torch import load_file as safe_load - - checkpoint = safe_load(checkpoint_path, device="cpu") - else: - if scan_needed: - # scan model - scan_result = scan_file_path(checkpoint_path) - if scan_result.infected_files != 0: - raise Exception("The model {checkpoint_path} is potentially infected by malware. Aborting import.") - if device is None: - device = "cuda" if torch.cuda.is_available() else "cpu" - checkpoint = torch.load(checkpoint_path, map_location=device) - else: - checkpoint = torch.load(checkpoint_path, map_location=device) - - # Sometimes models don't have the global_step item - if "global_step" in checkpoint: - global_step = checkpoint["global_step"] - else: - logger.debug("global_step key not found in model") - global_step = None - - # NOTE: this while loop isn't great but this controlnet checkpoint has one additional - # "state_dict" key https://huggingface.co/thibaud/controlnet-canny-sd21 - while "state_dict" in checkpoint: - checkpoint = checkpoint["state_dict"] - - logger.debug(f"model_type = {model_type}; original_config_file = {original_config_file}") - - precision_probing_key = "model.diffusion_model.input_blocks.0.0.bias" - logger.debug(f"original checkpoint precision == {checkpoint[precision_probing_key].dtype}") - precision = precision or checkpoint[precision_probing_key].dtype - - if original_config_file is None: - key_name_v2_1 = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" - key_name_sd_xl_base = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias" - key_name_sd_xl_refiner = "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias" - - # model_type = "v1" - config_url = ( - "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" - ) - - if key_name_v2_1 in checkpoint and checkpoint[key_name_v2_1].shape[-1] == 1024: - # model_type = "v2" - config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml" - - if global_step == 110000: - # v2.1 needs to upcast attention - upcast_attention = True - elif key_name_sd_xl_base in checkpoint: - # only base xl has two text embedders - config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml" - elif key_name_sd_xl_refiner in checkpoint: - # only refiner xl has embedder and one text embedders - config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml" - - original_config_file = BytesIO(requests.get(config_url).content) - - original_config = OmegaConf.load(original_config_file) - if original_config["model"]["params"].get("use_ema") is not None: - extract_ema = original_config["model"]["params"]["use_ema"] - - if ( - model_version in [BaseModelType.StableDiffusion2, BaseModelType.StableDiffusion1] - and original_config["model"]["params"].get("parameterization") == "v" - ): - prediction_type = "v_prediction" - upcast_attention = True - image_size = 768 if model_version == BaseModelType.StableDiffusion2 else 512 - else: - prediction_type = "epsilon" - upcast_attention = False - image_size = 512 - - # Convert the text model. - if ( - model_type is None - and "cond_stage_config" in original_config.model.params - and original_config.model.params.cond_stage_config is not None - ): - model_type = original_config.model.params.cond_stage_config.target.split(".")[-1] - logger.debug(f"no `model_type` given, `model_type` inferred as: {model_type}") - elif model_type is None and original_config.model.params.network_config is not None: - if original_config.model.params.network_config.params.context_dim == 2048: - model_type = "SDXL" - else: - model_type = "SDXL-Refiner" - if image_size is None: - image_size = 1024 - - if num_in_channels is None and pipeline_class == StableDiffusionInpaintPipeline: - num_in_channels = 9 - elif num_in_channels is None: - num_in_channels = 4 - - if "unet_config" in original_config.model.params: - original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels - - if ( - "parameterization" in original_config["model"]["params"] - and original_config["model"]["params"]["parameterization"] == "v" - ): - if prediction_type is None: - # NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"` - # as it relies on a brittle global step parameter here - prediction_type = "epsilon" if global_step == 875000 else "v_prediction" - if image_size is None: - # NOTE: For stable diffusion 2 base one has to pass `image_size==512` - # as it relies on a brittle global step parameter here - image_size = 512 if global_step == 875000 else 768 - else: - if prediction_type is None: - prediction_type = "epsilon" - if image_size is None: - image_size = 512 - - if controlnet is None and "control_stage_config" in original_config.model.params: - controlnet = convert_controlnet_checkpoint( - checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema - ) - - num_train_timesteps = getattr(original_config.model.params, "timesteps", None) or 1000 - - if model_type in ["SDXL", "SDXL-Refiner"]: - scheduler_dict = { - "beta_schedule": "scaled_linear", - "beta_start": 0.00085, - "beta_end": 0.012, - "interpolation_type": "linear", - "num_train_timesteps": num_train_timesteps, - "prediction_type": "epsilon", - "sample_max_value": 1.0, - "set_alpha_to_one": False, - "skip_prk_steps": True, - "steps_offset": 1, - "timestep_spacing": "leading", - } - scheduler = EulerDiscreteScheduler.from_config(scheduler_dict) - scheduler_type = "euler" - else: - beta_start = getattr(original_config.model.params, "linear_start", None) or 0.02 - beta_end = getattr(original_config.model.params, "linear_end", None) or 0.085 - scheduler = DDIMScheduler( - beta_end=beta_end, - beta_schedule="scaled_linear", - beta_start=beta_start, - num_train_timesteps=num_train_timesteps, - steps_offset=1, - clip_sample=False, - set_alpha_to_one=False, - prediction_type=prediction_type, - ) - # make sure scheduler works correctly with DDIM - scheduler.register_to_config(clip_sample=False) - - if scheduler_type == "pndm": - config = dict(scheduler.config) - config["skip_prk_steps"] = True - scheduler = PNDMScheduler.from_config(config) - elif scheduler_type == "lms": - scheduler = LMSDiscreteScheduler.from_config(scheduler.config) - elif scheduler_type == "heun": - scheduler = HeunDiscreteScheduler.from_config(scheduler.config) - elif scheduler_type == "euler": - scheduler = EulerDiscreteScheduler.from_config(scheduler.config) - elif scheduler_type == "euler-ancestral": - scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config) - elif scheduler_type == "dpm": - scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config) - elif scheduler_type == "ddim": - scheduler = scheduler - else: - raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!") - - # Convert the UNet2DConditionModel model. - unet_config = create_unet_diffusers_config(original_config, image_size=image_size) - unet_config["upcast_attention"] = upcast_attention - converted_unet_checkpoint = convert_ldm_unet_checkpoint( - checkpoint, unet_config, path=checkpoint_path, extract_ema=extract_ema - ) - - ctx = init_empty_weights if is_accelerate_available() else nullcontext - with ctx(): - unet = UNet2DConditionModel(**unet_config) - - if is_accelerate_available(): - for param_name, param in converted_unet_checkpoint.items(): - set_module_tensor_to_device(unet, param_name, "cpu", value=param) - else: - unet.load_state_dict(converted_unet_checkpoint) - - # Convert the VAE model. - if vae_path is None: - vae_config = create_vae_diffusers_config(original_config, image_size=image_size) - converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) - - if ( - "model" in original_config - and "params" in original_config.model - and "scale_factor" in original_config.model.params - ): - vae_scaling_factor = original_config.model.params.scale_factor - else: - vae_scaling_factor = 0.18215 # default SD scaling factor - - vae_config["scaling_factor"] = vae_scaling_factor - - ctx = init_empty_weights if is_accelerate_available() else nullcontext - with ctx(): - vae = AutoencoderKL(**vae_config) - - if is_accelerate_available(): - for param_name, param in converted_vae_checkpoint.items(): - set_module_tensor_to_device(vae, param_name, "cpu", value=param) - else: - vae.load_state_dict(converted_vae_checkpoint) - else: - vae = AutoencoderKL.from_pretrained(vae_path) - - if model_type == "FrozenOpenCLIPEmbedder": - config_name = "stabilityai/stable-diffusion-2" - config_kwargs = {"subfolder": "text_encoder"} - - text_model = convert_open_clip_checkpoint(checkpoint, config_name, **config_kwargs) - tokenizer = CLIPTokenizer.from_pretrained(CONVERT_MODEL_ROOT / "stable-diffusion-2-clip", subfolder="tokenizer") - - if stable_unclip is None: - if controlnet: - pipe = pipeline_class( - vae=vae.to(precision), - text_encoder=text_model.to(precision), - tokenizer=tokenizer, - unet=unet.to(precision), - scheduler=scheduler, - controlnet=controlnet, - safety_checker=None, - feature_extractor=None, - requires_safety_checker=False, - ) - else: - pipe = pipeline_class( - vae=vae.to(precision), - text_encoder=text_model.to(precision), - tokenizer=tokenizer, - unet=unet.to(precision), - scheduler=scheduler, - safety_checker=None, - feature_extractor=None, - requires_safety_checker=False, - ) - else: - image_normalizer, image_noising_scheduler = stable_unclip_image_noising_components( - original_config, clip_stats_path=clip_stats_path, device=device - ) - - if stable_unclip == "img2img": - feature_extractor, image_encoder = stable_unclip_image_encoder(original_config) - - pipe = StableUnCLIPImg2ImgPipeline( - # image encoding components - feature_extractor=feature_extractor, - image_encoder=image_encoder, - # image noising components - image_normalizer=image_normalizer, - image_noising_scheduler=image_noising_scheduler, - # regular denoising components - tokenizer=tokenizer, - text_encoder=text_model.to(precision), - unet=unet.to(precision), - scheduler=scheduler, - # vae - vae=vae, - ) - elif stable_unclip == "txt2img": - if stable_unclip_prior is None or stable_unclip_prior == "karlo": - karlo_model = "kakaobrain/karlo-v1-alpha" - prior = PriorTransformer.from_pretrained(karlo_model, subfolder="prior") - - prior_tokenizer = CLIPTokenizer.from_pretrained(CONVERT_MODEL_ROOT / "clip-vit-large-patch14") - prior_text_model = CLIPTextModelWithProjection.from_pretrained( - CONVERT_MODEL_ROOT / "clip-vit-large-patch14" - ) - - prior_scheduler = UnCLIPScheduler.from_pretrained(karlo_model, subfolder="prior_scheduler") - prior_scheduler = DDPMScheduler.from_config(prior_scheduler.config) - else: - raise NotImplementedError(f"unknown prior for stable unclip model: {stable_unclip_prior}") - - pipe = StableUnCLIPPipeline( - # prior components - prior_tokenizer=prior_tokenizer, - prior_text_encoder=prior_text_model, - prior=prior, - prior_scheduler=prior_scheduler, - # image noising components - image_normalizer=image_normalizer, - image_noising_scheduler=image_noising_scheduler, - # regular denoising components - tokenizer=tokenizer, - text_encoder=text_model, - unet=unet, - scheduler=scheduler, - # vae - vae=vae, - ) - else: - raise NotImplementedError(f"unknown `stable_unclip` type: {stable_unclip}") - elif model_type == "PaintByExample": - vision_model = convert_paint_by_example_checkpoint(checkpoint) - tokenizer = CLIPTokenizer.from_pretrained(CONVERT_MODEL_ROOT / "clip-vit-large-patch14") - feature_extractor = AutoFeatureExtractor.from_pretrained(CONVERT_MODEL_ROOT / "stable-diffusion-safety-checker") - pipe = PaintByExamplePipeline( - vae=vae, - image_encoder=vision_model, - unet=unet, - scheduler=scheduler, - safety_checker=None, - feature_extractor=feature_extractor, - ) - elif model_type == "FrozenCLIPEmbedder": - text_model = convert_ldm_clip_checkpoint( - checkpoint, local_files_only=local_files_only, text_encoder=text_encoder - ) - tokenizer = ( - CLIPTokenizer.from_pretrained(CONVERT_MODEL_ROOT / "clip-vit-large-patch14") - if tokenizer is None - else tokenizer - ) - - if load_safety_checker: - safety_checker = StableDiffusionSafetyChecker.from_pretrained( - CONVERT_MODEL_ROOT / "stable-diffusion-safety-checker" - ) - feature_extractor = AutoFeatureExtractor.from_pretrained( - CONVERT_MODEL_ROOT / "stable-diffusion-safety-checker" - ) - else: - safety_checker = None - feature_extractor = None - - if controlnet: - pipe = pipeline_class( - vae=vae.to(precision), - text_encoder=text_model.to(precision), - tokenizer=tokenizer, - unet=unet.to(precision), - controlnet=controlnet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - ) - else: - pipe = pipeline_class( - vae=vae.to(precision), - text_encoder=text_model.to(precision), - tokenizer=tokenizer, - unet=unet.to(precision), - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - ) - elif model_type in ["SDXL", "SDXL-Refiner"]: - if model_type == "SDXL": - tokenizer = CLIPTokenizer.from_pretrained(CONVERT_MODEL_ROOT / "clip-vit-large-patch14") - text_encoder = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only) - - tokenizer_name = CONVERT_MODEL_ROOT / "CLIP-ViT-bigG-14-laion2B-39B-b160k" - tokenizer_2 = CLIPTokenizer.from_pretrained(tokenizer_name, pad_token="!") - - config_name = tokenizer_name - config_kwargs = {"projection_dim": 1280} - text_encoder_2 = convert_open_clip_checkpoint( - checkpoint, config_name, prefix="conditioner.embedders.1.model.", has_projection=True, **config_kwargs - ) - - pipe = StableDiffusionXLPipeline( - vae=vae.to(precision), - text_encoder=text_encoder.to(precision), - tokenizer=tokenizer, - text_encoder_2=text_encoder_2.to(precision), - tokenizer_2=tokenizer_2, - unet=unet.to(precision), - scheduler=scheduler, - force_zeros_for_empty_prompt=True, - ) - else: - tokenizer = None - text_encoder = None - tokenizer_name = CONVERT_MODEL_ROOT / "CLIP-ViT-bigG-14-laion2B-39B-b160k" - tokenizer_2 = CLIPTokenizer.from_pretrained(tokenizer_name, pad_token="!") - - config_name = tokenizer_name - config_kwargs = {"projection_dim": 1280} - text_encoder_2 = convert_open_clip_checkpoint( - checkpoint, config_name, prefix="conditioner.embedders.0.model.", has_projection=True, **config_kwargs - ) - - pipe = StableDiffusionXLImg2ImgPipeline( - vae=vae.to(precision), - text_encoder=text_encoder, - tokenizer=tokenizer, - text_encoder_2=text_encoder_2, - tokenizer_2=tokenizer_2, - unet=unet.to(precision), - scheduler=scheduler, - requires_aesthetics_score=True, - force_zeros_for_empty_prompt=False, - ) - else: - text_config = create_ldm_bert_config(original_config) - text_model = convert_ldm_bert_checkpoint(checkpoint, text_config) - tokenizer = BertTokenizerFast.from_pretrained(CONVERT_MODEL_ROOT / "bert-base-uncased") - pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler) - - return pipe - - -def download_controlnet_from_original_ckpt( - checkpoint_path: str, - original_config_file: str, - image_size: int = 512, - extract_ema: bool = False, - precision: Optional[torch.dtype] = None, - num_in_channels: Optional[int] = None, - upcast_attention: Optional[bool] = None, - device: str = None, - from_safetensors: bool = False, - use_linear_projection: Optional[bool] = None, - cross_attention_dim: Optional[bool] = None, - scan_needed: bool = False, -) -> DiffusionPipeline: - if from_safetensors: - from safetensors import safe_open - - checkpoint = {} - with safe_open(checkpoint_path, framework="pt", device="cpu") as f: - for key in f.keys(): - checkpoint[key] = f.get_tensor(key) - else: - if scan_needed: - # scan model - scan_result = scan_file_path(checkpoint_path) - if scan_result.infected_files != 0: - raise Exception("The model {checkpoint_path} is potentially infected by malware. Aborting import.") - if device is None: - device = "cuda" if torch.cuda.is_available() else "cpu" - checkpoint = torch.load(checkpoint_path, map_location=device) - else: - checkpoint = torch.load(checkpoint_path, map_location=device) - - # NOTE: this while loop isn't great but this controlnet checkpoint has one additional - # "state_dict" key https://huggingface.co/thibaud/controlnet-canny-sd21 - while "state_dict" in checkpoint: - checkpoint = checkpoint["state_dict"] - - # use original precision - precision_probing_key = "input_blocks.0.0.bias" - ckpt_precision = checkpoint[precision_probing_key].dtype - logger.debug(f"original controlnet precision = {ckpt_precision}") - precision = precision or ckpt_precision - - original_config = OmegaConf.load(original_config_file) - - if num_in_channels is not None: - original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels - - if "control_stage_config" not in original_config.model.params: - raise ValueError("`control_stage_config` not present in original config") - - controlnet = convert_controlnet_checkpoint( - checkpoint, - original_config, - checkpoint_path, - image_size, - upcast_attention, - extract_ema, - use_linear_projection=use_linear_projection, - cross_attention_dim=cross_attention_dim, - ) - - return controlnet.to(precision) - - -def convert_ldm_vae_to_diffusers(checkpoint, vae_config: DictConfig, image_size: int) -> AutoencoderKL: - vae_config = create_vae_diffusers_config(vae_config, image_size=image_size) - - converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) - - vae = AutoencoderKL(**vae_config) - vae.load_state_dict(converted_vae_checkpoint) - return vae - - -def convert_ckpt_to_diffusers( - checkpoint_path: Union[str, Path], - dump_path: Union[str, Path], - use_safetensors: bool = True, - **kwargs, -): - """ - Takes all the arguments of download_from_original_stable_diffusion_ckpt(), - and in addition a path-like object indicating the location of the desired diffusers - model to be written. - """ - pipe = download_from_original_stable_diffusion_ckpt(checkpoint_path, **kwargs) - - pipe.save_pretrained( - dump_path, - safe_serialization=use_safetensors, - ) - - -def convert_controlnet_to_diffusers( - checkpoint_path: Union[str, Path], - dump_path: Union[str, Path], - **kwargs, -): - """ - Takes all the arguments of download_controlnet_from_original_ckpt(), - and in addition a path-like object indicating the location of the desired diffusers - model to be written. - """ - pipe = download_controlnet_from_original_ckpt(checkpoint_path, **kwargs) - - pipe.save_pretrained(dump_path, safe_serialization=True) diff --git a/invokeai/backend/model_management_OLD/detect_baked_in_vae.py b/invokeai/backend/model_management_OLD/detect_baked_in_vae.py deleted file mode 100644 index 9118438548d..00000000000 --- a/invokeai/backend/model_management_OLD/detect_baked_in_vae.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright (c) 2024 Lincoln Stein and the InvokeAI Development Team -""" -This module exports the function has_baked_in_sdxl_vae(). -It returns True if an SDXL checkpoint model has the original SDXL 1.0 VAE, -which doesn't work properly in fp16 mode. -""" - -import hashlib -from pathlib import Path - -from safetensors.torch import load_file - -SDXL_1_0_VAE_HASH = "bc40b16c3a0fa4625abdfc01c04ffc21bf3cefa6af6c7768ec61eb1f1ac0da51" - - -def has_baked_in_sdxl_vae(checkpoint_path: Path) -> bool: - """Return true if the checkpoint contains a custom (non SDXL-1.0) VAE.""" - hash = _vae_hash(checkpoint_path) - return hash != SDXL_1_0_VAE_HASH - - -def _vae_hash(checkpoint_path: Path) -> str: - checkpoint = load_file(checkpoint_path, device="cpu") - vae_keys = [x for x in checkpoint.keys() if x.startswith("first_stage_model.")] - hash = hashlib.new("sha256") - for key in vae_keys: - value = checkpoint[key] - hash.update(bytes(key, "UTF-8")) - hash.update(bytes(str(value), "UTF-8")) - - return hash.hexdigest() diff --git a/invokeai/backend/model_management_OLD/lora.py b/invokeai/backend/model_management_OLD/lora.py deleted file mode 100644 index aed5eb60d57..00000000000 --- a/invokeai/backend/model_management_OLD/lora.py +++ /dev/null @@ -1,582 +0,0 @@ -from __future__ import annotations - -import pickle -from contextlib import contextmanager -from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union - -import numpy as np -import torch -from compel.embeddings_provider import BaseTextualInversionManager -from diffusers.models import UNet2DConditionModel -from safetensors.torch import load_file -from transformers import CLIPTextModel, CLIPTokenizer - -from invokeai.app.shared.models import FreeUConfig -from invokeai.backend.model_management.model_load_optimizations import skip_torch_weight_init - -from .models.lora import LoRAModel - -""" -loras = [ - (lora_model1, 0.7), - (lora_model2, 0.4), -] -with LoRAHelper.apply_lora_unet(unet, loras): - # unet with applied loras -# unmodified unet - -""" - - -# TODO: rename smth like ModelPatcher and add TI method? -class ModelPatcher: - @staticmethod - def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]: - assert "." not in lora_key - - if not lora_key.startswith(prefix): - raise Exception(f"lora_key with invalid prefix: {lora_key}, {prefix}") - - module = model - module_key = "" - key_parts = lora_key[len(prefix) :].split("_") - - submodule_name = key_parts.pop(0) - - while len(key_parts) > 0: - try: - module = module.get_submodule(submodule_name) - module_key += "." + submodule_name - submodule_name = key_parts.pop(0) - except Exception: - submodule_name += "_" + key_parts.pop(0) - - module = module.get_submodule(submodule_name) - module_key = (module_key + "." + submodule_name).lstrip(".") - - return (module_key, module) - - @classmethod - @contextmanager - def apply_lora_unet( - cls, - unet: UNet2DConditionModel, - loras: List[Tuple[LoRAModel, float]], - ): - with cls.apply_lora(unet, loras, "lora_unet_"): - yield - - @classmethod - @contextmanager - def apply_lora_text_encoder( - cls, - text_encoder: CLIPTextModel, - loras: List[Tuple[LoRAModel, float]], - ): - with cls.apply_lora(text_encoder, loras, "lora_te_"): - yield - - @classmethod - @contextmanager - def apply_sdxl_lora_text_encoder( - cls, - text_encoder: CLIPTextModel, - loras: List[Tuple[LoRAModel, float]], - ): - with cls.apply_lora(text_encoder, loras, "lora_te1_"): - yield - - @classmethod - @contextmanager - def apply_sdxl_lora_text_encoder2( - cls, - text_encoder: CLIPTextModel, - loras: List[Tuple[LoRAModel, float]], - ): - with cls.apply_lora(text_encoder, loras, "lora_te2_"): - yield - - @classmethod - @contextmanager - def apply_lora( - cls, - model: torch.nn.Module, - loras: List[Tuple[LoRAModel, float]], # THIS IS INCORRECT. IT IS ACTUALLY A LoRAModelRaw - prefix: str, - ): - original_weights = {} - try: - with torch.no_grad(): - for lora, lora_weight in loras: - # assert lora.device.type == "cpu" - for layer_key, layer in lora.layers.items(): - if not layer_key.startswith(prefix): - continue - - # TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This - # should be improved in the following ways: - # 1. The key mapping could be more-efficiently pre-computed. This would save time every time a - # LoRA model is applied. - # 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the - # intricacies of Stable Diffusion key resolution. It should just expect the input LoRA - # weights to have valid keys. - module_key, module = cls._resolve_lora_key(model, layer_key, prefix) - - # All of the LoRA weight calculations will be done on the same device as the module weight. - # (Performance will be best if this is a CUDA device.) - device = module.weight.device - dtype = module.weight.dtype - - if module_key not in original_weights: - original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True) - - layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0 - - # We intentionally move to the target device first, then cast. Experimentally, this was found to - # be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the - # same thing in a single call to '.to(...)'. - layer.to(device=device) - layer.to(dtype=torch.float32) - # TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA - # devices here. Experimentally, it was found to be very slow on CPU. More investigation needed. - layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale) - layer.to(device="cpu") - - if module.weight.shape != layer_weight.shape: - # TODO: debug on lycoris - layer_weight = layer_weight.reshape(module.weight.shape) - - module.weight += layer_weight.to(dtype=dtype) - - yield # wait for context manager exit - - finally: - with torch.no_grad(): - for module_key, weight in original_weights.items(): - model.get_submodule(module_key).weight.copy_(weight) - - @classmethod - @contextmanager - def apply_ti( - cls, - tokenizer: CLIPTokenizer, - text_encoder: CLIPTextModel, - ti_list: List[Tuple[str, Any]], - ) -> Tuple[CLIPTokenizer, TextualInversionManager]: - init_tokens_count = None - new_tokens_added = None - - # TODO: This is required since Transformers 4.32 see - # https://github.com/huggingface/transformers/pull/25088 - # More information by NVIDIA: - # https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc - # This value might need to be changed in the future and take the GPUs model into account as there seem - # to be ideal values for different GPUS. This value is temporary! - # For references to the current discussion please see https://github.com/invoke-ai/InvokeAI/pull/4817 - pad_to_multiple_of = 8 - - try: - # HACK: The CLIPTokenizer API does not include a way to remove tokens after calling add_tokens(...). As a - # workaround, we create a full copy of `tokenizer` so that its original behavior can be restored after - # exiting this `apply_ti(...)` context manager. - # - # In a previous implementation, the deep copy was obtained with `ti_tokenizer = copy.deepcopy(tokenizer)`, - # but a pickle roundtrip was found to be much faster (1 sec vs. 0.05 secs). - ti_tokenizer = pickle.loads(pickle.dumps(tokenizer)) - ti_manager = TextualInversionManager(ti_tokenizer) - init_tokens_count = text_encoder.resize_token_embeddings(None, pad_to_multiple_of).num_embeddings - - def _get_trigger(ti_name, index): - trigger = ti_name - if index > 0: - trigger += f"-!pad-{i}" - return f"<{trigger}>" - - def _get_ti_embedding(model_embeddings, ti): - print(f"DEBUG: model_embeddings={type(model_embeddings)}, ti={type(ti)}") - print(f"DEBUG: is it an nn.Module? {isinstance(model_embeddings, torch.nn.Module)}") - # for SDXL models, select the embedding that matches the text encoder's dimensions - if ti.embedding_2 is not None: - return ( - ti.embedding_2 - if ti.embedding_2.shape[1] == model_embeddings.weight.data[0].shape[0] - else ti.embedding - ) - else: - print(f"DEBUG: ti.embedding={type(ti.embedding)}") - return ti.embedding - - # modify tokenizer - new_tokens_added = 0 - for ti_name, ti in ti_list: - ti_embedding = _get_ti_embedding(text_encoder.get_input_embeddings(), ti) - - for i in range(ti_embedding.shape[0]): - new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i)) - - # Modify text_encoder. - # resize_token_embeddings(...) constructs a new torch.nn.Embedding internally. Initializing the weights of - # this embedding is slow and unnecessary, so we wrap this step in skip_torch_weight_init() to save some - # time. - with skip_torch_weight_init(): - text_encoder.resize_token_embeddings(init_tokens_count + new_tokens_added, pad_to_multiple_of) - model_embeddings = text_encoder.get_input_embeddings() - - for ti_name, ti in ti_list: - ti_embedding = _get_ti_embedding(text_encoder.get_input_embeddings(), ti) - - ti_tokens = [] - for i in range(ti_embedding.shape[0]): - embedding = ti_embedding[i] - trigger = _get_trigger(ti_name, i) - - token_id = ti_tokenizer.convert_tokens_to_ids(trigger) - if token_id == ti_tokenizer.unk_token_id: - raise RuntimeError(f"Unable to find token id for token '{trigger}'") - - if model_embeddings.weight.data[token_id].shape != embedding.shape: - raise ValueError( - f"Cannot load embedding for {trigger}. It was trained on a model with token dimension" - f" {embedding.shape[0]}, but the current model has token dimension" - f" {model_embeddings.weight.data[token_id].shape[0]}." - ) - - model_embeddings.weight.data[token_id] = embedding.to( - device=text_encoder.device, dtype=text_encoder.dtype - ) - ti_tokens.append(token_id) - - if len(ti_tokens) > 1: - ti_manager.pad_tokens[ti_tokens[0]] = ti_tokens[1:] - - yield ti_tokenizer, ti_manager - - finally: - if init_tokens_count and new_tokens_added: - text_encoder.resize_token_embeddings(init_tokens_count, pad_to_multiple_of) - - @classmethod - @contextmanager - def apply_clip_skip( - cls, - text_encoder: CLIPTextModel, - clip_skip: int, - ): - skipped_layers = [] - try: - for _i in range(clip_skip): - skipped_layers.append(text_encoder.text_model.encoder.layers.pop(-1)) - - yield - - finally: - while len(skipped_layers) > 0: - text_encoder.text_model.encoder.layers.append(skipped_layers.pop()) - - @classmethod - @contextmanager - def apply_freeu( - cls, - unet: UNet2DConditionModel, - freeu_config: Optional[FreeUConfig] = None, - ): - did_apply_freeu = False - try: - if freeu_config is not None: - unet.enable_freeu(b1=freeu_config.b1, b2=freeu_config.b2, s1=freeu_config.s1, s2=freeu_config.s2) - did_apply_freeu = True - - yield - - finally: - if did_apply_freeu: - unet.disable_freeu() - - -class TextualInversionModel: - embedding: torch.Tensor # [n, 768]|[n, 1280] - embedding_2: Optional[torch.Tensor] = None # [n, 768]|[n, 1280] - for SDXL models - - @classmethod - def from_checkpoint( - cls, - file_path: Union[str, Path], - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ): - if not isinstance(file_path, Path): - file_path = Path(file_path) - - result = cls() # TODO: - - if file_path.suffix == ".safetensors": - state_dict = load_file(file_path.absolute().as_posix(), device="cpu") - else: - state_dict = torch.load(file_path, map_location="cpu") - - # both v1 and v2 format embeddings - # difference mostly in metadata - if "string_to_param" in state_dict: - if len(state_dict["string_to_param"]) > 1: - print( - f'Warn: Embedding "{file_path.name}" contains multiple tokens, which is not supported. The first', - " token will be used.", - ) - - result.embedding = next(iter(state_dict["string_to_param"].values())) - - # v3 (easynegative) - elif "emb_params" in state_dict: - result.embedding = state_dict["emb_params"] - - # v5(sdxl safetensors file) - elif "clip_g" in state_dict and "clip_l" in state_dict: - result.embedding = state_dict["clip_g"] - result.embedding_2 = state_dict["clip_l"] - - # v4(diffusers bin files) - else: - result.embedding = next(iter(state_dict.values())) - - if len(result.embedding.shape) == 1: - result.embedding = result.embedding.unsqueeze(0) - - if not isinstance(result.embedding, torch.Tensor): - raise ValueError(f"Invalid embeddings file: {file_path.name}") - - return result - - -class TextualInversionManager(BaseTextualInversionManager): - pad_tokens: Dict[int, List[int]] - tokenizer: CLIPTokenizer - - def __init__(self, tokenizer: CLIPTokenizer): - self.pad_tokens = {} - self.tokenizer = tokenizer - - def expand_textual_inversion_token_ids_if_necessary(self, token_ids: list[int]) -> list[int]: - if len(self.pad_tokens) == 0: - return token_ids - - if token_ids[0] == self.tokenizer.bos_token_id: - raise ValueError("token_ids must not start with bos_token_id") - if token_ids[-1] == self.tokenizer.eos_token_id: - raise ValueError("token_ids must not end with eos_token_id") - - new_token_ids = [] - for token_id in token_ids: - new_token_ids.append(token_id) - if token_id in self.pad_tokens: - new_token_ids.extend(self.pad_tokens[token_id]) - - # Do not exceed the max model input size - # The -2 here is compensating for compensate compel.embeddings_provider.get_token_ids(), - # which first removes and then adds back the start and end tokens. - max_length = list(self.tokenizer.max_model_input_sizes.values())[0] - 2 - if len(new_token_ids) > max_length: - new_token_ids = new_token_ids[0:max_length] - - return new_token_ids - - -class ONNXModelPatcher: - from diffusers import OnnxRuntimeModel - - from .models.base import IAIOnnxRuntimeModel - - @classmethod - @contextmanager - def apply_lora_unet( - cls, - unet: OnnxRuntimeModel, - loras: List[Tuple[LoRAModel, float]], - ): - with cls.apply_lora(unet, loras, "lora_unet_"): - yield - - @classmethod - @contextmanager - def apply_lora_text_encoder( - cls, - text_encoder: OnnxRuntimeModel, - loras: List[Tuple[LoRAModel, float]], - ): - with cls.apply_lora(text_encoder, loras, "lora_te_"): - yield - - # based on - # https://github.com/ssube/onnx-web/blob/ca2e436f0623e18b4cfe8a0363fcfcf10508acf7/api/onnx_web/convert/diffusion/lora.py#L323 - @classmethod - @contextmanager - def apply_lora( - cls, - model: IAIOnnxRuntimeModel, - loras: List[Tuple[LoRAModel, float]], - prefix: str, - ): - from .models.base import IAIOnnxRuntimeModel - - if not isinstance(model, IAIOnnxRuntimeModel): - raise Exception("Only IAIOnnxRuntimeModel models supported") - - orig_weights = {} - - try: - blended_loras = {} - - for lora, lora_weight in loras: - for layer_key, layer in lora.layers.items(): - if not layer_key.startswith(prefix): - continue - - layer.to(dtype=torch.float32) - layer_key = layer_key.replace(prefix, "") - # TODO: rewrite to pass original tensor weight(required by ia3) - layer_weight = layer.get_weight(None).detach().cpu().numpy() * lora_weight - if layer_key is blended_loras: - blended_loras[layer_key] += layer_weight - else: - blended_loras[layer_key] = layer_weight - - node_names = {} - for node in model.nodes.values(): - node_names[node.name.replace("/", "_").replace(".", "_").lstrip("_")] = node.name - - for layer_key, lora_weight in blended_loras.items(): - conv_key = layer_key + "_Conv" - gemm_key = layer_key + "_Gemm" - matmul_key = layer_key + "_MatMul" - - if conv_key in node_names or gemm_key in node_names: - if conv_key in node_names: - conv_node = model.nodes[node_names[conv_key]] - else: - conv_node = model.nodes[node_names[gemm_key]] - - weight_name = [n for n in conv_node.input if ".weight" in n][0] - orig_weight = model.tensors[weight_name] - - if orig_weight.shape[-2:] == (1, 1): - if lora_weight.shape[-2:] == (1, 1): - new_weight = orig_weight.squeeze((3, 2)) + lora_weight.squeeze((3, 2)) - else: - new_weight = orig_weight.squeeze((3, 2)) + lora_weight - - new_weight = np.expand_dims(new_weight, (2, 3)) - else: - if orig_weight.shape != lora_weight.shape: - new_weight = orig_weight + lora_weight.reshape(orig_weight.shape) - else: - new_weight = orig_weight + lora_weight - - orig_weights[weight_name] = orig_weight - model.tensors[weight_name] = new_weight.astype(orig_weight.dtype) - - elif matmul_key in node_names: - weight_node = model.nodes[node_names[matmul_key]] - matmul_name = [n for n in weight_node.input if "MatMul" in n][0] - - orig_weight = model.tensors[matmul_name] - new_weight = orig_weight + lora_weight.transpose() - - orig_weights[matmul_name] = orig_weight - model.tensors[matmul_name] = new_weight.astype(orig_weight.dtype) - - else: - # warn? err? - pass - - yield - - finally: - # restore original weights - for name, orig_weight in orig_weights.items(): - model.tensors[name] = orig_weight - - @classmethod - @contextmanager - def apply_ti( - cls, - tokenizer: CLIPTokenizer, - text_encoder: IAIOnnxRuntimeModel, - ti_list: List[Tuple[str, Any]], - ) -> Tuple[CLIPTokenizer, TextualInversionManager]: - from .models.base import IAIOnnxRuntimeModel - - if not isinstance(text_encoder, IAIOnnxRuntimeModel): - raise Exception("Only IAIOnnxRuntimeModel models supported") - - orig_embeddings = None - - try: - # HACK: The CLIPTokenizer API does not include a way to remove tokens after calling add_tokens(...). As a - # workaround, we create a full copy of `tokenizer` so that its original behavior can be restored after - # exiting this `apply_ti(...)` context manager. - # - # In a previous implementation, the deep copy was obtained with `ti_tokenizer = copy.deepcopy(tokenizer)`, - # but a pickle roundtrip was found to be much faster (1 sec vs. 0.05 secs). - ti_tokenizer = pickle.loads(pickle.dumps(tokenizer)) - ti_manager = TextualInversionManager(ti_tokenizer) - - def _get_trigger(ti_name, index): - trigger = ti_name - if index > 0: - trigger += f"-!pad-{i}" - return f"<{trigger}>" - - # modify text_encoder - orig_embeddings = text_encoder.tensors["text_model.embeddings.token_embedding.weight"] - - # modify tokenizer - new_tokens_added = 0 - for ti_name, ti in ti_list: - if ti.embedding_2 is not None: - ti_embedding = ( - ti.embedding_2 if ti.embedding_2.shape[1] == orig_embeddings.shape[0] else ti.embedding - ) - else: - ti_embedding = ti.embedding - - for i in range(ti_embedding.shape[0]): - new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i)) - - embeddings = np.concatenate( - (np.copy(orig_embeddings), np.zeros((new_tokens_added, orig_embeddings.shape[1]))), - axis=0, - ) - - for ti_name, _ in ti_list: - ti_tokens = [] - for i in range(ti_embedding.shape[0]): - embedding = ti_embedding[i].detach().numpy() - trigger = _get_trigger(ti_name, i) - - token_id = ti_tokenizer.convert_tokens_to_ids(trigger) - if token_id == ti_tokenizer.unk_token_id: - raise RuntimeError(f"Unable to find token id for token '{trigger}'") - - if embeddings[token_id].shape != embedding.shape: - raise ValueError( - f"Cannot load embedding for {trigger}. It was trained on a model with token dimension" - f" {embedding.shape[0]}, but the current model has token dimension" - f" {embeddings[token_id].shape[0]}." - ) - - embeddings[token_id] = embedding - ti_tokens.append(token_id) - - if len(ti_tokens) > 1: - ti_manager.pad_tokens[ti_tokens[0]] = ti_tokens[1:] - - text_encoder.tensors["text_model.embeddings.token_embedding.weight"] = embeddings.astype( - orig_embeddings.dtype - ) - - yield ti_tokenizer, ti_manager - - finally: - # restore - if orig_embeddings is not None: - text_encoder.tensors["text_model.embeddings.token_embedding.weight"] = orig_embeddings diff --git a/invokeai/backend/model_management_OLD/memory_snapshot.py b/invokeai/backend/model_management_OLD/memory_snapshot.py deleted file mode 100644 index fe54af191ce..00000000000 --- a/invokeai/backend/model_management_OLD/memory_snapshot.py +++ /dev/null @@ -1,99 +0,0 @@ -import gc -from typing import Optional - -import psutil -import torch - -from invokeai.backend.model_management.libc_util import LibcUtil, Struct_mallinfo2 - -GB = 2**30 # 1 GB - - -class MemorySnapshot: - """A snapshot of RAM and VRAM usage. All values are in bytes.""" - - def __init__(self, process_ram: int, vram: Optional[int], malloc_info: Optional[Struct_mallinfo2]): - """Initialize a MemorySnapshot. - - Most of the time, `MemorySnapshot` will be constructed with `MemorySnapshot.capture()`. - - Args: - process_ram (int): CPU RAM used by the current process. - vram (Optional[int]): VRAM used by torch. - malloc_info (Optional[Struct_mallinfo2]): Malloc info obtained from LibcUtil. - """ - self.process_ram = process_ram - self.vram = vram - self.malloc_info = malloc_info - - @classmethod - def capture(cls, run_garbage_collector: bool = True): - """Capture and return a MemorySnapshot. - - Note: This function has significant overhead, particularly if `run_garbage_collector == True`. - - Args: - run_garbage_collector (bool, optional): If true, gc.collect() will be run before checking the process RAM - usage. Defaults to True. - - Returns: - MemorySnapshot - """ - if run_garbage_collector: - gc.collect() - - # According to the psutil docs (https://psutil.readthedocs.io/en/latest/#psutil.Process.memory_info), rss is - # supported on all platforms. - process_ram = psutil.Process().memory_info().rss - - if torch.cuda.is_available(): - vram = torch.cuda.memory_allocated() - else: - # TODO: We could add support for mps.current_allocated_memory() as well. Leaving out for now until we have - # time to test it properly. - vram = None - - try: - malloc_info = LibcUtil().mallinfo2() - except (OSError, AttributeError): - # OSError: This is expected in environments that do not have the 'libc.so.6' shared library. - # AttributeError: This is expected in environments that have `libc.so.6` but do not have the `mallinfo2` (e.g. glibc < 2.33) - # TODO: Does `mallinfo` work? - malloc_info = None - - return cls(process_ram, vram, malloc_info) - - -def get_pretty_snapshot_diff(snapshot_1: Optional[MemorySnapshot], snapshot_2: Optional[MemorySnapshot]) -> str: - """Get a pretty string describing the difference between two `MemorySnapshot`s.""" - - def get_msg_line(prefix: str, val1: int, val2: int): - diff = val2 - val1 - return f"{prefix: <30} ({(diff/GB):+5.3f}): {(val1/GB):5.3f}GB -> {(val2/GB):5.3f}GB\n" - - msg = "" - - if snapshot_1 is None or snapshot_2 is None: - return msg - - msg += get_msg_line("Process RAM", snapshot_1.process_ram, snapshot_2.process_ram) - - if snapshot_1.malloc_info is not None and snapshot_2.malloc_info is not None: - msg += get_msg_line("libc mmap allocated", snapshot_1.malloc_info.hblkhd, snapshot_2.malloc_info.hblkhd) - - msg += get_msg_line("libc arena used", snapshot_1.malloc_info.uordblks, snapshot_2.malloc_info.uordblks) - - msg += get_msg_line("libc arena free", snapshot_1.malloc_info.fordblks, snapshot_2.malloc_info.fordblks) - - libc_total_allocated_1 = snapshot_1.malloc_info.arena + snapshot_1.malloc_info.hblkhd - libc_total_allocated_2 = snapshot_2.malloc_info.arena + snapshot_2.malloc_info.hblkhd - msg += get_msg_line("libc total allocated", libc_total_allocated_1, libc_total_allocated_2) - - libc_total_used_1 = snapshot_1.malloc_info.uordblks + snapshot_1.malloc_info.hblkhd - libc_total_used_2 = snapshot_2.malloc_info.uordblks + snapshot_2.malloc_info.hblkhd - msg += get_msg_line("libc total used", libc_total_used_1, libc_total_used_2) - - if snapshot_1.vram is not None and snapshot_2.vram is not None: - msg += get_msg_line("VRAM", snapshot_1.vram, snapshot_2.vram) - - return msg diff --git a/invokeai/backend/model_management_OLD/model_cache.py b/invokeai/backend/model_management_OLD/model_cache.py deleted file mode 100644 index 2a7f4b5a95e..00000000000 --- a/invokeai/backend/model_management_OLD/model_cache.py +++ /dev/null @@ -1,553 +0,0 @@ -""" -Manage a RAM cache of diffusion/transformer models for fast switching. -They are moved between GPU VRAM and CPU RAM as necessary. If the cache -grows larger than a preset maximum, then the least recently used -model will be cleared and (re)loaded from disk when next needed. - -The cache returns context manager generators designed to load the -model into the GPU within the context, and unload outside the -context. Use like this: - - cache = ModelCache(max_cache_size=7.5) - with cache.get_model('runwayml/stable-diffusion-1-5') as SD1, - cache.get_model('stabilityai/stable-diffusion-2') as SD2: - do_something_in_GPU(SD1,SD2) - - -""" - -import gc -import hashlib -import math -import os -import sys -import time -from contextlib import suppress -from dataclasses import dataclass, field -from pathlib import Path -from typing import Any, Dict, Optional, Type, Union, types - -import torch - -import invokeai.backend.util.logging as logger -from invokeai.backend.model_management.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff -from invokeai.backend.model_management.model_load_optimizations import skip_torch_weight_init - -from ..util.devices import choose_torch_device -from .models import BaseModelType, ModelBase, ModelType, SubModelType - -if choose_torch_device() == torch.device("mps"): - from torch import mps - -# Maximum size of the cache, in gigs -# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously -DEFAULT_MAX_CACHE_SIZE = 6.0 - -# amount of GPU memory to hold in reserve for use by generations (GB) -DEFAULT_MAX_VRAM_CACHE_SIZE = 2.75 - -# actual size of a gig -GIG = 1073741824 -# Size of a MB in bytes. -MB = 2**20 - - -@dataclass -class CacheStats(object): - hits: int = 0 # cache hits - misses: int = 0 # cache misses - high_watermark: int = 0 # amount of cache used - in_cache: int = 0 # number of models in cache - cleared: int = 0 # number of models cleared to make space - cache_size: int = 0 # total size of cache - # {submodel_key => size} - loaded_model_sizes: Dict[str, int] = field(default_factory=dict) - - -class ModelLocker(object): - "Forward declaration" - - pass - - -class ModelCache(object): - "Forward declaration" - - pass - - -class _CacheRecord: - size: int - model: Any - cache: ModelCache - _locks: int - - def __init__(self, cache, model: Any, size: int): - self.size = size - self.model = model - self.cache = cache - self._locks = 0 - - def lock(self): - self._locks += 1 - - def unlock(self): - self._locks -= 1 - assert self._locks >= 0 - - @property - def locked(self): - return self._locks > 0 - - @property - def loaded(self): - if self.model is not None and hasattr(self.model, "device"): - return self.model.device != self.cache.storage_device - else: - return False - - -class ModelCache(object): - def __init__( - self, - max_cache_size: float = DEFAULT_MAX_CACHE_SIZE, - max_vram_cache_size: float = DEFAULT_MAX_VRAM_CACHE_SIZE, - execution_device: torch.device = torch.device("cuda"), - storage_device: torch.device = torch.device("cpu"), - precision: torch.dtype = torch.float16, - sequential_offload: bool = False, - lazy_offloading: bool = True, - sha_chunksize: int = 16777216, - logger: types.ModuleType = logger, - log_memory_usage: bool = False, - ): - """ - :param max_cache_size: Maximum size of the RAM cache [6.0 GB] - :param execution_device: Torch device to load active model into [torch.device('cuda')] - :param storage_device: Torch device to save inactive model in [torch.device('cpu')] - :param precision: Precision for loaded models [torch.float16] - :param lazy_offloading: Keep model in VRAM until another model needs to be loaded - :param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially - :param sha_chunksize: Chunksize to use when calculating sha256 model hash - :param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache - operation, and the result will be logged (at debug level). There is a time cost to capturing the memory - snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's - behaviour. - """ - self.model_infos: Dict[str, ModelBase] = {} - # allow lazy offloading only when vram cache enabled - self.lazy_offloading = lazy_offloading and max_vram_cache_size > 0 - self.precision: torch.dtype = precision - self.max_cache_size: float = max_cache_size - self.max_vram_cache_size: float = max_vram_cache_size - self.execution_device: torch.device = execution_device - self.storage_device: torch.device = storage_device - self.sha_chunksize = sha_chunksize - self.logger = logger - self._log_memory_usage = log_memory_usage - - # used for stats collection - self.stats = None - - self._cached_models = {} - self._cache_stack = [] - - def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]: - if self._log_memory_usage: - return MemorySnapshot.capture() - return None - - def get_key( - self, - model_path: str, - base_model: BaseModelType, - model_type: ModelType, - submodel_type: Optional[SubModelType] = None, - ): - key = f"{model_path}:{base_model}:{model_type}" - if submodel_type: - key += f":{submodel_type}" - return key - - def _get_model_info( - self, - model_path: str, - model_class: Type[ModelBase], - base_model: BaseModelType, - model_type: ModelType, - ): - model_info_key = self.get_key( - model_path=model_path, - base_model=base_model, - model_type=model_type, - submodel_type=None, - ) - - if model_info_key not in self.model_infos: - self.model_infos[model_info_key] = model_class( - model_path, - base_model, - model_type, - ) - - return self.model_infos[model_info_key] - - # TODO: args - def get_model( - self, - model_path: Union[str, Path], - model_class: Type[ModelBase], - base_model: BaseModelType, - model_type: ModelType, - submodel: Optional[SubModelType] = None, - gpu_load: bool = True, - ) -> Any: - if not isinstance(model_path, Path): - model_path = Path(model_path) - - if not os.path.exists(model_path): - raise Exception(f"Model not found: {model_path}") - - model_info = self._get_model_info( - model_path=model_path, - model_class=model_class, - base_model=base_model, - model_type=model_type, - ) - key = self.get_key( - model_path=model_path, - base_model=base_model, - model_type=model_type, - submodel_type=submodel, - ) - # TODO: lock for no copies on simultaneous calls? - cache_entry = self._cached_models.get(key, None) - if cache_entry is None: - self.logger.info( - f"Loading model {model_path}, type" - f" {base_model.value}:{model_type.value}{':'+submodel.value if submodel else ''}" - ) - if self.stats: - self.stats.misses += 1 - - self_reported_model_size_before_load = model_info.get_size(submodel) - # Remove old models from the cache to make room for the new model. - self._make_cache_room(self_reported_model_size_before_load) - - # Load the model from disk and capture a memory snapshot before/after. - start_load_time = time.time() - snapshot_before = self._capture_memory_snapshot() - with skip_torch_weight_init(): - model = model_info.get_model(child_type=submodel, torch_dtype=self.precision) - snapshot_after = self._capture_memory_snapshot() - end_load_time = time.time() - - self_reported_model_size_after_load = model_info.get_size(submodel) - - self.logger.debug( - f"Moved model '{key}' from disk to cpu in {(end_load_time-start_load_time):.2f}s.\n" - f"Self-reported size before/after load: {(self_reported_model_size_before_load/GIG):.3f}GB /" - f" {(self_reported_model_size_after_load/GIG):.3f}GB.\n" - f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}" - ) - - if abs(self_reported_model_size_after_load - self_reported_model_size_before_load) > 10 * MB: - self.logger.debug( - f"Model '{key}' mis-reported its size before load. Self-reported size before/after load:" - f" {(self_reported_model_size_before_load/GIG):.2f}GB /" - f" {(self_reported_model_size_after_load/GIG):.2f}GB." - ) - - cache_entry = _CacheRecord(self, model, self_reported_model_size_after_load) - self._cached_models[key] = cache_entry - else: - if self.stats: - self.stats.hits += 1 - - if self.stats: - self.stats.cache_size = self.max_cache_size * GIG - self.stats.high_watermark = max(self.stats.high_watermark, self._cache_size()) - self.stats.in_cache = len(self._cached_models) - self.stats.loaded_model_sizes[key] = max( - self.stats.loaded_model_sizes.get(key, 0), model_info.get_size(submodel) - ) - - with suppress(Exception): - self._cache_stack.remove(key) - self._cache_stack.append(key) - - return self.ModelLocker(self, key, cache_entry.model, gpu_load, cache_entry.size) - - def _move_model_to_device(self, key: str, target_device: torch.device): - cache_entry = self._cached_models[key] - - source_device = cache_entry.model.device - # Note: We compare device types only so that 'cuda' == 'cuda:0'. This would need to be revised to support - # multi-GPU. - if torch.device(source_device).type == torch.device(target_device).type: - return - - start_model_to_time = time.time() - snapshot_before = self._capture_memory_snapshot() - cache_entry.model.to(target_device) - snapshot_after = self._capture_memory_snapshot() - end_model_to_time = time.time() - self.logger.debug( - f"Moved model '{key}' from {source_device} to" - f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s.\n" - f"Estimated model size: {(cache_entry.size/GIG):.3f} GB.\n" - f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}" - ) - - if ( - snapshot_before is not None - and snapshot_after is not None - and snapshot_before.vram is not None - and snapshot_after.vram is not None - ): - vram_change = abs(snapshot_before.vram - snapshot_after.vram) - - # If the estimated model size does not match the change in VRAM, log a warning. - if not math.isclose( - vram_change, - cache_entry.size, - rel_tol=0.1, - abs_tol=10 * MB, - ): - self.logger.debug( - f"Moving model '{key}' from {source_device} to" - f" {target_device} caused an unexpected change in VRAM usage. The model's" - " estimated size may be incorrect. Estimated model size:" - f" {(cache_entry.size/GIG):.3f} GB.\n" - f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}" - ) - - class ModelLocker(object): - def __init__(self, cache, key, model, gpu_load, size_needed): - """ - :param cache: The model_cache object - :param key: The key of the model to lock in GPU - :param model: The model to lock - :param gpu_load: True if load into gpu - :param size_needed: Size of the model to load - """ - self.gpu_load = gpu_load - self.cache = cache - self.key = key - self.model = model - self.size_needed = size_needed - self.cache_entry = self.cache._cached_models[self.key] - - def __enter__(self) -> Any: - if not hasattr(self.model, "to"): - return self.model - - # NOTE that the model has to have the to() method in order for this - # code to move it into GPU! - if self.gpu_load: - self.cache_entry.lock() - - try: - if self.cache.lazy_offloading: - self.cache._offload_unlocked_models(self.size_needed) - - self.cache._move_model_to_device(self.key, self.cache.execution_device) - - self.cache.logger.debug(f"Locking {self.key} in {self.cache.execution_device}") - self.cache._print_cuda_stats() - - except Exception: - self.cache_entry.unlock() - raise - - # TODO: not fully understand - # in the event that the caller wants the model in RAM, we - # move it into CPU if it is in GPU and not locked - elif self.cache_entry.loaded and not self.cache_entry.locked: - self.cache._move_model_to_device(self.key, self.cache.storage_device) - - return self.model - - def __exit__(self, type, value, traceback): - if not hasattr(self.model, "to"): - return - - self.cache_entry.unlock() - if not self.cache.lazy_offloading: - self.cache._offload_unlocked_models() - self.cache._print_cuda_stats() - - # TODO: should it be called untrack_model? - def uncache_model(self, cache_id: str): - with suppress(ValueError): - self._cache_stack.remove(cache_id) - self._cached_models.pop(cache_id, None) - - def model_hash( - self, - model_path: Union[str, Path], - ) -> str: - """ - Given the HF repo id or path to a model on disk, returns a unique - hash. Works for legacy checkpoint files, HF models on disk, and HF repo IDs - - :param model_path: Path to model file/directory on disk. - """ - return self._local_model_hash(model_path) - - def cache_size(self) -> float: - """Return the current size of the cache, in GB.""" - return self._cache_size() / GIG - - def _has_cuda(self) -> bool: - return self.execution_device.type == "cuda" - - def _print_cuda_stats(self): - vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG) - ram = "%4.2fG" % self.cache_size() - - cached_models = 0 - loaded_models = 0 - locked_models = 0 - for model_info in self._cached_models.values(): - cached_models += 1 - if model_info.loaded: - loaded_models += 1 - if model_info.locked: - locked_models += 1 - - self.logger.debug( - f"Current VRAM/RAM usage: {vram}/{ram}; cached_models/loaded_models/locked_models/ =" - f" {cached_models}/{loaded_models}/{locked_models}" - ) - - def _cache_size(self) -> int: - return sum([m.size for m in self._cached_models.values()]) - - def _make_cache_room(self, model_size): - # calculate how much memory this model will require - # multiplier = 2 if self.precision==torch.float32 else 1 - bytes_needed = model_size - maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes - current_size = self._cache_size() - - if current_size + bytes_needed > maximum_size: - self.logger.debug( - f"Max cache size exceeded: {(current_size/GIG):.2f}/{self.max_cache_size:.2f} GB, need an additional" - f" {(bytes_needed/GIG):.2f} GB" - ) - - self.logger.debug(f"Before unloading: cached_models={len(self._cached_models)}") - - pos = 0 - models_cleared = 0 - while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack): - model_key = self._cache_stack[pos] - cache_entry = self._cached_models[model_key] - - refs = sys.getrefcount(cache_entry.model) - - # HACK: This is a workaround for a memory-management issue that we haven't tracked down yet. We are directly - # going against the advice in the Python docs by using `gc.get_referrers(...)` in this way: - # https://docs.python.org/3/library/gc.html#gc.get_referrers - - # manualy clear local variable references of just finished function calls - # for some reason python don't want to collect it even by gc.collect() immidiately - if refs > 2: - while True: - cleared = False - for referrer in gc.get_referrers(cache_entry.model): - if type(referrer).__name__ == "frame": - # RuntimeError: cannot clear an executing frame - with suppress(RuntimeError): - referrer.clear() - cleared = True - # break - - # repeat if referrers changes(due to frame clear), else exit loop - if cleared: - gc.collect() - else: - break - - device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None - self.logger.debug( - f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}," - f" refs: {refs}" - ) - - # Expected refs: - # 1 from cache_entry - # 1 from getrefcount function - # 1 from onnx runtime object - if not cache_entry.locked and refs <= (3 if "onnx" in model_key else 2): - self.logger.debug( - f"Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)" - ) - current_size -= cache_entry.size - models_cleared += 1 - if self.stats: - self.stats.cleared += 1 - del self._cache_stack[pos] - del self._cached_models[model_key] - del cache_entry - - else: - pos += 1 - - if models_cleared > 0: - # There would likely be some 'garbage' to be collected regardless of whether a model was cleared or not, but - # there is a significant time cost to calling `gc.collect()`, so we want to use it sparingly. (The time cost - # is high even if no garbage gets collected.) - # - # Calling gc.collect(...) when a model is cleared seems like a good middle-ground: - # - If models had to be cleared, it's a signal that we are close to our memory limit. - # - If models were cleared, there's a good chance that there's a significant amount of garbage to be - # collected. - # - # Keep in mind that gc is only responsible for handling reference cycles. Most objects should be cleaned up - # immediately when their reference count hits 0. - gc.collect() - - torch.cuda.empty_cache() - if choose_torch_device() == torch.device("mps"): - mps.empty_cache() - - self.logger.debug(f"After unloading: cached_models={len(self._cached_models)}") - - def _offload_unlocked_models(self, size_needed: int = 0): - reserved = self.max_vram_cache_size * GIG - vram_in_use = torch.cuda.memory_allocated() - self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB") - for model_key, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size): - if vram_in_use <= reserved: - break - if not cache_entry.locked and cache_entry.loaded: - self._move_model_to_device(model_key, self.storage_device) - - vram_in_use = torch.cuda.memory_allocated() - self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB") - - torch.cuda.empty_cache() - if choose_torch_device() == torch.device("mps"): - mps.empty_cache() - - def _local_model_hash(self, model_path: Union[str, Path]) -> str: - sha = hashlib.sha256() - path = Path(model_path) - - hashpath = path / "checksum.sha256" - if hashpath.exists() and path.stat().st_mtime <= hashpath.stat().st_mtime: - with open(hashpath) as f: - hash = f.read() - return hash - - self.logger.debug(f"computing hash of model {path.name}") - for file in list(path.rglob("*.ckpt")) + list(path.rglob("*.safetensors")) + list(path.rglob("*.pth")): - with open(file, "rb") as f: - while chunk := f.read(self.sha_chunksize): - sha.update(chunk) - hash = sha.hexdigest() - with open(hashpath, "w") as f: - f.write(hash) - return hash diff --git a/invokeai/backend/model_management_OLD/model_load_optimizations.py b/invokeai/backend/model_management_OLD/model_load_optimizations.py deleted file mode 100644 index a46d262175f..00000000000 --- a/invokeai/backend/model_management_OLD/model_load_optimizations.py +++ /dev/null @@ -1,30 +0,0 @@ -from contextlib import contextmanager - -import torch - - -def _no_op(*args, **kwargs): - pass - - -@contextmanager -def skip_torch_weight_init(): - """A context manager that monkey-patches several of the common torch layers (torch.nn.Linear, torch.nn.Conv1d, etc.) - to skip weight initialization. - - By default, `torch.nn.Linear` and `torch.nn.ConvNd` layers initialize their weights (according to a particular - distribution) when __init__ is called. This weight initialization step can take a significant amount of time, and is - completely unnecessary if the intent is to load checkpoint weights from disk for the layer. This context manager - monkey-patches common torch layers to skip the weight initialization step. - """ - torch_modules = [torch.nn.Linear, torch.nn.modules.conv._ConvNd, torch.nn.Embedding] - saved_functions = [m.reset_parameters for m in torch_modules] - - try: - for torch_module in torch_modules: - torch_module.reset_parameters = _no_op - - yield None - finally: - for torch_module, saved_function in zip(torch_modules, saved_functions, strict=True): - torch_module.reset_parameters = saved_function diff --git a/invokeai/backend/model_management_OLD/model_manager.py b/invokeai/backend/model_management_OLD/model_manager.py deleted file mode 100644 index 84d93f15fa8..00000000000 --- a/invokeai/backend/model_management_OLD/model_manager.py +++ /dev/null @@ -1,1121 +0,0 @@ -"""This module manages the InvokeAI `models.yaml` file, mapping -symbolic diffusers model names to the paths and repo_ids used by the -underlying `from_pretrained()` call. - -SYNOPSIS: - - mgr = ModelManager('/home/phi/invokeai/configs/models.yaml') - sd1_5 = mgr.get_model('stable-diffusion-v1-5', - model_type=ModelType.Main, - base_model=BaseModelType.StableDiffusion1, - submodel_type=SubModelType.Unet) - with sd1_5 as unet: - run_some_inference(unet) - -FETCHING MODELS: - -Models are described using four attributes: - - 1) model_name -- the symbolic name for the model - - 2) ModelType -- an enum describing the type of the model. Currently - defined types are: - ModelType.Main -- a full model capable of generating images - ModelType.Vae -- a VAE model - ModelType.Lora -- a LoRA or LyCORIS fine-tune - ModelType.TextualInversion -- a textual inversion embedding - ModelType.ControlNet -- a ControlNet model - ModelType.IPAdapter -- an IPAdapter model - - 3) BaseModelType -- an enum indicating the stable diffusion base model, one of: - BaseModelType.StableDiffusion1 - BaseModelType.StableDiffusion2 - - 4) SubModelType (optional) -- an enum that refers to one of the submodels contained - within the main model. Values are: - - SubModelType.UNet - SubModelType.TextEncoder - SubModelType.Tokenizer - SubModelType.Scheduler - SubModelType.SafetyChecker - -To fetch a model, use `manager.get_model()`. This takes the symbolic -name of the model, the ModelType, the BaseModelType and the -SubModelType. The latter is required for ModelType.Main. - -get_model() will return a ModelInfo object that can then be used in -context to retrieve the model and move it into GPU VRAM (on GPU -systems). - -A typical example is: - - sd1_5 = mgr.get_model('stable-diffusion-v1-5', - model_type=ModelType.Main, - base_model=BaseModelType.StableDiffusion1, - submodel_type=SubModelType.UNet) - with sd1_5 as unet: - run_some_inference(unet) - -The ModelInfo object provides a number of useful fields describing the -model, including: - - name -- symbolic name of the model - base_model -- base model (BaseModelType) - type -- model type (ModelType) - location -- path to the model file - precision -- torch precision of the model - hash -- unique sha256 checksum for this model - -SUBMODELS: - -When fetching a main model, you must specify the submodel. Retrieval -of full pipelines is not supported. - - vae_info = mgr.get_model('stable-diffusion-1.5', - model_type = ModelType.Main, - base_model = BaseModelType.StableDiffusion1, - submodel_type = SubModelType.Vae - ) - with vae_info as vae: - do_something(vae) - -This rule does not apply to controlnets, embeddings, loras and standalone -VAEs, which do not have submodels. - -LISTING MODELS - -The model_names() method will return a list of Tuples describing each -model it knows about: - - >> mgr.model_names() - [ - ('stable-diffusion-1.5', , ), - ('stable-diffusion-2.1', , ), - ('inpaint', , ) - ('Ink scenery', , ) - ... - ] - -The tuple is in the correct order to pass to get_model(): - - for m in mgr.model_names(): - info = get_model(*m) - -In contrast, the list_models() method returns a list of dicts, each -providing information about a model defined in models.yaml. For example: - - >>> models = mgr.list_models() - >>> json.dumps(models[0]) - {"path": "/home/lstein/invokeai-main/models/sd-1/controlnet/canny", - "model_format": "diffusers", - "name": "canny", - "base_model": "sd-1", - "type": "controlnet" - } - -You can filter by model type and base model as shown here: - - - controlnets = mgr.list_models(model_type=ModelType.ControlNet, - base_model=BaseModelType.StableDiffusion1) - for c in controlnets: - name = c['name'] - format = c['model_format'] - path = c['path'] - type = c['type'] - # etc - -ADDING AND REMOVING MODELS - -At startup time, the `models` directory will be scanned for -checkpoints, diffusers pipelines, controlnets, LoRAs and TI -embeddings. New entries will be added to the model manager and defunct -ones removed. Anything that is a main model (ModelType.Main) will be -added to models.yaml. For scanning to succeed, files need to be in -their proper places. For example, a controlnet folder built on the -stable diffusion 2 base, will need to be placed in -`models/sd-2/controlnet`. - -Layout of the `models` directory: - - models - ├── sd-1 - │ ├── controlnet - │ ├── lora - │ ├── main - │ └── embedding - ├── sd-2 - │ ├── controlnet - │ ├── lora - │ ├── main - │ └── embedding - └── core - ├── face_reconstruction - │ ├── codeformer - │ └── gfpgan - ├── sd-conversion - │ ├── clip-vit-large-patch14 - tokenizer, text_encoder subdirs - │ ├── stable-diffusion-2 - tokenizer, text_encoder subdirs - │ └── stable-diffusion-safety-checker - └── upscaling - └─── esrgan - - - -class ConfigMeta(BaseModel):Loras, textual_inversion and controlnet models are not listed -explicitly in models.yaml, but are added to the in-memory data -structure at initialization time by scanning the models directory. The -in-memory data structure can be resynchronized by calling -`manager.scan_models_directory()`. - -Files and folders placed inside the `autoimport` paths (paths -defined in `invokeai.yaml`) will also be scanned for new models at -initialization time and added to `models.yaml`. Files will not be -moved from this location but preserved in-place. These directories -are: - - configuration default description - ------------- ------- ----------- - autoimport_dir autoimport/main main models - lora_dir autoimport/lora LoRA/LyCORIS models - embedding_dir autoimport/embedding TI embeddings - controlnet_dir autoimport/controlnet ControlNet models - -In actuality, models located in any of these directories are scanned -to determine their type, so it isn't strictly necessary to organize -the different types in this way. This entry in `invokeai.yaml` will -recursively scan all subdirectories within `autoimport`, scan models -files it finds, and import them if recognized. - - Paths: - autoimport_dir: autoimport - -A model can be manually added using `add_model()` using the model's -name, base model, type and a dict of model attributes. See -`invokeai/backend/model_management/models` for the attributes required -by each model type. - -A model can be deleted using `del_model()`, providing the same -identifying information as `get_model()` - -The `heuristic_import()` method will take a set of strings -corresponding to local paths, remote URLs, and repo_ids, probe the -object to determine what type of model it is (if any), and import new -models into the manager. If passed a directory, it will recursively -scan it for models to import. The return value is a set of the models -successfully added. - -MODELS.YAML - -The general format of a models.yaml section is: - - type-of-model/name-of-model: - path: /path/to/local/file/or/directory - description: a description - format: diffusers|checkpoint - variant: normal|inpaint|depth - -The type of model is given in the stanza key, and is one of -{main, vae, lora, controlnet, textual} - -The format indicates whether the model is organized as a diffusers -folder with model subdirectories, or is contained in a single -checkpoint or safetensors file. - -The path points to a file or directory on disk. If a relative path, -the root is the InvokeAI ROOTDIR. - -""" -from __future__ import annotations - -import hashlib -import os -import textwrap -import types -from dataclasses import dataclass -from pathlib import Path -from shutil import move, rmtree -from typing import Callable, Dict, List, Literal, Optional, Set, Tuple, Union, cast - -import torch -import yaml -from omegaconf import OmegaConf -from omegaconf.dictconfig import DictConfig -from pydantic import BaseModel, ConfigDict, Field - -import invokeai.backend.util.logging as logger -from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend.util import CUDA_DEVICE, Chdir - -from .model_cache import ModelCache, ModelLocker -from .model_search import ModelSearch -from .models import ( - MODEL_CLASSES, - BaseModelType, - DuplicateModelException, - InvalidModelException, - ModelBase, - ModelConfigBase, - ModelError, - ModelNotFoundException, - ModelType, - SchedulerPredictionType, - SubModelType, -) - -# We are only starting to number the config file with release 3. -# The config file version doesn't have to start at release version, but it will help -# reduce confusion. -CONFIG_FILE_VERSION = "3.0.0" - - -@dataclass -class LoadedModelInfo: - context: ModelLocker - name: str - base_model: BaseModelType - type: ModelType - hash: str - location: Union[Path, str] - precision: torch.dtype - _cache: Optional[ModelCache] = None - - def __enter__(self): - return self.context.__enter__() - - def __exit__(self, *args, **kwargs): - self.context.__exit__(*args, **kwargs) - - -class AddModelResult(BaseModel): - name: str = Field(description="The name of the model after installation") - model_type: ModelType = Field(description="The type of model") - base_model: BaseModelType = Field(description="The base model") - config: ModelConfigBase = Field(description="The configuration of the model") - - model_config = ConfigDict(protected_namespaces=()) - - -MAX_CACHE_SIZE = 6.0 # GB - - -class ConfigMeta(BaseModel): - version: str - - -class ModelManager(object): - """ - High-level interface to model management. - """ - - logger: types.ModuleType = logger - - def __init__( - self, - config: Union[Path, DictConfig, str], - device_type: torch.device = CUDA_DEVICE, - precision: torch.dtype = torch.float16, - max_cache_size=MAX_CACHE_SIZE, - sequential_offload=False, - logger: types.ModuleType = logger, - ): - """ - Initialize with the path to the models.yaml config file. - Optional parameters are the torch device type, precision, max_models, - and sequential_offload boolean. Note that the default device - type and precision are set up for a CUDA system running at half precision. - """ - self.config_path = None - if isinstance(config, (str, Path)): - self.config_path = Path(config) - if not self.config_path.exists(): - logger.warning(f"The file {self.config_path} was not found. Initializing a new file") - self.initialize_model_config(self.config_path) - config = OmegaConf.load(self.config_path) - - elif not isinstance(config, DictConfig): - raise ValueError("config argument must be an OmegaConf object, a Path or a string") - - self.config_meta = ConfigMeta(**config.pop("__metadata__")) - # TODO: metadata not found - # TODO: version check - - self.app_config = InvokeAIAppConfig.get_config() - self.logger = logger - self.cache = ModelCache( - max_cache_size=max_cache_size, - max_vram_cache_size=self.app_config.vram_cache_size, - lazy_offloading=self.app_config.lazy_offload, - execution_device=device_type, - precision=precision, - sequential_offload=sequential_offload, - logger=logger, - log_memory_usage=self.app_config.log_memory_usage, - ) - - self._read_models(config) - - def _read_models(self, config: Optional[DictConfig] = None): - if not config: - if self.config_path: - config = OmegaConf.load(self.config_path) - else: - return - - self.models = {} - for model_key, model_config in config.items(): - if model_key.startswith("_"): - continue - model_name, base_model, model_type = self.parse_key(model_key) - model_class = self._get_implementation(base_model, model_type) - # alias for config file - model_config["model_format"] = model_config.pop("format") - self.models[model_key] = model_class.create_config(**model_config) - - # check config version number and update on disk/RAM if necessary - self.cache_keys = {} - - # add controlnet, lora and textual_inversion models from disk - self.scan_models_directory() - - def sync_to_config(self): - """ - Call this when `models.yaml` has been changed externally. - This will reinitialize internal data structures - """ - # Reread models directory; note that this will reinitialize the cache, - # causing otherwise unreferenced models to be removed from memory - self._read_models() - - def model_exists(self, model_name: str, base_model: BaseModelType, model_type: ModelType, *, rescan=False) -> bool: - """ - Given a model name, returns True if it is a valid identifier. - - :param model_name: symbolic name of the model in models.yaml - :param model_type: ModelType enum indicating the type of model to return - :param base_model: BaseModelType enum indicating the base model used by this model - :param rescan: if True, scan_models_directory - """ - model_key = self.create_key(model_name, base_model, model_type) - exists = model_key in self.models - - # if model not found try to find it (maybe file just pasted) - if rescan and not exists: - self.scan_models_directory(base_model=base_model, model_type=model_type) - exists = self.model_exists(model_name, base_model, model_type, rescan=False) - - return exists - - @classmethod - def create_key( - cls, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - ) -> str: - # In 3.11, the behavior of (str,enum) when interpolated into a - # string has changed. The next two lines are defensive. - base_model = BaseModelType(base_model) - model_type = ModelType(model_type) - return f"{base_model.value}/{model_type.value}/{model_name}" - - @classmethod - def parse_key(cls, model_key: str) -> Tuple[str, BaseModelType, ModelType]: - base_model_str, model_type_str, model_name = model_key.split("/", 2) - try: - model_type = ModelType(model_type_str) - except Exception: - raise Exception(f"Unknown model type: {model_type_str}") - - try: - base_model = BaseModelType(base_model_str) - except Exception: - raise Exception(f"Unknown base model: {base_model_str}") - - return (model_name, base_model, model_type) - - def _get_model_cache_path(self, model_path): - return self.resolve_model_path(Path(".cache") / hashlib.md5(str(model_path).encode()).hexdigest()) - - @classmethod - def initialize_model_config(cls, config_path: Path): - """Create empty config file""" - with open(config_path, "w") as yaml_file: - yaml_file.write(yaml.dump({"__metadata__": {"version": "3.0.0"}})) - - def get_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - submodel_type: Optional[SubModelType] = None, - ) -> LoadedModelInfo: - """Given a model named identified in models.yaml, return - an ModelInfo object describing it. - :param model_name: symbolic name of the model in models.yaml - :param model_type: ModelType enum indicating the type of model to return - :param base_model: BaseModelType enum indicating the base model used by this model - :param submodel_type: an ModelType enum indicating the portion of - the model to retrieve (e.g. ModelType.Vae) - """ - model_key = self.create_key(model_name, base_model, model_type) - - if not self.model_exists(model_name, base_model, model_type, rescan=True): - raise ModelNotFoundException(f"Model not found - {model_key}") - - model_config = self._get_model_config(base_model, model_name, model_type) - - model_path, is_submodel_override = self._get_model_path(model_config, submodel_type) - - if is_submodel_override: - model_type = submodel_type - submodel_type = None - - model_class = self._get_implementation(base_model, model_type) - - if not model_path.exists(): - if model_class.save_to_config: - self.models[model_key].error = ModelError.NotFound - raise Exception(f'Files for model "{model_key}" not found at {model_path}') - - else: - self.models.pop(model_key, None) - raise ModelNotFoundException(f'Files for model "{model_key}" not found at {model_path}') - - # TODO: path - # TODO: is it accurate to use path as id - dst_convert_path = self._get_model_cache_path(model_path) - - model_path = model_class.convert_if_required( - base_model=base_model, - model_path=str(model_path), # TODO: refactor str/Path types logic - output_path=dst_convert_path, - config=model_config, - ) - - model_context = self.cache.get_model( - model_path=model_path, - model_class=model_class, - base_model=base_model, - model_type=model_type, - submodel_type=submodel_type, - ) - - if model_key not in self.cache_keys: - self.cache_keys[model_key] = set() - self.cache_keys[model_key].add(model_context.key) - - model_hash = "" # TODO: - - return LoadedModelInfo( - context=model_context, - name=model_name, - base_model=base_model, - type=submodel_type or model_type, - hash=model_hash, - location=model_path, # TODO: - precision=self.cache.precision, - _cache=self.cache, - ) - - def _get_model_path( - self, model_config: ModelConfigBase, submodel_type: Optional[SubModelType] = None - ) -> (Path, bool): - """Extract a model's filesystem path from its config. - - :return: The fully qualified Path of the module (or submodule). - """ - model_path = model_config.path - is_submodel_override = False - - # Does the config explicitly override the submodel? - if submodel_type is not None and hasattr(model_config, submodel_type): - submodel_path = getattr(model_config, submodel_type) - if submodel_path is not None and len(submodel_path) > 0: - model_path = getattr(model_config, submodel_type) - is_submodel_override = True - - model_path = self.resolve_model_path(model_path) - return model_path, is_submodel_override - - def _get_model_config(self, base_model: BaseModelType, model_name: str, model_type: ModelType) -> ModelConfigBase: - """Get a model's config object.""" - model_key = self.create_key(model_name, base_model, model_type) - try: - model_config = self.models[model_key] - except KeyError: - raise ModelNotFoundException(f"Model not found - {model_key}") - return model_config - - def _get_implementation(self, base_model: BaseModelType, model_type: ModelType) -> type[ModelBase]: - """Get the concrete implementation class for a specific model type.""" - model_class = MODEL_CLASSES[base_model][model_type] - return model_class - - def _instantiate( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - submodel_type: Optional[SubModelType] = None, - ) -> ModelBase: - """Make a new instance of this model, without loading it.""" - model_config = self._get_model_config(base_model, model_name, model_type) - model_path, is_submodel_override = self._get_model_path(model_config, submodel_type) - # FIXME: do non-overriden submodels get the right class? - constructor = self._get_implementation(base_model, model_type) - instance = constructor(model_path, base_model, model_type) - return instance - - def model_info( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - ) -> Union[dict, None]: - """ - Given a model name returns the OmegaConf (dict-like) object describing it. - """ - model_key = self.create_key(model_name, base_model, model_type) - if model_key in self.models: - return self.models[model_key].model_dump(exclude_defaults=True) - else: - return None # TODO: None or empty dict on not found - - def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]: - """ - Return a list of (str, BaseModelType, ModelType) corresponding to all models - known to the configuration. - """ - return [(self.parse_key(x)) for x in self.models.keys()] - - def list_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - ) -> Union[dict, None]: - """ - Returns a dict describing one installed model, using - the combined format of the list_models() method. - """ - models = self.list_models(base_model, model_type, model_name) - if len(models) >= 1: - return models[0] - else: - return None - - def list_models( - self, - base_model: Optional[BaseModelType] = None, - model_type: Optional[ModelType] = None, - model_name: Optional[str] = None, - ) -> list[dict]: - """ - Return a list of models. - """ - - model_keys = ( - [self.create_key(model_name, base_model, model_type)] - if model_name and base_model and model_type - else sorted(self.models, key=str.casefold) - ) - models = [] - for model_key in model_keys: - model_config = self.models.get(model_key) - if not model_config: - self.logger.error(f"Unknown model {model_name}") - raise ModelNotFoundException(f"Unknown model {model_name}") - - cur_model_name, cur_base_model, cur_model_type = self.parse_key(model_key) - if base_model is not None and cur_base_model != base_model: - continue - if model_type is not None and cur_model_type != model_type: - continue - - model_dict = dict( - **model_config.model_dump(exclude_defaults=True), - # OpenAPIModelInfoBase - model_name=cur_model_name, - base_model=cur_base_model, - model_type=cur_model_type, - ) - - # expose paths as absolute to help web UI - if path := model_dict.get("path"): - model_dict["path"] = str(self.resolve_model_path(path)) - models.append(model_dict) - - return models - - def print_models(self) -> None: - """ - Print a table of models and their descriptions. This needs to be redone - """ - # TODO: redo - for model_dict in self.list_models(): - for _model_name, model_info in model_dict.items(): - line = f'{model_info["name"]:25s} {model_info["type"]:10s} {model_info["description"]}' - print(line) - - # Tested - LS - def del_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - ): - """ - Delete the named model. - """ - model_key = self.create_key(model_name, base_model, model_type) - model_cfg = self.models.pop(model_key, None) - - if model_cfg is None: - raise ModelNotFoundException(f"Unknown model {model_key}") - - # note: it not garantie to release memory(model can has other references) - cache_ids = self.cache_keys.pop(model_key, []) - for cache_id in cache_ids: - self.cache.uncache_model(cache_id) - - # if model inside invoke models folder - delete files - model_path = self.resolve_model_path(model_cfg.path) - cache_path = self._get_model_cache_path(model_path) - if cache_path.exists(): - rmtree(str(cache_path)) - - if model_path.is_relative_to(self.app_config.models_path): - if model_path.is_dir(): - rmtree(str(model_path)) - else: - model_path.unlink() - self.commit() - - # LS: tested - def add_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - model_attributes: dict, - clobber: bool = False, - ) -> AddModelResult: - """ - Update the named model with a dictionary of attributes. Will fail with an - assertion error if the name already exists. Pass clobber=True to overwrite. - On a successful update, the config will be changed in memory and the - method will return True. Will fail with an assertion error if provided - attributes are incorrect or the model name is missing. - - The returned dict has the same format as the dict returned by - model_info(). - """ - # relativize paths as they go in - this makes it easier to move the models directory around - if path := model_attributes.get("path"): - model_attributes["path"] = str(self.relative_model_path(Path(path))) - - model_class = self._get_implementation(base_model, model_type) - model_config = model_class.create_config(**model_attributes) - model_key = self.create_key(model_name, base_model, model_type) - - if model_key in self.models and not clobber: - raise Exception(f'Attempt to overwrite existing model definition "{model_key}"') - - old_model = self.models.pop(model_key, None) - if old_model is not None: - # TODO: if path changed and old_model.path inside models folder should we delete this too? - - # remove conversion cache as config changed - old_model_path = self.resolve_model_path(old_model.path) - old_model_cache = self._get_model_cache_path(old_model_path) - if old_model_cache.exists(): - if old_model_cache.is_dir(): - rmtree(str(old_model_cache)) - else: - old_model_cache.unlink() - - # remove in-memory cache - # note: it not guaranteed to release memory(model can has other references) - cache_ids = self.cache_keys.pop(model_key, []) - for cache_id in cache_ids: - self.cache.uncache_model(cache_id) - - self.models[model_key] = model_config - self.commit() - - return AddModelResult( - name=model_name, - model_type=model_type, - base_model=base_model, - config=model_config, - ) - - def rename_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - new_name: Optional[str] = None, - new_base: Optional[BaseModelType] = None, - ) -> None: - """ - Rename or rebase a model. - """ - if new_name is None and new_base is None: - self.logger.error("rename_model() called with neither a new_name nor a new_base. {model_name} unchanged.") - return - - model_key = self.create_key(model_name, base_model, model_type) - model_cfg = self.models.get(model_key, None) - if not model_cfg: - raise ModelNotFoundException(f"Unknown model: {model_key}") - - old_path = self.resolve_model_path(model_cfg.path) - new_name = new_name or model_name - new_base = new_base or base_model - new_key = self.create_key(new_name, new_base, model_type) - if new_key in self.models: - raise ValueError(f'Attempt to overwrite existing model definition "{new_key}"') - - # if this is a model file/directory that we manage ourselves, we need to move it - if old_path.is_relative_to(self.app_config.models_path): - # keep the suffix! - if old_path.is_file(): - new_name = Path(new_name).with_suffix(old_path.suffix).as_posix() - new_path = self.resolve_model_path( - Path( - BaseModelType(new_base).value, - ModelType(model_type).value, - new_name, - ) - ) - move(old_path, new_path) - model_cfg.path = str(new_path.relative_to(self.app_config.models_path)) - - # clean up caches - old_model_cache = self._get_model_cache_path(old_path) - if old_model_cache.exists(): - if old_model_cache.is_dir(): - rmtree(str(old_model_cache)) - else: - old_model_cache.unlink() - - cache_ids = self.cache_keys.pop(model_key, []) - for cache_id in cache_ids: - self.cache.uncache_model(cache_id) - - self.models.pop(model_key, None) # delete - self.models[new_key] = model_cfg - self.commit() - - def convert_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: Literal[ModelType.Main, ModelType.Vae], - dest_directory: Optional[Path] = None, - ) -> AddModelResult: - """ - Convert a checkpoint file into a diffusers folder, deleting the cached - version and deleting the original checkpoint file if it is in the models - directory. - :param model_name: Name of the model to convert - :param base_model: Base model type - :param model_type: Type of model ['vae' or 'main'] - - This will raise a ValueError unless the model is a checkpoint. - """ - info = self.model_info(model_name, base_model, model_type) - - if info is None: - raise FileNotFoundError(f"model not found: {model_name}") - - if info["model_format"] != "checkpoint": - raise ValueError(f"not a checkpoint format model: {model_name}") - - # We are taking advantage of a side effect of get_model() that converts check points - # into cached diffusers directories stored at `location`. It doesn't matter - # what submodeltype we request here, so we get the smallest. - submodel = {"submodel_type": SubModelType.Scheduler} if model_type == ModelType.Main else {} - model = self.get_model( - model_name, - base_model, - model_type, - **submodel, - ) - checkpoint_path = self.resolve_model_path(info["path"]) - old_diffusers_path = self.resolve_model_path(model.location) - new_diffusers_path = ( - dest_directory or self.app_config.models_path / base_model.value / model_type.value - ) / model_name - if new_diffusers_path.exists(): - raise ValueError(f"A diffusers model already exists at {new_diffusers_path}") - - try: - move(old_diffusers_path, new_diffusers_path) - info["model_format"] = "diffusers" - info["path"] = ( - str(new_diffusers_path) - if dest_directory - else str(new_diffusers_path.relative_to(self.app_config.models_path)) - ) - info.pop("config") - - result = self.add_model(model_name, base_model, model_type, model_attributes=info, clobber=True) - except Exception: - # something went wrong, so don't leave dangling diffusers model in directory or it will cause a duplicate model error! - rmtree(new_diffusers_path) - raise - - if checkpoint_path.exists() and checkpoint_path.is_relative_to(self.app_config.models_path): - checkpoint_path.unlink() - - return result - - def resolve_model_path(self, path: Union[Path, str]) -> Path: - """return relative paths based on configured models_path""" - return self.app_config.models_path / path - - def relative_model_path(self, model_path: Path) -> Path: - if model_path.is_relative_to(self.app_config.models_path): - model_path = model_path.relative_to(self.app_config.models_path) - return model_path - - def search_models(self, search_folder): - self.logger.info(f"Finding Models In: {search_folder}") - models_folder_ckpt = Path(search_folder).glob("**/*.ckpt") - models_folder_safetensors = Path(search_folder).glob("**/*.safetensors") - - ckpt_files = [x for x in models_folder_ckpt if x.is_file()] - safetensor_files = [x for x in models_folder_safetensors if x.is_file()] - - files = ckpt_files + safetensor_files - - found_models = [] - for file in files: - location = str(file.resolve()).replace("\\", "/") - if "model.safetensors" not in location and "diffusion_pytorch_model.safetensors" not in location: - found_models.append({"name": file.stem, "location": location}) - - return search_folder, found_models - - def commit(self, conf_file: Optional[Path] = None) -> None: - """ - Write current configuration out to the indicated file. - """ - data_to_save = {} - data_to_save["__metadata__"] = self.config_meta.model_dump() - - for model_key, model_config in self.models.items(): - model_name, base_model, model_type = self.parse_key(model_key) - model_class = self._get_implementation(base_model, model_type) - if model_class.save_to_config: - # TODO: or exclude_unset better fits here? - data_to_save[model_key] = cast(BaseModel, model_config).model_dump( - exclude_defaults=True, exclude={"error"}, mode="json" - ) - # alias for config file - data_to_save[model_key]["format"] = data_to_save[model_key].pop("model_format") - - yaml_str = OmegaConf.to_yaml(data_to_save) - config_file_path = conf_file or self.config_path - assert config_file_path is not None, "no config file path to write to" - config_file_path = self.app_config.root_path / config_file_path - tmpfile = os.path.join(os.path.dirname(config_file_path), "new_config.tmp") - try: - with open(tmpfile, "w", encoding="utf-8") as outfile: - outfile.write(self.preamble()) - outfile.write(yaml_str) - os.replace(tmpfile, config_file_path) - except OSError as err: - self.logger.warning(f"Could not modify the config file at {config_file_path}") - self.logger.warning(err) - - def preamble(self) -> str: - """ - Returns the preamble for the config file. - """ - return textwrap.dedent( - """ - # This file describes the alternative machine learning models - # available to InvokeAI script. - # - # To add a new model, follow the examples below. Each - # model requires a model config file, a weights file, - # and the width and height of the images it - # was trained on. - """ - ) - - def scan_models_directory( - self, - base_model: Optional[BaseModelType] = None, - model_type: Optional[ModelType] = None, - ): - loaded_files = set() - new_models_found = False - - self.logger.info(f"Scanning {self.app_config.models_path} for new models") - with Chdir(self.app_config.models_path): - for model_key, model_config in list(self.models.items()): - model_name, cur_base_model, cur_model_type = self.parse_key(model_key) - - # Patch for relative path bug in older models.yaml - paths should not - # be starting with a hard-coded 'models'. This will also fix up - # models.yaml when committed. - if model_config.path.startswith("models"): - model_config.path = str(Path(*Path(model_config.path).parts[1:])) - - model_path = self.resolve_model_path(model_config.path).absolute() - if not model_path.exists(): - model_class = self._get_implementation(cur_base_model, cur_model_type) - if model_class.save_to_config: - model_config.error = ModelError.NotFound - self.models.pop(model_key, None) - else: - self.models.pop(model_key, None) - else: - loaded_files.add(model_path) - - for cur_base_model in BaseModelType: - if base_model is not None and cur_base_model != base_model: - continue - - for cur_model_type in ModelType: - if model_type is not None and cur_model_type != model_type: - continue - model_class = self._get_implementation(cur_base_model, cur_model_type) - models_dir = self.resolve_model_path(Path(cur_base_model.value, cur_model_type.value)) - - if not models_dir.exists(): - continue # TODO: or create all folders? - - for model_path in models_dir.iterdir(): - if model_path not in loaded_files: # TODO: check - if model_path.name.startswith("."): - continue - model_name = model_path.name if model_path.is_dir() else model_path.stem - model_key = self.create_key(model_name, cur_base_model, cur_model_type) - - try: - if model_key in self.models: - raise DuplicateModelException(f"Model with key {model_key} added twice") - - model_path = self.relative_model_path(model_path) - model_config: ModelConfigBase = model_class.probe_config( - str(model_path), model_base=cur_base_model - ) - self.models[model_key] = model_config - new_models_found = True - except DuplicateModelException as e: - self.logger.warning(e) - except InvalidModelException as e: - self.logger.warning(f"Not a valid model: {model_path}. {e}") - except NotImplementedError as e: - self.logger.warning(e) - except Exception as e: - self.logger.warning(f"Error loading model {model_path}. {e}") - - imported_models = self.scan_autoimport_directory() - if (new_models_found or imported_models) and self.config_path: - self.commit() - - def scan_autoimport_directory(self) -> Dict[str, AddModelResult]: - """ - Scan the autoimport directory (if defined) and import new models, delete defunct models. - """ - # avoid circular import - from invokeai.backend.install.model_install_backend import ModelInstall - from invokeai.frontend.install.model_install import ask_user_for_prediction_type - - class ScanAndImport(ModelSearch): - def __init__(self, directories, logger, ignore: Set[Path], installer: ModelInstall): - super().__init__(directories, logger) - self.installer = installer - self.ignore = ignore - - def on_search_started(self): - self.new_models_found = {} - - def on_model_found(self, model: Path): - if model not in self.ignore: - self.new_models_found.update(self.installer.heuristic_import(model)) - - def on_search_completed(self): - self.logger.info( - f"Scanned {self._items_scanned} files and directories, imported {len(self.new_models_found)} models" - ) - - def models_found(self): - return self.new_models_found - - config = self.app_config - - # LS: hacky - # Patch in the SD VAE from core so that it is available for use by the UI - try: - self.heuristic_import({str(self.resolve_model_path("core/convert/sd-vae-ft-mse"))}) - except Exception: - pass - - installer = ModelInstall( - config=self.app_config, - model_manager=self, - prediction_type_helper=ask_user_for_prediction_type, - ) - known_paths = {self.resolve_model_path(x["path"]) for x in self.list_models()} - directories = { - config.root_path / x - for x in [ - config.autoimport_dir, - config.lora_dir, - config.embedding_dir, - config.controlnet_dir, - ] - if x - } - scanner = ScanAndImport(directories, self.logger, ignore=known_paths, installer=installer) - scanner.search() - - return scanner.models_found() - - def heuristic_import( - self, - items_to_import: Set[str], - prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None, - ) -> Dict[str, AddModelResult]: - """Import a list of paths, repo_ids or URLs. Returns the set of - successfully imported items. - :param items_to_import: Set of strings corresponding to models to be imported. - :param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType. - - The prediction type helper is necessary to distinguish between - models based on Stable Diffusion 2 Base (requiring - SchedulerPredictionType.Epsilson) and Stable Diffusion 768 - (requiring SchedulerPredictionType.VPrediction). It is - generally impossible to do this programmatically, so the - prediction_type_helper usually asks the user to choose. - - The result is a set of successfully installed models. Each element - of the set is a dict corresponding to the newly-created OmegaConf stanza for - that model. - - May return the following exceptions: - - ModelNotFoundException - one or more of the items to import is not a valid path, repo_id or URL - - ValueError - a corresponding model already exists - """ - # avoid circular import here - from invokeai.backend.install.model_install_backend import ModelInstall - - successfully_installed = {} - - installer = ModelInstall( - config=self.app_config, prediction_type_helper=prediction_type_helper, model_manager=self - ) - for thing in items_to_import: - installed = installer.heuristic_import(thing) - successfully_installed.update(installed) - self.commit() - return successfully_installed diff --git a/invokeai/backend/model_management_OLD/model_merge.py b/invokeai/backend/model_management_OLD/model_merge.py deleted file mode 100644 index a9f0a23618e..00000000000 --- a/invokeai/backend/model_management_OLD/model_merge.py +++ /dev/null @@ -1,140 +0,0 @@ -""" -invokeai.backend.model_management.model_merge exports: -merge_diffusion_models() -- combine multiple models by location and return a pipeline object -merge_diffusion_models_and_commit() -- combine multiple models by ModelManager ID and write to models.yaml - -Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team -""" - -import warnings -from enum import Enum -from pathlib import Path -from typing import List, Optional, Union - -from diffusers import DiffusionPipeline -from diffusers import logging as dlogging - -import invokeai.backend.util.logging as logger - -from ...backend.model_management import AddModelResult, BaseModelType, ModelManager, ModelType, ModelVariantType - - -class MergeInterpolationMethod(str, Enum): - WeightedSum = "weighted_sum" - Sigmoid = "sigmoid" - InvSigmoid = "inv_sigmoid" - AddDifference = "add_difference" - - -class ModelMerger(object): - def __init__(self, manager: ModelManager): - self.manager = manager - - def merge_diffusion_models( - self, - model_paths: List[Path], - alpha: float = 0.5, - interp: Optional[MergeInterpolationMethod] = None, - force: bool = False, - **kwargs, - ) -> DiffusionPipeline: - """ - :param model_paths: up to three models, designated by their local paths or HuggingFace repo_ids - :param alpha: The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha - would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2 - :param interp: The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_difference" and None. - Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. - :param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False. - - **kwargs - the default DiffusionPipeline.get_config_dict kwargs: - cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map - """ - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - verbosity = dlogging.get_verbosity() - dlogging.set_verbosity_error() - - pipe = DiffusionPipeline.from_pretrained( - model_paths[0], - custom_pipeline="checkpoint_merger", - ) - merged_pipe = pipe.merge( - pretrained_model_name_or_path_list=model_paths, - alpha=alpha, - interp=interp.value if interp else None, # diffusers API treats None as "weighted sum" - force=force, - **kwargs, - ) - dlogging.set_verbosity(verbosity) - return merged_pipe - - def merge_diffusion_models_and_save( - self, - model_names: List[str], - base_model: Union[BaseModelType, str], - merged_model_name: str, - alpha: float = 0.5, - interp: Optional[MergeInterpolationMethod] = None, - force: bool = False, - merge_dest_directory: Optional[Path] = None, - **kwargs, - ) -> AddModelResult: - """ - :param models: up to three models, designated by their InvokeAI models.yaml model name - :param base_model: base model (must be the same for all merged models!) - :param merged_model_name: name for new model - :param alpha: The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha - would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2 - :param interp: The interpolation method to use for the merging. Supports "weighted_average", "sigmoid", "inv_sigmoid", "add_difference" and None. - Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. Add_difference is A+(B-C). - :param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False. - :param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended) - **kwargs - the default DiffusionPipeline.get_config_dict kwargs: - cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map - """ - model_paths = [] - config = self.manager.app_config - base_model = BaseModelType(base_model) - vae = None - - for mod in model_names: - info = self.manager.list_model(mod, base_model=base_model, model_type=ModelType.Main) - assert info, f"model {mod}, base_model {base_model}, is unknown" - assert ( - info["model_format"] == "diffusers" - ), f"{mod} is not a diffusers model. It must be optimized before merging" - assert info["variant"] == "normal", f"{mod} is a {info['variant']} model, which cannot currently be merged" - assert ( - len(model_names) <= 2 or interp == MergeInterpolationMethod.AddDifference - ), "When merging three models, only the 'add_difference' merge method is supported" - # pick up the first model's vae - if mod == model_names[0]: - vae = info.get("vae") - model_paths.extend([(config.root_path / info["path"]).as_posix()]) - - merge_method = None if interp == "weighted_sum" else MergeInterpolationMethod(interp) - logger.debug(f"interp = {interp}, merge_method={merge_method}") - merged_pipe = self.merge_diffusion_models(model_paths, alpha, merge_method, force, **kwargs) - dump_path = ( - Path(merge_dest_directory) - if merge_dest_directory - else config.models_path / base_model.value / ModelType.Main.value - ) - dump_path.mkdir(parents=True, exist_ok=True) - dump_path = (dump_path / merged_model_name).as_posix() - - merged_pipe.save_pretrained(dump_path, safe_serialization=True) - attributes = { - "path": dump_path, - "description": f"Merge of models {', '.join(model_names)}", - "model_format": "diffusers", - "variant": ModelVariantType.Normal.value, - "vae": vae, - } - return self.manager.add_model( - merged_model_name, - base_model=base_model, - model_type=ModelType.Main, - model_attributes=attributes, - clobber=True, - ) diff --git a/invokeai/backend/model_management_OLD/model_probe.py b/invokeai/backend/model_management_OLD/model_probe.py deleted file mode 100644 index 74b1b72d317..00000000000 --- a/invokeai/backend/model_management_OLD/model_probe.py +++ /dev/null @@ -1,664 +0,0 @@ -import json -import re -from dataclasses import dataclass -from pathlib import Path -from typing import Callable, Dict, Literal, Optional, Union - -import safetensors.torch -import torch -from diffusers import ConfigMixin, ModelMixin -from picklescan.scanner import scan_file_path - -from invokeai.backend.model_management.models.ip_adapter import IPAdapterModelFormat - -from .models import ( - BaseModelType, - InvalidModelException, - ModelType, - ModelVariantType, - SchedulerPredictionType, - SilenceWarnings, -) -from .models.base import read_checkpoint_meta -from .util import lora_token_vector_length - - -@dataclass -class ModelProbeInfo(object): - model_type: ModelType - base_type: BaseModelType - variant_type: ModelVariantType - prediction_type: SchedulerPredictionType - upcast_attention: bool - format: Literal["diffusers", "checkpoint", "lycoris", "olive", "onnx"] - image_size: int - name: Optional[str] = None - description: Optional[str] = None - - -class ProbeBase(object): - """forward declaration""" - - pass - - -class ModelProbe(object): - PROBES = { - "diffusers": {}, - "checkpoint": {}, - "onnx": {}, - } - - CLASS2TYPE = { - "StableDiffusionPipeline": ModelType.Main, - "StableDiffusionInpaintPipeline": ModelType.Main, - "StableDiffusionXLPipeline": ModelType.Main, - "StableDiffusionXLImg2ImgPipeline": ModelType.Main, - "StableDiffusionXLInpaintPipeline": ModelType.Main, - "LatentConsistencyModelPipeline": ModelType.Main, - "AutoencoderKL": ModelType.Vae, - "AutoencoderTiny": ModelType.Vae, - "ControlNetModel": ModelType.ControlNet, - "CLIPVisionModelWithProjection": ModelType.CLIPVision, - "T2IAdapter": ModelType.T2IAdapter, - } - - @classmethod - def register_probe( - cls, format: Literal["diffusers", "checkpoint", "onnx"], model_type: ModelType, probe_class: ProbeBase - ): - cls.PROBES[format][model_type] = probe_class - - @classmethod - def heuristic_probe( - cls, - model: Union[Dict, ModelMixin, Path], - prediction_type_helper: Callable[[Path], SchedulerPredictionType] = None, - ) -> ModelProbeInfo: - if isinstance(model, Path): - return cls.probe(model_path=model, prediction_type_helper=prediction_type_helper) - elif isinstance(model, (dict, ModelMixin, ConfigMixin)): - return cls.probe(model_path=None, model=model, prediction_type_helper=prediction_type_helper) - else: - raise InvalidModelException("model parameter {model} is neither a Path, nor a model") - - @classmethod - def probe( - cls, - model_path: Path, - model: Optional[Union[Dict, ModelMixin]] = None, - prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None, - ) -> ModelProbeInfo: - """ - Probe the model at model_path and return sufficient information about it - to place it somewhere in the models directory hierarchy. If the model is - already loaded into memory, you may provide it as model in order to avoid - opening it a second time. The prediction_type_helper callable is a function that receives - the path to the model and returns the SchedulerPredictionType. - """ - if model_path: - format_type = "diffusers" if model_path.is_dir() else "checkpoint" - else: - format_type = "diffusers" if isinstance(model, (ConfigMixin, ModelMixin)) else "checkpoint" - model_info = None - try: - model_type = ( - cls.get_model_type_from_folder(model_path, model) - if format_type == "diffusers" - else cls.get_model_type_from_checkpoint(model_path, model) - ) - format_type = "onnx" if model_type == ModelType.ONNX else format_type - probe_class = cls.PROBES[format_type].get(model_type) - if not probe_class: - return None - probe = probe_class(model_path, model, prediction_type_helper) - base_type = probe.get_base_type() - variant_type = probe.get_variant_type() - prediction_type = probe.get_scheduler_prediction_type() - name = cls.get_model_name(model_path) - description = f"{base_type.value} {model_type.value} model {name}" - format = probe.get_format() - model_info = ModelProbeInfo( - model_type=model_type, - base_type=base_type, - variant_type=variant_type, - prediction_type=prediction_type, - name=name, - description=description, - upcast_attention=( - base_type == BaseModelType.StableDiffusion2 - and prediction_type == SchedulerPredictionType.VPrediction - ), - format=format, - image_size=( - 1024 - if (base_type in {BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner}) - else ( - 768 - if ( - base_type == BaseModelType.StableDiffusion2 - and prediction_type == SchedulerPredictionType.VPrediction - ) - else 512 - ) - ), - ) - except Exception: - raise - - return model_info - - @classmethod - def get_model_name(cls, model_path: Path) -> str: - if model_path.suffix in {".safetensors", ".bin", ".pt", ".ckpt"}: - return model_path.stem - else: - return model_path.name - - @classmethod - def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: dict) -> ModelType: - if model_path.suffix not in (".bin", ".pt", ".ckpt", ".safetensors", ".pth"): - return None - - if model_path.name == "learned_embeds.bin": - return ModelType.TextualInversion - - ckpt = checkpoint if checkpoint else read_checkpoint_meta(model_path, scan=True) - ckpt = ckpt.get("state_dict", ckpt) - - for key in ckpt.keys(): - if any(key.startswith(v) for v in {"cond_stage_model.", "first_stage_model.", "model.diffusion_model."}): - return ModelType.Main - elif any(key.startswith(v) for v in {"encoder.conv_in", "decoder.conv_in"}): - return ModelType.Vae - elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}): - return ModelType.Lora - elif any(key.endswith(v) for v in {"to_k_lora.up.weight", "to_q_lora.down.weight"}): - return ModelType.Lora - elif any(key.startswith(v) for v in {"control_model", "input_blocks"}): - return ModelType.ControlNet - elif key in {"emb_params", "string_to_param"}: - return ModelType.TextualInversion - - else: - # diffusers-ti - if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()): - return ModelType.TextualInversion - - raise InvalidModelException(f"Unable to determine model type for {model_path}") - - @classmethod - def get_model_type_from_folder(cls, folder_path: Path, model: ModelMixin) -> ModelType: - """ - Get the model type of a hugging-face style folder. - """ - class_name = None - error_hint = None - if model: - class_name = model.__class__.__name__ - else: - for suffix in ["bin", "safetensors"]: - if (folder_path / f"learned_embeds.{suffix}").exists(): - return ModelType.TextualInversion - if (folder_path / f"pytorch_lora_weights.{suffix}").exists(): - return ModelType.Lora - if (folder_path / "unet/model.onnx").exists(): - return ModelType.ONNX - if (folder_path / "image_encoder.txt").exists(): - return ModelType.IPAdapter - - i = folder_path / "model_index.json" - c = folder_path / "config.json" - config_path = i if i.exists() else c if c.exists() else None - - if config_path: - with open(config_path, "r") as file: - conf = json.load(file) - if "_class_name" in conf: - class_name = conf["_class_name"] - elif "architectures" in conf: - class_name = conf["architectures"][0] - else: - class_name = None - else: - error_hint = f"No model_index.json or config.json found in {folder_path}." - - if class_name and (type := cls.CLASS2TYPE.get(class_name)): - return type - else: - error_hint = f"class {class_name} is not one of the supported classes [{', '.join(cls.CLASS2TYPE.keys())}]" - - # give up - raise InvalidModelException( - f"Unable to determine model type for {folder_path}" + (f"; {error_hint}" if error_hint else "") - ) - - @classmethod - def _scan_and_load_checkpoint(cls, model_path: Path) -> dict: - with SilenceWarnings(): - if model_path.suffix.endswith((".ckpt", ".pt", ".bin")): - cls._scan_model(model_path, model_path) - return torch.load(model_path, map_location="cpu") - else: - return safetensors.torch.load_file(model_path) - - @classmethod - def _scan_model(cls, model_name, checkpoint): - """ - Apply picklescanner to the indicated checkpoint and issue a warning - and option to exit if an infected file is identified. - """ - # scan model - scan_result = scan_file_path(checkpoint) - if scan_result.infected_files != 0: - raise Exception("The model {model_name} is potentially infected by malware. Aborting import.") - - -# ##################################################3 -# Checkpoint probing -# ##################################################3 -class ProbeBase(object): - def get_base_type(self) -> BaseModelType: - pass - - def get_variant_type(self) -> ModelVariantType: - pass - - def get_scheduler_prediction_type(self) -> SchedulerPredictionType: - pass - - def get_format(self) -> str: - pass - - -class CheckpointProbeBase(ProbeBase): - def __init__( - self, checkpoint_path: Path, checkpoint: dict, helper: Callable[[Path], SchedulerPredictionType] = None - ) -> BaseModelType: - self.checkpoint = checkpoint or ModelProbe._scan_and_load_checkpoint(checkpoint_path) - self.checkpoint_path = checkpoint_path - self.helper = helper - - def get_base_type(self) -> BaseModelType: - pass - - def get_format(self) -> str: - return "checkpoint" - - def get_variant_type(self) -> ModelVariantType: - model_type = ModelProbe.get_model_type_from_checkpoint(self.checkpoint_path, self.checkpoint) - if model_type != ModelType.Main: - return ModelVariantType.Normal - state_dict = self.checkpoint.get("state_dict") or self.checkpoint - in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1] - if in_channels == 9: - return ModelVariantType.Inpaint - elif in_channels == 5: - return ModelVariantType.Depth - elif in_channels == 4: - return ModelVariantType.Normal - else: - raise InvalidModelException( - f"Cannot determine variant type (in_channels={in_channels}) at {self.checkpoint_path}" - ) - - -class PipelineCheckpointProbe(CheckpointProbeBase): - def get_base_type(self) -> BaseModelType: - checkpoint = self.checkpoint - state_dict = self.checkpoint.get("state_dict") or checkpoint - key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" - if key_name in state_dict and state_dict[key_name].shape[-1] == 768: - return BaseModelType.StableDiffusion1 - if key_name in state_dict and state_dict[key_name].shape[-1] == 1024: - return BaseModelType.StableDiffusion2 - key_name = "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight" - if key_name in state_dict and state_dict[key_name].shape[-1] == 2048: - return BaseModelType.StableDiffusionXL - elif key_name in state_dict and state_dict[key_name].shape[-1] == 1280: - return BaseModelType.StableDiffusionXLRefiner - else: - raise InvalidModelException("Cannot determine base type") - - def get_scheduler_prediction_type(self) -> Optional[SchedulerPredictionType]: - """Return model prediction type.""" - # if there is a .yaml associated with this checkpoint, then we do not need - # to probe for the prediction type as it will be ignored. - if self.checkpoint_path and self.checkpoint_path.with_suffix(".yaml").exists(): - return None - - type = self.get_base_type() - if type == BaseModelType.StableDiffusion2: - checkpoint = self.checkpoint - state_dict = self.checkpoint.get("state_dict") or checkpoint - key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" - if key_name in state_dict and state_dict[key_name].shape[-1] == 1024: - if "global_step" in checkpoint: - if checkpoint["global_step"] == 220000: - return SchedulerPredictionType.Epsilon - elif checkpoint["global_step"] == 110000: - return SchedulerPredictionType.VPrediction - if self.helper and self.checkpoint_path: - if helper_guess := self.helper(self.checkpoint_path): - return helper_guess - return SchedulerPredictionType.VPrediction # a guess for sd2 ckpts - - elif type == BaseModelType.StableDiffusion1: - if self.helper and self.checkpoint_path: - if helper_guess := self.helper(self.checkpoint_path): - return helper_guess - return SchedulerPredictionType.Epsilon # a reasonable guess for sd1 ckpts - else: - return None - - -class VaeCheckpointProbe(CheckpointProbeBase): - def get_base_type(self) -> BaseModelType: - # I can't find any standalone 2.X VAEs to test with! - return BaseModelType.StableDiffusion1 - - -class LoRACheckpointProbe(CheckpointProbeBase): - def get_format(self) -> str: - return "lycoris" - - def get_base_type(self) -> BaseModelType: - checkpoint = self.checkpoint - token_vector_length = lora_token_vector_length(checkpoint) - - if token_vector_length == 768: - return BaseModelType.StableDiffusion1 - elif token_vector_length == 1024: - return BaseModelType.StableDiffusion2 - elif token_vector_length == 1280: - return BaseModelType.StableDiffusionXL # recognizes format at https://civitai.com/models/224641 - elif token_vector_length == 2048: - return BaseModelType.StableDiffusionXL - else: - raise InvalidModelException(f"Unknown LoRA type: {self.checkpoint_path}") - - -class TextualInversionCheckpointProbe(CheckpointProbeBase): - def get_format(self) -> str: - return None - - def get_base_type(self) -> BaseModelType: - checkpoint = self.checkpoint - if "string_to_token" in checkpoint: - token_dim = list(checkpoint["string_to_param"].values())[0].shape[-1] - elif "emb_params" in checkpoint: - token_dim = checkpoint["emb_params"].shape[-1] - elif "clip_g" in checkpoint: - token_dim = checkpoint["clip_g"].shape[-1] - else: - token_dim = list(checkpoint.values())[0].shape[-1] - if token_dim == 768: - return BaseModelType.StableDiffusion1 - elif token_dim == 1024: - return BaseModelType.StableDiffusion2 - elif token_dim == 1280: - return BaseModelType.StableDiffusionXL - else: - return None - - -class ControlNetCheckpointProbe(CheckpointProbeBase): - def get_base_type(self) -> BaseModelType: - checkpoint = self.checkpoint - for key_name in ( - "control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight", - "input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight", - ): - if key_name not in checkpoint: - continue - if checkpoint[key_name].shape[-1] == 768: - return BaseModelType.StableDiffusion1 - elif checkpoint[key_name].shape[-1] == 1024: - return BaseModelType.StableDiffusion2 - elif self.checkpoint_path and self.helper: - return self.helper(self.checkpoint_path) - raise InvalidModelException("Unable to determine base type for {self.checkpoint_path}") - - -class IPAdapterCheckpointProbe(CheckpointProbeBase): - def get_base_type(self) -> BaseModelType: - raise NotImplementedError() - - -class CLIPVisionCheckpointProbe(CheckpointProbeBase): - def get_base_type(self) -> BaseModelType: - raise NotImplementedError() - - -class T2IAdapterCheckpointProbe(CheckpointProbeBase): - def get_base_type(self) -> BaseModelType: - raise NotImplementedError() - - -######################################################## -# classes for probing folders -####################################################### -class FolderProbeBase(ProbeBase): - def __init__(self, folder_path: Path, model: ModelMixin = None, helper: Callable = None): # not used - self.model = model - self.folder_path = folder_path - - def get_variant_type(self) -> ModelVariantType: - return ModelVariantType.Normal - - def get_format(self) -> str: - return "diffusers" - - -class PipelineFolderProbe(FolderProbeBase): - def get_base_type(self) -> BaseModelType: - if self.model: - unet_conf = self.model.unet.config - else: - with open(self.folder_path / "unet" / "config.json", "r") as file: - unet_conf = json.load(file) - if unet_conf["cross_attention_dim"] == 768: - return BaseModelType.StableDiffusion1 - elif unet_conf["cross_attention_dim"] == 1024: - return BaseModelType.StableDiffusion2 - elif unet_conf["cross_attention_dim"] == 1280: - return BaseModelType.StableDiffusionXLRefiner - elif unet_conf["cross_attention_dim"] == 2048: - return BaseModelType.StableDiffusionXL - else: - raise InvalidModelException(f"Unknown base model for {self.folder_path}") - - def get_scheduler_prediction_type(self) -> SchedulerPredictionType: - if self.model: - scheduler_conf = self.model.scheduler.config - else: - with open(self.folder_path / "scheduler" / "scheduler_config.json", "r") as file: - scheduler_conf = json.load(file) - if scheduler_conf["prediction_type"] == "v_prediction": - return SchedulerPredictionType.VPrediction - elif scheduler_conf["prediction_type"] == "epsilon": - return SchedulerPredictionType.Epsilon - else: - return None - - def get_variant_type(self) -> ModelVariantType: - # This only works for pipelines! Any kind of - # exception results in our returning the - # "normal" variant type - try: - if self.model: - conf = self.model.unet.config - else: - config_file = self.folder_path / "unet" / "config.json" - with open(config_file, "r") as file: - conf = json.load(file) - - in_channels = conf["in_channels"] - if in_channels == 9: - return ModelVariantType.Inpaint - elif in_channels == 5: - return ModelVariantType.Depth - elif in_channels == 4: - return ModelVariantType.Normal - except Exception: - pass - return ModelVariantType.Normal - - -class VaeFolderProbe(FolderProbeBase): - def get_base_type(self) -> BaseModelType: - if self._config_looks_like_sdxl(): - return BaseModelType.StableDiffusionXL - elif self._name_looks_like_sdxl(): - # but SD and SDXL VAE are the same shape (3-channel RGB to 4-channel float scaled down - # by a factor of 8), we can't necessarily tell them apart by config hyperparameters. - return BaseModelType.StableDiffusionXL - else: - return BaseModelType.StableDiffusion1 - - def _config_looks_like_sdxl(self) -> bool: - # config values that distinguish Stability's SD 1.x VAE from their SDXL VAE. - config_file = self.folder_path / "config.json" - if not config_file.exists(): - raise InvalidModelException(f"Cannot determine base type for {self.folder_path}") - with open(config_file, "r") as file: - config = json.load(file) - return config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024] - - def _name_looks_like_sdxl(self) -> bool: - return bool(re.search(r"xl\b", self._guess_name(), re.IGNORECASE)) - - def _guess_name(self) -> str: - name = self.folder_path.name - if name == "vae": - name = self.folder_path.parent.name - return name - - -class TextualInversionFolderProbe(FolderProbeBase): - def get_format(self) -> str: - return None - - def get_base_type(self) -> BaseModelType: - path = self.folder_path / "learned_embeds.bin" - if not path.exists(): - return None - checkpoint = ModelProbe._scan_and_load_checkpoint(path) - return TextualInversionCheckpointProbe(None, checkpoint=checkpoint).get_base_type() - - -class ONNXFolderProbe(FolderProbeBase): - def get_format(self) -> str: - return "onnx" - - def get_base_type(self) -> BaseModelType: - return BaseModelType.StableDiffusion1 - - def get_variant_type(self) -> ModelVariantType: - return ModelVariantType.Normal - - -class ControlNetFolderProbe(FolderProbeBase): - def get_base_type(self) -> BaseModelType: - config_file = self.folder_path / "config.json" - if not config_file.exists(): - raise InvalidModelException(f"Cannot determine base type for {self.folder_path}") - with open(config_file, "r") as file: - config = json.load(file) - # no obvious way to distinguish between sd2-base and sd2-768 - dimension = config["cross_attention_dim"] - base_model = ( - BaseModelType.StableDiffusion1 - if dimension == 768 - else ( - BaseModelType.StableDiffusion2 - if dimension == 1024 - else BaseModelType.StableDiffusionXL - if dimension == 2048 - else None - ) - ) - if not base_model: - raise InvalidModelException(f"Unable to determine model base for {self.folder_path}") - return base_model - - -class LoRAFolderProbe(FolderProbeBase): - def get_base_type(self) -> BaseModelType: - model_file = None - for suffix in ["safetensors", "bin"]: - base_file = self.folder_path / f"pytorch_lora_weights.{suffix}" - if base_file.exists(): - model_file = base_file - break - if not model_file: - raise InvalidModelException("Unknown LoRA format encountered") - return LoRACheckpointProbe(model_file, None).get_base_type() - - -class IPAdapterFolderProbe(FolderProbeBase): - def get_format(self) -> str: - return IPAdapterModelFormat.InvokeAI.value - - def get_base_type(self) -> BaseModelType: - model_file = self.folder_path / "ip_adapter.bin" - if not model_file.exists(): - raise InvalidModelException("Unknown IP-Adapter model format.") - - state_dict = torch.load(model_file, map_location="cpu") - cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1] - if cross_attention_dim == 768: - return BaseModelType.StableDiffusion1 - elif cross_attention_dim == 1024: - return BaseModelType.StableDiffusion2 - elif cross_attention_dim == 2048: - return BaseModelType.StableDiffusionXL - else: - raise InvalidModelException(f"IP-Adapter had unexpected cross-attention dimension: {cross_attention_dim}.") - - -class CLIPVisionFolderProbe(FolderProbeBase): - def get_base_type(self) -> BaseModelType: - return BaseModelType.Any - - -class T2IAdapterFolderProbe(FolderProbeBase): - def get_base_type(self) -> BaseModelType: - config_file = self.folder_path / "config.json" - if not config_file.exists(): - raise InvalidModelException(f"Cannot determine base type for {self.folder_path}") - with open(config_file, "r") as file: - config = json.load(file) - - adapter_type = config.get("adapter_type", None) - if adapter_type == "full_adapter_xl": - return BaseModelType.StableDiffusionXL - elif adapter_type == "full_adapter" or "light_adapter": - # I haven't seen any T2I adapter models for SD2, so assume that this is an SD1 adapter. - return BaseModelType.StableDiffusion1 - else: - raise InvalidModelException( - f"Unable to determine base model for '{self.folder_path}' (adapter_type = {adapter_type})." - ) - - -############## register probe classes ###### -ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.Vae, VaeFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.Lora, LoRAFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe) - -ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.Lora, LoRACheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.T2IAdapter, T2IAdapterCheckpointProbe) - -ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe) diff --git a/invokeai/backend/model_management_OLD/model_search.py b/invokeai/backend/model_management_OLD/model_search.py deleted file mode 100644 index e125c3ced7f..00000000000 --- a/invokeai/backend/model_management_OLD/model_search.py +++ /dev/null @@ -1,112 +0,0 @@ -# Copyright 2023, Lincoln D. Stein and the InvokeAI Team -""" -Abstract base class for recursive directory search for models. -""" - -import os -from abc import ABC, abstractmethod -from pathlib import Path -from typing import List, Set, types - -import invokeai.backend.util.logging as logger - - -class ModelSearch(ABC): - def __init__(self, directories: List[Path], logger: types.ModuleType = logger): - """ - Initialize a recursive model directory search. - :param directories: List of directory Paths to recurse through - :param logger: Logger to use - """ - self.directories = directories - self.logger = logger - self._items_scanned = 0 - self._models_found = 0 - self._scanned_dirs = set() - self._scanned_paths = set() - self._pruned_paths = set() - - @abstractmethod - def on_search_started(self): - """ - Called before the scan starts. - """ - pass - - @abstractmethod - def on_model_found(self, model: Path): - """ - Process a found model. Raise an exception if something goes wrong. - :param model: Model to process - could be a directory or checkpoint. - """ - pass - - @abstractmethod - def on_search_completed(self): - """ - Perform some activity when the scan is completed. May use instance - variables, items_scanned and models_found - """ - pass - - def search(self): - self.on_search_started() - for dir in self.directories: - self.walk_directory(dir) - self.on_search_completed() - - def walk_directory(self, path: Path): - for root, dirs, files in os.walk(path, followlinks=True): - if str(Path(root).name).startswith("."): - self._pruned_paths.add(root) - if any(Path(root).is_relative_to(x) for x in self._pruned_paths): - continue - - self._items_scanned += len(dirs) + len(files) - for d in dirs: - path = Path(root) / d - if path in self._scanned_paths or path.parent in self._scanned_dirs: - self._scanned_dirs.add(path) - continue - if any( - (path / x).exists() - for x in { - "config.json", - "model_index.json", - "learned_embeds.bin", - "pytorch_lora_weights.bin", - "image_encoder.txt", - } - ): - try: - self.on_model_found(path) - self._models_found += 1 - self._scanned_dirs.add(path) - except Exception as e: - self.logger.warning(f"Failed to process '{path}': {e}") - - for f in files: - path = Path(root) / f - if path.parent in self._scanned_dirs: - continue - if path.suffix in {".ckpt", ".bin", ".pth", ".safetensors", ".pt"}: - try: - self.on_model_found(path) - self._models_found += 1 - except Exception as e: - self.logger.warning(f"Failed to process '{path}': {e}") - - -class FindModels(ModelSearch): - def on_search_started(self): - self.models_found: Set[Path] = set() - - def on_model_found(self, model: Path): - self.models_found.add(model) - - def on_search_completed(self): - pass - - def list_models(self) -> List[Path]: - self.search() - return list(self.models_found) diff --git a/invokeai/backend/model_management_OLD/models/__init__.py b/invokeai/backend/model_management_OLD/models/__init__.py deleted file mode 100644 index 5f9b13b96f1..00000000000 --- a/invokeai/backend/model_management_OLD/models/__init__.py +++ /dev/null @@ -1,167 +0,0 @@ -import inspect -from enum import Enum -from typing import Literal, get_origin - -from pydantic import BaseModel, ConfigDict, create_model - -from .base import ( # noqa: F401 - BaseModelType, - DuplicateModelException, - InvalidModelException, - ModelBase, - ModelConfigBase, - ModelError, - ModelNotFoundException, - ModelType, - ModelVariantType, - SchedulerPredictionType, - SilenceWarnings, - SubModelType, -) -from .clip_vision import CLIPVisionModel -from .controlnet import ControlNetModel # TODO: -from .ip_adapter import IPAdapterModel -from .lora import LoRAModel -from .sdxl import StableDiffusionXLModel -from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model -from .stable_diffusion_onnx import ONNXStableDiffusion1Model, ONNXStableDiffusion2Model -from .t2i_adapter import T2IAdapterModel -from .textual_inversion import TextualInversionModel -from .vae import VaeModel - -MODEL_CLASSES = { - BaseModelType.StableDiffusion1: { - ModelType.ONNX: ONNXStableDiffusion1Model, - ModelType.Main: StableDiffusion1Model, - ModelType.Vae: VaeModel, - ModelType.Lora: LoRAModel, - ModelType.ControlNet: ControlNetModel, - ModelType.TextualInversion: TextualInversionModel, - ModelType.IPAdapter: IPAdapterModel, - ModelType.CLIPVision: CLIPVisionModel, - ModelType.T2IAdapter: T2IAdapterModel, - }, - BaseModelType.StableDiffusion2: { - ModelType.ONNX: ONNXStableDiffusion2Model, - ModelType.Main: StableDiffusion2Model, - ModelType.Vae: VaeModel, - ModelType.Lora: LoRAModel, - ModelType.ControlNet: ControlNetModel, - ModelType.TextualInversion: TextualInversionModel, - ModelType.IPAdapter: IPAdapterModel, - ModelType.CLIPVision: CLIPVisionModel, - ModelType.T2IAdapter: T2IAdapterModel, - }, - BaseModelType.StableDiffusionXL: { - ModelType.Main: StableDiffusionXLModel, - ModelType.Vae: VaeModel, - # will not work until support written - ModelType.Lora: LoRAModel, - ModelType.ControlNet: ControlNetModel, - ModelType.TextualInversion: TextualInversionModel, - ModelType.ONNX: ONNXStableDiffusion2Model, - ModelType.IPAdapter: IPAdapterModel, - ModelType.CLIPVision: CLIPVisionModel, - ModelType.T2IAdapter: T2IAdapterModel, - }, - BaseModelType.StableDiffusionXLRefiner: { - ModelType.Main: StableDiffusionXLModel, - ModelType.Vae: VaeModel, - # will not work until support written - ModelType.Lora: LoRAModel, - ModelType.ControlNet: ControlNetModel, - ModelType.TextualInversion: TextualInversionModel, - ModelType.ONNX: ONNXStableDiffusion2Model, - ModelType.IPAdapter: IPAdapterModel, - ModelType.CLIPVision: CLIPVisionModel, - ModelType.T2IAdapter: T2IAdapterModel, - }, - BaseModelType.Any: { - ModelType.CLIPVision: CLIPVisionModel, - # The following model types are not expected to be used with BaseModelType.Any. - ModelType.ONNX: ONNXStableDiffusion2Model, - ModelType.Main: StableDiffusion2Model, - ModelType.Vae: VaeModel, - ModelType.Lora: LoRAModel, - ModelType.ControlNet: ControlNetModel, - ModelType.TextualInversion: TextualInversionModel, - ModelType.IPAdapter: IPAdapterModel, - ModelType.T2IAdapter: T2IAdapterModel, - }, - # BaseModelType.Kandinsky2_1: { - # ModelType.Main: Kandinsky2_1Model, - # ModelType.MoVQ: MoVQModel, - # ModelType.Lora: LoRAModel, - # ModelType.ControlNet: ControlNetModel, - # ModelType.TextualInversion: TextualInversionModel, - # }, -} - -MODEL_CONFIGS = [] -OPENAPI_MODEL_CONFIGS = [] - - -class OpenAPIModelInfoBase(BaseModel): - model_name: str - base_model: BaseModelType - model_type: ModelType - - model_config = ConfigDict(protected_namespaces=()) - - -for _base_model, models in MODEL_CLASSES.items(): - for model_type, model_class in models.items(): - model_configs = set(model_class._get_configs().values()) - model_configs.discard(None) - MODEL_CONFIGS.extend(model_configs) - - # LS: sort to get the checkpoint configs first, which makes - # for a better template in the Swagger docs - for cfg in sorted(model_configs, key=lambda x: str(x)): - model_name, cfg_name = cfg.__qualname__.split(".")[-2:] - openapi_cfg_name = model_name + cfg_name - if openapi_cfg_name in vars(): - continue - - api_wrapper = create_model( - openapi_cfg_name, - __base__=(cfg, OpenAPIModelInfoBase), - model_type=(Literal[model_type], model_type), # type: ignore - ) - vars()[openapi_cfg_name] = api_wrapper - OPENAPI_MODEL_CONFIGS.append(api_wrapper) - - -def get_model_config_enums(): - enums = [] - - for model_config in MODEL_CONFIGS: - if hasattr(inspect, "get_annotations"): - fields = inspect.get_annotations(model_config) - else: - fields = model_config.__annotations__ - try: - field = fields["model_format"] - except Exception: - raise Exception("format field not found") - - # model_format: None - # model_format: SomeModelFormat - # model_format: Literal[SomeModelFormat.Diffusers] - # model_format: Literal[SomeModelFormat.Diffusers, SomeModelFormat.Checkpoint] - - if isinstance(field, type) and issubclass(field, str) and issubclass(field, Enum): - enums.append(field) - - elif get_origin(field) is Literal and all( - isinstance(arg, str) and isinstance(arg, Enum) for arg in field.__args__ - ): - enums.append(type(field.__args__[0])) - - elif field is None: - pass - - else: - raise Exception(f"Unsupported format definition in {model_configs.__qualname__}") - - return enums diff --git a/invokeai/backend/model_management_OLD/models/base.py b/invokeai/backend/model_management_OLD/models/base.py deleted file mode 100644 index 7807cb9a542..00000000000 --- a/invokeai/backend/model_management_OLD/models/base.py +++ /dev/null @@ -1,681 +0,0 @@ -import inspect -import json -import os -import sys -import typing -import warnings -from abc import ABCMeta, abstractmethod -from contextlib import suppress -from enum import Enum -from pathlib import Path -from typing import Any, Callable, Dict, Generic, List, Literal, Optional, Type, TypeVar, Union - -import numpy as np -import onnx -import safetensors.torch -import torch -from diffusers import ConfigMixin, DiffusionPipeline -from diffusers import logging as diffusers_logging -from onnx import numpy_helper -from onnxruntime import InferenceSession, SessionOptions, get_available_providers -from picklescan.scanner import scan_file_path -from pydantic import BaseModel, ConfigDict, Field -from transformers import logging as transformers_logging - - -class DuplicateModelException(Exception): - pass - - -class InvalidModelException(Exception): - pass - - -class ModelNotFoundException(Exception): - pass - - -class BaseModelType(str, Enum): - Any = "any" # For models that are not associated with any particular base model. - StableDiffusion1 = "sd-1" - StableDiffusion2 = "sd-2" - StableDiffusionXL = "sdxl" - StableDiffusionXLRefiner = "sdxl-refiner" - # Kandinsky2_1 = "kandinsky-2.1" - - -class ModelType(str, Enum): - ONNX = "onnx" - Main = "main" - Vae = "vae" - Lora = "lora" - ControlNet = "controlnet" # used by model_probe - TextualInversion = "embedding" - IPAdapter = "ip_adapter" - CLIPVision = "clip_vision" - T2IAdapter = "t2i_adapter" - - -class SubModelType(str, Enum): - UNet = "unet" - TextEncoder = "text_encoder" - TextEncoder2 = "text_encoder_2" - Tokenizer = "tokenizer" - Tokenizer2 = "tokenizer_2" - Vae = "vae" - VaeDecoder = "vae_decoder" - VaeEncoder = "vae_encoder" - Scheduler = "scheduler" - SafetyChecker = "safety_checker" - # MoVQ = "movq" - - -class ModelVariantType(str, Enum): - Normal = "normal" - Inpaint = "inpaint" - Depth = "depth" - - -class SchedulerPredictionType(str, Enum): - Epsilon = "epsilon" - VPrediction = "v_prediction" - Sample = "sample" - - -class ModelError(str, Enum): - NotFound = "not_found" - - -def model_config_json_schema_extra(schema: dict[str, Any]) -> None: - if "required" not in schema: - schema["required"] = [] - schema["required"].append("model_type") - - -class ModelConfigBase(BaseModel): - path: str # or Path - description: Optional[str] = Field(None) - model_format: Optional[str] = Field(None) - error: Optional[ModelError] = Field(None) - - model_config = ConfigDict( - use_enum_values=True, protected_namespaces=(), json_schema_extra=model_config_json_schema_extra - ) - - -class EmptyConfigLoader(ConfigMixin): - @classmethod - def load_config(cls, *args, **kwargs): - cls.config_name = kwargs.pop("config_name") - return super().load_config(*args, **kwargs) - - -T_co = TypeVar("T_co", covariant=True) - - -class classproperty(Generic[T_co]): - def __init__(self, fget: Callable[[Any], T_co]) -> None: - self.fget = fget - - def __get__(self, instance: Optional[Any], owner: Type[Any]) -> T_co: - return self.fget(owner) - - def __set__(self, instance: Optional[Any], value: Any) -> None: - raise AttributeError("cannot set attribute") - - -class ModelBase(metaclass=ABCMeta): - # model_path: str - # base_model: BaseModelType - # model_type: ModelType - - def __init__( - self, - model_path: str, - base_model: BaseModelType, - model_type: ModelType, - ): - self.model_path = model_path - self.base_model = base_model - self.model_type = model_type - - def _hf_definition_to_type(self, subtypes: List[str]) -> Type: - if len(subtypes) < 2: - raise Exception("Invalid subfolder definition!") - if all(t is None for t in subtypes): - return None - elif any(t is None for t in subtypes): - raise Exception(f"Unsupported definition: {subtypes}") - - if subtypes[0] in ["diffusers", "transformers"]: - res_type = sys.modules[subtypes[0]] - subtypes = subtypes[1:] - - else: - res_type = sys.modules["diffusers"] - res_type = res_type.pipelines - - for subtype in subtypes: - res_type = getattr(res_type, subtype) - return res_type - - @classmethod - def _get_configs(cls): - with suppress(Exception): - return cls.__configs - - configs = {} - for name in dir(cls): - if name.startswith("__"): - continue - - value = getattr(cls, name) - if not isinstance(value, type) or not issubclass(value, ModelConfigBase): - continue - - if hasattr(inspect, "get_annotations"): - fields = inspect.get_annotations(value) - else: - fields = value.__annotations__ - try: - field = fields["model_format"] - except Exception: - raise Exception(f"Invalid config definition - format field not found({cls.__qualname__})") - - if isinstance(field, type) and issubclass(field, str) and issubclass(field, Enum): - for model_format in field: - configs[model_format.value] = value - - elif typing.get_origin(field) is Literal and all( - isinstance(arg, str) and isinstance(arg, Enum) for arg in field.__args__ - ): - for model_format in field.__args__: - configs[model_format.value] = value - - elif field is None: - configs[None] = value - - else: - raise Exception(f"Unsupported format definition in {cls.__qualname__}") - - cls.__configs = configs - return cls.__configs - - @classmethod - def create_config(cls, **kwargs) -> ModelConfigBase: - if "model_format" not in kwargs: - raise Exception("Field 'model_format' not found in model config") - - configs = cls._get_configs() - return configs[kwargs["model_format"]](**kwargs) - - @classmethod - def probe_config(cls, path: str, **kwargs) -> ModelConfigBase: - return cls.create_config( - path=path, - model_format=cls.detect_format(path), - ) - - @classmethod - @abstractmethod - def detect_format(cls, path: str) -> str: - raise NotImplementedError() - - @classproperty - @abstractmethod - def save_to_config(cls) -> bool: - raise NotImplementedError() - - @abstractmethod - def get_size(self, child_type: Optional[SubModelType] = None) -> int: - raise NotImplementedError() - - @abstractmethod - def get_model( - self, - torch_dtype: Optional[torch.dtype], - child_type: Optional[SubModelType] = None, - ) -> Any: - raise NotImplementedError() - - -class DiffusersModel(ModelBase): - # child_types: Dict[str, Type] - # child_sizes: Dict[str, int] - - def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): - super().__init__(model_path, base_model, model_type) - - self.child_types: Dict[str, Type] = {} - self.child_sizes: Dict[str, int] = {} - - try: - config_data = DiffusionPipeline.load_config(self.model_path) - # config_data = json.loads(os.path.join(self.model_path, "model_index.json")) - except Exception: - raise Exception("Invalid diffusers model! (model_index.json not found or invalid)") - - config_data.pop("_ignore_files", None) - - # retrieve all folder_names that contain relevant files - child_components = [k for k, v in config_data.items() if isinstance(v, list)] - - for child_name in child_components: - child_type = self._hf_definition_to_type(config_data[child_name]) - self.child_types[child_name] = child_type - self.child_sizes[child_name] = calc_model_size_by_fs(self.model_path, subfolder=child_name) - - def get_size(self, child_type: Optional[SubModelType] = None): - if child_type is None: - return sum(self.child_sizes.values()) - else: - return self.child_sizes[child_type] - - def get_model( - self, - torch_dtype: Optional[torch.dtype], - child_type: Optional[SubModelType] = None, - ): - # return pipeline in different function to pass more arguments - if child_type is None: - raise Exception("Child model type can't be null on diffusers model") - if child_type not in self.child_types: - return None # TODO: or raise - - if torch_dtype == torch.float16: - variants = ["fp16", None] - else: - variants = [None, "fp16"] - - # TODO: better error handling(differentiate not found from others) - for variant in variants: - try: - # TODO: set cache_dir to /dev/null to be sure that cache not used? - model = self.child_types[child_type].from_pretrained( - self.model_path, - subfolder=child_type.value, - torch_dtype=torch_dtype, - variant=variant, - local_files_only=True, - ) - break - except Exception as e: - if not str(e).startswith("Error no file"): - print("====ERR LOAD====") - print(f"{variant}: {e}") - pass - else: - raise Exception(f"Failed to load {self.base_model}:{self.model_type}:{child_type} model") - - # calc more accurate size - self.child_sizes[child_type] = calc_model_size_by_data(model) - return model - - # def convert_if_required(model_path: str, cache_path: str, config: Optional[dict]) -> str: - - -def calc_model_size_by_fs(model_path: str, subfolder: Optional[str] = None, variant: Optional[str] = None): - if subfolder is not None: - model_path = os.path.join(model_path, subfolder) - - # this can happen when, for example, the safety checker - # is not downloaded. - if not os.path.exists(model_path): - return 0 - - all_files = os.listdir(model_path) - all_files = [f for f in all_files if os.path.isfile(os.path.join(model_path, f))] - - fp16_files = {f for f in all_files if ".fp16." in f or ".fp16-" in f} - bit8_files = {f for f in all_files if ".8bit." in f or ".8bit-" in f} - other_files = set(all_files) - fp16_files - bit8_files - - if variant is None: - files = other_files - elif variant == "fp16": - files = fp16_files - elif variant == "8bit": - files = bit8_files - else: - raise NotImplementedError(f"Unknown variant: {variant}") - - # try read from index if exists - index_postfix = ".index.json" - if variant is not None: - index_postfix = f".index.{variant}.json" - - for file in files: - if not file.endswith(index_postfix): - continue - try: - with open(os.path.join(model_path, file), "r") as f: - index_data = json.loads(f.read()) - return int(index_data["metadata"]["total_size"]) - except Exception: - pass - - # calculate files size if there is no index file - formats = [ - (".safetensors",), # safetensors - (".bin",), # torch - (".onnx", ".pb"), # onnx - (".msgpack",), # flax - (".ckpt",), # tf - (".h5",), # tf2 - ] - - for file_format in formats: - model_files = [f for f in files if f.endswith(file_format)] - if len(model_files) == 0: - continue - - model_size = 0 - for model_file in model_files: - file_stats = os.stat(os.path.join(model_path, model_file)) - model_size += file_stats.st_size - return model_size - - # raise NotImplementedError(f"Unknown model structure! Files: {all_files}") - return 0 # scheduler/feature_extractor/tokenizer - models without loading to gpu - - -def calc_model_size_by_data(model) -> int: - if isinstance(model, DiffusionPipeline): - return _calc_pipeline_by_data(model) - elif isinstance(model, torch.nn.Module): - return _calc_model_by_data(model) - elif isinstance(model, IAIOnnxRuntimeModel): - return _calc_onnx_model_by_data(model) - else: - return 0 - - -def _calc_pipeline_by_data(pipeline) -> int: - res = 0 - for submodel_key in pipeline.components.keys(): - submodel = getattr(pipeline, submodel_key) - if submodel is not None and isinstance(submodel, torch.nn.Module): - res += _calc_model_by_data(submodel) - return res - - -def _calc_model_by_data(model) -> int: - mem_params = sum([param.nelement() * param.element_size() for param in model.parameters()]) - mem_bufs = sum([buf.nelement() * buf.element_size() for buf in model.buffers()]) - mem = mem_params + mem_bufs # in bytes - return mem - - -def _calc_onnx_model_by_data(model) -> int: - tensor_size = model.tensors.size() * 2 # The session doubles this - mem = tensor_size # in bytes - return mem - - -def _fast_safetensors_reader(path: str): - checkpoint = {} - device = torch.device("meta") - with open(path, "rb") as f: - definition_len = int.from_bytes(f.read(8), "little") - definition_json = f.read(definition_len) - definition = json.loads(definition_json) - - if "__metadata__" in definition and definition["__metadata__"].get("format", "pt") not in { - "pt", - "torch", - "pytorch", - }: - raise Exception("Supported only pytorch safetensors files") - definition.pop("__metadata__", None) - - for key, info in definition.items(): - dtype = { - "I8": torch.int8, - "I16": torch.int16, - "I32": torch.int32, - "I64": torch.int64, - "F16": torch.float16, - "F32": torch.float32, - "F64": torch.float64, - }[info["dtype"]] - - checkpoint[key] = torch.empty(info["shape"], dtype=dtype, device=device) - - return checkpoint - - -def read_checkpoint_meta(path: Union[str, Path], scan: bool = False): - if str(path).endswith(".safetensors"): - try: - checkpoint = _fast_safetensors_reader(path) - except Exception: - # TODO: create issue for support "meta"? - checkpoint = safetensors.torch.load_file(path, device="cpu") - else: - if scan: - scan_result = scan_file_path(path) - if scan_result.infected_files != 0: - raise Exception(f'The model file "{path}" is potentially infected by malware. Aborting import.') - checkpoint = torch.load(path, map_location=torch.device("meta")) - return checkpoint - - -class SilenceWarnings(object): - def __init__(self): - self.transformers_verbosity = transformers_logging.get_verbosity() - self.diffusers_verbosity = diffusers_logging.get_verbosity() - - def __enter__(self): - transformers_logging.set_verbosity_error() - diffusers_logging.set_verbosity_error() - warnings.simplefilter("ignore") - - def __exit__(self, type, value, traceback): - transformers_logging.set_verbosity(self.transformers_verbosity) - diffusers_logging.set_verbosity(self.diffusers_verbosity) - warnings.simplefilter("default") - - -ONNX_WEIGHTS_NAME = "model.onnx" - - -class IAIOnnxRuntimeModel: - class _tensor_access: - def __init__(self, model): - self.model = model - self.indexes = {} - for idx, obj in enumerate(self.model.proto.graph.initializer): - self.indexes[obj.name] = idx - - def __getitem__(self, key: str): - value = self.model.proto.graph.initializer[self.indexes[key]] - return numpy_helper.to_array(value) - - def __setitem__(self, key: str, value: np.ndarray): - new_node = numpy_helper.from_array(value) - # set_external_data(new_node, location="in-memory-location") - new_node.name = key - # new_node.ClearField("raw_data") - del self.model.proto.graph.initializer[self.indexes[key]] - self.model.proto.graph.initializer.insert(self.indexes[key], new_node) - # self.model.data[key] = OrtValue.ortvalue_from_numpy(value) - - # __delitem__ - - def __contains__(self, key: str): - return self.indexes[key] in self.model.proto.graph.initializer - - def items(self): - raise NotImplementedError("tensor.items") - # return [(obj.name, obj) for obj in self.raw_proto] - - def keys(self): - return self.indexes.keys() - - def values(self): - raise NotImplementedError("tensor.values") - # return [obj for obj in self.raw_proto] - - def size(self): - bytesSum = 0 - for node in self.model.proto.graph.initializer: - bytesSum += sys.getsizeof(node.raw_data) - return bytesSum - - class _access_helper: - def __init__(self, raw_proto): - self.indexes = {} - self.raw_proto = raw_proto - for idx, obj in enumerate(raw_proto): - self.indexes[obj.name] = idx - - def __getitem__(self, key: str): - return self.raw_proto[self.indexes[key]] - - def __setitem__(self, key: str, value): - index = self.indexes[key] - del self.raw_proto[index] - self.raw_proto.insert(index, value) - - # __delitem__ - - def __contains__(self, key: str): - return key in self.indexes - - def items(self): - return [(obj.name, obj) for obj in self.raw_proto] - - def keys(self): - return self.indexes.keys() - - def values(self): - return list(self.raw_proto) - - def __init__(self, model_path: str, provider: Optional[str]): - self.path = model_path - self.session = None - self.provider = provider - """ - self.data_path = self.path + "_data" - if not os.path.exists(self.data_path): - print(f"Moving model tensors to separate file: {self.data_path}") - tmp_proto = onnx.load(model_path, load_external_data=True) - onnx.save_model(tmp_proto, self.path, save_as_external_data=True, all_tensors_to_one_file=True, location=os.path.basename(self.data_path), size_threshold=1024, convert_attribute=False) - del tmp_proto - gc.collect() - - self.proto = onnx.load(model_path, load_external_data=False) - """ - - self.proto = onnx.load(model_path, load_external_data=True) - # self.data = dict() - # for tensor in self.proto.graph.initializer: - # name = tensor.name - - # if tensor.HasField("raw_data"): - # npt = numpy_helper.to_array(tensor) - # orv = OrtValue.ortvalue_from_numpy(npt) - # # self.data[name] = orv - # # set_external_data(tensor, location="in-memory-location") - # tensor.name = name - # # tensor.ClearField("raw_data") - - self.nodes = self._access_helper(self.proto.graph.node) - # self.initializers = self._access_helper(self.proto.graph.initializer) - # print(self.proto.graph.input) - # print(self.proto.graph.initializer) - - self.tensors = self._tensor_access(self) - - # TODO: integrate with model manager/cache - def create_session(self, height=None, width=None): - if self.session is None or self.session_width != width or self.session_height != height: - # onnx.save(self.proto, "tmp.onnx") - # onnx.save_model(self.proto, "tmp.onnx", save_as_external_data=True, all_tensors_to_one_file=True, location="tmp.onnx_data", size_threshold=1024, convert_attribute=False) - # TODO: something to be able to get weight when they already moved outside of model proto - # (trimmed_model, external_data) = buffer_external_data_tensors(self.proto) - sess = SessionOptions() - # self._external_data.update(**external_data) - # sess.add_external_initializers(list(self.data.keys()), list(self.data.values())) - # sess.enable_profiling = True - - # sess.intra_op_num_threads = 1 - # sess.inter_op_num_threads = 1 - # sess.execution_mode = ExecutionMode.ORT_SEQUENTIAL - # sess.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL - # sess.enable_cpu_mem_arena = True - # sess.enable_mem_pattern = True - # sess.add_session_config_entry("session.intra_op.use_xnnpack_threadpool", "1") ########### It's the key code - self.session_height = height - self.session_width = width - if height and width: - sess.add_free_dimension_override_by_name("unet_sample_batch", 2) - sess.add_free_dimension_override_by_name("unet_sample_channels", 4) - sess.add_free_dimension_override_by_name("unet_hidden_batch", 2) - sess.add_free_dimension_override_by_name("unet_hidden_sequence", 77) - sess.add_free_dimension_override_by_name("unet_sample_height", self.session_height) - sess.add_free_dimension_override_by_name("unet_sample_width", self.session_width) - sess.add_free_dimension_override_by_name("unet_time_batch", 1) - providers = [] - if self.provider: - providers.append(self.provider) - else: - providers = get_available_providers() - if "TensorrtExecutionProvider" in providers: - providers.remove("TensorrtExecutionProvider") - try: - self.session = InferenceSession(self.proto.SerializeToString(), providers=providers, sess_options=sess) - except Exception as e: - raise e - # self.session = InferenceSession("tmp.onnx", providers=[self.provider], sess_options=self.sess_options) - # self.io_binding = self.session.io_binding() - - def release_session(self): - self.session = None - import gc - - gc.collect() - return - - def __call__(self, **kwargs): - if self.session is None: - raise Exception("You should call create_session before running model") - - inputs = {k: np.array(v) for k, v in kwargs.items()} - # output_names = self.session.get_outputs() - # for k in inputs: - # self.io_binding.bind_cpu_input(k, inputs[k]) - # for name in output_names: - # self.io_binding.bind_output(name.name) - # self.session.run_with_iobinding(self.io_binding, None) - # return self.io_binding.copy_outputs_to_cpu() - return self.session.run(None, inputs) - - # compatability with diffusers load code - @classmethod - def from_pretrained( - cls, - model_id: Union[str, Path], - subfolder: Union[str, Path] = None, - file_name: Optional[str] = None, - provider: Optional[str] = None, - sess_options: Optional["SessionOptions"] = None, - **kwargs, - ): - file_name = file_name or ONNX_WEIGHTS_NAME - - if os.path.isdir(model_id): - model_path = model_id - if subfolder is not None: - model_path = os.path.join(model_path, subfolder) - model_path = os.path.join(model_path, file_name) - - else: - model_path = model_id - - # load model from local directory - if not os.path.isfile(model_path): - raise Exception(f"Model not found: {model_path}") - - # TODO: session options - return cls(model_path, provider=provider) diff --git a/invokeai/backend/model_management_OLD/models/clip_vision.py b/invokeai/backend/model_management_OLD/models/clip_vision.py deleted file mode 100644 index 2276c6beed1..00000000000 --- a/invokeai/backend/model_management_OLD/models/clip_vision.py +++ /dev/null @@ -1,82 +0,0 @@ -import os -from enum import Enum -from typing import Literal, Optional - -import torch -from transformers import CLIPVisionModelWithProjection - -from invokeai.backend.model_management.models.base import ( - BaseModelType, - InvalidModelException, - ModelBase, - ModelConfigBase, - ModelType, - SubModelType, - calc_model_size_by_data, - calc_model_size_by_fs, - classproperty, -) - - -class CLIPVisionModelFormat(str, Enum): - Diffusers = "diffusers" - - -class CLIPVisionModel(ModelBase): - class DiffusersConfig(ModelConfigBase): - model_format: Literal[CLIPVisionModelFormat.Diffusers] - - def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): - assert model_type == ModelType.CLIPVision - super().__init__(model_path, base_model, model_type) - - self.model_size = calc_model_size_by_fs(self.model_path) - - @classmethod - def detect_format(cls, path: str) -> str: - if not os.path.exists(path): - raise ModuleNotFoundError(f"No CLIP Vision model at path '{path}'.") - - if os.path.isdir(path) and os.path.exists(os.path.join(path, "config.json")): - return CLIPVisionModelFormat.Diffusers - - raise InvalidModelException(f"Unexpected CLIP Vision model format: {path}") - - @classproperty - def save_to_config(cls) -> bool: - return True - - def get_size(self, child_type: Optional[SubModelType] = None) -> int: - if child_type is not None: - raise ValueError("There are no child models in a CLIP Vision model.") - - return self.model_size - - def get_model( - self, - torch_dtype: Optional[torch.dtype], - child_type: Optional[SubModelType] = None, - ) -> CLIPVisionModelWithProjection: - if child_type is not None: - raise ValueError("There are no child models in a CLIP Vision model.") - - model = CLIPVisionModelWithProjection.from_pretrained(self.model_path, torch_dtype=torch_dtype) - - # Calculate a more accurate model size. - self.model_size = calc_model_size_by_data(model) - - return model - - @classmethod - def convert_if_required( - cls, - model_path: str, - output_path: str, - config: ModelConfigBase, - base_model: BaseModelType, - ) -> str: - format = cls.detect_format(model_path) - if format == CLIPVisionModelFormat.Diffusers: - return model_path - else: - raise ValueError(f"Unsupported format: '{format}'.") diff --git a/invokeai/backend/model_management_OLD/models/controlnet.py b/invokeai/backend/model_management_OLD/models/controlnet.py deleted file mode 100644 index 3b534cb9d14..00000000000 --- a/invokeai/backend/model_management_OLD/models/controlnet.py +++ /dev/null @@ -1,162 +0,0 @@ -import os -from enum import Enum -from pathlib import Path -from typing import Literal, Optional - -import torch - -import invokeai.backend.util.logging as logger -from invokeai.app.services.config import InvokeAIAppConfig - -from .base import ( - BaseModelType, - EmptyConfigLoader, - InvalidModelException, - ModelBase, - ModelConfigBase, - ModelNotFoundException, - ModelType, - SubModelType, - calc_model_size_by_data, - calc_model_size_by_fs, - classproperty, -) - - -class ControlNetModelFormat(str, Enum): - Checkpoint = "checkpoint" - Diffusers = "diffusers" - - -class ControlNetModel(ModelBase): - # model_class: Type - # model_size: int - - class DiffusersConfig(ModelConfigBase): - model_format: Literal[ControlNetModelFormat.Diffusers] - - class CheckpointConfig(ModelConfigBase): - model_format: Literal[ControlNetModelFormat.Checkpoint] - config: str - - def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): - assert model_type == ModelType.ControlNet - super().__init__(model_path, base_model, model_type) - - try: - config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json") - # config = json.loads(os.path.join(self.model_path, "config.json")) - except Exception: - raise Exception("Invalid controlnet model! (config.json not found or invalid)") - - model_class_name = config.get("_class_name", None) - if model_class_name not in {"ControlNetModel"}: - raise Exception(f"Invalid ControlNet model! Unknown _class_name: {model_class_name}") - - try: - self.model_class = self._hf_definition_to_type(["diffusers", model_class_name]) - self.model_size = calc_model_size_by_fs(self.model_path) - except Exception: - raise Exception("Invalid ControlNet model!") - - def get_size(self, child_type: Optional[SubModelType] = None): - if child_type is not None: - raise Exception("There is no child models in controlnet model") - return self.model_size - - def get_model( - self, - torch_dtype: Optional[torch.dtype], - child_type: Optional[SubModelType] = None, - ): - if child_type is not None: - raise Exception("There are no child models in controlnet model") - - model = None - for variant in ["fp16", None]: - try: - model = self.model_class.from_pretrained( - self.model_path, - torch_dtype=torch_dtype, - variant=variant, - ) - break - except Exception: - pass - if not model: - raise ModelNotFoundException() - - # calc more accurate size - self.model_size = calc_model_size_by_data(model) - return model - - @classproperty - def save_to_config(cls) -> bool: - return False - - @classmethod - def detect_format(cls, path: str): - if not os.path.exists(path): - raise ModelNotFoundException() - - if os.path.isdir(path): - if os.path.exists(os.path.join(path, "config.json")): - return ControlNetModelFormat.Diffusers - - if os.path.isfile(path): - if any(path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "pth"]): - return ControlNetModelFormat.Checkpoint - - raise InvalidModelException(f"Not a valid model: {path}") - - @classmethod - def convert_if_required( - cls, - model_path: str, - output_path: str, - config: ModelConfigBase, - base_model: BaseModelType, - ) -> str: - if cls.detect_format(model_path) == ControlNetModelFormat.Checkpoint: - return _convert_controlnet_ckpt_and_cache( - model_path=model_path, - model_config=config.config, - output_path=output_path, - base_model=base_model, - ) - else: - return model_path - - -def _convert_controlnet_ckpt_and_cache( - model_path: str, - output_path: str, - base_model: BaseModelType, - model_config: str, -) -> str: - """ - Convert the controlnet from checkpoint format to diffusers format, - cache it to disk, and return Path to converted - file. If already on disk then just returns Path. - """ - app_config = InvokeAIAppConfig.get_config() - weights = app_config.root_path / model_path - output_path = Path(output_path) - - logger.info(f"Converting {weights} to diffusers format") - # return cached version if it exists - if output_path.exists(): - return output_path - - # to avoid circular import errors - from ..convert_ckpt_to_diffusers import convert_controlnet_to_diffusers - - convert_controlnet_to_diffusers( - weights, - output_path, - original_config_file=app_config.root_path / model_config, - image_size=512, - scan_needed=True, - from_safetensors=weights.suffix == ".safetensors", - ) - return output_path diff --git a/invokeai/backend/model_management_OLD/models/ip_adapter.py b/invokeai/backend/model_management_OLD/models/ip_adapter.py deleted file mode 100644 index c60edd0abe3..00000000000 --- a/invokeai/backend/model_management_OLD/models/ip_adapter.py +++ /dev/null @@ -1,98 +0,0 @@ -import os -import typing -from enum import Enum -from typing import Literal, Optional - -import torch - -from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus, build_ip_adapter -from invokeai.backend.model_management.models.base import ( - BaseModelType, - InvalidModelException, - ModelBase, - ModelConfigBase, - ModelType, - SubModelType, - calc_model_size_by_fs, - classproperty, -) - - -class IPAdapterModelFormat(str, Enum): - # The custom IP-Adapter model format defined by InvokeAI. - InvokeAI = "invokeai" - - -class IPAdapterModel(ModelBase): - class InvokeAIConfig(ModelConfigBase): - model_format: Literal[IPAdapterModelFormat.InvokeAI] - - def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): - assert model_type == ModelType.IPAdapter - super().__init__(model_path, base_model, model_type) - - self.model_size = calc_model_size_by_fs(self.model_path) - - @classmethod - def detect_format(cls, path: str) -> str: - if not os.path.exists(path): - raise ModuleNotFoundError(f"No IP-Adapter model at path '{path}'.") - - if os.path.isdir(path): - model_file = os.path.join(path, "ip_adapter.bin") - image_encoder_config_file = os.path.join(path, "image_encoder.txt") - if os.path.exists(model_file) and os.path.exists(image_encoder_config_file): - return IPAdapterModelFormat.InvokeAI - - raise InvalidModelException(f"Unexpected IP-Adapter model format: {path}") - - @classproperty - def save_to_config(cls) -> bool: - return True - - def get_size(self, child_type: Optional[SubModelType] = None) -> int: - if child_type is not None: - raise ValueError("There are no child models in an IP-Adapter model.") - - return self.model_size - - def get_model( - self, - torch_dtype: torch.dtype, - child_type: Optional[SubModelType] = None, - ) -> typing.Union[IPAdapter, IPAdapterPlus]: - if child_type is not None: - raise ValueError("There are no child models in an IP-Adapter model.") - - model = build_ip_adapter( - ip_adapter_ckpt_path=os.path.join(self.model_path, "ip_adapter.bin"), - device=torch.device("cpu"), - dtype=torch_dtype, - ) - - self.model_size = model.calc_size() - return model - - @classmethod - def convert_if_required( - cls, - model_path: str, - output_path: str, - config: ModelConfigBase, - base_model: BaseModelType, - ) -> str: - format = cls.detect_format(model_path) - if format == IPAdapterModelFormat.InvokeAI: - return model_path - else: - raise ValueError(f"Unsupported format: '{format}'.") - - -def get_ip_adapter_image_encoder_model_id(model_path: str): - """Read the ID of the image encoder associated with the IP-Adapter at `model_path`.""" - image_encoder_config_file = os.path.join(model_path, "image_encoder.txt") - - with open(image_encoder_config_file, "r") as f: - image_encoder_model = f.readline().strip() - - return image_encoder_model diff --git a/invokeai/backend/model_management_OLD/models/lora.py b/invokeai/backend/model_management_OLD/models/lora.py deleted file mode 100644 index b110d75d220..00000000000 --- a/invokeai/backend/model_management_OLD/models/lora.py +++ /dev/null @@ -1,696 +0,0 @@ -import bisect -import os -from enum import Enum -from pathlib import Path -from typing import Dict, Optional, Union - -import torch -from safetensors.torch import load_file - -from .base import ( - BaseModelType, - InvalidModelException, - ModelBase, - ModelConfigBase, - ModelNotFoundException, - ModelType, - SubModelType, - classproperty, -) - - -class LoRAModelFormat(str, Enum): - LyCORIS = "lycoris" - Diffusers = "diffusers" - - -class LoRAModel(ModelBase): - # model_size: int - - class Config(ModelConfigBase): - model_format: LoRAModelFormat # TODO: - - def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): - assert model_type == ModelType.Lora - super().__init__(model_path, base_model, model_type) - - self.model_size = os.path.getsize(self.model_path) - - def get_size(self, child_type: Optional[SubModelType] = None): - if child_type is not None: - raise Exception("There is no child models in lora") - return self.model_size - - def get_model( - self, - torch_dtype: Optional[torch.dtype], - child_type: Optional[SubModelType] = None, - ): - if child_type is not None: - raise Exception("There is no child models in lora") - - model = LoRAModelRaw.from_checkpoint( - file_path=self.model_path, - dtype=torch_dtype, - base_model=self.base_model, - ) - - self.model_size = model.calc_size() - return model - - @classproperty - def save_to_config(cls) -> bool: - return True - - @classmethod - def detect_format(cls, path: str): - if not os.path.exists(path): - raise ModelNotFoundException() - - if os.path.isdir(path): - for ext in ["safetensors", "bin"]: - if os.path.exists(os.path.join(path, f"pytorch_lora_weights.{ext}")): - return LoRAModelFormat.Diffusers - - if os.path.isfile(path): - if any(path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]): - return LoRAModelFormat.LyCORIS - - raise InvalidModelException(f"Not a valid model: {path}") - - @classmethod - def convert_if_required( - cls, - model_path: str, - output_path: str, - config: ModelConfigBase, - base_model: BaseModelType, - ) -> str: - if cls.detect_format(model_path) == LoRAModelFormat.Diffusers: - for ext in ["safetensors", "bin"]: # return path to the safetensors file inside the folder - path = Path(model_path, f"pytorch_lora_weights.{ext}") - if path.exists(): - return path - else: - return model_path - - -class LoRALayerBase: - # rank: Optional[int] - # alpha: Optional[float] - # bias: Optional[torch.Tensor] - # layer_key: str - - # @property - # def scale(self): - # return self.alpha / self.rank if (self.alpha and self.rank) else 1.0 - - def __init__( - self, - layer_key: str, - values: dict, - ): - if "alpha" in values: - self.alpha = values["alpha"].item() - else: - self.alpha = None - - if "bias_indices" in values and "bias_values" in values and "bias_size" in values: - self.bias = torch.sparse_coo_tensor( - values["bias_indices"], - values["bias_values"], - tuple(values["bias_size"]), - ) - - else: - self.bias = None - - self.rank = None # set in layer implementation - self.layer_key = layer_key - - def get_weight(self, orig_weight: torch.Tensor): - raise NotImplementedError() - - def calc_size(self) -> int: - model_size = 0 - for val in [self.bias]: - if val is not None: - model_size += val.nelement() * val.element_size() - return model_size - - def to( - self, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ): - if self.bias is not None: - self.bias = self.bias.to(device=device, dtype=dtype) - - -# TODO: find and debug lora/locon with bias -class LoRALayer(LoRALayerBase): - # up: torch.Tensor - # mid: Optional[torch.Tensor] - # down: torch.Tensor - - def __init__( - self, - layer_key: str, - values: dict, - ): - super().__init__(layer_key, values) - - self.up = values["lora_up.weight"] - self.down = values["lora_down.weight"] - if "lora_mid.weight" in values: - self.mid = values["lora_mid.weight"] - else: - self.mid = None - - self.rank = self.down.shape[0] - - def get_weight(self, orig_weight: torch.Tensor): - if self.mid is not None: - up = self.up.reshape(self.up.shape[0], self.up.shape[1]) - down = self.down.reshape(self.down.shape[0], self.down.shape[1]) - weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down) - else: - weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1) - - return weight - - def calc_size(self) -> int: - model_size = super().calc_size() - for val in [self.up, self.mid, self.down]: - if val is not None: - model_size += val.nelement() * val.element_size() - return model_size - - def to( - self, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ): - super().to(device=device, dtype=dtype) - - self.up = self.up.to(device=device, dtype=dtype) - self.down = self.down.to(device=device, dtype=dtype) - - if self.mid is not None: - self.mid = self.mid.to(device=device, dtype=dtype) - - -class LoHALayer(LoRALayerBase): - # w1_a: torch.Tensor - # w1_b: torch.Tensor - # w2_a: torch.Tensor - # w2_b: torch.Tensor - # t1: Optional[torch.Tensor] = None - # t2: Optional[torch.Tensor] = None - - def __init__( - self, - layer_key: str, - values: dict, - ): - super().__init__(layer_key, values) - - self.w1_a = values["hada_w1_a"] - self.w1_b = values["hada_w1_b"] - self.w2_a = values["hada_w2_a"] - self.w2_b = values["hada_w2_b"] - - if "hada_t1" in values: - self.t1 = values["hada_t1"] - else: - self.t1 = None - - if "hada_t2" in values: - self.t2 = values["hada_t2"] - else: - self.t2 = None - - self.rank = self.w1_b.shape[0] - - def get_weight(self, orig_weight: torch.Tensor): - if self.t1 is None: - weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b) - - else: - rebuild1 = torch.einsum("i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a) - rebuild2 = torch.einsum("i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a) - weight = rebuild1 * rebuild2 - - return weight - - def calc_size(self) -> int: - model_size = super().calc_size() - for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]: - if val is not None: - model_size += val.nelement() * val.element_size() - return model_size - - def to( - self, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ): - super().to(device=device, dtype=dtype) - - self.w1_a = self.w1_a.to(device=device, dtype=dtype) - self.w1_b = self.w1_b.to(device=device, dtype=dtype) - if self.t1 is not None: - self.t1 = self.t1.to(device=device, dtype=dtype) - - self.w2_a = self.w2_a.to(device=device, dtype=dtype) - self.w2_b = self.w2_b.to(device=device, dtype=dtype) - if self.t2 is not None: - self.t2 = self.t2.to(device=device, dtype=dtype) - - -class LoKRLayer(LoRALayerBase): - # w1: Optional[torch.Tensor] = None - # w1_a: Optional[torch.Tensor] = None - # w1_b: Optional[torch.Tensor] = None - # w2: Optional[torch.Tensor] = None - # w2_a: Optional[torch.Tensor] = None - # w2_b: Optional[torch.Tensor] = None - # t2: Optional[torch.Tensor] = None - - def __init__( - self, - layer_key: str, - values: dict, - ): - super().__init__(layer_key, values) - - if "lokr_w1" in values: - self.w1 = values["lokr_w1"] - self.w1_a = None - self.w1_b = None - else: - self.w1 = None - self.w1_a = values["lokr_w1_a"] - self.w1_b = values["lokr_w1_b"] - - if "lokr_w2" in values: - self.w2 = values["lokr_w2"] - self.w2_a = None - self.w2_b = None - else: - self.w2 = None - self.w2_a = values["lokr_w2_a"] - self.w2_b = values["lokr_w2_b"] - - if "lokr_t2" in values: - self.t2 = values["lokr_t2"] - else: - self.t2 = None - - if "lokr_w1_b" in values: - self.rank = values["lokr_w1_b"].shape[0] - elif "lokr_w2_b" in values: - self.rank = values["lokr_w2_b"].shape[0] - else: - self.rank = None # unscaled - - def get_weight(self, orig_weight: torch.Tensor): - w1 = self.w1 - if w1 is None: - w1 = self.w1_a @ self.w1_b - - w2 = self.w2 - if w2 is None: - if self.t2 is None: - w2 = self.w2_a @ self.w2_b - else: - w2 = torch.einsum("i j k l, i p, j r -> p r k l", self.t2, self.w2_a, self.w2_b) - - if len(w2.shape) == 4: - w1 = w1.unsqueeze(2).unsqueeze(2) - w2 = w2.contiguous() - weight = torch.kron(w1, w2) - - return weight - - def calc_size(self) -> int: - model_size = super().calc_size() - for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]: - if val is not None: - model_size += val.nelement() * val.element_size() - return model_size - - def to( - self, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ): - super().to(device=device, dtype=dtype) - - if self.w1 is not None: - self.w1 = self.w1.to(device=device, dtype=dtype) - else: - self.w1_a = self.w1_a.to(device=device, dtype=dtype) - self.w1_b = self.w1_b.to(device=device, dtype=dtype) - - if self.w2 is not None: - self.w2 = self.w2.to(device=device, dtype=dtype) - else: - self.w2_a = self.w2_a.to(device=device, dtype=dtype) - self.w2_b = self.w2_b.to(device=device, dtype=dtype) - - if self.t2 is not None: - self.t2 = self.t2.to(device=device, dtype=dtype) - - -class FullLayer(LoRALayerBase): - # weight: torch.Tensor - - def __init__( - self, - layer_key: str, - values: dict, - ): - super().__init__(layer_key, values) - - self.weight = values["diff"] - - if len(values.keys()) > 1: - _keys = list(values.keys()) - _keys.remove("diff") - raise NotImplementedError(f"Unexpected keys in lora diff layer: {_keys}") - - self.rank = None # unscaled - - def get_weight(self, orig_weight: torch.Tensor): - return self.weight - - def calc_size(self) -> int: - model_size = super().calc_size() - model_size += self.weight.nelement() * self.weight.element_size() - return model_size - - def to( - self, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ): - super().to(device=device, dtype=dtype) - - self.weight = self.weight.to(device=device, dtype=dtype) - - -class IA3Layer(LoRALayerBase): - # weight: torch.Tensor - # on_input: torch.Tensor - - def __init__( - self, - layer_key: str, - values: dict, - ): - super().__init__(layer_key, values) - - self.weight = values["weight"] - self.on_input = values["on_input"] - - self.rank = None # unscaled - - def get_weight(self, orig_weight: torch.Tensor): - weight = self.weight - if not self.on_input: - weight = weight.reshape(-1, 1) - return orig_weight * weight - - def calc_size(self) -> int: - model_size = super().calc_size() - model_size += self.weight.nelement() * self.weight.element_size() - model_size += self.on_input.nelement() * self.on_input.element_size() - return model_size - - def to( - self, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ): - super().to(device=device, dtype=dtype) - - self.weight = self.weight.to(device=device, dtype=dtype) - self.on_input = self.on_input.to(device=device, dtype=dtype) - - -# TODO: rename all methods used in model logic with Info postfix and remove here Raw postfix -class LoRAModelRaw: # (torch.nn.Module): - _name: str - layers: Dict[str, LoRALayer] - - def __init__( - self, - name: str, - layers: Dict[str, LoRALayer], - ): - self._name = name - self.layers = layers - - @property - def name(self): - return self._name - - def to( - self, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ): - # TODO: try revert if exception? - for _key, layer in self.layers.items(): - layer.to(device=device, dtype=dtype) - - def calc_size(self) -> int: - model_size = 0 - for _, layer in self.layers.items(): - model_size += layer.calc_size() - return model_size - - @classmethod - def _convert_sdxl_keys_to_diffusers_format(cls, state_dict): - """Convert the keys of an SDXL LoRA state_dict to diffusers format. - - The input state_dict can be in either Stability AI format or diffusers format. If the state_dict is already in - diffusers format, then this function will have no effect. - - This function is adapted from: - https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L385-L409 - - Args: - state_dict (Dict[str, Tensor]): The SDXL LoRA state_dict. - - Raises: - ValueError: If state_dict contains an unrecognized key, or not all keys could be converted. - - Returns: - Dict[str, Tensor]: The diffusers-format state_dict. - """ - converted_count = 0 # The number of Stability AI keys converted to diffusers format. - not_converted_count = 0 # The number of keys that were not converted. - - # Get a sorted list of Stability AI UNet keys so that we can efficiently search for keys with matching prefixes. - # For example, we want to efficiently find `input_blocks_4_1` in the list when searching for - # `input_blocks_4_1_proj_in`. - stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP) - stability_unet_keys.sort() - - new_state_dict = {} - for full_key, value in state_dict.items(): - if full_key.startswith("lora_unet_"): - search_key = full_key.replace("lora_unet_", "") - # Use bisect to find the key in stability_unet_keys that *may* match the search_key's prefix. - position = bisect.bisect_right(stability_unet_keys, search_key) - map_key = stability_unet_keys[position - 1] - # Now, check if the map_key *actually* matches the search_key. - if search_key.startswith(map_key): - new_key = full_key.replace(map_key, SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP[map_key]) - new_state_dict[new_key] = value - converted_count += 1 - else: - new_state_dict[full_key] = value - not_converted_count += 1 - elif full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"): - # The CLIP text encoders have the same keys in both Stability AI and diffusers formats. - new_state_dict[full_key] = value - continue - else: - raise ValueError(f"Unrecognized SDXL LoRA key prefix: '{full_key}'.") - - if converted_count > 0 and not_converted_count > 0: - raise ValueError( - f"The SDXL LoRA could only be partially converted to diffusers format. converted={converted_count}," - f" not_converted={not_converted_count}" - ) - - return new_state_dict - - @classmethod - def from_checkpoint( - cls, - file_path: Union[str, Path], - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - base_model: Optional[BaseModelType] = None, - ): - device = device or torch.device("cpu") - dtype = dtype or torch.float32 - - if isinstance(file_path, str): - file_path = Path(file_path) - - model = cls( - name=file_path.stem, # TODO: - layers={}, - ) - - if file_path.suffix == ".safetensors": - state_dict = load_file(file_path.absolute().as_posix(), device="cpu") - else: - state_dict = torch.load(file_path, map_location="cpu") - - state_dict = cls._group_state(state_dict) - - if base_model == BaseModelType.StableDiffusionXL: - state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict) - - for layer_key, values in state_dict.items(): - # lora and locon - if "lora_down.weight" in values: - layer = LoRALayer(layer_key, values) - - # loha - elif "hada_w1_b" in values: - layer = LoHALayer(layer_key, values) - - # lokr - elif "lokr_w1_b" in values or "lokr_w1" in values: - layer = LoKRLayer(layer_key, values) - - # diff - elif "diff" in values: - layer = FullLayer(layer_key, values) - - # ia3 - elif "weight" in values and "on_input" in values: - layer = IA3Layer(layer_key, values) - - else: - print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}") - raise Exception("Unknown lora format!") - - # lower memory consumption by removing already parsed layer values - state_dict[layer_key].clear() - - layer.to(device=device, dtype=dtype) - model.layers[layer_key] = layer - - return model - - @staticmethod - def _group_state(state_dict: dict): - state_dict_groupped = {} - - for key, value in state_dict.items(): - stem, leaf = key.split(".", 1) - if stem not in state_dict_groupped: - state_dict_groupped[stem] = {} - state_dict_groupped[stem][leaf] = value - - return state_dict_groupped - - -# code from -# https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32 -def make_sdxl_unet_conversion_map(): - """Create a dict mapping state_dict keys from Stability AI SDXL format to diffusers SDXL format.""" - unet_conversion_map_layer = [] - - for i in range(3): # num_blocks is 3 in sdxl - # loop over downblocks/upblocks - for j in range(2): - # loop over resnets/attentions for downblocks - hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." - sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." - unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) - - if i < 3: - # no attention layers in down_blocks.3 - hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." - sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." - unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) - - for j in range(3): - # loop over resnets/attentions for upblocks - hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." - sd_up_res_prefix = f"output_blocks.{3*i + j}.0." - unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) - - # if i > 0: commentout for sdxl - # no attention layers in up_blocks.0 - hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." - sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." - unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) - - if i < 3: - # no downsample in down_blocks.3 - hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." - sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." - unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) - - # no upsample in up_blocks.3 - hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." - sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl - unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) - - hf_mid_atn_prefix = "mid_block.attentions.0." - sd_mid_atn_prefix = "middle_block.1." - unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) - - for j in range(2): - hf_mid_res_prefix = f"mid_block.resnets.{j}." - sd_mid_res_prefix = f"middle_block.{2*j}." - unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) - - unet_conversion_map_resnet = [ - # (stable-diffusion, HF Diffusers) - ("in_layers.0.", "norm1."), - ("in_layers.2.", "conv1."), - ("out_layers.0.", "norm2."), - ("out_layers.3.", "conv2."), - ("emb_layers.1.", "time_emb_proj."), - ("skip_connection.", "conv_shortcut."), - ] - - unet_conversion_map = [] - for sd, hf in unet_conversion_map_layer: - if "resnets" in hf: - for sd_res, hf_res in unet_conversion_map_resnet: - unet_conversion_map.append((sd + sd_res, hf + hf_res)) - else: - unet_conversion_map.append((sd, hf)) - - for j in range(2): - hf_time_embed_prefix = f"time_embedding.linear_{j+1}." - sd_time_embed_prefix = f"time_embed.{j*2}." - unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix)) - - for j in range(2): - hf_label_embed_prefix = f"add_embedding.linear_{j+1}." - sd_label_embed_prefix = f"label_emb.0.{j*2}." - unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix)) - - unet_conversion_map.append(("input_blocks.0.0.", "conv_in.")) - unet_conversion_map.append(("out.0.", "conv_norm_out.")) - unet_conversion_map.append(("out.2.", "conv_out.")) - - return unet_conversion_map - - -SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP = { - sd.rstrip(".").replace(".", "_"): hf.rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map() -} diff --git a/invokeai/backend/model_management_OLD/models/sdxl.py b/invokeai/backend/model_management_OLD/models/sdxl.py deleted file mode 100644 index 01e9420fed7..00000000000 --- a/invokeai/backend/model_management_OLD/models/sdxl.py +++ /dev/null @@ -1,148 +0,0 @@ -import json -import os -from enum import Enum -from pathlib import Path -from typing import Literal, Optional - -from omegaconf import OmegaConf -from pydantic import Field - -from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend.model_management.detect_baked_in_vae import has_baked_in_sdxl_vae -from invokeai.backend.util.logging import InvokeAILogger - -from .base import ( - BaseModelType, - DiffusersModel, - InvalidModelException, - ModelConfigBase, - ModelType, - ModelVariantType, - classproperty, - read_checkpoint_meta, -) - - -class StableDiffusionXLModelFormat(str, Enum): - Checkpoint = "checkpoint" - Diffusers = "diffusers" - - -class StableDiffusionXLModel(DiffusersModel): - # TODO: check that configs overwriten properly - class DiffusersConfig(ModelConfigBase): - model_format: Literal[StableDiffusionXLModelFormat.Diffusers] - vae: Optional[str] = Field(None) - variant: ModelVariantType - - class CheckpointConfig(ModelConfigBase): - model_format: Literal[StableDiffusionXLModelFormat.Checkpoint] - vae: Optional[str] = Field(None) - config: str - variant: ModelVariantType - - def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): - assert base_model in {BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner} - assert model_type == ModelType.Main - super().__init__( - model_path=model_path, - base_model=BaseModelType.StableDiffusionXL, - model_type=ModelType.Main, - ) - - @classmethod - def probe_config(cls, path: str, **kwargs): - model_format = cls.detect_format(path) - ckpt_config_path = kwargs.get("config", None) - if model_format == StableDiffusionXLModelFormat.Checkpoint: - if ckpt_config_path: - ckpt_config = OmegaConf.load(ckpt_config_path) - in_channels = ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"] - - else: - checkpoint = read_checkpoint_meta(path) - checkpoint = checkpoint.get("state_dict", checkpoint) - in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1] - - elif model_format == StableDiffusionXLModelFormat.Diffusers: - unet_config_path = os.path.join(path, "unet", "config.json") - if os.path.exists(unet_config_path): - with open(unet_config_path, "r") as f: - unet_config = json.loads(f.read()) - in_channels = unet_config["in_channels"] - - else: - raise InvalidModelException(f"{path} is not a recognized Stable Diffusion diffusers model") - - else: - raise NotImplementedError(f"Unknown stable diffusion 2.* format: {model_format}") - - if in_channels == 9: - variant = ModelVariantType.Inpaint - elif in_channels == 5: - variant = ModelVariantType.Depth - elif in_channels == 4: - variant = ModelVariantType.Normal - else: - raise Exception("Unkown stable diffusion 2.* model format") - - if ckpt_config_path is None: - # avoid circular import - from .stable_diffusion import _select_ckpt_config - - ckpt_config_path = _select_ckpt_config(kwargs.get("model_base", BaseModelType.StableDiffusionXL), variant) - - return cls.create_config( - path=path, - model_format=model_format, - config=ckpt_config_path, - variant=variant, - ) - - @classproperty - def save_to_config(cls) -> bool: - return True - - @classmethod - def detect_format(cls, model_path: str): - if os.path.isdir(model_path): - return StableDiffusionXLModelFormat.Diffusers - else: - return StableDiffusionXLModelFormat.Checkpoint - - @classmethod - def convert_if_required( - cls, - model_path: str, - output_path: str, - config: ModelConfigBase, - base_model: BaseModelType, - ) -> str: - # The convert script adapted from the diffusers package uses - # strings for the base model type. To avoid making too many - # source code changes, we simply translate here - if Path(output_path).exists(): - return output_path - - if isinstance(config, cls.CheckpointConfig): - from invokeai.backend.model_management.models.stable_diffusion import _convert_ckpt_and_cache - - # Hack in VAE-fp16 fix - If model sdxl-vae-fp16-fix is installed, - # then we bake it into the converted model unless there is already - # a nonstandard VAE installed. - kwargs = {} - app_config = InvokeAIAppConfig.get_config() - vae_path = app_config.models_path / "sdxl/vae/sdxl-vae-fp16-fix" - if vae_path.exists() and not has_baked_in_sdxl_vae(Path(model_path)): - InvokeAILogger.get_logger().warning("No baked-in VAE detected. Inserting sdxl-vae-fp16-fix.") - kwargs["vae_path"] = vae_path - - return _convert_ckpt_and_cache( - version=base_model, - model_config=config, - output_path=output_path, - use_safetensors=True, - **kwargs, - ) - else: - return model_path diff --git a/invokeai/backend/model_management_OLD/models/stable_diffusion.py b/invokeai/backend/model_management_OLD/models/stable_diffusion.py deleted file mode 100644 index a38a44fccf7..00000000000 --- a/invokeai/backend/model_management_OLD/models/stable_diffusion.py +++ /dev/null @@ -1,337 +0,0 @@ -import json -import os -from enum import Enum -from pathlib import Path -from typing import Literal, Optional, Union - -from diffusers import StableDiffusionInpaintPipeline, StableDiffusionPipeline -from omegaconf import OmegaConf -from pydantic import Field - -import invokeai.backend.util.logging as logger -from invokeai.app.services.config import InvokeAIAppConfig - -from .base import ( - BaseModelType, - DiffusersModel, - InvalidModelException, - ModelConfigBase, - ModelNotFoundException, - ModelType, - ModelVariantType, - SilenceWarnings, - classproperty, - read_checkpoint_meta, -) -from .sdxl import StableDiffusionXLModel - - -class StableDiffusion1ModelFormat(str, Enum): - Checkpoint = "checkpoint" - Diffusers = "diffusers" - - -class StableDiffusion1Model(DiffusersModel): - class DiffusersConfig(ModelConfigBase): - model_format: Literal[StableDiffusion1ModelFormat.Diffusers] - vae: Optional[str] = Field(None) - variant: ModelVariantType - - class CheckpointConfig(ModelConfigBase): - model_format: Literal[StableDiffusion1ModelFormat.Checkpoint] - vae: Optional[str] = Field(None) - config: str - variant: ModelVariantType - - def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): - assert base_model == BaseModelType.StableDiffusion1 - assert model_type == ModelType.Main - super().__init__( - model_path=model_path, - base_model=BaseModelType.StableDiffusion1, - model_type=ModelType.Main, - ) - - @classmethod - def probe_config(cls, path: str, **kwargs): - model_format = cls.detect_format(path) - ckpt_config_path = kwargs.get("config", None) - if model_format == StableDiffusion1ModelFormat.Checkpoint: - if ckpt_config_path: - ckpt_config = OmegaConf.load(ckpt_config_path) - ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"] - - else: - checkpoint = read_checkpoint_meta(path) - checkpoint = checkpoint.get("state_dict", checkpoint) - in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1] - - elif model_format == StableDiffusion1ModelFormat.Diffusers: - unet_config_path = os.path.join(path, "unet", "config.json") - if os.path.exists(unet_config_path): - with open(unet_config_path, "r") as f: - unet_config = json.loads(f.read()) - in_channels = unet_config["in_channels"] - - else: - raise NotImplementedError(f"{path} is not a supported stable diffusion diffusers format") - - else: - raise NotImplementedError(f"Unknown stable diffusion 1.* format: {model_format}") - - if in_channels == 9: - variant = ModelVariantType.Inpaint - elif in_channels == 4: - variant = ModelVariantType.Normal - else: - raise Exception("Unkown stable diffusion 1.* model format") - - if ckpt_config_path is None: - ckpt_config_path = _select_ckpt_config(BaseModelType.StableDiffusion1, variant) - - return cls.create_config( - path=path, - model_format=model_format, - config=ckpt_config_path, - variant=variant, - ) - - @classproperty - def save_to_config(cls) -> bool: - return True - - @classmethod - def detect_format(cls, model_path: str): - if not os.path.exists(model_path): - raise ModelNotFoundException() - - if os.path.isdir(model_path): - if os.path.exists(os.path.join(model_path, "model_index.json")): - return StableDiffusion1ModelFormat.Diffusers - - if os.path.isfile(model_path): - if any(model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]): - return StableDiffusion1ModelFormat.Checkpoint - - raise InvalidModelException(f"Not a valid model: {model_path}") - - @classmethod - def convert_if_required( - cls, - model_path: str, - output_path: str, - config: ModelConfigBase, - base_model: BaseModelType, - ) -> str: - if isinstance(config, cls.CheckpointConfig): - return _convert_ckpt_and_cache( - version=BaseModelType.StableDiffusion1, - model_config=config, - load_safety_checker=False, - output_path=output_path, - ) - else: - return model_path - - -class StableDiffusion2ModelFormat(str, Enum): - Checkpoint = "checkpoint" - Diffusers = "diffusers" - - -class StableDiffusion2Model(DiffusersModel): - # TODO: check that configs overwriten properly - class DiffusersConfig(ModelConfigBase): - model_format: Literal[StableDiffusion2ModelFormat.Diffusers] - vae: Optional[str] = Field(None) - variant: ModelVariantType - - class CheckpointConfig(ModelConfigBase): - model_format: Literal[StableDiffusion2ModelFormat.Checkpoint] - vae: Optional[str] = Field(None) - config: str - variant: ModelVariantType - - def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): - assert base_model == BaseModelType.StableDiffusion2 - assert model_type == ModelType.Main - super().__init__( - model_path=model_path, - base_model=BaseModelType.StableDiffusion2, - model_type=ModelType.Main, - ) - - @classmethod - def probe_config(cls, path: str, **kwargs): - model_format = cls.detect_format(path) - ckpt_config_path = kwargs.get("config", None) - if model_format == StableDiffusion2ModelFormat.Checkpoint: - if ckpt_config_path: - ckpt_config = OmegaConf.load(ckpt_config_path) - ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"] - - else: - checkpoint = read_checkpoint_meta(path) - checkpoint = checkpoint.get("state_dict", checkpoint) - in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1] - - elif model_format == StableDiffusion2ModelFormat.Diffusers: - unet_config_path = os.path.join(path, "unet", "config.json") - if os.path.exists(unet_config_path): - with open(unet_config_path, "r") as f: - unet_config = json.loads(f.read()) - in_channels = unet_config["in_channels"] - - else: - raise Exception("Not supported stable diffusion diffusers format(possibly onnx?)") - - else: - raise NotImplementedError(f"Unknown stable diffusion 2.* format: {model_format}") - - if in_channels == 9: - variant = ModelVariantType.Inpaint - elif in_channels == 5: - variant = ModelVariantType.Depth - elif in_channels == 4: - variant = ModelVariantType.Normal - else: - raise Exception("Unkown stable diffusion 2.* model format") - - if ckpt_config_path is None: - ckpt_config_path = _select_ckpt_config(BaseModelType.StableDiffusion2, variant) - - return cls.create_config( - path=path, - model_format=model_format, - config=ckpt_config_path, - variant=variant, - ) - - @classproperty - def save_to_config(cls) -> bool: - return True - - @classmethod - def detect_format(cls, model_path: str): - if not os.path.exists(model_path): - raise ModelNotFoundException() - - if os.path.isdir(model_path): - if os.path.exists(os.path.join(model_path, "model_index.json")): - return StableDiffusion2ModelFormat.Diffusers - - if os.path.isfile(model_path): - if any(model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]): - return StableDiffusion2ModelFormat.Checkpoint - - raise InvalidModelException(f"Not a valid model: {model_path}") - - @classmethod - def convert_if_required( - cls, - model_path: str, - output_path: str, - config: ModelConfigBase, - base_model: BaseModelType, - ) -> str: - if isinstance(config, cls.CheckpointConfig): - return _convert_ckpt_and_cache( - version=BaseModelType.StableDiffusion2, - model_config=config, - output_path=output_path, - ) - else: - return model_path - - -# TODO: rework -# pass precision - currently defaulting to fp16 -def _convert_ckpt_and_cache( - version: BaseModelType, - model_config: Union[ - StableDiffusion1Model.CheckpointConfig, - StableDiffusion2Model.CheckpointConfig, - StableDiffusionXLModel.CheckpointConfig, - ], - output_path: str, - use_save_model: bool = False, - **kwargs, -) -> str: - """ - Convert the checkpoint model indicated in mconfig into a - diffusers, cache it to disk, and return Path to converted - file. If already on disk then just returns Path. - """ - app_config = InvokeAIAppConfig.get_config() - - weights = app_config.models_path / model_config.path - config_file = app_config.root_path / model_config.config - output_path = Path(output_path) - variant = model_config.variant - pipeline_class = StableDiffusionInpaintPipeline if variant == "inpaint" else StableDiffusionPipeline - - # return cached version if it exists - if output_path.exists(): - return output_path - - # to avoid circular import errors - from ...util.devices import choose_torch_device, torch_dtype - from ..convert_ckpt_to_diffusers import convert_ckpt_to_diffusers - - model_base_to_model_type = { - BaseModelType.StableDiffusion1: "FrozenCLIPEmbedder", - BaseModelType.StableDiffusion2: "FrozenOpenCLIPEmbedder", - BaseModelType.StableDiffusionXL: "SDXL", - BaseModelType.StableDiffusionXLRefiner: "SDXL-Refiner", - } - logger.info(f"Converting {weights} to diffusers format") - with SilenceWarnings(): - convert_ckpt_to_diffusers( - weights, - output_path, - model_type=model_base_to_model_type[version], - model_version=version, - model_variant=model_config.variant, - original_config_file=config_file, - extract_ema=True, - scan_needed=True, - pipeline_class=pipeline_class, - from_safetensors=weights.suffix == ".safetensors", - precision=torch_dtype(choose_torch_device()), - **kwargs, - ) - return output_path - - -def _select_ckpt_config(version: BaseModelType, variant: ModelVariantType): - ckpt_configs = { - BaseModelType.StableDiffusion1: { - ModelVariantType.Normal: "v1-inference.yaml", - ModelVariantType.Inpaint: "v1-inpainting-inference.yaml", - }, - BaseModelType.StableDiffusion2: { - ModelVariantType.Normal: "v2-inference-v.yaml", # best guess, as we can't differentiate with base(512) - ModelVariantType.Inpaint: "v2-inpainting-inference.yaml", - ModelVariantType.Depth: "v2-midas-inference.yaml", - }, - BaseModelType.StableDiffusionXL: { - ModelVariantType.Normal: "sd_xl_base.yaml", - ModelVariantType.Inpaint: None, - ModelVariantType.Depth: None, - }, - BaseModelType.StableDiffusionXLRefiner: { - ModelVariantType.Normal: "sd_xl_refiner.yaml", - ModelVariantType.Inpaint: None, - ModelVariantType.Depth: None, - }, - } - - app_config = InvokeAIAppConfig.get_config() - try: - config_path = app_config.legacy_conf_path / ckpt_configs[version][variant] - if config_path.is_relative_to(app_config.root_path): - config_path = config_path.relative_to(app_config.root_path) - return str(config_path) - - except Exception: - return None diff --git a/invokeai/backend/model_management_OLD/models/stable_diffusion_onnx.py b/invokeai/backend/model_management_OLD/models/stable_diffusion_onnx.py deleted file mode 100644 index 2d0dd22c43a..00000000000 --- a/invokeai/backend/model_management_OLD/models/stable_diffusion_onnx.py +++ /dev/null @@ -1,150 +0,0 @@ -from enum import Enum -from typing import Literal - -from diffusers import OnnxRuntimeModel - -from .base import ( - BaseModelType, - DiffusersModel, - IAIOnnxRuntimeModel, - ModelConfigBase, - ModelType, - ModelVariantType, - SchedulerPredictionType, - classproperty, -) - - -class StableDiffusionOnnxModelFormat(str, Enum): - Olive = "olive" - Onnx = "onnx" - - -class ONNXStableDiffusion1Model(DiffusersModel): - class Config(ModelConfigBase): - model_format: Literal[StableDiffusionOnnxModelFormat.Onnx] - variant: ModelVariantType - - def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): - assert base_model == BaseModelType.StableDiffusion1 - assert model_type == ModelType.ONNX - super().__init__( - model_path=model_path, - base_model=BaseModelType.StableDiffusion1, - model_type=ModelType.ONNX, - ) - - for child_name, child_type in self.child_types.items(): - if child_type is OnnxRuntimeModel: - self.child_types[child_name] = IAIOnnxRuntimeModel - - # TODO: check that no optimum models provided - - @classmethod - def probe_config(cls, path: str, **kwargs): - model_format = cls.detect_format(path) - in_channels = 4 # TODO: - - if in_channels == 9: - variant = ModelVariantType.Inpaint - elif in_channels == 4: - variant = ModelVariantType.Normal - else: - raise Exception("Unkown stable diffusion 1.* model format") - - return cls.create_config( - path=path, - model_format=model_format, - variant=variant, - ) - - @classproperty - def save_to_config(cls) -> bool: - return True - - @classmethod - def detect_format(cls, model_path: str): - # TODO: Detect onnx vs olive - return StableDiffusionOnnxModelFormat.Onnx - - @classmethod - def convert_if_required( - cls, - model_path: str, - output_path: str, - config: ModelConfigBase, - base_model: BaseModelType, - ) -> str: - return model_path - - -class ONNXStableDiffusion2Model(DiffusersModel): - # TODO: check that configs overwriten properly - class Config(ModelConfigBase): - model_format: Literal[StableDiffusionOnnxModelFormat.Onnx] - variant: ModelVariantType - prediction_type: SchedulerPredictionType - upcast_attention: bool - - def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): - assert base_model == BaseModelType.StableDiffusion2 - assert model_type == ModelType.ONNX - super().__init__( - model_path=model_path, - base_model=BaseModelType.StableDiffusion2, - model_type=ModelType.ONNX, - ) - - for child_name, child_type in self.child_types.items(): - if child_type is OnnxRuntimeModel: - self.child_types[child_name] = IAIOnnxRuntimeModel - # TODO: check that no optimum models provided - - @classmethod - def probe_config(cls, path: str, **kwargs): - model_format = cls.detect_format(path) - in_channels = 4 # TODO: - - if in_channels == 9: - variant = ModelVariantType.Inpaint - elif in_channels == 5: - variant = ModelVariantType.Depth - elif in_channels == 4: - variant = ModelVariantType.Normal - else: - raise Exception("Unkown stable diffusion 2.* model format") - - if variant == ModelVariantType.Normal: - prediction_type = SchedulerPredictionType.VPrediction - upcast_attention = True - - else: - prediction_type = SchedulerPredictionType.Epsilon - upcast_attention = False - - return cls.create_config( - path=path, - model_format=model_format, - variant=variant, - prediction_type=prediction_type, - upcast_attention=upcast_attention, - ) - - @classproperty - def save_to_config(cls) -> bool: - return True - - @classmethod - def detect_format(cls, model_path: str): - # TODO: Detect onnx vs olive - return StableDiffusionOnnxModelFormat.Onnx - - @classmethod - def convert_if_required( - cls, - model_path: str, - output_path: str, - config: ModelConfigBase, - base_model: BaseModelType, - ) -> str: - return model_path diff --git a/invokeai/backend/model_management_OLD/models/t2i_adapter.py b/invokeai/backend/model_management_OLD/models/t2i_adapter.py deleted file mode 100644 index 4adb9901f99..00000000000 --- a/invokeai/backend/model_management_OLD/models/t2i_adapter.py +++ /dev/null @@ -1,102 +0,0 @@ -import os -from enum import Enum -from typing import Literal, Optional - -import torch -from diffusers import T2IAdapter - -from invokeai.backend.model_management.models.base import ( - BaseModelType, - EmptyConfigLoader, - InvalidModelException, - ModelBase, - ModelConfigBase, - ModelNotFoundException, - ModelType, - SubModelType, - calc_model_size_by_data, - calc_model_size_by_fs, - classproperty, -) - - -class T2IAdapterModelFormat(str, Enum): - Diffusers = "diffusers" - - -class T2IAdapterModel(ModelBase): - class DiffusersConfig(ModelConfigBase): - model_format: Literal[T2IAdapterModelFormat.Diffusers] - - def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): - assert model_type == ModelType.T2IAdapter - super().__init__(model_path, base_model, model_type) - - config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json") - - model_class_name = config.get("_class_name", None) - if model_class_name not in {"T2IAdapter"}: - raise InvalidModelException(f"Invalid T2I-Adapter model. Unknown _class_name: '{model_class_name}'.") - - self.model_class = self._hf_definition_to_type(["diffusers", model_class_name]) - self.model_size = calc_model_size_by_fs(self.model_path) - - def get_size(self, child_type: Optional[SubModelType] = None): - if child_type is not None: - raise ValueError(f"T2I-Adapters do not have child models. Invalid child type: '{child_type}'.") - return self.model_size - - def get_model( - self, - torch_dtype: Optional[torch.dtype], - child_type: Optional[SubModelType] = None, - ) -> T2IAdapter: - if child_type is not None: - raise ValueError(f"T2I-Adapters do not have child models. Invalid child type: '{child_type}'.") - - model = None - for variant in ["fp16", None]: - try: - model = self.model_class.from_pretrained( - self.model_path, - torch_dtype=torch_dtype, - variant=variant, - ) - break - except Exception: - pass - if not model: - raise ModelNotFoundException() - - # Calculate a more accurate size after loading the model into memory. - self.model_size = calc_model_size_by_data(model) - return model - - @classproperty - def save_to_config(cls) -> bool: - return False - - @classmethod - def detect_format(cls, path: str): - if not os.path.exists(path): - raise ModelNotFoundException(f"Model not found at '{path}'.") - - if os.path.isdir(path): - if os.path.exists(os.path.join(path, "config.json")): - return T2IAdapterModelFormat.Diffusers - - raise InvalidModelException(f"Unsupported T2I-Adapter format: '{path}'.") - - @classmethod - def convert_if_required( - cls, - model_path: str, - output_path: str, - config: ModelConfigBase, - base_model: BaseModelType, - ) -> str: - format = cls.detect_format(model_path) - if format == T2IAdapterModelFormat.Diffusers: - return model_path - else: - raise ValueError(f"Unsupported format: '{format}'.") diff --git a/invokeai/backend/model_management_OLD/models/textual_inversion.py b/invokeai/backend/model_management_OLD/models/textual_inversion.py deleted file mode 100644 index 99358704b8d..00000000000 --- a/invokeai/backend/model_management_OLD/models/textual_inversion.py +++ /dev/null @@ -1,87 +0,0 @@ -import os -from typing import Optional - -import torch - -# TODO: naming -from ..lora import TextualInversionModel as TextualInversionModelRaw -from .base import ( - BaseModelType, - InvalidModelException, - ModelBase, - ModelConfigBase, - ModelNotFoundException, - ModelType, - SubModelType, - classproperty, -) - - -class TextualInversionModel(ModelBase): - # model_size: int - - class Config(ModelConfigBase): - model_format: None - - def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): - assert model_type == ModelType.TextualInversion - super().__init__(model_path, base_model, model_type) - - self.model_size = os.path.getsize(self.model_path) - - def get_size(self, child_type: Optional[SubModelType] = None): - if child_type is not None: - raise Exception("There is no child models in textual inversion") - return self.model_size - - def get_model( - self, - torch_dtype: Optional[torch.dtype], - child_type: Optional[SubModelType] = None, - ): - if child_type is not None: - raise Exception("There is no child models in textual inversion") - - checkpoint_path = self.model_path - if os.path.isdir(checkpoint_path): - checkpoint_path = os.path.join(checkpoint_path, "learned_embeds.bin") - - if not os.path.exists(checkpoint_path): - raise ModelNotFoundException() - - model = TextualInversionModelRaw.from_checkpoint( - file_path=checkpoint_path, - dtype=torch_dtype, - ) - - self.model_size = model.embedding.nelement() * model.embedding.element_size() - return model - - @classproperty - def save_to_config(cls) -> bool: - return False - - @classmethod - def detect_format(cls, path: str): - if not os.path.exists(path): - raise ModelNotFoundException() - - if os.path.isdir(path): - if os.path.exists(os.path.join(path, "learned_embeds.bin")): - return None # diffusers-ti - - if os.path.isfile(path): - if any(path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "bin"]): - return None - - raise InvalidModelException(f"Not a valid model: {path}") - - @classmethod - def convert_if_required( - cls, - model_path: str, - output_path: str, - config: ModelConfigBase, - base_model: BaseModelType, - ) -> str: - return model_path diff --git a/invokeai/backend/model_management_OLD/models/vae.py b/invokeai/backend/model_management_OLD/models/vae.py deleted file mode 100644 index 8cc37e67a73..00000000000 --- a/invokeai/backend/model_management_OLD/models/vae.py +++ /dev/null @@ -1,179 +0,0 @@ -import os -from enum import Enum -from pathlib import Path -from typing import Optional - -import safetensors -import torch -from omegaconf import OmegaConf - -from invokeai.app.services.config import InvokeAIAppConfig - -from .base import ( - BaseModelType, - EmptyConfigLoader, - InvalidModelException, - ModelBase, - ModelConfigBase, - ModelNotFoundException, - ModelType, - ModelVariantType, - SubModelType, - calc_model_size_by_data, - calc_model_size_by_fs, - classproperty, -) - - -class VaeModelFormat(str, Enum): - Checkpoint = "checkpoint" - Diffusers = "diffusers" - - -class VaeModel(ModelBase): - # vae_class: Type - # model_size: int - - class Config(ModelConfigBase): - model_format: VaeModelFormat - - def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): - assert model_type == ModelType.Vae - super().__init__(model_path, base_model, model_type) - - try: - config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json") - # config = json.loads(os.path.join(self.model_path, "config.json")) - except Exception: - raise Exception("Invalid vae model! (config.json not found or invalid)") - - try: - vae_class_name = config.get("_class_name", "AutoencoderKL") - self.vae_class = self._hf_definition_to_type(["diffusers", vae_class_name]) - self.model_size = calc_model_size_by_fs(self.model_path) - except Exception: - raise Exception("Invalid vae model! (Unkown vae type)") - - def get_size(self, child_type: Optional[SubModelType] = None): - if child_type is not None: - raise Exception("There is no child models in vae model") - return self.model_size - - def get_model( - self, - torch_dtype: Optional[torch.dtype], - child_type: Optional[SubModelType] = None, - ): - if child_type is not None: - raise Exception("There is no child models in vae model") - - model = self.vae_class.from_pretrained( - self.model_path, - torch_dtype=torch_dtype, - ) - # calc more accurate size - self.model_size = calc_model_size_by_data(model) - return model - - @classproperty - def save_to_config(cls) -> bool: - return False - - @classmethod - def detect_format(cls, path: str): - if not os.path.exists(path): - raise ModelNotFoundException(f"Does not exist as local file: {path}") - - if os.path.isdir(path): - if os.path.exists(os.path.join(path, "config.json")): - return VaeModelFormat.Diffusers - - if os.path.isfile(path): - if any(path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]): - return VaeModelFormat.Checkpoint - - raise InvalidModelException(f"Not a valid model: {path}") - - @classmethod - def convert_if_required( - cls, - model_path: str, - output_path: str, - config: ModelConfigBase, # empty config or config of parent model - base_model: BaseModelType, - ) -> str: - if cls.detect_format(model_path) == VaeModelFormat.Checkpoint: - return _convert_vae_ckpt_and_cache( - weights_path=model_path, - output_path=output_path, - base_model=base_model, - model_config=config, - ) - else: - return model_path - - -# TODO: rework -def _convert_vae_ckpt_and_cache( - weights_path: str, - output_path: str, - base_model: BaseModelType, - model_config: ModelConfigBase, -) -> str: - """ - Convert the VAE indicated in mconfig into a diffusers AutoencoderKL - object, cache it to disk, and return Path to converted - file. If already on disk then just returns Path. - """ - app_config = InvokeAIAppConfig.get_config() - weights_path = app_config.root_dir / weights_path - output_path = Path(output_path) - - """ - this size used only in when tiling enabled to separate input in tiles - sizes in configs from stable diffusion githubs(1 and 2) set to 256 - on huggingface it: - 1.5 - 512 - 1.5-inpainting - 256 - 2-inpainting - 512 - 2-depth - 256 - 2-base - 512 - 2 - 768 - 2.1-base - 768 - 2.1 - 768 - """ - image_size = 512 - - # return cached version if it exists - if output_path.exists(): - return output_path - - if base_model in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}: - from .stable_diffusion import _select_ckpt_config - - # all sd models use same vae settings - config_file = _select_ckpt_config(base_model, ModelVariantType.Normal) - else: - raise Exception(f"Vae conversion not supported for model type: {base_model}") - - # this avoids circular import error - from ..convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers - - if weights_path.suffix == ".safetensors": - checkpoint = safetensors.torch.load_file(weights_path, device="cpu") - else: - checkpoint = torch.load(weights_path, map_location="cpu") - - # sometimes weights are hidden under "state_dict", and sometimes not - if "state_dict" in checkpoint: - checkpoint = checkpoint["state_dict"] - - config = OmegaConf.load(app_config.root_path / config_file) - - vae_model = convert_ldm_vae_to_diffusers( - checkpoint=checkpoint, - vae_config=config, - image_size=image_size, - ) - vae_model.save_pretrained(output_path, safe_serialization=True) - return output_path diff --git a/invokeai/backend/model_management_OLD/seamless.py b/invokeai/backend/model_management_OLD/seamless.py deleted file mode 100644 index fb9112b56dc..00000000000 --- a/invokeai/backend/model_management_OLD/seamless.py +++ /dev/null @@ -1,84 +0,0 @@ -from __future__ import annotations - -from contextlib import contextmanager -from typing import Callable, List, Union - -import torch.nn as nn -from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL -from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel - - -def _conv_forward_asymmetric(self, input, weight, bias): - """ - Patch for Conv2d._conv_forward that supports asymmetric padding - """ - working = nn.functional.pad(input, self.asymmetric_padding["x"], mode=self.asymmetric_padding_mode["x"]) - working = nn.functional.pad(working, self.asymmetric_padding["y"], mode=self.asymmetric_padding_mode["y"]) - return nn.functional.conv2d( - working, - weight, - bias, - self.stride, - nn.modules.utils._pair(0), - self.dilation, - self.groups, - ) - - -@contextmanager -def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL], seamless_axes: List[str]): - # Callable: (input: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor - to_restore: list[tuple[nn.Conv2d | nn.ConvTranspose2d, Callable]] = [] - try: - # Hard coded to skip down block layers, allowing for seamless tiling at the expense of prompt adherence - skipped_layers = 1 - for m_name, m in model.named_modules(): - if not isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): - continue - - if isinstance(model, UNet2DConditionModel) and m_name.startswith("down_blocks.") and ".resnets." in m_name: - # down_blocks.1.resnets.1.conv1 - _, block_num, _, resnet_num, submodule_name = m_name.split(".") - block_num = int(block_num) - resnet_num = int(resnet_num) - - if block_num >= len(model.down_blocks) - skipped_layers: - continue - - # Skip the second resnet (could be configurable) - if resnet_num > 0: - continue - - # Skip Conv2d layers (could be configurable) - if submodule_name == "conv2": - continue - - m.asymmetric_padding_mode = {} - m.asymmetric_padding = {} - m.asymmetric_padding_mode["x"] = "circular" if ("x" in seamless_axes) else "constant" - m.asymmetric_padding["x"] = ( - m._reversed_padding_repeated_twice[0], - m._reversed_padding_repeated_twice[1], - 0, - 0, - ) - m.asymmetric_padding_mode["y"] = "circular" if ("y" in seamless_axes) else "constant" - m.asymmetric_padding["y"] = ( - 0, - 0, - m._reversed_padding_repeated_twice[2], - m._reversed_padding_repeated_twice[3], - ) - - to_restore.append((m, m._conv_forward)) - m._conv_forward = _conv_forward_asymmetric.__get__(m, nn.Conv2d) - - yield - - finally: - for module, orig_conv_forward in to_restore: - module._conv_forward = orig_conv_forward - if hasattr(module, "asymmetric_padding_mode"): - del module.asymmetric_padding_mode - if hasattr(module, "asymmetric_padding"): - del module.asymmetric_padding diff --git a/invokeai/backend/model_management_OLD/util.py b/invokeai/backend/model_management_OLD/util.py deleted file mode 100644 index f4737d9f0b5..00000000000 --- a/invokeai/backend/model_management_OLD/util.py +++ /dev/null @@ -1,79 +0,0 @@ -# Copyright (c) 2023 The InvokeAI Development Team -"""Utilities used by the Model Manager""" - - -def lora_token_vector_length(checkpoint: dict) -> int: - """ - Given a checkpoint in memory, return the lora token vector length - - :param checkpoint: The checkpoint - """ - - def _get_shape_1(key: str, tensor, checkpoint) -> int: - lora_token_vector_length = None - - if "." not in key: - return lora_token_vector_length # wrong key format - model_key, lora_key = key.split(".", 1) - - # check lora/locon - if lora_key == "lora_down.weight": - lora_token_vector_length = tensor.shape[1] - - # check loha (don't worry about hada_t1/hada_t2 as it used only in 4d shapes) - elif lora_key in ["hada_w1_b", "hada_w2_b"]: - lora_token_vector_length = tensor.shape[1] - - # check lokr (don't worry about lokr_t2 as it used only in 4d shapes) - elif "lokr_" in lora_key: - if model_key + ".lokr_w1" in checkpoint: - _lokr_w1 = checkpoint[model_key + ".lokr_w1"] - elif model_key + "lokr_w1_b" in checkpoint: - _lokr_w1 = checkpoint[model_key + ".lokr_w1_b"] - else: - return lora_token_vector_length # unknown format - - if model_key + ".lokr_w2" in checkpoint: - _lokr_w2 = checkpoint[model_key + ".lokr_w2"] - elif model_key + "lokr_w2_b" in checkpoint: - _lokr_w2 = checkpoint[model_key + ".lokr_w2_b"] - else: - return lora_token_vector_length # unknown format - - lora_token_vector_length = _lokr_w1.shape[1] * _lokr_w2.shape[1] - - elif lora_key == "diff": - lora_token_vector_length = tensor.shape[1] - - # ia3 can be detected only by shape[0] in text encoder - elif lora_key == "weight" and "lora_unet_" not in model_key: - lora_token_vector_length = tensor.shape[0] - - return lora_token_vector_length - - lora_token_vector_length = None - lora_te1_length = None - lora_te2_length = None - for key, tensor in checkpoint.items(): - if key.startswith("lora_unet_") and ("_attn2_to_k." in key or "_attn2_to_v." in key): - lora_token_vector_length = _get_shape_1(key, tensor, checkpoint) - elif key.startswith("lora_unet_") and ( - "time_emb_proj.lora_down" in key - ): # recognizes format at https://civitai.com/models/224641 - lora_token_vector_length = _get_shape_1(key, tensor, checkpoint) - elif key.startswith("lora_te") and "_self_attn_" in key: - tmp_length = _get_shape_1(key, tensor, checkpoint) - if key.startswith("lora_te_"): - lora_token_vector_length = tmp_length - elif key.startswith("lora_te1_"): - lora_te1_length = tmp_length - elif key.startswith("lora_te2_"): - lora_te2_length = tmp_length - - if lora_te1_length is not None and lora_te2_length is not None: - lora_token_vector_length = lora_te1_length + lora_te2_length - - if lora_token_vector_length is not None: - break - - return lora_token_vector_length diff --git a/invokeai/backend/model_manager/__init__.py b/invokeai/backend/model_manager/__init__.py index 98cc5054c73..88356d04686 100644 --- a/invokeai/backend/model_manager/__init__.py +++ b/invokeai/backend/model_manager/__init__.py @@ -1,5 +1,4 @@ """Re-export frequently-used symbols from the Model Manager backend.""" - from .config import ( AnyModel, AnyModelConfig, @@ -33,3 +32,42 @@ "SchedulerPredictionType", "SubModelType", ] + +########## to help populate the openapi_schema with format enums for each config ########### +# This code is no longer necessary? +# leave it here just in case +# +# import inspect +# from enum import Enum +# from typing import Any, Iterable, Dict, get_args, Set +# def _expand(something: Any) -> Iterable[type]: +# if isinstance(something, type): +# yield something +# else: +# for x in get_args(something): +# for y in _expand(x): +# yield y + +# def _find_format(cls: type) -> Iterable[Enum]: +# if hasattr(inspect, "get_annotations"): +# fields = inspect.get_annotations(cls) +# else: +# fields = cls.__annotations__ +# if "format" in fields: +# for x in get_args(fields["format"]): +# yield x +# for parent_class in cls.__bases__: +# for x in _find_format(parent_class): +# yield x +# return None + +# def get_model_config_formats() -> Dict[str, Set[Enum]]: +# result: Dict[str, Set[Enum]] = {} +# for model_config in _expand(AnyModelConfig): +# for field in _find_format(model_config): +# if field is None: +# continue +# if not result.get(model_config.__qualname__): +# result[model_config.__qualname__] = set() +# result[model_config.__qualname__].add(field) +# return result diff --git a/invokeai/backend/model_manager/load/__init__.py b/invokeai/backend/model_manager/load/__init__.py index a3a840b6259..a0421017db9 100644 --- a/invokeai/backend/model_manager/load/__init__.py +++ b/invokeai/backend/model_manager/load/__init__.py @@ -6,12 +6,22 @@ from pathlib import Path from .convert_cache.convert_cache_default import ModelConvertCache -from .load_base import AnyModelLoader, LoadedModel +from .load_base import LoadedModel, ModelLoaderBase +from .load_default import ModelLoader from .model_cache.model_cache_default import ModelCache +from .model_loader_registry import ModelLoaderRegistry, ModelLoaderRegistryBase # This registers the subclasses that implement loaders of specific model types loaders = [x.stem for x in Path(Path(__file__).parent, "model_loaders").glob("*.py") if x.stem != "__init__"] for module in loaders: import_module(f"{__package__}.model_loaders.{module}") -__all__ = ["AnyModelLoader", "LoadedModel", "ModelCache", "ModelConvertCache"] +__all__ = [ + "LoadedModel", + "ModelCache", + "ModelConvertCache", + "ModelLoaderBase", + "ModelLoader", + "ModelLoaderRegistryBase", + "ModelLoaderRegistry", +] diff --git a/invokeai/backend/model_manager/load/load_base.py b/invokeai/backend/model_manager/load/load_base.py index 4c5e899aa3b..b8ce56eb16d 100644 --- a/invokeai/backend/model_manager/load/load_base.py +++ b/invokeai/backend/model_manager/load/load_base.py @@ -1,37 +1,22 @@ # Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team """ Base class for model loading in InvokeAI. - -Use like this: - - loader = AnyModelLoader(...) - loaded_model = loader.get_model('019ab39adfa1840455') - with loaded_model as model: # context manager moves model into VRAM - # do something with loaded_model """ -import hashlib from abc import ABC, abstractmethod from dataclasses import dataclass from logging import Logger from pathlib import Path -from typing import Any, Callable, Dict, Optional, Tuple, Type +from typing import Any, Optional from invokeai.app.services.config import InvokeAIAppConfig from invokeai.backend.model_manager.config import ( AnyModel, AnyModelConfig, - BaseModelType, - ModelConfigBase, - ModelFormat, - ModelType, SubModelType, - VaeCheckpointConfig, - VaeDiffusersConfig, ) from invokeai.backend.model_manager.load.convert_cache.convert_cache_base import ModelConvertCacheBase from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase -from invokeai.backend.util.logging import InvokeAILogger @dataclass @@ -56,6 +41,14 @@ def model(self) -> AnyModel: return self._locker.model +# TODO(MM2): +# Some "intermediary" subclasses in the ModelLoaderBase class hierarchy define methods that their subclasses don't +# know about. I think the problem may be related to this class being an ABC. +# +# For example, GenericDiffusersLoader defines `get_hf_load_class()`, and StableDiffusionDiffusersModel attempts to +# call it. However, the method is not defined in the ABC, so it is not guaranteed to be implemented. + + class ModelLoaderBase(ABC): """Abstract base class for loading models into RAM/VRAM.""" @@ -71,7 +64,7 @@ def __init__( pass @abstractmethod - def load_model(self, model_config: ModelConfigBase, submodel_type: Optional[SubModelType] = None) -> LoadedModel: + def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel: """ Return a model given its confguration. @@ -90,106 +83,3 @@ def get_size_fs( ) -> int: """Return size in bytes of the model, calculated before loading.""" pass - - -# TO DO: Better name? -class AnyModelLoader: - """This class manages the model loaders and invokes the correct one to load a model of given base and type.""" - - # this tracks the loader subclasses - _registry: Dict[str, Type[ModelLoaderBase]] = {} - _logger: Logger = InvokeAILogger.get_logger() - - def __init__( - self, - app_config: InvokeAIAppConfig, - logger: Logger, - ram_cache: ModelCacheBase[AnyModel], - convert_cache: ModelConvertCacheBase, - ): - """Initialize AnyModelLoader with its dependencies.""" - self._app_config = app_config - self._logger = logger - self._ram_cache = ram_cache - self._convert_cache = convert_cache - - @property - def ram_cache(self) -> ModelCacheBase[AnyModel]: - """Return the RAM cache associated used by the loaders.""" - return self._ram_cache - - @property - def convert_cache(self) -> ModelConvertCacheBase: - """Return the convert cache associated used by the loaders.""" - return self._convert_cache - - def load_model(self, model_config: ModelConfigBase, submodel_type: Optional[SubModelType] = None) -> LoadedModel: - """ - Return a model given its configuration. - - :param key: model key, as known to the config backend - :param submodel_type: an ModelType enum indicating the portion of - the model to retrieve (e.g. ModelType.Vae) - """ - implementation, model_config, submodel_type = self.__class__.get_implementation(model_config, submodel_type) - return implementation( - app_config=self._app_config, - logger=self._logger, - ram_cache=self._ram_cache, - convert_cache=self._convert_cache, - ).load_model(model_config, submodel_type) - - @staticmethod - def _to_registry_key(base: BaseModelType, type: ModelType, format: ModelFormat) -> str: - return "-".join([base.value, type.value, format.value]) - - @classmethod - def get_implementation( - cls, config: ModelConfigBase, submodel_type: Optional[SubModelType] - ) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]: - """Get subclass of ModelLoaderBase registered to handle base and type.""" - # We have to handle VAE overrides here because this will change the model type and the corresponding implementation returned - conf2, submodel_type = cls._handle_subtype_overrides(config, submodel_type) - - key1 = cls._to_registry_key(conf2.base, conf2.type, conf2.format) # for a specific base type - key2 = cls._to_registry_key(BaseModelType.Any, conf2.type, conf2.format) # with wildcard Any - implementation = cls._registry.get(key1) or cls._registry.get(key2) - if not implementation: - raise NotImplementedError( - f"No subclass of LoadedModel is registered for base={config.base}, type={config.type}, format={config.format}" - ) - return implementation, conf2, submodel_type - - @classmethod - def _handle_subtype_overrides( - cls, config: ModelConfigBase, submodel_type: Optional[SubModelType] - ) -> Tuple[ModelConfigBase, Optional[SubModelType]]: - if submodel_type == SubModelType.Vae and hasattr(config, "vae") and config.vae is not None: - model_path = Path(config.vae) - config_class = ( - VaeCheckpointConfig if model_path.suffix in [".pt", ".safetensors", ".ckpt"] else VaeDiffusersConfig - ) - hash = hashlib.md5(model_path.as_posix().encode("utf-8")).hexdigest() - new_conf = config_class(path=model_path.as_posix(), name=model_path.stem, base=config.base, key=hash) - submodel_type = None - else: - new_conf = config - return new_conf, submodel_type - - @classmethod - def register( - cls, type: ModelType, format: ModelFormat, base: BaseModelType = BaseModelType.Any - ) -> Callable[[Type[ModelLoaderBase]], Type[ModelLoaderBase]]: - """Define a decorator which registers the subclass of loader.""" - - def decorator(subclass: Type[ModelLoaderBase]) -> Type[ModelLoaderBase]: - cls._logger.debug(f"Registering class {subclass.__name__} to load models of type {base}/{type}/{format}") - key = cls._to_registry_key(base, type, format) - if key in cls._registry: - raise Exception( - f"{subclass.__name__} is trying to register as a loader for {base}/{type}/{format}, but this type of model has already been registered by {cls._registry[key].__name__}" - ) - cls._registry[key] = subclass - return subclass - - return decorator diff --git a/invokeai/backend/model_manager/load/load_default.py b/invokeai/backend/model_manager/load/load_default.py index 79c9311de1d..642cffaf4be 100644 --- a/invokeai/backend/model_manager/load/load_default.py +++ b/invokeai/backend/model_manager/load/load_default.py @@ -1,13 +1,9 @@ # Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team """Default implementation of model loading in InvokeAI.""" -import sys from logging import Logger from pathlib import Path -from typing import Any, Dict, Optional, Tuple - -from diffusers import ModelMixin -from diffusers.configuration_utils import ConfigMixin +from typing import Optional, Tuple from invokeai.app.services.config import InvokeAIAppConfig from invokeai.backend.model_manager import ( @@ -25,17 +21,6 @@ from invokeai.backend.util.devices import choose_torch_device, torch_dtype -class ConfigLoader(ConfigMixin): - """Subclass of ConfigMixin for loading diffusers configuration files.""" - - @classmethod - def load_config(cls, *args: Any, **kwargs: Any) -> Dict[str, Any]: - """Load a diffusrs ConfigMixin configuration.""" - cls.config_name = kwargs.pop("config_name") - # Diffusers doesn't provide typing info - return super().load_config(*args, **kwargs) # type: ignore - - # TO DO: The loader is not thread safe! class ModelLoader(ModelLoaderBase): """Default implementation of ModelLoaderBase.""" @@ -137,43 +122,6 @@ def get_size_fs( variant=config.repo_variant if hasattr(config, "repo_variant") else None, ) - def _load_diffusers_config(self, model_path: Path, config_name: str = "config.json") -> Dict[str, Any]: - return ConfigLoader.load_config(model_path, config_name=config_name) - - # TO DO: Add exception handling - def _hf_definition_to_type(self, module: str, class_name: str) -> ModelMixin: # fix with correct type - if module in ["diffusers", "transformers"]: - res_type = sys.modules[module] - else: - res_type = sys.modules["diffusers"].pipelines - result: ModelMixin = getattr(res_type, class_name) - return result - - # TO DO: Add exception handling - def _get_hf_load_class(self, model_path: Path, submodel_type: Optional[SubModelType] = None) -> ModelMixin: - if submodel_type: - try: - config = self._load_diffusers_config(model_path, config_name="model_index.json") - module, class_name = config[submodel_type.value] - return self._hf_definition_to_type(module=module, class_name=class_name) - except KeyError as e: - raise InvalidModelConfigException( - f'The "{submodel_type}" submodel is not available for this model.' - ) from e - else: - try: - config = self._load_diffusers_config(model_path, config_name="config.json") - class_name = config.get("_class_name", None) - if class_name: - return self._hf_definition_to_type(module="diffusers", class_name=class_name) - if config.get("model_type", None) == "clip_vision_model": - class_name = config.get("architectures")[0] - return self._hf_definition_to_type(module="transformers", class_name=class_name) - if not class_name: - raise InvalidModelConfigException("Unable to decifer Load Class based on given config.json") - except KeyError as e: - raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e - # This needs to be implemented in subclasses that handle checkpoints def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path: raise NotImplementedError diff --git a/invokeai/backend/model_manager/load/memory_snapshot.py b/invokeai/backend/model_manager/load/memory_snapshot.py index 209d7166f36..195e39361b4 100644 --- a/invokeai/backend/model_manager/load/memory_snapshot.py +++ b/invokeai/backend/model_manager/load/memory_snapshot.py @@ -55,7 +55,7 @@ def capture(cls, run_garbage_collector: bool = True) -> Self: vram = None try: - malloc_info = LibcUtil().mallinfo2() # type: ignore + malloc_info = LibcUtil().mallinfo2() except (OSError, AttributeError): # OSError: This is expected in environments that do not have the 'libc.so.6' shared library. # AttributeError: This is expected in environments that have `libc.so.6` but do not have the `mallinfo2` (e.g. glibc < 2.33) diff --git a/invokeai/backend/model_manager/load/model_loader_registry.py b/invokeai/backend/model_manager/load/model_loader_registry.py new file mode 100644 index 00000000000..ce1110e749b --- /dev/null +++ b/invokeai/backend/model_manager/load/model_loader_registry.py @@ -0,0 +1,122 @@ +# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team +""" +This module implements a system in which model loaders register the +type, base and format of models that they know how to load. + +Use like this: + + cls, model_config, submodel_type = ModelLoaderRegistry.get_implementation(model_config, submodel_type) # type: ignore + loaded_model = cls( + app_config=app_config, + logger=logger, + ram_cache=ram_cache, + convert_cache=convert_cache + ).load_model(model_config, submodel_type) + +""" +import hashlib +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Callable, Dict, Optional, Tuple, Type + +from ..config import ( + AnyModelConfig, + BaseModelType, + ModelConfigBase, + ModelFormat, + ModelType, + SubModelType, + VaeCheckpointConfig, + VaeDiffusersConfig, +) +from . import ModelLoaderBase + + +class ModelLoaderRegistryBase(ABC): + """This class allows model loaders to register their type, base and format.""" + + @classmethod + @abstractmethod + def register( + cls, type: ModelType, format: ModelFormat, base: BaseModelType = BaseModelType.Any + ) -> Callable[[Type[ModelLoaderBase]], Type[ModelLoaderBase]]: + """Define a decorator which registers the subclass of loader.""" + + @classmethod + @abstractmethod + def get_implementation( + cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] + ) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]: + """ + Get subclass of ModelLoaderBase registered to handle base and type. + + Parameters: + :param config: Model configuration record, as returned by ModelRecordService + :param submodel_type: Submodel to fetch (main models only) + :return: tuple(loader_class, model_config, submodel_type) + + Note that the returned model config may be different from one what passed + in, in the event that a submodel type is provided. + """ + + +class ModelLoaderRegistry: + """ + This class allows model loaders to register their type, base and format. + """ + + _registry: Dict[str, Type[ModelLoaderBase]] = {} + + @classmethod + def register( + cls, type: ModelType, format: ModelFormat, base: BaseModelType = BaseModelType.Any + ) -> Callable[[Type[ModelLoaderBase]], Type[ModelLoaderBase]]: + """Define a decorator which registers the subclass of loader.""" + + def decorator(subclass: Type[ModelLoaderBase]) -> Type[ModelLoaderBase]: + key = cls._to_registry_key(base, type, format) + if key in cls._registry: + raise Exception( + f"{subclass.__name__} is trying to register as a loader for {base}/{type}/{format}, but this type of model has already been registered by {cls._registry[key].__name__}" + ) + cls._registry[key] = subclass + return subclass + + return decorator + + @classmethod + def get_implementation( + cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] + ) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]: + """Get subclass of ModelLoaderBase registered to handle base and type.""" + # We have to handle VAE overrides here because this will change the model type and the corresponding implementation returned + conf2, submodel_type = cls._handle_subtype_overrides(config, submodel_type) + + key1 = cls._to_registry_key(conf2.base, conf2.type, conf2.format) # for a specific base type + key2 = cls._to_registry_key(BaseModelType.Any, conf2.type, conf2.format) # with wildcard Any + implementation = cls._registry.get(key1) or cls._registry.get(key2) + if not implementation: + raise NotImplementedError( + f"No subclass of LoadedModel is registered for base={config.base}, type={config.type}, format={config.format}" + ) + return implementation, conf2, submodel_type + + @classmethod + def _handle_subtype_overrides( + cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] + ) -> Tuple[ModelConfigBase, Optional[SubModelType]]: + if submodel_type == SubModelType.Vae and hasattr(config, "vae") and config.vae is not None: + model_path = Path(config.vae) + config_class = ( + VaeCheckpointConfig if model_path.suffix in [".pt", ".safetensors", ".ckpt"] else VaeDiffusersConfig + ) + hash = hashlib.md5(model_path.as_posix().encode("utf-8")).hexdigest() + new_conf = config_class(path=model_path.as_posix(), name=model_path.stem, base=config.base, key=hash) + submodel_type = None + else: + new_conf = config + return new_conf, submodel_type + + @staticmethod + def _to_registry_key(base: BaseModelType, type: ModelType, format: ModelFormat) -> str: + return "-".join([base.value, type.value, format.value]) diff --git a/invokeai/backend/model_manager/load/model_loaders/controlnet.py b/invokeai/backend/model_manager/load/model_loaders/controlnet.py index d446d079336..43393f5a847 100644 --- a/invokeai/backend/model_manager/load/model_loaders/controlnet.py +++ b/invokeai/backend/model_manager/load/model_loaders/controlnet.py @@ -13,13 +13,13 @@ ModelType, ) from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_controlnet_to_diffusers -from invokeai.backend.model_manager.load.load_base import AnyModelLoader +from .. import ModelLoaderRegistry from .generic_diffusers import GenericDiffusersLoader -@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Diffusers) -@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Checkpoint) +@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Diffusers) +@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Checkpoint) class ControlnetLoader(GenericDiffusersLoader): """Class to load ControlNet models.""" diff --git a/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py b/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py index 114e317f3c6..9a9b25aec53 100644 --- a/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py +++ b/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py @@ -1,24 +1,27 @@ # Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team """Class for simple diffusers model loading in InvokeAI.""" +import sys from pathlib import Path -from typing import Optional +from typing import Any, Dict, Optional + +from diffusers import ConfigMixin, ModelMixin from invokeai.backend.model_manager import ( AnyModel, BaseModelType, + InvalidModelConfigException, ModelFormat, ModelRepoVariant, ModelType, SubModelType, ) -from ..load_base import AnyModelLoader -from ..load_default import ModelLoader +from .. import ModelLoader, ModelLoaderRegistry -@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.CLIPVision, format=ModelFormat.Diffusers) -@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.T2IAdapter, format=ModelFormat.Diffusers) +@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.CLIPVision, format=ModelFormat.Diffusers) +@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T2IAdapter, format=ModelFormat.Diffusers) class GenericDiffusersLoader(ModelLoader): """Class to load simple diffusers models.""" @@ -28,9 +31,60 @@ def _load_model( model_variant: Optional[ModelRepoVariant] = None, submodel_type: Optional[SubModelType] = None, ) -> AnyModel: - model_class = self._get_hf_load_class(model_path) + model_class = self.get_hf_load_class(model_path) if submodel_type is not None: raise Exception(f"There are no submodels in models of type {model_class}") variant = model_variant.value if model_variant else None result: AnyModel = model_class.from_pretrained(model_path, torch_dtype=self._torch_dtype, variant=variant) # type: ignore return result + + # TO DO: Add exception handling + def get_hf_load_class(self, model_path: Path, submodel_type: Optional[SubModelType] = None) -> ModelMixin: + """Given the model path and submodel, returns the diffusers ModelMixin subclass needed to load.""" + if submodel_type: + try: + config = self._load_diffusers_config(model_path, config_name="model_index.json") + module, class_name = config[submodel_type.value] + result = self._hf_definition_to_type(module=module, class_name=class_name) + except KeyError as e: + raise InvalidModelConfigException( + f'The "{submodel_type}" submodel is not available for this model.' + ) from e + else: + try: + config = self._load_diffusers_config(model_path, config_name="config.json") + class_name = config.get("_class_name", None) + if class_name: + result = self._hf_definition_to_type(module="diffusers", class_name=class_name) + if config.get("model_type", None) == "clip_vision_model": + class_name = config.get("architectures") + assert class_name is not None + result = self._hf_definition_to_type(module="transformers", class_name=class_name[0]) + if not class_name: + raise InvalidModelConfigException("Unable to decifer Load Class based on given config.json") + except KeyError as e: + raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e + return result + + # TO DO: Add exception handling + def _hf_definition_to_type(self, module: str, class_name: str) -> ModelMixin: # fix with correct type + if module in ["diffusers", "transformers"]: + res_type = sys.modules[module] + else: + res_type = sys.modules["diffusers"].pipelines + result: ModelMixin = getattr(res_type, class_name) + return result + + def _load_diffusers_config(self, model_path: Path, config_name: str = "config.json") -> Dict[str, Any]: + return ConfigLoader.load_config(model_path, config_name=config_name) + + +class ConfigLoader(ConfigMixin): + """Subclass of ConfigMixin for loading diffusers configuration files.""" + + @classmethod + def load_config(cls, *args: Any, **kwargs: Any) -> Dict[str, Any]: + """Load a diffusrs ConfigMixin configuration.""" + cls.config_name = kwargs.pop("config_name") + # Diffusers doesn't provide typing info + return super().load_config(*args, **kwargs) # type: ignore diff --git a/invokeai/backend/model_manager/load/model_loaders/ip_adapter.py b/invokeai/backend/model_manager/load/model_loaders/ip_adapter.py index 27ced41c1e9..7d25e9d218c 100644 --- a/invokeai/backend/model_manager/load/model_loaders/ip_adapter.py +++ b/invokeai/backend/model_manager/load/model_loaders/ip_adapter.py @@ -15,11 +15,10 @@ ModelType, SubModelType, ) -from invokeai.backend.model_manager.load.load_base import AnyModelLoader -from invokeai.backend.model_manager.load.load_default import ModelLoader +from invokeai.backend.model_manager.load import ModelLoader, ModelLoaderRegistry -@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.IPAdapter, format=ModelFormat.InvokeAI) +@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.IPAdapter, format=ModelFormat.InvokeAI) class IPAdapterInvokeAILoader(ModelLoader): """Class to load IP Adapter diffusers models.""" diff --git a/invokeai/backend/model_manager/load/model_loaders/lora.py b/invokeai/backend/model_manager/load/model_loaders/lora.py index 6ff2dcc9182..fe804ef5654 100644 --- a/invokeai/backend/model_manager/load/model_loaders/lora.py +++ b/invokeai/backend/model_manager/load/model_loaders/lora.py @@ -18,13 +18,13 @@ SubModelType, ) from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase -from invokeai.backend.model_manager.load.load_base import AnyModelLoader -from invokeai.backend.model_manager.load.load_default import ModelLoader from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase +from .. import ModelLoader, ModelLoaderRegistry -@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Diffusers) -@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Lycoris) + +@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Diffusers) +@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Lycoris) class LoraLoader(ModelLoader): """Class to load LoRA models.""" diff --git a/invokeai/backend/model_manager/load/model_loaders/onnx.py b/invokeai/backend/model_manager/load/model_loaders/onnx.py index 935a6b7c953..38f0274acc6 100644 --- a/invokeai/backend/model_manager/load/model_loaders/onnx.py +++ b/invokeai/backend/model_manager/load/model_loaders/onnx.py @@ -13,13 +13,14 @@ ModelType, SubModelType, ) -from invokeai.backend.model_manager.load.load_base import AnyModelLoader -from invokeai.backend.model_manager.load.load_default import ModelLoader +from .. import ModelLoaderRegistry +from .generic_diffusers import GenericDiffusersLoader -@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Onnx) -@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Olive) -class OnnyxDiffusersModel(ModelLoader): + +@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Onnx) +@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Olive) +class OnnyxDiffusersModel(GenericDiffusersLoader): """Class to load onnx models.""" def _load_model( @@ -30,7 +31,7 @@ def _load_model( ) -> AnyModel: if not submodel_type is not None: raise Exception("A submodel type must be provided when loading onnx pipelines.") - load_class = self._get_hf_load_class(model_path, submodel_type) + load_class = self.get_hf_load_class(model_path, submodel_type) variant = model_variant.value if model_variant else None model_path = model_path / submodel_type.value result: AnyModel = load_class.from_pretrained( diff --git a/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py b/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py index 23b4e1fccd6..5884f84e8da 100644 --- a/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py +++ b/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py @@ -19,13 +19,14 @@ ) from invokeai.backend.model_manager.config import MainCheckpointConfig from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ckpt_to_diffusers -from invokeai.backend.model_manager.load.load_base import AnyModelLoader -from invokeai.backend.model_manager.load.load_default import ModelLoader +from .. import ModelLoaderRegistry +from .generic_diffusers import GenericDiffusersLoader -@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Main, format=ModelFormat.Diffusers) -@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Main, format=ModelFormat.Checkpoint) -class StableDiffusionDiffusersModel(ModelLoader): + +@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Main, format=ModelFormat.Diffusers) +@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Main, format=ModelFormat.Checkpoint) +class StableDiffusionDiffusersModel(GenericDiffusersLoader): """Class to load main models.""" model_base_to_model_type = { @@ -43,7 +44,7 @@ def _load_model( ) -> AnyModel: if not submodel_type is not None: raise Exception("A submodel type must be provided when loading main pipelines.") - load_class = self._get_hf_load_class(model_path, submodel_type) + load_class = self.get_hf_load_class(model_path, submodel_type) variant = model_variant.value if model_variant else None model_path = model_path / submodel_type.value result: AnyModel = load_class.from_pretrained( diff --git a/invokeai/backend/model_manager/load/model_loaders/textual_inversion.py b/invokeai/backend/model_manager/load/model_loaders/textual_inversion.py index 94767479609..094d4d7c5c3 100644 --- a/invokeai/backend/model_manager/load/model_loaders/textual_inversion.py +++ b/invokeai/backend/model_manager/load/model_loaders/textual_inversion.py @@ -5,7 +5,6 @@ from pathlib import Path from typing import Optional, Tuple -from invokeai.backend.textual_inversion import TextualInversionModelRaw from invokeai.backend.model_manager import ( AnyModel, AnyModelConfig, @@ -15,12 +14,15 @@ ModelType, SubModelType, ) -from invokeai.backend.model_manager.load.load_base import AnyModelLoader -from invokeai.backend.model_manager.load.load_default import ModelLoader +from invokeai.backend.textual_inversion import TextualInversionModelRaw +from .. import ModelLoader, ModelLoaderRegistry -@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.TextualInversion, format=ModelFormat.EmbeddingFile) -@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.TextualInversion, format=ModelFormat.EmbeddingFolder) + +@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.TextualInversion, format=ModelFormat.EmbeddingFile) +@ModelLoaderRegistry.register( + base=BaseModelType.Any, type=ModelType.TextualInversion, format=ModelFormat.EmbeddingFolder +) class TextualInversionLoader(ModelLoader): """Class to load TI models.""" diff --git a/invokeai/backend/model_manager/load/model_loaders/vae.py b/invokeai/backend/model_manager/load/model_loaders/vae.py index 3983ea75950..7ade1494eb1 100644 --- a/invokeai/backend/model_manager/load/model_loaders/vae.py +++ b/invokeai/backend/model_manager/load/model_loaders/vae.py @@ -14,14 +14,14 @@ ModelType, ) from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers -from invokeai.backend.model_manager.load.load_base import AnyModelLoader +from .. import ModelLoaderRegistry from .generic_diffusers import GenericDiffusersLoader -@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Vae, format=ModelFormat.Diffusers) -@AnyModelLoader.register(base=BaseModelType.StableDiffusion1, type=ModelType.Vae, format=ModelFormat.Checkpoint) -@AnyModelLoader.register(base=BaseModelType.StableDiffusion2, type=ModelType.Vae, format=ModelFormat.Checkpoint) +@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Vae, format=ModelFormat.Diffusers) +@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Vae, format=ModelFormat.Checkpoint) +@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Vae, format=ModelFormat.Checkpoint) class VaeLoader(GenericDiffusersLoader): """Class to load VAE models.""" diff --git a/invokeai/backend/model_manager/load/optimizations.py b/invokeai/backend/model_manager/load/optimizations.py index a46d262175f..030fcfa639a 100644 --- a/invokeai/backend/model_manager/load/optimizations.py +++ b/invokeai/backend/model_manager/load/optimizations.py @@ -1,16 +1,16 @@ from contextlib import contextmanager +from typing import Any, Generator import torch -def _no_op(*args, **kwargs): +def _no_op(*args: Any, **kwargs: Any) -> None: pass @contextmanager -def skip_torch_weight_init(): - """A context manager that monkey-patches several of the common torch layers (torch.nn.Linear, torch.nn.Conv1d, etc.) - to skip weight initialization. +def skip_torch_weight_init() -> Generator[None, None, None]: + """Monkey patch several of the common torch layers (torch.nn.Linear, torch.nn.Conv1d, etc.) to skip weight initialization. By default, `torch.nn.Linear` and `torch.nn.ConvNd` layers initialize their weights (according to a particular distribution) when __init__ is called. This weight initialization step can take a significant amount of time, and is @@ -18,13 +18,14 @@ def skip_torch_weight_init(): monkey-patches common torch layers to skip the weight initialization step. """ torch_modules = [torch.nn.Linear, torch.nn.modules.conv._ConvNd, torch.nn.Embedding] - saved_functions = [m.reset_parameters for m in torch_modules] + saved_functions = [hasattr(m, "reset_parameters") and m.reset_parameters for m in torch_modules] try: for torch_module in torch_modules: + assert hasattr(torch_module, "reset_parameters") torch_module.reset_parameters = _no_op - yield None finally: for torch_module, saved_function in zip(torch_modules, saved_functions, strict=True): + assert hasattr(torch_module, "reset_parameters") torch_module.reset_parameters = saved_function diff --git a/invokeai/backend/model_manager/merge.py b/invokeai/backend/model_manager/merge.py index 108f1f0e6f7..7063cb907d2 100644 --- a/invokeai/backend/model_manager/merge.py +++ b/invokeai/backend/model_manager/merge.py @@ -13,7 +13,7 @@ import torch from diffusers import AutoPipelineForText2Image -from diffusers import logging as dlogging +from diffusers.utils import logging as dlogging from invokeai.app.services.model_install import ModelInstallServiceBase from invokeai.backend.util.devices import choose_torch_device, torch_dtype @@ -76,7 +76,7 @@ def merge_diffusion_models( custom_pipeline="checkpoint_merger", torch_dtype=dtype, variant=variant, - ) + ) # type: ignore merged_pipe = pipe.merge( pretrained_model_name_or_path_list=model_paths, alpha=alpha, diff --git a/invokeai/backend/model_manager/metadata/metadata_base.py b/invokeai/backend/model_manager/metadata/metadata_base.py index 5c3afcdc960..6e410d82220 100644 --- a/invokeai/backend/model_manager/metadata/metadata_base.py +++ b/invokeai/backend/model_manager/metadata/metadata_base.py @@ -54,8 +54,8 @@ class LicenseRestrictions(BaseModel): AllowDifferentLicense: bool = Field( description="if true, derivatives of this model be redistributed under a different license", default=False ) - AllowCommercialUse: CommercialUsage = Field( - description="Type of commercial use allowed or 'No' if no commercial use is allowed.", default_factory=set + AllowCommercialUse: Optional[CommercialUsage] = Field( + description="Type of commercial use allowed or 'No' if no commercial use is allowed.", default=None ) @@ -139,7 +139,10 @@ def credit_required(self) -> bool: @property def allow_commercial_use(self) -> bool: """Return True if commercial use is allowed.""" - return self.restrictions.AllowCommercialUse != CommercialUsage("None") + if self.restrictions.AllowCommercialUse is None: + return False + else: + return self.restrictions.AllowCommercialUse != CommercialUsage("None") @property def allow_derivatives(self) -> bool: diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index d511ffa875f..7de4289466d 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -8,7 +8,6 @@ from picklescan.scanner import scan_file_path import invokeai.backend.util.logging as logger -from .util.model_util import lora_token_vector_length, read_checkpoint_meta from invokeai.backend.util.util import SilenceWarnings from .config import ( @@ -23,6 +22,7 @@ SchedulerPredictionType, ) from .hash import FastModelHash +from .util.model_util import lora_token_vector_length, read_checkpoint_meta CkptType = Dict[str, Any] @@ -53,6 +53,7 @@ }, } + class ProbeBase(object): """Base class for probes.""" diff --git a/invokeai/backend/model_manager/search.py b/invokeai/backend/model_manager/search.py index 0ead22b743f..f7ef2e049d4 100644 --- a/invokeai/backend/model_manager/search.py +++ b/invokeai/backend/model_manager/search.py @@ -116,9 +116,9 @@ class ModelSearch(ModelSearchBase): # returns all models that have 'anime' in the path """ - models_found: Optional[Set[Path]] = Field(default=None) - scanned_dirs: Optional[Set[Path]] = Field(default=None) - pruned_paths: Optional[Set[Path]] = Field(default=None) + models_found: Set[Path] = Field(default_factory=set) + scanned_dirs: Set[Path] = Field(default_factory=set) + pruned_paths: Set[Path] = Field(default_factory=set) def search_started(self) -> None: self.models_found = set() diff --git a/invokeai/backend/model_manager/util/libc_util.py b/invokeai/backend/model_manager/util/libc_util.py index 1fbcae0a93c..ef1ac2f8a4b 100644 --- a/invokeai/backend/model_manager/util/libc_util.py +++ b/invokeai/backend/model_manager/util/libc_util.py @@ -35,7 +35,7 @@ class Struct_mallinfo2(ctypes.Structure): ("keepcost", ctypes.c_size_t), ] - def __str__(self): + def __str__(self) -> str: s = "" s += f"{'arena': <10}= {(self.arena/2**30):15.5f} # Non-mmapped space allocated (GB) (uordblks + fordblks)\n" s += f"{'ordblks': <10}= {(self.ordblks): >15} # Number of free chunks\n" @@ -62,7 +62,7 @@ class LibcUtil: TODO: Improve cross-OS compatibility of this class. """ - def __init__(self): + def __init__(self) -> None: self._libc = ctypes.cdll.LoadLibrary("libc.so.6") def mallinfo2(self) -> Struct_mallinfo2: @@ -72,4 +72,5 @@ def mallinfo2(self) -> Struct_mallinfo2: """ mallinfo2 = self._libc.mallinfo2 mallinfo2.restype = Struct_mallinfo2 - return mallinfo2() + result: Struct_mallinfo2 = mallinfo2() + return result diff --git a/invokeai/backend/model_manager/util/model_util.py b/invokeai/backend/model_manager/util/model_util.py index 6847a40878c..2e448520e56 100644 --- a/invokeai/backend/model_manager/util/model_util.py +++ b/invokeai/backend/model_manager/util/model_util.py @@ -1,12 +1,15 @@ """Utilities for parsing model files, used mostly by probe.py""" import json -import torch -from typing import Union from pathlib import Path +from typing import Dict, Optional, Union + +import safetensors +import torch from picklescan.scanner import scan_file_path -def _fast_safetensors_reader(path: str): + +def _fast_safetensors_reader(path: str) -> Dict[str, torch.Tensor]: checkpoint = {} device = torch.device("meta") with open(path, "rb") as f: @@ -37,10 +40,12 @@ def _fast_safetensors_reader(path: str): return checkpoint -def read_checkpoint_meta(path: Union[str, Path], scan: bool = False): + +def read_checkpoint_meta(path: Union[str, Path], scan: bool = False) -> Dict[str, torch.Tensor]: if str(path).endswith(".safetensors"): try: - checkpoint = _fast_safetensors_reader(path) + path_str = path.as_posix() if isinstance(path, Path) else path + checkpoint = _fast_safetensors_reader(path_str) except Exception: # TODO: create issue for support "meta"? checkpoint = safetensors.torch.load_file(path, device="cpu") @@ -52,14 +57,15 @@ def read_checkpoint_meta(path: Union[str, Path], scan: bool = False): checkpoint = torch.load(path, map_location=torch.device("meta")) return checkpoint -def lora_token_vector_length(checkpoint: dict) -> int: + +def lora_token_vector_length(checkpoint: Dict[str, torch.Tensor]) -> Optional[int]: """ Given a checkpoint in memory, return the lora token vector length :param checkpoint: The checkpoint """ - def _get_shape_1(key: str, tensor, checkpoint) -> int: + def _get_shape_1(key: str, tensor: torch.Tensor, checkpoint: Dict[str, torch.Tensor]) -> Optional[int]: lora_token_vector_length = None if "." not in key: diff --git a/invokeai/backend/onnx/onnx_runtime.py b/invokeai/backend/onnx/onnx_runtime.py index 9b2096abdf0..8916865dd52 100644 --- a/invokeai/backend/onnx/onnx_runtime.py +++ b/invokeai/backend/onnx/onnx_runtime.py @@ -8,6 +8,7 @@ import onnx from onnx import numpy_helper from onnxruntime import InferenceSession, SessionOptions, get_available_providers + from ..raw_model import RawModel ONNX_WEIGHTS_NAME = "model.onnx" @@ -15,7 +16,7 @@ # NOTE FROM LS: This was copied from Stalker's original implementation. # I have not yet gone through and fixed all the type hints -class IAIOnnxRuntimeModel: +class IAIOnnxRuntimeModel(RawModel): class _tensor_access: def __init__(self, model): # type: ignore self.model = model diff --git a/invokeai/backend/raw_model.py b/invokeai/backend/raw_model.py index 2e224d538b3..d0dc50c4560 100644 --- a/invokeai/backend/raw_model.py +++ b/invokeai/backend/raw_model.py @@ -10,5 +10,6 @@ that adds additional methods and attributes. """ + class RawModel: """Base class for 'Raw' model wrappers.""" diff --git a/invokeai/backend/stable_diffusion/seamless.py b/invokeai/backend/stable_diffusion/seamless.py index bfdf9e0c536..fb9112b56dc 100644 --- a/invokeai/backend/stable_diffusion/seamless.py +++ b/invokeai/backend/stable_diffusion/seamless.py @@ -1,10 +1,11 @@ from __future__ import annotations from contextlib import contextmanager -from typing import List, Union +from typing import Callable, List, Union import torch.nn as nn -from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL +from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel def _conv_forward_asymmetric(self, input, weight, bias): @@ -26,70 +27,51 @@ def _conv_forward_asymmetric(self, input, weight, bias): @contextmanager def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL], seamless_axes: List[str]): + # Callable: (input: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor + to_restore: list[tuple[nn.Conv2d | nn.ConvTranspose2d, Callable]] = [] try: - to_restore = [] - + # Hard coded to skip down block layers, allowing for seamless tiling at the expense of prompt adherence + skipped_layers = 1 for m_name, m in model.named_modules(): - if isinstance(model, UNet2DConditionModel): - if ".attentions." in m_name: - continue + if not isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): + continue - if ".resnets." in m_name: - if ".conv2" in m_name: - continue - if ".conv_shortcut" in m_name: - continue - - """ - if isinstance(model, UNet2DConditionModel): - if False and ".upsamplers." in m_name: - continue + if isinstance(model, UNet2DConditionModel) and m_name.startswith("down_blocks.") and ".resnets." in m_name: + # down_blocks.1.resnets.1.conv1 + _, block_num, _, resnet_num, submodule_name = m_name.split(".") + block_num = int(block_num) + resnet_num = int(resnet_num) - if False and ".downsamplers." in m_name: + if block_num >= len(model.down_blocks) - skipped_layers: continue - if True and ".resnets." in m_name: - if True and ".conv1" in m_name: - if False and "down_blocks" in m_name: - continue - if False and "mid_block" in m_name: - continue - if False and "up_blocks" in m_name: - continue - - if True and ".conv2" in m_name: - continue - - if True and ".conv_shortcut" in m_name: - continue - - if True and ".attentions." in m_name: + # Skip the second resnet (could be configurable) + if resnet_num > 0: continue - if False and m_name in ["conv_in", "conv_out"]: + # Skip Conv2d layers (could be configurable) + if submodule_name == "conv2": continue - """ - - if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): - m.asymmetric_padding_mode = {} - m.asymmetric_padding = {} - m.asymmetric_padding_mode["x"] = "circular" if ("x" in seamless_axes) else "constant" - m.asymmetric_padding["x"] = ( - m._reversed_padding_repeated_twice[0], - m._reversed_padding_repeated_twice[1], - 0, - 0, - ) - m.asymmetric_padding_mode["y"] = "circular" if ("y" in seamless_axes) else "constant" - m.asymmetric_padding["y"] = ( - 0, - 0, - m._reversed_padding_repeated_twice[2], - m._reversed_padding_repeated_twice[3], - ) - to_restore.append((m, m._conv_forward)) - m._conv_forward = _conv_forward_asymmetric.__get__(m, nn.Conv2d) + m.asymmetric_padding_mode = {} + m.asymmetric_padding = {} + m.asymmetric_padding_mode["x"] = "circular" if ("x" in seamless_axes) else "constant" + m.asymmetric_padding["x"] = ( + m._reversed_padding_repeated_twice[0], + m._reversed_padding_repeated_twice[1], + 0, + 0, + ) + m.asymmetric_padding_mode["y"] = "circular" if ("y" in seamless_axes) else "constant" + m.asymmetric_padding["y"] = ( + 0, + 0, + m._reversed_padding_repeated_twice[2], + m._reversed_padding_repeated_twice[3], + ) + + to_restore.append((m, m._conv_forward)) + m._conv_forward = _conv_forward_asymmetric.__get__(m, nn.Conv2d) yield diff --git a/invokeai/backend/textual_inversion.py b/invokeai/backend/textual_inversion.py index 9a4fa0b5402..f7390979bbc 100644 --- a/invokeai/backend/textual_inversion.py +++ b/invokeai/backend/textual_inversion.py @@ -8,8 +8,10 @@ from safetensors.torch import load_file from transformers import CLIPTokenizer from typing_extensions import Self + from .raw_model import RawModel + class TextualInversionModelRaw(RawModel): embedding: torch.Tensor # [n, 768]|[n, 1280] embedding_2: Optional[torch.Tensor] = None # [n, 768]|[n, 1280] - for SDXL models diff --git a/invokeai/backend/util/test_utils.py b/invokeai/backend/util/test_utils.py index a3def182c8c..0d76c4633cf 100644 --- a/invokeai/backend/util/test_utils.py +++ b/invokeai/backend/util/test_utils.py @@ -42,7 +42,7 @@ def install_and_load_model( # If the requested model is already installed, return its LoadedModel with contextlib.suppress(UnknownModelException): # TODO: Replace with wrapper call - loaded_model: LoadedModel = model_manager.load.load_model_by_attr( + loaded_model: LoadedModel = model_manager.load_model_by_attr( model_name=model_name, base_model=base_model, model_type=model_type ) return loaded_model @@ -53,7 +53,7 @@ def install_and_load_model( assert job.complete try: - loaded_model = model_manager.load.load_model_by_config(job.config_out) + loaded_model = model_manager.load_model_by_config(job.config_out) return loaded_model except UnknownModelException as e: raise Exception( diff --git a/tests/backend/model_manager/model_loading/test_model_load.py b/tests/backend/model_manager/model_loading/test_model_load.py index 38d9b8afb8c..c1fde504eae 100644 --- a/tests/backend/model_manager/model_loading/test_model_load.py +++ b/tests/backend/model_manager/model_loading/test_model_load.py @@ -4,18 +4,27 @@ from pathlib import Path -from invokeai.app.services.model_install import ModelInstallServiceBase -from invokeai.app.services.model_load import ModelLoadServiceBase +from invokeai.app.services.model_manager import ModelManagerServiceBase from invokeai.backend.textual_inversion import TextualInversionModelRaw from tests.backend.model_manager.model_manager_fixtures import * # noqa F403 -def test_loading(mm2_installer: ModelInstallServiceBase, mm2_loader: ModelLoadServiceBase, embedding_file: Path): - store = mm2_installer.record_store + +def test_loading(mm2_model_manager: ModelManagerServiceBase, embedding_file: Path): + store = mm2_model_manager.store matches = store.search_by_attr(model_name="test_embedding") assert len(matches) == 0 - key = mm2_installer.register_path(embedding_file) - loaded_model = mm2_loader.load_model_by_config(store.get_model(key)) + key = mm2_model_manager.install.register_path(embedding_file) + loaded_model = mm2_model_manager.load_model_by_config(store.get_model(key)) assert loaded_model is not None assert loaded_model.config.key == key with loaded_model as model: assert isinstance(model, TextualInversionModelRaw) + loaded_model_2 = mm2_model_manager.load_model_by_key(key) + assert loaded_model.config.key == loaded_model_2.config.key + + loaded_model_3 = mm2_model_manager.load_model_by_attr( + model_name=loaded_model.config.name, + model_type=loaded_model.config.type, + base_model=loaded_model.config.base, + ) + assert loaded_model.config.key == loaded_model_3.config.key diff --git a/tests/backend/model_manager/model_manager_fixtures.py b/tests/backend/model_manager/model_manager_fixtures.py index 5f7f44c0188..df54e2f9267 100644 --- a/tests/backend/model_manager/model_manager_fixtures.py +++ b/tests/backend/model_manager/model_manager_fixtures.py @@ -6,17 +6,17 @@ from typing import Any, Dict, List import pytest -from pytest import FixtureRequest from pydantic import BaseModel +from pytest import FixtureRequest from requests.sessions import Session from requests_testadapter import TestAdapter, TestSession from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.app.services.download import DownloadQueueServiceBase, DownloadQueueService +from invokeai.app.services.download import DownloadQueueService, DownloadQueueServiceBase from invokeai.app.services.events.events_base import EventServiceBase -from invokeai.app.services.model_manager import ModelManagerServiceBase, ModelManagerService -from invokeai.app.services.model_load import ModelLoadServiceBase, ModelLoadService from invokeai.app.services.model_install import ModelInstallService, ModelInstallServiceBase +from invokeai.app.services.model_load import ModelLoadService, ModelLoadServiceBase +from invokeai.app.services.model_manager import ModelManagerService, ModelManagerServiceBase from invokeai.app.services.model_metadata import ModelMetadataStoreBase, ModelMetadataStoreSQL from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL from invokeai.backend.model_manager.config import ( @@ -95,9 +95,7 @@ def mm2_app_config(mm2_root_dir: Path) -> InvokeAIAppConfig: @pytest.fixture -def mm2_download_queue(mm2_session: Session, - request: FixtureRequest - ) -> DownloadQueueServiceBase: +def mm2_download_queue(mm2_session: Session, request: FixtureRequest) -> DownloadQueueServiceBase: download_queue = DownloadQueueService(requests_session=mm2_session) download_queue.start() @@ -107,30 +105,34 @@ def stop_queue() -> None: request.addfinalizer(stop_queue) return download_queue + @pytest.fixture def mm2_metadata_store(mm2_record_store: ModelRecordServiceSQL) -> ModelMetadataStoreBase: return mm2_record_store.metadata_store + @pytest.fixture def mm2_loader(mm2_app_config: InvokeAIAppConfig, mm2_record_store: ModelRecordServiceBase) -> ModelLoadServiceBase: ram_cache = ModelCache( logger=InvokeAILogger.get_logger(), max_cache_size=mm2_app_config.ram_cache_size, - max_vram_cache_size=mm2_app_config.vram_cache_size + max_vram_cache_size=mm2_app_config.vram_cache_size, ) convert_cache = ModelConvertCache(mm2_app_config.models_convert_cache_path) - return ModelLoadService(app_config=mm2_app_config, - record_store=mm2_record_store, - ram_cache=ram_cache, - convert_cache=convert_cache, - ) + return ModelLoadService( + app_config=mm2_app_config, + ram_cache=ram_cache, + convert_cache=convert_cache, + ) + @pytest.fixture -def mm2_installer(mm2_app_config: InvokeAIAppConfig, - mm2_download_queue: DownloadQueueServiceBase, - mm2_session: Session, - request: FixtureRequest, - ) -> ModelInstallServiceBase: +def mm2_installer( + mm2_app_config: InvokeAIAppConfig, + mm2_download_queue: DownloadQueueServiceBase, + mm2_session: Session, + request: FixtureRequest, +) -> ModelInstallServiceBase: logger = InvokeAILogger.get_logger() db = create_mock_sqlite_database(mm2_app_config, logger) events = DummyEventService() @@ -213,15 +215,13 @@ def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceBas store.add_model("test_config_5", raw5) return store + @pytest.fixture -def mm2_model_manager(mm2_record_store: ModelRecordServiceBase, - mm2_installer: ModelInstallServiceBase, - mm2_loader: ModelLoadServiceBase) -> ModelManagerServiceBase: - return ModelManagerService( - store=mm2_record_store, - install=mm2_installer, - load=mm2_loader - ) +def mm2_model_manager( + mm2_record_store: ModelRecordServiceBase, mm2_installer: ModelInstallServiceBase, mm2_loader: ModelLoadServiceBase +) -> ModelManagerServiceBase: + return ModelManagerService(store=mm2_record_store, install=mm2_installer, load=mm2_loader) + @pytest.fixture def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session: @@ -306,5 +306,3 @@ def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session: ), ) return sess - - diff --git a/tests/backend/model_manager/test_lora.py b/tests/backend/model_manager/test_lora.py index e124bb68efc..114a4cfdcff 100644 --- a/tests/backend/model_manager/test_lora.py +++ b/tests/backend/model_manager/test_lora.py @@ -5,8 +5,8 @@ import pytest import torch -from invokeai.backend.model_patcher import ModelPatcher from invokeai.backend.lora import LoRALayer, LoRAModelRaw +from invokeai.backend.model_patcher import ModelPatcher @pytest.mark.parametrize( diff --git a/tests/backend/model_manager/test_memory_snapshot.py b/tests/backend/model_manager/test_memory_snapshot.py index 87ec8c34ee0..d31ae79b668 100644 --- a/tests/backend/model_manager/test_memory_snapshot.py +++ b/tests/backend/model_manager/test_memory_snapshot.py @@ -1,7 +1,8 @@ import pytest -from invokeai.backend.model_manager.util.libc_util import Struct_mallinfo2 from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff +from invokeai.backend.model_manager.util.libc_util import Struct_mallinfo2 + def test_memory_snapshot_capture(): """Smoke test of MemorySnapshot.capture().""" From 5231bf4e9405f79205c99e65ef6f1812efb1b134 Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Mon, 19 Feb 2024 00:56:27 +0530 Subject: [PATCH 109/340] fix: Alpha channel causing issue with DW Processor --- .../app/invocations/controlnet_image_processors.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index 8542134fff0..5b15981caab 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -40,12 +40,7 @@ from invokeai.backend.image_util.depth_anything import DepthAnythingDetector from invokeai.backend.image_util.dw_openpose import DWOpenposeDetector -from .baseinvocation import ( - BaseInvocation, - BaseInvocationOutput, - invocation, - invocation_output, -) +from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output CONTROLNET_MODE_VALUES = Literal["balanced", "more_prompt", "more_control", "unbalanced"] CONTROLNET_RESIZE_VALUES = Literal[ @@ -593,9 +588,7 @@ def run_processor(self, image: Image.Image): depth_anything_detector = DepthAnythingDetector() depth_anything_detector.load_model(model_size=self.model_size) - if image.mode == "RGBA": - image = image.convert("RGB") - + image = image.convert("RGB") if image.mode != "RGB" else image processed_image = depth_anything_detector(image=image, resolution=self.resolution, offload=self.offload) return processed_image @@ -615,7 +608,8 @@ class DWOpenposeImageProcessorInvocation(ImageProcessorInvocation): draw_hands: bool = InputField(default=False) image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res) - def run_processor(self, image): + def run_processor(self, image: Image.Image): + image = image.convert("RGB") if image.mode != "RGB" else image dw_openpose = DWOpenposeDetector() processed_image = dw_openpose( image, From 1e3369dbf9992da6ae0ae764391ab39df5037dfe Mon Sep 17 00:00:00 2001 From: dunkeroni Date: Sun, 18 Feb 2024 15:42:58 -0500 Subject: [PATCH 110/340] feat(nodes): format option for get_image method Also default CNet preprocessors to "RGB" --- .../app/invocations/controlnet_image_processors.py | 2 +- invokeai/app/services/shared/invocation_context.py | 13 ++++++++++--- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index 5b15981caab..1ef5352db6e 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -144,7 +144,7 @@ def run_processor(self, image: Image.Image) -> Image.Image: return image def invoke(self, context: InvocationContext) -> ImageOutput: - raw_image = context.images.get_pil(self.image.image_name) + raw_image = context.images.get_pil(self.image.image_name, "RGB") # image type should be PIL.PngImagePlugin.PngImageFile ? processed_image = self.run_processor(raw_image) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 1395427a97e..d217a865afc 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -194,13 +194,20 @@ def save( node_id=self._context_data.invocation.id, ) - def get_pil(self, image_name: str) -> Image: + def get_pil(self, image_name: str, format: str | None = None) -> Image: """ Gets an image as a PIL Image object. :param image_name: The name of the image to get. - """ - return self._services.images.get_pil_image(image_name) + :param format: The color format to convert the image to. If None, the original format is used. + """ + image = self._services.images.get_pil_image(image_name) + if format and format != image.mode: + try: + image = image.convert(format) + except ValueError: + self._services.logger.warning(f"Could not convert image from {image.mode} to {format}. Using original format.") + return image def get_metadata(self, image_name: str) -> Optional[MetadataField]: """ From 4d9e8fecb7d2a4f2439aacda038da103a70a23b1 Mon Sep 17 00:00:00 2001 From: dunkeroni Date: Sun, 18 Feb 2024 16:29:31 -0500 Subject: [PATCH 111/340] fix(nodes): canny preprocessor uses RGBA again --- .../app/invocations/controlnet_image_processors.py | 10 +++++++++- .../app/invocations/custom_nodes/InvokeAI_DemoFusion | 1 + 2 files changed, 10 insertions(+), 1 deletion(-) create mode 160000 invokeai/app/invocations/custom_nodes/InvokeAI_DemoFusion diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index 1ef5352db6e..797ea62f7c3 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -143,8 +143,12 @@ def run_processor(self, image: Image.Image) -> Image.Image: # superclass just passes through image without processing return image + def load_image(self, context: InvocationContext) -> Image.Image: + # allows override for any special formatting specific to the preprocessor + return context.images.get_pil(self.image.image_name, "RGB") + def invoke(self, context: InvocationContext) -> ImageOutput: - raw_image = context.images.get_pil(self.image.image_name, "RGB") + raw_image = self.load_image(context) # image type should be PIL.PngImagePlugin.PngImageFile ? processed_image = self.run_processor(raw_image) @@ -181,6 +185,10 @@ class CannyImageProcessorInvocation(ImageProcessorInvocation): default=200, ge=0, le=255, description="The high threshold of the Canny pixel gradient (0-255)" ) + def load_image(self, context: InvocationContext) -> Image.Image: + # Keep alpha channel for Canny processing to detect edges of transparent areas + return context.images.get_pil(self.image.image_name, "RGBA") + def run_processor(self, image): canny_processor = CannyDetector() processed_image = canny_processor(image, self.low_threshold, self.high_threshold) diff --git a/invokeai/app/invocations/custom_nodes/InvokeAI_DemoFusion b/invokeai/app/invocations/custom_nodes/InvokeAI_DemoFusion new file mode 160000 index 00000000000..aae207914f0 --- /dev/null +++ b/invokeai/app/invocations/custom_nodes/InvokeAI_DemoFusion @@ -0,0 +1 @@ +Subproject commit aae207914f08f77324691ae984fae6dabb0b8976 From fce7e006dfca4cc443f7e42d8994dc2673e563e7 Mon Sep 17 00:00:00 2001 From: dunkeroni Date: Sun, 18 Feb 2024 16:48:05 -0500 Subject: [PATCH 112/340] fix: removed custom module --- invokeai/app/invocations/custom_nodes/InvokeAI_DemoFusion | 1 - 1 file changed, 1 deletion(-) delete mode 160000 invokeai/app/invocations/custom_nodes/InvokeAI_DemoFusion diff --git a/invokeai/app/invocations/custom_nodes/InvokeAI_DemoFusion b/invokeai/app/invocations/custom_nodes/InvokeAI_DemoFusion deleted file mode 160000 index aae207914f0..00000000000 --- a/invokeai/app/invocations/custom_nodes/InvokeAI_DemoFusion +++ /dev/null @@ -1 +0,0 @@ -Subproject commit aae207914f08f77324691ae984fae6dabb0b8976 From eae72894cce9839d87cb38fc477846ead17a77e8 Mon Sep 17 00:00:00 2001 From: dunkeroni Date: Sun, 18 Feb 2024 16:56:46 -0500 Subject: [PATCH 113/340] chore(invocations): use IMAGE_MODES constant literal --- invokeai/app/invocations/constants.py | 3 +++ invokeai/app/invocations/image.py | 4 +--- invokeai/app/services/shared/invocation_context.py | 11 ++++++----- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/invokeai/app/invocations/constants.py b/invokeai/app/invocations/constants.py index 795e7a3b604..fca5a2ec7f1 100644 --- a/invokeai/app/invocations/constants.py +++ b/invokeai/app/invocations/constants.py @@ -12,3 +12,6 @@ SCHEDULER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())] """A literal type representing the valid scheduler names.""" + +IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"] +"""A literal type for PIL image modes supported by Invoke""" \ No newline at end of file diff --git a/invokeai/app/invocations/image.py b/invokeai/app/invocations/image.py index f5ad5515a68..1f3b5b7368c 100644 --- a/invokeai/app/invocations/image.py +++ b/invokeai/app/invocations/image.py @@ -16,6 +16,7 @@ WithMetadata, ) from invokeai.app.invocations.primitives import ImageOutput +from invokeai.app.invocations.constants import IMAGE_MODES from invokeai.app.services.image_records.image_records_common import ImageCategory from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark @@ -263,9 +264,6 @@ def invoke(self, context: InvocationContext) -> ImageOutput: return ImageOutput.build(image_dto) -IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"] - - @invocation( "img_conv", title="Convert Image Mode", diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index d217a865afc..2383785ad4d 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -6,6 +6,7 @@ from torch import Tensor from invokeai.app.invocations.fields import MetadataField, WithBoard, WithMetadata +from invokeai.app.invocations.constants import IMAGE_MODES from invokeai.app.services.boards.boards_common import BoardDTO from invokeai.app.services.config.config_default import InvokeAIAppConfig from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin @@ -194,19 +195,19 @@ def save( node_id=self._context_data.invocation.id, ) - def get_pil(self, image_name: str, format: str | None = None) -> Image: + def get_pil(self, image_name: str, mode: IMAGE_MODES | None = None) -> Image: """ Gets an image as a PIL Image object. :param image_name: The name of the image to get. - :param format: The color format to convert the image to. If None, the original format is used. + :param mode: The color mode to convert the image to. If None, the original mode is used. """ image = self._services.images.get_pil_image(image_name) - if format and format != image.mode: + if mode and mode != image.mode: try: - image = image.convert(format) + image = image.convert(mode) except ValueError: - self._services.logger.warning(f"Could not convert image from {image.mode} to {format}. Using original format.") + self._services.logger.warning(f"Could not convert image from {image.mode} to {mode}. Using original mode instead.") return image def get_metadata(self, image_name: str) -> Optional[MetadataField]: From 490d77280b307e3ce42484e5077a8fcc8837810a Mon Sep 17 00:00:00 2001 From: dunkeroni Date: Sun, 18 Feb 2024 23:10:38 -0500 Subject: [PATCH 114/340] chore(invocations): remove redundant RGB conversions --- invokeai/app/invocations/controlnet_image_processors.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index 797ea62f7c3..1e998e4b616 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -423,10 +423,6 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation): min_confidence: float = InputField(default=0.5, ge=0, le=1, description="Minimum confidence for face detection") def run_processor(self, image): - # MediaPipeFaceDetector throws an error if image has alpha channel - # so convert to RGB if needed - if image.mode == "RGBA": - image = image.convert("RGB") mediapipe_face_processor = MediapipeFaceDetector() processed_image = mediapipe_face_processor(image, max_faces=self.max_faces, min_confidence=self.min_confidence) return processed_image @@ -595,8 +591,7 @@ class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation): def run_processor(self, image: Image.Image): depth_anything_detector = DepthAnythingDetector() depth_anything_detector.load_model(model_size=self.model_size) - - image = image.convert("RGB") if image.mode != "RGB" else image + processed_image = depth_anything_detector(image=image, resolution=self.resolution, offload=self.offload) return processed_image @@ -617,7 +612,6 @@ class DWOpenposeImageProcessorInvocation(ImageProcessorInvocation): image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res) def run_processor(self, image: Image.Image): - image = image.convert("RGB") if image.mode != "RGB" else image dw_openpose = DWOpenposeDetector() processed_image = dw_openpose( image, From 07af61b348c726aeeb989a73109ee6caeb5d3a70 Mon Sep 17 00:00:00 2001 From: dunkeroni Date: Sun, 18 Feb 2024 23:11:36 -0500 Subject: [PATCH 115/340] chore: ruff formatting --- invokeai/app/invocations/constants.py | 2 +- invokeai/app/invocations/controlnet_image_processors.py | 2 +- invokeai/app/invocations/image.py | 2 +- invokeai/app/services/shared/invocation_context.py | 6 ++++-- 4 files changed, 7 insertions(+), 5 deletions(-) diff --git a/invokeai/app/invocations/constants.py b/invokeai/app/invocations/constants.py index fca5a2ec7f1..cebe0eb30fc 100644 --- a/invokeai/app/invocations/constants.py +++ b/invokeai/app/invocations/constants.py @@ -14,4 +14,4 @@ """A literal type representing the valid scheduler names.""" IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"] -"""A literal type for PIL image modes supported by Invoke""" \ No newline at end of file +"""A literal type for PIL image modes supported by Invoke""" diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index 1e998e4b616..8774f2fb275 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -591,7 +591,7 @@ class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation): def run_processor(self, image: Image.Image): depth_anything_detector = DepthAnythingDetector() depth_anything_detector.load_model(model_size=self.model_size) - + processed_image = depth_anything_detector(image=image, resolution=self.resolution, offload=self.offload) return processed_image diff --git a/invokeai/app/invocations/image.py b/invokeai/app/invocations/image.py index 1f3b5b7368c..a0c41161c32 100644 --- a/invokeai/app/invocations/image.py +++ b/invokeai/app/invocations/image.py @@ -7,6 +7,7 @@ import numpy from PIL import Image, ImageChops, ImageFilter, ImageOps +from invokeai.app.invocations.constants import IMAGE_MODES from invokeai.app.invocations.fields import ( ColorField, FieldDescriptions, @@ -16,7 +17,6 @@ WithMetadata, ) from invokeai.app.invocations.primitives import ImageOutput -from invokeai.app.invocations.constants import IMAGE_MODES from invokeai.app.services.image_records.image_records_common import ImageCategory from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 2383785ad4d..43ecb2c543e 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -5,8 +5,8 @@ from PIL.Image import Image from torch import Tensor -from invokeai.app.invocations.fields import MetadataField, WithBoard, WithMetadata from invokeai.app.invocations.constants import IMAGE_MODES +from invokeai.app.invocations.fields import MetadataField, WithBoard, WithMetadata from invokeai.app.services.boards.boards_common import BoardDTO from invokeai.app.services.config.config_default import InvokeAIAppConfig from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin @@ -207,7 +207,9 @@ def get_pil(self, image_name: str, mode: IMAGE_MODES | None = None) -> Image: try: image = image.convert(mode) except ValueError: - self._services.logger.warning(f"Could not convert image from {image.mode} to {mode}. Using original mode instead.") + self._services.logger.warning( + f"Could not convert image from {image.mode} to {mode}. Using original mode instead." + ) return image def get_metadata(self, image_name: str) -> Optional[MetadataField]: From 111cb58af77d640e2bcab9816175fc0e25fffc3d Mon Sep 17 00:00:00 2001 From: dunkeroni Date: Sun, 18 Feb 2024 23:15:28 -0500 Subject: [PATCH 116/340] one more redundant RGB convert removed --- invokeai/app/invocations/controlnet_image_processors.py | 1 - 1 file changed, 1 deletion(-) diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index 8774f2fb275..9eba3acdcaa 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -552,7 +552,6 @@ class ColorMapImageProcessorInvocation(ImageProcessorInvocation): color_map_tile_size: int = InputField(default=64, ge=0, description=FieldDescriptions.tile_size) def run_processor(self, image: Image.Image): - image = image.convert("RGB") np_image = np.array(image, dtype=np.uint8) height, width = np_image.shape[:2] From 464ac9effbb32ed247e3709cc0f59ddd2c15c369 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sun, 18 Feb 2024 18:02:58 -0500 Subject: [PATCH 117/340] remove errant def that was crashing invokeai-configure --- .../backend/install/invokeai_configure.py | 25 ++++++++----------- invokeai/frontend/install/model_install.py | 2 +- 2 files changed, 12 insertions(+), 15 deletions(-) diff --git a/invokeai/backend/install/invokeai_configure.py b/invokeai/backend/install/invokeai_configure.py index 4dfa2b070c0..ac3e583de3e 100755 --- a/invokeai/backend/install/invokeai_configure.py +++ b/invokeai/backend/install/invokeai_configure.py @@ -84,6 +84,8 @@ def get_literal_fields(field: str) -> Tuple[Any]: MAX_VRAM /= GB MAX_RAM = psutil.virtual_memory().total / GB +FORCE_FULL_PRECISION = False + INIT_FILE_PREAMBLE = """# InvokeAI initialization file # This is the InvokeAI initialization file, which contains command-line default values. # Feel free to edit. If anything goes wrong, you can re-initialize this file by deleting @@ -112,9 +114,6 @@ def postscript(errors: Set[str]) -> None: Web UI: invokeai-web -Command-line client: - invokeai - If you installed using an installation script, run: {config.root_path}/invoke.{"bat" if sys.platform == "win32" else "sh"} @@ -408,7 +407,7 @@ def create(self): begin_entry_at=3, max_height=2, relx=30, - max_width=56, + max_width=80, scroll_exit=True, ) self.add_widget_intelligent( @@ -664,7 +663,6 @@ def marshall_arguments(self) -> Namespace: generation_options = [GENERATION_OPT_CHOICES[x] for x in self.generation_options.value] for v in GENERATION_OPT_CHOICES: setattr(new_opts, v, v in generation_options) - return new_opts @@ -695,9 +693,6 @@ def onStart(self): cycle_widgets=False, ) - def new_opts(self) -> Namespace: - return self.options.marshall_arguments() - def default_ramcache() -> float: """Run a heuristic for the default RAM cache based on installed RAM.""" @@ -712,6 +707,7 @@ def default_ramcache() -> float: def default_startup_options(init_file: Path) -> InvokeAIAppConfig: opts = InvokeAIAppConfig.get_config() opts.ram = default_ramcache() + opts.precision = "float32" if FORCE_FULL_PRECISION else choose_precision(torch.device(choose_torch_device())) return opts @@ -760,7 +756,8 @@ def initialize_rootdir(root: Path, yes_to_all: bool = False): def run_console_ui( program_opts: Namespace, initfile: Path, install_helper: InstallHelper ) -> Tuple[Optional[Namespace], Optional[InstallSelections]]: - invokeai_opts = default_startup_options(initfile) + first_time = not (config.root_path / "invokeai.yaml").exists() + invokeai_opts = default_startup_options(initfile) if first_time else config invokeai_opts.root = program_opts.root if not set_min_terminal_size(MIN_COLS, MIN_LINES): @@ -773,7 +770,7 @@ def run_console_ui( if editApp.user_cancelled: return (None, None) else: - return (editApp.new_opts(), editApp.install_selections) + return (editApp.new_opts, editApp.install_selections) # ------------------------------------- @@ -785,7 +782,7 @@ def write_opts(opts: InvokeAIAppConfig, init_file: Path) -> None: new_config = InvokeAIAppConfig.get_config() new_config.root = config.root - for key, value in opts.model_dump().items(): + for key, value in vars(opts).items(): if hasattr(new_config, key): setattr(new_config, key, value) @@ -869,7 +866,8 @@ def migrate_if_needed(opt: Namespace, root: Path) -> bool: # ------------------------------------- -def main(): +def main() -> None: + global FORCE_FULL_PRECISION # FIXME parser = argparse.ArgumentParser(description="InvokeAI model downloader") parser.add_argument( "--skip-sd-weights", @@ -921,17 +919,16 @@ def main(): help="path to root of install directory", ) opt = parser.parse_args() - invoke_args = [] if opt.root: invoke_args.extend(["--root", opt.root]) if opt.full_precision: invoke_args.extend(["--precision", "float32"]) config.parse_args(invoke_args) - config.precision = "float32" if opt.full_precision else choose_precision(torch.device(choose_torch_device())) logger = InvokeAILogger().get_logger(config=config) errors = set() + FORCE_FULL_PRECISION = opt.full_precision # FIXME global try: # if we do a root migration/upgrade, then we are keeping previous diff --git a/invokeai/frontend/install/model_install.py b/invokeai/frontend/install/model_install.py index 20b630dfc62..3a4d66ae0a0 100644 --- a/invokeai/frontend/install/model_install.py +++ b/invokeai/frontend/install/model_install.py @@ -43,7 +43,7 @@ warnings.filterwarnings("ignore", category=UserWarning) # noqa: E402 config = InvokeAIAppConfig.get_config() logger = InvokeAILogger.get_logger("ModelInstallService") -logger.setLevel("WARNING") +# logger.setLevel("WARNING") # logger.setLevel('DEBUG') # build a table mapping all non-printable characters to None From 8c10f252cc9f0344478260794f505732a12eba54 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 17 Feb 2024 11:22:08 +1100 Subject: [PATCH 118/340] feat(nodes): JIT graph nodes validation We use pydantic to validate a union of valid invocations when instantiating a graph. Previously, we constructed the union while creating the `Graph` class. This introduces a dependency on the order of imports. For example, consider a setup where we have 3 invocations in the app: - Python executes the module where `FirstInvocation` is defined, registering `FirstInvocation`. - Python executes the module where `SecondInvocation` is defined, registering `SecondInvocation`. - Python executes the module where `Graph` is defined. A union of invocations is created and used to define the `Graph.nodes` field. The union contains `FirstInvocation` and `SecondInvocation`. - Python executes the module where `ThirdInvocation` is defined, registering `ThirdInvocation`. - A graph is created that includes `ThirdInvocation`. Pydantic validates the graph using the union, which does not know about `ThirdInvocation`, raising a `ValidationError` about an unknown invocation type. This scenario has been particularly problematic in tests, where we may create invocations dynamically. The test files have to be structured in such a way that the imports happen in the right order. It's a major pain. This PR refactors the validation of graph nodes to resolve this issue: - `BaseInvocation` gets a new method `get_typeadapter`. This builds a pydantic `TypeAdapter` for the union of all registered invocations, caching it after the first call. - `Graph.nodes`'s type is widened to `dict[str, BaseInvocation]`. This actually is a nice bonus, because we get better type hints whenever we reference `some_graph.nodes`. - A "plain" field validator takes over the validation logic for `Graph.nodes`. "Plain" validators totally override pydantic's own validation logic. The validator grabs the `TypeAdapter` from `BaseInvocation`, then validates each node with it. The validation is identical to the previous implementation - we get the same errors. `BaseInvocationOutput` gets the same treatment. --- invokeai/app/invocations/baseinvocation.py | 45 ++++++++++++++++------ invokeai/app/invocations/compel.py | 2 +- invokeai/app/services/shared/graph.py | 41 +++++++++++++------- 3 files changed, 63 insertions(+), 25 deletions(-) diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index 3243714937f..5edae5342df 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -8,13 +8,26 @@ from abc import ABC, abstractmethod from enum import Enum from inspect import signature -from types import UnionType -from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Literal, Optional, Type, TypeVar, Union, cast +from typing import ( + TYPE_CHECKING, + Annotated, + Any, + Callable, + ClassVar, + Iterable, + Literal, + Optional, + Type, + TypeVar, + Union, + cast, +) import semver -from pydantic import BaseModel, ConfigDict, Field, create_model +from pydantic import BaseModel, ConfigDict, Field, TypeAdapter, create_model from pydantic.fields import FieldInfo from pydantic_core import PydanticUndefined +from typing_extensions import TypeAliasType from invokeai.app.invocations.fields import ( FieldKind, @@ -84,6 +97,7 @@ class BaseInvocationOutput(BaseModel): """ _output_classes: ClassVar[set[BaseInvocationOutput]] = set() + _typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None @classmethod def register_output(cls, output: BaseInvocationOutput) -> None: @@ -96,10 +110,14 @@ def get_outputs(cls) -> Iterable[BaseInvocationOutput]: return cls._output_classes @classmethod - def get_outputs_union(cls) -> UnionType: - """Gets a union of all invocation outputs.""" - outputs_union = Union[tuple(cls._output_classes)] # type: ignore [valid-type] - return outputs_union # type: ignore [return-value] + def get_typeadapter(cls) -> TypeAdapter[Any]: + """Gets a pydantc TypeAdapter for the union of all invocation output types.""" + if not cls._typeadapter: + InvocationOutputsUnion = TypeAliasType( + "InvocationOutputsUnion", Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")] + ) + cls._typeadapter = TypeAdapter(InvocationOutputsUnion) + return cls._typeadapter @classmethod def get_output_types(cls) -> Iterable[str]: @@ -148,6 +166,7 @@ class BaseInvocation(ABC, BaseModel): """ _invocation_classes: ClassVar[set[BaseInvocation]] = set() + _typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None @classmethod def get_type(cls) -> str: @@ -160,10 +179,14 @@ def register_invocation(cls, invocation: BaseInvocation) -> None: cls._invocation_classes.add(invocation) @classmethod - def get_invocations_union(cls) -> UnionType: - """Gets a union of all invocation types.""" - invocations_union = Union[tuple(cls._invocation_classes)] # type: ignore [valid-type] - return invocations_union # type: ignore [return-value] + def get_typeadapter(cls) -> TypeAdapter[Any]: + """Gets a pydantc TypeAdapter for the union of all invocation types.""" + if not cls._typeadapter: + InvocationsUnion = TypeAliasType( + "InvocationsUnion", Annotated[Union[tuple(cls._invocation_classes)], Field(discriminator="type")] + ) + cls._typeadapter = TypeAdapter(InvocationsUnion) + return cls._typeadapter @classmethod def get_invocations(cls) -> Iterable[BaseInvocation]: diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 517da4375e1..47be380626b 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -417,7 +417,7 @@ class ClipSkipInvocation(BaseInvocation): """Skip layers in clip text_encoder model.""" clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP") - skipped_layers: int = InputField(default=0, description=FieldDescriptions.skipped_layers) + skipped_layers: int = InputField(default=0, ge=0, description=FieldDescriptions.skipped_layers) def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput: self.clip.skipped_layers += self.skipped_layers diff --git a/invokeai/app/services/shared/graph.py b/invokeai/app/services/shared/graph.py index 3df230f5ee7..3066af0e503 100644 --- a/invokeai/app/services/shared/graph.py +++ b/invokeai/app/services/shared/graph.py @@ -2,10 +2,15 @@ import copy import itertools -from typing import Annotated, Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints +from typing import Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints import networkx as nx -from pydantic import BaseModel, ConfigDict, field_validator, model_validator +from pydantic import ( + BaseModel, + ConfigDict, + field_validator, + model_validator, +) from pydantic.fields import Field # Importing * is bad karma but needed here for node detection @@ -260,21 +265,24 @@ def invoke(self, context: InvocationContext) -> CollectInvocationOutput: return CollectInvocationOutput(collection=copy.copy(self.collection)) -InvocationsUnion: Any = BaseInvocation.get_invocations_union() -InvocationOutputsUnion: Any = BaseInvocationOutput.get_outputs_union() - - class Graph(BaseModel): id: str = Field(description="The id of this graph", default_factory=uuid_string) # TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me - nodes: dict[str, Annotated[InvocationsUnion, Field(discriminator="type")]] = Field( - description="The nodes in this graph", default_factory=dict - ) + nodes: dict[str, BaseInvocation] = Field(description="The nodes in this graph", default_factory=dict) edges: list[Edge] = Field( description="The connections between nodes and their fields in this graph", default_factory=list, ) + @field_validator("nodes", mode="plain") + @classmethod + def validate_nodes(cls, v: dict[str, Any]): + nodes: dict[str, BaseInvocation] = {} + typeadapter = BaseInvocation.get_typeadapter() + for node_id, node in v.items(): + nodes[node_id] = typeadapter.validate_python(node) + return nodes + def add_node(self, node: BaseInvocation) -> None: """Adds a node to a graph @@ -824,9 +832,7 @@ class GraphExecutionState(BaseModel): ) # The results of executed nodes - results: dict[str, Annotated[InvocationOutputsUnion, Field(discriminator="type")]] = Field( - description="The results of node executions", default_factory=dict - ) + results: dict[str, BaseInvocationOutput] = Field(description="The results of node executions", default_factory=dict) # Errors raised when executing nodes errors: dict[str, str] = Field(description="Errors raised when executing nodes", default_factory=dict) @@ -843,6 +849,15 @@ class GraphExecutionState(BaseModel): default_factory=dict, ) + @field_validator("results", mode="plain") + @classmethod + def validate_results(cls, v: dict[str, BaseInvocationOutput]): + results: dict[str, BaseInvocationOutput] = {} + typeadapter = BaseInvocationOutput.get_typeadapter() + for result_id, result in v.items(): + results[result_id] = typeadapter.validate_python(result) + return results + @field_validator("graph") def graph_is_valid(cls, v: Graph): """Validates that the graph is valid""" @@ -1247,6 +1262,6 @@ def validate_exposed_nodes(cls, values): return values -GraphInvocation.model_rebuild(force=True) Graph.model_rebuild(force=True) +GraphInvocation.model_rebuild(force=True) GraphExecutionState.model_rebuild(force=True) From 29a5195b65614e4be6274ec8727b2761c0c1d911 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 17 Feb 2024 18:19:22 +1100 Subject: [PATCH 119/340] fix(nodes): fix OpenAPI schema generation The change to `Graph.nodes` and `GraphExecutionState.results` validation requires some fanagling to get the OpenAPI schema generation to work. See new comments for a details. --- invokeai/app/api_app.py | 3 +- invokeai/app/services/shared/graph.py | 90 +++++++++++++++++++++++++-- 2 files changed, 86 insertions(+), 7 deletions(-) diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index 149d47fb962..65607c436a5 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -151,6 +151,8 @@ def custom_openapi() -> dict[str, Any]: # TODO: note that we assume the schema_key here is the TYPE.__name__ # This could break in some cases, figure out a better way to do it output_type_titles[schema_key] = output_schema["title"] + openapi_schema["components"]["schemas"][schema_key] = output_schema + openapi_schema["components"]["schemas"][schema_key]["class"] = "output" # Add Node Editor UI helper schemas ui_config_schemas = models_json_schema( @@ -173,7 +175,6 @@ def custom_openapi() -> dict[str, Any]: outputs_ref = {"$ref": f"#/components/schemas/{output_type_title}"} invoker_schema["output"] = outputs_ref invoker_schema["class"] = "invocation" - openapi_schema["components"]["schemas"][f"{output_type_title}"]["class"] = "output" # This code no longer seems to be necessary? # Leave it here just in case diff --git a/invokeai/app/services/shared/graph.py b/invokeai/app/services/shared/graph.py index 3066af0e503..1b53f642225 100644 --- a/invokeai/app/services/shared/graph.py +++ b/invokeai/app/services/shared/graph.py @@ -2,16 +2,19 @@ import copy import itertools -from typing import Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints +from typing import Annotated, Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints import networkx as nx from pydantic import ( BaseModel, ConfigDict, + GetJsonSchemaHandler, field_validator, model_validator, ) from pydantic.fields import Field +from pydantic.json_schema import JsonSchemaValue +from pydantic_core import CoreSchema # Importing * is bad karma but needed here for node detection from invokeai.app.invocations import * # noqa: F401 F403 @@ -277,12 +280,61 @@ class Graph(BaseModel): @field_validator("nodes", mode="plain") @classmethod def validate_nodes(cls, v: dict[str, Any]): + """Validates the nodes in the graph by retrieving a union of all node types and validating each node.""" + + # Invocations register themselves as their python modules are executed. The union of all invocations is + # constructed at runtime. We use pydantic to validate `Graph.nodes` using that union. + # + # It's possible that when `graph.py` is executed, not all invocation-containing modules will have executed. If + # we construct the invocation union as `graph.py` is executed, we may miss some invocations. Those missing + # invocations will cause a graph to fail if they are used. + # + # We can get around this by validating the nodes in the graph using a "plain" validator, which overrides the + # pydantic validation entirely. This allows us to validate the nodes using the union of invocations at runtime. + # + # This same pattern is used in `GraphExecutionState`. + nodes: dict[str, BaseInvocation] = {} typeadapter = BaseInvocation.get_typeadapter() for node_id, node in v.items(): nodes[node_id] = typeadapter.validate_python(node) return nodes + @classmethod + def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue: + # We use a "plain" validator to validate the nodes in the graph. Pydantic is unable to create a JSON Schema for + # fields that use "plain" validators, so we have to hack around this. Also, we need to add all invocations to + # the generated schema as options for the `nodes` field. + # + # The workaround is to create a new BaseModel that has the same fields as `Graph` but without the validator and + # with the invocation union as the type for the `nodes` field. Pydantic then generates the JSON Schema as + # expected. + # + # You might be tempted to do something like this: + # + # ```py + # cloned_model = create_model(cls.__name__, __base__=cls, nodes=...) + # delattr(cloned_model, "validate_nodes") + # cloned_model.model_rebuild(force=True) + # json_schema = handler(cloned_model.__pydantic_core_schema__) + # ``` + # + # Unfortunately, this does not work. Calling `handler` here results in infinite recursion as pydantic attempts + # to build the JSON Schema for the cloned model. Instead, we have to manually clone the model. + # + # This same pattern is used in `GraphExecutionState`. + + class Graph(BaseModel): + id: Optional[str] = Field(default=None, description="The id of this graph") + nodes: dict[ + str, Annotated[Union[tuple(BaseInvocation._invocation_classes)], Field(discriminator="type")] + ] = Field(description="The nodes in this graph") + edges: list[Edge] = Field(description="The connections between nodes and their fields in this graph") + + json_schema = handler(Graph.__pydantic_core_schema__) + json_schema = handler.resolve_ref_schema(json_schema) + return json_schema + def add_node(self, node: BaseInvocation) -> None: """Adds a node to a graph @@ -852,6 +904,9 @@ class GraphExecutionState(BaseModel): @field_validator("results", mode="plain") @classmethod def validate_results(cls, v: dict[str, BaseInvocationOutput]): + """Validates the results in the GES by retrieving a union of all output types and validating each result.""" + + # See the comment in `Graph.validate_nodes` for an explanation of this logic. results: dict[str, BaseInvocationOutput] = {} typeadapter = BaseInvocationOutput.get_typeadapter() for result_id, result in v.items(): @@ -864,6 +919,34 @@ def graph_is_valid(cls, v: Graph): v.validate_self() return v + @classmethod + def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue: + # See the comment in `Graph.__get_pydantic_json_schema__` for an explanation of this logic. + class GraphExecutionState(BaseModel): + """Tracks the state of a graph execution""" + + id: str = Field(description="The id of the execution state") + graph: Graph = Field(description="The graph being executed") + execution_graph: Graph = Field(description="The expanded graph of activated and executed nodes") + executed: set[str] = Field(description="The set of node ids that have been executed") + executed_history: list[str] = Field( + description="The list of node ids that have been executed, in order of execution" + ) + results: dict[ + str, Annotated[Union[tuple(BaseInvocationOutput._output_classes)], Field(discriminator="type")] + ] = Field(description="The results of node executions") + errors: dict[str, str] = Field(description="Errors raised when executing nodes") + prepared_source_mapping: dict[str, str] = Field( + description="The map of prepared nodes to original graph nodes" + ) + source_prepared_mapping: dict[str, set[str]] = Field( + description="The map of original graph nodes to prepared nodes" + ) + + json_schema = handler(GraphExecutionState.__pydantic_core_schema__) + json_schema = handler.resolve_ref_schema(json_schema) + return json_schema + model_config = ConfigDict( json_schema_extra={ "required": [ @@ -1260,8 +1343,3 @@ def validate_exposed_nodes(cls, values): ) return values - - -Graph.model_rebuild(force=True) -GraphInvocation.model_rebuild(force=True) -GraphExecutionState.model_rebuild(force=True) From c2dbe7a98bfae532441c199912a8dd969b322b35 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 17 Feb 2024 19:56:13 +1100 Subject: [PATCH 120/340] tidy(nodes): remove GraphInvocation `GraphInvocation` is a node that can contain a whole graph. It is removed for a number of reasons: 1. This feature was unused (the UI doesn't support it) and there is no plan for it to be used. The use-case it served is known in other node execution engines as "node groups" or "blocks" - a self-contained group of nodes, which has group inputs and outputs. This is a planned feature that will be handled client-side. 2. It adds substantial complexity to the graph processing logic. It's probably not enough to have a measurable performance impact but it does make it harder to work in the graph logic. 3. It allows for graphs to be recursive, and the improved invocations union handling does not play well with it. Actually, it works fine within `graph.py` but not in the tests for some reason. I do not understand why. There's probably a workaround, but I took this as encouragement to remove `GraphInvocation` from the app since we don't use it. --- invokeai/app/services/shared/graph.py | 292 +++++++------------------- tests/aa_nodes/test_invoker.py | 13 +- tests/aa_nodes/test_node_graph.py | 168 ++++++++------- tests/aa_nodes/test_session_queue.py | 47 ++--- 4 files changed, 181 insertions(+), 339 deletions(-) diff --git a/invokeai/app/services/shared/graph.py b/invokeai/app/services/shared/graph.py index 1b53f642225..4df9f0c4b04 100644 --- a/invokeai/app/services/shared/graph.py +++ b/invokeai/app/services/shared/graph.py @@ -184,10 +184,6 @@ class NodeIdMismatchError(ValueError): pass -class InvalidSubGraphError(ValueError): - pass - - class CyclicalGraphError(ValueError): pass @@ -196,25 +192,6 @@ class UnknownGraphValidationError(ValueError): pass -# TODO: Create and use an Empty output? -@invocation_output("graph_output") -class GraphInvocationOutput(BaseInvocationOutput): - pass - - -# TODO: Fill this out and move to invocations -@invocation("graph", version="1.0.0") -class GraphInvocation(BaseInvocation): - """Execute a graph""" - - # TODO: figure out how to create a default here - graph: "Graph" = InputField(description="The graph to run", default=None) - - def invoke(self, context: InvocationContext) -> GraphInvocationOutput: - """Invoke with provided services and return outputs.""" - return GraphInvocationOutput() - - @invocation_output("iterate_output") class IterateInvocationOutput(BaseInvocationOutput): """Used to connect iteration outputs. Will be expanded to a specific output.""" @@ -346,41 +323,21 @@ def add_node(self, node: BaseInvocation) -> None: self.nodes[node.id] = node - def _get_graph_and_node(self, node_path: str) -> tuple["Graph", str]: - """Returns the graph and node id for a node path.""" - # Materialized graphs may have nodes at the top level - if node_path in self.nodes: - return (self, node_path) - - node_id = node_path if "." not in node_path else node_path[: node_path.index(".")] - if node_id not in self.nodes: - raise NodeNotFoundError(f"Node {node_path} not found in graph") - - node = self.nodes[node_id] - - if not isinstance(node, GraphInvocation): - # There's more node path left but this isn't a graph - failure - raise NodeNotFoundError("Node path terminated early at a non-graph node") - - return node.graph._get_graph_and_node(node_path[node_path.index(".") + 1 :]) - - def delete_node(self, node_path: str) -> None: + def delete_node(self, node_id: str) -> None: """Deletes a node from a graph""" try: - graph, node_id = self._get_graph_and_node(node_path) - # Delete edges for this node - input_edges = self._get_input_edges_and_graphs(node_path) - output_edges = self._get_output_edges_and_graphs(node_path) + input_edges = self._get_input_edges(node_id) + output_edges = self._get_output_edges(node_id) - for edge_graph, _, edge in input_edges: - edge_graph.delete_edge(edge) + for edge in input_edges: + self.delete_edge(edge) - for edge_graph, _, edge in output_edges: - edge_graph.delete_edge(edge) + for edge in output_edges: + self.delete_edge(edge) - del graph.nodes[node_id] + del self.nodes[node_id] except NodeNotFoundError: pass # Ignore, not doesn't exist (should this throw?) @@ -430,13 +387,6 @@ def validate_self(self) -> None: if k != v.id: raise NodeIdMismatchError(f"Node ids must match, got {k} and {v.id}") - # Validate all subgraphs - for gn in (n for n in self.nodes.values() if isinstance(n, GraphInvocation)): - try: - gn.graph.validate_self() - except Exception as e: - raise InvalidSubGraphError(f"Subgraph {gn.id} is invalid") from e - # Validate that all edges match nodes and fields in the graph for edge in self.edges: source_node = self.nodes.get(edge.source.node_id, None) @@ -498,7 +448,6 @@ def is_valid(self) -> bool: except ( DuplicateNodeIdError, NodeIdMismatchError, - InvalidSubGraphError, NodeNotFoundError, NodeFieldNotFoundError, CyclicalGraphError, @@ -519,7 +468,7 @@ def _is_destination_field_list_of_Any(self, edge: Edge) -> bool: def _validate_edge(self, edge: Edge): """Validates that a new edge doesn't create a cycle in the graph""" - # Validate that the nodes exist (edges may contain node paths, so we can't just check for nodes directly) + # Validate that the nodes exist try: from_node = self.get_node(edge.source.node_id) to_node = self.get_node(edge.destination.node_id) @@ -586,171 +535,90 @@ def _validate_edge(self, edge: Edge): f"Collector input type does not match collector output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}" ) - def has_node(self, node_path: str) -> bool: + def has_node(self, node_id: str) -> bool: """Determines whether or not a node exists in the graph.""" try: - n = self.get_node(node_path) - if n is not None: - return True - else: - return False + _ = self.get_node(node_id) + return True except NodeNotFoundError: return False - def get_node(self, node_path: str) -> BaseInvocation: - """Gets a node from the graph using a node path.""" - # Materialized graphs may have nodes at the top level - graph, node_id = self._get_graph_and_node(node_path) - return graph.nodes[node_id] - - def _get_node_path(self, node_id: str, prefix: Optional[str] = None) -> str: - return node_id if prefix is None or prefix == "" else f"{prefix}.{node_id}" + def get_node(self, node_id: str) -> BaseInvocation: + """Gets a node from the graph.""" + try: + return self.nodes[node_id] + except KeyError as e: + raise NodeNotFoundError(f"Node {node_id} not found in graph") from e - def update_node(self, node_path: str, new_node: BaseInvocation) -> None: + def update_node(self, node_id: str, new_node: BaseInvocation) -> None: """Updates a node in the graph.""" - graph, node_id = self._get_graph_and_node(node_path) - node = graph.nodes[node_id] + node = self.nodes[node_id] # Ensure the node type matches the new node if type(node) is not type(new_node): - raise TypeError(f"Node {node_path} is type {type(node)} but new node is type {type(new_node)}") + raise TypeError(f"Node {node_id} is type {type(node)} but new node is type {type(new_node)}") # Ensure the new id is either the same or is not in the graph - prefix = None if "." not in node_path else node_path[: node_path.rindex(".")] - new_path = self._get_node_path(new_node.id, prefix=prefix) - if new_node.id != node.id and self.has_node(new_path): - raise NodeAlreadyInGraphError("Node with id {new_node.id} already exists in graph") + if new_node.id != node.id and self.has_node(new_node.id): + raise NodeAlreadyInGraphError(f"Node with id {new_node.id} already exists in graph") # Set the new node in the graph - graph.nodes[new_node.id] = new_node + self.nodes[new_node.id] = new_node if new_node.id != node.id: - input_edges = self._get_input_edges_and_graphs(node_path) - output_edges = self._get_output_edges_and_graphs(node_path) + input_edges = self._get_input_edges(node_id) + output_edges = self._get_output_edges(node_id) # Delete node and all edges - graph.delete_node(node_path) + self.delete_node(node_id) # Create new edges for each input and output - for graph, _, edge in input_edges: - # Remove the graph prefix from the node path - new_graph_node_path = ( - new_node.id - if "." not in edge.destination.node_id - else f'{edge.destination.node_id[edge.destination.node_id.rindex("."):]}.{new_node.id}' - ) - graph.add_edge( + for edge in input_edges: + self.add_edge( Edge( source=edge.source, - destination=EdgeConnection(node_id=new_graph_node_path, field=edge.destination.field), + destination=EdgeConnection(node_id=new_node.id, field=edge.destination.field), ) ) - for graph, _, edge in output_edges: - # Remove the graph prefix from the node path - new_graph_node_path = ( - new_node.id - if "." not in edge.source.node_id - else f'{edge.source.node_id[edge.source.node_id.rindex("."):]}.{new_node.id}' - ) - graph.add_edge( + for edge in output_edges: + self.add_edge( Edge( - source=EdgeConnection(node_id=new_graph_node_path, field=edge.source.field), + source=EdgeConnection(node_id=new_node.id, field=edge.source.field), destination=edge.destination, ) ) - def _get_input_edges(self, node_path: str, field: Optional[str] = None) -> list[Edge]: - """Gets all input edges for a node""" - edges = self._get_input_edges_and_graphs(node_path) - - # Filter to edges that match the field - filtered_edges = (e for e in edges if field is None or e[2].destination.field == field) - - # Create full node paths for each edge - return [ - Edge( - source=EdgeConnection( - node_id=self._get_node_path(e.source.node_id, prefix=prefix), - field=e.source.field, - ), - destination=EdgeConnection( - node_id=self._get_node_path(e.destination.node_id, prefix=prefix), - field=e.destination.field, - ), - ) - for _, prefix, e in filtered_edges - ] + def _get_input_edges(self, node_id: str, field: Optional[str] = None) -> list[Edge]: + """Gets all input edges for a node. If field is provided, only edges to that field are returned.""" - def _get_input_edges_and_graphs( - self, node_path: str, prefix: Optional[str] = None - ) -> list[tuple["Graph", Union[str, None], Edge]]: - """Gets all input edges for a node along with the graph they are in and the graph's path""" - edges = [] + edges = [e for e in self.edges if e.destination.node_id == node_id] - # Return any input edges that appear in this graph - edges.extend([(self, prefix, e) for e in self.edges if e.destination.node_id == node_path]) + if field is None: + return edges - node_id = node_path if "." not in node_path else node_path[: node_path.index(".")] - node = self.nodes[node_id] + filtered_edges = [e for e in edges if e.destination.field == field] - if isinstance(node, GraphInvocation): - graph = node.graph - graph_path = node.id if prefix is None or prefix == "" else self._get_node_path(node.id, prefix=prefix) - graph_edges = graph._get_input_edges_and_graphs(node_path[(len(node_id) + 1) :], prefix=graph_path) - edges.extend(graph_edges) - - return edges - - def _get_output_edges(self, node_path: str, field: str) -> list[Edge]: - """Gets all output edges for a node""" - edges = self._get_output_edges_and_graphs(node_path) - - # Filter to edges that match the field - filtered_edges = (e for e in edges if e[2].source.field == field) - - # Create full node paths for each edge - return [ - Edge( - source=EdgeConnection( - node_id=self._get_node_path(e.source.node_id, prefix=prefix), - field=e.source.field, - ), - destination=EdgeConnection( - node_id=self._get_node_path(e.destination.node_id, prefix=prefix), - field=e.destination.field, - ), - ) - for _, prefix, e in filtered_edges - ] + return filtered_edges - def _get_output_edges_and_graphs( - self, node_path: str, prefix: Optional[str] = None - ) -> list[tuple["Graph", Union[str, None], Edge]]: - """Gets all output edges for a node along with the graph they are in and the graph's path""" - edges = [] + def _get_output_edges(self, node_id: str, field: Optional[str] = None) -> list[Edge]: + """Gets all output edges for a node. If field is provided, only edges from that field are returned.""" + edges = [e for e in self.edges if e.source.node_id == node_id] - # Return any input edges that appear in this graph - edges.extend([(self, prefix, e) for e in self.edges if e.source.node_id == node_path]) + if field is None: + return edges - node_id = node_path if "." not in node_path else node_path[: node_path.index(".")] - node = self.nodes[node_id] + filtered_edges = [e for e in edges if e.source.field == field] - if isinstance(node, GraphInvocation): - graph = node.graph - graph_path = node.id if prefix is None or prefix == "" else self._get_node_path(node.id, prefix=prefix) - graph_edges = graph._get_output_edges_and_graphs(node_path[(len(node_id) + 1) :], prefix=graph_path) - edges.extend(graph_edges) - - return edges + return filtered_edges def _is_iterator_connection_valid( self, - node_path: str, + node_id: str, new_input: Optional[EdgeConnection] = None, new_output: Optional[EdgeConnection] = None, ) -> bool: - inputs = [e.source for e in self._get_input_edges(node_path, "collection")] - outputs = [e.destination for e in self._get_output_edges(node_path, "item")] + inputs = [e.source for e in self._get_input_edges(node_id, "collection")] + outputs = [e.destination for e in self._get_output_edges(node_id, "item")] if new_input is not None: inputs.append(new_input) @@ -778,12 +646,12 @@ def _is_iterator_connection_valid( def _is_collector_connection_valid( self, - node_path: str, + node_id: str, new_input: Optional[EdgeConnection] = None, new_output: Optional[EdgeConnection] = None, ) -> bool: - inputs = [e.source for e in self._get_input_edges(node_path, "item")] - outputs = [e.destination for e in self._get_output_edges(node_path, "collection")] + inputs = [e.source for e in self._get_input_edges(node_id, "item")] + outputs = [e.destination for e in self._get_output_edges(node_id, "collection")] if new_input is not None: inputs.append(new_input) @@ -839,27 +707,17 @@ def nx_graph_with_data(self) -> nx.DiGraph: g.add_edges_from({(e.source.node_id, e.destination.node_id) for e in self.edges}) return g - def nx_graph_flat(self, nx_graph: Optional[nx.DiGraph] = None, prefix: Optional[str] = None) -> nx.DiGraph: + def nx_graph_flat(self, nx_graph: Optional[nx.DiGraph] = None) -> nx.DiGraph: """Returns a flattened NetworkX DiGraph, including all subgraphs (but not with iterations expanded)""" g = nx_graph or nx.DiGraph() # Add all nodes from this graph except graph/iteration nodes - g.add_nodes_from( - [ - self._get_node_path(n.id, prefix) - for n in self.nodes.values() - if not isinstance(n, GraphInvocation) and not isinstance(n, IterateInvocation) - ] - ) - - # Expand graph nodes - for sgn in (gn for gn in self.nodes.values() if isinstance(gn, GraphInvocation)): - g = sgn.graph.nx_graph_flat(g, self._get_node_path(sgn.id, prefix)) + g.add_nodes_from([n.id for n in self.nodes.values() if not isinstance(n, IterateInvocation)]) # TODO: figure out if iteration nodes need to be expanded unique_edges = {(e.source.node_id, e.destination.node_id) for e in self.edges} - g.add_edges_from([(self._get_node_path(e[0], prefix), self._get_node_path(e[1], prefix)) for e in unique_edges]) + g.add_edges_from([(e[0], e[1]) for e in unique_edges]) return g @@ -1017,17 +875,17 @@ def has_error(self) -> bool: """Returns true if the graph has any errors""" return len(self.errors) > 0 - def _create_execution_node(self, node_path: str, iteration_node_map: list[tuple[str, str]]) -> list[str]: + def _create_execution_node(self, node_id: str, iteration_node_map: list[tuple[str, str]]) -> list[str]: """Prepares an iteration node and connects all edges, returning the new node id""" - node = self.graph.get_node(node_path) + node = self.graph.get_node(node_id) self_iteration_count = -1 # If this is an iterator node, we must create a copy for each iteration if isinstance(node, IterateInvocation): # Get input collection edge (should error if there are no inputs) - input_collection_edge = next(iter(self.graph._get_input_edges(node_path, "collection"))) + input_collection_edge = next(iter(self.graph._get_input_edges(node_id, "collection"))) input_collection_prepared_node_id = next( n[1] for n in iteration_node_map if n[0] == input_collection_edge.source.node_id ) @@ -1041,7 +899,7 @@ def _create_execution_node(self, node_path: str, iteration_node_map: list[tuple[ return new_nodes # Get all input edges - input_edges = self.graph._get_input_edges(node_path) + input_edges = self.graph._get_input_edges(node_id) # Create new edges for this iteration # For collect nodes, this may contain multiple inputs to the same field @@ -1068,10 +926,10 @@ def _create_execution_node(self, node_path: str, iteration_node_map: list[tuple[ # Add to execution graph self.execution_graph.add_node(new_node) - self.prepared_source_mapping[new_node.id] = node_path - if node_path not in self.source_prepared_mapping: - self.source_prepared_mapping[node_path] = set() - self.source_prepared_mapping[node_path].add(new_node.id) + self.prepared_source_mapping[new_node.id] = node_id + if node_id not in self.source_prepared_mapping: + self.source_prepared_mapping[node_id] = set() + self.source_prepared_mapping[node_id].add(new_node.id) # Add new edges to execution graph for edge in new_edges: @@ -1175,13 +1033,13 @@ def _prepare(self) -> Optional[str]: def _get_iteration_node( self, - source_node_path: str, + source_node_id: str, graph: nx.DiGraph, execution_graph: nx.DiGraph, prepared_iterator_nodes: list[str], ) -> Optional[str]: """Gets the prepared version of the specified source node that matches every iteration specified""" - prepared_nodes = self.source_prepared_mapping[source_node_path] + prepared_nodes = self.source_prepared_mapping[source_node_id] if len(prepared_nodes) == 1: return next(iter(prepared_nodes)) @@ -1192,7 +1050,7 @@ def _get_iteration_node( # Filter to only iterator nodes that are a parent of the specified node, in tuple format (prepared, source) iterator_source_node_mapping = [(n, self.prepared_source_mapping[n]) for n in prepared_iterator_nodes] - parent_iterators = [itn for itn in iterator_source_node_mapping if nx.has_path(graph, itn[1], source_node_path)] + parent_iterators = [itn for itn in iterator_source_node_mapping if nx.has_path(graph, itn[1], source_node_id)] return next( (n for n in prepared_nodes if all(nx.has_path(execution_graph, pit[0], n) for pit in parent_iterators)), @@ -1261,19 +1119,19 @@ def _is_node_updatable(self, node_id: str) -> bool: def add_node(self, node: BaseInvocation) -> None: self.graph.add_node(node) - def update_node(self, node_path: str, new_node: BaseInvocation) -> None: - if not self._is_node_updatable(node_path): + def update_node(self, node_id: str, new_node: BaseInvocation) -> None: + if not self._is_node_updatable(node_id): raise NodeAlreadyExecutedError( - f"Node {node_path} has already been prepared or executed and cannot be updated" + f"Node {node_id} has already been prepared or executed and cannot be updated" ) - self.graph.update_node(node_path, new_node) + self.graph.update_node(node_id, new_node) - def delete_node(self, node_path: str) -> None: - if not self._is_node_updatable(node_path): + def delete_node(self, node_id: str) -> None: + if not self._is_node_updatable(node_id): raise NodeAlreadyExecutedError( - f"Node {node_path} has already been prepared or executed and cannot be deleted" + f"Node {node_id} has already been prepared or executed and cannot be deleted" ) - self.graph.delete_node(node_path) + self.graph.delete_node(node_id) def add_edge(self, edge: Edge) -> None: if not self._is_node_updatable(edge.destination.node_id): diff --git a/tests/aa_nodes/test_invoker.py b/tests/aa_nodes/test_invoker.py index f67b5a2ac55..38fcf859a58 100644 --- a/tests/aa_nodes/test_invoker.py +++ b/tests/aa_nodes/test_invoker.py @@ -23,7 +23,7 @@ from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService from invokeai.app.services.invoker import Invoker from invokeai.app.services.session_queue.session_queue_common import DEFAULT_QUEUE_ID -from invokeai.app.services.shared.graph import Graph, GraphExecutionState, GraphInvocation +from invokeai.app.services.shared.graph import Graph, GraphExecutionState @pytest.fixture @@ -35,17 +35,6 @@ def simple_graph(): return g -@pytest.fixture -def graph_with_subgraph(): - sub_g = Graph() - sub_g.add_node(PromptTestInvocation(id="1", prompt="Banana sushi")) - sub_g.add_node(TextToImageTestInvocation(id="2")) - sub_g.add_edge(create_edge("1", "prompt", "2", "prompt")) - g = Graph() - g.add_node(GraphInvocation(id="1", graph=sub_g)) - return g - - # This must be defined here to avoid issues with the dynamic creation of the union of all invocation types # Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate # the test invocations. diff --git a/tests/aa_nodes/test_node_graph.py b/tests/aa_nodes/test_node_graph.py index 12a181f392f..94682962adf 100644 --- a/tests/aa_nodes/test_node_graph.py +++ b/tests/aa_nodes/test_node_graph.py @@ -8,8 +8,6 @@ invocation, invocation_output, ) -from invokeai.app.invocations.image import ShowImageInvocation -from invokeai.app.invocations.math import AddInvocation, SubtractInvocation from invokeai.app.invocations.primitives import ( FloatCollectionInvocation, FloatInvocation, @@ -17,13 +15,11 @@ StringInvocation, ) from invokeai.app.invocations.upscale import ESRGANInvocation -from invokeai.app.services.shared.default_graphs import create_text_to_image from invokeai.app.services.shared.graph import ( CollectInvocation, Edge, EdgeConnection, Graph, - GraphInvocation, InvalidEdgeError, IterateInvocation, NodeAlreadyInGraphError, @@ -425,19 +421,19 @@ def test_graph_invalid_if_edges_reference_missing_nodes(): assert g.is_valid() is False -def test_graph_invalid_if_subgraph_invalid(): - g = Graph() - n1 = GraphInvocation(id="1") - n1.graph = Graph() +# def test_graph_invalid_if_subgraph_invalid(): +# g = Graph() +# n1 = GraphInvocation(id="1") +# n1.graph = Graph() - n1_1 = TextToImageTestInvocation(id="2", prompt="Banana sushi") - n1.graph.nodes[n1_1.id] = n1_1 - e1 = create_edge("1", "image", "2", "image") - n1.graph.edges.append(e1) +# n1_1 = TextToImageTestInvocation(id="2", prompt="Banana sushi") +# n1.graph.nodes[n1_1.id] = n1_1 +# e1 = create_edge("1", "image", "2", "image") +# n1.graph.edges.append(e1) - g.nodes[n1.id] = n1 +# g.nodes[n1.id] = n1 - assert g.is_valid() is False +# assert g.is_valid() is False def test_graph_invalid_if_has_cycle(): @@ -466,108 +462,108 @@ def test_graph_invalid_with_invalid_connection(): assert g.is_valid() is False -# TODO: Subgraph operations -def test_graph_gets_subgraph_node(): - g = Graph() - n1 = GraphInvocation(id="1") - n1.graph = Graph() +# # TODO: Subgraph operations +# def test_graph_gets_subgraph_node(): +# g = Graph() +# n1 = GraphInvocation(id="1") +# n1.graph = Graph() - n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi") - n1.graph.add_node(n1_1) +# n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi") +# n1.graph.add_node(n1_1) - g.add_node(n1) +# g.add_node(n1) - result = g.get_node("1.1") +# result = g.get_node("1.1") - assert result is not None - assert result.id == "1" - assert result == n1_1 +# assert result is not None +# assert result.id == "1" +# assert result == n1_1 -def test_graph_expands_subgraph(): - g = Graph() - n1 = GraphInvocation(id="1") - n1.graph = Graph() +# def test_graph_expands_subgraph(): +# g = Graph() +# n1 = GraphInvocation(id="1") +# n1.graph = Graph() - n1_1 = AddInvocation(id="1", a=1, b=2) - n1_2 = SubtractInvocation(id="2", b=3) - n1.graph.add_node(n1_1) - n1.graph.add_node(n1_2) - n1.graph.add_edge(create_edge("1", "value", "2", "a")) +# n1_1 = AddInvocation(id="1", a=1, b=2) +# n1_2 = SubtractInvocation(id="2", b=3) +# n1.graph.add_node(n1_1) +# n1.graph.add_node(n1_2) +# n1.graph.add_edge(create_edge("1", "value", "2", "a")) - g.add_node(n1) +# g.add_node(n1) - n2 = AddInvocation(id="2", b=5) - g.add_node(n2) - g.add_edge(create_edge("1.2", "value", "2", "a")) +# n2 = AddInvocation(id="2", b=5) +# g.add_node(n2) +# g.add_edge(create_edge("1.2", "value", "2", "a")) - dg = g.nx_graph_flat() - assert set(dg.nodes) == {"1.1", "1.2", "2"} - assert set(dg.edges) == {("1.1", "1.2"), ("1.2", "2")} +# dg = g.nx_graph_flat() +# assert set(dg.nodes) == {"1.1", "1.2", "2"} +# assert set(dg.edges) == {("1.1", "1.2"), ("1.2", "2")} -def test_graph_subgraph_t2i(): - g = Graph() - n1 = GraphInvocation(id="1") +# def test_graph_subgraph_t2i(): +# g = Graph() +# n1 = GraphInvocation(id="1") - # Get text to image default graph - lg = create_text_to_image() - n1.graph = lg.graph +# # Get text to image default graph +# lg = create_text_to_image() +# n1.graph = lg.graph - g.add_node(n1) +# g.add_node(n1) - n2 = IntegerInvocation(id="2", value=512) - n3 = IntegerInvocation(id="3", value=256) +# n2 = IntegerInvocation(id="2", value=512) +# n3 = IntegerInvocation(id="3", value=256) - g.add_node(n2) - g.add_node(n3) +# g.add_node(n2) +# g.add_node(n3) - g.add_edge(create_edge("2", "value", "1.width", "value")) - g.add_edge(create_edge("3", "value", "1.height", "value")) +# g.add_edge(create_edge("2", "value", "1.width", "value")) +# g.add_edge(create_edge("3", "value", "1.height", "value")) - n4 = ShowImageInvocation(id="4") - g.add_node(n4) - g.add_edge(create_edge("1.8", "image", "4", "image")) +# n4 = ShowImageInvocation(id="4") +# g.add_node(n4) +# g.add_edge(create_edge("1.8", "image", "4", "image")) - # Validate - dg = g.nx_graph_flat() - assert set(dg.nodes) == {"1.width", "1.height", "1.seed", "1.3", "1.4", "1.5", "1.6", "1.7", "1.8", "2", "3", "4"} - expected_edges = [(f"1.{e.source.node_id}", f"1.{e.destination.node_id}") for e in lg.graph.edges] - expected_edges.extend([("2", "1.width"), ("3", "1.height"), ("1.8", "4")]) - print(expected_edges) - print(list(dg.edges)) - assert set(dg.edges) == set(expected_edges) +# # Validate +# dg = g.nx_graph_flat() +# assert set(dg.nodes) == {"1.width", "1.height", "1.seed", "1.3", "1.4", "1.5", "1.6", "1.7", "1.8", "2", "3", "4"} +# expected_edges = [(f"1.{e.source.node_id}", f"1.{e.destination.node_id}") for e in lg.graph.edges] +# expected_edges.extend([("2", "1.width"), ("3", "1.height"), ("1.8", "4")]) +# print(expected_edges) +# print(list(dg.edges)) +# assert set(dg.edges) == set(expected_edges) -def test_graph_fails_to_get_missing_subgraph_node(): - g = Graph() - n1 = GraphInvocation(id="1") - n1.graph = Graph() +# def test_graph_fails_to_get_missing_subgraph_node(): +# g = Graph() +# n1 = GraphInvocation(id="1") +# n1.graph = Graph() - n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi") - n1.graph.add_node(n1_1) +# n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi") +# n1.graph.add_node(n1_1) - g.add_node(n1) +# g.add_node(n1) - with pytest.raises(NodeNotFoundError): - _ = g.get_node("1.2") +# with pytest.raises(NodeNotFoundError): +# _ = g.get_node("1.2") -def test_graph_fails_to_enumerate_non_subgraph_node(): - g = Graph() - n1 = GraphInvocation(id="1") - n1.graph = Graph() +# def test_graph_fails_to_enumerate_non_subgraph_node(): +# g = Graph() +# n1 = GraphInvocation(id="1") +# n1.graph = Graph() - n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi") - n1.graph.add_node(n1_1) +# n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi") +# n1.graph.add_node(n1_1) - g.add_node(n1) +# g.add_node(n1) - n2 = ESRGANInvocation(id="2") - g.add_node(n2) +# n2 = ESRGANInvocation(id="2") +# g.add_node(n2) - with pytest.raises(NodeNotFoundError): - _ = g.get_node("2.1") +# with pytest.raises(NodeNotFoundError): +# _ = g.get_node("2.1") def test_graph_gets_networkx_graph(): diff --git a/tests/aa_nodes/test_session_queue.py b/tests/aa_nodes/test_session_queue.py index b15bb9df360..bfe6444de8c 100644 --- a/tests/aa_nodes/test_session_queue.py +++ b/tests/aa_nodes/test_session_queue.py @@ -8,10 +8,9 @@ NodeFieldValue, calc_session_count, create_session_nfv_tuples, - populate_graph, prepare_values_to_insert, ) -from invokeai.app.services.shared.graph import Graph, GraphExecutionState, GraphInvocation +from invokeai.app.services.shared.graph import Graph, GraphExecutionState from tests.aa_nodes.test_nodes import PromptTestInvocation @@ -39,28 +38,28 @@ def batch_graph() -> Graph: return g -def test_populate_graph_with_subgraph(): - g1 = Graph() - g1.add_node(PromptTestInvocation(id="1", prompt="Banana sushi")) - g1.add_node(PromptTestInvocation(id="2", prompt="Banana sushi")) - n1 = PromptTestInvocation(id="1", prompt="Banana snake") - subgraph = Graph() - subgraph.add_node(n1) - g1.add_node(GraphInvocation(id="3", graph=subgraph)) - - nfvs = [ - NodeFieldValue(node_path="1", field_name="prompt", value="Strawberry sushi"), - NodeFieldValue(node_path="2", field_name="prompt", value="Strawberry sunday"), - NodeFieldValue(node_path="3.1", field_name="prompt", value="Strawberry snake"), - ] - - g2 = populate_graph(g1, nfvs) - - # do not mutate g1 - assert g1 is not g2 - assert g2.get_node("1").prompt == "Strawberry sushi" - assert g2.get_node("2").prompt == "Strawberry sunday" - assert g2.get_node("3.1").prompt == "Strawberry snake" +# def test_populate_graph_with_subgraph(): +# g1 = Graph() +# g1.add_node(PromptTestInvocation(id="1", prompt="Banana sushi")) +# g1.add_node(PromptTestInvocation(id="2", prompt="Banana sushi")) +# n1 = PromptTestInvocation(id="1", prompt="Banana snake") +# subgraph = Graph() +# subgraph.add_node(n1) +# g1.add_node(GraphInvocation(id="3", graph=subgraph)) + +# nfvs = [ +# NodeFieldValue(node_path="1", field_name="prompt", value="Strawberry sushi"), +# NodeFieldValue(node_path="2", field_name="prompt", value="Strawberry sunday"), +# NodeFieldValue(node_path="3.1", field_name="prompt", value="Strawberry snake"), +# ] + +# g2 = populate_graph(g1, nfvs) + +# # do not mutate g1 +# assert g1 is not g2 +# assert g2.get_node("1").prompt == "Strawberry sushi" +# assert g2.get_node("2").prompt == "Strawberry sunday" +# assert g2.get_node("3.1").prompt == "Strawberry snake" def test_create_sessions_from_batch_with_runs(batch_data_collection, batch_graph): From ab89437e189b9d45d75bf965818a491017e3fbbe Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 17 Feb 2024 19:59:21 +1100 Subject: [PATCH 121/340] tidy(nodes): move node tests to parent dir Thanks to the resolution of the import vs union issue, we can put tests anywhere. --- tests/aa_nodes/__init__.py | 0 tests/{aa_nodes => }/test_graph_execution_state.py | 0 tests/{aa_nodes => }/test_invoker.py | 0 tests/{aa_nodes => }/test_node_graph.py | 0 tests/{aa_nodes => }/test_nodes.py | 0 tests/{aa_nodes => }/test_session_queue.py | 3 ++- 6 files changed, 2 insertions(+), 1 deletion(-) delete mode 100644 tests/aa_nodes/__init__.py rename tests/{aa_nodes => }/test_graph_execution_state.py (100%) rename tests/{aa_nodes => }/test_invoker.py (100%) rename tests/{aa_nodes => }/test_node_graph.py (100%) rename tests/{aa_nodes => }/test_nodes.py (100%) rename tests/{aa_nodes => }/test_session_queue.py (99%) diff --git a/tests/aa_nodes/__init__.py b/tests/aa_nodes/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/tests/aa_nodes/test_graph_execution_state.py b/tests/test_graph_execution_state.py similarity index 100% rename from tests/aa_nodes/test_graph_execution_state.py rename to tests/test_graph_execution_state.py diff --git a/tests/aa_nodes/test_invoker.py b/tests/test_invoker.py similarity index 100% rename from tests/aa_nodes/test_invoker.py rename to tests/test_invoker.py diff --git a/tests/aa_nodes/test_node_graph.py b/tests/test_node_graph.py similarity index 100% rename from tests/aa_nodes/test_node_graph.py rename to tests/test_node_graph.py diff --git a/tests/aa_nodes/test_nodes.py b/tests/test_nodes.py similarity index 100% rename from tests/aa_nodes/test_nodes.py rename to tests/test_nodes.py diff --git a/tests/aa_nodes/test_session_queue.py b/tests/test_session_queue.py similarity index 99% rename from tests/aa_nodes/test_session_queue.py rename to tests/test_session_queue.py index bfe6444de8c..48b980539c7 100644 --- a/tests/aa_nodes/test_session_queue.py +++ b/tests/test_session_queue.py @@ -11,7 +11,8 @@ prepare_values_to_insert, ) from invokeai.app.services.shared.graph import Graph, GraphExecutionState -from tests.aa_nodes.test_nodes import PromptTestInvocation + +from .test_nodes import PromptTestInvocation @pytest.fixture From baafcd25df624fb616ed6b2b5594a3af0d7f6af9 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 17 Feb 2024 20:00:58 +1100 Subject: [PATCH 122/340] tidy(nodes): remove LibraryGraphs The workflow library supersedes this unused feature. --- .../app/services/shared/default_graphs.py | 92 ------------------- invokeai/app/services/shared/graph.py | 56 ----------- 2 files changed, 148 deletions(-) delete mode 100644 invokeai/app/services/shared/default_graphs.py diff --git a/invokeai/app/services/shared/default_graphs.py b/invokeai/app/services/shared/default_graphs.py deleted file mode 100644 index 7e62c6d0a1b..00000000000 --- a/invokeai/app/services/shared/default_graphs.py +++ /dev/null @@ -1,92 +0,0 @@ -from invokeai.app.services.item_storage.item_storage_base import ItemStorageABC - -from ...invocations.compel import CompelInvocation -from ...invocations.image import ImageNSFWBlurInvocation -from ...invocations.latent import DenoiseLatentsInvocation, LatentsToImageInvocation -from ...invocations.noise import NoiseInvocation -from ...invocations.primitives import IntegerInvocation -from .graph import Edge, EdgeConnection, ExposedNodeInput, ExposedNodeOutput, Graph, LibraryGraph - -default_text_to_image_graph_id = "539b2af5-2b4d-4d8c-8071-e54a3255fc74" - - -def create_text_to_image() -> LibraryGraph: - graph = Graph( - nodes={ - "width": IntegerInvocation(id="width", value=512), - "height": IntegerInvocation(id="height", value=512), - "seed": IntegerInvocation(id="seed", value=-1), - "3": NoiseInvocation(id="3"), - "4": CompelInvocation(id="4"), - "5": CompelInvocation(id="5"), - "6": DenoiseLatentsInvocation(id="6"), - "7": LatentsToImageInvocation(id="7"), - "8": ImageNSFWBlurInvocation(id="8"), - }, - edges=[ - Edge( - source=EdgeConnection(node_id="width", field="value"), - destination=EdgeConnection(node_id="3", field="width"), - ), - Edge( - source=EdgeConnection(node_id="height", field="value"), - destination=EdgeConnection(node_id="3", field="height"), - ), - Edge( - source=EdgeConnection(node_id="seed", field="value"), - destination=EdgeConnection(node_id="3", field="seed"), - ), - Edge( - source=EdgeConnection(node_id="3", field="noise"), - destination=EdgeConnection(node_id="6", field="noise"), - ), - Edge( - source=EdgeConnection(node_id="6", field="latents"), - destination=EdgeConnection(node_id="7", field="latents"), - ), - Edge( - source=EdgeConnection(node_id="4", field="conditioning"), - destination=EdgeConnection(node_id="6", field="positive_conditioning"), - ), - Edge( - source=EdgeConnection(node_id="5", field="conditioning"), - destination=EdgeConnection(node_id="6", field="negative_conditioning"), - ), - Edge( - source=EdgeConnection(node_id="7", field="image"), - destination=EdgeConnection(node_id="8", field="image"), - ), - ], - ) - return LibraryGraph( - id=default_text_to_image_graph_id, - name="t2i", - description="Converts text to an image", - graph=graph, - exposed_inputs=[ - ExposedNodeInput(node_path="4", field="prompt", alias="positive_prompt"), - ExposedNodeInput(node_path="5", field="prompt", alias="negative_prompt"), - ExposedNodeInput(node_path="width", field="value", alias="width"), - ExposedNodeInput(node_path="height", field="value", alias="height"), - ExposedNodeInput(node_path="seed", field="value", alias="seed"), - ], - exposed_outputs=[ExposedNodeOutput(node_path="8", field="image", alias="image")], - ) - - -def create_system_graphs(graph_library: ItemStorageABC[LibraryGraph]) -> list[LibraryGraph]: - """Creates the default system graphs, or adds new versions if the old ones don't match""" - - # TODO: Uncomment this when we are ready to fix this up to prevent breaking changes - graphs: list[LibraryGraph] = [] - - text_to_image = graph_library.get(default_text_to_image_graph_id) - - # TODO: Check if the graph is the same as the default one, and if not, update it - # if text_to_image is None: - text_to_image = create_text_to_image() - graph_library.set(text_to_image) - - graphs.append(text_to_image) - - return graphs diff --git a/invokeai/app/services/shared/graph.py b/invokeai/app/services/shared/graph.py index 4df9f0c4b04..5380c2e795d 100644 --- a/invokeai/app/services/shared/graph.py +++ b/invokeai/app/services/shared/graph.py @@ -10,7 +10,6 @@ ConfigDict, GetJsonSchemaHandler, field_validator, - model_validator, ) from pydantic.fields import Field from pydantic.json_schema import JsonSchemaValue @@ -1146,58 +1145,3 @@ def delete_edge(self, edge: Edge) -> None: f"Destination node {edge.destination.node_id} has already been prepared or executed and cannot have a source edge deleted" ) self.graph.delete_edge(edge) - - -class ExposedNodeInput(BaseModel): - node_path: str = Field(description="The node path to the node with the input") - field: str = Field(description="The field name of the input") - alias: str = Field(description="The alias of the input") - - -class ExposedNodeOutput(BaseModel): - node_path: str = Field(description="The node path to the node with the output") - field: str = Field(description="The field name of the output") - alias: str = Field(description="The alias of the output") - - -class LibraryGraph(BaseModel): - id: str = Field(description="The unique identifier for this library graph", default_factory=uuid_string) - graph: Graph = Field(description="The graph") - name: str = Field(description="The name of the graph") - description: str = Field(description="The description of the graph") - exposed_inputs: list[ExposedNodeInput] = Field(description="The inputs exposed by this graph", default_factory=list) - exposed_outputs: list[ExposedNodeOutput] = Field( - description="The outputs exposed by this graph", default_factory=list - ) - - @field_validator("exposed_inputs", "exposed_outputs") - def validate_exposed_aliases(cls, v: list[Union[ExposedNodeInput, ExposedNodeOutput]]): - if len(v) != len({i.alias for i in v}): - raise ValueError("Duplicate exposed alias") - return v - - @model_validator(mode="after") - def validate_exposed_nodes(cls, values): - graph = values.graph - - # Validate exposed inputs - for exposed_input in values.exposed_inputs: - if not graph.has_node(exposed_input.node_path): - raise ValueError(f"Exposed input node {exposed_input.node_path} does not exist") - node = graph.get_node(exposed_input.node_path) - if get_input_field(node, exposed_input.field) is None: - raise ValueError( - f"Exposed input field {exposed_input.field} does not exist on node {exposed_input.node_path}" - ) - - # Validate exposed outputs - for exposed_output in values.exposed_outputs: - if not graph.has_node(exposed_output.node_path): - raise ValueError(f"Exposed output node {exposed_output.node_path} does not exist") - node = graph.get_node(exposed_output.node_path) - if get_output_field(node, exposed_output.field) is None: - raise ValueError( - f"Exposed output field {exposed_output.field} does not exist on node {exposed_output.node_path}" - ) - - return values From eff1dec7d447c397fed3d132367cd06fc07b3f38 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 17 Feb 2024 20:02:37 +1100 Subject: [PATCH 123/340] tidy(nodes): remove no-op model_config Because we now customize the JSON Schema creation for GraphExecutionState, the model_config did nothing. --- invokeai/app/services/shared/graph.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/invokeai/app/services/shared/graph.py b/invokeai/app/services/shared/graph.py index 5380c2e795d..e3941d9ca37 100644 --- a/invokeai/app/services/shared/graph.py +++ b/invokeai/app/services/shared/graph.py @@ -7,7 +7,6 @@ import networkx as nx from pydantic import ( BaseModel, - ConfigDict, GetJsonSchemaHandler, field_validator, ) @@ -804,22 +803,6 @@ class GraphExecutionState(BaseModel): json_schema = handler.resolve_ref_schema(json_schema) return json_schema - model_config = ConfigDict( - json_schema_extra={ - "required": [ - "id", - "graph", - "execution_graph", - "executed", - "executed_history", - "results", - "errors", - "prepared_source_mapping", - "source_prepared_mapping", - ] - } - ) - def next(self) -> Optional[BaseInvocation]: """Gets the next node ready to execute.""" From 15daad7ab37c2bb855218e00d8f6bb063faa4df9 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 17 Feb 2024 20:02:51 +1100 Subject: [PATCH 124/340] chore(ui): typegen --- .../frontend/web/src/services/api/schema.ts | 125 ++++++++---------- 1 file changed, 52 insertions(+), 73 deletions(-) diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index 40fc262be26..47a257ffe6c 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -2011,8 +2011,9 @@ export type components = { /** * CLIP * @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count + * @default null */ - clip?: components["schemas"]["ClipField"] | null; + clip: components["schemas"]["ClipField"] | null; /** * type * @default clip_skip_output @@ -3264,6 +3265,7 @@ export type components = { /** * Masked Latents Name * @description The name of the masked image latents + * @default null */ masked_latents_name?: string | null; }; @@ -4211,14 +4213,14 @@ export type components = { * Nodes * @description The nodes in this graph */ - nodes?: { - [key: string]: components["schemas"]["ControlNetInvocation"] | components["schemas"]["TileToPropertiesInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["SchedulerInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["StringInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["ImageInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["SeamlessModeInvocation"] | components["schemas"]["FloatInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["SDXLLoraLoaderInvocation"] | components["schemas"]["SaveImageInvocation"] | components["schemas"]["ColorMapImageProcessorInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["T2IAdapterInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["CalculateImageTilesMinimumOverlapInvocation"] | components["schemas"]["StringSplitNegInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["DWOpenposeImageProcessorInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["ConditioningCollectionInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["StringSplitInvocation"] | components["schemas"]["FaceIdentifierInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["IntegerInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["LaMaInfillInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["CenterPadCropInvocation"] | components["schemas"]["CoreMetadataInvocation"] | components["schemas"]["CalculateImageTilesInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["CreateDenoiseMaskInvocation"] | components["schemas"]["StringJoinThreeInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["PairTileImageInvocation"] | components["schemas"]["IdealSizeInvocation"] | components["schemas"]["IntegerMathInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["LatentsInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["FloatMathInvocation"] | components["schemas"]["StringCollectionInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageChannelOffsetInvocation"] | components["schemas"]["FloatCollectionInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MergeTilesToImageInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["MergeMetadataInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["BlendLatentsInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["MetadataItemInvocation"] | components["schemas"]["FloatToIntegerInvocation"] | components["schemas"]["DepthAnythingImageProcessorInvocation"] | components["schemas"]["StringReplaceInvocation"] | components["schemas"]["RoundInvocation"] | components["schemas"]["LatentsCollectionInvocation"] | components["schemas"]["CropLatentsCoreInvocation"] | components["schemas"]["BooleanCollectionInvocation"] | components["schemas"]["BlankImageInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["FreeUInvocation"] | components["schemas"]["IPAdapterInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["ImageChannelMultiplyInvocation"] | components["schemas"]["MetadataInvocation"] | components["schemas"]["StringJoinInvocation"] | components["schemas"]["FaceMaskInvocation"] | components["schemas"]["IntegerCollectionInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["CalculateImageTilesEvenSplitInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["CV2InfillInvocation"] | components["schemas"]["RandomFloatInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["UnsharpMaskInvocation"] | components["schemas"]["ColorInvocation"] | components["schemas"]["BooleanInvocation"] | components["schemas"]["FaceOffInvocation"] | components["schemas"]["ConditioningInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["CvInpaintInvocation"]; + nodes: { + [key: string]: components["schemas"]["ImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["BooleanInvocation"] | components["schemas"]["IdealSizeInvocation"] | components["schemas"]["ImageChannelMultiplyInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["CalculateImageTilesInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["BlankImageInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["FloatCollectionInvocation"] | components["schemas"]["FaceMaskInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["ColorMapImageProcessorInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["UnsharpMaskInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["IntegerMathInvocation"] | components["schemas"]["FloatToIntegerInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["StringInvocation"] | components["schemas"]["IPAdapterInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["FaceOffInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["FloatMathInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["IntegerCollectionInvocation"] | components["schemas"]["ConditioningCollectionInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CropLatentsCoreInvocation"] | components["schemas"]["TileToPropertiesInvocation"] | components["schemas"]["SeamlessModeInvocation"] | components["schemas"]["RoundInvocation"] | components["schemas"]["MergeTilesToImageInvocation"] | components["schemas"]["StringReplaceInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["DWOpenposeImageProcessorInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["LatentsInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["FloatInvocation"] | components["schemas"]["StringJoinThreeInvocation"] | components["schemas"]["MetadataItemInvocation"] | components["schemas"]["PairTileImageInvocation"] | components["schemas"]["CoreMetadataInvocation"] | components["schemas"]["CalculateImageTilesEvenSplitInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MetadataInvocation"] | components["schemas"]["LatentsCollectionInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["CalculateImageTilesMinimumOverlapInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["DepthAnythingImageProcessorInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["SDXLLoraLoaderInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["BooleanCollectionInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["StringSplitInvocation"] | components["schemas"]["StringSplitNegInvocation"] | components["schemas"]["CV2InfillInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ConditioningInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["FreeUInvocation"] | components["schemas"]["SaveImageInvocation"] | components["schemas"]["RandomFloatInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["CenterPadCropInvocation"] | components["schemas"]["StringCollectionInvocation"] | components["schemas"]["SchedulerInvocation"] | components["schemas"]["LaMaInfillInvocation"] | components["schemas"]["IntegerInvocation"] | components["schemas"]["FaceIdentifierInvocation"] | components["schemas"]["T2IAdapterInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["MergeMetadataInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["ImageChannelOffsetInvocation"] | components["schemas"]["CreateDenoiseMaskInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["BlendLatentsInvocation"] | components["schemas"]["StringJoinInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ColorInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["CompelInvocation"]; }; /** * Edges * @description The connections between nodes and their fields in this graph */ - edges?: components["schemas"]["Edge"][]; + edges: components["schemas"]["Edge"][]; }; /** * GraphExecutionState @@ -4249,7 +4251,7 @@ export type components = { * @description The results of node executions */ results: { - [key: string]: components["schemas"]["SchedulerOutput"] | components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["LoraLoaderOutput"] | components["schemas"]["IdealSizeOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["IPAdapterOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["CLIPOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["PairTileImageOutput"] | components["schemas"]["GraphInvocationOutput"] | components["schemas"]["CalculateImageTilesOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["SeamlessModeOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["UNetOutput"] | components["schemas"]["SDXLLoraLoaderOutput"] | components["schemas"]["TileToPropertiesOutput"] | components["schemas"]["T2IAdapterOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["ColorCollectionOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["DenoiseMaskOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["VAEOutput"] | components["schemas"]["ClipSkipInvocationOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["String2Output"] | components["schemas"]["IntegerOutput"] | components["schemas"]["LatentsCollectionOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["MetadataOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["IterateInvocationOutput"]; + [key: string]: components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["UNetOutput"] | components["schemas"]["IntegerOutput"] | components["schemas"]["ColorCollectionOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["MetadataOutput"] | components["schemas"]["String2Output"] | components["schemas"]["ControlOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["SeamlessModeOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["ClipSkipInvocationOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["LoraLoaderOutput"] | components["schemas"]["VAEOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["CalculateImageTilesOutput"] | components["schemas"]["DenoiseMaskOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["PairTileImageOutput"] | components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["SchedulerOutput"] | components["schemas"]["IdealSizeOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["IPAdapterOutput"] | components["schemas"]["TileToPropertiesOutput"] | components["schemas"]["T2IAdapterOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["CLIPOutput"] | components["schemas"]["SDXLLoraLoaderOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["LatentsCollectionOutput"]; }; /** * Errors @@ -4273,46 +4275,6 @@ export type components = { [key: string]: string[]; }; }; - /** - * GraphInvocation - * @description Execute a graph - */ - GraphInvocation: { - /** - * Id - * @description The id of this instance of an invocation. Must be unique among all instances of invocations. - */ - id: string; - /** - * Is Intermediate - * @description Whether or not this is an intermediate invocation. - * @default false - */ - is_intermediate?: boolean; - /** - * Use Cache - * @description Whether or not to use the cache - * @default true - */ - use_cache?: boolean; - /** @description The graph to run */ - graph?: components["schemas"]["Graph"]; - /** - * type - * @default graph - * @constant - */ - type: "graph"; - }; - /** GraphInvocationOutput */ - GraphInvocationOutput: { - /** - * type - * @default graph_output - * @constant - */ - type: "graph_output"; - }; /** * HFModelSource * @description A HuggingFace repo_id with optional variant, sub-folder and access token. @@ -6218,6 +6180,7 @@ export type components = { /** * Seed * @description Seed used to generate this latents + * @default null */ seed?: number | null; }; @@ -6631,7 +6594,10 @@ export type components = { * @description Key of model as returned by ModelRecordServiceBase.get_model() */ key: string; - /** @description Info to load submodel */ + /** + * @description Info to load submodel + * @default null + */ submodel_type?: components["schemas"]["SubModelType"] | null; /** * Weight @@ -6697,13 +6663,15 @@ export type components = { /** * UNet * @description UNet (scheduler, LoRAs) + * @default null */ - unet?: components["schemas"]["UNetField"] | null; + unet: components["schemas"]["UNetField"] | null; /** * CLIP * @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count + * @default null */ - clip?: components["schemas"]["ClipField"] | null; + clip: components["schemas"]["ClipField"] | null; /** * type * @default lora_loader_output @@ -7420,7 +7388,10 @@ export type components = { * @description Key of model as returned by ModelRecordServiceBase.get_model() */ key: string; - /** @description Info to load submodel */ + /** + * @description Info to load submodel + * @default null + */ submodel_type?: components["schemas"]["SubModelType"] | null; }; /** @@ -8794,18 +8765,21 @@ export type components = { /** * UNet * @description UNet (scheduler, LoRAs) + * @default null */ - unet?: components["schemas"]["UNetField"] | null; + unet: components["schemas"]["UNetField"] | null; /** * CLIP 1 * @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count + * @default null */ - clip?: components["schemas"]["ClipField"] | null; + clip: components["schemas"]["ClipField"] | null; /** * CLIP 2 * @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count + * @default null */ - clip2?: components["schemas"]["ClipField"] | null; + clip2: components["schemas"]["ClipField"] | null; /** * type * @default sdxl_lora_loader_output @@ -9202,13 +9176,15 @@ export type components = { /** * UNet * @description UNet (scheduler, LoRAs) + * @default null */ - unet?: components["schemas"]["UNetField"] | null; + unet: components["schemas"]["UNetField"] | null; /** * VAE * @description VAE + * @default null */ - vae?: components["schemas"]["VaeField"] | null; + vae: components["schemas"]["VaeField"] | null; /** * type * @default seamless_output @@ -10397,7 +10373,10 @@ export type components = { * @description Axes("x" and "y") to which apply seamless */ seamless_axes?: string[]; - /** @description FreeU configuration */ + /** + * @description FreeU configuration + * @default null + */ freeu_config?: components["schemas"]["FreeUConfig"] | null; }; /** @@ -11113,17 +11092,17 @@ export type components = { */ UIType: "SDXLMainModelField" | "SDXLRefinerModelField" | "ONNXModelField" | "VAEModelField" | "LoRAModelField" | "ControlNetModelField" | "IPAdapterModelField" | "SchedulerField" | "AnyField" | "CollectionField" | "CollectionItemField" | "DEPRECATED_Boolean" | "DEPRECATED_Color" | "DEPRECATED_Conditioning" | "DEPRECATED_Control" | "DEPRECATED_Float" | "DEPRECATED_Image" | "DEPRECATED_Integer" | "DEPRECATED_Latents" | "DEPRECATED_String" | "DEPRECATED_BooleanCollection" | "DEPRECATED_ColorCollection" | "DEPRECATED_ConditioningCollection" | "DEPRECATED_ControlCollection" | "DEPRECATED_FloatCollection" | "DEPRECATED_ImageCollection" | "DEPRECATED_IntegerCollection" | "DEPRECATED_LatentsCollection" | "DEPRECATED_StringCollection" | "DEPRECATED_BooleanPolymorphic" | "DEPRECATED_ColorPolymorphic" | "DEPRECATED_ConditioningPolymorphic" | "DEPRECATED_ControlPolymorphic" | "DEPRECATED_FloatPolymorphic" | "DEPRECATED_ImagePolymorphic" | "DEPRECATED_IntegerPolymorphic" | "DEPRECATED_LatentsPolymorphic" | "DEPRECATED_StringPolymorphic" | "DEPRECATED_MainModel" | "DEPRECATED_UNet" | "DEPRECATED_Vae" | "DEPRECATED_CLIP" | "DEPRECATED_Collection" | "DEPRECATED_CollectionItem" | "DEPRECATED_Enum" | "DEPRECATED_WorkflowField" | "DEPRECATED_IsIntermediate" | "DEPRECATED_BoardField" | "DEPRECATED_MetadataItem" | "DEPRECATED_MetadataItemCollection" | "DEPRECATED_MetadataItemPolymorphic" | "DEPRECATED_MetadataDict"; /** - * VaeModelFormat + * ControlNetModelFormat * @description An enumeration. * @enum {string} */ - VaeModelFormat: "checkpoint" | "diffusers"; + ControlNetModelFormat: "checkpoint" | "diffusers"; /** - * T2IAdapterModelFormat + * LoRAModelFormat * @description An enumeration. * @enum {string} */ - T2IAdapterModelFormat: "diffusers"; + LoRAModelFormat: "lycoris" | "diffusers"; /** * StableDiffusionXLModelFormat * @description An enumeration. @@ -11131,47 +11110,47 @@ export type components = { */ StableDiffusionXLModelFormat: "checkpoint" | "diffusers"; /** - * StableDiffusion1ModelFormat + * IPAdapterModelFormat * @description An enumeration. * @enum {string} */ - StableDiffusion1ModelFormat: "checkpoint" | "diffusers"; + IPAdapterModelFormat: "invokeai"; /** - * StableDiffusionOnnxModelFormat + * T2IAdapterModelFormat * @description An enumeration. * @enum {string} */ - StableDiffusionOnnxModelFormat: "olive" | "onnx"; + T2IAdapterModelFormat: "diffusers"; /** - * ControlNetModelFormat + * StableDiffusion1ModelFormat * @description An enumeration. * @enum {string} */ - ControlNetModelFormat: "checkpoint" | "diffusers"; + StableDiffusion1ModelFormat: "checkpoint" | "diffusers"; /** - * StableDiffusion2ModelFormat + * CLIPVisionModelFormat * @description An enumeration. * @enum {string} */ - StableDiffusion2ModelFormat: "checkpoint" | "diffusers"; + CLIPVisionModelFormat: "diffusers"; /** - * LoRAModelFormat + * StableDiffusionOnnxModelFormat * @description An enumeration. * @enum {string} */ - LoRAModelFormat: "lycoris" | "diffusers"; + StableDiffusionOnnxModelFormat: "olive" | "onnx"; /** - * CLIPVisionModelFormat + * StableDiffusion2ModelFormat * @description An enumeration. * @enum {string} */ - CLIPVisionModelFormat: "diffusers"; + StableDiffusion2ModelFormat: "checkpoint" | "diffusers"; /** - * IPAdapterModelFormat + * VaeModelFormat * @description An enumeration. * @enum {string} */ - IPAdapterModelFormat: "invokeai"; + VaeModelFormat: "checkpoint" | "diffusers"; }; responses: never; parameters: never; From ec3b2f502eda827ce58ed76fee6ba03c8b305367 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 17 Feb 2024 20:08:22 +1100 Subject: [PATCH 125/340] tidy(nodes): remove commented tests --- tests/test_node_graph.py | 119 ------------------------------------ tests/test_session_queue.py | 24 -------- 2 files changed, 143 deletions(-) diff --git a/tests/test_node_graph.py b/tests/test_node_graph.py index 94682962adf..87a4948af40 100644 --- a/tests/test_node_graph.py +++ b/tests/test_node_graph.py @@ -421,21 +421,6 @@ def test_graph_invalid_if_edges_reference_missing_nodes(): assert g.is_valid() is False -# def test_graph_invalid_if_subgraph_invalid(): -# g = Graph() -# n1 = GraphInvocation(id="1") -# n1.graph = Graph() - -# n1_1 = TextToImageTestInvocation(id="2", prompt="Banana sushi") -# n1.graph.nodes[n1_1.id] = n1_1 -# e1 = create_edge("1", "image", "2", "image") -# n1.graph.edges.append(e1) - -# g.nodes[n1.id] = n1 - -# assert g.is_valid() is False - - def test_graph_invalid_if_has_cycle(): g = Graph() n1 = ESRGANInvocation(id="1") @@ -462,110 +447,6 @@ def test_graph_invalid_with_invalid_connection(): assert g.is_valid() is False -# # TODO: Subgraph operations -# def test_graph_gets_subgraph_node(): -# g = Graph() -# n1 = GraphInvocation(id="1") -# n1.graph = Graph() - -# n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi") -# n1.graph.add_node(n1_1) - -# g.add_node(n1) - -# result = g.get_node("1.1") - -# assert result is not None -# assert result.id == "1" -# assert result == n1_1 - - -# def test_graph_expands_subgraph(): -# g = Graph() -# n1 = GraphInvocation(id="1") -# n1.graph = Graph() - -# n1_1 = AddInvocation(id="1", a=1, b=2) -# n1_2 = SubtractInvocation(id="2", b=3) -# n1.graph.add_node(n1_1) -# n1.graph.add_node(n1_2) -# n1.graph.add_edge(create_edge("1", "value", "2", "a")) - -# g.add_node(n1) - -# n2 = AddInvocation(id="2", b=5) -# g.add_node(n2) -# g.add_edge(create_edge("1.2", "value", "2", "a")) - -# dg = g.nx_graph_flat() -# assert set(dg.nodes) == {"1.1", "1.2", "2"} -# assert set(dg.edges) == {("1.1", "1.2"), ("1.2", "2")} - - -# def test_graph_subgraph_t2i(): -# g = Graph() -# n1 = GraphInvocation(id="1") - -# # Get text to image default graph -# lg = create_text_to_image() -# n1.graph = lg.graph - -# g.add_node(n1) - -# n2 = IntegerInvocation(id="2", value=512) -# n3 = IntegerInvocation(id="3", value=256) - -# g.add_node(n2) -# g.add_node(n3) - -# g.add_edge(create_edge("2", "value", "1.width", "value")) -# g.add_edge(create_edge("3", "value", "1.height", "value")) - -# n4 = ShowImageInvocation(id="4") -# g.add_node(n4) -# g.add_edge(create_edge("1.8", "image", "4", "image")) - -# # Validate -# dg = g.nx_graph_flat() -# assert set(dg.nodes) == {"1.width", "1.height", "1.seed", "1.3", "1.4", "1.5", "1.6", "1.7", "1.8", "2", "3", "4"} -# expected_edges = [(f"1.{e.source.node_id}", f"1.{e.destination.node_id}") for e in lg.graph.edges] -# expected_edges.extend([("2", "1.width"), ("3", "1.height"), ("1.8", "4")]) -# print(expected_edges) -# print(list(dg.edges)) -# assert set(dg.edges) == set(expected_edges) - - -# def test_graph_fails_to_get_missing_subgraph_node(): -# g = Graph() -# n1 = GraphInvocation(id="1") -# n1.graph = Graph() - -# n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi") -# n1.graph.add_node(n1_1) - -# g.add_node(n1) - -# with pytest.raises(NodeNotFoundError): -# _ = g.get_node("1.2") - - -# def test_graph_fails_to_enumerate_non_subgraph_node(): -# g = Graph() -# n1 = GraphInvocation(id="1") -# n1.graph = Graph() - -# n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi") -# n1.graph.add_node(n1_1) - -# g.add_node(n1) - -# n2 = ESRGANInvocation(id="2") -# g.add_node(n2) - -# with pytest.raises(NodeNotFoundError): -# _ = g.get_node("2.1") - - def test_graph_gets_networkx_graph(): g = Graph() n1 = TextToImageTestInvocation(id="1", prompt="Banana sushi") diff --git a/tests/test_session_queue.py b/tests/test_session_queue.py index 48b980539c7..bf26b9b0026 100644 --- a/tests/test_session_queue.py +++ b/tests/test_session_queue.py @@ -39,30 +39,6 @@ def batch_graph() -> Graph: return g -# def test_populate_graph_with_subgraph(): -# g1 = Graph() -# g1.add_node(PromptTestInvocation(id="1", prompt="Banana sushi")) -# g1.add_node(PromptTestInvocation(id="2", prompt="Banana sushi")) -# n1 = PromptTestInvocation(id="1", prompt="Banana snake") -# subgraph = Graph() -# subgraph.add_node(n1) -# g1.add_node(GraphInvocation(id="3", graph=subgraph)) - -# nfvs = [ -# NodeFieldValue(node_path="1", field_name="prompt", value="Strawberry sushi"), -# NodeFieldValue(node_path="2", field_name="prompt", value="Strawberry sunday"), -# NodeFieldValue(node_path="3.1", field_name="prompt", value="Strawberry snake"), -# ] - -# g2 = populate_graph(g1, nfvs) - -# # do not mutate g1 -# assert g1 is not g2 -# assert g2.get_node("1").prompt == "Strawberry sushi" -# assert g2.get_node("2").prompt == "Strawberry sunday" -# assert g2.get_node("3.1").prompt == "Strawberry snake" - - def test_create_sessions_from_batch_with_runs(batch_data_collection, batch_graph): b = Batch(graph=batch_graph, data=batch_data_collection, runs=2) t = list(create_session_nfv_tuples(batch=b, maximum=1000)) From 484c883d3b15166fcf4293ecac357aa44b1ad97b Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 18 Feb 2024 01:41:04 +1100 Subject: [PATCH 126/340] refactor(nodes): merge processors Consolidate graph processing logic into session processor. With graphs as the unit of work, and the session queue distributing graphs, we no longer need the invocation queue or processor. Instead, the session processor dequeues the next session and processes it in a simple loop, greatly simplifying the app. - Remove `graph_execution_manager` service. - Remove `queue` (invocation queue) service. - Remove `processor` (invocation processor) service. - Remove queue-related logic from `Invoker`. It now only starts and stops the services, providing them with access to other services. - Remove unused `invocation_retrieval_error` and `session_retrieval_error` events, these are no longer needed. - Clean up stats service now that it is less coupled to the rest of the app. - Refactor cancellation logic - cancellations now originate from session queue (i.e. HTTP cancel endpoint) and are emitted as events. Processor gets the events and sets the canceled event. Access to this event is provided to the invocation context for e.g. the step callback. - Remove `sessions` router; it provided access to `graph_executions` but that no longer exists. --- invokeai/app/api/dependencies.py | 10 - invokeai/app/api/routers/sessions.py | 276 ------------------ invokeai/app/api_app.py | 3 - invokeai/app/services/events/events_base.py | 48 +-- .../services/invocation_processor/__init__.py | 0 .../invocation_processor_base.py | 5 - .../invocation_processor_common.py | 15 - .../invocation_processor_default.py | 241 --------------- .../app/services/invocation_queue/__init__.py | 0 .../invocation_queue/invocation_queue_base.py | 26 -- .../invocation_queue_common.py | 23 -- .../invocation_queue_memory.py | 44 --- invokeai/app/services/invocation_services.py | 10 - .../invocation_stats/invocation_stats_base.py | 10 +- .../invocation_stats_default.py | 43 +-- invokeai/app/services/invoker.py | 52 ---- .../services/model_load/model_load_default.py | 3 - .../session_processor_common.py | 14 + .../session_processor_default.py | 251 +++++++++++----- .../session_queue/session_queue_sqlite.py | 5 +- .../app/services/shared/invocation_context.py | 20 +- invokeai/app/util/step_callback.py | 9 +- 22 files changed, 228 insertions(+), 880 deletions(-) delete mode 100644 invokeai/app/api/routers/sessions.py delete mode 100644 invokeai/app/services/invocation_processor/__init__.py delete mode 100644 invokeai/app/services/invocation_processor/invocation_processor_base.py delete mode 100644 invokeai/app/services/invocation_processor/invocation_processor_common.py delete mode 100644 invokeai/app/services/invocation_processor/invocation_processor_default.py delete mode 100644 invokeai/app/services/invocation_queue/__init__.py delete mode 100644 invokeai/app/services/invocation_queue/invocation_queue_base.py delete mode 100644 invokeai/app/services/invocation_queue/invocation_queue_common.py delete mode 100644 invokeai/app/services/invocation_queue/invocation_queue_memory.py diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 8e79b26e2d9..a9132516a86 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -4,7 +4,6 @@ import torch -from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache from invokeai.app.services.shared.sqlite.sqlite_util import init_db @@ -22,8 +21,6 @@ from ..services.image_records.image_records_sqlite import SqliteImageRecordStorage from ..services.images.images_default import ImageService from ..services.invocation_cache.invocation_cache_memory import MemoryInvocationCache -from ..services.invocation_processor.invocation_processor_default import DefaultInvocationProcessor -from ..services.invocation_queue.invocation_queue_memory import MemoryInvocationQueue from ..services.invocation_services import InvocationServices from ..services.invocation_stats.invocation_stats_default import InvocationStatsService from ..services.invoker import Invoker @@ -33,7 +30,6 @@ from ..services.names.names_default import SimpleNameService from ..services.session_processor.session_processor_default import DefaultSessionProcessor from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue -from ..services.shared.graph import GraphExecutionState from ..services.urls.urls_default import LocalUrlService from ..services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage from .events import FastAPIEventService @@ -85,7 +81,6 @@ def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger board_records = SqliteBoardRecordStorage(db=db) boards = BoardService() events = FastAPIEventService(event_handler_id) - graph_execution_manager = ItemStorageMemory[GraphExecutionState]() image_records = SqliteImageRecordStorage(db=db) images = ImageService() invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size) @@ -105,8 +100,6 @@ def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger ) names = SimpleNameService() performance_statistics = InvocationStatsService() - processor = DefaultInvocationProcessor() - queue = MemoryInvocationQueue() session_processor = DefaultSessionProcessor() session_queue = SqliteSessionQueue(db=db) urls = LocalUrlService() @@ -119,7 +112,6 @@ def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger boards=boards, configuration=configuration, events=events, - graph_execution_manager=graph_execution_manager, image_files=image_files, image_records=image_records, images=images, @@ -129,8 +121,6 @@ def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger download_queue=download_queue_service, names=names, performance_statistics=performance_statistics, - processor=processor, - queue=queue, session_processor=session_processor, session_queue=session_queue, urls=urls, diff --git a/invokeai/app/api/routers/sessions.py b/invokeai/app/api/routers/sessions.py deleted file mode 100644 index fb850d0b2b8..00000000000 --- a/invokeai/app/api/routers/sessions.py +++ /dev/null @@ -1,276 +0,0 @@ -# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) - - -from fastapi import HTTPException, Path -from fastapi.routing import APIRouter - -from ...services.shared.graph import GraphExecutionState -from ..dependencies import ApiDependencies - -session_router = APIRouter(prefix="/v1/sessions", tags=["sessions"]) - - -# @session_router.post( -# "/", -# operation_id="create_session", -# responses={ -# 200: {"model": GraphExecutionState}, -# 400: {"description": "Invalid json"}, -# }, -# deprecated=True, -# ) -# async def create_session( -# queue_id: str = Query(default="", description="The id of the queue to associate the session with"), -# graph: Optional[Graph] = Body(default=None, description="The graph to initialize the session with"), -# ) -> GraphExecutionState: -# """Creates a new session, optionally initializing it with an invocation graph""" -# session = ApiDependencies.invoker.create_execution_state(queue_id=queue_id, graph=graph) -# return session - - -# @session_router.get( -# "/", -# operation_id="list_sessions", -# responses={200: {"model": PaginatedResults[GraphExecutionState]}}, -# deprecated=True, -# ) -# async def list_sessions( -# page: int = Query(default=0, description="The page of results to get"), -# per_page: int = Query(default=10, description="The number of results per page"), -# query: str = Query(default="", description="The query string to search for"), -# ) -> PaginatedResults[GraphExecutionState]: -# """Gets a list of sessions, optionally searching""" -# if query == "": -# result = ApiDependencies.invoker.services.graph_execution_manager.list(page, per_page) -# else: -# result = ApiDependencies.invoker.services.graph_execution_manager.search(query, page, per_page) -# return result - - -@session_router.get( - "/{session_id}", - operation_id="get_session", - responses={ - 200: {"model": GraphExecutionState}, - 404: {"description": "Session not found"}, - }, -) -async def get_session( - session_id: str = Path(description="The id of the session to get"), -) -> GraphExecutionState: - """Gets a session""" - session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id) - if session is None: - raise HTTPException(status_code=404) - else: - return session - - -# @session_router.post( -# "/{session_id}/nodes", -# operation_id="add_node", -# responses={ -# 200: {"model": str}, -# 400: {"description": "Invalid node or link"}, -# 404: {"description": "Session not found"}, -# }, -# deprecated=True, -# ) -# async def add_node( -# session_id: str = Path(description="The id of the session"), -# node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body( # type: ignore -# description="The node to add" -# ), -# ) -> str: -# """Adds a node to the graph""" -# session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id) -# if session is None: -# raise HTTPException(status_code=404) - -# try: -# session.add_node(node) -# ApiDependencies.invoker.services.graph_execution_manager.set( -# session -# ) # TODO: can this be done automatically, or add node through an API? -# return session.id -# except NodeAlreadyExecutedError: -# raise HTTPException(status_code=400) -# except IndexError: -# raise HTTPException(status_code=400) - - -# @session_router.put( -# "/{session_id}/nodes/{node_path}", -# operation_id="update_node", -# responses={ -# 200: {"model": GraphExecutionState}, -# 400: {"description": "Invalid node or link"}, -# 404: {"description": "Session not found"}, -# }, -# deprecated=True, -# ) -# async def update_node( -# session_id: str = Path(description="The id of the session"), -# node_path: str = Path(description="The path to the node in the graph"), -# node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body( # type: ignore -# description="The new node" -# ), -# ) -> GraphExecutionState: -# """Updates a node in the graph and removes all linked edges""" -# session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id) -# if session is None: -# raise HTTPException(status_code=404) - -# try: -# session.update_node(node_path, node) -# ApiDependencies.invoker.services.graph_execution_manager.set( -# session -# ) # TODO: can this be done automatically, or add node through an API? -# return session -# except NodeAlreadyExecutedError: -# raise HTTPException(status_code=400) -# except IndexError: -# raise HTTPException(status_code=400) - - -# @session_router.delete( -# "/{session_id}/nodes/{node_path}", -# operation_id="delete_node", -# responses={ -# 200: {"model": GraphExecutionState}, -# 400: {"description": "Invalid node or link"}, -# 404: {"description": "Session not found"}, -# }, -# deprecated=True, -# ) -# async def delete_node( -# session_id: str = Path(description="The id of the session"), -# node_path: str = Path(description="The path to the node to delete"), -# ) -> GraphExecutionState: -# """Deletes a node in the graph and removes all linked edges""" -# session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id) -# if session is None: -# raise HTTPException(status_code=404) - -# try: -# session.delete_node(node_path) -# ApiDependencies.invoker.services.graph_execution_manager.set( -# session -# ) # TODO: can this be done automatically, or add node through an API? -# return session -# except NodeAlreadyExecutedError: -# raise HTTPException(status_code=400) -# except IndexError: -# raise HTTPException(status_code=400) - - -# @session_router.post( -# "/{session_id}/edges", -# operation_id="add_edge", -# responses={ -# 200: {"model": GraphExecutionState}, -# 400: {"description": "Invalid node or link"}, -# 404: {"description": "Session not found"}, -# }, -# deprecated=True, -# ) -# async def add_edge( -# session_id: str = Path(description="The id of the session"), -# edge: Edge = Body(description="The edge to add"), -# ) -> GraphExecutionState: -# """Adds an edge to the graph""" -# session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id) -# if session is None: -# raise HTTPException(status_code=404) - -# try: -# session.add_edge(edge) -# ApiDependencies.invoker.services.graph_execution_manager.set( -# session -# ) # TODO: can this be done automatically, or add node through an API? -# return session -# except NodeAlreadyExecutedError: -# raise HTTPException(status_code=400) -# except IndexError: -# raise HTTPException(status_code=400) - - -# # TODO: the edge being in the path here is really ugly, find a better solution -# @session_router.delete( -# "/{session_id}/edges/{from_node_id}/{from_field}/{to_node_id}/{to_field}", -# operation_id="delete_edge", -# responses={ -# 200: {"model": GraphExecutionState}, -# 400: {"description": "Invalid node or link"}, -# 404: {"description": "Session not found"}, -# }, -# deprecated=True, -# ) -# async def delete_edge( -# session_id: str = Path(description="The id of the session"), -# from_node_id: str = Path(description="The id of the node the edge is coming from"), -# from_field: str = Path(description="The field of the node the edge is coming from"), -# to_node_id: str = Path(description="The id of the node the edge is going to"), -# to_field: str = Path(description="The field of the node the edge is going to"), -# ) -> GraphExecutionState: -# """Deletes an edge from the graph""" -# session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id) -# if session is None: -# raise HTTPException(status_code=404) - -# try: -# edge = Edge( -# source=EdgeConnection(node_id=from_node_id, field=from_field), -# destination=EdgeConnection(node_id=to_node_id, field=to_field), -# ) -# session.delete_edge(edge) -# ApiDependencies.invoker.services.graph_execution_manager.set( -# session -# ) # TODO: can this be done automatically, or add node through an API? -# return session -# except NodeAlreadyExecutedError: -# raise HTTPException(status_code=400) -# except IndexError: -# raise HTTPException(status_code=400) - - -# @session_router.put( -# "/{session_id}/invoke", -# operation_id="invoke_session", -# responses={ -# 200: {"model": None}, -# 202: {"description": "The invocation is queued"}, -# 400: {"description": "The session has no invocations ready to invoke"}, -# 404: {"description": "Session not found"}, -# }, -# deprecated=True, -# ) -# async def invoke_session( -# queue_id: str = Query(description="The id of the queue to associate the session with"), -# session_id: str = Path(description="The id of the session to invoke"), -# all: bool = Query(default=False, description="Whether or not to invoke all remaining invocations"), -# ) -> Response: -# """Invokes a session""" -# session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id) -# if session is None: -# raise HTTPException(status_code=404) - -# if session.is_complete(): -# raise HTTPException(status_code=400) - -# ApiDependencies.invoker.invoke(queue_id, session, invoke_all=all) -# return Response(status_code=202) - - -# @session_router.delete( -# "/{session_id}/invoke", -# operation_id="cancel_session_invoke", -# responses={202: {"description": "The invocation is canceled"}}, -# deprecated=True, -# ) -# async def cancel_session_invoke( -# session_id: str = Path(description="The id of the session to cancel"), -# ) -> Response: -# """Invokes a session""" -# ApiDependencies.invoker.cancel(session_id) -# return Response(status_code=202) diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index 65607c436a5..f6b08ddba66 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -50,7 +50,6 @@ images, model_manager, session_queue, - sessions, utilities, workflows, ) @@ -110,8 +109,6 @@ async def shutdown_event() -> None: # Include all routers -app.include_router(sessions.session_router, prefix="/api") - app.include_router(utilities.utilities_router, prefix="/api") app.include_router(model_manager.model_manager_router, prefix="/api") app.include_router(download_queue.download_queue_router, prefix="/api") diff --git a/invokeai/app/services/events/events_base.py b/invokeai/app/services/events/events_base.py index 90d9068b88c..5355fe22987 100644 --- a/invokeai/app/services/events/events_base.py +++ b/invokeai/app/services/events/events_base.py @@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional, Union -from invokeai.app.services.invocation_processor.invocation_processor_common import ProgressImage +from invokeai.app.services.session_processor.session_processor_common import ProgressImage from invokeai.app.services.session_queue.session_queue_common import ( BatchStatus, EnqueueBatchResult, @@ -204,52 +204,6 @@ def emit_model_load_completed( }, ) - def emit_session_retrieval_error( - self, - queue_id: str, - queue_item_id: int, - queue_batch_id: str, - graph_execution_state_id: str, - error_type: str, - error: str, - ) -> None: - """Emitted when session retrieval fails""" - self.__emit_queue_event( - event_name="session_retrieval_error", - payload={ - "queue_id": queue_id, - "queue_item_id": queue_item_id, - "queue_batch_id": queue_batch_id, - "graph_execution_state_id": graph_execution_state_id, - "error_type": error_type, - "error": error, - }, - ) - - def emit_invocation_retrieval_error( - self, - queue_id: str, - queue_item_id: int, - queue_batch_id: str, - graph_execution_state_id: str, - node_id: str, - error_type: str, - error: str, - ) -> None: - """Emitted when invocation retrieval fails""" - self.__emit_queue_event( - event_name="invocation_retrieval_error", - payload={ - "queue_id": queue_id, - "queue_item_id": queue_item_id, - "queue_batch_id": queue_batch_id, - "graph_execution_state_id": graph_execution_state_id, - "node_id": node_id, - "error_type": error_type, - "error": error, - }, - ) - def emit_session_canceled( self, queue_id: str, diff --git a/invokeai/app/services/invocation_processor/__init__.py b/invokeai/app/services/invocation_processor/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/invokeai/app/services/invocation_processor/invocation_processor_base.py b/invokeai/app/services/invocation_processor/invocation_processor_base.py deleted file mode 100644 index 7947a201dd2..00000000000 --- a/invokeai/app/services/invocation_processor/invocation_processor_base.py +++ /dev/null @@ -1,5 +0,0 @@ -from abc import ABC - - -class InvocationProcessorABC(ABC): # noqa: B024 - pass diff --git a/invokeai/app/services/invocation_processor/invocation_processor_common.py b/invokeai/app/services/invocation_processor/invocation_processor_common.py deleted file mode 100644 index 347f6c73234..00000000000 --- a/invokeai/app/services/invocation_processor/invocation_processor_common.py +++ /dev/null @@ -1,15 +0,0 @@ -from pydantic import BaseModel, Field - - -class ProgressImage(BaseModel): - """The progress image sent intermittently during processing""" - - width: int = Field(description="The effective width of the image in pixels") - height: int = Field(description="The effective height of the image in pixels") - dataURL: str = Field(description="The image data as a b64 data URL") - - -class CanceledException(Exception): - """Execution canceled by user.""" - - pass diff --git a/invokeai/app/services/invocation_processor/invocation_processor_default.py b/invokeai/app/services/invocation_processor/invocation_processor_default.py deleted file mode 100644 index d2ebe235e63..00000000000 --- a/invokeai/app/services/invocation_processor/invocation_processor_default.py +++ /dev/null @@ -1,241 +0,0 @@ -import time -import traceback -from contextlib import suppress -from threading import BoundedSemaphore, Event, Thread -from typing import Optional - -import invokeai.backend.util.logging as logger -from invokeai.app.services.invocation_queue.invocation_queue_common import InvocationQueueItem -from invokeai.app.services.invocation_stats.invocation_stats_common import ( - GESStatsNotFoundError, -) -from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context -from invokeai.app.util.profiler import Profiler - -from ..invoker import Invoker -from .invocation_processor_base import InvocationProcessorABC -from .invocation_processor_common import CanceledException - - -class DefaultInvocationProcessor(InvocationProcessorABC): - __invoker_thread: Thread - __stop_event: Event - __invoker: Invoker - __threadLimit: BoundedSemaphore - - def start(self, invoker: Invoker) -> None: - # if we do want multithreading at some point, we could make this configurable - self.__threadLimit = BoundedSemaphore(1) - self.__invoker = invoker - self.__stop_event = Event() - self.__invoker_thread = Thread( - name="invoker_processor", - target=self.__process, - kwargs={"stop_event": self.__stop_event}, - ) - self.__invoker_thread.daemon = True # TODO: make async and do not use threads - self.__invoker_thread.start() - - def stop(self, *args, **kwargs) -> None: - self.__stop_event.set() - - def __process(self, stop_event: Event): - try: - self.__threadLimit.acquire() - queue_item: Optional[InvocationQueueItem] = None - - profiler = ( - Profiler( - logger=self.__invoker.services.logger, - output_dir=self.__invoker.services.configuration.profiles_path, - prefix=self.__invoker.services.configuration.profile_prefix, - ) - if self.__invoker.services.configuration.profile_graphs - else None - ) - - def stats_cleanup(graph_execution_state_id: str) -> None: - if profiler: - profile_path = profiler.stop() - stats_path = profile_path.with_suffix(".json") - self.__invoker.services.performance_statistics.dump_stats( - graph_execution_state_id=graph_execution_state_id, output_path=stats_path - ) - with suppress(GESStatsNotFoundError): - self.__invoker.services.performance_statistics.log_stats(graph_execution_state_id) - self.__invoker.services.performance_statistics.reset_stats(graph_execution_state_id) - - while not stop_event.is_set(): - try: - queue_item = self.__invoker.services.queue.get() - except Exception as e: - self.__invoker.services.logger.error("Exception while getting from queue:\n%s" % e) - - if not queue_item: # Probably stopping - # do not hammer the queue - time.sleep(0.5) - continue - - if profiler and profiler.profile_id != queue_item.graph_execution_state_id: - profiler.start(profile_id=queue_item.graph_execution_state_id) - - try: - graph_execution_state = self.__invoker.services.graph_execution_manager.get( - queue_item.graph_execution_state_id - ) - except Exception as e: - self.__invoker.services.logger.error("Exception while retrieving session:\n%s" % e) - self.__invoker.services.events.emit_session_retrieval_error( - queue_batch_id=queue_item.session_queue_batch_id, - queue_item_id=queue_item.session_queue_item_id, - queue_id=queue_item.session_queue_id, - graph_execution_state_id=queue_item.graph_execution_state_id, - error_type=e.__class__.__name__, - error=traceback.format_exc(), - ) - continue - - try: - invocation = graph_execution_state.execution_graph.get_node(queue_item.invocation_id) - except Exception as e: - self.__invoker.services.logger.error("Exception while retrieving invocation:\n%s" % e) - self.__invoker.services.events.emit_invocation_retrieval_error( - queue_batch_id=queue_item.session_queue_batch_id, - queue_item_id=queue_item.session_queue_item_id, - queue_id=queue_item.session_queue_id, - graph_execution_state_id=queue_item.graph_execution_state_id, - node_id=queue_item.invocation_id, - error_type=e.__class__.__name__, - error=traceback.format_exc(), - ) - continue - - # get the source node id to provide to clients (the prepared node id is not as useful) - source_node_id = graph_execution_state.prepared_source_mapping[invocation.id] - - # Send starting event - self.__invoker.services.events.emit_invocation_started( - queue_batch_id=queue_item.session_queue_batch_id, - queue_item_id=queue_item.session_queue_item_id, - queue_id=queue_item.session_queue_id, - graph_execution_state_id=graph_execution_state.id, - node=invocation.model_dump(), - source_node_id=source_node_id, - ) - - # Invoke - try: - graph_id = graph_execution_state.id - with self.__invoker.services.performance_statistics.collect_stats(invocation, graph_id): - # use the internal invoke_internal(), which wraps the node's invoke() method, - # which handles a few things: - # - nodes that require a value, but get it only from a connection - # - referencing the invocation cache instead of executing the node - context_data = InvocationContextData( - invocation=invocation, - session_id=graph_id, - workflow=queue_item.workflow, - source_node_id=source_node_id, - queue_id=queue_item.session_queue_id, - queue_item_id=queue_item.session_queue_item_id, - batch_id=queue_item.session_queue_batch_id, - ) - context = build_invocation_context( - services=self.__invoker.services, - context_data=context_data, - ) - outputs = invocation.invoke_internal(context=context, services=self.__invoker.services) - - # Check queue to see if this is canceled, and skip if so - if self.__invoker.services.queue.is_canceled(graph_execution_state.id): - continue - - # Save outputs and history - graph_execution_state.complete(invocation.id, outputs) - - # Save the state changes - self.__invoker.services.graph_execution_manager.set(graph_execution_state) - - # Send complete event - self.__invoker.services.events.emit_invocation_complete( - queue_batch_id=queue_item.session_queue_batch_id, - queue_item_id=queue_item.session_queue_item_id, - queue_id=queue_item.session_queue_id, - graph_execution_state_id=graph_execution_state.id, - node=invocation.model_dump(), - source_node_id=source_node_id, - result=outputs.model_dump(), - ) - - except KeyboardInterrupt: - pass - - except CanceledException: - stats_cleanup(graph_execution_state.id) - pass - - except Exception as e: - error = traceback.format_exc() - logger.error(error) - - # Save error - graph_execution_state.set_node_error(invocation.id, error) - - # Save the state changes - self.__invoker.services.graph_execution_manager.set(graph_execution_state) - - self.__invoker.services.logger.error("Error while invoking:\n%s" % e) - # Send error event - self.__invoker.services.events.emit_invocation_error( - queue_batch_id=queue_item.session_queue_batch_id, - queue_item_id=queue_item.session_queue_item_id, - queue_id=queue_item.session_queue_id, - graph_execution_state_id=graph_execution_state.id, - node=invocation.model_dump(), - source_node_id=source_node_id, - error_type=e.__class__.__name__, - error=error, - ) - pass - - # Check queue to see if this is canceled, and skip if so - if self.__invoker.services.queue.is_canceled(graph_execution_state.id): - continue - - # Queue any further commands if invoking all - is_complete = graph_execution_state.is_complete() - if queue_item.invoke_all and not is_complete: - try: - self.__invoker.invoke( - session_queue_batch_id=queue_item.session_queue_batch_id, - session_queue_item_id=queue_item.session_queue_item_id, - session_queue_id=queue_item.session_queue_id, - graph_execution_state=graph_execution_state, - workflow=queue_item.workflow, - invoke_all=True, - ) - except Exception as e: - self.__invoker.services.logger.error("Error while invoking:\n%s" % e) - self.__invoker.services.events.emit_invocation_error( - queue_batch_id=queue_item.session_queue_batch_id, - queue_item_id=queue_item.session_queue_item_id, - queue_id=queue_item.session_queue_id, - graph_execution_state_id=graph_execution_state.id, - node=invocation.model_dump(), - source_node_id=source_node_id, - error_type=e.__class__.__name__, - error=traceback.format_exc(), - ) - elif is_complete: - self.__invoker.services.events.emit_graph_execution_complete( - queue_batch_id=queue_item.session_queue_batch_id, - queue_item_id=queue_item.session_queue_item_id, - queue_id=queue_item.session_queue_id, - graph_execution_state_id=graph_execution_state.id, - ) - stats_cleanup(graph_execution_state.id) - - except KeyboardInterrupt: - pass # Log something? KeyboardInterrupt is probably not going to be seen by the processor - finally: - self.__threadLimit.release() diff --git a/invokeai/app/services/invocation_queue/__init__.py b/invokeai/app/services/invocation_queue/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/invokeai/app/services/invocation_queue/invocation_queue_base.py b/invokeai/app/services/invocation_queue/invocation_queue_base.py deleted file mode 100644 index 09f4875c5f6..00000000000 --- a/invokeai/app/services/invocation_queue/invocation_queue_base.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) - -from abc import ABC, abstractmethod -from typing import Optional - -from .invocation_queue_common import InvocationQueueItem - - -class InvocationQueueABC(ABC): - """Abstract base class for all invocation queues""" - - @abstractmethod - def get(self) -> InvocationQueueItem: - pass - - @abstractmethod - def put(self, item: Optional[InvocationQueueItem]) -> None: - pass - - @abstractmethod - def cancel(self, graph_execution_state_id: str) -> None: - pass - - @abstractmethod - def is_canceled(self, graph_execution_state_id: str) -> bool: - pass diff --git a/invokeai/app/services/invocation_queue/invocation_queue_common.py b/invokeai/app/services/invocation_queue/invocation_queue_common.py deleted file mode 100644 index 696f6a981d7..00000000000 --- a/invokeai/app/services/invocation_queue/invocation_queue_common.py +++ /dev/null @@ -1,23 +0,0 @@ -# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) - -import time -from typing import Optional - -from pydantic import BaseModel, Field - -from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID - - -class InvocationQueueItem(BaseModel): - graph_execution_state_id: str = Field(description="The ID of the graph execution state") - invocation_id: str = Field(description="The ID of the node being invoked") - session_queue_id: str = Field(description="The ID of the session queue from which this invocation queue item came") - session_queue_item_id: int = Field( - description="The ID of session queue item from which this invocation queue item came" - ) - session_queue_batch_id: str = Field( - description="The ID of the session batch from which this invocation queue item came" - ) - workflow: Optional[WorkflowWithoutID] = Field(description="The workflow associated with this queue item") - invoke_all: bool = Field(default=False) - timestamp: float = Field(default_factory=time.time) diff --git a/invokeai/app/services/invocation_queue/invocation_queue_memory.py b/invokeai/app/services/invocation_queue/invocation_queue_memory.py deleted file mode 100644 index 8d6fff70524..00000000000 --- a/invokeai/app/services/invocation_queue/invocation_queue_memory.py +++ /dev/null @@ -1,44 +0,0 @@ -# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) - -import time -from queue import Queue -from typing import Optional - -from .invocation_queue_base import InvocationQueueABC -from .invocation_queue_common import InvocationQueueItem - - -class MemoryInvocationQueue(InvocationQueueABC): - __queue: Queue - __cancellations: dict[str, float] - - def __init__(self): - self.__queue = Queue() - self.__cancellations = {} - - def get(self) -> InvocationQueueItem: - item = self.__queue.get() - - while ( - isinstance(item, InvocationQueueItem) - and item.graph_execution_state_id in self.__cancellations - and self.__cancellations[item.graph_execution_state_id] > item.timestamp - ): - item = self.__queue.get() - - # Clear old items - for graph_execution_state_id in list(self.__cancellations.keys()): - if self.__cancellations[graph_execution_state_id] < item.timestamp: - del self.__cancellations[graph_execution_state_id] - - return item - - def put(self, item: Optional[InvocationQueueItem]) -> None: - self.__queue.put(item) - - def cancel(self, graph_execution_state_id: str) -> None: - if graph_execution_state_id not in self.__cancellations: - self.__cancellations[graph_execution_state_id] = time.time() - - def is_canceled(self, graph_execution_state_id: str) -> bool: - return graph_execution_state_id in self.__cancellations diff --git a/invokeai/app/services/invocation_services.py b/invokeai/app/services/invocation_services.py index 0a1fa1e9222..04fe71a3eb3 100644 --- a/invokeai/app/services/invocation_services.py +++ b/invokeai/app/services/invocation_services.py @@ -23,15 +23,11 @@ from .image_records.image_records_base import ImageRecordStorageBase from .images.images_base import ImageServiceABC from .invocation_cache.invocation_cache_base import InvocationCacheBase - from .invocation_processor.invocation_processor_base import InvocationProcessorABC - from .invocation_queue.invocation_queue_base import InvocationQueueABC from .invocation_stats.invocation_stats_base import InvocationStatsServiceBase - from .item_storage.item_storage_base import ItemStorageABC from .model_manager.model_manager_base import ModelManagerServiceBase from .names.names_base import NameServiceBase from .session_processor.session_processor_base import SessionProcessorBase from .session_queue.session_queue_base import SessionQueueBase - from .shared.graph import GraphExecutionState from .urls.urls_base import UrlServiceBase from .workflow_records.workflow_records_base import WorkflowRecordsStorageBase @@ -47,16 +43,13 @@ def __init__( board_records: "BoardRecordStorageBase", configuration: "InvokeAIAppConfig", events: "EventServiceBase", - graph_execution_manager: "ItemStorageABC[GraphExecutionState]", images: "ImageServiceABC", image_files: "ImageFileStorageBase", image_records: "ImageRecordStorageBase", logger: "Logger", model_manager: "ModelManagerServiceBase", download_queue: "DownloadQueueServiceBase", - processor: "InvocationProcessorABC", performance_statistics: "InvocationStatsServiceBase", - queue: "InvocationQueueABC", session_queue: "SessionQueueBase", session_processor: "SessionProcessorBase", invocation_cache: "InvocationCacheBase", @@ -72,16 +65,13 @@ def __init__( self.board_records = board_records self.configuration = configuration self.events = events - self.graph_execution_manager = graph_execution_manager self.images = images self.image_files = image_files self.image_records = image_records self.logger = logger self.model_manager = model_manager self.download_queue = download_queue - self.processor = processor self.performance_statistics = performance_statistics - self.queue = queue self.session_queue = session_queue self.session_processor = session_processor self.invocation_cache = invocation_cache diff --git a/invokeai/app/services/invocation_stats/invocation_stats_base.py b/invokeai/app/services/invocation_stats/invocation_stats_base.py index ec8a453323d..b28220e74c4 100644 --- a/invokeai/app/services/invocation_stats/invocation_stats_base.py +++ b/invokeai/app/services/invocation_stats/invocation_stats_base.py @@ -3,7 +3,7 @@ Usage: -statistics = InvocationStatsService(graph_execution_manager) +statistics = InvocationStatsService() with statistics.collect_stats(invocation, graph_execution_state.id): ... execute graphs... statistics.log_stats() @@ -60,12 +60,8 @@ def collect_stats( pass @abstractmethod - def reset_stats(self, graph_execution_state_id: str) -> None: - """ - Reset all statistics for the indicated graph. - :param graph_execution_state_id: The id of the session whose stats to reset. - :raises GESStatsNotFoundError: if the graph isn't tracked in the stats. - """ + def reset_stats(self): + """Reset all stored statistics.""" pass @abstractmethod diff --git a/invokeai/app/services/invocation_stats/invocation_stats_default.py b/invokeai/app/services/invocation_stats/invocation_stats_default.py index 486a1ca5b3e..06a5b675c31 100644 --- a/invokeai/app/services/invocation_stats/invocation_stats_default.py +++ b/invokeai/app/services/invocation_stats/invocation_stats_default.py @@ -10,7 +10,6 @@ import invokeai.backend.util.logging as logger from invokeai.app.invocations.baseinvocation import BaseInvocation from invokeai.app.services.invoker import Invoker -from invokeai.app.services.item_storage.item_storage_common import ItemNotFoundError from invokeai.backend.model_manager.load.model_cache import CacheStats from .invocation_stats_base import InvocationStatsServiceBase @@ -51,9 +50,6 @@ def collect_stats(self, invocation: BaseInvocation, graph_execution_state_id: st self._stats[graph_execution_state_id] = GraphExecutionStats() self._cache_stats[graph_execution_state_id] = CacheStats() - # Prune stale stats. There should be none since we're starting a new graph, but just in case. - self._prune_stale_stats() - # Record state before the invocation. start_time = time.time() start_ram = psutil.Process().memory_info().rss @@ -78,42 +74,9 @@ def collect_stats(self, invocation: BaseInvocation, graph_execution_state_id: st ) self._stats[graph_execution_state_id].add_node_execution_stats(node_stats) - def _prune_stale_stats(self) -> None: - """Check all graphs being tracked and prune any that have completed/errored. - - This shouldn't be necessary, but we don't have totally robust upstream handling of graph completions/errors, so - for now we call this function periodically to prevent them from accumulating. - """ - to_prune: list[str] = [] - for graph_execution_state_id in self._stats: - try: - graph_execution_state = self._invoker.services.graph_execution_manager.get(graph_execution_state_id) - except ItemNotFoundError: - # TODO(ryand): What would cause this? Should this exception just be allowed to propagate? - logger.warning(f"Failed to get graph state for {graph_execution_state_id}.") - continue - - if not graph_execution_state.is_complete(): - # The graph is still running, don't prune it. - continue - - to_prune.append(graph_execution_state_id) - - for graph_execution_state_id in to_prune: - del self._stats[graph_execution_state_id] - del self._cache_stats[graph_execution_state_id] - - if len(to_prune) > 0: - logger.info(f"Pruned stale graph stats for {to_prune}.") - - def reset_stats(self, graph_execution_state_id: str): - try: - del self._stats[graph_execution_state_id] - del self._cache_stats[graph_execution_state_id] - except KeyError as e: - raise GESStatsNotFoundError( - f"Attempted to clear statistics for unknown graph {graph_execution_state_id}: {e}." - ) from e + def reset_stats(self): + self._stats = {} + self._cache_stats = {} def get_stats(self, graph_execution_state_id: str) -> InvocationStatsSummary: graph_stats_summary = self._get_graph_summary(graph_execution_state_id) diff --git a/invokeai/app/services/invoker.py b/invokeai/app/services/invoker.py index a04c6f2059d..527afb37f44 100644 --- a/invokeai/app/services/invoker.py +++ b/invokeai/app/services/invoker.py @@ -1,12 +1,7 @@ # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) -from typing import Optional -from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID - -from .invocation_queue.invocation_queue_common import InvocationQueueItem from .invocation_services import InvocationServices -from .shared.graph import Graph, GraphExecutionState class Invoker: @@ -18,51 +13,6 @@ def __init__(self, services: InvocationServices): self.services = services self._start() - def invoke( - self, - session_queue_id: str, - session_queue_item_id: int, - session_queue_batch_id: str, - graph_execution_state: GraphExecutionState, - workflow: Optional[WorkflowWithoutID] = None, - invoke_all: bool = False, - ) -> Optional[str]: - """Determines the next node to invoke and enqueues it, preparing if needed. - Returns the id of the queued node, or `None` if there are no nodes left to enqueue.""" - - # Get the next invocation - invocation = graph_execution_state.next() - if not invocation: - return None - - # Save the execution state - self.services.graph_execution_manager.set(graph_execution_state) - - # Queue the invocation - self.services.queue.put( - InvocationQueueItem( - session_queue_id=session_queue_id, - session_queue_item_id=session_queue_item_id, - session_queue_batch_id=session_queue_batch_id, - graph_execution_state_id=graph_execution_state.id, - invocation_id=invocation.id, - workflow=workflow, - invoke_all=invoke_all, - ) - ) - - return invocation.id - - def create_execution_state(self, graph: Optional[Graph] = None) -> GraphExecutionState: - """Creates a new execution state for the given graph""" - new_state = GraphExecutionState(graph=Graph() if graph is None else graph) - self.services.graph_execution_manager.set(new_state) - return new_state - - def cancel(self, graph_execution_state_id: str) -> None: - """Cancels the given execution state""" - self.services.queue.cancel(graph_execution_state_id) - def __start_service(self, service) -> None: # Call start() method on any services that have it start_op = getattr(service, "start", None) @@ -85,5 +35,3 @@ def stop(self) -> None: # First stop all services for service in vars(self.services): self.__stop_service(getattr(self.services, service)) - - self.services.queue.put(None) diff --git a/invokeai/app/services/model_load/model_load_default.py b/invokeai/app/services/model_load/model_load_default.py index 15c6283d8af..24ab10b4273 100644 --- a/invokeai/app/services/model_load/model_load_default.py +++ b/invokeai/app/services/model_load/model_load_default.py @@ -4,7 +4,6 @@ from typing import Optional, Type from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException from invokeai.app.services.invoker import Invoker from invokeai.app.services.shared.invocation_context import InvocationContextData from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType @@ -95,8 +94,6 @@ def _emit_load_event( ) -> None: if not self._invoker: return - if self._invoker.services.queue.is_canceled(context_data.session_id): - raise CanceledException() if not loaded: self._invoker.services.events.emit_model_load_started( diff --git a/invokeai/app/services/session_processor/session_processor_common.py b/invokeai/app/services/session_processor/session_processor_common.py index 00195a773f0..0ca51de517c 100644 --- a/invokeai/app/services/session_processor/session_processor_common.py +++ b/invokeai/app/services/session_processor/session_processor_common.py @@ -4,3 +4,17 @@ class SessionProcessorStatus(BaseModel): is_started: bool = Field(description="Whether the session processor is started") is_processing: bool = Field(description="Whether a session is being processed") + + +class CanceledException(Exception): + """Execution canceled by user.""" + + pass + + +class ProgressImage(BaseModel): + """The progress image sent intermittently during processing""" + + width: int = Field(description="The effective width of the image in pixels") + height: int = Field(description="The effective height of the image in pixels") + dataURL: str = Field(description="The image data as a b64 data URL") diff --git a/invokeai/app/services/session_processor/session_processor_default.py b/invokeai/app/services/session_processor/session_processor_default.py index 32e94a305dc..dd34c782523 100644 --- a/invokeai/app/services/session_processor/session_processor_default.py +++ b/invokeai/app/services/session_processor/session_processor_default.py @@ -1,4 +1,5 @@ import traceback +from contextlib import suppress from threading import BoundedSemaphore, Thread from threading import Event as ThreadEvent from typing import Optional @@ -7,7 +8,11 @@ from fastapi_events.typing import Event as FastAPIEvent from invokeai.app.services.events.events_base import EventServiceBase +from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError +from invokeai.app.services.session_processor.session_processor_common import CanceledException from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem +from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context +from invokeai.app.util.profiler import Profiler from ..invoker import Invoker from .session_processor_base import SessionProcessorBase @@ -19,123 +24,237 @@ class DefaultSessionProcessor(SessionProcessorBase): def start(self, invoker: Invoker) -> None: - self.__invoker: Invoker = invoker - self.__queue_item: Optional[SessionQueueItem] = None + self._invoker: Invoker = invoker + self._queue_item: Optional[SessionQueueItem] = None - self.__resume_event = ThreadEvent() - self.__stop_event = ThreadEvent() - self.__poll_now_event = ThreadEvent() + self._resume_event = ThreadEvent() + self._stop_event = ThreadEvent() + self._poll_now_event = ThreadEvent() + self._cancel_event = ThreadEvent() local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_queue_event) - self.__threadLimit = BoundedSemaphore(THREAD_LIMIT) - self.__thread = Thread( + self._thread_limit = BoundedSemaphore(THREAD_LIMIT) + self._thread = Thread( name="session_processor", - target=self.__process, + target=self._process, kwargs={ - "stop_event": self.__stop_event, - "poll_now_event": self.__poll_now_event, - "resume_event": self.__resume_event, + "stop_event": self._stop_event, + "poll_now_event": self._poll_now_event, + "resume_event": self._resume_event, + "cancel_event": self._cancel_event, }, ) - self.__thread.start() + self._thread.start() def stop(self, *args, **kwargs) -> None: - self.__stop_event.set() + self._stop_event.set() def _poll_now(self) -> None: - self.__poll_now_event.set() + self._poll_now_event.set() async def _on_queue_event(self, event: FastAPIEvent) -> None: event_name = event[1]["event"] - # This was a match statement, but match is not supported on python 3.9 - if event_name in [ - "graph_execution_state_complete", - "invocation_error", - "session_retrieval_error", - "invocation_retrieval_error", - ]: - self.__queue_item = None - self._poll_now() - elif ( - event_name == "session_canceled" - and self.__queue_item is not None - and self.__queue_item.session_id == event[1]["data"]["graph_execution_state_id"] - ): - self.__queue_item = None + if event_name == "session_canceled" or event_name == "queue_cleared": + # These both mean we should cancel the current session. + self._cancel_event.set() self._poll_now() elif event_name == "batch_enqueued": self._poll_now() - elif event_name == "queue_cleared": - self.__queue_item = None - self._poll_now() def resume(self) -> SessionProcessorStatus: - if not self.__resume_event.is_set(): - self.__resume_event.set() + if not self._resume_event.is_set(): + self._resume_event.set() return self.get_status() def pause(self) -> SessionProcessorStatus: - if self.__resume_event.is_set(): - self.__resume_event.clear() + if self._resume_event.is_set(): + self._resume_event.clear() return self.get_status() def get_status(self) -> SessionProcessorStatus: return SessionProcessorStatus( - is_started=self.__resume_event.is_set(), - is_processing=self.__queue_item is not None, + is_started=self._resume_event.is_set(), + is_processing=self._queue_item is not None, ) - def __process( + def _process( self, stop_event: ThreadEvent, poll_now_event: ThreadEvent, resume_event: ThreadEvent, + cancel_event: ThreadEvent, ): + # Outermost processor try block; any unhandled exception is a fatal processor error try: + self._thread_limit.acquire() stop_event.clear() resume_event.set() - self.__threadLimit.acquire() - queue_item: Optional[SessionQueueItem] = None + cancel_event.clear() + + # If profiling is enabled, create a profiler. The same profiler will be used for all sessions. Internally, + # the profiler will create a new profile for each session. + profiler = ( + Profiler( + logger=self._invoker.services.logger, + output_dir=self._invoker.services.configuration.profiles_path, + prefix=self._invoker.services.configuration.profile_prefix, + ) + if self._invoker.services.configuration.profile_graphs + else None + ) + + # Helper function to stop the profiler and save the stats + def stats_cleanup(graph_execution_state_id: str) -> None: + if profiler: + profile_path = profiler.stop() + stats_path = profile_path.with_suffix(".json") + self._invoker.services.performance_statistics.dump_stats( + graph_execution_state_id=graph_execution_state_id, output_path=stats_path + ) + # We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor + # we don't care about that - suppress the error. + with suppress(GESStatsNotFoundError): + self._invoker.services.performance_statistics.log_stats(graph_execution_state_id) + self._invoker.services.performance_statistics.reset_stats() + while not stop_event.is_set(): poll_now_event.clear() + # Middle processor try block; any unhandled exception is a non-fatal processor error try: - # do not dequeue if there is already a session running - if self.__queue_item is None and resume_event.is_set(): - queue_item = self.__invoker.services.session_queue.dequeue() - - if queue_item is not None: - self.__invoker.services.logger.debug(f"Executing queue item {queue_item.item_id}") - self.__queue_item = queue_item - self.__invoker.services.graph_execution_manager.set(queue_item.session) - self.__invoker.invoke( - session_queue_batch_id=queue_item.batch_id, - session_queue_id=queue_item.queue_id, - session_queue_item_id=queue_item.item_id, - graph_execution_state=queue_item.session, - workflow=queue_item.workflow, - invoke_all=True, + # Get the next session to process + self._queue_item = self._invoker.services.session_queue.dequeue() + if self._queue_item is not None and resume_event.is_set(): + self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}") + cancel_event.clear() + + # If profiling is enabled, start the profiler + if profiler is not None: + profiler.start(profile_id=self._queue_item.session_id) + + # Prepare invocations and take the first + invocation = self._queue_item.session.next() + + # Loop over invocations until the session is complete or canceled + while invocation is not None and not cancel_event.is_set(): + # get the source node id to provide to clients (the prepared node id is not as useful) + source_node_id = self._queue_item.session.prepared_source_mapping[invocation.id] + + # Send starting event + self._invoker.services.events.emit_invocation_started( + queue_batch_id=self._queue_item.batch_id, + queue_item_id=self._queue_item.item_id, + queue_id=self._queue_item.queue_id, + graph_execution_state_id=self._queue_item.session_id, + node=invocation.model_dump(), + source_node_id=source_node_id, ) - queue_item = None - if queue_item is None: - self.__invoker.services.logger.debug("Waiting for next polling interval or event") + # Innermost processor try block; any unhandled exception is an invocation error & will fail the graph + try: + with self._invoker.services.performance_statistics.collect_stats( + invocation, self._queue_item.session.id + ): + # Build invocation context (the node-facing API) + context_data = InvocationContextData( + invocation=invocation, + source_node_id=source_node_id, + session_id=self._queue_item.session.id, + workflow=self._queue_item.workflow, + queue_id=self._queue_item.queue_id, + queue_item_id=self._queue_item.item_id, + batch_id=self._queue_item.batch_id, + ) + context = build_invocation_context( + context_data=context_data, + services=self._invoker.services, + cancel_event=self._cancel_event, + ) + + # Invoke the node + outputs = invocation.invoke_internal( + context=context, services=self._invoker.services + ) + + # Save outputs and history + self._queue_item.session.complete(invocation.id, outputs) + + # Send complete event + self._invoker.services.events.emit_invocation_complete( + queue_batch_id=self._queue_item.batch_id, + queue_item_id=self._queue_item.item_id, + queue_id=self._queue_item.queue_id, + graph_execution_state_id=self._queue_item.session.id, + node=invocation.model_dump(), + source_node_id=source_node_id, + result=outputs.model_dump(), + ) + + except KeyboardInterrupt: + pass + + except CanceledException: + pass + + except Exception as e: + error = traceback.format_exc() + + # Save error + self._queue_item.session.set_node_error(invocation.id, error) + self._invoker.services.logger.error("Error while invoking:\n%s" % e) + + # Send error event + self._invoker.services.events.emit_invocation_error( + queue_batch_id=self._queue_item.session_id, + queue_item_id=self._queue_item.item_id, + queue_id=self._queue_item.queue_id, + graph_execution_state_id=self._queue_item.session.id, + node=invocation.model_dump(), + source_node_id=source_node_id, + error_type=e.__class__.__name__, + error=error, + ) + pass + + if self._queue_item.session.is_complete() or cancel_event.is_set(): + # Send complete event + self._invoker.services.events.emit_graph_execution_complete( + queue_batch_id=self._queue_item.batch_id, + queue_item_id=self._queue_item.item_id, + queue_id=self._queue_item.queue_id, + graph_execution_state_id=self._queue_item.session.id, + ) + # Save the stats and stop the profiler if it's running + stats_cleanup(self._queue_item.session.id) + invocation = None + else: + # Prepare the next invocation + invocation = self._queue_item.session.next() + + # The session is complete, immediately poll for next session + self._queue_item = None + poll_now_event.set() + else: + # The queue was empty, wait for next polling interval or event to try again + self._invoker.services.logger.debug("Waiting for next polling interval or event") poll_now_event.wait(POLLING_INTERVAL) continue except Exception as e: - self.__invoker.services.logger.error(f"Error in session processor: {e}") - if queue_item is not None: - self.__invoker.services.session_queue.cancel_queue_item( - queue_item.item_id, error=traceback.format_exc() + # Non-fatal error in processor, cancel the queue item and wait for next polling interval or event + self._invoker.services.logger.error(f"Error in session processor: {e}") + if self._queue_item is not None: + self._invoker.services.session_queue.cancel_queue_item( + self._queue_item.item_id, error=traceback.format_exc() ) poll_now_event.wait(POLLING_INTERVAL) continue except Exception as e: - self.__invoker.services.logger.error(f"Fatal Error in session processor: {e}") + # Fatal error in processor, log and pass - we're done here + self._invoker.services.logger.error(f"Fatal Error in session processor: {e}") pass finally: stop_event.clear() poll_now_event.clear() - self.__queue_item = None - self.__threadLimit.release() + self._queue_item = None + self._thread_limit.release() diff --git a/invokeai/app/services/session_queue/session_queue_sqlite.py b/invokeai/app/services/session_queue/session_queue_sqlite.py index 64642690e9c..7af9f0e08cd 100644 --- a/invokeai/app/services/session_queue/session_queue_sqlite.py +++ b/invokeai/app/services/session_queue/session_queue_sqlite.py @@ -60,7 +60,7 @@ async def _on_session_event(self, event: FastAPIEvent) -> FastAPIEvent: # This was a match statement, but match is not supported on python 3.9 if event_name == "graph_execution_state_complete": await self._handle_complete_event(event) - elif event_name in ["invocation_error", "session_retrieval_error", "invocation_retrieval_error"]: + elif event_name == "invocation_error": await self._handle_error_event(event) elif event_name == "session_canceled": await self._handle_cancel_event(event) @@ -429,7 +429,6 @@ def cancel_queue_item(self, item_id: int, error: Optional[str] = None) -> Sessio if queue_item.status not in ["canceled", "failed", "completed"]: status = "failed" if error is not None else "canceled" queue_item = self._set_queue_item_status(item_id=item_id, status=status, error=error) # type: ignore [arg-type] # mypy seems to not narrow the Literals here - self.__invoker.services.queue.cancel(queue_item.session_id) self.__invoker.services.events.emit_session_canceled( queue_item_id=queue_item.item_id, queue_id=queue_item.queue_id, @@ -471,7 +470,6 @@ def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBa ) self.__conn.commit() if current_queue_item is not None and current_queue_item.batch_id in batch_ids: - self.__invoker.services.queue.cancel(current_queue_item.session_id) self.__invoker.services.events.emit_session_canceled( queue_item_id=current_queue_item.item_id, queue_id=current_queue_item.queue_id, @@ -523,7 +521,6 @@ def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult: ) self.__conn.commit() if current_queue_item is not None and current_queue_item.queue_id == queue_id: - self.__invoker.services.queue.cancel(current_queue_item.session_id) self.__invoker.services.events.emit_session_canceled( queue_item_id=current_queue_item.item_id, queue_id=current_queue_item.queue_id, diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 43ecb2c543e..4606bd9e03b 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -1,6 +1,7 @@ +import threading from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Callable, Optional from PIL.Image import Image from torch import Tensor @@ -370,6 +371,12 @@ def get(self) -> InvokeAIAppConfig: class UtilInterface(InvocationContextInterface): + def __init__( + self, services: InvocationServices, context_data: InvocationContextData, is_canceled: Callable[[], bool] + ) -> None: + super().__init__(services, context_data) + self._is_canceled = is_canceled + def sd_step_callback(self, intermediate_state: PipelineIntermediateState, base_model: BaseModelType) -> None: """ The step callback emits a progress event with the current step, the total number of @@ -390,8 +397,8 @@ def sd_step_callback(self, intermediate_state: PipelineIntermediateState, base_m context_data=self._context_data, intermediate_state=intermediate_state, base_model=base_model, - invocation_queue=self._services.queue, events=self._services.events, + is_canceled=self._is_canceled, ) @@ -412,6 +419,7 @@ def __init__( boards: BoardsInterface, context_data: InvocationContextData, services: InvocationServices, + is_canceled: Callable[[], bool], ) -> None: self.images = images """Provides methods to save, get and update images and their metadata.""" @@ -433,11 +441,13 @@ def __init__( """Provides data about the current queue item and invocation. This is an internal API and may change without warning.""" self._services = services """Provides access to the full application services. This is an internal API and may change without warning.""" + self._is_canceled = is_canceled def build_invocation_context( services: InvocationServices, context_data: InvocationContextData, + cancel_event: threading.Event, ) -> InvocationContext: """ Builds the invocation context for a specific invocation execution. @@ -446,12 +456,15 @@ def build_invocation_context( :param invocation_context_data: The invocation context data. """ + def is_canceled() -> bool: + return cancel_event.is_set() + logger = LoggerInterface(services=services, context_data=context_data) images = ImagesInterface(services=services, context_data=context_data) tensors = TensorsInterface(services=services, context_data=context_data) models = ModelsInterface(services=services, context_data=context_data) config = ConfigInterface(services=services, context_data=context_data) - util = UtilInterface(services=services, context_data=context_data) + util = UtilInterface(services=services, context_data=context_data, is_canceled=is_canceled) conditioning = ConditioningInterface(services=services, context_data=context_data) boards = BoardsInterface(services=services, context_data=context_data) @@ -466,6 +479,7 @@ def build_invocation_context( conditioning=conditioning, services=services, boards=boards, + is_canceled=is_canceled, ) return ctx diff --git a/invokeai/app/util/step_callback.py b/invokeai/app/util/step_callback.py index 33d00ca3660..9c9f5254a47 100644 --- a/invokeai/app/util/step_callback.py +++ b/invokeai/app/util/step_callback.py @@ -1,9 +1,9 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Callable import torch from PIL import Image -from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException, ProgressImage +from invokeai.app.services.session_processor.session_processor_common import CanceledException, ProgressImage from invokeai.backend.model_manager.config import BaseModelType from ...backend.stable_diffusion import PipelineIntermediateState @@ -11,7 +11,6 @@ if TYPE_CHECKING: from invokeai.app.services.events.events_base import EventServiceBase - from invokeai.app.services.invocation_queue.invocation_queue_base import InvocationQueueABC from invokeai.app.services.shared.invocation_context import InvocationContextData @@ -34,10 +33,10 @@ def stable_diffusion_step_callback( context_data: "InvocationContextData", intermediate_state: PipelineIntermediateState, base_model: BaseModelType, - invocation_queue: "InvocationQueueABC", events: "EventServiceBase", + is_canceled: Callable[[], bool], ) -> None: - if invocation_queue.is_canceled(context_data.session_id): + if is_canceled(): raise CanceledException # Some schedulers report not only the noisy latents at the current timestep, From fe702a9e80b375e2fd32c451ef96cfeb615251c4 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 18 Feb 2024 11:25:00 +1100 Subject: [PATCH 127/340] feat(nodes): promote `is_canceled` to public node API --- invokeai/app/services/shared/invocation_context.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 4606bd9e03b..317cbdbb23d 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -437,11 +437,12 @@ def __init__( """Provides utility methods.""" self.boards = boards """Provides methods to interact with boards.""" + self.is_canceled = is_canceled + """Checks if the current invocation has been canceled.""" self._data = context_data """Provides data about the current queue item and invocation. This is an internal API and may change without warning.""" self._services = services """Provides access to the full application services. This is an internal API and may change without warning.""" - self._is_canceled = is_canceled def build_invocation_context( @@ -457,6 +458,7 @@ def build_invocation_context( """ def is_canceled() -> bool: + """Checks if the current invocation has been canceled.""" return cancel_event.is_set() logger = LoggerInterface(services=services, context_data=context_data) From b4c8183b4d043d3070cb40382981de320ddb2e83 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 18 Feb 2024 11:42:40 +1100 Subject: [PATCH 128/340] chore(nodes): add comments for cancel state --- .../session_processor/session_processor_default.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/invokeai/app/services/session_processor/session_processor_default.py b/invokeai/app/services/session_processor/session_processor_default.py index dd34c782523..2ff76c06e4c 100644 --- a/invokeai/app/services/session_processor/session_processor_default.py +++ b/invokeai/app/services/session_processor/session_processor_default.py @@ -192,9 +192,20 @@ def stats_cleanup(graph_execution_state_id: str) -> None: ) except KeyboardInterrupt: + # TODO(psyche): should we set the cancel event here and/or cancel the queue item? pass except CanceledException: + # When the user cancels the graph, we first set the cancel event. The event is checked + # between invocations, in this loop. Some invocations are long-running, and we need to + # be able to cancel them mid-execution. + # + # For example, denoising is a long-running invocation with many steps. A step callback + # is executed after each step. This step callback checks if the canceled event is set, + # then raises a CanceledException to stop execution immediately. + # + # When we get a CanceledException, we don't need to do anything - just pass and let the + # loop go to its next iteration, and the cancel event will be handled correctly. pass except Exception as e: From 29483c39ca78d1a691c2b84141e85891d293ea1a Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 18 Feb 2024 11:42:53 +1100 Subject: [PATCH 129/340] feat(nodes): better invocation error messages --- .../services/session_processor/session_processor_default.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/invokeai/app/services/session_processor/session_processor_default.py b/invokeai/app/services/session_processor/session_processor_default.py index 2ff76c06e4c..f235544405a 100644 --- a/invokeai/app/services/session_processor/session_processor_default.py +++ b/invokeai/app/services/session_processor/session_processor_default.py @@ -213,7 +213,9 @@ def stats_cleanup(graph_execution_state_id: str) -> None: # Save error self._queue_item.session.set_node_error(invocation.id, error) - self._invoker.services.logger.error("Error while invoking:\n%s" % e) + self._invoker.services.logger.error( + f"Error while invoking session {self._queue_item.session_id}, invocation {invocation.id} ({invocation.get_type()}):\n{e}" + ) # Send error event self._invoker.services.events.emit_invocation_error( From d4cef6c8466b0debc3d721c6da4bc90947df6137 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 18 Feb 2024 11:44:08 +1100 Subject: [PATCH 130/340] tidy(nodes): remove extraneous comments --- invokeai/app/services/shared/invocation_context.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 317cbdbb23d..6b6379dc5d3 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -388,11 +388,6 @@ def sd_step_callback(self, intermediate_state: PipelineIntermediateState, base_m :param base_model: The base model for the current denoising step. """ - # The step callback needs access to the events and the invocation queue services, but this - # represents a dangerous level of access. - # - # We wrap the step callback so that nodes do not have direct access to these services. - stable_diffusion_step_callback( context_data=self._context_data, intermediate_state=intermediate_state, @@ -458,7 +453,6 @@ def build_invocation_context( """ def is_canceled() -> bool: - """Checks if the current invocation has been canceled.""" return cancel_event.is_set() logger = LoggerInterface(services=services, context_data=context_data) From 9ea5019705b6eab70d11c22cf14524cc32646763 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 18 Feb 2024 11:51:50 +1100 Subject: [PATCH 131/340] feat(nodes): add whole queue_item to InvocationContextData No reason to not have the whole thing in there. --- .../session_processor_default.py | 16 +++++--------- .../app/services/shared/invocation_context.py | 22 ++++++------------- invokeai/app/util/step_callback.py | 10 ++++----- 3 files changed, 18 insertions(+), 30 deletions(-) diff --git a/invokeai/app/services/session_processor/session_processor_default.py b/invokeai/app/services/session_processor/session_processor_default.py index f235544405a..e49f79bcf3c 100644 --- a/invokeai/app/services/session_processor/session_processor_default.py +++ b/invokeai/app/services/session_processor/session_processor_default.py @@ -139,7 +139,7 @@ def stats_cleanup(graph_execution_state_id: str) -> None: # Loop over invocations until the session is complete or canceled while invocation is not None and not cancel_event.is_set(): # get the source node id to provide to clients (the prepared node id is not as useful) - source_node_id = self._queue_item.session.prepared_source_mapping[invocation.id] + source_invocation_id = self._queue_item.session.prepared_source_mapping[invocation.id] # Send starting event self._invoker.services.events.emit_invocation_started( @@ -148,7 +148,7 @@ def stats_cleanup(graph_execution_state_id: str) -> None: queue_id=self._queue_item.queue_id, graph_execution_state_id=self._queue_item.session_id, node=invocation.model_dump(), - source_node_id=source_node_id, + source_node_id=source_invocation_id, ) # Innermost processor try block; any unhandled exception is an invocation error & will fail the graph @@ -159,12 +159,8 @@ def stats_cleanup(graph_execution_state_id: str) -> None: # Build invocation context (the node-facing API) context_data = InvocationContextData( invocation=invocation, - source_node_id=source_node_id, - session_id=self._queue_item.session.id, - workflow=self._queue_item.workflow, - queue_id=self._queue_item.queue_id, - queue_item_id=self._queue_item.item_id, - batch_id=self._queue_item.batch_id, + source_invocation_id=source_invocation_id, + queue_item=self._queue_item, ) context = build_invocation_context( context_data=context_data, @@ -187,7 +183,7 @@ def stats_cleanup(graph_execution_state_id: str) -> None: queue_id=self._queue_item.queue_id, graph_execution_state_id=self._queue_item.session.id, node=invocation.model_dump(), - source_node_id=source_node_id, + source_node_id=source_invocation_id, result=outputs.model_dump(), ) @@ -224,7 +220,7 @@ def stats_cleanup(graph_execution_state_id: str) -> None: queue_id=self._queue_item.queue_id, graph_execution_state_id=self._queue_item.session.id, node=invocation.model_dump(), - source_node_id=source_node_id, + source_node_id=source_invocation_id, error_type=e.__class__.__name__, error=error, ) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 6b6379dc5d3..6b314d10bff 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -13,7 +13,6 @@ from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin from invokeai.app.services.images.images_common import ImageDTO from invokeai.app.services.invocation_services import InvocationServices -from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType from invokeai.backend.model_manager.load.load_base import LoadedModel @@ -23,6 +22,7 @@ if TYPE_CHECKING: from invokeai.app.invocations.baseinvocation import BaseInvocation + from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem """ The InvocationContext provides access to various services and data about the current invocation. @@ -49,20 +49,12 @@ @dataclass class InvocationContextData: + queue_item: "SessionQueueItem" + """The queue item that is being executed.""" invocation: "BaseInvocation" """The invocation that is being executed.""" - session_id: str - """The session that is being executed.""" - queue_id: str - """The queue in which the session is being executed.""" - source_node_id: str - """The ID of the node from which the currently executing invocation was prepared.""" - queue_item_id: int - """The ID of the queue item that is being executed.""" - batch_id: str - """The ID of the batch that is being executed.""" - workflow: Optional[WorkflowWithoutID] = None - """The workflow associated with this queue item, if any.""" + source_invocation_id: str + """The ID of the invocation from which the currently executing invocation was prepared.""" class InvocationContextInterface: @@ -191,8 +183,8 @@ def save( board_id=board_id_, metadata=metadata_, image_origin=ResourceOrigin.INTERNAL, - workflow=self._context_data.workflow, - session_id=self._context_data.session_id, + workflow=self._context_data.queue_item.workflow, + session_id=self._context_data.queue_item.session_id, node_id=self._context_data.invocation.id, ) diff --git a/invokeai/app/util/step_callback.py b/invokeai/app/util/step_callback.py index 9c9f5254a47..8cb59f5b3aa 100644 --- a/invokeai/app/util/step_callback.py +++ b/invokeai/app/util/step_callback.py @@ -114,12 +114,12 @@ def stable_diffusion_step_callback( dataURL = image_to_dataURL(image, image_format="JPEG") events.emit_generator_progress( - queue_id=context_data.queue_id, - queue_item_id=context_data.queue_item_id, - queue_batch_id=context_data.batch_id, - graph_execution_state_id=context_data.session_id, + queue_id=context_data.queue_item.queue_id, + queue_item_id=context_data.queue_item.item_id, + queue_batch_id=context_data.queue_item.batch_id, + graph_execution_state_id=context_data.queue_item.session_id, node_id=context_data.invocation.id, - source_node_id=context_data.source_node_id, + source_node_id=context_data.source_invocation_id, progress_image=ProgressImage(width=width, height=height, dataURL=dataURL), step=intermediate_state.step, order=intermediate_state.order, From cadb40816685a1927afa0e38a18fcccb75b458ef Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 18 Feb 2024 11:54:16 +1100 Subject: [PATCH 132/340] refactor(nodes): move is_canceled to `context.util` --- .../app/services/shared/invocation_context.py | 21 ++++++++----------- 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 6b314d10bff..994c99dc45b 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -1,7 +1,7 @@ import threading from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING, Callable, Optional +from typing import TYPE_CHECKING, Optional from PIL.Image import Image from torch import Tensor @@ -364,10 +364,14 @@ def get(self) -> InvokeAIAppConfig: class UtilInterface(InvocationContextInterface): def __init__( - self, services: InvocationServices, context_data: InvocationContextData, is_canceled: Callable[[], bool] + self, services: InvocationServices, context_data: InvocationContextData, cancel_event: threading.Event ) -> None: super().__init__(services, context_data) - self._is_canceled = is_canceled + self._cancel_event = cancel_event + + def is_canceled(self) -> bool: + """Checks if the current invocation has been canceled.""" + return self._cancel_event.is_set() def sd_step_callback(self, intermediate_state: PipelineIntermediateState, base_model: BaseModelType) -> None: """ @@ -385,7 +389,7 @@ def sd_step_callback(self, intermediate_state: PipelineIntermediateState, base_m intermediate_state=intermediate_state, base_model=base_model, events=self._services.events, - is_canceled=self._is_canceled, + is_canceled=self.is_canceled, ) @@ -406,7 +410,6 @@ def __init__( boards: BoardsInterface, context_data: InvocationContextData, services: InvocationServices, - is_canceled: Callable[[], bool], ) -> None: self.images = images """Provides methods to save, get and update images and their metadata.""" @@ -424,8 +427,6 @@ def __init__( """Provides utility methods.""" self.boards = boards """Provides methods to interact with boards.""" - self.is_canceled = is_canceled - """Checks if the current invocation has been canceled.""" self._data = context_data """Provides data about the current queue item and invocation. This is an internal API and may change without warning.""" self._services = services @@ -444,15 +445,12 @@ def build_invocation_context( :param invocation_context_data: The invocation context data. """ - def is_canceled() -> bool: - return cancel_event.is_set() - logger = LoggerInterface(services=services, context_data=context_data) images = ImagesInterface(services=services, context_data=context_data) tensors = TensorsInterface(services=services, context_data=context_data) models = ModelsInterface(services=services, context_data=context_data) config = ConfigInterface(services=services, context_data=context_data) - util = UtilInterface(services=services, context_data=context_data, is_canceled=is_canceled) + util = UtilInterface(services=services, context_data=context_data, cancel_event=cancel_event) conditioning = ConditioningInterface(services=services, context_data=context_data) boards = BoardsInterface(services=services, context_data=context_data) @@ -467,7 +465,6 @@ def is_canceled() -> bool: conditioning=conditioning, services=services, boards=boards, - is_canceled=is_canceled, ) return ctx From e0919dd9adb4178893b13503f9908dbd4e553eb2 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 18 Feb 2024 11:56:54 +1100 Subject: [PATCH 133/340] chore(nodes): "context_data" -> "data" Changed within InvocationContext, for brevity. --- .../session_processor_default.py | 4 +- .../app/services/shared/invocation_context.py | 54 +++++++++---------- tests/test_graph_execution_state.py | 2 +- 3 files changed, 30 insertions(+), 30 deletions(-) diff --git a/invokeai/app/services/session_processor/session_processor_default.py b/invokeai/app/services/session_processor/session_processor_default.py index e49f79bcf3c..dc08fc83458 100644 --- a/invokeai/app/services/session_processor/session_processor_default.py +++ b/invokeai/app/services/session_processor/session_processor_default.py @@ -157,13 +157,13 @@ def stats_cleanup(graph_execution_state_id: str) -> None: invocation, self._queue_item.session.id ): # Build invocation context (the node-facing API) - context_data = InvocationContextData( + data = InvocationContextData( invocation=invocation, source_invocation_id=source_invocation_id, queue_item=self._queue_item, ) context = build_invocation_context( - context_data=context_data, + data=data, services=self._invoker.services, cancel_event=self._cancel_event, ) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 994c99dc45b..f8425523bf2 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -58,9 +58,9 @@ class InvocationContextData: class InvocationContextInterface: - def __init__(self, services: InvocationServices, context_data: InvocationContextData) -> None: + def __init__(self, services: InvocationServices, data: InvocationContextData) -> None: self._services = services - self._context_data = context_data + self._data = data class BoardsInterface(InvocationContextInterface): @@ -166,26 +166,26 @@ def save( metadata_ = None if metadata: metadata_ = metadata - elif isinstance(self._context_data.invocation, WithMetadata): - metadata_ = self._context_data.invocation.metadata + elif isinstance(self._data.invocation, WithMetadata): + metadata_ = self._data.invocation.metadata # If `board_id` is provided directly, use that. Else, use the board provided by `WithBoard`, falling back to None. board_id_ = None if board_id: board_id_ = board_id - elif isinstance(self._context_data.invocation, WithBoard) and self._context_data.invocation.board: - board_id_ = self._context_data.invocation.board.board_id + elif isinstance(self._data.invocation, WithBoard) and self._data.invocation.board: + board_id_ = self._data.invocation.board.board_id return self._services.images.create( image=image, - is_intermediate=self._context_data.invocation.is_intermediate, + is_intermediate=self._data.invocation.is_intermediate, image_category=image_category, board_id=board_id_, metadata=metadata_, image_origin=ResourceOrigin.INTERNAL, - workflow=self._context_data.queue_item.workflow, - session_id=self._context_data.queue_item.session_id, - node_id=self._context_data.invocation.id, + workflow=self._data.queue_item.workflow, + session_id=self._data.queue_item.session_id, + node_id=self._data.invocation.id, ) def get_pil(self, image_name: str, mode: IMAGE_MODES | None = None) -> Image: @@ -285,7 +285,7 @@ def load(self, key: str, submodel_type: Optional[SubModelType] = None) -> Loaded # the event payloads. return self._services.model_manager.load_model_by_key( - key=key, submodel_type=submodel_type, context_data=self._context_data + key=key, submodel_type=submodel_type, context_data=self._data ) def load_by_attrs( @@ -304,7 +304,7 @@ def load_by_attrs( base_model=base_model, model_type=model_type, submodel=submodel, - context_data=self._context_data, + context_data=self._data, ) def get_config(self, key: str) -> AnyModelConfig: @@ -364,9 +364,9 @@ def get(self) -> InvokeAIAppConfig: class UtilInterface(InvocationContextInterface): def __init__( - self, services: InvocationServices, context_data: InvocationContextData, cancel_event: threading.Event + self, services: InvocationServices, data: InvocationContextData, cancel_event: threading.Event ) -> None: - super().__init__(services, context_data) + super().__init__(services, data) self._cancel_event = cancel_event def is_canceled(self) -> bool: @@ -385,7 +385,7 @@ def sd_step_callback(self, intermediate_state: PipelineIntermediateState, base_m """ stable_diffusion_step_callback( - context_data=self._context_data, + context_data=self._data, intermediate_state=intermediate_state, base_model=base_model, events=self._services.events, @@ -408,7 +408,7 @@ def __init__( config: ConfigInterface, util: UtilInterface, boards: BoardsInterface, - context_data: InvocationContextData, + data: InvocationContextData, services: InvocationServices, ) -> None: self.images = images @@ -427,7 +427,7 @@ def __init__( """Provides utility methods.""" self.boards = boards """Provides methods to interact with boards.""" - self._data = context_data + self._data = data """Provides data about the current queue item and invocation. This is an internal API and may change without warning.""" self._services = services """Provides access to the full application services. This is an internal API and may change without warning.""" @@ -435,7 +435,7 @@ def __init__( def build_invocation_context( services: InvocationServices, - context_data: InvocationContextData, + data: InvocationContextData, cancel_event: threading.Event, ) -> InvocationContext: """ @@ -445,14 +445,14 @@ def build_invocation_context( :param invocation_context_data: The invocation context data. """ - logger = LoggerInterface(services=services, context_data=context_data) - images = ImagesInterface(services=services, context_data=context_data) - tensors = TensorsInterface(services=services, context_data=context_data) - models = ModelsInterface(services=services, context_data=context_data) - config = ConfigInterface(services=services, context_data=context_data) - util = UtilInterface(services=services, context_data=context_data, cancel_event=cancel_event) - conditioning = ConditioningInterface(services=services, context_data=context_data) - boards = BoardsInterface(services=services, context_data=context_data) + logger = LoggerInterface(services=services, data=data) + images = ImagesInterface(services=services, data=data) + tensors = TensorsInterface(services=services, data=data) + models = ModelsInterface(services=services, data=data) + config = ConfigInterface(services=services, data=data) + util = UtilInterface(services=services, data=data, cancel_event=cancel_event) + conditioning = ConditioningInterface(services=services, data=data) + boards = BoardsInterface(services=services, data=data) ctx = InvocationContext( images=images, @@ -460,7 +460,7 @@ def build_invocation_context( config=config, tensors=tensors, models=models, - context_data=context_data, + data=data, util=util, conditioning=conditioning, services=services, diff --git a/tests/test_graph_execution_state.py b/tests/test_graph_execution_state.py index f839a4a8785..9cff502acf7 100644 --- a/tests/test_graph_execution_state.py +++ b/tests/test_graph_execution_state.py @@ -86,7 +86,7 @@ def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[B InvocationContext( conditioning=None, config=None, - context_data=None, + data=None, images=None, tensors=None, logger=None, From d2038e45fdb5592402e0d6f5558b40dadf679ebe Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 18 Feb 2024 12:03:43 +1100 Subject: [PATCH 134/340] chore(nodes): better comments for invocation context --- .../app/services/shared/invocation_context.py | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index f8425523bf2..31064a5e7cc 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -247,7 +247,7 @@ def save(self, conditioning_data: ConditioningFieldData) -> str: """ Saves a conditioning data object, returning its name. - :param conditioning_context_data: The conditioning data to save. + :param conditioning_data: The conditioning data to save. """ name = self._services.conditioning.save(obj=conditioning_data) @@ -412,25 +412,25 @@ def __init__( services: InvocationServices, ) -> None: self.images = images - """Provides methods to save, get and update images and their metadata.""" + """Methods to save, get and update images and their metadata.""" self.tensors = tensors - """Provides methods to save and get tensors, including image, noise, masks, and masked images.""" + """Methods to save and get tensors, including image, noise, masks, and masked images.""" self.conditioning = conditioning - """Provides methods to save and get conditioning data.""" + """Methods to save and get conditioning data.""" self.models = models - """Provides methods to check if a model exists, get a model, and get a model's info.""" + """Methods to check if a model exists, get a model, and get a model's info.""" self.logger = logger - """Provides access to the app logger.""" + """The app logger.""" self.config = config - """Provides access to the app's config.""" + """The app config.""" self.util = util - """Provides utility methods.""" + """Utility methods, including a method to check if an invocation was canceled and step callbacks.""" self.boards = boards - """Provides methods to interact with boards.""" + """Methods to interact with boards.""" self._data = data - """Provides data about the current queue item and invocation. This is an internal API and may change without warning.""" + """An internal API providing access to data about the current queue item and invocation. You probably shouldn't use this. It may change without warning.""" self._services = services - """Provides access to the full application services. This is an internal API and may change without warning.""" + """An internal API providing access to all application services. You probably shouldn't use this. It may change without warning.""" def build_invocation_context( @@ -441,8 +441,8 @@ def build_invocation_context( """ Builds the invocation context for a specific invocation execution. - :param invocation_services: The invocation services to wrap. - :param invocation_context_data: The invocation context data. + :param services: The invocation services to wrap. + :param data: The invocation context data. """ logger = LoggerInterface(services=services, data=data) From 26f0af47076c3a9337553e8db2f6d7e1ffbcfbec Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 18 Feb 2024 12:29:08 +1100 Subject: [PATCH 135/340] tests(nodes): fix tests following removal of services --- tests/test_graph_execution_state.py | 89 ++++++--------- tests/test_invoker.py | 163 ---------------------------- 2 files changed, 35 insertions(+), 217 deletions(-) delete mode 100644 tests/test_invoker.py diff --git a/tests/test_graph_execution_state.py b/tests/test_graph_execution_state.py index 9cff502acf7..2e88178424a 100644 --- a/tests/test_graph_execution_state.py +++ b/tests/test_graph_execution_state.py @@ -1,9 +1,9 @@ import logging +from typing import Optional +from unittest.mock import Mock import pytest -from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory - # This import must happen before other invoke imports or test in other files(!!) break from .test_nodes import ( # isort: split PromptCollectionTestInvocation, @@ -17,8 +17,6 @@ from invokeai.app.invocations.math import AddInvocation, MultiplyInvocation from invokeai.app.services.config.config_default import InvokeAIAppConfig from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache -from invokeai.app.services.invocation_processor.invocation_processor_default import DefaultInvocationProcessor -from invokeai.app.services.invocation_queue.invocation_queue_memory import MemoryInvocationQueue from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService from invokeai.app.services.shared.graph import ( @@ -28,11 +26,11 @@ IterateInvocation, ) -from .test_invoker import create_edge +from .test_nodes import create_edge @pytest.fixture -def simple_graph(): +def simple_graph() -> Graph: g = Graph() g.add_node(PromptTestInvocation(id="1", prompt="Banana sushi")) g.add_node(TextToImageTestInvocation(id="2")) @@ -47,7 +45,6 @@ def simple_graph(): def mock_services() -> InvocationServices: configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0) # NOTE: none of these are actually called by the test invocations - graph_execution_manager = ItemStorageMemory[GraphExecutionState]() return InvocationServices( board_image_records=None, # type: ignore board_images=None, # type: ignore @@ -55,7 +52,6 @@ def mock_services() -> InvocationServices: boards=None, # type: ignore configuration=configuration, events=TestEventService(), - graph_execution_manager=graph_execution_manager, image_files=None, # type: ignore image_records=None, # type: ignore images=None, # type: ignore @@ -65,47 +61,32 @@ def mock_services() -> InvocationServices: download_queue=None, # type: ignore names=None, # type: ignore performance_statistics=InvocationStatsService(), - processor=DefaultInvocationProcessor(), - queue=MemoryInvocationQueue(), session_processor=None, # type: ignore session_queue=None, # type: ignore urls=None, # type: ignore workflow_records=None, # type: ignore - tensors=None, - conditioning=None, + tensors=None, # type: ignore + conditioning=None, # type: ignore ) -def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[BaseInvocation, BaseInvocationOutput]: +def invoke_next(g: GraphExecutionState) -> tuple[Optional[BaseInvocation], Optional[BaseInvocationOutput]]: n = g.next() if n is None: return (None, None) print(f"invoking {n.id}: {type(n)}") - o = n.invoke( - InvocationContext( - conditioning=None, - config=None, - data=None, - images=None, - tensors=None, - logger=None, - models=None, - util=None, - boards=None, - services=None, - ) - ) + o = n.invoke(Mock(InvocationContext)) g.complete(n.id, o) return (n, o) -def test_graph_state_executes_in_order(simple_graph, mock_services): +def test_graph_state_executes_in_order(simple_graph: Graph): g = GraphExecutionState(graph=simple_graph) - n1 = invoke_next(g, mock_services) - n2 = invoke_next(g, mock_services) + n1 = invoke_next(g) + n2 = invoke_next(g) n3 = g.next() assert g.prepared_source_mapping[n1[0].id] == "1" @@ -115,18 +96,18 @@ def test_graph_state_executes_in_order(simple_graph, mock_services): assert n2[0].prompt == n1[0].prompt -def test_graph_is_complete(simple_graph, mock_services): +def test_graph_is_complete(simple_graph: Graph): g = GraphExecutionState(graph=simple_graph) - _ = invoke_next(g, mock_services) - _ = invoke_next(g, mock_services) + _ = invoke_next(g) + _ = invoke_next(g) _ = g.next() assert g.is_complete() -def test_graph_is_not_complete(simple_graph, mock_services): +def test_graph_is_not_complete(simple_graph: Graph): g = GraphExecutionState(graph=simple_graph) - _ = invoke_next(g, mock_services) + _ = invoke_next(g) _ = g.next() assert not g.is_complete() @@ -135,7 +116,7 @@ def test_graph_is_not_complete(simple_graph, mock_services): # TODO: test completion with iterators/subgraphs -def test_graph_state_expands_iterator(mock_services): +def test_graph_state_expands_iterator(): graph = Graph() graph.add_node(RangeInvocation(id="0", start=0, stop=3, step=1)) graph.add_node(IterateInvocation(id="1")) @@ -147,7 +128,7 @@ def test_graph_state_expands_iterator(mock_services): g = GraphExecutionState(graph=graph) while not g.is_complete(): - invoke_next(g, mock_services) + invoke_next(g) prepared_add_nodes = g.source_prepared_mapping["3"] results = {g.results[n].value for n in prepared_add_nodes} @@ -155,7 +136,7 @@ def test_graph_state_expands_iterator(mock_services): assert results == expected -def test_graph_state_collects(mock_services): +def test_graph_state_collects(): graph = Graph() test_prompts = ["Banana sushi", "Cat sushi"] graph.add_node(PromptCollectionTestInvocation(id="1", collection=list(test_prompts))) @@ -167,19 +148,19 @@ def test_graph_state_collects(mock_services): graph.add_edge(create_edge("3", "prompt", "4", "item")) g = GraphExecutionState(graph=graph) - _ = invoke_next(g, mock_services) - _ = invoke_next(g, mock_services) - _ = invoke_next(g, mock_services) - _ = invoke_next(g, mock_services) - _ = invoke_next(g, mock_services) - n6 = invoke_next(g, mock_services) + _ = invoke_next(g) + _ = invoke_next(g) + _ = invoke_next(g) + _ = invoke_next(g) + _ = invoke_next(g) + n6 = invoke_next(g) assert isinstance(n6[0], CollectInvocation) assert sorted(g.results[n6[0].id].collection) == sorted(test_prompts) -def test_graph_state_prepares_eagerly(mock_services): +def test_graph_state_prepares_eagerly(): """Tests that all prepareable nodes are prepared""" graph = Graph() @@ -208,7 +189,7 @@ def test_graph_state_prepares_eagerly(mock_services): assert "prompt_iterated" not in g.source_prepared_mapping -def test_graph_executes_depth_first(mock_services): +def test_graph_executes_depth_first(): """Tests that the graph executes depth-first, executing a branch as far as possible before moving to the next branch""" graph = Graph() @@ -222,14 +203,14 @@ def test_graph_executes_depth_first(mock_services): graph.add_edge(create_edge("prompt_iterated", "prompt", "prompt_successor", "prompt")) g = GraphExecutionState(graph=graph) - _ = invoke_next(g, mock_services) - _ = invoke_next(g, mock_services) - _ = invoke_next(g, mock_services) - _ = invoke_next(g, mock_services) + _ = invoke_next(g) + _ = invoke_next(g) + _ = invoke_next(g) + _ = invoke_next(g) # Because ordering is not guaranteed, we cannot compare results directly. # Instead, we must count the number of results. - def get_completed_count(g, id): + def get_completed_count(g: GraphExecutionState, id: str): ids = list(g.source_prepared_mapping[id]) completed_ids = [i for i in g.executed if i in ids] return len(completed_ids) @@ -238,17 +219,17 @@ def get_completed_count(g, id): assert get_completed_count(g, "prompt_iterated") == 1 assert get_completed_count(g, "prompt_successor") == 0 - _ = invoke_next(g, mock_services) + _ = invoke_next(g) assert get_completed_count(g, "prompt_iterated") == 1 assert get_completed_count(g, "prompt_successor") == 1 - _ = invoke_next(g, mock_services) + _ = invoke_next(g) assert get_completed_count(g, "prompt_iterated") == 2 assert get_completed_count(g, "prompt_successor") == 1 - _ = invoke_next(g, mock_services) + _ = invoke_next(g) assert get_completed_count(g, "prompt_iterated") == 2 assert get_completed_count(g, "prompt_successor") == 2 diff --git a/tests/test_invoker.py b/tests/test_invoker.py deleted file mode 100644 index 38fcf859a58..00000000000 --- a/tests/test_invoker.py +++ /dev/null @@ -1,163 +0,0 @@ -import logging -from unittest.mock import Mock - -import pytest - -from invokeai.app.services.config.config_default import InvokeAIAppConfig -from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory - -# This import must happen before other invoke imports or test in other files(!!) break -from .test_nodes import ( # isort: split - ErrorInvocation, - PromptTestInvocation, - TestEventService, - TextToImageTestInvocation, - create_edge, - wait_until, -) - -from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache -from invokeai.app.services.invocation_processor.invocation_processor_default import DefaultInvocationProcessor -from invokeai.app.services.invocation_queue.invocation_queue_memory import MemoryInvocationQueue -from invokeai.app.services.invocation_services import InvocationServices -from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService -from invokeai.app.services.invoker import Invoker -from invokeai.app.services.session_queue.session_queue_common import DEFAULT_QUEUE_ID -from invokeai.app.services.shared.graph import Graph, GraphExecutionState - - -@pytest.fixture -def simple_graph(): - g = Graph() - g.add_node(PromptTestInvocation(id="1", prompt="Banana sushi")) - g.add_node(TextToImageTestInvocation(id="2")) - g.add_edge(create_edge("1", "prompt", "2", "prompt")) - return g - - -# This must be defined here to avoid issues with the dynamic creation of the union of all invocation types -# Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate -# the test invocations. -@pytest.fixture -def mock_services() -> InvocationServices: - configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0) - return InvocationServices( - board_image_records=None, # type: ignore - board_images=None, # type: ignore - board_records=None, # type: ignore - boards=None, # type: ignore - configuration=configuration, - events=TestEventService(), - graph_execution_manager=ItemStorageMemory[GraphExecutionState](), - image_files=None, # type: ignore - image_records=None, # type: ignore - images=None, # type: ignore - invocation_cache=MemoryInvocationCache(max_cache_size=0), - logger=logging, # type: ignore - model_manager=Mock(), # type: ignore - download_queue=None, # type: ignore - names=None, # type: ignore - performance_statistics=InvocationStatsService(), - processor=DefaultInvocationProcessor(), - queue=MemoryInvocationQueue(), - session_processor=None, # type: ignore - session_queue=None, # type: ignore - urls=None, # type: ignore - workflow_records=None, # type: ignore - tensors=None, - conditioning=None, - ) - - -@pytest.fixture() -def mock_invoker(mock_services: InvocationServices) -> Invoker: - return Invoker(services=mock_services) - - -def test_can_create_graph_state(mock_invoker: Invoker): - g = mock_invoker.create_execution_state() - mock_invoker.stop() - - assert g is not None - assert isinstance(g, GraphExecutionState) - - -def test_can_create_graph_state_from_graph(mock_invoker: Invoker, simple_graph): - g = mock_invoker.create_execution_state(graph=simple_graph) - mock_invoker.stop() - - assert g is not None - assert isinstance(g, GraphExecutionState) - assert g.graph == simple_graph - - -# @pytest.mark.xfail(reason = "Requires fixing following the model manager refactor") -def test_can_invoke(mock_invoker: Invoker, simple_graph): - g = mock_invoker.create_execution_state(graph=simple_graph) - invocation_id = mock_invoker.invoke( - session_queue_batch_id="1", - session_queue_item_id=1, - session_queue_id=DEFAULT_QUEUE_ID, - graph_execution_state=g, - ) - assert invocation_id is not None - - def has_executed_any(g: GraphExecutionState): - g = mock_invoker.services.graph_execution_manager.get(g.id) - return len(g.executed) > 0 - - wait_until(lambda: has_executed_any(g), timeout=5, interval=1) - mock_invoker.stop() - - g = mock_invoker.services.graph_execution_manager.get(g.id) - assert len(g.executed) > 0 - - -# @pytest.mark.xfail(reason = "Requires fixing following the model manager refactor") -def test_can_invoke_all(mock_invoker: Invoker, simple_graph): - g = mock_invoker.create_execution_state(graph=simple_graph) - invocation_id = mock_invoker.invoke( - session_queue_batch_id="1", - session_queue_item_id=1, - session_queue_id=DEFAULT_QUEUE_ID, - graph_execution_state=g, - invoke_all=True, - ) - assert invocation_id is not None - - def has_executed_all(g: GraphExecutionState): - g = mock_invoker.services.graph_execution_manager.get(g.id) - return g.is_complete() - - wait_until(lambda: has_executed_all(g), timeout=5, interval=1) - mock_invoker.stop() - - g = mock_invoker.services.graph_execution_manager.get(g.id) - assert g.is_complete() - - -# @pytest.mark.xfail(reason = "Requires fixing following the model manager refactor") -def test_handles_errors(mock_invoker: Invoker): - g = mock_invoker.create_execution_state() - g.graph.add_node(ErrorInvocation(id="1")) - - mock_invoker.invoke( - session_queue_batch_id="1", - session_queue_item_id=1, - session_queue_id=DEFAULT_QUEUE_ID, - graph_execution_state=g, - invoke_all=True, - ) - - def has_executed_all(g: GraphExecutionState): - g = mock_invoker.services.graph_execution_manager.get(g.id) - return g.is_complete() - - wait_until(lambda: has_executed_all(g), timeout=5, interval=1) - mock_invoker.stop() - - g = mock_invoker.services.graph_execution_manager.get(g.id) - assert g.has_error() - assert g.is_complete() - - assert all((i in g.errors for i in g.source_prepared_mapping["1"])) From 475cfb61d731b34edbe5e3fede7b3167e46fff85 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 18 Feb 2024 12:33:16 +1100 Subject: [PATCH 136/340] feat(nodes): make processor thread limit and polling interval configurable --- .../session_processor_default.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/invokeai/app/services/session_processor/session_processor_default.py b/invokeai/app/services/session_processor/session_processor_default.py index dc08fc83458..3035a74a5a9 100644 --- a/invokeai/app/services/session_processor/session_processor_default.py +++ b/invokeai/app/services/session_processor/session_processor_default.py @@ -18,12 +18,9 @@ from .session_processor_base import SessionProcessorBase from .session_processor_common import SessionProcessorStatus -POLLING_INTERVAL = 1 -THREAD_LIMIT = 1 - class DefaultSessionProcessor(SessionProcessorBase): - def start(self, invoker: Invoker) -> None: + def start(self, invoker: Invoker, thread_limit: int = 1, polling_interval: int = 1) -> None: self._invoker: Invoker = invoker self._queue_item: Optional[SessionQueueItem] = None @@ -34,7 +31,10 @@ def start(self, invoker: Invoker) -> None: local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_queue_event) - self._thread_limit = BoundedSemaphore(THREAD_LIMIT) + self._thread_limit = thread_limit + self._thread_semaphore = BoundedSemaphore(thread_limit) + self._polling_interval = polling_interval + self._thread = Thread( name="session_processor", target=self._process, @@ -88,7 +88,7 @@ def _process( ): # Outermost processor try block; any unhandled exception is a fatal processor error try: - self._thread_limit.acquire() + self._thread_semaphore.acquire() stop_event.clear() resume_event.set() cancel_event.clear() @@ -247,7 +247,7 @@ def stats_cleanup(graph_execution_state_id: str) -> None: else: # The queue was empty, wait for next polling interval or event to try again self._invoker.services.logger.debug("Waiting for next polling interval or event") - poll_now_event.wait(POLLING_INTERVAL) + poll_now_event.wait(self._polling_interval) continue except Exception as e: # Non-fatal error in processor, cancel the queue item and wait for next polling interval or event @@ -256,7 +256,7 @@ def stats_cleanup(graph_execution_state_id: str) -> None: self._invoker.services.session_queue.cancel_queue_item( self._queue_item.item_id, error=traceback.format_exc() ) - poll_now_event.wait(POLLING_INTERVAL) + poll_now_event.wait(self._polling_interval) continue except Exception as e: # Fatal error in processor, log and pass - we're done here @@ -266,4 +266,4 @@ def stats_cleanup(graph_execution_state_id: str) -> None: stop_event.clear() poll_now_event.clear() self._queue_item = None - self._thread_limit.release() + self._thread_semaphore.release() From 1b8922f7e2db4f2d8430604e20909ce6d23dc1a5 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 18 Feb 2024 15:58:53 +1100 Subject: [PATCH 137/340] feat(nodes): improved error messages in processor --- .../session_processor/session_processor_default.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/invokeai/app/services/session_processor/session_processor_default.py b/invokeai/app/services/session_processor/session_processor_default.py index 3035a74a5a9..7d761e627f4 100644 --- a/invokeai/app/services/session_processor/session_processor_default.py +++ b/invokeai/app/services/session_processor/session_processor_default.py @@ -249,18 +249,20 @@ def stats_cleanup(graph_execution_state_id: str) -> None: self._invoker.services.logger.debug("Waiting for next polling interval or event") poll_now_event.wait(self._polling_interval) continue - except Exception as e: + except Exception: # Non-fatal error in processor, cancel the queue item and wait for next polling interval or event - self._invoker.services.logger.error(f"Error in session processor: {e}") + self._invoker.services.logger.error( + f"Non-fatal error in session processor:\n{traceback.format_exc()}" + ) if self._queue_item is not None: self._invoker.services.session_queue.cancel_queue_item( self._queue_item.item_id, error=traceback.format_exc() ) poll_now_event.wait(self._polling_interval) continue - except Exception as e: + except Exception: # Fatal error in processor, log and pass - we're done here - self._invoker.services.logger.error(f"Fatal Error in session processor: {e}") + self._invoker.services.logger.error(f"Fatal Error in session processor:\n{traceback.format_exc()}") pass finally: stop_event.clear() From c07ace5d07a0a0d2b7b09ea7b341a918ebc551dd Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 18 Feb 2024 16:47:55 +1100 Subject: [PATCH 138/340] feat(nodes): making invocation class var in processor --- .../session_processor_default.py | 30 ++++++++++--------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/invokeai/app/services/session_processor/session_processor_default.py b/invokeai/app/services/session_processor/session_processor_default.py index 7d761e627f4..9ba726bff3b 100644 --- a/invokeai/app/services/session_processor/session_processor_default.py +++ b/invokeai/app/services/session_processor/session_processor_default.py @@ -7,6 +7,7 @@ from fastapi_events.handlers.local import local_handler from fastapi_events.typing import Event as FastAPIEvent +from invokeai.app.invocations.baseinvocation import BaseInvocation from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError from invokeai.app.services.session_processor.session_processor_common import CanceledException @@ -23,6 +24,7 @@ class DefaultSessionProcessor(SessionProcessorBase): def start(self, invoker: Invoker, thread_limit: int = 1, polling_interval: int = 1) -> None: self._invoker: Invoker = invoker self._queue_item: Optional[SessionQueueItem] = None + self._invocation: Optional[BaseInvocation] = None self._resume_event = ThreadEvent() self._stop_event = ThreadEvent() @@ -134,12 +136,12 @@ def stats_cleanup(graph_execution_state_id: str) -> None: profiler.start(profile_id=self._queue_item.session_id) # Prepare invocations and take the first - invocation = self._queue_item.session.next() + self._invocation = self._queue_item.session.next() # Loop over invocations until the session is complete or canceled - while invocation is not None and not cancel_event.is_set(): + while self._invocation is not None and not cancel_event.is_set(): # get the source node id to provide to clients (the prepared node id is not as useful) - source_invocation_id = self._queue_item.session.prepared_source_mapping[invocation.id] + source_invocation_id = self._queue_item.session.prepared_source_mapping[self._invocation.id] # Send starting event self._invoker.services.events.emit_invocation_started( @@ -147,18 +149,18 @@ def stats_cleanup(graph_execution_state_id: str) -> None: queue_item_id=self._queue_item.item_id, queue_id=self._queue_item.queue_id, graph_execution_state_id=self._queue_item.session_id, - node=invocation.model_dump(), + node=self._invocation.model_dump(), source_node_id=source_invocation_id, ) # Innermost processor try block; any unhandled exception is an invocation error & will fail the graph try: with self._invoker.services.performance_statistics.collect_stats( - invocation, self._queue_item.session.id + self._invocation, self._queue_item.session.id ): # Build invocation context (the node-facing API) data = InvocationContextData( - invocation=invocation, + invocation=self._invocation, source_invocation_id=source_invocation_id, queue_item=self._queue_item, ) @@ -169,12 +171,12 @@ def stats_cleanup(graph_execution_state_id: str) -> None: ) # Invoke the node - outputs = invocation.invoke_internal( + outputs = self._invocation.invoke_internal( context=context, services=self._invoker.services ) # Save outputs and history - self._queue_item.session.complete(invocation.id, outputs) + self._queue_item.session.complete(self._invocation.id, outputs) # Send complete event self._invoker.services.events.emit_invocation_complete( @@ -182,7 +184,7 @@ def stats_cleanup(graph_execution_state_id: str) -> None: queue_item_id=self._queue_item.item_id, queue_id=self._queue_item.queue_id, graph_execution_state_id=self._queue_item.session.id, - node=invocation.model_dump(), + node=self._invocation.model_dump(), source_node_id=source_invocation_id, result=outputs.model_dump(), ) @@ -208,9 +210,9 @@ def stats_cleanup(graph_execution_state_id: str) -> None: error = traceback.format_exc() # Save error - self._queue_item.session.set_node_error(invocation.id, error) + self._queue_item.session.set_node_error(self._invocation.id, error) self._invoker.services.logger.error( - f"Error while invoking session {self._queue_item.session_id}, invocation {invocation.id} ({invocation.get_type()}):\n{e}" + f"Error while invoking session {self._queue_item.session_id}, invocation {self._invocation.id} ({self._invocation.get_type()}):\n{e}" ) # Send error event @@ -219,7 +221,7 @@ def stats_cleanup(graph_execution_state_id: str) -> None: queue_item_id=self._queue_item.item_id, queue_id=self._queue_item.queue_id, graph_execution_state_id=self._queue_item.session.id, - node=invocation.model_dump(), + node=self._invocation.model_dump(), source_node_id=source_invocation_id, error_type=e.__class__.__name__, error=error, @@ -236,10 +238,10 @@ def stats_cleanup(graph_execution_state_id: str) -> None: ) # Save the stats and stop the profiler if it's running stats_cleanup(self._queue_item.session.id) - invocation = None + self._invocation = None else: # Prepare the next invocation - invocation = self._queue_item.session.next() + self._invocation = self._queue_item.session.next() # The session is complete, immediately poll for next session self._queue_item = None From c2bd829e78f269c62bd619274f69b3de72c31f75 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 18 Feb 2024 16:51:58 +1100 Subject: [PATCH 139/340] fix(nodes): fix model load events was accessing incorrect properties in event data --- .../services/model_load/model_load_default.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/invokeai/app/services/model_load/model_load_default.py b/invokeai/app/services/model_load/model_load_default.py index 24ab10b4273..3ff7898c0e4 100644 --- a/invokeai/app/services/model_load/model_load_default.py +++ b/invokeai/app/services/model_load/model_load_default.py @@ -97,17 +97,17 @@ def _emit_load_event( if not loaded: self._invoker.services.events.emit_model_load_started( - queue_id=context_data.queue_id, - queue_item_id=context_data.queue_item_id, - queue_batch_id=context_data.batch_id, - graph_execution_state_id=context_data.session_id, + queue_id=context_data.queue_item.queue_id, + queue_item_id=context_data.queue_item.item_id, + queue_batch_id=context_data.queue_item.batch_id, + graph_execution_state_id=context_data.queue_item.session_id, model_config=model_config, ) else: self._invoker.services.events.emit_model_load_completed( - queue_id=context_data.queue_id, - queue_item_id=context_data.queue_item_id, - queue_batch_id=context_data.batch_id, - graph_execution_state_id=context_data.session_id, + queue_id=context_data.queue_item.queue_id, + queue_item_id=context_data.queue_item.item_id, + queue_batch_id=context_data.queue_item.batch_id, + graph_execution_state_id=context_data.queue_item.session_id, model_config=model_config, ) From 269c0e682351c584c87a8566a95728fb6b660ef7 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 19 Feb 2024 12:57:05 +1100 Subject: [PATCH 140/340] fix(nodes): fix typing on stats service context manager --- .../app/services/invocation_stats/invocation_stats_base.py | 4 ++-- .../app/services/invocation_stats/invocation_stats_default.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/invokeai/app/services/invocation_stats/invocation_stats_base.py b/invokeai/app/services/invocation_stats/invocation_stats_base.py index b28220e74c4..3266d985fef 100644 --- a/invokeai/app/services/invocation_stats/invocation_stats_base.py +++ b/invokeai/app/services/invocation_stats/invocation_stats_base.py @@ -30,7 +30,7 @@ from abc import ABC, abstractmethod from pathlib import Path -from typing import Iterator +from typing import ContextManager from invokeai.app.invocations.baseinvocation import BaseInvocation from invokeai.app.services.invocation_stats.invocation_stats_common import InvocationStatsSummary @@ -50,7 +50,7 @@ def collect_stats( self, invocation: BaseInvocation, graph_execution_state_id: str, - ) -> Iterator[None]: + ) -> ContextManager[None]: """ Return a context object that will capture the statistics on the execution of invocaation. Use with: to place around the part of the code that executes the invocation. diff --git a/invokeai/app/services/invocation_stats/invocation_stats_default.py b/invokeai/app/services/invocation_stats/invocation_stats_default.py index 06a5b675c31..5a41f1f5d6b 100644 --- a/invokeai/app/services/invocation_stats/invocation_stats_default.py +++ b/invokeai/app/services/invocation_stats/invocation_stats_default.py @@ -2,7 +2,7 @@ import time from contextlib import contextmanager from pathlib import Path -from typing import Iterator +from typing import Generator import psutil import torch @@ -41,7 +41,7 @@ def start(self, invoker: Invoker) -> None: self._invoker = invoker @contextmanager - def collect_stats(self, invocation: BaseInvocation, graph_execution_state_id: str) -> Iterator[None]: + def collect_stats(self, invocation: BaseInvocation, graph_execution_state_id: str) -> Generator[None, None, None]: # This is to handle case of the model manager not being initialized, which happens # during some tests. services = self._invoker.services From b8f3b4f9eb90eb96bdaa1cfc78284006a21cfc3b Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 19 Feb 2024 13:09:00 +1100 Subject: [PATCH 141/340] tidy(nodes): clean up profiler/stats in processor, better comments --- .../session_processor_default.py | 65 ++++++++++--------- 1 file changed, 34 insertions(+), 31 deletions(-) diff --git a/invokeai/app/services/session_processor/session_processor_default.py b/invokeai/app/services/session_processor/session_processor_default.py index 9ba726bff3b..cff7bb6c6c5 100644 --- a/invokeai/app/services/session_processor/session_processor_default.py +++ b/invokeai/app/services/session_processor/session_processor_default.py @@ -37,6 +37,18 @@ def start(self, invoker: Invoker, thread_limit: int = 1, polling_interval: int = self._thread_semaphore = BoundedSemaphore(thread_limit) self._polling_interval = polling_interval + # If profiling is enabled, create a profiler. The same profiler will be used for all sessions. Internally, + # the profiler will create a new profile for each session. + self._profiler = ( + Profiler( + logger=self._invoker.services.logger, + output_dir=self._invoker.services.configuration.profiles_path, + prefix=self._invoker.services.configuration.profile_prefix, + ) + if self._invoker.services.configuration.profile_graphs + else None + ) + self._thread = Thread( name="session_processor", target=self._process, @@ -95,32 +107,6 @@ def _process( resume_event.set() cancel_event.clear() - # If profiling is enabled, create a profiler. The same profiler will be used for all sessions. Internally, - # the profiler will create a new profile for each session. - profiler = ( - Profiler( - logger=self._invoker.services.logger, - output_dir=self._invoker.services.configuration.profiles_path, - prefix=self._invoker.services.configuration.profile_prefix, - ) - if self._invoker.services.configuration.profile_graphs - else None - ) - - # Helper function to stop the profiler and save the stats - def stats_cleanup(graph_execution_state_id: str) -> None: - if profiler: - profile_path = profiler.stop() - stats_path = profile_path.with_suffix(".json") - self._invoker.services.performance_statistics.dump_stats( - graph_execution_state_id=graph_execution_state_id, output_path=stats_path - ) - # We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor - # we don't care about that - suppress the error. - with suppress(GESStatsNotFoundError): - self._invoker.services.performance_statistics.log_stats(graph_execution_state_id) - self._invoker.services.performance_statistics.reset_stats() - while not stop_event.is_set(): poll_now_event.clear() # Middle processor try block; any unhandled exception is a non-fatal processor error @@ -132,8 +118,8 @@ def stats_cleanup(graph_execution_state_id: str) -> None: cancel_event.clear() # If profiling is enabled, start the profiler - if profiler is not None: - profiler.start(profile_id=self._queue_item.session_id) + if self._profiler is not None: + self._profiler.start(profile_id=self._queue_item.session_id) # Prepare invocations and take the first self._invocation = self._queue_item.session.next() @@ -228,6 +214,7 @@ def stats_cleanup(graph_execution_state_id: str) -> None: ) pass + # The session is complete if the all invocations are complete or there was an error if self._queue_item.session.is_complete() or cancel_event.is_set(): # Send complete event self._invoker.services.events.emit_graph_execution_complete( @@ -236,8 +223,20 @@ def stats_cleanup(graph_execution_state_id: str) -> None: queue_id=self._queue_item.queue_id, graph_execution_state_id=self._queue_item.session.id, ) - # Save the stats and stop the profiler if it's running - stats_cleanup(self._queue_item.session.id) + # If we are profiling, stop the profiler and dump the profile & stats + if self._profiler: + profile_path = self._profiler.stop() + stats_path = profile_path.with_suffix(".json") + self._invoker.services.performance_statistics.dump_stats( + graph_execution_state_id=self._queue_item.session.id, output_path=stats_path + ) + # We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor + # we don't care about that - suppress the error. + with suppress(GESStatsNotFoundError): + self._invoker.services.performance_statistics.log_stats(self._queue_item.session.id) + self._invoker.services.performance_statistics.reset_stats() + + # Set the invocation to None to prepare for the next session self._invocation = None else: # Prepare the next invocation @@ -252,14 +251,18 @@ def stats_cleanup(graph_execution_state_id: str) -> None: poll_now_event.wait(self._polling_interval) continue except Exception: - # Non-fatal error in processor, cancel the queue item and wait for next polling interval or event + # Non-fatal error in processor self._invoker.services.logger.error( f"Non-fatal error in session processor:\n{traceback.format_exc()}" ) + # Cancel the queue item if self._queue_item is not None: self._invoker.services.session_queue.cancel_queue_item( self._queue_item.item_id, error=traceback.format_exc() ) + # Reset the invocation to None to prepare for the next session + self._invocation = None + # Immediately poll for next queue item poll_now_event.wait(self._polling_interval) continue except Exception: From 934fc82920b46426c5e1546d33983e174d295654 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 19 Feb 2024 13:16:06 +1100 Subject: [PATCH 142/340] chore(nodes): update TODO comment --- .../app/services/session_processor/session_processor_default.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invokeai/app/services/session_processor/session_processor_default.py b/invokeai/app/services/session_processor/session_processor_default.py index cff7bb6c6c5..c0b98220c87 100644 --- a/invokeai/app/services/session_processor/session_processor_default.py +++ b/invokeai/app/services/session_processor/session_processor_default.py @@ -176,7 +176,7 @@ def _process( ) except KeyboardInterrupt: - # TODO(psyche): should we set the cancel event here and/or cancel the queue item? + # TODO(MM2): Create an event for this pass except CanceledException: From 71ff5b07a716603a747d5c69add099a2b4e016ef Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Mon, 19 Feb 2024 11:40:53 -0500 Subject: [PATCH 143/340] Add a few convenience targets to Makefile - "test" to run pytests - "frontend-install" to reinstall pnpm's node modeuls --- Makefile | 39 +++++++++++++++++++++++++-------------- 1 file changed, 25 insertions(+), 14 deletions(-) diff --git a/Makefile b/Makefile index 10d7a257c55..c3eec094f79 100644 --- a/Makefile +++ b/Makefile @@ -6,33 +6,44 @@ default: help help: @echo Developer commands: @echo - @echo "ruff Run ruff, fixing any safely-fixable errors and formatting" - @echo "ruff-unsafe Run ruff, fixing all fixable errors and formatting" - @echo "mypy Run mypy using the config in pyproject.toml to identify type mismatches and other coding errors" - @echo "mypy-all Run mypy ignoring the config in pyproject.tom but still ignoring missing imports" - @echo "frontend-build Build the frontend in order to run on localhost:9090" - @echo "frontend-dev Run the frontend in developer mode on localhost:5173" - @echo "installer-zip Build the installer .zip file for the current version" - @echo "tag-release Tag the GitHub repository with the current version (use at release time only!)" + @echo "ruff Run ruff, fixing any safely-fixable errors and formatting" + @echo "ruff-unsafe Run ruff, fixing all fixable errors and formatting" + @echo "mypy Run mypy using the config in pyproject.toml to identify type mismatches and other coding errors" + @echo "mypy-all Run mypy ignoring the config in pyproject.tom but still ignoring missing imports" + @echo "test" Run the unit tests. + @echo "frontend-install" Install the pnpm modules needed for the front end + @echo "frontend-build Build the frontend in order to run on localhost:9090" + @echo "frontend-dev Run the frontend in developer mode on localhost:5173" + @echo "installer-zip Build the installer .zip file for the current version" + @echo "tag-release Tag the GitHub repository with the current version (use at release time only!)" # Runs ruff, fixing any safely-fixable errors and formatting ruff: - ruff check . --fix - ruff format . + ruff check . --fix + ruff format . # Runs ruff, fixing all errors it can fix and formatting ruff-unsafe: - ruff check . --fix --unsafe-fixes - ruff format . + ruff check . --fix --unsafe-fixes + ruff format . # Runs mypy, using the config in pyproject.toml mypy: - mypy scripts/invokeai-web.py + mypy scripts/invokeai-web.py # Runs mypy, ignoring the config in pyproject.toml but still ignoring missing (untyped) imports # (many files are ignored by the config, so this is useful for checking all files) mypy-all: - mypy scripts/invokeai-web.py --config-file= --ignore-missing-imports + mypy scripts/invokeai-web.py --config-file= --ignore-missing-imports + +# Run the unit tests +test: + pytest ./tests + +# Install the pnpm modules needed for the front end +frontend-install: + rm -rf invokeai/frontend/web/node_modules + cd invokeai/frontend/web && pnpm install # Build the frontend frontend-build: From 4b5fdb66e23f1d347ddf0388c203d6180e5a1c97 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 20 Feb 2024 10:16:25 +1100 Subject: [PATCH 144/340] fix(ui): fix low-hanging fruit types --- .../listenerMiddleware/listeners/controlNetImageProcessed.ts | 1 + invokeai/frontend/web/src/services/api/endpoints/models.ts | 2 +- invokeai/frontend/web/src/services/api/types.ts | 1 - 3 files changed, 2 insertions(+), 2 deletions(-) diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlNetImageProcessed.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlNetImageProcessed.ts index fba274beb87..73de9312976 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlNetImageProcessed.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlNetImageProcessed.ts @@ -51,6 +51,7 @@ export const addControlNetImageProcessedListener = () => { image: { image_name: ca.controlImage }, }, }, + edges: [], }, runs: 1, }, diff --git a/invokeai/frontend/web/src/services/api/endpoints/models.ts b/invokeai/frontend/web/src/services/api/endpoints/models.ts index 9a7f1080564..57cf7dacfc1 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/models.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/models.ts @@ -39,7 +39,7 @@ type UpdateLoRAModelArg = { type UpdateMainModelResponse = paths['/api/v2/models/i/{key}']['patch']['responses']['200']['content']['application/json']; -type ListModelsArg = NonNullable; +type ListModelsArg = NonNullable; type UpdateLoRAModelResponse = UpdateMainModelResponse; diff --git a/invokeai/frontend/web/src/services/api/types.ts b/invokeai/frontend/web/src/services/api/types.ts index 4ae2f9b594e..aaa70a26848 100644 --- a/invokeai/frontend/web/src/services/api/types.ts +++ b/invokeai/frontend/web/src/services/api/types.ts @@ -164,7 +164,6 @@ export type IntegerOutput = S['IntegerOutput']; export type IterateInvocationOutput = S['IterateInvocationOutput']; export type CollectInvocationOutput = S['CollectInvocationOutput']; export type LatentsOutput = S['LatentsOutput']; -export type GraphInvocationOutput = S['GraphInvocationOutput']; // Post-image upload actions, controls workflows when images are uploaded From 0ad4537af5add45e443912f22441bde69caebc7e Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 20 Feb 2024 10:42:14 +1100 Subject: [PATCH 145/340] chore(ui): bump `@invoke-ai/ui-library` --- invokeai/frontend/web/package.json | 2 +- invokeai/frontend/web/pnpm-lock.yaml | 110 ++++++++++++++++----------- 2 files changed, 68 insertions(+), 44 deletions(-) diff --git a/invokeai/frontend/web/package.json b/invokeai/frontend/web/package.json index cea13350d26..a9f37f76ad6 100644 --- a/invokeai/frontend/web/package.json +++ b/invokeai/frontend/web/package.json @@ -57,7 +57,7 @@ "@dnd-kit/sortable": "^8.0.0", "@dnd-kit/utilities": "^3.2.2", "@fontsource-variable/inter": "^5.0.16", - "@invoke-ai/ui-library": "^0.0.18", + "@invoke-ai/ui-library": "^0.0.21", "@mantine/form": "6.0.21", "@nanostores/react": "^0.7.1", "@reduxjs/toolkit": "2.0.1", diff --git a/invokeai/frontend/web/pnpm-lock.yaml b/invokeai/frontend/web/pnpm-lock.yaml index 0ec2e47a0cd..4f9902299cb 100644 --- a/invokeai/frontend/web/pnpm-lock.yaml +++ b/invokeai/frontend/web/pnpm-lock.yaml @@ -32,8 +32,8 @@ dependencies: specifier: ^5.0.16 version: 5.0.16 '@invoke-ai/ui-library': - specifier: ^0.0.18 - version: 0.0.18(@chakra-ui/form-control@2.2.0)(@chakra-ui/icon@3.2.0)(@chakra-ui/media-query@3.3.0)(@chakra-ui/menu@2.2.1)(@chakra-ui/spinner@2.1.0)(@chakra-ui/system@2.6.2)(@fontsource-variable/inter@5.0.16)(@internationalized/date@3.5.1)(@types/react@18.2.48)(i18next@23.7.16)(react-dom@18.2.0)(react@18.2.0) + specifier: ^0.0.21 + version: 0.0.21(@chakra-ui/form-control@2.2.0)(@chakra-ui/icon@3.2.0)(@chakra-ui/media-query@3.3.0)(@chakra-ui/menu@2.2.1)(@chakra-ui/spinner@2.1.0)(@chakra-ui/system@2.6.2)(@fontsource-variable/inter@5.0.16)(@internationalized/date@3.5.2)(@types/react@18.2.48)(i18next@23.7.16)(react-dom@18.2.0)(react@18.2.0) '@mantine/form': specifier: 6.0.21 version: 6.0.21(react@18.2.0) @@ -344,7 +344,7 @@ packages: '@jridgewell/trace-mapping': 0.3.21 dev: true - /@ark-ui/anatomy@1.3.0(@internationalized/date@3.5.1): + /@ark-ui/anatomy@1.3.0(@internationalized/date@3.5.2): resolution: {integrity: sha512-1yG2MrzUlix6KthjQMCNiHnkXrWwEdFAX6D+HqGJaNu0XvaGul2J+wDNtjsdX+gxiWu1nXXEEOAWlFVYMUf65w==} dependencies: '@zag-js/accordion': 0.32.1 @@ -356,7 +356,7 @@ packages: '@zag-js/color-utils': 0.32.1 '@zag-js/combobox': 0.32.1 '@zag-js/date-picker': 0.32.1 - '@zag-js/date-utils': 0.32.1(@internationalized/date@3.5.1) + '@zag-js/date-utils': 0.32.1(@internationalized/date@3.5.2) '@zag-js/dialog': 0.32.1 '@zag-js/editable': 0.32.1 '@zag-js/file-upload': 0.32.1 @@ -383,13 +383,13 @@ packages: - '@internationalized/date' dev: false - /@ark-ui/react@1.3.0(@internationalized/date@3.5.1)(react-dom@18.2.0)(react@18.2.0): + /@ark-ui/react@1.3.0(@internationalized/date@3.5.2)(react-dom@18.2.0)(react@18.2.0): resolution: {integrity: sha512-JHjNoIX50+mUCTaEGMjfGQWGGi31pKsV646jZJlR/1xohpYJigzg8BvO97cTsVk8fwtur+cm11gz3Nf7f5QUnA==} peerDependencies: react: '>=18.0.0' react-dom: '>=18.0.0' dependencies: - '@ark-ui/anatomy': 1.3.0(@internationalized/date@3.5.1) + '@ark-ui/anatomy': 1.3.0(@internationalized/date@3.5.2) '@zag-js/accordion': 0.32.1 '@zag-js/avatar': 0.32.1 '@zag-js/carousel': 0.32.1 @@ -399,7 +399,7 @@ packages: '@zag-js/combobox': 0.32.1 '@zag-js/core': 0.32.1 '@zag-js/date-picker': 0.32.1 - '@zag-js/date-utils': 0.32.1(@internationalized/date@3.5.1) + '@zag-js/date-utils': 0.32.1(@internationalized/date@3.5.2) '@zag-js/dialog': 0.32.1 '@zag-js/editable': 0.32.1 '@zag-js/file-upload': 0.32.1 @@ -1709,6 +1709,13 @@ packages: dependencies: regenerator-runtime: 0.14.1 + /@babel/runtime@7.23.9: + resolution: {integrity: sha512-0CX6F+BI2s9dkUqr08KFrAIZgNFj75rdBU/DjCyYLIaV/quFjkk6T+EJ2LkZHyZTbEV4L5p97mNkUsHl2wLFAw==} + engines: {node: '>=6.9.0'} + dependencies: + regenerator-runtime: 0.14.1 + dev: false + /@babel/template@7.22.15: resolution: {integrity: sha512-QPErUVm4uyJa60rkI73qneDacvdvzxshT3kksGqlGWYdOTIUOwJ7RDUL8sGqslY1uXWSL6xMFKEXDS3ox2uF0w==} engines: {node: '>=6.9.0'} @@ -1969,7 +1976,7 @@ packages: dependencies: '@chakra-ui/dom-utils': 2.1.0 react: 18.2.0 - react-focus-lock: 2.9.6(@types/react@18.2.48)(react@18.2.0) + react-focus-lock: 2.11.1(@types/react@18.2.48)(react@18.2.0) transitivePeerDependencies: - '@types/react' dev: false @@ -3013,7 +3020,7 @@ packages: '@types/react': optional: true dependencies: - '@babel/runtime': 7.23.8 + '@babel/runtime': 7.23.9 '@emotion/babel-plugin': 11.11.0 '@emotion/is-prop-valid': 1.2.1 '@emotion/react': 11.11.3(@types/react@18.2.48)(react@18.2.0) @@ -3560,16 +3567,16 @@ packages: resolution: {integrity: sha512-dvuCeX5fC9dXgJn9t+X5atfmgQAzUOWqS1254Gh0m6i8wKd10ebXkfNKiRK+1GWi/yTvvLDHpoxLr0xxxeslWw==} dev: true - /@internationalized/date@3.5.1: - resolution: {integrity: sha512-LUQIfwU9e+Fmutc/DpRTGXSdgYZLBegi4wygCWDSVmUdLTaMHsQyASDiJtREwanwKuQLq0hY76fCJ9J/9I2xOQ==} + /@internationalized/date@3.5.2: + resolution: {integrity: sha512-vo1yOMUt2hzp63IutEaTUxROdvQg1qlMRsbCvbay2AK2Gai7wIgCyK5weEX3nHkiLgo4qCXHijFNC/ILhlRpOQ==} dependencies: - '@swc/helpers': 0.5.3 + '@swc/helpers': 0.5.6 dev: false - /@internationalized/number@3.5.0: - resolution: {integrity: sha512-ZY1BW8HT9WKYvaubbuqXbbDdHhOUMfE2zHHFJeTppid0S+pc8HtdIxFxaYMsGjCb4UsF+MEJ4n2TfU7iHnUK8w==} + /@internationalized/number@3.5.1: + resolution: {integrity: sha512-N0fPU/nz15SwR9IbfJ5xaS9Ss/O5h1sVXMZf43vc9mxEG48ovglvvzBjF53aHlq20uoR6c+88CrIXipU/LSzwg==} dependencies: - '@swc/helpers': 0.5.3 + '@swc/helpers': 0.5.6 dev: false /@invoke-ai/eslint-config-react@0.0.13(@typescript-eslint/eslint-plugin@6.19.0)(@typescript-eslint/parser@6.19.0)(eslint-config-prettier@9.1.0)(eslint-plugin-import@2.29.1)(eslint-plugin-react-hooks@4.6.0)(eslint-plugin-react-refresh@0.4.5)(eslint-plugin-react@7.33.2)(eslint-plugin-simple-import-sort@10.0.0)(eslint-plugin-storybook@0.6.15)(eslint-plugin-unused-imports@3.0.0)(eslint@8.56.0): @@ -3608,14 +3615,14 @@ packages: prettier: 3.2.4 dev: true - /@invoke-ai/ui-library@0.0.18(@chakra-ui/form-control@2.2.0)(@chakra-ui/icon@3.2.0)(@chakra-ui/media-query@3.3.0)(@chakra-ui/menu@2.2.1)(@chakra-ui/spinner@2.1.0)(@chakra-ui/system@2.6.2)(@fontsource-variable/inter@5.0.16)(@internationalized/date@3.5.1)(@types/react@18.2.48)(i18next@23.7.16)(react-dom@18.2.0)(react@18.2.0): - resolution: {integrity: sha512-Yme+2+pzYy3TPb7ZT0hYmBwahH29ZRSVIxLKSexh3BsbJXbTzGssRQU78QvK6Ymxemgbso3P8Rs+IW0zNhQKjQ==} + /@invoke-ai/ui-library@0.0.21(@chakra-ui/form-control@2.2.0)(@chakra-ui/icon@3.2.0)(@chakra-ui/media-query@3.3.0)(@chakra-ui/menu@2.2.1)(@chakra-ui/spinner@2.1.0)(@chakra-ui/system@2.6.2)(@fontsource-variable/inter@5.0.16)(@internationalized/date@3.5.2)(@types/react@18.2.48)(i18next@23.7.16)(react-dom@18.2.0)(react@18.2.0): + resolution: {integrity: sha512-tCvgkBPDt0gNq+8IcR03e/Mw7R8Mb/SMXTqx3FEIxlTQEo93A/D38dKXeDCzTdx4sQ+sknfB+JLBbHs6sg5hhQ==} peerDependencies: '@fontsource-variable/inter': ^5.0.16 react: ^18.2.0 react-dom: ^18.2.0 dependencies: - '@ark-ui/react': 1.3.0(@internationalized/date@3.5.1)(react-dom@18.2.0)(react@18.2.0) + '@ark-ui/react': 1.3.0(@internationalized/date@3.5.2)(react-dom@18.2.0)(react@18.2.0) '@chakra-ui/anatomy': 2.2.2 '@chakra-ui/icons': 2.1.1(@chakra-ui/system@2.6.2)(react@18.2.0) '@chakra-ui/layout': 2.3.1(@chakra-ui/system@2.6.2)(react@18.2.0) @@ -3631,11 +3638,11 @@ packages: framer-motion: 10.18.0(react-dom@18.2.0)(react@18.2.0) lodash-es: 4.17.21 nanostores: 0.9.5 - overlayscrollbars: 2.4.7 - overlayscrollbars-react: 0.5.4(overlayscrollbars@2.4.7)(react@18.2.0) + overlayscrollbars: 2.5.0 + overlayscrollbars-react: 0.5.4(overlayscrollbars@2.5.0)(react@18.2.0) react: 18.2.0 react-dom: 18.2.0(react@18.2.0) - react-i18next: 14.0.1(i18next@23.7.16)(react-dom@18.2.0)(react@18.2.0) + react-i18next: 14.0.5(i18next@23.7.16)(react-dom@18.2.0)(react@18.2.0) react-icons: 5.0.1(react@18.2.0) react-select: 5.8.0(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0) transitivePeerDependencies: @@ -5634,8 +5641,8 @@ packages: resolution: {integrity: sha512-9F4ys4C74eSTEUNndnER3VJ15oru2NumfQxS8geE+f3eB5xvfxpWyqE5XlVnxb/R14uoXi6SLbBwwiDSkv+XEw==} dev: true - /@swc/helpers@0.5.3: - resolution: {integrity: sha512-FaruWX6KdudYloq1AHD/4nU+UsMTdNE8CKyrseXWEcgjDAbvkwJg2QGPAnfIJLIWsjZOSPLOAykK6fuYp4vp4A==} + /@swc/helpers@0.5.6: + resolution: {integrity: sha512-aYX01Ke9hunpoCexYAgQucEpARGQ5w/cqHFrIR+e9gdKb1QWTsVJuTJ2ozQzIAxLyRQe/m+2RqzkyOOGiMKRQA==} dependencies: tslib: 2.6.2 dev: false @@ -6754,10 +6761,10 @@ packages: /@zag-js/date-picker@0.32.1: resolution: {integrity: sha512-n/hYmF+/R4+NuyfPRzCgeuLT6LJihKSuKzK29STPWy3sC/tBBHiqhNv1/4UKbatHUJXdBW2XF+N8Rw08RffcFQ==} dependencies: - '@internationalized/date': 3.5.1 + '@internationalized/date': 3.5.2 '@zag-js/anatomy': 0.32.1 '@zag-js/core': 0.32.1 - '@zag-js/date-utils': 0.32.1(@internationalized/date@3.5.1) + '@zag-js/date-utils': 0.32.1(@internationalized/date@3.5.2) '@zag-js/dismissable': 0.32.1 '@zag-js/dom-event': 0.32.1 '@zag-js/dom-query': 0.32.1 @@ -6769,12 +6776,12 @@ packages: '@zag-js/utils': 0.32.1 dev: false - /@zag-js/date-utils@0.32.1(@internationalized/date@3.5.1): + /@zag-js/date-utils@0.32.1(@internationalized/date@3.5.2): resolution: {integrity: sha512-dbBDRSVr5pRUw3rXndyGuSshZiWqQI5JQO4D2KIFGkXzorj6WzoOpcO910Z7AdM/9cCAMpCjUrka8d8o9BpJBg==} peerDependencies: '@internationalized/date': '>=3.0.0' dependencies: - '@internationalized/date': 3.5.1 + '@internationalized/date': 3.5.2 dev: false /@zag-js/dialog@0.32.1: @@ -6917,7 +6924,7 @@ packages: /@zag-js/number-input@0.32.1: resolution: {integrity: sha512-atyIOvoMITb4hZtQym7yD6I7grvPW83UeMFO8hCQg3HWwd2zR4+63mouWuyMoWb4QrzVFRVQBaU8OG5xGlknEw==} dependencies: - '@internationalized/number': 3.5.0 + '@internationalized/number': 3.5.1 '@zag-js/anatomy': 0.32.1 '@zag-js/core': 0.32.1 '@zag-js/dom-event': 0.32.1 @@ -9537,8 +9544,8 @@ packages: engines: {node: '>=0.4.0'} dev: true - /focus-lock@1.0.0: - resolution: {integrity: sha512-a8Ge6cdKh9za/GZR/qtigTAk7SrGore56EFcoMshClsh7FLk1zwszc/ltuMfKhx56qeuyL/jWQ4J4axou0iJ9w==} + /focus-lock@1.3.2: + resolution: {integrity: sha512-kFI92jZVqa8rP4Yer2sLNlUDcOdEFxYum2tIIr4eCH0XF+pOmlg0xiY4tkbDmHJXt3phtbJoWs1L6PgUVk97rA==} engines: {node: '>=10'} dependencies: tslib: 2.6.2 @@ -11431,13 +11438,13 @@ packages: react: 18.2.0 dev: false - /overlayscrollbars-react@0.5.4(overlayscrollbars@2.4.7)(react@18.2.0): + /overlayscrollbars-react@0.5.4(overlayscrollbars@2.5.0)(react@18.2.0): resolution: {integrity: sha512-FPKx9XnXovTnI4+2JXig5uEaTLSEJ6svOwPzIfBBXTHBRNsz2+WhYUmfM0K/BNYxjgDEwuPm+NQhEoOA0RoG1g==} peerDependencies: overlayscrollbars: ^2.0.0 react: '>=16.8.0' dependencies: - overlayscrollbars: 2.4.7 + overlayscrollbars: 2.5.0 react: 18.2.0 dev: false @@ -11445,8 +11452,8 @@ packages: resolution: {integrity: sha512-C7tmhetwMv9frEvIT/RfkAVEgbjRNz/Gh2zE8BVmN+jl35GRaAnz73rlGQCMRoC2arpACAXyMNnJkzHb7GBrcA==} dev: false - /overlayscrollbars@2.4.7: - resolution: {integrity: sha512-02X2/nHno35dzebCx+EO2tRDaKAOltZqUKdUqvq3Pt8htCuhJbYi+mjr0CYerVeGRRoZ2Uo6/8XrNg//DJJ+GA==} + /overlayscrollbars@2.5.0: + resolution: {integrity: sha512-CWVC2dwS07XZfLHDm5GmZN1iYggiJ8Vufnvzwt0gwR9Yz1hVckKeTxg7VILZeYVGhDYJHZ1Xc8Xfys5dWZ1qiA==} dev: false /p-limit@2.3.0: @@ -11949,7 +11956,7 @@ packages: peerDependencies: react: ^15.3.0 || ^16.0.0 || ^17.0.0 || ^18.0.0 dependencies: - '@babel/runtime': 7.23.8 + '@babel/runtime': 7.23.9 react: 18.2.0 dev: false @@ -12035,8 +12042,8 @@ packages: resolution: {integrity: sha512-nsO+KSNgo1SbJqJEYRE9ERzo7YtYbou/OqjSQKxV7jcKox7+usiUVZOAC+XnDOABXggQTno0Y1CpVnuWEc1boQ==} dev: false - /react-focus-lock@2.9.6(@types/react@18.2.48)(react@18.2.0): - resolution: {integrity: sha512-B7gYnCjHNrNYwY2juS71dHbf0+UpXXojt02svxybj8N5bxceAkzPChKEncHuratjUHkIFNCn06k2qj1DRlzTug==} + /react-focus-lock@2.11.1(@types/react@18.2.48)(react@18.2.0): + resolution: {integrity: sha512-IXLwnTBrLTlKTpASZXqqXJ8oymWrgAlOfuuDYN4XCuN1YJ72dwX198UCaF1QqGUk5C3QOnlMik//n3ufcfe8Ig==} peerDependencies: '@types/react': ^16.8.0 || ^17.0.0 || ^18.0.0 react: ^16.8.0 || ^17.0.0 || ^18.0.0 @@ -12044,9 +12051,9 @@ packages: '@types/react': optional: true dependencies: - '@babel/runtime': 7.23.8 + '@babel/runtime': 7.23.9 '@types/react': 18.2.48 - focus-lock: 1.0.0 + focus-lock: 1.3.2 prop-types: 15.8.1 react: 18.2.0 react-clientside-effect: 1.2.6(react@18.2.0) @@ -12093,8 +12100,8 @@ packages: react-dom: 18.2.0(react@18.2.0) dev: false - /react-i18next@14.0.1(i18next@23.7.16)(react-dom@18.2.0)(react@18.2.0): - resolution: {integrity: sha512-TMV8hFismBmpMdIehoFHin/okfvgjFhp723RYgIqB4XyhDobVMyukyM3Z8wtTRmajyFMZrBl/OaaXF2P6WjUAw==} + /react-i18next@14.0.5(i18next@23.7.16)(react-dom@18.2.0)(react@18.2.0): + resolution: {integrity: sha512-5+bQSeEtgJrMBABBL5lO7jPdSNAbeAZ+MlFWDw//7FnVacuVu3l9EeWFzBQvZsKy+cihkbThWOAThEdH8YjGEw==} peerDependencies: i18next: '>= 23.2.3' react: '>= 16.8.0' @@ -12106,7 +12113,7 @@ packages: react-native: optional: true dependencies: - '@babel/runtime': 7.23.8 + '@babel/runtime': 7.23.9 html-parse-stringify: 3.0.1 i18next: 23.7.16 react: 18.2.0 @@ -12204,6 +12211,23 @@ packages: react: 18.2.0 react-style-singleton: 2.2.1(@types/react@18.2.48)(react@18.2.0) tslib: 2.6.2 + dev: true + + /react-remove-scroll-bar@2.3.5(@types/react@18.2.48)(react@18.2.0): + resolution: {integrity: sha512-3cqjOqg6s0XbOjWvmasmqHch+RLxIEk2r/70rzGXuz3iIGQsQheEQyqYCBb5EECoD01Vo2SIbDqW4paLeLTASw==} + engines: {node: '>=10'} + peerDependencies: + '@types/react': ^16.8.0 || ^17.0.0 || ^18.0.0 + react: ^16.8.0 || ^17.0.0 || ^18.0.0 + peerDependenciesMeta: + '@types/react': + optional: true + dependencies: + '@types/react': 18.2.48 + react: 18.2.0 + react-style-singleton: 2.2.1(@types/react@18.2.48)(react@18.2.0) + tslib: 2.6.2 + dev: false /react-remove-scroll@2.5.5(@types/react@18.2.48)(react@18.2.0): resolution: {integrity: sha512-ImKhrzJJsyXJfBZ4bzu8Bwpka14c/fQt0k+cyFp/PBhTfyDnU5hjOtM4AG/0AMyy8oKzOTR0lDgJIM7pYXI0kw==} @@ -12236,7 +12260,7 @@ packages: dependencies: '@types/react': 18.2.48 react: 18.2.0 - react-remove-scroll-bar: 2.3.4(@types/react@18.2.48)(react@18.2.0) + react-remove-scroll-bar: 2.3.5(@types/react@18.2.48)(react@18.2.0) react-style-singleton: 2.2.1(@types/react@18.2.48)(react@18.2.0) tslib: 2.6.2 use-callback-ref: 1.3.1(@types/react@18.2.48)(react@18.2.0) From e31ee85ddb02a78e57b46252c12e4e5fd5f065db Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 20 Feb 2024 10:42:44 +1100 Subject: [PATCH 146/340] fix(ui): get lora select working --- .../web/src/common/hooks/useGroupedModelCombobox.ts | 11 ++++++----- .../web/src/features/lora/components/LoRACard.tsx | 12 ++++++------ .../web/src/features/lora/components/LoRASelect.tsx | 8 ++++---- 3 files changed, 16 insertions(+), 15 deletions(-) diff --git a/invokeai/frontend/web/src/common/hooks/useGroupedModelCombobox.ts b/invokeai/frontend/web/src/common/hooks/useGroupedModelCombobox.ts index 875ce1f1c4c..140cf3eaa69 100644 --- a/invokeai/frontend/web/src/common/hooks/useGroupedModelCombobox.ts +++ b/invokeai/frontend/web/src/common/hooks/useGroupedModelCombobox.ts @@ -2,15 +2,16 @@ import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library'; import type { EntityState } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import type { GroupBase } from 'chakra-react-select'; +import type { ModelIdentifierWithBase } from 'features/nodes/types/common'; import { groupBy, map, reduce } from 'lodash-es'; import { useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; -import type { AnyModelConfig } from 'services/api/endpoints/models'; import { getModelId } from 'services/api/endpoints/models'; +import type { AnyModelConfig } from 'services/api/types'; type UseGroupedModelComboboxArg = { modelEntities: EntityState | undefined; - selectedModel?: Pick | null; + selectedModel?: ModelIdentifierWithBase | null; onChange: (value: T | null) => void; getIsDisabled?: (model: T) => boolean; isLoading?: boolean; @@ -28,7 +29,7 @@ export const useGroupedModelCombobox = ( arg: UseGroupedModelComboboxArg ): UseGroupedModelComboboxReturn => { const { t } = useTranslation(); - const base_model = useAppSelector((s) => s.generation.model?.base_model ?? 'sdxl'); + const base_model = useAppSelector((s) => s.generation.model?.base ?? 'sdxl'); const { modelEntities, selectedModel, getIsDisabled, onChange, isLoading } = arg; const options = useMemo[]>(() => { if (!modelEntities) { @@ -42,8 +43,8 @@ export const useGroupedModelCombobox = ( acc.push({ label, options: val.map((model) => ({ - label: model.model_name, - value: model.id, + label: model.name, + value: model.key, isDisabled: getIsDisabled ? getIsDisabled(model) : false, })), }); diff --git a/invokeai/frontend/web/src/features/lora/components/LoRACard.tsx b/invokeai/frontend/web/src/features/lora/components/LoRACard.tsx index 81e0027b2d1..71ce1457864 100644 --- a/invokeai/frontend/web/src/features/lora/components/LoRACard.tsx +++ b/invokeai/frontend/web/src/features/lora/components/LoRACard.tsx @@ -26,18 +26,18 @@ export const LoRACard = memo((props: LoRACardProps) => { const handleChange = useCallback( (v: number) => { - dispatch(loraWeightChanged({ id: lora.id, weight: v })); + dispatch(loraWeightChanged({ key: lora.key, weight: v })); }, - [dispatch, lora.id] + [dispatch, lora.key] ); const handleSetLoraToggle = useCallback(() => { - dispatch(loraIsEnabledChanged({ id: lora.id, isEnabled: !lora.isEnabled })); - }, [dispatch, lora.id, lora.isEnabled]); + dispatch(loraIsEnabledChanged({ key: lora.key, isEnabled: !lora.isEnabled })); + }, [dispatch, lora.key, lora.isEnabled]); const handleRemoveLora = useCallback(() => { - dispatch(loraRemoved(lora.id)); - }, [dispatch, lora.id]); + dispatch(loraRemoved(lora.key)); + }, [dispatch, lora.key]); return ( diff --git a/invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx b/invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx index 069c557aefa..b58751ca5e2 100644 --- a/invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx +++ b/invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx @@ -7,8 +7,8 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox'; import { loraAdded, selectLoraSlice } from 'features/lora/store/loraSlice'; import { memo, useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; -import type { LoRAConfig } from 'services/api/endpoints/models'; import { useGetLoRAModelsQuery } from 'services/api/endpoints/models'; +import type { LoRAConfig } from 'services/api/types'; const selectAddedLoRAs = createMemoizedSelector(selectLoraSlice, (lora) => lora.loras); @@ -17,11 +17,11 @@ const LoRASelect = () => { const { data, isLoading } = useGetLoRAModelsQuery(); const { t } = useTranslation(); const addedLoRAs = useAppSelector(selectAddedLoRAs); - const currentBaseModel = useAppSelector((s) => s.generation.model?.base_model); + const currentBaseModel = useAppSelector((s) => s.generation.model?.base); const getIsDisabled = (lora: LoRAConfig): boolean => { - const isCompatible = currentBaseModel === lora.base_model; - const isAdded = Boolean(addedLoRAs[lora.id]); + const isCompatible = currentBaseModel === lora.base; + const isAdded = Boolean(addedLoRAs[lora.key]); const hasMainModel = Boolean(currentBaseModel); return !hasMainModel || !isCompatible || isAdded; }; From 53e13f7b440dc20df82bba71edcdce95ef7d5609 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 20 Feb 2024 10:45:05 +1100 Subject: [PATCH 147/340] fix(ui): get embedding select working --- .../frontend/web/src/features/embedding/EmbeddingSelect.tsx | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/invokeai/frontend/web/src/features/embedding/EmbeddingSelect.tsx b/invokeai/frontend/web/src/features/embedding/EmbeddingSelect.tsx index 426ddd21e23..fd05edc4667 100644 --- a/invokeai/frontend/web/src/features/embedding/EmbeddingSelect.tsx +++ b/invokeai/frontend/web/src/features/embedding/EmbeddingSelect.tsx @@ -18,7 +18,7 @@ export const EmbeddingSelect = memo(({ onSelect, onClose }: EmbeddingSelectProps const getIsDisabled = useCallback( (embedding: TextualInversionConfig): boolean => { - const isCompatible = currentBaseModel === embedding.base_model; + const isCompatible = currentBaseModel === embedding.base; const hasMainModel = Boolean(currentBaseModel); return !hasMainModel || !isCompatible; }, @@ -31,7 +31,7 @@ export const EmbeddingSelect = memo(({ onSelect, onClose }: EmbeddingSelectProps if (!embedding) { return; } - onSelect(embedding.model_name); + onSelect(embedding.key); }, [onSelect] ); From 47935692266a319a3b09c796dc18b2bbbe991c91 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 20 Feb 2024 11:14:08 +1100 Subject: [PATCH 148/340] fix(ui): get vae model select working --- .../common/hooks/useGroupedModelCombobox.ts | 4 +--- .../web/src/common/hooks/useModelCombobox.ts | 12 ++++++------ .../VAEModel/ParamVAEModelSelect.tsx | 18 ++++++++++-------- .../web/src/services/api/endpoints/models.ts | 3 --- 4 files changed, 17 insertions(+), 20 deletions(-) diff --git a/invokeai/frontend/web/src/common/hooks/useGroupedModelCombobox.ts b/invokeai/frontend/web/src/common/hooks/useGroupedModelCombobox.ts index 140cf3eaa69..fc5bc455eef 100644 --- a/invokeai/frontend/web/src/common/hooks/useGroupedModelCombobox.ts +++ b/invokeai/frontend/web/src/common/hooks/useGroupedModelCombobox.ts @@ -6,7 +6,6 @@ import type { ModelIdentifierWithBase } from 'features/nodes/types/common'; import { groupBy, map, reduce } from 'lodash-es'; import { useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; -import { getModelId } from 'services/api/endpoints/models'; import type { AnyModelConfig } from 'services/api/types'; type UseGroupedModelComboboxArg = { @@ -58,8 +57,7 @@ export const useGroupedModelCombobox = ( const value = useMemo( () => - options.flatMap((o) => o.options).find((m) => (selectedModel ? m.value === getModelId(selectedModel) : false)) ?? - null, + options.flatMap((o) => o.options).find((m) => (selectedModel ? m.value === selectedModel.key : false)) ?? null, [options, selectedModel] ); diff --git a/invokeai/frontend/web/src/common/hooks/useModelCombobox.ts b/invokeai/frontend/web/src/common/hooks/useModelCombobox.ts index 07e6aeb34c4..e0718d64132 100644 --- a/invokeai/frontend/web/src/common/hooks/useModelCombobox.ts +++ b/invokeai/frontend/web/src/common/hooks/useModelCombobox.ts @@ -1,14 +1,14 @@ import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library'; import type { EntityState } from '@reduxjs/toolkit'; +import type { ModelIdentifierWithBase } from 'features/nodes/types/common'; import { map } from 'lodash-es'; import { useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; -import type { AnyModelConfig } from 'services/api/endpoints/models'; -import { getModelId } from 'services/api/endpoints/models'; +import type { AnyModelConfig } from 'services/api/types'; type UseModelComboboxArg = { modelEntities: EntityState | undefined; - selectedModel?: Pick | null; + selectedModel?: ModelIdentifierWithBase | null; onChange: (value: T | null) => void; getIsDisabled?: (model: T) => boolean; optionsFilter?: (model: T) => boolean; @@ -33,14 +33,14 @@ export const useModelCombobox = (arg: UseModelCombobox return map(modelEntities.entities) .filter(optionsFilter) .map((model) => ({ - label: model.model_name, - value: model.id, + label: model.name, + value: model.key, isDisabled: getIsDisabled ? getIsDisabled(model) : false, })); }, [optionsFilter, getIsDisabled, modelEntities]); const value = useMemo( - () => options.find((m) => (selectedModel ? m.value === getModelId(selectedModel) : false)), + () => options.find((m) => (selectedModel ? m.value === selectedModel.key : false)), [options, selectedModel] ); diff --git a/invokeai/frontend/web/src/features/parameters/components/VAEModel/ParamVAEModelSelect.tsx b/invokeai/frontend/web/src/features/parameters/components/VAEModel/ParamVAEModelSelect.tsx index cc0164153d8..1810c3ff68a 100644 --- a/invokeai/frontend/web/src/features/parameters/components/VAEModel/ParamVAEModelSelect.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/VAEModel/ParamVAEModelSelect.tsx @@ -7,8 +7,8 @@ import { selectGenerationSlice, vaeSelected } from 'features/parameters/store/ge import { pick } from 'lodash-es'; import { memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; -import type { VAEConfig } from 'services/api/endpoints/models'; import { useGetVaeModelsQuery } from 'services/api/endpoints/models'; +import type { VAEConfig } from 'services/api/types'; const selector = createMemoizedSelector(selectGenerationSlice, (generation) => { const { model, vae } = generation; @@ -22,25 +22,27 @@ const ParamVAEModelSelect = () => { const { data, isLoading } = useGetVaeModelsQuery(); const getIsDisabled = useCallback( (vae: VAEConfig): boolean => { - const isCompatible = model?.base_model === vae.base_model; - const hasMainModel = Boolean(model?.base_model); + const isCompatible = model?.base === vae.base; + const hasMainModel = Boolean(model?.base); return !hasMainModel || !isCompatible; }, - [model?.base_model] + [model?.base] ); const _onChange = useCallback( (vae: VAEConfig | null) => { - dispatch(vaeSelected(vae ? pick(vae, 'base_model', 'model_name') : null)); + dispatch(vaeSelected(vae ? pick(vae, 'key', 'base') : null)); }, [dispatch] ); - const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({ + const { options, value, onChange, noOptionsMessage } = useGroupedModelCombobox({ modelEntities: data, onChange: _onChange, - selectedModel: vae ? { ...vae, model_type: 'vae' } : null, + selectedModel: vae ? pick(vae, 'key', 'base') : null, isLoading, getIsDisabled, }); + + console.log(value) return ( @@ -50,7 +52,7 @@ const ParamVAEModelSelect = () => { input; - type UpdateMainModelArg = { base_model: BaseModelType; model_name: string; From ebc3b1dc002cd219625d2804517831e48284c230 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 20 Feb 2024 11:16:17 +1100 Subject: [PATCH 149/340] fix(ui): get refiner model select working --- .../SDXLRefiner/ParamSDXLRefinerModelSelect.tsx | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerModelSelect.tsx b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerModelSelect.tsx index 4c542515573..e5978ca21b4 100644 --- a/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerModelSelect.tsx +++ b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerModelSelect.tsx @@ -4,15 +4,16 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover'; import { useModelCombobox } from 'common/hooks/useModelCombobox'; import { refinerModelChanged, selectSdxlSlice } from 'features/sdxl/store/sdxlSlice'; +import { pick } from 'lodash-es'; import { memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; import { REFINER_BASE_MODELS } from 'services/api/constants'; -import type { MainModelConfig } from 'services/api/endpoints/models'; import { useGetMainModelsQuery } from 'services/api/endpoints/models'; +import type { MainModelConfig } from 'services/api/types'; const selectModel = createMemoizedSelector(selectSdxlSlice, (sdxl) => sdxl.refinerModel); -const optionsFilter = (model: MainModelConfig) => model.base_model === 'sdxl-refiner'; +const optionsFilter = (model: MainModelConfig) => model.base === 'sdxl-refiner'; const ParamSDXLRefinerModelSelect = () => { const dispatch = useAppDispatch(); @@ -25,13 +26,7 @@ const ParamSDXLRefinerModelSelect = () => { dispatch(refinerModelChanged(null)); return; } - dispatch( - refinerModelChanged({ - base_model: 'sdxl-refiner', - model_name: model.model_name, - model_type: model.model_type, - }) - ); + dispatch(refinerModelChanged(pick(model, ['key', 'base']))); }, [dispatch] ); From f6735d04d693f7d42a2946128f4f08c12d5ac5a7 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 20 Feb 2024 11:32:58 +1100 Subject: [PATCH 150/340] fix(ui): get workflow editor model selects working --- .../inputs/ControlNetModelFieldInputComponent.tsx | 5 +++-- .../inputs/IPAdapterModelFieldInputComponent.tsx | 5 +++-- .../fields/inputs/LoRAModelFieldInputComponent.tsx | 5 +++-- .../fields/inputs/MainModelFieldInputComponent.tsx | 2 +- .../inputs/RefinerModelFieldInputComponent.tsx | 2 +- .../inputs/SDXLMainModelFieldInputComponent.tsx | 2 +- .../inputs/T2IAdapterModelFieldInputComponent.tsx | 5 +++-- .../fields/inputs/VAEModelFieldInputComponent.tsx | 5 +++-- .../frontend/web/src/features/nodes/types/common.ts | 12 ++++++------ 9 files changed, 24 insertions(+), 19 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ControlNetModelFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ControlNetModelFieldInputComponent.tsx index 53d800e7b62..1951ec60d36 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ControlNetModelFieldInputComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ControlNetModelFieldInputComponent.tsx @@ -3,9 +3,10 @@ import { useAppDispatch } from 'app/store/storeHooks'; import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox'; import { fieldControlNetModelValueChanged } from 'features/nodes/store/nodesSlice'; import type { ControlNetModelFieldInputInstance, ControlNetModelFieldInputTemplate } from 'features/nodes/types/field'; +import { pick } from 'lodash-es'; import { memo, useCallback } from 'react'; -import type { ControlNetConfig } from 'services/api/endpoints/models'; import { useGetControlNetModelsQuery } from 'services/api/endpoints/models'; +import type { ControlNetConfig } from 'services/api/types'; import type { FieldComponentProps } from './types'; @@ -35,7 +36,7 @@ const ControlNetModelFieldInputComponent = (props: Props) => { const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({ modelEntities: data, onChange: _onChange, - selectedModel: field.value ? { ...field.value, model_type: 'controlnet' } : undefined, + selectedModel: field.value ? pick(field.value, ['key', 'base']) : undefined, isLoading, }); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/IPAdapterModelFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/IPAdapterModelFieldInputComponent.tsx index 3f195ceb32c..137f751fca1 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/IPAdapterModelFieldInputComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/IPAdapterModelFieldInputComponent.tsx @@ -3,9 +3,10 @@ import { useAppDispatch } from 'app/store/storeHooks'; import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox'; import { fieldIPAdapterModelValueChanged } from 'features/nodes/store/nodesSlice'; import type { IPAdapterModelFieldInputInstance, IPAdapterModelFieldInputTemplate } from 'features/nodes/types/field'; +import { pick } from 'lodash-es'; import { memo, useCallback } from 'react'; -import type { IPAdapterConfig } from 'services/api/endpoints/models'; import { useGetIPAdapterModelsQuery } from 'services/api/endpoints/models'; +import type { IPAdapterConfig } from 'services/api/types'; import type { FieldComponentProps } from './types'; @@ -35,7 +36,7 @@ const IPAdapterModelFieldInputComponent = ( const { options, value, onChange } = useGroupedModelCombobox({ modelEntities: ipAdapterModels, onChange: _onChange, - selectedModel: field.value ? { ...field.value, model_type: 'ip_adapter' } : undefined, + selectedModel: field.value ? pick(field.value, ['key', 'base']) : undefined, }); return ( diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/LoRAModelFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/LoRAModelFieldInputComponent.tsx index eeb07fa08e2..5f6318de9e5 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/LoRAModelFieldInputComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/LoRAModelFieldInputComponent.tsx @@ -3,9 +3,10 @@ import { useAppDispatch } from 'app/store/storeHooks'; import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox'; import { fieldLoRAModelValueChanged } from 'features/nodes/store/nodesSlice'; import type { LoRAModelFieldInputInstance, LoRAModelFieldInputTemplate } from 'features/nodes/types/field'; +import { pick } from 'lodash-es'; import { memo, useCallback } from 'react'; -import type { LoRAConfig } from 'services/api/endpoints/models'; import { useGetLoRAModelsQuery } from 'services/api/endpoints/models'; +import type { LoRAConfig } from 'services/api/types'; import type { FieldComponentProps } from './types'; @@ -34,7 +35,7 @@ const LoRAModelFieldInputComponent = (props: Props) => { const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({ modelEntities: data, onChange: _onChange, - selectedModel: field.value ? { ...field.value, model_type: 'lora' } : undefined, + selectedModel: field.value ? pick(field.value, ['key', 'base']) : undefined, isLoading, }); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/MainModelFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/MainModelFieldInputComponent.tsx index 7ddde08816c..1cb0658b81f 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/MainModelFieldInputComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/MainModelFieldInputComponent.tsx @@ -6,8 +6,8 @@ import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice'; import type { MainModelFieldInputInstance, MainModelFieldInputTemplate } from 'features/nodes/types/field'; import { memo, useCallback } from 'react'; import { NON_SDXL_MAIN_MODELS } from 'services/api/constants'; -import type { MainModelConfig } from 'services/api/endpoints/models'; import { useGetMainModelsQuery } from 'services/api/endpoints/models'; +import type { MainModelConfig } from 'services/api/types'; import type { FieldComponentProps } from './types'; diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/RefinerModelFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/RefinerModelFieldInputComponent.tsx index 9b5a1138d4a..be2b4a4d4f3 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/RefinerModelFieldInputComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/RefinerModelFieldInputComponent.tsx @@ -9,8 +9,8 @@ import type { } from 'features/nodes/types/field'; import { memo, useCallback } from 'react'; import { REFINER_BASE_MODELS } from 'services/api/constants'; -import type { MainModelConfig } from 'services/api/endpoints/models'; import { useGetMainModelsQuery } from 'services/api/endpoints/models'; +import type { MainModelConfig } from 'services/api/types'; import type { FieldComponentProps } from './types'; diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/SDXLMainModelFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/SDXLMainModelFieldInputComponent.tsx index cf353619e8f..d0d7754606b 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/SDXLMainModelFieldInputComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/SDXLMainModelFieldInputComponent.tsx @@ -6,8 +6,8 @@ import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice'; import type { SDXLMainModelFieldInputInstance, SDXLMainModelFieldInputTemplate } from 'features/nodes/types/field'; import { memo, useCallback } from 'react'; import { SDXL_MAIN_MODELS } from 'services/api/constants'; -import type { MainModelConfig } from 'services/api/endpoints/models'; import { useGetMainModelsQuery } from 'services/api/endpoints/models'; +import type { MainModelConfig } from 'services/api/types'; import type { FieldComponentProps } from './types'; diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/T2IAdapterModelFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/T2IAdapterModelFieldInputComponent.tsx index 8402c56343a..9115f22c145 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/T2IAdapterModelFieldInputComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/T2IAdapterModelFieldInputComponent.tsx @@ -3,9 +3,10 @@ import { useAppDispatch } from 'app/store/storeHooks'; import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox'; import { fieldT2IAdapterModelValueChanged } from 'features/nodes/store/nodesSlice'; import type { T2IAdapterModelFieldInputInstance, T2IAdapterModelFieldInputTemplate } from 'features/nodes/types/field'; +import { pick } from 'lodash-es'; import { memo, useCallback } from 'react'; -import type { T2IAdapterConfig } from 'services/api/endpoints/models'; import { useGetT2IAdapterModelsQuery } from 'services/api/endpoints/models'; +import type { T2IAdapterConfig } from 'services/api/types'; import type { FieldComponentProps } from './types'; @@ -36,7 +37,7 @@ const T2IAdapterModelFieldInputComponent = ( const { options, value, onChange } = useGroupedModelCombobox({ modelEntities: t2iAdapterModels, onChange: _onChange, - selectedModel: field.value ? { ...field.value, model_type: 't2i_adapter' } : undefined, + selectedModel: field.value ? pick(field.value, ['key', 'base']) : undefined, }); return ( diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/VAEModelFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/VAEModelFieldInputComponent.tsx index af09f2d8f20..87272f48b9b 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/VAEModelFieldInputComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/VAEModelFieldInputComponent.tsx @@ -4,9 +4,10 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox'; import { SyncModelsIconButton } from 'features/modelManager/components/SyncModels/SyncModelsIconButton'; import { fieldVaeModelValueChanged } from 'features/nodes/store/nodesSlice'; import type { VAEModelFieldInputInstance, VAEModelFieldInputTemplate } from 'features/nodes/types/field'; +import { pick } from 'lodash-es'; import { memo, useCallback } from 'react'; -import type { VAEConfig } from 'services/api/endpoints/models'; import { useGetVaeModelsQuery } from 'services/api/endpoints/models'; +import type { VAEConfig } from 'services/api/types'; import type { FieldComponentProps } from './types'; @@ -34,7 +35,7 @@ const VAEModelFieldInputComponent = (props: Props) => { const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({ modelEntities: data, onChange: _onChange, - selectedModel: field.value ? { ...field.value, model_type: 'vae' } : null, + selectedModel: field.value ? pick(field.value, ['key', 'base']) : null, isLoading, }); diff --git a/invokeai/frontend/web/src/features/nodes/types/common.ts b/invokeai/frontend/web/src/features/nodes/types/common.ts index 891bd29bc86..d5d04deaa54 100644 --- a/invokeai/frontend/web/src/features/nodes/types/common.ts +++ b/invokeai/frontend/web/src/features/nodes/types/common.ts @@ -73,7 +73,7 @@ export type BaseModel = z.infer; export type ModelType = z.infer; export type ModelIdentifier = z.infer; export type ModelIdentifierWithBase = z.infer; -export const zMainModelField = zModelFieldBase; +export const zMainModelField = zModelIdentifierWithBase; export type MainModelField = z.infer; export const zSDXLRefinerModelField = zModelIdentifier; @@ -93,23 +93,23 @@ export const zSubModelType = z.enum([ ]); export type SubModelType = z.infer; -export const zVAEModelField = zModelFieldBase; +export const zVAEModelField = zModelIdentifierWithBase; export const zModelInfo = zModelIdentifier.extend({ submodel_type: zSubModelType.nullish(), }); export type ModelInfo = z.infer; -export const zLoRAModelField = zModelFieldBase; +export const zLoRAModelField = zModelIdentifierWithBase; export type LoRAModelField = z.infer; -export const zControlNetModelField = zModelFieldBase; +export const zControlNetModelField = zModelIdentifierWithBase; export type ControlNetModelField = z.infer; -export const zIPAdapterModelField = zModelFieldBase; +export const zIPAdapterModelField = zModelIdentifierWithBase; export type IPAdapterModelField = z.infer; -export const zT2IAdapterModelField = zModelFieldBase; +export const zT2IAdapterModelField = zModelIdentifierWithBase; export type T2IAdapterModelField = z.infer; export const zLoraInfo = zModelInfo.extend({ From f0d476f45da29de422299dbb60def1bafc3ffbef Mon Sep 17 00:00:00 2001 From: Stefan Tobler Date: Sun, 7 Jan 2024 19:54:58 -0500 Subject: [PATCH 151/340] groundwork for the bulk_download_service --- .../app/services/bulk_download/__init__.py | 0 .../bulk_download/bulk_download_base.py | 32 +++++++++++++++++++ .../bulk_download/bulk_download_common.py | 21 ++++++++++++ 3 files changed, 53 insertions(+) create mode 100644 invokeai/app/services/bulk_download/__init__.py create mode 100644 invokeai/app/services/bulk_download/bulk_download_base.py create mode 100644 invokeai/app/services/bulk_download/bulk_download_common.py diff --git a/invokeai/app/services/bulk_download/__init__.py b/invokeai/app/services/bulk_download/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/invokeai/app/services/bulk_download/bulk_download_base.py b/invokeai/app/services/bulk_download/bulk_download_base.py new file mode 100644 index 00000000000..54c87714374 --- /dev/null +++ b/invokeai/app/services/bulk_download/bulk_download_base.py @@ -0,0 +1,32 @@ +from pathlib import Path +from typing import Optional, Union + +from abc import ABC, abstractmethod + +from invokeai.app.services.events.events_base import EventServiceBase +from invokeai.app.services.invoker import Invoker + +class BulkDownloadBase(ABC): + + @abstractmethod + def __init__( + self, + output_folder: Union[str, Path], + event_bus: Optional["EventServiceBase"] = None, + ): + """ + Create BulkDownloadBase object. + + :param output_folder: The path to the output folder where the bulk download files can be temporarily stored. + :param event_bus: InvokeAI event bus for reporting events to. + """ + + @abstractmethod + def start(self, invoker: Invoker, image_names: list[str], board_id: Optional[str]) -> str: + """ + Starts a a bulk download job. + + :param invoker: The Invoker that holds all the services, required to be passed as a parameter to avoid circular dependencies. + :param image_names: A list of image names to include in the zip file. + :param board_id: The ID of the board. If provided, all images associated with the board will be included in the zip file. + """ \ No newline at end of file diff --git a/invokeai/app/services/bulk_download/bulk_download_common.py b/invokeai/app/services/bulk_download/bulk_download_common.py new file mode 100644 index 00000000000..3ac1f5bba88 --- /dev/null +++ b/invokeai/app/services/bulk_download/bulk_download_common.py @@ -0,0 +1,21 @@ + +class BulkDownloadException(Exception): + """Exception raised when a bulk download fails.""" + + def __init__(self, message="Bulk download failed"): + super().__init__(message) + self.message = message + +class BulkDownloadTargetException(BulkDownloadException): + """Exception raised when a bulk download target is not found.""" + + def __init__(self, message="The bulk download target was not found"): + super().__init__(message) + self.message = message + +class BulkDownloadParametersException(BulkDownloadException): + """Exception raised when a bulk download parameter is invalid.""" + + def __init__(self, message="The bulk download parameters are invalid, either an array of image names or a board id must be provided"): + super().__init__(message) + self.message = message \ No newline at end of file From 506a87a11349d5c9b5dbf6643cab1299bf17f6a5 Mon Sep 17 00:00:00 2001 From: Stefan Tobler Date: Sun, 7 Jan 2024 19:55:59 -0500 Subject: [PATCH 152/340] adding socket events for bulk download --- invokeai/app/api/sockets.py | 30 ++++++++++++++++++-- invokeai/app/services/events/events_base.py | 31 +++++++++++++++++++++ 2 files changed, 59 insertions(+), 2 deletions(-) diff --git a/invokeai/app/api/sockets.py b/invokeai/app/api/sockets.py index e651e435591..c5d9ace8d26 100644 --- a/invokeai/app/api/sockets.py +++ b/invokeai/app/api/sockets.py @@ -12,16 +12,27 @@ class SocketIO: __sio: AsyncServer __app: ASGIApp + __sub_queue: str = "subscribe_queue" + __unsub_queue: str = "unsubscribe_queue" + + __sub_bulk_download: str = "subscribe_bulk_download" + __unsub_bulk_download: str = "unsubscribe_bulk_download" + + def __init__(self, app: FastAPI): self.__sio = AsyncServer(async_mode="asgi", cors_allowed_origins="*") self.__app = ASGIApp(socketio_server=self.__sio, socketio_path="/ws/socket.io") app.mount("/ws", self.__app) - self.__sio.on("subscribe_queue", handler=self._handle_sub_queue) - self.__sio.on("unsubscribe_queue", handler=self._handle_unsub_queue) + self.__sio.on(self.__sub_queue, handler=self._handle_sub_queue) + self.__sio.on(self.__unsub_queue, handler=self._handle_unsub_queue) local_handler.register(event_name=EventServiceBase.queue_event, _func=self._handle_queue_event) local_handler.register(event_name=EventServiceBase.model_event, _func=self._handle_model_event) + self.__sio.on(self.__sub_bulk_download, handler=self._handle_sub_bulk_download) + self.__sio.on(self.__unsub_bulk_download, handler=self._handle_unsub_bulk_download) + local_handler.register(event_name=EventServiceBase.bulk_download_event, _func=self._handle_bulk_download_event) + async def _handle_queue_event(self, event: Event): await self.__sio.emit( event=event[1]["event"], @@ -39,3 +50,18 @@ async def _handle_unsub_queue(self, sid, data, *args, **kwargs) -> None: async def _handle_model_event(self, event: Event) -> None: await self.__sio.emit(event=event[1]["event"], data=event[1]["data"]) + + async def _handle_bulk_download_event(self, event: Event): + await self.__sio.emit( + event=event[1]["event"], + data=event[1]["data"], + room=event[1]["data"]["bulk_download_id"], + ) + + async def _handle_sub_bulk_download(self, sid, data, *args, **kwargs): + if "bulk_download_id" in data: + await self.__sio.enter_room(sid, data["bulk_download_id"]) + + async def _handle_unsub_bulk_download(self, sid, data, *args, **kwargs): + if "bulk_download_id" in data: + await self.__sio.leave_room(sid, data["bulk_download_id"]) diff --git a/invokeai/app/services/events/events_base.py b/invokeai/app/services/events/events_base.py index 5355fe22987..0a0668b2740 100644 --- a/invokeai/app/services/events/events_base.py +++ b/invokeai/app/services/events/events_base.py @@ -16,6 +16,7 @@ class EventServiceBase: queue_event: str = "queue_event" + bulk_download_event: str = "bulk_download_event" download_event: str = "download_event" model_event: str = "model_event" @@ -24,6 +25,14 @@ class EventServiceBase: def dispatch(self, event_name: str, payload: Any) -> None: pass + def _emit_bulk_download_event(self, event_name: str, payload: dict) -> None: + """Bulk download events are emitted to a room with queue_id as the room name""" + payload["timestamp"] = get_timestamp() + self.dispatch( + event_name=EventServiceBase.bulk_download_event, + payload={"event": event_name, "data": payload}, + ) + def __emit_queue_event(self, event_name: str, payload: dict) -> None: """Queue events are emitted to a room with queue_id as the room name""" payload["timestamp"] = get_timestamp() @@ -430,3 +439,25 @@ def emit_model_install_error( "error": error, }, ) + + def emit_bulk_download_started(self, bulk_download_id: str) -> None: + """Emitted when a bulk download starts""" + self._emit_bulk_download_event( + event_name="bulk_download_started", + payload={"bulk_download_id": bulk_download_id, } + ) + + def emit_bulk_download_completed(self, bulk_download_id: str, file_path: str) -> None: + """Emitted when a bulk download completes""" + self._emit_bulk_download_event( + event_name="bulk_download_completed", + payload={"bulk_download_id": bulk_download_id, + "file_path": file_path} + ) + + def emit_bulk_download_failed(self, bulk_download_id: str, error: str) -> None: + """Emitted when a bulk download fails""" + self._emit_bulk_download_event( + event_name="bulk_download_failed", + payload={"bulk_download_id": bulk_download_id, "error": error} + ) From 27311d4182f0dbedd1d77fec5d927ad92936020f Mon Sep 17 00:00:00 2001 From: Stefan Tobler Date: Sun, 7 Jan 2024 21:29:42 -0500 Subject: [PATCH 153/340] implementation of bulkdownload background task --- invokeai/app/api/dependencies.py | 3 + invokeai/app/api/routers/images.py | 14 ++- .../bulk_download/bulk_download_base.py | 2 +- .../bulk_download/bulk_download_defauilt.py | 114 ++++++++++++++++++ invokeai/app/services/invocation_services.py | 3 + 5 files changed, 130 insertions(+), 6 deletions(-) create mode 100644 invokeai/app/services/bulk_download/bulk_download_defauilt.py diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index a9132516a86..ab09d1e5d7b 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -15,6 +15,7 @@ from ..services.board_images.board_images_default import BoardImagesService from ..services.board_records.board_records_sqlite import SqliteBoardRecordStorage from ..services.boards.boards_default import BoardService +from ..services.bulk_download.bulk_download_defauilt import BulkDownloadService from ..services.config import InvokeAIAppConfig from ..services.download import DownloadQueueService from ..services.image_files.image_files_disk import DiskImageFileStorage @@ -81,6 +82,7 @@ def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger board_records = SqliteBoardRecordStorage(db=db) boards = BoardService() events = FastAPIEventService(event_handler_id) + bulk_download = BulkDownloadService(output_folder=f"{output_folder}", event_bus=events) image_records = SqliteImageRecordStorage(db=db) images = ImageService() invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size) @@ -110,6 +112,7 @@ def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger board_images=board_images, board_records=board_records, boards=boards, + bulk_download=bulk_download, configuration=configuration, events=events, image_files=image_files, diff --git a/invokeai/app/api/routers/images.py b/invokeai/app/api/routers/images.py index cc60ad1be83..2a8e1e7ec7f 100644 --- a/invokeai/app/api/routers/images.py +++ b/invokeai/app/api/routers/images.py @@ -2,7 +2,7 @@ import traceback from typing import Optional -from fastapi import Body, HTTPException, Path, Query, Request, Response, UploadFile +from fastapi import BackgroundTasks, Body, HTTPException, Path, Query, Request, Response, UploadFile from fastapi.responses import FileResponse from fastapi.routing import APIRouter from PIL import Image @@ -10,6 +10,7 @@ from invokeai.app.invocations.fields import MetadataField, MetadataFieldValidator from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin +from invokeai.app.services.board_records.board_records_common import BoardRecordNotFoundException from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO from invokeai.app.services.shared.pagination import OffsetPaginatedResults from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID, WorkflowWithoutIDValidator @@ -372,19 +373,22 @@ async def unstar_images_in_list( except Exception: raise HTTPException(status_code=500, detail="Failed to unstar images") - class ImagesDownloaded(BaseModel): response: Optional[str] = Field( description="If defined, the message to display to the user when images begin downloading" ) -@images_router.post("/download", operation_id="download_images_from_list", response_model=ImagesDownloaded) +@images_router.post("/download", operation_id="download_images_from_list", response_model=ImagesDownloaded, status_code=202) async def download_images_from_list( + background_tasks: BackgroundTasks, image_names: list[str] = Body(description="The list of names of images to download", embed=True), board_id: Optional[str] = Body( default=None, description="The board from which image should be downloaded from", embed=True ), ) -> ImagesDownloaded: - # return ImagesDownloaded(response="Your images are downloading") - raise HTTPException(status_code=501, detail="Endpoint is not yet implemented") + if (image_names is None or len(image_names) == 0) and board_id is None: + raise HTTPException(status_code=400, detail="No images or board id specified.") + background_tasks.add_task(ApiDependencies.invoker.services.bulk_download.handler, ApiDependencies.invoker, image_names, board_id) + return ImagesDownloaded(response="Your images are preparing to be downloaded") + diff --git a/invokeai/app/services/bulk_download/bulk_download_base.py b/invokeai/app/services/bulk_download/bulk_download_base.py index 54c87714374..b788020bba2 100644 --- a/invokeai/app/services/bulk_download/bulk_download_base.py +++ b/invokeai/app/services/bulk_download/bulk_download_base.py @@ -22,7 +22,7 @@ def __init__( """ @abstractmethod - def start(self, invoker: Invoker, image_names: list[str], board_id: Optional[str]) -> str: + def handler(self, invoker: Invoker, image_names: list[str], board_id: Optional[str]) -> None: """ Starts a a bulk download job. diff --git a/invokeai/app/services/bulk_download/bulk_download_defauilt.py b/invokeai/app/services/bulk_download/bulk_download_defauilt.py new file mode 100644 index 00000000000..ebeaa4be5a7 --- /dev/null +++ b/invokeai/app/services/bulk_download/bulk_download_defauilt.py @@ -0,0 +1,114 @@ +from pathlib import Path +from typing import Optional, Union +import uuid +from zipfile import ZipFile + +from invokeai.app.services.board_records.board_records_common import BoardRecordNotFoundException +from invokeai.app.services.bulk_download.bulk_download_common import BulkDownloadException +from invokeai.app.services.events.events_base import EventServiceBase +from invokeai.app.services.image_records.image_records_common import ImageRecordNotFoundException +from invokeai.app.services.invoker import Invoker + +from .bulk_download_base import BulkDownloadBase + +class BulkDownloadService(BulkDownloadBase): + + __output_folder: Path + __bulk_downloads_folder: Path + __event_bus: Optional[EventServiceBase] + + def __init__(self, + output_folder: Union[str, Path], + event_bus: Optional[EventServiceBase] = None,): + """ + Initialize the downloader object. + + :param event_bus: Optional EventService object + """ + self.__output_folder: Path = output_folder if isinstance(output_folder, Path) else Path(output_folder) + self.__bulk_downloads_folder = self.__output_folder / "bulk_downloads" + self.__bulk_downloads_folder.mkdir(parents=True, exist_ok=True) + self.__event_bus = event_bus + + + def handler(self, invoker: Invoker, image_names: list[str], board_id: Optional[str]) -> None: + """ + Create a zip file containing the images specified by the given image names or board id. + + param: image_names: A list of image names to include in the zip file. + param: board_id: The ID of the board. If provided, all images associated with the board will be included in the zip file. + """ + bulk_download_id = str(uuid.uuid4()) + + self._signal_job_started(bulk_download_id) + try: + board_name: Union[str, None] = None + if board_id: + image_names = invoker.services.board_image_records.get_all_board_image_names_for_board(board_id) + if board_id == "none": + board_id = "Uncategorized" + image_names_to_paths: dict[str, str] = self._get_image_name_to_path_map(invoker, image_names) + file_path: str = self._create_zip_file(image_names_to_paths, bulk_download_id) + self._signal_job_completed(bulk_download_id, file_path) + except (ImageRecordNotFoundException, BoardRecordNotFoundException, BulkDownloadException) as e: + self._signal_job_failed(bulk_download_id, e) + except Exception as e: + self._signal_job_failed(bulk_download_id, e) + + def _get_image_name_to_path_map(self, invoker: Invoker, image_names: list[str]) -> dict[str, str]: + """ + Create a map of image names to their paths. + :param image_names: A list of image names. + """ + image_names_to_paths: dict[str, str] = {} + for image_name in image_names: + image_names_to_paths[image_name] = invoker.services.images.get_path(image_name) + return image_names_to_paths + + + + def _create_zip_file(self, image_names_to_paths: dict[str, str], bulk_download_id: str) -> str: + """ + Create a zip file containing the images specified by the given image names or board id. + If download with the same bulk_download_id already exists, it will be overwritten. + """ + + zip_file_path = self.__bulk_downloads_folder / (bulk_download_id + ".zip") + + with ZipFile(zip_file_path, "w") as zip_file: + for image_name, image_path in image_names_to_paths.items(): + zip_file.write(image_path, arcname=image_name) + + return str(zip_file_path) + + + def _signal_job_started(self, bulk_download_id: str) -> None: + """Signal that a bulk download job has started.""" + if self.__event_bus: + assert bulk_download_id is not None + self.__event_bus.emit_bulk_download_started( + bulk_download_id=bulk_download_id, + ) + + + def _signal_job_completed(self, bulk_download_id: str, file_path: str) -> None: + """Signal that a bulk download job has completed.""" + if self.__event_bus: + assert bulk_download_id is not None + assert file_path is not None + self.__event_bus.emit_bulk_download_completed( + bulk_download_id=bulk_download_id, + file_path=file_path, + ) + + def _signal_job_failed(self, bulk_download_id: str, exception: Exception) -> None: + """Signal that a bulk download job has failed.""" + if self.__event_bus: + assert bulk_download_id is not None + assert exception is not None + self.__event_bus.emit_bulk_download_failed( + bulk_download_id=bulk_download_id, + error=str(exception), + ) + + \ No newline at end of file diff --git a/invokeai/app/services/invocation_services.py b/invokeai/app/services/invocation_services.py index 04fe71a3eb3..a560696692e 100644 --- a/invokeai/app/services/invocation_services.py +++ b/invokeai/app/services/invocation_services.py @@ -16,6 +16,7 @@ from .board_images.board_images_base import BoardImagesServiceABC from .board_records.board_records_base import BoardRecordStorageBase from .boards.boards_base import BoardServiceABC + from .bulk_download.bulk_download_base import BulkDownloadBase from .config import InvokeAIAppConfig from .download import DownloadQueueServiceBase from .events.events_base import EventServiceBase @@ -41,6 +42,7 @@ def __init__( board_image_records: "BoardImageRecordStorageBase", boards: "BoardServiceABC", board_records: "BoardRecordStorageBase", + bulk_download: "BulkDownloadBase", configuration: "InvokeAIAppConfig", events: "EventServiceBase", images: "ImageServiceABC", @@ -63,6 +65,7 @@ def __init__( self.board_image_records = board_image_records self.boards = boards self.board_records = board_records + self.bulk_download = bulk_download self.configuration = configuration self.events = events self.images = images From 6f85ced8f352038225b81f2d0b055572d2846572 Mon Sep 17 00:00:00 2001 From: Stefan Tobler Date: Sun, 7 Jan 2024 22:17:03 -0500 Subject: [PATCH 154/340] linted and styling --- invokeai/app/api/routers/images.py | 11 ++++--- invokeai/app/api/sockets.py | 1 - .../bulk_download/bulk_download_base.py | 9 +++--- .../bulk_download/bulk_download_common.py | 10 ++++-- .../bulk_download/bulk_download_defauilt.py | 31 +++++++------------ invokeai/app/services/events/events_base.py | 15 +++++---- 6 files changed, 37 insertions(+), 40 deletions(-) diff --git a/invokeai/app/api/routers/images.py b/invokeai/app/api/routers/images.py index 2a8e1e7ec7f..e32f7fb9eec 100644 --- a/invokeai/app/api/routers/images.py +++ b/invokeai/app/api/routers/images.py @@ -10,7 +10,6 @@ from invokeai.app.invocations.fields import MetadataField, MetadataFieldValidator from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin -from invokeai.app.services.board_records.board_records_common import BoardRecordNotFoundException from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO from invokeai.app.services.shared.pagination import OffsetPaginatedResults from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID, WorkflowWithoutIDValidator @@ -373,13 +372,16 @@ async def unstar_images_in_list( except Exception: raise HTTPException(status_code=500, detail="Failed to unstar images") + class ImagesDownloaded(BaseModel): response: Optional[str] = Field( description="If defined, the message to display to the user when images begin downloading" ) -@images_router.post("/download", operation_id="download_images_from_list", response_model=ImagesDownloaded, status_code=202) +@images_router.post( + "/download", operation_id="download_images_from_list", response_model=ImagesDownloaded, status_code=202 +) async def download_images_from_list( background_tasks: BackgroundTasks, image_names: list[str] = Body(description="The list of names of images to download", embed=True), @@ -389,6 +391,7 @@ async def download_images_from_list( ) -> ImagesDownloaded: if (image_names is None or len(image_names) == 0) and board_id is None: raise HTTPException(status_code=400, detail="No images or board id specified.") - background_tasks.add_task(ApiDependencies.invoker.services.bulk_download.handler, ApiDependencies.invoker, image_names, board_id) + background_tasks.add_task( + ApiDependencies.invoker.services.bulk_download.handler, ApiDependencies.invoker, image_names, board_id + ) return ImagesDownloaded(response="Your images are preparing to be downloaded") - diff --git a/invokeai/app/api/sockets.py b/invokeai/app/api/sockets.py index c5d9ace8d26..463545d9bc1 100644 --- a/invokeai/app/api/sockets.py +++ b/invokeai/app/api/sockets.py @@ -18,7 +18,6 @@ class SocketIO: __sub_bulk_download: str = "subscribe_bulk_download" __unsub_bulk_download: str = "unsubscribe_bulk_download" - def __init__(self, app: FastAPI): self.__sio = AsyncServer(async_mode="asgi", cors_allowed_origins="*") self.__app = ASGIApp(socketio_server=self.__sio, socketio_path="/ws/socket.io") diff --git a/invokeai/app/services/bulk_download/bulk_download_base.py b/invokeai/app/services/bulk_download/bulk_download_base.py index b788020bba2..fc45aff2806 100644 --- a/invokeai/app/services/bulk_download/bulk_download_base.py +++ b/invokeai/app/services/bulk_download/bulk_download_base.py @@ -1,13 +1,12 @@ +from abc import ABC, abstractmethod from pathlib import Path from typing import Optional, Union -from abc import ABC, abstractmethod - from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.invoker import Invoker -class BulkDownloadBase(ABC): +class BulkDownloadBase(ABC): @abstractmethod def __init__( self, @@ -25,8 +24,8 @@ def __init__( def handler(self, invoker: Invoker, image_names: list[str], board_id: Optional[str]) -> None: """ Starts a a bulk download job. - + :param invoker: The Invoker that holds all the services, required to be passed as a parameter to avoid circular dependencies. :param image_names: A list of image names to include in the zip file. :param board_id: The ID of the board. If provided, all images associated with the board will be included in the zip file. - """ \ No newline at end of file + """ diff --git a/invokeai/app/services/bulk_download/bulk_download_common.py b/invokeai/app/services/bulk_download/bulk_download_common.py index 3ac1f5bba88..23a0589dafd 100644 --- a/invokeai/app/services/bulk_download/bulk_download_common.py +++ b/invokeai/app/services/bulk_download/bulk_download_common.py @@ -1,4 +1,3 @@ - class BulkDownloadException(Exception): """Exception raised when a bulk download fails.""" @@ -6,6 +5,7 @@ def __init__(self, message="Bulk download failed"): super().__init__(message) self.message = message + class BulkDownloadTargetException(BulkDownloadException): """Exception raised when a bulk download target is not found.""" @@ -13,9 +13,13 @@ def __init__(self, message="The bulk download target was not found"): super().__init__(message) self.message = message + class BulkDownloadParametersException(BulkDownloadException): """Exception raised when a bulk download parameter is invalid.""" - def __init__(self, message="The bulk download parameters are invalid, either an array of image names or a board id must be provided"): + def __init__( + self, + message="The bulk download parameters are invalid, either an array of image names or a board id must be provided", + ): super().__init__(message) - self.message = message \ No newline at end of file + self.message = message diff --git a/invokeai/app/services/bulk_download/bulk_download_defauilt.py b/invokeai/app/services/bulk_download/bulk_download_defauilt.py index ebeaa4be5a7..8321f5069db 100644 --- a/invokeai/app/services/bulk_download/bulk_download_defauilt.py +++ b/invokeai/app/services/bulk_download/bulk_download_defauilt.py @@ -1,6 +1,6 @@ +import uuid from pathlib import Path from typing import Optional, Union -import uuid from zipfile import ZipFile from invokeai.app.services.board_records.board_records_common import BoardRecordNotFoundException @@ -11,15 +11,17 @@ from .bulk_download_base import BulkDownloadBase -class BulkDownloadService(BulkDownloadBase): +class BulkDownloadService(BulkDownloadBase): __output_folder: Path __bulk_downloads_folder: Path __event_bus: Optional[EventServiceBase] - def __init__(self, - output_folder: Union[str, Path], - event_bus: Optional[EventServiceBase] = None,): + def __init__( + self, + output_folder: Union[str, Path], + event_bus: Optional[EventServiceBase] = None, + ): """ Initialize the downloader object. @@ -30,8 +32,7 @@ def __init__(self, self.__bulk_downloads_folder.mkdir(parents=True, exist_ok=True) self.__event_bus = event_bus - - def handler(self, invoker: Invoker, image_names: list[str], board_id: Optional[str]) -> None: + def handler(self, invoker: Invoker, image_names: list[str], board_id: Optional[str]) -> None: """ Create a zip file containing the images specified by the given image names or board id. @@ -39,10 +40,8 @@ def handler(self, invoker: Invoker, image_names: list[str], board_id: Optional[ param: board_id: The ID of the board. If provided, all images associated with the board will be included in the zip file. """ bulk_download_id = str(uuid.uuid4()) - - self._signal_job_started(bulk_download_id) + try: - board_name: Union[str, None] = None if board_id: image_names = invoker.services.board_image_records.get_all_board_image_names_for_board(board_id) if board_id == "none": @@ -64,24 +63,21 @@ def _get_image_name_to_path_map(self, invoker: Invoker, image_names: list[str]) for image_name in image_names: image_names_to_paths[image_name] = invoker.services.images.get_path(image_name) return image_names_to_paths - - def _create_zip_file(self, image_names_to_paths: dict[str, str], bulk_download_id: str) -> str: """ - Create a zip file containing the images specified by the given image names or board id. + Create a zip file containing the images specified by the given image names or board id. If download with the same bulk_download_id already exists, it will be overwritten. """ zip_file_path = self.__bulk_downloads_folder / (bulk_download_id + ".zip") - + with ZipFile(zip_file_path, "w") as zip_file: for image_name, image_path in image_names_to_paths.items(): zip_file.write(image_path, arcname=image_name) return str(zip_file_path) - def _signal_job_started(self, bulk_download_id: str) -> None: """Signal that a bulk download job has started.""" if self.__event_bus: @@ -90,7 +86,6 @@ def _signal_job_started(self, bulk_download_id: str) -> None: bulk_download_id=bulk_download_id, ) - def _signal_job_completed(self, bulk_download_id: str, file_path: str) -> None: """Signal that a bulk download job has completed.""" if self.__event_bus: @@ -100,7 +95,7 @@ def _signal_job_completed(self, bulk_download_id: str, file_path: str) -> None: bulk_download_id=bulk_download_id, file_path=file_path, ) - + def _signal_job_failed(self, bulk_download_id: str, exception: Exception) -> None: """Signal that a bulk download job has failed.""" if self.__event_bus: @@ -110,5 +105,3 @@ def _signal_job_failed(self, bulk_download_id: str, exception: Exception) -> Non bulk_download_id=bulk_download_id, error=str(exception), ) - - \ No newline at end of file diff --git a/invokeai/app/services/events/events_base.py b/invokeai/app/services/events/events_base.py index 0a0668b2740..597a56d9442 100644 --- a/invokeai/app/services/events/events_base.py +++ b/invokeai/app/services/events/events_base.py @@ -444,20 +444,19 @@ def emit_bulk_download_started(self, bulk_download_id: str) -> None: """Emitted when a bulk download starts""" self._emit_bulk_download_event( event_name="bulk_download_started", - payload={"bulk_download_id": bulk_download_id, } + payload={ + "bulk_download_id": bulk_download_id, + }, ) - + def emit_bulk_download_completed(self, bulk_download_id: str, file_path: str) -> None: """Emitted when a bulk download completes""" self._emit_bulk_download_event( - event_name="bulk_download_completed", - payload={"bulk_download_id": bulk_download_id, - "file_path": file_path} + event_name="bulk_download_completed", payload={"bulk_download_id": bulk_download_id, "file_path": file_path} ) - + def emit_bulk_download_failed(self, bulk_download_id: str, error: str) -> None: """Emitted when a bulk download fails""" self._emit_bulk_download_event( - event_name="bulk_download_failed", - payload={"bulk_download_id": bulk_download_id, "error": error} + event_name="bulk_download_failed", payload={"bulk_download_id": bulk_download_id, "error": error} ) From cd69f77d64714b1974c2ff4f0711b2eed656ca49 Mon Sep 17 00:00:00 2001 From: Stefan Tobler Date: Sat, 13 Jan 2024 23:35:33 -0500 Subject: [PATCH 155/340] reworking some of the logic to use a default room, adding endpoint to download file on complete --- invokeai/app/api/dependencies.py | 2 +- invokeai/app/api/routers/images.py | 33 +++++++++ .../bulk_download/bulk_download_base.py | 25 +++++++ .../bulk_download/bulk_download_common.py | 3 + ...d_defauilt.py => bulk_download_default.py} | 72 +++++++++++++++---- invokeai/app/services/events/events_base.py | 23 ++++-- 6 files changed, 137 insertions(+), 21 deletions(-) rename invokeai/app/services/bulk_download/{bulk_download_defauilt.py => bulk_download_default.py} (60%) diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index ab09d1e5d7b..aaa08a2498d 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -15,7 +15,7 @@ from ..services.board_images.board_images_default import BoardImagesService from ..services.board_records.board_records_sqlite import SqliteBoardRecordStorage from ..services.boards.boards_default import BoardService -from ..services.bulk_download.bulk_download_defauilt import BulkDownloadService +from ..services.bulk_download.bulk_download_default import BulkDownloadService from ..services.config import InvokeAIAppConfig from ..services.download import DownloadQueueService from ..services.image_files.image_files_disk import DiskImageFileStorage diff --git a/invokeai/app/api/routers/images.py b/invokeai/app/api/routers/images.py index e32f7fb9eec..43392dd4719 100644 --- a/invokeai/app/api/routers/images.py +++ b/invokeai/app/api/routers/images.py @@ -395,3 +395,36 @@ async def download_images_from_list( ApiDependencies.invoker.services.bulk_download.handler, ApiDependencies.invoker, image_names, board_id ) return ImagesDownloaded(response="Your images are preparing to be downloaded") + + +@images_router.api_route( + "/download/{bulk_download_item_name}", + methods=["GET"], + operation_id="get_bulk_download_item", + response_class=Response, + responses={ + 200: { + "description": "Return the complete bulk download item", + "content": {"application/zip": {}}, + }, + 404: {"description": "Image not found"}, + }, +) +async def get_bulk_download_item( + bulk_download_item_name: str = Path(description="The bulk_download_item_id of the bulk download item to get"), +) -> FileResponse: + """Gets a bulk download zip file""" + + try: + path = ApiDependencies.invoker.services.bulk_download.get_path(bulk_download_item_name) + + response = FileResponse( + path, + media_type="application/zip", + filename=bulk_download_item_name, + content_disposition_type="inline", + ) + response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}" + return response + except Exception: + raise HTTPException(status_code=404) diff --git a/invokeai/app/services/bulk_download/bulk_download_base.py b/invokeai/app/services/bulk_download/bulk_download_base.py index fc45aff2806..8a9ea1f3f22 100644 --- a/invokeai/app/services/bulk_download/bulk_download_base.py +++ b/invokeai/app/services/bulk_download/bulk_download_base.py @@ -29,3 +29,28 @@ def handler(self, invoker: Invoker, image_names: list[str], board_id: Optional[s :param image_names: A list of image names to include in the zip file. :param board_id: The ID of the board. If provided, all images associated with the board will be included in the zip file. """ + + @abstractmethod + def get_path(self, bulk_download_item_id: str) -> str: + """ + Get the path to the bulk download file. + + :param bulk_download_item_id: The ID of the bulk download item. + :return: The path to the bulk download file. + """ + + @abstractmethod + def stop(self, *args, **kwargs) -> None: + """ + Stops the BulkDownloadService and cleans up all the remnants. + + This method is responsible for stopping the BulkDownloadService and performing any necessary cleanup + operations to remove any remnants or resources associated with the service. + + Args: + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + + Returns: + None + """ diff --git a/invokeai/app/services/bulk_download/bulk_download_common.py b/invokeai/app/services/bulk_download/bulk_download_common.py index 23a0589dafd..37b80073bee 100644 --- a/invokeai/app/services/bulk_download/bulk_download_common.py +++ b/invokeai/app/services/bulk_download/bulk_download_common.py @@ -1,3 +1,6 @@ +DEFAULT_BULK_DOWNLOAD_ID = "default" + + class BulkDownloadException(Exception): """Exception raised when a bulk download fails.""" diff --git a/invokeai/app/services/bulk_download/bulk_download_defauilt.py b/invokeai/app/services/bulk_download/bulk_download_default.py similarity index 60% rename from invokeai/app/services/bulk_download/bulk_download_defauilt.py rename to invokeai/app/services/bulk_download/bulk_download_default.py index 8321f5069db..561fd173a81 100644 --- a/invokeai/app/services/bulk_download/bulk_download_defauilt.py +++ b/invokeai/app/services/bulk_download/bulk_download_default.py @@ -4,7 +4,11 @@ from zipfile import ZipFile from invokeai.app.services.board_records.board_records_common import BoardRecordNotFoundException -from invokeai.app.services.bulk_download.bulk_download_common import BulkDownloadException +from invokeai.app.services.bulk_download.bulk_download_common import ( + DEFAULT_BULK_DOWNLOAD_ID, + BulkDownloadException, + BulkDownloadTargetException, +) from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.image_records.image_records_common import ImageRecordNotFoundException from invokeai.app.services.invoker import Invoker @@ -32,6 +36,32 @@ def __init__( self.__bulk_downloads_folder.mkdir(parents=True, exist_ok=True) self.__event_bus = event_bus + def get_path(self, bulk_download_item_name: str) -> str: + """ + Get the path to the bulk download file. + + :param bulk_download_item_name: The name of the bulk download item. + :return: The path to the bulk download file. + """ + path = str(self.__bulk_downloads_folder / bulk_download_item_name) + if not self.validate_path(path): + raise BulkDownloadTargetException() + return path + + def get_bulk_download_item_name(self, bulk_download_item_id: str) -> str: + """ + Get the name of the bulk download item. + + :param bulk_download_item_id: The ID of the bulk download item. + :return: The name of the bulk download item. + """ + return bulk_download_item_id + ".zip" + + def validate_path(self, path: Union[str, Path]) -> bool: + """Validates the path given for a bulk download.""" + path = path if isinstance(path, Path) else Path(path) + return path.exists() + def handler(self, invoker: Invoker, image_names: list[str], board_id: Optional[str]) -> None: """ Create a zip file containing the images specified by the given image names or board id. @@ -39,7 +69,9 @@ def handler(self, invoker: Invoker, image_names: list[str], board_id: Optional[s param: image_names: A list of image names to include in the zip file. param: board_id: The ID of the board. If provided, all images associated with the board will be included in the zip file. """ - bulk_download_id = str(uuid.uuid4()) + bulk_download_id = DEFAULT_BULK_DOWNLOAD_ID + bulk_download_item_id = str(uuid.uuid4()) + self._signal_job_started(bulk_download_id, bulk_download_item_id) try: if board_id: @@ -47,12 +79,12 @@ def handler(self, invoker: Invoker, image_names: list[str], board_id: Optional[s if board_id == "none": board_id = "Uncategorized" image_names_to_paths: dict[str, str] = self._get_image_name_to_path_map(invoker, image_names) - file_path: str = self._create_zip_file(image_names_to_paths, bulk_download_id) - self._signal_job_completed(bulk_download_id, file_path) + bulk_download_item_name: str = self._create_zip_file(image_names_to_paths, bulk_download_item_id) + self._signal_job_completed(bulk_download_id, bulk_download_item_id, bulk_download_item_name) except (ImageRecordNotFoundException, BoardRecordNotFoundException, BulkDownloadException) as e: - self._signal_job_failed(bulk_download_id, e) + self._signal_job_failed(bulk_download_id, bulk_download_item_id, e) except Exception as e: - self._signal_job_failed(bulk_download_id, e) + self._signal_job_failed(bulk_download_id, bulk_download_item_id, e) def _get_image_name_to_path_map(self, invoker: Invoker, image_names: list[str]) -> dict[str, str]: """ @@ -64,44 +96,54 @@ def _get_image_name_to_path_map(self, invoker: Invoker, image_names: list[str]) image_names_to_paths[image_name] = invoker.services.images.get_path(image_name) return image_names_to_paths - def _create_zip_file(self, image_names_to_paths: dict[str, str], bulk_download_id: str) -> str: + def _create_zip_file(self, image_names_to_paths: dict[str, str], bulk_download_item_id: str) -> str: """ Create a zip file containing the images specified by the given image names or board id. If download with the same bulk_download_id already exists, it will be overwritten. - """ - zip_file_path = self.__bulk_downloads_folder / (bulk_download_id + ".zip") + :return: The name of the zip file. + """ + zip_file_name = bulk_download_item_id + ".zip" + zip_file_path = self.__bulk_downloads_folder / (zip_file_name) with ZipFile(zip_file_path, "w") as zip_file: for image_name, image_path in image_names_to_paths.items(): zip_file.write(image_path, arcname=image_name) - return str(zip_file_path) + return str(zip_file_name) - def _signal_job_started(self, bulk_download_id: str) -> None: + def _signal_job_started(self, bulk_download_id: str, bulk_download_item_id: str) -> None: """Signal that a bulk download job has started.""" if self.__event_bus: assert bulk_download_id is not None self.__event_bus.emit_bulk_download_started( bulk_download_id=bulk_download_id, + bulk_download_item_id=bulk_download_item_id, ) - def _signal_job_completed(self, bulk_download_id: str, file_path: str) -> None: + def _signal_job_completed( + self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str + ) -> None: """Signal that a bulk download job has completed.""" if self.__event_bus: assert bulk_download_id is not None - assert file_path is not None + assert bulk_download_item_name is not None self.__event_bus.emit_bulk_download_completed( bulk_download_id=bulk_download_id, - file_path=file_path, + bulk_download_item_id=bulk_download_item_id, + bulk_download_item_name=bulk_download_item_name, ) - def _signal_job_failed(self, bulk_download_id: str, exception: Exception) -> None: + def _signal_job_failed(self, bulk_download_id: str, bulk_download_item_id: str, exception: Exception) -> None: """Signal that a bulk download job has failed.""" if self.__event_bus: assert bulk_download_id is not None assert exception is not None self.__event_bus.emit_bulk_download_failed( bulk_download_id=bulk_download_id, + bulk_download_item_id=bulk_download_item_id, error=str(exception), ) + + def stop(self, *args, **kwargs): + pass diff --git a/invokeai/app/services/events/events_base.py b/invokeai/app/services/events/events_base.py index 597a56d9442..3cc3ba2f28f 100644 --- a/invokeai/app/services/events/events_base.py +++ b/invokeai/app/services/events/events_base.py @@ -440,23 +440,36 @@ def emit_model_install_error( }, ) - def emit_bulk_download_started(self, bulk_download_id: str) -> None: + def emit_bulk_download_started(self, bulk_download_id: str, bulk_download_item_id: str) -> None: """Emitted when a bulk download starts""" self._emit_bulk_download_event( event_name="bulk_download_started", payload={ "bulk_download_id": bulk_download_id, + "bulk_download_item_id": bulk_download_item_id, }, ) - def emit_bulk_download_completed(self, bulk_download_id: str, file_path: str) -> None: + def emit_bulk_download_completed( + self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str + ) -> None: """Emitted when a bulk download completes""" self._emit_bulk_download_event( - event_name="bulk_download_completed", payload={"bulk_download_id": bulk_download_id, "file_path": file_path} + event_name="bulk_download_completed", + payload={ + "bulk_download_id": bulk_download_id, + "bulk_download_item_id": bulk_download_item_id, + "bulk_download_item_name": bulk_download_item_name, + }, ) - def emit_bulk_download_failed(self, bulk_download_id: str, error: str) -> None: + def emit_bulk_download_failed(self, bulk_download_id: str, bulk_download_item_id: str, error: str) -> None: """Emitted when a bulk download fails""" self._emit_bulk_download_event( - event_name="bulk_download_failed", payload={"bulk_download_id": bulk_download_id, "error": error} + event_name="bulk_download_failed", + payload={ + "bulk_download_id": bulk_download_id, + "bulk_download_item_id": bulk_download_item_id, + "error": error, + }, ) From b3264dbc8ca742bf309ba928f2015d4cd897aa3d Mon Sep 17 00:00:00 2001 From: Stefan Tobler Date: Sun, 14 Jan 2024 00:21:00 -0500 Subject: [PATCH 156/340] using the board name to download boards --- .../bulk_download/bulk_download_base.py | 4 +- .../bulk_download/bulk_download_default.py | 39 ++++++++++++------- 2 files changed, 27 insertions(+), 16 deletions(-) diff --git a/invokeai/app/services/bulk_download/bulk_download_base.py b/invokeai/app/services/bulk_download/bulk_download_base.py index 8a9ea1f3f22..366a5fec5fe 100644 --- a/invokeai/app/services/bulk_download/bulk_download_base.py +++ b/invokeai/app/services/bulk_download/bulk_download_base.py @@ -31,11 +31,11 @@ def handler(self, invoker: Invoker, image_names: list[str], board_id: Optional[s """ @abstractmethod - def get_path(self, bulk_download_item_id: str) -> str: + def get_path(self, bulk_download_item_name: str) -> str: """ Get the path to the bulk download file. - :param bulk_download_item_id: The ID of the bulk download item. + :param bulk_download_item_name: The name of the bulk download item. :return: The path to the bulk download file. """ diff --git a/invokeai/app/services/bulk_download/bulk_download_default.py b/invokeai/app/services/bulk_download/bulk_download_default.py index 561fd173a81..b80b8cc2f5e 100644 --- a/invokeai/app/services/bulk_download/bulk_download_default.py +++ b/invokeai/app/services/bulk_download/bulk_download_default.py @@ -48,15 +48,6 @@ def get_path(self, bulk_download_item_name: str) -> str: raise BulkDownloadTargetException() return path - def get_bulk_download_item_name(self, bulk_download_item_id: str) -> str: - """ - Get the name of the bulk download item. - - :param bulk_download_item_id: The ID of the bulk download item. - :return: The name of the bulk download item. - """ - return bulk_download_item_id + ".zip" - def validate_path(self, path: Union[str, Path]) -> bool: """Validates the path given for a bulk download.""" path = path if isinstance(path, Path) else Path(path) @@ -69,17 +60,27 @@ def handler(self, invoker: Invoker, image_names: list[str], board_id: Optional[s param: image_names: A list of image names to include in the zip file. param: board_id: The ID of the board. If provided, all images associated with the board will be included in the zip file. """ - bulk_download_id = DEFAULT_BULK_DOWNLOAD_ID - bulk_download_item_id = str(uuid.uuid4()) - self._signal_job_started(bulk_download_id, bulk_download_item_id) + + bulk_download_id: str = DEFAULT_BULK_DOWNLOAD_ID + bulk_download_item_id: str = str(uuid.uuid4()) if board_id is None else board_id try: + board_name: str = "" if board_id: image_names = invoker.services.board_image_records.get_all_board_image_names_for_board(board_id) if board_id == "none": board_id = "Uncategorized" + board_name = "Uncategorized" + else: + board_name = invoker.services.board_records.get(board_id).board_name + board_name = self._clean_string_to_path_safe(board_name) + + self._signal_job_started(bulk_download_id, bulk_download_item_id) + image_names_to_paths: dict[str, str] = self._get_image_name_to_path_map(invoker, image_names) - bulk_download_item_name: str = self._create_zip_file(image_names_to_paths, bulk_download_item_id) + bulk_download_item_name: str = self._create_zip_file( + image_names_to_paths, bulk_download_item_id if board_id is None else board_name + ) self._signal_job_completed(bulk_download_id, bulk_download_item_id, bulk_download_item_name) except (ImageRecordNotFoundException, BoardRecordNotFoundException, BulkDownloadException) as e: self._signal_job_failed(bulk_download_id, bulk_download_item_id, e) @@ -112,6 +113,10 @@ def _create_zip_file(self, image_names_to_paths: dict[str, str], bulk_download_i return str(zip_file_name) + def _clean_string_to_path_safe(self, s: str) -> str: + """Clean a string to be path safe.""" + return "".join([c for c in s if c.isalpha() or c.isdigit() or c == " "]).rstrip() + def _signal_job_started(self, bulk_download_id: str, bulk_download_item_id: str) -> None: """Signal that a bulk download job has started.""" if self.__event_bus: @@ -146,4 +151,10 @@ def _signal_job_failed(self, bulk_download_id: str, bulk_download_item_id: str, ) def stop(self, *args, **kwargs): - pass + """Stop the bulk download service and delete the files in the bulk download folder.""" + # Get all the files in the bulk downloads folder + files = self.__bulk_downloads_folder.glob("*") + + # Delete all the files + for file in files: + file.unlink() From 4f3d1260fa80b77379d575117e643c08c36826f9 Mon Sep 17 00:00:00 2001 From: Stefan Tobler Date: Sun, 14 Jan 2024 01:33:43 -0500 Subject: [PATCH 157/340] fixing issue where default board did not return images --- .../bulk_download/bulk_download_default.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/invokeai/app/services/bulk_download/bulk_download_default.py b/invokeai/app/services/bulk_download/bulk_download_default.py index b80b8cc2f5e..36d2b350b94 100644 --- a/invokeai/app/services/bulk_download/bulk_download_default.py +++ b/invokeai/app/services/bulk_download/bulk_download_default.py @@ -10,7 +10,7 @@ BulkDownloadTargetException, ) from invokeai.app.services.events.events_base import EventServiceBase -from invokeai.app.services.image_records.image_records_common import ImageRecordNotFoundException +from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordNotFoundException from invokeai.app.services.invoker import Invoker from .bulk_download_base import BulkDownloadBase @@ -67,7 +67,17 @@ def handler(self, invoker: Invoker, image_names: list[str], board_id: Optional[s try: board_name: str = "" if board_id: - image_names = invoker.services.board_image_records.get_all_board_image_names_for_board(board_id) + # -1 is the default value for limit, which means no limit, is_intermediate only gives us completed images + image_names = [ + img.image_name + for img in invoker.services.images.get_many( + offset=0, + limit=-1, + board_id=board_id, + is_intermediate=False, + categories=[ImageCategory.GENERAL], + ).items + ] if board_id == "none": board_id = "Uncategorized" board_name = "Uncategorized" From dee70062563cfda2191694e25dfc979637ba2804 Mon Sep 17 00:00:00 2001 From: Stefan Tobler Date: Sun, 14 Jan 2024 03:09:35 -0500 Subject: [PATCH 158/340] refactoring bulkdownload to consider image category --- invokeai/app/api/routers/images.py | 4 +- .../bulk_download/bulk_download_base.py | 13 +++- .../bulk_download/bulk_download_default.py | 62 +++++++++---------- 3 files changed, 43 insertions(+), 36 deletions(-) diff --git a/invokeai/app/api/routers/images.py b/invokeai/app/api/routers/images.py index 43392dd4719..236961fa9e4 100644 --- a/invokeai/app/api/routers/images.py +++ b/invokeai/app/api/routers/images.py @@ -391,9 +391,7 @@ async def download_images_from_list( ) -> ImagesDownloaded: if (image_names is None or len(image_names) == 0) and board_id is None: raise HTTPException(status_code=400, detail="No images or board id specified.") - background_tasks.add_task( - ApiDependencies.invoker.services.bulk_download.handler, ApiDependencies.invoker, image_names, board_id - ) + background_tasks.add_task(ApiDependencies.invoker.services.bulk_download.handler, image_names, board_id) return ImagesDownloaded(response="Your images are preparing to be downloaded") diff --git a/invokeai/app/services/bulk_download/bulk_download_base.py b/invokeai/app/services/bulk_download/bulk_download_base.py index 366a5fec5fe..880345fe982 100644 --- a/invokeai/app/services/bulk_download/bulk_download_base.py +++ b/invokeai/app/services/bulk_download/bulk_download_base.py @@ -7,6 +7,17 @@ class BulkDownloadBase(ABC): + @abstractmethod + def start(self, invoker: Invoker) -> None: + """ + Starts the BulkDownloadService. + + This method is responsible for starting the BulkDownloadService and performing any necessary initialization + operations to prepare the service for use. + + param: invoker: The Invoker that holds all the services, required to be passed as a parameter to avoid circular dependencies. + """ + @abstractmethod def __init__( self, @@ -21,7 +32,7 @@ def __init__( """ @abstractmethod - def handler(self, invoker: Invoker, image_names: list[str], board_id: Optional[str]) -> None: + def handler(self, image_names: list[str], board_id: Optional[str]) -> None: """ Starts a a bulk download job. diff --git a/invokeai/app/services/bulk_download/bulk_download_default.py b/invokeai/app/services/bulk_download/bulk_download_default.py index 36d2b350b94..ffc26dfa54b 100644 --- a/invokeai/app/services/bulk_download/bulk_download_default.py +++ b/invokeai/app/services/bulk_download/bulk_download_default.py @@ -10,7 +10,8 @@ BulkDownloadTargetException, ) from invokeai.app.services.events.events_base import EventServiceBase -from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordNotFoundException +from invokeai.app.services.image_records.image_records_common import ImageRecordNotFoundException +from invokeai.app.services.images.images_common import ImageDTO from invokeai.app.services.invoker import Invoker from .bulk_download_base import BulkDownloadBase @@ -20,6 +21,10 @@ class BulkDownloadService(BulkDownloadBase): __output_folder: Path __bulk_downloads_folder: Path __event_bus: Optional[EventServiceBase] + __invoker: Invoker + + def start(self, invoker: Invoker) -> None: + self.__invoker = invoker def __init__( self, @@ -53,7 +58,7 @@ def validate_path(self, path: Union[str, Path]) -> bool: path = path if isinstance(path, Path) else Path(path) return path.exists() - def handler(self, invoker: Invoker, image_names: list[str], board_id: Optional[str]) -> None: + def handler(self, image_names: list[str], board_id: Optional[str]) -> None: """ Create a zip file containing the images specified by the given image names or board id. @@ -64,50 +69,40 @@ def handler(self, invoker: Invoker, image_names: list[str], board_id: Optional[s bulk_download_id: str = DEFAULT_BULK_DOWNLOAD_ID bulk_download_item_id: str = str(uuid.uuid4()) if board_id is None else board_id + self._signal_job_started(bulk_download_id, bulk_download_item_id) + try: board_name: str = "" + image_dtos: list[ImageDTO] = [] + if board_id: - # -1 is the default value for limit, which means no limit, is_intermediate only gives us completed images - image_names = [ - img.image_name - for img in invoker.services.images.get_many( - offset=0, - limit=-1, - board_id=board_id, - is_intermediate=False, - categories=[ImageCategory.GENERAL], - ).items - ] if board_id == "none": - board_id = "Uncategorized" board_name = "Uncategorized" else: - board_name = invoker.services.board_records.get(board_id).board_name + board_name = self.__invoker.services.board_records.get(board_id).board_name board_name = self._clean_string_to_path_safe(board_name) - self._signal_job_started(bulk_download_id, bulk_download_item_id) - - image_names_to_paths: dict[str, str] = self._get_image_name_to_path_map(invoker, image_names) + # -1 is the default value for limit, which means no limit, is_intermediate only gives us completed images + image_dtos = self.__invoker.services.images.get_many( + offset=0, + limit=-1, + board_id=board_id, + is_intermediate=False, + ).items + else: + image_dtos = [self.__invoker.services.images.get_dto(image_name) for image_name in image_names] bulk_download_item_name: str = self._create_zip_file( - image_names_to_paths, bulk_download_item_id if board_id is None else board_name + image_dtos, bulk_download_item_id if board_id is None else board_name ) self._signal_job_completed(bulk_download_id, bulk_download_item_id, bulk_download_item_name) except (ImageRecordNotFoundException, BoardRecordNotFoundException, BulkDownloadException) as e: self._signal_job_failed(bulk_download_id, bulk_download_item_id, e) except Exception as e: self._signal_job_failed(bulk_download_id, bulk_download_item_id, e) + self.__invoker.services.logger.error("Problem bulk downloading images.") + raise e - def _get_image_name_to_path_map(self, invoker: Invoker, image_names: list[str]) -> dict[str, str]: - """ - Create a map of image names to their paths. - :param image_names: A list of image names. - """ - image_names_to_paths: dict[str, str] = {} - for image_name in image_names: - image_names_to_paths[image_name] = invoker.services.images.get_path(image_name) - return image_names_to_paths - - def _create_zip_file(self, image_names_to_paths: dict[str, str], bulk_download_item_id: str) -> str: + def _create_zip_file(self, image_dtos: list[ImageDTO], bulk_download_item_id: str) -> str: """ Create a zip file containing the images specified by the given image names or board id. If download with the same bulk_download_id already exists, it will be overwritten. @@ -118,11 +113,14 @@ def _create_zip_file(self, image_names_to_paths: dict[str, str], bulk_download_i zip_file_path = self.__bulk_downloads_folder / (zip_file_name) with ZipFile(zip_file_path, "w") as zip_file: - for image_name, image_path in image_names_to_paths.items(): - zip_file.write(image_path, arcname=image_name) + for image_dto in image_dtos: + image_zip_path = Path(image_dto.image_category.value) / image_dto.image_name + image_path = self.__invoker.services.images.get_path(image_dto.image_name) + zip_file.write(image_path, arcname=image_zip_path) return str(zip_file_name) + # from https://stackoverflow.com/questions/7406102/create-sane-safe-filename-from-any-unsafe-string def _clean_string_to_path_safe(self, s: str) -> str: """Clean a string to be path safe.""" return "".join([c for c in s if c.isalpha() or c.isdigit() or c == " "]).rstrip() From eb278fda0072756b281ba60e5041f7a9221f1b62 Mon Sep 17 00:00:00 2001 From: Stefan Tobler Date: Sun, 14 Jan 2024 22:58:07 -0500 Subject: [PATCH 159/340] refactoring dummy event service, DRY principal; adding bulk_download_event to existing invoker tests --- .../services/download/test_download_queue.py | 28 ++------------- .../model_install/test_model_install.py | 1 - tests/fixtures/event_service.py | 34 +++++++++++++++++++ tests/test_graph_execution_state.py | 1 + 4 files changed, 37 insertions(+), 27 deletions(-) create mode 100644 tests/fixtures/event_service.py diff --git a/tests/app/services/download/test_download_queue.py b/tests/app/services/download/test_download_queue.py index 93a3832b51e..34408ac5aed 100644 --- a/tests/app/services/download/test_download_queue.py +++ b/tests/app/services/download/test_download_queue.py @@ -2,19 +2,17 @@ import re import time from pathlib import Path -from typing import Any, Dict, List import pytest -from pydantic import BaseModel from pydantic.networks import AnyHttpUrl from requests.sessions import Session from requests_testadapter import TestAdapter, TestSession from invokeai.app.services.download import DownloadJob, DownloadJobStatus, DownloadQueueService -from invokeai.app.services.events.events_base import EventServiceBase +from tests.fixtures.event_service import DummyEventService # Prevent pytest deprecation warnings -TestAdapter.__test__ = False +TestAdapter.__test__ = False # type: ignore @pytest.fixture @@ -52,28 +50,6 @@ def session() -> Session: return sess -class DummyEvent(BaseModel): - """Dummy Event to use with Dummy Event service.""" - - event_name: str - payload: Dict[str, Any] - - -# A dummy event service for testing event issuing -class DummyEventService(EventServiceBase): - """Dummy event service for testing.""" - - events: List[DummyEvent] - - def __init__(self) -> None: - super().__init__() - self.events = [] - - def dispatch(self, event_name: str, payload: Any) -> None: - """Dispatch an event by appending it to self.events.""" - self.events.append(DummyEvent(event_name=payload["event"], payload=payload["data"])) - - def test_basic_queue_download(tmp_path: Path, session: Session) -> None: events = set() diff --git a/tests/app/services/model_install/test_model_install.py b/tests/app/services/model_install/test_model_install.py index 55f7e865410..14c8ed5c84d 100644 --- a/tests/app/services/model_install/test_model_install.py +++ b/tests/app/services/model_install/test_model_install.py @@ -10,7 +10,6 @@ from pydantic.networks import Url from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.model_install import ( InstallStatus, LocalModelSource, diff --git a/tests/fixtures/event_service.py b/tests/fixtures/event_service.py new file mode 100644 index 00000000000..71262be3f92 --- /dev/null +++ b/tests/fixtures/event_service.py @@ -0,0 +1,34 @@ +from typing import Any, Dict, List + +import pytest +from pydantic import BaseModel + +from invokeai.app.services.events.events_base import EventServiceBase + + +class DummyEvent(BaseModel): + """Dummy Event to use with Dummy Event service.""" + + event_name: str + payload: Dict[str, Any] + + +# A dummy event service for testing event issuing +class DummyEventService(EventServiceBase): + """Dummy event service for testing.""" + + events: List[DummyEvent] + + def __init__(self) -> None: + super().__init__() + self.events = [] + + def dispatch(self, event_name: str, payload: Any) -> None: + """Dispatch an event by appending it to self.events.""" + self.events.append(DummyEvent(event_name=payload["event"], payload=payload["data"])) + + +@pytest.fixture +def mock_event_service() -> EventServiceBase: + """Create a dummy event service.""" + return DummyEventService() diff --git a/tests/test_graph_execution_state.py b/tests/test_graph_execution_state.py index 2e88178424a..9a350374311 100644 --- a/tests/test_graph_execution_state.py +++ b/tests/test_graph_execution_state.py @@ -50,6 +50,7 @@ def mock_services() -> InvocationServices: board_images=None, # type: ignore board_records=None, # type: ignore boards=None, # type: ignore + bulk_download=None, # type: ignore configuration=configuration, events=TestEventService(), image_files=None, # type: ignore From fe824f05269bcd53dadc93ad7453182fc4a14b9f Mon Sep 17 00:00:00 2001 From: Stefan Tobler Date: Mon, 15 Jan 2024 12:59:45 -0500 Subject: [PATCH 160/340] refactoring bulk_download to be better managed --- invokeai/app/api/dependencies.py | 2 +- .../bulk_download/bulk_download_base.py | 8 +----- .../bulk_download/bulk_download_default.py | 26 +++++++++---------- 3 files changed, 15 insertions(+), 21 deletions(-) diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index aaa08a2498d..984fd8e2670 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -82,7 +82,7 @@ def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger board_records = SqliteBoardRecordStorage(db=db) boards = BoardService() events = FastAPIEventService(event_handler_id) - bulk_download = BulkDownloadService(output_folder=f"{output_folder}", event_bus=events) + bulk_download = BulkDownloadService(output_folder=f"{output_folder}") image_records = SqliteImageRecordStorage(db=db) images = ImageService() invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size) diff --git a/invokeai/app/services/bulk_download/bulk_download_base.py b/invokeai/app/services/bulk_download/bulk_download_base.py index 880345fe982..7a4aa0661c0 100644 --- a/invokeai/app/services/bulk_download/bulk_download_base.py +++ b/invokeai/app/services/bulk_download/bulk_download_base.py @@ -2,7 +2,6 @@ from pathlib import Path from typing import Optional, Union -from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.invoker import Invoker @@ -19,16 +18,11 @@ def start(self, invoker: Invoker) -> None: """ @abstractmethod - def __init__( - self, - output_folder: Union[str, Path], - event_bus: Optional["EventServiceBase"] = None, - ): + def __init__(self, output_folder: Union[str, Path]): """ Create BulkDownloadBase object. :param output_folder: The path to the output folder where the bulk download files can be temporarily stored. - :param event_bus: InvokeAI event bus for reporting events to. """ @abstractmethod diff --git a/invokeai/app/services/bulk_download/bulk_download_default.py b/invokeai/app/services/bulk_download/bulk_download_default.py index ffc26dfa54b..a9ea12bfd6f 100644 --- a/invokeai/app/services/bulk_download/bulk_download_default.py +++ b/invokeai/app/services/bulk_download/bulk_download_default.py @@ -1,4 +1,3 @@ -import uuid from pathlib import Path from typing import Optional, Union from zipfile import ZipFile @@ -13,6 +12,7 @@ from invokeai.app.services.image_records.image_records_common import ImageRecordNotFoundException from invokeai.app.services.images.images_common import ImageDTO from invokeai.app.services.invoker import Invoker +from invokeai.app.util.misc import uuid_string from .bulk_download_base import BulkDownloadBase @@ -20,26 +20,23 @@ class BulkDownloadService(BulkDownloadBase): __output_folder: Path __bulk_downloads_folder: Path - __event_bus: Optional[EventServiceBase] + __event_bus: EventServiceBase __invoker: Invoker def start(self, invoker: Invoker) -> None: self.__invoker = invoker + self.__event_bus = invoker.services.events def __init__( self, output_folder: Union[str, Path], - event_bus: Optional[EventServiceBase] = None, ): """ Initialize the downloader object. - - :param event_bus: Optional EventService object """ self.__output_folder: Path = output_folder if isinstance(output_folder, Path) else Path(output_folder) self.__bulk_downloads_folder = self.__output_folder / "bulk_downloads" self.__bulk_downloads_folder.mkdir(parents=True, exist_ok=True) - self.__event_bus = event_bus def get_path(self, bulk_download_item_name: str) -> str: """ @@ -67,7 +64,7 @@ def handler(self, image_names: list[str], board_id: Optional[str]) -> None: """ bulk_download_id: str = DEFAULT_BULK_DOWNLOAD_ID - bulk_download_item_id: str = str(uuid.uuid4()) if board_id is None else board_id + bulk_download_item_id: str = uuid_string() if board_id is None else board_id self._signal_job_started(bulk_download_id, bulk_download_item_id) @@ -76,10 +73,7 @@ def handler(self, image_names: list[str], board_id: Optional[str]) -> None: image_dtos: list[ImageDTO] = [] if board_id: - if board_id == "none": - board_name = "Uncategorized" - else: - board_name = self.__invoker.services.board_records.get(board_id).board_name + board_name = self._get_board_name(board_id) board_name = self._clean_string_to_path_safe(board_name) # -1 is the default value for limit, which means no limit, is_intermediate only gives us completed images @@ -102,6 +96,12 @@ def handler(self, image_names: list[str], board_id: Optional[str]) -> None: self.__invoker.services.logger.error("Problem bulk downloading images.") raise e + def _get_board_name(self, board_id: str) -> str: + if board_id == "none": + return "Uncategorized" + + return self.__invoker.services.board_records.get(board_id).board_name + def _create_zip_file(self, image_dtos: list[ImageDTO], bulk_download_item_id: str) -> str: """ Create a zip file containing the images specified by the given image names or board id. @@ -115,8 +115,8 @@ def _create_zip_file(self, image_dtos: list[ImageDTO], bulk_download_item_id: st with ZipFile(zip_file_path, "w") as zip_file: for image_dto in image_dtos: image_zip_path = Path(image_dto.image_category.value) / image_dto.image_name - image_path = self.__invoker.services.images.get_path(image_dto.image_name) - zip_file.write(image_path, arcname=image_zip_path) + image_disk_path = self.__invoker.services.images.get_path(image_dto.image_name) + zip_file.write(image_disk_path, arcname=image_zip_path) return str(zip_file_name) From a405e5ef0c5ceff15787107414f2ad0baaedcfbe Mon Sep 17 00:00:00 2001 From: Stefan Tobler Date: Mon, 15 Jan 2024 13:01:06 -0500 Subject: [PATCH 161/340] 97% test coverage on bulk_download --- .../bulk_download/test_bulk_download.py | 319 ++++++++++++++++++ tests/fixtures/event_service.py | 2 +- 2 files changed, 320 insertions(+), 1 deletion(-) create mode 100644 tests/app/services/bulk_download/test_bulk_download.py diff --git a/tests/app/services/bulk_download/test_bulk_download.py b/tests/app/services/bulk_download/test_bulk_download.py new file mode 100644 index 00000000000..4f476c21bed --- /dev/null +++ b/tests/app/services/bulk_download/test_bulk_download.py @@ -0,0 +1,319 @@ +import os +from pathlib import Path +from typing import Any +from zipfile import ZipFile + +import pytest + +from invokeai.app.services.board_records.board_records_common import BoardRecord, BoardRecordNotFoundException +from invokeai.app.services.board_records.board_records_sqlite import SqliteBoardRecordStorage +from invokeai.app.services.bulk_download.bulk_download_common import BulkDownloadTargetException +from invokeai.app.services.bulk_download.bulk_download_default import BulkDownloadService +from invokeai.app.services.config.config_default import InvokeAIAppConfig +from invokeai.app.services.image_records.image_records_common import ( + ImageCategory, + ImageRecordNotFoundException, + ResourceOrigin, +) +from invokeai.app.services.images.images_common import ImageDTO +from invokeai.app.services.images.images_default import ImageService +from invokeai.app.services.invocation_services import InvocationServices +from invokeai.app.services.invoker import Invoker +from invokeai.app.services.shared.pagination import OffsetPaginatedResults +from invokeai.backend.util.logging import InvokeAILogger +from tests.fixtures.event_service import DummyEventService, mock_event_service # noqa: F401,F811 +from tests.fixtures.sqlite_database import create_mock_sqlite_database + + +@pytest.fixture +def mock_image_dto() -> ImageDTO: + """Create a mock ImageDTO.""" + return ImageDTO( + image_name="mock_image.png", + board_id="12345", + image_url="None", + width=100, + height=100, + thumbnail_url="None", + image_origin=ResourceOrigin.INTERNAL, + image_category=ImageCategory.GENERAL, + created_at="None", + updated_at="None", + starred=False, + has_workflow=False, + is_intermediate=False, + ) + + +@pytest.fixture +def mock_services(mock_event_service: DummyEventService) -> InvocationServices: + configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0) + logger = InvokeAILogger.get_logger() + db = create_mock_sqlite_database(configuration, logger) + + return InvocationServices( + board_image_records=None, # type: ignore + board_images=None, # type: ignore + board_records=SqliteBoardRecordStorage(db=db), + boards=None, # type: ignore + bulk_download=None, # type: ignore + configuration=None, # type: ignore + events=mock_event_service, + graph_execution_manager=None, # type: ignore + image_files=None, # type: ignore + image_records=None, # type: ignore + images=ImageService(), + invocation_cache=None, # type: ignore + latents=None, # type: ignore + logger=logger, + model_manager=None, # type: ignore + model_records=None, # type: ignore + download_queue=None, # type: ignore + model_install=None, # type: ignore + names=None, # type: ignore + performance_statistics=None, # type: ignore + processor=None, # type: ignore + queue=None, # type: ignore + session_processor=None, # type: ignore + session_queue=None, # type: ignore + urls=None, # type: ignore + workflow_records=None, # type: ignore + ) + + +@pytest.fixture() +def mock_invoker(mock_services: InvocationServices) -> Invoker: + return Invoker(services=mock_services) + + +def test_get_path_when_file_exists(tmp_path: Path) -> None: + """Test get_path when the file exists.""" + + # Create a directory at tmp_path/bulk_downloads + test_bulk_downloads_dir: Path = tmp_path / "bulk_downloads" + test_bulk_downloads_dir.mkdir(parents=True, exist_ok=True) + + # Create a file at tmp_path/bulk_downloads/test.zip + test_file_path: Path = test_bulk_downloads_dir / "test.zip" + test_file_path.touch() + + bulk_download_service = BulkDownloadService(tmp_path) + assert bulk_download_service.get_path("test.zip") == str(test_file_path) + + +def test_get_path_when_file_does_not_exist(tmp_path: Path) -> None: + """Test get_path when the file does not exist.""" + + bulk_download_service = BulkDownloadService(tmp_path) + with pytest.raises(BulkDownloadTargetException): + bulk_download_service.get_path("test") + + +def test_bulk_downloads_dir_created_at_start(tmp_path: Path) -> None: + """Test that the bulk_downloads directory is created at start.""" + + BulkDownloadService(tmp_path) + assert (tmp_path / "bulk_downloads").exists() + + +def test_handler_image_names(tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker): + """Test that the handler creates the zip file correctly when given a list of image names.""" + + expected_zip_path, expected_image_path, mock_image_contents = prepare_handler_test( + tmp_path, monkeypatch, mock_image_dto, mock_invoker + ) + + bulk_download_service = BulkDownloadService(tmp_path) + bulk_download_service.start(mock_invoker) + bulk_download_service.handler([mock_image_dto.image_name], None) + + assert_handler_success( + expected_zip_path, expected_image_path, mock_image_contents, tmp_path, mock_invoker.services.events + ) + + +def test_handler_board_id(tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker): + """Test that the handler creates the zip file correctly when given a board id.""" + + expected_zip_path, expected_image_path, mock_image_contents = prepare_handler_test( + tmp_path, monkeypatch, mock_image_dto, mock_invoker + ) + + def mock_board_get(*args, **kwargs): + return BoardRecord(board_id="12345", board_name="test", created_at="None", updated_at="None") + + monkeypatch.setattr(mock_invoker.services.board_records, "get", mock_board_get) + + def mock_get_many(*args, **kwargs): + return OffsetPaginatedResults(limit=-1, total=1, offset=0, items=[mock_image_dto]) + + monkeypatch.setattr(mock_invoker.services.images, "get_many", mock_get_many) + + bulk_download_service = BulkDownloadService(tmp_path) + bulk_download_service.start(mock_invoker) + bulk_download_service.handler([], "test") + + assert_handler_success( + expected_zip_path, expected_image_path, mock_image_contents, tmp_path, mock_invoker.services.events + ) + + +def test_handler_board_id_default(tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker): + """Test that the handler creates the zip file correctly when given a board id.""" + + _, expected_image_path, mock_image_contents = prepare_handler_test( + tmp_path, monkeypatch, mock_image_dto, mock_invoker + ) + expected_zip_path: Path = tmp_path / "bulk_downloads" / "Uncategorized.zip" + + def mock_get_many(*args, **kwargs): + return OffsetPaginatedResults(limit=-1, total=1, offset=0, items=[mock_image_dto]) + + monkeypatch.setattr(mock_invoker.services.images, "get_many", mock_get_many) + + bulk_download_service = BulkDownloadService(tmp_path) + bulk_download_service.start(mock_invoker) + bulk_download_service.handler([], "none") + + assert_handler_success( + expected_zip_path, expected_image_path, mock_image_contents, tmp_path, mock_invoker.services.events + ) + + +def prepare_handler_test(tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker): + """Prepare the test for the handler tests.""" + + def mock_uuid_string(): + return "test" + + # You have to patch the function within the module it's being imported into. This is strange, but it works. + # See http://www.gregreda.com/2021/06/28/mocking-imported-module-function-python/ + monkeypatch.setattr("invokeai.app.services.bulk_download.bulk_download_default.uuid_string", mock_uuid_string) + + expected_zip_path: Path = tmp_path / "bulk_downloads" / "test.zip" + expected_image_path: Path = ( + tmp_path / "bulk_downloads" / mock_image_dto.image_category.value / mock_image_dto.image_name + ) + + # Mock the get_dto method so that when the image dto needs to be retrieved it is returned + def mock_get_dto(*args, **kwargs): + return mock_image_dto + + monkeypatch.setattr(mock_invoker.services.images, "get_dto", mock_get_dto) + + # Create a mock image file so that the contents of the zip file are not empty + mock_image_path: Path = tmp_path / mock_image_dto.image_name + mock_image_contents: str = "Totally an image" + mock_image_path.write_text(mock_image_contents) + + def mock_get_path(*args, **kwargs): + return str(mock_image_path) + + monkeypatch.setattr(mock_invoker.services.images, "get_path", mock_get_path) + + return expected_zip_path, expected_image_path, mock_image_contents + + +def assert_handler_success( + expected_zip_path: Path, + expected_image_path: Path, + mock_image_contents: str, + tmp_path: Path, + event_bus: DummyEventService, +): + """Assert that the handler was successful.""" + # Check that the zip file was created + assert expected_zip_path.exists() + assert expected_zip_path.is_file() + assert expected_zip_path.stat().st_size > 0 + + # Check that the zip contents are expected + with ZipFile(expected_zip_path, "r") as zip_file: + zip_file.extractall(tmp_path / "bulk_downloads") + assert expected_image_path.exists() + assert expected_image_path.is_file() + assert expected_image_path.stat().st_size > 0 + assert expected_image_path.read_text() == mock_image_contents + + # Check that the correct events were emitted + assert len(event_bus.events) == 2 + assert event_bus.events[0].event_name == "bulk_download_started" + assert event_bus.events[1].event_name == "bulk_download_completed" + assert event_bus.events[1].payload["bulk_download_item_name"] == os.path.basename(expected_zip_path) + + +def test_stop(tmp_path: Path) -> None: + """Test that the stop method removes the bulk_downloads directory.""" + + bulk_download_service = BulkDownloadService(tmp_path) + + mock_file: Path = tmp_path / "bulk_downloads" / "test.zip" + mock_file.write_text("contents") + + bulk_download_service.stop() + + assert (tmp_path / "bulk_downloads").exists() + assert len(os.listdir(tmp_path / "bulk_downloads")) == 0 + + +def test_handler_on_image_not_found(tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker): + """Test that the handler emits an error event when the image is not found.""" + exception: Exception = ImageRecordNotFoundException("Image not found") + + def mock_get_dto(*args, **kwargs): + raise exception + + monkeypatch.setattr(mock_invoker.services.images, "get_dto", mock_get_dto) + + execute_handler_test_on_error(tmp_path, monkeypatch, mock_image_dto, mock_invoker, exception) + + +def test_handler_on_board_not_found(tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker): + """Test that the handler emits an error event when the image is not found.""" + + exception: Exception = BoardRecordNotFoundException("Image not found") + + def mock_get_board_name(*args, **kwargs): + raise exception + + monkeypatch.setattr(mock_invoker.services.images, "get_dto", mock_get_board_name) + + execute_handler_test_on_error(tmp_path, monkeypatch, mock_image_dto, mock_invoker, exception) + + +def test_handler_on_generic_exception( + tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker +): + """Test that the handler emits an error event when the image is not found.""" + + exception: Exception = Exception("Generic exception") + + def mock_get_board_name(*args, **kwargs): + raise exception + + monkeypatch.setattr(mock_invoker.services.images, "get_dto", mock_get_board_name) + + with pytest.raises(Exception): + execute_handler_test_on_error(tmp_path, monkeypatch, mock_image_dto, mock_invoker, exception) + + event_bus: DummyEventService = mock_invoker.services.events + + assert len(event_bus.events) == 2 + assert event_bus.events[0].event_name == "bulk_download_started" + assert event_bus.events[1].event_name == "bulk_download_failed" + assert event_bus.events[1].payload["error"] == exception.__str__() + + +def execute_handler_test_on_error( + tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker, error: Exception +): + bulk_download_service = BulkDownloadService(tmp_path) + bulk_download_service.start(mock_invoker) + bulk_download_service.handler([mock_image_dto.image_name], None) + + event_bus: DummyEventService = mock_invoker.services.events + + assert len(event_bus.events) == 2 + assert event_bus.events[0].event_name == "bulk_download_started" + assert event_bus.events[1].event_name == "bulk_download_failed" + assert event_bus.events[1].payload["error"] == error.__str__() diff --git a/tests/fixtures/event_service.py b/tests/fixtures/event_service.py index 71262be3f92..0a09fa0d64a 100644 --- a/tests/fixtures/event_service.py +++ b/tests/fixtures/event_service.py @@ -29,6 +29,6 @@ def dispatch(self, event_name: str, payload: Any) -> None: @pytest.fixture -def mock_event_service() -> EventServiceBase: +def mock_event_service() -> DummyEventService: """Create a dummy event service.""" return DummyEventService() From 3c157144f231f632ec5cc15ad68684956ab60cb6 Mon Sep 17 00:00:00 2001 From: Stefan Tobler Date: Mon, 15 Jan 2024 14:37:42 -0500 Subject: [PATCH 162/340] replacing import removed during rebase --- tests/app/services/model_install/test_model_install.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/app/services/model_install/test_model_install.py b/tests/app/services/model_install/test_model_install.py index 14c8ed5c84d..55f7e865410 100644 --- a/tests/app/services/model_install/test_model_install.py +++ b/tests/app/services/model_install/test_model_install.py @@ -10,6 +10,7 @@ from pydantic.networks import Url from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.model_install import ( InstallStatus, LocalModelSource, From faa16b960230480caec207586d8f9bfe803abb25 Mon Sep 17 00:00:00 2001 From: Stefan Tobler Date: Mon, 15 Jan 2024 15:58:43 -0500 Subject: [PATCH 163/340] cleaning up bulk download zip after the response is complete --- invokeai/app/api/routers/images.py | 3 +- .../bulk_download/bulk_download_base.py | 8 ++++ .../bulk_download/bulk_download_default.py | 43 +++++++++++-------- .../bulk_download/test_bulk_download.py | 16 ++++++- 4 files changed, 51 insertions(+), 19 deletions(-) diff --git a/invokeai/app/api/routers/images.py b/invokeai/app/api/routers/images.py index 236961fa9e4..d11c89c749b 100644 --- a/invokeai/app/api/routers/images.py +++ b/invokeai/app/api/routers/images.py @@ -409,10 +409,10 @@ async def download_images_from_list( }, ) async def get_bulk_download_item( + background_tasks: BackgroundTasks, bulk_download_item_name: str = Path(description="The bulk_download_item_id of the bulk download item to get"), ) -> FileResponse: """Gets a bulk download zip file""" - try: path = ApiDependencies.invoker.services.bulk_download.get_path(bulk_download_item_name) @@ -423,6 +423,7 @@ async def get_bulk_download_item( content_disposition_type="inline", ) response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}" + background_tasks.add_task(ApiDependencies.invoker.services.bulk_download.delete, bulk_download_item_name) return response except Exception: raise HTTPException(status_code=404) diff --git a/invokeai/app/services/bulk_download/bulk_download_base.py b/invokeai/app/services/bulk_download/bulk_download_base.py index 7a4aa0661c0..a1071f254ab 100644 --- a/invokeai/app/services/bulk_download/bulk_download_base.py +++ b/invokeai/app/services/bulk_download/bulk_download_base.py @@ -59,3 +59,11 @@ def stop(self, *args, **kwargs) -> None: Returns: None """ + + @abstractmethod + def delete(self, bulk_download_item_name: str) -> None: + """ + Delete the bulk download file. + + :param bulk_download_item_name: The name of the bulk download item. + """ diff --git a/invokeai/app/services/bulk_download/bulk_download_default.py b/invokeai/app/services/bulk_download/bulk_download_default.py index a9ea12bfd6f..a0abb6743ad 100644 --- a/invokeai/app/services/bulk_download/bulk_download_default.py +++ b/invokeai/app/services/bulk_download/bulk_download_default.py @@ -38,23 +38,6 @@ def __init__( self.__bulk_downloads_folder = self.__output_folder / "bulk_downloads" self.__bulk_downloads_folder.mkdir(parents=True, exist_ok=True) - def get_path(self, bulk_download_item_name: str) -> str: - """ - Get the path to the bulk download file. - - :param bulk_download_item_name: The name of the bulk download item. - :return: The path to the bulk download file. - """ - path = str(self.__bulk_downloads_folder / bulk_download_item_name) - if not self.validate_path(path): - raise BulkDownloadTargetException() - return path - - def validate_path(self, path: Union[str, Path]) -> bool: - """Validates the path given for a bulk download.""" - path = path if isinstance(path, Path) else Path(path) - return path.exists() - def handler(self, image_names: list[str], board_id: Optional[str]) -> None: """ Create a zip file containing the images specified by the given image names or board id. @@ -166,3 +149,29 @@ def stop(self, *args, **kwargs): # Delete all the files for file in files: file.unlink() + + def delete(self, bulk_download_item_name: str) -> None: + """ + Delete the bulk download file. + + :param bulk_download_item_name: The name of the bulk download item. + """ + path = self.get_path(bulk_download_item_name) + Path(path).unlink() + + def get_path(self, bulk_download_item_name: str) -> str: + """ + Get the path to the bulk download file. + + :param bulk_download_item_name: The name of the bulk download item. + :return: The path to the bulk download file. + """ + path = str(self.__bulk_downloads_folder / bulk_download_item_name) + if not self.validate_path(path): + raise BulkDownloadTargetException() + return path + + def validate_path(self, path: Union[str, Path]) -> bool: + """Validates the path given for a bulk download.""" + path = path if isinstance(path, Path) else Path(path) + return path.exists() diff --git a/tests/app/services/bulk_download/test_bulk_download.py b/tests/app/services/bulk_download/test_bulk_download.py index 4f476c21bed..4c9dc426124 100644 --- a/tests/app/services/bulk_download/test_bulk_download.py +++ b/tests/app/services/bulk_download/test_bulk_download.py @@ -293,7 +293,7 @@ def mock_get_board_name(*args, **kwargs): monkeypatch.setattr(mock_invoker.services.images, "get_dto", mock_get_board_name) - with pytest.raises(Exception): + with pytest.raises(Exception): # noqa: B017 execute_handler_test_on_error(tmp_path, monkeypatch, mock_image_dto, mock_invoker, exception) event_bus: DummyEventService = mock_invoker.services.events @@ -317,3 +317,17 @@ def execute_handler_test_on_error( assert event_bus.events[0].event_name == "bulk_download_started" assert event_bus.events[1].event_name == "bulk_download_failed" assert event_bus.events[1].payload["error"] == error.__str__() + + +def test_delete(tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker): + """Test that the delete method removes the bulk download file.""" + + bulk_download_service = BulkDownloadService(tmp_path) + + mock_file: Path = tmp_path / "bulk_downloads" / "test.zip" + mock_file.write_text("contents") + + bulk_download_service.delete("test.zip") + + assert (tmp_path / "bulk_downloads").exists() + assert len(os.listdir(tmp_path / "bulk_downloads")) == 0 From 604aa921475982a36bd3978788cfbc7fa95ca3f9 Mon Sep 17 00:00:00 2001 From: Stefan Tobler Date: Mon, 15 Jan 2024 19:11:03 -0500 Subject: [PATCH 164/340] adding test coverage for new bulk download routes --- pyproject.toml | 1 + tests/app/routers/test_images.py | 145 ++++++++++++++++++ .../bulk_download/test_bulk_download.py | 2 +- 3 files changed, 147 insertions(+), 1 deletion(-) create mode 100644 tests/app/routers/test_images.py diff --git a/pyproject.toml b/pyproject.toml index f57607bc0af..5345851951f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -119,6 +119,7 @@ dependencies = [ "pytest-cov", "pytest-datadir", "requests_testadapter", + "httpx", ] [project.scripts] diff --git a/tests/app/routers/test_images.py b/tests/app/routers/test_images.py new file mode 100644 index 00000000000..040ae01914d --- /dev/null +++ b/tests/app/routers/test_images.py @@ -0,0 +1,145 @@ +from pathlib import Path +from typing import Any + +import pytest +from fastapi import BackgroundTasks +from fastapi.testclient import TestClient + +from invokeai.app.api.dependencies import ApiDependencies +from invokeai.app.api_app import app +from invokeai.app.services.board_records.board_records_sqlite import SqliteBoardRecordStorage +from invokeai.app.services.bulk_download.bulk_download_default import BulkDownloadService +from invokeai.app.services.config.config_default import InvokeAIAppConfig +from invokeai.app.services.images.images_default import ImageService +from invokeai.app.services.invocation_services import InvocationServices +from invokeai.app.services.invoker import Invoker +from invokeai.backend.util.logging import InvokeAILogger +from tests.fixtures.sqlite_database import create_mock_sqlite_database + +client = TestClient(app) + + +@pytest.fixture +def mock_services(tmp_path: Path) -> InvocationServices: + configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0) + logger = InvokeAILogger.get_logger() + db = create_mock_sqlite_database(configuration, logger) + + return InvocationServices( + board_image_records=None, # type: ignore + board_images=None, # type: ignore + board_records=SqliteBoardRecordStorage(db=db), + boards=None, # type: ignore + bulk_download=BulkDownloadService(tmp_path), + configuration=None, # type: ignore + events=None, # type: ignore + graph_execution_manager=None, # type: ignore + image_files=None, # type: ignore + image_records=None, # type: ignore + images=ImageService(), + invocation_cache=None, # type: ignore + latents=None, # type: ignore + logger=logger, + model_manager=None, # type: ignore + model_records=None, # type: ignore + download_queue=None, # type: ignore + model_install=None, # type: ignore + names=None, # type: ignore + performance_statistics=None, # type: ignore + processor=None, # type: ignore + queue=None, # type: ignore + session_processor=None, # type: ignore + session_queue=None, # type: ignore + urls=None, # type: ignore + workflow_records=None, # type: ignore + ) + + +@pytest.fixture() +def mock_invoker(mock_services: InvocationServices) -> Invoker: + return Invoker(services=mock_services) + + +class MockApiDependencies(ApiDependencies): + invoker: Invoker + + def __init__(self, invoker) -> None: + self.invoker = invoker + + +def test_download_images_from_list(monkeypatch: Any, mock_invoker: Invoker) -> None: + prepare_download_images_test(monkeypatch, mock_invoker) + + response = client.post("/api/v1/images/download", json={"image_names": ["test.png"]}) + + assert response.status_code == 202 + + +def test_download_images_from_board_id_empty_image_name_list(monkeypatch: Any, mock_invoker: Invoker) -> None: + prepare_download_images_test(monkeypatch, mock_invoker) + + response = client.post("/api/v1/images/download", json={"image_names": [], "board_id": "test"}) + + assert response.status_code == 202 + + +def prepare_download_images_test(monkeypatch: Any, mock_invoker: Invoker) -> None: + monkeypatch.setattr("invokeai.app.api.routers.images.ApiDependencies", MockApiDependencies(mock_invoker)) + + def mock_add_task(*args, **kwargs): + return None + + monkeypatch.setattr(BackgroundTasks, "add_task", mock_add_task) + + +def test_download_images_with_empty_image_list_and_no_board_id(monkeypatch: Any, mock_invoker: Invoker) -> None: + prepare_download_images_test(monkeypatch, mock_invoker) + + response = client.post("/api/v1/images/download", json={"image_names": []}) + + assert response.status_code == 400 + + +def test_get_bulk_download_image(tmp_path: Path, monkeypatch: Any, mock_invoker: Invoker) -> None: + mock_file: Path = tmp_path / "test.zip" + mock_file.write_text("contents") + + monkeypatch.setattr(mock_invoker.services.bulk_download, "get_path", lambda x: str(mock_file)) + monkeypatch.setattr("invokeai.app.api.routers.images.ApiDependencies", MockApiDependencies(mock_invoker)) + + def mock_add_task(*args, **kwargs): + return None + + monkeypatch.setattr(BackgroundTasks, "add_task", mock_add_task) + + response = client.get("/api/v1/images/download/test.zip") + + assert response.status_code == 200 + assert response.content == b"contents" + + +def test_get_bulk_download_image_not_found(monkeypatch: Any, mock_invoker: Invoker) -> None: + monkeypatch.setattr("invokeai.app.api.routers.images.ApiDependencies", MockApiDependencies(mock_invoker)) + + def mock_add_task(*args, **kwargs): + return None + + monkeypatch.setattr(BackgroundTasks, "add_task", mock_add_task) + + response = client.get("/api/v1/images/download/test.zip") + + assert response.status_code == 404 + + +def test_get_bulk_download_image_image_deleted_after_response( + monkeypatch: Any, mock_invoker: Invoker, tmp_path: Path +) -> None: + mock_file: Path = tmp_path / "test.zip" + mock_file.write_text("contents") + + monkeypatch.setattr(mock_invoker.services.bulk_download, "get_path", lambda x: str(mock_file)) + monkeypatch.setattr("invokeai.app.api.routers.images.ApiDependencies", MockApiDependencies(mock_invoker)) + + client.get("/api/v1/images/download/test.zip") + + assert not (tmp_path / "test.zip").exists() diff --git a/tests/app/services/bulk_download/test_bulk_download.py b/tests/app/services/bulk_download/test_bulk_download.py index 4c9dc426124..bc6eb8d41ca 100644 --- a/tests/app/services/bulk_download/test_bulk_download.py +++ b/tests/app/services/bulk_download/test_bulk_download.py @@ -293,7 +293,7 @@ def mock_get_board_name(*args, **kwargs): monkeypatch.setattr(mock_invoker.services.images, "get_dto", mock_get_board_name) - with pytest.raises(Exception): # noqa: B017 + with pytest.raises(Exception): # noqa: B017 execute_handler_test_on_error(tmp_path, monkeypatch, mock_image_dto, mock_invoker, exception) event_bus: DummyEventService = mock_invoker.services.events From 2d3ee91dec08003d68222e9f6e72d41544855310 Mon Sep 17 00:00:00 2001 From: Stefan Tobler Date: Sun, 28 Jan 2024 00:38:01 -0500 Subject: [PATCH 165/340] narrowing bulk_download stop service scope --- .../bulk_download/bulk_download_default.py | 4 ++-- .../bulk_download/test_bulk_download.py | 20 ++++++++++++++++++- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/invokeai/app/services/bulk_download/bulk_download_default.py b/invokeai/app/services/bulk_download/bulk_download_default.py index a0abb6743ad..87966ad6223 100644 --- a/invokeai/app/services/bulk_download/bulk_download_default.py +++ b/invokeai/app/services/bulk_download/bulk_download_default.py @@ -143,8 +143,8 @@ def _signal_job_failed(self, bulk_download_id: str, bulk_download_item_id: str, def stop(self, *args, **kwargs): """Stop the bulk download service and delete the files in the bulk download folder.""" - # Get all the files in the bulk downloads folder - files = self.__bulk_downloads_folder.glob("*") + # Get all the files in the bulk downloads folder, only .zip files + files = self.__bulk_downloads_folder.glob("*.zip") # Delete all the files for file in files: diff --git a/tests/app/services/bulk_download/test_bulk_download.py b/tests/app/services/bulk_download/test_bulk_download.py index bc6eb8d41ca..184519866a0 100644 --- a/tests/app/services/bulk_download/test_bulk_download.py +++ b/tests/app/services/bulk_download/test_bulk_download.py @@ -319,7 +319,7 @@ def execute_handler_test_on_error( assert event_bus.events[1].payload["error"] == error.__str__() -def test_delete(tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker): +def test_delete(tmp_path: Path): """Test that the delete method removes the bulk download file.""" bulk_download_service = BulkDownloadService(tmp_path) @@ -331,3 +331,21 @@ def test_delete(tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock assert (tmp_path / "bulk_downloads").exists() assert len(os.listdir(tmp_path / "bulk_downloads")) == 0 + +def test_stop(tmp_path: Path): + """Test that the delete method removes the bulk download file.""" + + bulk_download_service = BulkDownloadService(tmp_path) + + mock_file: Path = tmp_path / "bulk_downloads" / "test.zip" + mock_file.write_text("contents") + + mock_dir: Path = tmp_path / "bulk_downloads" / "test" + mock_dir.mkdir(parents=True, exist_ok=True) + + + bulk_download_service.stop() + + assert (tmp_path / "bulk_downloads").exists() + assert mock_dir.exists() + assert len(os.listdir(tmp_path / "bulk_downloads")) == 1 From aefd4c4bf5ee5f37b52ffaee11fead12ffe56871 Mon Sep 17 00:00:00 2001 From: Stefan Tobler Date: Sun, 28 Jan 2024 01:23:38 -0500 Subject: [PATCH 166/340] returning the bulk_download_item_name on response for possible polling --- invokeai/app/api/routers/images.py | 28 +++++-- .../bulk_download/bulk_download_base.py | 11 ++- .../bulk_download/bulk_download_default.py | 9 ++- tests/app/routers/test_images.py | 22 +++++- .../bulk_download/test_bulk_download.py | 79 ++++++++++++++----- 5 files changed, 116 insertions(+), 33 deletions(-) diff --git a/invokeai/app/api/routers/images.py b/invokeai/app/api/routers/images.py index d11c89c749b..c12556aed62 100644 --- a/invokeai/app/api/routers/images.py +++ b/invokeai/app/api/routers/images.py @@ -1,6 +1,6 @@ import io import traceback -from typing import Optional +from typing import Optional, cast from fastapi import BackgroundTasks, Body, HTTPException, Path, Query, Request, Response, UploadFile from fastapi.responses import FileResponse @@ -13,6 +13,7 @@ from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO from invokeai.app.services.shared.pagination import OffsetPaginatedResults from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID, WorkflowWithoutIDValidator +from invokeai.app.util.misc import uuid_string from ..dependencies import ApiDependencies @@ -377,6 +378,7 @@ class ImagesDownloaded(BaseModel): response: Optional[str] = Field( description="If defined, the message to display to the user when images begin downloading" ) + bulk_download_item_name: str = Field(description="The bulk download item name of the bulk download item") @images_router.post( @@ -384,15 +386,31 @@ class ImagesDownloaded(BaseModel): ) async def download_images_from_list( background_tasks: BackgroundTasks, - image_names: list[str] = Body(description="The list of names of images to download", embed=True), + image_names: Optional[list[str]] = Body( + default=None, description="The list of names of images to download", embed=True + ), board_id: Optional[str] = Body( default=None, description="The board from which image should be downloaded from", embed=True ), ) -> ImagesDownloaded: if (image_names is None or len(image_names) == 0) and board_id is None: raise HTTPException(status_code=400, detail="No images or board id specified.") - background_tasks.add_task(ApiDependencies.invoker.services.bulk_download.handler, image_names, board_id) - return ImagesDownloaded(response="Your images are preparing to be downloaded") + bulk_download_item_id: str = uuid_string() if board_id is None else board_id + board_name: str = ( + "" if board_id is None else ApiDependencies.invoker.services.board_records.get(board_id).board_name + ) + + # Type narrowing handled above ^, we know that image_names is not None, trying to keep null checks at the boundaries + background_tasks.add_task( + ApiDependencies.invoker.services.bulk_download.handler, + cast(list[str], image_names), + board_id, + bulk_download_item_id, + ) + return ImagesDownloaded( + response="Your images are preparing to be downloaded", + bulk_download_item_name=bulk_download_item_id if board_id is None else board_name + ".zip", + ) @images_router.api_route( @@ -410,7 +428,7 @@ async def download_images_from_list( ) async def get_bulk_download_item( background_tasks: BackgroundTasks, - bulk_download_item_name: str = Path(description="The bulk_download_item_id of the bulk download item to get"), + bulk_download_item_name: str = Path(description="The bulk_download_item_name of the bulk download item to get"), ) -> FileResponse: """Gets a bulk download zip file""" try: diff --git a/invokeai/app/services/bulk_download/bulk_download_base.py b/invokeai/app/services/bulk_download/bulk_download_base.py index a1071f254ab..d6b0e62211b 100644 --- a/invokeai/app/services/bulk_download/bulk_download_base.py +++ b/invokeai/app/services/bulk_download/bulk_download_base.py @@ -26,7 +26,7 @@ def __init__(self, output_folder: Union[str, Path]): """ @abstractmethod - def handler(self, image_names: list[str], board_id: Optional[str]) -> None: + def handler(self, image_names: list[str], board_id: Optional[str], bulk_download_item_id: Optional[str]) -> None: """ Starts a a bulk download job. @@ -44,6 +44,15 @@ def get_path(self, bulk_download_item_name: str) -> str: :return: The path to the bulk download file. """ + @abstractmethod + def get_board_name(self, board_id: str) -> str: + """ + Get the name of the board. + + :param board_id: The ID of the board. + :return: The name of the board. + """ + @abstractmethod def stop(self, *args, **kwargs) -> None: """ diff --git a/invokeai/app/services/bulk_download/bulk_download_default.py b/invokeai/app/services/bulk_download/bulk_download_default.py index 87966ad6223..be70dea2c1d 100644 --- a/invokeai/app/services/bulk_download/bulk_download_default.py +++ b/invokeai/app/services/bulk_download/bulk_download_default.py @@ -38,7 +38,7 @@ def __init__( self.__bulk_downloads_folder = self.__output_folder / "bulk_downloads" self.__bulk_downloads_folder.mkdir(parents=True, exist_ok=True) - def handler(self, image_names: list[str], board_id: Optional[str]) -> None: + def handler(self, image_names: list[str], board_id: Optional[str], bulk_download_item_id: Optional[str]) -> None: """ Create a zip file containing the images specified by the given image names or board id. @@ -47,7 +47,8 @@ def handler(self, image_names: list[str], board_id: Optional[str]) -> None: """ bulk_download_id: str = DEFAULT_BULK_DOWNLOAD_ID - bulk_download_item_id: str = uuid_string() if board_id is None else board_id + if bulk_download_item_id is None: + bulk_download_item_id = uuid_string() if board_id is None else board_id self._signal_job_started(bulk_download_id, bulk_download_item_id) @@ -56,7 +57,7 @@ def handler(self, image_names: list[str], board_id: Optional[str]) -> None: image_dtos: list[ImageDTO] = [] if board_id: - board_name = self._get_board_name(board_id) + board_name = self.get_board_name(board_id) board_name = self._clean_string_to_path_safe(board_name) # -1 is the default value for limit, which means no limit, is_intermediate only gives us completed images @@ -79,7 +80,7 @@ def handler(self, image_names: list[str], board_id: Optional[str]) -> None: self.__invoker.services.logger.error("Problem bulk downloading images.") raise e - def _get_board_name(self, board_id: str) -> str: + def get_board_name(self, board_id: str) -> str: if board_id == "none": return "Uncategorized" diff --git a/tests/app/routers/test_images.py b/tests/app/routers/test_images.py index 040ae01914d..a709daf24e3 100644 --- a/tests/app/routers/test_images.py +++ b/tests/app/routers/test_images.py @@ -7,6 +7,7 @@ from invokeai.app.api.dependencies import ApiDependencies from invokeai.app.api_app import app +from invokeai.app.services.board_records.board_records_common import BoardRecord from invokeai.app.services.board_records.board_records_sqlite import SqliteBoardRecordStorage from invokeai.app.services.bulk_download.bulk_download_default import BulkDownloadService from invokeai.app.services.config.config_default import InvokeAIAppConfig @@ -70,17 +71,32 @@ def __init__(self, invoker) -> None: def test_download_images_from_list(monkeypatch: Any, mock_invoker: Invoker) -> None: prepare_download_images_test(monkeypatch, mock_invoker) - response = client.post("/api/v1/images/download", json={"image_names": ["test.png"]}) + def mock_uuid_string(): + return "test" + + # You have to patch the function within the module it's being imported into. This is strange, but it works. + # See http://www.gregreda.com/2021/06/28/mocking-imported-module-function-python/ + monkeypatch.setattr("invokeai.app.api.routers.images.uuid_string", mock_uuid_string) + response = client.post("/api/v1/images/download", json={"image_names": ["test.png"]}) + json_response = response.json() assert response.status_code == 202 + assert json_response["bulk_download_item_name"] == "test" def test_download_images_from_board_id_empty_image_name_list(monkeypatch: Any, mock_invoker: Invoker) -> None: - prepare_download_images_test(monkeypatch, mock_invoker) + expected_board_name = "test" + + def mock_get(*args, **kwargs): + return BoardRecord(board_id="12345", board_name=expected_board_name, created_at="None", updated_at="None") - response = client.post("/api/v1/images/download", json={"image_names": [], "board_id": "test"}) + monkeypatch.setattr(mock_invoker.services.board_records, "get", mock_get) + prepare_download_images_test(monkeypatch, mock_invoker) + response = client.post("/api/v1/images/download", json={"board_id": "test"}) + json_response = response.json() assert response.status_code == 202 + assert json_response["bulk_download_item_name"] == "test.zip" def prepare_download_images_test(monkeypatch: Any, mock_invoker: Invoker) -> None: diff --git a/tests/app/services/bulk_download/test_bulk_download.py b/tests/app/services/bulk_download/test_bulk_download.py index 184519866a0..7909c44214e 100644 --- a/tests/app/services/bulk_download/test_bulk_download.py +++ b/tests/app/services/bulk_download/test_bulk_download.py @@ -125,7 +125,7 @@ def test_handler_image_names(tmp_path: Path, monkeypatch: Any, mock_image_dto: I bulk_download_service = BulkDownloadService(tmp_path) bulk_download_service.start(mock_invoker) - bulk_download_service.handler([mock_image_dto.image_name], None) + bulk_download_service.handler([mock_image_dto.image_name], None, None) assert_handler_success( expected_zip_path, expected_image_path, mock_image_contents, tmp_path, mock_invoker.services.events @@ -151,7 +151,7 @@ def mock_get_many(*args, **kwargs): bulk_download_service = BulkDownloadService(tmp_path) bulk_download_service.start(mock_invoker) - bulk_download_service.handler([], "test") + bulk_download_service.handler([], "test", None) assert_handler_success( expected_zip_path, expected_image_path, mock_image_contents, tmp_path, mock_invoker.services.events @@ -173,7 +173,31 @@ def mock_get_many(*args, **kwargs): bulk_download_service = BulkDownloadService(tmp_path) bulk_download_service.start(mock_invoker) - bulk_download_service.handler([], "none") + bulk_download_service.handler([], "none", None) + + assert_handler_success( + expected_zip_path, expected_image_path, mock_image_contents, tmp_path, mock_invoker.services.events + ) + + +def test_handler_bulk_download__item_id_given( + tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker +): + """Test that the handler creates the zip file correctly when given a pregenerated bulk download item id.""" + + _, expected_image_path, mock_image_contents = prepare_handler_test( + tmp_path, monkeypatch, mock_image_dto, mock_invoker + ) + expected_zip_path: Path = tmp_path / "bulk_downloads" / "test_id.zip" + + def mock_get_many(*args, **kwargs): + return OffsetPaginatedResults(limit=-1, total=1, offset=0, items=[mock_image_dto]) + + monkeypatch.setattr(mock_invoker.services.images, "get_many", mock_get_many) + + bulk_download_service = BulkDownloadService(tmp_path) + bulk_download_service.start(mock_invoker) + bulk_download_service.handler([mock_image_dto.image_name], None, "test_id") assert_handler_success( expected_zip_path, expected_image_path, mock_image_contents, tmp_path, mock_invoker.services.events @@ -242,20 +266,6 @@ def assert_handler_success( assert event_bus.events[1].payload["bulk_download_item_name"] == os.path.basename(expected_zip_path) -def test_stop(tmp_path: Path) -> None: - """Test that the stop method removes the bulk_downloads directory.""" - - bulk_download_service = BulkDownloadService(tmp_path) - - mock_file: Path = tmp_path / "bulk_downloads" / "test.zip" - mock_file.write_text("contents") - - bulk_download_service.stop() - - assert (tmp_path / "bulk_downloads").exists() - assert len(os.listdir(tmp_path / "bulk_downloads")) == 0 - - def test_handler_on_image_not_found(tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker): """Test that the handler emits an error event when the image is not found.""" exception: Exception = ImageRecordNotFoundException("Image not found") @@ -309,7 +319,7 @@ def execute_handler_test_on_error( ): bulk_download_service = BulkDownloadService(tmp_path) bulk_download_service.start(mock_invoker) - bulk_download_service.handler([mock_image_dto.image_name], None) + bulk_download_service.handler([mock_image_dto.image_name], None, None) event_bus: DummyEventService = mock_invoker.services.events @@ -319,6 +329,35 @@ def execute_handler_test_on_error( assert event_bus.events[1].payload["error"] == error.__str__() +def test_get_board_name(tmp_path: Path, monkeypatch: Any, mock_invoker: Invoker): + """Test that the get_board_name function returns the correct board name.""" + + expected_board_name = "board1" + + def mock_get(*args, **kwargs): + return BoardRecord(board_id="12345", board_name=expected_board_name, created_at="None", updated_at="None") + + monkeypatch.setattr(mock_invoker.services.board_records, "get", mock_get) + + bulk_download_service = BulkDownloadService(tmp_path) + bulk_download_service.start(mock_invoker) + board_name = bulk_download_service.get_board_name("12345") + + assert board_name == expected_board_name + + +def test_get_board_name_default(tmp_path: Path, mock_invoker: Invoker): + """Test that the get_board_name function returns the correct board name.""" + + expected_board_name = "Uncategorized" + + bulk_download_service = BulkDownloadService(tmp_path) + bulk_download_service.start(mock_invoker) + board_name = bulk_download_service.get_board_name("none") + + assert board_name == expected_board_name + + def test_delete(tmp_path: Path): """Test that the delete method removes the bulk download file.""" @@ -332,8 +371,9 @@ def test_delete(tmp_path: Path): assert (tmp_path / "bulk_downloads").exists() assert len(os.listdir(tmp_path / "bulk_downloads")) == 0 + def test_stop(tmp_path: Path): - """Test that the delete method removes the bulk download file.""" + """Test that the stop method removes the bulk download file and not any directories.""" bulk_download_service = BulkDownloadService(tmp_path) @@ -343,7 +383,6 @@ def test_stop(tmp_path: Path): mock_dir: Path = tmp_path / "bulk_downloads" / "test" mock_dir.mkdir(parents=True, exist_ok=True) - bulk_download_service.stop() assert (tmp_path / "bulk_downloads").exists() From 3080029f164df1c8465289ae662b4379ebea63ef Mon Sep 17 00:00:00 2001 From: Stefan Tobler Date: Sun, 28 Jan 2024 18:59:56 -0500 Subject: [PATCH 167/340] using temp directory for downloads --- invokeai/app/api/routers/images.py | 2 +- .../bulk_download/bulk_download_base.py | 2 +- .../bulk_download/bulk_download_default.py | 18 +++++----- .../bulk_download/test_bulk_download.py | 35 ++++++++++++++----- 4 files changed, 37 insertions(+), 20 deletions(-) diff --git a/invokeai/app/api/routers/images.py b/invokeai/app/api/routers/images.py index c12556aed62..69a76e40624 100644 --- a/invokeai/app/api/routers/images.py +++ b/invokeai/app/api/routers/images.py @@ -397,7 +397,7 @@ async def download_images_from_list( raise HTTPException(status_code=400, detail="No images or board id specified.") bulk_download_item_id: str = uuid_string() if board_id is None else board_id board_name: str = ( - "" if board_id is None else ApiDependencies.invoker.services.board_records.get(board_id).board_name + "" if board_id is None else ApiDependencies.invoker.services.bulk_download.get_clean_board_name(board_id) ) # Type narrowing handled above ^, we know that image_names is not None, trying to keep null checks at the boundaries diff --git a/invokeai/app/services/bulk_download/bulk_download_base.py b/invokeai/app/services/bulk_download/bulk_download_base.py index d6b0e62211b..89b2e737720 100644 --- a/invokeai/app/services/bulk_download/bulk_download_base.py +++ b/invokeai/app/services/bulk_download/bulk_download_base.py @@ -45,7 +45,7 @@ def get_path(self, bulk_download_item_name: str) -> str: """ @abstractmethod - def get_board_name(self, board_id: str) -> str: + def get_clean_board_name(self, board_id: str) -> str: """ Get the name of the board. diff --git a/invokeai/app/services/bulk_download/bulk_download_default.py b/invokeai/app/services/bulk_download/bulk_download_default.py index be70dea2c1d..fe76a123337 100644 --- a/invokeai/app/services/bulk_download/bulk_download_default.py +++ b/invokeai/app/services/bulk_download/bulk_download_default.py @@ -1,4 +1,5 @@ from pathlib import Path +from tempfile import TemporaryDirectory from typing import Optional, Union from zipfile import ZipFile @@ -19,6 +20,7 @@ class BulkDownloadService(BulkDownloadBase): __output_folder: Path + __temp_directory: TemporaryDirectory __bulk_downloads_folder: Path __event_bus: EventServiceBase __invoker: Invoker @@ -35,7 +37,8 @@ def __init__( Initialize the downloader object. """ self.__output_folder: Path = output_folder if isinstance(output_folder, Path) else Path(output_folder) - self.__bulk_downloads_folder = self.__output_folder / "bulk_downloads" + self.__temp_directory = TemporaryDirectory(dir=self.__output_folder) + self.__bulk_downloads_folder = Path(self.__temp_directory.name) / "bulk_downloads" self.__bulk_downloads_folder.mkdir(parents=True, exist_ok=True) def handler(self, image_names: list[str], board_id: Optional[str], bulk_download_item_id: Optional[str]) -> None: @@ -57,8 +60,7 @@ def handler(self, image_names: list[str], board_id: Optional[str], bulk_download image_dtos: list[ImageDTO] = [] if board_id: - board_name = self.get_board_name(board_id) - board_name = self._clean_string_to_path_safe(board_name) + board_name = self.get_clean_board_name(board_id) # -1 is the default value for limit, which means no limit, is_intermediate only gives us completed images image_dtos = self.__invoker.services.images.get_many( @@ -80,11 +82,11 @@ def handler(self, image_names: list[str], board_id: Optional[str], bulk_download self.__invoker.services.logger.error("Problem bulk downloading images.") raise e - def get_board_name(self, board_id: str) -> str: + def get_clean_board_name(self, board_id: str) -> str: if board_id == "none": return "Uncategorized" - return self.__invoker.services.board_records.get(board_id).board_name + return self._clean_string_to_path_safe(self.__invoker.services.board_records.get(board_id).board_name) def _create_zip_file(self, image_dtos: list[ImageDTO], bulk_download_item_id: str) -> str: """ @@ -145,11 +147,7 @@ def _signal_job_failed(self, bulk_download_id: str, bulk_download_item_id: str, def stop(self, *args, **kwargs): """Stop the bulk download service and delete the files in the bulk download folder.""" # Get all the files in the bulk downloads folder, only .zip files - files = self.__bulk_downloads_folder.glob("*.zip") - - # Delete all the files - for file in files: - file.unlink() + self.__temp_directory.cleanup() def delete(self, bulk_download_item_name: str) -> None: """ diff --git a/tests/app/services/bulk_download/test_bulk_download.py b/tests/app/services/bulk_download/test_bulk_download.py index 7909c44214e..3cd2123232f 100644 --- a/tests/app/services/bulk_download/test_bulk_download.py +++ b/tests/app/services/bulk_download/test_bulk_download.py @@ -1,5 +1,6 @@ import os from pathlib import Path +from tempfile import TemporaryDirectory from typing import Any from zipfile import ZipFile @@ -86,9 +87,28 @@ def mock_invoker(mock_services: InvocationServices) -> Invoker: return Invoker(services=mock_services) +@pytest.fixture(autouse=True) +def mock_temporary_directory(monkeypatch: Any, tmp_path: Path): + """Mock the TemporaryDirectory class so that it uses the tmp_path fixture.""" + + class MockTemporaryDirectory(TemporaryDirectory): + def __init__(self): + super().__init__(dir=tmp_path) + self.name = tmp_path + + def mock_TemporaryDirectory(*args, **kwargs): + return MockTemporaryDirectory() + + monkeypatch.setattr( + "invokeai.app.services.bulk_download.bulk_download_default.TemporaryDirectory", mock_TemporaryDirectory + ) + + def test_get_path_when_file_exists(tmp_path: Path) -> None: """Test get_path when the file exists.""" + bulk_download_service = BulkDownloadService(tmp_path) + # Create a directory at tmp_path/bulk_downloads test_bulk_downloads_dir: Path = tmp_path / "bulk_downloads" test_bulk_downloads_dir.mkdir(parents=True, exist_ok=True) @@ -97,7 +117,6 @@ def test_get_path_when_file_exists(tmp_path: Path) -> None: test_file_path: Path = test_bulk_downloads_dir / "test.zip" test_file_path.touch() - bulk_download_service = BulkDownloadService(tmp_path) assert bulk_download_service.get_path("test.zip") == str(test_file_path) @@ -164,7 +183,6 @@ def test_handler_board_id_default(tmp_path: Path, monkeypatch: Any, mock_image_d _, expected_image_path, mock_image_contents = prepare_handler_test( tmp_path, monkeypatch, mock_image_dto, mock_invoker ) - expected_zip_path: Path = tmp_path / "bulk_downloads" / "Uncategorized.zip" def mock_get_many(*args, **kwargs): return OffsetPaginatedResults(limit=-1, total=1, offset=0, items=[mock_image_dto]) @@ -175,6 +193,8 @@ def mock_get_many(*args, **kwargs): bulk_download_service.start(mock_invoker) bulk_download_service.handler([], "none", None) + expected_zip_path: Path = tmp_path / "bulk_downloads" / "Uncategorized.zip" + assert_handler_success( expected_zip_path, expected_image_path, mock_image_contents, tmp_path, mock_invoker.services.events ) @@ -188,7 +208,6 @@ def test_handler_bulk_download__item_id_given( _, expected_image_path, mock_image_contents = prepare_handler_test( tmp_path, monkeypatch, mock_image_dto, mock_invoker ) - expected_zip_path: Path = tmp_path / "bulk_downloads" / "test_id.zip" def mock_get_many(*args, **kwargs): return OffsetPaginatedResults(limit=-1, total=1, offset=0, items=[mock_image_dto]) @@ -199,6 +218,8 @@ def mock_get_many(*args, **kwargs): bulk_download_service.start(mock_invoker) bulk_download_service.handler([mock_image_dto.image_name], None, "test_id") + expected_zip_path: Path = tmp_path / "bulk_downloads" / "test_id.zip" + assert_handler_success( expected_zip_path, expected_image_path, mock_image_contents, tmp_path, mock_invoker.services.events ) @@ -341,7 +362,7 @@ def mock_get(*args, **kwargs): bulk_download_service = BulkDownloadService(tmp_path) bulk_download_service.start(mock_invoker) - board_name = bulk_download_service.get_board_name("12345") + board_name = bulk_download_service.get_clean_board_name("12345") assert board_name == expected_board_name @@ -353,7 +374,7 @@ def test_get_board_name_default(tmp_path: Path, mock_invoker: Invoker): bulk_download_service = BulkDownloadService(tmp_path) bulk_download_service.start(mock_invoker) - board_name = bulk_download_service.get_board_name("none") + board_name = bulk_download_service.get_clean_board_name("none") assert board_name == expected_board_name @@ -385,6 +406,4 @@ def test_stop(tmp_path: Path): bulk_download_service.stop() - assert (tmp_path / "bulk_downloads").exists() - assert mock_dir.exists() - assert len(os.listdir(tmp_path / "bulk_downloads")) == 1 + assert not (tmp_path / "bulk_downloads").exists() From e0051f42d0d941391e440042f8fd72ff58abb3eb Mon Sep 17 00:00:00 2001 From: Stefan Tobler Date: Fri, 16 Feb 2024 14:10:50 -0500 Subject: [PATCH 168/340] updating imports to satisfy ruff --- tests/app/services/bulk_download/test_bulk_download.py | 2 +- tests/conftest.py | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/app/services/bulk_download/test_bulk_download.py b/tests/app/services/bulk_download/test_bulk_download.py index 3cd2123232f..924385f7e1a 100644 --- a/tests/app/services/bulk_download/test_bulk_download.py +++ b/tests/app/services/bulk_download/test_bulk_download.py @@ -22,7 +22,7 @@ from invokeai.app.services.invoker import Invoker from invokeai.app.services.shared.pagination import OffsetPaginatedResults from invokeai.backend.util.logging import InvokeAILogger -from tests.fixtures.event_service import DummyEventService, mock_event_service # noqa: F401,F811 +from tests.fixtures.event_service import DummyEventService from tests.fixtures.sqlite_database import create_mock_sqlite_database diff --git a/tests/conftest.py b/tests/conftest.py index 1c816002296..85fecfe440f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,2 +1,8 @@ # conftest.py is a special pytest file. Fixtures defined in this file will be accessible to all tests in this directory # without needing to explicitly import them. (https://docs.pytest.org/en/6.2.x/fixture.html) + + +# We import the model_installer and torch_device fixtures here so that they can be used by all tests. Flake8 does not +# play well with fixtures (F401 and F811), so this is cleaner than importing in all files that use these fixtures. +from invokeai.backend.util.test_utils import torch_device # noqa: F401 +from tests.fixtures.event_service import mock_event_service # noqa: F401 From 62d899ffe87619cecd15c81d4743923c9f930eae Mon Sep 17 00:00:00 2001 From: Stefan Tobler Date: Fri, 16 Feb 2024 15:50:48 -0500 Subject: [PATCH 169/340] moving the responsibility of cleaning up board names to the service not the route --- invokeai/app/api/routers/images.py | 8 +-- .../bulk_download/bulk_download_base.py | 18 ++--- .../bulk_download/bulk_download_default.py | 23 +++--- tests/app/routers/test_images.py | 13 ++-- .../bulk_download/test_bulk_download.py | 71 ++++++++++--------- 5 files changed, 63 insertions(+), 70 deletions(-) diff --git a/invokeai/app/api/routers/images.py b/invokeai/app/api/routers/images.py index 69a76e40624..d1c64648de6 100644 --- a/invokeai/app/api/routers/images.py +++ b/invokeai/app/api/routers/images.py @@ -13,7 +13,6 @@ from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO from invokeai.app.services.shared.pagination import OffsetPaginatedResults from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID, WorkflowWithoutIDValidator -from invokeai.app.util.misc import uuid_string from ..dependencies import ApiDependencies @@ -395,10 +394,7 @@ async def download_images_from_list( ) -> ImagesDownloaded: if (image_names is None or len(image_names) == 0) and board_id is None: raise HTTPException(status_code=400, detail="No images or board id specified.") - bulk_download_item_id: str = uuid_string() if board_id is None else board_id - board_name: str = ( - "" if board_id is None else ApiDependencies.invoker.services.bulk_download.get_clean_board_name(board_id) - ) + bulk_download_item_id: str = ApiDependencies.invoker.services.bulk_download.generate_item_id(board_id) # Type narrowing handled above ^, we know that image_names is not None, trying to keep null checks at the boundaries background_tasks.add_task( @@ -409,7 +405,7 @@ async def download_images_from_list( ) return ImagesDownloaded( response="Your images are preparing to be downloaded", - bulk_download_item_name=bulk_download_item_id if board_id is None else board_name + ".zip", + bulk_download_item_name=bulk_download_item_id + ".zip", ) diff --git a/invokeai/app/services/bulk_download/bulk_download_base.py b/invokeai/app/services/bulk_download/bulk_download_base.py index 89b2e737720..5199652ad41 100644 --- a/invokeai/app/services/bulk_download/bulk_download_base.py +++ b/invokeai/app/services/bulk_download/bulk_download_base.py @@ -30,9 +30,9 @@ def handler(self, image_names: list[str], board_id: Optional[str], bulk_download """ Starts a a bulk download job. - :param invoker: The Invoker that holds all the services, required to be passed as a parameter to avoid circular dependencies. :param image_names: A list of image names to include in the zip file. :param board_id: The ID of the board. If provided, all images associated with the board will be included in the zip file. + :param bulk_download_item_id: The bulk_download_item_id that will be used to retrieve the bulk download item when it is prepared, if none is provided a uuid will be generated. """ @abstractmethod @@ -45,12 +45,12 @@ def get_path(self, bulk_download_item_name: str) -> str: """ @abstractmethod - def get_clean_board_name(self, board_id: str) -> str: + def generate_item_id(self, board_id: Optional[str]) -> str: """ - Get the name of the board. + Generate an item ID for a bulk download item. - :param board_id: The ID of the board. - :return: The name of the board. + :param board_id: The ID of the board whose name is to be included in the item id. + :return: The generated item ID. """ @abstractmethod @@ -61,12 +61,8 @@ def stop(self, *args, **kwargs) -> None: This method is responsible for stopping the BulkDownloadService and performing any necessary cleanup operations to remove any remnants or resources associated with the service. - Args: - *args: Variable length argument list. - **kwargs: Arbitrary keyword arguments. - - Returns: - None + :param *args: Variable length argument list. + :param **kwargs: Arbitrary keyword arguments. """ @abstractmethod diff --git a/invokeai/app/services/bulk_download/bulk_download_default.py b/invokeai/app/services/bulk_download/bulk_download_default.py index fe76a123337..406bd7d9972 100644 --- a/invokeai/app/services/bulk_download/bulk_download_default.py +++ b/invokeai/app/services/bulk_download/bulk_download_default.py @@ -50,19 +50,15 @@ def handler(self, image_names: list[str], board_id: Optional[str], bulk_download """ bulk_download_id: str = DEFAULT_BULK_DOWNLOAD_ID - if bulk_download_item_id is None: - bulk_download_item_id = uuid_string() if board_id is None else board_id + bulk_download_item_id = uuid_string() if bulk_download_item_id is None else bulk_download_item_id self._signal_job_started(bulk_download_id, bulk_download_item_id) try: - board_name: str = "" image_dtos: list[ImageDTO] = [] if board_id: - board_name = self.get_clean_board_name(board_id) - - # -1 is the default value for limit, which means no limit, is_intermediate only gives us completed images + # -1 is the default value for limit, which means no limit, is_intermediate False only gives us completed images image_dtos = self.__invoker.services.images.get_many( offset=0, limit=-1, @@ -71,9 +67,7 @@ def handler(self, image_names: list[str], board_id: Optional[str], bulk_download ).items else: image_dtos = [self.__invoker.services.images.get_dto(image_name) for image_name in image_names] - bulk_download_item_name: str = self._create_zip_file( - image_dtos, bulk_download_item_id if board_id is None else board_name - ) + bulk_download_item_name: str = self._create_zip_file(image_dtos, bulk_download_item_id) self._signal_job_completed(bulk_download_id, bulk_download_item_id, bulk_download_item_name) except (ImageRecordNotFoundException, BoardRecordNotFoundException, BulkDownloadException) as e: self._signal_job_failed(bulk_download_id, bulk_download_item_id, e) @@ -82,7 +76,10 @@ def handler(self, image_names: list[str], board_id: Optional[str], bulk_download self.__invoker.services.logger.error("Problem bulk downloading images.") raise e - def get_clean_board_name(self, board_id: str) -> str: + def generate_item_id(self, board_id: Optional[str]) -> str: + return uuid_string() if board_id is None else self._get_clean_board_name(board_id) + "_" + uuid_string() + + def _get_clean_board_name(self, board_id: str) -> str: if board_id == "none": return "Uncategorized" @@ -109,7 +106,7 @@ def _create_zip_file(self, image_dtos: list[ImageDTO], bulk_download_item_id: st # from https://stackoverflow.com/questions/7406102/create-sane-safe-filename-from-any-unsafe-string def _clean_string_to_path_safe(self, s: str) -> str: """Clean a string to be path safe.""" - return "".join([c for c in s if c.isalpha() or c.isdigit() or c == " "]).rstrip() + return "".join([c for c in s if c.isalpha() or c.isdigit() or c == " " or c == "_" or c == "-"]).rstrip() def _signal_job_started(self, bulk_download_id: str, bulk_download_item_id: str) -> None: """Signal that a bulk download job has started.""" @@ -166,11 +163,11 @@ def get_path(self, bulk_download_item_name: str) -> str: :return: The path to the bulk download file. """ path = str(self.__bulk_downloads_folder / bulk_download_item_name) - if not self.validate_path(path): + if not self._is_valid_path(path): raise BulkDownloadTargetException() return path - def validate_path(self, path: Union[str, Path]) -> bool: + def _is_valid_path(self, path: Union[str, Path]) -> bool: """Validates the path given for a bulk download.""" path = path if isinstance(path, Path) else Path(path) return path.exists() diff --git a/tests/app/routers/test_images.py b/tests/app/routers/test_images.py index a709daf24e3..e8521bf1322 100644 --- a/tests/app/routers/test_images.py +++ b/tests/app/routers/test_images.py @@ -71,17 +71,10 @@ def __init__(self, invoker) -> None: def test_download_images_from_list(monkeypatch: Any, mock_invoker: Invoker) -> None: prepare_download_images_test(monkeypatch, mock_invoker) - def mock_uuid_string(): - return "test" - - # You have to patch the function within the module it's being imported into. This is strange, but it works. - # See http://www.gregreda.com/2021/06/28/mocking-imported-module-function-python/ - monkeypatch.setattr("invokeai.app.api.routers.images.uuid_string", mock_uuid_string) - response = client.post("/api/v1/images/download", json={"image_names": ["test.png"]}) json_response = response.json() assert response.status_code == 202 - assert json_response["bulk_download_item_name"] == "test" + assert json_response["bulk_download_item_name"] == "test.zip" def test_download_images_from_board_id_empty_image_name_list(monkeypatch: Any, mock_invoker: Invoker) -> None: @@ -101,6 +94,10 @@ def mock_get(*args, **kwargs): def prepare_download_images_test(monkeypatch: Any, mock_invoker: Invoker) -> None: monkeypatch.setattr("invokeai.app.api.routers.images.ApiDependencies", MockApiDependencies(mock_invoker)) + monkeypatch.setattr( + "invokeai.app.api.routers.images.ApiDependencies.invoker.services.bulk_download.generate_item_id", + lambda arg: "test", + ) def mock_add_task(*args, **kwargs): return None diff --git a/tests/app/services/bulk_download/test_bulk_download.py b/tests/app/services/bulk_download/test_bulk_download.py index 924385f7e1a..d70510cd91a 100644 --- a/tests/app/services/bulk_download/test_bulk_download.py +++ b/tests/app/services/bulk_download/test_bulk_download.py @@ -151,6 +151,42 @@ def test_handler_image_names(tmp_path: Path, monkeypatch: Any, mock_image_dto: I ) +def test_generate_id(monkeypatch: Any): + """Test that the generate_id method generates a unique id.""" + + bulk_download_service = BulkDownloadService("test") + + monkeypatch.setattr("invokeai.app.services.bulk_download.bulk_download_default.uuid_string", lambda: "test") + + assert bulk_download_service.generate_item_id(None) == "test" + + +def test_generate_id_with_board_id(monkeypatch: Any, mock_invoker: Invoker): + """Test that the generate_id method generates a unique id with a board id.""" + + bulk_download_service = BulkDownloadService("test") + bulk_download_service.start(mock_invoker) + + def mock_board_get(*args, **kwargs): + return BoardRecord(board_id="12345", board_name="test_board_name", created_at="None", updated_at="None") + + monkeypatch.setattr(mock_invoker.services.board_records, "get", mock_board_get) + + monkeypatch.setattr("invokeai.app.services.bulk_download.bulk_download_default.uuid_string", lambda: "test") + + assert bulk_download_service.generate_item_id("12345") == "test_board_name_test" + + +def test_generate_id_with_default_board_id(monkeypatch: Any): + """Test that the generate_id method generates a unique id with a board id.""" + + bulk_download_service = BulkDownloadService("test") + + monkeypatch.setattr("invokeai.app.services.bulk_download.bulk_download_default.uuid_string", lambda: "test") + + assert bulk_download_service.generate_item_id("none") == "Uncategorized_test" + + def test_handler_board_id(tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker): """Test that the handler creates the zip file correctly when given a board id.""" @@ -159,7 +195,7 @@ def test_handler_board_id(tmp_path: Path, monkeypatch: Any, mock_image_dto: Imag ) def mock_board_get(*args, **kwargs): - return BoardRecord(board_id="12345", board_name="test", created_at="None", updated_at="None") + return BoardRecord(board_id="12345", board_name="test_board_name", created_at="None", updated_at="None") monkeypatch.setattr(mock_invoker.services.board_records, "get", mock_board_get) @@ -193,14 +229,14 @@ def mock_get_many(*args, **kwargs): bulk_download_service.start(mock_invoker) bulk_download_service.handler([], "none", None) - expected_zip_path: Path = tmp_path / "bulk_downloads" / "Uncategorized.zip" + expected_zip_path: Path = tmp_path / "bulk_downloads" / "test.zip" assert_handler_success( expected_zip_path, expected_image_path, mock_image_contents, tmp_path, mock_invoker.services.events ) -def test_handler_bulk_download__item_id_given( +def test_handler_bulk_download_item_id_given( tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker ): """Test that the handler creates the zip file correctly when given a pregenerated bulk download item id.""" @@ -350,35 +386,6 @@ def execute_handler_test_on_error( assert event_bus.events[1].payload["error"] == error.__str__() -def test_get_board_name(tmp_path: Path, monkeypatch: Any, mock_invoker: Invoker): - """Test that the get_board_name function returns the correct board name.""" - - expected_board_name = "board1" - - def mock_get(*args, **kwargs): - return BoardRecord(board_id="12345", board_name=expected_board_name, created_at="None", updated_at="None") - - monkeypatch.setattr(mock_invoker.services.board_records, "get", mock_get) - - bulk_download_service = BulkDownloadService(tmp_path) - bulk_download_service.start(mock_invoker) - board_name = bulk_download_service.get_clean_board_name("12345") - - assert board_name == expected_board_name - - -def test_get_board_name_default(tmp_path: Path, mock_invoker: Invoker): - """Test that the get_board_name function returns the correct board name.""" - - expected_board_name = "Uncategorized" - - bulk_download_service = BulkDownloadService(tmp_path) - bulk_download_service.start(mock_invoker) - board_name = bulk_download_service.get_clean_board_name("none") - - assert board_name == expected_board_name - - def test_delete(tmp_path: Path): """Test that the delete method removes the bulk download file.""" From 45ffdce37128de7d8006c79c91ee392008c65bdc Mon Sep 17 00:00:00 2001 From: Stefan Tobler Date: Fri, 16 Feb 2024 22:49:38 -0500 Subject: [PATCH 170/340] relocating event_service fixture due to import ordering --- tests/app/services/bulk_download/test_bulk_download.py | 6 ++++++ tests/conftest.py | 1 - tests/fixtures/event_service.py | 7 ------- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/tests/app/services/bulk_download/test_bulk_download.py b/tests/app/services/bulk_download/test_bulk_download.py index d70510cd91a..b7480091d99 100644 --- a/tests/app/services/bulk_download/test_bulk_download.py +++ b/tests/app/services/bulk_download/test_bulk_download.py @@ -46,6 +46,12 @@ def mock_image_dto() -> ImageDTO: ) +@pytest.fixture +def mock_event_service() -> DummyEventService: + """Create a dummy event service.""" + return DummyEventService() + + @pytest.fixture def mock_services(mock_event_service: DummyEventService) -> InvocationServices: configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0) diff --git a/tests/conftest.py b/tests/conftest.py index 85fecfe440f..873ccc13fd2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,4 +5,3 @@ # We import the model_installer and torch_device fixtures here so that they can be used by all tests. Flake8 does not # play well with fixtures (F401 and F811), so this is cleaner than importing in all files that use these fixtures. from invokeai.backend.util.test_utils import torch_device # noqa: F401 -from tests.fixtures.event_service import mock_event_service # noqa: F401 diff --git a/tests/fixtures/event_service.py b/tests/fixtures/event_service.py index 0a09fa0d64a..8f6a45c38fc 100644 --- a/tests/fixtures/event_service.py +++ b/tests/fixtures/event_service.py @@ -1,6 +1,5 @@ from typing import Any, Dict, List -import pytest from pydantic import BaseModel from invokeai.app.services.events.events_base import EventServiceBase @@ -26,9 +25,3 @@ def __init__(self) -> None: def dispatch(self, event_name: str, payload: Any) -> None: """Dispatch an event by appending it to self.events.""" self.events.append(DummyEvent(event_name=payload["event"], payload=payload["data"])) - - -@pytest.fixture -def mock_event_service() -> DummyEventService: - """Create a dummy event service.""" - return DummyEventService() From c2c93a031470a816d0970a9ef8ebc8cdaf0bec70 Mon Sep 17 00:00:00 2001 From: Stefan Tobler Date: Sat, 17 Feb 2024 00:29:05 -0500 Subject: [PATCH 171/340] removing dependency on an output folder, embrace python temp folder for bulk download --- invokeai/app/api/dependencies.py | 2 +- .../bulk_download/bulk_download_base.py | 7 ++--- .../bulk_download/bulk_download_default.py | 9 ++----- tests/app/routers/test_images.py | 2 +- .../bulk_download/test_bulk_download.py | 26 +++++++++---------- 5 files changed, 19 insertions(+), 27 deletions(-) diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 984fd8e2670..95407291ec0 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -82,7 +82,7 @@ def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger board_records = SqliteBoardRecordStorage(db=db) boards = BoardService() events = FastAPIEventService(event_handler_id) - bulk_download = BulkDownloadService(output_folder=f"{output_folder}") + bulk_download = BulkDownloadService() image_records = SqliteImageRecordStorage(db=db) images = ImageService() invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size) diff --git a/invokeai/app/services/bulk_download/bulk_download_base.py b/invokeai/app/services/bulk_download/bulk_download_base.py index 5199652ad41..d889e2ed0ee 100644 --- a/invokeai/app/services/bulk_download/bulk_download_base.py +++ b/invokeai/app/services/bulk_download/bulk_download_base.py @@ -1,6 +1,5 @@ from abc import ABC, abstractmethod -from pathlib import Path -from typing import Optional, Union +from typing import Optional from invokeai.app.services.invoker import Invoker @@ -18,11 +17,9 @@ def start(self, invoker: Invoker) -> None: """ @abstractmethod - def __init__(self, output_folder: Union[str, Path]): + def __init__(self): """ Create BulkDownloadBase object. - - :param output_folder: The path to the output folder where the bulk download files can be temporarily stored. """ @abstractmethod diff --git a/invokeai/app/services/bulk_download/bulk_download_default.py b/invokeai/app/services/bulk_download/bulk_download_default.py index 406bd7d9972..4f5bfb087ff 100644 --- a/invokeai/app/services/bulk_download/bulk_download_default.py +++ b/invokeai/app/services/bulk_download/bulk_download_default.py @@ -19,7 +19,6 @@ class BulkDownloadService(BulkDownloadBase): - __output_folder: Path __temp_directory: TemporaryDirectory __bulk_downloads_folder: Path __event_bus: EventServiceBase @@ -29,15 +28,11 @@ def start(self, invoker: Invoker) -> None: self.__invoker = invoker self.__event_bus = invoker.services.events - def __init__( - self, - output_folder: Union[str, Path], - ): + def __init__(self): """ Initialize the downloader object. """ - self.__output_folder: Path = output_folder if isinstance(output_folder, Path) else Path(output_folder) - self.__temp_directory = TemporaryDirectory(dir=self.__output_folder) + self.__temp_directory = TemporaryDirectory() self.__bulk_downloads_folder = Path(self.__temp_directory.name) / "bulk_downloads" self.__bulk_downloads_folder.mkdir(parents=True, exist_ok=True) diff --git a/tests/app/routers/test_images.py b/tests/app/routers/test_images.py index e8521bf1322..67297a116f0 100644 --- a/tests/app/routers/test_images.py +++ b/tests/app/routers/test_images.py @@ -31,7 +31,7 @@ def mock_services(tmp_path: Path) -> InvocationServices: board_images=None, # type: ignore board_records=SqliteBoardRecordStorage(db=db), boards=None, # type: ignore - bulk_download=BulkDownloadService(tmp_path), + bulk_download=BulkDownloadService(), configuration=None, # type: ignore events=None, # type: ignore graph_execution_manager=None, # type: ignore diff --git a/tests/app/services/bulk_download/test_bulk_download.py b/tests/app/services/bulk_download/test_bulk_download.py index b7480091d99..3e8b7fd2eb6 100644 --- a/tests/app/services/bulk_download/test_bulk_download.py +++ b/tests/app/services/bulk_download/test_bulk_download.py @@ -113,7 +113,7 @@ def mock_TemporaryDirectory(*args, **kwargs): def test_get_path_when_file_exists(tmp_path: Path) -> None: """Test get_path when the file exists.""" - bulk_download_service = BulkDownloadService(tmp_path) + bulk_download_service = BulkDownloadService() # Create a directory at tmp_path/bulk_downloads test_bulk_downloads_dir: Path = tmp_path / "bulk_downloads" @@ -129,7 +129,7 @@ def test_get_path_when_file_exists(tmp_path: Path) -> None: def test_get_path_when_file_does_not_exist(tmp_path: Path) -> None: """Test get_path when the file does not exist.""" - bulk_download_service = BulkDownloadService(tmp_path) + bulk_download_service = BulkDownloadService() with pytest.raises(BulkDownloadTargetException): bulk_download_service.get_path("test") @@ -137,7 +137,7 @@ def test_get_path_when_file_does_not_exist(tmp_path: Path) -> None: def test_bulk_downloads_dir_created_at_start(tmp_path: Path) -> None: """Test that the bulk_downloads directory is created at start.""" - BulkDownloadService(tmp_path) + BulkDownloadService() assert (tmp_path / "bulk_downloads").exists() @@ -148,7 +148,7 @@ def test_handler_image_names(tmp_path: Path, monkeypatch: Any, mock_image_dto: I tmp_path, monkeypatch, mock_image_dto, mock_invoker ) - bulk_download_service = BulkDownloadService(tmp_path) + bulk_download_service = BulkDownloadService() bulk_download_service.start(mock_invoker) bulk_download_service.handler([mock_image_dto.image_name], None, None) @@ -160,7 +160,7 @@ def test_handler_image_names(tmp_path: Path, monkeypatch: Any, mock_image_dto: I def test_generate_id(monkeypatch: Any): """Test that the generate_id method generates a unique id.""" - bulk_download_service = BulkDownloadService("test") + bulk_download_service = BulkDownloadService() monkeypatch.setattr("invokeai.app.services.bulk_download.bulk_download_default.uuid_string", lambda: "test") @@ -170,7 +170,7 @@ def test_generate_id(monkeypatch: Any): def test_generate_id_with_board_id(monkeypatch: Any, mock_invoker: Invoker): """Test that the generate_id method generates a unique id with a board id.""" - bulk_download_service = BulkDownloadService("test") + bulk_download_service = BulkDownloadService() bulk_download_service.start(mock_invoker) def mock_board_get(*args, **kwargs): @@ -186,7 +186,7 @@ def mock_board_get(*args, **kwargs): def test_generate_id_with_default_board_id(monkeypatch: Any): """Test that the generate_id method generates a unique id with a board id.""" - bulk_download_service = BulkDownloadService("test") + bulk_download_service = BulkDownloadService() monkeypatch.setattr("invokeai.app.services.bulk_download.bulk_download_default.uuid_string", lambda: "test") @@ -210,7 +210,7 @@ def mock_get_many(*args, **kwargs): monkeypatch.setattr(mock_invoker.services.images, "get_many", mock_get_many) - bulk_download_service = BulkDownloadService(tmp_path) + bulk_download_service = BulkDownloadService() bulk_download_service.start(mock_invoker) bulk_download_service.handler([], "test", None) @@ -231,7 +231,7 @@ def mock_get_many(*args, **kwargs): monkeypatch.setattr(mock_invoker.services.images, "get_many", mock_get_many) - bulk_download_service = BulkDownloadService(tmp_path) + bulk_download_service = BulkDownloadService() bulk_download_service.start(mock_invoker) bulk_download_service.handler([], "none", None) @@ -256,7 +256,7 @@ def mock_get_many(*args, **kwargs): monkeypatch.setattr(mock_invoker.services.images, "get_many", mock_get_many) - bulk_download_service = BulkDownloadService(tmp_path) + bulk_download_service = BulkDownloadService() bulk_download_service.start(mock_invoker) bulk_download_service.handler([mock_image_dto.image_name], None, "test_id") @@ -380,7 +380,7 @@ def mock_get_board_name(*args, **kwargs): def execute_handler_test_on_error( tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker, error: Exception ): - bulk_download_service = BulkDownloadService(tmp_path) + bulk_download_service = BulkDownloadService() bulk_download_service.start(mock_invoker) bulk_download_service.handler([mock_image_dto.image_name], None, None) @@ -395,7 +395,7 @@ def execute_handler_test_on_error( def test_delete(tmp_path: Path): """Test that the delete method removes the bulk download file.""" - bulk_download_service = BulkDownloadService(tmp_path) + bulk_download_service = BulkDownloadService() mock_file: Path = tmp_path / "bulk_downloads" / "test.zip" mock_file.write_text("contents") @@ -409,7 +409,7 @@ def test_delete(tmp_path: Path): def test_stop(tmp_path: Path): """Test that the stop method removes the bulk download file and not any directories.""" - bulk_download_service = BulkDownloadService(tmp_path) + bulk_download_service = BulkDownloadService() mock_file: Path = tmp_path / "bulk_downloads" / "test.zip" mock_file.write_text("contents") From bc0d7d85aa5a021459f2a7864dbb2588485c9231 Mon Sep 17 00:00:00 2001 From: Stefan Tobler Date: Sat, 17 Feb 2024 23:53:38 -0500 Subject: [PATCH 172/340] refactoring handlers to do null check --- invokeai/app/api/routers/images.py | 5 ++- .../bulk_download/bulk_download_base.py | 4 ++- .../bulk_download/bulk_download_default.py | 31 +++++++++++++------ 3 files changed, 27 insertions(+), 13 deletions(-) diff --git a/invokeai/app/api/routers/images.py b/invokeai/app/api/routers/images.py index d1c64648de6..c3504b104d2 100644 --- a/invokeai/app/api/routers/images.py +++ b/invokeai/app/api/routers/images.py @@ -1,6 +1,6 @@ import io import traceback -from typing import Optional, cast +from typing import Optional from fastapi import BackgroundTasks, Body, HTTPException, Path, Query, Request, Response, UploadFile from fastapi.responses import FileResponse @@ -396,10 +396,9 @@ async def download_images_from_list( raise HTTPException(status_code=400, detail="No images or board id specified.") bulk_download_item_id: str = ApiDependencies.invoker.services.bulk_download.generate_item_id(board_id) - # Type narrowing handled above ^, we know that image_names is not None, trying to keep null checks at the boundaries background_tasks.add_task( ApiDependencies.invoker.services.bulk_download.handler, - cast(list[str], image_names), + image_names, board_id, bulk_download_item_id, ) diff --git a/invokeai/app/services/bulk_download/bulk_download_base.py b/invokeai/app/services/bulk_download/bulk_download_base.py index d889e2ed0ee..80a2ddfb251 100644 --- a/invokeai/app/services/bulk_download/bulk_download_base.py +++ b/invokeai/app/services/bulk_download/bulk_download_base.py @@ -23,7 +23,9 @@ def __init__(self): """ @abstractmethod - def handler(self, image_names: list[str], board_id: Optional[str], bulk_download_item_id: Optional[str]) -> None: + def handler( + self, image_names: Optional[list[str]], board_id: Optional[str], bulk_download_item_id: Optional[str] + ) -> None: """ Starts a a bulk download job. diff --git a/invokeai/app/services/bulk_download/bulk_download_default.py b/invokeai/app/services/bulk_download/bulk_download_default.py index 4f5bfb087ff..72bb5a5d52f 100644 --- a/invokeai/app/services/bulk_download/bulk_download_default.py +++ b/invokeai/app/services/bulk_download/bulk_download_default.py @@ -7,6 +7,7 @@ from invokeai.app.services.bulk_download.bulk_download_common import ( DEFAULT_BULK_DOWNLOAD_ID, BulkDownloadException, + BulkDownloadParametersException, BulkDownloadTargetException, ) from invokeai.app.services.events.events_base import EventServiceBase @@ -36,7 +37,9 @@ def __init__(self): self.__bulk_downloads_folder = Path(self.__temp_directory.name) / "bulk_downloads" self.__bulk_downloads_folder.mkdir(parents=True, exist_ok=True) - def handler(self, image_names: list[str], board_id: Optional[str], bulk_download_item_id: Optional[str]) -> None: + def handler( + self, image_names: Optional[list[str]], board_id: Optional[str], bulk_download_item_id: Optional[str] + ) -> None: """ Create a zip file containing the images specified by the given image names or board id. @@ -53,15 +56,12 @@ def handler(self, image_names: list[str], board_id: Optional[str], bulk_download image_dtos: list[ImageDTO] = [] if board_id: - # -1 is the default value for limit, which means no limit, is_intermediate False only gives us completed images - image_dtos = self.__invoker.services.images.get_many( - offset=0, - limit=-1, - board_id=board_id, - is_intermediate=False, - ).items + image_dtos = self._board_handler(board_id) + elif image_names: + image_dtos = self._image_handler(image_names) else: - image_dtos = [self.__invoker.services.images.get_dto(image_name) for image_name in image_names] + raise BulkDownloadParametersException() + bulk_download_item_name: str = self._create_zip_file(image_dtos, bulk_download_item_id) self._signal_job_completed(bulk_download_id, bulk_download_item_id, bulk_download_item_name) except (ImageRecordNotFoundException, BoardRecordNotFoundException, BulkDownloadException) as e: @@ -71,6 +71,19 @@ def handler(self, image_names: list[str], board_id: Optional[str], bulk_download self.__invoker.services.logger.error("Problem bulk downloading images.") raise e + def _image_handler(self, image_names: list[str]) -> list[ImageDTO]: + return [self.__invoker.services.images.get_dto(image_name) for image_name in image_names] + + def _board_handler(self, board_id: str) -> list[ImageDTO]: + # -1 is the default value for limit, which means no limit, is_intermediate False only gives us completed images + image_dtos = self.__invoker.services.images.get_many( + offset=0, + limit=-1, + board_id=board_id, + is_intermediate=False, + ).items + return image_dtos + def generate_item_id(self, board_id: Optional[str]) -> str: return uuid_string() if board_id is None else self._get_clean_board_name(board_id) + "_" + uuid_string() From b57336eff511555245205c43c27c40b9cb311d25 Mon Sep 17 00:00:00 2001 From: Stefan Tobler Date: Mon, 19 Feb 2024 13:54:48 -0500 Subject: [PATCH 173/340] adding bulk_download_item_name to socket events --- .../bulk_download/bulk_download_default.py | 24 ++++++++++++++----- invokeai/app/services/events/events_base.py | 10 ++++++-- 2 files changed, 26 insertions(+), 8 deletions(-) diff --git a/invokeai/app/services/bulk_download/bulk_download_default.py b/invokeai/app/services/bulk_download/bulk_download_default.py index 72bb5a5d52f..4d0d2e7b0fa 100644 --- a/invokeai/app/services/bulk_download/bulk_download_default.py +++ b/invokeai/app/services/bulk_download/bulk_download_default.py @@ -49,8 +49,9 @@ def handler( bulk_download_id: str = DEFAULT_BULK_DOWNLOAD_ID bulk_download_item_id = uuid_string() if bulk_download_item_id is None else bulk_download_item_id + bulk_download_item_name = bulk_download_item_id + ".zip" - self._signal_job_started(bulk_download_id, bulk_download_item_id) + self._signal_job_started(bulk_download_id, bulk_download_item_id, bulk_download_item_name) try: image_dtos: list[ImageDTO] = [] @@ -64,10 +65,15 @@ def handler( bulk_download_item_name: str = self._create_zip_file(image_dtos, bulk_download_item_id) self._signal_job_completed(bulk_download_id, bulk_download_item_id, bulk_download_item_name) - except (ImageRecordNotFoundException, BoardRecordNotFoundException, BulkDownloadException) as e: - self._signal_job_failed(bulk_download_id, bulk_download_item_id, e) + except ( + ImageRecordNotFoundException, + BoardRecordNotFoundException, + BulkDownloadException, + BulkDownloadParametersException, + ) as e: + self._signal_job_failed(bulk_download_id, bulk_download_item_id, bulk_download_item_name, e) except Exception as e: - self._signal_job_failed(bulk_download_id, bulk_download_item_id, e) + self._signal_job_failed(bulk_download_id, bulk_download_item_id, bulk_download_item_name, e) self.__invoker.services.logger.error("Problem bulk downloading images.") raise e @@ -116,13 +122,16 @@ def _clean_string_to_path_safe(self, s: str) -> str: """Clean a string to be path safe.""" return "".join([c for c in s if c.isalpha() or c.isdigit() or c == " " or c == "_" or c == "-"]).rstrip() - def _signal_job_started(self, bulk_download_id: str, bulk_download_item_id: str) -> None: + def _signal_job_started( + self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str + ) -> None: """Signal that a bulk download job has started.""" if self.__event_bus: assert bulk_download_id is not None self.__event_bus.emit_bulk_download_started( bulk_download_id=bulk_download_id, bulk_download_item_id=bulk_download_item_id, + bulk_download_item_name=bulk_download_item_name, ) def _signal_job_completed( @@ -138,7 +147,9 @@ def _signal_job_completed( bulk_download_item_name=bulk_download_item_name, ) - def _signal_job_failed(self, bulk_download_id: str, bulk_download_item_id: str, exception: Exception) -> None: + def _signal_job_failed( + self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, exception: Exception + ) -> None: """Signal that a bulk download job has failed.""" if self.__event_bus: assert bulk_download_id is not None @@ -146,6 +157,7 @@ def _signal_job_failed(self, bulk_download_id: str, bulk_download_item_id: str, self.__event_bus.emit_bulk_download_failed( bulk_download_id=bulk_download_id, bulk_download_item_id=bulk_download_item_id, + bulk_download_item_name=bulk_download_item_name, error=str(exception), ) diff --git a/invokeai/app/services/events/events_base.py b/invokeai/app/services/events/events_base.py index 3cc3ba2f28f..53df14330fa 100644 --- a/invokeai/app/services/events/events_base.py +++ b/invokeai/app/services/events/events_base.py @@ -440,13 +440,16 @@ def emit_model_install_error( }, ) - def emit_bulk_download_started(self, bulk_download_id: str, bulk_download_item_id: str) -> None: + def emit_bulk_download_started( + self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str + ) -> None: """Emitted when a bulk download starts""" self._emit_bulk_download_event( event_name="bulk_download_started", payload={ "bulk_download_id": bulk_download_id, "bulk_download_item_id": bulk_download_item_id, + "bulk_download_item_name": bulk_download_item_name, }, ) @@ -463,13 +466,16 @@ def emit_bulk_download_completed( }, ) - def emit_bulk_download_failed(self, bulk_download_id: str, bulk_download_item_id: str, error: str) -> None: + def emit_bulk_download_failed( + self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, error: str + ) -> None: """Emitted when a bulk download fails""" self._emit_bulk_download_event( event_name="bulk_download_failed", payload={ "bulk_download_id": bulk_download_id, "bulk_download_item_id": bulk_download_item_id, + "bulk_download_item_name": bulk_download_item_name, "error": error, }, ) From de8213cf9715c9b851ef9935b77a6db4850c2b30 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 20 Feb 2024 17:47:04 +1100 Subject: [PATCH 174/340] tidy(bulk_download): clean up comments --- .../bulk_download/bulk_download_base.py | 12 ++-------- .../bulk_download/bulk_download_default.py | 23 ------------------- 2 files changed, 2 insertions(+), 33 deletions(-) diff --git a/invokeai/app/services/bulk_download/bulk_download_base.py b/invokeai/app/services/bulk_download/bulk_download_base.py index 80a2ddfb251..f085d384a95 100644 --- a/invokeai/app/services/bulk_download/bulk_download_base.py +++ b/invokeai/app/services/bulk_download/bulk_download_base.py @@ -27,7 +27,7 @@ def handler( self, image_names: Optional[list[str]], board_id: Optional[str], bulk_download_item_id: Optional[str] ) -> None: """ - Starts a a bulk download job. + Create a zip file containing the images specified by the given image names or board id. :param image_names: A list of image names to include in the zip file. :param board_id: The ID of the board. If provided, all images associated with the board will be included in the zip file. @@ -54,15 +54,7 @@ def generate_item_id(self, board_id: Optional[str]) -> str: @abstractmethod def stop(self, *args, **kwargs) -> None: - """ - Stops the BulkDownloadService and cleans up all the remnants. - - This method is responsible for stopping the BulkDownloadService and performing any necessary cleanup - operations to remove any remnants or resources associated with the service. - - :param *args: Variable length argument list. - :param **kwargs: Arbitrary keyword arguments. - """ + """Stops the BulkDownloadService and cleans up.""" @abstractmethod def delete(self, bulk_download_item_name: str) -> None: diff --git a/invokeai/app/services/bulk_download/bulk_download_default.py b/invokeai/app/services/bulk_download/bulk_download_default.py index 4d0d2e7b0fa..670703db9b4 100644 --- a/invokeai/app/services/bulk_download/bulk_download_default.py +++ b/invokeai/app/services/bulk_download/bulk_download_default.py @@ -30,9 +30,6 @@ def start(self, invoker: Invoker) -> None: self.__event_bus = invoker.services.events def __init__(self): - """ - Initialize the downloader object. - """ self.__temp_directory = TemporaryDirectory() self.__bulk_downloads_folder = Path(self.__temp_directory.name) / "bulk_downloads" self.__bulk_downloads_folder.mkdir(parents=True, exist_ok=True) @@ -40,13 +37,6 @@ def __init__(self): def handler( self, image_names: Optional[list[str]], board_id: Optional[str], bulk_download_item_id: Optional[str] ) -> None: - """ - Create a zip file containing the images specified by the given image names or board id. - - param: image_names: A list of image names to include in the zip file. - param: board_id: The ID of the board. If provided, all images associated with the board will be included in the zip file. - """ - bulk_download_id: str = DEFAULT_BULK_DOWNLOAD_ID bulk_download_item_id = uuid_string() if bulk_download_item_id is None else bulk_download_item_id bulk_download_item_name = bulk_download_item_id + ".zip" @@ -162,26 +152,13 @@ def _signal_job_failed( ) def stop(self, *args, **kwargs): - """Stop the bulk download service and delete the files in the bulk download folder.""" - # Get all the files in the bulk downloads folder, only .zip files self.__temp_directory.cleanup() def delete(self, bulk_download_item_name: str) -> None: - """ - Delete the bulk download file. - - :param bulk_download_item_name: The name of the bulk download item. - """ path = self.get_path(bulk_download_item_name) Path(path).unlink() def get_path(self, bulk_download_item_name: str) -> str: - """ - Get the path to the bulk download file. - - :param bulk_download_item_name: The name of the bulk download item. - :return: The path to the bulk download file. - """ path = str(self.__bulk_downloads_folder / bulk_download_item_name) if not self._is_valid_path(path): raise BulkDownloadTargetException() From 9cd4d8c3912450d000f4ad5210c28f0651e4d852 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 20 Feb 2024 17:48:51 +1100 Subject: [PATCH 175/340] tidy(bulk_download): remove extraneous abstract methods `start`, `stop` and `__init__` are not required in implementations of an ABC or service. --- .../bulk_download/bulk_download_base.py | 23 +------------------ 1 file changed, 1 insertion(+), 22 deletions(-) diff --git a/invokeai/app/services/bulk_download/bulk_download_base.py b/invokeai/app/services/bulk_download/bulk_download_base.py index f085d384a95..617b611f566 100644 --- a/invokeai/app/services/bulk_download/bulk_download_base.py +++ b/invokeai/app/services/bulk_download/bulk_download_base.py @@ -1,26 +1,9 @@ from abc import ABC, abstractmethod from typing import Optional -from invokeai.app.services.invoker import Invoker - class BulkDownloadBase(ABC): - @abstractmethod - def start(self, invoker: Invoker) -> None: - """ - Starts the BulkDownloadService. - - This method is responsible for starting the BulkDownloadService and performing any necessary initialization - operations to prepare the service for use. - - param: invoker: The Invoker that holds all the services, required to be passed as a parameter to avoid circular dependencies. - """ - - @abstractmethod - def __init__(self): - """ - Create BulkDownloadBase object. - """ + """Responsible for creating a zip file containing the images specified by the given image names or board id.""" @abstractmethod def handler( @@ -52,10 +35,6 @@ def generate_item_id(self, board_id: Optional[str]) -> str: :return: The generated item ID. """ - @abstractmethod - def stop(self, *args, **kwargs) -> None: - """Stops the BulkDownloadService and cleans up.""" - @abstractmethod def delete(self, bulk_download_item_name: str) -> None: """ From 498ce5c3db5b201cf4e46fc38f76402befc327d5 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 20 Feb 2024 17:50:10 +1100 Subject: [PATCH 176/340] tidy(bulk_download): remove class-level attr annotations These can be misleading as they shadow actual assigned class attributes. This pattern is in the rest of the app but it shouldn't be. --- .../app/services/bulk_download/bulk_download_default.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/invokeai/app/services/bulk_download/bulk_download_default.py b/invokeai/app/services/bulk_download/bulk_download_default.py index 670703db9b4..b44501a8d9d 100644 --- a/invokeai/app/services/bulk_download/bulk_download_default.py +++ b/invokeai/app/services/bulk_download/bulk_download_default.py @@ -10,7 +10,6 @@ BulkDownloadParametersException, BulkDownloadTargetException, ) -from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.image_records.image_records_common import ImageRecordNotFoundException from invokeai.app.services.images.images_common import ImageDTO from invokeai.app.services.invoker import Invoker @@ -20,11 +19,6 @@ class BulkDownloadService(BulkDownloadBase): - __temp_directory: TemporaryDirectory - __bulk_downloads_folder: Path - __event_bus: EventServiceBase - __invoker: Invoker - def start(self, invoker: Invoker) -> None: self.__invoker = invoker self.__event_bus = invoker.services.events From de4c687cbb2c9b0aa607a22a1272f3936d372d61 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 20 Feb 2024 17:52:39 +1100 Subject: [PATCH 177/340] tidy(bulk_download): use single underscore for private attrs Double underscores are used in the app but it doesn't actually do or convey anything that single underscores don't already do. Considered unpythonic except for actual dunder/magic methods. --- .../bulk_download/bulk_download_default.py | 38 +++++++++---------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/invokeai/app/services/bulk_download/bulk_download_default.py b/invokeai/app/services/bulk_download/bulk_download_default.py index b44501a8d9d..f8475b8f6ef 100644 --- a/invokeai/app/services/bulk_download/bulk_download_default.py +++ b/invokeai/app/services/bulk_download/bulk_download_default.py @@ -20,13 +20,13 @@ class BulkDownloadService(BulkDownloadBase): def start(self, invoker: Invoker) -> None: - self.__invoker = invoker - self.__event_bus = invoker.services.events + self._invoker = invoker + self._event_bus = invoker.services.events def __init__(self): - self.__temp_directory = TemporaryDirectory() - self.__bulk_downloads_folder = Path(self.__temp_directory.name) / "bulk_downloads" - self.__bulk_downloads_folder.mkdir(parents=True, exist_ok=True) + self._temp_directory = TemporaryDirectory() + self._bulk_downloads_folder = Path(self._temp_directory.name) / "bulk_downloads" + self._bulk_downloads_folder.mkdir(parents=True, exist_ok=True) def handler( self, image_names: Optional[list[str]], board_id: Optional[str], bulk_download_item_id: Optional[str] @@ -58,15 +58,15 @@ def handler( self._signal_job_failed(bulk_download_id, bulk_download_item_id, bulk_download_item_name, e) except Exception as e: self._signal_job_failed(bulk_download_id, bulk_download_item_id, bulk_download_item_name, e) - self.__invoker.services.logger.error("Problem bulk downloading images.") + self._invoker.services.logger.error("Problem bulk downloading images.") raise e def _image_handler(self, image_names: list[str]) -> list[ImageDTO]: - return [self.__invoker.services.images.get_dto(image_name) for image_name in image_names] + return [self._invoker.services.images.get_dto(image_name) for image_name in image_names] def _board_handler(self, board_id: str) -> list[ImageDTO]: # -1 is the default value for limit, which means no limit, is_intermediate False only gives us completed images - image_dtos = self.__invoker.services.images.get_many( + image_dtos = self._invoker.services.images.get_many( offset=0, limit=-1, board_id=board_id, @@ -81,7 +81,7 @@ def _get_clean_board_name(self, board_id: str) -> str: if board_id == "none": return "Uncategorized" - return self._clean_string_to_path_safe(self.__invoker.services.board_records.get(board_id).board_name) + return self._clean_string_to_path_safe(self._invoker.services.board_records.get(board_id).board_name) def _create_zip_file(self, image_dtos: list[ImageDTO], bulk_download_item_id: str) -> str: """ @@ -91,12 +91,12 @@ def _create_zip_file(self, image_dtos: list[ImageDTO], bulk_download_item_id: st :return: The name of the zip file. """ zip_file_name = bulk_download_item_id + ".zip" - zip_file_path = self.__bulk_downloads_folder / (zip_file_name) + zip_file_path = self._bulk_downloads_folder / (zip_file_name) with ZipFile(zip_file_path, "w") as zip_file: for image_dto in image_dtos: image_zip_path = Path(image_dto.image_category.value) / image_dto.image_name - image_disk_path = self.__invoker.services.images.get_path(image_dto.image_name) + image_disk_path = self._invoker.services.images.get_path(image_dto.image_name) zip_file.write(image_disk_path, arcname=image_zip_path) return str(zip_file_name) @@ -110,9 +110,9 @@ def _signal_job_started( self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str ) -> None: """Signal that a bulk download job has started.""" - if self.__event_bus: + if self._event_bus: assert bulk_download_id is not None - self.__event_bus.emit_bulk_download_started( + self._event_bus.emit_bulk_download_started( bulk_download_id=bulk_download_id, bulk_download_item_id=bulk_download_item_id, bulk_download_item_name=bulk_download_item_name, @@ -122,10 +122,10 @@ def _signal_job_completed( self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str ) -> None: """Signal that a bulk download job has completed.""" - if self.__event_bus: + if self._event_bus: assert bulk_download_id is not None assert bulk_download_item_name is not None - self.__event_bus.emit_bulk_download_completed( + self._event_bus.emit_bulk_download_completed( bulk_download_id=bulk_download_id, bulk_download_item_id=bulk_download_item_id, bulk_download_item_name=bulk_download_item_name, @@ -135,10 +135,10 @@ def _signal_job_failed( self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, exception: Exception ) -> None: """Signal that a bulk download job has failed.""" - if self.__event_bus: + if self._event_bus: assert bulk_download_id is not None assert exception is not None - self.__event_bus.emit_bulk_download_failed( + self._event_bus.emit_bulk_download_failed( bulk_download_id=bulk_download_id, bulk_download_item_id=bulk_download_item_id, bulk_download_item_name=bulk_download_item_name, @@ -146,14 +146,14 @@ def _signal_job_failed( ) def stop(self, *args, **kwargs): - self.__temp_directory.cleanup() + self._temp_directory.cleanup() def delete(self, bulk_download_item_name: str) -> None: path = self.get_path(bulk_download_item_name) Path(path).unlink() def get_path(self, bulk_download_item_name: str) -> str: - path = str(self.__bulk_downloads_folder / bulk_download_item_name) + path = str(self._bulk_downloads_folder / bulk_download_item_name) if not self._is_valid_path(path): raise BulkDownloadTargetException() return path From 635fe78f474801adf5a9b3fd696d5c4747fe668c Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 20 Feb 2024 17:55:13 +1100 Subject: [PATCH 178/340] tidy(bulk_download): nit - use `or` as a coalescing operator Just a bit cleaner. --- invokeai/app/services/bulk_download/bulk_download_default.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invokeai/app/services/bulk_download/bulk_download_default.py b/invokeai/app/services/bulk_download/bulk_download_default.py index f8475b8f6ef..6c49957a5cf 100644 --- a/invokeai/app/services/bulk_download/bulk_download_default.py +++ b/invokeai/app/services/bulk_download/bulk_download_default.py @@ -32,7 +32,7 @@ def handler( self, image_names: Optional[list[str]], board_id: Optional[str], bulk_download_item_id: Optional[str] ) -> None: bulk_download_id: str = DEFAULT_BULK_DOWNLOAD_ID - bulk_download_item_id = uuid_string() if bulk_download_item_id is None else bulk_download_item_id + bulk_download_item_id = bulk_download_item_id or uuid_string() bulk_download_item_name = bulk_download_item_id + ".zip" self._signal_job_started(bulk_download_id, bulk_download_item_id, bulk_download_item_name) From 4cde91e4a2d61533166e2171724e295a1444bf5a Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 20 Feb 2024 17:58:21 +1100 Subject: [PATCH 179/340] tidy(bulk_download): do not rely on pagination API to get all images for board We can get all images for the board as a list of image names, then pass that to `_image_handler` to get the DTOs, decoupling from the pagination API. --- .../services/bulk_download/bulk_download_default.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/invokeai/app/services/bulk_download/bulk_download_default.py b/invokeai/app/services/bulk_download/bulk_download_default.py index 6c49957a5cf..9fad0c34434 100644 --- a/invokeai/app/services/bulk_download/bulk_download_default.py +++ b/invokeai/app/services/bulk_download/bulk_download_default.py @@ -65,14 +65,8 @@ def _image_handler(self, image_names: list[str]) -> list[ImageDTO]: return [self._invoker.services.images.get_dto(image_name) for image_name in image_names] def _board_handler(self, board_id: str) -> list[ImageDTO]: - # -1 is the default value for limit, which means no limit, is_intermediate False only gives us completed images - image_dtos = self._invoker.services.images.get_many( - offset=0, - limit=-1, - board_id=board_id, - is_intermediate=False, - ).items - return image_dtos + image_names = self._invoker.services.board_image_records.get_all_board_image_names_for_board(board_id) + return self._image_handler(image_names) def generate_item_id(self, board_id: Optional[str]) -> str: return uuid_string() if board_id is None else self._get_clean_board_name(board_id) + "_" + uuid_string() From a237a375b2e01b61fbbf4216037d04a6b196e6ce Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 20 Feb 2024 18:05:52 +1100 Subject: [PATCH 180/340] tidy(bulk_download): don't store events service separately Using the invoker object directly leaves no ambiguity as to what `_events_bus` actually is. --- .../services/bulk_download/bulk_download_default.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/invokeai/app/services/bulk_download/bulk_download_default.py b/invokeai/app/services/bulk_download/bulk_download_default.py index 9fad0c34434..04cec928f4a 100644 --- a/invokeai/app/services/bulk_download/bulk_download_default.py +++ b/invokeai/app/services/bulk_download/bulk_download_default.py @@ -21,7 +21,6 @@ class BulkDownloadService(BulkDownloadBase): def start(self, invoker: Invoker) -> None: self._invoker = invoker - self._event_bus = invoker.services.events def __init__(self): self._temp_directory = TemporaryDirectory() @@ -104,9 +103,9 @@ def _signal_job_started( self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str ) -> None: """Signal that a bulk download job has started.""" - if self._event_bus: + if self._invoker: assert bulk_download_id is not None - self._event_bus.emit_bulk_download_started( + self._invoker.services.events.emit_bulk_download_started( bulk_download_id=bulk_download_id, bulk_download_item_id=bulk_download_item_id, bulk_download_item_name=bulk_download_item_name, @@ -116,10 +115,10 @@ def _signal_job_completed( self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str ) -> None: """Signal that a bulk download job has completed.""" - if self._event_bus: + if self._invoker: assert bulk_download_id is not None assert bulk_download_item_name is not None - self._event_bus.emit_bulk_download_completed( + self._invoker.services.events.emit_bulk_download_completed( bulk_download_id=bulk_download_id, bulk_download_item_id=bulk_download_item_id, bulk_download_item_name=bulk_download_item_name, @@ -129,10 +128,10 @@ def _signal_job_failed( self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, exception: Exception ) -> None: """Signal that a bulk download job has failed.""" - if self._event_bus: + if self._invoker: assert bulk_download_id is not None assert exception is not None - self._event_bus.emit_bulk_download_failed( + self._invoker.services.events.emit_bulk_download_failed( bulk_download_id=bulk_download_id, bulk_download_item_id=bulk_download_item_id, bulk_download_item_name=bulk_download_item_name, From f27840f70191e0e327d5cacc07f35fe6195337c7 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 20 Feb 2024 18:49:55 +1100 Subject: [PATCH 181/340] test: clean up & fix tests - Deduplicate the mock invocation services. This is possible now that the import order issue is resolved. - Merge `DummyEventService` into `TestEventService` and update all tests to use `TestEventService`. --- tests/app/routers/test_images.py | 49 ------------- .../bulk_download/test_bulk_download.py | 71 ++++--------------- .../services/download/test_download_queue.py | 6 +- tests/conftest.py | 55 +++++++++++++- tests/fixtures/event_service.py | 27 ------- tests/test_graph_execution_state.py | 39 ---------- tests/test_nodes.py | 15 ++-- 7 files changed, 78 insertions(+), 184 deletions(-) delete mode 100644 tests/fixtures/event_service.py diff --git a/tests/app/routers/test_images.py b/tests/app/routers/test_images.py index 67297a116f0..5cb8cf1c37b 100644 --- a/tests/app/routers/test_images.py +++ b/tests/app/routers/test_images.py @@ -1,66 +1,17 @@ from pathlib import Path from typing import Any -import pytest from fastapi import BackgroundTasks from fastapi.testclient import TestClient from invokeai.app.api.dependencies import ApiDependencies from invokeai.app.api_app import app from invokeai.app.services.board_records.board_records_common import BoardRecord -from invokeai.app.services.board_records.board_records_sqlite import SqliteBoardRecordStorage -from invokeai.app.services.bulk_download.bulk_download_default import BulkDownloadService -from invokeai.app.services.config.config_default import InvokeAIAppConfig -from invokeai.app.services.images.images_default import ImageService -from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.invoker import Invoker -from invokeai.backend.util.logging import InvokeAILogger -from tests.fixtures.sqlite_database import create_mock_sqlite_database client = TestClient(app) -@pytest.fixture -def mock_services(tmp_path: Path) -> InvocationServices: - configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0) - logger = InvokeAILogger.get_logger() - db = create_mock_sqlite_database(configuration, logger) - - return InvocationServices( - board_image_records=None, # type: ignore - board_images=None, # type: ignore - board_records=SqliteBoardRecordStorage(db=db), - boards=None, # type: ignore - bulk_download=BulkDownloadService(), - configuration=None, # type: ignore - events=None, # type: ignore - graph_execution_manager=None, # type: ignore - image_files=None, # type: ignore - image_records=None, # type: ignore - images=ImageService(), - invocation_cache=None, # type: ignore - latents=None, # type: ignore - logger=logger, - model_manager=None, # type: ignore - model_records=None, # type: ignore - download_queue=None, # type: ignore - model_install=None, # type: ignore - names=None, # type: ignore - performance_statistics=None, # type: ignore - processor=None, # type: ignore - queue=None, # type: ignore - session_processor=None, # type: ignore - session_queue=None, # type: ignore - urls=None, # type: ignore - workflow_records=None, # type: ignore - ) - - -@pytest.fixture() -def mock_invoker(mock_services: InvocationServices) -> Invoker: - return Invoker(services=mock_services) - - class MockApiDependencies(ApiDependencies): invoker: Invoker diff --git a/tests/app/services/bulk_download/test_bulk_download.py b/tests/app/services/bulk_download/test_bulk_download.py index 3e8b7fd2eb6..b18f6e038d9 100644 --- a/tests/app/services/bulk_download/test_bulk_download.py +++ b/tests/app/services/bulk_download/test_bulk_download.py @@ -7,23 +7,17 @@ import pytest from invokeai.app.services.board_records.board_records_common import BoardRecord, BoardRecordNotFoundException -from invokeai.app.services.board_records.board_records_sqlite import SqliteBoardRecordStorage from invokeai.app.services.bulk_download.bulk_download_common import BulkDownloadTargetException from invokeai.app.services.bulk_download.bulk_download_default import BulkDownloadService -from invokeai.app.services.config.config_default import InvokeAIAppConfig from invokeai.app.services.image_records.image_records_common import ( ImageCategory, ImageRecordNotFoundException, ResourceOrigin, ) from invokeai.app.services.images.images_common import ImageDTO -from invokeai.app.services.images.images_default import ImageService -from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.invoker import Invoker from invokeai.app.services.shared.pagination import OffsetPaginatedResults -from invokeai.backend.util.logging import InvokeAILogger -from tests.fixtures.event_service import DummyEventService -from tests.fixtures.sqlite_database import create_mock_sqlite_database +from tests.test_nodes import TestEventService @pytest.fixture @@ -46,53 +40,6 @@ def mock_image_dto() -> ImageDTO: ) -@pytest.fixture -def mock_event_service() -> DummyEventService: - """Create a dummy event service.""" - return DummyEventService() - - -@pytest.fixture -def mock_services(mock_event_service: DummyEventService) -> InvocationServices: - configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0) - logger = InvokeAILogger.get_logger() - db = create_mock_sqlite_database(configuration, logger) - - return InvocationServices( - board_image_records=None, # type: ignore - board_images=None, # type: ignore - board_records=SqliteBoardRecordStorage(db=db), - boards=None, # type: ignore - bulk_download=None, # type: ignore - configuration=None, # type: ignore - events=mock_event_service, - graph_execution_manager=None, # type: ignore - image_files=None, # type: ignore - image_records=None, # type: ignore - images=ImageService(), - invocation_cache=None, # type: ignore - latents=None, # type: ignore - logger=logger, - model_manager=None, # type: ignore - model_records=None, # type: ignore - download_queue=None, # type: ignore - model_install=None, # type: ignore - names=None, # type: ignore - performance_statistics=None, # type: ignore - processor=None, # type: ignore - queue=None, # type: ignore - session_processor=None, # type: ignore - session_queue=None, # type: ignore - urls=None, # type: ignore - workflow_records=None, # type: ignore - ) - - -@pytest.fixture() -def mock_invoker(mock_services: InvocationServices) -> Invoker: - return Invoker(services=mock_services) - - @pytest.fixture(autouse=True) def mock_temporary_directory(monkeypatch: Any, tmp_path: Path): """Mock the TemporaryDirectory class so that it uses the tmp_path fixture.""" @@ -288,6 +235,16 @@ def mock_get_dto(*args, **kwargs): monkeypatch.setattr(mock_invoker.services.images, "get_dto", mock_get_dto) + # This is used when preparing all images for a given board + def mock_get_all_board_image_names_for_board(*args, **kwargs): + return [mock_image_dto.image_name] + + monkeypatch.setattr( + mock_invoker.services.board_image_records, + "get_all_board_image_names_for_board", + mock_get_all_board_image_names_for_board, + ) + # Create a mock image file so that the contents of the zip file are not empty mock_image_path: Path = tmp_path / mock_image_dto.image_name mock_image_contents: str = "Totally an image" @@ -306,7 +263,7 @@ def assert_handler_success( expected_image_path: Path, mock_image_contents: str, tmp_path: Path, - event_bus: DummyEventService, + event_bus: TestEventService, ): """Assert that the handler was successful.""" # Check that the zip file was created @@ -369,7 +326,7 @@ def mock_get_board_name(*args, **kwargs): with pytest.raises(Exception): # noqa: B017 execute_handler_test_on_error(tmp_path, monkeypatch, mock_image_dto, mock_invoker, exception) - event_bus: DummyEventService = mock_invoker.services.events + event_bus: TestEventService = mock_invoker.services.events assert len(event_bus.events) == 2 assert event_bus.events[0].event_name == "bulk_download_started" @@ -384,7 +341,7 @@ def execute_handler_test_on_error( bulk_download_service.start(mock_invoker) bulk_download_service.handler([mock_image_dto.image_name], None, None) - event_bus: DummyEventService = mock_invoker.services.events + event_bus: TestEventService = mock_invoker.services.events assert len(event_bus.events) == 2 assert event_bus.events[0].event_name == "bulk_download_started" diff --git a/tests/app/services/download/test_download_queue.py b/tests/app/services/download/test_download_queue.py index 34408ac5aed..ff9b193b177 100644 --- a/tests/app/services/download/test_download_queue.py +++ b/tests/app/services/download/test_download_queue.py @@ -9,7 +9,7 @@ from requests_testadapter import TestAdapter, TestSession from invokeai.app.services.download import DownloadJob, DownloadJobStatus, DownloadQueueService -from tests.fixtures.event_service import DummyEventService +from tests.test_nodes import TestEventService # Prevent pytest deprecation warnings TestAdapter.__test__ = False # type: ignore @@ -101,7 +101,7 @@ def test_errors(tmp_path: Path, session: Session) -> None: def test_event_bus(tmp_path: Path, session: Session) -> None: - event_bus = DummyEventService() + event_bus = TestEventService() queue = DownloadQueueService(requests_session=session, event_bus=event_bus) queue.start() @@ -167,7 +167,7 @@ def broken_callback(job: DownloadJob) -> None: def test_cancel(tmp_path: Path, session: Session) -> None: - event_bus = DummyEventService() + event_bus = TestEventService() queue = DownloadQueueService(requests_session=session, event_bus=event_bus) queue.start() diff --git a/tests/conftest.py b/tests/conftest.py index 873ccc13fd2..a483b7529a1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,4 +4,57 @@ # We import the model_installer and torch_device fixtures here so that they can be used by all tests. Flake8 does not # play well with fixtures (F401 and F811), so this is cleaner than importing in all files that use these fixtures. -from invokeai.backend.util.test_utils import torch_device # noqa: F401 +import logging + +import pytest + +from invokeai.app.services.board_image_records.board_image_records_sqlite import SqliteBoardImageRecordStorage +from invokeai.app.services.board_records.board_records_sqlite import SqliteBoardRecordStorage +from invokeai.app.services.bulk_download.bulk_download_default import BulkDownloadService +from invokeai.app.services.config.config_default import InvokeAIAppConfig +from invokeai.app.services.images.images_default import ImageService +from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache +from invokeai.app.services.invocation_services import InvocationServices +from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService +from invokeai.app.services.invoker import Invoker +from invokeai.backend.util.logging import InvokeAILogger +from tests.fixtures.sqlite_database import create_mock_sqlite_database # noqa: F401 +from tests.test_nodes import TestEventService + + +@pytest.fixture +def mock_services() -> InvocationServices: + configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0) + logger = InvokeAILogger.get_logger() + db = create_mock_sqlite_database(configuration, logger) + + # NOTE: none of these are actually called by the test invocations + return InvocationServices( + board_image_records=SqliteBoardImageRecordStorage(db=db), + board_images=None, # type: ignore + board_records=SqliteBoardRecordStorage(db=db), + boards=None, # type: ignore + bulk_download=BulkDownloadService(), + configuration=configuration, + events=TestEventService(), + image_files=None, # type: ignore + image_records=None, # type: ignore + images=ImageService(), + invocation_cache=MemoryInvocationCache(max_cache_size=0), + logger=logging, # type: ignore + model_manager=None, # type: ignore + download_queue=None, # type: ignore + names=None, # type: ignore + performance_statistics=InvocationStatsService(), + session_processor=None, # type: ignore + session_queue=None, # type: ignore + urls=None, # type: ignore + workflow_records=None, # type: ignore + tensors=None, # type: ignore + conditioning=None, # type: ignore + ) + + +@pytest.fixture() +def mock_invoker(mock_services: InvocationServices) -> Invoker: + return Invoker(services=mock_services) diff --git a/tests/fixtures/event_service.py b/tests/fixtures/event_service.py deleted file mode 100644 index 8f6a45c38fc..00000000000 --- a/tests/fixtures/event_service.py +++ /dev/null @@ -1,27 +0,0 @@ -from typing import Any, Dict, List - -from pydantic import BaseModel - -from invokeai.app.services.events.events_base import EventServiceBase - - -class DummyEvent(BaseModel): - """Dummy Event to use with Dummy Event service.""" - - event_name: str - payload: Dict[str, Any] - - -# A dummy event service for testing event issuing -class DummyEventService(EventServiceBase): - """Dummy event service for testing.""" - - events: List[DummyEvent] - - def __init__(self) -> None: - super().__init__() - self.events = [] - - def dispatch(self, event_name: str, payload: Any) -> None: - """Dispatch an event by appending it to self.events.""" - self.events.append(DummyEvent(event_name=payload["event"], payload=payload["data"])) diff --git a/tests/test_graph_execution_state.py b/tests/test_graph_execution_state.py index 9a350374311..0bb15b17df6 100644 --- a/tests/test_graph_execution_state.py +++ b/tests/test_graph_execution_state.py @@ -1,4 +1,3 @@ -import logging from typing import Optional from unittest.mock import Mock @@ -8,17 +7,12 @@ from .test_nodes import ( # isort: split PromptCollectionTestInvocation, PromptTestInvocation, - TestEventService, TextToImageTestInvocation, ) from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext from invokeai.app.invocations.collections import RangeInvocation from invokeai.app.invocations.math import AddInvocation, MultiplyInvocation -from invokeai.app.services.config.config_default import InvokeAIAppConfig -from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache -from invokeai.app.services.invocation_services import InvocationServices -from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService from invokeai.app.services.shared.graph import ( CollectInvocation, Graph, @@ -38,39 +32,6 @@ def simple_graph() -> Graph: return g -# This must be defined here to avoid issues with the dynamic creation of the union of all invocation types -# Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate -# the test invocations. -@pytest.fixture -def mock_services() -> InvocationServices: - configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0) - # NOTE: none of these are actually called by the test invocations - return InvocationServices( - board_image_records=None, # type: ignore - board_images=None, # type: ignore - board_records=None, # type: ignore - boards=None, # type: ignore - bulk_download=None, # type: ignore - configuration=configuration, - events=TestEventService(), - image_files=None, # type: ignore - image_records=None, # type: ignore - images=None, # type: ignore - invocation_cache=MemoryInvocationCache(max_cache_size=0), - logger=logging, # type: ignore - model_manager=None, # type: ignore - download_queue=None, # type: ignore - names=None, # type: ignore - performance_statistics=InvocationStatsService(), - session_processor=None, # type: ignore - session_queue=None, # type: ignore - urls=None, # type: ignore - workflow_records=None, # type: ignore - tensors=None, # type: ignore - conditioning=None, # type: ignore - ) - - def invoke_next(g: GraphExecutionState) -> tuple[Optional[BaseInvocation], Optional[BaseInvocationOutput]]: n = g.next() if n is None: diff --git a/tests/test_nodes.py b/tests/test_nodes.py index aab3d9c7b4b..e1fe8570405 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -1,5 +1,7 @@ from typing import Any, Callable, Union +from pydantic import BaseModel + from invokeai.app.invocations.baseinvocation import ( BaseInvocation, BaseInvocationOutput, @@ -115,25 +117,22 @@ def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> Edg ) -class TestEvent: - event_name: str - payload: Any +class TestEvent(BaseModel): __test__ = False # not a pytest test case - def __init__(self, event_name: str, payload: Any): - self.event_name = event_name - self.payload = payload + event_name: str + payload: Any class TestEventService(EventServiceBase): - events: list __test__ = False # not a pytest test case def __init__(self): super().__init__() - self.events = [] + self.events: list[TestEvent] = [] def dispatch(self, event_name: str, payload: Any) -> None: + self.events.append(TestEvent(event_name=payload["event"], payload=payload["data"])) pass From aa03b6af7d57b18a7958578835e286a88d816a8c Mon Sep 17 00:00:00 2001 From: Stefan Tobler Date: Sun, 18 Feb 2024 00:01:15 -0500 Subject: [PATCH 182/340] setting up event listeners for bulk download socket --- .../app/store/nanostores/bulkDownloadId.ts | 9 +++++ .../src/features/system/store/configSlice.ts | 2 +- .../web/src/services/events/actions.ts | 15 ++++++++ .../frontend/web/src/services/events/types.ts | 34 +++++++++++++++++-- .../services/events/util/setEventListeners.ts | 18 ++++++++++ 5 files changed, 75 insertions(+), 3 deletions(-) create mode 100644 invokeai/frontend/web/src/app/store/nanostores/bulkDownloadId.ts diff --git a/invokeai/frontend/web/src/app/store/nanostores/bulkDownloadId.ts b/invokeai/frontend/web/src/app/store/nanostores/bulkDownloadId.ts new file mode 100644 index 00000000000..5615124493f --- /dev/null +++ b/invokeai/frontend/web/src/app/store/nanostores/bulkDownloadId.ts @@ -0,0 +1,9 @@ +import { atom } from 'nanostores'; + +export const DEFAULT_BULK_DOWNLOAD_ID = 'default'; + +/** + * The download id for a bulk download. Used for socket subscriptions. + */ + +export const $bulkDownloadId = atom(DEFAULT_BULK_DOWNLOAD_ID); diff --git a/invokeai/frontend/web/src/features/system/store/configSlice.ts b/invokeai/frontend/web/src/features/system/store/configSlice.ts index 1cf62e89c8c..94f1f1c64a1 100644 --- a/invokeai/frontend/web/src/features/system/store/configSlice.ts +++ b/invokeai/frontend/web/src/features/system/store/configSlice.ts @@ -18,7 +18,7 @@ export const initialConfigState: AppConfig = { shouldUpdateImagesOnConnect: false, shouldFetchMetadataFromApi: false, disabledTabs: [], - disabledFeatures: ['lightbox', 'faceRestore', 'batches', 'bulkDownload'], + disabledFeatures: ['lightbox', 'faceRestore', 'batches'], disabledSDFeatures: ['variation', 'symmetry', 'hires', 'perlinNoise', 'noiseThreshold'], nodesAllowlist: undefined, nodesDenylist: undefined, diff --git a/invokeai/frontend/web/src/services/events/actions.ts b/invokeai/frontend/web/src/services/events/actions.ts index 101e928f797..b80363315e8 100644 --- a/invokeai/frontend/web/src/services/events/actions.ts +++ b/invokeai/frontend/web/src/services/events/actions.ts @@ -1,5 +1,8 @@ import { createAction } from '@reduxjs/toolkit'; import type { + BulkDownloadCompletedEvent, + BulkDownloadFailedEvent, + BulkDownloadStartedEvent, GeneratorProgressEvent, GraphExecutionStateCompleteEvent, InvocationCompleteEvent, @@ -64,3 +67,15 @@ export const socketInvocationRetrievalError = createAction<{ export const socketQueueItemStatusChanged = createAction<{ data: QueueItemStatusChangedEvent; }>('socket/socketQueueItemStatusChanged'); + +export const socketBulkDownloadStarted = createAction<{ + data: BulkDownloadStartedEvent; +}>('socket/socketBulkDownloadStarted'); + +export const socketBulkDownloadCompleted = createAction<{ + data: BulkDownloadCompletedEvent; +}>('socket/socketBulkDownloadCompleted'); + +export const socketBulkDownloadFailed = createAction<{ + data: BulkDownloadFailedEvent; +}>('socket/socketBulkDownloadFailed'); diff --git a/invokeai/frontend/web/src/services/events/types.ts b/invokeai/frontend/web/src/services/events/types.ts index 9579b6abc10..092132fea2e 100644 --- a/invokeai/frontend/web/src/services/events/types.ts +++ b/invokeai/frontend/web/src/services/events/types.ts @@ -156,7 +156,7 @@ export type InvocationRetrievalErrorEvent = { * * @example socket.on('queue_item_status_changed', (data: QueueItemStatusChangedEvent) => { ... } */ -export type QueueItemStatusChangedEvent = { +export type QueueItemStatusChangedEvent = { queue_id: string; queue_item: { queue_id: string; @@ -191,7 +191,7 @@ export type QueueItemStatusChangedEvent = { failed: number; canceled: number; total: number; - }; + }; }; export type ClientEmitSubscribeQueue = { @@ -202,6 +202,31 @@ export type ClientEmitUnsubscribeQueue = { queue_id: string; }; +export type BulkDownloadStartedEvent = { + bulk_download_id: string; + bulk_download_item_id: string; +}; + +export type BulkDownloadCompletedEvent = { + bulk_download_id: string; + bulk_download_item_id: string; + bulk_download_item_name: string; +}; + +export type BulkDownloadFailedEvent = { + bulk_download_id: string; + bulk_download_item_id: string; + error: string; +} + +export type ClientEmitSubscribeBulkDownload = { + bulk_download_id: string; +}; + +export type ClientEmitUnsubscribeBulkDownload = { + bulk_download_id: string; +}; + export type ServerToClientEvents = { generator_progress: (payload: GeneratorProgressEvent) => void; invocation_complete: (payload: InvocationCompleteEvent) => void; @@ -213,6 +238,9 @@ export type ServerToClientEvents = { session_retrieval_error: (payload: SessionRetrievalErrorEvent) => void; invocation_retrieval_error: (payload: InvocationRetrievalErrorEvent) => void; queue_item_status_changed: (payload: QueueItemStatusChangedEvent) => void; + bulk_download_started: (payload: BulkDownloadStartedEvent) => void; + bulk_download_completed: (payload: BulkDownloadCompletedEvent) => void; + bulk_download_failed: (payload: BulkDownloadFailedEvent) => void; }; export type ClientToServerEvents = { @@ -220,4 +248,6 @@ export type ClientToServerEvents = { disconnect: () => void; subscribe_queue: (payload: ClientEmitSubscribeQueue) => void; unsubscribe_queue: (payload: ClientEmitUnsubscribeQueue) => void; + subscribe_bulk_download: (payload: ClientEmitSubscribeBulkDownload) => void; + unsubscribe_bulk_download: (payload: ClientEmitUnsubscribeBulkDownload) => void; }; diff --git a/invokeai/frontend/web/src/services/events/util/setEventListeners.ts b/invokeai/frontend/web/src/services/events/util/setEventListeners.ts index c66defee605..d851a185ff8 100644 --- a/invokeai/frontend/web/src/services/events/util/setEventListeners.ts +++ b/invokeai/frontend/web/src/services/events/util/setEventListeners.ts @@ -1,8 +1,12 @@ +import { $bulkDownloadId } from 'app/store/nanostores/bulkDownloadId'; import { $queueId } from 'app/store/nanostores/queueId'; import type { AppDispatch } from 'app/store/store'; import { addToast } from 'features/system/store/systemSlice'; import { makeToast } from 'features/system/util/makeToast'; import { + socketBulkDownloadCompleted, + socketBulkDownloadFailed, + socketBulkDownloadStarted, socketConnected, socketDisconnected, socketGeneratorProgress, @@ -34,6 +38,8 @@ export const setEventListeners = (arg: SetEventListenersArg) => { dispatch(socketConnected()); const queue_id = $queueId.get(); socket.emit('subscribe_queue', { queue_id }); + const bulk_download_id = $bulkDownloadId.get(); + socket.emit('subscribe_bulk_download', { bulk_download_id }); }); socket.on('connect_error', (error) => { @@ -150,4 +156,16 @@ export const setEventListeners = (arg: SetEventListenersArg) => { socket.on('queue_item_status_changed', (data) => { dispatch(socketQueueItemStatusChanged({ data })); }); + + socket.on('bulk_download_started', (data) => { + dispatch(socketBulkDownloadStarted({ data })); + }); + + socket.on('bulk_download_completed', (data) => { + dispatch(socketBulkDownloadCompleted({ data })); + }); + + socket.on('bulk_download_failed', (data) => { + dispatch(socketBulkDownloadFailed({ data })); + }); }; From 555c10a827fd7ea757a2977162537f135be9ba08 Mon Sep 17 00:00:00 2001 From: Stefan Tobler Date: Mon, 19 Feb 2024 14:03:26 -0500 Subject: [PATCH 183/340] implementing download for bulk_download events --- invokeai/frontend/web/public/locales/en.json | 2 + .../middleware/listenerMiddleware/index.ts | 4 ++ .../socketio/socketBulkDownloadComplete.ts | 41 +++++++++++++++++++ .../socketio/socketBulkDownloadFailed.ts | 32 +++++++++++++++ .../frontend/web/src/services/events/types.ts | 8 ++-- 5 files changed, 84 insertions(+), 3 deletions(-) create mode 100644 invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketBulkDownloadComplete.ts create mode 100644 invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketBulkDownloadFailed.ts diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 32c707f9080..32d3d382bd5 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -426,6 +426,8 @@ "downloadSelection": "Download Selection", "preparingDownload": "Preparing Download", "preparingDownloadFailed": "Problem Preparing Download", + "bulkDownloadStarting": "Beginning Download", + "bulkDownloadFailed": "Problem Preparing Download", "problemDeletingImages": "Problem Deleting Images", "problemDeletingImagesDesc": "One or more images could not be deleted" }, diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts index 322c4eb1ecd..07d9bb5df52 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts @@ -48,6 +48,8 @@ import { addInitialImageSelectedListener } from './listeners/initialImageSelecte import { addModelSelectedListener } from './listeners/modelSelected'; import { addModelsLoadedListener } from './listeners/modelsLoaded'; import { addDynamicPromptsListener } from './listeners/promptChanged'; +import { addBulkDownloadCompleteEventListener } from './listeners/socketio/socketBulkDownloadComplete'; +import { addBulkDownloadFailedEventListener } from './listeners/socketio/socketBulkDownloadFailed'; import { addSocketConnectedEventListener as addSocketConnectedListener } from './listeners/socketio/socketConnected'; import { addSocketDisconnectedEventListener as addSocketDisconnectedListener } from './listeners/socketio/socketDisconnected'; import { addGeneratorProgressEventListener as addGeneratorProgressListener } from './listeners/socketio/socketGeneratorProgress'; @@ -137,6 +139,8 @@ addModelLoadEventListener(); addSessionRetrievalErrorEventListener(); addInvocationRetrievalErrorEventListener(); addSocketQueueItemStatusChangedEventListener(); +addBulkDownloadCompleteEventListener(); +addBulkDownloadFailedEventListener(); // ControlNet addControlNetImageProcessedListener(); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketBulkDownloadComplete.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketBulkDownloadComplete.ts new file mode 100644 index 00000000000..acdb61ff25f --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketBulkDownloadComplete.ts @@ -0,0 +1,41 @@ +import { logger } from 'app/logging/logger'; +import { addToast } from 'features/system/store/systemSlice'; +import { t } from 'i18next'; +import { socketBulkDownloadCompleted } from 'services/events/actions'; + +import { startAppListening } from '../..'; + +const log = logger('socketio'); + +export const addBulkDownloadCompleteEventListener = () => { + startAppListening({ + actionCreator: socketBulkDownloadCompleted, + effect: async (action, { dispatch }) => { + log.debug(action.payload, 'Bulk download complete'); + + const bulk_download_item_name = action.payload.data.bulk_download_item_name; + + const url = `/api/v1/images/download/${bulk_download_item_name}`; + const a = document.createElement('a'); + a.style.display = 'none'; + a.href = url; + a.download = bulk_download_item_name; + document.body.appendChild(a); + a.click(); + + dispatch( + addToast({ + title: t('gallery.bulkDownloadStarting'), + status: 'success', + ...(action.payload + ? { + description: bulk_download_item_name, + duration: null, + isClosable: true, + } + : {}), + }) + ); + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketBulkDownloadFailed.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketBulkDownloadFailed.ts new file mode 100644 index 00000000000..a9a45c42ae6 --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketBulkDownloadFailed.ts @@ -0,0 +1,32 @@ +import { logger } from 'app/logging/logger'; +import { addToast } from 'features/system/store/systemSlice'; +import { t } from 'i18next'; +import { socketBulkDownloadFailed } from 'services/events/actions'; + +import { startAppListening } from '../..'; + +const log = logger('socketio'); + +export const addBulkDownloadFailedEventListener = () => { + startAppListening({ + actionCreator: socketBulkDownloadFailed, + effect: async (action, { dispatch }) => { + log.debug(action.payload, 'Bulk download error'); + + + dispatch( + addToast({ + title: t('gallery.bulkDownloadFailed'), + status: 'error', + ...(action.payload + ? { + description: action.payload.data.error, + duration: null, + isClosable: true, + } + : {}), + }) + ); + }, + }); +}; diff --git a/invokeai/frontend/web/src/services/events/types.ts b/invokeai/frontend/web/src/services/events/types.ts index 092132fea2e..4e68131ee91 100644 --- a/invokeai/frontend/web/src/services/events/types.ts +++ b/invokeai/frontend/web/src/services/events/types.ts @@ -156,7 +156,7 @@ export type InvocationRetrievalErrorEvent = { * * @example socket.on('queue_item_status_changed', (data: QueueItemStatusChangedEvent) => { ... } */ -export type QueueItemStatusChangedEvent = { +export type QueueItemStatusChangedEvent = { queue_id: string; queue_item: { queue_id: string; @@ -191,7 +191,7 @@ export type QueueItemStatusChangedEvent = { failed: number; canceled: number; total: number; - }; + }; }; export type ClientEmitSubscribeQueue = { @@ -205,6 +205,7 @@ export type ClientEmitUnsubscribeQueue = { export type BulkDownloadStartedEvent = { bulk_download_id: string; bulk_download_item_id: string; + bulk_download_item_name: string; }; export type BulkDownloadCompletedEvent = { @@ -216,8 +217,9 @@ export type BulkDownloadCompletedEvent = { export type BulkDownloadFailedEvent = { bulk_download_id: string; bulk_download_item_id: string; + bulk_download_item_name: string; error: string; -} +}; export type ClientEmitSubscribeBulkDownload = { bulk_download_id: string; From 32afcfcbfd8c0ff0c8b384774007361c28486121 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 20 Feb 2024 22:15:01 +1100 Subject: [PATCH 184/340] feat(bulk_download): update response model, messages --- invokeai/app/api/routers/images.py | 13 ++++++------- .../services/bulk_download/bulk_download_common.py | 5 +---- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/invokeai/app/api/routers/images.py b/invokeai/app/api/routers/images.py index c3504b104d2..dc8a04b7117 100644 --- a/invokeai/app/api/routers/images.py +++ b/invokeai/app/api/routers/images.py @@ -375,9 +375,11 @@ async def unstar_images_in_list( class ImagesDownloaded(BaseModel): response: Optional[str] = Field( - description="If defined, the message to display to the user when images begin downloading" + default=None, description="The message to display to the user when images begin downloading" + ) + bulk_download_item_name: Optional[str] = Field( + default=None, description="The name of the bulk download item for which events will be emitted" ) - bulk_download_item_name: str = Field(description="The bulk download item name of the bulk download item") @images_router.post( @@ -389,7 +391,7 @@ async def download_images_from_list( default=None, description="The list of names of images to download", embed=True ), board_id: Optional[str] = Body( - default=None, description="The board from which image should be downloaded from", embed=True + default=None, description="The board from which image should be downloaded", embed=True ), ) -> ImagesDownloaded: if (image_names is None or len(image_names) == 0) and board_id is None: @@ -402,10 +404,7 @@ async def download_images_from_list( board_id, bulk_download_item_id, ) - return ImagesDownloaded( - response="Your images are preparing to be downloaded", - bulk_download_item_name=bulk_download_item_id + ".zip", - ) + return ImagesDownloaded(bulk_download_item_name=bulk_download_item_id + ".zip") @images_router.api_route( diff --git a/invokeai/app/services/bulk_download/bulk_download_common.py b/invokeai/app/services/bulk_download/bulk_download_common.py index 37b80073bee..68724eb228b 100644 --- a/invokeai/app/services/bulk_download/bulk_download_common.py +++ b/invokeai/app/services/bulk_download/bulk_download_common.py @@ -20,9 +20,6 @@ def __init__(self, message="The bulk download target was not found"): class BulkDownloadParametersException(BulkDownloadException): """Exception raised when a bulk download parameter is invalid.""" - def __init__( - self, - message="The bulk download parameters are invalid, either an array of image names or a board id must be provided", - ): + def __init__(self, message="No image names or board ID provided"): super().__init__(message) self.message = message From 09c858cbe343ec28538878cd8203f9f7c964eb09 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 20 Feb 2024 22:15:14 +1100 Subject: [PATCH 185/340] chore(ui): typegen --- .../frontend/web/src/services/api/schema.ts | 159 ++++++------------ 1 file changed, 52 insertions(+), 107 deletions(-) diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index 47a257ffe6c..2115e797689 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -5,13 +5,6 @@ export type paths = { - "/api/v1/sessions/{session_id}": { - /** - * Get Session - * @description Gets a session - */ - get: operations["get_session"]; - }; "/api/v1/utilities/dynamicprompts": { /** * Parse Dynamicprompts @@ -389,6 +382,13 @@ export type paths = { /** Download Images From List */ post: operations["download_images_from_list"]; }; + "/api/v1/images/download/{bulk_download_item_name}": { + /** + * Get Bulk Download Item + * @description Gets a bulk download zip file + */ + get: operations["get_bulk_download_item"]; + }; "/api/v1/boards/": { /** * List Boards @@ -1152,10 +1152,10 @@ export type components = { * Image Names * @description The list of names of images to download */ - image_names: string[]; + image_names?: string[] | null; /** * Board Id - * @description The board from which image should be downloaded from + * @description The board from which image should be downloaded */ board_id?: string | null; }; @@ -4208,13 +4208,13 @@ export type components = { * Id * @description The id of this graph */ - id?: string; + id?: string | null; /** * Nodes * @description The nodes in this graph */ nodes: { - [key: string]: components["schemas"]["ImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["BooleanInvocation"] | components["schemas"]["IdealSizeInvocation"] | components["schemas"]["ImageChannelMultiplyInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["CalculateImageTilesInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["BlankImageInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["FloatCollectionInvocation"] | components["schemas"]["FaceMaskInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["ColorMapImageProcessorInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["UnsharpMaskInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["IntegerMathInvocation"] | components["schemas"]["FloatToIntegerInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["StringInvocation"] | components["schemas"]["IPAdapterInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["FaceOffInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["FloatMathInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["IntegerCollectionInvocation"] | components["schemas"]["ConditioningCollectionInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CropLatentsCoreInvocation"] | components["schemas"]["TileToPropertiesInvocation"] | components["schemas"]["SeamlessModeInvocation"] | components["schemas"]["RoundInvocation"] | components["schemas"]["MergeTilesToImageInvocation"] | components["schemas"]["StringReplaceInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["DWOpenposeImageProcessorInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["LatentsInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["FloatInvocation"] | components["schemas"]["StringJoinThreeInvocation"] | components["schemas"]["MetadataItemInvocation"] | components["schemas"]["PairTileImageInvocation"] | components["schemas"]["CoreMetadataInvocation"] | components["schemas"]["CalculateImageTilesEvenSplitInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MetadataInvocation"] | components["schemas"]["LatentsCollectionInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["CalculateImageTilesMinimumOverlapInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["DepthAnythingImageProcessorInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["SDXLLoraLoaderInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["BooleanCollectionInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["StringSplitInvocation"] | components["schemas"]["StringSplitNegInvocation"] | components["schemas"]["CV2InfillInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ConditioningInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["FreeUInvocation"] | components["schemas"]["SaveImageInvocation"] | components["schemas"]["RandomFloatInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["CenterPadCropInvocation"] | components["schemas"]["StringCollectionInvocation"] | components["schemas"]["SchedulerInvocation"] | components["schemas"]["LaMaInfillInvocation"] | components["schemas"]["IntegerInvocation"] | components["schemas"]["FaceIdentifierInvocation"] | components["schemas"]["T2IAdapterInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["MergeMetadataInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["ImageChannelOffsetInvocation"] | components["schemas"]["CreateDenoiseMaskInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["BlendLatentsInvocation"] | components["schemas"]["StringJoinInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ColorInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["CompelInvocation"]; + [key: string]: components["schemas"]["LatentsCollectionInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["T2IAdapterInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["FloatToIntegerInvocation"] | components["schemas"]["SaveImageInvocation"] | components["schemas"]["RoundInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["IPAdapterInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["ImageChannelOffsetInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["FloatMathInvocation"] | components["schemas"]["CalculateImageTilesMinimumOverlapInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["MetadataInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["FaceIdentifierInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["CoreMetadataInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["ImageChannelMultiplyInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["LatentsInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["CreateDenoiseMaskInvocation"] | components["schemas"]["TileToPropertiesInvocation"] | components["schemas"]["CV2InfillInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["StringCollectionInvocation"] | components["schemas"]["MetadataItemInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["SeamlessModeInvocation"] | components["schemas"]["FloatCollectionInvocation"] | components["schemas"]["StringJoinInvocation"] | components["schemas"]["ConditioningCollectionInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["StringJoinThreeInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["StringInvocation"] | components["schemas"]["IntegerCollectionInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["RandomFloatInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["LaMaInfillInvocation"] | components["schemas"]["CalculateImageTilesEvenSplitInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["UnsharpMaskInvocation"] | components["schemas"]["BlankImageInvocation"] | components["schemas"]["ConditioningInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["ColorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["ColorMapImageProcessorInvocation"] | components["schemas"]["IntegerMathInvocation"] | components["schemas"]["SDXLLoraLoaderInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["PairTileImageInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["StringReplaceInvocation"] | components["schemas"]["ImageInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["BooleanInvocation"] | components["schemas"]["CropLatentsCoreInvocation"] | components["schemas"]["StringSplitNegInvocation"] | components["schemas"]["FloatInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["BlendLatentsInvocation"] | components["schemas"]["IdealSizeInvocation"] | components["schemas"]["DWOpenposeImageProcessorInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["FreeUInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["SchedulerInvocation"] | components["schemas"]["FaceMaskInvocation"] | components["schemas"]["StringSplitInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["MergeTilesToImageInvocation"] | components["schemas"]["MergeMetadataInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["CenterPadCropInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["BooleanCollectionInvocation"] | components["schemas"]["CalculateImageTilesInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["IntegerInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["DepthAnythingImageProcessorInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["FaceOffInvocation"]; }; /** * Edges @@ -4251,7 +4251,7 @@ export type components = { * @description The results of node executions */ results: { - [key: string]: components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["UNetOutput"] | components["schemas"]["IntegerOutput"] | components["schemas"]["ColorCollectionOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["MetadataOutput"] | components["schemas"]["String2Output"] | components["schemas"]["ControlOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["SeamlessModeOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["ClipSkipInvocationOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["LoraLoaderOutput"] | components["schemas"]["VAEOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["CalculateImageTilesOutput"] | components["schemas"]["DenoiseMaskOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["PairTileImageOutput"] | components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["SchedulerOutput"] | components["schemas"]["IdealSizeOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["IPAdapterOutput"] | components["schemas"]["TileToPropertiesOutput"] | components["schemas"]["T2IAdapterOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["CLIPOutput"] | components["schemas"]["SDXLLoraLoaderOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["LatentsCollectionOutput"]; + [key: string]: components["schemas"]["ModelLoaderOutput"] | components["schemas"]["MetadataOutput"] | components["schemas"]["String2Output"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["IdealSizeOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["SchedulerOutput"] | components["schemas"]["SeamlessModeOutput"] | components["schemas"]["PairTileImageOutput"] | components["schemas"]["ClipSkipInvocationOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["DenoiseMaskOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["T2IAdapterOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["VAEOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["LoraLoaderOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["SDXLLoraLoaderOutput"] | components["schemas"]["IntegerOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["LatentsCollectionOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["CalculateImageTilesOutput"] | components["schemas"]["IPAdapterOutput"] | components["schemas"]["TileToPropertiesOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["UNetOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["CLIPOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["ColorCollectionOutput"] | components["schemas"]["StringCollectionOutput"]; }; /** * Errors @@ -5674,9 +5674,14 @@ export type components = { ImagesDownloaded: { /** * Response - * @description If defined, the message to display to the user when images begin downloading + * @description The message to display to the user when images begin downloading + */ + response?: string | null; + /** + * Bulk Download Item Name + * @description The name of the bulk download item for which events will be emitted */ - response: string | null; + bulk_download_item_name?: string | null; }; /** ImagesUpdatedFromListResult */ ImagesUpdatedFromListResult: { @@ -6377,7 +6382,7 @@ export type components = { */ AllowDifferentLicense?: boolean; /** @description Type of commercial use allowed or 'No' if no commercial use is allowed. */ - AllowCommercialUse?: components["schemas"]["CommercialUsage"]; + AllowCommercialUse?: components["schemas"]["CommercialUsage"] | null; }; /** * Lineart Anime Processor @@ -11091,66 +11096,6 @@ export type components = { * @enum {string} */ UIType: "SDXLMainModelField" | "SDXLRefinerModelField" | "ONNXModelField" | "VAEModelField" | "LoRAModelField" | "ControlNetModelField" | "IPAdapterModelField" | "SchedulerField" | "AnyField" | "CollectionField" | "CollectionItemField" | "DEPRECATED_Boolean" | "DEPRECATED_Color" | "DEPRECATED_Conditioning" | "DEPRECATED_Control" | "DEPRECATED_Float" | "DEPRECATED_Image" | "DEPRECATED_Integer" | "DEPRECATED_Latents" | "DEPRECATED_String" | "DEPRECATED_BooleanCollection" | "DEPRECATED_ColorCollection" | "DEPRECATED_ConditioningCollection" | "DEPRECATED_ControlCollection" | "DEPRECATED_FloatCollection" | "DEPRECATED_ImageCollection" | "DEPRECATED_IntegerCollection" | "DEPRECATED_LatentsCollection" | "DEPRECATED_StringCollection" | "DEPRECATED_BooleanPolymorphic" | "DEPRECATED_ColorPolymorphic" | "DEPRECATED_ConditioningPolymorphic" | "DEPRECATED_ControlPolymorphic" | "DEPRECATED_FloatPolymorphic" | "DEPRECATED_ImagePolymorphic" | "DEPRECATED_IntegerPolymorphic" | "DEPRECATED_LatentsPolymorphic" | "DEPRECATED_StringPolymorphic" | "DEPRECATED_MainModel" | "DEPRECATED_UNet" | "DEPRECATED_Vae" | "DEPRECATED_CLIP" | "DEPRECATED_Collection" | "DEPRECATED_CollectionItem" | "DEPRECATED_Enum" | "DEPRECATED_WorkflowField" | "DEPRECATED_IsIntermediate" | "DEPRECATED_BoardField" | "DEPRECATED_MetadataItem" | "DEPRECATED_MetadataItemCollection" | "DEPRECATED_MetadataItemPolymorphic" | "DEPRECATED_MetadataDict"; - /** - * ControlNetModelFormat - * @description An enumeration. - * @enum {string} - */ - ControlNetModelFormat: "checkpoint" | "diffusers"; - /** - * LoRAModelFormat - * @description An enumeration. - * @enum {string} - */ - LoRAModelFormat: "lycoris" | "diffusers"; - /** - * StableDiffusionXLModelFormat - * @description An enumeration. - * @enum {string} - */ - StableDiffusionXLModelFormat: "checkpoint" | "diffusers"; - /** - * IPAdapterModelFormat - * @description An enumeration. - * @enum {string} - */ - IPAdapterModelFormat: "invokeai"; - /** - * T2IAdapterModelFormat - * @description An enumeration. - * @enum {string} - */ - T2IAdapterModelFormat: "diffusers"; - /** - * StableDiffusion1ModelFormat - * @description An enumeration. - * @enum {string} - */ - StableDiffusion1ModelFormat: "checkpoint" | "diffusers"; - /** - * CLIPVisionModelFormat - * @description An enumeration. - * @enum {string} - */ - CLIPVisionModelFormat: "diffusers"; - /** - * StableDiffusionOnnxModelFormat - * @description An enumeration. - * @enum {string} - */ - StableDiffusionOnnxModelFormat: "olive" | "onnx"; - /** - * StableDiffusion2ModelFormat - * @description An enumeration. - * @enum {string} - */ - StableDiffusion2ModelFormat: "checkpoint" | "diffusers"; - /** - * VaeModelFormat - * @description An enumeration. - * @enum {string} - */ - VaeModelFormat: "checkpoint" | "diffusers"; }; responses: never; parameters: never; @@ -11165,36 +11110,6 @@ export type external = Record; export type operations = { - /** - * Get Session - * @description Gets a session - */ - get_session: { - parameters: { - path: { - /** @description The id of the session to get */ - session_id: string; - }; - }; - responses: { - /** @description Successful Response */ - 200: { - content: { - "application/json": components["schemas"]["GraphExecutionState"]; - }; - }; - /** @description Session not found */ - 404: { - content: never; - }; - /** @description Validation Error */ - 422: { - content: { - "application/json": components["schemas"]["HTTPValidationError"]; - }; - }; - }; - }; /** * Parse Dynamicprompts * @description Creates a batch process @@ -12451,14 +12366,14 @@ export type operations = { }; /** Download Images From List */ download_images_from_list: { - requestBody: { + requestBody?: { content: { "application/json": components["schemas"]["Body_download_images_from_list"]; }; }; responses: { /** @description Successful Response */ - 200: { + 202: { content: { "application/json": components["schemas"]["ImagesDownloaded"]; }; @@ -12471,6 +12386,36 @@ export type operations = { }; }; }; + /** + * Get Bulk Download Item + * @description Gets a bulk download zip file + */ + get_bulk_download_item: { + parameters: { + path: { + /** @description The bulk_download_item_name of the bulk download item to get */ + bulk_download_item_name: string; + }; + }; + responses: { + /** @description Return the complete bulk download item */ + 200: { + content: { + "application/zip": unknown; + }; + }; + /** @description Image not found */ + 404: { + content: never; + }; + /** @description Validation Error */ + 422: { + content: { + "application/json": components["schemas"]["HTTPValidationError"]; + }; + }; + }; + }; /** * List Boards * @description Gets a list of boards From 0ce7d032d6473bfd5054e29be5cd977eb200841d Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 20 Feb 2024 22:26:38 +1100 Subject: [PATCH 186/340] feat(ui): revise bulk download listeners - Use a single listener for all of the to keep them in one spot - Use the bulk download item name as a toast id so we can update the existing toasts - Update handling to work with other environments - Move all bulk download handling from components to listener --- invokeai/frontend/web/public/locales/en.json | 9 +- .../middleware/listenerMiddleware/index.ts | 6 +- .../listeners/bulkDownload.ts | 118 ++++++++++++++++++ .../socketio/socketBulkDownloadComplete.ts | 41 ------ .../socketio/socketBulkDownloadFailed.ts | 32 ----- .../components/Boards/BoardContextMenu.tsx | 33 +---- .../MultipleSelectionMenuItems.tsx | 32 +---- 7 files changed, 131 insertions(+), 140 deletions(-) create mode 100644 invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/bulkDownload.ts delete mode 100644 invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketBulkDownloadComplete.ts delete mode 100644 invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketBulkDownloadFailed.ts diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 32d3d382bd5..9abf0b80aa2 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -424,10 +424,11 @@ "uploads": "Uploads", "deleteSelection": "Delete Selection", "downloadSelection": "Download Selection", - "preparingDownload": "Preparing Download", - "preparingDownloadFailed": "Problem Preparing Download", - "bulkDownloadStarting": "Beginning Download", - "bulkDownloadFailed": "Problem Preparing Download", + "bulkDownloadRequested": "Preparing Download", + "bulkDownloadRequestedDesc": "Your download request is being prepared. This may take a few moments.", + "bulkDownloadRequestFailed": "Problem Preparing Download", + "bulkDownloadStarting": "Download Starting", + "bulkDownloadFailed": "Download Failed", "problemDeletingImages": "Problem Deleting Images", "problemDeletingImagesDesc": "One or more images could not be deleted" }, diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts index 07d9bb5df52..23e23c11404 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts @@ -1,5 +1,6 @@ import type { ListenerEffect, TypedAddListener, TypedStartListening, UnknownAction } from '@reduxjs/toolkit'; import { addListener, createListenerMiddleware } from '@reduxjs/toolkit'; +import { addBulkDownloadListeners } from 'app/store/middleware/listenerMiddleware/listeners/bulkDownload'; import { addGalleryImageClickedListener } from 'app/store/middleware/listenerMiddleware/listeners/galleryImageClicked'; import type { AppDispatch, RootState } from 'app/store/store'; @@ -48,8 +49,6 @@ import { addInitialImageSelectedListener } from './listeners/initialImageSelecte import { addModelSelectedListener } from './listeners/modelSelected'; import { addModelsLoadedListener } from './listeners/modelsLoaded'; import { addDynamicPromptsListener } from './listeners/promptChanged'; -import { addBulkDownloadCompleteEventListener } from './listeners/socketio/socketBulkDownloadComplete'; -import { addBulkDownloadFailedEventListener } from './listeners/socketio/socketBulkDownloadFailed'; import { addSocketConnectedEventListener as addSocketConnectedListener } from './listeners/socketio/socketConnected'; import { addSocketDisconnectedEventListener as addSocketDisconnectedListener } from './listeners/socketio/socketDisconnected'; import { addGeneratorProgressEventListener as addGeneratorProgressListener } from './listeners/socketio/socketGeneratorProgress'; @@ -139,8 +138,7 @@ addModelLoadEventListener(); addSessionRetrievalErrorEventListener(); addInvocationRetrievalErrorEventListener(); addSocketQueueItemStatusChangedEventListener(); -addBulkDownloadCompleteEventListener(); -addBulkDownloadFailedEventListener(); +addBulkDownloadListeners(); // ControlNet addControlNetImageProcessedListener(); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/bulkDownload.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/bulkDownload.ts new file mode 100644 index 00000000000..39d7e574c2b --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/bulkDownload.ts @@ -0,0 +1,118 @@ +import type { UseToastOptions } from '@invoke-ai/ui-library'; +import { createStandaloneToast, theme, TOAST_OPTIONS } from '@invoke-ai/ui-library'; +import { logger } from 'app/logging/logger'; +import { startAppListening } from 'app/store/middleware/listenerMiddleware'; +import { t } from 'i18next'; +import { imagesApi } from 'services/api/endpoints/images'; +import { + socketBulkDownloadCompleted, + socketBulkDownloadFailed, + socketBulkDownloadStarted, +} from 'services/events/actions'; + +const log = logger('images'); + +const { toast } = createStandaloneToast({ + theme: theme, + defaultOptions: TOAST_OPTIONS.defaultOptions, +}); + +export const addBulkDownloadListeners = () => { + startAppListening({ + matcher: imagesApi.endpoints.bulkDownloadImages.matchFulfilled, + effect: async (action) => { + log.debug(action.payload, 'Bulk download requested'); + + // If we have an item name, we are processing the bulk download locally and should use it as the toast id to + // prevent multiple toasts for the same item. + toast({ + id: action.payload.bulk_download_item_name ?? undefined, + title: t('gallery.bulkDownloadRequested'), + status: 'success', + // Show the response message if it exists, otherwise show the default message + description: action.payload.response || t('gallery.bulkDownloadRequestedDesc'), + duration: null, + isClosable: true, + }); + }, + }); + + startAppListening({ + matcher: imagesApi.endpoints.bulkDownloadImages.matchRejected, + effect: async () => { + log.debug('Bulk download request failed'); + + // There isn't any toast to update if we get this event. + toast({ + title: t('gallery.bulkDownloadRequestFailed'), + status: 'success', + isClosable: true, + }); + }, + }); + + startAppListening({ + actionCreator: socketBulkDownloadStarted, + effect: async (action) => { + // This should always happen immediately after the bulk download request, so we don't need to show a toast here. + log.debug(action.payload.data, 'Bulk download preparation started'); + }, + }); + + startAppListening({ + actionCreator: socketBulkDownloadCompleted, + effect: async (action) => { + log.debug(action.payload.data, 'Bulk download preparation completed'); + + const { bulk_download_item_name } = action.payload.data; + + // TODO(psyche): This URL may break in in some environments (e.g. Nvidia workbench) but we need to test it first + const url = `/api/v1/images/download/${bulk_download_item_name}`; + const a = document.createElement('a'); + a.style.display = 'none'; + a.href = url; + a.download = bulk_download_item_name; + document.body.appendChild(a); + a.click(); + + const toastOptions: UseToastOptions = { + id: bulk_download_item_name, + title: t('gallery.bulkDownloadStarting'), + status: 'success', + description: bulk_download_item_name, + duration: 5000, + isClosable: true, + }; + + if (toast.isActive(bulk_download_item_name)) { + toast.update(bulk_download_item_name, toastOptions); + } else { + toast(toastOptions); + } + }, + }); + + startAppListening({ + actionCreator: socketBulkDownloadFailed, + effect: async (action) => { + log.debug(action.payload.data, 'Bulk download preparation failed'); + + const { bulk_download_item_name } = action.payload.data; + + const toastOptions: UseToastOptions = { + id: bulk_download_item_name, + title: t('gallery.bulkDownloadFailed'), + status: 'error', + description: action.payload.data.error, + duration: null, + isClosable: true, + }; + + if (toast.isActive(bulk_download_item_name)) { + toast.update(bulk_download_item_name, toastOptions); + } else { + toast(toastOptions); + } + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketBulkDownloadComplete.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketBulkDownloadComplete.ts deleted file mode 100644 index acdb61ff25f..00000000000 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketBulkDownloadComplete.ts +++ /dev/null @@ -1,41 +0,0 @@ -import { logger } from 'app/logging/logger'; -import { addToast } from 'features/system/store/systemSlice'; -import { t } from 'i18next'; -import { socketBulkDownloadCompleted } from 'services/events/actions'; - -import { startAppListening } from '../..'; - -const log = logger('socketio'); - -export const addBulkDownloadCompleteEventListener = () => { - startAppListening({ - actionCreator: socketBulkDownloadCompleted, - effect: async (action, { dispatch }) => { - log.debug(action.payload, 'Bulk download complete'); - - const bulk_download_item_name = action.payload.data.bulk_download_item_name; - - const url = `/api/v1/images/download/${bulk_download_item_name}`; - const a = document.createElement('a'); - a.style.display = 'none'; - a.href = url; - a.download = bulk_download_item_name; - document.body.appendChild(a); - a.click(); - - dispatch( - addToast({ - title: t('gallery.bulkDownloadStarting'), - status: 'success', - ...(action.payload - ? { - description: bulk_download_item_name, - duration: null, - isClosable: true, - } - : {}), - }) - ); - }, - }); -}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketBulkDownloadFailed.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketBulkDownloadFailed.ts deleted file mode 100644 index a9a45c42ae6..00000000000 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketBulkDownloadFailed.ts +++ /dev/null @@ -1,32 +0,0 @@ -import { logger } from 'app/logging/logger'; -import { addToast } from 'features/system/store/systemSlice'; -import { t } from 'i18next'; -import { socketBulkDownloadFailed } from 'services/events/actions'; - -import { startAppListening } from '../..'; - -const log = logger('socketio'); - -export const addBulkDownloadFailedEventListener = () => { - startAppListening({ - actionCreator: socketBulkDownloadFailed, - effect: async (action, { dispatch }) => { - log.debug(action.payload, 'Bulk download error'); - - - dispatch( - addToast({ - title: t('gallery.bulkDownloadFailed'), - status: 'error', - ...(action.payload - ? { - description: action.payload.data.error, - duration: null, - isClosable: true, - } - : {}), - }) - ); - }, - }); -}; diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardContextMenu.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardContextMenu.tsx index 490e8eac9ee..ad6c37532ef 100644 --- a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardContextMenu.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardContextMenu.tsx @@ -5,7 +5,6 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { autoAddBoardIdChanged, selectGallerySlice } from 'features/gallery/store/gallerySlice'; import type { BoardId } from 'features/gallery/store/types'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; -import { addToast } from 'features/system/store/systemSlice'; import { memo, useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { PiDownloadBold, PiPlusBold } from 'react-icons/pi'; @@ -41,35 +40,9 @@ const BoardContextMenu = ({ board, board_id, setBoardToDelete, children }: Props dispatch(autoAddBoardIdChanged(board_id)); }, [board_id, dispatch]); - const handleBulkDownload = useCallback(async () => { - try { - const response = await bulkDownload({ - image_names: [], - board_id: board_id, - }).unwrap(); - - dispatch( - addToast({ - title: t('gallery.preparingDownload'), - status: 'success', - ...(response.response - ? { - description: response.response, - duration: null, - isClosable: true, - } - : {}), - }) - ); - } catch { - dispatch( - addToast({ - title: t('gallery.preparingDownloadFailed'), - status: 'error', - }) - ); - } - }, [t, board_id, bulkDownload, dispatch]); + const handleBulkDownload = useCallback(() => { + bulkDownload({ image_names: [], board_id: board_id }); + }, [board_id, bulkDownload]); const renderMenuFunc = useCallback( () => ( diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/MultipleSelectionMenuItems.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/MultipleSelectionMenuItems.tsx index e8f71c02f39..7b1fa73472d 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/MultipleSelectionMenuItems.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/MultipleSelectionMenuItems.tsx @@ -5,7 +5,6 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { imagesToChangeSelected, isModalOpenChanged } from 'features/changeBoardModal/store/slice'; import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; -import { addToast } from 'features/system/store/systemSlice'; import { memo, useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { PiDownloadSimpleBold, PiFoldersBold, PiStarBold, PiStarFill, PiTrashSimpleBold } from 'react-icons/pi'; @@ -44,34 +43,9 @@ const MultipleSelectionMenuItems = () => { unstarImages({ imageDTOs: selection }); }, [unstarImages, selection]); - const handleBulkDownload = useCallback(async () => { - try { - const response = await bulkDownload({ - image_names: selection.map((img) => img.image_name), - }).unwrap(); - - dispatch( - addToast({ - title: t('gallery.preparingDownload'), - status: 'success', - ...(response.response - ? { - description: response.response, - duration: null, - isClosable: true, - } - : {}), - }) - ); - } catch { - dispatch( - addToast({ - title: t('gallery.preparingDownloadFailed'), - status: 'error', - }) - ); - } - }, [t, selection, bulkDownload, dispatch]); + const handleBulkDownload = useCallback(() => { + bulkDownload({ image_names: selection.map((img) => img.image_name) }); + }, [selection, bulkDownload]); const areAllStarred = useMemo(() => { return selection.every((img) => img.starred); From 9dd4332863c800172cd3cd2d473e24338a12b104 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 20 Feb 2024 22:40:16 +1100 Subject: [PATCH 187/340] feat(ui): do not subscribe to bulk download sio room if baseUrl is set --- .../web/src/services/events/util/setEventListeners.ts | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/invokeai/frontend/web/src/services/events/util/setEventListeners.ts b/invokeai/frontend/web/src/services/events/util/setEventListeners.ts index d851a185ff8..1f27955ca73 100644 --- a/invokeai/frontend/web/src/services/events/util/setEventListeners.ts +++ b/invokeai/frontend/web/src/services/events/util/setEventListeners.ts @@ -1,3 +1,4 @@ +import { $baseUrl } from 'app/store/nanostores/baseUrl'; import { $bulkDownloadId } from 'app/store/nanostores/bulkDownloadId'; import { $queueId } from 'app/store/nanostores/queueId'; import type { AppDispatch } from 'app/store/store'; @@ -38,8 +39,10 @@ export const setEventListeners = (arg: SetEventListenersArg) => { dispatch(socketConnected()); const queue_id = $queueId.get(); socket.emit('subscribe_queue', { queue_id }); - const bulk_download_id = $bulkDownloadId.get(); - socket.emit('subscribe_bulk_download', { bulk_download_id }); + if (!$baseUrl.get()) { + const bulk_download_id = $bulkDownloadId.get(); + socket.emit('subscribe_bulk_download', { bulk_download_id }); + } }); socket.on('connect_error', (error) => { From 71f7dcab16711ceb51373cd938ae5b10486d5c6a Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 21 Feb 2024 09:44:16 +1100 Subject: [PATCH 188/340] fix(ui): fix package build --- invokeai/frontend/web/tsconfig.node.json | 2 +- invokeai/frontend/web/vite.config.mts | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/invokeai/frontend/web/tsconfig.node.json b/invokeai/frontend/web/tsconfig.node.json index b5936d415ca..046964021fc 100644 --- a/invokeai/frontend/web/tsconfig.node.json +++ b/invokeai/frontend/web/tsconfig.node.json @@ -5,5 +5,5 @@ "moduleResolution": "Node", "allowSyntheticDefaultImports": true }, - "include": ["vite.config.mts", "config/vite.app.config.mts", "config/vite.package.config.mts", "config/common.mts"] + "include": ["vite.config.mts"] } diff --git a/invokeai/frontend/web/vite.config.mts b/invokeai/frontend/web/vite.config.mts index 32e3e1f64fe..ae64b7198d4 100644 --- a/invokeai/frontend/web/vite.config.mts +++ b/invokeai/frontend/web/vite.config.mts @@ -26,7 +26,7 @@ export default defineConfig(({ mode }) => { build: { cssCodeSplit: true, lib: { - entry: path.resolve(__dirname, '../src/index.ts'), + entry: path.resolve(__dirname, './src/index.ts'), name: 'InvokeAIUI', fileName: (format) => `invoke-ai-ui.${format}.js`, }, @@ -44,12 +44,12 @@ export default defineConfig(({ mode }) => { }, resolve: { alias: { - app: path.resolve(__dirname, '../src/app'), - assets: path.resolve(__dirname, '../src/assets'), - common: path.resolve(__dirname, '../src/common'), - features: path.resolve(__dirname, '../src/features'), - services: path.resolve(__dirname, '../src/services'), - theme: path.resolve(__dirname, '../src/theme'), + app: path.resolve(__dirname, './src/app'), + assets: path.resolve(__dirname, './src/assets'), + common: path.resolve(__dirname, './src/common'), + features: path.resolve(__dirname, './src/features'), + services: path.resolve(__dirname, './src/services'), + theme: path.resolve(__dirname, './src/theme'), }, }, }; From 51d5f05b561058de83fc7dc1d742cc8087541781 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 21 Feb 2024 19:04:43 +1100 Subject: [PATCH 189/340] fix(nodes): fix TI loading --- invokeai/app/invocations/compel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 47be380626b..50f53225137 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -86,7 +86,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: for trigger in extract_ti_triggers_from_prompt(self.prompt): name = trigger[1:-1] try: - loaded_model = context.models.load(**self.clip.text_encoder.model_dump()).model + loaded_model = context.models.load(key=name).model assert isinstance(loaded_model, TextualInversionModelRaw) ti_list.append((name, loaded_model)) except UnknownModelException: From 7fc8b82637fde57c717817d3ddedc35cad3998c6 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 21 Feb 2024 19:42:36 +1100 Subject: [PATCH 190/340] fix(ui): use model names in badges --- .../src/features/lora/components/LoRACard.tsx | 6 +- .../web/src/features/nodes/types/common.ts | 8 ++- .../parameters/types/parameterSchemas.ts | 4 +- .../AdvancedSettingsAccordion.tsx | 56 ++++++++++--------- .../GenerationSettingsAccordion.tsx | 31 +++++----- .../web/src/services/api/endpoints/models.ts | 13 +++++ .../api/hooks/useSelectedModelConfig.ts | 14 +++++ .../frontend/web/src/services/api/index.ts | 1 + 8 files changed, 85 insertions(+), 48 deletions(-) create mode 100644 invokeai/frontend/web/src/services/api/hooks/useSelectedModelConfig.ts diff --git a/invokeai/frontend/web/src/features/lora/components/LoRACard.tsx b/invokeai/frontend/web/src/features/lora/components/LoRACard.tsx index 71ce1457864..579d45054be 100644 --- a/invokeai/frontend/web/src/features/lora/components/LoRACard.tsx +++ b/invokeai/frontend/web/src/features/lora/components/LoRACard.tsx @@ -15,14 +15,16 @@ import type { LoRA } from 'features/lora/store/loraSlice'; import { loraIsEnabledChanged, loraRemoved, loraWeightChanged } from 'features/lora/store/loraSlice'; import { memo, useCallback } from 'react'; import { PiTrashSimpleBold } from 'react-icons/pi'; +import { useGetModelConfigQuery } from 'services/api/endpoints/models'; type LoRACardProps = { lora: LoRA; }; export const LoRACard = memo((props: LoRACardProps) => { - const dispatch = useAppDispatch(); const { lora } = props; + const dispatch = useAppDispatch(); + const { data: loraConfig } = useGetModelConfigQuery(lora.key); const handleChange = useCallback( (v: number) => { @@ -44,7 +46,7 @@ export const LoRACard = memo((props: LoRACardProps) => { - {lora.key} + {loraConfig?.name ?? lora.key.substring(0, 8)} diff --git a/invokeai/frontend/web/src/features/nodes/types/common.ts b/invokeai/frontend/web/src/features/nodes/types/common.ts index d5d04deaa54..b195ce4434c 100644 --- a/invokeai/frontend/web/src/features/nodes/types/common.ts +++ b/invokeai/frontend/web/src/features/nodes/types/common.ts @@ -67,6 +67,8 @@ export const zModelName = z.string().min(3); export const zModelIdentifier = z.object({ key: z.string().min(1), }); +export const isModelIdentifier = (field: unknown): field is ModelIdentifier => + zModelIdentifier.safeParse(field).success; export const zModelFieldBase = zModelIdentifier; export const zModelIdentifierWithBase = zModelIdentifier.extend({ base: zBaseModel }); export type BaseModel = z.infer; @@ -141,7 +143,7 @@ export type VAEField = z.infer; // #region Control Adapters export const zControlField = z.object({ image: zImageField, - control_model: zControlNetModelField, + control_model: zModelFieldBase, control_weight: z.union([z.number(), z.array(z.number())]).optional(), begin_step_percent: z.number().optional(), end_step_percent: z.number().optional(), @@ -152,7 +154,7 @@ export type ControlField = z.infer; export const zIPAdapterField = z.object({ image: zImageField, - ip_adapter_model: zIPAdapterModelField, + ip_adapter_model: zModelFieldBase, weight: z.number(), begin_step_percent: z.number().optional(), end_step_percent: z.number().optional(), @@ -161,7 +163,7 @@ export type IPAdapterField = z.infer; export const zT2IAdapterField = z.object({ image: zImageField, - t2i_adapter_model: zT2IAdapterModelField, + t2i_adapter_model: zModelFieldBase, weight: z.union([z.number(), z.array(z.number())]).optional(), begin_step_percent: z.number().optional(), end_step_percent: z.number().optional(), diff --git a/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts b/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts index abd8ee28103..b30d5df147b 100644 --- a/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts +++ b/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts @@ -4,7 +4,7 @@ import { zControlNetModelField, zIPAdapterModelField, zLoRAModelField, - zMainModelField, + zModelIdentifierWithBase, zSchedulerField, zSDXLRefinerModelField, zT2IAdapterModelField, @@ -105,7 +105,7 @@ export const isParameterAspectRatio = (val: unknown): val is ParameterAspectRati // #endregion // #region Model -export const zParameterModel = zMainModelField.extend({ base: zBaseModel }); +export const zParameterModel = zModelIdentifierWithBase; export type ParameterModel = z.infer; export const isParameterModel = (val: unknown): val is ParameterModel => zParameterModel.safeParse(val).success; // #endregion diff --git a/invokeai/frontend/web/src/features/settingsAccordions/components/AdvancedSettingsAccordion/AdvancedSettingsAccordion.tsx b/invokeai/frontend/web/src/features/settingsAccordions/components/AdvancedSettingsAccordion/AdvancedSettingsAccordion.tsx index fc8c54576c5..8b10d9bddd4 100644 --- a/invokeai/frontend/web/src/features/settingsAccordions/components/AdvancedSettingsAccordion/AdvancedSettingsAccordion.tsx +++ b/invokeai/frontend/web/src/features/settingsAccordions/components/AdvancedSettingsAccordion/AdvancedSettingsAccordion.tsx @@ -1,5 +1,6 @@ import type { FormLabelProps } from '@invoke-ai/ui-library'; import { Flex, FormControlGroup, StandaloneAccordion } from '@invoke-ai/ui-library'; +import { skipToken } from '@reduxjs/toolkit/query'; import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; import ParamCFGRescaleMultiplier from 'features/parameters/components/Advanced/ParamCFGRescaleMultiplier'; @@ -10,8 +11,9 @@ import ParamVAEModelSelect from 'features/parameters/components/VAEModel/ParamVA import ParamVAEPrecision from 'features/parameters/components/VAEModel/ParamVAEPrecision'; import { selectGenerationSlice } from 'features/parameters/store/generationSlice'; import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/useStandaloneAccordionToggle'; -import { memo } from 'react'; +import { memo, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; +import { useGetModelConfigQuery } from 'services/api/endpoints/models'; const formLabelProps: FormLabelProps = { minW: '9.2rem', @@ -21,31 +23,35 @@ const formLabelProps2: FormLabelProps = { flexGrow: 1, }; -const selectBadges = createMemoizedSelector(selectGenerationSlice, (generation) => { - const badges: (string | number)[] = []; - if (generation.vae) { - // TODO(MM2): Fetch the vae name - let vaeBadge = generation.vae.key; - if (generation.vaePrecision === 'fp16') { - vaeBadge += ` ${generation.vaePrecision}`; - } - badges.push(vaeBadge); - } else if (generation.vaePrecision === 'fp16') { - badges.push(`VAE ${generation.vaePrecision}`); - } - if (generation.clipSkip) { - badges.push(`Skip ${generation.clipSkip}`); - } - if (generation.cfgRescaleMultiplier) { - badges.push(`Rescale ${generation.cfgRescaleMultiplier}`); - } - if (generation.seamlessXAxis || generation.seamlessYAxis) { - badges.push('seamless'); - } - return badges; -}); - export const AdvancedSettingsAccordion = memo(() => { + const vaeKey = useAppSelector((state) => state.generation.vae?.key); + const { data: vaeConfig } = useGetModelConfigQuery(vaeKey ?? skipToken); + const selectBadges = useMemo( + () => + createMemoizedSelector(selectGenerationSlice, (generation) => { + const badges: (string | number)[] = []; + if (vaeConfig) { + let vaeBadge = vaeConfig.name; + if (generation.vaePrecision === 'fp16') { + vaeBadge += ` ${generation.vaePrecision}`; + } + badges.push(vaeBadge); + } else if (generation.vaePrecision === 'fp16') { + badges.push(`VAE ${generation.vaePrecision}`); + } + if (generation.clipSkip) { + badges.push(`Skip ${generation.clipSkip}`); + } + if (generation.cfgRescaleMultiplier) { + badges.push(`Rescale ${generation.cfgRescaleMultiplier}`); + } + if (generation.seamlessXAxis || generation.seamlessYAxis) { + badges.push('seamless'); + } + return badges; + }), + [vaeConfig] + ); const badges = useAppSelector(selectBadges); const { t } = useTranslation(); const { isOpen, onToggle } = useStandaloneAccordionToggle({ diff --git a/invokeai/frontend/web/src/features/settingsAccordions/components/GenerationSettingsAccordion/GenerationSettingsAccordion.tsx b/invokeai/frontend/web/src/features/settingsAccordions/components/GenerationSettingsAccordion/GenerationSettingsAccordion.tsx index cda7dcf6e92..d57e48f11e7 100644 --- a/invokeai/frontend/web/src/features/settingsAccordions/components/GenerationSettingsAccordion/GenerationSettingsAccordion.tsx +++ b/invokeai/frontend/web/src/features/settingsAccordions/components/GenerationSettingsAccordion/GenerationSettingsAccordion.tsx @@ -12,6 +12,7 @@ import { } from '@invoke-ai/ui-library'; import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; +import { EMPTY_ARRAY } from 'app/store/util'; import { LoRAList } from 'features/lora/components/LoRAList'; import LoRASelect from 'features/lora/components/LoRASelect'; import { selectLoraSlice } from 'features/lora/store/loraSlice'; @@ -20,33 +21,31 @@ import ParamCFGScale from 'features/parameters/components/Core/ParamCFGScale'; import ParamScheduler from 'features/parameters/components/Core/ParamScheduler'; import ParamSteps from 'features/parameters/components/Core/ParamSteps'; import ParamMainModelSelect from 'features/parameters/components/MainModel/ParamMainModelSelect'; -import { selectGenerationSlice } from 'features/parameters/store/generationSlice'; import { useExpanderToggle } from 'features/settingsAccordions/hooks/useExpanderToggle'; import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/useStandaloneAccordionToggle'; import { filter } from 'lodash-es'; -import { memo } from 'react'; +import { memo, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; +import { useSelectedModelConfig } from 'services/api/hooks/useSelectedModelConfig'; const formLabelProps: FormLabelProps = { minW: '4rem', }; -const badgesSelector = createMemoizedSelector(selectLoraSlice, selectGenerationSlice, (lora, generation) => { - const enabledLoRAsCount = filter(lora.loras, (l) => !!l.isEnabled).length; - const loraTabBadges = enabledLoRAsCount ? [enabledLoRAsCount] : []; - const accordionBadges: (string | number)[] = []; - // TODO(MM2): fetch model name - if (generation.model) { - accordionBadges.push(generation.model.key); - accordionBadges.push(generation.model.base); - } - - return { loraTabBadges, accordionBadges }; -}); - export const GenerationSettingsAccordion = memo(() => { const { t } = useTranslation(); - const { loraTabBadges, accordionBadges } = useAppSelector(badgesSelector); + const modelConfig = useSelectedModelConfig(); + const selectBadges = useMemo( + () => + createMemoizedSelector(selectLoraSlice, (lora) => { + const enabledLoRAsCount = filter(lora.loras, (l) => !!l.isEnabled).length; + const loraTabBadges = enabledLoRAsCount ? [enabledLoRAsCount] : EMPTY_ARRAY; + const accordionBadges = modelConfig ? [modelConfig.name, modelConfig.base] : EMPTY_ARRAY; + return { loraTabBadges, accordionBadges }; + }), + [modelConfig] + ); + const { loraTabBadges, accordionBadges } = useAppSelector(selectBadges); const { isOpen: isOpenExpander, onToggle: onToggleExpander } = useExpanderToggle({ id: 'generation-settings-advanced', defaultIsOpen: false, diff --git a/invokeai/frontend/web/src/services/api/endpoints/models.ts b/invokeai/frontend/web/src/services/api/endpoints/models.ts index 46be42d9e53..666e0c707d5 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/models.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/models.ts @@ -236,6 +236,18 @@ export const modelsApi = api.injectEndpoints({ }, invalidatesTags: ['Model'], }), + getModelConfig: build.query({ + query: (key) => buildModelsUrl(`i/${key}`), + providesTags: (result) => { + const tags: ApiTagDescription[] = ['Model']; + + if (result) { + tags.push({ type: 'ModelConfig', id: result.key }); + } + + return tags; + }, + }), syncModels: build.mutation({ query: () => { return { @@ -313,6 +325,7 @@ export const modelsApi = api.injectEndpoints({ }); export const { + useGetModelConfigQuery, useGetMainModelsQuery, useGetControlNetModelsQuery, useGetIPAdapterModelsQuery, diff --git a/invokeai/frontend/web/src/services/api/hooks/useSelectedModelConfig.ts b/invokeai/frontend/web/src/services/api/hooks/useSelectedModelConfig.ts new file mode 100644 index 00000000000..4a8d8d72e2a --- /dev/null +++ b/invokeai/frontend/web/src/services/api/hooks/useSelectedModelConfig.ts @@ -0,0 +1,14 @@ +import { createSelector } from '@reduxjs/toolkit'; +import { skipToken } from '@reduxjs/toolkit/query'; +import { useAppSelector } from 'app/store/storeHooks'; +import { selectGenerationSlice } from 'features/parameters/store/generationSlice'; +import { useGetModelConfigQuery } from 'services/api/endpoints/models'; + +const selectModelKey = createSelector(selectGenerationSlice, (generation) => generation.model?.key); + +export const useSelectedModelConfig = () => { + const key = useAppSelector(selectModelKey); + const { currentData: modelConfig } = useGetModelConfigQuery(key ?? skipToken); + + return modelConfig; +}; diff --git a/invokeai/frontend/web/src/services/api/index.ts b/invokeai/frontend/web/src/services/api/index.ts index 3584bdec453..7eeee7b3c82 100644 --- a/invokeai/frontend/web/src/services/api/index.ts +++ b/invokeai/frontend/web/src/services/api/index.ts @@ -26,6 +26,7 @@ export const tagTypes = [ 'BatchStatus', 'InvocationCacheStatus', 'Model', + 'ModelConfig', 'T2IAdapterModel', 'MainModel', 'VaeModel', From 8c34ac23aef74a570412b82484af0a7b7f094fe5 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 21 Feb 2024 19:42:49 +1100 Subject: [PATCH 191/340] fix(ui): handle new model format for metadata --- .../ImageMetadataActions.tsx | 44 +++-- .../ImageMetadataViewer/ImageMetadataItem.tsx | 46 ++++- .../web/src/features/nodes/types/metadata.ts | 4 +- .../parameters/hooks/useRecallParameters.ts | 161 ++++++++++++------ 4 files changed, 178 insertions(+), 77 deletions(-) diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx index 5907ba07000..7eec7e18757 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx @@ -1,3 +1,4 @@ +import { isModelIdentifier } from 'features/nodes/types/common'; import type { ControlNetMetadataItem, CoreMetadata, @@ -6,15 +7,10 @@ import type { T2IAdapterMetadataItem, } from 'features/nodes/types/metadata'; import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters'; -import { - isParameterControlNetModel, - isParameterLoRAModel, - isParameterT2IAdapterModel, -} from 'features/parameters/types/parameterSchemas'; import { memo, useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; -import ImageMetadataItem from './ImageMetadataItem'; +import ImageMetadataItem, { ModelMetadataItem, VAEMetadataItem } from './ImageMetadataItem'; type Props = { metadata?: CoreMetadata; @@ -147,19 +143,19 @@ const ImageMetadataActions = (props: Props) => { const validControlNets: ControlNetMetadataItem[] = useMemo(() => { return metadata?.controlnets - ? metadata.controlnets.filter((controlnet) => isParameterControlNetModel(controlnet.control_model)) + ? metadata.controlnets.filter((controlnet) => isModelIdentifier(controlnet.control_model)) : []; }, [metadata?.controlnets]); const validIPAdapters: IPAdapterMetadataItem[] = useMemo(() => { return metadata?.ipAdapters - ? metadata.ipAdapters.filter((ipAdapter) => isParameterControlNetModel(ipAdapter.ip_adapter_model)) + ? metadata.ipAdapters.filter((ipAdapter) => isModelIdentifier(ipAdapter.ip_adapter_model)) : []; }, [metadata?.ipAdapters]); const validT2IAdapters: T2IAdapterMetadataItem[] = useMemo(() => { return metadata?.t2iAdapters - ? metadata.t2iAdapters.filter((t2iAdapter) => isParameterT2IAdapterModel(t2iAdapter.t2i_adapter_model)) + ? metadata.t2iAdapters.filter((t2iAdapter) => isModelIdentifier(t2iAdapter.t2i_adapter_model)) : []; }, [metadata?.t2iAdapters]); @@ -209,7 +205,7 @@ const ImageMetadataActions = (props: Props) => { )} {metadata.model !== undefined && metadata.model !== null && metadata.model.key && ( - + )} {metadata.width && ( @@ -220,11 +216,7 @@ const ImageMetadataActions = (props: Props) => { {metadata.scheduler && ( )} - + {metadata.steps && ( )} @@ -264,38 +256,42 @@ const ImageMetadataActions = (props: Props) => { )} {metadata.loras && metadata.loras.map((lora, index) => { - if (isParameterLoRAModel(lora.lora)) { + if (isModelIdentifier(lora.lora)) { return ( - ); } })} {validControlNets.map((controlnet, index) => ( - ))} {validIPAdapters.map((ipAdapter, index) => ( - ))} {validT2IAdapters.map((t2iAdapter, index) => ( - ))} diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataItem.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataItem.tsx index c6dbd162698..7d17a2ad3d3 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataItem.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataItem.tsx @@ -1,8 +1,10 @@ import { ExternalLink, Flex, IconButton, Text, Tooltip } from '@invoke-ai/ui-library'; -import { memo, useCallback } from 'react'; +import { skipToken } from '@reduxjs/toolkit/query'; +import { memo, useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { IoArrowUndoCircleOutline } from 'react-icons/io5'; import { PiCopyBold } from 'react-icons/pi'; +import { useGetModelConfigQuery } from 'services/api/endpoints/models'; type MetadataItemProps = { isLink?: boolean; @@ -18,8 +20,9 @@ type MetadataItemProps = { */ const ImageMetadataItem = ({ label, value, onClick, isLink, labelPosition, withCopy = false }: MetadataItemProps) => { const { t } = useTranslation(); - - const handleCopy = useCallback(() => navigator.clipboard.writeText(value.toString()), [value]); + const handleCopy = useCallback(() => { + navigator.clipboard.writeText(value?.toString()); + }, [value]); if (!value) { return null; @@ -68,3 +71,40 @@ const ImageMetadataItem = ({ label, value, onClick, isLink, labelPosition, withC }; export default memo(ImageMetadataItem); + +type VAEMetadataItemProps = { + label: string; + modelKey?: string; + onClick: () => void; +}; + +export const VAEMetadataItem = memo(({ label, modelKey, onClick }: VAEMetadataItemProps) => { + const { data: modelConfig } = useGetModelConfigQuery(modelKey ?? skipToken); + + return ( + + ); +}); + +VAEMetadataItem.displayName = 'VAEMetadataItem'; + +type ModelMetadataItemProps = { + label: string; + modelKey?: string; + + extra?: string; + onClick: () => void; +}; + +export const ModelMetadataItem = memo(({ label, modelKey, extra, onClick }: ModelMetadataItemProps) => { + const { data: modelConfig } = useGetModelConfigQuery(modelKey ?? skipToken); + const value = useMemo(() => { + if (modelConfig) { + return `${modelConfig.name}${extra ?? ''}`; + } + return `${modelKey}${extra ?? ''}`; + }, [extra, modelConfig, modelKey]); + return ; +}); + +ModelMetadataItem.displayName = 'ModelMetadataItem'; diff --git a/invokeai/frontend/web/src/features/nodes/types/metadata.ts b/invokeai/frontend/web/src/features/nodes/types/metadata.ts index 0cc30499e38..493a0464b3f 100644 --- a/invokeai/frontend/web/src/features/nodes/types/metadata.ts +++ b/invokeai/frontend/web/src/features/nodes/types/metadata.ts @@ -3,8 +3,8 @@ import { z } from 'zod'; import { zControlField, zIPAdapterField, - zLoRAModelField, zMainModelField, + zModelFieldBase, zSDXLRefinerModelField, zT2IAdapterField, zVAEModelField, @@ -15,7 +15,7 @@ import { // - https://github.com/colinhacks/zod/issues/2106 // - https://github.com/colinhacks/zod/issues/2854 export const zLoRAMetadataItem = z.object({ - lora: zLoRAModelField.deepPartial(), + lora: zModelFieldBase.deepPartial(), weight: z.number(), }); const zControlNetMetadataItem = zControlField.deepPartial(); diff --git a/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts b/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts index c8b17816bb5..0d464cd9b94 100644 --- a/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts +++ b/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts @@ -11,6 +11,8 @@ import { } from 'features/controlAdapters/util/buildControlAdapter'; import { setHrfEnabled, setHrfMethod, setHrfStrength } from 'features/hrf/store/hrfSlice'; import { loraRecalled, lorasCleared } from 'features/lora/store/loraSlice'; +import type { ModelIdentifier } from 'features/nodes/types/common'; +import { isModelIdentifier } from 'features/nodes/types/common'; import type { ControlNetMetadataItem, CoreMetadata, @@ -37,13 +39,9 @@ import type { ParameterModel } from 'features/parameters/types/parameterSchemas' import { isParameterCFGRescaleMultiplier, isParameterCFGScale, - isParameterControlNetModel, isParameterHeight, isParameterHRFEnabled, isParameterHRFMethod, - isParameterIPAdapterModel, - isParameterLoRAModel, - isParameterModel, isParameterNegativePrompt, isParameterNegativeStylePromptSDXL, isParameterPositivePrompt, @@ -56,7 +54,6 @@ import { isParameterSeed, isParameterSteps, isParameterStrength, - isParameterVAEModel, isParameterWidth, } from 'features/parameters/types/parameterSchemas'; import { @@ -73,15 +70,20 @@ import { import { isNil } from 'lodash-es'; import { useCallback } from 'react'; import { useTranslation } from 'react-i18next'; +import { ALL_BASE_MODELS } from 'services/api/constants'; import { controlNetModelsAdapterSelectors, ipAdapterModelsAdapterSelectors, loraModelsAdapterSelectors, + mainModelsAdapterSelectors, t2iAdapterModelsAdapterSelectors, useGetControlNetModelsQuery, useGetIPAdapterModelsQuery, useGetLoRAModelsQuery, + useGetMainModelsQuery, useGetT2IAdapterModelsQuery, + useGetVaeModelsQuery, + vaeModelsAdapterSelectors, } from 'services/api/endpoints/models'; import type { ImageDTO } from 'services/api/types'; import { v4 as uuidv4 } from 'uuid'; @@ -278,21 +280,6 @@ export const useRecallParameters = () => { [dispatch, parameterSetToast, parameterNotSetToast] ); - /** - * Recall model with toast - */ - const recallModel = useCallback( - (model: unknown) => { - if (!isParameterModel(model)) { - parameterNotSetToast(); - return; - } - dispatch(modelSelected(model)); - parameterSetToast(); - }, - [dispatch, parameterSetToast, parameterNotSetToast] - ); - /** * Recall scheduler with toast */ @@ -308,25 +295,6 @@ export const useRecallParameters = () => { [dispatch, parameterSetToast, parameterNotSetToast] ); - /** - * Recall vae model - */ - const recallVaeModel = useCallback( - (vae: unknown) => { - if (!isParameterVAEModel(vae) && !isNil(vae)) { - parameterNotSetToast(); - return; - } - if (isNil(vae)) { - dispatch(vaeSelected(null)); - } else { - dispatch(vaeSelected(vae)); - } - parameterSetToast(); - }, - [dispatch, parameterSetToast, parameterNotSetToast] - ); - /** * Recall steps with toast */ @@ -452,6 +420,95 @@ export const useRecallParameters = () => { [dispatch, parameterSetToast, parameterNotSetToast] ); + const { data: mainModels } = useGetMainModelsQuery(ALL_BASE_MODELS); + + const prepareMainModelMetadataItem = useCallback( + (model: ModelIdentifier) => { + const matchingModel = mainModels ? mainModelsAdapterSelectors.selectById(mainModels, model.key) : undefined; + + if (!matchingModel) { + return { model: null, error: 'Model is not installed' }; + } + + return { model: matchingModel, error: null }; + }, + [mainModels] + ); + + /** + * Recall model with toast + */ + const recallModel = useCallback( + (model: unknown) => { + if (!isModelIdentifier(model)) { + parameterNotSetToast(); + return; + } + + const result = prepareMainModelMetadataItem(model); + + if (!result.model) { + parameterNotSetToast(result.error); + return; + } + + dispatch(modelSelected(result.model)); + parameterSetToast(); + }, + [prepareMainModelMetadataItem, dispatch, parameterSetToast, parameterNotSetToast] + ); + + const { data: vaeModels } = useGetVaeModelsQuery(); + + const prepareVAEMetadataItem = useCallback( + (vae: ModelIdentifier, newModel?: ParameterModel) => { + const matchingModel = vaeModels ? vaeModelsAdapterSelectors.selectById(vaeModels, vae.key) : undefined; + if (!matchingModel) { + return { vae: null, error: 'VAE model is not installed' }; + } + const isCompatibleBaseModel = matchingModel?.base === (newModel ?? model)?.base; + + if (!isCompatibleBaseModel) { + return { + vae: null, + error: 'VAE incompatible with currently-selected model', + }; + } + + return { vae: matchingModel, error: null }; + }, + [model, vaeModels] + ); + + /** + * Recall vae model + */ + const recallVaeModel = useCallback( + (vae: unknown) => { + if (!isModelIdentifier(vae) && !isNil(vae)) { + parameterNotSetToast(); + return; + } + + if (isNil(vae)) { + dispatch(vaeSelected(null)); + parameterSetToast(); + return; + } + + const result = prepareVAEMetadataItem(vae); + + if (!result.vae) { + parameterNotSetToast(result.error); + return; + } + + dispatch(vaeSelected(result.vae)); + parameterSetToast(); + }, + [prepareVAEMetadataItem, dispatch, parameterSetToast, parameterNotSetToast] + ); + /** * Recall LoRA with toast */ @@ -460,7 +517,7 @@ export const useRecallParameters = () => { const prepareLoRAMetadataItem = useCallback( (loraMetadataItem: LoRAMetadataItem, newModel?: ParameterModel) => { - if (!isParameterLoRAModel(loraMetadataItem.lora)) { + if (!isModelIdentifier(loraMetadataItem.lora)) { return { lora: null, error: 'Invalid LoRA model' }; } @@ -510,7 +567,7 @@ export const useRecallParameters = () => { const prepareControlNetMetadataItem = useCallback( (controlnetMetadataItem: ControlNetMetadataItem, newModel?: ParameterModel) => { - if (!isParameterControlNetModel(controlnetMetadataItem.control_model)) { + if (!isModelIdentifier(controlnetMetadataItem.control_model)) { return { controlnet: null, error: 'Invalid ControlNet model' }; } @@ -584,7 +641,7 @@ export const useRecallParameters = () => { const prepareT2IAdapterMetadataItem = useCallback( (t2iAdapterMetadataItem: T2IAdapterMetadataItem, newModel?: ParameterModel) => { - if (!isParameterControlNetModel(t2iAdapterMetadataItem.t2i_adapter_model)) { + if (!isModelIdentifier(t2iAdapterMetadataItem.t2i_adapter_model)) { return { controlnet: null, error: 'Invalid ControlNet model' }; } @@ -657,7 +714,7 @@ export const useRecallParameters = () => { const prepareIPAdapterMetadataItem = useCallback( (ipAdapterMetadataItem: IPAdapterMetadataItem, newModel?: ParameterModel) => { - if (!isParameterIPAdapterModel(ipAdapterMetadataItem?.ip_adapter_model)) { + if (!isModelIdentifier(ipAdapterMetadataItem?.ip_adapter_model)) { return { ipAdapter: null, error: 'Invalid IP Adapter model' }; } @@ -762,9 +819,12 @@ export const useRecallParameters = () => { let newModel: ParameterModel | undefined = undefined; - if (isParameterModel(model)) { - newModel = model; - dispatch(modelSelected(model)); + if (isModelIdentifier(model)) { + const result = prepareMainModelMetadataItem(model); + if (result.model) { + dispatch(modelSelected(result.model)); + newModel = result.model; + } } if (isParameterCFGScale(cfg_scale)) { @@ -786,11 +846,14 @@ export const useRecallParameters = () => { if (isParameterScheduler(scheduler)) { dispatch(setScheduler(scheduler)); } - if (isParameterVAEModel(vae) || isNil(vae)) { + if (isModelIdentifier(vae) || isNil(vae)) { if (isNil(vae)) { dispatch(vaeSelected(null)); } else { - dispatch(vaeSelected(vae)); + const result = prepareVAEMetadataItem(vae, newModel); + if (result.vae) { + dispatch(vaeSelected(result.vae)); + } } } @@ -898,6 +961,8 @@ export const useRecallParameters = () => { dispatch, allParameterSetToast, allParameterNotSetToast, + prepareMainModelMetadataItem, + prepareVAEMetadataItem, prepareLoRAMetadataItem, prepareControlNetMetadataItem, prepareIPAdapterMetadataItem, From 8967c865477269e1cb6ae844588be9cece8362c6 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 21 Feb 2024 19:43:14 +1100 Subject: [PATCH 192/340] tidy(ui): remove debugging stmt --- .../parameters/components/VAEModel/ParamVAEModelSelect.tsx | 2 -- 1 file changed, 2 deletions(-) diff --git a/invokeai/frontend/web/src/features/parameters/components/VAEModel/ParamVAEModelSelect.tsx b/invokeai/frontend/web/src/features/parameters/components/VAEModel/ParamVAEModelSelect.tsx index 1810c3ff68a..4b9f2764bf0 100644 --- a/invokeai/frontend/web/src/features/parameters/components/VAEModel/ParamVAEModelSelect.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/VAEModel/ParamVAEModelSelect.tsx @@ -41,8 +41,6 @@ const ParamVAEModelSelect = () => { isLoading, getIsDisabled, }); - - console.log(value) return ( From 9abb2ed430173181cf93c13ef0dcc48093d4badd Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 21 Feb 2024 13:47:46 +1100 Subject: [PATCH 193/340] chore(ui): bump deps Notable updates: - Minor version of RTK includes customizable selectors for RTK Query, so we can remove the patch that was added to ensure only the LRU memoize function was used for perf reasons. Updated to use the LRU memoize function. - Major version of react-resizable-panels. No breaking changes, works great, and you can now resize all panels when dragging at the intersection point of panels. Cool! - Minor (?) version of nanostores. `action` API is removed, we were using it in one spot. Fixed. - @invoke-ai/eslint-config-react has all deps bumped and now has its dependent plugins/configs listed as normal dependencies (as opposed to peer deps). This means we can remove those packages from explicit dev deps. --- invokeai/frontend/web/package.json | 103 +- invokeai/frontend/web/pnpm-lock.yaml | 4094 +++++++++-------- .../store/enhancers/reduxRemember/driver.ts | 8 +- .../frontend/web/src/services/api/index.ts | 12 +- 4 files changed, 2203 insertions(+), 2014 deletions(-) diff --git a/invokeai/frontend/web/package.json b/invokeai/frontend/web/package.json index a9f37f76ad6..743cb1e09d6 100644 --- a/invokeai/frontend/web/package.json +++ b/invokeai/frontend/web/package.json @@ -59,52 +59,52 @@ "@fontsource-variable/inter": "^5.0.16", "@invoke-ai/ui-library": "^0.0.21", "@mantine/form": "6.0.21", - "@nanostores/react": "^0.7.1", - "@reduxjs/toolkit": "2.0.1", + "@nanostores/react": "^0.7.2", + "@reduxjs/toolkit": "2.2.1", "@roarr/browser-log-writer": "^1.3.0", "chakra-react-select": "^4.7.6", "compare-versions": "^6.1.0", "dateformat": "^5.0.3", - "framer-motion": "^10.18.0", - "i18next": "^23.7.16", - "i18next-http-backend": "^2.4.2", + "framer-motion": "^11.0.5", + "i18next": "^23.9.0", + "i18next-http-backend": "^2.4.3", "idb-keyval": "^6.2.1", "jsondiffpatch": "^0.6.0", - "konva": "^9.3.1", + "konva": "^9.3.3", "lodash-es": "^4.17.21", - "nanostores": "^0.9.5", + "nanostores": "^0.10.0", "new-github-issue-url": "^1.0.0", - "overlayscrollbars": "^2.4.6", - "overlayscrollbars-react": "^0.5.3", - "query-string": "^8.1.0", + "overlayscrollbars": "^2.5.0", + "overlayscrollbars-react": "^0.5.4", + "query-string": "^8.2.0", "react": "^18.2.0", "react-colorful": "^5.6.1", "react-dom": "^18.2.0", "react-dropzone": "^14.2.3", "react-error-boundary": "^4.0.12", - "react-hook-form": "^7.49.3", - "react-hotkeys-hook": "4.4.4", - "react-i18next": "^14.0.0", + "react-hook-form": "^7.50.1", + "react-hotkeys-hook": "4.5.0", + "react-i18next": "^14.0.5", "react-icons": "^5.0.1", "react-konva": "^18.2.10", "react-redux": "9.1.0", - "react-resizable-panels": "^1.0.9", + "react-resizable-panels": "^2.0.9", "react-select": "5.8.0", "react-textarea-autosize": "^8.5.3", - "react-use": "^17.4.3", - "react-virtuoso": "^4.6.2", - "reactflow": "^11.10.2", + "react-use": "^17.5.0", + "react-virtuoso": "^4.7.0", + "reactflow": "^11.10.4", "redux-dynamic-middlewares": "^2.2.0", "redux-remember": "^5.1.0", "roarr": "^7.21.0", "serialize-error": "^11.0.3", "socket.io-client": "^4.7.4", - "type-fest": "^4.9.0", + "type-fest": "^4.10.2", "use-debounce": "^10.0.0", "use-image": "^1.1.1", "uuid": "^9.0.1", "zod": "^3.22.4", - "zod-validation-error": "^3.0.0" + "zod-validation-error": "^3.0.2" }, "peerDependencies": { "@chakra-ui/react": "^2.8.2", @@ -113,59 +113,44 @@ "ts-toolbelt": "^9.6.0" }, "devDependencies": { - "@arthurgeron/eslint-plugin-react-usememo": "^2.2.3", - "@invoke-ai/eslint-config-react": "^0.0.13", - "@invoke-ai/prettier-config-react": "^0.0.6", - "@storybook/addon-docs": "^7.6.10", - "@storybook/addon-essentials": "^7.6.10", - "@storybook/addon-interactions": "^7.6.10", - "@storybook/addon-links": "^7.6.10", - "@storybook/addon-storysource": "^7.6.10", - "@storybook/blocks": "^7.6.10", - "@storybook/manager-api": "^7.6.10", - "@storybook/react": "^7.6.10", - "@storybook/react-vite": "^7.6.10", - "@storybook/test": "^7.6.10", - "@storybook/theming": "^7.6.10", + "@invoke-ai/eslint-config-react": "^0.0.14", + "@invoke-ai/prettier-config-react": "^0.0.7", + "@storybook/addon-docs": "^7.6.17", + "@storybook/addon-essentials": "^7.6.17", + "@storybook/addon-interactions": "^7.6.17", + "@storybook/addon-links": "^7.6.17", + "@storybook/addon-storysource": "^7.6.17", + "@storybook/blocks": "^7.6.17", + "@storybook/manager-api": "^7.6.17", + "@storybook/react": "^7.6.17", + "@storybook/react-vite": "^7.6.17", + "@storybook/test": "^7.6.17", + "@storybook/theming": "^7.6.17", "@types/dateformat": "^5.0.2", "@types/lodash-es": "^4.17.12", - "@types/node": "^20.11.5", - "@types/react": "^18.2.48", - "@types/react-dom": "^18.2.18", - "@types/uuid": "^9.0.7", - "@typescript-eslint/eslint-plugin": "^6.19.0", - "@typescript-eslint/parser": "^6.19.0", - "@vitejs/plugin-react-swc": "^3.5.0", + "@types/node": "^20.11.19", + "@types/react": "^18.2.57", + "@types/react-dom": "^18.2.19", + "@types/uuid": "^9.0.8", + "@vitejs/plugin-react-swc": "^3.6.0", "concurrently": "^8.2.2", "eslint": "^8.56.0", - "eslint-config-prettier": "^9.1.0", "eslint-plugin-i18next": "^6.0.3", - "eslint-plugin-import": "^2.29.1", "eslint-plugin-path": "^1.2.4", - "eslint-plugin-react": "^7.33.2", - "eslint-plugin-react-hooks": "^4.6.0", - "eslint-plugin-simple-import-sort": "^10.0.0", - "eslint-plugin-storybook": "^0.6.15", - "eslint-plugin-unused-imports": "^3.0.0", "madge": "^6.1.0", "openapi-types": "^12.1.3", - "openapi-typescript": "^6.7.3", - "prettier": "^3.2.4", + "openapi-typescript": "^6.7.4", + "prettier": "^3.2.5", "rollup-plugin-visualizer": "^5.12.0", - "storybook": "^7.6.10", + "storybook": "^7.6.17", "ts-toolbelt": "^9.6.0", "tsafe": "^1.6.6", "typescript": "^5.3.3", - "vite": "^5.0.12", - "vite-plugin-css-injected-by-js": "^3.3.1", - "vite-plugin-dts": "^3.7.1", + "vite": "^5.1.3", + "vite-plugin-css-injected-by-js": "^3.4.0", + "vite-plugin-dts": "^3.7.2", "vite-plugin-eslint": "^1.8.1", "vite-tsconfig-paths": "^4.3.1", - "vitest": "^1.2.2" - }, - "pnpm": { - "patchedDependencies": { - "reselect@5.0.1": "patches/reselect@5.0.1.patch" - } + "vitest": "^1.3.1" } } diff --git a/invokeai/frontend/web/pnpm-lock.yaml b/invokeai/frontend/web/pnpm-lock.yaml index 4f9902299cb..9e873102e6a 100644 --- a/invokeai/frontend/web/pnpm-lock.yaml +++ b/invokeai/frontend/web/pnpm-lock.yaml @@ -4,15 +4,10 @@ settings: autoInstallPeers: true excludeLinksFromLockfile: false -patchedDependencies: - reselect@5.0.1: - hash: kvbgwzjyy4x4fnh7znyocvb75q - path: patches/reselect@5.0.1.patch - dependencies: '@chakra-ui/react': specifier: ^2.8.2 - version: 2.8.2(@emotion/react@11.11.3)(@emotion/styled@11.11.0)(@types/react@18.2.48)(framer-motion@10.18.0)(react-dom@18.2.0)(react@18.2.0) + version: 2.8.2(@emotion/react@11.11.3)(@emotion/styled@11.11.0)(@types/react@18.2.57)(framer-motion@11.0.5)(react-dom@18.2.0)(react@18.2.0) '@chakra-ui/react-use-size': specifier: ^2.1.0 version: 2.1.0(react@18.2.0) @@ -33,22 +28,22 @@ dependencies: version: 5.0.16 '@invoke-ai/ui-library': specifier: ^0.0.21 - version: 0.0.21(@chakra-ui/form-control@2.2.0)(@chakra-ui/icon@3.2.0)(@chakra-ui/media-query@3.3.0)(@chakra-ui/menu@2.2.1)(@chakra-ui/spinner@2.1.0)(@chakra-ui/system@2.6.2)(@fontsource-variable/inter@5.0.16)(@internationalized/date@3.5.2)(@types/react@18.2.48)(i18next@23.7.16)(react-dom@18.2.0)(react@18.2.0) + version: 0.0.21(@chakra-ui/form-control@2.2.0)(@chakra-ui/icon@3.2.0)(@chakra-ui/media-query@3.3.0)(@chakra-ui/menu@2.2.1)(@chakra-ui/spinner@2.1.0)(@chakra-ui/system@2.6.2)(@fontsource-variable/inter@5.0.16)(@internationalized/date@3.5.2)(@types/react@18.2.57)(i18next@23.9.0)(react-dom@18.2.0)(react@18.2.0) '@mantine/form': specifier: 6.0.21 version: 6.0.21(react@18.2.0) '@nanostores/react': - specifier: ^0.7.1 - version: 0.7.1(nanostores@0.9.5)(react@18.2.0) + specifier: ^0.7.2 + version: 0.7.2(nanostores@0.10.0)(react@18.2.0) '@reduxjs/toolkit': - specifier: 2.0.1 - version: 2.0.1(react-redux@9.1.0)(react@18.2.0) + specifier: 2.2.1 + version: 2.2.1(react-redux@9.1.0)(react@18.2.0) '@roarr/browser-log-writer': specifier: ^1.3.0 version: 1.3.0 chakra-react-select: specifier: ^4.7.6 - version: 4.7.6(@chakra-ui/form-control@2.2.0)(@chakra-ui/icon@3.2.0)(@chakra-ui/layout@2.3.1)(@chakra-ui/media-query@3.3.0)(@chakra-ui/menu@2.2.1)(@chakra-ui/spinner@2.1.0)(@chakra-ui/system@2.6.2)(@emotion/react@11.11.3)(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0) + version: 4.7.6(@chakra-ui/form-control@2.2.0)(@chakra-ui/icon@3.2.0)(@chakra-ui/layout@2.3.1)(@chakra-ui/media-query@3.3.0)(@chakra-ui/menu@2.2.1)(@chakra-ui/spinner@2.1.0)(@chakra-ui/system@2.6.2)(@emotion/react@11.11.3)(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0) compare-versions: specifier: ^6.1.0 version: 6.1.0 @@ -56,14 +51,14 @@ dependencies: specifier: ^5.0.3 version: 5.0.3 framer-motion: - specifier: ^10.18.0 - version: 10.18.0(react-dom@18.2.0)(react@18.2.0) + specifier: ^11.0.5 + version: 11.0.5(react-dom@18.2.0)(react@18.2.0) i18next: - specifier: ^23.7.16 - version: 23.7.16 + specifier: ^23.9.0 + version: 23.9.0 i18next-http-backend: - specifier: ^2.4.2 - version: 2.4.2 + specifier: ^2.4.3 + version: 2.4.3 idb-keyval: specifier: ^6.2.1 version: 6.2.1 @@ -71,26 +66,26 @@ dependencies: specifier: ^0.6.0 version: 0.6.0 konva: - specifier: ^9.3.1 - version: 9.3.1 + specifier: ^9.3.3 + version: 9.3.3 lodash-es: specifier: ^4.17.21 version: 4.17.21 nanostores: - specifier: ^0.9.5 - version: 0.9.5 + specifier: ^0.10.0 + version: 0.10.0 new-github-issue-url: specifier: ^1.0.0 version: 1.0.0 overlayscrollbars: - specifier: ^2.4.6 - version: 2.4.6 + specifier: ^2.5.0 + version: 2.5.0 overlayscrollbars-react: - specifier: ^0.5.3 - version: 0.5.3(overlayscrollbars@2.4.6)(react@18.2.0) + specifier: ^0.5.4 + version: 0.5.4(overlayscrollbars@2.5.0)(react@18.2.0) query-string: - specifier: ^8.1.0 - version: 8.1.0 + specifier: ^8.2.0 + version: 8.2.0 react: specifier: ^18.2.0 version: 18.2.0 @@ -107,41 +102,41 @@ dependencies: specifier: ^4.0.12 version: 4.0.12(react@18.2.0) react-hook-form: - specifier: ^7.49.3 - version: 7.49.3(react@18.2.0) + specifier: ^7.50.1 + version: 7.50.1(react@18.2.0) react-hotkeys-hook: - specifier: 4.4.4 - version: 4.4.4(react-dom@18.2.0)(react@18.2.0) + specifier: 4.5.0 + version: 4.5.0(react-dom@18.2.0)(react@18.2.0) react-i18next: - specifier: ^14.0.0 - version: 14.0.0(i18next@23.7.16)(react-dom@18.2.0)(react@18.2.0) + specifier: ^14.0.5 + version: 14.0.5(i18next@23.9.0)(react-dom@18.2.0)(react@18.2.0) react-icons: specifier: ^5.0.1 version: 5.0.1(react@18.2.0) react-konva: specifier: ^18.2.10 - version: 18.2.10(konva@9.3.1)(react-dom@18.2.0)(react@18.2.0) + version: 18.2.10(konva@9.3.3)(react-dom@18.2.0)(react@18.2.0) react-redux: specifier: 9.1.0 - version: 9.1.0(@types/react@18.2.48)(react@18.2.0)(redux@5.0.1) + version: 9.1.0(@types/react@18.2.57)(react@18.2.0)(redux@5.0.1) react-resizable-panels: - specifier: ^1.0.9 - version: 1.0.9(react-dom@18.2.0)(react@18.2.0) + specifier: ^2.0.9 + version: 2.0.9(react-dom@18.2.0)(react@18.2.0) react-select: specifier: 5.8.0 - version: 5.8.0(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0) + version: 5.8.0(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0) react-textarea-autosize: specifier: ^8.5.3 - version: 8.5.3(@types/react@18.2.48)(react@18.2.0) + version: 8.5.3(@types/react@18.2.57)(react@18.2.0) react-use: - specifier: ^17.4.3 - version: 17.4.3(react-dom@18.2.0)(react@18.2.0) + specifier: ^17.5.0 + version: 17.5.0(react-dom@18.2.0)(react@18.2.0) react-virtuoso: - specifier: ^4.6.2 - version: 4.6.2(react-dom@18.2.0)(react@18.2.0) + specifier: ^4.7.0 + version: 4.7.0(react-dom@18.2.0)(react@18.2.0) reactflow: - specifier: ^11.10.2 - version: 11.10.2(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0) + specifier: ^11.10.4 + version: 11.10.4(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0) redux-dynamic-middlewares: specifier: ^2.2.0 version: 2.2.0 @@ -158,8 +153,8 @@ dependencies: specifier: ^4.7.4 version: 4.7.4 type-fest: - specifier: ^4.9.0 - version: 4.9.0 + specifier: ^4.10.2 + version: 4.10.2 use-debounce: specifier: ^10.0.0 version: 10.0.0(react@18.2.0) @@ -173,52 +168,49 @@ dependencies: specifier: ^3.22.4 version: 3.22.4 zod-validation-error: - specifier: ^3.0.0 - version: 3.0.0(zod@3.22.4) + specifier: ^3.0.2 + version: 3.0.2(zod@3.22.4) devDependencies: - '@arthurgeron/eslint-plugin-react-usememo': - specifier: ^2.2.3 - version: 2.2.3 '@invoke-ai/eslint-config-react': - specifier: ^0.0.13 - version: 0.0.13(@typescript-eslint/eslint-plugin@6.19.0)(@typescript-eslint/parser@6.19.0)(eslint-config-prettier@9.1.0)(eslint-plugin-import@2.29.1)(eslint-plugin-react-hooks@4.6.0)(eslint-plugin-react-refresh@0.4.5)(eslint-plugin-react@7.33.2)(eslint-plugin-simple-import-sort@10.0.0)(eslint-plugin-storybook@0.6.15)(eslint-plugin-unused-imports@3.0.0)(eslint@8.56.0) + specifier: ^0.0.14 + version: 0.0.14(eslint@8.56.0)(prettier@3.2.5)(typescript@5.3.3) '@invoke-ai/prettier-config-react': - specifier: ^0.0.6 - version: 0.0.6(prettier@3.2.4) + specifier: ^0.0.7 + version: 0.0.7(prettier@3.2.5) '@storybook/addon-docs': - specifier: ^7.6.10 - version: 7.6.10(@types/react-dom@18.2.18)(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0) + specifier: ^7.6.17 + version: 7.6.17(@types/react-dom@18.2.19)(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0) '@storybook/addon-essentials': - specifier: ^7.6.10 - version: 7.6.10(@types/react-dom@18.2.18)(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0) + specifier: ^7.6.17 + version: 7.6.17(@types/react-dom@18.2.19)(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0) '@storybook/addon-interactions': - specifier: ^7.6.10 - version: 7.6.10 + specifier: ^7.6.17 + version: 7.6.17 '@storybook/addon-links': - specifier: ^7.6.10 - version: 7.6.10(react@18.2.0) + specifier: ^7.6.17 + version: 7.6.17(react@18.2.0) '@storybook/addon-storysource': - specifier: ^7.6.10 - version: 7.6.10 + specifier: ^7.6.17 + version: 7.6.17 '@storybook/blocks': - specifier: ^7.6.10 - version: 7.6.10(@types/react-dom@18.2.18)(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0) + specifier: ^7.6.17 + version: 7.6.17(@types/react-dom@18.2.19)(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0) '@storybook/manager-api': - specifier: ^7.6.10 - version: 7.6.10(react-dom@18.2.0)(react@18.2.0) + specifier: ^7.6.17 + version: 7.6.17(react-dom@18.2.0)(react@18.2.0) '@storybook/react': - specifier: ^7.6.10 - version: 7.6.10(react-dom@18.2.0)(react@18.2.0)(typescript@5.3.3) + specifier: ^7.6.17 + version: 7.6.17(react-dom@18.2.0)(react@18.2.0)(typescript@5.3.3) '@storybook/react-vite': - specifier: ^7.6.10 - version: 7.6.10(react-dom@18.2.0)(react@18.2.0)(typescript@5.3.3)(vite@5.0.12) + specifier: ^7.6.17 + version: 7.6.17(react-dom@18.2.0)(react@18.2.0)(typescript@5.3.3)(vite@5.1.3) '@storybook/test': - specifier: ^7.6.10 - version: 7.6.10(vitest@1.2.2) + specifier: ^7.6.17 + version: 7.6.17(vitest@1.3.1) '@storybook/theming': - specifier: ^7.6.10 - version: 7.6.10(react-dom@18.2.0)(react@18.2.0) + specifier: ^7.6.17 + version: 7.6.17(react-dom@18.2.0)(react@18.2.0) '@types/dateformat': specifier: ^5.0.2 version: 5.0.2 @@ -226,59 +218,32 @@ devDependencies: specifier: ^4.17.12 version: 4.17.12 '@types/node': - specifier: ^20.11.5 - version: 20.11.5 + specifier: ^20.11.19 + version: 20.11.19 '@types/react': - specifier: ^18.2.48 - version: 18.2.48 + specifier: ^18.2.57 + version: 18.2.57 '@types/react-dom': - specifier: ^18.2.18 - version: 18.2.18 + specifier: ^18.2.19 + version: 18.2.19 '@types/uuid': - specifier: ^9.0.7 - version: 9.0.7 - '@typescript-eslint/eslint-plugin': - specifier: ^6.19.0 - version: 6.19.0(@typescript-eslint/parser@6.19.0)(eslint@8.56.0)(typescript@5.3.3) - '@typescript-eslint/parser': - specifier: ^6.19.0 - version: 6.19.0(eslint@8.56.0)(typescript@5.3.3) + specifier: ^9.0.8 + version: 9.0.8 '@vitejs/plugin-react-swc': - specifier: ^3.5.0 - version: 3.5.0(vite@5.0.12) + specifier: ^3.6.0 + version: 3.6.0(vite@5.1.3) concurrently: specifier: ^8.2.2 version: 8.2.2 eslint: specifier: ^8.56.0 version: 8.56.0 - eslint-config-prettier: - specifier: ^9.1.0 - version: 9.1.0(eslint@8.56.0) eslint-plugin-i18next: specifier: ^6.0.3 version: 6.0.3 - eslint-plugin-import: - specifier: ^2.29.1 - version: 2.29.1(@typescript-eslint/parser@6.19.0)(eslint@8.56.0) eslint-plugin-path: specifier: ^1.2.4 version: 1.2.4(eslint@8.56.0) - eslint-plugin-react: - specifier: ^7.33.2 - version: 7.33.2(eslint@8.56.0) - eslint-plugin-react-hooks: - specifier: ^4.6.0 - version: 4.6.0(eslint@8.56.0) - eslint-plugin-simple-import-sort: - specifier: ^10.0.0 - version: 10.0.0(eslint@8.56.0) - eslint-plugin-storybook: - specifier: ^0.6.15 - version: 0.6.15(eslint@8.56.0)(typescript@5.3.3) - eslint-plugin-unused-imports: - specifier: ^3.0.0 - version: 3.0.0(@typescript-eslint/eslint-plugin@6.19.0)(eslint@8.56.0) madge: specifier: ^6.1.0 version: 6.1.0(typescript@5.3.3) @@ -286,17 +251,17 @@ devDependencies: specifier: ^12.1.3 version: 12.1.3 openapi-typescript: - specifier: ^6.7.3 - version: 6.7.3 + specifier: ^6.7.4 + version: 6.7.4 prettier: - specifier: ^3.2.4 - version: 3.2.4 + specifier: ^3.2.5 + version: 3.2.5 rollup-plugin-visualizer: specifier: ^5.12.0 version: 5.12.0 storybook: - specifier: ^7.6.10 - version: 7.6.10 + specifier: ^7.6.17 + version: 7.6.17 ts-toolbelt: specifier: ^9.6.0 version: 9.6.0 @@ -307,23 +272,23 @@ devDependencies: specifier: ^5.3.3 version: 5.3.3 vite: - specifier: ^5.0.12 - version: 5.0.12(@types/node@20.11.5) + specifier: ^5.1.3 + version: 5.1.3(@types/node@20.11.19) vite-plugin-css-injected-by-js: - specifier: ^3.3.1 - version: 3.3.1(vite@5.0.12) + specifier: ^3.4.0 + version: 3.4.0(vite@5.1.3) vite-plugin-dts: - specifier: ^3.7.1 - version: 3.7.1(@types/node@20.11.5)(typescript@5.3.3)(vite@5.0.12) + specifier: ^3.7.2 + version: 3.7.2(@types/node@20.11.19)(typescript@5.3.3)(vite@5.1.3) vite-plugin-eslint: specifier: ^1.8.1 - version: 1.8.1(eslint@8.56.0)(vite@5.0.12) + version: 1.8.1(eslint@8.56.0)(vite@5.1.3) vite-tsconfig-paths: specifier: ^4.3.1 - version: 4.3.1(typescript@5.3.3)(vite@5.0.12) + version: 4.3.1(typescript@5.3.3)(vite@5.1.3) vitest: - specifier: ^1.2.2 - version: 1.2.2(@types/node@20.11.5) + specifier: ^1.3.1 + version: 1.3.1(@types/node@20.11.19) packages: @@ -332,8 +297,8 @@ packages: engines: {node: '>=0.10.0'} dev: true - /@adobe/css-tools@4.3.2: - resolution: {integrity: sha512-DA5a1C0gD/pLOvhv33YMrbf2FK3oUzwNl9oOJqE4XVjuEtt6XIakRcsd7eLiOSPkp1kTRQGICTA8cKra/vFbjw==} + /@adobe/css-tools@4.3.3: + resolution: {integrity: sha512-rE0Pygv0sEZ4vBWHlAgJLGDU7Pm8xoO6p3wsEceb7GYAjScrOHpEo8KK/eVkAcnSM+slAEtXjA2JpdjLp4fJQQ==} dev: true /@ampproject/remapping@2.2.1: @@ -341,7 +306,7 @@ packages: engines: {node: '>=6.0.0'} dependencies: '@jridgewell/gen-mapping': 0.3.3 - '@jridgewell/trace-mapping': 0.3.21 + '@jridgewell/trace-mapping': 0.3.22 dev: true /@ark-ui/anatomy@1.3.0(@internationalized/date@3.5.2): @@ -430,13 +395,6 @@ packages: - '@internationalized/date' dev: false - /@arthurgeron/eslint-plugin-react-usememo@2.2.3: - resolution: {integrity: sha512-YJG+8hULmhHAxztaANswpa9hWNqEOSvbZcbd6R/JQzyNlEZ49Xh97kqZGuJGZ74rrmULckEO1m3Jh5ctqrGA2A==} - dependencies: - minimatch: 9.0.3 - uuid: 9.0.1 - dev: true - /@aw-web-design/x-default-browser@1.4.126: resolution: {integrity: sha512-Xk1sIhyNC/esHGGVjL/niHLowM0csl/kFO5uawBy4IrWwy0o1G8LGt3jP6nmWGz+USxeeqbihAmp/oVZju6wug==} hasBin: true @@ -456,20 +414,20 @@ packages: engines: {node: '>=6.9.0'} dev: true - /@babel/core@7.23.7: - resolution: {integrity: sha512-+UpDgowcmqe36d4NwqvKsyPMlOLNGMsfMmQ5WGCu+siCe3t3dfe9njrzGfdN4qq+bcNUt0+Vw6haRxBOycs4dw==} + /@babel/core@7.23.9: + resolution: {integrity: sha512-5q0175NOjddqpvvzU+kDiSOAk4PfdO6FvwCWoQ6RO7rTzEe8vlo+4HVfcnAREhD4npMs0e9uZypjTwzZPCf/cw==} engines: {node: '>=6.9.0'} dependencies: '@ampproject/remapping': 2.2.1 '@babel/code-frame': 7.23.5 '@babel/generator': 7.23.6 '@babel/helper-compilation-targets': 7.23.6 - '@babel/helper-module-transforms': 7.23.3(@babel/core@7.23.7) - '@babel/helpers': 7.23.8 - '@babel/parser': 7.23.6 - '@babel/template': 7.22.15 - '@babel/traverse': 7.23.7 - '@babel/types': 7.23.6 + '@babel/helper-module-transforms': 7.23.3(@babel/core@7.23.9) + '@babel/helpers': 7.23.9 + '@babel/parser': 7.23.9 + '@babel/template': 7.23.9 + '@babel/traverse': 7.23.9 + '@babel/types': 7.23.9 convert-source-map: 2.0.0 debug: 4.3.4 gensync: 1.0.0-beta.2 @@ -483,9 +441,9 @@ packages: resolution: {integrity: sha512-qrSfCYxYQB5owCmGLbl8XRpX1ytXlpueOb0N0UmQwA073KZxejgQTzAmJezxvpwQD9uGtK2shHdi55QT+MbjIw==} engines: {node: '>=6.9.0'} dependencies: - '@babel/types': 7.23.6 + '@babel/types': 7.23.9 '@jridgewell/gen-mapping': 0.3.3 - '@jridgewell/trace-mapping': 0.3.21 + '@jridgewell/trace-mapping': 0.3.22 jsesc: 2.5.2 dev: true @@ -493,14 +451,14 @@ packages: resolution: {integrity: sha512-LvBTxu8bQSQkcyKOU+a1btnNFQ1dMAd0R6PyW3arXes06F6QLWLIrd681bxRPIXlrMGR3XYnW9JyML7dP3qgxg==} engines: {node: '>=6.9.0'} dependencies: - '@babel/types': 7.23.6 + '@babel/types': 7.23.9 dev: true /@babel/helper-builder-binary-assignment-operator-visitor@7.22.15: resolution: {integrity: sha512-QkBXwGgaoC2GtGZRoma6kv7Szfv06khvhFav67ZExau2RaXzy8MpHSMO2PNoP2XtmQphJQRHFfg77Bq731Yizw==} engines: {node: '>=6.9.0'} dependencies: - '@babel/types': 7.23.6 + '@babel/types': 7.23.9 dev: true /@babel/helper-compilation-targets@7.23.6: @@ -509,62 +467,47 @@ packages: dependencies: '@babel/compat-data': 7.23.5 '@babel/helper-validator-option': 7.23.5 - browserslist: 4.22.2 + browserslist: 4.23.0 lru-cache: 5.1.1 semver: 6.3.1 dev: true - /@babel/helper-create-class-features-plugin@7.23.7(@babel/core@7.23.7): - resolution: {integrity: sha512-xCoqR/8+BoNnXOY7RVSgv6X+o7pmT5q1d+gGcRlXYkI+9B31glE4jeejhKVpA04O1AtzOt7OSQ6VYKP5FcRl9g==} + /@babel/helper-create-class-features-plugin@7.23.10(@babel/core@7.23.9): + resolution: {integrity: sha512-2XpP2XhkXzgxecPNEEK8Vz8Asj9aRxt08oKOqtiZoqV2UGZ5T+EkyP9sXQ9nwMxBIG34a7jmasVqoMop7VdPUw==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-annotate-as-pure': 7.22.5 '@babel/helper-environment-visitor': 7.22.20 '@babel/helper-function-name': 7.23.0 '@babel/helper-member-expression-to-functions': 7.23.0 '@babel/helper-optimise-call-expression': 7.22.5 - '@babel/helper-replace-supers': 7.22.20(@babel/core@7.23.7) + '@babel/helper-replace-supers': 7.22.20(@babel/core@7.23.9) '@babel/helper-skip-transparent-expression-wrappers': 7.22.5 '@babel/helper-split-export-declaration': 7.22.6 semver: 6.3.1 dev: true - /@babel/helper-create-regexp-features-plugin@7.22.15(@babel/core@7.23.7): + /@babel/helper-create-regexp-features-plugin@7.22.15(@babel/core@7.23.9): resolution: {integrity: sha512-29FkPLFjn4TPEa3RE7GpW+qbE8tlsu3jntNYNfcGsc49LphF1PQIiD+vMZ1z1xVOKt+93khA9tc2JBs3kBjA7w==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-annotate-as-pure': 7.22.5 regexpu-core: 5.3.2 semver: 6.3.1 dev: true - /@babel/helper-define-polyfill-provider@0.4.4(@babel/core@7.23.7): - resolution: {integrity: sha512-QcJMILQCu2jm5TFPGA3lCpJJTeEP+mqeXooG/NZbg/h5FTFi6V0+99ahlRsW8/kRLyb24LZVCCiclDedhLKcBA==} - peerDependencies: - '@babel/core': ^7.4.0 || ^8.0.0-0 <8.0.0 - dependencies: - '@babel/core': 7.23.7 - '@babel/helper-compilation-targets': 7.23.6 - '@babel/helper-plugin-utils': 7.22.5 - debug: 4.3.4 - lodash.debounce: 4.0.8 - resolve: 1.22.8 - transitivePeerDependencies: - - supports-color - dev: true - - /@babel/helper-define-polyfill-provider@0.5.0(@babel/core@7.23.7): + /@babel/helper-define-polyfill-provider@0.5.0(@babel/core@7.23.9): resolution: {integrity: sha512-NovQquuQLAQ5HuyjCz7WQP9MjRj7dx++yspwiyUiGl9ZyadHRSql1HZh5ogRd8W8w6YM6EQ/NTB8rgjLt5W65Q==} peerDependencies: '@babel/core': ^7.4.0 || ^8.0.0-0 <8.0.0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-compilation-targets': 7.23.6 '@babel/helper-plugin-utils': 7.22.5 debug: 4.3.4 @@ -583,37 +526,37 @@ packages: resolution: {integrity: sha512-OErEqsrxjZTJciZ4Oo+eoZqeW9UIiOcuYKRJA4ZAgV9myA+pOXhhmpfNCKjEH/auVfEYVFJ6y1Tc4r0eIApqiw==} engines: {node: '>=6.9.0'} dependencies: - '@babel/template': 7.22.15 - '@babel/types': 7.23.6 + '@babel/template': 7.23.9 + '@babel/types': 7.23.9 dev: true /@babel/helper-hoist-variables@7.22.5: resolution: {integrity: sha512-wGjk9QZVzvknA6yKIUURb8zY3grXCcOZt+/7Wcy8O2uctxhplmUPkOdlgoNhmdVee2c92JXbf1xpMtVNbfoxRw==} engines: {node: '>=6.9.0'} dependencies: - '@babel/types': 7.23.6 + '@babel/types': 7.23.9 dev: true /@babel/helper-member-expression-to-functions@7.23.0: resolution: {integrity: sha512-6gfrPwh7OuT6gZyJZvd6WbTfrqAo7vm4xCzAXOusKqq/vWdKXphTpj5klHKNmRUU6/QRGlBsyU9mAIPaWHlqJA==} engines: {node: '>=6.9.0'} dependencies: - '@babel/types': 7.23.6 + '@babel/types': 7.23.9 dev: true /@babel/helper-module-imports@7.22.15: resolution: {integrity: sha512-0pYVBnDKZO2fnSPCrgM/6WMc7eS20Fbok+0r88fp+YtWVLZrp4CkafFGIp+W0VKw4a22sgebPT99y+FDNMdP4w==} engines: {node: '>=6.9.0'} dependencies: - '@babel/types': 7.23.6 + '@babel/types': 7.23.9 - /@babel/helper-module-transforms@7.23.3(@babel/core@7.23.7): + /@babel/helper-module-transforms@7.23.3(@babel/core@7.23.9): resolution: {integrity: sha512-7bBs4ED9OmswdfDzpz4MpWgSrV7FXlc3zIagvLFjS5H+Mk7Snr21vQ6QwrsoCGMfNC4e4LQPdoULEt4ykz0SRQ==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-environment-visitor': 7.22.20 '@babel/helper-module-imports': 7.22.15 '@babel/helper-simple-access': 7.22.5 @@ -625,7 +568,7 @@ packages: resolution: {integrity: sha512-HBwaojN0xFRx4yIvpwGqxiV2tUfl7401jlok564NgB9EHS1y6QT17FmKWm4ztqjeVdXLuC4fSvHc5ePpQjoTbw==} engines: {node: '>=6.9.0'} dependencies: - '@babel/types': 7.23.6 + '@babel/types': 7.23.9 dev: true /@babel/helper-plugin-utils@7.22.5: @@ -633,25 +576,25 @@ packages: engines: {node: '>=6.9.0'} dev: true - /@babel/helper-remap-async-to-generator@7.22.20(@babel/core@7.23.7): + /@babel/helper-remap-async-to-generator@7.22.20(@babel/core@7.23.9): resolution: {integrity: sha512-pBGyV4uBqOns+0UvhsTO8qgl8hO89PmiDYv+/COyp1aeMcmfrfruz+/nCMFiYyFF/Knn0yfrC85ZzNFjembFTw==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-annotate-as-pure': 7.22.5 '@babel/helper-environment-visitor': 7.22.20 '@babel/helper-wrap-function': 7.22.20 dev: true - /@babel/helper-replace-supers@7.22.20(@babel/core@7.23.7): + /@babel/helper-replace-supers@7.22.20(@babel/core@7.23.9): resolution: {integrity: sha512-qsW0In3dbwQUbK8kejJ4R7IHVGwHJlV6lpG6UA7a9hSa2YEiAib+N1T2kr6PEeUT+Fl7najmSOS6SmAwCHK6Tw==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-environment-visitor': 7.22.20 '@babel/helper-member-expression-to-functions': 7.23.0 '@babel/helper-optimise-call-expression': 7.22.5 @@ -661,21 +604,21 @@ packages: resolution: {integrity: sha512-n0H99E/K+Bika3++WNL17POvo4rKWZ7lZEp1Q+fStVbUi8nxPQEBOlTmCOxW/0JsS56SKKQ+ojAe2pHKJHN35w==} engines: {node: '>=6.9.0'} dependencies: - '@babel/types': 7.23.6 + '@babel/types': 7.23.9 dev: true /@babel/helper-skip-transparent-expression-wrappers@7.22.5: resolution: {integrity: sha512-tK14r66JZKiC43p8Ki33yLBVJKlQDFoA8GYN67lWCDCqoL6EMMSuM9b+Iff2jHaM/RRFYl7K+iiru7hbRqNx8Q==} engines: {node: '>=6.9.0'} dependencies: - '@babel/types': 7.23.6 + '@babel/types': 7.23.9 dev: true /@babel/helper-split-export-declaration@7.22.6: resolution: {integrity: sha512-AsUnxuLhRYsisFiaJwvp1QF+I3KjD5FOxut14q/GzovUe6orHLesW2C7d754kRm53h5gqrz6sFl6sxc4BVtE/g==} engines: {node: '>=6.9.0'} dependencies: - '@babel/types': 7.23.6 + '@babel/types': 7.23.9 dev: true /@babel/helper-string-parser@7.23.4: @@ -696,17 +639,17 @@ packages: engines: {node: '>=6.9.0'} dependencies: '@babel/helper-function-name': 7.23.0 - '@babel/template': 7.22.15 - '@babel/types': 7.23.6 + '@babel/template': 7.23.9 + '@babel/types': 7.23.9 dev: true - /@babel/helpers@7.23.8: - resolution: {integrity: sha512-KDqYz4PiOWvDFrdHLPhKtCThtIcKVy6avWD2oG4GEvyQ+XDZwHD4YQd+H2vNMnq2rkdxsDkU82T+Vk8U/WXHRQ==} + /@babel/helpers@7.23.9: + resolution: {integrity: sha512-87ICKgU5t5SzOT7sBMfCOZQ2rHjRU+Pcb9BoILMYz600W6DkVRLFBPwQ18gwUVvggqXivaUakpnxWQGbpywbBQ==} engines: {node: '>=6.9.0'} dependencies: - '@babel/template': 7.22.15 - '@babel/traverse': 7.23.7 - '@babel/types': 7.23.6 + '@babel/template': 7.23.9 + '@babel/traverse': 7.23.9 + '@babel/types': 7.23.9 transitivePeerDependencies: - supports-color dev: true @@ -719,966 +662,966 @@ packages: chalk: 2.4.2 js-tokens: 4.0.0 - /@babel/parser@7.23.6: - resolution: {integrity: sha512-Z2uID7YJ7oNvAI20O9X0bblw7Qqs8Q2hFy0R9tAfnfLkp5MW0UH9eUvnDSnFwKZ0AvgS1ucqR4KzvVHgnke1VQ==} + /@babel/parser@7.23.9: + resolution: {integrity: sha512-9tcKgqKbs3xGJ+NtKF2ndOBBLVwPjl1SHxPQkd36r3Dlirw3xWUeGaTbqr7uGZcTaxkVNwc+03SVP7aCdWrTlA==} engines: {node: '>=6.0.0'} hasBin: true dependencies: - '@babel/types': 7.23.6 + '@babel/types': 7.23.9 dev: true - /@babel/plugin-bugfix-safari-id-destructuring-collision-in-function-expression@7.23.3(@babel/core@7.23.7): + /@babel/plugin-bugfix-safari-id-destructuring-collision-in-function-expression@7.23.3(@babel/core@7.23.9): resolution: {integrity: sha512-iRkKcCqb7iGnq9+3G6rZ+Ciz5VywC4XNRHe57lKM+jOeYAoR0lVqdeeDRfh0tQcTfw/+vBhHn926FmQhLtlFLQ==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-plugin-utils': 7.22.5 dev: true - /@babel/plugin-bugfix-v8-spread-parameters-in-optional-chaining@7.23.3(@babel/core@7.23.7): + /@babel/plugin-bugfix-v8-spread-parameters-in-optional-chaining@7.23.3(@babel/core@7.23.9): resolution: {integrity: sha512-WwlxbfMNdVEpQjZmK5mhm7oSwD3dS6eU+Iwsi4Knl9wAletWem7kaRsGOG+8UEbRyqxY4SS5zvtfXwX+jMxUwQ==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.13.0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-plugin-utils': 7.22.5 '@babel/helper-skip-transparent-expression-wrappers': 7.22.5 - '@babel/plugin-transform-optional-chaining': 7.23.4(@babel/core@7.23.7) + '@babel/plugin-transform-optional-chaining': 7.23.4(@babel/core@7.23.9) dev: true - /@babel/plugin-bugfix-v8-static-class-fields-redefine-readonly@7.23.7(@babel/core@7.23.7): + /@babel/plugin-bugfix-v8-static-class-fields-redefine-readonly@7.23.7(@babel/core@7.23.9): resolution: {integrity: sha512-LlRT7HgaifEpQA1ZgLVOIJZZFVPWN5iReq/7/JixwBtwcoeVGDBD53ZV28rrsLYOZs1Y/EHhA8N/Z6aazHR8cw==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-environment-visitor': 7.22.20 '@babel/helper-plugin-utils': 7.22.5 dev: true - /@babel/plugin-proposal-private-property-in-object@7.21.0-placeholder-for-preset-env.2(@babel/core@7.23.7): + /@babel/plugin-proposal-private-property-in-object@7.21.0-placeholder-for-preset-env.2(@babel/core@7.23.9): resolution: {integrity: sha512-SOSkfJDddaM7mak6cPEpswyTRnuRltl429hMraQEglW+OkovnCzsiszTmsrlY//qLFjCpQDFRvjdm2wA5pPm9w==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 dev: true - /@babel/plugin-syntax-async-generators@7.8.4(@babel/core@7.23.7): + /@babel/plugin-syntax-async-generators@7.8.4(@babel/core@7.23.9): resolution: {integrity: sha512-tycmZxkGfZaxhMRbXlPXuVFpdWlXpir2W4AMhSJgRKzk/eDlIXOhb2LHWoLpDF7TEHylV5zNhykX6KAgHJmTNw==} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-plugin-utils': 7.22.5 dev: true - /@babel/plugin-syntax-class-properties@7.12.13(@babel/core@7.23.7): + /@babel/plugin-syntax-class-properties@7.12.13(@babel/core@7.23.9): resolution: {integrity: sha512-fm4idjKla0YahUNgFNLCB0qySdsoPiZP3iQE3rky0mBUtMZ23yDJ9SJdg6dXTSDnulOVqiF3Hgr9nbXvXTQZYA==} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-plugin-utils': 7.22.5 dev: true - /@babel/plugin-syntax-class-static-block@7.14.5(@babel/core@7.23.7): + /@babel/plugin-syntax-class-static-block@7.14.5(@babel/core@7.23.9): resolution: {integrity: sha512-b+YyPmr6ldyNnM6sqYeMWE+bgJcJpO6yS4QD7ymxgH34GBPNDM/THBh8iunyvKIZztiwLH4CJZ0RxTk9emgpjw==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-plugin-utils': 7.22.5 dev: true - /@babel/plugin-syntax-dynamic-import@7.8.3(@babel/core@7.23.7): + /@babel/plugin-syntax-dynamic-import@7.8.3(@babel/core@7.23.9): resolution: {integrity: sha512-5gdGbFon+PszYzqs83S3E5mpi7/y/8M9eC90MRTZfduQOYW76ig6SOSPNe41IG5LoP3FGBn2N0RjVDSQiS94kQ==} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-plugin-utils': 7.22.5 dev: true - /@babel/plugin-syntax-export-namespace-from@7.8.3(@babel/core@7.23.7): + /@babel/plugin-syntax-export-namespace-from@7.8.3(@babel/core@7.23.9): resolution: {integrity: sha512-MXf5laXo6c1IbEbegDmzGPwGNTsHZmEy6QGznu5Sh2UCWvueywb2ee+CCE4zQiZstxU9BMoQO9i6zUFSY0Kj0Q==} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-plugin-utils': 7.22.5 dev: true - /@babel/plugin-syntax-flow@7.23.3(@babel/core@7.23.7): + /@babel/plugin-syntax-flow@7.23.3(@babel/core@7.23.9): resolution: {integrity: sha512-YZiAIpkJAwQXBJLIQbRFayR5c+gJ35Vcz3bg954k7cd73zqjvhacJuL9RbrzPz8qPmZdgqP6EUKwy0PCNhaaPA==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-plugin-utils': 7.22.5 dev: true - /@babel/plugin-syntax-import-assertions@7.23.3(@babel/core@7.23.7): + /@babel/plugin-syntax-import-assertions@7.23.3(@babel/core@7.23.9): resolution: {integrity: sha512-lPgDSU+SJLK3xmFDTV2ZRQAiM7UuUjGidwBywFavObCiZc1BeAAcMtHJKUya92hPHO+at63JJPLygilZard8jw==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-plugin-utils': 7.22.5 dev: true - /@babel/plugin-syntax-import-attributes@7.23.3(@babel/core@7.23.7): + /@babel/plugin-syntax-import-attributes@7.23.3(@babel/core@7.23.9): resolution: {integrity: sha512-pawnE0P9g10xgoP7yKr6CK63K2FMsTE+FZidZO/1PwRdzmAPVs+HS1mAURUsgaoxammTJvULUdIkEK0gOcU2tA==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-plugin-utils': 7.22.5 dev: true - /@babel/plugin-syntax-import-meta@7.10.4(@babel/core@7.23.7): + /@babel/plugin-syntax-import-meta@7.10.4(@babel/core@7.23.9): resolution: {integrity: sha512-Yqfm+XDx0+Prh3VSeEQCPU81yC+JWZ2pDPFSS4ZdpfZhp4MkFMaDC1UqseovEKwSUpnIL7+vK+Clp7bfh0iD7g==} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-plugin-utils': 7.22.5 dev: true - /@babel/plugin-syntax-json-strings@7.8.3(@babel/core@7.23.7): + /@babel/plugin-syntax-json-strings@7.8.3(@babel/core@7.23.9): resolution: {integrity: sha512-lY6kdGpWHvjoe2vk4WrAapEuBR69EMxZl+RoGRhrFGNYVK8mOPAW8VfbT/ZgrFbXlDNiiaxQnAtgVCZ6jv30EA==} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-plugin-utils': 7.22.5 dev: true - /@babel/plugin-syntax-jsx@7.23.3(@babel/core@7.23.7): + /@babel/plugin-syntax-jsx@7.23.3(@babel/core@7.23.9): resolution: {integrity: sha512-EB2MELswq55OHUoRZLGg/zC7QWUKfNLpE57m/S2yr1uEneIgsTgrSzXP3NXEsMkVn76OlaVVnzN+ugObuYGwhg==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-plugin-utils': 7.22.5 dev: true - /@babel/plugin-syntax-logical-assignment-operators@7.10.4(@babel/core@7.23.7): + /@babel/plugin-syntax-logical-assignment-operators@7.10.4(@babel/core@7.23.9): resolution: {integrity: sha512-d8waShlpFDinQ5MtvGU9xDAOzKH47+FFoney2baFIoMr952hKOLp1HR7VszoZvOsV/4+RRszNY7D17ba0te0ig==} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-plugin-utils': 7.22.5 dev: true - /@babel/plugin-syntax-nullish-coalescing-operator@7.8.3(@babel/core@7.23.7): + /@babel/plugin-syntax-nullish-coalescing-operator@7.8.3(@babel/core@7.23.9): resolution: {integrity: sha512-aSff4zPII1u2QD7y+F8oDsz19ew4IGEJg9SVW+bqwpwtfFleiQDMdzA/R+UlWDzfnHFCxxleFT0PMIrR36XLNQ==} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-plugin-utils': 7.22.5 dev: true - /@babel/plugin-syntax-numeric-separator@7.10.4(@babel/core@7.23.7): + /@babel/plugin-syntax-numeric-separator@7.10.4(@babel/core@7.23.9): resolution: {integrity: sha512-9H6YdfkcK/uOnY/K7/aA2xpzaAgkQn37yzWUMRK7OaPOqOpGS1+n0H5hxT9AUw9EsSjPW8SVyMJwYRtWs3X3ug==} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-plugin-utils': 7.22.5 dev: true - /@babel/plugin-syntax-object-rest-spread@7.8.3(@babel/core@7.23.7): + /@babel/plugin-syntax-object-rest-spread@7.8.3(@babel/core@7.23.9): resolution: {integrity: sha512-XoqMijGZb9y3y2XskN+P1wUGiVwWZ5JmoDRwx5+3GmEplNyVM2s2Dg8ILFQm8rWM48orGy5YpI5Bl8U1y7ydlA==} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-plugin-utils': 7.22.5 dev: true - /@babel/plugin-syntax-optional-catch-binding@7.8.3(@babel/core@7.23.7): + /@babel/plugin-syntax-optional-catch-binding@7.8.3(@babel/core@7.23.9): resolution: {integrity: sha512-6VPD0Pc1lpTqw0aKoeRTMiB+kWhAoT24PA+ksWSBrFtl5SIRVpZlwN3NNPQjehA2E/91FV3RjLWoVTglWcSV3Q==} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-plugin-utils': 7.22.5 dev: true - /@babel/plugin-syntax-optional-chaining@7.8.3(@babel/core@7.23.7): + /@babel/plugin-syntax-optional-chaining@7.8.3(@babel/core@7.23.9): resolution: {integrity: sha512-KoK9ErH1MBlCPxV0VANkXW2/dw4vlbGDrFgz8bmUsBGYkFRcbRwMh6cIJubdPrkxRwuGdtCk0v/wPTKbQgBjkg==} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-plugin-utils': 7.22.5 dev: true - /@babel/plugin-syntax-private-property-in-object@7.14.5(@babel/core@7.23.7): + /@babel/plugin-syntax-private-property-in-object@7.14.5(@babel/core@7.23.9): resolution: {integrity: sha512-0wVnp9dxJ72ZUJDV27ZfbSj6iHLoytYZmh3rFcxNnvsJF3ktkzLDZPy/mA17HGsaQT3/DQsWYX1f1QGWkCoVUg==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-plugin-utils': 7.22.5 dev: true - /@babel/plugin-syntax-top-level-await@7.14.5(@babel/core@7.23.7): + /@babel/plugin-syntax-top-level-await@7.14.5(@babel/core@7.23.9): resolution: {integrity: sha512-hx++upLv5U1rgYfwe1xBQUhRmU41NEvpUvrp8jkrSCdvGSnM5/qdRMtylJ6PG5OFkBaHkbTAKTnd3/YyESRHFw==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-plugin-utils': 7.22.5 dev: true - /@babel/plugin-syntax-typescript@7.23.3(@babel/core@7.23.7): + /@babel/plugin-syntax-typescript@7.23.3(@babel/core@7.23.9): resolution: {integrity: sha512-9EiNjVJOMwCO+43TqoTrgQ8jMwcAd0sWyXi9RPfIsLTj4R2MADDDQXELhffaUx/uJv2AYcxBgPwH6j4TIA4ytQ==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-plugin-utils': 7.22.5 dev: true - /@babel/plugin-syntax-unicode-sets-regex@7.18.6(@babel/core@7.23.7): + /@babel/plugin-syntax-unicode-sets-regex@7.18.6(@babel/core@7.23.9): resolution: {integrity: sha512-727YkEAPwSIQTv5im8QHz3upqp92JTWhidIC81Tdx4VJYIte/VndKf1qKrfnnhPLiPghStWfvC/iFaMCQu7Nqg==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0 dependencies: - '@babel/core': 7.23.7 - '@babel/helper-create-regexp-features-plugin': 7.22.15(@babel/core@7.23.7) + '@babel/core': 7.23.9 + '@babel/helper-create-regexp-features-plugin': 7.22.15(@babel/core@7.23.9) '@babel/helper-plugin-utils': 7.22.5 dev: true - /@babel/plugin-transform-arrow-functions@7.23.3(@babel/core@7.23.7): + /@babel/plugin-transform-arrow-functions@7.23.3(@babel/core@7.23.9): resolution: {integrity: sha512-NzQcQrzaQPkaEwoTm4Mhyl8jI1huEL/WWIEvudjTCMJ9aBZNpsJbMASx7EQECtQQPS/DcnFpo0FIh3LvEO9cxQ==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-plugin-utils': 7.22.5 dev: true - /@babel/plugin-transform-async-generator-functions@7.23.7(@babel/core@7.23.7): - resolution: {integrity: sha512-PdxEpL71bJp1byMG0va5gwQcXHxuEYC/BgI/e88mGTtohbZN28O5Yit0Plkkm/dBzCF/BxmbNcses1RH1T+urA==} + /@babel/plugin-transform-async-generator-functions@7.23.9(@babel/core@7.23.9): + resolution: {integrity: sha512-8Q3veQEDGe14dTYuwagbRtwxQDnytyg1JFu4/HwEMETeofocrB0U0ejBJIXoeG/t2oXZ8kzCyI0ZZfbT80VFNQ==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-environment-visitor': 7.22.20 '@babel/helper-plugin-utils': 7.22.5 - '@babel/helper-remap-async-to-generator': 7.22.20(@babel/core@7.23.7) - '@babel/plugin-syntax-async-generators': 7.8.4(@babel/core@7.23.7) + '@babel/helper-remap-async-to-generator': 7.22.20(@babel/core@7.23.9) + '@babel/plugin-syntax-async-generators': 7.8.4(@babel/core@7.23.9) dev: true - /@babel/plugin-transform-async-to-generator@7.23.3(@babel/core@7.23.7): + /@babel/plugin-transform-async-to-generator@7.23.3(@babel/core@7.23.9): resolution: {integrity: sha512-A7LFsKi4U4fomjqXJlZg/u0ft/n8/7n7lpffUP/ZULx/DtV9SGlNKZolHH6PE8Xl1ngCc0M11OaeZptXVkfKSw==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-module-imports': 7.22.15 '@babel/helper-plugin-utils': 7.22.5 - '@babel/helper-remap-async-to-generator': 7.22.20(@babel/core@7.23.7) + '@babel/helper-remap-async-to-generator': 7.22.20(@babel/core@7.23.9) dev: true - /@babel/plugin-transform-block-scoped-functions@7.23.3(@babel/core@7.23.7): + /@babel/plugin-transform-block-scoped-functions@7.23.3(@babel/core@7.23.9): resolution: {integrity: sha512-vI+0sIaPIO6CNuM9Kk5VmXcMVRiOpDh7w2zZt9GXzmE/9KD70CUEVhvPR/etAeNK/FAEkhxQtXOzVF3EuRL41A==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-plugin-utils': 7.22.5 dev: true - /@babel/plugin-transform-block-scoping@7.23.4(@babel/core@7.23.7): + /@babel/plugin-transform-block-scoping@7.23.4(@babel/core@7.23.9): resolution: {integrity: sha512-0QqbP6B6HOh7/8iNR4CQU2Th/bbRtBp4KS9vcaZd1fZ0wSh5Fyssg0UCIHwxh+ka+pNDREbVLQnHCMHKZfPwfw==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-plugin-utils': 7.22.5 dev: true - /@babel/plugin-transform-class-properties@7.23.3(@babel/core@7.23.7): + /@babel/plugin-transform-class-properties@7.23.3(@babel/core@7.23.9): resolution: {integrity: sha512-uM+AN8yCIjDPccsKGlw271xjJtGii+xQIF/uMPS8H15L12jZTsLfF4o5vNO7d/oUguOyfdikHGc/yi9ge4SGIg==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 - '@babel/helper-create-class-features-plugin': 7.23.7(@babel/core@7.23.7) + '@babel/core': 7.23.9 + '@babel/helper-create-class-features-plugin': 7.23.10(@babel/core@7.23.9) '@babel/helper-plugin-utils': 7.22.5 dev: true - /@babel/plugin-transform-class-static-block@7.23.4(@babel/core@7.23.7): + /@babel/plugin-transform-class-static-block@7.23.4(@babel/core@7.23.9): resolution: {integrity: sha512-nsWu/1M+ggti1SOALj3hfx5FXzAY06fwPJsUZD4/A5e1bWi46VUIWtD+kOX6/IdhXGsXBWllLFDSnqSCdUNydQ==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.12.0 dependencies: - '@babel/core': 7.23.7 - '@babel/helper-create-class-features-plugin': 7.23.7(@babel/core@7.23.7) + '@babel/core': 7.23.9 + '@babel/helper-create-class-features-plugin': 7.23.10(@babel/core@7.23.9) '@babel/helper-plugin-utils': 7.22.5 - '@babel/plugin-syntax-class-static-block': 7.14.5(@babel/core@7.23.7) + '@babel/plugin-syntax-class-static-block': 7.14.5(@babel/core@7.23.9) dev: true - /@babel/plugin-transform-classes@7.23.8(@babel/core@7.23.7): + /@babel/plugin-transform-classes@7.23.8(@babel/core@7.23.9): resolution: {integrity: sha512-yAYslGsY1bX6Knmg46RjiCiNSwJKv2IUC8qOdYKqMMr0491SXFhcHqOdRDeCRohOOIzwN/90C6mQ9qAKgrP7dg==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-annotate-as-pure': 7.22.5 '@babel/helper-compilation-targets': 7.23.6 '@babel/helper-environment-visitor': 7.22.20 '@babel/helper-function-name': 7.23.0 '@babel/helper-plugin-utils': 7.22.5 - '@babel/helper-replace-supers': 7.22.20(@babel/core@7.23.7) + '@babel/helper-replace-supers': 7.22.20(@babel/core@7.23.9) '@babel/helper-split-export-declaration': 7.22.6 globals: 11.12.0 dev: true - /@babel/plugin-transform-computed-properties@7.23.3(@babel/core@7.23.7): + /@babel/plugin-transform-computed-properties@7.23.3(@babel/core@7.23.9): resolution: {integrity: sha512-dTj83UVTLw/+nbiHqQSFdwO9CbTtwq1DsDqm3CUEtDrZNET5rT5E6bIdTlOftDTDLMYxvxHNEYO4B9SLl8SLZw==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-plugin-utils': 7.22.5 - '@babel/template': 7.22.15 + '@babel/template': 7.23.9 dev: true - /@babel/plugin-transform-destructuring@7.23.3(@babel/core@7.23.7): + /@babel/plugin-transform-destructuring@7.23.3(@babel/core@7.23.9): resolution: {integrity: sha512-n225npDqjDIr967cMScVKHXJs7rout1q+tt50inyBCPkyZ8KxeI6d+GIbSBTT/w/9WdlWDOej3V9HE5Lgk57gw==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-plugin-utils': 7.22.5 dev: true - /@babel/plugin-transform-dotall-regex@7.23.3(@babel/core@7.23.7): + /@babel/plugin-transform-dotall-regex@7.23.3(@babel/core@7.23.9): resolution: {integrity: sha512-vgnFYDHAKzFaTVp+mneDsIEbnJ2Np/9ng9iviHw3P/KVcgONxpNULEW/51Z/BaFojG2GI2GwwXck5uV1+1NOYQ==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 - '@babel/helper-create-regexp-features-plugin': 7.22.15(@babel/core@7.23.7) + '@babel/core': 7.23.9 + '@babel/helper-create-regexp-features-plugin': 7.22.15(@babel/core@7.23.9) '@babel/helper-plugin-utils': 7.22.5 dev: true - /@babel/plugin-transform-duplicate-keys@7.23.3(@babel/core@7.23.7): + /@babel/plugin-transform-duplicate-keys@7.23.3(@babel/core@7.23.9): resolution: {integrity: sha512-RrqQ+BQmU3Oyav3J+7/myfvRCq7Tbz+kKLLshUmMwNlDHExbGL7ARhajvoBJEvc+fCguPPu887N+3RRXBVKZUA==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-plugin-utils': 7.22.5 dev: true - /@babel/plugin-transform-dynamic-import@7.23.4(@babel/core@7.23.7): + /@babel/plugin-transform-dynamic-import@7.23.4(@babel/core@7.23.9): resolution: {integrity: sha512-V6jIbLhdJK86MaLh4Jpghi8ho5fGzt3imHOBu/x0jlBaPYqDoWz4RDXjmMOfnh+JWNaQleEAByZLV0QzBT4YQQ==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-plugin-utils': 7.22.5 - '@babel/plugin-syntax-dynamic-import': 7.8.3(@babel/core@7.23.7) + '@babel/plugin-syntax-dynamic-import': 7.8.3(@babel/core@7.23.9) dev: true - /@babel/plugin-transform-exponentiation-operator@7.23.3(@babel/core@7.23.7): + /@babel/plugin-transform-exponentiation-operator@7.23.3(@babel/core@7.23.9): resolution: {integrity: sha512-5fhCsl1odX96u7ILKHBj4/Y8vipoqwsJMh4csSA8qFfxrZDEA4Ssku2DyNvMJSmZNOEBT750LfFPbtrnTP90BQ==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-builder-binary-assignment-operator-visitor': 7.22.15 '@babel/helper-plugin-utils': 7.22.5 dev: true - /@babel/plugin-transform-export-namespace-from@7.23.4(@babel/core@7.23.7): + /@babel/plugin-transform-export-namespace-from@7.23.4(@babel/core@7.23.9): resolution: {integrity: sha512-GzuSBcKkx62dGzZI1WVgTWvkkz84FZO5TC5T8dl/Tht/rAla6Dg/Mz9Yhypg+ezVACf/rgDuQt3kbWEv7LdUDQ==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-plugin-utils': 7.22.5 - '@babel/plugin-syntax-export-namespace-from': 7.8.3(@babel/core@7.23.7) + '@babel/plugin-syntax-export-namespace-from': 7.8.3(@babel/core@7.23.9) dev: true - /@babel/plugin-transform-flow-strip-types@7.23.3(@babel/core@7.23.7): + /@babel/plugin-transform-flow-strip-types@7.23.3(@babel/core@7.23.9): resolution: {integrity: sha512-26/pQTf9nQSNVJCrLB1IkHUKyPxR+lMrH2QDPG89+Znu9rAMbtrybdbWeE9bb7gzjmE5iXHEY+e0HUwM6Co93Q==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-plugin-utils': 7.22.5 - '@babel/plugin-syntax-flow': 7.23.3(@babel/core@7.23.7) + '@babel/plugin-syntax-flow': 7.23.3(@babel/core@7.23.9) dev: true - /@babel/plugin-transform-for-of@7.23.6(@babel/core@7.23.7): + /@babel/plugin-transform-for-of@7.23.6(@babel/core@7.23.9): resolution: {integrity: sha512-aYH4ytZ0qSuBbpfhuofbg/e96oQ7U2w1Aw/UQmKT+1l39uEhUPoFS3fHevDc1G0OvewyDudfMKY1OulczHzWIw==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-plugin-utils': 7.22.5 '@babel/helper-skip-transparent-expression-wrappers': 7.22.5 dev: true - /@babel/plugin-transform-function-name@7.23.3(@babel/core@7.23.7): + /@babel/plugin-transform-function-name@7.23.3(@babel/core@7.23.9): resolution: {integrity: sha512-I1QXp1LxIvt8yLaib49dRW5Okt7Q4oaxao6tFVKS/anCdEOMtYwWVKoiOA1p34GOWIZjUK0E+zCp7+l1pfQyiw==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-compilation-targets': 7.23.6 '@babel/helper-function-name': 7.23.0 '@babel/helper-plugin-utils': 7.22.5 dev: true - /@babel/plugin-transform-json-strings@7.23.4(@babel/core@7.23.7): + /@babel/plugin-transform-json-strings@7.23.4(@babel/core@7.23.9): resolution: {integrity: sha512-81nTOqM1dMwZ/aRXQ59zVubN9wHGqk6UtqRK+/q+ciXmRy8fSolhGVvG09HHRGo4l6fr/c4ZhXUQH0uFW7PZbg==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-plugin-utils': 7.22.5 - '@babel/plugin-syntax-json-strings': 7.8.3(@babel/core@7.23.7) + '@babel/plugin-syntax-json-strings': 7.8.3(@babel/core@7.23.9) dev: true - /@babel/plugin-transform-literals@7.23.3(@babel/core@7.23.7): + /@babel/plugin-transform-literals@7.23.3(@babel/core@7.23.9): resolution: {integrity: sha512-wZ0PIXRxnwZvl9AYpqNUxpZ5BiTGrYt7kueGQ+N5FiQ7RCOD4cm8iShd6S6ggfVIWaJf2EMk8eRzAh52RfP4rQ==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-plugin-utils': 7.22.5 dev: true - /@babel/plugin-transform-logical-assignment-operators@7.23.4(@babel/core@7.23.7): + /@babel/plugin-transform-logical-assignment-operators@7.23.4(@babel/core@7.23.9): resolution: {integrity: sha512-Mc/ALf1rmZTP4JKKEhUwiORU+vcfarFVLfcFiolKUo6sewoxSEgl36ak5t+4WamRsNr6nzjZXQjM35WsU+9vbg==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-plugin-utils': 7.22.5 - '@babel/plugin-syntax-logical-assignment-operators': 7.10.4(@babel/core@7.23.7) + '@babel/plugin-syntax-logical-assignment-operators': 7.10.4(@babel/core@7.23.9) dev: true - /@babel/plugin-transform-member-expression-literals@7.23.3(@babel/core@7.23.7): + /@babel/plugin-transform-member-expression-literals@7.23.3(@babel/core@7.23.9): resolution: {integrity: sha512-sC3LdDBDi5x96LA+Ytekz2ZPk8i/Ck+DEuDbRAll5rknJ5XRTSaPKEYwomLcs1AA8wg9b3KjIQRsnApj+q51Ag==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-plugin-utils': 7.22.5 dev: true - /@babel/plugin-transform-modules-amd@7.23.3(@babel/core@7.23.7): + /@babel/plugin-transform-modules-amd@7.23.3(@babel/core@7.23.9): resolution: {integrity: sha512-vJYQGxeKM4t8hYCKVBlZX/gtIY2I7mRGFNcm85sgXGMTBcoV3QdVtdpbcWEbzbfUIUZKwvgFT82mRvaQIebZzw==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 - '@babel/helper-module-transforms': 7.23.3(@babel/core@7.23.7) + '@babel/core': 7.23.9 + '@babel/helper-module-transforms': 7.23.3(@babel/core@7.23.9) '@babel/helper-plugin-utils': 7.22.5 dev: true - /@babel/plugin-transform-modules-commonjs@7.23.3(@babel/core@7.23.7): + /@babel/plugin-transform-modules-commonjs@7.23.3(@babel/core@7.23.9): resolution: {integrity: sha512-aVS0F65LKsdNOtcz6FRCpE4OgsP2OFnW46qNxNIX9h3wuzaNcSQsJysuMwqSibC98HPrf2vCgtxKNwS0DAlgcA==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 - '@babel/helper-module-transforms': 7.23.3(@babel/core@7.23.7) + '@babel/core': 7.23.9 + '@babel/helper-module-transforms': 7.23.3(@babel/core@7.23.9) '@babel/helper-plugin-utils': 7.22.5 '@babel/helper-simple-access': 7.22.5 dev: true - /@babel/plugin-transform-modules-systemjs@7.23.3(@babel/core@7.23.7): - resolution: {integrity: sha512-ZxyKGTkF9xT9YJuKQRo19ewf3pXpopuYQd8cDXqNzc3mUNbOME0RKMoZxviQk74hwzfQsEe66dE92MaZbdHKNQ==} + /@babel/plugin-transform-modules-systemjs@7.23.9(@babel/core@7.23.9): + resolution: {integrity: sha512-KDlPRM6sLo4o1FkiSlXoAa8edLXFsKKIda779fbLrvmeuc3itnjCtaO6RrtoaANsIJANj+Vk1zqbZIMhkCAHVw==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-hoist-variables': 7.22.5 - '@babel/helper-module-transforms': 7.23.3(@babel/core@7.23.7) + '@babel/helper-module-transforms': 7.23.3(@babel/core@7.23.9) '@babel/helper-plugin-utils': 7.22.5 '@babel/helper-validator-identifier': 7.22.20 dev: true - /@babel/plugin-transform-modules-umd@7.23.3(@babel/core@7.23.7): + /@babel/plugin-transform-modules-umd@7.23.3(@babel/core@7.23.9): resolution: {integrity: sha512-zHsy9iXX2nIsCBFPud3jKn1IRPWg3Ing1qOZgeKV39m1ZgIdpJqvlWVeiHBZC6ITRG0MfskhYe9cLgntfSFPIg==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 - '@babel/helper-module-transforms': 7.23.3(@babel/core@7.23.7) + '@babel/core': 7.23.9 + '@babel/helper-module-transforms': 7.23.3(@babel/core@7.23.9) '@babel/helper-plugin-utils': 7.22.5 dev: true - /@babel/plugin-transform-named-capturing-groups-regex@7.22.5(@babel/core@7.23.7): + /@babel/plugin-transform-named-capturing-groups-regex@7.22.5(@babel/core@7.23.9): resolution: {integrity: sha512-YgLLKmS3aUBhHaxp5hi1WJTgOUb/NCuDHzGT9z9WTt3YG+CPRhJs6nprbStx6DnWM4dh6gt7SU3sZodbZ08adQ==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0 dependencies: - '@babel/core': 7.23.7 - '@babel/helper-create-regexp-features-plugin': 7.22.15(@babel/core@7.23.7) + '@babel/core': 7.23.9 + '@babel/helper-create-regexp-features-plugin': 7.22.15(@babel/core@7.23.9) '@babel/helper-plugin-utils': 7.22.5 dev: true - /@babel/plugin-transform-new-target@7.23.3(@babel/core@7.23.7): + /@babel/plugin-transform-new-target@7.23.3(@babel/core@7.23.9): resolution: {integrity: sha512-YJ3xKqtJMAT5/TIZnpAR3I+K+WaDowYbN3xyxI8zxx/Gsypwf9B9h0VB+1Nh6ACAAPRS5NSRje0uVv5i79HYGQ==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-plugin-utils': 7.22.5 dev: true - /@babel/plugin-transform-nullish-coalescing-operator@7.23.4(@babel/core@7.23.7): + /@babel/plugin-transform-nullish-coalescing-operator@7.23.4(@babel/core@7.23.9): resolution: {integrity: sha512-jHE9EVVqHKAQx+VePv5LLGHjmHSJR76vawFPTdlxR/LVJPfOEGxREQwQfjuZEOPTwG92X3LINSh3M40Rv4zpVA==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-plugin-utils': 7.22.5 - '@babel/plugin-syntax-nullish-coalescing-operator': 7.8.3(@babel/core@7.23.7) + '@babel/plugin-syntax-nullish-coalescing-operator': 7.8.3(@babel/core@7.23.9) dev: true - /@babel/plugin-transform-numeric-separator@7.23.4(@babel/core@7.23.7): + /@babel/plugin-transform-numeric-separator@7.23.4(@babel/core@7.23.9): resolution: {integrity: sha512-mps6auzgwjRrwKEZA05cOwuDc9FAzoyFS4ZsG/8F43bTLf/TgkJg7QXOrPO1JO599iA3qgK9MXdMGOEC8O1h6Q==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-plugin-utils': 7.22.5 - '@babel/plugin-syntax-numeric-separator': 7.10.4(@babel/core@7.23.7) + '@babel/plugin-syntax-numeric-separator': 7.10.4(@babel/core@7.23.9) dev: true - /@babel/plugin-transform-object-rest-spread@7.23.4(@babel/core@7.23.7): + /@babel/plugin-transform-object-rest-spread@7.23.4(@babel/core@7.23.9): resolution: {integrity: sha512-9x9K1YyeQVw0iOXJlIzwm8ltobIIv7j2iLyP2jIhEbqPRQ7ScNgwQufU2I0Gq11VjyG4gI4yMXt2VFags+1N3g==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: '@babel/compat-data': 7.23.5 - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-compilation-targets': 7.23.6 '@babel/helper-plugin-utils': 7.22.5 - '@babel/plugin-syntax-object-rest-spread': 7.8.3(@babel/core@7.23.7) - '@babel/plugin-transform-parameters': 7.23.3(@babel/core@7.23.7) + '@babel/plugin-syntax-object-rest-spread': 7.8.3(@babel/core@7.23.9) + '@babel/plugin-transform-parameters': 7.23.3(@babel/core@7.23.9) dev: true - /@babel/plugin-transform-object-super@7.23.3(@babel/core@7.23.7): + /@babel/plugin-transform-object-super@7.23.3(@babel/core@7.23.9): resolution: {integrity: sha512-BwQ8q0x2JG+3lxCVFohg+KbQM7plfpBwThdW9A6TMtWwLsbDA01Ek2Zb/AgDN39BiZsExm4qrXxjk+P1/fzGrA==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-plugin-utils': 7.22.5 - '@babel/helper-replace-supers': 7.22.20(@babel/core@7.23.7) + '@babel/helper-replace-supers': 7.22.20(@babel/core@7.23.9) dev: true - /@babel/plugin-transform-optional-catch-binding@7.23.4(@babel/core@7.23.7): + /@babel/plugin-transform-optional-catch-binding@7.23.4(@babel/core@7.23.9): resolution: {integrity: sha512-XIq8t0rJPHf6Wvmbn9nFxU6ao4c7WhghTR5WyV8SrJfUFzyxhCm4nhC+iAp3HFhbAKLfYpgzhJ6t4XCtVwqO5A==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-plugin-utils': 7.22.5 - '@babel/plugin-syntax-optional-catch-binding': 7.8.3(@babel/core@7.23.7) + '@babel/plugin-syntax-optional-catch-binding': 7.8.3(@babel/core@7.23.9) dev: true - /@babel/plugin-transform-optional-chaining@7.23.4(@babel/core@7.23.7): + /@babel/plugin-transform-optional-chaining@7.23.4(@babel/core@7.23.9): resolution: {integrity: sha512-ZU8y5zWOfjM5vZ+asjgAPwDaBjJzgufjES89Rs4Lpq63O300R/kOz30WCLo6BxxX6QVEilwSlpClnG5cZaikTA==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-plugin-utils': 7.22.5 '@babel/helper-skip-transparent-expression-wrappers': 7.22.5 - '@babel/plugin-syntax-optional-chaining': 7.8.3(@babel/core@7.23.7) + '@babel/plugin-syntax-optional-chaining': 7.8.3(@babel/core@7.23.9) dev: true - /@babel/plugin-transform-parameters@7.23.3(@babel/core@7.23.7): + /@babel/plugin-transform-parameters@7.23.3(@babel/core@7.23.9): resolution: {integrity: sha512-09lMt6UsUb3/34BbECKVbVwrT9bO6lILWln237z7sLaWnMsTi7Yc9fhX5DLpkJzAGfaReXI22wP41SZmnAA3Vw==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-plugin-utils': 7.22.5 dev: true - /@babel/plugin-transform-private-methods@7.23.3(@babel/core@7.23.7): + /@babel/plugin-transform-private-methods@7.23.3(@babel/core@7.23.9): resolution: {integrity: sha512-UzqRcRtWsDMTLrRWFvUBDwmw06tCQH9Rl1uAjfh6ijMSmGYQ+fpdB+cnqRC8EMh5tuuxSv0/TejGL+7vyj+50g==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 - '@babel/helper-create-class-features-plugin': 7.23.7(@babel/core@7.23.7) + '@babel/core': 7.23.9 + '@babel/helper-create-class-features-plugin': 7.23.10(@babel/core@7.23.9) '@babel/helper-plugin-utils': 7.22.5 dev: true - /@babel/plugin-transform-private-property-in-object@7.23.4(@babel/core@7.23.7): + /@babel/plugin-transform-private-property-in-object@7.23.4(@babel/core@7.23.9): resolution: {integrity: sha512-9G3K1YqTq3F4Vt88Djx1UZ79PDyj+yKRnUy7cZGSMe+a7jkwD259uKKuUzQlPkGam7R+8RJwh5z4xO27fA1o2A==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-annotate-as-pure': 7.22.5 - '@babel/helper-create-class-features-plugin': 7.23.7(@babel/core@7.23.7) + '@babel/helper-create-class-features-plugin': 7.23.10(@babel/core@7.23.9) '@babel/helper-plugin-utils': 7.22.5 - '@babel/plugin-syntax-private-property-in-object': 7.14.5(@babel/core@7.23.7) + '@babel/plugin-syntax-private-property-in-object': 7.14.5(@babel/core@7.23.9) dev: true - /@babel/plugin-transform-property-literals@7.23.3(@babel/core@7.23.7): + /@babel/plugin-transform-property-literals@7.23.3(@babel/core@7.23.9): resolution: {integrity: sha512-jR3Jn3y7cZp4oEWPFAlRsSWjxKe4PZILGBSd4nis1TsC5qeSpb+nrtihJuDhNI7QHiVbUaiXa0X2RZY3/TI6Nw==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-plugin-utils': 7.22.5 dev: true - /@babel/plugin-transform-react-jsx-self@7.23.3(@babel/core@7.23.7): + /@babel/plugin-transform-react-jsx-self@7.23.3(@babel/core@7.23.9): resolution: {integrity: sha512-qXRvbeKDSfwnlJnanVRp0SfuWE5DQhwQr5xtLBzp56Wabyo+4CMosF6Kfp+eOD/4FYpql64XVJ2W0pVLlJZxOQ==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-plugin-utils': 7.22.5 dev: true - /@babel/plugin-transform-react-jsx-source@7.23.3(@babel/core@7.23.7): + /@babel/plugin-transform-react-jsx-source@7.23.3(@babel/core@7.23.9): resolution: {integrity: sha512-91RS0MDnAWDNvGC6Wio5XYkyWI39FMFO+JK9+4AlgaTH+yWwVTsw7/sn6LK0lH7c5F+TFkpv/3LfCJ1Ydwof/g==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-plugin-utils': 7.22.5 dev: true - /@babel/plugin-transform-regenerator@7.23.3(@babel/core@7.23.7): + /@babel/plugin-transform-regenerator@7.23.3(@babel/core@7.23.9): resolution: {integrity: sha512-KP+75h0KghBMcVpuKisx3XTu9Ncut8Q8TuvGO4IhY+9D5DFEckQefOuIsB/gQ2tG71lCke4NMrtIPS8pOj18BQ==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-plugin-utils': 7.22.5 regenerator-transform: 0.15.2 dev: true - /@babel/plugin-transform-reserved-words@7.23.3(@babel/core@7.23.7): + /@babel/plugin-transform-reserved-words@7.23.3(@babel/core@7.23.9): resolution: {integrity: sha512-QnNTazY54YqgGxwIexMZva9gqbPa15t/x9VS+0fsEFWplwVpXYZivtgl43Z1vMpc1bdPP2PP8siFeVcnFvA3Cg==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-plugin-utils': 7.22.5 dev: true - /@babel/plugin-transform-shorthand-properties@7.23.3(@babel/core@7.23.7): + /@babel/plugin-transform-shorthand-properties@7.23.3(@babel/core@7.23.9): resolution: {integrity: sha512-ED2fgqZLmexWiN+YNFX26fx4gh5qHDhn1O2gvEhreLW2iI63Sqm4llRLCXALKrCnbN4Jy0VcMQZl/SAzqug/jg==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-plugin-utils': 7.22.5 dev: true - /@babel/plugin-transform-spread@7.23.3(@babel/core@7.23.7): + /@babel/plugin-transform-spread@7.23.3(@babel/core@7.23.9): resolution: {integrity: sha512-VvfVYlrlBVu+77xVTOAoxQ6mZbnIq5FM0aGBSFEcIh03qHf+zNqA4DC/3XMUozTg7bZV3e3mZQ0i13VB6v5yUg==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-plugin-utils': 7.22.5 '@babel/helper-skip-transparent-expression-wrappers': 7.22.5 dev: true - /@babel/plugin-transform-sticky-regex@7.23.3(@babel/core@7.23.7): + /@babel/plugin-transform-sticky-regex@7.23.3(@babel/core@7.23.9): resolution: {integrity: sha512-HZOyN9g+rtvnOU3Yh7kSxXrKbzgrm5X4GncPY1QOquu7epga5MxKHVpYu2hvQnry/H+JjckSYRb93iNfsioAGg==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-plugin-utils': 7.22.5 dev: true - /@babel/plugin-transform-template-literals@7.23.3(@babel/core@7.23.7): + /@babel/plugin-transform-template-literals@7.23.3(@babel/core@7.23.9): resolution: {integrity: sha512-Flok06AYNp7GV2oJPZZcP9vZdszev6vPBkHLwxwSpaIqx75wn6mUd3UFWsSsA0l8nXAKkyCmL/sR02m8RYGeHg==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-plugin-utils': 7.22.5 dev: true - /@babel/plugin-transform-typeof-symbol@7.23.3(@babel/core@7.23.7): + /@babel/plugin-transform-typeof-symbol@7.23.3(@babel/core@7.23.9): resolution: {integrity: sha512-4t15ViVnaFdrPC74be1gXBSMzXk3B4Us9lP7uLRQHTFpV5Dvt33pn+2MyyNxmN3VTTm3oTrZVMUmuw3oBnQ2oQ==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-plugin-utils': 7.22.5 dev: true - /@babel/plugin-transform-typescript@7.23.6(@babel/core@7.23.7): + /@babel/plugin-transform-typescript@7.23.6(@babel/core@7.23.9): resolution: {integrity: sha512-6cBG5mBvUu4VUD04OHKnYzbuHNP8huDsD3EDqqpIpsswTDoqHCjLoHb6+QgsV1WsT2nipRqCPgxD3LXnEO7XfA==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-annotate-as-pure': 7.22.5 - '@babel/helper-create-class-features-plugin': 7.23.7(@babel/core@7.23.7) + '@babel/helper-create-class-features-plugin': 7.23.10(@babel/core@7.23.9) '@babel/helper-plugin-utils': 7.22.5 - '@babel/plugin-syntax-typescript': 7.23.3(@babel/core@7.23.7) + '@babel/plugin-syntax-typescript': 7.23.3(@babel/core@7.23.9) dev: true - /@babel/plugin-transform-unicode-escapes@7.23.3(@babel/core@7.23.7): + /@babel/plugin-transform-unicode-escapes@7.23.3(@babel/core@7.23.9): resolution: {integrity: sha512-OMCUx/bU6ChE3r4+ZdylEqAjaQgHAgipgW8nsCfu5pGqDcFytVd91AwRvUJSBZDz0exPGgnjoqhgRYLRjFZc9Q==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-plugin-utils': 7.22.5 dev: true - /@babel/plugin-transform-unicode-property-regex@7.23.3(@babel/core@7.23.7): + /@babel/plugin-transform-unicode-property-regex@7.23.3(@babel/core@7.23.9): resolution: {integrity: sha512-KcLIm+pDZkWZQAFJ9pdfmh89EwVfmNovFBcXko8szpBeF8z68kWIPeKlmSOkT9BXJxs2C0uk+5LxoxIv62MROA==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 - '@babel/helper-create-regexp-features-plugin': 7.22.15(@babel/core@7.23.7) + '@babel/core': 7.23.9 + '@babel/helper-create-regexp-features-plugin': 7.22.15(@babel/core@7.23.9) '@babel/helper-plugin-utils': 7.22.5 dev: true - /@babel/plugin-transform-unicode-regex@7.23.3(@babel/core@7.23.7): + /@babel/plugin-transform-unicode-regex@7.23.3(@babel/core@7.23.9): resolution: {integrity: sha512-wMHpNA4x2cIA32b/ci3AfwNgheiva2W0WUKWTK7vBHBhDKfPsc5cFGNWm69WBqpwd86u1qwZ9PWevKqm1A3yAw==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 - '@babel/helper-create-regexp-features-plugin': 7.22.15(@babel/core@7.23.7) + '@babel/core': 7.23.9 + '@babel/helper-create-regexp-features-plugin': 7.22.15(@babel/core@7.23.9) '@babel/helper-plugin-utils': 7.22.5 dev: true - /@babel/plugin-transform-unicode-sets-regex@7.23.3(@babel/core@7.23.7): + /@babel/plugin-transform-unicode-sets-regex@7.23.3(@babel/core@7.23.9): resolution: {integrity: sha512-W7lliA/v9bNR83Qc3q1ip9CQMZ09CcHDbHfbLRDNuAhn1Mvkr1ZNF7hPmztMQvtTGVLJ9m8IZqWsTkXOml8dbw==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0 dependencies: - '@babel/core': 7.23.7 - '@babel/helper-create-regexp-features-plugin': 7.22.15(@babel/core@7.23.7) + '@babel/core': 7.23.9 + '@babel/helper-create-regexp-features-plugin': 7.22.15(@babel/core@7.23.9) '@babel/helper-plugin-utils': 7.22.5 dev: true - /@babel/preset-env@7.23.8(@babel/core@7.23.7): - resolution: {integrity: sha512-lFlpmkApLkEP6woIKprO6DO60RImpatTQKtz4sUcDjVcK8M8mQ4sZsuxaTMNOZf0sqAq/ReYW1ZBHnOQwKpLWA==} + /@babel/preset-env@7.23.9(@babel/core@7.23.9): + resolution: {integrity: sha512-3kBGTNBBk9DQiPoXYS0g0BYlwTQYUTifqgKTjxUwEUkduRT2QOa0FPGBJ+NROQhGyYO5BuTJwGvBnqKDykac6A==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: '@babel/compat-data': 7.23.5 - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-compilation-targets': 7.23.6 '@babel/helper-plugin-utils': 7.22.5 '@babel/helper-validator-option': 7.23.5 - '@babel/plugin-bugfix-safari-id-destructuring-collision-in-function-expression': 7.23.3(@babel/core@7.23.7) - '@babel/plugin-bugfix-v8-spread-parameters-in-optional-chaining': 7.23.3(@babel/core@7.23.7) - '@babel/plugin-bugfix-v8-static-class-fields-redefine-readonly': 7.23.7(@babel/core@7.23.7) - '@babel/plugin-proposal-private-property-in-object': 7.21.0-placeholder-for-preset-env.2(@babel/core@7.23.7) - '@babel/plugin-syntax-async-generators': 7.8.4(@babel/core@7.23.7) - '@babel/plugin-syntax-class-properties': 7.12.13(@babel/core@7.23.7) - '@babel/plugin-syntax-class-static-block': 7.14.5(@babel/core@7.23.7) - '@babel/plugin-syntax-dynamic-import': 7.8.3(@babel/core@7.23.7) - '@babel/plugin-syntax-export-namespace-from': 7.8.3(@babel/core@7.23.7) - '@babel/plugin-syntax-import-assertions': 7.23.3(@babel/core@7.23.7) - '@babel/plugin-syntax-import-attributes': 7.23.3(@babel/core@7.23.7) - '@babel/plugin-syntax-import-meta': 7.10.4(@babel/core@7.23.7) - '@babel/plugin-syntax-json-strings': 7.8.3(@babel/core@7.23.7) - '@babel/plugin-syntax-logical-assignment-operators': 7.10.4(@babel/core@7.23.7) - '@babel/plugin-syntax-nullish-coalescing-operator': 7.8.3(@babel/core@7.23.7) - '@babel/plugin-syntax-numeric-separator': 7.10.4(@babel/core@7.23.7) - '@babel/plugin-syntax-object-rest-spread': 7.8.3(@babel/core@7.23.7) - '@babel/plugin-syntax-optional-catch-binding': 7.8.3(@babel/core@7.23.7) - '@babel/plugin-syntax-optional-chaining': 7.8.3(@babel/core@7.23.7) - '@babel/plugin-syntax-private-property-in-object': 7.14.5(@babel/core@7.23.7) - '@babel/plugin-syntax-top-level-await': 7.14.5(@babel/core@7.23.7) - '@babel/plugin-syntax-unicode-sets-regex': 7.18.6(@babel/core@7.23.7) - '@babel/plugin-transform-arrow-functions': 7.23.3(@babel/core@7.23.7) - '@babel/plugin-transform-async-generator-functions': 7.23.7(@babel/core@7.23.7) - '@babel/plugin-transform-async-to-generator': 7.23.3(@babel/core@7.23.7) - '@babel/plugin-transform-block-scoped-functions': 7.23.3(@babel/core@7.23.7) - '@babel/plugin-transform-block-scoping': 7.23.4(@babel/core@7.23.7) - '@babel/plugin-transform-class-properties': 7.23.3(@babel/core@7.23.7) - '@babel/plugin-transform-class-static-block': 7.23.4(@babel/core@7.23.7) - '@babel/plugin-transform-classes': 7.23.8(@babel/core@7.23.7) - '@babel/plugin-transform-computed-properties': 7.23.3(@babel/core@7.23.7) - '@babel/plugin-transform-destructuring': 7.23.3(@babel/core@7.23.7) - '@babel/plugin-transform-dotall-regex': 7.23.3(@babel/core@7.23.7) - '@babel/plugin-transform-duplicate-keys': 7.23.3(@babel/core@7.23.7) - '@babel/plugin-transform-dynamic-import': 7.23.4(@babel/core@7.23.7) - '@babel/plugin-transform-exponentiation-operator': 7.23.3(@babel/core@7.23.7) - '@babel/plugin-transform-export-namespace-from': 7.23.4(@babel/core@7.23.7) - '@babel/plugin-transform-for-of': 7.23.6(@babel/core@7.23.7) - '@babel/plugin-transform-function-name': 7.23.3(@babel/core@7.23.7) - '@babel/plugin-transform-json-strings': 7.23.4(@babel/core@7.23.7) - '@babel/plugin-transform-literals': 7.23.3(@babel/core@7.23.7) - '@babel/plugin-transform-logical-assignment-operators': 7.23.4(@babel/core@7.23.7) - '@babel/plugin-transform-member-expression-literals': 7.23.3(@babel/core@7.23.7) - '@babel/plugin-transform-modules-amd': 7.23.3(@babel/core@7.23.7) - '@babel/plugin-transform-modules-commonjs': 7.23.3(@babel/core@7.23.7) - '@babel/plugin-transform-modules-systemjs': 7.23.3(@babel/core@7.23.7) - '@babel/plugin-transform-modules-umd': 7.23.3(@babel/core@7.23.7) - '@babel/plugin-transform-named-capturing-groups-regex': 7.22.5(@babel/core@7.23.7) - '@babel/plugin-transform-new-target': 7.23.3(@babel/core@7.23.7) - '@babel/plugin-transform-nullish-coalescing-operator': 7.23.4(@babel/core@7.23.7) - '@babel/plugin-transform-numeric-separator': 7.23.4(@babel/core@7.23.7) - '@babel/plugin-transform-object-rest-spread': 7.23.4(@babel/core@7.23.7) - '@babel/plugin-transform-object-super': 7.23.3(@babel/core@7.23.7) - '@babel/plugin-transform-optional-catch-binding': 7.23.4(@babel/core@7.23.7) - '@babel/plugin-transform-optional-chaining': 7.23.4(@babel/core@7.23.7) - '@babel/plugin-transform-parameters': 7.23.3(@babel/core@7.23.7) - '@babel/plugin-transform-private-methods': 7.23.3(@babel/core@7.23.7) - '@babel/plugin-transform-private-property-in-object': 7.23.4(@babel/core@7.23.7) - '@babel/plugin-transform-property-literals': 7.23.3(@babel/core@7.23.7) - '@babel/plugin-transform-regenerator': 7.23.3(@babel/core@7.23.7) - '@babel/plugin-transform-reserved-words': 7.23.3(@babel/core@7.23.7) - '@babel/plugin-transform-shorthand-properties': 7.23.3(@babel/core@7.23.7) - '@babel/plugin-transform-spread': 7.23.3(@babel/core@7.23.7) - '@babel/plugin-transform-sticky-regex': 7.23.3(@babel/core@7.23.7) - '@babel/plugin-transform-template-literals': 7.23.3(@babel/core@7.23.7) - '@babel/plugin-transform-typeof-symbol': 7.23.3(@babel/core@7.23.7) - '@babel/plugin-transform-unicode-escapes': 7.23.3(@babel/core@7.23.7) - '@babel/plugin-transform-unicode-property-regex': 7.23.3(@babel/core@7.23.7) - '@babel/plugin-transform-unicode-regex': 7.23.3(@babel/core@7.23.7) - '@babel/plugin-transform-unicode-sets-regex': 7.23.3(@babel/core@7.23.7) - '@babel/preset-modules': 0.1.6-no-external-plugins(@babel/core@7.23.7) - babel-plugin-polyfill-corejs2: 0.4.8(@babel/core@7.23.7) - babel-plugin-polyfill-corejs3: 0.8.7(@babel/core@7.23.7) - babel-plugin-polyfill-regenerator: 0.5.5(@babel/core@7.23.7) - core-js-compat: 3.35.0 + '@babel/plugin-bugfix-safari-id-destructuring-collision-in-function-expression': 7.23.3(@babel/core@7.23.9) + '@babel/plugin-bugfix-v8-spread-parameters-in-optional-chaining': 7.23.3(@babel/core@7.23.9) + '@babel/plugin-bugfix-v8-static-class-fields-redefine-readonly': 7.23.7(@babel/core@7.23.9) + '@babel/plugin-proposal-private-property-in-object': 7.21.0-placeholder-for-preset-env.2(@babel/core@7.23.9) + '@babel/plugin-syntax-async-generators': 7.8.4(@babel/core@7.23.9) + '@babel/plugin-syntax-class-properties': 7.12.13(@babel/core@7.23.9) + '@babel/plugin-syntax-class-static-block': 7.14.5(@babel/core@7.23.9) + '@babel/plugin-syntax-dynamic-import': 7.8.3(@babel/core@7.23.9) + '@babel/plugin-syntax-export-namespace-from': 7.8.3(@babel/core@7.23.9) + '@babel/plugin-syntax-import-assertions': 7.23.3(@babel/core@7.23.9) + '@babel/plugin-syntax-import-attributes': 7.23.3(@babel/core@7.23.9) + '@babel/plugin-syntax-import-meta': 7.10.4(@babel/core@7.23.9) + '@babel/plugin-syntax-json-strings': 7.8.3(@babel/core@7.23.9) + '@babel/plugin-syntax-logical-assignment-operators': 7.10.4(@babel/core@7.23.9) + '@babel/plugin-syntax-nullish-coalescing-operator': 7.8.3(@babel/core@7.23.9) + '@babel/plugin-syntax-numeric-separator': 7.10.4(@babel/core@7.23.9) + '@babel/plugin-syntax-object-rest-spread': 7.8.3(@babel/core@7.23.9) + '@babel/plugin-syntax-optional-catch-binding': 7.8.3(@babel/core@7.23.9) + '@babel/plugin-syntax-optional-chaining': 7.8.3(@babel/core@7.23.9) + '@babel/plugin-syntax-private-property-in-object': 7.14.5(@babel/core@7.23.9) + '@babel/plugin-syntax-top-level-await': 7.14.5(@babel/core@7.23.9) + '@babel/plugin-syntax-unicode-sets-regex': 7.18.6(@babel/core@7.23.9) + '@babel/plugin-transform-arrow-functions': 7.23.3(@babel/core@7.23.9) + '@babel/plugin-transform-async-generator-functions': 7.23.9(@babel/core@7.23.9) + '@babel/plugin-transform-async-to-generator': 7.23.3(@babel/core@7.23.9) + '@babel/plugin-transform-block-scoped-functions': 7.23.3(@babel/core@7.23.9) + '@babel/plugin-transform-block-scoping': 7.23.4(@babel/core@7.23.9) + '@babel/plugin-transform-class-properties': 7.23.3(@babel/core@7.23.9) + '@babel/plugin-transform-class-static-block': 7.23.4(@babel/core@7.23.9) + '@babel/plugin-transform-classes': 7.23.8(@babel/core@7.23.9) + '@babel/plugin-transform-computed-properties': 7.23.3(@babel/core@7.23.9) + '@babel/plugin-transform-destructuring': 7.23.3(@babel/core@7.23.9) + '@babel/plugin-transform-dotall-regex': 7.23.3(@babel/core@7.23.9) + '@babel/plugin-transform-duplicate-keys': 7.23.3(@babel/core@7.23.9) + '@babel/plugin-transform-dynamic-import': 7.23.4(@babel/core@7.23.9) + '@babel/plugin-transform-exponentiation-operator': 7.23.3(@babel/core@7.23.9) + '@babel/plugin-transform-export-namespace-from': 7.23.4(@babel/core@7.23.9) + '@babel/plugin-transform-for-of': 7.23.6(@babel/core@7.23.9) + '@babel/plugin-transform-function-name': 7.23.3(@babel/core@7.23.9) + '@babel/plugin-transform-json-strings': 7.23.4(@babel/core@7.23.9) + '@babel/plugin-transform-literals': 7.23.3(@babel/core@7.23.9) + '@babel/plugin-transform-logical-assignment-operators': 7.23.4(@babel/core@7.23.9) + '@babel/plugin-transform-member-expression-literals': 7.23.3(@babel/core@7.23.9) + '@babel/plugin-transform-modules-amd': 7.23.3(@babel/core@7.23.9) + '@babel/plugin-transform-modules-commonjs': 7.23.3(@babel/core@7.23.9) + '@babel/plugin-transform-modules-systemjs': 7.23.9(@babel/core@7.23.9) + '@babel/plugin-transform-modules-umd': 7.23.3(@babel/core@7.23.9) + '@babel/plugin-transform-named-capturing-groups-regex': 7.22.5(@babel/core@7.23.9) + '@babel/plugin-transform-new-target': 7.23.3(@babel/core@7.23.9) + '@babel/plugin-transform-nullish-coalescing-operator': 7.23.4(@babel/core@7.23.9) + '@babel/plugin-transform-numeric-separator': 7.23.4(@babel/core@7.23.9) + '@babel/plugin-transform-object-rest-spread': 7.23.4(@babel/core@7.23.9) + '@babel/plugin-transform-object-super': 7.23.3(@babel/core@7.23.9) + '@babel/plugin-transform-optional-catch-binding': 7.23.4(@babel/core@7.23.9) + '@babel/plugin-transform-optional-chaining': 7.23.4(@babel/core@7.23.9) + '@babel/plugin-transform-parameters': 7.23.3(@babel/core@7.23.9) + '@babel/plugin-transform-private-methods': 7.23.3(@babel/core@7.23.9) + '@babel/plugin-transform-private-property-in-object': 7.23.4(@babel/core@7.23.9) + '@babel/plugin-transform-property-literals': 7.23.3(@babel/core@7.23.9) + '@babel/plugin-transform-regenerator': 7.23.3(@babel/core@7.23.9) + '@babel/plugin-transform-reserved-words': 7.23.3(@babel/core@7.23.9) + '@babel/plugin-transform-shorthand-properties': 7.23.3(@babel/core@7.23.9) + '@babel/plugin-transform-spread': 7.23.3(@babel/core@7.23.9) + '@babel/plugin-transform-sticky-regex': 7.23.3(@babel/core@7.23.9) + '@babel/plugin-transform-template-literals': 7.23.3(@babel/core@7.23.9) + '@babel/plugin-transform-typeof-symbol': 7.23.3(@babel/core@7.23.9) + '@babel/plugin-transform-unicode-escapes': 7.23.3(@babel/core@7.23.9) + '@babel/plugin-transform-unicode-property-regex': 7.23.3(@babel/core@7.23.9) + '@babel/plugin-transform-unicode-regex': 7.23.3(@babel/core@7.23.9) + '@babel/plugin-transform-unicode-sets-regex': 7.23.3(@babel/core@7.23.9) + '@babel/preset-modules': 0.1.6-no-external-plugins(@babel/core@7.23.9) + babel-plugin-polyfill-corejs2: 0.4.8(@babel/core@7.23.9) + babel-plugin-polyfill-corejs3: 0.9.0(@babel/core@7.23.9) + babel-plugin-polyfill-regenerator: 0.5.5(@babel/core@7.23.9) + core-js-compat: 3.36.0 semver: 6.3.1 transitivePeerDependencies: - supports-color dev: true - /@babel/preset-flow@7.23.3(@babel/core@7.23.7): + /@babel/preset-flow@7.23.3(@babel/core@7.23.9): resolution: {integrity: sha512-7yn6hl8RIv+KNk6iIrGZ+D06VhVY35wLVf23Cz/mMu1zOr7u4MMP4j0nZ9tLf8+4ZFpnib8cFYgB/oYg9hfswA==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-plugin-utils': 7.22.5 '@babel/helper-validator-option': 7.23.5 - '@babel/plugin-transform-flow-strip-types': 7.23.3(@babel/core@7.23.7) + '@babel/plugin-transform-flow-strip-types': 7.23.3(@babel/core@7.23.9) dev: true - /@babel/preset-modules@0.1.6-no-external-plugins(@babel/core@7.23.7): + /@babel/preset-modules@0.1.6-no-external-plugins(@babel/core@7.23.9): resolution: {integrity: sha512-HrcgcIESLm9aIR842yhJ5RWan/gebQUJ6E/E5+rf0y9o6oj7w0Br+sWuL6kEQ/o/AdfvR1Je9jG18/gnpwjEyA==} peerDependencies: '@babel/core': ^7.0.0-0 || ^8.0.0-0 <8.0.0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-plugin-utils': 7.22.5 - '@babel/types': 7.23.6 + '@babel/types': 7.23.9 esutils: 2.0.3 dev: true - /@babel/preset-typescript@7.23.3(@babel/core@7.23.7): + /@babel/preset-typescript@7.23.3(@babel/core@7.23.9): resolution: {integrity: sha512-17oIGVlqz6CchO9RFYn5U6ZpWRZIngayYCtrPRSgANSwC2V1Jb+iP74nVxzzXJte8b8BYxrL1yY96xfhTBrNNQ==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@babel/helper-plugin-utils': 7.22.5 '@babel/helper-validator-option': 7.23.5 - '@babel/plugin-syntax-jsx': 7.23.3(@babel/core@7.23.7) - '@babel/plugin-transform-modules-commonjs': 7.23.3(@babel/core@7.23.7) - '@babel/plugin-transform-typescript': 7.23.6(@babel/core@7.23.7) + '@babel/plugin-syntax-jsx': 7.23.3(@babel/core@7.23.9) + '@babel/plugin-transform-modules-commonjs': 7.23.3(@babel/core@7.23.9) + '@babel/plugin-transform-typescript': 7.23.6(@babel/core@7.23.9) dev: true - /@babel/register@7.23.7(@babel/core@7.23.7): + /@babel/register@7.23.7(@babel/core@7.23.9): resolution: {integrity: sha512-EjJeB6+kvpk+Y5DAkEAmbOBEFkh9OASx0huoEkqYTFxAZHzOAX2Oh5uwAUuL2rUddqfM0SA+KPXV2TbzoZ2kvQ==} engines: {node: '>=6.9.0'} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 clone-deep: 4.0.1 find-cache-dir: 2.1.0 make-dir: 2.1.0 @@ -1690,43 +1633,23 @@ packages: resolution: {integrity: sha512-x/rqGMdzj+fWZvCOYForTghzbtqPDZ5gPwaoNGHdgDfF2QA/XZbCBp4Moo5scrkAMPhB7z26XM/AaHuIJdgauA==} dev: true - /@babel/runtime@7.23.6: - resolution: {integrity: sha512-zHd0eUrf5GZoOWVCXp6koAKQTfZV07eit6bGPmJgnZdnSAvvZee6zniW2XMF7Cmc4ISOOnPy3QaSiIJGJkVEDQ==} - engines: {node: '>=6.9.0'} - dependencies: - regenerator-runtime: 0.14.1 - - /@babel/runtime@7.23.7: - resolution: {integrity: sha512-w06OXVOFso7LcbzMiDGt+3X7Rh7Ho8MmgPoWU3rarH+8upf+wSU/grlGbWzQyr3DkdN6ZeuMFjpdwW0Q+HxobA==} - engines: {node: '>=6.9.0'} - dependencies: - regenerator-runtime: 0.14.1 - dev: false - - /@babel/runtime@7.23.8: - resolution: {integrity: sha512-Y7KbAP984rn1VGMbGqKmBLio9V7y5Je9GvU4rQPCPinCyNfUcToxIXl06d59URp/F3LwinvODxab5N/G6qggkw==} - engines: {node: '>=6.9.0'} - dependencies: - regenerator-runtime: 0.14.1 - /@babel/runtime@7.23.9: resolution: {integrity: sha512-0CX6F+BI2s9dkUqr08KFrAIZgNFj75rdBU/DjCyYLIaV/quFjkk6T+EJ2LkZHyZTbEV4L5p97mNkUsHl2wLFAw==} engines: {node: '>=6.9.0'} dependencies: regenerator-runtime: 0.14.1 - dev: false - /@babel/template@7.22.15: - resolution: {integrity: sha512-QPErUVm4uyJa60rkI73qneDacvdvzxshT3kksGqlGWYdOTIUOwJ7RDUL8sGqslY1uXWSL6xMFKEXDS3ox2uF0w==} + /@babel/template@7.23.9: + resolution: {integrity: sha512-+xrD2BWLpvHKNmX2QbpdpsBaWnRxahMwJjO+KZk2JOElj5nSmKezyS1B4u+QbHMTX69t4ukm6hh9lsYQ7GHCKA==} engines: {node: '>=6.9.0'} dependencies: '@babel/code-frame': 7.23.5 - '@babel/parser': 7.23.6 - '@babel/types': 7.23.6 + '@babel/parser': 7.23.9 + '@babel/types': 7.23.9 dev: true - /@babel/traverse@7.23.7: - resolution: {integrity: sha512-tY3mM8rH9jM0YHFGyfC0/xf+SB5eKUu7HPj7/k3fpi9dAlsMc5YbQvDi0Sh2QTPXqMhyaAtzAr807TIyfQrmyg==} + /@babel/traverse@7.23.9: + resolution: {integrity: sha512-I/4UJ9vs90OkBtY6iiiTORVMyIhJ4kAVmsKo9KFc8UOxMeUfi2hvtIBsET5u9GizXE6/GFSuKCTNfgCswuEjRg==} engines: {node: '>=6.9.0'} dependencies: '@babel/code-frame': 7.23.5 @@ -1735,16 +1658,16 @@ packages: '@babel/helper-function-name': 7.23.0 '@babel/helper-hoist-variables': 7.22.5 '@babel/helper-split-export-declaration': 7.22.6 - '@babel/parser': 7.23.6 - '@babel/types': 7.23.6 + '@babel/parser': 7.23.9 + '@babel/types': 7.23.9 debug: 4.3.4 globals: 11.12.0 transitivePeerDependencies: - supports-color dev: true - /@babel/types@7.23.6: - resolution: {integrity: sha512-+uarb83brBzPKN38NX1MkB6vb6+mwvR6amUulqAE7ccQw1pEl+bCia9TbdG1lsnFP7lZySvUn37CHyXQdfTwzg==} + /@babel/types@7.23.9: + resolution: {integrity: sha512-dQjSq/7HaSjRM43FFGnv5keM2HsxpmyV1PfaSVm0nzzjwwTmjOe6J4bC8e3+pTEIgHaHj+1ZlLThRJ2auc/w1Q==} engines: {node: '>=6.9.0'} dependencies: '@babel/helper-string-parser': 7.23.4 @@ -1774,6 +1697,25 @@ packages: react: 18.2.0 dev: false + /@chakra-ui/accordion@2.3.1(@chakra-ui/system@2.6.2)(framer-motion@11.0.5)(react@18.2.0): + resolution: {integrity: sha512-FSXRm8iClFyU+gVaXisOSEw0/4Q+qZbFRiuhIAkVU6Boj0FxAMrlo9a8AV5TuF77rgaHytCdHk0Ng+cyUijrag==} + peerDependencies: + '@chakra-ui/system': '>=2.0.0' + framer-motion: '>=4.0.0' + react: '>=18' + dependencies: + '@chakra-ui/descendant': 3.1.0(react@18.2.0) + '@chakra-ui/icon': 3.2.0(@chakra-ui/system@2.6.2)(react@18.2.0) + '@chakra-ui/react-context': 2.1.0(react@18.2.0) + '@chakra-ui/react-use-controllable-state': 2.1.0(react@18.2.0) + '@chakra-ui/react-use-merge-refs': 2.1.0(react@18.2.0) + '@chakra-ui/shared-utils': 2.0.5 + '@chakra-ui/system': 2.6.2(@emotion/react@11.11.3)(@emotion/styled@11.11.0)(react@18.2.0) + '@chakra-ui/transition': 2.1.0(framer-motion@11.0.5)(react@18.2.0) + framer-motion: 11.0.5(react-dom@18.2.0)(react@18.2.0) + react: 18.2.0 + dev: false + /@chakra-ui/alert@2.2.2(@chakra-ui/system@2.6.2)(react@18.2.0): resolution: {integrity: sha512-jHg4LYMRNOJH830ViLuicjb3F+v6iriE/2G5T+Sd0Hna04nukNJ1MxUmBPE+vI22me2dIflfelu2v9wdB6Pojw==} peerDependencies: @@ -1928,7 +1870,7 @@ packages: '@emotion/react': '>=10.0.35' react: '>=18' dependencies: - '@emotion/react': 11.11.3(@types/react@18.2.48)(react@18.2.0) + '@emotion/react': 11.11.3(@types/react@18.2.57)(react@18.2.0) react: 18.2.0 dev: false @@ -1969,14 +1911,14 @@ packages: resolution: {integrity: sha512-IGM/yGUHS+8TOQrZGpAKOJl/xGBrmRYJrmbHfUE7zrG3PpQyXvbLDP1M+RggkCFVgHlJi2wpYIf0QtQlU0XZfw==} dev: false - /@chakra-ui/focus-lock@2.1.0(@types/react@18.2.48)(react@18.2.0): + /@chakra-ui/focus-lock@2.1.0(@types/react@18.2.57)(react@18.2.0): resolution: {integrity: sha512-EmGx4PhWGjm4dpjRqM4Aa+rCWBxP+Rq8Uc/nAVnD4YVqkEhBkrPTpui2lnjsuxqNaZ24fIAZ10cF1hlpemte/w==} peerDependencies: react: '>=18' dependencies: '@chakra-ui/dom-utils': 2.1.0 react: 18.2.0 - react-focus-lock: 2.11.1(@types/react@18.2.48)(react@18.2.0) + react-focus-lock: 2.11.1(@types/react@18.2.57)(react@18.2.0) transitivePeerDependencies: - '@types/react' dev: false @@ -2125,7 +2067,34 @@ packages: react: 18.2.0 dev: false - /@chakra-ui/modal@2.3.1(@chakra-ui/system@2.6.2)(@types/react@18.2.48)(framer-motion@10.18.0)(react-dom@18.2.0)(react@18.2.0): + /@chakra-ui/menu@2.2.1(@chakra-ui/system@2.6.2)(framer-motion@11.0.5)(react@18.2.0): + resolution: {integrity: sha512-lJS7XEObzJxsOwWQh7yfG4H8FzFPRP5hVPN/CL+JzytEINCSBvsCDHrYPQGp7jzpCi8vnTqQQGQe0f8dwnXd2g==} + peerDependencies: + '@chakra-ui/system': '>=2.0.0' + framer-motion: '>=4.0.0' + react: '>=18' + dependencies: + '@chakra-ui/clickable': 2.1.0(react@18.2.0) + '@chakra-ui/descendant': 3.1.0(react@18.2.0) + '@chakra-ui/lazy-utils': 2.0.5 + '@chakra-ui/popper': 3.1.0(react@18.2.0) + '@chakra-ui/react-children-utils': 2.0.6(react@18.2.0) + '@chakra-ui/react-context': 2.1.0(react@18.2.0) + '@chakra-ui/react-use-animation-state': 2.1.0(react@18.2.0) + '@chakra-ui/react-use-controllable-state': 2.1.0(react@18.2.0) + '@chakra-ui/react-use-disclosure': 2.1.0(react@18.2.0) + '@chakra-ui/react-use-focus-effect': 2.1.0(react@18.2.0) + '@chakra-ui/react-use-merge-refs': 2.1.0(react@18.2.0) + '@chakra-ui/react-use-outside-click': 2.2.0(react@18.2.0) + '@chakra-ui/react-use-update-effect': 2.1.0(react@18.2.0) + '@chakra-ui/shared-utils': 2.0.5 + '@chakra-ui/system': 2.6.2(@emotion/react@11.11.3)(@emotion/styled@11.11.0)(react@18.2.0) + '@chakra-ui/transition': 2.1.0(framer-motion@11.0.5)(react@18.2.0) + framer-motion: 11.0.5(react-dom@18.2.0)(react@18.2.0) + react: 18.2.0 + dev: false + + /@chakra-ui/modal@2.3.1(@chakra-ui/system@2.6.2)(@types/react@18.2.57)(framer-motion@10.18.0)(react-dom@18.2.0)(react@18.2.0): resolution: {integrity: sha512-TQv1ZaiJMZN+rR9DK0snx/OPwmtaGH1HbZtlYt4W4s6CzyK541fxLRTjIXfEzIGpvNW+b6VFuFjbcR78p4DEoQ==} peerDependencies: '@chakra-ui/system': '>=2.0.0' @@ -2134,7 +2103,7 @@ packages: react-dom: '>=18' dependencies: '@chakra-ui/close-button': 2.1.1(@chakra-ui/system@2.6.2)(react@18.2.0) - '@chakra-ui/focus-lock': 2.1.0(@types/react@18.2.48)(react@18.2.0) + '@chakra-ui/focus-lock': 2.1.0(@types/react@18.2.57)(react@18.2.0) '@chakra-ui/portal': 2.1.0(react-dom@18.2.0)(react@18.2.0) '@chakra-ui/react-context': 2.1.0(react@18.2.0) '@chakra-ui/react-types': 2.0.7(react@18.2.0) @@ -2146,7 +2115,33 @@ packages: framer-motion: 10.18.0(react-dom@18.2.0)(react@18.2.0) react: 18.2.0 react-dom: 18.2.0(react@18.2.0) - react-remove-scroll: 2.5.7(@types/react@18.2.48)(react@18.2.0) + react-remove-scroll: 2.5.7(@types/react@18.2.57)(react@18.2.0) + transitivePeerDependencies: + - '@types/react' + dev: false + + /@chakra-ui/modal@2.3.1(@chakra-ui/system@2.6.2)(@types/react@18.2.57)(framer-motion@11.0.5)(react-dom@18.2.0)(react@18.2.0): + resolution: {integrity: sha512-TQv1ZaiJMZN+rR9DK0snx/OPwmtaGH1HbZtlYt4W4s6CzyK541fxLRTjIXfEzIGpvNW+b6VFuFjbcR78p4DEoQ==} + peerDependencies: + '@chakra-ui/system': '>=2.0.0' + framer-motion: '>=4.0.0' + react: '>=18' + react-dom: '>=18' + dependencies: + '@chakra-ui/close-button': 2.1.1(@chakra-ui/system@2.6.2)(react@18.2.0) + '@chakra-ui/focus-lock': 2.1.0(@types/react@18.2.57)(react@18.2.0) + '@chakra-ui/portal': 2.1.0(react-dom@18.2.0)(react@18.2.0) + '@chakra-ui/react-context': 2.1.0(react@18.2.0) + '@chakra-ui/react-types': 2.0.7(react@18.2.0) + '@chakra-ui/react-use-merge-refs': 2.1.0(react@18.2.0) + '@chakra-ui/shared-utils': 2.0.5 + '@chakra-ui/system': 2.6.2(@emotion/react@11.11.3)(@emotion/styled@11.11.0)(react@18.2.0) + '@chakra-ui/transition': 2.1.0(framer-motion@11.0.5)(react@18.2.0) + aria-hidden: 1.2.3 + framer-motion: 11.0.5(react-dom@18.2.0)(react@18.2.0) + react: 18.2.0 + react-dom: 18.2.0(react@18.2.0) + react-remove-scroll: 2.5.7(@types/react@18.2.57)(react@18.2.0) transitivePeerDependencies: - '@types/react' dev: false @@ -2220,6 +2215,29 @@ packages: react: 18.2.0 dev: false + /@chakra-ui/popover@2.2.1(@chakra-ui/system@2.6.2)(framer-motion@11.0.5)(react@18.2.0): + resolution: {integrity: sha512-K+2ai2dD0ljvJnlrzesCDT9mNzLifE3noGKZ3QwLqd/K34Ym1W/0aL1ERSynrcG78NKoXS54SdEzkhCZ4Gn/Zg==} + peerDependencies: + '@chakra-ui/system': '>=2.0.0' + framer-motion: '>=4.0.0' + react: '>=18' + dependencies: + '@chakra-ui/close-button': 2.1.1(@chakra-ui/system@2.6.2)(react@18.2.0) + '@chakra-ui/lazy-utils': 2.0.5 + '@chakra-ui/popper': 3.1.0(react@18.2.0) + '@chakra-ui/react-context': 2.1.0(react@18.2.0) + '@chakra-ui/react-types': 2.0.7(react@18.2.0) + '@chakra-ui/react-use-animation-state': 2.1.0(react@18.2.0) + '@chakra-ui/react-use-disclosure': 2.1.0(react@18.2.0) + '@chakra-ui/react-use-focus-effect': 2.1.0(react@18.2.0) + '@chakra-ui/react-use-focus-on-pointer-down': 2.1.0(react@18.2.0) + '@chakra-ui/react-use-merge-refs': 2.1.0(react@18.2.0) + '@chakra-ui/shared-utils': 2.0.5 + '@chakra-ui/system': 2.6.2(@emotion/react@11.11.3)(@emotion/styled@11.11.0)(react@18.2.0) + framer-motion: 11.0.5(react-dom@18.2.0)(react@18.2.0) + react: 18.2.0 + dev: false + /@chakra-ui/popper@3.1.0(react@18.2.0): resolution: {integrity: sha512-ciDdpdYbeFG7og6/6J8lkTFxsSvwTdMLFkpVylAF6VNC22jssiWfquj2eyD4rJnzkRFPvIWJq8hvbfhsm+AjSg==} peerDependencies: @@ -2267,8 +2285,8 @@ packages: '@chakra-ui/react-env': 3.1.0(react@18.2.0) '@chakra-ui/system': 2.6.2(@emotion/react@11.11.3)(@emotion/styled@11.11.0)(react@18.2.0) '@chakra-ui/utils': 2.0.15 - '@emotion/react': 11.11.3(@types/react@18.2.48)(react@18.2.0) - '@emotion/styled': 11.11.0(@emotion/react@11.11.3)(@types/react@18.2.48)(react@18.2.0) + '@emotion/react': 11.11.3(@types/react@18.2.57)(react@18.2.0) + '@emotion/styled': 11.11.0(@emotion/react@11.11.3)(@types/react@18.2.57)(react@18.2.0) react: 18.2.0 react-dom: 18.2.0(react@18.2.0) dev: false @@ -2484,7 +2502,7 @@ packages: react: 18.2.0 dev: false - /@chakra-ui/react@2.8.2(@emotion/react@11.11.3)(@emotion/styled@11.11.0)(@types/react@18.2.48)(framer-motion@10.18.0)(react-dom@18.2.0)(react@18.2.0): + /@chakra-ui/react@2.8.2(@emotion/react@11.11.3)(@emotion/styled@11.11.0)(@types/react@18.2.57)(framer-motion@10.18.0)(react-dom@18.2.0)(react@18.2.0): resolution: {integrity: sha512-Hn0moyxxyCDKuR9ywYpqgX8dvjqwu9ArwpIb9wHNYjnODETjLwazgNIliCVBRcJvysGRiV51U2/JtJVrpeCjUQ==} peerDependencies: '@emotion/react': ^11.0.0 @@ -2505,7 +2523,7 @@ packages: '@chakra-ui/counter': 2.1.0(react@18.2.0) '@chakra-ui/css-reset': 2.3.0(@emotion/react@11.11.3)(react@18.2.0) '@chakra-ui/editable': 3.1.0(@chakra-ui/system@2.6.2)(react@18.2.0) - '@chakra-ui/focus-lock': 2.1.0(@types/react@18.2.48)(react@18.2.0) + '@chakra-ui/focus-lock': 2.1.0(@types/react@18.2.57)(react@18.2.0) '@chakra-ui/form-control': 2.2.0(@chakra-ui/system@2.6.2)(react@18.2.0) '@chakra-ui/hooks': 2.2.1(react@18.2.0) '@chakra-ui/icon': 3.2.0(@chakra-ui/system@2.6.2)(react@18.2.0) @@ -2515,7 +2533,7 @@ packages: '@chakra-ui/live-region': 2.1.0(react@18.2.0) '@chakra-ui/media-query': 3.3.0(@chakra-ui/system@2.6.2)(react@18.2.0) '@chakra-ui/menu': 2.2.1(@chakra-ui/system@2.6.2)(framer-motion@10.18.0)(react@18.2.0) - '@chakra-ui/modal': 2.3.1(@chakra-ui/system@2.6.2)(@types/react@18.2.48)(framer-motion@10.18.0)(react-dom@18.2.0)(react@18.2.0) + '@chakra-ui/modal': 2.3.1(@chakra-ui/system@2.6.2)(@types/react@18.2.57)(framer-motion@10.18.0)(react-dom@18.2.0)(react@18.2.0) '@chakra-ui/number-input': 2.1.2(@chakra-ui/system@2.6.2)(react@18.2.0) '@chakra-ui/pin-input': 2.1.0(@chakra-ui/system@2.6.2)(react@18.2.0) '@chakra-ui/popover': 2.2.1(@chakra-ui/system@2.6.2)(framer-motion@10.18.0)(react@18.2.0) @@ -2546,8 +2564,8 @@ packages: '@chakra-ui/transition': 2.1.0(framer-motion@10.18.0)(react@18.2.0) '@chakra-ui/utils': 2.0.15 '@chakra-ui/visually-hidden': 2.2.0(@chakra-ui/system@2.6.2)(react@18.2.0) - '@emotion/react': 11.11.3(@types/react@18.2.48)(react@18.2.0) - '@emotion/styled': 11.11.0(@emotion/react@11.11.3)(@types/react@18.2.48)(react@18.2.0) + '@emotion/react': 11.11.3(@types/react@18.2.57)(react@18.2.0) + '@emotion/styled': 11.11.0(@emotion/react@11.11.3)(@types/react@18.2.57)(react@18.2.0) framer-motion: 10.18.0(react-dom@18.2.0)(react@18.2.0) react: 18.2.0 react-dom: 18.2.0(react@18.2.0) @@ -2555,6 +2573,77 @@ packages: - '@types/react' dev: false + /@chakra-ui/react@2.8.2(@emotion/react@11.11.3)(@emotion/styled@11.11.0)(@types/react@18.2.57)(framer-motion@11.0.5)(react-dom@18.2.0)(react@18.2.0): + resolution: {integrity: sha512-Hn0moyxxyCDKuR9ywYpqgX8dvjqwu9ArwpIb9wHNYjnODETjLwazgNIliCVBRcJvysGRiV51U2/JtJVrpeCjUQ==} + peerDependencies: + '@emotion/react': ^11.0.0 + '@emotion/styled': ^11.0.0 + framer-motion: '>=4.0.0' + react: '>=18' + react-dom: '>=18' + dependencies: + '@chakra-ui/accordion': 2.3.1(@chakra-ui/system@2.6.2)(framer-motion@11.0.5)(react@18.2.0) + '@chakra-ui/alert': 2.2.2(@chakra-ui/system@2.6.2)(react@18.2.0) + '@chakra-ui/avatar': 2.3.0(@chakra-ui/system@2.6.2)(react@18.2.0) + '@chakra-ui/breadcrumb': 2.2.0(@chakra-ui/system@2.6.2)(react@18.2.0) + '@chakra-ui/button': 2.1.0(@chakra-ui/system@2.6.2)(react@18.2.0) + '@chakra-ui/card': 2.2.0(@chakra-ui/system@2.6.2)(react@18.2.0) + '@chakra-ui/checkbox': 2.3.2(@chakra-ui/system@2.6.2)(react@18.2.0) + '@chakra-ui/close-button': 2.1.1(@chakra-ui/system@2.6.2)(react@18.2.0) + '@chakra-ui/control-box': 2.1.0(@chakra-ui/system@2.6.2)(react@18.2.0) + '@chakra-ui/counter': 2.1.0(react@18.2.0) + '@chakra-ui/css-reset': 2.3.0(@emotion/react@11.11.3)(react@18.2.0) + '@chakra-ui/editable': 3.1.0(@chakra-ui/system@2.6.2)(react@18.2.0) + '@chakra-ui/focus-lock': 2.1.0(@types/react@18.2.57)(react@18.2.0) + '@chakra-ui/form-control': 2.2.0(@chakra-ui/system@2.6.2)(react@18.2.0) + '@chakra-ui/hooks': 2.2.1(react@18.2.0) + '@chakra-ui/icon': 3.2.0(@chakra-ui/system@2.6.2)(react@18.2.0) + '@chakra-ui/image': 2.1.0(@chakra-ui/system@2.6.2)(react@18.2.0) + '@chakra-ui/input': 2.1.2(@chakra-ui/system@2.6.2)(react@18.2.0) + '@chakra-ui/layout': 2.3.1(@chakra-ui/system@2.6.2)(react@18.2.0) + '@chakra-ui/live-region': 2.1.0(react@18.2.0) + '@chakra-ui/media-query': 3.3.0(@chakra-ui/system@2.6.2)(react@18.2.0) + '@chakra-ui/menu': 2.2.1(@chakra-ui/system@2.6.2)(framer-motion@11.0.5)(react@18.2.0) + '@chakra-ui/modal': 2.3.1(@chakra-ui/system@2.6.2)(@types/react@18.2.57)(framer-motion@11.0.5)(react-dom@18.2.0)(react@18.2.0) + '@chakra-ui/number-input': 2.1.2(@chakra-ui/system@2.6.2)(react@18.2.0) + '@chakra-ui/pin-input': 2.1.0(@chakra-ui/system@2.6.2)(react@18.2.0) + '@chakra-ui/popover': 2.2.1(@chakra-ui/system@2.6.2)(framer-motion@11.0.5)(react@18.2.0) + '@chakra-ui/popper': 3.1.0(react@18.2.0) + '@chakra-ui/portal': 2.1.0(react-dom@18.2.0)(react@18.2.0) + '@chakra-ui/progress': 2.2.0(@chakra-ui/system@2.6.2)(react@18.2.0) + '@chakra-ui/provider': 2.4.2(@emotion/react@11.11.3)(@emotion/styled@11.11.0)(react-dom@18.2.0)(react@18.2.0) + '@chakra-ui/radio': 2.1.2(@chakra-ui/system@2.6.2)(react@18.2.0) + '@chakra-ui/react-env': 3.1.0(react@18.2.0) + '@chakra-ui/select': 2.1.2(@chakra-ui/system@2.6.2)(react@18.2.0) + '@chakra-ui/skeleton': 2.1.0(@chakra-ui/system@2.6.2)(react@18.2.0) + '@chakra-ui/skip-nav': 2.1.0(@chakra-ui/system@2.6.2)(react@18.2.0) + '@chakra-ui/slider': 2.1.0(@chakra-ui/system@2.6.2)(react@18.2.0) + '@chakra-ui/spinner': 2.1.0(@chakra-ui/system@2.6.2)(react@18.2.0) + '@chakra-ui/stat': 2.1.1(@chakra-ui/system@2.6.2)(react@18.2.0) + '@chakra-ui/stepper': 2.3.1(@chakra-ui/system@2.6.2)(react@18.2.0) + '@chakra-ui/styled-system': 2.9.2 + '@chakra-ui/switch': 2.1.2(@chakra-ui/system@2.6.2)(framer-motion@11.0.5)(react@18.2.0) + '@chakra-ui/system': 2.6.2(@emotion/react@11.11.3)(@emotion/styled@11.11.0)(react@18.2.0) + '@chakra-ui/table': 2.1.0(@chakra-ui/system@2.6.2)(react@18.2.0) + '@chakra-ui/tabs': 3.0.0(@chakra-ui/system@2.6.2)(react@18.2.0) + '@chakra-ui/tag': 3.1.1(@chakra-ui/system@2.6.2)(react@18.2.0) + '@chakra-ui/textarea': 2.1.2(@chakra-ui/system@2.6.2)(react@18.2.0) + '@chakra-ui/theme': 3.3.1(@chakra-ui/styled-system@2.9.2) + '@chakra-ui/theme-utils': 2.0.21 + '@chakra-ui/toast': 7.0.2(@chakra-ui/system@2.6.2)(framer-motion@11.0.5)(react-dom@18.2.0)(react@18.2.0) + '@chakra-ui/tooltip': 2.3.1(@chakra-ui/system@2.6.2)(framer-motion@11.0.5)(react-dom@18.2.0)(react@18.2.0) + '@chakra-ui/transition': 2.1.0(framer-motion@11.0.5)(react@18.2.0) + '@chakra-ui/utils': 2.0.15 + '@chakra-ui/visually-hidden': 2.2.0(@chakra-ui/system@2.6.2)(react@18.2.0) + '@emotion/react': 11.11.3(@types/react@18.2.57)(react@18.2.0) + '@emotion/styled': 11.11.0(@emotion/react@11.11.3)(@types/react@18.2.57)(react@18.2.0) + framer-motion: 11.0.5(react-dom@18.2.0)(react@18.2.0) + react: 18.2.0 + react-dom: 18.2.0(react@18.2.0) + transitivePeerDependencies: + - '@types/react' + dev: false + /@chakra-ui/select@2.1.2(@chakra-ui/system@2.6.2)(react@18.2.0): resolution: {integrity: sha512-ZwCb7LqKCVLJhru3DXvKXpZ7Pbu1TDZ7N0PdQ0Zj1oyVLJyrpef1u9HR5u0amOpqcH++Ugt0f5JSmirjNlctjA==} peerDependencies: @@ -2673,6 +2762,20 @@ packages: react: 18.2.0 dev: false + /@chakra-ui/switch@2.1.2(@chakra-ui/system@2.6.2)(framer-motion@11.0.5)(react@18.2.0): + resolution: {integrity: sha512-pgmi/CC+E1v31FcnQhsSGjJnOE2OcND4cKPyTE+0F+bmGm48Q/b5UmKD9Y+CmZsrt/7V3h8KNczowupfuBfIHA==} + peerDependencies: + '@chakra-ui/system': '>=2.0.0' + framer-motion: '>=4.0.0' + react: '>=18' + dependencies: + '@chakra-ui/checkbox': 2.3.2(@chakra-ui/system@2.6.2)(react@18.2.0) + '@chakra-ui/shared-utils': 2.0.5 + '@chakra-ui/system': 2.6.2(@emotion/react@11.11.3)(@emotion/styled@11.11.0)(react@18.2.0) + framer-motion: 11.0.5(react-dom@18.2.0)(react@18.2.0) + react: 18.2.0 + dev: false + /@chakra-ui/system@2.6.2(@emotion/react@11.11.3)(@emotion/styled@11.11.0)(react@18.2.0): resolution: {integrity: sha512-EGtpoEjLrUu4W1fHD+a62XR+hzC5YfsWm+6lO0Kybcga3yYEij9beegO0jZgug27V+Rf7vns95VPVP6mFd/DEQ==} peerDependencies: @@ -2686,8 +2789,8 @@ packages: '@chakra-ui/styled-system': 2.9.2 '@chakra-ui/theme-utils': 2.0.21 '@chakra-ui/utils': 2.0.15 - '@emotion/react': 11.11.3(@types/react@18.2.48)(react@18.2.0) - '@emotion/styled': 11.11.0(@emotion/react@11.11.3)(@types/react@18.2.48)(react@18.2.0) + '@emotion/react': 11.11.3(@types/react@18.2.57)(react@18.2.0) + '@emotion/styled': 11.11.0(@emotion/react@11.11.3)(@types/react@18.2.57)(react@18.2.0) react: 18.2.0 react-fast-compare: 3.2.2 dev: false @@ -2801,6 +2904,29 @@ packages: react-dom: 18.2.0(react@18.2.0) dev: false + /@chakra-ui/toast@7.0.2(@chakra-ui/system@2.6.2)(framer-motion@11.0.5)(react-dom@18.2.0)(react@18.2.0): + resolution: {integrity: sha512-yvRP8jFKRs/YnkuE41BVTq9nB2v/KDRmje9u6dgDmE5+1bFt3bwjdf9gVbif4u5Ve7F7BGk5E093ARRVtvLvXA==} + peerDependencies: + '@chakra-ui/system': 2.6.2 + framer-motion: '>=4.0.0' + react: '>=18' + react-dom: '>=18' + dependencies: + '@chakra-ui/alert': 2.2.2(@chakra-ui/system@2.6.2)(react@18.2.0) + '@chakra-ui/close-button': 2.1.1(@chakra-ui/system@2.6.2)(react@18.2.0) + '@chakra-ui/portal': 2.1.0(react-dom@18.2.0)(react@18.2.0) + '@chakra-ui/react-context': 2.1.0(react@18.2.0) + '@chakra-ui/react-use-timeout': 2.1.0(react@18.2.0) + '@chakra-ui/react-use-update-effect': 2.1.0(react@18.2.0) + '@chakra-ui/shared-utils': 2.0.5 + '@chakra-ui/styled-system': 2.9.2 + '@chakra-ui/system': 2.6.2(@emotion/react@11.11.3)(@emotion/styled@11.11.0)(react@18.2.0) + '@chakra-ui/theme': 3.3.1(@chakra-ui/styled-system@2.9.2) + framer-motion: 11.0.5(react-dom@18.2.0)(react@18.2.0) + react: 18.2.0 + react-dom: 18.2.0(react@18.2.0) + dev: false + /@chakra-ui/tooltip@2.3.1(@chakra-ui/system@2.6.2)(framer-motion@10.18.0)(react-dom@18.2.0)(react@18.2.0): resolution: {integrity: sha512-Rh39GBn/bL4kZpuEMPPRwYNnccRCL+w9OqamWHIB3Qboxs6h8cOyXfIdGxjo72lvhu1QI/a4KFqkM3St+WfC0A==} peerDependencies: @@ -2823,6 +2949,28 @@ packages: react-dom: 18.2.0(react@18.2.0) dev: false + /@chakra-ui/tooltip@2.3.1(@chakra-ui/system@2.6.2)(framer-motion@11.0.5)(react-dom@18.2.0)(react@18.2.0): + resolution: {integrity: sha512-Rh39GBn/bL4kZpuEMPPRwYNnccRCL+w9OqamWHIB3Qboxs6h8cOyXfIdGxjo72lvhu1QI/a4KFqkM3St+WfC0A==} + peerDependencies: + '@chakra-ui/system': '>=2.0.0' + framer-motion: '>=4.0.0' + react: '>=18' + react-dom: '>=18' + dependencies: + '@chakra-ui/dom-utils': 2.1.0 + '@chakra-ui/popper': 3.1.0(react@18.2.0) + '@chakra-ui/portal': 2.1.0(react-dom@18.2.0)(react@18.2.0) + '@chakra-ui/react-types': 2.0.7(react@18.2.0) + '@chakra-ui/react-use-disclosure': 2.1.0(react@18.2.0) + '@chakra-ui/react-use-event-listener': 2.1.0(react@18.2.0) + '@chakra-ui/react-use-merge-refs': 2.1.0(react@18.2.0) + '@chakra-ui/shared-utils': 2.0.5 + '@chakra-ui/system': 2.6.2(@emotion/react@11.11.3)(@emotion/styled@11.11.0)(react@18.2.0) + framer-motion: 11.0.5(react-dom@18.2.0)(react@18.2.0) + react: 18.2.0 + react-dom: 18.2.0(react@18.2.0) + dev: false + /@chakra-ui/transition@2.1.0(framer-motion@10.18.0)(react@18.2.0): resolution: {integrity: sha512-orkT6T/Dt+/+kVwJNy7zwJ+U2xAZ3EU7M3XCs45RBvUnZDr/u9vdmaM/3D/rOpmQJWgQBwKPJleUXrYWUagEDQ==} peerDependencies: @@ -2834,6 +2982,17 @@ packages: react: 18.2.0 dev: false + /@chakra-ui/transition@2.1.0(framer-motion@11.0.5)(react@18.2.0): + resolution: {integrity: sha512-orkT6T/Dt+/+kVwJNy7zwJ+U2xAZ3EU7M3XCs45RBvUnZDr/u9vdmaM/3D/rOpmQJWgQBwKPJleUXrYWUagEDQ==} + peerDependencies: + framer-motion: '>=4.0.0' + react: '>=18' + dependencies: + '@chakra-ui/shared-utils': 2.0.5 + framer-motion: 11.0.5(react-dom@18.2.0)(react@18.2.0) + react: 18.2.0 + dev: false + /@chakra-ui/utils@2.0.15: resolution: {integrity: sha512-El4+jL0WSaYYs+rJbuYFDbjmfCcfGDmRY95GO4xwzit6YAPZBLcR65rOEwLps+XWluZTy1xdMrusg/hW0c1aAA==} dependencies: @@ -2925,7 +3084,7 @@ packages: resolution: {integrity: sha512-m4HEDZleaaCH+XgDDsPF15Ht6wTLsgDTeR3WYj9Q/k76JtWhrJjcP4+/XlG8LGT/Rol9qUfOIztXeA84ATpqPQ==} dependencies: '@babel/helper-module-imports': 7.22.15 - '@babel/runtime': 7.23.8 + '@babel/runtime': 7.23.9 '@emotion/hash': 0.9.1 '@emotion/memoize': 0.8.1 '@emotion/serialize': 1.1.3 @@ -2975,7 +3134,7 @@ packages: resolution: {integrity: sha512-W2P2c/VRW1/1tLox0mVUalvnWXxavmv/Oum2aPsRcoDJuob75FC3Y8FbpfLwUegRcxINtGUMPq0tFCvYNTBXNA==} dev: false - /@emotion/react@11.11.3(@types/react@18.2.48)(react@18.2.0): + /@emotion/react@11.11.3(@types/react@18.2.57)(react@18.2.0): resolution: {integrity: sha512-Cnn0kuq4DoONOMcnoVsTOR8E+AdnKFf//6kUWc4LCdnxj31pZWn7rIULd6Y7/Js1PiPHzn7SKCM9vB/jBni8eA==} peerDependencies: '@types/react': '*' @@ -2984,14 +3143,14 @@ packages: '@types/react': optional: true dependencies: - '@babel/runtime': 7.23.8 + '@babel/runtime': 7.23.9 '@emotion/babel-plugin': 11.11.0 '@emotion/cache': 11.11.0 '@emotion/serialize': 1.1.3 '@emotion/use-insertion-effect-with-fallbacks': 1.0.1(react@18.2.0) '@emotion/utils': 1.2.1 '@emotion/weak-memoize': 0.3.1 - '@types/react': 18.2.48 + '@types/react': 18.2.57 hoist-non-react-statics: 3.3.2 react: 18.2.0 dev: false @@ -3010,7 +3169,7 @@ packages: resolution: {integrity: sha512-0QBtGvaqtWi+nx6doRwDdBIzhNdZrXUppvTM4dtZZWEGTXL/XE/yJxLMGlDT1Gt+UHH5IX1n+jkXyytE/av7OA==} dev: false - /@emotion/styled@11.11.0(@emotion/react@11.11.3)(@types/react@18.2.48)(react@18.2.0): + /@emotion/styled@11.11.0(@emotion/react@11.11.3)(@types/react@18.2.57)(react@18.2.0): resolution: {integrity: sha512-hM5Nnvu9P3midq5aaXj4I+lnSfNi7Pmd4EWk1fOZ3pxookaQTNew6bp4JaCBYM4HVFZF9g7UjJmsUmC2JlxOng==} peerDependencies: '@emotion/react': ^11.0.0-rc.0 @@ -3023,11 +3182,11 @@ packages: '@babel/runtime': 7.23.9 '@emotion/babel-plugin': 11.11.0 '@emotion/is-prop-valid': 1.2.1 - '@emotion/react': 11.11.3(@types/react@18.2.48)(react@18.2.0) + '@emotion/react': 11.11.3(@types/react@18.2.57)(react@18.2.0) '@emotion/serialize': 1.1.3 '@emotion/use-insertion-effect-with-fallbacks': 1.0.1(react@18.2.0) '@emotion/utils': 1.2.1 - '@types/react': 18.2.48 + '@types/react': 18.2.57 react: 18.2.0 dev: false @@ -3050,8 +3209,8 @@ packages: resolution: {integrity: sha512-EsBwpc7hBUJWAsNPBmJy4hxWx12v6bshQsldrVmjxJoc3isbxhOrF2IcCpaXxfvq03NwkI7sbsOLXbYuqF/8Ww==} dev: false - /@esbuild/aix-ppc64@0.19.11: - resolution: {integrity: sha512-FnzU0LyE3ySQk7UntJO4+qIiQgI7KoODnZg5xzXIrFJlKd2P2gwHsHY4927xj9y5PJmJSzULiUCWmv7iWnNa7g==} + /@esbuild/aix-ppc64@0.19.12: + resolution: {integrity: sha512-bmoCYyWdEL3wDQIVbcyzRyeKLgk2WtWLTWz1ZIAZF/EGbNOwSA6ew3PftJ1PqMiOOGu0OyFMzG53L0zqIpPeNA==} engines: {node: '>=12'} cpu: [ppc64] os: [aix] @@ -3068,8 +3227,8 @@ packages: dev: true optional: true - /@esbuild/android-arm64@0.19.11: - resolution: {integrity: sha512-aiu7K/5JnLj//KOnOfEZ0D90obUkRzDMyqd/wNAUQ34m4YUPVhRZpnqKV9uqDGxT7cToSDnIHsGooyIczu9T+Q==} + /@esbuild/android-arm64@0.19.12: + resolution: {integrity: sha512-P0UVNGIienjZv3f5zq0DP3Nt2IE/3plFzuaS96vihvD0Hd6H/q4WXUGpCxD/E8YrSXfNyRPbpTq+T8ZQioSuPA==} engines: {node: '>=12'} cpu: [arm64] os: [android] @@ -3086,8 +3245,8 @@ packages: dev: true optional: true - /@esbuild/android-arm@0.19.11: - resolution: {integrity: sha512-5OVapq0ClabvKvQ58Bws8+wkLCV+Rxg7tUVbo9xu034Nm536QTII4YzhaFriQ7rMrorfnFKUsArD2lqKbFY4vw==} + /@esbuild/android-arm@0.19.12: + resolution: {integrity: sha512-qg/Lj1mu3CdQlDEEiWrlC4eaPZ1KztwGJ9B6J+/6G+/4ewxJg7gqj8eVYWvao1bXrqGiW2rsBZFSX3q2lcW05w==} engines: {node: '>=12'} cpu: [arm] os: [android] @@ -3104,8 +3263,8 @@ packages: dev: true optional: true - /@esbuild/android-x64@0.19.11: - resolution: {integrity: sha512-eccxjlfGw43WYoY9QgB82SgGgDbibcqyDTlk3l3C0jOVHKxrjdc9CTwDUQd0vkvYg5um0OH+GpxYvp39r+IPOg==} + /@esbuild/android-x64@0.19.12: + resolution: {integrity: sha512-3k7ZoUW6Q6YqhdhIaq/WZ7HwBpnFBlW905Fa4s4qWJyiNOgT1dOqDiVAQFwBH7gBRZr17gLrlFCRzF6jFh7Kew==} engines: {node: '>=12'} cpu: [x64] os: [android] @@ -3122,8 +3281,8 @@ packages: dev: true optional: true - /@esbuild/darwin-arm64@0.19.11: - resolution: {integrity: sha512-ETp87DRWuSt9KdDVkqSoKoLFHYTrkyz2+65fj9nfXsaV3bMhTCjtQfw3y+um88vGRKRiF7erPrh/ZuIdLUIVxQ==} + /@esbuild/darwin-arm64@0.19.12: + resolution: {integrity: sha512-B6IeSgZgtEzGC42jsI+YYu9Z3HKRxp8ZT3cqhvliEHovq8HSX2YX8lNocDn79gCKJXOSaEot9MVYky7AKjCs8g==} engines: {node: '>=12'} cpu: [arm64] os: [darwin] @@ -3140,8 +3299,8 @@ packages: dev: true optional: true - /@esbuild/darwin-x64@0.19.11: - resolution: {integrity: sha512-fkFUiS6IUK9WYUO/+22omwetaSNl5/A8giXvQlcinLIjVkxwTLSktbF5f/kJMftM2MJp9+fXqZ5ezS7+SALp4g==} + /@esbuild/darwin-x64@0.19.12: + resolution: {integrity: sha512-hKoVkKzFiToTgn+41qGhsUJXFlIjxI/jSYeZf3ugemDYZldIXIxhvwN6erJGlX4t5h417iFuheZ7l+YVn05N3A==} engines: {node: '>=12'} cpu: [x64] os: [darwin] @@ -3158,8 +3317,8 @@ packages: dev: true optional: true - /@esbuild/freebsd-arm64@0.19.11: - resolution: {integrity: sha512-lhoSp5K6bxKRNdXUtHoNc5HhbXVCS8V0iZmDvyWvYq9S5WSfTIHU2UGjcGt7UeS6iEYp9eeymIl5mJBn0yiuxA==} + /@esbuild/freebsd-arm64@0.19.12: + resolution: {integrity: sha512-4aRvFIXmwAcDBw9AueDQ2YnGmz5L6obe5kmPT8Vd+/+x/JMVKCgdcRwH6APrbpNXsPz+K653Qg8HB/oXvXVukA==} engines: {node: '>=12'} cpu: [arm64] os: [freebsd] @@ -3176,8 +3335,8 @@ packages: dev: true optional: true - /@esbuild/freebsd-x64@0.19.11: - resolution: {integrity: sha512-JkUqn44AffGXitVI6/AbQdoYAq0TEullFdqcMY/PCUZ36xJ9ZJRtQabzMA+Vi7r78+25ZIBosLTOKnUXBSi1Kw==} + /@esbuild/freebsd-x64@0.19.12: + resolution: {integrity: sha512-EYoXZ4d8xtBoVN7CEwWY2IN4ho76xjYXqSXMNccFSx2lgqOG/1TBPW0yPx1bJZk94qu3tX0fycJeeQsKovA8gg==} engines: {node: '>=12'} cpu: [x64] os: [freebsd] @@ -3194,8 +3353,8 @@ packages: dev: true optional: true - /@esbuild/linux-arm64@0.19.11: - resolution: {integrity: sha512-LneLg3ypEeveBSMuoa0kwMpCGmpu8XQUh+mL8XXwoYZ6Be2qBnVtcDI5azSvh7vioMDhoJFZzp9GWp9IWpYoUg==} + /@esbuild/linux-arm64@0.19.12: + resolution: {integrity: sha512-EoTjyYyLuVPfdPLsGVVVC8a0p1BFFvtpQDB/YLEhaXyf/5bczaGeN15QkR+O4S5LeJ92Tqotve7i1jn35qwvdA==} engines: {node: '>=12'} cpu: [arm64] os: [linux] @@ -3212,8 +3371,8 @@ packages: dev: true optional: true - /@esbuild/linux-arm@0.19.11: - resolution: {integrity: sha512-3CRkr9+vCV2XJbjwgzjPtO8T0SZUmRZla+UL1jw+XqHZPkPgZiyWvbDvl9rqAN8Zl7qJF0O/9ycMtjU67HN9/Q==} + /@esbuild/linux-arm@0.19.12: + resolution: {integrity: sha512-J5jPms//KhSNv+LO1S1TX1UWp1ucM6N6XuL6ITdKWElCu8wXP72l9MM0zDTzzeikVyqFE6U8YAV9/tFyj0ti+w==} engines: {node: '>=12'} cpu: [arm] os: [linux] @@ -3230,8 +3389,8 @@ packages: dev: true optional: true - /@esbuild/linux-ia32@0.19.11: - resolution: {integrity: sha512-caHy++CsD8Bgq2V5CodbJjFPEiDPq8JJmBdeyZ8GWVQMjRD0sU548nNdwPNvKjVpamYYVL40AORekgfIubwHoA==} + /@esbuild/linux-ia32@0.19.12: + resolution: {integrity: sha512-Thsa42rrP1+UIGaWz47uydHSBOgTUnwBwNq59khgIwktK6x60Hivfbux9iNR0eHCHzOLjLMLfUMLCypBkZXMHA==} engines: {node: '>=12'} cpu: [ia32] os: [linux] @@ -3248,8 +3407,8 @@ packages: dev: true optional: true - /@esbuild/linux-loong64@0.19.11: - resolution: {integrity: sha512-ppZSSLVpPrwHccvC6nQVZaSHlFsvCQyjnvirnVjbKSHuE5N24Yl8F3UwYUUR1UEPaFObGD2tSvVKbvR+uT1Nrg==} + /@esbuild/linux-loong64@0.19.12: + resolution: {integrity: sha512-LiXdXA0s3IqRRjm6rV6XaWATScKAXjI4R4LoDlvO7+yQqFdlr1Bax62sRwkVvRIrwXxvtYEHHI4dm50jAXkuAA==} engines: {node: '>=12'} cpu: [loong64] os: [linux] @@ -3266,8 +3425,8 @@ packages: dev: true optional: true - /@esbuild/linux-mips64el@0.19.11: - resolution: {integrity: sha512-B5x9j0OgjG+v1dF2DkH34lr+7Gmv0kzX6/V0afF41FkPMMqaQ77pH7CrhWeR22aEeHKaeZVtZ6yFwlxOKPVFyg==} + /@esbuild/linux-mips64el@0.19.12: + resolution: {integrity: sha512-fEnAuj5VGTanfJ07ff0gOA6IPsvrVHLVb6Lyd1g2/ed67oU1eFzL0r9WL7ZzscD+/N6i3dWumGE1Un4f7Amf+w==} engines: {node: '>=12'} cpu: [mips64el] os: [linux] @@ -3284,8 +3443,8 @@ packages: dev: true optional: true - /@esbuild/linux-ppc64@0.19.11: - resolution: {integrity: sha512-MHrZYLeCG8vXblMetWyttkdVRjQlQUb/oMgBNurVEnhj4YWOr4G5lmBfZjHYQHHN0g6yDmCAQRR8MUHldvvRDA==} + /@esbuild/linux-ppc64@0.19.12: + resolution: {integrity: sha512-nYJA2/QPimDQOh1rKWedNOe3Gfc8PabU7HT3iXWtNUbRzXS9+vgB0Fjaqr//XNbd82mCxHzik2qotuI89cfixg==} engines: {node: '>=12'} cpu: [ppc64] os: [linux] @@ -3302,8 +3461,8 @@ packages: dev: true optional: true - /@esbuild/linux-riscv64@0.19.11: - resolution: {integrity: sha512-f3DY++t94uVg141dozDu4CCUkYW+09rWtaWfnb3bqe4w5NqmZd6nPVBm+qbz7WaHZCoqXqHz5p6CM6qv3qnSSQ==} + /@esbuild/linux-riscv64@0.19.12: + resolution: {integrity: sha512-2MueBrlPQCw5dVJJpQdUYgeqIzDQgw3QtiAHUC4RBz9FXPrskyyU3VI1hw7C0BSKB9OduwSJ79FTCqtGMWqJHg==} engines: {node: '>=12'} cpu: [riscv64] os: [linux] @@ -3320,8 +3479,8 @@ packages: dev: true optional: true - /@esbuild/linux-s390x@0.19.11: - resolution: {integrity: sha512-A5xdUoyWJHMMlcSMcPGVLzYzpcY8QP1RtYzX5/bS4dvjBGVxdhuiYyFwp7z74ocV7WDc0n1harxmpq2ePOjI0Q==} + /@esbuild/linux-s390x@0.19.12: + resolution: {integrity: sha512-+Pil1Nv3Umes4m3AZKqA2anfhJiVmNCYkPchwFJNEJN5QxmTs1uzyy4TvmDrCRNT2ApwSari7ZIgrPeUx4UZDg==} engines: {node: '>=12'} cpu: [s390x] os: [linux] @@ -3338,8 +3497,8 @@ packages: dev: true optional: true - /@esbuild/linux-x64@0.19.11: - resolution: {integrity: sha512-grbyMlVCvJSfxFQUndw5mCtWs5LO1gUlwP4CDi4iJBbVpZcqLVT29FxgGuBJGSzyOxotFG4LoO5X+M1350zmPA==} + /@esbuild/linux-x64@0.19.12: + resolution: {integrity: sha512-B71g1QpxfwBvNrfyJdVDexenDIt1CiDN1TIXLbhOw0KhJzE78KIFGX6OJ9MrtC0oOqMWf+0xop4qEU8JrJTwCg==} engines: {node: '>=12'} cpu: [x64] os: [linux] @@ -3356,8 +3515,8 @@ packages: dev: true optional: true - /@esbuild/netbsd-x64@0.19.11: - resolution: {integrity: sha512-13jvrQZJc3P230OhU8xgwUnDeuC/9egsjTkXN49b3GcS5BKvJqZn86aGM8W9pd14Kd+u7HuFBMVtrNGhh6fHEQ==} + /@esbuild/netbsd-x64@0.19.12: + resolution: {integrity: sha512-3ltjQ7n1owJgFbuC61Oj++XhtzmymoCihNFgT84UAmJnxJfm4sYCiSLTXZtE00VWYpPMYc+ZQmB6xbSdVh0JWA==} engines: {node: '>=12'} cpu: [x64] os: [netbsd] @@ -3374,8 +3533,8 @@ packages: dev: true optional: true - /@esbuild/openbsd-x64@0.19.11: - resolution: {integrity: sha512-ysyOGZuTp6SNKPE11INDUeFVVQFrhcNDVUgSQVDzqsqX38DjhPEPATpid04LCoUr2WXhQTEZ8ct/EgJCUDpyNw==} + /@esbuild/openbsd-x64@0.19.12: + resolution: {integrity: sha512-RbrfTB9SWsr0kWmb9srfF+L933uMDdu9BIzdA7os2t0TXhCRjrQyCeOt6wVxr79CKD4c+p+YhCj31HBkYcXebw==} engines: {node: '>=12'} cpu: [x64] os: [openbsd] @@ -3392,8 +3551,8 @@ packages: dev: true optional: true - /@esbuild/sunos-x64@0.19.11: - resolution: {integrity: sha512-Hf+Sad9nVwvtxy4DXCZQqLpgmRTQqyFyhT3bZ4F2XlJCjxGmRFF0Shwn9rzhOYRB61w9VMXUkxlBy56dk9JJiQ==} + /@esbuild/sunos-x64@0.19.12: + resolution: {integrity: sha512-HKjJwRrW8uWtCQnQOz9qcU3mUZhTUQvi56Q8DPTLLB+DawoiQdjsYq+j+D3s9I8VFtDr+F9CjgXKKC4ss89IeA==} engines: {node: '>=12'} cpu: [x64] os: [sunos] @@ -3410,8 +3569,8 @@ packages: dev: true optional: true - /@esbuild/win32-arm64@0.19.11: - resolution: {integrity: sha512-0P58Sbi0LctOMOQbpEOvOL44Ne0sqbS0XWHMvvrg6NE5jQ1xguCSSw9jQeUk2lfrXYsKDdOe6K+oZiwKPilYPQ==} + /@esbuild/win32-arm64@0.19.12: + resolution: {integrity: sha512-URgtR1dJnmGvX864pn1B2YUYNzjmXkuJOIqG2HdU62MVS4EHpU2946OZoTMnRUHklGtJdJZ33QfzdjGACXhn1A==} engines: {node: '>=12'} cpu: [arm64] os: [win32] @@ -3428,8 +3587,8 @@ packages: dev: true optional: true - /@esbuild/win32-ia32@0.19.11: - resolution: {integrity: sha512-6YOrWS+sDJDmshdBIQU+Uoyh7pQKrdykdefC1avn76ss5c+RN6gut3LZA4E2cH5xUEp5/cA0+YxRaVtRAb0xBg==} + /@esbuild/win32-ia32@0.19.12: + resolution: {integrity: sha512-+ZOE6pUkMOJfmxmBZElNOx72NKpIa/HFOMGzu8fqzQJ5kgf6aTGrcJaFsNiVMH4JKpMipyK+7k0n2UXN7a8YKQ==} engines: {node: '>=12'} cpu: [ia32] os: [win32] @@ -3446,8 +3605,8 @@ packages: dev: true optional: true - /@esbuild/win32-x64@0.19.11: - resolution: {integrity: sha512-vfkhltrjCAb603XaFhqhAF4LGDi2M4OrCRrFusyQ+iTLQ/o60QQXxc9cZC/FFpihBI9N1Grn6SMKVJ4KP7Fuiw==} + /@esbuild/win32-x64@0.19.12: + resolution: {integrity: sha512-T1QyPSDCyMXaO3pzBkF96E8xMkiRYbUEZADd29SyPGabqxMViNoii+NcK7eWJAEoU6RZyEm5lVSIjTmcdoB9HA==} engines: {node: '>=12'} cpu: [x64] os: [win32] @@ -3478,7 +3637,7 @@ packages: debug: 4.3.4 espree: 9.6.1 globals: 13.24.0 - ignore: 5.3.0 + ignore: 5.3.1 import-fresh: 3.3.0 js-yaml: 4.1.0 minimatch: 3.1.2 @@ -3501,45 +3660,35 @@ packages: engines: {node: '>=14'} dev: true - /@floating-ui/core@1.5.2: - resolution: {integrity: sha512-Ii3MrfY/GAIN3OhXNzpCKaLxHQfJF9qvwq/kEJYdqDxeIHa01K8sldugal6TmeeXl+WMvhv9cnVzUTaFFJF09A==} - dependencies: - '@floating-ui/utils': 0.1.6 - dev: false - - /@floating-ui/core@1.5.3: - resolution: {integrity: sha512-O0WKDOo0yhJuugCx6trZQj5jVJ9yR0ystG2JaNAemYUWce+pmM6WUEFIibnWyEJKdrDxhm75NoSRME35FNaM/Q==} + /@floating-ui/core@1.6.0: + resolution: {integrity: sha512-PcF++MykgmTj3CIyOQbKA/hDzOAiqI3mhuoN44WRCopIs1sgoDoU4oty4Jtqaj/y3oDU6fnVSm4QG0a3t5i0+g==} dependencies: '@floating-ui/utils': 0.2.1 - /@floating-ui/dom@1.5.3: - resolution: {integrity: sha512-ClAbQnEqJAKCJOEbbLo5IUlZHkNszqhuxS4fHAVxRPXPya6Ysf2G8KypnYcOTpx6I8xcgF9bbHb6g/2KpbV8qA==} + /@floating-ui/dom@1.5.4: + resolution: {integrity: sha512-jByEsHIY+eEdCjnTVu+E3ephzTOzkQ8hgUfGwos+bg7NlH33Zc5uO+QHz1mrQUOgIKKDD1RtS201P9NvAfq3XQ==} dependencies: - '@floating-ui/core': 1.5.2 - '@floating-ui/utils': 0.1.6 + '@floating-ui/core': 1.6.0 + '@floating-ui/utils': 0.2.1 dev: false - /@floating-ui/dom@1.5.4: - resolution: {integrity: sha512-jByEsHIY+eEdCjnTVu+E3ephzTOzkQ8hgUfGwos+bg7NlH33Zc5uO+QHz1mrQUOgIKKDD1RtS201P9NvAfq3XQ==} + /@floating-ui/dom@1.6.3: + resolution: {integrity: sha512-RnDthu3mzPlQ31Ss/BTwQ1zjzIhr3lk1gZB1OC56h/1vEtaXkESrOqL5fQVMfXpwGtRwX+YsZBdyHtJMQnkArw==} dependencies: - '@floating-ui/core': 1.5.3 + '@floating-ui/core': 1.6.0 '@floating-ui/utils': 0.2.1 - /@floating-ui/react-dom@2.0.6(react-dom@18.2.0)(react@18.2.0): - resolution: {integrity: sha512-IB8aCRFxr8nFkdYZgH+Otd9EVQPJoynxeFRGTB8voPoZMRWo8XjYuCRgpI1btvuKY69XMiLnW+ym7zoBHM90Rw==} + /@floating-ui/react-dom@2.0.8(react-dom@18.2.0)(react@18.2.0): + resolution: {integrity: sha512-HOdqOt3R3OGeTKidaLvJKcgg75S6tibQ3Tif4eyd91QnIJWr0NLvoXFpJA/j8HqkFSL68GDca9AuyWEHlhyClw==} peerDependencies: react: '>=16.8.0' react-dom: '>=16.8.0' dependencies: - '@floating-ui/dom': 1.5.4 + '@floating-ui/dom': 1.6.3 react: 18.2.0 react-dom: 18.2.0(react@18.2.0) dev: true - /@floating-ui/utils@0.1.6: - resolution: {integrity: sha512-OfX7E2oUDYxtBvsuS4e/jSn4Q9Qb6DzgeYtsAdkPZ47znpoNsMgZw0+tVijiv3uGNR6dgNlty6r9rzIzHjtd/A==} - dev: false - /@floating-ui/utils@0.2.1: resolution: {integrity: sha512-9TANp6GPoMtYzQdt54kfAyMmz1+osLlXdg2ENroU7zzrtflTLrrC/lgrIfaSe+Wu0b89GKccT7vxXA0MoAIO+Q==} @@ -3547,11 +3696,11 @@ packages: resolution: {integrity: sha512-k+BUNqksTL+AN+o+OV7ILeiE9B5M5X+/jA7LWvCwjbV9ovXTqZyKRhA/x7uYv/ml8WQ0XNLBM7cRFIx4jW0/hg==} dev: false - /@humanwhocodes/config-array@0.11.13: - resolution: {integrity: sha512-JSBDMiDKSzQVngfRjOdFXgFfklaXI4K9nLF49Auh21lmBWRLIK3+xTErTWD4KU54pb6coM6ESE7Awz/FNU3zgQ==} + /@humanwhocodes/config-array@0.11.14: + resolution: {integrity: sha512-3T8LkOmg45BV5FICb15QQMsyUSWrQ8AygVfC7ZG32zOalnqrilm018ZVCw0eapXux8FtA33q8PSRSstjee3jSg==} engines: {node: '>=10.10.0'} dependencies: - '@humanwhocodes/object-schema': 2.0.1 + '@humanwhocodes/object-schema': 2.0.2 debug: 4.3.4 minimatch: 3.1.2 transitivePeerDependencies: @@ -3563,8 +3712,8 @@ packages: engines: {node: '>=12.22'} dev: true - /@humanwhocodes/object-schema@2.0.1: - resolution: {integrity: sha512-dvuCeX5fC9dXgJn9t+X5atfmgQAzUOWqS1254Gh0m6i8wKd10ebXkfNKiRK+1GWi/yTvvLDHpoxLr0xxxeslWw==} + /@humanwhocodes/object-schema@2.0.2: + resolution: {integrity: sha512-6EwiSjwWYP7pTckG6I5eyFANjPhmPjUX9JRLUSfNPC7FX7zK9gyZAfUEaECL6ALTpGX5AjnBq3C9XmVWPitNpw==} dev: true /@internationalized/date@3.5.2: @@ -3579,43 +3728,41 @@ packages: '@swc/helpers': 0.5.6 dev: false - /@invoke-ai/eslint-config-react@0.0.13(@typescript-eslint/eslint-plugin@6.19.0)(@typescript-eslint/parser@6.19.0)(eslint-config-prettier@9.1.0)(eslint-plugin-import@2.29.1)(eslint-plugin-react-hooks@4.6.0)(eslint-plugin-react-refresh@0.4.5)(eslint-plugin-react@7.33.2)(eslint-plugin-simple-import-sort@10.0.0)(eslint-plugin-storybook@0.6.15)(eslint-plugin-unused-imports@3.0.0)(eslint@8.56.0): - resolution: {integrity: sha512-dfo9k+wPHdvpy1z6ABoYXR/Ttzs1FAnbC46ttIxVhZuqDq8K5cLWznivrOfl7f0hJb8Cb8HiuQb4pHDxhHBDqA==} + /@invoke-ai/eslint-config-react@0.0.14(eslint@8.56.0)(prettier@3.2.5)(typescript@5.3.3): + resolution: {integrity: sha512-6ZUY9zgdDhv2WUoLdDKOQdU9ImnH0CBOFtRlOaNOh34IOsNRfn+JA7wqA0PKnkiNrlfPkIQWhn4GRJp68NT5bw==} peerDependencies: - '@typescript-eslint/eslint-plugin': ^6.19.0 - '@typescript-eslint/parser': ^6.19.0 eslint: ^8.56.0 - eslint-config-prettier: ^9.1.0 - eslint-plugin-import: ^2.29.1 - eslint-plugin-react: ^7.33.2 - eslint-plugin-react-hooks: ^4.6.0 - eslint-plugin-react-refresh: ^0.4.5 - eslint-plugin-simple-import-sort: ^10.0.0 - eslint-plugin-storybook: ^0.6.15 - eslint-plugin-unused-imports: ^3.0.0 - dependencies: - '@typescript-eslint/eslint-plugin': 6.19.0(@typescript-eslint/parser@6.19.0)(eslint@8.56.0)(typescript@5.3.3) - '@typescript-eslint/parser': 6.19.0(eslint@8.56.0)(typescript@5.3.3) + prettier: ^3.2.5 + typescript: ^5.3.3 + dependencies: + '@typescript-eslint/eslint-plugin': 7.0.2(@typescript-eslint/parser@7.0.2)(eslint@8.56.0)(typescript@5.3.3) + '@typescript-eslint/parser': 7.0.2(eslint@8.56.0)(typescript@5.3.3) eslint: 8.56.0 eslint-config-prettier: 9.1.0(eslint@8.56.0) - eslint-plugin-import: 2.29.1(@typescript-eslint/parser@6.19.0)(eslint@8.56.0) + eslint-plugin-import: 2.29.1(@typescript-eslint/parser@7.0.2)(eslint@8.56.0) eslint-plugin-react: 7.33.2(eslint@8.56.0) eslint-plugin-react-hooks: 4.6.0(eslint@8.56.0) eslint-plugin-react-refresh: 0.4.5(eslint@8.56.0) - eslint-plugin-simple-import-sort: 10.0.0(eslint@8.56.0) - eslint-plugin-storybook: 0.6.15(eslint@8.56.0)(typescript@5.3.3) - eslint-plugin-unused-imports: 3.0.0(@typescript-eslint/eslint-plugin@6.19.0)(eslint@8.56.0) + eslint-plugin-simple-import-sort: 12.0.0(eslint@8.56.0) + eslint-plugin-storybook: 0.8.0(eslint@8.56.0)(typescript@5.3.3) + eslint-plugin-unused-imports: 3.1.0(@typescript-eslint/eslint-plugin@7.0.2)(eslint@8.56.0) + prettier: 3.2.5 + typescript: 5.3.3 + transitivePeerDependencies: + - eslint-import-resolver-typescript + - eslint-import-resolver-webpack + - supports-color dev: true - /@invoke-ai/prettier-config-react@0.0.6(prettier@3.2.4): - resolution: {integrity: sha512-qHE6GAw/Aka/8TLTN9U1U+8pxjaFe5irDv/uSgzqmrBR1rGiVyMp19pEficWRRt+03zYdquiiDjTmoabWQxY0Q==} + /@invoke-ai/prettier-config-react@0.0.7(prettier@3.2.5): + resolution: {integrity: sha512-vQeWzqwih116TBlIJII93L8ictj6uv7PxcSlAGNZrzG2UcaCFMsQqKCsB/qio26uihgv/EtvN6XAF96SnE0TKw==} peerDependencies: - prettier: ^3.2.4 + prettier: ^3.2.5 dependencies: - prettier: 3.2.4 + prettier: 3.2.5 dev: true - /@invoke-ai/ui-library@0.0.21(@chakra-ui/form-control@2.2.0)(@chakra-ui/icon@3.2.0)(@chakra-ui/media-query@3.3.0)(@chakra-ui/menu@2.2.1)(@chakra-ui/spinner@2.1.0)(@chakra-ui/system@2.6.2)(@fontsource-variable/inter@5.0.16)(@internationalized/date@3.5.2)(@types/react@18.2.48)(i18next@23.7.16)(react-dom@18.2.0)(react@18.2.0): + /@invoke-ai/ui-library@0.0.21(@chakra-ui/form-control@2.2.0)(@chakra-ui/icon@3.2.0)(@chakra-ui/media-query@3.3.0)(@chakra-ui/menu@2.2.1)(@chakra-ui/spinner@2.1.0)(@chakra-ui/system@2.6.2)(@fontsource-variable/inter@5.0.16)(@internationalized/date@3.5.2)(@types/react@18.2.57)(i18next@23.9.0)(react-dom@18.2.0)(react@18.2.0): resolution: {integrity: sha512-tCvgkBPDt0gNq+8IcR03e/Mw7R8Mb/SMXTqx3FEIxlTQEo93A/D38dKXeDCzTdx4sQ+sknfB+JLBbHs6sg5hhQ==} peerDependencies: '@fontsource-variable/inter': ^5.0.16 @@ -3627,14 +3774,14 @@ packages: '@chakra-ui/icons': 2.1.1(@chakra-ui/system@2.6.2)(react@18.2.0) '@chakra-ui/layout': 2.3.1(@chakra-ui/system@2.6.2)(react@18.2.0) '@chakra-ui/portal': 2.1.0(react-dom@18.2.0)(react@18.2.0) - '@chakra-ui/react': 2.8.2(@emotion/react@11.11.3)(@emotion/styled@11.11.0)(@types/react@18.2.48)(framer-motion@10.18.0)(react-dom@18.2.0)(react@18.2.0) + '@chakra-ui/react': 2.8.2(@emotion/react@11.11.3)(@emotion/styled@11.11.0)(@types/react@18.2.57)(framer-motion@10.18.0)(react-dom@18.2.0)(react@18.2.0) '@chakra-ui/styled-system': 2.9.2 '@chakra-ui/theme-tools': 2.1.2(@chakra-ui/styled-system@2.9.2) - '@emotion/react': 11.11.3(@types/react@18.2.48)(react@18.2.0) - '@emotion/styled': 11.11.0(@emotion/react@11.11.3)(@types/react@18.2.48)(react@18.2.0) + '@emotion/react': 11.11.3(@types/react@18.2.57)(react@18.2.0) + '@emotion/styled': 11.11.0(@emotion/react@11.11.3)(@types/react@18.2.57)(react@18.2.0) '@fontsource-variable/inter': 5.0.16 - '@nanostores/react': 0.7.1(nanostores@0.9.5)(react@18.2.0) - chakra-react-select: 4.7.6(@chakra-ui/form-control@2.2.0)(@chakra-ui/icon@3.2.0)(@chakra-ui/layout@2.3.1)(@chakra-ui/media-query@3.3.0)(@chakra-ui/menu@2.2.1)(@chakra-ui/spinner@2.1.0)(@chakra-ui/system@2.6.2)(@emotion/react@11.11.3)(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0) + '@nanostores/react': 0.7.2(nanostores@0.9.5)(react@18.2.0) + chakra-react-select: 4.7.6(@chakra-ui/form-control@2.2.0)(@chakra-ui/icon@3.2.0)(@chakra-ui/layout@2.3.1)(@chakra-ui/media-query@3.3.0)(@chakra-ui/menu@2.2.1)(@chakra-ui/spinner@2.1.0)(@chakra-ui/system@2.6.2)(@emotion/react@11.11.3)(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0) framer-motion: 10.18.0(react-dom@18.2.0)(react@18.2.0) lodash-es: 4.17.21 nanostores: 0.9.5 @@ -3642,9 +3789,9 @@ packages: overlayscrollbars-react: 0.5.4(overlayscrollbars@2.5.0)(react@18.2.0) react: 18.2.0 react-dom: 18.2.0(react@18.2.0) - react-i18next: 14.0.5(i18next@23.7.16)(react-dom@18.2.0)(react@18.2.0) + react-i18next: 14.0.5(i18next@23.9.0)(react-dom@18.2.0)(react@18.2.0) react-icons: 5.0.1(react@18.2.0) - react-select: 5.8.0(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0) + react-select: 5.8.0(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0) transitivePeerDependencies: - '@chakra-ui/form-control' - '@chakra-ui/icon' @@ -3697,9 +3844,9 @@ packages: resolution: {integrity: sha512-ok/BTPFzFKVMwO5eOHRrvnBVHdRy9IrsrW1GpMaQ9MCnilNLXQKmAX8s1YXDFaai9xJpac2ySzV0YeRRECr2Vw==} engines: {node: ^14.15.0 || ^16.10.0 || >=18.0.0} dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 '@jest/types': 29.6.3 - '@jridgewell/trace-mapping': 0.3.21 + '@jridgewell/trace-mapping': 0.3.22 babel-plugin-istanbul: 6.1.1 chalk: 4.1.2 convert-source-map: 2.0.0 @@ -3722,7 +3869,7 @@ packages: dependencies: '@types/istanbul-lib-coverage': 2.0.6 '@types/istanbul-reports': 3.0.4 - '@types/node': 20.11.5 + '@types/node': 20.11.19 '@types/yargs': 16.0.9 chalk: 4.1.2 dev: true @@ -3734,12 +3881,12 @@ packages: '@jest/schemas': 29.6.3 '@types/istanbul-lib-coverage': 2.0.6 '@types/istanbul-reports': 3.0.4 - '@types/node': 20.11.5 + '@types/node': 20.11.19 '@types/yargs': 17.0.32 chalk: 4.1.2 dev: true - /@joshwooding/vite-plugin-react-docgen-typescript@0.3.0(typescript@5.3.3)(vite@5.0.12): + /@joshwooding/vite-plugin-react-docgen-typescript@0.3.0(typescript@5.3.3)(vite@5.1.3): resolution: {integrity: sha512-2D6y7fNvFmsLmRt6UCOFJPvFoPMJGT0Uh1Wg0RaigUp7kdQPs6yYn8Dmx6GZkOH/NW0yMTwRz/p0SRMMRo50vA==} peerDependencies: typescript: '>= 4.3.x' @@ -3753,7 +3900,7 @@ packages: magic-string: 0.27.0 react-docgen-typescript: 2.2.2(typescript@5.3.3) typescript: 5.3.3 - vite: 5.0.12(@types/node@20.11.5) + vite: 5.1.3(@types/node@20.11.19) dev: true /@jridgewell/gen-mapping@0.3.3: @@ -3762,11 +3909,11 @@ packages: dependencies: '@jridgewell/set-array': 1.1.2 '@jridgewell/sourcemap-codec': 1.4.15 - '@jridgewell/trace-mapping': 0.3.21 + '@jridgewell/trace-mapping': 0.3.22 dev: true - /@jridgewell/resolve-uri@3.1.1: - resolution: {integrity: sha512-dSYZh7HhCDtCKm4QakX0xFpsRDqjjtZf/kjI/v3T3Nwt5r8/qz/M19F9ySyOqU94SXBmeG9ttTul+YnR4LOxFA==} + /@jridgewell/resolve-uri@3.1.2: + resolution: {integrity: sha512-bRISgCIjP20/tbWSPWMEi54QVPRZExkuD9lJL+UIxUKtwVJA8wW1Trb1jMs1RFXo1CBTNZ/5hpC9QvmKWdopKw==} engines: {node: '>=6.0.0'} dev: true @@ -3778,10 +3925,10 @@ packages: /@jridgewell/sourcemap-codec@1.4.15: resolution: {integrity: sha512-eF2rxCRulEKXHTRiDrDy6erMYWqNw4LPdQ8UQA4huuxaQsVeRPFl2oM8oDGxMFhJUWZf9McpLtJasDDZb/Bpeg==} - /@jridgewell/trace-mapping@0.3.21: - resolution: {integrity: sha512-SRfKmRe1KvYnxjEMtxEr+J4HIeMX5YBg/qhRHpxEIGjhX1rshcHlnFUE9K0GazhVKWM7B+nARSkV8LuvJdJ5/g==} + /@jridgewell/trace-mapping@0.3.22: + resolution: {integrity: sha512-Wf963MzWtA2sjrNt+g18IAln9lKnlRp+K2eH4jjIoF1wYeq3aMREpG09xhlhdzS0EjwU7qmUJYangWa+151vZw==} dependencies: - '@jridgewell/resolve-uri': 3.1.1 + '@jridgewell/resolve-uri': 3.1.2 '@jridgewell/sourcemap-codec': 1.4.15 dev: true @@ -3804,29 +3951,29 @@ packages: peerDependencies: react: '>=16' dependencies: - '@types/mdx': 2.0.10 - '@types/react': 18.2.48 + '@types/mdx': 2.0.11 + '@types/react': 18.2.57 react: 18.2.0 dev: true - /@microsoft/api-extractor-model@7.28.3(@types/node@20.11.5): + /@microsoft/api-extractor-model@7.28.3(@types/node@20.11.19): resolution: {integrity: sha512-wT/kB2oDbdZXITyDh2SQLzaWwTOFbV326fP0pUwNW00WeliARs0qjmXBWmGWardEzp2U3/axkO3Lboqun6vrig==} dependencies: '@microsoft/tsdoc': 0.14.2 '@microsoft/tsdoc-config': 0.16.2 - '@rushstack/node-core-library': 3.62.0(@types/node@20.11.5) + '@rushstack/node-core-library': 3.62.0(@types/node@20.11.19) transitivePeerDependencies: - '@types/node' dev: true - /@microsoft/api-extractor@7.39.0(@types/node@20.11.5): + /@microsoft/api-extractor@7.39.0(@types/node@20.11.19): resolution: {integrity: sha512-PuXxzadgnvp+wdeZFPonssRAj/EW4Gm4s75TXzPk09h3wJ8RS3x7typf95B4vwZRrPTQBGopdUl+/vHvlPdAcg==} hasBin: true dependencies: - '@microsoft/api-extractor-model': 7.28.3(@types/node@20.11.5) + '@microsoft/api-extractor-model': 7.28.3(@types/node@20.11.19) '@microsoft/tsdoc': 0.14.2 '@microsoft/tsdoc-config': 0.16.2 - '@rushstack/node-core-library': 3.62.0(@types/node@20.11.5) + '@rushstack/node-core-library': 3.62.0(@types/node@20.11.19) '@rushstack/rig-package': 0.5.1 '@rushstack/ts-command-line': 4.17.1 colors: 1.2.5 @@ -3852,11 +3999,22 @@ packages: resolution: {integrity: sha512-9b8mPpKrfeGRuhFH5iO1iwCLeIIsV6+H1sRfxbkoGXIyQE2BTsPd9zqSqQJ+pv5sJ/hT5M1zvOFL02MnEezFug==} dev: true - /@nanostores/react@0.7.1(nanostores@0.9.5)(react@18.2.0): - resolution: {integrity: sha512-EXQg9N4MdI4eJQz/AZLIx3hxQ6BuBmV4Q55bCd5YCSgEOAW7tGTsIZxpRXxvxLXzflNvHTBvfrDNY38TlSVBkQ==} - engines: {node: ^16.0.0 || ^18.0.0 || >=20.0.0} + /@nanostores/react@0.7.2(nanostores@0.10.0)(react@18.2.0): + resolution: {integrity: sha512-e3OhHJFv3NMSFYDgREdlAQqkyBTHJM91s31kOZ4OvZwJKdFk5BLk0MLbh51EOGUz9QGX2aCHfy1RvweSi7fgwA==} + engines: {node: ^18.0.0 || >=20.0.0} peerDependencies: - nanostores: ^0.9.0 + nanostores: ^0.9.0 || ^0.10.0 + react: '>=18.0.0' + dependencies: + nanostores: 0.10.0 + react: 18.2.0 + dev: false + + /@nanostores/react@0.7.2(nanostores@0.9.5)(react@18.2.0): + resolution: {integrity: sha512-e3OhHJFv3NMSFYDgREdlAQqkyBTHJM91s31kOZ4OvZwJKdFk5BLk0MLbh51EOGUz9QGX2aCHfy1RvweSi7fgwA==} + engines: {node: ^18.0.0 || >=20.0.0} + peerDependencies: + nanostores: ^0.9.0 || ^0.10.0 react: '>=18.0.0' dependencies: nanostores: 0.9.5 @@ -3889,7 +4047,7 @@ packages: engines: {node: '>= 8'} dependencies: '@nodelib/fs.scandir': 2.1.5 - fastq: 1.16.0 + fastq: 1.17.1 dev: true /@pkgjs/parseargs@0.11.0: @@ -3906,16 +4064,16 @@ packages: /@radix-ui/number@1.0.1: resolution: {integrity: sha512-T5gIdVO2mmPW3NNhjNgEP3cqMXjXL9UbO0BzWcXfvdBs+BohbQxvd/K5hSVKmn9/lbTdsQVKbUcP5WLCwvUbBg==} dependencies: - '@babel/runtime': 7.23.8 + '@babel/runtime': 7.23.9 dev: true /@radix-ui/primitive@1.0.1: resolution: {integrity: sha512-yQ8oGX2GVsEYMWGxcovu1uGWPCxV5BFfeeYxqPmuAzUyLT9qmaMXSAhXpb0WrspIeqYzdJpkh2vHModJPgRIaw==} dependencies: - '@babel/runtime': 7.23.8 + '@babel/runtime': 7.23.9 dev: true - /@radix-ui/react-arrow@1.0.3(@types/react-dom@18.2.18)(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0): + /@radix-ui/react-arrow@1.0.3(@types/react-dom@18.2.19)(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0): resolution: {integrity: sha512-wSP+pHsB/jQRaL6voubsQ/ZlrGBHHrOjmBnr19hxYgtS0WvAFwZhK2WP/YY5yF9uKECCEEDGxuLxq1NBK51wFA==} peerDependencies: '@types/react': '*' @@ -3928,15 +4086,15 @@ packages: '@types/react-dom': optional: true dependencies: - '@babel/runtime': 7.23.8 - '@radix-ui/react-primitive': 1.0.3(@types/react-dom@18.2.18)(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0) - '@types/react': 18.2.48 - '@types/react-dom': 18.2.18 + '@babel/runtime': 7.23.9 + '@radix-ui/react-primitive': 1.0.3(@types/react-dom@18.2.19)(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0) + '@types/react': 18.2.57 + '@types/react-dom': 18.2.19 react: 18.2.0 react-dom: 18.2.0(react@18.2.0) dev: true - /@radix-ui/react-collection@1.0.3(@types/react-dom@18.2.18)(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0): + /@radix-ui/react-collection@1.0.3(@types/react-dom@18.2.19)(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0): resolution: {integrity: sha512-3SzW+0PW7yBBoQlT8wNcGtaxaD0XSu0uLUFgrtHY08Acx05TaHaOmVLR73c0j/cqpDy53KBMO7s0dx2wmOIDIA==} peerDependencies: '@types/react': '*' @@ -3949,18 +4107,18 @@ packages: '@types/react-dom': optional: true dependencies: - '@babel/runtime': 7.23.8 - '@radix-ui/react-compose-refs': 1.0.1(@types/react@18.2.48)(react@18.2.0) - '@radix-ui/react-context': 1.0.1(@types/react@18.2.48)(react@18.2.0) - '@radix-ui/react-primitive': 1.0.3(@types/react-dom@18.2.18)(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0) - '@radix-ui/react-slot': 1.0.2(@types/react@18.2.48)(react@18.2.0) - '@types/react': 18.2.48 - '@types/react-dom': 18.2.18 + '@babel/runtime': 7.23.9 + '@radix-ui/react-compose-refs': 1.0.1(@types/react@18.2.57)(react@18.2.0) + '@radix-ui/react-context': 1.0.1(@types/react@18.2.57)(react@18.2.0) + '@radix-ui/react-primitive': 1.0.3(@types/react-dom@18.2.19)(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0) + '@radix-ui/react-slot': 1.0.2(@types/react@18.2.57)(react@18.2.0) + '@types/react': 18.2.57 + '@types/react-dom': 18.2.19 react: 18.2.0 react-dom: 18.2.0(react@18.2.0) dev: true - /@radix-ui/react-compose-refs@1.0.1(@types/react@18.2.48)(react@18.2.0): + /@radix-ui/react-compose-refs@1.0.1(@types/react@18.2.57)(react@18.2.0): resolution: {integrity: sha512-fDSBgd44FKHa1FRMU59qBMPFcl2PZE+2nmqunj+BWFyYYjnhIDWL2ItDs3rrbJDQOtzt5nIebLCQc4QRfz6LJw==} peerDependencies: '@types/react': '*' @@ -3969,12 +4127,12 @@ packages: '@types/react': optional: true dependencies: - '@babel/runtime': 7.23.8 - '@types/react': 18.2.48 + '@babel/runtime': 7.23.9 + '@types/react': 18.2.57 react: 18.2.0 dev: true - /@radix-ui/react-context@1.0.1(@types/react@18.2.48)(react@18.2.0): + /@radix-ui/react-context@1.0.1(@types/react@18.2.57)(react@18.2.0): resolution: {integrity: sha512-ebbrdFoYTcuZ0v4wG5tedGnp9tzcV8awzsxYph7gXUyvnNLuTIcCk1q17JEbnVhXAKG9oX3KtchwiMIAYp9NLg==} peerDependencies: '@types/react': '*' @@ -3983,12 +4141,12 @@ packages: '@types/react': optional: true dependencies: - '@babel/runtime': 7.23.8 - '@types/react': 18.2.48 + '@babel/runtime': 7.23.9 + '@types/react': 18.2.57 react: 18.2.0 dev: true - /@radix-ui/react-direction@1.0.1(@types/react@18.2.48)(react@18.2.0): + /@radix-ui/react-direction@1.0.1(@types/react@18.2.57)(react@18.2.0): resolution: {integrity: sha512-RXcvnXgyvYvBEOhCBuddKecVkoMiI10Jcm5cTI7abJRAHYfFxeu+FBQs/DvdxSYucxR5mna0dNsL6QFlds5TMA==} peerDependencies: '@types/react': '*' @@ -3997,12 +4155,12 @@ packages: '@types/react': optional: true dependencies: - '@babel/runtime': 7.23.8 - '@types/react': 18.2.48 + '@babel/runtime': 7.23.9 + '@types/react': 18.2.57 react: 18.2.0 dev: true - /@radix-ui/react-dismissable-layer@1.0.4(@types/react-dom@18.2.18)(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0): + /@radix-ui/react-dismissable-layer@1.0.4(@types/react-dom@18.2.19)(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0): resolution: {integrity: sha512-7UpBa/RKMoHJYjie1gkF1DlK8l1fdU/VKDpoS3rCCo8YBJR294GwcEHyxHw72yvphJ7ld0AXEcSLAzY2F/WyCg==} peerDependencies: '@types/react': '*' @@ -4015,19 +4173,19 @@ packages: '@types/react-dom': optional: true dependencies: - '@babel/runtime': 7.23.8 + '@babel/runtime': 7.23.9 '@radix-ui/primitive': 1.0.1 - '@radix-ui/react-compose-refs': 1.0.1(@types/react@18.2.48)(react@18.2.0) - '@radix-ui/react-primitive': 1.0.3(@types/react-dom@18.2.18)(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0) - '@radix-ui/react-use-callback-ref': 1.0.1(@types/react@18.2.48)(react@18.2.0) - '@radix-ui/react-use-escape-keydown': 1.0.3(@types/react@18.2.48)(react@18.2.0) - '@types/react': 18.2.48 - '@types/react-dom': 18.2.18 + '@radix-ui/react-compose-refs': 1.0.1(@types/react@18.2.57)(react@18.2.0) + '@radix-ui/react-primitive': 1.0.3(@types/react-dom@18.2.19)(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0) + '@radix-ui/react-use-callback-ref': 1.0.1(@types/react@18.2.57)(react@18.2.0) + '@radix-ui/react-use-escape-keydown': 1.0.3(@types/react@18.2.57)(react@18.2.0) + '@types/react': 18.2.57 + '@types/react-dom': 18.2.19 react: 18.2.0 react-dom: 18.2.0(react@18.2.0) dev: true - /@radix-ui/react-focus-guards@1.0.1(@types/react@18.2.48)(react@18.2.0): + /@radix-ui/react-focus-guards@1.0.1(@types/react@18.2.57)(react@18.2.0): resolution: {integrity: sha512-Rect2dWbQ8waGzhMavsIbmSVCgYxkXLxxR3ZvCX79JOglzdEy4JXMb98lq4hPxUbLr77nP0UOGf4rcMU+s1pUA==} peerDependencies: '@types/react': '*' @@ -4036,12 +4194,12 @@ packages: '@types/react': optional: true dependencies: - '@babel/runtime': 7.23.8 - '@types/react': 18.2.48 + '@babel/runtime': 7.23.9 + '@types/react': 18.2.57 react: 18.2.0 dev: true - /@radix-ui/react-focus-scope@1.0.3(@types/react-dom@18.2.18)(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0): + /@radix-ui/react-focus-scope@1.0.3(@types/react-dom@18.2.19)(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0): resolution: {integrity: sha512-upXdPfqI4islj2CslyfUBNlaJCPybbqRHAi1KER7Isel9Q2AtSJ0zRBZv8mWQiFXD2nyAJ4BhC3yXgZ6kMBSrQ==} peerDependencies: '@types/react': '*' @@ -4054,17 +4212,17 @@ packages: '@types/react-dom': optional: true dependencies: - '@babel/runtime': 7.23.8 - '@radix-ui/react-compose-refs': 1.0.1(@types/react@18.2.48)(react@18.2.0) - '@radix-ui/react-primitive': 1.0.3(@types/react-dom@18.2.18)(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0) - '@radix-ui/react-use-callback-ref': 1.0.1(@types/react@18.2.48)(react@18.2.0) - '@types/react': 18.2.48 - '@types/react-dom': 18.2.18 + '@babel/runtime': 7.23.9 + '@radix-ui/react-compose-refs': 1.0.1(@types/react@18.2.57)(react@18.2.0) + '@radix-ui/react-primitive': 1.0.3(@types/react-dom@18.2.19)(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0) + '@radix-ui/react-use-callback-ref': 1.0.1(@types/react@18.2.57)(react@18.2.0) + '@types/react': 18.2.57 + '@types/react-dom': 18.2.19 react: 18.2.0 react-dom: 18.2.0(react@18.2.0) dev: true - /@radix-ui/react-id@1.0.1(@types/react@18.2.48)(react@18.2.0): + /@radix-ui/react-id@1.0.1(@types/react@18.2.57)(react@18.2.0): resolution: {integrity: sha512-tI7sT/kqYp8p96yGWY1OAnLHrqDgzHefRBKQ2YAkBS5ja7QLcZ9Z/uY7bEjPUatf8RomoXM8/1sMj1IJaE5UzQ==} peerDependencies: '@types/react': '*' @@ -4073,13 +4231,13 @@ packages: '@types/react': optional: true dependencies: - '@babel/runtime': 7.23.8 - '@radix-ui/react-use-layout-effect': 1.0.1(@types/react@18.2.48)(react@18.2.0) - '@types/react': 18.2.48 + '@babel/runtime': 7.23.9 + '@radix-ui/react-use-layout-effect': 1.0.1(@types/react@18.2.57)(react@18.2.0) + '@types/react': 18.2.57 react: 18.2.0 dev: true - /@radix-ui/react-popper@1.1.2(@types/react-dom@18.2.18)(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0): + /@radix-ui/react-popper@1.1.2(@types/react-dom@18.2.19)(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0): resolution: {integrity: sha512-1CnGGfFi/bbqtJZZ0P/NQY20xdG3E0LALJaLUEoKwPLwl6PPPfbeiCqMVQnhoFRAxjJj4RpBRJzDmUgsex2tSg==} peerDependencies: '@types/react': '*' @@ -4092,24 +4250,24 @@ packages: '@types/react-dom': optional: true dependencies: - '@babel/runtime': 7.23.8 - '@floating-ui/react-dom': 2.0.6(react-dom@18.2.0)(react@18.2.0) - '@radix-ui/react-arrow': 1.0.3(@types/react-dom@18.2.18)(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0) - '@radix-ui/react-compose-refs': 1.0.1(@types/react@18.2.48)(react@18.2.0) - '@radix-ui/react-context': 1.0.1(@types/react@18.2.48)(react@18.2.0) - '@radix-ui/react-primitive': 1.0.3(@types/react-dom@18.2.18)(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0) - '@radix-ui/react-use-callback-ref': 1.0.1(@types/react@18.2.48)(react@18.2.0) - '@radix-ui/react-use-layout-effect': 1.0.1(@types/react@18.2.48)(react@18.2.0) - '@radix-ui/react-use-rect': 1.0.1(@types/react@18.2.48)(react@18.2.0) - '@radix-ui/react-use-size': 1.0.1(@types/react@18.2.48)(react@18.2.0) + '@babel/runtime': 7.23.9 + '@floating-ui/react-dom': 2.0.8(react-dom@18.2.0)(react@18.2.0) + '@radix-ui/react-arrow': 1.0.3(@types/react-dom@18.2.19)(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0) + '@radix-ui/react-compose-refs': 1.0.1(@types/react@18.2.57)(react@18.2.0) + '@radix-ui/react-context': 1.0.1(@types/react@18.2.57)(react@18.2.0) + '@radix-ui/react-primitive': 1.0.3(@types/react-dom@18.2.19)(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0) + '@radix-ui/react-use-callback-ref': 1.0.1(@types/react@18.2.57)(react@18.2.0) + '@radix-ui/react-use-layout-effect': 1.0.1(@types/react@18.2.57)(react@18.2.0) + '@radix-ui/react-use-rect': 1.0.1(@types/react@18.2.57)(react@18.2.0) + '@radix-ui/react-use-size': 1.0.1(@types/react@18.2.57)(react@18.2.0) '@radix-ui/rect': 1.0.1 - '@types/react': 18.2.48 - '@types/react-dom': 18.2.18 + '@types/react': 18.2.57 + '@types/react-dom': 18.2.19 react: 18.2.0 react-dom: 18.2.0(react@18.2.0) dev: true - /@radix-ui/react-portal@1.0.3(@types/react-dom@18.2.18)(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0): + /@radix-ui/react-portal@1.0.3(@types/react-dom@18.2.19)(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0): resolution: {integrity: sha512-xLYZeHrWoPmA5mEKEfZZevoVRK/Q43GfzRXkWV6qawIWWK8t6ifIiLQdd7rmQ4Vk1bmI21XhqF9BN3jWf+phpA==} peerDependencies: '@types/react': '*' @@ -4122,15 +4280,15 @@ packages: '@types/react-dom': optional: true dependencies: - '@babel/runtime': 7.23.8 - '@radix-ui/react-primitive': 1.0.3(@types/react-dom@18.2.18)(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0) - '@types/react': 18.2.48 - '@types/react-dom': 18.2.18 + '@babel/runtime': 7.23.9 + '@radix-ui/react-primitive': 1.0.3(@types/react-dom@18.2.19)(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0) + '@types/react': 18.2.57 + '@types/react-dom': 18.2.19 react: 18.2.0 react-dom: 18.2.0(react@18.2.0) dev: true - /@radix-ui/react-primitive@1.0.3(@types/react-dom@18.2.18)(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0): + /@radix-ui/react-primitive@1.0.3(@types/react-dom@18.2.19)(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0): resolution: {integrity: sha512-yi58uVyoAcK/Nq1inRY56ZSjKypBNKTa/1mcL8qdl6oJeEaDbOldlzrGn7P6Q3Id5d+SYNGc5AJgc4vGhjs5+g==} peerDependencies: '@types/react': '*' @@ -4143,15 +4301,15 @@ packages: '@types/react-dom': optional: true dependencies: - '@babel/runtime': 7.23.8 - '@radix-ui/react-slot': 1.0.2(@types/react@18.2.48)(react@18.2.0) - '@types/react': 18.2.48 - '@types/react-dom': 18.2.18 + '@babel/runtime': 7.23.9 + '@radix-ui/react-slot': 1.0.2(@types/react@18.2.57)(react@18.2.0) + '@types/react': 18.2.57 + '@types/react-dom': 18.2.19 react: 18.2.0 react-dom: 18.2.0(react@18.2.0) dev: true - /@radix-ui/react-roving-focus@1.0.4(@types/react-dom@18.2.18)(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0): + /@radix-ui/react-roving-focus@1.0.4(@types/react-dom@18.2.19)(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0): resolution: {integrity: sha512-2mUg5Mgcu001VkGy+FfzZyzbmuUWzgWkj3rvv4yu+mLw03+mTzbxZHvfcGyFp2b8EkQeMkpRQ5FiA2Vr2O6TeQ==} peerDependencies: '@types/react': '*' @@ -4164,23 +4322,23 @@ packages: '@types/react-dom': optional: true dependencies: - '@babel/runtime': 7.23.8 + '@babel/runtime': 7.23.9 '@radix-ui/primitive': 1.0.1 - '@radix-ui/react-collection': 1.0.3(@types/react-dom@18.2.18)(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0) - '@radix-ui/react-compose-refs': 1.0.1(@types/react@18.2.48)(react@18.2.0) - '@radix-ui/react-context': 1.0.1(@types/react@18.2.48)(react@18.2.0) - '@radix-ui/react-direction': 1.0.1(@types/react@18.2.48)(react@18.2.0) - '@radix-ui/react-id': 1.0.1(@types/react@18.2.48)(react@18.2.0) - '@radix-ui/react-primitive': 1.0.3(@types/react-dom@18.2.18)(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0) - '@radix-ui/react-use-callback-ref': 1.0.1(@types/react@18.2.48)(react@18.2.0) - '@radix-ui/react-use-controllable-state': 1.0.1(@types/react@18.2.48)(react@18.2.0) - '@types/react': 18.2.48 - '@types/react-dom': 18.2.18 + '@radix-ui/react-collection': 1.0.3(@types/react-dom@18.2.19)(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0) + '@radix-ui/react-compose-refs': 1.0.1(@types/react@18.2.57)(react@18.2.0) + '@radix-ui/react-context': 1.0.1(@types/react@18.2.57)(react@18.2.0) + '@radix-ui/react-direction': 1.0.1(@types/react@18.2.57)(react@18.2.0) + '@radix-ui/react-id': 1.0.1(@types/react@18.2.57)(react@18.2.0) + '@radix-ui/react-primitive': 1.0.3(@types/react-dom@18.2.19)(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0) + '@radix-ui/react-use-callback-ref': 1.0.1(@types/react@18.2.57)(react@18.2.0) + '@radix-ui/react-use-controllable-state': 1.0.1(@types/react@18.2.57)(react@18.2.0) + '@types/react': 18.2.57 + '@types/react-dom': 18.2.19 react: 18.2.0 react-dom: 18.2.0(react@18.2.0) dev: true - /@radix-ui/react-select@1.2.2(@types/react-dom@18.2.18)(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0): + /@radix-ui/react-select@1.2.2(@types/react-dom@18.2.19)(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0): resolution: {integrity: sha512-zI7McXr8fNaSrUY9mZe4x/HC0jTLY9fWNhO1oLWYMQGDXuV4UCivIGTxwioSzO0ZCYX9iSLyWmAh/1TOmX3Cnw==} peerDependencies: '@types/react': '*' @@ -4193,35 +4351,35 @@ packages: '@types/react-dom': optional: true dependencies: - '@babel/runtime': 7.23.8 + '@babel/runtime': 7.23.9 '@radix-ui/number': 1.0.1 '@radix-ui/primitive': 1.0.1 - '@radix-ui/react-collection': 1.0.3(@types/react-dom@18.2.18)(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0) - '@radix-ui/react-compose-refs': 1.0.1(@types/react@18.2.48)(react@18.2.0) - '@radix-ui/react-context': 1.0.1(@types/react@18.2.48)(react@18.2.0) - '@radix-ui/react-direction': 1.0.1(@types/react@18.2.48)(react@18.2.0) - '@radix-ui/react-dismissable-layer': 1.0.4(@types/react-dom@18.2.18)(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0) - '@radix-ui/react-focus-guards': 1.0.1(@types/react@18.2.48)(react@18.2.0) - '@radix-ui/react-focus-scope': 1.0.3(@types/react-dom@18.2.18)(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0) - '@radix-ui/react-id': 1.0.1(@types/react@18.2.48)(react@18.2.0) - '@radix-ui/react-popper': 1.1.2(@types/react-dom@18.2.18)(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0) - '@radix-ui/react-portal': 1.0.3(@types/react-dom@18.2.18)(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0) - '@radix-ui/react-primitive': 1.0.3(@types/react-dom@18.2.18)(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0) - '@radix-ui/react-slot': 1.0.2(@types/react@18.2.48)(react@18.2.0) - '@radix-ui/react-use-callback-ref': 1.0.1(@types/react@18.2.48)(react@18.2.0) - '@radix-ui/react-use-controllable-state': 1.0.1(@types/react@18.2.48)(react@18.2.0) - '@radix-ui/react-use-layout-effect': 1.0.1(@types/react@18.2.48)(react@18.2.0) - '@radix-ui/react-use-previous': 1.0.1(@types/react@18.2.48)(react@18.2.0) - '@radix-ui/react-visually-hidden': 1.0.3(@types/react-dom@18.2.18)(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0) - '@types/react': 18.2.48 - '@types/react-dom': 18.2.18 + '@radix-ui/react-collection': 1.0.3(@types/react-dom@18.2.19)(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0) + '@radix-ui/react-compose-refs': 1.0.1(@types/react@18.2.57)(react@18.2.0) + '@radix-ui/react-context': 1.0.1(@types/react@18.2.57)(react@18.2.0) + '@radix-ui/react-direction': 1.0.1(@types/react@18.2.57)(react@18.2.0) + '@radix-ui/react-dismissable-layer': 1.0.4(@types/react-dom@18.2.19)(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0) + '@radix-ui/react-focus-guards': 1.0.1(@types/react@18.2.57)(react@18.2.0) + '@radix-ui/react-focus-scope': 1.0.3(@types/react-dom@18.2.19)(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0) + '@radix-ui/react-id': 1.0.1(@types/react@18.2.57)(react@18.2.0) + '@radix-ui/react-popper': 1.1.2(@types/react-dom@18.2.19)(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0) + '@radix-ui/react-portal': 1.0.3(@types/react-dom@18.2.19)(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0) + '@radix-ui/react-primitive': 1.0.3(@types/react-dom@18.2.19)(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0) + '@radix-ui/react-slot': 1.0.2(@types/react@18.2.57)(react@18.2.0) + '@radix-ui/react-use-callback-ref': 1.0.1(@types/react@18.2.57)(react@18.2.0) + '@radix-ui/react-use-controllable-state': 1.0.1(@types/react@18.2.57)(react@18.2.0) + '@radix-ui/react-use-layout-effect': 1.0.1(@types/react@18.2.57)(react@18.2.0) + '@radix-ui/react-use-previous': 1.0.1(@types/react@18.2.57)(react@18.2.0) + '@radix-ui/react-visually-hidden': 1.0.3(@types/react-dom@18.2.19)(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0) + '@types/react': 18.2.57 + '@types/react-dom': 18.2.19 aria-hidden: 1.2.3 react: 18.2.0 react-dom: 18.2.0(react@18.2.0) - react-remove-scroll: 2.5.5(@types/react@18.2.48)(react@18.2.0) + react-remove-scroll: 2.5.5(@types/react@18.2.57)(react@18.2.0) dev: true - /@radix-ui/react-separator@1.0.3(@types/react-dom@18.2.18)(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0): + /@radix-ui/react-separator@1.0.3(@types/react-dom@18.2.19)(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0): resolution: {integrity: sha512-itYmTy/kokS21aiV5+Z56MZB54KrhPgn6eHDKkFeOLR34HMN2s8PaN47qZZAGnvupcjxHaFZnW4pQEh0BvvVuw==} peerDependencies: '@types/react': '*' @@ -4234,15 +4392,15 @@ packages: '@types/react-dom': optional: true dependencies: - '@babel/runtime': 7.23.8 - '@radix-ui/react-primitive': 1.0.3(@types/react-dom@18.2.18)(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0) - '@types/react': 18.2.48 - '@types/react-dom': 18.2.18 + '@babel/runtime': 7.23.9 + '@radix-ui/react-primitive': 1.0.3(@types/react-dom@18.2.19)(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0) + '@types/react': 18.2.57 + '@types/react-dom': 18.2.19 react: 18.2.0 react-dom: 18.2.0(react@18.2.0) dev: true - /@radix-ui/react-slot@1.0.2(@types/react@18.2.48)(react@18.2.0): + /@radix-ui/react-slot@1.0.2(@types/react@18.2.57)(react@18.2.0): resolution: {integrity: sha512-YeTpuq4deV+6DusvVUW4ivBgnkHwECUu0BiN43L5UCDFgdhsRUWAghhTF5MbvNTPzmiFOx90asDSUjWuCNapwg==} peerDependencies: '@types/react': '*' @@ -4251,13 +4409,13 @@ packages: '@types/react': optional: true dependencies: - '@babel/runtime': 7.23.8 - '@radix-ui/react-compose-refs': 1.0.1(@types/react@18.2.48)(react@18.2.0) - '@types/react': 18.2.48 + '@babel/runtime': 7.23.9 + '@radix-ui/react-compose-refs': 1.0.1(@types/react@18.2.57)(react@18.2.0) + '@types/react': 18.2.57 react: 18.2.0 dev: true - /@radix-ui/react-toggle-group@1.0.4(@types/react-dom@18.2.18)(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0): + /@radix-ui/react-toggle-group@1.0.4(@types/react-dom@18.2.19)(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0): resolution: {integrity: sha512-Uaj/M/cMyiyT9Bx6fOZO0SAG4Cls0GptBWiBmBxofmDbNVnYYoyRWj/2M/6VCi/7qcXFWnHhRUfdfZFvvkuu8A==} peerDependencies: '@types/react': '*' @@ -4270,21 +4428,21 @@ packages: '@types/react-dom': optional: true dependencies: - '@babel/runtime': 7.23.8 + '@babel/runtime': 7.23.9 '@radix-ui/primitive': 1.0.1 - '@radix-ui/react-context': 1.0.1(@types/react@18.2.48)(react@18.2.0) - '@radix-ui/react-direction': 1.0.1(@types/react@18.2.48)(react@18.2.0) - '@radix-ui/react-primitive': 1.0.3(@types/react-dom@18.2.18)(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0) - '@radix-ui/react-roving-focus': 1.0.4(@types/react-dom@18.2.18)(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0) - '@radix-ui/react-toggle': 1.0.3(@types/react-dom@18.2.18)(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0) - '@radix-ui/react-use-controllable-state': 1.0.1(@types/react@18.2.48)(react@18.2.0) - '@types/react': 18.2.48 - '@types/react-dom': 18.2.18 + '@radix-ui/react-context': 1.0.1(@types/react@18.2.57)(react@18.2.0) + '@radix-ui/react-direction': 1.0.1(@types/react@18.2.57)(react@18.2.0) + '@radix-ui/react-primitive': 1.0.3(@types/react-dom@18.2.19)(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0) + '@radix-ui/react-roving-focus': 1.0.4(@types/react-dom@18.2.19)(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0) + '@radix-ui/react-toggle': 1.0.3(@types/react-dom@18.2.19)(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0) + '@radix-ui/react-use-controllable-state': 1.0.1(@types/react@18.2.57)(react@18.2.0) + '@types/react': 18.2.57 + '@types/react-dom': 18.2.19 react: 18.2.0 react-dom: 18.2.0(react@18.2.0) dev: true - /@radix-ui/react-toggle@1.0.3(@types/react-dom@18.2.18)(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0): + /@radix-ui/react-toggle@1.0.3(@types/react-dom@18.2.19)(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0): resolution: {integrity: sha512-Pkqg3+Bc98ftZGsl60CLANXQBBQ4W3mTFS9EJvNxKMZ7magklKV69/id1mlAlOFDDfHvlCms0fx8fA4CMKDJHg==} peerDependencies: '@types/react': '*' @@ -4297,17 +4455,17 @@ packages: '@types/react-dom': optional: true dependencies: - '@babel/runtime': 7.23.8 + '@babel/runtime': 7.23.9 '@radix-ui/primitive': 1.0.1 - '@radix-ui/react-primitive': 1.0.3(@types/react-dom@18.2.18)(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0) - '@radix-ui/react-use-controllable-state': 1.0.1(@types/react@18.2.48)(react@18.2.0) - '@types/react': 18.2.48 - '@types/react-dom': 18.2.18 + '@radix-ui/react-primitive': 1.0.3(@types/react-dom@18.2.19)(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0) + '@radix-ui/react-use-controllable-state': 1.0.1(@types/react@18.2.57)(react@18.2.0) + '@types/react': 18.2.57 + '@types/react-dom': 18.2.19 react: 18.2.0 react-dom: 18.2.0(react@18.2.0) dev: true - /@radix-ui/react-toolbar@1.0.4(@types/react-dom@18.2.18)(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0): + /@radix-ui/react-toolbar@1.0.4(@types/react-dom@18.2.19)(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0): resolution: {integrity: sha512-tBgmM/O7a07xbaEkYJWYTXkIdU/1pW4/KZORR43toC/4XWyBCURK0ei9kMUdp+gTPPKBgYLxXmRSH1EVcIDp8Q==} peerDependencies: '@types/react': '*' @@ -4320,21 +4478,21 @@ packages: '@types/react-dom': optional: true dependencies: - '@babel/runtime': 7.23.8 + '@babel/runtime': 7.23.9 '@radix-ui/primitive': 1.0.1 - '@radix-ui/react-context': 1.0.1(@types/react@18.2.48)(react@18.2.0) - '@radix-ui/react-direction': 1.0.1(@types/react@18.2.48)(react@18.2.0) - '@radix-ui/react-primitive': 1.0.3(@types/react-dom@18.2.18)(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0) - '@radix-ui/react-roving-focus': 1.0.4(@types/react-dom@18.2.18)(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0) - '@radix-ui/react-separator': 1.0.3(@types/react-dom@18.2.18)(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0) - '@radix-ui/react-toggle-group': 1.0.4(@types/react-dom@18.2.18)(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0) - '@types/react': 18.2.48 - '@types/react-dom': 18.2.18 + '@radix-ui/react-context': 1.0.1(@types/react@18.2.57)(react@18.2.0) + '@radix-ui/react-direction': 1.0.1(@types/react@18.2.57)(react@18.2.0) + '@radix-ui/react-primitive': 1.0.3(@types/react-dom@18.2.19)(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0) + '@radix-ui/react-roving-focus': 1.0.4(@types/react-dom@18.2.19)(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0) + '@radix-ui/react-separator': 1.0.3(@types/react-dom@18.2.19)(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0) + '@radix-ui/react-toggle-group': 1.0.4(@types/react-dom@18.2.19)(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0) + '@types/react': 18.2.57 + '@types/react-dom': 18.2.19 react: 18.2.0 react-dom: 18.2.0(react@18.2.0) dev: true - /@radix-ui/react-use-callback-ref@1.0.1(@types/react@18.2.48)(react@18.2.0): + /@radix-ui/react-use-callback-ref@1.0.1(@types/react@18.2.57)(react@18.2.0): resolution: {integrity: sha512-D94LjX4Sp0xJFVaoQOd3OO9k7tpBYNOXdVhkltUbGv2Qb9OXdrg/CpsjlZv7ia14Sylv398LswWBVVu5nqKzAQ==} peerDependencies: '@types/react': '*' @@ -4343,12 +4501,12 @@ packages: '@types/react': optional: true dependencies: - '@babel/runtime': 7.23.8 - '@types/react': 18.2.48 + '@babel/runtime': 7.23.9 + '@types/react': 18.2.57 react: 18.2.0 dev: true - /@radix-ui/react-use-controllable-state@1.0.1(@types/react@18.2.48)(react@18.2.0): + /@radix-ui/react-use-controllable-state@1.0.1(@types/react@18.2.57)(react@18.2.0): resolution: {integrity: sha512-Svl5GY5FQeN758fWKrjM6Qb7asvXeiZltlT4U2gVfl8Gx5UAv2sMR0LWo8yhsIZh2oQ0eFdZ59aoOOMV7b47VA==} peerDependencies: '@types/react': '*' @@ -4357,13 +4515,13 @@ packages: '@types/react': optional: true dependencies: - '@babel/runtime': 7.23.8 - '@radix-ui/react-use-callback-ref': 1.0.1(@types/react@18.2.48)(react@18.2.0) - '@types/react': 18.2.48 + '@babel/runtime': 7.23.9 + '@radix-ui/react-use-callback-ref': 1.0.1(@types/react@18.2.57)(react@18.2.0) + '@types/react': 18.2.57 react: 18.2.0 dev: true - /@radix-ui/react-use-escape-keydown@1.0.3(@types/react@18.2.48)(react@18.2.0): + /@radix-ui/react-use-escape-keydown@1.0.3(@types/react@18.2.57)(react@18.2.0): resolution: {integrity: sha512-vyL82j40hcFicA+M4Ex7hVkB9vHgSse1ZWomAqV2Je3RleKGO5iM8KMOEtfoSB0PnIelMd2lATjTGMYqN5ylTg==} peerDependencies: '@types/react': '*' @@ -4372,13 +4530,13 @@ packages: '@types/react': optional: true dependencies: - '@babel/runtime': 7.23.8 - '@radix-ui/react-use-callback-ref': 1.0.1(@types/react@18.2.48)(react@18.2.0) - '@types/react': 18.2.48 + '@babel/runtime': 7.23.9 + '@radix-ui/react-use-callback-ref': 1.0.1(@types/react@18.2.57)(react@18.2.0) + '@types/react': 18.2.57 react: 18.2.0 dev: true - /@radix-ui/react-use-layout-effect@1.0.1(@types/react@18.2.48)(react@18.2.0): + /@radix-ui/react-use-layout-effect@1.0.1(@types/react@18.2.57)(react@18.2.0): resolution: {integrity: sha512-v/5RegiJWYdoCvMnITBkNNx6bCj20fiaJnWtRkU18yITptraXjffz5Qbn05uOiQnOvi+dbkznkoaMltz1GnszQ==} peerDependencies: '@types/react': '*' @@ -4387,12 +4545,12 @@ packages: '@types/react': optional: true dependencies: - '@babel/runtime': 7.23.8 - '@types/react': 18.2.48 + '@babel/runtime': 7.23.9 + '@types/react': 18.2.57 react: 18.2.0 dev: true - /@radix-ui/react-use-previous@1.0.1(@types/react@18.2.48)(react@18.2.0): + /@radix-ui/react-use-previous@1.0.1(@types/react@18.2.57)(react@18.2.0): resolution: {integrity: sha512-cV5La9DPwiQ7S0gf/0qiD6YgNqM5Fk97Kdrlc5yBcrF3jyEZQwm7vYFqMo4IfeHgJXsRaMvLABFtd0OVEmZhDw==} peerDependencies: '@types/react': '*' @@ -4401,12 +4559,12 @@ packages: '@types/react': optional: true dependencies: - '@babel/runtime': 7.23.8 - '@types/react': 18.2.48 + '@babel/runtime': 7.23.9 + '@types/react': 18.2.57 react: 18.2.0 dev: true - /@radix-ui/react-use-rect@1.0.1(@types/react@18.2.48)(react@18.2.0): + /@radix-ui/react-use-rect@1.0.1(@types/react@18.2.57)(react@18.2.0): resolution: {integrity: sha512-Cq5DLuSiuYVKNU8orzJMbl15TXilTnJKUCltMVQg53BQOF1/C5toAaGrowkgksdBQ9H+SRL23g0HDmg9tvmxXw==} peerDependencies: '@types/react': '*' @@ -4415,13 +4573,13 @@ packages: '@types/react': optional: true dependencies: - '@babel/runtime': 7.23.8 + '@babel/runtime': 7.23.9 '@radix-ui/rect': 1.0.1 - '@types/react': 18.2.48 + '@types/react': 18.2.57 react: 18.2.0 dev: true - /@radix-ui/react-use-size@1.0.1(@types/react@18.2.48)(react@18.2.0): + /@radix-ui/react-use-size@1.0.1(@types/react@18.2.57)(react@18.2.0): resolution: {integrity: sha512-ibay+VqrgcaI6veAojjofPATwledXiSmX+C0KrBk/xgpX9rBzPV3OsfwlhQdUOFbh+LKQorLYT+xTXW9V8yd0g==} peerDependencies: '@types/react': '*' @@ -4430,13 +4588,13 @@ packages: '@types/react': optional: true dependencies: - '@babel/runtime': 7.23.8 - '@radix-ui/react-use-layout-effect': 1.0.1(@types/react@18.2.48)(react@18.2.0) - '@types/react': 18.2.48 + '@babel/runtime': 7.23.9 + '@radix-ui/react-use-layout-effect': 1.0.1(@types/react@18.2.57)(react@18.2.0) + '@types/react': 18.2.57 react: 18.2.0 dev: true - /@radix-ui/react-visually-hidden@1.0.3(@types/react-dom@18.2.18)(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0): + /@radix-ui/react-visually-hidden@1.0.3(@types/react-dom@18.2.19)(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0): resolution: {integrity: sha512-D4w41yN5YRKtu464TLnByKzMDG/JlMPHtfZgQAu9v6mNakUqGUI9vUrfQKz8NK41VMm/xbZbh76NUTVtIYqOMA==} peerDependencies: '@types/react': '*' @@ -4449,10 +4607,10 @@ packages: '@types/react-dom': optional: true dependencies: - '@babel/runtime': 7.23.8 - '@radix-ui/react-primitive': 1.0.3(@types/react-dom@18.2.18)(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0) - '@types/react': 18.2.48 - '@types/react-dom': 18.2.18 + '@babel/runtime': 7.23.9 + '@radix-ui/react-primitive': 1.0.3(@types/react-dom@18.2.19)(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0) + '@types/react': 18.2.57 + '@types/react-dom': 18.2.19 react: 18.2.0 react-dom: 18.2.0(react@18.2.0) dev: true @@ -4460,43 +4618,43 @@ packages: /@radix-ui/rect@1.0.1: resolution: {integrity: sha512-fyrgCaedtvMg9NK3en0pnOYJdtfwxUcNolezkNPUsoX57X8oQk+NkqcvzHXD2uKNij6GXmWU9NDru2IWjrO4BQ==} dependencies: - '@babel/runtime': 7.23.8 + '@babel/runtime': 7.23.9 dev: true - /@reactflow/background@11.3.7(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0): - resolution: {integrity: sha512-PhkvoFtO/NXJgFtBvfbPwdR/6/dl25egQlFhKWS3T4aYa7rh80dvf6dF3t6+JXJS4q5ToYJizD2/n8/qylo1yQ==} + /@reactflow/background@11.3.9(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0): + resolution: {integrity: sha512-byj/G9pEC8tN0wT/ptcl/LkEP/BBfa33/SvBkqE4XwyofckqF87lKp573qGlisfnsijwAbpDlf81PuFL41So4Q==} peerDependencies: react: '>=17' react-dom: '>=17' dependencies: - '@reactflow/core': 11.10.2(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0) + '@reactflow/core': 11.10.4(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0) classcat: 5.0.4 react: 18.2.0 react-dom: 18.2.0(react@18.2.0) - zustand: 4.4.7(@types/react@18.2.48)(react@18.2.0) + zustand: 4.5.1(@types/react@18.2.57)(react@18.2.0) transitivePeerDependencies: - '@types/react' - immer dev: false - /@reactflow/controls@11.2.7(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0): - resolution: {integrity: sha512-mugzVALH/SuKlVKk+JCRm1OXQ+p8e9+k8PCTIaqL+nBl+lPF8KA4uMm8ApsOvhuSAb2A80ezewpyvYHr0qSYVA==} + /@reactflow/controls@11.2.9(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0): + resolution: {integrity: sha512-e8nWplbYfOn83KN1BrxTXS17+enLyFnjZPbyDgHSRLtI5ZGPKF/8iRXV+VXb2LFVzlu4Wh3la/pkxtfP/0aguA==} peerDependencies: react: '>=17' react-dom: '>=17' dependencies: - '@reactflow/core': 11.10.2(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0) + '@reactflow/core': 11.10.4(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0) classcat: 5.0.4 react: 18.2.0 react-dom: 18.2.0(react@18.2.0) - zustand: 4.4.7(@types/react@18.2.48)(react@18.2.0) + zustand: 4.5.1(@types/react@18.2.57)(react@18.2.0) transitivePeerDependencies: - '@types/react' - immer dev: false - /@reactflow/core@11.10.2(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0): - resolution: {integrity: sha512-/cbTxtFpfkIGReSVkcnQhS4Jx4VFY2AhPlJ5n0sbPtnR7OWowF9zodh5Yyzr4j1NOUoBgJ9h+UqGEwwY2dbAlw==} + /@reactflow/core@11.10.4(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0): + resolution: {integrity: sha512-j3i9b2fsTX/sBbOm+RmNzYEFWbNx4jGWGuGooh2r1jQaE2eV+TLJgiG/VNOp0q5mBl9f6g1IXs3Gm86S9JfcGw==} peerDependencies: react: '>=17' react-dom: '>=17' @@ -4511,19 +4669,19 @@ packages: d3-zoom: 3.0.0 react: 18.2.0 react-dom: 18.2.0(react@18.2.0) - zustand: 4.4.7(@types/react@18.2.48)(react@18.2.0) + zustand: 4.5.1(@types/react@18.2.57)(react@18.2.0) transitivePeerDependencies: - '@types/react' - immer dev: false - /@reactflow/minimap@11.7.7(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0): - resolution: {integrity: sha512-Pwqw31tJ663cJur6ypqyJU33nPckvTepmz96erdQZoHsfOyLmFj4nXT7afC30DJ48lp0nfNsw+028mlf7f/h4g==} + /@reactflow/minimap@11.7.9(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0): + resolution: {integrity: sha512-le95jyTtt3TEtJ1qa7tZ5hyM4S7gaEQkW43cixcMOZLu33VAdc2aCpJg/fXcRrrf7moN2Mbl9WIMNXUKsp5ILA==} peerDependencies: react: '>=17' react-dom: '>=17' dependencies: - '@reactflow/core': 11.10.2(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0) + '@reactflow/core': 11.10.4(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0) '@types/d3-selection': 3.0.10 '@types/d3-zoom': 3.0.8 classcat: 5.0.4 @@ -4531,48 +4689,48 @@ packages: d3-zoom: 3.0.0 react: 18.2.0 react-dom: 18.2.0(react@18.2.0) - zustand: 4.4.7(@types/react@18.2.48)(react@18.2.0) + zustand: 4.5.1(@types/react@18.2.57)(react@18.2.0) transitivePeerDependencies: - '@types/react' - immer dev: false - /@reactflow/node-resizer@2.2.7(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0): - resolution: {integrity: sha512-BMBstmWNiklHnnAjHu8irkiPQ8/k8nnjzqlTql4acbVhD6Tsdxx/t/saOkELmfQODqGZNiPw9+pHcAHgtE6oNQ==} + /@reactflow/node-resizer@2.2.9(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0): + resolution: {integrity: sha512-HfickMm0hPDIHt9qH997nLdgLt0kayQyslKE0RS/GZvZ4UMQJlx/NRRyj5y47Qyg0NnC66KYOQWDM9LLzRTnUg==} peerDependencies: react: '>=17' react-dom: '>=17' dependencies: - '@reactflow/core': 11.10.2(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0) + '@reactflow/core': 11.10.4(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0) classcat: 5.0.4 d3-drag: 3.0.0 d3-selection: 3.0.0 react: 18.2.0 react-dom: 18.2.0(react@18.2.0) - zustand: 4.4.7(@types/react@18.2.48)(react@18.2.0) + zustand: 4.5.1(@types/react@18.2.57)(react@18.2.0) transitivePeerDependencies: - '@types/react' - immer dev: false - /@reactflow/node-toolbar@1.3.7(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0): - resolution: {integrity: sha512-75moEQKg23YKA3A2DNSFhq719ZPmby5mpwOD+NO7ZffJ88oMS/2eY8l8qpA3hvb1PTBHDxyKazhJirW+f4t0Wg==} + /@reactflow/node-toolbar@1.3.9(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0): + resolution: {integrity: sha512-VmgxKmToax4sX1biZ9LXA7cj/TBJ+E5cklLGwquCCVVxh+lxpZGTBF3a5FJGVHiUNBBtFsC8ldcSZIK4cAlQww==} peerDependencies: react: '>=17' react-dom: '>=17' dependencies: - '@reactflow/core': 11.10.2(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0) + '@reactflow/core': 11.10.4(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0) classcat: 5.0.4 react: 18.2.0 react-dom: 18.2.0(react@18.2.0) - zustand: 4.4.7(@types/react@18.2.48)(react@18.2.0) + zustand: 4.5.1(@types/react@18.2.57)(react@18.2.0) transitivePeerDependencies: - '@types/react' - immer dev: false - /@reduxjs/toolkit@2.0.1(react-redux@9.1.0)(react@18.2.0): - resolution: {integrity: sha512-fxIjrR9934cmS8YXIGd9e7s1XRsEU++aFc9DVNMFMRTM5Vtsg2DCRMj21eslGtDt43IUf9bJL3h5bwUlZleibA==} + /@reduxjs/toolkit@2.2.1(react-redux@9.1.0)(react@18.2.0): + resolution: {integrity: sha512-8CREoqJovQW/5I4yvvijm/emUiCCmcs4Ev4XPWd4mizSO+dD3g5G6w34QK5AGeNrSH7qM8Fl66j4vuV7dpOdkw==} peerDependencies: react: ^16.9.0 || ^17.0.0 || ^18 react-redux: ^7.2.1 || ^8.1.3 || ^9.0.0 @@ -4584,10 +4742,10 @@ packages: dependencies: immer: 10.0.3 react: 18.2.0 - react-redux: 9.1.0(@types/react@18.2.48)(react@18.2.0)(redux@5.0.1) + react-redux: 9.1.0(@types/react@18.2.57)(react@18.2.0)(redux@5.0.1) redux: 5.0.1 redux-thunk: 3.1.0(redux@5.0.1) - reselect: 5.0.1(patch_hash=kvbgwzjyy4x4fnh7znyocvb75q) + reselect: 5.1.0 dev: false /@roarr/browser-log-writer@1.3.0: @@ -4621,111 +4779,111 @@ packages: picomatch: 2.3.1 dev: true - /@rollup/rollup-android-arm-eabi@4.9.4: - resolution: {integrity: sha512-ub/SN3yWqIv5CWiAZPHVS1DloyZsJbtXmX4HxUTIpS0BHm9pW5iYBo2mIZi+hE3AeiTzHz33blwSnhdUo+9NpA==} + /@rollup/rollup-android-arm-eabi@4.12.0: + resolution: {integrity: sha512-+ac02NL/2TCKRrJu2wffk1kZ+RyqxVUlbjSagNgPm94frxtr+XDL12E5Ll1enWskLrtrZ2r8L3wED1orIibV/w==} cpu: [arm] os: [android] requiresBuild: true dev: true optional: true - /@rollup/rollup-android-arm64@4.9.4: - resolution: {integrity: sha512-ehcBrOR5XTl0W0t2WxfTyHCR/3Cq2jfb+I4W+Ch8Y9b5G+vbAecVv0Fx/J1QKktOrgUYsIKxWAKgIpvw56IFNA==} + /@rollup/rollup-android-arm64@4.12.0: + resolution: {integrity: sha512-OBqcX2BMe6nvjQ0Nyp7cC90cnumt8PXmO7Dp3gfAju/6YwG0Tj74z1vKrfRz7qAv23nBcYM8BCbhrsWqO7PzQQ==} cpu: [arm64] os: [android] requiresBuild: true dev: true optional: true - /@rollup/rollup-darwin-arm64@4.9.4: - resolution: {integrity: sha512-1fzh1lWExwSTWy8vJPnNbNM02WZDS8AW3McEOb7wW+nPChLKf3WG2aG7fhaUmfX5FKw9zhsF5+MBwArGyNM7NA==} + /@rollup/rollup-darwin-arm64@4.12.0: + resolution: {integrity: sha512-X64tZd8dRE/QTrBIEs63kaOBG0b5GVEd3ccoLtyf6IdXtHdh8h+I56C2yC3PtC9Ucnv0CpNFJLqKFVgCYe0lOQ==} cpu: [arm64] os: [darwin] requiresBuild: true dev: true optional: true - /@rollup/rollup-darwin-x64@4.9.4: - resolution: {integrity: sha512-Gc6cukkF38RcYQ6uPdiXi70JB0f29CwcQ7+r4QpfNpQFVHXRd0DfWFidoGxjSx1DwOETM97JPz1RXL5ISSB0pA==} + /@rollup/rollup-darwin-x64@4.12.0: + resolution: {integrity: sha512-cc71KUZoVbUJmGP2cOuiZ9HSOP14AzBAThn3OU+9LcA1+IUqswJyR1cAJj3Mg55HbjZP6OLAIscbQsQLrpgTOg==} cpu: [x64] os: [darwin] requiresBuild: true dev: true optional: true - /@rollup/rollup-linux-arm-gnueabihf@4.9.4: - resolution: {integrity: sha512-g21RTeFzoTl8GxosHbnQZ0/JkuFIB13C3T7Y0HtKzOXmoHhewLbVTFBQZu+z5m9STH6FZ7L/oPgU4Nm5ErN2fw==} + /@rollup/rollup-linux-arm-gnueabihf@4.12.0: + resolution: {integrity: sha512-a6w/Y3hyyO6GlpKL2xJ4IOh/7d+APaqLYdMf86xnczU3nurFTaVN9s9jOXQg97BE4nYm/7Ga51rjec5nfRdrvA==} cpu: [arm] os: [linux] requiresBuild: true dev: true optional: true - /@rollup/rollup-linux-arm64-gnu@4.9.4: - resolution: {integrity: sha512-TVYVWD/SYwWzGGnbfTkrNpdE4HON46orgMNHCivlXmlsSGQOx/OHHYiQcMIOx38/GWgwr/po2LBn7wypkWw/Mg==} + /@rollup/rollup-linux-arm64-gnu@4.12.0: + resolution: {integrity: sha512-0fZBq27b+D7Ar5CQMofVN8sggOVhEtzFUwOwPppQt0k+VR+7UHMZZY4y+64WJ06XOhBTKXtQB/Sv0NwQMXyNAA==} cpu: [arm64] os: [linux] requiresBuild: true dev: true optional: true - /@rollup/rollup-linux-arm64-musl@4.9.4: - resolution: {integrity: sha512-XcKvuendwizYYhFxpvQ3xVpzje2HHImzg33wL9zvxtj77HvPStbSGI9czrdbfrf8DGMcNNReH9pVZv8qejAQ5A==} + /@rollup/rollup-linux-arm64-musl@4.12.0: + resolution: {integrity: sha512-eTvzUS3hhhlgeAv6bfigekzWZjaEX9xP9HhxB0Dvrdbkk5w/b+1Sxct2ZuDxNJKzsRStSq1EaEkVSEe7A7ipgQ==} cpu: [arm64] os: [linux] requiresBuild: true dev: true optional: true - /@rollup/rollup-linux-riscv64-gnu@4.9.4: - resolution: {integrity: sha512-LFHS/8Q+I9YA0yVETyjonMJ3UA+DczeBd/MqNEzsGSTdNvSJa1OJZcSH8GiXLvcizgp9AlHs2walqRcqzjOi3A==} + /@rollup/rollup-linux-riscv64-gnu@4.12.0: + resolution: {integrity: sha512-ix+qAB9qmrCRiaO71VFfY8rkiAZJL8zQRXveS27HS+pKdjwUfEhqo2+YF2oI+H/22Xsiski+qqwIBxVewLK7sw==} cpu: [riscv64] os: [linux] requiresBuild: true dev: true optional: true - /@rollup/rollup-linux-x64-gnu@4.9.4: - resolution: {integrity: sha512-dIYgo+j1+yfy81i0YVU5KnQrIJZE8ERomx17ReU4GREjGtDW4X+nvkBak2xAUpyqLs4eleDSj3RrV72fQos7zw==} + /@rollup/rollup-linux-x64-gnu@4.12.0: + resolution: {integrity: sha512-TenQhZVOtw/3qKOPa7d+QgkeM6xY0LtwzR8OplmyL5LrgTWIXpTQg2Q2ycBf8jm+SFW2Wt/DTn1gf7nFp3ssVA==} cpu: [x64] os: [linux] requiresBuild: true dev: true optional: true - /@rollup/rollup-linux-x64-musl@4.9.4: - resolution: {integrity: sha512-RoaYxjdHQ5TPjaPrLsfKqR3pakMr3JGqZ+jZM0zP2IkDtsGa4CqYaWSfQmZVgFUCgLrTnzX+cnHS3nfl+kB6ZQ==} + /@rollup/rollup-linux-x64-musl@4.12.0: + resolution: {integrity: sha512-LfFdRhNnW0zdMvdCb5FNuWlls2WbbSridJvxOvYWgSBOYZtgBfW9UGNJG//rwMqTX1xQE9BAodvMH9tAusKDUw==} cpu: [x64] os: [linux] requiresBuild: true dev: true optional: true - /@rollup/rollup-win32-arm64-msvc@4.9.4: - resolution: {integrity: sha512-T8Q3XHV+Jjf5e49B4EAaLKV74BbX7/qYBRQ8Wop/+TyyU0k+vSjiLVSHNWdVd1goMjZcbhDmYZUYW5RFqkBNHQ==} + /@rollup/rollup-win32-arm64-msvc@4.12.0: + resolution: {integrity: sha512-JPDxovheWNp6d7AHCgsUlkuCKvtu3RB55iNEkaQcf0ttsDU/JZF+iQnYcQJSk/7PtT4mjjVG8N1kpwnI9SLYaw==} cpu: [arm64] os: [win32] requiresBuild: true dev: true optional: true - /@rollup/rollup-win32-ia32-msvc@4.9.4: - resolution: {integrity: sha512-z+JQ7JirDUHAsMecVydnBPWLwJjbppU+7LZjffGf+Jvrxq+dVjIE7By163Sc9DKc3ADSU50qPVw0KonBS+a+HQ==} + /@rollup/rollup-win32-ia32-msvc@4.12.0: + resolution: {integrity: sha512-fjtuvMWRGJn1oZacG8IPnzIV6GF2/XG+h71FKn76OYFqySXInJtseAqdprVTDTyqPxQOG9Exak5/E9Z3+EJ8ZA==} cpu: [ia32] os: [win32] requiresBuild: true dev: true optional: true - /@rollup/rollup-win32-x64-msvc@4.9.4: - resolution: {integrity: sha512-LfdGXCV9rdEify1oxlN9eamvDSjv9md9ZVMAbNHA87xqIfFCxImxan9qZ8+Un54iK2nnqPlbnSi4R54ONtbWBw==} + /@rollup/rollup-win32-x64-msvc@4.12.0: + resolution: {integrity: sha512-ZYmr5mS2wd4Dew/JjT0Fqi2NPB/ZhZ2VvPp7SmvPZb4Y1CG/LRcS6tcRo2cYU7zLK5A7cdbhWnnWmUjoI4qapg==} cpu: [x64] os: [win32] requiresBuild: true dev: true optional: true - /@rushstack/node-core-library@3.62.0(@types/node@20.11.5): + /@rushstack/node-core-library@3.62.0(@types/node@20.11.19): resolution: {integrity: sha512-88aJn2h8UpSvdwuDXBv1/v1heM6GnBf3RjEy6ZPP7UnzHNCqOHA2Ut+ScYUbXcqIdfew9JlTAe3g+cnX9xQ/Aw==} peerDependencies: '@types/node': '*' @@ -4733,7 +4891,7 @@ packages: '@types/node': optional: true dependencies: - '@types/node': 20.11.5 + '@types/node': 20.11.19 colors: 1.2.5 fs-extra: 7.0.1 import-lazy: 4.0.0 @@ -4767,29 +4925,29 @@ packages: resolution: {integrity: sha512-+9jVqKhRSpsc591z5vX+X5Yyw+he/HCB4iQ/RYxw35CEPaY1gnsNE43nf9n9AaYjAQrTiI/mOwKUKdUs9vf7Xg==} dev: false - /@storybook/addon-actions@7.6.10: - resolution: {integrity: sha512-pcKmf0H/caGzKDy8cz1adNSjv+KOBWLJ11RzGExrWm+Ad5ACifwlsQPykJ3TQ/21sTd9IXVrE9uuq4LldEnPbg==} + /@storybook/addon-actions@7.6.17: + resolution: {integrity: sha512-TBphs4v6LRfyTpFo/WINF0TkMaE3rrNog7wW5mbz6n0j8o53kDN4o9ZEcygSL5zQX43CAaghQTeDCss7ueG7ZQ==} dependencies: - '@storybook/core-events': 7.6.10 + '@storybook/core-events': 7.6.17 '@storybook/global': 5.0.0 - '@types/uuid': 9.0.7 + '@types/uuid': 9.0.8 dequal: 2.0.3 - polished: 4.2.2 + polished: 4.3.1 uuid: 9.0.1 dev: true - /@storybook/addon-backgrounds@7.6.10: - resolution: {integrity: sha512-kGzsN1QkfyI8Cz7TErEx9OCB3PMzpCFGLd/iy7FreXwbMbeAQ3/9fYgKUsNOYgOhuTz7S09koZUWjS/WJuZGFA==} + /@storybook/addon-backgrounds@7.6.17: + resolution: {integrity: sha512-7dize7x8+37PH77kmt69b0xSaeDqOcZ4fpzW6+hk53hIaCVU26eGs4+j+743Xva31eOgZWNLupUhOpUDc6SqZw==} dependencies: '@storybook/global': 5.0.0 memoizerific: 1.11.3 ts-dedent: 2.2.0 dev: true - /@storybook/addon-controls@7.6.10(@types/react-dom@18.2.18)(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0): - resolution: {integrity: sha512-LjwCQRMWq1apLtFwDi6U8MI6ITUr+KhxJucZ60tfc58RgB2v8ayozyDAonFEONsx9YSR1dNIJ2Z/e2rWTBJeYA==} + /@storybook/addon-controls@7.6.17(@types/react-dom@18.2.19)(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0): + resolution: {integrity: sha512-zR0aLaUF7FtV/nMRyfniFbCls/e0DAAoXACuOAUAwNAv0lbIS8AyZZiHSmKucCvziUQ6WceeCC7+du3C+9y0rQ==} dependencies: - '@storybook/blocks': 7.6.10(@types/react-dom@18.2.18)(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0) + '@storybook/blocks': 7.6.17(@types/react-dom@18.2.19)(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0) lodash: 4.17.21 ts-dedent: 2.2.0 transitivePeerDependencies: @@ -4801,27 +4959,27 @@ packages: - supports-color dev: true - /@storybook/addon-docs@7.6.10(@types/react-dom@18.2.18)(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0): - resolution: {integrity: sha512-GtyQ9bMx1AOOtl6ZS9vwK104HFRK+tqzxddRRxhXkpyeKu3olm9aMgXp35atE/3fJSqyyDm2vFtxxH8mzBA20A==} + /@storybook/addon-docs@7.6.17(@types/react-dom@18.2.19)(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0): + resolution: {integrity: sha512-FKa4Mdy7nhgvEVZJHpMkHriDzpVHbohn87zv9NCL+Ctjs1iAmzGwxEm0culszyDS1HN2ToVoY0h8CSi2RSSZqA==} peerDependencies: react: ^16.8.0 || ^17.0.0 || ^18.0.0 react-dom: ^16.8.0 || ^17.0.0 || ^18.0.0 dependencies: '@jest/transform': 29.7.0 '@mdx-js/react': 2.3.0(react@18.2.0) - '@storybook/blocks': 7.6.10(@types/react-dom@18.2.18)(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0) - '@storybook/client-logger': 7.6.10 - '@storybook/components': 7.6.10(@types/react-dom@18.2.18)(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0) - '@storybook/csf-plugin': 7.6.10 - '@storybook/csf-tools': 7.6.10 + '@storybook/blocks': 7.6.17(@types/react-dom@18.2.19)(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0) + '@storybook/client-logger': 7.6.17 + '@storybook/components': 7.6.17(@types/react-dom@18.2.19)(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0) + '@storybook/csf-plugin': 7.6.17 + '@storybook/csf-tools': 7.6.17 '@storybook/global': 5.0.0 '@storybook/mdx2-csf': 1.1.0 - '@storybook/node-logger': 7.6.10 - '@storybook/postinstall': 7.6.10 - '@storybook/preview-api': 7.6.10 - '@storybook/react-dom-shim': 7.6.10(react-dom@18.2.0)(react@18.2.0) - '@storybook/theming': 7.6.10(react-dom@18.2.0)(react@18.2.0) - '@storybook/types': 7.6.10 + '@storybook/node-logger': 7.6.17 + '@storybook/postinstall': 7.6.17 + '@storybook/preview-api': 7.6.17 + '@storybook/react-dom-shim': 7.6.17(react-dom@18.2.0)(react@18.2.0) + '@storybook/theming': 7.6.17(react-dom@18.2.0)(react@18.2.0) + '@storybook/types': 7.6.17 fs-extra: 11.2.0 react: 18.2.0 react-dom: 18.2.0(react@18.2.0) @@ -4835,25 +4993,25 @@ packages: - supports-color dev: true - /@storybook/addon-essentials@7.6.10(@types/react-dom@18.2.18)(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0): - resolution: {integrity: sha512-cjbuCCK/3dtUity0Uqi5LwbkgfxqCCE5x5mXZIk9lTMeDz5vB9q6M5nzncVDy8F8przF3NbDLLgxKlt8wjiICg==} + /@storybook/addon-essentials@7.6.17(@types/react-dom@18.2.19)(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0): + resolution: {integrity: sha512-qlSpamxuYfT2taF953nC9QijGF2pSbg1ewMNpdwLTj16PTZvR/d8NCDMTJujI1bDwM2m18u8Yc43ibh5LEmxCw==} peerDependencies: react: ^16.8.0 || ^17.0.0 || ^18.0.0 react-dom: ^16.8.0 || ^17.0.0 || ^18.0.0 dependencies: - '@storybook/addon-actions': 7.6.10 - '@storybook/addon-backgrounds': 7.6.10 - '@storybook/addon-controls': 7.6.10(@types/react-dom@18.2.18)(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0) - '@storybook/addon-docs': 7.6.10(@types/react-dom@18.2.18)(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0) - '@storybook/addon-highlight': 7.6.10 - '@storybook/addon-measure': 7.6.10 - '@storybook/addon-outline': 7.6.10 - '@storybook/addon-toolbars': 7.6.10 - '@storybook/addon-viewport': 7.6.10 - '@storybook/core-common': 7.6.10 - '@storybook/manager-api': 7.6.10(react-dom@18.2.0)(react@18.2.0) - '@storybook/node-logger': 7.6.10 - '@storybook/preview-api': 7.6.10 + '@storybook/addon-actions': 7.6.17 + '@storybook/addon-backgrounds': 7.6.17 + '@storybook/addon-controls': 7.6.17(@types/react-dom@18.2.19)(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0) + '@storybook/addon-docs': 7.6.17(@types/react-dom@18.2.19)(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0) + '@storybook/addon-highlight': 7.6.17 + '@storybook/addon-measure': 7.6.17 + '@storybook/addon-outline': 7.6.17 + '@storybook/addon-toolbars': 7.6.17 + '@storybook/addon-viewport': 7.6.17 + '@storybook/core-common': 7.6.17 + '@storybook/manager-api': 7.6.17(react-dom@18.2.0)(react@18.2.0) + '@storybook/node-logger': 7.6.17 + '@storybook/preview-api': 7.6.17 react: 18.2.0 react-dom: 18.2.0(react@18.2.0) ts-dedent: 2.2.0 @@ -4864,24 +5022,24 @@ packages: - supports-color dev: true - /@storybook/addon-highlight@7.6.10: - resolution: {integrity: sha512-dIuS5QmoT1R+gFOcf6CoBa6D9UR5/wHCfPqPRH8dNNcCLtIGSHWQ4v964mS5OCq1Huj7CghmR15lOUk7SaYwUA==} + /@storybook/addon-highlight@7.6.17: + resolution: {integrity: sha512-R1yBPUUqGn+60aJakn8q+5Zt34E/gU3n3VmgPdryP0LJUdZ5q1/RZShoVDV+yYQ40htMH6oaCv3OyyPzFAGJ6A==} dependencies: '@storybook/global': 5.0.0 dev: true - /@storybook/addon-interactions@7.6.10: - resolution: {integrity: sha512-lEsAdP/PrOZK/KmRbZ/fU4RjEqDP+e/PBlVVVJT2QvHniWK/xxkjCD0axsHU/XuaeQRFhmg0/KR342PC/cIf9A==} + /@storybook/addon-interactions@7.6.17: + resolution: {integrity: sha512-6zlX+RDQ1PlA6fp7C+hun8t7h2RXfCGs5dGrhEenp2lqnR/rYuUJRC0tmKpkZBb8kZVcbSChzkB/JYkBjBCzpQ==} dependencies: '@storybook/global': 5.0.0 - '@storybook/types': 7.6.10 + '@storybook/types': 7.6.17 jest-mock: 27.5.1 - polished: 4.2.2 + polished: 4.3.1 ts-dedent: 2.2.0 dev: true - /@storybook/addon-links@7.6.10(react@18.2.0): - resolution: {integrity: sha512-s/WkSYHpr2pb9p57j6u/xDBg3TKJhBq55YMl0GB5gXgkRPIeuGbPhGJhm2yTGVFLvXgr/aHHnOxb/R/W8PiRhA==} + /@storybook/addon-links@7.6.17(react@18.2.0): + resolution: {integrity: sha512-iFUwKObRn0EKI0zMETsil2p9a/81rCuSMEWECsi+khkCAs1FUnD2cT6Ag5ydcNcBXsdtdfDJdtXQrkw+TSoStQ==} peerDependencies: react: ^16.8.0 || ^17.0.0 || ^18.0.0 peerDependenciesMeta: @@ -4894,62 +5052,62 @@ packages: ts-dedent: 2.2.0 dev: true - /@storybook/addon-measure@7.6.10: - resolution: {integrity: sha512-OVfTI56+kc4hLWfZ/YPV3WKj/aA9e4iKXYxZyPdhfX4Z8TgZdD1wv9Z6e8DKS0H5kuybYrHKHaID5ki6t7qz3w==} + /@storybook/addon-measure@7.6.17: + resolution: {integrity: sha512-O5vnHZNkduvZ95jf1UssbOl6ivIxzl5tv+4EpScPYId7w700bxWsJH+QX7ip6KlrCf2o3iUhmPe8bm05ghG2KA==} dependencies: '@storybook/global': 5.0.0 tiny-invariant: 1.3.1 dev: true - /@storybook/addon-outline@7.6.10: - resolution: {integrity: sha512-RVJrEoPArhI6zAIMNl1Gz0zrj84BTfEWYYz0yDWOTVgvN411ugsoIk1hw0671MOneXJ2RcQ9MFIeV/v6AVDQYg==} + /@storybook/addon-outline@7.6.17: + resolution: {integrity: sha512-9o9JXDsYjNaDgz/cY5+jv694+aik/1aiRGGvsCv68e1p/ob0glkGKav4lnJe2VJqD+gCmaARoD8GOJlhoQl8JQ==} dependencies: '@storybook/global': 5.0.0 ts-dedent: 2.2.0 dev: true - /@storybook/addon-storysource@7.6.10: - resolution: {integrity: sha512-ZtMiO26Bqd2oEovEeJ5ulvIL/rsAuHHpjAgBRZd/Byw25DQKY3GTqGtV474Wjm5tzj7HWhfk69fqAv87HnveCw==} + /@storybook/addon-storysource@7.6.17: + resolution: {integrity: sha512-8SZiIuIkRU9NQM3Y2mmE0m+bqtXQefzW8Z9DkPKwTJSJxVBvMZVMHjRiQcPn8ll6zhqQIaQiBj0ahlR8ZqrnqA==} dependencies: - '@storybook/source-loader': 7.6.10 + '@storybook/source-loader': 7.6.17 estraverse: 5.3.0 tiny-invariant: 1.3.1 dev: true - /@storybook/addon-toolbars@7.6.10: - resolution: {integrity: sha512-PaXY/oj9yxF7/H0CNdQKcioincyCkfeHpISZriZbZqhyqsjn3vca7RFEmsB88Q+ou6rMeqyA9st+6e2cx/Ct6A==} + /@storybook/addon-toolbars@7.6.17: + resolution: {integrity: sha512-UMrchbUHiyWrh6WuGnpy34Jqzkx/63B+MSgb3CW7YsQaXz64kE0Rol0TNSznnB+mYXplcqH+ndI4r4kFsmgwDg==} dev: true - /@storybook/addon-viewport@7.6.10: - resolution: {integrity: sha512-+bA6juC/lH4vEhk+w0rXakaG8JgLG4MOYrIudk5vJKQaC6X58LIM9N4kzIS2KSExRhkExXBPrWsnMfCo7uxmKg==} + /@storybook/addon-viewport@7.6.17: + resolution: {integrity: sha512-sA0QCcf4QAMixWvn8uvRYPfkKCSl6JajJaAspoPqXSxHEpK7uwOlpg3kqFU5XJJPXD0X957M+ONgNvBzYqSpEw==} dependencies: memoizerific: 1.11.3 dev: true - /@storybook/blocks@7.6.10(@types/react-dom@18.2.18)(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0): - resolution: {integrity: sha512-oSIukGC3yuF8pojABC/HLu5tv2axZvf60TaUs8eDg7+NiiKhzYSPoMQxs5uMrKngl+EJDB92ESgWT9vvsfvIPg==} + /@storybook/blocks@7.6.17(@types/react-dom@18.2.19)(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0): + resolution: {integrity: sha512-PsNVoe0bX1mMn4Kk3nbKZ0ItDZZ0YJnYAFJ6toAbsyBAbgzg1sce88sQinzvbn58/RT9MPKeWMPB45ZS7ggiNg==} peerDependencies: react: ^16.8.0 || ^17.0.0 || ^18.0.0 react-dom: ^16.8.0 || ^17.0.0 || ^18.0.0 dependencies: - '@storybook/channels': 7.6.10 - '@storybook/client-logger': 7.6.10 - '@storybook/components': 7.6.10(@types/react-dom@18.2.18)(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0) - '@storybook/core-events': 7.6.10 + '@storybook/channels': 7.6.17 + '@storybook/client-logger': 7.6.17 + '@storybook/components': 7.6.17(@types/react-dom@18.2.19)(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0) + '@storybook/core-events': 7.6.17 '@storybook/csf': 0.1.2 - '@storybook/docs-tools': 7.6.10 + '@storybook/docs-tools': 7.6.17 '@storybook/global': 5.0.0 - '@storybook/manager-api': 7.6.10(react-dom@18.2.0)(react@18.2.0) - '@storybook/preview-api': 7.6.10 - '@storybook/theming': 7.6.10(react-dom@18.2.0)(react@18.2.0) - '@storybook/types': 7.6.10 + '@storybook/manager-api': 7.6.17(react-dom@18.2.0)(react@18.2.0) + '@storybook/preview-api': 7.6.17 + '@storybook/theming': 7.6.17(react-dom@18.2.0)(react@18.2.0) + '@storybook/types': 7.6.17 '@types/lodash': 4.14.202 color-convert: 2.0.1 dequal: 2.0.3 lodash: 4.17.21 - markdown-to-jsx: 7.4.0(react@18.2.0) + markdown-to-jsx: 7.4.1(react@18.2.0) memoizerific: 1.11.3 - polished: 4.2.2 + polished: 4.3.1 react: 18.2.0 react-colorful: 5.6.1(react-dom@18.2.0)(react@18.2.0) react-dom: 18.2.0(react@18.2.0) @@ -4964,13 +5122,13 @@ packages: - supports-color dev: true - /@storybook/builder-manager@7.6.10: - resolution: {integrity: sha512-f+YrjZwohGzvfDtH8BHzqM3xW0p4vjjg9u7uzRorqUiNIAAKHpfNrZ/WvwPlPYmrpAHt4xX/nXRJae4rFSygPw==} + /@storybook/builder-manager@7.6.17: + resolution: {integrity: sha512-Sj8hcDYiPCCMfeLzus37czl0zdrAxAz4IyYam2jBjVymrIrcDAFyL1OCZvnq33ft179QYQWhUs9qwzVmlR/ZWg==} dependencies: '@fal-works/esbuild-plugin-global-externals': 2.1.2 - '@storybook/core-common': 7.6.10 - '@storybook/manager': 7.6.10 - '@storybook/node-logger': 7.6.10 + '@storybook/core-common': 7.6.17 + '@storybook/manager': 7.6.17 + '@storybook/node-logger': 7.6.17 '@types/ejs': 3.1.5 '@types/find-cache-dir': 3.2.1 '@yarnpkg/esbuild-plugin-pnp': 3.0.0-rc.15(esbuild@0.18.20) @@ -4988,8 +5146,8 @@ packages: - supports-color dev: true - /@storybook/builder-vite@7.6.10(typescript@5.3.3)(vite@5.0.12): - resolution: {integrity: sha512-qxe19axiNJVdIKj943e1ucAmADwU42fTGgMSdBzzrvfH3pSOmx2057aIxRzd8YtBRnj327eeqpgCHYIDTunMYQ==} + /@storybook/builder-vite@7.6.17(typescript@5.3.3)(vite@5.1.3): + resolution: {integrity: sha512-2Q32qalI401EsKKr9Hkk8TAOcHEerqwsjCpQgTNJnCu6GgCVKoVUcb99oRbR9Vyg0xh+jb19XiWqqQujFtLYlQ==} peerDependencies: '@preact/preset-vite': '*' typescript: '>= 4.3.x' @@ -5003,64 +5161,64 @@ packages: vite-plugin-glimmerx: optional: true dependencies: - '@storybook/channels': 7.6.10 - '@storybook/client-logger': 7.6.10 - '@storybook/core-common': 7.6.10 - '@storybook/csf-plugin': 7.6.10 - '@storybook/node-logger': 7.6.10 - '@storybook/preview': 7.6.10 - '@storybook/preview-api': 7.6.10 - '@storybook/types': 7.6.10 + '@storybook/channels': 7.6.17 + '@storybook/client-logger': 7.6.17 + '@storybook/core-common': 7.6.17 + '@storybook/csf-plugin': 7.6.17 + '@storybook/node-logger': 7.6.17 + '@storybook/preview': 7.6.17 + '@storybook/preview-api': 7.6.17 + '@storybook/types': 7.6.17 '@types/find-cache-dir': 3.2.1 browser-assert: 1.2.1 es-module-lexer: 0.9.3 express: 4.18.2 find-cache-dir: 3.3.2 fs-extra: 11.2.0 - magic-string: 0.30.5 + magic-string: 0.30.7 rollup: 3.29.4 typescript: 5.3.3 - vite: 5.0.12(@types/node@20.11.5) + vite: 5.1.3(@types/node@20.11.19) transitivePeerDependencies: - encoding - supports-color dev: true - /@storybook/channels@7.6.10: - resolution: {integrity: sha512-ITCLhFuDBKgxetuKnWwYqMUWlU7zsfH3gEKZltTb+9/2OAWR7ez0iqU7H6bXP1ridm0DCKkt2UMWj2mmr9iQqg==} + /@storybook/channels@7.6.17: + resolution: {integrity: sha512-GFG40pzaSxk1hUr/J/TMqW5AFDDPUSu+HkeE/oqSWJbOodBOLJzHN6CReJS6y1DjYSZLNFt1jftPWZZInG/XUA==} dependencies: - '@storybook/client-logger': 7.6.10 - '@storybook/core-events': 7.6.10 + '@storybook/client-logger': 7.6.17 + '@storybook/core-events': 7.6.17 '@storybook/global': 5.0.0 qs: 6.11.2 telejson: 7.2.0 tiny-invariant: 1.3.1 dev: true - /@storybook/cli@7.6.10: - resolution: {integrity: sha512-pK1MEseMm73OMO2OVoSz79QWX8ymxgIGM8IeZTCo9gImiVRChMNDFYcv8yPWkjuyesY8c15CoO48aR7pdA1OjQ==} + /@storybook/cli@7.6.17: + resolution: {integrity: sha512-1sCo+nCqyR+nKfTcEidVu8XzNoECC7Y1l+uW38/r7s2f/TdDorXaIGAVrpjbSaXSoQpx5DxYJVaKCcQuOgqwcA==} hasBin: true dependencies: - '@babel/core': 7.23.7 - '@babel/preset-env': 7.23.8(@babel/core@7.23.7) - '@babel/types': 7.23.6 + '@babel/core': 7.23.9 + '@babel/preset-env': 7.23.9(@babel/core@7.23.9) + '@babel/types': 7.23.9 '@ndelangen/get-tarball': 3.0.9 - '@storybook/codemod': 7.6.10 - '@storybook/core-common': 7.6.10 - '@storybook/core-events': 7.6.10 - '@storybook/core-server': 7.6.10 - '@storybook/csf-tools': 7.6.10 - '@storybook/node-logger': 7.6.10 - '@storybook/telemetry': 7.6.10 - '@storybook/types': 7.6.10 - '@types/semver': 7.5.6 + '@storybook/codemod': 7.6.17 + '@storybook/core-common': 7.6.17 + '@storybook/core-events': 7.6.17 + '@storybook/core-server': 7.6.17 + '@storybook/csf-tools': 7.6.17 + '@storybook/node-logger': 7.6.17 + '@storybook/telemetry': 7.6.17 + '@storybook/types': 7.6.17 + '@types/semver': 7.5.7 '@yarnpkg/fslib': 2.10.3 '@yarnpkg/libzip': 2.3.0 chalk: 4.1.2 commander: 6.2.1 cross-spawn: 7.0.3 detect-indent: 6.1.0 - envinfo: 7.11.0 + envinfo: 7.11.1 execa: 5.1.1 express: 4.18.2 find-up: 5.0.0 @@ -5069,14 +5227,14 @@ packages: get-port: 5.1.1 giget: 1.2.1 globby: 11.1.0 - jscodeshift: 0.15.1(@babel/preset-env@7.23.8) + jscodeshift: 0.15.1(@babel/preset-env@7.23.9) leven: 3.1.0 ora: 5.4.1 prettier: 2.8.8 prompts: 2.4.2 puppeteer-core: 2.1.1 read-pkg-up: 7.0.1 - semver: 7.5.4 + semver: 7.6.0 strip-json-comments: 3.1.1 tempy: 1.0.1 ts-dedent: 2.2.0 @@ -5088,26 +5246,26 @@ packages: - utf-8-validate dev: true - /@storybook/client-logger@7.6.10: - resolution: {integrity: sha512-U7bbpu21ntgePMz/mKM18qvCSWCUGCUlYru8mgVlXLCKqFqfTeP887+CsPEQf29aoE3cLgDrxqbRJ1wxX9kL9A==} + /@storybook/client-logger@7.6.17: + resolution: {integrity: sha512-6WBYqixAXNAXlSaBWwgljWpAu10tPRBJrcFvx2gPUne58EeMM20Gi/iHYBz2kMCY+JLAgeIH7ZxInqwO8vDwiQ==} dependencies: '@storybook/global': 5.0.0 dev: true - /@storybook/codemod@7.6.10: - resolution: {integrity: sha512-pzFR0nocBb94vN9QCJLC3C3dP734ZigqyPmd0ZCDj9Xce2ytfHK3v1lKB6TZWzKAZT8zztauECYxrbo4LVuagw==} + /@storybook/codemod@7.6.17: + resolution: {integrity: sha512-JuTmf2u3C4fCnjO7o3dqRgrq3ozNYfWlrRP8xuIdvT7niMap7a396hJtSKqS10FxCgKFcMAOsRgrCalH1dWxUg==} dependencies: - '@babel/core': 7.23.7 - '@babel/preset-env': 7.23.8(@babel/core@7.23.7) - '@babel/types': 7.23.6 + '@babel/core': 7.23.9 + '@babel/preset-env': 7.23.9(@babel/core@7.23.9) + '@babel/types': 7.23.9 '@storybook/csf': 0.1.2 - '@storybook/csf-tools': 7.6.10 - '@storybook/node-logger': 7.6.10 - '@storybook/types': 7.6.10 + '@storybook/csf-tools': 7.6.17 + '@storybook/node-logger': 7.6.17 + '@storybook/types': 7.6.17 '@types/cross-spawn': 6.0.6 cross-spawn: 7.0.3 globby: 11.1.0 - jscodeshift: 0.15.1(@babel/preset-env@7.23.8) + jscodeshift: 0.15.1(@babel/preset-env@7.23.9) lodash: 4.17.21 prettier: 2.8.8 recast: 0.23.4 @@ -5115,19 +5273,19 @@ packages: - supports-color dev: true - /@storybook/components@7.6.10(@types/react-dom@18.2.18)(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0): - resolution: {integrity: sha512-H5hF8pxwtbt0LxV24KMMsPlbYG9Oiui3ObvAQkvGu6q62EYxRPeNSrq3GBI5XEbI33OJY9bT24cVaZx18dXqwQ==} + /@storybook/components@7.6.17(@types/react-dom@18.2.19)(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0): + resolution: {integrity: sha512-lbh7GynMidA+CZcJnstVku6Nhs+YkqjYaZ+mKPugvlVhGVWv0DaaeQFVuZ8cJtUGJ/5FFU4Y+n+gylYUHkGBMA==} peerDependencies: react: ^16.8.0 || ^17.0.0 || ^18.0.0 react-dom: ^16.8.0 || ^17.0.0 || ^18.0.0 dependencies: - '@radix-ui/react-select': 1.2.2(@types/react-dom@18.2.18)(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0) - '@radix-ui/react-toolbar': 1.0.4(@types/react-dom@18.2.18)(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0) - '@storybook/client-logger': 7.6.10 + '@radix-ui/react-select': 1.2.2(@types/react-dom@18.2.19)(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0) + '@radix-ui/react-toolbar': 1.0.4(@types/react-dom@18.2.19)(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0) + '@storybook/client-logger': 7.6.17 '@storybook/csf': 0.1.2 '@storybook/global': 5.0.0 - '@storybook/theming': 7.6.10(react-dom@18.2.0)(react@18.2.0) - '@storybook/types': 7.6.10 + '@storybook/theming': 7.6.17(react-dom@18.2.0)(react@18.2.0) + '@storybook/types': 7.6.17 memoizerific: 1.11.3 react: 18.2.0 react-dom: 18.2.0(react@18.2.0) @@ -5138,21 +5296,21 @@ packages: - '@types/react-dom' dev: true - /@storybook/core-client@7.6.10: - resolution: {integrity: sha512-DjnzSzSNDmZyxyg6TxugzWQwOsW+n/iWVv6sHNEvEd5STr0mjuJjIEELmv58LIr5Lsre5+LEddqHsyuLyt8ubg==} + /@storybook/core-client@7.6.17: + resolution: {integrity: sha512-LuDbADK+DPNAOOCXOlvY09hdGVueXlDetsdOJ/DgYnSa9QSWv9Uv+F8QcEgR3QckZJbPlztKJIVLgP2n/Xkijw==} dependencies: - '@storybook/client-logger': 7.6.10 - '@storybook/preview-api': 7.6.10 + '@storybook/client-logger': 7.6.17 + '@storybook/preview-api': 7.6.17 dev: true - /@storybook/core-common@7.6.10: - resolution: {integrity: sha512-K3YWqjCKMnpvYsWNjOciwTH6zWbuuZzmOiipziZaVJ+sB1XYmH52Y3WGEm07TZI8AYK9DRgwA13dR/7W0nw72Q==} + /@storybook/core-common@7.6.17: + resolution: {integrity: sha512-me2TP3Q9/qzqCLoDHUSsUF+VS1MHxfHbTVF6vAz0D/COTxzsxLpu9TxTbzJoBCxse6XRb6wWI1RgF1mIcjic7g==} dependencies: - '@storybook/core-events': 7.6.10 - '@storybook/node-logger': 7.6.10 - '@storybook/types': 7.6.10 + '@storybook/core-events': 7.6.17 + '@storybook/node-logger': 7.6.17 + '@storybook/types': 7.6.17 '@types/find-cache-dir': 3.2.1 - '@types/node': 18.19.8 + '@types/node': 18.19.17 '@types/node-fetch': 2.6.11 '@types/pretty-hrtime': 1.0.3 chalk: 4.1.2 @@ -5176,34 +5334,34 @@ packages: - supports-color dev: true - /@storybook/core-events@7.6.10: - resolution: {integrity: sha512-yccDH67KoROrdZbRKwxgTswFMAco5nlCyxszCDASCLygGSV2Q2e+YuywrhchQl3U6joiWi3Ps1qWu56NeNafag==} + /@storybook/core-events@7.6.17: + resolution: {integrity: sha512-AriWMCm/k1cxlv10f+jZ1wavThTRpLaN3kY019kHWbYT9XgaSuLU67G7GPr3cGnJ6HuA6uhbzu8qtqVCd6OfXA==} dependencies: ts-dedent: 2.2.0 dev: true - /@storybook/core-server@7.6.10: - resolution: {integrity: sha512-2icnqJkn3vwq0eJPP0rNaHd7IOvxYf5q4lSVl2AWTxo/Ae19KhokI6j/2vvS2XQJMGQszwshlIwrZUNsj5p0yw==} + /@storybook/core-server@7.6.17: + resolution: {integrity: sha512-KWGhTTaL1Q14FolcoKKZgytlPJUbH6sbJ1Ptj/84EYWFewcnEgVs0Zlnh1VStRZg+Rd1WC1V4yVd/bbDzxrvQA==} dependencies: '@aw-web-design/x-default-browser': 1.4.126 '@discoveryjs/json-ext': 0.5.7 - '@storybook/builder-manager': 7.6.10 - '@storybook/channels': 7.6.10 - '@storybook/core-common': 7.6.10 - '@storybook/core-events': 7.6.10 + '@storybook/builder-manager': 7.6.17 + '@storybook/channels': 7.6.17 + '@storybook/core-common': 7.6.17 + '@storybook/core-events': 7.6.17 '@storybook/csf': 0.1.2 - '@storybook/csf-tools': 7.6.10 + '@storybook/csf-tools': 7.6.17 '@storybook/docs-mdx': 0.1.0 '@storybook/global': 5.0.0 - '@storybook/manager': 7.6.10 - '@storybook/node-logger': 7.6.10 - '@storybook/preview-api': 7.6.10 - '@storybook/telemetry': 7.6.10 - '@storybook/types': 7.6.10 + '@storybook/manager': 7.6.17 + '@storybook/node-logger': 7.6.17 + '@storybook/preview-api': 7.6.17 + '@storybook/telemetry': 7.6.17 + '@storybook/types': 7.6.17 '@types/detect-port': 1.3.5 - '@types/node': 18.19.8 + '@types/node': 18.19.17 '@types/pretty-hrtime': 1.0.3 - '@types/semver': 7.5.6 + '@types/semver': 7.5.7 better-opn: 3.0.2 chalk: 4.1.2 cli-table3: 0.6.3 @@ -5212,13 +5370,13 @@ packages: express: 4.18.2 fs-extra: 11.2.0 globby: 11.1.0 - ip: 2.0.0 + ip: 2.0.1 lodash: 4.17.21 open: 8.4.2 pretty-hrtime: 1.0.3 prompts: 2.4.2 read-pkg-up: 7.0.1 - semver: 7.5.4 + semver: 7.6.0 telejson: 7.2.0 tiny-invariant: 1.3.1 ts-dedent: 2.2.0 @@ -5233,24 +5391,24 @@ packages: - utf-8-validate dev: true - /@storybook/csf-plugin@7.6.10: - resolution: {integrity: sha512-Sc+zZg/BnPH2X28tthNaQBnDiFfO0QmfjVoOx0fGYM9SvY3P5ehzWwp5hMRBim6a/twOTzePADtqYL+t6GMqqg==} + /@storybook/csf-plugin@7.6.17: + resolution: {integrity: sha512-xTHv9BUh3bkDVCvcbmdfVF0/e96BdrEgqPJ3G3RmKbSzWLOkQ2U9yiPfHzT0KJWPhVwj12fjfZp0zunu+pcS6Q==} dependencies: - '@storybook/csf-tools': 7.6.10 - unplugin: 1.6.0 + '@storybook/csf-tools': 7.6.17 + unplugin: 1.7.1 transitivePeerDependencies: - supports-color dev: true - /@storybook/csf-tools@7.6.10: - resolution: {integrity: sha512-TnDNAwIALcN6SA4l00Cb67G02XMOrYU38bIpFJk5VMDX2dvgPjUtJNBuLmEbybGcOt7nPyyFIHzKcY5FCVGoWA==} + /@storybook/csf-tools@7.6.17: + resolution: {integrity: sha512-dAQtam0EBPeTJYcQPLxXgz4L9JFqD+HWbLFG9CmNIhMMjticrB0mpk1EFIS6vPXk/VsVWpBgMLD7dZlD6YMKcQ==} dependencies: '@babel/generator': 7.23.6 - '@babel/parser': 7.23.6 - '@babel/traverse': 7.23.7 - '@babel/types': 7.23.6 + '@babel/parser': 7.23.9 + '@babel/traverse': 7.23.9 + '@babel/types': 7.23.9 '@storybook/csf': 0.1.2 - '@storybook/types': 7.6.10 + '@storybook/types': 7.6.17 fs-extra: 11.2.0 recast: 0.23.4 ts-dedent: 2.2.0 @@ -5274,12 +5432,12 @@ packages: resolution: {integrity: sha512-JDaBR9lwVY4eSH5W8EGHrhODjygPd6QImRbwjAuJNEnY0Vw4ie3bPkeGfnacB3OBW6u/agqPv2aRlR46JcAQLg==} dev: true - /@storybook/docs-tools@7.6.10: - resolution: {integrity: sha512-UgbikducoXzqQHf2TozO0f2rshaeBNnShVbL5Ai4oW7pDymBmrfzdjGbF/milO7yxNKcoIByeoNmu384eBamgQ==} + /@storybook/docs-tools@7.6.17: + resolution: {integrity: sha512-bYrLoj06adqklyLkEwD32C0Ww6t+9ZVvrJHiVT42bIhTRpFiFPAetl1a9KPHtFLnfduh4n2IxIr1jv32ThPDTA==} dependencies: - '@storybook/core-common': 7.6.10 - '@storybook/preview-api': 7.6.10 - '@storybook/types': 7.6.10 + '@storybook/core-common': 7.6.17 + '@storybook/preview-api': 7.6.17 + '@storybook/types': 7.6.17 '@types/doctrine': 0.0.3 assert: 2.1.0 doctrine: 3.0.0 @@ -5293,33 +5451,33 @@ packages: resolution: {integrity: sha512-FcOqPAXACP0I3oJ/ws6/rrPT9WGhu915Cg8D02a9YxLo0DE9zI+a9A5gRGvmQ09fiWPukqI8ZAEoQEdWUKMQdQ==} dev: true - /@storybook/instrumenter@7.6.10: - resolution: {integrity: sha512-9FYXW1CKXnZ7yYmy2A6U0seqJMe1F7g55J28Vslk3ZLoGATFJ2BR0eoQS+cgfBly6djehjaVeuV3IcUYGnQ/6Q==} + /@storybook/instrumenter@7.6.17: + resolution: {integrity: sha512-zTLIPTt1fvlWgkIVUyQpF327iVE+EiPdpM0Or0aARaNfIikPRBTcjU+6cK96E+Ust2E1qKajEjIuv4i4lLQPng==} dependencies: - '@storybook/channels': 7.6.10 - '@storybook/client-logger': 7.6.10 - '@storybook/core-events': 7.6.10 + '@storybook/channels': 7.6.17 + '@storybook/client-logger': 7.6.17 + '@storybook/core-events': 7.6.17 '@storybook/global': 5.0.0 - '@storybook/preview-api': 7.6.10 + '@storybook/preview-api': 7.6.17 '@vitest/utils': 0.34.7 util: 0.12.5 dev: true - /@storybook/manager-api@7.6.10(react-dom@18.2.0)(react@18.2.0): - resolution: {integrity: sha512-8eGVpRlpunuFScDtc7nxpPJf/4kJBAAZlNdlhmX09j8M3voX6GpcxabBamSEX5pXZqhwxQCshD4IbqBmjvadlw==} + /@storybook/manager-api@7.6.17(react-dom@18.2.0)(react@18.2.0): + resolution: {integrity: sha512-IJIV1Yc6yw1dhCY4tReHCfBnUKDqEBnMyHp3mbXpsaHxnxJZrXO45WjRAZIKlQKhl/Ge1CrnznmHRCmYgqmrWg==} dependencies: - '@storybook/channels': 7.6.10 - '@storybook/client-logger': 7.6.10 - '@storybook/core-events': 7.6.10 + '@storybook/channels': 7.6.17 + '@storybook/client-logger': 7.6.17 + '@storybook/core-events': 7.6.17 '@storybook/csf': 0.1.2 '@storybook/global': 5.0.0 - '@storybook/router': 7.6.10 - '@storybook/theming': 7.6.10(react-dom@18.2.0)(react@18.2.0) - '@storybook/types': 7.6.10 + '@storybook/router': 7.6.17 + '@storybook/theming': 7.6.17(react-dom@18.2.0)(react@18.2.0) + '@storybook/types': 7.6.17 dequal: 2.0.3 lodash: 4.17.21 memoizerific: 1.11.3 - store2: 2.14.2 + store2: 2.14.3 telejson: 7.2.0 ts-dedent: 2.2.0 transitivePeerDependencies: @@ -5327,31 +5485,31 @@ packages: - react-dom dev: true - /@storybook/manager@7.6.10: - resolution: {integrity: sha512-Co3sLCbNYY6O4iH2ggmRDLCPWLj03JE5s/DOG8OVoXc6vBwTc/Qgiyrsxxp6BHQnPpM0mxL6aKAxE3UjsW/Nog==} + /@storybook/manager@7.6.17: + resolution: {integrity: sha512-A1LDDIqMpwRzq/dqkbbiza0QI04o4ZHCl2a3UMDZUV/+QLc2nsr2DAaLk4CVL4/cIc5zGqmIcaOTvprx2YKVBw==} dev: true /@storybook/mdx2-csf@1.1.0: resolution: {integrity: sha512-TXJJd5RAKakWx4BtpwvSNdgTDkKM6RkXU8GK34S/LhidQ5Pjz3wcnqb0TxEkfhK/ztbP8nKHqXFwLfa2CYkvQw==} dev: true - /@storybook/node-logger@7.6.10: - resolution: {integrity: sha512-ZBuqrv4bjJzKXyfRGFkVIi+z6ekn6rOPoQao4KmsfLNQAUUsEdR8Baw/zMnnU417zw5dSEaZdpuwx75SCQAeOA==} + /@storybook/node-logger@7.6.17: + resolution: {integrity: sha512-w59MQuXhhUNrUVmVkXhMwIg2nvFWjdDczLTwYLorhfsE36CWeUOY5QCZWQy0Qf/h+jz8Uo7Evy64qn18v9C4wA==} dev: true - /@storybook/postinstall@7.6.10: - resolution: {integrity: sha512-SMdXtednPCy3+SRJ7oN1OPN1oVFhj3ih+ChOEX8/kZ5J3nfmV3wLPtsZvFGUCf0KWQEP1xL+1Urv48mzMKcV/w==} + /@storybook/postinstall@7.6.17: + resolution: {integrity: sha512-WaWqB8o9vUc9aaVls+povQSVirf1Xd1LZcVhUKfAocAF3mzYUsnJsVqvnbjRj/F96UFVihOyDt9Zjl/9OvrCvQ==} dev: true - /@storybook/preview-api@7.6.10: - resolution: {integrity: sha512-5A3etoIwZCx05yuv3KSTv1wynN4SR4rrzaIs/CTBp3BC4q1RBL+Or/tClk0IJPXQMlx/4Y134GtNIBbkiDofpw==} + /@storybook/preview-api@7.6.17: + resolution: {integrity: sha512-wLfDdI9RWo1f2zzFe54yRhg+2YWyxLZvqdZnSQ45mTs4/7xXV5Wfbv3QNTtcdw8tT3U5KRTrN1mTfTCiRJc0Kw==} dependencies: - '@storybook/channels': 7.6.10 - '@storybook/client-logger': 7.6.10 - '@storybook/core-events': 7.6.10 + '@storybook/channels': 7.6.17 + '@storybook/client-logger': 7.6.17 + '@storybook/core-events': 7.6.17 '@storybook/csf': 0.1.2 '@storybook/global': 5.0.0 - '@storybook/types': 7.6.10 + '@storybook/types': 7.6.17 '@types/qs': 6.9.11 dequal: 2.0.3 lodash: 4.17.21 @@ -5362,12 +5520,12 @@ packages: util-deprecate: 1.0.2 dev: true - /@storybook/preview@7.6.10: - resolution: {integrity: sha512-F07BzVXTD3byq+KTWtvsw3pUu3fQbyiBNLFr2CnfU4XSdLKja5lDt8VqDQq70TayVQOf5qfUTzRd4M6pQkjw1w==} + /@storybook/preview@7.6.17: + resolution: {integrity: sha512-LvkMYK/y6alGjwRVNDIKL1lFlbyZ0H0c8iAbcQkiMoaFiujMQyVswMDKlWcj42Upfr/B1igydiruomc+eUt0mw==} dev: true - /@storybook/react-dom-shim@7.6.10(react-dom@18.2.0)(react@18.2.0): - resolution: {integrity: sha512-M+N/h6ximacaFdIDjMN2waNoWwApeVYTpFeoDppiFTvdBTXChyIuiPgYX9QSg7gDz92OaA52myGOot4wGvXVzg==} + /@storybook/react-dom-shim@7.6.17(react-dom@18.2.0)(react@18.2.0): + resolution: {integrity: sha512-32Sa/G+WnvaPiQ1Wvjjw5UM9rr2c4GDohwCcWVv3/LJuiFPqNS6zglAtmnsrlIBnUwRBMLMh/ekCTdqMiUmfDw==} peerDependencies: react: ^16.8.0 || ^17.0.0 || ^18.0.0 react-dom: ^16.8.0 || ^17.0.0 || ^18.0.0 @@ -5376,24 +5534,24 @@ packages: react-dom: 18.2.0(react@18.2.0) dev: true - /@storybook/react-vite@7.6.10(react-dom@18.2.0)(react@18.2.0)(typescript@5.3.3)(vite@5.0.12): - resolution: {integrity: sha512-YE2+J1wy8nO+c6Nv/hBMu91Edew3K184L1KSnfoZV8vtq2074k1Me/8pfe0QNuq631AncpfCYNb37yBAXQ/80w==} + /@storybook/react-vite@7.6.17(react-dom@18.2.0)(react@18.2.0)(typescript@5.3.3)(vite@5.1.3): + resolution: {integrity: sha512-4dIm3CuRl44X1TLzN3WoZh/bChzJF7Ud28li9atj9C8db0bb/y0zl8cahrsRFoR7/LyfqdOVLqaztrnA5SsWfg==} engines: {node: '>=16'} peerDependencies: react: ^16.8.0 || ^17.0.0 || ^18.0.0 react-dom: ^16.8.0 || ^17.0.0 || ^18.0.0 vite: ^3.0.0 || ^4.0.0 || ^5.0.0 dependencies: - '@joshwooding/vite-plugin-react-docgen-typescript': 0.3.0(typescript@5.3.3)(vite@5.0.12) + '@joshwooding/vite-plugin-react-docgen-typescript': 0.3.0(typescript@5.3.3)(vite@5.1.3) '@rollup/pluginutils': 5.1.0 - '@storybook/builder-vite': 7.6.10(typescript@5.3.3)(vite@5.0.12) - '@storybook/react': 7.6.10(react-dom@18.2.0)(react@18.2.0)(typescript@5.3.3) - '@vitejs/plugin-react': 3.1.0(vite@5.0.12) - magic-string: 0.30.5 + '@storybook/builder-vite': 7.6.17(typescript@5.3.3)(vite@5.1.3) + '@storybook/react': 7.6.17(react-dom@18.2.0)(react@18.2.0)(typescript@5.3.3) + '@vitejs/plugin-react': 3.1.0(vite@5.1.3) + magic-string: 0.30.7 react: 18.2.0 react-docgen: 7.0.3 react-dom: 18.2.0(react@18.2.0) - vite: 5.0.12(@types/node@20.11.5) + vite: 5.1.3(@types/node@20.11.19) transitivePeerDependencies: - '@preact/preset-vite' - encoding @@ -5403,8 +5561,8 @@ packages: - vite-plugin-glimmerx dev: true - /@storybook/react@7.6.10(react-dom@18.2.0)(react@18.2.0)(typescript@5.3.3): - resolution: {integrity: sha512-wwBn1cg2uZWW4peqqBjjU7XGmFq8HdkVUtWwh6dpfgmlY1Aopi+vPgZt7pY9KkWcTOq5+DerMdSfwxukpc3ajQ==} + /@storybook/react@7.6.17(react-dom@18.2.0)(react@18.2.0)(typescript@5.3.3): + resolution: {integrity: sha512-lVqzQSU03rRJWYW+gK2gq6mSo3/qtnVICY8B8oP7gc36jVu4ksDIu45bTfukM618ODkUZy0vZe6T4engK3azjA==} engines: {node: '>=16.0.0'} peerDependencies: react: ^16.8.0 || ^17.0.0 || ^18.0.0 @@ -5414,16 +5572,16 @@ packages: typescript: optional: true dependencies: - '@storybook/client-logger': 7.6.10 - '@storybook/core-client': 7.6.10 - '@storybook/docs-tools': 7.6.10 + '@storybook/client-logger': 7.6.17 + '@storybook/core-client': 7.6.17 + '@storybook/docs-tools': 7.6.17 '@storybook/global': 5.0.0 - '@storybook/preview-api': 7.6.10 - '@storybook/react-dom-shim': 7.6.10(react-dom@18.2.0)(react@18.2.0) - '@storybook/types': 7.6.10 + '@storybook/preview-api': 7.6.17 + '@storybook/react-dom-shim': 7.6.17(react-dom@18.2.0)(react@18.2.0) + '@storybook/types': 7.6.17 '@types/escodegen': 0.0.6 '@types/estree': 0.0.51 - '@types/node': 18.19.8 + '@types/node': 18.19.17 acorn: 7.4.1 acorn-jsx: 5.3.2(acorn@7.4.1) acorn-walk: 7.2.0 @@ -5443,30 +5601,30 @@ packages: - supports-color dev: true - /@storybook/router@7.6.10: - resolution: {integrity: sha512-G/H4Jn2+y8PDe8Zbq4DVxF/TPn0/goSItdILts39JENucHiuGBCjKjSWGBe1rkwKi1tUbB3yhxJVrLagxFEPpQ==} + /@storybook/router@7.6.17: + resolution: {integrity: sha512-GnyC0j6Wi5hT4qRhSyT8NPtJfGmf82uZw97LQRWeyYu5gWEshUdM7aj40XlNiScd5cZDp0owO1idduVF2k2l2A==} dependencies: - '@storybook/client-logger': 7.6.10 + '@storybook/client-logger': 7.6.17 memoizerific: 1.11.3 qs: 6.11.2 dev: true - /@storybook/source-loader@7.6.10: - resolution: {integrity: sha512-S3nOWyj+sdpsqJqKGIN3DKE1q+Q0KYxEyPlPCawMFazozUH7tOodTIqmHBqJZCSNqdC4M1S/qcL8vpP4PfXhuA==} + /@storybook/source-loader@7.6.17: + resolution: {integrity: sha512-90v1es7dHmHgkGbflPlaRBYcn2+mqdC8OG4QtyYqOUq6xsLsyg+5CX2rupfHbuSLw9r0A3o1ViOII2J/kWtFow==} dependencies: '@storybook/csf': 0.1.2 - '@storybook/types': 7.6.10 + '@storybook/types': 7.6.17 estraverse: 5.3.0 lodash: 4.17.21 prettier: 2.8.8 dev: true - /@storybook/telemetry@7.6.10: - resolution: {integrity: sha512-p3mOSUtIyy2tF1z6pQXxNh1JzYFcAm97nUgkwLzF07GfEdVAPM+ftRSLFbD93zVvLEkmLTlsTiiKaDvOY/lQWg==} + /@storybook/telemetry@7.6.17: + resolution: {integrity: sha512-WOcOAmmengYnGInH98Px44F47DSpLyk20BM+Z/IIQDzfttGOLlxNqBBG1XTEhNRn+AYuk4aZ2JEed2lCjVIxcA==} dependencies: - '@storybook/client-logger': 7.6.10 - '@storybook/core-common': 7.6.10 - '@storybook/csf-tools': 7.6.10 + '@storybook/client-logger': 7.6.17 + '@storybook/core-common': 7.6.17 + '@storybook/csf-tools': 7.6.17 chalk: 4.1.2 detect-package-manager: 2.0.1 fetch-retry: 5.0.6 @@ -5477,15 +5635,15 @@ packages: - supports-color dev: true - /@storybook/test@7.6.10(vitest@1.2.2): - resolution: {integrity: sha512-dn/T+HcWOBlVh3c74BHurp++BaqBoQgNbSIaXlYDpJoZ+DzNIoEQVsWFYm5gCbtKK27iFd4n52RiQI3f6Vblqw==} + /@storybook/test@7.6.17(vitest@1.3.1): + resolution: {integrity: sha512-WGrmUUtKiuq3bzDsN4MUvluGcX120jwczMik1GDTyxS+JBoe7P0t2Y8dDuVs/l3nZd1J7qY4z0RGxMDYqONIOw==} dependencies: - '@storybook/client-logger': 7.6.10 - '@storybook/core-events': 7.6.10 - '@storybook/instrumenter': 7.6.10 - '@storybook/preview-api': 7.6.10 + '@storybook/client-logger': 7.6.17 + '@storybook/core-events': 7.6.17 + '@storybook/instrumenter': 7.6.17 + '@storybook/preview-api': 7.6.17 '@testing-library/dom': 9.3.4 - '@testing-library/jest-dom': 6.2.0(vitest@1.2.2) + '@testing-library/jest-dom': 6.4.2(vitest@1.3.1) '@testing-library/user-event': 14.3.0(@testing-library/dom@9.3.4) '@types/chai': 4.3.11 '@vitest/expect': 0.34.7 @@ -5494,36 +5652,37 @@ packages: util: 0.12.5 transitivePeerDependencies: - '@jest/globals' + - '@types/bun' - '@types/jest' - jest - vitest dev: true - /@storybook/theming@7.6.10(react-dom@18.2.0)(react@18.2.0): - resolution: {integrity: sha512-f5tuy7yV3TOP3fIboSqpgLHy0wKayAw/M8HxX0jVET4Z4fWlFK0BiHJabQ+XEdAfQM97XhPFHB2IPbwsqhCEcQ==} + /@storybook/theming@7.6.17(react-dom@18.2.0)(react@18.2.0): + resolution: {integrity: sha512-ZbaBt3KAbmBtfjNqgMY7wPMBshhSJlhodyMNQypv+95xLD/R+Az6aBYbpVAOygLaUQaQk4ar7H/Ww6lFIoiFbA==} peerDependencies: react: ^16.8.0 || ^17.0.0 || ^18.0.0 react-dom: ^16.8.0 || ^17.0.0 || ^18.0.0 dependencies: '@emotion/use-insertion-effect-with-fallbacks': 1.0.1(react@18.2.0) - '@storybook/client-logger': 7.6.10 + '@storybook/client-logger': 7.6.17 '@storybook/global': 5.0.0 memoizerific: 1.11.3 react: 18.2.0 react-dom: 18.2.0(react@18.2.0) dev: true - /@storybook/types@7.6.10: - resolution: {integrity: sha512-hcS2HloJblaMpCAj2axgGV+53kgSRYPT0a1PG1IHsZaYQILfHSMmBqM8XzXXYTsgf9250kz3dqFX1l0n3EqMlQ==} + /@storybook/types@7.6.17: + resolution: {integrity: sha512-GRY0xEJQ0PrL7DY2qCNUdIfUOE0Gsue6N+GBJw9ku1IUDFLJRDOF+4Dx2BvYcVCPI5XPqdWKlEyZdMdKjiQN7Q==} dependencies: - '@storybook/channels': 7.6.10 + '@storybook/channels': 7.6.17 '@types/babel__core': 7.20.5 '@types/express': 4.17.21 file-system-cache: 2.3.0 dev: true - /@swc/core-darwin-arm64@1.3.101: - resolution: {integrity: sha512-mNFK+uHNPRXSnfTOG34zJOeMl2waM4hF4a2NY7dkMXrPqw9CoJn4MwTXJcyMiSz1/BnNjjTCHF3Yhj0jPxmkzQ==} + /@swc/core-darwin-arm64@1.4.2: + resolution: {integrity: sha512-1uSdAn1MRK5C1m/TvLZ2RDvr0zLvochgrZ2xL+lRzugLlCTlSA+Q4TWtrZaOz+vnnFVliCpw7c7qu0JouhgQIw==} engines: {node: '>=10'} cpu: [arm64] os: [darwin] @@ -5531,8 +5690,8 @@ packages: dev: true optional: true - /@swc/core-darwin-x64@1.3.101: - resolution: {integrity: sha512-B085j8XOx73Fg15KsHvzYWG262bRweGr3JooO1aW5ec5pYbz5Ew9VS5JKYS03w2UBSxf2maWdbPz2UFAxg0whw==} + /@swc/core-darwin-x64@1.4.2: + resolution: {integrity: sha512-TYD28+dCQKeuxxcy7gLJUCFLqrwDZnHtC2z7cdeGfZpbI2mbfppfTf2wUPzqZk3gEC96zHd4Yr37V3Tvzar+lQ==} engines: {node: '>=10'} cpu: [x64] os: [darwin] @@ -5540,8 +5699,8 @@ packages: dev: true optional: true - /@swc/core-linux-arm-gnueabihf@1.3.101: - resolution: {integrity: sha512-9xLKRb6zSzRGPqdz52Hy5GuB1lSjmLqa0lST6MTFads3apmx4Vgs8Y5NuGhx/h2I8QM4jXdLbpqQlifpzTlSSw==} + /@swc/core-linux-arm-gnueabihf@1.4.2: + resolution: {integrity: sha512-Eyqipf7ZPGj0vplKHo8JUOoU1un2sg5PjJMpEesX0k+6HKE2T8pdyeyXODN0YTFqzndSa/J43EEPXm+rHAsLFQ==} engines: {node: '>=10'} cpu: [arm] os: [linux] @@ -5549,8 +5708,8 @@ packages: dev: true optional: true - /@swc/core-linux-arm64-gnu@1.3.101: - resolution: {integrity: sha512-oE+r1lo7g/vs96Weh2R5l971dt+ZLuhaUX+n3BfDdPxNHfObXgKMjO7E+QS5RbGjv/AwiPCxQmbdCp/xN5ICJA==} + /@swc/core-linux-arm64-gnu@1.4.2: + resolution: {integrity: sha512-wZn02DH8VYPv3FC0ub4my52Rttsus/rFw+UUfzdb3tHMHXB66LqN+rR0ssIOZrH6K+VLN6qpTw9VizjyoH0BxA==} engines: {node: '>=10'} cpu: [arm64] os: [linux] @@ -5558,8 +5717,8 @@ packages: dev: true optional: true - /@swc/core-linux-arm64-musl@1.3.101: - resolution: {integrity: sha512-OGjYG3H4BMOTnJWJyBIovCez6KiHF30zMIu4+lGJTCrxRI2fAjGLml3PEXj8tC3FMcud7U2WUn6TdG0/te2k6g==} + /@swc/core-linux-arm64-musl@1.4.2: + resolution: {integrity: sha512-3G0D5z9hUj9bXNcwmA1eGiFTwe5rWkuL3DsoviTj73TKLpk7u64ND0XjEfO0huVv4vVu9H1jodrKb7nvln/dlw==} engines: {node: '>=10'} cpu: [arm64] os: [linux] @@ -5567,8 +5726,8 @@ packages: dev: true optional: true - /@swc/core-linux-x64-gnu@1.3.101: - resolution: {integrity: sha512-/kBMcoF12PRO/lwa8Z7w4YyiKDcXQEiLvM+S3G9EvkoKYGgkkz4Q6PSNhF5rwg/E3+Hq5/9D2R+6nrkF287ihg==} + /@swc/core-linux-x64-gnu@1.4.2: + resolution: {integrity: sha512-LFxn9U8cjmYHw3jrdPNqPAkBGglKE3tCZ8rA7hYyp0BFxuo7L2ZcEnPm4RFpmSCCsExFH+LEJWuMGgWERoktvg==} engines: {node: '>=10'} cpu: [x64] os: [linux] @@ -5576,8 +5735,8 @@ packages: dev: true optional: true - /@swc/core-linux-x64-musl@1.3.101: - resolution: {integrity: sha512-kDN8lm4Eew0u1p+h1l3JzoeGgZPQ05qDE0czngnjmfpsH2sOZxVj1hdiCwS5lArpy7ktaLu5JdRnx70MkUzhXw==} + /@swc/core-linux-x64-musl@1.4.2: + resolution: {integrity: sha512-dp0fAmreeVVYTUcb4u9njTPrYzKnbIH0EhH2qvC9GOYNNREUu2GezSIDgonjOXkHiTCvopG4xU7y56XtXj4VrQ==} engines: {node: '>=10'} cpu: [x64] os: [linux] @@ -5585,8 +5744,8 @@ packages: dev: true optional: true - /@swc/core-win32-arm64-msvc@1.3.101: - resolution: {integrity: sha512-9Wn8TTLWwJKw63K/S+jjrZb9yoJfJwCE2RV5vPCCWmlMf3U1AXj5XuWOLUX+Rp2sGKau7wZKsvywhheWm+qndQ==} + /@swc/core-win32-arm64-msvc@1.4.2: + resolution: {integrity: sha512-HlVIiLMQkzthAdqMslQhDkoXJ5+AOLUSTV6fm6shFKZKqc/9cJvr4S8UveNERL9zUficA36yM3bbfo36McwnvQ==} engines: {node: '>=10'} cpu: [arm64] os: [win32] @@ -5594,8 +5753,8 @@ packages: dev: true optional: true - /@swc/core-win32-ia32-msvc@1.3.101: - resolution: {integrity: sha512-onO5KvICRVlu2xmr4//V2je9O2XgS1SGKpbX206KmmjcJhXN5EYLSxW9qgg+kgV5mip+sKTHTAu7IkzkAtElYA==} + /@swc/core-win32-ia32-msvc@1.4.2: + resolution: {integrity: sha512-WCF8faPGjCl4oIgugkp+kL9nl3nUATlzKXCEGFowMEmVVCFM0GsqlmGdPp1pjZoWc9tpYanoXQDnp5IvlDSLhA==} engines: {node: '>=10'} cpu: [ia32] os: [win32] @@ -5603,8 +5762,8 @@ packages: dev: true optional: true - /@swc/core-win32-x64-msvc@1.3.101: - resolution: {integrity: sha512-T3GeJtNQV00YmiVw/88/nxJ/H43CJvFnpvBHCVn17xbahiVUOPOduh3rc9LgAkKiNt/aV8vU3OJR+6PhfMR7UQ==} + /@swc/core-win32-x64-msvc@1.4.2: + resolution: {integrity: sha512-oV71rwiSpA5xre2C5570BhCsg1HF97SNLsZ/12xv7zayGzqr3yvFALFJN8tHKpqUdCB4FGPjoP3JFdV3i+1wUw==} engines: {node: '>=10'} cpu: [x64] os: [win32] @@ -5612,8 +5771,8 @@ packages: dev: true optional: true - /@swc/core@1.3.101: - resolution: {integrity: sha512-w5aQ9qYsd/IYmXADAnkXPGDMTqkQalIi+kfFf/MHRKTpaOL7DHjMXwPp/n8hJ0qNjRvchzmPtOqtPBiER50d8A==} + /@swc/core@1.4.2: + resolution: {integrity: sha512-vWgY07R/eqj1/a0vsRKLI9o9klGZfpLNOVEnrv4nrccxBgYPjcf22IWwAoaBJ+wpA7Q4fVjCUM8lP0m01dpxcg==} engines: {node: '>=10'} requiresBuild: true peerDependencies: @@ -5622,23 +5781,23 @@ packages: '@swc/helpers': optional: true dependencies: - '@swc/counter': 0.1.2 + '@swc/counter': 0.1.3 '@swc/types': 0.1.5 optionalDependencies: - '@swc/core-darwin-arm64': 1.3.101 - '@swc/core-darwin-x64': 1.3.101 - '@swc/core-linux-arm-gnueabihf': 1.3.101 - '@swc/core-linux-arm64-gnu': 1.3.101 - '@swc/core-linux-arm64-musl': 1.3.101 - '@swc/core-linux-x64-gnu': 1.3.101 - '@swc/core-linux-x64-musl': 1.3.101 - '@swc/core-win32-arm64-msvc': 1.3.101 - '@swc/core-win32-ia32-msvc': 1.3.101 - '@swc/core-win32-x64-msvc': 1.3.101 + '@swc/core-darwin-arm64': 1.4.2 + '@swc/core-darwin-x64': 1.4.2 + '@swc/core-linux-arm-gnueabihf': 1.4.2 + '@swc/core-linux-arm64-gnu': 1.4.2 + '@swc/core-linux-arm64-musl': 1.4.2 + '@swc/core-linux-x64-gnu': 1.4.2 + '@swc/core-linux-x64-musl': 1.4.2 + '@swc/core-win32-arm64-msvc': 1.4.2 + '@swc/core-win32-ia32-msvc': 1.4.2 + '@swc/core-win32-x64-msvc': 1.4.2 dev: true - /@swc/counter@0.1.2: - resolution: {integrity: sha512-9F4ys4C74eSTEUNndnER3VJ15oru2NumfQxS8geE+f3eB5xvfxpWyqE5XlVnxb/R14uoXi6SLbBwwiDSkv+XEw==} + /@swc/counter@0.1.3: + resolution: {integrity: sha512-e2BR4lsJkkRlKZ/qCHPw9ZaSxc0MVUd7gtbtaB7aMvHeJVYe8sOB8DBZkP2DtISHGSku9sCK6T6cnY0CtXrOCQ==} dev: true /@swc/helpers@0.5.6: @@ -5656,7 +5815,7 @@ packages: engines: {node: '>=14'} dependencies: '@babel/code-frame': 7.23.5 - '@babel/runtime': 7.23.8 + '@babel/runtime': 7.23.9 '@types/aria-query': 5.0.4 aria-query: 5.1.3 chalk: 4.1.2 @@ -5665,17 +5824,20 @@ packages: pretty-format: 27.5.1 dev: true - /@testing-library/jest-dom@6.2.0(vitest@1.2.2): - resolution: {integrity: sha512-+BVQlJ9cmEn5RDMUS8c2+TU6giLvzaHZ8sU/x0Jj7fk+6/46wPdwlgOPcpxS17CjcanBi/3VmGMqVr2rmbUmNw==} + /@testing-library/jest-dom@6.4.2(vitest@1.3.1): + resolution: {integrity: sha512-CzqH0AFymEMG48CpzXFriYYkOjk6ZGPCLMhW9e9jg3KMCn5OfJecF8GtGW7yGfR/IgCe3SX8BSwjdzI6BBbZLw==} engines: {node: '>=14', npm: '>=6', yarn: '>=1'} peerDependencies: '@jest/globals': '>= 28' + '@types/bun': latest '@types/jest': '>= 28' jest: '>= 28' vitest: '>= 0.32' peerDependenciesMeta: '@jest/globals': optional: true + '@types/bun': + optional: true '@types/jest': optional: true jest: @@ -5683,15 +5845,15 @@ packages: vitest: optional: true dependencies: - '@adobe/css-tools': 4.3.2 - '@babel/runtime': 7.23.8 + '@adobe/css-tools': 4.3.3 + '@babel/runtime': 7.23.9 aria-query: 5.3.0 chalk: 3.0.0 css.escape: 1.5.1 dom-accessibility-api: 0.6.3 lodash: 4.17.21 redent: 3.0.0 - vitest: 1.2.2(@types/node@20.11.5) + vitest: 1.3.1(@types/node@20.11.19) dev: true /@testing-library/user-event@14.3.0(@testing-library/dom@9.3.4): @@ -5714,8 +5876,8 @@ packages: /@types/babel__core@7.20.5: resolution: {integrity: sha512-qoQprZvz5wQFJwMDqeseRXWv3rqMvhgpbXFfVyWhbx9X47POIA6i/+dXefEmZKoAgOaTdaIgNSMqMIU61yRyzA==} dependencies: - '@babel/parser': 7.23.6 - '@babel/types': 7.23.6 + '@babel/parser': 7.23.9 + '@babel/types': 7.23.9 '@types/babel__generator': 7.6.8 '@types/babel__template': 7.4.4 '@types/babel__traverse': 7.20.5 @@ -5724,27 +5886,27 @@ packages: /@types/babel__generator@7.6.8: resolution: {integrity: sha512-ASsj+tpEDsEiFr1arWrlN6V3mdfjRMZt6LtK/Vp/kreFLnr5QH5+DhvD5nINYZXzwJvXeGq+05iUXcAzVrqWtw==} dependencies: - '@babel/types': 7.23.6 + '@babel/types': 7.23.9 dev: true /@types/babel__template@7.4.4: resolution: {integrity: sha512-h/NUaSyG5EyxBIp8YRxo4RMe2/qQgvyowRwVMzhYhBCONbW8PUsg4lkFMrhgZhUe5z3L3MiLDuvyJ/CaPa2A8A==} dependencies: - '@babel/parser': 7.23.6 - '@babel/types': 7.23.6 + '@babel/parser': 7.23.9 + '@babel/types': 7.23.9 dev: true /@types/babel__traverse@7.20.5: resolution: {integrity: sha512-WXCyOcRtH37HAUkpXhUduaxdm82b4GSlyTqajXviN4EfiuPgNYR109xMCKvpl6zPIpua0DGlMEDCq+g8EdoheQ==} dependencies: - '@babel/types': 7.23.6 + '@babel/types': 7.23.9 dev: true /@types/body-parser@1.19.5: resolution: {integrity: sha512-fB3Zu92ucau0iQ0JMCFQE7b/dv8Ot07NI3KaZIkIUNXq82k4eBAqUaneXfleGY9JWskeS9y+u0nXMyspcuQrCg==} dependencies: '@types/connect': 3.4.38 - '@types/node': 20.11.5 + '@types/node': 20.11.19 dev: true /@types/chai@4.3.11: @@ -5754,13 +5916,13 @@ packages: /@types/connect@3.4.38: resolution: {integrity: sha512-K6uROf1LD88uDQqJCktA4yzL1YYAK6NgfsI0v/mTgyPKWsX1CnJ0XPSDhViejru1GcRkLWb8RlzFYJRqGUbaug==} dependencies: - '@types/node': 20.11.5 + '@types/node': 20.11.19 dev: true /@types/cross-spawn@6.0.6: resolution: {integrity: sha512-fXRhhUkG4H3TQk5dBhQ7m/JDdSNHKwR2BBia62lhwEIq9xGiQKLxd6LymNhn47SjXhsUEPmxi+PKw2OkW4LLjA==} dependencies: - '@types/node': 20.11.5 + '@types/node': 20.11.19 dev: true /@types/d3-array@3.2.1: @@ -5791,7 +5953,7 @@ packages: resolution: {integrity: sha512-BjzLgXGnCWjUSYGfH1cpdo41/hgdWETu4YxpezoztawmqsvCeep+8QGfiY6YbDvfgHz/DkjeIkkZVJavB4a3rg==} dependencies: '@types/d3-array': 3.2.1 - '@types/geojson': 7946.0.13 + '@types/geojson': 7946.0.14 dev: false /@types/d3-delaunay@6.0.4: @@ -5833,7 +5995,7 @@ packages: /@types/d3-geo@3.1.0: resolution: {integrity: sha512-856sckF0oP/diXtS4jNsiQw/UuK5fQG8l/a9VVLeSouf1/PPbBE1i1W852zVwKwYCBkFJJB7nCFTbk6UMEXBOQ==} dependencies: - '@types/geojson': 7946.0.13 + '@types/geojson': 7946.0.14 dev: false /@types/d3-hierarchy@3.1.6: @@ -5846,8 +6008,8 @@ packages: '@types/d3-color': 3.1.3 dev: false - /@types/d3-path@3.0.2: - resolution: {integrity: sha512-WAIEVlOCdd/NKRYTsqCpOMHQHemKBEINf8YXMYOtXH0GA7SY0dqMB78P3Uhgfy+4X+/Mlw2wDtlETkN6kQUCMA==} + /@types/d3-path@3.1.0: + resolution: {integrity: sha512-P2dlU/q51fkOc/Gfl3Ul9kicV7l+ra934qBFXCFhrZMOL6du1TM0pm1ThYvENukyOn5h9v+yMJ9Fn5JK4QozrQ==} dev: false /@types/d3-polygon@3.0.2: @@ -5879,7 +6041,7 @@ packages: /@types/d3-shape@3.1.6: resolution: {integrity: sha512-5KKk5aKGu2I+O6SONMYSNflgiP0WfZIQvVUMan50wHsLG1G94JlxEVnCpQARfTtzytuY0p/9PXXZb3I7giofIA==} dependencies: - '@types/d3-path': 3.0.2 + '@types/d3-path': 3.1.0 dev: false /@types/d3-time-format@4.0.3: @@ -5927,7 +6089,7 @@ packages: '@types/d3-geo': 3.1.0 '@types/d3-hierarchy': 3.1.6 '@types/d3-interpolate': 3.0.4 - '@types/d3-path': 3.0.2 + '@types/d3-path': 3.1.0 '@types/d3-polygon': 3.0.2 '@types/d3-quadtree': 3.0.6 '@types/d3-random': 3.0.3 @@ -5974,8 +6136,8 @@ packages: resolution: {integrity: sha512-AjwI4MvWx3HAOaZqYsjKWyEObT9lcVV0Y0V8nXo6cXzN8ZiMxVhf6F3d/UNvXVGKrEzL/Dluc5p+y9GkzlTWig==} dev: true - /@types/eslint@8.56.0: - resolution: {integrity: sha512-FlsN0p4FhuYRjIxpbdXovvHQhtlG05O1GG/RNWvdAxTboR438IOTwmrY/vLA+Xfgg06BTkP045M3vpFwTMv1dg==} + /@types/eslint@8.56.2: + resolution: {integrity: sha512-uQDwm1wFHmbBbCZCqAlq6Do9LYwByNZHWzXppSnay9SuwJ+VRbjkbLABer54kcPnMSlG6Fdiy2yaFXm/z9Z5gw==} dependencies: '@types/estree': 1.0.5 '@types/json-schema': 7.0.15 @@ -5989,10 +6151,10 @@ packages: resolution: {integrity: sha512-/kYRxGDLWzHOB7q+wtSUQlFrtcdUccpfy+X+9iMBpHK8QLLhx2wIPYuS5DYtR9Wa/YlZAbIovy7qVdB1Aq6Lyw==} dev: true - /@types/express-serve-static-core@4.17.41: - resolution: {integrity: sha512-OaJ7XLaelTgrvlZD8/aa0vvvxZdUmlCn6MtWeB7TkiKW70BQLc9XEPpDLPdbo52ZhXUCrznlWdCHWxJWtdyajA==} + /@types/express-serve-static-core@4.17.43: + resolution: {integrity: sha512-oaYtiBirUOPQGSWNGPWnzyAFJ0BP3cwvN4oWZQY+zUBwpVIGsKUkpBpSztp74drYcjavs7SKFZ4DX1V2QeN8rg==} dependencies: - '@types/node': 20.11.5 + '@types/node': 20.11.19 '@types/qs': 6.9.11 '@types/range-parser': 1.2.7 '@types/send': 0.17.4 @@ -6002,7 +6164,7 @@ packages: resolution: {integrity: sha512-ejlPM315qwLpaQlQDTjPdsUFSc6ZsP4AN6AlWnogPjQ7CVi7PYF3YVz+CY3jE2pwYf7E/7HlDAN0rV2GxTG0HQ==} dependencies: '@types/body-parser': 1.19.5 - '@types/express-serve-static-core': 4.17.41 + '@types/express-serve-static-core': 4.17.43 '@types/qs': 6.9.11 '@types/serve-static': 1.15.5 dev: true @@ -6011,21 +6173,21 @@ packages: resolution: {integrity: sha512-frsJrz2t/CeGifcu/6uRo4b+SzAwT4NYCVPu1GN8IB9XTzrpPkGuV0tmh9mN+/L0PklAlsC3u5Fxt0ju00LXIw==} dev: true - /@types/geojson@7946.0.13: - resolution: {integrity: sha512-bmrNrgKMOhM3WsafmbGmC+6dsF2Z308vLFsQ3a/bT8X8Sv5clVYpPars/UPq+sAaJP+5OoLAYgwbkS5QEJdLUQ==} + /@types/geojson@7946.0.14: + resolution: {integrity: sha512-WCfD5Ht3ZesJUsONdhvm84dmzWOiOzOAqOncN0++w0lBw1o8OuDNJF2McvvCef/yBqb/HYRahp1BYtODFQ8bRg==} dev: false /@types/glob@7.2.0: resolution: {integrity: sha512-ZUxbzKl0IfJILTS6t7ip5fQQM/J3TJYubDm3nMbgubNNYS62eXeUpoLUC8/7fJNiFYHTrGPQn7hspDUzIHX3UA==} dependencies: '@types/minimatch': 5.1.2 - '@types/node': 20.11.5 + '@types/node': 20.11.19 dev: true /@types/graceful-fs@4.1.9: resolution: {integrity: sha512-olP3sd1qOEe5dXTSaFvQG+02VdRXcdytWLAZsAq1PecU8uqQAhkrnbli7DagjtXKW/Bl7YJbUsa8MPcuc8LHEQ==} dependencies: - '@types/node': 20.11.5 + '@types/node': 20.11.19 dev: true /@types/http-errors@2.0.4: @@ -6075,8 +6237,8 @@ packages: /@types/lodash@4.14.202: resolution: {integrity: sha512-OvlIYQK9tNneDlS0VN54LLd5uiPCBOp7gS5Z0f1mjoJYBrtStzgmJBxONW3U6OZqdtNzZPmn9BS/7WI7BFFcFQ==} - /@types/mdx@2.0.10: - resolution: {integrity: sha512-Rllzc5KHk0Al5/WANwgSPl1/CwjqCy+AZrGd78zuK+jO9aDM6ffblZ+zIjgPNAaEBmlO0RYDvLNh7wD0zKVgEg==} + /@types/mdx@2.0.11: + resolution: {integrity: sha512-HM5bwOaIQJIQbAYfax35HCKxx7a3KrK3nBtIqJgSOitivTD1y3oW9P3rxY9RkXYPUk7y/AjAohfHKmFpGE79zw==} dev: true /@types/mime-types@2.1.4: @@ -6098,18 +6260,18 @@ packages: /@types/node-fetch@2.6.11: resolution: {integrity: sha512-24xFj9R5+rfQJLRyM56qh+wnVSYhyXC2tkoBndtY0U+vubqNsYXGjufB2nn8Q6gt0LrARwL6UBtMCSVCwl4B1g==} dependencies: - '@types/node': 20.11.5 + '@types/node': 20.11.19 form-data: 4.0.0 dev: true - /@types/node@18.19.8: - resolution: {integrity: sha512-g1pZtPhsvGVTwmeVoexWZLTQaOvXwoSq//pTL0DHeNzUDrFnir4fgETdhjhIxjVnN+hKOuh98+E1eMLnUXstFg==} + /@types/node@18.19.17: + resolution: {integrity: sha512-SzyGKgwPzuWp2SHhlpXKzCX0pIOfcI4V2eF37nNBJOhwlegQ83omtVQ1XxZpDE06V/d6AQvfQdPfnw0tRC//Ng==} dependencies: undici-types: 5.26.5 dev: true - /@types/node@20.11.5: - resolution: {integrity: sha512-g557vgQjUUfN76MZAN/dt1z3dzcUsimuysco0KeluHgrPdJXkP/XdAURgyO2W9fZWHRtRBiVKzKn8vyOAwlG+w==} + /@types/node@20.11.19: + resolution: {integrity: sha512-7xMnVEcZFu0DikYjWOlRq7NTPETrm7teqUT2WkQjrTIkEgUyyGdWsj/Zg8bEJt5TNklzbPD1X3fqfsHw3SpapQ==} dependencies: undici-types: 5.26.5 dev: true @@ -6137,26 +6299,26 @@ packages: resolution: {integrity: sha512-hKormJbkJqzQGhziax5PItDUTMAM9uE2XXQmM37dyd4hVM+5aVl7oVxMVUiVQn2oCQFN/LKCZdvSM0pFRqbSmQ==} dev: true - /@types/react-dom@18.2.18: - resolution: {integrity: sha512-TJxDm6OfAX2KJWJdMEVTwWke5Sc/E/RlnPGvGfS0W7+6ocy2xhDVQVh/KvC2Uf7kACs+gDytdusDSdWfWkaNzw==} + /@types/react-dom@18.2.19: + resolution: {integrity: sha512-aZvQL6uUbIJpjZk4U8JZGbau9KDeAwMfmhyWorxgBkqDIEf6ROjRozcmPIicqsUwPUjbkDfHKgGee1Lq65APcA==} dependencies: - '@types/react': 18.2.48 + '@types/react': 18.2.57 dev: true /@types/react-reconciler@0.28.8: resolution: {integrity: sha512-SN9c4kxXZonFhbX4hJrZy37yw9e7EIxcpHCxQv5JUS18wDE5ovkQKlqQEkufdJCCMfuI9BnjUJvhYeJ9x5Ra7g==} dependencies: - '@types/react': 18.2.48 + '@types/react': 18.2.57 dev: false /@types/react-transition-group@4.4.10: resolution: {integrity: sha512-hT/+s0VQs2ojCX823m60m5f0sL5idt9SO6Tj6Dg+rdphGPIeJbJ6CxvBYkgkGKrYeDjvIpKTR38UzmtHJOGW3Q==} dependencies: - '@types/react': 18.2.48 + '@types/react': 18.2.57 dev: false - /@types/react@18.2.48: - resolution: {integrity: sha512-qboRCl6Ie70DQQG9hhNREz81jqC1cs9EVNcjQ1AU+jH6NFfSAhVVbrrY/+nSF+Bsk4AOwm9Qa61InvMCyV+H3w==} + /@types/react@18.2.57: + resolution: {integrity: sha512-ZvQsktJgSYrQiMirAN60y4O/LRevIV8hUzSOSNB6gfR3/o3wCBFQx3sPwIYtuDMeiVgsSS3UzCV26tEzgnfvQw==} dependencies: '@types/prop-types': 15.7.11 '@types/scheduler': 0.16.8 @@ -6169,15 +6331,15 @@ packages: /@types/scheduler@0.16.8: resolution: {integrity: sha512-WZLiwShhwLRmeV6zH+GkbOFT6Z6VklCItrDioxUnv+u4Ll+8vKeFySoFyK/0ctcRpOmwAicELfmys1sDc/Rw+A==} - /@types/semver@7.5.6: - resolution: {integrity: sha512-dn1l8LaMea/IjDoHNd9J52uBbInB796CDffS6VdIxvqYCPSG0V0DzHp76GpaWnlhg88uYyPbXCDIowa86ybd5A==} + /@types/semver@7.5.7: + resolution: {integrity: sha512-/wdoPq1QqkSj9/QOeKkFquEuPzQbHTWAMPH/PaUMB+JuR31lXhlWXRZ52IpfDYVlDOUBvX09uBrPwxGT1hjNBg==} dev: true /@types/send@0.17.4: resolution: {integrity: sha512-x2EM6TJOybec7c52BX0ZspPodMsQUd5L6PRwOunVyVUhXiBSKf3AezDL8Dgvgt5o0UfKNfuA0eMLr2wLT4AiBA==} dependencies: '@types/mime': 1.3.5 - '@types/node': 20.11.5 + '@types/node': 20.11.19 dev: true /@types/serve-static@1.15.5: @@ -6185,7 +6347,7 @@ packages: dependencies: '@types/http-errors': 2.0.4 '@types/mime': 3.0.4 - '@types/node': 20.11.5 + '@types/node': 20.11.19 dev: true /@types/unist@2.0.10: @@ -6196,8 +6358,8 @@ packages: resolution: {integrity: sha512-EwmlvuaxPNej9+T4v5AuBPJa2x2UOJVdjCtDHgcDqitUeOtjnJKJ+apYjVcAoBEMjKW1VVFGZLUb5+qqa09XFA==} dev: false - /@types/uuid@9.0.7: - resolution: {integrity: sha512-WUtIVRUZ9i5dYXefDEAI7sh9/O7jGvHg7Df/5O/gtH3Yabe5odI3UWopVR1qbPXQtvOxWu3mM4XxlYeZtMWF4g==} + /@types/uuid@9.0.8: + resolution: {integrity: sha512-jg+97EGIcY9AGHJJRaaPVgetKDsrTgbRjQ5Msgjh/DQKEFl0DtyRr/VCOyD1T2R1MNeWPK/u7JoGhlDZnKBAfA==} dev: true /@types/yargs-parser@21.0.3: @@ -6216,49 +6378,49 @@ packages: '@types/yargs-parser': 21.0.3 dev: true - /@typescript-eslint/eslint-plugin@6.19.0(@typescript-eslint/parser@6.19.0)(eslint@8.56.0)(typescript@5.3.3): - resolution: {integrity: sha512-DUCUkQNklCQYnrBSSikjVChdc84/vMPDQSgJTHBZ64G9bA9w0Crc0rd2diujKbTdp6w2J47qkeHQLoi0rpLCdg==} + /@typescript-eslint/eslint-plugin@7.0.2(@typescript-eslint/parser@7.0.2)(eslint@8.56.0)(typescript@5.3.3): + resolution: {integrity: sha512-/XtVZJtbaphtdrWjr+CJclaCVGPtOdBpFEnvtNf/jRV0IiEemRrL0qABex/nEt8isYcnFacm3nPHYQwL+Wb7qg==} engines: {node: ^16.0.0 || >=18.0.0} peerDependencies: - '@typescript-eslint/parser': ^6.0.0 || ^6.0.0-alpha - eslint: ^7.0.0 || ^8.0.0 + '@typescript-eslint/parser': ^7.0.0 + eslint: ^8.56.0 typescript: '*' peerDependenciesMeta: typescript: optional: true dependencies: '@eslint-community/regexpp': 4.10.0 - '@typescript-eslint/parser': 6.19.0(eslint@8.56.0)(typescript@5.3.3) - '@typescript-eslint/scope-manager': 6.19.0 - '@typescript-eslint/type-utils': 6.19.0(eslint@8.56.0)(typescript@5.3.3) - '@typescript-eslint/utils': 6.19.0(eslint@8.56.0)(typescript@5.3.3) - '@typescript-eslint/visitor-keys': 6.19.0 + '@typescript-eslint/parser': 7.0.2(eslint@8.56.0)(typescript@5.3.3) + '@typescript-eslint/scope-manager': 7.0.2 + '@typescript-eslint/type-utils': 7.0.2(eslint@8.56.0)(typescript@5.3.3) + '@typescript-eslint/utils': 7.0.2(eslint@8.56.0)(typescript@5.3.3) + '@typescript-eslint/visitor-keys': 7.0.2 debug: 4.3.4 eslint: 8.56.0 graphemer: 1.4.0 - ignore: 5.3.0 + ignore: 5.3.1 natural-compare: 1.4.0 - semver: 7.5.4 - ts-api-utils: 1.0.3(typescript@5.3.3) + semver: 7.6.0 + ts-api-utils: 1.2.1(typescript@5.3.3) typescript: 5.3.3 transitivePeerDependencies: - supports-color dev: true - /@typescript-eslint/parser@6.19.0(eslint@8.56.0)(typescript@5.3.3): - resolution: {integrity: sha512-1DyBLG5SH7PYCd00QlroiW60YJ4rWMuUGa/JBV0iZuqi4l4IK3twKPq5ZkEebmGqRjXWVgsUzfd3+nZveewgow==} + /@typescript-eslint/parser@7.0.2(eslint@8.56.0)(typescript@5.3.3): + resolution: {integrity: sha512-GdwfDglCxSmU+QTS9vhz2Sop46ebNCXpPPvsByK7hu0rFGRHL+AusKQJ7SoN+LbLh6APFpQwHKmDSwN35Z700Q==} engines: {node: ^16.0.0 || >=18.0.0} peerDependencies: - eslint: ^7.0.0 || ^8.0.0 + eslint: ^8.56.0 typescript: '*' peerDependenciesMeta: typescript: optional: true dependencies: - '@typescript-eslint/scope-manager': 6.19.0 - '@typescript-eslint/types': 6.19.0 - '@typescript-eslint/typescript-estree': 6.19.0(typescript@5.3.3) - '@typescript-eslint/visitor-keys': 6.19.0 + '@typescript-eslint/scope-manager': 7.0.2 + '@typescript-eslint/types': 7.0.2 + '@typescript-eslint/typescript-estree': 7.0.2(typescript@5.3.3) + '@typescript-eslint/visitor-keys': 7.0.2 debug: 4.3.4 eslint: 8.56.0 typescript: 5.3.3 @@ -6274,29 +6436,29 @@ packages: '@typescript-eslint/visitor-keys': 5.62.0 dev: true - /@typescript-eslint/scope-manager@6.19.0: - resolution: {integrity: sha512-dO1XMhV2ehBI6QN8Ufi7I10wmUovmLU0Oru3n5LVlM2JuzB4M+dVphCPLkVpKvGij2j/pHBWuJ9piuXx+BhzxQ==} + /@typescript-eslint/scope-manager@7.0.2: + resolution: {integrity: sha512-l6sa2jF3h+qgN2qUMjVR3uCNGjWw4ahGfzIYsCtFrQJCjhbrDPdiihYT8FnnqFwsWX+20hK592yX9I2rxKTP4g==} engines: {node: ^16.0.0 || >=18.0.0} dependencies: - '@typescript-eslint/types': 6.19.0 - '@typescript-eslint/visitor-keys': 6.19.0 + '@typescript-eslint/types': 7.0.2 + '@typescript-eslint/visitor-keys': 7.0.2 dev: true - /@typescript-eslint/type-utils@6.19.0(eslint@8.56.0)(typescript@5.3.3): - resolution: {integrity: sha512-mcvS6WSWbjiSxKCwBcXtOM5pRkPQ6kcDds/juxcy/727IQr3xMEcwr/YLHW2A2+Fp5ql6khjbKBzOyjuPqGi/w==} + /@typescript-eslint/type-utils@7.0.2(eslint@8.56.0)(typescript@5.3.3): + resolution: {integrity: sha512-IKKDcFsKAYlk8Rs4wiFfEwJTQlHcdn8CLwLaxwd6zb8HNiMcQIFX9sWax2k4Cjj7l7mGS5N1zl7RCHOVwHq2VQ==} engines: {node: ^16.0.0 || >=18.0.0} peerDependencies: - eslint: ^7.0.0 || ^8.0.0 + eslint: ^8.56.0 typescript: '*' peerDependenciesMeta: typescript: optional: true dependencies: - '@typescript-eslint/typescript-estree': 6.19.0(typescript@5.3.3) - '@typescript-eslint/utils': 6.19.0(eslint@8.56.0)(typescript@5.3.3) + '@typescript-eslint/typescript-estree': 7.0.2(typescript@5.3.3) + '@typescript-eslint/utils': 7.0.2(eslint@8.56.0)(typescript@5.3.3) debug: 4.3.4 eslint: 8.56.0 - ts-api-utils: 1.0.3(typescript@5.3.3) + ts-api-utils: 1.2.1(typescript@5.3.3) typescript: 5.3.3 transitivePeerDependencies: - supports-color @@ -6312,8 +6474,8 @@ packages: engines: {node: ^12.22.0 || ^14.17.0 || >=16.0.0} dev: true - /@typescript-eslint/types@6.19.0: - resolution: {integrity: sha512-lFviGV/vYhOy3m8BJ/nAKoAyNhInTdXpftonhWle66XHAtT1ouBlkjL496b5H5hb8dWXHwtypTqgtb/DEa+j5A==} + /@typescript-eslint/types@7.0.2: + resolution: {integrity: sha512-ZzcCQHj4JaXFjdOql6adYV4B/oFOFjPOC9XYwCaZFRvqN8Llfvv4gSxrkQkd2u4Ci62i2c6W6gkDwQJDaRc4nA==} engines: {node: ^16.0.0 || >=18.0.0} dev: true @@ -6331,7 +6493,7 @@ packages: debug: 4.3.4 globby: 11.1.0 is-glob: 4.0.3 - semver: 7.5.4 + semver: 7.6.0 tsutils: 3.21.0(typescript@3.9.10) typescript: 3.9.10 transitivePeerDependencies: @@ -6352,7 +6514,7 @@ packages: debug: 4.3.4 globby: 11.1.0 is-glob: 4.0.3 - semver: 7.5.4 + semver: 7.6.0 tsutils: 3.21.0(typescript@4.9.5) typescript: 4.9.5 transitivePeerDependencies: @@ -6373,15 +6535,15 @@ packages: debug: 4.3.4 globby: 11.1.0 is-glob: 4.0.3 - semver: 7.5.4 + semver: 7.6.0 tsutils: 3.21.0(typescript@5.3.3) typescript: 5.3.3 transitivePeerDependencies: - supports-color dev: true - /@typescript-eslint/typescript-estree@6.19.0(typescript@5.3.3): - resolution: {integrity: sha512-o/zefXIbbLBZ8YJ51NlkSAt2BamrK6XOmuxSR3hynMIzzyMY33KuJ9vuMdFSXW+H0tVvdF9qBPTHA91HDb4BIQ==} + /@typescript-eslint/typescript-estree@7.0.2(typescript@5.3.3): + resolution: {integrity: sha512-3AMc8khTcELFWcKcPc0xiLviEvvfzATpdPj/DXuOGIdQIIFybf4DMT1vKRbuAEOFMwhWt7NFLXRkbjsvKZQyvw==} engines: {node: ^16.0.0 || >=18.0.0} peerDependencies: typescript: '*' @@ -6389,14 +6551,14 @@ packages: typescript: optional: true dependencies: - '@typescript-eslint/types': 6.19.0 - '@typescript-eslint/visitor-keys': 6.19.0 + '@typescript-eslint/types': 7.0.2 + '@typescript-eslint/visitor-keys': 7.0.2 debug: 4.3.4 globby: 11.1.0 is-glob: 4.0.3 minimatch: 9.0.3 - semver: 7.5.4 - ts-api-utils: 1.0.3(typescript@5.3.3) + semver: 7.6.0 + ts-api-utils: 1.2.1(typescript@5.3.3) typescript: 5.3.3 transitivePeerDependencies: - supports-color @@ -6410,32 +6572,32 @@ packages: dependencies: '@eslint-community/eslint-utils': 4.4.0(eslint@8.56.0) '@types/json-schema': 7.0.15 - '@types/semver': 7.5.6 + '@types/semver': 7.5.7 '@typescript-eslint/scope-manager': 5.62.0 '@typescript-eslint/types': 5.62.0 '@typescript-eslint/typescript-estree': 5.62.0(typescript@5.3.3) eslint: 8.56.0 eslint-scope: 5.1.1 - semver: 7.5.4 + semver: 7.6.0 transitivePeerDependencies: - supports-color - typescript dev: true - /@typescript-eslint/utils@6.19.0(eslint@8.56.0)(typescript@5.3.3): - resolution: {integrity: sha512-QR41YXySiuN++/dC9UArYOg4X86OAYP83OWTewpVx5ct1IZhjjgTLocj7QNxGhWoTqknsgpl7L+hGygCO+sdYw==} + /@typescript-eslint/utils@7.0.2(eslint@8.56.0)(typescript@5.3.3): + resolution: {integrity: sha512-PZPIONBIB/X684bhT1XlrkjNZJIEevwkKDsdwfiu1WeqBxYEEdIgVDgm8/bbKHVu+6YOpeRqcfImTdImx/4Bsw==} engines: {node: ^16.0.0 || >=18.0.0} peerDependencies: - eslint: ^7.0.0 || ^8.0.0 + eslint: ^8.56.0 dependencies: '@eslint-community/eslint-utils': 4.4.0(eslint@8.56.0) '@types/json-schema': 7.0.15 - '@types/semver': 7.5.6 - '@typescript-eslint/scope-manager': 6.19.0 - '@typescript-eslint/types': 6.19.0 - '@typescript-eslint/typescript-estree': 6.19.0(typescript@5.3.3) + '@types/semver': 7.5.7 + '@typescript-eslint/scope-manager': 7.0.2 + '@typescript-eslint/types': 7.0.2 + '@typescript-eslint/typescript-estree': 7.0.2(typescript@5.3.3) eslint: 8.56.0 - semver: 7.5.4 + semver: 7.6.0 transitivePeerDependencies: - supports-color - typescript @@ -6457,11 +6619,11 @@ packages: eslint-visitor-keys: 3.4.3 dev: true - /@typescript-eslint/visitor-keys@6.19.0: - resolution: {integrity: sha512-hZaUCORLgubBvtGpp1JEFEazcuEdfxta9j4iUwdSAr7mEsYYAp3EAUyCZk3VEEqGj6W+AV4uWyrDGtrlawAsgQ==} + /@typescript-eslint/visitor-keys@7.0.2: + resolution: {integrity: sha512-8Y+YiBmqPighbm5xA2k4wKTxRzx9EkBu7Rlw+WHqMvRJ3RPz/BMBO9b2ru0LUNmXg120PHUXD5+SWFy2R8DqlQ==} engines: {node: ^16.0.0 || >=18.0.0} dependencies: - '@typescript-eslint/types': 6.19.0 + '@typescript-eslint/types': 7.0.2 eslint-visitor-keys: 3.4.3 dev: true @@ -6469,29 +6631,29 @@ packages: resolution: {integrity: sha512-zuVdFrMJiuCDQUMCzQaD6KL28MjnqqN8XnAqiEq9PNm/hCPTSGfrXCOfwj1ow4LFb/tNymJPwsNbVePc1xFqrQ==} dev: true - /@vitejs/plugin-react-swc@3.5.0(vite@5.0.12): - resolution: {integrity: sha512-1PrOvAaDpqlCV+Up8RkAh9qaiUjoDUcjtttyhXDKw53XA6Ve16SOp6cCOpRs8Dj8DqUQs6eTW5YkLcLJjrXAig==} + /@vitejs/plugin-react-swc@3.6.0(vite@5.1.3): + resolution: {integrity: sha512-XFRbsGgpGxGzEV5i5+vRiro1bwcIaZDIdBRP16qwm+jP68ue/S8FJTBEgOeojtVDYrbSua3XFp71kC8VJE6v+g==} peerDependencies: vite: ^4 || ^5 dependencies: - '@swc/core': 1.3.101 - vite: 5.0.12(@types/node@20.11.5) + '@swc/core': 1.4.2 + vite: 5.1.3(@types/node@20.11.19) transitivePeerDependencies: - '@swc/helpers' dev: true - /@vitejs/plugin-react@3.1.0(vite@5.0.12): + /@vitejs/plugin-react@3.1.0(vite@5.1.3): resolution: {integrity: sha512-AfgcRL8ZBhAlc3BFdigClmTUMISmmzHn7sB2h9U1odvc5U/MjWXsAaz18b/WoppUTDBzxOJwo2VdClfUcItu9g==} engines: {node: ^14.18.0 || >=16.0.0} peerDependencies: vite: ^4.1.0-beta.0 dependencies: - '@babel/core': 7.23.7 - '@babel/plugin-transform-react-jsx-self': 7.23.3(@babel/core@7.23.7) - '@babel/plugin-transform-react-jsx-source': 7.23.3(@babel/core@7.23.7) + '@babel/core': 7.23.9 + '@babel/plugin-transform-react-jsx-self': 7.23.3(@babel/core@7.23.9) + '@babel/plugin-transform-react-jsx-source': 7.23.3(@babel/core@7.23.9) magic-string: 0.27.0 react-refresh: 0.14.0 - vite: 5.0.12(@types/node@20.11.5) + vite: 5.1.3(@types/node@20.11.19) transitivePeerDependencies: - supports-color dev: true @@ -6504,26 +6666,26 @@ packages: chai: 4.4.1 dev: true - /@vitest/expect@1.2.2: - resolution: {integrity: sha512-3jpcdPAD7LwHUUiT2pZTj2U82I2Tcgg2oVPvKxhn6mDI2On6tfvPQTjAI4628GUGDZrCm4Zna9iQHm5cEexOAg==} + /@vitest/expect@1.3.1: + resolution: {integrity: sha512-xofQFwIzfdmLLlHa6ag0dPV8YsnKOCP1KdAeVVh34vSjN2dcUiXYCD9htu/9eM7t8Xln4v03U9HLxLpPlsXdZw==} dependencies: - '@vitest/spy': 1.2.2 - '@vitest/utils': 1.2.2 + '@vitest/spy': 1.3.1 + '@vitest/utils': 1.3.1 chai: 4.4.1 dev: true - /@vitest/runner@1.2.2: - resolution: {integrity: sha512-JctG7QZ4LSDXr5CsUweFgcpEvrcxOV1Gft7uHrvkQ+fsAVylmWQvnaAr/HDp3LAH1fztGMQZugIheTWjaGzYIg==} + /@vitest/runner@1.3.1: + resolution: {integrity: sha512-5FzF9c3jG/z5bgCnjr8j9LNq/9OxV2uEBAITOXfoe3rdZJTdO7jzThth7FXv/6b+kdY65tpRQB7WaKhNZwX+Kg==} dependencies: - '@vitest/utils': 1.2.2 + '@vitest/utils': 1.3.1 p-limit: 5.0.0 pathe: 1.1.2 dev: true - /@vitest/snapshot@1.2.2: - resolution: {integrity: sha512-SmGY4saEw1+bwE1th6S/cZmPxz/Q4JWsl7LvbQIky2tKE35US4gd0Mjzqfr84/4OD0tikGWaWdMja/nWL5NIPA==} + /@vitest/snapshot@1.3.1: + resolution: {integrity: sha512-EF++BZbt6RZmOlE3SuTPu/NfwBF6q4ABS37HHXzs2LUVPBLx2QoY/K0fKpRChSo8eLiuxcbCVfqKgx/dplCDuQ==} dependencies: - magic-string: 0.30.5 + magic-string: 0.30.7 pathe: 1.1.2 pretty-format: 29.7.0 dev: true @@ -6531,13 +6693,13 @@ packages: /@vitest/spy@0.34.7: resolution: {integrity: sha512-NMMSzOY2d8L0mcOt4XcliDOS1ISyGlAXuQtERWVOoVHnKwmG+kKhinAiGw3dTtMQWybfa89FG8Ucg9tiC/FhTQ==} dependencies: - tinyspy: 2.2.0 + tinyspy: 2.2.1 dev: true - /@vitest/spy@1.2.2: - resolution: {integrity: sha512-k9Gcahssw8d7X3pSLq3e3XEu/0L78mUkCjivUqCQeXJm9clfXR/Td8+AP+VC1O6fKPIDLcHDTAmBOINVuv6+7g==} + /@vitest/spy@1.3.1: + resolution: {integrity: sha512-xAcW+S099ylC9VLU7eZfdT9myV67Nor9w9zhf0mGCYJSO+zM2839tOeROTdikOi/8Qeusffvxb/MyBSOja1Uig==} dependencies: - tinyspy: 2.2.0 + tinyspy: 2.2.1 dev: true /@vitest/utils@0.34.7: @@ -6548,8 +6710,8 @@ packages: pretty-format: 29.7.0 dev: true - /@vitest/utils@1.2.2: - resolution: {integrity: sha512-WKITBHLsBHlpjnDQahr+XK6RE7MiAsgrIkr0pGhQ9ygoxBfUeG0lUG5iLlzqjmKSlBv3+j5EGsriBzh+C3Tq9g==} + /@vitest/utils@1.3.1: + resolution: {integrity: sha512-d3Waie/299qqRyHTm2DjADeTaNdNSVsnwHPWrs20JMpjh6eiVq7ggggweO8rc4arhf6rRkWuHKwvxGvejUXZZQ==} dependencies: diff-sequences: 29.6.3 estree-walker: 3.0.3 @@ -6576,21 +6738,21 @@ packages: path-browserify: 1.0.1 dev: true - /@vue/compiler-core@3.4.15: - resolution: {integrity: sha512-XcJQVOaxTKCnth1vCxEChteGuwG6wqnUHxAm1DO3gCz0+uXKaJNx8/digSz4dLALCy8n2lKq24jSUs8segoqIw==} + /@vue/compiler-core@3.4.19: + resolution: {integrity: sha512-gj81785z0JNzRcU0Mq98E56e4ltO1yf8k5PQ+tV/7YHnbZkrM0fyFyuttnN8ngJZjbpofWE/m4qjKBiLl8Ju4w==} dependencies: - '@babel/parser': 7.23.6 - '@vue/shared': 3.4.15 + '@babel/parser': 7.23.9 + '@vue/shared': 3.4.19 entities: 4.5.0 estree-walker: 2.0.2 source-map-js: 1.0.2 dev: true - /@vue/compiler-dom@3.4.15: - resolution: {integrity: sha512-wox0aasVV74zoXyblarOM3AZQz/Z+OunYcIHe1OsGclCHt8RsRm04DObjefaI82u6XDzv+qGWZ24tIsRAIi5MQ==} + /@vue/compiler-dom@3.4.19: + resolution: {integrity: sha512-vm6+cogWrshjqEHTzIDCp72DKtea8Ry/QVpQRYoyTIg9k7QZDX6D8+HGURjtmatfgM8xgCFtJJaOlCaRYRK3QA==} dependencies: - '@vue/compiler-core': 3.4.15 - '@vue/shared': 3.4.15 + '@vue/compiler-core': 3.4.19 + '@vue/shared': 3.4.19 dev: true /@vue/language-core@1.8.27(typescript@5.3.3): @@ -6603,8 +6765,8 @@ packages: dependencies: '@volar/language-core': 1.11.1 '@volar/source-map': 1.11.1 - '@vue/compiler-dom': 3.4.15 - '@vue/shared': 3.4.15 + '@vue/compiler-dom': 3.4.19 + '@vue/shared': 3.4.19 computeds: 0.0.1 minimatch: 9.0.3 muggle-string: 0.3.1 @@ -6613,8 +6775,8 @@ packages: vue-template-compiler: 2.7.16 dev: true - /@vue/shared@3.4.15: - resolution: {integrity: sha512-KzfPTxVaWfB+eGcGdbSf4CWdaXcGDqckoeXUh7SB3fZdEtzPCK2Vq9B/lRRL3yutax/LWITz+SwvgyOxz5V75g==} + /@vue/shared@3.4.19: + resolution: {integrity: sha512-/KliRRHMF6LoiThEy+4c1Z4KB/gbPrGjWwJR+crg2otgrf/egKzRaCPvJ51S5oetgsgXLfc4Rm5ZgrKHZrtMSw==} dev: true /@xobotyi/scrollbar-width@1.9.5: @@ -7224,12 +7386,12 @@ packages: acorn: 7.4.1 dev: true - /acorn-jsx@5.3.2(acorn@8.11.2): + /acorn-jsx@5.3.2(acorn@8.11.3): resolution: {integrity: sha512-rq9s+JNhf0IChjtDXxllJ7g41oZk5SlXtp0LHwyA5cejwn7vKmKp4pPri6YEePv2PU65sAsegbXtIinmDFDXgQ==} peerDependencies: acorn: ^6.0.0 || ^7.0.0 || ^8.0.0 dependencies: - acorn: 8.11.2 + acorn: 8.11.3 dev: true /acorn-walk@7.2.0: @@ -7248,12 +7410,6 @@ packages: hasBin: true dev: true - /acorn@8.11.2: - resolution: {integrity: sha512-nc0Axzp/0FILLEVsm4fNwLCwMttvhEI263QtVPQcbpfZZ3ts0hLsZGOpE6czNlid7CJ9MlyH8reXkpsf3YUY4w==} - engines: {node: '>=0.4.0'} - hasBin: true - dev: true - /acorn@8.11.3: resolution: {integrity: sha512-Y9rRfJG5jcKOE0CLisYbojUjIrIEE7AGMzA/Sm4BslANhbS+cDMpgBdcPT91oJ7OuJ9hYJBx59RjbhxVnrF8Xg==} engines: {node: '>=0.4.0'} @@ -7373,11 +7529,12 @@ packages: dequal: 2.0.3 dev: true - /array-buffer-byte-length@1.0.0: - resolution: {integrity: sha512-LPuwb2P+NrQw3XhxGc36+XSvuBPopovXYTR9Ew++Du9Yb/bx5AzBfrIsBoj0EZUifjQU+sHL21sseZ3jerWO/A==} + /array-buffer-byte-length@1.0.1: + resolution: {integrity: sha512-ahC5W1xgou+KTXix4sAO8Ki12Q+jf4i0+tmk3sC+zgcynshkHxzpXdImBehiUYKKKDwvfFiJl1tZt6ewscS1Mg==} + engines: {node: '>= 0.4'} dependencies: - call-bind: 1.0.5 - is-array-buffer: 3.0.2 + call-bind: 1.0.7 + is-array-buffer: 3.0.4 dev: true /array-flatten@1.1.1: @@ -7388,10 +7545,10 @@ packages: resolution: {integrity: sha512-dlcsNBIiWhPkHdOEEKnehA+RNUWDc4UqFtnIXU4uuYDPtA4LDkr7qip2p0VvFAEXNDr0yWZ9PJyIRiGjRLQzwQ==} engines: {node: '>= 0.4'} dependencies: - call-bind: 1.0.5 + call-bind: 1.0.7 define-properties: 1.2.1 - es-abstract: 1.22.3 - get-intrinsic: 1.2.2 + es-abstract: 1.22.4 + get-intrinsic: 1.2.4 is-string: 1.0.7 dev: true @@ -7400,24 +7557,35 @@ packages: engines: {node: '>=8'} dev: true - /array.prototype.findlastindex@1.2.3: - resolution: {integrity: sha512-LzLoiOMAxvy+Gd3BAq3B7VeIgPdo+Q8hthvKtXybMvRV0jrXfJM/t8mw7nNlpEcVlVUnCnM2KSX4XU5HmpodOA==} + /array.prototype.filter@1.0.3: + resolution: {integrity: sha512-VizNcj/RGJiUyQBgzwxzE5oHdeuXY5hSbbmKMlphj1cy1Vl7Pn2asCGbSrru6hSQjmCzqTBPVWAF/whmEOVHbw==} + engines: {node: '>= 0.4'} + dependencies: + call-bind: 1.0.7 + define-properties: 1.2.1 + es-abstract: 1.22.4 + es-array-method-boxes-properly: 1.0.0 + is-string: 1.0.7 + dev: true + + /array.prototype.findlastindex@1.2.4: + resolution: {integrity: sha512-hzvSHUshSpCflDR1QMUBLHGHP1VIEBegT4pix9H/Z92Xw3ySoy6c2qh7lJWTJnRJ8JCZ9bJNCgTyYaJGcJu6xQ==} engines: {node: '>= 0.4'} dependencies: - call-bind: 1.0.5 + call-bind: 1.0.7 define-properties: 1.2.1 - es-abstract: 1.22.3 + es-abstract: 1.22.4 + es-errors: 1.3.0 es-shim-unscopables: 1.0.2 - get-intrinsic: 1.2.2 dev: true /array.prototype.flat@1.3.2: resolution: {integrity: sha512-djYB+Zx2vLewY8RWlNCUdHjDXs2XOgm602S9E7P/UpHgfeHL00cRiIF+IN/G/aUJ7kGPb6yO/ErDI5V2s8iycA==} engines: {node: '>= 0.4'} dependencies: - call-bind: 1.0.5 + call-bind: 1.0.7 define-properties: 1.2.1 - es-abstract: 1.22.3 + es-abstract: 1.22.4 es-shim-unscopables: 1.0.2 dev: true @@ -7425,39 +7593,40 @@ packages: resolution: {integrity: sha512-Ewyx0c9PmpcsByhSW4r+9zDU7sGjFc86qf/kKtuSCRdhfbk0SNLLkaT5qvcHnRGgc5NP/ly/y+qkXkqONX54CQ==} engines: {node: '>= 0.4'} dependencies: - call-bind: 1.0.5 + call-bind: 1.0.7 define-properties: 1.2.1 - es-abstract: 1.22.3 + es-abstract: 1.22.4 es-shim-unscopables: 1.0.2 dev: true - /array.prototype.tosorted@1.1.2: - resolution: {integrity: sha512-HuQCHOlk1Weat5jzStICBCd83NxiIMwqDg/dHEsoefabn/hJRj5pVdWcPUSpRrwhwxZOsQassMpgN/xRYFBMIg==} + /array.prototype.tosorted@1.1.3: + resolution: {integrity: sha512-/DdH4TiTmOKzyQbp/eadcCVexiCb36xJg7HshYOYJnNZFDj33GEv0P7GxsynpShhq4OLYJzbGcBDkLsDt7MnNg==} dependencies: - call-bind: 1.0.5 + call-bind: 1.0.7 define-properties: 1.2.1 - es-abstract: 1.22.3 + es-abstract: 1.22.4 + es-errors: 1.3.0 es-shim-unscopables: 1.0.2 - get-intrinsic: 1.2.2 dev: true - /arraybuffer.prototype.slice@1.0.2: - resolution: {integrity: sha512-yMBKppFur/fbHu9/6USUe03bZ4knMYiwFBcyiaXB8Go0qNehwX6inYPzK9U0NeQvGxKthcmHcaR8P5MStSRBAw==} + /arraybuffer.prototype.slice@1.0.3: + resolution: {integrity: sha512-bMxMKAjg13EBSVscxTaYA4mRc5t1UAXa2kXiGTNfZ079HIWXEkKmkgFrh/nJqamaLSrXO5H4WFFkPEaLJWbs3A==} engines: {node: '>= 0.4'} dependencies: - array-buffer-byte-length: 1.0.0 - call-bind: 1.0.5 + array-buffer-byte-length: 1.0.1 + call-bind: 1.0.7 define-properties: 1.2.1 - es-abstract: 1.22.3 - get-intrinsic: 1.2.2 - is-array-buffer: 3.0.2 + es-abstract: 1.22.4 + es-errors: 1.3.0 + get-intrinsic: 1.2.4 + is-array-buffer: 3.0.4 is-shared-array-buffer: 1.0.2 dev: true /assert@2.1.0: resolution: {integrity: sha512-eLHpSK/Y4nhMJ07gDaAzoX/XAKS8PSaojml3M0DM4JpV1LAi5JOJ/p6H/XWrl8L+DzVEvVCW1z3vWAaB9oTsQw==} dependencies: - call-bind: 1.0.5 + call-bind: 1.0.7 is-nan: 1.3.2 object-is: 1.1.5 object.assign: 4.1.5 @@ -7512,17 +7681,19 @@ packages: engines: {node: '>=4'} dev: false - /available-typed-arrays@1.0.5: - resolution: {integrity: sha512-DMD0KiN46eipeziST1LPP/STfDU0sufISXmjSgvVsoU2tqxctQeASejWcfNtxYKqETM1UxQ8sp2OrSBWpHY6sw==} + /available-typed-arrays@1.0.7: + resolution: {integrity: sha512-wvUjBtSGN7+7SjNpq/9M2Tg350UZD3q62IFZLbRAR1bSMlCo1ZaeW+BJ+D090e4hIIZLBcTDWe4Mh4jvUDajzQ==} engines: {node: '>= 0.4'} + dependencies: + possible-typed-array-names: 1.0.0 dev: true - /babel-core@7.0.0-bridge.0(@babel/core@7.23.7): + /babel-core@7.0.0-bridge.0(@babel/core@7.23.9): resolution: {integrity: sha512-poPX9mZH/5CSanm50Q+1toVci6pv5KSRv/5TWCwtzQS5XEwn40BcCrgIeMFWP9CKKIniKXNxoIOnOq4VVlGXhg==} peerDependencies: '@babel/core': ^7.0.0-0 dependencies: - '@babel/core': 7.23.7 + '@babel/core': 7.23.9 dev: true /babel-plugin-istanbul@6.1.1: @@ -7542,43 +7713,43 @@ packages: resolution: {integrity: sha512-Cg7TFGpIr01vOQNODXOOaGz2NpCU5gl8x1qJFbb6hbZxR7XrcE2vtbAsTAbJ7/xwJtUuJEw8K8Zr/AE0LHlesg==} engines: {node: '>=10', npm: '>=6'} dependencies: - '@babel/runtime': 7.23.8 + '@babel/runtime': 7.23.9 cosmiconfig: 7.1.0 resolve: 1.22.8 dev: false - /babel-plugin-polyfill-corejs2@0.4.8(@babel/core@7.23.7): + /babel-plugin-polyfill-corejs2@0.4.8(@babel/core@7.23.9): resolution: {integrity: sha512-OtIuQfafSzpo/LhnJaykc0R/MMnuLSSVjVYy9mHArIZ9qTCSZ6TpWCuEKZYVoN//t8HqBNScHrOtCrIK5IaGLg==} peerDependencies: '@babel/core': ^7.4.0 || ^8.0.0-0 <8.0.0 dependencies: '@babel/compat-data': 7.23.5 - '@babel/core': 7.23.7 - '@babel/helper-define-polyfill-provider': 0.5.0(@babel/core@7.23.7) + '@babel/core': 7.23.9 + '@babel/helper-define-polyfill-provider': 0.5.0(@babel/core@7.23.9) semver: 6.3.1 transitivePeerDependencies: - supports-color dev: true - /babel-plugin-polyfill-corejs3@0.8.7(@babel/core@7.23.7): - resolution: {integrity: sha512-KyDvZYxAzkC0Aj2dAPyDzi2Ym15e5JKZSK+maI7NAwSqofvuFglbSsxE7wUOvTg9oFVnHMzVzBKcqEb4PJgtOA==} + /babel-plugin-polyfill-corejs3@0.9.0(@babel/core@7.23.9): + resolution: {integrity: sha512-7nZPG1uzK2Ymhy/NbaOWTg3uibM2BmGASS4vHS4szRZAIR8R6GwA/xAujpdrXU5iyklrimWnLWU+BLF9suPTqg==} peerDependencies: '@babel/core': ^7.4.0 || ^8.0.0-0 <8.0.0 dependencies: - '@babel/core': 7.23.7 - '@babel/helper-define-polyfill-provider': 0.4.4(@babel/core@7.23.7) - core-js-compat: 3.35.0 + '@babel/core': 7.23.9 + '@babel/helper-define-polyfill-provider': 0.5.0(@babel/core@7.23.9) + core-js-compat: 3.36.0 transitivePeerDependencies: - supports-color dev: true - /babel-plugin-polyfill-regenerator@0.5.5(@babel/core@7.23.7): + /babel-plugin-polyfill-regenerator@0.5.5(@babel/core@7.23.9): resolution: {integrity: sha512-OJGYZlhLqBh2DDHeqAxWB1XIvr49CxiJ2gIt61/PU55CQK4Z58OzMqjDe1zwQdQk+rBYsRc+1rJmdajM3gimHg==} peerDependencies: '@babel/core': ^7.4.0 || ^8.0.0-0 <8.0.0 dependencies: - '@babel/core': 7.23.7 - '@babel/helper-define-polyfill-provider': 0.5.0(@babel/core@7.23.7) + '@babel/core': 7.23.9 + '@babel/helper-define-polyfill-provider': 0.5.0(@babel/core@7.23.9) transitivePeerDependencies: - supports-color dev: true @@ -7677,15 +7848,15 @@ packages: pako: 0.2.9 dev: true - /browserslist@4.22.2: - resolution: {integrity: sha512-0UgcrvQmBDvZHFGdYUehrCNIazki7/lUP3kkoi/r3YB2amZbFM9J43ZRkJTXBUZK4gmx56+Sqk9+Vs9mwZx9+A==} + /browserslist@4.23.0: + resolution: {integrity: sha512-QW8HiM1shhT2GuzkvklfjcKDiWFXHOeFCIA/huJPwHsslwcydgk7X+z2zXpEijP98UCY7HbubZt5J2Zgvf0CaQ==} engines: {node: ^6 || ^7 || ^8 || ^9 || ^10 || ^11 || ^12 || >=13.7} hasBin: true dependencies: - caniuse-lite: 1.0.30001579 - electron-to-chromium: 1.4.639 + caniuse-lite: 1.0.30001588 + electron-to-chromium: 1.4.677 node-releases: 2.0.14 - update-browserslist-db: 1.0.13(browserslist@4.22.2) + update-browserslist-db: 1.0.13(browserslist@4.23.0) dev: true /bser@2.1.1: @@ -7724,12 +7895,15 @@ packages: engines: {node: '>=8'} dev: true - /call-bind@1.0.5: - resolution: {integrity: sha512-C3nQxfFZxFRVoJoGKKI8y3MOEo129NQ+FgQ08iye+Mk4zNZZGdjfs06bVTr+DBSlA66Q2VEcMki/cUCP4SercQ==} + /call-bind@1.0.7: + resolution: {integrity: sha512-GHTSNSYICQ7scH7sZ+M2rFopRoLh8t2bLSW6BbgrtLsahOIB5iyAVJf9GjWK3cYTDaMj4XdBpM1cA6pIS0Kv2w==} + engines: {node: '>= 0.4'} dependencies: + es-define-property: 1.0.0 + es-errors: 1.3.0 function-bind: 1.1.2 - get-intrinsic: 1.2.2 - set-function-length: 1.1.1 + get-intrinsic: 1.2.4 + set-function-length: 1.2.1 dev: true /callsites@3.1.0: @@ -7741,8 +7915,8 @@ packages: engines: {node: '>=6'} dev: true - /caniuse-lite@1.0.30001579: - resolution: {integrity: sha512-u5AUVkixruKHJjw/pj9wISlcMpgFWzSrczLZbrqBSxukQixmg0SJ5sZTpvaFvxU0HoQKd4yoyAogyrAz9pzJnA==} + /caniuse-lite@1.0.30001588: + resolution: {integrity: sha512-+hVY9jE44uKLkH0SrUTqxjxqNTOWHsbnQDIKjwkZ3lNTzUUVdBLBGXtj/q5Mp5u98r3droaZAewQuEDzjQdZlQ==} dev: true /chai@4.4.1: @@ -7758,7 +7932,7 @@ packages: type-detect: 4.0.8 dev: true - /chakra-react-select@4.7.6(@chakra-ui/form-control@2.2.0)(@chakra-ui/icon@3.2.0)(@chakra-ui/layout@2.3.1)(@chakra-ui/media-query@3.3.0)(@chakra-ui/menu@2.2.1)(@chakra-ui/spinner@2.1.0)(@chakra-ui/system@2.6.2)(@emotion/react@11.11.3)(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0): + /chakra-react-select@4.7.6(@chakra-ui/form-control@2.2.0)(@chakra-ui/icon@3.2.0)(@chakra-ui/layout@2.3.1)(@chakra-ui/media-query@3.3.0)(@chakra-ui/menu@2.2.1)(@chakra-ui/spinner@2.1.0)(@chakra-ui/system@2.6.2)(@emotion/react@11.11.3)(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0): resolution: {integrity: sha512-ZL43hyXPnWf1g/HjsZDecbeJ4F2Q6tTPYJozlKWkrQ7lIX7ORP0aZYwmc5/Wly4UNzMimj2Vuosl6MmIXH+G2g==} peerDependencies: '@chakra-ui/form-control': ^2.0.0 @@ -7776,13 +7950,13 @@ packages: '@chakra-ui/icon': 3.2.0(@chakra-ui/system@2.6.2)(react@18.2.0) '@chakra-ui/layout': 2.3.1(@chakra-ui/system@2.6.2)(react@18.2.0) '@chakra-ui/media-query': 3.3.0(@chakra-ui/system@2.6.2)(react@18.2.0) - '@chakra-ui/menu': 2.2.1(@chakra-ui/system@2.6.2)(framer-motion@10.18.0)(react@18.2.0) + '@chakra-ui/menu': 2.2.1(@chakra-ui/system@2.6.2)(framer-motion@11.0.5)(react@18.2.0) '@chakra-ui/spinner': 2.1.0(@chakra-ui/system@2.6.2)(react@18.2.0) '@chakra-ui/system': 2.6.2(@emotion/react@11.11.3)(@emotion/styled@11.11.0)(react@18.2.0) - '@emotion/react': 11.11.3(@types/react@18.2.48)(react@18.2.0) + '@emotion/react': 11.11.3(@types/react@18.2.57)(react@18.2.0) react: 18.2.0 react-dom: 18.2.0(react@18.2.0) - react-select: 5.7.7(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0) + react-select: 5.7.7(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0) transitivePeerDependencies: - '@types/react' dev: false @@ -7822,8 +7996,8 @@ packages: get-func-name: 2.0.2 dev: true - /chokidar@3.5.3: - resolution: {integrity: sha512-Dr3sfKRP6oTcjf2JmUmFJfeVMvXBdegxB0iVQ5eb2V10uFJUCAS8OByZdVAyVb8xXNz3GjjTgj9kLWsZTqE6kw==} + /chokidar@3.6.0: + resolution: {integrity: sha512-7VT13fmjotKpGipCW9JEQAusEPE+Ei8nl6/g4FBAmIm0GOOLMua9NDDo/DWp0ZAxCr3cPq5ZpBqmPAQgDda2Pw==} engines: {node: '>= 8.10.0'} dependencies: anymatch: 3.1.3 @@ -7851,8 +8025,8 @@ packages: engines: {node: '>=8'} dev: true - /citty@0.1.5: - resolution: {integrity: sha512-AS7n5NSc0OQVMV9v6wt3ByujNIrne0/cTjiC2MYqhvao57VNfiuVksTSr2p17nVOhEr2KtqiAkGwHcgMC/qUuQ==} + /citty@0.1.6: + resolution: {integrity: sha512-tskPPKEs8D2KPafUypv2gxwJP8h/OaJmC82QQGGDQcHvXX43xF2VDACcJVmZ0EuSxkpO9Kc4MlrA3q0+FG58AQ==} dependencies: consola: 3.2.3 dev: true @@ -8071,10 +8245,10 @@ packages: toggle-selection: 1.0.6 dev: false - /core-js-compat@3.35.0: - resolution: {integrity: sha512-5blwFAddknKeNgsjBzilkdQ0+YK8L1PfqPYq40NOYMYFSS38qj+hpTcLLWwpIwA2A5bje/x5jmVn2tzUMg9IVw==} + /core-js-compat@3.36.0: + resolution: {integrity: sha512-iV9Pd/PsgjNWBXeq8XRtWVSgz2tKAfhfvBs7qxYty+RlRd+OCksaWmOnc4JKrTc1cToXL1N0s3l/vwlxPtdElw==} dependencies: - browserslist: 4.22.2 + browserslist: 4.23.0 dev: true /core-util-is@1.0.3: @@ -8210,7 +8384,7 @@ packages: resolution: {integrity: sha512-fnULvOpxnC5/Vg3NCiWelDsLiUc9bRwAPs/+LfTLNvetFCtCTN+yQz15C/fs4AwX1R9K5GLtLfn8QW+dWisaAw==} engines: {node: '>=0.11'} dependencies: - '@babel/runtime': 7.23.6 + '@babel/runtime': 7.23.9 dev: true /dateformat@5.0.3: @@ -8271,12 +8445,12 @@ packages: resolution: {integrity: sha512-ZIwpnevOurS8bpT4192sqAowWM76JDKSHYzMLty3BZGSswgq6pBaH3DhCSW5xVAZICZyKdOBPjwww5wfgT/6PA==} engines: {node: '>= 0.4'} dependencies: - array-buffer-byte-length: 1.0.0 - call-bind: 1.0.5 + array-buffer-byte-length: 1.0.1 + call-bind: 1.0.7 es-get-iterator: 1.1.3 - get-intrinsic: 1.2.2 + get-intrinsic: 1.2.4 is-arguments: 1.1.1 - is-array-buffer: 3.0.2 + is-array-buffer: 3.0.4 is-date-object: 1.0.5 is-regex: 1.1.4 is-shared-array-buffer: 1.0.2 @@ -8284,11 +8458,11 @@ packages: object-is: 1.1.5 object-keys: 1.1.1 object.assign: 4.1.5 - regexp.prototype.flags: 1.5.1 - side-channel: 1.0.4 + regexp.prototype.flags: 1.5.2 + side-channel: 1.0.5 which-boxed-primitive: 1.0.2 which-collection: 1.0.1 - which-typed-array: 1.1.13 + which-typed-array: 1.1.14 dev: true /deep-extend@0.6.0: @@ -8314,13 +8488,13 @@ packages: clone: 1.0.4 dev: true - /define-data-property@1.1.1: - resolution: {integrity: sha512-E7uGkTzkk1d0ByLeSc6ZsFS79Axg+m1P/VsgYsxHgiuc3tFSj+MjMIwe90FC4lOAZzNBdY7kkO2P2wKdsQ1vgQ==} + /define-data-property@1.1.4: + resolution: {integrity: sha512-rBMvIzlpA8v6E+SJZoo++HAYqsLrkg7MSfIinMPFhmkorw7X+dOXVJQs+QT69zGkzMyfDnIMN2Wid1+NbL3T+A==} engines: {node: '>= 0.4'} dependencies: - get-intrinsic: 1.2.2 + es-define-property: 1.0.0 + es-errors: 1.3.0 gopd: 1.0.1 - has-property-descriptors: 1.0.1 /define-lazy-prop@2.0.0: resolution: {integrity: sha512-Ds09qNh8yw3khSjiJjiUInaGX9xlqZDY7JVryGxdxV7NPeuqQfplOpQ66yJFZut3jLa5zOwkXw1g9EI2uKh4Og==} @@ -8331,8 +8505,8 @@ packages: resolution: {integrity: sha512-8QmQKqEASLd5nx0U1B1okLElbUuuttJ/AnYmRXbbbGDWh6uS208EjD4Xqq/I9wK7u0v6O08XhTWnt5XtEbR6Dg==} engines: {node: '>= 0.4'} dependencies: - define-data-property: 1.1.1 - has-property-descriptors: 1.0.1 + define-data-property: 1.1.4 + has-property-descriptors: 1.0.2 object-keys: 1.1.1 /defu@6.1.4: @@ -8481,7 +8655,7 @@ packages: dependencies: debug: 4.3.4 is-url: 1.2.4 - postcss: 8.4.33 + postcss: 8.4.35 postcss-values-parser: 2.0.1 transitivePeerDependencies: - supports-color @@ -8492,8 +8666,8 @@ packages: engines: {node: ^12.0.0 || ^14.0.0 || >=16.0.0} dependencies: is-url: 1.2.4 - postcss: 8.4.33 - postcss-values-parser: 6.0.2(postcss@8.4.33) + postcss: 8.4.35 + postcss-values-parser: 6.0.2(postcss@8.4.35) dev: true /detective-sass@3.0.2: @@ -8611,7 +8785,7 @@ packages: /dom-helpers@5.2.1: resolution: {integrity: sha512-nRCa7CK3VTrM2NmGkIy4cbK7IZlgBE/PYMn55rrXefr5xXDP0LdtfPnblFDoVdcAfslJ7or6iqAUnx0CCGIWQA==} dependencies: - '@babel/runtime': 7.23.7 + '@babel/runtime': 7.23.9 csstype: 3.1.3 dev: false @@ -8620,8 +8794,8 @@ packages: engines: {node: '>=12'} dev: true - /dotenv@16.3.1: - resolution: {integrity: sha512-IPzF4w4/Rd94bA9imS68tZBaYyBWSCE47V1RGuMrB94iyTOIEwRmVL2x/4An+6mETpLrKJ5hQkB8W4kFAadeIQ==} + /dotenv@16.4.5: + resolution: {integrity: sha512-ZmdL2rui+eB2YwhsWzjInR8LldtZHGDoQ1ugH85ppHKwpUHL7j7rN0Ti9NCnGiQbhaZ11FpR+7ao1dNsmduNUg==} engines: {node: '>=12'} dev: true @@ -8650,8 +8824,8 @@ packages: jake: 10.8.7 dev: true - /electron-to-chromium@1.4.639: - resolution: {integrity: sha512-CkKf3ZUVZchr+zDpAlNLEEy2NJJ9T64ULWaDgy3THXXlPVPkLu3VOs9Bac44nebVtdwl2geSj6AxTtGDOxoXhg==} + /electron-to-chromium@1.4.677: + resolution: {integrity: sha512-erDa3CaDzwJOpyvfKhOiJjBVNnMM0qxHq47RheVVwsSQrgBA9ZSGV9kdaOfZDPXcHzhG7lBxhj6A7KvfLJBd6Q==} dev: true /emoji-regex@8.0.0: @@ -8678,7 +8852,7 @@ packages: dependencies: '@socket.io/component-emitter': 3.1.0 debug: 4.3.4 - engine.io-parser: 5.2.1 + engine.io-parser: 5.2.2 ws: 8.11.0 xmlhttprequest-ssl: 2.0.0 transitivePeerDependencies: @@ -8687,8 +8861,8 @@ packages: - utf-8-validate dev: false - /engine.io-parser@5.2.1: - resolution: {integrity: sha512-9JktcM3u18nU9N2Lz3bWeBgxVgOKpw7yhRaoxQA3FUDZzzw+9WlA6p4G4u0RixNkg14fH7EfEc/RhpurtiROTQ==} + /engine.io-parser@5.2.2: + resolution: {integrity: sha512-RcyUFKA93/CXH20l4SoVvzZfrSDMOTUS3bWVpTt2FuFP+XYrL8i8oonHP7WInRyVHXh0n/ORtoeiE1os+8qkSw==} engines: {node: '>=10.0.0'} dev: false @@ -8705,8 +8879,8 @@ packages: engines: {node: '>=0.12'} dev: true - /envinfo@7.11.0: - resolution: {integrity: sha512-G9/6xF1FPbIw0TtalAMaVPpiq2aDEuKLXM314jPVAO9r2fo2a4BLqMNkmRS7O/xPPZ+COAhGIz3ETvHEV3eUcg==} + /envinfo@7.11.1: + resolution: {integrity: sha512-8PiZgZNIB4q/Lw4AhOvAfB/ityHAd2bli3lESSWmWSzSsl5dKpy5N1d1Rfkd2teq/g9xN90lc6o98DOjMeYHpg==} engines: {node: '>=4'} hasBin: true dev: true @@ -8722,56 +8896,72 @@ packages: stackframe: 1.3.4 dev: false - /es-abstract@1.22.3: - resolution: {integrity: sha512-eiiY8HQeYfYH2Con2berK+To6GrK2RxbPawDkGq4UiCQQfZHb6wX9qQqkbpPqaxQFcl8d9QzZqo0tGE0VcrdwA==} + /es-abstract@1.22.4: + resolution: {integrity: sha512-vZYJlk2u6qHYxBOTjAeg7qUxHdNfih64Uu2J8QqWgXZ2cri0ZpJAkzDUK/q593+mvKwlxyaxr6F1Q+3LKoQRgg==} engines: {node: '>= 0.4'} dependencies: - array-buffer-byte-length: 1.0.0 - arraybuffer.prototype.slice: 1.0.2 - available-typed-arrays: 1.0.5 - call-bind: 1.0.5 - es-set-tostringtag: 2.0.2 + array-buffer-byte-length: 1.0.1 + arraybuffer.prototype.slice: 1.0.3 + available-typed-arrays: 1.0.7 + call-bind: 1.0.7 + es-define-property: 1.0.0 + es-errors: 1.3.0 + es-set-tostringtag: 2.0.3 es-to-primitive: 1.2.1 function.prototype.name: 1.1.6 - get-intrinsic: 1.2.2 - get-symbol-description: 1.0.0 + get-intrinsic: 1.2.4 + get-symbol-description: 1.0.2 globalthis: 1.0.3 gopd: 1.0.1 - has-property-descriptors: 1.0.1 - has-proto: 1.0.1 + has-property-descriptors: 1.0.2 + has-proto: 1.0.3 has-symbols: 1.0.3 - hasown: 2.0.0 - internal-slot: 1.0.6 - is-array-buffer: 3.0.2 + hasown: 2.0.1 + internal-slot: 1.0.7 + is-array-buffer: 3.0.4 is-callable: 1.2.7 - is-negative-zero: 2.0.2 + is-negative-zero: 2.0.3 is-regex: 1.1.4 is-shared-array-buffer: 1.0.2 is-string: 1.0.7 - is-typed-array: 1.1.12 + is-typed-array: 1.1.13 is-weakref: 1.0.2 object-inspect: 1.13.1 object-keys: 1.1.1 object.assign: 4.1.5 - regexp.prototype.flags: 1.5.1 - safe-array-concat: 1.0.1 - safe-regex-test: 1.0.0 + regexp.prototype.flags: 1.5.2 + safe-array-concat: 1.1.0 + safe-regex-test: 1.0.3 string.prototype.trim: 1.2.8 string.prototype.trimend: 1.0.7 string.prototype.trimstart: 1.0.7 - typed-array-buffer: 1.0.0 + typed-array-buffer: 1.0.2 typed-array-byte-length: 1.0.0 - typed-array-byte-offset: 1.0.0 - typed-array-length: 1.0.4 + typed-array-byte-offset: 1.0.2 + typed-array-length: 1.0.5 unbox-primitive: 1.0.2 - which-typed-array: 1.1.13 + which-typed-array: 1.1.14 + dev: true + + /es-array-method-boxes-properly@1.0.0: + resolution: {integrity: sha512-wd6JXUmyHmt8T5a2xreUwKcGPq6f1f+WwIJkijUqiGcJz1qqnZgP6XIK+QyIWU5lT7imeNxUll48bziG+TSYcA==} dev: true + /es-define-property@1.0.0: + resolution: {integrity: sha512-jxayLKShrEqqzJ0eumQbVhTYQM27CfT1T35+gCgDFoL82JLsXqTJ76zv6A0YLOgEnLUMvLzsDsGIrl8NFpT2gQ==} + engines: {node: '>= 0.4'} + dependencies: + get-intrinsic: 1.2.4 + + /es-errors@1.3.0: + resolution: {integrity: sha512-Zf5H2Kxt2xjTvbJvP2ZWLEICxA6j+hAmMzIlypy4xcBg1vKVnx89Wy0GbS+kf5cwCVFFzdCFh2XSCFNULS6csw==} + engines: {node: '>= 0.4'} + /es-get-iterator@1.1.3: resolution: {integrity: sha512-sPZmqHBe6JIiTfN5q2pEi//TwxmAFHwj/XEuYjTuse78i8KxaqMTTzxPoFKuzRpDpTJ+0NAbpfenkmH2rePtuw==} dependencies: - call-bind: 1.0.5 - get-intrinsic: 1.2.2 + call-bind: 1.0.7 + get-intrinsic: 1.2.4 has-symbols: 1.0.3 is-arguments: 1.1.1 is-map: 2.0.2 @@ -8781,42 +8971,44 @@ packages: stop-iteration-iterator: 1.0.0 dev: true - /es-iterator-helpers@1.0.15: - resolution: {integrity: sha512-GhoY8uYqd6iwUl2kgjTm4CZAf6oo5mHK7BPqx3rKgx893YSsy0LGHV6gfqqQvZt/8xM8xeOnfXBCfqclMKkJ5g==} + /es-iterator-helpers@1.0.17: + resolution: {integrity: sha512-lh7BsUqelv4KUbR5a/ZTaGGIMLCjPGPqJ6q+Oq24YP0RdyptX1uzm4vvaqzk7Zx3bpl/76YLTTDj9L7uYQ92oQ==} + engines: {node: '>= 0.4'} dependencies: asynciterator.prototype: 1.0.0 - call-bind: 1.0.5 + call-bind: 1.0.7 define-properties: 1.2.1 - es-abstract: 1.22.3 - es-set-tostringtag: 2.0.2 + es-abstract: 1.22.4 + es-errors: 1.3.0 + es-set-tostringtag: 2.0.3 function-bind: 1.1.2 - get-intrinsic: 1.2.2 + get-intrinsic: 1.2.4 globalthis: 1.0.3 - has-property-descriptors: 1.0.1 - has-proto: 1.0.1 + has-property-descriptors: 1.0.2 + has-proto: 1.0.3 has-symbols: 1.0.3 - internal-slot: 1.0.6 + internal-slot: 1.0.7 iterator.prototype: 1.1.2 - safe-array-concat: 1.0.1 + safe-array-concat: 1.1.0 dev: true /es-module-lexer@0.9.3: resolution: {integrity: sha512-1HQ2M2sPtxwnvOvT1ZClHyQDiggdNjURWpY2we6aMKCQiUVxTmVs2UYPLIrD84sS+kMdUwfBSylbJPwNnBrnHQ==} dev: true - /es-set-tostringtag@2.0.2: - resolution: {integrity: sha512-BuDyupZt65P9D2D2vA/zqcI3G5xRsklm5N3xCwuiy+/vKy8i0ifdsQP1sLgO4tZDSCaQUSnmC48khknGMV3D2Q==} + /es-set-tostringtag@2.0.3: + resolution: {integrity: sha512-3T8uNMC3OQTHkFUsFq8r/BwAXLHvU/9O9mE0fBc/MY5iq/8H7ncvO947LmYA6ldWw9Uh8Yhf25zu6n7nML5QWQ==} engines: {node: '>= 0.4'} dependencies: - get-intrinsic: 1.2.2 - has-tostringtag: 1.0.0 - hasown: 2.0.0 + get-intrinsic: 1.2.4 + has-tostringtag: 1.0.2 + hasown: 2.0.1 dev: true /es-shim-unscopables@1.0.2: resolution: {integrity: sha512-J3yBRXCzDu4ULnQwxyToo/OjdMx6akgVC7K6few0a7F/0wLtmKKN7I73AH5T2836UuXRqN7Qg+IIUw/+YJksRw==} dependencies: - hasown: 2.0.0 + hasown: 2.0.1 dev: true /es-to-primitive@1.2.1: @@ -8873,39 +9065,39 @@ packages: '@esbuild/win32-x64': 0.18.20 dev: true - /esbuild@0.19.11: - resolution: {integrity: sha512-HJ96Hev2hX/6i5cDVwcqiJBBtuo9+FeIJOtZ9W1kA5M6AMJRHUZlpYZ1/SbEwtO0ioNAW8rUooVpC/WehY2SfA==} + /esbuild@0.19.12: + resolution: {integrity: sha512-aARqgq8roFBj054KvQr5f1sFu0D65G+miZRCuJyJ0G13Zwx7vRar5Zhn2tkQNzIXcBrNVsv/8stehpj+GAjgbg==} engines: {node: '>=12'} hasBin: true requiresBuild: true optionalDependencies: - '@esbuild/aix-ppc64': 0.19.11 - '@esbuild/android-arm': 0.19.11 - '@esbuild/android-arm64': 0.19.11 - '@esbuild/android-x64': 0.19.11 - '@esbuild/darwin-arm64': 0.19.11 - '@esbuild/darwin-x64': 0.19.11 - '@esbuild/freebsd-arm64': 0.19.11 - '@esbuild/freebsd-x64': 0.19.11 - '@esbuild/linux-arm': 0.19.11 - '@esbuild/linux-arm64': 0.19.11 - '@esbuild/linux-ia32': 0.19.11 - '@esbuild/linux-loong64': 0.19.11 - '@esbuild/linux-mips64el': 0.19.11 - '@esbuild/linux-ppc64': 0.19.11 - '@esbuild/linux-riscv64': 0.19.11 - '@esbuild/linux-s390x': 0.19.11 - '@esbuild/linux-x64': 0.19.11 - '@esbuild/netbsd-x64': 0.19.11 - '@esbuild/openbsd-x64': 0.19.11 - '@esbuild/sunos-x64': 0.19.11 - '@esbuild/win32-arm64': 0.19.11 - '@esbuild/win32-ia32': 0.19.11 - '@esbuild/win32-x64': 0.19.11 - dev: true - - /escalade@3.1.1: - resolution: {integrity: sha512-k0er2gUkLf8O0zKJiAhmkTnJlTvINGv7ygDNPbeIsX/TJjGJZHuh9B2UxbsaEkmlEo9MfhrSzmhIlhRlI2GXnw==} + '@esbuild/aix-ppc64': 0.19.12 + '@esbuild/android-arm': 0.19.12 + '@esbuild/android-arm64': 0.19.12 + '@esbuild/android-x64': 0.19.12 + '@esbuild/darwin-arm64': 0.19.12 + '@esbuild/darwin-x64': 0.19.12 + '@esbuild/freebsd-arm64': 0.19.12 + '@esbuild/freebsd-x64': 0.19.12 + '@esbuild/linux-arm': 0.19.12 + '@esbuild/linux-arm64': 0.19.12 + '@esbuild/linux-ia32': 0.19.12 + '@esbuild/linux-loong64': 0.19.12 + '@esbuild/linux-mips64el': 0.19.12 + '@esbuild/linux-ppc64': 0.19.12 + '@esbuild/linux-riscv64': 0.19.12 + '@esbuild/linux-s390x': 0.19.12 + '@esbuild/linux-x64': 0.19.12 + '@esbuild/netbsd-x64': 0.19.12 + '@esbuild/openbsd-x64': 0.19.12 + '@esbuild/sunos-x64': 0.19.12 + '@esbuild/win32-arm64': 0.19.12 + '@esbuild/win32-ia32': 0.19.12 + '@esbuild/win32-x64': 0.19.12 + dev: true + + /escalade@3.1.2: + resolution: {integrity: sha512-ErCHMCae19vR8vQGe50xIsVomy19rg6gFu3+r3jkEO46suLMWBksvVyoGgQV+jOfl84ZSOSlmv6Gxa89PmTGmA==} engines: {node: '>=6'} dev: true @@ -8952,7 +9144,7 @@ packages: - supports-color dev: true - /eslint-module-utils@2.8.0(@typescript-eslint/parser@6.19.0)(eslint-import-resolver-node@0.3.9)(eslint@8.56.0): + /eslint-module-utils@2.8.0(@typescript-eslint/parser@7.0.2)(eslint-import-resolver-node@0.3.9)(eslint@8.56.0): resolution: {integrity: sha512-aWajIYfsqCKRDgUfjEXNN/JlrzauMuSEy5sbd7WXbtW3EH6A6MpwEh42c7qD+MqQo9QMJ6fWLAeIJynx0g6OAw==} engines: {node: '>=4'} peerDependencies: @@ -8973,7 +9165,7 @@ packages: eslint-import-resolver-webpack: optional: true dependencies: - '@typescript-eslint/parser': 6.19.0(eslint@8.56.0)(typescript@5.3.3) + '@typescript-eslint/parser': 7.0.2(eslint@8.56.0)(typescript@5.3.3) debug: 3.2.7 eslint: 8.56.0 eslint-import-resolver-node: 0.3.9 @@ -8989,7 +9181,7 @@ packages: requireindex: 1.1.0 dev: true - /eslint-plugin-import@2.29.1(@typescript-eslint/parser@6.19.0)(eslint@8.56.0): + /eslint-plugin-import@2.29.1(@typescript-eslint/parser@7.0.2)(eslint@8.56.0): resolution: {integrity: sha512-BbPC0cuExzhiMo4Ff1BTVwHpjjv28C5R+btTOGaCRC7UEz801up0JadwkeSk5Ued6TG34uaczuVuH6qyy5YUxw==} engines: {node: '>=4'} peerDependencies: @@ -8999,22 +9191,22 @@ packages: '@typescript-eslint/parser': optional: true dependencies: - '@typescript-eslint/parser': 6.19.0(eslint@8.56.0)(typescript@5.3.3) + '@typescript-eslint/parser': 7.0.2(eslint@8.56.0)(typescript@5.3.3) array-includes: 3.1.7 - array.prototype.findlastindex: 1.2.3 + array.prototype.findlastindex: 1.2.4 array.prototype.flat: 1.3.2 array.prototype.flatmap: 1.3.2 debug: 3.2.7 doctrine: 2.1.0 eslint: 8.56.0 eslint-import-resolver-node: 0.3.9 - eslint-module-utils: 2.8.0(@typescript-eslint/parser@6.19.0)(eslint-import-resolver-node@0.3.9)(eslint@8.56.0) - hasown: 2.0.0 + eslint-module-utils: 2.8.0(@typescript-eslint/parser@7.0.2)(eslint-import-resolver-node@0.3.9)(eslint@8.56.0) + hasown: 2.0.1 is-core-module: 2.13.1 is-glob: 4.0.3 minimatch: 3.1.2 object.fromentries: 2.0.7 - object.groupby: 1.0.1 + object.groupby: 1.0.2 object.values: 1.1.7 semver: 6.3.1 tsconfig-paths: 3.15.0 @@ -9059,9 +9251,9 @@ packages: dependencies: array-includes: 3.1.7 array.prototype.flatmap: 1.3.2 - array.prototype.tosorted: 1.1.2 + array.prototype.tosorted: 1.1.3 doctrine: 2.1.0 - es-iterator-helpers: 1.0.15 + es-iterator-helpers: 1.0.17 eslint: 8.56.0 estraverse: 5.3.0 jsx-ast-utils: 3.3.5 @@ -9076,17 +9268,17 @@ packages: string.prototype.matchall: 4.0.10 dev: true - /eslint-plugin-simple-import-sort@10.0.0(eslint@8.56.0): - resolution: {integrity: sha512-AeTvO9UCMSNzIHRkg8S6c3RPy5YEwKWSQPx3DYghLedo2ZQxowPFLGDN1AZ2evfg6r6mjBSZSLxLFsWSu3acsw==} + /eslint-plugin-simple-import-sort@12.0.0(eslint@8.56.0): + resolution: {integrity: sha512-8o0dVEdAkYap0Cn5kNeklaKcT1nUsa3LITWEuFk3nJifOoD+5JQGoyDUW2W/iPWwBsNBJpyJS9y4je/BgxLcyQ==} peerDependencies: eslint: '>=5.0.0' dependencies: eslint: 8.56.0 dev: true - /eslint-plugin-storybook@0.6.15(eslint@8.56.0)(typescript@5.3.3): - resolution: {integrity: sha512-lAGqVAJGob47Griu29KXYowI4G7KwMoJDOkEip8ujikuDLxU+oWJ1l0WL6F2oDO4QiyUFXvtDkEkISMOPzo+7w==} - engines: {node: 12.x || 14.x || >= 16} + /eslint-plugin-storybook@0.8.0(eslint@8.56.0)(typescript@5.3.3): + resolution: {integrity: sha512-CZeVO5EzmPY7qghO2t64oaFM+8FTaD4uzOEjHKp516exyTKo+skKAL9GI3QALS2BXhyALJjNtwbmr1XinGE8bA==} + engines: {node: '>= 18'} peerDependencies: eslint: '>=6' dependencies: @@ -9100,17 +9292,17 @@ packages: - typescript dev: true - /eslint-plugin-unused-imports@3.0.0(@typescript-eslint/eslint-plugin@6.19.0)(eslint@8.56.0): - resolution: {integrity: sha512-sduiswLJfZHeeBJ+MQaG+xYzSWdRXoSw61DpU13mzWumCkR0ufD0HmO4kdNokjrkluMHpj/7PJeN35pgbhW3kw==} + /eslint-plugin-unused-imports@3.1.0(@typescript-eslint/eslint-plugin@7.0.2)(eslint@8.56.0): + resolution: {integrity: sha512-9l1YFCzXKkw1qtAru1RWUtG2EVDZY0a0eChKXcL+EZ5jitG7qxdctu4RnvhOJHv4xfmUf7h+JJPINlVpGhZMrw==} engines: {node: ^12.22.0 || ^14.17.0 || >=16.0.0} peerDependencies: - '@typescript-eslint/eslint-plugin': ^6.0.0 - eslint: ^8.0.0 + '@typescript-eslint/eslint-plugin': 6 - 7 + eslint: '8' peerDependenciesMeta: '@typescript-eslint/eslint-plugin': optional: true dependencies: - '@typescript-eslint/eslint-plugin': 6.19.0(@typescript-eslint/parser@6.19.0)(eslint@8.56.0)(typescript@5.3.3) + '@typescript-eslint/eslint-plugin': 7.0.2(@typescript-eslint/parser@7.0.2)(eslint@8.56.0)(typescript@5.3.3) eslint: 8.56.0 eslint-rule-composer: 0.3.0 dev: true @@ -9155,7 +9347,7 @@ packages: '@eslint-community/regexpp': 4.10.0 '@eslint/eslintrc': 2.1.4 '@eslint/js': 8.56.0 - '@humanwhocodes/config-array': 0.11.13 + '@humanwhocodes/config-array': 0.11.14 '@humanwhocodes/module-importer': 1.0.1 '@nodelib/fs.walk': 1.2.8 '@ungap/structured-clone': 1.2.0 @@ -9176,7 +9368,7 @@ packages: glob-parent: 6.0.2 globals: 13.24.0 graphemer: 1.4.0 - ignore: 5.3.0 + ignore: 5.3.1 imurmurhash: 0.1.4 is-glob: 4.0.3 is-path-inside: 3.0.3 @@ -9197,8 +9389,8 @@ packages: resolution: {integrity: sha512-oruZaFkjorTpF32kDSI5/75ViwGeZginGGy2NoOSg3Q9bnwlnmDm4HLnkl0RE3n+njDXR037aY1+x58Z/zFdwQ==} engines: {node: ^12.22.0 || ^14.17.0 || >=16.0.0} dependencies: - acorn: 8.11.2 - acorn-jsx: 5.3.2(acorn@8.11.2) + acorn: 8.11.3 + acorn-jsx: 5.3.2(acorn@8.11.3) eslint-visitor-keys: 3.4.3 dev: true @@ -9378,8 +9570,8 @@ packages: resolution: {integrity: sha512-bijHueCGd0LqqNK9b5oCMHc0MluJAx0cwqASgbWMvkO01lCYgIhacVRLcaDz3QnyYIRNJRDwMb41VuT6pHJ91Q==} dev: false - /fastq@1.16.0: - resolution: {integrity: sha512-ifCoaXsDrsdkWTtiNJX5uzHDsrck5TzfKKDcuFFTIrrc/BS076qgEIfoIy1VeZqViznfKiysPYTh/QeHtnIsYA==} + /fastq@1.17.1: + resolution: {integrity: sha512-sRVD3lWVIXWg6By68ZN7vho9a1pQcN/WBFaAAsDDFzlJjvoGx0P8z7V1t72grFJfJhu3YPZBuu25f7Kaw2jN1w==} dependencies: reusify: 1.0.4 dev: true @@ -9525,13 +9717,13 @@ packages: resolution: {integrity: sha512-CYcENa+FtcUKLmhhqyctpclsq7QF38pKjZHsGNiSQF5r4FtoKDWabFDl3hzaEQMvT1LHEysw5twgLvpYYb4vbw==} engines: {node: ^10.12.0 || >=12.0.0} dependencies: - flatted: 3.2.9 + flatted: 3.3.0 keyv: 4.5.4 rimraf: 3.0.2 dev: true - /flatted@3.2.9: - resolution: {integrity: sha512-36yxDn5H7OFZQla0/jFJmbIKTdZAQHngCedGxiMmpNfEZM0sdEeT+WczLQrjK6D7o2aiyLYDnkw0R3JK0Qv1RQ==} + /flatted@3.3.0: + resolution: {integrity: sha512-noqGuLw158+DuD9UPRKHpJ2hGxpFyDlYYrfM0mWt4XhT4n0lwzTLh70Tkdyy4kyTmyTT9Bv7bWAJqw7cgkEXDg==} dev: true /flatten@1.0.3: @@ -9539,13 +9731,13 @@ packages: deprecated: flatten is deprecated in favor of utility frameworks such as lodash. dev: true - /flow-parser@0.227.0: - resolution: {integrity: sha512-nOygtGKcX/siZK/lFzpfdHEfOkfGcTW7rNroR1Zsz6T/JxSahPALXVt5qVHq/fgvMJuv096BTKbgxN3PzVBaDA==} + /flow-parser@0.229.0: + resolution: {integrity: sha512-mOYmMuvJwAo/CvnMFEq4SHftq7E5188hYMTTxJyQOXk2nh+sgslRdYMw3wTthH+FMcFaZLtmBPuMu6IwztdoUQ==} engines: {node: '>=0.4.0'} dev: true - /focus-lock@1.3.2: - resolution: {integrity: sha512-kFI92jZVqa8rP4Yer2sLNlUDcOdEFxYum2tIIr4eCH0XF+pOmlg0xiY4tkbDmHJXt3phtbJoWs1L6PgUVk97rA==} + /focus-lock@1.3.3: + resolution: {integrity: sha512-hfXkZha7Xt4RQtrL1HBfspAuIj89Y0fb6GX0dfJilb8S2G/lvL4akPAcHq6xoD2NuZnDMCnZL/zQesMyeu6Psg==} engines: {node: '>=10'} dependencies: tslib: 2.6.2 @@ -9603,6 +9795,24 @@ packages: '@emotion/is-prop-valid': 0.8.8 dev: false + /framer-motion@11.0.5(react-dom@18.2.0)(react@18.2.0): + resolution: {integrity: sha512-Lb0EYbQcSK/pgyQUJm+KzsQrKrJRX9sFRyzl9hSr9gFG4Mk8yP7BjhuxvRXzblOM/+JxycrJdCDVmOQBsjpYlw==} + peerDependencies: + react: ^18.0.0 + react-dom: ^18.0.0 + peerDependenciesMeta: + react: + optional: true + react-dom: + optional: true + dependencies: + react: 18.2.0 + react-dom: 18.2.0(react@18.2.0) + tslib: 2.6.2 + optionalDependencies: + '@emotion/is-prop-valid': 0.8.8 + dev: false + /framesync@6.1.2: resolution: {integrity: sha512-jBTqhX6KaQVDyus8muwZbBeGGP0XgujBRbQ7gM7BRdS3CadCZIHiawyzYLnafYcvZIh5j8WE7cxZKFn7dXhu9g==} dependencies: @@ -9671,9 +9881,9 @@ packages: resolution: {integrity: sha512-Z5kx79swU5P27WEayXM1tBi5Ze/lbIyiNgU3qyXUOf9b2rgXYyF9Dy9Cx+IQv/Lc8WCG6L82zwUPpSS9hGehIg==} engines: {node: '>= 0.4'} dependencies: - call-bind: 1.0.5 + call-bind: 1.0.7 define-properties: 1.2.1 - es-abstract: 1.22.3 + es-abstract: 1.22.4 functions-have-names: 1.2.3 dev: true @@ -9711,13 +9921,15 @@ packages: resolution: {integrity: sha512-8vXOvuE167CtIc3OyItco7N/dpRtBbYOsPsXCz7X/PMnlGjYjSGuZJgM1Y7mmew7BKf9BqvLX2tnOVy1BBUsxQ==} dev: true - /get-intrinsic@1.2.2: - resolution: {integrity: sha512-0gSo4ml/0j98Y3lngkFEot/zhiCeWsbYIlZ+uZOVgzLyLaUw7wxUL+nCTP0XJvJg1AXulJRI3UJi8GsbDuxdGA==} + /get-intrinsic@1.2.4: + resolution: {integrity: sha512-5uYhsJH8VJBTv7oslg4BznJYhDoRI6waYCxMmCdnTrcCrHA/fCFKoTFz2JKKE0HdDFUF7/oQuhzumXJK7paBRQ==} + engines: {node: '>= 0.4'} dependencies: + es-errors: 1.3.0 function-bind: 1.1.2 - has-proto: 1.0.1 + has-proto: 1.0.3 has-symbols: 1.0.3 - hasown: 2.0.0 + hasown: 2.0.1 /get-nonce@1.0.1: resolution: {integrity: sha512-FJhYRoDaiatfEkUK8HKlicmu/3SGFD51q3itKDGoSTysQJBnfOcxU5GxnhE1E6soB76MbT0MBtnKJuXyAx+96Q==} @@ -9752,23 +9964,24 @@ packages: engines: {node: '>=16'} dev: true - /get-symbol-description@1.0.0: - resolution: {integrity: sha512-2EmdH1YvIQiZpltCNgkuiUnyukzxM/R6NDJX31Ke3BG1Nq5b0S2PhX59UKi9vZpPDQVdqn+1IcaAwnzTT5vCjw==} + /get-symbol-description@1.0.2: + resolution: {integrity: sha512-g0QYk1dZBxGwk+Ngc+ltRH2IBp2f7zBkBMBJZCDerh6EhlhSR6+9irMCuT/09zD6qkarHUSn529sK/yL4S27mg==} engines: {node: '>= 0.4'} dependencies: - call-bind: 1.0.5 - get-intrinsic: 1.2.2 + call-bind: 1.0.7 + es-errors: 1.3.0 + get-intrinsic: 1.2.4 dev: true /giget@1.2.1: resolution: {integrity: sha512-4VG22mopWtIeHwogGSy1FViXVo0YT+m6BrqZfz0JJFwbSsePsCdOzdLIIli5BtMp7Xe8f/o2OmBpQX2NBOC24g==} hasBin: true dependencies: - citty: 0.1.5 + citty: 0.1.6 consola: 3.2.3 defu: 6.1.4 - node-fetch-native: 1.6.1 - nypm: 0.3.4 + node-fetch-native: 1.6.2 + nypm: 0.3.6 ohash: 1.1.3 pathe: 1.1.2 tar: 6.2.0 @@ -9854,7 +10067,7 @@ packages: array-union: 2.1.0 dir-glob: 3.0.1 fast-glob: 3.3.2 - ignore: 5.3.0 + ignore: 5.3.1 merge2: 1.4.1 slash: 3.0.0 dev: true @@ -9874,7 +10087,7 @@ packages: /gopd@1.0.1: resolution: {integrity: sha512-d65bNlIadxvpb/A2abVdlqKqV563juRnZ1Wtk6s1sIR8uNsXR70xqIzVqxVf1eTqDunwT2MkczEeaezCKTZhwA==} dependencies: - get-intrinsic: 1.2.2 + get-intrinsic: 1.2.4 /graceful-fs@4.2.11: resolution: {integrity: sha512-RbJ5/jmFcNNCcDV5o9eTnBLJ/HszWV0P73bc+Ff4nS/rJj+YaS6IGyiOL0VoBYX+l1Wrl3k63h/KrH+nhJ0XvQ==} @@ -9922,28 +10135,28 @@ packages: engines: {node: '>=8'} dev: true - /has-property-descriptors@1.0.1: - resolution: {integrity: sha512-VsX8eaIewvas0xnvinAe9bw4WfIeODpGYikiWYLH+dma0Jw6KHYqWiWfhQlgOVK8D6PvjubK5Uc4P0iIhIcNVg==} + /has-property-descriptors@1.0.2: + resolution: {integrity: sha512-55JNKuIW+vq4Ke1BjOTjM2YctQIvCT7GFzHwmfZPGo5wnrgkid0YQtnAleFSqumZm4az3n2BS+erby5ipJdgrg==} dependencies: - get-intrinsic: 1.2.2 + es-define-property: 1.0.0 - /has-proto@1.0.1: - resolution: {integrity: sha512-7qE+iP+O+bgF9clE5+UoBFzE65mlBiVj3tKCrlNQ0Ogwm0BjpT/gK4SlLYDMybDh5I3TCTKnPPa0oMG7JDYrhg==} + /has-proto@1.0.3: + resolution: {integrity: sha512-SJ1amZAJUiZS+PhsVLf5tGydlaVB8EdFpaSO4gmiUKUOxk8qzn5AIy4ZeJUmh22znIdk/uMAUT2pl3FxzVUH+Q==} engines: {node: '>= 0.4'} /has-symbols@1.0.3: resolution: {integrity: sha512-l3LCuF6MgDNwTDKkdYGEihYjt5pRPbEg46rtlmnSPlUbgmB8LOIrKJbYYFBSbnPaJexMKtiPO8hmeRjRz2Td+A==} engines: {node: '>= 0.4'} - /has-tostringtag@1.0.0: - resolution: {integrity: sha512-kFjcSNhnlGV1kyoGk7OXKSawH5JOb/LzUc5w9B02hOTO0dfFRjbHQKvg1d6cf3HbeUmtU9VbbV3qzZ2Teh97WQ==} + /has-tostringtag@1.0.2: + resolution: {integrity: sha512-NqADB8VjPFLM2V0VvHUewwwsw0ZWBaIdgo+ieHtK3hasLz4qeCRjYcqfB6AQrBggRKppKF8L52/VqdVsO47Dlw==} engines: {node: '>= 0.4'} dependencies: has-symbols: 1.0.3 dev: true - /hasown@2.0.0: - resolution: {integrity: sha512-vUptKVTpIJhcczKBbgnS+RtcuYMB8+oNzPK2/Hp3hanz8JmpATdmmgLgSaadVREkDm+e2giHwY3ZRkyjSIDDFA==} + /hasown@2.0.1: + resolution: {integrity: sha512-1/th4MHjnwncwXsIW6QMzlvYL9kG5e/CpVvLRZe4XPa8TOUNbCELqmvhDmnkNsAjwaG4+I8gJJL0JBvTTLO9qA==} engines: {node: '>= 0.4'} dependencies: function-bind: 1.1.2 @@ -10009,18 +10222,18 @@ packages: resolution: {integrity: sha512-ygGZLjmXfPHj+ZWh6LwbC37l43MhfztxetbFCoYTM2VjkIUpeHgSNn7QIyVFj7YQ1Wl9Cbw5sholVJPzWvC2MQ==} dev: false - /i18next-http-backend@2.4.2: - resolution: {integrity: sha512-wKrgGcaFQ4EPjfzBTjzMU0rbFTYpa0S5gv9N/d8WBmWS64+IgJb7cHddMvV+tUkse7vUfco3eVs2lB+nJhPo3w==} + /i18next-http-backend@2.4.3: + resolution: {integrity: sha512-jo2M03O6n1/DNb51WSQ8PsQ0xEELzLZRdYUTbf17mLw3rVwnJF9hwNgMXvEFSxxb+N8dT+o0vtigA6s5mGWyPA==} dependencies: cross-fetch: 4.0.0 transitivePeerDependencies: - encoding dev: false - /i18next@23.7.16: - resolution: {integrity: sha512-SrqFkMn9W6Wb43ZJ9qrO6U2U4S80RsFMA7VYFSqp7oc7RllQOYDCdRfsse6A7Cq/V8MnpxKvJCYgM8++27n4Fw==} + /i18next@23.9.0: + resolution: {integrity: sha512-f3MUciKqwzNV//mHG6EtdSlC65+nqH/3zK8sOSWqNV6FVu2tmHhF/rFOp9UF8S4m1odojtuipKaKJrP0Loh60g==} dependencies: - '@babel/runtime': 7.23.7 + '@babel/runtime': 7.23.9 dev: false /iconv-lite@0.4.24: @@ -10038,8 +10251,8 @@ packages: resolution: {integrity: sha512-dcyqhDvX1C46lXZcVqCpK+FtMRQVdIMN6/Df5js2zouUsqG7I6sFxitIC+7KYK29KdXOLHdu9zL4sFnoVQnqaA==} dev: true - /ignore@5.3.0: - resolution: {integrity: sha512-g7dmpshy+gD7mh88OC9NwSGTKoc3kyLAZQRU1mt53Aw/vnvfXnbC+F/7F7QoYVKbV+KNvJx8wArewKy1vXMtlg==} + /ignore@5.3.1: + resolution: {integrity: sha512-5Fytz/IraMjqpwfd34ke28PTVMjZjJG2MPn5t7OE4eUCUNf8BAa7b5WUS9/Qvr6mwOQS7Mk6vdsMno5he+T8Xw==} engines: {node: '>= 4'} dev: true @@ -10095,13 +10308,13 @@ packages: fast-loops: 1.1.3 dev: false - /internal-slot@1.0.6: - resolution: {integrity: sha512-Xj6dv+PsbtwyPpEflsejS+oIZxmMlV44zAhG479uYu89MsjcYOhCFnNyKrkJrihbsiasQyY0afoCl/9BLR65bg==} + /internal-slot@1.0.7: + resolution: {integrity: sha512-NGnrKwXzSms2qUUih/ILZ5JBqNTSa1+ZmP6flaIp6KmSElgE9qdndzS3cqjrDovwFdmwsGsLdeFgB6suw+1e9g==} engines: {node: '>= 0.4'} dependencies: - get-intrinsic: 1.2.2 - hasown: 2.0.0 - side-channel: 1.0.4 + es-errors: 1.3.0 + hasown: 2.0.1 + side-channel: 1.0.5 dev: true /invariant@2.2.4: @@ -10109,8 +10322,8 @@ packages: dependencies: loose-envify: 1.4.0 - /ip@2.0.0: - resolution: {integrity: sha512-WKa+XuLG1A1R0UWhl2+1XQSi+fZWMsYKffMZTTYsiZaUD8k2yDAj5atimTUD2TZkyCkNEeYE5NhFZmupOGtjYQ==} + /ip@2.0.1: + resolution: {integrity: sha512-lJUL9imLTNi1ZfXT+DU6rBBdbiKGBuay9B6xGSPVjUeQwaH1RIGqef8RZkUtHioLmSNpPR5M4HVKJGm1j8FWVQ==} dev: true /ipaddr.js@1.9.1: @@ -10127,16 +10340,16 @@ packages: resolution: {integrity: sha512-8Q7EARjzEnKpt/PCD7e1cgUS0a6X8u5tdSiMqXhojOdoV9TsMsiO+9VLC5vAmO8N7/GmXn7yjR8qnA6bVAEzfA==} engines: {node: '>= 0.4'} dependencies: - call-bind: 1.0.5 - has-tostringtag: 1.0.0 + call-bind: 1.0.7 + has-tostringtag: 1.0.2 dev: true - /is-array-buffer@3.0.2: - resolution: {integrity: sha512-y+FyyR/w8vfIRq4eQcM1EYgSTnmHXPqaF+IgzgraytCFq5Xh8lllDVmAZolPJiZttZLeFSINPYMaEJ7/vWUa1w==} + /is-array-buffer@3.0.4: + resolution: {integrity: sha512-wcjaerHw0ydZwfhiKbXJWLDY8A7yV7KhjQOpb83hGgGfId/aQa4TOvwyzn2PuswW2gPCYEL/nEAiSVpdOj1lXw==} + engines: {node: '>= 0.4'} dependencies: - call-bind: 1.0.5 - get-intrinsic: 1.2.2 - is-typed-array: 1.1.12 + call-bind: 1.0.7 + get-intrinsic: 1.2.4 dev: true /is-arrayish@0.2.1: @@ -10146,7 +10359,7 @@ packages: resolution: {integrity: sha512-Y1JXKrfykRJGdlDwdKlLpLyMIiWqWvuSd17TvZk68PLAOGOoF4Xyav1z0Xhoi+gCYjZVeC5SI+hYFOfvXmGRCA==} engines: {node: '>= 0.4'} dependencies: - has-tostringtag: 1.0.0 + has-tostringtag: 1.0.2 dev: true /is-bigint@1.0.4: @@ -10166,8 +10379,8 @@ packages: resolution: {integrity: sha512-gDYaKHJmnj4aWxyj6YHyXVpdQawtVLHU5cb+eztPGczf6cjuTdwve5ZIEfgXqH4e57An1D1AKf8CZ3kYrQRqYA==} engines: {node: '>= 0.4'} dependencies: - call-bind: 1.0.5 - has-tostringtag: 1.0.0 + call-bind: 1.0.7 + has-tostringtag: 1.0.2 dev: true /is-callable@1.2.7: @@ -10178,13 +10391,13 @@ packages: /is-core-module@2.13.1: resolution: {integrity: sha512-hHrIjvZsftOsvKSn2TRYl63zvxsgE0K+0mYMoH6gD4omR5IWB2KynivBQczo3+wF1cCkjzvptnI9Q0sPU66ilw==} dependencies: - hasown: 2.0.0 + hasown: 2.0.1 /is-date-object@1.0.5: resolution: {integrity: sha512-9YQaSxsAiSwcvS33MBk3wTCVnWK+HhF8VZR2jRxehM16QcVOdHqPn4VPHmRK4lSr38n9JriurInLcP90xsYNfQ==} engines: {node: '>= 0.4'} dependencies: - has-tostringtag: 1.0.0 + has-tostringtag: 1.0.2 dev: true /is-deflate@1.0.0: @@ -10205,7 +10418,7 @@ packages: /is-finalizationregistry@1.0.2: resolution: {integrity: sha512-0by5vtUJs8iFQb5TYUHHPudOR+qXYIMKtiUzvLIZITZUjknFmziyBJuLhVRc+Ds0dREFlskDNJKYIdIzu/9pfw==} dependencies: - call-bind: 1.0.5 + call-bind: 1.0.7 dev: true /is-fullwidth-code-point@3.0.0: @@ -10217,7 +10430,7 @@ packages: resolution: {integrity: sha512-jsEjy9l3yiXEQ+PsXdmBwEPcOxaXWLspKdplFUVI9vq1iZgIekeC0L167qeu86czQaxed3q/Uzuw0swL0irL8A==} engines: {node: '>= 0.4'} dependencies: - has-tostringtag: 1.0.0 + has-tostringtag: 1.0.2 dev: true /is-glob@4.0.3: @@ -10245,12 +10458,12 @@ packages: resolution: {integrity: sha512-E+zBKpQ2t6MEo1VsonYmluk9NxGrbzpeeLC2xIViuO2EjU2xsXsBPwTr3Ykv9l08UYEVEdWeRZNouaZqF6RN0w==} engines: {node: '>= 0.4'} dependencies: - call-bind: 1.0.5 + call-bind: 1.0.7 define-properties: 1.2.1 dev: true - /is-negative-zero@2.0.2: - resolution: {integrity: sha512-dqJvarLawXsFbNDeJW7zAz8ItJ9cd28YufuuFzh0G8pNHjJMnY08Dv7sYX2uF5UpQOwieAeOExEYAWWfu7ZZUA==} + /is-negative-zero@2.0.3: + resolution: {integrity: sha512-5KoIu2Ngpyek75jXodFvnafB6DJgr3u8uuK0LEZJjrU19DrMD3EVERaR8sjz8CCGgpZvxPl9SuE1GMVPFHx1mw==} engines: {node: '>= 0.4'} dev: true @@ -10258,7 +10471,7 @@ packages: resolution: {integrity: sha512-k1U0IRzLMo7ZlYIfzRu23Oh6MiIFasgpb9X76eqfFZAqwH44UI4KTBvBYIZ1dSL9ZzChTB9ShHfLkR4pdW5krQ==} engines: {node: '>= 0.4'} dependencies: - has-tostringtag: 1.0.0 + has-tostringtag: 1.0.2 dev: true /is-number@7.0.0: @@ -10297,8 +10510,8 @@ packages: resolution: {integrity: sha512-kvRdxDsxZjhzUX07ZnLydzS1TU/TJlTUHHY4YLL87e37oUA49DfkLqgy+VjFocowy29cKvcSiu+kIv728jTTVg==} engines: {node: '>= 0.4'} dependencies: - call-bind: 1.0.5 - has-tostringtag: 1.0.0 + call-bind: 1.0.7 + has-tostringtag: 1.0.2 dev: true /is-regexp@1.0.0: @@ -10317,7 +10530,7 @@ packages: /is-shared-array-buffer@1.0.2: resolution: {integrity: sha512-sqN2UDu1/0y6uvXyStCOzyhAjCSlHceFoMKJW8W9EU9cvic/QdsZ0kEU93HEy3IUEFZIiH/3w+AH/UQbPHNdhA==} dependencies: - call-bind: 1.0.5 + call-bind: 1.0.7 dev: true /is-stream@2.0.1: @@ -10334,7 +10547,7 @@ packages: resolution: {integrity: sha512-tE2UXzivje6ofPW7l23cjDOMa09gb7xlAqG6jG5ej6uPV32TlWP3NKPigtaGeHNu9fohccRYvIiZMfOOnOYUtg==} engines: {node: '>= 0.4'} dependencies: - has-tostringtag: 1.0.0 + has-tostringtag: 1.0.2 dev: true /is-symbol@1.0.4: @@ -10344,11 +10557,11 @@ packages: has-symbols: 1.0.3 dev: true - /is-typed-array@1.1.12: - resolution: {integrity: sha512-Z14TF2JNG8Lss5/HMqt0//T9JeHXttXy5pH/DBU4vi98ozO2btxzq9MwYDZYnKwU8nRsz/+GVFVRDq3DkVuSPg==} + /is-typed-array@1.1.13: + resolution: {integrity: sha512-uZ25/bUAlUY5fR4OKT4rZQEBrzQWYV9ZJYGGsUmEJ6thodVJ1HX64ePQ6Z0qPWP+m+Uq6e9UugrE38jeYsDSMw==} engines: {node: '>= 0.4'} dependencies: - which-typed-array: 1.1.13 + which-typed-array: 1.1.14 dev: true /is-unicode-supported@0.1.0: @@ -10372,14 +10585,14 @@ packages: /is-weakref@1.0.2: resolution: {integrity: sha512-qctsuLZmIQ0+vSSMfoVvyFe2+GSEvnmZ2ezTup1SBse9+twCCeial6EEi3Nc2KFcf6+qz2FBPnjXsk8xhKSaPQ==} dependencies: - call-bind: 1.0.5 + call-bind: 1.0.7 dev: true /is-weakset@2.0.2: resolution: {integrity: sha512-t2yVvttHkQktwnNNmBQ98AhENLdPUTDTE21uPqAQ0ARwQfGeQKRVS0NNurH7bTf7RrvcVn1OOge45CnBeHCSmg==} dependencies: - call-bind: 1.0.5 - get-intrinsic: 1.2.2 + call-bind: 1.0.7 + get-intrinsic: 1.2.4 dev: true /is-wsl@2.2.0: @@ -10415,8 +10628,8 @@ packages: resolution: {integrity: sha512-pzqtp31nLv/XFOzXGuvhCb8qhjmTVo5vjVk19XE4CRlSWz0KoeJ3bw9XsA7nOp9YBf4qHjwBxkDzKcME/J29Yg==} engines: {node: '>=8'} dependencies: - '@babel/core': 7.23.7 - '@babel/parser': 7.23.6 + '@babel/core': 7.23.9 + '@babel/parser': 7.23.9 '@istanbuljs/schema': 0.1.3 istanbul-lib-coverage: 3.2.2 semver: 6.3.1 @@ -10428,10 +10641,10 @@ packages: resolution: {integrity: sha512-DR33HMMr8EzwuRL8Y9D3u2BMj8+RqSE850jfGu59kS7tbmPLzGkZmVSfyCFSDxuZiEY6Rzt3T2NA/qU+NwVj1w==} dependencies: define-properties: 1.2.1 - get-intrinsic: 1.2.2 + get-intrinsic: 1.2.4 has-symbols: 1.0.3 - reflect.getprototypeof: 1.0.4 - set-function-name: 2.0.1 + reflect.getprototypeof: 1.0.5 + set-function-name: 2.0.2 dev: true /its-fine@1.1.1(react@18.2.0): @@ -10469,7 +10682,7 @@ packages: dependencies: '@jest/types': 29.6.3 '@types/graceful-fs': 4.1.9 - '@types/node': 20.11.5 + '@types/node': 20.11.19 anymatch: 3.1.3 fb-watchman: 2.0.2 graceful-fs: 4.2.11 @@ -10487,7 +10700,7 @@ packages: engines: {node: ^10.13.0 || ^12.13.0 || ^14.15.0 || >=15.0.0} dependencies: '@jest/types': 27.5.1 - '@types/node': 20.11.5 + '@types/node': 20.11.19 dev: true /jest-regex-util@29.6.3: @@ -10500,7 +10713,7 @@ packages: engines: {node: ^14.15.0 || ^16.10.0 || >=18.0.0} dependencies: '@jest/types': 29.6.3 - '@types/node': 20.11.5 + '@types/node': 20.11.19 chalk: 4.1.2 ci-info: 3.9.0 graceful-fs: 4.2.11 @@ -10511,7 +10724,7 @@ packages: resolution: {integrity: sha512-eIz2msL/EzL9UFTFFx7jBTkeZfku0yUAyZZZmJ93H2TYEiroIx2PQjEXcwYtYl8zXCxb+PAmA2hLIt/6ZEkPHw==} engines: {node: ^14.15.0 || ^16.10.0 || >=18.0.0} dependencies: - '@types/node': 20.11.5 + '@types/node': 20.11.19 jest-util: 29.7.0 merge-stream: 2.0.0 supports-color: 8.1.1 @@ -10528,6 +10741,10 @@ packages: /js-tokens@4.0.0: resolution: {integrity: sha512-RdJUflcE3cUzKiMqQgsCu06FPu9UdIJO0beYbPhHN4k6apgJtifcoCtT9bcxOpYBtpD2kCM6Sbzg4CausW/PKQ==} + /js-tokens@8.0.3: + resolution: {integrity: sha512-UfJMcSJc+SEXEl9lH/VLHSZbThQyLpw1vLO1Lb+j4RWDvG3N2f7yj3PVQA3cmkTBNldJ9eFnM+xEXxHIXrYiJw==} + dev: true + /js-yaml@3.14.1: resolution: {integrity: sha512-okMH7OXXJ7YrN9Ok3/SXrnu4iX9yOk+25nqX4imS2npuvTYDmo/QEZoqwZkYaIDk3jVvBOTOIEgEhaLOynBS9g==} hasBin: true @@ -10543,7 +10760,7 @@ packages: argparse: 2.0.1 dev: true - /jscodeshift@0.15.1(@babel/preset-env@7.23.8): + /jscodeshift@0.15.1(@babel/preset-env@7.23.9): resolution: {integrity: sha512-hIJfxUy8Rt4HkJn/zZPU9ChKfKZM1342waJ1QC2e2YsPcWhM+3BJ4dcfQCzArTrk1jJeNLB341H+qOcEHRxJZg==} hasBin: true peerDependencies: @@ -10552,20 +10769,20 @@ packages: '@babel/preset-env': optional: true dependencies: - '@babel/core': 7.23.7 - '@babel/parser': 7.23.6 - '@babel/plugin-transform-class-properties': 7.23.3(@babel/core@7.23.7) - '@babel/plugin-transform-modules-commonjs': 7.23.3(@babel/core@7.23.7) - '@babel/plugin-transform-nullish-coalescing-operator': 7.23.4(@babel/core@7.23.7) - '@babel/plugin-transform-optional-chaining': 7.23.4(@babel/core@7.23.7) - '@babel/plugin-transform-private-methods': 7.23.3(@babel/core@7.23.7) - '@babel/preset-env': 7.23.8(@babel/core@7.23.7) - '@babel/preset-flow': 7.23.3(@babel/core@7.23.7) - '@babel/preset-typescript': 7.23.3(@babel/core@7.23.7) - '@babel/register': 7.23.7(@babel/core@7.23.7) - babel-core: 7.0.0-bridge.0(@babel/core@7.23.7) + '@babel/core': 7.23.9 + '@babel/parser': 7.23.9 + '@babel/plugin-transform-class-properties': 7.23.3(@babel/core@7.23.9) + '@babel/plugin-transform-modules-commonjs': 7.23.3(@babel/core@7.23.9) + '@babel/plugin-transform-nullish-coalescing-operator': 7.23.4(@babel/core@7.23.9) + '@babel/plugin-transform-optional-chaining': 7.23.4(@babel/core@7.23.9) + '@babel/plugin-transform-private-methods': 7.23.3(@babel/core@7.23.9) + '@babel/preset-env': 7.23.9(@babel/core@7.23.9) + '@babel/preset-flow': 7.23.3(@babel/core@7.23.9) + '@babel/preset-typescript': 7.23.3(@babel/core@7.23.9) + '@babel/register': 7.23.7(@babel/core@7.23.9) + babel-core: 7.0.0-bridge.0(@babel/core@7.23.9) chalk: 4.1.2 - flow-parser: 0.227.0 + flow-parser: 0.229.0 graceful-fs: 4.2.11 micromatch: 4.0.5 neo-async: 2.6.2 @@ -10679,8 +10896,8 @@ packages: resolution: {integrity: sha512-Y+60/zizpJ3HRH8DCss+q95yr6145JXZo46OTpFvDZWLfRCE4qChOyk1b26nMaNpfHHgxagk9dXT5OP0Tfe+dQ==} dev: true - /konva@9.3.1: - resolution: {integrity: sha512-KXHJVUrYVWFIJUbnlw8QUZDBGC1jx6wwRsGaByPm/2yk78xw7hKquCMNEd9EtVqGz/jUkKFJAWom77TLB+zVOA==} + /konva@9.3.3: + resolution: {integrity: sha512-cg/AHxnfawZ1rKxygCnzx0TZY7hQiQiAKgAHPinEwMn49MVrBkeKLj2d0EaleoFG/0y0XhEKTD0dFZiPPdWlCQ==} dev: false /lazy-universal-dotenv@4.0.0: @@ -10688,7 +10905,7 @@ packages: engines: {node: '>=14.0.0'} dependencies: app-root-dir: 1.0.2 - dotenv: 16.3.1 + dotenv: 16.4.5 dotenv-expand: 10.0.0 dev: true @@ -10799,8 +11016,8 @@ packages: get-func-name: 2.0.2 dev: true - /lru-cache@10.1.0: - resolution: {integrity: sha512-/1clY/ui8CzjKFyjdvwPWJUYKiFVXG2I2cY0ssG7h4+hwk+XOIX7ZSG9Q7TW8TW3Kp3BUSqgFWBLgL4PJ+Blag==} + /lru-cache@10.2.0: + resolution: {integrity: sha512-2bIM8x+VAf6JT4bKAljS1qUWgMsqZRPGJS6FSahIMPVvctcNhyVp7AJu7quxOW9jwkryBReKZY5tY5JYv2n/7Q==} engines: {node: 14 || >=16.14} dev: true @@ -10852,7 +11069,7 @@ packages: pretty-ms: 7.0.1 rc: 1.2.8 stream-to-array: 2.3.0 - ts-graphviz: 1.8.1 + ts-graphviz: 1.8.2 typescript: 5.3.3 walkdir: 0.4.1 transitivePeerDependencies: @@ -10866,8 +11083,8 @@ packages: '@jridgewell/sourcemap-codec': 1.4.15 dev: true - /magic-string@0.30.5: - resolution: {integrity: sha512-7xlpfBaQaP/T6Vh8MO/EqXSW5En6INHEvEXQiuff7Gku0PWjU3uf6w/j9o7O+SpB5fOAkrI5HeoNgwjEO0pFsA==} + /magic-string@0.30.7: + resolution: {integrity: sha512-8vBuFF/I/+OSLRmdf2wwFCJCz+nSn0m6DPvGH1fS/KiQoSaR+sETbov0eIk9KhEKy8CYqIkIAnbohxT/4H0kuA==} engines: {node: '>=12'} dependencies: '@jridgewell/sourcemap-codec': 1.4.15 @@ -10898,8 +11115,8 @@ packages: resolution: {integrity: sha512-0aF7ZmVon1igznGI4VS30yugpduQW3y3GkcgGJOp7d8x8QrizhigUxjI/m2UojsXXto+jLAH3KSz+xOJTiORjg==} dev: true - /markdown-to-jsx@7.4.0(react@18.2.0): - resolution: {integrity: sha512-zilc+MIkVVXPyTb4iIUTIz9yyqfcWjszGXnwF9K/aiBWcHXFcmdEMTkG01/oQhwSCH7SY1BnG6+ev5BzWmbPrg==} + /markdown-to-jsx@7.4.1(react@18.2.0): + resolution: {integrity: sha512-GbrbkTnHp9u6+HqbPRFJbObi369AgJNXi/sGqq5HRsoZW063xR1XDCaConqq+whfEIAlzB1YPnOgsPc7B7bc/A==} engines: {node: '>= 10'} peerDependencies: react: '>= 0.14.0' @@ -11073,7 +11290,7 @@ packages: acorn: 8.11.3 pathe: 1.1.2 pkg-types: 1.0.3 - ufo: 1.3.2 + ufo: 1.4.0 dev: true /module-definition@3.4.0: @@ -11151,6 +11368,11 @@ packages: hasBin: true dev: true + /nanostores@0.10.0: + resolution: {integrity: sha512-Poy5+9wFXOD0jAstn4kv9n686U2BFw48z/W8lms8cS8lcbRz7BU20JxZ3e/kkKQVfRrkm4yLWCUA6GQINdvJCQ==} + engines: {node: ^18.0.0 || >=20.0.0} + dev: false + /nanostores@0.9.5: resolution: {integrity: sha512-Z+p+g8E7yzaWwOe5gEUB2Ox0rCEeXWYIZWmYvw/ajNYX8DlXdMvMDj8DWfM/subqPAcsf8l8Td4iAwO1DeIIRQ==} engines: {node: ^16.0.0 || ^18.0.0 || >=20.0.0} @@ -11191,8 +11413,8 @@ packages: minimatch: 3.1.2 dev: true - /node-fetch-native@1.6.1: - resolution: {integrity: sha512-bW9T/uJDPAJB2YNYEpWzE54U5O3MQidXsOyTfnbKYtTtFexRvGzb1waphBN4ZwP6EcIvYYEOwW0b72BpAqydTw==} + /node-fetch-native@1.6.2: + resolution: {integrity: sha512-69mtXOFZ6hSkYiXAVB5SqaRvrbITC/NPyqv7yuu/qw0nmgPyYbIMYYNIDhNtwPrzk0ptrimrLz/hhjvm4w5Z+w==} dev: true /node-fetch@2.7.0: @@ -11218,14 +11440,14 @@ packages: resolution: {integrity: sha512-8Q1hXew6ETzqKRAs3jjLioSxNfT1cx74ooiF8RlAONwVMcfq+UdzLC2eB5qcPldUxaE5w3ytLkrmV1TGddhZTA==} engines: {node: '>=6.0'} dependencies: - '@babel/parser': 7.23.6 + '@babel/parser': 7.23.9 dev: true /node-source-walk@5.0.2: resolution: {integrity: sha512-Y4jr/8SRS5hzEdZ7SGuvZGwfORvNsSsNRwDXx5WisiqzsVfeftDvRgfeqWNgZvWSJbgubTRVRYBzK6UO+ErqjA==} engines: {node: '>=12'} dependencies: - '@babel/parser': 7.23.6 + '@babel/parser': 7.23.9 dev: true /normalize-package-data@2.5.0: @@ -11256,15 +11478,15 @@ packages: path-key: 4.0.0 dev: true - /nypm@0.3.4: - resolution: {integrity: sha512-1JLkp/zHBrkS3pZ692IqOaIKSYHmQXgqfELk6YTOfVBnwealAmPA1q2kKK7PHJAHSMBozerThEFZXP3G6o7Ukg==} + /nypm@0.3.6: + resolution: {integrity: sha512-2CATJh3pd6CyNfU5VZM7qSwFu0ieyabkEdnogE30Obn1czrmOYiZ8DOZLe1yBdLKWoyD3Mcy2maUs+0MR3yVjQ==} engines: {node: ^14.16.0 || >=16.10.0} hasBin: true dependencies: - citty: 0.1.5 + citty: 0.1.6 execa: 8.0.1 pathe: 1.1.2 - ufo: 1.3.2 + ufo: 1.4.0 dev: true /object-assign@4.1.1: @@ -11279,7 +11501,7 @@ packages: resolution: {integrity: sha512-3cyDsyHgtmi7I7DfSSI2LDp6SK2lwvtbg0p0R1e0RvTqF5ceGx+K2dfSjm1bKDMVCFEDAQvy+o8c6a7VujOddw==} engines: {node: '>= 0.4'} dependencies: - call-bind: 1.0.5 + call-bind: 1.0.7 define-properties: 1.2.1 dev: true @@ -11291,7 +11513,7 @@ packages: resolution: {integrity: sha512-byy+U7gp+FVwmyzKPYhW2h5l3crpmGsxl7X2s8y43IgxvG4g3QZ6CffDtsNQy1WsmZpQbO+ybo0AlW7TY6DcBQ==} engines: {node: '>= 0.4'} dependencies: - call-bind: 1.0.5 + call-bind: 1.0.7 define-properties: 1.2.1 has-symbols: 1.0.3 object-keys: 1.1.1 @@ -11301,43 +11523,44 @@ packages: resolution: {integrity: sha512-jCBs/0plmPsOnrKAfFQXRG2NFjlhZgjjcBLSmTnEhU8U6vVTsVe8ANeQJCHTl3gSsI4J+0emOoCgoKlmQPMgmA==} engines: {node: '>= 0.4'} dependencies: - call-bind: 1.0.5 + call-bind: 1.0.7 define-properties: 1.2.1 - es-abstract: 1.22.3 + es-abstract: 1.22.4 dev: true /object.fromentries@2.0.7: resolution: {integrity: sha512-UPbPHML6sL8PI/mOqPwsH4G6iyXcCGzLin8KvEPenOZN5lpCNBZZQ+V62vdjB1mQHrmqGQt5/OJzemUA+KJmEA==} engines: {node: '>= 0.4'} dependencies: - call-bind: 1.0.5 + call-bind: 1.0.7 define-properties: 1.2.1 - es-abstract: 1.22.3 + es-abstract: 1.22.4 dev: true - /object.groupby@1.0.1: - resolution: {integrity: sha512-HqaQtqLnp/8Bn4GL16cj+CUYbnpe1bh0TtEaWvybszDG4tgxCJuRpV8VGuvNaI1fAnI4lUJzDG55MXcOH4JZcQ==} + /object.groupby@1.0.2: + resolution: {integrity: sha512-bzBq58S+x+uo0VjurFT0UktpKHOZmv4/xePiOA1nbB9pMqpGK7rUPNgf+1YC+7mE+0HzhTMqNUuCqvKhj6FnBw==} dependencies: - call-bind: 1.0.5 + array.prototype.filter: 1.0.3 + call-bind: 1.0.7 define-properties: 1.2.1 - es-abstract: 1.22.3 - get-intrinsic: 1.2.2 + es-abstract: 1.22.4 + es-errors: 1.3.0 dev: true /object.hasown@1.1.3: resolution: {integrity: sha512-fFI4VcYpRHvSLXxP7yiZOMAd331cPfd2p7PFDVbgUsYOfCT3tICVqXWngbjr4m49OvsBwUBQ6O2uQoJvy3RexA==} dependencies: define-properties: 1.2.1 - es-abstract: 1.22.3 + es-abstract: 1.22.4 dev: true /object.values@1.1.7: resolution: {integrity: sha512-aU6xnDFYT3x17e/f0IiiwlGPTy2jzMySGfUB4fq6z7CV8l85CWHDk5ErhyhpfDHhrOMwGFhSQkhMGHaIotA6Ng==} engines: {node: '>= 0.4'} dependencies: - call-bind: 1.0.5 + call-bind: 1.0.7 define-properties: 1.2.1 - es-abstract: 1.22.3 + es-abstract: 1.22.4 dev: true /ohash@1.1.3: @@ -11389,15 +11612,15 @@ packages: resolution: {integrity: sha512-N4YtSYJqghVu4iek2ZUvcN/0aqH1kRDuNqzcycDxhOUpg7GdvLa2F3DgS6yBNhInhv2r/6I0Flkn7CqL8+nIcw==} dev: true - /openapi-typescript@6.7.3: - resolution: {integrity: sha512-es3mGcDXV6TKPo6n3aohzHm0qxhLyR39MhF6mkD1FwFGjhxnqMqfSIgM0eCpInZvqatve4CxmXcMZw3jnnsaXw==} + /openapi-typescript@6.7.4: + resolution: {integrity: sha512-EZyeW9Wy7UDCKv0iYmKrq2pVZtquXiD/YHiUClAKqiMi42nodx/EQH11K6fLqjt1IZlJmVokrAsExsBMM2RROQ==} hasBin: true dependencies: ansi-colors: 4.1.3 fast-glob: 3.3.2 js-yaml: 4.1.0 supports-color: 9.4.0 - undici: 5.28.2 + undici: 5.28.3 yargs-parser: 21.1.1 dev: true @@ -11428,16 +11651,6 @@ packages: wcwidth: 1.0.1 dev: true - /overlayscrollbars-react@0.5.3(overlayscrollbars@2.4.6)(react@18.2.0): - resolution: {integrity: sha512-mq9D9tbfSeq0cti1kKMf3B3AzsEGwHcRIDX/K49CvYkHz/tKeU38GiahDkIPKTMEAp6lzKCo4x1eJZA6ZFYOxQ==} - peerDependencies: - overlayscrollbars: ^2.0.0 - react: '>=16.8.0' - dependencies: - overlayscrollbars: 2.4.6 - react: 18.2.0 - dev: false - /overlayscrollbars-react@0.5.4(overlayscrollbars@2.5.0)(react@18.2.0): resolution: {integrity: sha512-FPKx9XnXovTnI4+2JXig5uEaTLSEJ6svOwPzIfBBXTHBRNsz2+WhYUmfM0K/BNYxjgDEwuPm+NQhEoOA0RoG1g==} peerDependencies: @@ -11448,10 +11661,6 @@ packages: react: 18.2.0 dev: false - /overlayscrollbars@2.4.6: - resolution: {integrity: sha512-C7tmhetwMv9frEvIT/RfkAVEgbjRNz/Gh2zE8BVmN+jl35GRaAnz73rlGQCMRoC2arpACAXyMNnJkzHb7GBrcA==} - dev: false - /overlayscrollbars@2.5.0: resolution: {integrity: sha512-CWVC2dwS07XZfLHDm5GmZN1iYggiJ8Vufnvzwt0gwR9Yz1hVckKeTxg7VILZeYVGhDYJHZ1Xc8Xfys5dWZ1qiA==} dev: false @@ -11575,7 +11784,7 @@ packages: resolution: {integrity: sha512-MkhCqzzBEpPvxxQ71Md0b1Kk51W01lrYvlMzSUaIzNsODdd7mqhiimSZlr+VegAz5Z6Vzt9Xg2ttE//XBhH3EQ==} engines: {node: '>=16 || 14 >=14.17'} dependencies: - lru-cache: 10.1.0 + lru-cache: 10.2.0 minipass: 7.0.4 dev: true @@ -11660,11 +11869,16 @@ packages: engines: {node: '>=4'} dev: true - /polished@4.2.2: - resolution: {integrity: sha512-Sz2Lkdxz6F2Pgnpi9U5Ng/WdWAUZxmHrNPoVlm3aAemxoy2Qy7LGjQg4uf8qKelDAUW94F4np3iH2YPf2qefcQ==} + /polished@4.3.1: + resolution: {integrity: sha512-OBatVyC/N7SCW/FaDHrSd+vn0o5cS855TOmYi4OkdWUMSJCET/xip//ch8xGUvtr3i44X9LVyWwQlRMTN3pwSA==} engines: {node: '>=10'} dependencies: - '@babel/runtime': 7.23.8 + '@babel/runtime': 7.23.9 + dev: true + + /possible-typed-array-names@1.0.0: + resolution: {integrity: sha512-d7Uw+eZoloe0EHDIYoe+bQ5WXnGMOpmiZFTuMWCwpjzzkL2nTjcKiAk4hh8TjnGye2TwWOk3UXucZ+3rbmBa8Q==} + engines: {node: '>= 0.4'} dev: true /postcss-values-parser@2.0.1: @@ -11676,7 +11890,7 @@ packages: uniq: 1.0.1 dev: true - /postcss-values-parser@6.0.2(postcss@8.4.33): + /postcss-values-parser@6.0.2(postcss@8.4.35): resolution: {integrity: sha512-YLJpK0N1brcNJrs9WatuJFtHaV9q5aAOj+S4DI5S7jgHlRfm0PIbDCAFRYMQD5SHq7Fy6xsDhyutgS0QOAs0qw==} engines: {node: '>=10'} peerDependencies: @@ -11684,12 +11898,12 @@ packages: dependencies: color-name: 1.1.4 is-url-superb: 4.0.0 - postcss: 8.4.33 + postcss: 8.4.35 quote-unquote: 1.0.0 dev: true - /postcss@8.4.33: - resolution: {integrity: sha512-Kkpbhhdjw2qQs2O2DGX+8m5OVqEcbB9HRBvuYM9pgrjEFUg30A9LmXNlTAUj4S9kgtGyrMbTzVjH7E+s5Re2yg==} + /postcss@8.4.35: + resolution: {integrity: sha512-u5U8qYpBCpN13BsiEB0CbR1Hhh4Gc0zLFuedrHJKMctHCHAGrMdG0PRM/KErzAL3CU6/eckEtmHNB3x6e3c0vA==} engines: {node: ^10 || ^12 || >=14} dependencies: nanoid: 3.3.7 @@ -11751,8 +11965,8 @@ packages: hasBin: true dev: true - /prettier@3.2.4: - resolution: {integrity: sha512-FWu1oLHKCrtpO1ypU6J0SbK2d9Ckwysq6bHj/uaCP26DxrPpppCLQRGVuqAxSTvhF00AcvDRyYrLNW7ocBhFFQ==} + /prettier@3.2.5: + resolution: {integrity: sha512-3/GWa9aOC0YeD7LUfvOG2NiDyhOWRvt1k+rcKhOuYnMY24iiCphgneUfJDyFXd6rZCAnuLBv6UeAULtrhT/F4A==} engines: {node: '>=14'} hasBin: true dev: true @@ -11883,18 +12097,18 @@ packages: resolution: {integrity: sha512-MvjoMCJwEarSbUYk5O+nmoSzSutSsTwF85zcHPQ9OrlFoZOYIjaqBAJIqIXjptyD5vThxGq52Xu/MaJzRkIk4Q==} engines: {node: '>=0.6'} dependencies: - side-channel: 1.0.4 + side-channel: 1.0.5 dev: true /qs@6.11.2: resolution: {integrity: sha512-tDNIz22aBzCDxLtVH++VnTfzxlfeK5CbqohpSqpJgj1Wg/cQbStNAz3NuqCs5vV+pjBsK4x4pN9HlVh7rcYRiA==} engines: {node: '>=0.6'} dependencies: - side-channel: 1.0.4 + side-channel: 1.0.5 dev: true - /query-string@8.1.0: - resolution: {integrity: sha512-BFQeWxJOZxZGix7y+SByG3F36dA0AbTy9o6pSmKFcFz7DAj0re9Frkty3saBn3nHo3D0oZJ/+rx3r8H8r8Jbpw==} + /query-string@8.2.0: + resolution: {integrity: sha512-tUZIw8J0CawM5wyGBiDOAp7ObdRQh4uBor/fUR9ZjmbZVvw95OD9If4w3MQxr99rg0DJZ/9CIORcpEqU5hQG7g==} engines: {node: '>=14.16'} dependencies: decode-uri-component: 0.4.1 @@ -11981,9 +12195,9 @@ packages: resolution: {integrity: sha512-i8aF1nyKInZnANZ4uZrH49qn1paRgBZ7wZiCNBMnenlPzEv0mRl+ShpTVEI6wZNl8sSc79xZkivtgLKQArcanQ==} engines: {node: '>=16.14.0'} dependencies: - '@babel/core': 7.23.7 - '@babel/traverse': 7.23.7 - '@babel/types': 7.23.6 + '@babel/core': 7.23.9 + '@babel/traverse': 7.23.9 + '@babel/types': 7.23.9 '@types/babel__core': 7.20.5 '@types/babel__traverse': 7.20.5 '@types/doctrine': 0.0.9 @@ -12034,7 +12248,7 @@ packages: peerDependencies: react: '>=16.13.1' dependencies: - '@babel/runtime': 7.23.6 + '@babel/runtime': 7.23.9 react: 18.2.0 dev: false @@ -12042,7 +12256,7 @@ packages: resolution: {integrity: sha512-nsO+KSNgo1SbJqJEYRE9ERzo7YtYbou/OqjSQKxV7jcKox7+usiUVZOAC+XnDOABXggQTno0Y1CpVnuWEc1boQ==} dev: false - /react-focus-lock@2.11.1(@types/react@18.2.48)(react@18.2.0): + /react-focus-lock@2.11.1(@types/react@18.2.57)(react@18.2.0): resolution: {integrity: sha512-IXLwnTBrLTlKTpASZXqqXJ8oymWrgAlOfuuDYN4XCuN1YJ72dwX198UCaF1QqGUk5C3QOnlMik//n3ufcfe8Ig==} peerDependencies: '@types/react': ^16.8.0 || ^17.0.0 || ^18.0.0 @@ -12052,26 +12266,26 @@ packages: optional: true dependencies: '@babel/runtime': 7.23.9 - '@types/react': 18.2.48 - focus-lock: 1.3.2 + '@types/react': 18.2.57 + focus-lock: 1.3.3 prop-types: 15.8.1 react: 18.2.0 react-clientside-effect: 1.2.6(react@18.2.0) - use-callback-ref: 1.3.1(@types/react@18.2.48)(react@18.2.0) - use-sidecar: 1.1.2(@types/react@18.2.48)(react@18.2.0) + use-callback-ref: 1.3.1(@types/react@18.2.57)(react@18.2.0) + use-sidecar: 1.1.2(@types/react@18.2.57)(react@18.2.0) dev: false - /react-hook-form@7.49.3(react@18.2.0): - resolution: {integrity: sha512-foD6r3juidAT1cOZzpmD/gOKt7fRsDhXXZ0y28+Al1CHgX+AY1qIN9VSIIItXRq1dN68QrRwl1ORFlwjBaAqeQ==} - engines: {node: '>=18', pnpm: '8'} + /react-hook-form@7.50.1(react@18.2.0): + resolution: {integrity: sha512-3PCY82oE0WgeOgUtIr3nYNNtNvqtJ7BZjsbxh6TnYNbXButaD5WpjOmTjdxZfheuHKR68qfeFnEDVYoSSFPMTQ==} + engines: {node: '>=12.22.0'} peerDependencies: react: ^16.8.0 || ^17 || ^18 dependencies: react: 18.2.0 dev: false - /react-hotkeys-hook@4.4.4(react-dom@18.2.0)(react@18.2.0): - resolution: {integrity: sha512-wzZmqb/Obr0ds9Myc1sIFPJ52GA/Eeg/vXBWV0HA1LvHlVAW5Va3KB0q6EZNlNSHQWscWZ2K8+6w0GYSie2o7A==} + /react-hotkeys-hook@4.5.0(react-dom@18.2.0)(react@18.2.0): + resolution: {integrity: sha512-Samb85GSgAWFQNvVt3PS90LPPGSf9mkH/r4au81ZP1yOIFayLC3QAvqTgGtJ8YEDMXtPmaVBs6NgipHO6h4Mug==} peerDependencies: react: '>=16.8.1' react-dom: '>=16.8.1' @@ -12080,27 +12294,7 @@ packages: react-dom: 18.2.0(react@18.2.0) dev: false - /react-i18next@14.0.0(i18next@23.7.16)(react-dom@18.2.0)(react@18.2.0): - resolution: {integrity: sha512-OCrS8rHNAmnr8ggGRDxjakzihrMW7HCbsplduTm3EuuQ6fyvWGT41ksZpqbduYoqJurBmEsEVZ1pILSUWkHZng==} - peerDependencies: - i18next: '>= 23.2.3' - react: '>= 16.8.0' - react-dom: '*' - react-native: '*' - peerDependenciesMeta: - react-dom: - optional: true - react-native: - optional: true - dependencies: - '@babel/runtime': 7.23.7 - html-parse-stringify: 3.0.1 - i18next: 23.7.16 - react: 18.2.0 - react-dom: 18.2.0(react@18.2.0) - dev: false - - /react-i18next@14.0.5(i18next@23.7.16)(react-dom@18.2.0)(react@18.2.0): + /react-i18next@14.0.5(i18next@23.9.0)(react-dom@18.2.0)(react@18.2.0): resolution: {integrity: sha512-5+bQSeEtgJrMBABBL5lO7jPdSNAbeAZ+MlFWDw//7FnVacuVu3l9EeWFzBQvZsKy+cihkbThWOAThEdH8YjGEw==} peerDependencies: i18next: '>= 23.2.3' @@ -12115,7 +12309,7 @@ packages: dependencies: '@babel/runtime': 7.23.9 html-parse-stringify: 3.0.1 - i18next: 23.7.16 + i18next: 23.9.0 react: 18.2.0 react-dom: 18.2.0(react@18.2.0) dev: false @@ -12143,7 +12337,7 @@ packages: resolution: {integrity: sha512-xWGDIW6x921xtzPkhiULtthJHoJvBbF3q26fzloPCK0hsvxtPVelvftw3zjbHWSkR2km9Z+4uxbDDK/6Zw9B8w==} dev: true - /react-konva@18.2.10(konva@9.3.1)(react-dom@18.2.0)(react@18.2.0): + /react-konva@18.2.10(konva@9.3.3)(react-dom@18.2.0)(react@18.2.0): resolution: {integrity: sha512-ohcX1BJINL43m4ynjZ24MxFI1syjBdrXhqVxYVDw2rKgr3yuS0x/6m1Y2Z4sl4T/gKhfreBx8KHisd0XC6OT1g==} peerDependencies: konva: ^8.0.1 || ^7.2.5 || ^9.0.0 @@ -12152,7 +12346,7 @@ packages: dependencies: '@types/react-reconciler': 0.28.8 its-fine: 1.1.1(react@18.2.0) - konva: 9.3.1 + konva: 9.3.3 react: 18.2.0 react-dom: 18.2.0(react@18.2.0) react-reconciler: 0.29.0(react@18.2.0) @@ -12170,7 +12364,7 @@ packages: scheduler: 0.23.0 dev: false - /react-redux@9.1.0(@types/react@18.2.48)(react@18.2.0)(redux@5.0.1): + /react-redux@9.1.0(@types/react@18.2.57)(react@18.2.0)(redux@5.0.1): resolution: {integrity: sha512-6qoDzIO+gbrza8h3hjMA9aq4nwVFCKFtY2iLxCtVT38Swyy2C/dJCGBXHeHLtx6qlg/8qzc2MrhOeduf5K32wQ==} peerDependencies: '@types/react': ^18.2.25 @@ -12185,7 +12379,7 @@ packages: redux: optional: true dependencies: - '@types/react': 18.2.48 + '@types/react': 18.2.57 '@types/use-sync-external-store': 0.0.3 react: 18.2.0 redux: 5.0.1 @@ -12197,23 +12391,7 @@ packages: engines: {node: '>=0.10.0'} dev: true - /react-remove-scroll-bar@2.3.4(@types/react@18.2.48)(react@18.2.0): - resolution: {integrity: sha512-63C4YQBUt0m6ALadE9XV56hV8BgJWDmmTPY758iIJjfQKt2nYwoUrPk0LXRXcB/yIj82T1/Ixfdpdk68LwIB0A==} - engines: {node: '>=10'} - peerDependencies: - '@types/react': ^16.8.0 || ^17.0.0 || ^18.0.0 - react: ^16.8.0 || ^17.0.0 || ^18.0.0 - peerDependenciesMeta: - '@types/react': - optional: true - dependencies: - '@types/react': 18.2.48 - react: 18.2.0 - react-style-singleton: 2.2.1(@types/react@18.2.48)(react@18.2.0) - tslib: 2.6.2 - dev: true - - /react-remove-scroll-bar@2.3.5(@types/react@18.2.48)(react@18.2.0): + /react-remove-scroll-bar@2.3.5(@types/react@18.2.57)(react@18.2.0): resolution: {integrity: sha512-3cqjOqg6s0XbOjWvmasmqHch+RLxIEk2r/70rzGXuz3iIGQsQheEQyqYCBb5EECoD01Vo2SIbDqW4paLeLTASw==} engines: {node: '>=10'} peerDependencies: @@ -12223,13 +12401,12 @@ packages: '@types/react': optional: true dependencies: - '@types/react': 18.2.48 + '@types/react': 18.2.57 react: 18.2.0 - react-style-singleton: 2.2.1(@types/react@18.2.48)(react@18.2.0) + react-style-singleton: 2.2.1(@types/react@18.2.57)(react@18.2.0) tslib: 2.6.2 - dev: false - /react-remove-scroll@2.5.5(@types/react@18.2.48)(react@18.2.0): + /react-remove-scroll@2.5.5(@types/react@18.2.57)(react@18.2.0): resolution: {integrity: sha512-ImKhrzJJsyXJfBZ4bzu8Bwpka14c/fQt0k+cyFp/PBhTfyDnU5hjOtM4AG/0AMyy8oKzOTR0lDgJIM7pYXI0kw==} engines: {node: '>=10'} peerDependencies: @@ -12239,16 +12416,16 @@ packages: '@types/react': optional: true dependencies: - '@types/react': 18.2.48 + '@types/react': 18.2.57 react: 18.2.0 - react-remove-scroll-bar: 2.3.4(@types/react@18.2.48)(react@18.2.0) - react-style-singleton: 2.2.1(@types/react@18.2.48)(react@18.2.0) + react-remove-scroll-bar: 2.3.5(@types/react@18.2.57)(react@18.2.0) + react-style-singleton: 2.2.1(@types/react@18.2.57)(react@18.2.0) tslib: 2.6.2 - use-callback-ref: 1.3.1(@types/react@18.2.48)(react@18.2.0) - use-sidecar: 1.1.2(@types/react@18.2.48)(react@18.2.0) + use-callback-ref: 1.3.1(@types/react@18.2.57)(react@18.2.0) + use-sidecar: 1.1.2(@types/react@18.2.57)(react@18.2.0) dev: true - /react-remove-scroll@2.5.7(@types/react@18.2.48)(react@18.2.0): + /react-remove-scroll@2.5.7(@types/react@18.2.57)(react@18.2.0): resolution: {integrity: sha512-FnrTWO4L7/Bhhf3CYBNArEG/yROV0tKmTv7/3h9QCFvH6sndeFf1wPqOcbFVu5VAulS5dV1wGT3GZZ/1GawqiA==} engines: {node: '>=10'} peerDependencies: @@ -12258,17 +12435,17 @@ packages: '@types/react': optional: true dependencies: - '@types/react': 18.2.48 + '@types/react': 18.2.57 react: 18.2.0 - react-remove-scroll-bar: 2.3.5(@types/react@18.2.48)(react@18.2.0) - react-style-singleton: 2.2.1(@types/react@18.2.48)(react@18.2.0) + react-remove-scroll-bar: 2.3.5(@types/react@18.2.57)(react@18.2.0) + react-style-singleton: 2.2.1(@types/react@18.2.57)(react@18.2.0) tslib: 2.6.2 - use-callback-ref: 1.3.1(@types/react@18.2.48)(react@18.2.0) - use-sidecar: 1.1.2(@types/react@18.2.48)(react@18.2.0) + use-callback-ref: 1.3.1(@types/react@18.2.57)(react@18.2.0) + use-sidecar: 1.1.2(@types/react@18.2.57)(react@18.2.0) dev: false - /react-resizable-panels@1.0.9(react-dom@18.2.0)(react@18.2.0): - resolution: {integrity: sha512-QPfW3L7yetEC6z04G9AYYFz5kBklh8rTWcOsVFImYMNUVhr1Y1r9Qc/20Yks2tA+lXMBWCUz4fkGEvbS7tpBSg==} + /react-resizable-panels@2.0.9(react-dom@18.2.0)(react@18.2.0): + resolution: {integrity: sha512-ZylBvs7oG7Y/INWw3oYGolqgpFvoPW8MPeg9l1fURDeKpxrmUuCHBUmPj47BdZ11MODImu3kZYXG85rbySab7w==} peerDependencies: react: ^16.14.0 || ^17.0.0 || ^18.0.0 react-dom: ^16.14.0 || ^17.0.0 || ^18.0.0 @@ -12277,49 +12454,49 @@ packages: react-dom: 18.2.0(react@18.2.0) dev: false - /react-select@5.7.7(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0): + /react-select@5.7.7(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0): resolution: {integrity: sha512-HhashZZJDRlfF/AKj0a0Lnfs3sRdw/46VJIRd8IbB9/Ovr74+ZIwkAdSBjSPXsFMG+u72c5xShqwLSKIJllzqw==} peerDependencies: react: ^16.8.0 || ^17.0.0 || ^18.0.0 react-dom: ^16.8.0 || ^17.0.0 || ^18.0.0 dependencies: - '@babel/runtime': 7.23.6 + '@babel/runtime': 7.23.9 '@emotion/cache': 11.11.0 - '@emotion/react': 11.11.3(@types/react@18.2.48)(react@18.2.0) - '@floating-ui/dom': 1.5.3 + '@emotion/react': 11.11.3(@types/react@18.2.57)(react@18.2.0) + '@floating-ui/dom': 1.6.3 '@types/react-transition-group': 4.4.10 memoize-one: 6.0.0 prop-types: 15.8.1 react: 18.2.0 react-dom: 18.2.0(react@18.2.0) react-transition-group: 4.4.5(react-dom@18.2.0)(react@18.2.0) - use-isomorphic-layout-effect: 1.1.2(@types/react@18.2.48)(react@18.2.0) + use-isomorphic-layout-effect: 1.1.2(@types/react@18.2.57)(react@18.2.0) transitivePeerDependencies: - '@types/react' dev: false - /react-select@5.8.0(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0): + /react-select@5.8.0(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0): resolution: {integrity: sha512-TfjLDo58XrhP6VG5M/Mi56Us0Yt8X7xD6cDybC7yoRMUNm7BGO7qk8J0TLQOua/prb8vUOtsfnXZwfm30HGsAA==} peerDependencies: react: ^16.8.0 || ^17.0.0 || ^18.0.0 react-dom: ^16.8.0 || ^17.0.0 || ^18.0.0 dependencies: - '@babel/runtime': 7.23.7 + '@babel/runtime': 7.23.9 '@emotion/cache': 11.11.0 - '@emotion/react': 11.11.3(@types/react@18.2.48)(react@18.2.0) - '@floating-ui/dom': 1.5.3 + '@emotion/react': 11.11.3(@types/react@18.2.57)(react@18.2.0) + '@floating-ui/dom': 1.6.3 '@types/react-transition-group': 4.4.10 memoize-one: 6.0.0 prop-types: 15.8.1 react: 18.2.0 react-dom: 18.2.0(react@18.2.0) react-transition-group: 4.4.5(react-dom@18.2.0)(react@18.2.0) - use-isomorphic-layout-effect: 1.1.2(@types/react@18.2.48)(react@18.2.0) + use-isomorphic-layout-effect: 1.1.2(@types/react@18.2.57)(react@18.2.0) transitivePeerDependencies: - '@types/react' dev: false - /react-style-singleton@2.2.1(@types/react@18.2.48)(react@18.2.0): + /react-style-singleton@2.2.1(@types/react@18.2.57)(react@18.2.0): resolution: {integrity: sha512-ZWj0fHEMyWkHzKYUr2Bs/4zU6XLmq9HsgBURm7g5pAVfyn49DgUiNgY2d4lXRlYSiCif9YBGpQleewkcqddc7g==} engines: {node: '>=10'} peerDependencies: @@ -12329,22 +12506,22 @@ packages: '@types/react': optional: true dependencies: - '@types/react': 18.2.48 + '@types/react': 18.2.57 get-nonce: 1.0.1 invariant: 2.2.4 react: 18.2.0 tslib: 2.6.2 - /react-textarea-autosize@8.5.3(@types/react@18.2.48)(react@18.2.0): + /react-textarea-autosize@8.5.3(@types/react@18.2.57)(react@18.2.0): resolution: {integrity: sha512-XT1024o2pqCuZSuBt9FwHlaDeNtVrtCXu0Rnz88t1jUGheCLa3PhjE1GH8Ctm2axEtvdCl5SUHYschyQ0L5QHQ==} engines: {node: '>=10'} peerDependencies: react: ^16.8.0 || ^17.0.0 || ^18.0.0 dependencies: - '@babel/runtime': 7.23.6 + '@babel/runtime': 7.23.9 react: 18.2.0 use-composed-ref: 1.3.0(react@18.2.0) - use-latest: 1.2.1(@types/react@18.2.48)(react@18.2.0) + use-latest: 1.2.1(@types/react@18.2.57)(react@18.2.0) transitivePeerDependencies: - '@types/react' dev: false @@ -12355,7 +12532,7 @@ packages: react: '>=16.6.0' react-dom: '>=16.6.0' dependencies: - '@babel/runtime': 7.23.7 + '@babel/runtime': 7.23.9 dom-helpers: 5.2.1 loose-envify: 1.4.0 prop-types: 15.8.1 @@ -12373,8 +12550,8 @@ packages: tslib: 2.6.2 dev: false - /react-use@17.4.3(react-dom@18.2.0)(react@18.2.0): - resolution: {integrity: sha512-05Oyuwn4ZccdzLD4ttLbMe8TkobdKpOj7YCFE9VhVpbXrTWZpvCcMyroRw/Banh1RIcQRcM06tfzPpY5D9sTsQ==} + /react-use@17.5.0(react-dom@18.2.0)(react@18.2.0): + resolution: {integrity: sha512-PbfwSPMwp/hoL847rLnm/qkjg3sTRCvn6YhUZiHaUa3FA6/aNoFX79ul5Xt70O1rK+9GxSVqkY0eTwMdsR/bWg==} peerDependencies: react: '*' react-dom: '*' @@ -12397,8 +12574,8 @@ packages: tslib: 2.6.2 dev: false - /react-virtuoso@4.6.2(react-dom@18.2.0)(react@18.2.0): - resolution: {integrity: sha512-vvlqvzPif+MvBrJ09+hJJrVY0xJK9yran+A+/1iwY78k0YCVKsyoNPqoLxOxzYPggspNBNXqUXEcvckN29OxyQ==} + /react-virtuoso@4.7.0(react-dom@18.2.0)(react@18.2.0): + resolution: {integrity: sha512-cpgvI1rSOETGDMhqVAVDuH+XHbWO1uIGKv5I6l4CyC71xWYUeGrE5n7sgTZklROB4+Vbv85pcgfWloTlY48HGQ==} engines: {node: '>=10'} peerDependencies: react: '>=16 || >=17 || >= 18' @@ -12414,18 +12591,18 @@ packages: dependencies: loose-envify: 1.4.0 - /reactflow@11.10.2(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0): - resolution: {integrity: sha512-tqQJfPEiIkXonT3piVYf+F9CvABI5e28t5I6rpaLTnO8YVCAOh1h0f+ziDKz0Bx9Y2B/mFgyz+H7LZeUp/+lhQ==} + /reactflow@11.10.4(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0): + resolution: {integrity: sha512-0CApYhtYicXEDg/x2kvUHiUk26Qur8lAtTtiSlptNKuyEuGti6P1y5cS32YGaUoDMoCqkm/m+jcKkfMOvSCVRA==} peerDependencies: react: '>=17' react-dom: '>=17' dependencies: - '@reactflow/background': 11.3.7(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0) - '@reactflow/controls': 11.2.7(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0) - '@reactflow/core': 11.10.2(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0) - '@reactflow/minimap': 11.7.7(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0) - '@reactflow/node-resizer': 2.2.7(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0) - '@reactflow/node-toolbar': 1.3.7(@types/react@18.2.48)(react-dom@18.2.0)(react@18.2.0) + '@reactflow/background': 11.3.9(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0) + '@reactflow/controls': 11.2.9(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0) + '@reactflow/core': 11.10.4(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0) + '@reactflow/minimap': 11.7.9(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0) + '@reactflow/node-resizer': 2.2.9(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0) + '@reactflow/node-toolbar': 1.3.9(@types/react@18.2.57)(react-dom@18.2.0)(react@18.2.0) react: 18.2.0 react-dom: 18.2.0(react@18.2.0) transitivePeerDependencies: @@ -12523,14 +12700,15 @@ packages: resolution: {integrity: sha512-M9/ELqF6fy8FwmkpnF0S3YKOqMyoWJ4+CS5Efg2ct3oY9daQvd/Pc71FpGZsVsbl3Cpb+IIcjBDUnnyBdQbq4w==} dev: false - /reflect.getprototypeof@1.0.4: - resolution: {integrity: sha512-ECkTw8TmJwW60lOTR+ZkODISW6RQ8+2CL3COqtiJKLd6MmB45hN51HprHFziKLGkAuTGQhBb91V8cy+KHlaCjw==} + /reflect.getprototypeof@1.0.5: + resolution: {integrity: sha512-62wgfC8dJWrmxv44CA36pLDnP6KKl3Vhxb7PL+8+qrrFMMoJij4vgiMP8zV4O8+CBMXY1mHxI5fITGHXFHVmQQ==} engines: {node: '>= 0.4'} dependencies: - call-bind: 1.0.5 + call-bind: 1.0.7 define-properties: 1.2.1 - es-abstract: 1.22.3 - get-intrinsic: 1.2.2 + es-abstract: 1.22.4 + es-errors: 1.3.0 + get-intrinsic: 1.2.4 globalthis: 1.0.3 which-builtin-type: 1.1.3 dev: true @@ -12552,16 +12730,17 @@ packages: /regenerator-transform@0.15.2: resolution: {integrity: sha512-hfMp2BoF0qOk3uc5V20ALGDS2ddjQaLrdl7xrGXvAIow7qeWRM2VA2HuCHkUKk9slq3VwEwLNK3DFBqDfPGYtg==} dependencies: - '@babel/runtime': 7.23.8 + '@babel/runtime': 7.23.9 dev: true - /regexp.prototype.flags@1.5.1: - resolution: {integrity: sha512-sy6TXMN+hnP/wMy+ISxg3krXx7BAtWVO4UouuCN/ziM9UEne0euamVNafDfvC83bRNr95y0V5iijeDQFUNpvrg==} + /regexp.prototype.flags@1.5.2: + resolution: {integrity: sha512-NcDiDkTLuPR+++OCKB0nWafEmhg/Da8aUPLPMQbK+bxKKCm1/S5he+AqYa4PlMCVBalb4/yxIRub6qkEx5yJbw==} engines: {node: '>= 0.4'} dependencies: - call-bind: 1.0.5 + call-bind: 1.0.7 define-properties: 1.2.1 - set-function-name: 2.0.1 + es-errors: 1.3.0 + set-function-name: 2.0.2 dev: true /regexpu-core@5.3.2: @@ -12630,10 +12809,9 @@ packages: hasBin: true dev: true - /reselect@5.0.1(patch_hash=kvbgwzjyy4x4fnh7znyocvb75q): - resolution: {integrity: sha512-D72j2ubjgHpvuCiORWkOUxndHJrxDaSolheiz5CO+roz8ka97/4msh2E8F5qay4GawR5vzBt5MkbDHT+Rdy/Wg==} + /reselect@5.1.0: + resolution: {integrity: sha512-aw7jcGLDpSgNDyWBQLv2cedml85qd95/iszJjN988zX1t7AVRJi19d9kto5+W7oCfQ94gyo40dVbT6g2k4/kXg==} dev: false - patched: true /resize-observer-polyfill@1.5.1: resolution: {integrity: sha512-LwZrotdHOo12nQuZlHEmtuXdqGoOD0OhaxopaNFxWzInpEgaLWoVuAMbTzixuosCx2nEG58ngzW3vxdWoxIgdg==} @@ -12757,33 +12935,33 @@ packages: fsevents: 2.3.3 dev: true - /rollup@4.9.4: - resolution: {integrity: sha512-2ztU7pY/lrQyXSCnnoU4ICjT/tCG9cdH3/G25ERqE3Lst6vl2BCM5hL2Nw+sslAvAf+ccKsAq1SkKQALyqhR7g==} + /rollup@4.12.0: + resolution: {integrity: sha512-wz66wn4t1OHIJw3+XU7mJJQV/2NAfw5OAk6G6Hoo3zcvz/XOfQ52Vgi+AN4Uxoxi0KBBwk2g8zPrTDA4btSB/Q==} engines: {node: '>=18.0.0', npm: '>=8.0.0'} hasBin: true dependencies: '@types/estree': 1.0.5 optionalDependencies: - '@rollup/rollup-android-arm-eabi': 4.9.4 - '@rollup/rollup-android-arm64': 4.9.4 - '@rollup/rollup-darwin-arm64': 4.9.4 - '@rollup/rollup-darwin-x64': 4.9.4 - '@rollup/rollup-linux-arm-gnueabihf': 4.9.4 - '@rollup/rollup-linux-arm64-gnu': 4.9.4 - '@rollup/rollup-linux-arm64-musl': 4.9.4 - '@rollup/rollup-linux-riscv64-gnu': 4.9.4 - '@rollup/rollup-linux-x64-gnu': 4.9.4 - '@rollup/rollup-linux-x64-musl': 4.9.4 - '@rollup/rollup-win32-arm64-msvc': 4.9.4 - '@rollup/rollup-win32-ia32-msvc': 4.9.4 - '@rollup/rollup-win32-x64-msvc': 4.9.4 + '@rollup/rollup-android-arm-eabi': 4.12.0 + '@rollup/rollup-android-arm64': 4.12.0 + '@rollup/rollup-darwin-arm64': 4.12.0 + '@rollup/rollup-darwin-x64': 4.12.0 + '@rollup/rollup-linux-arm-gnueabihf': 4.12.0 + '@rollup/rollup-linux-arm64-gnu': 4.12.0 + '@rollup/rollup-linux-arm64-musl': 4.12.0 + '@rollup/rollup-linux-riscv64-gnu': 4.12.0 + '@rollup/rollup-linux-x64-gnu': 4.12.0 + '@rollup/rollup-linux-x64-musl': 4.12.0 + '@rollup/rollup-win32-arm64-msvc': 4.12.0 + '@rollup/rollup-win32-ia32-msvc': 4.12.0 + '@rollup/rollup-win32-x64-msvc': 4.12.0 fsevents: 2.3.3 dev: true /rtl-css-js@1.16.1: resolution: {integrity: sha512-lRQgou1mu19e+Ya0LsTvKrVJ5TYUbqCVPAiImX3UfLTenarvPUl1QFdvu5Z3PYmHT9RCcwIfbjRQBntExyj3Zg==} dependencies: - '@babel/runtime': 7.23.8 + '@babel/runtime': 7.23.9 dev: false /run-parallel@1.2.0: @@ -12798,12 +12976,12 @@ packages: tslib: 2.6.2 dev: true - /safe-array-concat@1.0.1: - resolution: {integrity: sha512-6XbUAseYE2KtOuGueyeobCySj9L4+66Tn6KQMOPQJrAJEowYKW/YR/MGJZl7FdydUdaFu4LYyDZjxf4/Nmo23Q==} + /safe-array-concat@1.1.0: + resolution: {integrity: sha512-ZdQ0Jeb9Ofti4hbt5lX3T2JcAamT9hfzYU1MNB+z/jaEbB6wfFfPIR/zEORmZqobkCCJhSjodobH6WHNmJ97dg==} engines: {node: '>=0.4'} dependencies: - call-bind: 1.0.5 - get-intrinsic: 1.2.2 + call-bind: 1.0.7 + get-intrinsic: 1.2.4 has-symbols: 1.0.3 isarray: 2.0.5 dev: true @@ -12816,11 +12994,12 @@ packages: resolution: {integrity: sha512-rp3So07KcdmmKbGvgaNxQSJr7bGVSVk5S9Eq1F+ppbRo70+YeaDxkw5Dd8NPN+GD6bjnYm2VuPuCXmpuYvmCXQ==} dev: true - /safe-regex-test@1.0.0: - resolution: {integrity: sha512-JBUUzyOgEwXQY1NuPtvcj/qcBDbDmEvWufhlnXZIm75DEHp+afM1r1ujJpJsV/gSM4t59tpDyPi1sd6ZaPFfsA==} + /safe-regex-test@1.0.3: + resolution: {integrity: sha512-CdASjNJPvRa7roO6Ra/gLYBTzYzzPyyBXxIMdGW3USQLyjWEls2RgW5UBTXaQVp+OrpeCK3bLem8smtmheoRuw==} + engines: {node: '>= 0.4'} dependencies: - call-bind: 1.0.5 - get-intrinsic: 1.2.2 + call-bind: 1.0.7 + es-errors: 1.3.0 is-regex: 1.1.4 dev: true @@ -12873,6 +13052,14 @@ packages: lru-cache: 6.0.0 dev: true + /semver@7.6.0: + resolution: {integrity: sha512-EnwXhrlwXMk9gKu5/flx5sv/an57AkRplG3hTK68W7FRDN+k+OWBj65M7719OkA82XLBxrcX0KSHj+X5COhOVg==} + engines: {node: '>=10'} + hasBin: true + dependencies: + lru-cache: 6.0.0 + dev: true + /send@0.18.0: resolution: {integrity: sha512-qqWzuOjSFOuqPjFe4NOsMLafToQQwBSOEpS+FwEt3A2V3vKubTquT3vmLTQpFgMXp8AlFWFuP1qKaJZOtPpVXg==} engines: {node: '>= 0.8.0'} @@ -12913,23 +13100,26 @@ packages: - supports-color dev: true - /set-function-length@1.1.1: - resolution: {integrity: sha512-VoaqjbBJKiWtg4yRcKBQ7g7wnGnLV3M8oLvVWwOk2PdYY6PEFegR1vezXR0tw6fZGF9csVakIRjrJiy2veSBFQ==} + /set-function-length@1.2.1: + resolution: {integrity: sha512-j4t6ccc+VsKwYHso+kElc5neZpjtq9EnRICFZtWyBsLojhmeF/ZBd/elqm22WJh/BziDe/SBiOeAt0m2mfLD0g==} engines: {node: '>= 0.4'} dependencies: - define-data-property: 1.1.1 - get-intrinsic: 1.2.2 + define-data-property: 1.1.4 + es-errors: 1.3.0 + function-bind: 1.1.2 + get-intrinsic: 1.2.4 gopd: 1.0.1 - has-property-descriptors: 1.0.1 + has-property-descriptors: 1.0.2 dev: true - /set-function-name@2.0.1: - resolution: {integrity: sha512-tMNCiqYVkXIZgc2Hnoy2IvC/f8ezc5koaRFkCjrpWzGpCd3qbZXPzVy9MAZzK1ch/X0jvSkojys3oqJN0qCmdA==} + /set-function-name@2.0.2: + resolution: {integrity: sha512-7PGFlmtwsEADb0WYyvCMa1t+yke6daIG4Wirafur5kcf+MhUnPms1UeR0CKQdTZD81yESwMHbtn+TR+dMviakQ==} engines: {node: '>= 0.4'} dependencies: - define-data-property: 1.1.1 + define-data-property: 1.1.4 + es-errors: 1.3.0 functions-have-names: 1.2.3 - has-property-descriptors: 1.0.1 + has-property-descriptors: 1.0.2 dev: true /set-harmonic-interval@1.0.1: @@ -12964,11 +13154,13 @@ packages: resolution: {integrity: sha512-6j1W9l1iAs/4xYBI1SYOVZyFcCis9b4KCLQ8fgAGG07QvzaRLVVRQvAy85yNmmZSjYjg4MWh4gNvlPujU/5LpA==} dev: true - /side-channel@1.0.4: - resolution: {integrity: sha512-q5XPytqFEIKHkGdiMIrY10mvLRvnQh42/+GoBlFW3b2LXLE2xxJpZFdm94we0BaoV3RwJyGqg5wS7epxTv0Zvw==} + /side-channel@1.0.5: + resolution: {integrity: sha512-QcgiIWV4WV7qWExbN5llt6frQB/lBven9pqliLXfGPB+K9ZYXxDozp0wLkHS24kWCm+6YXH/f0HhnObZnZOBnQ==} + engines: {node: '>= 0.4'} dependencies: - call-bind: 1.0.5 - get-intrinsic: 1.2.2 + call-bind: 1.0.7 + es-errors: 1.3.0 + get-intrinsic: 1.2.4 object-inspect: 1.13.1 dev: true @@ -13061,22 +13253,22 @@ packages: resolution: {integrity: sha512-kN9dJbvnySHULIluDHy32WHRUu3Og7B9sbY7tsFLctQkIqnMh3hErYgdMjTYuqmcXX+lK5T1lnUt3G7zNswmZA==} dependencies: spdx-expression-parse: 3.0.1 - spdx-license-ids: 3.0.16 + spdx-license-ids: 3.0.17 dev: true - /spdx-exceptions@2.3.0: - resolution: {integrity: sha512-/tTrYOC7PPI1nUAgx34hUpqXuyJG+DTHJTnIULG4rDygi4xu/tfgmq1e1cIRwRzwZgo4NLySi+ricLkZkw4i5A==} + /spdx-exceptions@2.5.0: + resolution: {integrity: sha512-PiU42r+xO4UbUS1buo3LPJkjlO7430Xn5SVAhdpzzsPHsjbYVflnnFdATgabnLude+Cqu25p6N+g2lw/PFsa4w==} dev: true /spdx-expression-parse@3.0.1: resolution: {integrity: sha512-cbqHunsQWnJNE6KhVSMsMeH5H/L9EpymbzqTQ3uLwNCLZ1Q481oWaofqH7nO6V07xlXwY6PhQdQ2IedWx/ZK4Q==} dependencies: - spdx-exceptions: 2.3.0 - spdx-license-ids: 3.0.16 + spdx-exceptions: 2.5.0 + spdx-license-ids: 3.0.17 dev: true - /spdx-license-ids@3.0.16: - resolution: {integrity: sha512-eWN+LnM3GR6gPu35WxNgbGl8rmY1AEmoMDvL/QD6zYmPWgywxWqJWNdLGT+ke8dKNWrcYgYjPpG5gbTfghP8rw==} + /spdx-license-ids@3.0.17: + resolution: {integrity: sha512-sh8PWc/ftMqAAdFiBu6Fy6JUOYjqDJBJvIhpfDMyHrr0Rbp5liZqd4TjtQ/RgfLjKFZb+LMx5hpml5qOWy0qvg==} dev: true /split-on-first@3.0.0: @@ -13130,18 +13322,18 @@ packages: resolution: {integrity: sha512-iCGQj+0l0HOdZ2AEeBADlsRC+vsnDsZsbdSiH1yNSjcfKM7fdpCMfqAL/dwF5BLiw/XhRft/Wax6zQbhq2BcjQ==} engines: {node: '>= 0.4'} dependencies: - internal-slot: 1.0.6 + internal-slot: 1.0.7 dev: true - /store2@2.14.2: - resolution: {integrity: sha512-siT1RiqlfQnGqgT/YzXVUNsom9S0H1OX+dpdGN1xkyYATo4I6sep5NmsRD/40s3IIOvlCq6akxkqG82urIZW1w==} + /store2@2.14.3: + resolution: {integrity: sha512-4QcZ+yx7nzEFiV4BMLnr/pRa5HYzNITX2ri0Zh6sT9EyQHbBHacC6YigllUPU9X3D0f/22QCgfokpKs52YRrUg==} dev: true - /storybook@7.6.10: - resolution: {integrity: sha512-ypFeGhQTUBBfqSUVZYh7wS5ghn3O2wILCiQc4459SeUpvUn+skcqw/TlrwGSoF5EWjDA7gtRrWDxO3mnlPt5Cw==} + /storybook@7.6.17: + resolution: {integrity: sha512-8+EIo91bwmeFWPg1eysrxXlhIYv3OsXrznTr4+4Eq0NikqAoq6oBhtlN5K2RGS2lBVF537eN+9jTCNbR+WrzDA==} hasBin: true dependencies: - '@storybook/cli': 7.6.10 + '@storybook/cli': 7.6.17 transitivePeerDependencies: - bufferutil - encoding @@ -13185,40 +13377,40 @@ packages: /string.prototype.matchall@4.0.10: resolution: {integrity: sha512-rGXbGmOEosIQi6Qva94HUjgPs9vKW+dkG7Y8Q5O2OYkWL6wFaTRZO8zM4mhP94uX55wgyrXzfS2aGtGzUL7EJQ==} dependencies: - call-bind: 1.0.5 + call-bind: 1.0.7 define-properties: 1.2.1 - es-abstract: 1.22.3 - get-intrinsic: 1.2.2 + es-abstract: 1.22.4 + get-intrinsic: 1.2.4 has-symbols: 1.0.3 - internal-slot: 1.0.6 - regexp.prototype.flags: 1.5.1 - set-function-name: 2.0.1 - side-channel: 1.0.4 + internal-slot: 1.0.7 + regexp.prototype.flags: 1.5.2 + set-function-name: 2.0.2 + side-channel: 1.0.5 dev: true /string.prototype.trim@1.2.8: resolution: {integrity: sha512-lfjY4HcixfQXOfaqCvcBuOIapyaroTXhbkfJN3gcB1OtyupngWK4sEET9Knd0cXd28kTUqu/kHoV4HKSJdnjiQ==} engines: {node: '>= 0.4'} dependencies: - call-bind: 1.0.5 + call-bind: 1.0.7 define-properties: 1.2.1 - es-abstract: 1.22.3 + es-abstract: 1.22.4 dev: true /string.prototype.trimend@1.0.7: resolution: {integrity: sha512-Ni79DqeB72ZFq1uH/L6zJ+DKZTkOtPIHovb3YZHQViE+HDouuU4mBrLOLDn5Dde3RF8qw5qVETEjhu9locMLvA==} dependencies: - call-bind: 1.0.5 + call-bind: 1.0.7 define-properties: 1.2.1 - es-abstract: 1.22.3 + es-abstract: 1.22.4 dev: true /string.prototype.trimstart@1.0.7: resolution: {integrity: sha512-NGhtDFu3jCEm7B4Fy0DpLewdJQOZcQ0rGbwQ/+stjnrp2i+rlKeCvos9hOIeCmqwratM47OBxY7uFZzjxHXmrg==} dependencies: - call-bind: 1.0.5 + call-bind: 1.0.7 define-properties: 1.2.1 - es-abstract: 1.22.3 + es-abstract: 1.22.4 dev: true /string_decoder@1.1.1: @@ -13295,10 +13487,10 @@ packages: engines: {node: '>=8'} dev: true - /strip-literal@1.3.0: - resolution: {integrity: sha512-PugKzOsyXpArk0yWmUwqOZecSO0GH0bPoctLcqNDH9J04pVW3lflYE0ujElBGTloevcxF5MofAOZ7C5l2b+wLg==} + /strip-literal@2.0.0: + resolution: {integrity: sha512-f9vHgsCWBq2ugHAkGMiiYY+AYG0D/cbloKKg0nhaaaSNsujdGIpVXCNsrJpCKr5M0f4aI31mr13UjY6GAuXCKA==} dependencies: - acorn: 8.11.3 + js-tokens: 8.0.3 dev: true /stylis@4.2.0: @@ -13460,8 +13652,8 @@ packages: engines: {node: '>=14.0.0'} dev: true - /tinyspy@2.2.0: - resolution: {integrity: sha512-d2eda04AN/cPOR89F7Xv5bK/jrQEhmcLFe6HFldoeO9AJtps+fqEnh486vnT/8y4bw38pSyxDcTCAq+Ks2aJTg==} + /tinyspy@2.2.1: + resolution: {integrity: sha512-KYad6Vy5VDWV4GH3fjpseMQ/XU2BhIYP7Vzd0LG44qRWm/Yt2WCOTicFdvmgo6gWaqooMQCawTtILVQJupKu7A==} engines: {node: '>=14.0.0'} dev: true @@ -13501,9 +13693,9 @@ packages: hasBin: true dev: true - /ts-api-utils@1.0.3(typescript@5.3.3): - resolution: {integrity: sha512-wNMeqtMz5NtwpT/UZGY5alT+VoKdSsOOP/kqHFcUW1P/VRhH2wJ48+DN2WwUliNbQ976ETwDL0Ifd2VVvgonvg==} - engines: {node: '>=16.13.0'} + /ts-api-utils@1.2.1(typescript@5.3.3): + resolution: {integrity: sha512-RIYA36cJn2WiH9Hy77hdF9r7oEwxAtB/TS9/S4Qd90Ap4z5FSiin5zEiTL44OII1Y3IIlEvxwxFUVgrHSZ/UpA==} + engines: {node: '>=16'} peerDependencies: typescript: '>=4.2.0' dependencies: @@ -13523,8 +13715,8 @@ packages: resolution: {integrity: sha512-tLJxacIQUM82IR7JO1UUkKlYuUTmoY9HBJAmNWFzheSlDS5SPMcNIepejHJa4BpPQLAcbRhRf3GDJzyj6rbKvA==} dev: false - /ts-graphviz@1.8.1: - resolution: {integrity: sha512-54/fe5iu0Jb6X0pmDmzsA2UHLfyHjUEUwfHtZcEOR0fZ6Myf+dFoO6eNsyL8CBDMJ9u7WWEewduVaiaXlvjSVw==} + /ts-graphviz@1.8.2: + resolution: {integrity: sha512-5YhbFoHmjxa7pgQLkB07MtGnGJ/yhvjmc9uhsnDBEICME6gkPf83SBwLDQqGDoCa3XzUMWLk1AU2Wn1u1naDtA==} engines: {node: '>=14.16'} dev: true @@ -13536,8 +13728,8 @@ packages: resolution: {integrity: sha512-gzkapsdbMNwBnTIjgO758GujLCj031IgHK/PKr2mrmkCSJMhSOR5FeOuSxKLMUoYc0vAA4RGEYYbjt/v6afD3g==} dev: true - /tsconfck@3.0.1(typescript@5.3.3): - resolution: {integrity: sha512-7ppiBlF3UEddCLeI1JRx5m2Ryq+xk4JrZuq4EuYXykipebaq1dV0Fhgr1hb7CkmHt32QSgOZlcqVLEtHBG4/mg==} + /tsconfck@3.0.2(typescript@5.3.3): + resolution: {integrity: sha512-6lWtFjwuhS3XI4HsX4Zg0izOI3FU/AI9EGVlPEUMDIhvLPMD4wkiof0WCoDgW7qY+Dy198g4d9miAqUHWHFH6Q==} engines: {node: ^18 || >=20} hasBin: true peerDependencies: @@ -13635,8 +13827,8 @@ packages: resolution: {integrity: sha512-RAH822pAdBgcNMAfWnCBU3CFZcfZ/i1eZjwFU/dsLKumyuuP3niueg2UAukXYF0E2AAoc82ZSSf9J0WQBinzHA==} engines: {node: '>=12.20'} - /type-fest@4.9.0: - resolution: {integrity: sha512-KS/6lh/ynPGiHD/LnAobrEFq3Ad4pBzOlJ1wAnJx9N4EYoqFhMfLIBjUT2UEx4wg5ZE+cC1ob6DCSpppVo+rtg==} + /type-fest@4.10.2: + resolution: {integrity: sha512-anpAG63wSpdEbLwOqH8L84urkL6PiVIov3EMmgIhhThevh9aiMQov+6Btx0wldNcvm4wV+e2/Rt1QdDwKHFbHw==} engines: {node: '>=16'} dev: false @@ -13648,42 +13840,47 @@ packages: mime-types: 2.1.35 dev: true - /typed-array-buffer@1.0.0: - resolution: {integrity: sha512-Y8KTSIglk9OZEr8zywiIHG/kmQ7KWyjseXs1CbSo8vC42w7hg2HgYTxSWwP0+is7bWDc1H+Fo026CpHFwm8tkw==} + /typed-array-buffer@1.0.2: + resolution: {integrity: sha512-gEymJYKZtKXzzBzM4jqa9w6Q1Jjm7x2d+sh19AdsD4wqnMPDYyvwpsIc2Q/835kHuo3BEQ7CjelGhfTsoBb2MQ==} engines: {node: '>= 0.4'} dependencies: - call-bind: 1.0.5 - get-intrinsic: 1.2.2 - is-typed-array: 1.1.12 + call-bind: 1.0.7 + es-errors: 1.3.0 + is-typed-array: 1.1.13 dev: true /typed-array-byte-length@1.0.0: resolution: {integrity: sha512-Or/+kvLxNpeQ9DtSydonMxCx+9ZXOswtwJn17SNLvhptaXYDJvkFFP5zbfU/uLmvnBJlI4yrnXRxpdWH/M5tNA==} engines: {node: '>= 0.4'} dependencies: - call-bind: 1.0.5 + call-bind: 1.0.7 for-each: 0.3.3 - has-proto: 1.0.1 - is-typed-array: 1.1.12 + has-proto: 1.0.3 + is-typed-array: 1.1.13 dev: true - /typed-array-byte-offset@1.0.0: - resolution: {integrity: sha512-RD97prjEt9EL8YgAgpOkf3O4IF9lhJFr9g0htQkm0rchFp/Vx7LW5Q8fSXXub7BXAODyUQohRMyOc3faCPd0hg==} + /typed-array-byte-offset@1.0.2: + resolution: {integrity: sha512-Ous0vodHa56FviZucS2E63zkgtgrACj7omjwd/8lTEMEPFFyjfixMZ1ZXenpgCFBBt4EC1J2XsyVS2gkG0eTFA==} engines: {node: '>= 0.4'} dependencies: - available-typed-arrays: 1.0.5 - call-bind: 1.0.5 + available-typed-arrays: 1.0.7 + call-bind: 1.0.7 for-each: 0.3.3 - has-proto: 1.0.1 - is-typed-array: 1.1.12 + gopd: 1.0.1 + has-proto: 1.0.3 + is-typed-array: 1.1.13 dev: true - /typed-array-length@1.0.4: - resolution: {integrity: sha512-KjZypGq+I/H7HI5HlOoGHkWUUGq+Q0TPhQurLbyrVrvnKTBgzLhIJ7j6J/XTQOi0d1RjyZ0wdas8bKs2p0x3Ng==} + /typed-array-length@1.0.5: + resolution: {integrity: sha512-yMi0PlwuznKHxKmcpoOdeLwxBoVPkqZxd7q2FgMkmD3bNwvF5VW0+UlUQ1k1vmktTu4Yu13Q0RIxEP8+B+wloA==} + engines: {node: '>= 0.4'} dependencies: - call-bind: 1.0.5 + call-bind: 1.0.7 for-each: 0.3.3 - is-typed-array: 1.1.12 + gopd: 1.0.1 + has-proto: 1.0.3 + is-typed-array: 1.1.13 + possible-typed-array-names: 1.0.0 dev: true /typedarray@0.0.6: @@ -13708,8 +13905,8 @@ packages: hasBin: true dev: true - /ufo@1.3.2: - resolution: {integrity: sha512-o+ORpgGwaYQXgqGDwd+hkS4PuZ3QnmqMMxRuajK/a38L6fTpcE5GPIfrf+L/KemFzfUpeUQc1rRS1iDBozvnFA==} + /ufo@1.4.0: + resolution: {integrity: sha512-Hhy+BhRBleFjpJ2vchUNN40qgkh0366FWJGqVLYBHev0vpHTrXSA0ryT+74UiW6KWsldNurQMKGqCm1M2zBciQ==} dev: true /uglify-js@3.17.4: @@ -13723,7 +13920,7 @@ packages: /unbox-primitive@1.0.2: resolution: {integrity: sha512-61pPlCD9h51VoreyJ0BReideM3MDKMKnh6+V9L08331ipq6Q8OFXZYiqP6n/tbHx4s5I9uRhcye6BrbkizkBDw==} dependencies: - call-bind: 1.0.5 + call-bind: 1.0.7 has-bigints: 1.0.2 has-symbols: 1.0.3 which-boxed-primitive: 1.0.2 @@ -13733,8 +13930,8 @@ packages: resolution: {integrity: sha512-JlCMO+ehdEIKqlFxk6IfVoAUVmgz7cU7zD/h9XZ0qzeosSHmUJVOzSQvvYSYWXkFXC+IfLKSIffhv0sVZup6pA==} dev: true - /undici@5.28.2: - resolution: {integrity: sha512-wh1pHJHnUeQV5Xa8/kyQhO7WFa8M34l026L5P/+2TYiakvGy5Rdc8jWZVyG7ieht/0WgJLEd3kcU5gKx+6GC8w==} + /undici@5.28.3: + resolution: {integrity: sha512-3ItfzbrhDlINjaP0duwnNsKpDQk3acHI3gVJ1z4fmwMK31k5G9OVIAMLSIaP6w4FaGkaAkN6zaQO9LUvZ1t7VA==} engines: {node: '>=14.0'} dependencies: '@fastify/busboy': 2.1.0 @@ -13808,11 +14005,11 @@ packages: engines: {node: '>= 0.8'} dev: true - /unplugin@1.6.0: - resolution: {integrity: sha512-BfJEpWBu3aE/AyHx8VaNE/WgouoQxgH9baAiH82JjX8cqVyi3uJQstqwD5J+SZxIK326SZIhsSZlALXVBCknTQ==} + /unplugin@1.7.1: + resolution: {integrity: sha512-JqzORDAPxxs8ErLV4x+LL7bk5pk3YlcWqpSNsIkAZj972KzFZLClc/ekppahKkOczGkwIG6ElFgdOgOlK4tXZw==} dependencies: acorn: 8.11.3 - chokidar: 3.5.3 + chokidar: 3.6.0 webpack-sources: 3.2.3 webpack-virtual-modules: 0.6.1 dev: true @@ -13822,14 +14019,14 @@ packages: engines: {node: '>=8'} dev: true - /update-browserslist-db@1.0.13(browserslist@4.22.2): + /update-browserslist-db@1.0.13(browserslist@4.23.0): resolution: {integrity: sha512-xebP81SNcPuNpPP3uzeW1NYXxI3rxyJzF3pD6sH4jE7o/IX+WtSpwnVU+qIsDPyk0d3hmFQ7mjqc6AtV604hbg==} hasBin: true peerDependencies: browserslist: '>= 4.21.0' dependencies: - browserslist: 4.22.2 - escalade: 3.1.1 + browserslist: 4.23.0 + escalade: 3.1.2 picocolors: 1.0.0 dev: true @@ -13839,7 +14036,7 @@ packages: punycode: 2.3.1 dev: true - /use-callback-ref@1.3.1(@types/react@18.2.48)(react@18.2.0): + /use-callback-ref@1.3.1(@types/react@18.2.57)(react@18.2.0): resolution: {integrity: sha512-Lg4Vx1XZQauB42Hw3kK7JM6yjVjgFmFC5/Ab797s79aARomD2nEErc4mCgM8EZrARLmmbWpi5DGCadmK50DcAQ==} engines: {node: '>=10'} peerDependencies: @@ -13849,7 +14046,7 @@ packages: '@types/react': optional: true dependencies: - '@types/react': 18.2.48 + '@types/react': 18.2.57 react: 18.2.0 tslib: 2.6.2 @@ -13880,7 +14077,7 @@ packages: react-dom: 18.2.0(react@18.2.0) dev: false - /use-isomorphic-layout-effect@1.1.2(@types/react@18.2.48)(react@18.2.0): + /use-isomorphic-layout-effect@1.1.2(@types/react@18.2.57)(react@18.2.0): resolution: {integrity: sha512-49L8yCO3iGT/ZF9QttjwLF/ZD9Iwto5LnH5LmEdk/6cFmXddqi2ulF0edxTwjj+7mqvpVVGQWvbXZdn32wRSHA==} peerDependencies: '@types/react': '*' @@ -13889,11 +14086,11 @@ packages: '@types/react': optional: true dependencies: - '@types/react': 18.2.48 + '@types/react': 18.2.57 react: 18.2.0 dev: false - /use-latest@1.2.1(@types/react@18.2.48)(react@18.2.0): + /use-latest@1.2.1(@types/react@18.2.57)(react@18.2.0): resolution: {integrity: sha512-xA+AVm/Wlg3e2P/JiItTziwS7FK92LWrDB0p+hgXloIMuVCeJJ8v6f0eeHyPZaJrM+usM1FkFfbNCrJGs8A/zw==} peerDependencies: '@types/react': '*' @@ -13902,9 +14099,9 @@ packages: '@types/react': optional: true dependencies: - '@types/react': 18.2.48 + '@types/react': 18.2.57 react: 18.2.0 - use-isomorphic-layout-effect: 1.1.2(@types/react@18.2.48)(react@18.2.0) + use-isomorphic-layout-effect: 1.1.2(@types/react@18.2.57)(react@18.2.0) dev: false /use-resize-observer@9.1.0(react-dom@18.2.0)(react@18.2.0): @@ -13918,7 +14115,7 @@ packages: react-dom: 18.2.0(react@18.2.0) dev: true - /use-sidecar@1.1.2(@types/react@18.2.48)(react@18.2.0): + /use-sidecar@1.1.2(@types/react@18.2.57)(react@18.2.0): resolution: {integrity: sha512-epTbsLuzZ7lPClpz2TyryBfztm7m+28DlEv2ZCQ3MDr5ssiwyOwGH/e5F9CkfWjJ1t4clvI58yF822/GUkjjhw==} engines: {node: '>=10'} peerDependencies: @@ -13928,7 +14125,7 @@ packages: '@types/react': optional: true dependencies: - '@types/react': 18.2.48 + '@types/react': 18.2.57 detect-node-es: 1.1.0 react: 18.2.0 tslib: 2.6.2 @@ -13951,8 +14148,8 @@ packages: inherits: 2.0.4 is-arguments: 1.1.1 is-generator-function: 1.0.10 - is-typed-array: 1.1.12 - which-typed-array: 1.1.13 + is-typed-array: 1.1.13 + which-typed-array: 1.1.14 dev: true /utils-merge@1.0.1: @@ -13981,8 +14178,8 @@ packages: engines: {node: '>= 0.8'} dev: true - /vite-node@1.2.2(@types/node@20.11.5): - resolution: {integrity: sha512-1as4rDTgVWJO3n1uHmUYqq7nsFgINQ9u+mRcXpjeOMJUmviqNKjcZB7UfRZrlM7MjYXMKpuWp5oGkjaFLnjawg==} + /vite-node@1.3.1(@types/node@20.11.19): + resolution: {integrity: sha512-azbRrqRxlWTJEVbzInZCTchx0X69M/XPTCz4H+TLvlTcR/xH/3hkRqhOakT41fMJCMzXTu4UvegkZiEoJAWvng==} engines: {node: ^18.0.0 || >=20.0.0} hasBin: true dependencies: @@ -13990,7 +14187,7 @@ packages: debug: 4.3.4 pathe: 1.1.2 picocolors: 1.0.0 - vite: 5.0.12(@types/node@20.11.5) + vite: 5.1.3(@types/node@20.11.19) transitivePeerDependencies: - '@types/node' - less @@ -14002,16 +14199,16 @@ packages: - terser dev: true - /vite-plugin-css-injected-by-js@3.3.1(vite@5.0.12): - resolution: {integrity: sha512-PjM/X45DR3/V1K1fTRs8HtZHEQ55kIfdrn+dzaqNBFrOYO073SeSNCxp4j7gSYhV9NffVHaEnOL4myoko0ePAg==} + /vite-plugin-css-injected-by-js@3.4.0(vite@5.1.3): + resolution: {integrity: sha512-wS5+UYtJXQ/vNornsqTQxOLBVO/UjXU54ZsYMeX0mj2OrbStMQ4GLgvneVDQGPwyGJcm/ntBPawc2lA7xx+Lpg==} peerDependencies: vite: '>2.0.0-0' dependencies: - vite: 5.0.12(@types/node@20.11.5) + vite: 5.1.3(@types/node@20.11.19) dev: true - /vite-plugin-dts@3.7.1(@types/node@20.11.5)(typescript@5.3.3)(vite@5.0.12): - resolution: {integrity: sha512-VZJckNFpVfRAkmOxhGT5OgTUVWVXxkNQqLpBUuiNGAr9HbtvmvsPLo2JB3Xhn+o/Z9+CT6YZfYa4bX9SGR5hNw==} + /vite-plugin-dts@3.7.2(@types/node@20.11.19)(typescript@5.3.3)(vite@5.1.3): + resolution: {integrity: sha512-kg//1nDA01b8rufJf4TsvYN8LMkdwv0oBYpiQi6nRwpHyue+wTlhrBiqgipdFpMnW1oOYv6ywmzE5B0vg6vSEA==} engines: {node: ^14.18.0 || >=16.0.0} peerDependencies: typescript: '*' @@ -14020,13 +14217,13 @@ packages: vite: optional: true dependencies: - '@microsoft/api-extractor': 7.39.0(@types/node@20.11.5) + '@microsoft/api-extractor': 7.39.0(@types/node@20.11.19) '@rollup/pluginutils': 5.1.0 '@vue/language-core': 1.8.27(typescript@5.3.3) debug: 4.3.4 kolorist: 1.8.0 typescript: 5.3.3 - vite: 5.0.12(@types/node@20.11.5) + vite: 5.1.3(@types/node@20.11.19) vue-tsc: 1.8.27(typescript@5.3.3) transitivePeerDependencies: - '@types/node' @@ -14034,20 +14231,20 @@ packages: - supports-color dev: true - /vite-plugin-eslint@1.8.1(eslint@8.56.0)(vite@5.0.12): + /vite-plugin-eslint@1.8.1(eslint@8.56.0)(vite@5.1.3): resolution: {integrity: sha512-PqdMf3Y2fLO9FsNPmMX+//2BF5SF8nEWspZdgl4kSt7UvHDRHVVfHvxsD7ULYzZrJDGRxR81Nq7TOFgwMnUang==} peerDependencies: eslint: '>=7' vite: '>=2' dependencies: '@rollup/pluginutils': 4.2.1 - '@types/eslint': 8.56.0 + '@types/eslint': 8.56.2 eslint: 8.56.0 rollup: 2.79.1 - vite: 5.0.12(@types/node@20.11.5) + vite: 5.1.3(@types/node@20.11.19) dev: true - /vite-tsconfig-paths@4.3.1(typescript@5.3.3)(vite@5.0.12): + /vite-tsconfig-paths@4.3.1(typescript@5.3.3)(vite@5.1.3): resolution: {integrity: sha512-cfgJwcGOsIxXOLU/nELPny2/LUD/lcf1IbfyeKTv2bsupVbTH/xpFtdQlBmIP1GEK2CjjLxYhFfB+QODFAx5aw==} peerDependencies: vite: '*' @@ -14057,15 +14254,15 @@ packages: dependencies: debug: 4.3.4 globrex: 0.1.2 - tsconfck: 3.0.1(typescript@5.3.3) - vite: 5.0.12(@types/node@20.11.5) + tsconfck: 3.0.2(typescript@5.3.3) + vite: 5.1.3(@types/node@20.11.19) transitivePeerDependencies: - supports-color - typescript dev: true - /vite@5.0.12(@types/node@20.11.5): - resolution: {integrity: sha512-4hsnEkG3q0N4Tzf1+t6NdN9dg/L3BM+q8SWgbSPnJvrgH2kgdyzfVJwbR1ic69/4uMJJ/3dqDZZE5/WwqW8U1w==} + /vite@5.1.3(@types/node@20.11.19): + resolution: {integrity: sha512-UfmUD36DKkqhi/F75RrxvPpry+9+tTkrXfMNZD+SboZqBCMsxKtO52XeGzzuh7ioz+Eo/SYDBbdb0Z7vgcDJew==} engines: {node: ^18.0.0 || >=20.0.0} hasBin: true peerDependencies: @@ -14092,23 +14289,23 @@ packages: terser: optional: true dependencies: - '@types/node': 20.11.5 - esbuild: 0.19.11 - postcss: 8.4.33 - rollup: 4.9.4 + '@types/node': 20.11.19 + esbuild: 0.19.12 + postcss: 8.4.35 + rollup: 4.12.0 optionalDependencies: fsevents: 2.3.3 dev: true - /vitest@1.2.2(@types/node@20.11.5): - resolution: {integrity: sha512-d5Ouvrnms3GD9USIK36KG8OZ5bEvKEkITFtnGv56HFaSlbItJuYr7hv2Lkn903+AvRAgSixiamozUVfORUekjw==} + /vitest@1.3.1(@types/node@20.11.19): + resolution: {integrity: sha512-/1QJqXs8YbCrfv/GPQ05wAZf2eakUPLPa18vkJAKE7RXOKfVHqMZZ1WlTjiwl6Gcn65M5vpNUB6EFLnEdRdEXQ==} engines: {node: ^18.0.0 || >=20.0.0} hasBin: true peerDependencies: '@edge-runtime/vm': '*' '@types/node': ^18.0.0 || >=20.0.0 - '@vitest/browser': ^1.0.0 - '@vitest/ui': ^1.0.0 + '@vitest/browser': 1.3.1 + '@vitest/ui': 1.3.1 happy-dom: '*' jsdom: '*' peerDependenciesMeta: @@ -14125,27 +14322,26 @@ packages: jsdom: optional: true dependencies: - '@types/node': 20.11.5 - '@vitest/expect': 1.2.2 - '@vitest/runner': 1.2.2 - '@vitest/snapshot': 1.2.2 - '@vitest/spy': 1.2.2 - '@vitest/utils': 1.2.2 + '@types/node': 20.11.19 + '@vitest/expect': 1.3.1 + '@vitest/runner': 1.3.1 + '@vitest/snapshot': 1.3.1 + '@vitest/spy': 1.3.1 + '@vitest/utils': 1.3.1 acorn-walk: 8.3.2 - cac: 6.7.14 chai: 4.4.1 debug: 4.3.4 execa: 8.0.1 local-pkg: 0.5.0 - magic-string: 0.30.5 + magic-string: 0.30.7 pathe: 1.1.2 picocolors: 1.0.0 std-env: 3.7.0 - strip-literal: 1.3.0 + strip-literal: 2.0.0 tinybench: 2.6.0 tinypool: 0.8.2 - vite: 5.0.12(@types/node@20.11.5) - vite-node: 1.2.2(@types/node@20.11.5) + vite: 5.1.3(@types/node@20.11.19) + vite-node: 1.3.1(@types/node@20.11.19) why-is-node-running: 2.2.2 transitivePeerDependencies: - less @@ -14177,7 +14373,7 @@ packages: dependencies: '@volar/typescript': 1.11.1 '@vue/language-core': 1.8.27(typescript@5.3.3) - semver: 7.5.4 + semver: 7.6.0 typescript: 5.3.3 dev: true @@ -14239,7 +14435,7 @@ packages: engines: {node: '>= 0.4'} dependencies: function.prototype.name: 1.1.6 - has-tostringtag: 1.0.0 + has-tostringtag: 1.0.2 is-async-function: 2.0.0 is-date-object: 1.0.5 is-finalizationregistry: 1.0.2 @@ -14249,7 +14445,7 @@ packages: isarray: 2.0.5 which-boxed-primitive: 1.0.2 which-collection: 1.0.1 - which-typed-array: 1.1.13 + which-typed-array: 1.1.14 dev: true /which-collection@1.0.1: @@ -14261,15 +14457,15 @@ packages: is-weakset: 2.0.2 dev: true - /which-typed-array@1.1.13: - resolution: {integrity: sha512-P5Nra0qjSncduVPEAr7xhoF5guty49ArDTwzJ/yNuPIbZppyRxFQsRCWrocxIY+CnMVG+qfbU2FmDKyvSGClow==} + /which-typed-array@1.1.14: + resolution: {integrity: sha512-VnXFiIW8yNn9kIHN88xvZ4yOWchftKDsRJ8fEPacX/wl1lOvBrhsJ/OeJCXq7B0AaijRuqgzSKalJoPk+D8MPg==} engines: {node: '>= 0.4'} dependencies: - available-typed-arrays: 1.0.5 - call-bind: 1.0.5 + available-typed-arrays: 1.0.7 + call-bind: 1.0.7 for-each: 0.3.3 gopd: 1.0.1 - has-tostringtag: 1.0.0 + has-tostringtag: 1.0.2 dev: true /which@2.0.2: @@ -14409,7 +14605,7 @@ packages: engines: {node: '>=12'} dependencies: cliui: 8.0.1 - escalade: 3.1.1 + escalade: 3.1.2 get-caller-file: 2.0.5 require-directory: 2.1.1 string-width: 4.2.3 @@ -14446,8 +14642,8 @@ packages: commander: 9.5.0 dev: true - /zod-validation-error@3.0.0(zod@3.22.4): - resolution: {integrity: sha512-x+agsJJG9rvC7axF0xqTEdZhJkLHyIZkdOAWDJSmwGPzxNHMHwtU6w2yDOAAP6yuSfTAUhAMJRBfhVGY64ySEQ==} + /zod-validation-error@3.0.2(zod@3.22.4): + resolution: {integrity: sha512-21xGaDmnU7lJZ4J63n5GXWqi+rTzGy3gDHbuZ1jP6xrK/DEQGyOqs/xW7eH96tIfCOYm+ecCuT0bfajBRKEVUw==} engines: {node: '>=18.0.0'} peerDependencies: zod: ^3.18.0 @@ -14459,12 +14655,12 @@ packages: resolution: {integrity: sha512-iC+8Io04lddc+mVqQ9AZ7OQ2MrUKGN+oIQyq1vemgt46jwCwLfhq7/pwnBnNXXXZb8VTVLKwp9EDkx+ryxIWmg==} dev: false - /zustand@4.4.7(@types/react@18.2.48)(react@18.2.0): - resolution: {integrity: sha512-QFJWJMdlETcI69paJwhSMJz7PPWjVP8Sjhclxmxmxv/RYI7ZOvR5BHX+ktH0we9gTWQMxcne8q1OY8xxz604gw==} + /zustand@4.5.1(@types/react@18.2.57)(react@18.2.0): + resolution: {integrity: sha512-XlauQmH64xXSC1qGYNv00ODaQ3B+tNPoy22jv2diYiP4eoDKr9LA+Bh5Bc3gplTrFdb6JVI+N4kc1DZ/tbtfPg==} engines: {node: '>=12.7.0'} peerDependencies: '@types/react': '>=16.8' - immer: '>=9.0' + immer: '>=9.0.6' react: '>=16.8' peerDependenciesMeta: '@types/react': @@ -14474,7 +14670,7 @@ packages: react: optional: true dependencies: - '@types/react': 18.2.48 + '@types/react': 18.2.57 react: 18.2.0 use-sync-external-store: 1.2.0(react@18.2.0) dev: false diff --git a/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/driver.ts b/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/driver.ts index 61fbd015f8d..c8237664c2e 100644 --- a/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/driver.ts +++ b/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/driver.ts @@ -2,15 +2,15 @@ import { StorageError } from 'app/store/enhancers/reduxRemember/errors'; import { $projectId } from 'app/store/nanostores/projectId'; import type { UseStore } from 'idb-keyval'; import { clear, createStore as createIDBKeyValStore, get, set } from 'idb-keyval'; -import { action, atom } from 'nanostores'; +import { atom } from 'nanostores'; import type { Driver } from 'redux-remember'; // Create a custom idb-keyval store (just needed to customize the name) export const $idbKeyValStore = atom(createIDBKeyValStore('invoke', 'invoke-store')); -export const clearIdbKeyValStore = action($idbKeyValStore, 'clear', (store) => { - clear(store.get()); -}); +export const clearIdbKeyValStore = () => { + clear($idbKeyValStore.get()); +}; // Create redux-remember driver, wrapping idb-keyval export const idbKeyValDriver: Driver = { diff --git a/invokeai/frontend/web/src/services/api/index.ts b/invokeai/frontend/web/src/services/api/index.ts index 7eeee7b3c82..8a30b3226d3 100644 --- a/invokeai/frontend/web/src/services/api/index.ts +++ b/invokeai/frontend/web/src/services/api/index.ts @@ -1,6 +1,7 @@ +import { createSelectorCreator, lruMemoize } from '@reduxjs/toolkit'; import type { FetchBaseQueryArgs } from '@reduxjs/toolkit/dist/query/fetchBaseQuery'; import type { BaseQueryFn, FetchArgs, FetchBaseQueryError, TagDescription } from '@reduxjs/toolkit/query/react'; -import { createApi, fetchBaseQuery } from '@reduxjs/toolkit/query/react'; +import { buildCreateApi, coreModule, fetchBaseQuery, reactHooksModule } from '@reduxjs/toolkit/query/react'; import { $authToken } from 'app/store/nanostores/authToken'; import { $baseUrl } from 'app/store/nanostores/baseUrl'; import { $projectId } from 'app/store/nanostores/projectId'; @@ -85,7 +86,14 @@ baseUrl: baseUrl || window.location.href.replace(/\/$/, ''), return rawBaseQuery(args, api, extraOptions); }; -export const api = createApi({ +const createLruSelector = createSelectorCreator(lruMemoize); + +const customCreateApi = buildCreateApi( + coreModule({ createSelector: createLruSelector }), + reactHooksModule({ createSelector: createLruSelector }) +); + +export const api = customCreateApi({ baseQuery: dynamicBaseQuery, reducerPath: 'api', tagTypes, From 8ea0de92f10b66f62ba7b23f00a9573027acb1db Mon Sep 17 00:00:00 2001 From: Brandon Rising Date: Wed, 21 Feb 2024 11:54:02 -0500 Subject: [PATCH 194/340] Create /search endpoint, update model object structure in scan model page --- invokeai/app/api/routers/model_manager.py | 31 +++++++++++++++++++ .../subpanels/ImportModelsPanel.tsx | 2 +- .../subpanels/ModelManagerPanel/ModelList.tsx | 6 ++-- 3 files changed, 35 insertions(+), 4 deletions(-) diff --git a/invokeai/app/api/routers/model_manager.py b/invokeai/app/api/routers/model_manager.py index 6b7111dd2ce..be4ed75069f 100644 --- a/invokeai/app/api/routers/model_manager.py +++ b/invokeai/app/api/routers/model_manager.py @@ -32,6 +32,7 @@ ) from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata +from invokeai.backend.model_manager.search import ModelSearch from ..dependencies import ApiDependencies @@ -233,6 +234,36 @@ async def list_tags() -> Set[str]: result: Set[str] = record_store.list_tags() return result +@model_manager_router.get( + "/search", + operation_id="search_for_models", + responses={ + 200: {"description": "Directory searched successfully"}, + 404: {"description": "Invalid directory path"}, + }, + status_code=200, + response_model=List[pathlib.Path], +) +async def search_for_models( + search_path: str = Query(description="Directory path to search for models", default=None), +) -> List[pathlib.Path]: + path = pathlib.Path(search_path) + if not search_path or not path.is_dir(): + raise HTTPException( + status_code=404, + detail=f"The search path '{search_path}' does not exist or is not directory", + ) + + search = ModelSearch() + try: + models_found = list(search.search(path)) + except Exception as e: + raise HTTPException( + status_code=404, + detail=f"An error occurred while searching the directory: {e}", + ) + return models_found + @model_manager_router.get( "/tags/search", diff --git a/invokeai/frontend/web/src/features/modelManager/subpanels/ImportModelsPanel.tsx b/invokeai/frontend/web/src/features/modelManager/subpanels/ImportModelsPanel.tsx index 18fcef96148..960f798e791 100644 --- a/invokeai/frontend/web/src/features/modelManager/subpanels/ImportModelsPanel.tsx +++ b/invokeai/frontend/web/src/features/modelManager/subpanels/ImportModelsPanel.tsx @@ -20,7 +20,7 @@ const ImportModelsPanel = () => { - diff --git a/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/ModelList.tsx b/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/ModelList.tsx index dd74bb0c23f..49c9bc61122 100644 --- a/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/ModelList.tsx +++ b/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/ModelList.tsx @@ -139,10 +139,10 @@ const modelsFilter = ( return; } - const matchesFilter = model.model_name.toLowerCase().includes(nameFilter.toLowerCase()); + const matchesFilter = model.name.toLowerCase().includes(nameFilter.toLowerCase()); - const matchesFormat = model_format === undefined || model.model_format === model_format; - const matchesType = model.model_type === model_type; + const matchesFormat = model_format === undefined || model.format === model_format; + const matchesType = model.type === model_type; if (matchesFilter && matchesFormat && matchesType) { filteredModels.push(model); From 0651e84fff90737e018129adbe6fa1ce058d7624 Mon Sep 17 00:00:00 2001 From: Mary Hipp Date: Thu, 22 Feb 2024 09:08:18 -0500 Subject: [PATCH 195/340] rename endpoint for scanning --- invokeai/app/api/routers/model_manager.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/invokeai/app/api/routers/model_manager.py b/invokeai/app/api/routers/model_manager.py index be4ed75069f..30b78f589dd 100644 --- a/invokeai/app/api/routers/model_manager.py +++ b/invokeai/app/api/routers/model_manager.py @@ -235,23 +235,23 @@ async def list_tags() -> Set[str]: return result @model_manager_router.get( - "/search", - operation_id="search_for_models", + "/scan_folder", + operation_id="scan_for_models", responses={ - 200: {"description": "Directory searched successfully"}, - 404: {"description": "Invalid directory path"}, + 200: {"description": "Directory scanned successfully"}, + 400: {"description": "Invalid directory path"}, }, status_code=200, response_model=List[pathlib.Path], ) -async def search_for_models( - search_path: str = Query(description="Directory path to search for models", default=None), +async def scan_for_models( + scan_path: str = Query(description="Directory path to search for models", default=None), ) -> List[pathlib.Path]: - path = pathlib.Path(search_path) - if not search_path or not path.is_dir(): + path = pathlib.Path(scan_path) + if not scan_path or not path.is_dir(): raise HTTPException( - status_code=404, - detail=f"The search path '{search_path}' does not exist or is not directory", + status_code=400, + detail=f"The search path '{scan_path}' does not exist or is not directory", ) search = ModelSearch() @@ -259,7 +259,7 @@ async def search_for_models( models_found = list(search.search(path)) except Exception as e: raise HTTPException( - status_code=404, + status_code=500, detail=f"An error occurred while searching the directory: {e}", ) return models_found From 03d0d190bb55a3eccd98aba6f38d2786c0b33953 Mon Sep 17 00:00:00 2001 From: Brandon Rising Date: Thu, 22 Feb 2024 09:13:50 -0500 Subject: [PATCH 196/340] Run ruff --- invokeai/app/api/routers/model_manager.py | 1 + 1 file changed, 1 insertion(+) diff --git a/invokeai/app/api/routers/model_manager.py b/invokeai/app/api/routers/model_manager.py index 30b78f589dd..aee457406a8 100644 --- a/invokeai/app/api/routers/model_manager.py +++ b/invokeai/app/api/routers/model_manager.py @@ -234,6 +234,7 @@ async def list_tags() -> Set[str]: result: Set[str] = record_store.list_tags() return result + @model_manager_router.get( "/scan_folder", operation_id="scan_for_models", From 69d93df0464a969aea4fd737a92ab515c4d78f83 Mon Sep 17 00:00:00 2001 From: dunkeroni Date: Tue, 20 Feb 2024 21:13:19 -0500 Subject: [PATCH 197/340] feat(nodes): added gradient mask node --- invokeai/app/invocations/fields.py | 1 + invokeai/app/invocations/latent.py | 69 +++++++++++++++++-- invokeai/app/invocations/primitives.py | 8 ++- .../stable_diffusion/diffusers_pipeline.py | 13 +++- 4 files changed, 80 insertions(+), 11 deletions(-) diff --git a/invokeai/app/invocations/fields.py b/invokeai/app/invocations/fields.py index 40d403c03d9..7f2d2783f21 100644 --- a/invokeai/app/invocations/fields.py +++ b/invokeai/app/invocations/fields.py @@ -199,6 +199,7 @@ class DenoiseMaskField(BaseModel): mask_name: str = Field(description="The name of the mask image") masked_latents_name: Optional[str] = Field(default=None, description="The name of the masked image latents") + gradient: Optional[bool] = Field(default=False, description="Used for gradient inpainting") class LatentsField(BaseModel): diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index bfe7255b628..97d3c705d45 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -23,7 +23,7 @@ from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel from diffusers.schedulers import DPMSolverSDEScheduler from diffusers.schedulers import SchedulerMixin as Scheduler -from PIL import Image +from PIL import Image, ImageFilter from pydantic import field_validator from torchvision.transforms.functional import resize as tv_resize @@ -128,7 +128,7 @@ class CreateDenoiseMaskInvocation(BaseInvocation): ui_order=4, ) - def prep_mask_tensor(self, mask_image: Image) -> torch.Tensor: + def prep_mask_tensor(self, mask_image: Image.Image) -> torch.Tensor: if mask_image.mode != "L": mask_image = mask_image.convert("L") mask_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False) @@ -169,6 +169,62 @@ def invoke(self, context: InvocationContext) -> DenoiseMaskOutput: return DenoiseMaskOutput.build( mask_name=mask_name, masked_latents_name=masked_latents_name, + gradient=False, + ) + + +@invocation( + "create_gradient_mask", + title="Create Gradient Mask", + tags=["mask", "denoise"], + category="latents", + version="1.0.0", +) +class CreateGradientMaskInvocation(BaseInvocation): + """Creates mask for denoising model run.""" + + mask: ImageField = InputField(default=None, description="Image which will be masked", ui_order=1) + edge_radius: int = InputField( + default=16, ge=0, description="How far to blur/expand the edges of the mask", ui_order=2 + ) + coherence_mode: Literal["Gaussian Blur", "Box Blur", "Staged"] = InputField(default="Gaussian Blur", ui_order=3) + minimum_denoise: float = InputField( + default=0.0, ge=0, le=1, description="Minimum denoise level for the coherence region", ui_order=4 + ) + + @torch.no_grad() + def invoke(self, context: InvocationContext) -> DenoiseMaskOutput: + mask_image = context.images.get_pil(self.mask.image_name, mode="L") + if self.coherence_mode == "Box Blur": + blur_mask = mask_image.filter(ImageFilter.BoxBlur(self.edge_radius)) + else: # Gaussian Blur OR Staged + # Gaussian Blur uses standard deviation. 1/2 radius is a good approximation + blur_mask = mask_image.filter(ImageFilter.GaussianBlur(self.edge_radius / 2)) + + mask_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False) + blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(blur_mask, normalize=False) + + # redistribute blur so that the edges are 0 and blur out to 1 + blur_tensor = (blur_tensor - 0.5) * 2 + + threshold = 1 - self.minimum_denoise + + if self.coherence_mode == "Staged": + # wherever the blur_tensor is masked to any degree, convert it to threshold + blur_tensor = torch.where((blur_tensor < 1), threshold, blur_tensor) + else: + # wherever the blur_tensor is above threshold but less than 1, drop it to threshold + blur_tensor = torch.where((blur_tensor > threshold) & (blur_tensor < 1), threshold, blur_tensor) + + # multiply original mask to force actually masked regions to 0 + blur_tensor = mask_tensor * blur_tensor + + mask_name = context.tensors.save(tensor=blur_tensor.unsqueeze(1)) + + return DenoiseMaskOutput.build( + mask_name=mask_name, + masked_latents_name=None, + gradient=True, ) @@ -606,9 +662,9 @@ def init_scheduler( def prep_inpaint_mask( self, context: InvocationContext, latents: torch.Tensor - ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: + ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], bool]: if self.denoise_mask is None: - return None, None + return None, None, False mask = context.tensors.load(self.denoise_mask.mask_name) mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False) @@ -617,7 +673,7 @@ def prep_inpaint_mask( else: masked_latents = None - return 1 - mask, masked_latents + return 1 - mask, masked_latents, self.denoise_mask.gradient @torch.no_grad() def invoke(self, context: InvocationContext) -> LatentsOutput: @@ -644,7 +700,7 @@ def invoke(self, context: InvocationContext) -> LatentsOutput: if seed is None: seed = 0 - mask, masked_latents = self.prep_inpaint_mask(context, latents) + mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents) # TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets, # below. Investigate whether this is appropriate. @@ -732,6 +788,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: seed=seed, mask=mask, masked_latents=masked_latents, + gradient_mask=gradient_mask, num_inference_steps=num_inference_steps, conditioning_data=conditioning_data, control_data=controlnet_data, diff --git a/invokeai/app/invocations/primitives.py b/invokeai/app/invocations/primitives.py index 43422134829..c761bb0895c 100644 --- a/invokeai/app/invocations/primitives.py +++ b/invokeai/app/invocations/primitives.py @@ -299,9 +299,13 @@ class DenoiseMaskOutput(BaseInvocationOutput): denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run") @classmethod - def build(cls, mask_name: str, masked_latents_name: Optional[str] = None) -> "DenoiseMaskOutput": + def build( + cls, mask_name: str, masked_latents_name: Optional[str] = None, gradient: Optional[bool] = False + ) -> "DenoiseMaskOutput": return cls( - denoise_mask=DenoiseMaskField(mask_name=mask_name, masked_latents_name=masked_latents_name), + denoise_mask=DenoiseMaskField( + mask_name=mask_name, masked_latents_name=masked_latents_name, gradient=gradient + ), ) diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index a85e3762dc4..fd3ecde47b7 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -86,6 +86,7 @@ class AddsMaskGuidance: mask_latents: torch.FloatTensor scheduler: SchedulerMixin noise: torch.Tensor + gradient_mask: bool def __call__(self, step_output: Union[BaseOutput, SchedulerOutput], t: torch.Tensor, conditioning) -> BaseOutput: output_class = step_output.__class__ # We'll create a new one with masked data. @@ -121,7 +122,12 @@ def apply_mask(self, latents: torch.Tensor, t) -> torch.Tensor: # TODO: Do we need to also apply scheduler.scale_model_input? Or is add_noise appropriately scaled already? # mask_latents = self.scheduler.scale_model_input(mask_latents, t) mask_latents = einops.repeat(mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size) - masked_input = torch.lerp(mask_latents.to(dtype=latents.dtype), latents, mask.to(dtype=latents.dtype)) + if self.gradient_mask: + threshhold = (t.item()) / self.scheduler.config.num_train_timesteps + mask_bool = mask > threshhold # I don't know when mask got inverted, but it did + masked_input = torch.where(mask_bool, latents, mask_latents) + else: + masked_input = torch.lerp(mask_latents.to(dtype=latents.dtype), latents, mask.to(dtype=latents.dtype)) return masked_input @@ -335,6 +341,7 @@ def latents_from_embeddings( t2i_adapter_data: Optional[list[T2IAdapterData]] = None, mask: Optional[torch.Tensor] = None, masked_latents: Optional[torch.Tensor] = None, + gradient_mask: Optional[bool] = False, seed: Optional[int] = None, ) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]: if init_timestep.shape[0] == 0: @@ -375,7 +382,7 @@ def latents_from_embeddings( self._unet_forward, mask, masked_latents ) else: - additional_guidance.append(AddsMaskGuidance(mask, orig_latents, self.scheduler, noise)) + additional_guidance.append(AddsMaskGuidance(mask, orig_latents, self.scheduler, noise, gradient_mask)) try: latents, attention_map_saver = self.generate_latents_from_embeddings( @@ -392,7 +399,7 @@ def latents_from_embeddings( self.invokeai_diffuser.model_forward_callback = self._unet_forward # restore unmasked part - if mask is not None: + if mask is not None and not gradient_mask: latents = torch.lerp(orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype)) return latents, attention_map_saver From 5fc0ed63fe3f5d4ac07face5b48329723f9497c9 Mon Sep 17 00:00:00 2001 From: dunkeroni Date: Tue, 20 Feb 2024 21:47:25 -0500 Subject: [PATCH 198/340] chore: typing fix --- invokeai/app/invocations/fields.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invokeai/app/invocations/fields.py b/invokeai/app/invocations/fields.py index 7f2d2783f21..712ab415b0c 100644 --- a/invokeai/app/invocations/fields.py +++ b/invokeai/app/invocations/fields.py @@ -199,7 +199,7 @@ class DenoiseMaskField(BaseModel): mask_name: str = Field(description="The name of the mask image") masked_latents_name: Optional[str] = Field(default=None, description="The name of the masked image latents") - gradient: Optional[bool] = Field(default=False, description="Used for gradient inpainting") + gradient: bool = Field(default=False, description="Used for gradient inpainting") class LatentsField(BaseModel): From 3b0bd6773edb7535d21462ce60724c92d2baf88b Mon Sep 17 00:00:00 2001 From: dunkeroni Date: Tue, 20 Feb 2024 22:13:01 -0500 Subject: [PATCH 199/340] chore: typing --- invokeai/app/invocations/primitives.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invokeai/app/invocations/primitives.py b/invokeai/app/invocations/primitives.py index c761bb0895c..b80e34dc98f 100644 --- a/invokeai/app/invocations/primitives.py +++ b/invokeai/app/invocations/primitives.py @@ -300,7 +300,7 @@ class DenoiseMaskOutput(BaseInvocationOutput): @classmethod def build( - cls, mask_name: str, masked_latents_name: Optional[str] = None, gradient: Optional[bool] = False + cls, mask_name: str, masked_latents_name: Optional[str] = None, gradient: bool = False ) -> "DenoiseMaskOutput": return cls( denoise_mask=DenoiseMaskField( From 70cd3196c404c39d693567c976a5d9ebc6530a0b Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Wed, 21 Feb 2024 10:18:30 -0500 Subject: [PATCH 200/340] remove startup dependency on legacy models.yaml file --- invokeai/app/services/config/config_base.py | 1 + .../app/services/config/config_default.py | 5 +- invokeai/backend/install/check_root.py | 1 - .../backend/install/invokeai_configure.py | 2 +- invokeai/backend/model_manager/merge.py | 4 +- .../training/textual_inversion_training.py | 2 +- invokeai/frontend/install/model_install.py | 2 +- invokeai/frontend/merge/merge_diffusers.py | 2 +- .../frontend/training/textual_inversion.py | 4 +- .../frontend/training/textual_inversion2.py | 454 ------------------ invokeai/frontend/web/public/locales/en.json | 2 +- 11 files changed, 11 insertions(+), 468 deletions(-) mode change 100755 => 100644 invokeai/frontend/training/textual_inversion.py delete mode 100644 invokeai/frontend/training/textual_inversion2.py diff --git a/invokeai/app/services/config/config_base.py b/invokeai/app/services/config/config_base.py index c73aa438096..20dac149374 100644 --- a/invokeai/app/services/config/config_base.py +++ b/invokeai/app/services/config/config_base.py @@ -156,6 +156,7 @@ def _excluded_from_yaml(cls) -> List[str]: "lora_dir", "embedding_dir", "controlnet_dir", + "conf_path", ] @classmethod diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py index 2af775372dd..01fd5e21792 100644 --- a/invokeai/app/services/config/config_default.py +++ b/invokeai/app/services/config/config_default.py @@ -30,7 +30,6 @@ lora_dir: null embedding_dir: null controlnet_dir: null - conf_path: configs/models.yaml models_dir: models legacy_conf_dir: configs/stable-diffusion db_dir: databases @@ -123,7 +122,6 @@ root_path - path to InvokeAI root output_path - path to default outputs directory - model_conf_path - path to models.yaml conf - alias for the above embedding_path - path to the embeddings directory lora_path - path to the LoRA directory @@ -163,7 +161,6 @@ class InvokeBatch(InvokeAISettings): InvokeAI: Paths: root: /home/lstein/invokeai-main - conf_path: configs/models.yaml legacy_conf_dir: configs/stable-diffusion outdir: outputs ... @@ -237,7 +234,6 @@ class InvokeAIAppConfig(InvokeAISettings): # PATHS root : Optional[Path] = Field(default=None, description='InvokeAI runtime root directory', json_schema_extra=Categories.Paths) autoimport_dir : Path = Field(default=Path('autoimport'), description='Path to a directory of models files to be imported on startup.', json_schema_extra=Categories.Paths) - conf_path : Path = Field(default=Path('configs/models.yaml'), description='Path to models definition file', json_schema_extra=Categories.Paths) models_dir : Path = Field(default=Path('models'), description='Path to the models directory', json_schema_extra=Categories.Paths) convert_cache_dir : Path = Field(default=Path('models/.cache'), description='Path to the converted models cache directory', json_schema_extra=Categories.Paths) legacy_conf_dir : Path = Field(default=Path('configs/stable-diffusion'), description='Path to directory of legacy checkpoint config files', json_schema_extra=Categories.Paths) @@ -301,6 +297,7 @@ class InvokeAIAppConfig(InvokeAISettings): lora_dir : Optional[Path] = Field(default=None, description='Path to a directory of LoRA/LyCORIS models to be imported on startup.', json_schema_extra=Categories.Paths) embedding_dir : Optional[Path] = Field(default=None, description='Path to a directory of Textual Inversion embeddings to be imported on startup.', json_schema_extra=Categories.Paths) controlnet_dir : Optional[Path] = Field(default=None, description='Path to a directory of ControlNet embeddings to be imported on startup.', json_schema_extra=Categories.Paths) + conf_path : Path = Field(default=Path('configs/models.yaml'), description='Path to models definition file', json_schema_extra=Categories.Paths) # this is not referred to in the source code and can be removed entirely #free_gpu_mem : Optional[bool] = Field(default=None, description="If true, purge model from GPU after each generation.", json_schema_extra=Categories.MemoryPerformance) diff --git a/invokeai/backend/install/check_root.py b/invokeai/backend/install/check_root.py index ee264016b45..cbf9976123c 100644 --- a/invokeai/backend/install/check_root.py +++ b/invokeai/backend/install/check_root.py @@ -8,7 +8,6 @@ def check_invokeai_root(config: InvokeAIAppConfig): try: - assert config.model_conf_path.exists(), f"{config.model_conf_path} not found" assert config.db_path.parent.exists(), f"{config.db_path.parent} not found" assert config.models_path.exists(), f"{config.models_path} not found" if not config.ignore_missing_core_models: diff --git a/invokeai/backend/install/invokeai_configure.py b/invokeai/backend/install/invokeai_configure.py index ac3e583de3e..53cca64a1a5 100755 --- a/invokeai/backend/install/invokeai_configure.py +++ b/invokeai/backend/install/invokeai_configure.py @@ -939,7 +939,7 @@ def main() -> None: # run this unconditionally in case new directories need to be added initialize_rootdir(config.root_path, opt.yes_to_all) - # this will initialize the models.yaml file if not present + # this will initialize and populate the models tables if not present install_helper = InstallHelper(config, logger) models_to_download = default_user_selections(opt, install_helper) diff --git a/invokeai/backend/model_manager/merge.py b/invokeai/backend/model_manager/merge.py index 7063cb907d2..1a3b9cb7de0 100644 --- a/invokeai/backend/model_manager/merge.py +++ b/invokeai/backend/model_manager/merge.py @@ -1,7 +1,7 @@ """ invokeai.backend.model_manager.merge exports: merge_diffusion_models() -- combine multiple models by location and return a pipeline object -merge_diffusion_models_and_commit() -- combine multiple models by ModelManager ID and write to models.yaml +merge_diffusion_models_and_commit() -- combine multiple models by ModelManager ID and write to the models tables Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team """ @@ -101,7 +101,7 @@ def merge_diffusion_models_and_save( **kwargs: Any, ) -> AnyModelConfig: """ - :param models: up to three models, designated by their InvokeAI models.yaml model name + :param models: up to three models, designated by their registered InvokeAI model name :param merged_model_name: name for new model :param alpha: The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2 diff --git a/invokeai/backend/training/textual_inversion_training.py b/invokeai/backend/training/textual_inversion_training.py index e31ce959c29..9a38c006a5b 100644 --- a/invokeai/backend/training/textual_inversion_training.py +++ b/invokeai/backend/training/textual_inversion_training.py @@ -120,7 +120,7 @@ def parse_args() -> Namespace: "--model", type=str, default="sd-1/main/stable-diffusion-v1-5", - help="Name of the diffusers model to train against, as defined in configs/models.yaml.", + help="Name of the diffusers model to train against.", ) model_group.add_argument( "--revision", diff --git a/invokeai/frontend/install/model_install.py b/invokeai/frontend/install/model_install.py index 3a4d66ae0a0..2f7fd0a1d03 100644 --- a/invokeai/frontend/install/model_install.py +++ b/invokeai/frontend/install/model_install.py @@ -455,7 +455,7 @@ def marshall_arguments(self) -> None: selections = self.parentApp.install_selections all_models = self.all_models - # Defined models (in INITIAL_CONFIG.yaml or models.yaml) to add/remove + # Defined models (in INITIAL_CONFIG.yaml or invokeai.db) to add/remove ui_sections = [ self.starter_pipelines, self.pipeline_models, diff --git a/invokeai/frontend/merge/merge_diffusers.py b/invokeai/frontend/merge/merge_diffusers.py index 5484040674d..ff9acd569ee 100644 --- a/invokeai/frontend/merge/merge_diffusers.py +++ b/invokeai/frontend/merge/merge_diffusers.py @@ -435,7 +435,7 @@ def main(): run_cli(args) except widget.NotEnoughSpaceForWidget as e: if str(e).startswith("Height of 1 allocated"): - logger.error("You need to have at least two diffusers models defined in models.yaml in order to merge") + logger.error("You need to have at least two diffusers models in order to merge") else: logger.error("Not enough room for the user interface. Try making this window larger.") sys.exit(-1) diff --git a/invokeai/frontend/training/textual_inversion.py b/invokeai/frontend/training/textual_inversion.py old mode 100755 new mode 100644 index 81b1081bb8a..6250b313b06 --- a/invokeai/frontend/training/textual_inversion.py +++ b/invokeai/frontend/training/textual_inversion.py @@ -261,7 +261,7 @@ def ok_cancel(self): def validate_field_values(self) -> bool: bad_fields = [] if self.model.value is None: - bad_fields.append("Model Name must correspond to a known model in models.yaml") + bad_fields.append("Model Name must correspond to a known model in invokeai.db") if not re.match("^[a-zA-Z0-9.-]+$", self.placeholder_token.value): bad_fields.append("Trigger term must only contain alphanumeric characters, the dot and hyphen") if self.train_data_dir.value is None: @@ -442,7 +442,7 @@ def main() -> None: pass except (widget.NotEnoughSpaceForWidget, Exception) as e: if str(e).startswith("Height of 1 allocated"): - logger.error("You need to have at least one diffusers models defined in models.yaml in order to train") + logger.error("You need to have at least one diffusers models defined in invokeai.db in order to train") elif str(e).startswith("addwstr"): logger.error("Not enough window space for the interface. Please make your window larger and try again.") else: diff --git a/invokeai/frontend/training/textual_inversion2.py b/invokeai/frontend/training/textual_inversion2.py deleted file mode 100644 index 81b1081bb8a..00000000000 --- a/invokeai/frontend/training/textual_inversion2.py +++ /dev/null @@ -1,454 +0,0 @@ -#!/usr/bin/env python - -""" -This is the frontend to "textual_inversion_training.py". - -Copyright (c) 2023-24 Lincoln Stein and the InvokeAI Development Team -""" - - -import os -import re -import shutil -import sys -import traceback -from argparse import Namespace -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import npyscreen -from npyscreen import widget -from omegaconf import OmegaConf - -import invokeai.backend.util.logging as logger -from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend.install.install_helper import initialize_installer -from invokeai.backend.model_manager import ModelType -from invokeai.backend.training import do_textual_inversion_training, parse_args - -TRAINING_DATA = "text-inversion-training-data" -TRAINING_DIR = "text-inversion-output" -CONF_FILE = "preferences.conf" -config = None - - -class textualInversionForm(npyscreen.FormMultiPageAction): - resolutions = [512, 768, 1024] - lr_schedulers = [ - "linear", - "cosine", - "cosine_with_restarts", - "polynomial", - "constant", - "constant_with_warmup", - ] - precisions = ["no", "fp16", "bf16"] - learnable_properties = ["object", "style"] - - def __init__(self, parentApp: npyscreen.NPSAppManaged, name: str, saved_args: Optional[Dict[str, str]] = None): - self.saved_args = saved_args or {} - super().__init__(parentApp, name) - - def afterEditing(self) -> None: - self.parentApp.setNextForm(None) - - def create(self) -> None: - self.model_names, default = self.get_model_names() - default_initializer_token = "★" - default_placeholder_token = "" - saved_args = self.saved_args - - assert config is not None - - try: - default = self.model_names.index(saved_args["model"]) - except Exception: - pass - - self.add_widget_intelligent( - npyscreen.FixedText, - value="Use ctrl-N and ctrl-P to move to the ext and

revious fields, cursor arrows to make a selection, and space to toggle checkboxes.", - editable=False, - ) - - self.model = self.add_widget_intelligent( - npyscreen.TitleSelectOne, - name="Model Name:", - values=sorted(self.model_names), - value=default, - max_height=len(self.model_names) + 1, - scroll_exit=True, - ) - self.placeholder_token = self.add_widget_intelligent( - npyscreen.TitleText, - name="Trigger Term:", - value="", # saved_args.get('placeholder_token',''), # to restore previous term - scroll_exit=True, - ) - self.placeholder_token.when_value_edited = self.initializer_changed - self.nextrely -= 1 - self.nextrelx += 30 - self.prompt_token = self.add_widget_intelligent( - npyscreen.FixedText, - name="Trigger term for use in prompt", - value="", - editable=False, - scroll_exit=True, - ) - self.nextrelx -= 30 - self.initializer_token = self.add_widget_intelligent( - npyscreen.TitleText, - name="Initializer:", - value=saved_args.get("initializer_token", default_initializer_token), - scroll_exit=True, - ) - self.resume_from_checkpoint = self.add_widget_intelligent( - npyscreen.Checkbox, - name="Resume from last saved checkpoint", - value=False, - scroll_exit=True, - ) - self.learnable_property = self.add_widget_intelligent( - npyscreen.TitleSelectOne, - name="Learnable property:", - values=self.learnable_properties, - value=self.learnable_properties.index(saved_args.get("learnable_property", "object")), - max_height=4, - scroll_exit=True, - ) - self.train_data_dir = self.add_widget_intelligent( - npyscreen.TitleFilename, - name="Data Training Directory:", - select_dir=True, - must_exist=False, - value=str( - saved_args.get( - "train_data_dir", - config.root_dir / TRAINING_DATA / default_placeholder_token, - ) - ), - scroll_exit=True, - ) - self.output_dir = self.add_widget_intelligent( - npyscreen.TitleFilename, - name="Output Destination Directory:", - select_dir=True, - must_exist=False, - value=str( - saved_args.get( - "output_dir", - config.root_dir / TRAINING_DIR / default_placeholder_token, - ) - ), - scroll_exit=True, - ) - self.resolution = self.add_widget_intelligent( - npyscreen.TitleSelectOne, - name="Image resolution (pixels):", - values=self.resolutions, - value=self.resolutions.index(saved_args.get("resolution", 512)), - max_height=4, - scroll_exit=True, - ) - self.center_crop = self.add_widget_intelligent( - npyscreen.Checkbox, - name="Center crop images before resizing to resolution", - value=saved_args.get("center_crop", False), - scroll_exit=True, - ) - self.mixed_precision = self.add_widget_intelligent( - npyscreen.TitleSelectOne, - name="Mixed Precision:", - values=self.precisions, - value=self.precisions.index(saved_args.get("mixed_precision", "fp16")), - max_height=4, - scroll_exit=True, - ) - self.num_train_epochs = self.add_widget_intelligent( - npyscreen.TitleSlider, - name="Number of training epochs:", - out_of=1000, - step=50, - lowest=1, - value=saved_args.get("num_train_epochs", 100), - scroll_exit=True, - ) - self.max_train_steps = self.add_widget_intelligent( - npyscreen.TitleSlider, - name="Max Training Steps:", - out_of=10000, - step=500, - lowest=1, - value=saved_args.get("max_train_steps", 3000), - scroll_exit=True, - ) - self.train_batch_size = self.add_widget_intelligent( - npyscreen.TitleSlider, - name="Batch Size (reduce if you run out of memory):", - out_of=50, - step=1, - lowest=1, - value=saved_args.get("train_batch_size", 8), - scroll_exit=True, - ) - self.gradient_accumulation_steps = self.add_widget_intelligent( - npyscreen.TitleSlider, - name="Gradient Accumulation Steps (may need to decrease this to resume from a checkpoint):", - out_of=10, - step=1, - lowest=1, - value=saved_args.get("gradient_accumulation_steps", 4), - scroll_exit=True, - ) - self.lr_warmup_steps = self.add_widget_intelligent( - npyscreen.TitleSlider, - name="Warmup Steps:", - out_of=100, - step=1, - lowest=0, - value=saved_args.get("lr_warmup_steps", 0), - scroll_exit=True, - ) - self.learning_rate = self.add_widget_intelligent( - npyscreen.TitleText, - name="Learning Rate:", - value=str( - saved_args.get("learning_rate", "5.0e-04"), - ), - scroll_exit=True, - ) - self.scale_lr = self.add_widget_intelligent( - npyscreen.Checkbox, - name="Scale learning rate by number GPUs, steps and batch size", - value=saved_args.get("scale_lr", True), - scroll_exit=True, - ) - self.enable_xformers_memory_efficient_attention = self.add_widget_intelligent( - npyscreen.Checkbox, - name="Use xformers acceleration", - value=saved_args.get("enable_xformers_memory_efficient_attention", False), - scroll_exit=True, - ) - self.lr_scheduler = self.add_widget_intelligent( - npyscreen.TitleSelectOne, - name="Learning rate scheduler:", - values=self.lr_schedulers, - max_height=7, - value=self.lr_schedulers.index(saved_args.get("lr_scheduler", "constant")), - scroll_exit=True, - ) - self.model.editing = True - - def initializer_changed(self) -> None: - placeholder = self.placeholder_token.value - self.prompt_token.value = f"(Trigger by using <{placeholder}> in your prompts)" - self.train_data_dir.value = str(config.root_dir / TRAINING_DATA / placeholder) - self.output_dir.value = str(config.root_dir / TRAINING_DIR / placeholder) - self.resume_from_checkpoint.value = Path(self.output_dir.value).exists() - - def on_ok(self): - if self.validate_field_values(): - self.parentApp.setNextForm(None) - self.editing = False - self.parentApp.ti_arguments = self.marshall_arguments() - npyscreen.notify("Launching textual inversion training. This will take a while...") - else: - self.editing = True - - def ok_cancel(self): - sys.exit(0) - - def validate_field_values(self) -> bool: - bad_fields = [] - if self.model.value is None: - bad_fields.append("Model Name must correspond to a known model in models.yaml") - if not re.match("^[a-zA-Z0-9.-]+$", self.placeholder_token.value): - bad_fields.append("Trigger term must only contain alphanumeric characters, the dot and hyphen") - if self.train_data_dir.value is None: - bad_fields.append("Data Training Directory cannot be empty") - if self.output_dir.value is None: - bad_fields.append("The Output Destination Directory cannot be empty") - if len(bad_fields) > 0: - message = "The following problems were detected and must be corrected:" - for problem in bad_fields: - message += f"\n* {problem}" - npyscreen.notify_confirm(message) - return False - else: - return True - - def get_model_names(self) -> Tuple[List[str], int]: - global config - assert config is not None - installer = initialize_installer(config) - store = installer.record_store - main_models = store.search_by_attr(model_type=ModelType.Main) - model_names = [f"{x.base.value}/{x.type.value}/{x.name}" for x in main_models if x.format == "diffusers"] - default = 0 - return (model_names, default) - - def marshall_arguments(self) -> dict: - args = {} - - # the choices - args.update( - model=self.model_names[self.model.value[0]], - resolution=self.resolutions[self.resolution.value[0]], - lr_scheduler=self.lr_schedulers[self.lr_scheduler.value[0]], - mixed_precision=self.precisions[self.mixed_precision.value[0]], - learnable_property=self.learnable_properties[self.learnable_property.value[0]], - ) - - # all the strings and booleans - for attr in ( - "initializer_token", - "placeholder_token", - "train_data_dir", - "output_dir", - "scale_lr", - "center_crop", - "enable_xformers_memory_efficient_attention", - ): - args[attr] = getattr(self, attr).value - - # all the integers - for attr in ( - "train_batch_size", - "gradient_accumulation_steps", - "num_train_epochs", - "max_train_steps", - "lr_warmup_steps", - ): - args[attr] = int(getattr(self, attr).value) - - # the floats (just one) - args.update(learning_rate=float(self.learning_rate.value)) - - # a special case - if self.resume_from_checkpoint.value and Path(self.output_dir.value).exists(): - args["resume_from_checkpoint"] = "latest" - - return args - - -class MyApplication(npyscreen.NPSAppManaged): - def __init__(self, saved_args: Optional[Dict[str, str]] = None): - super().__init__() - self.ti_arguments = None - self.saved_args = saved_args - - def onStart(self): - npyscreen.setTheme(npyscreen.Themes.DefaultTheme) - self.main = self.addForm( - "MAIN", - textualInversionForm, - name="Textual Inversion Settings", - saved_args=self.saved_args, - ) - - -def copy_to_embeddings_folder(args: Dict[str, str]) -> None: - """ - Copy learned_embeds.bin into the embeddings folder, and offer to - delete the full model and checkpoints. - """ - assert config is not None - source = Path(args["output_dir"], "learned_embeds.bin") - dest_dir_name = args["placeholder_token"].strip("<>") - destination = config.root_dir / "embeddings" / dest_dir_name - os.makedirs(destination, exist_ok=True) - logger.info(f"Training completed. Copying learned_embeds.bin into {str(destination)}") - shutil.copy(source, destination) - if (input("Delete training logs and intermediate checkpoints? [y] ") or "y").startswith(("y", "Y")): - shutil.rmtree(Path(args["output_dir"])) - else: - logger.info(f'Keeping {args["output_dir"]}') - - -def save_args(args: dict) -> None: - """ - Save the current argument values to an omegaconf file - """ - assert config is not None - dest_dir = config.root_dir / TRAINING_DIR - os.makedirs(dest_dir, exist_ok=True) - conf_file = dest_dir / CONF_FILE - conf = OmegaConf.create(args) - OmegaConf.save(config=conf, f=conf_file) - - -def previous_args() -> dict: - """ - Get the previous arguments used. - """ - assert config is not None - conf_file = config.root_dir / TRAINING_DIR / CONF_FILE - try: - conf = OmegaConf.load(conf_file) - conf["placeholder_token"] = conf["placeholder_token"].strip("<>") - except Exception: - conf = None - - return conf - - -def do_front_end() -> None: - global config - saved_args = previous_args() - myapplication = MyApplication(saved_args=saved_args) - myapplication.run() - - if my_args := myapplication.ti_arguments: - os.makedirs(my_args["output_dir"], exist_ok=True) - - # Automatically add angle brackets around the trigger - if not re.match("^<.+>$", my_args["placeholder_token"]): - my_args["placeholder_token"] = f"<{my_args['placeholder_token']}>" - - my_args["only_save_embeds"] = True - save_args(my_args) - - try: - print(my_args) - do_textual_inversion_training(config, **my_args) - copy_to_embeddings_folder(my_args) - except Exception as e: - logger.error("An exception occurred during training. The exception was:") - logger.error(str(e)) - logger.error("DETAILS:") - logger.error(traceback.format_exc()) - - -def main() -> None: - global config - - args: Namespace = parse_args() - config = InvokeAIAppConfig.get_config() - config.parse_args([]) - - # change root if needed - if args.root_dir: - config.root = args.root_dir - - try: - if args.front_end: - do_front_end() - else: - do_textual_inversion_training(config, **vars(args)) - except AssertionError as e: - logger.error(e) - sys.exit(-1) - except KeyboardInterrupt: - pass - except (widget.NotEnoughSpaceForWidget, Exception) as e: - if str(e).startswith("Height of 1 allocated"): - logger.error("You need to have at least one diffusers models defined in models.yaml in order to train") - elif str(e).startswith("addwstr"): - logger.error("Not enough window space for the interface. Please make your window larger and try again.") - else: - logger.error(e) - sys.exit(-1) - - -if __name__ == "__main__": - main() diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 9abf0b80aa2..a458563fd56 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -814,7 +814,7 @@ "simpleModelDesc": "Provide a path to a local Diffusers model, local checkpoint / safetensors model a HuggingFace Repo ID, or a checkpoint/diffusers model URL.", "statusConverting": "Converting", "syncModels": "Sync Models", - "syncModelsDesc": "If your models are out of sync with the backend, you can refresh them up using this option. This is generally handy in cases where you manually update your models.yaml file or add models to the InvokeAI root folder after the application has booted.", + "syncModelsDesc": "If your models are out of sync with the backend, you can refresh them up using this option. This is generally handy in cases where you add models to the InvokeAI root folder or autoimport directory after the application has booted.", "updateModel": "Update Model", "useCustomConfig": "Use Custom Config", "v1": "v1", From b6f0fc1603200956881e7359017b003206d50ad3 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Wed, 21 Feb 2024 11:46:23 -0500 Subject: [PATCH 201/340] fix repo-id for the Deliberate v5 model prevent lora and embedding file suffixes from being stripped during installation apply psychedelicious patch to get compel to load proper TI embedding --- .../app/services/model_install/model_install_default.py | 6 +++++- invokeai/configs/INITIAL_MODELS.yaml | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index 7dee8bfd8cb..9b771c51596 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -154,8 +154,12 @@ def install_path( info: AnyModelConfig = self._probe_model(Path(model_path), config) old_hash = info.current_hash + + if preferred_name := config.get("name"): + preferred_name = Path(preferred_name).with_suffix(model_path.suffix) + dest_path = ( - self.app_config.models_path / info.base.value / info.type.value / (config.get("name") or model_path.name) + self.app_config.models_path / info.base.value / info.type.value / (preferred_name or model_path.name) ) try: new_path = self._copy_model(model_path, dest_path) diff --git a/invokeai/configs/INITIAL_MODELS.yaml b/invokeai/configs/INITIAL_MODELS.yaml index ca2283ab811..811121d1ba1 100644 --- a/invokeai/configs/INITIAL_MODELS.yaml +++ b/invokeai/configs/INITIAL_MODELS.yaml @@ -34,7 +34,7 @@ sd-1/main/Analog-Diffusion: recommended: False sd-1/main/Deliberate: description: Versatile model that produces detailed images up to 768px (4.27 GB) - source: XpucT/Deliberate + source: stablediffusionapi/deliberate-v5 recommended: False sd-1/main/Dungeons-and-Diffusion: description: Dungeons & Dragons characters (2.13 GB) From d2a4ab5e7bedc46fc904631392388b03bdacfef9 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Wed, 21 Feb 2024 17:40:53 -0500 Subject: [PATCH 202/340] use official Deliberate download repo --- invokeai/configs/INITIAL_MODELS.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invokeai/configs/INITIAL_MODELS.yaml b/invokeai/configs/INITIAL_MODELS.yaml index 811121d1ba1..8ad788fba78 100644 --- a/invokeai/configs/INITIAL_MODELS.yaml +++ b/invokeai/configs/INITIAL_MODELS.yaml @@ -34,7 +34,7 @@ sd-1/main/Analog-Diffusion: recommended: False sd-1/main/Deliberate: description: Versatile model that produces detailed images up to 768px (4.27 GB) - source: stablediffusionapi/deliberate-v5 + source: https://huggingface.co/XpucT/Deliberate/resolve/main/Deliberate_v5.safetensors?download=true recommended: False sd-1/main/Dungeons-and-Diffusion: description: Dungeons & Dragons characters (2.13 GB) From ea67507fce3d7cd543012cfe1a7ae7afe0289bac Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Wed, 21 Feb 2024 17:15:54 -0500 Subject: [PATCH 203/340] several small model install enhancements - Support extended HF repoid syntax in TUI. This allows installation of subfolders and safetensors files, as in `XpucT/Deliberate::Deliberate_v5.safetensors` - Add `error` and `error_traceback` properties to the install job objects. - Rename the `heuristic_import` route to `heuristic_install`. - Fix the example `config` input in the `heuristic_install` route. --- invokeai/app/api/routers/model_manager.py | 8 ++-- .../model_install/model_install_base.py | 13 ++++-- invokeai/backend/install/install_helper.py | 42 +++---------------- .../model_install/test_model_install.py | 2 +- 4 files changed, 20 insertions(+), 45 deletions(-) diff --git a/invokeai/app/api/routers/model_manager.py b/invokeai/app/api/routers/model_manager.py index aee457406a8..f57f5f97b63 100644 --- a/invokeai/app/api/routers/model_manager.py +++ b/invokeai/app/api/routers/model_manager.py @@ -382,8 +382,8 @@ async def add_model_record( @model_manager_router.post( - "/heuristic_import", - operation_id="heuristic_import_model", + "/heuristic_install", + operation_id="heuristic_install_model", responses={ 201: {"description": "The model imported successfully"}, 415: {"description": "Unrecognized file/folder format"}, @@ -392,12 +392,12 @@ async def add_model_record( }, status_code=201, ) -async def heuristic_import( +async def heuristic_install( source: str, config: Optional[Dict[str, Any]] = Body( description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ", default=None, - example={"name": "modelT", "description": "antique cars"}, + example={"name": "string", "description": "string"}, ), access_token: Optional[str] = None, ) -> ModelInstallJob: diff --git a/invokeai/app/services/model_install/model_install_base.py b/invokeai/app/services/model_install/model_install_base.py index 080219af75e..d1e8e4f8e58 100644 --- a/invokeai/app/services/model_install/model_install_base.py +++ b/invokeai/app/services/model_install/model_install_base.py @@ -177,6 +177,12 @@ class ModelInstallJob(BaseModel): download_parts: Set[DownloadJob] = Field( default_factory=set, description="Download jobs contributing to this install" ) + error: Optional[str] = Field( + default=None, description="On an error condition, this field will contain the text of the exception" + ) + error_traceback: Optional[str] = Field( + default=None, description="On an error condition, this field will contain the exception traceback" + ) # internal flags and transitory settings _install_tmpdir: Optional[Path] = PrivateAttr(default=None) _exception: Optional[Exception] = PrivateAttr(default=None) @@ -184,6 +190,8 @@ class ModelInstallJob(BaseModel): def set_error(self, e: Exception) -> None: """Record the error and traceback from an exception.""" self._exception = e + self.error = str(e) + self.error_traceback = self._format_error(e) self.status = InstallStatus.ERROR def cancel(self) -> None: @@ -195,10 +203,9 @@ def error_type(self) -> Optional[str]: """Class name of the exception that led to status==ERROR.""" return self._exception.__class__.__name__ if self._exception else None - @property - def error(self) -> Optional[str]: + def _format_error(self, exception: Exception) -> str: """Error traceback.""" - return "".join(traceback.format_exception(self._exception)) if self._exception else None + return "".join(traceback.format_exception(exception)) @property def cancelled(self) -> bool: diff --git a/invokeai/backend/install/install_helper.py b/invokeai/backend/install/install_helper.py index 3623b623a94..999dcdd1007 100644 --- a/invokeai/backend/install/install_helper.py +++ b/invokeai/backend/install/install_helper.py @@ -1,14 +1,11 @@ """Utility (backend) functions used by model_install.py""" -import re from logging import Logger from pathlib import Path from typing import Any, Dict, List, Optional import omegaconf -from huggingface_hub import HfFolder from pydantic import BaseModel, Field from pydantic.dataclasses import dataclass -from pydantic.networks import AnyHttpUrl from requests import HTTPError from tqdm import tqdm @@ -18,12 +15,8 @@ from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.image_files.image_files_disk import DiskImageFileStorage from invokeai.app.services.model_install import ( - HFModelSource, - LocalModelSource, ModelInstallService, ModelInstallServiceBase, - ModelSource, - URLModelSource, ) from invokeai.app.services.model_metadata import ModelMetadataStoreSQL from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL @@ -31,7 +24,6 @@ from invokeai.backend.model_manager import ( BaseModelType, InvalidModelConfigException, - ModelRepoVariant, ModelType, ) from invokeai.backend.model_manager.metadata import UnknownMetadataException @@ -226,37 +218,13 @@ def _add_required_models(self, model_list: List[UnifiedModelInfo]) -> None: additional_models.append(reverse_source[requirement]) model_list.extend(additional_models) - def _make_install_source(self, model_info: UnifiedModelInfo) -> ModelSource: - assert model_info.source - model_path_id_or_url = model_info.source.strip("\"' ") - model_path = Path(model_path_id_or_url) - - if model_path.exists(): # local file on disk - return LocalModelSource(path=model_path.absolute(), inplace=True) - - # parsing huggingface repo ids - # we're going to do a little trick that allows for extended repo_ids of form "foo/bar:fp16" - variants = "|".join([x.lower() for x in ModelRepoVariant.__members__]) - if match := re.match(f"^([^/]+/[^/]+?)(?::({variants}))?$", model_path_id_or_url): - repo_id = match.group(1) - repo_variant = ModelRepoVariant(match.group(2)) if match.group(2) else None - subfolder = Path(model_info.subfolder) if model_info.subfolder else None - return HFModelSource( - repo_id=repo_id, - access_token=HfFolder.get_token(), - subfolder=subfolder, - variant=repo_variant, - ) - if re.match(r"^(http|https):", model_path_id_or_url): - return URLModelSource(url=AnyHttpUrl(model_path_id_or_url)) - raise ValueError(f"Unsupported model source: {model_path_id_or_url}") - def add_or_delete(self, selections: InstallSelections) -> None: """Add or delete selected models.""" installer = self._installer self._add_required_models(selections.install_models) for model in selections.install_models: - source = self._make_install_source(model) + assert model.source + model_path_id_or_url = model.source.strip("\"' ") config = ( { "description": model.description, @@ -267,12 +235,12 @@ def add_or_delete(self, selections: InstallSelections) -> None: ) try: - installer.import_model( - source=source, + installer.heuristic_import( + source=model_path_id_or_url, config=config, ) except (UnknownMetadataException, InvalidModelConfigException, HTTPError, OSError) as e: - self._logger.warning(f"{source}: {e}") + self._logger.warning(f"{model.source}: {e}") for model_to_remove in selections.remove_models: parts = model_to_remove.split("/") diff --git a/tests/app/services/model_install/test_model_install.py b/tests/app/services/model_install/test_model_install.py index 55f7e865410..80b106c5cb2 100644 --- a/tests/app/services/model_install/test_model_install.py +++ b/tests/app/services/model_install/test_model_install.py @@ -256,4 +256,4 @@ def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: In assert job.error_type == "HTTPError" assert job.error assert "NOT FOUND" in job.error - assert "Traceback" in job.error + assert job.error_traceback.startswith("Traceback") From 10ba71b212d35670d2bc4425692501f36eebcaa2 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 22 Feb 2024 17:19:08 +1100 Subject: [PATCH 204/340] feat(ui): replace `type-fest` with `utility-types` - The new package has more useful types - Only used `JsonObject` from `type-fest`; added an implementation of that type --- invokeai/frontend/web/package.json | 2 +- invokeai/frontend/web/pnpm-lock.yaml | 16 ++--- invokeai/frontend/web/src/app/store/store.ts | 4 +- invokeai/frontend/web/src/common/types.ts | 7 ++ .../src/features/nodes/util/graph/metadata.ts | 4 +- .../nodes/util/workflow/validateWorkflow.ts | 4 +- .../frontend/web/src/services/api/types.ts | 70 ++++++++++++++----- 7 files changed, 73 insertions(+), 34 deletions(-) create mode 100644 invokeai/frontend/web/src/common/types.ts diff --git a/invokeai/frontend/web/package.json b/invokeai/frontend/web/package.json index 743cb1e09d6..0bf236ee384 100644 --- a/invokeai/frontend/web/package.json +++ b/invokeai/frontend/web/package.json @@ -99,7 +99,6 @@ "roarr": "^7.21.0", "serialize-error": "^11.0.3", "socket.io-client": "^4.7.4", - "type-fest": "^4.10.2", "use-debounce": "^10.0.0", "use-image": "^1.1.1", "uuid": "^9.0.1", @@ -146,6 +145,7 @@ "ts-toolbelt": "^9.6.0", "tsafe": "^1.6.6", "typescript": "^5.3.3", + "utility-types": "^3.11.0", "vite": "^5.1.3", "vite-plugin-css-injected-by-js": "^3.4.0", "vite-plugin-dts": "^3.7.2", diff --git a/invokeai/frontend/web/pnpm-lock.yaml b/invokeai/frontend/web/pnpm-lock.yaml index 9e873102e6a..f2abdd87bf6 100644 --- a/invokeai/frontend/web/pnpm-lock.yaml +++ b/invokeai/frontend/web/pnpm-lock.yaml @@ -152,9 +152,6 @@ dependencies: socket.io-client: specifier: ^4.7.4 version: 4.7.4 - type-fest: - specifier: ^4.10.2 - version: 4.10.2 use-debounce: specifier: ^10.0.0 version: 10.0.0(react@18.2.0) @@ -271,6 +268,9 @@ devDependencies: typescript: specifier: ^5.3.3 version: 5.3.3 + utility-types: + specifier: ^3.11.0 + version: 3.11.0 vite: specifier: ^5.1.3 version: 5.1.3(@types/node@20.11.19) @@ -13827,11 +13827,6 @@ packages: resolution: {integrity: sha512-RAH822pAdBgcNMAfWnCBU3CFZcfZ/i1eZjwFU/dsLKumyuuP3niueg2UAukXYF0E2AAoc82ZSSf9J0WQBinzHA==} engines: {node: '>=12.20'} - /type-fest@4.10.2: - resolution: {integrity: sha512-anpAG63wSpdEbLwOqH8L84urkL6PiVIov3EMmgIhhThevh9aiMQov+6Btx0wldNcvm4wV+e2/Rt1QdDwKHFbHw==} - engines: {node: '>=16'} - dev: false - /type-is@1.6.18: resolution: {integrity: sha512-TkRKr9sUTxEH8MdfuCSP7VizJyzRNMjj2J2do2Jr3Kym598JVdEksuzPQCnlFPW4ky9Q+iA+ma9BGm06XQBy8g==} engines: {node: '>= 0.6'} @@ -14152,6 +14147,11 @@ packages: which-typed-array: 1.1.14 dev: true + /utility-types@3.11.0: + resolution: {integrity: sha512-6Z7Ma2aVEWisaL6TvBCy7P8rm2LQoPv6dJ7ecIaIixHcwfbJ0x7mWdbcwlIM5IGQxPZSFYeqRCqlOOeKoJYMkw==} + engines: {node: '>= 4'} + dev: true + /utils-merge@1.0.1: resolution: {integrity: sha512-pMZTvIkT1d+TFGvDOqodOclx0QWkkgi6Tdoa8gC8ffGAAqz9pzPTZWAybbsHHoED/ztMtkv/VoYTYyShUn81hA==} engines: {node: '>= 0.4.0'} diff --git a/invokeai/frontend/web/src/app/store/store.ts b/invokeai/frontend/web/src/app/store/store.ts index 270662c3d21..16f1632d882 100644 --- a/invokeai/frontend/web/src/app/store/store.ts +++ b/invokeai/frontend/web/src/app/store/store.ts @@ -3,6 +3,7 @@ import { autoBatchEnhancer, combineReducers, configureStore } from '@reduxjs/too import { logger } from 'app/logging/logger'; import { idbKeyValDriver } from 'app/store/enhancers/reduxRemember/driver'; import { errorHandler } from 'app/store/enhancers/reduxRemember/errors'; +import type { JSONObject } from 'common/types'; import { canvasPersistConfig, canvasSlice } from 'features/canvas/store/canvasSlice'; import { changeBoardModalSlice } from 'features/changeBoardModal/store/slice'; import { @@ -32,7 +33,6 @@ import { rememberEnhancer, rememberReducer } from 'redux-remember'; import { serializeError } from 'serialize-error'; import { api } from 'services/api'; import { authToastMiddleware } from 'services/api/authToastMiddleware'; -import type { JsonObject } from 'type-fest'; import { STORAGE_PREFIX } from './constants'; import { actionSanitizer } from './middleware/devtools/actionSanitizer'; @@ -125,7 +125,7 @@ const unserialize: UnserializeFunction = (data, key) => { { persistedData: parsed, rehydratedData: transformed, - diff: diff(parsed, transformed) as JsonObject, // this is always serializable + diff: diff(parsed, transformed) as JSONObject, // this is always serializable }, `Rehydrated slice "${key}"` ); diff --git a/invokeai/frontend/web/src/common/types.ts b/invokeai/frontend/web/src/common/types.ts new file mode 100644 index 00000000000..29a411788dd --- /dev/null +++ b/invokeai/frontend/web/src/common/types.ts @@ -0,0 +1,7 @@ +export type JSONValue = string | number | boolean | null | JSONValue[] | { [key: string]: JSONValue }; + +export interface JSONObject { + [k: string]: JSONValue; +} + +export interface JSONArray extends Array {} diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/metadata.ts b/invokeai/frontend/web/src/features/nodes/util/graph/metadata.ts index 781ce57ebc9..c48f54d1917 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/metadata.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/metadata.ts @@ -1,5 +1,5 @@ +import type { JSONObject } from 'common/types'; import type { CoreMetadataInvocation, NonNullableGraph } from 'services/api/types'; -import type { JsonObject } from 'type-fest'; import { METADATA } from './constants'; @@ -30,7 +30,7 @@ export const addCoreMetadataNode = ( export const upsertMetadata = ( graph: NonNullableGraph, - metadata: Partial | JsonObject + metadata: Partial | JSONObject ): void => { const metadataNode = graph.nodes[METADATA] as CoreMetadataInvocation | undefined; diff --git a/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.ts b/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.ts index 5096e588b06..b402f2f8af1 100644 --- a/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.ts +++ b/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.ts @@ -1,3 +1,4 @@ +import type { JSONObject } from 'common/types'; import { parseify } from 'common/util/serialize'; import type { InvocationTemplate } from 'features/nodes/types/invocation'; import type { WorkflowV3 } from 'features/nodes/types/workflow'; @@ -5,14 +6,13 @@ import { isWorkflowInvocationNode } from 'features/nodes/types/workflow'; import { getNeedsUpdate } from 'features/nodes/util/node/nodeUpdate'; import { t } from 'i18next'; import { keyBy } from 'lodash-es'; -import type { JsonObject } from 'type-fest'; import { parseAndMigrateWorkflow } from './migrations'; type WorkflowWarning = { message: string; issues?: string[]; - data: JsonObject; + data: JSONObject; }; type ValidateWorkflowResult = { diff --git a/invokeai/frontend/web/src/services/api/types.ts b/invokeai/frontend/web/src/services/api/types.ts index aaa70a26848..d561173337c 100644 --- a/invokeai/frontend/web/src/services/api/types.ts +++ b/invokeai/frontend/web/src/services/api/types.ts @@ -2,7 +2,7 @@ import type { UseToastOptions } from '@invoke-ai/ui-library'; import type { EntityState } from '@reduxjs/toolkit'; import type { components, paths } from 'services/api/schema'; import type { O } from 'ts-toolbelt'; -import type { SetRequired } from 'type-fest'; +import type { Overwrite } from 'utility-types'; export type S = components['schemas']; @@ -61,28 +61,60 @@ export type IPAdapterField = S['IPAdapterField']; // Model Configs // TODO(MM2): Can we make key required in the pydantic model? -type KeyRequired = SetRequired; -export type LoRAConfig = KeyRequired; +export type LoRAModelConfig = S['LoRAConfig']; // TODO(MM2): Can we rename this from Vae -> VAE -export type VAEConfig = KeyRequired | KeyRequired; -export type ControlNetConfig = - | KeyRequired - | KeyRequired; -export type IPAdapterConfig = KeyRequired; +export type VAEModelConfig = S['VaeCheckpointConfig'] | S['VaeDiffusersConfig']; +export type ControlNetModelConfig = S['ControlNetDiffusersConfig'] | S['ControlNetCheckpointConfig']; +export type IPAdapterModelConfig = S['IPAdapterConfig']; // TODO(MM2): Can we rename this to T2IAdapterConfig -export type T2IAdapterConfig = KeyRequired; -export type TextualInversionConfig = KeyRequired; -export type DiffusersModelConfig = KeyRequired; -export type CheckpointModelConfig = KeyRequired; +export type T2IAdapterModelConfig = S['T2IConfig']; +export type TextualInversionModelConfig = S['TextualInversionConfig']; +export type DiffusersModelConfig = S['MainDiffusersConfig']; +export type CheckpointModelConfig = S['MainCheckpointConfig']; export type MainModelConfig = DiffusersModelConfig | CheckpointModelConfig; +export type RefinerMainModelConfig = Overwrite; +export type NonRefinerMainModelConfig = Overwrite; export type AnyModelConfig = - | LoRAConfig - | VAEConfig - | ControlNetConfig - | IPAdapterConfig - | T2IAdapterConfig - | TextualInversionConfig - | MainModelConfig; + | LoRAModelConfig + | VAEModelConfig + | ControlNetModelConfig + | IPAdapterModelConfig + | T2IAdapterModelConfig + | TextualInversionModelConfig + | RefinerMainModelConfig + | NonRefinerMainModelConfig; + +export const isLoRAModelConfig = (config: AnyModelConfig): config is LoRAModelConfig => { + return config.type === 'lora'; +}; + +export const isVAEModelConfig = (config: AnyModelConfig): config is VAEModelConfig => { + return config.type === 'vae'; +}; + +export const isControlNetModelConfig = (config: AnyModelConfig): config is ControlNetModelConfig => { + return config.type === 'controlnet'; +}; + +export const isIPAdapterModelConfig = (config: AnyModelConfig): config is IPAdapterModelConfig => { + return config.type === 'ip_adapter'; +}; + +export const isT2IAdapterModelConfig = (config: AnyModelConfig): config is T2IAdapterModelConfig => { + return config.type === 't2i_adapter'; +}; + +export const isTextualInversionModelConfig = (config: AnyModelConfig): config is TextualInversionModelConfig => { + return config.type === 'embedding'; +}; + +export const isNonRefinerMainModelConfig = (config: AnyModelConfig): config is NonRefinerMainModelConfig => { + return config.type === 'main' && config.base !== 'sdxl-refiner'; +}; + +export const isRefinerMainModelModelConfig = (config: AnyModelConfig): config is RefinerMainModelConfig => { + return config.type === 'main' && config.base === 'sdxl-refiner'; +}; export type MergeModelConfig = S['Body_merge']; export type ImportModelConfig = S['Body_import_model']; From 7f81d89d082b1a4f8a8ffe5fd4987b22cc0a0b70 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 22 Feb 2024 17:21:37 +1100 Subject: [PATCH 205/340] fix(nodes): make fields on `ModelConfigBase` required The setup of `ModelConfigBase` means autogenerated types have critical fields flagged as nullable (like `key` and `base`). Need to manually flag them as required. --- invokeai/backend/model_manager/config.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index bc4848b0a50..fb0593a651a 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -138,9 +138,16 @@ class ModelConfigBase(BaseModel): source: Optional[str] = Field(description="model original source (path, URL or repo_id)", default=None) last_modified: Optional[float] = Field(description="timestamp for modification time", default_factory=time.time) + @staticmethod + def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None: + schema["required"].extend( + ["key", "base", "type", "format", "original_hash", "current_hash", "source", "last_modified"] + ) + model_config = ConfigDict( use_enum_values=False, validate_assignment=True, + json_schema_extra=json_schema_extra, ) def update(self, attributes: Dict[str, Any]) -> None: From f8159770dff476b04368325e5af7c714ab5b820f Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 22 Feb 2024 17:21:46 +1100 Subject: [PATCH 206/340] chore(ui): typegen --- .../frontend/web/src/services/api/schema.ts | 190 +++++++++--------- 1 file changed, 95 insertions(+), 95 deletions(-) diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index 2115e797689..3566fdb9e28 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -1420,7 +1420,7 @@ export type components = { * @default clip_vision * @constant */ - type?: "clip_vision"; + type: "clip_vision"; /** * Format * @constant @@ -1431,17 +1431,17 @@ export type components = { * @description unique key for model * @default */ - key?: string; + key: string; /** * Original Hash * @description original fasthash of model contents */ - original_hash?: string | null; + original_hash: string | null; /** * Current Hash * @description current fasthash of model contents */ - current_hash?: string | null; + current_hash: string | null; /** * Description * @description human readable description of the model @@ -1451,12 +1451,12 @@ export type components = { * Source * @description model original source (path, URL or repo_id) */ - source?: string | null; + source: string | null; /** * Last Modified * @description timestamp for modification time */ - last_modified?: number | null; + last_modified: number | null; }; /** CLIPVisionModelField */ CLIPVisionModelField: { @@ -2538,29 +2538,29 @@ export type components = { * @default controlnet * @constant */ - type?: "controlnet"; + type: "controlnet"; /** * Format * @default checkpoint * @constant */ - format?: "checkpoint"; + format: "checkpoint"; /** * Key * @description unique key for model * @default */ - key?: string; + key: string; /** * Original Hash * @description original fasthash of model contents */ - original_hash?: string | null; + original_hash: string | null; /** * Current Hash * @description current fasthash of model contents */ - current_hash?: string | null; + current_hash: string | null; /** * Description * @description human readable description of the model @@ -2570,12 +2570,12 @@ export type components = { * Source * @description model original source (path, URL or repo_id) */ - source?: string | null; + source: string | null; /** * Last Modified * @description timestamp for modification time */ - last_modified?: number | null; + last_modified: number | null; /** * Config * @description path to the checkpoint model config file @@ -2604,29 +2604,29 @@ export type components = { * @default controlnet * @constant */ - type?: "controlnet"; + type: "controlnet"; /** * Format * @default diffusers * @constant */ - format?: "diffusers"; + format: "diffusers"; /** * Key * @description unique key for model * @default */ - key?: string; + key: string; /** * Original Hash * @description original fasthash of model contents */ - original_hash?: string | null; + original_hash: string | null; /** * Current Hash * @description current fasthash of model contents */ - current_hash?: string | null; + current_hash: string | null; /** * Description * @description human readable description of the model @@ -2636,12 +2636,12 @@ export type components = { * Source * @description model original source (path, URL or repo_id) */ - source?: string | null; + source: string | null; /** * Last Modified * @description timestamp for modification time */ - last_modified?: number | null; + last_modified: number | null; /** @default */ repo_variant?: components["schemas"]["ModelRepoVariant"] | null; }; @@ -4214,7 +4214,7 @@ export type components = { * @description The nodes in this graph */ nodes: { - [key: string]: components["schemas"]["LatentsCollectionInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["T2IAdapterInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["FloatToIntegerInvocation"] | components["schemas"]["SaveImageInvocation"] | components["schemas"]["RoundInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["IPAdapterInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["ImageChannelOffsetInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["FloatMathInvocation"] | components["schemas"]["CalculateImageTilesMinimumOverlapInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["MetadataInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["FaceIdentifierInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["CoreMetadataInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["ImageChannelMultiplyInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["LatentsInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["CreateDenoiseMaskInvocation"] | components["schemas"]["TileToPropertiesInvocation"] | components["schemas"]["CV2InfillInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["StringCollectionInvocation"] | components["schemas"]["MetadataItemInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["SeamlessModeInvocation"] | components["schemas"]["FloatCollectionInvocation"] | components["schemas"]["StringJoinInvocation"] | components["schemas"]["ConditioningCollectionInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["StringJoinThreeInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["StringInvocation"] | components["schemas"]["IntegerCollectionInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["RandomFloatInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["LaMaInfillInvocation"] | components["schemas"]["CalculateImageTilesEvenSplitInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["UnsharpMaskInvocation"] | components["schemas"]["BlankImageInvocation"] | components["schemas"]["ConditioningInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["ColorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["ColorMapImageProcessorInvocation"] | components["schemas"]["IntegerMathInvocation"] | components["schemas"]["SDXLLoraLoaderInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["PairTileImageInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["StringReplaceInvocation"] | components["schemas"]["ImageInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["BooleanInvocation"] | components["schemas"]["CropLatentsCoreInvocation"] | components["schemas"]["StringSplitNegInvocation"] | components["schemas"]["FloatInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["BlendLatentsInvocation"] | components["schemas"]["IdealSizeInvocation"] | components["schemas"]["DWOpenposeImageProcessorInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["FreeUInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["SchedulerInvocation"] | components["schemas"]["FaceMaskInvocation"] | components["schemas"]["StringSplitInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["MergeTilesToImageInvocation"] | components["schemas"]["MergeMetadataInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["CenterPadCropInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["BooleanCollectionInvocation"] | components["schemas"]["CalculateImageTilesInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["IntegerInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["DepthAnythingImageProcessorInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["FaceOffInvocation"]; + [key: string]: components["schemas"]["CalculateImageTilesEvenSplitInvocation"] | components["schemas"]["IntegerCollectionInvocation"] | components["schemas"]["ImageChannelOffsetInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["MergeTilesToImageInvocation"] | components["schemas"]["IdealSizeInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["SeamlessModeInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["StringReplaceInvocation"] | components["schemas"]["ColorMapImageProcessorInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["FaceOffInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["LatentsInvocation"] | components["schemas"]["BlankImageInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ConditioningInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["SchedulerInvocation"] | components["schemas"]["FloatCollectionInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["MergeMetadataInvocation"] | components["schemas"]["FaceMaskInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["StringSplitInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["StringSplitNegInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["IntegerInvocation"] | components["schemas"]["DWOpenposeImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["BooleanInvocation"] | components["schemas"]["SDXLLoraLoaderInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["FloatToIntegerInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["StringCollectionInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["CalculateImageTilesInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["PairTileImageInvocation"] | components["schemas"]["IntegerMathInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["LaMaInfillInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["CenterPadCropInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["BlendLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["ImageInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["DepthAnythingImageProcessorInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CropLatentsCoreInvocation"] | components["schemas"]["ConditioningCollectionInvocation"] | components["schemas"]["FreeUInvocation"] | components["schemas"]["CoreMetadataInvocation"] | components["schemas"]["ColorInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["SaveImageInvocation"] | components["schemas"]["FloatInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["UnsharpMaskInvocation"] | components["schemas"]["MetadataInvocation"] | components["schemas"]["MetadataItemInvocation"] | components["schemas"]["T2IAdapterInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["CV2InfillInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["TileToPropertiesInvocation"] | components["schemas"]["FaceIdentifierInvocation"] | components["schemas"]["IPAdapterInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["BooleanCollectionInvocation"] | components["schemas"]["RandomFloatInvocation"] | components["schemas"]["CalculateImageTilesMinimumOverlapInvocation"] | components["schemas"]["StringInvocation"] | components["schemas"]["StringJoinInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["RoundInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["StringJoinThreeInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["FloatMathInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageChannelMultiplyInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["CreateDenoiseMaskInvocation"] | components["schemas"]["LatentsCollectionInvocation"]; }; /** * Edges @@ -4251,7 +4251,7 @@ export type components = { * @description The results of node executions */ results: { - [key: string]: components["schemas"]["ModelLoaderOutput"] | components["schemas"]["MetadataOutput"] | components["schemas"]["String2Output"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["IdealSizeOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["SchedulerOutput"] | components["schemas"]["SeamlessModeOutput"] | components["schemas"]["PairTileImageOutput"] | components["schemas"]["ClipSkipInvocationOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["DenoiseMaskOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["T2IAdapterOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["VAEOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["LoraLoaderOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["SDXLLoraLoaderOutput"] | components["schemas"]["IntegerOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["LatentsCollectionOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["CalculateImageTilesOutput"] | components["schemas"]["IPAdapterOutput"] | components["schemas"]["TileToPropertiesOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["UNetOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["CLIPOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["ColorCollectionOutput"] | components["schemas"]["StringCollectionOutput"]; + [key: string]: components["schemas"]["CalculateImageTilesOutput"] | components["schemas"]["ColorCollectionOutput"] | components["schemas"]["VAEOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["LatentsCollectionOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["T2IAdapterOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["SeamlessModeOutput"] | components["schemas"]["IntegerOutput"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["DenoiseMaskOutput"] | components["schemas"]["CLIPOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["LoraLoaderOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["MetadataOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["TileToPropertiesOutput"] | components["schemas"]["ClipSkipInvocationOutput"] | components["schemas"]["String2Output"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["SDXLLoraLoaderOutput"] | components["schemas"]["UNetOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["IPAdapterOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["PairTileImageOutput"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["SchedulerOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["IdealSizeOutput"]; }; /** * Errors @@ -4422,7 +4422,7 @@ export type components = { * @default ip_adapter * @constant */ - type?: "ip_adapter"; + type: "ip_adapter"; /** * Format * @constant @@ -4433,17 +4433,17 @@ export type components = { * @description unique key for model * @default */ - key?: string; + key: string; /** * Original Hash * @description original fasthash of model contents */ - original_hash?: string | null; + original_hash: string | null; /** * Current Hash * @description current fasthash of model contents */ - current_hash?: string | null; + current_hash: string | null; /** * Description * @description human readable description of the model @@ -4453,12 +4453,12 @@ export type components = { * Source * @description model original source (path, URL or repo_id) */ - source?: string | null; + source: string | null; /** * Last Modified * @description timestamp for modification time */ - last_modified?: number | null; + last_modified: number | null; /** Image Encoder Model Id */ image_encoder_model_id: string; }; @@ -6506,7 +6506,7 @@ export type components = { * @default lora * @constant */ - type?: "lora"; + type: "lora"; /** * Format * @enum {string} @@ -6517,17 +6517,17 @@ export type components = { * @description unique key for model * @default */ - key?: string; + key: string; /** * Original Hash * @description original fasthash of model contents */ - original_hash?: string | null; + original_hash: string | null; /** * Current Hash * @description current fasthash of model contents */ - current_hash?: string | null; + current_hash: string | null; /** * Description * @description human readable description of the model @@ -6537,12 +6537,12 @@ export type components = { * Source * @description model original source (path, URL or repo_id) */ - source?: string | null; + source: string | null; /** * Last Modified * @description timestamp for modification time */ - last_modified?: number | null; + last_modified: number | null; }; /** * LoRAMetadataField @@ -6706,29 +6706,29 @@ export type components = { * @default main * @constant */ - type?: "main"; + type: "main"; /** * Format * @default checkpoint * @constant */ - format?: "checkpoint"; + format: "checkpoint"; /** * Key * @description unique key for model * @default */ - key?: string; + key: string; /** * Original Hash * @description original fasthash of model contents */ - original_hash?: string | null; + original_hash: string | null; /** * Current Hash * @description current fasthash of model contents */ - current_hash?: string | null; + current_hash: string | null; /** * Description * @description human readable description of the model @@ -6738,12 +6738,12 @@ export type components = { * Source * @description model original source (path, URL or repo_id) */ - source?: string | null; + source: string | null; /** * Last Modified * @description timestamp for modification time */ - last_modified?: number | null; + last_modified: number | null; /** Vae */ vae?: string | null; /** @default normal */ @@ -6788,29 +6788,29 @@ export type components = { * @default main * @constant */ - type?: "main"; + type: "main"; /** * Format * @default diffusers * @constant */ - format?: "diffusers"; + format: "diffusers"; /** * Key * @description unique key for model * @default */ - key?: string; + key: string; /** * Original Hash * @description original fasthash of model contents */ - original_hash?: string | null; + original_hash: string | null; /** * Current Hash * @description current fasthash of model contents */ - current_hash?: string | null; + current_hash: string | null; /** * Description * @description human readable description of the model @@ -6820,12 +6820,12 @@ export type components = { * Source * @description model original source (path, URL or repo_id) */ - source?: string | null; + source: string | null; /** * Last Modified * @description timestamp for modification time */ - last_modified?: number | null; + last_modified: number | null; /** Vae */ vae?: string | null; /** @default normal */ @@ -7758,13 +7758,13 @@ export type components = { * @default sd-1 * @constant */ - base?: "sd-1"; + base: "sd-1"; /** * Type * @default onnx * @constant */ - type?: "onnx"; + type: "onnx"; /** * Format * @enum {string} @@ -7775,17 +7775,17 @@ export type components = { * @description unique key for model * @default */ - key?: string; + key: string; /** * Original Hash * @description original fasthash of model contents */ - original_hash?: string | null; + original_hash: string | null; /** * Current Hash * @description current fasthash of model contents */ - current_hash?: string | null; + current_hash: string | null; /** * Description * @description human readable description of the model @@ -7795,12 +7795,12 @@ export type components = { * Source * @description model original source (path, URL or repo_id) */ - source?: string | null; + source: string | null; /** * Last Modified * @description timestamp for modification time */ - last_modified?: number | null; + last_modified: number | null; /** Vae */ vae?: string | null; /** @default normal */ @@ -7838,13 +7838,13 @@ export type components = { * @default sd-2 * @constant */ - base?: "sd-2"; + base: "sd-2"; /** * Type * @default onnx * @constant */ - type?: "onnx"; + type: "onnx"; /** * Format * @enum {string} @@ -7855,17 +7855,17 @@ export type components = { * @description unique key for model * @default */ - key?: string; + key: string; /** * Original Hash * @description original fasthash of model contents */ - original_hash?: string | null; + original_hash: string | null; /** * Current Hash * @description current fasthash of model contents */ - current_hash?: string | null; + current_hash: string | null; /** * Description * @description human readable description of the model @@ -7875,12 +7875,12 @@ export type components = { * Source * @description model original source (path, URL or repo_id) */ - source?: string | null; + source: string | null; /** * Last Modified * @description timestamp for modification time */ - last_modified?: number | null; + last_modified: number | null; /** Vae */ vae?: string | null; /** @default normal */ @@ -7918,13 +7918,13 @@ export type components = { * @default sdxl * @constant */ - base?: "sdxl"; + base: "sdxl"; /** * Type * @default onnx * @constant */ - type?: "onnx"; + type: "onnx"; /** * Format * @enum {string} @@ -7935,17 +7935,17 @@ export type components = { * @description unique key for model * @default */ - key?: string; + key: string; /** * Original Hash * @description original fasthash of model contents */ - original_hash?: string | null; + original_hash: string | null; /** * Current Hash * @description current fasthash of model contents */ - current_hash?: string | null; + current_hash: string | null; /** * Description * @description human readable description of the model @@ -7955,12 +7955,12 @@ export type components = { * Source * @description model original source (path, URL or repo_id) */ - source?: string | null; + source: string | null; /** * Last Modified * @description timestamp for modification time */ - last_modified?: number | null; + last_modified: number | null; /** Vae */ vae?: string | null; /** @default normal */ @@ -10110,7 +10110,7 @@ export type components = { * @default t2i_adapter * @constant */ - type?: "t2i_adapter"; + type: "t2i_adapter"; /** * Format * @constant @@ -10121,17 +10121,17 @@ export type components = { * @description unique key for model * @default */ - key?: string; + key: string; /** * Original Hash * @description original fasthash of model contents */ - original_hash?: string | null; + original_hash: string | null; /** * Current Hash * @description current fasthash of model contents */ - current_hash?: string | null; + current_hash: string | null; /** * Description * @description human readable description of the model @@ -10141,12 +10141,12 @@ export type components = { * Source * @description model original source (path, URL or repo_id) */ - source?: string | null; + source: string | null; /** * Last Modified * @description timestamp for modification time */ - last_modified?: number | null; + last_modified: number | null; }; /** TBLR */ TBLR: { @@ -10181,7 +10181,7 @@ export type components = { * @default embedding * @constant */ - type?: "embedding"; + type: "embedding"; /** * Format * @enum {string} @@ -10192,17 +10192,17 @@ export type components = { * @description unique key for model * @default */ - key?: string; + key: string; /** * Original Hash * @description original fasthash of model contents */ - original_hash?: string | null; + original_hash: string | null; /** * Current Hash * @description current fasthash of model contents */ - current_hash?: string | null; + current_hash: string | null; /** * Description * @description human readable description of the model @@ -10212,12 +10212,12 @@ export type components = { * Source * @description model original source (path, URL or repo_id) */ - source?: string | null; + source: string | null; /** * Last Modified * @description timestamp for modification time */ - last_modified?: number | null; + last_modified: number | null; }; /** Tile */ Tile: { @@ -10530,29 +10530,29 @@ export type components = { * @default vae * @constant */ - type?: "vae"; + type: "vae"; /** * Format * @default checkpoint * @constant */ - format?: "checkpoint"; + format: "checkpoint"; /** * Key * @description unique key for model * @default */ - key?: string; + key: string; /** * Original Hash * @description original fasthash of model contents */ - original_hash?: string | null; + original_hash: string | null; /** * Current Hash * @description current fasthash of model contents */ - current_hash?: string | null; + current_hash: string | null; /** * Description * @description human readable description of the model @@ -10562,12 +10562,12 @@ export type components = { * Source * @description model original source (path, URL or repo_id) */ - source?: string | null; + source: string | null; /** * Last Modified * @description timestamp for modification time */ - last_modified?: number | null; + last_modified: number | null; }; /** * VaeDiffusersConfig @@ -10591,29 +10591,29 @@ export type components = { * @default vae * @constant */ - type?: "vae"; + type: "vae"; /** * Format * @default diffusers * @constant */ - format?: "diffusers"; + format: "diffusers"; /** * Key * @description unique key for model * @default */ - key?: string; + key: string; /** * Original Hash * @description original fasthash of model contents */ - original_hash?: string | null; + original_hash: string | null; /** * Current Hash * @description current fasthash of model contents */ - current_hash?: string | null; + current_hash: string | null; /** * Description * @description human readable description of the model @@ -10623,12 +10623,12 @@ export type components = { * Source * @description model original source (path, URL or repo_id) */ - source?: string | null; + source: string | null; /** * Last Modified * @description timestamp for modification time */ - last_modified?: number | null; + last_modified: number | null; }; /** VaeField */ VaeField: { From 247d97146e7fef5e84d0f12470e3bf1182aac1d2 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 22 Feb 2024 17:33:20 +1100 Subject: [PATCH 207/340] feat(ui): refactor metadata handling Refactor of metadata recall handling. This is in preparation for a backwards compatibility layer for models. - Create helpers to fetch a model outside react (e.g. not in a hook) - Created helpers to parse model metadata - Renamed a lot of types that were confusing and/or had naming collisions --- .../web/src/app/store/nanostores/store.ts | 22 + .../parameters/ParamControlAdapterModel.tsx | 4 +- .../features/embedding/EmbeddingSelect.tsx | 6 +- .../features/lora/components/LoRASelect.tsx | 6 +- .../web/src/features/lora/store/loraSlice.ts | 9 +- .../ModelManagerPanel/LoRAModelEdit.tsx | 13 +- .../ControlNetModelFieldInputComponent.tsx | 4 +- .../IPAdapterModelFieldInputComponent.tsx | 4 +- .../inputs/LoRAModelFieldInputComponent.tsx | 4 +- .../T2IAdapterModelFieldInputComponent.tsx | 4 +- .../inputs/VAEModelFieldInputComponent.tsx | 4 +- .../VAEModel/ParamVAEModelSelect.tsx | 6 +- .../parameters/hooks/useRecallParameters.ts | 516 ++++-------------- .../parameters/util/modelFetchingHelpers.ts | 113 ++++ .../parameters/util/modelMetadataHelpers.ts | 150 +++++ .../web/src/services/api/endpoints/models.ts | 64 +-- 16 files changed, 443 insertions(+), 486 deletions(-) create mode 100644 invokeai/frontend/web/src/features/parameters/util/modelFetchingHelpers.ts create mode 100644 invokeai/frontend/web/src/features/parameters/util/modelMetadataHelpers.ts diff --git a/invokeai/frontend/web/src/app/store/nanostores/store.ts b/invokeai/frontend/web/src/app/store/nanostores/store.ts index aee0f0e6ef7..f4cd001c96f 100644 --- a/invokeai/frontend/web/src/app/store/nanostores/store.ts +++ b/invokeai/frontend/web/src/app/store/nanostores/store.ts @@ -8,4 +8,26 @@ declare global { } } +/** + * Raised when the redux store is unable to be retrieved. + */ +export class ReduxStoreNotInitialized extends Error { + /** + * Create ReduxStoreNotInitialized + * @param {String} message + */ + constructor(message = 'Redux store not initialized') { + super(message); + this.name = this.constructor.name; + } +} + export const $store = atom> | undefined>(); + +export const getStore = () => { + const store = $store.get(); + if (!store) { + throw new ReduxStoreNotInitialized(); + } + return store; +}; diff --git a/invokeai/frontend/web/src/features/controlAdapters/components/parameters/ParamControlAdapterModel.tsx b/invokeai/frontend/web/src/features/controlAdapters/components/parameters/ParamControlAdapterModel.tsx index 696bf47b2a6..75372c350de 100644 --- a/invokeai/frontend/web/src/features/controlAdapters/components/parameters/ParamControlAdapterModel.tsx +++ b/invokeai/frontend/web/src/features/controlAdapters/components/parameters/ParamControlAdapterModel.tsx @@ -8,7 +8,7 @@ import { useControlAdapterType } from 'features/controlAdapters/hooks/useControl import { controlAdapterModelChanged } from 'features/controlAdapters/store/controlAdaptersSlice'; import { pick } from 'lodash-es'; import { memo, useCallback, useMemo } from 'react'; -import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'services/api/types'; +import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types'; type ParamControlAdapterModelProps = { id: string; @@ -24,7 +24,7 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => { const { data, isLoading } = useControlAdapterModelQuery(controlAdapterType); const _onChange = useCallback( - (model: ControlNetConfig | IPAdapterConfig | T2IAdapterConfig | null) => { + (model: ControlNetModelConfig | IPAdapterModelConfig | T2IAdapterModelConfig | null) => { if (!model) { return; } diff --git a/invokeai/frontend/web/src/features/embedding/EmbeddingSelect.tsx b/invokeai/frontend/web/src/features/embedding/EmbeddingSelect.tsx index fd05edc4667..a5ad358fa09 100644 --- a/invokeai/frontend/web/src/features/embedding/EmbeddingSelect.tsx +++ b/invokeai/frontend/web/src/features/embedding/EmbeddingSelect.tsx @@ -7,7 +7,7 @@ import { t } from 'i18next'; import { memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; import { useGetTextualInversionModelsQuery } from 'services/api/endpoints/models'; -import type { TextualInversionConfig } from 'services/api/types'; +import type { TextualInversionModelConfig } from 'services/api/types'; const noOptionsMessage = () => t('embedding.noMatchingEmbedding'); @@ -17,7 +17,7 @@ export const EmbeddingSelect = memo(({ onSelect, onClose }: EmbeddingSelectProps const currentBaseModel = useAppSelector((s) => s.generation.model?.base); const getIsDisabled = useCallback( - (embedding: TextualInversionConfig): boolean => { + (embedding: TextualInversionModelConfig): boolean => { const isCompatible = currentBaseModel === embedding.base; const hasMainModel = Boolean(currentBaseModel); return !hasMainModel || !isCompatible; @@ -27,7 +27,7 @@ export const EmbeddingSelect = memo(({ onSelect, onClose }: EmbeddingSelectProps const { data, isLoading } = useGetTextualInversionModelsQuery(); const _onChange = useCallback( - (embedding: TextualInversionConfig | null) => { + (embedding: TextualInversionModelConfig | null) => { if (!embedding) { return; } diff --git a/invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx b/invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx index b58751ca5e2..e7d40c5eafb 100644 --- a/invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx +++ b/invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx @@ -8,7 +8,7 @@ import { loraAdded, selectLoraSlice } from 'features/lora/store/loraSlice'; import { memo, useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { useGetLoRAModelsQuery } from 'services/api/endpoints/models'; -import type { LoRAConfig } from 'services/api/types'; +import type { LoRAModelConfig } from 'services/api/types'; const selectAddedLoRAs = createMemoizedSelector(selectLoraSlice, (lora) => lora.loras); @@ -19,7 +19,7 @@ const LoRASelect = () => { const addedLoRAs = useAppSelector(selectAddedLoRAs); const currentBaseModel = useAppSelector((s) => s.generation.model?.base); - const getIsDisabled = (lora: LoRAConfig): boolean => { + const getIsDisabled = (lora: LoRAModelConfig): boolean => { const isCompatible = currentBaseModel === lora.base; const isAdded = Boolean(addedLoRAs[lora.key]); const hasMainModel = Boolean(currentBaseModel); @@ -27,7 +27,7 @@ const LoRASelect = () => { }; const _onChange = useCallback( - (lora: LoRAConfig | null) => { + (lora: LoRAModelConfig | null) => { if (!lora) { return; } diff --git a/invokeai/frontend/web/src/features/lora/store/loraSlice.ts b/invokeai/frontend/web/src/features/lora/store/loraSlice.ts index dd455e12c39..377406b3e57 100644 --- a/invokeai/frontend/web/src/features/lora/store/loraSlice.ts +++ b/invokeai/frontend/web/src/features/lora/store/loraSlice.ts @@ -2,7 +2,7 @@ import type { PayloadAction } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit'; import type { PersistConfig, RootState } from 'app/store/store'; import type { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas'; -import type { LoRAConfig } from 'services/api/types'; +import type { LoRAModelConfig } from 'services/api/types'; export type LoRA = ParameterLoRAModel & { weight: number; @@ -28,13 +28,12 @@ export const loraSlice = createSlice({ name: 'lora', initialState: initialLoraState, reducers: { - loraAdded: (state, action: PayloadAction) => { + loraAdded: (state, action: PayloadAction) => { const { key, base } = action.payload; state.loras[key] = { key, base, ...defaultLoRAConfig }; }, - loraRecalled: (state, action: PayloadAction) => { - const { key, base, weight } = action.payload; - state.loras[key] = { key, base, weight, isEnabled: true }; + loraRecalled: (state, action: PayloadAction) => { + state.loras[action.payload.key] = action.payload; }, loraRemoved: (state, action: PayloadAction) => { const key = action.payload; diff --git a/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/LoRAModelEdit.tsx b/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/LoRAModelEdit.tsx index 75151cd0012..1a8f235aaff 100644 --- a/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/LoRAModelEdit.tsx +++ b/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/LoRAModelEdit.tsx @@ -8,12 +8,11 @@ import { memo, useCallback } from 'react'; import type { SubmitHandler } from 'react-hook-form'; import { useForm } from 'react-hook-form'; import { useTranslation } from 'react-i18next'; -import type { LoRAConfig } from 'services/api/endpoints/models'; import { useUpdateLoRAModelsMutation } from 'services/api/endpoints/models'; -import type { LoRAConfig } from 'services/api/types'; +import type { LoRAModelConfig } from 'services/api/types'; type LoRAModelEditProps = { - model: LoRAConfig; + model: LoRAModelConfig; }; const LoRAModelEdit = (props: LoRAModelEditProps) => { @@ -30,7 +29,7 @@ const LoRAModelEdit = (props: LoRAModelEditProps) => { control, formState: { errors }, reset, - } = useForm({ + } = useForm({ defaultValues: { model_name: model.model_name ? model.model_name : '', base_model: model.base_model, @@ -42,7 +41,7 @@ const LoRAModelEdit = (props: LoRAModelEditProps) => { mode: 'onChange', }); - const onSubmit = useCallback>( + const onSubmit = useCallback>( (values) => { const responseBody = { base_model: model.base_model, @@ -53,7 +52,7 @@ const LoRAModelEdit = (props: LoRAModelEditProps) => { updateLoRAModel(responseBody) .unwrap() .then((payload) => { - reset(payload as LoRAConfig, { keepDefaultValues: true }); + reset(payload as LoRAModelConfig, { keepDefaultValues: true }); dispatch( addToast( makeToast({ @@ -106,7 +105,7 @@ const LoRAModelEdit = (props: LoRAModelEditProps) => { {t('modelManager.description')} - control={control} name="base_model" /> + control={control} name="base_model" /> {t('modelManager.modelLocation')} diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ControlNetModelFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ControlNetModelFieldInputComponent.tsx index 1951ec60d36..29a1f93dd5e 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ControlNetModelFieldInputComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ControlNetModelFieldInputComponent.tsx @@ -6,7 +6,7 @@ import type { ControlNetModelFieldInputInstance, ControlNetModelFieldInputTempla import { pick } from 'lodash-es'; import { memo, useCallback } from 'react'; import { useGetControlNetModelsQuery } from 'services/api/endpoints/models'; -import type { ControlNetConfig } from 'services/api/types'; +import type { ControlNetModelConfig } from 'services/api/types'; import type { FieldComponentProps } from './types'; @@ -18,7 +18,7 @@ const ControlNetModelFieldInputComponent = (props: Props) => { const { data, isLoading } = useGetControlNetModelsQuery(); const _onChange = useCallback( - (value: ControlNetConfig | null) => { + (value: ControlNetModelConfig | null) => { if (!value) { return; } diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/IPAdapterModelFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/IPAdapterModelFieldInputComponent.tsx index 137f751fca1..d4f0ae3de11 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/IPAdapterModelFieldInputComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/IPAdapterModelFieldInputComponent.tsx @@ -6,7 +6,7 @@ import type { IPAdapterModelFieldInputInstance, IPAdapterModelFieldInputTemplate import { pick } from 'lodash-es'; import { memo, useCallback } from 'react'; import { useGetIPAdapterModelsQuery } from 'services/api/endpoints/models'; -import type { IPAdapterConfig } from 'services/api/types'; +import type { IPAdapterModelConfig } from 'services/api/types'; import type { FieldComponentProps } from './types'; @@ -18,7 +18,7 @@ const IPAdapterModelFieldInputComponent = ( const { data: ipAdapterModels } = useGetIPAdapterModelsQuery(); const _onChange = useCallback( - (value: IPAdapterConfig | null) => { + (value: IPAdapterModelConfig | null) => { if (!value) { return; } diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/LoRAModelFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/LoRAModelFieldInputComponent.tsx index 5f6318de9e5..9fd223e6940 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/LoRAModelFieldInputComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/LoRAModelFieldInputComponent.tsx @@ -6,7 +6,7 @@ import type { LoRAModelFieldInputInstance, LoRAModelFieldInputTemplate } from 'f import { pick } from 'lodash-es'; import { memo, useCallback } from 'react'; import { useGetLoRAModelsQuery } from 'services/api/endpoints/models'; -import type { LoRAConfig } from 'services/api/types'; +import type { LoRAModelConfig } from 'services/api/types'; import type { FieldComponentProps } from './types'; @@ -17,7 +17,7 @@ const LoRAModelFieldInputComponent = (props: Props) => { const dispatch = useAppDispatch(); const { data, isLoading } = useGetLoRAModelsQuery(); const _onChange = useCallback( - (value: LoRAConfig | null) => { + (value: LoRAModelConfig | null) => { if (!value) { return; } diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/T2IAdapterModelFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/T2IAdapterModelFieldInputComponent.tsx index 9115f22c145..a38356a0b87 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/T2IAdapterModelFieldInputComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/T2IAdapterModelFieldInputComponent.tsx @@ -6,7 +6,7 @@ import type { T2IAdapterModelFieldInputInstance, T2IAdapterModelFieldInputTempla import { pick } from 'lodash-es'; import { memo, useCallback } from 'react'; import { useGetT2IAdapterModelsQuery } from 'services/api/endpoints/models'; -import type { T2IAdapterConfig } from 'services/api/types'; +import type { T2IAdapterModelConfig } from 'services/api/types'; import type { FieldComponentProps } from './types'; @@ -19,7 +19,7 @@ const T2IAdapterModelFieldInputComponent = ( const { data: t2iAdapterModels } = useGetT2IAdapterModelsQuery(); const _onChange = useCallback( - (value: T2IAdapterConfig | null) => { + (value: T2IAdapterModelConfig | null) => { if (!value) { return; } diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/VAEModelFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/VAEModelFieldInputComponent.tsx index 87272f48b9b..272f7f5b354 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/VAEModelFieldInputComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/VAEModelFieldInputComponent.tsx @@ -7,7 +7,7 @@ import type { VAEModelFieldInputInstance, VAEModelFieldInputTemplate } from 'fea import { pick } from 'lodash-es'; import { memo, useCallback } from 'react'; import { useGetVaeModelsQuery } from 'services/api/endpoints/models'; -import type { VAEConfig } from 'services/api/types'; +import type { VAEModelConfig } from 'services/api/types'; import type { FieldComponentProps } from './types'; @@ -18,7 +18,7 @@ const VAEModelFieldInputComponent = (props: Props) => { const dispatch = useAppDispatch(); const { data, isLoading } = useGetVaeModelsQuery(); const _onChange = useCallback( - (value: VAEConfig | null) => { + (value: VAEModelConfig | null) => { if (!value) { return; } diff --git a/invokeai/frontend/web/src/features/parameters/components/VAEModel/ParamVAEModelSelect.tsx b/invokeai/frontend/web/src/features/parameters/components/VAEModel/ParamVAEModelSelect.tsx index 4b9f2764bf0..4a630fa9cef 100644 --- a/invokeai/frontend/web/src/features/parameters/components/VAEModel/ParamVAEModelSelect.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/VAEModel/ParamVAEModelSelect.tsx @@ -8,7 +8,7 @@ import { pick } from 'lodash-es'; import { memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; import { useGetVaeModelsQuery } from 'services/api/endpoints/models'; -import type { VAEConfig } from 'services/api/types'; +import type { VAEModelConfig } from 'services/api/types'; const selector = createMemoizedSelector(selectGenerationSlice, (generation) => { const { model, vae } = generation; @@ -21,7 +21,7 @@ const ParamVAEModelSelect = () => { const { model, vae } = useAppSelector(selector); const { data, isLoading } = useGetVaeModelsQuery(); const getIsDisabled = useCallback( - (vae: VAEConfig): boolean => { + (vae: VAEModelConfig): boolean => { const isCompatible = model?.base === vae.base; const hasMainModel = Boolean(model?.base); return !hasMainModel || !isCompatible; @@ -29,7 +29,7 @@ const ParamVAEModelSelect = () => { [model?.base] ); const _onChange = useCallback( - (vae: VAEConfig | null) => { + (vae: VAEModelConfig | null) => { dispatch(vaeSelected(vae ? pick(vae, 'key', 'base') : null)); }, [dispatch] diff --git a/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts b/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts index 0d464cd9b94..0929fc1dc33 100644 --- a/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts +++ b/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts @@ -1,17 +1,9 @@ import { useAppToaster } from 'app/components/Toaster'; import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants'; import { controlAdapterRecalled, controlAdaptersReset } from 'features/controlAdapters/store/controlAdaptersSlice'; -import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'features/controlAdapters/store/types'; -import { - initialControlNet, - initialIPAdapter, - initialT2IAdapter, -} from 'features/controlAdapters/util/buildControlAdapter'; import { setHrfEnabled, setHrfMethod, setHrfStrength } from 'features/hrf/store/hrfSlice'; import { loraRecalled, lorasCleared } from 'features/lora/store/loraSlice'; -import type { ModelIdentifier } from 'features/nodes/types/common'; import { isModelIdentifier } from 'features/nodes/types/common'; import type { ControlNetMetadataItem, @@ -56,6 +48,14 @@ import { isParameterStrength, isParameterWidth, } from 'features/parameters/types/parameterSchemas'; +import { + prepareControlNetMetadataItem, + prepareIPAdapterMetadataItem, + prepareLoRAMetadataItem, + prepareMainModelMetadataItem, + prepareT2IAdapterMetadataItem, + prepareVAEMetadataItem, +} from 'features/parameters/util/modelMetadataHelpers'; import { refinerModelChanged, setNegativeStylePromptSDXL, @@ -70,23 +70,7 @@ import { import { isNil } from 'lodash-es'; import { useCallback } from 'react'; import { useTranslation } from 'react-i18next'; -import { ALL_BASE_MODELS } from 'services/api/constants'; -import { - controlNetModelsAdapterSelectors, - ipAdapterModelsAdapterSelectors, - loraModelsAdapterSelectors, - mainModelsAdapterSelectors, - t2iAdapterModelsAdapterSelectors, - useGetControlNetModelsQuery, - useGetIPAdapterModelsQuery, - useGetLoRAModelsQuery, - useGetMainModelsQuery, - useGetT2IAdapterModelsQuery, - useGetVaeModelsQuery, - vaeModelsAdapterSelectors, -} from 'services/api/endpoints/models'; import type { ImageDTO } from 'services/api/types'; -import { v4 as uuidv4 } from 'uuid'; const selectModel = createMemoizedSelector(selectGenerationSlice, (generation) => generation.model); @@ -140,9 +124,6 @@ export const useRecallParameters = () => { [t, toaster] ); - /** - * Recall both prompts with toast - */ const recallBothPrompts = useCallback( (positivePrompt: unknown, negativePrompt: unknown, positiveStylePrompt: unknown, negativeStylePrompt: unknown) => { if ( @@ -175,9 +156,6 @@ export const useRecallParameters = () => { [dispatch, parameterSetToast, parameterNotSetToast] ); - /** - * Recall positive prompt with toast - */ const recallPositivePrompt = useCallback( (positivePrompt: unknown) => { if (!isParameterPositivePrompt(positivePrompt)) { @@ -190,9 +168,6 @@ export const useRecallParameters = () => { [dispatch, parameterSetToast, parameterNotSetToast] ); - /** - * Recall negative prompt with toast - */ const recallNegativePrompt = useCallback( (negativePrompt: unknown) => { if (!isParameterNegativePrompt(negativePrompt)) { @@ -205,9 +180,6 @@ export const useRecallParameters = () => { [dispatch, parameterSetToast, parameterNotSetToast] ); - /** - * Recall SDXL Positive Style Prompt with toast - */ const recallSDXLPositiveStylePrompt = useCallback( (positiveStylePrompt: unknown) => { if (!isParameterPositiveStylePromptSDXL(positiveStylePrompt)) { @@ -220,9 +192,6 @@ export const useRecallParameters = () => { [dispatch, parameterSetToast, parameterNotSetToast] ); - /** - * Recall SDXL Negative Style Prompt with toast - */ const recallSDXLNegativeStylePrompt = useCallback( (negativeStylePrompt: unknown) => { if (!isParameterNegativeStylePromptSDXL(negativeStylePrompt)) { @@ -235,9 +204,6 @@ export const useRecallParameters = () => { [dispatch, parameterSetToast, parameterNotSetToast] ); - /** - * Recall seed with toast - */ const recallSeed = useCallback( (seed: unknown) => { if (!isParameterSeed(seed)) { @@ -250,9 +216,6 @@ export const useRecallParameters = () => { [dispatch, parameterSetToast, parameterNotSetToast] ); - /** - * Recall CFG scale with toast - */ const recallCfgScale = useCallback( (cfgScale: unknown) => { if (!isParameterCFGScale(cfgScale)) { @@ -265,9 +228,6 @@ export const useRecallParameters = () => { [dispatch, parameterSetToast, parameterNotSetToast] ); - /** - * Recall CFG rescale multiplier with toast - */ const recallCfgRescaleMultiplier = useCallback( (cfgRescaleMultiplier: unknown) => { if (!isParameterCFGRescaleMultiplier(cfgRescaleMultiplier)) { @@ -280,9 +240,6 @@ export const useRecallParameters = () => { [dispatch, parameterSetToast, parameterNotSetToast] ); - /** - * Recall scheduler with toast - */ const recallScheduler = useCallback( (scheduler: unknown) => { if (!isParameterScheduler(scheduler)) { @@ -295,9 +252,6 @@ export const useRecallParameters = () => { [dispatch, parameterSetToast, parameterNotSetToast] ); - /** - * Recall steps with toast - */ const recallSteps = useCallback( (steps: unknown) => { if (!isParameterSteps(steps)) { @@ -310,9 +264,6 @@ export const useRecallParameters = () => { [dispatch, parameterSetToast, parameterNotSetToast] ); - /** - * Recall width with toast - */ const recallWidth = useCallback( (width: unknown) => { if (!isParameterWidth(width)) { @@ -325,9 +276,6 @@ export const useRecallParameters = () => { [dispatch, parameterSetToast, parameterNotSetToast] ); - /** - * Recall height with toast - */ const recallHeight = useCallback( (height: unknown) => { if (!isParameterHeight(height)) { @@ -340,9 +288,6 @@ export const useRecallParameters = () => { [dispatch, parameterSetToast, parameterNotSetToast] ); - /** - * Recall width and height with toast - */ const recallWidthAndHeight = useCallback( (width: unknown, height: unknown) => { if (!isParameterWidth(width)) { @@ -360,9 +305,6 @@ export const useRecallParameters = () => { [dispatch, allParameterSetToast, allParameterNotSetToast] ); - /** - * Recall strength with toast - */ const recallStrength = useCallback( (strength: unknown) => { if (!isParameterStrength(strength)) { @@ -375,9 +317,6 @@ export const useRecallParameters = () => { [dispatch, parameterSetToast, parameterNotSetToast] ); - /** - * Recall high resolution enabled with toast - */ const recallHrfEnabled = useCallback( (hrfEnabled: unknown) => { if (!isParameterHRFEnabled(hrfEnabled)) { @@ -390,9 +329,6 @@ export const useRecallParameters = () => { [dispatch, parameterSetToast, parameterNotSetToast] ); - /** - * Recall high resolution strength with toast - */ const recallHrfStrength = useCallback( (hrfStrength: unknown) => { if (!isParameterStrength(hrfStrength)) { @@ -405,9 +341,6 @@ export const useRecallParameters = () => { [dispatch, parameterSetToast, parameterNotSetToast] ); - /** - * Recall high resolution method with toast - */ const recallHrfMethod = useCallback( (hrfMethod: unknown) => { if (!isParameterHRFMethod(hrfMethod)) { @@ -420,358 +353,95 @@ export const useRecallParameters = () => { [dispatch, parameterSetToast, parameterNotSetToast] ); - const { data: mainModels } = useGetMainModelsQuery(ALL_BASE_MODELS); - - const prepareMainModelMetadataItem = useCallback( - (model: ModelIdentifier) => { - const matchingModel = mainModels ? mainModelsAdapterSelectors.selectById(mainModels, model.key) : undefined; - - if (!matchingModel) { - return { model: null, error: 'Model is not installed' }; - } - - return { model: matchingModel, error: null }; - }, - [mainModels] - ); - - /** - * Recall model with toast - */ const recallModel = useCallback( - (model: unknown) => { - if (!isModelIdentifier(model)) { - parameterNotSetToast(); - return; - } - - const result = prepareMainModelMetadataItem(model); - - if (!result.model) { - parameterNotSetToast(result.error); + async (modelMetadataItem: unknown) => { + try { + const model = await prepareMainModelMetadataItem(modelMetadataItem); + dispatch(modelSelected(model)); + parameterSetToast(); + } catch (e) { + parameterNotSetToast((e as unknown as Error).message); return; } - - dispatch(modelSelected(result.model)); - parameterSetToast(); }, - [prepareMainModelMetadataItem, dispatch, parameterSetToast, parameterNotSetToast] - ); - - const { data: vaeModels } = useGetVaeModelsQuery(); - - const prepareVAEMetadataItem = useCallback( - (vae: ModelIdentifier, newModel?: ParameterModel) => { - const matchingModel = vaeModels ? vaeModelsAdapterSelectors.selectById(vaeModels, vae.key) : undefined; - if (!matchingModel) { - return { vae: null, error: 'VAE model is not installed' }; - } - const isCompatibleBaseModel = matchingModel?.base === (newModel ?? model)?.base; - - if (!isCompatibleBaseModel) { - return { - vae: null, - error: 'VAE incompatible with currently-selected model', - }; - } - - return { vae: matchingModel, error: null }; - }, - [model, vaeModels] + [dispatch, parameterSetToast, parameterNotSetToast] ); - /** - * Recall vae model - */ const recallVaeModel = useCallback( - (vae: unknown) => { - if (!isModelIdentifier(vae) && !isNil(vae)) { - parameterNotSetToast(); - return; - } - - if (isNil(vae)) { + async (vaeMetadataItem: unknown) => { + if (isNil(vaeMetadataItem)) { dispatch(vaeSelected(null)); parameterSetToast(); return; } - - const result = prepareVAEMetadataItem(vae); - - if (!result.vae) { - parameterNotSetToast(result.error); + try { + const vae = await prepareVAEMetadataItem(vaeMetadataItem); + dispatch(vaeSelected(vae)); + parameterSetToast(); + } catch (e) { + parameterNotSetToast((e as unknown as Error).message); return; } - - dispatch(vaeSelected(result.vae)); - parameterSetToast(); - }, - [prepareVAEMetadataItem, dispatch, parameterSetToast, parameterNotSetToast] - ); - - /** - * Recall LoRA with toast - */ - - const { data: loraModels } = useGetLoRAModelsQuery(undefined); - - const prepareLoRAMetadataItem = useCallback( - (loraMetadataItem: LoRAMetadataItem, newModel?: ParameterModel) => { - if (!isModelIdentifier(loraMetadataItem.lora)) { - return { lora: null, error: 'Invalid LoRA model' }; - } - - const { lora } = loraMetadataItem; - - const matchingLoRA = loraModels ? loraModelsAdapterSelectors.selectById(loraModels, lora.key) : undefined; - - if (!matchingLoRA) { - return { lora: null, error: 'LoRA model is not installed' }; - } - - const isCompatibleBaseModel = matchingLoRA?.base === (newModel ?? model)?.base; - - if (!isCompatibleBaseModel) { - return { - lora: null, - error: 'LoRA incompatible with currently-selected model', - }; - } - - return { lora: matchingLoRA, error: null }; }, - [loraModels, model] + [dispatch, parameterSetToast, parameterNotSetToast] ); const recallLoRA = useCallback( - (loraMetadataItem: LoRAMetadataItem) => { - const result = prepareLoRAMetadataItem(loraMetadataItem); - - if (!result.lora) { - parameterNotSetToast(result.error); + async (loraMetadataItem: LoRAMetadataItem) => { + try { + const lora = await prepareLoRAMetadataItem(loraMetadataItem, model?.base); + dispatch(loraRecalled(lora)); + parameterSetToast(); + } catch (e) { + parameterNotSetToast((e as unknown as Error).message); return; } - - dispatch(loraRecalled({ ...result.lora, weight: loraMetadataItem.weight })); - - parameterSetToast(); }, - [prepareLoRAMetadataItem, dispatch, parameterSetToast, parameterNotSetToast] - ); - - /** - * Recall ControlNet with toast - */ - - const { data: controlNetModels } = useGetControlNetModelsQuery(undefined); - - const prepareControlNetMetadataItem = useCallback( - (controlnetMetadataItem: ControlNetMetadataItem, newModel?: ParameterModel) => { - if (!isModelIdentifier(controlnetMetadataItem.control_model)) { - return { controlnet: null, error: 'Invalid ControlNet model' }; - } - - const { image, control_model, control_weight, begin_step_percent, end_step_percent, control_mode, resize_mode } = - controlnetMetadataItem; - - const matchingControlNetModel = controlNetModels - ? controlNetModelsAdapterSelectors.selectById(controlNetModels, control_model.key) - : undefined; - - if (!matchingControlNetModel) { - return { controlnet: null, error: 'ControlNet model is not installed' }; - } - - const isCompatibleBaseModel = matchingControlNetModel?.base === (newModel ?? model)?.base; - - if (!isCompatibleBaseModel) { - return { - controlnet: null, - error: 'ControlNet incompatible with currently-selected model', - }; - } - - // We don't save the original image that was processed into a control image, only the processed image - const processorType = 'none'; - const processorNode = CONTROLNET_PROCESSORS.none.default; - - const controlnet: ControlNetConfig = { - type: 'controlnet', - isEnabled: true, - model: matchingControlNetModel, - weight: typeof control_weight === 'number' ? control_weight : initialControlNet.weight, - beginStepPct: begin_step_percent || initialControlNet.beginStepPct, - endStepPct: end_step_percent || initialControlNet.endStepPct, - controlMode: control_mode || initialControlNet.controlMode, - resizeMode: resize_mode || initialControlNet.resizeMode, - controlImage: image?.image_name || null, - processedControlImage: image?.image_name || null, - processorType, - processorNode, - shouldAutoConfig: true, - id: uuidv4(), - }; - - return { controlnet, error: null }; - }, - [controlNetModels, model] + [model?.base, dispatch, parameterSetToast, parameterNotSetToast] ); const recallControlNet = useCallback( - (controlnetMetadataItem: ControlNetMetadataItem) => { - const result = prepareControlNetMetadataItem(controlnetMetadataItem); - - if (!result.controlnet) { - parameterNotSetToast(result.error); + async (controlnetMetadataItem: ControlNetMetadataItem) => { + try { + const controlNetConfig = await prepareControlNetMetadataItem(controlnetMetadataItem, model?.base); + dispatch(controlAdapterRecalled(controlNetConfig)); + parameterSetToast(); + } catch (e) { + parameterNotSetToast((e as unknown as Error).message); return; } - - dispatch(controlAdapterRecalled(result.controlnet)); - - parameterSetToast(); }, - [prepareControlNetMetadataItem, dispatch, parameterSetToast, parameterNotSetToast] - ); - - /** - * Recall T2I Adapter with toast - */ - - const { data: t2iAdapterModels } = useGetT2IAdapterModelsQuery(undefined); - - const prepareT2IAdapterMetadataItem = useCallback( - (t2iAdapterMetadataItem: T2IAdapterMetadataItem, newModel?: ParameterModel) => { - if (!isModelIdentifier(t2iAdapterMetadataItem.t2i_adapter_model)) { - return { controlnet: null, error: 'Invalid ControlNet model' }; - } - - const { image, t2i_adapter_model, weight, begin_step_percent, end_step_percent, resize_mode } = - t2iAdapterMetadataItem; - - const matchingT2IAdapterModel = t2iAdapterModels - ? t2iAdapterModelsAdapterSelectors.selectById(t2iAdapterModels, t2i_adapter_model.key) - : undefined; - - if (!matchingT2IAdapterModel) { - return { controlnet: null, error: 'ControlNet model is not installed' }; - } - - const isCompatibleBaseModel = matchingT2IAdapterModel?.base === (newModel ?? model)?.base; - - if (!isCompatibleBaseModel) { - return { - t2iAdapter: null, - error: 'ControlNet incompatible with currently-selected model', - }; - } - - // We don't save the original image that was processed into a control image, only the processed image - const processorType = 'none'; - const processorNode = CONTROLNET_PROCESSORS.none.default; - - const t2iAdapter: T2IAdapterConfig = { - type: 't2i_adapter', - isEnabled: true, - model: matchingT2IAdapterModel, - weight: typeof weight === 'number' ? weight : initialT2IAdapter.weight, - beginStepPct: begin_step_percent || initialT2IAdapter.beginStepPct, - endStepPct: end_step_percent || initialT2IAdapter.endStepPct, - resizeMode: resize_mode || initialT2IAdapter.resizeMode, - controlImage: image?.image_name || null, - processedControlImage: image?.image_name || null, - processorType, - processorNode, - shouldAutoConfig: true, - id: uuidv4(), - }; - - return { t2iAdapter, error: null }; - }, - [model, t2iAdapterModels] + [model?.base, dispatch, parameterSetToast, parameterNotSetToast] ); const recallT2IAdapter = useCallback( - (t2iAdapterMetadataItem: T2IAdapterMetadataItem) => { - const result = prepareT2IAdapterMetadataItem(t2iAdapterMetadataItem); - - if (!result.t2iAdapter) { - parameterNotSetToast(result.error); + async (t2iAdapterMetadataItem: T2IAdapterMetadataItem) => { + try { + const t2iAdapterConfig = await prepareT2IAdapterMetadataItem(t2iAdapterMetadataItem, model?.base); + dispatch(controlAdapterRecalled(t2iAdapterConfig)); + parameterSetToast(); + } catch (e) { + parameterNotSetToast((e as unknown as Error).message); return; } - - dispatch(controlAdapterRecalled(result.t2iAdapter)); - - parameterSetToast(); }, - [prepareT2IAdapterMetadataItem, dispatch, parameterSetToast, parameterNotSetToast] - ); - - /** - * Recall IP Adapter with toast - */ - - const { data: ipAdapterModels } = useGetIPAdapterModelsQuery(undefined); - - const prepareIPAdapterMetadataItem = useCallback( - (ipAdapterMetadataItem: IPAdapterMetadataItem, newModel?: ParameterModel) => { - if (!isModelIdentifier(ipAdapterMetadataItem?.ip_adapter_model)) { - return { ipAdapter: null, error: 'Invalid IP Adapter model' }; - } - - const { image, ip_adapter_model, weight, begin_step_percent, end_step_percent } = ipAdapterMetadataItem; - - const matchingIPAdapterModel = ipAdapterModels - ? ipAdapterModelsAdapterSelectors.selectById(ipAdapterModels, ip_adapter_model.key) - : undefined; - - if (!matchingIPAdapterModel) { - return { ipAdapter: null, error: 'IP Adapter model is not installed' }; - } - - const isCompatibleBaseModel = matchingIPAdapterModel?.base === (newModel ?? model)?.base; - - if (!isCompatibleBaseModel) { - return { - ipAdapter: null, - error: 'IP Adapter incompatible with currently-selected model', - }; - } - - const ipAdapter: IPAdapterConfig = { - id: uuidv4(), - type: 'ip_adapter', - isEnabled: true, - controlImage: image?.image_name ?? null, - model: matchingIPAdapterModel, - weight: weight ?? initialIPAdapter.weight, - beginStepPct: begin_step_percent ?? initialIPAdapter.beginStepPct, - endStepPct: end_step_percent ?? initialIPAdapter.endStepPct, - }; - - return { ipAdapter, error: null }; - }, - [ipAdapterModels, model] + [model?.base, dispatch, parameterSetToast, parameterNotSetToast] ); const recallIPAdapter = useCallback( - (ipAdapterMetadataItem: IPAdapterMetadataItem) => { - const result = prepareIPAdapterMetadataItem(ipAdapterMetadataItem); - - if (!result.ipAdapter) { - parameterNotSetToast(result.error); + async (ipAdapterMetadataItem: IPAdapterMetadataItem) => { + try { + const ipAdapterConfig = await prepareIPAdapterMetadataItem(ipAdapterMetadataItem, model?.base); + dispatch(controlAdapterRecalled(ipAdapterConfig)); + parameterSetToast(); + } catch (e) { + parameterNotSetToast((e as unknown as Error).message); return; } - - dispatch(controlAdapterRecalled(result.ipAdapter)); - - parameterSetToast(); }, - [prepareIPAdapterMetadataItem, dispatch, parameterSetToast, parameterNotSetToast] + [model?.base, dispatch, parameterSetToast, parameterNotSetToast] ); - /* - * Sets image as initial image with toast - */ const sendToImageToImage = useCallback( (image: ImageDTO) => { dispatch(initialImageSelected(image)); @@ -780,7 +450,7 @@ export const useRecallParameters = () => { ); const recallAllParameters = useCallback( - (metadata: CoreMetadata | undefined) => { + async (metadata: CoreMetadata | undefined) => { if (!metadata) { allParameterNotSetToast(); return; @@ -820,10 +490,12 @@ export const useRecallParameters = () => { let newModel: ParameterModel | undefined = undefined; if (isModelIdentifier(model)) { - const result = prepareMainModelMetadataItem(model); - if (result.model) { - dispatch(modelSelected(result.model)); - newModel = result.model; + try { + const _model = await prepareMainModelMetadataItem(model); + dispatch(modelSelected(_model)); + newModel = _model; + } catch { + return; } } @@ -850,9 +522,11 @@ export const useRecallParameters = () => { if (isNil(vae)) { dispatch(vaeSelected(null)); } else { - const result = prepareVAEMetadataItem(vae, newModel); - if (result.vae) { - dispatch(vaeSelected(result.vae)); + try { + const _vae = await prepareVAEMetadataItem(vae, newModel?.base); + dispatch(vaeSelected(_vae)); + } catch { + return; } } } @@ -926,48 +600,46 @@ export const useRecallParameters = () => { } dispatch(lorasCleared()); - loras?.forEach((lora) => { - const result = prepareLoRAMetadataItem(lora, newModel); - if (result.lora) { - dispatch(loraRecalled({ ...result.lora, weight: lora.weight })); + loras?.forEach(async (loraMetadataItem) => { + try { + const lora = await prepareLoRAMetadataItem(loraMetadataItem, newModel?.base); + dispatch(loraRecalled(lora)); + } catch { + return; } }); dispatch(controlAdaptersReset()); - controlnets?.forEach((controlnet) => { - const result = prepareControlNetMetadataItem(controlnet, newModel); - if (result.controlnet) { - dispatch(controlAdapterRecalled(result.controlnet)); + controlnets?.forEach(async (controlNetMetadataItem) => { + try { + const controlNet = await prepareControlNetMetadataItem(controlNetMetadataItem, newModel?.base); + dispatch(controlAdapterRecalled(controlNet)); + } catch { + return; } }); - ipAdapters?.forEach((ipAdapter) => { - const result = prepareIPAdapterMetadataItem(ipAdapter, newModel); - if (result.ipAdapter) { - dispatch(controlAdapterRecalled(result.ipAdapter)); + ipAdapters?.forEach(async (ipAdapterMetadataItem) => { + try { + const ipAdapter = await prepareIPAdapterMetadataItem(ipAdapterMetadataItem, newModel?.base); + dispatch(controlAdapterRecalled(ipAdapter)); + } catch { + return; } }); - t2iAdapters?.forEach((t2iAdapter) => { - const result = prepareT2IAdapterMetadataItem(t2iAdapter, newModel); - if (result.t2iAdapter) { - dispatch(controlAdapterRecalled(result.t2iAdapter)); + t2iAdapters?.forEach(async (t2iAdapterMetadataItem) => { + try { + const t2iAdapter = await prepareT2IAdapterMetadataItem(t2iAdapterMetadataItem, newModel?.base); + dispatch(controlAdapterRecalled(t2iAdapter)); + } catch { + return; } }); allParameterSetToast(); }, - [ - dispatch, - allParameterSetToast, - allParameterNotSetToast, - prepareMainModelMetadataItem, - prepareVAEMetadataItem, - prepareLoRAMetadataItem, - prepareControlNetMetadataItem, - prepareIPAdapterMetadataItem, - prepareT2IAdapterMetadataItem, - ] + [dispatch, allParameterSetToast, allParameterNotSetToast] ); return { diff --git a/invokeai/frontend/web/src/features/parameters/util/modelFetchingHelpers.ts b/invokeai/frontend/web/src/features/parameters/util/modelFetchingHelpers.ts new file mode 100644 index 00000000000..c7d25fed8b5 --- /dev/null +++ b/invokeai/frontend/web/src/features/parameters/util/modelFetchingHelpers.ts @@ -0,0 +1,113 @@ +import { getStore } from 'app/store/nanostores/store'; +import { isModelIdentifier } from 'features/nodes/types/common'; +import { modelsApi } from 'services/api/endpoints/models'; +import type { AnyModelConfig, BaseModelType } from 'services/api/types'; +import { + isControlNetModelConfig, + isIPAdapterModelConfig, + isLoRAModelConfig, + isNonRefinerMainModelConfig, + isRefinerMainModelModelConfig, + isT2IAdapterModelConfig, + isTextualInversionModelConfig, + isVAEModelConfig, +} from 'services/api/types'; + +/** + * Raised when a model config is unable to be fetched. + */ +export class ModelConfigNotFoundError extends Error { + /** + * Create ModelConfigNotFoundError + * @param {String} message + */ + constructor(message: string) { + super(message); + this.name = this.constructor.name; + } +} + +/** + * Raised when a fetched model config is of an unexpected type. + */ +export class InvalidModelConfigError extends Error { + /** + * Create InvalidModelConfigError + * @param {String} message + */ + constructor(message: string) { + super(message); + this.name = this.constructor.name; + } +} + +export const fetchModelConfig = async (key: string): Promise => { + const { dispatch } = getStore(); + try { + const req = dispatch(modelsApi.endpoints.getModelConfig.initiate(key)); + req.unsubscribe(); + return await req.unwrap(); + } catch { + throw new ModelConfigNotFoundError(`Unable to retrieve model config for key ${key}`); + } +}; + +export const fetchModelConfigWithTypeGuard = async ( + key: string, + typeGuard: (config: AnyModelConfig) => config is T +) => { + const modelConfig = await fetchModelConfig(key); + if (!typeGuard(modelConfig)) { + throw new InvalidModelConfigError(`Invalid model type for key ${key}: ${modelConfig.type}`); + } + return modelConfig; +}; + +export const fetchMainModel = async (key: string) => { + return fetchModelConfigWithTypeGuard(key, isNonRefinerMainModelConfig); +}; + +export const fetchRefinerModel = async (key: string) => { + return fetchModelConfigWithTypeGuard(key, isRefinerMainModelModelConfig); +}; + +export const fetchVAEModel = async (key: string) => { + return fetchModelConfigWithTypeGuard(key, isVAEModelConfig); +}; + +export const fetchLoRAModel = async (key: string) => { + return fetchModelConfigWithTypeGuard(key, isLoRAModelConfig); +}; + +export const fetchControlNetModel = async (key: string) => { + return fetchModelConfigWithTypeGuard(key, isControlNetModelConfig); +}; + +export const fetchIPAdapterModel = async (key: string) => { + return fetchModelConfigWithTypeGuard(key, isIPAdapterModelConfig); +}; + +export const fetchT2IAdapterModel = async (key: string) => { + return fetchModelConfigWithTypeGuard(key, isT2IAdapterModelConfig); +}; + +export const fetchTextualInversionModel = async (key: string) => { + return fetchModelConfigWithTypeGuard(key, isTextualInversionModelConfig); +}; + +export const isBaseCompatible = (sourceBase: BaseModelType, targetBase: BaseModelType) => { + return sourceBase === targetBase; +}; + +export const raiseIfBaseIncompatible = (sourceBase: BaseModelType, targetBase?: BaseModelType, message?: string) => { + if (targetBase && !isBaseCompatible(sourceBase, targetBase)) { + throw new InvalidModelConfigError(message || `Incompatible base models: ${sourceBase} and ${targetBase}`); + } +}; + +export const getModelKey = (modelIdentifier: unknown, message?: string): string => { + if (!isModelIdentifier(modelIdentifier)) { + throw new InvalidModelConfigError(message || `Invalid model identifier: ${modelIdentifier}`); + } + return modelIdentifier.key; +}; diff --git a/invokeai/frontend/web/src/features/parameters/util/modelMetadataHelpers.ts b/invokeai/frontend/web/src/features/parameters/util/modelMetadataHelpers.ts new file mode 100644 index 00000000000..722073366fc --- /dev/null +++ b/invokeai/frontend/web/src/features/parameters/util/modelMetadataHelpers.ts @@ -0,0 +1,150 @@ +import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants'; +import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'features/controlAdapters/store/types'; +import { + initialControlNet, + initialIPAdapter, + initialT2IAdapter, +} from 'features/controlAdapters/util/buildControlAdapter'; +import type { LoRA } from 'features/lora/store/loraSlice'; +import type { ModelIdentifierWithBase } from 'features/nodes/types/common'; +import { zModelIdentifierWithBase } from 'features/nodes/types/common'; +import type { + ControlNetMetadataItem, + IPAdapterMetadataItem, + LoRAMetadataItem, + T2IAdapterMetadataItem, +} from 'features/nodes/types/metadata'; +import { + fetchControlNetModel, + fetchIPAdapterModel, + fetchLoRAModel, + fetchMainModel, + fetchRefinerModel, + fetchT2IAdapterModel, + fetchVAEModel, + getModelKey, + raiseIfBaseIncompatible, +} from 'features/parameters/util/modelFetchingHelpers'; +import type { BaseModelType } from 'services/api/types'; +import { v4 as uuidv4 } from 'uuid'; + +export const prepareMainModelMetadataItem = async (model: unknown): Promise => { + const key = getModelKey(model); + const mainModel = await fetchMainModel(key); + return zModelIdentifierWithBase.parse(mainModel); +}; + +export const prepareRefinerMetadataItem = async (model: unknown): Promise => { + const key = getModelKey(model); + const refinerModel = await fetchRefinerModel(key); + return zModelIdentifierWithBase.parse(refinerModel); +}; + +export const prepareVAEMetadataItem = async (vae: unknown, base?: BaseModelType): Promise => { + const key = getModelKey(vae); + const vaeModel = await fetchVAEModel(key); + raiseIfBaseIncompatible(vaeModel.base, base, 'VAE incompatible with currently-selected model'); + return zModelIdentifierWithBase.parse(vaeModel); +}; + +export const prepareLoRAMetadataItem = async ( + loraMetadataItem: LoRAMetadataItem, + base?: BaseModelType +): Promise => { + const key = getModelKey(loraMetadataItem.lora); + const loraModel = await fetchLoRAModel(key); + raiseIfBaseIncompatible(loraModel.base, base, 'LoRA incompatible with currently-selected model'); + return { key: loraModel.key, base: loraModel.base, weight: loraMetadataItem.weight, isEnabled: true }; +}; + +export const prepareControlNetMetadataItem = async ( + controlnetMetadataItem: ControlNetMetadataItem, + base?: BaseModelType +): Promise => { + const key = getModelKey(controlnetMetadataItem.control_model); + const controlNetModel = await fetchControlNetModel(key); + raiseIfBaseIncompatible(controlNetModel.base, base, 'ControlNet incompatible with currently-selected model'); + + const { image, control_weight, begin_step_percent, end_step_percent, control_mode, resize_mode } = + controlnetMetadataItem; + + // We don't save the original image that was processed into a control image, only the processed image + const processorType = 'none'; + const processorNode = CONTROLNET_PROCESSORS.none.default; + + const controlnet: ControlNetConfig = { + type: 'controlnet', + isEnabled: true, + model: zModelIdentifierWithBase.parse(controlNetModel), + weight: typeof control_weight === 'number' ? control_weight : initialControlNet.weight, + beginStepPct: begin_step_percent || initialControlNet.beginStepPct, + endStepPct: end_step_percent || initialControlNet.endStepPct, + controlMode: control_mode || initialControlNet.controlMode, + resizeMode: resize_mode || initialControlNet.resizeMode, + controlImage: image?.image_name || null, + processedControlImage: image?.image_name || null, + processorType, + processorNode, + shouldAutoConfig: true, + id: uuidv4(), + }; + + return controlnet; +}; + +export const prepareT2IAdapterMetadataItem = async ( + t2iAdapterMetadataItem: T2IAdapterMetadataItem, + base?: BaseModelType +): Promise => { + const key = getModelKey(t2iAdapterMetadataItem.t2i_adapter_model); + const t2iAdapterModel = await fetchT2IAdapterModel(key); + raiseIfBaseIncompatible(t2iAdapterModel.base, base, 'T2I Adapter incompatible with currently-selected model'); + + const { image, weight, begin_step_percent, end_step_percent, resize_mode } = t2iAdapterMetadataItem; + + // We don't save the original image that was processed into a control image, only the processed image + const processorType = 'none'; + const processorNode = CONTROLNET_PROCESSORS.none.default; + + const t2iAdapter: T2IAdapterConfig = { + type: 't2i_adapter', + isEnabled: true, + model: zModelIdentifierWithBase.parse(t2iAdapterModel), + weight: typeof weight === 'number' ? weight : initialT2IAdapter.weight, + beginStepPct: begin_step_percent || initialT2IAdapter.beginStepPct, + endStepPct: end_step_percent || initialT2IAdapter.endStepPct, + resizeMode: resize_mode || initialT2IAdapter.resizeMode, + controlImage: image?.image_name || null, + processedControlImage: image?.image_name || null, + processorType, + processorNode, + shouldAutoConfig: true, + id: uuidv4(), + }; + + return t2iAdapter; +}; + +export const prepareIPAdapterMetadataItem = async ( + ipAdapterMetadataItem: IPAdapterMetadataItem, + base?: BaseModelType +): Promise => { + const key = getModelKey(ipAdapterMetadataItem?.ip_adapter_model); + const ipAdapterModel = await fetchIPAdapterModel(key); + raiseIfBaseIncompatible(ipAdapterModel.base, base, 'T2I Adapter incompatible with currently-selected model'); + + const { image, weight, begin_step_percent, end_step_percent } = ipAdapterMetadataItem; + + const ipAdapter: IPAdapterConfig = { + id: uuidv4(), + type: 'ip_adapter', + isEnabled: true, + controlImage: image?.image_name ?? null, + model: zModelIdentifierWithBase.parse(ipAdapterModel), + weight: weight ?? initialIPAdapter.weight, + beginStepPct: begin_step_percent ?? initialIPAdapter.beginStepPct, + endStepPct: end_step_percent ?? initialIPAdapter.endStepPct, + }; + + return ipAdapter; +}; diff --git a/invokeai/frontend/web/src/services/api/endpoints/models.ts b/invokeai/frontend/web/src/services/api/endpoints/models.ts index 666e0c707d5..2bd1a0a2460 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/models.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/models.ts @@ -6,16 +6,16 @@ import type { operations, paths } from 'services/api/schema'; import type { AnyModelConfig, BaseModelType, - ControlNetConfig, + ControlNetModelConfig, ImportModelConfig, - IPAdapterConfig, - LoRAConfig, + IPAdapterModelConfig, + LoRAModelConfig, MainModelConfig, MergeModelConfig, ModelType, - T2IAdapterConfig, - TextualInversionConfig, - VAEConfig, + T2IAdapterModelConfig, + TextualInversionModelConfig, + VAEModelConfig, } from 'services/api/types'; import type { ApiTagDescription, tagTypes } from '..'; @@ -30,7 +30,7 @@ type UpdateMainModelArg = { type UpdateLoRAModelArg = { base_model: BaseModelType; model_name: string; - body: LoRAConfig; + body: LoRAModelConfig; }; type UpdateMainModelResponse = @@ -97,27 +97,27 @@ export const mainModelsAdapter = createEntityAdapter({ sortComparer: (a, b) => a.name.localeCompare(b.name), }); export const mainModelsAdapterSelectors = mainModelsAdapter.getSelectors(undefined, getSelectorsOptions); -export const loraModelsAdapter = createEntityAdapter({ +export const loraModelsAdapter = createEntityAdapter({ selectId: (entity) => entity.key, sortComparer: (a, b) => a.name.localeCompare(b.name), }); export const loraModelsAdapterSelectors = loraModelsAdapter.getSelectors(undefined, getSelectorsOptions); -export const controlNetModelsAdapter = createEntityAdapter({ +export const controlNetModelsAdapter = createEntityAdapter({ selectId: (entity) => entity.key, sortComparer: (a, b) => a.name.localeCompare(b.name), }); export const controlNetModelsAdapterSelectors = controlNetModelsAdapter.getSelectors(undefined, getSelectorsOptions); -export const ipAdapterModelsAdapter = createEntityAdapter({ +export const ipAdapterModelsAdapter = createEntityAdapter({ selectId: (entity) => entity.key, sortComparer: (a, b) => a.name.localeCompare(b.name), }); export const ipAdapterModelsAdapterSelectors = ipAdapterModelsAdapter.getSelectors(undefined, getSelectorsOptions); -export const t2iAdapterModelsAdapter = createEntityAdapter({ +export const t2iAdapterModelsAdapter = createEntityAdapter({ selectId: (entity) => entity.key, sortComparer: (a, b) => a.name.localeCompare(b.name), }); export const t2iAdapterModelsAdapterSelectors = t2iAdapterModelsAdapter.getSelectors(undefined, getSelectorsOptions); -export const textualInversionModelsAdapter = createEntityAdapter({ +export const textualInversionModelsAdapter = createEntityAdapter({ selectId: (entity) => entity.key, sortComparer: (a, b) => a.name.localeCompare(b.name), }); @@ -125,7 +125,7 @@ export const textualInversionModelsAdapterSelectors = textualInversionModelsAdap undefined, getSelectorsOptions ); -export const vaeModelsAdapter = createEntityAdapter({ +export const vaeModelsAdapter = createEntityAdapter({ selectId: (entity) => entity.key, sortComparer: (a, b) => a.name.localeCompare(b.name), }); @@ -162,6 +162,8 @@ const buildTransformResponse = */ const buildModelsUrl = (path: string = '') => buildV2Url(`models/${path}`); +// TODO(psyche): Ideally we can share the cache between the `getXYZModels` queries and `getModelConfig` query + export const modelsApi = api.injectEndpoints({ endpoints: (build) => ({ getMainModels: build.query, BaseModelType[]>({ @@ -257,10 +259,10 @@ export const modelsApi = api.injectEndpoints({ }, invalidatesTags: ['Model'], }), - getLoRAModels: build.query, void>({ + getLoRAModels: build.query, void>({ query: () => ({ url: buildModelsUrl(), params: { model_type: 'lora' } }), - providesTags: buildProvidesTags('LoRAModel'), - transformResponse: buildTransformResponse(loraModelsAdapter), + providesTags: buildProvidesTags('LoRAModel'), + transformResponse: buildTransformResponse(loraModelsAdapter), }), updateLoRAModels: build.mutation({ query: ({ base_model, model_name, body }) => { @@ -281,30 +283,30 @@ export const modelsApi = api.injectEndpoints({ }, invalidatesTags: [{ type: 'LoRAModel', id: LIST_TAG }], }), - getControlNetModels: build.query, void>({ + getControlNetModels: build.query, void>({ query: () => ({ url: buildModelsUrl(), params: { model_type: 'controlnet' } }), - providesTags: buildProvidesTags('ControlNetModel'), - transformResponse: buildTransformResponse(controlNetModelsAdapter), + providesTags: buildProvidesTags('ControlNetModel'), + transformResponse: buildTransformResponse(controlNetModelsAdapter), }), - getIPAdapterModels: build.query, void>({ + getIPAdapterModels: build.query, void>({ query: () => ({ url: buildModelsUrl(), params: { model_type: 'ip_adapter' } }), - providesTags: buildProvidesTags('IPAdapterModel'), - transformResponse: buildTransformResponse(ipAdapterModelsAdapter), + providesTags: buildProvidesTags('IPAdapterModel'), + transformResponse: buildTransformResponse(ipAdapterModelsAdapter), }), - getT2IAdapterModels: build.query, void>({ + getT2IAdapterModels: build.query, void>({ query: () => ({ url: buildModelsUrl(), params: { model_type: 't2i_adapter' } }), - providesTags: buildProvidesTags('T2IAdapterModel'), - transformResponse: buildTransformResponse(t2iAdapterModelsAdapter), + providesTags: buildProvidesTags('T2IAdapterModel'), + transformResponse: buildTransformResponse(t2iAdapterModelsAdapter), }), - getVaeModels: build.query, void>({ + getVaeModels: build.query, void>({ query: () => ({ url: buildModelsUrl(), params: { model_type: 'vae' } }), - providesTags: buildProvidesTags('VaeModel'), - transformResponse: buildTransformResponse(vaeModelsAdapter), + providesTags: buildProvidesTags('VaeModel'), + transformResponse: buildTransformResponse(vaeModelsAdapter), }), - getTextualInversionModels: build.query, void>({ + getTextualInversionModels: build.query, void>({ query: () => ({ url: buildModelsUrl(), params: { model_type: 'embedding' } }), - providesTags: buildProvidesTags('TextualInversionModel'), - transformResponse: buildTransformResponse(textualInversionModelsAdapter), + providesTags: buildProvidesTags('TextualInversionModel'), + transformResponse: buildTransformResponse(textualInversionModelsAdapter), }), getModelsInFolder: build.query({ query: (arg) => { From 46389cc7b19a4fa54327903a47c061e17e74ab9d Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 22 Feb 2024 22:42:57 +1100 Subject: [PATCH 208/340] fix(ui): roll back utility-types It's `Required` util does not distribute over unions as expected. Also we have `ts-toolbelt` already for some utils. --- invokeai/frontend/web/package.json | 1 - invokeai/frontend/web/pnpm-lock.yaml | 8 -------- invokeai/frontend/web/src/services/api/types.ts | 5 ++--- 3 files changed, 2 insertions(+), 12 deletions(-) diff --git a/invokeai/frontend/web/package.json b/invokeai/frontend/web/package.json index 0bf236ee384..48bad31000e 100644 --- a/invokeai/frontend/web/package.json +++ b/invokeai/frontend/web/package.json @@ -145,7 +145,6 @@ "ts-toolbelt": "^9.6.0", "tsafe": "^1.6.6", "typescript": "^5.3.3", - "utility-types": "^3.11.0", "vite": "^5.1.3", "vite-plugin-css-injected-by-js": "^3.4.0", "vite-plugin-dts": "^3.7.2", diff --git a/invokeai/frontend/web/pnpm-lock.yaml b/invokeai/frontend/web/pnpm-lock.yaml index f2abdd87bf6..d79d482d08d 100644 --- a/invokeai/frontend/web/pnpm-lock.yaml +++ b/invokeai/frontend/web/pnpm-lock.yaml @@ -268,9 +268,6 @@ devDependencies: typescript: specifier: ^5.3.3 version: 5.3.3 - utility-types: - specifier: ^3.11.0 - version: 3.11.0 vite: specifier: ^5.1.3 version: 5.1.3(@types/node@20.11.19) @@ -14147,11 +14144,6 @@ packages: which-typed-array: 1.1.14 dev: true - /utility-types@3.11.0: - resolution: {integrity: sha512-6Z7Ma2aVEWisaL6TvBCy7P8rm2LQoPv6dJ7ecIaIixHcwfbJ0x7mWdbcwlIM5IGQxPZSFYeqRCqlOOeKoJYMkw==} - engines: {node: '>= 4'} - dev: true - /utils-merge@1.0.1: resolution: {integrity: sha512-pMZTvIkT1d+TFGvDOqodOclx0QWkkgi6Tdoa8gC8ffGAAqz9pzPTZWAybbsHHoED/ztMtkv/VoYTYyShUn81hA==} engines: {node: '>= 0.4.0'} diff --git a/invokeai/frontend/web/src/services/api/types.ts b/invokeai/frontend/web/src/services/api/types.ts index d561173337c..6f6018d9743 100644 --- a/invokeai/frontend/web/src/services/api/types.ts +++ b/invokeai/frontend/web/src/services/api/types.ts @@ -2,7 +2,6 @@ import type { UseToastOptions } from '@invoke-ai/ui-library'; import type { EntityState } from '@reduxjs/toolkit'; import type { components, paths } from 'services/api/schema'; import type { O } from 'ts-toolbelt'; -import type { Overwrite } from 'utility-types'; export type S = components['schemas']; @@ -72,8 +71,8 @@ export type TextualInversionModelConfig = S['TextualInversionConfig']; export type DiffusersModelConfig = S['MainDiffusersConfig']; export type CheckpointModelConfig = S['MainCheckpointConfig']; export type MainModelConfig = DiffusersModelConfig | CheckpointModelConfig; -export type RefinerMainModelConfig = Overwrite; -export type NonRefinerMainModelConfig = Overwrite; +export type RefinerMainModelConfig = Omit & { base: 'sdxl-refiner' }; +export type NonRefinerMainModelConfig = Omit & { base: 'any' | 'sd-1' | 'sd-2' | 'sdxl' }; export type AnyModelConfig = | LoRAModelConfig | VAEModelConfig From fda5b7d06c3f5a1afe8a5fa85f21b658d19e497b Mon Sep 17 00:00:00 2001 From: Brandon Rising Date: Thu, 22 Feb 2024 14:27:49 -0500 Subject: [PATCH 209/340] Allow users to run model manager without cuda --- .../app/services/model_install/model_install_default.py | 2 +- .../app/services/model_manager/model_manager_base.py | 5 ++++- .../app/services/model_manager/model_manager_default.py | 9 ++++++++- 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index 9b771c51596..c2718b5b2e1 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -543,7 +543,7 @@ def _register( self, model_path: Path, config: Optional[Dict[str, Any]] = None, info: Optional[AnyModelConfig] = None ) -> str: info = info or ModelProbe.probe(model_path, config) - key = self._create_key() + key = info.key or self._create_key() model_path = model_path.absolute() if model_path.is_relative_to(self.app_config.models_path): diff --git a/invokeai/app/services/model_manager/model_manager_base.py b/invokeai/app/services/model_manager/model_manager_base.py index c25aa6fb47c..938e14adcb5 100644 --- a/invokeai/app/services/model_manager/model_manager_base.py +++ b/invokeai/app/services/model_manager/model_manager_base.py @@ -1,5 +1,7 @@ # Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team +import torch + from abc import ABC, abstractmethod from typing import Optional @@ -32,9 +34,10 @@ class ModelManagerServiceBase(ABC): def build_model_manager( cls, app_config: InvokeAIAppConfig, - db: SqliteDatabase, + model_record_service: ModelRecordServiceBase, download_queue: DownloadQueueServiceBase, events: EventServiceBase, + execution_device: torch.device, ) -> Self: """ Construct the model manager service instance. diff --git a/invokeai/app/services/model_manager/model_manager_default.py b/invokeai/app/services/model_manager/model_manager_default.py index d029f9e0339..22761115867 100644 --- a/invokeai/app/services/model_manager/model_manager_default.py +++ b/invokeai/app/services/model_manager/model_manager_default.py @@ -1,6 +1,8 @@ # Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team """Implementation of ModelManagerServiceBase.""" +import torch + from typing import Optional from typing_extensions import Self @@ -9,6 +11,7 @@ from invokeai.app.services.shared.invocation_context import InvocationContextData from invokeai.backend.model_manager import AnyModelConfig, BaseModelType, LoadedModel, ModelType, SubModelType from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry +from invokeai.backend.util.devices import choose_torch_device from invokeai.backend.util.logging import InvokeAILogger from ..config import InvokeAIAppConfig @@ -119,6 +122,7 @@ def build_model_manager( model_record_service: ModelRecordServiceBase, download_queue: DownloadQueueServiceBase, events: EventServiceBase, + execution_device: torch.device = choose_torch_device(), ) -> Self: """ Construct the model manager service instance. @@ -129,7 +133,10 @@ def build_model_manager( logger.setLevel(app_config.log_level.upper()) ram_cache = ModelCache( - max_cache_size=app_config.ram_cache_size, max_vram_cache_size=app_config.vram_cache_size, logger=logger + max_cache_size=app_config.ram_cache_size, + max_vram_cache_size=app_config.vram_cache_size, + logger=logger, + execution_device=execution_device, ) convert_cache = ModelConvertCache( cache_path=app_config.models_convert_cache_path, max_size=app_config.convert_cache_size From fa8d8471dfa9126e2f91ba3733ce80bd3245498e Mon Sep 17 00:00:00 2001 From: Brandon Rising Date: Thu, 22 Feb 2024 14:54:48 -0500 Subject: [PATCH 210/340] Run ruff --- invokeai/app/services/model_manager/model_manager_base.py | 4 +--- invokeai/app/services/model_manager/model_manager_default.py | 3 +-- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/invokeai/app/services/model_manager/model_manager_base.py b/invokeai/app/services/model_manager/model_manager_base.py index 938e14adcb5..6e886df6527 100644 --- a/invokeai/app/services/model_manager/model_manager_base.py +++ b/invokeai/app/services/model_manager/model_manager_base.py @@ -1,10 +1,9 @@ # Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team -import torch - from abc import ABC, abstractmethod from typing import Optional +import torch from typing_extensions import Self from invokeai.app.services.invoker import Invoker @@ -18,7 +17,6 @@ from ..model_install import ModelInstallServiceBase from ..model_load import ModelLoadServiceBase from ..model_records import ModelRecordServiceBase -from ..shared.sqlite.sqlite_database import SqliteDatabase class ModelManagerServiceBase(ABC): diff --git a/invokeai/app/services/model_manager/model_manager_default.py b/invokeai/app/services/model_manager/model_manager_default.py index 22761115867..7d4b248323a 100644 --- a/invokeai/app/services/model_manager/model_manager_default.py +++ b/invokeai/app/services/model_manager/model_manager_default.py @@ -1,10 +1,9 @@ # Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team """Implementation of ModelManagerServiceBase.""" -import torch - from typing import Optional +import torch from typing_extensions import Self from invokeai.app.services.invoker import Invoker From fee4cfb0ad32efa5aec8f59264ab05e4f0954ca0 Mon Sep 17 00:00:00 2001 From: Brandon Rising Date: Fri, 23 Feb 2024 10:19:20 -0500 Subject: [PATCH 211/340] Remove passing keys in on register --- invokeai/app/services/model_install/model_install_default.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index c2718b5b2e1..9b771c51596 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -543,7 +543,7 @@ def _register( self, model_path: Path, config: Optional[Dict[str, Any]] = None, info: Optional[AnyModelConfig] = None ) -> str: info = info or ModelProbe.probe(model_path, config) - key = info.key or self._create_key() + key = self._create_key() model_path = model_path.absolute() if model_path.is_relative_to(self.app_config.models_path): From 8221a61ec19c4d28f9207d6ab713c52a9d0caae0 Mon Sep 17 00:00:00 2001 From: Brandon Rising Date: Fri, 23 Feb 2024 12:57:54 -0500 Subject: [PATCH 212/340] Allow passing in key on register --- .../app/services/model_install/model_install_default.py | 8 +++++--- invokeai/backend/model_manager/probe.py | 2 +- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index 9b771c51596..2419fbe5daa 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -542,8 +542,10 @@ def _create_key(self) -> str: def _register( self, model_path: Path, config: Optional[Dict[str, Any]] = None, info: Optional[AnyModelConfig] = None ) -> str: - info = info or ModelProbe.probe(model_path, config) key = self._create_key() + if config and not config.get('key', None): + config['key'] = key + info = info or ModelProbe.probe(model_path, config) model_path = model_path.absolute() if model_path.is_relative_to(self.app_config.models_path): @@ -556,8 +558,8 @@ def _register( # make config relative to our root legacy_conf = (self.app_config.root_dir / self.app_config.legacy_conf_dir / info.config).resolve() info.config = legacy_conf.relative_to(self.app_config.root_dir).as_posix() - self.record_store.add_model(key, info) - return key + self.record_store.add_model(info.key, info) + return info.key def _next_id(self) -> int: with self._lock: diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index 7de4289466d..c33254ef4e5 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -188,7 +188,7 @@ def probe( and fields["prediction_type"] == SchedulerPredictionType.VPrediction ) - model_info = ModelConfigFactory.make_config(fields) + model_info = ModelConfigFactory.make_config(fields, key=fields.get("key", None)) return model_info @classmethod From 9953ff91bf11ab845807619e019d5cbf3b5cd4f5 Mon Sep 17 00:00:00 2001 From: Mary Hipp Date: Tue, 20 Feb 2024 10:03:10 -0500 Subject: [PATCH 213/340] get old UI working somewhat with new endpoints --- .../subpanels/AddModelsPanel/AddModels.tsx | 6 +- .../AddModelsPanel/SimpleAddModels.tsx | 5 +- .../subpanels/ModelManagerPanel.tsx | 14 +-- .../ModelManagerPanel/CheckpointModelEdit.tsx | 37 +++---- .../ModelManagerPanel/DiffusersModelEdit.tsx | 32 +++--- .../ModelManagerPanel/LoRAModelEdit.tsx | 30 +++--- .../ModelManagerPanel/ModelConvert.tsx | 14 +-- .../subpanels/ModelManagerPanel/ModelList.tsx | 7 +- .../ModelManagerPanel/ModelListItem.tsx | 28 ++--- .../web/src/services/api/endpoints/models.ts | 102 +++++++----------- 10 files changed, 120 insertions(+), 155 deletions(-) diff --git a/invokeai/frontend/web/src/features/modelManager/subpanels/AddModelsPanel/AddModels.tsx b/invokeai/frontend/web/src/features/modelManager/subpanels/AddModelsPanel/AddModels.tsx index cb50334c992..9b4b95be9ec 100644 --- a/invokeai/frontend/web/src/features/modelManager/subpanels/AddModelsPanel/AddModels.tsx +++ b/invokeai/frontend/web/src/features/modelManager/subpanels/AddModelsPanel/AddModels.tsx @@ -1,15 +1,18 @@ -import { Button, ButtonGroup, Flex } from '@invoke-ai/ui-library'; +import { Button, ButtonGroup, Flex, Text } from '@invoke-ai/ui-library'; import { memo, useCallback, useState } from 'react'; import { useTranslation } from 'react-i18next'; import AdvancedAddModels from './AdvancedAddModels'; import SimpleAddModels from './SimpleAddModels'; +import { useGetModelImportsQuery } from '../../../../services/api/endpoints/models'; const AddModels = () => { const { t } = useTranslation(); const [addModelMode, setAddModelMode] = useState<'simple' | 'advanced'>('simple'); const handleAddModelSimple = useCallback(() => setAddModelMode('simple'), []); const handleAddModelAdvanced = useCallback(() => setAddModelMode('advanced'), []); + const { data } = useGetModelImportsQuery({}); + console.log({ data }); return ( @@ -24,6 +27,7 @@ const AddModels = () => { {addModelMode === 'simple' && } {addModelMode === 'advanced' && } + {data?.map((model) => {model.status})} ); }; diff --git a/invokeai/frontend/web/src/features/modelManager/subpanels/AddModelsPanel/SimpleAddModels.tsx b/invokeai/frontend/web/src/features/modelManager/subpanels/AddModelsPanel/SimpleAddModels.tsx index d7f705aedce..0124d6d570a 100644 --- a/invokeai/frontend/web/src/features/modelManager/subpanels/AddModelsPanel/SimpleAddModels.tsx +++ b/invokeai/frontend/web/src/features/modelManager/subpanels/AddModelsPanel/SimpleAddModels.tsx @@ -36,11 +36,10 @@ const SimpleAddModels = () => { const handleAddModelSubmit = (values: ExtendedImportModelConfig) => { const importModelResponseBody = { - location: values.location, - prediction_type: values.prediction_type === 'none' ? undefined : values.prediction_type, + config: values.prediction_type === 'none' ? undefined : values.prediction_type, }; - importMainModel({ body: importModelResponseBody }) + importMainModel({ source: values.location, config: importModelResponseBody }) .unwrap() .then((_) => { dispatch( diff --git a/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel.tsx b/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel.tsx index 15149b339b9..dab4e0b872f 100644 --- a/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel.tsx +++ b/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel.tsx @@ -2,13 +2,13 @@ import { Flex, Text } from '@invoke-ai/ui-library'; import { memo, useState } from 'react'; import { useTranslation } from 'react-i18next'; import { ALL_BASE_MODELS } from 'services/api/constants'; -import type { DiffusersModelConfig, LoRAConfig, MainModelConfig } from 'services/api/endpoints/models'; import { useGetLoRAModelsQuery, useGetMainModelsQuery } from 'services/api/endpoints/models'; import CheckpointModelEdit from './ModelManagerPanel/CheckpointModelEdit'; import DiffusersModelEdit from './ModelManagerPanel/DiffusersModelEdit'; import LoRAModelEdit from './ModelManagerPanel/LoRAModelEdit'; import ModelList from './ModelManagerPanel/ModelList'; +import { DiffusersModelConfig, LoRAConfig, MainModelConfig } from '../../../services/api/types'; const ModelManagerPanel = () => { const [selectedModelId, setSelectedModelId] = useState(); @@ -41,16 +41,16 @@ const ModelEdit = (props: ModelEditProps) => { const { t } = useTranslation(); const { model } = props; - if (model?.model_format === 'checkpoint') { - return ; + if (model?.format === 'checkpoint') { + return ; } - if (model?.model_format === 'diffusers') { - return ; + if (model?.format === 'diffusers') { + return ; } - if (model?.model_type === 'lora') { - return ; + if (model?.type === 'lora') { + return ; } return ( diff --git a/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx b/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx index 43707308e0e..0dd8a7add68 100644 --- a/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx +++ b/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx @@ -21,11 +21,9 @@ import { memo, useCallback, useEffect, useState } from 'react'; import type { SubmitHandler } from 'react-hook-form'; import { useForm } from 'react-hook-form'; import { useTranslation } from 'react-i18next'; -import type { CheckpointModelConfig } from 'services/api/endpoints/models'; -import { useGetCheckpointConfigsQuery, useUpdateMainModelsMutation } from 'services/api/endpoints/models'; -import type { CheckpointModelConfig } from 'services/api/types'; - +import { useGetCheckpointConfigsQuery, useUpdateModelsMutation } from 'services/api/endpoints/models'; import ModelConvert from './ModelConvert'; +import { CheckpointModelConfig } from '../../../../services/api/types'; type CheckpointModelEditProps = { model: CheckpointModelConfig; @@ -34,7 +32,7 @@ type CheckpointModelEditProps = { const CheckpointModelEdit = (props: CheckpointModelEditProps) => { const { model } = props; - const [updateMainModel, { isLoading }] = useUpdateMainModelsMutation(); + const [updateModel, { isLoading }] = useUpdateModelsMutation(); const { data: availableCheckpointConfigs } = useGetCheckpointConfigsQuery(); const [useCustomConfig, setUseCustomConfig] = useState(false); @@ -56,12 +54,12 @@ const CheckpointModelEdit = (props: CheckpointModelEditProps) => { reset, } = useForm({ defaultValues: { - model_name: model.model_name ? model.model_name : '', - base_model: model.base_model, - model_type: 'main', + name: model.name ? model.name : '', + base: model.base, + type: 'main', path: model.path ? model.path : '', description: model.description ? model.description : '', - model_format: 'checkpoint', + format: 'checkpoint', vae: model.vae ? model.vae : '', config: model.config ? model.config : '', variant: model.variant, @@ -74,11 +72,10 @@ const CheckpointModelEdit = (props: CheckpointModelEditProps) => { const onSubmit = useCallback>( (values) => { const responseBody = { - base_model: model.base_model, - model_name: model.model_name, + key: model.key, body: values, }; - updateMainModel(responseBody) + updateModel(responseBody) .unwrap() .then((payload) => { reset(payload as CheckpointModelConfig, { keepDefaultValues: true }); @@ -103,7 +100,7 @@ const CheckpointModelEdit = (props: CheckpointModelEditProps) => { ); }); }, - [dispatch, model.base_model, model.model_name, reset, t, updateMainModel] + [dispatch, model.key, reset, t, updateModel] ); return ( @@ -111,13 +108,13 @@ const CheckpointModelEdit = (props: CheckpointModelEditProps) => { - {model.model_name} + {model.name} - {MODEL_TYPE_MAP[model.base_model]} {t('modelManager.model')} + {MODEL_TYPE_MAP[model.base]} {t('modelManager.model')} - {![''].includes(model.base_model) ? ( + {![''].includes(model.base) ? ( ) : ( @@ -130,20 +127,20 @@ const CheckpointModelEdit = (props: CheckpointModelEditProps) => {

- + {t('modelManager.name')} value.trim().length > 3 || 'Must be at least 3 characters', })} /> - {errors.model_name?.message && {errors.model_name?.message}} + {errors.name?.message && {errors.name?.message}} {t('modelManager.description')} - control={control} name="base_model" /> + control={control} name="base" /> control={control} name="variant" /> {t('modelManager.modelLocation')} diff --git a/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/DiffusersModelEdit.tsx b/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/DiffusersModelEdit.tsx index bf6349234f5..5be9a5631c6 100644 --- a/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/DiffusersModelEdit.tsx +++ b/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/DiffusersModelEdit.tsx @@ -9,9 +9,8 @@ import { memo, useCallback } from 'react'; import type { SubmitHandler } from 'react-hook-form'; import { useForm } from 'react-hook-form'; import { useTranslation } from 'react-i18next'; -import type { DiffusersModelConfig } from 'services/api/endpoints/models'; -import { useUpdateMainModelsMutation } from 'services/api/endpoints/models'; import type { DiffusersModelConfig } from 'services/api/types'; +import { useUpdateModelsMutation } from '../../../../services/api/endpoints/models'; type DiffusersModelEditProps = { model: DiffusersModelConfig; @@ -20,7 +19,7 @@ type DiffusersModelEditProps = { const DiffusersModelEdit = (props: DiffusersModelEditProps) => { const { model } = props; - const [updateMainModel, { isLoading }] = useUpdateMainModelsMutation(); + const [updateModel, { isLoading }] = useUpdateModelsMutation(); const dispatch = useAppDispatch(); const { t } = useTranslation(); @@ -33,12 +32,12 @@ const DiffusersModelEdit = (props: DiffusersModelEditProps) => { reset, } = useForm({ defaultValues: { - model_name: model.model_name ? model.model_name : '', - base_model: model.base_model, - model_type: 'main', + name: model.name ? model.name : '', + base: model.base, + type: 'main', path: model.path ? model.path : '', description: model.description ? model.description : '', - model_format: 'diffusers', + format: 'diffusers', vae: model.vae ? model.vae : '', variant: model.variant, }, @@ -48,12 +47,11 @@ const DiffusersModelEdit = (props: DiffusersModelEditProps) => { const onSubmit = useCallback>( (values) => { const responseBody = { - base_model: model.base_model, - model_name: model.model_name, + key: model.key, body: values, }; - updateMainModel(responseBody) + updateModel(responseBody) .unwrap() .then((payload) => { reset(payload as DiffusersModelConfig, { keepDefaultValues: true }); @@ -78,37 +76,37 @@ const DiffusersModelEdit = (props: DiffusersModelEditProps) => { ); }); }, - [dispatch, model.base_model, model.model_name, reset, t, updateMainModel] + [dispatch, model.key, reset, t, updateModel] ); return ( - {model.model_name} + {model.name} - {MODEL_TYPE_MAP[model.base_model]} {t('modelManager.model')} + {MODEL_TYPE_MAP[model.base]} {t('modelManager.model')} - + {t('modelManager.name')} value.trim().length > 3 || 'Must be at least 3 characters', })} /> - {errors.model_name?.message && {errors.model_name?.message}} + {errors.name?.message && {errors.name?.message}} {t('modelManager.description')} - control={control} name="base_model" /> + control={control} name="base" /> control={control} name="variant" /> {t('modelManager.modelLocation')} diff --git a/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/LoRAModelEdit.tsx b/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/LoRAModelEdit.tsx index 1a8f235aaff..edb73e82756 100644 --- a/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/LoRAModelEdit.tsx +++ b/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/LoRAModelEdit.tsx @@ -8,7 +8,6 @@ import { memo, useCallback } from 'react'; import type { SubmitHandler } from 'react-hook-form'; import { useForm } from 'react-hook-form'; import { useTranslation } from 'react-i18next'; -import { useUpdateLoRAModelsMutation } from 'services/api/endpoints/models'; import type { LoRAModelConfig } from 'services/api/types'; type LoRAModelEditProps = { @@ -18,7 +17,7 @@ type LoRAModelEditProps = { const LoRAModelEdit = (props: LoRAModelEditProps) => { const { model } = props; - const [updateLoRAModel, { isLoading }] = useUpdateLoRAModelsMutation(); + const [updateModel, { isLoading }] = useUpdateModelsMutation(); const dispatch = useAppDispatch(); const { t } = useTranslation(); @@ -31,12 +30,12 @@ const LoRAModelEdit = (props: LoRAModelEditProps) => { reset, } = useForm({ defaultValues: { - model_name: model.model_name ? model.model_name : '', - base_model: model.base_model, - model_type: 'lora', + name: model.name ? model.name : '', + base: model.base, + type: 'lora', path: model.path ? model.path : '', description: model.description ? model.description : '', - model_format: model.model_format, + format: model.format, }, mode: 'onChange', }); @@ -44,12 +43,11 @@ const LoRAModelEdit = (props: LoRAModelEditProps) => { const onSubmit = useCallback>( (values) => { const responseBody = { - base_model: model.base_model, - model_name: model.model_name, + key: model.key, body: values, }; - updateLoRAModel(responseBody) + updateModel(responseBody) .unwrap() .then((payload) => { reset(payload as LoRAModelConfig, { keepDefaultValues: true }); @@ -74,17 +72,17 @@ const LoRAModelEdit = (props: LoRAModelEditProps) => { ); }); }, - [dispatch, model.base_model, model.model_name, reset, t, updateLoRAModel] + [dispatch, model.key, reset, t, updateModel] ); return ( - {model.model_name} + {model.name} - {MODEL_TYPE_MAP[model.base_model]} {t('modelManager.model')} ⋅ {LORA_MODEL_FORMAT_MAP[model.model_format]}{' '} + {MODEL_TYPE_MAP[model.base]} {t('modelManager.model')} ⋅ {LORA_MODEL_FORMAT_MAP[model.format]}{' '} {t('common.format')} @@ -92,20 +90,20 @@ const LoRAModelEdit = (props: LoRAModelEditProps) => { - + {t('modelManager.name')} value.trim().length > 3 || 'Must be at least 3 characters', })} /> - {errors.model_name?.message && {errors.model_name?.message}} + {errors.name?.message && {errors.name?.message}} {t('modelManager.description')} - control={control} name="base_model" /> + control={control} name="base" /> {t('modelManager.modelLocation')} diff --git a/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/ModelConvert.tsx b/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/ModelConvert.tsx index 6e34d5039ee..9a2746abe67 100644 --- a/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/ModelConvert.tsx +++ b/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/ModelConvert.tsx @@ -54,8 +54,8 @@ const ModelConvert = (props: ModelConvertProps) => { const modelConvertHandler = useCallback(() => { const queryArg = { - base_model: model.base_model, - model_name: model.model_name, + base_model: model.base, + model_name: model.name, convert_dest_directory: saveLocation === 'Custom' ? customSaveLocation : undefined, }; @@ -74,7 +74,7 @@ const ModelConvert = (props: ModelConvertProps) => { dispatch( addToast( makeToast({ - title: `${t('modelManager.convertingModelBegin')}: ${model.model_name}`, + title: `${t('modelManager.convertingModelBegin')}: ${model.name}`, status: 'info', }) ) @@ -86,7 +86,7 @@ const ModelConvert = (props: ModelConvertProps) => { dispatch( addToast( makeToast({ - title: `${t('modelManager.modelConverted')}: ${model.model_name}`, + title: `${t('modelManager.modelConverted')}: ${model.name}`, status: 'success', }) ) @@ -96,13 +96,13 @@ const ModelConvert = (props: ModelConvertProps) => { dispatch( addToast( makeToast({ - title: `${t('modelManager.modelConversionFailed')}: ${model.model_name}`, + title: `${t('modelManager.modelConversionFailed')}: ${model.name}`, status: 'error', }) ) ); }); - }, [convertModel, customSaveLocation, dispatch, model.base_model, model.model_name, saveLocation, t]); + }, [convertModel, customSaveLocation, dispatch, model.base, model.name, saveLocation, t]); return ( <> @@ -116,7 +116,7 @@ const ModelConvert = (props: ModelConvertProps) => { 🧨 {t('modelManager.convertToDiffusers')} { {modelList.map((model) => ( ))} diff --git a/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/ModelListItem.tsx b/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/ModelListItem.tsx index 835499d25ad..2014d889611 100644 --- a/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/ModelListItem.tsx +++ b/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/ModelListItem.tsx @@ -15,8 +15,8 @@ import { makeToast } from 'features/system/util/makeToast'; import { memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; import { PiTrashSimpleBold } from 'react-icons/pi'; -import type { LoRAConfig, MainModelConfig } from 'services/api/endpoints/models'; -import { useDeleteLoRAModelsMutation, useDeleteMainModelsMutation } from 'services/api/endpoints/models'; +import { useDeleteModelsMutation } from 'services/api/endpoints/models'; +import { LoRAConfig, MainModelConfig } from '../../../../services/api/types'; type ModelListItemProps = { model: MainModelConfig | LoRAConfig; @@ -27,29 +27,23 @@ type ModelListItemProps = { const ModelListItem = (props: ModelListItemProps) => { const { t } = useTranslation(); const dispatch = useAppDispatch(); - const [deleteMainModel] = useDeleteMainModelsMutation(); - const [deleteLoRAModel] = useDeleteLoRAModelsMutation(); + const [deleteModel] = useDeleteModelsMutation(); const { isOpen, onOpen, onClose } = useDisclosure(); const { model, isSelected, setSelectedModelId } = props; const handleSelectModel = useCallback(() => { - setSelectedModelId(model.id); - }, [model.id, setSelectedModelId]); + setSelectedModelId(model.key); + }, [model.key, setSelectedModelId]); const handleModelDelete = useCallback(() => { - const method = { - main: deleteMainModel, - lora: deleteLoRAModel, - }[model.model_type]; - - method(model) + deleteModel({ key: model.key }) .unwrap() .then((_) => { dispatch( addToast( makeToast({ - title: `${t('modelManager.modelDeleted')}: ${model.model_name}`, + title: `${t('modelManager.modelDeleted')}: ${model.name}`, status: 'success', }) ) @@ -60,7 +54,7 @@ const ModelListItem = (props: ModelListItemProps) => { dispatch( addToast( makeToast({ - title: `${t('modelManager.modelDeleteFailed')}: ${model.model_name}`, + title: `${t('modelManager.modelDeleteFailed')}: ${model.name}`, status: 'error', }) ) @@ -68,7 +62,7 @@ const ModelListItem = (props: ModelListItemProps) => { } }); setSelectedModelId(undefined); - }, [deleteMainModel, deleteLoRAModel, model, setSelectedModelId, dispatch, t]); + }, [deleteModel, model, setSelectedModelId, dispatch, t]); return ( @@ -85,10 +79,10 @@ const ModelListItem = (props: ModelListItemProps) => { > - {MODEL_TYPE_SHORT_MAP[model.base_model as keyof typeof MODEL_TYPE_SHORT_MAP]} + {MODEL_TYPE_SHORT_MAP[model.base as keyof typeof MODEL_TYPE_SHORT_MAP]} - {model.model_name} + {model.name} diff --git a/invokeai/frontend/web/src/services/api/endpoints/models.ts b/invokeai/frontend/web/src/services/api/endpoints/models.ts index 2bd1a0a2460..6897768a1bd 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/models.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/models.ts @@ -7,12 +7,10 @@ import type { AnyModelConfig, BaseModelType, ControlNetModelConfig, - ImportModelConfig, IPAdapterModelConfig, LoRAModelConfig, MainModelConfig, MergeModelConfig, - ModelType, T2IAdapterModelConfig, TextualInversionModelConfig, VAEModelConfig, @@ -21,37 +19,21 @@ import type { import type { ApiTagDescription, tagTypes } from '..'; import { api, buildV2Url, LIST_TAG } from '..'; -type UpdateMainModelArg = { - base_model: BaseModelType; - model_name: string; - body: MainModelConfig; -}; - -type UpdateLoRAModelArg = { - base_model: BaseModelType; - model_name: string; - body: LoRAModelConfig; +type UpdateModelArg = { + key: NonNullable; + body: NonNullable; }; -type UpdateMainModelResponse = - paths['/api/v2/models/i/{key}']['patch']['responses']['200']['content']['application/json']; +type UpdateModelResponse = paths['/api/v2/models/i/{key}']['patch']['responses']['200']['content']['application/json']; type ListModelsArg = NonNullable; -type UpdateLoRAModelResponse = UpdateMainModelResponse; - type DeleteMainModelArg = { - base_model: BaseModelType; - model_name: string; - model_type: ModelType; + key: string; }; type DeleteMainModelResponse = void; -type DeleteLoRAModelArg = DeleteMainModelArg; - -type DeleteLoRAModelResponse = void; - type ConvertMainModelArg = { base_model: BaseModelType; model_name: string; @@ -59,36 +41,40 @@ type ConvertMainModelArg = { }; type ConvertMainModelResponse = - paths['/api/v1/models/convert/{base_model}/{model_type}/{model_name}']['put']['responses']['200']['content']['application/json']; + paths['/api/v2/models/convert/{key}']['put']['responses']['200']['content']['application/json']; type MergeMainModelArg = { base_model: BaseModelType; body: MergeModelConfig; }; -type MergeMainModelResponse = - paths['/api/v1/models/merge/{base_model}']['put']['responses']['200']['content']['application/json']; +type MergeMainModelResponse = paths['/api/v2/models/merge']['put']['responses']['200']['content']['application/json']; type ImportMainModelArg = { - body: ImportModelConfig; + source: NonNullable; + access_token?: operations['heuristic_import_model']['parameters']['query']['access_token']; + config: NonNullable; }; type ImportMainModelResponse = - paths['/api/v1/models/import']['post']['responses']['201']['content']['application/json']; + paths['/api/v2/models/import']['post']['responses']['201']['content']['application/json']; + +type ListImportModelsResponse = + paths['/api/v2/models/import']['get']['responses']['200']['content']['application/json']; type AddMainModelArg = { body: MainModelConfig; }; -type AddMainModelResponse = paths['/api/v1/models/add']['post']['responses']['201']['content']['application/json']; +type AddMainModelResponse = paths['/api/v2/models/add']['post']['responses']['201']['content']['application/json']; -type SyncModelsResponse = paths['/api/v1/models/sync']['post']['responses']['201']['content']['application/json']; +type SyncModelsResponse = paths['/api/v2/models/sync']['post']['responses']['201']['content']['application/json']; export type SearchFolderResponse = - paths['/api/v1/models/search']['get']['responses']['200']['content']['application/json']; + paths['/api/v2/models/search']['get']['responses']['200']['content']['application/json']; type CheckpointConfigsResponse = - paths['/api/v1/models/ckpt_confs']['get']['responses']['200']['content']['application/json']; + paths['/api/v2/models/ckpt_confs']['get']['responses']['200']['content']['application/json']; type SearchFolderArg = operations['search_for_models']['parameters']['query']; @@ -179,10 +165,10 @@ export const modelsApi = api.injectEndpoints({ providesTags: buildProvidesTags('MainModel'), transformResponse: buildTransformResponse(mainModelsAdapter), }), - updateMainModels: build.mutation({ - query: ({ base_model, model_name, body }) => { + updateModels: build.mutation({ + query: ({ key, body }) => { return { - url: buildModelsUrl(`${base_model}/main/${model_name}`), + url: buildModelsUrl(`i/${key}`), method: 'PATCH', body: body, }; @@ -190,11 +176,12 @@ export const modelsApi = api.injectEndpoints({ invalidatesTags: ['Model'], }), importMainModels: build.mutation({ - query: ({ body }) => { + query: ({ source, config, access_token }) => { return { - url: buildModelsUrl('import'), + url: buildModelsUrl('heuristic_import'), + params: { source, access_token }, method: 'POST', - body: body, + body: config, }; }, invalidatesTags: ['Model'], @@ -209,10 +196,10 @@ export const modelsApi = api.injectEndpoints({ }, invalidatesTags: ['Model'], }), - deleteMainModels: build.mutation({ - query: ({ base_model, model_name, model_type }) => { + deleteModels: build.mutation({ + query: ({ key }) => { return { - url: buildModelsUrl(`${base_model}/${model_type}/${model_name}`), + url: buildModelsUrl(`i/${key}`), method: 'DELETE', }; }, @@ -264,25 +251,6 @@ export const modelsApi = api.injectEndpoints({ providesTags: buildProvidesTags('LoRAModel'), transformResponse: buildTransformResponse(loraModelsAdapter), }), - updateLoRAModels: build.mutation({ - query: ({ base_model, model_name, body }) => { - return { - url: buildModelsUrl(`${base_model}/lora/${model_name}`), - method: 'PATCH', - body: body, - }; - }, - invalidatesTags: [{ type: 'LoRAModel', id: LIST_TAG }], - }), - deleteLoRAModels: build.mutation({ - query: ({ base_model, model_name }) => { - return { - url: buildModelsUrl(`${base_model}/lora/${model_name}`), - method: 'DELETE', - }; - }, - invalidatesTags: [{ type: 'LoRAModel', id: LIST_TAG }], - }), getControlNetModels: build.query, void>({ query: () => ({ url: buildModelsUrl(), params: { model_type: 'controlnet' } }), providesTags: buildProvidesTags('ControlNetModel'), @@ -316,6 +284,13 @@ export const modelsApi = api.injectEndpoints({ }; }, }), + getModelImports: build.query({ + query: (arg) => { + return { + url: buildModelsUrl(`import`), + }; + }, + }), getCheckpointConfigs: build.query({ query: () => { return { @@ -335,15 +310,14 @@ export const { useGetLoRAModelsQuery, useGetTextualInversionModelsQuery, useGetVaeModelsQuery, - useUpdateMainModelsMutation, - useDeleteMainModelsMutation, + useDeleteModelsMutation, + useUpdateModelsMutation, useImportMainModelsMutation, useAddMainModelsMutation, useConvertMainModelsMutation, useMergeMainModelsMutation, - useDeleteLoRAModelsMutation, - useUpdateLoRAModelsMutation, useSyncModelsMutation, useGetModelsInFolderQuery, useGetCheckpointConfigsQuery, + useGetModelImportsQuery, } = modelsApi; From 6a80706ce4e7afee12c8215ead9ab70eb2c1db88 Mon Sep 17 00:00:00 2001 From: Mary Hipp Date: Tue, 20 Feb 2024 10:09:40 -0500 Subject: [PATCH 214/340] workspace for mary and jenn --- .../modelManagerV2/subpanels/ImportModels.tsx | 10 +++ .../modelManagerV2/subpanels/ModelManager.tsx | 44 ++++++++++++ .../ui/components/tabs/ModelManagerTab.tsx | 68 +++---------------- 3 files changed, 64 insertions(+), 58 deletions(-) create mode 100644 invokeai/frontend/web/src/features/modelManagerV2/subpanels/ImportModels.tsx create mode 100644 invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManager.tsx diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ImportModels.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ImportModels.tsx new file mode 100644 index 00000000000..2325fcf7dcc --- /dev/null +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ImportModels.tsx @@ -0,0 +1,10 @@ +import { Box } from '@invoke-ai/ui-library'; + +//jenn's workspace +export const ImportModels = () => { + return ( + + Import Models + + ); +}; diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManager.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManager.tsx new file mode 100644 index 00000000000..59b01de0c0f --- /dev/null +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManager.tsx @@ -0,0 +1,44 @@ +import { + Box, + Button, + Flex, + Heading, + IconButton, + Input, + InputGroup, + InputRightElement, + Spacer, +} from '@invoke-ai/ui-library'; +import { t } from 'i18next'; +import { PiXBold } from 'react-icons/pi'; +import { SyncModelsIconButton } from '../../modelManager/components/SyncModels/SyncModelsIconButton'; + +export const ModelManager = () => { + return ( + + + + Model Manager + + + + + + + + + + + + + ( + + } /> + + ) + + + + + ); +}; diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManagerTab.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManagerTab.tsx index 124c9b31a59..50db10fb572 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManagerTab.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManagerTab.tsx @@ -1,65 +1,17 @@ -import { Tab, TabList, TabPanel, TabPanels, Tabs } from '@invoke-ai/ui-library'; -import ImportModelsPanel from 'features/modelManager/subpanels/ImportModelsPanel'; -import MergeModelsPanel from 'features/modelManager/subpanels/MergeModelsPanel'; -import ModelManagerPanel from 'features/modelManager/subpanels/ModelManagerPanel'; -import ModelManagerSettingsPanel from 'features/modelManager/subpanels/ModelManagerSettingsPanel'; -import type { ReactNode } from 'react'; -import { memo, useMemo } from 'react'; +import { Flex, Box } from '@invoke-ai/ui-library'; +import { memo } from 'react'; import { useTranslation } from 'react-i18next'; - -type ModelManagerTabName = 'modelManager' | 'importModels' | 'mergeModels' | 'settings'; - -type ModelManagerTabInfo = { - id: ModelManagerTabName; - label: string; - content: ReactNode; -}; +import { ImportModels } from '../../../modelManagerV2/subpanels/ImportModels'; +import { ModelManager } from '../../../modelManagerV2/subpanels/ModelManager'; const ModelManagerTab = () => { - const { t } = useTranslation(); - - const tabs: ModelManagerTabInfo[] = useMemo( - () => [ - { - id: 'modelManager', - label: t('modelManager.modelManager'), - content: , - }, - { - id: 'importModels', - label: t('modelManager.importModels'), - content: , - }, - { - id: 'mergeModels', - label: t('modelManager.mergeModels'), - content: , - }, - { - id: 'settings', - label: t('modelManager.settings'), - content: , - }, - ], - [t] - ); return ( - - - {tabs.map((tab) => ( - - {tab.label} - - ))} - - - {tabs.map((tab) => ( - - {tab.content} - - ))} - - + + + + + + ); }; From f0d0f8df37e31db9c3c1997491162cddec14a6a5 Mon Sep 17 00:00:00 2001 From: Mary Hipp Date: Tue, 20 Feb 2024 13:03:28 -0500 Subject: [PATCH 215/340] model list, filtering, searching --- invokeai/frontend/web/src/app/store/store.ts | 2 + .../subpanels/AddModelsPanel/AddModels.tsx | 2 +- .../subpanels/ModelManagerPanel.tsx | 2 +- .../ModelManagerPanel/CheckpointModelEdit.tsx | 3 +- .../ModelManagerPanel/DiffusersModelEdit.tsx | 2 +- .../ModelManagerPanel/LoRAModelEdit.tsx | 1 + .../subpanels/ModelManagerPanel/ModelList.tsx | 2 +- .../ModelManagerPanel/ModelListItem.tsx | 2 +- .../store/modelManagerV2Slice.ts | 54 ++++++ .../modelManagerV2/subpanels/ModelManager.tsx | 32 +--- .../subpanels/ModelManagerPanel/ModelList.tsx | 160 ++++++++++++++++++ .../ModelManagerPanel/ModelListHeader.tsx | 23 +++ .../ModelManagerPanel/ModelListItem.tsx | 115 +++++++++++++ .../ModelManagerPanel/ModelListNavigation.tsx | 52 ++++++ .../ModelManagerPanel/ModelListWrapper.tsx | 25 +++ .../ModelManagerPanel/ModelTypeFilter.tsx | 54 ++++++ .../web/src/features/modelManagerV2/types.ts | 14 ++ .../ui/components/tabs/ModelManagerTab.tsx | 7 +- 18 files changed, 517 insertions(+), 35 deletions(-) create mode 100644 invokeai/frontend/web/src/features/modelManagerV2/store/modelManagerV2Slice.ts create mode 100644 invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelList.tsx create mode 100644 invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelListHeader.tsx create mode 100644 invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelListItem.tsx create mode 100644 invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelListNavigation.tsx create mode 100644 invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelListWrapper.tsx create mode 100644 invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelTypeFilter.tsx create mode 100644 invokeai/frontend/web/src/features/modelManagerV2/types.ts diff --git a/invokeai/frontend/web/src/app/store/store.ts b/invokeai/frontend/web/src/app/store/store.ts index 16f1632d882..c63bc02e092 100644 --- a/invokeai/frontend/web/src/app/store/store.ts +++ b/invokeai/frontend/web/src/app/store/store.ts @@ -16,6 +16,7 @@ import { galleryPersistConfig, gallerySlice } from 'features/gallery/store/galle import { hrfPersistConfig, hrfSlice } from 'features/hrf/store/hrfSlice'; import { loraPersistConfig, loraSlice } from 'features/lora/store/loraSlice'; import { modelManagerPersistConfig, modelManagerSlice } from 'features/modelManager/store/modelManagerSlice'; +import { modelManagerV2Slice } from 'features/modelManagerV2/store/modelManagerV2Slice'; import { nodesPersistConfig, nodesSlice } from 'features/nodes/store/nodesSlice'; import { workflowPersistConfig, workflowSlice } from 'features/nodes/store/workflowSlice'; import { generationPersistConfig, generationSlice } from 'features/parameters/store/generationSlice'; @@ -55,6 +56,7 @@ const allReducers = { [changeBoardModalSlice.name]: changeBoardModalSlice.reducer, [loraSlice.name]: loraSlice.reducer, [modelManagerSlice.name]: modelManagerSlice.reducer, + [modelManagerV2Slice.name]: modelManagerV2Slice.reducer, [sdxlSlice.name]: sdxlSlice.reducer, [queueSlice.name]: queueSlice.reducer, [workflowSlice.name]: workflowSlice.reducer, diff --git a/invokeai/frontend/web/src/features/modelManager/subpanels/AddModelsPanel/AddModels.tsx b/invokeai/frontend/web/src/features/modelManager/subpanels/AddModelsPanel/AddModels.tsx index 9b4b95be9ec..82ccb7f3090 100644 --- a/invokeai/frontend/web/src/features/modelManager/subpanels/AddModelsPanel/AddModels.tsx +++ b/invokeai/frontend/web/src/features/modelManager/subpanels/AddModelsPanel/AddModels.tsx @@ -1,10 +1,10 @@ import { Button, ButtonGroup, Flex, Text } from '@invoke-ai/ui-library'; import { memo, useCallback, useState } from 'react'; import { useTranslation } from 'react-i18next'; +import { useGetModelImportsQuery } from 'services/api/endpoints/models'; import AdvancedAddModels from './AdvancedAddModels'; import SimpleAddModels from './SimpleAddModels'; -import { useGetModelImportsQuery } from '../../../../services/api/endpoints/models'; const AddModels = () => { const { t } = useTranslation(); diff --git a/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel.tsx b/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel.tsx index dab4e0b872f..06b5b7db36b 100644 --- a/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel.tsx +++ b/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel.tsx @@ -3,12 +3,12 @@ import { memo, useState } from 'react'; import { useTranslation } from 'react-i18next'; import { ALL_BASE_MODELS } from 'services/api/constants'; import { useGetLoRAModelsQuery, useGetMainModelsQuery } from 'services/api/endpoints/models'; +import type { DiffusersModelConfig, LoRAConfig, MainModelConfig } from 'services/api/types'; import CheckpointModelEdit from './ModelManagerPanel/CheckpointModelEdit'; import DiffusersModelEdit from './ModelManagerPanel/DiffusersModelEdit'; import LoRAModelEdit from './ModelManagerPanel/LoRAModelEdit'; import ModelList from './ModelManagerPanel/ModelList'; -import { DiffusersModelConfig, LoRAConfig, MainModelConfig } from '../../../services/api/types'; const ModelManagerPanel = () => { const [selectedModelId, setSelectedModelId] = useState(); diff --git a/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx b/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx index 0dd8a7add68..c24d660cc63 100644 --- a/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx +++ b/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx @@ -22,8 +22,9 @@ import type { SubmitHandler } from 'react-hook-form'; import { useForm } from 'react-hook-form'; import { useTranslation } from 'react-i18next'; import { useGetCheckpointConfigsQuery, useUpdateModelsMutation } from 'services/api/endpoints/models'; +import type { CheckpointModelConfig } from 'services/api/types'; + import ModelConvert from './ModelConvert'; -import { CheckpointModelConfig } from '../../../../services/api/types'; type CheckpointModelEditProps = { model: CheckpointModelConfig; diff --git a/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/DiffusersModelEdit.tsx b/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/DiffusersModelEdit.tsx index 5be9a5631c6..b5023f0eff4 100644 --- a/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/DiffusersModelEdit.tsx +++ b/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/DiffusersModelEdit.tsx @@ -9,8 +9,8 @@ import { memo, useCallback } from 'react'; import type { SubmitHandler } from 'react-hook-form'; import { useForm } from 'react-hook-form'; import { useTranslation } from 'react-i18next'; +import { useUpdateModelsMutation } from 'services/api/endpoints/models'; import type { DiffusersModelConfig } from 'services/api/types'; -import { useUpdateModelsMutation } from '../../../../services/api/endpoints/models'; type DiffusersModelEditProps = { model: DiffusersModelConfig; diff --git a/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/LoRAModelEdit.tsx b/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/LoRAModelEdit.tsx index edb73e82756..81f2c4df29f 100644 --- a/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/LoRAModelEdit.tsx +++ b/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/LoRAModelEdit.tsx @@ -8,6 +8,7 @@ import { memo, useCallback } from 'react'; import type { SubmitHandler } from 'react-hook-form'; import { useForm } from 'react-hook-form'; import { useTranslation } from 'react-i18next'; +import { useUpdateModelsMutation } from 'services/api/endpoints/models'; import type { LoRAModelConfig } from 'services/api/types'; type LoRAModelEditProps = { diff --git a/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/ModelList.tsx b/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/ModelList.tsx index c546831476f..b129b7310d9 100644 --- a/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/ModelList.tsx +++ b/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/ModelList.tsx @@ -7,9 +7,9 @@ import { useTranslation } from 'react-i18next'; import { ALL_BASE_MODELS } from 'services/api/constants'; // import type { LoRAConfig, MainModelConfig } from 'services/api/endpoints/models'; import { useGetLoRAModelsQuery, useGetMainModelsQuery } from 'services/api/endpoints/models'; +import type { LoRAConfig, MainModelConfig } from 'services/api/types'; import ModelListItem from './ModelListItem'; -import { LoRAConfig, MainModelConfig } from '../../../../services/api/types'; type ModelListProps = { selectedModelId: string | undefined; diff --git a/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/ModelListItem.tsx b/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/ModelListItem.tsx index 2014d889611..08b9b61aea7 100644 --- a/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/ModelListItem.tsx +++ b/invokeai/frontend/web/src/features/modelManager/subpanels/ModelManagerPanel/ModelListItem.tsx @@ -16,7 +16,7 @@ import { memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; import { PiTrashSimpleBold } from 'react-icons/pi'; import { useDeleteModelsMutation } from 'services/api/endpoints/models'; -import { LoRAConfig, MainModelConfig } from '../../../../services/api/types'; +import type { LoRAConfig, MainModelConfig } from 'services/api/types'; type ModelListItemProps = { model: MainModelConfig | LoRAConfig; diff --git a/invokeai/frontend/web/src/features/modelManagerV2/store/modelManagerV2Slice.ts b/invokeai/frontend/web/src/features/modelManagerV2/store/modelManagerV2Slice.ts new file mode 100644 index 00000000000..29b071e9b16 --- /dev/null +++ b/invokeai/frontend/web/src/features/modelManagerV2/store/modelManagerV2Slice.ts @@ -0,0 +1,54 @@ +import type { PayloadAction } from '@reduxjs/toolkit'; +import { createSlice } from '@reduxjs/toolkit'; +import type { PersistConfig, RootState } from 'app/store/store'; + + +type ModelManagerState = { + _version: 1; + selectedModelKey: string | null; + searchTerm: string; + filteredModelType: string | null; +}; + +export const initialModelManagerState: ModelManagerState = { + _version: 1, + selectedModelKey: null, + filteredModelType: null, + searchTerm: "" +}; + +export const modelManagerV2Slice = createSlice({ + name: 'modelmanagerV2', + initialState: initialModelManagerState, + reducers: { + setSelectedModelKey: (state, action: PayloadAction) => { + state.selectedModelKey = action.payload; + }, + setSearchTerm: (state, action: PayloadAction) => { + state.searchTerm = action.payload; + }, + + setFilteredModelType: (state, action: PayloadAction) => { + state.filteredModelType = action.payload; + }, + }, +}); + +export const { setSelectedModelKey, setSearchTerm, setFilteredModelType } = modelManagerV2Slice.actions; + +export const selectModelManagerSlice = (state: RootState) => state.modelmanager; + +/* eslint-disable-next-line @typescript-eslint/no-explicit-any */ +export const migrateModelManagerState = (state: any): any => { + if (!('_version' in state)) { + state._version = 1; + } + return state; +}; + +export const modelManagerPersistConfig: PersistConfig = { + name: modelManagerV2Slice.name, + initialState: initialModelManagerState, + migrate: migrateModelManagerState, + persistDenylist: [], +}; diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManager.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManager.tsx index 59b01de0c0f..648f98dbb3e 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManager.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManager.tsx @@ -1,17 +1,8 @@ -import { - Box, - Button, - Flex, - Heading, - IconButton, - Input, - InputGroup, - InputRightElement, - Spacer, -} from '@invoke-ai/ui-library'; -import { t } from 'i18next'; -import { PiXBold } from 'react-icons/pi'; -import { SyncModelsIconButton } from '../../modelManager/components/SyncModels/SyncModelsIconButton'; +import { Box, Button, Flex, Heading } from '@invoke-ai/ui-library'; +import { SyncModelsIconButton } from 'features/modelManager/components/SyncModels/SyncModelsIconButton'; + +import ModelList from './ModelManagerPanel/ModelList'; +import { ModelListNavigation } from './ModelManagerPanel/ModelListNavigation'; export const ModelManager = () => { return ( @@ -27,17 +18,8 @@ export const ModelManager = () => { - - - - - ( - - } /> - - ) - - + + ); diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelList.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelList.tsx new file mode 100644 index 00000000000..bf43c01da20 --- /dev/null +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelList.tsx @@ -0,0 +1,160 @@ +import { Flex, Spinner, Text } from '@invoke-ai/ui-library'; +import type { EntityState } from '@reduxjs/toolkit'; +import { useAppSelector } from 'app/store/storeHooks'; +import { forEach } from 'lodash-es'; +import { memo } from 'react'; +import { ALL_BASE_MODELS } from 'services/api/constants'; +import { + useGetControlNetModelsQuery, + useGetIPAdapterModelsQuery, + useGetLoRAModelsQuery, + useGetMainModelsQuery, + useGetT2IAdapterModelsQuery, + useGetTextualInversionModelsQuery, + useGetVaeModelsQuery, +} from 'services/api/endpoints/models'; +import type { AnyModelConfig } from 'services/api/types'; + +import { ModelListWrapper } from './ModelListWrapper'; + +const ModelList = () => { + const { searchTerm, filteredModelType } = useAppSelector((s) => s.modelmanagerV2); + + const { filteredMainModels, isLoadingMainModels } = useGetMainModelsQuery(ALL_BASE_MODELS, { + selectFromResult: ({ data, isLoading }) => ({ + filteredMainModels: modelsFilter(data, searchTerm, filteredModelType), + isLoadingMainModels: isLoading, + }), + }); + + const { filteredLoraModels, isLoadingLoraModels } = useGetLoRAModelsQuery(undefined, { + selectFromResult: ({ data, isLoading }) => ({ + filteredLoraModels: modelsFilter(data, searchTerm, filteredModelType), + isLoadingLoraModels: isLoading, + }), + }); + + const { filteredTextualInversionModels, isLoadingTextualInversionModels } = useGetTextualInversionModelsQuery( + undefined, + { + selectFromResult: ({ data, isLoading }) => ({ + filteredTextualInversionModels: modelsFilter(data, searchTerm, filteredModelType), + isLoadingTextualInversionModels: isLoading, + }), + } + ); + + const { filteredControlnetModels, isLoadingControlnetModels } = useGetControlNetModelsQuery(undefined, { + selectFromResult: ({ data, isLoading }) => ({ + filteredControlnetModels: modelsFilter(data, searchTerm, filteredModelType), + isLoadingControlnetModels: isLoading, + }), + }); + + const { filteredT2iAdapterModels, isLoadingT2IAdapterModels } = useGetT2IAdapterModelsQuery(undefined, { + selectFromResult: ({ data, isLoading }) => ({ + filteredT2iAdapterModels: modelsFilter(data, searchTerm, filteredModelType), + isLoadingT2IAdapterModels: isLoading, + }), + }); + + const { filteredIpAdapterModels, isLoadingIpAdapterModels } = useGetIPAdapterModelsQuery(undefined, { + selectFromResult: ({ data, isLoading }) => ({ + filteredIpAdapterModels: modelsFilter(data, searchTerm, filteredModelType), + isLoadingIpAdapterModels: isLoading, + }), + }); + + const { filteredVaeModels, isLoadingVaeModels } = useGetVaeModelsQuery(undefined, { + selectFromResult: ({ data, isLoading }) => ({ + filteredVaeModels: modelsFilter(data, searchTerm, filteredModelType), + isLoadingVaeModels: isLoading, + }), + }); + + return ( + + + {/* Main Model List */} + {isLoadingMainModels && } + {!isLoadingMainModels && filteredMainModels.length > 0 && ( + + )} + {/* LoRAs List */} + {isLoadingLoraModels && } + {!isLoadingLoraModels && filteredLoraModels.length > 0 && ( + + )} + + {/* TI List */} + {isLoadingTextualInversionModels && } + {!isLoadingTextualInversionModels && filteredTextualInversionModels.length > 0 && ( + + )} + + {/* VAE List */} + {isLoadingVaeModels && } + {!isLoadingVaeModels && filteredVaeModels.length > 0 && ( + + )} + + {/* Controlnet List */} + {isLoadingControlnetModels && } + {!isLoadingControlnetModels && filteredControlnetModels.length > 0 && ( + + )} + {/* IP Adapter List */} + {isLoadingIpAdapterModels && } + {!isLoadingIpAdapterModels && filteredIpAdapterModels.length > 0 && ( + + )} + {/* T2I Adapters List */} + {isLoadingT2IAdapterModels && } + {!isLoadingT2IAdapterModels && filteredT2iAdapterModels.length > 0 && ( + + )} + + + ); +}; + +export default memo(ModelList); + +const modelsFilter = ( + data: EntityState | undefined, + nameFilter: string, + filteredModelType: string | null +): T[] => { + const filteredModels: T[] = []; + + forEach(data?.entities, (model) => { + if (!model) { + return; + } + + const matchesFilter = model.name.toLowerCase().includes(nameFilter.toLowerCase()); + const matchesType = filteredModelType ? model.type === filteredModelType : true; + + if (matchesFilter && matchesType) { + filteredModels.push(model); + } + }); + return filteredModels; +}; + +const FetchingModelsLoader = memo(({ loadingMessage }: { loadingMessage?: string }) => { + return ( + + + + {loadingMessage ? loadingMessage : 'Fetching...'} + + + ); +}); + +FetchingModelsLoader.displayName = 'FetchingModelsLoader'; diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelListHeader.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelListHeader.tsx new file mode 100644 index 00000000000..874d1c9ac2d --- /dev/null +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelListHeader.tsx @@ -0,0 +1,23 @@ +import { Box, Divider, Text } from '@invoke-ai/ui-library'; + +export const ModelListHeader = ({ title }: { title: string }) => { + return ( + + + + + {title} + + + + ); +}; diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelListItem.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelListItem.tsx new file mode 100644 index 00000000000..5cc429ebcdc --- /dev/null +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelListItem.tsx @@ -0,0 +1,115 @@ +import { + Badge, + Button, + ConfirmationAlertDialog, + Flex, + IconButton, + Text, + Tooltip, + useDisclosure, +} from '@invoke-ai/ui-library'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { setSelectedModelKey } from 'features/modelManagerV2/store/modelManagerV2Slice'; +import { MODEL_TYPE_SHORT_MAP } from 'features/parameters/types/constants'; +import { addToast } from 'features/system/store/systemSlice'; +import { makeToast } from 'features/system/util/makeToast'; +import { memo, useCallback, useMemo } from 'react'; +import { useTranslation } from 'react-i18next'; +import { PiTrashSimpleBold } from 'react-icons/pi'; +import { useDeleteModelsMutation } from 'services/api/endpoints/models'; +import type { AnyModelConfig } from 'services/api/types'; + +type ModelListItemProps = { + model: AnyModelConfig; +}; + +const ModelListItem = (props: ModelListItemProps) => { + const { t } = useTranslation(); + const dispatch = useAppDispatch(); + const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey); + const [deleteModel] = useDeleteModelsMutation(); + const { isOpen, onOpen, onClose } = useDisclosure(); + + const { model } = props; + + const handleSelectModel = useCallback(() => { + dispatch(setSelectedModelKey(model.key)); + }, [model.key, dispatch]); + + const isSelected = useMemo(() => { + return selectedModelKey === model.key; + }, [selectedModelKey, model.key]); + + const handleModelDelete = useCallback(() => { + deleteModel({ key: model.key }) + .unwrap() + .then((_) => { + dispatch( + addToast( + makeToast({ + title: `${t('modelManager.modelDeleted')}: ${model.name}`, + status: 'success', + }) + ) + ); + }) + .catch((error) => { + if (error) { + dispatch( + addToast( + makeToast({ + title: `${t('modelManager.modelDeleteFailed')}: ${model.name}`, + status: 'error', + }) + ) + ); + } + }); + dispatch(setSelectedModelKey(null)); + }, [deleteModel, model, dispatch, t]); + + return ( + + + + + {MODEL_TYPE_SHORT_MAP[model.base as keyof typeof MODEL_TYPE_SHORT_MAP]} + + + {model.name} + + + + } + aria-label={t('modelManager.deleteConfig')} + colorScheme="error" + /> + + + {t('modelManager.deleteMsg1')} + {t('modelManager.deleteMsg2')} + + + + ); +}; + +export default memo(ModelListItem); diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelListNavigation.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelListNavigation.tsx new file mode 100644 index 00000000000..c0d06af2452 --- /dev/null +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelListNavigation.tsx @@ -0,0 +1,52 @@ +import { Flex, IconButton,Input, InputGroup, InputRightElement, Spacer } from '@invoke-ai/ui-library'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { setSearchTerm } from 'features/modelManagerV2/store/modelManagerV2Slice'; +import { t } from 'i18next'; +import type { ChangeEventHandler} from 'react'; +import { useCallback } from 'react'; +import { PiXBold } from 'react-icons/pi'; + +import { ModelTypeFilter } from './ModelTypeFilter'; + +export const ModelListNavigation = () => { + const dispatch = useAppDispatch(); + const searchTerm = useAppSelector((s) => s.modelmanagerV2.searchTerm); + + const handleSearch: ChangeEventHandler = useCallback( + (event) => { + dispatch(setSearchTerm(event.target.value)); + }, + [dispatch] + ); + + const clearSearch = useCallback(() => { + dispatch(setSearchTerm('')); + }, [dispatch]); + + return ( + + + + + + + {!!searchTerm?.length && ( + + } + onClick={clearSearch} + /> + + )} + + + ); +}; diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelListWrapper.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelListWrapper.tsx new file mode 100644 index 00000000000..24460e64533 --- /dev/null +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelListWrapper.tsx @@ -0,0 +1,25 @@ +import { Flex } from '@invoke-ai/ui-library'; +import type { AnyModelConfig } from 'services/api/types'; + +import { ModelListHeader } from './ModelListHeader'; +import ModelListItem from './ModelListItem'; + +type ModelListWrapperProps = { + title: string; + modelList: AnyModelConfig[]; +}; + +export const ModelListWrapper = (props: ModelListWrapperProps) => { + const { title, modelList } = props; + return ( + + + + + {modelList.map((model) => ( + + ))} + + + ); +}; diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelTypeFilter.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelTypeFilter.tsx new file mode 100644 index 00000000000..0134ffc811d --- /dev/null +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelTypeFilter.tsx @@ -0,0 +1,54 @@ +import { Button, Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { setFilteredModelType } from 'features/modelManagerV2/store/modelManagerV2Slice'; +import { useCallback } from 'react'; +import { IoFilter } from 'react-icons/io5'; + +export const MODEL_TYPE_LABELS: { [key: string]: string } = { + main: 'Main', + lora: 'LoRA', + embedding: 'Textual Inversion', + controlnet: 'ControlNet', + vae: 'VAE', + t2i_adapter: 'T2I Adapter', + ip_adapter: 'IP Adapter', + clip_vision: 'Clip Vision', + onnx: 'Onnx', +}; + +export const ModelTypeFilter = () => { + const dispatch = useAppDispatch(); + const filteredModelType = useAppSelector((s) => s.modelmanagerV2.filteredModelType); + + const selectModelType = useCallback( + (option: string) => { + dispatch(setFilteredModelType(option)); + }, + [dispatch] + ); + + const clearModelType = useCallback(() => { + dispatch(setFilteredModelType(null)); + }, [dispatch]); + + return ( + + }> + {filteredModelType ? MODEL_TYPE_LABELS[filteredModelType] : 'All Models'} + + + All Models + {Object.keys(MODEL_TYPE_LABELS).map((option) => ( + + {MODEL_TYPE_LABELS[option]} + + ))} + + + ); +}; diff --git a/invokeai/frontend/web/src/features/modelManagerV2/types.ts b/invokeai/frontend/web/src/features/modelManagerV2/types.ts new file mode 100644 index 00000000000..a209fbb8768 --- /dev/null +++ b/invokeai/frontend/web/src/features/modelManagerV2/types.ts @@ -0,0 +1,14 @@ +import { z } from "zod"; + +export const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']); +export const zModelType = z.enum([ + 'main', + 'vae', + 'lora', + 'controlnet', + 'embedding', + 'ip_adapter', + 'clip_vision', + 't2i_adapter', + 'onnx', // TODO(psyche): Remove this when removed from backend +]); \ No newline at end of file diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManagerTab.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManagerTab.tsx index 50db10fb572..8117631f22e 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManagerTab.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManagerTab.tsx @@ -1,8 +1,7 @@ -import { Flex, Box } from '@invoke-ai/ui-library'; +import { Box,Flex } from '@invoke-ai/ui-library'; +import { ImportModels } from 'features/modelManagerV2/subpanels/ImportModels'; +import { ModelManager } from 'features/modelManagerV2/subpanels/ModelManager'; import { memo } from 'react'; -import { useTranslation } from 'react-i18next'; -import { ImportModels } from '../../../modelManagerV2/subpanels/ImportModels'; -import { ModelManager } from '../../../modelManagerV2/subpanels/ModelManager'; const ModelManagerTab = () => { return ( From 35f870f432f6f0ae64f15bc279a8a69d2f008ba6 Mon Sep 17 00:00:00 2001 From: Jennifer Player Date: Tue, 20 Feb 2024 13:39:27 -0500 Subject: [PATCH 216/340] added import model form and importqueue --- invokeai/frontend/web/public/locales/en.json | 3 + .../modelManagerV2/subpanels/ImportModels.tsx | 83 +++++++++++++++++- .../modelManagerV2/subpanels/ImportQueue.tsx | 85 +++++++++++++++++++ 3 files changed, 167 insertions(+), 4 deletions(-) create mode 100644 invokeai/frontend/web/src/features/modelManagerV2/subpanels/ImportQueue.tsx diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index a458563fd56..406df71e060 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -701,6 +701,7 @@ "availableModels": "Available Models", "baseModel": "Base Model", "cached": "cached", + "cancel": "Cancel", "cannotUseSpaces": "Cannot Use Spaces", "checkpointFolder": "Checkpoint Folder", "checkpointModels": "Checkpoints", @@ -743,6 +744,7 @@ "heightValidationMsg": "Default height of your model.", "ignoreMismatch": "Ignore Mismatches Between Selected Models", "importModels": "Import Models", + "importQueue": "Import Queue", "inpainting": "v1 Inpainting", "interpolationType": "Interpolation Type", "inverseSigmoid": "Inverse Sigmoid", @@ -796,6 +798,7 @@ "pickModelType": "Pick Model Type", "predictionType": "Prediction Type (for Stable Diffusion 2.x Models and occasional Stable Diffusion 1.x Models)", "quickAdd": "Quick Add", + "removeFromQueue": "Remove From Queue", "repo_id": "Repo ID", "repoIDValidationMsg": "Online repository of your model", "safetensorModels": "SafeTensors", diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ImportModels.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ImportModels.tsx index 2325fcf7dcc..14ed0c6848d 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ImportModels.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ImportModels.tsx @@ -1,10 +1,85 @@ -import { Box } from '@invoke-ai/ui-library'; +import { Button, Box, Flex, FormControl, FormLabel, Heading, Input, Text, Divider } from '@invoke-ai/ui-library'; +import { t } from 'i18next'; +import { CSSProperties } from 'react'; +import { useImportMainModelsMutation } from '../../../services/api/endpoints/models'; +import { useForm } from '@mantine/form'; +import { useAppDispatch } from 'app/store/storeHooks'; +import { addToast } from 'features/system/store/systemSlice'; +import { makeToast } from 'features/system/util/makeToast'; +import { ImportQueue } from './ImportQueue'; + +const formStyles: CSSProperties = { + width: '100%', +}; + +type ExtendedImportModelConfig = { + location: string; +}; -//jenn's workspace export const ImportModels = () => { + const dispatch = useAppDispatch(); + + const [importMainModel, { isLoading }] = useImportMainModelsMutation(); + + const addModelForm = useForm({ + initialValues: { + location: '', + }, + }); + + console.log('addModelForm', addModelForm.values.location) + + const handleAddModelSubmit = (values: ExtendedImportModelConfig) => { + importMainModel({ source: values.location, config: undefined }) + .unwrap() + .then((_) => { + dispatch( + addToast( + makeToast({ + title: t('toast.modelAddedSimple'), + status: 'success', + }) + ) + ); + addModelForm.reset(); + }) + .catch((error) => { + if (error) { + dispatch( + addToast( + makeToast({ + title: `${error.data.detail} `, + status: 'error', + }) + ) + ); + } + }); + }; + return ( - - Import Models + + + Add Model + + + handleAddModelSubmit(v))} style={formStyles}> + + + + {t('modelManager.modelLocation')} + + + + + + + + {t('modelManager.importQueue')} + + ); }; diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ImportQueue.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ImportQueue.tsx new file mode 100644 index 00000000000..4a0407ff1a0 --- /dev/null +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ImportQueue.tsx @@ -0,0 +1,85 @@ +import { + Button, + Box, + Flex, + FormControl, + FormLabel, + Heading, + IconButton, + Input, + InputGroup, + InputRightElement, + Progress, + Text, +} from '@invoke-ai/ui-library'; +import { t } from 'i18next'; +import { useMemo } from 'react'; +import { useGetModelImportsQuery } from '../../../services/api/endpoints/models'; +import { useAppDispatch } from 'app/store/storeHooks'; +import { addToast } from 'features/system/store/systemSlice'; +import { makeToast } from 'features/system/util/makeToast'; +import { PiXBold } from 'react-icons/pi'; + +export const ImportQueue = () => { + const dispatch = useAppDispatch(); + + // start with this data then pull from sockets (idk how to do that yet, also might not even use this and just use socket) + const { data } = useGetModelImportsQuery(); + + const progressValues = useMemo(() => { + if (!data) { + return []; + } + const values = []; + for (let i = 0; i < data.length; i++) { + let value; + if (data[i] && data[i]?.bytes && data[i]?.total_bytes) { + value = (data[i]?.bytes / data[i]?.total_bytes) * 100; + } + values.push(value || undefined); + } + return values; + }, [data]); + + return ( + + + {data?.map((model, i) => ( + + + {model.source.repo_id} + + + {model.status} + {model.status === 'completed' ? ( + } + // onClick={handleRemove} + /> + ) : ( + } + // onClick={handleCancel} + colorScheme="error" + /> + )} + + ))} + + + ); +}; From c8e84a7dfc4b83e629d6d4da7e9d29390d5fa8d4 Mon Sep 17 00:00:00 2001 From: Mary Hipp Date: Wed, 21 Feb 2024 09:39:02 -0500 Subject: [PATCH 217/340] single model view --- .../ModelManagerPanel/ModelListItem.tsx | 11 ++ .../modelManagerV2/subpanels/ModelPane.tsx | 13 ++ .../subpanels/ModelPanel/ModelAttrView.tsx | 15 +++ .../subpanels/ModelPanel/ModelView.tsx | 116 ++++++++++++++++++ .../web/src/features/modelManagerV2/types.ts | 14 --- .../ui/components/tabs/ModelManagerTab.tsx | 29 +++-- .../web/src/services/api/endpoints/models.ts | 11 ++ .../frontend/web/src/services/api/schema.ts | 13 +- 8 files changed, 195 insertions(+), 27 deletions(-) create mode 100644 invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPane.tsx create mode 100644 invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelAttrView.tsx create mode 100644 invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelView.tsx delete mode 100644 invokeai/frontend/web/src/features/modelManagerV2/types.ts diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelListItem.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelListItem.tsx index 5cc429ebcdc..d6cb70f4e89 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelListItem.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelListItem.tsx @@ -3,10 +3,12 @@ import { Button, ConfirmationAlertDialog, Flex, + Icon, IconButton, Text, Tooltip, useDisclosure, + Box, } from '@invoke-ai/ui-library'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { setSelectedModelKey } from 'features/modelManagerV2/store/modelManagerV2Slice'; @@ -15,6 +17,7 @@ import { addToast } from 'features/system/store/systemSlice'; import { makeToast } from 'features/system/util/makeToast'; import { memo, useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; +import { IoWarning } from 'react-icons/io5'; import { PiTrashSimpleBold } from 'react-icons/pi'; import { useDeleteModelsMutation } from 'services/api/endpoints/models'; import type { AnyModelConfig } from 'services/api/types'; @@ -88,8 +91,16 @@ const ModelListItem = (props: ModelListItemProps) => { {model.name} + {model.format === 'checkpoint' && ( + + + + + + )} + } diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPane.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPane.tsx new file mode 100644 index 00000000000..9756b357b35 --- /dev/null +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPane.tsx @@ -0,0 +1,13 @@ +import { Box } from '@invoke-ai/ui-library'; +import { useAppSelector } from '../../../app/store/storeHooks'; +import { ImportModels } from './ImportModels'; +import { ModelView } from './ModelPanel/ModelView'; + +export const ModelPane = () => { + const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey); + return ( + + {selectedModelKey ? : } + + ); +}; diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelAttrView.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelAttrView.tsx new file mode 100644 index 00000000000..f45bfca993a --- /dev/null +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelAttrView.tsx @@ -0,0 +1,15 @@ +import { FormControl, FormLabel, Text } from '@invoke-ai/ui-library'; + +interface Props { + label: string; + value: string | null | undefined; +} + +export const ModelAttrView = ({ label, value }: Props) => { + return ( + + {label} + {value || '-'} + + ); +}; diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelView.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelView.tsx new file mode 100644 index 00000000000..83e4e5380f9 --- /dev/null +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelView.tsx @@ -0,0 +1,116 @@ +import { skipToken } from '@reduxjs/toolkit/query'; +import { useAppSelector } from '../../../../app/store/storeHooks'; +import { useGetModelQuery } from '../../../../services/api/endpoints/models'; +import { Flex, Text, Heading } from '@invoke-ai/ui-library'; +import DataViewer from '../../../gallery/components/ImageMetadataViewer/DataViewer'; +import { useMemo } from 'react'; +import { + CheckpointModelConfig, + ControlNetConfig, + DiffusersModelConfig, + IPAdapterConfig, + LoRAConfig, + T2IAdapterConfig, + TextualInversionConfig, + VAEConfig, +} from '../../../../services/api/types'; +import { ModelAttrView } from './ModelAttrView'; + +export const ModelView = () => { + const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey); + const { data, isLoading } = useGetModelQuery(selectedModelKey ?? skipToken); + + const modelConfigData = useMemo(() => { + if (!data) { + return null; + } + const modelFormat = data.config.format; + const modelType = data.config.type; + + if (modelType === 'main') { + if (modelFormat === 'diffusers') { + return data.config as DiffusersModelConfig; + } else if (modelFormat === 'checkpoint') { + return data.config as CheckpointModelConfig; + } + } + + switch (modelType) { + case 'lora': + return data.config as LoRAConfig; + case 'embedding': + return data.config as TextualInversionConfig; + case 't2i_adapter': + return data.config as T2IAdapterConfig; + case 'ip_adapter': + return data.config as IPAdapterConfig; + case 'controlnet': + return data.config as ControlNetConfig; + case 'vae': + return data.config as VAEConfig; + default: + return null; + } + }, [data]); + + if (isLoading) { + return Loading; + } + + if (!modelConfigData) { + return Something went wrong; + } + return ( + + + + {modelConfigData.name} + + {modelConfigData.source && Source: {modelConfigData.source}} + + + + + + + + + + + + + + + {modelConfigData.type === 'main' && ( + <> + + {modelConfigData.format === 'diffusers' && ( + + )} + {modelConfigData.format === 'checkpoint' && ( + + )} + + + + + + + + + + + + + )} + {modelConfigData.type === 'ip_adapter' && ( + + + + )} + + + {!!data?.metadata && } + + ); +}; diff --git a/invokeai/frontend/web/src/features/modelManagerV2/types.ts b/invokeai/frontend/web/src/features/modelManagerV2/types.ts deleted file mode 100644 index a209fbb8768..00000000000 --- a/invokeai/frontend/web/src/features/modelManagerV2/types.ts +++ /dev/null @@ -1,14 +0,0 @@ -import { z } from "zod"; - -export const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']); -export const zModelType = z.enum([ - 'main', - 'vae', - 'lora', - 'controlnet', - 'embedding', - 'ip_adapter', - 'clip_vision', - 't2i_adapter', - 'onnx', // TODO(psyche): Remove this when removed from backend -]); \ No newline at end of file diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManagerTab.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManagerTab.tsx index 8117631f22e..9245f5c60d5 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManagerTab.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManagerTab.tsx @@ -1,16 +1,25 @@ -import { Box,Flex } from '@invoke-ai/ui-library'; -import { ImportModels } from 'features/modelManagerV2/subpanels/ImportModels'; -import { ModelManager } from 'features/modelManagerV2/subpanels/ModelManager'; -import { memo } from 'react'; +import { Flex, Heading, Tab, TabList, TabPanel, TabPanels, Tabs, Box, Button } from '@invoke-ai/ui-library'; +import ImportModelsPanel from 'features/modelManager/subpanels/ImportModelsPanel'; +import MergeModelsPanel from 'features/modelManager/subpanels/MergeModelsPanel'; +import ModelManagerPanel from 'features/modelManager/subpanels/ModelManagerPanel'; +import ModelManagerSettingsPanel from 'features/modelManager/subpanels/ModelManagerSettingsPanel'; +import type { ReactNode } from 'react'; +import { memo, useMemo } from 'react'; +import { useTranslation } from 'react-i18next'; +import { SyncModelsIconButton } from '../../../modelManager/components/SyncModels/SyncModelsIconButton'; +import { ModelManager } from '../../../modelManagerV2/subpanels/ModelManager'; +import { ModelPane } from '../../../modelManagerV2/subpanels/ModelPane'; + +type ModelManagerTabName = 'modelManager' | 'importModels' | 'mergeModels' | 'settings'; const ModelManagerTab = () => { + const { t } = useTranslation(); + return ( - - - - - - + + + + ); }; diff --git a/invokeai/frontend/web/src/services/api/endpoints/models.ts b/invokeai/frontend/web/src/services/api/endpoints/models.ts index 6897768a1bd..5db2dea00e6 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/models.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/models.ts @@ -26,6 +26,10 @@ type UpdateModelArg = { type UpdateModelResponse = paths['/api/v2/models/i/{key}']['patch']['responses']['200']['content']['application/json']; + +type GetModelResponse = + paths['/api/v2/models/i/{key}']['get']['responses']['200']['content']['application/json']; + type ListModelsArg = NonNullable; type DeleteMainModelArg = { @@ -165,6 +169,12 @@ export const modelsApi = api.injectEndpoints({ providesTags: buildProvidesTags('MainModel'), transformResponse: buildTransformResponse(mainModelsAdapter), }), + getModel: build.query({ + query: (key) => { + return buildModelsUrl(`i/${key}`); + }, + providesTags: ['Model'], + }), updateModels: build.mutation({ query: ({ key, body }) => { return { @@ -320,4 +330,5 @@ export const { useGetModelsInFolderQuery, useGetCheckpointConfigsQuery, useGetModelImportsQuery, + useGetModelQuery } = modelsApi; diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index 3566fdb9e28..c9690649f82 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -22,7 +22,7 @@ export type paths = { "/api/v2/models/i/{key}": { /** * Get Model Record - * @description Get a model record + * @description Get a model record and metadata */ get: operations["get_model_record"]; /** @@ -4202,6 +4202,13 @@ export type components = { */ type: "freeu"; }; + /** GetModelResponse */ + GetModelResponse: { + /** Config */ + config: (components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"]) | (components["schemas"]["ONNXSD1Config"] | components["schemas"]["ONNXSD2Config"] | components["schemas"]["ONNXSDXLConfig"]) | (components["schemas"]["VaeDiffusersConfig"] | components["schemas"]["VaeCheckpointConfig"]) | (components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"]) | components["schemas"]["LoRAConfig"] | components["schemas"]["TextualInversionConfig"] | components["schemas"]["IPAdapterConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"] | components["schemas"]["T2IConfig"]; + /** Metadata */ + metadata: (components["schemas"]["BaseMetadata"] | components["schemas"]["HuggingFaceMetadata"] | components["schemas"]["CivitaiMetadata"]) | null; + }; /** Graph */ Graph: { /** @@ -11169,7 +11176,7 @@ export type operations = { }; /** * Get Model Record - * @description Get a model record + * @description Get a model record and metadata */ get_model_record: { parameters: { @@ -11182,7 +11189,7 @@ export type operations = { /** @description The model configuration was retrieved successfully */ 200: { content: { - "application/json": (components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"]) | (components["schemas"]["ONNXSD1Config"] | components["schemas"]["ONNXSD2Config"] | components["schemas"]["ONNXSDXLConfig"]) | (components["schemas"]["VaeDiffusersConfig"] | components["schemas"]["VaeCheckpointConfig"]) | (components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"]) | components["schemas"]["LoRAConfig"] | components["schemas"]["TextualInversionConfig"] | components["schemas"]["IPAdapterConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"] | components["schemas"]["T2IConfig"]; + "application/json": components["schemas"]["GetModelResponse"]; }; }; /** @description Bad request */ From 05e098b492c5edc51d3b2ab51e8bafa199e9980f Mon Sep 17 00:00:00 2001 From: Mary Hipp Date: Wed, 21 Feb 2024 09:41:50 -0500 Subject: [PATCH 218/340] hook up Add Model button --- .../modelManagerV2/subpanels/ModelManager.tsx | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManager.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManager.tsx index 648f98dbb3e..900a6a9342c 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManager.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManager.tsx @@ -3,8 +3,16 @@ import { SyncModelsIconButton } from 'features/modelManager/components/SyncModel import ModelList from './ModelManagerPanel/ModelList'; import { ModelListNavigation } from './ModelManagerPanel/ModelListNavigation'; +import { useCallback } from 'react'; +import { useAppDispatch } from '../../../app/store/storeHooks'; +import { setSelectedModelKey } from '../store/modelManagerV2Slice'; export const ModelManager = () => { + const dispatch = useAppDispatch(); + const handleClickAddModel = useCallback(() => { + dispatch(setSelectedModelKey(null)); + }, [dispatch]); + return ( @@ -13,7 +21,9 @@ export const ModelManager = () => { - +
From e73fc5dce4ff1930525faa02aa5e9870b31a305b Mon Sep 17 00:00:00 2001 From: Mary Hipp Date: Wed, 21 Feb 2024 13:56:14 -0500 Subject: [PATCH 219/340] edit view for model, depending on type and valid values --- .../store/modelManagerV2Slice.ts | 8 +- .../modelManagerV2/subpanels/ModelPane.tsx | 4 +- .../ModelPanel/Fields/BaseModelSelect.tsx | 29 +++ .../ModelPanel/Fields/BooleanSelect.tsx | 27 +++ .../ModelPanel/Fields/ModelFormatSelect.tsx | 53 +++++ .../ModelPanel/Fields/ModelTypeSelect.tsx | 33 +++ .../ModelPanel/Fields/ModelVariantSelect.tsx | 27 +++ .../Fields/PredictionTypeSelect.tsx | 28 +++ .../ModelPanel/Fields/RepoVariantSelect.tsx | 30 +++ .../subpanels/ModelPanel/Model.tsx | 8 + .../subpanels/ModelPanel/ModelEdit.tsx | 196 ++++++++++++++++++ .../subpanels/ModelPanel/ModelView.tsx | 136 +++++++----- .../features/parameters/types/constants.ts | 6 +- .../web/src/services/api/endpoints/models.ts | 11 +- .../frontend/web/src/services/api/schema.ts | 17 +- 15 files changed, 537 insertions(+), 76 deletions(-) create mode 100644 invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/BaseModelSelect.tsx create mode 100644 invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/BooleanSelect.tsx create mode 100644 invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/ModelFormatSelect.tsx create mode 100644 invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/ModelTypeSelect.tsx create mode 100644 invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/ModelVariantSelect.tsx create mode 100644 invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/PredictionTypeSelect.tsx create mode 100644 invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/RepoVariantSelect.tsx create mode 100644 invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Model.tsx create mode 100644 invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelEdit.tsx diff --git a/invokeai/frontend/web/src/features/modelManagerV2/store/modelManagerV2Slice.ts b/invokeai/frontend/web/src/features/modelManagerV2/store/modelManagerV2Slice.ts index 29b071e9b16..83bd0cf8d52 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/store/modelManagerV2Slice.ts +++ b/invokeai/frontend/web/src/features/modelManagerV2/store/modelManagerV2Slice.ts @@ -6,6 +6,7 @@ import type { PersistConfig, RootState } from 'app/store/store'; type ModelManagerState = { _version: 1; selectedModelKey: string | null; + selectedModelMode: "edit" | "view", searchTerm: string; filteredModelType: string | null; }; @@ -13,6 +14,7 @@ type ModelManagerState = { export const initialModelManagerState: ModelManagerState = { _version: 1, selectedModelKey: null, + selectedModelMode: "view", filteredModelType: null, searchTerm: "" }; @@ -22,8 +24,12 @@ export const modelManagerV2Slice = createSlice({ initialState: initialModelManagerState, reducers: { setSelectedModelKey: (state, action: PayloadAction) => { + state.selectedModelMode = "view" state.selectedModelKey = action.payload; }, + setSelectedModelMode: (state, action: PayloadAction<"view" | "edit">) => { + state.selectedModelMode = action.payload; + }, setSearchTerm: (state, action: PayloadAction) => { state.searchTerm = action.payload; }, @@ -34,7 +40,7 @@ export const modelManagerV2Slice = createSlice({ }, }); -export const { setSelectedModelKey, setSearchTerm, setFilteredModelType } = modelManagerV2Slice.actions; +export const { setSelectedModelKey, setSearchTerm, setFilteredModelType, setSelectedModelMode } = modelManagerV2Slice.actions; export const selectModelManagerSlice = (state: RootState) => state.modelmanager; diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPane.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPane.tsx index 9756b357b35..7658e741d3f 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPane.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPane.tsx @@ -1,13 +1,13 @@ import { Box } from '@invoke-ai/ui-library'; import { useAppSelector } from '../../../app/store/storeHooks'; import { ImportModels } from './ImportModels'; -import { ModelView } from './ModelPanel/ModelView'; +import { Model } from './ModelPanel/Model'; export const ModelPane = () => { const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey); return ( - {selectedModelKey ? : } + {selectedModelKey ? : } ); }; diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/BaseModelSelect.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/BaseModelSelect.tsx new file mode 100644 index 00000000000..da7333c2a8b --- /dev/null +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/BaseModelSelect.tsx @@ -0,0 +1,29 @@ +import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library'; +import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library'; +import { typedMemo } from 'common/util/typedMemo'; +import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; +import { useCallback, useMemo } from 'react'; +import type { UseControllerProps } from 'react-hook-form'; +import { useController } from 'react-hook-form'; +import type { AnyModelConfig } from 'services/api/types'; + +const options: ComboboxOption[] = [ + { value: 'sd-1', label: MODEL_TYPE_MAP['sd-1'] }, + { value: 'sd-2', label: MODEL_TYPE_MAP['sd-2'] }, + { value: 'sdxl', label: MODEL_TYPE_MAP['sdxl'] }, + { value: 'sdxl-refiner', label: MODEL_TYPE_MAP['sdxl-refiner'] }, +]; + +const BaseModelSelect = (props: UseControllerProps) => { + const { field } = useController(props); + const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]); + const onChange = useCallback( + (v) => { + field.onChange(v?.value); + }, + [field] + ); + return ; +}; + +export default typedMemo(BaseModelSelect); diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/BooleanSelect.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/BooleanSelect.tsx new file mode 100644 index 00000000000..d21ee895316 --- /dev/null +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/BooleanSelect.tsx @@ -0,0 +1,27 @@ +import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library'; +import { Combobox } from '@invoke-ai/ui-library'; +import { typedMemo } from 'common/util/typedMemo'; +import { useCallback, useMemo } from 'react'; +import type { UseControllerProps } from 'react-hook-form'; +import { useController } from 'react-hook-form'; +import type { AnyModelConfig } from 'services/api/types'; + +const options: ComboboxOption[] = [ + { value: 'none', label: '-' }, + { value: true as any, label: 'True' }, + { value: false as any, label: 'False' }, +]; + +const BooleanSelect = (props: UseControllerProps) => { + const { field } = useController(props); + const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]); + const onChange = useCallback( + (v) => { + v?.value === 'none' ? field.onChange(undefined) : field.onChange(v?.value); + }, + [field] + ); + return ; +}; + +export default typedMemo(BooleanSelect); diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/ModelFormatSelect.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/ModelFormatSelect.tsx new file mode 100644 index 00000000000..0552789a867 --- /dev/null +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/ModelFormatSelect.tsx @@ -0,0 +1,53 @@ +import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library'; +import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library'; +import { typedMemo } from 'common/util/typedMemo'; +import { LORA_MODEL_FORMAT_MAP, MODEL_TYPE_MAP } from 'features/parameters/types/constants'; +import { useCallback, useMemo } from 'react'; +import type { UseControllerProps } from 'react-hook-form'; +import { useController } from 'react-hook-form'; +import type { AnyModelConfig } from 'services/api/types'; + +const options: ComboboxOption[] = [ + { value: 'sd-1', label: MODEL_TYPE_MAP['sd-1'] }, + { value: 'sd-2', label: MODEL_TYPE_MAP['sd-2'] }, + { value: 'sdxl', label: MODEL_TYPE_MAP['sdxl'] }, + { value: 'sdxl-refiner', label: MODEL_TYPE_MAP['sdxl-refiner'] }, +]; + +const ModelFormatSelect = (props: UseControllerProps) => { + const { field, formState } = useController(props); + + const onChange = useCallback( + (v) => { + field.onChange(v?.value); + }, + [field] + ); + + const options: ComboboxOption[] = useMemo(() => { + if (formState.defaultValues?.type === 'lora') { + return Object.keys(LORA_MODEL_FORMAT_MAP).map((format) => ({ + value: format, + label: LORA_MODEL_FORMAT_MAP[format], + })) as ComboboxOption[]; + } else if (formState.defaultValues?.type === 'embedding') { + return [ + { value: 'embedding_file', label: 'Embedding File' }, + { value: 'embedding_folder', label: 'Embedding Folder' }, + ]; + } else if (formState.defaultValues?.type === 'ip_adapter') { + return [{ value: 'invokeai', label: 'invokeai' }]; + } else { + return [ + { value: 'diffusers', label: 'Diffusers' }, + { value: 'checkpoint', label: 'Checkpoint' }, + ]; + } + }, [formState.defaultValues?.type]); + + const value = useMemo(() => options.find((o) => o.value === field.value), [options, field.value]); + + return ; +}; + +export default typedMemo(ModelFormatSelect); diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/ModelTypeSelect.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/ModelTypeSelect.tsx new file mode 100644 index 00000000000..140bfa9fe08 --- /dev/null +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/ModelTypeSelect.tsx @@ -0,0 +1,33 @@ +import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library'; +import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library'; +import { typedMemo } from 'common/util/typedMemo'; +import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; +import { useCallback, useMemo } from 'react'; +import type { UseControllerProps } from 'react-hook-form'; +import { useController } from 'react-hook-form'; +import type { AnyModelConfig } from 'services/api/types'; +import { MODEL_TYPE_LABELS } from '../../ModelManagerPanel/ModelTypeFilter'; + +const options: ComboboxOption[] = [ + { value: 'main', label: MODEL_TYPE_LABELS['main'] as string }, + { value: 'lora', label: MODEL_TYPE_LABELS['lora'] as string }, + { value: 'embedding', label: MODEL_TYPE_LABELS['embedding'] as string }, + { value: 'vae', label: MODEL_TYPE_LABELS['vae'] as string }, + { value: 'controlnet', label: MODEL_TYPE_LABELS['controlnet'] as string }, + { value: 'ip_adapter', label: MODEL_TYPE_LABELS['ip_adapter'] as string }, + { value: 't2i_adapater', label: MODEL_TYPE_LABELS['t2i_adapter'] as string }, +]; + +const ModelTypeSelect = (props: UseControllerProps) => { + const { field } = useController(props); + const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]); + const onChange = useCallback( + (v) => { + field.onChange(v?.value); + }, + [field] + ); + return ; +}; + +export default typedMemo(ModelTypeSelect); diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/ModelVariantSelect.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/ModelVariantSelect.tsx new file mode 100644 index 00000000000..7fb74b0bd91 --- /dev/null +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/ModelVariantSelect.tsx @@ -0,0 +1,27 @@ +import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library'; +import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library'; +import { typedMemo } from 'common/util/typedMemo'; +import { useCallback, useMemo } from 'react'; +import type { UseControllerProps } from 'react-hook-form'; +import { useController } from 'react-hook-form'; +import type { AnyModelConfig, CheckpointModelConfig, DiffusersModelConfig } from 'services/api/types'; + +const options: ComboboxOption[] = [ + { value: 'normal', label: 'Normal' }, + { value: 'inpaint', label: 'Inpaint' }, + { value: 'depth', label: 'Depth' }, +]; + +const ModelVariantSelect = (props: UseControllerProps) => { + const { field } = useController(props); + const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]); + const onChange = useCallback( + (v) => { + field.onChange(v?.value); + }, + [field] + ); + return ; +}; + +export default typedMemo(ModelVariantSelect); diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/PredictionTypeSelect.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/PredictionTypeSelect.tsx new file mode 100644 index 00000000000..20667ab5bc5 --- /dev/null +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/PredictionTypeSelect.tsx @@ -0,0 +1,28 @@ +import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library'; +import { Combobox } from '@invoke-ai/ui-library'; +import { typedMemo } from 'common/util/typedMemo'; +import { useCallback, useMemo } from 'react'; +import type { UseControllerProps } from 'react-hook-form'; +import { useController } from 'react-hook-form'; +import type { AnyModelConfig } from 'services/api/types'; + +const options: ComboboxOption[] = [ + { value: 'none', label: '-' }, + { value: 'epsilon', label: 'epsilon' }, + { value: 'v_prediction', label: 'v_prediction' }, + { value: 'sample', label: 'sample' }, +]; + +const PredictionTypeSelect = (props: UseControllerProps) => { + const { field } = useController(props); + const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]); + const onChange = useCallback( + (v) => { + v?.value === 'none' ? field.onChange(undefined) : field.onChange(v?.value); + }, + [field] + ); + return ; +}; + +export default typedMemo(PredictionTypeSelect); diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/RepoVariantSelect.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/RepoVariantSelect.tsx new file mode 100644 index 00000000000..74793be789f --- /dev/null +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/RepoVariantSelect.tsx @@ -0,0 +1,30 @@ +import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library'; +import { Combobox } from '@invoke-ai/ui-library'; +import { typedMemo } from 'common/util/typedMemo'; +import { useCallback, useMemo } from 'react'; +import type { UseControllerProps } from 'react-hook-form'; +import { useController } from 'react-hook-form'; +import type { AnyModelConfig } from 'services/api/types'; + +const options: ComboboxOption[] = [ + { value: 'none', label: '-' }, + { value: 'fp16', label: 'fp16' }, + { value: 'fp32', label: 'fp32' }, + { value: 'onnx', label: 'onnx' }, + { value: 'openvino', label: 'openvino' }, + { value: 'flax', label: 'flax' }, +]; + +const RepoVariantSelect = (props: UseControllerProps) => { + const { field } = useController(props); + const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]); + const onChange = useCallback( + (v) => { + v?.value === 'none' ? field.onChange(undefined) : field.onChange(v?.value); + }, + [field] + ); + return ; +}; + +export default typedMemo(RepoVariantSelect); diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Model.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Model.tsx new file mode 100644 index 00000000000..8a6f7ddee4d --- /dev/null +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Model.tsx @@ -0,0 +1,8 @@ +import { useAppSelector } from '../../../../app/store/storeHooks'; +import { ModelEdit } from './ModelEdit'; +import { ModelView } from './ModelView'; + +export const Model = () => { + const selectedModelMode = useAppSelector((s) => s.modelmanagerV2.selectedModelMode); + return selectedModelMode === 'view' ? : ; +}; diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelEdit.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelEdit.tsx new file mode 100644 index 00000000000..70d0596cd8a --- /dev/null +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelEdit.tsx @@ -0,0 +1,196 @@ +import { skipToken } from '@reduxjs/toolkit/query'; +import { useAppDispatch, useAppSelector } from '../../../../app/store/storeHooks'; +import { useGetModelQuery } from '../../../../services/api/endpoints/models'; +import { Flex, Text, Heading, Button, Input, FormControl, FormLabel, Textarea } from '@invoke-ai/ui-library'; +import { useCallback, useMemo } from 'react'; +import { + AnyModelConfig, + CheckpointModelConfig, + ControlNetConfig, + DiffusersModelConfig, + IPAdapterConfig, + LoRAConfig, + T2IAdapterConfig, + TextualInversionConfig, + VAEConfig, +} from '../../../../services/api/types'; +import { setSelectedModelMode } from '../../store/modelManagerV2Slice'; +import BaseModelSelect from './Fields/BaseModelSelect'; +import { useForm } from 'react-hook-form'; +import ModelTypeSelect from './Fields/ModelTypeSelect'; +import ModelVariantSelect from './Fields/ModelVariantSelect'; +import RepoVariantSelect from './Fields/RepoVariantSelect'; +import PredictionTypeSelect from './Fields/PredictionTypeSelect'; +import BooleanSelect from './Fields/BooleanSelect'; +import ModelFormatSelect from './Fields/ModelFormatSelect'; + +export const ModelEdit = () => { + const dispatch = useAppDispatch(); + const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey); + const { data, isLoading } = useGetModelQuery(selectedModelKey ?? skipToken); + + const modelData = useMemo(() => { + if (!data) { + return null; + } + const modelFormat = data.format; + const modelType = data.type; + + if (modelType === 'main') { + if (modelFormat === 'diffusers') { + return data as DiffusersModelConfig; + } else if (modelFormat === 'checkpoint') { + return data as CheckpointModelConfig; + } + } + + switch (modelType) { + case 'lora': + return data as LoRAConfig; + case 'embedding': + return data as TextualInversionConfig; + case 't2i_adapter': + return data as T2IAdapterConfig; + case 'ip_adapter': + return data as IPAdapterConfig; + case 'controlnet': + return data as ControlNetConfig; + case 'vae': + return data as VAEConfig; + default: + return data as DiffusersModelConfig; + } + }, [data]); + + const { + register, + handleSubmit, + control, + formState: { errors }, + reset, + } = useForm({ + defaultValues: { + ...modelData, + }, + mode: 'onChange', + }); + + const handleClickCancel = useCallback(() => { + dispatch(setSelectedModelMode('view')); + }, [dispatch]); + + if (isLoading) { + return Loading; + } + + if (!modelData) { + return Something went wrong; + } + return ( + + + value.trim().length > 3 || 'Must be at least 3 characters', + })} + size="lg" + /> + + + + + + + + + + Description +