forked from invoke-ai/InvokeAI
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(ui): custom field types connection validation
In the initial commit, a custom field's original type was added to the *field templates* only as `originalType`. Custom fields' `type` property was `"Custom"`*. This allowed for type safety throughout the UI logic. *Actually, it was `"Unknown"`, but I changed it to custom for clarity. Connection validation logic, however, uses the *field instance* of the node/field. Like the templates, *field instances* with custom types have their `type` set to `"Custom"`, but they didn't have an `originalType` property. As a result, all custom fields could be connected to all other custom fields. To resolve this, we need to add `originalType` to the *field instances*, then switch the validation logic to use this instead of `type`. This ended up needing a bit of fanagling: - If we make `originalType` a required property on field instances, existing workflows will break during connection validation, because they won't have this property. We'd need a new layer of logic to migrate the workflows, adding the new `originalType` property. While this layer is probably needed anyways, typing `originalType` as optional is much simpler. Workflow migration logic can come layer. (Technically, we could remove all references to field types from the workflow files, and let the templates hold all this information. This feels like a significant change and I'm reluctant to do it now.) - Because `originalType` is optional, anywhere we care about the type of a field, we need to use it over `type`. So there are a number of `field.originalType ?? field.type` expressions. This is a bit of a gotcha, we'll need to remember this in the future. - We use `Array.prototype.includes()` often in the workflow editor, e.g. `COLLECTION_TYPES.includes(type)`. In these cases, the const array is of type `FieldType[]`, and `type` is is `FieldType`. Because we now support custom types, the arg `type` is now widened from `FieldType` to `string`. This causes a TS error. This behaviour is somewhat controversial (see microsoft/TypeScript#14520). These expressions are now rewritten as `COLLECTION_TYPES.some((t) => t === type)` to satisfy TS. It's logically equivalent.
1 parent
8e125b8
commit 6ee8607
Showing
13 changed files
with
371 additions
and
178 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
8 changes: 4 additions & 4 deletions
8
invokeai/frontend/web/src/features/nodes/components/flow/edges/util/getEdgeColor.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,12 @@ | ||
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar'; | ||
import { FIELD_COLORS } from 'features/nodes/types/constants'; | ||
import { FieldType } from 'features/nodes/types/field'; | ||
import { FIELDS } from 'features/nodes/types/constants'; | ||
import { FieldType } from 'features/nodes/types/types'; | ||
|
||
export const getFieldColor = (fieldType: FieldType | null): string => { | ||
export const getFieldColor = (fieldType: FieldType | string | null): string => { | ||
if (!fieldType) { | ||
return colorTokenToCssVar('base.500'); | ||
} | ||
const color = FIELD_COLORS[fieldType.name]; | ||
const color = FIELDS[fieldType]?.color; | ||
|
||
return color ? colorTokenToCssVar(color) : colorTokenToCssVar('base.500'); | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
209 changes: 135 additions & 74 deletions
209
invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
128 changes: 128 additions & 0 deletions
128
invokeai/frontend/web/src/features/nodes/store/util/buildNodeData.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants'; | ||
import { | ||
CurrentImageNodeData, | ||
InputFieldValue, | ||
InvocationNodeData, | ||
InvocationTemplate, | ||
NotesNodeData, | ||
OutputFieldValue, | ||
} from 'features/nodes/types/types'; | ||
import { buildInputFieldValue } from 'features/nodes/util/fieldValueBuilders'; | ||
import { reduce } from 'lodash-es'; | ||
import { Node, XYPosition } from 'reactflow'; | ||
import { AnyInvocationType } from 'services/events/types'; | ||
import { v4 as uuidv4 } from 'uuid'; | ||
|
||
export const SHARED_NODE_PROPERTIES: Partial<Node> = { | ||
dragHandle: `.${DRAG_HANDLE_CLASSNAME}`, | ||
}; | ||
export const buildNodeData = ( | ||
type: AnyInvocationType | 'current_image' | 'notes', | ||
position: XYPosition, | ||
template?: InvocationTemplate | ||
): | ||
| Node<CurrentImageNodeData> | ||
| Node<NotesNodeData> | ||
| Node<InvocationNodeData> | ||
| undefined => { | ||
const nodeId = uuidv4(); | ||
|
||
if (type === 'current_image') { | ||
const node: Node<CurrentImageNodeData> = { | ||
...SHARED_NODE_PROPERTIES, | ||
id: nodeId, | ||
type: 'current_image', | ||
position, | ||
data: { | ||
id: nodeId, | ||
type: 'current_image', | ||
isOpen: true, | ||
label: 'Current Image', | ||
}, | ||
}; | ||
|
||
return node; | ||
} | ||
|
||
if (type === 'notes') { | ||
const node: Node<NotesNodeData> = { | ||
...SHARED_NODE_PROPERTIES, | ||
id: nodeId, | ||
type: 'notes', | ||
position, | ||
data: { | ||
id: nodeId, | ||
isOpen: true, | ||
label: 'Notes', | ||
notes: '', | ||
type: 'notes', | ||
}, | ||
}; | ||
|
||
return node; | ||
} | ||
|
||
if (template === undefined) { | ||
console.error(`Unable to find template ${type}.`); | ||
return; | ||
} | ||
|
||
const inputs = reduce( | ||
template.inputs, | ||
(inputsAccumulator, inputTemplate, inputName) => { | ||
const fieldId = uuidv4(); | ||
|
||
const inputFieldValue: InputFieldValue = buildInputFieldValue( | ||
fieldId, | ||
inputTemplate | ||
); | ||
|
||
inputsAccumulator[inputName] = inputFieldValue; | ||
|
||
return inputsAccumulator; | ||
}, | ||
{} as Record<string, InputFieldValue> | ||
); | ||
|
||
const outputs = reduce( | ||
template.outputs, | ||
(outputsAccumulator, outputTemplate, outputName) => { | ||
const fieldId = uuidv4(); | ||
|
||
const outputFieldValue: OutputFieldValue = { | ||
id: fieldId, | ||
name: outputName, | ||
type: outputTemplate.type, | ||
fieldKind: 'output', | ||
originalType: outputTemplate.originalType, | ||
}; | ||
|
||
outputsAccumulator[outputName] = outputFieldValue; | ||
|
||
return outputsAccumulator; | ||
}, | ||
{} as Record<string, OutputFieldValue> | ||
); | ||
|
||
const invocation: Node<InvocationNodeData> = { | ||
...SHARED_NODE_PROPERTIES, | ||
id: nodeId, | ||
type: 'invocation', | ||
position, | ||
data: { | ||
id: nodeId, | ||
type, | ||
version: template.version, | ||
label: '', | ||
notes: '', | ||
isOpen: true, | ||
embedWorkflow: false, | ||
isIntermediate: type === 'save_image' ? false : true, | ||
inputs, | ||
outputs, | ||
useCache: template.useCache, | ||
}, | ||
}; | ||
|
||
return invocation; | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
90 changes: 46 additions & 44 deletions
90
invokeai/frontend/web/src/features/nodes/store/util/validateSourceAndTargetTypes.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters