From faee99ba90325f3f6b198fbd4eb8d7cf93508294 Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Wed, 3 Jul 2024 14:02:36 -0400 Subject: [PATCH 1/2] Add compiler backend --- MANIFEST.in | 1 + pyproject.toml | 2 +- src/controlflow/__init__.py | 10 +- src/controlflow/agents/agent.py | 48 ++- src/controlflow/agents/memory.py | 10 +- src/controlflow/controllers/controller.py | 344 --------------- .../controllers/instruction_template.py | 272 ------------ .../controllers/process_messages.py | 127 ------ src/controlflow/events/__init__.py | 1 + src/controlflow/events/agent_events.py | 118 ++++++ src/controlflow/events/controller_events.py | 26 ++ src/controlflow/events/event_store.py | 256 +++++++++++ src/controlflow/events/events.py | 36 ++ src/controlflow/events/message_compiler.py | 222 ++++++++++ src/controlflow/events/task_events.py | 24 ++ src/controlflow/events/tool_events.py | 62 +++ src/controlflow/flows/__init__.py | 2 +- src/controlflow/flows/flow.py | 113 ++--- .../{controllers => flows}/graph.py | 14 +- src/controlflow/flows/history.py | 18 +- src/controlflow/handlers/__init__.py | 0 src/controlflow/handlers/print_handler.py | 186 -------- src/controlflow/llm/__init__.py | 2 +- src/controlflow/llm/classify.py | 104 ----- src/controlflow/llm/completions.py | 399 ------------------ src/controlflow/llm/handlers.py | 106 ----- src/controlflow/llm/messages.py | 277 +----------- src/controlflow/llm/models.py | 8 + src/controlflow/llm/rules.py | 22 +- .../__init__.py | 0 src/controlflow/orchestration/controller.py | 288 +++++++++++++ src/controlflow/orchestration/handler.py | 9 + .../orchestration/print_handler.py | 188 +++++++++ .../orchestration/prompt_templates/agent.j2 | 49 +++ .../orchestration/prompt_templates/tools.j2 | 21 + .../prompt_templates/workflow.j2 | 88 ++++ src/controlflow/orchestration/prompts.py | 34 ++ src/controlflow/orchestration/tools.py | 103 +++++ src/controlflow/settings.py | 4 +- src/controlflow/tasks/agent_strategies.py | 54 ++- src/controlflow/tasks/task.py | 95 +---- src/controlflow/tools/__init__.py | 2 +- src/controlflow/{llm => tools}/tools.py | 86 ++-- src/controlflow/utilities/jinja.py | 21 +- src/controlflow/utilities/testing.py | 4 +- src/controlflow/utilities/types.py | 7 - tests/conftest.py | 2 +- tests/controllers/test_graph.py | 2 +- tests/fixtures/controlflow.py | 2 +- tests/llm/test_tools.py | 2 +- tests/tasks/test_tasks.py | 2 +- tests/test_settings.py | 6 + 52 files changed, 1788 insertions(+), 2091 deletions(-) create mode 100644 MANIFEST.in delete mode 100644 src/controlflow/controllers/controller.py delete mode 100644 src/controlflow/controllers/instruction_template.py delete mode 100644 src/controlflow/controllers/process_messages.py create mode 100644 src/controlflow/events/__init__.py create mode 100644 src/controlflow/events/agent_events.py create mode 100644 src/controlflow/events/controller_events.py create mode 100644 src/controlflow/events/event_store.py create mode 100644 src/controlflow/events/events.py create mode 100644 src/controlflow/events/message_compiler.py create mode 100644 src/controlflow/events/task_events.py create mode 100644 src/controlflow/events/tool_events.py rename src/controlflow/{controllers => flows}/graph.py (95%) delete mode 100644 src/controlflow/handlers/__init__.py delete mode 100644 src/controlflow/handlers/print_handler.py delete mode 100644 src/controlflow/llm/classify.py delete mode 100644 src/controlflow/llm/completions.py delete mode 100644 src/controlflow/llm/handlers.py rename src/controlflow/{controllers => orchestration}/__init__.py (100%) create mode 100644 src/controlflow/orchestration/controller.py create mode 100644 src/controlflow/orchestration/handler.py create mode 100644 src/controlflow/orchestration/print_handler.py create mode 100644 src/controlflow/orchestration/prompt_templates/agent.j2 create mode 100644 src/controlflow/orchestration/prompt_templates/tools.j2 create mode 100644 src/controlflow/orchestration/prompt_templates/workflow.j2 create mode 100644 src/controlflow/orchestration/prompts.py create mode 100644 src/controlflow/orchestration/tools.py rename src/controlflow/{llm => tools}/tools.py (80%) diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 00000000..04fd9fb7 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1 @@ +recursive-include controlflow/orchestration/prompt_templates * \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 96aee618..de1c5e22 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ dependencies = [ "jinja2>=3.1.4", "langchain_core>=0.2.9", "langchain_openai>=0.1.8", - "langchain-anthropic>=0.1.15", + "langchain-anthropic>=0.1.19", "markdownify>=0.12.1", "pydantic-settings>=2.2.1", "textual>=0.61.1", diff --git a/src/controlflow/__init__.py b/src/controlflow/__init__.py index ca631c1a..e16bbe5a 100644 --- a/src/controlflow/__init__.py +++ b/src/controlflow/__init__.py @@ -8,20 +8,20 @@ from .instructions import instructions from .decorators import flow, task -from .llm.tools import tool +from .tools import tool # --- Default settings --- from .llm.models import _get_initial_default_model, get_default_model -from .flows.history import InMemoryHistory, get_default_history +from .events.event_store import InMemoryStore, get_default_event_store # assign to controlflow.default_model to change the default model default_model = _get_initial_default_model() del _get_initial_default_model -# assign to controlflow.default_history to change the default history -default_history = InMemoryHistory() -del InMemoryHistory +# assign to controlflow.default_event_store to change the default event store +default_event_store = InMemoryStore() +del InMemoryStore # assign to controlflow.default_agent to change the default agent default_agent = Agent(name="Marvin") diff --git a/src/controlflow/agents/agent.py b/src/controlflow/agents/agent.py index 8b5a8e33..e190a9f4 100644 --- a/src/controlflow/agents/agent.py +++ b/src/controlflow/agents/agent.py @@ -2,16 +2,17 @@ import random import uuid from contextlib import contextmanager -from typing import TYPE_CHECKING, Any, Callable, Optional +from typing import TYPE_CHECKING, Any, Callable, Generator, Optional from langchain_core.language_models import BaseChatModel from pydantic import Field, field_serializer import controlflow +from controlflow.events.events import Event from controlflow.instructions import get_instructions +from controlflow.llm.messages import AIMessage, BaseMessage from controlflow.llm.models import get_default_model from controlflow.llm.rules import LLMRules -from controlflow.tools.talk_to_user import talk_to_user from controlflow.utilities.context import ctx from controlflow.utilities.types import ControlFlowModel @@ -20,6 +21,8 @@ if TYPE_CHECKING: from controlflow.tasks.task import Task + from controlflow.tools.tools import Tool + logger = logging.getLogger(__name__) @@ -109,6 +112,8 @@ def get_llm_rules(self) -> LLMRules: return controlflow.llm.rules.rules_for_model(self.get_model()) def get_tools(self) -> list[Callable]: + from controlflow.tools.talk_to_user import talk_to_user + tools = self.tools.copy() if self.user_access: tools.append(talk_to_user) @@ -141,5 +146,44 @@ def run(self, task: "Task"): async def run_async(self, task: "Task"): return await task.run_async(agents=[self]) + def _run_model( + self, + messages: list[BaseMessage], + additional_tools: list["Tool"] = None, + stream: bool = True, + ) -> Generator[Event, None, None]: + from controlflow.events.agent_events import ( + AgentMessageDeltaEvent, + AgentMessageEvent, + ) + from controlflow.events.tool_events import ToolCallEvent, ToolResultEvent + from controlflow.tools.tools import as_tools, handle_tool_call + + model = self.get_model() + + tools = as_tools(self.get_tools() + (additional_tools or [])) + if tools: + model = model.bind_tools([t.to_lc_tool() for t in tools]) + + if stream: + response = None + for delta in model.stream(messages): + if response is None: + response = delta + else: + response += delta + yield AgentMessageDeltaEvent(agent=self, delta=delta, snapshot=response) + + else: + response: AIMessage = model.invoke(messages) + + yield AgentMessageEvent(agent=self, message=response) + + for tool_call in response.tool_calls + response.invalid_tool_calls: + yield ToolCallEvent(agent=self, tool_call=tool_call, message=response) + + result = handle_tool_call(tool_call, tools=tools) + yield ToolResultEvent(agent=self, tool_call=tool_call, tool_result=result) + DEFAULT_AGENT = Agent(name="Marvin") diff --git a/src/controlflow/agents/memory.py b/src/controlflow/agents/memory.py index acf94153..b45bce71 100644 --- a/src/controlflow/agents/memory.py +++ b/src/controlflow/agents/memory.py @@ -1,13 +1,15 @@ import abc import uuid -from typing import ClassVar, Optional, cast +from typing import TYPE_CHECKING, ClassVar, Optional, cast from pydantic import Field -from controlflow.tools import Tool from controlflow.utilities.context import ctx from controlflow.utilities.types import ControlFlowModel +if TYPE_CHECKING: + from controlflow.tools import Tool + class Memory(ControlFlowModel, abc.ABC): id: str = Field(default_factory=lambda: uuid.uuid4().hex) @@ -27,7 +29,9 @@ def update(self, value: str, index: int = None): def delete(self, index: int): raise NotImplementedError() - def get_tools(self) -> list[Tool]: + def get_tools(self) -> list["Tool"]: + from controlflow.tools import Tool + update_tool = Tool.from_function( self.update, name="update_memory", diff --git a/src/controlflow/controllers/controller.py b/src/controlflow/controllers/controller.py deleted file mode 100644 index cc49313c..00000000 --- a/src/controlflow/controllers/controller.py +++ /dev/null @@ -1,344 +0,0 @@ -import inspect -import logging -import math -from collections import defaultdict -from contextlib import asynccontextmanager -from typing import Callable - -from pydantic import Field, PrivateAttr, field_validator, model_validator - -import controlflow -from controlflow.agents import Agent -from controlflow.controllers.graph import Graph -from controlflow.controllers.process_messages import prepare_messages -from controlflow.flows import Flow, get_flow -from controlflow.handlers.print_handler import PrintHandler -from controlflow.instructions import get_instructions -from controlflow.llm.completions import completion, completion_async -from controlflow.llm.handlers import ResponseHandler, TUIHandler -from controlflow.llm.messages import MessageType, SystemMessage -from controlflow.tasks.task import Task -from controlflow.tools import as_tools -from controlflow.utilities.context import ctx -from controlflow.utilities.prefect import create_markdown_artifact -from controlflow.utilities.prefect import prefect_task as prefect_task -from controlflow.utilities.types import ControlFlowModel - -logger = logging.getLogger(__name__) - - -def create_messages_markdown_artifact(messages, thread_id): - markdown_messages = "\n\n".join([f"{msg.role}: {msg.content}" for msg in messages]) - create_markdown_artifact( - key="messages", - markdown=inspect.cleandoc( - """ - # Messages - - *Thread ID: {thread_id}* - - {markdown_messages} - """.format( - thread_id=thread_id, - markdown_messages=markdown_messages, - ) - ), - ) - - -class Controller(ControlFlowModel): - """ - A controller contains logic for executing agents with context about the - larger workflow, including the flow itself, any tasks, and any other agents - they are collaborating with. The controller is responsible for orchestrating - agent behavior by generating instructions and tools for each agent. Note - that while the controller accepts details about (potentially multiple) - agents and tasks, it's responsiblity is to invoke one agent one time. Other - mechanisms should be used to orchestrate multiple agents invocations. This - is done by the controller to avoid tying e.g. agents to tasks or even a - specific flow. - """ - - # the flow is tracked by the Controller, not the Task, so that tasks can be - # defined and even instantiated outside a flow. When a Controller is - # created, we know we're inside a flow context and ready to load defaults - # and run. - flow: Flow = Field( - default_factory=get_flow, - description="The flow that the controller is a part of.", - validate_default=True, - ) - tasks: list[Task] = Field( - description="Tasks that the controller will complete.", - ) - agents: dict[Task, list[Agent]] = Field( - default_factory=dict, - description="Optionally assign agents to complete tasks. The provided mapping must be task" - " -> [agents]. Any tasks that aren't included will use their default agents.", - ) - context: dict = {} - model_config: dict = dict(extra="forbid") - enable_experimental_tui: bool = Field( - default_factory=lambda: controlflow.settings.enable_experimental_tui - ) - max_iterations: int = Field( - default_factory=lambda: controlflow.settings.max_iterations - ) - _iteration: int = 0 - _should_stop: bool = False - _end_turn_counts: dict = PrivateAttr(default_factory=lambda: defaultdict(int)) - - @property - def graph(self) -> Graph: - return Graph.from_tasks(self.flow.tasks.values()) - - @field_validator("agents", mode="before") - def _default_agents(cls, v): - if v is None: - v = {} - return v - - @model_validator(mode="after") - def _finalize(self): - for task in self.tasks: - self.flow.add_task(task) - return self - - def _create_end_turn_tool(self) -> Callable: - def end_turn(): - """ - This tool is for emergencies only; you should not use it normally. - If you find yourself in a situation where you are repeatedly invoked - and your normal tools do not work, or you can not escape the loop, - use this tool to signal to the controller that you are stuck. A new - agent will be selected to go next. If this tool is used 3 times by - an agent the workflow will be aborted automatically. - - """ - - # the agent's name is used as the key to track the number of times - key = getattr(ctx.get("agent", None), "name", None) - - self._end_turn_counts[key] += 1 - if self._end_turn_counts[key] >= 3: - self._should_stop = True - self._end_turn_counts[key] = 0 - - return ( - f"Ending turn. {3 - self._end_turn_counts[key]}" - " more uses will abort the workflow." - ) - - return end_turn - - @asynccontextmanager - async def tui(self): - if tui := ctx.get("tui"): - yield tui - elif self.enable_experimental_tui: - from controlflow.tui.app import TUIApp as TUI - - tui = TUI(flow=self.flow) - with ctx(tui=tui): - async with tui.run_context(): - yield tui - else: - yield - - def _setup_run(self): - """ - Generate the payload for a single run of the controller. - """ - if self._iteration >= (self.max_iterations or math.inf): - raise ValueError( - f"Controller has exceeded maximum iterations of {self.max_iterations}." - ) - ready_tasks = [t for t in self.tasks if t.is_ready()] - - # if there are no ready tasks, return. This will usually happen because - # all the tasks are complete. - if not ready_tasks: - return - - # start tracking tasks - for task in ready_tasks: - if not task._prefect_task.is_started: - task._prefect_task.start( - depends_on=[ - t.result for t in task.depends_on if t.result is not None - ] - ) - - messages = self.flow.get_messages() - - # get an agent from the next ready task - agents = self.agents.get(ready_tasks[0], None) - if agents is None: - agents = ready_tasks[0].get_agents() - if len(agents) == 1: - agent = agents[0] - else: - strategy_fn = ready_tasks[0].get_agent_strategy() - agent = strategy_fn(agents=agents, task=ready_tasks[0], flow=self.flow) - ready_tasks[0]._iteration += 1 - - from controlflow.controllers.instruction_template import MainTemplate - - tools = self.flow.tools + agent.get_tools() + [self._create_end_turn_tool()] - - # add tools for any ready tasks that the agent is assigned to - for task in ready_tasks: - if agent in self.agents.get(task, []) or agent in task.get_agents(): - tools.extend(task.get_tools()) - - instructions_template = MainTemplate( - agent=agent, - controller=self, - ready_tasks=ready_tasks, - current_task=ready_tasks[0], - context=self.context, - instructions=get_instructions(), - agent_assignments=self.agents, - ) - instructions = instructions_template.render() - - # prepare messages - system_message = SystemMessage(content=instructions) - - rules = agent.get_llm_rules() - - messages = prepare_messages( - agent=agent, - system_message=system_message, - messages=messages, - rules=rules, - tools=tools, - ) - - # setup handlers - handlers = [] - if self.enable_experimental_tui: - handlers.append(TUIHandler()) - elif controlflow.settings.enable_print_handler: - handlers.append(PrintHandler()) - # yield the agent payload - return dict( - agent=agent, - messages=messages, - tools=as_tools(tools), - handlers=handlers, - ) - - @prefect_task(task_run_name="Run LLM") - async def run_once_async(self) -> list[MessageType]: - async with self.tui(): - payload = self._setup_run() - if payload is None: - return - agent: Agent = payload.pop("agent") - response_handler = ResponseHandler() - payload["handlers"].append(response_handler) - - with ctx(agent=agent, flow=self.flow, controller=self): - response_gen = await completion_async( - messages=payload["messages"], - model=agent.get_model(), - tools=payload["tools"], - handlers=payload["handlers"], - max_iterations=1, - stream=True, - agent=agent, - ) - async for _ in response_gen: - pass - - # save history - self.flow.add_messages( - messages=response_handler.response_messages, - ) - self._iteration += 1 - - create_messages_markdown_artifact( - messages=response_handler.response_messages, - thread_id=self.flow.thread_id, - ) - - return response_handler.response_messages - - @prefect_task(task_run_name="Run LLM") - def run_once(self) -> list[MessageType]: - payload = self._setup_run() - if payload is None: - return - agent: Agent = payload.pop("agent") - response_handler = ResponseHandler() - payload["handlers"].append(response_handler) - - with ctx( - agent=agent, - flow=self.flow, - controller=self, - ): - response_gen = completion( - messages=payload["messages"], - model=agent.get_model(), - tools=payload["tools"], - handlers=payload["handlers"], - max_iterations=1, - stream=True, - agent=agent, - ) - for _ in response_gen: - pass - - # save history - self.flow.add_messages( - messages=response_handler.response_messages, - ) - self._iteration += 1 - - create_messages_markdown_artifact( - messages=response_handler.response_messages, - thread_id=self.flow.thread_id, - ) - - return response_handler.response_messages - - @prefect_task(task_run_name="Run LLM Controller") - async def run_async(self) -> list[MessageType]: - """ - Run the controller until all tasks are complete. - """ - if all(t.is_complete() for t in self.tasks): - return - - messages = [] - async with self.tui(): - # enter a flow context - with self.flow: - while ( - any(t.is_incomplete() for t in self.tasks) and not self._should_stop - ): - new_messages = await self.run_once_async() - if new_messages: - messages.extend(new_messages) - self._should_stop = False - return messages - - @prefect_task(task_run_name="Run LLM Controller") - def run(self) -> list[MessageType]: - """ - Run the controller until all tasks are complete. - """ - if all(t.is_complete() for t in self.tasks): - return - - messages = [] - # enter a flow context - with self.flow: - while any(t.is_incomplete() for t in self.tasks) and not self._should_stop: - new_messages = self.run_once() - if new_messages: - messages.extend(new_messages) - self._should_stop = False - return messages diff --git a/src/controlflow/controllers/instruction_template.py b/src/controlflow/controllers/instruction_template.py deleted file mode 100644 index 73a0ccfd..00000000 --- a/src/controlflow/controllers/instruction_template.py +++ /dev/null @@ -1,272 +0,0 @@ -import inspect - -from controlflow.agents import Agent -from controlflow.flows import Flow -from controlflow.tasks.task import Task -from controlflow.utilities.jinja import jinja_env -from controlflow.utilities.types import ControlFlowModel - -from .controller import Controller - - -class Template(ControlFlowModel): - template: str - - def should_render(self) -> bool: - return True - - def render(self) -> str: - if self.should_render(): - render_kwargs = dict(self) - render_kwargs.pop("template") - return jinja_env.from_string(inspect.cleandoc(self.template)).render( - **render_kwargs - ) - - -class AgentTemplate(Template): - template: str = """ - You are an AI agent participating in a workflow. Your role is to work on - your tasks and use the provided tools to complete those tasks and - communicate with the orchestrator. - - Important: The orchestrator is a Python script and cannot read or - respond to messages posted in this thread. You must use the provided - tools to communicate with the orchestrator. Posting messages in this - thread should only be used for thinking out loud, working through a - problem, or communicating with other agents. Any System messages or - messages prefixed with "SYSTEM:" are from the workflow system, not an - actual human. - - Your job is to: - 1. Select one or more tasks to work on from the ready tasks. - 2. Read the task instructions and work on completing the task objective, which may - involve using appropriate tools or collaborating with other agents - assigned to the same task. - 3. When you (and any other agents) have completed the task objective, - use the provided tool to inform the orchestrator of the task completion - and result. - 4. Repeat steps 1-3 until no more tasks are available for execution. - - Note that the orchestrator may decide to activate a different agent at any time. - - ## Your information - - - ID: {{ agent.id }} - - Name: "{{ agent.name }}" - {% if agent.description -%} - - Description: "{{ agent.description }}" - {% endif %} - - ## Instructions - - You must follow instructions at all times. Instructions can be added or removed at any time. - - - Never impersonate another agent - - {% if agent.instructions %} - {{ agent.instructions }} - {% endif %} - - {% if additional_instructions %} - {% for instruction in additional_instructions %} - - {{ instruction }} - {% endfor %} - {% endif %} - """ - agent: Agent - additional_instructions: list[str] - - -# class MemoryTemplate(Template): -# template: str = """ -# ## Memory - -# You have the following private memories: - -# {% for index, memory in memories %} -# - {{ index }}: {{ memory }} -# {% endfor %} - -# Use your memory to record information -# """ -# memories: dict[int, str] - - -class WorkflowTemplate(Template): - template: str = """ - - ## Tasks - - As soon as you have completed a task's objective, you must use the provided - tool to mark it successful and provide a result. It may take multiple - turns or collaboration with other agents to complete a task. Any agent - assigned to a task can complete it. Once a task is complete, no other - agent can interact with it. - - Tasks should only be marked failed due to technical errors like a broken - or erroring tool or unresponsive human. - - Tasks are not ready until all of their dependencies are met. Parent - tasks depend on all of their subtasks. - - ## Flow - - Name: {{ flow.name }} - {% if flow.description %} - Description: {{ flow.description }} - {% endif %} - {% if flow.context %} - Context: - {% for key, value in flow.context.items() %} - - {{ key }}: {{ value }} - {% endfor %} - {% endif %} - - ## Tasks - - ### Ready tasks - - These tasks are ready to be worked on because all of their dependencies have - been completed. You can only work on tasks to which you are assigned. - - {% for task in ready_tasks %} - #### Task {{ task.id }} - - objective: {{ task.objective }} - - instructions: {{ task.instructions}} - - context: {{ task.context }} - - result_type: {{ task.result_type }} - - depends_on: {{ task.depends_on }} - - parent: {{ task.parent }} - - assigned agents: {{ task.agents }} - {% if task.user_access %} - - user access: True - {% endif %} - - created_at: {{ task.created_at }} - - {% endfor %} - - ### Upstream tasks - - {% for task in upstream_tasks %} - #### Task {{ task.id }} - - objective: {{ task.objective }} - - instructions: {{ task.instructions}} - - status: {{ task.status }} - - result: {{ task.result }} - - error: {{ task.error }} - - context: {{ task.context }} - - depends_on: {{ task.depends_on }} - - parent: {{ task.parent }} - - assigned agents: {{ task.agents }} - {% if task.user_access %} - - user access: True - {% endif %} - - created_at: {{ task.created_at }} - - {% endfor %} - - ### Downstream tasks - - {% for task in downstream_tasks %} - #### Task {{ task.id }} - - objective: {{ task.objective }} - - instructions: {{ task.instructions}} - - status: {{ task.status }} - - result_type: {{ task.result_type }} - - context: {{ task.context }} - - depends_on: {{ task.depends_on }} - - parent: {{ task.parent }} - - assigned agents: {{ task.agents }} - {% if task.user_access %} - - user access: True - {% endif %} - - created_at: {{ task.created_at }} - - {% endfor %} - """ - - ready_tasks: list[dict] - upstream_tasks: list[dict] - downstream_tasks: list[dict] - current_task: Task - flow: Flow - - -class ToolTemplate(Template): - template: str = """ - You have access to various tools. They may change, so do not rely on history - to see what tools are available. - - ## Talking to human users - - If your task requires you to interact with a user, it will show - `user_access=True` and you will be given a `talk_to_user` tool. You can - use it to send messages to the user and optionally wait for a response. - This is how you tell the user things and ask questions. Do not mention - your tasks or the workflow. The user can only see messages you send - them via tool. They can not read the rest of the - thread. - - Human users may give poor, incorrect, or partial responses. You may need - to ask questions multiple times in order to complete your tasks. Do not - make up answers for omitted information; ask again and only fail the - task if you truly can not make progress. If your task requires human - interaction and neither it nor any assigned agents have `user_access`, - you can fail the task. - """ - - agent: Agent - - -class MainTemplate(ControlFlowModel): - agent: Agent - controller: Controller - ready_tasks: list[Task] - current_task: Task - context: dict - instructions: list[str] - agent_assignments: dict[Task, list[Agent]] - - def render(self): - # get up to 50 upstream and 50 downstream tasks - g = self.controller.graph - upstream_tasks = g.topological_sort([t for t in g.tasks if t.is_complete()])[ - -50: - ] - downstream_tasks = g.topological_sort( - [t for t in g.tasks if t.is_incomplete() and t not in self.ready_tasks] - )[:50] - - ready_tasks = [t.model_dump() for t in self.ready_tasks] - upstream_tasks = [t.model_dump() for t in upstream_tasks] - downstream_tasks = [t.model_dump() for t in downstream_tasks] - - # update agent assignments - assignments = {t.id: a for t, a in self.agent_assignments.items()} - for t in ready_tasks + upstream_tasks + downstream_tasks: - if t["id"] in assignments: - t["agents"] = assignments[t["id"]] - - templates = [ - AgentTemplate( - agent=self.agent, - additional_instructions=self.instructions, - ), - WorkflowTemplate( - flow=self.controller.flow, - ready_tasks=ready_tasks, - upstream_tasks=upstream_tasks, - downstream_tasks=downstream_tasks, - current_task=self.current_task, - ), - ToolTemplate(agent=self.agent), - # CommunicationTemplate( - # agent=self.agent, - # ), - ] - - rendered = [ - template.render() for template in templates if template.should_render() - ] - return "\n\n".join(rendered) diff --git a/src/controlflow/controllers/process_messages.py b/src/controlflow/controllers/process_messages.py deleted file mode 100644 index 83e42844..00000000 --- a/src/controlflow/controllers/process_messages.py +++ /dev/null @@ -1,127 +0,0 @@ -from typing import Optional, Union - -from controlflow.agents.agent import Agent -from controlflow.llm.messages import ( - AIMessage, - MessageType, - SystemMessage, - UserMessage, -) -from controlflow.llm.rules import LLMRules -from controlflow.tools import Tool - - -def create_system_message( - content: str, rules: LLMRules -) -> Union[SystemMessage, UserMessage]: - """ - Creates a SystemMessage or HumanMessage with SYSTEM: prefix, depending on the rules. - """ - if rules.system_message_must_be_first: - return SystemMessage(content=content) - else: - return UserMessage(content=f"SYSTEM: {content}") - - -def handle_agent_info_in_messages( - messages: list[MessageType], agent: Agent, rules: LLMRules -) -> list[MessageType]: - """ - If the message is from an agent, add a system message immediately before it to clarify which agent - it is from. This helps the system follow multi-agent conversations. - - """ - if not rules.add_system_messages_for_multi_agent: - return messages - - current_agent = agent - new_messages = [] - for msg in messages: - # if the message is from a different agent than the previous message, - # add a clarifying system message - if isinstance(msg, AIMessage) and msg.agent and msg.agent != current_agent: - system_msg = SystemMessage( - content=f'The following message is from agent "{msg.agent["name"]}" ' - f'with id {msg.agent["id"]}.' - ) - new_messages.append(system_msg) - current_agent = msg.agent - new_messages.append(msg) - return new_messages - - -def handle_system_messages_must_be_first(messages: list[MessageType], rules: LLMRules): - if rules.system_message_must_be_first: - new_messages = [] - # consolidate consecutive SystemMessages into one - if isinstance(messages[0], SystemMessage): - content = [messages[0].content] - i = 1 - while i < len(messages) and isinstance(messages[i], SystemMessage): - i += 1 - content.append(messages[i].content) - new_messages.append(SystemMessage(content="\n\n".join(content))) - - # replace all other SystemMessages with HumanMessages - for i, msg in enumerate(messages[len(new_messages) :]): - if isinstance(msg, SystemMessage): - msg = UserMessage(content=f"SYSTEM: {msg.content}") - new_messages.append(msg) - - return new_messages - else: - return messages - - -def handle_user_message_must_be_first_after_system( - messages: list[MessageType], rules: LLMRules -): - if rules.user_message_must_be_first_after_system: - if not messages: - messages.append(UserMessage(content="SYSTEM: Begin.")) - - # else get first non-system message - else: - i = 0 - while i < len(messages) and isinstance(messages[i], SystemMessage): - i += 1 - if i == len(messages) or ( - i < len(messages) and not isinstance(messages[i], UserMessage) - ): - messages.insert(i, UserMessage(content="SYSTEM: Begin.")) - return messages - - -def prepare_messages( - agent: Agent, - messages: list[MessageType], - system_message: Optional[SystemMessage], - rules: LLMRules, - tools: list[Tool], -): - """This is the main function for processing messages. It applies all the rules""" - messages = messages.copy() - - if system_message is not None: - messages.insert(0, system_message) - - messages = handle_agent_info_in_messages(messages, agent=agent, rules=rules) - - if not rules.allow_last_message_has_ai_role_with_tools: - if messages and tools and isinstance(messages[-1], AIMessage): - messages.append(create_system_message("Continue.", rules=rules)) - - if not rules.allow_consecutive_ai_messages: - if messages: - i = 1 - while i < len(messages): - if isinstance(messages[i], AIMessage) and isinstance( - messages[i - 1], AIMessage - ): - messages.insert(i, create_system_message("Continue.", rules=rules)) - i += 1 - - messages = handle_system_messages_must_be_first(messages, rules=rules) - messages = handle_user_message_must_be_first_after_system(messages, rules=rules) - - return messages diff --git a/src/controlflow/events/__init__.py b/src/controlflow/events/__init__.py new file mode 100644 index 00000000..23868b4f --- /dev/null +++ b/src/controlflow/events/__init__.py @@ -0,0 +1 @@ +from .events import Event diff --git a/src/controlflow/events/agent_events.py b/src/controlflow/events/agent_events.py new file mode 100644 index 00000000..e0cd3b6a --- /dev/null +++ b/src/controlflow/events/agent_events.py @@ -0,0 +1,118 @@ +from typing import Literal, Optional + +from pydantic import field_validator, model_validator + +from controlflow.agents.agent import Agent +from controlflow.events.events import Event, UnpersistedEvent +from controlflow.events.message_compiler import EventContext +from controlflow.llm.messages import ( + AIMessage, + AIMessageChunk, + BaseMessage, + HumanMessage, + SystemMessage, +) +from controlflow.utilities.logging import get_logger + +logger = get_logger(__name__) + + +class SelectAgentEvent(Event): + event: Literal["select-agent"] = "select-agent" + agent: Agent + + # def to_messages(self, context: EventContext) -> list[BaseMessage]: + # return [ + # SystemMessage( + # content=f'Agent "{self.agent.name}" with ID {self.agent.id} was selected.' + # ) + # ] + + +class SystemMessageEvent(Event): + event: Literal["system-message"] = "system-message" + content: str + + def to_messages(self, context: EventContext) -> list[BaseMessage]: + return [SystemMessage(content=self.content)] + + +class UserMessageEvent(Event): + event: Literal["user-message"] = "user-message" + content: str + + def to_messages(self, context: EventContext) -> list[BaseMessage]: + return [HumanMessage(content=self.content)] + + +class AgentMessageEvent(Event): + event: Literal["agent-message"] = "agent-message" + agent: Agent + message: dict + + @field_validator("message", mode="before") + def _message(cls, v): + if isinstance(v, BaseMessage): + v = v.dict() + v["type"] = "ai" + return v + + @model_validator(mode="after") + def _finalize(self): + self.message["name"] = self.agent.name + + @property + def ai_message(self) -> AIMessage: + return AIMessage(**self.message) + + def to_messages(self, context: EventContext) -> list[BaseMessage]: + if self.agent.name == context.agent.name: + return [self.ai_message] + elif self.message["content"]: + return [ + SystemMessage( + content=f'The following message was posted by Agent "{self.agent.name}" ' + f"with ID {self.agent.id}:", + ), + HumanMessage( + # ensure this is stringified to avoid issues with inline tool calls + content=str(self.message["content"]), + name=self.agent.name, + ), + ] + else: + return [] + + +class AgentMessageDeltaEvent(UnpersistedEvent): + event: Literal["agent-message-delta"] = "agent-message-delta" + + agent: Agent + delta: dict + snapshot: dict + + @field_validator("delta", "snapshot", mode="before") + def _message(cls, v): + if isinstance(v, BaseMessage): + v = v.dict() + v["type"] = "AIMessageChunk" + return v + + @model_validator(mode="after") + def _finalize(self): + self.delta["name"] = self.agent.name + self.snapshot["name"] = self.agent.name + + @property + def delta_message(self) -> AIMessageChunk: + return AIMessageChunk(**self.delta) + + @property + def snapshot_message(self) -> AIMessage: + return AIMessage(**self.snapshot | {"type": "ai"}) + + +class EndTurnEvent(Event): + event: Literal["end-turn"] = "end-turn" + agent: Agent + next_agent_name: Optional[str] = None diff --git a/src/controlflow/events/controller_events.py b/src/controlflow/events/controller_events.py new file mode 100644 index 00000000..b69fc21d --- /dev/null +++ b/src/controlflow/events/controller_events.py @@ -0,0 +1,26 @@ +from typing import Literal + +from controlflow.events.events import UnpersistedEvent +from controlflow.orchestration.controller import Controller +from controlflow.utilities.logging import get_logger + +logger = get_logger(__name__) + + +class ControllerStart(UnpersistedEvent): + event: Literal["controller-start"] = "controller-start" + persist: bool = False + controller: Controller + + +class ControllerEnd(UnpersistedEvent): + event: Literal["controller-end"] = "controller-end" + persist: bool = False + controller: Controller + + +class ControllerError(UnpersistedEvent): + event: Literal["controller-error"] = "controller-error" + persist: bool = False + controller: Controller + error: Exception diff --git a/src/controlflow/events/event_store.py b/src/controlflow/events/event_store.py new file mode 100644 index 00000000..3f0db393 --- /dev/null +++ b/src/controlflow/events/event_store.py @@ -0,0 +1,256 @@ +import abc +import json +import math +from functools import cache +from pathlib import Path +from typing import TYPE_CHECKING, Optional, Union + +from pydantic import Field, TypeAdapter, field_validator + +import controlflow +from controlflow.events.agent_events import ( + AgentMessageEvent, + EndTurnEvent, + SelectAgentEvent, + SystemMessageEvent, + UserMessageEvent, +) +from controlflow.events.events import Event +from controlflow.events.task_events import TaskCompleteEvent, TaskReadyEvent +from controlflow.events.tool_events import ToolResultEvent +from controlflow.utilities.types import ControlFlowModel + +if TYPE_CHECKING: + pass + +# This is a global variable that will be shared between all instances of InMemoryStore +IN_MEMORY_STORE = {} + + +@cache +def get_event_validator() -> TypeAdapter: + types = Union[ + TaskReadyEvent, + TaskCompleteEvent, + SelectAgentEvent, + SystemMessageEvent, + UserMessageEvent, + AgentMessageEvent, + EndTurnEvent, + ToolResultEvent, + Event, + ] + return TypeAdapter(list[types]) + + +def filter_events( + events: list[Event], + agent_ids: Optional[list[str]] = None, + task_ids: Optional[list[str]] = None, + types: Optional[list[str]] = None, + before_id: Optional[str] = None, + after_id: Optional[str] = None, + limit: Optional[int] = None, +): + """ + Filters a list of events based on the specified criteria. + + Args: + events (list[Event]): The list of events to filter. + agent_ids (Optional[list[str]]): The agent ids to filter by. Defaults to None. + task_ids (Optional[list[str]]): The task ids to filter by. Defaults to None. + types (Optional[list[str]]): The event types to filter by. Defaults to None. + before_id (Optional[str]): The ID of the event before which to start including events. Defaults to None. + after_id (Optional[str]): The ID of the event after which to stop including events. Defaults to None. + limit (Optional[int]): The maximum number of events to include. Defaults to None. + + Returns: + list[Event]: The filtered list of events. + """ + new_events = [] + seen_before_id = True if not before_id else False + seen_after_id = False if not after_id else True + + for event in reversed(events): + if event.id == before_id: + seen_before_id = True + if event.id == after_id: + seen_after_id = True + + # if we haven't reached the `before_id` we can skip this event + if not seen_before_id: + continue + + # if we've reached the `after_id` we can stop searching + if seen_after_id: + break + + # if types are specified and this event is not one of them, skip it + if types and event.event not in types: + continue + + # if agent_ids are specified and this event has agent_ids and none of them are in the list, skip it + agent_match = ( + ( + agent_ids + and event.agent_ids + and any(a in event.agent_ids for a in agent_ids) + ) + or not agent_ids + or not event.agent_ids + ) + + # if task_ids are specified and this event has task_ids and none of them are in the list, skip it + task_match = ( + (task_ids and event.task_ids and any(t in event.task_ids for t in task_ids)) + or not task_ids + or not event.task_ids + ) + + # if neither agent_ids nor task_ids were matched + if not (agent_match or task_match): + continue + + new_events.append(event) + + if len(new_events) >= (limit or math.inf): + break + + return list(reversed(new_events)) + + +def get_default_event_store() -> "EventStore": + return controlflow.default_event_store + + +class EventStore(ControlFlowModel, abc.ABC): + @abc.abstractmethod + def get_events( + self, + thread_id: str, + types: Optional[list[str]] = None, + agent_ids: Optional[list[str]] = None, + task_ids: Optional[list[str]] = None, + before_id: Optional[str] = None, + after_id: Optional[str] = None, + limit: Optional[int] = None, + ) -> list[Event]: + raise NotImplementedError() + + @abc.abstractmethod + def add_events(self, thread_id: str, events: list[Event]): + raise NotImplementedError() + + +class InMemoryStore(EventStore): + store: dict[str, list[Event]] = Field(default_factory=lambda: IN_MEMORY_STORE) + + def add_events(self, thread_id: str, events: list[Event]): + self.store.setdefault(thread_id, []).extend(events) + + def get_events( + self, + thread_id: str, + types: Optional[list[str]] = None, + agent_ids: Optional[list[str]] = None, + task_ids: Optional[list[str]] = None, + before_id: Optional[str] = None, + after_id: Optional[str] = None, + limit: Optional[int] = None, + ) -> list[Event]: + """ + Retrieve a list of events based on the specified criteria. + + Args: + thread_id (str): The ID of the thread to retrieve events from. + agent_ids (Optional[list[str]]): The agent associated with the events (default: None). + task_ids (Optional[list[str]]): The list of tasks associated with the events (default: None). + types (Optional[list[str]]): The list of event types to filter by (default: None). + before_id (Optional[str]): The ID of the event before which to start retrieving events (default: None). + after_id (Optional[str]): The ID of the event after which to stop retrieving events (default: None). + limit (Optional[int]): The maximum number of events to retrieve (default: None). + + Returns: + list[Event]: A list of events that match the specified criteria. + + """ + events = self.store.get(thread_id, []) + return filter_events( + events=events, + agent_ids=agent_ids, + task_ids=task_ids, + types=types, + before_id=before_id, + after_id=after_id, + limit=limit, + ) + + +class FileStore(EventStore): + base_path: Path = Field( + default_factory=lambda: controlflow.settings.home_path / "filestore_events" + ) + + def path(self, thread_id: str) -> Path: + return self.base_path / f"{thread_id}.json" + + @field_validator("base_path", mode="before") + def _validate_path(cls, v): + v = Path(v).expanduser() + if not v.exists(): + v.mkdir(parents=True, exist_ok=True) + return v + + def get_events( + self, + thread_id: str, + agent_ids: Optional[list[str]] = None, + task_ids: Optional[list[str]] = None, + types: Optional[list[str]] = None, + before_id: Optional[str] = None, + after_id: Optional[str] = None, + limit: Optional[int] = None, + ) -> list[Event]: + """ + Retrieves a list of events based on the specified criteria. + + Args: + thread_id (str): The ID of the thread to retrieve events from. + agent_ids (Optional[list[str]]): The agent associated with the events (default: None). + task_ids (Optional[list[str]]): The list of tasks associated with the events (default: None). + types (Optional[list[str]]): The list of event types to filter by (default: None). + before_id (Optional[str]): The ID of the event before which to stop retrieving events (default: None). + after_id (Optional[str]): The ID of the event after which to start retrieving events (default: None). + limit (Optional[int]): The maximum number of events to retrieve (default: None). + + Returns: + list[Event]: A list of events that match the specified criteria. + """ + if not self.path(thread_id).exists(): + return [] + + with open(self.path(thread_id), "r") as f: + raw_data = f.read() + + validator = get_event_validator() + events = validator.validate_json(raw_data) + + return filter_events( + events=events, + agent_ids=agent_ids, + task_ids=task_ids, + types=types, + before_id=before_id, + after_id=after_id, + limit=limit, + ) + + def add_events(self, thread_id: str, events: list[Event]): + if self.path(thread_id).exists(): + with open(self.path(thread_id), "r") as f: + all_events = json.load(f) + else: + all_events = [] + all_events.extend([event.model_dump(mode="json") for event in events]) + with open(self.path(thread_id), "w") as f: + json.dump(all_events, f) diff --git a/src/controlflow/events/events.py b/src/controlflow/events/events.py new file mode 100644 index 00000000..e839b73c --- /dev/null +++ b/src/controlflow/events/events.py @@ -0,0 +1,36 @@ +import datetime +import uuid +from typing import TYPE_CHECKING, Optional + +from pydantic import Field + +from controlflow.utilities.types import ControlFlowModel + +if TYPE_CHECKING: + from controlflow.events.message_compiler import EventContext + from controlflow.llm.messages import BaseMessage + +# This is a global variable that will be shared between all instances of InMemoryStore +IN_MEMORY_STORE = {} + + +class Event(ControlFlowModel): + model_config: dict = dict(extra="allow") + + event: str + id: str = Field(default_factory=lambda: uuid.uuid4().hex) + thread_id: Optional[str] = None + agent_ids: list[str] = [] + task_ids: list[str] = [] + timestamp: datetime.datetime = Field( + default_factory=lambda: datetime.datetime.now(datetime.timezone.utc) + ) + persist: bool = True + + def to_messages(self, context: "EventContext") -> list["BaseMessage"]: + return [] + + +class UnpersistedEvent(Event): + model_config = dict(arbitrary_types_allowed=True) + persist: bool = False diff --git a/src/controlflow/events/message_compiler.py b/src/controlflow/events/message_compiler.py new file mode 100644 index 00000000..c6988ff8 --- /dev/null +++ b/src/controlflow/events/message_compiler.py @@ -0,0 +1,222 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING, Optional + +import tiktoken + +import controlflow +from controlflow.events.events import Event +from controlflow.llm.messages import ( + AIMessage, + BaseMessage, + HumanMessage, + SystemMessage, + ToolMessage, +) +from controlflow.llm.rules import LLMRules +from controlflow.utilities.logging import get_logger + +if TYPE_CHECKING: + from controlflow.agents.agent import Agent + from controlflow.flows.flow import Flow + from controlflow.orchestration.controller import Controller + from controlflow.tasks.task import Task +logger = get_logger(__name__) + + +def add_user_message_to_beginning( + messages: list[BaseMessage], rules: LLMRules +) -> list[BaseMessage]: + """ + If the LLM requires the user message to be the first message, add a user + message to the beginning of the list. + """ + if rules.require_user_message_after_system: + if not messages or not isinstance(messages[0], HumanMessage): + messages.insert(0, HumanMessage(content="SYSTEM: Begin.")) + return messages + + +def ensure_at_least_one_message( + messages: list[BaseMessage], rules: LLMRules +) -> list[BaseMessage]: + if not messages and rules.require_at_least_one_message: + messages.append(HumanMessage(content="SYSTEM: Begin.")) + return messages + + +def add_user_message_to_end( + messages: list[BaseMessage], rules: LLMRules +) -> list[BaseMessage]: + """ + If the LLM doesn't allow the last message to be from the AI when using tools, + add a user message to the end of the list. + """ + if not rules.allow_last_message_from_ai_when_using_tools: + if not messages or isinstance(messages[-1], AIMessage): + msg = HumanMessage(content="SYSTEM: Continue.") + messages.append(msg) + return messages + + +def remove_duplicate_messages(messages: list[BaseMessage]) -> list[BaseMessage]: + """ + Removes duplicate messages from the list. + """ + seen = set() + new_messages = [] + for message in messages: + if message.id not in seen: + new_messages.append(message) + if message.id: + seen.add(message.id) + return new_messages + + +def break_up_consecutive_ai_messages( + messages: list[BaseMessage], rules: LLMRules +) -> list[BaseMessage]: + """ + Breaks up consecutive AI messages by inserting a system message. + """ + if not messages or rules.allow_consecutive_ai_messages: + return messages + + new_messages = messages.copy() + i = 1 + while i < len(new_messages): + if isinstance(new_messages[i], AIMessage) and isinstance( + new_messages[i - 1], AIMessage + ): + new_messages.insert(i, SystemMessage(content="Continue.")) + i += 1 + + return new_messages + + +def convert_system_messages( + messages: list[BaseMessage], rules: LLMRules +) -> list[BaseMessage]: + """ + Converts system messages to human messages if the LLM doesnt support system messages. + """ + if not messages or not rules.require_system_message_first: + return messages + + new_messages = [] + for message in messages: + if isinstance(message, SystemMessage): + new_messages.append(HumanMessage(content=f"SYSTEM: {message.content}")) + else: + new_messages.append(message) + return new_messages + + +def organize_tool_result_messages( + messages: list[BaseMessage], rules: LLMRules +) -> list[BaseMessage]: + if not messages or not rules.tool_result_must_follow_tool_call: + return messages + + tool_calls = {} + new_messages = [] + i = 0 + while i < len(messages): + message = messages[i] + # save the message index of any tool calls + if isinstance(message, AIMessage): + for tool_call in message.tool_calls + message.invalid_tool_calls: + tool_calls[tool_call["id"]] = i + new_messages.append(message) + + # move tool messages to follow their corresponding tool calls + elif isinstance(message, ToolMessage) and tool_call["id"] in tool_calls: + tool_call_index = tool_calls[tool_call["id"]] + new_messages.insert(tool_call_index + 1, message) + tool_calls[tool_call["id"]] += 1 + + else: + new_messages.append(message) + i += 1 + return new_messages + + +def count_tokens(message: BaseMessage) -> int: + # always use gpt-3.5 token counter with the entire message object; we only need to be approximate here + return len(tiktoken.encoding_for_model("gpt-3.5-turbo").encode(message.json())) + + +def trim_messages( + messages: list[BaseMessage], max_tokens: Optional[int] +) -> list[BaseMessage]: + """ + Trims messages to a maximum number of tokens while keeping the system message at the front. + """ + + if not messages or max_tokens is None: + return messages + + new_messages = [] + budget = max_tokens + + for message in reversed(messages): + if count_tokens(message) > budget: + break + new_messages.append(message) + budget -= count_tokens(message) + + return list(reversed(new_messages)) + + +@dataclass +class EventContext: + llm_rules: LLMRules + agent: Optional["Agent"] + ready_tasks: list["Task"] + flow: Optional["Flow"] + controller: Optional["Controller"] + + +class MessageCompiler: + def __init__( + self, + events: list[Event], + context: EventContext, + system_prompt: str = None, + max_tokens: int = None, + ): + self.events = events + self.context = context + self.system_prompt = system_prompt + self.max_tokens = max_tokens or controlflow.settings.max_input_tokens + + def compile_to_messages(self) -> list[BaseMessage]: + if self.system_prompt: + system = [SystemMessage(content=self.system_prompt)] + max_tokens = self.max_tokens - count_tokens(system[0]) + else: + system = [] + max_tokens = self.max_tokens + + messages = [] + + for event in self.events: + messages.extend(event.to_messages(self.context)) + + # process messages + msgs = messages.copy() + + # trim messages + msgs = trim_messages(msgs, max_tokens=max_tokens) + + # apply LLM rules + msgs = ensure_at_least_one_message(msgs, rules=self.context.llm_rules) + msgs = add_user_message_to_beginning(msgs, rules=self.context.llm_rules) + msgs = add_user_message_to_end(msgs, rules=self.context.llm_rules) + msgs = remove_duplicate_messages(msgs) + msgs = organize_tool_result_messages(msgs, rules=self.context.llm_rules) + msgs = break_up_consecutive_ai_messages(msgs, rules=self.context.llm_rules) + + # this should go last + msgs = convert_system_messages(msgs, rules=self.context.llm_rules) + + return system + msgs diff --git a/src/controlflow/events/task_events.py b/src/controlflow/events/task_events.py new file mode 100644 index 00000000..7f3480ce --- /dev/null +++ b/src/controlflow/events/task_events.py @@ -0,0 +1,24 @@ +from typing import Literal + +from controlflow.events.events import Event, UnpersistedEvent +from controlflow.tasks.task import Task +from controlflow.utilities.logging import get_logger + +logger = get_logger(__name__) + + +class TaskReadyEvent(UnpersistedEvent): + event: Literal["task-ready"] = "task-ready" + task: Task + + +class TaskCompleteEvent(Event): + event: Literal["task-complete"] = "task-complete" + task: Task + + # def to_messages(self, context: EventContext) -> list[BaseMessage]: + # return [ + # SystemMessage( + # content=f"Task {self.task.id} is complete with status: {self.task.status}" + # ) + # ] diff --git a/src/controlflow/events/tool_events.py b/src/controlflow/events/tool_events.py new file mode 100644 index 00000000..d27ce4fe --- /dev/null +++ b/src/controlflow/events/tool_events.py @@ -0,0 +1,62 @@ +from typing import Literal + +from pydantic import field_validator, model_validator + +from controlflow.agents.agent import Agent +from controlflow.events.events import Event +from controlflow.events.message_compiler import EventContext +from controlflow.llm.messages import AIMessage, BaseMessage, SystemMessage, ToolMessage +from controlflow.tools.tools import ToolCall, ToolResult +from controlflow.utilities.logging import get_logger + +logger = get_logger(__name__) + + +class ToolCallEvent(Event): + event: Literal["tool-call"] = "tool-call" + agent: Agent + tool_call: ToolCall + message: dict + + @field_validator("message", mode="before") + def _message(cls, v): + if isinstance(v, AIMessage): + v = v.dict() + v["type"] = "ai" + return v + + @model_validator(mode="after") + def _finalize(self): + self.message["name"] = self.agent.name + + @property + def ai_message(self) -> AIMessage: + return AIMessage(**self.message) + + +class ToolResultEvent(Event): + event: Literal["tool-result"] = "tool-result" + agent: Agent + tool_call: ToolCall + tool_result: ToolResult + + def to_messages(self, context: EventContext) -> list[BaseMessage]: + if self.agent.name == context.agent.name: + return [ + ToolMessage( + content=self.tool_result.str_result, + tool_call_id=self.tool_call["id"], + name=self.agent.name, + ) + ] + elif not self.tool_result.is_private: + return [ + SystemMessage( + content=f'The following {"failed" if self.tool_result.is_error else "successful"} ' + f'tool result was received by "{self.agent.name}" with ID {self.agent.id}:' + ), + SystemMessage(content=self.tool_result.str_result), + ] + + else: + return [] diff --git a/src/controlflow/flows/__init__.py b/src/controlflow/flows/__init__.py index d87c29d0..d1f717b7 100644 --- a/src/controlflow/flows/__init__.py +++ b/src/controlflow/flows/__init__.py @@ -1 +1 @@ -from .flow import Flow, get_flow, get_flow_messages +from .flow import Flow, get_flow diff --git a/src/controlflow/flows/flow.py b/src/controlflow/flows/flow.py index b64e89f8..8041e453 100644 --- a/src/controlflow/flows/flow.py +++ b/src/controlflow/flows/flow.py @@ -1,4 +1,3 @@ -import datetime import uuid from contextlib import contextmanager, nullcontext from typing import Any, Callable, Optional, Union @@ -6,8 +5,9 @@ from pydantic import Field from controlflow.agents import Agent -from controlflow.flows.history import History, get_default_history -from controlflow.llm.messages import MessageType +from controlflow.events.event_store import EventStore, get_default_event_store +from controlflow.events.events import Event +from controlflow.flows.graph import Graph from controlflow.tasks.task import Task from controlflow.utilities.context import ctx from controlflow.utilities.logging import get_logger @@ -18,10 +18,10 @@ class Flow(ControlFlowModel): + thread_id: str = Field(default_factory=lambda: uuid.uuid4().hex) name: Optional[str] = None description: Optional[str] = None - thread_id: str = Field(default_factory=lambda: uuid.uuid4().hex) - history: History = Field(default_factory=get_default_history) + event_store: EventStore = Field(default_factory=get_default_event_store) tools: list[Callable] = Field( default_factory=list, description="Tools that will be available to every agent in the flow", @@ -32,20 +32,20 @@ class Flow(ControlFlowModel): default_factory=list, ) context: dict[str, Any] = {} - tasks: dict[str, Task] = {} + graph: Graph = Field(default_factory=Graph) _cm_stack: list[contextmanager] = [] def __init__(self, *, copy_parent: bool = True, **kwargs): """ - By default, the flow will copy the history from the parent flow if one + By default, the flow will copy the event history from the parent flow if one exists, including all completed tasks. Because each flow is a new - thread, new messages will not be shared between the parent and child + thread, new events will not be shared between the parent and child flow. """ super().__init__(**kwargs) parent = get_flow() if parent and copy_parent: - self.add_messages(parent.get_messages()) + self.add_events(parent.get_events()) for task in parent.tasks.values(): if task.is_complete(): self.add_task(task) @@ -60,25 +60,34 @@ def __exit__(self, *exc_info): # exit the context manager return self._cm_stack.pop().__exit__(*exc_info) - def get_messages( + def add_task(self, task: Task): + self.graph.add_task(task) + + @property + def tasks(self) -> list[Task]: + return self.graph.topological_sort() + + def get_events( self, - limit: int = None, - before: datetime.datetime = None, - after: datetime.datetime = None, - ) -> list[MessageType]: - return self.history.load_messages( - thread_id=self.thread_id, limit=limit, before=before, after=after + agent_ids: Optional[list[str]] = None, + task_ids: Optional[list[str]] = None, + before_id: Optional[str] = None, + after_id: Optional[str] = None, + limit: Optional[int] = None, + types: Optional[list[str]] = None, + ) -> list[Event]: + return self.event_store.get_events( + thread_id=self.thread_id, + agent_ids=agent_ids, + task_ids=task_ids, + before_id=before_id, + after_id=after_id, + limit=limit, + types=types, ) - def add_messages(self, messages: list[MessageType]): - self.history.save_messages(thread_id=self.thread_id, messages=messages) - - def add_task(self, task: Task): - if self.tasks.get(task.id, task) is not task: - raise ValueError( - f"A different task with id '{task.id}' already exists in flow." - ) - self.tasks[task.id] = task + def add_events(self, events: list[Event]): + self.event_store.add_events(thread_id=self.thread_id, events=events) @contextmanager def create_context(self, create_prefect_flow_context: bool = True): @@ -92,57 +101,23 @@ def create_context(self, create_prefect_flow_context: bool = True): with ctx(**ctx_args), prefect_ctx: yield self - async def run_once_async(self): - """ - Runs one step of the flow asynchronously. - """ - if self.tasks: - from controlflow.controllers import Controller - - controller = Controller( - flow=self, - tasks=list(self.tasks.values()), - ) - await controller.run_once_async() - def run_once(self): """ Runs one step of the flow. """ - if self.tasks: - from controlflow.controllers import Controller - - controller = Controller( - flow=self, - tasks=list(self.tasks.values()), - ) - controller.run_once() - - async def run_async(self): - """ - Runs the flow asynchronously. - """ - if self.tasks: - from controlflow.controllers import Controller + from controlflow.orchestration import Controller - controller = Controller( - flow=self, - tasks=list(self.tasks.values()), - ) - await controller.run_async() + controller = Controller(flow=self) + controller.run_once() def run(self): """ Runs the flow. """ - if self.tasks: - from controlflow.controllers import Controller + from controlflow.orchestration import Controller - controller = Controller( - flow=self, - tasks=list(self.tasks.values()), - ) - controller.run() + controller = Controller(flow=self) + controller.run() def get_flow() -> Optional[Flow]: @@ -153,16 +128,14 @@ def get_flow() -> Optional[Flow]: return flow -def get_flow_messages(limit: int = None) -> list[MessageType]: +def get_flow_events(limit: int = None) -> list[Event]: """ - Loads messages from the flow's thread. - - Will error if no flow is found in the context. + Loads events from the active flow's thread. """ if limit is None: limit = 50 flow = get_flow() if flow: - return get_default_history().load_messages(flow.thread_id, limit=limit) + return flow.get_events(limit=limit) else: return [] diff --git a/src/controlflow/controllers/graph.py b/src/controlflow/flows/graph.py similarity index 95% rename from src/controlflow/controllers/graph.py rename to src/controlflow/flows/graph.py index bef4b14b..384b9323 100644 --- a/src/controlflow/controllers/graph.py +++ b/src/controlflow/flows/graph.py @@ -129,7 +129,7 @@ def upstream_tasks( f"upstream_{'immediate' if immediate else 'all'}_{tuple(start_tasks)}" ) if cache_key not in self._cache: - result = set() + result = set(start_tasks) visited = set() def _upstream(task): @@ -137,10 +137,7 @@ def _upstream(task): return visited.add(task) for edge in self.upstream_edges().get(task, []): - if ( - edge.upstream not in visited - and edge.upstream not in start_tasks - ): + if edge.upstream not in visited: result.add(edge.upstream) if not immediate: _upstream(edge.upstream) @@ -172,7 +169,7 @@ def downstream_tasks( f"downstream_{'immediate' if immediate else 'all'}_{tuple(start_tasks)}" ) if cache_key not in self._cache: - result = set() + result = set(start_tasks) visited = set() def _downstream(task): @@ -180,10 +177,7 @@ def _downstream(task): return visited.add(task) for edge in self.downstream_edges().get(task, []): - if ( - edge.downstream not in visited - and edge.downstream not in start_tasks - ): + if edge.downstream not in visited: result.add(edge.downstream) if not immediate: _downstream(edge.downstream) diff --git a/src/controlflow/flows/history.py b/src/controlflow/flows/history.py index cf519dbe..ebc81afe 100644 --- a/src/controlflow/flows/history.py +++ b/src/controlflow/flows/history.py @@ -7,7 +7,7 @@ from pydantic import Field, field_validator import controlflow -from controlflow.llm.messages import MessageType +from controlflow.llm.messages import BaseMessage from controlflow.utilities.types import ControlFlowModel # This is a global variable that will be shared between all instances of InMemoryHistory @@ -26,16 +26,16 @@ def load_messages( limit: int = None, before: datetime.datetime = None, after: datetime.datetime = None, - ) -> list[MessageType]: + ) -> list[BaseMessage]: raise NotImplementedError() @abc.abstractmethod - def save_messages(self, thread_id: str, messages: list[MessageType]): + def save_messages(self, thread_id: str, messages: list[BaseMessage]): raise NotImplementedError() class InMemoryHistory(History): - history: dict[str, list[MessageType]] = Field( + history: dict[str, list[BaseMessage]] = Field( default_factory=lambda: IN_MEMORY_HISTORY ) @@ -45,7 +45,7 @@ def load_messages( limit: int = None, before: datetime.datetime = None, after: datetime.datetime = None, - ) -> list[MessageType]: + ) -> list[BaseMessage]: messages = self.history.get(thread_id, []) filtered_messages = [ msg @@ -56,7 +56,7 @@ def load_messages( ] return list(reversed(filtered_messages)) - def save_messages(self, thread_id: str, messages: list[MessageType]): + def save_messages(self, thread_id: str, messages: list[BaseMessage]): self.history.setdefault(thread_id, []).extend(messages) @@ -81,7 +81,7 @@ def load_messages( limit: int = None, before: datetime.datetime = None, after: datetime.datetime = None, - ) -> list[MessageType]: + ) -> list[BaseMessage]: if not self.path(thread_id).exists(): return [] @@ -90,7 +90,7 @@ def load_messages( messages = [] for msg in reversed(all_messages): - message = MessageType.model_validate(msg) + message = BaseMessage.model_validate(msg) if before is None or message.timestamp < before: if after is None or message.timestamp > after: messages.append(message) @@ -99,7 +99,7 @@ def load_messages( return list(reversed(messages)) - def save_messages(self, thread_id: str, messages: list[MessageType]): + def save_messages(self, thread_id: str, messages: list[BaseMessage]): if self.path(thread_id).exists(): with open(self.path(thread_id), "r") as f: all_messages = json.load(f) diff --git a/src/controlflow/handlers/__init__.py b/src/controlflow/handlers/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/controlflow/handlers/print_handler.py b/src/controlflow/handlers/print_handler.py deleted file mode 100644 index 9130e73d..00000000 --- a/src/controlflow/handlers/print_handler.py +++ /dev/null @@ -1,186 +0,0 @@ -import datetime - -import rich -from rich import box -from rich.console import Group -from rich.live import Live -from rich.markdown import Markdown -from rich.panel import Panel -from rich.spinner import Spinner -from rich.table import Table - -import controlflow -from controlflow.llm.handlers import CompletionHandler -from controlflow.llm.messages import ( - AIMessage, - AIMessageChunk, - MessageType, - SystemMessage, - ToolCall, - ToolMessage, - UserMessage, -) -from controlflow.utilities.rich import console as cf_console - - -class PrintHandler(CompletionHandler): - def __init__(self): - self.messages: dict[str, MessageType] = {} - self.live: Live = Live(auto_refresh=True, console=cf_console) - self.paused_id: str = None - super().__init__() - - def on_start(self): - try: - self.live.start() - except rich.errors.LiveError: - pass - - def on_end(self): - self.live.stop() - - def on_exception(self, exc: Exception): - self.live.stop() - - def update_live(self, latest: MessageType = None): - # sort by timestamp, using the custom message id as a tiebreaker - # in case the same message appears twice (e.g. tool call and message) - messages = sorted(self.messages.items(), key=lambda m: (m[1].timestamp, m[0])) - content = [] - - tool_results = {} # To track tool results by their call ID - - # gather all tool messages first - for _, message in messages: - if isinstance(message, ToolMessage): - tool_results[message.tool_call_id] = message - - for _, message in messages: - if isinstance(message, (SystemMessage, UserMessage, AIMessage)): - content.append(format_message(message, tool_results=tool_results)) - # no need to handle tool messages - - if self.live.is_started: - self.live.update(Group(*content), refresh=True) - elif latest: - cf_console.print(format_message(latest)) - - def on_message_delta(self, delta: AIMessageChunk, snapshot: AIMessageChunk): - self.messages[snapshot.id] = snapshot - self.update_live() - - def on_message_done(self, message: AIMessage): - self.messages[message.id] = message - self.update_live(latest=message) - - def on_tool_call_delta(self, delta: AIMessageChunk, snapshot: AIMessageChunk): - self.messages[snapshot.id] = snapshot - self.update_live() - - def on_tool_call_done(self, message: AIMessage): - self.messages[message.id] = message - self.update_live(latest=message) - - def on_tool_result_created(self, message: AIMessage, tool_call: ToolCall): - # if collecting input on the terminal, pause the live display - # to avoid overwriting the input prompt - if tool_call["name"] == "talk_to_user": - self.paused_id = tool_call["id"] - self.live.stop() - self.messages.clear() - - def on_tool_result_done(self, message: ToolMessage): - self.messages[f"tool-result:{message.tool_call_id}"] = message - - # if we were paused, resume the live display - if self.paused_id and self.paused_id == message.tool_call_id: - self.paused_id = None - # print newline to avoid odd formatting issues - print() - self.live = Live(auto_refresh=False) - self.live.start() - self.update_live(latest=message) - - -def format_timestamp(timestamp: datetime.datetime) -> str: - local_timestamp = timestamp.astimezone() - return local_timestamp.strftime("%I:%M:%S %p").lstrip("0").rjust(11) - - -def status(icon, text) -> Table: - t = Table.grid(padding=1) - t.add_row(icon, text) - return t - - -ROLE_COLORS = { - "system": "gray", - "ai": "blue", - "user": "green", -} -ROLE_NAMES = { - "system": "System", - "ai": "Agent", - "user": "User", -} - - -def format_message(message: MessageType, tool_results: dict = None) -> Panel: - if message.role == "ai" and message.name: - title = f"Agent: {message.name}" - else: - title = ROLE_NAMES.get(message.role, "Agent") - - content = [] - if message.str_content: - content.append(Markdown(message.str_content or "")) - - tool_content = [] - for tool_call in getattr(message, "tool_calls", []): - tool_result = (tool_results or {}).get(tool_call["id"]) - if tool_result: - c = format_tool_result(tool_result) - - else: - c = format_tool_call(tool_call) - if c: - tool_content.append(c) - - if content and tool_content: - content.append("\n") - - return Panel( - Group(*content, *tool_content), - title=f"[bold]{title}[/]", - subtitle=f"[italic]{format_timestamp(message.timestamp)}[/]", - title_align="left", - subtitle_align="right", - border_style=ROLE_COLORS.get(message.role, "red"), - box=box.ROUNDED, - width=100, - expand=True, - padding=(1, 2), - ) - - -def format_tool_call(tool_call: ToolCall) -> Panel: - name = tool_call["name"] - args = tool_call["args"] - if controlflow.settings.tools_verbose: - return status(Spinner("dots"), f'Tool call: "{name}"\n\nTool args: {args}') - return status(Spinner("dots"), f'Tool call: "{name}"') - - -def format_tool_result(message: ToolMessage) -> Panel: - name = message.tool_call["name"] - - if message.is_error: - icon = ":x:" - else: - icon = ":white_check_mark:" - - if controlflow.settings.tools_verbose: - msg = f'Tool call: "{name}"\n\nTool result: {message.str_content}' - else: - msg = f'Tool call: "{name}"' - return status(icon, msg) diff --git a/src/controlflow/llm/__init__.py b/src/controlflow/llm/__init__.py index b1dda89d..8383c810 100644 --- a/src/controlflow/llm/__init__.py +++ b/src/controlflow/llm/__init__.py @@ -1 +1 @@ -from controlflow.llm import models, messages, tools, handlers, completions, rules +from controlflow.llm import models, messages, rules diff --git a/src/controlflow/llm/classify.py b/src/controlflow/llm/classify.py deleted file mode 100644 index 7458e77e..00000000 --- a/src/controlflow/llm/classify.py +++ /dev/null @@ -1,104 +0,0 @@ -from typing import Union - -import tiktoken -from langchain_openai import AzureChatOpenAI, ChatOpenAI -from pydantic import TypeAdapter - -import controlflow -from controlflow.llm.messages import AIMessage, SystemMessage, UserMessage -from controlflow.llm.models import BaseChatModel - - -def classify( - data: str, - labels: list, - instructions: str = None, - context: dict = None, - model: BaseChatModel = None, -): - try: - label_strings = [TypeAdapter(type(t)).dump_json(t).decode() for t in labels] - except Exception as exc: - raise ValueError(f"Unable to cast labels to strings: {exc}") - - messages = [ - SystemMessage( - """ - You are an expert classifier that always maintains as much semantic meaning - as possible when labeling information. You use inference or deduction whenever - necessary to understand missing or omitted data. Classify the provided data, - text, or information as one of the provided labels. For boolean labels, - consider "truthy" or affirmative inputs to be "true". - - ## Labels - - You must classify the data as one of the following labels, which are - numbered (starting from 0) and provide a brief description. Output - the label number only. - - {% for label in labels %} - - Label {{ loop.index0 }}: {{ label }} - {% endfor %} - """ - ).render(labels=label_strings), - UserMessage( - """ - ## Information to classify - - {{ data }} - - {% if instructions -%} - ## Additional instructions - - {{ instructions }} - {% endif %} - - {% if context -%} - ## Additional context - - {% for key, value in context.items() -%} - - {{ key }}: {{ value }} - - {% endfor %} - {% endif %} - - """ - ).render(data=data, instructions=instructions, context=context), - AIMessage(""" - The best label for the data is Label number - """), - ] - - model = model or controlflow.llm.models.get_default_model() - - kwargs = {} - if isinstance(model, (ChatOpenAI, AzureChatOpenAI)): - openai_kwargs = _openai_kwargs(model=model, n_labels=len(labels)) - kwargs.update(openai_kwargs) - else: - messages.append( - SystemMessage( - "Return only the label number, no other information or tokens." - ) - ) - - result = controlflow.llm.completions.completion( - messages=messages, - model=model, - max_tokens=1, - **kwargs, - ) - - index = int(result[0].content) - return labels[index] - - -def _openai_kwargs(model: Union[AzureChatOpenAI, ChatOpenAI], n_labels: int): - encoding = tiktoken.encoding_for_model(model.model_name) - - logit_bias = {} - for i in range(n_labels): - for token in encoding.encode(str(i)): - logit_bias[token] = 100 - - return dict(logit_bias=logit_bias) diff --git a/src/controlflow/llm/completions.py b/src/controlflow/llm/completions.py deleted file mode 100644 index 5f104965..00000000 --- a/src/controlflow/llm/completions.py +++ /dev/null @@ -1,399 +0,0 @@ -import math -from typing import TYPE_CHECKING, AsyncGenerator, Callable, Generator, Optional, Union - -import langchain_core.language_models as lc_models -import langchain_core.messages -import tiktoken -from langchain_core.messages.utils import trim_messages - -import controlflow -import controlflow.llm.models -from controlflow.llm.handlers import ( - CompletionEvent, - CompletionHandler, - ResponseHandler, -) -from controlflow.llm.messages import AIMessage, AIMessageChunk, BaseMessage, MessageType -from controlflow.llm.tools import ToolCall, as_tools, handle_tool_call - -if TYPE_CHECKING: - from controlflow.agents.agent import Agent - - -def token_counter(message: langchain_core.messages.BaseMessage) -> int: - # always use gpt-3.5 token counter with the entire message object; we only need to be approximate here - return len(tiktoken.encoding_for_model("gpt-3.5-turbo").encode(message.json())) - - -def handle_tool_calls( - message: AIMessage, - tools: list[Callable], - response_messages: list[MessageType], - agent: Optional["Agent"] = None, -): - """ - Emit events for the given message when it has tool calls. - """ - for tool_call in message.tool_calls: - yield CompletionEvent( - type="tool_result_created", - payload=dict(message=message, tool_call=tool_call), - ) - error = handle_multiple_talk_to_user_calls(tool_call, message) - tool_result_message = handle_tool_call( - tool_call, tools, error=error, agent=agent - ) - response_messages.append(tool_result_message) - yield CompletionEvent( - type="tool_result_done", payload=dict(message=tool_result_message) - ) - - -def handle_delta_events( - delta: langchain_core.messages.AIMessageChunk, - snapshot: langchain_core.messages.AIMessageChunk, - deltas: list[langchain_core.messages.AIMessageChunk], - agent: "Agent", -): - """ - Emit events for the given delta message. - - Note this function receives langchain messages - """ - delta = AIMessageChunk.from_langchain_message(delta, agent=agent) - snapshot = AIMessageChunk.from_langchain_message(snapshot, agent=agent) - - if delta.content: - if not deltas[-1].content: - yield CompletionEvent(type="message_created", payload=dict(delta=delta)) - if delta.content != deltas[-1].content: - yield CompletionEvent( - type="message_delta", - payload=dict(delta=delta, snapshot=snapshot), - ) - - if delta.tool_calls: - if not deltas[-1].tool_calls: - yield CompletionEvent(type="tool_call_created", payload=dict(delta=delta)) - yield CompletionEvent( - type="tool_call_delta", - payload=dict(delta=delta, snapshot=snapshot), - ) - - -def handle_done_events(message: AIMessage): - """ - Emit events for the given message when it has been processed. - """ - if message.content: - yield CompletionEvent(type="message_done", payload=dict(message=message)) - if message.tool_calls: - yield CompletionEvent(type="tool_call_done", payload=dict(message=message)) - - -def handle_multiple_talk_to_user_calls(tool_call: ToolCall, message: AIMessage): - if ( - tool_call["name"] == "talk_to_user" - and len([t for t in message.tool_calls if t["name"] == "talk_to_user"]) > 1 - ): - error = 'Tool call "talk_to_user" can only be used once per turn.' - else: - error = None - return error - - -def prepare_messages(messages: list[MessageType]) -> list[MessageType]: - """ - Make any necessary modifications to the messages before they are passed to the model. - """ - return messages - - -def _completion_generator( - messages: list[MessageType], - model: lc_models.BaseChatModel, - tools: Optional[list[Callable]], - max_iterations: int, - stream: bool, - agent: Optional["Agent"] = None, - **kwargs, -) -> Generator[CompletionEvent, None, None]: - response_messages = [] - response_message = None - - if tools: - model = model.bind_tools([t.to_lc_tool() for t in as_tools(tools)]) - - counter = 0 - try: - yield CompletionEvent(type="start", payload={}) - - # continue as long as the last response message contains tool calls (or - # there is no response message yet) - while not response_message or response_message.tool_calls: - input_messages = [ - m.to_langchain_message() - for m in messages + response_messages - if isinstance(m, BaseMessage) - ] - input_messages = trim_messages( - messages=input_messages, - max_tokens=controlflow.settings.max_input_tokens, - include_system=True, - token_counter=token_counter, - ) - - if not stream: - response_message = model.invoke(input=input_messages, **kwargs) - response_message = AIMessage.from_langchain_message( - response_message, agent=agent - ) - - else: - # all streaming responses are langchain Pydantic v1 models - # which we don't convert to AIMessage/AIMessageChunks for sanity. - # they are converted in handle_delta_events and when the stream is finished. - - # initialize the list of deltas with an empty delta - # to facilitate comparison with the previous delta - deltas = [langchain_core.messages.AIMessageChunk(content="")] - - for i, delta in enumerate(model.stream(input=input_messages, **kwargs)): - if i == 0: - snapshot = delta - else: - snapshot = snapshot + delta - - yield from handle_delta_events( - delta=delta, snapshot=snapshot, deltas=deltas, agent=agent - ) - - deltas.append(delta) - - # the last snapshot message is the response message - response_message = AIMessage.from_langchain_message( - snapshot, agent=agent - ) - - # handle done events for the response message - yield from handle_done_events(response_message) - - # append the response message to the list of response messages - response_messages.append(response_message) - - # handle tool calls - yield from handle_tool_calls( - message=response_message, - tools=tools, - response_messages=response_messages, - agent=agent, - ) - - counter += 1 - if counter >= (max_iterations or math.inf): - break - - except (BaseException, Exception) as exc: - yield CompletionEvent(type="exception", payload=dict(exc=exc)) - raise - finally: - yield CompletionEvent(type="end", payload={}) - - -async def _completion_async_generator( - messages: list[MessageType], - model: lc_models.BaseChatModel, - tools: Optional[list[Callable]], - max_iterations: int, - stream: bool, - agent: Optional["Agent"] = None, - **kwargs, -) -> AsyncGenerator[CompletionEvent, None]: - response_messages = [] - response_message = None - - if tools: - model = model.bind_tools([t.to_lc_tool() for t in as_tools(tools)]) - - counter = 0 - try: - yield CompletionEvent(type="start", payload={}) - - # continue as long as the last response message contains tool calls (or - # there is no response message yet) - while not response_message or response_message.tool_calls: - input_messages = [ - m.to_langchain_message() if isinstance(m, BaseMessage) else m - for m in messages + response_messages - ] - - input_messages = trim_messages( - messages=input_messages, - max_tokens=controlflow.settings.max_input_tokens, - include_system=True, - token_counter=token_counter, - ) - input_messages = [ - m.to_langchain_message() - for m in input_messages - if isinstance(m, BaseMessage) - ] - - if not stream: - response_message = await model.ainvoke(input=input_messages, **kwargs) - response_message = AIMessage.from_langchain_message( - response_message, agent=agent - ) - - else: - # all streaming responses are langchain Pydantic v1 models - # which we don't convert to AIMessage/AIMessageChunks for sanity. - # they are converted in handle_delta_events and when the stream is finished. - - # initialize the list of deltas with an empty delta - # to facilitate comparison with the previous delta - deltas = [langchain_core.messages.AIMessageChunk(content="")] - - async for i, delta in enumerate( - model.astream(input=input_messages, **kwargs) - ): - if i == 0: - snapshot = delta - else: - snapshot = snapshot + delta - - for event in handle_delta_events( - delta=delta, snapshot=snapshot, deltas=deltas, agent=agent - ): - yield event - - deltas.append(delta) - - # the last snapshot message is the response message - response_message = AIMessage.from_langchain_message( - snapshot, agent=agent - ) - - # handle done events for the response message - for event in handle_done_events(response_message): - yield event - - # append the response message to the list of response messages - response_messages.append(response_message) - - # handle tool calls - for event in handle_tool_calls( - message=response_message, - tools=tools, - response_messages=response_messages, - agent=agent, - ): - yield event - - counter += 1 - if counter >= (max_iterations or math.inf): - break - - except (BaseException, Exception) as exc: - yield CompletionEvent(type="exception", payload=dict(exc=exc)) - raise - finally: - yield CompletionEvent(type="end", payload={}) - - -def _handle_events( - generator: Generator[CompletionEvent, None, None], handlers: list[CompletionHandler] -) -> Generator[CompletionEvent, None, None]: - for event in generator: - for handler in handlers: - try: - handler.on_event(event) - except Exception as exc: - generator.throw(exc) - yield event - - -async def _handle_events_async( - generator: AsyncGenerator, handlers: list[CompletionHandler] -) -> AsyncGenerator[CompletionEvent, None]: - async for event in generator: - for handler in handlers: - try: - handler.on_event(event) - except Exception as exc: - await generator.athrow(exc) - yield event - - -def completion( - messages: list[MessageType], - model: lc_models.BaseChatModel = None, - tools: list[Callable] = None, - max_iterations: int = None, - handlers: list[CompletionHandler] = None, - stream: bool = False, - agent: Optional["Agent"] = None, - **kwargs, -) -> Union[list[MessageType], Generator[MessageType, None, None]]: - if model is None: - model = controlflow.llm.models.get_default_model() - - response_handler = ResponseHandler() - handlers = handlers or [] - handlers.append(response_handler) - - completion_generator = _completion_generator( - messages=messages, - model=model, - tools=tools, - max_iterations=max_iterations, - stream=stream, - agent=agent, - **kwargs, - ) - - handlers_generator = _handle_events(completion_generator, handlers) - - if stream: - return handlers_generator - else: - for _ in handlers_generator: - pass - return response_handler.response_messages - - -async def completion_async( - messages: list[MessageType], - model: lc_models.BaseChatModel = None, - tools: list[Callable] = None, - max_iterations: int = None, - handlers: list[CompletionHandler] = None, - stream: bool = False, - agent: Optional["Agent"] = None, - **kwargs, -) -> Union[list[MessageType], Generator[MessageType, None, None]]: - if model is None: - model = controlflow.llm.models.get_default_model() - - response_handler = ResponseHandler() - handlers = handlers or [] - handlers.append(response_handler) - - completion_generator = _completion_async_generator( - messages=messages, - model=model, - tools=tools, - max_iterations=max_iterations, - stream=stream, - agent=agent, - **kwargs, - ) - - handlers_generator = _handle_events_async(completion_generator, handlers) - - if stream: - return handlers_generator - else: - async for _ in handlers_generator: - pass - return response_handler.response_messages diff --git a/src/controlflow/llm/handlers.py b/src/controlflow/llm/handlers.py deleted file mode 100644 index efaa6d53..00000000 --- a/src/controlflow/llm/handlers.py +++ /dev/null @@ -1,106 +0,0 @@ -from controlflow.llm.messages import ( - AIMessage, - AIMessageChunk, - MessageType, - ToolCall, - ToolMessage, -) -from controlflow.utilities.context import ctx -from controlflow.utilities.types import ControlFlowModel - - -class CompletionEvent(ControlFlowModel): - type: str - payload: dict - - -class CompletionHandler: - def __init__(self): - self._response_message_ids: set[str] = set() - - def on_event(self, event: CompletionEvent): - method = getattr(self, f"on_{event.type}", None) - if not method: - raise ValueError(f"Unknown event type: {event.type}") - method(**event.payload) - if event.type in [ - "message_done", - "tool_call_done", - "tool_result_done", - ]: - # only fire the on_response_message hook once per message - # (a message could contain both a tool call and a message) - if event.payload["message"].id not in self._response_message_ids: - self.on_response_message(event.payload["message"]) - self._response_message_ids.add(event.payload["message"].id) - - def on_start(self): - pass - - def on_end(self): - pass - - def on_exception(self, exc: Exception): - pass - - def on_message_created(self, delta: AIMessageChunk): - pass - - def on_message_delta(self, delta: AIMessageChunk, snapshot: AIMessageChunk): - pass - - def on_message_done(self, message: AIMessage): - pass - - def on_tool_call_created(self, delta: AIMessageChunk): - pass - - def on_tool_call_delta(self, delta: AIMessageChunk, snapshot: AIMessageChunk): - pass - - def on_tool_call_done(self, message: AIMessage): - pass - - def on_tool_result_created(self, message: AIMessage, tool_call: ToolCall): - pass - - def on_tool_result_done(self, message: ToolMessage): - pass - - def on_response_message(self, message: MessageType): - """ - This handler is called whenever a message is generated that should be - included in the completion history (e.g. a `message`, `tool_call` or - `tool_result`). Note that this is called *in addition* to the respective - on_*_done handlers, and can be used to quickly collect all messages - generated during a completion. Messages that satisfy multiple criteria - (e.g. a message and a tool call) will only be included once. - """ - pass - - -class ResponseHandler(CompletionHandler): - """ - A handler for collecting response messages. - """ - - def __init__(self): - super().__init__() - self.response_messages = [] - - def on_response_message(self, message: MessageType): - self.response_messages.append(message) - - -class TUIHandler(CompletionHandler): - def on_message_delta(self, delta: AIMessageChunk, snapshot: AIMessageChunk) -> None: - if tui := ctx.get("tui"): - tui.update_message(message=snapshot) - - def on_tool_call_delta(self, delta: AIMessageChunk, snapshot: AIMessageChunk): - if tui := ctx.get("tui"): - tui.update_message(message=snapshot) - - def on_tool_result_done(self, message: ToolMessage): - if tui := ctx.get("tui"): - tui.update_tool_result(message=message) diff --git a/src/controlflow/llm/messages.py b/src/controlflow/llm/messages.py index 7a85317a..76076b93 100644 --- a/src/controlflow/llm/messages.py +++ b/src/controlflow/llm/messages.py @@ -1,254 +1,23 @@ -import datetime -import re -import uuid -from typing import TYPE_CHECKING, Any, Literal, Optional, Union - -import langchain_core.messages -from langchain_core.messages.tool import InvalidToolCall, ToolCall, ToolCallChunk -from pydantic import Field, field_validator - -from controlflow.utilities.jinja import jinja_env -from controlflow.utilities.types import ControlFlowModel - -if TYPE_CHECKING: - from controlflow.agents.agent import Agent - - -class BaseMessage(ControlFlowModel): - """ - ControlFlow uses Message objects that are similar to LangChain messages, but more purpose built. - - Note that LangChain messages are Pydantic V1 models, while ControlFlow messages are Pydantic V2 models. - """ - - id: Optional[str] = Field(default_factory=lambda: uuid.uuid4().hex) - timestamp: datetime.datetime = Field( - default_factory=lambda: datetime.datetime.now(datetime.timezone.utc), - ) - role: str - content: Union[str, list[Union[str, dict]]] - name: Optional[str] = None - - # private attr to hold the original langchain message - _langchain_message: Optional[ - Union[ - langchain_core.messages.BaseMessage, - langchain_core.messages.BaseMessageChunk, - ] - ] = None - - @field_validator("name") - def _sanitize_name(cls, v): - # sanitize name for API compatibility - OpenAI API only allows alphanumeric characters, dashes, and underscores - if v is not None: - v = re.sub(r"[^a-zA-Z0-9_-]", "-", v).strip("-") - return v - - def __init__( - self, - content: Union[str, list[Union[str, dict]]], - _langchain_message: Optional[ - Union[ - langchain_core.messages.BaseMessage, - langchain_core.messages.BaseMessageChunk, - ] - ] = None, - **kwargs: Any, - ) -> None: - """Pass in content as positional arg.""" - super().__init__(content=content, **kwargs) - self._langchain_message = _langchain_message - - @property - def str_content(self) -> str: - if not isinstance(self.content, str): - return str(self.content) - return self.content - - def render(self, **kwargs) -> "MessageType": - """ - Renders the content as a jinja template with the given keyword arguments - and returns a new Message. - """ - content = jinja_env.from_string(self.content).render(**kwargs) - return self.model_copy(update=dict(content=content)) - - @classmethod - def _langchain_message_kwargs( - cls, message: langchain_core.messages.BaseMessage - ) -> "BaseMessage": - return message.dict(include={"content", "id", "name"}) | dict( - _langchain_message=message - ) - - @classmethod - def from_langchain_message(message: langchain_core.messages.BaseMessage, **kwargs): - raise NotImplementedError() - - def to_langchain_message(self) -> langchain_core.messages.BaseMessage: - raise NotImplementedError() - - -class AgentReference(ControlFlowModel): - name: str - - -class AgentMessageMixin(ControlFlowModel): - agent: Optional[AgentReference] = None - - @field_validator("agent", mode="before") - def _validate_agent(cls, v): - from controlflow.agents.agent import Agent - - if isinstance(v, Agent): - return AgentReference(name=v.name) - return v - - def __init__(self, *args, agent: "Agent" = None, **data): - if agent is not None and data.get("name") is None: - data["name"] = agent.name - super().__init__(*args, agent=agent, **data) - - -class AIMessage(BaseMessage, AgentMessageMixin): - role: Literal["ai"] = "ai" - tool_calls: list[ToolCall] = [] - - is_delta: bool = False - - def __init__(self, *args, **data): - super().__init__(*args, **data) - - # GPT-4 models somtimes use a hallucinated parallel tool calling mechanism - # whose name is not compatible with the API's restrictions on tool names - for tool_call in self.tool_calls: - if tool_call["name"] == "multi_tool_use.parallel": - tool_call["name"] = "multi_tool_use_parallel" - - def has_tool_calls(self) -> bool: - return any(self.tool_calls) - - @classmethod - def from_langchain_message( - cls, - message: langchain_core.messages.AIMessage, - **kwargs, - ): - data = dict( - **cls._langchain_message_kwargs(message), - tool_calls=message.tool_calls + getattr(message, "invalid_tool_calls", []), - ) - - return cls(**data | kwargs) - - def to_langchain_message( - self, - ) -> langchain_core.messages.AIMessage: - if self._langchain_message is not None: - return self._langchain_message - return langchain_core.messages.AIMessage( - content=self.content, tool_calls=self.tool_calls, id=self.id, name=self.name - ) - - -class AIMessageChunk(AIMessage): - tool_calls: list[ToolCallChunk] = [] - - @classmethod - def from_langchain_message( - cls, - message: Union[ - langchain_core.messages.AIMessageChunk, langchain_core.messages.AIMessage - ], - **kwargs, - ): - if isinstance(message, langchain_core.messages.AIMessageChunk): - tool_calls = message.tool_call_chunks - elif isinstance(message, langchain_core.messages.AIMessage): - tool_calls = [] - for i, call in enumerate(message.tool_calls): - tool_calls.append( - ToolCallChunk( - id=call["id"], - name=call["name"], - args=str(call["args"]) if call["args"] else None, - index=call.get("index", i), - ) - ) - data = dict( - **cls._langchain_message_kwargs(message), - tool_calls=tool_calls, - is_delta=True, - ) - - return cls(**data | kwargs) - - def to_langchain_message( - self, - ) -> langchain_core.messages.AIMessageChunk: - if self._langchain_message is not None: - return self._langchain_message - return langchain_core.messages.AIMessageChunk( - content=self.content, - tool_call_chunks=self.tool_calls, - id=self.id, - name=self.name, - ) - - -class UserMessage(BaseMessage): - role: Literal["user"] = "user" - - @classmethod - def from_langchain_message( - cls, message: langchain_core.messages.HumanMessage, **kwargs - ): - return cls(**cls._langchain_message_kwargs(message) | kwargs) - - def to_langchain_message(self) -> langchain_core.messages.BaseMessage: - if self._langchain_message is not None: - return self._langchain_message - return langchain_core.messages.HumanMessage( - content=self.content, id=self.id, name=self.name - ) - - -class SystemMessage(BaseMessage): - role: Literal["system"] = "system" - - @classmethod - def from_langchain_message( - cls, message: langchain_core.messages.SystemMessage, **kwargs - ): - return cls(**cls._langchain_message_kwargs(message) | kwargs) - - def to_langchain_message(self) -> langchain_core.messages.BaseMessage: - if self._langchain_message is not None: - return self._langchain_message - return langchain_core.messages.SystemMessage( - content=self.content, id=self.id, name=self.name - ) - - -class ToolMessage(BaseMessage, AgentMessageMixin): - model_config = dict(arbitrary_types_allowed=True) - role: Literal["tool"] = "tool" - tool_call_id: str - tool_call: Union[ToolCall, InvalidToolCall] - tool_result: Any = Field(None, exclude=True) - tool_metadata: dict[str, Any] = {} - is_error: bool = False - - def to_langchain_message(self) -> langchain_core.messages.BaseMessage: - if self._langchain_message is not None: - return self._langchain_message - else: - return langchain_core.messages.ToolMessage( - id=self.id, - name=self.name, - content=self.content, - tool_call_id=self.tool_call_id, - ) - - -MessageType = Union[UserMessage, AIMessage, SystemMessage, ToolMessage] +from langchain_core.messages import ( + AIMessage, + AIMessageChunk, + BaseMessage, + HumanMessage, + InvalidToolCall, + SystemMessage, + ToolCall, + ToolCallChunk, + ToolMessage, +) + +__all__ = [ + "AIMessage", + "AIMessageChunk", + "BaseMessage", + "HumanMessage", + "InvalidToolCall", + "SystemMessage", + "ToolCall", + "ToolCallChunk", + "ToolMessage", +] diff --git a/src/controlflow/llm/models.py b/src/controlflow/llm/models.py index d5881fd9..aee3d9bd 100644 --- a/src/controlflow/llm/models.py +++ b/src/controlflow/llm/models.py @@ -45,6 +45,14 @@ def model_from_string( "To use Google models, please install the `langchain_google_genai` package." ) cls = ChatGoogleGenerativeAI + elif provider == "groq": + try: + from langchain_groq import ChatGroq + except ImportError: + raise ImportError( + "To use Groq models, please install the `langchain_groq` package." + ) + cls = ChatGroq else: raise ValueError( f"Could not load provider automatically: {provider}. Please create your model manually." diff --git a/src/controlflow/llm/rules.py b/src/controlflow/llm/rules.py index 0db1dea4..c5a3a8d6 100644 --- a/src/controlflow/llm/rules.py +++ b/src/controlflow/llm/rules.py @@ -2,9 +2,10 @@ from langchain_openai import AzureChatOpenAI, ChatOpenAI from controlflow.llm.models import BaseChatModel +from controlflow.utilities.types import ControlFlowModel -class LLMRules: +class LLMRules(ControlFlowModel): """ LLM rules let us tailor DAG compilation, message generation, tool use, and other behavior to the requirements of different LLM provider APIs. @@ -13,14 +14,17 @@ class LLMRules: necessary. """ + # require at least one non-system message + require_at_least_one_message: bool = False + # system messages can only be provided as the very first message in a thread - system_message_must_be_first: bool = False + require_system_message_first: bool = False # other than a system message, the first message must be from the user - user_message_must_be_first_after_system: bool = False + require_user_message_after_system: bool = False # the last message in a thread can't be from an AI if tool use is allowed - allow_last_message_has_ai_role_with_tools: bool = True + allow_last_message_from_ai_when_using_tools: bool = True # consecutive AI messages must be separated by a user message allow_consecutive_ai_messages: bool = True @@ -29,15 +33,19 @@ class LLMRules: # (some APIs can use the `name` field for this purpose, but others can't) add_system_messages_for_multi_agent: bool = False + # if a tool is used, the result must follow the tool call immediately + tool_result_must_follow_tool_call: bool = True + class OpenAIRules(LLMRules): pass class AnthropicRules(LLMRules): - system_message_must_be_first: bool = True - user_message_must_be_first_after_system: bool = True - allow_last_message_has_ai_role_with_tools: bool = False + require_at_least_one_message: bool = True + require_system_message_first: bool = True + require_user_message_after_system: bool = True + allow_last_message_from_ai_when_using_tools: bool = False allow_consecutive_ai_messages: bool = False diff --git a/src/controlflow/controllers/__init__.py b/src/controlflow/orchestration/__init__.py similarity index 100% rename from src/controlflow/controllers/__init__.py rename to src/controlflow/orchestration/__init__.py diff --git a/src/controlflow/orchestration/controller.py b/src/controlflow/orchestration/controller.py new file mode 100644 index 00000000..77ac328f --- /dev/null +++ b/src/controlflow/orchestration/controller.py @@ -0,0 +1,288 @@ +import logging +from typing import Generator, Optional, TypeVar, Union + +from pydantic import Field, field_validator + +from controlflow.agents import Agent +from controlflow.events.agent_events import ( + EndTurnEvent, + SelectAgentEvent, +) +from controlflow.events.events import Event +from controlflow.events.task_events import TaskReadyEvent +from controlflow.flows import Flow +from controlflow.instructions import get_instructions +from controlflow.llm.messages import AIMessage +from controlflow.orchestration.handler import Handler +from controlflow.orchestration.tools import ( + create_end_turn_tool, + create_task_fail_tool, + create_task_success_tool, +) +from controlflow.tasks.task import Task +from controlflow.tools import as_tools +from controlflow.tools.tools import Tool +from controlflow.utilities.prefect import prefect_task as prefect_task +from controlflow.utilities.types import ControlFlowModel + +logger = logging.getLogger(__name__) + +T = TypeVar("T") + + +class Controller(ControlFlowModel): + """ + The controller is responsible for managing the flow of tasks and agents. It + is given objects that it is responsible for managing. At each iteration, the + controller will select a task and an agent to complete the task. The + controller will then create and execute an AgentContext to run the task. + """ + + model_config = dict(arbitrary_types_allowed=True) + flow: "Flow" = Field(description="The flow that the controller is managing") + tasks: Optional[list[Task]] = Field( + None, + description="Tasks to be completed by the controller. If None, all tasks in the flow will be used.", + ) + agents: dict[Task, list[Agent]] = Field( + default_factory=dict, + description="Optionally assign agents to complete tasks. The provided mapping must be task" + " -> [agents]. Any tasks that aren't included will use their default agents.", + ) + handlers: list[Handler] = Field(None, validate_default=True) + + @field_validator("handlers", mode="before") + def _handlers(cls, v): + from controlflow.orchestration.print_handler import PrintHandler + + if v is None: + v = [PrintHandler()] + return v + + @field_validator("agents", mode="before") + def _agents(cls, v): + if v is None: + v = {} + return v + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.tasks = self.tasks or self.flow.tasks + for task in self.tasks: + self.flow.add_task(task) + + def handle_event( + self, event: Event, tasks: list[Task] = None, agents: list[Agent] = None + ): + event.thread_id = self.flow.thread_id + event.task_ids = [t.id for t in tasks or []] + event.agent_ids = [a.id for a in agents or []] + for handler in self.handlers: + handler.on_event(event) + if event.persist: + self.flow.add_events([event]) + + def run_once(self): + """ + Core pipeline for running the controller. + """ + from controlflow.events.controller_events import ( + ControllerEnd, + ControllerError, + ControllerStart, + ) + + self.handle_event(ControllerStart(controller=self)) + + try: + ready_tasks = self.get_ready_tasks() + + if not ready_tasks: + return + + # select an agent + agent = self.get_agent(ready_tasks=ready_tasks) + active_tasks = self.get_active_tasks(agent=agent, ready_tasks=ready_tasks) + + context = AgentContext( + agent=agent, + tasks=active_tasks, + flow=self.flow, + controller=self, + ) + + # run + context.run() + + except Exception as exc: + self.handle_event(ControllerError(controller=self, error=exc)) + raise + finally: + self.handle_event(ControllerEnd(controller=self)) + + def run(self): + while any(t.is_incomplete() for t in self.tasks): + self.run_once() + + def get_ready_tasks(self) -> list[Task]: + all_tasks = self.flow.graph.upstream_tasks(self.tasks) + ready_tasks = [t for t in all_tasks if t.is_ready()] + return ready_tasks + + def get_active_tasks(self, agent: Agent, ready_tasks: list[Task]) -> list[Task]: + """ + Get the subset of ready tasks that the agent is assigned to. + """ + active_tasks = [] + for task in ready_tasks: + if agent in self.agents.get(task, task.get_agents()): + active_tasks.append(task) + self.handle_event(TaskReadyEvent(task=task), tasks=[task]) + if not task._prefect_task.is_started: + task._prefect_task.start( + depends_on=[t.result for t in task.depends_on] + ) + return active_tasks + + def get_agent(self, ready_tasks: list[Task]) -> tuple[Agent, list[Task]]: + candidates = [ + agent + for task in ready_tasks + # get agents from either controller assignments or the task defaults + for agent in self.agents.get(task, task.get_agents()) + ] + + # if there is only one candidate, return it + if len(candidates) == 1: + agent = candidates[0] + + # get the last select-agent event + select_event: list[Union[SelectAgentEvent, EndTurnEvent]] = ( + self.flow.get_events(limit=1, types=["select-agent", "end-turn"]) + ) + + # if an agent was selected and is a candidate, return it + if select_event and select_event[0].event == "select-agent": + agent = next( + (a for a in candidates if a.name == select_event[0].agent.name), None + ) + if agent: + return agent + # if an agent was nominated and is a candidate, return it + elif select_event and select_event[0].event == "end-turn": + agent = next( + (a for a in candidates if a.name == select_event[0].next_agent_name), + None, + ) + if agent: + return agent + + # if there are multiple candiates remaining, use the first task's strategy to select one + strategy_fn = ready_tasks[0].get_agent_strategy() + agent = strategy_fn(agents=candidates, task=ready_tasks[0], flow=self.flow) + ready_tasks[0]._iteration += 1 + + self.handle_event(SelectAgentEvent(agent=agent), agents=[agent]) + return agent + + +class AgentContext(ControlFlowModel): + agent: Agent + tasks: list[Task] + flow: Flow + controller: Controller + + def get_events(self) -> list[Event]: + return self.flow.get_events( + agent_ids=[self.agent.id], + task_ids=[t.id for t in self.flow.graph.upstream_tasks(self.tasks)], + ) + + def get_prompt(self) -> str: + from controlflow.orchestration import prompts + + # get up to 50 upstream and 50 downstream tasks + g = self.flow.graph + upstream_tasks = g.topological_sort([t for t in g.tasks if t.is_complete()])[ + -50: + ] + downstream_tasks = g.topological_sort( + [t for t in g.tasks if t.is_incomplete() and t not in self.tasks] + )[:50] + + tasks = [t.model_dump() for t in self.tasks] + upstream_tasks = [t.model_dump() for t in upstream_tasks] + downstream_tasks = [t.model_dump() for t in downstream_tasks] + + agent_prompt = prompts.AgentTemplate( + agent=self.agent, + additional_instructions=get_instructions(), + ) + + workflow_prompt = prompts.WorkflowTemplate( + flow=self.flow, + ready_tasks=tasks, + upstream_tasks=upstream_tasks, + downstream_tasks=downstream_tasks, + ) + + tool_prompt = prompts.ToolTemplate(agent=self.agent) + + return "\n\n".join( + [p.render() for p in [agent_prompt, workflow_prompt, tool_prompt]] + ) + + def get_tools(self) -> list[Tool]: + tools = [] + + # add flow tools + tools.extend(self.flow.tools) + + # add end turn tool + tools.append(create_end_turn_tool(controller=self.controller, agent=self.agent)) + + # add tools for any ready tasks that the agent is assigned to + for task in self.tasks: + tools.extend(task.get_tools()) + tools.append( + create_task_success_tool( + controller=self.controller, task=task, agent=self.agent + ) + ) + tools.append( + create_task_fail_tool( + controller=self.controller, task=task, agent=self.agent + ) + ) + + return as_tools(tools) + + def get_messages(self) -> list[AIMessage]: + from controlflow.events.message_compiler import EventContext, MessageCompiler + + events = self.flow.get_events( + agent_ids=[self.agent.id], + task_ids=[t.id for t in self.flow.graph.upstream_tasks(self.tasks)], + ) + + event_context = EventContext( + llm_rules=self.agent.get_llm_rules(), + agent=self.agent, + ready_tasks=self.tasks, + controller=self.controller, + flow=self.flow, + ) + + compiler = MessageCompiler( + events=events, + context=event_context, + system_prompt=self.get_prompt(), + ) + messages = compiler.compile_to_messages() + return messages + + def run(self) -> Generator["Event", None, None]: + tools = self.get_tools() + messages = self.get_messages() + for event in self.agent._run_model(messages=messages, additional_tools=tools): + self.controller.handle_event(event, tasks=self.tasks, agents=[self.agent]) diff --git a/src/controlflow/orchestration/handler.py b/src/controlflow/orchestration/handler.py new file mode 100644 index 00000000..92e96695 --- /dev/null +++ b/src/controlflow/orchestration/handler.py @@ -0,0 +1,9 @@ +from controlflow.events.events import Event + + +class Handler: + def on_event(self, event: Event): + event_type = event.event.replace("-", "_") + method = getattr(self, f"on_{event_type}", None) + if method: + method(event=event) diff --git a/src/controlflow/orchestration/print_handler.py b/src/controlflow/orchestration/print_handler.py new file mode 100644 index 00000000..1d92144a --- /dev/null +++ b/src/controlflow/orchestration/print_handler.py @@ -0,0 +1,188 @@ +import datetime +from typing import Union + +import rich +from rich import box +from rich.console import Group +from rich.live import Live +from rich.markdown import Markdown +from rich.panel import Panel +from rich.spinner import Spinner +from rich.table import Table + +import controlflow +from controlflow.events.agent_events import ( + AgentMessageDeltaEvent, + AgentMessageEvent, +) +from controlflow.events.controller_events import ( + ControllerEnd, + ControllerError, + ControllerStart, +) +from controlflow.events.events import Event +from controlflow.events.tool_events import ToolCall, ToolCallEvent, ToolResultEvent +from controlflow.llm.messages import BaseMessage +from controlflow.orchestration.handler import Handler +from controlflow.utilities.rich import console as cf_console + + +class PrintHandler(Handler): + def __init__(self): + self.events: dict[str, Event] = {} + self.paused_id: str = None + super().__init__() + + def on_controller_start(self, event: ControllerStart): + self.live: Live = Live(auto_refresh=False, console=cf_console) + self.events.clear() + try: + self.live.start() + except rich.errors.LiveError: + pass + + def on_controller_end(self, event: ControllerEnd): + self.live.stop() + + def on_controller_error(self, event: ControllerError): + self.live.stop() + + def update_live(self, latest: BaseMessage = None): + events = sorted(self.events.items(), key=lambda e: (e[1].timestamp, e[0])) + content = [] + + tool_results = {} # To track tool results by their call ID + + # gather all tool events first + for _, event in events: + if isinstance(event, ToolResultEvent): + tool_results[event.tool_call["id"]] = event + + for _, event in events: + if isinstance(event, (AgentMessageDeltaEvent, AgentMessageEvent)): + if formatted := format_event(event, tool_results=tool_results): + content.append(formatted) + + if not content: + return + elif self.live.is_started: + self.live.update(Group(*content), refresh=True) + elif latest: + cf_console.print(format_event(latest)) + + def on_agent_message_delta(self, event: AgentMessageDeltaEvent): + self.events[event.snapshot_message.id] = event + self.update_live() + + def on_agent_message(self, event: AgentMessageEvent): + self.events[event.ai_message.id] = event + self.update_live() + + def on_tool_call(self, event: ToolCallEvent): + # if collecting input on the terminal, pause the live display + # to avoid overwriting the input prompt + if event.tool_call["name"] == "talk_to_user": + self.paused_id = event.tool_call["id"] + self.live.stop() + self.events.clear() + + def on_tool_result(self, event: ToolResultEvent): + self.events[f"tool-result:{event.tool_call['id']}"] = event + + # # if we were paused, resume the live display + if self.paused_id and self.paused_id == event.tool_call["id"]: + self.paused_id = None + # print newline to avoid odd formatting issues + print() + self.live = Live(auto_refresh=False) + self.live.start() + self.update_live(latest=event) + + +ROLE_COLORS = { + "system": "gray", + "ai": "blue", + "user": "green", +} +ROLE_NAMES = { + "system": "System", + "ai": "Agent", + "user": "User", +} + + +def format_timestamp(timestamp: datetime.datetime) -> str: + local_timestamp = timestamp.astimezone() + return local_timestamp.strftime("%I:%M:%S %p").lstrip("0").rjust(11) + + +def status(icon, text) -> Table: + t = Table.grid(padding=1) + t.add_row(icon, text) + return t + + +def format_event( + event: Union[AgentMessageDeltaEvent, AgentMessageEvent], + tool_results: dict[str, ToolResultEvent] = None, +) -> Panel: + title = f"Agent: {event.agent.name}" + + content = [] + if isinstance(event, AgentMessageDeltaEvent): + message = event.snapshot_message + elif isinstance(event, AgentMessageEvent): + message = event.ai_message + else: + return + + if message.content: + content.append(Markdown(str(message.content))) + + tool_content = [] + for tool_call in message.tool_calls + message.invalid_tool_calls: + tool_result = (tool_results or {}).get(tool_call["id"]) + if tool_result: + c = format_tool_result(tool_result) + else: + c = format_tool_call(tool_call) + if c: + tool_content.append(c) + + if content and tool_content: + content.append("\n") + + return Panel( + Group(*content, *tool_content), + title=f"[bold]{title}[/]", + subtitle=f"[italic]{format_timestamp(event.timestamp)}[/]", + title_align="left", + subtitle_align="right", + border_style=ROLE_COLORS.get("ai", "red"), + box=box.ROUNDED, + width=100, + expand=True, + padding=(1, 2), + ) + + +def format_tool_call(tool_call: ToolCall) -> Panel: + if controlflow.settings.tools_verbose: + return status( + Spinner("dots"), + f'Tool call: "{tool_call["name"]}"\n\nTool args: {tool_call["args"]}', + ) + return status(Spinner("dots"), f'Tool call: "{tool_call["name"]}"') + + +def format_tool_result(event: ToolResultEvent) -> Panel: + if event.tool_result.is_error: + icon = ":x:" + else: + icon = ":white_check_mark:" + + if controlflow.settings.tools_verbose: + msg = f'Tool call: "{event.tool_call["name"]}"\n\nTool args: {event.tool_call["args"]}\n\nTool result: {event.tool_result.str_result}' + else: + msg = f'Tool call: "{event.tool_call["name"]}"' + return status(icon, msg) diff --git a/src/controlflow/orchestration/prompt_templates/agent.j2 b/src/controlflow/orchestration/prompt_templates/agent.j2 new file mode 100644 index 00000000..961c52e5 --- /dev/null +++ b/src/controlflow/orchestration/prompt_templates/agent.j2 @@ -0,0 +1,49 @@ +# Agent + +You are an AI agent participating in a workflow. Your role is to work on +your tasks and use the provided tools to complete those tasks and +communicate with the orchestrator. + +Important: The orchestrator is a Python script and cannot read or +respond to messages posted in this thread. You must use the provided +tools to communicate with the orchestrator. Posting messages in this +thread should only be used for thinking out loud, working through a +problem, or communicating with other agents. Any System messages or +messages prefixed with "SYSTEM:" are from the workflow system, not an +actual human. + +Your job is to: +1. Select one or more tasks to work on from the ready tasks. +2. Read the task instructions and work on completing the task objective, which +may involve using appropriate tools or collaborating with other agents assigned +to the same task. +3. When you (and any other agents) have completed the task objective, use the +provided tool to inform the orchestrator of the task completion and result. +4. Repeat steps 1-3 until no more tasks are available for execution. + +Note that the orchestrator may decide to activate a different agent at any time. + +## Your information + +- ID: {{ agent.id }} +- Name: "{{ agent.name }}" +{% if agent.description -%} +- Description: "{{ agent.description }}" +{% endif %} + +## Instructions + +You must follow instructions at all times. Instructions can be added or removed +at any time. + +- Never impersonate another agent + +{% if agent.instructions %} +{{ agent.instructions }} +{% endif %} + +{% if additional_instructions %} +{% for instruction in additional_instructions %} +- {{ instruction }} +{% endfor %} +{% endif %} \ No newline at end of file diff --git a/src/controlflow/orchestration/prompt_templates/tools.j2 b/src/controlflow/orchestration/prompt_templates/tools.j2 new file mode 100644 index 00000000..62441f70 --- /dev/null +++ b/src/controlflow/orchestration/prompt_templates/tools.j2 @@ -0,0 +1,21 @@ +# Tools + +You have access to various tools. They may change, so do not rely on history +to see what tools are available. + +## Talking to human users + +If your task requires you to interact with a user, it will show +`user_access=True` and you will be given a `talk_to_user` tool. You can +use it to send messages to the user and optionally wait for a response. +This is how you tell the user things and ask questions. Do not mention +your tasks or the workflow. The user can only see messages you send +them via tool. They can not read the rest of the +thread. + +Human users may give poor, incorrect, or partial responses. You may need +to ask questions multiple times in order to complete your tasks. Do not +make up answers for omitted information; ask again and only fail the +task if you truly can not make progress. If your task requires human +interaction and neither it nor any assigned agents have `user_access`, +you can fail the task. \ No newline at end of file diff --git a/src/controlflow/orchestration/prompt_templates/workflow.j2 b/src/controlflow/orchestration/prompt_templates/workflow.j2 new file mode 100644 index 00000000..d0dca962 --- /dev/null +++ b/src/controlflow/orchestration/prompt_templates/workflow.j2 @@ -0,0 +1,88 @@ +# Workflow + +As soon as you have completed a task's objective, you must use the provided +tool to mark it successful and provide a result. It may take multiple +turns or collaboration with other agents to complete a task. Any agent +assigned to a task can complete it. Once a task is complete, no other +agent can interact with it. + +Tasks should only be marked failed due to technical errors like a broken +or erroring tool or unresponsive human. + +Tasks are not ready until all of their dependencies are met. Parent +tasks depend on all of their subtasks. + +## Flow + +Name: {{ flow.name }} +{% if flow.description %} +Description: {{ flow.description }} +{% endif %} +{% if flow.context %} +Context: +{% for key, value in flow.context.items() %} +- {{ key }}: {{ value }} +{% endfor %} +{% endif %} + +## Tasks + +### Ready tasks + +These tasks are ready to be worked on because all of their dependencies have +been completed. You can only work on tasks to which you are assigned. + +{% for task in ready_tasks %} +#### Task {{ task.id }} +- objective: {{ task.objective }} +- instructions: {{ task.instructions}} +- context: {{ task.context }} +- result_type: {{ task.result_type }} +- depends_on: {{ task.depends_on }} +- parent: {{ task.parent }} +- assigned agents: {{ task.agents }} +{% if task.user_access %} +- user access: True +{% endif %} +- created_at: {{ task.created_at }} + +{% endfor %} + +### Upstream tasks + +{% for task in upstream_tasks %} +#### Task {{ task.id }} +- objective: {{ task.objective }} +- instructions: {{ task.instructions}} +- status: {{ task.status }} +- result: {{ task.result }} +- error: {{ task.error }} +- context: {{ task.context }} +- depends_on: {{ task.depends_on }} +- parent: {{ task.parent }} +- assigned agents: {{ task.agents }} +{% if task.user_access %} +- user access: True +{% endif %} +- created_at: {{ task.created_at }} + +{% endfor %} + +### Downstream tasks + +{% for task in downstream_tasks %} +#### Task {{ task.id }} +- objective: {{ task.objective }} +- instructions: {{ task.instructions}} +- status: {{ task.status }} +- result_type: {{ task.result_type }} +- context: {{ task.context }} +- depends_on: {{ task.depends_on }} +- parent: {{ task.parent }} +- assigned agents: {{ task.agents }} +{% if task.user_access %} +- user access: True +{% endif %} +- created_at: {{ task.created_at }} + +{% endfor %} \ No newline at end of file diff --git a/src/controlflow/orchestration/prompts.py b/src/controlflow/orchestration/prompts.py new file mode 100644 index 00000000..88f977e4 --- /dev/null +++ b/src/controlflow/orchestration/prompts.py @@ -0,0 +1,34 @@ +from controlflow.agents import Agent +from controlflow.flows import Flow +from controlflow.utilities.jinja import prompt_env +from controlflow.utilities.types import ControlFlowModel + + +class Template(ControlFlowModel): + template_path: str + + def render(self) -> str: + render_kwargs = dict(self) + template_path = render_kwargs.pop("template_path") + template_env = prompt_env.get_template(template_path) + return template_env.render(**render_kwargs) + + +class AgentTemplate(Template): + template_path: str = "agent.j2" + agent: Agent + additional_instructions: list[str] + + +class WorkflowTemplate(Template): + template_path: str = "workflow.j2" + + ready_tasks: list[dict] + upstream_tasks: list[dict] + downstream_tasks: list[dict] + flow: Flow + + +class ToolTemplate(Template): + template_path: str = "tools.j2" + agent: Agent diff --git a/src/controlflow/orchestration/tools.py b/src/controlflow/orchestration/tools.py new file mode 100644 index 00000000..c019d1a8 --- /dev/null +++ b/src/controlflow/orchestration/tools.py @@ -0,0 +1,103 @@ +from typing import TYPE_CHECKING, TypeVar + +from pydantic import PydanticSchemaGenerationError, TypeAdapter + +from controlflow.agents import Agent +from controlflow.events.agent_events import EndTurnEvent +from controlflow.events.task_events import TaskCompleteEvent +from controlflow.tasks.task import Task +from controlflow.tools.tools import Tool, tool + +if TYPE_CHECKING: + from controlflow.orchestration.controller import Controller + +T = TypeVar("T") + + +def generate_result_schema(result_type: type[T]) -> type[T]: + if result_type is None: + return None + + result_schema = None + # try loading pydantic-compatible schemas + try: + TypeAdapter(result_type) + result_schema = result_type + except PydanticSchemaGenerationError: + pass + # try loading as dataframe + # try: + # import pandas as pd + + # if result_type is pd.DataFrame: + # result_schema = PandasDataFrame + # elif result_type is pd.Series: + # result_schema = PandasSeries + # except ImportError: + # pass + if result_schema is None: + raise ValueError( + f"Could not load or infer schema for result type {result_type}. " + "Please use a custom type or add compatibility." + ) + return result_schema + + +def create_task_success_tool( + controller: "Controller", task: Task, agent: Agent +) -> Tool: + """ + Create an agent-compatible tool for marking this task as successful. + """ + + result_schema = generate_result_schema(task.result_type) + + @tool( + name=f"mark_task_{task.id}_successful", + description=f"Mark task {task.id} as successful.", + private=True, + ) + def succeed(result: result_schema) -> str: # type: ignore + result = task.mark_successful(result=result) + controller.handle_event(TaskCompleteEvent(task=task)) + controller.handle_event(EndTurnEvent(agent=agent)) + return result + + return succeed + + +def create_task_fail_tool(controller: "Controller", task: Task, agent: Agent) -> Tool: + """ + Create an agent-compatible tool for failing this task. + """ + + @tool( + name=f"mark_task_{task.id}_failed", + description=f"Mark task {task.id} as failed. Only use when technical errors prevent success.", + private=True, + ) + def fail(error: str) -> str: + result = task.mark_failed(error=error) + controller.handle_event(TaskCompleteEvent(task=task)) + controller.handle_event(EndTurnEvent(agent=agent)) + return result + + return fail + + +def create_end_turn_tool(controller: "Controller", agent: Agent) -> Tool: + """ + Create an agent-compatible tool for ending the turn. + """ + + @tool(private=True) + def end_turn(next_agent_name: str = None) -> str: + """End your turn so another agent can work. You can optionally choose + the next agent, which can be any other agent assigned to a ready task. + Choose an agent likely to help you complete your tasks.""" + controller.handle_event( + EndTurnEvent(agent=agent, next_agent_name=next_agent_name) + ) + return "Turn ended." + + return end_turn diff --git a/src/controlflow/settings.py b/src/controlflow/settings.py index c0a1633d..c492701e 100644 --- a/src/controlflow/settings.py +++ b/src/controlflow/settings.py @@ -94,11 +94,11 @@ class Settings(ControlFlowSettings): # ------------ Debug settings ------------ tools_raise_on_error: bool = Field( - False, description="If True, an error in a tool call will raise an exception." + True, description="If True, an error in a tool call will raise an exception." ) tools_verbose: bool = Field( - False, description="If True, tools will log additional information." + True, description="If True, tools will log additional information." ) # ------------ Prefect settings ------------ diff --git a/src/controlflow/tasks/agent_strategies.py b/src/controlflow/tasks/agent_strategies.py index f09f30a0..f8371dfe 100644 --- a/src/controlflow/tasks/agent_strategies.py +++ b/src/controlflow/tasks/agent_strategies.py @@ -1,7 +1,5 @@ from controlflow.agents import Agent -from controlflow.flows import Flow, get_flow_messages -from controlflow.instructions import get_instructions -from controlflow.llm.classify import classify +from controlflow.flows import Flow from controlflow.tasks.task import Task @@ -14,29 +12,29 @@ def round_robin( return agents[task._iteration % len(agents)] -def moderator( - agents: list[Agent], - tasks: list[Task], - context: dict = None, - iteration: int = 0, - model: str = None, -) -> Agent: - history = get_flow_messages() - instructions = get_instructions() - context = context or {} - context.update(tasks=tasks, history=history, instructions=instructions) +# def moderator( +# agents: list[Agent], +# tasks: list[Task], +# context: dict = None, +# iteration: int = 0, +# model: str = None, +# ) -> Agent: +# history () +# instructions = get_instructions() +# context = context or {} +# context.update(tasks=tasks, history=history, instructions=instructions) - agent = classify( - context, - labels=agents, - instructions=""" - Given the context, choose the AI agent best suited to take the - next turn at completing the tasks in the task graph. Take into account - any descriptions, tasks, history, instructions, and tools. Focus on - agents assigned to upstream dependencies or subtasks that need to be - completed before their downstream/parents can be completed. An agent - can only work on a task that it is assigned to. - """, - model=model, - ) - return agent +# agent = classify( +# context, +# labels=agents, +# instructions=""" +# Given the context, choose the AI agent best suited to take the +# next turn at completing the tasks in the task graph. Take into account +# any descriptions, tasks, history, instructions, and tools. Focus on +# agents assigned to upstream dependencies or subtasks that need to be +# completed before their downstream/parents can be completed. An agent +# can only work on a task that it is assigned to. +# """, +# model=model, +# ) +# return agent diff --git a/src/controlflow/tasks/task.py b/src/controlflow/tasks/task.py index fa0e70fe..cd8e78b9 100644 --- a/src/controlflow/tasks/task.py +++ b/src/controlflow/tasks/task.py @@ -44,7 +44,6 @@ ) if TYPE_CHECKING: - from controlflow.controllers.graph import Graph from controlflow.flows import Flow T = TypeVar("T") @@ -262,14 +261,9 @@ def friendly_name(self): objective = f'"{self.objective}"' return f"Task {self.id} ({objective})" - def as_graph(self) -> "Graph": - from controlflow.controllers.graph import Graph - - return Graph.from_tasks(tasks=[self]) - @property def subtasks(self) -> list["Task"]: - from controlflow.controllers.graph import Graph + from controlflow.flows.graph import Graph return Graph.from_tasks(tasks=self._subtasks).topological_sort() @@ -302,7 +296,7 @@ def run_once(self, agents: Optional[list["Agent"]] = None, flow: "Flow" = None): "Task.run_once() must be called within a flow context or with a flow argument." ) - from controlflow.controllers import Controller + from controlflow.orchestration import Controller controller = Controller( tasks=[self], flow=flow, agents={self: agents} if agents else None @@ -323,7 +317,7 @@ async def run_once_async( "Task.run_once_async() must be called within a flow context or with a flow argument." ) - from controlflow.controllers import Controller + from controlflow.orchestration import Controller controller = Controller( tasks=[self], flow=flow, agents={self: agents} if agents else None @@ -352,7 +346,7 @@ def run( else: flow = Flow() - from controlflow.controllers import Controller + from controlflow.orchestration import Controller controller = Controller( tasks=[self], flow=flow, agents={self: agents} if agents else None @@ -386,7 +380,7 @@ async def run_async( else: flow = Flow() - from controlflow.controllers import Controller + from controlflow.orchestration import Controller controller = Controller( tasks=[self], flow=flow, agents={self: agents} if agents else None @@ -436,53 +430,6 @@ def is_ready(self) -> bool: """ return self.is_incomplete() and all(t.is_complete() for t in self.depends_on) - def _create_success_tool(self) -> Tool: - """ - Create an agent-compatible tool for marking this task as successful. - """ - # generate tool for result_type=None - if self.result_type is None: - - def succeed() -> str: - return self.mark_successful(result=None) - - # generate tool for other result types - else: - result_schema = generate_result_schema(self.result_type) - - def succeed(result: result_schema) -> str: # type: ignore - return self.mark_successful(result=result) - - return Tool.from_function( - succeed, - name=f"mark_task_{self.id}_successful", - description=f"Mark task {self.id} as successful.", - metadata=dict(ignore_result=True), - ) - - def _create_fail_tool(self) -> Tool: - """ - Create an agent-compatible tool for failing this task. - """ - - return Tool.from_function( - self.mark_failed, - name=f"mark_task_{self.id}_failed", - description=f"Mark task {self.id} as failed. Only use when technical errors prevent success.", - metadata=dict(ignore_result=True), - ) - - def _create_skip_tool(self) -> Tool: - """ - Create an agent-compatible tool for skipping this task. - """ - return Tool.from_function( - self.mark_skipped, - name=f"mark_task_{self.id}_skipped", - description=f"Mark task {self.id} as skipped. Only use when completing a parent task early.", - metadata=dict(ignore_result=True), - ) - def get_agents(self) -> list["Agent"]: if self.agents: return self.agents @@ -521,12 +468,6 @@ def get_agent_strategy(self) -> Callable: def get_tools(self) -> list[Union[Tool, Callable]]: tools = self.tools.copy() - # if this task is ready to run, generate tools - if self.is_ready(): - tools.extend([self._create_fail_tool(), self._create_success_tool()]) - # add skip tool if this task has a parent task - # if self.parent is not None: - # tools.append(self._create_skip_tool()) if self.user_access: tools.append(talk_to_user) return tools @@ -602,32 +543,6 @@ def generate_subtasks(self, instructions: str = None, agent: Agent = None): ) -def generate_result_schema(result_type: type[T]) -> type[T]: - result_schema = None - # try loading pydantic-compatible schemas - try: - TypeAdapter(result_type) - result_schema = result_type - except PydanticSchemaGenerationError: - pass - # try loading as dataframe - # try: - # import pandas as pd - - # if result_type is pd.DataFrame: - # result_schema = PandasDataFrame - # elif result_type is pd.Series: - # result_schema = PandasSeries - # except ImportError: - # pass - if result_schema is None: - raise ValueError( - f"Could not load or infer schema for result type {result_type}. " - "Please use a custom type or add compatibility." - ) - return result_schema - - def validate_result(result: Any, result_type: type[T]) -> T: if result_type is None and result is not None: raise ValueError("Task has result_type=None, but a result was provided.") diff --git a/src/controlflow/tools/__init__.py b/src/controlflow/tools/__init__.py index 510e8df5..4b3e4222 100644 --- a/src/controlflow/tools/__init__.py +++ b/src/controlflow/tools/__init__.py @@ -1 +1 @@ -from controlflow.llm.tools import tool, Tool, as_tools +from controlflow.tools.tools import tool, Tool, as_tools diff --git a/src/controlflow/llm/tools.py b/src/controlflow/tools/tools.py similarity index 80% rename from src/controlflow/llm/tools.py rename to src/controlflow/tools/tools.py index 1a8c7f3b..1a16d500 100644 --- a/src/controlflow/llm/tools.py +++ b/src/controlflow/tools/tools.py @@ -2,7 +2,7 @@ import inspect import json import typing -from typing import TYPE_CHECKING, Annotated, Any, Callable, Optional, Union +from typing import Annotated, Any, Callable, Optional, Union import langchain_core.tools import pydantic @@ -12,14 +12,9 @@ from pydantic import Field, TypeAdapter import controlflow -from controlflow.llm.messages import ToolMessage from controlflow.utilities.prefect import create_markdown_artifact, prefect_task from controlflow.utilities.types import ControlFlowModel -if TYPE_CHECKING: - from controlflow.agents import Agent - - TOOL_CALL_FUNCTION_RESULT_TEMPLATE = """ # Tool call: {name} @@ -43,11 +38,13 @@ class Tool(ControlFlowModel): name: str description: str parameters: dict - metadata: dict = Field({}, exclude_none=True) + metadata: dict = {} + private: bool = False fn: Callable = Field(None, exclude=True) def to_lc_tool(self) -> dict: - return self.model_dump(include={"name", "description", "parameters"}) + payload = self.model_dump(include={"name", "description", "parameters"}) + return dict(type="function", function=payload) @prefect_task(task_run_name="Tool call: {self.name}") def run(self, input: dict): @@ -134,18 +131,14 @@ def tool( *, name: Optional[str] = None, description: Optional[str] = None, - metadata: Optional[dict] = None, + **kwargs, ) -> Tool: """ Decorator for turning a function into a Tool """ if fn is None: - return functools.partial( - tool, name=name, description=description, metadata=metadata - ) - return Tool.from_function( - fn, name=name, description=description, metadata=metadata or {} - ) + return functools.partial(tool, name=name, description=description, **kwargs) + return Tool.from_function(fn, name=name, description=description, **kwargs) def as_tools( @@ -193,45 +186,48 @@ def output_to_string(output: Any) -> str: return str(output) -def handle_tool_call( - tool_call: ToolCall, - tools: list[Tool], - error: str = None, - agent: "Agent" = None, -) -> ToolMessage: +class ToolResult(ControlFlowModel): + tool_call_id: str + result: Any = Field(exclude=True, repr=False) + str_result: str = Field(repr=False) + is_error: bool = False + is_private: bool = False + + +def handle_tool_call(tool_call: ToolCall, tools: list[Tool]) -> Any: + """ + Given a ToolCall and set of available tools, runs the tool call and returns + a ToolResult object + """ + is_error = False + tool = None tool_lookup = {t.name: t for t in tools} fn_name = tool_call["name"] - is_error = False - metadata = {} - try: - if error: - fn_output = error - is_error = True - elif fn_name not in tool_lookup: - fn_output = f'Function "{fn_name}" not found.' - is_error = True - else: + + if fn_name not in tool_lookup: + fn_output = f'Function "{fn_name}" not found.' + is_error = True + if controlflow.settings.tools_raise_on_error: + raise ValueError(fn_output) + + if not is_error: + try: tool = tool_lookup[fn_name] - metadata.update(getattr(tool, "metadata", {})) fn_args = tool_call["args"] if isinstance(tool, Tool): fn_output = tool.run(input=fn_args) elif isinstance(tool, langchain_core.tools.BaseTool): fn_output = tool.invoke(input=fn_args) - except Exception as exc: - fn_output = f'Error calling function "{fn_name}": {exc}' - is_error = True - if controlflow.settings.tools_raise_on_error: - raise - - from controlflow.llm.messages import ToolMessage + except Exception as exc: + fn_output = f'Error calling function "{fn_name}": {exc}' + is_error = True + if controlflow.settings.tools_raise_on_error: + raise exc - return ToolMessage( - content=output_to_string(fn_output), + return ToolResult( tool_call_id=tool_call["id"], - tool_call=tool_call, - tool_result=fn_output, - tool_metadata=metadata, + result=fn_output, + str_result=output_to_string(fn_output), is_error=is_error, - agent=agent, + is_private=getattr(tool, "private", False), ) diff --git a/src/controlflow/utilities/jinja.py b/src/controlflow/utilities/jinja.py index 1297add3..82b38f13 100644 --- a/src/controlflow/utilities/jinja.py +++ b/src/controlflow/utilities/jinja.py @@ -4,9 +4,17 @@ from zoneinfo import ZoneInfo from jinja2 import Environment as JinjaEnvironment -from jinja2 import StrictUndefined, select_autoescape +from jinja2 import PackageLoader, StrictUndefined, select_autoescape -jinja_env = JinjaEnvironment( +global_fns = { + "now": lambda: datetime.now(ZoneInfo("UTC")), + "inspect": inspect, + "getcwd": os.getcwd, + "zip": zip, +} + +prompt_env = JinjaEnvironment( + loader=PackageLoader("controlflow.orchestration", "prompt_templates"), autoescape=select_autoescape(default_for_string=False), trim_blocks=True, lstrip_blocks=True, @@ -14,11 +22,4 @@ undefined=StrictUndefined, ) -jinja_env.globals.update( - { - "now": lambda: datetime.now(ZoneInfo("UTC")), - "inspect": inspect, - "getcwd": os.getcwd, - "zip": zip, - } -) +prompt_env.globals.update(global_fns) diff --git a/src/controlflow/utilities/testing.py b/src/controlflow/utilities/testing.py index 59dad02c..6f908e90 100644 --- a/src/controlflow/utilities/testing.py +++ b/src/controlflow/utilities/testing.py @@ -6,12 +6,12 @@ import controlflow from controlflow.flows.history import InMemoryHistory -from controlflow.llm.messages import BaseMessage, MessageType +from controlflow.llm.messages import BaseMessage class FakeLLM(FakeMessagesListChatModel): def set_responses( - self, responses: list[Union[MessageType, langchain_core.messages.BaseMessage]] + self, responses: list[Union[BaseMessage, langchain_core.messages.BaseMessage]] ): new_responses = [] for msg in responses: diff --git a/src/controlflow/utilities/types.py b/src/controlflow/utilities/types.py index 8f81376d..1caa96a3 100644 --- a/src/controlflow/utilities/types.py +++ b/src/controlflow/utilities/types.py @@ -34,10 +34,3 @@ class PandasSeries(ControlFlowModel): index: Optional[list[str]] = None name: Optional[str] = None dtype: Optional[str] = None - - -class _OpenAIBaseType(ControlFlowModel): - model_config = ConfigDict(extra="allow") - - -__all__ = ["ControlFlowModel", "PandasDataFrame", "PandasSeries", "_OpenAIBaseType"] diff --git a/tests/conftest.py b/tests/conftest.py index adcf8563..4285f452 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,5 @@ import pytest -from controlflow.llm.messages import MessageType +from controlflow.llm.messages import BaseMessage from controlflow.settings import temporary_settings from prefect.testing.utilities import prefect_test_harness diff --git a/tests/controllers/test_graph.py b/tests/controllers/test_graph.py index e775e7e4..b5b006cc 100644 --- a/tests/controllers/test_graph.py +++ b/tests/controllers/test_graph.py @@ -1,5 +1,5 @@ # test_graph.py -from controlflow.controllers.graph import Edge, EdgeType, Graph +from controlflow.flows.graph import Edge, EdgeType, Graph from controlflow.tasks.task import Task diff --git a/tests/fixtures/controlflow.py b/tests/fixtures/controlflow.py index f6b1c086..3f4d7bcc 100644 --- a/tests/fixtures/controlflow.py +++ b/tests/fixtures/controlflow.py @@ -1,6 +1,6 @@ import controlflow import pytest -from controlflow.llm.messages import MessageType +from controlflow.llm.messages import BaseMessage from controlflow.utilities.testing import FakeLLM diff --git a/tests/llm/test_tools.py b/tests/llm/test_tools.py index 7fb037fe..67e55e44 100644 --- a/tests/llm/test_tools.py +++ b/tests/llm/test_tools.py @@ -4,7 +4,7 @@ import pytest from controlflow.agents.agent import Agent from controlflow.llm.messages import ToolMessage -from controlflow.llm.tools import ( +from controlflow.tools.tools import ( Tool, handle_tool_call, tool, diff --git a/tests/tasks/test_tasks.py b/tests/tasks/test_tasks.py index 9ace51c0..8f49fc02 100644 --- a/tests/tasks/test_tasks.py +++ b/tests/tasks/test_tasks.py @@ -2,8 +2,8 @@ import pytest from controlflow.agents import Agent, get_default_agent -from controlflow.controllers.graph import EdgeType from controlflow.flows import Flow +from controlflow.flows.graph import EdgeType from controlflow.instructions import instructions from controlflow.tasks.task import Task, TaskStatus from controlflow.utilities.context import ctx diff --git a/tests/test_settings.py b/tests/test_settings.py index 6246b0d0..b09e688f 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -6,6 +6,12 @@ from prefect.logging import get_logger +def test_defaults(): + # ensure that debug settings etc. are not left on by default + assert controlflow.settings.tools_raise_on_error is False + assert controlflow.settings.tools_verbose is False + + def test_temporary_settings(): assert controlflow.settings.tools_raise_on_error is False with temporary_settings(tools_raise_on_error=True): From f3c5383686787786250f52fbf1d4159e3002ca8e Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Wed, 3 Jul 2024 15:14:42 -0400 Subject: [PATCH 2/2] Update tests --- src/controlflow/__init__.py | 8 +- .../events/{event_store.py => history.py} | 20 ++-- src/controlflow/flows/flow.py | 6 +- src/controlflow/flows/graph.py | 17 ++- src/controlflow/flows/history.py | 110 ------------------ src/controlflow/settings.py | 4 +- src/controlflow/utilities/testing.py | 45 +++---- tests/controllers/test_graph.py | 27 +++-- tests/flows/test_flows.py | 72 +++++------- tests/llm/test_messages.py | 15 --- tests/tasks/test_tasks.py | 102 +--------------- tests/utilities/test_testing.py | 63 ++++------ 12 files changed, 125 insertions(+), 364 deletions(-) rename src/controlflow/events/{event_store.py => history.py} (95%) delete mode 100644 src/controlflow/flows/history.py delete mode 100644 tests/llm/test_messages.py diff --git a/src/controlflow/__init__.py b/src/controlflow/__init__.py index e16bbe5a..b194f197 100644 --- a/src/controlflow/__init__.py +++ b/src/controlflow/__init__.py @@ -13,15 +13,15 @@ # --- Default settings --- from .llm.models import _get_initial_default_model, get_default_model -from .events.event_store import InMemoryStore, get_default_event_store +from .events.history import InMemoryHistory, get_default_history # assign to controlflow.default_model to change the default model default_model = _get_initial_default_model() del _get_initial_default_model -# assign to controlflow.default_event_store to change the default event store -default_event_store = InMemoryStore() -del InMemoryStore +# assign to controlflow.default_history to change the default history +default_history = InMemoryHistory() +del InMemoryHistory # assign to controlflow.default_agent to change the default agent default_agent = Agent(name="Marvin") diff --git a/src/controlflow/events/event_store.py b/src/controlflow/events/history.py similarity index 95% rename from src/controlflow/events/event_store.py rename to src/controlflow/events/history.py index 3f0db393..bf7bc59c 100644 --- a/src/controlflow/events/event_store.py +++ b/src/controlflow/events/history.py @@ -27,6 +27,10 @@ IN_MEMORY_STORE = {} +def get_default_history() -> "History": + return controlflow.default_history + + @cache def get_event_validator() -> TypeAdapter: types = Union[ @@ -119,11 +123,7 @@ def filter_events( return list(reversed(new_events)) -def get_default_event_store() -> "EventStore": - return controlflow.default_event_store - - -class EventStore(ControlFlowModel, abc.ABC): +class History(ControlFlowModel, abc.ABC): @abc.abstractmethod def get_events( self, @@ -142,11 +142,11 @@ def add_events(self, thread_id: str, events: list[Event]): raise NotImplementedError() -class InMemoryStore(EventStore): - store: dict[str, list[Event]] = Field(default_factory=lambda: IN_MEMORY_STORE) +class InMemoryHistory(History): + history: dict[str, list[Event]] = Field(default_factory=lambda: IN_MEMORY_STORE) def add_events(self, thread_id: str, events: list[Event]): - self.store.setdefault(thread_id, []).extend(events) + self.history.setdefault(thread_id, []).extend(events) def get_events( self, @@ -174,7 +174,7 @@ def get_events( list[Event]: A list of events that match the specified criteria. """ - events = self.store.get(thread_id, []) + events = self.history.get(thread_id, []) return filter_events( events=events, agent_ids=agent_ids, @@ -186,7 +186,7 @@ def get_events( ) -class FileStore(EventStore): +class FileHistory(History): base_path: Path = Field( default_factory=lambda: controlflow.settings.home_path / "filestore_events" ) diff --git a/src/controlflow/flows/flow.py b/src/controlflow/flows/flow.py index 8041e453..81c21c1e 100644 --- a/src/controlflow/flows/flow.py +++ b/src/controlflow/flows/flow.py @@ -5,8 +5,8 @@ from pydantic import Field from controlflow.agents import Agent -from controlflow.events.event_store import EventStore, get_default_event_store from controlflow.events.events import Event +from controlflow.events.history import History, get_default_history from controlflow.flows.graph import Graph from controlflow.tasks.task import Task from controlflow.utilities.context import ctx @@ -21,7 +21,7 @@ class Flow(ControlFlowModel): thread_id: str = Field(default_factory=lambda: uuid.uuid4().hex) name: Optional[str] = None description: Optional[str] = None - event_store: EventStore = Field(default_factory=get_default_event_store) + event_store: History = Field(default_factory=get_default_history) tools: list[Callable] = Field( default_factory=list, description="Tools that will be available to every agent in the flow", @@ -46,7 +46,7 @@ def __init__(self, *, copy_parent: bool = True, **kwargs): parent = get_flow() if parent and copy_parent: self.add_events(parent.get_events()) - for task in parent.tasks.values(): + for task in parent.tasks: if task.is_complete(): self.add_task(task) diff --git a/src/controlflow/flows/graph.py b/src/controlflow/flows/graph.py index 384b9323..7f9d0e91 100644 --- a/src/controlflow/flows/graph.py +++ b/src/controlflow/flows/graph.py @@ -193,7 +193,7 @@ def _downstream(task): def topological_sort(self, tasks: Optional[list[Task]] = None) -> list[Task]: """ - Perform a topological sort on the provided tasks or all tasks in the graph. + Perform a deterministic topological sort on the provided tasks or all tasks in the graph. Args: tasks (Optional[list[Task]]): A list of tasks to sort topologically. @@ -202,6 +202,15 @@ def topological_sort(self, tasks: Optional[list[Task]] = None) -> list[Task]: Returns: list[Task]: A list of tasks in topological order (upstream tasks first). """ + # Create a cache key based on the input tasks + cache_key = ( + f"topo_sort_{tuple(sorted(task.id for task in (tasks or self.tasks)))}" + ) + + # Check if the result is already in the cache + if cache_key in self._cache: + return self._cache[cache_key] + if tasks is None: tasks_to_sort = self.tasks else: @@ -216,6 +225,8 @@ def topological_sort(self, tasks: Optional[list[Task]] = None) -> list[Task]: # Kahn's algorithm for topological sorting result = [] no_incoming = [task for task in tasks_to_sort if not dependencies[task]] + # sort to create a deterministic order + no_incoming.sort(key=lambda t: t.created_at) while no_incoming: task = no_incoming.pop(0) @@ -227,6 +238,8 @@ def topological_sort(self, tasks: Optional[list[Task]] = None) -> list[Task]: dependencies[dependent_task].remove(task) if not dependencies[dependent_task]: no_incoming.append(dependent_task) + # resort to maintain deterministic order + no_incoming.sort(key=lambda t: t.created_at) # Check for cycles if len(result) != len(tasks_to_sort): @@ -234,4 +247,6 @@ def topological_sort(self, tasks: Optional[list[Task]] = None) -> list[Task]: "The graph contains a cycle and cannot be topologically sorted" ) + # Cache the result before returning + self._cache[cache_key] = result return result diff --git a/src/controlflow/flows/history.py b/src/controlflow/flows/history.py deleted file mode 100644 index ebc81afe..00000000 --- a/src/controlflow/flows/history.py +++ /dev/null @@ -1,110 +0,0 @@ -import abc -import datetime -import json -import math -from pathlib import Path - -from pydantic import Field, field_validator - -import controlflow -from controlflow.llm.messages import BaseMessage -from controlflow.utilities.types import ControlFlowModel - -# This is a global variable that will be shared between all instances of InMemoryHistory -IN_MEMORY_HISTORY = {} - - -def get_default_history() -> "History": - return controlflow.default_history - - -class History(ControlFlowModel, abc.ABC): - @abc.abstractmethod - def load_messages( - self, - thread_id: str, - limit: int = None, - before: datetime.datetime = None, - after: datetime.datetime = None, - ) -> list[BaseMessage]: - raise NotImplementedError() - - @abc.abstractmethod - def save_messages(self, thread_id: str, messages: list[BaseMessage]): - raise NotImplementedError() - - -class InMemoryHistory(History): - history: dict[str, list[BaseMessage]] = Field( - default_factory=lambda: IN_MEMORY_HISTORY - ) - - def load_messages( - self, - thread_id: str, - limit: int = None, - before: datetime.datetime = None, - after: datetime.datetime = None, - ) -> list[BaseMessage]: - messages = self.history.get(thread_id, []) - filtered_messages = [ - msg - for i, msg in enumerate(reversed(messages)) - if (before is None or msg.timestamp < before) - and (after is None or msg.timestamp > after) - and i < (limit or math.inf) - ] - return list(reversed(filtered_messages)) - - def save_messages(self, thread_id: str, messages: list[BaseMessage]): - self.history.setdefault(thread_id, []).extend(messages) - - -class FileHistory(History): - base_path: Path = Field( - default_factory=lambda: controlflow.settings.home_path / "history" - ) - - def path(self, thread_id: str) -> Path: - return self.base_path / f"{thread_id}.json" - - @field_validator("base_path", mode="before") - def _validate_path(cls, v): - v = Path(v).expanduser() - if not v.exists(): - v.mkdir(parents=True, exist_ok=True) - return v - - def load_messages( - self, - thread_id: str, - limit: int = None, - before: datetime.datetime = None, - after: datetime.datetime = None, - ) -> list[BaseMessage]: - if not self.path(thread_id).exists(): - return [] - - with open(self.path(thread_id), "r") as f: - all_messages = json.load(f) - - messages = [] - for msg in reversed(all_messages): - message = BaseMessage.model_validate(msg) - if before is None or message.timestamp < before: - if after is None or message.timestamp > after: - messages.append(message) - if len(messages) >= limit or math.inf: - break - - return list(reversed(messages)) - - def save_messages(self, thread_id: str, messages: list[BaseMessage]): - if self.path(thread_id).exists(): - with open(self.path(thread_id), "r") as f: - all_messages = json.load(f) - else: - all_messages = [] - all_messages.extend([msg.model_dump(mode="json") for msg in messages]) - with open(self.path(thread_id), "w") as f: - json.dump(all_messages, f) diff --git a/src/controlflow/settings.py b/src/controlflow/settings.py index c492701e..c0a1633d 100644 --- a/src/controlflow/settings.py +++ b/src/controlflow/settings.py @@ -94,11 +94,11 @@ class Settings(ControlFlowSettings): # ------------ Debug settings ------------ tools_raise_on_error: bool = Field( - True, description="If True, an error in a tool call will raise an exception." + False, description="If True, an error in a tool call will raise an exception." ) tools_verbose: bool = Field( - True, description="If True, tools will log additional information." + False, description="If True, tools will log additional information." ) # ------------ Prefect settings ------------ diff --git a/src/controlflow/utilities/testing.py b/src/controlflow/utilities/testing.py index 6f908e90..f5ee9f80 100644 --- a/src/controlflow/utilities/testing.py +++ b/src/controlflow/utilities/testing.py @@ -1,24 +1,15 @@ from contextlib import contextmanager -from typing import Union -import langchain_core.messages from langchain_core.language_models.fake_chat_models import FakeMessagesListChatModel import controlflow -from controlflow.flows.history import InMemoryHistory +from controlflow.events.history import InMemoryHistory from controlflow.llm.messages import BaseMessage class FakeLLM(FakeMessagesListChatModel): - def set_responses( - self, responses: list[Union[BaseMessage, langchain_core.messages.BaseMessage]] - ): - new_responses = [] - for msg in responses: - if isinstance(msg, BaseMessage): - msg = msg.to_langchain_message() - new_responses.append(msg) - self.responses = new_responses + def set_responses(self, responses: list[BaseMessage]): + self.responses = responses def bind_tools(self, *args, **kwargs): """When binding tools, passthrough""" @@ -30,38 +21,32 @@ def get_num_tokens_from_messages(self, messages: list) -> int: @contextmanager -def record_messages( - remove_additional_kwargs: bool = True, remove_tool_call_chunks: bool = True -): +def record_events(): """ Context manager for recording all messages in a flow, useful for testing. - with record_messages() as messages: + with record_events() as events: cf.Task("say hello").run() - assert messages[0].content == "Hello!" + assert events[0].content == "Hello!" """ history = InMemoryHistory(history={}) old_default_history = controlflow.default_history controlflow.default_history = history - messages = [] + events = [] try: - yield messages + yield events finally: controlflow.default_history = old_default_history - _messages_buffer = [] - for _, thread_messages in history.history.items(): - for message in thread_messages: - message = message.copy() - if hasattr(message, "additional_kwargs") and remove_additional_kwargs: - message.additional_kwargs = {} - if hasattr(message, "tool_call_chunks") and remove_tool_call_chunks: - message.tool_call_chunks = [] - _messages_buffer.append(message) - - messages.extend(sorted(_messages_buffer, key=lambda m: m.timestamp)) + _events_buffer = [] + for _, thread_events in history.history.items(): + for event in thread_events: + event = event.copy() + _events_buffer.append(event) + + events.extend(sorted(_events_buffer, key=lambda m: m.timestamp)) diff --git a/tests/controllers/test_graph.py b/tests/controllers/test_graph.py index b5b006cc..7dc9dde1 100644 --- a/tests/controllers/test_graph.py +++ b/tests/controllers/test_graph.py @@ -87,6 +87,16 @@ def test_topological_sort(): assert sorted_tasks.index(task3) < sorted_tasks.index(task4) +def test_topological_sort_uses_time_to_tiebreak(): + task1 = Task(objective="Task 1") + task2 = Task(objective="Task 2") + task3 = Task(objective="Task 3") + task4 = Task(objective="Task 4") + graph = Graph.from_tasks([task1, task2, task3, task4]) + sorted_tasks = graph.topological_sort() + assert sorted_tasks == [task1, task2, task3, task4] + + def test_topological_sort_with_fan_in_and_fan_out(): task1 = Task(objective="Task 1") task2 = Task(objective="Task 2") @@ -125,12 +135,11 @@ def test_upstream_tasks(): graph.add_edge(edge2) graph.add_edge(edge3) - assert graph.upstream_tasks([task3]) == [task1, task2] - assert graph.upstream_tasks([task2]) == [task1] - assert graph.upstream_tasks([task1]) == [] + assert graph.upstream_tasks([task3]) == [task1, task2, task3] + assert graph.upstream_tasks([task2]) == [task1, task2] + assert graph.upstream_tasks([task1]) == [task1] - # never include a start task in the usptream list - assert graph.upstream_tasks([task1, task3]) == [task2] + assert graph.upstream_tasks([task1, task3]) == [task1, task2, task3] def test_downstream_tasks(): @@ -147,9 +156,9 @@ def test_downstream_tasks(): graph.add_edge(edge2) graph.add_edge(edge3) - assert graph.downstream_tasks([task3]) == [] - assert graph.downstream_tasks([task2]) == [task3] - assert graph.downstream_tasks([task1]) == [task2, task3] + assert graph.downstream_tasks([task3]) == [task3] + assert graph.downstream_tasks([task2]) == [task2, task3] + assert graph.downstream_tasks([task1]) == [task1, task2, task3] # never include a start task in the downstream list - assert graph.downstream_tasks([task1, task3]) == [task2] + assert graph.downstream_tasks([task1, task3]) == [task1, task2, task3] diff --git a/tests/flows/test_flows.py b/tests/flows/test_flows.py index ad98f5d6..b3ed3404 100644 --- a/tests/flows/test_flows.py +++ b/tests/flows/test_flows.py @@ -1,6 +1,6 @@ from controlflow.agents import Agent +from controlflow.events.agent_events import UserMessageEvent from controlflow.flows import Flow, get_flow -from controlflow.llm.messages import UserMessage from controlflow.tasks.task import Task from controlflow.utilities.context import ctx @@ -71,7 +71,7 @@ def test_tasks_created_in_flow_context(self): t1 = Task("test 1") t2 = Task("test 2") - assert flow.tasks == {t1.id: t1, t2.id: t2} + assert flow.tasks == [t1, t2] def test_tasks_created_in_nested_flows_only_in_inner_flow(self): with Flow() as flow1: @@ -79,73 +79,63 @@ def test_tasks_created_in_nested_flows_only_in_inner_flow(self): with Flow() as flow2: t2 = Task("test 2") - assert flow1.tasks == {t1.id: t1} - assert flow2.tasks == {t2.id: t2} + assert flow1.tasks == [t1] + assert flow2.tasks == [t2] + + def test_inner_flow_includes_completed_parent_tasks(self): + with Flow() as flow1: + t1 = Task("test 1", status="SUCCESSFUL") + t2 = Task("test 2") + with Flow() as flow2: + t3 = Task("test 3") + + assert flow1.tasks == [t1, t2] + assert flow2.tasks == [t1, t3] class TestFlowHistory: - def test_get_messages_empty(self): + def test_get_events_empty(self): flow = Flow() - messages = flow.get_messages() + messages = flow.get_events() assert messages == [] - def test_add_messages_with_history(self): - flow = Flow() - flow.add_messages( - messages=[UserMessage(content="hello"), UserMessage(content="world")] - ) - messages = flow.get_messages() - assert len(messages) == 2 - assert [m.content for m in messages] == ["hello", "world"] - - def test_copy_parent_history(self): - flow1 = Flow() - flow1.add_messages( - messages=[UserMessage(content="hello"), UserMessage(content="world")] - ) - - with flow1: - flow2 = Flow() - - messages1 = flow1.get_messages() - assert len(messages1) == 2 - assert [m.content for m in messages1] == ["hello", "world"] - - messages2 = flow2.get_messages() - assert len(messages2) == 2 - assert [m.content for m in messages2] == ["hello", "world"] - def test_disable_copying_parent_history(self): flow1 = Flow() - flow1.add_messages( - messages=[UserMessage(content="hello"), UserMessage(content="world")] + flow1.add_events( + [ + UserMessageEvent(content="hello"), + UserMessageEvent(content="world"), + ] ) with flow1: flow2 = Flow(copy_parent=False) - messages1 = flow1.get_messages() + messages1 = flow1.get_events() assert len(messages1) == 2 assert [m.content for m in messages1] == ["hello", "world"] - messages2 = flow2.get_messages() + messages2 = flow2.get_events() assert len(messages2) == 0 def test_child_flow_messages_dont_go_to_parent(self): flow1 = Flow() - flow1.add_messages( - messages=[UserMessage(content="hello"), UserMessage(content="world")] + flow1.add_events( + [ + UserMessageEvent(content="hello"), + UserMessageEvent(content="world"), + ] ) with flow1: flow2 = Flow() - flow2.add_messages(messages=[UserMessage(content="goodbye")]) + flow2.add_events([UserMessageEvent(content="goodbye")]) - messages1 = flow1.get_messages() + messages1 = flow1.get_events() assert len(messages1) == 2 assert [m.content for m in messages1] == ["hello", "world"] - messages2 = flow2.get_messages() + messages2 = flow2.get_events() assert len(messages2) == 3 assert [m.content for m in messages2] == ["hello", "world", "goodbye"] diff --git a/tests/llm/test_messages.py b/tests/llm/test_messages.py deleted file mode 100644 index 67adec53..00000000 --- a/tests/llm/test_messages.py +++ /dev/null @@ -1,15 +0,0 @@ -import controlflow -from controlflow.llm.messages import AgentReference, AIMessage - - -class TestAIMessage: - def test_agent(self): - agent = controlflow.Agent(name="Test Agent!") - message = AIMessage(content="", agent=agent) - assert isinstance(message.agent, AgentReference) - assert message.agent.name == "Test Agent!" - - def test_name_loaded_from_agent(self): - agent = controlflow.Agent(name="Test Agent!") - message = AIMessage(content="", agent=agent) - assert message.name == "Test-Agent" diff --git a/tests/tasks/test_tasks.py b/tests/tasks/test_tasks.py index 8f49fc02..2abdb808 100644 --- a/tests/tasks/test_tasks.py +++ b/tests/tasks/test_tasks.py @@ -3,7 +3,6 @@ import pytest from controlflow.agents import Agent, get_default_agent from controlflow.flows import Flow -from controlflow.flows.graph import EdgeType from controlflow.instructions import instructions from controlflow.tasks.task import Task, TaskStatus from controlflow.utilities.context import ctx @@ -124,14 +123,14 @@ def test_task_loads_agent_from_parent_before_flow(): def test_task_tracking(): with Flow() as flow: task = SimpleTask() - assert task in flow.tasks.values() + assert task in flow.tasks def test_task_tracking_on_call(): task = SimpleTask() with Flow() as flow: task.run_once() - assert task in flow.tasks.values() + assert task in flow.tasks class TestTaskStatus: @@ -213,100 +212,3 @@ def test_task_hash(self): task1 = SimpleTask() task2 = SimpleTask() assert hash(task1) != hash(task2) - - def test_ready_task_adds_tools(self): - task = SimpleTask() - assert task.is_ready() - - tools = task.get_tools() - assert any(tool.name == f"mark_task_{task.id}_failed" for tool in tools) - assert any(tool.name == f"mark_task_{task.id}_successful" for tool in tools) - - def test_completed_task_does_not_add_tools(self): - task = SimpleTask() - task.mark_successful() - tools = task.get_tools() - assert not any(tool.name == f"mark_task_{task.id}_failed" for tool in tools) - assert not any(tool.name == f"mark_task_{task.id}_successful" for tool in tools) - - def test_task_with_incomplete_upstream_does_not_add_tools(self): - upstream_task = SimpleTask() - downstream_task = SimpleTask(depends_on=[upstream_task]) - tools = downstream_task.get_tools() - assert not any( - tool.name == f"mark_task_{downstream_task.id}_failed" for tool in tools - ) - assert not any( - tool.name == f"mark_task_{downstream_task.id}_successful" for tool in tools - ) - - def test_task_with_incomplete_subtask_does_not_add_tools(self): - parent = SimpleTask() - SimpleTask(parent=parent) - tools = parent.get_tools() - assert not any(tool.name == f"mark_task_{parent.id}_failed" for tool in tools) - assert not any( - tool.name == f"mark_task_{parent.id}_successful" for tool in tools - ) - - -class TestTaskToGraph: - def test_single_task_graph(self): - task = SimpleTask() - graph = task.as_graph() - assert len(graph.tasks) == 1 - assert task in graph.tasks - assert len(graph.edges) == 0 - - def test_task_with_subtasks_graph(self): - task1 = SimpleTask() - task2 = SimpleTask(parent=task1) - graph = task1.as_graph() - assert len(graph.tasks) == 2 - assert task1 in graph.tasks - assert task2 in graph.tasks - assert len(graph.edges) == 1 - assert any( - edge.upstream == task2 - and edge.downstream == task1 - and edge.type == EdgeType.SUBTASK - for edge in graph.edges - ) - - def test_task_with_dependencies_graph(self): - task1 = SimpleTask() - task2 = SimpleTask(depends_on=[task1]) - graph = task2.as_graph() - assert len(graph.tasks) == 2 - assert task1 in graph.tasks - assert task2 in graph.tasks - assert len(graph.edges) == 1 - assert any( - edge.upstream == task1 - and edge.downstream == task2 - and edge.type == EdgeType.DEPENDENCY - for edge in graph.edges - ) - - def test_task_with_subtasks_and_dependencies_graph(self): - task1 = SimpleTask() - task2 = SimpleTask(depends_on=[task1]) - task3 = SimpleTask(objective="Task 3", parent=task2) - graph = task2.as_graph() - assert len(graph.tasks) == 3 - assert task1 in graph.tasks - assert task2 in graph.tasks - assert task3 in graph.tasks - assert len(graph.edges) == 2 - assert any( - edge.upstream == task1 - and edge.downstream == task2 - and edge.type == EdgeType.DEPENDENCY - for edge in graph.edges - ) - assert any( - edge.upstream == task3 - and edge.downstream == task2 - and edge.type == EdgeType.SUBTASK - for edge in graph.edges - ) diff --git a/tests/utilities/test_testing.py b/tests/utilities/test_testing.py index c23a0db6..875a7f67 100644 --- a/tests/utilities/test_testing.py +++ b/tests/utilities/test_testing.py @@ -1,63 +1,48 @@ -import datetime - import controlflow -from controlflow.llm.messages import AIMessage, ToolMessage -from controlflow.utilities.testing import record_messages +from controlflow.llm.messages import AIMessage +from controlflow.utilities.testing import record_events -def test_record_messages_empty(): - with record_messages() as messages: +def test_record_events_empty(): + with record_events() as events: pass - assert messages == [] + assert events == [] -def test_record_task_messages(default_fake_llm): - task = controlflow.Task("say hello") +def test_record_task_events(default_fake_llm): + task = controlflow.Task("say hello", id="12345") response = AIMessage( - agent=dict(name="Marvin"), id="run-2af8bb73-661f-4ec3-92ff-d7d8e3074926", - timestamp=datetime.datetime( - 2024, 6, 23, 17, 12, 24, 91830, tzinfo=datetime.timezone.utc - ), + name="Marvin", role="ai", content="", - name="Marvin", tool_calls=[ { - "name": f"mark_task_{task.id}_successful", + "name": "mark_task_12345_successful", "args": {"result": "Hello!"}, "id": "call_ZEPdV8mCgeBe5UHjKzm6e3pe", } ], - is_delta=False, ) default_fake_llm.set_responses([response]) - with record_messages() as rec_messages: + with record_events() as events: task.run() - assert rec_messages[0].content == response.content - assert rec_messages[0].id == response.id - assert rec_messages[0].tool_calls == response.tool_calls - - expected_tool_message = ToolMessage( - agent=dict(name="Marvin"), - id="cb84bb8f3e0f4245bbf5eefeee9272b2", - timestamp=datetime.datetime( - 2024, 6, 23, 17, 12, 24, 187384, tzinfo=datetime.timezone.utc - ), - role="tool", - content=f'Task {task.id} ("say hello") marked successful by Marvin.', - name="Marvin", + assert events[0].event == "select-agent" + assert events[1].event == "agent-message" + assert response == events[1].ai_message + + assert events[5].event == "tool-result" + assert events[5].tool_call == { + "name": "mark_task_12345_successful", + "args": {"result": "Hello!"}, + "id": "call_ZEPdV8mCgeBe5UHjKzm6e3pe", + } + assert events[5].tool_result.model_dump() == dict( tool_call_id="call_ZEPdV8mCgeBe5UHjKzm6e3pe", - tool_call={ - "name": f"mark_task_{task.id}_successful", - "args": {"result": "Hello!"}, - "id": "call_ZEPdV8mCgeBe5UHjKzm6e3pe", - }, - tool_metadata={"ignore_result": True}, + str_result='Task 12345 ("say hello") marked successful.', + is_error=False, + is_private=True, ) - assert rec_messages[1].content == expected_tool_message.content - assert rec_messages[1].tool_call_id == expected_tool_message.tool_call_id - assert rec_messages[1].tool_call == expected_tool_message.tool_call