Skip to content

Commit 74474d6

Browse files
committed
add integrated the unbind tests
1 parent 1e0f44a commit 74474d6

File tree

1 file changed

+76
-0
lines changed

1 file changed

+76
-0
lines changed

tests/function_libs/torch_lib/e2e_ops_tests.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -520,6 +520,82 @@ def forward(self, x):
520520
)
521521
_testing.assert_onnx_program(onnx_program)
522522

523+
def test_unbind_dim0(self):
524+
"""Test unbind along dimension 0 (pytorch/pytorch#168969)."""
525+
526+
class UnbindModel(torch.nn.Module):
527+
def forward(self, x):
528+
tensors = torch.unbind(x, dim=0)
529+
return sum(tensors)
530+
531+
model = UnbindModel()
532+
x = torch.randn(3, 4, 5)
533+
onnx_program = torch.onnx.export(model, (x,), dynamo=True, verbose=False)
534+
_testing.assert_onnx_program(onnx_program)
535+
536+
def test_unbind_dim1(self):
537+
"""Test unbind along dimension 1 (pytorch/pytorch#168969)."""
538+
539+
class UnbindModel(torch.nn.Module):
540+
def forward(self, x):
541+
tensors = torch.unbind(x, dim=1)
542+
return sum(tensors)
543+
544+
model = UnbindModel()
545+
x = torch.randn(2, 3, 4)
546+
onnx_program = torch.onnx.export(model, (x,), dynamo=True, verbose=False)
547+
_testing.assert_onnx_program(onnx_program)
548+
549+
def test_unbind_negative_dim(self):
550+
"""Test unbind with negative dimension (pytorch/pytorch#168969)."""
551+
552+
class UnbindModel(torch.nn.Module):
553+
def forward(self, x):
554+
tensors = torch.unbind(x, dim=-1)
555+
return sum(tensors)
556+
557+
model = UnbindModel()
558+
x = torch.randn(2, 3, 4)
559+
onnx_program = torch.onnx.export(model, (x,), dynamo=True, verbose=False)
560+
_testing.assert_onnx_program(onnx_program)
561+
562+
def test_unbind_size_one(self):
563+
"""Test unbind with dimension of size 1 (pytorch/pytorch#168969)."""
564+
565+
class UnbindModel(torch.nn.Module):
566+
def forward(self, x):
567+
tensors = torch.unbind(x, dim=0)
568+
return tensors[0]
569+
570+
model = UnbindModel()
571+
x = torch.randn(1, 4, 5)
572+
onnx_program = torch.onnx.export(model, (x,), dynamo=True, verbose=False)
573+
_testing.assert_onnx_program(onnx_program)
574+
575+
def test_unbind_with_lstm(self):
576+
"""Test unbind in LSTM context (pytorch/pytorch#168969)."""
577+
578+
class LSTMDecoder(torch.nn.Module):
579+
def __init__(self):
580+
super().__init__()
581+
self.embedding = torch.nn.Embedding(100, 64)
582+
self.lstm = torch.nn.LSTM(64, 64, 2, batch_first=True) # 2 layers
583+
self.fc = torch.nn.Linear(64, 100)
584+
585+
def forward(self, tokens, h, c):
586+
embedded = self.embedding(tokens).unsqueeze(0)
587+
output, (h_out, c_out) = self.lstm(embedded, (h, c))
588+
logits = self.fc(output.squeeze(0).squeeze(0))
589+
return logits, h_out, c_out
590+
591+
model = LSTMDecoder()
592+
model.eval()
593+
tokens = torch.tensor([1])
594+
h = torch.randn(2, 1, 64) # 2 layers
595+
c = torch.randn(2, 1, 64) # 2 layers
596+
onnx_program = torch.onnx.export(model, (tokens, h, c), dynamo=True, verbose=False)
597+
_testing.assert_onnx_program(onnx_program)
598+
523599

524600
if __name__ == "__main__":
525601
unittest.main()

0 commit comments

Comments
 (0)