Skip to content

Commit 4ecc9ea

Browse files
d4l3kfacebook-github-bot
authored andcommitted
shader: fix HardDepthShader sizes + tests (#1252)
Summary: This fixes a indexing bug in HardDepthShader and adds proper unit tests for both of the DepthShaders. This bug was introduced when updating the shader sizes and discovered when I switched my local model onto pytorch3d trunk instead of the patched copy. Pull Request resolved: #1252 Test Plan: Unit test + custom model code ``` pytest tests/test_shader.py ``` ![image](https://user-images.githubusercontent.com/909104/178397456-f478d0e0-9f6c-467a-a85b-adb4c47adfee.png) Reviewed By: bottler Differential Revision: D37775767 Pulled By: d4l3k fbshipit-source-id: 5f001903985976d7067d1fa0a3102d602790e3e8
1 parent 8d10ba5 commit 4ecc9ea

File tree

2 files changed

+35
-3
lines changed

2 files changed

+35
-3
lines changed

pytorch3d/renderer/mesh/shader.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -374,11 +374,11 @@ def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tenso
374374
cameras = super()._get_cameras(**kwargs)
375375

376376
zfar = kwargs.get("zfar", getattr(cameras, "zfar", 100.0))
377-
mask = fragments.pix_to_face < 0
377+
mask = fragments.pix_to_face[..., 0:1] < 0
378378

379-
zbuf = fragments.zbuf[..., 0].clone()
379+
zbuf = fragments.zbuf[..., 0:1].clone()
380380
zbuf[mask] = zfar
381-
return zbuf.unsqueeze(3)
381+
return zbuf
382382

383383

384384
class SoftDepthShader(ShaderBase):

tests/test_shader.py

+32
Original file line numberDiff line numberDiff line change
@@ -91,3 +91,35 @@ def test_cameras_check(self):
9191

9292
with self.assertRaises(ValueError):
9393
shader(fragments, meshes)
94+
95+
def test_depth_shader(self):
96+
shader_classes = [
97+
HardDepthShader,
98+
SoftDepthShader,
99+
]
100+
101+
verts = torch.tensor(
102+
[[-1, -1, 0], [1, -1, 1], [1, 1, 0], [-1, 1, 1]], dtype=torch.float32
103+
)
104+
faces = torch.tensor([[0, 1, 2], [2, 3, 0]], dtype=torch.int64)
105+
meshes = Meshes(verts=[verts], faces=[faces])
106+
107+
pix_to_face = torch.tensor([0, 1], dtype=torch.int64).view(1, 1, 1, 2)
108+
barycentric_coords = torch.tensor(
109+
[[0.1, 0.2, 0.7], [0.3, 0.5, 0.2]], dtype=torch.float32
110+
).view(1, 1, 1, 2, -1)
111+
for faces_per_pixel in [1, 2]:
112+
fragments = Fragments(
113+
pix_to_face=pix_to_face[:, :, :, :faces_per_pixel],
114+
bary_coords=barycentric_coords[:, :, :, :faces_per_pixel],
115+
zbuf=torch.ones_like(pix_to_face),
116+
dists=torch.ones_like(pix_to_face),
117+
)
118+
R, T = look_at_view_transform()
119+
cameras = PerspectiveCameras(R=R, T=T)
120+
121+
for shader_class in shader_classes:
122+
shader = shader_class()
123+
124+
out = shader(fragments, meshes, cameras=cameras)
125+
self.assertEqual(out.shape, (1, 1, 1, 1))

0 commit comments

Comments
 (0)