Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update prompt customization #226

Merged
merged 1 commit into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/tutorial.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ with cf.instructions('No more than 5 sentences per document'):
technical_document.run()
```

When you run the `technical_document` task, ControlFlow will assign both the `docs_agent` and the `editor_agent` to complete the task. The `docs_agent` will generate the technical document, and the `editor_agent` will review and edit the document to ensure its accuracy and readability. By default, they will be run in round-robin fashion, but you can customize the agent selection strategy by passing a function as the task's `agent_strategy`.
When you run the `technical_document` task, ControlFlow will assign both the `docs_agent` and the `editor_agent` to complete the task. The `docs_agent` will generate the technical document, and the `editor_agent` will review and edit the document to ensure its accuracy and readability.

### Instructions

Expand Down
35 changes: 15 additions & 20 deletions src/controlflow/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ class BaseAgent(ControlFlowModel, abc.ABC):
description: Optional[str] = Field(
None, description="A description of the agent, visible to other agents."
)
prompt: Optional[str] = Field(
None,
description="A prompt to display as a system message to the agent."
"Prompts are formatted as jinja templates, with keywords `agent: Agent` and `context: AgentContext`.",
)

def serialize_for_prompt(self) -> dict:
return self.model_dump()
Expand All @@ -65,8 +70,15 @@ async def _run_async(self, context: "AgentContext") -> list[Event]:
"""
raise NotImplementedError()

async def get_activation_prompt(self) -> str:
return f"Agent {self.name} is now active."
def get_prompt(self, context: "AgentContext") -> str:
from controlflow.orchestration import prompt_templates

template = prompt_templates.AgentTemplate(
template=self.prompt,
agent=self,
context=context,
)
return template.render()


class Agent(BaseAgent):
Expand All @@ -86,7 +98,7 @@ class Agent(BaseAgent):
False,
description="If True, the agent is given tools for interacting with a human user.",
)
system_template: Optional[str] = Field(
prompt: Optional[str] = Field(
None,
description="A system template for the agent. The template should be formatted as a jinja2 template.",
)
Expand Down Expand Up @@ -164,23 +176,6 @@ def get_tools(self) -> list[Callable]:

return tools

def get_prompt(self, context: "AgentContext") -> str:
from controlflow.orchestration import prompt_templates

if self.system_template:
template = prompt_templates.AgentTemplate(
template=self.system_template,
template_path=None,
agent=self,
context=context,
)
else:
template = prompt_templates.AgentTemplate(agent=self, context=context)
return template.render()

def get_activation_prompt(self) -> str:
return f"Agent {self.name} is now active."

@contextmanager
def create_context(self):
with ctx(agent=self):
Expand Down
15 changes: 11 additions & 4 deletions src/controlflow/agents/teams.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@ class Team(BaseAgent):
None,
description="Instructions for all agents on the team, private to this agent.",
)

prompt: Optional[str] = Field(
None,
description="A prompt to display as an instruction to any agent selected as part of this team (or a nested team). "
"Prompts are formatted as jinja templates, with keywords `team: Team` and `context: AgentContext`.",
)
agents: list[Agent] = Field(
description="The agents in the team.",
default_factory=list,
Expand All @@ -47,9 +51,12 @@ def get_agent(self, context: "AgentContext") -> Agent:
raise NotImplementedError()

def get_prompt(self, context: "AgentContext") -> str:
from controlflow.orchestration.prompt_templates import TeamTemplate
from controlflow.orchestration import prompt_templates

return TeamTemplate(team=self, context=context).render()
template = prompt_templates.TeamTemplate(
template=self.prompt, team=self, context=context
)
return template.render()

def _run(self, context: "AgentContext"):
context.add_instructions([self.get_prompt(context=context)])
Expand All @@ -59,7 +66,7 @@ def _run(self, context: "AgentContext"):
self._iterations += 1

async def _run_async(self, context: "AgentContext"):
context.add_instructions([self.get_prompt()])
context.add_instructions([self.get_prompt(context=context)])
agent = self.get_agent(context=context)
with context.with_agent(agent) as agent_context:
await agent._run_async(context=agent_context)
Expand Down
6 changes: 6 additions & 0 deletions src/controlflow/flows/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ class Flow(ControlFlowModel):
description="The default agent for the flow. This agent will be used "
"for any task that does not specify an agent.",
)
prompt: Optional[str] = Field(
None,
description="A prompt to display to the agent working on the flow. "
"Prompts are formatted as jinja templates, with keywords `flow: Flow` and `context: AgentContext`.",
)
context: dict[str, Any] = {}
graph: Graph = Field(default_factory=Graph, repr=False, exclude=True)
_cm_stack: list[contextmanager] = []
Expand Down Expand Up @@ -79,6 +84,7 @@ def get_prompt(self, context: "AgentContext") -> str:
from controlflow.orchestration import prompt_templates

template = prompt_templates.FlowTemplate(
template=self.prompt,
flow=self,
context=context,
)
Expand Down
15 changes: 8 additions & 7 deletions src/controlflow/orchestration/prompt_templates.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

from pydantic import model_validator

from controlflow.agents.agent import Agent
Expand All @@ -10,15 +12,14 @@


class Template(ControlFlowModel):
template: str = None
template_path: str = None
model_config = dict(extra="allow")
template: Optional[str] = None
template_path: Optional[str] = None

@model_validator(mode="after")
def _validate(self):
if not self.template and not self.template_path:
raise ValueError("Template or template_path must be provided.")
elif self.template and self.template_path:
raise ValueError("Only one of template or template_path must be provided.")
return self

def render(self, **kwargs) -> str:
Expand All @@ -29,10 +30,10 @@ def render(self, **kwargs) -> str:
del render_kwargs["template"]
del render_kwargs["template_path"]

if self.template_path:
template = prompt_env.get_template(self.template_path)
else:
if self.template is not None:
template = prompt_env.from_string(self.template)
else:
template = prompt_env.get_template(self.template_path)
return template.render(**render_kwargs | kwargs)

def should_render(self) -> bool:
Expand Down
35 changes: 8 additions & 27 deletions src/controlflow/tasks/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,11 @@ class Task(ControlFlowModel):
depends_on: set["Task"] = Field(
default_factory=set, description="Tasks that this task depends on explicitly."
)
prompt: Optional[str] = Field(
None,
description="A prompt to display to the agent working on the task. "
"Prompts are formatted as jinja templates, with keywords `task: Task` and `context: AgentContext`.",
)
status: TaskStatus = TaskStatus.PENDING
result: T = None
result_type: Union[type[T], GenericAlias, _LiteralGenericAlias, None] = Field(
Expand All @@ -123,14 +128,6 @@ class Task(ControlFlowModel):
False,
description="Work on private tasks is not visible to agents other than those assigned to the task.",
)
agent_strategy: Optional[Callable] = Field(
None,
description="A function that returns an agent, used for customizing how "
"the next agent is selected. The returned agent must be one "
"of the assigned agents. If not provided, will be inferred "
"from the parent task; round-robin selection is the default. "
"Only used for tasks with more than one agent assigned.",
)
created_at: datetime.datetime = Field(default_factory=datetime.datetime.now)
max_iterations: Optional[int] = Field(
default_factory=lambda: controlflow.settings.max_task_iterations,
Expand Down Expand Up @@ -453,24 +450,6 @@ def get_agent(self) -> "Agent":
else:
return controlflow.defaults.agent

def get_agent_strategy(self) -> Callable:
"""
Get a function for selecting the next agent to work on this
task.

If an agent_strategy is provided, it will be used. Otherwise, the parent
task's agent_strategy will be used. Finally, the global default agent_strategy
will be used (round-robin selection).
"""
if self.agent_strategy is not None:
return self.agent_strategy
elif self.parent:
return self.parent.get_agent_strategy()
else:
import controlflow.tasks.agent_strategies

return controlflow.tasks.agent_strategies.round_robin

def get_tools(self) -> list[Union[Tool, Callable]]:
tools = self.tools.copy()
if self.user_access:
Expand All @@ -483,7 +462,9 @@ def get_prompt(self, context: "AgentContext") -> str:
"""
from controlflow.orchestration import prompt_templates

template = prompt_templates.TaskTemplate(task=self, context=context)
template = prompt_templates.TaskTemplate(
template=self.prompt, task=self, context=context
)
return template.render()

def set_status(self, status: TaskStatus):
Expand Down
4 changes: 4 additions & 0 deletions src/controlflow/utilities/testing.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from contextlib import contextmanager
from functools import partial
from typing import Union

from langchain_core.language_models.fake_chat_models import FakeMessagesListChatModel

import controlflow
from controlflow.events.history import InMemoryHistory
from controlflow.llm.messages import AIMessage, BaseMessage
from controlflow.tasks.task import Task

SimpleTask = partial(Task, objective="test", result_type=None)


class FakeLLM(FakeMessagesListChatModel):
Expand Down
28 changes: 28 additions & 0 deletions tests/agents/test_agents.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import controlflow
import pytest
from controlflow.agents import Agent
from controlflow.agents.names import AGENTS
from controlflow.flows import Flow
from controlflow.instructions import instructions
from controlflow.orchestration.agent_context import AgentContext
from controlflow.tasks.task import Task
from langchain_openai import ChatOpenAI

Expand Down Expand Up @@ -67,3 +70,28 @@ def test_updating_the_default_model_updates_the_default_agent_model(self):
task = Task("task")
assert task.get_agents()[0].model is None
assert task.get_agents()[0].get_model() is new_model


class TestAgentPrompt:
@pytest.fixture
def agent_context(self) -> AgentContext:
return AgentContext(agent=Agent(name="Test Agent"), flow=Flow(), tasks=[])

def test_default_prompt(self):
agent = Agent()
assert agent.prompt is None

def test_default_template(self, agent_context):
agent = Agent()
prompt = agent.get_prompt(context=agent_context)
assert prompt.startswith("# Agent")

def test_custom_prompt(self, agent_context):
agent = Agent(prompt="Custom Prompt")
prompt = agent.get_prompt(context=agent_context)
assert prompt == "Custom Prompt"

def test_custom_templated_prompt(self, agent_context):
agent = Agent(prompt="{{ agent.name }}", name="abc")
prompt = agent.get_prompt(context=agent_context)
assert prompt == "abc"
27 changes: 27 additions & 0 deletions tests/flows/test_flows.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import pytest
from controlflow.agents import Agent
from controlflow.events.events import UserMessage
from controlflow.flows import Flow, get_flow
from controlflow.orchestration.agent_context import AgentContext
from controlflow.tasks.task import Task
from controlflow.utilities.context import ctx

Expand Down Expand Up @@ -157,3 +159,28 @@ def test_flow_agent_becomes_task_default(self):
with Flow(agents=[agent]):
t2 = Task("t2")
assert t2.get_agents() == [agent]


class TestFlowPrompt:
@pytest.fixture
def agent_context(self) -> AgentContext:
return AgentContext(agent=Agent(name="Test Agent"), flow=Flow(), tasks=[])

def test_default_prompt(self):
flow = Flow()
assert flow.prompt is None

def test_default_template(self, agent_context):
flow = Flow()
prompt = flow.get_prompt(context=agent_context)
assert prompt.startswith("# Flow")

def test_custom_prompt(self, agent_context):
flow = Flow(prompt="Custom Prompt")
prompt = flow.get_prompt(context=agent_context)
assert prompt == "Custom Prompt"

def test_custom_templated_prompt(self, agent_context):
flow = Flow(prompt="{{ flow.name }}", name="abc")
prompt = flow.get_prompt(context=agent_context)
assert prompt == "abc"
31 changes: 27 additions & 4 deletions tests/tasks/test_tasks.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
from functools import partial

import controlflow
import pytest
from controlflow.agents import Agent
from controlflow.flows import Flow
from controlflow.instructions import instructions
from controlflow.orchestration.agent_context import AgentContext
from controlflow.tasks.task import (
COMPLETE_STATUSES,
INCOMPLETE_STATUSES,
Task,
TaskStatus,
)
from controlflow.utilities.context import ctx

SimpleTask = partial(Task, objective="test", result_type=None)
from controlflow.utilities.testing import SimpleTask


def test_status_coverage():
Expand Down Expand Up @@ -261,3 +259,28 @@ def test_task_hash(self):
task1 = SimpleTask()
task2 = SimpleTask()
assert hash(task1) != hash(task2)


class TestTaskPrompt:
@pytest.fixture
def agent_context(self) -> AgentContext:
return AgentContext(agent=Agent(name="Test Agent"), flow=Flow(), tasks=[])

def test_default_prompt(self):
task = SimpleTask()
assert task.prompt is None

def test_default_template(self, agent_context):
task = SimpleTask()
prompt = task.get_prompt(context=agent_context)
assert prompt.startswith("## Task")

def test_custom_prompt(self, agent_context):
task = SimpleTask(prompt="Custom Prompt")
prompt = task.get_prompt(context=agent_context)
assert prompt == "Custom Prompt"

def test_custom_templated_prompt(self, agent_context):
task = SimpleTask(prompt="{{ task.objective }}", objective="abc")
prompt = task.get_prompt(context=agent_context)
assert prompt == "abc"
Loading