Skip to content

Commit

Permalink
Remove uses of deprecated xb, xc, and xe abbreviation from jax.interp…
Browse files Browse the repository at this point in the history
…reters.xla

PiperOrigin-RevId: 658688048
  • Loading branch information
Jake VanderPlas authored and copybara-github committed Aug 2, 2024
1 parent 66d8b74 commit 54a1eba
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions haiku/_src/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,8 +338,8 @@ def transform(f, *, apply_rng=True) -> Transformed:

return without_state(transform_with_state(f))

COMPILED_FN_TYPES = (jax.interpreters.xla.xe.PjitFunction,
jax.interpreters.xla.xe.PmapFunction) # pytype: disable=name-error
COMPILED_FN_TYPES = (jax.lib.xla_extension.PjitFunction,
jax.lib.xla_extension.PmapFunction) # pytype: disable=name-error


def check_not_jax_transformed(f):
Expand Down

0 comments on commit 54a1eba

Please sign in to comment.