Skip to content

SIGSEGV/SIGABRT with JAX 0.7.0 #30517

@patrick-kidger

Description

@patrick-kidger

Description

This:

import weakref
import jax

cache = weakref.WeakKeyDictionary()

def foo():
    def key(_):
        pass

    def bar(fn):
        pass

    bar = cache[key] = jax.jit(bar, static_argnums=0)
    bar(key)

foo()

produces either

python(73214,0x20e465f00) malloc: *** error for object 0x60000398e490: pointer being freed was not allocated
python(73214,0x20e465f00) malloc: *** set a breakpoint in malloc_error_break to debug
fish: Job 1, 'python tmp.py' terminated by signal SIGABRT (Abort)

or

Job 1, 'python tmp.py' terminated by signal SIGSEGV (Address boundary error)

Looks like some kind of double-free / reference cycle shenanigans.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.7.0
jaxlib: 0.7.0
numpy:  2.0.2
python: 3.11.8 | packaged by conda-forge | (main, Feb 16 2024, 20:49:36) [Clang 16.0.6 ]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Darwin', node='Air.local', release='24.5.0', version='Darwin Kernel Version 24.5.0: Tue Apr 22 19:54:26 PDT 2025; root:xnu-11417.121.6~2/RELEASE_ARM64_T8112', machine='arm64')
JAX_TRACEBACK_FILTERING=tracebackhide

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions