-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Description
This:
import weakref
import jax
cache = weakref.WeakKeyDictionary()
def foo():
def key(_):
pass
def bar(fn):
pass
bar = cache[key] = jax.jit(bar, static_argnums=0)
bar(key)
foo()produces either
python(73214,0x20e465f00) malloc: *** error for object 0x60000398e490: pointer being freed was not allocated
python(73214,0x20e465f00) malloc: *** set a breakpoint in malloc_error_break to debug
fish: Job 1, 'python tmp.py' terminated by signal SIGABRT (Abort)
or
Job 1, 'python tmp.py' terminated by signal SIGSEGV (Address boundary error)
Looks like some kind of double-free / reference cycle shenanigans.
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.7.0
jaxlib: 0.7.0
numpy: 2.0.2
python: 3.11.8 | packaged by conda-forge | (main, Feb 16 2024, 20:49:36) [Clang 16.0.6 ]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Darwin', node='Air.local', release='24.5.0', version='Darwin Kernel Version 24.5.0: Tue Apr 22 19:54:26 PDT 2025; root:xnu-11417.121.6~2/RELEASE_ARM64_T8112', machine='arm64')
JAX_TRACEBACK_FILTERING=tracebackhide
jlperla, pfackeldey and jpbrodrick89
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working