diff --git a/desc/compute/_basis_vectors.py b/desc/compute/_basis_vectors.py index c7a49f9f88..255fb431f4 100644 --- a/desc/compute/_basis_vectors.py +++ b/desc/compute/_basis_vectors.py @@ -1716,6 +1716,7 @@ def _e_sub_rho_rrz(params, transforms, profiles, data, **kwargs): "omega_rt", "omega_t", ], + aliases=["x_rrt", "x_rtr", "x_trr"], ) def _e_sub_rho_rt(params, transforms, profiles, data, **kwargs): data["e_rho_rt"] = jnp.array( diff --git a/desc/compute/_metric.py b/desc/compute/_metric.py index d8d51d40b9..edf06293c9 100644 --- a/desc/compute/_metric.py +++ b/desc/compute/_metric.py @@ -1341,6 +1341,7 @@ def _g_sup_rz(params, transforms, profiles, data, **kwargs): profiles=[], coordinates="rtz", data=["e^theta", "e^zeta"], + aliases=["g^zt"], ) def _g_sup_tz(params, transforms, profiles, data, **kwargs): data["g^tz"] = dot(data["e^theta"], data["e^zeta"]) diff --git a/desc/compute/data_index.py b/desc/compute/data_index.py index 5c4dbf4e51..d7e3db3c52 100644 --- a/desc/compute/data_index.py +++ b/desc/compute/data_index.py @@ -1,4 +1,55 @@ """data_index contains all the quantities calculated by the compute functions.""" +import functools +from collections import deque + +import numpy as np + + +def find_permutations(primary, separator="_"): + """Finds permutations of quantity names for aliases.""" + split_name = primary.split(separator) + primary_permutation = split_name[-1] + primary_permutation = deque(primary_permutation) + + new_permutations = [] + for i in range(len(primary_permutation)): + primary_permutation.rotate(1) + new_permutations.append(list(primary_permutation)) + + # join new permutation to form alias keys + aliases = [ + "".join(split_name[:-1]) + separator + "".join(perm) + for perm in new_permutations + ] + aliases = np.unique(aliases) + aliases = np.delete(aliases, np.where(aliases == primary)) + + return aliases + + +def assign_alias_data( + alias, primary, base_class, data_index, params, profiles, transforms, data, **kwargs +): + """Assigns primary data to alias. + + Parameters + ---------- + alias : `str` + data_index key for alias of primary + primary : `str` + key defined in compute function + + Returns + ------- + data : `dict` + computed data dictionary (includes both alias and primary) + + """ + data = data_index[base_class][primary]["fun"]( + params, transforms, profiles, data, **kwargs + ) + data[alias] = data[primary].copy() + return data def register_compute_fun( @@ -13,6 +64,7 @@ def register_compute_fun( profiles, coordinates, data, + aliases=[], parameterization="desc.equilibrium.equilibrium.Equilibrium", axis_limit_data=None, **kwargs, @@ -51,6 +103,9 @@ def register_compute_fun( or `desc.equilibrium.Equilibrium`. axis_limit_data : list of str Names of other items in the data index needed to compute axis limit of qty. + aliases : list + Aliases of `name`. Will be stored in the data dictionary as a copy of `name`s + data. Notes ----- @@ -69,6 +124,10 @@ def register_compute_fun( "kwargs": list(kwargs.values()), } + permutable_names = ["R_", "Z_", "phi_", "lambda_", "omega_"] + if not aliases and "".join(name.split("_")[:-1]) + "_" in permutable_names: + aliases = find_permutations(name) + def _decorator(func): d = { "label": label, @@ -79,6 +138,7 @@ def _decorator(func): "dim": dim, "coordinates": coordinates, "dependencies": deps, + "aliases": aliases, } for p in parameterization: flag = False @@ -89,6 +149,18 @@ def _decorator(func): f"Already registered function with parameterization {p} and name {name}." ) data_index[base_class][name] = d.copy() + for alias in aliases: + + data_index[base_class][alias] = d.copy() + # assigns alias compute func to generator to be used later + data_index[base_class][alias]["fun"] = functools.partial( + assign_alias_data, + alias=alias, + primary=name, + base_class=base_class, + data_index=data_index, + ) + flag = True if not flag: raise ValueError( diff --git a/desc/compute/utils.py b/desc/compute/utils.py index 40edec12fa..a6d70c57af 100644 --- a/desc/compute/utils.py +++ b/desc/compute/utils.py @@ -130,8 +130,9 @@ def _compute( ) # now compute the quantity data = data_index[parameterization][name]["fun"]( - params, transforms, profiles, data, **kwargs + params=params, transforms=transforms, profiles=profiles, data=data, **kwargs ) + return data diff --git a/docs/write_variables.py b/docs/write_variables.py index 3c22784509..1ee47aa758 100644 --- a/docs/write_variables.py +++ b/docs/write_variables.py @@ -19,24 +19,29 @@ def _escape(line): def write_csv(parameterization): with open(parameterization + ".csv", "w", newline="") as f: - fieldnames = ["Name", "Label", "Units", "Description", "Module"] + fieldnames = ["Name", "Label", "Units", "Description", "Module", "Aliases"] writer = csv.DictWriter(f, fieldnames=fieldnames, extrasaction="ignore") writer.writeheader() datidx = data_index[parameterization] keys = datidx.keys() for key in keys: - d = { - "Name": "``" + key + "``", - "Label": ":math:`" + datidx[key]["label"].replace("$", "") + "`", - "Units": datidx[key]["units_long"], - "Description": datidx[key]["description"], - "Module": "``" + datidx[key]["fun"].__module__ + "``", - } - - # stuff like |x| is interpreted as a substitution by rst, need to escape - d["Description"] = _escape(d["Description"]) - writer.writerow(d) + if key not in data_index[parameterization][key]["aliases"]: + d = { + "Name": "``" + key + "``", + "Label": ":math:`" + datidx[key]["label"].replace("$", "") + "`", + "Units": datidx[key]["units_long"], + "Description": datidx[key]["description"], + "Module": "``" + datidx[key]["fun"].__module__ + "``", + "Aliases": f"{['``' + alias + '``' for alias in datidx[key]['aliases']]}".strip( + "[]" + ).replace( + "'", "" + ), + } + # stuff like |x| is interpreted as a substitution by rst, need to escape + d["Description"] = _escape(d["Description"]) + writer.writerow(d) header = """ @@ -53,6 +58,8 @@ def write_csv(parameterization): * **Units** : physical units for the variable * **Description** : description of the variable * **Module** : where in the code the source is defined (mostly for developers) + * **Aliases** : alternative names of a variable that can be used in the same way as + the primary name """ @@ -64,7 +71,7 @@ def write_csv(parameterization): .. csv-table:: List of Variables: {} :file: {}.csv - :widths: 15, 15, 15, 60, 30 + :widths: 15, 15, 15, 60, 30, 15 :header-rows: 1 """ diff --git a/tests/test_axis_limits.py b/tests/test_axis_limits.py index 0b2b37a41b..89c83016ff 100644 --- a/tests/test_axis_limits.py +++ b/tests/test_axis_limits.py @@ -81,6 +81,26 @@ } +def add_all_aliases(names): + """Add aliases to limits.""" + all_aliases = [] + for name in names: + for base_class in data_index.keys(): + if name in data_index[base_class].keys(): + all_aliases.append(data_index[base_class][name]["aliases"]) + + # flatten + all_aliases = [name for sublist in all_aliases for name in sublist] + names.update(all_aliases) + + return names + + +zero_limits = add_all_aliases(zero_limits) +not_finite_limits = add_all_aliases(not_finite_limits) +not_implemented_limits = add_all_aliases(not_implemented_limits) + + def grow_seeds( seeds, search_space, parameterization="desc.equilibrium.equilibrium.Equilibrium" ): diff --git a/tests/test_compute_funs.py b/tests/test_compute_funs.py index d53bed263c..bf3e7fa0aa 100644 --- a/tests/test_compute_funs.py +++ b/tests/test_compute_funs.py @@ -46,6 +46,29 @@ def myconvolve_2d(arr_1d, stencil, shape): return conv +@pytest.mark.unit +def test_aliases(): + """Tests that data_index aliases are equal.""" + surface = FourierRZToroidalSurface( + R_lmn=[10, 1, 0.2], + Z_lmn=[-2, -0.2], + modes_R=[[0, 0], [1, 0], [0, 1]], + modes_Z=[[-1, 0], [0, -1]], + ) + + eq = Equilibrium(surface=surface) + + # automatic case + primary_data = eq.compute("R_tz") + alias_data = eq.compute("R_zt") + np.testing.assert_allclose(primary_data["R_tz"], alias_data["R_zt"]) + + # manual case + primary_data = eq.compute("e_rho_rt") + alias_data = eq.compute("x_rrt") + np.testing.assert_allclose(primary_data["e_rho_rt"], alias_data["x_rrt"]) + + @pytest.mark.unit def test_total_volume(DummyStellarator): """Test that the volume enclosed by the LCFS is equal to the total volume.""" diff --git a/tests/test_data_index.py b/tests/test_data_index.py index 325cba7a76..e3fd65a495 100644 --- a/tests/test_data_index.py +++ b/tests/test_data_index.py @@ -91,9 +91,13 @@ def test_data_index_deps(self): for base_class, superclasses in _class_inheritance.items(): if p in superclasses or p == base_class: queried_deps[base_class][name] = deps + aliases = data_index[base_class][name]["aliases"] + for alias in aliases: + queried_deps[base_class][alias] = deps for p in data_index: for name, val in data_index[p].items(): + print(name) err_msg = f"Parameterization: {p}. Name: {name}." deps = val["dependencies"] data = set(deps["data"])