Skip to content

Commit

Permalink
RFC: refactor plot callbacks registration. Erroneous calls to annotat…
Browse files Browse the repository at this point in the history
…ion methods now fail immediately and don't block rendering
  • Loading branch information
neutrinoceros committed Jun 4, 2022
1 parent b2965ce commit 2165f4c
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 87 deletions.
29 changes: 26 additions & 3 deletions yt/visualization/plot_modifications.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
import re
import warnings
from abc import ABC, abstractmethod
from functools import wraps
from functools import update_wrapper, wraps
from numbers import Integral, Number
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Dict, Optional, Tuple, Type, Union

import matplotlib
import numpy as np
Expand Down Expand Up @@ -33,8 +33,9 @@
from yt.visualization._commons import _swap_arg_pair_order, _swap_axes_extents
from yt.visualization.base_plot_types import CallbackWrapper
from yt.visualization.image_writer import apply_colormap
from yt.visualization.plot_window import PWViewerMPL

callback_registry = {}
callback_registry: Dict[str, Type["PlotCallback"]] = {}


def _verify_geometry(func):
Expand Down Expand Up @@ -83,7 +84,29 @@ class PlotCallback(ABC):
def __init_subclass__(cls, *args, **kwargs):
if inspect.isabstract(cls):
return

# register class
callback_registry[cls.__name__] = cls

# create a PWViewerMPL method by wrapping __init__
if cls.__init__.__doc__ is None:
# allow definition the docstring at the class level instead of __init__
cls.__init__.__doc__ = cls.__doc__

method_name = "annotate_" + cls._type_name

def closure(self, *args, **kwargs):
self._callbacks.append(cls(*args, **kwargs))

update_wrapper(
wrapper=closure,
wrapped=cls.__init__,
assigned=("__annotations__", "__doc__"),
)

closure.__name__ = method_name
setattr(PWViewerMPL, method_name, closure)

cls.__call__ = _verify_geometry(cls.__call__)

@abstractmethod
Expand Down
96 changes: 12 additions & 84 deletions yt/visualization/plot_window.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import abc
from collections import defaultdict
from functools import wraps
from numbers import Number
from typing import List, Optional, Type, Union

Expand All @@ -27,21 +26,20 @@
YTInvalidFieldType,
YTPlotCallbackError,
YTUnitNotRecognized,
YTUnsupportedPlotCallback,
)
from yt.utilities.math_utils import ortho_find
from yt.utilities.orientation import Orientation
from yt.visualization.base_plot_types import CallbackWrapper

from ._commons import MPL_VERSION, _swap_axes_extents
from .base_plot_types import CallbackWrapper, ImagePlotMPL
from .base_plot_types import ImagePlotMPL
from .fixed_resolution import (
FixedResolutionBuffer,
OffAxisProjectionFixedResolutionBuffer,
)
from .geo_plot_utils import get_mpl_transform
from .plot_container import (
ImagePlotContainer,
apply_callback,
get_log_minorticks,
get_symlog_minorticks,
invalidate_data,
Expand All @@ -51,7 +49,6 @@
log_transform,
symlog_transform,
)
from .plot_modifications import callback_registry

import sys # isort: skip

