Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/).
...

### Added

- Add scaffolding for `TaskHandler` (#6)
- Add `LLMAgent` and associated data structures (#6)
3 changes: 2 additions & 1 deletion src/llm_agents_from_scratch/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .agent import LLMAgent
from .task_handler import TaskHandler

__all__ = ["LLMAgent"]
__all__ = ["LLMAgent", "TaskHandler"]
30 changes: 30 additions & 0 deletions src/llm_agents_from_scratch/core/agent.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
"""Agent Module."""

import asyncio

from typing_extensions import Self

from llm_agents_from_scratch.base.llm import BaseLLM
from llm_agents_from_scratch.base.tool import BaseTool
from llm_agents_from_scratch.data_structures import Task, TaskResult

from .task_handler import TaskHandler


class LLMAgent:
Expand All @@ -23,3 +28,28 @@ def add_tool(self, tool: BaseTool) -> Self:
"""
self.tools = self.tools + [tool]
return self

def run(self, task: Task) -> TaskHandler:
"""Asynchronously run `task`."""

task_handler = TaskHandler(task, self.llm, self.tools)

async def _run() -> None:
"""Internal async run helper task."""
while not task_handler.done():
try:
step = await task_handler.get_next_step()
step_result = await task_handler.run_step(step)
if step_result.last_step:
task_result = TaskResult(
task=task,
content=step_result.content,
rollout="",
)
task_handler.set_result(task_result)
except Exception as e:
task_handler.set_exception(e)

task_handler.add_asyncio_task(asyncio.create_task(_run()))

return task_handler
55 changes: 55 additions & 0 deletions src/llm_agents_from_scratch/core/task_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""Task Handler"""

import asyncio
from typing import Any

from llm_agents_from_scratch.base.llm import BaseLLM
from llm_agents_from_scratch.base.tool import BaseTool
from llm_agents_from_scratch.data_structures import (
Task,
TaskStep,
TaskStepResult,
)


class TaskHandler(asyncio.Future):
def __init__(
self,
task: Task,
llm: BaseLLM,
tools: list[BaseTool],
*args: Any,
**kwargs: Any,
) -> None:
super().__init__(*args, **kwargs)
self.task = task
self.llm = llm
self.tools = tools
self._asyncio_tasks: list[asyncio.Task] = []

def add_asyncio_task(self, asyncio_task: asyncio.Task) -> None:
self._asyncio_tasks.append(asyncio_task)

async def get_next_step(self) -> TaskStep | None:
"""Based on task progress, determine next step.

Returns:
TaskStep | None: The next step to run, if `None` then Task is done.
"""
# TODO: implement
pass # pragma: no cover

async def run_step(self, step: TaskStep) -> TaskStepResult:
"""Run next step of a given task.

Example: perform tool call, generated LLM response, etc.

Args:
last_step (Any): The result of the previous step.

Returns:
Any: The result of the next sub step and sets result if Task the completion
of the sub-step represents the completion of the Task.
"""
# TODO: implement
pass # pragma: no cover
6 changes: 6 additions & 0 deletions src/llm_agents_from_scratch/data_structures/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
from .agent import Task, TaskResult, TaskStep, TaskStepResult
from .llm import ChatMessage, ChatRole, CompleteResult

__all__ = [
# agent
"Task",
"TaskResult",
"TaskStep",
"TaskStepResult",
# llm
"ChatRole",
"ChatMessage",
Expand Down
24 changes: 24 additions & 0 deletions src/llm_agents_from_scratch/data_structures/agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""Data Structures for LLM Agent."""

from pydantic import BaseModel


class Task(BaseModel):
instruction: str


class TaskStep(BaseModel):
instruction: str


class TaskStepResult(BaseModel):
task_step: TaskStep
content: str | None
last_step: bool = False


class TaskResult(BaseModel):
task: Task
content: str
rollout: str
error: bool = False
89 changes: 85 additions & 4 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
from unittest.mock import MagicMock
import asyncio
from unittest.mock import MagicMock, patch

import pytest
from typing_extensions import override

from llm_agents_from_scratch.base.llm import BaseLLM
from llm_agents_from_scratch.core import LLMAgent
from llm_agents_from_scratch.core import LLMAgent, TaskHandler
from llm_agents_from_scratch.data_structures.agent import (
Task,
TaskStep,
TaskStepResult,
)


def test_init(mock_llm: BaseLLM) -> None:
"""tests init of LLMAgent"""
"""Tests init of LLMAgent."""

agent = LLMAgent(llm=mock_llm)

Expand All @@ -14,7 +23,7 @@ def test_init(mock_llm: BaseLLM) -> None:


def test_add_tool(mock_llm: BaseLLM) -> None:
"""tests add tool"""
"""Tests add tool."""

# arrange
tool = MagicMock()
Expand All @@ -25,3 +34,75 @@ def test_add_tool(mock_llm: BaseLLM) -> None:

# assert
assert agent.tools == [tool]


@pytest.mark.asyncio
@patch("llm_agents_from_scratch.core.agent.TaskHandler")
async def test_run(
mock_task_handler_class: MagicMock, mock_llm: BaseLLM
) -> None:
"""Tests run method."""

class MockTaskHandler(TaskHandler):
@override
async def get_next_step(self) -> TaskStep | None:
await asyncio.sleep(0.1)
return TaskStep(instruction="mock step")

@override
async def run_step(self, step: TaskStep) -> TaskStepResult:
await asyncio.sleep(0.1)
return TaskStepResult(
task_step=step, content="mock result", last_step=True
)

# arrange
agent = LLMAgent(llm=mock_llm)
task = Task(instruction="mock instruction")
mock_handler = MockTaskHandler(task, agent.llm, agent.tools)
mock_task_handler_class.return_value = mock_handler

# act
handler = agent.run(task)
await handler

# cleanup
for t in handler._asyncio_tasks:
t.cancel()

assert handler == mock_handler
mock_task_handler_class.assert_called_once_with(
task, agent.llm, agent.tools
)
assert handler.result().content == "mock result"


@pytest.mark.asyncio
@patch("llm_agents_from_scratch.core.agent.TaskHandler")
async def test_run_exception(
mock_task_handler_class: MagicMock, mock_llm: BaseLLM
) -> None:
"""Tests run method with exception."""

err = RuntimeError("mock error")

class MockTaskHandler(TaskHandler):
@override
async def get_next_step(self) -> TaskStep | None:
raise err

# arrange
agent = LLMAgent(llm=mock_llm)
task = Task(instruction="mock instruction")
mock_handler = MockTaskHandler(task, agent.llm, agent.tools)
mock_task_handler_class.return_value = mock_handler

# act
handler = agent.run(task)
await asyncio.sleep(0.1) # Let it run

assert handler == mock_handler
mock_task_handler_class.assert_called_once_with(
task, agent.llm, agent.tools
)
assert handler.exception() == err