Skip to content

Commit

Permalink
[Sprint] Set mixed as default representation in data.build (#361)
Browse files Browse the repository at this point in the history
* Rename base linear and angular velocity parameters in `forward_kinematics_model`

* Update JaxSimModelData to use Mixed as default repr. in `build` function

* Restore `step` function to accept link forces in the same reprensentation of data

* Refactor tests to reflect changes in API

* Remove redundant `system_velocity_dynamics` function and update `system_dynamics` to use `system_acceleration` directly

* Update `step` function to use `system_acceleration` instead of `system_velocity_dynamics`

* Format `test_automatic_differentiation.py`
  • Loading branch information
xela-95 authored and CarlottaSartore committed Feb 4, 2025
1 parent 4445ba1 commit a1c5770
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 97 deletions.
40 changes: 21 additions & 19 deletions src/jaxsim/api/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
base_transform: The base transform.
joint_transforms: The joint transforms.
link_transforms: The link transforms.
link_velocities: The link velocities.
link_velocities: The link velocities in inertial-fixed representation.
"""

# Joint state
Expand Down Expand Up @@ -69,7 +69,7 @@ def build(
base_linear_velocity: jtp.VectorLike | None = None,
base_angular_velocity: jtp.VectorLike | None = None,
joint_velocities: jtp.VectorLike | None = None,
velocity_representation: VelRepr = VelRepr.Inertial,
velocity_representation: VelRepr = VelRepr.Mixed,
) -> JaxSimModelData:
"""
Create a `JaxSimModelData` object with the given state.
Expand All @@ -84,7 +84,7 @@ def build(
base_angular_velocity:
The base angular velocity in the selected representation.
joint_velocities: The joint velocities.
velocity_representation: The velocity representation to use.
velocity_representation: The velocity representation to use. It defaults to mixed if not provided.
Returns:
A `JaxSimModelData` initialized with the given state.
Expand Down Expand Up @@ -144,7 +144,7 @@ def build(
translation=base_position, quaternion=base_quaternion
)

v_WB = JaxSimModelData.other_representation_to_inertial(
W_v_WB = JaxSimModelData.other_representation_to_inertial(
array=jnp.hstack([base_linear_velocity, base_angular_velocity]),
other_representation=velocity_representation,
transform=W_H_B,
Expand All @@ -155,28 +155,30 @@ def build(
joint_positions=joint_positions, base_transform=W_H_B
)

link_transforms, link_velocities = jaxsim.rbda.forward_kinematics_model(
model=model,
base_position=base_position,
base_quaternion=base_quaternion,
joint_positions=joint_positions,
base_linear_velocity=v_WB[0:3],
base_angular_velocity=v_WB[3:6],
joint_velocities=joint_velocities,
link_transforms, link_velocities_inertial = (
jaxsim.rbda.forward_kinematics_model(
model=model,
base_position=base_position,
base_quaternion=base_quaternion,
joint_positions=joint_positions,
base_linear_velocity_inertial=W_v_WB[0:3],
base_angular_velocity_inertial=W_v_WB[3:6],
joint_velocities=joint_velocities,
)
)

model_data = JaxSimModelData(
base_quaternion=base_quaternion,
base_position=base_position,
joint_positions=joint_positions,
base_linear_velocity=v_WB[0:3],
base_angular_velocity=v_WB[3:6],
base_linear_velocity=W_v_WB[0:3],
base_angular_velocity=W_v_WB[3:6],
joint_velocities=joint_velocities,
velocity_representation=velocity_representation,
base_transform=W_H_B,
joint_transforms=joint_transforms,
link_transforms=link_transforms,
link_velocities=link_velocities,
link_velocities=link_velocities_inertial,
)

if not model_data.valid(model=model):
Expand All @@ -189,14 +191,14 @@ def build(
@staticmethod
def zero(
model: js.model.JaxSimModel,
velocity_representation: VelRepr = VelRepr.Inertial,
velocity_representation: VelRepr = VelRepr.Mixed,
) -> JaxSimModelData:
"""
Create a `JaxSimModelData` object with zero state.
Args:
model: The model for which to create the state.
velocity_representation: The velocity representation to use.
velocity_representation: The velocity representation to use. It defaults to mixed if not provided.
Returns:
A `JaxSimModelData` initialized with zero state.
Expand Down Expand Up @@ -603,8 +605,8 @@ def update_cached(self, model: js.model.JaxSimModel) -> JaxSimModelData:
base_quaternion=self.base_quaternion,
joint_positions=self.joint_positions,
joint_velocities=self.joint_velocities,
base_linear_velocity=self.base_linear_velocity,
base_angular_velocity=self.base_angular_velocity,
base_linear_velocity_inertial=self.base_linear_velocity,
base_angular_velocity_inertial=self.base_angular_velocity,
)

return self.replace(
Expand Down
25 changes: 18 additions & 7 deletions src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1989,7 +1989,7 @@ def step(
model: JaxSimModel,
data: js.data.JaxSimModelData,
*,
link_forces_inertial: jtp.MatrixLike | None = None,
link_forces: jtp.MatrixLike | None = None,
joint_force_references: jtp.VectorLike | None = None,
) -> js.data.JaxSimModelData:
"""
Expand All @@ -1999,8 +1999,8 @@ def step(
model: The model to consider.
data: The data of the considered model.
dt: The time step to consider. If not specified, it is read from the model.
link_forces_inertial:
The 6D forces to apply to the links expressed in inertial-representation.
link_forces:
The 6D forces to apply to the links expressed in same representation of data.
joint_force_references: The joint force references to consider.
Returns:
Expand All @@ -2016,11 +2016,22 @@ def step(
# the enabled collidable points

# Extract the inputs
W_f_L_external = jnp.atleast_2d(
jnp.array(link_forces_inertial, dtype=float).squeeze()
if link_forces_inertial is not None
O_f_L_external = jnp.atleast_2d(
jnp.array(link_forces, dtype=float).squeeze()
if link_forces is not None
else jnp.zeros((model.number_of_links(), 6))
)

# Get the external forces in inertial-fixed representation.
W_f_L_external = jax.vmap(
lambda f_L, W_H_L: js.data.JaxSimModelData.other_representation_to_inertial(
f_L,
other_representation=data.velocity_representation,
transform=W_H_L,
is_force=True,
)
)(O_f_L_external, data.link_transforms)

τ_references = jnp.atleast_1d(
jnp.array(joint_force_references, dtype=float).squeeze()
if joint_force_references is not None
Expand Down Expand Up @@ -2063,7 +2074,7 @@ def step(
# ===============================

with data.switch_velocity_representation(jaxsim.VelRepr.Inertial):
W_v̇_WB, = js.ode.system_velocity_dynamics(
W_v̇_WB, = js.ode.system_acceleration(
model=model,
data=data,
link_forces=W_f_L_total,
Expand Down
54 changes: 1 addition & 53 deletions src/jaxsim/api/ode.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Any

import jax
import jax.numpy as jnp

Expand All @@ -15,56 +13,6 @@
# ==================================


@jax.jit
@js.common.named_scope
def system_velocity_dynamics(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
*,
link_forces: jtp.Vector | None = None,
joint_torques: jtp.Vector | None = None,
) -> tuple[jtp.Vector, jtp.Vector, dict[str, Any]]:
"""
Compute the dynamics of the system velocity.
Args:
model: The model to consider.
data: The data of the considered model.
link_forces:
The 6D forces to apply to the links expressed in inertial-fixed representation.
joint_torques: The joint torques acting on the joints.
Returns:
A tuple containing the derivative of the base 6D velocity in inertial-fixed
representation, the derivative of the joint velocities, and auxiliary data
returned by the system dynamics evaluation.
"""

# Build link forces if not provided.
# These forces are expressed in the frame corresponding to the velocity
# representation of data.
W_f_L = (
jnp.atleast_2d(link_forces.squeeze())
if link_forces is not None
else jnp.zeros((model.number_of_links(), 6))
).astype(float)

# ===========================
# Compute system acceleration
# ===========================

# Compute the system acceleration in inertial-fixed representation.
# This representation is useful for integration purpose.
W_v̇_WB, = system_acceleration(
model=model,
data=data,
joint_torques=joint_torques,
link_forces=W_f_L,
)

return W_v̇_WB,


def system_acceleration(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
Expand Down Expand Up @@ -197,7 +145,7 @@ def system_dynamics(
"""

with data.switch_velocity_representation(velocity_representation=VelRepr.Inertial):
W_v̇_WB, = system_velocity_dynamics(
W_v̇_WB, = system_acceleration(
model=model,
data=data,
joint_torques=joint_torques,
Expand Down
12 changes: 6 additions & 6 deletions src/jaxsim/rbda/forward_kinematics.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ def forward_kinematics_model(
base_position: jtp.VectorLike,
base_quaternion: jtp.VectorLike,
joint_positions: jtp.VectorLike,
base_linear_velocity: jtp.VectorLike,
base_angular_velocity: jtp.VectorLike,
base_linear_velocity_inertial: jtp.VectorLike,
base_angular_velocity_inertial: jtp.VectorLike,
joint_velocities: jtp.VectorLike,
) -> jtp.Array:
"""
Expand All @@ -27,8 +27,8 @@ def forward_kinematics_model(
base_position: The position of the base link.
base_quaternion: The quaternion of the base link.
joint_positions: The positions of the joints.
base_linear_velocity: The linear velocity of the base link.
base_angular_velocity: The angular velocity of the base link.
base_linear_velocity_inertial: The linear velocity of the base link in inertial-fixed representation.
base_angular_velocity_inertial: The angular velocity of the base link in inertial-fixed representation.
joint_velocities: The velocities of the joints.
Returns:
Expand All @@ -40,8 +40,8 @@ def forward_kinematics_model(
base_position=base_position,
base_quaternion=base_quaternion,
joint_positions=joint_positions,
base_linear_velocity=base_linear_velocity,
base_angular_velocity=base_angular_velocity,
base_linear_velocity=base_linear_velocity_inertial,
base_angular_velocity=base_angular_velocity_inertial,
joint_velocities=joint_velocities,
)

Expand Down
6 changes: 3 additions & 3 deletions tests/test_automatic_differentiation.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,8 @@ def test_ad_fk(
base_position=W_p_B,
base_quaternion=W_Q_B / jnp.linalg.norm(W_Q_B),
joint_positions=s,
base_linear_velocity=W_v_lin,
base_angular_velocity=W_v_ang,
base_linear_velocity_inertial=W_v_lin,
base_angular_velocity_inertial=W_v_ang,
joint_velocities=,
)

Expand Down Expand Up @@ -344,7 +344,7 @@ def step(
model=model,
data=data_x0,
joint_force_references=τ,
link_forces_inertial=W_f_L,
link_forces=W_f_L,
)

xf_W_p_B = data_xf.base_position
Expand Down
22 changes: 13 additions & 9 deletions tests/test_simulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def test_box_with_external_forces(
data = js.model.step(
model=model,
data=data,
link_forces_inertial=references._link_forces,
link_forces=references.link_forces(model, data),
)

# Check that the box didn't move.
Expand All @@ -84,6 +84,7 @@ def test_box_with_external_forces(

def test_box_with_zero_gravity(
jaxsim_model_box: js.model.JaxSimModel,
velocity_representation: VelRepr,
prng_key: jnp.ndarray,
):

Expand All @@ -101,14 +102,14 @@ def test_box_with_zero_gravity(
data0 = js.data.JaxSimModelData.build(
model=model,
base_position=jax.random.uniform(subkey, shape=(3,)),
velocity_representation=jaxsim.VelRepr.Inertial,
velocity_representation=velocity_representation,
)

# Initialize a references object that simplifies handling external forces.
references = js.references.JaxSimModelReferences.build(
model=model,
data=data0,
velocity_representation=jaxsim.VelRepr.Inertial,
velocity_representation=velocity_representation,
)

# Apply a link forces to the base link.
Expand Down Expand Up @@ -144,12 +145,15 @@ def test_box_with_zero_gravity(

# ... and step the simulation.
for _ in T:

data = js.model.step(
model=model,
data=data,
link_forces_inertial=references.link_forces(model=model, data=data),
)
with (
data.switch_velocity_representation(velocity_representation),
references.switch_velocity_representation(velocity_representation),
):
data = js.model.step(
model=model,
data=data,
link_forces=references.link_forces(model=model, data=data),
)

# Check that the box moved as expected.
assert data.base_position == pytest.approx(
Expand Down

0 comments on commit a1c5770

Please sign in to comment.