|
4 | 4 | # This source code is licensed under the BSD-style license found in the
|
5 | 5 | # LICENSE file in the root directory of this source tree.
|
6 | 6 |
|
7 |
| - |
8 | 7 | import math
|
| 8 | +import os |
9 | 9 | import unittest
|
| 10 | +from unittest import mock |
10 | 11 |
|
11 | 12 | import torch
|
12 | 13 | from pytorch3d.transforms import random_rotations
|
@@ -191,7 +192,25 @@ def test_translate(self):
|
191 | 192 | self.assertTrue(torch.allclose(points_out, points_out_expected))
|
192 | 193 | self.assertTrue(torch.allclose(normals_out, normals_out_expected))
|
193 | 194 |
|
194 |
| - def test_rotate(self): |
| 195 | + @mock.patch.dict(os.environ, {"PYTORCH3D_CHECK_ROTATION_MATRICES": "1"}, clear=True) |
| 196 | + def test_rotate_check_rot_valid_on(self): |
| 197 | + R = so3_exp_map(torch.randn((1, 3))) |
| 198 | + t = Transform3d().rotate(R) |
| 199 | + points = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.5, 0.5, 0.0]]).view( |
| 200 | + 1, 3, 3 |
| 201 | + ) |
| 202 | + normals = torch.tensor( |
| 203 | + [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [1.0, 1.0, 0.0]] |
| 204 | + ).view(1, 3, 3) |
| 205 | + points_out = t.transform_points(points) |
| 206 | + normals_out = t.transform_normals(normals) |
| 207 | + points_out_expected = torch.bmm(points, R) |
| 208 | + normals_out_expected = torch.bmm(normals, R) |
| 209 | + self.assertTrue(torch.allclose(points_out, points_out_expected)) |
| 210 | + self.assertTrue(torch.allclose(normals_out, normals_out_expected)) |
| 211 | + |
| 212 | + @mock.patch.dict(os.environ, {"PYTORCH3D_CHECK_ROTATION_MATRICES": "0"}, clear=True) |
| 213 | + def test_rotate_check_rot_valid_off(self): |
195 | 214 | R = so3_exp_map(torch.randn((1, 3)))
|
196 | 215 | t = Transform3d().rotate(R)
|
197 | 216 | points = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.5, 0.5, 0.0]]).view(
|
|
0 commit comments