Skip to content

Commit 39bb2ce

Browse files
bottlerfacebook-github-bot
authored andcommitted
Join cameras as batch
Summary: Function to join a list of cameras objects into a single batched object. FB: In the next diff I will remove the `concatenate_cameras` function in implicitron and update the callsites. Reviewed By: nikhilaravi Differential Revision: D33198209 fbshipit-source-id: 0c9f5f5df498a0def9dba756c984e6a946618158
1 parent 9e2bc3a commit 39bb2ce

File tree

5 files changed

+187
-9
lines changed

5 files changed

+187
-9
lines changed

Diff for: pytorch3d/renderer/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
sigmoid_alpha_blend,
1111
softmax_rgb_blend,
1212
)
13-
from .camera_utils import rotate_on_spot
13+
from .camera_utils import join_cameras_as_batch, rotate_on_spot
1414
from .cameras import OpenGLOrthographicCameras # deprecated
1515
from .cameras import OpenGLPerspectiveCameras # deprecated
1616
from .cameras import SfMOrthographicCameras # deprecated
@@ -29,6 +29,7 @@
2929
AbsorptionOnlyRaymarcher,
3030
EmissionAbsorptionRaymarcher,
3131
GridRaysampler,
32+
HarmonicEmbedding,
3233
ImplicitRenderer,
3334
MonteCarloRaysampler,
3435
NDCGridRaysampler,
@@ -37,7 +38,6 @@
3738
VolumeSampler,
3839
ray_bundle_to_ray_points,
3940
ray_bundle_variables_to_ray_points,
40-
HarmonicEmbedding,
4141
)
4242
from .lighting import AmbientLights, DirectionalLights, PointLights, diffuse, specular
4343
from .materials import Materials

Diff for: pytorch3d/renderer/camera_utils.py

+65-1
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Tuple
7+
from typing import Sequence, Tuple
88

99
import torch
1010
from pytorch3d.transforms import Transform3d
1111

12+
from .cameras import CamerasBase
13+
1214

