diff --git a/aten/src/ATen/native/mps/operations/Convolution.mm b/aten/src/ATen/native/mps/operations/Convolution.mm index 0fe690698c3b8..753306a98285a 100644 --- a/aten/src/ATen/native/mps/operations/Convolution.mm +++ b/aten/src/ATen/native/mps/operations/Convolution.mm @@ -33,8 +33,9 @@ void fill_conv_desc(MPSGraphConvolution2DOpDescriptor* descriptor_, descriptor_.dataLayout = (memory_format == at::MemoryFormat::Contiguous) ? MPSGraphTensorNamedDataLayoutNCHW : MPSGraphTensorNamedDataLayoutNHWC; - descriptor_.weightsLayout = (memory_format == at::MemoryFormat::Contiguous) ? - MPSGraphTensorNamedDataLayoutOIHW : MPSGraphTensorNamedDataLayoutHWIO; + + // PyTorch always uses OIHW memory layout for weights + descriptor_.weightsLayout = MPSGraphTensorNamedDataLayoutOIHW; descriptor_.groups = groups; } @@ -61,6 +62,7 @@ Tensor _mps_convolution( bias_defined = bias_opt->defined(); auto memory_format = input_t.suggest_memory_format(); + bool is_channels_last = (memory_format == at::MemoryFormat::ChannelsLast); auto output_t = at::empty( conv_output_size(input->sizes(), weight->sizes(), padding, stride, dilation), @@ -68,7 +70,7 @@ Tensor _mps_convolution( c10::nullopt, kMPS, c10::nullopt, - memory_format); + c10::nullopt); if (output_t.numel() == 0) { return output_t; @@ -122,6 +124,15 @@ Tensor _mps_convolution( + mps::getTensorsStringKey({input_t, weight_t}) + ":" + to_string(bias_defined) + ":" + bias_shape_key; CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + MPSShape* inputShape = nil; + + if (is_channels_last) { + const auto inputSizes = input_t.sizes(); + IntArrayRef input_nhwc = {inputSizes[0], inputSizes[2], inputSizes[3], inputSizes[1]}; + inputShape = native_mps::getMPSShape(input_nhwc); + } else { + inputShape = native_mps::getMPSShape(input_t); + } if(!cachedGraph) { native_mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ native_mps::MPSCachedGraph * () { @@ -138,21 +149,29 @@ Tensor _mps_convolution( padding[1], padding[0], memory_format, groups); - MPSGraphTensor* inputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, input_t); + MPSGraphTensor* inputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, native_mps::getMPSScalarType(input_t.scalar_type()), inputShape); MPSGraphTensor* weightTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, weight_t); + MPSGraphTensor* biasTensor = nil; if(bias_defined) biasTensor = native_mps::mpsGraphUnrankedPlaceHolder(mpsGraph, native_mps::getMPSDataType((bias_opt.value()).scalar_type())); - MPSGraphTensor* outputTensor = [mpsGraph convolution2DWithSourceTensor:inputTensor - weightsTensor:weightTensor - descriptor:descriptor_ - name:nil]; + MPSGraphTensor* outputTensor = [mpsGraph convolution2DWithSourceTensor: inputTensor + weightsTensor: weightTensor + descriptor: descriptor_ + name: nil]; + if (is_channels_last) { + // NHWC -> NCHW + outputTensor = [mpsGraph transposeTensor: [mpsGraph transposeTensor:outputTensor dimension:-1 withDimension:-2 name:nil] + dimension: -2 + withDimension: -3 + name: nil]; + } if(bias_defined) { - outputTensor = [mpsGraph additionWithPrimaryTensor:outputTensor - secondaryTensor:biasTensor - name:nil]; + outputTensor = [mpsGraph additionWithPrimaryTensor: outputTensor + secondaryTensor: biasTensor + name: nil]; } newCachedGraph->inputTensor_ = inputTensor; @@ -165,7 +184,7 @@ Tensor _mps_convolution( cachedGraph = static_cast(tmpCachedGraph); } - auto inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, input_t); + auto inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, input_t, inputShape); auto weightsPlaceholder = native_mps::Placeholder(cachedGraph->weightTensor_, weight_t); auto biasPlaceholder = native_mps::Placeholder(); // Reshape the bias to be broadcastable with output of conv2d diff --git a/test/test_mps.py b/test/test_mps.py index 01755a6475e94..31acc32177cb5 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -1096,6 +1096,31 @@ def helper(input_shape, wt_shape, helper((N, C_out, H, W), (C_out, C_in, kH, kW), bias_shape=(C_in), stride=stride, padding=padding, output_padding=output_padding, dilation=dilation) + def test_conv1d_channels_last(self): + model_cpu = torch.nn.Conv1d(1, 128, 3) + a_cpu = torch.arange((128 * 176), dtype=torch.float32) + a_cpu = a_cpu.view(128, 176, 1).permute(0, 2, 1) + out_cpu = model_cpu(a_cpu) # pass + + a_mps = a_cpu.detach().clone().to("mps") + model_mps = model_cpu.to("mps") + out_mps = model_mps(a_mps) + + torch.testing.assert_allclose(out_cpu.shape, out_mps.shape) + torch.testing.assert_allclose(out_cpu, out_mps.cpu()) + + def test_conv1d_contiguous(self): + model_cpu = torch.nn.Conv1d(1, 128, 3) + a_cpu = torch.ones(128, 1, 176) + out_cpu = model_cpu(a_cpu) + + a_mps = a_cpu.detach().clone().to("mps") + model_mps = model_cpu.to("mps") + out_mps = model_mps(a_mps) + + torch.testing.assert_allclose(out_cpu.shape, out_mps.shape) + torch.testing.assert_allclose(out_cpu, out_mps.cpu()) + # Test sigmoid def test_sigmoid(self): def helper(shape):