Skip to content

Commit f8f01cb

Browse files
fix(jax): fix the usage of jaxlib.xla_extension (#4824)
`jaxlib.xla_extension` has been removed in recent versions of JAX. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Refactor** * Improved exception handling by updating error type checks and removing an unnecessary dependency. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent ef5869a commit f8f01cb

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

deepmd/jax/utils/auto_batch_size.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22

3-
import jaxlib
43

54
from deepmd.jax.env import (
65
jax,
@@ -52,7 +51,7 @@ def is_oom_error(self, e: Exception) -> bool:
5251
# several sources think CUSOLVER_STATUS_INTERNAL_ERROR is another out-of-memory error,
5352
# such as https://github.com/JuliaGPU/CUDA.jl/issues/1924
5453
# (the meaningless error message should be considered as a bug in cusolver)
55-
if isinstance(e, (jaxlib.xla_extension.XlaRuntimeError, ValueError)) and (
54+
if isinstance(e, (RuntimeError, ValueError)) and (
5655
"RESOURCE_EXHAUSTED:" in e.args[0]
5756
):
5857
return True

0 commit comments

Comments
 (0)