Expand Down Expand Up @@ -270,7 +267,6 @@ def __init__(
# Access the dictionary to force the key to be created
self._units_config[field]

self.setup_callbacks()
self._setup_plots()

def __iter__(self):
Expand Down Expand Up @@ -922,6 +918,11 @@ def __init__(self, *args, **kwargs):
self._frb: Optional[FixedResolutionBuffer] = None
PlotWindow.__init__(self, *args, **kwargs)

# import type here to avoid import cycles
from yt.visualization.plot_modifications import PlotCallback

self._callbacks: List[PlotCallback] = []

@property
def _data_valid(self) -> bool:
return self._frb is not None and self._frb._data_valid
Expand Down Expand Up @@ -1387,89 +1388,17 @@ def _setup_plots(self):

self._plot_valid = True

def setup_callbacks(self):
ignored = ["PlotCallback"]
if self._plot_type.startswith("OffAxis"):
ignored += [
"ParticleCallback",
"ClumpContourCallback",
"GridBoundaryCallback",
]
if self._plot_type == "OffAxisProjection":
ignored += [
"VelocityCallback",
"MagFieldCallback",
"QuiverCallback",
"CuttingQuiverCallback",
"StreamlineCallback",
"LineIntegralConvolutionCallback",
]
elif self._plot_type == "Particle":
ignored += [
"HopCirclesCallback",
"HopParticleCallback",
"ClumpContourCallback",
"GridBoundaryCallback",
"VelocityCallback",
"MagFieldCallback",
"QuiverCallback",
"CuttingQuiverCallback",
"StreamlineCallback",
"ContourCallback",
]

def missing_callback_closure(cbname):
def _(*args, **kwargs):
raise YTUnsupportedPlotCallback(
callback=cbname, plot_type=self._plot_type
)

return _

for key in callback_registry:
cbname = callback_registry[key]._type_name

if key in ignored:
self.__dict__["annotate_" + cbname] = missing_callback_closure(cbname)
continue

# We need to wrap to create a closure so that
# CallbackMaker is bound to the wrapped method.
def closure():
CallbackMaker = callback_registry[key]

@wraps(CallbackMaker)
def method(*args, **kwargs):
# We need to also do it here as "invalidate_plot"
# and "apply_callback" require the functions'
# __name__ in order to work properly
@wraps(CallbackMaker)
def cb(self, *a, **kwa):
# We construct the callback method
# skipping self
return CallbackMaker(*a, **kwa)

# Create callback
cb = invalidate_plot(apply_callback(cb))

return cb(self, *args, **kwargs)

return method

self.__dict__["annotate_" + cbname] = closure()

@invalidate_plot
def clear_annotations(self, index=None):
def clear_annotations(self, index: Optional[int] = None):
"""
Clear callbacks from the plot. If index is not set, clear all
callbacks. If index is set, clear that index (ie 0 is the first one
created, 1 is the 2nd one created, -1 is the last one created, etc.)
"""
if index is None:
self._callbacks = []
self._callbacks.clear()
else:
del self._callbacks[index]
self.setup_callbacks()
self._callbacks.pop(index)
return self

def list_annotations(self):
Expand All @@ -1484,7 +1413,7 @@ def list_annotations(self):
def run_callbacks(self):
for f in self.fields:
keys = self.frb.keys()
for name, (args, kwargs) in self._callbacks:
for callback in self._callbacks:
# need to pass _swap_axes and adjust all the callbacks
cbw = CallbackWrapper(
self,
Expand All @@ -1494,8 +1423,7 @@ def run_callbacks(self):
self._font_properties,
self._font_color,
)
CallbackMaker = callback_registry[name]
callback = CallbackMaker(*args[1:], **kwargs)

try:
callback(cbw)
except YTDataTypeUnsupported as e:
Expand Down
27 changes: 27 additions & 0 deletions yt/visualization/tests/test_callbacks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import contextlib
import inspect
import shutil
import tempfile

Expand Down Expand Up @@ -71,6 +72,32 @@ def _cleanup_fname():
shutil.rmtree(tmpdir)


def test_method_signature():
ds = fake_amr_ds(
fields=[("gas", "density"), ("gas", "velocity_x"), ("gas", "velocity_y")],
units=["g/cm**3", "m/s", "m/s"],
)
p = SlicePlot(ds, "z", ("gas", "density"))
sig = inspect.signature(p.annotate_velocity)
# checking the first few arguments rather than the whole signature
# we just want to validate that method wrapping works
assert list(sig.parameters.keys())[:4] == [
"factor",
"scale",
"scale_units",
"normalize",
]


def test_init_signature_error_callback():
ds = fake_amr_ds(
fields=[("gas", "density"), ("gas", "velocity_x"), ("gas", "velocity_y")],
units=["g/cm**3", "m/s", "m/s"],
)
p = SlicePlot(ds, "z", ("gas", "density"))
assert_raises(TypeError, p.annotate_velocity, {"invalid_argument": 1})


def check_axis_manipulation(plot_obj, prefix):
# convenience function for testing functionality of axis manipulation
# callbacks. Can use in any of the other test functions.
Expand Down

0 comments on commit 2165f4c

Please sign in to comment.