Skip to content

Commit

Permalink
Add support for conv1d channels last memory format (#78)
Browse files Browse the repository at this point in the history
* Add support for conv1d channels last memory format

* Remove NCHW -> NHWC unnecessary permute when input is already in NHWC
  • Loading branch information
DenisVieriu97 authored and kulinseth committed Aug 8, 2022
1 parent 06aa934 commit e957317
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 12 deletions.
43 changes: 31 additions & 12 deletions aten/src/ATen/native/mps/operations/Convolution.mm
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@ void fill_conv_desc(MPSGraphConvolution2DOpDescriptor* descriptor_,

descriptor_.dataLayout = (memory_format == at::MemoryFormat::Contiguous) ?
MPSGraphTensorNamedDataLayoutNCHW : MPSGraphTensorNamedDataLayoutNHWC;
descriptor_.weightsLayout = (memory_format == at::MemoryFormat::Contiguous) ?
MPSGraphTensorNamedDataLayoutOIHW : MPSGraphTensorNamedDataLayoutHWIO;

// PyTorch always uses OIHW memory layout for weights
descriptor_.weightsLayout = MPSGraphTensorNamedDataLayoutOIHW;
descriptor_.groups = groups;
}

Expand All @@ -61,14 +62,15 @@ Tensor _mps_convolution(
bias_defined = bias_opt->defined();

auto memory_format = input_t.suggest_memory_format();
bool is_channels_last = (memory_format == at::MemoryFormat::ChannelsLast);
auto output_t = at::empty(
conv_output_size(input->sizes(), weight->sizes(),
padding, stride, dilation),
input->scalar_type(),
c10::nullopt,
kMPS,
c10::nullopt,
memory_format);
c10::nullopt);

if (output_t.numel() == 0) {
return output_t;
Expand Down Expand Up @@ -122,6 +124,15 @@ Tensor _mps_convolution(
+ mps::getTensorsStringKey({input_t, weight_t}) + ":"
+ to_string(bias_defined) + ":" + bias_shape_key;
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
MPSShape* inputShape = nil;

if (is_channels_last) {
const auto inputSizes = input_t.sizes();
IntArrayRef input_nhwc = {inputSizes[0], inputSizes[2], inputSizes[3], inputSizes[1]};
inputShape = native_mps::getMPSShape(input_nhwc);
} else {
inputShape = native_mps::getMPSShape(input_t);
}

if(!cachedGraph) {
native_mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ native_mps::MPSCachedGraph * () {
Expand All @@ -138,21 +149,29 @@ Tensor _mps_convolution(
padding[1], padding[0],
memory_format, groups);

MPSGraphTensor* inputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, input_t);
MPSGraphTensor* inputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, native_mps::getMPSScalarType(input_t.scalar_type()), inputShape);
MPSGraphTensor* weightTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, weight_t);

MPSGraphTensor* biasTensor = nil;
if(bias_defined)
biasTensor = native_mps::mpsGraphUnrankedPlaceHolder(mpsGraph, native_mps::getMPSDataType((bias_opt.value()).scalar_type()));

MPSGraphTensor* outputTensor = [mpsGraph convolution2DWithSourceTensor:inputTensor
weightsTensor:weightTensor
descriptor:descriptor_
name:nil];
MPSGraphTensor* outputTensor = [mpsGraph convolution2DWithSourceTensor: inputTensor
weightsTensor: weightTensor
descriptor: descriptor_
name: nil];
if (is_channels_last) {
// NHWC -> NCHW
outputTensor = [mpsGraph transposeTensor: [mpsGraph transposeTensor:outputTensor dimension:-1 withDimension:-2 name:nil]
dimension: -2
withDimension: -3
name: nil];
}

if(bias_defined) {
outputTensor = [mpsGraph additionWithPrimaryTensor:outputTensor
secondaryTensor:biasTensor
name:nil];
outputTensor = [mpsGraph additionWithPrimaryTensor: outputTensor
secondaryTensor: biasTensor
name: nil];
}

newCachedGraph->inputTensor_ = inputTensor;
Expand All @@ -165,7 +184,7 @@ Tensor _mps_convolution(
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
}

auto inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, input_t);
auto inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, input_t, inputShape);
auto weightsPlaceholder = native_mps::Placeholder(cachedGraph->weightTensor_, weight_t);
auto biasPlaceholder = native_mps::Placeholder();
// Reshape the bias to be broadcastable with output of conv2d
Expand Down
25 changes: 25 additions & 0 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -1096,6 +1096,31 @@ def helper(input_shape, wt_shape,
helper((N, C_out, H, W), (C_out, C_in, kH, kW), bias_shape=(C_in), stride=stride,
padding=padding, output_padding=output_padding, dilation=dilation)

def test_conv1d_channels_last(self):
model_cpu = torch.nn.Conv1d(1, 128, 3)
a_cpu = torch.arange((128 * 176), dtype=torch.float32)
a_cpu = a_cpu.view(128, 176, 1).permute(0, 2, 1)
out_cpu = model_cpu(a_cpu) # pass

a_mps = a_cpu.detach().clone().to("mps")
model_mps = model_cpu.to("mps")
out_mps = model_mps(a_mps)

torch.testing.assert_allclose(out_cpu.shape, out_mps.shape)
torch.testing.assert_allclose(out_cpu, out_mps.cpu())

def test_conv1d_contiguous(self):
model_cpu = torch.nn.Conv1d(1, 128, 3)
a_cpu = torch.ones(128, 1, 176)
out_cpu = model_cpu(a_cpu)

a_mps = a_cpu.detach().clone().to("mps")
model_mps = model_cpu.to("mps")
out_mps = model_mps(a_mps)

torch.testing.assert_allclose(out_cpu.shape, out_mps.shape)
torch.testing.assert_allclose(out_cpu, out_mps.cpu())

# Test sigmoid
def test_sigmoid(self):
def helper(shape):
Expand Down

0 comments on commit e957317

Please sign in to comment.