Skip to content

Commit

Permalink
Add new compiler backend (#199) (#200)
Browse files Browse the repository at this point in the history
* Add async support

* Update settings.py
  • Loading branch information
jlowin authored Jul 3, 2024
1 parent 1d62265 commit 984c157
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 2 deletions.
42 changes: 41 additions & 1 deletion src/controlflow/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import random
import uuid
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Callable, Generator, Optional
from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Generator, Optional

from langchain_core.language_models import BaseChatModel
from pydantic import Field, field_serializer
Expand All @@ -13,6 +13,7 @@
from controlflow.llm.messages import AIMessage, BaseMessage
from controlflow.llm.models import get_default_model
from controlflow.llm.rules import LLMRules
from controlflow.tools.tools import handle_tool_call_async
from controlflow.utilities.context import ctx
from controlflow.utilities.types import ControlFlowModel

Expand Down Expand Up @@ -185,5 +186,44 @@ def _run_model(
result = handle_tool_call(tool_call, tools=tools)
yield ToolResultEvent(agent=self, tool_call=tool_call, tool_result=result)

async def _run_model_async(
self,
messages: list[BaseMessage],
additional_tools: list["Tool"] = None,
stream: bool = True,
) -> AsyncGenerator[Event, None]:
from controlflow.events.agent_events import (
AgentMessageDeltaEvent,
AgentMessageEvent,
)
from controlflow.events.tool_events import ToolCallEvent, ToolResultEvent
from controlflow.tools.tools import as_tools

model = self.get_model()

tools = as_tools(self.get_tools() + (additional_tools or []))
if tools:
model = model.bind_tools([t.to_lc_tool() for t in tools])

if stream:
response = None
async for delta in model.astream(messages):
if response is None:
response = delta
else:
response += delta
yield AgentMessageDeltaEvent(agent=self, delta=delta, snapshot=response)

else:
response: AIMessage = model.invoke(messages)

yield AgentMessageEvent(agent=self, message=response)

for tool_call in response.tool_calls + response.invalid_tool_calls:
yield ToolCallEvent(agent=self, tool_call=tool_call, message=response)

result = await handle_tool_call_async(tool_call, tools=tools)
yield ToolResultEvent(agent=self, tool_call=tool_call, tool_result=result)


DEFAULT_AGENT = Agent(name="Marvin")
52 changes: 51 additions & 1 deletion src/controlflow/orchestration/controller.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Generator, Optional, TypeVar, Union
from typing import AsyncGenerator, Generator, Optional, TypeVar, Union

from pydantic import Field, field_validator

Expand Down Expand Up @@ -120,10 +120,52 @@ def run_once(self):
finally:
self.handle_event(ControllerEnd(controller=self))

async def run_once_async(self):
"""
Core pipeline for running the controller.
"""
from controlflow.events.controller_events import (
ControllerEnd,
ControllerError,
ControllerStart,
)

self.handle_event(ControllerStart(controller=self))

try:
ready_tasks = self.get_ready_tasks()

if not ready_tasks:
return

# select an agent
agent = self.get_agent(ready_tasks=ready_tasks)
active_tasks = self.get_active_tasks(agent=agent, ready_tasks=ready_tasks)

context = AgentContext(
agent=agent,
tasks=active_tasks,
flow=self.flow,
controller=self,
)

# run
await context.run_async()

except Exception as exc:
self.handle_event(ControllerError(controller=self, error=exc))
raise
finally:
self.handle_event(ControllerEnd(controller=self))

def run(self):
while any(t.is_incomplete() for t in self.tasks):
self.run_once()

async def run_async(self):
while any(t.is_incomplete() for t in self.tasks):
await self.run_once_async()

def get_ready_tasks(self) -> list[Task]:
all_tasks = self.flow.graph.upstream_tasks(self.tasks)
ready_tasks = [t for t in all_tasks if t.is_ready()]
Expand Down Expand Up @@ -286,3 +328,11 @@ def run(self) -> Generator["Event", None, None]:
messages = self.get_messages()
for event in self.agent._run_model(messages=messages, additional_tools=tools):
self.controller.handle_event(event, tasks=self.tasks, agents=[self.agent])

async def run_async(self) -> AsyncGenerator["Event", None]:
tools = self.get_tools()
messages = self.get_messages()
async for event in self.agent._run_model_async(
messages=messages, additional_tools=tools
):
self.controller.handle_event(event, tasks=self.tasks, agents=[self.agent])
63 changes: 63 additions & 0 deletions src/controlflow/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,30 @@ def run(self, input: dict):
)
return result

@prefect_task(task_run_name="Tool call: {self.name}")
async def run_async(self, input: dict):
result = self.fn(**input)
if inspect.isawaitable(result):
result = await result

# prepare artifact
passed_args = inspect.signature(self.fn).bind(**input).arguments
try:
# try to pretty print the args
passed_args = json.dumps(passed_args, indent=2)
except Exception:
pass
create_markdown_artifact(
markdown=TOOL_CALL_FUNCTION_RESULT_TEMPLATE.format(
name=self.name,
description=self.description or "(none provided)",
args=passed_args,
result=result,
),
key="tool-result",
)
return result

@classmethod
def from_function(
cls, fn: Callable, name: str = None, description: str = None, **kwargs
Expand Down Expand Up @@ -231,3 +255,42 @@ def handle_tool_call(tool_call: ToolCall, tools: list[Tool]) -> Any:
is_error=is_error,
is_private=getattr(tool, "private", False),
)


async def handle_tool_call_async(tool_call: ToolCall, tools: list[Tool]) -> Any:
"""
Given a ToolCall and set of available tools, runs the tool call and returns
a ToolResult object
"""
is_error = False
tool = None
tool_lookup = {t.name: t for t in tools}
fn_name = tool_call["name"]

if fn_name not in tool_lookup:
fn_output = f'Function "{fn_name}" not found.'
is_error = True
if controlflow.settings.tools_raise_on_error:
raise ValueError(fn_output)

if not is_error:
try:
tool = tool_lookup[fn_name]
fn_args = tool_call["args"]
if isinstance(tool, Tool):
fn_output = await tool.run_async(input=fn_args)
elif isinstance(tool, langchain_core.tools.BaseTool):
fn_output = await tool.ainvoke(input=fn_args)
except Exception as exc:
fn_output = f'Error calling function "{fn_name}": {exc}'
is_error = True
if controlflow.settings.tools_raise_on_error:
raise exc

return ToolResult(
tool_call_id=tool_call["id"],
result=fn_output,
str_result=output_to_string(fn_output),
is_error=is_error,
is_private=getattr(tool, "private", False),
)
File renamed without changes.

0 comments on commit 984c157

Please sign in to comment.