Skip to content

Commit

Permalink
[NestedTensor] Extend coverage for unbind when ragged_idx != 1 (#127493)
Browse files Browse the repository at this point in the history
Summary:
Extend coverage for the `NestedTensor` `unbind` operator to cases in which `ragged_idx != 1`.

Currently, the `unbind` operator in the `NestedTensor` class splits a tensor along the 0-th dimension, where the `ragged_idx` property, which controls the jagged dimension upon which `unbind` splits, is 1. This diff extends support for `ragged_idx != 1` in `NestedTensor`s, allowing `unbind` to split a tensor along a jagged dimension greater than 0 for `NestedTensor`s with and without the `lengths` property.

Test Plan:
Added the following unit tests:

`test_unbind_ragged_idx_equals_2_cpu`, `test_unbind_ragged_idx_equals_3_cpu`, and `test_unbind_ragged_idx_equals_last_dim_cpu` verify that `unbind` works for all jagged dimensions greater than 1, for `NestedTensor`s without `lengths`.
```
test_unbind_ragged_idx_equals_2_cpu (test_nestedtensor.TestNestedTensorSubclassCPU) ... ok
test_unbind_ragged_idx_equals_3_cpu (test_nestedtensor.TestNestedTensorSubclassCPU) ... ok
test_unbind_ragged_idx_equals_last_dim_cpu (test_nestedtensor.TestNestedTensorSubclassCPU) ... ok
```

`test_unbind_with_lengths_cpu` and `test_unbind_with_lengths_ragged_idx_equals_1_cpu` verify that `unbind` works when the jagged dimension is 1, for `NestedTensor`s with `lengths`.
```
test_unbind_with_lengths_cpu (test_nestedtensor.TestNestedTensorSubclassCPU) ... ok
test_unbind_with_lengths_ragged_idx_equals_1_cpu (test_nestedtensor.TestNestedTensorSubclassCPU) ... ok
```

`test_unbind_with_lengths_ragged_idx_equals_2_cpu` and `test_unbind_with_lengths_ragged_idx_equals_3_cpu` verify that `unbind` works when the jagged dimension is greater than 1, for `NestedTensor`s with `lengths`.
```
test_unbind_with_lengths_ragged_idx_equals_2_cpu (test_nestedtensor.TestNestedTensorSubclassCPU) ... ok
test_unbind_with_lengths_ragged_idx_equals_3_cpu (test_nestedtensor.TestNestedTensorSubclassCPU) ... ok
```

`test_unbind_with_lengths_ragged_idx_equals_0_cpu` verifies that `unbind` fails when the jagged dimension is 0 (the batch dimension), for `NestedTensor`s with `lengths`.
```
test_unbind_with_lengths_ragged_idx_equals_0_cpu (test_nestedtensor.TestNestedTensorSubclassCPU) ... ok
```

`test_unbind_with_lengths_ragged_idx_equals_2_bad_dim_cpu` verifies that `unbind` fails when there is a mismatch between the offsets and the jagged dimension, for `NestedTensor`s with `lengths`.
```
test_unbind_with_lengths_ragged_idx_equals_2_bad_dim_cpu (test_nestedtensor.TestNestedTensorSubclassCPU) ... ok
```

`test_unbind_with_wrong_lengths_cpu` verifies that `unbind` fails when the lengths exceed the limitations set by offsets, for `NestedTensor`s with `lengths`.

```
test_unbind_with_wrong_lengths_cpu (test_nestedtensor.TestNestedTensorSubclassCPU) ... ok
```

Differential Revision: D57942686

Pull Request resolved: #127493
Approved by: https://github.com/davidberard98
  • Loading branch information
jananisriram authored and pytorchmergebot committed Jun 3, 2024
1 parent 4d32de1 commit 7c3740d
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 7 deletions.
139 changes: 138 additions & 1 deletion test/test_nestedtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3048,6 +3048,15 @@ def _make_tensor(*shape, include_requires_grad=include_requires_grad, requires_g
_make_tensor(5, 5, 6),
_make_tensor(6, 5, 6),
],
# (B, *, D_0, D_1, D_2) with B=6
[
_make_tensor(2, 5, 6, 7),
_make_tensor(3, 5, 6, 7),
_make_tensor(4, 5, 6, 7, requires_grad=False),
_make_tensor(5, 5, 6, 7),
_make_tensor(6, 5, 6, 7),
_make_tensor(7, 5, 6, 7),
],
]

if include_list_of_lists:
Expand Down Expand Up @@ -3786,12 +3795,140 @@ def test_unbind(self, device):
nt = torch.nested.nested_tensor(
tensor_list,
layout=torch.jagged,
device=device)
device=device) # ragged_idx = 1
out = nt.unbind()
self.assertEqual(len(out), len(tensor_list))
for i, t in enumerate(out):
self.assertEqual(t, tensor_list[i])

@parametrize("ragged_idx", [2, 3])
def test_unbind_transpose(self, device, ragged_idx):
for tensor_list in self._get_example_tensor_lists():
nt = torch.nested.nested_tensor(
tensor_list,
layout=torch.jagged,
device=device)
if ragged_idx < nt.dim():
nt = nt.transpose(1, ragged_idx) # set ragged_idx
out = nt.unbind()
self.assertEqual(len(out), len(tensor_list))
for i, t in enumerate(out):
self.assertEqual(t.transpose(0, ragged_idx - 1), tensor_list[i]) # transpose back each element of result

def test_unbind_transpose_ragged_idx_last_dim(self, device):
for tensor_list in self._get_example_tensor_lists():
nt = torch.nested.nested_tensor(
tensor_list,
layout=torch.jagged,
device=device).transpose(1, -1) # set ragged_idx = last dimension
out = nt.unbind()
self.assertEqual(len(out), len(tensor_list))
for i, t in enumerate(out):
self.assertEqual(t.transpose(0, -1), tensor_list[i]) # transpose back each element of result

def test_unbind_lengths(self, device):
values = torch.randn(16, 128, device=device)
offsets = torch.tensor([0, 8, 12, 13, 16], device=device)
lengths = torch.tensor([6, 2, 1, 2], device=device)
nt = torch.nested.nested_tensor_from_jagged(
values,
offsets=offsets,
lengths=lengths) # 3D nested tensor

tensor_list = []
for i in range(offsets.shape[0] - 1):
tensor_list.append(values[offsets[i] : (offsets[i] + lengths[i])])

out = nt.unbind()
self.assertEqual(len(out), len(tensor_list))
for i, t in enumerate(out):
self.assertEqual(t, tensor_list[i])

def test_unbind_lengths_ragged_idx_1(self, device):
values = torch.randn(16, 8, 128, device=device)
offsets = torch.tensor([0, 8, 12, 13, 16], device=device)
lengths = torch.tensor([6, 2, 1, 2], device=device)
ragged_idx = 1
nt = torch.nested._internal.nested_tensor.NestedTensor(
values,
offsets=offsets,
lengths=lengths,
_ragged_idx=ragged_idx) # 4D nested tensor

tensor_list = []
for i in range(offsets.shape[0] - 1):
tensor_list.append(values[offsets[i] : (offsets[i] + lengths[i]), :, :])

out = nt.unbind()

self.assertEqual(len(out), len(tensor_list))
for i, t in enumerate(out):
self.assertEqual(t, tensor_list[i])

def test_unbind_lengths_ragged_idx_2(self, device):
values = torch.randn(16, 8, 128, device=device)
offsets = torch.tensor([0, 2, 4, 8], device=device)
lengths = torch.tensor([2, 1, 3], device=device)
ragged_idx = 2
nt = torch.nested._internal.nested_tensor.NestedTensor(
values,
offsets=offsets,
lengths=lengths,
_ragged_idx=ragged_idx) # 4D nested tensor

tensor_list = []
for i in range(offsets.shape[0] - 1):
tensor_list.append(values[:, offsets[i] : (offsets[i] + lengths[i]), :])

out = nt.unbind()

self.assertEqual(len(out), len(tensor_list))
for i, t in enumerate(out):
self.assertEqual(t, tensor_list[i])

def test_unbind_lengths_ragged_idx_3(self, device):
values = torch.randn(16, 8, 128, device=device)
offsets = torch.tensor([0, 100, 128], device=device)
lengths = torch.tensor([50, 28], device=device)
ragged_idx = 3
nt = torch.nested._internal.nested_tensor.NestedTensor(
values,
offsets=offsets,
lengths=lengths,
_ragged_idx=ragged_idx) # 4D nested tensor

tensor_list = []
for i in range(offsets.shape[0] - 1):
tensor_list.append(values[:, :, offsets[i] : (offsets[i] + lengths[i])])

out = nt.unbind()

self.assertEqual(len(out), len(tensor_list))
for i, t in enumerate(out):
self.assertEqual(t, tensor_list[i])

@skipIfTorchDynamo("TorchDynamo raises an error for ragged_idx == 0 earlier than Torch")
def test_unbind_lengths_ragged_idx_0(self, device):
values = torch.randn(16, 8, 128, device=device)
offsets = torch.tensor([0, 100, 128], device=device)
lengths = torch.tensor([50, 28], device=device)
ragged_idx = 0
nt = torch.nested._internal.nested_tensor.NestedTensor(
values,
offsets=offsets,
lengths=lengths,
_ragged_idx=ragged_idx) # 4D nested tensor

tensor_list = []
for i in range(offsets.shape[0] - 1):
tensor_list.append(values[:, :, offsets[i] : (offsets[i] + lengths[i])])

self.assertRaisesRegex(
RuntimeError,
r"unbind\(\): nested tensor.*out of bounds",
lambda: nt.unbind()
)

@xfailIfTorchDynamo
def test_layer_norm_2(self, device):
test_tensor_list = self._get_list_for_jagged_tensor(
Expand Down
14 changes: 8 additions & 6 deletions torch/nested/_internal/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,16 +616,18 @@ def unbind_int(func, *args, **kwargs):
values = inp.values()
offsets = inp.offsets()
lengths = inp.lengths()
ragged_idx = inp._ragged_idx

if inp._ragged_idx != 1:
if lengths is None:
return torch.split(values, offsets.diff().tolist(), dim=(ragged_idx - 1))

if ragged_idx <= 0:
raise RuntimeError(
"unbind(): only supported for NestedTensor when jagged dimension is 1"
"unbind(): nested tensor ragged_idx out of bounds (should be >= 1)"
)

if lengths is None:
return torch.split(values, offsets.diff().tolist())
return [
values[offsets[i] : (offsets[i] + lengths[i])] for i in range(lengths.shape[0])
torch.narrow(values, dim=(ragged_idx - 1), start=offsets[i], length=lengths[i])
for i in range(lengths.shape[0])
]


Expand Down

0 comments on commit 7c3740d

Please sign in to comment.