Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apple Silicon: error: failed to legalize operation 'mhlo.pad' #16366

Closed
mlaves opened this issue Jun 12, 2023 · 13 comments
Closed

Apple Silicon: error: failed to legalize operation 'mhlo.pad' #16366

mlaves opened this issue Jun 12, 2023 · 13 comments
Assignees
Labels

Comments

@mlaves
Copy link

mlaves commented Jun 12, 2023

Description

When following the MNIST example from flax (https://github.com/google/flax/tree/main/examples/mnist/), the following error occurs when using the latest jax-metal plugin installed as described at https://developer.apple.com/metal/jax/ :

jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: /Users/max/jax-metal/lib/python3.11/site-packages/flax/linen/pooling.py:69:6: error: failed to legalize operation 'mhlo.pad'
  y = lax.reduce_window(inputs, init, reduce_fn, dims, strides, padding)

What jax/jaxlib version are you using?

jax 0.4.11, jaxlib v0.4.10

Which accelerator(s) are you using?

MPS

Additional system info

Python 3.11, macOS 13.4, Mac Mini M2 Pro

NVIDIA GPU info

No response

@mlaves mlaves added the bug Something isn't working label Jun 12, 2023
@hawkinsp
Copy link
Collaborator

@kulinseth @shuhand0

@hawkinsp
Copy link
Collaborator

@mlaves is there more to the error? In particular, I think more details about the operation should be printed?

@mlaves
Copy link
Author

mlaves commented Jun 13, 2023

@hawkinsp Sure, here's the full stacktrace.

Traceback (most recent call last):
  File "/Users/max/flax/examples/mnist/main.py", line 65, in <module>
    app.run(main)
  File "/Users/max/jax-metal/lib/python3.11/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/Users/max/jax-metal/lib/python3.11/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
             ^^^^^^^^^^
  File "/Users/max/flax/examples/mnist/main.py", line 60, in main
    train.train_and_evaluate(FLAGS.config, FLAGS.workdir)
  File "/Users/max/flax/examples/mnist/train.py", line 140, in train_and_evaluate
    state, train_loss, train_accuracy = train_epoch(state, train_ds,
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/max/flax/examples/mnist/train.py", line 89, in train_epoch
    grads, loss, accuracy = apply_model(state, batch_images, batch_labels)
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/max/jax-metal/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File "/Users/max/jax-metal/lib/python3.11/site-packages/jax/_src/pjit.py", line 249, in cache_miss
    outs, out_flat, out_tree, args_flat = _python_pjit_helper(
                                          ^^^^^^^^^^^^^^^^^^^^
  File "/Users/max/jax-metal/lib/python3.11/site-packages/jax/_src/pjit.py", line 160, in _python_pjit_helper
    out_flat = pjit_p.bind(*args_flat, **params)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/max/jax-metal/lib/python3.11/site-packages/jax/_src/core.py", line 2647, in bind
    return self.bind_with_trace(top_trace, args, params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/max/jax-metal/lib/python3.11/site-packages/jax/_src/core.py", line 383, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/max/jax-metal/lib/python3.11/site-packages/jax/_src/core.py", line 790, in process_primitive
    return primitive.impl(*tracers, **params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/max/jax-metal/lib/python3.11/site-packages/jax/_src/pjit.py", line 1193, in _pjit_call_impl
    return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/max/jax-metal/lib/python3.11/site-packages/jax/_src/pjit.py", line 1177, in call_impl_cache_miss
    out_flat, compiled = _pjit_call_impl_python(
                         ^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/max/jax-metal/lib/python3.11/site-packages/jax/_src/pjit.py", line 1113, in _pjit_call_impl_python
    always_lower=False, lowering_platform=None).compile()
                                                ^^^^^^^^^
  File "/Users/max/jax-metal/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2319, in compile
    executable = UnloadedMeshExecutable.from_hlo(
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/max/jax-metal/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2638, in from_hlo
    xla_executable, compile_options = _cached_compilation(
                                      ^^^^^^^^^^^^^^^^^^^^
  File "/Users/max/jax-metal/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2551, in _cached_compilation
    xla_executable = dispatch.compile_or_get_cached(
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/max/jax-metal/lib/python3.11/site-packages/jax/_src/dispatch.py", line 497, in compile_or_get_cached
    return backend_compile(backend, computation, compile_options,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/max/jax-metal/lib/python3.11/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/max/jax-metal/lib/python3.11/site-packages/jax/_src/dispatch.py", line 465, in backend_compile
    return backend.compile(built_c, compile_options=options)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: /Users/max/jax-metal/lib/python3.11/site-packages/flax/linen/pooling.py:69:6: error: failed to legalize operation 'mhlo.pad'
  y = lax.reduce_window(inputs, init, reduce_fn, dims, strides, padding)
     ^
/Users/max/jax-metal/lib/python3.11/site-packages/flax/linen/pooling.py:69:6: note: see current operation: %207 = "mhlo.pad"(%206, %19) {edge_padding_high = dense<[0, 1, 1, 0]> : tensor<4xi64>, edge_padding_low = dense<[0, 1, 1, 0]> : tensor<4xi64>, interior_padding = dense<[0, 1, 1, 0]> : tensor<4xi64>} : (tensor<128x7x7x64xf32>, tensor<f32>) -> tensor<128x15x15x64xf32>

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/max/flax/examples/mnist/main.py", line 65, in <module>
    app.run(main)
  File "/Users/max/jax-metal/lib/python3.11/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/Users/max/jax-metal/lib/python3.11/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
             ^^^^^^^^^^
  File "/Users/max/flax/examples/mnist/main.py", line 60, in main
    train.train_and_evaluate(FLAGS.config, FLAGS.workdir)
  File "/Users/max/flax/examples/mnist/train.py", line 140, in train_and_evaluate
    state, train_loss, train_accuracy = train_epoch(state, train_ds,
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/max/flax/examples/mnist/train.py", line 89, in train_epoch
    grads, loss, accuracy = apply_model(state, batch_images, batch_labels)
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: /Users/max/jax-metal/lib/python3.11/site-packages/flax/linen/pooling.py:69:6: error: failed to legalize operation 'mhlo.pad'
  y = lax.reduce_window(inputs, init, reduce_fn, dims, strides, padding)
     ^
/Users/max/jax-metal/lib/python3.11/site-packages/flax/linen/pooling.py:69:6: note: see current operation: %207 = "mhlo.pad"(%206, %19) {edge_padding_high = dense<[0, 1, 1, 0]> : tensor<4xi64>, edge_padding_low = dense<[0, 1, 1, 0]> : tensor<4xi64>, interior_padding = dense<[0, 1, 1, 0]> : tensor<4xi64>} : (tensor<128x7x7x64xf32>, tensor<f32>) -> tensor<128x15x15x64xf32>

@BradBalderson
Copy link

BradBalderson commented Jun 14, 2023

I get the same error, with almost identical specs. Except python 3.9, and Apple Mac M2 Pro. Seems to be coming from within jax, as opposed to jax-metal however.

File ~/mambaforge/envs/mrff/lib/python3.9/site-packages/jax/_src/dispatch.py:465, in backend_compile(backend, module, options, host_callbacks)
    460   return backend.compile(built_c, compile_options=options,
    461                          host_callbacks=host_callbacks)
    462 # Some backends don't have `host_callbacks` option yet
    463 # TODO(sharadmv): remove this fallback when all backends allow `compile`
    464 # to take in `host_callbacks`
--> 465 return backend.compile(built_c, compile_options=options)
XlaRuntimeError: UNKNOWN: io.py:164:0: error: failed to legalize operation 'mhlo.pad'

@shuhand0
Copy link
Collaborator

Thanks for sending the bug report. JAX-Metal plugin do not support pad with non-zero interior_padding. We will look into expanding the coverage and update here.

@BradBalderson
Copy link

I am running a pretrained model, I wonder if there is a way to change my inputs/tokenisation to try and add interior_padding to circumvent this issue?

@hawkinsp
Copy link
Collaborator

@BradBalderson It will be impossible to say without more details on how the operator is used in the model. If it is applied to one of the model inputs, perhaps.

@BradBalderson
Copy link

Ah OK, I will see if it is due to the inputs. Thanks for the fast reply and feedback @hawkinsp and @shuhand0

@hawkinsp
Copy link
Collaborator

BTW, you can implement interior padding with edge padding, if the interior padding is from your user code.

For example, to pad the innermost dimension, you do this:

  • Reshape the input to add a new inner dimension of size 1.
  • Pad that new dimension to size 1 + interior_padding
  • Reshape back to the original number of dimensions, flattening the new inner dimension into the original inner dimension
  • Truncate interior_padding elements off the end.

But... it might just be better to wait for our colleagues from Apple to fix the plugin :-)

@hawkinsp hawkinsp added enhancement New feature or request and removed bug Something isn't working labels Jun 15, 2023
agalashov added a commit to tedmoskovitz/directorv3 that referenced this issue Jul 11, 2023
The caviat: Some XLS ops are not compiled correctly. According to
jax-ml/jax#16366, some of the XLA ops are not
yet supported by Apple Sillicon.
@Mixpap
Copy link

Mixpap commented Dec 27, 2023

I encountered the same bug by trying to calculate the grad of a loss function for a physics informed NN problem in a mac M1.

Jax version: 0.4.20
Bellow is the slacktrace. I am trying to construct a minimum reproducible problem, I will update with a new comment but the problem is very complex and difficult to recreate a more simple version of it.

{
	"name": "XlaRuntimeError",
	"message": "UNKNOWN: /var/folders/l6/g1b2w7y10rg8qzc142j02b240000gn/T/ipykernel_73686/1915339887.py:2:47: error: failed to legalize operation 'mhlo.pad'
/var/folders/l6/g1b2w7y10rg8qzc142j02b240000gn/T/ipykernel_73686/1915339887.py:2:47: note: see current operation: %99 = \"mhlo.pad\"(%98, %1) {edge_padding_high = dense<0> : tensor<2xi64>, edge_padding_low = dense<[0, -1]> : tensor<2xi64>, interior_padding = dense<0> : tensor<2xi64>} : (tensor<1024x2xf32>, tensor<f32>) -> tensor<1024x1xf32>
",
	"stack": "---------------------------------------------------------------------------
JaxStackTraceBeforeTransformation         Traceback (most recent call last)
File <frozen runpy>:198, in _run_module_as_main()

File <frozen runpy>:88, in _run_code()

File ~/miniconda3/envs/metal/lib/python3.11/site-packages/ipykernel_launcher.py:17
     15 from ipykernel import kernelapp as app
---> 17 app.launch_new_instance()

File ~/miniconda3/envs/metal/lib/python3.11/site-packages/traitlets/config/application.py:1077, in launch_instance()
   1076 app.initialize(argv)
-> 1077 app.start()

File ~/miniconda3/envs/metal/lib/python3.11/site-packages/ipykernel/kernelapp.py:701, in start()
    700 try:
--> 701     self.io_loop.start()
    702 except KeyboardInterrupt:

File ~/miniconda3/envs/metal/lib/python3.11/site-packages/tornado/platform/asyncio.py:195, in start()
    194 def start(self) -> None:
--> 195     self.asyncio_loop.run_forever()

File ~/miniconda3/envs/metal/lib/python3.11/asyncio/base_events.py:607, in run_forever()
    606 while True:
--> 607     self._run_once()
    608     if self._stopping:

File ~/miniconda3/envs/metal/lib/python3.11/asyncio/base_events.py:1922, in _run_once()
   1921     else:
-> 1922         handle._run()
   1923 handle = None

File ~/miniconda3/envs/metal/lib/python3.11/asyncio/events.py:80, in _run()
     79 try:
---> 80     self._context.run(self._callback, *self._args)
     81 except (SystemExit, KeyboardInterrupt):

File ~/miniconda3/envs/metal/lib/python3.11/site-packages/ipykernel/kernelbase.py:534, in dispatch_queue()
    533 try:
--> 534     await self.process_one()
    535 except Exception:

File ~/miniconda3/envs/metal/lib/python3.11/site-packages/ipykernel/kernelbase.py:523, in process_one()
    522         return
--> 523 await dispatch(*args)

File ~/miniconda3/envs/metal/lib/python3.11/site-packages/ipykernel/kernelbase.py:429, in dispatch_shell()
    428     if inspect.isawaitable(result):
--> 429         await result
    430 except Exception:

File ~/miniconda3/envs/metal/lib/python3.11/site-packages/ipykernel/kernelbase.py:767, in execute_request()
    766 if inspect.isawaitable(reply_content):
--> 767     reply_content = await reply_content
    769 # Flush output before sending the reply.

File ~/miniconda3/envs/metal/lib/python3.11/site-packages/ipykernel/ipkernel.py:429, in do_execute()
    428 if accepts_params[\"cell_id\"]:
--> 429     res = shell.run_cell(
    430         code,
    431         store_history=store_history,
    432         silent=silent,
    433         cell_id=cell_id,
    434     )
    435 else:

File ~/miniconda3/envs/metal/lib/python3.11/site-packages/ipykernel/zmqshell.py:549, in run_cell()
    548 self._last_traceback = None
--> 549 return super().run_cell(*args, **kwargs)

File ~/miniconda3/envs/metal/lib/python3.11/site-packages/IPython/core/interactiveshell.py:3051, in run_cell()
   3050 try:
-> 3051     result = self._run_cell(
   3052         raw_cell, store_history, silent, shell_futures, cell_id
   3053     )
   3054 finally:

File ~/miniconda3/envs/metal/lib/python3.11/site-packages/IPython/core/interactiveshell.py:3106, in _run_cell()
   3105 try:
-> 3106     result = runner(coro)
   3107 except BaseException as e:

File ~/miniconda3/envs/metal/lib/python3.11/site-packages/IPython/core/async_helpers.py:129, in _pseudo_sync_runner()
    128 try:
--> 129     coro.send(None)
    130 except StopIteration as exc:

File ~/miniconda3/envs/metal/lib/python3.11/site-packages/IPython/core/interactiveshell.py:3311, in run_cell_async()
   3308 interactivity = \"none\" if silent else self.ast_node_interactivity
-> 3311 has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
   3312        interactivity=interactivity, compiler=compiler, result=result)
   3314 self.last_execution_succeeded = not has_raised

File ~/miniconda3/envs/metal/lib/python3.11/site-packages/IPython/core/interactiveshell.py:3493, in run_ast_nodes()
   3492     asy = compare(code)
-> 3493 if await self.run_code(code, result, async_=asy):
   3494     return True

File ~/miniconda3/envs/metal/lib/python3.11/site-packages/IPython/core/interactiveshell.py:3553, in run_code()
   3552     else:
-> 3553         exec(code_obj, self.user_global_ns, self.user_ns)
   3554 finally:
   3555     # Reset our crash handler in place

Cell In[288], line 1
----> 1 jax.grad(los1)(params)

Cell In[286], line 1, in los1()
----> 1 def los1(params): return jnp.mean(los_physics1(params,tt,xx)**2)

Cell In[243], line 15, in los_physics1()
     13 rho_V_x_x = lambda t,x: jax.grad(lambda t,x: jnp.sum(rho(t,x)*Vx_x(t,x)),1)(t,x)
---> 15 return rho_t(tc,xc)+rho_V_x_x(tc,xc)

Cell In[243], line 13, in los_physics1.<locals>.<lambda>()
     12 #rho_V_x_t = lambda t,x: jax.grad(lambda t,x: jnp.sum(rho(t,x)*Vx_x(t,x)),0)(t,x)
---> 13 rho_V_x_x = lambda t,x: jax.grad(lambda t,x: jnp.sum(rho(t,x)*Vx_x(t,x)),1)(t,x)
     15 return rho_t(tc,xc)+rho_V_x_x(tc,xc)

Cell In[243], line 13, in los_physics1.<locals>.<lambda>()
     12 #rho_V_x_t = lambda t,x: jax.grad(lambda t,x: jnp.sum(rho(t,x)*Vx_x(t,x)),0)(t,x)
---> 13 rho_V_x_x = lambda t,x: jax.grad(lambda t,x: jnp.sum(rho(t,x)*Vx_x(t,x)),1)(t,x)
     15 return rho_t(tc,xc)+rho_V_x_x(tc,xc)

Cell In[243], line 10, in los_physics1.<locals>.<lambda>()
      9 rho_t= lambda t,x: jax.grad(lambda t,x: jnp.sum(rho(t,x)),0)(t,x)
---> 10 Vx_x= lambda t,x: jax.grad(lambda t,x: jnp.sum(Vx(t,x)),1)(t,x)
     12 #rho_V_x_t = lambda t,x: jax.grad(lambda t,x: jnp.sum(rho(t,x)*Vx_x(t,x)),0)(t,x)

Cell In[243], line 10, in los_physics1.<locals>.<lambda>()
      9 rho_t= lambda t,x: jax.grad(lambda t,x: jnp.sum(rho(t,x)),0)(t,x)
---> 10 Vx_x= lambda t,x: jax.grad(lambda t,x: jnp.sum(Vx(t,x)),1)(t,x)
     12 #rho_V_x_t = lambda t,x: jax.grad(lambda t,x: jnp.sum(rho(t,x)*Vx_x(t,x)),0)(t,x)

Cell In[243], line 6, in los_physics1.<locals>.<lambda>()
      5 rho=lambda t,x: NN(params,t,x)[0]
----> 6 Vx=lambda t,x: NN(params,t,x)[1]
      7 #p=lambda t,x: NN(params,t,x)[2]

JaxStackTraceBeforeTransformation: jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: /var/folders/l6/g1b2w7y10rg8qzc142j02b240000gn/T/ipykernel_73686/1915339887.py:2:47: error: failed to legalize operation 'mhlo.pad'
/var/folders/l6/g1b2w7y10rg8qzc142j02b240000gn/T/ipykernel_73686/1915339887.py:2:47: note: see current operation: %99 = \"mhlo.pad\"(%98, %1) {edge_padding_high = dense<0> : tensor<2xi64>, edge_padding_low = dense<[0, -1]> : tensor<2xi64>, interior_padding = dense<0> : tensor<2xi64>} : (tensor<1024x2xf32>, tensor<f32>) -> tensor<1024x1xf32>

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

XlaRuntimeError                           Traceback (most recent call last)
Cell In[288], line 1
----> 1 jax.grad(los1)(params)

    [... skipping hidden 21 frame]

File ~/miniconda3/envs/metal/lib/python3.11/site-packages/jax/_src/compiler.py:255, in backend_compile(backend, module, options, host_callbacks)
    250   return backend.compile(built_c, compile_options=options,
    251                          host_callbacks=host_callbacks)
    252 # Some backends don't have `host_callbacks` option yet
    253 # TODO(sharadmv): remove this fallback when all backends allow `compile`
    254 # to take in `host_callbacks`
--> 255 return backend.compile(built_c, compile_options=options)

XlaRuntimeError: UNKNOWN: /var/folders/l6/g1b2w7y10rg8qzc142j02b240000gn/T/ipykernel_73686/1915339887.py:2:47: error: failed to legalize operation 'mhlo.pad'
/var/folders/l6/g1b2w7y10rg8qzc142j02b240000gn/T/ipykernel_73686/1915339887.py:2:47: note: see current operation: %99 = \"mhlo.pad\"(%98, %1) {edge_padding_high = dense<0> : tensor<2xi64>, edge_padding_low = dense<[0, -1]> : tensor<2xi64>, interior_padding = dense<0> : tensor<2xi64>} : (tensor<1024x2xf32>, tensor<f32>) -> tensor<1024x1xf32>
"
}

@mar-muel
Copy link

Hello - I'm getting this error when running the following, very simple operation

import jax.numpy as jnp
jnp.cumprod(jnp.arange(10))
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/martin/miniconda3/envs/jax-metal/lib/python3.10/site-packages/jax/_src/numpy/reductions.py", line 657, in cumulative_reduction
    return _cumulative_reduction(a, _ensure_optional_axes(axis), dtype, out)
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <stdin>:1:0: error: failed to legalize operation 'mhlo.pad'
<stdin>:1:0: note: called from
<stdin>:1:0: note: see current operation: %43 = "mhlo.pad"(%42, %1) {edge_padding_high = dense<1> : tensor<1xi64>, edge_padding_low = dense<0> : tensor<1xi64>, interior_padding = dense<1> : tensor<1xi64>} : (tensor<1xsi32>, tensor<si32>) -> tensor<2xsi32>

My env

>>> jax.print_environment_info()
jax:    0.4.20
jaxlib: 0.4.20
numpy:  1.26.4
python: 3.10.13 (main, Sep 11 2023, 08:16:02) [Clang 14.0.6 ]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1

Would be really cool if someone could fix this, as it makes jax-metal pretty much unusable for me :(

@shuhand0
Copy link
Collaborator

For the pad with interior_padding, the fix will be in the upcoming jax-metal release and work in 14.4 OS.

@shuhand0
Copy link
Collaborator

The fix is in jax-metal 0.0.6. Some output from running flax/examples/mnist:

python main.py --workdir=/tmp/mnist --config=configs/default.py \
--config.learning_rate=0.05 --config.num_epochs=5
I0312 14:30:01.288698 7932478208 xla_bridge.py:660] Unable to initialize backend 'cuda': 
I0312 14:30:01.288793 7932478208 xla_bridge.py:660] Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
I0312 14:30:01.290362 7932478208 xla_bridge.py:660] Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: dlopen(libtpu.so, 0x0001): tried: 'libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OSlibtpu.so' (no such file), '/Users/shuhan/miniconda3/envs/test/bin/../lib/libtpu.so' (no such file), '/usr/lib/libtpu.so' (no such file, not in dyld cache), 'libtpu.so' (no such file), '/usr/local/lib/libtpu.so' (no such file), '/usr/lib/libtpu.so' (no such file, not in dyld cache)
W0312 14:30:01.290431 7932478208 xla_bridge.py:758] Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
2024-03-12 14:30:01.290494: W pjrt_plugin/src/mps_client.cc:563] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M1 Max

systemMemory: 64.00 GB
maxCacheSize: 24.00 GB

I0312 14:30:01.301946 7932478208 main.py:50] JAX process: 0 / 1
I0312 14:30:01.302008 7932478208 main.py:51] JAX local devices: [METAL(id=0)]
I0312 14:30:01.302042 7932478208 local.py:45] Setting task status: process_index: 0, process_count: 1
I0312 14:30:01.302149 7932478208 local.py:50] Created artifact workdir of type ArtifactType.DIRECTORY and value /tmp/mnist.
I0312 14:30:01.302522 7932478208 dataset_info.py:358] Load dataset info from /Users/shuhan/tensorflow_datasets/mnist/3.0.1
I0312 14:30:01.303414 7932478208 dataset_info.py:411] Field info.citation from disk and from code do not match. Keeping the one from code.
I0312 14:30:01.303492 7932478208 dataset_info.py:411] Field info.splits from disk and from code do not match. Keeping the one from code.
I0312 14:30:01.303527 7932478208 dataset_info.py:411] Field info.module_name from disk and from code do not match. Keeping the one from code.
I0312 14:30:01.303578 7932478208 dataset_builder.py:351] Reusing dataset mnist (/Users/shuhan/tensorflow_datasets/mnist/3.0.1)
I0312 14:30:01.303609 7932478208 logging_logger.py:35] Constructing tf.data.Dataset mnist for split train, from /Users/shuhan/tensorflow_datasets/mnist/3.0.1
WARNING:tensorflow:From /Users/shuhan/miniconda3/envs/test/lib/python3.9/site-packages/tensorflow_datasets/core/dataset_builder.py:622: get_single_element (from tensorflow.python.data.experimental.ops.get_single_element) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Dataset.get_single_element()`.
W0312 14:30:01.332822 7932478208 deprecation.py:50] From /Users/shuhan/miniconda3/envs/test/lib/python3.9/site-packages/tensorflow_datasets/core/dataset_builder.py:622: get_single_element (from tensorflow.python.data.experimental.ops.get_single_element) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Dataset.get_single_element()`.
I0312 14:30:01.966295 7932478208 logging_logger.py:35] Constructing tf.data.Dataset mnist for split test, from /Users/shuhan/tensorflow_datasets/mnist/3.0.1
I0312 14:30:07.488775 7932478208 train.py:146] epoch:  1, train_loss: 0.2369, train_accuracy: 93.05, test_loss: 0.0590, test_accuracy: 98.07
I0312 14:30:10.687863 7932478208 train.py:146] epoch:  2, train_loss: 0.0611, train_accuracy: 98.11, test_loss: 0.0547, test_accuracy: 98.17
I0312 14:30:13.719958 7932478208 train.py:146] epoch:  3, train_loss: 0.0423, train_accuracy: 98.68, test_loss: 0.0330, test_accuracy: 98.73
I0312 14:30:16.748282 7932478208 train.py:146] epoch:  4, train_loss: 0.0308, train_accuracy: 99.00, test_loss: 0.0302, test_accuracy: 99.04
I0312 14:30:19.781277 7932478208 train.py:146] epoch:  5, train_loss: 0.0250, train_accuracy: 99.21, test_loss: 0.0306, test_accuracy: 99.03

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

7 participants