Skip to content
Merged
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/).

### Added

- Add initial implementation of `OllamaLLM` (#11)
- Add implementation of `base.tool.BaseTool` and relevant data structures (#12)
- Add `tools` to `LLM.chat` and update relevant data structures (#8)
- Add scaffolding for `TaskHandler` (#6)
Expand Down
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ dependencies = [
]

[tool.ruff]
line-length = 79
line-length = 80

[tool.ruff.format]
quote-style = "double"
Expand Down Expand Up @@ -64,6 +64,9 @@ select = [
"PL", # pylint
"D213" # Multiline Docstrings start on newline
]
ignore = [
"A002" # Ignore variable `input` is shadowing a Python builtin in function
]

[tool.ruff.lint.per-file-ignores]
"__init__.py" = ["D104"]
Expand Down
43 changes: 36 additions & 7 deletions src/llm_agents_from_scratch/base/llm.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
"""Base LLM."""

from abc import ABC, abstractmethod
from typing import Any, Sequence

from llm_agents_from_scratch.data_structures import ChatMessage, CompleteResult
from llm_agents_from_scratch.data_structures import (
ChatMessage,
CompleteResult,
ToolCallResult,
)

from .tool import BaseTool

Expand All @@ -11,30 +16,54 @@ class BaseLLM(ABC):
"""Base LLM Class."""

@abstractmethod
async def complete(self, prompt: str) -> CompleteResult:
async def complete(self, prompt: str, **kwargs: Any) -> CompleteResult:
"""Text Complete.

Args:
prompt (str): The prompt the LLM should use as input.
**kwargs (Any): Additional keyword arguments.

Returns:
str: The completion of the prompt.

"""

@abstractmethod
async def chat(
self,
chat_messages: list[ChatMessage],
tools: list[BaseTool] | None = None,
input: str,
chat_messages: Sequence[ChatMessage] | None = None,
tools: Sequence[BaseTool] | None = None,
**kwargs: Any,
) -> ChatMessage:
"""Chat interface.

Args:
chat_messages (list[ChatMessage]): chat history.
tools (list[BaseTool]): tools that the LLM can call.
input (str): The user's current input.
chat_messages (Sequence[ChatMessage]|None, optional): chat history.
tools (Sequence[BaseTool]|None, optional): tools that the LLM
can call.
**kwargs (Any): Additional keyword arguments.

Returns:
ChatMessage: The response of the LLM structured as a `ChatMessage`.
"""

@abstractmethod
async def continue_conversation_with_tool_results(
self,
tool_call_results: Sequence[ToolCallResult],
chat_messages: Sequence[ChatMessage],
**kwargs: Any,
) -> ChatMessage:
"""Continue a conversation submitting tool call results.

Args:
tool_call_results (Sequence[ToolCallResult]):
Tool call results.
chat_messages (Sequence[ChatMessage]): The chat history.
Defaults to None.
**kwargs (Any): Additional keyword arguments.

Returns:
ChatMessage: The response of the LLM structured as a `ChatMessage`.
"""
4 changes: 4 additions & 0 deletions src/llm_agents_from_scratch/data_structures/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .agent import Task, TaskResult, TaskStep, TaskStepResult
from .llm import ChatMessage, ChatRole, CompleteResult
from .tool import ToolCall, ToolCallResult

__all__ = [
# agent
Expand All @@ -11,4 +12,7 @@
"ChatRole",
"ChatMessage",
"CompleteResult",
# tool
"ToolCall",
"ToolCallResult",
]
7 changes: 4 additions & 3 deletions src/llm_agents_from_scratch/data_structures/llm.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
"""Data Structures for LLMs."""

from enum import Enum
from typing import Any

from pydantic import BaseModel, ConfigDict

from .tool import ToolCall


class ChatRole(str, Enum):
"""Roles for chat messages."""
Expand All @@ -27,7 +28,7 @@ class ChatMessage(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
role: ChatRole
content: str
tool_calls: list[dict[str, Any]] | None = None
tool_calls: list[ToolCall] | None = None


class CompleteResult(BaseModel):
Expand All @@ -39,4 +40,4 @@ class CompleteResult(BaseModel):
"""

response: str
full_response: str
prompt: str
1 change: 1 addition & 0 deletions src/llm_agents_from_scratch/data_structures/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,6 @@ class ToolCallResult(BaseModel):
error: Whether or not the tool call yielded an error.
"""

tool_call: ToolCall
content: Any | None
error: bool = False
39 changes: 0 additions & 39 deletions src/llm_agents_from_scratch/llms/ollama.py

This file was deleted.

3 changes: 3 additions & 0 deletions src/llm_agents_from_scratch/llms/ollama/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .llm import OllamaLLM

__all__ = ["OllamaLLM"]
119 changes: 119 additions & 0 deletions src/llm_agents_from_scratch/llms/ollama/llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
"""Ollama LLM integration."""

from typing import Any, Sequence

from ollama import AsyncClient

from llm_agents_from_scratch.base.llm import BaseLLM
from llm_agents_from_scratch.base.tool import BaseTool
from llm_agents_from_scratch.data_structures import (
ChatMessage,
CompleteResult,
ToolCallResult,
)

from .utils import (
chat_message_to_ollama_message,
ollama_message_to_chat_message,
tool_call_result_to_ollama_message,
)


class OllamaLLM(BaseLLM):
"""Ollama LLM class.

Integration to `ollama` library for running open source models locally.
"""

def __init__(self, model: str, *args: Any, **kwargs: Any) -> None:
"""Create an OllamaLLM instance.

Args:
model (str): The name of the LLM model.
*args (Any): Additional positional arguments.
**kwargs (Any): Additional keyword arguments.
"""
super().__init__(*args, **kwargs)
self.model = model
self._client = AsyncClient()

async def complete(self, prompt: str, **kwargs: Any) -> CompleteResult:
"""Complete a prompt with an Ollama LLM.

Args:
prompt (str): The prompt to complete.
**kwargs (Any): Additional keyword arguments.

Returns:
CompleteResult: The text completion result.
"""
response = await self._client.generate(
model=self.model,
prompt=prompt,
**kwargs,
)
return CompleteResult(
response=response.response,
prompt=prompt,
)

async def chat(
self,
input: str,
chat_messages: list[ChatMessage] | None = None,
tools: list[BaseTool] | None = None,
**kwargs: Any,
) -> ChatMessage:
"""Chat with an Ollama LLM.

Args:
input (str): The user's current input.
chat_messages (list[ChatMessage] | None, optional): The chat
history.
tools (list[BaseTool] | None, optional): The tools available to the
LLM.
**kwargs (Any): Additional keyword arguments.

Returns:
ChatMessage: The chat message from the LLM.
"""
o_messages = [
chat_message_to_ollama_message(
ChatMessage(role="user", content=input),
),
]
o_messages.extend(
[chat_message_to_ollama_message(cm) for cm in chat_messages]
if chat_messages
else [],
)

# TODO: add tools to chat request.

result = await self._client.chat(model=self.model, messages=o_messages)

return ollama_message_to_chat_message(result.message)

async def continue_conversation_with_tool_results(
self,
tool_call_results: Sequence[ToolCallResult],
chat_messages: Sequence[ChatMessage],
**kwargs: Any,
) -> ChatMessage:
"""Implements continue_conversation_with_tool_results method.

Args:
tool_call_results (Sequence[ToolCallResult]): The tool call results.
chat_messages (Sequence[ChatMessage]): The chat history.
**kwargs (Any): Additional keyword arguments.

Returns:
ChatMessage: The chat message from the LLM.
"""
o_messages = [
tool_call_result_to_ollama_message(tc) for tc in tool_call_results
] + [chat_message_to_ollama_message(cm) for cm in chat_messages]

result = await self._client.chat(model=self.model, messages=o_messages)

return ollama_message_to_chat_message(result.message)
Loading