Skip to content

Commit

Permalink
Explicitly test basic examples (#2390)
Browse files Browse the repository at this point in the history
  • Loading branch information
quaquel authored Oct 20, 2024
1 parent d78641e commit 988ab71
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 86 deletions.
4 changes: 2 additions & 2 deletions mesa/examples/basic/boltzmann_wealth_model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ class BoltzmannWealthModel(mesa.Model):
highly skewed distribution of wealth.
"""

def __init__(self, n=100, width=10, height=10):
super().__init__()
def __init__(self, n=100, width=10, height=10, seed=None):
super().__init__(seed=seed)
self.num_agents = n
self.grid = mesa.space.MultiGrid(width, height, True)

Expand Down
4 changes: 2 additions & 2 deletions mesa/examples/basic/conways_game_of_life/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
class ConwaysGameOfLife(Model):
"""Represents the 2-dimensional array of cells in Conway's Game of Life."""

def __init__(self, width=50, height=50):
def __init__(self, width=50, height=50, seed=None):
"""Create a new playing area of (width, height) cells."""
super().__init__()
super().__init__(seed=seed)
# Use a simple grid, where edges wrap around.
self.grid = SingleGrid(width, height, torus=True)

Expand Down
8 changes: 3 additions & 5 deletions mesa/examples/basic/virus_on_network/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,9 @@ def __init__(
virus_check_frequency=0.4,
recovery_chance=0.3,
gain_resistance_chance=0.5,
seed=None,
):
super().__init__()
super().__init__(seed=seed)
self.num_nodes = num_nodes
prob = avg_node_degree / self.num_nodes
self.G = nx.erdos_renyi_graph(n=self.num_nodes, p=prob)
Expand All @@ -56,6 +57,7 @@ def __init__(
"Infected": number_infected,
"Susceptible": number_susceptible,
"Resistant": number_resistant,
"R over S": self.resistant_susceptible_ratio,
}
)

Expand Down Expand Up @@ -93,7 +95,3 @@ def step(self):
self.agents.shuffle_do("step")
# collect data
self.datacollector.collect(self)

def run_model(self, n):
for _ in range(n):
self.step()
104 changes: 27 additions & 77 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
@@ -1,90 +1,40 @@
# noqa: D100
import contextlib
import importlib
import os.path
import sys
import unittest
from mesa.examples import (
BoidFlockers,
BoltzmannWealthModel,
ConwaysGameOfLife,
Schelling,
VirusOnNetwork,
)


def test_examples_imports():
"""Test examples imports."""
from mesa.examples import (
BoidFlockers,
BoltzmannWealthModel,
ConwaysGameOfLife,
Schelling,
VirusOnNetwork,
)
def test_boltzmann_model(): # noqa: D103
model = BoltzmannWealthModel(seed=42)

BoltzmannWealthModel()
Schelling()
BoidFlockers()
ConwaysGameOfLife()
VirusOnNetwork()
for _i in range(10):
model.step()


def classcase(name): # noqa: D103
return "".join(x.capitalize() for x in name.replace("-", "_").split("_"))
def test_conways_game_model(): # noqa: D103
model = ConwaysGameOfLife(seed=42)
for _i in range(10):
model.step()


@unittest.skip(
"Skipping TextExamples, because examples folder was moved. More discussion needed."
)
class TestExamples(unittest.TestCase):
"""Test examples' models.
def test_schelling_model(): # noqa: D103
model = Schelling(seed=42)
for _i in range(10):
model.step()

This creates a model object and iterates it through
some steps. The idea is to get code coverage, rather than to test the
details of each example's model.
"""

EXAMPLES = os.path.abspath(
os.path.join(os.path.dirname(__file__), "../mesa/examples")
)
def test_virus_on_network(): # noqa: D103
model = VirusOnNetwork(seed=42)
for _i in range(10):
model.step()

@contextlib.contextmanager
def active_example_dir(self, example):
"""Save and restore sys.path and sys.modules."""
old_sys_path = sys.path[:]
old_sys_modules = sys.modules.copy()
old_cwd = os.getcwd()
example_path = os.path.abspath(os.path.join(self.EXAMPLES, example))
try:
sys.path.insert(0, example_path)
os.chdir(example_path)
yield
finally:
os.chdir(old_cwd)
added = [m for m in sys.modules if m not in old_sys_modules]
for mod in added:
del sys.modules[mod]
sys.modules.update(old_sys_modules)
sys.path[:] = old_sys_path

def test_examples(self): # noqa: D102
for example in os.listdir(self.EXAMPLES):
if not os.path.isdir(os.path.join(self.EXAMPLES, example)):
continue
if hasattr(self, f"test_{example.replace('-', '_')}"):
# non-standard example; tested below
continue
def test_boid_flockers(): # noqa: D103
model = BoidFlockers(seed=42)

print(f"testing example {example!r}")
with self.active_example_dir(example):
try:
# epstein_civil_violence.py at the top level
mod = importlib.import_module("model")
server = importlib.import_module("server")
server.server.render_model()
except ImportError:
# <example>/epstein_civil_violence.py
mod = importlib.import_module(f"{example.replace('-', '_')}.model")
server = importlib.import_module(
f"{example.replace('-', '_')}.server"
)
server.server.render_model()
model_class = getattr(mod, classcase(example))
model = model_class()
for _ in range(10):
model.step()
self.assertEqual(model.steps, 10)
for _i in range(10):
model.step()

0 comments on commit 988ab71

Please sign in to comment.