Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Events #387

Merged
merged 28 commits into from
Jun 29, 2024
Merged

Events #387

merged 28 commits into from
Jun 29, 2024

Conversation

cholberg
Copy link
Contributor

Updates to how events are handled in diffrax. The main changes are:

  • There is now only one event class, Event.
  • Multiple cond_fn are supported. An event is triggered whenever one of them changes sign.
  • Supports differentiable event times if there are real-valued cond_fn and a root_finder is provided.
  • New tests for the new event implementation.

Some things that still might require a little thinking about:

  • Update documentation (and maybe examples, too).
  • What to do with outdated tests.
  • How to handle the case where the control is not smooth, e.g., Brownian motion.

Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks very clean! I've gone through and left some comments.

Comment on lines 90 to 93
event_result: Optional[PyTree[Union[BoolScalarLike, RealScalarLike]]] = None
event_mask: Optional[PyTree[BoolScalarLike]] = None
dense_info_for_event: Optional[DenseInfo] = None
tprevprev: Optional[FloatScalarLike] = None
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think default arguments are needed here.

(This is also just a general good-practice-rule-of-thumb: try to avoid default arguments where possible, as they're a common source of unexpected behaviour. So in practice this usually means putting them on public interfaces.)

@@ -227,14 +238,45 @@ def _maybe_static(static_x: Optional[ArrayLike], x: ArrayLike) -> ArrayLike:
return x


def _is_cond_fn(x: Any) -> bool:
return isinstance(x, Callable)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can replace this whole function with just the built-in callable.

new_event_result: Union[BoolScalarLike, RealScalarLike],
) -> BoolScalarLike:
return jnp.sign(jnp.array(old_event_result, float)) != jnp.sign(
jnp.array(new_event_result, float)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't use the Python builtin float (or int etc.) as a dtype. This will actually give different behaviour on different platforms (MacOS vs Windows etc.) for some strange reason.
I think in this case you're probably looking for dtype=jnp.result_type(old_event_result.dtype, jnp.float32)?

@@ -309,6 +351,14 @@ def body_fun_aux(state):
# everything breaks.) See #143.
y_error = jtu.tree_map(lambda x: jnp.where(jnp.isnan(x), jnp.inf, x), y_error)

# Save info for event handling
if event is not None:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: remove the not and flip teh branches.

Comment on lines 1041 to 1049
_, _, dense_info_for_event, _, _ = solver.step(
terms,
tprev,
tnext,
y0,
args,
solver_state,
made_jump,
)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that this will actually evaluate the function, I think essentially pointlessly (as it will be re-evaluated in the loop).
I think using jax.eval_shape or eqx.filter_eval_shape will allow you to get the structure/dtype/shape of dense_info_for_event without having to actually do any work at runtime. (And without having to compile solver.step an additional time during compilation time either, which is a nontrivial concern.)

Comment on lines 1078 to 1102
if event is not None:

def _call_event(_cond_fn):
return _cond_fn(
init_state,
y=y0,
solver=solver,
stepsize_controller=stepsize_controller,
saveat=saveat,
t0=t0,
t1=t1,
dt0=dt0,
max_steps=max_steps,
terms=terms,
args=args,
)

event_result = jtu.tree_map(_call_event, event.cond_fn, is_leaf=_is_cond_fn)
event_mask = jtu.tree_map(lambda x: False, event.cond_fn, is_leaf=_is_cond_fn)
init_state = eqx.tree_at(
lambda s: s.event_result, init_state, event_result, is_leaf=_is_none
)
init_state = eqx.tree_at(
lambda s: s.event_mask, init_state, event_mask, is_leaf=_is_none
)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(a) I think if you put this before init_state = ... then you can just pass in this information during initialisation, without tree_at.
(b) I think again you can use eval_shape to initialise arrays of the appropriate sort without having to actually evaluate each cond_fn.

if event is not None:
event_mask = final_state.event_mask
event_happened = _event_happened(event_mask)
tevent = final_state.tprev
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: prefer single-assignment form:

if event.root_finder is None:
    tevent = final_state.tprev
else:
    # all the rest of the branch below
    tevent = ...

one of these software best-practices for writing easy-to-read code.

cholberg added a commit to cholberg/diffrax that referenced this pull request Mar 30, 2024
@cholberg
Copy link
Contributor Author

Thank you for the comments, Patrick. I have gone through and made changes accordingly. The only thing I would disagree with is your comment on not having to evaluate cond_fn. Correct me if I am wrong, but would we not need to know the sign of cond_fn evaluated at the intial condition. I.e., apriori you cannot know if an event happens when the condition function goes from negative to positive or vice versa. Furthermore, if we want cond_fn to also possibly take the state as an argument, event_result would need to be initialized after init_state is defined.

@patrick-kidger
Copy link
Owner

Just letting you know that I've not forgotten about this! I'm trying to focus on getting #344 in first, and then I'm hoping to return to this. They're both quite large changes so I don't want them to step on each other's toes.

@patrick-kidger patrick-kidger force-pushed the dev branch 2 times, most recently from 76a9441 to 34cbe5c Compare April 20, 2024 09:27
@patrick-kidger
Copy link
Owner

patrick-kidger commented May 4, 2024

Okay, #344 is in! I'd love to get this in next.

I appeciate that's rather a lot of merge conflicts. If you're able to rebase on to the latest dev branch then I'd be very happy to come back to this and start work on getting this in.

cholberg added a commit to cholberg/diffrax that referenced this pull request May 6, 2024
The main changes are:

    1. Added the generic Event class:
    ```
    class Event:
        event_function: PyTree[EventFn]
        root_finder: Optional[optx.AbstractRootFinder] = None
    ```
    EventFn is defined as:
    ```
    class EventFn(eqx.Module):
        cond_fn: Callable[..., Union[BoolScalarLike, RealScalarLike]]
        transition_fn: Optional[Callable[[PyTree[ArrayLike]], PyTree[ArrayLike]]] = (
            lambda x: x
        )
    ````

    2. Added root finding procedure in diffeqsolve to find exact event times that are differentiable. This is only done when root_finder is not None in the given Event class.

    3. Added event_mask to the Solution class so that, when multiple event functions are passed, the user can see which one was triggered for a given solve.

Hopefully this new event-handling is sufficiently generic to handle all kinds of events in a unified and simple manner. The main benefit is that we can now differentiate through events. So far the current implementation is enough to deal with ODEs, but I suspect more is needed for dealing with SDEs.

The new approach reduces to the old approach when passing only one EventFn with a boolean cond_fn and no transition_fn.

For now the transition_fn is not used, but it will be useful when adding non-terminating events. Similarly, we might add other attributes to EventFn to distinguish between different types of events.

No event cases in root-finding

At the end of the root-fining step (L1146 in _integrate.py), I changed:
```
return jtu.tree_map(
    _call_real,
    event.event_fn,
    final_state.event_result,
    final_state.event_compare,
    is_leaf=_is_event_fn,
)
```

to

```
results = jtu.tree_map(
    _call_real,
    event.event_fn,
    final_state.event_result,
    final_state.event_compare,
    is_leaf=_is_event_fn,
)
results_ravel, _ = jfu.ravel_pytree(results)
return jnp.where(event_happened, results_ravel, final_state.tprev - t)
```

Thus, if no event occurs the root-find will return tprev as desired. Before call_real() was constantly 0 in this case which caused in error in the root-find.

Added EventFn and Event to diffrax/__init__.py

Added tests for new event handling

I added new tests for the updated event implementation which, apart from the old ones, also checks that the right event time is found in the case where a root-find is called and that the derivatives match the theoretical derivatives.

Furthermore, I marked the following tests, that rely on the old event implementation, with @pytest.mark.skip:
 - test_event.py::test_discrete_terminate1
 - test_event.py::test_discrete_terminate2
 - test_event.py::test_event_backsolve
 - test_adjoint.py::test_implicit

In order to avoid pyright errors I had to add # pyright: ignore in a few places in the the old test referenced above.

Deleted old event implementation

I deleted the following two classes:
- diffrax._event.DiscreteTerminatingEvent
- diffrax._event.SteadyStateEvent

These were also removed from the diffrax.__init__.py

Minor changes to event hadnling

The changes are the following:

- Tweaked the event API and got rid of the EventFn class. Now there is only an Event class:

```
class Event(eqx.Module):
    cond_fn: PyTree[Callable[..., Union[BoolScalarLike, RealScalarLike]]]
    root_finder: Optional[optx.AbstractRootFinder] = None
```

- Changed the way boolean condition functions are handled in the root finding step. Now instead of calling _bool_event_gradient, we simply return result = final_state.tprev - t.

- Removed all cases where jtu.ravel_pytree was used.

- Changed "teventprev" to "tprevprev" and "event_compare" to "event_mask" in the State class.

- Updated tests.py and __init__.py to reflect the changes.

Minor changes for simplicity

I slightly changed the initialization of the event attributes in the state in _integrate.py mainly for aesthetic reasons.

Made changes according to comments on patrick-kidger#387

No event case

Changed it so that the final value of the solve is returned in cases where no event happens instead of evaluating the interpolator.
@cholberg
Copy link
Contributor Author

cholberg commented May 6, 2024

Perfect, I rebased and squashed all the commits into a single big one. Quite a few tests are failing when I run it locally now, but I just wanted to update it so you could have a look.

Two thoughts since we last touched base:

  1. We should probably narrow down exaclty what arguments should be passed to the functions in cond_fn.
  2. We might want to make it so that the user has a way of passing arguments to the optimistix.root_find call that finds the exact event time.

@LuggiStruggi
Copy link

LuggiStruggi commented May 7, 2024

Hi! Maybe i am using it wrong but at the moment I can't get the root_finder to do smth. I get the same event times when i use the Newton Method as a root finder or just root_finder = None. Also the tests still pass when setting the root finder to None:

45 def test_continuous_event_time():
46     term = diffrax.ODETerm(lambda t, y, args: 1.0)
47     solver = diffrax.Tsit5()
48     t0 = 0
49     t1 = jnp.inf
50     dt0 = 1.0
51     y0 = -10.0
52 
53     def cond_fn(state, y, **kwargs):
54         assert isinstance(state.y, jax.Array)
55         return y
56 
57     #root_finder = optx.Newton(1e-5, 1e-5, optx.rms_norm)
58     root_finder = None
59     event = diffrax.Event(cond_fn, root_finder)
60     sol = diffrax.diffeqsolve(term, solver, t0, t1, dt0, y0, event=event)
61     assert jnp.all(jnp.isclose(cast(Array, sol.ts), 10.0, 1e-5))

Or in my code the resulting event times do not change if having a root finder or not:

  7 def bouncing_ball():
  8     g = 9.81
  9     damping = 0.8
 10     max_bounces = 10
 11     vx_0 = 5.0
 12 
 13     def dynamics(t, y, args):
 14         x, y, vx, vy = y
 15         dxdt = vx
 16         dydt = vy
 17         dvxdt = 0
 18         dvydt = -g
 19         return jnp.array([dxdt, dydt, dvxdt, dvydt])
 20 
 21     def cond_fn(state, y, **kwargs):
 22         return y[1] < 0
 23 
 24     y0 = jnp.array([0.0, 10.0, vx_0, 0.0])
 25     t0, t1 = 0, float('inf')
 26 
 27     times = []
 28     states = []
 29 
 30     for _ in range(max_bounces):
 31         root_finder = optx.Newton(1e-5, 1e-5)
 32         #root_finder = None
 33         event = Event(cond_fn, root_finder=root_finder)
 34         solver = Tsit5()
 35 
 36         sol = diffeqsolve(ODETerm(dynamics), solver, t0, t1, 0.01, y0, event=event)
 37 
 38         t0 = sol.ts[-1]
 39         last_y = sol.ys[-1]
 40         y0 = last_y * jnp.array([1, 0, 1, -damping])
 41         times.append(sol.ts)
 42         states.append(y0)
 43 
 44 
 45     return jnp.array(times), jnp.array(states)

Thanks for the help! (Also sorry If this is the wrong place to ask for this, just let me know where to write this) :)

@cholberg
Copy link
Contributor Author

cholberg commented May 7, 2024

Hi! Maybe i am using it wrong but at the moment I can't get the root_finder to do smth. I get the same event times when i use the Newton Method as a root finder or just root_finder = None. Also the tests still pass when setting the root finder to None:

Ah, thanks for mentioning this. This is essentially due to the fact that the solution to the ODE is linear and the fact that dt0 divides 10.0 which is exactly the time at which the solution crosses 0. In other words, it just so happens that the root is exactly at the end point of the last step of the solver which is also the event time that is returned when no root finder is provided. I have tweaked the test so this is no longer the case.

Or in my code the resulting event times do not change if having a root finder or not:

In your example the cond_fn returns a boolean. In this case the returned event time is exactly the first step of the solver for which cond_fn switches sign, i.e., your event time will be n * dt0 where n is the number of steps taken until and including the point at which y[1] >= 0. Note that this is the same behaviour as when root_finder=None.

If you want continuous event times, you should specify a real-valued condition function. In your bouncing ball example, this would simply correspond to setting:

def cond_fn(state, y, **kwargs):
	return y[1]

@LuggiStruggi
Copy link

LuggiStruggi commented May 7, 2024

Ah i see! Was just a bit confusing that both boolean and comparison with 0 is possible. I tried a real valued cond_fn initially but then it was every time directly triggered at t=0 bc i reset the state to 0 at every jump. I guess setting it to a small value should work. Thank you! :)

@LuggiStruggi
Copy link

LuggiStruggi commented May 7, 2024

Regarding my confusion: Maybe this is a bad idea but maybe it would make sense to have a warning if both a boolean cond_fn and a root finder are passed to the event since this should never make sense then?

@LuggiStruggi
Copy link

LuggiStruggi commented May 13, 2024

If I pass multiple cond_fn, what would you suggest to determine which of them caused the Event? :)
Thanks for the help!

@cholberg
Copy link
Contributor Author

Regarding my confusion: Maybe this is a bad idea but maybe it would make sense to have a warning if both a boolean cond_fn and a root finder are passed to the event since this should never make sense then?

Hmm, I'm not too sure about this. There might be cases where you have one real-valued event function and one boolean. E.g., in the bouncing ball example we might want to add an extra function to cond_fn stopping the solve when the velocity is low enough (i.e., the ball has reached a steady state and stopped bouncing). But happy to hear your thoughts as well @patrick-kidger.

If I pass multiple cond_fn, what would you suggest to determine which of them caused the Event? :)
Thanks for the help!

The solution returned by diffeqsolve has an attribute called event_mask which is a PyTree of the same structure as your cond_fn where each leaf is False if the corresponding condition function did not trigger an event and True otherwise. (Note: for now, only one leaf can be True). If you're interested in a more involved example of how this all works with multiple event handling you might want to check out this repo (specifically snn.py). This is a very much a work in progress, but it should serve as an example of how it all works. Hope that helps!

@patrick-kidger
Copy link
Owner

Agreed on the first point. Happy to add a warning if all events are Boolean, though -- no strong feelings.

@dkweiss31
Copy link
Contributor

Hi folks! Very excited about this PR, as I'm thinking about quantum jump applications for dynamiqs. I'm unfortunately running into an error If I try to pass an option different from saveat = SaveAt(t1=True) :

import diffrax as dx
import optimistix as optx
import jax
import jax.numpy as jnp

term = dx.ODETerm(lambda t, y, args: y)
solver = dx.Tsit5()
t0 = 0
t1 = 100.0
dt0 = 1.0
y0 = 1.0
ts = jnp.arange(t0, t1, dt0)

def cond_fn(state, y, **kwargs):
    assert isinstance(state.y, jax.Array)
    return y - jnp.exp(1.0)

fn = lambda t, y, args: y

subsaveat_a = dx.SubSaveAt(ts=ts, fn=fn)  # save solution regularly
subsaveat_b = dx.SubSaveAt(t1=True)  # save last state
saveat = dx.SaveAt(subs=[subsaveat_a, subsaveat_b])

root_finder = optx.Newton(1e-5, 1e-5, optx.rms_norm)
event = dx.Event(cond_fn, root_finder)
sol = dx.diffeqsolve(
    term, solver, t0, t1, dt0, y0, saveat=saveat, event=event
)

This runs into an error on line 1204 of _integrate.py

    ys = jtu.tree_map(lambda _y, _yevent: _y.at[-1].set(_yevent), ys, yevent)
ValueError: Expected list, got Traced<ShapedArray(float32[1])>with<DynamicJaxprTrace(level=1/0)>.

Thanks!

cholberg added a commit to cholberg/diffrax that referenced this pull request May 15, 2024
The main changes are:

    1. Added the generic Event class:
    ```
    class Event:
        event_function: PyTree[EventFn]
        root_finder: Optional[optx.AbstractRootFinder] = None
    ```
    EventFn is defined as:
    ```
    class EventFn(eqx.Module):
        cond_fn: Callable[..., Union[BoolScalarLike, RealScalarLike]]
        transition_fn: Optional[Callable[[PyTree[ArrayLike]], PyTree[ArrayLike]]] = (
            lambda x: x
        )
    ````

    2. Added root finding procedure in diffeqsolve to find exact event times that are differentiable. This is only done when root_finder is not None in the given Event class.

    3. Added event_mask to the Solution class so that, when multiple event functions are passed, the user can see which one was triggered for a given solve.

Hopefully this new event-handling is sufficiently generic to handle all kinds of events in a unified and simple manner. The main benefit is that we can now differentiate through events. So far the current implementation is enough to deal with ODEs, but I suspect more is needed for dealing with SDEs.

The new approach reduces to the old approach when passing only one EventFn with a boolean cond_fn and no transition_fn.

For now the transition_fn is not used, but it will be useful when adding non-terminating events. Similarly, we might add other attributes to EventFn to distinguish between different types of events.

No event cases in root-finding

At the end of the root-fining step (L1146 in _integrate.py), I changed:
```
return jtu.tree_map(
    _call_real,
    event.event_fn,
    final_state.event_result,
    final_state.event_compare,
    is_leaf=_is_event_fn,
)
```

to

```
results = jtu.tree_map(
    _call_real,
    event.event_fn,
    final_state.event_result,
    final_state.event_compare,
    is_leaf=_is_event_fn,
)
results_ravel, _ = jfu.ravel_pytree(results)
return jnp.where(event_happened, results_ravel, final_state.tprev - t)
```

Thus, if no event occurs the root-find will return tprev as desired. Before call_real() was constantly 0 in this case which caused in error in the root-find.

Added EventFn and Event to diffrax/__init__.py

Added tests for new event handling

I added new tests for the updated event implementation which, apart from the old ones, also checks that the right event time is found in the case where a root-find is called and that the derivatives match the theoretical derivatives.

Furthermore, I marked the following tests, that rely on the old event implementation, with @pytest.mark.skip:
 - test_event.py::test_discrete_terminate1
 - test_event.py::test_discrete_terminate2
 - test_event.py::test_event_backsolve
 - test_adjoint.py::test_implicit

In order to avoid pyright errors I had to add # pyright: ignore in a few places in the the old test referenced above.

Deleted old event implementation

I deleted the following two classes:
- diffrax._event.DiscreteTerminatingEvent
- diffrax._event.SteadyStateEvent

These were also removed from the diffrax.__init__.py

Minor changes to event hadnling

The changes are the following:

- Tweaked the event API and got rid of the EventFn class. Now there is only an Event class:

```
class Event(eqx.Module):
    cond_fn: PyTree[Callable[..., Union[BoolScalarLike, RealScalarLike]]]
    root_finder: Optional[optx.AbstractRootFinder] = None
```

- Changed the way boolean condition functions are handled in the root finding step. Now instead of calling _bool_event_gradient, we simply return result = final_state.tprev - t.

- Removed all cases where jtu.ravel_pytree was used.

- Changed "teventprev" to "tprevprev" and "event_compare" to "event_mask" in the State class.

- Updated tests.py and __init__.py to reflect the changes.

Minor changes for simplicity

I slightly changed the initialization of the event attributes in the state in _integrate.py mainly for aesthetic reasons.

Made changes according to comments on patrick-kidger#387

No event case

Changed it so that the final value of the solve is returned in cases where no event happens instead of evaluating the interpolator.
@cholberg
Copy link
Contributor Author

Ah, yes I forgot to handle the case where multiple SubSaveAt are passed. Should be fixed now with the latest commit. Let me know if that works for you.

@dkweiss31
Copy link
Contributor

Thanks for the quick response! So that fixed the example I posted, however I am still running into issues on slightly more complicated examples more in line with how dynamiqs actually calls diffeqsolve and more specifically how it saves data as the simulation progresses. Here is a MWE where y saves the state and y2 saves "expectation values".

import diffrax as dx
import optimistix as optx
import jax
import jax.numpy as jnp
import equinox as eqx
from jax import Array

term = dx.ODETerm(lambda t, y, args: y + t)
solver = dx.Tsit5()
t0 = 0
t1 = 100.0
dt0 = 1.0
y0 = jnp.array([[1.0], [0.0]])
ts = jnp.arange(t0, t1, dt0)


def cond_fn(state, y, **kwargs):
    assert isinstance(state.y, jax.Array)
    norm = jnp.einsum("ij,ij->", y, y)
    return norm - jnp.exp(1.0)


class Saved(eqx.Module):
    y: Array
    y2: Array

def save_fn(t, y, args):
    ynorm = jnp.einsum("ij,ij->", y, y)
    return Saved(y, jnp.array([ynorm, 3 * ynorm]))

subsaveat_a = dx.SubSaveAt(ts=ts, fn=save_fn)  # save solution regularly
subsaveat_b = dx.SubSaveAt(t1=True)  # save last state
saveat = dx.SaveAt(subs=[subsaveat_a, subsaveat_b])

root_finder = optx.Newton(1e-5, 1e-5, optx.rms_norm)
event = dx.Event(cond_fn, root_finder)
sol = dx.diffeqsolve(
    term, solver, t0, t1, dt0, y0, saveat=saveat, event=event
)

This runs into

ValueError: Cannot broadcast to shape with fewer dimensions: arr_shape=(2, 2) shape=(2,)

Interestingly, the code runs without errors if I replace return Saved(y, jnp.array([ynorm, 3 * ynorm])) with return Saved(y, jnp.array([y, 3 * y]))

@cholberg
Copy link
Contributor Author

You're right I did not account for the fact that SubSaveAt.fn could return a PyTree. Should be fixed now. At least your MWE works with the latest commit.

@dkweiss31
Copy link
Contributor

Indeed that fixed my MWE! I hate to be such a pain but I am now running into another issue, here is an example that is now much closer to the actual code I am interested in running.

import diffrax as dx
import optimistix as optx
import jax.numpy as jnp

L_op = 0.1 * jnp.array([[0.0, 1.0],
                        [0.0, 0.0]], dtype=complex)
H = 0.0 * L_op
t0 = 0
t1 = 100.0
dt0 = 1.0
y0 = jnp.array([[0.0], [1.0]], dtype=complex)


def vector_field(t, state, _args):
    L_d_L = jnp.transpose(L_op) @ L_op
    new_state = -1j * (H - 1j * 0.5 * L_d_L) @ state
    return new_state


def cond_fn(state, **kwargs):
    psi = state.y
    prob = jnp.abs(jnp.einsum("id,id->", jnp.conj(psi), psi))
    return prob - 0.95


term = dx.ODETerm(vector_field)
root_finder = optx.Newton(1e-5, 1e-5, optx.rms_norm)
event = dx.Event(cond_fn, root_finder)

sol = dx.diffeqsolve(
    term, dx.Tsit5(), t0, t1, dt0, y0, event=event
)

This runs into

equinox._errors.EqxRuntimeError: The linear solver returned non-finite (NaN or inf) output. This usually means that the
operator was not well-posed, and that the solver does not support this.

It's possible this could be due to my use of complex numbers, which as I understand are only partly supported in diffrax? However with the previous DiscreteTerminatingEvent I did not run into such errors. Note that if I change 0.95 to 0.0 I don't see this error (likely because the event is not triggered)

@patrick-kidger
Copy link
Owner

Complex number support is definitely still iffy. Can you try reproducing this without using them? You can still solve the same equation, mathematically speaking, just by splitting things into separate real and imaginary components.

@dkweiss31
Copy link
Contributor

Right, here is the same example using the complex->real isomorphism described e.g. here (see Eq. (9)). I am getting the same error as before, so it seems then this is not a complex number issue

import diffrax as dx
import optimistix as optx
import jax.numpy as jnp


def mat_cmp_to_real(matrix):
    re_matrix = jnp.real(matrix)
    im_matrix = jnp.imag(matrix)
    top_row = jnp.hstack((re_matrix, -im_matrix))
    bottom_row = jnp.hstack((im_matrix, re_matrix))
    return jnp.vstack((top_row, bottom_row))


def vec_cmp_to_real(vector):
    re_vec = jnp.real(vector)
    im_vec = jnp.imag(vector)
    return jnp.vstack((re_vec, im_vec))


L_op = 0.1 * jnp.array([[0.0, 1.0],
                        [0.0, 0.0]], dtype=complex)
L_d_L = jnp.transpose(L_op) @ L_op
H = 0.0 * L_op
_prop = -1j * (H - 1j * 0.5 * L_d_L)
_y0 = jnp.array([[0.0], [1.0]], dtype=complex)
prop = mat_cmp_to_real(_prop)
y0 = vec_cmp_to_real(_y0)
t0 = 0
t1 = 100.0
dt0 = 1.0


def vector_field(t, state, _args):
    new_state = prop @ state
    return new_state


def cond_fn(state, **kwargs):
    psi = state.y
    prob = jnp.abs(jnp.einsum("id,id->", jnp.conj(psi), psi))
    return prob - 0.95


term = dx.ODETerm(vector_field)
root_finder = optx.Newton(1e-5, 1e-5, optx.rms_norm)
event = dx.Event(cond_fn, root_finder)

sol = dx.diffeqsolve(
    term, dx.Tsit5(), t0, t1, dt0, y0, event=event
)

patrick-kidger added a commit that referenced this pull request May 19, 2024
- Semantic change: boolean events now trigger when they become truthy (before they occurred when they swap being falsy<->truthy). Note that this required twiddling around a few things as previously it was impossible for an event to trigger on the first step; now they can.
- Semantic change: event functions now have the signature `(t, y, args *, terms, solver, **etc)` for consistency with vector fields and with `SaveAt(fn=...)`.
- Feature: now backward-compatible with the old discrete terminating events.
- Feature: added `diffrax.steady_state_event`.
- Bugfix: the final `t` and `y` from an event are now saved in the correct index of `ts` and `ys`, rather than just always being saved at index `-1`.
- Bugfix: at one point `args` referred to the `args` coming from a root find rather than the overall `diffeqsolve`.
- Bugfix: the current `state.tprev` was used instead of the previous state's `tnext`. (These are usually but not always the same -- in particular when around jumps.)
- Bugfix: added some checks when the condition function of an event does not return a bool/float scalar.
- Performance: includes a fastpath for skipping the rootfind if no events are triggered.
- Performance: now avoiding tracing for the shape of `dense_info` twice when using adaptive step size controllers alongside events.
- Performance: avoided quadratic loop for figuring out what was the first event to trigger.
- Chore: added support for the possibility of the final root find (for the time of the event) failing.
- Chore: removed some dead code (`_bool_event_gradient`).
- Chore: removed references in the docs to the old `discrete_terminating_event`.

In addition, some drive-bys:

- Fixed warnings about pending deprecations `jnp.clip(..., a_min=..., a_max=...)`.
- Had `aux_stats` (in `_integrate.py`) forward to the overall output statistics. In practice this is empty but it's worth doing for the future.
@patrick-kidger
Copy link
Owner

Thankyou @dkweiss31!
@cholberg, are you able to have a look at this example?

Anyway, as promised! Getting this in is my next priority. As such I've gone through and submitted a PR against this branch here. I don't claim that everything I've done is necessarily correct, so @cholberg I'd appreciate a review! :D

@cholberg
Copy link
Contributor Author

Perfect! Rebased on the latest dev now :)

