diff --git a/aten/src/ATen/native/mps/operations/Convolution.mm b/aten/src/ATen/native/mps/operations/Convolution.mm index aad32ed2b878f..eaad7c64f9fee 100644 --- a/aten/src/ATen/native/mps/operations/Convolution.mm +++ b/aten/src/ATen/native/mps/operations/Convolution.mm @@ -39,6 +39,19 @@ void fill_conv_desc(MPSGraphConvolution2DOpDescriptor* descriptor_, descriptor_.groups = groups; } +static +MPSShape* get_mps_conv_shape(bool is_channels_last, const Tensor& tensor) { + if (is_channels_last) { + const auto tensorSizes = tensor.sizes(); + const NSUInteger N = tensorSizes[0]; + const NSUInteger C = tensorSizes[1]; + const NSUInteger H = tensorSizes[2]; + const NSUInteger W = tensorSizes[3]; + return @[@(N), @(H), @(W), @(C)]; + } + return at::native::mps::getMPSShape(tensor); +} + Tensor _mps_convolution( const Tensor& input_t, const Tensor& weight_t, @@ -126,19 +139,7 @@ Tensor _mps_convolution( + mps::getTensorsStringKey({input_t, weight_t}) + ":" + to_string(bias_defined) + ":" + bias_shape_key; CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - MPSShape* inputShape = nil; - - if (is_channels_last) { - const auto inputSizes = input_t.sizes(); - const NSUInteger N = inputSizes[0]; - const NSUInteger C = inputSizes[1]; - const NSUInteger H = inputSizes[2]; - const NSUInteger W = inputSizes[3]; - inputShape = @[@(N), @(H), @(W), @(C)]; - } else { - inputShape = native_mps::getMPSShape(input_t); - } - + MPSShape* inputShape = get_mps_conv_shape(is_channels_last, input_t); if(!cachedGraph) { native_mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ native_mps::MPSCachedGraph * () { @@ -333,6 +334,9 @@ Tensor mps_convolution_backward_weights( using namespace mps; CheckedFrom c = "mps_convolution_backward_weights"; auto memory_format = input_t.suggest_memory_format(); + bool is_channels_last = (memory_format == at::MemoryFormat::ChannelsLast); + MPSShape* inputShape = get_mps_conv_shape(is_channels_last, input_t); + MPSShape* gradOutputShape = get_mps_conv_shape(is_channels_last, grad_output_t); // For uniformity with everything else, although it seems grad_weight // would be unambiguous too. @@ -399,8 +403,8 @@ Tensor mps_convolution_backward_weights( padding[1], padding[0], memory_format, groups); - MPSGraphTensor* gradOutputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, grad_output_t); - MPSGraphTensor* inputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, input_t); + MPSGraphTensor* gradOutputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, native_mps::getMPSScalarType(grad_output_t.scalar_type()), gradOutputShape); + MPSGraphTensor* inputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, native_mps::getMPSScalarType(input_t.scalar_type()), inputShape); MPSGraphTensor* gradWeightTensor = [mpsGraph convolution2DWeightsGradientWithIncomingGradientTensor:gradOutputTensor sourceTensor:inputTensor @@ -417,8 +421,8 @@ Tensor mps_convolution_backward_weights( cachedGraph = static_cast(tmpCachedGraph); } - auto gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output_t); - auto inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input_t); + auto gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output_t, gradOutputShape); + auto inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input_t, inputShape); auto outputPlaceholder = Placeholder(cachedGraph->gradWeightTensor_, grad_weight_t); NSDictionary *feeds = @{ diff --git a/test/test_mps.py b/test/test_mps.py index 23e6512bcebd9..3395d73ce234d 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -5953,6 +5953,19 @@ def test_conv_transpose_1d_nn_functional(self): self.assertEqual(tcpu, tgpu.cpu(), rtol=2.6e-05, atol=2e-04) + def test_conv_backward_1d_channels_last(self): + # https://github.com/pytorch/pytorch/issues/84511 + conv_cpu = torch.nn.Conv1d(in_channels=1, out_channels=1, kernel_size=3) + conv_mps = copy.deepcopy(conv_cpu).to(device='mps') + + data = torch.rand(1, 176, 1, dtype=torch.float32) + x_cpu = data.permute(0, 2, 1).contiguous() + x_mps = data.permute(0, 2, 1).contiguous().to("mps") + res_cpu = conv_cpu(x_cpu).sum().backward() + res_mps = conv_mps(x_mps).sum().backward() + + self.assertEqual(res_cpu, res_mps) + def test_conv1d_contiguous(self): model_cpu = torch.nn.Conv1d(1, 128, 3) a_cpu = torch.ones(128, 1, 176)