Skip to content

Integrating JAX with a custom accelerator #11439

Answered by hawkinsp
slai-nick asked this question in Q&A
Discussion options

You must be logged in to vote

I'm assuming here you have say, a novel GPU or TPU-like accelerator device that you would like to integrate into JAX. Yes, this is something that is on the roadmap.

The exact mechanics of how you do this will change, but in essence one implements new JAX devices by providing a compiler and runtime that implement this API:

https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/compiler/xla/pjrt/pjrt_client.h

The PJRT client API is what JAX uses to compile and run MHLO programs and to perform other runtime tasks, such as management of buffers.

Currently such an extension also requires building and shipping a custom jaxlib into which your custom plugin is integrated, however …

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@slai-nick
Comment options

Answer selected by slai-nick
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants