Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ControlFlowTask #913

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions griptape/artifacts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .audio_artifact import AudioArtifact
from .action_artifact import ActionArtifact
from .generic_artifact import GenericArtifact
from .control_flow_artifact import ControlFlowArtifact


__all__ = [
Expand All @@ -27,4 +28,5 @@
"AudioArtifact",
"ActionArtifact",
"GenericArtifact",
"ControlFlowArtifact",
]
8 changes: 5 additions & 3 deletions griptape/artifacts/boolean_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,19 @@

from attrs import define, field

from griptape.artifacts import BaseArtifact
from griptape.artifacts import BaseArtifact, TextArtifact


@define
class BooleanArtifact(BaseArtifact):
value: bool = field(converter=bool, metadata={"serializable": True})

@classmethod
def parse_bool(cls, value: Union[str, bool]) -> BooleanArtifact: # noqa: FBT001
"""Convert a string literal or bool to a BooleanArtifact. The string must be either "true" or "false" with any casing."""
def parse_bool(cls, value: Union[str, bool, TextArtifact]) -> BooleanArtifact: # noqa: FBT001
"""Convert a string literal, TextArtifact or bool to a BooleanArtifact. The string must be either "true" or "false" with any casing."""
if value is not None:
if isinstance(value, TextArtifact):
value = str(value)
if isinstance(value, str):
if value.lower() == "true":
return BooleanArtifact(True) # noqa: FBT003
Expand Down
8 changes: 8 additions & 0 deletions griptape/artifacts/control_flow_artifact.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from attrs import define

from griptape.artifacts import BaseArtifact


@define
class ControlFlowArtifact(BaseArtifact):
pass
3 changes: 3 additions & 0 deletions griptape/structures/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,9 @@ def default_task_memory(self) -> TaskMemory:
def is_finished(self) -> bool:
return all(s.is_finished() for s in self.tasks)

def is_complete(self) -> bool:
return all(s.is_complete() for s in self.tasks)

def is_executing(self) -> bool:
return any(s for s in self.tasks if s.is_executing())

Expand Down
16 changes: 9 additions & 7 deletions griptape/structures/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,20 +94,22 @@ def insert_task(
def try_run(self, *args) -> Workflow:
exit_loop = False

while not self.is_finished() and not exit_loop:
while not self.is_complete() and not exit_loop:
futures_list = {}
ordered_tasks = self.order_tasks()
executable_tasks = [*filter(lambda task: task.can_execute(), self.order_tasks())]

for task in ordered_tasks:
if task.can_execute():
future = self.futures_executor_fn().submit(task.execute)
futures_list[future] = task
if not executable_tasks:
exit_loop = True
break

for task in executable_tasks:
future = self.futures_executor_fn().submit(task.execute)
futures_list[future] = task

# Wait for all tasks to complete
for future in futures.as_completed(futures_list):
if isinstance(future.result(), ErrorArtifact) and self.fail_fast:
exit_loop = True

break

if self.conversation_memory and self.output is not None:
Expand Down
4 changes: 4 additions & 0 deletions griptape/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from .text_to_speech_task import TextToSpeechTask
from .structure_run_task import StructureRunTask
from .audio_transcription_task import AudioTranscriptionTask
from .base_control_flow_task import BaseControlFlowTask
from .choice_control_flow_task import ChoiceControlFlowTask

__all__ = [
"BaseTask",
Expand All @@ -46,4 +48,6 @@
"TextToSpeechTask",
"StructureRunTask",
"AudioTranscriptionTask",
"BaseControlFlowTask",
"ChoiceControlFlowTask",
]
29 changes: 29 additions & 0 deletions griptape/tasks/base_control_flow_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from __future__ import annotations

from abc import ABC

from attrs import define

from griptape.tasks import BaseTask


@define
class BaseControlFlowTask(BaseTask, ABC):
def before_run(self) -> None:
super().before_run()

self.structure.logger.info("%s %s\nInput: %s", self.__class__.__name__, self.id, self.input.to_text())

def after_run(self) -> None:
super().after_run()

self.structure.logger.info("%s %s\nOutput: %s", self.__class__.__name__, self.id, self.output.to_text())

def _cancel_children_rec(self, task: BaseTask, chosen_task: BaseTask) -> None:
for child in filter(lambda child: child != chosen_task, task.children):
if all(parent.is_complete() for parent in filter(lambda parent: parent != task, child.parents)):
child.state = BaseTask.State.CANCELLED
self._cancel_children_rec(child, chosen_task)

def _get_task(self, task: str | BaseTask) -> BaseTask:
return self.structure.find_task(task) if isinstance(task, str) else task
15 changes: 14 additions & 1 deletion griptape/tasks/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class State(Enum):
PENDING = 1
EXECUTING = 2
FINISHED = 3
CANCELLED = 4

id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True)
state: State = field(default=State.PENDING, kw_only=True)
Expand Down Expand Up @@ -122,9 +123,15 @@ def is_pending(self) -> bool:
def is_finished(self) -> bool:
return self.state == BaseTask.State.FINISHED

def is_cancelled(self) -> bool:
return self.state == BaseTask.State.CANCELLED

def is_executing(self) -> bool:
return self.state == BaseTask.State.EXECUTING

def is_complete(self) -> bool:
return self.is_finished() or self.is_cancelled()

def before_run(self) -> None:
if self.structure is not None:
event_bus.publish_event(
Expand Down Expand Up @@ -168,7 +175,13 @@ def execute(self) -> Optional[BaseArtifact]:
return self.output

def can_execute(self) -> bool:
return self.state == BaseTask.State.PENDING and all(parent.is_finished() for parent in self.parents)
return self.is_pending() and (
(
all(parent.is_complete() for parent in self.parents)
and any(parent.is_finished() for parent in self.parents)
)
or len(self.parents) == 0
)

def reset(self) -> BaseTask:
self.state = BaseTask.State.PENDING
Expand Down
61 changes: 61 additions & 0 deletions griptape/tasks/choice_control_flow_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from __future__ import annotations

from typing import Callable

from attrs import define, field

from griptape.artifacts import BaseArtifact, ErrorArtifact, ListArtifact, TextArtifact
from griptape.tasks import BaseControlFlowTask, BaseTask


@define
class ChoiceControlFlowTask(BaseControlFlowTask):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need a dedicated Task for this? What if we added control_flow_fn to BaseTask?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

future proof if there is other types of control flow tasks

control_flow_fn: Callable[[list[BaseTask] | BaseTask], list[BaseTask | str] | BaseTask | str] = field(
metadata={"serializable": False}
)

@property
def input(self) -> BaseArtifact:
if len(self.parents) == 1:
return self.parents[0].output if self.parents[0].output is not None else TextArtifact("")
parents = filter(lambda parent: parent.output is not None, self.parents)
return ListArtifact(
[
parent.output
for parent in parents # pyright: ignore[reportArgumentType]
]
)

def run(self) -> BaseArtifact:
tasks = self.control_flow_fn(
[artifact.value for artifact in self.input.value]
if isinstance(self.input, ListArtifact)
else self.input.value
)

if not isinstance(tasks, list):
tasks = [tasks]

if tasks is None:
tasks = []

tasks = [self._get_task(task) for task in tasks]

for task in tasks:
if task.id not in self.child_ids:
self.output = ErrorArtifact(f"ControlFlowTask {self.id} did not return a valid child task")
return self.output

self.output = (
ListArtifact(
[
parent.value.output
for parent in filter(lambda parent: parent.value.output is not None, self.input.value)
]
)
if isinstance(self.input, ListArtifact)
else self.input.value.output
)
self._cancel_children_rec(self, task)

return self.output # pyright: ignore[reportReturnType]
7 changes: 6 additions & 1 deletion griptape/utils/structure_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,12 @@ def to_url(self) -> str:
def __render_task(self, task: BaseTask) -> str:
if task.children:
children = " & ".join([f"{self.__get_id(child.id)}({child.id})" for child in task.children])
return f"{self.__get_id(task.id)}({task.id})--> {children};"
from griptape.tasks import ChoiceControlFlowTask

if isinstance(task, ChoiceControlFlowTask):
return f"{self.__get_id(task.id)}{{{task.id}}}-.-> {children};"
else:
return f"{self.__get_id(task.id)}({task.id})--> {children};"
else:
return f"{self.__get_id(task.id)}({task.id});"

Expand Down
130 changes: 129 additions & 1 deletion tests/unit/structures/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from griptape.memory.task.storage import TextArtifactStorage
from griptape.rules import Rule, Ruleset
from griptape.structures import Workflow
from griptape.tasks import BaseTask, CodeExecutionTask, PromptTask, ToolkitTask
from griptape.tasks import BaseTask, ChoiceControlFlowTask, CodeExecutionTask, PromptTask, ToolkitTask
from tests.mocks.mock_embedding_driver import MockEmbeddingDriver
from tests.mocks.mock_prompt_driver import MockPromptDriver
from tests.mocks.mock_tool.tool import MockTool
Expand All @@ -29,6 +29,134 @@ def fn(task):

return CodeExecutionTask(run_fn=fn)

def test_workflow_with_control_flow_task(self):
task1 = PromptTask("prompt1", id="task1")
task1.output = TextArtifact("task1 output")
task2 = PromptTask("prompt2", id="task2")
task3 = PromptTask("prompt3", id="task3")
task4 = PromptTask("prompt4", id="end")
control_flow_task = ChoiceControlFlowTask(id="control_flow_task", control_flow_fn=lambda x: task2)
control_flow_task.add_parent(task1)
control_flow_task.add_children([task2, task3])
task4.add_parents([task2, task3])
workflow = Workflow(prompt_driver=MockPromptDriver(), tasks=[task1, task2, task3, task4, control_flow_task])
workflow.resolve_relationships()
workflow.run()

assert task1.state == BaseTask.State.FINISHED
assert task2.state == BaseTask.State.FINISHED
assert task3.state == BaseTask.State.CANCELLED
assert task4.state == BaseTask.State.FINISHED
assert control_flow_task.state == BaseTask.State.FINISHED

def test_workflow_with_multiple_control_flow_tasks(self):
# control_flow_task should branch to task3 but
# task3 should be executed only once
# and task4 should be CANCELLED
task1 = PromptTask("prompt1", id="task1")
task2 = PromptTask("prompt2", id="task2")
task3 = PromptTask("prompt3", id="task3")
task4 = PromptTask("prompt4", id="task4")
task5 = PromptTask("prompt5", id="task5")
task6 = PromptTask("prompt6", id="task6")
control_flow_task1 = ChoiceControlFlowTask(id="control_flow_task1", control_flow_fn=lambda x: task3)
control_flow_task1.add_parent(task1)
control_flow_task1.add_children([task2, task3])
control_flow_task2 = ChoiceControlFlowTask(id="control_flow_task2", control_flow_fn=lambda x: task5)
control_flow_task2.add_parent(task2)
control_flow_task2.add_children([task4, task5])
task6.add_parents([task3, task4, task5])
workflow = Workflow(
prompt_driver=MockPromptDriver(),
tasks=[task1, task2, task3, task4, task5, task6, control_flow_task1, control_flow_task2],
)
workflow.resolve_relationships()
workflow.run()

assert task1.state == BaseTask.State.FINISHED
assert task2.state == BaseTask.State.CANCELLED
assert task3.state == BaseTask.State.FINISHED
assert task4.state == BaseTask.State.CANCELLED
assert task5.state == BaseTask.State.CANCELLED
assert task6.state == BaseTask.State.FINISHED
assert control_flow_task1.state == BaseTask.State.FINISHED
assert control_flow_task2.state == BaseTask.State.CANCELLED

def test_workflow_with_control_flow_task_multiple_input_parents(self):
# control_flow_task should branch to task3 but
# task3 should be executed only once
# and task4 should be CANCELLED
task1 = PromptTask("prompt1", id="task1", prompt_driver=MockPromptDriver(mock_output="3"))
task2 = PromptTask("prompt2", id="task2")
task3 = PromptTask(id="task3")
task4 = PromptTask("prompt4", id="task4")
task5 = PromptTask("prompt5", id="task5")

def test(parents) -> tuple:
return "task3" if parents[0].output.value == "3" else "task4"

control_flow_task = ChoiceControlFlowTask(id="control_flow_task", control_flow_fn=test)
control_flow_task.add_parents([task1, task2])
control_flow_task.add_children([task3, task4])
task5.add_parents([task3, task4])
workflow = Workflow(
prompt_driver=MockPromptDriver(), tasks=[task1, task2, task3, task4, task5, control_flow_task]
)
workflow.resolve_relationships()
workflow.run()

assert task1.state == BaseTask.State.FINISHED
assert task2.state == BaseTask.State.FINISHED
assert task3.state == BaseTask.State.FINISHED
assert task4.state == BaseTask.State.CANCELLED
assert task5.state == BaseTask.State.FINISHED
assert control_flow_task.state == BaseTask.State.FINISHED

def test_workflow_with_control_flow_task_multiple_child_parents(self):
# control_flow_task should branch to task3 but
# task3 should be executed only once
# and task4 should be CANCELLED
task1 = PromptTask("prompt1", id="task1")
task2 = PromptTask("prompt2", id="task2")
task3 = PromptTask(id="task3")
task4 = PromptTask("prompt4", id="task4")
task5 = PromptTask("prompt5", id="task5")
control_flow_task = ChoiceControlFlowTask(id="control_flow_task", control_flow_fn=lambda x: task3)
task2.add_parent(task1)
task2.add_child(task3)
control_flow_task.add_parent(task1)
control_flow_task.add_children([task3, task4])
task5.add_parents([task3, task4])
workflow = Workflow(
prompt_driver=MockPromptDriver(), tasks=[task1, task2, task3, task4, task5, control_flow_task]
)
workflow.resolve_relationships()
workflow.run()

assert task1.state == BaseTask.State.FINISHED
assert task2.state == BaseTask.State.FINISHED
assert task3.state == BaseTask.State.FINISHED
assert task4.state == BaseTask.State.CANCELLED
assert task5.state == BaseTask.State.FINISHED
assert control_flow_task.state == BaseTask.State.FINISHED

for task in [task1, task2, task3, task4, task5, control_flow_task]:
task.reset()
assert task.state == BaseTask.State.PENDING
assert workflow.output is None

# this time control_flow_task should branch to task4
# and task3 should still be executed because it has another parent
control_flow_task.control_flow_fn = lambda x: task4
workflow.run()

assert task1.state == BaseTask.State.FINISHED
assert task2.state == BaseTask.State.FINISHED
assert task3.state == BaseTask.State.FINISHED
assert task4.state == BaseTask.State.FINISHED
assert task5.state == BaseTask.State.FINISHED
assert control_flow_task.state == BaseTask.State.FINISHED

def test_init(self):
driver = MockPromptDriver()
workflow = Workflow(prompt_driver=driver, rulesets=[Ruleset("TestRuleset", [Rule("test")])])
Expand Down
Loading