@@ -520,6 +520,82 @@ def forward(self, x):
520520 )
521521 _testing .assert_onnx_program (onnx_program )
522522
523+ def test_unbind_dim0 (self ):
524+ """Test unbind along dimension 0 (pytorch/pytorch#168969)."""
525+
526+ class UnbindModel (torch .nn .Module ):
527+ def forward (self , x ):
528+ tensors = torch .unbind (x , dim = 0 )
529+ return sum (tensors )
530+
531+ model = UnbindModel ()
532+ x = torch .randn (3 , 4 , 5 )
533+ onnx_program = torch .onnx .export (model , (x ,), dynamo = True , verbose = False )
534+ _testing .assert_onnx_program (onnx_program )
535+
536+ def test_unbind_dim1 (self ):
537+ """Test unbind along dimension 1 (pytorch/pytorch#168969)."""
538+
539+ class UnbindModel (torch .nn .Module ):
540+ def forward (self , x ):
541+ tensors = torch .unbind (x , dim = 1 )
542+ return sum (tensors )
543+
544+ model = UnbindModel ()
545+ x = torch .randn (2 , 3 , 4 )
546+ onnx_program = torch .onnx .export (model , (x ,), dynamo = True , verbose = False )
547+ _testing .assert_onnx_program (onnx_program )
548+
549+ def test_unbind_negative_dim (self ):
550+ """Test unbind with negative dimension (pytorch/pytorch#168969)."""
551+
552+ class UnbindModel (torch .nn .Module ):
553+ def forward (self , x ):
554+ tensors = torch .unbind (x , dim = - 1 )
555+ return sum (tensors )
556+
557+ model = UnbindModel ()
558+ x = torch .randn (2 , 3 , 4 )
559+ onnx_program = torch .onnx .export (model , (x ,), dynamo = True , verbose = False )
560+ _testing .assert_onnx_program (onnx_program )
561+
562+ def test_unbind_size_one (self ):
563+ """Test unbind with dimension of size 1 (pytorch/pytorch#168969)."""
564+
565+ class UnbindModel (torch .nn .Module ):
566+ def forward (self , x ):
567+ tensors = torch .unbind (x , dim = 0 )
568+ return tensors [0 ]
569+
570+ model = UnbindModel ()
571+ x = torch .randn (1 , 4 , 5 )
572+ onnx_program = torch .onnx .export (model , (x ,), dynamo = True , verbose = False )
573+ _testing .assert_onnx_program (onnx_program )
574+
575+ def test_unbind_with_lstm (self ):
576+ """Test unbind in LSTM context (pytorch/pytorch#168969)."""
577+
578+ class LSTMDecoder (torch .nn .Module ):
579+ def __init__ (self ):
580+ super ().__init__ ()
581+ self .embedding = torch .nn .Embedding (100 , 64 )
582+ self .lstm = torch .nn .LSTM (64 , 64 , 2 , batch_first = True ) # 2 layers
583+ self .fc = torch .nn .Linear (64 , 100 )
584+
585+ def forward (self , tokens , h , c ):
586+ embedded = self .embedding (tokens ).unsqueeze (0 )
587+ output , (h_out , c_out ) = self .lstm (embedded , (h , c ))
588+ logits = self .fc (output .squeeze (0 ).squeeze (0 ))
589+ return logits , h_out , c_out
590+
591+ model = LSTMDecoder ()
592+ model .eval ()
593+ tokens = torch .tensor ([1 ])
594+ h = torch .randn (2 , 1 , 64 ) # 2 layers
595+ c = torch .randn (2 , 1 , 64 ) # 2 layers
596+ onnx_program = torch .onnx .export (model , (tokens , h , c ), dynamo = True , verbose = False )
597+ _testing .assert_onnx_program (onnx_program )
598+
523599
524600if __name__ == "__main__" :
525601 unittest .main ()
0 commit comments