Skip to content

Commit

Permalink
Add initialisation logic
Browse files Browse the repository at this point in the history
  • Loading branch information
theo-brown committed Dec 6, 2024
1 parent 3b7f33c commit 0051d37
Show file tree
Hide file tree
Showing 8 changed files with 63 additions and 130 deletions.
3 changes: 1 addition & 2 deletions torax/config/profile_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,12 @@
import logging

import chex
from typing_extensions import override

from torax import array_typing
from torax import geometry
from torax import interpolated_param
from torax.config import base
from torax.config import config_args
from typing_extensions import override


# pylint: disable=invalid-name
Expand Down
6 changes: 3 additions & 3 deletions torax/config/tests/profile_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,14 @@

"""Unit tests for the `torax.config.profile_conditions` module."""

import numpy as np
import xarray as xr
from absl.testing import absltest
from absl.testing import parameterized

import numpy as np
from torax import geometry
from torax import interpolated_param
from torax.config import config_args
from torax.config import profile_conditions
import xarray as xr


# pylint: disable=invalid-name
Expand Down Expand Up @@ -251,5 +250,6 @@ def test_profile_conditions_raises_error_if_boundary_condition_not_defined(
Ti_bound_right=None,
)


if __name__ == '__main__':
absltest.main()
76 changes: 52 additions & 24 deletions torax/core_profile_setters.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,8 @@
evolved by the PDE system.
"""
import dataclasses

import jax
from jax import numpy as jnp

from torax import constants
from torax import geometry
from torax import jax_utils
Expand Down Expand Up @@ -607,58 +605,88 @@ def _init_psi_and_current(
dynamic_runtime_params_slice.profile_conditions.Vloop_bound_right is not None
)

# Retrieving psi from the profile conditions.
# Case 1: retrieving psi from the profile conditions.
if dynamic_runtime_params_slice.profile_conditions.psi is not None:
# TODO: do we need to support the case where psi is given, but Vloop_bound_right
# is used to set the BC rather than Ip_tot?
psi = cell_variable.CellVariable(
# Calculate the dpsi/drho necessary to achieve the given Ip_tot
dpsi_drho_edge = _calculate_psi_grad_constraint_from_Ip_tot(
dynamic_runtime_params_slice,
geo,
)

# Set the psi BCs to ensure the correct Ip_tot
if use_Vloop_bound_right:
# Extrapolate using the dpsi/drho calculated above to set the psi value at the right face
psi = cell_variable.CellVariable(
value=dynamic_runtime_params_slice.profile_conditions.psi,
right_face_grad_constraint=_calculate_psi_grad_constraint_from_Ip_tot(
dynamic_runtime_params_slice,
geo,
),
right_face_constraint=dynamic_runtime_params_slice.profile_conditions.psi[-1] + dpsi_drho_edge * geo.drho[-1]/2,
dr=geo.drho_norm,
)
)
else:
# Use the dpsi/drho calculated above as the right face gradient constraint
psi = cell_variable.CellVariable(
value=dynamic_runtime_params_slice.profile_conditions.psi,
right_face_grad_constraint=dpsi_drho_edge,
dr=geo.drho_norm,
)

core_profiles = dataclasses.replace(core_profiles, psi=psi)
currents = _calculate_currents_from_psi(
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
geo=geo,
core_profiles=core_profiles,
source_models=source_models,
)
# Retrieving psi from the standard geometry input.

# Case 2: retrieving psi from the standard geometry input.
elif (
isinstance(geo, geometry.StandardGeometry)
and not dynamic_runtime_params_slice.profile_conditions.initial_psi_from_j
):
# psi is already provided from a numerical equilibrium, so no need to
# first calculate currents. However, non-inductive currents are still
# calculated and used in current diffusion equation.

# Calculate the dpsi/drho necessary to achieve the given Ip_tot
dpsi_drho_edge = _calculate_psi_grad_constraint_from_Ip_tot(
dynamic_runtime_params_slice,
geo,
)

# Set the psi BCs based on whether Vloop is provided and the source of Ip
if use_Vloop_bound_right and geo.Ip_from_parameters:
right_face_grad_constraint = None
right_face_constraint = geo.psi_from_Ip[-1] + dpsi_drho_edge * geo.drho[-1]/2
elif use_Vloop_bound_right:
right_face_grad_constraint = None
right_face_constraint = geo.psi_from_Ip[-1]
else:
right_face_grad_constraint = dpsi_drho_edge
right_face_constraint = None

psi = cell_variable.CellVariable(
value=geo.psi_from_Ip,
right_face_grad_constraint=_calculate_psi_grad_constraint_from_Ip_tot(
dynamic_runtime_params_slice,
geo,
)
if not use_Vloop_bound_right
else None,
right_face_constraint=geo.psi_from_Ip[-1]
if use_Vloop_bound_right
else None,
dr=geo.drho_norm,
value=geo.psi_from_Ip, # Use psi from equilibrium
right_face_grad_constraint=right_face_grad_constraint,
right_face_constraint=right_face_constraint,
dr=geo.drho_norm,
)
core_profiles = dataclasses.replace(core_profiles, psi=psi)
# Calculate non-inductive currents
currents = _calculate_currents_from_psi(
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
geo=geo,
core_profiles=core_profiles,
source_models=source_models,
)
# Calculating j according to nu formula and psi from j.

# Case 3: calculating j according to nu formula and psi from j.
elif (
isinstance(geo, geometry.CircularAnalyticalGeometry)
or dynamic_runtime_params_slice.profile_conditions.initial_psi_from_j
):
# TODO: Vloop_bound_right is not yet supported for this case.
if use_Vloop_bound_right:
raise NotImplementedError('Vloop_bound_right not yet supported for this case.')

currents = _prescribe_currents_no_bootstrap(
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
geo=geo,
Expand Down
90 changes: 0 additions & 90 deletions torax/examples/vloop.py

This file was deleted.

8 changes: 3 additions & 5 deletions torax/sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,13 @@

import dataclasses
import time
from typing import Any
from typing import Optional
from typing import Any, Optional

from absl import logging
import chex
import jax
import jax.numpy as jnp
import numpy as np
import xarray as xr
from absl import logging

from torax import calc_coeffs
from torax import core_profile_setters
from torax import geometry
Expand All @@ -60,6 +57,7 @@
from torax.time_step_calculator import chi_time_step_calculator
from torax.time_step_calculator import time_step_calculator as ts
from torax.transport_model import transport_model as transport_model_lib
import xarray as xr


def _log_timestep(
Expand Down
3 changes: 1 addition & 2 deletions torax/tests/boundary_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@
"""Tests for module torax.boundary_conditions."""


import numpy as np
from absl.testing import absltest
from absl.testing import parameterized

import numpy as np
from torax import constants
from torax import core_profile_setters
from torax import geometry
Expand Down
6 changes: 3 additions & 3 deletions torax/tests/core_profile_setters.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for module torax.boundary_conditions."""
"""Tests for module torax.core_profile_setters."""

import numpy as np
from absl.testing import absltest
from absl.testing import parameterized

import numpy as np
from torax import core_profile_setters
from torax import geometry
from torax import physics
Expand All @@ -28,6 +27,7 @@
from torax.stepper import runtime_params as stepper_params_lib
from torax.transport_model import runtime_params as transport_params_lib


SMALL_VALUE = 1e-6


Expand Down
1 change: 0 additions & 1 deletion torax/tests/test_lib/explicit_stepper.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

import jax
from jax import numpy as jnp

from torax import constants
from torax import core_profile_setters
from torax import geometry
Expand Down

0 comments on commit 0051d37

Please sign in to comment.