-
-
Notifications
You must be signed in to change notification settings - Fork 134
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
First DDE version #169
base: main
Are you sure you want to change the base?
First DDE version #169
Conversation
Okay, so there's a lot of spurious changes here, mostly due to unneeded formatting changes. Take a look at CONTRIBUTING.md, and in particular the pre-commit hooks. These will autoformat etc. the code. I'll be able to do a proper review then. Regarding passing def diffeqsolve(y0, delays, ...):
if delays is None:
y0_history = None
else:
y0_history = y0
y0 = y0_history(t0)
adjoint.loop(..., y0_history=y0_history) so that internally we still disambiguate between Regarding the changes to constant step sizing: hmm, this seems strange to me. I don't think we should need to change any stepsize controller at all. I think the stepsize controller changes we need to make (due to discontinuities) should happen entirely within |
Hello, 1/ I used the pre-commit hooks but still have one small spurious change in 4/ 1 edges cases was found, i haven't thought too much and did a "sloppy" fix for now (https://github.com/patrick-kidger/diffrax/pull/169/files#r991488410). Essentially it comes when we integrate a step from Things not done : |
Latest commit does what you suggested
To do :
Comments :
|
Latest commit has :
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay I've reviewed about half of it! I'll try and get to the other half in the next week or so.
Overall I like where this is going. I think we now have some first implementations for every piece.
Regarding my comments about avoiding in-place updates. I think this is probably doable by updating HistoryVectorField
to operate on three regions: y0_history
, the recorded dense_infos
, and finally the current step. (Much like how it is operating over two regions at the moment.) I think this should allow us to generate efficient code.
Latest commit updates some remarks of code comment after latest review #169 (review). I've bundled together as you said the delays term together for an easier API. class _Delays(eqx.Module):
delays: Optional[PyTree[Callable]]
initial_discontinuities: Union[Array, Tuple]
max_discontinuities: Int
recurrent_checking: Bool
rtol: float
atol: float
eps : float
What you suggested is great for later iterations because we can just slap any new arguments into What needs to be done :
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, let's focus on this bit before we move on to discussing the discontinuity handling below. I'll leave you to make the changes already discussed here. Let me know if any of them aren't clear.
Sounds good, I'll take care of the first bullet point later, the second one should be ok on my side however i'd like to have your take on the third one with the in-place operations (for |
Sure thing. I'm suggesting that class _HistoryVectorField(eqx.Module):
...
tprev: float
tnext: float
dense_info: PyTree[Array]
interpolation_cls: Type[AbstractLocalInterpolation]
def __call__(self, t, y, args):
...
if self.dense_interp is None:
...
else:
for delay in delays:
delay_val = delay(t, y, args)
alpha_val = t - delay_val
is_before_t0 = alpha_val < self.t0
is_before_tprev = alpha_val < self.tprev
at_most_t0 = jnp.min(alpha_val, self.t0)
t0_to_tprev = jnp.clip(alpha_val, self.t0, self.tprev)
at_least_tprev = jnp.max(alpha_val, self.tprev)
step_interpolation = self.interpolation_cls(
t0=self.tprev, t1=self.tnext, **self.dense_info
)
switch = jnp.where(is_before_t0, 0, jnp.where(is_before_tprev, 1, 2))
history_val = lax.switch(switch, [lambda: self.y0_history(at_most_t0),
lambda: self.dense_interp(t0_to_tprev),
lambda: step_interpolation.evaluate(at_least_tprev)])
history_vals.append(history_val)
...
return ... And then when it is called inside the implicit routine: def body_fun(val):
dense_info, ... = val
...
_HistoryVectorField(..., state.tprev, state.tnext, dense_info, solver.interpolation_cls)
...
return new_dense_info, ... |
Latest commit should have handle all of the issues mentionned above.
Discussion/Bottleneck for implicit step Regarding the implicit step we have a issue when it comes to large steps because an step_interpolation = self.interpolation_cls(t0=self.tprev, t1=self.tnext, **self.dense_info) To elaborate a bit more, if we have an implicit step from _pred = (((y - y_prev) / y) > delays.eps).any() to something that checks the MSE of the extrapolated history function before and after the integration step. Not sure with this in mind the This also impacts too the population of the |
I don't think this is true. Anything after Regarding 2-point interpolations: this isn't the case for most solvers. Each solver evaluates various intermediate quantities during its step (e.g. the stages of an RK solver) and these also feed into the interpolation. Even if were, though, I don't think it matters: we just need to converge to a solution of the implicit problem. |
Ok, this makes sense, so i'll take back what I said in my bottleneck "Discussion/Bottleneck for implicit step", thanks for the clarification ! The Edit : |
Relevant changes are made :
In order to get something backprop compatible in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, this is probably the final round of review on the implicit stepping. (Nice work!) I'll go over the discontinuity checking at some point shortly.
Regarding backprop through the lax.while_loop
, the correct thing to do here is actually to use the implicit function theorem, e.g. as is already done with Newton's method. This can be wired up using the misc.implicit_jvp
helper. Let me know if you're not familiar with this and I can give you some pointers on how this works.
Great news ! |
diffrax/__init__.py
Outdated
@@ -6,6 +6,7 @@ | |||
RecursiveCheckpointAdjoint, | |||
) | |||
from .brownian import AbstractBrownianPath, UnsafeBrownianPath, VirtualBrownianTree | |||
from .delays import _HistoryVectorField, Delays, history_extrapolation_implicit |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: I think we only want to expose Delays
in the public interface. Everything else is Diffrax-internal.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm sorry not sure to understand what you mean
diffrax/integrate.py
Outdated
) | ||
return _next_ta, _next_tb, _pred, _step, max_step | ||
|
||
_init_val = (sub_tprev, sub_tnext, True, 0, 400) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if we hit max_steps
here? I think we need a mechanism in this whole discontinuity procedure to allow the discontinuity-finding to fail. (Namely, reject the step and use some more naive way of picking tnext
, e.g. tprev
plus 0.1 times the current interval length or something.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we hit max_steps
that would just means that the NewtonSolver would take the lead with the current root approximation ? Speaking from experience, since each sub_tprev
and sub_tnext
used are pretty small intervals this would probably not happen ? The interval [tprev, tnext]
is splitted into N sub intervals where root tracking is done.
diffrax/integrate.py
Outdated
_discont = _discont_solver(_h, _tb, args).root | ||
_disconts.append(_discont) | ||
if _discont.size == 0: | ||
return jnp.inf |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we can ever hit this branch?
(And indeed, as currently written it would crash, since it doesn't also return a bool.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup we never would hit this, no problem should be encountered since NewtonNonlinearSolver
return NaNs
if it fails right ? (if that is the case we should be good then )
Alright, on to the next block of code! As for |
As for |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is looking good. Let's start adding this into the documentation and so on:
- Add a new doc page under the "Advanced API" section.
- Add a short example. (See the steady-state example for inspiration.)
Let's also add a test_delays.py
file, checking all the edge cases. Off the top of my head:
- Basic checks that the solver actually runs without crashing.
- Numerical checks that the expected solution is obtained.
- Check what happens when we exceed
Delays.max_discontinuities
. - Check what happens when we exceed
Delays.max_steps
. - Test combining delays with stochastic terms.
- Test combining delays with
PIDController(jump_ts=...)
. - Test a 'smooth' DDE (whose initial history is such that there is no discontinuity), and check that this can be solved without any expensive discontinuity detection at all. (Actually, I don't think this is possible at the moment -- maybe the
recurrent_checking
argument should be generalised:discontinuity_checking=True/False/None
for always check / only check on rejected steps / never check?) - Test DDEs that hit both the implicit-step and the explicit-step branches, and that those branches are taken. (To test this: perhaps we can count the number of implicit and explicit steps, and return that in the auxiliary
stats
.)
Other changes that come to mind:
- Another good auxiliary statistic could be how many discontinuities were encountered. (Including any discontinuities arising from
jump_ts
?) - I think the numerics might have still have a subtle bug from the lack of something like
Line 82 in 5f5a121
def _clip_to_end(tprev, tnext, t1, keep_step): tnext
.
As for implicit_jvp
. Fix some function f
, and define y(θ)
as being the value of y
satisfying f(y, θ) = 0
. (And we assume there is a unique such y
.) Then we can see that the function f
has implicitly defined a function θ -> y(θ)
.
We can now seek to evaluate the derivative dy/dθ
. This actually involves a linear solve, and is what implicit_jvp
does. This is pretty easy to do. Probably the best reference is:
http://implicit-layers-tutorial.org/implicit_functions/
Also see the final equation of this section:
https://en.wikipedia.org/wiki/Implicit_function_theorem#Statement_of_the_theorem
You can see this in action for the existing nonlinear solvers, which use the implicit function theorem to differentiate through their root-finding operation.
diffrax/delays.py
Outdated
from .local_interpolation import AbstractLocalInterpolation | ||
from .misc import rms_norm | ||
from .misc.omega import ω | ||
from .misc.unvmap import unvmap_any |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Before merging, we'll have to rebase your branch onto the current main
version of Diffrax. When that happens, heads-up that these imports will change to:
import equinox.internal as eqxi
from equinox.internal import ω
eqxi.unvmap_any
Thanks for the review Patrick :) , so from what I understood what i need to do is to create a new class class DDEImplicitNonLinearSolve(AbstractNonlinearSolver):
def _solve(
self,
fn: callable, # would be terms
x : Pytree, # here would be state.y
nondiff_args: PyTree, # here would be all the other args from current history_extrapolation_implicit(...)
diff_args: PyTree,
) def history_extrapolation_implicit(...):
nonlinearsolver = DDEImplicitNonLinearSolve(...)
results = nonlinearsolver(terms, y, args).root
y, y_error, dense_info, solver_state, solver_result = results
return y, y_error, dense_info, solver_state, solver_result If thats the case could you explain how you usually work with your |
Adding ts = state.dense_ts[...] Unfortunately unwrapping the Buffer with other structure's like infos=jtu.tree_map(lambda x: x[...], state.dense_infos) The problem came from the However, I found another way unwrapped_buffer = jtu.tree_leaves(
eqx.filter(state.dense_infos, eqx.is_inexact_array),
is_leaf=eqx.is_inexact_array,
)
unwrapped_dense_infos = dict(zip(state.dense_infos.keys(), unwrapped_buffer)) |
Notice how I pass in an additional argument in my previous snippet: This is because each The use of buffers is a pretty advanced/annoying detail. I'm pondering using something like Quax to create a safer API for this, but that's a long way down the to-do list. |
Indeed your right, I just realised that in the documentation yesterday, this makes sense now ! i.e. dense_interp = DenseInterpolation(
ts=state.dense_ts[...],
infos = jtu.tree_map(lambda _, x: x[...], dense_info, state.dense_infos),
...
)
(
y,
y_error,
dense_info,
solver_state,
solver_result,
) = history_extrapolation_implicit(
...
) So I would agree that this works if |
Since I'm starting to see activity here again -- you can track my progress updating Diffrax in #217 and https://github.com/patrick-kidger/diffrax/tree/big-refactor. (This currently depends on the unreleased versions of Equinox and jaxtyping.) Mostly there now! |
37775d7
to
37ee18e
Compare
bbad868
to
3020bb5
Compare
Reviving this PR ! The MWE to test it out : import jax
import equinox as eqx
import jax.numpy as jnp
import jax.random as jrandom
import diffrax
class Func(eqx.Module):
linear: eqx.nn.Linear
def __init__(self, data_size, *, key, **kwargs):
super().__init__(**kwargs)
self.linear = eqx.nn.Linear(2 * data_size, data_size, key=key)
def __call__(self, t, y, args, *, history):
return self.linear(jnp.hstack([y, *history]))
class NeuralDDE(eqx.Module):
func: Func
delays: diffrax.Delays
def __init__(self, data_size, delays, *, key, **kwargs):
super().__init__(**kwargs)
self.func = Func(data_size, key=key)
self.delays = delays
def __call__(self, ts, y0):
solution = diffrax.diffeqsolve(
diffrax.ODETerm(self.func),
diffrax.Bosh3(),
t0=ts[0],
t1=ts[-1],
dt0=ts[1] - ts[0],
y0=lambda t: y0,
saveat=diffrax.SaveAt(ts=ts, dense=True),
delays=self.delays,
)
return solution.ys
@eqx.filter_value_and_grad
def grad_loss(model, ti, yi):
y_pred = model(ti, yi[0])
return jnp.mean((yi - y_pred) ** 2)
@eqx.filter_value_and_grad
def grad_loss_batch(model, ti, yi):
y_pred = jax.vmap(model, (None, 0))(ti, yi[:, 0])
return jnp.mean((yi - y_pred) ** 2)
key = jrandom.PRNGKey(0)
ys = jnp.ones((1, 3, 1))
ts = jnp.linspace(0, 1.0, 3)
_, length_size, datasize = ys.shape
delays = diffrax.Delays(delays=(lambda t, y, args: 1.0,))
model_dde = NeuralDDE(datasize, delays, key=key)
print("Starting check_tracer_leaks()")
with jax.check_tracer_leaks():
loss, grads = grad_loss(model_dde, ts, ys[0])
print("SUCCESS with check_tracer_leaks() with grad_loss()")
loss2, grads2 = grad_loss_batch(model_dde, ts, ys)
print("SUCCESS with check_tracer_leaks() with grad_loss_batch()") Neural DDE works with A modified version of the To integrate a DDE, two methods can be used :
In hindsight, I would like to go for the first option (1/). This would mean that some code (e.g https://github.com/thibmonsel/diffrax/blob/7e7d1b443e76c2573458e2d7f4a72223967cb01d/diffrax/_delays.py#L297) and many attributes of |
New files :
discontinuity.py
that does the root finding during integration steps.Modified files:
integrate.py
changed a bit of the code but essentially looks the same but with moreif
statements. There is also the discontinuity handling before each integration step done. Added 2 new arguments to_State
(discontinuities
,discontinuities_save_index
)constant.py
does the discontinuity checking and returns the next integration step. But as said in WIP theprevbefore
andnextafter
are done in theloop()
Followed your suggestion regarding dropping
y0_history
and putting it iny0
. However by doing this we must passy0
to theloop
function now. Haven't done thePyTree
handling of delays yet. Only works for constant stepsize controller, doing adaptive now.PS : I dont have the save saving format as you so
terms.py
shows some deletion and addition for no reason ....Boilerplate code for a dde :