-
-
Notifications
You must be signed in to change notification settings - Fork 144
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Coupled SDE System Implementation #461
Comments
The classic (#446 (comment)) strikes once again 😉 It seems like there are a few errors here. First, you return a diagonal, but control term is for full matrices by default, so you need to fix that (with a DiagonalOperator). Second, SEA requires a SpaceTimeLevy area (this should go in the solver docs imo). Finally, SEA requires additive noise (i.e. g is not a function of x) so you can't use this solver with that noise term. Using all three tricks you get something that works and looks like: import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
import diffrax
import lineax as lx
def sde_drift(t, y, args):
N, _ = args
beta = jnp.exp(y[3])
gamma = jnp.exp(y[4])
dS = -(beta * y[0] * y[1]) / N
dI = (beta * y[0] * y[1]) / N - y[1] * gamma
dR = y[1] * gamma
# only diffusion, no drift
dbeta = 0.0 # jnp.array([0.0])
dgamma = 0.0 # jnp.array([0.0])
dy = jnp.array([dS, dI, dR, dbeta, dgamma])
return dy
def sde_diffusion(t, y, args):
_, sigma_1 = args
y1, y2, y3, y4, y5 = y
diagonal = jnp.array([0.0, 0.0, 0.0, sigma_1 * y4, sigma_1 * y5])
return lx.DiagonalLinearOperator(diagonal)
def sde():
t0 = 0
t1 = 100
dt0 = 0.1
y0 = jnp.array([3990.0, 10.0, 0.01, jnp.log(0.25), jnp.log(0.05)])
args = (4000.0, 0.2)
bm = diffrax.VirtualBrownianTree(t0, t1, tol=1e-2, shape=(5,), key=jr.PRNGKey(42), levy_area=diffrax.SpaceTimeLevyArea)
terms = diffrax.MultiTerm(diffrax.ODETerm(sde_drift), diffrax.ControlTerm(sde_diffusion, bm))
solver = diffrax.GeneralShARK()
saveat = diffrax.SaveAt(dense=True)
print(type(terms))
sol = diffrax.diffeqsolve(terms, solver, t0, t1, dt0=dt0, y0=y0, args=args, saveat=saveat)
print(sol)
sde() |
Thanks for the reply; I must have missed some of the points about the solver you make in the docs. Trying the code you suggested, I get the error For completeness, here's the library versions I'm using:
Thanks again for the assistance. |
Yes, I was using diffrax 0.6.0 |
That was it, thanks again! |
Hi,
I updated all the packages to the versions above abd I get the error:
Do you have any ideas? |
What versions are you using? |
Hi Owen, |
Hi all,
thanks for the great library. I'm having an issue implementing a coupled system of SDEs. I'm getting an
ValueError: `terms` must be a PyTree of `AbstractTerms` (such as `ODETerm`)
error. The system is:The code is
Printing the
type
ofterms
yields'diffrax._term.MultiTerm
, so I'm not entirely sure where to look. What can you suggest to look at?Thanks in advance.
The text was updated successfully, but these errors were encountered: