diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 00000000..04fd9fb7 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1 @@ +recursive-include controlflow/orchestration/prompt_templates * \ No newline at end of file diff --git a/examples/teacher_student.py b/examples/teacher_student.py index b6626850..549c195a 100644 --- a/examples/teacher_student.py +++ b/examples/teacher_student.py @@ -1,29 +1,34 @@ from controlflow import Agent, Task, flow from controlflow.instructions import instructions -teacher = Agent(name="teacher") -student = Agent(name="student") +teacher = Agent(name="Teacher") +student = Agent(name="Student") @flow def demo(): - with Task("Teach a class by asking and answering 3 questions") as task: + with Task("Teach a class by asking and answering 3 questions", agents=[teacher]): for _ in range(3): question = Task( - "ask the student a question. Wait for the student to answer your question before asking another one.", - str, - agents=[teacher], + "Ask the student a question.", result_type=str, agents=[teacher] ) - with instructions("one sentence max"): - Task( - "answer the question", - str, + + with instructions("One sentence max"): + answer = Task( + "Answer the question.", agents=[student], context=dict(question=question), ) - task.run() - return task + grade = Task( + "Assess the answer.", + result_type=["pass", "fail"], + agents=[teacher], + context=dict(answer=answer), + ) + + # run each qa session, one at a time + grade.run() t = demo() diff --git a/pyproject.toml b/pyproject.toml index 96aee618..cc15be49 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,11 +6,11 @@ authors = [ { name = "Jeremiah Lowin", email = "153965+jlowin@users.noreply.github.com" }, ] dependencies = [ - "prefect>=3.0rc4", + "prefect>=3.0rc10", "jinja2>=3.1.4", "langchain_core>=0.2.9", "langchain_openai>=0.1.8", - "langchain-anthropic>=0.1.15", + "langchain-anthropic>=0.1.19", "markdownify>=0.12.1", "pydantic-settings>=2.2.1", "textual>=0.61.1", diff --git a/src/controlflow/__init__.py b/src/controlflow/__init__.py index ca631c1a..b194f197 100644 --- a/src/controlflow/__init__.py +++ b/src/controlflow/__init__.py @@ -8,12 +8,12 @@ from .instructions import instructions from .decorators import flow, task -from .llm.tools import tool +from .tools import tool # --- Default settings --- from .llm.models import _get_initial_default_model, get_default_model -from .flows.history import InMemoryHistory, get_default_history +from .events.history import InMemoryHistory, get_default_history # assign to controlflow.default_model to change the default model default_model = _get_initial_default_model() diff --git a/src/controlflow/agents/agent.py b/src/controlflow/agents/agent.py index 8b5a8e33..419f6c2b 100644 --- a/src/controlflow/agents/agent.py +++ b/src/controlflow/agents/agent.py @@ -2,16 +2,18 @@ import random import uuid from contextlib import contextmanager -from typing import TYPE_CHECKING, Any, Callable, Optional +from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Generator, Optional from langchain_core.language_models import BaseChatModel from pydantic import Field, field_serializer import controlflow +from controlflow.events.events import Event from controlflow.instructions import get_instructions +from controlflow.llm.messages import AIMessage, BaseMessage from controlflow.llm.models import get_default_model from controlflow.llm.rules import LLMRules -from controlflow.tools.talk_to_user import talk_to_user +from controlflow.tools.tools import handle_tool_call_async from controlflow.utilities.context import ctx from controlflow.utilities.types import ControlFlowModel @@ -20,6 +22,8 @@ if TYPE_CHECKING: from controlflow.tasks.task import Task + from controlflow.tools.tools import Tool + logger = logging.getLogger(__name__) @@ -109,6 +113,8 @@ def get_llm_rules(self) -> LLMRules: return controlflow.llm.rules.rules_for_model(self.get_model()) def get_tools(self) -> list[Callable]: + from controlflow.tools.talk_to_user import talk_to_user + tools = self.tools.copy() if self.user_access: tools.append(talk_to_user) @@ -141,5 +147,83 @@ def run(self, task: "Task"): async def run_async(self, task: "Task"): return await task.run_async(agents=[self]) + def _run_model( + self, + messages: list[BaseMessage], + additional_tools: list["Tool"] = None, + stream: bool = True, + ) -> Generator[Event, None, None]: + from controlflow.events.agent_events import ( + AgentMessageDeltaEvent, + AgentMessageEvent, + ) + from controlflow.events.tool_events import ToolCallEvent, ToolResultEvent + from controlflow.tools.tools import as_tools, handle_tool_call + + model = self.get_model() + + tools = as_tools(self.get_tools() + (additional_tools or [])) + if tools: + model = model.bind_tools([t.to_lc_tool() for t in tools]) + + if stream: + response = None + for delta in model.stream(messages): + if response is None: + response = delta + else: + response += delta + yield AgentMessageDeltaEvent(agent=self, delta=delta, snapshot=response) + + else: + response: AIMessage = model.invoke(messages) + + yield AgentMessageEvent(agent=self, message=response) + + for tool_call in response.tool_calls + response.invalid_tool_calls: + yield ToolCallEvent(agent=self, tool_call=tool_call, message=response) + + result = handle_tool_call(tool_call, tools=tools) + yield ToolResultEvent(agent=self, tool_call=tool_call, tool_result=result) + + async def _run_model_async( + self, + messages: list[BaseMessage], + additional_tools: list["Tool"] = None, + stream: bool = True, + ) -> AsyncGenerator[Event, None]: + from controlflow.events.agent_events import ( + AgentMessageDeltaEvent, + AgentMessageEvent, + ) + from controlflow.events.tool_events import ToolCallEvent, ToolResultEvent + from controlflow.tools.tools import as_tools + + model = self.get_model() + + tools = as_tools(self.get_tools() + (additional_tools or [])) + if tools: + model = model.bind_tools([t.to_lc_tool() for t in tools]) + + if stream: + response = None + async for delta in model.astream(messages): + if response is None: + response = delta + else: + response += delta + yield AgentMessageDeltaEvent(agent=self, delta=delta, snapshot=response) + + else: + response: AIMessage = model.invoke(messages) + + yield AgentMessageEvent(agent=self, message=response) + + for tool_call in response.tool_calls + response.invalid_tool_calls: + yield ToolCallEvent(agent=self, tool_call=tool_call, message=response) + + result = await handle_tool_call_async(tool_call, tools=tools) + yield ToolResultEvent(agent=self, tool_call=tool_call, tool_result=result) + DEFAULT_AGENT = Agent(name="Marvin") diff --git a/src/controlflow/agents/memory.py b/src/controlflow/agents/memory.py index acf94153..b45bce71 100644 --- a/src/controlflow/agents/memory.py +++ b/src/controlflow/agents/memory.py @@ -1,13 +1,15 @@ import abc import uuid -from typing import ClassVar, Optional, cast +from typing import TYPE_CHECKING, ClassVar, Optional, cast from pydantic import Field -from controlflow.tools import Tool from controlflow.utilities.context import ctx from controlflow.utilities.types import ControlFlowModel +if TYPE_CHECKING: + from controlflow.tools import Tool + class Memory(ControlFlowModel, abc.ABC): id: str = Field(default_factory=lambda: uuid.uuid4().hex) @@ -27,7 +29,9 @@ def update(self, value: str, index: int = None): def delete(self, index: int): raise NotImplementedError() - def get_tools(self) -> list[Tool]: + def get_tools(self) -> list["Tool"]: + from controlflow.tools import Tool + update_tool = Tool.from_function( self.update, name="update_memory", diff --git a/src/controlflow/controllers/controller.py b/src/controlflow/controllers/controller.py deleted file mode 100644 index cc49313c..00000000 --- a/src/controlflow/controllers/controller.py +++ /dev/null @@ -1,344 +0,0 @@ -import inspect -import logging -import math -from collections import defaultdict -from contextlib import asynccontextmanager -from typing import Callable - -from pydantic import Field, PrivateAttr, field_validator, model_validator - -import controlflow -from controlflow.agents import Agent -from controlflow.controllers.graph import Graph -from controlflow.controllers.process_messages import prepare_messages -from controlflow.flows import Flow, get_flow -from controlflow.handlers.print_handler import PrintHandler -from controlflow.instructions import get_instructions -from controlflow.llm.completions import completion, completion_async -from controlflow.llm.handlers import ResponseHandler, TUIHandler -from controlflow.llm.messages import MessageType, SystemMessage -from controlflow.tasks.task import Task -from controlflow.tools import as_tools -from controlflow.utilities.context import ctx -from controlflow.utilities.prefect import create_markdown_artifact -from controlflow.utilities.prefect import prefect_task as prefect_task -from controlflow.utilities.types import ControlFlowModel - -logger = logging.getLogger(__name__) - - -def create_messages_markdown_artifact(messages, thread_id): - markdown_messages = "\n\n".join([f"{msg.role}: {msg.content}" for msg in messages]) - create_markdown_artifact( - key="messages", - markdown=inspect.cleandoc( - """ - # Messages - - *Thread ID: {thread_id}* - - {markdown_messages} - """.format( - thread_id=thread_id, - markdown_messages=markdown_messages, - ) - ), - ) - - -class Controller(ControlFlowModel): - """ - A controller contains logic for executing agents with context about the - larger workflow, including the flow itself, any tasks, and any other agents - they are collaborating with. The controller is responsible for orchestrating - agent behavior by generating instructions and tools for each agent. Note - that while the controller accepts details about (potentially multiple) - agents and tasks, it's responsiblity is to invoke one agent one time. Other - mechanisms should be used to orchestrate multiple agents invocations. This - is done by the controller to avoid tying e.g. agents to tasks or even a - specific flow. - """ - - # the flow is tracked by the Controller, not the Task, so that tasks can be - # defined and even instantiated outside a flow. When a Controller is - # created, we know we're inside a flow context and ready to load defaults - # and run. - flow: Flow = Field( - default_factory=get_flow, - description="The flow that the controller is a part of.", - validate_default=True, - ) - tasks: list[Task] = Field( - description="Tasks that the controller will complete.", - ) - agents: dict[Task, list[Agent]] = Field( - default_factory=dict, - description="Optionally assign agents to complete tasks. The provided mapping must be task" - " -> [agents]. Any tasks that aren't included will use their default agents.", - ) - context: dict = {} - model_config: dict = dict(extra="forbid") - enable_experimental_tui: bool = Field( - default_factory=lambda: controlflow.settings.enable_experimental_tui - ) - max_iterations: int = Field( - default_factory=lambda: controlflow.settings.max_iterations - ) - _iteration: int = 0 - _should_stop: bool = False - _end_turn_counts: dict = PrivateAttr(default_factory=lambda: defaultdict(int)) - - @property - def graph(self) -> Graph: - return Graph.from_tasks(self.flow.tasks.values()) - - @field_validator("agents", mode="before") - def _default_agents(cls, v): - if v is None: - v = {} - return v - - @model_validator(mode="after") - def _finalize(self): - for task in self.tasks: - self.flow.add_task(task) - return self - - def _create_end_turn_tool(self) -> Callable: - def end_turn(): - """ - This tool is for emergencies only; you should not use it normally. - If you find yourself in a situation where you are repeatedly invoked - and your normal tools do not work, or you can not escape the loop, - use this tool to signal to the controller that you are stuck. A new - agent will be selected to go next. If this tool is used 3 times by - an agent the workflow will be aborted automatically. - - """ - - # the agent's name is used as the key to track the number of times - key = getattr(ctx.get("agent", None), "name", None) - - self._end_turn_counts[key] += 1 - if self._end_turn_counts[key] >= 3: - self._should_stop = True - self._end_turn_counts[key] = 0 - - return ( - f"Ending turn. {3 - self._end_turn_counts[key]}" - " more uses will abort the workflow." - ) - - return end_turn - - @asynccontextmanager - async def tui(self): - if tui := ctx.get("tui"): - yield tui - elif self.enable_experimental_tui: - from controlflow.tui.app import TUIApp as TUI - - tui = TUI(flow=self.flow) - with ctx(tui=tui): - async with tui.run_context(): - yield tui - else: - yield - - def _setup_run(self): - """ - Generate the payload for a single run of the controller. - """ - if self._iteration >= (self.max_iterations or math.inf): - raise ValueError( - f"Controller has exceeded maximum iterations of {self.max_iterations}." - ) - ready_tasks = [t for t in self.tasks if t.is_ready()] - - # if there are no ready tasks, return. This will usually happen because - # all the tasks are complete. - if not ready_tasks: - return - - # start tracking tasks - for task in ready_tasks: - if not task._prefect_task.is_started: - task._prefect_task.start( - depends_on=[ - t.result for t in task.depends_on if t.result is not None - ] - ) - - messages = self.flow.get_messages() - - # get an agent from the next ready task - agents = self.agents.get(ready_tasks[0], None) - if agents is None: - agents = ready_tasks[0].get_agents() - if len(agents) == 1: - agent = agents[0] - else: - strategy_fn = ready_tasks[0].get_agent_strategy() - agent = strategy_fn(agents=agents, task=ready_tasks[0], flow=self.flow) - ready_tasks[0]._iteration += 1 - - from controlflow.controllers.instruction_template import MainTemplate - - tools = self.flow.tools + agent.get_tools() + [self._create_end_turn_tool()] - - # add tools for any ready tasks that the agent is assigned to - for task in ready_tasks: - if agent in self.agents.get(task, []) or agent in task.get_agents(): - tools.extend(task.get_tools()) - - instructions_template = MainTemplate( - agent=agent, - controller=self, - ready_tasks=ready_tasks, - current_task=ready_tasks[0], - context=self.context, - instructions=get_instructions(), - agent_assignments=self.agents, - ) - instructions = instructions_template.render() - - # prepare messages - system_message = SystemMessage(content=instructions) - - rules = agent.get_llm_rules() - - messages = prepare_messages( - agent=agent, - system_message=system_message, - messages=messages, - rules=rules, - tools=tools, - ) - - # setup handlers - handlers = [] - if self.enable_experimental_tui: - handlers.append(TUIHandler()) - elif controlflow.settings.enable_print_handler: - handlers.append(PrintHandler()) - # yield the agent payload - return dict( - agent=agent, - messages=messages, - tools=as_tools(tools), - handlers=handlers, - ) - - @prefect_task(task_run_name="Run LLM") - async def run_once_async(self) -> list[MessageType]: - async with self.tui(): - payload = self._setup_run() - if payload is None: - return - agent: Agent = payload.pop("agent") - response_handler = ResponseHandler() - payload["handlers"].append(response_handler) - - with ctx(agent=agent, flow=self.flow, controller=self): - response_gen = await completion_async( - messages=payload["messages"], - model=agent.get_model(), - tools=payload["tools"], - handlers=payload["handlers"], - max_iterations=1, - stream=True, - agent=agent, - ) - async for _ in response_gen: - pass - - # save history - self.flow.add_messages( - messages=response_handler.response_messages, - ) - self._iteration += 1 - - create_messages_markdown_artifact( - messages=response_handler.response_messages, - thread_id=self.flow.thread_id, - ) - - return response_handler.response_messages - - @prefect_task(task_run_name="Run LLM") - def run_once(self) -> list[MessageType]: - payload = self._setup_run() - if payload is None: - return - agent: Agent = payload.pop("agent") - response_handler = ResponseHandler() - payload["handlers"].append(response_handler) - - with ctx( - agent=agent, - flow=self.flow, - controller=self, - ): - response_gen = completion( - messages=payload["messages"], - model=agent.get_model(), - tools=payload["tools"], - handlers=payload["handlers"], - max_iterations=1, - stream=True, - agent=agent, - ) - for _ in response_gen: - pass - - # save history - self.flow.add_messages( - messages=response_handler.response_messages, - ) - self._iteration += 1 - - create_messages_markdown_artifact( - messages=response_handler.response_messages, - thread_id=self.flow.thread_id, - ) - - return response_handler.response_messages - - @prefect_task(task_run_name="Run LLM Controller") - async def run_async(self) -> list[MessageType]: - """ - Run the controller until all tasks are complete. - """ - if all(t.is_complete() for t in self.tasks): - return - - messages = [] - async with self.tui(): - # enter a flow context - with self.flow: - while ( - any(t.is_incomplete() for t in self.tasks) and not self._should_stop - ): - new_messages = await self.run_once_async() - if new_messages: - messages.extend(new_messages) - self._should_stop = False - return messages - - @prefect_task(task_run_name="Run LLM Controller") - def run(self) -> list[MessageType]: - """ - Run the controller until all tasks are complete. - """ - if all(t.is_complete() for t in self.tasks): - return - - messages = [] - # enter a flow context - with self.flow: - while any(t.is_incomplete() for t in self.tasks) and not self._should_stop: - new_messages = self.run_once() - if new_messages: - messages.extend(new_messages) - self._should_stop = False - return messages diff --git a/src/controlflow/controllers/instruction_template.py b/src/controlflow/controllers/instruction_template.py deleted file mode 100644 index 73a0ccfd..00000000 --- a/src/controlflow/controllers/instruction_template.py +++ /dev/null @@ -1,272 +0,0 @@ -import inspect - -from controlflow.agents import Agent -from controlflow.flows import Flow -from controlflow.tasks.task import Task -from controlflow.utilities.jinja import jinja_env -from controlflow.utilities.types import ControlFlowModel - -from .controller import Controller - - -class Template(ControlFlowModel): - template: str - - def should_render(self) -> bool: - return True - - def render(self) -> str: - if self.should_render(): - render_kwargs = dict(self) - render_kwargs.pop("template") - return jinja_env.from_string(inspect.cleandoc(self.template)).render( - **render_kwargs - ) - - -class AgentTemplate(Template): - template: str = """ - You are an AI agent participating in a workflow. Your role is to work on - your tasks and use the provided tools to complete those tasks and - communicate with the orchestrator. - - Important: The orchestrator is a Python script and cannot read or - respond to messages posted in this thread. You must use the provided - tools to communicate with the orchestrator. Posting messages in this - thread should only be used for thinking out loud, working through a - problem, or communicating with other agents. Any System messages or - messages prefixed with "SYSTEM:" are from the workflow system, not an - actual human. - - Your job is to: - 1. Select one or more tasks to work on from the ready tasks. - 2. Read the task instructions and work on completing the task objective, which may - involve using appropriate tools or collaborating with other agents - assigned to the same task. - 3. When you (and any other agents) have completed the task objective, - use the provided tool to inform the orchestrator of the task completion - and result. - 4. Repeat steps 1-3 until no more tasks are available for execution. - - Note that the orchestrator may decide to activate a different agent at any time. - - ## Your information - - - ID: {{ agent.id }} - - Name: "{{ agent.name }}" - {% if agent.description -%} - - Description: "{{ agent.description }}" - {% endif %} - - ## Instructions - - You must follow instructions at all times. Instructions can be added or removed at any time. - - - Never impersonate another agent - - {% if agent.instructions %} - {{ agent.instructions }} - {% endif %} - - {% if additional_instructions %} - {% for instruction in additional_instructions %} - - {{ instruction }} - {% endfor %} - {% endif %} - """ - agent: Agent - additional_instructions: list[str] - - -# class MemoryTemplate(Template): -# template: str = """ -# ## Memory - -# You have the following private memories: - -# {% for index, memory in memories %} -# - {{ index }}: {{ memory }} -# {% endfor %} - -# Use your memory to record information -# """ -# memories: dict[int, str] - - -class WorkflowTemplate(Template): - template: str = """ - - ## Tasks - - As soon as you have completed a task's objective, you must use the provided - tool to mark it successful and provide a result. It may take multiple - turns or collaboration with other agents to complete a task. Any agent - assigned to a task can complete it. Once a task is complete, no other - agent can interact with it. - - Tasks should only be marked failed due to technical errors like a broken - or erroring tool or unresponsive human. - - Tasks are not ready until all of their dependencies are met. Parent - tasks depend on all of their subtasks. - - ## Flow - - Name: {{ flow.name }} - {% if flow.description %} - Description: {{ flow.description }} - {% endif %} - {% if flow.context %} - Context: - {% for key, value in flow.context.items() %} - - {{ key }}: {{ value }} - {% endfor %} - {% endif %} - - ## Tasks - - ### Ready tasks - - These tasks are ready to be worked on because all of their dependencies have - been completed. You can only work on tasks to which you are assigned. - - {% for task in ready_tasks %} - #### Task {{ task.id }} - - objective: {{ task.objective }} - - instructions: {{ task.instructions}} - - context: {{ task.context }} - - result_type: {{ task.result_type }} - - depends_on: {{ task.depends_on }} - - parent: {{ task.parent }} - - assigned agents: {{ task.agents }} - {% if task.user_access %} - - user access: True - {% endif %} - - created_at: {{ task.created_at }} - - {% endfor %} - - ### Upstream tasks - - {% for task in upstream_tasks %} - #### Task {{ task.id }} - - objective: {{ task.objective }} - - instructions: {{ task.instructions}} - - status: {{ task.status }} - - result: {{ task.result }} - - error: {{ task.error }} - - context: {{ task.context }} - - depends_on: {{ task.depends_on }} - - parent: {{ task.parent }} - - assigned agents: {{ task.agents }} - {% if task.user_access %} - - user access: True - {% endif %} - - created_at: {{ task.created_at }} - - {% endfor %} - - ### Downstream tasks - - {% for task in downstream_tasks %} - #### Task {{ task.id }} - - objective: {{ task.objective }} - - instructions: {{ task.instructions}} - - status: {{ task.status }} - - result_type: {{ task.result_type }} - - context: {{ task.context }} - - depends_on: {{ task.depends_on }} - - parent: {{ task.parent }} - - assigned agents: {{ task.agents }} - {% if task.user_access %} - - user access: True - {% endif %} - - created_at: {{ task.created_at }} - - {% endfor %} - """ - - ready_tasks: list[dict] - upstream_tasks: list[dict] - downstream_tasks: list[dict] - current_task: Task - flow: Flow - - -class ToolTemplate(Template): - template: str = """ - You have access to various tools. They may change, so do not rely on history - to see what tools are available. - - ## Talking to human users - - If your task requires you to interact with a user, it will show - `user_access=True` and you will be given a `talk_to_user` tool. You can - use it to send messages to the user and optionally wait for a response. - This is how you tell the user things and ask questions. Do not mention - your tasks or the workflow. The user can only see messages you send - them via tool. They can not read the rest of the - thread. - - Human users may give poor, incorrect, or partial responses. You may need - to ask questions multiple times in order to complete your tasks. Do not - make up answers for omitted information; ask again and only fail the - task if you truly can not make progress. If your task requires human - interaction and neither it nor any assigned agents have `user_access`, - you can fail the task. - """ - - agent: Agent - - -class MainTemplate(ControlFlowModel): - agent: Agent - controller: Controller - ready_tasks: list[Task] - current_task: Task - context: dict - instructions: list[str] - agent_assignments: dict[Task, list[Agent]] - - def render(self): - # get up to 50 upstream and 50 downstream tasks - g = self.controller.graph - upstream_tasks = g.topological_sort([t for t in g.tasks if t.is_complete()])[ - -50: - ] - downstream_tasks = g.topological_sort( - [t for t in g.tasks if t.is_incomplete() and t not in self.ready_tasks] - )[:50] - - ready_tasks = [t.model_dump() for t in self.ready_tasks] - upstream_tasks = [t.model_dump() for t in upstream_tasks] - downstream_tasks = [t.model_dump() for t in downstream_tasks] - - # update agent assignments - assignments = {t.id: a for t, a in self.agent_assignments.items()} - for t in ready_tasks + upstream_tasks + downstream_tasks: - if t["id"] in assignments: - t["agents"] = assignments[t["id"]] - - templates = [ - AgentTemplate( - agent=self.agent, - additional_instructions=self.instructions, - ), - WorkflowTemplate( - flow=self.controller.flow, - ready_tasks=ready_tasks, - upstream_tasks=upstream_tasks, - downstream_tasks=downstream_tasks, - current_task=self.current_task, - ), - ToolTemplate(agent=self.agent), - # CommunicationTemplate( - # agent=self.agent, - # ), - ] - - rendered = [ - template.render() for template in templates if template.should_render() - ] - return "\n\n".join(rendered) diff --git a/src/controlflow/controllers/process_messages.py b/src/controlflow/controllers/process_messages.py deleted file mode 100644 index 83e42844..00000000 --- a/src/controlflow/controllers/process_messages.py +++ /dev/null @@ -1,127 +0,0 @@ -from typing import Optional, Union - -from controlflow.agents.agent import Agent -from controlflow.llm.messages import ( - AIMessage, - MessageType, - SystemMessage, - UserMessage, -) -from controlflow.llm.rules import LLMRules -from controlflow.tools import Tool - - -def create_system_message( - content: str, rules: LLMRules -) -> Union[SystemMessage, UserMessage]: - """ - Creates a SystemMessage or HumanMessage with SYSTEM: prefix, depending on the rules. - """ - if rules.system_message_must_be_first: - return SystemMessage(content=content) - else: - return UserMessage(content=f"SYSTEM: {content}") - - -def handle_agent_info_in_messages( - messages: list[MessageType], agent: Agent, rules: LLMRules -) -> list[MessageType]: - """ - If the message is from an agent, add a system message immediately before it to clarify which agent - it is from. This helps the system follow multi-agent conversations. - - """ - if not rules.add_system_messages_for_multi_agent: - return messages - - current_agent = agent - new_messages = [] - for msg in messages: - # if the message is from a different agent than the previous message, - # add a clarifying system message - if isinstance(msg, AIMessage) and msg.agent and msg.agent != current_agent: - system_msg = SystemMessage( - content=f'The following message is from agent "{msg.agent["name"]}" ' - f'with id {msg.agent["id"]}.' - ) - new_messages.append(system_msg) - current_agent = msg.agent - new_messages.append(msg) - return new_messages - - -def handle_system_messages_must_be_first(messages: list[MessageType], rules: LLMRules): - if rules.system_message_must_be_first: - new_messages = [] - # consolidate consecutive SystemMessages into one - if isinstance(messages[0], SystemMessage): - content = [messages[0].content] - i = 1 - while i < len(messages) and isinstance(messages[i], SystemMessage): - i += 1 - content.append(messages[i].content) - new_messages.append(SystemMessage(content="\n\n".join(content))) - - # replace all other SystemMessages with HumanMessages - for i, msg in enumerate(messages[len(new_messages) :]): - if isinstance(msg, SystemMessage): - msg = UserMessage(content=f"SYSTEM: {msg.content}") - new_messages.append(msg) - - return new_messages - else: - return messages - - -def handle_user_message_must_be_first_after_system( - messages: list[MessageType], rules: LLMRules -): - if rules.user_message_must_be_first_after_system: - if not messages: - messages.append(UserMessage(content="SYSTEM: Begin.")) - - # else get first non-system message - else: - i = 0 - while i < len(messages) and isinstance(messages[i], SystemMessage): - i += 1 - if i == len(messages) or ( - i < len(messages) and not isinstance(messages[i], UserMessage) - ): - messages.insert(i, UserMessage(content="SYSTEM: Begin.")) - return messages - - -def prepare_messages( - agent: Agent, - messages: list[MessageType], - system_message: Optional[SystemMessage], - rules: LLMRules, - tools: list[Tool], -): - """This is the main function for processing messages. It applies all the rules""" - messages = messages.copy() - - if system_message is not None: - messages.insert(0, system_message) - - messages = handle_agent_info_in_messages(messages, agent=agent, rules=rules) - - if not rules.allow_last_message_has_ai_role_with_tools: - if messages and tools and isinstance(messages[-1], AIMessage): - messages.append(create_system_message("Continue.", rules=rules)) - - if not rules.allow_consecutive_ai_messages: - if messages: - i = 1 - while i < len(messages): - if isinstance(messages[i], AIMessage) and isinstance( - messages[i - 1], AIMessage - ): - messages.insert(i, create_system_message("Continue.", rules=rules)) - i += 1 - - messages = handle_system_messages_must_be_first(messages, rules=rules) - messages = handle_user_message_must_be_first_after_system(messages, rules=rules) - - return messages diff --git a/src/controlflow/decorators.py b/src/controlflow/decorators.py index c02fc7bd..e83503d8 100644 --- a/src/controlflow/decorators.py +++ b/src/controlflow/decorators.py @@ -7,8 +7,7 @@ from controlflow.flows import Flow from controlflow.tasks.task import Task from controlflow.utilities.logging import get_logger -from controlflow.utilities.prefect import prefect_flow as prefect_flow -from controlflow.utilities.prefect import prefect_task as prefect_task +from controlflow.utilities.prefect import prefect_flow, prefect_task # from controlflow.utilities.marvin import patch_marvin from controlflow.utilities.tasks import resolve_tasks diff --git a/src/controlflow/events/__init__.py b/src/controlflow/events/__init__.py new file mode 100644 index 00000000..23868b4f --- /dev/null +++ b/src/controlflow/events/__init__.py @@ -0,0 +1 @@ +from .events import Event diff --git a/src/controlflow/events/agent_events.py b/src/controlflow/events/agent_events.py new file mode 100644 index 00000000..e0cd3b6a --- /dev/null +++ b/src/controlflow/events/agent_events.py @@ -0,0 +1,118 @@ +from typing import Literal, Optional + +from pydantic import field_validator, model_validator + +from controlflow.agents.agent import Agent +from controlflow.events.events import Event, UnpersistedEvent +from controlflow.events.message_compiler import EventContext +from controlflow.llm.messages import ( + AIMessage, + AIMessageChunk, + BaseMessage, + HumanMessage, + SystemMessage, +) +from controlflow.utilities.logging import get_logger + +logger = get_logger(__name__) + + +class SelectAgentEvent(Event): + event: Literal["select-agent"] = "select-agent" + agent: Agent + + # def to_messages(self, context: EventContext) -> list[BaseMessage]: + # return [ + # SystemMessage( + # content=f'Agent "{self.agent.name}" with ID {self.agent.id} was selected.' + # ) + # ] + + +class SystemMessageEvent(Event): + event: Literal["system-message"] = "system-message" + content: str + + def to_messages(self, context: EventContext) -> list[BaseMessage]: + return [SystemMessage(content=self.content)] + + +class UserMessageEvent(Event): + event: Literal["user-message"] = "user-message" + content: str + + def to_messages(self, context: EventContext) -> list[BaseMessage]: + return [HumanMessage(content=self.content)] + + +class AgentMessageEvent(Event): + event: Literal["agent-message"] = "agent-message" + agent: Agent + message: dict + + @field_validator("message", mode="before") + def _message(cls, v): + if isinstance(v, BaseMessage): + v = v.dict() + v["type"] = "ai" + return v + + @model_validator(mode="after") + def _finalize(self): + self.message["name"] = self.agent.name + + @property + def ai_message(self) -> AIMessage: + return AIMessage(**self.message) + + def to_messages(self, context: EventContext) -> list[BaseMessage]: + if self.agent.name == context.agent.name: + return [self.ai_message] + elif self.message["content"]: + return [ + SystemMessage( + content=f'The following message was posted by Agent "{self.agent.name}" ' + f"with ID {self.agent.id}:", + ), + HumanMessage( + # ensure this is stringified to avoid issues with inline tool calls + content=str(self.message["content"]), + name=self.agent.name, + ), + ] + else: + return [] + + +class AgentMessageDeltaEvent(UnpersistedEvent): + event: Literal["agent-message-delta"] = "agent-message-delta" + + agent: Agent + delta: dict + snapshot: dict + + @field_validator("delta", "snapshot", mode="before") + def _message(cls, v): + if isinstance(v, BaseMessage): + v = v.dict() + v["type"] = "AIMessageChunk" + return v + + @model_validator(mode="after") + def _finalize(self): + self.delta["name"] = self.agent.name + self.snapshot["name"] = self.agent.name + + @property + def delta_message(self) -> AIMessageChunk: + return AIMessageChunk(**self.delta) + + @property + def snapshot_message(self) -> AIMessage: + return AIMessage(**self.snapshot | {"type": "ai"}) + + +class EndTurnEvent(Event): + event: Literal["end-turn"] = "end-turn" + agent: Agent + next_agent_name: Optional[str] = None diff --git a/src/controlflow/events/controller_events.py b/src/controlflow/events/controller_events.py new file mode 100644 index 00000000..b69fc21d --- /dev/null +++ b/src/controlflow/events/controller_events.py @@ -0,0 +1,26 @@ +from typing import Literal + +from controlflow.events.events import UnpersistedEvent +from controlflow.orchestration.controller import Controller +from controlflow.utilities.logging import get_logger + +logger = get_logger(__name__) + + +class ControllerStart(UnpersistedEvent): + event: Literal["controller-start"] = "controller-start" + persist: bool = False + controller: Controller + + +class ControllerEnd(UnpersistedEvent): + event: Literal["controller-end"] = "controller-end" + persist: bool = False + controller: Controller + + +class ControllerError(UnpersistedEvent): + event: Literal["controller-error"] = "controller-error" + persist: bool = False + controller: Controller + error: Exception diff --git a/src/controlflow/events/events.py b/src/controlflow/events/events.py new file mode 100644 index 00000000..e839b73c --- /dev/null +++ b/src/controlflow/events/events.py @@ -0,0 +1,36 @@ +import datetime +import uuid +from typing import TYPE_CHECKING, Optional + +from pydantic import Field + +from controlflow.utilities.types import ControlFlowModel + +if TYPE_CHECKING: + from controlflow.events.message_compiler import EventContext + from controlflow.llm.messages import BaseMessage + +# This is a global variable that will be shared between all instances of InMemoryStore +IN_MEMORY_STORE = {} + + +class Event(ControlFlowModel): + model_config: dict = dict(extra="allow") + + event: str + id: str = Field(default_factory=lambda: uuid.uuid4().hex) + thread_id: Optional[str] = None + agent_ids: list[str] = [] + task_ids: list[str] = [] + timestamp: datetime.datetime = Field( + default_factory=lambda: datetime.datetime.now(datetime.timezone.utc) + ) + persist: bool = True + + def to_messages(self, context: "EventContext") -> list["BaseMessage"]: + return [] + + +class UnpersistedEvent(Event): + model_config = dict(arbitrary_types_allowed=True) + persist: bool = False diff --git a/src/controlflow/events/history.py b/src/controlflow/events/history.py new file mode 100644 index 00000000..f44cf615 --- /dev/null +++ b/src/controlflow/events/history.py @@ -0,0 +1,258 @@ +import abc +import json +import math +from functools import cache +from pathlib import Path +from typing import TYPE_CHECKING, Optional, Union + +from pydantic import Field, TypeAdapter, field_validator + +import controlflow +from controlflow.events.agent_events import ( + AgentMessageEvent, + EndTurnEvent, + SelectAgentEvent, + SystemMessageEvent, + UserMessageEvent, +) +from controlflow.events.events import Event +from controlflow.events.task_events import TaskCompleteEvent, TaskReadyEvent +from controlflow.events.tool_events import ToolResultEvent +from controlflow.utilities.types import ControlFlowModel + +if TYPE_CHECKING: + pass + +# This is a global variable that will be shared between all instances of InMemoryStore +IN_MEMORY_STORE = {} + + +def get_default_history() -> "History": + return controlflow.default_history + + +@cache +def get_event_validator() -> TypeAdapter: + types = Union[ + TaskReadyEvent, + TaskCompleteEvent, + SelectAgentEvent, + SystemMessageEvent, + UserMessageEvent, + AgentMessageEvent, + EndTurnEvent, + ToolResultEvent, + Event, + ] + return TypeAdapter(list[types]) + + +def filter_events( + events: list[Event], + agent_ids: Optional[list[str]] = None, + task_ids: Optional[list[str]] = None, + types: Optional[list[str]] = None, + before_id: Optional[str] = None, + after_id: Optional[str] = None, + limit: Optional[int] = None, +): + """ + Filters a list of events based on the specified criteria. + + Args: + events (list[Event]): The list of events to filter. + agent_ids (Optional[list[str]]): The agent ids to filter by. Defaults to None. + task_ids (Optional[list[str]]): The task ids to filter by. Defaults to None. + types (Optional[list[str]]): The event types to filter by. Defaults to None. + before_id (Optional[str]): The ID of the event before which to start including events. Defaults to None. + after_id (Optional[str]): The ID of the event after which to stop including events. Defaults to None. + limit (Optional[int]): The maximum number of events to include. Defaults to None. + + Returns: + list[Event]: The filtered list of events. + """ + new_events = [] + seen_before_id = True if not before_id else False + seen_after_id = False if not after_id else True + + for event in reversed(events): + if event.id == before_id: + seen_before_id = True + if event.id == after_id: + seen_after_id = True + + # if we haven't reached the `before_id` we can skip this event + if not seen_before_id: + continue + + # if we've reached the `after_id` we can stop searching + if seen_after_id: + break + + # if types are specified and this event is not one of them, skip it + if types and event.event not in types: + continue + + # if agent_ids are specified and this event has agent_ids and none of them are in the list, skip it + agent_match = ( + ( + agent_ids + and event.agent_ids + and any(a in event.agent_ids for a in agent_ids) + ) + or not agent_ids + or not event.agent_ids + ) + + # if task_ids are specified and this event has task_ids and none of them are in the list, skip it + task_match = ( + (task_ids and event.task_ids and any(t in event.task_ids for t in task_ids)) + or not task_ids + or not event.task_ids + ) + + # if neither agent_ids nor task_ids were matched + if not (agent_match or task_match): + continue + + new_events.append(event) + + if len(new_events) >= (limit or math.inf): + break + + return list(reversed(new_events)) + + +class History(ControlFlowModel, abc.ABC): + @abc.abstractmethod + def get_events( + self, + thread_id: str, + types: Optional[list[str]] = None, + agent_ids: Optional[list[str]] = None, + task_ids: Optional[list[str]] = None, + before_id: Optional[str] = None, + after_id: Optional[str] = None, + limit: Optional[int] = None, + ) -> list[Event]: + raise NotImplementedError() + + @abc.abstractmethod + def add_events(self, thread_id: str, events: list[Event]): + raise NotImplementedError() + + +class InMemoryHistory(History): + history: dict[str, list[Event]] = Field( + default_factory=lambda: IN_MEMORY_STORE, repr=False + ) + + def add_events(self, thread_id: str, events: list[Event]): + self.history.setdefault(thread_id, []).extend(events) + + def get_events( + self, + thread_id: str, + types: Optional[list[str]] = None, + agent_ids: Optional[list[str]] = None, + task_ids: Optional[list[str]] = None, + before_id: Optional[str] = None, + after_id: Optional[str] = None, + limit: Optional[int] = None, + ) -> list[Event]: + """ + Retrieve a list of events based on the specified criteria. + + Args: + thread_id (str): The ID of the thread to retrieve events from. + agent_ids (Optional[list[str]]): The agent associated with the events (default: None). + task_ids (Optional[list[str]]): The list of tasks associated with the events (default: None). + types (Optional[list[str]]): The list of event types to filter by (default: None). + before_id (Optional[str]): The ID of the event before which to start retrieving events (default: None). + after_id (Optional[str]): The ID of the event after which to stop retrieving events (default: None). + limit (Optional[int]): The maximum number of events to retrieve (default: None). + + Returns: + list[Event]: A list of events that match the specified criteria. + + """ + events = self.history.get(thread_id, []) + return filter_events( + events=events, + agent_ids=agent_ids, + task_ids=task_ids, + types=types, + before_id=before_id, + after_id=after_id, + limit=limit, + ) + + +class FileHistory(History): + base_path: Path = Field( + default_factory=lambda: controlflow.settings.home_path / "filestore_events" + ) + + def path(self, thread_id: str) -> Path: + return self.base_path / f"{thread_id}.json" + + @field_validator("base_path", mode="before") + def _validate_path(cls, v): + v = Path(v).expanduser() + if not v.exists(): + v.mkdir(parents=True, exist_ok=True) + return v + + def get_events( + self, + thread_id: str, + agent_ids: Optional[list[str]] = None, + task_ids: Optional[list[str]] = None, + types: Optional[list[str]] = None, + before_id: Optional[str] = None, + after_id: Optional[str] = None, + limit: Optional[int] = None, + ) -> list[Event]: + """ + Retrieves a list of events based on the specified criteria. + + Args: + thread_id (str): The ID of the thread to retrieve events from. + agent_ids (Optional[list[str]]): The agent associated with the events (default: None). + task_ids (Optional[list[str]]): The list of tasks associated with the events (default: None). + types (Optional[list[str]]): The list of event types to filter by (default: None). + before_id (Optional[str]): The ID of the event before which to stop retrieving events (default: None). + after_id (Optional[str]): The ID of the event after which to start retrieving events (default: None). + limit (Optional[int]): The maximum number of events to retrieve (default: None). + + Returns: + list[Event]: A list of events that match the specified criteria. + """ + if not self.path(thread_id).exists(): + return [] + + with open(self.path(thread_id), "r") as f: + raw_data = f.read() + + validator = get_event_validator() + events = validator.validate_json(raw_data) + + return filter_events( + events=events, + agent_ids=agent_ids, + task_ids=task_ids, + types=types, + before_id=before_id, + after_id=after_id, + limit=limit, + ) + + def add_events(self, thread_id: str, events: list[Event]): + if self.path(thread_id).exists(): + with open(self.path(thread_id), "r") as f: + all_events = json.load(f) + else: + all_events = [] + all_events.extend([event.model_dump(mode="json") for event in events]) + with open(self.path(thread_id), "w") as f: + json.dump(all_events, f) diff --git a/src/controlflow/events/message_compiler.py b/src/controlflow/events/message_compiler.py new file mode 100644 index 00000000..c6988ff8 --- /dev/null +++ b/src/controlflow/events/message_compiler.py @@ -0,0 +1,222 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING, Optional + +import tiktoken + +import controlflow +from controlflow.events.events import Event +from controlflow.llm.messages import ( + AIMessage, + BaseMessage, + HumanMessage, + SystemMessage, + ToolMessage, +) +from controlflow.llm.rules import LLMRules +from controlflow.utilities.logging import get_logger + +if TYPE_CHECKING: + from controlflow.agents.agent import Agent + from controlflow.flows.flow import Flow + from controlflow.orchestration.controller import Controller + from controlflow.tasks.task import Task +logger = get_logger(__name__) + + +def add_user_message_to_beginning( + messages: list[BaseMessage], rules: LLMRules +) -> list[BaseMessage]: + """ + If the LLM requires the user message to be the first message, add a user + message to the beginning of the list. + """ + if rules.require_user_message_after_system: + if not messages or not isinstance(messages[0], HumanMessage): + messages.insert(0, HumanMessage(content="SYSTEM: Begin.")) + return messages + + +def ensure_at_least_one_message( + messages: list[BaseMessage], rules: LLMRules +) -> list[BaseMessage]: + if not messages and rules.require_at_least_one_message: + messages.append(HumanMessage(content="SYSTEM: Begin.")) + return messages + + +def add_user_message_to_end( + messages: list[BaseMessage], rules: LLMRules +) -> list[BaseMessage]: + """ + If the LLM doesn't allow the last message to be from the AI when using tools, + add a user message to the end of the list. + """ + if not rules.allow_last_message_from_ai_when_using_tools: + if not messages or isinstance(messages[-1], AIMessage): + msg = HumanMessage(content="SYSTEM: Continue.") + messages.append(msg) + return messages + + +def remove_duplicate_messages(messages: list[BaseMessage]) -> list[BaseMessage]: + """ + Removes duplicate messages from the list. + """ + seen = set() + new_messages = [] + for message in messages: + if message.id not in seen: + new_messages.append(message) + if message.id: + seen.add(message.id) + return new_messages + + +def break_up_consecutive_ai_messages( + messages: list[BaseMessage], rules: LLMRules +) -> list[BaseMessage]: + """ + Breaks up consecutive AI messages by inserting a system message. + """ + if not messages or rules.allow_consecutive_ai_messages: + return messages + + new_messages = messages.copy() + i = 1 + while i < len(new_messages): + if isinstance(new_messages[i], AIMessage) and isinstance( + new_messages[i - 1], AIMessage + ): + new_messages.insert(i, SystemMessage(content="Continue.")) + i += 1 + + return new_messages + + +def convert_system_messages( + messages: list[BaseMessage], rules: LLMRules +) -> list[BaseMessage]: + """ + Converts system messages to human messages if the LLM doesnt support system messages. + """ + if not messages or not rules.require_system_message_first: + return messages + + new_messages = [] + for message in messages: + if isinstance(message, SystemMessage): + new_messages.append(HumanMessage(content=f"SYSTEM: {message.content}")) + else: + new_messages.append(message) + return new_messages + + +def organize_tool_result_messages( + messages: list[BaseMessage], rules: LLMRules +) -> list[BaseMessage]: + if not messages or not rules.tool_result_must_follow_tool_call: + return messages + + tool_calls = {} + new_messages = [] + i = 0 + while i < len(messages): + message = messages[i] + # save the message index of any tool calls + if isinstance(message, AIMessage): + for tool_call in message.tool_calls + message.invalid_tool_calls: + tool_calls[tool_call["id"]] = i + new_messages.append(message) + + # move tool messages to follow their corresponding tool calls + elif isinstance(message, ToolMessage) and tool_call["id"] in tool_calls: + tool_call_index = tool_calls[tool_call["id"]] + new_messages.insert(tool_call_index + 1, message) + tool_calls[tool_call["id"]] += 1 + + else: + new_messages.append(message) + i += 1 + return new_messages + + +def count_tokens(message: BaseMessage) -> int: + # always use gpt-3.5 token counter with the entire message object; we only need to be approximate here + return len(tiktoken.encoding_for_model("gpt-3.5-turbo").encode(message.json())) + + +def trim_messages( + messages: list[BaseMessage], max_tokens: Optional[int] +) -> list[BaseMessage]: + """ + Trims messages to a maximum number of tokens while keeping the system message at the front. + """ + + if not messages or max_tokens is None: + return messages + + new_messages = [] + budget = max_tokens + + for message in reversed(messages): + if count_tokens(message) > budget: + break + new_messages.append(message) + budget -= count_tokens(message) + + return list(reversed(new_messages)) + + +@dataclass +class EventContext: + llm_rules: LLMRules + agent: Optional["Agent"] + ready_tasks: list["Task"] + flow: Optional["Flow"] + controller: Optional["Controller"] + + +class MessageCompiler: + def __init__( + self, + events: list[Event], + context: EventContext, + system_prompt: str = None, + max_tokens: int = None, + ): + self.events = events + self.context = context + self.system_prompt = system_prompt + self.max_tokens = max_tokens or controlflow.settings.max_input_tokens + + def compile_to_messages(self) -> list[BaseMessage]: + if self.system_prompt: + system = [SystemMessage(content=self.system_prompt)] + max_tokens = self.max_tokens - count_tokens(system[0]) + else: + system = [] + max_tokens = self.max_tokens + + messages = [] + + for event in self.events: + messages.extend(event.to_messages(self.context)) + + # process messages + msgs = messages.copy() + + # trim messages + msgs = trim_messages(msgs, max_tokens=max_tokens) + + # apply LLM rules + msgs = ensure_at_least_one_message(msgs, rules=self.context.llm_rules) + msgs = add_user_message_to_beginning(msgs, rules=self.context.llm_rules) + msgs = add_user_message_to_end(msgs, rules=self.context.llm_rules) + msgs = remove_duplicate_messages(msgs) + msgs = organize_tool_result_messages(msgs, rules=self.context.llm_rules) + msgs = break_up_consecutive_ai_messages(msgs, rules=self.context.llm_rules) + + # this should go last + msgs = convert_system_messages(msgs, rules=self.context.llm_rules) + + return system + msgs diff --git a/src/controlflow/events/task_events.py b/src/controlflow/events/task_events.py new file mode 100644 index 00000000..7f3480ce --- /dev/null +++ b/src/controlflow/events/task_events.py @@ -0,0 +1,24 @@ +from typing import Literal + +from controlflow.events.events import Event, UnpersistedEvent +from controlflow.tasks.task import Task +from controlflow.utilities.logging import get_logger + +logger = get_logger(__name__) + + +class TaskReadyEvent(UnpersistedEvent): + event: Literal["task-ready"] = "task-ready" + task: Task + + +class TaskCompleteEvent(Event): + event: Literal["task-complete"] = "task-complete" + task: Task + + # def to_messages(self, context: EventContext) -> list[BaseMessage]: + # return [ + # SystemMessage( + # content=f"Task {self.task.id} is complete with status: {self.task.status}" + # ) + # ] diff --git a/src/controlflow/events/tool_events.py b/src/controlflow/events/tool_events.py new file mode 100644 index 00000000..d27ce4fe --- /dev/null +++ b/src/controlflow/events/tool_events.py @@ -0,0 +1,62 @@ +from typing import Literal + +from pydantic import field_validator, model_validator + +from controlflow.agents.agent import Agent +from controlflow.events.events import Event +from controlflow.events.message_compiler import EventContext +from controlflow.llm.messages import AIMessage, BaseMessage, SystemMessage, ToolMessage +from controlflow.tools.tools import ToolCall, ToolResult +from controlflow.utilities.logging import get_logger + +logger = get_logger(__name__) + + +class ToolCallEvent(Event): + event: Literal["tool-call"] = "tool-call" + agent: Agent + tool_call: ToolCall + message: dict + + @field_validator("message", mode="before") + def _message(cls, v): + if isinstance(v, AIMessage): + v = v.dict() + v["type"] = "ai" + return v + + @model_validator(mode="after") + def _finalize(self): + self.message["name"] = self.agent.name + + @property + def ai_message(self) -> AIMessage: + return AIMessage(**self.message) + + +class ToolResultEvent(Event): + event: Literal["tool-result"] = "tool-result" + agent: Agent + tool_call: ToolCall + tool_result: ToolResult + + def to_messages(self, context: EventContext) -> list[BaseMessage]: + if self.agent.name == context.agent.name: + return [ + ToolMessage( + content=self.tool_result.str_result, + tool_call_id=self.tool_call["id"], + name=self.agent.name, + ) + ] + elif not self.tool_result.is_private: + return [ + SystemMessage( + content=f'The following {"failed" if self.tool_result.is_error else "successful"} ' + f'tool result was received by "{self.agent.name}" with ID {self.agent.id}:' + ), + SystemMessage(content=self.tool_result.str_result), + ] + + else: + return [] diff --git a/src/controlflow/flows/__init__.py b/src/controlflow/flows/__init__.py index d87c29d0..d1f717b7 100644 --- a/src/controlflow/flows/__init__.py +++ b/src/controlflow/flows/__init__.py @@ -1 +1 @@ -from .flow import Flow, get_flow, get_flow_messages +from .flow import Flow, get_flow diff --git a/src/controlflow/flows/flow.py b/src/controlflow/flows/flow.py index b64e89f8..9e2f8b1c 100644 --- a/src/controlflow/flows/flow.py +++ b/src/controlflow/flows/flow.py @@ -1,4 +1,3 @@ -import datetime import uuid from contextlib import contextmanager, nullcontext from typing import Any, Callable, Optional, Union @@ -6,8 +5,9 @@ from pydantic import Field from controlflow.agents import Agent -from controlflow.flows.history import History, get_default_history -from controlflow.llm.messages import MessageType +from controlflow.events.events import Event +from controlflow.events.history import History, get_default_history +from controlflow.flows.graph import Graph from controlflow.tasks.task import Task from controlflow.utilities.context import ctx from controlflow.utilities.logging import get_logger @@ -18,10 +18,11 @@ class Flow(ControlFlowModel): + model_config = dict(arbitrary_types_allowed=True) + thread_id: str = Field(default_factory=lambda: uuid.uuid4().hex) name: Optional[str] = None description: Optional[str] = None - thread_id: str = Field(default_factory=lambda: uuid.uuid4().hex) - history: History = Field(default_factory=get_default_history) + event_store: History = Field(default_factory=get_default_history) tools: list[Callable] = Field( default_factory=list, description="Tools that will be available to every agent in the flow", @@ -32,21 +33,21 @@ class Flow(ControlFlowModel): default_factory=list, ) context: dict[str, Any] = {} - tasks: dict[str, Task] = {} + graph: Graph = Field(default_factory=Graph, repr=False, exclude=True) _cm_stack: list[contextmanager] = [] def __init__(self, *, copy_parent: bool = True, **kwargs): """ - By default, the flow will copy the history from the parent flow if one + By default, the flow will copy the event history from the parent flow if one exists, including all completed tasks. Because each flow is a new - thread, new messages will not be shared between the parent and child + thread, new events will not be shared between the parent and child flow. """ super().__init__(**kwargs) parent = get_flow() if parent and copy_parent: - self.add_messages(parent.get_messages()) - for task in parent.tasks.values(): + self.add_events(parent.get_events()) + for task in parent.tasks: if task.is_complete(): self.add_task(task) @@ -60,25 +61,34 @@ def __exit__(self, *exc_info): # exit the context manager return self._cm_stack.pop().__exit__(*exc_info) - def get_messages( + def add_task(self, task: Task): + self.graph.add_task(task) + + @property + def tasks(self) -> list[Task]: + return self.graph.topological_sort() + + def get_events( self, - limit: int = None, - before: datetime.datetime = None, - after: datetime.datetime = None, - ) -> list[MessageType]: - return self.history.load_messages( - thread_id=self.thread_id, limit=limit, before=before, after=after + agent_ids: Optional[list[str]] = None, + task_ids: Optional[list[str]] = None, + before_id: Optional[str] = None, + after_id: Optional[str] = None, + limit: Optional[int] = None, + types: Optional[list[str]] = None, + ) -> list[Event]: + return self.event_store.get_events( + thread_id=self.thread_id, + agent_ids=agent_ids, + task_ids=task_ids, + before_id=before_id, + after_id=after_id, + limit=limit, + types=types, ) - def add_messages(self, messages: list[MessageType]): - self.history.save_messages(thread_id=self.thread_id, messages=messages) - - def add_task(self, task: Task): - if self.tasks.get(task.id, task) is not task: - raise ValueError( - f"A different task with id '{task.id}' already exists in flow." - ) - self.tasks[task.id] = task + def add_events(self, events: list[Event]): + self.event_store.add_events(thread_id=self.thread_id, events=events) @contextmanager def create_context(self, create_prefect_flow_context: bool = True): @@ -92,57 +102,23 @@ def create_context(self, create_prefect_flow_context: bool = True): with ctx(**ctx_args), prefect_ctx: yield self - async def run_once_async(self): - """ - Runs one step of the flow asynchronously. - """ - if self.tasks: - from controlflow.controllers import Controller - - controller = Controller( - flow=self, - tasks=list(self.tasks.values()), - ) - await controller.run_once_async() - def run_once(self): """ Runs one step of the flow. """ - if self.tasks: - from controlflow.controllers import Controller - - controller = Controller( - flow=self, - tasks=list(self.tasks.values()), - ) - controller.run_once() - - async def run_async(self): - """ - Runs the flow asynchronously. - """ - if self.tasks: - from controlflow.controllers import Controller + from controlflow.orchestration import Controller - controller = Controller( - flow=self, - tasks=list(self.tasks.values()), - ) - await controller.run_async() + controller = Controller(flow=self) + controller.run_once() def run(self): """ Runs the flow. """ - if self.tasks: - from controlflow.controllers import Controller + from controlflow.orchestration import Controller - controller = Controller( - flow=self, - tasks=list(self.tasks.values()), - ) - controller.run() + controller = Controller(flow=self) + controller.run() def get_flow() -> Optional[Flow]: @@ -153,16 +129,14 @@ def get_flow() -> Optional[Flow]: return flow -def get_flow_messages(limit: int = None) -> list[MessageType]: +def get_flow_events(limit: int = None) -> list[Event]: """ - Loads messages from the flow's thread. - - Will error if no flow is found in the context. + Loads events from the active flow's thread. """ if limit is None: limit = 50 flow = get_flow() if flow: - return get_default_history().load_messages(flow.thread_id, limit=limit) + return flow.get_events(limit=limit) else: return [] diff --git a/src/controlflow/controllers/graph.py b/src/controlflow/flows/graph.py similarity index 79% rename from src/controlflow/controllers/graph.py rename to src/controlflow/flows/graph.py index bef4b14b..c1fd9c82 100644 --- a/src/controlflow/controllers/graph.py +++ b/src/controlflow/flows/graph.py @@ -1,8 +1,7 @@ +from dataclasses import dataclass from enum import Enum from typing import Optional, TypeVar -from pydantic import BaseModel - from controlflow.tasks.task import Task T = TypeVar("T") @@ -32,7 +31,8 @@ class EdgeType(Enum): SUBTASK = "subtask" -class Edge(BaseModel): +@dataclass +class Edge: upstream: Task downstream: Task type: EdgeType @@ -40,29 +40,39 @@ class Edge(BaseModel): def __repr__(self): return f"{self.type}: {self.upstream.friendly_name()} -> {self.downstream.friendly_name()}" - def __hash__(self) -> int: - return id(self) - + def __hash__(self) -> id: + return hash((id(self.upstream), id(self.downstream), self.type)) -class Graph(BaseModel): - tasks: set[Task] = set() - edges: set[Edge] = set() - _cache: dict[str, dict[Task, list[Task]]] = {} - def __init__(self): - super().__init__() - - @classmethod - def from_tasks(cls, tasks: list[Task]) -> "Graph": - graph = cls() - for task in tasks: - graph.add_task(task) - return graph +class Graph: + def __init__(self, tasks: list[Task] = None, edges: list[Edge] = None): + self.tasks: set[Task] = set() + self.edges: set[Edge] = set() + self._cache: dict[str[dict[Task, list[Task]]]] = {} + if tasks: + for task in tasks: + self.add_task(task) + if edges: + for edge in edges: + self.add_edge(edge) def add_task(self, task: Task): if task in self.tasks: return + self.tasks.add(task) + + # add the task's parent + if task.parent: + self.add_edge( + Edge( + upstream=task, + downstream=task.parent, + type=EdgeType.SUBTASK, + ) + ) + + # add the task's subtasks for subtask in task._subtasks: self.add_edge( Edge( @@ -72,6 +82,7 @@ def add_task(self, task: Task): ) ) + # add the task's dependencies for upstream in task.depends_on: if upstream not in task._subtasks: self.add_edge( @@ -129,7 +140,7 @@ def upstream_tasks( f"upstream_{'immediate' if immediate else 'all'}_{tuple(start_tasks)}" ) if cache_key not in self._cache: - result = set() + result = set(start_tasks) visited = set() def _upstream(task): @@ -137,10 +148,7 @@ def _upstream(task): return visited.add(task) for edge in self.upstream_edges().get(task, []): - if ( - edge.upstream not in visited - and edge.upstream not in start_tasks - ): + if edge.upstream not in visited: result.add(edge.upstream) if not immediate: _upstream(edge.upstream) @@ -172,7 +180,7 @@ def downstream_tasks( f"downstream_{'immediate' if immediate else 'all'}_{tuple(start_tasks)}" ) if cache_key not in self._cache: - result = set() + result = set(start_tasks) visited = set() def _downstream(task): @@ -180,10 +188,7 @@ def _downstream(task): return visited.add(task) for edge in self.downstream_edges().get(task, []): - if ( - edge.downstream not in visited - and edge.downstream not in start_tasks - ): + if edge.downstream not in visited: result.add(edge.downstream) if not immediate: _downstream(edge.downstream) @@ -199,7 +204,7 @@ def _downstream(task): def topological_sort(self, tasks: Optional[list[Task]] = None) -> list[Task]: """ - Perform a topological sort on the provided tasks or all tasks in the graph. + Perform a deterministic topological sort on the provided tasks or all tasks in the graph. Args: tasks (Optional[list[Task]]): A list of tasks to sort topologically. @@ -208,6 +213,15 @@ def topological_sort(self, tasks: Optional[list[Task]] = None) -> list[Task]: Returns: list[Task]: A list of tasks in topological order (upstream tasks first). """ + # Create a cache key based on the input tasks + cache_key = ( + f"topo_sort_{tuple(sorted(task.id for task in (tasks or self.tasks)))}" + ) + + # Check if the result is already in the cache + if cache_key in self._cache: + return self._cache[cache_key] + if tasks is None: tasks_to_sort = self.tasks else: @@ -222,6 +236,8 @@ def topological_sort(self, tasks: Optional[list[Task]] = None) -> list[Task]: # Kahn's algorithm for topological sorting result = [] no_incoming = [task for task in tasks_to_sort if not dependencies[task]] + # sort to create a deterministic order + no_incoming.sort(key=lambda t: t.created_at) while no_incoming: task = no_incoming.pop(0) @@ -233,6 +249,8 @@ def topological_sort(self, tasks: Optional[list[Task]] = None) -> list[Task]: dependencies[dependent_task].remove(task) if not dependencies[dependent_task]: no_incoming.append(dependent_task) + # resort to maintain deterministic order + no_incoming.sort(key=lambda t: t.created_at) # Check for cycles if len(result) != len(tasks_to_sort): @@ -240,4 +258,6 @@ def topological_sort(self, tasks: Optional[list[Task]] = None) -> list[Task]: "The graph contains a cycle and cannot be topologically sorted" ) + # Cache the result before returning + self._cache[cache_key] = result return result diff --git a/src/controlflow/flows/history.py b/src/controlflow/flows/history.py deleted file mode 100644 index cf519dbe..00000000 --- a/src/controlflow/flows/history.py +++ /dev/null @@ -1,110 +0,0 @@ -import abc -import datetime -import json -import math -from pathlib import Path - -from pydantic import Field, field_validator - -import controlflow -from controlflow.llm.messages import MessageType -from controlflow.utilities.types import ControlFlowModel - -# This is a global variable that will be shared between all instances of InMemoryHistory -IN_MEMORY_HISTORY = {} - - -def get_default_history() -> "History": - return controlflow.default_history - - -class History(ControlFlowModel, abc.ABC): - @abc.abstractmethod - def load_messages( - self, - thread_id: str, - limit: int = None, - before: datetime.datetime = None, - after: datetime.datetime = None, - ) -> list[MessageType]: - raise NotImplementedError() - - @abc.abstractmethod - def save_messages(self, thread_id: str, messages: list[MessageType]): - raise NotImplementedError() - - -class InMemoryHistory(History): - history: dict[str, list[MessageType]] = Field( - default_factory=lambda: IN_MEMORY_HISTORY - ) - - def load_messages( - self, - thread_id: str, - limit: int = None, - before: datetime.datetime = None, - after: datetime.datetime = None, - ) -> list[MessageType]: - messages = self.history.get(thread_id, []) - filtered_messages = [ - msg - for i, msg in enumerate(reversed(messages)) - if (before is None or msg.timestamp < before) - and (after is None or msg.timestamp > after) - and i < (limit or math.inf) - ] - return list(reversed(filtered_messages)) - - def save_messages(self, thread_id: str, messages: list[MessageType]): - self.history.setdefault(thread_id, []).extend(messages) - - -class FileHistory(History): - base_path: Path = Field( - default_factory=lambda: controlflow.settings.home_path / "history" - ) - - def path(self, thread_id: str) -> Path: - return self.base_path / f"{thread_id}.json" - - @field_validator("base_path", mode="before") - def _validate_path(cls, v): - v = Path(v).expanduser() - if not v.exists(): - v.mkdir(parents=True, exist_ok=True) - return v - - def load_messages( - self, - thread_id: str, - limit: int = None, - before: datetime.datetime = None, - after: datetime.datetime = None, - ) -> list[MessageType]: - if not self.path(thread_id).exists(): - return [] - - with open(self.path(thread_id), "r") as f: - all_messages = json.load(f) - - messages = [] - for msg in reversed(all_messages): - message = MessageType.model_validate(msg) - if before is None or message.timestamp < before: - if after is None or message.timestamp > after: - messages.append(message) - if len(messages) >= limit or math.inf: - break - - return list(reversed(messages)) - - def save_messages(self, thread_id: str, messages: list[MessageType]): - if self.path(thread_id).exists(): - with open(self.path(thread_id), "r") as f: - all_messages = json.load(f) - else: - all_messages = [] - all_messages.extend([msg.model_dump(mode="json") for msg in messages]) - with open(self.path(thread_id), "w") as f: - json.dump(all_messages, f) diff --git a/src/controlflow/handlers/print_handler.py b/src/controlflow/handlers/print_handler.py deleted file mode 100644 index 9130e73d..00000000 --- a/src/controlflow/handlers/print_handler.py +++ /dev/null @@ -1,186 +0,0 @@ -import datetime - -import rich -from rich import box -from rich.console import Group -from rich.live import Live -from rich.markdown import Markdown -from rich.panel import Panel -from rich.spinner import Spinner -from rich.table import Table - -import controlflow -from controlflow.llm.handlers import CompletionHandler -from controlflow.llm.messages import ( - AIMessage, - AIMessageChunk, - MessageType, - SystemMessage, - ToolCall, - ToolMessage, - UserMessage, -) -from controlflow.utilities.rich import console as cf_console - - -class PrintHandler(CompletionHandler): - def __init__(self): - self.messages: dict[str, MessageType] = {} - self.live: Live = Live(auto_refresh=True, console=cf_console) - self.paused_id: str = None - super().__init__() - - def on_start(self): - try: - self.live.start() - except rich.errors.LiveError: - pass - - def on_end(self): - self.live.stop() - - def on_exception(self, exc: Exception): - self.live.stop() - - def update_live(self, latest: MessageType = None): - # sort by timestamp, using the custom message id as a tiebreaker - # in case the same message appears twice (e.g. tool call and message) - messages = sorted(self.messages.items(), key=lambda m: (m[1].timestamp, m[0])) - content = [] - - tool_results = {} # To track tool results by their call ID - - # gather all tool messages first - for _, message in messages: - if isinstance(message, ToolMessage): - tool_results[message.tool_call_id] = message - - for _, message in messages: - if isinstance(message, (SystemMessage, UserMessage, AIMessage)): - content.append(format_message(message, tool_results=tool_results)) - # no need to handle tool messages - - if self.live.is_started: - self.live.update(Group(*content), refresh=True) - elif latest: - cf_console.print(format_message(latest)) - - def on_message_delta(self, delta: AIMessageChunk, snapshot: AIMessageChunk): - self.messages[snapshot.id] = snapshot - self.update_live() - - def on_message_done(self, message: AIMessage): - self.messages[message.id] = message - self.update_live(latest=message) - - def on_tool_call_delta(self, delta: AIMessageChunk, snapshot: AIMessageChunk): - self.messages[snapshot.id] = snapshot - self.update_live() - - def on_tool_call_done(self, message: AIMessage): - self.messages[message.id] = message - self.update_live(latest=message) - - def on_tool_result_created(self, message: AIMessage, tool_call: ToolCall): - # if collecting input on the terminal, pause the live display - # to avoid overwriting the input prompt - if tool_call["name"] == "talk_to_user": - self.paused_id = tool_call["id"] - self.live.stop() - self.messages.clear() - - def on_tool_result_done(self, message: ToolMessage): - self.messages[f"tool-result:{message.tool_call_id}"] = message - - # if we were paused, resume the live display - if self.paused_id and self.paused_id == message.tool_call_id: - self.paused_id = None - # print newline to avoid odd formatting issues - print() - self.live = Live(auto_refresh=False) - self.live.start() - self.update_live(latest=message) - - -def format_timestamp(timestamp: datetime.datetime) -> str: - local_timestamp = timestamp.astimezone() - return local_timestamp.strftime("%I:%M:%S %p").lstrip("0").rjust(11) - - -def status(icon, text) -> Table: - t = Table.grid(padding=1) - t.add_row(icon, text) - return t - - -ROLE_COLORS = { - "system": "gray", - "ai": "blue", - "user": "green", -} -ROLE_NAMES = { - "system": "System", - "ai": "Agent", - "user": "User", -} - - -def format_message(message: MessageType, tool_results: dict = None) -> Panel: - if message.role == "ai" and message.name: - title = f"Agent: {message.name}" - else: - title = ROLE_NAMES.get(message.role, "Agent") - - content = [] - if message.str_content: - content.append(Markdown(message.str_content or "")) - - tool_content = [] - for tool_call in getattr(message, "tool_calls", []): - tool_result = (tool_results or {}).get(tool_call["id"]) - if tool_result: - c = format_tool_result(tool_result) - - else: - c = format_tool_call(tool_call) - if c: - tool_content.append(c) - - if content and tool_content: - content.append("\n") - - return Panel( - Group(*content, *tool_content), - title=f"[bold]{title}[/]", - subtitle=f"[italic]{format_timestamp(message.timestamp)}[/]", - title_align="left", - subtitle_align="right", - border_style=ROLE_COLORS.get(message.role, "red"), - box=box.ROUNDED, - width=100, - expand=True, - padding=(1, 2), - ) - - -def format_tool_call(tool_call: ToolCall) -> Panel: - name = tool_call["name"] - args = tool_call["args"] - if controlflow.settings.tools_verbose: - return status(Spinner("dots"), f'Tool call: "{name}"\n\nTool args: {args}') - return status(Spinner("dots"), f'Tool call: "{name}"') - - -def format_tool_result(message: ToolMessage) -> Panel: - name = message.tool_call["name"] - - if message.is_error: - icon = ":x:" - else: - icon = ":white_check_mark:" - - if controlflow.settings.tools_verbose: - msg = f'Tool call: "{name}"\n\nTool result: {message.str_content}' - else: - msg = f'Tool call: "{name}"' - return status(icon, msg) diff --git a/src/controlflow/llm/__init__.py b/src/controlflow/llm/__init__.py index b1dda89d..8383c810 100644 --- a/src/controlflow/llm/__init__.py +++ b/src/controlflow/llm/__init__.py @@ -1 +1 @@ -from controlflow.llm import models, messages, tools, handlers, completions, rules +from controlflow.llm import models, messages, rules diff --git a/src/controlflow/llm/classify.py b/src/controlflow/llm/classify.py deleted file mode 100644 index 7458e77e..00000000 --- a/src/controlflow/llm/classify.py +++ /dev/null @@ -1,104 +0,0 @@ -from typing import Union - -import tiktoken -from langchain_openai import AzureChatOpenAI, ChatOpenAI -from pydantic import TypeAdapter - -import controlflow -from controlflow.llm.messages import AIMessage, SystemMessage, UserMessage -from controlflow.llm.models import BaseChatModel - - -def classify( - data: str, - labels: list, - instructions: str = None, - context: dict = None, - model: BaseChatModel = None, -): - try: - label_strings = [TypeAdapter(type(t)).dump_json(t).decode() for t in labels] - except Exception as exc: - raise ValueError(f"Unable to cast labels to strings: {exc}") - - messages = [ - SystemMessage( - """ - You are an expert classifier that always maintains as much semantic meaning - as possible when labeling information. You use inference or deduction whenever - necessary to understand missing or omitted data. Classify the provided data, - text, or information as one of the provided labels. For boolean labels, - consider "truthy" or affirmative inputs to be "true". - - ## Labels - - You must classify the data as one of the following labels, which are - numbered (starting from 0) and provide a brief description. Output - the label number only. - - {% for label in labels %} - - Label {{ loop.index0 }}: {{ label }} - {% endfor %} - """ - ).render(labels=label_strings), - UserMessage( - """ - ## Information to classify - - {{ data }} - - {% if instructions -%} - ## Additional instructions - - {{ instructions }} - {% endif %} - - {% if context -%} - ## Additional context - - {% for key, value in context.items() -%} - - {{ key }}: {{ value }} - - {% endfor %} - {% endif %} - - """ - ).render(data=data, instructions=instructions, context=context), - AIMessage(""" - The best label for the data is Label number - """), - ] - - model = model or controlflow.llm.models.get_default_model() - - kwargs = {} - if isinstance(model, (ChatOpenAI, AzureChatOpenAI)): - openai_kwargs = _openai_kwargs(model=model, n_labels=len(labels)) - kwargs.update(openai_kwargs) - else: - messages.append( - SystemMessage( - "Return only the label number, no other information or tokens." - ) - ) - - result = controlflow.llm.completions.completion( - messages=messages, - model=model, - max_tokens=1, - **kwargs, - ) - - index = int(result[0].content) - return labels[index] - - -def _openai_kwargs(model: Union[AzureChatOpenAI, ChatOpenAI], n_labels: int): - encoding = tiktoken.encoding_for_model(model.model_name) - - logit_bias = {} - for i in range(n_labels): - for token in encoding.encode(str(i)): - logit_bias[token] = 100 - - return dict(logit_bias=logit_bias) diff --git a/src/controlflow/llm/completions.py b/src/controlflow/llm/completions.py deleted file mode 100644 index 5f104965..00000000 --- a/src/controlflow/llm/completions.py +++ /dev/null @@ -1,399 +0,0 @@ -import math -from typing import TYPE_CHECKING, AsyncGenerator, Callable, Generator, Optional, Union - -import langchain_core.language_models as lc_models -import langchain_core.messages -import tiktoken -from langchain_core.messages.utils import trim_messages - -import controlflow -import controlflow.llm.models -from controlflow.llm.handlers import ( - CompletionEvent, - CompletionHandler, - ResponseHandler, -) -from controlflow.llm.messages import AIMessage, AIMessageChunk, BaseMessage, MessageType -from controlflow.llm.tools import ToolCall, as_tools, handle_tool_call - -if TYPE_CHECKING: - from controlflow.agents.agent import Agent - - -def token_counter(message: langchain_core.messages.BaseMessage) -> int: - # always use gpt-3.5 token counter with the entire message object; we only need to be approximate here - return len(tiktoken.encoding_for_model("gpt-3.5-turbo").encode(message.json())) - - -def handle_tool_calls( - message: AIMessage, - tools: list[Callable], - response_messages: list[MessageType], - agent: Optional["Agent"] = None, -): - """ - Emit events for the given message when it has tool calls. - """ - for tool_call in message.tool_calls: - yield CompletionEvent( - type="tool_result_created", - payload=dict(message=message, tool_call=tool_call), - ) - error = handle_multiple_talk_to_user_calls(tool_call, message) - tool_result_message = handle_tool_call( - tool_call, tools, error=error, agent=agent - ) - response_messages.append(tool_result_message) - yield CompletionEvent( - type="tool_result_done", payload=dict(message=tool_result_message) - ) - - -def handle_delta_events( - delta: langchain_core.messages.AIMessageChunk, - snapshot: langchain_core.messages.AIMessageChunk, - deltas: list[langchain_core.messages.AIMessageChunk], - agent: "Agent", -): - """ - Emit events for the given delta message. - - Note this function receives langchain messages - """ - delta = AIMessageChunk.from_langchain_message(delta, agent=agent) - snapshot = AIMessageChunk.from_langchain_message(snapshot, agent=agent) - - if delta.content: - if not deltas[-1].content: - yield CompletionEvent(type="message_created", payload=dict(delta=delta)) - if delta.content != deltas[-1].content: - yield CompletionEvent( - type="message_delta", - payload=dict(delta=delta, snapshot=snapshot), - ) - - if delta.tool_calls: - if not deltas[-1].tool_calls: - yield CompletionEvent(type="tool_call_created", payload=dict(delta=delta)) - yield CompletionEvent( - type="tool_call_delta", - payload=dict(delta=delta, snapshot=snapshot), - ) - - -def handle_done_events(message: AIMessage): - """ - Emit events for the given message when it has been processed. - """ - if message.content: - yield CompletionEvent(type="message_done", payload=dict(message=message)) - if message.tool_calls: - yield CompletionEvent(type="tool_call_done", payload=dict(message=message)) - - -def handle_multiple_talk_to_user_calls(tool_call: ToolCall, message: AIMessage): - if ( - tool_call["name"] == "talk_to_user" - and len([t for t in message.tool_calls if t["name"] == "talk_to_user"]) > 1 - ): - error = 'Tool call "talk_to_user" can only be used once per turn.' - else: - error = None - return error - - -def prepare_messages(messages: list[MessageType]) -> list[MessageType]: - """ - Make any necessary modifications to the messages before they are passed to the model. - """ - return messages - - -def _completion_generator( - messages: list[MessageType], - model: lc_models.BaseChatModel, - tools: Optional[list[Callable]], - max_iterations: int, - stream: bool, - agent: Optional["Agent"] = None, - **kwargs, -) -> Generator[CompletionEvent, None, None]: - response_messages = [] - response_message = None - - if tools: - model = model.bind_tools([t.to_lc_tool() for t in as_tools(tools)]) - - counter = 0 - try: - yield CompletionEvent(type="start", payload={}) - - # continue as long as the last response message contains tool calls (or - # there is no response message yet) - while not response_message or response_message.tool_calls: - input_messages = [ - m.to_langchain_message() - for m in messages + response_messages - if isinstance(m, BaseMessage) - ] - input_messages = trim_messages( - messages=input_messages, - max_tokens=controlflow.settings.max_input_tokens, - include_system=True, - token_counter=token_counter, - ) - - if not stream: - response_message = model.invoke(input=input_messages, **kwargs) - response_message = AIMessage.from_langchain_message( - response_message, agent=agent - ) - - else: - # all streaming responses are langchain Pydantic v1 models - # which we don't convert to AIMessage/AIMessageChunks for sanity. - # they are converted in handle_delta_events and when the stream is finished. - - # initialize the list of deltas with an empty delta - # to facilitate comparison with the previous delta - deltas = [langchain_core.messages.AIMessageChunk(content="")] - - for i, delta in enumerate(model.stream(input=input_messages, **kwargs)): - if i == 0: - snapshot = delta - else: - snapshot = snapshot + delta - - yield from handle_delta_events( - delta=delta, snapshot=snapshot, deltas=deltas, agent=agent - ) - - deltas.append(delta) - - # the last snapshot message is the response message - response_message = AIMessage.from_langchain_message( - snapshot, agent=agent - ) - - # handle done events for the response message - yield from handle_done_events(response_message) - - # append the response message to the list of response messages - response_messages.append(response_message) - - # handle tool calls - yield from handle_tool_calls( - message=response_message, - tools=tools, - response_messages=response_messages, - agent=agent, - ) - - counter += 1 - if counter >= (max_iterations or math.inf): - break - - except (BaseException, Exception) as exc: - yield CompletionEvent(type="exception", payload=dict(exc=exc)) - raise - finally: - yield CompletionEvent(type="end", payload={}) - - -async def _completion_async_generator( - messages: list[MessageType], - model: lc_models.BaseChatModel, - tools: Optional[list[Callable]], - max_iterations: int, - stream: bool, - agent: Optional["Agent"] = None, - **kwargs, -) -> AsyncGenerator[CompletionEvent, None]: - response_messages = [] - response_message = None - - if tools: - model = model.bind_tools([t.to_lc_tool() for t in as_tools(tools)]) - - counter = 0 - try: - yield CompletionEvent(type="start", payload={}) - - # continue as long as the last response message contains tool calls (or - # there is no response message yet) - while not response_message or response_message.tool_calls: - input_messages = [ - m.to_langchain_message() if isinstance(m, BaseMessage) else m - for m in messages + response_messages - ] - - input_messages = trim_messages( - messages=input_messages, - max_tokens=controlflow.settings.max_input_tokens, - include_system=True, - token_counter=token_counter, - ) - input_messages = [ - m.to_langchain_message() - for m in input_messages - if isinstance(m, BaseMessage) - ] - - if not stream: - response_message = await model.ainvoke(input=input_messages, **kwargs) - response_message = AIMessage.from_langchain_message( - response_message, agent=agent - ) - - else: - # all streaming responses are langchain Pydantic v1 models - # which we don't convert to AIMessage/AIMessageChunks for sanity. - # they are converted in handle_delta_events and when the stream is finished. - - # initialize the list of deltas with an empty delta - # to facilitate comparison with the previous delta - deltas = [langchain_core.messages.AIMessageChunk(content="")] - - async for i, delta in enumerate( - model.astream(input=input_messages, **kwargs) - ): - if i == 0: - snapshot = delta - else: - snapshot = snapshot + delta - - for event in handle_delta_events( - delta=delta, snapshot=snapshot, deltas=deltas, agent=agent - ): - yield event - - deltas.append(delta) - - # the last snapshot message is the response message - response_message = AIMessage.from_langchain_message( - snapshot, agent=agent - ) - - # handle done events for the response message - for event in handle_done_events(response_message): - yield event - - # append the response message to the list of response messages - response_messages.append(response_message) - - # handle tool calls - for event in handle_tool_calls( - message=response_message, - tools=tools, - response_messages=response_messages, - agent=agent, - ): - yield event - - counter += 1 - if counter >= (max_iterations or math.inf): - break - - except (BaseException, Exception) as exc: - yield CompletionEvent(type="exception", payload=dict(exc=exc)) - raise - finally: - yield CompletionEvent(type="end", payload={}) - - -def _handle_events( - generator: Generator[CompletionEvent, None, None], handlers: list[CompletionHandler] -) -> Generator[CompletionEvent, None, None]: - for event in generator: - for handler in handlers: - try: - handler.on_event(event) - except Exception as exc: - generator.throw(exc) - yield event - - -async def _handle_events_async( - generator: AsyncGenerator, handlers: list[CompletionHandler] -) -> AsyncGenerator[CompletionEvent, None]: - async for event in generator: - for handler in handlers: - try: - handler.on_event(event) - except Exception as exc: - await generator.athrow(exc) - yield event - - -def completion( - messages: list[MessageType], - model: lc_models.BaseChatModel = None, - tools: list[Callable] = None, - max_iterations: int = None, - handlers: list[CompletionHandler] = None, - stream: bool = False, - agent: Optional["Agent"] = None, - **kwargs, -) -> Union[list[MessageType], Generator[MessageType, None, None]]: - if model is None: - model = controlflow.llm.models.get_default_model() - - response_handler = ResponseHandler() - handlers = handlers or [] - handlers.append(response_handler) - - completion_generator = _completion_generator( - messages=messages, - model=model, - tools=tools, - max_iterations=max_iterations, - stream=stream, - agent=agent, - **kwargs, - ) - - handlers_generator = _handle_events(completion_generator, handlers) - - if stream: - return handlers_generator - else: - for _ in handlers_generator: - pass - return response_handler.response_messages - - -async def completion_async( - messages: list[MessageType], - model: lc_models.BaseChatModel = None, - tools: list[Callable] = None, - max_iterations: int = None, - handlers: list[CompletionHandler] = None, - stream: bool = False, - agent: Optional["Agent"] = None, - **kwargs, -) -> Union[list[MessageType], Generator[MessageType, None, None]]: - if model is None: - model = controlflow.llm.models.get_default_model() - - response_handler = ResponseHandler() - handlers = handlers or [] - handlers.append(response_handler) - - completion_generator = _completion_async_generator( - messages=messages, - model=model, - tools=tools, - max_iterations=max_iterations, - stream=stream, - agent=agent, - **kwargs, - ) - - handlers_generator = _handle_events_async(completion_generator, handlers) - - if stream: - return handlers_generator - else: - async for _ in handlers_generator: - pass - return response_handler.response_messages diff --git a/src/controlflow/llm/handlers.py b/src/controlflow/llm/handlers.py deleted file mode 100644 index efaa6d53..00000000 --- a/src/controlflow/llm/handlers.py +++ /dev/null @@ -1,106 +0,0 @@ -from controlflow.llm.messages import ( - AIMessage, - AIMessageChunk, - MessageType, - ToolCall, - ToolMessage, -) -from controlflow.utilities.context import ctx -from controlflow.utilities.types import ControlFlowModel - - -class CompletionEvent(ControlFlowModel): - type: str - payload: dict - - -class CompletionHandler: - def __init__(self): - self._response_message_ids: set[str] = set() - - def on_event(self, event: CompletionEvent): - method = getattr(self, f"on_{event.type}", None) - if not method: - raise ValueError(f"Unknown event type: {event.type}") - method(**event.payload) - if event.type in [ - "message_done", - "tool_call_done", - "tool_result_done", - ]: - # only fire the on_response_message hook once per message - # (a message could contain both a tool call and a message) - if event.payload["message"].id not in self._response_message_ids: - self.on_response_message(event.payload["message"]) - self._response_message_ids.add(event.payload["message"].id) - - def on_start(self): - pass - - def on_end(self): - pass - - def on_exception(self, exc: Exception): - pass - - def on_message_created(self, delta: AIMessageChunk): - pass - - def on_message_delta(self, delta: AIMessageChunk, snapshot: AIMessageChunk): - pass - - def on_message_done(self, message: AIMessage): - pass - - def on_tool_call_created(self, delta: AIMessageChunk): - pass - - def on_tool_call_delta(self, delta: AIMessageChunk, snapshot: AIMessageChunk): - pass - - def on_tool_call_done(self, message: AIMessage): - pass - - def on_tool_result_created(self, message: AIMessage, tool_call: ToolCall): - pass - - def on_tool_result_done(self, message: ToolMessage): - pass - - def on_response_message(self, message: MessageType): - """ - This handler is called whenever a message is generated that should be - included in the completion history (e.g. a `message`, `tool_call` or - `tool_result`). Note that this is called *in addition* to the respective - on_*_done handlers, and can be used to quickly collect all messages - generated during a completion. Messages that satisfy multiple criteria - (e.g. a message and a tool call) will only be included once. - """ - pass - - -class ResponseHandler(CompletionHandler): - """ - A handler for collecting response messages. - """ - - def __init__(self): - super().__init__() - self.response_messages = [] - - def on_response_message(self, message: MessageType): - self.response_messages.append(message) - - -class TUIHandler(CompletionHandler): - def on_message_delta(self, delta: AIMessageChunk, snapshot: AIMessageChunk) -> None: - if tui := ctx.get("tui"): - tui.update_message(message=snapshot) - - def on_tool_call_delta(self, delta: AIMessageChunk, snapshot: AIMessageChunk): - if tui := ctx.get("tui"): - tui.update_message(message=snapshot) - - def on_tool_result_done(self, message: ToolMessage): - if tui := ctx.get("tui"): - tui.update_tool_result(message=message) diff --git a/src/controlflow/llm/messages.py b/src/controlflow/llm/messages.py index 7a85317a..76076b93 100644 --- a/src/controlflow/llm/messages.py +++ b/src/controlflow/llm/messages.py @@ -1,254 +1,23 @@ -import datetime -import re -import uuid -from typing import TYPE_CHECKING, Any, Literal, Optional, Union - -import langchain_core.messages -from langchain_core.messages.tool import InvalidToolCall, ToolCall, ToolCallChunk -from pydantic import Field, field_validator - -from controlflow.utilities.jinja import jinja_env -from controlflow.utilities.types import ControlFlowModel - -if TYPE_CHECKING: - from controlflow.agents.agent import Agent - - -class BaseMessage(ControlFlowModel): - """ - ControlFlow uses Message objects that are similar to LangChain messages, but more purpose built. - - Note that LangChain messages are Pydantic V1 models, while ControlFlow messages are Pydantic V2 models. - """ - - id: Optional[str] = Field(default_factory=lambda: uuid.uuid4().hex) - timestamp: datetime.datetime = Field( - default_factory=lambda: datetime.datetime.now(datetime.timezone.utc), - ) - role: str - content: Union[str, list[Union[str, dict]]] - name: Optional[str] = None - - # private attr to hold the original langchain message - _langchain_message: Optional[ - Union[ - langchain_core.messages.BaseMessage, - langchain_core.messages.BaseMessageChunk, - ] - ] = None - - @field_validator("name") - def _sanitize_name(cls, v): - # sanitize name for API compatibility - OpenAI API only allows alphanumeric characters, dashes, and underscores - if v is not None: - v = re.sub(r"[^a-zA-Z0-9_-]", "-", v).strip("-") - return v - - def __init__( - self, - content: Union[str, list[Union[str, dict]]], - _langchain_message: Optional[ - Union[ - langchain_core.messages.BaseMessage, - langchain_core.messages.BaseMessageChunk, - ] - ] = None, - **kwargs: Any, - ) -> None: - """Pass in content as positional arg.""" - super().__init__(content=content, **kwargs) - self._langchain_message = _langchain_message - - @property - def str_content(self) -> str: - if not isinstance(self.content, str): - return str(self.content) - return self.content - - def render(self, **kwargs) -> "MessageType": - """ - Renders the content as a jinja template with the given keyword arguments - and returns a new Message. - """ - content = jinja_env.from_string(self.content).render(**kwargs) - return self.model_copy(update=dict(content=content)) - - @classmethod - def _langchain_message_kwargs( - cls, message: langchain_core.messages.BaseMessage - ) -> "BaseMessage": - return message.dict(include={"content", "id", "name"}) | dict( - _langchain_message=message - ) - - @classmethod - def from_langchain_message(message: langchain_core.messages.BaseMessage, **kwargs): - raise NotImplementedError() - - def to_langchain_message(self) -> langchain_core.messages.BaseMessage: - raise NotImplementedError() - - -class AgentReference(ControlFlowModel): - name: str - - -class AgentMessageMixin(ControlFlowModel): - agent: Optional[AgentReference] = None - - @field_validator("agent", mode="before") - def _validate_agent(cls, v): - from controlflow.agents.agent import Agent - - if isinstance(v, Agent): - return AgentReference(name=v.name) - return v - - def __init__(self, *args, agent: "Agent" = None, **data): - if agent is not None and data.get("name") is None: - data["name"] = agent.name - super().__init__(*args, agent=agent, **data) - - -class AIMessage(BaseMessage, AgentMessageMixin): - role: Literal["ai"] = "ai" - tool_calls: list[ToolCall] = [] - - is_delta: bool = False - - def __init__(self, *args, **data): - super().__init__(*args, **data) - - # GPT-4 models somtimes use a hallucinated parallel tool calling mechanism - # whose name is not compatible with the API's restrictions on tool names - for tool_call in self.tool_calls: - if tool_call["name"] == "multi_tool_use.parallel": - tool_call["name"] = "multi_tool_use_parallel" - - def has_tool_calls(self) -> bool: - return any(self.tool_calls) - - @classmethod - def from_langchain_message( - cls, - message: langchain_core.messages.AIMessage, - **kwargs, - ): - data = dict( - **cls._langchain_message_kwargs(message), - tool_calls=message.tool_calls + getattr(message, "invalid_tool_calls", []), - ) - - return cls(**data | kwargs) - - def to_langchain_message( - self, - ) -> langchain_core.messages.AIMessage: - if self._langchain_message is not None: - return self._langchain_message - return langchain_core.messages.AIMessage( - content=self.content, tool_calls=self.tool_calls, id=self.id, name=self.name - ) - - -class AIMessageChunk(AIMessage): - tool_calls: list[ToolCallChunk] = [] - - @classmethod - def from_langchain_message( - cls, - message: Union[ - langchain_core.messages.AIMessageChunk, langchain_core.messages.AIMessage - ], - **kwargs, - ): - if isinstance(message, langchain_core.messages.AIMessageChunk): - tool_calls = message.tool_call_chunks - elif isinstance(message, langchain_core.messages.AIMessage): - tool_calls = [] - for i, call in enumerate(message.tool_calls): - tool_calls.append( - ToolCallChunk( - id=call["id"], - name=call["name"], - args=str(call["args"]) if call["args"] else None, - index=call.get("index", i), - ) - ) - data = dict( - **cls._langchain_message_kwargs(message), - tool_calls=tool_calls, - is_delta=True, - ) - - return cls(**data | kwargs) - - def to_langchain_message( - self, - ) -> langchain_core.messages.AIMessageChunk: - if self._langchain_message is not None: - return self._langchain_message - return langchain_core.messages.AIMessageChunk( - content=self.content, - tool_call_chunks=self.tool_calls, - id=self.id, - name=self.name, - ) - - -class UserMessage(BaseMessage): - role: Literal["user"] = "user" - - @classmethod - def from_langchain_message( - cls, message: langchain_core.messages.HumanMessage, **kwargs - ): - return cls(**cls._langchain_message_kwargs(message) | kwargs) - - def to_langchain_message(self) -> langchain_core.messages.BaseMessage: - if self._langchain_message is not None: - return self._langchain_message - return langchain_core.messages.HumanMessage( - content=self.content, id=self.id, name=self.name - ) - - -class SystemMessage(BaseMessage): - role: Literal["system"] = "system" - - @classmethod - def from_langchain_message( - cls, message: langchain_core.messages.SystemMessage, **kwargs - ): - return cls(**cls._langchain_message_kwargs(message) | kwargs) - - def to_langchain_message(self) -> langchain_core.messages.BaseMessage: - if self._langchain_message is not None: - return self._langchain_message - return langchain_core.messages.SystemMessage( - content=self.content, id=self.id, name=self.name - ) - - -class ToolMessage(BaseMessage, AgentMessageMixin): - model_config = dict(arbitrary_types_allowed=True) - role: Literal["tool"] = "tool" - tool_call_id: str - tool_call: Union[ToolCall, InvalidToolCall] - tool_result: Any = Field(None, exclude=True) - tool_metadata: dict[str, Any] = {} - is_error: bool = False - - def to_langchain_message(self) -> langchain_core.messages.BaseMessage: - if self._langchain_message is not None: - return self._langchain_message - else: - return langchain_core.messages.ToolMessage( - id=self.id, - name=self.name, - content=self.content, - tool_call_id=self.tool_call_id, - ) - - -MessageType = Union[UserMessage, AIMessage, SystemMessage, ToolMessage] +from langchain_core.messages import ( + AIMessage, + AIMessageChunk, + BaseMessage, + HumanMessage, + InvalidToolCall, + SystemMessage, + ToolCall, + ToolCallChunk, + ToolMessage, +) + +__all__ = [ + "AIMessage", + "AIMessageChunk", + "BaseMessage", + "HumanMessage", + "InvalidToolCall", + "SystemMessage", + "ToolCall", + "ToolCallChunk", + "ToolMessage", +] diff --git a/src/controlflow/llm/models.py b/src/controlflow/llm/models.py index d5881fd9..aee3d9bd 100644 --- a/src/controlflow/llm/models.py +++ b/src/controlflow/llm/models.py @@ -45,6 +45,14 @@ def model_from_string( "To use Google models, please install the `langchain_google_genai` package." ) cls = ChatGoogleGenerativeAI + elif provider == "groq": + try: + from langchain_groq import ChatGroq + except ImportError: + raise ImportError( + "To use Groq models, please install the `langchain_groq` package." + ) + cls = ChatGroq else: raise ValueError( f"Could not load provider automatically: {provider}. Please create your model manually." diff --git a/src/controlflow/llm/rules.py b/src/controlflow/llm/rules.py index 0db1dea4..c5a3a8d6 100644 --- a/src/controlflow/llm/rules.py +++ b/src/controlflow/llm/rules.py @@ -2,9 +2,10 @@ from langchain_openai import AzureChatOpenAI, ChatOpenAI from controlflow.llm.models import BaseChatModel +from controlflow.utilities.types import ControlFlowModel -class LLMRules: +class LLMRules(ControlFlowModel): """ LLM rules let us tailor DAG compilation, message generation, tool use, and other behavior to the requirements of different LLM provider APIs. @@ -13,14 +14,17 @@ class LLMRules: necessary. """ + # require at least one non-system message + require_at_least_one_message: bool = False + # system messages can only be provided as the very first message in a thread - system_message_must_be_first: bool = False + require_system_message_first: bool = False # other than a system message, the first message must be from the user - user_message_must_be_first_after_system: bool = False + require_user_message_after_system: bool = False # the last message in a thread can't be from an AI if tool use is allowed - allow_last_message_has_ai_role_with_tools: bool = True + allow_last_message_from_ai_when_using_tools: bool = True # consecutive AI messages must be separated by a user message allow_consecutive_ai_messages: bool = True @@ -29,15 +33,19 @@ class LLMRules: # (some APIs can use the `name` field for this purpose, but others can't) add_system_messages_for_multi_agent: bool = False + # if a tool is used, the result must follow the tool call immediately + tool_result_must_follow_tool_call: bool = True + class OpenAIRules(LLMRules): pass class AnthropicRules(LLMRules): - system_message_must_be_first: bool = True - user_message_must_be_first_after_system: bool = True - allow_last_message_has_ai_role_with_tools: bool = False + require_at_least_one_message: bool = True + require_system_message_first: bool = True + require_user_message_after_system: bool = True + allow_last_message_from_ai_when_using_tools: bool = False allow_consecutive_ai_messages: bool = False diff --git a/src/controlflow/controllers/__init__.py b/src/controlflow/orchestration/__init__.py similarity index 100% rename from src/controlflow/controllers/__init__.py rename to src/controlflow/orchestration/__init__.py diff --git a/src/controlflow/orchestration/controller.py b/src/controlflow/orchestration/controller.py new file mode 100644 index 00000000..dbf0091e --- /dev/null +++ b/src/controlflow/orchestration/controller.py @@ -0,0 +1,374 @@ +import logging +from typing import AsyncGenerator, Generator, Optional, TypeVar, Union + +from pydantic import Field, field_validator + +import controlflow +from controlflow.agents import Agent +from controlflow.events.agent_events import ( + EndTurnEvent, + SelectAgentEvent, + SystemMessageEvent, +) +from controlflow.events.events import Event +from controlflow.events.task_events import TaskReadyEvent +from controlflow.flows import Flow +from controlflow.instructions import get_instructions +from controlflow.llm.messages import AIMessage +from controlflow.orchestration.handler import Handler +from controlflow.orchestration.tools import ( + create_end_turn_tool, + create_task_fail_tool, + create_task_success_tool, +) +from controlflow.tasks.task import Task +from controlflow.tools import as_tools +from controlflow.tools.tools import Tool +from controlflow.utilities.prefect import prefect_task as prefect_task +from controlflow.utilities.types import ControlFlowModel + +logger = logging.getLogger(__name__) + +T = TypeVar("T") + + +class Controller(ControlFlowModel): + """ + The controller is responsible for managing the flow of tasks and agents. It + is given objects that it is responsible for managing. At each iteration, the + controller will select a task and an agent to complete the task. The + controller will then create and execute an AgentContext to run the task. + """ + + model_config = dict(arbitrary_types_allowed=True) + flow: "Flow" = Field(description="The flow that the controller is managing") + tasks: Optional[list[Task]] = Field( + None, + description="Tasks to be completed by the controller. If None, all tasks in the flow will be used.", + ) + agents: dict[Task, list[Agent]] = Field( + default_factory=dict, + description="Optionally assign agents to complete tasks. The provided mapping must be task" + " -> [agents]. Any tasks that aren't included will use their default agents.", + ) + handlers: list[Handler] = Field(None, validate_default=True) + _ready_task_counter: int = 0 + + @field_validator("handlers", mode="before") + def _handlers(cls, v): + from controlflow.orchestration.print_handler import PrintHandler + + if v is None and controlflow.settings.enable_print_handler: + v = [PrintHandler()] + return v or [] + + @field_validator("agents", mode="before") + def _agents(cls, v): + if v is None: + v = {} + return v + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.tasks = self.tasks or self.flow.tasks + for task in self.tasks: + self.flow.add_task(task) + + def handle_event( + self, event: Event, tasks: list[Task] = None, agents: list[Agent] = None + ): + event.thread_id = self.flow.thread_id + event.task_ids = [t.id for t in tasks or []] + event.agent_ids = [a.id for a in agents or []] + for handler in self.handlers: + handler.on_event(event) + if event.persist: + self.flow.add_events([event]) + + def run_once(self): + """ + Core pipeline for running the controller. + """ + from controlflow.events.controller_events import ( + ControllerEnd, + ControllerError, + ControllerStart, + ) + + self.handle_event(ControllerStart(controller=self)) + + try: + ready_tasks = self.get_ready_tasks() + + if not ready_tasks: + self._ready_task_counter += 1 + if self._ready_task_counter >= 3: + raise ValueError("No tasks are ready to run. This is unexpected.") + return + else: + self._ready_task_counter = 0 + + # select an agent + agent = self.get_agent(ready_tasks=ready_tasks) + active_tasks = self.get_active_tasks(agent=agent, ready_tasks=ready_tasks) + + context = AgentContext( + agent=agent, + tasks=active_tasks, + flow=self.flow, + controller=self, + ) + + # run + context.run() + + except Exception as exc: + self.handle_event(ControllerError(controller=self, error=exc)) + raise + finally: + self.handle_event(ControllerEnd(controller=self)) + + async def run_once_async(self): + """ + Core pipeline for running the controller. + """ + from controlflow.events.controller_events import ( + ControllerEnd, + ControllerError, + ControllerStart, + ) + + self.handle_event(ControllerStart(controller=self)) + + try: + ready_tasks = self.get_ready_tasks() + + if not ready_tasks: + return + + # select an agent + agent = self.get_agent(ready_tasks=ready_tasks) + active_tasks = self.get_active_tasks(agent=agent, ready_tasks=ready_tasks) + + context = AgentContext( + agent=agent, + tasks=active_tasks, + flow=self.flow, + controller=self, + ) + + # run + await context.run_async() + + except Exception as exc: + self.handle_event(ControllerError(controller=self, error=exc)) + raise + finally: + self.handle_event(ControllerEnd(controller=self)) + + def run(self): + while any(t.is_incomplete() for t in self.tasks): + self.run_once() + + async def run_async(self): + while any(t.is_incomplete() for t in self.tasks): + await self.run_once_async() + + def get_ready_tasks(self) -> list[Task]: + all_tasks = self.flow.graph.upstream_tasks(self.tasks) + ready_tasks = [t for t in all_tasks if t.is_ready()] + return ready_tasks + + def get_active_tasks(self, agent: Agent, ready_tasks: list[Task]) -> list[Task]: + """ + Get the subset of ready tasks that the agent is assigned to. + """ + active_tasks = [] + for task in ready_tasks: + if agent in self.agents_for_task(task): + active_tasks.append(task) + self.handle_event(TaskReadyEvent(task=task), tasks=[task]) + if not task._prefect_task.is_started: + task._prefect_task.start( + depends_on=[t.result for t in task.depends_on] + ) + if not active_tasks: + raise ValueError("No active tasks for agent. This is unexpected.") + return active_tasks + + def agents_for_task(self, task: Task) -> list[Agent]: + return self.agents.get(task, task.get_agents()) + + def get_agent(self, ready_tasks: list[Task]) -> tuple[Agent, list[Task]]: + candidates = [ + agent + for task in ready_tasks + # get agents from either controller assignments or the task defaults + for agent in self.agents_for_task(task) + ] + + # if there is only one candidate, return it + if len(candidates) == 1: + agent = candidates[0] + + # get the last select-agent or end-turn event + agent_event: list[Union[SelectAgentEvent, EndTurnEvent]] = self.flow.get_events( + limit=1, + types=["select-agent", "end-turn"], + task_ids=[t.id for t in self.flow.graph.upstream_tasks(ready_tasks)], + ) + if agent_event: + event = agent_event[0] + # if an agent was selected and is a candidate, return it + if event.event == "select-agent": + agent = next( + (a for a in candidates if a.name == event.agent.name), None + ) + if agent: + return agent + # if an agent was nominated and is a candidate, return it + elif event.event == "end-turn" and event.next_agent_name is not None: + agent = next( + (a for a in candidates if a.name == event.next_agent_name), + None, + ) + if agent: + return agent + + # if there are multiple candiates remaining, use the first task's strategy to select one + strategy_fn = ready_tasks[0].get_agent_strategy() + agent = strategy_fn(agents=candidates, task=ready_tasks[0], flow=self.flow) + ready_tasks[0]._iteration += 1 + + self.handle_event(SelectAgentEvent(agent=agent), agents=[agent]) + return agent + + +class AgentContext(ControlFlowModel): + agent: Agent = Field(description="The active agent") + tasks: list[Task] = Field( + description="The tasks that the agent is assigned to complete that are ready to be completed" + ) + flow: Flow = Field(description="The flow that the agent is working in") + controller: Controller = Field( + description="The controller that is managing the flow" + ) + + def get_events(self) -> list[Event]: + return self.flow.get_events( + agent_ids=[self.agent.id], + task_ids=[t.id for t in self.flow.graph.upstream_tasks(self.tasks)], + ) + + def get_prompt(self, tools: list[Tool]) -> str: + from controlflow.orchestration import prompts + + # get up to 50 upstream and 50 downstream tasks + g = self.flow.graph + upstream_tasks = g.topological_sort([t for t in g.tasks if t.is_complete()])[ + -50: + ] + downstream_tasks = g.topological_sort( + [t for t in g.tasks if t.is_incomplete() and t not in self.tasks] + )[:50] + + tasks = [t.model_dump() for t in self.tasks] + upstream_tasks = [t.model_dump() for t in upstream_tasks] + downstream_tasks = [t.model_dump() for t in downstream_tasks] + + agent_prompt = prompts.AgentTemplate( + agent=self.agent, + additional_instructions=get_instructions(), + ) + + workflow_prompt = prompts.WorkflowTemplate( + flow=self.flow, + ready_tasks=tasks, + upstream_tasks=upstream_tasks, + downstream_tasks=downstream_tasks, + ) + + tool_prompt = prompts.ToolTemplate( + agent=self.agent, + has_user_access_tool="talk_to_user" in [t.name for t in tools], + has_end_turn_tool="end_turn" in [t.name for t in tools], + ) + + communication_prompt = prompts.CommunicationTemplate() + + prompts = [ + p.render() + for p in [agent_prompt, workflow_prompt, tool_prompt, communication_prompt] + ] + + return "\n\n".join(prompts) + + def get_tools(self) -> list[Tool]: + tools = [] + + # add flow tools + tools.extend(self.flow.tools) + + # add end turn tool if there are multiple agents for any task + if any(len(self.controller.agents_for_task(t)) > 1 for t in self.tasks): + tools.append( + create_end_turn_tool(controller=self.controller, agent=self.agent) + ) + + # add tools for working with tasks + for task in self.tasks: + tools.extend(task.get_tools()) + tools.append( + create_task_success_tool( + controller=self.controller, task=task, agent=self.agent + ) + ) + tools.append( + create_task_fail_tool( + controller=self.controller, task=task, agent=self.agent + ) + ) + + return as_tools(tools) + + def get_messages(self, tools: list[Tool] = None) -> list[AIMessage]: + from controlflow.events.message_compiler import EventContext, MessageCompiler + + events = self.flow.get_events( + agent_ids=[self.agent.id], + task_ids=[t.id for t in self.flow.graph.upstream_tasks(self.tasks)], + ) + + events.append( + SystemMessageEvent(content=f"{self.agent.name}, it is your turn.") + ) + + event_context = EventContext( + llm_rules=self.agent.get_llm_rules(), + agent=self.agent, + ready_tasks=self.tasks, + controller=self.controller, + flow=self.flow, + ) + + compiler = MessageCompiler( + events=events, + context=event_context, + system_prompt=self.get_prompt(tools=tools), + ) + messages = compiler.compile_to_messages() + return messages + + def run(self) -> Generator["Event", None, None]: + tools = self.get_tools() + messages = self.get_messages(tools=tools) + for event in self.agent._run_model(messages=messages, additional_tools=tools): + self.controller.handle_event(event, tasks=self.tasks, agents=[self.agent]) + + async def run_async(self) -> AsyncGenerator["Event", None]: + tools = self.get_tools() + messages = self.get_messages() + async for event in self.agent._run_model_async( + messages=messages, additional_tools=tools + ): + self.controller.handle_event(event, tasks=self.tasks, agents=[self.agent]) diff --git a/src/controlflow/orchestration/handler.py b/src/controlflow/orchestration/handler.py new file mode 100644 index 00000000..92e96695 --- /dev/null +++ b/src/controlflow/orchestration/handler.py @@ -0,0 +1,9 @@ +from controlflow.events.events import Event + + +class Handler: + def on_event(self, event: Event): + event_type = event.event.replace("-", "_") + method = getattr(self, f"on_{event_type}", None) + if method: + method(event=event) diff --git a/src/controlflow/orchestration/print_handler.py b/src/controlflow/orchestration/print_handler.py new file mode 100644 index 00000000..f7603aea --- /dev/null +++ b/src/controlflow/orchestration/print_handler.py @@ -0,0 +1,202 @@ +import datetime +from typing import Union + +import rich +from rich import box +from rich.console import Group +from rich.live import Live +from rich.markdown import Markdown +from rich.panel import Panel +from rich.spinner import Spinner +from rich.table import Table + +import controlflow +from controlflow.events.agent_events import ( + AgentMessageDeltaEvent, + AgentMessageEvent, +) +from controlflow.events.controller_events import ( + ControllerEnd, + ControllerError, + ControllerStart, +) +from controlflow.events.events import Event +from controlflow.events.tool_events import ToolCall, ToolCallEvent, ToolResultEvent +from controlflow.llm.messages import BaseMessage +from controlflow.orchestration.handler import Handler +from controlflow.utilities.rich import console as cf_console + + +class PrintHandler(Handler): + def __init__(self): + self.events: dict[str, Event] = {} + self.paused_id: str = None + super().__init__() + + def on_controller_start(self, event: ControllerStart): + self.live: Live = Live(auto_refresh=False, console=cf_console) + self.events.clear() + try: + self.live.start() + except rich.errors.LiveError: + pass + + def on_controller_end(self, event: ControllerEnd): + self.live.stop() + + def on_controller_error(self, event: ControllerError): + self.live.stop() + + def update_live(self, latest: BaseMessage = None): + events = sorted(self.events.items(), key=lambda e: (e[1].timestamp, e[0])) + content = [] + + tool_results = {} # To track tool results by their call ID + + # gather all tool events first + for _, event in events: + if isinstance(event, ToolResultEvent): + tool_results[event.tool_call["id"]] = event + + for _, event in events: + if isinstance(event, (AgentMessageDeltaEvent, AgentMessageEvent)): + if formatted := format_event(event, tool_results=tool_results): + content.append(formatted) + + if not content: + return + elif self.live.is_started: + self.live.update(Group(*content), refresh=True) + elif latest: + cf_console.print(format_event(latest)) + + def on_agent_message_delta(self, event: AgentMessageDeltaEvent): + self.events[event.snapshot_message.id] = event + self.update_live() + + def on_agent_message(self, event: AgentMessageEvent): + self.events[event.ai_message.id] = event + self.update_live() + + def on_tool_call(self, event: ToolCallEvent): + # if collecting input on the terminal, pause the live display + # to avoid overwriting the input prompt + if event.tool_call["name"] == "talk_to_user": + self.paused_id = event.tool_call["id"] + self.live.stop() + self.events.clear() + + def on_tool_result(self, event: ToolResultEvent): + self.events[f"tool-result:{event.tool_call['id']}"] = event + + # # if we were paused, resume the live display + if self.paused_id and self.paused_id == event.tool_call["id"]: + self.paused_id = None + # print newline to avoid odd formatting issues + print() + self.live = Live(auto_refresh=False) + self.live.start() + self.update_live(latest=event) + + +ROLE_COLORS = { + "system": "gray", + "ai": "blue", + "user": "green", +} +ROLE_NAMES = { + "system": "System", + "ai": "Agent", + "user": "User", +} + + +def format_timestamp(timestamp: datetime.datetime) -> str: + local_timestamp = timestamp.astimezone() + return local_timestamp.strftime("%I:%M:%S %p").lstrip("0").rjust(11) + + +def status(icon, text) -> Table: + t = Table.grid(padding=1) + t.add_row(icon, text) + return t + + +def format_event( + event: Union[AgentMessageDeltaEvent, AgentMessageEvent], + tool_results: dict[str, ToolResultEvent] = None, +) -> Panel: + title = f"Agent: {event.agent.name}" + + content = [] + if isinstance(event, AgentMessageDeltaEvent): + message = event.snapshot_message + elif isinstance(event, AgentMessageEvent): + message = event.ai_message + else: + return + + if message.content: + if isinstance(message.content, str): + content.append(Markdown(str(message.content))) + elif isinstance(message.content, dict): + if "content" in message.content: + content.append(Markdown(str(message.content["content"]))) + elif "text" in message.content: + content.append(Markdown(str(message.content["text"]))) + elif isinstance(message.content, list): + for item in message.content: + if isinstance(item, str): + content.append(Markdown(str(item))) + elif "content" in item: + content.append(Markdown(str(item["content"]))) + elif "text" in item: + content.append(Markdown(str(item["text"]))) + + tool_content = [] + for tool_call in message.tool_calls + message.invalid_tool_calls: + tool_result = (tool_results or {}).get(tool_call["id"]) + if tool_result: + c = format_tool_result(tool_result) + else: + c = format_tool_call(tool_call) + if c: + tool_content.append(c) + + if content and tool_content: + content.append("\n") + + return Panel( + Group(*content, *tool_content), + title=f"[bold]{title}[/]", + subtitle=f"[italic]{format_timestamp(event.timestamp)}[/]", + title_align="left", + subtitle_align="right", + border_style=ROLE_COLORS.get("ai", "red"), + box=box.ROUNDED, + width=100, + expand=True, + padding=(1, 2), + ) + + +def format_tool_call(tool_call: ToolCall) -> Panel: + if controlflow.settings.tools_verbose: + return status( + Spinner("dots"), + f'Tool call: "{tool_call["name"]}"\n\nTool args: {tool_call["args"]}', + ) + return status(Spinner("dots"), f'Tool call: "{tool_call["name"]}"') + + +def format_tool_result(event: ToolResultEvent) -> Panel: + if event.tool_result.is_error: + icon = ":x:" + else: + icon = ":white_check_mark:" + + if controlflow.settings.tools_verbose: + msg = f'Tool call: "{event.tool_call["name"]}"\n\nTool args: {event.tool_call["args"]}\n\nTool result: {event.tool_result.str_result}' + else: + msg = f'Tool call: "{event.tool_call["name"]}"' + return status(icon, msg) diff --git a/src/controlflow/orchestration/prompt_templates/agent.md.jinja b/src/controlflow/orchestration/prompt_templates/agent.md.jinja new file mode 100644 index 00000000..e97ce8df --- /dev/null +++ b/src/controlflow/orchestration/prompt_templates/agent.md.jinja @@ -0,0 +1,47 @@ +# Agent + +You are an AI agent participating in a workflow. Your role is to work on +your tasks and use the provided tools to complete those tasks and +communicate with the orchestrator. + +Important: The orchestrator is a Python script and cannot read or +respond to messages posted in this thread. You must use the provided +tools to communicate with the orchestrator. Posting messages in this +thread should only be used for thinking out loud, working through a +problem, or communicating with other agents. Any System messages or +messages prefixed with "SYSTEM:" are from the workflow system, not an +actual human. + +Your job is to: +1. Select one or more tasks to work on from the ready tasks. +2. Read the task instructions and work on completing the task objective, which +may involve using appropriate tools or collaborating with other agents assigned +to the same task. +3. When you (and any other agents) have completed the task objective, use the +provided tool to inform the orchestrator of the task completion and result. +4. Repeat steps 1-3 until no more tasks are available for execution. + +## Your information + +- ID: {{ agent.id }} +- Name: "{{ agent.name }}" +{% if agent.description -%} +- Description: "{{ agent.description }}" +{% endif %} + +## Instructions + +You must follow instructions at all times. Instructions can be added or removed +at any time. + +- Never impersonate another agent + +{% if agent.instructions %} +{{ agent.instructions }} +{% endif %} + +{% if additional_instructions %} +{% for instruction in additional_instructions %} +- {{ instruction }} +{% endfor %} +{% endif %} \ No newline at end of file diff --git a/src/controlflow/orchestration/prompt_templates/communication.md.jinja b/src/controlflow/orchestration/prompt_templates/communication.md.jinja new file mode 100644 index 00000000..a98dc47a --- /dev/null +++ b/src/controlflow/orchestration/prompt_templates/communication.md.jinja @@ -0,0 +1,8 @@ +# Communication + +You are are collaborating with other agents to complete tasks (or dependent +tasks). You can communicate with other agents in the thread by posting messages. +Agents can see any messages or tool calls that pertain to any task they are +assigned to. + +If your task asks you to interact with other agents or otherwise "speak out loud", post messages to accomplish that. diff --git a/src/controlflow/orchestration/prompt_templates/tools.md.jinja b/src/controlflow/orchestration/prompt_templates/tools.md.jinja new file mode 100644 index 00000000..e1453956 --- /dev/null +++ b/src/controlflow/orchestration/prompt_templates/tools.md.jinja @@ -0,0 +1,34 @@ +# Tools + +You have access to various tools. They may change, so do not rely on history to +see what tools are available. Whenever possible, use tools in parallel or when +posting a message instead of waiting for your next turn. + +{% if has_end_turn_tool %} +## Ending your turn + +Your turn will continue until you mark a task complete or you use the `end_turn` +tool. When using the `end_turn` tool, you can optionally name the agent that +should go next, or leave it blank to let the orchestrator decide. You must use +the `end_turn` tool to let another agent speak. + +{% endif %} + +{% if has_user_access_tool %} +## Talking to human users + +If your task requires you to interact with a user, it will show +`user_access=True` and you will be given a `talk_to_user` tool. You can +use it to send messages to the user and optionally wait for a response. +This is how you tell the user things and ask questions. Do not mention +your tasks or the workflow. The user can only see messages you send +them via tool. They can not read the rest of the +thread. + +Human users may give poor, incorrect, or partial responses. You may need +to ask questions multiple times in order to complete your tasks. Do not +make up answers for omitted information; ask again and only fail the +task if you truly can not make progress. If your task requires human +interaction and neither it nor any assigned agents have `user_access`, +you can fail the task. +{% endif %} \ No newline at end of file diff --git a/src/controlflow/orchestration/prompt_templates/workflow.md.jinja b/src/controlflow/orchestration/prompt_templates/workflow.md.jinja new file mode 100644 index 00000000..85682af0 --- /dev/null +++ b/src/controlflow/orchestration/prompt_templates/workflow.md.jinja @@ -0,0 +1,94 @@ +# Workflow + +As soon as you have completed a task's objective, you must use the provided +tool to mark it successful and provide a result. It may take multiple +turns or collaboration with other agents to complete a task. Any agent +assigned to a task can complete it. Once a task is complete, no other +agent can interact with it. + +Tasks should only be marked failed due to technical errors like a broken +or erroring tool or unresponsive human. + +Tasks are not ready until all of their dependencies are met. Parent +tasks depend on all of their subtasks. + +## Flow + +Name: {{ flow.name }} +{% if flow.description %} +Description: {{ flow.description }} +{% endif %} +{% if flow.context %} +Context: +{% for key, value in flow.context.items() %} +- {{ key }}: {{ value }} +{% endfor %} +{% endif %} + +## Tasks + +{% if ready_tasks %} +### Ready tasks + +These tasks are ready to be worked on because all of their dependencies have +been completed. You can only work on tasks to which you are assigned. + +{% for task in ready_tasks %} +#### Task {{ task.id }} +- objective: {{ task.objective }} +- instructions: {{ task.instructions}} +- context: {{ task.context }} +- result_type: {{ task.result_type }} +- depends_on: {{ task.depends_on }} +- parent: {{ task.parent }} +- assigned agents: {{ task.agents }} +{% if task.user_access %} +- user access: True +{% endif %} +- created_at: {{ task.created_at }} + +{% endfor %} +{% endif %} + +{% if upstream_tasks %} +### Upstream tasks + +{% for task in upstream_tasks %} +#### Task {{ task.id }} +- objective: {{ task.objective }} +- instructions: {{ task.instructions}} +- status: {{ task.status }} +- result: {{ task.result }} +- error: {{ task.error }} +- context: {{ task.context }} +- depends_on: {{ task.depends_on }} +- parent: {{ task.parent }} +- assigned agents: {{ task.agents }} +{% if task.user_access %} +- user access: True +{% endif %} +- created_at: {{ task.created_at }} + +{% endfor %} +{% endif %} + +{% if downstream_tasks %} +### Downstream tasks + +{% for task in downstream_tasks %} +#### Task {{ task.id }} +- objective: {{ task.objective }} +- instructions: {{ task.instructions}} +- status: {{ task.status }} +- result_type: {{ task.result_type }} +- context: {{ task.context }} +- depends_on: {{ task.depends_on }} +- parent: {{ task.parent }} +- assigned agents: {{ task.agents }} +{% if task.user_access %} +- user access: True +{% endif %} +- created_at: {{ task.created_at }} + +{% endfor %} +{% endif %} \ No newline at end of file diff --git a/src/controlflow/orchestration/prompts.py b/src/controlflow/orchestration/prompts.py new file mode 100644 index 00000000..b70793e0 --- /dev/null +++ b/src/controlflow/orchestration/prompts.py @@ -0,0 +1,49 @@ +from controlflow.agents import Agent +from controlflow.flows import Flow +from controlflow.utilities.jinja import prompt_env +from controlflow.utilities.types import ControlFlowModel + + +class Template(ControlFlowModel): + template_path: str + + def render(self) -> str: + if not self.should_render(): + return "" + + render_kwargs = dict(self) + template_path = render_kwargs.pop("template_path") + template_env = prompt_env.get_template(template_path) + return template_env.render(**render_kwargs) + + def should_render(self) -> bool: + return True + + +class AgentTemplate(Template): + template_path: str = "agent.md.jinja" + agent: Agent + additional_instructions: list[str] + + +class WorkflowTemplate(Template): + template_path: str = "workflow.md.jinja" + + ready_tasks: list[dict] + upstream_tasks: list[dict] + downstream_tasks: list[dict] + flow: Flow + + +class ToolTemplate(Template): + template_path: str = "tools.md.jinja" + agent: Agent + has_user_access_tool: bool + has_end_turn_tool: bool + + def should_render(self): + return self.has_user_access_tool or self.has_end_turn_tool + + +class CommunicationTemplate(Template): + template_path: str = "communication.md.jinja" diff --git a/src/controlflow/orchestration/tools.py b/src/controlflow/orchestration/tools.py new file mode 100644 index 00000000..76d86cb4 --- /dev/null +++ b/src/controlflow/orchestration/tools.py @@ -0,0 +1,105 @@ +from typing import TYPE_CHECKING, TypeVar + +from pydantic import PydanticSchemaGenerationError, TypeAdapter + +from controlflow.agents import Agent +from controlflow.events.agent_events import EndTurnEvent +from controlflow.events.task_events import TaskCompleteEvent +from controlflow.tasks.task import Task +from controlflow.tools.tools import Tool, tool + +if TYPE_CHECKING: + from controlflow.orchestration.controller import Controller + +T = TypeVar("T") + + +def generate_result_schema(result_type: type[T]) -> type[T]: + if result_type is None: + return None + + result_schema = None + # try loading pydantic-compatible schemas + try: + TypeAdapter(result_type) + result_schema = result_type + except PydanticSchemaGenerationError: + pass + # try loading as dataframe + # try: + # import pandas as pd + + # if result_type is pd.DataFrame: + # result_schema = PandasDataFrame + # elif result_type is pd.Series: + # result_schema = PandasSeries + # except ImportError: + # pass + if result_schema is None: + raise ValueError( + f"Could not load or infer schema for result type {result_type}. " + "Please use a custom type or add compatibility." + ) + return result_schema + + +def create_task_success_tool( + controller: "Controller", task: Task, agent: Agent +) -> Tool: + """ + Create an agent-compatible tool for marking this task as successful. + """ + + result_schema = generate_result_schema(task.result_type) + + @tool( + name=f"mark_task_{task.id}_successful", + description=f"Mark task {task.id} as successful.", + private=True, + ) + def succeed(result: result_schema) -> str: # type: ignore + result = task.mark_successful(result=result) + controller.handle_event(TaskCompleteEvent(task=task)) + controller.handle_event(EndTurnEvent(agent=agent)) + return result + + return succeed + + +def create_task_fail_tool(controller: "Controller", task: Task, agent: Agent) -> Tool: + """ + Create an agent-compatible tool for failing this task. + """ + + @tool( + name=f"mark_task_{task.id}_failed", + description=f"Mark task {task.id} as failed. Only use when technical errors prevent success.", + private=True, + ) + def fail(error: str) -> str: + result = task.mark_failed(error=error) + controller.handle_event(TaskCompleteEvent(task=task)) + controller.handle_event(EndTurnEvent(agent=agent)) + return result + + return fail + + +def create_end_turn_tool(controller: "Controller", agent: Agent) -> Tool: + """ + Create an agent-compatible tool for ending the turn. + """ + + @tool(private=True) + def end_turn(next_agent_name: str = None) -> str: + """ + End your turn so another agent can work. You can optionally choose + the next agent, which can be any other agent assigned to a ready task. + Choose an agent likely to help you complete your tasks. + """ + controller.handle_event( + EndTurnEvent(agent=agent, next_agent_name=next_agent_name) + ) + return "Turn ended." + + return end_turn diff --git a/src/controlflow/settings.py b/src/controlflow/settings.py index c0a1633d..a6b80252 100644 --- a/src/controlflow/settings.py +++ b/src/controlflow/settings.py @@ -45,7 +45,7 @@ class Settings(ControlFlowSettings): # ------------ display and logging settings ------------ log_prints: bool = Field( - False, + default=False, description="Whether to log workflow prints to the Prefect logger by default.", ) @@ -94,11 +94,12 @@ class Settings(ControlFlowSettings): # ------------ Debug settings ------------ tools_raise_on_error: bool = Field( - False, description="If True, an error in a tool call will raise an exception." + default=False, + description="If True, an error in a tool call will raise an exception.", ) tools_verbose: bool = Field( - False, description="If True, tools will log additional information." + default=False, description="If True, tools will log additional information." ) # ------------ Prefect settings ------------ diff --git a/src/controlflow/tasks/agent_strategies.py b/src/controlflow/tasks/agent_strategies.py index f09f30a0..f8371dfe 100644 --- a/src/controlflow/tasks/agent_strategies.py +++ b/src/controlflow/tasks/agent_strategies.py @@ -1,7 +1,5 @@ from controlflow.agents import Agent -from controlflow.flows import Flow, get_flow_messages -from controlflow.instructions import get_instructions -from controlflow.llm.classify import classify +from controlflow.flows import Flow from controlflow.tasks.task import Task @@ -14,29 +12,29 @@ def round_robin( return agents[task._iteration % len(agents)] -def moderator( - agents: list[Agent], - tasks: list[Task], - context: dict = None, - iteration: int = 0, - model: str = None, -) -> Agent: - history = get_flow_messages() - instructions = get_instructions() - context = context or {} - context.update(tasks=tasks, history=history, instructions=instructions) +# def moderator( +# agents: list[Agent], +# tasks: list[Task], +# context: dict = None, +# iteration: int = 0, +# model: str = None, +# ) -> Agent: +# history () +# instructions = get_instructions() +# context = context or {} +# context.update(tasks=tasks, history=history, instructions=instructions) - agent = classify( - context, - labels=agents, - instructions=""" - Given the context, choose the AI agent best suited to take the - next turn at completing the tasks in the task graph. Take into account - any descriptions, tasks, history, instructions, and tools. Focus on - agents assigned to upstream dependencies or subtasks that need to be - completed before their downstream/parents can be completed. An agent - can only work on a task that it is assigned to. - """, - model=model, - ) - return agent +# agent = classify( +# context, +# labels=agents, +# instructions=""" +# Given the context, choose the AI agent best suited to take the +# next turn at completing the tasks in the task graph. Take into account +# any descriptions, tasks, history, instructions, and tools. Focus on +# agents assigned to upstream dependencies or subtasks that need to be +# completed before their downstream/parents can be completed. An agent +# can only work on a task that it is assigned to. +# """, +# model=model, +# ) +# return agent diff --git a/src/controlflow/tasks/task.py b/src/controlflow/tasks/task.py index fa0e70fe..b2915523 100644 --- a/src/controlflow/tasks/task.py +++ b/src/controlflow/tasks/task.py @@ -6,7 +6,6 @@ TYPE_CHECKING, Any, Callable, - Generator, GenericAlias, Literal, Optional, @@ -22,7 +21,6 @@ TypeAdapter, field_serializer, field_validator, - model_validator, ) import controlflow @@ -44,7 +42,6 @@ ) if TYPE_CHECKING: - from controlflow.controllers.graph import Graph from controlflow.flows import Flow T = TypeVar("T") @@ -152,6 +149,24 @@ def __init__( tags=[self.__class__.__name__], ) + # create dependencies to tasks passed in as depends_on + for task in self.depends_on: + self.add_dependency(task) + + # create dependencies to tasks passed as subtasks + if self.parent is not None: + self.parent.add_subtask(self) + + # create dependencies to tasks passed in as context + context_tasks = collect_tasks(self.context) + + for task in context_tasks: + self.add_dependency(task) + + # add task to flow, if exists + if flow := controlflow.flows.get_flow(): + flow.add_task(self) + def __hash__(self): return hash((self.__class__.__name__, self.id)) @@ -172,7 +187,7 @@ def __eq__(self, other): return False def __repr__(self) -> str: - serialized = self.model_dump() + serialized = self.model_dump(include={"id", "objective"}) return f"{self.__class__.__name__}({', '.join(f'{key}={repr(value)}' for key, value in serialized.items())})" @field_validator("parent", mode="before") @@ -196,28 +211,6 @@ def _turn_list_into_literal_result_type(cls, v): return Literal[tuple(v)] # type: ignore return v - @model_validator(mode="after") - def _finalize(self): - # add task to flow, if exists - if flow := controlflow.flows.get_flow(): - flow.add_task(self) - - # create dependencies to tasks passed in as depends_on - for task in self.depends_on: - self.add_dependency(task) - - # create dependencies to tasks passed as subtasks - if self.parent is not None: - self.parent.add_subtask(self) - - # create dependencies to tasks passed in as context - context_tasks = collect_tasks(self.context) - - for task in context_tasks: - self.depends_on.add(task) - - return self - @field_serializer("parent") def _serialize_parent(self, parent: Optional["Task"]): return parent.id if parent is not None else None @@ -262,16 +255,11 @@ def friendly_name(self): objective = f'"{self.objective}"' return f"Task {self.id} ({objective})" - def as_graph(self) -> "Graph": - from controlflow.controllers.graph import Graph - - return Graph.from_tasks(tasks=[self]) - @property def subtasks(self) -> list["Task"]: - from controlflow.controllers.graph import Graph + from controlflow.flows.graph import Graph - return Graph.from_tasks(tasks=self._subtasks).topological_sort() + return Graph(tasks=self._subtasks).topological_sort() def add_subtask(self, task: "Task"): """ @@ -302,7 +290,7 @@ def run_once(self, agents: Optional[list["Agent"]] = None, flow: "Flow" = None): "Task.run_once() must be called within a flow context or with a flow argument." ) - from controlflow.controllers import Controller + from controlflow.orchestration import Controller controller = Controller( tasks=[self], flow=flow, agents={self: agents} if agents else None @@ -323,7 +311,7 @@ async def run_once_async( "Task.run_once_async() must be called within a flow context or with a flow argument." ) - from controlflow.controllers import Controller + from controlflow.orchestration import Controller controller = Controller( tasks=[self], flow=flow, agents={self: agents} if agents else None @@ -336,9 +324,9 @@ def run( agents: Optional[list["Agent"]] = None, raise_on_error: bool = True, flow: "Flow" = None, - ) -> Generator[T, None, None]: + ) -> T: """ - Internal function that can handle both sync and async runs by yielding either the result or the coroutine. + Run the task until it is complete """ from controlflow.flows import Flow, get_flow @@ -352,7 +340,7 @@ def run( else: flow = Flow() - from controlflow.controllers import Controller + from controlflow.orchestration import Controller controller = Controller( tasks=[self], flow=flow, agents={self: agents} if agents else None @@ -370,9 +358,9 @@ async def run_async( agents: Optional[list["Agent"]] = None, raise_on_error: bool = True, flow: "Flow" = None, - ) -> Generator[T, None, None]: + ) -> T: """ - Internal function that can handle both sync and async runs by yielding either the result or the coroutine. + Run the task until it is complete """ from controlflow.flows import Flow, get_flow @@ -386,7 +374,7 @@ async def run_async( else: flow = Flow() - from controlflow.controllers import Controller + from controlflow.orchestration import Controller controller = Controller( tasks=[self], flow=flow, agents={self: agents} if agents else None @@ -436,53 +424,6 @@ def is_ready(self) -> bool: """ return self.is_incomplete() and all(t.is_complete() for t in self.depends_on) - def _create_success_tool(self) -> Tool: - """ - Create an agent-compatible tool for marking this task as successful. - """ - # generate tool for result_type=None - if self.result_type is None: - - def succeed() -> str: - return self.mark_successful(result=None) - - # generate tool for other result types - else: - result_schema = generate_result_schema(self.result_type) - - def succeed(result: result_schema) -> str: # type: ignore - return self.mark_successful(result=result) - - return Tool.from_function( - succeed, - name=f"mark_task_{self.id}_successful", - description=f"Mark task {self.id} as successful.", - metadata=dict(ignore_result=True), - ) - - def _create_fail_tool(self) -> Tool: - """ - Create an agent-compatible tool for failing this task. - """ - - return Tool.from_function( - self.mark_failed, - name=f"mark_task_{self.id}_failed", - description=f"Mark task {self.id} as failed. Only use when technical errors prevent success.", - metadata=dict(ignore_result=True), - ) - - def _create_skip_tool(self) -> Tool: - """ - Create an agent-compatible tool for skipping this task. - """ - return Tool.from_function( - self.mark_skipped, - name=f"mark_task_{self.id}_skipped", - description=f"Mark task {self.id} as skipped. Only use when completing a parent task early.", - metadata=dict(ignore_result=True), - ) - def get_agents(self) -> list["Agent"]: if self.agents: return self.agents @@ -521,12 +462,6 @@ def get_agent_strategy(self) -> Callable: def get_tools(self) -> list[Union[Tool, Callable]]: tools = self.tools.copy() - # if this task is ready to run, generate tools - if self.is_ready(): - tools.extend([self._create_fail_tool(), self._create_success_tool()]) - # add skip tool if this task has a parent task - # if self.parent is not None: - # tools.append(self._create_skip_tool()) if self.user_access: tools.append(talk_to_user) return tools @@ -602,32 +537,6 @@ def generate_subtasks(self, instructions: str = None, agent: Agent = None): ) -def generate_result_schema(result_type: type[T]) -> type[T]: - result_schema = None - # try loading pydantic-compatible schemas - try: - TypeAdapter(result_type) - result_schema = result_type - except PydanticSchemaGenerationError: - pass - # try loading as dataframe - # try: - # import pandas as pd - - # if result_type is pd.DataFrame: - # result_schema = PandasDataFrame - # elif result_type is pd.Series: - # result_schema = PandasSeries - # except ImportError: - # pass - if result_schema is None: - raise ValueError( - f"Could not load or infer schema for result type {result_type}. " - "Please use a custom type or add compatibility." - ) - return result_schema - - def validate_result(result: Any, result_type: type[T]) -> T: if result_type is None and result is not None: raise ValueError("Task has result_type=None, but a result was provided.") diff --git a/src/controlflow/tools/__init__.py b/src/controlflow/tools/__init__.py index 510e8df5..4b3e4222 100644 --- a/src/controlflow/tools/__init__.py +++ b/src/controlflow/tools/__init__.py @@ -1 +1 @@ -from controlflow.llm.tools import tool, Tool, as_tools +from controlflow.tools.tools import tool, Tool, as_tools diff --git a/src/controlflow/llm/tools.py b/src/controlflow/tools/tools.py similarity index 62% rename from src/controlflow/llm/tools.py rename to src/controlflow/tools/tools.py index 1a8c7f3b..da1a50ba 100644 --- a/src/controlflow/llm/tools.py +++ b/src/controlflow/tools/tools.py @@ -2,7 +2,7 @@ import inspect import json import typing -from typing import TYPE_CHECKING, Annotated, Any, Callable, Optional, Union +from typing import Annotated, Any, Callable, Optional, Union import langchain_core.tools import pydantic @@ -12,14 +12,9 @@ from pydantic import Field, TypeAdapter import controlflow -from controlflow.llm.messages import ToolMessage from controlflow.utilities.prefect import create_markdown_artifact, prefect_task from controlflow.utilities.types import ControlFlowModel -if TYPE_CHECKING: - from controlflow.agents import Agent - - TOOL_CALL_FUNCTION_RESULT_TEMPLATE = """ # Tool call: {name} @@ -43,11 +38,13 @@ class Tool(ControlFlowModel): name: str description: str parameters: dict - metadata: dict = Field({}, exclude_none=True) + metadata: dict = {} + private: bool = False fn: Callable = Field(None, exclude=True) def to_lc_tool(self) -> dict: - return self.model_dump(include={"name", "description", "parameters"}) + payload = self.model_dump(include={"name", "description", "parameters"}) + return dict(type="function", function=payload) @prefect_task(task_run_name="Tool call: {self.name}") def run(self, input: dict): @@ -73,6 +70,30 @@ def run(self, input: dict): ) return result + @prefect_task(task_run_name="Tool call: {self.name}") + async def run_async(self, input: dict): + result = self.fn(**input) + if inspect.isawaitable(result): + result = await result + + # prepare artifact + passed_args = inspect.signature(self.fn).bind(**input).arguments + try: + # try to pretty print the args + passed_args = json.dumps(passed_args, indent=2) + except Exception: + pass + create_markdown_artifact( + markdown=TOOL_CALL_FUNCTION_RESULT_TEMPLATE.format( + name=self.name, + description=self.description or "(none provided)", + args=passed_args, + result=result, + ), + key="tool-result", + ) + return result + @classmethod def from_function( cls, fn: Callable, name: str = None, description: str = None, **kwargs @@ -134,18 +155,14 @@ def tool( *, name: Optional[str] = None, description: Optional[str] = None, - metadata: Optional[dict] = None, + **kwargs, ) -> Tool: """ Decorator for turning a function into a Tool """ if fn is None: - return functools.partial( - tool, name=name, description=description, metadata=metadata - ) - return Tool.from_function( - fn, name=name, description=description, metadata=metadata or {} - ) + return functools.partial(tool, name=name, description=description, **kwargs) + return Tool.from_function(fn, name=name, description=description, **kwargs) def as_tools( @@ -193,45 +210,87 @@ def output_to_string(output: Any) -> str: return str(output) -def handle_tool_call( - tool_call: ToolCall, - tools: list[Tool], - error: str = None, - agent: "Agent" = None, -) -> ToolMessage: +class ToolResult(ControlFlowModel): + tool_call_id: str + result: Any = Field(exclude=True, repr=False) + str_result: str = Field(repr=False) + is_error: bool = False + is_private: bool = False + + +def handle_tool_call(tool_call: ToolCall, tools: list[Tool]) -> Any: + """ + Given a ToolCall and set of available tools, runs the tool call and returns + a ToolResult object + """ + is_error = False + tool = None tool_lookup = {t.name: t for t in tools} fn_name = tool_call["name"] - is_error = False - metadata = {} - try: - if error: - fn_output = error - is_error = True - elif fn_name not in tool_lookup: - fn_output = f'Function "{fn_name}" not found.' - is_error = True - else: + + if fn_name not in tool_lookup: + fn_output = f'Function "{fn_name}" not found.' + is_error = True + if controlflow.settings.tools_raise_on_error: + raise ValueError(fn_output) + + if not is_error: + try: tool = tool_lookup[fn_name] - metadata.update(getattr(tool, "metadata", {})) fn_args = tool_call["args"] if isinstance(tool, Tool): fn_output = tool.run(input=fn_args) elif isinstance(tool, langchain_core.tools.BaseTool): fn_output = tool.invoke(input=fn_args) - except Exception as exc: - fn_output = f'Error calling function "{fn_name}": {exc}' + except Exception as exc: + fn_output = f'Error calling function "{fn_name}": {exc}' + is_error = True + if controlflow.settings.tools_raise_on_error: + raise exc + + return ToolResult( + tool_call_id=tool_call["id"], + result=fn_output, + str_result=output_to_string(fn_output), + is_error=is_error, + is_private=getattr(tool, "private", False), + ) + + +async def handle_tool_call_async(tool_call: ToolCall, tools: list[Tool]) -> Any: + """ + Given a ToolCall and set of available tools, runs the tool call and returns + a ToolResult object + """ + is_error = False + tool = None + tool_lookup = {t.name: t for t in tools} + fn_name = tool_call["name"] + + if fn_name not in tool_lookup: + fn_output = f'Function "{fn_name}" not found.' is_error = True if controlflow.settings.tools_raise_on_error: - raise + raise ValueError(fn_output) - from controlflow.llm.messages import ToolMessage + if not is_error: + try: + tool = tool_lookup[fn_name] + fn_args = tool_call["args"] + if isinstance(tool, Tool): + fn_output = await tool.run_async(input=fn_args) + elif isinstance(tool, langchain_core.tools.BaseTool): + fn_output = await tool.ainvoke(input=fn_args) + except Exception as exc: + fn_output = f'Error calling function "{fn_name}": {exc}' + is_error = True + if controlflow.settings.tools_raise_on_error: + raise exc - return ToolMessage( - content=output_to_string(fn_output), + return ToolResult( tool_call_id=tool_call["id"], - tool_call=tool_call, - tool_result=fn_output, - tool_metadata=metadata, + result=fn_output, + str_result=output_to_string(fn_output), is_error=is_error, - agent=agent, + is_private=getattr(tool, "private", False), ) diff --git a/src/controlflow/utilities/jinja.py b/src/controlflow/utilities/jinja.py index 1297add3..82b38f13 100644 --- a/src/controlflow/utilities/jinja.py +++ b/src/controlflow/utilities/jinja.py @@ -4,9 +4,17 @@ from zoneinfo import ZoneInfo from jinja2 import Environment as JinjaEnvironment -from jinja2 import StrictUndefined, select_autoescape +from jinja2 import PackageLoader, StrictUndefined, select_autoescape -jinja_env = JinjaEnvironment( +global_fns = { + "now": lambda: datetime.now(ZoneInfo("UTC")), + "inspect": inspect, + "getcwd": os.getcwd, + "zip": zip, +} + +prompt_env = JinjaEnvironment( + loader=PackageLoader("controlflow.orchestration", "prompt_templates"), autoescape=select_autoescape(default_for_string=False), trim_blocks=True, lstrip_blocks=True, @@ -14,11 +22,4 @@ undefined=StrictUndefined, ) -jinja_env.globals.update( - { - "now": lambda: datetime.now(ZoneInfo("UTC")), - "inspect": inspect, - "getcwd": os.getcwd, - "zip": zip, - } -) +prompt_env.globals.update(global_fns) diff --git a/src/controlflow/utilities/prefect.py b/src/controlflow/utilities/prefect.py index 5b9da445..c4c8714c 100644 --- a/src/controlflow/utilities/prefect.py +++ b/src/controlflow/utilities/prefect.py @@ -50,7 +50,7 @@ def prefect_task(*args, **kwargs): kwargs.setdefault("log_prints", controlflow.settings.log_prints) kwargs.setdefault("cache_policy", prefect.cache_policies.NONE) - kwargs.setdefault("result_serializer", prefect.serializers.JSONSerializer()) + kwargs.setdefault("result_serializer", "json") return prefect.task(*args, **kwargs) @@ -61,6 +61,7 @@ def prefect_flow(*args, **kwargs): """ kwargs.setdefault("log_prints", controlflow.settings.log_prints) + kwargs.setdefault("result_serializer", "json") return prefect.flow(*args, **kwargs) @@ -182,12 +183,11 @@ def start(self, depends_on: list = None): self.is_started = True self._client = get_client(sync_client=True) - self._task = prefect.Task( - fn=lambda: None, + self._task = prefect_task( name=self.name, description=self.description, tags=self.tags, - ) + )(lambda: None) self._task_run = run_coro_as_sync( self._task.create_run( diff --git a/src/controlflow/utilities/testing.py b/src/controlflow/utilities/testing.py index 59dad02c..f5ee9f80 100644 --- a/src/controlflow/utilities/testing.py +++ b/src/controlflow/utilities/testing.py @@ -1,24 +1,15 @@ from contextlib import contextmanager -from typing import Union -import langchain_core.messages from langchain_core.language_models.fake_chat_models import FakeMessagesListChatModel import controlflow -from controlflow.flows.history import InMemoryHistory -from controlflow.llm.messages import BaseMessage, MessageType +from controlflow.events.history import InMemoryHistory +from controlflow.llm.messages import BaseMessage class FakeLLM(FakeMessagesListChatModel): - def set_responses( - self, responses: list[Union[MessageType, langchain_core.messages.BaseMessage]] - ): - new_responses = [] - for msg in responses: - if isinstance(msg, BaseMessage): - msg = msg.to_langchain_message() - new_responses.append(msg) - self.responses = new_responses + def set_responses(self, responses: list[BaseMessage]): + self.responses = responses def bind_tools(self, *args, **kwargs): """When binding tools, passthrough""" @@ -30,38 +21,32 @@ def get_num_tokens_from_messages(self, messages: list) -> int: @contextmanager -def record_messages( - remove_additional_kwargs: bool = True, remove_tool_call_chunks: bool = True -): +def record_events(): """ Context manager for recording all messages in a flow, useful for testing. - with record_messages() as messages: + with record_events() as events: cf.Task("say hello").run() - assert messages[0].content == "Hello!" + assert events[0].content == "Hello!" """ history = InMemoryHistory(history={}) old_default_history = controlflow.default_history controlflow.default_history = history - messages = [] + events = [] try: - yield messages + yield events finally: controlflow.default_history = old_default_history - _messages_buffer = [] - for _, thread_messages in history.history.items(): - for message in thread_messages: - message = message.copy() - if hasattr(message, "additional_kwargs") and remove_additional_kwargs: - message.additional_kwargs = {} - if hasattr(message, "tool_call_chunks") and remove_tool_call_chunks: - message.tool_call_chunks = [] - _messages_buffer.append(message) - - messages.extend(sorted(_messages_buffer, key=lambda m: m.timestamp)) + _events_buffer = [] + for _, thread_events in history.history.items(): + for event in thread_events: + event = event.copy() + _events_buffer.append(event) + + events.extend(sorted(_events_buffer, key=lambda m: m.timestamp)) diff --git a/src/controlflow/utilities/types.py b/src/controlflow/utilities/types.py index 8f81376d..1caa96a3 100644 --- a/src/controlflow/utilities/types.py +++ b/src/controlflow/utilities/types.py @@ -34,10 +34,3 @@ class PandasSeries(ControlFlowModel): index: Optional[list[str]] = None name: Optional[str] = None dtype: Optional[str] = None - - -class _OpenAIBaseType(ControlFlowModel): - model_config = ConfigDict(extra="allow") - - -__all__ = ["ControlFlowModel", "PandasDataFrame", "PandasSeries", "_OpenAIBaseType"] diff --git a/tests/conftest.py b/tests/conftest.py index adcf8563..4285f452 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,5 @@ import pytest -from controlflow.llm.messages import MessageType +from controlflow.llm.messages import BaseMessage from controlflow.settings import temporary_settings from prefect.testing.utilities import prefect_test_harness diff --git a/tests/controllers/__init__.py b/tests/controllers/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/fixtures/controlflow.py b/tests/fixtures/controlflow.py index f6b1c086..3f4d7bcc 100644 --- a/tests/fixtures/controlflow.py +++ b/tests/fixtures/controlflow.py @@ -1,6 +1,6 @@ import controlflow import pytest -from controlflow.llm.messages import MessageType +from controlflow.llm.messages import BaseMessage from controlflow.utilities.testing import FakeLLM diff --git a/tests/flows/test_flows.py b/tests/flows/test_flows.py index ad98f5d6..b3ed3404 100644 --- a/tests/flows/test_flows.py +++ b/tests/flows/test_flows.py @@ -1,6 +1,6 @@ from controlflow.agents import Agent +from controlflow.events.agent_events import UserMessageEvent from controlflow.flows import Flow, get_flow -from controlflow.llm.messages import UserMessage from controlflow.tasks.task import Task from controlflow.utilities.context import ctx @@ -71,7 +71,7 @@ def test_tasks_created_in_flow_context(self): t1 = Task("test 1") t2 = Task("test 2") - assert flow.tasks == {t1.id: t1, t2.id: t2} + assert flow.tasks == [t1, t2] def test_tasks_created_in_nested_flows_only_in_inner_flow(self): with Flow() as flow1: @@ -79,73 +79,63 @@ def test_tasks_created_in_nested_flows_only_in_inner_flow(self): with Flow() as flow2: t2 = Task("test 2") - assert flow1.tasks == {t1.id: t1} - assert flow2.tasks == {t2.id: t2} + assert flow1.tasks == [t1] + assert flow2.tasks == [t2] + + def test_inner_flow_includes_completed_parent_tasks(self): + with Flow() as flow1: + t1 = Task("test 1", status="SUCCESSFUL") + t2 = Task("test 2") + with Flow() as flow2: + t3 = Task("test 3") + + assert flow1.tasks == [t1, t2] + assert flow2.tasks == [t1, t3] class TestFlowHistory: - def test_get_messages_empty(self): + def test_get_events_empty(self): flow = Flow() - messages = flow.get_messages() + messages = flow.get_events() assert messages == [] - def test_add_messages_with_history(self): - flow = Flow() - flow.add_messages( - messages=[UserMessage(content="hello"), UserMessage(content="world")] - ) - messages = flow.get_messages() - assert len(messages) == 2 - assert [m.content for m in messages] == ["hello", "world"] - - def test_copy_parent_history(self): - flow1 = Flow() - flow1.add_messages( - messages=[UserMessage(content="hello"), UserMessage(content="world")] - ) - - with flow1: - flow2 = Flow() - - messages1 = flow1.get_messages() - assert len(messages1) == 2 - assert [m.content for m in messages1] == ["hello", "world"] - - messages2 = flow2.get_messages() - assert len(messages2) == 2 - assert [m.content for m in messages2] == ["hello", "world"] - def test_disable_copying_parent_history(self): flow1 = Flow() - flow1.add_messages( - messages=[UserMessage(content="hello"), UserMessage(content="world")] + flow1.add_events( + [ + UserMessageEvent(content="hello"), + UserMessageEvent(content="world"), + ] ) with flow1: flow2 = Flow(copy_parent=False) - messages1 = flow1.get_messages() + messages1 = flow1.get_events() assert len(messages1) == 2 assert [m.content for m in messages1] == ["hello", "world"] - messages2 = flow2.get_messages() + messages2 = flow2.get_events() assert len(messages2) == 0 def test_child_flow_messages_dont_go_to_parent(self): flow1 = Flow() - flow1.add_messages( - messages=[UserMessage(content="hello"), UserMessage(content="world")] + flow1.add_events( + [ + UserMessageEvent(content="hello"), + UserMessageEvent(content="world"), + ] ) with flow1: flow2 = Flow() - flow2.add_messages(messages=[UserMessage(content="goodbye")]) + flow2.add_events([UserMessageEvent(content="goodbye")]) - messages1 = flow1.get_messages() + messages1 = flow1.get_events() assert len(messages1) == 2 assert [m.content for m in messages1] == ["hello", "world"] - messages2 = flow2.get_messages() + messages2 = flow2.get_events() assert len(messages2) == 3 assert [m.content for m in messages2] == ["hello", "world", "goodbye"] diff --git a/tests/controllers/test_graph.py b/tests/flows/test_graph.py similarity index 79% rename from tests/controllers/test_graph.py rename to tests/flows/test_graph.py index e775e7e4..be6209d2 100644 --- a/tests/controllers/test_graph.py +++ b/tests/flows/test_graph.py @@ -1,5 +1,5 @@ # test_graph.py -from controlflow.controllers.graph import Edge, EdgeType, Graph +from controlflow.flows.graph import Edge, EdgeType, Graph from controlflow.tasks.task import Task @@ -34,7 +34,7 @@ def test_from_tasks(): task1 = Task(objective="Task 1") task2 = Task(objective="Task 2", depends_on=[task1]) task3 = Task(objective="Task 3", parent=task2) - graph = Graph.from_tasks([task1, task2, task3]) + graph = Graph(tasks=[task1, task2, task3]) assert len(graph.tasks) == 3 assert task1 in graph.tasks assert task2 in graph.tasks @@ -57,7 +57,7 @@ def test_from_tasks(): def test_upstream_edges(): task1 = Task(objective="Task 1") task2 = Task(objective="Task 2", depends_on=[task1]) - graph = Graph.from_tasks([task1, task2]) + graph = Graph(tasks=[task1, task2]) upstream_edges = graph.upstream_edges() assert len(upstream_edges[task1]) == 0 assert len(upstream_edges[task2]) == 1 @@ -67,7 +67,7 @@ def test_upstream_edges(): def test_downstream_edges(): task1 = Task(objective="Task 1") task2 = Task(objective="Task 2", depends_on=[task1]) - graph = Graph.from_tasks([task1, task2]) + graph = Graph(tasks=[task1, task2]) downstream_edges = graph.downstream_edges() assert len(downstream_edges[task1]) == 1 assert len(downstream_edges[task2]) == 0 @@ -79,7 +79,7 @@ def test_topological_sort(): task2 = Task(objective="Task 2", depends_on=[task1]) task3 = Task(objective="Task 3", depends_on=[task2]) task4 = Task(objective="Task 4", depends_on=[task3]) - graph = Graph.from_tasks([task1, task2, task3, task4]) + graph = Graph(tasks=[task1, task2, task3, task4]) sorted_tasks = graph.topological_sort() assert len(sorted_tasks) == 4 assert sorted_tasks.index(task1) < sorted_tasks.index(task2) @@ -87,6 +87,16 @@ def test_topological_sort(): assert sorted_tasks.index(task3) < sorted_tasks.index(task4) +def test_topological_sort_uses_time_to_tiebreak(): + task1 = Task(objective="Task 1") + task2 = Task(objective="Task 2") + task3 = Task(objective="Task 3") + task4 = Task(objective="Task 4") + graph = Graph(tasks=[task1, task2, task3, task4]) + sorted_tasks = graph.topological_sort() + assert sorted_tasks == [task1, task2, task3, task4] + + def test_topological_sort_with_fan_in_and_fan_out(): task1 = Task(objective="Task 1") task2 = Task(objective="Task 2") @@ -125,12 +135,11 @@ def test_upstream_tasks(): graph.add_edge(edge2) graph.add_edge(edge3) - assert graph.upstream_tasks([task3]) == [task1, task2] - assert graph.upstream_tasks([task2]) == [task1] - assert graph.upstream_tasks([task1]) == [] + assert graph.upstream_tasks([task3]) == [task1, task2, task3] + assert graph.upstream_tasks([task2]) == [task1, task2] + assert graph.upstream_tasks([task1]) == [task1] - # never include a start task in the usptream list - assert graph.upstream_tasks([task1, task3]) == [task2] + assert graph.upstream_tasks([task1, task3]) == [task1, task2, task3] def test_downstream_tasks(): @@ -147,9 +156,9 @@ def test_downstream_tasks(): graph.add_edge(edge2) graph.add_edge(edge3) - assert graph.downstream_tasks([task3]) == [] - assert graph.downstream_tasks([task2]) == [task3] - assert graph.downstream_tasks([task1]) == [task2, task3] + assert graph.downstream_tasks([task3]) == [task3] + assert graph.downstream_tasks([task2]) == [task2, task3] + assert graph.downstream_tasks([task1]) == [task1, task2, task3] # never include a start task in the downstream list - assert graph.downstream_tasks([task1, task3]) == [task2] + assert graph.downstream_tasks([task1, task3]) == [task1, task2, task3] diff --git a/tests/llm/test_messages.py b/tests/llm/test_messages.py deleted file mode 100644 index 67adec53..00000000 --- a/tests/llm/test_messages.py +++ /dev/null @@ -1,15 +0,0 @@ -import controlflow -from controlflow.llm.messages import AgentReference, AIMessage - - -class TestAIMessage: - def test_agent(self): - agent = controlflow.Agent(name="Test Agent!") - message = AIMessage(content="", agent=agent) - assert isinstance(message.agent, AgentReference) - assert message.agent.name == "Test Agent!" - - def test_name_loaded_from_agent(self): - agent = controlflow.Agent(name="Test Agent!") - message = AIMessage(content="", agent=agent) - assert message.name == "Test-Agent" diff --git a/src/controlflow/handlers/__init__.py b/tests/orchestration/__init__.py similarity index 100% rename from src/controlflow/handlers/__init__.py rename to tests/orchestration/__init__.py diff --git a/tests/orchestration/test_controller.py b/tests/orchestration/test_controller.py new file mode 100644 index 00000000..eb58ea6f --- /dev/null +++ b/tests/orchestration/test_controller.py @@ -0,0 +1,33 @@ +from controlflow.flows import Flow +from controlflow.orchestration.controller import Controller +from controlflow.tasks import Task + + +class TestReadyTasks: + def test_ready_tasks(self): + controller = Controller(flow=Flow()) + assert controller.get_ready_tasks() == [] + + def test_ready_tasks_nested_1(self): + with Flow() as flow: + with Task("parent") as parent: + child_1 = Task("child 1") + child_2 = Task("child 2") + + assert Controller(flow=flow, tasks=[]).get_ready_tasks() == [child_1, child_2] + assert Controller(flow=flow, tasks=[child_1]).get_ready_tasks() == [child_1] + assert Controller(flow=flow, tasks=[child_2]).get_ready_tasks() == [child_2] + assert Controller(flow=flow, tasks=[parent]).get_ready_tasks() == [ + child_1, + child_2, + ] + + def test_ready_tasks_nested(self): + with Flow() as flow: + with Task("parent"): + child_1 = Task("child 1") + child_2 = Task("child 2", context=dict(sibling=child_1)) + + assert Controller(flow=flow, tasks=[child_2]).get_ready_tasks() == [child_1] + assert Controller(flow=flow, tasks=[]).get_ready_tasks() == [child_1] + assert Controller(flow=flow, tasks=[child_1]).get_ready_tasks() == [child_1] diff --git a/tests/tasks/test_tasks.py b/tests/tasks/test_tasks.py index 9ace51c0..13e63a3c 100644 --- a/tests/tasks/test_tasks.py +++ b/tests/tasks/test_tasks.py @@ -2,7 +2,6 @@ import pytest from controlflow.agents import Agent, get_default_agent -from controlflow.controllers.graph import EdgeType from controlflow.flows import Flow from controlflow.instructions import instructions from controlflow.tasks.task import Task, TaskStatus @@ -48,6 +47,23 @@ def test_task_dependencies(): assert task2 in task1._downstreams +def test_task_context_dependencies(): + task1 = SimpleTask() + task2 = SimpleTask(context=dict(a=task1)) + assert task1 in task2.depends_on + assert task2 in task1._downstreams + + +def test_task_context_complex_dependencies(): + task1 = SimpleTask() + task2 = SimpleTask() + task3 = SimpleTask(context=dict(a=[task1], b=dict(c=[task2]))) + assert task1 in task3.depends_on + assert task2 in task3.depends_on + assert task3 in task1._downstreams + assert task3 in task2._downstreams + + def test_task_subtasks(): task1 = SimpleTask() task2 = SimpleTask(parent=task1) @@ -121,17 +137,29 @@ def test_task_loads_agent_from_parent_before_flow(): assert child.get_agents() == [agent2] -def test_task_tracking(): - with Flow() as flow: +class TestFlowRegistration: + def test_task_tracking(self): + with Flow() as flow: + task = SimpleTask() + assert task in flow.tasks + + def test_task_tracking_on_call(self): task = SimpleTask() - assert task in flow.tasks.values() + with Flow() as flow: + task.run_once() + assert task in flow.tasks + def test_parent_child_tracking(self): + with Flow() as flow: + with SimpleTask() as parent: + with SimpleTask() as child: + grandchild = SimpleTask() -def test_task_tracking_on_call(): - task = SimpleTask() - with Flow() as flow: - task.run_once() - assert task in flow.tasks.values() + assert parent in flow.tasks + assert child in flow.tasks + assert grandchild in flow.tasks + + assert len(flow.graph.edges) == 2 class TestTaskStatus: @@ -213,100 +241,3 @@ def test_task_hash(self): task1 = SimpleTask() task2 = SimpleTask() assert hash(task1) != hash(task2) - - def test_ready_task_adds_tools(self): - task = SimpleTask() - assert task.is_ready() - - tools = task.get_tools() - assert any(tool.name == f"mark_task_{task.id}_failed" for tool in tools) - assert any(tool.name == f"mark_task_{task.id}_successful" for tool in tools) - - def test_completed_task_does_not_add_tools(self): - task = SimpleTask() - task.mark_successful() - tools = task.get_tools() - assert not any(tool.name == f"mark_task_{task.id}_failed" for tool in tools) - assert not any(tool.name == f"mark_task_{task.id}_successful" for tool in tools) - - def test_task_with_incomplete_upstream_does_not_add_tools(self): - upstream_task = SimpleTask() - downstream_task = SimpleTask(depends_on=[upstream_task]) - tools = downstream_task.get_tools() - assert not any( - tool.name == f"mark_task_{downstream_task.id}_failed" for tool in tools - ) - assert not any( - tool.name == f"mark_task_{downstream_task.id}_successful" for tool in tools - ) - - def test_task_with_incomplete_subtask_does_not_add_tools(self): - parent = SimpleTask() - SimpleTask(parent=parent) - tools = parent.get_tools() - assert not any(tool.name == f"mark_task_{parent.id}_failed" for tool in tools) - assert not any( - tool.name == f"mark_task_{parent.id}_successful" for tool in tools - ) - - -class TestTaskToGraph: - def test_single_task_graph(self): - task = SimpleTask() - graph = task.as_graph() - assert len(graph.tasks) == 1 - assert task in graph.tasks - assert len(graph.edges) == 0 - - def test_task_with_subtasks_graph(self): - task1 = SimpleTask() - task2 = SimpleTask(parent=task1) - graph = task1.as_graph() - assert len(graph.tasks) == 2 - assert task1 in graph.tasks - assert task2 in graph.tasks - assert len(graph.edges) == 1 - assert any( - edge.upstream == task2 - and edge.downstream == task1 - and edge.type == EdgeType.SUBTASK - for edge in graph.edges - ) - - def test_task_with_dependencies_graph(self): - task1 = SimpleTask() - task2 = SimpleTask(depends_on=[task1]) - graph = task2.as_graph() - assert len(graph.tasks) == 2 - assert task1 in graph.tasks - assert task2 in graph.tasks - assert len(graph.edges) == 1 - assert any( - edge.upstream == task1 - and edge.downstream == task2 - and edge.type == EdgeType.DEPENDENCY - for edge in graph.edges - ) - - def test_task_with_subtasks_and_dependencies_graph(self): - task1 = SimpleTask() - task2 = SimpleTask(depends_on=[task1]) - task3 = SimpleTask(objective="Task 3", parent=task2) - graph = task2.as_graph() - assert len(graph.tasks) == 3 - assert task1 in graph.tasks - assert task2 in graph.tasks - assert task3 in graph.tasks - assert len(graph.edges) == 2 - assert any( - edge.upstream == task1 - and edge.downstream == task2 - and edge.type == EdgeType.DEPENDENCY - for edge in graph.edges - ) - assert any( - edge.upstream == task3 - and edge.downstream == task2 - and edge.type == EdgeType.SUBTASK - for edge in graph.edges - ) diff --git a/tests/test_settings.py b/tests/test_settings.py index 6246b0d0..b09e688f 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -6,6 +6,12 @@ from prefect.logging import get_logger +def test_defaults(): + # ensure that debug settings etc. are not left on by default + assert controlflow.settings.tools_raise_on_error is False + assert controlflow.settings.tools_verbose is False + + def test_temporary_settings(): assert controlflow.settings.tools_raise_on_error is False with temporary_settings(tools_raise_on_error=True): diff --git a/tests/llm/test_tools.py b/tests/tools/test_tools.py similarity index 99% rename from tests/llm/test_tools.py rename to tests/tools/test_tools.py index 7fb037fe..67e55e44 100644 --- a/tests/llm/test_tools.py +++ b/tests/tools/test_tools.py @@ -4,7 +4,7 @@ import pytest from controlflow.agents.agent import Agent from controlflow.llm.messages import ToolMessage -from controlflow.llm.tools import ( +from controlflow.tools.tools import ( Tool, handle_tool_call, tool, diff --git a/tests/utilities/test_testing.py b/tests/utilities/test_testing.py index c23a0db6..875a7f67 100644 --- a/tests/utilities/test_testing.py +++ b/tests/utilities/test_testing.py @@ -1,63 +1,48 @@ -import datetime - import controlflow -from controlflow.llm.messages import AIMessage, ToolMessage -from controlflow.utilities.testing import record_messages +from controlflow.llm.messages import AIMessage +from controlflow.utilities.testing import record_events -def test_record_messages_empty(): - with record_messages() as messages: +def test_record_events_empty(): + with record_events() as events: pass - assert messages == [] + assert events == [] -def test_record_task_messages(default_fake_llm): - task = controlflow.Task("say hello") +def test_record_task_events(default_fake_llm): + task = controlflow.Task("say hello", id="12345") response = AIMessage( - agent=dict(name="Marvin"), id="run-2af8bb73-661f-4ec3-92ff-d7d8e3074926", - timestamp=datetime.datetime( - 2024, 6, 23, 17, 12, 24, 91830, tzinfo=datetime.timezone.utc - ), + name="Marvin", role="ai", content="", - name="Marvin", tool_calls=[ { - "name": f"mark_task_{task.id}_successful", + "name": "mark_task_12345_successful", "args": {"result": "Hello!"}, "id": "call_ZEPdV8mCgeBe5UHjKzm6e3pe", } ], - is_delta=False, ) default_fake_llm.set_responses([response]) - with record_messages() as rec_messages: + with record_events() as events: task.run() - assert rec_messages[0].content == response.content - assert rec_messages[0].id == response.id - assert rec_messages[0].tool_calls == response.tool_calls - - expected_tool_message = ToolMessage( - agent=dict(name="Marvin"), - id="cb84bb8f3e0f4245bbf5eefeee9272b2", - timestamp=datetime.datetime( - 2024, 6, 23, 17, 12, 24, 187384, tzinfo=datetime.timezone.utc - ), - role="tool", - content=f'Task {task.id} ("say hello") marked successful by Marvin.', - name="Marvin", + assert events[0].event == "select-agent" + assert events[1].event == "agent-message" + assert response == events[1].ai_message + + assert events[5].event == "tool-result" + assert events[5].tool_call == { + "name": "mark_task_12345_successful", + "args": {"result": "Hello!"}, + "id": "call_ZEPdV8mCgeBe5UHjKzm6e3pe", + } + assert events[5].tool_result.model_dump() == dict( tool_call_id="call_ZEPdV8mCgeBe5UHjKzm6e3pe", - tool_call={ - "name": f"mark_task_{task.id}_successful", - "args": {"result": "Hello!"}, - "id": "call_ZEPdV8mCgeBe5UHjKzm6e3pe", - }, - tool_metadata={"ignore_result": True}, + str_result='Task 12345 ("say hello") marked successful.', + is_error=False, + is_private=True, ) - assert rec_messages[1].content == expected_tool_message.content - assert rec_messages[1].tool_call_id == expected_tool_message.tool_call_id - assert rec_messages[1].tool_call == expected_tool_message.tool_call