Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
f9d38c1
Implemented ordering for expanded iterators
JPPhoto Jan 6, 2026
fb8de3f
Update test_graph_execution_state.py
JPPhoto Jan 6, 2026
526d082
Filter invalid nested-iterator parent mappings in _prepare()
JPPhoto Jan 6, 2026
9c5ee7a
Merge branch 'main' into iterator-expansion-ordering
JPPhoto Jan 6, 2026
5bcbafb
Merge branch 'invoke-ai:main' into iterator-expansion-ordering
JPPhoto Jan 6, 2026
8e0d092
Merge branch 'main' into iterator-expansion-ordering
JPPhoto Jan 9, 2026
f75d715
Merge branch 'main' into iterator-expansion-ordering
JPPhoto Jan 11, 2026
071b6c0
Merge branch 'main' into iterator-expansion-ordering
JPPhoto Jan 11, 2026
453e02b
Merge branch 'invoke-ai:main' into iterator-expansion-ordering
JPPhoto Jan 13, 2026
5a1dcfc
Merge branch 'invoke-ai:main' into iterator-expansion-ordering
JPPhoto Jan 15, 2026
6f9f717
Fixed Collect node ordering
JPPhoto Jan 16, 2026
2783198
Merge branch 'main' into iterator-expansion-ordering
JPPhoto Jan 16, 2026
07ca01f
ruff
JPPhoto Jan 16, 2026
3291fec
Removed ordering guarantees from test_node_graph.py
JPPhoto Jan 16, 2026
fb77e84
Merge branch 'invoke-ai:main' into iterator-expansion-ordering
JPPhoto Jan 21, 2026
365220a
Merge branch 'invoke-ai:main' into iterator-expansion-ordering
JPPhoto Jan 26, 2026
af580ce
Merge branch 'invoke-ai:main' into iterator-expansion-ordering
JPPhoto Jan 27, 2026
a1999b9
Merge branch 'invoke-ai:main' into iterator-expansion-ordering
JPPhoto Jan 28, 2026
d961698
Merge branch 'invoke-ai:main' into iterator-expansion-ordering
JPPhoto Jan 29, 2026
f986658
Merge branch 'main' into iterator-expansion-ordering
JPPhoto Feb 1, 2026
20b3d34
Fix iterator prep and type compatibility in graph execution
JPPhoto Feb 1, 2026
500aac2
Merge branch 'invoke-ai:main' into iterator-expansion-ordering
JPPhoto Feb 1, 2026
d00f6c8
Merge branch 'main' into iterator-expansion-ordering
lstein Feb 1, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 99 additions & 47 deletions invokeai/app/services/shared/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,38 +124,36 @@ def is_any(t: Any) -> bool:


def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool:
if not from_type:
return False
if not to_type:
if not from_type or not to_type:
return False

# TODO: this is pretty forgiving on generic types. Clean that up (need to handle optionals and such)
if from_type and to_type:
# Ports are compatible
if from_type == to_type or is_any(from_type) or is_any(to_type):
return True
# Ports are compatible
if from_type == to_type or is_any(from_type) or is_any(to_type):
return True

if from_type in get_args(to_type):
return True
if from_type in get_args(to_type):
return True

if to_type in get_args(from_type):
return True
if to_type in get_args(from_type):
return True

# allow int -> float, pydantic will cast for us
if from_type is int and to_type is float:
return True
# allow int -> float, pydantic will cast for us
if from_type is int and to_type is float:
return True

# allow int|float -> str, pydantic will cast for us
if (from_type is int or from_type is float) and to_type is str:
return True
# allow int|float -> str, pydantic will cast for us
if (from_type is int or from_type is float) and to_type is str:
return True

# if not issubclass(from_type, to_type):
if not is_union_subtype(from_type, to_type):
return False
else:
return False
# Prefer issubclass when both are real classes
try:
if isinstance(from_type, type) and isinstance(to_type, type):
return issubclass(from_type, to_type)
except TypeError:
pass

return True
# Union-to-Union (or Union-to-non-Union) handling
return is_union_subtype(from_type, to_type)


