Skip to content

Commit 46cb5aa

Browse files
Jiali Duanfacebook-github-bot
Jiali Duan
authored andcommitted
Omit _check_valid_rotation_matrix by default
Summary: According to the profiler trace D40326775, _check_valid_rotation_matrix is slow because of aten::all_close operation and _safe_det_3x3 bottlenecks. Disable the check by default unless environment variable PYTORCH3D_CHECK_ROTATION_MATRICES is set to 1. Comparison after applying the change: ``` Profiling/Function get_world_to_view (ms) Transform_points(ms) specular(ms) before 12.751 18.577 21.384 after 4.432 (34.7%) 9.248 (49.8%) 11.507 (53.8%) ``` Profiling trace: https://pxl.cl/2h687 More details in https://docs.google.com/document/d/1kfhEQfpeQToikr5OH9ZssM39CskxWoJ2p8DO5-t6eWk/edit?usp=sharing Reviewed By: kjchalup Differential Revision: D40442503 fbshipit-source-id: 954b58de47de235c9d93af441643c22868b547d0
1 parent 8339cf2 commit 46cb5aa

File tree

2 files changed

+26
-3
lines changed

2 files changed

+26
-3
lines changed

Diff for: pytorch3d/transforms/transform3d.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import math
8+
import os
89
import warnings
910
from typing import List, Optional, Union
1011

@@ -636,7 +637,10 @@ def __init__(
636637
msg = "R must have shape (3, 3) or (N, 3, 3); got %s"
637638
raise ValueError(msg % repr(R.shape))
638639
R = R.to(device=device_, dtype=dtype)
639-
_check_valid_rotation_matrix(R, tol=orthogonal_tol)
640+
if os.environ.get("PYTORCH3D_CHECK_ROTATION_MATRICES", "0") == "1":
641+
# Note: aten::all_close in the check is computationally slow, so we
642+
# only run the check when PYTORCH3D_CHECK_ROTATION_MATRICES is on.
643+
_check_valid_rotation_matrix(R, tol=orthogonal_tol)
640644
N = R.shape[0]
641645
mat = torch.eye(4, dtype=dtype, device=device_)
642646
mat = mat.view(1, 4, 4).repeat(N, 1, 1)

Diff for: tests/test_transforms.py

+21-2
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
87
import math
8+
import os
99
import unittest
10+
from unittest import mock
1011

1112
import torch
1213
from pytorch3d.transforms import random_rotations
@@ -191,7 +192,25 @@ def test_translate(self):
191192
self.assertTrue(torch.allclose(points_out, points_out_expected))
192193
self.assertTrue(torch.allclose(normals_out, normals_out_expected))
193194

194-
def test_rotate(self):
195+
@mock.patch.dict(os.environ, {"PYTORCH3D_CHECK_ROTATION_MATRICES": "1"}, clear=True)
196+
def test_rotate_check_rot_valid_on(self):
197+
R = so3_exp_map(torch.randn((1, 3)))
198+
t = Transform3d().rotate(R)
199+
points = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.5, 0.5, 0.0]]).view(
200+
1, 3, 3
201+
)
202+
normals = torch.tensor(
203+
[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [1.0, 1.0, 0.0]]
204+
).view(1, 3, 3)
205+
points_out = t.transform_points(points)
206+
normals_out = t.transform_normals(normals)
207+
points_out_expected = torch.bmm(points, R)
208+
normals_out_expected = torch.bmm(normals, R)
209+
self.assertTrue(torch.allclose(points_out, points_out_expected))
210+
self.assertTrue(torch.allclose(normals_out, normals_out_expected))
211+
212+
@mock.patch.dict(os.environ, {"PYTORCH3D_CHECK_ROTATION_MATRICES": "0"}, clear=True)
213+
def test_rotate_check_rot_valid_off(self):
195214
R = so3_exp_map(torch.randn((1, 3)))
196215
t = Transform3d().rotate(R)
197216
points = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.5, 0.5, 0.0]]).view(

0 commit comments

Comments
 (0)