diff --git a/src/jaxsim/high_level/model.py b/src/jaxsim/high_level/model.py index c1f1687ba..4bd1e0839 100644 --- a/src/jaxsim/high_level/model.py +++ b/src/jaxsim/high_level/model.py @@ -402,6 +402,41 @@ def joints(self, joint_names: List[str] = None) -> List[high_level.joint.Joint]: return [self._joints[name] for name in joint_names] + def in_contact( + self, + link_names: Optional[List[str]] = None, + terrain: Terrain = FlatTerrain(), + ) -> jtp.Vector: + """""" + + link_names = link_names if link_names is not None else self.link_names() + + if set(link_names) - set(self._links.keys()) != set(): + raise ValueError("One or more link names are not part of the model") + + from jaxsim.physics.algos.soft_contacts import collidable_points_pos_vel + + W_p_Ci, _ = collidable_points_pos_vel( + model=self.physics_model, + q=self.data.model_state.joint_positions, + qd=self.data.model_state.joint_velocities, + xfb=self.data.model_state.xfb(), + ) + + terrain_height = jax.vmap(terrain.height)(W_p_Ci[0, :], W_p_Ci[1, :]) + + below_terrain = W_p_Ci[2, :] <= terrain_height + + links_in_contact = jax.vmap( + lambda link_index: jnp.where( + self.physics_model.gc.body == link_index, + below_terrain, + jnp.zeros_like(below_terrain, dtype=bool), + ).any() + )(jnp.array([link.index() for link in self.links(link_names=link_names)])) + + return links_in_contact + # ================== # Vectorized methods # ==================