Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions aten/src/ATen/native/mps/operations/Shape.mm
Original file line number Diff line number Diff line change
Expand Up @@ -639,11 +639,19 @@ void check_shape_except_dim(const Tensor &first, const Tensor &second,

// Create placeholders
MPSGraphTensor* inputMPSGraphTensors[inputs.size()];
MPSGraphTensor* castInputMPSGraphTensors[inputs.size()];

for(int i = 0; i < inputs.size(); i++)
for(int i = 0; i < inputs.size(); i++) {
inputMPSGraphTensors[i] = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(result_type(inputs)));
if(getMPSDataType(result_type(inputs)) == MPSDataTypeBool)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this code work for other types? did you test with consistency test to make sure other types are working correctly ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this code works for all types
Also, I think we should let the test I added remain, because it fails without my change, but testConsistency passes without my change

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay.

castInputMPSGraphTensors[i] = [mpsGraph castTensor:inputMPSGraphTensors[i]
toType:MPSDataTypeInt32
name:[NSString stringWithFormat:@"inputTensor_%@", [NSNumber numberWithInt:i]]];
else
castInputMPSGraphTensors[i] = inputMPSGraphTensors[i];
}

auto inputTensorsArray = [NSArray arrayWithObjects:inputMPSGraphTensors
auto inputTensorsArray = [NSArray arrayWithObjects:castInputMPSGraphTensors
count:inputs.size()];
// Use concatTensors to concatenate
MPSGraphTensor* outputTensor = [mpsGraph concatTensors:inputTensorsArray
Expand All @@ -654,6 +662,10 @@ void check_shape_except_dim(const Tensor &first, const Tensor &second,

for(int i = 0; i < inputs.size(); i++)
newCachedGraph->inputMPSGraphTensors_[i] = inputMPSGraphTensors[i];
if(getMPSDataType(result_type(inputs)) == MPSDataTypeBool)
outputTensor = [mpsGraph castTensor:outputTensor
toType:MPSDataTypeBool
name:@"outputTensor"];
newCachedGraph->outputTensor_ = outputTensor;
}
return newCachedGraph;
Expand Down
36 changes: 29 additions & 7 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -3390,22 +3390,44 @@ def helper(shape, padding, op):
# Test stack forward
def test_stack(self):
# All shapes must be same
def helper(shape):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
def helper(shape, dtype=torch.float32):

cpu_y = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
y = cpu_y.detach().clone().to('mps')
x, cpu_x = None, None
y, cpu_y = None, None
z, cpu_z = None, None

cpu_z = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
z = cpu_z.detach().clone().to('mps')
if(dtype not in [torch.float32, torch.bool]):
cpu_x = torch.randint(50, shape, device='cpu', dtype=dtype, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
cpu_y = torch.randint(50, shape, device='cpu', dtype=dtype, requires_grad=False)
y = cpu_y.detach().clone().to('mps')
cpu_z = torch.randint(50, shape, device='cpu', dtype=dtype, requires_grad=False)
z = cpu_z.detach().clone().to('mps')
elif (dtype == torch.bool):
cpu_x = torch.randint(2, shape, device='cpu', dtype=dtype, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
cpu_y = torch.randint(2, shape, device='cpu', dtype=dtype, requires_grad=False)
y = cpu_y.detach().clone().to('mps')
cpu_z = torch.randint(2, shape, device='cpu', dtype=dtype, requires_grad=False)
z = cpu_z.detach().clone().to('mps')
else:
cpu_x = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
cpu_y = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=True)
y = cpu_y.detach().clone().to('mps').requires_grad_()
cpu_z = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=True)
z = cpu_z.detach().clone().to('mps').requires_grad_()

stack = torch.stack([x, y, z], dim=1)
stack_cpu = torch.stack([cpu_x, cpu_y, cpu_z], dim=1)

self.assertEqual(stack, stack_cpu)

helper([2, 8, 4, 5])
helper([2, 8, 4, 5], dtype=torch.float16)
helper([2, 8, 4, 5], dtype=torch.int32)
helper([2, 8, 4, 5], dtype=torch.int64)
helper([2, 8, 4, 5], dtype=torch.bool)
# Empty test - Currently failing! Empty tensor not handled!
# helper([0, 2, 4, 5])

Expand Down