Skip to content

Commit

Permalink
Refactor code: remove deprecated interval parameter, clean up unused …
Browse files Browse the repository at this point in the history
…classes, and improve import statements
  • Loading branch information
david-zwicker committed Feb 10, 2025
1 parent 8c4246e commit e190945
Show file tree
Hide file tree
Showing 12 changed files with 41 additions and 334 deletions.
54 changes: 1 addition & 53 deletions pde/fields/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,50 +886,6 @@ def _plot_merged_image(
}
return PlotReference(ax, axes_image, parameters)

@plot_on_axes(update_method="_update_rgb_image_plot")
def _plot_rgb_image(
self,
ax,
transpose: bool = False,
vmin: float | list[float | None] | None = None,
vmax: float | list[float | None] | None = None,
**kwargs,
) -> PlotReference:
r"""Visualize fields by mapping to different color chanels in a 2d density plot.
Args:
ax (:class:`matplotlib.axes.Axes`):
Figure axes to be used for plotting.
transpose (bool):
Determines whether the transpose of the data is plotted
vmin, vmax (float, list of float):
Define the data range that the color chanels cover. By default, they
cover the complete value range of the supplied data.
\**kwargs:
Additional keyword arguments that affect the image. Non-Cartesian grids
might support `performance_goal` to influence how an image is created
from raw data. Finally, remaining arguments are passed to
:func:`matplotlib.pyplot.imshow` to affect the appearance.
Returns:
:class:`PlotReference`: Instance that contains information to update the
plot with new data later.
"""
# since 2024-01-25
warnings.warn(
"`rgb_image` is deprecated in favor of `merged`", DeprecationWarning
)
return self._plot_merged_image( # type: ignore
ax=ax,
colors="rgb",
background_color="k",
projection="max",
transpose=transpose,
vmin=vmin,
vmax=vmax,
**kwargs,
)

def _update_plot(self, reference: list[PlotReference]) -> None:
"""Update a plot collection with the current field values.
Expand Down Expand Up @@ -984,7 +940,7 @@ def plot(
List of :class:`PlotReference`: Instances that contain information
to update all the plots with new data later.
"""
if kind in {"merged", "rgb", "rgb_image", "rgb-image"}:
if kind in {"merged"}:
num_panels = 1
else:
num_panels = len(self)
Expand Down Expand Up @@ -1024,14 +980,6 @@ def plot(
)
]

elif kind in {"rgb", "rgb_image", "rgb-image"}:
# plot a single RGB representation
reference = [
self._plot_rgb_image(
ax=axs[0], action="none", **kwargs, **subplot_args[0]
)
]

else:
# plot all the elements onto the respective axes
if isinstance(kind, str):
Expand Down
157 changes: 1 addition & 156 deletions pde/grids/boundaries/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,7 @@
from ...tools.cache import cached_method
from ...tools.docstrings import fill_in_docstring
from ...tools.numba import address_as_void_pointer, jit, numba_dict
from ...tools.typing import (
AdjacentEvaluator,
FloatNumerical,
GhostCellSetter,
VirtualPointEvaluator,
)
from ...tools.typing import FloatNumerical, GhostCellSetter, VirtualPointEvaluator
from ..base import GridBase, PeriodicityError

