From d087a0c9af8322e3c59a31b902cf967dd21a9d91 Mon Sep 17 00:00:00 2001 From: Kian Orr Date: Tue, 5 Sep 2023 13:00:54 -0400 Subject: [PATCH 01/18] added draft code for aliases --- desc/compute/_metric.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/desc/compute/_metric.py b/desc/compute/_metric.py index d8d51d40b9..030e641e8a 100644 --- a/desc/compute/_metric.py +++ b/desc/compute/_metric.py @@ -1294,7 +1294,8 @@ def _g_sup_zz(params, transforms, profiles, data, **kwargs): @register_compute_fun( - name="g^rt", + # name="g^rt" + name = ["g^" + i + j for i in ['r', 't', 'z'] for j in ['r', 't', 'z']], label="g^{\\rho\\theta}", units="m^{-2}", units_long="inverse square meters", @@ -1307,7 +1308,12 @@ def _g_sup_zz(params, transforms, profiles, data, **kwargs): data=["e^rho", "e^theta"], ) def _g_sup_rt(params, transforms, profiles, data, **kwargs): - data["g^rt"] = dot(data["e^rho"], data["e^theta"]) + # data["g^rt"] = dot(data["e^rho"], data["e^theta"]) + # reruns each time but looks cleaner this way i think + aliases = ["g^" + i + j for i in ['r', 't', 'z'] for j in ['r', 't', 'z']] + for alias in aliases: + data[alias] = dot(data["e^rho"], data["e^theta"]) + return data From a4504f39015af9787ba02ad22b0dd2448881c04c Mon Sep 17 00:00:00 2001 From: Kian Orr Date: Tue, 5 Sep 2023 23:36:16 -0400 Subject: [PATCH 02/18] data_index adjustment bc forgot to save --- desc/compute/data_index.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/desc/compute/data_index.py b/desc/compute/data_index.py index 3dae714d86..141b00d77b 100644 --- a/desc/compute/data_index.py +++ b/desc/compute/data_index.py @@ -88,7 +88,12 @@ def _decorator(func): raise ValueError( f"Already registered function with parameterization {p} and name {name}." ) - data_index[base_class][name] = d.copy() + if isinstance(name, list): + data_index[base_class][name[0]] = d.copy() + for alias in name: + data_index[base_class][alias] = data_index[base_class][name[0]] + else: + data_index[base_class][name] = d.copy() flag = True if not flag: raise ValueError( From 0dc823b52a79024e53462f5e122681a2968f0bc9 Mon Sep 17 00:00:00 2001 From: Kian Orr Date: Mon, 25 Sep 2023 23:00:24 -0400 Subject: [PATCH 03/18] made test for compute --- desc/compute/_metric.py | 2 +- tests/test_compute_funs.py | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/desc/compute/_metric.py b/desc/compute/_metric.py index 030e641e8a..6f4a96234d 100644 --- a/desc/compute/_metric.py +++ b/desc/compute/_metric.py @@ -1295,7 +1295,7 @@ def _g_sup_zz(params, transforms, profiles, data, **kwargs): @register_compute_fun( # name="g^rt" - name = ["g^" + i + j for i in ['r', 't', 'z'] for j in ['r', 't', 'z']], + name=["g^" + i + j for i in ['r', 't', 'z'] for j in ['r', 't', 'z']], label="g^{\\rho\\theta}", units="m^{-2}", units_long="inverse square meters", diff --git a/tests/test_compute_funs.py b/tests/test_compute_funs.py index 7ecfe14f89..6ceb15ed8c 100644 --- a/tests/test_compute_funs.py +++ b/tests/test_compute_funs.py @@ -43,6 +43,15 @@ def myconvolve_2d(arr_1d, stencil, shape): return conv +def test_aliases(): + eq = Equilibrium() # torus + rho = np.linspace(0, 1, 64) + grid = LinearGrid(M=eq.M_grid, N=eq.N_grid, NFP=eq.NFP, sym=eq.sym, rho=rho) + coeffs = [eq.compute("g^" + i + j, grid=grid) + for i in ['r', 't', 'z'] for j in ['r', 't', 'z']] + np.testing.assert_allclose(coeffs) + + @pytest.mark.unit def test_total_volume(DummyStellarator): """Test that the volume enclosed by the LCFS is equal to the total volume.""" From f150d3338b8fd6993dc55999b240b44bd17f81f8 Mon Sep 17 00:00:00 2001 From: Kian Orr Date: Mon, 25 Sep 2023 23:26:12 -0400 Subject: [PATCH 04/18] updated alias test --- tests/test_compute_funs.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/test_compute_funs.py b/tests/test_compute_funs.py index 6ceb15ed8c..afd3f56d3c 100644 --- a/tests/test_compute_funs.py +++ b/tests/test_compute_funs.py @@ -43,13 +43,14 @@ def myconvolve_2d(arr_1d, stencil, shape): return conv +@pytest.mark.unit def test_aliases(): eq = Equilibrium() # torus rho = np.linspace(0, 1, 64) grid = LinearGrid(M=eq.M_grid, N=eq.N_grid, NFP=eq.NFP, sym=eq.sym, rho=rho) - coeffs = [eq.compute("g^" + i + j, grid=grid) - for i in ['r', 't', 'z'] for j in ['r', 't', 'z']] - np.testing.assert_allclose(coeffs) + aliases = ["g^" + i + j for i in ['r', 't', 'z'] for j in ['r', 't', 'z']] + data = eq.compute(aliases, grid=grid) + np.testing.assert_allclose([data[alias] for alias in aliases]) @pytest.mark.unit From 3f60eee2e91cf7ea6131b5a29ef576c8e8137e00 Mon Sep 17 00:00:00 2001 From: Kian Orr Date: Fri, 20 Oct 2023 10:55:35 -0700 Subject: [PATCH 05/18] uncommented out code again --- desc/compute/_metric.py | 10 ++++------ desc/compute/data_index.py | 11 +++++------ 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/desc/compute/_metric.py b/desc/compute/_metric.py index 6f4a96234d..dd07b38c92 100644 --- a/desc/compute/_metric.py +++ b/desc/compute/_metric.py @@ -1294,8 +1294,7 @@ def _g_sup_zz(params, transforms, profiles, data, **kwargs): @register_compute_fun( - # name="g^rt" - name=["g^" + i + j for i in ['r', 't', 'z'] for j in ['r', 't', 'z']], + name="g^rt", label="g^{\\rho\\theta}", units="m^{-2}", units_long="inverse square meters", @@ -1306,11 +1305,10 @@ def _g_sup_zz(params, transforms, profiles, data, **kwargs): profiles=[], coordinates="rtz", data=["e^rho", "e^theta"], + aliases=["g^" + i + j for i in ["r", "t", "z"] for j in ["r", "t", "z"]], ) -def _g_sup_rt(params, transforms, profiles, data, **kwargs): - # data["g^rt"] = dot(data["e^rho"], data["e^theta"]) - # reruns each time but looks cleaner this way i think - aliases = ["g^" + i + j for i in ['r', 't', 'z'] for j in ['r', 't', 'z']] +def _g_sup_rt(params, transforms, profiles, data): + aliases = ["g^" + i + j for i in ["r", "t", "z"] for j in ["r", "t", "z"]] for alias in aliases: data[alias] = dot(data["e^rho"], data["e^theta"]) diff --git a/desc/compute/data_index.py b/desc/compute/data_index.py index a858d401ed..4ea238b2d7 100644 --- a/desc/compute/data_index.py +++ b/desc/compute/data_index.py @@ -13,6 +13,7 @@ def register_compute_fun( profiles, coordinates, data, + aliases=[], parameterization="desc.equilibrium.equilibrium.Equilibrium", axis_limit_data=None, **kwargs, @@ -88,12 +89,10 @@ def _decorator(func): raise ValueError( f"Already registered function with parameterization {p} and name {name}." ) - if isinstance(name, list): - data_index[base_class][name[0]] = d.copy() - for alias in name: - data_index[base_class][alias] = data_index[base_class][name[0]] - else: - data_index[base_class][name] = d.copy() + data_index[base_class][name] = d.copy() + for alias in aliases: + data_index[base_class][alias] = data_index[base_class][name] + flag = True if not flag: raise ValueError( From 4a29aa9912b5058f31adb80202d750b2f772e74c Mon Sep 17 00:00:00 2001 From: Kian Orr Date: Sun, 22 Oct 2023 16:46:16 -0400 Subject: [PATCH 06/18] added example cases --- desc/compute/_metric.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/desc/compute/_metric.py b/desc/compute/_metric.py index dd07b38c92..fd573ff2ea 100644 --- a/desc/compute/_metric.py +++ b/desc/compute/_metric.py @@ -1305,12 +1305,14 @@ def _g_sup_zz(params, transforms, profiles, data, **kwargs): profiles=[], coordinates="rtz", data=["e^rho", "e^theta"], - aliases=["g^" + i + j for i in ["r", "t", "z"] for j in ["r", "t", "z"]], + aliases=["g^rt_alias", "g^rt_alias2"], ) def _g_sup_rt(params, transforms, profiles, data): - aliases = ["g^" + i + j for i in ["r", "t", "z"] for j in ["r", "t", "z"]] + data["g^rt"] = dot(data["e^rho"], data["e^theta"]) + + aliases = ["g^rt_alias", "g^rt_alias2"] for alias in aliases: - data[alias] = dot(data["e^rho"], data["e^theta"]) + data[alias] = data["g^rt"].copy() return data @@ -1327,9 +1329,15 @@ def _g_sup_rt(params, transforms, profiles, data): profiles=[], coordinates="rtz", data=["e^rho", "e^zeta"], + aliases=["g^rz_alias1", "g^rz_alias2"], ) def _g_sup_rz(params, transforms, profiles, data, **kwargs): data["g^rz"] = dot(data["e^rho"], data["e^zeta"]) + + aliases = (["g^rz_alias1", "g^rz_alias2"],) + for alias in aliases: + data[alias] = data["g^rz"].copy() + return data From 2ff78c84aa2cbe49377ebe34d4ad94729e7b6c90 Mon Sep 17 00:00:00 2001 From: Kian Orr Date: Sat, 4 Nov 2023 19:51:21 -0400 Subject: [PATCH 07/18] added function to utils used in data_index --- desc/compute/_metric.py | 12 ------------ desc/compute/data_index.py | 5 +++++ desc/utils.py | 16 ++++++++++++++++ tests/test_compute_funs.py | 20 +++++++++++++------- 4 files changed, 34 insertions(+), 19 deletions(-) diff --git a/desc/compute/_metric.py b/desc/compute/_metric.py index fd573ff2ea..1777bebbcc 100644 --- a/desc/compute/_metric.py +++ b/desc/compute/_metric.py @@ -1305,15 +1305,9 @@ def _g_sup_zz(params, transforms, profiles, data, **kwargs): profiles=[], coordinates="rtz", data=["e^rho", "e^theta"], - aliases=["g^rt_alias", "g^rt_alias2"], ) def _g_sup_rt(params, transforms, profiles, data): data["g^rt"] = dot(data["e^rho"], data["e^theta"]) - - aliases = ["g^rt_alias", "g^rt_alias2"] - for alias in aliases: - data[alias] = data["g^rt"].copy() - return data @@ -1329,15 +1323,9 @@ def _g_sup_rt(params, transforms, profiles, data): profiles=[], coordinates="rtz", data=["e^rho", "e^zeta"], - aliases=["g^rz_alias1", "g^rz_alias2"], ) def _g_sup_rz(params, transforms, profiles, data, **kwargs): data["g^rz"] = dot(data["e^rho"], data["e^zeta"]) - - aliases = (["g^rz_alias1", "g^rz_alias2"],) - for alias in aliases: - data[alias] = data["g^rz"].copy() - return data diff --git a/desc/compute/data_index.py b/desc/compute/data_index.py index 4ea238b2d7..4b8eef72cb 100644 --- a/desc/compute/data_index.py +++ b/desc/compute/data_index.py @@ -1,4 +1,5 @@ """data_index contains all the quantities calculated by the compute functions.""" +from desc.utils import find_permutations def register_compute_fun( @@ -70,6 +71,10 @@ def register_compute_fun( "kwargs": list(kwargs.values()), } + permutable_names = ["R_", "Z_", "phi_", "lambda_", "omega_", "sqrt(g)_"] + if not aliases and "".join(name.split("_")[:-1]) + "_" in permutable_names: + aliases = find_permutations(name) + def _decorator(func): d = { "label": label, diff --git a/desc/utils.py b/desc/utils.py index ae051a990d..36495cf4e0 100644 --- a/desc/utils.py +++ b/desc/utils.py @@ -196,6 +196,22 @@ def __getitem__(self, index): Index = _Indexable() +def find_permutations(name, separator="_"): + """Finds permutations of quantity names for aliases.""" + split_name = name.split(separator) + original_permutation = split_name[-1] + + new_permutations = list(permutations(original_permutation)) + aliases = [ + "".join(split_name[:-1]) + separator + "".join(perm) + for perm in new_permutations + ] + aliases = np.unique(aliases) + aliases = np.delete(aliases, np.where(aliases == name)) + + return aliases + + def equals(a, b): """Compare (possibly nested) objects, such as dicts and lists. diff --git a/tests/test_compute_funs.py b/tests/test_compute_funs.py index 56c28f0488..46e4a64a82 100644 --- a/tests/test_compute_funs.py +++ b/tests/test_compute_funs.py @@ -48,14 +48,20 @@ def myconvolve_2d(arr_1d, stencil, shape): @pytest.mark.unit def test_aliases(): - eq = Equilibrium() # torus - rho = np.linspace(0, 1, 64) - grid = LinearGrid(M=eq.M_grid, N=eq.N_grid, NFP=eq.NFP, sym=eq.sym, rho=rho) - aliases = ["g^" + i + j for i in ['r', 't', 'z'] for j in ['r', 't', 'z']] - data = eq.compute(aliases, grid=grid) - np.testing.assert_allclose([data[alias] for alias in aliases]) + """Tests that data_index aliases are equal.""" + n = 10 + surface = FourierRZToroidalSurface( + R_lmn=np.array([10, 1, 0.5]), + Z_lmn=np.array([0, -1, -0.5]), + modes_R=np.array([[0, 0], [1, 0], [1, n]]), + modes_Z=np.array([[0, 0], [-1, 0], [-1, n]]), + ) + eq = Equilibrium(surface=surface) + data_1 = eq.compute("R_tz") + data_2 = eq.compute("R_zt") + np.testing.assert_allclose(data_1["R_tz"], data_2["R_tz"]) + - @pytest.mark.unit def test_total_volume(DummyStellarator): """Test that the volume enclosed by the LCFS is equal to the total volume.""" From 640a4b1aacd8639e3c2c327e5a9aeff4b1d56678 Mon Sep 17 00:00:00 2001 From: Kian Orr Date: Wed, 8 Nov 2023 22:11:13 -0500 Subject: [PATCH 08/18] manually added a couple cases --- desc/compute/_basis_vectors.py | 1 + desc/compute/_metric.py | 1 + 2 files changed, 2 insertions(+) diff --git a/desc/compute/_basis_vectors.py b/desc/compute/_basis_vectors.py index c7a49f9f88..255fb431f4 100644 --- a/desc/compute/_basis_vectors.py +++ b/desc/compute/_basis_vectors.py @@ -1716,6 +1716,7 @@ def _e_sub_rho_rrz(params, transforms, profiles, data, **kwargs): "omega_rt", "omega_t", ], + aliases=["x_rrt", "x_rtr", "x_trr"], ) def _e_sub_rho_rt(params, transforms, profiles, data, **kwargs): data["e_rho_rt"] = jnp.array( diff --git a/desc/compute/_metric.py b/desc/compute/_metric.py index 1777bebbcc..c0b33ef2b7 100644 --- a/desc/compute/_metric.py +++ b/desc/compute/_metric.py @@ -1341,6 +1341,7 @@ def _g_sup_rz(params, transforms, profiles, data, **kwargs): profiles=[], coordinates="rtz", data=["e^theta", "e^zeta"], + aliases=["g^zt"], ) def _g_sup_tz(params, transforms, profiles, data, **kwargs): data["g^tz"] = dot(data["e^theta"], data["e^zeta"]) From 2589f8eb263e35733605f2847cc2ff2fde96422a Mon Sep 17 00:00:00 2001 From: Kian Orr Date: Thu, 9 Nov 2023 14:20:55 -0500 Subject: [PATCH 09/18] updated test --- tests/test_compute_funs.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/test_compute_funs.py b/tests/test_compute_funs.py index 46e4a64a82..631c1bf44f 100644 --- a/tests/test_compute_funs.py +++ b/tests/test_compute_funs.py @@ -49,18 +49,18 @@ def myconvolve_2d(arr_1d, stencil, shape): @pytest.mark.unit def test_aliases(): """Tests that data_index aliases are equal.""" - n = 10 - surface = FourierRZToroidalSurface( - R_lmn=np.array([10, 1, 0.5]), - Z_lmn=np.array([0, -1, -0.5]), - modes_R=np.array([[0, 0], [1, 0], [1, n]]), - modes_Z=np.array([[0, 0], [-1, 0], [-1, n]]), - ) - eq = Equilibrium(surface=surface) + eq = Equilibrium() + + # automatic case data_1 = eq.compute("R_tz") data_2 = eq.compute("R_zt") np.testing.assert_allclose(data_1["R_tz"], data_2["R_tz"]) + # manual case + data_1 = eq.compute("x_rrt") + data_2 = eq.compute("e_rho_rt") + np.testing.assert_allclose(data_1["e_rho_rt"], data_2["e_rho_rt"]) + @pytest.mark.unit def test_total_volume(DummyStellarator): From c2372ea10dbe552ad04640de8bcc0ba0846d4aa5 Mon Sep 17 00:00:00 2001 From: Kian Orr Date: Sat, 11 Nov 2023 01:20:32 -0500 Subject: [PATCH 10/18] formatting --- desc/compute/data_index.py | 43 ++++++++++++++++++++++++++++++++++++-- desc/compute/utils.py | 3 ++- desc/utils.py | 16 -------------- 3 files changed, 43 insertions(+), 19 deletions(-) diff --git a/desc/compute/data_index.py b/desc/compute/data_index.py index 4b8eef72cb..271fb8826f 100644 --- a/desc/compute/data_index.py +++ b/desc/compute/data_index.py @@ -1,5 +1,35 @@ """data_index contains all the quantities calculated by the compute functions.""" -from desc.utils import find_permutations +import functools +import itertools + +import numpy as np + + +def find_permutations(name, separator="_"): + """Finds permutations of quantity names for aliases.""" + split_name = name.split(separator) + original_permutation = split_name[-1] + + new_permutations = list(itertools.permutations(original_permutation)) + aliases = [ + "".join(split_name[:-1]) + separator + "".join(perm) + for perm in new_permutations + ] + aliases = np.unique(aliases) + aliases = np.delete(aliases, np.where(aliases == name)) + + return aliases + + +def assign_alias_data( + alias, primary, base_class, data_index, params, profiles, transforms, data, **kwargs +): + """Assigns primary data to alias.""" + data = data_index[base_class][primary]["fun"]( + params, transforms, profiles, data, **kwargs + ) + data[alias] = data[primary].copy() + return data def register_compute_fun( @@ -96,7 +126,16 @@ def _decorator(func): ) data_index[base_class][name] = d.copy() for alias in aliases: - data_index[base_class][alias] = data_index[base_class][name] + + data_index[base_class][alias] = d.copy() + # assigns alias compute func to generator to be used later + data_index[base_class][alias]["fun"] = functools.partial( + assign_alias_data, + alias=alias, + primary=name, + base_class=base_class, + data_index=data_index, + ) flag = True if not flag: diff --git a/desc/compute/utils.py b/desc/compute/utils.py index 40edec12fa..a6d70c57af 100644 --- a/desc/compute/utils.py +++ b/desc/compute/utils.py @@ -130,8 +130,9 @@ def _compute( ) # now compute the quantity data = data_index[parameterization][name]["fun"]( - params, transforms, profiles, data, **kwargs + params=params, transforms=transforms, profiles=profiles, data=data, **kwargs ) + return data diff --git a/desc/utils.py b/desc/utils.py index 36495cf4e0..ae051a990d 100644 --- a/desc/utils.py +++ b/desc/utils.py @@ -196,22 +196,6 @@ def __getitem__(self, index): Index = _Indexable() -def find_permutations(name, separator="_"): - """Finds permutations of quantity names for aliases.""" - split_name = name.split(separator) - original_permutation = split_name[-1] - - new_permutations = list(permutations(original_permutation)) - aliases = [ - "".join(split_name[:-1]) + separator + "".join(perm) - for perm in new_permutations - ] - aliases = np.unique(aliases) - aliases = np.delete(aliases, np.where(aliases == name)) - - return aliases - - def equals(a, b): """Compare (possibly nested) objects, such as dicts and lists. From 8b8559a18fb09d6bdc55b9d6235a16bc97b81c21 Mon Sep 17 00:00:00 2001 From: Kian Orr Date: Mon, 13 Nov 2023 17:26:49 -0500 Subject: [PATCH 11/18] formatting --- desc/compute/_basis_vectors.py | 1 + desc/compute/data_index.py | 16 +++++++++++++++- tests/test_compute_funs.py | 21 ++++++++++++++------- 3 files changed, 30 insertions(+), 8 deletions(-) diff --git a/desc/compute/_basis_vectors.py b/desc/compute/_basis_vectors.py index 255fb431f4..1c5f1d9dfd 100644 --- a/desc/compute/_basis_vectors.py +++ b/desc/compute/_basis_vectors.py @@ -1853,6 +1853,7 @@ def _e_sub_rho_rtt(params, transforms, profiles, data, **kwargs): "omega_tz", "omega_z", ], + aliases=["x_rrt", "x_trr", "x_rtr"], ) def _e_sub_rho_rtz(params, transforms, profiles, data, **kwargs): data["e_rho_rtz"] = jnp.array( diff --git a/desc/compute/data_index.py b/desc/compute/data_index.py index 271fb8826f..e1a998c14e 100644 --- a/desc/compute/data_index.py +++ b/desc/compute/data_index.py @@ -24,7 +24,21 @@ def find_permutations(name, separator="_"): def assign_alias_data( alias, primary, base_class, data_index, params, profiles, transforms, data, **kwargs ): - """Assigns primary data to alias.""" + """Assigns primary data to alias. + + Parameters + ---------- + alias : `str` + data_index key for alias of primary + primary : `str` + key defined in compute function + + Returns + ------- + data : `dict` + computed data dictionary (includes both alias and primary) + + """ data = data_index[base_class][primary]["fun"]( params, transforms, profiles, data, **kwargs ) diff --git a/tests/test_compute_funs.py b/tests/test_compute_funs.py index 631c1bf44f..3d139c707c 100644 --- a/tests/test_compute_funs.py +++ b/tests/test_compute_funs.py @@ -49,17 +49,24 @@ def myconvolve_2d(arr_1d, stencil, shape): @pytest.mark.unit def test_aliases(): """Tests that data_index aliases are equal.""" - eq = Equilibrium() + surface = FourierRZToroidalSurface( + R_lmn=[10, 1, 0.2], + Z_lmn=[-2, -0.2], + modes_R=[[0, 0], [1, 0], [0, 1]], + modes_Z=[[-1, 0], [0, -1]], + ) + + eq = Equilibrium(surface=surface) # automatic case - data_1 = eq.compute("R_tz") - data_2 = eq.compute("R_zt") - np.testing.assert_allclose(data_1["R_tz"], data_2["R_tz"]) + primary_data = eq.compute("R_tz") + alias_data = eq.compute("R_zt") + np.testing.assert_allclose(primary_data["R_tz"], alias_data["R_zt"]) # manual case - data_1 = eq.compute("x_rrt") - data_2 = eq.compute("e_rho_rt") - np.testing.assert_allclose(data_1["e_rho_rt"], data_2["e_rho_rt"]) + primary_data = eq.compute("e_rho_rt") + alias_data = eq.compute("x_rrt") + np.testing.assert_allclose(primary_data["e_rho_rt"], alias_data["x_rrt"]) @pytest.mark.unit From e06c9e2f278ad8e5f4ce29c5bc9668bb0a86befc Mon Sep 17 00:00:00 2001 From: Kian Orr Date: Mon, 13 Nov 2023 21:19:21 -0500 Subject: [PATCH 12/18] adjust test_data_index.py --- desc/compute/data_index.py | 1 + tests/test_data_index.py | 6 ++++++ 2 files changed, 7 insertions(+) diff --git a/desc/compute/data_index.py b/desc/compute/data_index.py index e1a998c14e..2944aee15d 100644 --- a/desc/compute/data_index.py +++ b/desc/compute/data_index.py @@ -129,6 +129,7 @@ def _decorator(func): "dim": dim, "coordinates": coordinates, "dependencies": deps, + "aliases": aliases, } for p in parameterization: flag = False diff --git a/tests/test_data_index.py b/tests/test_data_index.py index 325cba7a76..7ce9f4a774 100644 --- a/tests/test_data_index.py +++ b/tests/test_data_index.py @@ -92,8 +92,14 @@ def test_data_index_deps(self): if p in superclasses or p == base_class: queried_deps[base_class][name] = deps + for alias in data_index[base_class][name][ + "aliases" + ]: + queried_deps[base_class][alias] = deps + for p in data_index: for name, val in data_index[p].items(): + print(name) err_msg = f"Parameterization: {p}. Name: {name}." deps = val["dependencies"] data = set(deps["data"]) From 8842a76ed0d9f488a061d1e1efe0adb4bb01fdfe Mon Sep 17 00:00:00 2001 From: Kian Orr Date: Mon, 13 Nov 2023 23:58:40 -0500 Subject: [PATCH 13/18] added docstring --- desc/compute/data_index.py | 11 ++++++----- tests/test_axis_limits.py | 20 ++++++++++++++++++++ tests/test_data_index.py | 6 ++---- 3 files changed, 28 insertions(+), 9 deletions(-) diff --git a/desc/compute/data_index.py b/desc/compute/data_index.py index 2944aee15d..0fe4df2116 100644 --- a/desc/compute/data_index.py +++ b/desc/compute/data_index.py @@ -5,18 +5,19 @@ import numpy as np -def find_permutations(name, separator="_"): +def find_permutations(primary, separator="_"): """Finds permutations of quantity names for aliases.""" - split_name = name.split(separator) - original_permutation = split_name[-1] + split_name = primary.split(separator) + primary_permutation = split_name[-1] - new_permutations = list(itertools.permutations(original_permutation)) + new_permutations = list(itertools.permutations(primary_permutation)) + # join new permutation to form alias keys aliases = [ "".join(split_name[:-1]) + separator + "".join(perm) for perm in new_permutations ] aliases = np.unique(aliases) - aliases = np.delete(aliases, np.where(aliases == name)) + aliases = np.delete(aliases, np.where(aliases == primary)) return aliases diff --git a/tests/test_axis_limits.py b/tests/test_axis_limits.py index 0b2b37a41b..89c83016ff 100644 --- a/tests/test_axis_limits.py +++ b/tests/test_axis_limits.py @@ -81,6 +81,26 @@ } +def add_all_aliases(names): + """Add aliases to limits.""" + all_aliases = [] + for name in names: + for base_class in data_index.keys(): + if name in data_index[base_class].keys(): + all_aliases.append(data_index[base_class][name]["aliases"]) + + # flatten + all_aliases = [name for sublist in all_aliases for name in sublist] + names.update(all_aliases) + + return names + + +zero_limits = add_all_aliases(zero_limits) +not_finite_limits = add_all_aliases(not_finite_limits) +not_implemented_limits = add_all_aliases(not_implemented_limits) + + def grow_seeds( seeds, search_space, parameterization="desc.equilibrium.equilibrium.Equilibrium" ): diff --git a/tests/test_data_index.py b/tests/test_data_index.py index 7ce9f4a774..e3fd65a495 100644 --- a/tests/test_data_index.py +++ b/tests/test_data_index.py @@ -91,10 +91,8 @@ def test_data_index_deps(self): for base_class, superclasses in _class_inheritance.items(): if p in superclasses or p == base_class: queried_deps[base_class][name] = deps - - for alias in data_index[base_class][name][ - "aliases" - ]: + aliases = data_index[base_class][name]["aliases"] + for alias in aliases: queried_deps[base_class][alias] = deps for p in data_index: From db43b0f9b3500c1d18eec7ad7a76718be4251841 Mon Sep 17 00:00:00 2001 From: Kian Orr Date: Tue, 14 Nov 2023 17:38:23 -0500 Subject: [PATCH 14/18] added alias documentation and fix metric args --- desc/compute/_metric.py | 2 +- desc/compute/data_index.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/desc/compute/_metric.py b/desc/compute/_metric.py index c0b33ef2b7..edf06293c9 100644 --- a/desc/compute/_metric.py +++ b/desc/compute/_metric.py @@ -1306,7 +1306,7 @@ def _g_sup_zz(params, transforms, profiles, data, **kwargs): coordinates="rtz", data=["e^rho", "e^theta"], ) -def _g_sup_rt(params, transforms, profiles, data): +def _g_sup_rt(params, transforms, profiles, data, **kwargs): data["g^rt"] = dot(data["e^rho"], data["e^theta"]) return data diff --git a/desc/compute/data_index.py b/desc/compute/data_index.py index 0fe4df2116..3ebe262897 100644 --- a/desc/compute/data_index.py +++ b/desc/compute/data_index.py @@ -98,6 +98,9 @@ def register_compute_fun( or `desc.equilibrium.Equilibrium`. axis_limit_data : list of str Names of other items in the data index needed to compute axis limit of qty. + aliases : list + Aliases of `name`. Will be stored in the data dictionary as a copy of `name`s + data. Notes ----- From 5075e81037bf2206b1746faff2eeabcd3202a91d Mon Sep 17 00:00:00 2001 From: Kian Orr Date: Wed, 15 Nov 2023 20:43:43 -0500 Subject: [PATCH 15/18] removed rows and added column for aliases --- docs/write_variables.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/docs/write_variables.py b/docs/write_variables.py index 3c22784509..e4d7c1c6db 100644 --- a/docs/write_variables.py +++ b/docs/write_variables.py @@ -26,14 +26,17 @@ def write_csv(parameterization): datidx = data_index[parameterization] keys = datidx.keys() for key in keys: - d = { - "Name": "``" + key + "``", - "Label": ":math:`" + datidx[key]["label"].replace("$", "") + "`", - "Units": datidx[key]["units_long"], - "Description": datidx[key]["description"], - "Module": "``" + datidx[key]["fun"].__module__ + "``", - } - + if key not in data_index[parameterization][key]["aliases"]: + d = { + "Name": "``" + key + "``", + "Label": ":math:`" + datidx[key]["label"].replace("$", "") + "`", + "Units": datidx[key]["units_long"], + "Description": datidx[key]["description"], + "Module": "``" + datidx[key]["fun"].__module__ + "``", + "Aliases": f"{['``' + alias + '``' for alias in datidx[key]['aliases']]}".strip( + "[]" + ), + } # stuff like |x| is interpreted as a substitution by rst, need to escape d["Description"] = _escape(d["Description"]) writer.writerow(d) From c968eb51dad794d9d6d594a3be6b81c1ffa4ae71 Mon Sep 17 00:00:00 2001 From: Kian Orr Date: Wed, 15 Nov 2023 21:03:11 -0500 Subject: [PATCH 16/18] added to fieldnames --- docs/write_variables.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/docs/write_variables.py b/docs/write_variables.py index e4d7c1c6db..230ce5a0dd 100644 --- a/docs/write_variables.py +++ b/docs/write_variables.py @@ -19,7 +19,7 @@ def _escape(line): def write_csv(parameterization): with open(parameterization + ".csv", "w", newline="") as f: - fieldnames = ["Name", "Label", "Units", "Description", "Module"] + fieldnames = ["Name", "Label", "Units", "Description", "Module", "Aliases"] writer = csv.DictWriter(f, fieldnames=fieldnames, extrasaction="ignore") writer.writeheader() @@ -37,9 +37,9 @@ def write_csv(parameterization): "[]" ), } - # stuff like |x| is interpreted as a substitution by rst, need to escape - d["Description"] = _escape(d["Description"]) - writer.writerow(d) + # stuff like |x| is interpreted as a substitution by rst, need to escape + d["Description"] = _escape(d["Description"]) + writer.writerow(d) header = """ @@ -56,6 +56,8 @@ def write_csv(parameterization): * **Units** : physical units for the variable * **Description** : description of the variable * **Module** : where in the code the source is defined (mostly for developers) + * **Aliases** : alternative names of a variable that can be used in the same way as + the primary name """ @@ -67,7 +69,7 @@ def write_csv(parameterization): .. csv-table:: List of Variables: {} :file: {}.csv - :widths: 15, 15, 15, 60, 30 + :widths: 15, 15, 15, 60, 30, 15 :header-rows: 1 """ From de5b9463a128978f64314d7902bf301f73e841d9 Mon Sep 17 00:00:00 2001 From: Kian Orr Date: Wed, 15 Nov 2023 22:18:43 -0500 Subject: [PATCH 17/18] fix find_permutations --- desc/compute/data_index.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/desc/compute/data_index.py b/desc/compute/data_index.py index 3ebe262897..d7e3db3c52 100644 --- a/desc/compute/data_index.py +++ b/desc/compute/data_index.py @@ -1,6 +1,6 @@ """data_index contains all the quantities calculated by the compute functions.""" import functools -import itertools +from collections import deque import numpy as np @@ -9,8 +9,13 @@ def find_permutations(primary, separator="_"): """Finds permutations of quantity names for aliases.""" split_name = primary.split(separator) primary_permutation = split_name[-1] + primary_permutation = deque(primary_permutation) + + new_permutations = [] + for i in range(len(primary_permutation)): + primary_permutation.rotate(1) + new_permutations.append(list(primary_permutation)) - new_permutations = list(itertools.permutations(primary_permutation)) # join new permutation to form alias keys aliases = [ "".join(split_name[:-1]) + separator + "".join(perm) @@ -119,7 +124,7 @@ def register_compute_fun( "kwargs": list(kwargs.values()), } - permutable_names = ["R_", "Z_", "phi_", "lambda_", "omega_", "sqrt(g)_"] + permutable_names = ["R_", "Z_", "phi_", "lambda_", "omega_"] if not aliases and "".join(name.split("_")[:-1]) + "_" in permutable_names: aliases = find_permutations(name) From b10e347c630252228b9e87ae4996987463ff18bd Mon Sep 17 00:00:00 2001 From: Kian Orr Date: Thu, 16 Nov 2023 12:01:02 -0500 Subject: [PATCH 18/18] remove quotes and delete incorrect aliases --- desc/compute/_basis_vectors.py | 1 - docs/write_variables.py | 2 ++ 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/desc/compute/_basis_vectors.py b/desc/compute/_basis_vectors.py index 1c5f1d9dfd..255fb431f4 100644 --- a/desc/compute/_basis_vectors.py +++ b/desc/compute/_basis_vectors.py @@ -1853,7 +1853,6 @@ def _e_sub_rho_rtt(params, transforms, profiles, data, **kwargs): "omega_tz", "omega_z", ], - aliases=["x_rrt", "x_trr", "x_rtr"], ) def _e_sub_rho_rtz(params, transforms, profiles, data, **kwargs): data["e_rho_rtz"] = jnp.array( diff --git a/docs/write_variables.py b/docs/write_variables.py index 230ce5a0dd..1ee47aa758 100644 --- a/docs/write_variables.py +++ b/docs/write_variables.py @@ -35,6 +35,8 @@ def write_csv(parameterization): "Module": "``" + datidx[key]["fun"].__module__ + "``", "Aliases": f"{['``' + alias + '``' for alias in datidx[key]['aliases']]}".strip( "[]" + ).replace( + "'", "" ), } # stuff like |x| is interpreted as a substitution by rst, need to escape