Skip to content

Commit

Permalink
Merge pull request #648 from PlasmaControl/ko/aliases
Browse files Browse the repository at this point in the history
Add aliases for data_index
  • Loading branch information
kianorr authored Nov 17, 2023
2 parents 6f24cf9 + 33075bb commit 23d0635
Show file tree
Hide file tree
Showing 8 changed files with 143 additions and 14 deletions.
1 change: 1 addition & 0 deletions desc/compute/_basis_vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1716,6 +1716,7 @@ def _e_sub_rho_rrz(params, transforms, profiles, data, **kwargs):
"omega_rt",
"omega_t",
],
aliases=["x_rrt", "x_rtr", "x_trr"],
)
def _e_sub_rho_rt(params, transforms, profiles, data, **kwargs):
data["e_rho_rt"] = jnp.array(
Expand Down
1 change: 1 addition & 0 deletions desc/compute/_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -1341,6 +1341,7 @@ def _g_sup_rz(params, transforms, profiles, data, **kwargs):
profiles=[],
coordinates="rtz",
data=["e^theta", "e^zeta"],
aliases=["g^zt"],
)
def _g_sup_tz(params, transforms, profiles, data, **kwargs):
data["g^tz"] = dot(data["e^theta"], data["e^zeta"])
Expand Down
72 changes: 72 additions & 0 deletions desc/compute/data_index.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,55 @@
"""data_index contains all the quantities calculated by the compute functions."""
import functools
from collections import deque

import numpy as np


def find_permutations(primary, separator="_"):
"""Finds permutations of quantity names for aliases."""
split_name = primary.split(separator)
primary_permutation = split_name[-1]
primary_permutation = deque(primary_permutation)

new_permutations = []
for i in range(len(primary_permutation)):
primary_permutation.rotate(1)
new_permutations.append(list(primary_permutation))

# join new permutation to form alias keys
aliases = [
"".join(split_name[:-1]) + separator + "".join(perm)
for perm in new_permutations
]
aliases = np.unique(aliases)
aliases = np.delete(aliases, np.where(aliases == primary))

return aliases


def assign_alias_data(
alias, primary, base_class, data_index, params, profiles, transforms, data, **kwargs
):
"""Assigns primary data to alias.
Parameters
----------
alias : `str`
data_index key for alias of primary
primary : `str`
key defined in compute function
Returns
-------
data : `dict`
computed data dictionary (includes both alias and primary)
"""
data = data_index[base_class][primary]["fun"](
params, transforms, profiles, data, **kwargs
)
data[alias] = data[primary].copy()
return data


