Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 83 additions & 95 deletions aten/src/ATen/native/mps/operations/Activation.mm
Original file line number Diff line number Diff line change
Expand Up @@ -2320,11 +2320,10 @@ Tensor hardswish_mps(const Tensor& self) {
Tensor hardswish_backward_mps(const Tensor& grad_output, const Tensor& self) {
using namespace mps;

if (grad_output.numel() == 0) {
return grad_output;
}

Tensor grad_input = at::empty_like(self, self.suggest_memory_format());
if (grad_input.numel() == 0) {
return grad_input;
}

struct CachedGraph : public MPSCachedGraph {
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
Expand All @@ -2335,113 +2334,102 @@ Tensor hardswish_backward_mps(const Tensor& grad_output, const Tensor& self) {

MPSGraphCache* cache_ = MPSGraphCache::getInstance();

MPSStream* stream = at::mps::getCurrentMPSStream();

@autoreleasepool {
string key = "hardswish_backward_mps" + getTensorsStringKey({self});
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
CachedGraph* cachedGraph = cache_->LookUpAs<CachedGraph>(key);
if (!cachedGraph) {
MPSCachedGraph* tmpCachedGraph =
cache_->CreateCachedGraph(key, ^MPSCachedGraph*() {
CachedGraph* newCachedGraph = nil;
@autoreleasepool {
MPSGraph* mpsGraph = make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);
MPSGraphTensor* gradOutputTensor =
mpsGraphRankedPlaceHolder(mpsGraph, grad_output);
MPSGraphTensor* inputTensor =
mpsGraphRankedPlaceHolder(mpsGraph, self);

MPSGraphTensor* zeroTensor = [mpsGraph
constantWithScalar:0.0f
shape:@[ @1 ]
dataType:getMPSDataType(grad_output.scalar_type())];

MPSGraphTensor* unitTensor = [mpsGraph
constantWithScalar:1.0f
shape:@[ @1 ]
dataType:getMPSDataType(grad_output.scalar_type())];

MPSGraphTensor* threeTensor = [mpsGraph
constantWithScalar:3.0f
shape:@[ @1 ]
dataType:getMPSDataType(grad_output.scalar_type())];

MPSGraphTensor* negativeThreeTensor = [mpsGraph
constantWithScalar:-3.0f
shape:@[ @1 ]
dataType:getMPSDataType(grad_output.scalar_type())];

MPSGraphTensor* halfTensor = [mpsGraph
constantWithScalar:0.5f
shape:@[ @1 ]
dataType:getMPSDataType(grad_output.scalar_type())];

MPSGraphTensor* tempTensor =
[mpsGraph divisionWithPrimaryTensor:inputTensor
secondaryTensor:threeTensor
name:nil];

MPSGraphTensor* weightedTensor =
[mpsGraph additionWithPrimaryTensor:tempTensor
secondaryTensor:halfTensor
name:nil];

MPSGraphTensor* lessThanMinPredicateTensor = [mpsGraph
lessThanOrEqualToWithPrimaryTensor:inputTensor
secondaryTensor:negativeThreeTensor
name:nil];

MPSGraphTensor* lessThanMaxPredicateTensor =
[mpsGraph lessThanWithPrimaryTensor:inputTensor
secondaryTensor:threeTensor
name:nil];

MPSGraphTensor* lessThanMaxGradTensor =
[mpsGraph selectWithPredicateTensor:lessThanMaxPredicateTensor
truePredicateTensor:weightedTensor
falsePredicateTensor:unitTensor
name:nil];
cachedGraph = cache_->CreateCachedGraphAs<CachedGraph>(key, ^MPSCachedGraph*() {
CachedGraph* newCachedGraph = nil;
@autoreleasepool {
MPSGraph* mpsGraph = make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);
MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output);
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);

MPSGraphTensor* gradTensor =
[mpsGraph selectWithPredicateTensor:lessThanMinPredicateTensor
truePredicateTensor:zeroTensor
falsePredicateTensor:lessThanMaxGradTensor
name:nil];
MPSGraphTensor* gradInputTensor =
[mpsGraph multiplicationWithPrimaryTensor:gradTensor
secondaryTensor:gradOutputTensor
name:nil];
MPSGraphTensor* zeroTensor = [mpsGraph
constantWithScalar:0.0f
shape:@[ @1 ]
dataType:getMPSDataType(grad_output.scalar_type())];

MPSGraphTensor* unitTensor = [mpsGraph
constantWithScalar:1.0f
shape:@[ @1 ]
dataType:getMPSDataType(grad_output.scalar_type())];

MPSGraphTensor* threeTensor = [mpsGraph
constantWithScalar:3.0f
shape:@[ @1 ]
dataType:getMPSDataType(grad_output.scalar_type())];

MPSGraphTensor* negativeThreeTensor = [mpsGraph
constantWithScalar:-3.0f
shape:@[ @1 ]
dataType:getMPSDataType(grad_output.scalar_type())];

MPSGraphTensor* halfTensor = [mpsGraph
constantWithScalar:0.5f
shape:@[ @1 ]
dataType:getMPSDataType(grad_output.scalar_type())];

MPSGraphTensor* tempTensor =
[mpsGraph divisionWithPrimaryTensor:inputTensor
secondaryTensor:threeTensor
name:nil];

MPSGraphTensor* weightedTensor =
[mpsGraph additionWithPrimaryTensor:tempTensor
secondaryTensor:halfTensor
name:nil];

MPSGraphTensor* lessThanMinPredicateTensor = [mpsGraph
lessThanOrEqualToWithPrimaryTensor:inputTensor
secondaryTensor:negativeThreeTensor
name:nil];

MPSGraphTensor* lessThanMaxPredicateTensor =
[mpsGraph lessThanWithPrimaryTensor:inputTensor
secondaryTensor:threeTensor
name:nil];

MPSGraphTensor* lessThanMaxGradTensor =
[mpsGraph selectWithPredicateTensor:lessThanMaxPredicateTensor
truePredicateTensor:weightedTensor
falsePredicateTensor:unitTensor
name:nil];

MPSGraphTensor* gradTensor =
[mpsGraph selectWithPredicateTensor:lessThanMinPredicateTensor
truePredicateTensor:zeroTensor
falsePredicateTensor:lessThanMaxGradTensor
name:nil];
MPSGraphTensor* gradInputTensor =
[mpsGraph multiplicationWithPrimaryTensor:gradTensor
secondaryTensor:gradOutputTensor
name:nil];

newCachedGraph->gradOutputTensor_ = gradOutputTensor;
newCachedGraph->inputTensor_ = inputTensor;
newCachedGraph->gradInputTensor_ = gradInputTensor;
}
return newCachedGraph;
});
cachedGraph = static_cast<CachedGraph*>(tmpCachedGraph);
newCachedGraph->gradOutputTensor_ = gradOutputTensor;
newCachedGraph->inputTensor_ = inputTensor;
newCachedGraph->gradInputTensor_ = gradInputTensor;
}
return newCachedGraph;
});
}

Placeholder gradOutputPlaceholder =
Placeholder(cachedGraph->gradOutputTensor_, grad_output);
Placeholder gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output);
Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self);
Placeholder gradInputPlaceholder =
Placeholder(cachedGraph->gradInputTensor_, grad_input);
Placeholder gradInputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, grad_input);

// Create dictionary of inputs and outputs
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
gradOutputPlaceholder.getMPSGraphTensor() :
gradOutputPlaceholder.getMPSGraphTensorData(),
selfPlaceholder.getMPSGraphTensor() :
selfPlaceholder.getMPSGraphTensorData()
gradOutputPlaceholder.getMPSGraphTensor() : gradOutputPlaceholder.getMPSGraphTensorData(),
selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData()
};

NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
gradInputPlaceholder.getMPSGraphTensor() :
gradInputPlaceholder.getMPSGraphTensorData()
gradInputPlaceholder.getMPSGraphTensor() : gradInputPlaceholder.getMPSGraphTensorData()
};

runMPSGraph(stream, cachedGraph->graph(), feeds, results);
runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results);
}
return grad_input;
}
Expand Down