diff --git a/CHANGELOG.md b/CHANGELOG.md index a210b07..5eff28c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,3 +11,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/). ... ### Added + +- Add scaffolding for `TaskHandler` (#6) +- Add `LLMAgent` and associated data structures (#6) diff --git a/src/llm_agents_from_scratch/core/__init__.py b/src/llm_agents_from_scratch/core/__init__.py index 48327e3..f61466d 100644 --- a/src/llm_agents_from_scratch/core/__init__.py +++ b/src/llm_agents_from_scratch/core/__init__.py @@ -1,3 +1,4 @@ from .agent import LLMAgent +from .task_handler import TaskHandler -__all__ = ["LLMAgent"] +__all__ = ["LLMAgent", "TaskHandler"] diff --git a/src/llm_agents_from_scratch/core/agent.py b/src/llm_agents_from_scratch/core/agent.py index 6d39418..7b2527a 100644 --- a/src/llm_agents_from_scratch/core/agent.py +++ b/src/llm_agents_from_scratch/core/agent.py @@ -1,9 +1,14 @@ """Agent Module.""" +import asyncio + from typing_extensions import Self from llm_agents_from_scratch.base.llm import BaseLLM from llm_agents_from_scratch.base.tool import BaseTool +from llm_agents_from_scratch.data_structures import Task, TaskResult + +from .task_handler import TaskHandler class LLMAgent: @@ -23,3 +28,28 @@ def add_tool(self, tool: BaseTool) -> Self: """ self.tools = self.tools + [tool] return self + + def run(self, task: Task) -> TaskHandler: + """Asynchronously run `task`.""" + + task_handler = TaskHandler(task, self.llm, self.tools) + + async def _run() -> None: + """Internal async run helper task.""" + while not task_handler.done(): + try: + step = await task_handler.get_next_step() + step_result = await task_handler.run_step(step) + if step_result.last_step: + task_result = TaskResult( + task=task, + content=step_result.content, + rollout="", + ) + task_handler.set_result(task_result) + except Exception as e: + task_handler.set_exception(e) + + task_handler.add_asyncio_task(asyncio.create_task(_run())) + + return task_handler diff --git a/src/llm_agents_from_scratch/core/task_handler.py b/src/llm_agents_from_scratch/core/task_handler.py new file mode 100644 index 0000000..ce33238 --- /dev/null +++ b/src/llm_agents_from_scratch/core/task_handler.py @@ -0,0 +1,55 @@ +"""Task Handler""" + +import asyncio +from typing import Any + +from llm_agents_from_scratch.base.llm import BaseLLM +from llm_agents_from_scratch.base.tool import BaseTool +from llm_agents_from_scratch.data_structures import ( + Task, + TaskStep, + TaskStepResult, +) + + +class TaskHandler(asyncio.Future): + def __init__( + self, + task: Task, + llm: BaseLLM, + tools: list[BaseTool], + *args: Any, + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) + self.task = task + self.llm = llm + self.tools = tools + self._asyncio_tasks: list[asyncio.Task] = [] + + def add_asyncio_task(self, asyncio_task: asyncio.Task) -> None: + self._asyncio_tasks.append(asyncio_task) + + async def get_next_step(self) -> TaskStep | None: + """Based on task progress, determine next step. + + Returns: + TaskStep | None: The next step to run, if `None` then Task is done. + """ + # TODO: implement + pass # pragma: no cover + + async def run_step(self, step: TaskStep) -> TaskStepResult: + """Run next step of a given task. + + Example: perform tool call, generated LLM response, etc. + + Args: + last_step (Any): The result of the previous step. + + Returns: + Any: The result of the next sub step and sets result if Task the completion + of the sub-step represents the completion of the Task. + """ + # TODO: implement + pass # pragma: no cover diff --git a/src/llm_agents_from_scratch/data_structures/__init__.py b/src/llm_agents_from_scratch/data_structures/__init__.py index 7a8aed0..1f5a00b 100644 --- a/src/llm_agents_from_scratch/data_structures/__init__.py +++ b/src/llm_agents_from_scratch/data_structures/__init__.py @@ -1,6 +1,12 @@ +from .agent import Task, TaskResult, TaskStep, TaskStepResult from .llm import ChatMessage, ChatRole, CompleteResult __all__ = [ + # agent + "Task", + "TaskResult", + "TaskStep", + "TaskStepResult", # llm "ChatRole", "ChatMessage", diff --git a/src/llm_agents_from_scratch/data_structures/agent.py b/src/llm_agents_from_scratch/data_structures/agent.py new file mode 100644 index 0000000..0dc90f6 --- /dev/null +++ b/src/llm_agents_from_scratch/data_structures/agent.py @@ -0,0 +1,24 @@ +"""Data Structures for LLM Agent.""" + +from pydantic import BaseModel + + +class Task(BaseModel): + instruction: str + + +class TaskStep(BaseModel): + instruction: str + + +class TaskStepResult(BaseModel): + task_step: TaskStep + content: str | None + last_step: bool = False + + +class TaskResult(BaseModel): + task: Task + content: str + rollout: str + error: bool = False diff --git a/tests/test_agent.py b/tests/test_agent.py index 0236cf3..06c0eb6 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -1,11 +1,20 @@ -from unittest.mock import MagicMock +import asyncio +from unittest.mock import MagicMock, patch + +import pytest +from typing_extensions import override from llm_agents_from_scratch.base.llm import BaseLLM -from llm_agents_from_scratch.core import LLMAgent +from llm_agents_from_scratch.core import LLMAgent, TaskHandler +from llm_agents_from_scratch.data_structures.agent import ( + Task, + TaskStep, + TaskStepResult, +) def test_init(mock_llm: BaseLLM) -> None: - """tests init of LLMAgent""" + """Tests init of LLMAgent.""" agent = LLMAgent(llm=mock_llm) @@ -14,7 +23,7 @@ def test_init(mock_llm: BaseLLM) -> None: def test_add_tool(mock_llm: BaseLLM) -> None: - """tests add tool""" + """Tests add tool.""" # arrange tool = MagicMock() @@ -25,3 +34,75 @@ def test_add_tool(mock_llm: BaseLLM) -> None: # assert assert agent.tools == [tool] + + +@pytest.mark.asyncio +@patch("llm_agents_from_scratch.core.agent.TaskHandler") +async def test_run( + mock_task_handler_class: MagicMock, mock_llm: BaseLLM +) -> None: + """Tests run method.""" + + class MockTaskHandler(TaskHandler): + @override + async def get_next_step(self) -> TaskStep | None: + await asyncio.sleep(0.1) + return TaskStep(instruction="mock step") + + @override + async def run_step(self, step: TaskStep) -> TaskStepResult: + await asyncio.sleep(0.1) + return TaskStepResult( + task_step=step, content="mock result", last_step=True + ) + + # arrange + agent = LLMAgent(llm=mock_llm) + task = Task(instruction="mock instruction") + mock_handler = MockTaskHandler(task, agent.llm, agent.tools) + mock_task_handler_class.return_value = mock_handler + + # act + handler = agent.run(task) + await handler + + # cleanup + for t in handler._asyncio_tasks: + t.cancel() + + assert handler == mock_handler + mock_task_handler_class.assert_called_once_with( + task, agent.llm, agent.tools + ) + assert handler.result().content == "mock result" + + +@pytest.mark.asyncio +@patch("llm_agents_from_scratch.core.agent.TaskHandler") +async def test_run_exception( + mock_task_handler_class: MagicMock, mock_llm: BaseLLM +) -> None: + """Tests run method with exception.""" + + err = RuntimeError("mock error") + + class MockTaskHandler(TaskHandler): + @override + async def get_next_step(self) -> TaskStep | None: + raise err + + # arrange + agent = LLMAgent(llm=mock_llm) + task = Task(instruction="mock instruction") + mock_handler = MockTaskHandler(task, agent.llm, agent.tools) + mock_task_handler_class.return_value = mock_handler + + # act + handler = agent.run(task) + await asyncio.sleep(0.1) # Let it run + + assert handler == mock_handler + mock_task_handler_class.assert_called_once_with( + task, agent.llm, agent.tools + ) + assert handler.exception() == err