Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 21 additions & 17 deletions aten/src/ATen/native/mps/operations/Convolution.mm
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,19 @@ void fill_conv_desc(MPSGraphConvolution2DOpDescriptor* descriptor_,
descriptor_.groups = groups;
}

static
MPSShape* get_mps_conv_shape(bool is_channels_last, const Tensor& tensor) {
if (is_channels_last) {
const auto tensorSizes = tensor.sizes();
const NSUInteger N = tensorSizes[0];
const NSUInteger C = tensorSizes[1];
const NSUInteger H = tensorSizes[2];
const NSUInteger W = tensorSizes[3];
return @[@(N), @(H), @(W), @(C)];
}
return at::native::mps::getMPSShape(tensor);
}

Tensor _mps_convolution(
const Tensor& input_t,
const Tensor& weight_t,
Expand Down Expand Up @@ -126,19 +139,7 @@ Tensor _mps_convolution(
+ mps::getTensorsStringKey({input_t, weight_t}) + ":"
+ to_string(bias_defined) + ":" + bias_shape_key;
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
MPSShape* inputShape = nil;

if (is_channels_last) {
const auto inputSizes = input_t.sizes();
const NSUInteger N = inputSizes[0];
const NSUInteger C = inputSizes[1];
const NSUInteger H = inputSizes[2];
const NSUInteger W = inputSizes[3];
inputShape = @[@(N), @(H), @(W), @(C)];
} else {
inputShape = native_mps::getMPSShape(input_t);
}

MPSShape* inputShape = get_mps_conv_shape(is_channels_last, input_t);
if(!cachedGraph) {
native_mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ native_mps::MPSCachedGraph * () {

Expand Down Expand Up @@ -333,6 +334,9 @@ Tensor mps_convolution_backward_weights(
using namespace mps;
CheckedFrom c = "mps_convolution_backward_weights";
auto memory_format = input_t.suggest_memory_format();
bool is_channels_last = (memory_format == at::MemoryFormat::ChannelsLast);
MPSShape* inputShape = get_mps_conv_shape(is_channels_last, input_t);
MPSShape* gradOutputShape = get_mps_conv_shape(is_channels_last, grad_output_t);

// For uniformity with everything else, although it seems grad_weight
// would be unambiguous too.
Expand Down Expand Up @@ -399,8 +403,8 @@ Tensor mps_convolution_backward_weights(
padding[1], padding[0],
memory_format, groups);

MPSGraphTensor* gradOutputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, grad_output_t);
MPSGraphTensor* inputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, input_t);
MPSGraphTensor* gradOutputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, native_mps::getMPSScalarType(grad_output_t.scalar_type()), gradOutputShape);
MPSGraphTensor* inputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, native_mps::getMPSScalarType(input_t.scalar_type()), inputShape);

MPSGraphTensor* gradWeightTensor = [mpsGraph convolution2DWeightsGradientWithIncomingGradientTensor:gradOutputTensor
sourceTensor:inputTensor
Expand All @@ -417,8 +421,8 @@ Tensor mps_convolution_backward_weights(
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
}

auto gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output_t);
auto inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input_t);
auto gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output_t, gradOutputShape);
auto inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input_t, inputShape);
auto outputPlaceholder = Placeholder(cachedGraph->gradWeightTensor_, grad_weight_t);

NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *feeds = @{
Expand Down
13 changes: 13 additions & 0 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -5953,6 +5953,19 @@ def test_conv_transpose_1d_nn_functional(self):

self.assertEqual(tcpu, tgpu.cpu(), rtol=2.6e-05, atol=2e-04)

def test_conv_backward_1d_channels_last(self):
# https://github.com/pytorch/pytorch/issues/84511
conv_cpu = torch.nn.Conv1d(in_channels=1, out_channels=1, kernel_size=3)
conv_mps = copy.deepcopy(conv_cpu).to(device='mps')

data = torch.rand(1, 176, 1, dtype=torch.float32)
x_cpu = data.permute(0, 2, 1).contiguous()
x_mps = data.permute(0, 2, 1).contiguous().to("mps")
res_cpu = conv_cpu(x_cpu).sum().backward()
res_mps = conv_mps(x_mps).sum().backward()

self.assertEqual(res_cpu, res_mps)

def test_conv1d_contiguous(self):
model_cpu = torch.nn.Conv1d(1, 128, 3)
a_cpu = torch.ones(128, 1, 176)
Expand Down