From 5dd0bc53ba58528e12df595eaa7a576ee364a85e Mon Sep 17 00:00:00 2001 From: matt Date: Wed, 3 Jul 2024 13:45:15 -0500 Subject: [PATCH 1/2] Add `ControlFlowTask` --- griptape/artifacts/__init__.py | 4 + griptape/artifacts/boolean_artifact.py | 11 +- griptape/artifacts/control_flow_artifact.py | 7 + griptape/artifacts/task_artifact.py | 30 ++++ griptape/memory/meta/__init__.py | 3 +- .../memory/meta/control_flow_meta_entry.py | 15 ++ griptape/structures/structure.py | 3 + griptape/structures/workflow.py | 16 +- griptape/tasks/__init__.py | 4 + griptape/tasks/base_control_flow_task.py | 35 +++++ griptape/tasks/base_task.py | 15 +- griptape/tasks/boolean_control_flow_task.py | 45 ++++++ griptape/tasks/choice_control_flow_task.py | 64 ++++++++ griptape/utils/structure_visualizer.py | 7 +- tests/unit/structures/test_workflow.py | 140 ++++++++++++++++++ 15 files changed, 385 insertions(+), 14 deletions(-) create mode 100644 griptape/artifacts/control_flow_artifact.py create mode 100644 griptape/artifacts/task_artifact.py create mode 100644 griptape/memory/meta/control_flow_meta_entry.py create mode 100644 griptape/tasks/base_control_flow_task.py create mode 100644 griptape/tasks/boolean_control_flow_task.py create mode 100644 griptape/tasks/choice_control_flow_task.py diff --git a/griptape/artifacts/__init__.py b/griptape/artifacts/__init__.py index e9fc6daba..13a1c8385 100644 --- a/griptape/artifacts/__init__.py +++ b/griptape/artifacts/__init__.py @@ -11,6 +11,8 @@ from .audio_artifact import AudioArtifact from .action_artifact import ActionArtifact from .generic_artifact import GenericArtifact +from .control_flow_artifact import ControlFlowArtifact +from .task_artifact import TaskArtifact __all__ = [ @@ -27,4 +29,6 @@ "AudioArtifact", "ActionArtifact", "GenericArtifact", + "ControlFlowArtifact", + "TaskArtifact", ] diff --git a/griptape/artifacts/boolean_artifact.py b/griptape/artifacts/boolean_artifact.py index 5bcdfac9b..cadf6724d 100644 --- a/griptape/artifacts/boolean_artifact.py +++ b/griptape/artifacts/boolean_artifact.py @@ -3,8 +3,7 @@ from typing import Union from attrs import define, field - -from griptape.artifacts import BaseArtifact +from griptape.artifacts import TextArtifact, BaseArtifact @define @@ -12,9 +11,13 @@ class BooleanArtifact(BaseArtifact): value: bool = field(converter=bool, metadata={"serializable": True}) @classmethod - def parse_bool(cls, value: Union[str, bool]) -> BooleanArtifact: # noqa: FBT001 - """Convert a string literal or bool to a BooleanArtifact. The string must be either "true" or "false" with any casing.""" + def parse_bool(cls, value: Union[str, bool, TextArtifact]) -> BooleanArtifact: + """ + Convert a string literal or bool to a BooleanArtifact. The string must be either "true" or "false" with any casing. + """ if value is not None: + if isinstance(value, TextArtifact): + value = str(value) if isinstance(value, str): if value.lower() == "true": return BooleanArtifact(True) # noqa: FBT003 diff --git a/griptape/artifacts/control_flow_artifact.py b/griptape/artifacts/control_flow_artifact.py new file mode 100644 index 000000000..6908bf139 --- /dev/null +++ b/griptape/artifacts/control_flow_artifact.py @@ -0,0 +1,7 @@ +from attrs import define +from griptape.artifacts import BaseArtifact + + +@define +class ControlFlowArtifact(BaseArtifact): + pass diff --git a/griptape/artifacts/task_artifact.py b/griptape/artifacts/task_artifact.py new file mode 100644 index 000000000..508815dd6 --- /dev/null +++ b/griptape/artifacts/task_artifact.py @@ -0,0 +1,30 @@ +from __future__ import annotations +from attrs import define, field +from typing import TYPE_CHECKING +from griptape.artifacts import ControlFlowArtifact + +if TYPE_CHECKING: + from griptape.tasks import BaseTask + from griptape.artifacts import BaseArtifact + + +@define +class TaskArtifact(ControlFlowArtifact): + value: BaseTask = field(metadata={"serializable": True}) + + @property + def task_id(self) -> str: + return self.value.id + + @property + def task(self) -> BaseTask: + return self.value + + def to_text(self) -> str: + return self.value.id + + def __add__(self, other: BaseArtifact) -> BaseArtifact: + raise NotImplementedError("TaskArtifact does not support addition") + + def __eq__(self, value: object) -> bool: + return self.value is value diff --git a/griptape/memory/meta/__init__.py b/griptape/memory/meta/__init__.py index 56b6d607c..0c9fd1824 100644 --- a/griptape/memory/meta/__init__.py +++ b/griptape/memory/meta/__init__.py @@ -1,5 +1,6 @@ from .base_meta_entry import BaseMetaEntry from .action_subtask_meta_entry import ActionSubtaskMetaEntry +from .control_flow_meta_entry import ControlFlowMetaEntry from .meta_memory import MetaMemory -__all__ = ["BaseMetaEntry", "MetaMemory", "ActionSubtaskMetaEntry"] +__all__ = ["BaseMetaEntry", "MetaMemory", "ActionSubtaskMetaEntry", "ControlFlowMetaEntry"] diff --git a/griptape/memory/meta/control_flow_meta_entry.py b/griptape/memory/meta/control_flow_meta_entry.py new file mode 100644 index 000000000..0d5d5ef84 --- /dev/null +++ b/griptape/memory/meta/control_flow_meta_entry.py @@ -0,0 +1,15 @@ +from __future__ import annotations +from typing import TYPE_CHECKING, Optional +from attrs import field, define +from griptape.memory.meta import BaseMetaEntry + +if TYPE_CHECKING: + from griptape.artifacts import BaseArtifact + + +@define +class ControlFlowMetaEntry(BaseMetaEntry): + type: str = field(default=__name__, kw_only=True, metadata={"serializable": False}) + input_tasks: list[str] = field(factory=list, kw_only=True) + output_tasks: list[str] = field(factory=list, kw_only=True) + output: Optional[BaseArtifact] = field(default=None, kw_only=True) diff --git a/griptape/structures/structure.py b/griptape/structures/structure.py index 8f095dfeb..1feb350d8 100644 --- a/griptape/structures/structure.py +++ b/griptape/structures/structure.py @@ -203,6 +203,9 @@ def default_task_memory(self) -> TaskMemory: def is_finished(self) -> bool: return all(s.is_finished() for s in self.tasks) + def is_complete(self) -> bool: + return all(s.is_complete() for s in self.tasks) + def is_executing(self) -> bool: return any(s for s in self.tasks if s.is_executing()) diff --git a/griptape/structures/workflow.py b/griptape/structures/workflow.py index 53e59e751..34619a263 100644 --- a/griptape/structures/workflow.py +++ b/griptape/structures/workflow.py @@ -94,20 +94,22 @@ def insert_task( def try_run(self, *args) -> Workflow: exit_loop = False - while not self.is_finished() and not exit_loop: + while not self.is_complete() and not exit_loop: futures_list = {} - ordered_tasks = self.order_tasks() + executable_tasks = [*filter(lambda task: task.can_execute(), self.order_tasks())] - for task in ordered_tasks: - if task.can_execute(): - future = self.futures_executor_fn().submit(task.execute) - futures_list[future] = task + if not executable_tasks: + exit_loop = True + break + + for task in executable_tasks: + future = self.futures_executor_fn().submit(task.execute) + futures_list[future] = task # Wait for all tasks to complete for future in futures.as_completed(futures_list): if isinstance(future.result(), ErrorArtifact) and self.fail_fast: exit_loop = True - break if self.conversation_memory and self.output is not None: diff --git a/griptape/tasks/__init__.py b/griptape/tasks/__init__.py index 764d1669a..e08124fc8 100644 --- a/griptape/tasks/__init__.py +++ b/griptape/tasks/__init__.py @@ -21,6 +21,8 @@ from .text_to_speech_task import TextToSpeechTask from .structure_run_task import StructureRunTask from .audio_transcription_task import AudioTranscriptionTask +from .base_control_flow_task import BaseControlFlowTask +from .choice_control_flow_task import ChoiceControlFlowTask __all__ = [ "BaseTask", @@ -46,4 +48,6 @@ "TextToSpeechTask", "StructureRunTask", "AudioTranscriptionTask", + "BaseControlFlowTask", + "ChoiceControlFlowTask", ] diff --git a/griptape/tasks/base_control_flow_task.py b/griptape/tasks/base_control_flow_task.py new file mode 100644 index 000000000..c7196b203 --- /dev/null +++ b/griptape/tasks/base_control_flow_task.py @@ -0,0 +1,35 @@ +from __future__ import annotations +from abc import ABC +from attrs import define +from griptape.tasks import BaseTask +from griptape.memory.meta import ControlFlowMetaEntry + + +@define +class BaseControlFlowTask(BaseTask, ABC): + def before_run(self) -> None: + super().before_run() + + self.structure.logger.info(f"{self.__class__.__name__} {self.id}\nInput: {self.input.to_text()}") + + def after_run(self) -> None: + super().after_run() + + self.structure.meta_memory.add_entry( + ControlFlowMetaEntry( + input_tasks=[parent.id for parent in self.parents], + output_tasks=[child.id for child in filter(lambda child: not child.is_finished(), self.children)], + output=self.output, + ) + ) + + self.structure.logger.info(f"{self.__class__.__name__} {self.id}\nOutput: {self.output.to_text()}") + + def _cancel_children_rec(self, task: BaseTask, chosen_task: BaseTask) -> None: + for child in filter(lambda child: child != chosen_task, task.children): + if all(parent.is_complete() for parent in filter(lambda parent: parent != task, child.parents)): + child.state = BaseTask.State.CANCELLED + self._cancel_children_rec(child, chosen_task) + + def _get_task(self, task: str | BaseTask) -> BaseTask: + return self.structure.find_task(task) if isinstance(task, str) else task diff --git a/griptape/tasks/base_task.py b/griptape/tasks/base_task.py index ade656f87..67e21d3a0 100644 --- a/griptape/tasks/base_task.py +++ b/griptape/tasks/base_task.py @@ -23,6 +23,7 @@ class State(Enum): PENDING = 1 EXECUTING = 2 FINISHED = 3 + CANCELLED = 4 id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True) state: State = field(default=State.PENDING, kw_only=True) @@ -122,9 +123,15 @@ def is_pending(self) -> bool: def is_finished(self) -> bool: return self.state == BaseTask.State.FINISHED + def is_cancelled(self) -> bool: + return self.state == BaseTask.State.CANCELLED + def is_executing(self) -> bool: return self.state == BaseTask.State.EXECUTING + def is_complete(self) -> bool: + return self.is_finished() or self.is_cancelled() + def before_run(self) -> None: if self.structure is not None: event_bus.publish_event( @@ -168,7 +175,13 @@ def execute(self) -> Optional[BaseArtifact]: return self.output def can_execute(self) -> bool: - return self.state == BaseTask.State.PENDING and all(parent.is_finished() for parent in self.parents) + return self.is_pending() and ( + ( + all(parent.is_complete() for parent in self.parents) + and any(parent.is_finished() for parent in self.parents) + ) + or len(self.parents) == 0 + ) def reset(self) -> BaseTask: self.state = BaseTask.State.PENDING diff --git a/griptape/tasks/boolean_control_flow_task.py b/griptape/tasks/boolean_control_flow_task.py new file mode 100644 index 000000000..a505b4b9a --- /dev/null +++ b/griptape/tasks/boolean_control_flow_task.py @@ -0,0 +1,45 @@ +from __future__ import annotations +from typing import TYPE_CHECKING, Literal, Union +from attrs import field + +from griptape.artifacts import BooleanArtifact +from griptape.tasks import BaseControlFlowTask + +if TYPE_CHECKING: + from griptape.tasks import BaseTask + + +class BooleanControlFlowTask(BaseControlFlowTask): + true_tasks: list[str | BaseTask] = field(factory=list, kw_only=True) + false_tasks: list[str | BaseTask] = field(factory=list, kw_only=True) + operator: Union[Literal["and"], Literal["or"], Literal["xor"]] = field(default="and", kw_only=True) + coerce_inputs_to_bool: bool = field(default=False, kw_only=True) + + def run(self) -> BooleanArtifact: + if not all( + choice_task if isinstance(choice_task, str) else choice_task.id in self.child_ids + for choice_task in [*self.true_tasks, *self.false_tasks] + ): + raise ValueError(f"BooleanControlFlowTask {self.id} has invalid true_tasks or false_tasks") + + inputs = [task.output for task in self.parents] + + if self.coerce_inputs_to_bool: + inputs = [BooleanArtifact(input) for input in inputs] + else: + if not all(isinstance(input, BooleanArtifact) for input in inputs): + raise ValueError(f"BooleanControlFlowTask {self.id} has non-BooleanArtifact inputs") + + if self.operator == "and": + self.output = BooleanArtifact(all(inputs)) + elif self.operator == "or": + self.output = BooleanArtifact(any(inputs)) + elif self.operator == "xor": + self.output = BooleanArtifact(sum([int(input.value) for input in inputs]) == 1) + else: + raise ValueError(f"BooleanControlFlowTask {self.id} has invalid operator {self.operator}") + + for task in self.true_tasks if self.output.value else self.false_tasks: + task = self._get_task(task) + self._cancel_children_rec(self, task) + return self.output diff --git a/griptape/tasks/choice_control_flow_task.py b/griptape/tasks/choice_control_flow_task.py new file mode 100644 index 000000000..1ec809aee --- /dev/null +++ b/griptape/tasks/choice_control_flow_task.py @@ -0,0 +1,64 @@ +from __future__ import annotations +from typing import Callable +from attrs import define, field + +from griptape.artifacts import BaseArtifact, ErrorArtifact, TaskArtifact, ListArtifact +from griptape.tasks import BaseTask +from griptape.tasks import BaseControlFlowTask + + +@define +class ChoiceControlFlowTask(BaseControlFlowTask): + control_flow_fn: Callable[[list[BaseTask] | BaseTask], list[BaseTask | str] | BaseTask | str] = field( + metadata={"serializable": False} + ) + + @property + def input(self) -> BaseArtifact: + if len(self.parents) == 1: + return TaskArtifact(self.parents[0]) + return ListArtifact([TaskArtifact(parent) for parent in self.parents]) + + def before_run(self) -> None: + super().before_run() + + self.structure.logger.info(f"{self.__class__.__name__} {self.id}\nInput: {self.input.to_text()}") + + def after_run(self) -> None: + super().after_run() + + self.structure.logger.info(f"{self.__class__.__name__} {self.id}\nOutput: {self.output.to_text()}") + + def run(self) -> BaseArtifact: + tasks = self.control_flow_fn( + [artifact.value for artifact in self.input.value] + if isinstance(self.input, ListArtifact) + else self.input.value + ) + + if not isinstance(tasks, list): + tasks = [tasks] + + if tasks is None: + tasks = [] + + tasks = [self._get_task(task) for task in tasks] + + for task in tasks: + if task.id not in self.child_ids: + self.output = ErrorArtifact(f"ControlFlowTask {self.id} did not return a valid child task") + return self.output + + self.output = ( + ListArtifact( + [ + parent.value.output + for parent in filter(lambda parent: parent.value.output is not None, self.input.value) + ] # pyright: ignore + ) + if isinstance(self.input, ListArtifact) + else self.input.value.output + ) + self._cancel_children_rec(self, task) + + return self.output # pyright: ignore diff --git a/griptape/utils/structure_visualizer.py b/griptape/utils/structure_visualizer.py index f24443cd6..ea49e2f51 100644 --- a/griptape/utils/structure_visualizer.py +++ b/griptape/utils/structure_visualizer.py @@ -40,7 +40,12 @@ def to_url(self) -> str: def __render_task(self, task: BaseTask) -> str: if task.children: children = " & ".join([f"{self.__get_id(child.id)}({child.id})" for child in task.children]) - return f"{self.__get_id(task.id)}({task.id})--> {children};" + from griptape.tasks import ChoiceControlFlowTask + + if isinstance(task, ChoiceControlFlowTask): + return f"{self.__get_id(task.id)}{{{task.id}}}-.-> {children};" + else: + return f"{self.__get_id(task.id)}({task.id})--> {children};" else: return f"{self.__get_id(task.id)}({task.id});" diff --git a/tests/unit/structures/test_workflow.py b/tests/unit/structures/test_workflow.py index 242de29c5..e40c19fcd 100644 --- a/tests/unit/structures/test_workflow.py +++ b/tests/unit/structures/test_workflow.py @@ -2,6 +2,12 @@ import pytest +from pytest import fixture +from griptape.memory.task.storage import TextArtifactStorage +from tests.mocks.mock_prompt_driver import MockPromptDriver +from griptape.rules import Rule, Ruleset +from griptape.tasks import PromptTask, BaseTask, ToolkitTask, CodeExecutionTask, ChoiceControlFlowTask +from griptape.structures import Workflow from griptape.artifacts import ErrorArtifact, TextArtifact from griptape.memory.structure import ConversationMemory from griptape.memory.task.storage import TextArtifactStorage @@ -29,6 +35,140 @@ def fn(task): return CodeExecutionTask(run_fn=fn) + def test_workflow_with_control_flow_task(self): + task1 = PromptTask("prompt1", id="task1") + task1.output = TextArtifact("task1 output") + task2 = PromptTask("prompt2", id="task2") + task3 = PromptTask("prompt3", id="task3") + task4 = PromptTask("prompt4", id="end") + control_flow_task = ChoiceControlFlowTask(id="control_flow_task", control_flow_fn=lambda x: task2) + control_flow_task.add_parent(task1) + control_flow_task.add_children([task2, task3]) + task4.add_parents([task2, task3]) + workflow = Workflow(prompt_driver=MockPromptDriver(), tasks=[task1, task2, task3, task4, control_flow_task]) + workflow.resolve_relationships() + workflow.run() + from griptape.utils import StructureVisualizer + + print(StructureVisualizer(workflow).to_url()) + + assert task1.state == BaseTask.State.FINISHED + assert task2.state == BaseTask.State.FINISHED + assert task3.state == BaseTask.State.CANCELLED + assert task4.state == BaseTask.State.FINISHED + assert control_flow_task.state == BaseTask.State.FINISHED + + def test_workflow_with_multiple_control_flow_tasks(self): + # control_flow_task should branch to task3 but + # task3 should be executed only once + # and task4 should be CANCELLED + task1 = PromptTask("prompt1", id="task1") + task2 = PromptTask("prompt2", id="task2") + task3 = PromptTask("prompt3", id="task3") + task4 = PromptTask("prompt4", id="task4") + task5 = PromptTask("prompt5", id="task5") + task6 = PromptTask("prompt6", id="task6") + control_flow_task1 = ChoiceControlFlowTask(id="control_flow_task1", control_flow_fn=lambda x: task3) + control_flow_task1.add_parent(task1) + control_flow_task1.add_children([task2, task3]) + control_flow_task2 = ChoiceControlFlowTask(id="control_flow_task2", control_flow_fn=lambda x: task5) + control_flow_task2.add_parent(task2) + control_flow_task2.add_children([task4, task5]) + task6.add_parents([task3, task4, task5]) + workflow = Workflow( + prompt_driver=MockPromptDriver(), + tasks=[task1, task2, task3, task4, task5, task6, control_flow_task1, control_flow_task2], + ) + workflow.resolve_relationships() + workflow.run() + from griptape.utils import StructureVisualizer + + print(StructureVisualizer(workflow).to_url()) + + assert task1.state == BaseTask.State.FINISHED + assert task2.state == BaseTask.State.CANCELLED + assert task3.state == BaseTask.State.FINISHED + assert task4.state == BaseTask.State.CANCELLED + assert task5.state == BaseTask.State.CANCELLED + assert task6.state == BaseTask.State.FINISHED + assert control_flow_task1.state == BaseTask.State.FINISHED + assert control_flow_task2.state == BaseTask.State.CANCELLED + + def test_workflow_with_control_flow_task_multiple_input_parents(self): + # control_flow_task should branch to task3 but + # task3 should be executed only once + # and task4 should be CANCELLED + task1 = PromptTask("prompt1", id="task1", prompt_driver=MockPromptDriver(mock_output="3")) + task2 = PromptTask("prompt2", id="task2") + task3 = PromptTask(id="task3") + task4 = PromptTask("prompt4", id="task4") + task5 = PromptTask("prompt5", id="task5") + + def test(parents) -> tuple: + return "task3" if parents[0].output.value == "3" else "task4" + + control_flow_task = ChoiceControlFlowTask(id="control_flow_task", control_flow_fn=test) + control_flow_task.add_parents([task1, task2]) + control_flow_task.add_children([task3, task4]) + task5.add_parents([task3, task4]) + workflow = Workflow( + prompt_driver=MockPromptDriver(), tasks=[task1, task2, task3, task4, task5, control_flow_task] + ) + workflow.resolve_relationships() + workflow.run() + + assert task1.state == BaseTask.State.FINISHED + assert task2.state == BaseTask.State.FINISHED + assert task3.state == BaseTask.State.FINISHED + assert task4.state == BaseTask.State.CANCELLED + assert task5.state == BaseTask.State.FINISHED + assert control_flow_task.state == BaseTask.State.FINISHED + + def test_workflow_with_control_flow_task_multiple_child_parents(self): + # control_flow_task should branch to task3 but + # task3 should be executed only once + # and task4 should be CANCELLED + task1 = PromptTask("prompt1", id="task1") + task2 = PromptTask("prompt2", id="task2") + task3 = PromptTask(id="task3") + task4 = PromptTask("prompt4", id="task4") + task5 = PromptTask("prompt5", id="task5") + control_flow_task = ChoiceControlFlowTask(id="control_flow_task", control_flow_fn=lambda x: task3) + task2.add_parent(task1) + task2.add_child(task3) + control_flow_task.add_parent(task1) + control_flow_task.add_children([task3, task4]) + task5.add_parents([task3, task4]) + workflow = Workflow( + prompt_driver=MockPromptDriver(), tasks=[task1, task2, task3, task4, task5, control_flow_task] + ) + workflow.resolve_relationships() + workflow.run() + + assert task1.state == BaseTask.State.FINISHED + assert task2.state == BaseTask.State.FINISHED + assert task3.state == BaseTask.State.FINISHED + assert task4.state == BaseTask.State.CANCELLED + assert task5.state == BaseTask.State.FINISHED + assert control_flow_task.state == BaseTask.State.FINISHED + + for task in [task1, task2, task3, task4, task5, control_flow_task]: + task.reset() + assert task.state == BaseTask.State.PENDING + assert workflow.output is None + + # this time control_flow_task should branch to task4 + # and task3 should still be executed because it has another parent + control_flow_task.control_flow_fn = lambda x: task4 + workflow.run() + + assert task1.state == BaseTask.State.FINISHED + assert task2.state == BaseTask.State.FINISHED + assert task3.state == BaseTask.State.FINISHED + assert task4.state == BaseTask.State.FINISHED + assert task5.state == BaseTask.State.FINISHED + assert control_flow_task.state == BaseTask.State.FINISHED + def test_init(self): driver = MockPromptDriver() workflow = Workflow(prompt_driver=driver, rulesets=[Ruleset("TestRuleset", [Rule("test")])]) From 9110e98a9429df22f4940bfb6cedb74fc0ce86d7 Mon Sep 17 00:00:00 2001 From: matt Date: Fri, 26 Jul 2024 12:22:02 -0500 Subject: [PATCH 2/2] updates --- griptape/artifacts/__init__.py | 2 - griptape/artifacts/boolean_artifact.py | 9 ++-- griptape/artifacts/control_flow_artifact.py | 1 + griptape/artifacts/task_artifact.py | 30 ------------- griptape/memory/meta/__init__.py | 3 +- .../memory/meta/control_flow_meta_entry.py | 15 ------- griptape/tasks/base_control_flow_task.py | 16 +++---- griptape/tasks/boolean_control_flow_task.py | 45 ------------------- griptape/tasks/choice_control_flow_task.py | 31 ++++++------- tests/unit/structures/test_workflow.py | 14 +----- 10 files changed, 26 insertions(+), 140 deletions(-) delete mode 100644 griptape/artifacts/task_artifact.py delete mode 100644 griptape/memory/meta/control_flow_meta_entry.py delete mode 100644 griptape/tasks/boolean_control_flow_task.py diff --git a/griptape/artifacts/__init__.py b/griptape/artifacts/__init__.py index 13a1c8385..6520e60d9 100644 --- a/griptape/artifacts/__init__.py +++ b/griptape/artifacts/__init__.py @@ -12,7 +12,6 @@ from .action_artifact import ActionArtifact from .generic_artifact import GenericArtifact from .control_flow_artifact import ControlFlowArtifact -from .task_artifact import TaskArtifact __all__ = [ @@ -30,5 +29,4 @@ "ActionArtifact", "GenericArtifact", "ControlFlowArtifact", - "TaskArtifact", ] diff --git a/griptape/artifacts/boolean_artifact.py b/griptape/artifacts/boolean_artifact.py index cadf6724d..d517b5c32 100644 --- a/griptape/artifacts/boolean_artifact.py +++ b/griptape/artifacts/boolean_artifact.py @@ -3,7 +3,8 @@ from typing import Union from attrs import define, field -from griptape.artifacts import TextArtifact, BaseArtifact + +from griptape.artifacts import BaseArtifact, TextArtifact @define @@ -11,10 +12,8 @@ class BooleanArtifact(BaseArtifact): value: bool = field(converter=bool, metadata={"serializable": True}) @classmethod - def parse_bool(cls, value: Union[str, bool, TextArtifact]) -> BooleanArtifact: - """ - Convert a string literal or bool to a BooleanArtifact. The string must be either "true" or "false" with any casing. - """ + def parse_bool(cls, value: Union[str, bool, TextArtifact]) -> BooleanArtifact: # noqa: FBT001 + """Convert a string literal, TextArtifact or bool to a BooleanArtifact. The string must be either "true" or "false" with any casing.""" if value is not None: if isinstance(value, TextArtifact): value = str(value) diff --git a/griptape/artifacts/control_flow_artifact.py b/griptape/artifacts/control_flow_artifact.py index 6908bf139..4e1b6532a 100644 --- a/griptape/artifacts/control_flow_artifact.py +++ b/griptape/artifacts/control_flow_artifact.py @@ -1,4 +1,5 @@ from attrs import define + from griptape.artifacts import BaseArtifact diff --git a/griptape/artifacts/task_artifact.py b/griptape/artifacts/task_artifact.py deleted file mode 100644 index 508815dd6..000000000 --- a/griptape/artifacts/task_artifact.py +++ /dev/null @@ -1,30 +0,0 @@ -from __future__ import annotations -from attrs import define, field -from typing import TYPE_CHECKING -from griptape.artifacts import ControlFlowArtifact - -if TYPE_CHECKING: - from griptape.tasks import BaseTask - from griptape.artifacts import BaseArtifact - - -@define -class TaskArtifact(ControlFlowArtifact): - value: BaseTask = field(metadata={"serializable": True}) - - @property - def task_id(self) -> str: - return self.value.id - - @property - def task(self) -> BaseTask: - return self.value - - def to_text(self) -> str: - return self.value.id - - def __add__(self, other: BaseArtifact) -> BaseArtifact: - raise NotImplementedError("TaskArtifact does not support addition") - - def __eq__(self, value: object) -> bool: - return self.value is value diff --git a/griptape/memory/meta/__init__.py b/griptape/memory/meta/__init__.py index 0c9fd1824..56b6d607c 100644 --- a/griptape/memory/meta/__init__.py +++ b/griptape/memory/meta/__init__.py @@ -1,6 +1,5 @@ from .base_meta_entry import BaseMetaEntry from .action_subtask_meta_entry import ActionSubtaskMetaEntry -from .control_flow_meta_entry import ControlFlowMetaEntry from .meta_memory import MetaMemory -__all__ = ["BaseMetaEntry", "MetaMemory", "ActionSubtaskMetaEntry", "ControlFlowMetaEntry"] +__all__ = ["BaseMetaEntry", "MetaMemory", "ActionSubtaskMetaEntry"] diff --git a/griptape/memory/meta/control_flow_meta_entry.py b/griptape/memory/meta/control_flow_meta_entry.py deleted file mode 100644 index 0d5d5ef84..000000000 --- a/griptape/memory/meta/control_flow_meta_entry.py +++ /dev/null @@ -1,15 +0,0 @@ -from __future__ import annotations -from typing import TYPE_CHECKING, Optional -from attrs import field, define -from griptape.memory.meta import BaseMetaEntry - -if TYPE_CHECKING: - from griptape.artifacts import BaseArtifact - - -@define -class ControlFlowMetaEntry(BaseMetaEntry): - type: str = field(default=__name__, kw_only=True, metadata={"serializable": False}) - input_tasks: list[str] = field(factory=list, kw_only=True) - output_tasks: list[str] = field(factory=list, kw_only=True) - output: Optional[BaseArtifact] = field(default=None, kw_only=True) diff --git a/griptape/tasks/base_control_flow_task.py b/griptape/tasks/base_control_flow_task.py index c7196b203..b60b507ae 100644 --- a/griptape/tasks/base_control_flow_task.py +++ b/griptape/tasks/base_control_flow_task.py @@ -1,8 +1,10 @@ from __future__ import annotations + from abc import ABC + from attrs import define + from griptape.tasks import BaseTask -from griptape.memory.meta import ControlFlowMetaEntry @define @@ -10,20 +12,12 @@ class BaseControlFlowTask(BaseTask, ABC): def before_run(self) -> None: super().before_run() - self.structure.logger.info(f"{self.__class__.__name__} {self.id}\nInput: {self.input.to_text()}") + self.structure.logger.info("%s %s\nInput: %s", self.__class__.__name__, self.id, self.input.to_text()) def after_run(self) -> None: super().after_run() - self.structure.meta_memory.add_entry( - ControlFlowMetaEntry( - input_tasks=[parent.id for parent in self.parents], - output_tasks=[child.id for child in filter(lambda child: not child.is_finished(), self.children)], - output=self.output, - ) - ) - - self.structure.logger.info(f"{self.__class__.__name__} {self.id}\nOutput: {self.output.to_text()}") + self.structure.logger.info("%s %s\nOutput: %s", self.__class__.__name__, self.id, self.output.to_text()) def _cancel_children_rec(self, task: BaseTask, chosen_task: BaseTask) -> None: for child in filter(lambda child: child != chosen_task, task.children): diff --git a/griptape/tasks/boolean_control_flow_task.py b/griptape/tasks/boolean_control_flow_task.py deleted file mode 100644 index a505b4b9a..000000000 --- a/griptape/tasks/boolean_control_flow_task.py +++ /dev/null @@ -1,45 +0,0 @@ -from __future__ import annotations -from typing import TYPE_CHECKING, Literal, Union -from attrs import field - -from griptape.artifacts import BooleanArtifact -from griptape.tasks import BaseControlFlowTask - -if TYPE_CHECKING: - from griptape.tasks import BaseTask - - -class BooleanControlFlowTask(BaseControlFlowTask): - true_tasks: list[str | BaseTask] = field(factory=list, kw_only=True) - false_tasks: list[str | BaseTask] = field(factory=list, kw_only=True) - operator: Union[Literal["and"], Literal["or"], Literal["xor"]] = field(default="and", kw_only=True) - coerce_inputs_to_bool: bool = field(default=False, kw_only=True) - - def run(self) -> BooleanArtifact: - if not all( - choice_task if isinstance(choice_task, str) else choice_task.id in self.child_ids - for choice_task in [*self.true_tasks, *self.false_tasks] - ): - raise ValueError(f"BooleanControlFlowTask {self.id} has invalid true_tasks or false_tasks") - - inputs = [task.output for task in self.parents] - - if self.coerce_inputs_to_bool: - inputs = [BooleanArtifact(input) for input in inputs] - else: - if not all(isinstance(input, BooleanArtifact) for input in inputs): - raise ValueError(f"BooleanControlFlowTask {self.id} has non-BooleanArtifact inputs") - - if self.operator == "and": - self.output = BooleanArtifact(all(inputs)) - elif self.operator == "or": - self.output = BooleanArtifact(any(inputs)) - elif self.operator == "xor": - self.output = BooleanArtifact(sum([int(input.value) for input in inputs]) == 1) - else: - raise ValueError(f"BooleanControlFlowTask {self.id} has invalid operator {self.operator}") - - for task in self.true_tasks if self.output.value else self.false_tasks: - task = self._get_task(task) - self._cancel_children_rec(self, task) - return self.output diff --git a/griptape/tasks/choice_control_flow_task.py b/griptape/tasks/choice_control_flow_task.py index 1ec809aee..14b785db4 100644 --- a/griptape/tasks/choice_control_flow_task.py +++ b/griptape/tasks/choice_control_flow_task.py @@ -1,10 +1,11 @@ from __future__ import annotations + from typing import Callable + from attrs import define, field -from griptape.artifacts import BaseArtifact, ErrorArtifact, TaskArtifact, ListArtifact -from griptape.tasks import BaseTask -from griptape.tasks import BaseControlFlowTask +from griptape.artifacts import BaseArtifact, ErrorArtifact, ListArtifact, TextArtifact +from griptape.tasks import BaseControlFlowTask, BaseTask @define @@ -16,18 +17,14 @@ class ChoiceControlFlowTask(BaseControlFlowTask): @property def input(self) -> BaseArtifact: if len(self.parents) == 1: - return TaskArtifact(self.parents[0]) - return ListArtifact([TaskArtifact(parent) for parent in self.parents]) - - def before_run(self) -> None: - super().before_run() - - self.structure.logger.info(f"{self.__class__.__name__} {self.id}\nInput: {self.input.to_text()}") - - def after_run(self) -> None: - super().after_run() - - self.structure.logger.info(f"{self.__class__.__name__} {self.id}\nOutput: {self.output.to_text()}") + return self.parents[0].output if self.parents[0].output is not None else TextArtifact("") + parents = filter(lambda parent: parent.output is not None, self.parents) + return ListArtifact( + [ + parent.output + for parent in parents # pyright: ignore[reportArgumentType] + ] + ) def run(self) -> BaseArtifact: tasks = self.control_flow_fn( @@ -54,11 +51,11 @@ def run(self) -> BaseArtifact: [ parent.value.output for parent in filter(lambda parent: parent.value.output is not None, self.input.value) - ] # pyright: ignore + ] ) if isinstance(self.input, ListArtifact) else self.input.value.output ) self._cancel_children_rec(self, task) - return self.output # pyright: ignore + return self.output # pyright: ignore[reportReturnType] diff --git a/tests/unit/structures/test_workflow.py b/tests/unit/structures/test_workflow.py index e40c19fcd..b5cf5627f 100644 --- a/tests/unit/structures/test_workflow.py +++ b/tests/unit/structures/test_workflow.py @@ -2,18 +2,12 @@ import pytest -from pytest import fixture -from griptape.memory.task.storage import TextArtifactStorage -from tests.mocks.mock_prompt_driver import MockPromptDriver -from griptape.rules import Rule, Ruleset -from griptape.tasks import PromptTask, BaseTask, ToolkitTask, CodeExecutionTask, ChoiceControlFlowTask -from griptape.structures import Workflow from griptape.artifacts import ErrorArtifact, TextArtifact from griptape.memory.structure import ConversationMemory from griptape.memory.task.storage import TextArtifactStorage from griptape.rules import Rule, Ruleset from griptape.structures import Workflow -from griptape.tasks import BaseTask, CodeExecutionTask, PromptTask, ToolkitTask +from griptape.tasks import BaseTask, ChoiceControlFlowTask, CodeExecutionTask, PromptTask, ToolkitTask from tests.mocks.mock_embedding_driver import MockEmbeddingDriver from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.mocks.mock_tool.tool import MockTool @@ -48,9 +42,6 @@ def test_workflow_with_control_flow_task(self): workflow = Workflow(prompt_driver=MockPromptDriver(), tasks=[task1, task2, task3, task4, control_flow_task]) workflow.resolve_relationships() workflow.run() - from griptape.utils import StructureVisualizer - - print(StructureVisualizer(workflow).to_url()) assert task1.state == BaseTask.State.FINISHED assert task2.state == BaseTask.State.FINISHED @@ -81,9 +72,6 @@ def test_workflow_with_multiple_control_flow_tasks(self): ) workflow.resolve_relationships() workflow.run() - from griptape.utils import StructureVisualizer - - print(StructureVisualizer(workflow).to_url()) assert task1.state == BaseTask.State.FINISHED assert task2.state == BaseTask.State.CANCELLED