diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index ad16812f9ab..17bedea6d3a 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -1,7 +1,3 @@ -from typing import Any - -from fastapi.responses import HTMLResponse - from .services.config import InvokeAIAppConfig # parse_args() must be called before any other imports. if it is not called first, consumers of the config @@ -16,6 +12,7 @@ import socket from inspect import signature from pathlib import Path + from typing import Any import uvicorn from fastapi import FastAPI @@ -23,7 +20,7 @@ from fastapi.middleware.gzip import GZipMiddleware from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html from fastapi.openapi.utils import get_openapi - from fastapi.responses import FileResponse + from fastapi.responses import FileResponse, HTMLResponse from fastapi.staticfiles import StaticFiles from fastapi_events.handlers.local import local_handler from fastapi_events.middleware import EventHandlerASGIMiddleware @@ -51,7 +48,12 @@ workflows, ) from .api.sockets import SocketIO - from .invocations.baseinvocation import BaseInvocation, UIConfigBase, _InputField, _OutputField + from .invocations.baseinvocation import ( + BaseInvocation, + InputFieldJSONSchemaExtra, + OutputFieldJSONSchemaExtra, + UIConfigBase, + ) if is_mps_available(): import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import) @@ -147,7 +149,11 @@ def custom_openapi() -> dict[str, Any]: # Add Node Editor UI helper schemas ui_config_schemas = models_json_schema( - [(UIConfigBase, "serialization"), (_InputField, "serialization"), (_OutputField, "serialization")], + [ + (UIConfigBase, "serialization"), + (InputFieldJSONSchemaExtra, "serialization"), + (OutputFieldJSONSchemaExtra, "serialization"), + ], ref_template="#/components/schemas/{model}", ) for schema_key, ui_config_schema in ui_config_schemas[1]["$defs"].items(): @@ -155,7 +161,7 @@ def custom_openapi() -> dict[str, Any]: # Add a reference to the output type to additionalProperties of the invoker schema for invoker in all_invocations: - invoker_name = invoker.__name__ + invoker_name = invoker.__name__ # type: ignore [attr-defined] # this is a valid attribute output_type = signature(obj=invoker.invoke).return_annotation output_type_title = output_type_titles[output_type.__name__] invoker_schema = openapi_schema["components"]["schemas"][f"{invoker_name}"] diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index 1b3e535d340..51849dcbccc 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) +# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI team from __future__ import annotations @@ -8,7 +8,7 @@ from enum import Enum from inspect import signature from types import UnionType -from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Literal, Optional, Type, TypeVar, Union +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Literal, Optional, Type, TypeVar, Union, cast import semver from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter, create_model @@ -17,11 +17,17 @@ from invokeai.app.services.config.config_default import InvokeAIAppConfig from invokeai.app.shared.fields import FieldDescriptions +from invokeai.app.util.metaenum import MetaEnum from invokeai.app.util.misc import uuid_string +from invokeai.backend.util.logging import InvokeAILogger if TYPE_CHECKING: from ..services.invocation_services import InvocationServices +logger = InvokeAILogger.get_logger() + +CUSTOM_NODE_PACK_SUFFIX = "__invokeai-custom-node" + class InvalidVersionError(ValueError): pass @@ -31,7 +37,7 @@ class InvalidFieldError(TypeError): pass -class Input(str, Enum): +class Input(str, Enum, metaclass=MetaEnum): """ The type of input a field accepts. - `Input.Direct`: The field must have its value provided directly, when the invocation and field \ @@ -45,86 +51,124 @@ class Input(str, Enum): Any = "any" -class UIType(str, Enum): +class FieldKind(str, Enum, metaclass=MetaEnum): """ - Type hints for the UI. - If a field should be provided a data type that does not exactly match the python type of the field, \ - use this to provide the type that should be used instead. See the node development docs for detail \ - on adding a new field type, which involves client-side changes. + The kind of field. + - `Input`: An input field on a node. + - `Output`: An output field on a node. + - `Internal`: A field which is treated as an input, but cannot be used in node definitions. Metadata is + one example. It is provided to nodes via the WithMetadata class, and we want to reserve the field name + "metadata" for this on all nodes. `FieldKind` is used to short-circuit the field name validation logic, + allowing "metadata" for that field. + - `NodeAttribute`: The field is a node attribute. These are fields which are not inputs or outputs, + but which are used to store information about the node. For example, the `id` and `type` fields are node + attributes. + + The presence of this in `json_schema_extra["field_kind"]` is used when initializing node schemas on app + startup, and when generating the OpenAPI schema for the workflow editor. """ - # region Primitives - Boolean = "boolean" - Color = "ColorField" - Conditioning = "ConditioningField" - Control = "ControlField" - Float = "float" - Image = "ImageField" - Integer = "integer" - Latents = "LatentsField" - String = "string" - # endregion + Input = "input" + Output = "output" + Internal = "internal" + NodeAttribute = "node_attribute" - # region Collection Primitives - BooleanCollection = "BooleanCollection" - ColorCollection = "ColorCollection" - ConditioningCollection = "ConditioningCollection" - ControlCollection = "ControlCollection" - FloatCollection = "FloatCollection" - ImageCollection = "ImageCollection" - IntegerCollection = "IntegerCollection" - LatentsCollection = "LatentsCollection" - StringCollection = "StringCollection" - # endregion - # region Polymorphic Primitives - BooleanPolymorphic = "BooleanPolymorphic" - ColorPolymorphic = "ColorPolymorphic" - ConditioningPolymorphic = "ConditioningPolymorphic" - ControlPolymorphic = "ControlPolymorphic" - FloatPolymorphic = "FloatPolymorphic" - ImagePolymorphic = "ImagePolymorphic" - IntegerPolymorphic = "IntegerPolymorphic" - LatentsPolymorphic = "LatentsPolymorphic" - StringPolymorphic = "StringPolymorphic" - # endregion +class UIType(str, Enum, metaclass=MetaEnum): + """ + Type hints for the UI for situations in which the field type is not enough to infer the correct UI type. + + - Model Fields + The most common node-author-facing use will be for model fields. Internally, there is no difference + between SD-1, SD-2 and SDXL model fields - they all use the class `MainModelField`. To ensure the + base-model-specific UI is rendered, use e.g. `ui_type=UIType.SDXLMainModelField` to indicate that + the field is an SDXL main model field. + + - Any Field + We cannot infer the usage of `typing.Any` via schema parsing, so you *must* use `ui_type=UIType.Any` to + indicate that the field accepts any type. Use with caution. This cannot be used on outputs. + + - Scheduler Field + Special handling in the UI is needed for this field, which otherwise would be parsed as a plain enum field. - # region Models - MainModel = "MainModelField" + - Internal Fields + Similar to the Any Field, the `collect` and `iterate` nodes make use of `typing.Any`. To facilitate + handling these types in the client, we use `UIType._Collection` and `UIType._CollectionItem`. These + should not be used by node authors. + + - DEPRECATED Fields + These types are deprecated and should not be used by node authors. A warning will be logged if one is + used, and the type will be ignored. They are included here for backwards compatibility. + """ + + # region Model Field Types SDXLMainModel = "SDXLMainModelField" SDXLRefinerModel = "SDXLRefinerModelField" ONNXModel = "ONNXModelField" - VaeModel = "VaeModelField" + VaeModel = "VAEModelField" LoRAModel = "LoRAModelField" ControlNetModel = "ControlNetModelField" IPAdapterModel = "IPAdapterModelField" - UNet = "UNetField" - Vae = "VaeField" - CLIP = "ClipField" # endregion - # region Iterate/Collect - Collection = "Collection" - CollectionItem = "CollectionItem" + # region Misc Field Types + Scheduler = "SchedulerField" + Any = "AnyField" + # endregion + + # region Internal Field Types + _Collection = "CollectionField" + _CollectionItem = "CollectionItemField" # endregion - # region Misc - Enum = "enum" - Scheduler = "Scheduler" - WorkflowField = "WorkflowField" - IsIntermediate = "IsIntermediate" - BoardField = "BoardField" - Any = "Any" - MetadataItem = "MetadataItem" - MetadataItemCollection = "MetadataItemCollection" - MetadataItemPolymorphic = "MetadataItemPolymorphic" - MetadataDict = "MetadataDict" + # region DEPRECATED + Boolean = "DEPRECATED_Boolean" + Color = "DEPRECATED_Color" + Conditioning = "DEPRECATED_Conditioning" + Control = "DEPRECATED_Control" + Float = "DEPRECATED_Float" + Image = "DEPRECATED_Image" + Integer = "DEPRECATED_Integer" + Latents = "DEPRECATED_Latents" + String = "DEPRECATED_String" + BooleanCollection = "DEPRECATED_BooleanCollection" + ColorCollection = "DEPRECATED_ColorCollection" + ConditioningCollection = "DEPRECATED_ConditioningCollection" + ControlCollection = "DEPRECATED_ControlCollection" + FloatCollection = "DEPRECATED_FloatCollection" + ImageCollection = "DEPRECATED_ImageCollection" + IntegerCollection = "DEPRECATED_IntegerCollection" + LatentsCollection = "DEPRECATED_LatentsCollection" + StringCollection = "DEPRECATED_StringCollection" + BooleanPolymorphic = "DEPRECATED_BooleanPolymorphic" + ColorPolymorphic = "DEPRECATED_ColorPolymorphic" + ConditioningPolymorphic = "DEPRECATED_ConditioningPolymorphic" + ControlPolymorphic = "DEPRECATED_ControlPolymorphic" + FloatPolymorphic = "DEPRECATED_FloatPolymorphic" + ImagePolymorphic = "DEPRECATED_ImagePolymorphic" + IntegerPolymorphic = "DEPRECATED_IntegerPolymorphic" + LatentsPolymorphic = "DEPRECATED_LatentsPolymorphic" + StringPolymorphic = "DEPRECATED_StringPolymorphic" + MainModel = "DEPRECATED_MainModel" + UNet = "DEPRECATED_UNet" + Vae = "DEPRECATED_Vae" + CLIP = "DEPRECATED_CLIP" + Collection = "DEPRECATED_Collection" + CollectionItem = "DEPRECATED_CollectionItem" + Enum = "DEPRECATED_Enum" + WorkflowField = "DEPRECATED_WorkflowField" + IsIntermediate = "DEPRECATED_IsIntermediate" + BoardField = "DEPRECATED_BoardField" + MetadataItem = "DEPRECATED_MetadataItem" + MetadataItemCollection = "DEPRECATED_MetadataItemCollection" + MetadataItemPolymorphic = "DEPRECATED_MetadataItemPolymorphic" + MetadataDict = "DEPRECATED_MetadataDict" # endregion -class UIComponent(str, Enum): +class UIComponent(str, Enum, metaclass=MetaEnum): """ - The type of UI component to use for a field, used to override the default components, which are \ + The type of UI component to use for a field, used to override the default components, which are inferred from the field type. """ @@ -133,21 +177,22 @@ class UIComponent(str, Enum): Slider = "slider" -class _InputField(BaseModel): +class InputFieldJSONSchemaExtra(BaseModel): """ - *DO NOT USE* - This helper class is used to tell the client about our custom field attributes via OpenAPI - schema generation, and Typescript type generation from that schema. It serves no functional - purpose in the backend. + Extra attributes to be added to input fields and their OpenAPI schema. Used during graph execution, + and by the workflow editor during schema parsing and UI rendering. """ input: Input - ui_hidden: bool - ui_type: Optional[UIType] - ui_component: Optional[UIComponent] - ui_order: Optional[int] - ui_choice_labels: Optional[dict[str, str]] - item_default: Optional[Any] + orig_required: bool + field_kind: FieldKind + default: Optional[Any] = None + orig_default: Optional[Any] = None + ui_hidden: bool = False + ui_type: Optional[UIType] = None + ui_component: Optional[UIComponent] = None + ui_order: Optional[int] = None + ui_choice_labels: Optional[dict[str, str]] = None model_config = ConfigDict( validate_assignment=True, @@ -155,14 +200,13 @@ class _InputField(BaseModel): ) -class _OutputField(BaseModel): +class OutputFieldJSONSchemaExtra(BaseModel): """ - *DO NOT USE* - This helper class is used to tell the client about our custom field attributes via OpenAPI - schema generation, and Typescript type generation from that schema. It serves no functional - purpose in the backend. + Extra attributes to be added to input fields and their OpenAPI schema. Used by the workflow editor + during schema parsing and UI rendering. """ + field_kind: FieldKind ui_hidden: bool ui_type: Optional[UIType] ui_order: Optional[int] @@ -173,13 +217,9 @@ class _OutputField(BaseModel): ) -def get_type(klass: BaseModel) -> str: - """Helper function to get an invocation or invocation output's type. This is the default value of the `type` field.""" - return klass.model_fields["type"].default - - def InputField( # copied from pydantic's Field + # TODO: Can we support default_factory? default: Any = _Unset, default_factory: Callable[[], Any] | None = _Unset, title: str | None = _Unset, @@ -203,12 +243,11 @@ def InputField( ui_hidden: bool = False, ui_order: Optional[int] = None, ui_choice_labels: Optional[dict[str, str]] = None, - item_default: Optional[Any] = None, ) -> Any: """ Creates an input field for an invocation. - This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/1.10/usage/schema/#field-customization) \ + This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/latest/api/fields/#pydantic.fields.Field) \ that adds a few extra parameters to support graph execution and the node editor UI. :param Input input: [Input.Any] The kind of input this field requires. \ @@ -228,28 +267,59 @@ def InputField( For example, a `string` field will default to a single-line input, but you may want a multi-line textarea instead. \ For this case, you could provide `UIComponent.Textarea`. - : param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI. + :param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI. - : param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. \ + :param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. - : param bool item_default: [None] Specifies the default item value, if this is a collection input. \ - Ignored for non-collection fields. + :param dict[str, str] ui_choice_labels: [None] Specifies the labels to use for the choices in an enum field. """ - json_schema_extra_: dict[str, Any] = { - "input": input, - "ui_type": ui_type, - "ui_component": ui_component, - "ui_hidden": ui_hidden, - "ui_order": ui_order, - "item_default": item_default, - "ui_choice_labels": ui_choice_labels, - "_field_kind": "input", - } + json_schema_extra_ = InputFieldJSONSchemaExtra( + input=input, + ui_type=ui_type, + ui_component=ui_component, + ui_hidden=ui_hidden, + ui_order=ui_order, + ui_choice_labels=ui_choice_labels, + field_kind=FieldKind.Input, + orig_required=True, + ) + + """ + There is a conflict between the typing of invocation definitions and the typing of an invocation's + `invoke()` function. + + On instantiation of a node, the invocation definition is used to create the python class. At this time, + any number of fields may be optional, because they may be provided by connections. + + On calling of `invoke()`, however, those fields may be required. + + For example, consider an ResizeImageInvocation with an `image: ImageField` field. + + `image` is required during the call to `invoke()`, but when the python class is instantiated, + the field may not be present. This is fine, because that image field will be provided by a + connection from an ancestor node, which outputs an image. + + This means we want to type the `image` field as optional for the node class definition, but required + for the `invoke()` function. + + If we use `typing.Optional` in the node class definition, the field will be typed as optional in the + `invoke()` method, and we'll have to do a lot of runtime checks to ensure the field is present - or + any static type analysis tools will complain. + + To get around this, in node class definitions, we type all fields correctly for the `invoke()` function, + but secretly make them optional in `InputField()`. We also store the original required bool and/or default + value. When we call `invoke()`, we use this stored information to do an additional check on the class. + """ + + if default_factory is not _Unset and default_factory is not None: + default = default_factory() + del default_factory + logger.warn('"default_factory" is not supported, calling it now to set "default"') + # These are the args we may wish pass to the pydantic `Field()` function field_args = { "default": default, - "default_factory": default_factory, "title": title, "description": description, "pattern": pattern, @@ -266,70 +336,34 @@ def InputField( "max_length": max_length, } - """ - Invocation definitions have their fields typed correctly for their `invoke()` functions. - This typing is often more specific than the actual invocation definition requires, because - fields may have values provided only by connections. - - For example, consider an ResizeImageInvocation with an `image: ImageField` field. - - `image` is required during the call to `invoke()`, but when the python class is instantiated, - the field may not be present. This is fine, because that image field will be provided by a - an ancestor node that outputs the image. - - So we'd like to type that `image` field as `Optional[ImageField]`. If we do that, however, then - we need to handle a lot of extra logic in the `invoke()` function to check if the field has a - value or not. This is very tedious. - - Ideally, the invocation definition would be able to specify that the field is required during - invocation, but optional during instantiation. So the field would be typed as `image: ImageField`, - but when calling the `invoke()` function, we raise an error if the field is not present. - - To do this, we need to do a bit of fanagling to make the pydantic field optional, and then do - extra validation when calling `invoke()`. - - There is some additional logic here to cleaning create the pydantic field via the wrapper. - """ - - # Filter out field args not provided + # We only want to pass the args that were provided, otherwise the `Field()`` function won't work as expected provided_args = {k: v for (k, v) in field_args.items() if v is not PydanticUndefined} - if (default is not PydanticUndefined) and (default_factory is not PydanticUndefined): - raise ValueError("Cannot specify both default and default_factory") - - # because we are manually making fields optional, we need to store the original required bool for reference later - if default is PydanticUndefined and default_factory is PydanticUndefined: - json_schema_extra_.update({"orig_required": True}) - else: - json_schema_extra_.update({"orig_required": False}) + # Because we are manually making fields optional, we need to store the original required bool for reference later + json_schema_extra_.orig_required = default is PydanticUndefined - # make Input.Any and Input.Connection fields optional, providing None as a default if the field doesn't already have one - if (input is Input.Any or input is Input.Connection) and default_factory is PydanticUndefined: + # Make Input.Any and Input.Connection fields optional, providing None as a default if the field doesn't already have one + if input is Input.Any or input is Input.Connection: default_ = None if default is PydanticUndefined else default provided_args.update({"default": default_}) if default is not PydanticUndefined: - # before invoking, we'll grab the original default value and set it on the field if the field wasn't provided a value - json_schema_extra_.update({"default": default}) - json_schema_extra_.update({"orig_default": default}) - elif default is not PydanticUndefined and default_factory is PydanticUndefined: + # Before invoking, we'll check for the original default value and set it on the field if the field has no value + json_schema_extra_.default = default + json_schema_extra_.orig_default = default + elif default is not PydanticUndefined: default_ = default provided_args.update({"default": default_}) - json_schema_extra_.update({"orig_default": default_}) - elif default_factory is not PydanticUndefined: - provided_args.update({"default_factory": default_factory}) - # TODO: cannot serialize default_factory... - # json_schema_extra_.update(dict(orig_default_factory=default_factory)) + json_schema_extra_.orig_default = default_ return Field( **provided_args, - json_schema_extra=json_schema_extra_, + json_schema_extra=json_schema_extra_.model_dump(exclude_none=True), ) def OutputField( # copied from pydantic's Field default: Any = _Unset, - default_factory: Callable[[], Any] | None = _Unset, title: str | None = _Unset, description: str | None = _Unset, pattern: str | None = _Unset, @@ -362,13 +396,12 @@ def OutputField( `MainModelField`. So to ensure the base-model-specific UI is rendered, you can use \ `UIType.SDXLMainModelField` to indicate that the field is an SDXL main model field. - : param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI. \ + :param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI. \ - : param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. \ + :param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. \ """ return Field( default=default, - default_factory=default_factory, title=title, description=description, pattern=pattern, @@ -383,12 +416,12 @@ def OutputField( decimal_places=decimal_places, min_length=min_length, max_length=max_length, - json_schema_extra={ - "ui_type": ui_type, - "ui_hidden": ui_hidden, - "ui_order": ui_order, - "_field_kind": "output", - }, + json_schema_extra=OutputFieldJSONSchemaExtra( + ui_type=ui_type, + ui_hidden=ui_hidden, + ui_order=ui_order, + field_kind=FieldKind.Output, + ).model_dump(exclude_none=True), ) @@ -401,10 +434,10 @@ class UIConfigBase(BaseModel): tags: Optional[list[str]] = Field(default_factory=None, description="The node's tags") title: Optional[str] = Field(default=None, description="The node's display name") category: Optional[str] = Field(default=None, description="The node's category") - version: Optional[str] = Field( - default=None, + version: str = Field( description='The node\'s version. Should be a valid semver string e.g. "1.0.0" or "3.8.13".', ) + node_pack: Optional[str] = Field(default=None, description="Whether or not this is a custom node") model_config = ConfigDict( validate_assignment=True, @@ -447,29 +480,39 @@ class BaseInvocationOutput(BaseModel): @classmethod def register_output(cls, output: BaseInvocationOutput) -> None: + """Registers an invocation output.""" cls._output_classes.add(output) @classmethod def get_outputs(cls) -> Iterable[BaseInvocationOutput]: + """Gets all invocation outputs.""" return cls._output_classes @classmethod def get_outputs_union(cls) -> UnionType: + """Gets a union of all invocation outputs.""" outputs_union = Union[tuple(cls._output_classes)] # type: ignore [valid-type] return outputs_union # type: ignore [return-value] @classmethod def get_output_types(cls) -> Iterable[str]: - return (get_type(i) for i in BaseInvocationOutput.get_outputs()) + """Gets all invocation output types.""" + return (i.get_type() for i in BaseInvocationOutput.get_outputs()) @staticmethod def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None: + """Adds various UI-facing attributes to the invocation output's OpenAPI schema.""" # Because we use a pydantic Literal field with default value for the invocation type, # it will be typed as optional in the OpenAPI schema. Make it required manually. if "required" not in schema or not isinstance(schema["required"], list): schema["required"] = [] schema["required"].extend(["type"]) + @classmethod + def get_type(cls) -> str: + """Gets the invocation output's type, as provided by the `@invocation_output` decorator.""" + return cls.model_fields["type"].default + model_config = ConfigDict( protected_namespaces=(), validate_assignment=True, @@ -499,21 +542,29 @@ class BaseInvocation(ABC, BaseModel): _invocation_classes: ClassVar[set[BaseInvocation]] = set() + @classmethod + def get_type(cls) -> str: + """Gets the invocation's type, as provided by the `@invocation` decorator.""" + return cls.model_fields["type"].default + @classmethod def register_invocation(cls, invocation: BaseInvocation) -> None: + """Registers an invocation.""" cls._invocation_classes.add(invocation) @classmethod def get_invocations_union(cls) -> UnionType: + """Gets a union of all invocation types.""" invocations_union = Union[tuple(cls._invocation_classes)] # type: ignore [valid-type] return invocations_union # type: ignore [return-value] @classmethod def get_invocations(cls) -> Iterable[BaseInvocation]: + """Gets all invocations, respecting the allowlist and denylist.""" app_config = InvokeAIAppConfig.get_config() allowed_invocations: set[BaseInvocation] = set() for sc in cls._invocation_classes: - invocation_type = get_type(sc) + invocation_type = sc.get_type() is_in_allowlist = ( invocation_type in app_config.allow_nodes if isinstance(app_config.allow_nodes, list) else True ) @@ -526,28 +577,32 @@ def get_invocations(cls) -> Iterable[BaseInvocation]: @classmethod def get_invocations_map(cls) -> dict[str, BaseInvocation]: - # Get the type strings out of the literals and into a dictionary - return {get_type(i): i for i in BaseInvocation.get_invocations()} + """Gets a map of all invocation types to their invocation classes.""" + return {i.get_type(): i for i in BaseInvocation.get_invocations()} @classmethod def get_invocation_types(cls) -> Iterable[str]: - return (get_type(i) for i in BaseInvocation.get_invocations()) + """Gets all invocation types.""" + return (i.get_type() for i in BaseInvocation.get_invocations()) @classmethod - def get_output_type(cls) -> BaseInvocationOutput: + def get_output_annotation(cls) -> BaseInvocationOutput: + """Gets the invocation's output annotation (i.e. the return annotation of its `invoke()` method).""" return signature(cls.invoke).return_annotation @staticmethod - def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None: - # Add the various UI-facing attributes to the schema. These are used to build the invocation templates. - uiconfig = getattr(model_class, "UIConfig", None) - if uiconfig and hasattr(uiconfig, "title"): - schema["title"] = uiconfig.title - if uiconfig and hasattr(uiconfig, "tags"): - schema["tags"] = uiconfig.tags - if uiconfig and hasattr(uiconfig, "category"): - schema["category"] = uiconfig.category - if uiconfig and hasattr(uiconfig, "version"): + def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel], *args, **kwargs) -> None: + """Adds various UI-facing attributes to the invocation's OpenAPI schema.""" + uiconfig = cast(UIConfigBase | None, getattr(model_class, "UIConfig", None)) + if uiconfig is not None: + if uiconfig.title is not None: + schema["title"] = uiconfig.title + if uiconfig.tags is not None: + schema["tags"] = uiconfig.tags + if uiconfig.category is not None: + schema["category"] = uiconfig.category + if uiconfig.node_pack is not None: + schema["node_pack"] = uiconfig.node_pack schema["version"] = uiconfig.version if "required" not in schema or not isinstance(schema["required"], list): schema["required"] = [] @@ -559,6 +614,10 @@ def invoke(self, context: InvocationContext) -> BaseInvocationOutput: pass def invoke_internal(self, context: InvocationContext) -> BaseInvocationOutput: + """ + Internal invoke method, calls `invoke()` after some prep. + Handles optional fields that are required to call `invoke()` and invocation cache. + """ for field_name, field in self.model_fields.items(): if not field.json_schema_extra or callable(field.json_schema_extra): # something has gone terribly awry, we should always have this and it should be a dict @@ -598,21 +657,20 @@ def invoke_internal(self, context: InvocationContext) -> BaseInvocationOutput: context.services.logger.debug(f'Skipping invocation cache for "{self.get_type()}": {self.id}') return self.invoke(context) - def get_type(self) -> str: - return self.model_fields["type"].default - id: str = Field( default_factory=uuid_string, description="The id of this instance of an invocation. Must be unique among all instances of invocations.", - json_schema_extra={"_field_kind": "internal"}, + json_schema_extra={"field_kind": FieldKind.NodeAttribute}, ) is_intermediate: bool = Field( default=False, description="Whether or not this is an intermediate invocation.", - json_schema_extra={"ui_type": UIType.IsIntermediate, "_field_kind": "internal"}, + json_schema_extra={"ui_type": "IsIntermediate", "field_kind": FieldKind.NodeAttribute}, ) use_cache: bool = Field( - default=True, description="Whether or not to use the cache", json_schema_extra={"_field_kind": "internal"} + default=True, + description="Whether or not to use the cache", + json_schema_extra={"field_kind": FieldKind.NodeAttribute}, ) UIConfig: ClassVar[Type[UIConfigBase]] @@ -629,12 +687,15 @@ def get_type(self) -> str: TBaseInvocation = TypeVar("TBaseInvocation", bound=BaseInvocation) -RESERVED_INPUT_FIELD_NAMES = { +RESERVED_NODE_ATTRIBUTE_FIELD_NAMES = { "id", "is_intermediate", "use_cache", "type", "workflow", +} + +RESERVED_INPUT_FIELD_NAMES = { "metadata", } @@ -652,40 +713,59 @@ class _Model(BaseModel): def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None: """ Validates the fields of an invocation or invocation output: - - must not override any pydantic reserved fields - - must be created via `InputField`, `OutputField`, or be an internal field defined in this file + - Must not override any pydantic reserved fields + - Must have a type annotation + - Must have a json_schema_extra dict + - Must have field_kind in json_schema_extra + - Field name must not be reserved, according to its field_kind """ for name, field in model_fields.items(): if name in RESERVED_PYDANTIC_FIELD_NAMES: raise InvalidFieldError(f'Invalid field name "{name}" on "{model_type}" (reserved by pydantic)') - field_kind = ( - # _field_kind is defined via InputField(), OutputField() or by one of the internal fields defined in this file - field.json_schema_extra.get("_field_kind", None) if field.json_schema_extra else None - ) + if not field.annotation: + raise InvalidFieldError(f'Invalid field type "{name}" on "{model_type}" (missing annotation)') + + if not isinstance(field.json_schema_extra, dict): + raise InvalidFieldError( + f'Invalid field definition for "{name}" on "{model_type}" (missing json_schema_extra dict)' + ) + + field_kind = field.json_schema_extra.get("field_kind", None) # must have a field_kind - if field_kind is None or field_kind not in {"input", "output", "internal"}: + if not isinstance(field_kind, FieldKind): raise InvalidFieldError( f'Invalid field definition for "{name}" on "{model_type}" (maybe it\'s not an InputField or OutputField?)' ) - if field_kind == "input" and name in RESERVED_INPUT_FIELD_NAMES: + if field_kind is FieldKind.Input and ( + name in RESERVED_NODE_ATTRIBUTE_FIELD_NAMES or name in RESERVED_INPUT_FIELD_NAMES + ): raise InvalidFieldError(f'Invalid field name "{name}" on "{model_type}" (reserved input field name)') - if field_kind == "output" and name in RESERVED_OUTPUT_FIELD_NAMES: + if field_kind is FieldKind.Output and name in RESERVED_OUTPUT_FIELD_NAMES: raise InvalidFieldError(f'Invalid field name "{name}" on "{model_type}" (reserved output field name)') - # internal fields *must* be in the reserved list + if (field_kind is FieldKind.Internal) and name not in RESERVED_INPUT_FIELD_NAMES: + raise InvalidFieldError( + f'Invalid field name "{name}" on "{model_type}" (internal field without reserved name)' + ) + + # node attribute fields *must* be in the reserved list if ( - field_kind == "internal" - and name not in RESERVED_INPUT_FIELD_NAMES + field_kind is FieldKind.NodeAttribute + and name not in RESERVED_NODE_ATTRIBUTE_FIELD_NAMES and name not in RESERVED_OUTPUT_FIELD_NAMES ): raise InvalidFieldError( - f'Invalid field name "{name}" on "{model_type}" (internal field without reserved name)' + f'Invalid field name "{name}" on "{model_type}" (node attribute field without reserved name)' ) + ui_type = field.json_schema_extra.get("ui_type", None) + if isinstance(ui_type, str) and ui_type.startswith("DEPRECATED_"): + logger.warn(f"\"UIType.{ui_type.split('_')[-1]}\" is deprecated, ignoring") + field.json_schema_extra.pop("ui_type") return None @@ -720,21 +800,30 @@ def wrapper(cls: Type[TBaseInvocation]) -> Type[TBaseInvocation]: validate_fields(cls.model_fields, invocation_type) # Add OpenAPI schema extras - uiconf_name = cls.__qualname__ + ".UIConfig" - if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconf_name: - cls.UIConfig = type(uiconf_name, (UIConfigBase,), {}) - if title is not None: - cls.UIConfig.title = title - if tags is not None: - cls.UIConfig.tags = tags - if category is not None: - cls.UIConfig.category = category + uiconfig_name = cls.__qualname__ + ".UIConfig" + if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconfig_name: + cls.UIConfig = type(uiconfig_name, (UIConfigBase,), {}) + cls.UIConfig.title = title + cls.UIConfig.tags = tags + cls.UIConfig.category = category + + # Grab the node pack's name from the module name, if it's a custom node + module_name = cls.__module__.split(".")[0] + if module_name.endswith(CUSTOM_NODE_PACK_SUFFIX): + cls.UIConfig.node_pack = module_name.split(CUSTOM_NODE_PACK_SUFFIX)[0] + else: + cls.UIConfig.node_pack = None + if version is not None: try: semver.Version.parse(version) except ValueError as e: raise InvalidVersionError(f'Invalid version string for node "{invocation_type}": "{version}"') from e cls.UIConfig.version = version + else: + logger.warn(f'No version specified for node "{invocation_type}", using "1.0.0"') + cls.UIConfig.version = "1.0.0" + if use_cache is not None: cls.model_fields["use_cache"].default = use_cache @@ -749,7 +838,7 @@ def wrapper(cls: Type[TBaseInvocation]) -> Type[TBaseInvocation]: invocation_type_annotation = Literal[invocation_type] # type: ignore invocation_type_field = Field( - title="type", default=invocation_type, json_schema_extra={"_field_kind": "internal"} + title="type", default=invocation_type, json_schema_extra={"field_kind": FieldKind.NodeAttribute} ) docstring = cls.__doc__ @@ -795,7 +884,9 @@ def wrapper(cls: Type[TBaseInvocationOutput]) -> Type[TBaseInvocationOutput]: # Add the output type to the model. output_type_annotation = Literal[output_type] # type: ignore - output_type_field = Field(title="type", default=output_type, json_schema_extra={"_field_kind": "internal"}) + output_type_field = Field( + title="type", default=output_type, json_schema_extra={"field_kind": FieldKind.NodeAttribute} + ) docstring = cls.__doc__ cls = create_model( @@ -827,7 +918,7 @@ class WorkflowField(RootModel): class WithWorkflow(BaseModel): workflow: Optional[WorkflowField] = Field( - default=None, description=FieldDescriptions.workflow, json_schema_extra={"_field_kind": "internal"} + default=None, description=FieldDescriptions.workflow, json_schema_extra={"field_kind": FieldKind.NodeAttribute} ) @@ -845,5 +936,11 @@ class MetadataField(RootModel): class WithMetadata(BaseModel): metadata: Optional[MetadataField] = Field( - default=None, description=FieldDescriptions.metadata, json_schema_extra={"_field_kind": "internal"} + default=None, + description=FieldDescriptions.metadata, + json_schema_extra=InputFieldJSONSchemaExtra( + field_kind=FieldKind.Internal, + input=Input.Connection, + orig_required=False, + ).model_dump(exclude_none=True), ) diff --git a/invokeai/app/invocations/collections.py b/invokeai/app/invocations/collections.py index f26eebe1ff5..4c7b6f94cd4 100644 --- a/invokeai/app/invocations/collections.py +++ b/invokeai/app/invocations/collections.py @@ -5,7 +5,7 @@ from pydantic import ValidationInfo, field_validator from invokeai.app.invocations.primitives import IntegerCollectionOutput -from invokeai.app.util.misc import SEED_MAX, get_random_seed +from invokeai.app.util.misc import SEED_MAX from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation @@ -55,7 +55,7 @@ def invoke(self, context: InvocationContext) -> IntegerCollectionOutput: title="Random Range", tags=["range", "integer", "random", "collection"], category="collections", - version="1.0.0", + version="1.0.1", use_cache=False, ) class RandomRangeInvocation(BaseInvocation): @@ -65,10 +65,10 @@ class RandomRangeInvocation(BaseInvocation): high: int = InputField(default=np.iinfo(np.int32).max, description="The exclusive high value") size: int = InputField(default=1, description="The number of values to generate") seed: int = InputField( + default=0, ge=0, le=SEED_MAX, description="The seed for the RNG (omit for random)", - default_factory=get_random_seed, ) def invoke(self, context: InvocationContext) -> IntegerCollectionOutput: diff --git a/invokeai/app/invocations/custom_nodes/init.py b/invokeai/app/invocations/custom_nodes/init.py index c6708e95a7f..e26bdaf5683 100644 --- a/invokeai/app/invocations/custom_nodes/init.py +++ b/invokeai/app/invocations/custom_nodes/init.py @@ -6,6 +6,7 @@ from importlib.util import module_from_spec, spec_from_file_location from pathlib import Path +from invokeai.app.invocations.baseinvocation import CUSTOM_NODE_PACK_SUFFIX from invokeai.backend.util.logging import InvokeAILogger logger = InvokeAILogger.get_logger() @@ -32,13 +33,15 @@ if module_name in globals(): continue - # we have a legit module to import - spec = spec_from_file_location(module_name, init.absolute()) + # load the module, appending adding a suffix to identify it as a custom node pack + spec = spec_from_file_location(f"{module_name}{CUSTOM_NODE_PACK_SUFFIX}", init.absolute()) if spec is None or spec.loader is None: logger.warn(f"Could not load {init}") continue + logger.info(f"Loading node pack {module_name}") + module = module_from_spec(spec) sys.modules[spec.name] = module spec.loader.exec_module(module) @@ -47,5 +50,5 @@ del init, module_name - -logger.info(f"Loaded {loaded_count} modules from {Path(__file__).parent}") +if loaded_count > 0: + logger.info(f"Loaded {loaded_count} node packs from {Path(__file__).parent}") diff --git a/invokeai/app/invocations/infill.py b/invokeai/app/invocations/infill.py index 9905aa1b5ec..0822a4ce2df 100644 --- a/invokeai/app/invocations/infill.py +++ b/invokeai/app/invocations/infill.py @@ -8,7 +8,7 @@ from invokeai.app.invocations.primitives import ColorField, ImageField, ImageOutput from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin -from invokeai.app.util.misc import SEED_MAX, get_random_seed +from invokeai.app.util.misc import SEED_MAX from invokeai.backend.image_util.cv2_inpaint import cv2_inpaint from invokeai.backend.image_util.lama import LaMA from invokeai.backend.image_util.patchmatch import PatchMatch @@ -154,17 +154,17 @@ def invoke(self, context: InvocationContext) -> ImageOutput: ) -@invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint", version="1.1.0") +@invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint", version="1.1.1") class InfillTileInvocation(BaseInvocation, WithWorkflow, WithMetadata): """Infills transparent areas of an image with tiles of the image""" image: ImageField = InputField(description="The image to infill") tile_size: int = InputField(default=32, ge=1, description="The tile size (px)") seed: int = InputField( + default=0, ge=0, le=SEED_MAX, description="The seed to use for tile generation (omit for random)", - default_factory=get_random_seed, ) def invoke(self, context: InvocationContext) -> ImageOutput: diff --git a/invokeai/app/invocations/ip_adapter.py b/invokeai/app/invocations/ip_adapter.py index 485932e18dd..e0f582eab82 100644 --- a/invokeai/app/invocations/ip_adapter.py +++ b/invokeai/app/invocations/ip_adapter.py @@ -11,7 +11,6 @@ InputField, InvocationContext, OutputField, - UIType, invocation, invocation_output, ) @@ -67,7 +66,7 @@ class IPAdapterInvocation(BaseInvocation): # weight: float = InputField(default=1.0, description="The weight of the IP-Adapter.", ui_type=UIType.Float) weight: Union[float, List[float]] = InputField( - default=1, ge=-1, description="The weight given to the IP-Adapter", ui_type=UIType.Float, title="Weight" + default=1, ge=-1, description="The weight given to the IP-Adapter", title="Weight" ) begin_step_percent: float = InputField( diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 9d4afb70204..d438bcae02e 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -274,7 +274,10 @@ class DenoiseLatentsInvocation(BaseInvocation): ui_order=7, ) latents: Optional[LatentsField] = InputField( - default=None, description=FieldDescriptions.latents, input=Input.Connection + default=None, + description=FieldDescriptions.latents, + input=Input.Connection, + ui_order=4, ) denoise_mask: Optional[DenoiseMaskField] = InputField( default=None, diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index 8cce9bdb881..99dcc72999b 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -14,7 +14,6 @@ InputField, InvocationContext, OutputField, - UIType, invocation, invocation_output, ) @@ -395,7 +394,6 @@ class VaeLoaderInvocation(BaseInvocation): vae_model: VAEModelField = InputField( description=FieldDescriptions.vae_model, input=Input.Direct, - ui_type=UIType.VaeModel, title="VAE", ) diff --git a/invokeai/app/invocations/noise.py b/invokeai/app/invocations/noise.py index e975b7bf22b..b1ee91e1cdf 100644 --- a/invokeai/app/invocations/noise.py +++ b/invokeai/app/invocations/noise.py @@ -6,7 +6,7 @@ from invokeai.app.invocations.latent import LatentsField from invokeai.app.shared.fields import FieldDescriptions -from invokeai.app.util.misc import SEED_MAX, get_random_seed +from invokeai.app.util.misc import SEED_MAX from ...backend.util.devices import choose_torch_device, torch_dtype from .baseinvocation import ( @@ -83,16 +83,16 @@ def build_noise_output(latents_name: str, latents: torch.Tensor, seed: int): title="Noise", tags=["latents", "noise"], category="latents", - version="1.0.0", + version="1.0.1", ) class NoiseInvocation(BaseInvocation): """Generates latent noise.""" seed: int = InputField( + default=0, ge=0, le=SEED_MAX, description=FieldDescriptions.seed, - default_factory=get_random_seed, ) width: int = InputField( default=512, diff --git a/invokeai/app/invocations/primitives.py b/invokeai/app/invocations/primitives.py index ccfb7dcbb3e..afe8ff06d9d 100644 --- a/invokeai/app/invocations/primitives.py +++ b/invokeai/app/invocations/primitives.py @@ -62,12 +62,12 @@ def invoke(self, context: InvocationContext) -> BooleanOutput: title="Boolean Collection Primitive", tags=["primitives", "boolean", "collection"], category="primitives", - version="1.0.0", + version="1.0.1", ) class BooleanCollectionInvocation(BaseInvocation): """A collection of boolean primitive values""" - collection: list[bool] = InputField(default_factory=list, description="The collection of boolean values") + collection: list[bool] = InputField(default=[], description="The collection of boolean values") def invoke(self, context: InvocationContext) -> BooleanCollectionOutput: return BooleanCollectionOutput(collection=self.collection) @@ -111,12 +111,12 @@ def invoke(self, context: InvocationContext) -> IntegerOutput: title="Integer Collection Primitive", tags=["primitives", "integer", "collection"], category="primitives", - version="1.0.0", + version="1.0.1", ) class IntegerCollectionInvocation(BaseInvocation): """A collection of integer primitive values""" - collection: list[int] = InputField(default_factory=list, description="The collection of integer values") + collection: list[int] = InputField(default=[], description="The collection of integer values") def invoke(self, context: InvocationContext) -> IntegerCollectionOutput: return IntegerCollectionOutput(collection=self.collection) @@ -158,12 +158,12 @@ def invoke(self, context: InvocationContext) -> FloatOutput: title="Float Collection Primitive", tags=["primitives", "float", "collection"], category="primitives", - version="1.0.0", + version="1.0.1", ) class FloatCollectionInvocation(BaseInvocation): """A collection of float primitive values""" - collection: list[float] = InputField(default_factory=list, description="The collection of float values") + collection: list[float] = InputField(default=[], description="The collection of float values") def invoke(self, context: InvocationContext) -> FloatCollectionOutput: return FloatCollectionOutput(collection=self.collection) @@ -205,12 +205,12 @@ def invoke(self, context: InvocationContext) -> StringOutput: title="String Collection Primitive", tags=["primitives", "string", "collection"], category="primitives", - version="1.0.0", + version="1.0.1", ) class StringCollectionInvocation(BaseInvocation): """A collection of string primitive values""" - collection: list[str] = InputField(default_factory=list, description="The collection of string values") + collection: list[str] = InputField(default=[], description="The collection of string values") def invoke(self, context: InvocationContext) -> StringCollectionOutput: return StringCollectionOutput(collection=self.collection) @@ -467,13 +467,13 @@ def invoke(self, context: InvocationContext) -> ConditioningOutput: title="Conditioning Collection Primitive", tags=["primitives", "conditioning", "collection"], category="primitives", - version="1.0.0", + version="1.0.1", ) class ConditioningCollectionInvocation(BaseInvocation): """A collection of conditioning tensor primitive values""" collection: list[ConditioningField] = InputField( - default_factory=list, + default=[], description="The collection of conditioning tensors", ) diff --git a/invokeai/app/invocations/t2i_adapter.py b/invokeai/app/invocations/t2i_adapter.py index 8ff8ca762c2..2412a000798 100644 --- a/invokeai/app/invocations/t2i_adapter.py +++ b/invokeai/app/invocations/t2i_adapter.py @@ -9,7 +9,6 @@ InputField, InvocationContext, OutputField, - UIType, invocation, invocation_output, ) @@ -59,7 +58,7 @@ class T2IAdapterInvocation(BaseInvocation): ui_order=-1, ) weight: Union[float, list[float]] = InputField( - default=1, ge=0, description="The weight given to the T2I-Adapter", ui_type=UIType.Float, title="Weight" + default=1, ge=0, description="The weight given to the T2I-Adapter", title="Weight" ) begin_step_percent: float = InputField( default=0, ge=-1, le=2, description="When the T2I-Adapter is first applied (% of total steps)" diff --git a/invokeai/app/services/shared/graph.py b/invokeai/app/services/shared/graph.py index 29af1e2333c..c825a840117 100644 --- a/invokeai/app/services/shared/graph.py +++ b/invokeai/app/services/shared/graph.py @@ -49,7 +49,7 @@ class Edge(BaseModel): def get_output_field(node: BaseInvocation, field: str) -> Any: node_type = type(node) - node_outputs = get_type_hints(node_type.get_output_type()) + node_outputs = get_type_hints(node_type.get_output_annotation()) node_output_field = node_outputs.get(field) or None return node_output_field @@ -188,7 +188,7 @@ class GraphInvocationOutput(BaseInvocationOutput): # TODO: Fill this out and move to invocations -@invocation("graph") +@invocation("graph", version="1.0.0") class GraphInvocation(BaseInvocation): """Execute a graph""" @@ -205,7 +205,7 @@ class IterateInvocationOutput(BaseInvocationOutput): """Used to connect iteration outputs. Will be expanded to a specific output.""" item: Any = OutputField( - description="The item being iterated over", title="Collection Item", ui_type=UIType.CollectionItem + description="The item being iterated over", title="Collection Item", ui_type=UIType._CollectionItem ) @@ -215,7 +215,7 @@ class IterateInvocation(BaseInvocation): """Iterates over a list of items""" collection: list[Any] = InputField( - description="The list of items to iterate over", default_factory=list, ui_type=UIType.Collection + description="The list of items to iterate over", default=[], ui_type=UIType._Collection ) index: int = InputField(description="The index, will be provided on executed iterators", default=0, ui_hidden=True) @@ -227,7 +227,7 @@ def invoke(self, context: InvocationContext) -> IterateInvocationOutput: @invocation_output("collect_output") class CollectInvocationOutput(BaseInvocationOutput): collection: list[Any] = OutputField( - description="The collection of input items", title="Collection", ui_type=UIType.Collection + description="The collection of input items", title="Collection", ui_type=UIType._Collection ) @@ -238,12 +238,12 @@ class CollectInvocation(BaseInvocation): item: Optional[Any] = InputField( default=None, description="The item to collect (all inputs must be of the same type)", - ui_type=UIType.CollectionItem, + ui_type=UIType._CollectionItem, title="Collection Item", input=Input.Connection, ) collection: list[Any] = InputField( - description="The collection, will be provided on execution", default_factory=list, ui_hidden=True + description="The collection, will be provided on execution", default=[], ui_hidden=True ) def invoke(self, context: InvocationContext) -> CollectInvocationOutput: @@ -379,7 +379,7 @@ def validate_self(self) -> None: raise NodeNotFoundError(f"Edge destination node {edge.destination.node_id} does not exist in the graph") # output fields are not on the node object directly, they are on the output type - if edge.source.field not in source_node.get_output_type().model_fields: + if edge.source.field not in source_node.get_output_annotation().model_fields: raise NodeFieldNotFoundError( f"Edge source field {edge.source.field} does not exist in node {edge.source.node_id}" ) diff --git a/invokeai/frontend/web/.prettierignore b/invokeai/frontend/web/.prettierignore index bdf02d5c9eb..05782f1f536 100644 --- a/invokeai/frontend/web/.prettierignore +++ b/invokeai/frontend/web/.prettierignore @@ -9,6 +9,5 @@ index.html .yalc/ *.scss src/services/api/schema.d.ts -docs/ static/ src/theme/css/overlayscrollbars.css diff --git a/invokeai/frontend/web/docs/README.md b/invokeai/frontend/web/docs/README.md index 47839994197..2545206c6ab 100644 --- a/invokeai/frontend/web/docs/README.md +++ b/invokeai/frontend/web/docs/README.md @@ -13,6 +13,8 @@ - [Vite](#vite) - [i18next & Weblate](#i18next--weblate) - [openapi-typescript](#openapi-typescript) + - [reactflow](#reactflow) + - [zod](#zod) - [Client Types Generation](#client-types-generation) - [Package Scripts](#package-scripts) - [Contributing](#contributing) @@ -26,46 +28,54 @@ The UI is a fairly straightforward Typescript React app. ## Core Libraries -The app makes heavy use of a handful of libraries. +InvokeAI's UI is made possible by a number of excellent open-source libraries. The most heavily-used are listed below, but there are many others. ### Redux Toolkit -[Redux Toolkit](https://github.com/reduxjs/redux-toolkit) is used for state management and fetching/caching: +[Redux Toolkit] is used for state management and fetching/caching: - `RTK-Query` for data fetching and caching - `createAsyncThunk` for a couple other HTTP requests - `createEntityAdapter` to normalize things like images and models - `createListenerMiddleware` for async workflows -We use [redux-remember](https://github.com/zewish/redux-remember) for persistence. +We use [redux-remember] for persistence. ### Socket\.IO -[Socket\.IO](https://github.com/socketio/socket.io) is used for server-to-client events, like generation process and queue state changes. +[Socket.IO] is used for server-to-client events, like generation process and queue state changes. ### Chakra UI -[Chakra UI](https://github.com/chakra-ui/chakra-ui) is our primary UI library, but we also use a few components from [Mantine v6](https://v6.mantine.dev/). +[Chakra UI] is our primary UI library, but we also use a few components from [Mantine v6]. ### KonvaJS -[KonvaJS](https://github.com/konvajs/react-konva) powers the canvas. In the future, we'd like to explore [PixiJS](https://github.com/pixijs/pixijs) or WebGPU. +[KonvaJS] powers the canvas. In the future, we'd like to explore [PixiJS] or WebGPU. ### Vite -[Vite](https://github.com/vitejs/vite) is our bundler. +[Vite] is our bundler. ### i18next & Weblate -We use [i18next](https://github.com/i18next/react-i18next) for localisation, but translation to languages other than English happens on our [Weblate](https://hosted.weblate.org/engage/invokeai/) project. **Only the English source strings should be changed on this repo.** +We use [i18next] for localization, but translation to languages other than English happens on our [Weblate] project. **Only the English source strings should be changed on this repo.** ### openapi-typescript -[openapi-typescript](https://github.com/drwpow/openapi-typescript) is used to generate types from the server's OpenAPI schema. See TYPES_CODEGEN.md. +[openapi-typescript] is used to generate types from the server's OpenAPI schema. See TYPES_CODEGEN.md. + +### reactflow + +[reactflow] powers the Workflow Editor. + +### zod + +[zod] schemas are used to model data structures and provide runtime validation. ## Client Types Generation -We use [`openapi-typescript`](https://github.com/drwpow/openapi-typescript) to generate types from the app's OpenAPI schema. +We use [openapi-typescript] to generate types from the app's OpenAPI schema. The generated types are written to `invokeai/frontend/web/src/services/api/schema.d.ts`. This file is committed to the repo. @@ -98,11 +108,11 @@ Run with `yarn