Skip to content

Commit

Permalink
add documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
FFroehlich committed Jan 29, 2025
1 parent 95207d7 commit 702a3e9
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 45 deletions.
12 changes: 10 additions & 2 deletions python/sdist/amici/jax/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,10 +511,15 @@ def simulate_condition(
:param adjoint:
adjoint method. Recommended values are `diffrax.DirectAdjoint()` for jax.jacfwd (with vector-valued
outputs) and `diffrax.RecursiveCheckpointAdjoint()` for jax.grad (for scalar-valued outputs).
:param steady_state_event:
event function for steady state. See :func:`diffrax.steady_state_event` for details.
:param max_steps:
maximum number of solver steps
:param ret:
which output to return. See :class:`ReturnValue` for available options.
:param ts_mask:
mask to remove (padded) time points. If `True`, the corresponding time point is used for the evaluation of
the output. Only applied if ret is ReturnValue.llh, ReturnValue.nllhs, ReturnValue.res, or ReturnValue.chi2.
:return:
output according to `ret` and statistics
"""
Expand All @@ -524,7 +529,7 @@ def simulate_condition(
x = self._x0(p)

if not ts_mask.shape[0]:
ts_mask = jnp.zeros_like(my, dtype=jnp.bool_)
ts_mask = jnp.ones_like(my, dtype=jnp.bool_)

# Re-initialization
if x_reinit.shape[0]:
Expand Down Expand Up @@ -615,9 +620,12 @@ def simulate_condition(
m_obj = obs_trafo(my, iy_trafos)
if ret == ReturnValue.chi2:
sigma_obj = self._sigmays(ts, x, p, tcl, iys)
output = jnp.sum(jnp.square((ys_obj - m_obj) / sigma_obj))
chi2 = jnp.square((ys_obj - m_obj) / sigma_obj)
chi2 = jnp.where(ts_mask, chi2, 0.0)
output = jnp.sum(chi2)
else:
output = ys_obj - m_obj
output = jnp.where(ts_mask, output, 0.0)
else:
raise NotImplementedError(f"Return value {ret} not implemented.")

Expand Down
189 changes: 146 additions & 43 deletions python/sdist/amici/jax/petab.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,37 @@ def update_parameters(self, p: jt.Float[jt.Array, "np"]) -> "JAXProblem":
"""
return eqx.tree_at(lambda p: p.parameters, self, p)

def _prepare_conditions(
self, conditions: Iterable[str]
) -> tuple[
jt.Float[jt.Array, "np"], # noqa: F821
jt.Bool[jt.Array, "nx"], # noqa: F821
jt.Float[jt.Array, "nx"], # noqa: F821
]:
"""
Prepare conditions for simulation.
:param conditions:
Simulation conditions to prepare.
:return:
Tuple of parameter arrays, reinitialisation masks and reinitialisation values.
"""
p_array = jnp.stack([self.load_parameters(sc) for sc in conditions])

mask_reinit_array = jnp.stack(
[
self.load_reinitialisation(sc, p)[0]
for sc, p in zip(conditions, p_array)
]
)
x_reinit_array = jnp.stack(
[
self.load_reinitialisation(sc, p)[1]
for sc, p in zip(conditions, p_array)
]
)
return p_array, mask_reinit_array, x_reinit_array

@eqx.filter_vmap(
in_axes={
"max_steps": None,
Expand Down Expand Up @@ -577,14 +608,36 @@ def run_simulation(
"""
Run a simulation for a given simulation condition.
:param p:
Parameters for the simulation condition
:param ts_dyn:
(Padded) dynamic time points
:param ts_posteq:
(Padded) ost-equilibrium time points
:param my:
(Padded) measurements
:param iys:
(Padded) observable indices
:param iy_trafos:
(Padded) observable transformations indices
:param mask_reinit:
Mask for states that need reinitialisation
:param x_reinit:
Reinitialisation values for states
:param solver:
ODE solver to use for simulation
:param controller:
Step size controller to use for simulation
:param steady_state_event:
Steady state event function to use for post-equilibration. Allows customisation of the steady state
condition, see :func:`diffrax.steady_state_event` for details.
:param max_steps:
Maximum number of steps to take during simulation
:param x_preeq:
Pre-equilibration state if available
Pre-equilibration state. Can be empty if no pre-equilibration is available, in which case the states will
be initialised to the model default values.
:param ts_mask:
padding mask, see :meth:`JAXModel.simulate_condition` for details.
:param ret:
which output to return. See :class:`ReturnValue` for available options.
:return:
Expand All @@ -611,6 +664,61 @@ def run_simulation(
ret=ret,
)

def run_simulations(
self,
simulation_conditions: list[str],
preeq_array: jt.Float[jt.Array, "ncond *nx"], # noqa: F821, F722
solver: diffrax.AbstractSolver,
controller: diffrax.AbstractStepSizeController,
steady_state_event: Callable[
..., diffrax._custom_types.BoolScalarLike
],
max_steps: jnp.int_,
ret: ReturnValue = ReturnValue.llh,
):
"""
Run simulations for a list of simulation conditions.
:param simulation_conditions:
List of simulation conditions to run simulations for.
:param preeq_array:
Matrix of pre-equilibrated states for the simulation conditions. Ordering must match the simulation
conditions. If no pre-equilibration is available for a condition, the corresponding row must be empty.
:param solver:
ODE solver to use for simulation.
:param controller:
Step size controller to use for simulation.
:param steady_state_event:
Steady state event function to use for post-equilibration. Allows customisation of the steady state
condition, see :func:`diffrax.steady_state_event` for details.
:param max_steps:
Maximum number of steps to take during simulation.
:param ret:
which output to return. See :class:`ReturnValue` for available options.
:return:
Output value and condition specific results and statistics. Results and statistics are returned as a dict
with arrays with the leading dimension corresponding to the simulation conditions.
"""
p_array, mask_reinit_array, x_reinit_array = self._prepare_conditions(
simulation_conditions
)
return self.run_simulation(
p_array,
self._ts_dyn,
self._ts_posteq,
self._my,
self._iys,
self._iy_trafos,
mask_reinit_array,
x_reinit_array,
solver,
controller,
steady_state_event,
max_steps,
preeq_array,
self._ts_masks,
ret,
)

@eqx.filter_vmap(
in_axes={
"max_steps": None,
Expand All @@ -632,12 +740,19 @@ def run_preequilibration(
"""
Run a pre-equilibration simulation for a given simulation condition.
:param simulation_condition:
Simulation condition to run simulation for.
:param p:
Parameters for the simulation condition
:param mask_reinit:
Mask for states that need reinitialisation
:param x_reinit:
Reinitialisation values for states
:param solver:
ODE solver to use for simulation
:param controller:
Step size controller to use for simulation
:param steady_state_event:
Steady state event function to use for pre-equilibration. Allows customisation of the steady state
condition, see :func:`diffrax.steady_state_event` for details.
:param max_steps:
Maximum number of steps to take during simulation
:return:
Expand All @@ -653,6 +768,29 @@ def run_preequilibration(
steady_state_event=steady_state_event,
)

def run_preequilibrations(
self,
simulation_conditions: list[str],
solver: diffrax.AbstractSolver,
controller: diffrax.AbstractStepSizeController,
steady_state_event: Callable[
..., diffrax._custom_types.BoolScalarLike
],
max_steps: jnp.int_,
):
p_array, mask_reinit_array, x_reinit_array = self._prepare_conditions(
simulation_conditions
)
return self.run_preequilibration(
p_array,
mask_reinit_array,
x_reinit_array,
solver,
controller,
steady_state_event,
max_steps,
)


def run_simulations(
problem: JAXProblem,
Expand Down Expand Up @@ -700,13 +838,8 @@ def run_simulations(
)

if preequilibration_conditions:
p_array, mask_reinit_array, x_reinit_array = _prepare_conditions(
problem, preequilibration_conditions
)
preeqs, preresults = problem.run_preequilibration(
p_array,
mask_reinit_array,
x_reinit_array,
preeqs, preresults = problem.run_preequilibrations(
preequilibration_conditions,
solver,
controller,
steady_state_event,
Expand All @@ -726,25 +859,13 @@ def run_simulations(
]
)

### simulation
p_array, mask_reinit_array, x_reinit_array = _prepare_conditions(
problem, dynamic_conditions
)
output, results = problem.run_simulation(
p_array,
problem._ts_dyn,
problem._ts_posteq,
problem._my,
problem._iys,
problem._iy_trafos,
mask_reinit_array,
x_reinit_array,
output, results = problem.run_simulations(
dynamic_conditions,
preeq_array,
solver,
controller,
steady_state_event,
max_steps,
preeq_array,
problem._ts_masks,
ret,
)

Expand All @@ -760,24 +881,6 @@ def run_simulations(
}


def _prepare_conditions(problem: JAXProblem, conditions: Iterable[str]):
p_array = jnp.stack([problem.load_parameters(sc) for sc in conditions])

mask_reinit_array = jnp.stack(
[
problem.load_reinitialisation(sc, p)[0]
for sc, p in zip(conditions, p_array)
]
)
x_reinit_array = jnp.stack(
[
problem.load_reinitialisation(sc, p)[1]
for sc, p in zip(conditions, p_array)
]
)
return p_array, mask_reinit_array, x_reinit_array


def petab_simulate(
problem: JAXProblem,
solver: diffrax.AbstractSolver = diffrax.Kvaerno5(),
Expand Down

0 comments on commit 702a3e9

Please sign in to comment.