@@ -3390,22 +3390,44 @@ def helper(shape, padding, op):
33903390 # Test stack forward
33913391 def test_stack (self ):
33923392 # All shapes must be same
3393- def helper (shape ):
3394- cpu_x = torch .randn (shape , device = 'cpu' , dtype = torch .float , requires_grad = False )
3395- x = cpu_x .detach ().clone ().to ('mps' )
3393+ def helper (shape , dtype = torch .float32 ):
33963394
3397- cpu_y = torch .randn (shape , device = 'cpu' , dtype = torch .float , requires_grad = False )
3398- y = cpu_y .detach ().clone ().to ('mps' )
3395+ x , cpu_x = None , None
3396+ y , cpu_y = None , None
3397+ z , cpu_z = None , None
33993398
3400- cpu_z = torch .randn (shape , device = 'cpu' , dtype = torch .float , requires_grad = False )
3401- z = cpu_z .detach ().clone ().to ('mps' )
3399+ if (dtype not in [torch .float32 , torch .bool ]):
3400+ cpu_x = torch .randint (50 , shape , device = 'cpu' , dtype = dtype , requires_grad = False )
3401+ x = cpu_x .detach ().clone ().to ('mps' )
3402+ cpu_y = torch .randint (50 , shape , device = 'cpu' , dtype = dtype , requires_grad = False )
3403+ y = cpu_y .detach ().clone ().to ('mps' )
3404+ cpu_z = torch .randint (50 , shape , device = 'cpu' , dtype = dtype , requires_grad = False )
3405+ z = cpu_z .detach ().clone ().to ('mps' )
3406+ elif (dtype == torch .bool ):
3407+ cpu_x = torch .randint (2 , shape , device = 'cpu' , dtype = dtype , requires_grad = False )
3408+ x = cpu_x .detach ().clone ().to ('mps' )
3409+ cpu_y = torch .randint (2 , shape , device = 'cpu' , dtype = dtype , requires_grad = False )
3410+ y = cpu_y .detach ().clone ().to ('mps' )
3411+ cpu_z = torch .randint (2 , shape , device = 'cpu' , dtype = dtype , requires_grad = False )
3412+ z = cpu_z .detach ().clone ().to ('mps' )
3413+ else :
3414+ cpu_x = torch .randn (shape , device = 'cpu' , dtype = dtype , requires_grad = True )
3415+ x = cpu_x .detach ().clone ().to ('mps' ).requires_grad_ ()
3416+ cpu_y = torch .randn (shape , device = 'cpu' , dtype = dtype , requires_grad = True )
3417+ y = cpu_y .detach ().clone ().to ('mps' ).requires_grad_ ()
3418+ cpu_z = torch .randn (shape , device = 'cpu' , dtype = dtype , requires_grad = True )
3419+ z = cpu_z .detach ().clone ().to ('mps' ).requires_grad_ ()
34023420
34033421 stack = torch .stack ([x , y , z ], dim = 1 )
34043422 stack_cpu = torch .stack ([cpu_x , cpu_y , cpu_z ], dim = 1 )
34053423
34063424 self .assertEqual (stack , stack_cpu )
34073425
34083426 helper ([2 , 8 , 4 , 5 ])
3427+ helper ([2 , 8 , 4 , 5 ], dtype = torch .float16 )
3428+ helper ([2 , 8 , 4 , 5 ], dtype = torch .int32 )
3429+ helper ([2 , 8 , 4 , 5 ], dtype = torch .int64 )
3430+ helper ([2 , 8 , 4 , 5 ], dtype = torch .bool )
34093431 # Empty test - Currently failing! Empty tensor not handled!
34103432 # helper([0, 2, 4, 5])
34113433
0 commit comments