Skip to content

Commit

Permalink
Add new compiler backend (#203)
Browse files Browse the repository at this point in the history
Add new compiler backend
  • Loading branch information
jlowin authored Jul 3, 2024
2 parents 09d8ad3 + b698f7c commit 9eb2297
Show file tree
Hide file tree
Showing 61 changed files with 2,277 additions and 2,499 deletions.
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
recursive-include controlflow/orchestration/prompt_templates *
29 changes: 17 additions & 12 deletions examples/teacher_student.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,34 @@
from controlflow import Agent, Task, flow
from controlflow.instructions import instructions

teacher = Agent(name="teacher")
student = Agent(name="student")
teacher = Agent(name="Teacher")
student = Agent(name="Student")


@flow
def demo():
with Task("Teach a class by asking and answering 3 questions") as task:
with Task("Teach a class by asking and answering 3 questions", agents=[teacher]):
for _ in range(3):
question = Task(
"ask the student a question. Wait for the student to answer your question before asking another one.",
str,
agents=[teacher],
"Ask the student a question.", result_type=str, agents=[teacher]
)
with instructions("one sentence max"):
Task(
"answer the question",
str,

with instructions("One sentence max"):
answer = Task(
"Answer the question.",
agents=[student],
context=dict(question=question),
)

task.run()
return task
grade = Task(
"Assess the answer.",
result_type=["pass", "fail"],
agents=[teacher],
context=dict(answer=answer),
)

# run each qa session, one at a time
grade.run()


t = demo()
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ authors = [
{ name = "Jeremiah Lowin", email = "153965+jlowin@users.noreply.github.com" },
]
dependencies = [
"prefect>=3.0rc4",
"prefect>=3.0rc10",
"jinja2>=3.1.4",
"langchain_core>=0.2.9",
"langchain_openai>=0.1.8",
"langchain-anthropic>=0.1.15",
"langchain-anthropic>=0.1.19",
"markdownify>=0.12.1",
"pydantic-settings>=2.2.1",
"textual>=0.61.1",
Expand Down
4 changes: 2 additions & 2 deletions src/controlflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@

from .instructions import instructions
from .decorators import flow, task
from .llm.tools import tool
from .tools import tool

# --- Default settings ---

from .llm.models import _get_initial_default_model, get_default_model
from .flows.history import InMemoryHistory, get_default_history
from .events.history import InMemoryHistory, get_default_history

# assign to controlflow.default_model to change the default model
default_model = _get_initial_default_model()
Expand Down
88 changes: 86 additions & 2 deletions src/controlflow/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,18 @@
import random
import uuid
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Callable, Optional
from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Generator, Optional

from langchain_core.language_models import BaseChatModel
from pydantic import Field, field_serializer

import controlflow
from controlflow.events.events import Event
from controlflow.instructions import get_instructions
from controlflow.llm.messages import AIMessage, BaseMessage
from controlflow.llm.models import get_default_model
from controlflow.llm.rules import LLMRules
from controlflow.tools.talk_to_user import talk_to_user
from controlflow.tools.tools import handle_tool_call_async
from controlflow.utilities.context import ctx
from controlflow.utilities.types import ControlFlowModel

Expand All @@ -20,6 +22,8 @@

if TYPE_CHECKING:
from controlflow.tasks.task import Task
from controlflow.tools.tools import Tool

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -109,6 +113,8 @@ def get_llm_rules(self) -> LLMRules:
return controlflow.llm.rules.rules_for_model(self.get_model())

def get_tools(self) -> list[Callable]:
from controlflow.tools.talk_to_user import talk_to_user

tools = self.tools.copy()
if self.user_access:
tools.append(talk_to_user)
Expand Down Expand Up @@ -141,5 +147,83 @@ def run(self, task: "Task"):
async def run_async(self, task: "Task"):
return await task.run_async(agents=[self])

def _run_model(
self,
messages: list[BaseMessage],
additional_tools: list["Tool"] = None,
stream: bool = True,
) -> Generator[Event, None, None]:
from controlflow.events.agent_events import (
AgentMessageDeltaEvent,
AgentMessageEvent,
)
from controlflow.events.tool_events import ToolCallEvent, ToolResultEvent
from controlflow.tools.tools import as_tools, handle_tool_call

model = self.get_model()

tools = as_tools(self.get_tools() + (additional_tools or []))
if tools:
model = model.bind_tools([t.to_lc_tool() for t in tools])

if stream:
response = None
for delta in model.stream(messages):
if response is None:
response = delta
else:
response += delta
yield AgentMessageDeltaEvent(agent=self, delta=delta, snapshot=response)

else:
response: AIMessage = model.invoke(messages)

yield AgentMessageEvent(agent=self, message=response)

for tool_call in response.tool_calls + response.invalid_tool_calls:
yield ToolCallEvent(agent=self, tool_call=tool_call, message=response)

result = handle_tool_call(tool_call, tools=tools)
yield ToolResultEvent(agent=self, tool_call=tool_call, tool_result=result)

async def _run_model_async(
self,
messages: list[BaseMessage],
additional_tools: list["Tool"] = None,
stream: bool = True,
) -> AsyncGenerator[Event, None]:
from controlflow.events.agent_events import (
AgentMessageDeltaEvent,
AgentMessageEvent,
)
from controlflow.events.tool_events import ToolCallEvent, ToolResultEvent
from controlflow.tools.tools import as_tools

model = self.get_model()

tools = as_tools(self.get_tools() + (additional_tools or []))
if tools:
model = model.bind_tools([t.to_lc_tool() for t in tools])

if stream:
response = None
async for delta in model.astream(messages):
if response is None:
response = delta
else:
response += delta
yield AgentMessageDeltaEvent(agent=self, delta=delta, snapshot=response)

else:
response: AIMessage = model.invoke(messages)

yield AgentMessageEvent(agent=self, message=response)

for tool_call in response.tool_calls + response.invalid_tool_calls:
yield ToolCallEvent(agent=self, tool_call=tool_call, message=response)

result = await handle_tool_call_async(tool_call, tools=tools)
yield ToolResultEvent(agent=self, tool_call=tool_call, tool_result=result)


DEFAULT_AGENT = Agent(name="Marvin")
10 changes: 7 additions & 3 deletions src/controlflow/agents/memory.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import abc
import uuid
from typing import ClassVar, Optional, cast
from typing import TYPE_CHECKING, ClassVar, Optional, cast

from pydantic import Field

from controlflow.tools import Tool
from controlflow.utilities.context import ctx
from controlflow.utilities.types import ControlFlowModel

if TYPE_CHECKING:
from controlflow.tools import Tool


class Memory(ControlFlowModel, abc.ABC):
id: str = Field(default_factory=lambda: uuid.uuid4().hex)
Expand All @@ -27,7 +29,9 @@ def update(self, value: str, index: int = None):
def delete(self, index: int):
raise NotImplementedError()

def get_tools(self) -> list[Tool]:
def get_tools(self) -> list["Tool"]:
from controlflow.tools import Tool

update_tool = Tool.from_function(
self.update,
name="update_memory",
Expand Down
Loading

0 comments on commit 9eb2297

Please sign in to comment.