Skip to content

Commit

Permalink
Broadcast tensors before stacking in pytorch backend
Browse files Browse the repository at this point in the history
Topic: broadcast_and_stack_symforce_pytorch

Noticed in rare cases the generated pytorch functions didn't properly broadcast because the tensors used as input to stack weren't the same shape. For example, linear camera unprojection
```
_ray = torch.stack(
    [
        (-cal[..., 2] + pixel[..., 0]) / cal[..., 0],
        (-cal[..., 3] + pixel[..., 1]) / cal[..., 1],
        torch.tensor(1, **tensor_kwargs),
    ],
    dim=-1,
)
```

This PR fixes this issue by broadcasting tensors to a common shape before stacking.

```
_ray = broadcast_and_stack(
    [
        (-cal[..., 2] + pixel[..., 0]) / cal[..., 0],
        (-cal[..., 3] + pixel[..., 1]) / cal[..., 1],
        torch.tensor(1, **tensor_kwargs),
    ],
    dim=-1,
)
```
where
```
def broadcast_and_stack(tensors, dim=-1):
    # type: (T.List[torch.Tensor], int) -> torch.Tensor
    """
    broadcast tensors to common shape then stack along new dimension
    """

    broadcast_shape = torch.broadcast_shapes(*(x.size() for x in tensors))
    broadcast_tensors = [x.broadcast_to(broadcast_shape) for x in tensors]

    return torch.stack(broadcast_tensors, dim=dim)
```

Reviewers:
GitOrigin-RevId: a33e3605ba6672c20f4a7b3b9a86c78009ccb62f
  • Loading branch information
zachary-teed-skydio authored and aaron-skydio committed May 31, 2023
1 parent 0e0fafc commit 1387297
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,17 @@ class TensorKwargs(T.TypedDict):
dtype: torch.dtype


def _broadcast_and_stack(tensors, dim=-1):
# type: (T.List[torch.Tensor], int) -> torch.Tensor
"""
broadcast tensors to common shape then stack along new dimension
"""

broadcast_shape = torch.broadcast_shapes(*(x.size() for x in tensors))
broadcast_tensors = [x.broadcast_to(broadcast_shape) for x in tensors]

return torch.stack(broadcast_tensors, dim=dim)

