diff --git a/README.md b/README.md index 935f258..e1e2cb0 100644 --- a/README.md +++ b/README.md @@ -29,6 +29,7 @@ Docs are available [here](https://masci.github.io/banks/). - [Examples](#examples) - [:point\_right: Render a prompt template as chat messages](#point_right-render-a-prompt-template-as-chat-messages) - [:point\_right: Use a LLM to generate a text while rendering a prompt](#point_right-use-a-llm-to-generate-a-text-while-rendering-a-prompt) + - [:point\_right: Call functions directly from the prompt](#point_right-call-functions-directly-from-the-prompt) - [:point\_right: Use prompt caching from Anthropic](#point_right-use-prompt-caching-from-anthropic) - [Reuse templates from registries](#reuse-templates-from-registries) - [Async support](#async-support) @@ -136,6 +137,43 @@ Examples: > Banks uses a cache to avoid generating text again for the same template with the same context. By default > the cache is in-memory but it can be customized. +### :point_right: Call functions directly from the prompt + +Banks provides a filter `tool` that can be used to convert a callable passed to a prompt into an LLM function call. +Docstrings are used to describe the tool and its arguments, and during prompt rendering Banks will perform all the LLM +roundtrips needed in case the model wants to use a tool within a `{% completion %}` block. For example: + +```py +import platform + +from banks import Prompt + + +def get_laptop_info(): + """Get information about the user laptop. + + For example, it returns the operating system and version, along with hardware and network specs.""" + return str(platform.uname()) + + +p = Prompt(""" +{% set response %} +{% completion model="gpt-3.5-turbo-0125" %} + {% chat role="user" %}{{ query }}{% endchat %} + {{ get_laptop_info | tool }} +{% endcompletion %} +{% endset %} + +{# the variable 'response' contains the result #} + +{{ response }} +""") + +print(p.text({"query": "Can you guess the name of my laptop?", "get_laptop_info": get_laptop_info})) +# Output: +# Based on the information provided, the name of your laptop is likely "MacGiver." +``` + ### :point_right: Use prompt caching from Anthropic Several inference providers support prompt caching to save time and costs, and Anthropic in particular offers diff --git a/pyproject.toml b/pyproject.toml index 5db274c..f5ada15 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: PyPy", ] dependencies = [ + "griffe", "jinja2", "litellm", "pydantic", @@ -183,8 +184,9 @@ exclude_lines = [ [[tool.mypy.overrides]] module = [ - "simplemma.*", + # "griffe.*", "litellm.*", + "simplemma.*", ] ignore_missing_imports = true diff --git a/src/banks/env.py b/src/banks/env.py index 68a1fe3..247dd34 100644 --- a/src/banks/env.py +++ b/src/banks/env.py @@ -4,7 +4,7 @@ from jinja2 import Environment, PackageLoader, select_autoescape from .config import config -from .filters import cache_control, lemmatize +from .filters import cache_control, lemmatize, tool def _add_extensions(_env): @@ -40,4 +40,5 @@ def _add_extensions(_env): # Setup custom filters and defaults env.filters["lemmatize"] = lemmatize env.filters["cache_control"] = cache_control +env.filters["tool"] = tool _add_extensions(env) diff --git a/src/banks/errors.py b/src/banks/errors.py index 875caf1..d6f8cf2 100644 --- a/src/banks/errors.py +++ b/src/banks/errors.py @@ -19,3 +19,7 @@ class PromptNotFoundError(Exception): class InvalidPromptError(Exception): """The prompt is not valid.""" + + +class LLMError(Exception): + """The LLM had problems.""" diff --git a/src/banks/extensions/chat.py b/src/banks/extensions/chat.py index 0549674..ee75fa2 100644 --- a/src/banks/extensions/chat.py +++ b/src/banks/extensions/chat.py @@ -108,4 +108,4 @@ def _store_chat_messages(self, role, caller): parser = _ContentBlockParser() parser.feed(caller()) cm = ChatMessage(role=role, content=parser.content) - return cm.model_dump_json() + return cm.model_dump_json(exclude_none=True) + "\n" diff --git a/src/banks/extensions/completion.py b/src/banks/extensions/completion.py index dad849e..d12b0a6 100644 --- a/src/banks/extensions/completion.py +++ b/src/banks/extensions/completion.py @@ -1,15 +1,18 @@ # SPDX-FileCopyrightText: 2023-present Massimiliano Pippi # # SPDX-License-Identifier: MIT +import importlib +import json from typing import cast from jinja2 import TemplateSyntaxError, nodes from jinja2.ext import Extension from litellm import acompletion, completion -from litellm.types.utils import ModelResponse +from litellm.types.utils import Choices, ModelResponse from pydantic import ValidationError -from banks.types import ChatMessage +from banks.errors import InvalidPromptError, LLMError +from banks.types import ChatMessage, Tool SUPPORTED_KWARGS = ("model",) @@ -72,41 +75,95 @@ def parse(self, parser): return nodes.CallBlock(self.call_method("_do_completion_async", args), [], [], body).set_lineno(lineno) return nodes.CallBlock(self.call_method("_do_completion", args), [], [], body).set_lineno(lineno) + def _get_tool_callable(self, tools, tool_call): + for tool in tools: + if tool.function.name == tool_call.function.name: + module_name, func_name = tool.import_path.rsplit(".", maxsplit=1) + module = importlib.import_module(module_name) + return getattr(module, func_name) + msg = f"Function {tool_call.function.name} not found in available tools" + raise ValueError(msg) + def _do_completion(self, model_name, caller): """ Helper callback. """ - messages = self._body_to_messages(caller()) - if not messages: - return "" + messages, tools = self._body_to_messages(caller()) + + response = cast(ModelResponse, completion(model=model_name, messages=messages, tools=tools)) + choices = cast(list[Choices], response.choices) + tool_calls = choices[0].message.tool_calls + if not tool_calls: + return choices[0].message.content + + messages.append(choices[0].message) # type:ignore + for tool_call in tool_calls: + if not tool_call.function.name: + msg = "Malformed response: function name is empty" + raise LLMError(msg) + + func = self._get_tool_callable(tools, tool_call) + + function_args = json.loads(tool_call.function.arguments) + function_response = func(**function_args) + messages.append( + ChatMessage( + tool_call_id=tool_call.id, role="tool", name=tool_call.function.name, content=function_response + ) + ) response = cast(ModelResponse, completion(model=model_name, messages=messages)) - return response.choices[0].message.content # type: ignore + choices = cast(list[Choices], response.choices) + return choices[0].message.content async def _do_completion_async(self, model_name, caller): """ Helper callback. """ - messages = self._body_to_messages(caller()) - if not messages: - return "" + messages, tools = self._body_to_messages(caller()) + + response = cast(ModelResponse, await acompletion(model=model_name, messages=messages, tools=tools)) + choices = cast(list[Choices], response.choices) + tool_calls = choices[0].message.tool_calls or [] + if not tool_calls: + return choices[0].message.content + + messages.append(choices[0].message) # type:ignore + for tool_call in tool_calls: + if not tool_call.function.name: + msg = "Function name is empty" + raise LLMError(msg) + + func = self._get_tool_callable(tools, tool_call) + + messages.append( + ChatMessage( + tool_call_id=tool_call.id, + role="tool", + name=tool_call.function.name, + content=func(**json.loads(tool_call.function.arguments)), + ) + ) response = cast(ModelResponse, await acompletion(model=model_name, messages=messages)) - return response.choices[0].message.content # type: ignore + choices = cast(list[Choices], response.choices) + return choices[0].message.content - def _body_to_messages(self, body: str) -> list[ChatMessage]: + def _body_to_messages(self, body: str) -> tuple[list[ChatMessage], list[Tool]]: body = body.strip() - if not body: - return [] - messages = [] + tools = [] for line in body.split("\n"): try: messages.append(ChatMessage.model_validate_json(line)) except ValidationError: # pylint: disable=R0801 - pass + try: + tools.append(Tool.model_validate_json(line)) + except ValidationError: + pass if not messages: - messages.append(ChatMessage(role="user", content=body)) + msg = "Completion must contain at least one chat message" + raise InvalidPromptError(msg) - return messages + return (messages, tools) diff --git a/src/banks/filters/__init__.py b/src/banks/filters/__init__.py index 23d05f3..27645ad 100644 --- a/src/banks/filters/__init__.py +++ b/src/banks/filters/__init__.py @@ -3,8 +3,10 @@ # SPDX-License-Identifier: MIT from .cache_control import cache_control from .lemmatize import lemmatize +from .tool import tool __all__ = ( "cache_control", "lemmatize", + "tool", ) diff --git a/src/banks/filters/tool.py b/src/banks/filters/tool.py new file mode 100644 index 0000000..decab05 --- /dev/null +++ b/src/banks/filters/tool.py @@ -0,0 +1,11 @@ +# SPDX-FileCopyrightText: 2023-present Massimiliano Pippi +# +# SPDX-License-Identifier: MIT +from typing import Callable + +from banks.types import Tool + + +def tool(value: Callable) -> str: + t = Tool.from_callable(value) + return t.model_dump_json() + "\n" diff --git a/src/banks/types.py b/src/banks/types.py index 43b34ef..62eca9b 100644 --- a/src/banks/types.py +++ b/src/banks/types.py @@ -2,8 +2,13 @@ # # SPDX-License-Identifier: MIT from enum import Enum +from inspect import Parameter, getdoc, signature +from typing import Callable from pydantic import BaseModel +from typing_extensions import Self + +from .utils import parse_params_from_docstring, python_type_to_jsonschema # pylint: disable=invalid-name @@ -49,3 +54,84 @@ class Config: class ChatMessage(BaseModel): role: str content: str | ChatMessageContent + tool_call_id: str | None = None + name: str | None = None + + +class FunctionParameter(BaseModel): + type: str + description: str + + +class FunctionParameters(BaseModel): + type: str = "object" + properties: dict[str, FunctionParameter] + required: list[str] + + +class Function(BaseModel): + name: str + description: str + parameters: FunctionParameters + + +class Tool(BaseModel): + """A model representing a Tool to be used in function calling. + + This model should dump the following: + ``` + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + }, + "import_path": "module.get_current_weather", + } + ``` + """ + + type: str = "function" + function: Function + import_path: str + + @classmethod + def from_callable(cls, func: Callable) -> Self: + sig = signature(func) + + # Try getting params descriptions from docstrings + param_docs = parse_params_from_docstring(func.__doc__ or "") + + # If the docstring is missing, use the qualname space-separated hopefully + # the LLM will get some more context than just using the function name. + description = getdoc(func) or " ".join(func.__qualname__.split(".")) + properties = {} + required = [] + for name, param in sig.parameters.items(): + p = FunctionParameter( + description=param_docs.get(name, {}).get("description", ""), + type=python_type_to_jsonschema(param.annotation), + ) + properties[name] = p + if param.default == Parameter.empty: + required.append(name) + + return cls( + function=Function( + name=func.__name__, + description=description, + parameters=FunctionParameters(properties=properties, required=required), + ), + import_path=f"{func.__module__}.{func.__qualname__}", + ) diff --git a/src/banks/utils.py b/src/banks/utils.py index eb29581..e3e82d5 100644 --- a/src/banks/utils.py +++ b/src/banks/utils.py @@ -1,5 +1,7 @@ import secrets +from griffe import Docstring, parse_google, parse_numpy, parse_sphinx + def strtobool(val: str) -> bool: """ @@ -22,3 +24,42 @@ def strtobool(val: str) -> bool: def generate_canary_word(prefix: str = "BANKS[", suffix: str = "]", token_length: int = 8) -> str: return f"{prefix}{secrets.token_hex(token_length // 2)}{suffix}" + + +def python_type_to_jsonschema(python_type: type) -> str: + """Given a Python type, returns the jsonschema string describing it.""" + if python_type is str: + return "string" + if python_type is int: + return "integer" + if python_type is float: + return "number" + if python_type is bool: + return "boolean" + if python_type is list: + return "array" + if python_type is dict: + return "object" + + msg = f"Unsupported type: {python_type}" + raise ValueError(msg) + + +def parse_params_from_docstring(docstring: str) -> dict[str, dict[str, str]]: + param_docs = [] + ds = Docstring(docstring) + for parser in (parse_google, parse_numpy, parse_sphinx): + sections = parser(ds) + for section in sections: + if section.kind == "parameters": + param_docs = section.value + break + if param_docs: + break + + ret = {} + for d in param_docs: + d_dict = d.as_dict() + ret[d_dict.pop("name")] = d_dict + + return ret diff --git a/tests/test_completion.py b/tests/test_completion.py index 7781f0f..4bb7816 100644 --- a/tests/test_completion.py +++ b/tests/test_completion.py @@ -1,10 +1,13 @@ +from os import getenv from unittest import mock import pytest from jinja2.environment import Environment +from litellm.types.utils import ChatCompletionMessageToolCall, Function +from banks.errors import InvalidPromptError, LLMError from banks.extensions.completion import CompletionExtension -from banks.types import ChatMessage +from banks.types import ChatMessage, Tool @pytest.fixture @@ -12,25 +15,186 @@ def ext(): return CompletionExtension(environment=Environment()) -def test__body_to_messages(ext): - assert ext._body_to_messages(" ") == [] - assert ext._body_to_messages(' \n{"role":"user", "content":"hello"}') == [ChatMessage(role="user", content="hello")] - assert ext._body_to_messages(" \nhello\n ") == [ChatMessage(role="user", content="hello")] - assert ext._body_to_messages('{"role":"user", "content":"hello"}\n HELLO!') == [ - ChatMessage(role="user", content="hello") +@pytest.fixture +def mocked_choices_no_tools(): + return [mock.MagicMock(message=mock.MagicMock(tool_calls=None, content="some response"))] + + +@pytest.fixture +def mocked_choices_with_tools(): + return [ + mock.MagicMock( + message=mock.MagicMock( + tool_calls=[ + ChatCompletionMessageToolCall( + id="call_DN6IiLULWZw7sobV6puCji1O", + function=Function( + arguments='{"location": "San Francisco", "unit": "celsius"}', name="get_current_weather" + ), + type="function", + ), + ChatCompletionMessageToolCall( + id="call_ERm1JfYO9AFo2oEWRmWUd40c", + function=Function( + arguments='{"location": "Tokyo", "unit": "celsius"}', name="get_current_weather" + ), + type="function", + ), + ChatCompletionMessageToolCall( + id="call_2lvUVB1y4wKunSxTenR0zClP", + function=Function( + arguments='{"location": "Paris", "unit": "celsius"}', name="get_current_weather" + ), + type="function", + ), + ], + content="some response", + ) + ) ] -def test__do_completion(ext): - assert ext._do_completion("test-model", lambda: " ") == "" +@pytest.fixture +def tools(): + return [ + Tool.model_validate( + { + "type": "function", + "function": { + "name": "getenv", + "description": "Get an environment variable, return None if it doesn't exist.", + "parameters": { + "type": "object", + "properties": { + "key": { + "type": "string", + "description": "The name of the environment variable", + }, + "default": { + "type": "string", + "description": "The value to return if the variable was not found", + }, + }, + "required": ["key"], + }, + }, + "import_path": "os.getenv", + } + ) + ] + + +def test__body_to_messages(ext): + assert ext._body_to_messages(' \n{"role":"user", "content":"hello"}') == ( + [ChatMessage(role="user", content="hello")], + [], + ) + assert ext._body_to_messages('{"role":"user", "content":"hello"}\n HELLO!') == ( + [ChatMessage(role="user", content="hello")], + [], + ) + with pytest.raises(InvalidPromptError, match="Completion must contain at least one chat message"): + ext._body_to_messages(" ") + with pytest.raises(InvalidPromptError, match="Completion must contain at least one chat message"): + ext._body_to_messages(" \nhello\n ") + + +def test__do_completion_no_prompt(ext): + with pytest.raises(InvalidPromptError, match="Completion must contain at least one chat message"): + ext._do_completion("test-model", lambda: " ") + + +@pytest.mark.asyncio +async def test__do_completion_async_no_prompt(ext): + with pytest.raises(InvalidPromptError, match="Completion must contain at least one chat message"): + await ext._do_completion_async("test-model", lambda: " ") + + +def test__do_completion_no_tools(ext, mocked_choices_no_tools): with mock.patch("banks.extensions.completion.completion") as mocked_completion: - ext._do_completion("test-model", lambda: "hello") - mocked_completion.assert_called_with(model="test-model", messages=[ChatMessage(role="user", content="hello")]) + mocked_completion.return_value.choices = mocked_choices_no_tools + ext._do_completion("test-model", lambda: '{"role":"user", "content":"hello"}') + mocked_completion.assert_called_with( + model="test-model", messages=[ChatMessage(role="user", content="hello")], tools=[] + ) @pytest.mark.asyncio -async def test__do_completion_async(ext): - assert await ext._do_completion_async("test-model", lambda: " ") == "" +async def test__do_completion_async_no_tools(ext, mocked_choices_no_tools): with mock.patch("banks.extensions.completion.acompletion") as mocked_completion: - await ext._do_completion_async("test-model", lambda: "hello") - mocked_completion.assert_called_with(model="test-model", messages=[ChatMessage(role="user", content="hello")]) + mocked_completion.return_value.choices = mocked_choices_no_tools + await ext._do_completion_async("test-model", lambda: '{"role":"user", "content":"hello"}') + mocked_completion.assert_called_with( + model="test-model", messages=[ChatMessage(role="user", content="hello")], tools=[] + ) + + +def test__do_completion_with_tools(ext, mocked_choices_with_tools): + ext._get_tool_callable = mock.MagicMock(return_value=lambda location, unit: f"I got {location} with {unit}") + ext._body_to_messages = mock.MagicMock(return_value=(["message1", "message2"], ["tool1", "tool2"])) + with mock.patch("banks.extensions.completion.completion") as mocked_completion: + mocked_completion.return_value.choices = mocked_choices_with_tools + ext._do_completion("test-model", lambda: '{"role":"user", "content":"hello"}') + calls = mocked_completion.call_args_list + assert len(calls) == 2 # complete query, complete with tool results + assert calls[0].kwargs["tools"] == ["tool1", "tool2"] + assert "tools" not in calls[1].kwargs + for m in calls[1].kwargs["messages"]: + if type(m) is ChatMessage: + assert m.role == "tool" + assert m.name == "get_current_weather" + + +@pytest.mark.asyncio +async def test__do_completion_async_with_tools(ext, mocked_choices_with_tools): + ext._get_tool_callable = mock.MagicMock(return_value=lambda location, unit: f"I got {location} with {unit}") + ext._body_to_messages = mock.MagicMock(return_value=(["message1", "message2"], ["tool1", "tool2"])) + with mock.patch("banks.extensions.completion.acompletion") as mocked_completion: + mocked_completion.return_value.choices = mocked_choices_with_tools + await ext._do_completion_async("test-model", lambda: '{"role":"user", "content":"hello"}') + calls = mocked_completion.call_args_list + assert len(calls) == 2 # complete query, complete with tool results + assert calls[0].kwargs["tools"] == ["tool1", "tool2"] + assert "tools" not in calls[1].kwargs + for m in calls[1].kwargs["messages"]: + if type(m) is ChatMessage: + assert m.role == "tool" + assert m.name == "get_current_weather" + + +def test__do_completion_with_tools_malformed(ext, mocked_choices_with_tools): + mocked_choices_with_tools[0].message.tool_calls[0].function.name = None + with mock.patch("banks.extensions.completion.completion") as mocked_completion: + mocked_completion.return_value.choices = mocked_choices_with_tools + with pytest.raises(LLMError): + ext._do_completion("test-model", lambda: '{"role":"user", "content":"hello"}') + + +@pytest.mark.asyncio +async def test__do_completion_async_with_tools_malformed(ext, mocked_choices_with_tools): + mocked_choices_with_tools[0].message.tool_calls[0].function.name = None + with mock.patch("banks.extensions.completion.acompletion") as mocked_completion: + mocked_completion.return_value.choices = mocked_choices_with_tools + with pytest.raises(LLMError): + await ext._do_completion_async("test-model", lambda: '{"role":"user", "content":"hello"}') + + +@pytest.mark.asyncio +async def test__do_completion_async_no_prompt_no_tools(ext, mocked_choices_no_tools): + with mock.patch("banks.extensions.completion.acompletion") as mocked_completion: + mocked_completion.return_value.choices = mocked_choices_no_tools + await ext._do_completion_async("test-model", lambda: '{"role":"user", "content":"hello"}') + mocked_completion.assert_called_with( + model="test-model", messages=[ChatMessage(role="user", content="hello")], tools=[] + ) + + +def test__get_tool_callable(ext, tools): + tool_call = mock.MagicMock() + + tool_call.function.name = "getenv" + assert ext._get_tool_callable(tools, tool_call) == getenv + + tool_call.function.name = "another_func" + with pytest.raises(ValueError, match="Function another_func not found in available tools"): + ext._get_tool_callable(tools, tool_call) diff --git a/tests/test_prompt.py b/tests/test_prompt.py index b1db95c..ace025d 100644 --- a/tests/test_prompt.py +++ b/tests/test_prompt.py @@ -89,9 +89,13 @@ def test_chat_messages(): p.text() == """ {"role":"system","content":"You are a helpful assistant.\\n"} + {"role":"user","content":"Hello, how are you?\\n"} + {"role":"system","content":"I'm doing well, thank you! How can I assist you today?\\n"} + {"role":"user","content":"Can you explain quantum computing?\\n"} + Some random text. """.strip() ) diff --git a/tests/test_tool.py b/tests/test_tool.py new file mode 100644 index 0000000..c46aed4 --- /dev/null +++ b/tests/test_tool.py @@ -0,0 +1,56 @@ +from banks.filters import tool +from banks.types import Tool + + +def test_tool(): + def my_tool_function(myparam: str): + """Description of the tool. + + Args: + myparam (str): description of the parameter + + """ + pass + + tool_dump = tool(my_tool_function) + t = Tool.model_validate_json(tool_dump) + assert t.model_dump() == { + "type": "function", + "function": { + "name": "my_tool_function", + "description": "Description of the tool.\n\nArgs:\n myparam (str): description of the parameter", + "parameters": { + "type": "object", + "properties": {"myparam": {"type": "string", "description": "description of the parameter"}}, + "required": ["myparam"], + }, + }, + "import_path": "tests.test_tool.test_tool..my_tool_function", + } + + +def test_tool_with_defaults(): + def my_tool_function(myparam: str = ""): + """Description of the tool. + + Args: + myparam (str): description of the parameter + + """ + pass + + tool_dump = tool(my_tool_function) + t = Tool.model_validate_json(tool_dump) + assert t.model_dump() == { + "type": "function", + "function": { + "name": "my_tool_function", + "description": "Description of the tool.\n\nArgs:\n myparam (str): description of the parameter", + "parameters": { + "type": "object", + "properties": {"myparam": {"type": "string", "description": "description of the parameter"}}, + "required": [], + }, + }, + "import_path": "tests.test_tool.test_tool_with_defaults..my_tool_function", + } diff --git a/tests/test_utils.py b/tests/test_utils.py index d3f9b32..73450b0 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,7 +1,7 @@ import pytest import regex as re -from banks.utils import generate_canary_word, strtobool +from banks.utils import generate_canary_word, parse_params_from_docstring, python_type_to_jsonschema, strtobool def test_generate_canary_word_defaults(): @@ -45,3 +45,68 @@ def test_strtobool_error(): ) def test_strtobool(test_input, expected): assert strtobool(test_input) == expected + + +@pytest.mark.parametrize( + "test_input,expected", + [ + (type("I am a string"), "string"), # noqa + (type(42), "integer"), # noqa + (type(0.42), "number"), # noqa + (type(True), "boolean"), # noqa + (type([]), "array"), + (type({"foo": "bar"}), "object"), + ], +) +def test_python_type_to_jsonschema(test_input, expected): + assert python_type_to_jsonschema(test_input) == expected + with pytest.raises(ValueError, match="Unsupported type: "): + python_type_to_jsonschema(type(Exception)) + + +def test_parse_params_from_docstring_google(): + def my_test_function(test_param: str): + """A docstring. + + Args: + test_param (str): The test parameter. + """ + pass + + assert parse_params_from_docstring(my_test_function.__doc__) == { # type: ignore + "test_param": {"annotation": "str", "description": "The test parameter."} + } + + +def test_parse_params_from_docstring_numpy(): + def my_test_function(test_param: str): + """A docstring. + + Parameters + ---------- + test_param : str + The test parameter. + """ + pass + + assert parse_params_from_docstring(my_test_function.__doc__) == { # type: ignore + "test_param": {"annotation": "str", "description": "The test parameter."} + } + + +def test_parse_params_from_docstring_sphinx(): + def my_test_function(test_param: str): + """A docstring. + + :param test_param: The test parameter. + :type test_param: str + """ + pass + + assert parse_params_from_docstring(my_test_function.__doc__) == { # type: ignore + "test_param": {"annotation": "str", "description": "The test parameter."} + } + + +def test_parse_params_from_docstring_empty(): + assert parse_params_from_docstring("") == {}