Skip to content

Commit

Permalink
Address code reviews
Browse files Browse the repository at this point in the history
Co-authored-by: Diego Ferigo <diego.ferigo@iit.it>
  • Loading branch information
flferretti and diegoferigo committed Oct 11, 2024
1 parent 2e2af71 commit 779d37b
Showing 1 changed file with 12 additions and 14 deletions.
26 changes: 12 additions & 14 deletions src/jaxsim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,34 +11,32 @@ def _jnp_options() -> None:
# Check if running on TPU
is_tpu = jax.devices()[0].platform == "tpu"

# Determine if 64-bit precision is requested
# Enable by default 64-bit precision to get accurate physics.
# Users can enforce 32-bit precision by setting the following variable to 0.
use_x64 = os.environ.get("JAX_ENABLE_X64", "1") != "0"

# Raise an error if 64-bit precision is not allowed on TPU
# Notify the user if unsupported 64-bit precision was enforced on TPU.
if is_tpu and use_x64:
msg = "64-bit precision is not allowed on TPU. Enforcing 32bit precision."
logging.error(msg)
logging.warning(msg)
use_x64 = False

# Enable 64-bit precision in JAX
elif not is_tpu and use_x64:
logging.info("Enabling JAX to use 64bit precision")
# Enable 64-bit precision in JAX.
if use_x64:
logging.info("Enabling JAX to use 64-bit precision")
jax.config.update("jax_enable_x64", True)

import jax.numpy as jnp
import numpy as np

# Verify 64-bit precision is correctly set
# Verify that 64-bit precision is correctly set.
if jnp.empty(0, dtype=float).dtype != jnp.empty(0, dtype=np.float64).dtype:
logging.warning("Failed to enable 64bit precision in JAX")
logging.warning("Failed to enable 64-bit precision in JAX")

# Enforce 32-bit precision on TPU
elif is_tpu:
logging.warning("JAX is running on TPU; 32bit precision is enforced.")

# Warn about experimental use of 32-bit precision
# Warn about experimental usage of 32-bit precision.
else:
logging.warning(
"Using 32bit precision in JaxSim is still experimental, please avoid to use variable step integrators."
"Using 32-bit precision in JaxSim is still experimental, please avoid to use variable step integrators."
)


Expand Down

0 comments on commit 779d37b

Please sign in to comment.