Skip to content

Commit d82f42e

Browse files
authored
[Feature] Add max_steps to LLMAgent.run and set handler result to MaxStepsReachedError (#91)
* add MaxStepsReachedError and unit tests * changelog * nit up date docstring
1 parent 872680e commit d82f42e

File tree

5 files changed

+56
-4
lines changed

5 files changed

+56
-4
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/).
88

99
## Unreleased
1010

11+
### Added
12+
13+
- Add `max_steps` to `LLMAgent.run` and set handler result to `MaxStepsReachedError` if reached (#91)
14+
1115
### Changed
1216

1317
- Improve `NextStepDecision` to allow for only one next_step or task_result (#88)

src/llm_agents_from_scratch/agent/llm_agent.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,11 @@
1818
TaskStepResult,
1919
ToolCallResult,
2020
)
21-
from llm_agents_from_scratch.errors import LLMAgentError, TaskHandlerError
21+
from llm_agents_from_scratch.errors import (
22+
LLMAgentError,
23+
MaxStepsReachedError,
24+
TaskHandlerError,
25+
)
2226
from llm_agents_from_scratch.logger import get_logger
2327

2428
from .templates import TaskHandlerTemplates, default_task_handler_templates
@@ -344,11 +348,13 @@ async def run_step(self, step: TaskStep) -> TaskStepResult:
344348
content=final_content,
345349
)
346350

347-
def run(self, task: Task) -> TaskHandler:
351+
def run(self, task: Task, max_steps: int | None = None) -> TaskHandler:
348352
"""Agent's processing loop for executing tasks.
349353
350354
Args:
351355
task (Task): the Task to perform.
356+
max_steps (int | None): Maximum number of steps to run for task.
357+
Defaults to None.
352358
353359
Returns:
354360
TaskHandler: the TaskHandler object responsible for task execution.
@@ -364,8 +370,12 @@ async def _process_loop() -> None:
364370
"""
365371
self.logger.info(f"🚀 Starting task: {task.instruction}")
366372
step_result = None
373+
ix = 0
367374
while not task_handler.done():
368375
try:
376+
if max_steps and ix == max_steps:
377+
raise MaxStepsReachedError("Max steps reached.")
378+
369379
next_step = await task_handler.get_next_step(step_result)
370380

371381
match next_step:
@@ -386,6 +396,8 @@ async def _process_loop() -> None:
386396

387397
except Exception as e:
388398
task_handler.set_exception(e)
399+
finally:
400+
ix += 1
389401

390402
task_handler.background_task = asyncio.create_task(_process_loop())
391403

src/llm_agents_from_scratch/errors/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .agent import LLMAgentError
1+
from .agent import LLMAgentError, MaxStepsReachedError
22
from .core import LLMAgentsFromScratchError, LLMAgentsFromScratchWarning
33
from .task_handler import TaskHandlerError
44

@@ -8,6 +8,7 @@
88
"LLMAgentsFromScratchWarning",
99
# agent
1010
"LLMAgentError",
11+
"MaxStepsReachedError",
1112
# task handler
1213
"TaskHandlerError",
1314
]

src/llm_agents_from_scratch/errors/agent.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,9 @@ class LLMAgentError(LLMAgentsFromScratchError):
77
"""Base error for all TaskHandler-related exceptions."""
88

99
pass
10+
11+
12+
class MaxStepsReachedError(LLMAgentError):
13+
"""Raised if the maximum number of steps reached in a run() method call."""
14+
15+
pass

tests/agent/test_agent.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
TaskResult,
1313
TaskStep,
1414
)
15-
from llm_agents_from_scratch.errors import LLMAgentError
15+
from llm_agents_from_scratch.errors import LLMAgentError, MaxStepsReachedError
1616

1717

1818
def test_init(mock_llm: BaseLLM) -> None:
@@ -113,3 +113,32 @@ async def test_run_exception(
113113
await asyncio.sleep(0.1) # Let it run
114114

115115
assert handler.exception() == err
116+
117+
118+
@pytest.mark.asyncio
119+
@patch.object(LLMAgent.TaskHandler, "get_next_step")
120+
async def test_run_max_steps_reached_error(
121+
mock_get_next_step: AsyncMock,
122+
mock_llm: BaseLLM,
123+
) -> None:
124+
"""Tests run method reaches max step."""
125+
126+
# arrange
127+
task = Task(instruction="mock instruction")
128+
mock_get_next_step.side_effect = [
129+
TaskStep(
130+
instruction="mock 1",
131+
task_id=task.id_,
132+
),
133+
TaskStep(
134+
instruction="mock 2",
135+
task_id=task.id_,
136+
),
137+
]
138+
agent = LLMAgent(llm=mock_llm)
139+
140+
# act
141+
handler = agent.run(task, max_steps=1)
142+
await asyncio.sleep(0.1) # Let it run
143+
144+
assert isinstance(handler.exception(), MaxStepsReachedError)

0 commit comments

Comments
 (0)