Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(ui): custom field types connection validation
Browse files Browse the repository at this point in the history
In the initial commit, a custom field's original type was added to the *field templates* only as `originalType`. Custom fields' `type` property was `"Custom"`*. This allowed for type safety throughout the UI logic.

*Actually, it was `"Unknown"`, but I changed it to custom for clarity.

Connection validation logic, however, uses the *field instance* of the node/field. Like the templates, *field instances* with custom types have their `type` set to `"Custom"`, but they didn't have an `originalType` property. As a result, all custom fields could be connected to all other custom fields.

To resolve this, we need to add `originalType` to the *field instances*, then switch the validation logic to use this instead of `type`.

This ended up needing a bit of fanagling:

- If we make `originalType` a required property on field instances, existing workflows will break during connection validation, because they won't have this property. We'd need a new layer of logic to migrate the workflows, adding the new `originalType` property.

While this layer is probably needed anyways, typing `originalType` as optional is much simpler. Workflow migration logic can come layer.

(Technically, we could remove all references to field types from the workflow files, and let the templates hold all this information. This feels like a significant change and I'm reluctant to do it now.)

- Because `originalType` is optional, anywhere we care about the type of a field, we need to use it over `type`. So there are a number of `field.originalType ?? field.type` expressions. This is a bit of a gotcha, we'll need to remember this in the future.

- We use `Array.prototype.includes()` often in the workflow editor, e.g. `COLLECTION_TYPES.includes(type)`. In these cases, the const array is of type `FieldType[]`, and `type` is is `FieldType`.

Because we now support custom types, the arg `type` is now widened from `FieldType` to `string`.

This causes a TS error. This behaviour is somewhat controversial (see microsoft/TypeScript#14520). These expressions are now rewritten as `COLLECTION_TYPES.some((t) => t === type)` to satisfy TS. It's logically equivalent.
psychedelicious authored and skunkworxdark committed Nov 29, 2023
1 parent 8e125b8 commit 6ee8607
Showing 13 changed files with 371 additions and 178 deletions.
Original file line number Diff line number Diff line change
@@ -73,9 +73,13 @@ const AddNodePopover = () => {

return some(handles, (handle) => {
const sourceType =
handleFilter == 'source' ? fieldFilter : handle.type;
handleFilter == 'source'
? fieldFilter
: handle.originalType ?? handle.type;
const targetType =
handleFilter == 'target' ? fieldFilter : handle.type;
handleFilter == 'target'
? fieldFilter
: handle.originalType ?? handle.type;

return validateSourceAndTargetTypes(sourceType, targetType);
});
Original file line number Diff line number Diff line change
@@ -4,14 +4,14 @@ import { useAppSelector } from 'app/store/storeHooks';
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
import { memo } from 'react';
import { ConnectionLineComponentProps, getBezierPath } from 'reactflow';
import { getFieldColor } from 'features/nodes/components/flow/edges/util/getEdgeColor';
import { getFieldColor } from '../edges/util/getEdgeColor';

const selector = createSelector(stateSelector, ({ nodes }) => {
const { shouldAnimateEdges, connectionStartFieldType, shouldColorEdges } =
const { shouldAnimateEdges, currentConnectionFieldType, shouldColorEdges } =
nodes;

const stroke = shouldColorEdges
? getFieldColor(connectionStartFieldType)
? getFieldColor(currentConnectionFieldType)
: colorTokenToCssVar('base.500');

let className = 'react-flow__custom_connection-path';
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
import { FIELD_COLORS } from 'features/nodes/types/constants';
import { FieldType } from 'features/nodes/types/field';
import { FIELDS } from 'features/nodes/types/constants';
import { FieldType } from 'features/nodes/types/types';

export const getFieldColor = (fieldType: FieldType | null): string => {
export const getFieldColor = (fieldType: FieldType | string | null): string => {
if (!fieldType) {
return colorTokenToCssVar('base.500');
}
const color = FIELD_COLORS[fieldType.name];
const color = FIELDS[fieldType]?.color;

return color ? colorTokenToCssVar(color) : colorTokenToCssVar('base.500');
};
Original file line number Diff line number Diff line change
@@ -2,7 +2,7 @@ import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { isInvocationNode } from 'features/nodes/types/types';
import { getFieldColor } from './getEdgeColor';

export const makeEdgeSelector = (
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import { Tooltip } from '@chakra-ui/react';
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
import {
COLLECTION_TYPES,
FIELDS,
HANDLE_TOOLTIP_OPEN_DELAY,
MODEL_TYPES,
POLYMORPHIC_TYPES,
@@ -13,6 +11,7 @@ import {
} from 'features/nodes/types/types';
import { CSSProperties, memo, useMemo } from 'react';
import { Handle, HandleType, Position } from 'reactflow';
import { getFieldColor } from '../../../edges/util/getEdgeColor';

export const handleBaseStyles: CSSProperties = {
position: 'absolute',
@@ -47,14 +46,14 @@ const FieldHandle = (props: FieldHandleProps) => {
isConnectionStartField,
connectionError,
} = props;
const { name, type, originalType } = fieldTemplate;
const { color: typeColor } = FIELDS[type];
const { name } = fieldTemplate;
const type = fieldTemplate.originalType ?? fieldTemplate.type;

const styles: CSSProperties = useMemo(() => {
const isCollectionType = COLLECTION_TYPES.includes(type);
const isPolymorphicType = POLYMORPHIC_TYPES.includes(type);
const isModelType = MODEL_TYPES.includes(type);
const color = colorTokenToCssVar(typeColor);
const isCollectionType = COLLECTION_TYPES.some((t) => t === type);
const isPolymorphicType = POLYMORPHIC_TYPES.some((t) => t === type);
const isModelType = MODEL_TYPES.some((t) => t === type);
const color = getFieldColor(type);
const s: CSSProperties = {
backgroundColor:
isCollectionType || isPolymorphicType
@@ -97,23 +96,14 @@ const FieldHandle = (props: FieldHandleProps) => {
isConnectionInProgress,
isConnectionStartField,
type,
typeColor,
]);

const tooltip = useMemo(() => {
if (isConnectionInProgress && isConnectionStartField) {
return originalType;
}
if (isConnectionInProgress && connectionError) {
return connectionError ?? originalType;
return connectionError;
}
return originalType;
}, [
connectionError,
isConnectionInProgress,
isConnectionStartField,
originalType,
]);
return type;
}, [connectionError, isConnectionInProgress, type]);

return (
<Tooltip
Original file line number Diff line number Diff line change
@@ -3,8 +3,8 @@ import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useMemo } from 'react';
import { KIND_MAP } from 'features/nodes/types/constants';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { KIND_MAP } from '../types/constants';
import { isInvocationNode } from '../types/types';

export const useFieldType = (
nodeId: string,
@@ -21,7 +21,7 @@ export const useFieldType = (
return;
}
const field = node.data[KIND_MAP[kind]][fieldName];
return field?.type;
return field?.originalType ?? field?.type;
},
defaultSelectorOptions
),
209 changes: 135 additions & 74 deletions invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts

Large diffs are not rendered by default.

26 changes: 15 additions & 11 deletions invokeai/frontend/web/src/features/nodes/store/types.ts
Original file line number Diff line number Diff line change
@@ -1,26 +1,30 @@
import {
Edge,
Node,
OnConnectStartParams,
SelectionMode,
Viewport,
XYPosition,
} from 'reactflow';
import { FieldIdentifier, FieldType } from 'features/nodes/types/field';
import {
AnyNode,
InvocationNodeEdge,
FieldIdentifier,
FieldType,
InvocationEdgeExtra,
InvocationTemplate,
NodeData,
NodeExecutionState,
} from 'features/nodes/types/invocation';
import { WorkflowV2 } from 'features/nodes/types/workflow';
Workflow,
} from '../types/types';

export type NodesState = {
nodes: AnyNode[];
edges: InvocationNodeEdge[];
nodes: Node<NodeData>[];
edges: Edge<InvocationEdgeExtra>[];
nodeTemplates: Record<string, InvocationTemplate>;
connectionStartParams: OnConnectStartParams | null;
connectionStartFieldType: FieldType | null;
currentConnectionFieldType: FieldType | string | null;
connectionMade: boolean;
modifyingEdge: boolean;
shouldShowFieldTypeLegend: boolean;
shouldShowMinimapPanel: boolean;
shouldValidateGraph: boolean;
shouldAnimateEdges: boolean;
@@ -29,14 +33,14 @@ export type NodesState = {
shouldColorEdges: boolean;
selectedNodes: string[];
selectedEdges: string[];
workflow: Omit<WorkflowV2, 'nodes' | 'edges'>;
workflow: Omit<Workflow, 'nodes' | 'edges'>;
nodeExecutionStates: Record<string, NodeExecutionState>;
viewport: Viewport;
isReady: boolean;
mouseOverField: FieldIdentifier | null;
mouseOverNode: string | null;
nodesToCopy: AnyNode[];
edgesToCopy: InvocationNodeEdge[];
nodesToCopy: Node<NodeData>[];
edgesToCopy: Edge<InvocationEdgeExtra>[];
isAddNodePopoverOpen: boolean;
addNewNodePosition: XYPosition | null;
selectionMode: SelectionMode;
128 changes: 128 additions & 0 deletions invokeai/frontend/web/src/features/nodes/store/util/buildNodeData.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
import {
CurrentImageNodeData,
InputFieldValue,
InvocationNodeData,
InvocationTemplate,
NotesNodeData,
OutputFieldValue,
} from 'features/nodes/types/types';
import { buildInputFieldValue } from 'features/nodes/util/fieldValueBuilders';
import { reduce } from 'lodash-es';
import { Node, XYPosition } from 'reactflow';
import { AnyInvocationType } from 'services/events/types';
import { v4 as uuidv4 } from 'uuid';

export const SHARED_NODE_PROPERTIES: Partial<Node> = {
dragHandle: `.${DRAG_HANDLE_CLASSNAME}`,
};
export const buildNodeData = (
type: AnyInvocationType | 'current_image' | 'notes',
position: XYPosition,
template?: InvocationTemplate
):
| Node<CurrentImageNodeData>
| Node<NotesNodeData>
| Node<InvocationNodeData>
| undefined => {
const nodeId = uuidv4();

if (type === 'current_image') {
const node: Node<CurrentImageNodeData> = {
...SHARED_NODE_PROPERTIES,
id: nodeId,
type: 'current_image',
position,
data: {
id: nodeId,
type: 'current_image',
isOpen: true,
label: 'Current Image',
},
};

return node;
}

if (type === 'notes') {
const node: Node<NotesNodeData> = {
...SHARED_NODE_PROPERTIES,
id: nodeId,
type: 'notes',
position,
data: {
id: nodeId,
isOpen: true,
label: 'Notes',
notes: '',
type: 'notes',
},
};

return node;
}

if (template === undefined) {
console.error(`Unable to find template ${type}.`);
return;
}

const inputs = reduce(
template.inputs,
(inputsAccumulator, inputTemplate, inputName) => {
const fieldId = uuidv4();

const inputFieldValue: InputFieldValue = buildInputFieldValue(
fieldId,
inputTemplate
);

inputsAccumulator[inputName] = inputFieldValue;

return inputsAccumulator;
},
{} as Record<string, InputFieldValue>
);

const outputs = reduce(
template.outputs,
(outputsAccumulator, outputTemplate, outputName) => {
const fieldId = uuidv4();

const outputFieldValue: OutputFieldValue = {
id: fieldId,
name: outputName,
type: outputTemplate.type,
fieldKind: 'output',
originalType: outputTemplate.originalType,
};

outputsAccumulator[outputName] = outputFieldValue;

return outputsAccumulator;
},
{} as Record<string, OutputFieldValue>
);

const invocation: Node<InvocationNodeData> = {
...SHARED_NODE_PROPERTIES,
id: nodeId,
type: 'invocation',
position,
data: {
id: nodeId,
type,
version: template.version,
label: '',
notes: '',
isOpen: true,
embedWorkflow: false,
isIntermediate: type === 'save_image' ? false : true,
inputs,
outputs,
useCache: template.useCache,
},
};

return invocation;
};
Original file line number Diff line number Diff line change
@@ -11,7 +11,7 @@ import { validateSourceAndTargetTypes } from './validateSourceAndTargetTypes';
const isValidConnection = (
edges: Edge[],
handleCurrentType: HandleType,
handleCurrentFieldType: FieldType,
handleCurrentFieldType: FieldType | string,
node: Node,
handle: FieldInputInstance | FieldOutputInstance
) => {
@@ -34,7 +34,12 @@ const isValidConnection = (
}
}

if (!validateSourceAndTargetTypes(handleCurrentFieldType, handle.type)) {
if (
!validateSourceAndTargetTypes(
handleCurrentFieldType,
handle.originalType ?? handle.type
)
) {
isValidConnection = false;
}

@@ -48,7 +53,7 @@ export const findConnectionToValidHandle = (
handleCurrentNodeId: string,
handleCurrentName: string,
handleCurrentType: HandleType,
handleCurrentFieldType: FieldType
handleCurrentFieldType: FieldType | string
): Connection | null => {
if (node.id === handleCurrentNodeId) {
return null;
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { FieldType } from 'features/nodes/types/field';
import { FieldType } from 'features/nodes/types/types';
import i18n from 'i18next';
import { HandleType } from 'reactflow';
import { getIsGraphAcyclic } from './getIsGraphAcyclic';
@@ -15,17 +15,17 @@ export const makeConnectionErrorSelector = (
nodeId: string,
fieldName: string,
handleType: HandleType,
fieldType?: FieldType
fieldType?: FieldType | string
) => {
return createSelector(stateSelector, (state): string | undefined => {
return createSelector(stateSelector, (state) => {
if (!fieldType) {
return i18n.t('nodes.noFieldType');
}

const { connectionStartFieldType, connectionStartParams, nodes, edges } =
const { currentConnectionFieldType, connectionStartParams, nodes, edges } =
state.nodes;

if (!connectionStartParams || !connectionStartFieldType) {
if (!connectionStartParams || !currentConnectionFieldType) {
return i18n.t('nodes.noConnectionInProgress');
}

@@ -40,9 +40,9 @@ export const makeConnectionErrorSelector = (
}

const targetType =
handleType === 'target' ? fieldType : connectionStartFieldType;
handleType === 'target' ? fieldType : currentConnectionFieldType;
const sourceType =
handleType === 'source' ? fieldType : connectionStartFieldType;
handleType === 'source' ? fieldType : currentConnectionFieldType;

if (nodeId === connectionNodeId) {
return i18n.t('nodes.cannotConnectToSelf');
@@ -80,7 +80,7 @@ export const makeConnectionErrorSelector = (
return edge.target === target && edge.targetHandle === targetHandle;
}) &&
// except CollectionItem inputs can have multiples
targetType.name !== 'CollectionItemField'
targetType !== 'CollectionItem'
) {
return i18n.t('nodes.inputMayOnlyHaveOneConnection');
}
@@ -100,6 +100,6 @@ export const makeConnectionErrorSelector = (
return i18n.t('nodes.connectionWouldCreateCycle');
}

return;
return null;
});
};
Original file line number Diff line number Diff line change
@@ -1,81 +1,83 @@
import { FieldType } from 'features/nodes/types/field';
import { isEqual } from 'lodash-es';
import {
COLLECTION_MAP,
COLLECTION_TYPES,
POLYMORPHIC_TO_SINGLE_MAP,
POLYMORPHIC_TYPES,
} from 'features/nodes/types/constants';
import { FieldType } from 'features/nodes/types/types';

/**
* Validates that the source and target types are compatible for a connection.
* @param sourceType The type of the source field.
* @param targetType The type of the target field.
* @returns True if the connection is valid, false otherwise.
*/
export const validateSourceAndTargetTypes = (
sourceType: FieldType,
targetType: FieldType
sourceType: FieldType | string,
targetType: FieldType | string
) => {
// TODO: There's a bug with Collect -> Iterate nodes:
// https://github.com/invoke-ai/InvokeAI/issues/3956
// Once this is resolved, we can remove this check.
if (
sourceType.name === 'CollectionField' &&
targetType.name === 'CollectionField'
) {
if (sourceType === 'Collection' && targetType === 'Collection') {
return false;
}

if (isEqual(sourceType, targetType)) {
if (sourceType === targetType) {
return true;
}

/**
* Connection types must be the same for a connection, with exceptions:
* - CollectionItem can connect to any non-Collection
* - Non-Collections can connect to CollectionItem
* - Anything (non-Collections, Collections, CollectionOrScalar) can connect to CollectionOrScalar of the same base type
* - Generic Collection can connect to any other Collection or CollectionOrScalar
* - Anything (non-Collections, Collections, Polymorphics) can connect to Polymorphics of the same base type
* - Generic Collection can connect to any other Collection or Polymorphic
* - Any Collection can connect to a Generic Collection
*/

const isCollectionItemToNonCollection =
sourceType.name === 'CollectionItemField' && !targetType.isCollection;
sourceType === 'CollectionItem' &&
!COLLECTION_TYPES.some((t) => t === targetType);

const isNonCollectionToCollectionItem =
targetType.name === 'CollectionItemField' &&
!sourceType.isCollection &&
!sourceType.isCollectionOrScalar;
targetType === 'CollectionItem' &&
!COLLECTION_TYPES.some((t) => t === sourceType) &&
!POLYMORPHIC_TYPES.some((t) => t === sourceType);

const isAnythingToCollectionOrScalarOfSameBaseType =
targetType.isCollectionOrScalar && sourceType.name === targetType.name;
const isAnythingToPolymorphicOfSameBaseType =
POLYMORPHIC_TYPES.some((t) => t === targetType) &&
(() => {
if (!POLYMORPHIC_TYPES.some((t) => t === targetType)) {
return false;
}
const baseType =
POLYMORPHIC_TO_SINGLE_MAP[
targetType as keyof typeof POLYMORPHIC_TO_SINGLE_MAP
];

const isGenericCollectionToAnyCollectionOrCollectionOrScalar =
sourceType.name === 'CollectionField' &&
(targetType.isCollection || targetType.isCollectionOrScalar);
const collectionType =
COLLECTION_MAP[baseType as keyof typeof COLLECTION_MAP];

const isCollectionToGenericCollection =
targetType.name === 'CollectionField' && sourceType.isCollection;
return sourceType === baseType || sourceType === collectionType;
})();

const isGenericCollectionToAnyCollectionOrPolymorphic =
sourceType === 'Collection' &&
(COLLECTION_TYPES.some((t) => t === targetType) ||
POLYMORPHIC_TYPES.some((t) => t === targetType));

const areBothTypesSingle =
!sourceType.isCollection &&
!sourceType.isCollectionOrScalar &&
!targetType.isCollection &&
!targetType.isCollectionOrScalar;
const isCollectionToGenericCollection =
targetType === 'Collection' &&
COLLECTION_TYPES.some((t) => t === sourceType);

const isIntToFloat =
areBothTypesSingle &&
sourceType.name === 'IntegerField' &&
targetType.name === 'FloatField';
const isIntToFloat = sourceType === 'integer' && targetType === 'float';

const isIntOrFloatToString =
areBothTypesSingle &&
(sourceType.name === 'IntegerField' || sourceType.name === 'FloatField') &&
targetType.name === 'StringField';
(sourceType === 'integer' || sourceType === 'float') &&
targetType === 'string';

const isTargetAnyType = targetType.name === 'AnyField';
const isTargetAnyType = targetType === 'Any';

// One of these must be true for the connection to be valid
return (
isCollectionItemToNonCollection ||
isNonCollectionToCollectionItem ||
isAnythingToCollectionOrScalarOfSameBaseType ||
isGenericCollectionToAnyCollectionOrCollectionOrScalar ||
isAnythingToPolymorphicOfSameBaseType ||
isGenericCollectionToAnyCollectionOrPolymorphic ||
isCollectionToGenericCollection ||
isIntToFloat ||
isIntOrFloatToString ||
Original file line number Diff line number Diff line change
@@ -60,7 +60,7 @@ const FIELD_VALUE_FALLBACK_MAP: {
UNetField: undefined,
VaeField: undefined,
VaeModelField: undefined,
Unknown: undefined,
Custom: undefined,
};

export const buildInputFieldValue = (
@@ -77,10 +77,9 @@ export const buildInputFieldValue = (
type: template.type,
label: '',
fieldKind: 'input',
originalType: template.originalType,
value: template.default ?? FIELD_VALUE_FALLBACK_MAP[template.type],
} as InputFieldValue;

fieldValue.value =
template.default ?? FIELD_VALUE_FALLBACK_MAP[template.type];

return fieldValue;
};

0 comments on commit 6ee8607

Please sign in to comment.