Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove references of Plucker coordinates #19

Merged
merged 2 commits into from
Sep 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 52 additions & 49 deletions src/jaxsim/math/adjoint.py
Original file line number Diff line number Diff line change
@@ -1,72 +1,75 @@
import jax.numpy as jnp

import jaxsim.typing as jtp
from jaxsim.sixd import so3

from .quaternion import Quaternion
from .skew import Skew


class Adjoint:
@staticmethod
def rotate_x(theta: float) -> jtp.Matrix:
def from_quaternion_and_translation(
quaternion: jtp.Vector = jnp.array([1.0, 0, 0, 0]),
translation: jtp.Vector = jnp.zeros(3),
inverse: bool = False,
normalize_quaternion: bool = False,
) -> jtp.Matrix:

c = jnp.cos(theta).squeeze()
s = jnp.sin(theta).squeeze()
assert quaternion.size == 4
assert translation.size == 3

return jnp.array(
[
[1, 0, 0, 0, 0, 0],
[0, c, s, 0, 0, 0],
[0, -s, c, 0, 0, 0],
[0, 0, 0, 1, 0, 0],
[0, 0, 0, 0, c, s],
[0, 0, 0, 0, -s, c],
]
Q_sixd = so3.SO3.from_quaternion_xyzw(xyzw=Quaternion.to_xyzw(quaternion))
Q_sixd = Q_sixd if not normalize_quaternion else Q_sixd.normalize()

return Adjoint.from_rotation_and_translation(
rotation=Q_sixd.as_matrix(), translation=translation, inverse=inverse
)

@staticmethod
def rotate_y(theta: float) -> jtp.Matrix:
def from_rotation_and_translation(
rotation: jtp.Matrix = jnp.eye(3),
translation: jtp.Vector = jnp.zeros(3),
inverse: bool = False,
) -> jtp.Matrix:

c = jnp.cos(theta).squeeze()
s = jnp.sin(theta).squeeze()

return jnp.array(
[
[c, 0, -s, 0, 0, 0],
[0, 1, 0, 0, 0, 0],
[s, 0, c, 0, 0, 0],
[0, 0, 0, c, 0, -s],
[0, 0, 0, 0, 1, 0],
[0, 0, 0, s, 0, c],
]
)
assert rotation.shape == (3, 3)
assert translation.size == 3

@staticmethod
def rotate_z(theta: float) -> jtp.Matrix:
A_R_B = rotation.squeeze()
A_o_B = translation.squeeze()

c = jnp.cos(theta).squeeze()
s = jnp.sin(theta).squeeze()
if not inverse:
X = A_X_B = jnp.block(
[
[A_R_B, Skew.wedge(A_o_B) @ A_R_B],
[jnp.zeros(shape=(3, 3)), A_R_B],
]
)
else:
X = B_X_A = jnp.block(
[
[A_R_B.T, -A_R_B.T @ Skew.wedge(A_o_B)],
[jnp.zeros(shape=(3, 3)), A_R_B.T],
]
)

return jnp.array(
[
[c, s, 0, 0, 0, 0],
[-s, c, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0],
[0, 0, 0, c, s, 0],
[0, 0, 0, -s, c, 0],
[0, 0, 0, 0, 0, 1],
]
)
return X

@staticmethod
def translate(direction: jtp.Vector) -> jtp.Matrix:
def to_transform(adjoint: jtp.Matrix) -> jtp.Matrix:

X = adjoint.squeeze()
assert X.shape == (6, 6)

x, y, z = direction
R = X[0:3, 0:3]
o_x_R = X[0:3, 3:6]

return jnp.array(
H = jnp.block(
[
[1, 0, 0, 0, z, -y],
[0, 1, 0, -z, 0, x],
[0, 0, 1, y, -x, 0],
[0, 0, 0, 1, 0, 0],
[0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 1],
[R, Skew.vee(matrix=o_x_R @ R.T)],
[0, 0, 0, 1],
]
)

return H
44 changes: 33 additions & 11 deletions src/jaxsim/math/joint.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from jaxsim.parsers.descriptions import JointDescriptor, JointGenericAxis, JointType

from .adjoint import Adjoint
from .plucker import Plucker
from .rotation import Rotation


Expand All @@ -27,46 +26,69 @@ def jcalc(
elif code is JointType.R:

jtyp: JointGenericAxis
Xj = Plucker.from_rot_and_trans(
dcm=Rotation.from_axis_angle(vector=(q * jtyp.axis)),
translation=jnp.zeros(3),

Xj = Adjoint.from_rotation_and_translation(
rotation=Rotation.from_axis_angle(vector=(q * jtyp.axis)), inverse=True
)

S = jnp.vstack(jnp.hstack([jnp.zeros(3), jtyp.axis.squeeze()]))

elif code is JointType.P:

jtyp: JointGenericAxis
Xj = Adjoint.translate(direction=(q * jtyp.axis))

Xj = Adjoint.from_rotation_and_translation(
translation=jnp.array(q * jtyp.axis), inverse=True
)

S = jnp.vstack(jnp.hstack([jtyp.axis.squeeze(), jnp.zeros(3)]))

elif code is JointType.Rx:

Xj = Adjoint.rotate_x(theta=q)
Xj = Adjoint.from_rotation_and_translation(
rotation=Rotation.x(theta=q), inverse=True
)

S = jnp.vstack([0, 0, 0, 1.0, 0, 0])

elif code is JointType.Ry:

Xj = Adjoint.rotate_y(theta=q)
Xj = Adjoint.from_rotation_and_translation(
rotation=Rotation.y(theta=q), inverse=True
)

S = jnp.vstack([0, 0, 0, 0, 1.0, 0])

elif code is JointType.Rz:

Xj = Adjoint.rotate_z(theta=q)
Xj = Adjoint.from_rotation_and_translation(
rotation=Rotation.z(theta=q), inverse=True
)

S = jnp.vstack([0, 0, 0, 0, 0, 1.0])

elif code is JointType.Px:

Xj = Adjoint.translate(direction=jnp.hstack([q, 0.0, 0.0]))
Xj = Adjoint.from_rotation_and_translation(
translation=jnp.array([q, 0.0, 0.0]), inverse=True
)

S = jnp.vstack([1.0, 0, 0, 0, 0, 0])

elif code is JointType.Py:

Xj = Adjoint.translate(direction=jnp.hstack([0.0, q, 0.0]))
Xj = Adjoint.from_rotation_and_translation(
translation=jnp.array([0.0, q, 0.0]), inverse=True
)

S = jnp.vstack([0, 1.0, 0, 0, 0, 0])

elif code is JointType.Pz:

Xj = Adjoint.translate(direction=jnp.hstack([0.0, 0.0, q]))
Xj = Adjoint.from_rotation_and_translation(
translation=jnp.array([0.0, 0.0, q]), inverse=True
)

S = jnp.vstack([0, 0, 1.0, 0, 0, 0])

else:
Expand Down
69 changes: 33 additions & 36 deletions src/jaxsim/math/quaternion.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,34 @@
import jax.lax
import jax.numpy as jnp

import jaxsim.typing as jtp

from .skew import Skew
from jaxsim.sixd import so3


class Quaternion:
@staticmethod
def to_dcm(quaternion: jtp.Vector) -> jtp.Matrix:

q = quaternion / jnp.linalg.norm(quaternion)

q0s = q[0] * q[0]
q1s = q[1] * q[1]
q2s = q[2] * q[2]
q3s = q[3] * q[3]
q01 = q[0] * q[1]
q02 = q[0] * q[2]
q03 = q[0] * q[3]
q12 = q[1] * q[2]
q13 = q[3] * q[1]
q23 = q[2] * q[3]

R = 2 * jnp.array(
[
[q0s + q1s - 0.5, q12 + q03, q13 - q02],
[q12 - q03, q0s + q2s - 0.5, q23 + q01],
[q13 + q02, q23 - q01, q0s + q3s - 0.5],
]
)
def to_xyzw(wxyz: jtp.Vector) -> jtp.Vector:

return R.squeeze()
return wxyz.squeeze()[jnp.array([1, 2, 3, 0])]

@staticmethod
def from_dcm(dcm: jtp.Matrix) -> jtp.Vector:
def to_wxyz(xyzw: jtp.Vector) -> jtp.Vector:

R = dcm.squeeze()
return xyzw.squeeze()[jnp.array([3, 0, 1, 2])]

tr = jnp.trace(R)
v = -Skew.vee(R)
@staticmethod
def to_dcm(quaternion: jtp.Vector) -> jtp.Matrix:

q = jnp.vstack([(tr + 1) / 2.0, v])
return so3.SO3.from_quaternion_xyzw(
xyzw=Quaternion.to_xyzw(quaternion)
).as_matrix()

return jnp.vstack(q) / jnp.linalg.norm(q)
@staticmethod
def from_dcm(dcm: jtp.Matrix) -> jtp.Vector:

return Quaternion.to_wxyz(
xyzw=so3.SO3.from_matrix(matrix=dcm).as_quaternion_xyzw()
)

@staticmethod
def derivative(
Expand All @@ -53,11 +39,13 @@ def derivative(
) -> jtp.Vector:

w = omega.squeeze()
qw, qx, qy, qz = quaternion.squeeze()
quaternion = quaternion.squeeze()

def Q_body(q: jtp.Vector) -> jtp.Matrix:

if omega_in_body_fixed:
qw, qx, qy, qz = q

Q = jnp.array(
return jnp.array(
[
[qw, -qx, -qy, -qz],
[qx, qw, -qz, qy],
Expand All @@ -66,9 +54,11 @@ def derivative(
]
)

else:
def Q_inertial(q: jtp.Vector) -> jtp.Matrix:

Q = jnp.array(
qw, qx, qy, qz = q

return jnp.array(
[
[qw, -qx, -qy, -qz],
[qx, qw, qz, -qy],
Expand All @@ -77,6 +67,13 @@ def derivative(
]
)

Q = jax.lax.cond(
pred=omega_in_body_fixed,
true_fun=Q_body,
false_fun=Q_inertial,
operand=quaternion,
)

qd = 0.5 * (
Q
@ jnp.hstack(
Expand Down
Loading