Skip to content

Commit 3872d20

Browse files
authored
Add support for bool input; expand testing (#54)
1 parent c267d73 commit 3872d20

File tree

2 files changed

+43
-9
lines changed

2 files changed

+43
-9
lines changed

aten/src/ATen/native/mps/operations/Shape.mm

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -639,11 +639,19 @@ void check_shape_except_dim(const Tensor &first, const Tensor &second,
639639

640640
// Create placeholders
641641
MPSGraphTensor* inputMPSGraphTensors[inputs.size()];
642+
MPSGraphTensor* castInputMPSGraphTensors[inputs.size()];
642643

643-
for(int i = 0; i < inputs.size(); i++)
644+
for(int i = 0; i < inputs.size(); i++) {
644645
inputMPSGraphTensors[i] = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(result_type(inputs)));
646+
if(getMPSDataType(result_type(inputs)) == MPSDataTypeBool)
647+
castInputMPSGraphTensors[i] = [mpsGraph castTensor:inputMPSGraphTensors[i]
648+
toType:MPSDataTypeInt32
649+
name:[NSString stringWithFormat:@"inputTensor_%@", [NSNumber numberWithInt:i]]];
650+
else
651+
castInputMPSGraphTensors[i] = inputMPSGraphTensors[i];
652+
}
645653

646-
auto inputTensorsArray = [NSArray arrayWithObjects:inputMPSGraphTensors
654+
auto inputTensorsArray = [NSArray arrayWithObjects:castInputMPSGraphTensors
647655
count:inputs.size()];
648656
// Use concatTensors to concatenate
649657
MPSGraphTensor* outputTensor = [mpsGraph concatTensors:inputTensorsArray
@@ -654,6 +662,10 @@ void check_shape_except_dim(const Tensor &first, const Tensor &second,
654662

655663
for(int i = 0; i < inputs.size(); i++)
656664
newCachedGraph->inputMPSGraphTensors_[i] = inputMPSGraphTensors[i];
665+
if(getMPSDataType(result_type(inputs)) == MPSDataTypeBool)
666+
outputTensor = [mpsGraph castTensor:outputTensor
667+
toType:MPSDataTypeBool
668+
name:@"outputTensor"];
657669
newCachedGraph->outputTensor_ = outputTensor;
658670
}
659671
return newCachedGraph;

test/test_mps.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3390,22 +3390,44 @@ def helper(shape, padding, op):
33903390
# Test stack forward
33913391
def test_stack(self):
33923392
# All shapes must be same
3393-
def helper(shape):
3394-
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
3395-
x = cpu_x.detach().clone().to('mps')
3393+
def helper(shape, dtype=torch.float32):
33963394

3397-
cpu_y = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
3398-
y = cpu_y.detach().clone().to('mps')
3395+
x, cpu_x = None, None
3396+
y, cpu_y = None, None
3397+
z, cpu_z = None, None
33993398

3400-
cpu_z = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
3401-
z = cpu_z.detach().clone().to('mps')
3399+
if(dtype not in [torch.float32, torch.bool]):
3400+
cpu_x = torch.randint(50, shape, device='cpu', dtype=dtype, requires_grad=False)
3401+
x = cpu_x.detach().clone().to('mps')
3402+
cpu_y = torch.randint(50, shape, device='cpu', dtype=dtype, requires_grad=False)
3403+
y = cpu_y.detach().clone().to('mps')
3404+
cpu_z = torch.randint(50, shape, device='cpu', dtype=dtype, requires_grad=False)
3405+
z = cpu_z.detach().clone().to('mps')
3406+
elif (dtype == torch.bool):
3407+
cpu_x = torch.randint(2, shape, device='cpu', dtype=dtype, requires_grad=False)
3408+
x = cpu_x.detach().clone().to('mps')
3409+
cpu_y = torch.randint(2, shape, device='cpu', dtype=dtype, requires_grad=False)
3410+
y = cpu_y.detach().clone().to('mps')
3411+
cpu_z = torch.randint(2, shape, device='cpu', dtype=dtype, requires_grad=False)
3412+
z = cpu_z.detach().clone().to('mps')
3413+
else:
3414+
cpu_x = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=True)
3415+
x = cpu_x.detach().clone().to('mps').requires_grad_()
3416+
cpu_y = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=True)
3417+
y = cpu_y.detach().clone().to('mps').requires_grad_()
3418+
cpu_z = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=True)
3419+
z = cpu_z.detach().clone().to('mps').requires_grad_()
34023420

34033421
stack = torch.stack([x, y, z], dim=1)
34043422
stack_cpu = torch.stack([cpu_x, cpu_y, cpu_z], dim=1)
34053423

34063424
self.assertEqual(stack, stack_cpu)
34073425

34083426
helper([2, 8, 4, 5])
3427+
helper([2, 8, 4, 5], dtype=torch.float16)
3428+
helper([2, 8, 4, 5], dtype=torch.int32)
3429+
helper([2, 8, 4, 5], dtype=torch.int64)
3430+
helper([2, 8, 4, 5], dtype=torch.bool)
34093431
# Empty test - Currently failing! Empty tensor not handled!
34103432
# helper([0, 2, 4, 5])
34113433

0 commit comments

Comments
 (0)