-
Is there any path for integrating JAX with a custom accelerator? If not, is it on the roadmap to add such a feature? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
I'm assuming here you have say, a novel GPU or TPU-like accelerator device that you would like to integrate into JAX. Yes, this is something that is on the roadmap. The exact mechanics of how you do this will change, but in essence one implements new JAX devices by providing a compiler and runtime that implement this API: The PJRT client API is what JAX uses to compile and run MHLO programs and to perform other runtime tasks, such as management of buffers. Currently such an extension also requires building and shipping a custom |
Beta Was this translation helpful? Give feedback.
I'm assuming here you have say, a novel GPU or TPU-like accelerator device that you would like to integrate into JAX. Yes, this is something that is on the roadmap.
The exact mechanics of how you do this will change, but in essence one implements new JAX devices by providing a compiler and runtime that implement this API:
https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/compiler/xla/pjrt/pjrt_client.h
The PJRT client API is what JAX uses to compile and run MHLO programs and to perform other runtime tasks, such as management of buffers.
Currently such an extension also requires building and shipping a custom
jaxlib
into which your custom plugin is integrated, however …