Skip to content

Commit 521cf5e

Browse files
committed
Fix aten_unbind for torch >= 2.7 dynamo export
Replace Split op with explicit Slice operations to fix TypeError when unbind is called during ONNX export with dynamo=True. The Split op with num_outputs parameter returns a non-iterable SymbolicTensor instead of a sequence, causing the list comprehension to fail. The fix uses individual Slice + Squeeze operations for each output, which properly handles symbolic tensors during graph construction. Fixes pytorch/pytorch#168969
1 parent 9dbf685 commit 521cf5e

File tree

2 files changed

+90
-9
lines changed

2 files changed

+90
-9
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9200,16 +9200,21 @@ def aten_type_as(self: TTensor, other: TTensor2) -> TTensor2:
92009200
def aten_unbind(self: TTensor, dim: int = 0) -> Sequence[TTensor]:
92019201
"""unbind.int(Tensor(a -> *) self, int dim=0) -> Tensor(a)[]"""
92029202

9203-
if isinstance(self.shape[dim], int) and not version_utils.torch_older_than("2.7"):
9204-
# We can create a definitive split op if the input shape is static
9205-
# Only torch>=2.7 supports correctly generating the correct number of outputs for Split
9203+
if isinstance(self.shape[dim], int):
92069204
num_outputs = self.shape[dim]
9207-
if num_outputs != 1:
9208-
outputs = op.Split(self, axis=dim, num_outputs=num_outputs)
9209-
else:
9210-
outputs = [self]
9211-
9212-
return [op.Squeeze(out, [dim]) for out in outputs]
9205+
results = []
9206+
for i in range(num_outputs):
9207+
# Slice to get a single element at position i along dim
9208+
sliced = op.Slice(
9209+
self,
9210+
starts=op.Constant(value_ints=[i]),
9211+
ends=op.Constant(value_ints=[i + 1]),
9212+
axes=op.Constant(value_ints=[dim]),
9213+
)
9214+
# Squeeze to remove the dimension of size 1
9215+
squeezed = op.Squeeze(sliced, axes=[dim])
9216+
results.append(squeezed)
9217+
return results
92139218

92149219
return op.SplitToSequence(self, axis=dim, keepdims=False)
92159220

tests/function_libs/torch_lib/e2e_ops_tests.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -520,6 +520,82 @@ def forward(self, x):
520520
)
521521
_testing.assert_onnx_program(onnx_program)
522522

523+
def test_unbind_dim0(self):
524+
"""Test unbind along dimension 0 (pytorch/pytorch#168969)."""
525+
526+
class UnbindModel(torch.nn.Module):
527+
def forward(self, x):
528+
tensors = torch.unbind(x, dim=0)
529+
return sum(tensors)
530+
531+
model = UnbindModel()
532+
x = torch.randn(3, 4, 5)
533+
onnx_program = torch.onnx.export(model, (x,), dynamo=True, verbose=False)
534+
_testing.assert_onnx_program(onnx_program)
535+
536+
def test_unbind_dim1(self):
537+
"""Test unbind along dimension 1 (pytorch/pytorch#168969)."""
538+
539+
class UnbindModel(torch.nn.Module):
540+
def forward(self, x):
541+
tensors = torch.unbind(x, dim=1)
542+
return sum(tensors)
543+
544+
model = UnbindModel()
545+
x = torch.randn(2, 3, 4)
546+
onnx_program = torch.onnx.export(model, (x,), dynamo=True, verbose=False)
547+
_testing.assert_onnx_program(onnx_program)
548+
549+
def test_unbind_negative_dim(self):
550+
"""Test unbind with negative dimension (pytorch/pytorch#168969)."""
551+
552+
class UnbindModel(torch.nn.Module):
553+
def forward(self, x):
554+
tensors = torch.unbind(x, dim=-1)
555+
return sum(tensors)
556+
557+
model = UnbindModel()
558+
x = torch.randn(2, 3, 4)
559+
onnx_program = torch.onnx.export(model, (x,), dynamo=True, verbose=False)
560+
_testing.assert_onnx_program(onnx_program)
561+
562+
def test_unbind_size_one(self):
563+
"""Test unbind with dimension of size 1 (pytorch/pytorch#168969)."""
564+
565+
class UnbindModel(torch.nn.Module):
566+
def forward(self, x):
567+
tensors = torch.unbind(x, dim=0)
568+
return tensors[0]
569+
570+
model = UnbindModel()
571+
x = torch.randn(1, 4, 5)
572+
onnx_program = torch.onnx.export(model, (x,), dynamo=True, verbose=False)
573+
_testing.assert_onnx_program(onnx_program)
574+
575+
def test_unbind_with_lstm(self):
576+
"""Test unbind in LSTM context (pytorch/pytorch#168969)."""
577+
578+
class LSTMDecoder(torch.nn.Module):
579+
def __init__(self):
580+
super().__init__()
581+
self.embedding = torch.nn.Embedding(100, 64)
582+
self.lstm = torch.nn.LSTM(64, 64, 2, batch_first=True) # 2 layers
583+
self.fc = torch.nn.Linear(64, 100)
584+
585+
def forward(self, tokens, h, c):
586+
embedded = self.embedding(tokens).unsqueeze(0)
587+
output, (h_out, c_out) = self.lstm(embedded, (h, c))
588+
logits = self.fc(output.squeeze(0).squeeze(0))
589+
return logits, h_out, c_out
590+
591+
model = LSTMDecoder()
592+
model.eval()
593+
tokens = torch.tensor([1])
594+
h = torch.randn(2, 1, 64) # 2 layers
595+
c = torch.randn(2, 1, 64) # 2 layers
596+
onnx_program = torch.onnx.export(model, (tokens, h, c), dynamo=True, verbose=False)
597+
_testing.assert_onnx_program(onnx_program)
598+
523599

524600
if __name__ == "__main__":
525601
unittest.main()

0 commit comments

Comments
 (0)