Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OPSIM-1141: Plotters return Figure objects, not fig.number #400

Merged
merged 6 commits into from
Mar 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ line_length = 110
exclude = [
"__init__.py",
]
line-length = 110
target-version = "py311"

[tool.ruff.lint]
ignore = [
"N802",
"N803",
Expand All @@ -124,21 +128,19 @@ ignore = [
"D400",
"E712",
]
line-length = 110
select = [
"E", # pycodestyle
"F", # pyflakes
"N", # pep8-naming
"W", # pycodestyle
]
target-version = "py311"
extend-select = [
"RUF100", # Warn about unused noqa
]

[tool.ruff.pycodestyle]
[tool.ruff.lint.pycodestyle]
max-doc-length = 79

[tool.ruff.pydocstyle]
[tool.ruff.lint.pydocstyle]
convention = "numpy"

2 changes: 1 addition & 1 deletion rubin_sim/maf/batches/moving_objects_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,7 +835,7 @@ def _codePlot(key):
display_dict["subgroup"] = "Completeness over time"
display_dict["caption"] = "Completeness over time, for H values indicated in legend."
ph.save_fig(
fig.number,
fig,
f"{figroot}_CompletenessOverTime",
"Combo",
"CompletenessOverTime",
Expand Down
12 changes: 6 additions & 6 deletions rubin_sim/maf/metric_bundles/metric_bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,9 +790,9 @@ def plot(self, plot_handler=None, plot_func=None, outfile_suffix=None, savefig=F
Returns
-------
made_plots : `dict`
Dictionary of plot_type:figure number key/value pairs,
Dictionary of plot_type:figure key/value pairs,
indicating what plots were created
and what matplotlib figure numbers were used.
and what matplotlib figures were used.
"""
# Generate a plot_handler if none was set.
if plot_handler is None:
Expand All @@ -808,10 +808,10 @@ def plot(self, plot_handler=None, plot_func=None, outfile_suffix=None, savefig=F
plot_handler.set_plot_dicts(plot_dicts=[self.plot_dict], reset=True)
made_plots = {}
if plot_func is not None:
fignum = plot_handler.plot(plot_func, outfile_suffix=outfile_suffix)
made_plots[plot_func.plotType] = fignum
fig = plot_handler.plot(plot_func, outfile_suffix=outfile_suffix)
made_plots[plot_func.plotType] = fig
else:
for plot_func in self.plot_funcs:
fignum = plot_handler.plot(plot_func, outfile_suffix=outfile_suffix)
made_plots[plot_func.plot_type] = fignum
fig = plot_handler.plot(plot_func, outfile_suffix=outfile_suffix)
made_plots[plot_func.plot_type] = fig
return made_plots
1 change: 0 additions & 1 deletion rubin_sim/maf/plots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from .night_pointing_plotter import *
from .oned_plotters import *
from .perceptual_rainbow import *
from .plot_bundle import *
from .plot_handler import *
from .spatial_plotters import *
from .special_plotters import *
Expand Down
37 changes: 22 additions & 15 deletions rubin_sim/maf/plots/hg_plotters.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,21 +305,25 @@ def _map_colors(self, values): # pylint: disable=invalid-name, no-self-use

return colors, color_mappable

def __call__(self, metric_value, slicer, user_plot_dict, fignum=None):
def __call__(self, metric_value, slicer, user_plot_dict, fig=None):
"""Restructure the metric data to use, and build the figure.

Parameters
----------
metric_value : `numpy.ndarray`
Metric values
slicer : `rubin_sim.maf.slicers.baseSlicer.BaseSlicer`
must have "mjd" and "duration" slice points, in units
of days and seconds, respectively.
user_plot_dict : `dict`
Plotting parameters
fignum : `int`
matplotlib figure number
metric_value : `numpy.ma.MaskedArray`
The metric values from the bundle.
slicer : `rubin_sim.maf.slicers.TwoDSlicer`
The slicer.
user_plot_dict: `dict`
Dictionary of plot parameters set by user
(overrides default values).
fig : `matplotlib.figure.Figure`
Matplotlib figure number to use. Default = None, starts new figure.

Returns
-------
fig : `matplotlib.figure.Figure`
Figure with the plot.
"""
# Highest level driver for the plotter.
# Prepares data structures and figure-wide elements
Expand All @@ -344,7 +348,8 @@ def __call__(self, metric_value, slicer, user_plot_dict, fignum=None):
self.plot_dict.update(user_plot_dict)

# Generate the figure
fig = plt.figure(fignum, figsize=self.plot_dict["figsize"])
if fig is None:
fig = plt.figure(figsize=self.plot_dict["figsize"])

# Add the plots
color_mappable, axes = self._plot(fig, intervals)
Expand All @@ -355,7 +360,9 @@ def __call__(self, metric_value, slicer, user_plot_dict, fignum=None):

# add the legend, if requested
if self.plot_dict["legend"] and len(self.color_map) > 0:
self._add_figure_legend(fig, axes)
fig = self._add_figure_legend(fig, axes)

return fig

def _add_figure_legend(self, fig, axes):
"""Creates and adds the figure legend.
Expand All @@ -370,8 +377,8 @@ def _add_figure_legend(self, fig, axes):

Returns
-------
figure_number : `int`
The matplotlib figure number
fig : `matplotlib.figure.Figure`
The matplotlib figure
"""
# Creates and adds the figure legend. This method follows two
# stages: first, it adds entries for each element in the color
Expand Down Expand Up @@ -423,7 +430,7 @@ def _add_figure_legend(self, fig, axes):
bbox_to_anchor=self.plot_dict["legend_bbox_to_anchor"],
)

return fig.number
return fig

def _add_colorbar(self, fig, color_mappable, axes): # pylint: disable=invalid-name, no-self-use
"""Add a colorbar.
Expand Down
17 changes: 10 additions & 7 deletions rubin_sim/maf/plots/hourglass_plotters.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def __init__(self):
"title": None,
"xlabel": "Night - min(Night)",
"ylabel": "Hours from local midnight",
"figsize": None,
}
self.filter2color = {
"u": "purple",
Expand All @@ -24,20 +25,20 @@ def __init__(self):
"y": "red",
}

def __call__(self, metric_value, slicer, user_plot_dict, fignum=None):
def __call__(self, metric_value, slicer, user_plot_dict, fig=None):
"""
Generate the hourglass plot
"""
if slicer.slicer_name != "HourglassSlicer":
raise ValueError("HourglassPlot is for use with hourglass slicers")

fig = plt.figure(fignum)
ax = fig.add_subplot(111)

plot_dict = {}
plot_dict.update(self.default_plot_dict)
plot_dict.update(user_plot_dict)

if fig is None:
fig = plt.figure(figsize=plot_dict["figsize"])
ax = fig.add_subplot(111)
pernight = metric_value[0]["pernight"]
perfilter = metric_value[0]["perfilter"]

Expand All @@ -57,9 +58,11 @@ def __call__(self, metric_value, slicer, user_plot_dict, fignum=None):
color=self.filter2color[key],
transform=ax.transAxes,
)
# ax.plot(pernight['mjd'] - dmin, (pernight['twi6_rise'] - pernight['midnight']) * 24.,
# ax.plot(pernight['mjd'] - dmin,
# (pernight['twi6_rise'] - pernight['midnight']) * 24.,
# 'blue', label=r'6$^\circ$ twilight')
# ax.plot(pernight['mjd'] - dmin, (pernight['twi6_set'] - pernight['midnight']) * 24.,
# ax.plot(pernight['mjd'] - dmin,
# (pernight['twi6_set'] - pernight['midnight']) * 24.,
# 'blue')
ax.plot(
pernight["mjd"] - dmin,
Expand Down Expand Up @@ -114,4 +117,4 @@ def __call__(self, metric_value, slicer, user_plot_dict, fignum=None):
)
ax.axhline((pernight["twi18_set"] - pernight["midnight"]) * 24.0, color="red")

return fig.number
return fig
43 changes: 26 additions & 17 deletions rubin_sim/maf/plots/mo_plotters.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@

from .plot_handler import BasePlotter

# mag_sun = -27.1 # apparent r band magnitude of the sun. this sets the band for the magnitude limit.
# see http://www.ucolick.org/~cnaw/sun.html for apparent magnitudes in other bands.
mag_sun = -26.74 # apparent V band magnitude of the Sun (our H mags translate to V band)
# mag_sun = -27.1
# apparent r band magnitude of the sun.
# this sets the band for the magnitude limit.
# see http://www.ucolick.org/~cnaw/sun.html for apparent mags in other bands.
mag_sun = -26.74
# apparent V band magnitude of the Sun (our H mags translate to V band)
km_per_au = 1.496e8
m_per_km = 1000

Expand Down Expand Up @@ -37,7 +40,7 @@ def __init__(self):
}
self.min_hrange = 1.0

