Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add alternative rigid contact model #227

Merged
merged 50 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
1b3a59c
Add dummy `jaxsim.rbda.contacts.rigid` module
xela-95 Sep 10, 2024
a1c9128
[WIP] Add rigid contact model implementation
xela-95 Sep 10, 2024
9743166
Update `ODEState` to handle rigid contacts
xela-95 Sep 10, 2024
336d429
Pass new argument to `ODEState.zero` from `wrap_system_dynamics_for_i…
xela-95 Sep 10, 2024
71145b4
Handle rigid contacts in `system_velocity_dynamics` and `system_dynam…
xela-95 Sep 10, 2024
5fa00f2
Update `jaxsim.api.contact.collidable_point_dynamics`
xela-95 Sep 10, 2024
397116e
Update `collidable_point_dynamics` return type
xela-95 Sep 10, 2024
d950488
Make `ContactsParams` and `ContactModel` inherit from `JaxsimDataclass`
xela-95 Sep 10, 2024
b5503bc
Fixes in rigid contact model
xela-95 Sep 10, 2024
54b5a6a
Decouple functions computing system acceleration and contact forces
xela-95 Sep 10, 2024
38b5d42
Update `aux_dict` returned by `collidable_point_dynamics`
xela-95 Sep 10, 2024
125c546
Add `qpax` dependency to environment.yml
xela-95 Sep 10, 2024
e76e39e
Update installation of `qpax` through PyPI
xela-95 Sep 10, 2024
48efd53
Move `qpax` as private import
xela-95 Sep 10, 2024
d4fa9e7
Move `inactive_collidable_points_prev` to RigidContactState
xela-95 Sep 10, 2024
2b42f97
Generalize handling of contact states
xela-95 Sep 10, 2024
56e3dce
[WIP] Restore inactive collidable points status array as `RigidContac…
xela-95 Sep 10, 2024
97c9b14
Remove impact detection and activation state at previous iteration
xela-95 Sep 10, 2024
b8e46c0
Improve computation of post impact system velocity
xela-95 Sep 10, 2024
67e3bf8
[WIP] Update system velocity in `jazsim.api.model.step`
xela-95 Sep 10, 2024
44db326
Compute impact velocity in `jaxsmi.api.model.step`
xela-95 Sep 10, 2024
8090a6c
Consider external link forces in `rigid.py`
xela-95 Sep 10, 2024
671e4f4
Manage external link forces in `contact.collidable_point_dynamics`
xela-95 Sep 10, 2024
f9c003b
Manage external link forces in `ode.system_velocity_dynamics`
xela-95 Sep 10, 2024
9efe5f0
Remove reset of system velocity in `contact.collidable_point_dynamics`
xela-95 Sep 10, 2024
44e42a2
Remove reset of system velocity in `ode.system_velocity_dynamics`
xela-95 Sep 10, 2024
3b92a57
Apply suggestions from code review
xela-95 Sep 10, 2024
027d8ca
Update `RigidContacts` api and add docstrings for public methods
xela-95 Sep 10, 2024
a556a05
Apply suggestions from code review
xela-95 Sep 10, 2024
66dd571
Apply suggestions from review
xela-95 Sep 10, 2024
eb41af0
Apply suggestions from code review
xela-95 Sep 10, 2024
19778b7
Fix link forces in `ode.system_dynamics`
xela-95 Sep 10, 2024
064745e
Take default arguments of `RigidContactParams` from dataclass fields
xela-95 Sep 10, 2024
1cf550b
Reorder public/private methods in `RigidContacts`
xela-95 Sep 10, 2024
dc16c67
Apply suggestion from code review
xela-95 Sep 10, 2024
e560547
Update `contact.py`
xela-95 Sep 10, 2024
2c2b132
Update `model.py`
xela-95 Sep 10, 2024
688a401
Update `common.py`
xela-95 Sep 10, 2024
99abbd4
Update `rigid.py`
xela-95 Sep 10, 2024
80275f1
Update `contact.py`
xela-95 Sep 10, 2024
1cb350d
Update `rigid.compute_contact_forces`
xela-95 Sep 10, 2024
cfd94a9
Update `ode.py`
xela-95 Sep 10, 2024
8e67724
Update `model.py`
xela-95 Sep 10, 2024
2f041a4
Update Baumgarte stabilization term
xela-95 Sep 10, 2024
c9fdae0
Update `RigidContactParams` class name to `RigidContactsParams`
xela-95 Sep 11, 2024
4a46d53
Raise error in `model.step` for not supported case of FE with rigid c…
xela-95 Sep 11, 2024
40397f7
Update `ode.py`
xela-95 Sep 11, 2024
181b61d
Update `model.py`
xela-95 Sep 11, 2024
c7a286a
Update `contact.py`
xela-95 Sep 11, 2024
0d3ea67
Update `rigid.py`
xela-95 Sep 11, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ dependencies:
- jaxlie >= 1.3.0
- jax-dataclasses >= 1.4.0
- pptree
- qpax
- rod >= 0.3.0
- typing_extensions # python<3.12
# ====================================
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ dependencies = [
"jaxlie >= 1.3.0",
"jax_dataclasses >= 1.4.0",
"pptree",
"qpax",
"rod >= 0.3.0",
"typing_extensions ; python_version < '3.12'",
]
Expand Down Expand Up @@ -187,6 +188,7 @@ mediapy = "*"
mujoco = "*"
notebook = "*"
pptree = "*"
qpax = "*"
rod = "*"
sdformat14 = "*"
typing_extensions = "*"
Expand Down
44 changes: 36 additions & 8 deletions src/jaxsim/api/contact.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,19 +114,24 @@ def collidable_point_forces(

@jax.jit
def collidable_point_dynamics(
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
) -> tuple[jtp.Matrix, jtp.Matrix]:
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
link_forces: jtp.MatrixLike | None = None,
) -> tuple[jtp.Matrix, dict[str, jtp.Array]]:
r"""
Compute the 6D force applied to each collidable point and the corresponding
material deformation rate.
Compute the 6D force applied to each collidable point.

Args:
model: The model to consider.
data: The data of the considered model.
link_forces:
The 6D external forces to apply to the links expressed in the same
representation of data.

Returns:
The 6D force applied to each collidable point and the corresponding
material deformation rate.
The 6D force applied to each collidable point and additional data based on the contact model configured:
- Soft: the material deformation rate.
- Rigid: nothing.

Note:
The material deformation rate is always returned in the mixed frame
Expand All @@ -138,7 +143,8 @@ def collidable_point_dynamics(
# all collidable points belonging to the robot.
W_p_Ci, W_ṗ_Ci = js.contact.collidable_point_kinematics(model=model, data=data)

# Import privately the soft contacts classes.
# Import privately the contacts classes.
from jaxsim.rbda.contacts.rigid import RigidContacts, RigidContactsState
from jaxsim.rbda.contacts.soft import SoftContacts, SoftContactsState

# Build the soft contact model.
Expand All @@ -161,6 +167,28 @@ def collidable_point_dynamics(
W_f_Ci, (CW_ṁ,) = jax.vmap(soft_contacts.compute_contact_forces)(
W_p_Ci, W_ṗ_Ci, data.state.contact.tangential_deformation
)
aux_data = dict(m_dot=CW_ṁ)

case RigidContacts():
assert isinstance(model.contact_model, RigidContacts)
assert isinstance(data.state.contact, RigidContactsState)

# Build the contact model.
rigid_contacts = RigidContacts(
parameters=data.contacts_params, terrain=model.terrain
)

# Compute the 6D force expressed in the inertial frame and applied to each
# collidable point.
W_f_Ci, _ = rigid_contacts.compute_contact_forces(
position=W_p_Ci,
velocity=W_ṗ_Ci,
model=model,
data=data,
link_forces=link_forces,
)

aux_data = dict()

case _:
raise ValueError(f"Invalid contact model {model.contact_model}")
Expand All @@ -175,7 +203,7 @@ def collidable_point_dynamics(
)
)(W_f_Ci)

return f_Ci, CW_ṁ
return f_Ci, aux_data


@functools.partial(jax.jit, static_argnames=["link_names"])
Expand Down
67 changes: 60 additions & 7 deletions src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from jax_dataclasses import Static

import jaxsim.api as js
import jaxsim.exceptions
import jaxsim.terrain
import jaxsim.typing as jtp
from jaxsim.math import Adjoint, Cross
Expand Down Expand Up @@ -1890,6 +1891,8 @@ def step(
and the new state of the integrator.
"""

from jaxsim.rbda.contacts.rigid import RigidContacts

# Extract the integrator kwargs.
# The following logic allows using integrators having kwargs colliding with the
# kwargs of this step function.
Expand All @@ -1901,12 +1904,12 @@ def step(

# Extract the initial resources.
t0_ns = data.time_ns
state_x0 = data.state
state_t0 = data.state
integrator_state_x0 = integrator_state

# Step the dynamics forward.
state_xf, integrator_state_xf = integrator.step(
x0=state_x0,
state_tf, integrator_state_tf = integrator.step(
x0=state_t0,
t0=jnp.array(t0_ns / 1e9).astype(float),
dt=dt,
params=integrator_state_x0,
Expand All @@ -1928,11 +1931,61 @@ def step(
),
)

return (
data_tf = (
# Store the new state of the model and the new time.
data.replace(
state=state_xf,
state=state_tf,
time_ns=t0_ns + jnp.array(dt * 1e9).astype(jnp.uint64),
),
integrator_state_xf,
)
)

# Post process the simulation state, if needed.
match model.contact_model:

# Rigid contact models use an impact model that produces a discontinuous model velocity.
# Hence here we need to reset the velocity after each impact to guarantee that
# the linear velocity of the active collidable points is zero.
case RigidContacts():
# Raise runtime error for not supported case in which Rigid contacts and Baumgarte stabilization
# enabled are used with ForwardEuler integrator.
jaxsim.exceptions.raise_runtime_error_if(
condition=jnp.logical_and(
isinstance(
integrator,
jaxsim.integrators.fixed_step.ForwardEuler
| jaxsim.integrators.fixed_step.ForwardEulerSO3,
),
jnp.array(
[data_tf.contacts_params.K, data_tf.contacts_params.D]
).any(),
),
msg="Baumgarte stabilization is not supported with ForwardEuler integrators",
)

with data_tf.switch_velocity_representation(VelRepr.Mixed):
W_p_C = js.contact.collidable_point_positions(model, data_tf)
M = js.model.free_floating_mass_matrix(model, data_tf)
J_WC = js.contact.jacobian(model, data_tf)
px, py, _ = W_p_C.T
terrain_height = jax.vmap(model.terrain.height)(px, py)
inactive_collidable_points, _ = RigidContacts.detect_contacts(
W_p_C=W_p_C,
terrain_height=terrain_height,
)
BW_nu_post_impact = RigidContacts.compute_impact_velocity(
data=data_tf,
inactive_collidable_points=inactive_collidable_points,
M=M,
J_WC=J_WC,
)
data_tf = data_tf.reset_base_velocity(BW_nu_post_impact[0:6])
data_tf = data_tf.reset_joint_velocities(BW_nu_post_impact[6:])
# Restore the input velocity representation.
data_tf = data_tf.replace(
velocity_representation=data.velocity_representation, validate=False
)

return (
data_tf,
integrator_state_tf,
)
Loading
Loading