From 36623d199514b78b9c28e32c2f54cd3f8a2252c3 Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Thu, 11 Jul 2024 20:17:09 -0400 Subject: [PATCH] Update prompt customization --- docs/tutorial.mdx | 2 +- src/controlflow/agents/agent.py | 35 ++++++++----------- src/controlflow/agents/teams.py | 15 +++++--- src/controlflow/flows/flow.py | 6 ++++ .../orchestration/prompt_templates.py | 15 ++++---- src/controlflow/tasks/task.py | 35 +++++-------------- src/controlflow/utilities/testing.py | 4 +++ tests/agents/test_agents.py | 28 +++++++++++++++ tests/flows/test_flows.py | 27 ++++++++++++++ tests/tasks/test_tasks.py | 31 +++++++++++++--- 10 files changed, 135 insertions(+), 63 deletions(-) diff --git a/docs/tutorial.mdx b/docs/tutorial.mdx index b4b3d284..a8132b8d 100644 --- a/docs/tutorial.mdx +++ b/docs/tutorial.mdx @@ -423,7 +423,7 @@ with cf.instructions('No more than 5 sentences per document'): technical_document.run() ``` -When you run the `technical_document` task, ControlFlow will assign both the `docs_agent` and the `editor_agent` to complete the task. The `docs_agent` will generate the technical document, and the `editor_agent` will review and edit the document to ensure its accuracy and readability. By default, they will be run in round-robin fashion, but you can customize the agent selection strategy by passing a function as the task's `agent_strategy`. +When you run the `technical_document` task, ControlFlow will assign both the `docs_agent` and the `editor_agent` to complete the task. The `docs_agent` will generate the technical document, and the `editor_agent` will review and edit the document to ensure its accuracy and readability. ### Instructions diff --git a/src/controlflow/agents/agent.py b/src/controlflow/agents/agent.py index c79f6057..82f56419 100644 --- a/src/controlflow/agents/agent.py +++ b/src/controlflow/agents/agent.py @@ -47,6 +47,11 @@ class BaseAgent(ControlFlowModel, abc.ABC): description: Optional[str] = Field( None, description="A description of the agent, visible to other agents." ) + prompt: Optional[str] = Field( + None, + description="A prompt to display as a system message to the agent." + "Prompts are formatted as jinja templates, with keywords `agent: Agent` and `context: AgentContext`.", + ) def serialize_for_prompt(self) -> dict: return self.model_dump() @@ -65,8 +70,15 @@ async def _run_async(self, context: "AgentContext") -> list[Event]: """ raise NotImplementedError() - async def get_activation_prompt(self) -> str: - return f"Agent {self.name} is now active." + def get_prompt(self, context: "AgentContext") -> str: + from controlflow.orchestration import prompt_templates + + template = prompt_templates.AgentTemplate( + template=self.prompt, + agent=self, + context=context, + ) + return template.render() class Agent(BaseAgent): @@ -86,7 +98,7 @@ class Agent(BaseAgent): False, description="If True, the agent is given tools for interacting with a human user.", ) - system_template: Optional[str] = Field( + prompt: Optional[str] = Field( None, description="A system template for the agent. The template should be formatted as a jinja2 template.", ) @@ -164,23 +176,6 @@ def get_tools(self) -> list[Callable]: return tools - def get_prompt(self, context: "AgentContext") -> str: - from controlflow.orchestration import prompt_templates - - if self.system_template: - template = prompt_templates.AgentTemplate( - template=self.system_template, - template_path=None, - agent=self, - context=context, - ) - else: - template = prompt_templates.AgentTemplate(agent=self, context=context) - return template.render() - - def get_activation_prompt(self) -> str: - return f"Agent {self.name} is now active." - @contextmanager def create_context(self): with ctx(agent=self): diff --git a/src/controlflow/agents/teams.py b/src/controlflow/agents/teams.py index ccc3961f..b3963ddc 100644 --- a/src/controlflow/agents/teams.py +++ b/src/controlflow/agents/teams.py @@ -24,7 +24,11 @@ class Team(BaseAgent): None, description="Instructions for all agents on the team, private to this agent.", ) - + prompt: Optional[str] = Field( + None, + description="A prompt to display as an instruction to any agent selected as part of this team (or a nested team). " + "Prompts are formatted as jinja templates, with keywords `team: Team` and `context: AgentContext`.", + ) agents: list[Agent] = Field( description="The agents in the team.", default_factory=list, @@ -47,9 +51,12 @@ def get_agent(self, context: "AgentContext") -> Agent: raise NotImplementedError() def get_prompt(self, context: "AgentContext") -> str: - from controlflow.orchestration.prompt_templates import TeamTemplate + from controlflow.orchestration import prompt_templates - return TeamTemplate(team=self, context=context).render() + template = prompt_templates.TeamTemplate( + template=self.prompt, team=self, context=context + ) + return template.render() def _run(self, context: "AgentContext"): context.add_instructions([self.get_prompt(context=context)]) @@ -59,7 +66,7 @@ def _run(self, context: "AgentContext"): self._iterations += 1 async def _run_async(self, context: "AgentContext"): - context.add_instructions([self.get_prompt()]) + context.add_instructions([self.get_prompt(context=context)]) agent = self.get_agent(context=context) with context.with_agent(agent) as agent_context: await agent._run_async(context=agent_context) diff --git a/src/controlflow/flows/flow.py b/src/controlflow/flows/flow.py index cbafec5a..75b83bdc 100644 --- a/src/controlflow/flows/flow.py +++ b/src/controlflow/flows/flow.py @@ -36,6 +36,11 @@ class Flow(ControlFlowModel): description="The default agent for the flow. This agent will be used " "for any task that does not specify an agent.", ) + prompt: Optional[str] = Field( + None, + description="A prompt to display to the agent working on the flow. " + "Prompts are formatted as jinja templates, with keywords `flow: Flow` and `context: AgentContext`.", + ) context: dict[str, Any] = {} graph: Graph = Field(default_factory=Graph, repr=False, exclude=True) _cm_stack: list[contextmanager] = [] @@ -79,6 +84,7 @@ def get_prompt(self, context: "AgentContext") -> str: from controlflow.orchestration import prompt_templates template = prompt_templates.FlowTemplate( + template=self.prompt, flow=self, context=context, ) diff --git a/src/controlflow/orchestration/prompt_templates.py b/src/controlflow/orchestration/prompt_templates.py index f49dd2f7..fe44fba3 100644 --- a/src/controlflow/orchestration/prompt_templates.py +++ b/src/controlflow/orchestration/prompt_templates.py @@ -1,3 +1,5 @@ +from typing import Optional + from pydantic import model_validator from controlflow.agents.agent import Agent @@ -10,15 +12,14 @@ class Template(ControlFlowModel): - template: str = None - template_path: str = None + model_config = dict(extra="allow") + template: Optional[str] = None + template_path: Optional[str] = None @model_validator(mode="after") def _validate(self): if not self.template and not self.template_path: raise ValueError("Template or template_path must be provided.") - elif self.template and self.template_path: - raise ValueError("Only one of template or template_path must be provided.") return self def render(self, **kwargs) -> str: @@ -29,10 +30,10 @@ def render(self, **kwargs) -> str: del render_kwargs["template"] del render_kwargs["template_path"] - if self.template_path: - template = prompt_env.get_template(self.template_path) - else: + if self.template is not None: template = prompt_env.from_string(self.template) + else: + template = prompt_env.get_template(self.template_path) return template.render(**render_kwargs | kwargs) def should_render(self) -> bool: diff --git a/src/controlflow/tasks/task.py b/src/controlflow/tasks/task.py index f9b7d9a6..f8bff156 100644 --- a/src/controlflow/tasks/task.py +++ b/src/controlflow/tasks/task.py @@ -105,6 +105,11 @@ class Task(ControlFlowModel): depends_on: set["Task"] = Field( default_factory=set, description="Tasks that this task depends on explicitly." ) + prompt: Optional[str] = Field( + None, + description="A prompt to display to the agent working on the task. " + "Prompts are formatted as jinja templates, with keywords `task: Task` and `context: AgentContext`.", + ) status: TaskStatus = TaskStatus.PENDING result: T = None result_type: Union[type[T], GenericAlias, _LiteralGenericAlias, None] = Field( @@ -123,14 +128,6 @@ class Task(ControlFlowModel): False, description="Work on private tasks is not visible to agents other than those assigned to the task.", ) - agent_strategy: Optional[Callable] = Field( - None, - description="A function that returns an agent, used for customizing how " - "the next agent is selected. The returned agent must be one " - "of the assigned agents. If not provided, will be inferred " - "from the parent task; round-robin selection is the default. " - "Only used for tasks with more than one agent assigned.", - ) created_at: datetime.datetime = Field(default_factory=datetime.datetime.now) max_iterations: Optional[int] = Field( default_factory=lambda: controlflow.settings.max_task_iterations, @@ -453,24 +450,6 @@ def get_agent(self) -> "Agent": else: return controlflow.defaults.agent - def get_agent_strategy(self) -> Callable: - """ - Get a function for selecting the next agent to work on this - task. - - If an agent_strategy is provided, it will be used. Otherwise, the parent - task's agent_strategy will be used. Finally, the global default agent_strategy - will be used (round-robin selection). - """ - if self.agent_strategy is not None: - return self.agent_strategy - elif self.parent: - return self.parent.get_agent_strategy() - else: - import controlflow.tasks.agent_strategies - - return controlflow.tasks.agent_strategies.round_robin - def get_tools(self) -> list[Union[Tool, Callable]]: tools = self.tools.copy() if self.user_access: @@ -483,7 +462,9 @@ def get_prompt(self, context: "AgentContext") -> str: """ from controlflow.orchestration import prompt_templates - template = prompt_templates.TaskTemplate(task=self, context=context) + template = prompt_templates.TaskTemplate( + template=self.prompt, task=self, context=context + ) return template.render() def set_status(self, status: TaskStatus): diff --git a/src/controlflow/utilities/testing.py b/src/controlflow/utilities/testing.py index e2b14adc..b66bcb21 100644 --- a/src/controlflow/utilities/testing.py +++ b/src/controlflow/utilities/testing.py @@ -1,4 +1,5 @@ from contextlib import contextmanager +from functools import partial from typing import Union from langchain_core.language_models.fake_chat_models import FakeMessagesListChatModel @@ -6,6 +7,9 @@ import controlflow from controlflow.events.history import InMemoryHistory from controlflow.llm.messages import AIMessage, BaseMessage +from controlflow.tasks.task import Task + +SimpleTask = partial(Task, objective="test", result_type=None) class FakeLLM(FakeMessagesListChatModel): diff --git a/tests/agents/test_agents.py b/tests/agents/test_agents.py index edc78d99..bb9dbfe4 100644 --- a/tests/agents/test_agents.py +++ b/tests/agents/test_agents.py @@ -1,7 +1,10 @@ import controlflow +import pytest from controlflow.agents import Agent from controlflow.agents.names import AGENTS +from controlflow.flows import Flow from controlflow.instructions import instructions +from controlflow.orchestration.agent_context import AgentContext from controlflow.tasks.task import Task from langchain_openai import ChatOpenAI @@ -67,3 +70,28 @@ def test_updating_the_default_model_updates_the_default_agent_model(self): task = Task("task") assert task.get_agents()[0].model is None assert task.get_agents()[0].get_model() is new_model + + +class TestAgentPrompt: + @pytest.fixture + def agent_context(self) -> AgentContext: + return AgentContext(agent=Agent(name="Test Agent"), flow=Flow(), tasks=[]) + + def test_default_prompt(self): + agent = Agent() + assert agent.prompt is None + + def test_default_template(self, agent_context): + agent = Agent() + prompt = agent.get_prompt(context=agent_context) + assert prompt.startswith("# Agent") + + def test_custom_prompt(self, agent_context): + agent = Agent(prompt="Custom Prompt") + prompt = agent.get_prompt(context=agent_context) + assert prompt == "Custom Prompt" + + def test_custom_templated_prompt(self, agent_context): + agent = Agent(prompt="{{ agent.name }}", name="abc") + prompt = agent.get_prompt(context=agent_context) + assert prompt == "abc" diff --git a/tests/flows/test_flows.py b/tests/flows/test_flows.py index 883f1cef..d693d952 100644 --- a/tests/flows/test_flows.py +++ b/tests/flows/test_flows.py @@ -1,6 +1,8 @@ +import pytest from controlflow.agents import Agent from controlflow.events.events import UserMessage from controlflow.flows import Flow, get_flow +from controlflow.orchestration.agent_context import AgentContext from controlflow.tasks.task import Task from controlflow.utilities.context import ctx @@ -157,3 +159,28 @@ def test_flow_agent_becomes_task_default(self): with Flow(agents=[agent]): t2 = Task("t2") assert t2.get_agents() == [agent] + + +class TestFlowPrompt: + @pytest.fixture + def agent_context(self) -> AgentContext: + return AgentContext(agent=Agent(name="Test Agent"), flow=Flow(), tasks=[]) + + def test_default_prompt(self): + flow = Flow() + assert flow.prompt is None + + def test_default_template(self, agent_context): + flow = Flow() + prompt = flow.get_prompt(context=agent_context) + assert prompt.startswith("# Flow") + + def test_custom_prompt(self, agent_context): + flow = Flow(prompt="Custom Prompt") + prompt = flow.get_prompt(context=agent_context) + assert prompt == "Custom Prompt" + + def test_custom_templated_prompt(self, agent_context): + flow = Flow(prompt="{{ flow.name }}", name="abc") + prompt = flow.get_prompt(context=agent_context) + assert prompt == "abc" diff --git a/tests/tasks/test_tasks.py b/tests/tasks/test_tasks.py index bfd08ebf..424e720f 100644 --- a/tests/tasks/test_tasks.py +++ b/tests/tasks/test_tasks.py @@ -1,10 +1,9 @@ -from functools import partial - import controlflow import pytest from controlflow.agents import Agent from controlflow.flows import Flow from controlflow.instructions import instructions +from controlflow.orchestration.agent_context import AgentContext from controlflow.tasks.task import ( COMPLETE_STATUSES, INCOMPLETE_STATUSES, @@ -12,8 +11,7 @@ TaskStatus, ) from controlflow.utilities.context import ctx - -SimpleTask = partial(Task, objective="test", result_type=None) +from controlflow.utilities.testing import SimpleTask def test_status_coverage(): @@ -261,3 +259,28 @@ def test_task_hash(self): task1 = SimpleTask() task2 = SimpleTask() assert hash(task1) != hash(task2) + + +class TestTaskPrompt: + @pytest.fixture + def agent_context(self) -> AgentContext: + return AgentContext(agent=Agent(name="Test Agent"), flow=Flow(), tasks=[]) + + def test_default_prompt(self): + task = SimpleTask() + assert task.prompt is None + + def test_default_template(self, agent_context): + task = SimpleTask() + prompt = task.get_prompt(context=agent_context) + assert prompt.startswith("## Task") + + def test_custom_prompt(self, agent_context): + task = SimpleTask(prompt="Custom Prompt") + prompt = task.get_prompt(context=agent_context) + assert prompt == "Custom Prompt" + + def test_custom_templated_prompt(self, agent_context): + task = SimpleTask(prompt="{{ task.objective }}", objective="abc") + prompt = task.get_prompt(context=agent_context) + assert prompt == "abc"