Skip to content

Commit

Permalink
Disable gradient calculation in _check_valid_rotation_matrix()
Browse files Browse the repository at this point in the history
Summary:
# Make `transform3d.py` a little bit better (performance and code quality)

## 1. Add decorator `torch.no_grad()` to the function `_check_valid_rotation_matrix()`

Function `_check_valid_rotation_matrix()` is needed to identify errors during forward pass only, it's not used for gradients.

## 2. Replace two calls `to` with the single one

Reviewed By: bottler

Differential Revision: D29656501

fbshipit-source-id: 4419e24dbf436c1b60abf77bda4376fb87a593be
  • Loading branch information
Alexey Sidnev authored and facebook-github-bot committed Jul 16, 2021
1 parent 0c02ae9 commit 2f668ec
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion pytorch3d/transforms/transform3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,7 @@ def __init__(
if R.shape[-2:] != (3, 3):
msg = "R must have shape (3, 3) or (N, 3, 3); got %s"
raise ValueError(msg % repr(R.shape))
R = R.to(dtype=dtype).to(device=device_)
R = R.to(device=device_, dtype=dtype)
_check_valid_rotation_matrix(R, tol=orthogonal_tol)
N = R.shape[0]
mat = torch.eye(4, dtype=dtype, device=device_)
Expand Down Expand Up @@ -752,6 +752,9 @@ def _broadcast_bmm(a, b):
return a.bmm(b)


# pyre-fixme[56]: Decorator `torch.no_grad(...)` could not be called, because
# its type `no_grad` is not callable.
@torch.no_grad()
def _check_valid_rotation_matrix(R, tol: float = 1e-7):
"""
Determine if R is a valid rotation matrix by checking it satisfies the
Expand Down

0 comments on commit 2f668ec

Please sign in to comment.