diff --git a/griptape/artifacts/__init__.py b/griptape/artifacts/__init__.py index e9fc6daba..6520e60d9 100644 --- a/griptape/artifacts/__init__.py +++ b/griptape/artifacts/__init__.py @@ -11,6 +11,7 @@ from .audio_artifact import AudioArtifact from .action_artifact import ActionArtifact from .generic_artifact import GenericArtifact +from .control_flow_artifact import ControlFlowArtifact __all__ = [ @@ -27,4 +28,5 @@ "AudioArtifact", "ActionArtifact", "GenericArtifact", + "ControlFlowArtifact", ] diff --git a/griptape/artifacts/boolean_artifact.py b/griptape/artifacts/boolean_artifact.py index 5bcdfac9b..d517b5c32 100644 --- a/griptape/artifacts/boolean_artifact.py +++ b/griptape/artifacts/boolean_artifact.py @@ -4,7 +4,7 @@ from attrs import define, field -from griptape.artifacts import BaseArtifact +from griptape.artifacts import BaseArtifact, TextArtifact @define @@ -12,9 +12,11 @@ class BooleanArtifact(BaseArtifact): value: bool = field(converter=bool, metadata={"serializable": True}) @classmethod - def parse_bool(cls, value: Union[str, bool]) -> BooleanArtifact: # noqa: FBT001 - """Convert a string literal or bool to a BooleanArtifact. The string must be either "true" or "false" with any casing.""" + def parse_bool(cls, value: Union[str, bool, TextArtifact]) -> BooleanArtifact: # noqa: FBT001 + """Convert a string literal, TextArtifact or bool to a BooleanArtifact. The string must be either "true" or "false" with any casing.""" if value is not None: + if isinstance(value, TextArtifact): + value = str(value) if isinstance(value, str): if value.lower() == "true": return BooleanArtifact(True) # noqa: FBT003 diff --git a/griptape/artifacts/control_flow_artifact.py b/griptape/artifacts/control_flow_artifact.py new file mode 100644 index 000000000..4e1b6532a --- /dev/null +++ b/griptape/artifacts/control_flow_artifact.py @@ -0,0 +1,8 @@ +from attrs import define + +from griptape.artifacts import BaseArtifact + + +@define +class ControlFlowArtifact(BaseArtifact): + pass diff --git a/griptape/structures/structure.py b/griptape/structures/structure.py index 8f095dfeb..1feb350d8 100644 --- a/griptape/structures/structure.py +++ b/griptape/structures/structure.py @@ -203,6 +203,9 @@ def default_task_memory(self) -> TaskMemory: def is_finished(self) -> bool: return all(s.is_finished() for s in self.tasks) + def is_complete(self) -> bool: + return all(s.is_complete() for s in self.tasks) + def is_executing(self) -> bool: return any(s for s in self.tasks if s.is_executing()) diff --git a/griptape/structures/workflow.py b/griptape/structures/workflow.py index 53e59e751..34619a263 100644 --- a/griptape/structures/workflow.py +++ b/griptape/structures/workflow.py @@ -94,20 +94,22 @@ def insert_task( def try_run(self, *args) -> Workflow: exit_loop = False - while not self.is_finished() and not exit_loop: + while not self.is_complete() and not exit_loop: futures_list = {} - ordered_tasks = self.order_tasks() + executable_tasks = [*filter(lambda task: task.can_execute(), self.order_tasks())] - for task in ordered_tasks: - if task.can_execute(): - future = self.futures_executor_fn().submit(task.execute) - futures_list[future] = task + if not executable_tasks: + exit_loop = True + break + + for task in executable_tasks: + future = self.futures_executor_fn().submit(task.execute) + futures_list[future] = task # Wait for all tasks to complete for future in futures.as_completed(futures_list): if isinstance(future.result(), ErrorArtifact) and self.fail_fast: exit_loop = True - break if self.conversation_memory and self.output is not None: diff --git a/griptape/tasks/__init__.py b/griptape/tasks/__init__.py index 764d1669a..e08124fc8 100644 --- a/griptape/tasks/__init__.py +++ b/griptape/tasks/__init__.py @@ -21,6 +21,8 @@ from .text_to_speech_task import TextToSpeechTask from .structure_run_task import StructureRunTask from .audio_transcription_task import AudioTranscriptionTask +from .base_control_flow_task import BaseControlFlowTask +from .choice_control_flow_task import ChoiceControlFlowTask __all__ = [ "BaseTask", @@ -46,4 +48,6 @@ "TextToSpeechTask", "StructureRunTask", "AudioTranscriptionTask", + "BaseControlFlowTask", + "ChoiceControlFlowTask", ] diff --git a/griptape/tasks/base_control_flow_task.py b/griptape/tasks/base_control_flow_task.py new file mode 100644 index 000000000..b60b507ae --- /dev/null +++ b/griptape/tasks/base_control_flow_task.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +from abc import ABC + +from attrs import define + +from griptape.tasks import BaseTask + + +@define +class BaseControlFlowTask(BaseTask, ABC): + def before_run(self) -> None: + super().before_run() + + self.structure.logger.info("%s %s\nInput: %s", self.__class__.__name__, self.id, self.input.to_text()) + + def after_run(self) -> None: + super().after_run() + + self.structure.logger.info("%s %s\nOutput: %s", self.__class__.__name__, self.id, self.output.to_text()) + + def _cancel_children_rec(self, task: BaseTask, chosen_task: BaseTask) -> None: + for child in filter(lambda child: child != chosen_task, task.children): + if all(parent.is_complete() for parent in filter(lambda parent: parent != task, child.parents)): + child.state = BaseTask.State.CANCELLED + self._cancel_children_rec(child, chosen_task) + + def _get_task(self, task: str | BaseTask) -> BaseTask: + return self.structure.find_task(task) if isinstance(task, str) else task diff --git a/griptape/tasks/base_task.py b/griptape/tasks/base_task.py index ade656f87..67e21d3a0 100644 --- a/griptape/tasks/base_task.py +++ b/griptape/tasks/base_task.py @@ -23,6 +23,7 @@ class State(Enum): PENDING = 1 EXECUTING = 2 FINISHED = 3 + CANCELLED = 4 id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True) state: State = field(default=State.PENDING, kw_only=True) @@ -122,9 +123,15 @@ def is_pending(self) -> bool: def is_finished(self) -> bool: return self.state == BaseTask.State.FINISHED + def is_cancelled(self) -> bool: + return self.state == BaseTask.State.CANCELLED + def is_executing(self) -> bool: return self.state == BaseTask.State.EXECUTING + def is_complete(self) -> bool: + return self.is_finished() or self.is_cancelled() + def before_run(self) -> None: if self.structure is not None: event_bus.publish_event( @@ -168,7 +175,13 @@ def execute(self) -> Optional[BaseArtifact]: return self.output def can_execute(self) -> bool: - return self.state == BaseTask.State.PENDING and all(parent.is_finished() for parent in self.parents) + return self.is_pending() and ( + ( + all(parent.is_complete() for parent in self.parents) + and any(parent.is_finished() for parent in self.parents) + ) + or len(self.parents) == 0 + ) def reset(self) -> BaseTask: self.state = BaseTask.State.PENDING diff --git a/griptape/tasks/choice_control_flow_task.py b/griptape/tasks/choice_control_flow_task.py new file mode 100644 index 000000000..14b785db4 --- /dev/null +++ b/griptape/tasks/choice_control_flow_task.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +from typing import Callable + +from attrs import define, field + +from griptape.artifacts import BaseArtifact, ErrorArtifact, ListArtifact, TextArtifact +from griptape.tasks import BaseControlFlowTask, BaseTask + + +@define +class ChoiceControlFlowTask(BaseControlFlowTask): + control_flow_fn: Callable[[list[BaseTask] | BaseTask], list[BaseTask | str] | BaseTask | str] = field( + metadata={"serializable": False} + ) + + @property + def input(self) -> BaseArtifact: + if len(self.parents) == 1: + return self.parents[0].output if self.parents[0].output is not None else TextArtifact("") + parents = filter(lambda parent: parent.output is not None, self.parents) + return ListArtifact( + [ + parent.output + for parent in parents # pyright: ignore[reportArgumentType] + ] + ) + + def run(self) -> BaseArtifact: + tasks = self.control_flow_fn( + [artifact.value for artifact in self.input.value] + if isinstance(self.input, ListArtifact) + else self.input.value + ) + + if not isinstance(tasks, list): + tasks = [tasks] + + if tasks is None: + tasks = [] + + tasks = [self._get_task(task) for task in tasks] + + for task in tasks: + if task.id not in self.child_ids: + self.output = ErrorArtifact(f"ControlFlowTask {self.id} did not return a valid child task") + return self.output + + self.output = ( + ListArtifact( + [ + parent.value.output + for parent in filter(lambda parent: parent.value.output is not None, self.input.value) + ] + ) + if isinstance(self.input, ListArtifact) + else self.input.value.output + ) + self._cancel_children_rec(self, task) + + return self.output # pyright: ignore[reportReturnType] diff --git a/griptape/utils/structure_visualizer.py b/griptape/utils/structure_visualizer.py index f24443cd6..ea49e2f51 100644 --- a/griptape/utils/structure_visualizer.py +++ b/griptape/utils/structure_visualizer.py @@ -40,7 +40,12 @@ def to_url(self) -> str: def __render_task(self, task: BaseTask) -> str: if task.children: children = " & ".join([f"{self.__get_id(child.id)}({child.id})" for child in task.children]) - return f"{self.__get_id(task.id)}({task.id})--> {children};" + from griptape.tasks import ChoiceControlFlowTask + + if isinstance(task, ChoiceControlFlowTask): + return f"{self.__get_id(task.id)}{{{task.id}}}-.-> {children};" + else: + return f"{self.__get_id(task.id)}({task.id})--> {children};" else: return f"{self.__get_id(task.id)}({task.id});" diff --git a/tests/unit/structures/test_workflow.py b/tests/unit/structures/test_workflow.py index 242de29c5..b5cf5627f 100644 --- a/tests/unit/structures/test_workflow.py +++ b/tests/unit/structures/test_workflow.py @@ -7,7 +7,7 @@ from griptape.memory.task.storage import TextArtifactStorage from griptape.rules import Rule, Ruleset from griptape.structures import Workflow -from griptape.tasks import BaseTask, CodeExecutionTask, PromptTask, ToolkitTask +from griptape.tasks import BaseTask, ChoiceControlFlowTask, CodeExecutionTask, PromptTask, ToolkitTask from tests.mocks.mock_embedding_driver import MockEmbeddingDriver from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.mocks.mock_tool.tool import MockTool @@ -29,6 +29,134 @@ def fn(task): return CodeExecutionTask(run_fn=fn) + def test_workflow_with_control_flow_task(self): + task1 = PromptTask("prompt1", id="task1") + task1.output = TextArtifact("task1 output") + task2 = PromptTask("prompt2", id="task2") + task3 = PromptTask("prompt3", id="task3") + task4 = PromptTask("prompt4", id="end") + control_flow_task = ChoiceControlFlowTask(id="control_flow_task", control_flow_fn=lambda x: task2) + control_flow_task.add_parent(task1) + control_flow_task.add_children([task2, task3]) + task4.add_parents([task2, task3]) + workflow = Workflow(prompt_driver=MockPromptDriver(), tasks=[task1, task2, task3, task4, control_flow_task]) + workflow.resolve_relationships() + workflow.run() + + assert task1.state == BaseTask.State.FINISHED + assert task2.state == BaseTask.State.FINISHED + assert task3.state == BaseTask.State.CANCELLED + assert task4.state == BaseTask.State.FINISHED + assert control_flow_task.state == BaseTask.State.FINISHED + + def test_workflow_with_multiple_control_flow_tasks(self): + # control_flow_task should branch to task3 but + # task3 should be executed only once + # and task4 should be CANCELLED + task1 = PromptTask("prompt1", id="task1") + task2 = PromptTask("prompt2", id="task2") + task3 = PromptTask("prompt3", id="task3") + task4 = PromptTask("prompt4", id="task4") + task5 = PromptTask("prompt5", id="task5") + task6 = PromptTask("prompt6", id="task6") + control_flow_task1 = ChoiceControlFlowTask(id="control_flow_task1", control_flow_fn=lambda x: task3) + control_flow_task1.add_parent(task1) + control_flow_task1.add_children([task2, task3]) + control_flow_task2 = ChoiceControlFlowTask(id="control_flow_task2", control_flow_fn=lambda x: task5) + control_flow_task2.add_parent(task2) + control_flow_task2.add_children([task4, task5]) + task6.add_parents([task3, task4, task5]) + workflow = Workflow( + prompt_driver=MockPromptDriver(), + tasks=[task1, task2, task3, task4, task5, task6, control_flow_task1, control_flow_task2], + ) + workflow.resolve_relationships() + workflow.run() + + assert task1.state == BaseTask.State.FINISHED + assert task2.state == BaseTask.State.CANCELLED + assert task3.state == BaseTask.State.FINISHED + assert task4.state == BaseTask.State.CANCELLED + assert task5.state == BaseTask.State.CANCELLED + assert task6.state == BaseTask.State.FINISHED + assert control_flow_task1.state == BaseTask.State.FINISHED + assert control_flow_task2.state == BaseTask.State.CANCELLED + + def test_workflow_with_control_flow_task_multiple_input_parents(self): + # control_flow_task should branch to task3 but + # task3 should be executed only once + # and task4 should be CANCELLED + task1 = PromptTask("prompt1", id="task1", prompt_driver=MockPromptDriver(mock_output="3")) + task2 = PromptTask("prompt2", id="task2") + task3 = PromptTask(id="task3") + task4 = PromptTask("prompt4", id="task4") + task5 = PromptTask("prompt5", id="task5") + + def test(parents) -> tuple: + return "task3" if parents[0].output.value == "3" else "task4" + + control_flow_task = ChoiceControlFlowTask(id="control_flow_task", control_flow_fn=test) + control_flow_task.add_parents([task1, task2]) + control_flow_task.add_children([task3, task4]) + task5.add_parents([task3, task4]) + workflow = Workflow( + prompt_driver=MockPromptDriver(), tasks=[task1, task2, task3, task4, task5, control_flow_task] + ) + workflow.resolve_relationships() + workflow.run() + + assert task1.state == BaseTask.State.FINISHED + assert task2.state == BaseTask.State.FINISHED + assert task3.state == BaseTask.State.FINISHED + assert task4.state == BaseTask.State.CANCELLED + assert task5.state == BaseTask.State.FINISHED + assert control_flow_task.state == BaseTask.State.FINISHED + + def test_workflow_with_control_flow_task_multiple_child_parents(self): + # control_flow_task should branch to task3 but + # task3 should be executed only once + # and task4 should be CANCELLED + task1 = PromptTask("prompt1", id="task1") + task2 = PromptTask("prompt2", id="task2") + task3 = PromptTask(id="task3") + task4 = PromptTask("prompt4", id="task4") + task5 = PromptTask("prompt5", id="task5") + control_flow_task = ChoiceControlFlowTask(id="control_flow_task", control_flow_fn=lambda x: task3) + task2.add_parent(task1) + task2.add_child(task3) + control_flow_task.add_parent(task1) + control_flow_task.add_children([task3, task4]) + task5.add_parents([task3, task4]) + workflow = Workflow( + prompt_driver=MockPromptDriver(), tasks=[task1, task2, task3, task4, task5, control_flow_task] + ) + workflow.resolve_relationships() + workflow.run() + + assert task1.state == BaseTask.State.FINISHED + assert task2.state == BaseTask.State.FINISHED + assert task3.state == BaseTask.State.FINISHED + assert task4.state == BaseTask.State.CANCELLED + assert task5.state == BaseTask.State.FINISHED + assert control_flow_task.state == BaseTask.State.FINISHED + + for task in [task1, task2, task3, task4, task5, control_flow_task]: + task.reset() + assert task.state == BaseTask.State.PENDING + assert workflow.output is None + + # this time control_flow_task should branch to task4 + # and task3 should still be executed because it has another parent + control_flow_task.control_flow_fn = lambda x: task4 + workflow.run() + + assert task1.state == BaseTask.State.FINISHED + assert task2.state == BaseTask.State.FINISHED + assert task3.state == BaseTask.State.FINISHED + assert task4.state == BaseTask.State.FINISHED + assert task5.state == BaseTask.State.FINISHED + assert control_flow_task.state == BaseTask.State.FINISHED + def test_init(self): driver = MockPromptDriver() workflow = Workflow(prompt_driver=driver, rulesets=[Ruleset("TestRuleset", [Rule("test")])])