From 27fd9071ba6442664848a6bae1b312476e910c98 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 17 Nov 2023 11:32:35 +1100 Subject: [PATCH 01/12] feat(ui): add support for custom field types Node authors may now create their own arbitrary/custom field types. Any pydantic model is supported. Two notes: 1. Your field type's class name must be unique. Suggest prefixing fields with something related to the node pack as a kind of namespace. 2. Custom field types function as connection-only fields. For example, if your custom field has string attributes, you will not get a text input for that attribute when you give a node a field with your custom type. This is the same behaviour as other complex fields that don't have custom UIs in the workflow editor - like, say, a string collection. --- .../web/src/features/nodes/types/constants.ts | 5 ++++ .../web/src/features/nodes/types/types.ts | 15 +++++++++++- .../nodes/util/fieldTemplateBuilders.ts | 23 +++++++++++-------- .../src/features/nodes/util/parseSchema.ts | 21 +++++++++-------- 4 files changed, 44 insertions(+), 20 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/types/constants.ts b/invokeai/frontend/web/src/features/nodes/types/constants.ts index c6eec736da0..c6b8e2acc4a 100644 --- a/invokeai/frontend/web/src/features/nodes/types/constants.ts +++ b/invokeai/frontend/web/src/features/nodes/types/constants.ts @@ -156,6 +156,11 @@ export const FIELDS: Record = { description: 'Any field type is accepted.', title: 'Any', }, + Unknown: { + color: 'gray.500', + description: 'Unknown field type is accepted.', + title: 'Unknown', + }, MetadataField: { color: 'gray.500', description: 'A metadata dict.', diff --git a/invokeai/frontend/web/src/features/nodes/types/types.ts b/invokeai/frontend/web/src/features/nodes/types/types.ts index c55d114dcf6..7ec5c5c293d 100644 --- a/invokeai/frontend/web/src/features/nodes/types/types.ts +++ b/invokeai/frontend/web/src/features/nodes/types/types.ts @@ -133,6 +133,7 @@ export const zFieldType = z.enum([ 'UNetField', 'VaeField', 'VaeModelField', + 'Unknown', ]); export type FieldType = z.infer; @@ -789,6 +790,11 @@ export const zAnyInputFieldValue = zInputFieldValueBase.extend({ value: z.any().optional(), }); +export const zUnknownInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('Unknown'), + value: z.any().optional(), +}); + export const zInputFieldValue = z.discriminatedUnion('type', [ zAnyInputFieldValue, zBoardInputFieldValue, @@ -846,6 +852,7 @@ export const zInputFieldValue = z.discriminatedUnion('type', [ zMetadataItemPolymorphicInputFieldValue, zMetadataInputFieldValue, zMetadataCollectionInputFieldValue, + zUnknownInputFieldValue, ]); export type InputFieldValue = z.infer; @@ -863,6 +870,11 @@ export type AnyInputFieldTemplate = InputFieldTemplateBase & { default: undefined; }; +export type UnknownInputFieldTemplate = InputFieldTemplateBase & { + type: 'Unknown'; + default: undefined; +}; + export type IntegerInputFieldTemplate = InputFieldTemplateBase & { type: 'integer'; default: number; @@ -1259,7 +1271,8 @@ export type InputFieldTemplate = | MetadataItemCollectionInputFieldTemplate | MetadataInputFieldTemplate | MetadataItemPolymorphicInputFieldTemplate - | MetadataCollectionInputFieldTemplate; + | MetadataCollectionInputFieldTemplate + | UnknownInputFieldTemplate; export const isInputFieldValue = ( field?: InputFieldValue | OutputFieldValue diff --git a/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts b/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts index 92e44e9ab2c..e67283d2e7a 100644 --- a/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts +++ b/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts @@ -81,6 +81,7 @@ import { T2IAdapterModelInputFieldTemplate, T2IAdapterPolymorphicInputFieldTemplate, UNetInputFieldTemplate, + UnknownInputFieldTemplate, VaeInputFieldTemplate, VaeModelInputFieldTemplate, isArraySchemaObject, @@ -981,6 +982,18 @@ const buildSchedulerInputFieldTemplate = ({ return template; }; +const buildUnknownInputFieldTemplate = ({ + baseField, +}: BuildInputFieldArg): UnknownInputFieldTemplate => { + const template: UnknownInputFieldTemplate = { + ...baseField, + type: 'Unknown', + default: undefined, + }; + + return template; +}; + export const getFieldType = ( schemaObject: OpenAPIV3_1SchemaOrRef ): string | undefined => { @@ -1145,13 +1158,9 @@ const TEMPLATE_BUILDER_MAP: { UNetField: buildUNetInputFieldTemplate, VaeField: buildVaeInputFieldTemplate, VaeModelField: buildVaeModelInputFieldTemplate, + Unknown: buildUnknownInputFieldTemplate, }; -const isTemplatedFieldType = ( - fieldType: string | undefined -): fieldType is keyof typeof TEMPLATE_BUILDER_MAP => - Boolean(fieldType && fieldType in TEMPLATE_BUILDER_MAP); - /** * Builds an input field from an invocation schema property. * @param fieldSchema The schema object @@ -1193,10 +1202,6 @@ export const buildInputFieldTemplate = ( ...extra, }; - if (!isTemplatedFieldType(fieldType)) { - return; - } - const builder = TEMPLATE_BUILDER_MAP[fieldType]; if (!builder) { diff --git a/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts b/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts index 8737fc52b9c..87c1d547608 100644 --- a/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts +++ b/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts @@ -4,6 +4,7 @@ import { reduce, startCase } from 'lodash-es'; import { OpenAPIV3_1 } from 'openapi-types'; import { AnyInvocationType } from 'services/events/types'; import { + FieldType, InputFieldTemplate, InvocationSchemaObject, InvocationTemplate, @@ -103,7 +104,7 @@ export const parseSchema = ( return inputsAccumulator; } - const fieldType = property.ui_type ?? getFieldType(property); + let fieldType = property.ui_type ?? getFieldType(property); if (!fieldType) { logger('nodes').warn( @@ -137,23 +138,23 @@ export const parseSchema = ( } if (!isFieldType(fieldType)) { - logger('nodes').warn( + logger('nodes').debug( { node: type, fieldName: propertyName, fieldType, field: parseify(property), }, - `Skipping unknown input field type: ${fieldType}` + `Fallback handling for unknown input field type: ${fieldType}` ); - return inputsAccumulator; + fieldType = 'Unknown'; } const field = buildInputFieldTemplate( schema, property, propertyName, - fieldType + fieldType as FieldType // we have already checked that fieldType is a valid FieldType, and forced it to be Unknown if not ); if (!field) { @@ -220,14 +221,14 @@ export const parseSchema = ( return outputsAccumulator; } - const fieldType = property.ui_type ?? getFieldType(property); + let fieldType = property.ui_type ?? getFieldType(property); if (!isFieldType(fieldType)) { - logger('nodes').warn( + logger('nodes').debug( { fieldName: propertyName, fieldType, field: parseify(property) }, - 'Skipping unknown output field type' + `Fallback handling for unknown input field type: ${fieldType}` ); - return outputsAccumulator; + fieldType = 'Unknown'; } outputsAccumulator[propertyName] = { @@ -236,7 +237,7 @@ export const parseSchema = ( title: property.title ?? (propertyName ? startCase(propertyName) : ''), description: property.description ?? '', - type: fieldType, + type: fieldType as FieldType, ui_hidden: property.ui_hidden ?? false, ui_type: property.ui_type, ui_order: property.ui_order, From 5ce2dc3a58859eabcc805cf3e0a2f22f4ebbe124 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 17 Nov 2023 12:01:39 +1100 Subject: [PATCH 02/12] feat(ui): fix tooltips for custom types We need to hold onto the original type of the field so they don't all just show up as "Unknown". --- .../nodes/Invocation/fields/FieldHandle.tsx | 17 +++++++++----- .../Invocation/fields/FieldTooltipContent.tsx | 4 +--- .../web/src/features/nodes/types/types.ts | 2 ++ .../nodes/util/fieldTemplateBuilders.ts | 4 +++- .../src/features/nodes/util/parseSchema.ts | 23 ++++++++++++++++++- 5 files changed, 39 insertions(+), 11 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/FieldHandle.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/FieldHandle.tsx index 31665902547..b458f2ca255 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/FieldHandle.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/FieldHandle.tsx @@ -47,8 +47,8 @@ const FieldHandle = (props: FieldHandleProps) => { isConnectionStartField, connectionError, } = props; - const { name, type } = fieldTemplate; - const { color: typeColor, title } = FIELDS[type]; + const { name, type, originalType } = fieldTemplate; + const { color: typeColor } = FIELDS[type]; const styles: CSSProperties = useMemo(() => { const isCollectionType = COLLECTION_TYPES.includes(type); @@ -102,13 +102,18 @@ const FieldHandle = (props: FieldHandleProps) => { const tooltip = useMemo(() => { if (isConnectionInProgress && isConnectionStartField) { - return title; + return originalType; } if (isConnectionInProgress && connectionError) { - return connectionError ?? title; + return connectionError ?? originalType; } - return title; - }, [connectionError, isConnectionInProgress, isConnectionStartField, title]); + return originalType; + }, [ + connectionError, + isConnectionInProgress, + isConnectionStartField, + originalType, + ]); return ( { {fieldTemplate.description} )} - {fieldTemplate && Type: {FIELDS[fieldTemplate.type].title}} + {fieldTemplate && Type: {fieldTemplate.originalType}} {isInputTemplate && Input: {startCase(fieldTemplate.input)}} ); diff --git a/invokeai/frontend/web/src/features/nodes/types/types.ts b/invokeai/frontend/web/src/features/nodes/types/types.ts index 7ec5c5c293d..92a18a418d1 100644 --- a/invokeai/frontend/web/src/features/nodes/types/types.ts +++ b/invokeai/frontend/web/src/features/nodes/types/types.ts @@ -191,6 +191,7 @@ export type OutputFieldTemplate = { type: FieldType; title: string; description: string; + originalType: string; // used for custom types } & _OutputField; export const zInputFieldValueBase = zFieldValueBase.extend({ @@ -863,6 +864,7 @@ export type InputFieldTemplateBase = { description: string; required: boolean; fieldKind: 'input'; + originalType: string; // used for custom types } & _InputField; export type AnyInputFieldTemplate = InputFieldTemplateBase & { diff --git a/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts b/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts index e67283d2e7a..fe94ec27828 100644 --- a/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts +++ b/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts @@ -1170,7 +1170,8 @@ export const buildInputFieldTemplate = ( nodeSchema: InvocationSchemaObject, fieldSchema: InvocationFieldSchema, name: string, - fieldType: FieldType + fieldType: FieldType, + originalType: string ) => { const { input, @@ -1192,6 +1193,7 @@ export const buildInputFieldTemplate = ( ui_order, ui_choice_labels, item_default, + originalType, }; const baseField = { diff --git a/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts b/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts index 87c1d547608..30fff76cafc 100644 --- a/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts +++ b/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts @@ -119,6 +119,9 @@ export const parseSchema = ( return inputsAccumulator; } + // stash this for custom types + const originalType = fieldType; + if (fieldType === 'WorkflowField') { withWorkflow = true; return inputsAccumulator; @@ -154,7 +157,8 @@ export const parseSchema = ( schema, property, propertyName, - fieldType as FieldType // we have already checked that fieldType is a valid FieldType, and forced it to be Unknown if not + fieldType as FieldType, // we have already checked that fieldType is a valid FieldType, and forced it to be Unknown if not + originalType ); if (!field) { @@ -223,6 +227,22 @@ export const parseSchema = ( let fieldType = property.ui_type ?? getFieldType(property); + if (!fieldType) { + logger('nodes').warn( + { + node: type, + fieldName: propertyName, + fieldType, + field: parseify(property), + }, + 'Missing output field type' + ); + return outputsAccumulator; + } + + // stash for custom types + const originalType = fieldType; + if (!isFieldType(fieldType)) { logger('nodes').debug( { fieldName: propertyName, fieldType, field: parseify(property) }, @@ -241,6 +261,7 @@ export const parseSchema = ( ui_hidden: property.ui_hidden ?? false, ui_type: property.ui_type, ui_order: property.ui_order, + originalType, }; return outputsAccumulator; From dc44debbabd9040dc5e2b2575744045d771114f0 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 17 Nov 2023 12:09:15 +1100 Subject: [PATCH 03/12] fix(ui): fix ts error with custom fields --- .../frontend/web/src/features/nodes/util/fieldValueBuilders.ts | 1 + 1 file changed, 1 insertion(+) diff --git a/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts b/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts index ca2513649d3..89d6729f4c3 100644 --- a/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts +++ b/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts @@ -60,6 +60,7 @@ const FIELD_VALUE_FALLBACK_MAP: { UNetField: undefined, VaeField: undefined, VaeModelField: undefined, + Unknown: undefined, }; export const buildInputFieldValue = ( From 98a0ce0f42c757947f20503de3af234548e06af0 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 18 Nov 2023 10:40:19 +1100 Subject: [PATCH 04/12] feat(ui): custom field types connection validation In the initial commit, a custom field's original type was added to the *field templates* only as `originalType`. Custom fields' `type` property was `"Custom"`*. This allowed for type safety throughout the UI logic. *Actually, it was `"Unknown"`, but I changed it to custom for clarity. Connection validation logic, however, uses the *field instance* of the node/field. Like the templates, *field instances* with custom types have their `type` set to `"Custom"`, but they didn't have an `originalType` property. As a result, all custom fields could be connected to all other custom fields. To resolve this, we need to add `originalType` to the *field instances*, then switch the validation logic to use this instead of `type`. This ended up needing a bit of fanagling: - If we make `originalType` a required property on field instances, existing workflows will break during connection validation, because they won't have this property. We'd need a new layer of logic to migrate the workflows, adding the new `originalType` property. While this layer is probably needed anyways, typing `originalType` as optional is much simpler. Workflow migration logic can come layer. (Technically, we could remove all references to field types from the workflow files, and let the templates hold all this information. This feels like a significant change and I'm reluctant to do it now.) - Because `originalType` is optional, anywhere we care about the type of a field, we need to use it over `type`. So there are a number of `field.originalType ?? field.type` expressions. This is a bit of a gotcha, we'll need to remember this in the future. - We use `Array.prototype.includes()` often in the workflow editor, e.g. `COLLECTION_TYPES.includes(type)`. In these cases, the const array is of type `FieldType[]`, and `type` is is `FieldType`. Because we now support custom types, the arg `type` is now widened from `FieldType` to `string`. This causes a TS error. This behaviour is somewhat controversial (see https://github.com/microsoft/TypeScript/issues/14520). These expressions are now rewritten as `COLLECTION_TYPES.some((t) => t === type)` to satisfy TS. It's logically equivalent. --- .../flow/AddNodePopover/AddNodePopover.tsx | 10 +++++-- .../connectionLines/CustomConnectionLine.tsx | 9 +++--- .../flow/edges/util/getEdgeColor.ts | 12 ++++++++ .../flow/edges/util/makeEdgeSelector.ts | 4 +-- .../nodes/Invocation/fields/FieldHandle.tsx | 30 +++++++------------ .../features/nodes/hooks/useFieldType.ts.ts | 3 +- .../src/features/nodes/store/nodesSlice.ts | 3 +- .../web/src/features/nodes/store/types.ts | 2 +- .../nodes/store/util/buildNodeData.ts | 1 + .../store/util/findConnectionToValidHandle.ts | 11 +++++-- .../util/makeIsConnectionValidSelector.ts | 4 +-- .../util/validateSourceAndTargetTypes.ts | 22 +++++++------- .../web/src/features/nodes/types/constants.ts | 8 ++--- .../web/src/features/nodes/types/types.ts | 19 ++++++------ .../nodes/util/fieldTemplateBuilders.ts | 12 ++++---- .../features/nodes/util/fieldValueBuilders.ts | 7 ++--- .../src/features/nodes/util/parseSchema.ts | 17 +++++++---- 17 files changed, 98 insertions(+), 76 deletions(-) create mode 100644 invokeai/frontend/web/src/features/nodes/components/flow/edges/util/getEdgeColor.ts diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx index 5ddd1d4ece6..b514474a260 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx @@ -74,9 +74,13 @@ const AddNodePopover = () => { return some(handles, (handle) => { const sourceType = - handleFilter == 'source' ? fieldFilter : handle.type; + handleFilter == 'source' + ? fieldFilter + : handle.originalType ?? handle.type; const targetType = - handleFilter == 'target' ? fieldFilter : handle.type; + handleFilter == 'target' + ? fieldFilter + : handle.originalType ?? handle.type; return validateSourceAndTargetTypes(sourceType, targetType); }); @@ -111,7 +115,7 @@ const AddNodePopover = () => { data.sort((a, b) => a.label.localeCompare(b.label)); - return { data, t }; + return { data }; }, defaultSelectorOptions ); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/connectionLines/CustomConnectionLine.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/connectionLines/CustomConnectionLine.tsx index a379be7ee28..a14b7b23c6d 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/connectionLines/CustomConnectionLine.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/connectionLines/CustomConnectionLine.tsx @@ -2,18 +2,17 @@ import { createSelector } from '@reduxjs/toolkit'; import { stateSelector } from 'app/store/store'; import { useAppSelector } from 'app/store/storeHooks'; import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar'; -import { FIELDS } from 'features/nodes/types/constants'; import { memo } from 'react'; import { ConnectionLineComponentProps, getBezierPath } from 'reactflow'; +import { getFieldColor } from '../edges/util/getEdgeColor'; const selector = createSelector(stateSelector, ({ nodes }) => { const { shouldAnimateEdges, currentConnectionFieldType, shouldColorEdges } = nodes; - const stroke = - currentConnectionFieldType && shouldColorEdges - ? colorTokenToCssVar(FIELDS[currentConnectionFieldType].color) - : colorTokenToCssVar('base.500'); + const stroke = shouldColorEdges + ? getFieldColor(currentConnectionFieldType) + : colorTokenToCssVar('base.500'); let className = 'react-flow__custom_connection-path'; diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/getEdgeColor.ts b/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/getEdgeColor.ts new file mode 100644 index 00000000000..99ada97de14 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/getEdgeColor.ts @@ -0,0 +1,12 @@ +import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar'; +import { FIELDS } from 'features/nodes/types/constants'; +import { FieldType } from 'features/nodes/types/types'; + +export const getFieldColor = (fieldType: FieldType | string | null): string => { + if (!fieldType) { + return colorTokenToCssVar('base.500'); + } + const color = FIELDS[fieldType]?.color; + + return color ? colorTokenToCssVar(color) : colorTokenToCssVar('base.500'); +}; diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/makeEdgeSelector.ts b/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/makeEdgeSelector.ts index b5dc484eaea..a6a409e1ad2 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/makeEdgeSelector.ts +++ b/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/makeEdgeSelector.ts @@ -2,8 +2,8 @@ import { createSelector } from '@reduxjs/toolkit'; import { stateSelector } from 'app/store/store'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar'; -import { FIELDS } from 'features/nodes/types/constants'; import { isInvocationNode } from 'features/nodes/types/types'; +import { getFieldColor } from './getEdgeColor'; export const makeEdgeSelector = ( source: string, @@ -29,7 +29,7 @@ export const makeEdgeSelector = ( const stroke = sourceType && nodes.shouldColorEdges - ? colorTokenToCssVar(FIELDS[sourceType].color) + ? getFieldColor(sourceType) : colorTokenToCssVar('base.500'); return { diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/FieldHandle.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/FieldHandle.tsx index b458f2ca255..849003ffbeb 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/FieldHandle.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/FieldHandle.tsx @@ -1,8 +1,6 @@ import { Tooltip } from '@chakra-ui/react'; -import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar'; import { COLLECTION_TYPES, - FIELDS, HANDLE_TOOLTIP_OPEN_DELAY, MODEL_TYPES, POLYMORPHIC_TYPES, @@ -13,6 +11,7 @@ import { } from 'features/nodes/types/types'; import { CSSProperties, memo, useMemo } from 'react'; import { Handle, HandleType, Position } from 'reactflow'; +import { getFieldColor } from '../../../edges/util/getEdgeColor'; export const handleBaseStyles: CSSProperties = { position: 'absolute', @@ -47,14 +46,14 @@ const FieldHandle = (props: FieldHandleProps) => { isConnectionStartField, connectionError, } = props; - const { name, type, originalType } = fieldTemplate; - const { color: typeColor } = FIELDS[type]; + const { name } = fieldTemplate; + const type = fieldTemplate.originalType ?? fieldTemplate.type; const styles: CSSProperties = useMemo(() => { - const isCollectionType = COLLECTION_TYPES.includes(type); - const isPolymorphicType = POLYMORPHIC_TYPES.includes(type); - const isModelType = MODEL_TYPES.includes(type); - const color = colorTokenToCssVar(typeColor); + const isCollectionType = COLLECTION_TYPES.some((t) => t === type); + const isPolymorphicType = POLYMORPHIC_TYPES.some((t) => t === type); + const isModelType = MODEL_TYPES.some((t) => t === type); + const color = getFieldColor(type); const s: CSSProperties = { backgroundColor: isCollectionType || isPolymorphicType @@ -97,23 +96,14 @@ const FieldHandle = (props: FieldHandleProps) => { isConnectionInProgress, isConnectionStartField, type, - typeColor, ]); const tooltip = useMemo(() => { - if (isConnectionInProgress && isConnectionStartField) { - return originalType; - } if (isConnectionInProgress && connectionError) { - return connectionError ?? originalType; + return connectionError; } - return originalType; - }, [ - connectionError, - isConnectionInProgress, - isConnectionStartField, - originalType, - ]); + return type; + }, [connectionError, isConnectionInProgress, type]); return ( ) => { const fieldType = state.currentConnectionFieldType; diff --git a/invokeai/frontend/web/src/features/nodes/store/types.ts b/invokeai/frontend/web/src/features/nodes/store/types.ts index f6bfa7cad8b..b81dd286d72 100644 --- a/invokeai/frontend/web/src/features/nodes/store/types.ts +++ b/invokeai/frontend/web/src/features/nodes/store/types.ts @@ -21,7 +21,7 @@ export type NodesState = { edges: Edge[]; nodeTemplates: Record; connectionStartParams: OnConnectStartParams | null; - currentConnectionFieldType: FieldType | null; + currentConnectionFieldType: FieldType | string | null; connectionMade: boolean; modifyingEdge: boolean; shouldShowFieldTypeLegend: boolean; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/buildNodeData.ts b/invokeai/frontend/web/src/features/nodes/store/util/buildNodeData.ts index 6cecc8c4098..0efd3d17c6a 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/buildNodeData.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/buildNodeData.ts @@ -94,6 +94,7 @@ export const buildNodeData = ( name: outputName, type: outputTemplate.type, fieldKind: 'output', + originalType: outputTemplate.originalType, }; outputsAccumulator[outputName] = outputFieldValue; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/findConnectionToValidHandle.ts b/invokeai/frontend/web/src/features/nodes/store/util/findConnectionToValidHandle.ts index 69386c1f23c..da2026ce570 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/findConnectionToValidHandle.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/findConnectionToValidHandle.ts @@ -12,7 +12,7 @@ import { getIsGraphAcyclic } from './getIsGraphAcyclic'; const isValidConnection = ( edges: Edge[], handleCurrentType: HandleType, - handleCurrentFieldType: FieldType, + handleCurrentFieldType: FieldType | string, node: Node, handle: InputFieldValue | OutputFieldValue ) => { @@ -35,7 +35,12 @@ const isValidConnection = ( } } - if (!validateSourceAndTargetTypes(handleCurrentFieldType, handle.type)) { + if ( + !validateSourceAndTargetTypes( + handleCurrentFieldType, + handle.originalType ?? handle.type + ) + ) { isValidConnection = false; } @@ -49,7 +54,7 @@ export const findConnectionToValidHandle = ( handleCurrentNodeId: string, handleCurrentName: string, handleCurrentType: HandleType, - handleCurrentFieldType: FieldType + handleCurrentFieldType: FieldType | string ): Connection | null => { if (node.id === handleCurrentNodeId) { return null; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts b/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts index 57dd284b88d..cb7886e57e1 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts @@ -1,9 +1,9 @@ import { createSelector } from '@reduxjs/toolkit'; import { stateSelector } from 'app/store/store'; -import { getIsGraphAcyclic } from './getIsGraphAcyclic'; import { FieldType } from 'features/nodes/types/types'; import i18n from 'i18next'; import { HandleType } from 'reactflow'; +import { getIsGraphAcyclic } from './getIsGraphAcyclic'; import { validateSourceAndTargetTypes } from './validateSourceAndTargetTypes'; /** @@ -15,7 +15,7 @@ export const makeConnectionErrorSelector = ( nodeId: string, fieldName: string, handleType: HandleType, - fieldType?: FieldType + fieldType?: FieldType | string ) => { return createSelector(stateSelector, (state) => { if (!fieldType) { diff --git a/invokeai/frontend/web/src/features/nodes/store/util/validateSourceAndTargetTypes.ts b/invokeai/frontend/web/src/features/nodes/store/util/validateSourceAndTargetTypes.ts index 2f47e47a787..123cda8e044 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/validateSourceAndTargetTypes.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/validateSourceAndTargetTypes.ts @@ -7,8 +7,8 @@ import { import { FieldType } from 'features/nodes/types/types'; export const validateSourceAndTargetTypes = ( - sourceType: FieldType, - targetType: FieldType + sourceType: FieldType | string, + targetType: FieldType | string ) => { // TODO: There's a bug with Collect -> Iterate nodes: // https://github.com/invoke-ai/InvokeAI/issues/3956 @@ -31,17 +31,18 @@ export const validateSourceAndTargetTypes = ( */ const isCollectionItemToNonCollection = - sourceType === 'CollectionItem' && !COLLECTION_TYPES.includes(targetType); + sourceType === 'CollectionItem' && + !COLLECTION_TYPES.some((t) => t === targetType); const isNonCollectionToCollectionItem = targetType === 'CollectionItem' && - !COLLECTION_TYPES.includes(sourceType) && - !POLYMORPHIC_TYPES.includes(sourceType); + !COLLECTION_TYPES.some((t) => t === sourceType) && + !POLYMORPHIC_TYPES.some((t) => t === sourceType); const isAnythingToPolymorphicOfSameBaseType = - POLYMORPHIC_TYPES.includes(targetType) && + POLYMORPHIC_TYPES.some((t) => t === targetType) && (() => { - if (!POLYMORPHIC_TYPES.includes(targetType)) { + if (!POLYMORPHIC_TYPES.some((t) => t === targetType)) { return false; } const baseType = @@ -57,11 +58,12 @@ export const validateSourceAndTargetTypes = ( const isGenericCollectionToAnyCollectionOrPolymorphic = sourceType === 'Collection' && - (COLLECTION_TYPES.includes(targetType) || - POLYMORPHIC_TYPES.includes(targetType)); + (COLLECTION_TYPES.some((t) => t === targetType) || + POLYMORPHIC_TYPES.some((t) => t === targetType)); const isCollectionToGenericCollection = - targetType === 'Collection' && COLLECTION_TYPES.includes(sourceType); + targetType === 'Collection' && + COLLECTION_TYPES.some((t) => t === sourceType); const isIntToFloat = sourceType === 'integer' && targetType === 'float'; diff --git a/invokeai/frontend/web/src/features/nodes/types/constants.ts b/invokeai/frontend/web/src/features/nodes/types/constants.ts index c6b8e2acc4a..93e6fa2948e 100644 --- a/invokeai/frontend/web/src/features/nodes/types/constants.ts +++ b/invokeai/frontend/web/src/features/nodes/types/constants.ts @@ -150,16 +150,16 @@ export const isPolymorphicItemType = ( ): itemType is keyof typeof SINGLE_TO_POLYMORPHIC_MAP => Boolean(itemType && itemType in SINGLE_TO_POLYMORPHIC_MAP); -export const FIELDS: Record = { +export const FIELDS: Record = { Any: { color: 'gray.500', description: 'Any field type is accepted.', title: 'Any', }, - Unknown: { + Custom: { color: 'gray.500', - description: 'Unknown field type is accepted.', - title: 'Unknown', + description: 'A custom field, provided by an external node.', + title: 'Custom', }, MetadataField: { color: 'gray.500', diff --git a/invokeai/frontend/web/src/features/nodes/types/types.ts b/invokeai/frontend/web/src/features/nodes/types/types.ts index 92a18a418d1..d87e45c8073 100644 --- a/invokeai/frontend/web/src/features/nodes/types/types.ts +++ b/invokeai/frontend/web/src/features/nodes/types/types.ts @@ -133,7 +133,7 @@ export const zFieldType = z.enum([ 'UNetField', 'VaeField', 'VaeModelField', - 'Unknown', + 'Custom', ]); export type FieldType = z.infer; @@ -164,6 +164,7 @@ export const zFieldValueBase = z.object({ id: z.string().trim().min(1), name: z.string().trim().min(1), type: zFieldType, + originalType: z.string().optional(), }); export type FieldValueBase = z.infer; @@ -191,7 +192,7 @@ export type OutputFieldTemplate = { type: FieldType; title: string; description: string; - originalType: string; // used for custom types + originalType?: string; // used for custom types } & _OutputField; export const zInputFieldValueBase = zFieldValueBase.extend({ @@ -791,8 +792,8 @@ export const zAnyInputFieldValue = zInputFieldValueBase.extend({ value: z.any().optional(), }); -export const zUnknownInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('Unknown'), +export const zCustomInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('Custom'), value: z.any().optional(), }); @@ -853,7 +854,7 @@ export const zInputFieldValue = z.discriminatedUnion('type', [ zMetadataItemPolymorphicInputFieldValue, zMetadataInputFieldValue, zMetadataCollectionInputFieldValue, - zUnknownInputFieldValue, + zCustomInputFieldValue, ]); export type InputFieldValue = z.infer; @@ -864,7 +865,7 @@ export type InputFieldTemplateBase = { description: string; required: boolean; fieldKind: 'input'; - originalType: string; // used for custom types + originalType?: string; // used for custom types } & _InputField; export type AnyInputFieldTemplate = InputFieldTemplateBase & { @@ -872,8 +873,8 @@ export type AnyInputFieldTemplate = InputFieldTemplateBase & { default: undefined; }; -export type UnknownInputFieldTemplate = InputFieldTemplateBase & { - type: 'Unknown'; +export type CustomInputFieldTemplate = InputFieldTemplateBase & { + type: 'Custom'; default: undefined; }; @@ -1274,7 +1275,7 @@ export type InputFieldTemplate = | MetadataInputFieldTemplate | MetadataItemPolymorphicInputFieldTemplate | MetadataCollectionInputFieldTemplate - | UnknownInputFieldTemplate; + | CustomInputFieldTemplate; export const isInputFieldValue = ( field?: InputFieldValue | OutputFieldValue diff --git a/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts b/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts index fe94ec27828..3aa720ef5cc 100644 --- a/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts +++ b/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts @@ -81,7 +81,7 @@ import { T2IAdapterModelInputFieldTemplate, T2IAdapterPolymorphicInputFieldTemplate, UNetInputFieldTemplate, - UnknownInputFieldTemplate, + CustomInputFieldTemplate, VaeInputFieldTemplate, VaeModelInputFieldTemplate, isArraySchemaObject, @@ -982,12 +982,12 @@ const buildSchedulerInputFieldTemplate = ({ return template; }; -const buildUnknownInputFieldTemplate = ({ +const buildCustomInputFieldTemplate = ({ baseField, -}: BuildInputFieldArg): UnknownInputFieldTemplate => { - const template: UnknownInputFieldTemplate = { +}: BuildInputFieldArg): CustomInputFieldTemplate => { + const template: CustomInputFieldTemplate = { ...baseField, - type: 'Unknown', + type: 'Custom', default: undefined, }; @@ -1158,7 +1158,7 @@ const TEMPLATE_BUILDER_MAP: { UNetField: buildUNetInputFieldTemplate, VaeField: buildVaeInputFieldTemplate, VaeModelField: buildVaeModelInputFieldTemplate, - Unknown: buildUnknownInputFieldTemplate, + Custom: buildCustomInputFieldTemplate, }; /** diff --git a/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts b/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts index 89d6729f4c3..f8db78ecc39 100644 --- a/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts +++ b/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts @@ -60,7 +60,7 @@ const FIELD_VALUE_FALLBACK_MAP: { UNetField: undefined, VaeField: undefined, VaeModelField: undefined, - Unknown: undefined, + Custom: undefined, }; export const buildInputFieldValue = ( @@ -77,10 +77,9 @@ export const buildInputFieldValue = ( type: template.type, label: '', fieldKind: 'input', + originalType: template.originalType, + value: template.default ?? FIELD_VALUE_FALLBACK_MAP[template.type], } as InputFieldValue; - fieldValue.value = - template.default ?? FIELD_VALUE_FALLBACK_MAP[template.type]; - return fieldValue; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts b/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts index 30fff76cafc..c9ccf503be2 100644 --- a/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts +++ b/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts @@ -4,7 +4,6 @@ import { reduce, startCase } from 'lodash-es'; import { OpenAPIV3_1 } from 'openapi-types'; import { AnyInvocationType } from 'services/events/types'; import { - FieldType, InputFieldTemplate, InvocationSchemaObject, InvocationTemplate, @@ -150,14 +149,18 @@ export const parseSchema = ( }, `Fallback handling for unknown input field type: ${fieldType}` ); - fieldType = 'Unknown'; + fieldType = 'Custom'; + if (!isFieldType(fieldType)) { + // satisfy TS gods + return inputsAccumulator; + } } const field = buildInputFieldTemplate( schema, property, propertyName, - fieldType as FieldType, // we have already checked that fieldType is a valid FieldType, and forced it to be Unknown if not + fieldType, originalType ); @@ -248,7 +251,11 @@ export const parseSchema = ( { fieldName: propertyName, fieldType, field: parseify(property) }, `Fallback handling for unknown input field type: ${fieldType}` ); - fieldType = 'Unknown'; + fieldType = 'Custom'; + if (!isFieldType(fieldType)) { + // satisfy TS gods + return outputsAccumulator; + } } outputsAccumulator[propertyName] = { @@ -257,7 +264,7 @@ export const parseSchema = ( title: property.title ?? (propertyName ? startCase(propertyName) : ''), description: property.description ?? '', - type: fieldType as FieldType, + type: fieldType, ui_hidden: property.ui_hidden ?? false, ui_type: property.ui_type, ui_order: property.ui_order, From 3ff13dc93c78f6877bdfac7af78933c43924cb3f Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 19 Nov 2023 12:47:05 +1100 Subject: [PATCH 05/12] fix(ui): typo --- invokeai/frontend/web/src/features/nodes/util/parseSchema.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts b/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts index c9ccf503be2..6d3146f9f69 100644 --- a/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts +++ b/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts @@ -249,7 +249,7 @@ export const parseSchema = ( if (!isFieldType(fieldType)) { logger('nodes').debug( { fieldName: propertyName, fieldType, field: parseify(property) }, - `Fallback handling for unknown input field type: ${fieldType}` + `Fallback handling for unknown output field type: ${fieldType}` ); fieldType = 'Custom'; if (!isFieldType(fieldType)) { From e30f22ae7e5fd962f2bd8a4f7941203454854517 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 19 Nov 2023 14:30:08 +1100 Subject: [PATCH 06/12] feat(ui): add CustomCollection and CustomPolymorphic field types --- .../nodes/Invocation/fields/FieldHandle.tsx | 9 +- .../util/makeIsConnectionValidSelector.ts | 3 + .../web/src/features/nodes/types/constants.ts | 15 +++ .../web/src/features/nodes/types/types.ts | 30 ++++- .../nodes/util/fieldTemplateBuilders.ts | 108 ++++++++++++++++-- .../features/nodes/util/fieldValueBuilders.ts | 2 + .../src/features/nodes/util/parseSchema.ts | 66 +++++++---- 7 files changed, 196 insertions(+), 37 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/FieldHandle.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/FieldHandle.tsx index 849003ffbeb..1a43fcdbb01 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/FieldHandle.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/FieldHandle.tsx @@ -50,8 +50,12 @@ const FieldHandle = (props: FieldHandleProps) => { const type = fieldTemplate.originalType ?? fieldTemplate.type; const styles: CSSProperties = useMemo(() => { - const isCollectionType = COLLECTION_TYPES.some((t) => t === type); - const isPolymorphicType = POLYMORPHIC_TYPES.some((t) => t === type); + const isCollectionType = COLLECTION_TYPES.some( + (t) => t === fieldTemplate.type + ); + const isPolymorphicType = POLYMORPHIC_TYPES.some( + (t) => t === fieldTemplate.type + ); const isModelType = MODEL_TYPES.some((t) => t === type); const color = getFieldColor(type); const s: CSSProperties = { @@ -92,6 +96,7 @@ const FieldHandle = (props: FieldHandleProps) => { return s; }, [ connectionError, + fieldTemplate.type, handleType, isConnectionInProgress, isConnectionStartField, diff --git a/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts b/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts index cb7886e57e1..224c0235f75 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts @@ -44,6 +44,9 @@ export const makeConnectionErrorSelector = ( const sourceType = handleType === 'source' ? fieldType : currentConnectionFieldType; + console.log('targetType', targetType); + console.log('sourceType', sourceType); + if (nodeId === connectionNodeId) { return i18n.t('nodes.cannotConnectToSelf'); } diff --git a/invokeai/frontend/web/src/features/nodes/types/constants.ts b/invokeai/frontend/web/src/features/nodes/types/constants.ts index 93e6fa2948e..e8297c17e3c 100644 --- a/invokeai/frontend/web/src/features/nodes/types/constants.ts +++ b/invokeai/frontend/web/src/features/nodes/types/constants.ts @@ -35,6 +35,7 @@ export const COLLECTION_TYPES: FieldType[] = [ 'IPAdapterCollection', 'MetadataItemCollection', 'MetadataCollection', + 'CustomCollection', ]; export const POLYMORPHIC_TYPES: FieldType[] = [ @@ -50,6 +51,7 @@ export const POLYMORPHIC_TYPES: FieldType[] = [ 'T2IAdapterPolymorphic', 'IPAdapterPolymorphic', 'MetadataItemPolymorphic', + 'CustomPolymorphic', ]; export const MODEL_TYPES: FieldType[] = [ @@ -83,6 +85,7 @@ export const COLLECTION_MAP: FieldTypeMapWithNumber = { IPAdapterField: 'IPAdapterCollection', MetadataItemField: 'MetadataItemCollection', MetadataField: 'MetadataCollection', + Custom: 'CustomCollection', }; export const isCollectionItemType = ( itemType: string | undefined @@ -103,6 +106,7 @@ export const SINGLE_TO_POLYMORPHIC_MAP: FieldTypeMapWithNumber = { T2IAdapterField: 'T2IAdapterPolymorphic', IPAdapterField: 'IPAdapterPolymorphic', MetadataItemField: 'MetadataItemPolymorphic', + Custom: 'CustomPolymorphic', }; export const POLYMORPHIC_TO_SINGLE_MAP: FieldTypeMap = { @@ -118,6 +122,7 @@ export const POLYMORPHIC_TO_SINGLE_MAP: FieldTypeMap = { T2IAdapterPolymorphic: 'T2IAdapterField', IPAdapterPolymorphic: 'IPAdapterField', MetadataItemPolymorphic: 'MetadataItemField', + CustomPolymorphic: 'Custom', }; export const TYPES_WITH_INPUT_COMPONENTS: FieldType[] = [ @@ -161,6 +166,16 @@ export const FIELDS: Record = { description: 'A custom field, provided by an external node.', title: 'Custom', }, + CustomCollection: { + color: 'gray.500', + description: 'A custom field collection, provided by an external node.', + title: 'Custom Collection', + }, + CustomPolymorphic: { + color: 'gray.500', + description: 'A custom field polymorphic, provided by an external node.', + title: 'Custom Polymorphic', + }, MetadataField: { color: 'gray.500', description: 'A metadata dict.', diff --git a/invokeai/frontend/web/src/features/nodes/types/types.ts b/invokeai/frontend/web/src/features/nodes/types/types.ts index d87e45c8073..03da1c31ddb 100644 --- a/invokeai/frontend/web/src/features/nodes/types/types.ts +++ b/invokeai/frontend/web/src/features/nodes/types/types.ts @@ -134,6 +134,8 @@ export const zFieldType = z.enum([ 'VaeField', 'VaeModelField', 'Custom', + 'CustomCollection', + 'CustomPolymorphic', ]); export type FieldType = z.infer; @@ -144,7 +146,7 @@ export type FieldTypeMapWithNumber = { export const zReservedFieldType = z.enum([ 'WorkflowField', - 'IsIntermediate', + 'IsIntermediate', // this is technically a reserved field type! 'MetadataField', ]); @@ -797,6 +799,16 @@ export const zCustomInputFieldValue = zInputFieldValueBase.extend({ value: z.any().optional(), }); +export const zCustomCollectionInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('CustomCollection'), + value: z.array(z.any()).optional(), +}); + +export const zCustomPolymorphicInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('CustomPolymorphic'), + value: z.union([z.any(), z.array(z.any())]).optional(), +}); + export const zInputFieldValue = z.discriminatedUnion('type', [ zAnyInputFieldValue, zBoardInputFieldValue, @@ -855,6 +867,8 @@ export const zInputFieldValue = z.discriminatedUnion('type', [ zMetadataInputFieldValue, zMetadataCollectionInputFieldValue, zCustomInputFieldValue, + zCustomCollectionInputFieldValue, + zCustomPolymorphicInputFieldValue, ]); export type InputFieldValue = z.infer; @@ -878,6 +892,16 @@ export type CustomInputFieldTemplate = InputFieldTemplateBase & { default: undefined; }; +export type CustomCollectionInputFieldTemplate = InputFieldTemplateBase & { + type: 'CustomCollection'; + default: []; +}; + +export type CustomPolymorphicInputFieldTemplate = InputFieldTemplateBase & { + type: 'CustomPolymorphic'; + default: undefined; +}; + export type IntegerInputFieldTemplate = InputFieldTemplateBase & { type: 'integer'; default: number; @@ -1275,7 +1299,9 @@ export type InputFieldTemplate = | MetadataInputFieldTemplate | MetadataItemPolymorphicInputFieldTemplate | MetadataCollectionInputFieldTemplate - | CustomInputFieldTemplate; + | CustomInputFieldTemplate + | CustomCollectionInputFieldTemplate + | CustomPolymorphicInputFieldTemplate; export const isInputFieldValue = ( field?: InputFieldValue | OutputFieldValue diff --git a/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts b/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts index 3aa720ef5cc..8538e1e48de 100644 --- a/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts +++ b/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts @@ -88,6 +88,9 @@ import { isNonArraySchemaObject, isRefObject, isSchemaObject, + isFieldType, + CustomCollectionInputFieldTemplate, + CustomPolymorphicInputFieldTemplate, } from '../types/types'; export type BaseFieldProperties = 'name' | 'title' | 'description'; @@ -982,6 +985,30 @@ const buildSchedulerInputFieldTemplate = ({ return template; }; +const buildCustomCollectionInputFieldTemplate = ({ + baseField, +}: BuildInputFieldArg): CustomCollectionInputFieldTemplate => { + const template: CustomCollectionInputFieldTemplate = { + ...baseField, + type: 'CustomCollection', + default: [], + }; + + return template; +}; + +const buildCustomPolymorphicInputFieldTemplate = ({ + baseField, +}: BuildInputFieldArg): CustomPolymorphicInputFieldTemplate => { + const template: CustomPolymorphicInputFieldTemplate = { + ...baseField, + type: 'CustomPolymorphic', + default: undefined, + }; + + return template; +}; + const buildCustomInputFieldTemplate = ({ baseField, }: BuildInputFieldArg): CustomInputFieldTemplate => { @@ -996,7 +1023,7 @@ const buildCustomInputFieldTemplate = ({ export const getFieldType = ( schemaObject: OpenAPIV3_1SchemaOrRef -): string | undefined => { +): { type: string; originalType: string } | undefined => { if (isSchemaObject(schemaObject)) { if (!schemaObject.type) { // if schemaObject has no type, then it should have one of allOf, anyOf, oneOf @@ -1004,7 +1031,17 @@ export const getFieldType = ( if (schemaObject.allOf) { const allOf = schemaObject.allOf; if (allOf && allOf[0] && isRefObject(allOf[0])) { - return refObjectToSchemaName(allOf[0]); + // This is a single ref type + const originalType = refObjectToSchemaName(allOf[0]); + + if (!originalType) { + // something has gone terribly awry + return; + } + return { + type: isFieldType(originalType) ? originalType : 'Custom', + originalType, + }; } } else if (schemaObject.anyOf) { // ignore null types @@ -1017,8 +1054,17 @@ export const getFieldType = ( return true; }); if (anyOf.length === 1) { + // This is a single ref type if (isRefObject(anyOf[0])) { - return refObjectToSchemaName(anyOf[0]); + const originalType = refObjectToSchemaName(anyOf[0]); + if (!originalType) { + return; + } + + return { + type: isFieldType(originalType) ? originalType : 'Custom', + originalType, + }; } else if (isSchemaObject(anyOf[0])) { return getFieldType(anyOf[0]); } @@ -1064,16 +1110,29 @@ export const getFieldType = ( secondType = second.type; } } - if (firstType === secondType && isPolymorphicItemType(firstType)) { - return SINGLE_TO_POLYMORPHIC_MAP[firstType]; + if (firstType === secondType) { + if (isPolymorphicItemType(firstType)) { + // Known polymorphic field type + const originalType = SINGLE_TO_POLYMORPHIC_MAP[firstType]; + if (!originalType) { + return; + } + return { type: originalType, originalType }; + } + + // else custom polymorphic + return { + type: 'CustomPolymorphic', + originalType: `${firstType}Polymorphic`, + }; } } } else if (schemaObject.enum) { - return 'enum'; + return { type: 'enum', originalType: 'enum' }; } else if (schemaObject.type) { if (schemaObject.type === 'number') { // floats are "number" in OpenAPI, while ints are "integer" - we need to distinguish them - return 'float'; + return { type: 'float', originalType: 'float' }; } else if (schemaObject.type === 'array') { const itemType = isSchemaObject(schemaObject.items) ? schemaObject.items.type @@ -1085,16 +1144,39 @@ export const getFieldType = ( } if (isCollectionItemType(itemType)) { - return COLLECTION_MAP[itemType]; + // known collection field type + const originalType = COLLECTION_MAP[itemType]; + if (!originalType) { + return; + } + return { type: originalType, originalType }; } - return; - } else if (!isArray(schemaObject.type)) { - return schemaObject.type; + return { + type: 'CustomCollection', + originalType: `${itemType}Collection`, + }; + } else if ( + !isArray(schemaObject.type) && + schemaObject.type !== 'null' && // 'null' is not valid + schemaObject.type !== 'object' // 'object' is not valid + ) { + const originalType = schemaObject.type; + return { type: originalType, originalType }; } + // else ignore + return; } } else if (isRefObject(schemaObject)) { - return refObjectToSchemaName(schemaObject); + const originalType = refObjectToSchemaName(schemaObject); + if (!originalType) { + return; + } + + return { + type: isFieldType(originalType) ? originalType : 'Custom', + originalType, + }; } return; }; @@ -1159,6 +1241,8 @@ const TEMPLATE_BUILDER_MAP: { VaeField: buildVaeInputFieldTemplate, VaeModelField: buildVaeModelInputFieldTemplate, Custom: buildCustomInputFieldTemplate, + CustomCollection: buildCustomCollectionInputFieldTemplate, + CustomPolymorphic: buildCustomPolymorphicInputFieldTemplate, }; /** diff --git a/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts b/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts index f8db78ecc39..be223b2c5f3 100644 --- a/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts +++ b/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts @@ -61,6 +61,8 @@ const FIELD_VALUE_FALLBACK_MAP: { VaeField: undefined, VaeModelField: undefined, Custom: undefined, + CustomCollection: [], + CustomPolymorphic: undefined, }; export const buildInputFieldValue = ( diff --git a/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts b/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts index 6d3146f9f69..407ac0d0a72 100644 --- a/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts +++ b/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts @@ -103,14 +103,15 @@ export const parseSchema = ( return inputsAccumulator; } - let fieldType = property.ui_type ?? getFieldType(property); + const fieldTypeResult = property.ui_type + ? { type: property.ui_type, originalType: property.ui_type } + : getFieldType(property); - if (!fieldType) { + if (!fieldTypeResult) { logger('nodes').warn( { node: type, fieldName: propertyName, - fieldType, field: parseify(property), }, 'Missing input field type' @@ -119,7 +120,7 @@ export const parseSchema = ( } // stash this for custom types - const originalType = fieldType; + const { type: fieldType, originalType } = fieldTypeResult; if (fieldType === 'WorkflowField') { withWorkflow = true; @@ -139,7 +140,7 @@ export const parseSchema = ( return inputsAccumulator; } - if (!isFieldType(fieldType)) { + if (!isFieldType(originalType)) { logger('nodes').debug( { node: type, @@ -149,11 +150,19 @@ export const parseSchema = ( }, `Fallback handling for unknown input field type: ${fieldType}` ); - fieldType = 'Custom'; - if (!isFieldType(fieldType)) { - // satisfy TS gods - return inputsAccumulator; - } + } + + if (!isFieldType(fieldType)) { + logger('nodes').warn( + { + node: type, + fieldName: propertyName, + fieldType, + field: parseify(property), + }, + `Unable to parse field type: ${fieldType}` + ); + return inputsAccumulator; } const field = buildInputFieldTemplate( @@ -170,6 +179,7 @@ export const parseSchema = ( node: type, fieldName: propertyName, fieldType, + originalType, field: parseify(property), }, 'Skipping input field with no template' @@ -228,14 +238,15 @@ export const parseSchema = ( return outputsAccumulator; } - let fieldType = property.ui_type ?? getFieldType(property); + const fieldTypeResult = property.ui_type + ? { type: property.ui_type, originalType: property.ui_type } + : getFieldType(property); - if (!fieldType) { + if (!fieldTypeResult) { logger('nodes').warn( { node: type, fieldName: propertyName, - fieldType, field: parseify(property), }, 'Missing output field type' @@ -243,19 +254,32 @@ export const parseSchema = ( return outputsAccumulator; } - // stash for custom types - const originalType = fieldType; + const { type: fieldType, originalType } = fieldTypeResult; if (!isFieldType(fieldType)) { logger('nodes').debug( - { fieldName: propertyName, fieldType, field: parseify(property) }, + { + node: type, + fieldName: propertyName, + fieldType, + originalType, + field: parseify(property), + }, `Fallback handling for unknown output field type: ${fieldType}` ); - fieldType = 'Custom'; - if (!isFieldType(fieldType)) { - // satisfy TS gods - return outputsAccumulator; - } + } + + if (!isFieldType(fieldType)) { + logger('nodes').warn( + { + node: type, + fieldName: propertyName, + fieldType, + field: parseify(property), + }, + `Unable to parse field type: ${fieldType}` + ); + return outputsAccumulator; } outputsAccumulator[propertyName] = { From 9ebffcd26bdf0340bedf029470a936e1cad4808a Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 19 Nov 2023 17:44:42 +1100 Subject: [PATCH 07/12] feat(ui): add validation for CustomCollection & CustomPolymorphic types - Update connection validation for custom types - Use simple string parsing to determine if a field is a collection or polymorphic type. - No longer need to keep a list of collection and polymorphic types. - Added runtime checks in `baseinvocation.py` to ensure no fields are named in such a way that it could mess up the new parsing --- invokeai/app/invocations/baseinvocation.py | 20 ++++++- .../nodes/Invocation/fields/FieldHandle.tsx | 21 +++---- .../hooks/useAnyOrDirectInputFieldNames.ts | 8 +-- .../hooks/useConnectionInputFieldNames.ts | 9 +-- .../nodes/hooks/useIsValidConnection.ts | 16 ++++-- .../nodes/store/util/parseFieldType.ts | 14 +++++ .../util/validateSourceAndTargetTypes.ts | 57 +++++++++---------- .../web/src/features/nodes/types/constants.ts | 34 ----------- .../nodes/util/fieldTemplateBuilders.ts | 12 ++-- 9 files changed, 92 insertions(+), 99 deletions(-) create mode 100644 invokeai/frontend/web/src/features/nodes/store/util/parseFieldType.ts diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index 1b3e535d340..cfeb229a09f 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -8,7 +8,7 @@ from enum import Enum from inspect import signature from types import UnionType -from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Literal, Optional, Type, TypeVar, Union +from typing import TYPE_CHECKING, Any, Callable, ClassVar, ForwardRef, Iterable, Literal, Optional, Type, TypeVar, Union import semver from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter, create_model @@ -653,12 +653,30 @@ def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None """ Validates the fields of an invocation or invocation output: - must not override any pydantic reserved fields + - must not end with "Collection" or "Polymorphic" as these are reserved for internal use - must be created via `InputField`, `OutputField`, or be an internal field defined in this file """ for name, field in model_fields.items(): if name in RESERVED_PYDANTIC_FIELD_NAMES: raise InvalidFieldError(f'Invalid field name "{name}" on "{model_type}" (reserved by pydantic)') + if not field.annotation: + raise InvalidFieldError(f'Invalid field type "{name}" on "{model_type}" (missing annotation)') + + annotation_name = ( + field.annotation.__forward_arg__ if isinstance(field.annotation, ForwardRef) else field.annotation.__name__ + ) + + if annotation_name.endswith("Polymorphic"): + raise InvalidFieldError( + f'Invalid field type "{annotation_name}" for "{name}" on "{model_type}" (must not end in "Polymorphic")' + ) + + if annotation_name.endswith("Collection"): + raise InvalidFieldError( + f'Invalid field type "{annotation_name}" for "{name}" on "{model_type}" (must not end in "Collection")' + ) + field_kind = ( # _field_kind is defined via InputField(), OutputField() or by one of the internal fields defined in this file field.json_schema_extra.get("_field_kind", None) if field.json_schema_extra else None diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/FieldHandle.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/FieldHandle.tsx index 1a43fcdbb01..aa20ccab94a 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/FieldHandle.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/FieldHandle.tsx @@ -1,9 +1,12 @@ import { Tooltip } from '@chakra-ui/react'; +import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar'; +import { + getIsCollection, + getIsPolymorphic, +} from 'features/nodes/store/util/parseFieldType'; import { - COLLECTION_TYPES, HANDLE_TOOLTIP_OPEN_DELAY, MODEL_TYPES, - POLYMORPHIC_TYPES, } from 'features/nodes/types/constants'; import { InputFieldTemplate, @@ -50,23 +53,17 @@ const FieldHandle = (props: FieldHandleProps) => { const type = fieldTemplate.originalType ?? fieldTemplate.type; const styles: CSSProperties = useMemo(() => { - const isCollectionType = COLLECTION_TYPES.some( - (t) => t === fieldTemplate.type - ); - const isPolymorphicType = POLYMORPHIC_TYPES.some( - (t) => t === fieldTemplate.type - ); + const isCollection = getIsCollection(fieldTemplate.type); + const isPolymorphic = getIsPolymorphic(fieldTemplate.type); const isModelType = MODEL_TYPES.some((t) => t === type); const color = getFieldColor(type); const s: CSSProperties = { backgroundColor: - isCollectionType || isPolymorphicType - ? 'var(--invokeai-colors-base-900)' - : color, + isCollection || isPolymorphic ? colorTokenToCssVar('base.900') : color, position: 'absolute', width: '1rem', height: '1rem', - borderWidth: isCollectionType || isPolymorphicType ? 4 : 0, + borderWidth: isCollection || isPolymorphic ? 4 : 0, borderStyle: 'solid', borderColor: color, borderRadius: isModelType ? 4 : '100%', diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useAnyOrDirectInputFieldNames.ts b/invokeai/frontend/web/src/features/nodes/hooks/useAnyOrDirectInputFieldNames.ts index dda2efc1568..e65b036f78e 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useAnyOrDirectInputFieldNames.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useAnyOrDirectInputFieldNames.ts @@ -4,11 +4,9 @@ import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { map } from 'lodash-es'; import { useMemo } from 'react'; +import { getIsPolymorphic } from '../store/util/parseFieldType'; +import { TYPES_WITH_INPUT_COMPONENTS } from '../types/constants'; import { isInvocationNode } from '../types/types'; -import { - POLYMORPHIC_TYPES, - TYPES_WITH_INPUT_COMPONENTS, -} from '../types/constants'; import { getSortedFilteredFieldNames } from '../util/getSortedFilteredFieldNames'; export const useAnyOrDirectInputFieldNames = (nodeId: string) => { @@ -28,7 +26,7 @@ export const useAnyOrDirectInputFieldNames = (nodeId: string) => { const fields = map(nodeTemplate.inputs).filter( (field) => (['any', 'direct'].includes(field.input) || - POLYMORPHIC_TYPES.includes(field.type)) && + getIsPolymorphic(field.type)) && TYPES_WITH_INPUT_COMPONENTS.includes(field.type) ); return getSortedFilteredFieldNames(fields); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useConnectionInputFieldNames.ts b/invokeai/frontend/web/src/features/nodes/hooks/useConnectionInputFieldNames.ts index 9fb31df801d..a000e76ae01 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useConnectionInputFieldNames.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useConnectionInputFieldNames.ts @@ -4,10 +4,8 @@ import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { map } from 'lodash-es'; import { useMemo } from 'react'; -import { - POLYMORPHIC_TYPES, - TYPES_WITH_INPUT_COMPONENTS, -} from '../types/constants'; +import { getIsPolymorphic } from '../store/util/parseFieldType'; +import { TYPES_WITH_INPUT_COMPONENTS } from '../types/constants'; import { isInvocationNode } from '../types/types'; import { getSortedFilteredFieldNames } from '../util/getSortedFilteredFieldNames'; @@ -29,8 +27,7 @@ export const useConnectionInputFieldNames = (nodeId: string) => { // get the visible fields const fields = map(nodeTemplate.inputs).filter( (field) => - (field.input === 'connection' && - !POLYMORPHIC_TYPES.includes(field.type)) || + (field.input === 'connection' && !getIsPolymorphic(field.type)) || !TYPES_WITH_INPUT_COMPONENTS.includes(field.type) ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts b/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts index c88d4758af1..a924a10cfc0 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts @@ -34,10 +34,10 @@ export const useIsValidConnection = () => { return false; } - const sourceType = sourceNode.data.outputs[sourceHandle]?.type; - const targetType = targetNode.data.inputs[targetHandle]?.type; + const sourceField = sourceNode.data.outputs[sourceHandle]; + const targetField = targetNode.data.inputs[targetHandle]; - if (!sourceType || !targetType) { + if (!sourceField || !targetField) { // something has gone terribly awry return false; } @@ -70,12 +70,18 @@ export const useIsValidConnection = () => { return edge.target === target && edge.targetHandle === targetHandle; }) && // except CollectionItem inputs can have multiples - targetType !== 'CollectionItem' + targetField.type !== 'CollectionItem' ) { return false; } - if (!validateSourceAndTargetTypes(sourceType, targetType)) { + // Must use the originalType here if it exists + if ( + !validateSourceAndTargetTypes( + sourceField?.originalType ?? sourceField.type, + targetField?.originalType ?? targetField.type + ) + ) { return false; } diff --git a/invokeai/frontend/web/src/features/nodes/store/util/parseFieldType.ts b/invokeai/frontend/web/src/features/nodes/store/util/parseFieldType.ts new file mode 100644 index 00000000000..da85b97d6f7 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/util/parseFieldType.ts @@ -0,0 +1,14 @@ +import { FieldType } from 'features/nodes/types/types'; + +export const getIsPolymorphic = (type: FieldType | string): boolean => + type.endsWith('Polymorphic'); + +export const getIsCollection = (type: FieldType | string): boolean => + type.endsWith('Collection'); + +export const getBaseType = (type: FieldType | string): FieldType | string => + getIsPolymorphic(type) + ? type.replace(/Polymorphic$/, '') + : getIsCollection(type) + ? type.replace(/Collection$/, '') + : type; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/validateSourceAndTargetTypes.ts b/invokeai/frontend/web/src/features/nodes/store/util/validateSourceAndTargetTypes.ts index 123cda8e044..f4ae93fcb4d 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/validateSourceAndTargetTypes.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/validateSourceAndTargetTypes.ts @@ -1,18 +1,32 @@ -import { - COLLECTION_MAP, - COLLECTION_TYPES, - POLYMORPHIC_TO_SINGLE_MAP, - POLYMORPHIC_TYPES, -} from 'features/nodes/types/constants'; import { FieldType } from 'features/nodes/types/types'; +import { + getBaseType, + getIsCollection, + getIsPolymorphic, +} from './parseFieldType'; +/** + * Validates that the source and target types are compatible for a connection. + * @param sourceType The type of the source field. Must be the originalType if it exists. + * @param targetType The type of the target field. Must be the originalType if it exists. + * @returns True if the connection is valid, false otherwise. + */ export const validateSourceAndTargetTypes = ( sourceType: FieldType | string, targetType: FieldType | string ) => { + const isSourcePolymorphic = getIsPolymorphic(sourceType); + const isSourceCollection = getIsCollection(sourceType); + const sourceBaseType = getBaseType(sourceType); + + const isTargetPolymorphic = getIsPolymorphic(targetType); + const isTargetCollection = getIsCollection(targetType); + const targetBaseType = getBaseType(targetType); + // TODO: There's a bug with Collect -> Iterate nodes: // https://github.com/invoke-ai/InvokeAI/issues/3956 // Once this is resolved, we can remove this check. + // Note that 'Collection' here is a field type, not node type. if (sourceType === 'Collection' && targetType === 'Collection') { return false; } @@ -31,39 +45,21 @@ export const validateSourceAndTargetTypes = ( */ const isCollectionItemToNonCollection = - sourceType === 'CollectionItem' && - !COLLECTION_TYPES.some((t) => t === targetType); + sourceType === 'CollectionItem' && !isTargetCollection; const isNonCollectionToCollectionItem = targetType === 'CollectionItem' && - !COLLECTION_TYPES.some((t) => t === sourceType) && - !POLYMORPHIC_TYPES.some((t) => t === sourceType); + !isSourceCollection && + !isSourcePolymorphic; const isAnythingToPolymorphicOfSameBaseType = - POLYMORPHIC_TYPES.some((t) => t === targetType) && - (() => { - if (!POLYMORPHIC_TYPES.some((t) => t === targetType)) { - return false; - } - const baseType = - POLYMORPHIC_TO_SINGLE_MAP[ - targetType as keyof typeof POLYMORPHIC_TO_SINGLE_MAP - ]; - - const collectionType = - COLLECTION_MAP[baseType as keyof typeof COLLECTION_MAP]; - - return sourceType === baseType || sourceType === collectionType; - })(); + isTargetPolymorphic && sourceBaseType === targetBaseType; const isGenericCollectionToAnyCollectionOrPolymorphic = - sourceType === 'Collection' && - (COLLECTION_TYPES.some((t) => t === targetType) || - POLYMORPHIC_TYPES.some((t) => t === targetType)); + sourceType === 'Collection' && (isTargetCollection || isTargetPolymorphic); const isCollectionToGenericCollection = - targetType === 'Collection' && - COLLECTION_TYPES.some((t) => t === sourceType); + targetType === 'Collection' && isSourceCollection; const isIntToFloat = sourceType === 'integer' && targetType === 'float'; @@ -73,6 +69,7 @@ export const validateSourceAndTargetTypes = ( const isTargetAnyType = targetType === 'Any'; + // One of these must be true for the connection to be valid return ( isCollectionItemToNonCollection || isNonCollectionToCollectionItem || diff --git a/invokeai/frontend/web/src/features/nodes/types/constants.ts b/invokeai/frontend/web/src/features/nodes/types/constants.ts index e8297c17e3c..db5c25d309d 100644 --- a/invokeai/frontend/web/src/features/nodes/types/constants.ts +++ b/invokeai/frontend/web/src/features/nodes/types/constants.ts @@ -20,40 +20,6 @@ export const KIND_MAP = { output: 'outputs' as const, }; -export const COLLECTION_TYPES: FieldType[] = [ - 'Collection', - 'IntegerCollection', - 'BooleanCollection', - 'FloatCollection', - 'StringCollection', - 'ImageCollection', - 'LatentsCollection', - 'ConditioningCollection', - 'ControlCollection', - 'ColorCollection', - 'T2IAdapterCollection', - 'IPAdapterCollection', - 'MetadataItemCollection', - 'MetadataCollection', - 'CustomCollection', -]; - -export const POLYMORPHIC_TYPES: FieldType[] = [ - 'IntegerPolymorphic', - 'BooleanPolymorphic', - 'FloatPolymorphic', - 'StringPolymorphic', - 'ImagePolymorphic', - 'LatentsPolymorphic', - 'ConditioningPolymorphic', - 'ControlPolymorphic', - 'ColorPolymorphic', - 'T2IAdapterPolymorphic', - 'IPAdapterPolymorphic', - 'MetadataItemPolymorphic', - 'CustomPolymorphic', -]; - export const MODEL_TYPES: FieldType[] = [ 'IPAdapterModelField', 'ControlNetModelField', diff --git a/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts b/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts index 8538e1e48de..4a5df835fa3 100644 --- a/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts +++ b/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts @@ -8,9 +8,9 @@ import { } from 'lodash-es'; import { OpenAPIV3_1 } from 'openapi-types'; import { ControlField } from 'services/api/types'; +import { getIsPolymorphic } from '../store/util/parseFieldType'; import { COLLECTION_MAP, - POLYMORPHIC_TYPES, SINGLE_TO_POLYMORPHIC_MAP, isCollectionItemType, isPolymorphicItemType, @@ -35,6 +35,9 @@ import { ControlInputFieldTemplate, ControlNetModelInputFieldTemplate, ControlPolymorphicInputFieldTemplate, + CustomCollectionInputFieldTemplate, + CustomInputFieldTemplate, + CustomPolymorphicInputFieldTemplate, DenoiseMaskInputFieldTemplate, EnumInputFieldTemplate, FieldType, @@ -81,16 +84,13 @@ import { T2IAdapterModelInputFieldTemplate, T2IAdapterPolymorphicInputFieldTemplate, UNetInputFieldTemplate, - CustomInputFieldTemplate, VaeInputFieldTemplate, VaeModelInputFieldTemplate, isArraySchemaObject, + isFieldType, isNonArraySchemaObject, isRefObject, isSchemaObject, - isFieldType, - CustomCollectionInputFieldTemplate, - CustomPolymorphicInputFieldTemplate, } from '../types/types'; export type BaseFieldProperties = 'name' | 'title' | 'description'; @@ -1269,7 +1269,7 @@ export const buildInputFieldTemplate = ( const extra = { // TODO: Can we support polymorphic inputs in the UI? - input: POLYMORPHIC_TYPES.includes(fieldType) ? 'connection' : input, + input: getIsPolymorphic(fieldType) ? 'connection' : input, ui_hidden, ui_component, ui_type, From 57567d4fc39e7e56aecaa9a48308c7fb0f905afa Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 19 Nov 2023 17:45:47 +1100 Subject: [PATCH 08/12] chore(ui): remove errant console.log --- .../components/ImageMetadataViewer/ImageMetadataActions.tsx | 2 -- .../features/nodes/store/util/makeIsConnectionValidSelector.ts | 3 --- 2 files changed, 5 deletions(-) diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx index ce5b178fa25..9053ce973d6 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx @@ -157,8 +157,6 @@ const ImageMetadataActions = (props: Props) => { return null; } - console.log(metadata); - return ( <> {metadata.created_by && ( diff --git a/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts b/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts index 224c0235f75..cb7886e57e1 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts @@ -44,9 +44,6 @@ export const makeConnectionErrorSelector = ( const sourceType = handleType === 'source' ? fieldType : currentConnectionFieldType; - console.log('targetType', targetType); - console.log('sourceType', sourceType); - if (nodeId === connectionNodeId) { return i18n.t('nodes.cannotConnectToSelf'); } From e047d431117df5f2eab0ee4412d8871253aa49e1 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 19 Nov 2023 18:05:25 +1100 Subject: [PATCH 09/12] fix(ui): rename 'nodes.currentConnectionFieldType' -> 'nodes.connectionStartFieldType' This was confusingly named and kept tripping me up. Renamed to be consistent with the `reactflow` `ConnectionStartParams` type. --- .../flow/AddNodePopover/AddNodePopover.tsx | 2 +- .../connectionLines/CustomConnectionLine.tsx | 4 ++-- .../nodes/hooks/useConnectionState.ts | 2 +- .../src/features/nodes/store/nodesSlice.ts | 22 +++++++++---------- .../web/src/features/nodes/store/types.ts | 2 +- .../util/makeIsConnectionValidSelector.ts | 8 +++---- 6 files changed, 20 insertions(+), 20 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx index b514474a260..1496f793cea 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx @@ -57,7 +57,7 @@ const AddNodePopover = () => { const { t } = useTranslation(); const fieldFilter = useAppSelector( - (state) => state.nodes.currentConnectionFieldType + (state) => state.nodes.connectionStartFieldType ); const handleFilter = useAppSelector( (state) => state.nodes.connectionStartParams?.handleType diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/connectionLines/CustomConnectionLine.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/connectionLines/CustomConnectionLine.tsx index a14b7b23c6d..f3d705b3476 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/connectionLines/CustomConnectionLine.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/connectionLines/CustomConnectionLine.tsx @@ -7,11 +7,11 @@ import { ConnectionLineComponentProps, getBezierPath } from 'reactflow'; import { getFieldColor } from '../edges/util/getEdgeColor'; const selector = createSelector(stateSelector, ({ nodes }) => { - const { shouldAnimateEdges, currentConnectionFieldType, shouldColorEdges } = + const { shouldAnimateEdges, connectionStartFieldType, shouldColorEdges } = nodes; const stroke = shouldColorEdges - ? getFieldColor(currentConnectionFieldType) + ? getFieldColor(connectionStartFieldType) : colorTokenToCssVar('base.500'); let className = 'react-flow__custom_connection-path'; diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts b/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts index 96b2d652e92..cc3b2ce7ac9 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts @@ -8,7 +8,7 @@ import { useFieldType } from './useFieldType.ts'; const selectIsConnectionInProgress = createSelector( stateSelector, ({ nodes }) => - nodes.currentConnectionFieldType !== null && + nodes.connectionStartFieldType !== null && nodes.connectionStartParams !== null ); diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts index d6b1f04afd1..c47bb5d04e3 100644 --- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts @@ -93,7 +93,7 @@ export const initialNodesState: NodesState = { nodeTemplates: {}, isReady: false, connectionStartParams: null, - currentConnectionFieldType: null, + connectionStartFieldType: null, connectionMade: false, modifyingEdge: false, addNewNodePosition: null, @@ -203,7 +203,7 @@ const nodesSlice = createSlice({ nodeId && handleId && handleType && - state.currentConnectionFieldType + state.connectionStartFieldType ) { const newConnection = findConnectionToValidHandle( node, @@ -212,7 +212,7 @@ const nodesSlice = createSlice({ nodeId, handleId, handleType, - state.currentConnectionFieldType + state.connectionStartFieldType ); if (newConnection) { state.edges = addEdge( @@ -224,7 +224,7 @@ const nodesSlice = createSlice({ } state.connectionStartParams = null; - state.currentConnectionFieldType = null; + state.connectionStartFieldType = null; }, edgeChangeStarted: (state) => { state.modifyingEdge = true; @@ -258,11 +258,11 @@ const nodesSlice = createSlice({ handleType === 'source' ? node.data.outputs[handleId] : node.data.inputs[handleId]; - state.currentConnectionFieldType = + state.connectionStartFieldType = field?.originalType ?? field?.type ?? null; }, connectionMade: (state, action: PayloadAction) => { - const fieldType = state.currentConnectionFieldType; + const fieldType = state.connectionStartFieldType; if (!fieldType) { return; } @@ -287,7 +287,7 @@ const nodesSlice = createSlice({ nodeId && handleId && handleType && - state.currentConnectionFieldType + state.connectionStartFieldType ) { const newConnection = findConnectionToValidHandle( mouseOverNode, @@ -296,7 +296,7 @@ const nodesSlice = createSlice({ nodeId, handleId, handleType, - state.currentConnectionFieldType + state.connectionStartFieldType ); if (newConnection) { state.edges = addEdge( @@ -307,14 +307,14 @@ const nodesSlice = createSlice({ } } state.connectionStartParams = null; - state.currentConnectionFieldType = null; + state.connectionStartFieldType = null; } else { state.addNewNodePosition = action.payload.cursorPosition; state.isAddNodePopoverOpen = true; } } else { state.connectionStartParams = null; - state.currentConnectionFieldType = null; + state.connectionStartFieldType = null; } state.modifyingEdge = false; }, @@ -943,7 +943,7 @@ const nodesSlice = createSlice({ //Make sure these get reset if we close the popover and haven't selected a node state.connectionStartParams = null; - state.currentConnectionFieldType = null; + state.connectionStartFieldType = null; }, addNodePopoverToggled: (state) => { state.isAddNodePopoverOpen = !state.isAddNodePopoverOpen; diff --git a/invokeai/frontend/web/src/features/nodes/store/types.ts b/invokeai/frontend/web/src/features/nodes/store/types.ts index b81dd286d72..942e10a3c40 100644 --- a/invokeai/frontend/web/src/features/nodes/store/types.ts +++ b/invokeai/frontend/web/src/features/nodes/store/types.ts @@ -21,7 +21,7 @@ export type NodesState = { edges: Edge[]; nodeTemplates: Record; connectionStartParams: OnConnectStartParams | null; - currentConnectionFieldType: FieldType | string | null; + connectionStartFieldType: FieldType | string | null; connectionMade: boolean; modifyingEdge: boolean; shouldShowFieldTypeLegend: boolean; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts b/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts index cb7886e57e1..30e73afd1ec 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts @@ -22,10 +22,10 @@ export const makeConnectionErrorSelector = ( return i18n.t('nodes.noFieldType'); } - const { currentConnectionFieldType, connectionStartParams, nodes, edges } = + const { connectionStartFieldType, connectionStartParams, nodes, edges } = state.nodes; - if (!connectionStartParams || !currentConnectionFieldType) { + if (!connectionStartParams || !connectionStartFieldType) { return i18n.t('nodes.noConnectionInProgress'); } @@ -40,9 +40,9 @@ export const makeConnectionErrorSelector = ( } const targetType = - handleType === 'target' ? fieldType : currentConnectionFieldType; + handleType === 'target' ? fieldType : connectionStartFieldType; const sourceType = - handleType === 'source' ? fieldType : currentConnectionFieldType; + handleType === 'source' ? fieldType : connectionStartFieldType; if (nodeId === connectionNodeId) { return i18n.t('nodes.cannotConnectToSelf'); From f280a2ecbd9f76b0189465c63e2a74947c9886b1 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 19 Nov 2023 18:09:32 +1100 Subject: [PATCH 10/12] fix(ui): fix ts error --- .../web/src/features/nodes/store/nodesPersistDenylist.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesPersistDenylist.ts b/invokeai/frontend/web/src/features/nodes/store/nodesPersistDenylist.ts index 64fee2293f9..1322bafa431 100644 --- a/invokeai/frontend/web/src/features/nodes/store/nodesPersistDenylist.ts +++ b/invokeai/frontend/web/src/features/nodes/store/nodesPersistDenylist.ts @@ -6,7 +6,7 @@ import { NodesState } from './types'; export const nodesPersistDenylist: (keyof NodesState)[] = [ 'nodeTemplates', 'connectionStartParams', - 'currentConnectionFieldType', + 'connectionStartFieldType', 'selectedNodes', 'selectedEdges', 'isReady', From 0e640adc2cb64e1666f56cbd774b2a929807ce21 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 20 Nov 2023 10:28:38 +1100 Subject: [PATCH 11/12] feat(nodes): add runtime check for custom field names "Custom", "CustomCollection" and "CustomPolymorphic" are reserved field names. --- invokeai/app/invocations/baseinvocation.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index cfeb229a09f..baf045fd9a8 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -648,6 +648,8 @@ class _Model(BaseModel): # Get all pydantic model attrs, methods, etc RESERVED_PYDANTIC_FIELD_NAMES = {m[0] for m in inspect.getmembers(_Model())} +RESERVED_INVOKEAI_FIELD_NAMES = {"Custom", "CustomCollection", "CustomPolymorphic"} + def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None: """ @@ -677,6 +679,9 @@ def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None f'Invalid field type "{annotation_name}" for "{name}" on "{model_type}" (must not end in "Collection")' ) + if annotation_name in RESERVED_INVOKEAI_FIELD_NAMES: + raise InvalidFieldError(f'Invalid field type "{annotation_name}" for "{name}" on "{model_type}" (reserved)') + field_kind = ( # _field_kind is defined via InputField(), OutputField() or by one of the internal fields defined in this file field.json_schema_extra.get("_field_kind", None) if field.json_schema_extra else None From b65acc01370486c78bfc0d30d701ce84a613a3ef Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 20 Nov 2023 10:29:04 +1100 Subject: [PATCH 12/12] chore(ui): add TODO for revising field type names --- .../web/src/features/nodes/types/constants.ts | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/invokeai/frontend/web/src/features/nodes/types/constants.ts b/invokeai/frontend/web/src/features/nodes/types/constants.ts index db5c25d309d..3be561df703 100644 --- a/invokeai/frontend/web/src/features/nodes/types/constants.ts +++ b/invokeai/frontend/web/src/features/nodes/types/constants.ts @@ -36,6 +36,26 @@ export const MODEL_TYPES: FieldType[] = [ 'IPAdapterModelField', ]; +/** + * TODO: Revise the field type naming scheme + * + * Unfortunately, due to inconsistent naming of types, we need to keep the below map objects/callbacks. + * + * Problems: + * - some types do not use the word "Field" in their name, e.g. "Scheduler" + * - primitive types use all-lowercase names, e.g. "integer" + * - collection and polymorphic types do not use the word "Field" + * + * If these inconsistencies were resolved, we could remove these mappings and use simple string + * parsing/manipulation to handle field types. + * + * It would make some of the parsing logic simpler and reduce the maintenance overhead of adding new + * "official" field types. + * + * This will require migration logic for workflows to update their field types. Workflows *do* have a + * version attached to them, so this shouldn't be too difficult. + */ + export const COLLECTION_MAP: FieldTypeMapWithNumber = { integer: 'IntegerCollection', boolean: 'BooleanCollection',