From df54d2fb812423f35b456fb77c90a974cb28983f Mon Sep 17 00:00:00 2001 From: Ramin Azarmehr Date: Wed, 25 Jan 2023 20:33:19 -0500 Subject: [PATCH] Fix the crash with hardswish_backward - Also fix indentation and formatting --- .../ATen/native/mps/operations/Activation.mm | 178 ++++++++---------- 1 file changed, 83 insertions(+), 95 deletions(-) diff --git a/aten/src/ATen/native/mps/operations/Activation.mm b/aten/src/ATen/native/mps/operations/Activation.mm index b84436bd99f5a..69be087ee2aaf 100644 --- a/aten/src/ATen/native/mps/operations/Activation.mm +++ b/aten/src/ATen/native/mps/operations/Activation.mm @@ -2320,11 +2320,10 @@ Tensor hardswish_mps(const Tensor& self) { Tensor hardswish_backward_mps(const Tensor& grad_output, const Tensor& self) { using namespace mps; - if (grad_output.numel() == 0) { - return grad_output; - } - Tensor grad_input = at::empty_like(self, self.suggest_memory_format()); + if (grad_input.numel() == 0) { + return grad_input; + } struct CachedGraph : public MPSCachedGraph { CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} @@ -2335,113 +2334,102 @@ Tensor hardswish_backward_mps(const Tensor& grad_output, const Tensor& self) { MPSGraphCache* cache_ = MPSGraphCache::getInstance(); - MPSStream* stream = at::mps::getCurrentMPSStream(); - @autoreleasepool { string key = "hardswish_backward_mps" + getTensorsStringKey({self}); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + CachedGraph* cachedGraph = cache_->LookUpAs(key); if (!cachedGraph) { - MPSCachedGraph* tmpCachedGraph = - cache_->CreateCachedGraph(key, ^MPSCachedGraph*() { - CachedGraph* newCachedGraph = nil; - @autoreleasepool { - MPSGraph* mpsGraph = make_mps_graph(); - newCachedGraph = new CachedGraph(mpsGraph); - MPSGraphTensor* gradOutputTensor = - mpsGraphRankedPlaceHolder(mpsGraph, grad_output); - MPSGraphTensor* inputTensor = - mpsGraphRankedPlaceHolder(mpsGraph, self); - - MPSGraphTensor* zeroTensor = [mpsGraph - constantWithScalar:0.0f - shape:@[ @1 ] - dataType:getMPSDataType(grad_output.scalar_type())]; - - MPSGraphTensor* unitTensor = [mpsGraph - constantWithScalar:1.0f - shape:@[ @1 ] - dataType:getMPSDataType(grad_output.scalar_type())]; - - MPSGraphTensor* threeTensor = [mpsGraph - constantWithScalar:3.0f - shape:@[ @1 ] - dataType:getMPSDataType(grad_output.scalar_type())]; - - MPSGraphTensor* negativeThreeTensor = [mpsGraph - constantWithScalar:-3.0f - shape:@[ @1 ] - dataType:getMPSDataType(grad_output.scalar_type())]; - - MPSGraphTensor* halfTensor = [mpsGraph - constantWithScalar:0.5f - shape:@[ @1 ] - dataType:getMPSDataType(grad_output.scalar_type())]; - - MPSGraphTensor* tempTensor = - [mpsGraph divisionWithPrimaryTensor:inputTensor - secondaryTensor:threeTensor - name:nil]; - - MPSGraphTensor* weightedTensor = - [mpsGraph additionWithPrimaryTensor:tempTensor - secondaryTensor:halfTensor - name:nil]; - - MPSGraphTensor* lessThanMinPredicateTensor = [mpsGraph - lessThanOrEqualToWithPrimaryTensor:inputTensor - secondaryTensor:negativeThreeTensor - name:nil]; - - MPSGraphTensor* lessThanMaxPredicateTensor = - [mpsGraph lessThanWithPrimaryTensor:inputTensor - secondaryTensor:threeTensor - name:nil]; - - MPSGraphTensor* lessThanMaxGradTensor = - [mpsGraph selectWithPredicateTensor:lessThanMaxPredicateTensor - truePredicateTensor:weightedTensor - falsePredicateTensor:unitTensor - name:nil]; + cachedGraph = cache_->CreateCachedGraphAs(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; + @autoreleasepool { + MPSGraph* mpsGraph = make_mps_graph(); + newCachedGraph = new CachedGraph(mpsGraph); + MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output); + MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); - MPSGraphTensor* gradTensor = - [mpsGraph selectWithPredicateTensor:lessThanMinPredicateTensor - truePredicateTensor:zeroTensor - falsePredicateTensor:lessThanMaxGradTensor - name:nil]; - MPSGraphTensor* gradInputTensor = - [mpsGraph multiplicationWithPrimaryTensor:gradTensor - secondaryTensor:gradOutputTensor - name:nil]; + MPSGraphTensor* zeroTensor = [mpsGraph + constantWithScalar:0.0f + shape:@[ @1 ] + dataType:getMPSDataType(grad_output.scalar_type())]; + + MPSGraphTensor* unitTensor = [mpsGraph + constantWithScalar:1.0f + shape:@[ @1 ] + dataType:getMPSDataType(grad_output.scalar_type())]; + + MPSGraphTensor* threeTensor = [mpsGraph + constantWithScalar:3.0f + shape:@[ @1 ] + dataType:getMPSDataType(grad_output.scalar_type())]; + + MPSGraphTensor* negativeThreeTensor = [mpsGraph + constantWithScalar:-3.0f + shape:@[ @1 ] + dataType:getMPSDataType(grad_output.scalar_type())]; + + MPSGraphTensor* halfTensor = [mpsGraph + constantWithScalar:0.5f + shape:@[ @1 ] + dataType:getMPSDataType(grad_output.scalar_type())]; + + MPSGraphTensor* tempTensor = + [mpsGraph divisionWithPrimaryTensor:inputTensor + secondaryTensor:threeTensor + name:nil]; + + MPSGraphTensor* weightedTensor = + [mpsGraph additionWithPrimaryTensor:tempTensor + secondaryTensor:halfTensor + name:nil]; + + MPSGraphTensor* lessThanMinPredicateTensor = [mpsGraph + lessThanOrEqualToWithPrimaryTensor:inputTensor + secondaryTensor:negativeThreeTensor + name:nil]; + + MPSGraphTensor* lessThanMaxPredicateTensor = + [mpsGraph lessThanWithPrimaryTensor:inputTensor + secondaryTensor:threeTensor + name:nil]; + + MPSGraphTensor* lessThanMaxGradTensor = + [mpsGraph selectWithPredicateTensor:lessThanMaxPredicateTensor + truePredicateTensor:weightedTensor + falsePredicateTensor:unitTensor + name:nil]; + + MPSGraphTensor* gradTensor = + [mpsGraph selectWithPredicateTensor:lessThanMinPredicateTensor + truePredicateTensor:zeroTensor + falsePredicateTensor:lessThanMaxGradTensor + name:nil]; + MPSGraphTensor* gradInputTensor = + [mpsGraph multiplicationWithPrimaryTensor:gradTensor + secondaryTensor:gradOutputTensor + name:nil]; - newCachedGraph->gradOutputTensor_ = gradOutputTensor; - newCachedGraph->inputTensor_ = inputTensor; - newCachedGraph->gradInputTensor_ = gradInputTensor; - } - return newCachedGraph; - }); - cachedGraph = static_cast(tmpCachedGraph); + newCachedGraph->gradOutputTensor_ = gradOutputTensor; + newCachedGraph->inputTensor_ = inputTensor; + newCachedGraph->gradInputTensor_ = gradInputTensor; + } + return newCachedGraph; + }); } - Placeholder gradOutputPlaceholder = - Placeholder(cachedGraph->gradOutputTensor_, grad_output); + Placeholder gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output); Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self); - Placeholder gradInputPlaceholder = - Placeholder(cachedGraph->gradInputTensor_, grad_input); + Placeholder gradInputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, grad_input); // Create dictionary of inputs and outputs NSDictionary* feeds = @{ - gradOutputPlaceholder.getMPSGraphTensor() : - gradOutputPlaceholder.getMPSGraphTensorData(), - selfPlaceholder.getMPSGraphTensor() : - selfPlaceholder.getMPSGraphTensorData() + gradOutputPlaceholder.getMPSGraphTensor() : gradOutputPlaceholder.getMPSGraphTensorData(), + selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData() }; NSDictionary* results = @{ - gradInputPlaceholder.getMPSGraphTensor() : - gradInputPlaceholder.getMPSGraphTensorData() + gradInputPlaceholder.getMPSGraphTensor() : gradInputPlaceholder.getMPSGraphTensorData() }; - runMPSGraph(stream, cachedGraph->graph(), feeds, results); + runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results); } return grad_input; }