@@ -3399,22 +3399,44 @@ def helper(shape, padding, op):
33993399 # Test stack forward
34003400 def test_stack (self ):
34013401 # All shapes must be same
3402- def helper (shape ):
3403- cpu_x = torch .randn (shape , device = 'cpu' , dtype = torch .float , requires_grad = False )
3404- x = cpu_x .detach ().clone ().to ('mps' )
3402+ def helper (shape , dtype = torch .float32 ):
34053403
3406- cpu_y = torch .randn (shape , device = 'cpu' , dtype = torch .float , requires_grad = False )
3407- y = cpu_y .detach ().clone ().to ('mps' )
3404+ x , cpu_x = None , None
3405+ y , cpu_y = None , None
3406+ z , cpu_z = None , None
34083407
3409- cpu_z = torch .randn (shape , device = 'cpu' , dtype = torch .float , requires_grad = False )
3410- z = cpu_z .detach ().clone ().to ('mps' )
3408+ if (dtype not in [torch .float32 , torch .bool ]):
3409+ cpu_x = torch .randint (50 , shape , device = 'cpu' , dtype = dtype , requires_grad = False )
3410+ x = cpu_x .detach ().clone ().to ('mps' )
3411+ cpu_y = torch .randint (50 , shape , device = 'cpu' , dtype = dtype , requires_grad = False )
3412+ y = cpu_y .detach ().clone ().to ('mps' )
3413+ cpu_z = torch .randint (50 , shape , device = 'cpu' , dtype = dtype , requires_grad = False )
3414+ z = cpu_z .detach ().clone ().to ('mps' )
3415+ elif (dtype == torch .bool ):
3416+ cpu_x = torch .randint (2 , shape , device = 'cpu' , dtype = dtype , requires_grad = False )
3417+ x = cpu_x .detach ().clone ().to ('mps' )
3418+ cpu_y = torch .randint (2 , shape , device = 'cpu' , dtype = dtype , requires_grad = False )
3419+ y = cpu_y .detach ().clone ().to ('mps' )
3420+ cpu_z = torch .randint (2 , shape , device = 'cpu' , dtype = dtype , requires_grad = False )
3421+ z = cpu_z .detach ().clone ().to ('mps' )
3422+ else :
3423+ cpu_x = torch .randn (shape , device = 'cpu' , dtype = dtype , requires_grad = True )
3424+ x = cpu_x .detach ().clone ().to ('mps' ).requires_grad_ ()
3425+ cpu_y = torch .randn (shape , device = 'cpu' , dtype = dtype , requires_grad = True )
3426+ y = cpu_y .detach ().clone ().to ('mps' ).requires_grad_ ()
3427+ cpu_z = torch .randn (shape , device = 'cpu' , dtype = dtype , requires_grad = True )
3428+ z = cpu_z .detach ().clone ().to ('mps' ).requires_grad_ ()
34113429
34123430 stack = torch .stack ([x , y , z ], dim = 1 )
34133431 stack_cpu = torch .stack ([cpu_x , cpu_y , cpu_z ], dim = 1 )
34143432
34153433 self .assertEqual (stack , stack_cpu )
34163434
34173435 helper ([2 , 8 , 4 , 5 ])
3436+ helper ([2 , 8 , 4 , 5 ], dtype = torch .float16 )
3437+ helper ([2 , 8 , 4 , 5 ], dtype = torch .int32 )
3438+ helper ([2 , 8 , 4 , 5 ], dtype = torch .int64 )
3439+ helper ([2 , 8 , 4 , 5 ], dtype = torch .bool )
34183440 # Empty test - Currently failing! Empty tensor not handled!
34193441 # helper([0, 2, 4, 5])
34203442
0 commit comments