Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOSA] Add logit, log1p, log10 and add promote type to unary fponly ops #3900

Merged
merged 1 commit into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
224 changes: 200 additions & 24 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ using namespace mlir::torch::Torch;

namespace {

// These legalizations are for unary ops with only for floating point datatypes.
// There is no supported quantized integer mode for these.
// These legalizations are for unary ops with promoting input to floating-point
// datatypes only. There is no supported quantized integer mode for these.
template <typename AtenOpT, typename TosaOpT>
class ConvertAtenUnaryFPOnlyOp : public OpConversionPattern<AtenOpT> {
class ConvertAtenUnaryPromoteToFPOp : public OpConversionPattern<AtenOpT> {
public:
using OpConversionPattern<AtenOpT>::OpConversionPattern;
using OpAdaptor = typename AtenOpT::Adaptor;
Expand All @@ -51,17 +51,22 @@ class ConvertAtenUnaryFPOnlyOp : public OpConversionPattern<AtenOpT> {
return rewriter.notifyMatchFailure(op,
"Only Tensor types supported in TOSA");

if (isa<mlir::FloatType>(selfTy.getElementType())) {
rewriter.replaceOpWithNewOp<TosaOpT>(
op,
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()),
self);
return success();
} else {
auto resultTy = dyn_cast<TensorType>(
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()));

if (!isa<mlir::FloatType>(resultTy.getElementType()))
return rewriter.notifyMatchFailure(
op, "Only floating-point datatype legalization supported");
}
op, "Only floating-point datatype result types are supported");

// Non floating point inputs are not supported in TOSA so we cast the input
// to result type
if (!isa<mlir::FloatType>(selfTy.getElementType()))
self = tosa::promoteType(rewriter, self, resultTy);

rewriter.replaceOpWithNewOp<TosaOpT>(op, resultTy, self);

return success();
}
};

Expand Down Expand Up @@ -2922,24 +2927,32 @@ template <>
LogicalResult ConvertAtenOp<AtenLog2Op>::matchAndRewrite(
AtenLog2Op op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto self = adaptor.getSelf();

// Not a tensor type.
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
auto selfType = dyn_cast<TensorType>(self.getType());
if (!selfType)
return rewriter.notifyMatchFailure(
op, "Only tensor types are currently supported");

auto outType =
dyn_cast<TensorType>(getTypeConverter()->convertType(op.getType()));

// If input is not a float type then cast it to output type
auto selfElemTy = selfType.getElementType();
if (!isa<mlir::FloatType>(selfElemTy))
self = tosa::promoteType(rewriter, self, outType);

// Constant value of ln2.
SmallVector<int64_t> ln2Shape(selfType.getRank(), 1);
auto ln2Op = tosa::getConstTensor<float>(rewriter, op, {0.69314718056f},
ln2Shape, selfType.getElementType())
ln2Shape, outType.getElementType())
.value();

auto rcpOp =
rewriter.create<tosa::ReciprocalOp>(op.getLoc(), ln2Op.getType(), ln2Op);

auto outType = getTypeConverter()->convertType(op.getType());
auto logOp =
rewriter.create<tosa::LogOp>(op.getLoc(), outType, adaptor.getSelf());
auto logOp = rewriter.create<tosa::LogOp>(op.getLoc(), outType, self);
rewriter.replaceOpWithNewOp<tosa::MulOp>(op, outType, logOp, rcpOp,
/*shift=*/0);

Expand Down Expand Up @@ -8025,6 +8038,166 @@ class ConvertUpsampleNearest2dForward : public OpConversionPattern<AtenOpT> {
}
};

// Legalization for aten.logit
template <>
LogicalResult ConvertAtenOp<AtenLogitOp>::matchAndRewrite(
AtenLogitOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// Logit formula:
// result = log(zi / (1 - zi))
// Where: if eps is not None:
// zi = input clampled to [eps, 1 - eps]
// else:
// zi = input
auto self = adaptor.getSelf();

auto selfType = dyn_cast<TensorType>(self.getType());
if (!selfType)
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");

auto resultType =
dyn_cast<TensorType>(typeConverter->convertType(op.getType()));
auto resultElemTy = resultType.getElementType();

if (!isa<mlir::FloatType>(resultElemTy))
return rewriter.notifyMatchFailure(
op, "Only floating-point datatype result types are supported");

// If input is not a float type then cast it to result element type
auto selfElemTy = selfType.getElementType();
if (!isa<mlir::FloatType>(selfElemTy))
self = tosa::promoteType(rewriter, self, resultType);

bool isEpsNone = isa<Torch::NoneType>(op.getEps().getType());

double eps;
if (!isEpsNone && !matchPattern(op.getEps(), m_TorchConstantFloat(&eps)))
return rewriter.notifyMatchFailure(op,
"Non-const eps value is not supported");

auto zi = self;

// Clamp input to [eps, 1 - eps] when eps is not None
if (!isEpsNone) {
zi = rewriter
.create<tosa::ClampOp>(
op->getLoc(), resultType, self,
rewriter.getI64IntegerAttr(static_cast<int64_t>(eps)),
rewriter.getI64IntegerAttr(static_cast<int64_t>(1 - eps)),
rewriter.getF32FloatAttr(static_cast<float>(eps)),
rewriter.getF32FloatAttr(static_cast<float>(1 - eps)))
.getResult();
}

auto one =
tosa::getConstTensor<float>(rewriter, op, 1.0f, {}, resultElemTy).value();

auto oneMinusZi =
rewriter.create<tosa::SubOp>(op->getLoc(), resultType, one, zi);

auto oneMinusZiReciprocal = rewriter.create<tosa::ReciprocalOp>(
op->getLoc(), resultType, oneMinusZi.getResult());

auto mulOp = rewriter.create<tosa::MulOp>(op->getLoc(), resultType, zi,
oneMinusZiReciprocal.getResult(),
/*shift=*/0);

auto result =
rewriter.create<tosa::LogOp>(op->getLoc(), resultType, mulOp.getResult());

rewriter.replaceOp(op, {result.getResult()});

return success();
}

// Legalization for aten.log1p
template <>
LogicalResult ConvertAtenOp<AtenLog1pOp>::matchAndRewrite(
AtenLog1pOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// log1p formula:
// yi = log(xi + 1)
auto self = adaptor.getSelf();

auto selfType = dyn_cast<TensorType>(self.getType());
if (!selfType)
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");

auto resultType =
dyn_cast<TensorType>(typeConverter->convertType(op.getType()));
auto resultElemTy = resultType.getElementType();

if (!isa<mlir::FloatType>(resultElemTy))
return rewriter.notifyMatchFailure(
op, "Only floating-point datatype result types are supported");

// If input is not a float type then cast it to result element type
auto selfElemTy = selfType.getElementType();
if (!isa<mlir::FloatType>(selfElemTy))
self = tosa::promoteType(rewriter, self, resultType);

auto one =
tosa::getConstTensor<float>(rewriter, op, 1.0f, {}, resultElemTy).value();

auto addOp =
rewriter.create<tosa::AddOp>(op->getLoc(), resultType, self, one);

auto result =
rewriter.create<tosa::LogOp>(op->getLoc(), resultType, addOp.getResult());

rewriter.replaceOp(op, {result.getResult()});

return success();
}

// Legalization for aten.log10
template <>
LogicalResult ConvertAtenOp<AtenLog10Op>::matchAndRewrite(
AtenLog10Op op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// log10 formula (using log base changing formula since TOSA doesn't have a
// builtin log10 op):
// yi = log(xi) / log(10)
auto self = adaptor.getSelf();

auto selfType = dyn_cast<TensorType>(self.getType());
if (!selfType)
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");

auto resultType =
dyn_cast<TensorType>(typeConverter->convertType(op.getType()));
auto resultElemTy = resultType.getElementType();

if (!isa<mlir::FloatType>(resultElemTy))
return rewriter.notifyMatchFailure(
op, "Only floating-point datatype result types are supported");

// If input is not a float type then cast it to result element type
auto selfElemTy = selfType.getElementType();
if (!isa<mlir::FloatType>(selfElemTy))
self = tosa::promoteType(rewriter, self, resultType);

auto ten = tosa::getConstTensor<float>(rewriter, op, 10.0f, {}, resultElemTy)
.value();

auto logOfSelf = rewriter.create<tosa::LogOp>(op->getLoc(), resultType, self);

auto constType = RankedTensorType::get({}, resultElemTy);

auto logOfTen = rewriter.create<tosa::LogOp>(op->getLoc(), constType, ten);

auto reciprocalOp = rewriter.create<tosa::ReciprocalOp>(
op->getLoc(), constType, logOfTen.getResult());

auto result = rewriter.create<tosa::MulOp>(
op->getLoc(), resultType, logOfSelf.getResult(), reciprocalOp.getResult(),
/*shift=*/0);

rewriter.replaceOp(op, {result.getResult()});

return success();
}

} // namespace

// -----------------------------------------------------------------------------
Expand Down Expand Up @@ -8069,13 +8242,13 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {

RewritePatternSet patterns(context);

#define INSERT_UNARY_FPONLY_PATTERN(AtenOp, TosaOp) \
#define INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenOp, TosaOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenUnaryFPOnlyOp<AtenOp, TosaOp>>(typeConverter, \
context);
INSERT_UNARY_FPONLY_PATTERN(AtenLogOp, tosa::LogOp)
INSERT_UNARY_FPONLY_PATTERN(AtenExpOp, tosa::ExpOp)
#undef INSERT_UNARY_FPONLY_PATTERN
patterns.add<ConvertAtenUnaryPromoteToFPOp<AtenOp, TosaOp>>(typeConverter, \
context);
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenLogOp, tosa::LogOp)
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenExpOp, tosa::ExpOp)
#undef INSERT_UNARY_PROMOTE_TO_FP_PATTERN

#define INSERT_UNARY_PATTERN(AtenOp, TosaOp) \
target.addIllegalOp<AtenOp>(); \
Expand Down Expand Up @@ -8364,6 +8537,9 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
INSERT_ATENOP_PATTERN(AtenReplicationPad2dOp);
INSERT_ATENOP_PATTERN(PrimsSplitDimOp);
INSERT_ATENOP_PATTERN(AtenOuterOp);
INSERT_ATENOP_PATTERN(AtenLogitOp);
INSERT_ATENOP_PATTERN(AtenLog1pOp);
INSERT_ATENOP_PATTERN(AtenLog10Op);
#undef INSERT_ATENOP_PATTERN

#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \
Expand Down
35 changes: 24 additions & 11 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1729,6 +1729,22 @@
"RandIntPinMemoryModule_basic",
"RenormModuleFloat16_basic",
"SplitDimStaticModule_basic",
"Deg2radModule_basic",
"ElementwiseExpIntModule_basic",
"ElementwiseLog10IntModule_basic",
"ElementwiseLog10Module_basic",
"ElementwiseLog1pModule_basic",
"ElementwiseLog2IntModule_basic",
"ElementwiseLogIntModule_basic",
"ElementwiseLogitModule_basic",
"ElementwiseMishModule_basic",
"L1LossMeanReductionModule_basic",
"L1LossNoReductionModule_basic",
"L1LossSumReductionModule_basic",
"RandIntLowModule_basic",
"RandIntModule_basic",
"RandIntPinMemoryModule_basic",
"SoftplusModule_basic",
"ReflectionPad1dModule2dInput_Right",
"ReflectionPad1dModule2dInput_basic",
"ReflectionPad1dModule3dInput_Left",
Expand Down Expand Up @@ -3416,6 +3432,8 @@
}

FX_IMPORTER_TOSA_XFAIL_SET = {
"AtenFftRfft2DLastDim_basic",
"AtenFftRfft2DMiddleDim_basic",
"IsInfiniteModule_basic",
"LayerNormFwAndBwModule_basic",
"LayerNormManualFwAndBwModule_basic",
Expand Down Expand Up @@ -3627,17 +3645,9 @@
"ElementwiseDequantizePerChannelModule_basic",
"ElementwiseDequantizePerTensorModule_basic",
"ElementwiseErfIntModule_basic",
"ElementwiseExpIntModule_basic",
"ElementwiseExpm1IntModule_basic",
"ElementwiseExpm1Module_basic",
"ElementwiseIntTensorLtFloatScalarModule_basic",
"ElementwiseLog10IntModule_basic",
"ElementwiseLog10Module_basic",
"ElementwiseLog1pModule_basic",
"ElementwiseLog2IntModule_basic",
"ElementwiseLogIntModule_basic",
"ElementwiseLogitModule_basic",
"ElementwiseMishModule_basic",
"ElementwiseMulTensorComplexDiffModule_basic",
"ElementwiseMulTensorComplexModule_basic",
"ElementwiseQuantizePerTensorModule_basic",
Expand Down Expand Up @@ -3755,6 +3765,7 @@
"NumelModule_basic",
"NumelZeroRankModule_basic",
"OnesLikeModule_falsePinMemory",
"PowIntIntModule_basic",
"PowIntFloatModule_basic",
"PrimMaxIntModule_basic",
"PrimMinIntDynamicModule_basic",
Expand Down Expand Up @@ -3822,7 +3833,6 @@
"SliceOutOfLowerBoundEndIndexModule_basic",
"SliceOutOfLowerBoundStartIndexModule_basic",
"SliceSizeTwoStepModule_basic",
"SoftplusModule_basic",
"SortIntListReverse_basic",
"SortIntList_basic",
"SortTensorDescending_basic",
Expand Down Expand Up @@ -3902,6 +3912,11 @@
}

ONNX_TOSA_XFAIL_SET = {
"AtenFftRfft2DLastDim_basic",
"AtenFftRfft2DMiddleDim_basic",
"PowFloatIntModule_basic",
"PowIntFloatModule_basic",
"PowIntIntModule_basic",
"ColumnStack0dModule_basic",
"ColumnStack1dModule_basic",
"ColumnStackBasicIntModule_basic",
Expand Down Expand Up @@ -4311,7 +4326,6 @@
"ElementwiseLog2IntModule_basic",
"ElementwiseLogIntModule_basic",
"ElementwiseLtDiffWidthScalarModule_basic",
"ElementwiseMishModule_basic",
"ElementwiseMulScalarModule_basic",
"ElementwiseMulTensorComplexDiffModule_basic",
"ElementwiseMulTensorComplexModule_basic",
Expand Down Expand Up @@ -4755,7 +4769,6 @@
"SoftmaxIntModule_basic",
"SoftmaxIntNegDimModule_basic",
"SoftmaxIntNonNoneDtypeModule_basic",
"SoftplusModule_basic",
"SortIntListReverse_basic",
"SortIntList_basic",
"SortTensorDescending_basic",
Expand Down
Loading
Loading