Skip to content

Commit

Permalink
API Allow HomogUniv.plot to forward kwargs to mpl plots
Browse files Browse the repository at this point in the history
Closes #424

Passing drawstyle and steps perform the same actions and are
supported for now. Only drawstyle will be supported in the future
but for now silently raise a pending deprecation notice
  • Loading branch information
drewejohnson committed Dec 18, 2020
1 parent 628f81e commit b9d88a3
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 11 deletions.
40 changes: 31 additions & 9 deletions serpentTools/objects/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"""
from itertools import product
from collections import namedtuple
import warnings

from matplotlib import pyplot
from numpy import arange, hstack, ndarray, zeros_like
Expand Down Expand Up @@ -371,7 +372,7 @@ def _lookup(self, variableName, uncertainty):
@magicPlotDocDecorator
def plot(self, qtys, limitE=True, ax=None, logx=None, logy=None,
loglog=None, sigma=3, xlabel=None, ylabel=None, legend=None,
ncol=1, steps=True, labelFmt=None, labels=None):
ncol=1, steps=True, labelFmt=None, labels=None, **kwargs):
"""
Plot homogenized data as a function of energy.
Expand All @@ -398,6 +399,8 @@ def plot(self, qtys, limitE=True, ax=None, logx=None, logy=None,
If ``True``, plot values as constant within
energy bins.
{univLabelFmt}
{kwargs} :func:`matplotlib.pyplot.plot` or
:func:`matplotlib.pyplot.errorbar`
Returns
-------
Expand All @@ -418,15 +421,35 @@ def plot(self, qtys, limitE=True, ax=None, logx=None, logy=None,
if limitE:
eneCap = min(self.microGroups.max(), self.groups.max())

# Check kwargs
if "drawstyle" in kwargs and steps:
# Conflicting arguments but defer to user value for now
warnings.warn(
"Passing steps and drawstyle will default to using the "
"drawstyle value but may cause an error later",
PendingDeprecationWarning
)
else:
kwargs.setdefault("drawstyle", drawstyle)

if "label" in kwargs:
if len(qtys) > 1:
raise ValueError(
"Passing label while plotting multiple entries {} is "
"not allowed".format(qtys)
)
if labels is not None:
raise ValueError("Passing label and labels is not allowed")
labels = kwargs.pop("label")

if isinstance(labels, str):
labels = [labels, ]
if labels is None:
elif labels is None:
labels = [labelFmt, ] * len(qtys)
else:
if len(labels) != len(qtys):
raise IndexError(
"Need equal number of labels for plot quantities. "
"Given {} expected: {}".format(len(labels), len(qtys)))
elif len(labels) != len(qtys):
raise IndexError(
"Need equal number of labels for plot quantities. "
"Given {} expected: {}".format(len(labels), len(qtys)))

for key, label in zip(qtys, labels):
yVals = self.__getitem__(key)
Expand All @@ -451,8 +474,7 @@ def plot(self, qtys, limitE=True, ax=None, logx=None, logy=None,

label = self.__formatLabel(label, key)

ax.errorbar(xdata, yVals, yerr=yUncs, label=label,
drawstyle=drawstyle)
ax.errorbar(xdata, yVals, yerr=yUncs, label=label, **kwargs)

if ylabel is None:
ylabel, yUnits = (("Cross Section", "[cm$^{-1}$]") if onlyXS
Expand Down
3 changes: 1 addition & 2 deletions serpentTools/utils/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,7 @@ def placeLegend(ax, legend, handlesAndLabels=None, **kwargs):
{rax}
"""
# import pdb
# pdb.set_trace()

if handlesAndLabels is None:
handles, labels = ax.get_legend_handles_labels()
else:
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
40 changes: 40 additions & 0 deletions tests/plots/test_homoguniv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import pytest
from serpentTools.settings import rc
from serpentTools.data import readDataFile

from . import compare_or_update_plot


@pytest.fixture(scope="module")
def univ():
with rc:
rc["serpentVersion"] = "2.1.30"
reader = readDataFile("InnerAssembly_res.m")
yield reader.universes["0", 0, 0, 0]


@compare_or_update_plot
def test_homoguniv_single(univ):
univ.plot("infTot", label="Total", legend=True)


@compare_or_update_plot
def test_homoguniv_multi(univ):
univ.plot(
["infAbs", "infTot"],
logx=False,
logy=False,
labelFmt="{u} @ {b} MWd/kgU // {d} days // step {i}: {k}",
xlabel="Incident energy (MeV)",
ylabel="Macroscopic cross section (cm$^{-1})$",
# Addtional arguments to pass along to the underlying plot
linestyle="--",
)


@compare_or_update_plot
def test_homoguniv_multi_named(univ):
univ.plot(
["infAbs", "infFlx"],
labels=["Absorption", "Flux"],
)

0 comments on commit b9d88a3

Please sign in to comment.