def are_connections_compatible(
Expand Down Expand Up @@ -654,6 +652,9 @@ def _is_iterator_connection_valid(
if new_output is not None:
outputs.append(new_output)

if len(inputs) == 0:
return "Iterator must have a collection input edge"

# Only one input is allowed for iterators
if len(inputs) > 1:
return "Iterator may only have one input edge"
Expand All @@ -675,9 +676,13 @@ def _is_iterator_connection_valid(

# Collector input type must match all iterator output types
if isinstance(input_node, CollectInvocation):
collector_inputs = self._get_input_edges(input_node.id, ITEM_FIELD)
if len(collector_inputs) == 0:
return "Iterator input collector must have at least one item input edge"

# Traverse the graph to find the first collector input edge. Collectors validate that their collection
# inputs are all of the same type, so we can use the first input edge to determine the collector's type
first_collector_input_edge = self._get_input_edges(input_node.id, ITEM_FIELD)[0]
first_collector_input_edge = collector_inputs[0]
first_collector_input_type = get_output_field_type(
self.get_node(first_collector_input_edge.source.node_id), first_collector_input_edge.source.field
)
Expand Down Expand Up @@ -751,21 +756,12 @@ def nx_graph(self) -> nx.DiGraph:
g.add_edges_from({(e.source.node_id, e.destination.node_id) for e in self.edges})
return g

def nx_graph_with_data(self) -> nx.DiGraph:
"""Returns a NetworkX DiGraph representing the data and layout of this graph"""
g = nx.DiGraph()
g.add_nodes_from(list(self.nodes.items()))
g.add_edges_from({(e.source.node_id, e.destination.node_id) for e in self.edges})
return g

def nx_graph_flat(self, nx_graph: Optional[nx.DiGraph] = None) -> nx.DiGraph:
"""Returns a flattened NetworkX DiGraph, including all subgraphs (but not with iterations expanded)"""
g = nx_graph or nx.DiGraph()

# Add all nodes from this graph except graph/iteration nodes
g.add_nodes_from([n.id for n in self.nodes.values() if not isinstance(n, IterateInvocation)])

# TODO: figure out if iteration nodes need to be expanded
g.add_nodes_from([n.id for n in self.nodes.values()])

unique_edges = {(e.source.node_id, e.destination.node_id) for e in self.edges}
g.add_edges_from(unique_edges)
Expand Down Expand Up @@ -816,10 +812,57 @@ class GraphExecutionState(BaseModel):
# Optional priority; others follow in name order
ready_order: list[str] = Field(default_factory=list)
indegree: dict[str, int] = Field(default_factory=dict, description="Remaining unmet input count for exec nodes")
_iteration_path_cache: dict[str, tuple[int, ...]] = PrivateAttr(default_factory=dict)

def _type_key(self, node_obj: BaseInvocation) -> str:
return node_obj.__class__.__name__

def _get_iteration_path(self, exec_node_id: str) -> tuple[int, ...]:
"""Best-effort outer->inner iteration indices for an execution node, stopping at collectors."""
cached = self._iteration_path_cache.get(exec_node_id)
if cached is not None:
return cached

# Only prepared execution nodes participate; otherwise treat as non-iterated.
source_node_id = self.prepared_source_mapping.get(exec_node_id)
if source_node_id is None:
self._iteration_path_cache[exec_node_id] = ()
return ()

# Source-graph iterator ancestry, with edges into collectors removed so iteration context doesn't leak.
it_g = self._iterator_graph(self.graph.nx_graph())
iterator_sources = [
n for n in nx.ancestors(it_g, source_node_id) if isinstance(self.graph.get_node(n), IterateInvocation)
]

# Order iterators outer->inner via topo order of the iterator graph.
topo = list(nx.topological_sort(it_g))
topo_index = {n: i for i, n in enumerate(topo)}
iterator_sources.sort(key=lambda n: topo_index.get(n, 0))

# Map iterator source nodes to the prepared iterator exec nodes that are ancestors of exec_node_id.
eg = self.execution_graph.nx_graph()
path: list[int] = []
for it_src in iterator_sources:
prepared = self.source_prepared_mapping.get(it_src)
if not prepared:
continue
it_exec = next((p for p in prepared if nx.has_path(eg, p, exec_node_id)), None)
if it_exec is None:
continue
it_node = self.execution_graph.nodes.get(it_exec)
if isinstance(it_node, IterateInvocation):
path.append(it_node.index)

# If this exec node is itself an iterator, include its own index as the innermost element.
node_obj = self.execution_graph.nodes.get(exec_node_id)
if isinstance(node_obj, IterateInvocation):
path.append(node_obj.index)

result = tuple(path)
self._iteration_path_cache[exec_node_id] = result
return result

def _queue_for(self, cls_name: str) -> Deque[str]:
q = self._ready_queues.get(cls_name)
if q is None:
Expand All @@ -843,7 +886,15 @@ def _enqueue_if_ready(self, nid: str) -> None:
if self.indegree[nid] != 0 or nid in self.executed:
return
node_obj = self.execution_graph.nodes[nid]
self._queue_for(self._type_key(node_obj)).append(nid)
q = self._queue_for(self._type_key(node_obj))
nid_path = self._get_iteration_path(nid)
# Insert in lexicographic outer->inner order; preserve FIFO for equal paths.
for i, existing in enumerate(q):
if self._get_iteration_path(existing) > nid_path:
q.insert(i, nid)
break
else:
q.append(nid)

model_config = ConfigDict(
json_schema_extra={
Expand Down Expand Up @@ -1083,12 +1134,12 @@ def no_unexecuted_iter_ancestors(n: str) -> bool:

# Select the correct prepared parents for each iteration
# For every iterator, the parent must either not be a child of that iterator, or must match the prepared iteration for that iterator
# TODO: Handle a node mapping to none
eg = self.execution_graph.nx_graph_flat()
prepared_parent_mappings = [
[(n, self._get_iteration_node(n, g, eg, it)) for n in next_node_parents]
for it in iterator_node_prepared_combinations
] # type: ignore
prepared_parent_mappings = [m for m in prepared_parent_mappings if all(p[1] is not None for p in m)]

# Create execution node for each iteration
for iteration_mappings in prepared_parent_mappings:
Expand All @@ -1110,15 +1161,17 @@ def _get_iteration_node(
if len(prepared_nodes) == 1:
return next(iter(prepared_nodes))

# Check if the requested node is an iterator
prepared_iterator = next((n for n in prepared_nodes if n in prepared_iterator_nodes), None)
if prepared_iterator is not None:
return prepared_iterator

# Filter to only iterator nodes that are a parent of the specified node, in tuple format (prepared, source)
iterator_source_node_mapping = [(n, self.prepared_source_mapping[n]) for n in prepared_iterator_nodes]
parent_iterators = [itn for itn in iterator_source_node_mapping if nx.has_path(graph, itn[1], source_node_id)]

# If the requested node is an iterator, only accept it if it is compatible with all parent iterators
prepared_iterator = next((n for n in prepared_nodes if n in prepared_iterator_nodes), None)
if prepared_iterator is not None:
if all(nx.has_path(execution_graph, pit[0], prepared_iterator) for pit in parent_iterators):
return prepared_iterator
return None

return next(
(n for n in prepared_nodes if all(nx.has_path(execution_graph, pit[0], n) for pit in parent_iterators)),
None,
Expand Down Expand Up @@ -1156,11 +1209,10 @@ def _prepare_inputs(self, node: BaseInvocation):
# Inputs must be deep-copied, else if a node mutates the object, other nodes that get the same input
# will see the mutation.
if isinstance(node, CollectInvocation):
output_collection = [
copydeep(getattr(self.results[edge.source.node_id], edge.source.field))
for edge in input_edges
if edge.destination.field == ITEM_FIELD
]
item_edges = [e for e in input_edges if e.destination.field == ITEM_FIELD]
item_edges.sort(key=lambda e: (self._get_iteration_path(e.source.node_id), e.source.node_id))

output_collection = [copydeep(getattr(self.results[e.source.node_id], e.source.field)) for e in item_edges]
node.collection = output_collection
else:
for edge in input_edges:
Expand Down
89 changes: 89 additions & 0 deletions tests/test_graph_execution_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,3 +225,92 @@ def test_graph_iterate_execution_order(execution_number: int):
_ = invoke_next(g)
assert _[1].item == "Dinosaur Sushi"
_ = invoke_next(g)


# Because this tests deterministic ordering, we run it multiple times
@pytest.mark.parametrize("execution_number", range(5))
def test_graph_nested_iterate_execution_order(execution_number: int):
"""
Validates best-effort in-order execution for nodes expanded under nested iterators.
Expected lexicographic order by (outer_index, inner_index), subject to readiness.
"""
graph = Graph()

# Outer iterator: [0, 1]
graph.add_node(RangeInvocation(id="outer_range", start=0, stop=2, step=1))
graph.add_node(IterateInvocation(id="outer_iter"))

# Inner iterator is derived from the outer item:
# start = outer_item * 10
# stop = start + 2 => yields 2 items per outer item
graph.add_node(MultiplyInvocation(id="mul10", b=10))
graph.add_node(AddInvocation(id="stop_plus2", b=2))
graph.add_node(RangeInvocation(id="inner_range", start=0, stop=1, step=1))
graph.add_node(IterateInvocation(id="inner_iter"))

# Observe inner items (they encode outer via start=outer*10)
graph.add_node(AddInvocation(id="sum", b=0))

graph.add_edge(create_edge("outer_range", "collection", "outer_iter", "collection"))
graph.add_edge(create_edge("outer_iter", "item", "mul10", "a"))
graph.add_edge(create_edge("mul10", "value", "stop_plus2", "a"))
graph.add_edge(create_edge("mul10", "value", "inner_range", "start"))
graph.add_edge(create_edge("stop_plus2", "value", "inner_range", "stop"))
graph.add_edge(create_edge("inner_range", "collection", "inner_iter", "collection"))
graph.add_edge(create_edge("inner_iter", "item", "sum", "a"))

g = GraphExecutionState(graph=graph)
sum_values: list[int] = []

while True:
n, o = invoke_next(g)
if n is None:
break
if g.prepared_source_mapping[n.id] == "sum":
sum_values.append(o.value)

assert sum_values == [0, 1, 10, 11]


def test_graph_validate_self_iterator_without_collection_input_raises_invalid_edge_error():
"""Iterator nodes with no collection input should fail validation cleanly.

This test exposes the bug where validation crashes with IndexError instead of raising InvalidEdgeError.
"""
from invokeai.app.services.shared.graph import InvalidEdgeError

graph = Graph()
graph.add_node(IterateInvocation(id="iterate"))

with pytest.raises(InvalidEdgeError):
graph.validate_self()


def test_graph_validate_self_collector_without_item_inputs_raises_invalid_edge_error():
"""Collector nodes with no item inputs should fail validation cleanly.

This test exposes the bug where validation can crash (e.g. StopIteration) instead of raising InvalidEdgeError.
"""
from invokeai.app.services.shared.graph import InvalidEdgeError

graph = Graph()
graph.add_node(CollectInvocation(id="collect"))

with pytest.raises(InvalidEdgeError):
graph.validate_self()


def test_are_connection_types_compatible_accepts_subclass_to_base():
"""A subclass output should be connectable to a base-class input.

This test exposes the bug where non-Union targets reject valid subclass connections.
"""
from invokeai.app.services.shared.graph import are_connection_types_compatible

class Base:
pass

class Child(Base):
pass

assert are_connection_types_compatible(Child, Base) is True
2 changes: 1 addition & 1 deletion tests/test_node_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,7 +785,7 @@ def test_collector_different_incomers():
run_session_with_mock_context(session)
output = get_single_output_from_session(session, n3.id)
assert isinstance(output, CollectInvocationOutput)
assert output.collection == ["Banana", "Sushi"] # Both inputs should be collected
assert set(output.collection) == {"Banana", "Sushi"} # Both inputs should be collected, no order guarantee


def test_iterator_collector_iterator_chain():
Expand Down