Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can't use Equinox inside term #446

Open
pascal-mueller opened this issue Jun 24, 2024 · 4 comments
Open

Can't use Equinox inside term #446

pascal-mueller opened this issue Jun 24, 2024 · 4 comments
Labels
question User queries

Comments

@pascal-mueller
Copy link

I have this code solving a PDE. If I set the force inside equations to a numerical value, everything is fine but if I try to replace it with a neural network, I get:


% python osc.py
Traceback (most recent call last):
  File ".../project_3/osc.py", line 71, in <module>
    solution = dfx.diffeqsolve(
               ^^^^^^^^^^^^^^^^
  File ".../project_3/venv/lib/python3.12/site-packages/equinox/_jit.py", line 206, in __call__
    return self._call(False, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../project_3/venv/lib/python3.12/site-packages/equinox/_module.py", line 1053, in __call__
    return self.__func__(self.__self__, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../project_3/venv/lib/python3.12/site-packages/equinox/_jit.py", line 200, in _call
    out = self._cached(dynamic_donate, dynamic_nodonate, static)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../project_3/venv/lib/python3.12/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File ".../project_3/venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 327, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
                                                                ^^^^^^^^^^^^^^^^^^^^
  File ".../project_3/venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 175, in _python_pjit_helper
    attrs_tracked) = _infer_params(jit_info, args, kwargs)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../project_3/venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 627, in _infer_params
    jaxpr, consts, out_shardings_flat, out_layouts_flat, attrs_tracked = _pjit_jaxpr(
                                                                         ^^^^^^^^^^^^
  File ".../project_3/venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 1275, in _pjit_jaxpr
    jaxpr, final_consts, out_type, attrs_tracked = _create_pjit_jaxpr(
                                                   ^^^^^^^^^^^^^^^^^^^
  File ".../project_3/venv/lib/python3.12/site-packages/jax/_src/linear_util.py", line 350, in memoized_fun
    ans = call(fun, *args)
          ^^^^^^^^^^^^^^^^
  File ".../project_3/venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 1189, in _create_pjit_jaxpr
    jaxpr, global_out_avals, consts, attrs_tracked = pe.trace_to_jaxpr_dynamic(
                                                     ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../project_3/venv/lib/python3.12/site-packages/jax/_src/profiler.py", line 335, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File ".../project_3/venv/lib/python3.12/site-packages/jax/_src/interpreters/partial_eval.py", line 2347, in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts, attrs_tracked = trace_to_subjaxpr_dynamic(
                                              ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../project_3/venv/lib/python3.12/site-packages/jax/_src/interpreters/partial_eval.py", line 2370, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../project_3/venv/lib/python3.12/site-packages/jax/_src/linear_util.py", line 192, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../project_3/venv/lib/python3.12/site-packages/equinox/_jit.py", line 49, in fun_wrapped
    out = fun(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^
  File ".../project_3/venv/lib/python3.12/site-packages/diffrax/_integrate.py", line 781, in diffeqsolve
    raise ValueError(
ValueError: `terms` must be a PyTree of `AbstractTerms` (such as `ODETerm`), with structure <class 'diffrax._term.AbstractTerm'>

Note the breakpoint I did and the type.

% python osc.py
> /.../osc.py(69)<module>()
-> saveAt = dfx.SaveAt(ts=jnp.linspace(t_start, t_end, num_points))
(Pdb) type(term)
<class 'diffrax._term.ODETerm'>

The type is what is expected. So what exactly am I doing wrong?

Code:

import jax
import jax.numpy as jnp
import diffrax as dfx
import equinox as eqx
import matplotlib.pyplot as plt
from jax import random


# Define the neural network for the external force using Equinox
class ForceMLP(eqx.Module):
    input: eqx.nn.Linear
    dense1: eqx.nn.Linear
    dense2: eqx.nn.Linear
    output: eqx.nn.Linear

    def __init__(self, key):
        key1, key2, key3, key4 = jax.random.split(key, 4)
        self.input = eqx.nn.Linear(1, 256, key=key1)
        self.dense1 = eqx.nn.Linear(256, 256, key=key2)
        self.dense2 = eqx.nn.Linear(256, 256, key=key3)
        self.output = eqx.nn.Linear(256, 1, key=key4)

    def __call__(self, t):
        x = self.input(t)
        x = jax.nn.tanh(x)
        x = self.dense1(x)
        x = jax.nn.relu(x)
        x = self.dense2(x)
        x = jax.nn.relu(x)
        F = self.output(x)
        return F


# Initialize the neural network
key = random.PRNGKey(0)
force_mlp = ForceMLP(key)


def get_force(t):
    return force_mlp(t)


# Define the equations for the PDE
def equations(t, y, args):
    position, velocity = y
    force = get_force(t)

    # Damped harmonic oscillator equations
    damping = 0.1
    spring_constant = 1.0

    dposition_dt = velocity
    dvelocity_dt = -damping * velocity - spring_constant * position + force

    return jnp.array([dposition_dt, dvelocity_dt])


# Initial conditions and time span
y0 = jnp.array([1.0, 0.0])  # Initial position and velocity
t_start = 0.0
t_end = 10.0
num_points = 100

# ODE solver using diffrax
solver = dfx.Tsit5()  # Tsitouras 5th order method
stepsize_controller = dfx.PIDController(rtol=1e-6, atol=1e-6)
term = dfx.ODETerm(equations)
breakpoint()
saveAt = dfx.SaveAt(ts=jnp.linspace(t_start, t_end, num_points))

# Solve the ODE
solution = dfx.diffeqsolve(
    term,
    solver,
    t0=t_start,
    t1=t_end,
    dt0=0.1,
    y0=y0,
    saveat=saveAt,
    stepsize_controller=stepsize_controller,
)

# Print the solution
ts = solution.ts
ys = solution.ys

plt.plot(ts, ys[:, 0], label="Position")
plt.plot(ts, ys[:, 1], label="Velocity")
plt.xlabel("Time")
plt.ylabel("Values")
plt.legend()
plt.title("Damped Harmonic Oscillator with Neural Network Force")
plt.show()

@lockwo
Copy link
Contributor

lockwo commented Jun 24, 2024

Ahh the classic

ValueError: `terms` must be a PyTree of `AbstractTerms` (such as `ODETerm`), with structure <class 'diffrax._term.AbstractTerm'>

I do think there could be a more informative error message here because in my experience 9 times out of 10, this is because there is some shape mismatch in the term input/output. A lot of the time, this can be revealed my just manually inspecting the drift function. In this case, if we just eval print(equations(t_start, y0, None)) we see it errors because there is a shape error ValueError: matmul input operand 1 must have ndim at least 1, but it has ndim 0. Basically, t is a scalar, but the matmul is expecting something of at least 1 dimension. The fix is just to add a dimension to T.

Here is the full fix:

import jax
import jax.numpy as jnp
import diffrax as dfx
import equinox as eqx
import matplotlib.pyplot as plt
from jax import random


# Define the neural network for the external force using Equinox
class ForceMLP(eqx.Module):
    input: eqx.nn.Linear
    dense1: eqx.nn.Linear
    dense2: eqx.nn.Linear
    output: eqx.nn.Linear

    def __init__(self, key):
        key1, key2, key3, key4 = jax.random.split(key, 4)
        self.input = eqx.nn.Linear(1, 256, key=key1)
        self.dense1 = eqx.nn.Linear(256, 256, key=key2)
        self.dense2 = eqx.nn.Linear(256, 256, key=key3)
        self.output = eqx.nn.Linear(256, 1, key=key4)

    def __call__(self, t):
        x = self.input(t)
        x = jax.nn.tanh(x)
        x = self.dense1(x)
        x = jax.nn.relu(x)
        x = self.dense2(x)
        x = jax.nn.relu(x)
        F = self.output(x)
        return F


# Initialize the neural network
key = random.PRNGKey(0)
force_mlp = ForceMLP(key)


def get_force(t):
    return force_mlp(t)


# Define the equations for the PDE
def equations(t, y, args):
    position, velocity = y
    t = jnp.array([t])
    force = get_force(t).squeeze()

    # Damped harmonic oscillator equations
    damping = 0.1
    spring_constant = 1.0

    dposition_dt = velocity
    dvelocity_dt = -damping * velocity - spring_constant * position + force

    return jnp.array([dposition_dt, dvelocity_dt])


# Initial conditions and time span
y0 = jnp.array([1.0, 0.0])  # Initial position and velocity
t_start = 0.0
t_end = 10.0
num_points = 100

print(equations(t_start, y0, None))

# ODE solver using diffrax
solver = dfx.Tsit5()  # Tsitouras 5th order method
stepsize_controller = dfx.PIDController(rtol=1e-6, atol=1e-6)
term = dfx.ODETerm(equations)
saveAt = dfx.SaveAt(ts=jnp.linspace(t_start, t_end, num_points))

# Solve the ODE
solution = dfx.diffeqsolve(
    term,
    solver,
    t0=t_start,
    t1=t_end,
    dt0=0.1,
    y0=y0,
    saveat=saveAt,
    stepsize_controller=stepsize_controller,
)

# Print the solution
ts = solution.ts
ys = solution.ys

plt.plot(ts, ys[:, 0], label="Position")
plt.plot(ts, ys[:, 1], label="Velocity")
plt.xlabel("Time")
plt.ylabel("Values")
plt.legend()
plt.title("Damped Harmonic Oscillator with Neural Network Force")
plt.show()
Screenshot 2024-06-24 at 11 46 09 AM

@patrick-kidger
Copy link
Owner

patrick-kidger commented Jun 24, 2024

Thank you @lockwo for the help!

By the way, I'd be very happy to take a pull request improving this error message, describing whatever you think most needs adding!

(It's new to me that this is apparently 'classic' -- probably that's because I'm used to navigating this library in a rather different way... :D )

@patrick-kidger patrick-kidger added the question User queries label Jun 24, 2024
@pascal-mueller
Copy link
Author

Thanks for the help. I am very new to JAX and Diffrax, so I had a bit of problems debugging it because I usually step into my programs with a debugger but I'm not used to the functional and JIT nature of JAX yet.

I realized pretty quickly that there is a "hidden error" behind this error message but I just couldn't figure out what and how to access it but that's mainly due to inexperience.

Thanks a lot

@lockwo
Copy link
Contributor

lockwo commented Jun 25, 2024

Thank you @lockwo for the help!

By the way, I'd be very happy to take a pull request improving this error message, describing whatever you think most needs adding!

(It's new to me that this is apparently 'classic' -- probably that's because I'm used to navigating this library in a rather different way... :D )

Maybe classic is a strong word haha, but I've been doing a lot of 1D systems this past week so I saw it a lot when I messed up how I squeezed or unsqueezed things and got used to seeing the message. I will see about adding a more clear error statement about shaping when possible.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

3 participants