Skip to content

Commit

Permalink
Merge pull request #400 from lsst/tickets/OPSIM-1141
Browse files Browse the repository at this point in the history
OPSIM-1141: Plotters return Figure objects, not fig.number
  • Loading branch information
rhiannonlynne authored Mar 29, 2024
2 parents 88eb4e9 + a12cb6e commit dcaac1c
Show file tree
Hide file tree
Showing 18 changed files with 462 additions and 361 deletions.
10 changes: 6 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ line_length = 110
exclude = [
"__init__.py",
]
line-length = 110
target-version = "py311"

[tool.ruff.lint]
ignore = [
"N802",
"N803",
Expand All @@ -124,21 +128,19 @@ ignore = [
"D400",
"E712",
]
line-length = 110
select = [
"E", # pycodestyle
"F", # pyflakes
"N", # pep8-naming
"W", # pycodestyle
]
target-version = "py311"
extend-select = [
"RUF100", # Warn about unused noqa
]

[tool.ruff.pycodestyle]
[tool.ruff.lint.pycodestyle]
max-doc-length = 79

[tool.ruff.pydocstyle]
[tool.ruff.lint.pydocstyle]
convention = "numpy"

2 changes: 1 addition & 1 deletion rubin_sim/maf/batches/moving_objects_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,7 +835,7 @@ def _codePlot(key):
display_dict["subgroup"] = "Completeness over time"
display_dict["caption"] = "Completeness over time, for H values indicated in legend."
ph.save_fig(
fig.number,
fig,
f"{figroot}_CompletenessOverTime",
"Combo",
"CompletenessOverTime",
Expand Down
12 changes: 6 additions & 6 deletions rubin_sim/maf/metric_bundles/metric_bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,9 +790,9 @@ def plot(self, plot_handler=None, plot_func=None, outfile_suffix=None, savefig=F
Returns
-------
made_plots : `dict`
Dictionary of plot_type:figure number key/value pairs,
Dictionary of plot_type:figure key/value pairs,
indicating what plots were created
and what matplotlib figure numbers were used.
and what matplotlib figures were used.
"""
# Generate a plot_handler if none was set.
if plot_handler is None:
Expand All @@ -808,10 +808,10 @@ def plot(self, plot_handler=None, plot_func=None, outfile_suffix=None, savefig=F
plot_handler.set_plot_dicts(plot_dicts=[self.plot_dict], reset=True)
made_plots = {}
if plot_func is not None:
fignum = plot_handler.plot(plot_func, outfile_suffix=outfile_suffix)
made_plots[plot_func.plotType] = fignum
fig = plot_handler.plot(plot_func, outfile_suffix=outfile_suffix)
made_plots[plot_func.plotType] = fig
else:
for plot_func in self.plot_funcs:
fignum = plot_handler.plot(plot_func, outfile_suffix=outfile_suffix)
made_plots[plot_func.plot_type] = fignum
fig = plot_handler.plot(plot_func, outfile_suffix=outfile_suffix)
made_plots[plot_func.plot_type] = fig
return made_plots
1 change: 0 additions & 1 deletion rubin_sim/maf/plots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from .night_pointing_plotter import *
from .oned_plotters import *
from .perceptual_rainbow import *
from .plot_bundle import *
from .plot_handler import *
from .spatial_plotters import *
from .special_plotters import *
Expand Down
37 changes: 22 additions & 15 deletions rubin_sim/maf/plots/hg_plotters.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,21 +305,25 @@ def _map_colors(self, values): # pylint: disable=invalid-name, no-self-use

return colors, color_mappable

def __call__(self, metric_value, slicer, user_plot_dict, fignum=None):
def __call__(self, metric_value, slicer, user_plot_dict, fig=None):
"""Restructure the metric data to use, and build the figure.
Parameters
----------
metric_value : `numpy.ndarray`
Metric values
slicer : `rubin_sim.maf.slicers.baseSlicer.BaseSlicer`
must have "mjd" and "duration" slice points, in units
of days and seconds, respectively.
user_plot_dict : `dict`
Plotting parameters
fignum : `int`
matplotlib figure number
metric_value : `numpy.ma.MaskedArray`
The metric values from the bundle.
slicer : `rubin_sim.maf.slicers.TwoDSlicer`
The slicer.
user_plot_dict: `dict`
Dictionary of plot parameters set by user
(overrides default values).
fig : `matplotlib.figure.Figure`
Matplotlib figure number to use. Default = None, starts new figure.
Returns
-------
fig : `matplotlib.figure.Figure`
Figure with the plot.
"""
# Highest level driver for the plotter.
# Prepares data structures and figure-wide elements
Expand All @@ -344,7 +348,8 @@ def __call__(self, metric_value, slicer, user_plot_dict, fignum=None):
self.plot_dict.update(user_plot_dict)

# Generate the figure
fig = plt.figure(fignum, figsize=self.plot_dict["figsize"])
if fig is None:
fig = plt.figure(figsize=self.plot_dict["figsize"])

# Add the plots
color_mappable, axes = self._plot(fig, intervals)
Expand All @@ -355,7 +360,9 @@ def __call__(self, metric_value, slicer, user_plot_dict, fignum=None):

# add the legend, if requested
if self.plot_dict["legend"] and len(self.color_map) > 0:
self._add_figure_legend(fig, axes)
fig = self._add_figure_legend(fig, axes)

return fig

def _add_figure_legend(self, fig, axes):
"""Creates and adds the figure legend.
Expand All @@ -370,8 +377,8 @@ def _add_figure_legend(self, fig, axes):
Returns
-------
figure_number : `int`
The matplotlib figure number
fig : `matplotlib.figure.Figure`
The matplotlib figure
"""
# Creates and adds the figure legend. This method follows two
# stages: first, it adds entries for each element in the color
Expand Down Expand Up @@ -423,7 +430,7 @@ def _add_figure_legend(self, fig, axes):
bbox_to_anchor=self.plot_dict["legend_bbox_to_anchor"],
)

return fig.number
return fig

def _add_colorbar(self, fig, color_mappable, axes): # pylint: disable=invalid-name, no-self-use
"""Add a colorbar.
Expand Down
17 changes: 10 additions & 7 deletions rubin_sim/maf/plots/hourglass_plotters.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def __init__(self):
"title": None,
"xlabel": "Night - min(Night)",
"ylabel": "Hours from local midnight",
"figsize": None,
}
self.filter2color = {
"u": "purple",
Expand All @@ -24,20 +25,20 @@ def __init__(self):
"y": "red",
}

def __call__(self, metric_value, slicer, user_plot_dict, fignum=None):
def __call__(self, metric_value, slicer, user_plot_dict, fig=None):
"""
Generate the hourglass plot
"""
if slicer.slicer_name != "HourglassSlicer":
raise ValueError("HourglassPlot is for use with hourglass slicers")

fig = plt.figure(fignum)
ax = fig.add_subplot(111)

plot_dict = {}
plot_dict.update(self.default_plot_dict)
plot_dict.update(user_plot_dict)

if fig is None:
fig = plt.figure(figsize=plot_dict["figsize"])
ax = fig.add_subplot(111)
pernight = metric_value[0]["pernight"]
perfilter = metric_value[0]["perfilter"]

Expand All @@ -57,9 +58,11 @@ def __call__(self, metric_value, slicer, user_plot_dict, fignum=None):
color=self.filter2color[key],
transform=ax.transAxes,
)
# ax.plot(pernight['mjd'] - dmin, (pernight['twi6_rise'] - pernight['midnight']) * 24.,
# ax.plot(pernight['mjd'] - dmin,
# (pernight['twi6_rise'] - pernight['midnight']) * 24.,
# 'blue', label=r'6$^\circ$ twilight')
# ax.plot(pernight['mjd'] - dmin, (pernight['twi6_set'] - pernight['midnight']) * 24.,
# ax.plot(pernight['mjd'] - dmin,
# (pernight['twi6_set'] - pernight['midnight']) * 24.,
# 'blue')
ax.plot(
pernight["mjd"] - dmin,
Expand Down Expand Up @@ -114,4 +117,4 @@ def __call__(self, metric_value, slicer, user_plot_dict, fignum=None):
)
ax.axhline((pernight["twi18_set"] - pernight["midnight"]) * 24.0, color="red")

return fig.number
return fig
43 changes: 26 additions & 17 deletions rubin_sim/maf/plots/mo_plotters.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@

from .plot_handler import BasePlotter

# mag_sun = -27.1 # apparent r band magnitude of the sun. this sets the band for the magnitude limit.
# see http://www.ucolick.org/~cnaw/sun.html for apparent magnitudes in other bands.
mag_sun = -26.74 # apparent V band magnitude of the Sun (our H mags translate to V band)
# mag_sun = -27.1
# apparent r band magnitude of the sun.
# this sets the band for the magnitude limit.
# see http://www.ucolick.org/~cnaw/sun.html for apparent mags in other bands.
mag_sun = -26.74
# apparent V band magnitude of the Sun (our H mags translate to V band)
km_per_au = 1.496e8
m_per_km = 1000

Expand Down Expand Up @@ -37,7 +40,7 @@ def __init__(self):
}
self.min_hrange = 1.0

def __call__(self, metric_value, slicer, user_plot_dict, fignum=None):
def __call__(self, metric_value, slicer, user_plot_dict, fig=None):
if "linestyle" not in user_plot_dict:
user_plot_dict["linestyle"] = "-"
plot_dict = {}
Expand All @@ -49,11 +52,13 @@ def __call__(self, metric_value, slicer, user_plot_dict, fignum=None):
reduce_func = np.mean
if hvals.shape[0] == 1:
# We have a simple set of values to plot against H.
# This may be due to running a summary metric, such as completeness.
# This may be due to running a summary metric,
# such as completeness.
m_vals = metric_value[0].filled()
elif len(hvals) == slicer.shape[1]:
# Using cloned H distribution.
# Apply 'np_reduce' method directly to metric values, and plot at matching H values.
# Apply 'np_reduce' method directly to metric values,
# and plot at matching H values.
m_vals = reduce_func(metric_value.filled(), axis=0)
else:
# Probably each object has its own H value.
Expand All @@ -67,7 +72,8 @@ def __call__(self, metric_value, slicer, user_plot_dict, fignum=None):
nbins = 30
stepsize = hrange / float(nbins)
bins = np.arange(min_h, min_h + hrange + stepsize / 2.0, stepsize)
# In each bin of H, calculate the 'np_reduce' value of the corresponding metric_values.
# In each bin of H, calculate the 'np_reduce' value of the
# corresponding metric_values.
inds = np.digitize(hvals, bins)
inds = inds - 1
m_vals = np.zeros(len(bins), float)
Expand All @@ -79,7 +85,8 @@ def __call__(self, metric_value, slicer, user_plot_dict, fignum=None):
m_vals[i] = reduce_func(match.filled())
hvals = bins
# Plot the values.
fig = plt.figure(fignum, figsize=plot_dict["figsize"])
if fig is None:
fig = plt.figure(figsize=plot_dict["figsize"])
ax = plt.gca()
ax.plot(
hvals,
Expand Down Expand Up @@ -130,7 +137,7 @@ def __call__(self, metric_value, slicer, user_plot_dict, fignum=None):
plt.xlabel(plot_dict["xlabel"])
plt.ylabel(plot_dict["ylabel"])
plt.tight_layout()
return fig.number
return fig


class MetricVsOrbit(BasePlotter):
Expand Down Expand Up @@ -159,11 +166,12 @@ def __init__(self, xaxis="q", yaxis="e"):
"figsize": None,
}

