diff --git a/CHANGELOG.md b/CHANGELOG.md
index bf74533..3549081 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/).
### Added
+- Add initial implementation of `OllamaLLM` (#11)
- Add implementation of `base.tool.BaseTool` and relevant data structures (#12)
- Add `tools` to `LLM.chat` and update relevant data structures (#8)
- Add scaffolding for `TaskHandler` (#6)
diff --git a/pyproject.toml b/pyproject.toml
index b37fb43..f0ad87a 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -36,7 +36,7 @@ dependencies = [
]
[tool.ruff]
-line-length = 79
+line-length = 80
[tool.ruff.format]
quote-style = "double"
@@ -64,6 +64,9 @@ select = [
"PL", # pylint
"D213" # Multiline Docstrings start on newline
]
+ignore = [
+ "A002" # Ignore variable `input` is shadowing a Python builtin in function
+]
[tool.ruff.lint.per-file-ignores]
"__init__.py" = ["D104"]
diff --git a/src/llm_agents_from_scratch/base/llm.py b/src/llm_agents_from_scratch/base/llm.py
index fb02f26..1661f39 100644
--- a/src/llm_agents_from_scratch/base/llm.py
+++ b/src/llm_agents_from_scratch/base/llm.py
@@ -1,8 +1,13 @@
"""Base LLM."""
from abc import ABC, abstractmethod
+from typing import Any, Sequence
-from llm_agents_from_scratch.data_structures import ChatMessage, CompleteResult
+from llm_agents_from_scratch.data_structures import (
+ ChatMessage,
+ CompleteResult,
+ ToolCallResult,
+)
from .tool import BaseTool
@@ -11,30 +16,54 @@ class BaseLLM(ABC):
"""Base LLM Class."""
@abstractmethod
- async def complete(self, prompt: str) -> CompleteResult:
+ async def complete(self, prompt: str, **kwargs: Any) -> CompleteResult:
"""Text Complete.
Args:
prompt (str): The prompt the LLM should use as input.
+ **kwargs (Any): Additional keyword arguments.
Returns:
str: The completion of the prompt.
-
"""
@abstractmethod
async def chat(
self,
- chat_messages: list[ChatMessage],
- tools: list[BaseTool] | None = None,
+ input: str,
+ chat_messages: Sequence[ChatMessage] | None = None,
+ tools: Sequence[BaseTool] | None = None,
+ **kwargs: Any,
) -> ChatMessage:
"""Chat interface.
Args:
- chat_messages (list[ChatMessage]): chat history.
- tools (list[BaseTool]): tools that the LLM can call.
+ input (str): The user's current input.
+ chat_messages (Sequence[ChatMessage]|None, optional): chat history.
+ tools (Sequence[BaseTool]|None, optional): tools that the LLM
+ can call.
+ **kwargs (Any): Additional keyword arguments.
Returns:
ChatMessage: The response of the LLM structured as a `ChatMessage`.
+ """
+
+ @abstractmethod
+ async def continue_conversation_with_tool_results(
+ self,
+ tool_call_results: Sequence[ToolCallResult],
+ chat_messages: Sequence[ChatMessage],
+ **kwargs: Any,
+ ) -> ChatMessage:
+ """Continue a conversation submitting tool call results.
+
+ Args:
+ tool_call_results (Sequence[ToolCallResult]):
+ Tool call results.
+ chat_messages (Sequence[ChatMessage]): The chat history.
+ Defaults to None.
+ **kwargs (Any): Additional keyword arguments.
+ Returns:
+ ChatMessage: The response of the LLM structured as a `ChatMessage`.
"""
diff --git a/src/llm_agents_from_scratch/data_structures/__init__.py b/src/llm_agents_from_scratch/data_structures/__init__.py
index 1f5a00b..a8f2f2b 100644
--- a/src/llm_agents_from_scratch/data_structures/__init__.py
+++ b/src/llm_agents_from_scratch/data_structures/__init__.py
@@ -1,5 +1,6 @@
from .agent import Task, TaskResult, TaskStep, TaskStepResult
from .llm import ChatMessage, ChatRole, CompleteResult
+from .tool import ToolCall, ToolCallResult
__all__ = [
# agent
@@ -11,4 +12,7 @@
"ChatRole",
"ChatMessage",
"CompleteResult",
+ # tool
+ "ToolCall",
+ "ToolCallResult",
]
diff --git a/src/llm_agents_from_scratch/data_structures/llm.py b/src/llm_agents_from_scratch/data_structures/llm.py
index 7a17588..539326a 100644
--- a/src/llm_agents_from_scratch/data_structures/llm.py
+++ b/src/llm_agents_from_scratch/data_structures/llm.py
@@ -1,10 +1,11 @@
"""Data Structures for LLMs."""
from enum import Enum
-from typing import Any
from pydantic import BaseModel, ConfigDict
+from .tool import ToolCall
+
class ChatRole(str, Enum):
"""Roles for chat messages."""
@@ -27,7 +28,7 @@ class ChatMessage(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
role: ChatRole
content: str
- tool_calls: list[dict[str, Any]] | None = None
+ tool_calls: list[ToolCall] | None = None
class CompleteResult(BaseModel):
@@ -39,4 +40,4 @@ class CompleteResult(BaseModel):
"""
response: str
- full_response: str
+ prompt: str
diff --git a/src/llm_agents_from_scratch/data_structures/tool.py b/src/llm_agents_from_scratch/data_structures/tool.py
index 96be23d..ba1c56e 100644
--- a/src/llm_agents_from_scratch/data_structures/tool.py
+++ b/src/llm_agents_from_scratch/data_structures/tool.py
@@ -25,5 +25,6 @@ class ToolCallResult(BaseModel):
error: Whether or not the tool call yielded an error.
"""
+ tool_call: ToolCall
content: Any | None
error: bool = False
diff --git a/src/llm_agents_from_scratch/llms/ollama.py b/src/llm_agents_from_scratch/llms/ollama.py
deleted file mode 100644
index 9b06fae..0000000
--- a/src/llm_agents_from_scratch/llms/ollama.py
+++ /dev/null
@@ -1,39 +0,0 @@
-"""Ollama LLM integration."""
-
-from llm_agents_from_scratch.base.llm import BaseLLM
-from llm_agents_from_scratch.base.tool import BaseTool
-from llm_agents_from_scratch.data_structures import ChatMessage, CompleteResult
-
-
-class OllamaLLM(BaseLLM):
- """Ollama LLM class.
-
- Integration to `ollama` library for running open source models locally.
- """
-
- async def complete(self, prompt: str) -> CompleteResult:
- """Complete a prompt with an Ollama LLM.
-
- Args:
- prompt (str): The prompt to complete.
-
- Returns:
- CompleteResult: The text completion result.
- """
- raise NotImplementedError # pragma: no cover
-
- async def chat(
- self,
- chat_messages: list[ChatMessage],
- tools: list[BaseTool],
- ) -> ChatMessage:
- """Chat with an Ollama LLM.
-
- Args:
- chat_messages (list[ChatMessage]): The chat history.
- tools (list[BaseTool]): The tools available to the LLM.
-
- Returns:
- ChatMessage: The chat message from the LLM.
- """
- raise NotImplementedError # pragma: no cover
diff --git a/src/llm_agents_from_scratch/llms/ollama/__init__.py b/src/llm_agents_from_scratch/llms/ollama/__init__.py
new file mode 100644
index 0000000..3549b0e
--- /dev/null
+++ b/src/llm_agents_from_scratch/llms/ollama/__init__.py
@@ -0,0 +1,3 @@
+from .llm import OllamaLLM
+
+__all__ = ["OllamaLLM"]
diff --git a/src/llm_agents_from_scratch/llms/ollama/llm.py b/src/llm_agents_from_scratch/llms/ollama/llm.py
new file mode 100644
index 0000000..fb5b03b
--- /dev/null
+++ b/src/llm_agents_from_scratch/llms/ollama/llm.py
@@ -0,0 +1,119 @@
+"""Ollama LLM integration."""
+
+from typing import Any, Sequence
+
+from ollama import AsyncClient
+
+from llm_agents_from_scratch.base.llm import BaseLLM
+from llm_agents_from_scratch.base.tool import BaseTool
+from llm_agents_from_scratch.data_structures import (
+ ChatMessage,
+ CompleteResult,
+ ToolCallResult,
+)
+
+from .utils import (
+ chat_message_to_ollama_message,
+ ollama_message_to_chat_message,
+ tool_call_result_to_ollama_message,
+)
+
+
+class OllamaLLM(BaseLLM):
+ """Ollama LLM class.
+
+ Integration to `ollama` library for running open source models locally.
+ """
+
+ def __init__(self, model: str, *args: Any, **kwargs: Any) -> None:
+ """Create an OllamaLLM instance.
+
+ Args:
+ model (str): The name of the LLM model.
+ *args (Any): Additional positional arguments.
+ **kwargs (Any): Additional keyword arguments.
+ """
+ super().__init__(*args, **kwargs)
+ self.model = model
+ self._client = AsyncClient()
+
+ async def complete(self, prompt: str, **kwargs: Any) -> CompleteResult:
+ """Complete a prompt with an Ollama LLM.
+
+ Args:
+ prompt (str): The prompt to complete.
+ **kwargs (Any): Additional keyword arguments.
+
+ Returns:
+ CompleteResult: The text completion result.
+ """
+ response = await self._client.generate(
+ model=self.model,
+ prompt=prompt,
+ **kwargs,
+ )
+ return CompleteResult(
+ response=response.response,
+ prompt=prompt,
+ )
+
+ async def chat(
+ self,
+ input: str,
+ chat_messages: list[ChatMessage] | None = None,
+ tools: list[BaseTool] | None = None,
+ **kwargs: Any,
+ ) -> ChatMessage:
+ """Chat with an Ollama LLM.
+
+ Args:
+ input (str): The user's current input.
+ chat_messages (list[ChatMessage] | None, optional): The chat
+ history.
+ tools (list[BaseTool] | None, optional): The tools available to the
+ LLM.
+ **kwargs (Any): Additional keyword arguments.
+
+ Returns:
+ ChatMessage: The chat message from the LLM.
+ """
+ o_messages = [
+ chat_message_to_ollama_message(
+ ChatMessage(role="user", content=input),
+ ),
+ ]
+ o_messages.extend(
+ [chat_message_to_ollama_message(cm) for cm in chat_messages]
+ if chat_messages
+ else [],
+ )
+
+ # TODO: add tools to chat request.
+
+ result = await self._client.chat(model=self.model, messages=o_messages)
+
+ return ollama_message_to_chat_message(result.message)
+
+ async def continue_conversation_with_tool_results(
+ self,
+ tool_call_results: Sequence[ToolCallResult],
+ chat_messages: Sequence[ChatMessage],
+ **kwargs: Any,
+ ) -> ChatMessage:
+ """Implements continue_conversation_with_tool_results method.
+
+ Args:
+ tool_call_results (Sequence[ToolCallResult]): The tool call results.
+ chat_messages (Sequence[ChatMessage]): The chat history.
+ **kwargs (Any): Additional keyword arguments.
+
+ Returns:
+ ChatMessage: The chat message from the LLM.
+ """
+ o_messages = [
+ tool_call_result_to_ollama_message(tc) for tc in tool_call_results
+ ] + [chat_message_to_ollama_message(cm) for cm in chat_messages]
+
+ result = await self._client.chat(model=self.model, messages=o_messages)
+
+ return ollama_message_to_chat_message(result.message)
diff --git a/src/llm_agents_from_scratch/llms/ollama/utils.py b/src/llm_agents_from_scratch/llms/ollama/utils.py
new file mode 100644
index 0000000..062992c
--- /dev/null
+++ b/src/llm_agents_from_scratch/llms/ollama/utils.py
@@ -0,0 +1,138 @@
+"""Ollama utils."""
+
+from ollama import Message as OllamaMessage
+from typing_extensions import assert_never
+
+from llm_agents_from_scratch.data_structures import (
+ ChatMessage,
+ ChatRole,
+ ToolCall,
+ ToolCallResult,
+)
+
+DEFAULT_TOOL_RESPONSE_TEMPLATE = """
+The below is a tool call response for a given tool call.
+
+tool name: {tool_name}
+arguments: {arguments}
+
+
+
+{tool_call_result}
+
+"""
+
+
+def ollama_message_to_chat_message(
+ ollama_message: OllamaMessage,
+) -> ChatMessage:
+ """Convert an ~ollama.Message to ChatMessage.
+
+ Args:
+ ollama_message (Message): The ~ollama.Message to convert.
+
+ Returns:
+ ChatMessage: The converted message.
+ """
+ # role
+ match ollama_message.role:
+ case "assistant":
+ role = ChatRole.ASSISTANT
+ case "tool":
+ role = ChatRole.TOOL
+ case "user":
+ role = ChatRole.USER
+ case "system":
+ role = ChatRole.SYSTEM
+ case _:
+ msg = (
+ "Failed to convert ~ollama.Message due to invalid role: "
+ f"`{ollama_message.role}`."
+ )
+ raise RuntimeError(msg)
+
+ # convert tools
+ converted_tool_calls = (
+ [
+ ToolCall(
+ tool_name=o_tool_call.function.name,
+ arguments=o_tool_call.function.arguments,
+ )
+ for o_tool_call in ollama_message.tool_calls
+ ]
+ if ollama_message.tool_calls
+ else None
+ )
+
+ return ChatMessage(
+ role=role,
+ content=ollama_message.content,
+ tool_calls=converted_tool_calls,
+ )
+
+
+def chat_message_to_ollama_message(chat_message: ChatMessage) -> OllamaMessage:
+ """Convert a ChatMessage to an ~ollama.Message type.
+
+ Args:
+ chat_message (ChatMessage): The ChatMessage instance to convert.
+
+ Returns:
+ OllamaMessage: The converted message.
+ """
+ # role
+ match chat_message.role:
+ case ChatRole.ASSISTANT:
+ role = "assistant"
+ case ChatRole.TOOL:
+ role = "tool"
+ case ChatRole.USER:
+ role = "user"
+ case ChatRole.SYSTEM:
+ role = "system"
+ case _: # pragma: no cover
+ assert_never(chat_message.role)
+
+ # convert tool calls
+ converted_tool_calls = (
+ [
+ OllamaMessage.ToolCall(
+ function=OllamaMessage.ToolCall.Function(
+ name=tc.tool_name,
+ arguments=tc.arguments,
+ ),
+ )
+ for tc in chat_message.tool_calls
+ ]
+ if chat_message.tool_calls
+ else None
+ )
+
+ return OllamaMessage(
+ role=role,
+ content=chat_message.content,
+ tool_calls=converted_tool_calls,
+ )
+
+
+def tool_call_result_to_ollama_message(
+ tool_call_result: ToolCallResult,
+) -> OllamaMessage:
+ """Convert a tool call result to an ~ollama.Message.
+
+ Args:
+ tool_call_result (ToolCallResult): The tool call result.
+
+ Returns:
+ OllamaMessage: The converted message.
+ """
+ formatted_content = DEFAULT_TOOL_RESPONSE_TEMPLATE.format(
+ tool_name=tool_call_result.tool_call.tool_name,
+ arguments=tool_call_result.tool_call.arguments,
+ tool_call_result=tool_call_result.content,
+ )
+
+ return OllamaMessage(
+ role="tool",
+ content=formatted_content,
+ )
diff --git a/tests/conftest.py b/tests/conftest.py
index f6666c5..542ebf0 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,8 +1,14 @@
+from typing import Any, Sequence
+
import pytest
from llm_agents_from_scratch.base.llm import BaseLLM
from llm_agents_from_scratch.base.tool import BaseTool
-from llm_agents_from_scratch.data_structures import ChatMessage, CompleteResult
+from llm_agents_from_scratch.data_structures import (
+ ChatMessage,
+ CompleteResult,
+ ToolCallResult,
+)
class MockBaseLLM(BaseLLM):
@@ -15,11 +21,23 @@ async def complete(self, prompt: str) -> CompleteResult:
async def chat(
self,
- chat_messages: list[ChatMessage],
- tools: list[BaseTool] | None = None,
+ chat_messages: Sequence[ChatMessage],
+ tools: Sequence[BaseTool] | None = None,
+ **kwargs: Any,
) -> ChatMessage:
return ChatMessage(role="assistant", content="mock chat response")
+ async def continue_conversation_with_tool_results(
+ self,
+ tool_call_results: Sequence[ToolCallResult],
+ chat_messages: Sequence[ChatMessage],
+ **kwargs: Any,
+ ):
+ return ChatMessage(
+ role="assistant",
+ content="mock response to tool call result",
+ )
+
@pytest.fixture()
def mock_llm() -> BaseLLM:
diff --git a/tests/llms/test_base.py b/tests/llms/test_base.py
index 5ef6272..1f233eb 100644
--- a/tests/llms/test_base.py
+++ b/tests/llms/test_base.py
@@ -7,3 +7,4 @@ def test_base_abstract_attr() -> None:
assert "complete" in abstract_methods
assert "chat" in abstract_methods
+ assert "continue_conversation_with_tool_results" in abstract_methods
diff --git a/tests/llms/test_ollama.py b/tests/llms/test_ollama.py
index f3f4c1f..e8bc03d 100644
--- a/tests/llms/test_ollama.py
+++ b/tests/llms/test_ollama.py
@@ -1,7 +1,288 @@
+from unittest.mock import AsyncMock, MagicMock, patch
+
+import pytest
+from ollama import ChatResponse, GenerateResponse
+from ollama import Message as OllamaMessage
+
from llm_agents_from_scratch.base.llm import BaseLLM
+from llm_agents_from_scratch.data_structures import (
+ ChatMessage,
+ ChatRole,
+ ToolCall,
+ ToolCallResult,
+)
from llm_agents_from_scratch.llms.ollama import OllamaLLM
+from llm_agents_from_scratch.llms.ollama.utils import (
+ DEFAULT_TOOL_RESPONSE_TEMPLATE,
+ chat_message_to_ollama_message,
+ ollama_message_to_chat_message,
+ tool_call_result_to_ollama_message,
+)
def test_ollama_llm_class() -> None:
names_of_base_classes = [b.__name__ for b in OllamaLLM.__mro__]
assert BaseLLM.__name__ in names_of_base_classes
+
+
+@patch("llm_agents_from_scratch.llms.ollama.llm.AsyncClient")
+def test_init(mock_async_client_class: MagicMock) -> None:
+ """Tests init of OllamaLLM."""
+ mock_instance = MagicMock()
+ mock_async_client_class.return_value = mock_instance
+ llm = OllamaLLM(model="llama3.2")
+
+ assert llm.model == "llama3.2"
+ assert llm._client == mock_instance
+ mock_async_client_class.assert_called_once()
+
+
+@pytest.mark.asyncio
+@patch("llm_agents_from_scratch.llms.ollama.llm.AsyncClient")
+async def test_complete(mock_async_client_class: MagicMock) -> None:
+ """Test complete method."""
+ # arrange mocks
+ mock_instance = MagicMock()
+ mock_generate = AsyncMock()
+ mock_generate.return_value = GenerateResponse(
+ model="llama3.2",
+ response="fake response",
+ )
+ mock_instance.generate = mock_generate
+ mock_async_client_class.return_value = mock_instance
+
+ llm = OllamaLLM(model="llama3.2")
+
+ # act
+ result = await llm.complete("fake prompt")
+
+ # assert
+ assert result.response == "fake response"
+ assert result.prompt == "fake prompt"
+
+
+@pytest.mark.asyncio
+@patch("llm_agents_from_scratch.llms.ollama.llm.AsyncClient")
+async def test_chat(mock_async_client_class: MagicMock) -> None:
+ """Test chat method."""
+ # arrange mocks
+ mock_instance = MagicMock()
+ mock_chat = AsyncMock()
+ mock_chat.return_value = ChatResponse(
+ model="llama3.2",
+ message=OllamaMessage(
+ role="assistant",
+ content="some fake content",
+ tool_calls=[
+ OllamaMessage.ToolCall(
+ function=OllamaMessage.ToolCall.Function(
+ name="a fake tool",
+ arguments={"arg1": 1},
+ ),
+ ),
+ ],
+ ),
+ )
+ mock_instance.chat = mock_chat
+ mock_async_client_class.return_value = mock_instance
+
+ llm = OllamaLLM(model="llama3.2")
+
+ # act
+ result = await llm.chat("Some new input.")
+
+ assert result.role == "assistant"
+ assert result.content == "some fake content"
+ mock_chat.assert_awaited_once_with(
+ model="llama3.2",
+ messages=[OllamaMessage(role="user", content="Some new input.")],
+ )
+ mock_async_client_class.assert_called_once()
+
+
+@pytest.mark.asyncio
+@patch("llm_agents_from_scratch.llms.ollama.llm.AsyncClient")
+async def test_continue_conversation_with_tool_results(
+ mock_async_client_class: MagicMock,
+) -> None:
+ """Test continue_conversation_with_tool_results method."""
+
+ # arrange mocks
+ mock_instance = MagicMock()
+ mock_chat = AsyncMock()
+ mock_chat.return_value = ChatResponse(
+ model="llama3.2",
+ message=OllamaMessage(
+ role="assistant",
+ content="Thank you for the tool call results.",
+ ),
+ )
+ mock_instance.chat = mock_chat
+ mock_async_client_class.return_value = mock_instance
+
+ llm = OllamaLLM(model="llama3.2")
+
+ # act
+ tool_call_results = [
+ ToolCallResult(
+ tool_call=ToolCall(
+ tool_name="a fake tool",
+ arguments={"arg1": 1},
+ ),
+ content="Some content",
+ error=False,
+ ),
+ ]
+ result = await llm.continue_conversation_with_tool_results(
+ tool_call_results=tool_call_results,
+ chat_messages=[],
+ )
+
+ assert result.role == "assistant"
+ assert result.content == "Thank you for the tool call results."
+ mock_chat.assert_awaited_once_with(
+ model="llama3.2",
+ messages=[tool_call_result_to_ollama_message(tool_call_results[0])],
+ )
+ mock_async_client_class.assert_called_once()
+
+
+# test converter methods
+def test_chat_message_to_ollama_message() -> None:
+ """Tests conversion from ChatMessage to ~ollama.Message."""
+ messages = [
+ ChatMessage(
+ role="system",
+ content="0",
+ ),
+ ChatMessage(
+ role="user",
+ content="1",
+ ),
+ ChatMessage(
+ role="assistant",
+ content="2",
+ tool_calls=[
+ ToolCall(
+ tool_name="a tool",
+ arguments={
+ "arg1": "1",
+ "arg2": 2,
+ },
+ ),
+ ],
+ ),
+ ChatMessage(
+ role="tool",
+ content="3",
+ ),
+ ]
+
+ ollama_messages = [chat_message_to_ollama_message(m) for m in messages]
+
+ assert ollama_messages[0].content == "0"
+ assert ollama_messages[0].role == "system"
+ assert ollama_messages[0].tool_calls is None
+
+ assert ollama_messages[1].content == "1"
+ assert ollama_messages[1].role == "user"
+ assert ollama_messages[1].tool_calls is None
+
+ assert ollama_messages[2].content == "2"
+ assert ollama_messages[2].role == "assistant"
+ assert ollama_messages[2].tool_calls[0].function.name == "a tool"
+ assert ollama_messages[2].tool_calls[0].function.arguments == {
+ "arg1": "1",
+ "arg2": 2,
+ }
+
+ assert ollama_messages[3].content == "3"
+ assert ollama_messages[3].role == "tool"
+ assert ollama_messages[3].tool_calls is None
+
+
+def test_ollama_message_to_chat_message() -> None:
+ """Tests conversion from ~ollama.Message to ChatMessage."""
+ messages = [
+ OllamaMessage(
+ role="system",
+ content="0",
+ ),
+ OllamaMessage(
+ role="user",
+ content="1",
+ ),
+ OllamaMessage(
+ role="assistant",
+ content="2",
+ tool_calls=[
+ OllamaMessage.ToolCall(
+ function=OllamaMessage.ToolCall.Function(
+ name="fake tool",
+ arguments={
+ "fake_param": "1",
+ "another_fake_param": "2",
+ },
+ ),
+ ),
+ ],
+ ),
+ OllamaMessage(
+ role="tool",
+ content="3",
+ ),
+ ]
+
+ converted = [ollama_message_to_chat_message(m) for m in messages]
+
+ assert converted[0].role == ChatRole.SYSTEM
+ assert converted[0].content == "0"
+ assert converted[0].tool_calls is None
+
+ assert converted[1].role == ChatRole.USER
+ assert converted[1].content == "1"
+ assert converted[1].tool_calls is None
+
+ assert converted[2].role == ChatRole.ASSISTANT
+ assert converted[2].content == "2"
+ assert converted[2].tool_calls[0].tool_name == "fake tool"
+ assert converted[2].tool_calls[0].arguments == {
+ "fake_param": "1",
+ "another_fake_param": "2",
+ }
+
+ assert converted[3].role == ChatRole.TOOL
+ assert converted[3].content == "3"
+ assert converted[3].tool_calls is None
+
+
+def test_ollama_message_to_chat_message_raises_error() -> None:
+ """Test conversion to chat message raises error with invalid role."""
+ with pytest.raises(RuntimeError):
+ ollama_message_to_chat_message(
+ OllamaMessage(
+ role="invalid role",
+ content="0",
+ ),
+ )
+
+
+def test_tool_call_result_to_ollama_message() -> None:
+ """Test conversion of tool call result to an ~ollama.Message."""
+ tool_call_result = ToolCallResult(
+ tool_call=ToolCall(
+ tool_name="a fake tool",
+ arguments={"arg1": 1},
+ ),
+ content="Some content",
+ error=False,
+ )
+
+ converted = tool_call_result_to_ollama_message(tool_call_result)
+
+ assert converted.role == "tool"
+ assert converted.content == DEFAULT_TOOL_RESPONSE_TEMPLATE.format(
+ tool_name="a fake tool",
+ arguments={"arg1": 1},
+ tool_call_result="Some content",
+ )