Skip to content

Commit

Permalink
Merge pull request #256 from ami-iit/remove_contact_state
Browse files Browse the repository at this point in the history
Remove `*ContactsState` classes
  • Loading branch information
diegoferigo authored Oct 7, 2024
2 parents 1bbf438 + 5d775de commit 4d85117
Show file tree
Hide file tree
Showing 12 changed files with 317 additions and 372 deletions.
15 changes: 3 additions & 12 deletions src/jaxsim/api/contact.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,8 @@ def collidable_point_dynamics(
The joint force references to apply to the joints.
Returns:
The 6D force applied to each collidable point and additional data based on the contact model configured:
The 6D force applied to each collidable point and additional data based
on the contact model configured:
- Soft: the material deformation rate.
- Rigid: no additional data.
- QuasiRigid: no additional data.
Expand All @@ -156,21 +157,13 @@ def collidable_point_dynamics(
"""

# Import privately the contacts classes.
from jaxsim.rbda.contacts import (
RelaxedRigidContacts,
RelaxedRigidContactsState,
RigidContacts,
RigidContactsState,
SoftContacts,
SoftContactsState,
)
from jaxsim.rbda.contacts import RelaxedRigidContacts, RigidContacts, SoftContacts

# Build the soft contact model.
match model.contact_model:

case SoftContacts():
assert isinstance(model.contact_model, SoftContacts)
assert isinstance(data.state.contact, SoftContactsState)

# Compute the 6D force expressed in the inertial frame and applied to each
# collidable point, and the corresponding material deformation rate.
Expand All @@ -187,7 +180,6 @@ def collidable_point_dynamics(

case RigidContacts():
assert isinstance(model.contact_model, RigidContacts)
assert isinstance(data.state.contact, RigidContactsState)

# Compute the 6D force expressed in the inertial frame and applied to each
# collidable point.
Expand All @@ -203,7 +195,6 @@ def collidable_point_dynamics(

case RelaxedRigidContacts():
assert isinstance(model.contact_model, RelaxedRigidContacts)
assert isinstance(data.state.contact, RelaxedRigidContactsState)

# Compute the 6D force expressed in the inertial frame and applied to each
# collidable point.
Expand Down
106 changes: 62 additions & 44 deletions src/jaxsim/api/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import jaxsim.math
import jaxsim.rbda
import jaxsim.typing as jtp
from jaxsim.rbda.contacts import SoftContacts
from jaxsim.utils import Mutability
from jaxsim.utils.tracing import not_tracing

Expand Down Expand Up @@ -107,17 +106,17 @@ def zero(
@staticmethod
def build(
model: js.model.JaxSimModel,
base_position: jtp.Vector | None = None,
base_quaternion: jtp.Vector | None = None,
joint_positions: jtp.Vector | None = None,
base_linear_velocity: jtp.Vector | None = None,
base_angular_velocity: jtp.Vector | None = None,
joint_velocities: jtp.Vector | None = None,
base_position: jtp.VectorLike | None = None,
base_quaternion: jtp.VectorLike | None = None,
joint_positions: jtp.VectorLike | None = None,
base_linear_velocity: jtp.VectorLike | None = None,
base_angular_velocity: jtp.VectorLike | None = None,
joint_velocities: jtp.VectorLike | None = None,
standard_gravity: jtp.FloatLike = jaxsim.math.StandardGravity,
contact: jaxsim.rbda.contacts.ContactsState | None = None,
contacts_params: jaxsim.rbda.contacts.ContactsParams | None = None,
velocity_representation: VelRepr = VelRepr.Inertial,
time: jtp.FloatLike | None = None,
extended_ode_state: dict[str, jtp.PyTree] | None = None,
) -> JaxSimModelData:
"""
Create a `JaxSimModelData` object with the given state.
Expand All @@ -133,56 +132,73 @@ def build(
The base angular velocity in the selected representation.
joint_velocities: The joint velocities.
standard_gravity: The standard gravity constant.
contact: The state of the soft contacts.
contacts_params: The parameters of the soft contacts.
velocity_representation: The velocity representation to use.
time: The time at which the state is created.
extended_ode_state:
Additional user-defined state variables that are not part of the
standard `ODEState` object. Useful to extend the system dynamics
considered by default in JaxSim.
Returns:
A `JaxSimModelData` object with the given state.
A `JaxSimModelData` initialized with the given state.
"""

base_position = jnp.array(
base_position if base_position is not None else jnp.zeros(3)
base_position if base_position is not None else jnp.zeros(3),
dtype=float,
).squeeze()

base_quaternion = jnp.array(
base_quaternion
if base_quaternion is not None
else jnp.array([1.0, 0, 0, 0])
(
base_quaternion
if base_quaternion is not None
else jnp.array([1.0, 0, 0, 0])
),
dtype=float,
).squeeze()

base_linear_velocity = jnp.array(
base_linear_velocity if base_linear_velocity is not None else jnp.zeros(3)
base_linear_velocity if base_linear_velocity is not None else jnp.zeros(3),
dtype=float,
).squeeze()

base_angular_velocity = jnp.array(
base_angular_velocity if base_angular_velocity is not None else jnp.zeros(3)
(
base_angular_velocity
if base_angular_velocity is not None
else jnp.zeros(3)
),
dtype=float,
).squeeze()

gravity = jnp.zeros(3).at[2].set(-standard_gravity)

joint_positions = jnp.atleast_1d(
joint_positions.squeeze()
if joint_positions is not None
else jnp.zeros(model.dofs())
jnp.array(
(
joint_positions
if joint_positions is not None
else jnp.zeros(model.dofs())
),
dtype=float,
).squeeze()
)

joint_velocities = jnp.atleast_1d(
joint_velocities.squeeze()
if joint_velocities is not None
else jnp.zeros(model.dofs())
jnp.array(
(
joint_velocities
if joint_velocities is not None
else jnp.zeros(model.dofs())
),
dtype=float,
).squeeze()
)

time_ns = (
jnp.array(
time * 1e9,
dtype=jnp.uint64 if jax.config.read("jax_enable_x64") else jnp.uint32,
)
if time is not None
else jnp.array(
0, dtype=jnp.uint64 if jax.config.read("jax_enable_x64") else jnp.uint32
)
time_ns = jnp.array(
time * 1e9 if time is not None else 0.0,
dtype=jnp.uint64 if jax.config.read("jax_enable_x64") else jnp.uint32,
)

W_H_B = jaxsim.math.Transform.from_quaternion_and_translation(
Expand All @@ -194,21 +210,22 @@ def build(
other_representation=velocity_representation,
transform=W_H_B,
is_force=False,
)
).astype(float)

ode_state = ODEState.build_from_jaxsim_model(
model=model,
base_position=base_position.astype(float),
base_quaternion=base_quaternion.astype(float),
joint_positions=joint_positions.astype(float),
base_linear_velocity=v_WB[0:3].astype(float),
base_angular_velocity=v_WB[3:6].astype(float),
joint_velocities=joint_velocities.astype(float),
tangential_deformation=(
contact.tangential_deformation
if contact is not None and isinstance(model.contact_model, SoftContacts)
else None
),
base_position=base_position,
base_quaternion=base_quaternion,
joint_positions=joint_positions,
base_linear_velocity=v_WB[0:3],
base_angular_velocity=v_WB[3:6],
joint_velocities=joint_velocities,
# Unpack all the additional ODE states. If the contact model requires an
# additional state that is not explicitly passed to this builder, ODEState
# automatically populates that state with zeroed variables.
# This is not true for any other custom state that the user might want to
# pass to the integrator.
**(extended_ode_state if extended_ode_state else {}),
)

if not ode_state.valid(model=model):
Expand All @@ -220,13 +237,14 @@ def build(
contacts_params = js.contact.estimate_good_soft_contacts_parameters(
model=model, standard_gravity=standard_gravity
)

else:
contacts_params = model.contact_model.parameters

return JaxSimModelData(
time_ns=time_ns,
state=ode_state,
gravity=gravity.astype(float),
gravity=gravity,
contacts_params=contacts_params,
velocity_representation=velocity_representation,
)
Expand Down
45 changes: 28 additions & 17 deletions src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class JaxSimModel(JaxsimDataclass):
model_name: Static[str]

terrain: Static[jaxsim.terrain.Terrain] = dataclasses.field(
default=jaxsim.terrain.FlatTerrain(), repr=False
default_factory=jaxsim.terrain.FlatTerrain.build, repr=False
)

contact_model: jaxsim.rbda.contacts.ContactModel | None = dataclasses.field(
Expand Down Expand Up @@ -101,13 +101,14 @@ def build_from_model_description(
A path to an SDF/URDF file, a string containing
its content, or a pre-parsed/pre-built rod model.
model_name:
The optional name of the model that overrides the one in
the description.
terrain:
The optional terrain to consider.
The name of the model. If not specified, it is read from the description.
terrain: The terrain to consider (the default is a flat infinite plane).
contact_model:
The contact model to consider.
If not specified, a soft contacts model is used.
is_urdf:
The optional flag to force the model description to be parsed as a
URDF or a SDF. This is otherwise automatically inferred.
The optional flag to force the model description to be parsed as a URDF.
This is usually automatically inferred.
considered_joints:
The list of joints to consider. If None, all joints are considered.
Expand All @@ -120,7 +121,7 @@ def build_from_model_description(
# Parse the input resource (either a path to file or a string with the URDF/SDF)
# and build the -intermediate- model description.
intermediate_description = jaxsim.parsers.rod.build_model_description(
model_description=model_description
model_description=model_description, is_urdf=is_urdf
)

# Lump links together if not all joints are considered.
Expand Down Expand Up @@ -160,11 +161,11 @@ def build(
The intermediate model description defining the kinematics and dynamics
of the model.
model_name:
The optional name of the model overriding the physics model name.
terrain:
The optional terrain to consider.
The name of the model. If not specified, it is read from the description.
terrain: The terrain to consider (the default is a flat infinite plane).
contact_model:
The optional contact model to consider. If None, the soft contact model is used.
The contact model to consider.
If not specified, a soft contacts model is used.
Returns:
The built Model object.
Expand All @@ -173,21 +174,31 @@ def build(
# Set the model name (if not provided, use the one from the model description).
model_name = model_name if model_name is not None else model_description.name

# Set the terrain (if not provided, use the default flat terrain).
terrain = terrain or JaxSimModel.__dataclass_fields__["terrain"].default
contact_model = contact_model or jaxsim.rbda.contacts.SoftContacts(
terrain=terrain
# Consider the default terrain (a flat infinite plane) if not specified.
terrain = (
terrain or JaxSimModel.__dataclass_fields__["terrain"].default_factory()
)

# Create the default contact model.
# It will be populated with an initial estimation of good parameters.
# While these might not be the best, they are a good starting point.
contact_model = contact_model or jaxsim.rbda.contacts.SoftContacts.build(
terrain=terrain, parameters=None
)

# Build the model.
model = JaxSimModel(
model_name=model_name,
_description=wrappers.HashlessObject(obj=model_description),
kin_dyn_parameters=js.kin_dyn_parameters.KynDynParameters.build(
model_description=model_description
),
terrain=terrain,
contact_model=contact_model,
# The following is wrapped as hashless since it's a static argument, and we
# don't want to trigger recompilation if it changes. All relevant parameters
# needed to compute kinematics and dynamics quantities are stored in the
# kin_dyn_parameters attribute.
_description=wrappers.HashlessObject(obj=model_description),
)

return model
Expand Down
16 changes: 9 additions & 7 deletions src/jaxsim/api/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,9 +370,8 @@ def system_dynamics(
corresponding derivative, and the dictionary of auxiliary data returned
by the system dynamics evaluation.
"""
from jaxsim.rbda.contacts.relaxed_rigid import RelaxedRigidContacts
from jaxsim.rbda.contacts.rigid import RigidContacts
from jaxsim.rbda.contacts.soft import SoftContacts

from jaxsim.rbda.contacts import RelaxedRigidContacts, RigidContacts, SoftContacts

# Compute the accelerations and the material deformation rate.
W_v̇_WB, , aux_dict = system_velocity_dynamics(
Expand All @@ -382,17 +381,20 @@ def system_dynamics(
link_forces=link_forces,
)

ode_state_kwargs = {}
# Initialize the dictionary storing the derivative of the additional state variables
# that extend the state vector of the integrated ODE system.
extended_ode_state = {}

match model.contact_model:

case SoftContacts():
ode_state_kwargs["tangential_deformation"] = aux_dict["m_dot"]
extended_ode_state["tangential_deformation"] = aux_dict["m_dot"]

case RigidContacts() | RelaxedRigidContacts():
pass

case _:
raise ValueError("Unable to determine contact state class prefix.")
raise ValueError(f"Invalid contact model {model.contact_model}")

# Extract the velocities.
W_ṗ_B, W_Q̇_B, = system_position_dynamics(
Expand All @@ -412,7 +414,7 @@ def system_dynamics(
base_linear_velocity=W_v̇_WB[0:3],
base_angular_velocity=W_v̇_WB[3:6],
joint_velocities=,
**ode_state_kwargs,
**extended_ode_state,
)

return ode_state_derivative, aux_dict
Loading

0 comments on commit 4d85117

Please sign in to comment.