diff --git a/tests/test_llm_agent.py b/tests/test_llm_agent.py index f707e2a..bf9eea4 100644 --- a/tests/test_llm_agent.py +++ b/tests/test_llm_agent.py @@ -1,6 +1,5 @@ # tests/test_llm_agent.py -import re from mesa.model import Model from mesa.space import MultiGrid @@ -11,170 +10,128 @@ from mesa_llm.reasoning.react import ReActReasoning -def test_apply_plan_adds_to_memory(monkeypatch): - monkeypatch.setenv("GEMINI_API_KEY", "dummy") +# Create a Mock Scheduler to bypass any potential installation/import issues +class MockScheduler: + def __init__(self, model): + self._ids = 0 + self.agents = [] + + def add(self, agent): + self.agents.append(agent) + + def next_id(self): + self._ids += 1 + return self._ids + +# Helper function to create a standardized DummyModel to avoid repetition +def create_dummy_model(seed): class DummyModel(Model): def __init__(self): - super().__init__(seed=42) - self.grid = MultiGrid(3, 3, torus=False) + super().__init__(seed=seed) + self.grid = MultiGrid(5, 5, torus=False) + self.schedule = MockScheduler(self) # Use the mock scheduler def add_agent(self, pos): - system_prompt = "You are an agent in a simulation." agents = LLMAgent.create_agents( self, n=1, reasoning=ReActReasoning, - system_prompt=system_prompt, + system_prompt="System prompt", vision=-1, internal_state=["test_state"], ) + agent = agents[0] + # THE FIX: Call next_id() on the schedule, not the model + agent.unique_id = self.schedule.next_id() + self.grid.place_agent(agent, pos) + self.schedule.add(agent) + return agent - x, y = pos + return DummyModel() - self.grid.place_agent(agents[0], (x, y)) - return agents[0] - model = DummyModel() +def test_apply_plan_adds_to_memory(monkeypatch): + monkeypatch.setenv("GEMINI_API_KEY", "dummy") + model = create_dummy_model(seed=42) agent = model.add_agent((1, 1)) - agent.memory = ShortTermMemory( - agent=agent, - n=5, - display=True, - ) + agent.memory = ShortTermMemory(agent=agent, n=5, display=True) - # fake response returned by the tool manager fake_response = [{"tool": "foo", "argument": "bar"}] - - # monkeypatch the tool manager so no real tool calls are made monkeypatch.setattr( agent.tool_manager, "call_tools", lambda agent, llm_response: fake_response ) plan = Plan(step=0, llm_plan="do something") - resp = agent.apply_plan(plan) assert resp == fake_response - - assert { - "tool": "foo", - "argument": "bar", - } in agent.memory.step_content.values() or agent.memory.step_content == { - "tool": "foo", - "argument": "bar", - } + assert {"tool": "foo", "argument": "bar"} in agent.memory.step_content.values() def test_generate_obs_with_one_neighbor(monkeypatch): monkeypatch.setenv("GEMINI_API_KEY", "dummy") - - class DummyModel(Model): - def __init__(self): - super().__init__(seed=45) - self.grid = MultiGrid(3, 3, torus=False) - - def add_agent(self, pos, agent_class=LLMAgent): - system_prompt = "You are an agent in a simulation." - agents = agent_class.create_agents( - self, - n=1, - reasoning=ReActReasoning, - system_prompt=system_prompt, - vision=-1, - internal_state=["test_state"], - ) - x, y = pos - self.grid.place_agent(agents[0], (x, y)) - return agents[0] - - model = DummyModel() - + model = create_dummy_model(seed=45) agent = model.add_agent((1, 1)) - agent.memory = ShortTermMemory( - agent=agent, - n=5, - display=True, - ) - agent.unique_id = 1 - + agent.memory = ShortTermMemory(agent=agent, n=5) neighbor = model.add_agent((1, 2)) - neighbor.memory = ShortTermMemory( - agent=agent, - n=5, - display=True, - ) - neighbor.unique_id = 2 monkeypatch.setattr(agent.memory, "add_to_memory", lambda *args, **kwargs: None) obs = agent.generate_obs() - assert obs.self_state["agent_unique_id"] == 1 - - # we should have exactly one neighboring agent in local_state + assert obs.self_state["agent_unique_id"] == agent.unique_id assert len(obs.local_state) == 1 - - # extract the neighbor key = next(iter(obs.local_state.keys())) - assert key == "LLMAgent 2" - - entry = obs.local_state[key] - assert entry["position"] == (1, 2) - assert entry["internal_state"] == ["test_state"] + assert key == f"LLMAgent {neighbor.unique_id}" + assert obs.local_state[key]["position"] == (1, 2) def test_send_message_updates_both_agents_memory(monkeypatch): monkeypatch.setenv("GEMINI_API_KEY", "dummy") - - class DummyModel(Model): - def __init__(self): - super().__init__(seed=45) - self.grid = MultiGrid(3, 3, torus=False) - - def add_agent(self, pos, agent_class=LLMAgent): - system_prompt = "You are an agent in a simulation." - agents = agent_class.create_agents( - self, - n=1, - reasoning=lambda agent: None, - system_prompt=system_prompt, - vision=-1, - internal_state=["test_state"], - ) - x, y = pos - self.grid.place_agent(agents[0], (x, y)) - return agents[0] - - model = DummyModel() + model = create_dummy_model(seed=45) sender = model.add_agent((0, 0)) - sender.memory = ShortTermMemory( - agent=sender, - n=5, - display=True, - ) - sender.unique_id = 1 - + sender.memory = ShortTermMemory(agent=sender, n=5) recipient = model.add_agent((1, 1)) - recipient.memory = ShortTermMemory( - agent=recipient, - n=5, - display=True, - ) - recipient.unique_id = 2 + recipient.memory = ShortTermMemory(agent=recipient, n=5) - # Track how many times add_to_memory is called call_counter = {"count": 0} def fake_add_to_memory(*args, **kwargs): call_counter["count"] += 1 - # monkeypatch both agents' memory modules monkeypatch.setattr(sender.memory, "add_to_memory", fake_add_to_memory) monkeypatch.setattr(recipient.memory, "add_to_memory", fake_add_to_memory) - result = sender.send_message("hello", recipients=[recipient]) - pattern = r"LLMAgent 1 → \[\] : hello" - assert re.match(pattern, result) - - # sender + recipient memory => should be called twice + sender.send_message("hello", recipients=[recipient]) assert call_counter["count"] == 2 + + +def test_generate_obs_zero_vision(monkeypatch): + monkeypatch.setenv("GEMINI_API_KEY", "dummy") + model = create_dummy_model(seed=45) + agent = model.add_agent((1, 1)) + agent.memory = ShortTermMemory(agent=agent, n=5) + monkeypatch.setattr(agent.memory, "add_to_memory", lambda *args, **kwargs: None) + _ = model.add_agent((1, 2)) + + agent.vision = 0 + obs = agent.generate_obs() + assert obs.local_state == {} + + +def test_generate_obs_limited_vision(monkeypatch): + monkeypatch.setenv("GEMINI_API_KEY", "dummy") + model = create_dummy_model(seed=45) + agent = model.add_agent((2, 2)) + agent.memory = ShortTermMemory(agent=agent, n=5) + monkeypatch.setattr(agent.memory, "add_to_memory", lambda *args, **kwargs: None) + + neighbor = model.add_agent((2, 3)) + far_agent = model.add_agent((4, 4)) + + agent.vision = 1 + obs = agent.generate_obs() + + assert len(obs.local_state) == 1 + assert f"LLMAgent {neighbor.unique_id}" in obs.local_state + assert f"LLMAgent {far_agent.unique_id}" not in obs.local_state diff --git a/tests/test_memory/test_st_memory.py b/tests/test_memory/test_st_memory.py new file mode 100644 index 0000000..ae70923 --- /dev/null +++ b/tests/test_memory/test_st_memory.py @@ -0,0 +1,71 @@ +# tests/test_memory/test_st_memory.py + +from collections import deque + +from mesa_llm.memory.memory import MemoryEntry +from mesa_llm.memory.st_memory import ShortTermMemory + + +class TestShortTermMemory: + """Tests for the ShortTermMemory class.""" + + def test_initialization(self, mock_agent): + """Test that the memory initializes correctly.""" + memory = ShortTermMemory(agent=mock_agent, n=7, display=False) + assert memory.n == 7 + assert isinstance(memory.short_term_memory, deque) + assert ( + memory.short_term_memory.maxlen is None + ) # Deque for STMemory is not bounded + + def test_process_step_logic(self, mock_agent): + """Test the two-part process_step logic for pre- and post-step.""" + mock_agent.model.steps = 1 + memory = ShortTermMemory(agent=mock_agent, n=5, display=False) + + # 1. Simulate pre_step: content is added with step=None + memory.step_content = {"observation": "seeing a cat"} + memory.process_step(pre_step=True) + + assert len(memory.short_term_memory) == 1 + first_entry = memory.short_term_memory[0] + assert first_entry.step is None + assert first_entry.content == {"observation": "seeing a cat"} + assert memory.step_content == {} # step_content should be cleared + + # 2. Simulate post_step: the previous entry is updated with the real step number + memory.step_content = {"action": "pet the cat"} + memory.process_step(pre_step=False) + + assert len(memory.short_term_memory) == 1 + updated_entry = memory.short_term_memory[0] + assert updated_entry.step == 1 # Step number is now set + assert updated_entry.content == { + "observation": "seeing a cat", + "action": "pet the cat", + } + assert memory.step_content == {} # step_content should be cleared again + + def test_format_short_term_empty(self, mock_agent): + """Test that formatting an empty memory returns the correct string.""" + memory = ShortTermMemory(agent=mock_agent) + assert memory.format_short_term() == "No recent memory." + + def test_get_communication_history(self, mock_agent): + """Test that communication history is correctly extracted.""" + memory = ShortTermMemory(agent=mock_agent) + + # Manually add some entries + msg_entry_content = {"message": "Hello there!"} + action_entry_content = {"action": "move"} + + memory.short_term_memory.append( + MemoryEntry(content=msg_entry_content, step=1, agent=mock_agent) + ) + memory.short_term_memory.append( + MemoryEntry(content=action_entry_content, step=1, agent=mock_agent) + ) + + history = memory.get_communication_history() + assert "step 1: Hello there!" in history + assert "action" not in history