Skip to content

Commit

Permalink
Use polymorphism to update the post impact velocity
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti committed Feb 10, 2025
1 parent 6972075 commit a2e6a1a
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 49 deletions.
52 changes: 3 additions & 49 deletions src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2101,54 +2101,8 @@ def step(
),
)

if isinstance(model.contact_model, jaxsim.rbda.contacts.RigidContacts):
# Extract the indices corresponding to the enabled collidable points.
indices_of_enabled_collidable_points = (
model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
)

W_p_C = js.contact.collidable_point_positions(model, data_tf)[
indices_of_enabled_collidable_points
]

# Compute the penetration depth of the collidable points.
δ, *_ = jax.vmap(
jaxsim.rbda.contacts.common.compute_penetration_data,
in_axes=(0, 0, None),
)(W_p_C, jnp.zeros_like(W_p_C), model.terrain)

with data_tf.switch_velocity_representation(VelRepr.Mixed):
J_WC = js.contact.jacobian(model, data_tf)[
indices_of_enabled_collidable_points
]
M = js.model.free_floating_mass_matrix(model, data_tf)
BW_ν_pre_impact = data_tf.generalized_velocity

# Compute the impact velocity.
# It may be discontinuous in case new contacts are made.
BW_ν_post_impact = (
jaxsim.rbda.contacts.RigidContacts.compute_impact_velocity(
generalized_velocity=BW_ν_pre_impact,
inactive_collidable_points=(δ <= 0),
M=M,
J_WC=J_WC,
)
)

BW_ν_post_impact_inertial = data_tf.other_representation_to_inertial(
array=BW_ν_post_impact[0:6],
other_representation=VelRepr.Mixed,
transform=data_tf._base_transform.at[0:3, 0:3].set(jnp.eye(3)),
is_force=False,
)

# Reset the generalized velocity.
data_tf = dataclasses.replace(
data_tf,
velocity_representation=data.velocity_representation,
_base_linear_velocity=BW_ν_post_impact_inertial[0:3],
_base_angular_velocity=BW_ν_post_impact_inertial[3:6],
_joint_velocities=BW_ν_post_impact[6:],
)
data_tf = model.contact_model.update_velocity_after_impact(
model=model, data=data_tf
)

return data_tf
76 changes: 76 additions & 0 deletions src/jaxsim/rbda/contacts/common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import abc
import dataclasses
import functools

import jax
Expand Down Expand Up @@ -275,3 +276,78 @@ def update_contact_state(
return {"tangential_deformation": old_contact_state["m_dot"]}
case RigidContacts() | RelaxedRigidContacts():
return {}

@jax.jit
@js.common.named_scope
def update_velocity_after_impact(
self: type[Self], model: js.model.JaxSimModel, data: js.data.JaxSimModelData
) -> js.data.JaxSimModelData:
"""
Update the velocity after an impact.
Args:
model: The robot model considered by the contact model.
data: The data of the considered model.
Returns:
The updated data of the considered model.
"""

# Import the rigid contact model to avoid circular imports.
from jaxsim.api.common import VelRepr

from .rigid import RigidContacts

if isinstance(self, RigidContacts):
# Extract the indices corresponding to the enabled collidable points.
indices_of_enabled_collidable_points = (
model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
)

W_p_C = js.contact.collidable_point_positions(model, data)[
indices_of_enabled_collidable_points
]

# Compute the penetration depth of the collidable points.
δ, *_ = jax.vmap(
jaxsim.rbda.contacts.common.compute_penetration_data,
in_axes=(0, 0, None),
)(W_p_C, jnp.zeros_like(W_p_C), model.terrain)

original_representation = data.velocity_representation

with data.switch_velocity_representation(VelRepr.Mixed):
J_WC = js.contact.jacobian(model, data)[
indices_of_enabled_collidable_points
]
M = js.model.free_floating_mass_matrix(model, data)
BW_ν_pre_impact = data.generalized_velocity

# Compute the impact velocity.
# It may be discontinuous in case new contacts are made.
BW_ν_post_impact = (
jaxsim.rbda.contacts.RigidContacts.compute_impact_velocity(
generalized_velocity=BW_ν_pre_impact,
inactive_collidable_points=(δ <= 0),
M=M,
J_WC=J_WC,
)
)

BW_ν_post_impact_inertial = data.other_representation_to_inertial(
array=BW_ν_post_impact[0:6],
other_representation=VelRepr.Mixed,
transform=data._base_transform.at[0:3, 0:3].set(jnp.eye(3)),
is_force=False,
)

# Reset the generalized velocity.
data = dataclasses.replace(
data,
velocity_representation=original_representation,
_base_linear_velocity=BW_ν_post_impact_inertial[0:3],
_base_angular_velocity=BW_ν_post_impact_inertial[3:6],
_joint_velocities=BW_ν_post_impact[6:],
)

return data

0 comments on commit a2e6a1a

Please sign in to comment.