def __call__(self, metric_value, slicer, user_plot_dict, fignum=None):
def __call__(self, metric_value, slicer, user_plot_dict, fig=None):
if "linestyle" not in user_plot_dict:
user_plot_dict["linestyle"] = "-"
plot_dict = {}
Expand All @@ -49,11 +52,13 @@ def __call__(self, metric_value, slicer, user_plot_dict, fignum=None):
reduce_func = np.mean
if hvals.shape[0] == 1:
# We have a simple set of values to plot against H.
# This may be due to running a summary metric, such as completeness.
# This may be due to running a summary metric,
# such as completeness.
m_vals = metric_value[0].filled()
elif len(hvals) == slicer.shape[1]:
# Using cloned H distribution.
# Apply 'np_reduce' method directly to metric values, and plot at matching H values.
# Apply 'np_reduce' method directly to metric values,
# and plot at matching H values.
m_vals = reduce_func(metric_value.filled(), axis=0)
else:
# Probably each object has its own H value.
Expand All @@ -67,7 +72,8 @@ def __call__(self, metric_value, slicer, user_plot_dict, fignum=None):
nbins = 30
stepsize = hrange / float(nbins)
bins = np.arange(min_h, min_h + hrange + stepsize / 2.0, stepsize)
# In each bin of H, calculate the 'np_reduce' value of the corresponding metric_values.
# In each bin of H, calculate the 'np_reduce' value of the
# corresponding metric_values.
inds = np.digitize(hvals, bins)
inds = inds - 1
m_vals = np.zeros(len(bins), float)
Expand All @@ -79,7 +85,8 @@ def __call__(self, metric_value, slicer, user_plot_dict, fignum=None):
m_vals[i] = reduce_func(match.filled())
hvals = bins
# Plot the values.
fig = plt.figure(fignum, figsize=plot_dict["figsize"])
if fig is None:
fig = plt.figure(figsize=plot_dict["figsize"])
ax = plt.gca()
ax.plot(
hvals,
Expand Down Expand Up @@ -130,7 +137,7 @@ def __call__(self, metric_value, slicer, user_plot_dict, fignum=None):
plt.xlabel(plot_dict["xlabel"])
plt.ylabel(plot_dict["ylabel"])
plt.tight_layout()
return fig.number
return fig


class MetricVsOrbit(BasePlotter):
Expand Down Expand Up @@ -159,11 +166,12 @@ def __init__(self, xaxis="q", yaxis="e"):
"figsize": None,
}

