Skip to content

Commit 96889de

Browse files
Krzysztof Chalupkafacebook-github-bot
Krzysztof Chalupka
authored andcommitted
SplatterPhongShader 1: Pull out common Shader functionality into ShaderBase
Summary: Most of the shaders copypaste exactly the same code into `__init__` and `to`. I will be adding a new shader in the next diff, so let's make it a bit easier. Reviewed By: bottler Differential Revision: D35767884 fbshipit-source-id: 0057e3e2ae3be4eaa49ae7e2bf3e4176953dde9d
1 parent 9f443ed commit 96889de

File tree

1 file changed

+20
-124
lines changed

1 file changed

+20
-124
lines changed

Diff for: pytorch3d/renderer/mesh/shader.py

+20-124
Original file line numberDiff line numberDiff line change
@@ -32,22 +32,7 @@
3232
# - sample colors from a texture map
3333
# - apply per pixel lighting
3434
# - blend colors across top K faces per pixel.
35-
36-
37-
class HardPhongShader(nn.Module):
38-
"""
39-
Per pixel lighting - the lighting model is applied using the interpolated
40-
coordinates and normals for each pixel. The blending function hard assigns
41-
the color of the closest face for each pixel.
42-
43-
To use the default values, simply initialize the shader with the desired
44-
device e.g.
45-
46-
.. code-block::
47-
48-
shader = HardPhongShader(device=torch.device("cuda:0"))
49-
"""
50-
35+
class ShaderBase(nn.Module):
5136
def __init__(
5237
self,
5338
device: Device = "cpu",
@@ -74,6 +59,21 @@ def to(self, device: Device):
7459
self.lights = self.lights.to(device)
7560
return self
7661

62+
63+
class HardPhongShader(ShaderBase):
64+
"""
65+
Per pixel lighting - the lighting model is applied using the interpolated
66+
coordinates and normals for each pixel. The blending function hard assigns
67+
the color of the closest face for each pixel.
68+
69+
To use the default values, simply initialize the shader with the desired
70+
device e.g.
71+
72+
.. code-block::
73+
74+
shader = HardPhongShader(device=torch.device("cuda:0"))
75+
"""
76+
7777
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
7878
cameras = kwargs.get("cameras", self.cameras)
7979
if cameras is None:
@@ -97,7 +97,7 @@ def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tenso
9797
return images
9898

9999

100-
class SoftPhongShader(nn.Module):
100+
class SoftPhongShader(ShaderBase):
101101
"""
102102
Per pixel lighting - the lighting model is applied using the interpolated
103103
coordinates and normals for each pixel. The blending function returns the
@@ -111,32 +111,6 @@ class SoftPhongShader(nn.Module):
111111
shader = SoftPhongShader(device=torch.device("cuda:0"))
112112
"""
113113

114-
def __init__(
115-
self,
116-
device: Device = "cpu",
117-
cameras: Optional[TensorProperties] = None,
118-
lights: Optional[TensorProperties] = None,
119-
materials: Optional[Materials] = None,
120-
blend_params: Optional[BlendParams] = None,
121-
) -> None:
122-
super().__init__()
123-
self.lights = lights if lights is not None else PointLights(device=device)
124-
self.materials = (
125-
materials if materials is not None else Materials(device=device)
126-
)
127-
self.cameras = cameras
128-
self.blend_params = blend_params if blend_params is not None else BlendParams()
129-
130-
# pyre-fixme[14]: `to` overrides method defined in `Module` inconsistently.
131-
def to(self, device: Device):
132-
# Manually move to device modules which are not subclasses of nn.Module
133-
cameras = self.cameras
134-
if cameras is not None:
135-
self.cameras = cameras.to(device)
136-
self.materials = self.materials.to(device)
137-
self.lights = self.lights.to(device)
138-
return self
139-
140114
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
141115
cameras = kwargs.get("cameras", self.cameras)
142116
if cameras is None:
@@ -164,7 +138,7 @@ def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tenso
164138
return images
165139

166140

167-
class HardGouraudShader(nn.Module):
141+
class HardGouraudShader(ShaderBase):
168142
"""
169143
Per vertex lighting - the lighting model is applied to the vertex colors and
170144
the colors are then interpolated using the barycentric coordinates to
@@ -179,32 +153,6 @@ class HardGouraudShader(nn.Module):
179153
shader = HardGouraudShader(device=torch.device("cuda:0"))
180154
"""
181155

182-
def __init__(
183-
self,
184-
device: Device = "cpu",
185-
cameras: Optional[TensorProperties] = None,
186-
lights: Optional[TensorProperties] = None,
187-
materials: Optional[Materials] = None,
188-
blend_params: Optional[BlendParams] = None,
189-
) -> None:
190-
super().__init__()
191-
self.lights = lights if lights is not None else PointLights(device=device)
192-
self.materials = (
193-
materials if materials is not None else Materials(device=device)
194-
)
195-
self.cameras = cameras
196-
self.blend_params = blend_params if blend_params is not None else BlendParams()
197-
198-
# pyre-fixme[14]: `to` overrides method defined in `Module` inconsistently.
199-
def to(self, device: Device):
200-
# Manually move to device modules which are not subclasses of nn.Module
201-
cameras = self.cameras
202-
if cameras is not None:
203-
self.cameras = cameras.to(device)
204-
self.materials = self.materials.to(device)
205-
self.lights = self.lights.to(device)
206-
return self
207-
208156
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
209157
cameras = kwargs.get("cameras", self.cameras)
210158
if cameras is None:
@@ -231,7 +179,7 @@ def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tenso
231179
return images
232180

233181

234-
class SoftGouraudShader(nn.Module):
182+
class SoftGouraudShader(ShaderBase):
235183
"""
236184
Per vertex lighting - the lighting model is applied to the vertex colors and
237185
the colors are then interpolated using the barycentric coordinates to
@@ -246,32 +194,6 @@ class SoftGouraudShader(nn.Module):
246194
shader = SoftGouraudShader(device=torch.device("cuda:0"))
247195
"""
248196

249-
def __init__(
250-
self,
251-
device: Device = "cpu",
252-
cameras: Optional[TensorProperties] = None,
253-
lights: Optional[TensorProperties] = None,
254-
materials: Optional[Materials] = None,
255-
blend_params: Optional[BlendParams] = None,
256-
) -> None:
257-
super().__init__()
258-
self.lights = lights if lights is not None else PointLights(device=device)
259-
self.materials = (
260-
materials if materials is not None else Materials(device=device)
261-
)
262-
self.cameras = cameras
263-
self.blend_params = blend_params if blend_params is not None else BlendParams()
264-
265-
# pyre-fixme[14]: `to` overrides method defined in `Module` inconsistently.
266-
def to(self, device: Device):
267-
# Manually move to device modules which are not subclasses of nn.Module
268-
cameras = self.cameras
269-
if cameras is not None:
270-
self.cameras = cameras.to(device)
271-
self.materials = self.materials.to(device)
272-
self.lights = self.lights.to(device)
273-
return self
274-
275197
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
276198
cameras = kwargs.get("cameras", self.cameras)
277199
if cameras is None:
@@ -320,7 +242,7 @@ def TexturedSoftPhongShader(
320242
)
321243

322244

323-
class HardFlatShader(nn.Module):
245+
class HardFlatShader(ShaderBase):
324246
"""
325247
Per face lighting - the lighting model is applied using the average face
326248
position and the face normal. The blending function hard assigns
@@ -334,32 +256,6 @@ class HardFlatShader(nn.Module):
334256
shader = HardFlatShader(device=torch.device("cuda:0"))
335257
"""
336258

337-
def __init__(
338-
self,
339-
device: Device = "cpu",
340-
cameras: Optional[TensorProperties] = None,
341-
lights: Optional[TensorProperties] = None,
342-
materials: Optional[Materials] = None,
343-
blend_params: Optional[BlendParams] = None,
344-
) -> None:
345-
super().__init__()
346-
self.lights = lights if lights is not None else PointLights(device=device)
347-
self.materials = (
348-
materials if materials is not None else Materials(device=device)
349-
)
350-
self.cameras = cameras
351-
self.blend_params = blend_params if blend_params is not None else BlendParams()
352-
353-
# pyre-fixme[14]: `to` overrides method defined in `Module` inconsistently.
354-
def to(self, device: Device):
355-
# Manually move to device modules which are not subclasses of nn.Module
356-
cameras = self.cameras
357-
if cameras is not None:
358-
self.cameras = cameras.to(device)
359-
self.materials = self.materials.to(device)
360-
self.lights = self.lights.to(device)
361-
return self
362-
363259
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
364260
cameras = kwargs.get("cameras", self.cameras)
365261
if cameras is None:

0 commit comments

Comments
 (0)