You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
(mjx_test) ➜ mjx python testspeed.py --mjcf humanoid/humanoid.xml
Rolling out 1000 steps at dt = 0.005...
I0423 16:03:29.898261 140143597762368 xla_bridge.py:863] Unable to initialize backend 'cuda':
I0423 16:03:29.898374 140143597762368 xla_bridge.py:863] Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
I0423 16:03:29.898871 140143597762368 xla_bridge.py:863] Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
W0423 16:03:29.898980 140143597762368 xla_bridge.py:901] An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
2024-04-23 16:03:53.934908: W external/xla/xla/service/cpu/onednn_matmul.cc:293] [Perf]: MatMul reference implementation being executed
2024-04-23 16:03:53.993505: W external/xla/xla/service/cpu/onednn_matmul.cc:293] [Perf]: MatMul reference implementation being executed
....
2024-04-23 16:04:45.988970: W external/xla/xla/service/cpu/onednn_matmul.cc:293] [Perf]: MatMul reference implementation being executed
result.qpos: [[[nan nan nan ... nan nan nan]
[nan nan nan ... nan nan nan]
[nan nan nan ... nan nan nan]
...
[nan nan nan ... nan nan nan]
[nan nan nan ... nan nan nan]
[nan nan nan ... nan nan nan]]]
Summary for 1024 parallel rollouts
Total JIT time: 23.58 s
Total simulation time: 52.14 s
Total steps per second: 19640
Total realtime factor: 98.20 x
Total time per step: 50.92 µs
After having added a line print(f"result.qpos: {result.qpos}")here.
The output of pip list is
I came across an issue mentioning the MatMul reference implementation being executedhere. I tried to run the benchmarks on a mac m1 and I do not have the MatMul reference or nan issue anymore.
If you have any idea what is going on for mjx on ubuntu (CPU) I would be happy,
Thanks for your time
EDIT:
I have also checked other issues here and trying to increase the precision via jax.config.update("jax_enable_x64", True) does not work
(mjx_test) ➜ mjx python testspeed.py --mjcf humanoid/humanoid.xml
Rolling out 1000 steps at dt = 0.005...
I0423 17:15:45.344351 139876225578816 xla_bridge.py:863] Unable to initialize backend 'cuda':
I0423 17:15:45.344468 139876225578816 xla_bridge.py:863] Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
I0423 17:15:45.344967 139876225578816 xla_bridge.py:863] Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
W0423 17:15:45.345096 139876225578816 xla_bridge.py:901] An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/fschramm/mambaforge/envs/mjx_test/lib/python3.10/site-packages/mujoco/mjx/testspeed.py", line 88, in <module>
main()
File "/home/fschramm/mambaforge/envs/mjx_test/lib/python3.10/site-packages/mujoco/mjx/testspeed.py", line 84, in main
app.run(_main)
File "/home/fschramm/mambaforge/envs/mjx_test/lib/python3.10/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/home/fschramm/mambaforge/envs/mjx_test/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/home/fschramm/mambaforge/envs/mjx_test/lib/python3.10/site-packages/mujoco/mjx/testspeed.py", line 58, in _main
jit_time, run_time, steps = mjx.benchmark(
File "/home/fschramm/mambaforge/envs/mjx_test/lib/python3.10/site-packages/mujoco/mjx/_src/test_util.py", line 106, in benchmark
jit_time, run_time = _measure(unroll, d)
File "/home/fschramm/mambaforge/envs/mjx_test/lib/python3.10/site-packages/mujoco/mjx/_src/test_util.py", line 41, in _measure
compiled_fn = fn.lower(*args).compile()
File "/home/fschramm/mambaforge/envs/mjx_test/lib/python3.10/site-packages/mujoco/mjx/_src/test_util.py", line 102, in unroll
d, _ = jax.lax.scan(step, d, None, length=nstep, unroll=unroll_steps)
TypeError: Scanned function carry input and carry output must have equal types (e.g. shapes and dtypes of arrays), but they differ:
* the input carry component d.contact.geom1 has type int64[1024,8] but the corresponding output carry component has type int32[1024,8], so the dtypes do not match
* the input carry component d.contact.geom2 has type int64[1024,8] but the corresponding output carry component has type int32[1024,8], so the dtypes do not match
Revise the scanned function so that all output types (e.g. shapes and dtypes) match the corresponding input types.
as well as config.update("jax_debug_nans", True) did not give me any useful information
Traceback (most recent call last):
File "/home/fschramm/mambaforge/envs/mjx_test/lib/python3.10/site-packages/mujoco/mjx/testspeed.py", line 85, in <module>
main()
File "/home/fschramm/mambaforge/envs/mjx_test/lib/python3.10/site-packages/mujoco/mjx/testspeed.py", line 81, in main
app.run(_main)
File "/home/fschramm/mambaforge/envs/mjx_test/lib/python3.10/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/home/fschramm/mambaforge/envs/mjx_test/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/home/fschramm/mambaforge/envs/mjx_test/lib/python3.10/site-packages/mujoco/mjx/testspeed.py", line 55, in _main
jit_time, run_time, steps = mjx.benchmark(
File "/home/fschramm/mambaforge/envs/mjx_test/lib/python3.10/site-packages/mujoco/mjx/_src/test_util.py", line 104, in benchmark
jit_time, run_time = _measure(unroll, d)
File "/home/fschramm/mambaforge/envs/mjx_test/lib/python3.10/site-packages/mujoco/mjx/_src/test_util.py", line 44, in _measure
result = compiled_fn(*args)
File "/home/fschramm/mambaforge/envs/mjx_test/lib/python3.10/site-packages/jax/_src/stages.py", line 594, in __call__
return self._call(*args, **kwargs)
File "/home/fschramm/mambaforge/envs/mjx_test/lib/python3.10/site-packages/jax/_src/stages.py", line 591, in cpp_call_fallback
outs, _, _ = Compiled.call(params, *args, **kwargs)
File "/home/fschramm/mambaforge/envs/mjx_test/lib/python3.10/site-packages/jax/_src/stages.py", line 563, in call
out_flat = params.executable.call(*args_flat)
File "/home/fschramm/mambaforge/envs/mjx_test/lib/python3.10/site-packages/jax/_src/profiler.py", line 335, in wrapper
return func(*args, **kwargs)
File "/home/fschramm/mambaforge/envs/mjx_test/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 1089, in call
return self.unsafe_call(*args) # pylint: disable=not-callable
File "/home/fschramm/mambaforge/envs/mjx_test/lib/python3.10/site-packages/jax/_src/profiler.py", line 335, in wrapper
return func(*args, **kwargs)
File "/home/fschramm/mambaforge/envs/mjx_test/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 1217, in __call__
dispatch.check_special(self.name, arrays)
File "/home/fschramm/mambaforge/envs/mjx_test/lib/python3.10/site-packages/jax/_src/dispatch.py", line 314, in check_special
_check_special(name, buf.dtype, buf)
File "/home/fschramm/mambaforge/envs/mjx_test/lib/python3.10/site-packages/jax/_src/dispatch.py", line 319, in _check_special
raise FloatingPointError(f"invalid value (nan) encountered in {name}")
FloatingPointError: invalid value (nan) encountered in parallel computation
The text was updated successfully, but these errors were encountered:
Hello guys and thanks for maintaining this library. I tried to run the
testspeed.py
script and encounter nans.I am on Ubuntu 22.04 and I create a clean env and installed mujoco-mjx by
then I launch the benchmarks by
and I see the following
After having added a line
print(f"result.qpos: {result.qpos}")
here.The output of pip list is
I came across an issue mentioning the
MatMul reference implementation being executed
here. I tried to run the benchmarks on a mac m1 and I do not have theMatMul reference
or nan issue anymore.If you have any idea what is going on for mjx on ubuntu (CPU) I would be happy,
Thanks for your time
EDIT:
I have also checked other issues here and trying to increase the precision via
jax.config.update("jax_enable_x64", True)
does not workas well as
config.update("jax_debug_nans", True)
did not give me any useful informationThe text was updated successfully, but these errors were encountered: