Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 71 additions & 114 deletions tests/test_llm_agent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# tests/test_llm_agent.py

import re

from mesa.model import Model
from mesa.space import MultiGrid
Expand All @@ -11,170 +10,128 @@
from mesa_llm.reasoning.react import ReActReasoning


def test_apply_plan_adds_to_memory(monkeypatch):
monkeypatch.setenv("GEMINI_API_KEY", "dummy")
# Create a Mock Scheduler to bypass any potential installation/import issues
class MockScheduler:
def __init__(self, model):
self._ids = 0
self.agents = []

def add(self, agent):
self.agents.append(agent)

def next_id(self):
self._ids += 1
return self._ids


# Helper function to create a standardized DummyModel to avoid repetition
def create_dummy_model(seed):
class DummyModel(Model):
def __init__(self):
super().__init__(seed=42)
self.grid = MultiGrid(3, 3, torus=False)
super().__init__(seed=seed)
self.grid = MultiGrid(5, 5, torus=False)
self.schedule = MockScheduler(self) # Use the mock scheduler

def add_agent(self, pos):
system_prompt = "You are an agent in a simulation."
agents = LLMAgent.create_agents(
self,
n=1,
reasoning=ReActReasoning,
system_prompt=system_prompt,
system_prompt="System prompt",
vision=-1,
internal_state=["test_state"],
)
agent = agents[0]
# THE FIX: Call next_id() on the schedule, not the model
agent.unique_id = self.schedule.next_id()
self.grid.place_agent(agent, pos)
self.schedule.add(agent)
return agent

x, y = pos
return DummyModel()

self.grid.place_agent(agents[0], (x, y))
return agents[0]

model = DummyModel()
def test_apply_plan_adds_to_memory(monkeypatch):
monkeypatch.setenv("GEMINI_API_KEY", "dummy")
model = create_dummy_model(seed=42)
agent = model.add_agent((1, 1))
agent.memory = ShortTermMemory(
agent=agent,
n=5,
display=True,
)
agent.memory = ShortTermMemory(agent=agent, n=5, display=True)

# fake response returned by the tool manager
fake_response = [{"tool": "foo", "argument": "bar"}]

# monkeypatch the tool manager so no real tool calls are made
monkeypatch.setattr(
agent.tool_manager, "call_tools", lambda agent, llm_response: fake_response
)

plan = Plan(step=0, llm_plan="do something")

resp = agent.apply_plan(plan)

assert resp == fake_response

assert {
"tool": "foo",
"argument": "bar",
} in agent.memory.step_content.values() or agent.memory.step_content == {
"tool": "foo",
"argument": "bar",
}
assert {"tool": "foo", "argument": "bar"} in agent.memory.step_content.values()


def test_generate_obs_with_one_neighbor(monkeypatch):
monkeypatch.setenv("GEMINI_API_KEY", "dummy")

class DummyModel(Model):
def __init__(self):
super().__init__(seed=45)
self.grid = MultiGrid(3, 3, torus=False)

def add_agent(self, pos, agent_class=LLMAgent):
system_prompt = "You are an agent in a simulation."
agents = agent_class.create_agents(
self,
n=1,
reasoning=ReActReasoning,
system_prompt=system_prompt,
vision=-1,
internal_state=["test_state"],
)
x, y = pos
self.grid.place_agent(agents[0], (x, y))
return agents[0]

model = DummyModel()

model = create_dummy_model(seed=45)
agent = model.add_agent((1, 1))
agent.memory = ShortTermMemory(
agent=agent,
n=5,
display=True,
)
agent.unique_id = 1

agent.memory = ShortTermMemory(agent=agent, n=5)
neighbor = model.add_agent((1, 2))
neighbor.memory = ShortTermMemory(
agent=agent,
n=5,
display=True,
)
neighbor.unique_id = 2
monkeypatch.setattr(agent.memory, "add_to_memory", lambda *args, **kwargs: None)

obs = agent.generate_obs()

assert obs.self_state["agent_unique_id"] == 1

# we should have exactly one neighboring agent in local_state
assert obs.self_state["agent_unique_id"] == agent.unique_id
assert len(obs.local_state) == 1

# extract the neighbor
key = next(iter(obs.local_state.keys()))
assert key == "LLMAgent 2"

entry = obs.local_state[key]
assert entry["position"] == (1, 2)
assert entry["internal_state"] == ["test_state"]
assert key == f"LLMAgent {neighbor.unique_id}"
assert obs.local_state[key]["position"] == (1, 2)


def test_send_message_updates_both_agents_memory(monkeypatch):
monkeypatch.setenv("GEMINI_API_KEY", "dummy")

class DummyModel(Model):
def __init__(self):
super().__init__(seed=45)
self.grid = MultiGrid(3, 3, torus=False)

def add_agent(self, pos, agent_class=LLMAgent):
system_prompt = "You are an agent in a simulation."
agents = agent_class.create_agents(
self,
n=1,
reasoning=lambda agent: None,
system_prompt=system_prompt,
vision=-1,
internal_state=["test_state"],
)
x, y = pos
self.grid.place_agent(agents[0], (x, y))
return agents[0]

