Skip to content

Commit

Permalink
Lighting broadcasting bug fix
Browse files Browse the repository at this point in the history
Summary: Fixed multiple issues with shape broadcasting in lighting, shading and blending and updated the tests.

Reviewed By: bottler

Differential Revision: D28997941

fbshipit-source-id: d3ef93f979344076b1d9098a86178b4da63844c8
  • Loading branch information
nikhilaravi authored and facebook-github-bot committed Jun 14, 2021
1 parent 9de627e commit bc8361f
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 31 deletions.
18 changes: 13 additions & 5 deletions pytorch3d/renderer/blending.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.


from typing import NamedTuple, Sequence
from typing import NamedTuple, Sequence, Union

import torch
from pytorch3d import _C # pyre-fixme[21]: Could not find name `_C` in `pytorch3d`.


# Example functions for blending the top K colors per pixel using the outputs
# from rasterization.
# NOTE: All blending function should return an RGBA image per batch element
Expand Down Expand Up @@ -117,7 +116,11 @@ def sigmoid_alpha_blend(colors, fragments, blend_params) -> torch.Tensor:


def softmax_rgb_blend(
colors, fragments, blend_params, znear: float = 1.0, zfar: float = 100
colors,
fragments,
blend_params,
znear: Union[float, torch.Tensor] = 1.0,
zfar: Union[float, torch.Tensor] = 100,
) -> torch.Tensor:
"""
RGB and alpha channel blending to return an RGBA image based on the method
Expand Down Expand Up @@ -184,11 +187,16 @@ def softmax_rgb_blend(
# overflow. zbuf shape (N, H, W, K), find max over K.
# TODO: there may still be some instability in the exponent calculation.

# Reshape to be compatible with (N, H, W, K) values in fragments
if torch.is_tensor(zfar):
# pyre-fixme[16]
zfar = zfar[:, None, None, None]
if torch.is_tensor(znear):
znear = znear[:, None, None, None]

z_inv = (zfar - fragments.zbuf) / (zfar - znear) * mask
# pyre-fixme[16]: `Tuple` has no attribute `values`.
# pyre-fixme[6]: Expected `Tensor` for 1st param but got `float`.
z_inv_max = torch.max(z_inv, dim=-1).values[..., None].clamp(min=eps)
# pyre-fixme[6]: Expected `Tensor` for 1st param but got `float`.
weights_num = prob_map * torch.exp((z_inv - z_inv_max) / blend_params.gamma)

# Also apply exp normalize trick for the background color weight.
Expand Down
18 changes: 16 additions & 2 deletions pytorch3d/renderer/lighting.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,12 +253,26 @@ def clone(self):
other = self.__class__(device=self.device)
return super().clone(other)

def reshape_location(self, points) -> torch.Tensor:
"""
Reshape the location tensor to have dimensions
compatible with the points which can either be of
shape (P, 3) or (N, H, W, K, 3).
"""
if self.location.ndim == points.ndim:
# pyre-fixme[7]
return self.location
# pyre-fixme[29]
return self.location[:, None, None, None, :]

def diffuse(self, normals, points) -> torch.Tensor:
direction = self.location - points
location = self.reshape_location(points)
direction = location - points
return diffuse(normals=normals, color=self.diffuse_color, direction=direction)

def specular(self, normals, points, camera_position, shininess) -> torch.Tensor:
direction = self.location - points
location = self.reshape_location(points)
direction = location - points
return specular(
points=points,
normals=normals,
Expand Down
10 changes: 8 additions & 2 deletions pytorch3d/renderer/mesh/shading.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ def _apply_lighting(
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
points: torch tensor of shape (N, P, 3) or (P, 3).
normals: torch tensor of shape (N, P, 3) or (P, 3)
points: torch tensor of shape (N, ..., 3) or (P, 3).
normals: torch tensor of shape (N, ..., 3) or (P, 3)
lights: instance of the Lights class.
cameras: instance of the Cameras class.
materials: instance of the Materials class.
Expand All @@ -35,13 +35,19 @@ def _apply_lighting(
ambient_color = materials.ambient_color * lights.ambient_color
diffuse_color = materials.diffuse_color * light_diffuse
specular_color = materials.specular_color * light_specular

if normals.dim() == 2 and points.dim() == 2:
# If given packed inputs remove batch dim in output.
return (
ambient_color.squeeze(),
diffuse_color.squeeze(),
specular_color.squeeze(),
)

if ambient_color.ndim != diffuse_color.ndim:
# Reshape from (N, 3) to have dimensions compatible with
# diffuse_color which is of shape (N, H, W, K, 3)
ambient_color = ambient_color[:, None, None, None, :]
return ambient_color, diffuse_color, specular_color


Expand Down
58 changes: 36 additions & 22 deletions tests/test_render_meshes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""
import os
import unittest
from collections import namedtuple

import numpy as np
import torch
Expand Down Expand Up @@ -53,6 +54,8 @@
DATA_DIR = get_tests_dir() / "data"
TUTORIAL_DATA_DIR = get_pytorch3d_dir() / "docs/tutorials/data"

ShaderTest = namedtuple("ShaderTest", ["shader", "reference_name", "debug_name"])


class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
def test_simple_sphere(self, elevated_camera=False, check_depth=False):
Expand Down Expand Up @@ -107,13 +110,13 @@ def test_simple_sphere(self, elevated_camera=False, check_depth=False):
blend_params = BlendParams(1e-4, 1e-4, (0, 0, 0))

# Test several shaders
shaders = {
"phong": HardPhongShader,
"gouraud": HardGouraudShader,
"flat": HardFlatShader,
}
for (name, shader_init) in shaders.items():
shader = shader_init(
shader_tests = [
ShaderTest(HardPhongShader, "phong", "hard_phong"),
ShaderTest(HardGouraudShader, "gouraud", "hard_gouraud"),
ShaderTest(HardFlatShader, "flat", "hard_flat"),
]
for test in shader_tests:
shader = test.shader(
lights=lights,
cameras=cameras,
materials=materials,
Expand All @@ -135,7 +138,7 @@ def test_simple_sphere(self, elevated_camera=False, check_depth=False):

rgb = images[0, ..., :3].squeeze().cpu()
filename = "simple_sphere_light_%s%s%s.png" % (
name,
test.reference_name,
postfix,
cam_type.__name__,
)
Expand All @@ -144,7 +147,12 @@ def test_simple_sphere(self, elevated_camera=False, check_depth=False):
self.assertClose(rgb, image_ref, atol=0.05)

if DEBUG:
filename = "DEBUG_%s" % filename
debug_filename = "simple_sphere_light_%s%s%s.png" % (
test.debug_name,
postfix,
cam_type.__name__,
)
filename = "DEBUG_%s" % debug_filename
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / filename
)
Expand Down Expand Up @@ -269,7 +277,8 @@ def test_simple_sphere_screen(self):
def test_simple_sphere_batched(self):
"""
Test a mesh with vertex textures can be extended to form a batch, and
is rendered correctly with Phong, Gouraud and Flat Shaders.
is rendered correctly with Phong, Gouraud and Flat Shaders with batched
lighting and hard and soft blending.
"""
batch_size = 5
device = torch.device("cuda:0")
Expand All @@ -291,24 +300,28 @@ def test_simple_sphere_batched(self):
R, T = look_at_view_transform(dist, elev, azim)
cameras = FoVPerspectiveCameras(device=device, R=R, T=T)
raster_settings = RasterizationSettings(
image_size=512, blur_radius=0.0, faces_per_pixel=1
image_size=512, blur_radius=0.0, faces_per_pixel=4
)

# Init shader settings
materials = Materials(device=device)
lights = PointLights(device=device)
lights.location = torch.tensor([0.0, 0.0, +2.0], device=device)[None]
lights_location = torch.tensor([0.0, 0.0, +2.0], device=device)
lights_location = lights_location[None].expand(batch_size, -1)
lights = PointLights(device=device, location=lights_location)
blend_params = BlendParams(1e-4, 1e-4, (0, 0, 0))

# Init renderer
rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
shaders = {
"phong": HardPhongShader,
"gouraud": HardGouraudShader,
"flat": HardFlatShader,
}
for (name, shader_init) in shaders.items():
shader = shader_init(
shader_tests = [
ShaderTest(HardPhongShader, "phong", "hard_phong"),
ShaderTest(SoftPhongShader, "phong", "soft_phong"),
ShaderTest(HardGouraudShader, "gouraud", "hard_gouraud"),
ShaderTest(HardFlatShader, "flat", "hard_flat"),
]
for test in shader_tests:
reference_name = test.reference_name
debug_name = test.debug_name
shader = test.shader(
lights=lights,
cameras=cameras,
materials=materials,
Expand All @@ -317,14 +330,15 @@ def test_simple_sphere_batched(self):
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
images = renderer(sphere_meshes)
image_ref = load_rgb_image(
"test_simple_sphere_light_%s_%s.png" % (name, type(cameras).__name__),
"test_simple_sphere_light_%s_%s.png"
% (reference_name, type(cameras).__name__),
DATA_DIR,
)
for i in range(batch_size):
rgb = images[i, ..., :3].squeeze().cpu()
if i == 0 and DEBUG:
filename = "DEBUG_simple_sphere_batched_%s_%s.png" % (
name,
debug_name,
type(cameras).__name__,
)
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
Expand Down

0 comments on commit bc8361f

Please sign in to comment.