Skip to content

Commit c371a9a

Browse files
bottlerfacebook-github-bot
authored andcommitted
rasterizer.to without cameras
Summary: As reported in #1100, a rasterizer couldn't be moved if it was missing the optional cameras member. Fix that. This matters because the renderer.to calls rasterizer.to, so this to() could be called even by a user who never sets a cameras member. Reviewed By: nikhilaravi Differential Revision: D34643841 fbshipit-source-id: 7e26e32e8bc585eb1ee533052754a7b59bc7467a
1 parent 4a1f176 commit c371a9a

File tree

3 files changed

+16
-2
lines changed

3 files changed

+16
-2
lines changed

pytorch3d/renderer/mesh/rasterizer.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,8 @@ def __init__(self, cameras=None, raster_settings=None) -> None:
110110

111111
def to(self, device):
112112
# Manually move to device cameras as it is not a subclass of nn.Module
113-
self.cameras = self.cameras.to(device)
113+
if self.cameras is not None:
114+
self.cameras = self.cameras.to(device)
114115
return self
115116

116117
def transform(self, meshes_world, **kwargs) -> torch.Tensor:

pytorch3d/renderer/points/rasterizer.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,8 @@ def transform(self, point_clouds, **kwargs) -> torch.Tensor:
115115

116116
def to(self, device):
117117
# Manually move to device cameras as it is not a subclass of nn.Module
118-
self.cameras = self.cameras.to(device)
118+
if self.cameras is not None:
119+
self.cameras = self.cameras.to(device)
119120
return self
120121

121122
def forward(self, point_clouds, **kwargs) -> PointFragments:

tests/test_rasterizer.py

+12
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,12 @@ def test_simple_sphere(self):
134134

135135
self.assertTrue(torch.allclose(image, image_ref))
136136

137+
def test_simple_to(self):
138+
# Check that to() works without a cameras object.
139+
device = torch.device("cuda:0")
140+
rasterizer = MeshRasterizer()
141+
rasterizer.to(device)
142+
137143

138144
class TestPointRasterizer(unittest.TestCase):
139145
def test_simple_sphere(self):
@@ -203,3 +209,9 @@ def test_simple_sphere(self):
203209
image[image >= 0] = 1.0
204210
image[image < 0] = 0.0
205211
self.assertTrue(torch.allclose(image, image_ref[..., 0]))
212+
213+
def test_simple_to(self):
214+
# Check that to() works without a cameras object.
215+
device = torch.device("cuda:0")
216+
rasterizer = PointsRasterizer()
217+
rasterizer.to(device)

0 commit comments

Comments
 (0)