Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add aliases for data_index #648

Merged
merged 23 commits into from
Nov 17, 2023
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions desc/compute/_basis_vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1716,6 +1716,7 @@ def _e_sub_rho_rrz(params, transforms, profiles, data, **kwargs):
"omega_rt",
"omega_t",
],
aliases=["x_rrt", "x_rtr", "x_trr"],
)
def _e_sub_rho_rt(params, transforms, profiles, data, **kwargs):
data["e_rho_rt"] = jnp.array(
Expand Down Expand Up @@ -1852,6 +1853,7 @@ def _e_sub_rho_rtt(params, transforms, profiles, data, **kwargs):
"omega_tz",
"omega_z",
],
aliases=["x_rrt", "x_trr", "x_rtr"],
)
def _e_sub_rho_rtz(params, transforms, profiles, data, **kwargs):
data["e_rho_rtz"] = jnp.array(
Expand Down
3 changes: 2 additions & 1 deletion desc/compute/_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -1306,7 +1306,7 @@ def _g_sup_zz(params, transforms, profiles, data, **kwargs):
coordinates="rtz",
data=["e^rho", "e^theta"],
)
def _g_sup_rt(params, transforms, profiles, data, **kwargs):
def _g_sup_rt(params, transforms, profiles, data):
data["g^rt"] = dot(data["e^rho"], data["e^theta"])
return data

Expand Down Expand Up @@ -1341,6 +1341,7 @@ def _g_sup_rz(params, transforms, profiles, data, **kwargs):
profiles=[],
coordinates="rtz",
data=["e^theta", "e^zeta"],
aliases=["g^zt"],
)
def _g_sup_tz(params, transforms, profiles, data, **kwargs):
data["g^tz"] = dot(data["e^theta"], data["e^zeta"])
Expand Down
64 changes: 64 additions & 0 deletions desc/compute/data_index.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,50 @@
"""data_index contains all the quantities calculated by the compute functions."""
import functools
import itertools

import numpy as np


def find_permutations(primary, separator="_"):
"""Finds permutations of quantity names for aliases."""
split_name = primary.split(separator)
primary_permutation = split_name[-1]

new_permutations = list(itertools.permutations(primary_permutation))
# join new permutation to form alias keys
aliases = [
"".join(split_name[:-1]) + separator + "".join(perm)
for perm in new_permutations
]
aliases = np.unique(aliases)
aliases = np.delete(aliases, np.where(aliases == primary))

return aliases


def assign_alias_data(
alias, primary, base_class, data_index, params, profiles, transforms, data, **kwargs
):
"""Assigns primary data to alias.

Parameters
----------
alias : `str`
data_index key for alias of primary
primary : `str`
key defined in compute function

Returns
-------
data : `dict`
computed data dictionary (includes both alias and primary)

"""
data = data_index[base_class][primary]["fun"](
params, transforms, profiles, data, **kwargs
)
data[alias] = data[primary].copy()
return data


def register_compute_fun(
Expand All @@ -13,6 +59,7 @@ def register_compute_fun(
profiles,
coordinates,
data,
aliases=[],
parameterization="desc.equilibrium.equilibrium.Equilibrium",
axis_limit_data=None,
**kwargs,
Expand Down Expand Up @@ -69,6 +116,10 @@ def register_compute_fun(
"kwargs": list(kwargs.values()),
}

permutable_names = ["R_", "Z_", "phi_", "lambda_", "omega_", "sqrt(g)_"]
if not aliases and "".join(name.split("_")[:-1]) + "_" in permutable_names:
aliases = find_permutations(name)

def _decorator(func):
d = {
"label": label,
Expand All @@ -79,6 +130,7 @@ def _decorator(func):
"dim": dim,
"coordinates": coordinates,
"dependencies": deps,
"aliases": aliases,
}
for p in parameterization:
flag = False
Expand All @@ -89,6 +141,18 @@ def _decorator(func):
f"Already registered function with parameterization {p} and name {name}."
)
data_index[base_class][name] = d.copy()
for alias in aliases:

data_index[base_class][alias] = d.copy()
# assigns alias compute func to generator to be used later
data_index[base_class][alias]["fun"] = functools.partial(
assign_alias_data,
alias=alias,
primary=name,
base_class=base_class,
data_index=data_index,
)

flag = True
if not flag:
raise ValueError(
Expand Down
3 changes: 2 additions & 1 deletion desc/compute/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,9 @@ def _compute(
)
# now compute the quantity
data = data_index[parameterization][name]["fun"](
params, transforms, profiles, data, **kwargs
params=params, transforms=transforms, profiles=profiles, data=data, **kwargs
)

return data


Expand Down
20 changes: 20 additions & 0 deletions tests/test_axis_limits.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,26 @@
}


def add_all_aliases(names):
"""Add aliases to limits."""
all_aliases = []
for name in names:
for base_class in data_index.keys():
if name in data_index[base_class].keys():
all_aliases.append(data_index[base_class][name]["aliases"])

# flatten
all_aliases = [name for sublist in all_aliases for name in sublist]
names.update(all_aliases)

return names


zero_limits = add_all_aliases(zero_limits)
not_finite_limits = add_all_aliases(not_finite_limits)
not_implemented_limits = add_all_aliases(not_implemented_limits)


def grow_seeds(
seeds, search_space, parameterization="desc.equilibrium.equilibrium.Equilibrium"
):
Expand Down
23 changes: 23 additions & 0 deletions tests/test_compute_funs.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,29 @@ def myconvolve_2d(arr_1d, stencil, shape):
return conv


@pytest.mark.unit
def test_aliases():
"""Tests that data_index aliases are equal."""
surface = FourierRZToroidalSurface(
R_lmn=[10, 1, 0.2],
Z_lmn=[-2, -0.2],
modes_R=[[0, 0], [1, 0], [0, 1]],
modes_Z=[[-1, 0], [0, -1]],
)

eq = Equilibrium(surface=surface)

# automatic case
primary_data = eq.compute("R_tz")
alias_data = eq.compute("R_zt")
np.testing.assert_allclose(primary_data["R_tz"], alias_data["R_zt"])

# manual case
primary_data = eq.compute("e_rho_rt")
alias_data = eq.compute("x_rrt")
np.testing.assert_allclose(primary_data["e_rho_rt"], alias_data["x_rrt"])


@pytest.mark.unit
def test_total_volume(DummyStellarator):
"""Test that the volume enclosed by the LCFS is equal to the total volume."""
Expand Down
4 changes: 4 additions & 0 deletions tests/test_data_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,13 @@ def test_data_index_deps(self):
for base_class, superclasses in _class_inheritance.items():
if p in superclasses or p == base_class:
queried_deps[base_class][name] = deps
aliases = data_index[base_class][name]["aliases"]
for alias in aliases:
queried_deps[base_class][alias] = deps

for p in data_index:
for name, val in data_index[p].items():
print(name)
err_msg = f"Parameterization: {p}. Name: {name}."
deps = val["dependencies"]
data = set(deps["data"])
Expand Down