Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 27 additions & 157 deletions aten/src/ATen/native/mps/operations/BinaryOps.mm
Original file line number Diff line number Diff line change
Expand Up @@ -72,38 +72,21 @@ void binaryOpTensor(const Tensor& self, const Tensor& other, const Scalar& alpha
MPSGraphTensor* secondaryCastTensor = newCachedGraph->secondaryTensor;

// this type inference is only required at the time of graph creation
const ScalarType common_dtype = c10::promoteTypes(self.scalar_type(), other.scalar_type());

// Condition -
// 1. Division operation
// 2. Inputs are not float
bool div_condition = op_name.rfind("div", 0) == 0
&& (!(common_dtype == ScalarType::Float || common_dtype == ScalarType::Half));

auto compute_type = ScalarType::Float;

if(div_condition) {

if(output_.scalar_type() == ScalarType::Float || output_.scalar_type() == ScalarType::Half)
compute_type = output_.scalar_type();

primaryCastTensor = castMPSTensor(mpsGraph, newCachedGraph->primaryTensor, compute_type);
secondaryCastTensor = castMPSTensor(mpsGraph, newCachedGraph->secondaryTensor, compute_type);
ScalarType common_dtype = c10::promoteTypes(self.scalar_type(), other.scalar_type());
// Integer input must be cast to float if output is float
if (isIntegralType(common_dtype, true) && isFloatingType(output.scalar_type())) {
common_dtype = output_.scalar_type();
}
if (self.scalar_type() != common_dtype) {
primaryCastTensor = castMPSTensor(mpsGraph, newCachedGraph->primaryTensor, common_dtype);
}
else {
if (self.scalar_type() != common_dtype) {
primaryCastTensor = castMPSTensor(mpsGraph, newCachedGraph->primaryTensor, common_dtype);
}
if (other.scalar_type() != common_dtype) {
secondaryCastTensor = castMPSTensor(mpsGraph, newCachedGraph->secondaryTensor, common_dtype);
}
if (other.scalar_type() != common_dtype) {
secondaryCastTensor = castMPSTensor(mpsGraph, newCachedGraph->secondaryTensor, common_dtype);
}
newCachedGraph->outputTensor = binaryBlock(newCachedGraph, primaryCastTensor, secondaryCastTensor);
// Cast output tensor to an expected type if needed, which addresses discrepancy when int64 scalar is added to int32 tensor
// Output tensor should have been promoted but it remains an int32 tensor

if ((div_condition && compute_type != output_.scalar_type()) ||
output_.scalar_type() != common_dtype) {
if (output_.scalar_type() != common_dtype) {
newCachedGraph->outputTensor = castMPSTensor(mpsGraph, newCachedGraph->outputTensor, output_.scalar_type());
}
}
Expand Down Expand Up @@ -319,139 +302,26 @@ Tensor floor_divide_mps(const Tensor& self, const Tensor& other) {

TORCH_IMPL_FUNC(logaddexp_out_mps) (const Tensor& self, const Tensor& other, const Tensor& output)
{
using namespace mps;
MPSStream* stream = getCurrentMPSStream();

if (&output != &self) {
output.resize_(self.sizes());;
}

// Derive from MPSCachedGraph
struct CachedGraph : public MPSCachedGraph
{
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
MPSGraphTensor *inputTensor_ = nil;
MPSGraphTensor *otherTensor_ = nil;
MPSGraphTensor *outputTensor_ = nil;
};

MPSGraphCache* cache_ = MPSGraphCache::getInstance();

@autoreleasepool {
string key = "log_base_e_out_mps:" + getTensorsStringKey({self, other});
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));

if(!cachedGraph) {
MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
CachedGraph *newCachedGraph = nil;

@autoreleasepool {
MPSGraph* mpsGraph = make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);
MPSGraphTensor* xTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
MPSGraphTensor* yTensor = mpsGraphRankedPlaceHolder(mpsGraph, other);
MPSGraphTensor* ePowXTensor = [mpsGraph exponentWithTensor:xTensor
name:nil];
MPSGraphTensor* ePowYTensor = [mpsGraph exponentWithTensor:yTensor
name:nil];
MPSGraphTensor* sumTensor = [mpsGraph additionWithPrimaryTensor:ePowXTensor
secondaryTensor:ePowYTensor
name:nil];
MPSGraphTensor* outputTensor = [mpsGraph logarithmWithTensor:sumTensor
name:nil];

newCachedGraph->inputTensor_ = xTensor;
newCachedGraph->otherTensor_ = yTensor;
newCachedGraph->outputTensor_ = outputTensor;
}
return newCachedGraph;
});
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
}

Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self);
Placeholder otherPlaceholder = Placeholder(cachedGraph->otherTensor_, other);
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output);

NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData(),
otherPlaceholder.getMPSGraphTensor() : otherPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
};

runMPSGraph(stream, cachedGraph->graph(), feeds, results);
}

}
mps::BinaryOpBlock logaddexp_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
MPSGraph* mpsGraph = cachedGraph->graph();
MPSGraphTensor* sumTensor = [mpsGraph additionWithPrimaryTensor:[mpsGraph exponentWithTensor:primaryCastTensor name:nil]
secondaryTensor:[mpsGraph exponentWithTensor:secondaryCastTensor name:nil]
name:nil];
return [mpsGraph logarithmWithTensor:sumTensor name:nil];
};
mps::binaryOpTensor(self, other, Scalar(1.0), output, "logaddexp_out_mps", logaddexp_op_block);
}

TORCH_IMPL_FUNC(logaddexp2_out_mps) (const Tensor& self, const Tensor& other, const Tensor& output)
{
using namespace mps;
MPSStream* stream = getCurrentMPSStream();

if (&output != &self) {
output.resize_(self.sizes());;
}

// Derive from MPSCachedGraph
struct CachedGraph : public MPSCachedGraph
{
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
MPSGraphTensor *inputTensor_ = nil;
MPSGraphTensor *otherTensor_ = nil;
MPSGraphTensor *outputTensor_ = nil;
};

MPSGraphCache* cache_ = MPSGraphCache::getInstance();

@autoreleasepool {
string key = "log_base_two_out_mps:" + getTensorsStringKey({self, other});
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));

if(!cachedGraph) {
MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
CachedGraph *newCachedGraph = nil;

@autoreleasepool {
MPSGraph* mpsGraph = make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);
MPSGraphTensor* xTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
MPSGraphTensor* yTensor = mpsGraphRankedPlaceHolder(mpsGraph, other);
MPSGraphTensor* twoPowXTensor = [mpsGraph exponentBase2WithTensor:xTensor
name:nil];
MPSGraphTensor* twoPowYTensor = [mpsGraph exponentBase2WithTensor:yTensor
name:nil];
MPSGraphTensor* sumTensor = [mpsGraph additionWithPrimaryTensor:twoPowXTensor
secondaryTensor:twoPowYTensor
name:nil];
MPSGraphTensor* outputTensor = [mpsGraph logarithmBase2WithTensor:sumTensor
name:nil];

newCachedGraph->inputTensor_ = xTensor;
newCachedGraph->otherTensor_ = yTensor;
newCachedGraph->outputTensor_ = outputTensor;
}
return newCachedGraph;
});
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
}

Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self);
Placeholder otherPlaceholder = Placeholder(cachedGraph->otherTensor_, other);
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output);

NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData(),
otherPlaceholder.getMPSGraphTensor() : otherPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
};

runMPSGraph(stream, cachedGraph->graph(), feeds, results);
}
mps::BinaryOpBlock logaddexp2_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
MPSGraph* mpsGraph = cachedGraph->graph();
MPSGraphTensor* sumTensor = [mpsGraph additionWithPrimaryTensor:[mpsGraph exponentBase2WithTensor:primaryCastTensor name:nil]
secondaryTensor:[mpsGraph exponentBase2WithTensor:secondaryCastTensor name:nil]
name:nil];
return [mpsGraph logarithmBase2WithTensor:sumTensor name:nil];
};
mps::binaryOpTensor(self, other, Scalar(1.0), output, "logaddexp2_out_mps", logaddexp2_op_block);
}

} // namespace native
Expand Down
Loading