Skip to content

Commit

Permalink
Make ohmic_heat_source use precalculated sources
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 726059816
  • Loading branch information
tamaranorman authored and Torax team committed Feb 12, 2025
1 parent 341b154 commit ca3a492
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 18 deletions.
33 changes: 18 additions & 15 deletions torax/sources/ohmic_heat_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,31 +166,36 @@ def calculate_psidot_from_psi_sources(


def ohmic_model_func(
static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice,
unused_static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice,
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
geo: geometry.Geometry,
unused_source_name: str,
core_profiles: state.CoreProfiles,
unused_calculated_source_profiles: source_profiles.SourceProfiles | None,
source_models: source_models_lib.SourceModels,
calculated_source_profiles: source_profiles.SourceProfiles | None,
unused_source_models: source_models_lib.SourceModels,
) -> tuple[chex.Array, ...]:
"""Returns the Ohmic source for electron heat equation."""
if source_models is None:
raise TypeError('source_models is a required argument for ohmic_model_func')
if calculated_source_profiles is None:
raise ValueError(
'calculated_source_profiles is a required argument for'
' ohmic_model_func. This can occur if this source function is used in'
' an explicit source.'
)

jtot, _, _ = physics.calc_jtot_from_psi(
geo,
core_profiles.psi,
)

psidot = calc_psidot(
static_runtime_params_slice,
dynamic_runtime_params_slice,
geo,
core_profiles,
source_models,
psidot = calculate_psidot_from_psi_sources(
psi_sources=source_operations.sum_sources_psi(
geo, calculated_source_profiles
),
sigma=calculated_source_profiles.j_bootstrap.sigma,
sigma_face=calculated_source_profiles.j_bootstrap.sigma_face,
resistivity_multiplier=dynamic_runtime_params_slice.numerics.resistivity_mult,
psi=core_profiles.psi,
geo=geo,
)

pohm = jtot * psidot / (2 * jnp.pi * geo.Rmaj)
return (pohm,)

Expand All @@ -214,8 +219,6 @@ class OhmicHeatSource(source_lib.Source):
SOURCE_NAME: ClassVar[str] = 'ohmic_heat_source'
DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = 'ohmic_model_func'
model_func: source_lib.SourceProfileFunction = ohmic_model_func
# Users must pass in a pointer to the complete set of sources to this object.
source_models: source_models_lib.SourceModels

@property
def source_name(self) -> str:
Expand Down
1 change: 0 additions & 1 deletion torax/sources/register_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,6 @@ class SupportedSource:
ModelFunction(
source_profile_function=ohmic_heat_source.ohmic_model_func,
runtime_params_class=ohmic_heat_source.OhmicRuntimeParams,
links_back=True,
)
)
},
Expand Down
38 changes: 37 additions & 1 deletion torax/sources/tests/ohmic_heat_source_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import mock

from absl.testing import absltest
from torax.config import runtime_params_slice
from torax.sources import ohmic_heat_source
from torax.sources.tests import test_lib

Expand All @@ -25,9 +28,42 @@ def setUpClass(cls):
source_class=ohmic_heat_source.OhmicHeatSource,
runtime_params_class=ohmic_heat_source.OhmicRuntimeParams,
source_name=ohmic_heat_source.OhmicHeatSource.SOURCE_NAME,
links_back=True,
model_func=ohmic_heat_source.ohmic_model_func,
needs_source_models=True,
)

def test_raises_error_if_calculated_source_profiles_is_none(self):
"""Tests that the source raises an error if calculated_source_profiles is None."""
source = ohmic_heat_source.OhmicHeatSource(
model_func=ohmic_heat_source.ohmic_model_func
)
static_runtime_params_slice = mock.create_autospec(
runtime_params_slice.StaticRuntimeParamsSlice,
instance=True,
sources={
self._source_name: (
self._runtime_params_class().build_static_params()
)
},
)
dynamic_runtime_params_slice = mock.create_autospec(
runtime_params_slice.DynamicRuntimeParamsSlice,
instance=True,
sources={self._source_name: mock.ANY},
)
with self.assertRaisesRegex(
ValueError,
'calculated_source_profiles is a required argument for'
' ohmic_model_func. This can occur if this source function is used in'
' an explicit source.',
):
source.get_value(
static_runtime_params_slice,
dynamic_runtime_params_slice,
mock.ANY,
mock.ANY,
calculated_source_profiles=None,
)


if __name__ == '__main__':
Expand Down
17 changes: 16 additions & 1 deletion torax/sources/tests/test_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from torax.sources import runtime_params as runtime_params_lib
from torax.sources import source as source_lib
from torax.sources import source_models as source_models_lib
from torax.sources import source_profiles


# Most of the checks and computations in TORAX require float64.
Expand Down Expand Up @@ -62,6 +63,7 @@ class SourceTestCase(parameterized.TestCase):
_config_attr_name: str
_source_name: str
_runtime_params_class: Type[runtime_params_lib.RuntimeParams]
_needs_source_models: bool

@classmethod
def setUpClass(
Expand All @@ -71,6 +73,7 @@ def setUpClass(
source_name: str,
model_func: source_lib.SourceProfileFunction | None,
links_back: bool = False,
needs_source_models: bool = False,
source_class_builder: source_lib.SourceBuilderProtocol | None = None,
):
super().setUpClass()
Expand All @@ -87,6 +90,7 @@ def setUpClass(
cls._runtime_params_class = runtime_params_class
cls._links_back = links_back
cls._source_name = source_name
cls._needs_source_models = needs_source_models

def test_runtime_params_builds_dynamic_params(self):
runtime_params = self._runtime_params_class()
Expand Down Expand Up @@ -161,12 +165,23 @@ def test_source_value_on_the_cell_grid(self):
geo=geo,
source_models=source_models,
)
if self._needs_source_models:
calculated_source_profiles = source_profiles.SourceProfiles(
j_bootstrap=source_profiles.BootstrapCurrentProfile.zero_profile(geo),
psi={'foo': jnp.full(geo.rho.shape, 13.0)},
temp_el={'foo_source': jnp.full(geo.rho.shape, 17.0)},
temp_ion={'foo_sink': jnp.full(geo.rho.shape, 19.0)},
ne={},
qei=source_profiles.QeiInfo.zeros(geo)
)
else:
calculated_source_profiles = None
value = source.get_value(
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
static_runtime_params_slice=static_slice,
geo=geo,
core_profiles=core_profiles,
calculated_source_profiles=None,
calculated_source_profiles=calculated_source_profiles,
)[0]
chex.assert_rank(value, 1)
self.assertEqual(value.shape, geo.rho.shape)
Expand Down

0 comments on commit ca3a492

Please sign in to comment.