From 937ecf0be7f986f76bc42d65568140d0b55b54fd Mon Sep 17 00:00:00 2001 From: Diego Ferigo Date: Thu, 22 Sep 2022 11:46:30 +0200 Subject: [PATCH 1/2] Add quaternion conversions wxyz / xyzw --- src/jaxsim/math/quaternion.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/jaxsim/math/quaternion.py b/src/jaxsim/math/quaternion.py index 4081c0897..98706e221 100644 --- a/src/jaxsim/math/quaternion.py +++ b/src/jaxsim/math/quaternion.py @@ -6,6 +6,16 @@ class Quaternion: + @staticmethod + def to_xyzw(wxyz: jtp.Vector) -> jtp.Vector: + + return wxyz.squeeze()[jnp.array([1, 2, 3, 0])] + + @staticmethod + def to_wxyz(xyzw: jtp.Vector) -> jtp.Vector: + + return xyzw.squeeze()[jnp.array([3, 0, 1, 2])] + @staticmethod def to_dcm(quaternion: jtp.Vector) -> jtp.Matrix: From ea409d1cb5c37c3fa595a084bd80f6f3637b26ac Mon Sep 17 00:00:00 2001 From: Diego Ferigo Date: Thu, 22 Sep 2022 18:07:00 +0200 Subject: [PATCH 2/2] Remove usage of Plucker coordinates --- src/jaxsim/math/adjoint.py | 101 +++++++++--------- src/jaxsim/math/joint.py | 44 ++++++-- src/jaxsim/math/quaternion.py | 61 +++++------ src/jaxsim/math/rotation.py | 51 ++------- src/jaxsim/parsers/sdf/parser.py | 6 +- src/jaxsim/physics/algos/aba.py | 18 +++- .../physics/algos/forward_kinematics.py | 11 +- src/jaxsim/physics/algos/rnea.py | 13 ++- src/jaxsim/physics/algos/soft_contacts.py | 10 +- src/jaxsim/physics/model/physics_model.py | 18 ++-- 10 files changed, 163 insertions(+), 170 deletions(-) diff --git a/src/jaxsim/math/adjoint.py b/src/jaxsim/math/adjoint.py index f0f6fa54b..7735eb140 100644 --- a/src/jaxsim/math/adjoint.py +++ b/src/jaxsim/math/adjoint.py @@ -1,72 +1,75 @@ import jax.numpy as jnp import jaxsim.typing as jtp +from jaxsim.sixd import so3 + +from .quaternion import Quaternion +from .skew import Skew class Adjoint: @staticmethod - def rotate_x(theta: float) -> jtp.Matrix: + def from_quaternion_and_translation( + quaternion: jtp.Vector = jnp.array([1.0, 0, 0, 0]), + translation: jtp.Vector = jnp.zeros(3), + inverse: bool = False, + normalize_quaternion: bool = False, + ) -> jtp.Matrix: - c = jnp.cos(theta).squeeze() - s = jnp.sin(theta).squeeze() + assert quaternion.size == 4 + assert translation.size == 3 - return jnp.array( - [ - [1, 0, 0, 0, 0, 0], - [0, c, s, 0, 0, 0], - [0, -s, c, 0, 0, 0], - [0, 0, 0, 1, 0, 0], - [0, 0, 0, 0, c, s], - [0, 0, 0, 0, -s, c], - ] + Q_sixd = so3.SO3.from_quaternion_xyzw(xyzw=Quaternion.to_xyzw(quaternion)) + Q_sixd = Q_sixd if not normalize_quaternion else Q_sixd.normalize() + + return Adjoint.from_rotation_and_translation( + rotation=Q_sixd.as_matrix(), translation=translation, inverse=inverse ) @staticmethod - def rotate_y(theta: float) -> jtp.Matrix: + def from_rotation_and_translation( + rotation: jtp.Matrix = jnp.eye(3), + translation: jtp.Vector = jnp.zeros(3), + inverse: bool = False, + ) -> jtp.Matrix: - c = jnp.cos(theta).squeeze() - s = jnp.sin(theta).squeeze() - - return jnp.array( - [ - [c, 0, -s, 0, 0, 0], - [0, 1, 0, 0, 0, 0], - [s, 0, c, 0, 0, 0], - [0, 0, 0, c, 0, -s], - [0, 0, 0, 0, 1, 0], - [0, 0, 0, s, 0, c], - ] - ) + assert rotation.shape == (3, 3) + assert translation.size == 3 - @staticmethod - def rotate_z(theta: float) -> jtp.Matrix: + A_R_B = rotation.squeeze() + A_o_B = translation.squeeze() - c = jnp.cos(theta).squeeze() - s = jnp.sin(theta).squeeze() + if not inverse: + X = A_X_B = jnp.block( + [ + [A_R_B, Skew.wedge(A_o_B) @ A_R_B], + [jnp.zeros(shape=(3, 3)), A_R_B], + ] + ) + else: + X = B_X_A = jnp.block( + [ + [A_R_B.T, -A_R_B.T @ Skew.wedge(A_o_B)], + [jnp.zeros(shape=(3, 3)), A_R_B.T], + ] + ) - return jnp.array( - [ - [c, s, 0, 0, 0, 0], - [-s, c, 0, 0, 0, 0], - [0, 0, 1, 0, 0, 0], - [0, 0, 0, c, s, 0], - [0, 0, 0, -s, c, 0], - [0, 0, 0, 0, 0, 1], - ] - ) + return X @staticmethod - def translate(direction: jtp.Vector) -> jtp.Matrix: + def to_transform(adjoint: jtp.Matrix) -> jtp.Matrix: + + X = adjoint.squeeze() + assert X.shape == (6, 6) - x, y, z = direction + R = X[0:3, 0:3] + o_x_R = X[0:3, 3:6] - return jnp.array( + H = jnp.block( [ - [1, 0, 0, 0, z, -y], - [0, 1, 0, -z, 0, x], - [0, 0, 1, y, -x, 0], - [0, 0, 0, 1, 0, 0], - [0, 0, 0, 0, 1, 0], - [0, 0, 0, 0, 0, 1], + [R, Skew.vee(matrix=o_x_R @ R.T)], + [0, 0, 0, 1], ] ) + + return H diff --git a/src/jaxsim/math/joint.py b/src/jaxsim/math/joint.py index 579d8cda2..2fb02c530 100644 --- a/src/jaxsim/math/joint.py +++ b/src/jaxsim/math/joint.py @@ -6,7 +6,6 @@ from jaxsim.parsers.descriptions import JointDescriptor, JointGenericAxis, JointType from .adjoint import Adjoint -from .plucker import Plucker from .rotation import Rotation @@ -27,46 +26,69 @@ def jcalc( elif code is JointType.R: jtyp: JointGenericAxis - Xj = Plucker.from_rot_and_trans( - dcm=Rotation.from_axis_angle(vector=(q * jtyp.axis)), - translation=jnp.zeros(3), + + Xj = Adjoint.from_rotation_and_translation( + rotation=Rotation.from_axis_angle(vector=(q * jtyp.axis)), inverse=True ) + S = jnp.vstack(jnp.hstack([jnp.zeros(3), jtyp.axis.squeeze()])) elif code is JointType.P: jtyp: JointGenericAxis - Xj = Adjoint.translate(direction=(q * jtyp.axis)) + + Xj = Adjoint.from_rotation_and_translation( + translation=jnp.array(q * jtyp.axis), inverse=True + ) + S = jnp.vstack(jnp.hstack([jtyp.axis.squeeze(), jnp.zeros(3)])) elif code is JointType.Rx: - Xj = Adjoint.rotate_x(theta=q) + Xj = Adjoint.from_rotation_and_translation( + rotation=Rotation.x(theta=q), inverse=True + ) + S = jnp.vstack([0, 0, 0, 1.0, 0, 0]) elif code is JointType.Ry: - Xj = Adjoint.rotate_y(theta=q) + Xj = Adjoint.from_rotation_and_translation( + rotation=Rotation.y(theta=q), inverse=True + ) + S = jnp.vstack([0, 0, 0, 0, 1.0, 0]) elif code is JointType.Rz: - Xj = Adjoint.rotate_z(theta=q) + Xj = Adjoint.from_rotation_and_translation( + rotation=Rotation.z(theta=q), inverse=True + ) + S = jnp.vstack([0, 0, 0, 0, 0, 1.0]) elif code is JointType.Px: - Xj = Adjoint.translate(direction=jnp.hstack([q, 0.0, 0.0])) + Xj = Adjoint.from_rotation_and_translation( + translation=jnp.array([q, 0.0, 0.0]), inverse=True + ) + S = jnp.vstack([1.0, 0, 0, 0, 0, 0]) elif code is JointType.Py: - Xj = Adjoint.translate(direction=jnp.hstack([0.0, q, 0.0])) + Xj = Adjoint.from_rotation_and_translation( + translation=jnp.array([0.0, q, 0.0]), inverse=True + ) + S = jnp.vstack([0, 1.0, 0, 0, 0, 0]) elif code is JointType.Pz: - Xj = Adjoint.translate(direction=jnp.hstack([0.0, 0.0, q])) + Xj = Adjoint.from_rotation_and_translation( + translation=jnp.array([0.0, 0.0, q]), inverse=True + ) + S = jnp.vstack([0, 0, 1.0, 0, 0, 0]) else: diff --git a/src/jaxsim/math/quaternion.py b/src/jaxsim/math/quaternion.py index 98706e221..326a0572e 100644 --- a/src/jaxsim/math/quaternion.py +++ b/src/jaxsim/math/quaternion.py @@ -1,8 +1,8 @@ +import jax.lax import jax.numpy as jnp import jaxsim.typing as jtp - -from .skew import Skew +from jaxsim.sixd import so3 class Quaternion: @@ -19,40 +19,16 @@ def to_wxyz(xyzw: jtp.Vector) -> jtp.Vector: @staticmethod def to_dcm(quaternion: jtp.Vector) -> jtp.Matrix: - q = quaternion / jnp.linalg.norm(quaternion) - - q0s = q[0] * q[0] - q1s = q[1] * q[1] - q2s = q[2] * q[2] - q3s = q[3] * q[3] - q01 = q[0] * q[1] - q02 = q[0] * q[2] - q03 = q[0] * q[3] - q12 = q[1] * q[2] - q13 = q[3] * q[1] - q23 = q[2] * q[3] - - R = 2 * jnp.array( - [ - [q0s + q1s - 0.5, q12 + q03, q13 - q02], - [q12 - q03, q0s + q2s - 0.5, q23 + q01], - [q13 + q02, q23 - q01, q0s + q3s - 0.5], - ] - ) - - return R.squeeze() + return so3.SO3.from_quaternion_xyzw( + xyzw=Quaternion.to_xyzw(quaternion) + ).as_matrix() @staticmethod def from_dcm(dcm: jtp.Matrix) -> jtp.Vector: - R = dcm.squeeze() - - tr = jnp.trace(R) - v = -Skew.vee(R) - - q = jnp.vstack([(tr + 1) / 2.0, v]) - - return jnp.vstack(q) / jnp.linalg.norm(q) + return Quaternion.to_wxyz( + xyzw=so3.SO3.from_matrix(matrix=dcm).as_quaternion_xyzw() + ) @staticmethod def derivative( @@ -63,11 +39,13 @@ def derivative( ) -> jtp.Vector: w = omega.squeeze() - qw, qx, qy, qz = quaternion.squeeze() + quaternion = quaternion.squeeze() + + def Q_body(q: jtp.Vector) -> jtp.Matrix: - if omega_in_body_fixed: + qw, qx, qy, qz = q - Q = jnp.array( + return jnp.array( [ [qw, -qx, -qy, -qz], [qx, qw, -qz, qy], @@ -76,9 +54,11 @@ def derivative( ] ) - else: + def Q_inertial(q: jtp.Vector) -> jtp.Matrix: + + qw, qx, qy, qz = q - Q = jnp.array( + return jnp.array( [ [qw, -qx, -qy, -qz], [qx, qw, qz, -qy], @@ -87,6 +67,13 @@ def derivative( ] ) + Q = jax.lax.cond( + pred=omega_in_body_fixed, + true_fun=Q_body, + false_fun=Q_inertial, + operand=quaternion, + ) + qd = 0.5 * ( Q @ jnp.hstack( diff --git a/src/jaxsim/math/rotation.py b/src/jaxsim/math/rotation.py index 8352116c5..da805f965 100644 --- a/src/jaxsim/math/rotation.py +++ b/src/jaxsim/math/rotation.py @@ -4,63 +4,34 @@ import jax.numpy as jnp import jaxsim.typing as jtp +from jaxsim.sixd import so3 from .skew import Skew class Rotation: @staticmethod - def x(theta: float) -> jtp.Matrix: + def x(theta: jtp.Float) -> jtp.Matrix: - c = jnp.cos(theta) - s = jnp.sin(theta) - - return jnp.array( - [ - [1, 0, 0], - [0, c, s], - [0, -s, c], - ] - ) + return so3.SO3.from_x_radians(theta=theta).as_matrix() @staticmethod - def y(theta: float) -> jtp.Matrix: + def y(theta: jtp.Float) -> jtp.Matrix: - c = jnp.cos(theta) - s = jnp.sin(theta) - - return jnp.array( - [ - [c, 0, -s], - [0, 1, 0], - [s, 0, c], - ] - ) + return so3.SO3.from_y_radians(theta=theta).as_matrix() @staticmethod - def z(theta: float) -> jtp.Matrix: + def z(theta: jtp.Float) -> jtp.Matrix: - c = jnp.cos(theta) - s = jnp.sin(theta) - - return jnp.array( - [ - [c, s, 0], - [-s, c, 0], - [0, 0, 1], - ] - ) + return so3.SO3.from_z_radians(theta=theta).as_matrix() @staticmethod def from_axis_angle(vector: jtp.Vector) -> jtp.Matrix: + vector = vector.squeeze() theta = jnp.linalg.norm(vector) - def theta_is_zero(theta_and_v: Tuple[float, jtp.Vector]) -> jtp.Matrix: - - return jnp.eye(3) - - def theta_is_not_zero(theta_and_v: Tuple[float, jtp.Vector]) -> jtp.Matrix: + def theta_is_not_zero(theta_and_v: Tuple[jtp.Float, jtp.Vector]) -> jtp.Matrix: theta, v = theta_and_v @@ -74,11 +45,11 @@ def theta_is_not_zero(theta_and_v: Tuple[float, jtp.Vector]) -> jtp.Matrix: R = c * jnp.eye(3) - s * Skew.wedge(u) + c1 * u @ u.T - return R + return R.transpose() return jax.lax.cond( pred=(theta == 0.0), - true_fun=theta_is_zero, + true_fun=lambda operand: jnp.eye(3), false_fun=theta_is_not_zero, operand=(theta, vector), ) diff --git a/src/jaxsim/parsers/sdf/parser.py b/src/jaxsim/parsers/sdf/parser.py index c02e0ce28..a2b87d356 100644 --- a/src/jaxsim/parsers/sdf/parser.py +++ b/src/jaxsim/parsers/sdf/parser.py @@ -5,9 +5,9 @@ import jax.numpy as jnp import numpy as np import pysdf -from scipy.spatial.transform.rotation import Rotation as R from jaxsim import logging +from jaxsim.math.quaternion import Quaternion from jaxsim.parsers import descriptions, kinematic_graph from . import utils as utils @@ -54,11 +54,9 @@ def extract_data_from_sdf( else: w_H_m = utils.from_sdf_pose(pose=sdf_tree.model.pose) - xyzw_to_wxyz = np.array([3, 0, 1, 2]) - w_quat_m = R.from_matrix(w_H_m[0:3, 0:3]).as_quat()[xyzw_to_wxyz] model_pose = kinematic_graph.RootPose( root_position=w_H_m[0:3, 3], - root_quaternion=w_quat_m, + root_quaternion=Quaternion.from_dcm(dcm=w_H_m[0:3, 0:3]), ) # =========== diff --git a/src/jaxsim/physics/algos/aba.py b/src/jaxsim/physics/algos/aba.py index ad544446a..540203841 100644 --- a/src/jaxsim/physics/algos/aba.py +++ b/src/jaxsim/physics/algos/aba.py @@ -5,9 +5,8 @@ import jax.numpy as jnp import jaxsim.typing as jtp +from jaxsim.math.adjoint import Adjoint from jaxsim.math.cross import Cross -from jaxsim.math.plucker import Plucker -from jaxsim.math.quaternion import Quaternion from jaxsim.physics.model.physics_model import PhysicsModel from . import utils @@ -44,8 +43,13 @@ def aba( base_quat = jnp.vstack(x_fb[0:4]) base_pos = jnp.vstack(x_fb[4:7]) base_vel = jnp.vstack(jnp.hstack([x_fb[10:13], x_fb[7:10]])) - B_X_W = Plucker.from_rot_and_trans( - dcm=Quaternion.to_dcm(quaternion=base_quat), translation=base_pos + + # 6D transform of base velocity + B_X_W = Adjoint.from_quaternion_and_translation( + quaternion=base_quat, + translation=base_pos, + inverse=True, + normalize_quaternion=True, ) i_X_λi = i_X_λi.at[0].set(B_X_W) @@ -208,8 +212,12 @@ def propagate(MA_pA): qdd = jnp.atleast_1d(qdd.squeeze()) qdd = jnp.vstack(qdd) if qdd.size > 0 else jnp.empty(shape=(0, 1)) + # Get the resulting base acceleration (w/o gravity) in body-fixed representation + B_a_WB = a[0] + + # Convert the base acceleration to inertial-fixed representation, and add gravity W_a_WB = jnp.vstack( - jnp.linalg.solve(B_X_W, a[0]) + jnp.vstack(model.gravity) + jnp.linalg.solve(B_X_W, B_a_WB) + jnp.vstack(model.gravity) if model.is_floating_base else jnp.zeros(6) ) diff --git a/src/jaxsim/physics/algos/forward_kinematics.py b/src/jaxsim/physics/algos/forward_kinematics.py index 4d6fb9e99..c57fbdf9f 100644 --- a/src/jaxsim/physics/algos/forward_kinematics.py +++ b/src/jaxsim/physics/algos/forward_kinematics.py @@ -5,8 +5,7 @@ import numpy as np import jaxsim.typing as jtp -from jaxsim.math.plucker import Plucker -from jaxsim.math.quaternion import Quaternion +from jaxsim.math.adjoint import Adjoint from jaxsim.physics.model.physics_model import PhysicsModel from . import utils @@ -20,9 +19,9 @@ def forward_kinematics_model( physics_model=model, xfb=xfb, q=q, qd=None, tau=None, f_ext=None ) - qn = jnp.vstack(x_fb[0:4]) - r = jnp.vstack(x_fb[4:7]) - W_X_0 = jnp.linalg.inv(Plucker.from_rot_and_trans(Quaternion.to_dcm(qn), r)) + W_X_0 = Adjoint.from_quaternion_and_translation( + quaternion=x_fb[0:4], translation=x_fb[4:7] + ) # This is the 6D velocity transform from i-th link frame to the world frame W_X_i = jnp.zeros(shape=[model.NB, 6, 6]) @@ -61,7 +60,7 @@ def propagate_kinematics( xs=np.arange(start=1, stop=model.NB), ) - return jnp.stack([Plucker.to_transform(X) for X in list(W_X_i)]) + return jnp.stack([Adjoint.to_transform(adjoint=X) for X in list(W_X_i)]) def forward_kinematics( diff --git a/src/jaxsim/physics/algos/rnea.py b/src/jaxsim/physics/algos/rnea.py index 9bee29d90..bdcad9c9b 100644 --- a/src/jaxsim/physics/algos/rnea.py +++ b/src/jaxsim/physics/algos/rnea.py @@ -4,9 +4,8 @@ import numpy as np import jaxsim.typing as jtp +from jaxsim.math.adjoint import Adjoint from jaxsim.math.cross import Cross -from jaxsim.math.plucker import Plucker -from jaxsim.math.quaternion import Quaternion from jaxsim.physics.model.physics_model import PhysicsModel from . import utils @@ -45,9 +44,13 @@ def rnea( a: Dict[int, jtp.VectorJax] = dict() f: Dict[int, jtp.VectorJax] = dict() - qn = jnp.vstack(xfb[0:4]) - r = jnp.vstack(xfb[4:7]) - Xup_0 = B_X_W = Plucker.from_rot_and_trans(Quaternion.to_dcm(qn), r) + # 6D transform of base velocity + Xup_0 = B_X_W = Adjoint.from_quaternion_and_translation( + quaternion=xfb[0:4], + translation=xfb[4:7], + inverse=True, + normalize_quaternion=True, + ) Xup = Xup.at[0].set(Xup_0) v[0] = jnp.zeros(shape=(6, 1)) diff --git a/src/jaxsim/physics/algos/soft_contacts.py b/src/jaxsim/physics/algos/soft_contacts.py index ca0f6f407..8ae456ad2 100644 --- a/src/jaxsim/physics/algos/soft_contacts.py +++ b/src/jaxsim/physics/algos/soft_contacts.py @@ -8,9 +8,8 @@ import jaxsim.physics.model.physics_model import jaxsim.typing as jtp +from jaxsim.math.adjoint import Adjoint from jaxsim.math.conv import Convert -from jaxsim.math.plucker import Plucker -from jaxsim.math.quaternion import Quaternion from jaxsim.math.skew import Skew from jaxsim.physics.algos.terrain import FlatTerrain, Terrain from jaxsim.physics.model.physics_model import PhysicsModel @@ -66,9 +65,10 @@ def collidable_points_pos_vel( Xa = jnp.array([jnp.eye(6)] * (model.NB)) vb = jnp.array([jnp.zeros([6, 1])] * (model.NB)) - qn = xfb[0:4] - r = xfb[4:7] - Xa_0 = Plucker.from_rot_and_trans(Quaternion.to_dcm(qn), r) + # 6D transform of base velocity + Xa_0 = B_X_W = Adjoint.from_quaternion_and_translation( + quaternion=xfb[0:4], translation=xfb[4:7], inverse=True + ) Xa = Xa.at[0].set(Xa_0) vfb = jnp.vstack(jnp.hstack([xfb[10:13], xfb[7:10]])) diff --git a/src/jaxsim/physics/model/physics_model.py b/src/jaxsim/physics/model/physics_model.py index 7486603c3..0de2d65db 100644 --- a/src/jaxsim/physics/model/physics_model.py +++ b/src/jaxsim/physics/model/physics_model.py @@ -10,9 +10,9 @@ import jaxsim.parsers import jaxsim.physics import jaxsim.typing as jtp -from jaxsim.math.plucker import Plucker from jaxsim.parsers.descriptions import JointDescriptor, JointType from jaxsim.physics import default_gravity +from jaxsim.sixd import se3 from jaxsim.utils import JaxsimDataclass, tracing from .ground_contact import GroundContact @@ -94,8 +94,14 @@ def build_from( # (this is just the pose of the base link in the SDF description) base_link = model_description.links_dict[model_description.link_names()[0]] R_H_B = model_description.transform(name=base_link.name) - B_H_R = np.linalg.inv(R_H_B) - tree_transform_0 = Plucker.from_transform(transform=B_H_R) + tree_transform_0 = se3.SE3.from_matrix(matrix=R_H_B).adjoint() + + # Helper to compute the transform pre(i)_H_λ(i). + # Given a joint 'i', it is the coordinate transform between its predecessor + # frame [pre(i)] and the frame of its parent link [λ(i)]. + prei_H_λi = lambda j: model_description.relative_transform( + relative_to=j.name, name=j.parent.name + ) # Compute the tree transforms: pre(i)_X_λ(i). # Given a joint 'i', it is the coordinate transform between its predecessor @@ -103,11 +109,7 @@ def build_from( tree_transforms_dict = { 0: tree_transform_0, **{ - j.index: Plucker.from_transform( - transform=model_description.relative_transform( - relative_to=j.name, name=j.parent.name - ) - ) + j.index: se3.SE3.from_matrix(matrix=prei_H_λi(j)).adjoint() for j in model_description.joints }, }