Skip to content

Commit 9a0f9ae

Browse files
shapovalovfacebook-github-bot
authored andcommitted
Extending the API of Transform3d with SE(3) log
Summary: This is quite a thin wrapper – not sure we need it. The motivation is that `Transform3d` is not as matrix-centric now, it can be converted to SE(3) logarithm equally easily. It simplifies things like averaging cameras and getting axis-angle of camera rotation (previously, one would need to call `se3_log_map(cameras.get_world_to_camera_transform().get_matrix())`), now one fewer thing to call / discover. Reviewed By: bottler Differential Revision: D39928000 fbshipit-source-id: 85248d5b8af136618f1d08791af5297ea5179d19
1 parent 74bbd6f commit 9a0f9ae

File tree

2 files changed

+66
-7
lines changed

2 files changed

+66
-7
lines changed

Diff for: pytorch3d/transforms/transform3d.py

+55-7
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from ..common.datatypes import Device, get_device, make_device
1414
from ..common.workaround import _safe_det_3x3
1515
from .rotation_conversions import _axis_angle_rotation
16+
from .se3 import se3_log_map
1617

1718

1819
class Transform3d:
@@ -130,13 +131,13 @@ class Transform3d:
130131
[Tx, Ty, Tz, 1],
131132
]
132133
133-
To apply the transformation to points which are row vectors, the M matrix
134-
can be pre multiplied by the points:
134+
To apply the transformation to points, which are row vectors, the latter are
135+
converted to homogeneous (4D) coordinates and right-multiplied by the M matrix:
135136
136137
.. code-block:: python
137138
138139
points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point
139-
transformed_points = points * M
140+
[transformed_points, 1] ∝ [points, 1] @ M
140141
141142
"""
142143

@@ -218,9 +219,10 @@ def compose(self, *others: "Transform3d") -> "Transform3d":
218219

219220
def get_matrix(self) -> torch.Tensor:
220221
"""
221-
Return a matrix which is the result of composing this transform
222-
with others stored in self.transforms. Where necessary transforms
223-
are broadcast against each other.
222+
Returns a 4×4 matrix corresponding to each transform in the batch.
223+
224+
If the transform was composed from others, the matrix for the composite
225+
transform will be returned.
224226
For example, if self.transforms contains transforms t1, t2, and t3, and
225227
given a set of points x, the following should be true:
226228
@@ -230,8 +232,11 @@ def get_matrix(self) -> torch.Tensor:
230232
y2 = t3.transform(t2.transform(t1.transform(x)))
231233
y1.get_matrix() == y2.get_matrix()
232234
235+
Where necessary, those transforms are broadcast against each other.
236+
233237
Returns:
234-
A transformation matrix representing the composed inputs.
238+
A (N, 4, 4) batch of transformation matrices representing
239+
the stored transforms. See the class documentation for the conventions.
235240
"""
236241
composed_matrix = self._matrix.clone()
237242
if len(self._transforms) > 0:
@@ -240,6 +245,49 @@ def get_matrix(self) -> torch.Tensor:
240245
composed_matrix = _broadcast_bmm(composed_matrix, other_matrix)
241246
return composed_matrix
242247

248+
def get_se3_log(self, eps: float = 1e-4, cos_bound: float = 1e-4) -> torch.Tensor:
249+
"""
250+
Returns a 6D SE(3) log vector corresponding to each transform in the batch.
251+
252+
In the SE(3) logarithmic representation SE(3) matrices are
253+
represented as 6-dimensional vectors `[log_translation | log_rotation]`,
254+
i.e. a concatenation of two 3D vectors `log_translation` and `log_rotation`.
255+
256+
The conversion from the 4x4 SE(3) matrix `transform` to the
257+
6D representation `log_transform = [log_translation | log_rotation]`
258+
is done as follows:
259+
```
260+
log_transform = log(transform.get_matrix())
261+
log_translation = log_transform[3, :3]
262+
log_rotation = inv_hat(log_transform[:3, :3])
263+
```
264+
where `log` is the matrix logarithm
265+
and `inv_hat` is the inverse of the Hat operator [2].
266+
267+
See the docstring for `se3.se3_log_map` and [1], Sec 9.4.2. for more
268+
detailed description.
269+
270+
Args:
271+
eps: A threshold for clipping the squared norm of the rotation logarithm
272+
to avoid division by zero in the singular case.
273+
cos_bound: Clamps the cosine of the rotation angle to
274+
[-1 + cos_bound, 3 - cos_bound] to avoid non-finite outputs.
275+
The non-finite outputs can be caused by passing small rotation angles
276+
to the `acos` function in `so3_rotation_angle` of `so3_log_map`.
277+
278+
Returns:
279+
A (N, 6) tensor, rows of which represent the individual transforms
280+
stored in the object as SE(3) logarithms.
281+
282+
Raises:
283+
ValueError if the stored transform is not Euclidean (e.g. R is not a rotation
284+
matrix or the last column has non-zeros in the first three places).
285+
286+
[1] https://jinyongjeong.github.io/Download/SE3/jlblanco2010geometry3d_techrep.pdf
287+
[2] https://en.wikipedia.org/wiki/Hat_operator
288+
"""
289+
return se3_log_map(self.get_matrix(), eps, cos_bound)
290+
243291
def _get_matrix_inverse(self) -> torch.Tensor:
244292
"""
245293
Return the inverse of self._matrix.

Diff for: tests/test_transforms.py

+11
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import torch
1212
from pytorch3d.transforms import random_rotations
13+
from pytorch3d.transforms.se3 import se3_log_map
1314
from pytorch3d.transforms.so3 import so3_exp_map
1415
from pytorch3d.transforms.transform3d import (
1516
Rotate,
@@ -161,6 +162,16 @@ def test_init_with_custom_matrix_errors(self):
161162
matrix = torch.randn(*bad_shape).float()
162163
self.assertRaises(ValueError, Transform3d, matrix=matrix)
163164

165+
def test_get_se3(self):
166+
N = 16
167+
random_rotations(N)
168+
tr = Translate(torch.rand((N, 3)))
169+
R = Rotate(random_rotations(N))
170+
transform = Transform3d().compose(R, tr)
171+
se3_log = transform.get_se3_log()
172+
gt_se3_log = se3_log_map(transform.get_matrix())
173+
self.assertClose(se3_log, gt_se3_log)
174+
164175
def test_translate(self):
165176
t = Transform3d().translate(1, 2, 3)
166177
points = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.5, 0.5, 0.0]]).view(

0 commit comments

Comments
 (0)