Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for computing the bias acceleration of the center of mass #161

Merged
merged 3 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 176 additions & 0 deletions src/jaxsim/api/com.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,3 +238,179 @@ def average_centroidal_velocity_jacobian(
G_Mbb = locked_centroidal_spatial_inertia(model=model, data=data)

return jnp.linalg.inv(G_Mbb) @ G_J


@jax.jit
def bias_acceleration(
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
) -> jtp.Vector:
r"""
Compute the bias linear acceleration of the center of mass.

Args:
model: The model to consider.
data: The data of the considered model.

Returns:
The bias linear acceleration of the center of mass in the active representation.

Note:
The bias acceleration is expressed in the mixed frame
:math:`G = ({}^W \mathbf{p}_{\text{CoM}}, [C])`, where :math:`[C] = [W]` if the
active velocity representation is either inertial-fixed or mixed,
and :math:`[C] = [B]` if the active velocity representation is body-fixed.
"""

# Compute the pose of all links with forward kinematics.
W_H_L = js.model.forward_kinematics(model=model, data=data)

# Compute the bias acceleration of all links by zeroing the generalized velocity
# in the active representation.
v̇_bias_WL = js.model.link_bias_accelerations(model=model, data=data)

def other_representation_to_body(
C_v̇_WL: jtp.Vector, C_v_WC: jtp.Vector, L_H_C: jtp.Matrix, L_v_LC: jtp.Vector
) -> jtp.Vector:
"""
Helper to convert the body-fixed representation of the link bias acceleration
C_v̇_WL expressed in a generic frame C to the body-fixed representation L_v̇_WL.
"""

L_X_C = jaxsim.math.Adjoint.from_transform(transform=L_H_C)
C_X_L = jaxsim.math.Adjoint.inverse(L_X_C)

L_v̇_WL = L_X_C @ (C_v̇_WL + jaxsim.math.Cross.vx(C_X_L @ L_v_LC) @ C_v_WC)
return L_v̇_WL

# We need here to get the body-fixed bias acceleration of the links.
# Since it's computed in the active representation, we need to convert it to body.
match data.velocity_representation:

case VelRepr.Body:
L_a_bias_WL = v̇_bias_WL

case VelRepr.Inertial:

C_v̇_WL = W_v̇_bias_WL = v̇_bias_WL
C_v_WC = W_v_WW = jnp.zeros(6)

L_H_C = L_H_W = jax.vmap(
lambda W_H_L: jaxsim.math.Transform.inverse(W_H_L)
)(W_H_L)

L_v_LC = L_v_LW = jax.vmap(
lambda i: -js.link.velocity(
model=model, data=data, link_index=i, output_vel_repr=VelRepr.Body
)
)(jnp.arange(model.number_of_links()))

L_a_bias_WL = jax.vmap(
lambda i: other_representation_to_body(
C_v̇_WL=C_v̇_WL[i],
C_v_WC=C_v_WC,
L_H_C=L_H_C[i],
L_v_LC=L_v_LC[i],
)
)(jnp.arange(model.number_of_links()))

case VelRepr.Mixed:

C_v̇_WL = LW_v̇_bias_WL = v̇_bias_WL

C_v_WC = LW_v_W_LW = jax.vmap(
lambda i: js.link.velocity(
model=model, data=data, link_index=i, output_vel_repr=VelRepr.Mixed
)
.at[3:6]
.set(jnp.zeros(3))
)(jnp.arange(model.number_of_links()))

L_H_C = L_H_LW = jax.vmap(
lambda W_H_L: jaxsim.math.Transform.inverse(
W_H_L.at[0:3, 3].set(jnp.zeros(3))
)
)(W_H_L)

L_v_LC = L_v_L_LW = jax.vmap(
lambda i: -js.link.velocity(
model=model, data=data, link_index=i, output_vel_repr=VelRepr.Body
)
.at[0:3]
.set(jnp.zeros(3))
)(jnp.arange(model.number_of_links()))

L_a_bias_WL = jax.vmap(
lambda i: other_representation_to_body(
C_v̇_WL=C_v̇_WL[i],
C_v_WC=C_v_WC[i],
L_H_C=L_H_C[i],
L_v_LC=L_v_LC[i],
)
)(jnp.arange(model.number_of_links()))

case _:
raise ValueError(data.velocity_representation)

# Compute the bias of the 6D momentum derivative.
def bias_momentum_derivative_term(
link_index: jtp.Int, L_a_bias_WL: jtp.Vector
) -> jtp.Vector:

# Get the body-fixed 6D inertia matrix.
L_M_L = js.link.spatial_inertia(model=model, link_index=link_index)

# Compute the body-fixed 6D velocity.
L_v_WL = js.link.velocity(
model=model, data=data, link_index=link_index, output_vel_repr=VelRepr.Body
)

# Compute the world-to-link transformations for 6D forces.
W_Xf_L = jaxsim.math.Adjoint.from_transform(
transform=W_H_L[link_index], inverse=True
).T

# Compute the contribution of the link to the bias acceleration of the CoM.
W_ḣ_bias_link_contribution = W_Xf_L @ (
L_M_L @ L_a_bias_WL + jaxsim.math.Cross.vx_star(L_v_WL) @ L_M_L @ L_v_WL
)

return W_ḣ_bias_link_contribution

# Sum the contributions of all links to the bias acceleration of the CoM.
W_ḣ_bias = jax.vmap(bias_momentum_derivative_term)(
jnp.arange(model.number_of_links()), L_a_bias_WL
).sum(axis=0)

# Compute the total mass of the model.
m = js.model.total_mass(model=model)

# Compute the position of the CoM.
W_p_CoM = com_position(model=model, data=data)

match data.velocity_representation:

# G := G[W] = (W_p_CoM, [W])
case VelRepr.Inertial | VelRepr.Mixed:

W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM)
GW_Xf_W = jaxsim.math.Adjoint.from_transform(W_H_GW).T

GW_ḣ_bias = GW_Xf_W @ W_ḣ_bias
GW_v̇l_com_bias = GW_ḣ_bias[0:3] / m

return GW_v̇l_com_bias

# G := G[B] = (W_p_CoM, [B])
case VelRepr.Body:

GB_Xf_W = jaxsim.math.Adjoint.from_transform(
transform=data.base_transform().at[0:3].set(W_p_CoM)
).T

GB_ḣ_bias = GB_Xf_W @ W_ḣ_bias
GB_v̇l_com_bias = GB_ḣ_bias[0:3] / m

return GB_v̇l_com_bias

case _:
raise ValueError(data.velocity_representation)
8 changes: 8 additions & 0 deletions tests/test_api_com.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,11 @@ def test_com_properties(
vl_com_idt = kin_dyn.com_velocity()
vl_com_js = js.com.com_linear_velocity(model=model, data=data)
assert pytest.approx(vl_com_idt) == vl_com_js

# iDynTree provides the bias acceleration in G[W] frame regardless of the velocity
# representation. JaxSim, instead, returns the bias acceleration in G[B] when the
# active representation is VelRepr.Body.
if data.velocity_representation is not VelRepr.Body:
G_v̇_bias_WG_idt = kin_dyn.com_bias_acceleration()
G_v̇_bias_WG_js = js.com.bias_acceleration(model=model, data=data)
assert pytest.approx(G_v̇_bias_WG_idt) == G_v̇_bias_WG_js
4 changes: 4 additions & 0 deletions tests/utils_idyntree.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,10 @@ def com_velocity(self) -> npt.NDArray:
W_ṗ_G = self.kin_dyn.getCenterOfMassVelocity()
return W_ṗ_G.toNumPy()

def com_bias_acceleration(self) -> npt.NDArray:

return self.kin_dyn.getCenterOfMassBiasAcc().toNumPy()

def mass_matrix(self) -> npt.NDArray:

M = idt.MatrixDynSize()
Expand Down