Skip to content

Commit

Permalink
Pass through the calculated source_profiles when building sources
Browse files Browse the repository at this point in the history
This enforces an order that sources should be built such that any sources needed to calculate that source have been calculated

PiperOrigin-RevId: 726057164
  • Loading branch information
tamaranorman authored and Torax team committed Feb 12, 2025
1 parent 7904949 commit 341b154
Show file tree
Hide file tree
Showing 11 changed files with 233 additions and 331 deletions.
13 changes: 5 additions & 8 deletions torax/fvm/calc_coeffs.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,25 +373,22 @@ def _calc_coeffs_full(
dynamic_runtime_params_slice.profile_conditions.set_pedestal,
)

# This only calculates sources set to implicit in the config. All other
# sources are set to 0 (and should have their profiles already calculated in
# explicit_source_profiles).
implicit_source_profiles = source_profile_builders.build_source_profiles(
# Calculate the implicit source profiles and combines with the explicit
merged_source_profiles = source_profile_builders.build_source_profiles(
source_models=source_models,
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
static_runtime_params_slice=static_runtime_params_slice,
geo=geo,
core_profiles=core_profiles,
explicit=False,
)
merged_source_profiles = source_profiles_lib.SourceProfiles.merge(
explicit_source_profiles=explicit_source_profiles,
implicit_source_profiles=implicit_source_profiles,
)
j_bootstrap = merged_source_profiles.j_bootstrap

# Sum over all psi sources (except the bootstrap current).
external_current = sum(merged_source_profiles.psi.values())
# Needed to ensure the correct shape if no psi sources are present.
external_current = jnp.zeros_like(geo.rho)
external_current += sum(merged_source_profiles.psi.values())

currents = dataclasses.replace(
core_profiles.currents,
Expand Down
5 changes: 4 additions & 1 deletion torax/orchestration/step_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,7 +656,10 @@ def _update_current_distribution(
) -> state.CoreProfiles:
"""Update bootstrap current based on the new core_profiles."""
bootstrap_profile = core_sources.j_bootstrap
external_current = sum(core_sources.psi.values())
# Needed for the case where no psi sources are present.
external_current = jnp.zeros_like(
core_profiles.currents.external_current_source)
external_current += sum(core_sources.psi.values())

johm = (
core_profiles.currents.jtot
Expand Down
38 changes: 14 additions & 24 deletions torax/sources/source_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from torax import state
from torax.config import runtime_params_slice
from torax.geometry import geometry
from torax.sources import source as source_lib
from torax.sources import source_models as source_models_lib
from torax.sources import source_profile_builders
from torax.sources import source_profiles as source_profiles_lib
Expand Down Expand Up @@ -79,38 +78,29 @@ def calc_and_sum_sources_psi(
# TODO(b/335597108): Revisit how to calculate this once we enable more
# expensive source functions that might not jittable (like file-based or
# RPC-based sources).
psi_profiles = source_profile_builders.build_standard_source_profiles(
static_runtime_params_slice,
dynamic_runtime_params_slice,
geo,
core_profiles,
source_models,
calculate_anyway=True,
affected_core_profiles=(source_lib.AffectedCoreProfile.PSI,),
)
static_bootstrap_runtime_params = static_runtime_params_slice.sources[
source_models.j_bootstrap_name
]
j_bootstrap_profiles = source_profile_builders.build_bootstrap_profiles(
j_bootstrap_profiles = source_models.j_bootstrap.get_bootstrap(
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
static_runtime_params_slice=static_runtime_params_slice,
static_source_runtime_params=static_bootstrap_runtime_params,
geo=geo,
core_profiles=core_profiles,
j_bootstrap_source=source_models.j_bootstrap,
calculate_anyway=True,
)
source_profiles = source_profiles_lib.SourceProfiles(
profiles = source_profiles_lib.SourceProfiles(
j_bootstrap=j_bootstrap_profiles,
qei=source_profiles_lib.QeiInfo.zeros(geo),
temp_el={},
temp_ion={},
ne={},
psi=psi_profiles[source_lib.AffectedCoreProfile.PSI],
psi={}, temp_el={}, temp_ion={}, ne={},
qei=source_profiles_lib.QeiInfo.zeros(geo))
source_profile_builders.build_standard_source_profiles(
static_runtime_params_slice=static_runtime_params_slice,
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
geo=geo,
core_profiles=core_profiles,
source_models=source_models,
calculate_anyway=True,
psi_only=True,
calculated_source_profiles=profiles,
)

return (
sum_sources_psi(geo, source_profiles=source_profiles),
sum_sources_psi(geo, source_profiles=profiles),
j_bootstrap_profiles.sigma,
j_bootstrap_profiles.sigma_face,
)
Loading

0 comments on commit 341b154

Please sign in to comment.