Skip to content

Commit

Permalink
Avoid depending on JAX internals, which are about to change.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 689265912
  • Loading branch information
dougalm authored and copybara-github committed Oct 24, 2024
1 parent c252c9c commit 3423f06
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions haiku/_src/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,11 @@ class JaxTraceLevel(NamedTuple):
@classmethod
def current(cls):
if jax.__version_info__ <= (0, 4, 33):
trace_stack = jax_core.thread_local_state.trace_state.trace_stack.stack
trace_stack = jax_core.thread_local_state.trace_state.trace_stack.stack # type: ignore
top_type = trace_stack[0].trace_type
level = trace_stack[-1].level
sublevel = jax_core.cur_sublevel()
return JaxTraceLevel(opaque=(top_type, level, sublevel))
sublevel = jax_core.cur_sublevel() # type: ignore
return JaxTraceLevel(opaque=(top_type, level, sublevel)) # type: ignore

ts = jax_core.get_opaque_trace_state(convention="haiku")
return JaxTraceLevel(opaque=ts)
Expand Down

0 comments on commit 3423f06

Please sign in to comment.