Skip to content

Commit

Permalink
refactor(iteration_node): use Sequence and Mapping in parameters (#11483
Browse files Browse the repository at this point in the history
)

Signed-off-by: -LAN- <laipz8200@outlook.com>
  • Loading branch information
laipz8200 authored Dec 9, 2024
1 parent c3c6a48 commit 537068c
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 23 deletions.
17 changes: 9 additions & 8 deletions api/core/workflow/graph_engine/entities/event.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections.abc import Mapping
from datetime import datetime
from typing import Any, Optional

Expand Down Expand Up @@ -140,8 +141,8 @@ class BaseIterationEvent(GraphEngineEvent):

class IterationRunStartedEvent(BaseIterationEvent):
start_at: datetime = Field(..., description="start at")
inputs: Optional[dict[str, Any]] = None
metadata: Optional[dict[str, Any]] = None
inputs: Optional[Mapping[str, Any]] = None
metadata: Optional[Mapping[str, Any]] = None
predecessor_node_id: Optional[str] = None


Expand All @@ -153,18 +154,18 @@ class IterationRunNextEvent(BaseIterationEvent):

class IterationRunSucceededEvent(BaseIterationEvent):
start_at: datetime = Field(..., description="start at")
inputs: Optional[dict[str, Any]] = None
outputs: Optional[dict[str, Any]] = None
metadata: Optional[dict[str, Any]] = None
inputs: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
metadata: Optional[Mapping[str, Any]] = None
steps: int = 0
iteration_duration_map: Optional[dict[str, float]] = None


class IterationRunFailedEvent(BaseIterationEvent):
start_at: datetime = Field(..., description="start at")
inputs: Optional[dict[str, Any]] = None
outputs: Optional[dict[str, Any]] = None
metadata: Optional[dict[str, Any]] = None
inputs: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
metadata: Optional[Mapping[str, Any]] = None
steps: int = 0
error: str = Field(..., description="failed reason")

Expand Down
31 changes: 16 additions & 15 deletions api/core/workflow/nodes/iteration/iteration_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,17 +167,17 @@ def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
for index, item in enumerate(iterator_list_value):
future: Future = thread_pool.submit(
self._run_single_iter_parallel,
current_app._get_current_object(), # type: ignore
q,
iterator_list_value,
inputs,
outputs,
start_at,
graph_engine,
iteration_graph,
index,
item,
iter_run_map,
flask_app=current_app._get_current_object(), # type: ignore
q=q,
iterator_list_value=iterator_list_value,
inputs=inputs,
outputs=outputs,
start_at=start_at,
graph_engine=graph_engine,
iteration_graph=iteration_graph,
index=index,
item=item,
iter_run_map=iter_run_map,
)
future.add_done_callback(thread_pool.task_done_callback)
futures.append(future)
Expand Down Expand Up @@ -370,9 +370,9 @@ def _handle_event_metadata(
def _run_single_iter(
self,
*,
iterator_list_value: list[str],
iterator_list_value: Sequence[str],
variable_pool: VariablePool,
inputs: dict[str, list],
inputs: Mapping[str, list],
outputs: list,
start_at: datetime,
graph_engine: "GraphEngine",
Expand Down Expand Up @@ -559,10 +559,11 @@ def _run_single_iter(

def _run_single_iter_parallel(
self,
*,
flask_app: Flask,
q: Queue,
iterator_list_value: list[str],
inputs: dict[str, list],
iterator_list_value: Sequence[str],
inputs: Mapping[str, list],
outputs: list,
start_at: datetime,
graph_engine: "GraphEngine",
Expand Down

0 comments on commit 537068c

Please sign in to comment.