Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ui): add support for custom field types #5113

Closed
wants to merge 14 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion invokeai/app/invocations/baseinvocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from enum import Enum
from inspect import signature
from types import UnionType
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Literal, Optional, Type, TypeVar, Union
from typing import TYPE_CHECKING, Any, Callable, ClassVar, ForwardRef, Iterable, Literal, Optional, Type, TypeVar, Union

import semver
from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter, create_model
Expand Down Expand Up @@ -648,17 +648,40 @@ class _Model(BaseModel):
# Get all pydantic model attrs, methods, etc
RESERVED_PYDANTIC_FIELD_NAMES = {m[0] for m in inspect.getmembers(_Model())}

RESERVED_INVOKEAI_FIELD_NAMES = {"Custom", "CustomCollection", "CustomPolymorphic"}


def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None:
"""
Validates the fields of an invocation or invocation output:
- must not override any pydantic reserved fields
- must not end with "Collection" or "Polymorphic" as these are reserved for internal use
- must be created via `InputField`, `OutputField`, or be an internal field defined in this file
"""
for name, field in model_fields.items():
if name in RESERVED_PYDANTIC_FIELD_NAMES:
raise InvalidFieldError(f'Invalid field name "{name}" on "{model_type}" (reserved by pydantic)')

if not field.annotation:
raise InvalidFieldError(f'Invalid field type "{name}" on "{model_type}" (missing annotation)')

annotation_name = (
field.annotation.__forward_arg__ if isinstance(field.annotation, ForwardRef) else field.annotation.__name__
)

if annotation_name.endswith("Polymorphic"):
raise InvalidFieldError(
f'Invalid field type "{annotation_name}" for "{name}" on "{model_type}" (must not end in "Polymorphic")'
)

if annotation_name.endswith("Collection"):
raise InvalidFieldError(
f'Invalid field type "{annotation_name}" for "{name}" on "{model_type}" (must not end in "Collection")'
)

if annotation_name in RESERVED_INVOKEAI_FIELD_NAMES:
raise InvalidFieldError(f'Invalid field type "{annotation_name}" for "{name}" on "{model_type}" (reserved)')

