Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: refactor plot callbacks registration #3957

Merged
merged 8 commits into from
Jun 14, 2022
70 changes: 29 additions & 41 deletions yt/visualization/fixed_resolution.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import weakref
from functools import wraps
from typing import Dict, List, Optional
from typing import TYPE_CHECKING, Dict, List, Optional

import numpy as np

from yt._maintenance.deprecation import issue_deprecation_warning
from yt.data_objects.image_array import ImageArray
from yt.frontends.ytdata.utilities import save_as_dataset
from yt.funcs import get_output_filename, iter_fields, mylog
Expand All @@ -15,13 +15,11 @@
from yt.utilities.lib.pixelization_routines import pixelize_cylinder
from yt.utilities.on_demand_imports import _h5py as h5py

from .fixed_resolution_filters import (
FixedResolutionBufferFilter,
apply_filter,
filter_registry,
)
from .volume_rendering.api import off_axis_projection

if TYPE_CHECKING:
from yt.visualization.fixed_resolution_filters import FixedResolutionBufferFilter


class FixedResolutionBuffer:
r"""
Expand Down Expand Up @@ -105,19 +103,30 @@ def __init__(
antialias=True,
periodic=False,
*,
filters: Optional[List[FixedResolutionBufferFilter]] = None,
filters: Optional[List["FixedResolutionBufferFilter"]] = None,
):
self.data_source = data_source
self.ds = data_source.ds
self.bounds = bounds
self.buff_size = (int(buff_size[0]), int(buff_size[1]))
self.antialias = antialias
self.data: Dict[str, np.ndarray] = {}
self._filters = []
self.axis = data_source.axis
self.periodic = periodic
self._data_valid = False
self._filters = filters if filters is not None else []

# import type here to avoid import cycles
# note that this import statement is actually crucial at runtime:
# the filter methods for the present class are defined only when
# fixed_resolution_filters is imported, so we need to guarantee
# that it happens no later than instanciation
from yt.visualization.fixed_resolution_filters import (
FixedResolutionBufferFilter,
)

self._filters: List[FixedResolutionBufferFilter] = (
filters if filters is not None else []
)

ds = getattr(data_source, "ds", None)
if ds is not None:
Expand All @@ -134,8 +143,6 @@ def __init__(
self._period = (DD[xax], DD[yax])
self._edges = ((DLE[xax], DRE[xax]), (DLE[yax], DRE[yax]))

self.setup_filters()

def keys(self):
return self.data.keys()

Expand Down Expand Up @@ -166,8 +173,7 @@ def __getitem__(self, item):
int(self.antialias),
)

for name, (args, kwargs) in self._filters:
buff = filter_registry[name](*args, **kwargs).apply(buff)
buff = self._apply_filters(buff)

# FIXME FIXME FIXME we shouldn't need to do this for projections
# but that will require fixing data object access for particle
Expand All @@ -186,6 +192,11 @@ def __getitem__(self, item):
self._data_valid = True
return self.data[item]

def _apply_filters(self, buffer: np.ndarray) -> np.ndarray:
for f in self._filters:
buffer = f(buffer)
return buffer

def __setitem__(self, item, val):
self.data[item] = val

Expand Down Expand Up @@ -522,33 +533,10 @@ def limits(self):
return rv

def setup_filters(self):
for key in filter_registry:
filtername = filter_registry[key]._filter_name

# We need to wrap to create a closure so that
# FilterMaker is bound to the wrapped method.
def closure():
FilterMaker = filter_registry[key]

@wraps(FilterMaker)
def method(*args, **kwargs):
# We need to also do it here as "invalidate_plot"
# and "apply_callback" require the functions'
# __name__ in order to work properly
@wraps(FilterMaker)
def cb(self, *a, **kwa):
# We construct the callback method
# skipping self
return FilterMaker(*a, **kwa)

# Create callback
cb = apply_filter(cb)

return cb(self, *args, **kwargs)

return method

self.__dict__["apply_" + filtername] = closure()
issue_deprecation_warning(
"The FixedResolutionBuffer.setup_filters method is now a no-op. ",
since="4.1.0",
)


class CylindricalFixedResolutionBuffer(FixedResolutionBuffer):
Expand Down
41 changes: 36 additions & 5 deletions yt/visualization/fixed_resolution_filters.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
from abc import ABC, abstractmethod
from functools import wraps
from functools import update_wrapper, wraps

import numpy as np

filter_registry = {}
from yt._maintenance.deprecation import issue_deprecation_warning
from yt.visualization.fixed_resolution import FixedResolutionBuffer


def apply_filter(f):
issue_deprecation_warning(
"The apply_filter decorator is not used in yt any more and "
"will be removed in a future version. "
"Please do not use it.",
since="4.1",
)

@wraps(f)
def newfunc(self, *args, **kwargs):
self._filters.append((f.__name__, (args, kwargs)))
Expand All @@ -25,18 +33,41 @@ class FixedResolutionBufferFilter(ABC):
"""

def __init_subclass__(cls, *args, **kwargs):
super().__init_subclass__(*args, **kwargs)
filter_registry[cls.__name__] = cls

if cls.__init__.__doc__ is None:
# allow docstring definition at the class level instead of __init__
cls.__init__.__doc__ = cls.__doc__

# add a method to FixedResolutionBuffer
method_name = "apply_" + cls._filter_name

def closure(self, *args, **kwargs):
self._filters.append(cls(*args, **kwargs))
self._data_valid = False
return self

update_wrapper(
wrapper=closure,
wrapped=cls.__init__,
assigned=("__annotations__", "__doc__"),
)

closure.__name__ = method_name
setattr(FixedResolutionBuffer, method_name, closure)

@abstractmethod
def __init__(self, *args, **kwargs):
"""This method is required in subclasses, but the signature is arbitrary"""
pass

@abstractmethod
def apply(self, buff):
def apply(self, buff: np.ndarray) -> np.ndarray:
pass

def __call__(self, buff: np.ndarray) -> np.ndarray:
# alias to apply
return self.apply(buff)


class FixedResolutionBufferGaussBeamFilter(FixedResolutionBufferFilter):

Expand Down
8 changes: 8 additions & 0 deletions yt/visualization/plot_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from matplotlib.font_manager import FontProperties
from more_itertools.more import always_iterable

from yt._maintenance.deprecation import issue_deprecation_warning
from yt.config import ytcfg
from yt.data_objects.time_series import DatasetSeries
from yt.funcs import dictWithFactory, ensure_dir, is_sequence, iter_fields, mylog
Expand All @@ -36,6 +37,13 @@


def apply_callback(f):
issue_deprecation_warning(
"The apply_callback decorator is not used in yt any more and "
"will be removed in a future version. "
"Please do not use it.",
since="4.1",
)

@wraps(f)
def newfunc(*args, **kwargs):
args[0]._callbacks.append((f.__name__, (args, kwargs)))
Expand Down
Loading