Skip to content

Commit 9dc28f5

Browse files
davnov134facebook-github-bot
authored andcommitted
Fixes for RayBundle plotting
Summary: Fixes some issues with RayBundle plotting: - allows plotting raybundles on gpu - view -> reshape since we do not require contiguous raybundle tensors as input Reviewed By: bottler, shapovalov Differential Revision: D42665923 fbshipit-source-id: e9c6c7810428365dca4cb5ec80ef15ff28644163
1 parent a12612a commit 9dc28f5

File tree

3 files changed

+100
-3
lines changed

3 files changed

+100
-3
lines changed

pytorch3d/vis/__init__.py

+16
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,19 @@
33
#
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
6+
7+
import warnings
8+
9+
10+
try:
11+
from .plotly_vis import get_camera_wireframe, plot_batch_individually, plot_scene
12+
except ModuleNotFoundError as err:
13+
if "plotly" in str(err):
14+
warnings.warn(
15+
"Cannot import plotly-based visualization code."
16+
" Please install plotly to enable (pip install plotly)."
17+
)
18+
else:
19+
raise
20+
21+
from .texture_vis import texturesuv_image_matplotlib, texturesuv_image_PIL

pytorch3d/vis/plotly_vis.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ class Lighting(NamedTuple): # pragma: no cover
100100
vertexnormalsepsilon: float = 1e-12
101101

102102

103+
@torch.no_grad()
103104
def plot_scene(
104105
plots: Dict[str, Dict[str, Struct]],
105106
*,
@@ -407,6 +408,7 @@ def plot_scene(
407408
return fig
408409

409410

411+
@torch.no_grad()
410412
def plot_batch_individually(
411413
batched_structs: Union[
412414
List[Struct],
@@ -888,8 +890,12 @@ def _add_ray_bundle_trace(
888890
)
889891

890892
# make the ray lines for plotly plotting
891-
nan_tensor = torch.Tensor([[float("NaN")] * 3])
892-
ray_lines = torch.empty(size=(1, 3))
893+
nan_tensor = torch.tensor(
894+
[[float("NaN")] * 3],
895+
device=ray_lines_endpoints.device,
896+
dtype=ray_lines_endpoints.dtype,
897+
)
898+
ray_lines = torch.empty(size=(1, 3), device=ray_lines_endpoints.device)
893899
for ray_line in ray_lines_endpoints:
894900
# We combine the ray lines into a single tensor to plot them in a
895901
# single trace. The NaNs are inserted between sets of ray lines
@@ -952,7 +958,7 @@ def _add_ray_bundle_trace(
952958
current_layout = fig["layout"][plot_scene]
953959

954960
# update the bounds of the axes for the current trace
955-
all_ray_points = ray_bundle_to_ray_points(ray_bundle).view(-1, 3)
961+
all_ray_points = ray_bundle_to_ray_points(ray_bundle).reshape(-1, 3)
956962
ray_points_center = all_ray_points.mean(dim=0)
957963
max_expand = (all_ray_points.max(0)[0] - all_ray_points.min(0)[0]).max().item()
958964
_update_axes_bounds(ray_points_center, float(max_expand), current_layout)
@@ -1002,6 +1008,7 @@ def _update_axes_bounds(
10021008
max_expand: the maximum spread in any dimension of the trace's vertices.
10031009
current_layout: the plotly figure layout scene corresponding to the referenced trace.
10041010
"""
1011+
verts_center = verts_center.detach().cpu()
10051012
verts_min = verts_center - max_expand
10061013
verts_max = verts_center + max_expand
10071014
bounds = torch.t(torch.stack((verts_min, verts_max)))

tests/test_vis.py

+74
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from pytorch3d.renderer import HeterogeneousRayBundle, PerspectiveCameras, RayBundle
11+
from pytorch3d.structures import Meshes, Pointclouds
12+
from pytorch3d.transforms import random_rotations
13+
14+
# Some of these imports are only needed for testing code coverage
15+
from pytorch3d.vis import ( # noqa: F401
16+
get_camera_wireframe, # noqa: F401
17+
plot_batch_individually, # noqa: F401
18+
plot_scene,
19+
texturesuv_image_PIL, # noqa: F401
20+
)
21+
22+
23+
class TestPlotlyVis(unittest.TestCase):
24+
def test_plot_scene(
25+
self,
26+
B: int = 3,
27+
n_rays: int = 128,
28+
n_pts_per_ray: int = 32,
29+
n_verts: int = 32,
30+
n_edges: int = 64,
31+
n_pts: int = 256,
32+
):
33+
"""
34+
Tests plotting of all supported structures using plot_scene.
35+
"""
36+
for device in ["cpu", "cuda:0"]:
37+
plot_scene(
38+
{
39+
"scene": {
40+
"ray_bundle": RayBundle(
41+
origins=torch.randn(B, n_rays, 3, device=device),
42+
xys=torch.randn(B, n_rays, 2, device=device),
43+
directions=torch.randn(B, n_rays, 3, device=device),
44+
lengths=torch.randn(
45+
B, n_rays, n_pts_per_ray, device=device
46+
),
47+
),
48+
"heterogeneous_ray_bundle": HeterogeneousRayBundle(
49+
origins=torch.randn(B * n_rays, 3, device=device),
50+
xys=torch.randn(B * n_rays, 2, device=device),
51+
directions=torch.randn(B * n_rays, 3, device=device),
52+
lengths=torch.randn(
53+
B * n_rays, n_pts_per_ray, device=device
54+
),
55+
camera_ids=torch.randint(
56+
low=0, high=B, size=(B * n_rays,), device=device
57+
),
58+
),
59+
"camera": PerspectiveCameras(
60+
R=random_rotations(B, device=device),
61+
T=torch.randn(B, 3, device=device),
62+
),
63+
"mesh": Meshes(
64+
verts=torch.randn(B, n_verts, 3, device=device),
65+
faces=torch.randint(
66+
low=0, high=n_verts, size=(B, n_edges, 3), device=device
67+
),
68+
),
69+
"point_clouds": Pointclouds(
70+
points=torch.randn(B, n_pts, 3, device=device),
71+
),
72+
}
73+
}
74+
)

0 commit comments

Comments
 (0)