Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve basic round robin #231

Merged
merged 2 commits into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/controlflow/agents/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from . import memory
from .agent import Agent
from .teams import Team
31 changes: 28 additions & 3 deletions src/controlflow/agents/teams.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,15 @@
logger = logging.getLogger(__name__)


class Team(BaseAgent):
class BaseTeam(BaseAgent):
"""
A team is a group of agents that can be assigned to a task.

Each team consists of one or more agents, and the only requirement for a
team is to implement the `get_agent` method. This method should return one of
the agents in the team, based on some logic that determines which agent should go next.
"""

name: str = Field(
description="The name of the team.",
default_factory=lambda: random.choice(TEAMS),
Expand Down Expand Up @@ -73,7 +81,24 @@ async def _run_async(self, context: "AgentContext"):
self._iterations += 1


class RoundRobinTeam(Team):
class Team(BaseTeam):
"""
The most basic team operates in a round robin fashion
"""

def get_agent(self, context: "AgentContext"):
# TODO: only advance agent if a tool wasn't used
# if the last event was a tool result, it should be shown to the same agent instead of advancing to the next agent
last_agent_event = context.get_events(
agents=self.agents,
tasks=context.tasks,
types=["tool-result", "agent-message"],
limit=1,
)
if (
last_agent_event
and last_agent_event[0].event == "tool-result"
and not last_agent_event[0].tool_result.end_turn
):
return last_agent_event[0].agent

return self.agents[self._iterations % len(self.agents)]
22 changes: 17 additions & 5 deletions src/controlflow/orchestration/agent_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,24 @@ def add_tools(self, tools: list[Tool]):
def add_instructions(self, instructions: list[str]):
self.instructions = self.instructions + instructions

def get_events(self, agents: list[Agent] = None) -> list[Event]:
upstream_tasks = [
t for t in self.flow.graph.upstream_tasks(self.tasks) if not t.private
]
def get_events(
self,
agents: list[Agent] = None,
tasks: list[Task] = None,
limit: Optional[int] = None,
**kwargs,
) -> list[Event]:
# if tasks are not provided, include all tasks that are upstream of the current tasks
if tasks is None:
tasks = [
t for t in self.flow.graph.upstream_tasks(self.tasks) if not t.private
]

events = self.flow.get_events(
agents=self.agents + (agents or []), tasks=upstream_tasks
agents=agents or self.agents,
tasks=tasks,
limit=limit or 100,
**kwargs,
)

return events
Expand Down
4 changes: 2 additions & 2 deletions src/controlflow/tasks/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,9 @@ def __init__(
"deprecated and will be removed in future versions. "
"Please provide a single agent or team of agents instead."
)
from controlflow.agents.teams import RoundRobinTeam
from controlflow.agents.teams import Team

kwargs["agent"] = RoundRobinTeam(agents=agents)
kwargs["agent"] = Team(agents=agents)
else:
raise ValueError(
"The 'agents' argument is deprecated and cannot be used with the 'agent' argument."
Expand Down
6 changes: 4 additions & 2 deletions src/controlflow/tools/orchestration.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
T = TypeVar("T")


def generate_result_schema(result_type: type[T]) -> type[T]:
def _generate_result_schema(result_type: type[T]) -> type[T]:
if result_type is None:
return None

Expand Down Expand Up @@ -43,12 +43,13 @@ def create_task_success_tool(task: Task) -> Tool:
Create an agent-compatible tool for marking this task as successful.
"""

result_schema = generate_result_schema(task.result_type)
result_schema = _generate_result_schema(task.result_type)

@tool(
name=f"mark_task_{task.id}_successful",
description=f"Mark task {task.id} as successful.",
private=True,
end_turn=True,
)
def succeed(result: result_schema) -> str: # type: ignore
task.mark_successful(result=result)
Expand All @@ -68,6 +69,7 @@ def create_task_fail_tool(task: Task) -> Tool:
f"Mark task {task.id} as failed. Only use when technical errors prevent success. Provide a detailed reason for the failure."
),
private=True,
end_turn=True,
)
def fail(reason: str) -> str:
task.mark_failed(reason=reason)
Expand Down
12 changes: 12 additions & 0 deletions src/controlflow/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ class Tool(ControlFlowModel):
)
metadata: dict = {}
private: bool = False
end_turn: bool = Field(
False,
description="If True, using this tool will end the agent's turn instead "
"of showing the result to the agent.",
)

fn: Callable = Field(None, exclude=True)

Expand Down Expand Up @@ -253,6 +258,7 @@ class ToolResult(ControlFlowModel):
str_result: str = Field(repr=False)
is_error: bool = False
is_private: bool = False
end_turn: bool = False


def handle_tool_call(
Expand All @@ -264,6 +270,7 @@ def handle_tool_call(
"""
is_error = False
is_private = False
end_turn = False
tool = None
tool_lookup = {t.name: t for t in tools}
fn_name = tool_call["name"]
Expand All @@ -281,6 +288,7 @@ def handle_tool_call(
fn_args = tool_call["args"]
if isinstance(tool, Tool):
fn_output = tool.run(input=fn_args)
end_turn = tool.end_turn
elif isinstance(tool, langchain_core.tools.BaseTool):
fn_output = tool.invoke(input=fn_args)
else:
Expand All @@ -297,6 +305,7 @@ def handle_tool_call(
str_result=output_to_string(fn_output),
is_error=is_error,
is_private=getattr(tool, "private", is_private),
end_turn=end_turn,
)


Expand All @@ -307,6 +316,7 @@ async def handle_tool_call_async(tool_call: ToolCall, tools: list[Tool]) -> Any:
"""
is_error = False
is_private = False
end_turn = False
tool = None
tool_lookup = {t.name: t for t in tools}
fn_name = tool_call["name"]
Expand All @@ -324,6 +334,7 @@ async def handle_tool_call_async(tool_call: ToolCall, tools: list[Tool]) -> Any:
fn_args = tool_call["args"]
if isinstance(tool, Tool):
fn_output = await tool.run_async(input=fn_args)
end_turn = tool.end_turn
elif isinstance(tool, langchain_core.tools.BaseTool):
fn_output = await tool.ainvoke(input=fn_args)
else:
Expand All @@ -340,4 +351,5 @@ async def handle_tool_call_async(tool_call: ToolCall, tools: list[Tool]) -> Any:
str_result=output_to_string(fn_output),
is_error=is_error,
is_private=getattr(tool, "private", is_private),
end_turn=end_turn,
)
14 changes: 14 additions & 0 deletions tests/agents/test_teams.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from controlflow.agents import Agent, Team


def test_team_agents():
a1 = Agent(name="a1")
a2 = Agent(name="a2")
t = Team(agents=[a1, a2])
assert t.agents == [a1, a2]


def test_team_with_one_agent():
a1 = Agent(name="a1")
t = Team(agents=[a1])
assert t.agents == [a1]
1 change: 1 addition & 0 deletions tests/utilities/test_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,5 @@ def test_record_task_events(default_fake_llm):
str_result='Task 12345 ("say hello") marked successful.',
is_error=False,
is_private=True,
end_turn=True,
)