Skip to content

Commit

Permalink
RFC: refactor plot filters registration
Browse files Browse the repository at this point in the history
  • Loading branch information
neutrinoceros committed Jun 4, 2022
1 parent 4d0f203 commit 30ee415
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 46 deletions.
66 changes: 25 additions & 41 deletions yt/visualization/fixed_resolution.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import weakref
from functools import wraps
from typing import Dict, List, Optional
from typing import TYPE_CHECKING, Dict, List, Optional

import numpy as np

from yt._maintenance.deprecation import issue_deprecation_warning
from yt.data_objects.image_array import ImageArray
from yt.frontends.ytdata.utilities import save_as_dataset
from yt.funcs import get_output_filename, iter_fields, mylog
Expand All @@ -15,13 +15,11 @@
from yt.utilities.lib.pixelization_routines import pixelize_cylinder
from yt.utilities.on_demand_imports import _h5py as h5py

from .fixed_resolution_filters import (
FixedResolutionBufferFilter,
apply_filter,
filter_registry,
)
from .volume_rendering.api import off_axis_projection

if TYPE_CHECKING:
from yt.visualization.fixed_resolution_filters import FixedResolutionBufferFilter


class FixedResolutionBuffer:
r"""
Expand Down Expand Up @@ -105,19 +103,26 @@ def __init__(
antialias=True,
periodic=False,
*,
filters: Optional[List[FixedResolutionBufferFilter]] = None,
filters: Optional[List["FixedResolutionBufferFilter"]] = None,
):
self.data_source = data_source
self.ds = data_source.ds
self.bounds = bounds
self.buff_size = (int(buff_size[0]), int(buff_size[1]))
self.antialias = antialias
self.data: Dict[str, np.ndarray] = {}
self._filters = []
self.axis = data_source.axis
self.periodic = periodic
self._data_valid = False
self._filters = filters if filters is not None else []

# import type here to avoid import cycles
from yt.visualization.fixed_resolution_filters import (
FixedResolutionBufferFilter,
)

self._filters: List[FixedResolutionBufferFilter] = (
filters if filters is not None else []
)

ds = getattr(data_source, "ds", None)
if ds is not None:
Expand All @@ -134,8 +139,6 @@ def __init__(
self._period = (DD[xax], DD[yax])
self._edges = ((DLE[xax], DRE[xax]), (DLE[yax], DRE[yax]))

self.setup_filters()

def keys(self):
return self.data.keys()

Expand Down Expand Up @@ -166,8 +169,7 @@ def __getitem__(self, item):
int(self.antialias),
)

for name, (args, kwargs) in self._filters:
buff = filter_registry[name](*args, **kwargs).apply(buff)
buff = self._apply_filters(buff)

# FIXME FIXME FIXME we shouldn't need to do this for projections
# but that will require fixing data object access for particle
Expand All @@ -186,6 +188,11 @@ def __getitem__(self, item):
self._data_valid = True
return self.data[item]

def _apply_filters(self, buffer: np.ndarray) -> np.ndarray:
for f in self._filters:
buffer = f(buffer)
return buffer

def __setitem__(self, item, val):
self.data[item] = val

Expand Down Expand Up @@ -522,33 +529,10 @@ def limits(self):
return rv

def setup_filters(self):
for key in filter_registry:
filtername = filter_registry[key]._filter_name

# We need to wrap to create a closure so that
# FilterMaker is bound to the wrapped method.
def closure():
FilterMaker = filter_registry[key]

@wraps(FilterMaker)
def method(*args, **kwargs):
# We need to also do it here as "invalidate_plot"
# requires the functions'
# __name__ in order to work properly
@wraps(FilterMaker)
def cb(self, *a, **kwa):
# We construct the callback method
# skipping self
return FilterMaker(*a, **kwa)

# Create callback
cb = apply_filter(cb)

return cb(self, *args, **kwargs)

return method

self.__dict__["apply_" + filtername] = closure()
issue_deprecation_warning(
"The FixedResolutionBuffer.setup_filters method is now a no-op. ",
since="4.1.0",
)


class CylindricalFixedResolutionBuffer(FixedResolutionBuffer):
Expand Down
28 changes: 23 additions & 5 deletions yt/visualization/fixed_resolution_filters.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from abc import ABC, abstractmethod
from functools import wraps
from functools import update_wrapper, wraps

import numpy as np

filter_registry = {}
from yt.visualization.fixed_resolution import FixedResolutionBuffer


def apply_filter(f):
Expand All @@ -25,18 +25,36 @@ class FixedResolutionBufferFilter(ABC):
"""

def __init_subclass__(cls, *args, **kwargs):
super().__init_subclass__(*args, **kwargs)
filter_registry[cls.__name__] = cls
# add a method to FixedResolutionBuffer
method_name = "apply_" + cls._filter_name

def closure(self, *args, **kwargs):
self._filters.append(cls(*args, **kwargs))
self._data_valid = False
return self

update_wrapper(
wrapper=closure,
wrapped=cls.__init__,
assigned=("__annotations__", "__doc__"),
)

closure.__name__ = method_name
setattr(FixedResolutionBuffer, method_name, closure)

@abstractmethod
def __init__(self, *args, **kwargs):
"""This method is required in subclasses, but the signature is arbitrary"""
pass

@abstractmethod
def apply(self, buff):
def apply(self, buff: np.ndarray) -> np.ndarray:
pass

def __call__(self, buff: np.ndarray) -> np.ndarray:
# alias to apply
return self.apply(buff)


class FixedResolutionBufferGaussBeamFilter(FixedResolutionBufferFilter):

Expand Down

0 comments on commit 30ee415

Please sign in to comment.