if TYPE_CHECKING:
Expand Down Expand Up @@ -592,23 +587,6 @@ def make_virtual_point_evaluator(self) -> VirtualPointEvaluator:
the boundary condition.
"""

def make_adjacent_evaluator(self) -> AdjacentEvaluator:
"""Returns a function evaluating the value adjacent to a given point.
.. deprecated:: Since 2023-12-19
Returns:
function: A function with signature (arr_1d, i_point, bc_idx), where
`arr_1d` is the one-dimensional data array (the data points along the axis
perpendicular to the boundary), `i_point` is the index into this array for
the current point and bc_idx are the remaining indices of the current point,
which indicate the location on the boundary plane. The result of the
function is the data value at the adjacent point along the axis associated
with this boundary condition in the upper (lower) direction when `upper` is
True (False).
"""
raise NotImplementedError

@abstractmethod
def set_ghost_cells(self, data_full: np.ndarray, *, args=None) -> None:
"""Set the ghost cell values for this boundary.
Expand Down Expand Up @@ -1529,9 +1507,6 @@ def get_sparse_matrix_data(
def get_virtual_point(self, arr, idx: tuple[int, ...] | None = None) -> float:
raise NotImplementedError

def make_adjacent_evaluator(self) -> AdjacentEvaluator:
raise NotImplementedError

def _get_value_cell_index(self, with_ghost_cells: bool) -> int:
if self.value_cell is None:
# pick adjacent cell by default
Expand Down Expand Up @@ -2159,69 +2134,6 @@ def virtual_point(

return virtual_point # type: ignore

def make_adjacent_evaluator(self) -> AdjacentEvaluator:
# method deprecated since 2023-12-19
warnings.warn("`make_adjacent_evaluator` is deprecated", DeprecationWarning)
# get values distinguishing upper from lower boundary
if self.upper:
i_bndry = self.grid.shape[self.axis] - 1
i_dx = 1
else:
i_bndry = 0
i_dx = -1

if self.homogeneous:
# the boundary condition does not depend on space

# calculate necessary constants
const, factor, index = self.get_virtual_point_data(compiled=True)
zeros = np.zeros(self._shape_tensor)
ones = np.ones(self._shape_tensor)

@register_jitable(inline="always")
def adjacent_point(
arr_1d: np.ndarray, i_point: int, bc_idx: tuple[int, ...]
) -> FloatNumerical:
"""Evaluate the value adjacent to the current point."""
# determine the parameters for evaluating adjacent point. Note
# that defining the variables c and f for the interior points
# seems needless, but it turns out that this results in a 10x
# faster function (because of branch prediction?).
if i_point == i_bndry:
c, f, i = const(), factor(), index
else:
c, f, i = zeros, ones, i_point + i_dx # INTENTIONAL

# calculate the values
return c + f * arr_1d[..., i] # type: ignore

else:
# the boundary condition is a function of space

# calculate necessary constants
const, factor, index = self.get_virtual_point_data(compiled=True)
zeros = np.zeros(self._shape_tensor + self._shape_boundary)
ones = np.ones(self._shape_tensor + self._shape_boundary)

@register_jitable(inline="always")
def adjacent_point(arr_1d, i_point, bc_idx) -> float:
"""Evaluate the value adjacent to the current point."""
# determine the parameters for evaluating adjacent point. Note
# that defining the variables c and f for the interior points
# seems needless, but it turns out that this results in a 10x
# faster function (because of branch prediction?). This is
# surprising, because it uses arrays zeros and ones that are
# quite pointless
if i_point == i_bndry:
c, f, i = const(), factor(), index
else:
c, f, i = zeros, ones, i_point + i_dx # INTENTIONAL

# calculate the values
return c[bc_idx] + f[bc_idx] * arr_1d[..., i] # type: ignore

return adjacent_point # type: ignore

def set_ghost_cells(self, data_full: np.ndarray, *, args=None) -> None:
# calculate necessary constants
const, factor, index = self.get_virtual_point_data()
Expand Down Expand Up @@ -2754,73 +2666,6 @@ def virtual_point(arr: np.ndarray, idx: tuple[int, ...], args=None):

return virtual_point # type: ignore

def make_adjacent_evaluator(self) -> AdjacentEvaluator:
# method deprecated since 2023-12-19
warnings.warn("`make_adjacent_evaluator` is deprecated", DeprecationWarning)
size = self.grid.shape[self.axis]
if size < 2:
raise ValueError(
f"Need at least two support points along axis {self.axis} to apply "
"boundary conditions"
)

# get values distinguishing upper from lower boundary
if self.upper:
i_bndry = size - 1
i_dx = 1
else:
i_bndry = 0
i_dx = -1

# calculate necessary constants
data_vp = self.get_virtual_point_data()

zeros = np.zeros_like(self.value)
ones = np.ones_like(self.value)

if self.homogeneous:
# the boundary condition does not depend on space

@register_jitable
def adjacent_point(
arr_1d: np.ndarray, i_point: int, bc_idx: tuple[int, ...]
) -> float:
"""Evaluate the value adjacent to the current point."""
# determine the parameters for evaluating adjacent point
if i_point == i_bndry:
data = data_vp
else:
data = (zeros, ones, i_point + i_dx, zeros, 0)

# calculate the values
return ( # type: ignore
data[0]
+ data[1] * arr_1d[..., data[2]]
+ data[3] * arr_1d[..., data[4]]
)

else:
# the boundary condition is a function of space

@register_jitable
def adjacent_point(
arr_1d: np.ndarray, i_point: int, bc_idx: tuple[int, ...]
) -> float:
"""Evaluate the value adjacent to the current point."""
# determine the parameters for evaluating adjacent point
if i_point == i_bndry:
data = data_vp
else:
data = (zeros, ones, i_point + i_dx, zeros, 0)

return ( # type: ignore
data[0][bc_idx]
+ data[1][bc_idx] * arr_1d[..., data[2]]
+ data[3][bc_idx] * arr_1d[..., data[4]]
)

return adjacent_point # type: ignore

def set_ghost_cells(self, data_full: np.ndarray, *, args=None) -> None:
# calculate necessary constants
data = self.get_virtual_point_data()
Expand Down
9 changes: 2 additions & 7 deletions pde/storage/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,6 @@ def tracker(
interrupts: InterruptData = 1,
*,
transformation: Callable[[FieldBase, float], FieldBase] | None = None,
interval=None,
) -> StorageTracker:
"""Create object that can be used as a tracker to fill this storage.
Expand Down Expand Up @@ -321,10 +320,7 @@ def add_to_state(state):
possible by defining appropriate :func:`add_to_state`
"""
return StorageTracker(
storage=self,
interrupts=interrupts,
transformation=transformation,
interval=interval,
storage=self, interrupts=interrupts, transformation=transformation
)

def start_writing(self, field: FieldBase, info: InfoDict | None = None) -> None:
Expand Down Expand Up @@ -566,7 +562,6 @@ def __init__(
interrupts: InterruptData = 1,
*,
transformation: Callable[[FieldBase, float], FieldBase] | None = None,
interval=None,
):
"""
Args:
Expand All @@ -582,7 +577,7 @@ def __init__(
the current field, while the optional second argument is the associated
time.
"""
super().__init__(interrupts=interrupts, interval=interval)
super().__init__(interrupts=interrupts)
self.storage = storage
if transformation is not None and not callable(transformation):
raise TypeError("`transformation` must be callable")
Expand Down
34 changes: 8 additions & 26 deletions pde/tools/numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,30 +13,13 @@
import numba as nb
import numpy as np
from numba.core.types import npytypes, scalars
from numba.extending import overload, register_jitable
from numba.extending import is_jitted, overload, register_jitable
from numba.typed import Dict as NumbaDict

from .. import config
from ..tools.misc import decorator_arguments
from .typing import Number

try:
# is_jitted has been added in numba 0.53 on 2021-03-11
from numba.extending import is_jitted

except ImportError:
# for earlier version of numba, we need to define the function

def is_jitted(function: Callable) -> bool:
"""Determine whether a function has already been jitted."""
try:
from numba.core.dispatcher import Dispatcher
except ImportError:
# assume older numba module structure
from numba.dispatcher import Dispatcher
return isinstance(function, Dispatcher)


# numba version as a list of integers
NUMBA_VERSION = [int(v) for v in nb.__version__.split(".")[:2]]

Expand Down Expand Up @@ -177,7 +160,6 @@ def jit(function: TFunc, signature=None, parallel: bool = False, **kwargs) -> TF
return function

# prepare the compilation arguments
kwargs.setdefault("nopython", True)
if config["numba.fastmath"] is True:
# enable some (but not all) fastmath flags. We skip the flags that affect
# handling of infinities and NaN for safety by default. Use "fast" to enable all
Expand All @@ -192,10 +174,10 @@ def jit(function: TFunc, signature=None, parallel: bool = False, **kwargs) -> TF
# log some details
logger = logging.getLogger(__name__)
name = getattr(function, "__name__", "<anonymous function>")
if kwargs["nopython"]: # standard case
logger.info("Compile `%s` with parallel=%s", name, kwargs["parallel"])
else: # this might imply numba falls back to object mode
logger.warning("Compile `%s` with nopython=False", name)
if kwargs["parallel"]:
logger.info("Compile `%s`", name)

Check warning on line 178 in pde/tools/numba.py

View check run for this annotation

Codecov / codecov/patch

pde/tools/numba.py#L178

Added line #L178 was not covered by tests
else:
logger.info("Compile `%s` with parallel=True", name)

# increase the compilation counter by one
JIT_COUNT.increment()
Expand Down Expand Up @@ -308,7 +290,7 @@ def get_common_numba_dtype(*args):
return nb.double


@jit(nopython=True, nogil=True)
@jit(nogil=True)
def _random_seed_compiled(seed: int) -> None:
"""Sets the seed of the random number generator of numba."""
np.random.seed(seed)
Expand All @@ -325,7 +307,7 @@ def random_seed(seed: int = 0) -> None:
_random_seed_compiled(seed)


if NUMBA_VERSION < [0, 45]:
if NUMBA_VERSION < [0, 59]:
warnings.warn(
"Your numba version is outdated. Please install at least version 0.45"
"Your numba version is outdated. Please install at least version 0.59"
)
7 changes: 0 additions & 7 deletions pde/tools/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,6 @@ def __call__(self, arr: np.ndarray, idx: tuple[int, ...], args=None) -> float:
"""Evaluate the virtual point at the given position."""


class AdjacentEvaluator(Protocol):
def __call__(
self, arr_1d: np.ndarray, i_point: int, bc_idx: tuple[int, ...]
) -> float:
"""Evaluate the values at adjecent points."""


class GhostCellSetter(Protocol):
def __call__(self, data_full: np.ndarray, args=None) -> None:
"""Set the ghost cells."""
Expand Down
Loading

0 comments on commit e190945

Please sign in to comment.