-
Notifications
You must be signed in to change notification settings - Fork 149
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Broadcast tensors before stacking in pytorch backend
Topic: broadcast_and_stack_symforce_pytorch Noticed in rare cases the generated pytorch functions didn't properly broadcast because the tensors used as input to stack weren't the same shape. For example, linear camera unprojection ``` _ray = torch.stack( [ (-cal[..., 2] + pixel[..., 0]) / cal[..., 0], (-cal[..., 3] + pixel[..., 1]) / cal[..., 1], torch.tensor(1, **tensor_kwargs), ], dim=-1, ) ``` This PR fixes this issue by broadcasting tensors to a common shape before stacking. ``` _ray = broadcast_and_stack( [ (-cal[..., 2] + pixel[..., 0]) / cal[..., 0], (-cal[..., 3] + pixel[..., 1]) / cal[..., 1], torch.tensor(1, **tensor_kwargs), ], dim=-1, ) ``` where ``` def broadcast_and_stack(tensors, dim=-1): # type: (T.List[torch.Tensor], int) -> torch.Tensor """ broadcast tensors to common shape then stack along new dimension """ broadcast_shape = torch.broadcast_shapes(*(x.size() for x in tensors)) broadcast_tensors = [x.broadcast_to(broadcast_shape) for x in tensors] return torch.stack(broadcast_tensors, dim=dim) ``` Reviewers: GitOrigin-RevId: a33e3605ba6672c20f4a7b3b9a86c78009ccb62f
- Loading branch information
1 parent
0e0fafc
commit 1387297
Showing
6 changed files
with
85 additions
and
26 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters