Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Coupled SDE System Implementation #461

Open
aspannaus opened this issue Jul 10, 2024 · 7 comments
Open

Coupled SDE System Implementation #461

aspannaus opened this issue Jul 10, 2024 · 7 comments

Comments

@aspannaus
Copy link

aspannaus commented Jul 10, 2024

Hi all,

thanks for the great library. I'm having an issue implementing a coupled system of SDEs. I'm getting an ValueError: `terms` must be a PyTree of `AbstractTerms` (such as `ODETerm`) error. The system is:

$$\begin{aligned} \frac{\mathrm{d} S(t)}{\mathrm{d} t} &= -\beta(t)S(t)\frac{I(t)}{N} \mathrm{d} t, \\\ \frac{\mathrm{d} I(t)}{\mathrm{d}t} &= (\beta(t)S(t)\frac{I(t)}{N} - \gamma(t) I(t)) \mathrm{d} t,\\\ \frac{\mathrm{d} R(t)}{\mathrm{d}t} &= \gamma(t) I(t)\, \mathrm{d} t,\\\ \frac{\mathrm{d} \log\beta(t)}{\mathrm{d}t} &= w_3\mathrm{d} B_w(t),\\\ \frac{\mathrm{d} \log\gamma(t)}{\mathrm{d}t} &= u_3 \mathrm{d}B_u(t) \end{aligned}$$

The code is


import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt

import diffrax

def sde_drift(t, y, args):
    N, _ = args
    beta = jnp.exp(y[3])
    gamma = jnp.exp(y[4])
    dS = -(beta * y[0] * y[1]) / N
    dI = (beta * y[0] * y[1]) / N - y[1] * gamma
    dR = y[1] * gamma
    # only diffusion, no drift
    dbeta = 0.0  # jnp.array([0.0])
    dgamma = 0.0  # jnp.array([0.0])
    dy = jnp.array([dS, dI, dR, dbeta, dgamma])

    return dy

def sde_diffusion(t, y, args):
    _, sigma_1 = args
    y1, y2, y3, y4, y5 = y
    diagonal = jnp.array([0.0, 0.0, 0.0, sigma_1 * y4, sigma_1 * y5])
    return diagonal 


def sde():

    t0 = 0
    t1 = 100
    dt0 = 0.1
    y0 = jnp.array([3990.0, 10.0, 0.01, jnp.log(0.25), jnp.log(0.05)])
    args = (4000.0, 0.2)

    bm = diffrax.VirtualBrownianTree(t0, t1, tol=1e-2, shape=(5,), key=jr.PRNGKey(42))
    terms = diffrax.MultiTerm(diffrax.ODETerm(sde_drift), diffrax.ControlTerm(sde_diffusion, bm))
    solver = diffrax.SEA()
    saveat = diffrax.SaveAt(dense=True)

    print(type(terms))

    sol = diffrax.diffeqsolve(terms, solver, t0, t1, dt0=dt0, y0=y0, args=args, saveat=saveat)
    print(sol)

Printing the type of terms yields 'diffrax._term.MultiTerm, so I'm not entirely sure where to look. What can you suggest to look at?

Thanks in advance.

@lockwo
Copy link
Contributor

lockwo commented Jul 11, 2024

The classic (#446 (comment)) strikes once again 😉

It seems like there are a few errors here. First, you return a diagonal, but control term is for full matrices by default, so you need to fix that (with a DiagonalOperator). Second, SEA requires a SpaceTimeLevy area (this should go in the solver docs imo). Finally, SEA requires additive noise (i.e. g is not a function of x) so you can't use this solver with that noise term.

Using all three tricks you get something that works and looks like:

import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt

import diffrax
import lineax as lx

def sde_drift(t, y, args):
    N, _ = args
    beta = jnp.exp(y[3])
    gamma = jnp.exp(y[4])
    dS = -(beta * y[0] * y[1]) / N
    dI = (beta * y[0] * y[1]) / N - y[1] * gamma
    dR = y[1] * gamma
    # only diffusion, no drift
    dbeta = 0.0  # jnp.array([0.0])
    dgamma = 0.0  # jnp.array([0.0])
    dy = jnp.array([dS, dI, dR, dbeta, dgamma])

    return dy

def sde_diffusion(t, y, args):
    _, sigma_1 = args
    y1, y2, y3, y4, y5 = y
    diagonal = jnp.array([0.0, 0.0, 0.0, sigma_1 * y4, sigma_1 * y5])
    return lx.DiagonalLinearOperator(diagonal) 


def sde():

    t0 = 0
    t1 = 100
    dt0 = 0.1
    y0 = jnp.array([3990.0, 10.0, 0.01, jnp.log(0.25), jnp.log(0.05)])
    args = (4000.0, 0.2)

    bm = diffrax.VirtualBrownianTree(t0, t1, tol=1e-2, shape=(5,), key=jr.PRNGKey(42), levy_area=diffrax.SpaceTimeLevyArea)
    terms = diffrax.MultiTerm(diffrax.ODETerm(sde_drift), diffrax.ControlTerm(sde_diffusion, bm))
    solver = diffrax.GeneralShARK()
    saveat = diffrax.SaveAt(dense=True)

    print(type(terms))

    sol = diffrax.diffeqsolve(terms, solver, t0, t1, dt0=dt0, y0=y0, args=args, saveat=saveat)
    print(sol)

sde()

@aspannaus
Copy link
Author

Thanks for the reply; I must have missed some of the points about the solver you make in the docs.

Trying the code you suggested, I get the error ValueError: Custom node type mismatch: expected type: <class 'lineax._operator.DiagonalLinearOperator'>, value: Traced<ShapedArray(float32[5])>with<DynamicJaxprTrace(level=2/0)>. I had tried this previously without success, but perhaps it is correct and there's something behind the scenes happening?

For completeness, here's the library versions I'm using:

  • jax: 0.4.30
  • diffrax: 0.5.1
  • lineax: 0.0.5
  • equinox: 0.11.4

Thanks again for the assistance.

@lockwo
Copy link
Contributor

lockwo commented Jul 11, 2024

Yes, I was using diffrax 0.6.0

@aspannaus
Copy link
Author

That was it, thanks again!

@SoerenNagel
Copy link

Hi,
I had the same issue as above. `

ValueError: Custom node type mismatch: expected type: <class `'lineax._operator.DiagonalLinearOperator'>``

I updated all the packages to the versions above abd I get the error:

AttributeError: module 'opt_einsum' has no attribute 'paths'

Do you have any ideas?

@lockwo
Copy link
Contributor

lockwo commented Aug 15, 2024

What versions are you using?

@SoerenNagel
Copy link

Hi Owen,
i fixed the issue by setting up a new conda environment and made sure jax, jaxlib,equinox, lineax and diffrax through pip and not conda (where diffrax 0.6.0 is not yet availble). I don't really know what the underlying issue was.
But thanks anyway.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants