-
-
Notifications
You must be signed in to change notification settings - Fork 150
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TPU] XLA: Channel is used for multiple host instructions #628
Comments
This looks like a bug in XLA:TPU. I'd suggest filing it either on the main JAX repository or on the XLA repo. You'll probably need to find a MWE, though.
Equinox doesn't maintain any state, actually! It was an important part of Equinox's design that we not go around mutating things. The Module is a PyTree of arrays, and these are explicitly updated, by you, when you do things like gradient descent. Equinox actually uses callbacks sparingly -- these are the only functions which use them, and they're all fairly uncommon:
are you explicitly using any of these? |
Nope. So its probably some constraint on TPUs placed by XLA 😥 I guess I'll try and debug it, but so far hadn't had any luck |
@patrick-kidger Turns out, the problem was using I don't know if you feel its worth it to resolve the bug on TPUs 😅 If you have access to them, I can try to make a repro for them. |
Ah! Indeed I'd forgotten, I'd suggest reporting this as an XLA bug regardless, but it's probably not a bug I can resolve directly. As a possible workaround, you can try commenting out every |
Thanks! I tried commenting out all the callbacks and rebuilt equinox like this and I talked to James Bradbury on twitter, and he said its a bit harder to work with equinox as host callbacks as "messy" and are closer to a "hack" so its likely that its clashing with some TPU specific optimizations built directly into XLA. I guess it'd require quite a bit of surgery to fix this bug - so it might be some time before its fixed 🙂 Until that's fixed, I suppose in the next release, maybe you could expose some flag for Again, thanks for everything and providing such a lovely library ❤️ and have a good weekend! |
This is to fix a crash on TPUs, see #628.
Ah, marvellous! I'm glad that's working for you. I've just written #631 to fix this up for the next release. Since it is specifically |
This is to fix a crash on TPUs, see #628.
Looks like using #631 looks good - I guess in the future, if more XLA bugs crop up, we could setup a dedicated Thanks for everything again and have a great weekend! |
This is to fix a crash on TPUs, see #628.
I'm training a custom arch of mine, and had a usecase where I wanted to perform 2 (different) forward passes which have a different computational graph. I wanted to take the outputs from both flows, and evaluate an aggregated loss.
But apparently, if I compute two branches, I get the below error.
Traceback
Code
I don't have a repro unfortunately. But my codeflow looks like this:
This is a bit convoluted, but the core is that there are the 2 different forward passes which explicitly depend on the same
model
and is reutilized here, with slightly different arguments (mainlyprev_thought
)Because the error occurs only when both of the flows are present - either through using
jax.lax.cond
to dynamically switch between both or simply aggregating outputs from both forward passes simultaneously, the common problem seems to be whenXLA
is unable to handle both computational flows.(Note:
jax.lax.cond
is lowered toselect
whenvmap
-ed, which is why both flows do end up getting computed too)This error is triggered only on TPUs, not on GPUs so perhaps it might just turn out to be a limitation of
equinox
. I don't understand much of howequinox
maintains state - my basic understanding is that the actual state at runtime is held byjax
internally andequinox
just issues host callbacks to mutate that state as needed - where theModule
is just the abstractPyTree
representation?The error kindof sounds like
equinox
issued multiple host callbacks and they collide. Why its only a problem on TPUs specifically, could be down to TPU-specific optimizations of XLA.Would you have any idea regarding this?
The text was updated successfully, but these errors were encountered: