Skip to content

Commit

Permalink
Merge pull request #112 from ami-iit/transition_to_functional
Browse files Browse the repository at this point in the history
Migrate `jaxsim.api` package away from `PhysicsModel`
  • Loading branch information
diegoferigo authored Mar 19, 2024
2 parents 1dbd1c8 + 900ea0d commit ca32a7e
Show file tree
Hide file tree
Showing 49 changed files with 2,673 additions and 2,100 deletions.
7 changes: 3 additions & 4 deletions src/jaxsim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def _is_editable() -> bool:
del _np_options
del _is_editable

from . import high_level, logging, math, simulation, sixd
from .high_level.common import VelRepr
from .simulation.ode_integration import IntegratorType
from .simulation.simulator import JaxSim
from . import terrain # isort:skip
from . import api, integrators, logging, math, rbda
from .api.common import VelRepr
3 changes: 2 additions & 1 deletion src/jaxsim/api/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from . import common # isort:skip
from . import model, data # isort:skip
from . import common, contact, joint, kin_dyn_parameters, link, ode, references
from . import contact, joint, kin_dyn_parameters, link, ode, ode_data, references
13 changes: 12 additions & 1 deletion src/jaxsim/api/common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import abc
import contextlib
import dataclasses
import enum
import functools
from typing import ContextManager

Expand All @@ -11,7 +12,6 @@
from jax_dataclasses import Static

import jaxsim.typing as jtp
from jaxsim.high_level.common import VelRepr
from jaxsim.utils import JaxsimDataclass, Mutability

try:
Expand All @@ -20,6 +20,17 @@
from typing_extensions import Self


@enum.unique
class VelRepr(enum.IntEnum):
"""
Enumeration of all supported 6D velocity representations.
"""

Body = enum.auto()
Mixed = enum.auto()
Inertial = enum.auto()


@jax_dataclasses.pytree_dataclass
class ModelDataWithVelocityRepresentation(JaxsimDataclass, abc.ABC):
"""
Expand Down
66 changes: 36 additions & 30 deletions src/jaxsim/api/contact.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
import jax.numpy as jnp

import jaxsim.api as js
import jaxsim.rbda
import jaxsim.typing as jtp
from jaxsim.physics.algos import soft_contacts

from .common import VelRepr


@jax.jit
Expand All @@ -28,16 +30,20 @@ def collidable_point_kinematics(
the linear component of the mixed 6D frame velocity.
"""

from jaxsim.physics.algos.soft_contacts import collidable_points_pos_vel

W_p_Ci, W_ṗ_Ci = collidable_points_pos_vel(
model=model.physics_model,
q=data.state.physics_model.joint_positions,
qd=data.state.physics_model.joint_velocities,
xfb=data.state.physics_model.xfb(),
)
from jaxsim.rbda import collidable_points

with data.switch_velocity_representation(VelRepr.Inertial):
W_p_Ci, W_ṗ_Ci = collidable_points.collidable_points_pos_vel(
model=model,
base_position=data.base_position(),
base_quaternion=data.base_orientation(dcm=False),
joint_positions=data.joint_positions(model=model),
base_linear_velocity=data.base_velocity()[0:3],
base_angular_velocity=data.base_velocity()[3:6],
joint_velocities=data.joint_velocities(model=model),
)

return W_p_Ci.T, W_ṗ_Ci.T
return W_p_Ci, W_ṗ_Ci


@jax.jit
Expand Down Expand Up @@ -101,24 +107,17 @@ def in_contact(
if set(link_names) - set(model.link_names()) != set():
raise ValueError("One or more link names are not part of the model")

from jaxsim.physics.algos.soft_contacts import collidable_points_pos_vel

W_p_Ci, _ = collidable_points_pos_vel(
model=model.physics_model,
q=data.state.physics_model.joint_positions,
qd=data.state.physics_model.joint_velocities,
xfb=data.state.physics_model.xfb(),
)
W_p_Ci = collidable_point_positions(model=model, data=data)

terrain_height = jax.vmap(lambda x, y: model.terrain.height(x=x, y=y))(
W_p_Ci[0, :], W_p_Ci[1, :]
W_p_Ci[:, 0], W_p_Ci[:, 1]
)

below_terrain = W_p_Ci[2, :] <= terrain_height
below_terrain = W_p_Ci[:, 2] <= terrain_height

links_in_contact = jax.vmap(
lambda link_index: jnp.where(
model.physics_model.gc.body == link_index,
model.kin_dyn_parameters.contact_parameters.body == link_index,
below_terrain,
jnp.zeros_like(below_terrain, dtype=bool),
).any()
Expand All @@ -130,16 +129,19 @@ def in_contact(
@jax.jit
def estimate_good_soft_contacts_parameters(
model: js.model.JaxSimModel,
*,
standard_gravity: jtp.FloatLike = jaxsim.math.StandardGravity,
static_friction_coefficient: jtp.FloatLike = 0.5,
number_of_active_collidable_points_steady_state: jtp.IntLike = 1,
damping_ratio: jtp.FloatLike = 1.0,
max_penetration: jtp.FloatLike | None = None,
) -> soft_contacts.SoftContactsParams:
) -> jaxsim.rbda.soft_contacts.SoftContactsParams:
"""
Estimate good soft contacts parameters for the given model.
Args:
model: The model to consider.
standard_gravity: The standard gravity constant.
static_friction_coefficient: The static friction coefficient.
number_of_active_collidable_points_steady_state:
The number of active collidable points in steady state supporting
Expand All @@ -162,12 +164,13 @@ def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float:
""""""

zero_data = js.data.JaxSimModelData.build(
model=model, soft_contacts_params=soft_contacts.SoftContactsParams()
model=model,
soft_contacts_params=jaxsim.rbda.soft_contacts.SoftContactsParams(),
)

W_pz_CoM = js.model.com_position(model=model, data=zero_data)[2]

if model.physics_model.is_floating_base:
if model.floating_base():
W_pz_C = collidable_point_positions(model=model, data=zero_data)[:, -1]
return 2 * (W_pz_CoM - W_pz_C.min())

Expand All @@ -181,12 +184,15 @@ def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float:

nc = number_of_active_collidable_points_steady_state

sc_parameters = soft_contacts.SoftContactsParams.build_default_from_physics_model(
physics_model=model.physics_model,
static_friction_coefficient=static_friction_coefficient,
max_penetration=max_δ,
number_of_active_collidable_points_steady_state=nc,
damping_ratio=damping_ratio,
sc_parameters = (
jaxsim.rbda.soft_contacts.SoftContactsParams.build_default_from_jaxsim_model(
model=model,
standard_gravity=standard_gravity,
static_friction_coefficient=static_friction_coefficient,
max_penetration=max_δ,
number_of_active_collidable_points_steady_state=nc,
damping_ratio=damping_ratio,
)
)

return sc_parameters
105 changes: 62 additions & 43 deletions src/jaxsim/api/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,15 @@
import numpy as np

import jaxsim.api as js
import jaxsim.physics.algos.aba
import jaxsim.physics.algos.crba
import jaxsim.physics.algos.forward_kinematics
import jaxsim.physics.algos.rnea
import jaxsim.physics.model.physics_model
import jaxsim.physics.model.physics_model_state
import jaxsim.rbda
import jaxsim.typing as jtp
from jaxsim.high_level.common import VelRepr
from jaxsim.physics.algos import soft_contacts
from jaxsim.simulation.ode_data import ODEState
from jaxsim.math import Quaternion
from jaxsim.utils import Mutability
from jaxsim.utils.tracing import not_tracing

from . import common
from .common import VelRepr
from .ode_data import ODEState

try:
from typing import Self
Expand All @@ -41,9 +37,8 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):

gravity: jtp.Array

soft_contacts_params: soft_contacts.SoftContactsParams = dataclasses.field(
repr=False
)
soft_contacts_params: jaxsim.rbda.SoftContactsParams = dataclasses.field(repr=False)

time_ns: jtp.Int = dataclasses.field(
default_factory=lambda: jnp.array(0, dtype=jnp.uint64)
)
Expand All @@ -60,9 +55,10 @@ def valid(self, model: js.model.JaxSimModel | None = None) -> bool:
"""

valid = True
valid = valid and self.standard_gravity() > 0

if model is not None:
valid = valid and self.state.valid(physics_model=model.physics_model)
valid = valid and self.state.valid(model=model)

return valid

Expand Down Expand Up @@ -95,9 +91,9 @@ def build(
base_linear_velocity: jtp.Vector | None = None,
base_angular_velocity: jtp.Vector | None = None,
joint_velocities: jtp.Vector | None = None,
gravity: jtp.Vector | None = None,
soft_contacts_state: soft_contacts.SoftContactsState | None = None,
soft_contacts_params: soft_contacts.SoftContactsParams | None = None,
standard_gravity: jtp.FloatLike = jaxsim.math.StandardGravity,
soft_contacts_state: js.ode_data.SoftContactsState | None = None,
soft_contacts_params: jaxsim.rbda.SoftContactsParams | None = None,
velocity_representation: VelRepr = VelRepr.Inertial,
time: jtp.FloatLike | None = None,
) -> JaxSimModelData:
Expand All @@ -114,7 +110,7 @@ def build(
base_angular_velocity:
The base angular velocity in the selected representation.
joint_velocities: The joint velocities.
gravity: The gravity 3D vector.
standard_gravity: The standard gravity constant.
soft_contacts_state: The state of the soft contacts.
soft_contacts_params: The parameters of the soft contacts.
velocity_representation: The velocity representation to use.
Expand Down Expand Up @@ -142,9 +138,7 @@ def build(
base_angular_velocity if base_angular_velocity is not None else jnp.zeros(3)
).squeeze()

gravity = jnp.array(
gravity if gravity is not None else model.physics_model.gravity[0:3]
).squeeze()
gravity = jnp.zeros(3).at[2].set(-standard_gravity)

joint_positions = jnp.atleast_1d(
joint_positions.squeeze()
Expand All @@ -167,7 +161,9 @@ def build(
soft_contacts_params = (
soft_contacts_params
if soft_contacts_params is not None
else js.contact.estimate_good_soft_contacts_parameters(model=model)
else js.contact.estimate_good_soft_contacts_parameters(
model=model, standard_gravity=standard_gravity
)
)

W_H_B = jaxlie.SE3.from_rotation_and_translation(
Expand All @@ -184,20 +180,22 @@ def build(
is_force=False,
)

ode_state = ODEState.build(
physics_model=model.physics_model,
physics_model_state=jaxsim.physics.model.physics_model_state.PhysicsModelState(
base_position=base_position.astype(float),
base_quaternion=base_quaternion.astype(float),
joint_positions=joint_positions.astype(float),
base_linear_velocity=v_WB[0:3].astype(float),
base_angular_velocity=v_WB[3:6].astype(float),
joint_velocities=joint_velocities.astype(float),
ode_state = ODEState.build_from_jaxsim_model(
model=model,
base_position=base_position.astype(float),
base_quaternion=base_quaternion.astype(float),
joint_positions=joint_positions.astype(float),
base_linear_velocity=v_WB[0:3].astype(float),
base_angular_velocity=v_WB[3:6].astype(float),
joint_velocities=joint_velocities.astype(float),
tangential_deformation=(
soft_contacts_state.tangential_deformation
if soft_contacts_state is not None
else None
),
soft_contacts_state=soft_contacts_state,
)

if not ode_state.valid(physics_model=model.physics_model):
if not ode_state.valid(model=model):
raise ValueError(ode_state)

return JaxSimModelData(
Expand All @@ -222,6 +220,16 @@ def time(self) -> jtp.Float:

return self.time_ns.astype(float) / 1e9

def standard_gravity(self) -> jtp.Float:
"""
Get the standard gravity constant.
Returns:
The standard gravity constant.
"""

return -self.gravity[2]

@functools.partial(jax.jit, static_argnames=["joint_names"])
def joint_positions(
self,
Expand Down Expand Up @@ -250,9 +258,14 @@ def joint_positions(
"""

if model is None:
if joint_names is not None:
raise ValueError("Joint names cannot be provided without a model")

return self.state.physics_model.joint_positions

if not self.valid(model=model):
if not_tracing(self.state.physics_model.joint_positions) and not self.valid(
model=model
):
msg = "The data object is not compatible with the provided model"
raise ValueError(msg)

Expand Down Expand Up @@ -290,9 +303,14 @@ def joint_velocities(
"""

if model is None:
if joint_names is not None:
raise ValueError("Joint names cannot be provided without a model")

return self.state.physics_model.joint_velocities

if not self.valid(model=model):
if not_tracing(self.state.physics_model.joint_velocities) and not self.valid(
model=model
):
msg = "The data object is not compatible with the provided model"
raise ValueError(msg)

Expand Down Expand Up @@ -325,26 +343,27 @@ def base_orientation(self, dcm: jtp.BoolLike = False) -> jtp.Vector | jtp.Matrix
The base orientation.
"""

# Extract the base quaternion.
W_Q_B = self.state.physics_model.base_quaternion.squeeze()

# Always normalize the quaternion to avoid numerical issues.
# If the active scheme does not integrate the quaternion on its manifold,
# we introduce a Baumgarte stabilization to let the quaternion converge to
# a unit quaternion. In this case, it is not guaranteed that the quaternion
# stored in the state is a unit quaternion.
base_unit_quaternion = (
self.state.physics_model.base_quaternion.squeeze()
/ jnp.linalg.norm(self.state.physics_model.base_quaternion)
W_Q_B = jax.lax.select(
pred=jnp.allclose(jnp.linalg.norm(W_Q_B), 1.0, atol=1e-6, rtol=0.0),
on_true=W_Q_B,
on_false=W_Q_B / jnp.linalg.norm(W_Q_B),
)

# Slice to convert quaternion wxyz -> xyzw
to_xyzw = np.array([1, 2, 3, 0])

return (
base_unit_quaternion
W_Q_B
if not dcm
else jaxlie.SO3.from_quaternion_xyzw(
base_unit_quaternion[to_xyzw]
Quaternion.to_xyzw(wxyz=W_Q_B)
).as_matrix()
)
).astype(float)

@jax.jit
def base_transform(self) -> jtp.MatrixJax:
Expand Down
Loading

0 comments on commit ca32a7e

Please sign in to comment.