Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update PJRT plugin API version to 0.57 #19241

Merged
merged 3 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions .github/workflows/pkgci_test_pjrt.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,8 @@ jobs:
# install editable into venv
source ${VENV_DIR}/bin/activate
python -m pip install -v --no-deps -e integrations/pjrt/python_packages/iree_${{ matrix.pjrt_platform }}_plugin
# install jax (must be no larger than 0.4.20, refer to #19223)
# TODO: switch to the latest JAX after #19223 is fixed
python -m pip install jax==0.4.20 jaxlib==0.4.20 'numpy<2'
# install
python -m pip install jax==0.4.35
- name: Run tests
run: |
source ${VENV_DIR}/bin/activate
Expand Down
12 changes: 2 additions & 10 deletions build_tools/testing/run_jax_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,9 @@ diff_jax_test() {
echo "no difference found"
}

# FIXME: due to #19223, we need to use jax no higher than 0.4.20,
# but in such version of jax, 'stablehlo.broadcast_in_dim' op
# will be emitted without attribute 'broadcast_dimensions',
# which leads to an error in IREE PJRT plugin.
# So currently any program with broadcast will fail,
# e.g. test/test_simple.py.
# After #19223 is fixed, we can uncomment the line below.

# diff_jax_test test/test_simple.py

diff_jax_test test/test_add.py
diff_jax_test test/test_degenerate.py
diff_jax_test test/test_simple.py


# FIXME: we can also utilize the native test cases from JAX,
Expand Down
9 changes: 9 additions & 0 deletions integrations/pjrt/src/iree_pjrt/common/api_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2164,6 +2164,15 @@ void BindMonomorphicApi(PJRT_Api* api) {
BindUndefineds(api);
ErrorInstance::BindApi(api);

// PJRT_Plugin_Attributes should be implemented since it will always be
// called from the PJRT client in the initial phase.
// here we provide a blank implementation to avoid crash due to unimplemented.
api->PJRT_Plugin_Attributes =
Comment on lines +2167 to +2170
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, classic. An internal bug link for API documentation.
openxla/xla@619ca7b

// Returns an array of plugin attributes which are key-value pairs. One example
// attribute is the minimum supported StableHLO version.
// TODO(b/280349977): standardize the list of attributes.
typedef PJRT_Error* PJRT_Plugin_Attributes(PJRT_Plugin_Attributes_Args* args);

Fine to implement as you have it here for now.

+[](PJRT_Plugin_Attributes_Args* args) -> PJRT_Error* {
args->num_attributes = 0;
args->attributes = nullptr;
return nullptr;
};
api->PJRT_Plugin_Initialize =
+[](PJRT_Plugin_Initialize_Args* args) -> PJRT_Error* { return nullptr; };

Expand Down
5 changes: 5 additions & 0 deletions integrations/pjrt/src/iree_pjrt/common/stubs.inc
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,8 @@
_STUB(PJRT_Client_CreateViewOfDeviceBuffer);
_STUB(PJRT_Executable_Fingerprint);
_STUB(PJRT_Client_TopologyDescription);
_STUB(PJRT_Executable_GetCompiledMemoryStats);
_STUB(PJRT_Memory_Kind_Id);
_STUB(PJRT_ExecuteContext_Create);
_STUB(PJRT_ExecuteContext_Destroy);
_STUB(PJRT_Buffer_CopyRawToHost);
2 changes: 1 addition & 1 deletion integrations/pjrt/third_party/pjrt_c_api/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ releases.
Last synced from:

* https://github.com/openxla/xla.git
* commit: 96d1250d70c0bd6adf2778f31a266c1813fd107a
* commit: a454e14ab0b10e35fb8ad73bd6db7d93782114f6
Loading
Loading