Skip to content

Commit

Permalink
Merge pull request #255 from ami-iit/fix/variable_step_integrators
Browse files Browse the repository at this point in the history
Handle auxiliary dictionary in variable step integrators
  • Loading branch information
flferretti authored Oct 4, 2024
2 parents 7a2d193 + 0d5fbf7 commit 5c9215e
Showing 1 changed file with 12 additions and 7 deletions.
19 changes: 12 additions & 7 deletions src/jaxsim/integrators/variable_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,9 @@ def init(
**kwargs,
)

def __call__(self, x0: State, t0: Time, dt: TimeStep, **kwargs) -> NextState:
def __call__(
self, x0: State, t0: Time, dt: TimeStep, **kwargs
) -> tuple[NextState, dict[str, Any]]:

# This method is called differently in three stages:
#
Expand Down Expand Up @@ -294,14 +296,17 @@ def __call__(self, x0: State, t0: Time, dt: TimeStep, **kwargs) -> NextState:
# In Stage 3, dt0 is taken from the previous step. If the integrator supports
# FSAL, dxdt0 is taken from the previous step. Otherwise, it is computed by
# evaluating the dynamics.
self.params["dt0"], self.params["dxdt0"] = jax.lax.cond(
self.params["dt0"], self.params["dxdt0"], aux_dict = jax.lax.cond(
pred=jnp.logical_or("dt0" not in self.params, integrator_first_step),
true_fun=lambda params: estimate_step_size(
x0=x0, t0=t0, f=f, order=p, atol=self.atol, rtol=self.rtol
true_fun=lambda params: (
*estimate_step_size(
x0=x0, t0=t0, f=f, order=p, atol=self.atol, rtol=self.rtol
),
self.params.get("dxdt0", f(x0, t0))[1],
),
false_fun=lambda params: (
params.get("dt0", jnp.array(0).astype(float)),
self.params.get("dxdt0", f(x0, t0)[0]),
*self.params.get("dxdt0", f(x0, t0)),
),
operand=self.params,
)
Expand Down Expand Up @@ -355,7 +360,7 @@ def while_loop_body(carry: Carry) -> Carry:
# The output z contains multiple solutions (depending on the rows of b.T).
with self.editable(validate=True) as integrator:
integrator.params = params
z = integrator._compute_next_state(x0=x0, t0=t0, dt=Δt0, **kwargs)
z, _ = integrator._compute_next_state(x0=x0, t0=t0, dt=Δt0, **kwargs)
params_next = integrator.params

# Extract the high-order solution xf and the low-order estimate x̂f.
Expand Down Expand Up @@ -481,7 +486,7 @@ def reject_step():
with self.mutable_context(mutability=Mutability.MUTABLE):
self.params = params_tf

return xf
return xf, aux_dict

@property
def order_of_solution(self) -> int:
Expand Down

0 comments on commit 5c9215e

Please sign in to comment.