Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure LC tools work #230

Merged
merged 3 commits into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ tests = [
"pytest>=7.0",
"pytest-timeout",
"pytest-xdist",
"langchain_community",
"duckduckgo-search",
]
dev = [
"controlflow[tests]",
Expand Down
13 changes: 9 additions & 4 deletions src/controlflow/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@
from controlflow.instructions import get_instructions
from controlflow.llm.messages import AIMessage, BaseMessage
from controlflow.llm.rules import LLMRules
from controlflow.tools.tools import handle_tool_call, handle_tool_call_async
from controlflow.tools.tools import (
as_lc_tools,
as_tools,
handle_tool_call,
handle_tool_call_async,
)
from controlflow.utilities.context import ctx
from controlflow.utilities.types import ControlFlowModel

Expand Down Expand Up @@ -156,7 +161,7 @@ def get_model(self, tools: list["Tool"] = None) -> BaseChatModel:
f"Agent {self.name}: No model provided and no default model could be loaded."
)
if tools:
model = model.bind_tools([t.to_lc_tool() for t in tools])
model = model.bind_tools(as_lc_tools(tools))
return model

def get_llm_rules(self) -> LLMRules:
Expand All @@ -165,7 +170,7 @@ def get_llm_rules(self) -> LLMRules:
"""
return controlflow.llm.rules.rules_for_model(self.get_model())

def get_tools(self) -> list[Callable]:
def get_tools(self) -> list["Tool"]:
from controlflow.tools.talk_to_user import talk_to_user

tools = self.tools.copy()
Expand All @@ -174,7 +179,7 @@ def get_tools(self) -> list[Callable]:
if self.memory is not None:
tools.extend(self.memory.get_tools())

return tools
return as_tools(tools)

@contextmanager
def create_context(self):
Expand Down
8 changes: 4 additions & 4 deletions src/controlflow/orchestration/agent_context.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from contextlib import ExitStack
from functools import partial, wraps
from typing import Callable, Optional
from typing import Any, Callable, Optional

from pydantic import Field

Expand All @@ -11,7 +11,7 @@
from controlflow.llm.messages import BaseMessage
from controlflow.orchestration.handler import Handler
from controlflow.tasks.task import Task
from controlflow.tools.tools import Tool
from controlflow.tools.tools import Tool, as_tools
from controlflow.utilities.context import ctx
from controlflow.utilities.types import ControlFlowModel

Expand All @@ -30,7 +30,7 @@ class AgentContext(ControlFlowModel):
model_config = dict(arbitrary_types_allowed=True)
flow: Flow
tasks: list[Task]
tools: list[Tool] = []
tools: list[Any] = []
agents: list[BaseAgent] = Field(
default_factory=list,
description="Any other agents that are relevant to this operation, in order to properly load events",
Expand Down Expand Up @@ -59,7 +59,7 @@ def add_handlers(self, handlers: list[Handler]):
self.handlers = self.handlers + handlers

def add_tools(self, tools: list[Tool]):
self.tools = self.tools + tools
self.tools = self.tools + as_tools(tools)

def add_instructions(self, instructions: list[str]):
self.instructions = self.instructions + instructions
Expand Down
2 changes: 1 addition & 1 deletion src/controlflow/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ class Settings(ControlFlowSettings):
)

tools_verbose: bool = Field(
default=False, description="If True, tools will log additional information."
default=True, description="If True, tools will log additional information."
)

# ------------ Prefect settings ------------
Expand Down
25 changes: 23 additions & 2 deletions src/controlflow/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,10 +209,27 @@ def as_tools(
else:
raise ValueError(f"Invalid tool: {t}")

if (t.name, t.description, t.fn) in seen:
if (t.name, t.description) in seen:
continue
new_tools.append(t)
seen.add((t.name, t.description, t.fn))
seen.add((t.name, t.description))
return new_tools


def as_lc_tools(
tools: list[Union[Callable, langchain_core.tools.BaseTool, Tool]],
) -> list[langchain_core.tools.BaseTool]:
new_tools = []
for t in tools:
if isinstance(t, langchain_core.tools.BaseTool):
pass
elif isinstance(t, Tool):
t = t.to_lc_tool()
elif inspect.isfunction(t):
t = langchain_core.tools.StructuredTool.from_function(t)
else:
raise ValueError(f"Invalid tool: {t}")
new_tools.append(t)
return new_tools


Expand Down Expand Up @@ -266,6 +283,8 @@ def handle_tool_call(
fn_output = tool.run(input=fn_args)
elif isinstance(tool, langchain_core.tools.BaseTool):
fn_output = tool.invoke(input=fn_args)
else:
raise ValueError(f"Invalid tool: {tool}")
except Exception as exc:
fn_output = f'Error calling function "{fn_name}": {exc}'
is_error = True
Expand Down Expand Up @@ -307,6 +326,8 @@ async def handle_tool_call_async(tool_call: ToolCall, tools: list[Tool]) -> Any:
fn_output = await tool.run_async(input=fn_args)
elif isinstance(tool, langchain_core.tools.BaseTool):
fn_output = await tool.ainvoke(input=fn_args)
else:
raise ValueError(f"Invalid tool: {tool}")
except Exception as exc:
fn_output = f'Error calling function "{fn_name}": {exc}'
is_error = True
Expand Down
124 changes: 72 additions & 52 deletions tests/test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@


def test_defaults():
# ensure that debug settings etc. are not left on by default
# ensure that debug settings etc. are not misconfigured during development
# change these settings to match whatever the default should be
assert controlflow.settings.tools_raise_on_error is False
assert controlflow.settings.tools_verbose is False
assert controlflow.settings.tools_verbose is True


def test_temporary_settings():
Expand All @@ -35,65 +36,84 @@ def test_prefect_settings_apply_at_runtime(caplog):


def test_import_without_default_api_key_warns_but_does_not_fail(monkeypatch, caplog):
# remove the OPENAI_API_KEY environment variable
monkeypatch.delenv("OPENAI_API_KEY", raising=False)

# Clear any previous logs
caplog.clear()

# Import the library
with caplog.at_level("WARNING"):
# Reload the library to apply changes
try:
with monkeypatch.context() as m:
# remove the OPENAI_API_KEY environment variable
m.delenv("OPENAI_API_KEY", raising=False)

# Clear any previous logs
caplog.clear()

# Import the library
with caplog.at_level("WARNING"):
# Reload the library to apply changes
defaults_module = importlib.import_module("controlflow.defaults")
importlib.reload(defaults_module)
importlib.reload(controlflow)

# Check if the warning was logged
assert any(
record.levelname == "WARNING"
and "The default LLM model could not be created" in record.message
for record in caplog.records
), "The expected warning was not logged"
finally:
defaults_module = importlib.import_module("controlflow.defaults")
importlib.reload(controlflow)
importlib.reload(defaults_module)

# Check if the warning was logged
assert any(
record.levelname == "WARNING"
and "The default LLM model could not be created" in record.message
for record in caplog.records
), "The expected warning was not logged"
importlib.reload(controlflow)


def test_import_without_default_api_key_errors_when_loading_model(monkeypatch):
# remove the OPENAI_API_KEY environment variable
monkeypatch.delenv("OPENAI_API_KEY", raising=False)

# Reload the library to apply changes
defaults_module = importlib.import_module("controlflow.defaults")
importlib.reload(controlflow)
importlib.reload(defaults_module)

with pytest.raises(ValueError, match="Did not find openai_api_key"):
controlflow.llm.models.get_default_model()

with pytest.raises(
ValueError, match="No model provided and no default model could be loaded"
):
controlflow.Agent().get_model()
try:
with monkeypatch.context() as m:
# remove the OPENAI_API_KEY environment variable
m.delenv("OPENAI_API_KEY", raising=False)

# Reload the library to apply changes
defaults_module = importlib.import_module("controlflow.defaults")
importlib.reload(defaults_module)
importlib.reload(controlflow)

with pytest.raises(ValueError, match="Did not find openai_api_key"):
controlflow.llm.models.get_default_model()

with pytest.raises(
ValueError,
match="No model provided and no default model could be loaded",
):
controlflow.Agent().get_model()
finally:
defaults_module = importlib.import_module("controlflow.defaults")
importlib.reload(defaults_module)
importlib.reload(controlflow)


def test_import_without_api_key_for_non_default_model_warns_but_does_not_fail(
monkeypatch, caplog
):
# remove the OPENAI_API_KEY environment variable
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
monkeypatch.setenv("CONTROLFLOW_LLM_MODEL", "anthropic/not-a-model")

# Clear any previous logs
caplog.clear()

# Import the library
with caplog.at_level("WARNING"):
# Reload the library to apply changes
try:
with monkeypatch.context() as m:
# remove the OPENAI_API_KEY environment variable
m.delenv("OPENAI_API_KEY", raising=False)
m.setenv("CONTROLFLOW_LLM_MODEL", "anthropic/not-a-model")

# Clear any previous logs
caplog.clear()

# Import the library
with caplog.at_level("WARNING"):
# Reload the library to apply changes
defaults_module = importlib.import_module("controlflow.defaults")
importlib.reload(defaults_module)
importlib.reload(controlflow)

# Check if the warning was logged
assert any(
record.levelname == "WARNING"
and "The default LLM model could not be created" in record.message
for record in caplog.records
), "The expected warning was not logged"
finally:
defaults_module = importlib.import_module("controlflow.defaults")
importlib.reload(controlflow)
importlib.reload(defaults_module)

# Check if the warning was logged
assert any(
record.levelname == "WARNING"
and "The default LLM model could not be created" in record.message
for record in caplog.records
), "The expected warning was not logged"
importlib.reload(controlflow)
13 changes: 13 additions & 0 deletions tests/tools/test_lc_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import controlflow
from langchain_community.tools import DuckDuckGoSearchRun


def test_ddg_tool():
task = controlflow.Task(
"Retrieve and summarize today's two top business headlines",
tools=[DuckDuckGoSearchRun()],
# agent=summarizer,
result_type=list[str],
)
task.run()
assert task.is_successful()