def __call__(self, metric_value, slicer, user_plot_dict, fignum=None):
def __call__(self, metric_value, slicer, user_plot_dict, fig=None):
plot_dict = {}
plot_dict.update(self.default_plot_dict)
plot_dict.update(user_plot_dict)
fig = plt.figure(fignum, figsize=plot_dict["figsize"])
if fig is None:
fig = plt.figure(figsize=plot_dict["figsize"])
xvals = slicer.slice_points["orbits"][plot_dict["xaxis"]]
yvals = slicer.slice_points["orbits"][plot_dict["yaxis"]]
# Set x/y bins.
Expand Down Expand Up @@ -249,13 +257,13 @@ def __call__(self, metric_value, slicer, user_plot_dict, fignum=None):
plt.title(plot_dict["title"])
plt.xlabel(plot_dict["xlabel"])
plt.ylabel(plot_dict["ylabel"])
return fig.number
return fig


class MetricVsOrbitPoints(BasePlotter):
"""
Plot metric values (at a particular H value) as function of orbital parameters,
using points for each metric value.
Plot metric values (at a particular H value) as function
of orbital parameters, using points for each metric value.
"""

def __init__(self, xaxis="q", yaxis="e"):
Expand All @@ -276,11 +284,12 @@ def __init__(self, xaxis="q", yaxis="e"):
"figsize": None,
}

def __call__(self, metric_value, slicer, user_plot_dict, fignum=None):
def __call__(self, metric_value, slicer, user_plot_dict, fig=None):
plot_dict = {}
plot_dict.update(self.default_plot_dict)
plot_dict.update(user_plot_dict)
fig = plt.figure(fignum, figsize=plot_dict["figsize"])
if fig is None:
fig = plt.figure(figsize=plot_dict["figsize"])
xvals = slicer.slice_points["orbits"][plot_dict["xaxis"]]
yvals = slicer.slice_points["orbits"][plot_dict["yaxis"]]
# Identify the relevant metric_values for the Hvalue we want to plot.
Expand Down Expand Up @@ -346,4 +355,4 @@ def __call__(self, metric_value, slicer, user_plot_dict, fignum=None):
plt.title(plot_dict["title"])
plt.xlabel(plot_dict["xlabel"])
plt.ylabel(plot_dict["ylabel"])
return fig.number
return fig
Loading

0 comments on commit dcaac1c

Please sign in to comment.