Skip to content

Commit

Permalink
Streamline new API changes to alternative contact models
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti committed Jan 30, 2025
1 parent 7f0893f commit a7b7150
Show file tree
Hide file tree
Showing 12 changed files with 443 additions and 136 deletions.
95 changes: 87 additions & 8 deletions src/jaxsim/api/contact.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,11 @@ def collidable_point_kinematics(
the linear component of the mixed 6D frame velocity.
"""

# Switch to inertial-fixed since the RBDAs expect velocities in this representation.
with data.switch_velocity_representation(VelRepr.Inertial):

W_p_Ci, W_ṗ_Ci = jaxsim.rbda.collidable_points.collidable_points_pos_vel(
model=model,
link_transforms=data.link_transforms,
link_velocities=data.link_velocities,
)
W_p_Ci, W_ṗ_Ci = jaxsim.rbda.collidable_points.collidable_points_pos_vel(
model=model,
link_transforms=data.link_transforms,
link_velocities=data.link_velocities,
)

return W_p_Ci, W_ṗ_Ci

Expand Down Expand Up @@ -164,15 +161,24 @@ def estimate_good_soft_contacts_parameters(
def estimate_good_contact_parameters(
model: js.model.JaxSimModel,
*,
standard_gravity: jtp.FloatLike = jaxsim.math.STANDARD_GRAVITY,
static_friction_coefficient: jtp.FloatLike = 0.5,
number_of_active_collidable_points_steady_state: jtp.IntLike = 1,
damping_ratio: jtp.FloatLike = 1.0,
max_penetration: jtp.FloatLike | None = None,
**kwargs,
) -> jaxsim.rbda.contacts.ContactParamsTypes:
"""
Estimate good contact parameters.
Args:
model: The model to consider.
standard_gravity: The standard gravity acceleration.
static_friction_coefficient: The static friction coefficient.
number_of_active_collidable_points_steady_state:
The number of active collidable points in steady state.
damping_ratio: The damping ratio.
max_penetration: The maximum penetration allowed.
kwargs:
Additional model-specific parameters passed to the builder method of
the parameters class.
Expand All @@ -190,8 +196,81 @@ def estimate_good_contact_parameters(
specific application.
"""

def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float:
"""
Displacement between the CoM and the lowest collidable point using zero
joint positions.
"""

zero_data = js.data.JaxSimModelData.build(
model=model,
)

W_pz_CoM = js.com.com_position(model=model, data=zero_data)[2]

if model.floating_base():
W_pz_C = collidable_point_positions(model=model, data=zero_data)[:, -1]
return 2 * (W_pz_CoM - W_pz_C.min())

return 2 * W_pz_CoM

max_δ = (
max_penetration
if max_penetration is not None
# Consider as default a 0.5% of the model height.
else 0.005 * estimate_model_height(model=model)
)

nc = number_of_active_collidable_points_steady_state

match model.contact_model:

case contacts.SoftContacts():
assert isinstance(model.contact_model, contacts.SoftContacts)

parameters = contacts.SoftContactsParams.build_default_from_jaxsim_model(
model=model,
standard_gravity=standard_gravity,
static_friction_coefficient=static_friction_coefficient,
max_penetration=max_δ,
number_of_active_collidable_points_steady_state=nc,
damping_ratio=damping_ratio,
**kwargs,
)

case contacts.ViscoElasticContacts():
assert isinstance(model.contact_model, contacts.ViscoElasticContacts)

parameters = (
contacts.ViscoElasticContactsParams.build_default_from_jaxsim_model(
model=model,
standard_gravity=standard_gravity,
static_friction_coefficient=static_friction_coefficient,
max_penetration=max_δ,
number_of_active_collidable_points_steady_state=nc,
damping_ratio=damping_ratio,
**kwargs,
)
)

case contacts.RigidContacts():
assert isinstance(model.contact_model, contacts.RigidContacts)

# Disable Baumgarte stabilization by default since it does not play
# well with the forward Euler integrator.
K = kwargs.get("K", 0.0)

parameters = contacts.RigidContactsParams.build(
mu=static_friction_coefficient,
**(
dict(
K=K,
D=2 * jnp.sqrt(K),
)
| kwargs
),
)

case contacts.RelaxedRigidContacts():
assert isinstance(model.contact_model, contacts.RelaxedRigidContacts)

Expand Down
14 changes: 9 additions & 5 deletions src/jaxsim/api/contact_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import jaxsim.api as js
import jaxsim.typing as jtp
from jaxsim.rbda.contacts import SoftContacts


@jax.jit
Expand All @@ -15,7 +16,7 @@ def link_contact_forces(
*,
link_forces: jtp.MatrixLike | None = None,
joint_torques: jtp.VectorLike | None = None,
) -> jtp.Matrix:
) -> tuple[jtp.Matrix, dict[str, jtp.Matrix]]:
"""
Compute the 6D contact forces of all links of the model in inertial representation.
Expand All @@ -33,11 +34,14 @@ def link_contact_forces(
"""

# Compute the contact forces for each collidable point with the active contact model.
W_f_C, _ = model.contact_model.compute_contact_forces(
W_f_C, aux_dict = model.contact_model.compute_contact_forces(
model=model,
data=data,
link_forces=link_forces,
joint_force_references=joint_torques,
**(
dict(link_forces=link_forces, joint_force_references=joint_torques)
if not isinstance(model.contact_model, SoftContacts)
else {}
),
)

# Compute the 6D forces applied to the links equivalent to the forces applied
Expand All @@ -46,7 +50,7 @@ def link_contact_forces(
model=model, data=data, contact_forces=W_f_C
)

return W_f_L
return W_f_L, aux_dict


@staticmethod
Expand Down
21 changes: 21 additions & 0 deletions src/jaxsim/api/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
link_transforms: jtp.Matrix = dataclasses.field(repr=False, default=None)
link_velocities: jtp.Matrix = dataclasses.field(repr=False, default=None)

# Extended state for soft and rigid contact models.
contact_state: dict[str, jtp.Array] = dataclasses.field(default=None)

@staticmethod
def build(
model: js.model.JaxSimModel,
Expand All @@ -70,6 +73,8 @@ def build(
base_angular_velocity: jtp.VectorLike | None = None,
joint_velocities: jtp.VectorLike | None = None,
velocity_representation: VelRepr = VelRepr.Inertial,
*,
contact_state: dict[str, jtp.Array] | None = None,
) -> JaxSimModelData:
"""
Create a `JaxSimModelData` object with the given state.
Expand All @@ -85,6 +90,7 @@ def build(
The base angular velocity in the selected representation.
joint_velocities: The joint velocities.
velocity_representation: The velocity representation to use.
contact_state: The optional contact state.
Returns:
A `JaxSimModelData` initialized with the given state.
Expand Down Expand Up @@ -165,6 +171,20 @@ def build(
joint_velocities=joint_velocities,
)

contact_state = (
{
"tangential_deformation": jnp.zeros_like(
model.kin_dyn_parameters.contact_parameters.point
)
}
if isinstance(
model.contact_model,
jaxsim.rbda.contacts.SoftContacts
| jaxsim.rbda.contacts.ViscoElasticContacts,
)
else contact_state or {}
)

model_data = JaxSimModelData(
base_quaternion=base_quaternion,
base_position=base_position,
Expand All @@ -177,6 +197,7 @@ def build(
joint_transforms=joint_transforms,
link_transforms=link_transforms,
link_velocities=link_velocities,
contact_state=contact_state or {},
)

if not model_data.valid(model=model):
Expand Down
76 changes: 73 additions & 3 deletions src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class JaxSimModel(JaxsimDataclass):
default_factory=jaxsim.terrain.FlatTerrain.build, repr=False
)

gravity: Static[float] = jaxsim.math.STANDARD_GRAVITY
gravity: Static[float] = -jaxsim.math.STANDARD_GRAVITY

contact_model: Static[jaxsim.rbda.contacts.ContactModel | None] = dataclasses.field(
default=None, repr=False
Expand Down Expand Up @@ -111,6 +111,7 @@ def build_from_model_description(
terrain: jaxsim.terrain.Terrain | None = None,
contact_model: jaxsim.rbda.contacts.ContactModel | None = None,
contact_params: jaxsim.rbda.contacts.ContactsParams | None = None,
gravity: jtp.FloatLike = jaxsim.math.STANDARD_GRAVITY,
is_urdf: bool | None = None,
considered_joints: Sequence[str] | None = None,
) -> JaxSimModel:
Expand All @@ -131,6 +132,7 @@ def build_from_model_description(
The contact model to consider.
If not specified, a soft contacts model is used.
contact_params: The parameters of the contact model.
gravity: The gravity constant.
is_urdf:
The optional flag to force the model description to be parsed as a URDF.
This is usually automatically inferred.
Expand Down Expand Up @@ -164,6 +166,7 @@ def build_from_model_description(
terrain=terrain,
contact_model=contact_model,
contacts_params=contact_params,
gravity=gravity,
)

# Store the origin of the model, in case downstream logic needs it.
Expand Down Expand Up @@ -247,7 +250,7 @@ def build(
terrain=terrain,
contact_model=contact_model,
contacts_params=contacts_params,
gravity=gravity,
gravity=-gravity,
# The following is wrapped as hashless since it's a static argument, and we
# don't want to trigger recompilation if it changes. All relevant parameters
# needed to compute kinematics and dynamics quantities are stored in the
Expand Down Expand Up @@ -447,6 +450,8 @@ def reduce(
time_step=model.time_step,
terrain=model.terrain,
contact_model=model.contact_model,
contacts_params=model.contacts_params,
gravity=model.gravity,
)

# Store the origin of the model, in case downstream logic needs it.
Expand Down Expand Up @@ -2045,7 +2050,7 @@ def step(

# Compute the 6D forces W_f ∈ ℝ^{n_L × 6} applied to links due to contact
# with the terrain.
W_f_L_terrain = js.contact_model.link_contact_forces(
W_f_L_terrain, aux_dict = js.contact_model.link_contact_forces(
model=model,
data=data,
link_forces=W_f_L_external,
Expand All @@ -2058,6 +2063,33 @@ def step(

W_f_L_total = W_f_L_external + W_f_L_terrain

# =============================
# Update the contact state data
# =============================

contact_state = {}

match model.contact_model:

case jaxsim.rbda.contacts.SoftContacts():
contact_state["tangential_deformation"] = aux_dict["m_dot"]
data = data.replace(contact_state=contact_state)

case jaxsim.rbda.contacts.ViscoElasticContacts():
contact_state["tangential_deformation"] = jnp.zeros_like(
jnp.array(model.kin_dyn_parameters.contact_parameters.point)
)
data = data.replace(contact_state=contact_state)

case (
jaxsim.rbda.contacts.RigidContacts()
| jaxsim.rbda.contacts.RelaxedRigidContacts()
):
pass

case _:
raise ValueError(f"Invalid contact model: {model.contact_model}")

# ===============================
# Compute the system acceleration
# ===============================
Expand All @@ -2081,6 +2113,44 @@ def step(
joint_accelerations=,
)

if isinstance(model.contact_model, jaxsim.rbda.contacts.RigidContacts):
# Extract the indices corresponding to the enabled collidable points.
indices_of_enabled_collidable_points = (
model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
)

W_p_C = js.contact.collidable_point_positions(model, data_tf)[
indices_of_enabled_collidable_points
]

# Compute the penetration depth of the collidable points.
δ, *_ = jax.vmap(
jaxsim.rbda.contacts.common.compute_penetration_data,
in_axes=(0, 0, None),
)(W_p_C, jnp.zeros_like(W_p_C), model.terrain)

with data_tf.switch_velocity_representation(VelRepr.Mixed):
J_WC = js.contact.jacobian(model, data_tf)[
indices_of_enabled_collidable_points
]
M = js.model.free_floating_mass_matrix(model, data_tf)
BW_ν_pre_impact = data_tf.generalized_velocity()

# Compute the impact velocity.
# It may be discontinuous in case new contacts are made.
BW_ν_post_impact = (
jaxsim.rbda.contacts.RigidContacts.compute_impact_velocity(
generalized_velocity=BW_ν_pre_impact,
inactive_collidable_points=(δ <= 0),
M=M,
J_WC=J_WC,
)
)

# Reset the generalized velocity.
data_tf = data_tf.reset_base_velocity(BW_ν_post_impact[0:6])
data_tf = data_tf.reset_joint_velocities(BW_ν_post_impact[6:])

# ne parliamo dopo
# Restore the input velocity representation
data_tf = data_tf.replace(
Expand Down
5 changes: 4 additions & 1 deletion src/jaxsim/api/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,10 @@ def system_velocity_dynamics(
link_forces=W_f_L,
)

return W_v̇_WB,
return (
W_v̇_WB,
,
)


def system_acceleration(
Expand Down
2 changes: 1 addition & 1 deletion src/jaxsim/math/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@


# Define the default standard gravity constant.
STANDARD_GRAVITY = -9.81
STANDARD_GRAVITY = 9.81
Loading

0 comments on commit a7b7150

Please sign in to comment.