def register_compute_fun(
Expand All @@ -13,6 +64,7 @@ def register_compute_fun(
profiles,
coordinates,
data,
aliases=[],
parameterization="desc.equilibrium.equilibrium.Equilibrium",
axis_limit_data=None,
**kwargs,
Expand Down Expand Up @@ -51,6 +103,9 @@ def register_compute_fun(
or `desc.equilibrium.Equilibrium`.
axis_limit_data : list of str
Names of other items in the data index needed to compute axis limit of qty.
aliases : list
Aliases of `name`. Will be stored in the data dictionary as a copy of `name`s
data.
Notes
-----
Expand All @@ -69,6 +124,10 @@ def register_compute_fun(
"kwargs": list(kwargs.values()),
}

permutable_names = ["R_", "Z_", "phi_", "lambda_", "omega_"]
if not aliases and "".join(name.split("_")[:-1]) + "_" in permutable_names:
aliases = find_permutations(name)

def _decorator(func):
d = {
"label": label,
Expand All @@ -79,6 +138,7 @@ def _decorator(func):
"dim": dim,
"coordinates": coordinates,
"dependencies": deps,
"aliases": aliases,
}
for p in parameterization:
flag = False
Expand All @@ -89,6 +149,18 @@ def _decorator(func):
f"Already registered function with parameterization {p} and name {name}."
)
data_index[base_class][name] = d.copy()
for alias in aliases:

data_index[base_class][alias] = d.copy()
# assigns alias compute func to generator to be used later
data_index[base_class][alias]["fun"] = functools.partial(
assign_alias_data,
alias=alias,
primary=name,
base_class=base_class,
data_index=data_index,
)

flag = True
if not flag:
raise ValueError(
Expand Down
3 changes: 2 additions & 1 deletion desc/compute/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,9 @@ def _compute(
)
# now compute the quantity
data = data_index[parameterization][name]["fun"](
params, transforms, profiles, data, **kwargs
params=params, transforms=transforms, profiles=profiles, data=data, **kwargs
)

return data


Expand Down
33 changes: 20 additions & 13 deletions docs/write_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,29 @@ def _escape(line):

def write_csv(parameterization):
with open(parameterization + ".csv", "w", newline="") as f:
fieldnames = ["Name", "Label", "Units", "Description", "Module"]
fieldnames = ["Name", "Label", "Units", "Description", "Module", "Aliases"]
writer = csv.DictWriter(f, fieldnames=fieldnames, extrasaction="ignore")
writer.writeheader()

datidx = data_index[parameterization]
keys = datidx.keys()
for key in keys:
d = {
"Name": "``" + key + "``",
"Label": ":math:`" + datidx[key]["label"].replace("$", "") + "`",
"Units": datidx[key]["units_long"],
"Description": datidx[key]["description"],
"Module": "``" + datidx[key]["fun"].__module__ + "``",
}

# stuff like |x| is interpreted as a substitution by rst, need to escape
d["Description"] = _escape(d["Description"])
writer.writerow(d)
if key not in data_index[parameterization][key]["aliases"]:
d = {
"Name": "``" + key + "``",
"Label": ":math:`" + datidx[key]["label"].replace("$", "") + "`",
"Units": datidx[key]["units_long"],
"Description": datidx[key]["description"],
"Module": "``" + datidx[key]["fun"].__module__ + "``",
"Aliases": f"{['``' + alias + '``' for alias in datidx[key]['aliases']]}".strip(
"[]"
).replace(
"'", ""
),
}
# stuff like |x| is interpreted as a substitution by rst, need to escape
d["Description"] = _escape(d["Description"])
writer.writerow(d)


header = """
Expand All @@ -53,6 +58,8 @@ def write_csv(parameterization):
* **Units** : physical units for the variable
* **Description** : description of the variable
* **Module** : where in the code the source is defined (mostly for developers)
* **Aliases** : alternative names of a variable that can be used in the same way as
the primary name
"""
Expand All @@ -64,7 +71,7 @@ def write_csv(parameterization):
.. csv-table:: List of Variables: {}
:file: {}.csv
:widths: 15, 15, 15, 60, 30
:widths: 15, 15, 15, 60, 30, 15
:header-rows: 1
"""
Expand Down
20 changes: 20 additions & 0 deletions tests/test_axis_limits.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,26 @@
}


def add_all_aliases(names):
"""Add aliases to limits."""
all_aliases = []
for name in names:
for base_class in data_index.keys():
if name in data_index[base_class].keys():
all_aliases.append(data_index[base_class][name]["aliases"])

# flatten
all_aliases = [name for sublist in all_aliases for name in sublist]
names.update(all_aliases)

return names


zero_limits = add_all_aliases(zero_limits)
not_finite_limits = add_all_aliases(not_finite_limits)
not_implemented_limits = add_all_aliases(not_implemented_limits)


def grow_seeds(
seeds, search_space, parameterization="desc.equilibrium.equilibrium.Equilibrium"
):
Expand Down
23 changes: 23 additions & 0 deletions tests/test_compute_funs.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,29 @@ def myconvolve_2d(arr_1d, stencil, shape):
return conv


@pytest.mark.unit
def test_aliases():
"""Tests that data_index aliases are equal."""
surface = FourierRZToroidalSurface(
R_lmn=[10, 1, 0.2],
Z_lmn=[-2, -0.2],
modes_R=[[0, 0], [1, 0], [0, 1]],
modes_Z=[[-1, 0], [0, -1]],
)

eq = Equilibrium(surface=surface)

# automatic case
primary_data = eq.compute("R_tz")
alias_data = eq.compute("R_zt")
np.testing.assert_allclose(primary_data["R_tz"], alias_data["R_zt"])

# manual case
primary_data = eq.compute("e_rho_rt")
alias_data = eq.compute("x_rrt")
np.testing.assert_allclose(primary_data["e_rho_rt"], alias_data["x_rrt"])


@pytest.mark.unit
def test_total_volume(DummyStellarator):
"""Test that the volume enclosed by the LCFS is equal to the total volume."""
Expand Down
4 changes: 4 additions & 0 deletions tests/test_data_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,13 @@ def test_data_index_deps(self):
for base_class, superclasses in _class_inheritance.items():
if p in superclasses or p == base_class:
queried_deps[base_class][name] = deps
aliases = data_index[base_class][name]["aliases"]
for alias in aliases:
queried_deps[base_class][alias] = deps

for p in data_index:
for name, val in data_index[p].items():
print(name)
err_msg = f"Parameterization: {p}. Name: {name}."
deps = val["dependencies"]
data = set(deps["data"])
Expand Down

0 comments on commit 23d0635

Please sign in to comment.