Skip to content

Commit 311d724

Browse files
abhudevkulinseth
authored andcommitted
Add support for bool input; expand testing (#54)
1 parent 9c07bcd commit 311d724

File tree

2 files changed

+43
-9
lines changed

2 files changed

+43
-9
lines changed

aten/src/ATen/native/mps/operations/Shape.mm

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -639,11 +639,19 @@ void check_shape_except_dim(const Tensor &first, const Tensor &second,
639639

640640
// Create placeholders
641641
MPSGraphTensor* inputMPSGraphTensors[inputs.size()];
642+
MPSGraphTensor* castInputMPSGraphTensors[inputs.size()];
642643

643-
for(int i = 0; i < inputs.size(); i++)
644+
for(int i = 0; i < inputs.size(); i++) {
644645
inputMPSGraphTensors[i] = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(result_type(inputs)));
646+
if(getMPSDataType(result_type(inputs)) == MPSDataTypeBool)
647+
castInputMPSGraphTensors[i] = [mpsGraph castTensor:inputMPSGraphTensors[i]
648+
toType:MPSDataTypeInt32
649+
name:[NSString stringWithFormat:@"inputTensor_%@", [NSNumber numberWithInt:i]]];
650+
else
651+
castInputMPSGraphTensors[i] = inputMPSGraphTensors[i];
652+
}
645653

646-
auto inputTensorsArray = [NSArray arrayWithObjects:inputMPSGraphTensors
654+
auto inputTensorsArray = [NSArray arrayWithObjects:castInputMPSGraphTensors
647655
count:inputs.size()];
648656
// Use concatTensors to concatenate
649657
MPSGraphTensor* outputTensor = [mpsGraph concatTensors:inputTensorsArray
@@ -654,6 +662,10 @@ void check_shape_except_dim(const Tensor &first, const Tensor &second,
654662

655663
for(int i = 0; i < inputs.size(); i++)
656664
newCachedGraph->inputMPSGraphTensors_[i] = inputMPSGraphTensors[i];
665+
if(getMPSDataType(result_type(inputs)) == MPSDataTypeBool)
666+
outputTensor = [mpsGraph castTensor:outputTensor
667+
toType:MPSDataTypeBool
668+
name:@"outputTensor"];
657669
newCachedGraph->outputTensor_ = outputTensor;
658670
}
659671
return newCachedGraph;

test/test_mps.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3399,22 +3399,44 @@ def helper(shape, padding, op):
33993399
# Test stack forward
34003400
def test_stack(self):
34013401
# All shapes must be same
3402-
def helper(shape):
3403-
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
3404-
x = cpu_x.detach().clone().to('mps')
3402+
def helper(shape, dtype=torch.float32):
34053403

3406-
cpu_y = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
3407-
y = cpu_y.detach().clone().to('mps')
3404+
x, cpu_x = None, None
3405+
y, cpu_y = None, None
3406+
z, cpu_z = None, None
34083407

3409-
cpu_z = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
3410-
z = cpu_z.detach().clone().to('mps')
3408+
if(dtype not in [torch.float32, torch.bool]):
3409+
cpu_x = torch.randint(50, shape, device='cpu', dtype=dtype, requires_grad=False)
3410+
x = cpu_x.detach().clone().to('mps')
3411+
cpu_y = torch.randint(50, shape, device='cpu', dtype=dtype, requires_grad=False)
3412+
y = cpu_y.detach().clone().to('mps')
3413+
cpu_z = torch.randint(50, shape, device='cpu', dtype=dtype, requires_grad=False)
3414+
z = cpu_z.detach().clone().to('mps')
3415+
elif (dtype == torch.bool):
3416+
cpu_x = torch.randint(2, shape, device='cpu', dtype=dtype, requires_grad=False)
3417+
x = cpu_x.detach().clone().to('mps')
3418+
cpu_y = torch.randint(2, shape, device='cpu', dtype=dtype, requires_grad=False)
3419+
y = cpu_y.detach().clone().to('mps')
3420+
cpu_z = torch.randint(2, shape, device='cpu', dtype=dtype, requires_grad=False)
3421+
z = cpu_z.detach().clone().to('mps')
3422+
else:
3423+
cpu_x = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=True)
3424+
x = cpu_x.detach().clone().to('mps').requires_grad_()
3425+
cpu_y = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=True)
3426+
y = cpu_y.detach().clone().to('mps').requires_grad_()
3427+
cpu_z = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=True)
3428+
z = cpu_z.detach().clone().to('mps').requires_grad_()
34113429

34123430
stack = torch.stack([x, y, z], dim=1)
34133431
stack_cpu = torch.stack([cpu_x, cpu_y, cpu_z], dim=1)
34143432

34153433
self.assertEqual(stack, stack_cpu)
34163434

34173435
helper([2, 8, 4, 5])
3436+
helper([2, 8, 4, 5], dtype=torch.float16)
3437+
helper([2, 8, 4, 5], dtype=torch.int32)
3438+
helper([2, 8, 4, 5], dtype=torch.int64)
3439+
helper([2, 8, 4, 5], dtype=torch.bool)
34183440
# Empty test - Currently failing! Empty tensor not handled!
34193441
# helper([0, 2, 4, 5])
34203442

0 commit comments

Comments
 (0)