Skip to content

Commit

Permalink
add async unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
masci committed Oct 17, 2024
1 parent 1552d4f commit 2cdbace
Showing 1 changed file with 70 additions and 30 deletions.
100 changes: 70 additions & 30 deletions tests/test_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,36 @@ def mocked_choices_with_tools():
]


@pytest.fixture
def tools():
return [
Tool.model_validate(
{
"type": "function",
"function": {
"name": "getenv",
"description": "Get an environment variable, return None if it doesn't exist.",
"parameters": {
"type": "object",
"properties": {
"key": {
"type": "string",
"description": "The name of the environment variable",
},
"default": {
"type": "string",
"description": "The value to return if the variable was not found",
},
},
"required": ["key"],
},
},
"import_path": "os.getenv",
}
)
]


def test__body_to_messages(ext):
assert ext._body_to_messages(' \n{"role":"user", "content":"hello"}') == (
[ChatMessage(role="user", content="hello")],
Expand All @@ -74,6 +104,12 @@ def test__do_completion_no_prompt(ext):
ext._do_completion("test-model", lambda: " ")


@pytest.mark.asyncio
async def test__do_completion_async_no_prompt(ext):
with pytest.raises(InvalidPromptError, match="Completion must contain at least one chat message"):
await ext._do_completion_async("test-model", lambda: " ")


def test__do_completion_no_tools(ext, mocked_choices_no_tools):
with mock.patch("banks.extensions.completion.completion") as mocked_completion:
mocked_completion.return_value.choices = mocked_choices_no_tools
Expand All @@ -83,6 +119,16 @@ def test__do_completion_no_tools(ext, mocked_choices_no_tools):
)


@pytest.mark.asyncio
async def test__do_completion_async_no_tools(ext, mocked_choices_no_tools):
with mock.patch("banks.extensions.completion.acompletion") as mocked_completion:
mocked_completion.return_value.choices = mocked_choices_no_tools
await ext._do_completion_async("test-model", lambda: '{"role":"user", "content":"hello"}')
mocked_completion.assert_called_with(
model="test-model", messages=[ChatMessage(role="user", content="hello")], tools=[]
)


def test__do_completion_with_tools(ext, mocked_choices_with_tools):
ext._get_tool_callable = mock.MagicMock(return_value=lambda location, unit: f"I got {location} with {unit}")
ext._body_to_messages = mock.MagicMock(return_value=(["message1", "message2"], ["tool1", "tool2"]))
Expand All @@ -99,6 +145,23 @@ def test__do_completion_with_tools(ext, mocked_choices_with_tools):
assert m.name == "get_current_weather"


@pytest.mark.asyncio
async def test__do_completion_async_with_tools(ext, mocked_choices_with_tools):
ext._get_tool_callable = mock.MagicMock(return_value=lambda location, unit: f"I got {location} with {unit}")
ext._body_to_messages = mock.MagicMock(return_value=(["message1", "message2"], ["tool1", "tool2"]))
with mock.patch("banks.extensions.completion.acompletion") as mocked_completion:
mocked_completion.return_value.choices = mocked_choices_with_tools
await ext._do_completion_async("test-model", lambda: '{"role":"user", "content":"hello"}')
calls = mocked_completion.call_args_list
assert len(calls) == 2 # complete query, complete with tool results
assert calls[0].kwargs["tools"] == ["tool1", "tool2"]
assert "tools" not in calls[1].kwargs
for m in calls[1].kwargs["messages"]:
if type(m) is ChatMessage:
assert m.role == "tool"
assert m.name == "get_current_weather"


def test__do_completion_with_tools_malformed(ext, mocked_choices_with_tools):
mocked_choices_with_tools[0].message.tool_calls[0].function.name = None
with mock.patch("banks.extensions.completion.completion") as mocked_completion:
Expand All @@ -108,9 +171,12 @@ def test__do_completion_with_tools_malformed(ext, mocked_choices_with_tools):


@pytest.mark.asyncio
async def test__do_completion_async_no_prompt(ext):
with pytest.raises(InvalidPromptError, match="Completion must contain at least one chat message"):
await ext._do_completion_async("test-model", lambda: " ")
async def test__do_completion_async_with_tools_malformed(ext, mocked_choices_with_tools):
mocked_choices_with_tools[0].message.tool_calls[0].function.name = None
with mock.patch("banks.extensions.completion.acompletion") as mocked_completion:
mocked_completion.return_value.choices = mocked_choices_with_tools
with pytest.raises(LLMError):
await ext._do_completion_async("test-model", lambda: '{"role":"user", "content":"hello"}')


@pytest.mark.asyncio
Expand All @@ -123,33 +189,7 @@ async def test__do_completion_async_no_prompt_no_tools(ext, mocked_choices_no_to
)


def test__get_tool_callable(ext):
tools = [
Tool.model_validate(
{
"type": "function",
"function": {
"name": "getenv",
"description": "Get an environment variable, return None if it doesn't exist.",
"parameters": {
"type": "object",
"properties": {
"key": {
"type": "string",
"description": "The name of the environment variable",
},
"default": {
"type": "string",
"description": "The value to return if the variable was not found",
},
},
"required": ["key"],
},
},
"import_path": "os.getenv",
}
)
]
def test__get_tool_callable(ext, tools):
tool_call = mock.MagicMock()

tool_call.function.name = "getenv"
Expand Down

0 comments on commit 2cdbace

Please sign in to comment.