Skip to content

Commit

Permalink
Merge branch 'master' into mcflugen/at-keyword-for-imshow
Browse files Browse the repository at this point in the history
  • Loading branch information
mcflugen committed May 22, 2022
2 parents 7eacf00 + 436f37c commit fccc3d5
Show file tree
Hide file tree
Showing 8 changed files with 172 additions and 21 deletions.
3 changes: 3 additions & 0 deletions landlab/plot/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from .imshow import imshow_grid, imshow_grid_at_node
from .imshowhs import imshowhs_grid, imshowhs_grid_at_node
from .network_sediment_transporter import plot_network_and_parcels
from .graph import plot_graph
from .layers import plot_layers


__all__ = [
"imshow_grid",
"imshowhs_grid",
"imshow_grid_at_node",
"imshowhs_grid_at_node",
"plot_network_and_parcels",
"plot_layers",
"plot_graph",
]
118 changes: 97 additions & 21 deletions landlab/plot/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np


def plot_nodes(graph, color="r", with_id=True, markersize=10):
def plot_nodes(graph, color="r", with_id=True, markersize=4):
for node in range(len(graph.x_of_node)):
x, y = graph.x_of_node[node], graph.y_of_node[node]
plt.plot(
Expand Down Expand Up @@ -41,38 +41,114 @@ def plot_links(
plt.text(x + dx * 0.5, y + dy * 0.5, link, size=16, color=color)


def plot_patches(graph, color="g"):
def plot_patches(graph, color="g", with_id=False):
from matplotlib.patches import Polygon

for patch, nodes in enumerate(graph.nodes_at_patch):
nodes = nodes[nodes >= 0]
x, y = np.mean(graph.x_of_node[nodes]), np.mean(graph.y_of_node[nodes])
plt.text(x, y, patch, color=color, size=16)
plt.gca().add_patch(
Polygon(graph.xy_of_node[nodes], ec=color, fc=None, alpha=0.5)
)
if with_id:
plt.text(
x,
y,
patch,
color=color,
size=16,
horizontalalignment="center",
verticalalignment="center",
)


def plot_graph(graph, at="node,link,patch", with_id=True, axes=None):
"""Plot elements of a graph.
Parameters
----------
graph : graph-like
A landlab graph-like object.
at : str or iterable of str
Comma-separated list of elements to plot.
with_id : str, iterable of str or bool
Indicate which elements should be plotted with their corresponding id.
Either a comma-separated list of grid elements or ``True`` to include
ids for all elements of ``False`` for no elements.
axes : , optional
Add the plot to an existing matplotlib ``Axes``, otherwise, create a new one.
def plot_graph(graph, at="node,link,patch", with_id=True):
locs = [loc.strip() for loc in at.split(",")]
for loc in locs:
if loc not in ("node", "link", "patch", "corner", "face", "cell"):
raise ValueError('{at}: "at" element not understood'.format(at=loc))
Returns
-------
``Axes``
The ``Axes`` containing the plot.
"""
EVERYWHERE = {"node", "link", "patch", "corner", "face", "cell"}

plt.plot(graph.x_of_node, graph.y_of_node, ".", color="r")
plt.xlim([min(graph.x_of_node) - 0.5, max(graph.x_of_node) + 0.5])
plt.ylim([min(graph.y_of_node) - 0.5, max(graph.y_of_node) + 0.5])
if isinstance(with_id, bool):
with_id = EVERYWHERE if with_id else set()
else:
with_id = _parse_locations_as_set(with_id)
locs = _parse_locations_as_set(at)

ax = plt.axes() if axes is None else axes

ax.set_xlim([min(graph.x_of_node) - 0.5, max(graph.x_of_node) + 0.5])
ax.set_ylim([min(graph.y_of_node) - 0.5, max(graph.y_of_node) + 0.5])

if "node" in locs:
plot_nodes(graph, with_id=with_id, markersize=10)
plot_nodes(graph, with_id="node" in with_id, markersize=4)
if "link" in locs:
plot_links(graph, with_id=with_id, linewidth=None, as_arrow=False)
plot_links(graph, with_id="link" in with_id, linewidth=None, as_arrow=True)
if "patch" in locs:
plot_patches(graph)
plot_patches(graph, with_id="patch" in with_id)

if "corner" in locs:
plot_nodes(graph.dual, color="c")
plot_nodes(graph.dual, color="c", with_id="corner" in with_id)
if "face" in locs:
plot_links(graph.dual, linestyle="dotted", color="k")
plot_links(graph.dual, linestyle="dotted", color="k", with_id="face" in with_id)
if "cell" in locs and graph.number_of_cells > 0:
plot_patches(graph.dual, color="m")
plot_patches(graph.dual, color="m", with_id="cell" in with_id)

ax.set_xlabel("x")
ax.set_ylabel("y")
ax.set_aspect(1.0)

return ax

plt.xlabel("x")
plt.ylabel("y")
plt.gca().set_aspect(1.0)

plt.show()
def _parse_locations_as_set(locations):
"""Parse grid element locations as a set.
Parameters
----------
locations : str or iterable of str
Grid locations.
Returns
-------
set
Grid locations as strings.
Raises
------
ValueError
If any of the locations are invalid.
"""
EVERYWHERE = {"node", "link", "patch", "corner", "face", "cell"}

if isinstance(locations, str):
as_set = set(locations.split(","))
else:
as_set = set(locations)

as_set = {item.strip() for item in as_set}

unknown = sorted(as_set - EVERYWHERE)
if unknown:
unknown = [repr(item) for item in unknown]
raise ValueError(
f"unknown location{'s' if len(unknown) > 1 else ''} ({', '.join(unknown)})"
)

return as_set
4 changes: 4 additions & 0 deletions news/1425.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Enhanced the ``plot_graph`` function: allow the ``with_id`` keyword to
accept a list of elements that should have included IDs, fill in patches and
cells.

2 changes: 2 additions & 0 deletions news/1425.feature.1
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
The ``plot_graph`` function now can take lists of graph elements rather than only comma-separated strings.

3 changes: 3 additions & 0 deletions news/1425.feature.2
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Added a new keyword, ``axes`` to ``plot_graph`` to allow plotting within an
existing axes.

4 changes: 4 additions & 0 deletions news/1428.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Fixed a bug where ``plot_graph`` would incorrectly include the last
node/corner with patches/cells that had fewer links/faces than the maximum of
the graph.

2 changes: 2 additions & 0 deletions news/1428.bugfix.1
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Fixed a bug in ``plot_graph`` where patch and cell polygons were not drawn.

57 changes: 57 additions & 0 deletions tests/plot/test_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import pytest

from landlab import RasterModelGrid
from landlab.plot.graph import _parse_locations_as_set, plot_graph


def _axes_arrows(ax):
from matplotlib.patches import FancyArrow

return [child for child in ax.get_children() if isinstance(child, FancyArrow)]


def test_parse_locations_from_str():
assert _parse_locations_as_set("node") == {"node"}
assert _parse_locations_as_set("node,link") == {"node", "link"}
assert _parse_locations_as_set("node,link,cell,link") == {"node", "link", "cell"}
assert _parse_locations_as_set("node , link, patch ") == {"node", "link", "patch"}


def test_parse_locations_from_iterable():
assert _parse_locations_as_set(["cell"]) == {"cell"}
assert _parse_locations_as_set(("patch", "corner")) == {"patch", "corner"}
assert _parse_locations_as_set(("patch", "corner", "patch")) == {"patch", "corner"}
assert _parse_locations_as_set((" patch ", "corner ")) == {"patch", "corner"}


def test_parse_locations_bad_value():
with pytest.raises(ValueError, match=r"^unknown location "):
_parse_locations_as_set("foo")
with pytest.raises(ValueError, match=r"^unknown locations "):
_parse_locations_as_set("cells,nodes")


@pytest.mark.parametrize("at", ["node", "link", "patch", "corner", "face", "cell"])
def test_plot_graph(at):
grid = RasterModelGrid((3, 4))
ax = plot_graph(grid, at=at)

assert ax.get_ylim() == (-0.5, 2.5)
assert ax.get_xlim() == (-0.5, 3.5)

if at in ("patch", "cell"):
assert len(ax.patches) == grid.number_of_elements(at)
elif at in ("link", "face"):
assert len(_axes_arrows(ax)) == grid.number_of_elements(at)
else:
assert len(ax.lines) == grid.number_of_elements(at)


def test_plot_graph_onto_existing():
grid = RasterModelGrid((3, 4))
ax = plot_graph(grid, at="node")
plot_graph(grid, at="cell", axes=ax)

assert (
len(ax.patches) + len(ax.lines) == grid.number_of_nodes + grid.number_of_cells
)

0 comments on commit fccc3d5

Please sign in to comment.