From fc884a008354f8901e8d98c28550f8e908b2491c Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi Date: Wed, 9 Oct 2024 10:01:35 +0200 Subject: [PATCH 01/11] first draft --- src/banks/env.py | 3 ++- src/banks/filters/__init__.py | 2 ++ src/banks/filters/tool.py | 8 ++++++++ 3 files changed, 12 insertions(+), 1 deletion(-) create mode 100644 src/banks/filters/tool.py diff --git a/src/banks/env.py b/src/banks/env.py index 68a1fe3..247dd34 100644 --- a/src/banks/env.py +++ b/src/banks/env.py @@ -4,7 +4,7 @@ from jinja2 import Environment, PackageLoader, select_autoescape from .config import config -from .filters import cache_control, lemmatize +from .filters import cache_control, lemmatize, tool def _add_extensions(_env): @@ -40,4 +40,5 @@ def _add_extensions(_env): # Setup custom filters and defaults env.filters["lemmatize"] = lemmatize env.filters["cache_control"] = cache_control +env.filters["tool"] = tool _add_extensions(env) diff --git a/src/banks/filters/__init__.py b/src/banks/filters/__init__.py index 23d05f3..27645ad 100644 --- a/src/banks/filters/__init__.py +++ b/src/banks/filters/__init__.py @@ -3,8 +3,10 @@ # SPDX-License-Identifier: MIT from .cache_control import cache_control from .lemmatize import lemmatize +from .tool import tool __all__ = ( "cache_control", "lemmatize", + "tool", ) diff --git a/src/banks/filters/tool.py b/src/banks/filters/tool.py new file mode 100644 index 0000000..68f7262 --- /dev/null +++ b/src/banks/filters/tool.py @@ -0,0 +1,8 @@ +# SPDX-FileCopyrightText: 2023-present Massimiliano Pippi +# +# SPDX-License-Identifier: MIT +from typing import Callable + + +def tool(value: Callable) -> str: + return f"Function {value}" From b715793d09b9d41f013babf496f728e468e44531 Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi Date: Mon, 14 Oct 2024 09:09:49 +0200 Subject: [PATCH 02/11] add introspection utils --- src/banks/utils.py | 41 +++++++++++++++++++++++++++ tests/test_utils.py | 67 ++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 107 insertions(+), 1 deletion(-) diff --git a/src/banks/utils.py b/src/banks/utils.py index eb29581..ccc8fa7 100644 --- a/src/banks/utils.py +++ b/src/banks/utils.py @@ -1,5 +1,7 @@ import secrets +from griffe import Docstring, parse_google, parse_numpy, parse_sphinx + def strtobool(val: str) -> bool: """ @@ -22,3 +24,42 @@ def strtobool(val: str) -> bool: def generate_canary_word(prefix: str = "BANKS[", suffix: str = "]", token_length: int = 8) -> str: return f"{prefix}{secrets.token_hex(token_length // 2)}{suffix}" + + +def python_type_to_jsonschema(python_type: type) -> str: + """Given a Python type, returns the jsonschema string describing it.""" + if python_type is str: + return "string" + elif python_type is int: + return "integer" + elif python_type is float: + return "number" + elif python_type is bool: + return "boolean" + elif python_type is list: + return "array" + elif python_type is dict: + return "object" + else: + msg = f"Unsupported type: {python_type}" + raise ValueError(msg) + + +def parse_params_from_docstring(docstring: str) -> dict[str, dict[str, str]]: + param_docs = [] + ds = Docstring(docstring) + for parser in (parse_google, parse_numpy, parse_sphinx): + sections = parser(ds) + for section in sections: + if section.kind == "parameters": + param_docs = section.value + break + if param_docs: + break + + ret = {} + for d in param_docs: + d_dict = d.as_dict() + ret[d_dict.pop("name")] = d_dict + + return ret diff --git a/tests/test_utils.py b/tests/test_utils.py index d3f9b32..73450b0 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,7 +1,7 @@ import pytest import regex as re -from banks.utils import generate_canary_word, strtobool +from banks.utils import generate_canary_word, parse_params_from_docstring, python_type_to_jsonschema, strtobool def test_generate_canary_word_defaults(): @@ -45,3 +45,68 @@ def test_strtobool_error(): ) def test_strtobool(test_input, expected): assert strtobool(test_input) == expected + + +@pytest.mark.parametrize( + "test_input,expected", + [ + (type("I am a string"), "string"), # noqa + (type(42), "integer"), # noqa + (type(0.42), "number"), # noqa + (type(True), "boolean"), # noqa + (type([]), "array"), + (type({"foo": "bar"}), "object"), + ], +) +def test_python_type_to_jsonschema(test_input, expected): + assert python_type_to_jsonschema(test_input) == expected + with pytest.raises(ValueError, match="Unsupported type: "): + python_type_to_jsonschema(type(Exception)) + + +def test_parse_params_from_docstring_google(): + def my_test_function(test_param: str): + """A docstring. + + Args: + test_param (str): The test parameter. + """ + pass + + assert parse_params_from_docstring(my_test_function.__doc__) == { # type: ignore + "test_param": {"annotation": "str", "description": "The test parameter."} + } + + +def test_parse_params_from_docstring_numpy(): + def my_test_function(test_param: str): + """A docstring. + + Parameters + ---------- + test_param : str + The test parameter. + """ + pass + + assert parse_params_from_docstring(my_test_function.__doc__) == { # type: ignore + "test_param": {"annotation": "str", "description": "The test parameter."} + } + + +def test_parse_params_from_docstring_sphinx(): + def my_test_function(test_param: str): + """A docstring. + + :param test_param: The test parameter. + :type test_param: str + """ + pass + + assert parse_params_from_docstring(my_test_function.__doc__) == { # type: ignore + "test_param": {"annotation": "str", "description": "The test parameter."} + } + + +def test_parse_params_from_docstring_empty(): + assert parse_params_from_docstring("") == {} From 2fad1dfed73baf72214457015763eac49ee827f4 Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi Date: Tue, 15 Oct 2024 10:08:11 +0200 Subject: [PATCH 03/11] add tool filter --- src/banks/filters/tool.py | 5 ++- src/banks/types.py | 88 +++++++++++++++++++++++++++++++++++++++ tests/test_tool.py | 56 +++++++++++++++++++++++++ 3 files changed, 148 insertions(+), 1 deletion(-) create mode 100644 tests/test_tool.py diff --git a/src/banks/filters/tool.py b/src/banks/filters/tool.py index 68f7262..1aca0ec 100644 --- a/src/banks/filters/tool.py +++ b/src/banks/filters/tool.py @@ -3,6 +3,9 @@ # SPDX-License-Identifier: MIT from typing import Callable +from banks.types import Tool + def tool(value: Callable) -> str: - return f"Function {value}" + tool = Tool.from_callable(value) + return tool.model_dump_json() diff --git a/src/banks/types.py b/src/banks/types.py index 43b34ef..bd173c9 100644 --- a/src/banks/types.py +++ b/src/banks/types.py @@ -2,8 +2,13 @@ # # SPDX-License-Identifier: MIT from enum import Enum +from inspect import Parameter, getdoc, signature +from typing import Callable from pydantic import BaseModel +from typing_extensions import Self + +from .utils import parse_params_from_docstring, python_type_to_jsonschema # pylint: disable=invalid-name @@ -49,3 +54,86 @@ class Config: class ChatMessage(BaseModel): role: str content: str | ChatMessageContent + + +class ChatWithToolMessage(ChatMessage): + tool_call_id: str + name: str + + +class FunctionParameter(BaseModel): + type: str + description: str + + +class FunctionParameters(BaseModel): + type: str = "object" + properties: dict[str, FunctionParameter] + required: list[str] + + +class Function(BaseModel): + name: str + description: str + parameters: FunctionParameters + + +class Tool(BaseModel): + """A model representing a Tool to be used in function calling. + + This model should dump the following: + ``` + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + }, + } + ``` + """ + + type: str = "function" + function: Function + _import_path: str + + @classmethod + def from_callable(cls, func: Callable) -> Self: + sig = signature(func) + + # Try getting params descriptions from docstrings + param_docs = parse_params_from_docstring(func.__doc__ or "") + + # If the docstring is missing, use the qualname space-separated hopefully + # the LLM will get some more context than just using the function name. + description = getdoc(func) or " ".join(func.__qualname__.split(".")) + properties = {} + required = [] + for name, param in sig.parameters.items(): + p = FunctionParameter( + description=param_docs.get(name, {}).get("description", ""), + type=python_type_to_jsonschema(param.annotation), + ) + properties[name] = p + if param.default == Parameter.empty: + required.append(name) + + return cls( + function=Function( + name=func.__name__, + description=description, + parameters=FunctionParameters(properties=properties, required=required), + ), + _import_path=f"{func.__module__}.{func.__qualname__}", + ) diff --git a/tests/test_tool.py b/tests/test_tool.py new file mode 100644 index 0000000..b7a272d --- /dev/null +++ b/tests/test_tool.py @@ -0,0 +1,56 @@ +import inspect + +from banks.filters import tool +from banks.types import Tool + + +def test_tool(): + def my_tool_function(myparam: str): + """Description of the tool. + + Args: + myparam (str): description of the parameter + + """ + pass + + tool_dump = tool(my_tool_function) + t = Tool.model_validate_json(tool_dump) + assert t.model_dump() == { + "function": { + "description": inspect.getdoc(my_tool_function), + "name": "my_tool_function", + "parameters": { + "properties": {"myparam": {"description": "description " "of the " "parameter", "type": "string"}}, + "required": ["myparam"], + "type": "object", + }, + }, + "type": "function", + } + + +def test_tool_with_defaults(): + def my_tool_function(myparam: str = ""): + """Description of the tool. + + Args: + myparam (str): description of the parameter + + """ + pass + + tool_dump = tool(my_tool_function) + t = Tool.model_validate_json(tool_dump) + assert t.model_dump() == { + "function": { + "description": inspect.getdoc(my_tool_function), + "name": "my_tool_function", + "parameters": { + "properties": {"myparam": {"description": "description " "of the " "parameter", "type": "string"}}, + "required": [], + "type": "object", + }, + }, + "type": "function", + } From a57ac4f33c0989a0a81cea79d50d6238de0f8a29 Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi Date: Tue, 15 Oct 2024 11:07:13 +0200 Subject: [PATCH 04/11] checkpoint --- src/banks/errors.py | 4 ++ src/banks/extensions/completion.py | 97 ++++++++++++++++++++++++------ tests/test_completion.py | 57 +++++++++++++----- 3 files changed, 124 insertions(+), 34 deletions(-) diff --git a/src/banks/errors.py b/src/banks/errors.py index 875caf1..d6f8cf2 100644 --- a/src/banks/errors.py +++ b/src/banks/errors.py @@ -19,3 +19,7 @@ class PromptNotFoundError(Exception): class InvalidPromptError(Exception): """The prompt is not valid.""" + + +class LLMError(Exception): + """The LLM had problems.""" diff --git a/src/banks/extensions/completion.py b/src/banks/extensions/completion.py index dad849e..0d6be41 100644 --- a/src/banks/extensions/completion.py +++ b/src/banks/extensions/completion.py @@ -1,15 +1,18 @@ # SPDX-FileCopyrightText: 2023-present Massimiliano Pippi # # SPDX-License-Identifier: MIT +import importlib +import json from typing import cast from jinja2 import TemplateSyntaxError, nodes from jinja2.ext import Extension from litellm import acompletion, completion -from litellm.types.utils import ModelResponse +from litellm.types.utils import Choices, ModelResponse from pydantic import ValidationError -from banks.types import ChatMessage +from banks.errors import InvalidPromptError, LLMError +from banks.types import ChatMessage, ChatWithToolMessage, Tool SUPPORTED_KWARGS = ("model",) @@ -72,41 +75,97 @@ def parse(self, parser): return nodes.CallBlock(self.call_method("_do_completion_async", args), [], [], body).set_lineno(lineno) return nodes.CallBlock(self.call_method("_do_completion", args), [], [], body).set_lineno(lineno) + def _get_tool_callable(self, tools, tool_call): + for tool in tools: + if tool.function.name == tool_call.function.name: + module_name, func_name = tool._import_path.rsplit(".", maxsplit=1) + module = importlib.import_module(module_name) + return getattr(module, func_name) + return None + def _do_completion(self, model_name, caller): """ Helper callback. """ - messages = self._body_to_messages(caller()) - if not messages: - return "" - + messages, tools = self._body_to_messages(caller()) + + response = cast(ModelResponse, completion(model=model_name, messages=messages, tools=tools)) + choices = cast(list[Choices], response.choices) + tool_calls = choices[0].message.tool_calls + if not tool_calls: + return choices[0].message.content + + for tool_call in tool_calls: + if not tool_call.function.name: + msg = "Function name is empty" + raise LLMError(msg) + + for tool in tools: + if tool.function.name == tool_call.function.name: + module_name, func_name = tool._import_path.rsplit(".", maxsplit=1) + module = importlib.import_module(module_name) + func = getattr(module, func_name) + + function_args = json.loads(tool_call.function.arguments) + function_response = func(**function_args) + messages.append( + ChatWithToolMessage( + tool_call_id=tool_call.id, role="tool", name=tool_call.function.name, content=function_response + ) + ) response = cast(ModelResponse, completion(model=model_name, messages=messages)) - return response.choices[0].message.content # type: ignore + choices = cast(list[Choices], response.choices) + return choices[0].message.content async def _do_completion_async(self, model_name, caller): """ Helper callback. """ - messages = self._body_to_messages(caller()) - if not messages: - return "" - + messages, tools = self._body_to_messages(caller()) + + response = cast(ModelResponse, await acompletion(model=model_name, messages=messages, tools=tools)) + choices = cast(list[Choices], response.choices) + tool_calls = choices[0].message.tool_calls or [] + if not tool_calls: + return choices[0].message.content + + for tool_call in tool_calls: + if not tool_call.function.name: + msg = "Function name is empty" + raise LLMError(msg) + + for tool in tools: + if tool.function.name == tool_call.function.name: + module_name, func_name = tool._import_path.rsplit(".", maxsplit=1) + module = importlib.import_module(module_name) + func = getattr(module, func_name) + + function_args = json.loads(tool_call.function.arguments) + function_response = func(**function_args) + messages.append( + ChatWithToolMessage( + tool_call_id=tool_call.id, role="tool", name=tool_call.function.name, content=function_response + ) + ) response = cast(ModelResponse, await acompletion(model=model_name, messages=messages)) - return response.choices[0].message.content # type: ignore + choices = cast(list[Choices], response.choices) + return choices[0].message.content - def _body_to_messages(self, body: str) -> list[ChatMessage]: + def _body_to_messages(self, body: str) -> tuple[list[ChatMessage], list[Tool]]: body = body.strip() - if not body: - return [] - messages = [] + tools = [] for line in body.split("\n"): try: messages.append(ChatMessage.model_validate_json(line)) except ValidationError: # pylint: disable=R0801 - pass + try: + tools.append(Tool.model_validate_json(line)) + except ValidationError: + pass if not messages: - messages.append(ChatMessage(role="user", content=body)) + msg = "Completion must contain at least one chat message" + raise InvalidPromptError(msg) - return messages + return (messages, tools) diff --git a/tests/test_completion.py b/tests/test_completion.py index 7781f0f..dd53b1b 100644 --- a/tests/test_completion.py +++ b/tests/test_completion.py @@ -3,6 +3,7 @@ import pytest from jinja2.environment import Environment +from banks.errors import InvalidPromptError from banks.extensions.completion import CompletionExtension from banks.types import ChatMessage @@ -12,25 +13,51 @@ def ext(): return CompletionExtension(environment=Environment()) -def test__body_to_messages(ext): - assert ext._body_to_messages(" ") == [] - assert ext._body_to_messages(' \n{"role":"user", "content":"hello"}') == [ChatMessage(role="user", content="hello")] - assert ext._body_to_messages(" \nhello\n ") == [ChatMessage(role="user", content="hello")] - assert ext._body_to_messages('{"role":"user", "content":"hello"}\n HELLO!') == [ - ChatMessage(role="user", content="hello") - ] +@pytest.fixture +def mocked_choices_no_tools(): + return [mock.MagicMock(message=mock.MagicMock(tool_calls=None, content="some response"))] -def test__do_completion(ext): - assert ext._do_completion("test-model", lambda: " ") == "" +def test__body_to_messages(ext): + assert ext._body_to_messages(' \n{"role":"user", "content":"hello"}') == ( + [ChatMessage(role="user", content="hello")], + [], + ) + assert ext._body_to_messages('{"role":"user", "content":"hello"}\n HELLO!') == ( + [ChatMessage(role="user", content="hello")], + [], + ) + with pytest.raises(InvalidPromptError, match="Completion must contain at least one chat message"): + ext._body_to_messages(" ") + with pytest.raises(InvalidPromptError, match="Completion must contain at least one chat message"): + ext._body_to_messages(" \nhello\n ") + + +def test__do_completion_no_prompt(ext): + with pytest.raises(InvalidPromptError, match="Completion must contain at least one chat message"): + ext._do_completion("test-model", lambda: " ") + + +def test__do_completion_no_tools(ext, mocked_choices_no_tools): with mock.patch("banks.extensions.completion.completion") as mocked_completion: - ext._do_completion("test-model", lambda: "hello") - mocked_completion.assert_called_with(model="test-model", messages=[ChatMessage(role="user", content="hello")]) + mocked_completion.return_value.choices = mocked_choices_no_tools + ext._do_completion("test-model", lambda: '{"role":"user", "content":"hello"}') + mocked_completion.assert_called_with( + model="test-model", messages=[ChatMessage(role="user", content="hello")], tools=[] + ) + + +@pytest.mark.asyncio +async def test__do_completion_async_no_prompt(ext): + with pytest.raises(InvalidPromptError, match="Completion must contain at least one chat message"): + await ext._do_completion_async("test-model", lambda: " ") @pytest.mark.asyncio -async def test__do_completion_async(ext): - assert await ext._do_completion_async("test-model", lambda: " ") == "" +async def test__do_completion_async_no_prompt_no_tools(ext, mocked_choices_no_tools): with mock.patch("banks.extensions.completion.acompletion") as mocked_completion: - await ext._do_completion_async("test-model", lambda: "hello") - mocked_completion.assert_called_with(model="test-model", messages=[ChatMessage(role="user", content="hello")]) + mocked_completion.return_value.choices = mocked_choices_no_tools + await ext._do_completion_async("test-model", lambda: '{"role":"user", "content":"hello"}') + mocked_completion.assert_called_with( + model="test-model", messages=[ChatMessage(role="user", content="hello")], tools=[] + ) From e818f8069e764ac7dd1bb40dff856a1bf2befd67 Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi Date: Tue, 15 Oct 2024 21:21:43 +0200 Subject: [PATCH 05/11] first working version --- src/banks/extensions/chat.py | 2 +- src/banks/extensions/completion.py | 23 ++++++++++++----------- src/banks/filters/tool.py | 2 +- src/banks/types.py | 14 ++++++++------ tests/test_prompt.py | 4 ++++ tests/test_tool.py | 18 ++++++++++-------- 6 files changed, 36 insertions(+), 27 deletions(-) diff --git a/src/banks/extensions/chat.py b/src/banks/extensions/chat.py index 0549674..ee75fa2 100644 --- a/src/banks/extensions/chat.py +++ b/src/banks/extensions/chat.py @@ -108,4 +108,4 @@ def _store_chat_messages(self, role, caller): parser = _ContentBlockParser() parser.feed(caller()) cm = ChatMessage(role=role, content=parser.content) - return cm.model_dump_json() + return cm.model_dump_json(exclude_none=True) + "\n" diff --git a/src/banks/extensions/completion.py b/src/banks/extensions/completion.py index 0d6be41..6b7d200 100644 --- a/src/banks/extensions/completion.py +++ b/src/banks/extensions/completion.py @@ -12,7 +12,7 @@ from pydantic import ValidationError from banks.errors import InvalidPromptError, LLMError -from banks.types import ChatMessage, ChatWithToolMessage, Tool +from banks.types import ChatMessage, Tool SUPPORTED_KWARGS = ("model",) @@ -78,10 +78,11 @@ def parse(self, parser): def _get_tool_callable(self, tools, tool_call): for tool in tools: if tool.function.name == tool_call.function.name: - module_name, func_name = tool._import_path.rsplit(".", maxsplit=1) + module_name, func_name = tool.import_path.rsplit(".", maxsplit=1) module = importlib.import_module(module_name) return getattr(module, func_name) - return None + msg = f"Function {tool.function.name} not found in available tools" + raise ValueError(msg) def _do_completion(self, model_name, caller): """ @@ -95,24 +96,22 @@ def _do_completion(self, model_name, caller): if not tool_calls: return choices[0].message.content + messages.append(choices[0].message) # type:ignore for tool_call in tool_calls: if not tool_call.function.name: msg = "Function name is empty" raise LLMError(msg) - for tool in tools: - if tool.function.name == tool_call.function.name: - module_name, func_name = tool._import_path.rsplit(".", maxsplit=1) - module = importlib.import_module(module_name) - func = getattr(module, func_name) + func = self._get_tool_callable(tools, tool_call) function_args = json.loads(tool_call.function.arguments) function_response = func(**function_args) messages.append( - ChatWithToolMessage( + ChatMessage( tool_call_id=tool_call.id, role="tool", name=tool_call.function.name, content=function_response ) ) + response = cast(ModelResponse, completion(model=model_name, messages=messages)) choices = cast(list[Choices], response.choices) return choices[0].message.content @@ -129,6 +128,7 @@ async def _do_completion_async(self, model_name, caller): if not tool_calls: return choices[0].message.content + messages.append(choices[0].message) # type:ignore for tool_call in tool_calls: if not tool_call.function.name: msg = "Function name is empty" @@ -136,17 +136,18 @@ async def _do_completion_async(self, model_name, caller): for tool in tools: if tool.function.name == tool_call.function.name: - module_name, func_name = tool._import_path.rsplit(".", maxsplit=1) + module_name, func_name = tool.import_path.rsplit(".", maxsplit=1) module = importlib.import_module(module_name) func = getattr(module, func_name) function_args = json.loads(tool_call.function.arguments) function_response = func(**function_args) messages.append( - ChatWithToolMessage( + ChatMessage( tool_call_id=tool_call.id, role="tool", name=tool_call.function.name, content=function_response ) ) + response = cast(ModelResponse, await acompletion(model=model_name, messages=messages)) choices = cast(list[Choices], response.choices) return choices[0].message.content diff --git a/src/banks/filters/tool.py b/src/banks/filters/tool.py index 1aca0ec..a8f6a18 100644 --- a/src/banks/filters/tool.py +++ b/src/banks/filters/tool.py @@ -8,4 +8,4 @@ def tool(value: Callable) -> str: tool = Tool.from_callable(value) - return tool.model_dump_json() + return tool.model_dump_json() + "\n" diff --git a/src/banks/types.py b/src/banks/types.py index bd173c9..1d91a39 100644 --- a/src/banks/types.py +++ b/src/banks/types.py @@ -5,6 +5,7 @@ from inspect import Parameter, getdoc, signature from typing import Callable +from litellm.types.utils import Message from pydantic import BaseModel from typing_extensions import Self @@ -54,11 +55,12 @@ class Config: class ChatMessage(BaseModel): role: str content: str | ChatMessageContent + tool_call_id: str | None = None + name: str | None = None - -class ChatWithToolMessage(ChatMessage): - tool_call_id: str - name: str + @classmethod + def from_litellm(cls, msg: Message) -> Self: + return cls(role=msg.role, content=msg.content or "") class FunctionParameter(BaseModel): @@ -106,7 +108,7 @@ class Tool(BaseModel): type: str = "function" function: Function - _import_path: str + import_path: str @classmethod def from_callable(cls, func: Callable) -> Self: @@ -135,5 +137,5 @@ def from_callable(cls, func: Callable) -> Self: description=description, parameters=FunctionParameters(properties=properties, required=required), ), - _import_path=f"{func.__module__}.{func.__qualname__}", + import_path=f"{func.__module__}.{func.__qualname__}", ) diff --git a/tests/test_prompt.py b/tests/test_prompt.py index b1db95c..ace025d 100644 --- a/tests/test_prompt.py +++ b/tests/test_prompt.py @@ -89,9 +89,13 @@ def test_chat_messages(): p.text() == """ {"role":"system","content":"You are a helpful assistant.\\n"} + {"role":"user","content":"Hello, how are you?\\n"} + {"role":"system","content":"I'm doing well, thank you! How can I assist you today?\\n"} + {"role":"user","content":"Can you explain quantum computing?\\n"} + Some random text. """.strip() ) diff --git a/tests/test_tool.py b/tests/test_tool.py index b7a272d..a426c29 100644 --- a/tests/test_tool.py +++ b/tests/test_tool.py @@ -17,16 +17,17 @@ def my_tool_function(myparam: str): tool_dump = tool(my_tool_function) t = Tool.model_validate_json(tool_dump) assert t.model_dump() == { + "type": "function", "function": { - "description": inspect.getdoc(my_tool_function), "name": "my_tool_function", + "description": "Description of the tool.\n\nArgs:\n myparam (str): description of the parameter", "parameters": { - "properties": {"myparam": {"description": "description " "of the " "parameter", "type": "string"}}, - "required": ["myparam"], "type": "object", + "properties": {"myparam": {"type": "string", "description": "description of the parameter"}}, + "required": ["myparam"], }, }, - "type": "function", + "import_path": "tests.test_tool.test_tool..my_tool_function", } @@ -43,14 +44,15 @@ def my_tool_function(myparam: str = ""): tool_dump = tool(my_tool_function) t = Tool.model_validate_json(tool_dump) assert t.model_dump() == { + "type": "function", "function": { - "description": inspect.getdoc(my_tool_function), "name": "my_tool_function", + "description": "Description of the tool.\n\nArgs:\n myparam (str): description of the parameter", "parameters": { - "properties": {"myparam": {"description": "description " "of the " "parameter", "type": "string"}}, - "required": [], "type": "object", + "properties": {"myparam": {"type": "string", "description": "description of the parameter"}}, + "required": [], }, }, - "type": "function", + "import_path": "tests.test_tool.test_tool_with_defaults..my_tool_function", } From 81301ce6bf3b5bf3742b06c8c646e526dfc5703b Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi Date: Tue, 15 Oct 2024 21:30:15 +0200 Subject: [PATCH 06/11] fix linting --- tests/test_tool.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/test_tool.py b/tests/test_tool.py index a426c29..c46aed4 100644 --- a/tests/test_tool.py +++ b/tests/test_tool.py @@ -1,5 +1,3 @@ -import inspect - from banks.filters import tool from banks.types import Tool From e9f06650b8ccb3c43cde3a0d9cd17fdc397e1553 Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi Date: Tue, 15 Oct 2024 21:42:09 +0200 Subject: [PATCH 07/11] add missing deps, fix pylint errors --- pyproject.toml | 4 +++- src/banks/extensions/completion.py | 15 ++++++--------- src/banks/filters/tool.py | 4 ++-- src/banks/utils.py | 16 ++++++++-------- 4 files changed, 19 insertions(+), 20 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 5db274c..f5ada15 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: PyPy", ] dependencies = [ + "griffe", "jinja2", "litellm", "pydantic", @@ -183,8 +184,9 @@ exclude_lines = [ [[tool.mypy.overrides]] module = [ - "simplemma.*", + # "griffe.*", "litellm.*", + "simplemma.*", ] ignore_missing_imports = true diff --git a/src/banks/extensions/completion.py b/src/banks/extensions/completion.py index 6b7d200..b97f766 100644 --- a/src/banks/extensions/completion.py +++ b/src/banks/extensions/completion.py @@ -81,7 +81,7 @@ def _get_tool_callable(self, tools, tool_call): module_name, func_name = tool.import_path.rsplit(".", maxsplit=1) module = importlib.import_module(module_name) return getattr(module, func_name) - msg = f"Function {tool.function.name} not found in available tools" + msg = f"Function {tool_call.function.name} not found in available tools" raise ValueError(msg) def _do_completion(self, model_name, caller): @@ -134,17 +134,14 @@ async def _do_completion_async(self, model_name, caller): msg = "Function name is empty" raise LLMError(msg) - for tool in tools: - if tool.function.name == tool_call.function.name: - module_name, func_name = tool.import_path.rsplit(".", maxsplit=1) - module = importlib.import_module(module_name) - func = getattr(module, func_name) + func = self._get_tool_callable(tools, tool_call) - function_args = json.loads(tool_call.function.arguments) - function_response = func(**function_args) messages.append( ChatMessage( - tool_call_id=tool_call.id, role="tool", name=tool_call.function.name, content=function_response + tool_call_id=tool_call.id, + role="tool", + name=tool_call.function.name, + content=func(**json.loads(tool_call.function.arguments)), ) ) diff --git a/src/banks/filters/tool.py b/src/banks/filters/tool.py index a8f6a18..decab05 100644 --- a/src/banks/filters/tool.py +++ b/src/banks/filters/tool.py @@ -7,5 +7,5 @@ def tool(value: Callable) -> str: - tool = Tool.from_callable(value) - return tool.model_dump_json() + "\n" + t = Tool.from_callable(value) + return t.model_dump_json() + "\n" diff --git a/src/banks/utils.py b/src/banks/utils.py index ccc8fa7..e3e82d5 100644 --- a/src/banks/utils.py +++ b/src/banks/utils.py @@ -30,19 +30,19 @@ def python_type_to_jsonschema(python_type: type) -> str: """Given a Python type, returns the jsonschema string describing it.""" if python_type is str: return "string" - elif python_type is int: + if python_type is int: return "integer" - elif python_type is float: + if python_type is float: return "number" - elif python_type is bool: + if python_type is bool: return "boolean" - elif python_type is list: + if python_type is list: return "array" - elif python_type is dict: + if python_type is dict: return "object" - else: - msg = f"Unsupported type: {python_type}" - raise ValueError(msg) + + msg = f"Unsupported type: {python_type}" + raise ValueError(msg) def parse_params_from_docstring(docstring: str) -> dict[str, dict[str, str]]: From 551b57d603f7736591f23358120d19dab076e887 Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi Date: Wed, 16 Oct 2024 10:15:07 +0200 Subject: [PATCH 08/11] add function tool example to the README --- README.md | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/README.md b/README.md index 935f258..e1e2cb0 100644 --- a/README.md +++ b/README.md @@ -29,6 +29,7 @@ Docs are available [here](https://masci.github.io/banks/). - [Examples](#examples) - [:point\_right: Render a prompt template as chat messages](#point_right-render-a-prompt-template-as-chat-messages) - [:point\_right: Use a LLM to generate a text while rendering a prompt](#point_right-use-a-llm-to-generate-a-text-while-rendering-a-prompt) + - [:point\_right: Call functions directly from the prompt](#point_right-call-functions-directly-from-the-prompt) - [:point\_right: Use prompt caching from Anthropic](#point_right-use-prompt-caching-from-anthropic) - [Reuse templates from registries](#reuse-templates-from-registries) - [Async support](#async-support) @@ -136,6 +137,43 @@ Examples: > Banks uses a cache to avoid generating text again for the same template with the same context. By default > the cache is in-memory but it can be customized. +### :point_right: Call functions directly from the prompt + +Banks provides a filter `tool` that can be used to convert a callable passed to a prompt into an LLM function call. +Docstrings are used to describe the tool and its arguments, and during prompt rendering Banks will perform all the LLM +roundtrips needed in case the model wants to use a tool within a `{% completion %}` block. For example: + +```py +import platform + +from banks import Prompt + + +def get_laptop_info(): + """Get information about the user laptop. + + For example, it returns the operating system and version, along with hardware and network specs.""" + return str(platform.uname()) + + +p = Prompt(""" +{% set response %} +{% completion model="gpt-3.5-turbo-0125" %} + {% chat role="user" %}{{ query }}{% endchat %} + {{ get_laptop_info | tool }} +{% endcompletion %} +{% endset %} + +{# the variable 'response' contains the result #} + +{{ response }} +""") + +print(p.text({"query": "Can you guess the name of my laptop?", "get_laptop_info": get_laptop_info})) +# Output: +# Based on the information provided, the name of your laptop is likely "MacGiver." +``` + ### :point_right: Use prompt caching from Anthropic Several inference providers support prompt caching to save time and costs, and Anthropic in particular offers From 7b14d91b7767998eabbc84401ae1aa285a08c9bb Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi Date: Wed, 16 Oct 2024 11:02:10 +0200 Subject: [PATCH 09/11] add unit tests --- src/banks/extensions/completion.py | 2 +- src/banks/types.py | 6 +- tests/test_completion.py | 101 ++++++++++++++++++++++++++++- 3 files changed, 101 insertions(+), 8 deletions(-) diff --git a/src/banks/extensions/completion.py b/src/banks/extensions/completion.py index b97f766..d12b0a6 100644 --- a/src/banks/extensions/completion.py +++ b/src/banks/extensions/completion.py @@ -99,7 +99,7 @@ def _do_completion(self, model_name, caller): messages.append(choices[0].message) # type:ignore for tool_call in tool_calls: if not tool_call.function.name: - msg = "Function name is empty" + msg = "Malformed response: function name is empty" raise LLMError(msg) func = self._get_tool_callable(tools, tool_call) diff --git a/src/banks/types.py b/src/banks/types.py index 1d91a39..62eca9b 100644 --- a/src/banks/types.py +++ b/src/banks/types.py @@ -5,7 +5,6 @@ from inspect import Parameter, getdoc, signature from typing import Callable -from litellm.types.utils import Message from pydantic import BaseModel from typing_extensions import Self @@ -58,10 +57,6 @@ class ChatMessage(BaseModel): tool_call_id: str | None = None name: str | None = None - @classmethod - def from_litellm(cls, msg: Message) -> Self: - return cls(role=msg.role, content=msg.content or "") - class FunctionParameter(BaseModel): type: str @@ -102,6 +97,7 @@ class Tool(BaseModel): "required": ["location"], }, }, + "import_path": "module.get_current_weather", } ``` """ diff --git a/tests/test_completion.py b/tests/test_completion.py index dd53b1b..abf7524 100644 --- a/tests/test_completion.py +++ b/tests/test_completion.py @@ -1,11 +1,13 @@ +from os import getenv from unittest import mock import pytest from jinja2.environment import Environment +from litellm.types.utils import ChatCompletionMessageToolCall, Function -from banks.errors import InvalidPromptError +from banks.errors import InvalidPromptError, LLMError from banks.extensions.completion import CompletionExtension -from banks.types import ChatMessage +from banks.types import ChatMessage, Tool @pytest.fixture @@ -18,6 +20,40 @@ def mocked_choices_no_tools(): return [mock.MagicMock(message=mock.MagicMock(tool_calls=None, content="some response"))] +@pytest.fixture +def mocked_choices_with_tools(): + return [ + mock.MagicMock( + message=mock.MagicMock( + tool_calls=[ + ChatCompletionMessageToolCall( + id="call_DN6IiLULWZw7sobV6puCji1O", + function=Function( + arguments='{"location": "San Francisco", "unit": "celsius"}', name="get_current_weather" + ), + type="function", + ), + ChatCompletionMessageToolCall( + id="call_ERm1JfYO9AFo2oEWRmWUd40c", + function=Function( + arguments='{"location": "Tokyo", "unit": "celsius"}', name="get_current_weather" + ), + type="function", + ), + ChatCompletionMessageToolCall( + id="call_2lvUVB1y4wKunSxTenR0zClP", + function=Function( + arguments='{"location": "Paris", "unit": "celsius"}', name="get_current_weather" + ), + type="function", + ), + ], + content="some response", + ) + ) + ] + + def test__body_to_messages(ext): assert ext._body_to_messages(' \n{"role":"user", "content":"hello"}') == ( [ChatMessage(role="user", content="hello")], @@ -47,6 +83,30 @@ def test__do_completion_no_tools(ext, mocked_choices_no_tools): ) +def test__do_completion_with_tools(ext, mocked_choices_with_tools): + ext._get_tool_callable = mock.MagicMock(return_value=lambda location, unit: f"I got {location} with {unit}") + ext._body_to_messages = mock.MagicMock(return_value=(["message1", "message2"], ["tool1", "tool2"])) + with mock.patch("banks.extensions.completion.completion") as mocked_completion: + mocked_completion.return_value.choices = mocked_choices_with_tools + ext._do_completion("test-model", lambda: '{"role":"user", "content":"hello"}') + calls = mocked_completion.call_args_list + assert len(calls) == 2 # complete query, complete with tool results + assert calls[0].kwargs["tools"] == ["tool1", "tool2"] + assert "tools" not in calls[1].kwargs + for m in calls[1].kwargs["messages"]: + if type(m) == ChatMessage: + assert m.role == "tool" + assert m.name == "get_current_weather" + + +def test__do_completion_with_tools_malformed(ext, mocked_choices_with_tools): + mocked_choices_with_tools[0].message.tool_calls[0].function.name = None + with mock.patch("banks.extensions.completion.completion") as mocked_completion: + mocked_completion.return_value.choices = mocked_choices_with_tools + with pytest.raises(LLMError): + ext._do_completion("test-model", lambda: '{"role":"user", "content":"hello"}') + + @pytest.mark.asyncio async def test__do_completion_async_no_prompt(ext): with pytest.raises(InvalidPromptError, match="Completion must contain at least one chat message"): @@ -61,3 +121,40 @@ async def test__do_completion_async_no_prompt_no_tools(ext, mocked_choices_no_to mocked_completion.assert_called_with( model="test-model", messages=[ChatMessage(role="user", content="hello")], tools=[] ) + + +def test__get_tool_callable(ext): + tools = [ + Tool.model_validate( + { + "type": "function", + "function": { + "name": "getenv", + "description": "Get an environment variable, return None if it doesn't exist.", + "parameters": { + "type": "object", + "properties": { + "key": { + "type": "string", + "description": "The name of the environment variable", + }, + "default": { + "type": "string", + "description": "The value to return if the variable was not found", + }, + }, + "required": ["key"], + }, + }, + "import_path": "os.getenv", + } + ) + ] + tool_call = mock.MagicMock() + + tool_call.function.name = "getenv" + assert ext._get_tool_callable(tools, tool_call) == getenv + + tool_call.function.name = "another_func" + with pytest.raises(ValueError, match="Function another_func not found in available tools"): + ext._get_tool_callable(tools, tool_call) From 1552d4f2a5d356f6ea10843cb93efa4310c28090 Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi Date: Wed, 16 Oct 2024 11:06:52 +0200 Subject: [PATCH 10/11] minor --- tests/test_completion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_completion.py b/tests/test_completion.py index abf7524..f0f643d 100644 --- a/tests/test_completion.py +++ b/tests/test_completion.py @@ -94,7 +94,7 @@ def test__do_completion_with_tools(ext, mocked_choices_with_tools): assert calls[0].kwargs["tools"] == ["tool1", "tool2"] assert "tools" not in calls[1].kwargs for m in calls[1].kwargs["messages"]: - if type(m) == ChatMessage: + if type(m) is ChatMessage: assert m.role == "tool" assert m.name == "get_current_weather" From 2cdbace275eb682c2031dae2773c63d1165b1827 Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi Date: Thu, 17 Oct 2024 09:59:31 +0200 Subject: [PATCH 11/11] add async unit tests --- tests/test_completion.py | 100 +++++++++++++++++++++++++++------------ 1 file changed, 70 insertions(+), 30 deletions(-) diff --git a/tests/test_completion.py b/tests/test_completion.py index f0f643d..4bb7816 100644 --- a/tests/test_completion.py +++ b/tests/test_completion.py @@ -54,6 +54,36 @@ def mocked_choices_with_tools(): ] +@pytest.fixture +def tools(): + return [ + Tool.model_validate( + { + "type": "function", + "function": { + "name": "getenv", + "description": "Get an environment variable, return None if it doesn't exist.", + "parameters": { + "type": "object", + "properties": { + "key": { + "type": "string", + "description": "The name of the environment variable", + }, + "default": { + "type": "string", + "description": "The value to return if the variable was not found", + }, + }, + "required": ["key"], + }, + }, + "import_path": "os.getenv", + } + ) + ] + + def test__body_to_messages(ext): assert ext._body_to_messages(' \n{"role":"user", "content":"hello"}') == ( [ChatMessage(role="user", content="hello")], @@ -74,6 +104,12 @@ def test__do_completion_no_prompt(ext): ext._do_completion("test-model", lambda: " ") +@pytest.mark.asyncio +async def test__do_completion_async_no_prompt(ext): + with pytest.raises(InvalidPromptError, match="Completion must contain at least one chat message"): + await ext._do_completion_async("test-model", lambda: " ") + + def test__do_completion_no_tools(ext, mocked_choices_no_tools): with mock.patch("banks.extensions.completion.completion") as mocked_completion: mocked_completion.return_value.choices = mocked_choices_no_tools @@ -83,6 +119,16 @@ def test__do_completion_no_tools(ext, mocked_choices_no_tools): ) +@pytest.mark.asyncio +async def test__do_completion_async_no_tools(ext, mocked_choices_no_tools): + with mock.patch("banks.extensions.completion.acompletion") as mocked_completion: + mocked_completion.return_value.choices = mocked_choices_no_tools + await ext._do_completion_async("test-model", lambda: '{"role":"user", "content":"hello"}') + mocked_completion.assert_called_with( + model="test-model", messages=[ChatMessage(role="user", content="hello")], tools=[] + ) + + def test__do_completion_with_tools(ext, mocked_choices_with_tools): ext._get_tool_callable = mock.MagicMock(return_value=lambda location, unit: f"I got {location} with {unit}") ext._body_to_messages = mock.MagicMock(return_value=(["message1", "message2"], ["tool1", "tool2"])) @@ -99,6 +145,23 @@ def test__do_completion_with_tools(ext, mocked_choices_with_tools): assert m.name == "get_current_weather" +@pytest.mark.asyncio +async def test__do_completion_async_with_tools(ext, mocked_choices_with_tools): + ext._get_tool_callable = mock.MagicMock(return_value=lambda location, unit: f"I got {location} with {unit}") + ext._body_to_messages = mock.MagicMock(return_value=(["message1", "message2"], ["tool1", "tool2"])) + with mock.patch("banks.extensions.completion.acompletion") as mocked_completion: + mocked_completion.return_value.choices = mocked_choices_with_tools + await ext._do_completion_async("test-model", lambda: '{"role":"user", "content":"hello"}') + calls = mocked_completion.call_args_list + assert len(calls) == 2 # complete query, complete with tool results + assert calls[0].kwargs["tools"] == ["tool1", "tool2"] + assert "tools" not in calls[1].kwargs + for m in calls[1].kwargs["messages"]: + if type(m) is ChatMessage: + assert m.role == "tool" + assert m.name == "get_current_weather" + + def test__do_completion_with_tools_malformed(ext, mocked_choices_with_tools): mocked_choices_with_tools[0].message.tool_calls[0].function.name = None with mock.patch("banks.extensions.completion.completion") as mocked_completion: @@ -108,9 +171,12 @@ def test__do_completion_with_tools_malformed(ext, mocked_choices_with_tools): @pytest.mark.asyncio -async def test__do_completion_async_no_prompt(ext): - with pytest.raises(InvalidPromptError, match="Completion must contain at least one chat message"): - await ext._do_completion_async("test-model", lambda: " ") +async def test__do_completion_async_with_tools_malformed(ext, mocked_choices_with_tools): + mocked_choices_with_tools[0].message.tool_calls[0].function.name = None + with mock.patch("banks.extensions.completion.acompletion") as mocked_completion: + mocked_completion.return_value.choices = mocked_choices_with_tools + with pytest.raises(LLMError): + await ext._do_completion_async("test-model", lambda: '{"role":"user", "content":"hello"}') @pytest.mark.asyncio @@ -123,33 +189,7 @@ async def test__do_completion_async_no_prompt_no_tools(ext, mocked_choices_no_to ) -def test__get_tool_callable(ext): - tools = [ - Tool.model_validate( - { - "type": "function", - "function": { - "name": "getenv", - "description": "Get an environment variable, return None if it doesn't exist.", - "parameters": { - "type": "object", - "properties": { - "key": { - "type": "string", - "description": "The name of the environment variable", - }, - "default": { - "type": "string", - "description": "The value to return if the variable was not found", - }, - }, - "required": ["key"], - }, - }, - "import_path": "os.getenv", - } - ) - ] +def test__get_tool_callable(ext, tools): tool_call = mock.MagicMock() tool_call.function.name = "getenv"