Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add function calling from {% completion %} #20

Merged
merged 11 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ Docs are available [here](https://masci.github.io/banks/).
- [Examples](#examples)
- [:point\_right: Render a prompt template as chat messages](#point_right-render-a-prompt-template-as-chat-messages)
- [:point\_right: Use a LLM to generate a text while rendering a prompt](#point_right-use-a-llm-to-generate-a-text-while-rendering-a-prompt)
- [:point\_right: Call functions directly from the prompt](#point_right-call-functions-directly-from-the-prompt)
- [:point\_right: Use prompt caching from Anthropic](#point_right-use-prompt-caching-from-anthropic)
- [Reuse templates from registries](#reuse-templates-from-registries)
- [Async support](#async-support)
Expand Down Expand Up @@ -136,6 +137,43 @@ Examples:
> Banks uses a cache to avoid generating text again for the same template with the same context. By default
> the cache is in-memory but it can be customized.

### :point_right: Call functions directly from the prompt

Banks provides a filter `tool` that can be used to convert a callable passed to a prompt into an LLM function call.
Docstrings are used to describe the tool and its arguments, and during prompt rendering Banks will perform all the LLM
roundtrips needed in case the model wants to use a tool within a `{% completion %}` block. For example:

```py
import platform

from banks import Prompt


def get_laptop_info():
"""Get information about the user laptop.

For example, it returns the operating system and version, along with hardware and network specs."""
return str(platform.uname())


p = Prompt("""
{% set response %}
{% completion model="gpt-3.5-turbo-0125" %}
{% chat role="user" %}{{ query }}{% endchat %}
{{ get_laptop_info | tool }}
{% endcompletion %}
{% endset %}

{# the variable 'response' contains the result #}

{{ response }}
""")

print(p.text({"query": "Can you guess the name of my laptop?", "get_laptop_info": get_laptop_info}))
# Output:
# Based on the information provided, the name of your laptop is likely "MacGiver."
```

### :point_right: Use prompt caching from Anthropic

Several inference providers support prompt caching to save time and costs, and Anthropic in particular offers
Expand Down
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ classifiers = [
"Programming Language :: Python :: Implementation :: PyPy",
]
dependencies = [
"griffe",
"jinja2",
"litellm",
"pydantic",
Expand Down Expand Up @@ -183,8 +184,9 @@ exclude_lines = [

[[tool.mypy.overrides]]
module = [
"simplemma.*",
# "griffe.*",
"litellm.*",
"simplemma.*",
]
ignore_missing_imports = true

Expand Down
3 changes: 2 additions & 1 deletion src/banks/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from jinja2 import Environment, PackageLoader, select_autoescape

from .config import config
from .filters import cache_control, lemmatize
from .filters import cache_control, lemmatize, tool


def _add_extensions(_env):
Expand Down Expand Up @@ -40,4 +40,5 @@ def _add_extensions(_env):
# Setup custom filters and defaults
env.filters["lemmatize"] = lemmatize
env.filters["cache_control"] = cache_control
env.filters["tool"] = tool
_add_extensions(env)
4 changes: 4 additions & 0 deletions src/banks/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,7 @@ class PromptNotFoundError(Exception):

class InvalidPromptError(Exception):
"""The prompt is not valid."""


class LLMError(Exception):
"""The LLM had problems."""
2 changes: 1 addition & 1 deletion src/banks/extensions/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,4 +108,4 @@ def _store_chat_messages(self, role, caller):
parser = _ContentBlockParser()
parser.feed(caller())
cm = ChatMessage(role=role, content=parser.content)
return cm.model_dump_json()
return cm.model_dump_json(exclude_none=True) + "\n"
91 changes: 74 additions & 17 deletions src/banks/extensions/completion.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
# SPDX-FileCopyrightText: 2023-present Massimiliano Pippi <mpippi@gmail.com>
#
# SPDX-License-Identifier: MIT
import importlib
import json
from typing import cast

from jinja2 import TemplateSyntaxError, nodes
from jinja2.ext import Extension
from litellm import acompletion, completion
from litellm.types.utils import ModelResponse
from litellm.types.utils import Choices, ModelResponse
from pydantic import ValidationError

from banks.types import ChatMessage
from banks.errors import InvalidPromptError, LLMError
from banks.types import ChatMessage, Tool

SUPPORTED_KWARGS = ("model",)

Expand Down Expand Up @@ -72,41 +75,95 @@ def parse(self, parser):
return nodes.CallBlock(self.call_method("_do_completion_async", args), [], [], body).set_lineno(lineno)
return nodes.CallBlock(self.call_method("_do_completion", args), [], [], body).set_lineno(lineno)

def _get_tool_callable(self, tools, tool_call):
for tool in tools:
if tool.function.name == tool_call.function.name:
module_name, func_name = tool.import_path.rsplit(".", maxsplit=1)
module = importlib.import_module(module_name)
return getattr(module, func_name)
msg = f"Function {tool_call.function.name} not found in available tools"
raise ValueError(msg)

def _do_completion(self, model_name, caller):
"""
Helper callback.
"""
messages = self._body_to_messages(caller())
if not messages:
return ""
messages, tools = self._body_to_messages(caller())

response = cast(ModelResponse, completion(model=model_name, messages=messages, tools=tools))
choices = cast(list[Choices], response.choices)
tool_calls = choices[0].message.tool_calls
if not tool_calls:
return choices[0].message.content

messages.append(choices[0].message) # type:ignore
for tool_call in tool_calls:
if not tool_call.function.name:
msg = "Malformed response: function name is empty"
raise LLMError(msg)

func = self._get_tool_callable(tools, tool_call)

function_args = json.loads(tool_call.function.arguments)
function_response = func(**function_args)
messages.append(
ChatMessage(
tool_call_id=tool_call.id, role="tool", name=tool_call.function.name, content=function_response
)
)

response = cast(ModelResponse, completion(model=model_name, messages=messages))
return response.choices[0].message.content # type: ignore
choices = cast(list[Choices], response.choices)
return choices[0].message.content

async def _do_completion_async(self, model_name, caller):
"""
Helper callback.
"""
messages = self._body_to_messages(caller())
if not messages:
return ""
messages, tools = self._body_to_messages(caller())

response = cast(ModelResponse, await acompletion(model=model_name, messages=messages, tools=tools))
choices = cast(list[Choices], response.choices)
tool_calls = choices[0].message.tool_calls or []
if not tool_calls:
return choices[0].message.content

messages.append(choices[0].message) # type:ignore
for tool_call in tool_calls:
if not tool_call.function.name:
msg = "Function name is empty"
raise LLMError(msg)

func = self._get_tool_callable(tools, tool_call)

messages.append(
ChatMessage(
tool_call_id=tool_call.id,
role="tool",
name=tool_call.function.name,
content=func(**json.loads(tool_call.function.arguments)),
)
)

response = cast(ModelResponse, await acompletion(model=model_name, messages=messages))
return response.choices[0].message.content # type: ignore
choices = cast(list[Choices], response.choices)
return choices[0].message.content

def _body_to_messages(self, body: str) -> list[ChatMessage]:
def _body_to_messages(self, body: str) -> tuple[list[ChatMessage], list[Tool]]:
body = body.strip()
if not body:
return []

messages = []
tools = []
for line in body.split("\n"):
try:
messages.append(ChatMessage.model_validate_json(line))
except ValidationError: # pylint: disable=R0801
pass
try:
tools.append(Tool.model_validate_json(line))
except ValidationError:
pass

if not messages:
messages.append(ChatMessage(role="user", content=body))
msg = "Completion must contain at least one chat message"
raise InvalidPromptError(msg)

return messages
return (messages, tools)
2 changes: 2 additions & 0 deletions src/banks/filters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
# SPDX-License-Identifier: MIT
from .cache_control import cache_control
from .lemmatize import lemmatize
from .tool import tool

__all__ = (
"cache_control",
"lemmatize",
"tool",
)
11 changes: 11 additions & 0 deletions src/banks/filters/tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# SPDX-FileCopyrightText: 2023-present Massimiliano Pippi <mpippi@gmail.com>
#
# SPDX-License-Identifier: MIT
from typing import Callable

from banks.types import Tool


def tool(value: Callable) -> str:
t = Tool.from_callable(value)
return t.model_dump_json() + "\n"
86 changes: 86 additions & 0 deletions src/banks/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,13 @@
#
# SPDX-License-Identifier: MIT
from enum import Enum
from inspect import Parameter, getdoc, signature
from typing import Callable

from pydantic import BaseModel
from typing_extensions import Self

from .utils import parse_params_from_docstring, python_type_to_jsonschema

# pylint: disable=invalid-name

Expand Down Expand Up @@ -49,3 +54,84 @@ class Config:
class ChatMessage(BaseModel):
role: str
content: str | ChatMessageContent
tool_call_id: str | None = None
name: str | None = None


class FunctionParameter(BaseModel):
type: str
description: str


class FunctionParameters(BaseModel):
type: str = "object"
properties: dict[str, FunctionParameter]
required: list[str]


class Function(BaseModel):
name: str
description: str
parameters: FunctionParameters


class Tool(BaseModel):
"""A model representing a Tool to be used in function calling.

This model should dump the following:
```
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
},
"import_path": "module.get_current_weather",
}
```
"""

type: str = "function"
function: Function
import_path: str

@classmethod
def from_callable(cls, func: Callable) -> Self:
sig = signature(func)

# Try getting params descriptions from docstrings
param_docs = parse_params_from_docstring(func.__doc__ or "")

# If the docstring is missing, use the qualname space-separated hopefully
# the LLM will get some more context than just using the function name.
description = getdoc(func) or " ".join(func.__qualname__.split("."))
properties = {}
required = []
for name, param in sig.parameters.items():
p = FunctionParameter(
description=param_docs.get(name, {}).get("description", ""),
type=python_type_to_jsonschema(param.annotation),
)
properties[name] = p
if param.default == Parameter.empty:
required.append(name)

return cls(
function=Function(
name=func.__name__,
description=description,
parameters=FunctionParameters(properties=properties, required=required),
),
import_path=f"{func.__module__}.{func.__qualname__}",
)
Loading