Skip to content

Commit

Permalink
ensure correct data types in getter methods (#1030)
Browse files Browse the repository at this point in the history
Resolves #1029
  • Loading branch information
ddudt authored Jun 25, 2024
2 parents c94aa20 + 80c40ab commit a7bd35f
Show file tree
Hide file tree
Showing 14 changed files with 121 additions and 150 deletions.
19 changes: 10 additions & 9 deletions desc/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,8 @@ def NFP(self):

@property
def sym(self):
"""str: {``'cos'``, ``'sin'``, ``False``} Type of symmetry."""
"""str: Type of symmetry."""
# one of: {'even', 'sin', 'cos', 'cos(t)', False}
return self.__dict__.setdefault("_sym", False)

@property
Expand Down Expand Up @@ -238,7 +239,7 @@ def __init__(self, L, sym="even"):
self._M = 0
self._N = 0
self._NFP = 1
self._sym = sym
self._sym = bool(sym) if not sym else str(sym)
self._spectral_indexing = "linear"

self._modes = self._get_modes(L=self.L)
Expand Down Expand Up @@ -351,7 +352,7 @@ def __init__(self, N, NFP=1, sym=False):
self._M = 0
self._N = check_nonnegint(N, "N", False)
self._NFP = check_posint(NFP, "NFP", False)
self._sym = sym
self._sym = bool(sym) if not sym else str(sym)
self._spectral_indexing = "linear"

self._modes = self._get_modes(N=self.N)
Expand Down Expand Up @@ -474,7 +475,7 @@ def __init__(self, M, N, NFP=1, sym=False):
self._M = check_nonnegint(M, "M", False)
self._N = check_nonnegint(N, "N", False)
self._NFP = check_posint(NFP, "NFP", False)
self._sym = sym
self._sym = bool(sym) if not sym else str(sym)
self._spectral_indexing = "linear"

self._modes = self._get_modes(M=self.M, N=self.N)
Expand Down Expand Up @@ -635,8 +636,8 @@ def __init__(self, L, M, sym=False, spectral_indexing="ansi"):
self._M = check_nonnegint(M, "M", False)
self._N = 0
self._NFP = 1
self._sym = sym
self._spectral_indexing = spectral_indexing
self._sym = bool(sym) if not sym else str(sym)
self._spectral_indexing = str(spectral_indexing)

self._modes = self._get_modes(
L=self.L, M=self.M, spectral_indexing=self.spectral_indexing
Expand Down Expand Up @@ -831,7 +832,7 @@ def __init__(self, L, M, N, NFP=1, sym=False):
self._M = check_nonnegint(M, "M", False)
self._N = check_nonnegint(N, "N", False)
self._NFP = check_posint(NFP, "NFP", False)
self._sym = sym
self._sym = bool(sym) if not sym else str(sym)
self._spectral_indexing = "linear"

self._modes = self._get_modes(L=self.L, M=self.M, N=self.N)
Expand Down Expand Up @@ -983,8 +984,8 @@ def __init__(self, L, M, N, NFP=1, sym=False, spectral_indexing="ansi"):
self._M = check_nonnegint(M, "M", False)
self._N = check_nonnegint(N, "N", False)
self._NFP = check_posint(NFP, "NFP", False)
self._sym = sym
self._spectral_indexing = spectral_indexing
self._sym = bool(sym) if not sym else str(sym)
self._spectral_indexing = str(spectral_indexing)

self._modes = self._get_modes(
L=self.L, M=self.M, N=self.N, spectral_indexing=self.spectral_indexing
Expand Down
7 changes: 4 additions & 3 deletions desc/coils.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,14 +748,14 @@ def __init__(self, *coils, NFP=1, sym=False, name=""):
assert all([isinstance(coil, (_Coil)) for coil in coils])
[_check_type(coil, coils[0]) for coil in coils]
self._coils = list(coils)
self._NFP = NFP
self._sym = sym
self._NFP = int(NFP)
self._sym = bool(sym)
self._name = str(name)

@property
def name(self):
"""str: Name of the curve."""
return self._name
return self.__dict__.setdefault("_name", "")

@name.setter
def name(self, new):
Expand Down Expand Up @@ -837,6 +837,7 @@ def compute(
params = [get_params(names, coil) for coil in self]
if data is None:
data = [{}] * len(self)

# if user supplied initial data for each coil we also need to vmap over that.
data = vmap(
lambda d, x: self[0].compute(
Expand Down
4 changes: 2 additions & 2 deletions desc/equilibrium/equilibrium.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def __init__(
ValueError,
f"sym should be one of True, False, None, got {sym}",
)
self._sym = setdefault(sym, getattr(surface, "sym", False))
self._sym = bool(setdefault(sym, getattr(surface, "sym", False)))
self._R_sym = "cos" if self.sym else False
self._Z_sym = "sin" if self.sym else False

Expand Down Expand Up @@ -564,7 +564,7 @@ def change_resolution(
self._M_grid = int(setdefault(M_grid, self.M_grid))
self._N_grid = int(setdefault(N_grid, self.N_grid))
self._NFP = int(setdefault(NFP, self.NFP))
self._sym = setdefault(sym, self.sym)
self._sym = bool(setdefault(sym, self.sym))

old_modes_R = self.R_basis.modes
old_modes_Z = self.Z_basis.modes
Expand Down
8 changes: 4 additions & 4 deletions desc/geometry/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,11 @@ def rotmat(self, new):
@property
def name(self):
"""Name of the curve."""
return self._name
return self.__dict__.setdefault("_name", "")

@name.setter
def name(self, new):
self._name = new
self._name = str(new)

def compute(
self,
Expand Down Expand Up @@ -323,11 +323,11 @@ def _set_up(self):
@property
def name(self):
"""str: Name of the surface."""
return self._name
return self.__dict__.setdefault("_name", "")

@name.setter
def name(self, new):
self._name = new
self._name = str(new)

@property
def L(self):
Expand Down
4 changes: 2 additions & 2 deletions desc/geometry/curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __init__(

@property
def sym(self):
"""Whether this curve has stellarator symmetry."""
"""bool: Whether or not the curve is stellarator symmetric."""
return self._sym

@property
Expand Down Expand Up @@ -128,7 +128,7 @@ def change_resolution(self, N=None, NFP=None, sym=None):
and (sym != self.sym)
):
self._NFP = int(NFP if NFP is not None else self.NFP)
self._sym = sym if sym is not None else self.sym
self._sym = bool(sym) if sym is not None else self.sym
N = int(N if N is not None else self.N)
R_modes_old = self.R_basis.modes
Z_modes_old = self.Z_basis.modes
Expand Down
6 changes: 3 additions & 3 deletions desc/geometry/surface.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def __init__(

self._R_lmn = copy_coeffs(R_lmn, modes_R, self.R_basis.modes[:, 1:])
self._Z_lmn = copy_coeffs(Z_lmn, modes_Z, self.Z_basis.modes[:, 1:])
self._sym = sym
self._sym = bool(sym)
self._rho = rho

if check_orientation and self._compute_orientation() == -1:
Expand Down Expand Up @@ -870,8 +870,8 @@ def __init__(

self._R_lmn = copy_coeffs(R_lmn, modes_R, self.R_basis.modes[:, :2])
self._Z_lmn = copy_coeffs(Z_lmn, modes_Z, self.Z_basis.modes[:, :2])
self._sym = sym
self._spectral_indexing = spectral_indexing
self._sym = bool(sym)
self._spectral_indexing = str(spectral_indexing)

self._zeta = zeta

Expand Down
2 changes: 1 addition & 1 deletion desc/input_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -826,7 +826,7 @@ def desc_output_to_input( # noqa: C901 - fxn too complex
Fourier coefficients below this value will be set to 0.
"""
from desc.grid import LinearGrid
from desc.io.equilibrium_io import load
from desc.io.optimizable_io import load
from desc.profiles import PowerSeriesProfile
from desc.utils import copy_coeffs

Expand Down
4 changes: 2 additions & 2 deletions desc/io/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
"""Functions and classes for reading and writing DESC data."""

# InputReader lives outside this module for import ordering reasons, so we can
# import InputReader in __main__ without importing equilibrium_io which imports JAX
# import InputReader in __main__ without importing optimizable_io which imports JAX
# stuff potentially before we've set the GPU correctly.
# We include a link to it here for backwards compatibility
from desc.input_reader import InputReader

from .ascii_io import read_ascii, write_ascii
from .equilibrium_io import IOAble, load
from .hdf5_io import hdf5Reader, hdf5Writer
from .optimizable_io import IOAble, load
from .pickle_io import PickleReader, PickleWriter

__all__ = ["InputReader", "load"]
2 changes: 1 addition & 1 deletion desc/io/hdf5_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def isarray(x):
group = loc.create_group(attr)
self.write_list(data, where=group)
else:
from .equilibrium_io import IOAble
from .optimizable_io import IOAble

if isinstance(data, IOAble):
group = loc.create_group(attr)
Expand Down
3 changes: 3 additions & 0 deletions desc/io/equilibrium_io.py → desc/io/optimizable_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def load(load_from, file_format=None):
-------
obj :
The object saved in the file
"""
if file_format is None and isinstance(load_from, (str, os.PathLike)):
name = str(load_from)
Expand Down Expand Up @@ -83,6 +84,8 @@ def _unjittable(x):
return any([_unjittable(y) for y in x])
if isinstance(x, dict):
return any([_unjittable(y) for y in x.values()])
if hasattr(x, "dtype") and np.ndim(x) == 0:
return np.issubdtype(x.dtype, np.bool_) or np.issubdtype(x.dtype, np.int_)
return isinstance(x, (str, types.FunctionType, bool, int, np.int_))


Expand Down
130 changes: 53 additions & 77 deletions desc/objectives/_coils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ class _CoilObjective(_Objective):
of the objective. Has no effect on self.grad or self.hess which always use
reverse mode and forward over reverse mode respectively.
grid : Grid, list, optional
Collocation grid containing the nodes to evaluate at. If list, has to adhere to
Objective.dim_f
Collocation grid containing the nodes to evaluate at.
If a list, must have the same structure as coil.
name : str, optional
Name of the objective function.
Expand Down Expand Up @@ -96,102 +96,78 @@ def build(self, use_jit=True, verbose=1): # noqa:C901
"""
# local import to avoid circular import
from desc.coils import CoilSet, MixedCoilSet
from desc.coils import CoilSet, MixedCoilSet, _Coil

self._dim_f = 0
self._quad_weights = jnp.array([])
def _is_single_coil(c):
return isinstance(c, _Coil) and not isinstance(c, CoilSet)

def to_list(coilset):
"""Turn a MixedCoilSet container into a list of what it's containing."""
if isinstance(coilset, list):
return [to_list(x) for x in coilset]
elif isinstance(coilset, MixedCoilSet):
return [to_list(x) for x in coilset]
def _prune_coilset_tree(coilset):
"""Remove extra members from CoilSets (but not MixedCoilSets)."""
if isinstance(coilset, list) or isinstance(coilset, MixedCoilSet):
return [_prune_coilset_tree(c) for c in coilset]
elif isinstance(coilset, CoilSet):
# use the same grid/transform for CoilSet
return to_list(coilset.coils[0])
# CoilSet only uses a single grid/transform for all coils
return _prune_coilset_tree(coilset.coils[0])
else:
return [coilset]
return coilset # single coil

# gives structure of coils, e.g. MixedCoilSet(coils, coils) would give a
# a structure of [[*, *], [*, *]] if n = 2 coils
coil_leaves, coil_structure = tree_flatten(
self.things[0], is_leaf=lambda x: not hasattr(x, "__len__")
)
self._num_coils = len(coil_leaves)

# check type
if isinstance(self._grid, numbers.Integral):
self._grid = LinearGrid(N=self._grid, endpoint=False)
# all of these cases return a container MixedCoilSet that contains
# LinearGrids. i.e. MixedCoilSet.coils = list of LinearGrid
if self._grid is None:
# map default grid to structure of inputted coils
self._grid = tree_map(
lambda x: LinearGrid(
N=2 * x.N + 5, NFP=getattr(x, "NFP", 1), endpoint=False
),
self.things[0],
is_leaf=lambda x: not hasattr(x, "__len__"),
)
elif isinstance(self._grid, _Grid):
# map inputted single LinearGrid to structure of inputted coils
self._grid = [self._grid] * self._num_coils
self._grid = tree_unflatten(coil_structure, self._grid)
else:
# this case covers an inputted list of grids that matches the size
# of the inputted coils. Can be a 1D list or nested list.
flattened_grid = tree_leaves(
self._grid, is_leaf=lambda x: isinstance(x, _Grid)
)
self._grid = tree_unflatten(coil_structure, flattened_grid)
coil = self.things[0]
grid = self._grid

timer = Timer()
if verbose > 0:
print("Precomputing transforms")
timer.start("Precomputing transforms")

transforms = tree_map(
lambda x, y: get_transforms(self._data_keys, obj=x, grid=y),
self.things[0],
self._grid,
is_leaf=lambda x: not hasattr(x, "__len__"),
)
# get individual coils from coilset
coils, structure = tree_flatten(coil, is_leaf=_is_single_coil)
self._num_coils = len(coils)

grids = tree_leaves(self._grid, is_leaf=lambda x: hasattr(x, "num_nodes"))
self._dim_f = np.sum([grid.num_nodes for grid in grids])
self._quad_weights = np.concatenate([grid.spacing[:, 2] for grid in grids])
# map grid to list of length coils
if grid is None:
grid = [LinearGrid(N=2 * c.N + 5, endpoint=False) for c in coils]
if isinstance(grid, numbers.Integral):
grid = LinearGrid(N=self._grid, endpoint=False)
if isinstance(grid, _Grid):
grid = [grid] * self._num_coils
if isinstance(grid, list):
grid = tree_leaves(grid, is_leaf=lambda g: isinstance(g, _Grid))

# get only needed grids (1 per CoilSet) and flatten that list
self._grid = tree_leaves(
to_list(self._grid), is_leaf=lambda x: isinstance(x, _Grid)
)
transforms = tree_leaves(
to_list(transforms), is_leaf=lambda x: isinstance(x, dict)
errorif(
len(grid) != len(coils),
ValueError,
"grid input must be broadcastable to the coil structure.",
)

errorif(
np.any([grid.num_rho > 1 or grid.num_theta > 1 for grid in self._grid]),
np.any([g.num_rho > 1 or g.num_theta > 1 for g in grid]),
ValueError,
"Only use toroidal resolution for coil grids.",
)

# CoilSet and _Coil have one grid/transform
if not isinstance(self.things[0], MixedCoilSet):
self._grid = self._grid[0]
transforms = transforms[0]
self._dim_f = np.sum([g.num_nodes for g in grid])
quad_weights = np.concatenate([g.spacing[:, 2] for g in grid])

self._constants = {
"transforms": transforms,
"quad_weights": self._quad_weights,
}
# map grid to the same structure as coil and then remove unnecessary members
grid = tree_unflatten(structure, grid)
grid = _prune_coilset_tree(grid)
coil = _prune_coilset_tree(coil)

timer = Timer()
if verbose > 0:
print("Precomputing transforms")
timer.start("Precomputing transforms")

transforms = tree_map(
lambda c, g: get_transforms(self._data_keys, obj=c, grid=g),
coil,
grid,
is_leaf=lambda x: _is_single_coil(x) or isinstance(x, _Grid),
)

self._grid = grid
self._constants = {"transforms": transforms, "quad_weights": quad_weights}

timer.stop("Precomputing transforms")
if verbose > 1:
timer.disp("Precomputing transforms")

if self._normalize:
self._scales = [compute_scaling_factors(coil) for coil in coil_leaves]
self._scales = [compute_scaling_factors(coil) for coil in coils]

super().build(use_jit=use_jit, verbose=verbose)

Expand Down
Loading

0 comments on commit a7bd35f

Please sign in to comment.