Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TPU] XLA: Channel is used for multiple host instructions #628

Closed
neel04 opened this issue Dec 23, 2023 · 7 comments
Closed

[TPU] XLA: Channel is used for multiple host instructions #628

neel04 opened this issue Dec 23, 2023 · 7 comments
Labels
question User queries

Comments

@neel04
Copy link

neel04 commented Dec 23, 2023

I'm training a custom arch of mine, and had a usecase where I wanted to perform 2 (different) forward passes which have a different computational graph. I wanted to take the outputs from both flows, and evaluate an aggregated loss.

But apparently, if I compute two branches, I get the below error.

Traceback

Traceback (most recent call last):
  File "/kaggle/working/ReAct_Jax/train_model.py", line 47, in <module>
    main(key)
  File "/kaggle/working/ReAct_Jax/train_model.py", line 43, in main
    trainer.train(args.epochs, trainloader, valloader, key)
  File "/kaggle/working/ReAct_Jax/ReAct/utils/trainer.py", line 235, in train
    loss, model, opt_state = make_step(model, seq, label, pad_mask, rndm_n, rndm_k,
                             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/equinox/_jit.py", line 206, in __call__
    return self._call(False, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/equinox/_module.py", line 875, in __call__
    return self.__func__(self.__self__, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/equinox/_jit.py", line 198, in _call
    out = self._cached(dynamic_donate, dynamic_nodonate, static)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: during context [pre-optimization]: RET_CHECK failure (third_party/tensorflow/compiler/xla/service/hlo_verifier.cc:2330) instructions.size() == 2 channel 11 is used for multiple host send/recv instructions
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

Code

I don't have a repro unfortunately. But my codeflow looks like this:

@jax.jit
def n_k_loop(model: eqx.Module, input_arr: Array, pad_mask: Array, n: int, k: int, key: PRNGKeyArray) -> Array:
    key1, key2 = jax.random.split(key, 2)

    # forward pass the model without tracking grads
    _, intermediate_array = model(
        input_arr, n,
        pad_mask=pad_mask,
        prev_thought=None,
        key=key1)

    intermediate_array = jax.lax.stop_gradient(intermediate_array)

    # n-k passes, but track the gradient this time
    output, _ = model(input_arr, k, pad_mask=pad_mask, prev_thought=intermediate_array, key=key2)

    return output

@jax.jit
def k_loop(model: eqx.Module, input_arr: Array, pad_mask: Array, n: int, k: int, key: PRNGKeyArray) -> Array:
    key1, key2 = jax.random.split(key, 2)

    output, _ = model(input_arr, n, pad_mask=pad_mask, prev_thought=None, key=key1)

    return output

This is a bit convoluted, but the core is that there are the 2 different forward passes which explicitly depend on the same model and is reutilized here, with slightly different arguments (mainly prev_thought)

Because the error occurs only when both of the flows are present - either through using jax.lax.cond to dynamically switch between both or simply aggregating outputs from both forward passes simultaneously, the common problem seems to be when XLA is unable to handle both computational flows.

(Note: jax.lax.cond is lowered to select when vmap-ed, which is why both flows do end up getting computed too)

  • This error is triggered only on TPUs, not on GPUs so perhaps it might just turn out to be a limitation of equinox. I don't understand much of how equinox maintains state - my basic understanding is that the actual state at runtime is held by jax internally and equinox just issues host callbacks to mutate that state as needed - where the Module is just the abstract PyTree representation?

  • The error kindof sounds like equinox issued multiple host callbacks and they collide. Why its only a problem on TPUs specifically, could be down to TPU-specific optimizations of XLA.

Would you have any idea regarding this?

@patrick-kidger
Copy link
Owner

This looks like a bug in XLA:TPU. I'd suggest filing it either on the main JAX repository or on the XLA repo. You'll probably need to find a MWE, though.

I don't understand much of how equinox maintains state - my basic understanding is that the actual state at runtime is held by jax internally and equinox just issues host callbacks to mutate that state as needed - where the Module is just the abstract PyTree representation?

Equinox doesn't maintain any state, actually! It was an important part of Equinox's design that we not go around mutating things. The Module is a PyTree of arrays, and these are explicitly updated, by you, when you do things like gradient descent.

Equinox actually uses callbacks sparingly -- these are the only functions which use them, and they're all fairly uncommon:

  • eqx.error_if, eqx.branched_error_if,
  • eqx.filter_pure_callback,
  • eqx.debug.store_dce,
  • eqx.internal.noinline (and this one isn't even a documented feature :) )

are you explicitly using any of these?

@patrick-kidger patrick-kidger added the question User queries label Dec 23, 2023
@neel04
Copy link
Author

neel04 commented Dec 25, 2023

are you explicitly using any of these?

Nope. So its probably some constraint on TPUs placed by XLA 😥 I guess I'll try and debug it, but so far hadn't had any luck

@neel04
Copy link
Author

neel04 commented Dec 26, 2023

@patrick-kidger Turns out, the problem was using eqx.internal.while_loop, specifically the checkpoint-ed form. Using the bounded type of loop works perfectly fine.

I don't know if you feel its worth it to resolve the bug on TPUs 😅 If you have access to them, I can try to make a repro for them.

@patrick-kidger
Copy link
Owner

Ah! Indeed I'd forgotten, eqx.internal.while_loop uses eqx.error_if, which then uses a callback.

I'd suggest reporting this as an XLA bug regardless, but it's probably not a bug I can resolve directly.

As a possible workaround, you can try commenting out every error_if inside the implementation of eqx.internal.while_loop. I imagine I'll actually do something like that in the next release of Equinox, actually -- this code has now proven itself pretty reliable.

@neel04
Copy link
Author

neel04 commented Dec 29, 2023

Thanks! I tried commenting out all the callbacks and rebuilt equinox like this and kind='checkpointed works pretty well now.

I talked to James Bradbury on twitter, and he said its a bit harder to work with equinox as host callbacks as "messy" and are closer to a "hack" so its likely that its clashing with some TPU specific optimizations built directly into XLA. I guess it'd require quite a bit of surgery to fix this bug - so it might be some time before its fixed 🙂

Until that's fixed, I suppose in the next release, maybe you could expose some flag for internal.while_loop that optionally disables host callbacks if the user desires (atleast the non-critical ones). It might go against the equinox philosophy of simplicity I suppose, but IMO its fine because its already an internal application to users are more likely to be careful when using it and (hopefully) read all the documentation.

Again, thanks for everything and providing such a lovely library ❤️ and have a good weekend!

patrick-kidger added a commit that referenced this issue Dec 30, 2023
This is to fix a crash on TPUs, see #628.
@patrick-kidger
Copy link
Owner

patrick-kidger commented Dec 30, 2023

Ah, marvellous! I'm glad that's working for you. I've just written #631 to fix this up for the next release.

Since it is specifically error_if that is causing you problems, then one other option worth knowing is the environment variable EQX_ON_ERROR=nan. (Documentation here.) This will disable every eqx.error_if used anywhere in your program. This is really intended as a "debug vs release mode" optimisation -- to remove the checks once you're satisfied that your program should always work. But it could also help with this issue.

patrick-kidger added a commit that referenced this issue Dec 30, 2023
This is to fix a crash on TPUs, see #628.
@neel04
Copy link
Author

neel04 commented Dec 30, 2023

Looks like using EQX_ON_ERROR=nan works pretty well here and resolves the issue 😄

#631 looks good - I guess in the future, if more XLA bugs crop up, we could setup a dedicated TPU_DEBUGGING flag that in turn would switch on and off the various flags to minimize collisions with XLA, since its a bit of a black box that we usually need to work around...

Thanks for everything again and have a great weekend!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

2 participants