Skip to content

Commit

Permalink
Fix the type cast issue with Binary Ops (#158)
Browse files Browse the repository at this point in the history
* Fix data type issues for logaddexp and logaddexp2 ops

* Fix the type cast issue with Binary Ops

* Move several ops out of Blocklist in TestConsistency

* Move good ops from block list to allow list
  • Loading branch information
razarmehr authored Oct 31, 2022
1 parent da16165 commit 2e118ff
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 188 deletions.
184 changes: 27 additions & 157 deletions aten/src/ATen/native/mps/operations/BinaryOps.mm
Original file line number Diff line number Diff line change
Expand Up @@ -72,38 +72,21 @@ void binaryOpTensor(const Tensor& self, const Tensor& other, const Scalar& alpha
MPSGraphTensor* secondaryCastTensor = newCachedGraph->secondaryTensor;

// this type inference is only required at the time of graph creation
const ScalarType common_dtype = c10::promoteTypes(self.scalar_type(), other.scalar_type());

// Condition -
// 1. Division operation
// 2. Inputs are not float
bool div_condition = op_name.rfind("div", 0) == 0
&& (!(common_dtype == ScalarType::Float || common_dtype == ScalarType::Half));

auto compute_type = ScalarType::Float;

if(div_condition) {

if(output_.scalar_type() == ScalarType::Float || output_.scalar_type() == ScalarType::Half)
compute_type = output_.scalar_type();

primaryCastTensor = castMPSTensor(mpsGraph, newCachedGraph->primaryTensor, compute_type);
secondaryCastTensor = castMPSTensor(mpsGraph, newCachedGraph->secondaryTensor, compute_type);
ScalarType common_dtype = c10::promoteTypes(self.scalar_type(), other.scalar_type());
// Integer input must be cast to float if output is float
if (isIntegralType(common_dtype, true) && isFloatingType(output.scalar_type())) {
common_dtype = output_.scalar_type();
}
if (self.scalar_type() != common_dtype) {
primaryCastTensor = castMPSTensor(mpsGraph, newCachedGraph->primaryTensor, common_dtype);
}
else {
if (self.scalar_type() != common_dtype) {
primaryCastTensor = castMPSTensor(mpsGraph, newCachedGraph->primaryTensor, common_dtype);
}
if (other.scalar_type() != common_dtype) {
secondaryCastTensor = castMPSTensor(mpsGraph, newCachedGraph->secondaryTensor, common_dtype);
}
if (other.scalar_type() != common_dtype) {
secondaryCastTensor = castMPSTensor(mpsGraph, newCachedGraph->secondaryTensor, common_dtype);
}
newCachedGraph->outputTensor = binaryBlock(newCachedGraph, primaryCastTensor, secondaryCastTensor);
// Cast output tensor to an expected type if needed, which addresses discrepancy when int64 scalar is added to int32 tensor
// Output tensor should have been promoted but it remains an int32 tensor

if ((div_condition && compute_type != output_.scalar_type()) ||
output_.scalar_type() != common_dtype) {
if (output_.scalar_type() != common_dtype) {
newCachedGraph->outputTensor = castMPSTensor(mpsGraph, newCachedGraph->outputTensor, output_.scalar_type());
}
}
Expand Down Expand Up @@ -319,139 +302,26 @@ Tensor floor_divide_mps(const Tensor& self, const Tensor& other) {

TORCH_IMPL_FUNC(logaddexp_out_mps) (const Tensor& self, const Tensor& other, const Tensor& output)
{
using namespace mps;
MPSStream* stream = getCurrentMPSStream();

if (&output != &self) {
output.resize_(self.sizes());;
}

// Derive from MPSCachedGraph
struct CachedGraph : public MPSCachedGraph
{
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
MPSGraphTensor *inputTensor_ = nil;
MPSGraphTensor *otherTensor_ = nil;
MPSGraphTensor *outputTensor_ = nil;
};

MPSGraphCache* cache_ = MPSGraphCache::getInstance();

@autoreleasepool {
string key = "log_base_e_out_mps:" + getTensorsStringKey({self, other});
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));

if(!cachedGraph) {
MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
CachedGraph *newCachedGraph = nil;

@autoreleasepool {
MPSGraph* mpsGraph = make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);
MPSGraphTensor* xTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
MPSGraphTensor* yTensor = mpsGraphRankedPlaceHolder(mpsGraph, other);
MPSGraphTensor* ePowXTensor = [mpsGraph exponentWithTensor:xTensor
name:nil];
MPSGraphTensor* ePowYTensor = [mpsGraph exponentWithTensor:yTensor
name:nil];
MPSGraphTensor* sumTensor = [mpsGraph additionWithPrimaryTensor:ePowXTensor
secondaryTensor:ePowYTensor
name:nil];
MPSGraphTensor* outputTensor = [mpsGraph logarithmWithTensor:sumTensor
name:nil];

newCachedGraph->inputTensor_ = xTensor;
newCachedGraph->otherTensor_ = yTensor;
newCachedGraph->outputTensor_ = outputTensor;
}
return newCachedGraph;
});
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
}

Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self);
Placeholder otherPlaceholder = Placeholder(cachedGraph->otherTensor_, other);
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output);

NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData(),
otherPlaceholder.getMPSGraphTensor() : otherPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
};

runMPSGraph(stream, cachedGraph->graph(), feeds, results);
}

}
mps::BinaryOpBlock logaddexp_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
MPSGraph* mpsGraph = cachedGraph->graph();
MPSGraphTensor* sumTensor = [mpsGraph additionWithPrimaryTensor:[mpsGraph exponentWithTensor:primaryCastTensor name:nil]
secondaryTensor:[mpsGraph exponentWithTensor:secondaryCastTensor name:nil]
name:nil];
return [mpsGraph logarithmWithTensor:sumTensor name:nil];
};
mps::binaryOpTensor(self, other, Scalar(1.0), output, "logaddexp_out_mps", logaddexp_op_block);
}

TORCH_IMPL_FUNC(logaddexp2_out_mps) (const Tensor& self, const Tensor& other, const Tensor& output)
{
using namespace mps;
MPSStream* stream = getCurrentMPSStream();

if (&output != &self) {
output.resize_(self.sizes());;
}

// Derive from MPSCachedGraph
struct CachedGraph : public MPSCachedGraph
{
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
MPSGraphTensor *inputTensor_ = nil;
MPSGraphTensor *otherTensor_ = nil;
MPSGraphTensor *outputTensor_ = nil;
};

MPSGraphCache* cache_ = MPSGraphCache::getInstance();

@autoreleasepool {
string key = "log_base_two_out_mps:" + getTensorsStringKey({self, other});
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));

if(!cachedGraph) {
MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
CachedGraph *newCachedGraph = nil;

@autoreleasepool {
MPSGraph* mpsGraph = make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);
MPSGraphTensor* xTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
MPSGraphTensor* yTensor = mpsGraphRankedPlaceHolder(mpsGraph, other);
MPSGraphTensor* twoPowXTensor = [mpsGraph exponentBase2WithTensor:xTensor
name:nil];
MPSGraphTensor* twoPowYTensor = [mpsGraph exponentBase2WithTensor:yTensor
name:nil];
MPSGraphTensor* sumTensor = [mpsGraph additionWithPrimaryTensor:twoPowXTensor
secondaryTensor:twoPowYTensor
name:nil];
MPSGraphTensor* outputTensor = [mpsGraph logarithmBase2WithTensor:sumTensor
name:nil];

newCachedGraph->inputTensor_ = xTensor;
newCachedGraph->otherTensor_ = yTensor;
newCachedGraph->outputTensor_ = outputTensor;
}
return newCachedGraph;
});
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
}

Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self);
Placeholder otherPlaceholder = Placeholder(cachedGraph->otherTensor_, other);
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output);

NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData(),
otherPlaceholder.getMPSGraphTensor() : otherPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
};

runMPSGraph(stream, cachedGraph->graph(), feeds, results);
}
mps::BinaryOpBlock logaddexp2_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
MPSGraph* mpsGraph = cachedGraph->graph();
MPSGraphTensor* sumTensor = [mpsGraph additionWithPrimaryTensor:[mpsGraph exponentBase2WithTensor:primaryCastTensor name:nil]
secondaryTensor:[mpsGraph exponentBase2WithTensor:secondaryCastTensor name:nil]
name:nil];
return [mpsGraph logarithmBase2WithTensor:sumTensor name:nil];
};
mps::binaryOpTensor(self, other, Scalar(1.0), output, "logaddexp2_out_mps", logaddexp2_op_block);
}

} // namespace native
Expand Down
Loading

0 comments on commit 2e118ff

Please sign in to comment.