Skip to content

Commit

Permalink
expand ax.scatter kwargs that can be used (#2445)
Browse files Browse the repository at this point in the history
  • Loading branch information
quaquel authored Oct 31, 2024
1 parent 217cb58 commit 1a29aa4
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 29 deletions.
16 changes: 8 additions & 8 deletions docs/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -177,20 +177,20 @@ def agent_portrayal(agent):

model_params = {
"N": {
"type": "SliderInt",
"value": 50,
"label": "Number of agents:",
"min": 10,
"max": 100,
"step": 1,
"type": "SliderInt",
"value": 50,
"label": "Number of agents:",
"min": 10,
"max": 100,
"step": 1,
}
}

page = SolaraViz(
MyModel,
[
make_space_component(agent_portrayal),
make_plot_component("mean_age")
make_space_component(agent_portrayal),
make_plot_component("mean_age")
],
model_params=model_params
)
Expand Down
24 changes: 12 additions & 12 deletions mesa/examples/basic/boltzmann_wealth_model/app.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,10 @@
from mesa.examples.basic.boltzmann_wealth_model.model import BoltzmannWealthModel
from mesa.visualization import (
SolaraViz,
make_plot_component,
make_space_component,
)
from mesa.visualization import SolaraViz, make_plot_component, make_space_component


def agent_portrayal(agent):
size = 10
color = "tab:red"
if agent.wealth > 0:
size = 50
color = "tab:blue"
return {"size": size, "color": color}
color = agent.wealth # we are using a colormap to translate wealth to color
return {"color": color}


model_params = {
Expand All @@ -28,6 +20,11 @@ def agent_portrayal(agent):
"height": 10,
}


def post_process(ax):
ax.get_figure().colorbar(ax.collections[0], label="wealth", ax=ax)


# Create initial model instance
model1 = BoltzmannWealthModel(50, 10, 10)

Expand All @@ -36,7 +33,10 @@ def agent_portrayal(agent):
# Under the hood these are just classes that receive the model instance.
# You can also author your own visualization elements, which can also be functions
# that receive the model instance and return a valid solara component.
SpaceGraph = make_space_component(agent_portrayal)

SpaceGraph = make_space_component(
agent_portrayal, cmap="viridis", vmin=0, vmax=10, post_process=post_process
)
GiniPlot = make_plot_component("Gini")

# Create the SolaraViz page. This will automatically create a server and display the
Expand Down
34 changes: 25 additions & 9 deletions mesa/visualization/components/matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ def draw_orthogonal_grid(
agent_portrayal: Callable,
ax: Axes | None = None,
draw_grid: bool = True,
**kwargs,
):
"""Visualize a orthogonal grid.
Expand All @@ -317,6 +318,7 @@ def draw_orthogonal_grid(
agent_portrayal: a callable that is called with the agent and returns a dict
ax: a Matplotlib Axes instance. If none is provided a new figure and ax will be created using plt.subplots
draw_grid: whether to draw the grid
kwargs: additional keyword arguments passed to ax.scatter
Returns:
Returns the Axes object with the plot drawn onto it.
Expand All @@ -333,7 +335,7 @@ def draw_orthogonal_grid(
arguments = collect_agent_data(space, agent_portrayal, size=s_default)

# plot the agents
_scatter(ax, arguments)
_scatter(ax, arguments, **kwargs)

# further styling
ax.set_xlim(-0.5, space.width - 0.5)
Expand All @@ -354,6 +356,7 @@ def draw_hex_grid(
agent_portrayal: Callable,
ax: Axes | None = None,
draw_grid: bool = True,
**kwargs,
):
"""Visualize a hex grid.
Expand All @@ -362,6 +365,7 @@ def draw_hex_grid(
agent_portrayal: a callable that is called with the agent and returns a dict
ax: a Matplotlib Axes instance. If none is provided a new figure and ax will be created using plt.subplots
draw_grid: whether to draw the grid
kwargs: additional keyword arguments passed to ax.scatter
Returns:
Returns the Axes object with the plot drawn onto it.
Expand Down Expand Up @@ -394,7 +398,7 @@ def draw_hex_grid(
arguments["loc"] = loc

# plot the agents
_scatter(ax, arguments)
_scatter(ax, arguments, **kwargs)

# further styling and adding of grid
ax.set_xlim(-1, space.width + 0.5)
Expand Down Expand Up @@ -443,6 +447,7 @@ def draw_network(
draw_grid: bool = True,
layout_alg=nx.spring_layout,
layout_kwargs=None,
**kwargs,
):
"""Visualize a network space.
Expand All @@ -453,6 +458,7 @@ def draw_network(
draw_grid: whether to draw the grid
layout_alg: a networkx layout algorithm or other callable with the same behavior
layout_kwargs: a dictionary of keyword arguments for the layout algorithm
kwargs: additional keyword arguments passed to ax.scatter
Returns:
Returns the Axes object with the plot drawn onto it.
Expand Down Expand Up @@ -488,7 +494,7 @@ def draw_network(
arguments["loc"] = pos[arguments["loc"]]

# plot the agents
_scatter(ax, arguments)
_scatter(ax, arguments, **kwargs)

# further styling
ax.set_axis_off()
Expand All @@ -506,14 +512,15 @@ def draw_network(


def draw_continuous_space(
space: ContinuousSpace, agent_portrayal: Callable, ax: Axes | None = None
space: ContinuousSpace, agent_portrayal: Callable, ax: Axes | None = None, **kwargs
):
"""Visualize a continuous space.
Args:
space: the space to visualize
agent_portrayal: a callable that is called with the agent and returns a dict
ax: a Matplotlib Axes instance. If none is provided a new figure and ax will be created using plt.subplots
kwargs: additional keyword arguments passed to ax.scatter
Returns:
Returns the Axes object with the plot drawn onto it.
Expand All @@ -536,7 +543,7 @@ def draw_continuous_space(
arguments = collect_agent_data(space, agent_portrayal, size=s_default)

# plot the agents
_scatter(ax, arguments)
_scatter(ax, arguments, **kwargs)

# further visual styling
border_style = "solid" if not space.torus else (0, (5, 10))
Expand All @@ -552,14 +559,15 @@ def draw_continuous_space(


def draw_voroinoi_grid(
space: VoronoiGrid, agent_portrayal: Callable, ax: Axes | None = None
space: VoronoiGrid, agent_portrayal: Callable, ax: Axes | None = None, **kwargs
):
"""Visualize a voronoi grid.
Args:
space: the space to visualize
agent_portrayal: a callable that is called with the agent and returns a dict
ax: a Matplotlib Axes instance. If none is provided a new figure and ax will be created using plt.subplots
kwargs: additional keyword arguments passed to ax.scatter
Returns:
Returns the Axes object with the plot drawn onto it.
Expand Down Expand Up @@ -589,7 +597,7 @@ def draw_voroinoi_grid(
ax.set_xlim(x_min - x_padding, x_max + x_padding)
ax.set_ylim(y_min - y_padding, y_max + y_padding)

_scatter(ax, arguments)
_scatter(ax, arguments, **kwargs)

for cell in space.all_cells:
polygon = cell.properties["polygon"]
Expand All @@ -604,8 +612,15 @@ def draw_voroinoi_grid(
return ax


def _scatter(ax: Axes, arguments):
"""Helper function for plotting the agents."""
def _scatter(ax: Axes, arguments, **kwargs):
"""Helper function for plotting the agents.
Args:
ax: a Matplotlib Axes instance
arguments: the agents specific arguments for platting
kwargs: additional keyword arguments for ax.scatter
"""
loc = arguments.pop("loc")

x = loc[:, 0]
Expand All @@ -624,6 +639,7 @@ def _scatter(ax: Axes, arguments):
marker=mark,
zorder=z_order,
**{k: v[logical] for k, v in arguments.items()},
**kwargs,
)


Expand Down

0 comments on commit 1a29aa4

Please sign in to comment.