def __call__(self, metric_value, slicer, user_plot_dict, fignum=None):
def __call__(self, metric_value, slicer, user_plot_dict, fig=None):
plot_dict = {}
plot_dict.update(self.default_plot_dict)
plot_dict.update(user_plot_dict)
fig = plt.figure(fignum, figsize=plot_dict["figsize"])
if fig is None:
fig = plt.figure(figsize=plot_dict["figsize"])
xvals = slicer.slice_points["orbits"][plot_dict["xaxis"]]
yvals = slicer.slice_points["orbits"][plot_dict["yaxis"]]
# Set x/y bins.
Expand Down Expand Up @@ -249,13 +257,13 @@ def __call__(self, metric_value, slicer, user_plot_dict, fignum=None):
plt.title(plot_dict["title"])
plt.xlabel(plot_dict["xlabel"])
plt.ylabel(plot_dict["ylabel"])
return fig.number
return fig


class MetricVsOrbitPoints(BasePlotter):
"""
Plot metric values (at a particular H value) as function of orbital parameters,
using points for each metric value.
Plot metric values (at a particular H value) as function
of orbital parameters, using points for each metric value.
"""

def __init__(self, xaxis="q", yaxis="e"):
Expand All @@ -276,11 +284,12 @@ def __init__(self, xaxis="q", yaxis="e"):
"figsize": None,
}

def __call__(self, metric_value, slicer, user_plot_dict, fignum=None):
def __call__(self, metric_value, slicer, user_plot_dict, fig=None):
plot_dict = {}
plot_dict.update(self.default_plot_dict)
plot_dict.update(user_plot_dict)
fig = plt.figure(fignum, figsize=plot_dict["figsize"])
if fig is None:
fig = plt.figure(figsize=plot_dict["figsize"])
xvals = slicer.slice_points["orbits"][plot_dict["xaxis"]]
yvals = slicer.slice_points["orbits"][plot_dict["yaxis"]]
# Identify the relevant metric_values for the Hvalue we want to plot.
Expand Down Expand Up @@ -346,4 +355,4 @@ def __call__(self, metric_value, slicer, user_plot_dict, fignum=None):
plt.title(plot_dict["title"])
plt.xlabel(plot_dict["xlabel"])
plt.ylabel(plot_dict["ylabel"])
return fig.number
return fig
Loading
Loading