Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The API version of PJRT plugin is lower than jax>0.4.20 #19223

Closed
PragmaTwice opened this issue Nov 20, 2024 · 1 comment · Fixed by #19241
Closed

The API version of PJRT plugin is lower than jax>0.4.20 #19223

PragmaTwice opened this issue Nov 20, 2024 · 1 comment · Fixed by #19241
Labels
bug 🐞 Something isn't working integrations/pjrt OpenXLA PJRT Integration Work

Comments

@PragmaTwice
Copy link
Member

PragmaTwice commented Nov 20, 2024

What happened?

  • The PJRT API version of IREE PJRT plugin is now 0.38, but in the latest version of JAX it's around 0.5x.
  • PJRT APIs like PJRT_Plugin_Attributes are not supported in IREE PJRT plugin which can lead to crashes in latest version of JAX:
(jax) ➜  iree JAX_PLATFORMS=iree_cpu python -c "import jax; a = jax.numpy.asarray([1, 2, 3, 4, 5, 6, 7, 8, 9]); print(a + a);"
WARNING:jax._src.xla_bridge:Platform 'iree_cpu' is experimental and not all JAX functionality may be correctly supported!
[IREE-PJRT] DEBUG: Using IREE compiler binary: /home/twice/miniconda3/envs/jax/lib/python3.12/site-packages/iree/compiler/_mlir_libs/libIREECompiler.so
[IREE-PJRT] DEBUG: Compiler Version: 3.0.0rc20241118 @ 29c451b00ecc9f9e5466e9d1079e0d69147da700 (API version 1.4)
[IREE-PJRT] DEBUG: Partitioner was not enabled. The partitioner can be enabled by setting the 'PARTITIONER_LIB_PATH' config var ('IREE_PJRT_PARTITIONER_LIB_PATH' env var)
[IREE-PJRT] DEBUG: CPU driver created
F1120 14:20:14.636158   43187 pjrt_c_api_helpers.cc:241] Unexpected error status /home/twice/projects/iree/integrations/pjrt/src/iree_pjrt/common/stubs.inc:5: UNIMPLEMENTED; PJRT_Plugin_Attributes
*** Check failure stack trace: ***
    @     0x7fa413cd6fa4  absl::lts_20230802::log_internal::LogMessage::SendToLog()
    @     0x7fa413cd6ea4  absl::lts_20230802::log_internal::LogMessage::Flush()
    @     0x7fa413cd7349  absl::lts_20230802::log_internal::LogMessageFatal::~LogMessageFatal()
    @     0x7fa40c7d3898  pjrt::LogFatalIfPjrtError()
    @     0x7fa40c7b74ab  xla::PjRtCApiClient::InitAttributes()
    @     0x7fa40c7b5f25  xla::PjRtCApiClient::PjRtCApiClient()
    @     0x7fa40c7c88f8  xla::GetCApiClient()
    @     0x7fa40c678fb6  nanobind::detail::func_create<>()::{lambda()#1}::__invoke()
    @     0x7fa4124c41d1  nanobind::detail::nb_func_vectorcall_complex()
    @           0x53e131  PyObject_Vectorcall
[1]    43187 IOT instruction (core dumped)  JAX_PLATFORMS=iree_cpu python -c

I tried lots of JAX versions, and found that:

  • It can work in jax==0.4.20 (and maybe some versions lower than that)
  • It cannot work in jax>0.4.20 (from 0.4.20 to the latest 0.4.35), which can be split to two cases:
    • for lower versions, it returns bad status to indicate that the API version is lower than JAX
    • for higher versions, it crashes like above

Steps to reproduce your issue

follows the README of PJRT plugin:

  • build the plugin via pip install
  • run a simple JAX program by iree_cpu backend
  • see the error

What component(s) does this issue relate to?

Other

Version information

the latest commit in main branch

Additional context

No response

@PragmaTwice PragmaTwice added the bug 🐞 Something isn't working label Nov 20, 2024
@PragmaTwice PragmaTwice changed the title The API version of PJRT plugin is relatively low and cannot work in JAX The API version of PJRT plugin is lower than JAX Nov 20, 2024
@PragmaTwice PragmaTwice changed the title The API version of PJRT plugin is lower than JAX The API version of PJRT plugin is lower than latest JAX Nov 20, 2024
@PragmaTwice PragmaTwice changed the title The API version of PJRT plugin is lower than latest JAX The API version of PJRT plugin is lower than jax>0.4.20 Nov 20, 2024
@PragmaTwice
Copy link
Member Author

I'll work on it soon : )

@ScottTodd ScottTodd added the integrations/pjrt OpenXLA PJRT Integration Work label Nov 20, 2024
ScottTodd pushed a commit that referenced this issue Nov 25, 2024
It closes #19223.

`integrations/pjrt/third_party/pjrt_c_api/xla/pjrt/c/pjrt_c_api.h` is
updated to the latest (with API version from 0.38 to 0.57), fetching
from
https://github.com/openxla/xla/blob/a454e14ab0b10e35fb8ad73bd6db7d93782114f6/xla/pjrt/c/pjrt_c_api.h.

A blank implementation of `PJRT_Plugin_Attributes` is now provided since
an unimplemented `PJRT_Plugin_Attributes` will lead to initialization
failure of PJRT plugin (and thus crashes) in recent versions of PJRT
clients.

Also the JAX version in the CI workflow is updated from 0.4.20 to 0.4.35
and subsequently more tests can be enabled.

ci-exactly: build_packages, test_pjrt

---------

Signed-off-by: PragmaTwice <twice@apache.org>
Groverkss pushed a commit to Groverkss/iree that referenced this issue Dec 1, 2024
It closes iree-org#19223.

`integrations/pjrt/third_party/pjrt_c_api/xla/pjrt/c/pjrt_c_api.h` is
updated to the latest (with API version from 0.38 to 0.57), fetching
from
https://github.com/openxla/xla/blob/a454e14ab0b10e35fb8ad73bd6db7d93782114f6/xla/pjrt/c/pjrt_c_api.h.

A blank implementation of `PJRT_Plugin_Attributes` is now provided since
an unimplemented `PJRT_Plugin_Attributes` will lead to initialization
failure of PJRT plugin (and thus crashes) in recent versions of PJRT
clients.

Also the JAX version in the CI workflow is updated from 0.4.20 to 0.4.35
and subsequently more tests can be enabled.

ci-exactly: build_packages, test_pjrt

---------

Signed-off-by: PragmaTwice <twice@apache.org>
giacs-epic pushed a commit to giacs-epic/iree that referenced this issue Dec 4, 2024
It closes iree-org#19223.

`integrations/pjrt/third_party/pjrt_c_api/xla/pjrt/c/pjrt_c_api.h` is
updated to the latest (with API version from 0.38 to 0.57), fetching
from
https://github.com/openxla/xla/blob/a454e14ab0b10e35fb8ad73bd6db7d93782114f6/xla/pjrt/c/pjrt_c_api.h.

A blank implementation of `PJRT_Plugin_Attributes` is now provided since
an unimplemented `PJRT_Plugin_Attributes` will lead to initialization
failure of PJRT plugin (and thus crashes) in recent versions of PJRT
clients.

Also the JAX version in the CI workflow is updated from 0.4.20 to 0.4.35
and subsequently more tests can be enabled.

ci-exactly: build_packages, test_pjrt

---------

Signed-off-by: PragmaTwice <twice@apache.org>
Signed-off-by: Giacomo Serafini <179146510+giacs-epic@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug 🐞 Something isn't working integrations/pjrt OpenXLA PJRT Integration Work
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants