Skip to content

Commit

Permalink
fix: xarray 2023.08.0 compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
yaugenst-flex committed Oct 28, 2024
1 parent 143efcf commit 2c9ae23
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 4 deletions.
8 changes: 4 additions & 4 deletions tests/test_components/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1630,7 +1630,7 @@ def test_custom_methods_grads(self, attr):
"""Test grads of TidyArrayBox methods implemented in autograd/boxes.py"""

def objective(x, attr):
da = DataArray(x, dims=map(str, range(x.ndim)))
da = DataArray(x)
attr_value = getattr(da, attr)
val = attr_value() if callable(attr_value) else attr_value
return val.item()
Expand All @@ -1643,11 +1643,11 @@ def test_multiply_at_grads(self, rng):

def objective(a, b):
coords = {str(i): np.arange(a.shape[i]) for i in range(a.ndim)}
da = DataArray(a, coords=coords, dims=map(str, range(a.ndim)))
da = DataArray(a, coords=coords)
da_mult = da.multiply_at(b, "0", [0, 1]) ** 2
return np.sum(da_mult).item()

a = rng.uniform(-1, 1, (3, 3))
b = 1.0
check_grads(lambda x: objective(x, b), modes=["fwd", "rev"], order=1)(a)
check_grads(lambda x: objective(a, x), modes=["fwd", "rev"], order=1)(b)
check_grads(lambda x: objective(x, b), modes=["fwd", "rev"], order=2)(a)
check_grads(lambda x: objective(a, x), modes=["fwd", "rev"], order=2)(b)
9 changes: 9 additions & 0 deletions tidy3d/components/autograd/boxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,21 @@
from typing import Any, Callable, Dict, List, Tuple

import autograd.numpy as anp
from autograd.extend import defjvp
from autograd.numpy.numpy_boxes import ArrayBox
from autograd.numpy.numpy_wrapper import _astype

TidyArrayBox = ArrayBox # NOT a subclass

_autograd_module_cache = {} # cache for imported autograd modules

defjvp(
_astype,
lambda g, ans, A, dtype, order="K", casting="unsafe", subok=True, copy=True: _astype(g, dtype),
)

anp.astype = _astype


@classmethod
def from_arraybox(cls, box: ArrayBox) -> TidyArrayBox:
Expand Down

0 comments on commit 2c9ae23

Please sign in to comment.