-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Apple Silicon: error: failed to legalize operation 'mhlo.pad' #16366
Comments
@mlaves is there more to the error? In particular, I think more details about the operation should be printed? |
@hawkinsp Sure, here's the full stacktrace.
|
I get the same error, with almost identical specs. Except python 3.9, and Apple Mac M2 Pro. Seems to be coming from within jax, as opposed to jax-metal however.
|
Thanks for sending the bug report. JAX-Metal plugin do not support pad with non-zero interior_padding. We will look into expanding the coverage and update here. |
I am running a pretrained model, I wonder if there is a way to change my inputs/tokenisation to try and add interior_padding to circumvent this issue? |
@BradBalderson It will be impossible to say without more details on how the operator is used in the model. If it is applied to one of the model inputs, perhaps. |
BTW, you can implement interior padding with edge padding, if the interior padding is from your user code. For example, to pad the innermost dimension, you do this:
But... it might just be better to wait for our colleagues from Apple to fix the plugin :-) |
The caviat: Some XLS ops are not compiled correctly. According to jax-ml/jax#16366, some of the XLA ops are not yet supported by Apple Sillicon.
I encountered the same bug by trying to calculate the grad of a loss function for a physics informed NN problem in a mac M1. Jax version: 0.4.20
|
Hello - I'm getting this error when running the following, very simple operation import jax.numpy as jnp
jnp.cumprod(jnp.arange(10))
My env
Would be really cool if someone could fix this, as it makes |
For the pad with interior_padding, the fix will be in the upcoming jax-metal release and work in 14.4 OS. |
The fix is in jax-metal 0.0.6. Some output from running flax/examples/mnist:
|
Description
When following the MNIST example from
flax
(https://github.com/google/flax/tree/main/examples/mnist/), the following error occurs when using the latestjax-metal
plugin installed as described at https://developer.apple.com/metal/jax/ :What jax/jaxlib version are you using?
jax 0.4.11, jaxlib v0.4.10
Which accelerator(s) are you using?
MPS
Additional system info
Python 3.11, macOS 13.4, Mac Mini M2 Pro
NVIDIA GPU info
No response
The text was updated successfully, but these errors were encountered: