diff --git a/torax/constants.py b/torax/constants.py index 1d4275d7..27b64a66 100644 --- a/torax/constants.py +++ b/torax/constants.py @@ -51,6 +51,7 @@ class Constants: epsilon0: chex.Numeric mu0: chex.Numeric eps: chex.Numeric + c: chex.Numeric CONSTANTS: Final[Constants] = Constants( @@ -61,6 +62,7 @@ class Constants: epsilon0=8.854e-12, mu0=4 * jnp.pi * 1e-7, eps=1e-7, + c=2.99792458e8, ) # Taken from diff --git a/torax/physics.py b/torax/physics.py index 281fd71c..e7e12f0e 100644 --- a/torax/physics.py +++ b/torax/physics.py @@ -23,6 +23,7 @@ import chex import jax from jax import numpy as jnp + from torax import array_typing from torax import constants from torax import jax_utils @@ -474,6 +475,23 @@ def _calculate_lambda_ei( """ return 15.2 - 0.5 * jnp.log(ne / 1e20) + jnp.log(temp_el) +def _calculate_lambda_ee( + temp_el: jax.Array, + ne: jax.Array, +) -> jax.Array: + """Calculates Coulomb logarithm for electron-ion collisions. + + See Wesson 3rd edition p727. + + Args: + temp_el: Electron temperature in keV. + ne: Electron density in m^-3. + + Returns: + Coulomb logarithm. + """ + return 14.9 - 0.5 * jnp.log(ne / 1e20) + jnp.log(temp_el) + def fast_ion_fractional_heating_formula( birth_energy: float | array_typing.ArrayFloat, diff --git a/torax/transport_model/tests/tglf_based_transport_model.py b/torax/transport_model/tests/tglf_based_transport_model.py new file mode 100644 index 00000000..a4e2969d --- /dev/null +++ b/torax/transport_model/tests/tglf_based_transport_model.py @@ -0,0 +1,208 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for torax.transport_model.tglf_based_transport_model.""" +from absl.testing import absltest +from absl.testing import parameterized +import chex +import jax.numpy as jnp +from torax import core_profile_setters +from torax import state +from torax.config import runtime_params as general_runtime_params +from torax.config import runtime_params_slice +from torax.geometry import circular_geometry +from torax.geometry import geometry +from torax.pedestal_model import pedestal_model as pedestal_model_lib +from torax.pedestal_model import set_tped_nped +from torax.sources import source_models as source_models_lib +from torax.transport_model import tglf_based_transport_model +from torax.transport_model import quasilinear_transport_model +from torax.transport_model import runtime_params as runtime_params_lib + + +def _get_model_inputs(transport: tglf_based_transport_model.RuntimeParams): + """Returns the model inputs for testing.""" + runtime_params = general_runtime_params.GeneralRuntimeParams() + geo = circular_geometry.build_circular_geometry() + source_models_builder = source_models_lib.SourceModelsBuilder() + source_models = source_models_builder() + pedestal_model_builder = ( + set_tped_nped.SetTemperatureDensityPedestalModelBuilder() + ) + dynamic_runtime_params_slice = ( + runtime_params_slice.DynamicRuntimeParamsSliceProvider( + runtime_params=runtime_params, + transport=transport, + sources=source_models_builder.runtime_params, + pedestal=pedestal_model_builder.runtime_params, + torax_mesh=geo.torax_mesh, + )( + t=runtime_params.numerics.t_initial, + ) + ) + static_slice = runtime_params_slice.build_static_runtime_params_slice( + runtime_params=runtime_params, + source_runtime_params=source_models_builder.runtime_params, + torax_mesh=geo.torax_mesh, + ) + core_profiles = core_profile_setters.initial_core_profiles( + dynamic_runtime_params_slice=dynamic_runtime_params_slice, + static_runtime_params_slice=static_slice, + geo=geo, + source_models=source_models, + ) + return dynamic_runtime_params_slice, geo, core_profiles + + +class TGLFBasedTransportModelTest(parameterized.TestCase): + """Unit tests for the `torax.transport_model.tglf_based_transport_model` module.""" + + def test_tglf_based_transport_model_output_shapes(self): + """Tests that the core transport output has the right shapes.""" + transport = tglf_based_transport_model.RuntimeParams( + **runtime_params_lib.RuntimeParams() + ) + transport_model = FakeTGLFBasedTransportModel() + dynamic_runtime_params_slice, geo, core_profiles = _get_model_inputs( + transport + ) + pedestal_model = set_tped_nped.SetTemperatureDensityPedestalModel() + pedestal_model_outputs = pedestal_model( + dynamic_runtime_params_slice, geo, core_profiles + ) + + core_transport = transport_model( + dynamic_runtime_params_slice, geo, core_profiles, pedestal_model_outputs + ) + expected_shape = geo.rho_face_norm.shape + self.assertEqual(core_transport.chi_face_ion.shape, expected_shape) + self.assertEqual(core_transport.chi_face_el.shape, expected_shape) + self.assertEqual(core_transport.d_face_el.shape, expected_shape) + self.assertEqual(core_transport.v_face_el.shape, expected_shape) + + def test_tglf_based_transport_model_prepare_tglf_inputs_shapes(self): + """Tests that the tglf inputs have the expected shapes.""" + transport = tglf_based_transport_model.RuntimeParams( + **runtime_params_lib.RuntimeParams() + ) + dynamic_runtime_params_slice, geo, core_profiles = _get_model_inputs( + transport + ) + transport_model = FakeTGLFBasedTransportModel() + tglf_inputs = transport_model._prepare_tglf_inputs( + Zeff_face=dynamic_runtime_params_slice.plasma_composition.Zeff_face, + q_correction_factor=dynamic_runtime_params_slice.numerics.q_correction_factor, + geo=geo, + core_profiles=core_profiles, + ) + + # Inputs that are 1D + vector_keys = [ + 'chiGB', + 'lref_over_lti', + 'lref_over_lte', + 'lref_over_lne', + 'lref_over_lni0', + 'lref_over_lni1', + 'Ti_over_Te', + 'drmaj', + 'q', + 's_hat', + 'nu_ee', + 'kappa', + 'kappa_shear', + 'delta', + 'delta_shear', + 'beta_e', + 'Zeff', + ] + # Inputs that are 0D + scalar_keys = ['Rmaj', 'Rmin'] + + expected_vector_length = geo.rho_face_norm.shape[0] + for key in vector_keys: + try: + self.assertEqual( + getattr(tglf_inputs, key).shape, (expected_vector_length,) + ) + except Exception as e: + print(key, getattr(tglf_inputs, key)) + raise e + for key in scalar_keys: + self.assertEqual(getattr(tglf_inputs, key).shape, ()) + + +class FakeTGLFBasedTransportModel( + tglf_based_transport_model.TGLFBasedTransportModel +): + """Fake TGLFBasedTransportModel for testing purposes.""" + + def __init__(self): + super().__init__() + self._frozen = True + + # pylint: disable=invalid-name + def prepare_tglf_inputs( + self, + Zeff_face: chex.Array, + q_correction_factor: chex.Numeric, + geo: geometry.Geometry, + core_profiles: state.CoreProfiles, + ) -> tglf_based_transport_model.TGLFInputs: + """Exposing prepare_tglf_inputs for testing.""" + return self._prepare_tglf_inputs( + Zeff_face=Zeff_face, + q_correction_factor=q_correction_factor, + geo=geo, + core_profiles=core_profiles, + ) + + # pylint: enable=invalid-name + + def _call_implementation( + self, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, + geo: geometry.Geometry, + core_profiles: state.CoreProfiles, + pedestal_model_output: pedestal_model_lib.PedestalModelOutput, + ) -> state.CoreTransport: + tglf_inputs = self._prepare_tglf_inputs( + Zeff_face=dynamic_runtime_params_slice.plasma_composition.Zeff_face, + q_correction_factor=dynamic_runtime_params_slice.numerics.q_correction_factor, + geo=geo, + core_profiles=core_profiles, + ) + + transport = dynamic_runtime_params_slice.transport + # Assert required for pytype. + assert isinstance( + transport, + tglf_based_transport_model.DynamicRuntimeParams, + ) + + return self._make_core_transport( + qi=jnp.ones(geo.rho_face_norm.shape) * 0.4, + qe=jnp.ones(geo.rho_face_norm.shape) * 0.5, + pfe=jnp.ones(geo.rho_face_norm.shape) * 1.6, + quasilinear_inputs=tglf_inputs, + transport=transport, + geo=geo, + core_profiles=core_profiles, + gradient_reference_length=geo.Rmaj, # TODO + gyrobohm_flux_reference_length=geo.Rmin, # TODO + ) + + +if __name__ == '__main__': + absltest.main() diff --git a/torax/transport_model/tglf_based_transport_model.py b/torax/transport_model/tglf_based_transport_model.py new file mode 100644 index 00000000..acb17f73 --- /dev/null +++ b/torax/transport_model/tglf_based_transport_model.py @@ -0,0 +1,220 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Base class and utils for TGLF-based models.""" + +import chex +from jax import numpy as jnp + +from torax.geometry import geometry +from torax import physics +from torax import state +from torax.constants import CONSTANTS +from torax.transport_model import quasilinear_transport_model +from torax.transport_model import runtime_params as runtime_params_lib + + +@chex.dataclass +class RuntimeParams(quasilinear_transport_model.RuntimeParams): + """Shared parameters for TGLF-based models.""" + + def make_provider( + self, torax_mesh: geometry.Grid1D | None = None + ) -> "RuntimeParamsProvider": + return RuntimeParamsProvider(**self.get_provider_kwargs(torax_mesh)) + + +@chex.dataclass(frozen=True) +class DynamicRuntimeParams(quasilinear_transport_model.DynamicRuntimeParams): + """Shared parameters for TGLF-based models.""" + + pass + + +@chex.dataclass +class RuntimeParamsProvider(runtime_params_lib.RuntimeParamsProvider): + """Provides a RuntimeParams to use during time t of the sim.""" + + runtime_params_config: RuntimeParams + + def build_dynamic_params(self, t: chex.Numeric) -> DynamicRuntimeParams: + return DynamicRuntimeParams(**self.get_dynamic_params_kwargs(t)) + + +@chex.dataclass(frozen=True) +class TGLFInputs(quasilinear_transport_model.QuasilinearInputs): + r"""Dimensionless inputs to TGLF-based models. + + See https://gafusion.github.io/doc/tglf/tglf_table.html for definitions. + """ + + # Ti/Te + Ti_over_Te: chex.Array + # drmaj/dr (flux surface centroid major radius gradient) + drmaj: chex.Array + # q + q: chex.Array + # r/q dq/dr + s_hat: chex.Array + # nu_ee (see note in prepare_tglf_inputs) + nu_ee: chex.Array + # Elongation, kappa + kappa: chex.Array + # Shear in elongation, r/kappa dkappa/dr + kappa_shear: chex.Array + # Triangularity, delta + delta: chex.Array + # Shear in triangularity, r ddelta/dr + delta_shear: chex.Array + # Electron pressure defined w.r.t B_unit + beta_e: chex.Array + # Effective charge + Zeff: chex.Array + + +class TGLFBasedTransportModel( + quasilinear_transport_model.QuasilinearTransportModel +): + """Base class for TGLF-based transport models.""" + + def _prepare_tglf_inputs( + self, + Zeff_face: chex.Array, + q_correction_factor: chex.Numeric, + geo: geometry.Geometry, + core_profiles: state.CoreProfiles, + ) -> TGLFInputs: + # Shorthand 'standard' variables + Te_keV = core_profiles.temp_el.face_value() + Te_eV = Te_keV * 1e3 + Te_J = Te_keV * CONSTANTS.keV2J + Ti_keV = core_profiles.temp_ion.face_value() + ne = core_profiles.ne.face_value() * core_profiles.nref + # q must be recalculated since in the nonlinear solver psi has intermediate + # states in the iterative solve + q, _ = physics.calc_q_from_psi( + geo=geo, + psi=core_profiles.psi, + q_correction_factor=q_correction_factor, + ) + + # Reference values used for TGLF-specific normalisation + # - 'a' in TGLF means the minor radius at the LCFS + # - 'r' in TGLF means the flux surface centroid minor radius. Gradients are + # taken w.r.t. r + # https://gafusion.github.io/doc/tglf/tglf_list.html#rmin-loc + # - B_unit = 1/r d(psi_tor)/dr = q/r dpsi/dr + # https://gafusion.github.io/doc/geometry.html#effective-field + # - c_s (ion sound speed) + # https://gafusion.github.io/doc/cgyro/outputs.html#output-normalization + m_D_amu = 2.014 # Mass of deuterium - TODO: load from lookup table + m_D = m_D_amu * CONSTANTS.mp # Mass of deuterium + c_s = (Te_J / m_D) ** 0.5 + a = geo.Rmin # Device minor radius at LCFS + r = geo.rmid_face # Flux surface centroid minor radius + B_unit = q / r * jnp.gradient(core_profiles.psi.face_value(), r) + + # Dimensionless gradients + normalized_log_gradients = quasilinear_transport_model.NormalizedLogarithmicGradients.from_profiles( + core_profiles=core_profiles, + radial_coordinate=geo.rmid, # TODO: Why does this have to be a variable on the cell grid? + reference_length=a, + ) + + # Dimensionless temperature ratio + Ti_over_Te = Ti_keV / Te_keV + + # Dimensionless electron-electron collision frequency = nu_ee / (c_s/a) + # https://gafusion.github.io/doc/tglf/tglf_list.html#xnue + # https://gafusion.github.io/doc/cgyro/cgyro_list.html#cgyro-nu-ee + # Note: In the TGLF docs, XNUE is mislabelled as electron-ion collision frequency. + # It is actually the electron-electron collision frequency, and is defined as in CGYRO + # See https://pyrokinetics.readthedocs.io/en/latest/user_guide/collisions.html#tglf + # Lambda_ee is computed with keV and m^-3 units + # normalised_nu_ee is computed with SI units (ie J rather than keV) + Lambda_ee = physics._calculate_lambda_ee(Te_keV, ne) + normalised_nu_ee = (4 * jnp.pi * ne * CONSTANTS.qe**4 * Lambda_ee) / ( + CONSTANTS.me**0.5 * (2 * Te_J) ** 1.5 + ) + nu_ee = normalised_nu_ee / (c_s / a) + + # Safety factor, q + # https://gafusion.github.io/doc/tglf/tglf_list.html#q-sa + # defined before + + # Safety factor shear, s_hat = r/q dq/dr + # https://gafusion.github.io/doc/tglf/tglf_list.html#tglf-shat-sa + # Note: calc_s_from_psi_rmid gives rq dq/dr, so we divide by q**2 + # r_mid = r + s_hat = physics.calc_s_from_psi_rmid(geo, core_profiles.psi) / q**2 + + # Electron beta + # https://gafusion.github.io/doc/tglf/tglf_list.html#tglf-betae + # Note: Te in eV + beta_e = 8 * jnp.pi * ne * Te_eV / B_unit**2 + + # Major radius shear = drmaj/drmin, where 'rmaj' is the flux surface centroid + # major radius and 'rmin' the flux surface centroid minor radius + # https://gafusion.github.io/doc/tglf/tglf_list.html#tglf-drmajdx-loc + rmaj = ( + geo.Rin_face + geo.Rout_face + ) / 2 # Flux surface centroid maj radius + drmaj = jnp.gradient(rmaj, r) + + # Elongation shear = r/kappa dkappa/dr + # https://gafusion.github.io/doc/tglf/tglf_list.html#tglf-s-kappa-loc + kappa = geo.elongation_face + kappa_shear = geo.rmid_face / kappa * jnp.gradient(kappa, r) + + # Triangularity shear = r ddelta/dr + # https://gafusion.github.io/doc/tglf/tglf_list.html#tglf-s-delta-loc + delta_shear = r * jnp.gradient(geo.delta_face, r) + + # Gyrobohm diffusivity + # https://gafusion.github.io/doc/tglf/tglf_table.html#id7 + # https://gafusion.github.io/doc/cgyro/outputs.html#output-normalization + # Note: TGLF uses the same normalisation as CGYRO + # This has an extra c^2 factor compared to TORAX's calculate_chiGB + chiGB = ( + quasilinear_transport_model.calculate_chiGB( + reference_temperature=Te_keV, # conversion to J done internally + reference_magnetic_field=B_unit, + reference_mass=m_D_amu, + reference_length=a, + ) + * CONSTANTS.c**2 + ) + + return TGLFInputs( + # From QuasilinearInputs + chiGB=chiGB, + Rmin=geo.Rmin, + Rmaj=geo.Rmaj, + lref_over_lti=normalized_log_gradients.lref_over_lti, + lref_over_lte=normalized_log_gradients.lref_over_lte, + lref_over_lne=normalized_log_gradients.lref_over_lne, + lref_over_lni0=normalized_log_gradients.lref_over_lni0, + lref_over_lni1=normalized_log_gradients.lref_over_lni1, + # From TGLFInputs + Ti_over_Te=Ti_over_Te, + drmaj=drmaj, + q=q, + s_hat=s_hat, + nu_ee=nu_ee, + kappa=kappa, + kappa_shear=kappa_shear, + delta=geo.delta_face, + delta_shear=delta_shear, + beta_e=beta_e, + Zeff=Zeff_face, + ) diff --git a/torax/transport_model/tglf_wrapper.py b/torax/transport_model/tglf_wrapper.py new file mode 100755 index 00000000..c0dd5b25 --- /dev/null +++ b/torax/transport_model/tglf_wrapper.py @@ -0,0 +1,403 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A wrapper around tglf. + +The wrapper calls tglf itself. Must be run with +TORAX_COMPILATION_ENABLED=False. Used for generating ground truth for QLKNN11D +evaluation. Kept as an internal model. +""" + +from __future__ import annotations + +from collections.abc import Callable +import dataclasses +import datetime +import os +import subprocess +import tempfile +from functools import partial +from typing import Dict, List, Union +from dataclasses import fields +from multiprocessing import Pool +import pandas as pd + +import chex +import numpy as np +from quasilinear_utils import QuasilinearTransportModel +#tglf_tools.tglf_io import inputfiles as tglf_inputtools +#from tglf_tools.tglf_io import tglfrun as tglf_runtools +from torax import geometry +from torax import jax_utils +from torax import state +from torax.config import runtime_params_slice +from torax.transport_model import tglf_based_transport_model +from torax.transport_model import runtime_params as runtime_params_lib +from torax.transport_model import transport_model +from torax.transport import tglf_tools + + + +# pylint: disable=invalid-name +@chex.dataclass +class RuntimeParams(tglf_based_transport_model.RuntimeParams): + """Extends the base runtime params with additional params for this model. + + See base class runtime_params.RuntimeParams docstring for more info. + """ + numprocs: int = 2 + NBASIS_MAX: int = 4 + NBASIS_MIN: int = 2 + USE_TRANSPORT_MODEL: bool = True + NS: int = 2 + NXGRID: int = 16 + GEOMETRY_FLAG: int = 1 + USE_BPER: bool = True + USE_BPAR: bool = True + KYGRID_MODEL: int = 4 + SAT_RULE: int = 1 + USE_MHD_RULE: bool = False + ALPHA_ZF: int = -1 + FILTER: float = 2.0 + + def make_provider( + self, torax_mesh: geometry.Grid1D | None = None + ) -> 'RuntimeParamsProvider': + return RuntimeParamsProvider(**self.get_provider_kwargs(torax_mesh)) + + +@chex.dataclass(frozen=True) +class DynamicRuntimeParams(tglf_based_transport_model.DynamicRuntimeParams): + numprocs: int + NBASIS_MAX: int + NBASIS_MIN: int + USE_TRANSPORT_MODEL: bool + NS: int + NXGRID: int + GEOMETRY_FLAG: int + USE_BPER: bool + USE_BPAR: bool + KYGRID_MODEL: int + SAT_RULE: int + USE_MHD_RULE: bool + ALPHA_ZF: int + FILTER: float + +class RuntimeParamsProvider(runtime_params_lib.RuntimeParamsProvider): + """Provides a RuntimeParams to use during time t of the sim.""" + + runtime_params_config: RuntimeParams + + def build_dynamic_params(self, t: chex.Numeric) -> DynamicRuntimeParams: + return DynamicRuntimeParams(**self.get_dynamic_params_kwargs(t)) + +_DEFAULT_tglfrun_NAME_PREFIX = 'torax_tglf_runs' +_DEFAULT_TGLF_EXEC_PATH = '~/tglf/tglf' +_TGLF_EXEC_PATH = os.environ.get( + 'TORAX_TGLF_EXEC_PATH', _DEFAULT_TGLF_EXEC_PATH +) + +class TGLFTransportModel(tglf_based_transport_model.TGLFBasedTransportModel): + """Calculates turbulent transport coefficients with tglf.""" + + def __init__( + self, + runtime_params: RuntimeParams | None = None, + ): + self._runtime_params = runtime_params or RuntimeParams() + self._tglfrun_parentdir = tempfile.TemporaryDirectory() + self._tglfrun_name = ( + _DEFAULT_tglfrun_NAME_PREFIX + + datetime.datetime.now().strftime('%Y%m%d_%H%M%S') + ) + self._runpath = os.path.join(self._tglfrun_parentdir.name, self._tglfrun_name) + self._frozen = True + + @property + def runtime_params(self) -> RuntimeParams: + return self._runtime_params + + @runtime_params.setter + def runtime_params(self, runtime_params: RuntimeParams) -> None: + self._runtime_params = runtime_params + + def _get_one_simulation_rundir(self, n: float): + return os.path.join(self._runpath,f'sim_{n}') + + def _call_implementation( + self, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, + geo: geometry.Geometry, + core_profiles: state.CoreProfiles, + ) -> state.CoreTransport: + """Calculates several transport coefficients simultaneously. + + Args: + dynamic_runtime_params_slice: Input runtime parameters + geo: Geometry of the torus. + core_profiles: Core plasma profiles. + + Returns: + coeffs: transport coefficients + + Raises: + EnvironmentError: if TORAX_COMPILATION_ENABLED is set to True. + """ + + if jax_utils.env_bool('TORAX_COMPILATION_ENABLED', True): + raise EnvironmentError( + 'TORAX_COMPILATION_ENABLED environment variable is set to True.' + 'JAX Compilation is not supported with tglf.' + ) + + # TODO + assert isinstance( + dynamic_runtime_params_slice.transport, DynamicRuntimeParams + ) + transport = dynamic_runtime_params_slice.transport + + tglf_inputs = self._prepare_tglf_inputs( + Zeff_face=dynamic_runtime_params_slice.plasma_composition.Zeff_face, + q_correction_factor=dynamic_runtime_params_slice.numerics.q_correction_factor, + geo=geo, + core_profiles=core_profiles, + ) + # Generate list of dictionaries that will correspond to input.tglf + tglf_plan = _extract_tglf_plan( + tglf_inputs=tglf_inputs, + dynamic_runtime_params_slice=dynamic_runtime_params_slice + ) + self._run_tglf( + tglf_plan=tglf_plan, + numprocs=dynamic_runtime_params_slice.transport.numprocs + ) + core_transport = self._extract_run_data( + tglf_inputs=tglf_inputs, + transport=transport, + geo=geo, + core_profiles=core_profiles, + ) + + return core_transport + + def _run_tglf( + self, + tglf_plan: List[Dict[str,Union[int,float,bool]]], + numprocs: int, + verbose: bool = True, + ) -> None: + """Runs tglf using command line tools. Loose coupling with TORAX.""" + + # Prepare parent run directory + if not os.path.exists(self._runpath): + os.makedirs(self._runpath) + + num_simulations = len(tglf_plan) + for n in num_simulations: + # Prepare local simulation directory + this_rundir = self._get_one_simulation_rundir(n) + if not os.path.exists(this_rundir): + os.makedirs(this_rundir) + # Dump input file + with open(os.path.join(this_rundir,'input.tglf'), 'w') as f: + for key,value in tglf_plan[n].items(): + f.write(f'{key}={value}') + # run TGLF + command = [ + 'tglf', + '-n', + str(numprocs), + '-e', + '.' + ] + process = subprocess.Popen( + command, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + cwd=this_rundir + ) + if verbose: + # Get output and error messages + stdout, stderr = process.communicate() + + # Print the output + print(stdout.decode()) + + # Print any error messages + if stderr: + print(stderr.decode()) + + def _extract_run_data( + self, + tglf_inputs: tglf_based_transport_model.TGLFInputs, + transport: DynamicRuntimeParams, + geo: geometry.Geometry, + core_profiles: state.CoreProfiles, + ) -> state.CoreTransport: + """Extracts tglf run data from runpath.""" + + qi = [] + qe = [] + pfe = [] + for run_dir in list(filter(os.path.isdir, os.listdir(self._runpath))): + df = pd.read_fwf(os.path.join(run_dir, 'out.tglf.run'), skiprows=5, index_col=0) + pfe.append(df.loc['elec','Gam/Gam_GB']) + qe.append(df.loc['elec','Q/Q_GB']) + qi.append(df.loc['ion1','Q/Q_GB']) + return self._make_core_transport( + qi=qi, + qe=qe, + pfe=pfe, + quasilinear_inputs=tglf_inputs, + transport=transport, + geo=geo, + core_profiles=core_profiles, + ) + + +def _extract_tglf_plan( + tglf_inputs: tglf_based_transport_model.TGLFInputs, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, + )->List[Dict[str,Union[int,float,bool]]]: + """Converts TORAX parameters to tglf input file input.tglf. Currently only supports electron and main ion. + + Args: + tglf_inputs: Precomputed physics data. + dynamic_runtime_params_slice: Runtime params at time t. + geo: TORAX geometry object. + core_profiles: TORAX CoreProfiles object, containing time-evolvable + quantities like q + + Returns: + A list containing dictionaries of input configs for TGLF. + """ + + def _get_q_prime_loc(params): + return params['SHAT']*(params['Q_LOC']/params['RMIN_LOC'])**2 # --- q_prime_loc = (q / r) ** 2 * shat = q / r * dq/dr + + def _get_p_prime_loc(params): + # The below is a manipulation of https://gafusion.github.io/doc/tglf/tglf_list.html#p-prime-loc + # --- p_prime_loc = -q/r * betae/(8*pi) * \sum_k [n_k/n_e * T_k/T_e * (a/Ln_k + a/LT_k)] + first_term = params['Q_LOC']/params['RMIN_LOC'] + second_term = params['BETAE']/(8*np.pi) + Sum_term = 0 + for k in [1,2]: #Electrons and main ion only + tmp = params[f'AS_{k}']*params[f'TAUS_{k}']*(params[f'RLNS_{k}']+params[f'RLTS_{k}']) + Sum_term += tmp + return -first_term*second_term*Sum_term + + def add_missing_params( + params: Dict[str,float], + transport: DynamicRuntimeParams)->Dict[str,float]: + """Utility to create TGLF input file + + Args: + physical_params: Physical parameters + transport: Runtime params at time t + + Returns: + TGLF inputs inclusive of numerics and other parameters not included in the + + """ + numerics_params = { + 'NBASIS_MAX': transport.NBASIS_MAX, + 'NBASIS_MIN': transport.NBASIS_MIN, + 'USE_TRANSPORT_MODEL': transport.USE_TRANSPORT_MODEL, + 'NS': transport.NS, + 'NXGRID': transport.NXGRID, + 'GEOMETRY_FLAG': transport.GEOMETRY_FLAG, + 'USE_BPER': transport.USE_BPER, + 'USE_BPAR': transport.USE_BPAR, + 'KYGRID_MODEL': transport.KYGRID_MODEL, + 'SAT_RULE': transport.SAT_RULE, + 'USE_MHD_RULE': transport.USE_MHD_RULE, + 'ALPHA_ZF': transport.ALPHA_ZF, + 'FILTER': transport.FILTER + } + params['P_PRIME_LOC'] = _get_p_prime_loc(params) + params['Q_PRIME_LOC'] = _get_q_prime_loc(params) + params.update(numerics_params) + return params + + assert isinstance( + dynamic_runtime_params_slice.transport, DynamicRuntimeParams + ) + transport: DynamicRuntimeParams = dynamic_runtime_params_slice.transport + prepare_input_dict = partial(add_missing_params, transport=transport) + zipped_arrays = zip( + np.array(tglf_inputs.Rmin), + np.array(tglf_inputs.dRmaj), + np.array(tglf_inputs.q), + np.array(tglf_inputs.Ate), + np.array(tglf_inputs.Ati), + np.array(tglf_inputs.Ane), + np.array(tglf_inputs.Ti_over_Te), + np.array(tglf_inputs.nu_ee), + np.array(tglf_inputs.kappa), + np.array(tglf_inputs.kappa_shear), + np.array(tglf_inputs.delta), + np.array(tglf_inputs.delta_shear), + np.array(tglf_inputs.beta_e), + ) + tglf_plan = [ + prepare_input_dict( + { + 'RMIN_LOC': rmin, + 'DRMAJDX_LOC': dR, + 'Q_LOC': q, + 'RMAJ_LOC': R, + 'RLTS_1': ate, + 'RLTS_2': ati, + 'RLNS_1': ane, + 'RLNS_2': ane, # quasineutrality + 'TAUS_1': 1, + 'TAUS_2': tie, + 'XNUE': nu, + 'KAPPA_LOC': k, + 'S_KAPPA_LOC': sk, + 'DELTA_LOC': d, + 'S_DELTA_LOC': sd, + 'BETAE': b, + 'ZEFF': z, + 'AS_1': 1, + 'AS_2': 1 + } + ) + for rmin, dR, q, R, ate, ati, ane, tie, nu, k, sk, d, sd, b, z in zipped_arrays + ] + return tglf_plan + + +def _default_tglf_builder() -> TGLFTransportModel: + return TGLFTransportModel() + + +@dataclasses.dataclass(kw_only=True) +class tglfTransportModelBuilder(transport_model.TransportModelBuilder): + """Builds a class tglfTransportModel.""" + + runtime_params: RuntimeParams = dataclasses.field( + default_factory=RuntimeParams + ) + model_path: str | None = None + + _builder: Callable[ + [], + TGLFTransportModel, + ] = _default_tglf_builder + + def __call__( + self, + ) -> TGLFTransportModel: + return self._builder() diff --git a/torax/transport_model/tglfnn.py b/torax/transport_model/tglfnn.py new file mode 100644 index 00000000..f8a12116 --- /dev/null +++ b/torax/transport_model/tglfnn.py @@ -0,0 +1,105 @@ +import chex +import jax.numpy as jnp +from flax import linen as nn + + +class TGLFNN(nn.Module): + """A simple MLP with dropout layers, ReLU activation, and outputting a mean and variance.""" + + hidden_dimension: int + n_hidden_layers: int + dropout: float + input_means: chex.Array + input_stds: chex.Array + output_mean: float + output_std: float + + @nn.compact + def __call__( + self, + x, + deterministic: bool = False, + standardise_inputs: bool = True, + standardise_outputs: bool = False, + ): + if standardise_inputs: + # Transform to 0 mean and unit variance + x = (x - self.input_means) / self.input_stds + + x = nn.Dense(self.hidden_dimension)(x) + x = nn.Dropout(rate=self.dropout, deterministic=deterministic)(x) + x = nn.relu(x) + for _ in range(self.n_hidden_layers): + x = nn.Dense(self.hidden_dimension)(x) + x = nn.Dropout(rate=self.dropout, deterministic=deterministic)(x) + x = nn.relu(x) + mean_and_var = nn.Dense(2)(x) + mean = mean_and_var[..., 0] + var = mean_and_var[..., 1] + var = nn.softplus(var) + + if not standardise_outputs: + # Transform back from 0 mean and unit variance + mean = mean * self.output_std + self.output_mean + var = var * self.output_std**2 + + return jnp.stack([mean, var], axis=-1) + + +class EnsembleTGLFNN(nn.Module): + """An ensemble of TGLFNN models.""" + + input_means: chex.Array + input_stds: chex.Array + output_mean: chex.Array + output_std: chex.Array + n_models: int = 5 + hidden_dimension: int = 512 + n_hidden_layers: int = 4 + dropout: float = 0.05 + + def setup( + self, + ): + self.models = [ + TGLFNN( + hidden_dimension=self.hidden_dimension, + n_hidden_layers=self.n_hidden_layers, + dropout=self.dropout, + input_means=self.input_means, + input_stds=self.input_stds, + output_mean=self.output_mean, + output_std=self.output_std, + ) + for i in range(self.n_models) + ] + + def __call__(self, x, *args, **kwargs): + # Shape is batch size x 2 x n_models + outputs = jnp.stack( + [model(x, *args, **kwargs) for model in self.models], axis=-1 + ) + # Shape is batch_size + mean = jnp.mean(outputs[:, 0, :], axis=-1) + aleatoric_uncertainty = jnp.mean(outputs[:, 1, :], axis=-1) + epistemic_uncertainty = jnp.var(outputs[:, 0, :], axis=-1) + return jnp.stack([mean, aleatoric_uncertainty + epistemic_uncertainty], axis=-1) + + def get_params_from_pytorch_state_dict(self, pytorch_state_dict: dict): + params = {} + for i in range(self.n_models): + model_dict = {} + for j in range(self.n_hidden_layers + 2): # +2 for input and output layers + # j*3 to skip dropout and activation + layer_dict = { + "kernel": jnp.array( + pytorch_state_dict[f"models.{i}.model.{j*3}.weight"] + ).T, + "bias": jnp.array( + pytorch_state_dict[f"models.{i}.model.{j*3}.bias"] + ).T, + } + model_dict[f"Dense_{j}"] = layer_dict + params[f"models_{i}"] = model_dict + + return params