Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Event and PIDController: event doesn't always occure #507

Open
dv-ai opened this issue Oct 4, 2024 · 3 comments
Open

Event and PIDController: event doesn't always occure #507

dv-ai opened this issue Oct 4, 2024 · 3 comments
Labels
question User queries

Comments

@dv-ai
Copy link

dv-ai commented Oct 4, 2024

First, I want to thank you for your amazing library. You have done a massive work which are very useful for my research.


diffrax 0.6.0
optimistix 0.0.7
jax 0.4.30

When using PID controler and Event functionnality are used simultaneously, I found that the event will not always raised due to the difference of complexity between the event function and the ode function. For example, If the ode function si very simple (straight line), the pid controler will allow to have large steps. Large step can miss the event condition if there is two changes of signs between this two steps.

To avoid this issue, I found that integrate the event condition in the ode function can correct the issue in my particular usecase. Is there a more mathematical grounded method to resolve this issue?

A small python example:

import jax
import jax.numpy as jnp
import jax.random as random


jax.config.update("jax_enable_x64", True)

import diffrax
import optimistix as optx


def event(x, high_pres=False,coeff=1.0):
    x = jnp.concatenate([jnp.ones((1, 1)) * x, jnp.zeros((1, 1))], axis=1)

    #  event_condition = (2.5 - y) . (5.0 - y) . (0.0 - y)
    event_condition = lambda t, y, args, **kwargs: (-y[0, 0] + 2.5) * (-y[0, 0] + 5.0) * (-y[0, 0] + 0.0)
  
    # dx/dt = [1, 0]
    # adding the event condition: dx/dt = [1, 0, | event_condition(t, x) | * coeff ]
    # coeff used to defined if the event condition is taken into account or not  
    ode_fun = lambda t, x_, args: jnp.concatenate([jnp.ones((1,1)), jnp.zeros((1,1)), coeff * jnp.expand_dims(jnp.expand_dims(jnp.abs(event_condition(t, x_, args)),axis=0),axis=0)],axis=1)

    fun = diffrax.ODETerm(ode_fun)

    if high_pres:
        stepsize_controller = diffrax.PIDController(rtol=1E-14, atol=1E-14)
    else:
        stepsize_controller = diffrax.PIDController(rtol=1E-8, atol=1E-8)

    solver = diffrax.Dopri8() 
    root_finder = optx.Bisection(1E-10, 1E-10)

    t1 = 10

    sol = diffrax.diffeqsolve(
        fun,
        solver,
        0.0,
        t1,
        None,
        jnp.concatenate([x, jnp.zeros((1,1))],axis=1),
        stepsize_controller=stepsize_controller,
        max_steps=None, 
        event=diffrax.Event(event_condition, root_finder),
        throw=False
    )

    event_occurred = diffrax.RESULTS.event_occurred == sol.result

    t_result = sol.ts
    x_last = sol.ys[0][:,:2]

    print(event_occurred, x_last, t_result)
    print(sol.result)

event(0.5, high_pres=True, coeff=0.0) # -> with high precision (1E-14), the event is detected
event(0.5, high_pres=False, coeff=0.0) # -> with "low" precision (1E-8), the event is not detected
event(0.5, high_pres=False, coeff=1.0) # -> with "low" precision (1E-8) and take into account the event condition on the ode terms, the event is detected
@patrick-kidger
Copy link
Owner

patrick-kidger commented Oct 4, 2024

This is a classical problem with event handling.

Some ODE solvers attempt to reduce the impact of this by evaluating the event function e.g. 10 times every step, equally spaced apart. (You can still get the same issue of course, it just has to occur at a smaller resolution. As such Diffrax doesn't do this.)

I recommend picking an event function that changes sign only once in your region of integration. Unfortunately there's no smarter way to handle this in general.

@patrick-kidger patrick-kidger added the question User queries label Oct 4, 2024
@dv-ai
Copy link
Author

dv-ai commented Oct 6, 2024

Thank you for your quick answer. After I meet my issue, I realize that, if nothing was done in the solver part, the event will not always detect.

I am not familiar with the event ode litterature, but I think that some author demonstrate some guaranties to detect the event for modern solver with fixed time such "Shampine, L. F., & Thompson, S. (2001). Event location for ordinary differential equations. Computers & Mathematics with Applications, 42(1-2), 85-93." Other reference: Reliable solution of special event location problems for ODEs

For my understanding, they modify the original ode systems (it was my intuition) to obtain guaranties. The modification is quiet obvious by introducing the total derivative of the event function $S$ in the ode:
$\frac{dy(t)}{dt} = F(y(t),t)$
$\frac{dz(t)}{dt} = \frac{\partial S(y(t),t)}{ \partial t} + \nabla_y S(y(t),t) . F(y(t),t)$
$y(0) =x$
$z(0) = S(x,0)$

I think if diffrax doesn't integrate this kind of technics, it could be described in the documention. Because, at first when I don't know this issue, I was thinking that the error of detection will be related to the error of the ode solver and that diffrax manages that.

@patrick-kidger
Copy link
Owner

Right, so this is a nice idea! It 'slows down' the integration to match the rate at which the event function varies.

Unfortunately this requires that the event funtion be real-valued and differentiable. Neither of these things are necessarily true in the general case that we support in Diffrax.

That said I could see us perhaps adding something like an Event(..., reliable=True) flag that would add this auxiliary equation on an opt-in basis. I'd be happy to take a PR on that. (Also tagging @cholberg for interest.)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

2 participants