@johannahaffner
Copy link
Contributor

johannahaffner commented Jun 26, 2024

Hi All!

I just opened #448, I get the same behavior on diffrax:main and cholberg/diffrax:dev.
When solver tolerances are atol=1e-06, rtol=1e-03 on cholberg/diffrax:dev, the solver terminates without a single step (?), reported t1 is 0.0.

There also appears to be a difference with respect to how the results are reported, as I get

---> 28 assert solution_a.result == dfx.RESULTS.discrete_terminating_event_occurred
(...)
--> 158 raise ValueError(
    159     "Can only compare equality between enumerations of the same type."
    160 )

ValueError: Can only compare equality between enumerations of the same type.

which is not raised on diffrax:main.

@johannahaffner
Copy link
Contributor

johannahaffner commented Jun 26, 2024

Alright, I was merely being obtuse and didn't think about the most obvious thing: how the absolute and relative tolerances actually affect the terminating condition, and what that is based on. I assumed that it would compare rates of change $dy$ for consecutive steps. If you think that others might make the same mistake, I suggest adding a half-sentence to the documentation, something like

class SteadyStateEvent(AbstractDiscreteTerminatingEvent):
    """Terminates the solve once it reaches a steady state, defined as 
    norm(dy) < atol + rtol * norm(y).
    """

@patrick-kidger patrick-kidger merged commit b55e4e4 into patrick-kidger:dev Jun 29, 2024
2 checks passed
@patrick-kidger
Copy link
Owner

Aaaaand... merged! 🎉

Thank you for all your hard work @cholberg! This was a massive PR, so getting this in is no small feat. Especially in an large existing codebase like Diffrax.

This is now on the dev branch, which I'll be merging into main and doing a new release for shortly. :)

@patrick-kidger
Copy link
Owner

By the way, @dkweiss31 -- since I've not seen any movement on jax-ml/jax#22011, my guess is that we should probably work around this by registering a transpose rule for lax.stop_gradient ourselves. This transpose rule should just be the identity function.

I'd be happy to take a PR on this!

@dkweiss31
Copy link
Contributor

Nice I'd love to take a swing at that! This would probably want to be in equinox rather than diffrax?

@patrick-kidger
Copy link
Owner

patrick-kidger commented Jul 1, 2024

Hmm, I think probably Optimistix actually, now you mention it.
This is what's doing the manual linearisation and transposition, so it's there that we fall afoul of this bug.

patrick-kidger added a commit that referenced this pull request Jul 1, 2024
* Changes to how events are handled in diffrax.

The main changes are:

    1. Added the generic Event class:
    ```
    class Event:
        event_function: PyTree[EventFn]
        root_finder: Optional[optx.AbstractRootFinder] = None
    ```
    EventFn is defined as:
    ```
    class EventFn(eqx.Module):
        cond_fn: Callable[..., Union[BoolScalarLike, RealScalarLike]]
        transition_fn: Optional[Callable[[PyTree[ArrayLike]], PyTree[ArrayLike]]] = (
            lambda x: x
        )
    ````

    2. Added root finding procedure in diffeqsolve to find exact event times that are differentiable. This is only done when root_finder is not None in the given Event class.

    3. Added event_mask to the Solution class so that, when multiple event functions are passed, the user can see which one was triggered for a given solve.

Hopefully this new event-handling is sufficiently generic to handle all kinds of events in a unified and simple manner. The main benefit is that we can now differentiate through events. So far the current implementation is enough to deal with ODEs, but I suspect more is needed for dealing with SDEs.

The new approach reduces to the old approach when passing only one EventFn with a boolean cond_fn and no transition_fn.

For now the transition_fn is not used, but it will be useful when adding non-terminating events. Similarly, we might add other attributes to EventFn to distinguish between different types of events.

No event cases in root-finding

At the end of the root-fining step (L1146 in _integrate.py), I changed:
```
return jtu.tree_map(
    _call_real,
    event.event_fn,
    final_state.event_result,
    final_state.event_compare,
    is_leaf=_is_event_fn,
)
```

to

```
results = jtu.tree_map(
    _call_real,
    event.event_fn,
    final_state.event_result,
    final_state.event_compare,
    is_leaf=_is_event_fn,
)
results_ravel, _ = jfu.ravel_pytree(results)
return jnp.where(event_happened, results_ravel, final_state.tprev - t)
```

Thus, if no event occurs the root-find will return tprev as desired. Before call_real() was constantly 0 in this case which caused in error in the root-find.

Added EventFn and Event to diffrax/__init__.py

Added tests for new event handling

I added new tests for the updated event implementation which, apart from the old ones, also checks that the right event time is found in the case where a root-find is called and that the derivatives match the theoretical derivatives.

Furthermore, I marked the following tests, that rely on the old event implementation, with @pytest.mark.skip:
 - test_event.py::test_discrete_terminate1
 - test_event.py::test_discrete_terminate2
 - test_event.py::test_event_backsolve
 - test_adjoint.py::test_implicit

In order to avoid pyright errors I had to add # pyright: ignore in a few places in the the old test referenced above.

Deleted old event implementation

I deleted the following two classes:
- diffrax._event.DiscreteTerminatingEvent
- diffrax._event.SteadyStateEvent

These were also removed from the diffrax.__init__.py

Minor changes to event hadnling

The changes are the following:

- Tweaked the event API and got rid of the EventFn class. Now there is only an Event class:

```
class Event(eqx.Module):
    cond_fn: PyTree[Callable[..., Union[BoolScalarLike, RealScalarLike]]]
    root_finder: Optional[optx.AbstractRootFinder] = None
