Skip to content

Increase performance for conversions including axis angles #1948

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 79 additions & 31 deletions pytorch3d/transforms/rotation_conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ def quaternion_apply(quaternion: torch.Tensor, point: torch.Tensor) -> torch.Ten
return out[..., 1:]


def axis_angle_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor:
def axis_angle_to_matrix(axis_angle: torch.Tensor, fast: bool=False) -> torch.Tensor:
"""
Convert rotations given as axis/angle to rotation matrices.

Expand All @@ -472,27 +472,95 @@ def axis_angle_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor:
as a tensor of shape (..., 3), where the magnitude is
the angle turned anticlockwise in radians around the
vector's direction.
fast: Whether to use the new faster implementation (based on the
Rodrigues formula) instead of the original implementation (which
first converted to a quaternion and then back to a rotation matrix).

Returns:
Rotation matrices as tensor of shape (..., 3, 3).
"""
return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
if not fast:
return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))

shape = axis_angle.shape
device, dtype = axis_angle.device, axis_angle.dtype

def matrix_to_axis_angle(matrix: torch.Tensor) -> torch.Tensor:
angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True).unsqueeze(-1)

rx, ry, rz = axis_angle[..., 0], axis_angle[..., 1], axis_angle[..., 2]
zeros = torch.zeros(shape[:-1], dtype=dtype, device=device)
cross_product_matrix = torch.stack(
[zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=-1
).view(shape + torch.Size([3]))
cross_product_matrix_sqrd = cross_product_matrix @ cross_product_matrix

identity = torch.eye(3, dtype=dtype, device=device)
angles_sqrd = angles * angles
angles_sqrd = torch.where(angles_sqrd == 0, 1, angles_sqrd)
return (
identity.expand(cross_product_matrix.shape)
+ torch.sinc(angles/torch.pi) * cross_product_matrix
+ ((1 - torch.cos(angles)) / angles_sqrd) * cross_product_matrix_sqrd
)


def matrix_to_axis_angle(matrix: torch.Tensor, fast: bool=False) -> torch.Tensor:
"""
Convert rotations given as rotation matrices to axis/angle.

Args:
matrix: Rotation matrices as tensor of shape (..., 3, 3).
fast: Whether to use the new faster implementation (based on the
Rodrigues formula) instead of the original implementation (which
first converted to a quaternion and then back to a rotation matrix).

Returns:
Rotations given as a vector in axis angle form, as a tensor
of shape (..., 3), where the magnitude is the angle
turned anticlockwise in radians around the vector's
direction.

"""
return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
if not fast:
return quaternion_to_axis_angle(matrix_to_quaternion(matrix))

if matrix.size(-1) != 3 or matrix.size(-2) != 3:
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")

omegas = torch.stack(
[
matrix[..., 2, 1] - matrix[..., 1, 2],
matrix[..., 0, 2] - matrix[..., 2, 0],
matrix[..., 1, 0] - matrix[..., 0, 1],
],
dim=-1,
)
norms = torch.norm(omegas, p=2, dim=-1, keepdim=True)
traces = torch.diagonal(matrix, dim1=-2, dim2=-1).sum(-1).unsqueeze(-1)
angles = torch.atan2(norms, traces - 1)

zeros = torch.zeros(3, dtype=matrix.dtype, device=matrix.device)
omegas = torch.where(
torch.isclose(angles, torch.zeros_like(angles)), zeros, omegas
)

near_pi = torch.isclose(
((traces - 1) / 2).abs(), torch.ones_like(traces)
).squeeze(-1)

axis_angles = torch.empty_like(omegas)
axis_angles[~near_pi] = 0.5 * omegas[~near_pi] / torch.sinc(
angles[~near_pi] / torch.pi
)

# this derives from: nnT = (R + 1) / 2
n = 0.5 * (
matrix[near_pi][..., 0, :] +
torch.eye(1, 3, dtype=matrix.dtype, device=matrix.device)
)
axis_angles[near_pi] = angles[near_pi] * n / torch.norm(n)

return axis_angles


def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor:
Expand All @@ -509,22 +577,11 @@ def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor:
quaternions with real part first, as tensor of shape (..., 4).
"""
angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
half_angles = angles * 0.5
eps = 1e-6
small_angles = angles.abs() < eps
sin_half_angles_over_angles = torch.empty_like(angles)
sin_half_angles_over_angles[~small_angles] = (
torch.sin(half_angles[~small_angles]) / angles[~small_angles]
)
# for x small, sin(x/2) is about x/2 - (x/2)^3/6
# so sin(x/2)/x is about 1/2 - (x*x)/48
sin_half_angles_over_angles[small_angles] = (
0.5 - (angles[small_angles] * angles[small_angles]) / 48
sin_half_angles_over_angles = 0.5 * torch.sinc(0.5 * angles / torch.pi)
return torch.cat(
[torch.cos(0.5 * angles), axis_angle * sin_half_angles_over_angles],
dim=-1
)
quaternions = torch.cat(
[torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1
)
return quaternions


def quaternion_to_axis_angle(quaternions: torch.Tensor) -> torch.Tensor:
Expand All @@ -543,18 +600,9 @@ def quaternion_to_axis_angle(quaternions: torch.Tensor) -> torch.Tensor:
"""
norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True)
half_angles = torch.atan2(norms, quaternions[..., :1])
angles = 2 * half_angles
eps = 1e-6
small_angles = angles.abs() < eps
sin_half_angles_over_angles = torch.empty_like(angles)
sin_half_angles_over_angles[~small_angles] = (
torch.sin(half_angles[~small_angles]) / angles[~small_angles]
)
# for x small, sin(x/2) is about x/2 - (x/2)^3/6
# so sin(x/2)/x is about 1/2 - (x*x)/48
sin_half_angles_over_angles[small_angles] = (
0.5 - (angles[small_angles] * angles[small_angles]) / 48
)
sin_half_angles_over_angles = 0.5 * torch.sinc(half_angles / torch.pi)
# angles/2 are between [-pi/2, pi/2], thus sin_half_angles_over_angles
# can't be zero
return quaternions[..., 1:] / sin_half_angles_over_angles


Expand Down