Skip to content

Commit

Permalink
Address code reviews
Browse files Browse the repository at this point in the history
Co-authored-by: Diego Ferigo <diego.ferigo@iit.it>
  • Loading branch information
flferretti and diegoferigo committed Oct 11, 2024
1 parent 9e44c6b commit 515522f
Showing 1 changed file with 5 additions and 8 deletions.
13 changes: 5 additions & 8 deletions src/jaxsim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@ def _jnp_options() -> None:
# Determine if 64-bit precision is requested
use_x64 = os.environ.get("JAX_ENABLE_X64", "1") != "0"

# Raise an error if 64-bit precision is not allowed on TPU
# Notify the user if unsupported 64-bit precision was enforced on TPU.
if is_tpu and use_x64:
msg = "64-bit precision is not allowed on TPU. Enforcing 32bit precision."
logging.error(msg)
msg = "64bit precision is not allowed on TPU. Enforcing 32bit precision."
logging.warning(msg)
use_x64 = False

# Enable 64-bit precision in JAX
elif not is_tpu and use_x64:
if use_x64:
logging.info("Enabling JAX to use 64bit precision")
jax.config.update("jax_enable_x64", True)

Expand All @@ -31,10 +32,6 @@ def _jnp_options() -> None:
if jnp.empty(0, dtype=float).dtype != jnp.empty(0, dtype=np.float64).dtype:
logging.warning("Failed to enable 64bit precision in JAX")

# Enforce 32-bit precision on TPU
elif is_tpu:
logging.warning("JAX is running on TPU; 32bit precision is enforced.")

# Warn about experimental use of 32-bit precision
else:
logging.warning(
Expand Down

0 comments on commit 515522f

Please sign in to comment.