```

- Changed the way boolean condition functions are handled in the root finding step. Now instead of calling _bool_event_gradient, we simply return result = final_state.tprev - t.

- Removed all cases where jtu.ravel_pytree was used.

- Changed "teventprev" to "tprevprev" and "event_compare" to "event_mask" in the State class.

- Updated tests.py and __init__.py to reflect the changes.

Minor changes for simplicity

I slightly changed the initialization of the event attributes in the state in _integrate.py mainly for aesthetic reasons.

Made changes according to comments on #387

No event case

Changed it so that the final value of the solve is returned in cases where no event happens instead of evaluating the interpolator.

* Test now fails when no root finder is provided

* Saving events with `SubSaveAt`s

Previously, updating the last element of ys and ts did not handle the case where multiple `SubSaveAt`s were used. This is now fixed by adding a `jtu.tree_map` in the appropriate place.

* Accounting for `SubSaveAt.fn` returning a PyTree

* Adjustments to #387 (events):

- Semantic change: boolean events now trigger when they become truthy (before they occurred when they swap being falsy<->truthy). Note that this required twiddling around a few things as previously it was impossible for an event to trigger on the first step; now they can.
- Semantic change: event functions now have the signature `(t, y, args *, terms, solver, **etc)` for consistency with vector fields and with `SaveAt(fn=...)`.
- Feature: now backward-compatible with the old discrete terminating events.
- Feature: added `diffrax.steady_state_event`.
- Bugfix: the final `t` and `y` from an event are now saved in the correct index of `ts` and `ys`, rather than just always being saved at index `-1`.
- Bugfix: at one point `args` referred to the `args` coming from a root find rather than the overall `diffeqsolve`.
- Bugfix: the current `state.tprev` was used instead of the previous state's `tnext`. (These are usually but not always the same -- in particular when around jumps.)
- Bugfix: added some checks when the condition function of an event does not return a bool/float scalar.
- Performance: includes a fastpath for skipping the rootfind if no events are triggered.
- Performance: now avoiding tracing for the shape of `dense_info` twice when using adaptive step size controllers alongside events.
- Performance: avoided quadratic loop for figuring out what was the first event to trigger.
- Chore: added support for the possibility of the final root find (for the time of the event) failing.
- Chore: removed some dead code (`_bool_event_gradient`).
- Chore: removed references in the docs to the old `discrete_terminating_event`.

In addition, some drive-bys:

- Fixed warnings about pending deprecations `jnp.clip(..., a_min=..., a_max=...)`.
- Had `aux_stats` (in `_integrate.py`) forward to the overall output statistics. In practice this is empty but it's worth doing for the future.

* Save values returned by root find when

* now returns condition function

* Fixed error for . All tests pass now.

* Added additional tests

Added a bunch of additional tests for events.

Also changed the way `save_index` was updated to handle PyTrees of subsaveats.

* Fixed save_index update and shape+dtype check for cond_fn

* Added PyTree check in _outer_cond_fn

* Added tests for checking that events error out correctly under misspecified cond_fn

* Fixed small error in the save_index update for events

* Updated how events are saved

When passing `SaveAt(steps=True, ts=ts)` for some array `ts` values will be saved at the times in `ts` in the time increments of each step of the solver. In practice this means that some of the saved values might be after the event time. I changed it so that these values are deleted.

* Added tests for different configurations of saveat

* Changed to ValueError when cond_fn returns non-boolean/float.

* Added docstring to Event class

* Updated docstring for steady_state_event

* Updated docstring for ImplicitAdjoint

* Added example to Event docstring

* Updated steady state example to use the new syntax

* Fixed weird type checker error

* Updated steady state test to use the new syntax

* Doc tweaks for events

* Typo in comment

* Simplified unsaving

* Deleted extra unnecessary argument

* Changed to strict inequality to be in line with the usual saving behviour

---------

Co-authored-by: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com>
@cholberg cholberg deleted the dev branch July 2, 2024 09:22
@dkweiss31
Copy link
Contributor

Hey @patrick-kidger, sorry for the delay I was on vacation last week :).

Please forgive and apologies in advance for my likely elementary jax mistakes: as a test I've tried doing the simplest possible thing, and added

from jax.interpreters import ad
from jax._src.ad_util import stop_gradient_p


def stop_gradient_transpose(ct, x):
    return ct

ad.primitive_transposes[stop_gradient_p] = stop_gradient_transpose

to the top of your example in jax-ml/jax#22011 as well as in my above example. Interestingly, your example now runs, while my example fails on line 720 of _integrate.py

line 720, in loop
    tfinal, yfinal, result = lax.cond(
jax._src.source_info_util.JaxStackTraceBeforeTransformation: AssertionError

Any idea what might be happening? (Also happy to move this conversation elsewhere, but just thought I'd keep it on this thread for continuities sake :))

@patrick-kidger
Copy link
Owner

Happy to keep the discussion here!
The fix is just three characters: you want return (ct,). :)

With this, your example works for me!

@dkweiss31
Copy link
Contributor

Ha! I knew it was something silly. Thank you! Just opened a PR in optimistix :)

@dkweiss31
Copy link
Contributor

I am, unfortunately, still encountering errors after the transpose fix. See my latest MWE, which fails with

EqxRuntimeError: Terminating differential equation solve because an event occurred.

after a few times through the for loop of the optimizer. Such a bizarre error, since I expect the differential equation to terminate!

from functools import partial
import optax
import jax
import jax.numpy as jnp
import diffrax as dx
import optimistix as optx


def run_traj(H, jump_ops, psi0):
    def vector_field(t, state, _args):
        Ls = jnp.stack([L for L in jump_ops])
        LdL = (Ls.mT.conj() @ Ls).sum(axis=0)
        new_state = -1j * (H - 1j * 0.5 * LdL) @ state
        return new_state

    def cond_fn(t, y, *args, **kwargs):
        return jnp.abs(jnp.conj(y) @ y) - no_jump_prob

    event = dx.Event(cond_fn, optx.Newton(1e-5, 1e-5, optx.rms_norm))
    solution = dx.diffeqsolve(
        dx.ODETerm(vector_field), dx.Tsit5(), t0=ts[0], t1=ts[-1], dt0=dt0, y0=psi0, event=event,
    )
    return solution.ys


def run(H, jump_ops, psi0):
    f = jax.vmap(run_traj, in_axes=(None, None, 0))
    return f(H, jump_ops, psi0)


def optimize(params_to_optimize):
    opt_state = optimizer.init(params_to_optimize)
    for epoch in range(10):
        params_to_optimize, opt_state = step(params_to_optimize, opt_state)
    return params_to_optimize


@partial(jax.jit)
def step(params_to_optimize, opt_state,):
    grads = jax.grad(loss)(params_to_optimize,)
    updates, opt_state = optimizer.update(grads, opt_state)
    params_to_optimize = optax.apply_updates(params_to_optimize, updates)
    return params_to_optimize, opt_state


def loss(params_to_optimize):
    res = run(params_to_optimize * H1, [a, ], initial_states)
    return jnp.log(jnp.sum(jnp.abs(res)))


dim = 4
a = jnp.diag(jnp.sqrt(jnp.arange(1, stop=dim, dtype=complex)), k=1)
H1 = a + a.mT.conj()
ket0 = jnp.array([1.0, 0.0, 0.0, 0.0], dtype=complex)
ket1 = jnp.array([0.0, 1.0, 0.0, 0.0], dtype=complex)
initial_states = jnp.asarray([ket0, ket1])
dt0 = 1.0
ts = jnp.arange(0.0, 100.0, dt0)
no_jump_prob = 0.9
optimizer = optax.adam(learning_rate=0.9)
init_drive_params = 0.1
opt_params = optimize(
    params_to_optimize=init_drive_params,
)

@dkweiss31
Copy link
Contributor

Sorry, the previous example was not exactly minimal. Here is a more minimal example that doesn't need the optimization layer. The error is now

equinox._errors.EqxRuntimeError: The maximum number of steps was reached in the nonlinear solver. The problem may not be solveable (e.g., a root-find on a function that has no roots), or you may need to increase `max_steps`.

I've tried increasing the number of max steps to no avail. Interestingly, choosing init_drive_params = -1.1 or 1.2 runs without an error

import jax.numpy as jnp
import diffrax as dx
import optimistix as optx


def run_traj(H, jump_op, psi0):
    def vector_field(t, state, _args):
        LdL = jump_op.mT.conj() @ jump_op
        new_state = -1j * (H - 1j * 0.5 * LdL) @ state
        return new_state

    def cond_fn(t, y, *args, **kwargs):
        return jnp.abs(jnp.conj(y) @ y) - no_jump_prob

    event = dx.Event(cond_fn, optx.Newton(1e-5, 1e-5, optx.rms_norm))
    solution = dx.diffeqsolve(
        dx.ODETerm(vector_field), dx.Tsit5(), t0=ts[0], t1=ts[-1], dt0=dt0, y0=psi0, event=event
    )
    return solution.ys


a = jnp.diag(jnp.sqrt(jnp.arange(1, stop=4, dtype=complex)), k=1)
ket1 = jnp.array([0.0, 1.0, 0.0, 0.0], dtype=complex)
dt0 = 1.0
ts = jnp.arange(0.0, 100.0, dt0)
no_jump_prob = 0.9
init_drive_params = -1.19
res = run_traj(init_drive_params * (a + a.mT.conj()), a, ket1)

@patrick-kidger
Copy link
Owner

Hmm! I think a big part of what this is speaking to is the need for better debug tooling.

There is one big thing I've been wanting to do here, which is to stop using throw=False on all of our calls to Optimistix. Instead, we should respect the value of diffeqsolve(..., throw=...). Then if this is True, the error will be localized to exactly where it occured. If this is False, we still get the desired behaviour of not throwing. I'd be happy to take a PR on this. (I think it might need a bit of threading through to all the places that we use it.)

Tweaking this locally and setting EQX_ON_ERROR=breakpoint, JAX_DISABLE_JIT=1, then I can see that this error is thrown from a root find to locate the event, seemingly on the very first step. Which seems accurate:

> event.cond_fn(0, _interpolator.evaluate(0))
Array(0.10000002, dtype=float32)
> event.cond_fn(1, _interpolator.evaluate(1))
Array(-0.47552538, dtype=float32)

Can you add some jax.debug.breakpoint (or some of the tools from eqx.debug.* inside the iteration to try debugging this?

@patrick-kidger
Copy link
Owner

FWIW you've spurred me to try and improve things from the Equinox side, at least.

patrick-kidger/equinox#785

This should now give much more clearly informative error messages, at least once the error_if is actually triggered.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants