diff --git a/CHANGELOG.md b/CHANGELOG.md index bf74533..3549081 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/). ### Added +- Add initial implementation of `OllamaLLM` (#11) - Add implementation of `base.tool.BaseTool` and relevant data structures (#12) - Add `tools` to `LLM.chat` and update relevant data structures (#8) - Add scaffolding for `TaskHandler` (#6) diff --git a/pyproject.toml b/pyproject.toml index b37fb43..f0ad87a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,7 @@ dependencies = [ ] [tool.ruff] -line-length = 79 +line-length = 80 [tool.ruff.format] quote-style = "double" @@ -64,6 +64,9 @@ select = [ "PL", # pylint "D213" # Multiline Docstrings start on newline ] +ignore = [ + "A002" # Ignore variable `input` is shadowing a Python builtin in function +] [tool.ruff.lint.per-file-ignores] "__init__.py" = ["D104"] diff --git a/src/llm_agents_from_scratch/base/llm.py b/src/llm_agents_from_scratch/base/llm.py index fb02f26..1661f39 100644 --- a/src/llm_agents_from_scratch/base/llm.py +++ b/src/llm_agents_from_scratch/base/llm.py @@ -1,8 +1,13 @@ """Base LLM.""" from abc import ABC, abstractmethod +from typing import Any, Sequence -from llm_agents_from_scratch.data_structures import ChatMessage, CompleteResult +from llm_agents_from_scratch.data_structures import ( + ChatMessage, + CompleteResult, + ToolCallResult, +) from .tool import BaseTool @@ -11,30 +16,54 @@ class BaseLLM(ABC): """Base LLM Class.""" @abstractmethod - async def complete(self, prompt: str) -> CompleteResult: + async def complete(self, prompt: str, **kwargs: Any) -> CompleteResult: """Text Complete. Args: prompt (str): The prompt the LLM should use as input. + **kwargs (Any): Additional keyword arguments. Returns: str: The completion of the prompt. - """ @abstractmethod async def chat( self, - chat_messages: list[ChatMessage], - tools: list[BaseTool] | None = None, + input: str, + chat_messages: Sequence[ChatMessage] | None = None, + tools: Sequence[BaseTool] | None = None, + **kwargs: Any, ) -> ChatMessage: """Chat interface. Args: - chat_messages (list[ChatMessage]): chat history. - tools (list[BaseTool]): tools that the LLM can call. + input (str): The user's current input. + chat_messages (Sequence[ChatMessage]|None, optional): chat history. + tools (Sequence[BaseTool]|None, optional): tools that the LLM + can call. + **kwargs (Any): Additional keyword arguments. Returns: ChatMessage: The response of the LLM structured as a `ChatMessage`. + """ + + @abstractmethod + async def continue_conversation_with_tool_results( + self, + tool_call_results: Sequence[ToolCallResult], + chat_messages: Sequence[ChatMessage], + **kwargs: Any, + ) -> ChatMessage: + """Continue a conversation submitting tool call results. + + Args: + tool_call_results (Sequence[ToolCallResult]): + Tool call results. + chat_messages (Sequence[ChatMessage]): The chat history. + Defaults to None. + **kwargs (Any): Additional keyword arguments. + Returns: + ChatMessage: The response of the LLM structured as a `ChatMessage`. """ diff --git a/src/llm_agents_from_scratch/data_structures/__init__.py b/src/llm_agents_from_scratch/data_structures/__init__.py index 1f5a00b..a8f2f2b 100644 --- a/src/llm_agents_from_scratch/data_structures/__init__.py +++ b/src/llm_agents_from_scratch/data_structures/__init__.py @@ -1,5 +1,6 @@ from .agent import Task, TaskResult, TaskStep, TaskStepResult from .llm import ChatMessage, ChatRole, CompleteResult +from .tool import ToolCall, ToolCallResult __all__ = [ # agent @@ -11,4 +12,7 @@ "ChatRole", "ChatMessage", "CompleteResult", + # tool + "ToolCall", + "ToolCallResult", ] diff --git a/src/llm_agents_from_scratch/data_structures/llm.py b/src/llm_agents_from_scratch/data_structures/llm.py index 7a17588..539326a 100644 --- a/src/llm_agents_from_scratch/data_structures/llm.py +++ b/src/llm_agents_from_scratch/data_structures/llm.py @@ -1,10 +1,11 @@ """Data Structures for LLMs.""" from enum import Enum -from typing import Any from pydantic import BaseModel, ConfigDict +from .tool import ToolCall + class ChatRole(str, Enum): """Roles for chat messages.""" @@ -27,7 +28,7 @@ class ChatMessage(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) role: ChatRole content: str - tool_calls: list[dict[str, Any]] | None = None + tool_calls: list[ToolCall] | None = None class CompleteResult(BaseModel): @@ -39,4 +40,4 @@ class CompleteResult(BaseModel): """ response: str - full_response: str + prompt: str diff --git a/src/llm_agents_from_scratch/data_structures/tool.py b/src/llm_agents_from_scratch/data_structures/tool.py index 96be23d..ba1c56e 100644 --- a/src/llm_agents_from_scratch/data_structures/tool.py +++ b/src/llm_agents_from_scratch/data_structures/tool.py @@ -25,5 +25,6 @@ class ToolCallResult(BaseModel): error: Whether or not the tool call yielded an error. """ + tool_call: ToolCall content: Any | None error: bool = False diff --git a/src/llm_agents_from_scratch/llms/ollama.py b/src/llm_agents_from_scratch/llms/ollama.py deleted file mode 100644 index 9b06fae..0000000 --- a/src/llm_agents_from_scratch/llms/ollama.py +++ /dev/null @@ -1,39 +0,0 @@ -"""Ollama LLM integration.""" - -from llm_agents_from_scratch.base.llm import BaseLLM -from llm_agents_from_scratch.base.tool import BaseTool -from llm_agents_from_scratch.data_structures import ChatMessage, CompleteResult - - -class OllamaLLM(BaseLLM): - """Ollama LLM class. - - Integration to `ollama` library for running open source models locally. - """ - - async def complete(self, prompt: str) -> CompleteResult: - """Complete a prompt with an Ollama LLM. - - Args: - prompt (str): The prompt to complete. - - Returns: - CompleteResult: The text completion result. - """ - raise NotImplementedError # pragma: no cover - - async def chat( - self, - chat_messages: list[ChatMessage], - tools: list[BaseTool], - ) -> ChatMessage: - """Chat with an Ollama LLM. - - Args: - chat_messages (list[ChatMessage]): The chat history. - tools (list[BaseTool]): The tools available to the LLM. - - Returns: - ChatMessage: The chat message from the LLM. - """ - raise NotImplementedError # pragma: no cover diff --git a/src/llm_agents_from_scratch/llms/ollama/__init__.py b/src/llm_agents_from_scratch/llms/ollama/__init__.py new file mode 100644 index 0000000..3549b0e --- /dev/null +++ b/src/llm_agents_from_scratch/llms/ollama/__init__.py @@ -0,0 +1,3 @@ +from .llm import OllamaLLM + +__all__ = ["OllamaLLM"] diff --git a/src/llm_agents_from_scratch/llms/ollama/llm.py b/src/llm_agents_from_scratch/llms/ollama/llm.py new file mode 100644 index 0000000..fb5b03b --- /dev/null +++ b/src/llm_agents_from_scratch/llms/ollama/llm.py @@ -0,0 +1,119 @@ +"""Ollama LLM integration.""" + +from typing import Any, Sequence + +from ollama import AsyncClient + +from llm_agents_from_scratch.base.llm import BaseLLM +from llm_agents_from_scratch.base.tool import BaseTool +from llm_agents_from_scratch.data_structures import ( + ChatMessage, + CompleteResult, + ToolCallResult, +) + +from .utils import ( + chat_message_to_ollama_message, + ollama_message_to_chat_message, + tool_call_result_to_ollama_message, +) + + +class OllamaLLM(BaseLLM): + """Ollama LLM class. + + Integration to `ollama` library for running open source models locally. + """ + + def __init__(self, model: str, *args: Any, **kwargs: Any) -> None: + """Create an OllamaLLM instance. + + Args: + model (str): The name of the LLM model. + *args (Any): Additional positional arguments. + **kwargs (Any): Additional keyword arguments. + """ + super().__init__(*args, **kwargs) + self.model = model + self._client = AsyncClient() + + async def complete(self, prompt: str, **kwargs: Any) -> CompleteResult: + """Complete a prompt with an Ollama LLM. + + Args: + prompt (str): The prompt to complete. + **kwargs (Any): Additional keyword arguments. + + Returns: + CompleteResult: The text completion result. + """ + response = await self._client.generate( + model=self.model, + prompt=prompt, + **kwargs, + ) + return CompleteResult( + response=response.response, + prompt=prompt, + ) + + async def chat( + self, + input: str, + chat_messages: list[ChatMessage] | None = None, + tools: list[BaseTool] | None = None, + **kwargs: Any, + ) -> ChatMessage: + """Chat with an Ollama LLM. + + Args: + input (str): The user's current input. + chat_messages (list[ChatMessage] | None, optional): The chat + history. + tools (list[BaseTool] | None, optional): The tools available to the + LLM. + **kwargs (Any): Additional keyword arguments. + + Returns: + ChatMessage: The chat message from the LLM. + """ + o_messages = [ + chat_message_to_ollama_message( + ChatMessage(role="user", content=input), + ), + ] + o_messages.extend( + [chat_message_to_ollama_message(cm) for cm in chat_messages] + if chat_messages + else [], + ) + + # TODO: add tools to chat request. + + result = await self._client.chat(model=self.model, messages=o_messages) + + return ollama_message_to_chat_message(result.message) + + async def continue_conversation_with_tool_results( + self, + tool_call_results: Sequence[ToolCallResult], + chat_messages: Sequence[ChatMessage], + **kwargs: Any, + ) -> ChatMessage: + """Implements continue_conversation_with_tool_results method. + + Args: + tool_call_results (Sequence[ToolCallResult]): The tool call results. + chat_messages (Sequence[ChatMessage]): The chat history. + **kwargs (Any): Additional keyword arguments. + + Returns: + ChatMessage: The chat message from the LLM. + """ + o_messages = [ + tool_call_result_to_ollama_message(tc) for tc in tool_call_results + ] + [chat_message_to_ollama_message(cm) for cm in chat_messages] + + result = await self._client.chat(model=self.model, messages=o_messages) + + return ollama_message_to_chat_message(result.message) diff --git a/src/llm_agents_from_scratch/llms/ollama/utils.py b/src/llm_agents_from_scratch/llms/ollama/utils.py new file mode 100644 index 0000000..062992c --- /dev/null +++ b/src/llm_agents_from_scratch/llms/ollama/utils.py @@ -0,0 +1,138 @@ +"""Ollama utils.""" + +from ollama import Message as OllamaMessage +from typing_extensions import assert_never + +from llm_agents_from_scratch.data_structures import ( + ChatMessage, + ChatRole, + ToolCall, + ToolCallResult, +) + +DEFAULT_TOOL_RESPONSE_TEMPLATE = """ +The below is a tool call response for a given tool call. + +tool name: {tool_name} +arguments: {arguments} + + + +{tool_call_result} + +""" + + +def ollama_message_to_chat_message( + ollama_message: OllamaMessage, +) -> ChatMessage: + """Convert an ~ollama.Message to ChatMessage. + + Args: + ollama_message (Message): The ~ollama.Message to convert. + + Returns: + ChatMessage: The converted message. + """ + # role + match ollama_message.role: + case "assistant": + role = ChatRole.ASSISTANT + case "tool": + role = ChatRole.TOOL + case "user": + role = ChatRole.USER + case "system": + role = ChatRole.SYSTEM + case _: + msg = ( + "Failed to convert ~ollama.Message due to invalid role: " + f"`{ollama_message.role}`." + ) + raise RuntimeError(msg) + + # convert tools + converted_tool_calls = ( + [ + ToolCall( + tool_name=o_tool_call.function.name, + arguments=o_tool_call.function.arguments, + ) + for o_tool_call in ollama_message.tool_calls + ] + if ollama_message.tool_calls + else None + ) + + return ChatMessage( + role=role, + content=ollama_message.content, + tool_calls=converted_tool_calls, + ) + + +def chat_message_to_ollama_message(chat_message: ChatMessage) -> OllamaMessage: + """Convert a ChatMessage to an ~ollama.Message type. + + Args: + chat_message (ChatMessage): The ChatMessage instance to convert. + + Returns: + OllamaMessage: The converted message. + """ + # role + match chat_message.role: + case ChatRole.ASSISTANT: + role = "assistant" + case ChatRole.TOOL: + role = "tool" + case ChatRole.USER: + role = "user" + case ChatRole.SYSTEM: + role = "system" + case _: # pragma: no cover + assert_never(chat_message.role) + + # convert tool calls + converted_tool_calls = ( + [ + OllamaMessage.ToolCall( + function=OllamaMessage.ToolCall.Function( + name=tc.tool_name, + arguments=tc.arguments, + ), + ) + for tc in chat_message.tool_calls + ] + if chat_message.tool_calls + else None + ) + + return OllamaMessage( + role=role, + content=chat_message.content, + tool_calls=converted_tool_calls, + ) + + +def tool_call_result_to_ollama_message( + tool_call_result: ToolCallResult, +) -> OllamaMessage: + """Convert a tool call result to an ~ollama.Message. + + Args: + tool_call_result (ToolCallResult): The tool call result. + + Returns: + OllamaMessage: The converted message. + """ + formatted_content = DEFAULT_TOOL_RESPONSE_TEMPLATE.format( + tool_name=tool_call_result.tool_call.tool_name, + arguments=tool_call_result.tool_call.arguments, + tool_call_result=tool_call_result.content, + ) + + return OllamaMessage( + role="tool", + content=formatted_content, + ) diff --git a/tests/conftest.py b/tests/conftest.py index f6666c5..542ebf0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,8 +1,14 @@ +from typing import Any, Sequence + import pytest from llm_agents_from_scratch.base.llm import BaseLLM from llm_agents_from_scratch.base.tool import BaseTool -from llm_agents_from_scratch.data_structures import ChatMessage, CompleteResult +from llm_agents_from_scratch.data_structures import ( + ChatMessage, + CompleteResult, + ToolCallResult, +) class MockBaseLLM(BaseLLM): @@ -15,11 +21,23 @@ async def complete(self, prompt: str) -> CompleteResult: async def chat( self, - chat_messages: list[ChatMessage], - tools: list[BaseTool] | None = None, + chat_messages: Sequence[ChatMessage], + tools: Sequence[BaseTool] | None = None, + **kwargs: Any, ) -> ChatMessage: return ChatMessage(role="assistant", content="mock chat response") + async def continue_conversation_with_tool_results( + self, + tool_call_results: Sequence[ToolCallResult], + chat_messages: Sequence[ChatMessage], + **kwargs: Any, + ): + return ChatMessage( + role="assistant", + content="mock response to tool call result", + ) + @pytest.fixture() def mock_llm() -> BaseLLM: diff --git a/tests/llms/test_base.py b/tests/llms/test_base.py index 5ef6272..1f233eb 100644 --- a/tests/llms/test_base.py +++ b/tests/llms/test_base.py @@ -7,3 +7,4 @@ def test_base_abstract_attr() -> None: assert "complete" in abstract_methods assert "chat" in abstract_methods + assert "continue_conversation_with_tool_results" in abstract_methods diff --git a/tests/llms/test_ollama.py b/tests/llms/test_ollama.py index f3f4c1f..e8bc03d 100644 --- a/tests/llms/test_ollama.py +++ b/tests/llms/test_ollama.py @@ -1,7 +1,288 @@ +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from ollama import ChatResponse, GenerateResponse +from ollama import Message as OllamaMessage + from llm_agents_from_scratch.base.llm import BaseLLM +from llm_agents_from_scratch.data_structures import ( + ChatMessage, + ChatRole, + ToolCall, + ToolCallResult, +) from llm_agents_from_scratch.llms.ollama import OllamaLLM +from llm_agents_from_scratch.llms.ollama.utils import ( + DEFAULT_TOOL_RESPONSE_TEMPLATE, + chat_message_to_ollama_message, + ollama_message_to_chat_message, + tool_call_result_to_ollama_message, +) def test_ollama_llm_class() -> None: names_of_base_classes = [b.__name__ for b in OllamaLLM.__mro__] assert BaseLLM.__name__ in names_of_base_classes + + +@patch("llm_agents_from_scratch.llms.ollama.llm.AsyncClient") +def test_init(mock_async_client_class: MagicMock) -> None: + """Tests init of OllamaLLM.""" + mock_instance = MagicMock() + mock_async_client_class.return_value = mock_instance + llm = OllamaLLM(model="llama3.2") + + assert llm.model == "llama3.2" + assert llm._client == mock_instance + mock_async_client_class.assert_called_once() + + +@pytest.mark.asyncio +@patch("llm_agents_from_scratch.llms.ollama.llm.AsyncClient") +async def test_complete(mock_async_client_class: MagicMock) -> None: + """Test complete method.""" + # arrange mocks + mock_instance = MagicMock() + mock_generate = AsyncMock() + mock_generate.return_value = GenerateResponse( + model="llama3.2", + response="fake response", + ) + mock_instance.generate = mock_generate + mock_async_client_class.return_value = mock_instance + + llm = OllamaLLM(model="llama3.2") + + # act + result = await llm.complete("fake prompt") + + # assert + assert result.response == "fake response" + assert result.prompt == "fake prompt" + + +@pytest.mark.asyncio +@patch("llm_agents_from_scratch.llms.ollama.llm.AsyncClient") +async def test_chat(mock_async_client_class: MagicMock) -> None: + """Test chat method.""" + # arrange mocks + mock_instance = MagicMock() + mock_chat = AsyncMock() + mock_chat.return_value = ChatResponse( + model="llama3.2", + message=OllamaMessage( + role="assistant", + content="some fake content", + tool_calls=[ + OllamaMessage.ToolCall( + function=OllamaMessage.ToolCall.Function( + name="a fake tool", + arguments={"arg1": 1}, + ), + ), + ], + ), + ) + mock_instance.chat = mock_chat + mock_async_client_class.return_value = mock_instance + + llm = OllamaLLM(model="llama3.2") + + # act + result = await llm.chat("Some new input.") + + assert result.role == "assistant" + assert result.content == "some fake content" + mock_chat.assert_awaited_once_with( + model="llama3.2", + messages=[OllamaMessage(role="user", content="Some new input.")], + ) + mock_async_client_class.assert_called_once() + + +@pytest.mark.asyncio +@patch("llm_agents_from_scratch.llms.ollama.llm.AsyncClient") +async def test_continue_conversation_with_tool_results( + mock_async_client_class: MagicMock, +) -> None: + """Test continue_conversation_with_tool_results method.""" + + # arrange mocks + mock_instance = MagicMock() + mock_chat = AsyncMock() + mock_chat.return_value = ChatResponse( + model="llama3.2", + message=OllamaMessage( + role="assistant", + content="Thank you for the tool call results.", + ), + ) + mock_instance.chat = mock_chat + mock_async_client_class.return_value = mock_instance + + llm = OllamaLLM(model="llama3.2") + + # act + tool_call_results = [ + ToolCallResult( + tool_call=ToolCall( + tool_name="a fake tool", + arguments={"arg1": 1}, + ), + content="Some content", + error=False, + ), + ] + result = await llm.continue_conversation_with_tool_results( + tool_call_results=tool_call_results, + chat_messages=[], + ) + + assert result.role == "assistant" + assert result.content == "Thank you for the tool call results." + mock_chat.assert_awaited_once_with( + model="llama3.2", + messages=[tool_call_result_to_ollama_message(tool_call_results[0])], + ) + mock_async_client_class.assert_called_once() + + +# test converter methods +def test_chat_message_to_ollama_message() -> None: + """Tests conversion from ChatMessage to ~ollama.Message.""" + messages = [ + ChatMessage( + role="system", + content="0", + ), + ChatMessage( + role="user", + content="1", + ), + ChatMessage( + role="assistant", + content="2", + tool_calls=[ + ToolCall( + tool_name="a tool", + arguments={ + "arg1": "1", + "arg2": 2, + }, + ), + ], + ), + ChatMessage( + role="tool", + content="3", + ), + ] + + ollama_messages = [chat_message_to_ollama_message(m) for m in messages] + + assert ollama_messages[0].content == "0" + assert ollama_messages[0].role == "system" + assert ollama_messages[0].tool_calls is None + + assert ollama_messages[1].content == "1" + assert ollama_messages[1].role == "user" + assert ollama_messages[1].tool_calls is None + + assert ollama_messages[2].content == "2" + assert ollama_messages[2].role == "assistant" + assert ollama_messages[2].tool_calls[0].function.name == "a tool" + assert ollama_messages[2].tool_calls[0].function.arguments == { + "arg1": "1", + "arg2": 2, + } + + assert ollama_messages[3].content == "3" + assert ollama_messages[3].role == "tool" + assert ollama_messages[3].tool_calls is None + + +def test_ollama_message_to_chat_message() -> None: + """Tests conversion from ~ollama.Message to ChatMessage.""" + messages = [ + OllamaMessage( + role="system", + content="0", + ), + OllamaMessage( + role="user", + content="1", + ), + OllamaMessage( + role="assistant", + content="2", + tool_calls=[ + OllamaMessage.ToolCall( + function=OllamaMessage.ToolCall.Function( + name="fake tool", + arguments={ + "fake_param": "1", + "another_fake_param": "2", + }, + ), + ), + ], + ), + OllamaMessage( + role="tool", + content="3", + ), + ] + + converted = [ollama_message_to_chat_message(m) for m in messages] + + assert converted[0].role == ChatRole.SYSTEM + assert converted[0].content == "0" + assert converted[0].tool_calls is None + + assert converted[1].role == ChatRole.USER + assert converted[1].content == "1" + assert converted[1].tool_calls is None + + assert converted[2].role == ChatRole.ASSISTANT + assert converted[2].content == "2" + assert converted[2].tool_calls[0].tool_name == "fake tool" + assert converted[2].tool_calls[0].arguments == { + "fake_param": "1", + "another_fake_param": "2", + } + + assert converted[3].role == ChatRole.TOOL + assert converted[3].content == "3" + assert converted[3].tool_calls is None + + +def test_ollama_message_to_chat_message_raises_error() -> None: + """Test conversion to chat message raises error with invalid role.""" + with pytest.raises(RuntimeError): + ollama_message_to_chat_message( + OllamaMessage( + role="invalid role", + content="0", + ), + ) + + +def test_tool_call_result_to_ollama_message() -> None: + """Test conversion of tool call result to an ~ollama.Message.""" + tool_call_result = ToolCallResult( + tool_call=ToolCall( + tool_name="a fake tool", + arguments={"arg1": 1}, + ), + content="Some content", + error=False, + ) + + converted = tool_call_result_to_ollama_message(tool_call_result) + + assert converted.role == "tool" + assert converted.content == DEFAULT_TOOL_RESPONSE_TEMPLATE.format( + tool_name="a fake tool", + arguments={"arg1": 1}, + tool_call_result="Some content", + )