diff --git a/environment.yml b/environment.yml index 3426409d1..cd31033b1 100644 --- a/environment.yml +++ b/environment.yml @@ -12,6 +12,7 @@ dependencies: - jaxlie >= 1.3.0 - jax-dataclasses >= 1.4.0 - pptree + - qpax - rod >= 0.3.0 - typing_extensions # python<3.12 # ==================================== diff --git a/pyproject.toml b/pyproject.toml index d1c4a8a7b..ddefc64dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,7 @@ dependencies = [ "jaxlie >= 1.3.0", "jax_dataclasses >= 1.4.0", "pptree", + "qpax", "rod >= 0.3.0", "typing_extensions ; python_version < '3.12'", ] @@ -187,6 +188,7 @@ mediapy = "*" mujoco = "*" notebook = "*" pptree = "*" +qpax = "*" rod = "*" sdformat14 = "*" typing_extensions = "*" diff --git a/src/jaxsim/api/contact.py b/src/jaxsim/api/contact.py index 832dad963..ac1520970 100644 --- a/src/jaxsim/api/contact.py +++ b/src/jaxsim/api/contact.py @@ -114,19 +114,24 @@ def collidable_point_forces( @jax.jit def collidable_point_dynamics( - model: js.model.JaxSimModel, data: js.data.JaxSimModelData -) -> tuple[jtp.Matrix, jtp.Matrix]: + model: js.model.JaxSimModel, + data: js.data.JaxSimModelData, + link_forces: jtp.MatrixLike | None = None, +) -> tuple[jtp.Matrix, dict[str, jtp.Array]]: r""" - Compute the 6D force applied to each collidable point and the corresponding - material deformation rate. + Compute the 6D force applied to each collidable point. Args: model: The model to consider. data: The data of the considered model. + link_forces: + The 6D external forces to apply to the links expressed in the same + representation of data. Returns: - The 6D force applied to each collidable point and the corresponding - material deformation rate. + The 6D force applied to each collidable point and additional data based on the contact model configured: + - Soft: the material deformation rate. + - Rigid: nothing. Note: The material deformation rate is always returned in the mixed frame @@ -138,7 +143,8 @@ def collidable_point_dynamics( # all collidable points belonging to the robot. W_p_Ci, W_ṗ_Ci = js.contact.collidable_point_kinematics(model=model, data=data) - # Import privately the soft contacts classes. + # Import privately the contacts classes. + from jaxsim.rbda.contacts.rigid import RigidContacts, RigidContactsState from jaxsim.rbda.contacts.soft import SoftContacts, SoftContactsState # Build the soft contact model. @@ -161,6 +167,28 @@ def collidable_point_dynamics( W_f_Ci, (CW_ṁ,) = jax.vmap(soft_contacts.compute_contact_forces)( W_p_Ci, W_ṗ_Ci, data.state.contact.tangential_deformation ) + aux_data = dict(m_dot=CW_ṁ) + + case RigidContacts(): + assert isinstance(model.contact_model, RigidContacts) + assert isinstance(data.state.contact, RigidContactsState) + + # Build the contact model. + rigid_contacts = RigidContacts( + parameters=data.contacts_params, terrain=model.terrain + ) + + # Compute the 6D force expressed in the inertial frame and applied to each + # collidable point. + W_f_Ci, _ = rigid_contacts.compute_contact_forces( + position=W_p_Ci, + velocity=W_ṗ_Ci, + model=model, + data=data, + link_forces=link_forces, + ) + + aux_data = dict() case _: raise ValueError(f"Invalid contact model {model.contact_model}") @@ -175,7 +203,7 @@ def collidable_point_dynamics( ) )(W_f_Ci) - return f_Ci, CW_ṁ + return f_Ci, aux_data @functools.partial(jax.jit, static_argnames=["link_names"]) diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index a91df0332..ad7edcec1 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -14,6 +14,7 @@ from jax_dataclasses import Static import jaxsim.api as js +import jaxsim.exceptions import jaxsim.terrain import jaxsim.typing as jtp from jaxsim.math import Adjoint, Cross @@ -1890,6 +1891,8 @@ def step( and the new state of the integrator. """ + from jaxsim.rbda.contacts.rigid import RigidContacts + # Extract the integrator kwargs. # The following logic allows using integrators having kwargs colliding with the # kwargs of this step function. @@ -1901,12 +1904,12 @@ def step( # Extract the initial resources. t0_ns = data.time_ns - state_x0 = data.state + state_t0 = data.state integrator_state_x0 = integrator_state # Step the dynamics forward. - state_xf, integrator_state_xf = integrator.step( - x0=state_x0, + state_tf, integrator_state_tf = integrator.step( + x0=state_t0, t0=jnp.array(t0_ns / 1e9).astype(float), dt=dt, params=integrator_state_x0, @@ -1928,11 +1931,61 @@ def step( ), ) - return ( + data_tf = ( # Store the new state of the model and the new time. data.replace( - state=state_xf, + state=state_tf, time_ns=t0_ns + jnp.array(dt * 1e9).astype(jnp.uint64), - ), - integrator_state_xf, + ) + ) + + # Post process the simulation state, if needed. + match model.contact_model: + + # Rigid contact models use an impact model that produces a discontinuous model velocity. + # Hence here we need to reset the velocity after each impact to guarantee that + # the linear velocity of the active collidable points is zero. + case RigidContacts(): + # Raise runtime error for not supported case in which Rigid contacts and Baumgarte stabilization + # enabled are used with ForwardEuler integrator. + jaxsim.exceptions.raise_runtime_error_if( + condition=jnp.logical_and( + isinstance( + integrator, + jaxsim.integrators.fixed_step.ForwardEuler + | jaxsim.integrators.fixed_step.ForwardEulerSO3, + ), + jnp.array( + [data_tf.contacts_params.K, data_tf.contacts_params.D] + ).any(), + ), + msg="Baumgarte stabilization is not supported with ForwardEuler integrators", + ) + + with data_tf.switch_velocity_representation(VelRepr.Mixed): + W_p_C = js.contact.collidable_point_positions(model, data_tf) + M = js.model.free_floating_mass_matrix(model, data_tf) + J_WC = js.contact.jacobian(model, data_tf) + px, py, _ = W_p_C.T + terrain_height = jax.vmap(model.terrain.height)(px, py) + inactive_collidable_points, _ = RigidContacts.detect_contacts( + W_p_C=W_p_C, + terrain_height=terrain_height, + ) + BW_nu_post_impact = RigidContacts.compute_impact_velocity( + data=data_tf, + inactive_collidable_points=inactive_collidable_points, + M=M, + J_WC=J_WC, + ) + data_tf = data_tf.reset_base_velocity(BW_nu_post_impact[0:6]) + data_tf = data_tf.reset_joint_velocities(BW_nu_post_impact[6:]) + # Restore the input velocity representation. + data_tf = data_tf.replace( + velocity_representation=data.velocity_representation, validate=False + ) + + return ( + data_tf, + integrator_state_tf, ) diff --git a/src/jaxsim/api/ode.py b/src/jaxsim/api/ode.py index bd6636395..14e362490 100644 --- a/src/jaxsim/api/ode.py +++ b/src/jaxsim/api/ode.py @@ -50,7 +50,7 @@ def wrap_system_dynamics_for_integration( # The wrapped dynamics will hold a reference of this object. model_closed = model.copy() data_closed = data.copy().replace( - state=js.ode_data.ODEState.zero(model=model_closed) + state=js.ode_data.ODEState.zero(model=model_closed, data=data) ) def f(x: ODEState, t: Time, **kwargs_f) -> tuple[ODEState, dict[str, Any]]: @@ -88,7 +88,7 @@ def system_velocity_dynamics( *, joint_forces: jtp.Vector | None = None, link_forces: jtp.Vector | None = None, -) -> tuple[jtp.Vector, jtp.Vector, jtp.Matrix, dict[str, Any]]: +) -> tuple[jtp.Vector, jtp.Vector, dict[str, Any]]: """ Compute the dynamics of the system velocity. @@ -102,18 +102,10 @@ def system_velocity_dynamics( Returns: A tuple containing the derivative of the base 6D velocity in inertial-fixed - representation, the derivative of the joint velocities, the derivative of - the material deformation, and the dictionary of auxiliary data returned by - the system dynamics evaluation. + representation, the derivative of the joint velocities, and auxiliary data + returned by the system dynamics evaluation. """ - # Build joint torques if not provided. - τ = ( - jnp.atleast_1d(joint_forces.squeeze()) - if joint_forces is not None - else jnp.zeros_like(data.joint_positions()) - ).astype(float) - # Build link forces if not provided. # These forces are expressed in the frame corresponding to the velocity # representation of data. @@ -123,6 +115,15 @@ def system_velocity_dynamics( else jnp.zeros((model.number_of_links(), 6)) ).astype(float) + # We expect that the 6D forces included in the `link_forces` argument are expressed + # in the frame corresponding to the velocity representation of `data`. + references = js.references.JaxSimModelReferences.build( + model=model, + link_forces=O_f_L, + data=data, + velocity_representation=data.velocity_representation, + ) + # ====================== # Compute contact forces # ====================== @@ -131,19 +132,17 @@ def system_velocity_dynamics( # with the terrain. W_f_Li_terrain = jnp.zeros_like(O_f_L).astype(float) - # Import privately the soft contacts classes. - from jaxsim.rbda.contacts.soft import SoftContactsState - - # Initialize the derivative of the tangential deformation ṁ ∈ ℝ^{n_c × 3}. - assert isinstance(data.state.contact, SoftContactsState) - ṁ = jnp.zeros_like(data.state.contact.tangential_deformation).astype(float) - + aux_data = {} if len(model.kin_dyn_parameters.contact_parameters.body) > 0: # Compute the 6D forces W_f ∈ ℝ^{n_c × 6} applied to each collidable point - # and the corresponding material deformation rates. + # along with contact-specific auxiliary states. with data.switch_velocity_representation(VelRepr.Inertial): - W_f_Ci, ṁ = js.contact.collidable_point_dynamics(model=model, data=data) + W_f_Ci, aux_data = js.contact.collidable_point_dynamics( + model=model, + data=data, + link_forces=references.link_forces(model=model, data=data), + ) # Construct the vector defining the parent link index of each collidable point. # We use this vector to sum the 6D forces of all collidable points rigidly @@ -161,6 +160,74 @@ def system_velocity_dynamics( W_f_Li_terrain = mask.T @ W_f_Ci + # =========================== + # Compute system acceleration + # =========================== + + # Compute the total link forces + with ( + data.switch_velocity_representation(VelRepr.Inertial), + references.switch_velocity_representation(VelRepr.Inertial), + ): + references = references.apply_link_forces( + model=model, + data=data, + forces=W_f_Li_terrain, + additive=True, + ) + # Get the link forces in the data representation + with references.switch_velocity_representation(data.velocity_representation): + f_L_total = references.link_forces(model=model, data=data) + + # The following method always returns the inertial-fixed acceleration, and expects + # the link_forces expressed in the inertial frame. + W_v̇_WB, s̈ = system_acceleration( + model=model, data=data, joint_forces=joint_forces, link_forces=f_L_total + ) + + return W_v̇_WB, s̈, aux_data + + +def system_acceleration( + model: js.model.JaxSimModel, + data: js.data.JaxSimModelData, + *, + joint_forces: jtp.VectorLike | None = None, + link_forces: jtp.MatrixLike | None = None, +) -> tuple[jtp.Vector, jtp.Vector]: + """ + Compute the system acceleration in inertial-fixed representation. + + Args: + model: The model to consider. + data: The data of the considered model. + joint_forces: The joint forces to apply. + link_forces: + The 6D forces to apply to the links expressed in the same representation of data. + + Returns: + A tuple containing the base 6D acceleration in inertial-fixed representation + and the joint accelerations. + """ + + # ==================== + # Validate input data + # ==================== + + # Build link forces if not provided. + f_L = ( + jnp.atleast_2d(link_forces.squeeze()) + if link_forces is not None + else jnp.zeros((model.number_of_links(), 6)) + ).astype(float) + + # Build joint torques if not provided. + τ = ( + jnp.atleast_1d(joint_forces.squeeze()) + if joint_forces is not None + else jnp.zeros_like(data.joint_positions()) + ).astype(float) + # ==================== # Enforce joint limits # ==================== @@ -198,29 +265,25 @@ def system_velocity_dynamics( references = js.references.JaxSimModelReferences.build( model=model, - joint_force_references=τ_total, - link_forces=O_f_L, data=data, velocity_representation=data.velocity_representation, + joint_force_references=τ_total, + link_forces=f_L, ) - with references.switch_velocity_representation(VelRepr.Inertial): - W_f_L = references.link_forces(model=model, data=data) - - # Compute the total external 6D forces applied to the links. - W_f_L_total = W_f_L + W_f_Li_terrain - # - Joint accelerations: s̈ ∈ ℝⁿ # - Base inertial-fixed acceleration: W_v̇_WB = (W_p̈_B, W_ω̇_B) ∈ ℝ⁶ - with data.switch_velocity_representation(velocity_representation=VelRepr.Inertial): + with ( + data.switch_velocity_representation(velocity_representation=VelRepr.Inertial), + references.switch_velocity_representation(VelRepr.Inertial), + ): W_v̇_WB, s̈ = js.model.forward_dynamics_aba( model=model, data=data, - joint_forces=τ_total, - link_forces=W_f_L_total, + joint_forces=references.joint_force_references(), + link_forces=references.link_forces(), ) - - return W_v̇_WB, s̈, ṁ, dict() + return W_v̇_WB, s̈ @jax.jit @@ -291,14 +354,29 @@ def system_dynamics( by the system dynamics evaluation. """ + from jaxsim.rbda.contacts.rigid import RigidContacts + from jaxsim.rbda.contacts.soft import SoftContacts + # Compute the accelerations and the material deformation rate. - W_v̇_WB, s̈, ṁ, aux_dict = system_velocity_dynamics( + W_v̇_WB, s̈, aux_dict = system_velocity_dynamics( model=model, data=data, joint_forces=joint_forces, link_forces=link_forces, ) + ode_state_kwargs = {} + + match model.contact_model: + case SoftContacts(): + ode_state_kwargs["tangential_deformation"] = aux_dict["m_dot"] + + case RigidContacts(): + pass + + case _: + raise ValueError("Unable to determine contact state class prefix.") + # Extract the velocities. W_ṗ_B, W_Q̇_B, ṡ = system_position_dynamics( model=model, @@ -317,7 +395,7 @@ def system_dynamics( base_linear_velocity=W_v̇_WB[0:3], base_angular_velocity=W_v̇_WB[3:6], joint_velocities=s̈, - tangential_deformation=ṁ, + **ode_state_kwargs, ) return ode_state_derivative, aux_dict diff --git a/src/jaxsim/api/ode_data.py b/src/jaxsim/api/ode_data.py index 766f7f926..8b51a518d 100644 --- a/src/jaxsim/api/ode_data.py +++ b/src/jaxsim/api/ode_data.py @@ -6,6 +6,7 @@ import jaxsim.api as js import jaxsim.typing as jtp from jaxsim.rbda import ContactsState +from jaxsim.rbda.contacts.rigid import RigidContacts, RigidContactsState from jaxsim.rbda.contacts.soft import SoftContacts, SoftContactsState from jaxsim.utils import JaxsimDataclass @@ -133,7 +134,7 @@ def build_from_jaxsim_model( base_quaternion: jtp.Vector | None = None, base_linear_velocity: jtp.Vector | None = None, base_angular_velocity: jtp.Vector | None = None, - tangential_deformation: jtp.Matrix | None = None, + **kwargs, ) -> ODEState: """ Build an `ODEState` from a `JaxSimModel`. @@ -148,9 +149,7 @@ def build_from_jaxsim_model( The linear velocity of the base link in inertial-fixed representation. base_angular_velocity: The angular velocity of the base link in inertial-fixed representation. - tangential_deformation: - The matrix of 3D tangential material deformations corresponding to - each collidable point. + kwargs: Additional arguments needed to build the contact state. Returns: The `ODEState` built from the `JaxSimModel`. @@ -163,6 +162,7 @@ def build_from_jaxsim_model( # Get the contact model from the `JaxSimModel`. match model.contact_model: case SoftContacts(): + tangential_deformation = kwargs.get("tangential_deformation", None) contact = SoftContactsState.build_from_jaxsim_model( model=model, **( @@ -171,6 +171,8 @@ def build_from_jaxsim_model( else dict() ), ) + case RigidContacts(): + contact = RigidContactsState.build() case _: raise ValueError("Unable to determine contact state class prefix.") @@ -214,7 +216,7 @@ def build( # Get the contact model from the `JaxSimModel`. match contact: - case SoftContactsState(): + case SoftContactsState() | RigidContactsState(): pass case None: contact = SoftContactsState.zero(model=model) @@ -224,7 +226,7 @@ def build( return ODEState(physics_model=physics_model_state, contact=contact) @staticmethod - def zero(model: js.model.JaxSimModel) -> ODEState: + def zero(model: js.model.JaxSimModel, data: js.data.JaxSimModelData) -> ODEState: """ Build a zero `ODEState` from a `JaxSimModel`. @@ -235,7 +237,9 @@ def zero(model: js.model.JaxSimModel) -> ODEState: A zero `ODEState` instance. """ - model_state = ODEState.build(model=model) + model_state = ODEState.build( + model=model, contact=data.state.contact.zero(model=model) + ) return model_state diff --git a/src/jaxsim/integrators/common.py b/src/jaxsim/integrators/common.py index 65dcec5ee..ed7cead78 100644 --- a/src/jaxsim/integrators/common.py +++ b/src/jaxsim/integrators/common.py @@ -109,11 +109,14 @@ def step( integrator.params = params with integrator.mutable_context(mutability=Mutability.MUTABLE): - xf = integrator(x0, t0, dt, **kwargs) + xf, aux_dict = integrator(x0, t0, dt, **kwargs) - return xf, integrator.params | { - Integrator.AfterInitKey: jnp.array(False).astype(bool) - } + return ( + xf, + integrator.params + | {Integrator.AfterInitKey: jnp.array(False).astype(bool)} + | aux_dict, + ) @abc.abstractmethod def __call__(self, x0: State, t0: Time, dt: TimeStep, **kwargs) -> NextState: @@ -277,15 +280,19 @@ def build( return integrator - def __call__(self, x0: State, t0: Time, dt: TimeStep, **kwargs) -> NextState: + def __call__( + self, x0: State, t0: Time, dt: TimeStep, **kwargs + ) -> tuple[NextState, dict[str, Any]]: # Here z is a batched state with as many batch elements as b.T rows. # Note that z has multiple batches only if b.T has more than one row, # e.g. in Butcher tableau of embedded schemes. - z = self._compute_next_state(x0=x0, t0=t0, dt=dt, **kwargs) + z, aux_dict = self._compute_next_state(x0=x0, t0=t0, dt=dt, **kwargs) # The next state is the batch element located at the configured index of solution. - return jax.tree_util.tree_map(lambda l: l[self.row_index_of_solution], z) + next_state = jax.tree_util.tree_map(lambda l: l[self.row_index_of_solution], z) + + return next_state, aux_dict @classmethod def integrate_rk_stage( @@ -343,7 +350,7 @@ def post_process_state( def _compute_next_state( self, x0: State, t0: Time, dt: TimeStep, **kwargs - ) -> NextState: + ) -> tuple[NextState, dict[str, Any]]: """ Compute the next state of the system, returning all the output states. @@ -373,19 +380,21 @@ def _compute_next_state( ) # Apply FSAL property by passing ẋ0 = f(x0, t0) from the previous iteration. - get_ẋ0 = lambda: self.params.get("dxdt0", f(x0, t0)[0]) + get_ẋ0_and_aux_dict = lambda: self.params.get("dxdt0", f(x0, t0)) # We use a `jax.lax.scan` to compile the `f` function only once. # Otherwise, if we compute e.g. for RK4 sequentially, the jit-compiled code # would include 4 repetitions of the `f` logic, making everything extremely slow. - def scan_body(carry: jax.Array, i: int | jax.Array) -> tuple[jax.Array, None]: + def scan_body( + carry: jax.Array, i: int | jax.Array + ) -> tuple[jax.Array, dict[str, Any]]: """""" # Unpack the carry, i.e. the stacked kᵢ vectors. K = carry # Define the computation of the Runge-Kutta stage. - def compute_ki() -> jax.Array: + def compute_ki() -> tuple[jax.Array, dict[str, Any]]: # Compute ∑ⱼ aᵢⱼ kⱼ. op_sum_ak = lambda k: jnp.einsum("s,s...->...", A[i], k) @@ -398,13 +407,13 @@ def compute_ki() -> jax.Array: # Compute the next time for the kᵢ evaluation. ti = t0 + c[i] * Δt - # This is kᵢ = f(xᵢ, tᵢ). - return f(xi, ti)[0] + # This is kᵢ, aux_dict = f(xᵢ, tᵢ). + return f(xi, ti) # This selector enables FSAL property in the first iteration (i=0). - ki = jax.lax.cond( + ki, aux_dict = jax.lax.cond( pred=jnp.logical_and(i == 0, self.has_fsal), - true_fun=get_ẋ0, + true_fun=get_ẋ0_and_aux_dict, false_fun=compute_ki, ) @@ -413,10 +422,10 @@ def compute_ki() -> jax.Array: K = jax.tree_util.tree_map(op, K, ki) carry = K - return carry, None + return carry, aux_dict # Compute the state derivatives kᵢ. - K, _ = jax.lax.scan( + K, aux_dict = jax.lax.scan( f=scan_body, init=carry0, xs=jnp.arange(c.size), @@ -439,7 +448,7 @@ def compute_ki() -> jax.Array: lambda xf: self.post_process_state(x0=x0, t0=t0, xf=xf, dt=dt) )(z) - return z_transformed + return z_transformed, aux_dict @staticmethod def butcher_tableau_is_valid( diff --git a/src/jaxsim/rbda/contacts/common.py b/src/jaxsim/rbda/contacts/common.py index 0928c96e1..9823b114d 100644 --- a/src/jaxsim/rbda/contacts/common.py +++ b/src/jaxsim/rbda/contacts/common.py @@ -5,6 +5,7 @@ import jaxsim.terrain import jaxsim.typing as jtp +from jaxsim.utils import JaxsimDataclass class ContactsState(abc.ABC): @@ -42,7 +43,7 @@ def valid(self, **kwargs) -> bool: pass -class ContactsParams(abc.ABC): +class ContactsParams(JaxsimDataclass): """ Abstract class representing the parameters of a contact model. """ @@ -67,7 +68,7 @@ def valid(self, *args, **kwargs) -> bool: pass -class ContactModel(abc.ABC): +class ContactModel(JaxsimDataclass): """ Abstract class representing a contact model. diff --git a/src/jaxsim/rbda/contacts/rigid.py b/src/jaxsim/rbda/contacts/rigid.py new file mode 100644 index 000000000..6c0e3c1bb --- /dev/null +++ b/src/jaxsim/rbda/contacts/rigid.py @@ -0,0 +1,478 @@ +from __future__ import annotations + +import dataclasses +from typing import Any + +import jax +import jax.numpy as jnp +import jax_dataclasses + +import jaxsim.api as js +import jaxsim.typing as jtp +from jaxsim import math +from jaxsim.api.common import ModelDataWithVelocityRepresentation, VelRepr +from jaxsim.terrain import FlatTerrain, Terrain + +from .common import ContactModel, ContactsParams, ContactsState + + +@jax_dataclasses.pytree_dataclass +class RigidContactsParams(ContactsParams): + """Parameters of the rigid contacts model.""" + + # Static friction coefficient + mu: jtp.Float = dataclasses.field( + default_factory=lambda: jnp.array(0.5, dtype=float) + ) + + # Baumgarte proportional term + K: jtp.Float = dataclasses.field( + default_factory=lambda: jnp.array(0.0, dtype=float) + ) + + # Baumgarte derivative term + D: jtp.Float = dataclasses.field( + default_factory=lambda: jnp.array(0.0, dtype=float) + ) + + def __hash__(self) -> int: + from jaxsim.utils.wrappers import HashedNumpyArray + + return hash( + ( + HashedNumpyArray.hash_of_array(self.mu), + HashedNumpyArray.hash_of_array(self.K), + HashedNumpyArray.hash_of_array(self.D), + ) + ) + + def __eq__(self, other: RigidContactsParams) -> bool: + return hash(self) == hash(other) + + @classmethod + def build( + cls, + mu: jtp.FloatLike | None = None, + K: jtp.FloatLike | None = None, + D: jtp.FloatLike | None = None, + ) -> RigidContactsParams: + """Create a `RigidContactParams` instance""" + return RigidContactsParams( + mu=mu or cls.__dataclass_fields__["mu"].default, + K=K or cls.__dataclass_fields__["K"].default, + D=D or cls.__dataclass_fields__["D"].default, + ) + + def valid(self) -> bool: + return bool( + jnp.all(self.mu >= 0.0) + and jnp.all(self.K >= 0.0) + and jnp.all(self.D >= 0.0) + ) + + +@jax_dataclasses.pytree_dataclass +class RigidContactsState(ContactsState): + """Class storing the state of the rigid contacts model.""" + + def __eq__(self, other: RigidContactsState) -> bool: + return hash(self) == hash(other) + + @staticmethod + def build(**kwargs) -> RigidContactsState: + """Create a `RigidContactsState` instance""" + + return RigidContactsState() + + @staticmethod + def zero(**kwargs) -> RigidContactsState: + """Build a zero `RigidContactsState` instance from a `JaxSimModel`.""" + return RigidContactsState.build() + + def valid(self, **kwargs) -> bool: + return True + + +@jax_dataclasses.pytree_dataclass +class RigidContacts(ContactModel): + """Rigid contacts model.""" + + parameters: RigidContactsParams = dataclasses.field( + default_factory=RigidContactsParams + ) + + terrain: jax_dataclasses.Static[Terrain] = dataclasses.field( + default_factory=FlatTerrain + ) + + @staticmethod + def detect_contacts( + W_p_C: jtp.ArrayLike, + terrain_height: jtp.ArrayLike, + ) -> tuple[jtp.Vector, jtp.Vector]: + """ + Detect contacts between the collidable points and the terrain. + + Args: + W_p_C: The position of the collidable points. + terrain_height: The height of the terrain at the collidable point position. + + Returns: + A tuple containing the activation state of the collidable points and the contact penetration depth h. + """ + + # TODO: reduce code duplication with js.contact.in_contact + def detect_contact( + W_p_C: jtp.ArrayLike, + terrain_height: jtp.FloatLike, + ) -> tuple[jtp.Bool, jtp.Float]: + """ + Detect contacts between the collidable points and the terrain. + """ + + # Unpack the position of the collidable point. + _, _, pz = W_p_C.squeeze() + + inactive = pz > terrain_height + + # Compute contact penetration depth + h = jnp.maximum(0.0, terrain_height - pz) + + return inactive, h + + inactive_collidable_points, h = jax.vmap(detect_contact)(W_p_C, terrain_height) + + return inactive_collidable_points, h + + @staticmethod + def compute_impact_velocity( + inactive_collidable_points: jtp.ArrayLike, + M: jtp.MatrixLike, + J_WC: jtp.MatrixLike, + data: js.data.JaxSimModelData, + ) -> jtp.Vector: + """Returns the new velocity of the system after a potential impact. + + Args: + inactive_collidable_points: The activation state of the collidable points. + M: The mass matrix of the system. + J_WC: The Jacobian matrix of the collidable points. + data: The `JaxSimModelData` instance. + """ + + def impact_velocity( + inactive_collidable_points: jtp.ArrayLike, + nu_pre: jtp.ArrayLike, + M: jtp.MatrixLike, + J_WC: jtp.MatrixLike, + data: js.data.JaxSimModelData, + ): + # Compute system velocity after impact maintaining zero linear velocity of active points + with data.switch_velocity_representation(VelRepr.Mixed): + sl = jnp.s_[:, 0:3, :] + Jl_WC = J_WC[sl] + # Zero out the jacobian rows of inactive points + Jl_WC = jnp.vstack( + jnp.where( + inactive_collidable_points[:, jnp.newaxis, jnp.newaxis], + jnp.zeros_like(Jl_WC), + Jl_WC, + ) + ) + + A = jnp.vstack( + [ + jnp.hstack([M, -Jl_WC.T]), + jnp.hstack( + [Jl_WC, jnp.zeros((Jl_WC.shape[0], Jl_WC.shape[0]))] + ), + ] + ) + b = jnp.hstack([M @ nu_pre, jnp.zeros(Jl_WC.shape[0])]) + x = jnp.linalg.lstsq(A, b)[0] + nu_post = x[0 : M.shape[0]] + + return nu_post + + with data.switch_velocity_representation(VelRepr.Mixed): + BW_ν_pre_impact = data.generalized_velocity() + + BW_ν_post_impact = impact_velocity( + data=data, + inactive_collidable_points=inactive_collidable_points, + nu_pre=BW_ν_pre_impact, + M=M, + J_WC=J_WC, + ) + + return BW_ν_post_impact + + def compute_contact_forces( + self, + position: jtp.Vector, + velocity: jtp.Vector, + model: js.model.JaxSimModel, + data: js.data.JaxSimModelData, + link_forces: jtp.MatrixLike | None = None, + regularization_term: jtp.FloatLike = 1e-6, + ) -> tuple[jtp.Vector, tuple[Any, ...]]: + """ + Compute the contact forces. + + Args: + position: The position of the collidable point. + velocity: The linear velocity of the collidable point. + model: The `JaxSimModel` instance. + data: The `JaxSimModelData` instance. + link_forces: + Optional `(n_links, 6)` matrix of external forces acting on the links, + expressed in the same representation of data. + regularization_term: + The regularization term to add to the diagonal of the Delassus + matrix for better numerical conditioning. + + Returns: + A tuple containing the contact forces. + """ + + # Import qpax just in this method + import qpax + + link_forces = ( + link_forces + if link_forces is not None + else jnp.zeros((model.number_of_links(), 6)) + ) + + # Compute kin-dyn quantities used in the contact model + with data.switch_velocity_representation(VelRepr.Mixed): + M = js.model.free_floating_mass_matrix(model=model, data=data) + J_WC = js.contact.jacobian(model=model, data=data) + W_H_C = js.contact.transforms(model=model, data=data) + terrain_height = jax.vmap(self.terrain.height)(position[:, 0], position[:, 1]) + n_collidable_points = model.kin_dyn_parameters.contact_parameters.point.shape[0] + + # Compute the activation state of the collidable points + inactive_collidable_points, h = RigidContacts.detect_contacts( + W_p_C=position, + terrain_height=terrain_height, + ) + + delassus_matrix = RigidContacts._delassus_matrix(M=M, J_WC=J_WC) + + # Add regularization for better numerical conditioning + delassus_matrix = delassus_matrix + regularization_term * jnp.eye( + delassus_matrix.shape[0] + ) + + references = js.references.JaxSimModelReferences.build( + model=model, + data=data, + velocity_representation=data.velocity_representation, + link_forces=link_forces, + ) + + with references.switch_velocity_representation(VelRepr.Mixed): + BW_ν̇_free = RigidContacts._compute_mixed_nu_dot_free( + model, data, references=references + ) + + free_contact_acc = RigidContacts._linear_acceleration_of_collidable_points( + model, + data, + BW_ν̇_free, + ).flatten() + + # Compute stabilization term + ḣ = velocity[:, 2].squeeze() + baumgarte_term = RigidContacts._compute_baumgarte_stabilization_term( + inactive_collidable_points=inactive_collidable_points, + h=h, + ḣ=ḣ, + K=self.parameters.K, + D=self.parameters.D, + ).flatten() + + free_contact_acc -= baumgarte_term + + # Setup optimization problem + Q = delassus_matrix + q = free_contact_acc + G = RigidContacts._compute_ineq_constraint_matrix( + inactive_collidable_points=inactive_collidable_points, mu=self.parameters.mu + ) + h_bounds = RigidContacts._compute_ineq_bounds( + n_collidable_points=n_collidable_points + ) + A = jnp.zeros((0, 3 * n_collidable_points)) + b = jnp.zeros((0,)) + + # Solve the optimization problem + solution, *_ = qpax.solve_qp(Q=Q, q=q, A=A, b=b, G=G, h=h_bounds) + + f_C_lin = solution.reshape(-1, 3) + + # Transform linear contact forces to 6D + CW_f_C = jnp.hstack( + ( + f_C_lin, + jnp.zeros((f_C_lin.shape[0], 3)), + ) + ) + + # Transform the contact forces to inertial-fixed representation + W_f_C = jax.vmap( + lambda CW_f_C, W_H_C: ModelDataWithVelocityRepresentation.other_representation_to_inertial( + array=CW_f_C, + transform=W_H_C, + other_representation=VelRepr.Mixed, + is_force=True, + ), + )( + CW_f_C, + W_H_C, + ) + + return W_f_C, () + + @staticmethod + def _delassus_matrix( + M: jtp.MatrixLike, + J_WC: jtp.MatrixLike, + ) -> jtp.Matrix: + sl = jnp.s_[:, 0:3, :] + J_WC_lin = jnp.vstack(J_WC[sl]) + + delassus_matrix = J_WC_lin @ jnp.linalg.pinv(M) @ J_WC_lin.T + return delassus_matrix + + @staticmethod + def _compute_ineq_constraint_matrix( + inactive_collidable_points: jtp.Vector, mu: jtp.FloatLike + ) -> jtp.Matrix: + def compute_G_single_point(mu: float, c: float) -> jtp.Matrix: + """ + Compute the inequality constraint matrix for a single collidable point + Rows 0-3: enforce the friction pyramid constraint, + Row 4: last one is for the non negativity of the vertical force + Row 5: contact complementarity condition + """ + G_single_point = jnp.array( + [ + [1, 0, -mu], + [0, 1, -mu], + [-1, 0, -mu], + [0, -1, -mu], + [0, 0, -1], + [0, 0, c], + ] + ) + return G_single_point + + G = jax.vmap(compute_G_single_point, in_axes=(None, 0))( + mu, inactive_collidable_points + ) + G = jax.scipy.linalg.block_diag(*G) + return G + + @staticmethod + def _compute_ineq_bounds(n_collidable_points: jtp.FloatLike) -> jtp.Vector: + n_constraints = 6 * n_collidable_points + return jnp.zeros(shape=(n_constraints,)) + + @staticmethod + def _compute_mixed_nu_dot_free( + model: js.model.JaxSimModel, + data: js.data.JaxSimModelData, + references: js.references.JaxSimModelReferences | None = None, + ) -> jtp.Array: + references = ( + references + if references is not None + else js.references.JaxSimModelReferences.zero(model=model, data=data) + ) + + with ( + data.switch_velocity_representation(VelRepr.Mixed), + references.switch_velocity_representation(VelRepr.Mixed), + ): + BW_v_WB = data.base_velocity() + W_ṗ_B, W_ω_WB = jnp.split(BW_v_WB, 2) + W_v̇_WB, s̈ = js.ode.system_acceleration( + model=model, + data=data, + joint_forces=references.joint_force_references(model=model), + link_forces=references.link_forces(model=model, data=data), + ) + + # Convert the inertial-fixed base acceleration to a mixed base acceleration. + W_H_B = data.base_transform() + W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3)) + BW_X_W = math.Adjoint.from_transform(W_H_BW, inverse=True) + term1 = BW_X_W @ W_v̇_WB + term2 = jnp.zeros(6).at[0:3].set(jnp.cross(W_ṗ_B, W_ω_WB)) + BW_v̇_WB = term1 - term2 + + BW_ν̇ = jnp.hstack([BW_v̇_WB, s̈]) + + return BW_ν̇ + + @staticmethod + def _linear_acceleration_of_collidable_points( + model: js.model.JaxSimModel, + data: js.data.JaxSimModelData, + mixed_nu_dot: jtp.ArrayLike, + ) -> jtp.Matrix: + with data.switch_velocity_representation(VelRepr.Mixed): + CW_J_WC_BW = js.contact.jacobian( + model=model, + data=data, + output_vel_repr=VelRepr.Mixed, + ) + CW_J̇_WC_BW = js.contact.jacobian_derivative( + model=model, + data=data, + output_vel_repr=VelRepr.Mixed, + ) + + BW_ν = data.generalized_velocity() + BW_ν̇ = mixed_nu_dot + + CW_a_WC = jnp.vstack(CW_J̇_WC_BW) @ BW_ν + jnp.vstack(CW_J_WC_BW) @ BW_ν̇ + CW_a_WC = CW_a_WC.reshape(-1, 6) + + return CW_a_WC[:, 0:3].squeeze() + + @staticmethod + def _compute_baumgarte_stabilization_term( + inactive_collidable_points: jtp.ArrayLike, + h: jtp.ArrayLike, + ḣ: jtp.ArrayLike, + K: jtp.FloatLike, + D: jtp.FloatLike, + ) -> jtp.Array: + def baumgarte_stabilization( + inactive: jtp.BoolLike, + h: jtp.FloatLike, + ḣ: jtp.FloatLike, + k_baumgarte: jtp.FloatLike, + d_baumgarte: jtp.FloatLike, + ) -> jtp.Array: + baumgarte_term = jax.lax.cond( + inactive, + lambda h, ḣ, K, D: jnp.zeros(shape=(3,)), + lambda h, ḣ, K, D: jnp.zeros(shape=(3,)).at[2].set(K * h + D * ḣ), + *( + h, + ḣ, + k_baumgarte, + d_baumgarte, + ), + ) + return baumgarte_term + + baumgarte_term = jax.vmap( + baumgarte_stabilization, in_axes=(0, 0, 0, None, None) + )(inactive_collidable_points, h, ḣ, K, D) + + return baumgarte_term