{{ util.function_declaration(spec) }}
{% if spec.docstring %}
{{ util.print_docstring(spec.docstring) | indent(4) }}
Expand Down
4 changes: 2 additions & 2 deletions symforce/codegen/backends/pytorch/templates/util/util.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -172,13 +172,13 @@ def {{ function_name_and_args(spec) }}:
{% if cols == 1 %}
_{{ name }} = (
{% else %}
_{{ name }} = torch.stack([
_{{ name }} = _broadcast_and_stack([
{% endif %}
{% for j in range(cols) %}
{% if rows == 1 %}
{{ terms[j][1] }}{% if not loop.last %},{% endif %}
{% else %}
torch.stack([
_broadcast_and_stack([
{% for i in range(rows) %}
{# NOTE(brad): The order of the terms is the storage order of geo.Matrix. If the
storage order of geo.Matrix is changed (i.e., from column major to row major),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,18 @@ class TensorKwargs(T.TypedDict):
dtype: torch.dtype


def _broadcast_and_stack(tensors, dim=-1):
# type: (T.List[torch.Tensor], int) -> torch.Tensor
"""
broadcast tensors to common shape then stack along new dimension
"""

broadcast_shape = torch.broadcast_shapes(*(x.size() for x in tensors))
broadcast_tensors = [x.broadcast_to(broadcast_shape) for x in tensors]

return torch.stack(broadcast_tensors, dim=dim)


def backend_test_function(x, y, tensor_kwargs=None):
# type: (torch.Tensor, torch.Tensor, TensorKwargs) -> T.Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,18 @@ class TensorKwargs(T.TypedDict):
dtype: torch.dtype


def _broadcast_and_stack(tensors, dim=-1):
# type: (T.List[torch.Tensor], int) -> torch.Tensor
"""
broadcast tensors to common shape then stack along new dimension
"""

broadcast_shape = torch.broadcast_shapes(*(x.size() for x in tensors))
broadcast_tensors = [x.broadcast_to(broadcast_shape) for x in tensors]

return torch.stack(broadcast_tensors, dim=dim)


def pytorch_func(a, b, c, d, e, f, tensor_kwargs=None):
# type: (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, TensorKwargs) -> T.Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
"""
Expand Down Expand Up @@ -72,18 +84,18 @@ def pytorch_func(a, b, c, d, e, f, tensor_kwargs=None):
# Output terms
_a_out = a
_b_out = b
_c_out = torch.stack([c[..., 0], c[..., 1], c[..., 2]], dim=-1)
_d_out = torch.stack(
_c_out = _broadcast_and_stack([c[..., 0], c[..., 1], c[..., 2]], dim=-1)
_d_out = _broadcast_and_stack(
[
torch.stack([d[..., 0, 0], d[..., 1, 0]], dim=-1),
torch.stack([d[..., 0, 1], d[..., 1, 1]], dim=-1),
_broadcast_and_stack([d[..., 0, 0], d[..., 1, 0]], dim=-1),
_broadcast_and_stack([d[..., 0, 1], d[..., 1, 1]], dim=-1),
],
dim=-1,
)
_e_out = torch.stack([e[..., 0], e[..., 1], e[..., 2], e[..., 3], e[..., 4]], dim=-1)
_f_out = torch.stack(
_e_out = _broadcast_and_stack([e[..., 0], e[..., 1], e[..., 2], e[..., 3], e[..., 4]], dim=-1)
_f_out = _broadcast_and_stack(
[
torch.stack(
_broadcast_and_stack(
[
f[..., 0, 0],
f[..., 1, 0],
Expand All @@ -94,7 +106,7 @@ def pytorch_func(a, b, c, d, e, f, tensor_kwargs=None):
],
dim=-1,
),
torch.stack(
_broadcast_and_stack(
[
f[..., 0, 1],
f[..., 1, 1],
Expand All @@ -105,7 +117,7 @@ def pytorch_func(a, b, c, d, e, f, tensor_kwargs=None):
],
dim=-1,
),
torch.stack(
_broadcast_and_stack(
[
f[..., 0, 2],
f[..., 1, 2],
Expand All @@ -116,7 +128,7 @@ def pytorch_func(a, b, c, d, e, f, tensor_kwargs=None):
],
dim=-1,
),
torch.stack(
_broadcast_and_stack(
[
f[..., 0, 3],
f[..., 1, 3],
Expand All @@ -127,7 +139,7 @@ def pytorch_func(a, b, c, d, e, f, tensor_kwargs=None):
],
dim=-1,
),
torch.stack(
_broadcast_and_stack(
[
f[..., 0, 4],
f[..., 1, 4],
Expand All @@ -138,7 +150,7 @@ def pytorch_func(a, b, c, d, e, f, tensor_kwargs=None):
],
dim=-1,
),
torch.stack(
_broadcast_and_stack(
[
f[..., 0, 5],
f[..., 1, 5],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,18 @@ class TensorKwargs(T.TypedDict):
dtype: torch.dtype


def _broadcast_and_stack(tensors, dim=-1):
# type: (T.List[torch.Tensor], int) -> torch.Tensor
"""
broadcast tensors to common shape then stack along new dimension
"""

broadcast_shape = torch.broadcast_shapes(*(x.size() for x in tensors))
broadcast_tensors = [x.broadcast_to(broadcast_shape) for x in tensors]

return torch.stack(broadcast_tensors, dim=dim)


def backend_test_function(x, y, tensor_kwargs=None):
# type: (torch.Tensor, torch.Tensor, TensorKwargs) -> T.Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,18 @@ class TensorKwargs(T.TypedDict):
dtype: torch.dtype


def _broadcast_and_stack(tensors, dim=-1):
# type: (T.List[torch.Tensor], int) -> torch.Tensor
"""
broadcast tensors to common shape then stack along new dimension
"""

broadcast_shape = torch.broadcast_shapes(*(x.size() for x in tensors))
broadcast_tensors = [x.broadcast_to(broadcast_shape) for x in tensors]

return torch.stack(broadcast_tensors, dim=dim)


def pytorch_func(a, b, c, d, e, f, tensor_kwargs=None):
# type: (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, TensorKwargs) -> T.Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
"""
Expand Down Expand Up @@ -72,18 +84,18 @@ def pytorch_func(a, b, c, d, e, f, tensor_kwargs=None):
# Output terms
_a_out = a
_b_out = b
_c_out = torch.stack([c[..., 0], c[..., 1], c[..., 2]], dim=-1)
_d_out = torch.stack(
_c_out = _broadcast_and_stack([c[..., 0], c[..., 1], c[..., 2]], dim=-1)
_d_out = _broadcast_and_stack(
[
torch.stack([d[..., 0, 0], d[..., 1, 0]], dim=-1),
torch.stack([d[..., 0, 1], d[..., 1, 1]], dim=-1),
_broadcast_and_stack([d[..., 0, 0], d[..., 1, 0]], dim=-1),
_broadcast_and_stack([d[..., 0, 1], d[..., 1, 1]], dim=-1),
],
dim=-1,
)
_e_out = torch.stack([e[..., 0], e[..., 1], e[..., 2], e[..., 3], e[..., 4]], dim=-1)
_f_out = torch.stack(
_e_out = _broadcast_and_stack([e[..., 0], e[..., 1], e[..., 2], e[..., 3], e[..., 4]], dim=-1)
_f_out = _broadcast_and_stack(
[
torch.stack(
_broadcast_and_stack(
[
f[..., 0, 0],
f[..., 1, 0],
Expand All @@ -94,7 +106,7 @@ def pytorch_func(a, b, c, d, e, f, tensor_kwargs=None):
],
dim=-1,
),
torch.stack(
_broadcast_and_stack(
[
f[..., 0, 1],
f[..., 1, 1],
Expand All @@ -105,7 +117,7 @@ def pytorch_func(a, b, c, d, e, f, tensor_kwargs=None):
],
dim=-1,
),
torch.stack(
_broadcast_and_stack(
[
f[..., 0, 2],
f[..., 1, 2],
Expand All @@ -116,7 +128,7 @@ def pytorch_func(a, b, c, d, e, f, tensor_kwargs=None):
],
dim=-1,
),
torch.stack(
_broadcast_and_stack(
[
f[..., 0, 3],
f[..., 1, 3],
Expand All @@ -127,7 +139,7 @@ def pytorch_func(a, b, c, d, e, f, tensor_kwargs=None):
],
dim=-1,
),
torch.stack(
_broadcast_and_stack(
[
f[..., 0, 4],
f[..., 1, 4],
Expand All @@ -138,7 +150,7 @@ def pytorch_func(a, b, c, d, e, f, tensor_kwargs=None):
],
dim=-1,
),
torch.stack(
_broadcast_and_stack(
[
f[..., 0, 5],
f[..., 1, 5],
Expand Down

0 comments on commit 1387297

Please sign in to comment.