model = DummyModel()
model = create_dummy_model(seed=45)
sender = model.add_agent((0, 0))
sender.memory = ShortTermMemory(
agent=sender,
n=5,
display=True,
)
sender.unique_id = 1

sender.memory = ShortTermMemory(agent=sender, n=5)
recipient = model.add_agent((1, 1))
recipient.memory = ShortTermMemory(
agent=recipient,
n=5,
display=True,
)
recipient.unique_id = 2
recipient.memory = ShortTermMemory(agent=recipient, n=5)

# Track how many times add_to_memory is called
call_counter = {"count": 0}

def fake_add_to_memory(*args, **kwargs):
call_counter["count"] += 1

# monkeypatch both agents' memory modules
monkeypatch.setattr(sender.memory, "add_to_memory", fake_add_to_memory)
monkeypatch.setattr(recipient.memory, "add_to_memory", fake_add_to_memory)

result = sender.send_message("hello", recipients=[recipient])
pattern = r"LLMAgent 1 → \[<mesa_llm\.llm_agent\.LLMAgent object at 0x[0-9A-Fa-f]+>\] : hello"
assert re.match(pattern, result)

# sender + recipient memory => should be called twice
sender.send_message("hello", recipients=[recipient])
assert call_counter["count"] == 2


def test_generate_obs_zero_vision(monkeypatch):
monkeypatch.setenv("GEMINI_API_KEY", "dummy")
model = create_dummy_model(seed=45)
agent = model.add_agent((1, 1))
agent.memory = ShortTermMemory(agent=agent, n=5)
monkeypatch.setattr(agent.memory, "add_to_memory", lambda *args, **kwargs: None)
_ = model.add_agent((1, 2))

agent.vision = 0
obs = agent.generate_obs()
assert obs.local_state == {}


def test_generate_obs_limited_vision(monkeypatch):
monkeypatch.setenv("GEMINI_API_KEY", "dummy")
model = create_dummy_model(seed=45)
agent = model.add_agent((2, 2))
agent.memory = ShortTermMemory(agent=agent, n=5)
monkeypatch.setattr(agent.memory, "add_to_memory", lambda *args, **kwargs: None)

neighbor = model.add_agent((2, 3))
far_agent = model.add_agent((4, 4))

agent.vision = 1
obs = agent.generate_obs()

assert len(obs.local_state) == 1
assert f"LLMAgent {neighbor.unique_id}" in obs.local_state
assert f"LLMAgent {far_agent.unique_id}" not in obs.local_state
71 changes: 71 additions & 0 deletions tests/test_memory/test_st_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# tests/test_memory/test_st_memory.py

from collections import deque

from mesa_llm.memory.memory import MemoryEntry
from mesa_llm.memory.st_memory import ShortTermMemory


class TestShortTermMemory:
"""Tests for the ShortTermMemory class."""

def test_initialization(self, mock_agent):
"""Test that the memory initializes correctly."""
memory = ShortTermMemory(agent=mock_agent, n=7, display=False)
assert memory.n == 7
assert isinstance(memory.short_term_memory, deque)
assert (
memory.short_term_memory.maxlen is None
) # Deque for STMemory is not bounded

def test_process_step_logic(self, mock_agent):
"""Test the two-part process_step logic for pre- and post-step."""
mock_agent.model.steps = 1
memory = ShortTermMemory(agent=mock_agent, n=5, display=False)

# 1. Simulate pre_step: content is added with step=None
memory.step_content = {"observation": "seeing a cat"}
memory.process_step(pre_step=True)

assert len(memory.short_term_memory) == 1
first_entry = memory.short_term_memory[0]
assert first_entry.step is None
assert first_entry.content == {"observation": "seeing a cat"}
assert memory.step_content == {} # step_content should be cleared

# 2. Simulate post_step: the previous entry is updated with the real step number
memory.step_content = {"action": "pet the cat"}
memory.process_step(pre_step=False)

assert len(memory.short_term_memory) == 1
updated_entry = memory.short_term_memory[0]
assert updated_entry.step == 1 # Step number is now set
assert updated_entry.content == {
"observation": "seeing a cat",
"action": "pet the cat",
}
assert memory.step_content == {} # step_content should be cleared again

def test_format_short_term_empty(self, mock_agent):
"""Test that formatting an empty memory returns the correct string."""
memory = ShortTermMemory(agent=mock_agent)
assert memory.format_short_term() == "No recent memory."

def test_get_communication_history(self, mock_agent):
"""Test that communication history is correctly extracted."""
memory = ShortTermMemory(agent=mock_agent)

# Manually add some entries
msg_entry_content = {"message": "Hello there!"}
action_entry_content = {"action": "move"}

memory.short_term_memory.append(
MemoryEntry(content=msg_entry_content, step=1, agent=mock_agent)
)
memory.short_term_memory.append(
MemoryEntry(content=action_entry_content, step=1, agent=mock_agent)
)

history = memory.get_communication_history()
assert "step 1: Hello there!" in history
assert "action" not in history
Loading