1315
def camera_to_eye_at_up(
1416
world_to_view_transform: Transform3d,
@@ -141,3 +143,65 @@ def rotate_on_spot(
141143
new_T = torch.matmul(new_R.transpose(1, 2), old_RT)[:, :, 0]
142144

143145
return new_R, new_T
146+
147+
148+
def join_cameras_as_batch(cameras_list: Sequence[CamerasBase]) -> CamerasBase:
149+
"""
150+
Create a batched cameras object by concatenating a list of input
151+
cameras objects. All the tensor attributes will be joined along
152+
the batch dimension.
153+
154+
Args:
155+
cameras_list: List of camera classes all of the same type and
156+
on the same device. Each represents one or more cameras.
157+
Returns:
158+
cameras: single batched cameras object of the same
159+
type as all the objects in the input list.
160+
"""
161+
# Get the type and fields to join from the first camera in the batch
162+
c0 = cameras_list[0]
163+
fields = c0._FIELDS
164+
shared_fields = c0._SHARED_FIELDS
165+
166+
if not all(isinstance(c, CamerasBase) for c in cameras_list):
167+
raise ValueError("cameras in cameras_list must inherit from CamerasBase")
168+
169+
if not all(type(c) is type(c0) for c in cameras_list[1:]):
170+
raise ValueError("All cameras must be of the same type")
171+
172+
if not all(c.device == c0.device for c in cameras_list[1:]):
173+
raise ValueError("All cameras in the batch must be on the same device")
174+
175+
# Concat the fields to make a batched tensor
176+
kwargs = {}
177+
kwargs["device"] = c0.device
178+
179+
for field in fields:
180+
field_not_none = [(getattr(c, field) is not None) for c in cameras_list]
181+
if not any(field_not_none):
182+
continue
183+
if not all(field_not_none):
184+
raise ValueError(f"Attribute {field} is inconsistently present")
185+
186+
attrs_list = [getattr(c, field) for c in cameras_list]
187+
188+
if field in shared_fields:
189+
# Only needs to be set once
190+
if not all(a == attrs_list[0] for a in attrs_list):
191+
raise ValueError(f"Attribute {field} is not constant across inputs")
192+
193+
# e.g. "in_ndc" is set as attribute "_in_ndc" on the class
194+
# but provided as "in_ndc" in the input args
195+
if field.startswith("_"):
196+
field = field[1:]
197+
198+
kwargs[field] = attrs_list[0]
199+
elif isinstance(attrs_list[0], torch.Tensor):
200+
# In the init, all inputs will be converted to
201+
# batched tensors before set as attributes
202+
# Join as a tensor along the batch dimension
203+
kwargs[field] = torch.cat(attrs_list, dim=0)
204+
else:
205+
raise ValueError(f"Field {field} type is not supported for batching")
206+
207+
return c0.__class__(**kwargs)

Diff for: pytorch3d/renderer/cameras.py

+24-1
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,12 @@ class CamerasBase(TensorProperties):
7777

7878
# Used in __getitem__ to index the relevant fields
7979
# When creating a new camera, this should be set in the __init__
80-
_FIELDS: Tuple = ()
80+
_FIELDS: Tuple[str, ...] = ()
81+
82+
# Names of fields which are a constant property of the whole batch, rather
83+
# than themselves a batch of data.
84+
# When joining objects into a batch, they will have to agree.
85+
_SHARED_FIELDS: Tuple[str, ...] = ()
8186

8287
def get_projection_transform(self):
8388
"""
@@ -499,6 +504,8 @@ class FoVPerspectiveCameras(CamerasBase):
499504
"degrees",
500505
)
501506

507+
_SHARED_FIELDS = ("degrees",)
508+
502509
def __init__(
503510
self,
504511
znear=1.0,
@@ -997,6 +1004,8 @@ class PerspectiveCameras(CamerasBase):
9971004
"image_size",
9981005
)
9991006

1007+
_SHARED_FIELDS = ("_in_ndc",)
1008+
10001009
def __init__(
10011010
self,
10021011
focal_length=1.0,
@@ -1047,6 +1056,12 @@ def __init__(
10471056
else:
10481057
self.image_size = None
10491058

1059+
# When focal length is provided as one value, expand to
1060+
# create (N, 2) shape tensor
1061+
if self.focal_length.ndim == 1: # (N,)
1062+
self.focal_length = self.focal_length[:, None] # (N, 1)
1063+
self.focal_length = self.focal_length.expand(-1, 2) # (N, 2)
1064+
10501065
def get_projection_transform(self, **kwargs) -> Transform3d:
10511066
"""
10521067
Calculate the projection matrix using the
@@ -1227,6 +1242,8 @@ class OrthographicCameras(CamerasBase):
12271242
"image_size",
12281243
)
12291244

1245+
_SHARED_FIELDS = ("_in_ndc",)
1246+
12301247
def __init__(
12311248
self,
12321249
focal_length=1.0,
@@ -1276,6 +1293,12 @@ def __init__(
12761293
else:
12771294
self.image_size = None
12781295

1296+
# When focal length is provided as one value, expand to
1297+
# create (N, 2) shape tensor
1298+
if self.focal_length.ndim == 1: # (N,)
1299+
self.focal_length = self.focal_length[:, None] # (N, 1)
1300+
self.focal_length = self.focal_length.expand(-1, 2) # (N, 2)
1301+
12791302
def get_projection_transform(self, **kwargs) -> Transform3d:
12801303
"""
12811304
Calculate the projection matrix using

Diff for: tests/test_camera_pixels.py

-3
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,4 @@ def test_camera(self):
250250
],
251251
dim=1,
252252
)
253-
254-
print(wanted)
255-
print(camera_points[batch_idx])
256253
self.assertClose(camera_points[batch_idx], wanted)

Diff for: tests/test_cameras.py

+96-2
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
import numpy as np
3737
import torch
3838
from common_testing import TestCaseMixin
39+
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
3940
from pytorch3d.renderer.cameras import (
4041
CamerasBase,
4142
FoVOrthographicCameras,
@@ -688,6 +689,99 @@ def test_clone(self, batch_size: int = 10):
688689
else:
689690
self.assertTrue(val == val_clone)
690691

692+
def test_join_cameras_as_batch_errors(self):
693+
cam0 = PerspectiveCameras(device="cuda:0")
694+
cam1 = OrthographicCameras(device="cuda:0")
695+
696+
# Cameras not of the same type
697+
with self.assertRaisesRegex(ValueError, "same type"):
698+
join_cameras_as_batch([cam0, cam1])
699+
700+
cam2 = OrthographicCameras(device="cpu")
701+
# Cameras not on the same device
702+
with self.assertRaisesRegex(ValueError, "same device"):
703+
join_cameras_as_batch([cam1, cam2])
704+
705+
cam3 = OrthographicCameras(in_ndc=False, device="cuda:0")
706+
# Different coordinate systems -- all should be in ndc or in screen
707+
with self.assertRaisesRegex(
708+
ValueError, "Attribute _in_ndc is not constant across inputs"
709+
):
710+
join_cameras_as_batch([cam1, cam3])
711+
712+
def join_cameras_as_batch_fov(self, camera_cls):
713+
R0 = torch.randn((6, 3, 3))
714+
R1 = torch.randn((3, 3, 3))
715+
cam0 = camera_cls(znear=10.0, zfar=100.0, R=R0, device="cuda:0")
716+
cam1 = camera_cls(znear=10.0, zfar=200.0, R=R1, device="cuda:0")
717+
718+
cam_batch = join_cameras_as_batch([cam0, cam1])
719+
720+
self.assertEqual(cam_batch._N, cam0._N + cam1._N)
721+
self.assertEqual(cam_batch.device, cam0.device)
722+
self.assertClose(cam_batch.R, torch.cat((R0, R1), dim=0).to(device="cuda:0"))
723+
724+
def join_cameras_as_batch(self, camera_cls):
725+
R0 = torch.randn((6, 3, 3))
726+
R1 = torch.randn((3, 3, 3))
727+
p0 = torch.randn((6, 2, 1))
728+
p1 = torch.randn((3, 2, 1))
729+
f0 = 5.0
730+
f1 = torch.randn(3, 2)
731+
f2 = torch.randn(3, 1)
732+
cam0 = camera_cls(
733+
R=R0,
734+
focal_length=f0,
735+
principal_point=p0,
736+
)
737+
cam1 = camera_cls(
738+
R=R1,
739+
focal_length=f0,
740+
principal_point=p1,
741+
)
742+
cam2 = camera_cls(
743+
R=R1,
744+
focal_length=f1,
745+
principal_point=p1,
746+
)
747+
cam3 = camera_cls(
748+
R=R1,
749+
focal_length=f2,
750+
principal_point=p1,
751+
)
752+
cam_batch = join_cameras_as_batch([cam0, cam1])
753+
754+
self.assertEqual(cam_batch._N, cam0._N + cam1._N)
755+
self.assertEqual(cam_batch.device, cam0.device)
756+
self.assertClose(cam_batch.R, torch.cat((R0, R1), dim=0))
757+
self.assertClose(cam_batch.principal_point, torch.cat((p0, p1), dim=0))
758+
self.assertEqual(cam_batch._in_ndc, cam0._in_ndc)
759+
760+
# Test one broadcasted value and one fixed value
761+
# Focal length as (N,) in one camera and (N, 2) in the other
762+
cam_batch = join_cameras_as_batch([cam0, cam2])
763+
self.assertEqual(cam_batch._N, cam0._N + cam2._N)
764+
self.assertClose(cam_batch.R, torch.cat((R0, R1), dim=0))
765+
self.assertClose(
766+
cam_batch.focal_length,
767+
torch.cat([torch.tensor([[f0, f0]]).expand(6, -1), f1], dim=0),
768+
)
769+
770+
# Focal length as (N, 1) in one camera and (N, 2) in the other
771+
cam_batch = join_cameras_as_batch([cam2, cam3])
772+
self.assertClose(
773+
cam_batch.focal_length,
774+
torch.cat([f1, f2.expand(-1, 2)], dim=0),
775+
)
776+
777+
def test_join_batch_perspective(self):
778+
self.join_cameras_as_batch_fov(FoVPerspectiveCameras)
779+
self.join_cameras_as_batch(PerspectiveCameras)
780+
781+
def test_join_batch_orthographic(self):
782+
self.join_cameras_as_batch_fov(FoVOrthographicCameras)
783+
self.join_cameras_as_batch(OrthographicCameras)
784+
691785

692786
############################################################
693787
# FoVPerspective Camera #
@@ -1055,7 +1149,7 @@ def test_getitem(self):
10551149
index = torch.tensor([1, 3, 5], dtype=torch.int64)
10561150
c135 = cam[index]
10571151
self.assertEqual(len(c135), 3)
1058-
self.assertClose(c135.focal_length, torch.tensor([5.0] * 3))
1152+
self.assertClose(c135.focal_length, torch.tensor([[5.0, 5.0]] * 3))
10591153
self.assertClose(c135.R, R_matrix[[1, 3, 5], ...])
10601154
self.assertClose(c135.principal_point, principal_point[[1, 3, 5], ...])
10611155

@@ -1131,7 +1225,7 @@ def test_getitem(self):
11311225
index = torch.tensor([1, 3, 5], dtype=torch.int64)
11321226
c135 = cam[index]
11331227
self.assertEqual(len(c135), 3)
1134-
self.assertClose(c135.focal_length, torch.tensor([5.0] * 3))
1228+
self.assertClose(c135.focal_length, torch.tensor([[5.0, 5.0]] * 3))
11351229
self.assertClose(c135.R, R_matrix[[1, 3, 5], ...])
11361230
self.assertClose(c135.principal_point, principal_point[[1, 3, 5], ...])
11371231

0 commit comments

Comments
 (0)