field_kind = (
# _field_kind is defined via InputField(), OutputField() or by one of the internal fields defined in this file
field.json_schema_extra.get("_field_kind", None) if field.json_schema_extra else None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,6 @@ const ImageMetadataActions = (props: Props) => {
return null;
}

console.log(metadata);

return (
<>
{metadata.created_by && (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ const AddNodePopover = () => {
const { t } = useTranslation();

const fieldFilter = useAppSelector(
(state) => state.nodes.currentConnectionFieldType
(state) => state.nodes.connectionStartFieldType
);
const handleFilter = useAppSelector(
(state) => state.nodes.connectionStartParams?.handleType
Expand All @@ -74,9 +74,13 @@ const AddNodePopover = () => {

return some(handles, (handle) => {
const sourceType =
handleFilter == 'source' ? fieldFilter : handle.type;
handleFilter == 'source'
? fieldFilter
: handle.originalType ?? handle.type;
const targetType =
handleFilter == 'target' ? fieldFilter : handle.type;
handleFilter == 'target'
? fieldFilter
: handle.originalType ?? handle.type;

return validateSourceAndTargetTypes(sourceType, targetType);
});
Expand Down Expand Up @@ -111,7 +115,7 @@ const AddNodePopover = () => {

data.sort((a, b) => a.label.localeCompare(b.label));

return { data, t };
return { data };
},
defaultSelectorOptions
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,17 @@ import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
import { FIELDS } from 'features/nodes/types/constants';
import { memo } from 'react';
import { ConnectionLineComponentProps, getBezierPath } from 'reactflow';
import { getFieldColor } from '../edges/util/getEdgeColor';

const selector = createSelector(stateSelector, ({ nodes }) => {
const { shouldAnimateEdges, currentConnectionFieldType, shouldColorEdges } =
const { shouldAnimateEdges, connectionStartFieldType, shouldColorEdges } =
nodes;

const stroke =
currentConnectionFieldType && shouldColorEdges
? colorTokenToCssVar(FIELDS[currentConnectionFieldType].color)
: colorTokenToCssVar('base.500');
const stroke = shouldColorEdges
? getFieldColor(connectionStartFieldType)
: colorTokenToCssVar('base.500');

let className = 'react-flow__custom_connection-path';

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
import { FIELDS } from 'features/nodes/types/constants';
import { FieldType } from 'features/nodes/types/types';

export const getFieldColor = (fieldType: FieldType | string | null): string => {
if (!fieldType) {
return colorTokenToCssVar('base.500');
}
const color = FIELDS[fieldType]?.color;

return color ? colorTokenToCssVar(color) : colorTokenToCssVar('base.500');
};
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
import { FIELDS } from 'features/nodes/types/constants';
import { isInvocationNode } from 'features/nodes/types/types';
import { getFieldColor } from './getEdgeColor';

export const makeEdgeSelector = (
source: string,
Expand All @@ -29,7 +29,7 @@ export const makeEdgeSelector = (

const stroke =
sourceType && nodes.shouldColorEdges
? colorTokenToCssVar(FIELDS[sourceType].color)
? getFieldColor(sourceType)
: colorTokenToCssVar('base.500');

return {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
import { Tooltip } from '@chakra-ui/react';
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
import {
COLLECTION_TYPES,
FIELDS,
getIsCollection,
getIsPolymorphic,
} from 'features/nodes/store/util/parseFieldType';
import {
HANDLE_TOOLTIP_OPEN_DELAY,
MODEL_TYPES,
POLYMORPHIC_TYPES,
} from 'features/nodes/types/constants';
import {
InputFieldTemplate,
OutputFieldTemplate,
} from 'features/nodes/types/types';
import { CSSProperties, memo, useMemo } from 'react';
import { Handle, HandleType, Position } from 'reactflow';
import { getFieldColor } from '../../../edges/util/getEdgeColor';

export const handleBaseStyles: CSSProperties = {
position: 'absolute',
Expand Down Expand Up @@ -47,23 +49,21 @@ const FieldHandle = (props: FieldHandleProps) => {
isConnectionStartField,
connectionError,
} = props;
const { name, type } = fieldTemplate;
const { color: typeColor, title } = FIELDS[type];
const { name } = fieldTemplate;
const type = fieldTemplate.originalType ?? fieldTemplate.type;

const styles: CSSProperties = useMemo(() => {
const isCollectionType = COLLECTION_TYPES.includes(type);
const isPolymorphicType = POLYMORPHIC_TYPES.includes(type);
const isModelType = MODEL_TYPES.includes(type);
const color = colorTokenToCssVar(typeColor);
const isCollection = getIsCollection(fieldTemplate.type);
const isPolymorphic = getIsPolymorphic(fieldTemplate.type);
const isModelType = MODEL_TYPES.some((t) => t === type);
const color = getFieldColor(type);
const s: CSSProperties = {
backgroundColor:
isCollectionType || isPolymorphicType
? 'var(--invokeai-colors-base-900)'
: color,
isCollection || isPolymorphic ? colorTokenToCssVar('base.900') : color,
position: 'absolute',
width: '1rem',
height: '1rem',
borderWidth: isCollectionType || isPolymorphicType ? 4 : 0,
borderWidth: isCollection || isPolymorphic ? 4 : 0,
borderStyle: 'solid',
borderColor: color,
borderRadius: isModelType ? 4 : '100%',
Expand Down Expand Up @@ -93,22 +93,19 @@ const FieldHandle = (props: FieldHandleProps) => {
return s;
}, [
connectionError,
fieldTemplate.type,
handleType,
isConnectionInProgress,
isConnectionStartField,
type,
typeColor,
]);

const tooltip = useMemo(() => {
if (isConnectionInProgress && isConnectionStartField) {
return title;
}
if (isConnectionInProgress && connectionError) {
return connectionError ?? title;
return connectionError;
}
return title;
}, [connectionError, isConnectionInProgress, isConnectionStartField, title]);
return type;
}, [connectionError, isConnectionInProgress, type]);

return (
<Tooltip
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
import { Flex, Text } from '@chakra-ui/react';
import { useFieldData } from 'features/nodes/hooks/useFieldData';
import { useFieldTemplate } from 'features/nodes/hooks/useFieldTemplate';
import { FIELDS } from 'features/nodes/types/constants';
import {
isInputFieldTemplate,
isInputFieldValue,
} from 'features/nodes/types/types';
import { startCase } from 'lodash-es';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';

interface Props {
nodeId: string;
fieldName: string;
Expand Down Expand Up @@ -49,7 +47,7 @@ const FieldTooltipContent = ({ nodeId, fieldName, kind }: Props) => {
{fieldTemplate.description}
</Text>
)}
{fieldTemplate && <Text>Type: {FIELDS[fieldTemplate.type].title}</Text>}
{fieldTemplate && <Text>Type: {fieldTemplate.originalType}</Text>}
{isInputTemplate && <Text>Input: {startCase(fieldTemplate.input)}</Text>}
</Flex>
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,9 @@ import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { map } from 'lodash-es';
import { useMemo } from 'react';
import { getIsPolymorphic } from '../store/util/parseFieldType';
import { TYPES_WITH_INPUT_COMPONENTS } from '../types/constants';
import { isInvocationNode } from '../types/types';
import {
POLYMORPHIC_TYPES,
TYPES_WITH_INPUT_COMPONENTS,
} from '../types/constants';
import { getSortedFilteredFieldNames } from '../util/getSortedFilteredFieldNames';

export const useAnyOrDirectInputFieldNames = (nodeId: string) => {
Expand All @@ -28,7 +26,7 @@ export const useAnyOrDirectInputFieldNames = (nodeId: string) => {
const fields = map(nodeTemplate.inputs).filter(
(field) =>
(['any', 'direct'].includes(field.input) ||
POLYMORPHIC_TYPES.includes(field.type)) &&
getIsPolymorphic(field.type)) &&
TYPES_WITH_INPUT_COMPONENTS.includes(field.type)
);
return getSortedFilteredFieldNames(fields);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,8 @@ import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { map } from 'lodash-es';
import { useMemo } from 'react';
import {
POLYMORPHIC_TYPES,
TYPES_WITH_INPUT_COMPONENTS,
} from '../types/constants';
import { getIsPolymorphic } from '../store/util/parseFieldType';
import { TYPES_WITH_INPUT_COMPONENTS } from '../types/constants';
import { isInvocationNode } from '../types/types';
import { getSortedFilteredFieldNames } from '../util/getSortedFilteredFieldNames';

Expand All @@ -29,8 +27,7 @@ export const useConnectionInputFieldNames = (nodeId: string) => {
// get the visible fields
const fields = map(nodeTemplate.inputs).filter(
(field) =>
(field.input === 'connection' &&
!POLYMORPHIC_TYPES.includes(field.type)) ||
(field.input === 'connection' && !getIsPolymorphic(field.type)) ||
!TYPES_WITH_INPUT_COMPONENTS.includes(field.type)
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import { useFieldType } from './useFieldType.ts';
const selectIsConnectionInProgress = createSelector(
stateSelector,
({ nodes }) =>
nodes.currentConnectionFieldType !== null &&
nodes.connectionStartFieldType !== null &&
nodes.connectionStartParams !== null
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ export const useFieldType = (
if (!isInvocationNode(node)) {
return;
}
return node?.data[KIND_MAP[kind]][fieldName]?.type;
const field = node.data[KIND_MAP[kind]][fieldName];
return field?.originalType ?? field?.type;
},
defaultSelectorOptions
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ export const useIsValidConnection = () => {
return false;
}

const sourceType = sourceNode.data.outputs[sourceHandle]?.type;
const targetType = targetNode.data.inputs[targetHandle]?.type;
const sourceField = sourceNode.data.outputs[sourceHandle];
const targetField = targetNode.data.inputs[targetHandle];

if (!sourceType || !targetType) {
if (!sourceField || !targetField) {
// something has gone terribly awry
return false;
}
Expand Down Expand Up @@ -70,12 +70,18 @@ export const useIsValidConnection = () => {
return edge.target === target && edge.targetHandle === targetHandle;
}) &&
// except CollectionItem inputs can have multiples
targetType !== 'CollectionItem'
targetField.type !== 'CollectionItem'
) {
return false;
}

if (!validateSourceAndTargetTypes(sourceType, targetType)) {
// Must use the originalType here if it exists
if (
!validateSourceAndTargetTypes(
sourceField?.originalType ?? sourceField.type,
targetField?.originalType ?? targetField.type
)
) {
return false;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import { NodesState } from './types';
export const nodesPersistDenylist: (keyof NodesState)[] = [
'nodeTemplates',
'connectionStartParams',
'currentConnectionFieldType',
'connectionStartFieldType',
'selectedNodes',
'selectedEdges',
'isReady',
Expand Down
Loading