diff --git a/stablehlo/conversions/tosa/tests/legalize_quant_ops_to_tosa_rescale.mlir b/stablehlo/conversions/tosa/tests/legalize_quant_ops_to_tosa_rescale.mlir index e5e14f07dbb..4d86683dfd9 100644 --- a/stablehlo/conversions/tosa/tests/legalize_quant_ops_to_tosa_rescale.mlir +++ b/stablehlo/conversions/tosa/tests/legalize_quant_ops_to_tosa_rescale.mlir @@ -27,3 +27,94 @@ func.func @sub(%arg0 : tensor<2x2x!quant.uniform>, -> tensor<2x2x!quant.uniform> return %0 : tensor<2x2x!quant.uniform> } + +// ----- +// CHECK-LABEL: @mul +func.func @mul(%arg0 : tensor<2x2x!quant.uniform>, + %arg1 : tensor<2x2x!quant.uniform>) -> tensor<2x2x!quant.uniform> { + // CHECK-DAG: %[[V0:.+]] = tosa.rescale %arg0 {double_round = false, input_zp = -1 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} + // CHECK-DAG: %[[V1:.+]] = tosa.rescale %arg1 {double_round = false, input_zp = -1 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} + // CHECK: %[[V2:.+]] = stablehlo.multiply %[[V0]], %[[V1]] : tensor<2x2xi32> + // CHECK: %[[V3:.+]] = tosa.rescale %[[V2]] {double_round = false, input_zp = 0 : i32, multiplier = array, output_zp = -1 : i32, per_channel = false, scale32 = true, shift = array} + // CHECK: return %[[V3]] : tensor<2x2x!quant.uniform> + %0 = "stablehlo.multiply"(%arg0, %arg1) : (tensor<2x2x!quant.uniform>, tensor<2x2x!quant.uniform>) + -> tensor<2x2x!quant.uniform> + return %0 : tensor<2x2x!quant.uniform> +} + +// ----- +// CHECK-LABEL: @div +func.func @div(%arg0 : tensor<2x2x!quant.uniform>, + %arg1 : tensor<2x2x!quant.uniform>) -> tensor<2x2x!quant.uniform> { + // CHECK-DAG: %[[V0:.+]] = tosa.rescale %arg0 {double_round = false, input_zp = -1 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} + // CHECK-DAG: %[[V1:.+]] = tosa.rescale %arg1 {double_round = false, input_zp = -2 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} + // CHECK: %[[V2:.+]] = stablehlo.divide %[[V0]], %[[V1]] : tensor<2x2xi32> + // CHECK: %[[V3:.+]] = tosa.rescale %[[V2]] {double_round = false, input_zp = 0 : i32, multiplier = array, output_zp = -3 : i32, per_channel = false, scale32 = true, shift = array} + // CHECK: return %[[V3]] : tensor<2x2x!quant.uniform> + %0 = "stablehlo.divide"(%arg0, %arg1) : (tensor<2x2x!quant.uniform>, tensor<2x2x!quant.uniform>) + -> tensor<2x2x!quant.uniform> + return %0 : tensor<2x2x!quant.uniform> +} + +// ----- +// CHECK-LABEL: @max +func.func @max(%arg0 : tensor<2x2x!quant.uniform>, + %arg1 : tensor<2x2x!quant.uniform>) -> tensor<2x2x!quant.uniform> { + // CHECK-DAG: %[[V0:.+]] = tosa.rescale %arg0 {double_round = false, input_zp = -1 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} + // CHECK-DAG: %[[V1:.+]] = tosa.rescale %arg1 {double_round = false, input_zp = -2 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} + // CHECK: %[[V2:.+]] = stablehlo.maximum %[[V0]], %[[V1]] : tensor<2x2xi32> + // CHECK: %[[V3:.+]] = tosa.rescale %[[V2]] {double_round = false, input_zp = 0 : i32, multiplier = array, output_zp = -3 : i32, per_channel = false, scale32 = true, shift = array} + // CHECK: return %[[V3]] : tensor<2x2x!quant.uniform> + %0 = "stablehlo.maximum"(%arg0, %arg1) : (tensor<2x2x!quant.uniform>, tensor<2x2x!quant.uniform>) + -> tensor<2x2x!quant.uniform> + return %0 : tensor<2x2x!quant.uniform> +} + +// ----- +// CHECK-LABEL: @min +func.func @min(%arg0 : tensor<2x2x!quant.uniform>, + %arg1 : tensor<2x2x!quant.uniform>) -> tensor<2x2x!quant.uniform> { + // CHECK-DAG: %[[V0:.+]] = tosa.rescale %arg0 {double_round = false, input_zp = -1 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} + // CHECK-DAG: %[[V1:.+]] = tosa.rescale %arg1 {double_round = false, input_zp = -2 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} + // CHECK: %[[V2:.+]] = stablehlo.minimum %[[V0]], %[[V1]] : tensor<2x2xi32> + // CHECK: %[[V3:.+]] = tosa.rescale %[[V2]] {double_round = false, input_zp = 0 : i32, multiplier = array, output_zp = -3 : i32, per_channel = false, scale32 = true, shift = array} + // CHECK: return %[[V3]] : tensor<2x2x!quant.uniform> + %0 = "stablehlo.minimum"(%arg0, %arg1) : (tensor<2x2x!quant.uniform>, tensor<2x2x!quant.uniform>) + -> tensor<2x2x!quant.uniform> + return %0 : tensor<2x2x!quant.uniform> +} + +// ----- +// CHECK-LABEL: @abs +func.func @abs(%arg0 : tensor<20x20x!quant.uniform>) -> tensor<20x20x!quant.uniform> { + // CHECK: %[[V0:.+]] = tosa.rescale %arg0 {double_round = false, input_zp = -1 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} + // CHECK: %[[V1:.+]] = stablehlo.abs %[[V0]] : tensor<20x20xi32> + // CHECK: %[[V3:.+]] = tosa.rescale %[[V1]] {double_round = false, input_zp = 0 : i32, multiplier = array, output_zp = -128 : i32, per_channel = false, scale32 = true, shift = array} + // CHECK: return %[[V3]] : tensor<20x20x!quant.uniform> + %0 = "stablehlo.abs"(%arg0) : (tensor<20x20x!quant.uniform>) -> tensor<20x20x!quant.uniform> + return %0 : tensor<20x20x!quant.uniform> +} + +// ----- +// CHECK-LABEL: @compareGE +func.func @compareGE(%arg0 : tensor<20x20x!quant.uniform>, + %arg1 : tensor<20x20x!quant.uniform>) -> tensor<20x20xi1> { + // CHECK: %[[V0:.+]] = tosa.rescale %arg0 {double_round = false, input_zp = -1 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} + // CHECK: %[[V1:.+]] = tosa.rescale %arg1 {double_round = false, input_zp = -2 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} + // CHECK: %[[V2:.+]] = stablehlo.compare GE, %[[V0]], %[[V1]], TOTALORDER : + // CHECK: return %[[V2]] + %0 = stablehlo.compare GE, %arg0, %arg1, TOTALORDER : (tensor<20x20x!quant.uniform>, tensor<20x20x!quant.uniform>) -> tensor<20x20xi1> + return %0 : tensor<20x20xi1> +} + +// ----- +// CHECK-LABEL: @compareLT +func.func @compareLT(%arg0 : tensor<20x20x!quant.uniform>, + %arg1 : tensor<20x20x!quant.uniform>) -> tensor<20x20xi1> { + // CHECK: %[[V0:.+]] = tosa.rescale %arg0 {double_round = false, input_zp = -1 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} + // CHECK: %[[V1:.+]] = tosa.rescale %arg1 {double_round = false, input_zp = -2 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} + // CHECK: %[[V2:.+]] = stablehlo.compare LT, %[[V0]], %[[V1]], TOTALORDER : + // CHECK: return %[[V2]] + %0 = stablehlo.compare LT, %arg0, %arg1, TOTALORDER : (tensor<20x20x!quant.uniform>, tensor<20x20x!quant.uniform>) -> tensor<20x20xi1> + return %0 : tensor<20x20xi1> +} diff --git a/stablehlo/conversions/tosa/transforms/StablehloQuantLegalizeToTosaRescale.cpp b/stablehlo/conversions/tosa/transforms/StablehloQuantLegalizeToTosaRescale.cpp index e5fa10450e5..44d1a0c7b8f 100644 --- a/stablehlo/conversions/tosa/transforms/StablehloQuantLegalizeToTosaRescale.cpp +++ b/stablehlo/conversions/tosa/transforms/StablehloQuantLegalizeToTosaRescale.cpp @@ -41,7 +41,7 @@ namespace tosa { namespace { // create a tosa rescale op and return its result value -Value buildRescale(PatternRewriter& rewriter, Location loc, +Value buildRescale(PatternRewriter &rewriter, Location loc, ShapedType outputType, Value inputVal, int32_t multiplier, int32_t shift, int64_t inputZp, int64_t outputZp, bool doubleRound, bool scale32, bool perChannel) { @@ -58,7 +58,7 @@ Value buildRescale(PatternRewriter& rewriter, Location loc, } // Creates TOSA rescale op with int32 output -Value buildRescaleToInt32(PatternRewriter& rewriter, Location loc, +Value buildRescaleToInt32(PatternRewriter &rewriter, Location loc, Value inputVal, double inputScale, int64_t inputZp) { auto inputType = cast(inputVal.getType()); auto outputType = inputType.clone(rewriter.getI32Type()); @@ -76,7 +76,7 @@ Value buildRescaleToInt32(PatternRewriter& rewriter, Location loc, } // Creates TOSA rescale op with int32 input -Value buildRescaleFromInt32(PatternRewriter& rewriter, Location loc, +Value buildRescaleFromInt32(PatternRewriter &rewriter, Location loc, ShapedType outputType, Value inputVal, double outputScale, int64_t outputZp) { // Input should be int32 type @@ -96,32 +96,88 @@ Value buildRescaleFromInt32(PatternRewriter& rewriter, Location loc, /*perChannel=*/false); } +using UnaryRescaleScalesFn = + void (*)(const quant::UniformQuantizedType &operandQType, + const quant::UniformQuantizedType &resultQType, + double &operandRescaleScale, double &resultRescaleScale); + +void GetUnaryRescaleScales(const quant::UniformQuantizedType &operandQType, + const quant::UniformQuantizedType &resultQType, + double &operandRescaleScale, + double &resultRescaleScale) { + double operandScale = operandQType.getScale(); + double resultScale = resultQType.getScale(); + + // rescale inputs to I32 with scale=1.0 + // perform I32 unary operation + // rescale result to scale = operandScale / resultScale + + operandRescaleScale = 1.0f; + resultRescaleScale = operandScale / resultScale; +} + template -LogicalResult matchAndRewriteAddSub(StablehloOp op, PatternRewriter& rewriter) { - Value lhs = op.getLhs(); - Value rhs = op.getRhs(); +LogicalResult matchAndRewriteUnaryOp( + StablehloOp op, PatternRewriter &rewriter, + UnaryRescaleScalesFn rescaleScalesFn = GetUnaryRescaleScales) { + Value operand = op.getOperand(); Value result = op.getResult(); - auto lhsType = cast(lhs.getType()); - auto rhsType = cast(rhs.getType()); + auto operandType = cast(operand.getType()); auto resultType = cast(result.getType()); - auto lhsQType = - dyn_cast(lhsType.getElementType()); - auto rhsQType = - dyn_cast(rhsType.getElementType()); + auto operandQType = + dyn_cast(operandType.getElementType()); auto resultQType = dyn_cast(resultType.getElementType()); - if (!lhsQType || !rhsQType || !resultQType) { + if (!operandQType || !resultQType) { return rewriter.notifyMatchFailure( op, "The conversion supports operands/results with per-tensor quantized " "types only"); } - // Following quantization described in tflite - // In details it does: + double operandRescaleScale, resultRescaleScale; + + rescaleScalesFn(operandQType, resultQType, operandRescaleScale, + resultRescaleScale); + + auto loc = op.getLoc(); + + // Implement single rounding only + Value rescaledOperand = buildRescaleToInt32( + rewriter, loc, operand, operandRescaleScale, operandQType.getZeroPoint()); + + auto rescaledResultType = resultType.clone(rewriter.getI32Type()); + Value rescaledResult = + rewriter.create(loc, rescaledResultType, rescaledOperand) + .getResult(); + + Value newOutput = + buildRescaleFromInt32(rewriter, loc, resultType, rescaledResult, + resultRescaleScale, resultQType.getZeroPoint()); + + rewriter.replaceOp(op, {newOutput}); + return success(); +} + +LogicalResult matchAndRewriteOp(stablehlo::AbsOp op, + PatternRewriter &rewriter) { + return matchAndRewriteUnaryOp(op, rewriter); +} + +using BinaryRescaleScalesFn = void (*)( + const quant::UniformQuantizedType &lhsQType, + const quant::UniformQuantizedType &rhsQType, + const quant::UniformQuantizedType &resultQType, double &lhsRescaleScale, + double &rhsRescaleScale, double &resultRescaleScale); + +void GetAddSubRescaleScales(const quant::UniformQuantizedType &lhsQType, + const quant::UniformQuantizedType &rhsQType, + const quant::UniformQuantizedType &resultQType, + double &lhsRescaleScale, double &rhsRescaleScale, + double &resultRescaleScale) { // 1. Rescale inputs to scale = 2.0 x max(lhs.scale, rhs.scale) // 2. Extra left shift to input to increase precision // Where input_shift = 20 if input is 8-bit @@ -139,22 +195,100 @@ LogicalResult matchAndRewriteAddSub(StablehloOp op, PatternRewriter& rewriter) { ? SHIFT_16_BIT : SHIFT_8_BIT; - double lhsRescaleScale = lhsScale / maxScale2x; - double rhsRescaleScale = rhsScale / maxScale2x; - double resultRescaleScale = + lhsRescaleScale = + (lhsScale / maxScale2x) * static_cast(1 << inputShift); + rhsRescaleScale = + (rhsScale / maxScale2x) * static_cast(1 << inputShift); + resultRescaleScale = maxScale2x / (resultScale * static_cast(1 << inputShift)); +} + +void GetMulDivRescaleScales(const quant::UniformQuantizedType &lhsQType, + const quant::UniformQuantizedType &rhsQType, + const quant::UniformQuantizedType &resultQType, + double &lhsRescaleScale, double &rhsRescaleScale, + double &resultRescaleScale) { + double lhsScale = lhsQType.getScale(); + double rhsScale = rhsQType.getScale(); + double resultScale = resultQType.getScale(); + + // rescale inputs to I32 with scale=1.0 + // perform I32 multiply or divide + // rescale result to scale=(lhsScale * rhsScale) / resultScale + + lhsRescaleScale = 1.0f; + rhsRescaleScale = 1.0f; + resultRescaleScale = lhsScale * rhsScale / resultScale; +} + +void GetMinMaxRescaleScales(const quant::UniformQuantizedType &lhsQType, + const quant::UniformQuantizedType &rhsQType, + const quant::UniformQuantizedType &resultQType, + double &lhsRescaleScale, double &rhsRescaleScale, + double &resultRescaleScale) { + // 1. Rescale inputs to scale = max(lhs.scale, rhs.scale) + // 2. Extra left shift to input to increase precision + // Where input_shift = 20 if input is 8-bit + // input_shift = 15 if input is 16-bit + + double lhsScale = lhsQType.getScale(); + double rhsScale = rhsQType.getScale(); + double resultScale = resultQType.getScale(); + + double maxScale = std::max(lhsScale, rhsScale); + + const int32_t SHIFT_8_BIT = 20; + const int32_t SHIFT_16_BIT = 15; + + int32_t inputShift = (resultQType.getStorageTypeIntegralWidth() == 16) + ? SHIFT_16_BIT + : SHIFT_8_BIT; + + lhsRescaleScale = + (lhsScale / maxScale) * static_cast(1 << inputShift); + rhsRescaleScale = + (rhsScale / maxScale) * static_cast(1 << inputShift); + resultRescaleScale = + maxScale / (resultScale * static_cast(1 << inputShift)); +} + +template +LogicalResult matchAndRewriteBinaryOp(StablehloOp op, PatternRewriter &rewriter, + BinaryRescaleScalesFn rescaleScalesFn) { + Value lhs = op.getLhs(); + Value rhs = op.getRhs(); + Value result = op.getResult(); + + auto lhsType = cast(lhs.getType()); + auto rhsType = cast(rhs.getType()); + auto resultType = cast(result.getType()); + + auto lhsQType = + dyn_cast(lhsType.getElementType()); + auto rhsQType = + dyn_cast(rhsType.getElementType()); + auto resultQType = + dyn_cast(resultType.getElementType()); + + if (!lhsQType || !rhsQType || !resultQType) { + return rewriter.notifyMatchFailure( + op, + "The conversion supports operands/results with per-tensor quantized " + "types only"); + } + + double lhsRescaleScale, rhsRescaleScale, resultRescaleScale; + + rescaleScalesFn(lhsQType, rhsQType, resultQType, lhsRescaleScale, + rhsRescaleScale, resultRescaleScale); auto loc = op.getLoc(); // Implement single rounding only - Value rescaledLhs = buildRescaleToInt32( - rewriter, loc, lhs, - /*scale=*/lhsRescaleScale * static_cast(1 << inputShift), - lhsQType.getZeroPoint()); - Value rescaledRhs = buildRescaleToInt32( - rewriter, loc, rhs, - /*scale=*/rhsRescaleScale * static_cast(1 << inputShift), - rhsQType.getZeroPoint()); + Value rescaledLhs = buildRescaleToInt32(rewriter, loc, lhs, lhsRescaleScale, + lhsQType.getZeroPoint()); + Value rescaledRhs = buildRescaleToInt32(rewriter, loc, rhs, rhsRescaleScale, + rhsQType.getZeroPoint()); auto rescaledResultType = resultType.clone(rewriter.getI32Type()); Value rescaledResult = rewriter @@ -170,20 +304,115 @@ LogicalResult matchAndRewriteAddSub(StablehloOp op, PatternRewriter& rewriter) { return success(); } +LogicalResult matchAndRewriteOp(stablehlo::AddOp op, + PatternRewriter &rewriter) { + return matchAndRewriteBinaryOp(op, rewriter, GetAddSubRescaleScales); +} + +LogicalResult matchAndRewriteOp(stablehlo::SubtractOp op, + PatternRewriter &rewriter) { + return matchAndRewriteBinaryOp(op, rewriter, GetAddSubRescaleScales); +} + +LogicalResult matchAndRewriteOp(stablehlo::MulOp op, + PatternRewriter &rewriter) { + return matchAndRewriteBinaryOp(op, rewriter, GetMulDivRescaleScales); +} + +LogicalResult matchAndRewriteOp(stablehlo::DivOp op, + PatternRewriter &rewriter) { + return matchAndRewriteBinaryOp(op, rewriter, GetMulDivRescaleScales); +} + +LogicalResult matchAndRewriteOp(stablehlo::MinOp op, + PatternRewriter &rewriter) { + return matchAndRewriteBinaryOp(op, rewriter, GetMinMaxRescaleScales); +} + +LogicalResult matchAndRewriteOp(stablehlo::MaxOp op, + PatternRewriter &rewriter) { + return matchAndRewriteBinaryOp(op, rewriter, GetMinMaxRescaleScales); +} + +LogicalResult matchAndRewriteCompareOp(stablehlo::CompareOp op, + PatternRewriter &rewriter) { + Value lhs = op.getLhs(); + Value rhs = op.getRhs(); + Value result = op.getResult(); + + auto lhsType = cast(lhs.getType()); + auto rhsType = cast(rhs.getType()); + auto resultType = cast(result.getType()); + + auto lhsQType = + dyn_cast(lhsType.getElementType()); + auto rhsQType = + dyn_cast(rhsType.getElementType()); + + if (!lhsQType || !rhsQType) { + return rewriter.notifyMatchFailure( + op, + "The conversion supports operands with per-tensor quantized " + "types only"); + } + + double lhsScale = lhsQType.getScale(); + double rhsScale = rhsQType.getScale(); + double maxScale = std::max(lhsScale, rhsScale); + + const int32_t SHIFT_8_BIT = 20; + const int32_t SHIFT_16_BIT = 15; + + // note: compare op require lhs/rhs have equal base storage width + int32_t inputShift = (lhsQType.getStorageTypeIntegralWidth() == 16) + ? SHIFT_16_BIT + : SHIFT_8_BIT; + + double lhsRescaleScale = + (lhsScale / maxScale) * static_cast(1 << inputShift); + double rhsRescaleScale = + (rhsScale / maxScale) * static_cast(1 << inputShift); + + auto loc = op.getLoc(); + + // Implement single rounding only + Value rescaledLhs = buildRescaleToInt32(rewriter, loc, lhs, lhsRescaleScale, + lhsQType.getZeroPoint()); + Value rescaledRhs = buildRescaleToInt32(rewriter, loc, rhs, rhsRescaleScale, + rhsQType.getZeroPoint()); + + auto compareDirection = op.getComparisonDirection(); + auto compareTypeAttr = op.getCompareTypeAttr(); + + Value newOutput = rewriter + .create( + loc, resultType, rescaledLhs, rescaledRhs, + compareDirection, compareTypeAttr) + .getResult(); + + rewriter.replaceOp(op, {newOutput}); + return success(); +} + +LogicalResult matchAndRewriteOp(stablehlo::CompareOp op, + PatternRewriter &rewriter) { + return matchAndRewriteCompareOp(op, rewriter); +} + template struct QuantizedStablehloOpConversion : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(StablehloOpType op, - PatternRewriter& rewriter) const override { - return matchAndRewriteAddSub(op, rewriter); + PatternRewriter &rewriter) const override { + return matchAndRewriteOp(op, rewriter); } }; struct StablehloQuantLegalizeToTosaRescalePass : impl::StablehloQuantLegalizeToTosaRescalePassBase< StablehloQuantLegalizeToTosaRescalePass> { - LogicalResult initialize(MLIRContext* ctx) override { + LogicalResult initialize(MLIRContext *ctx) override { RewritePatternSet patternList(ctx); populateStablehloQuantLegalizeToTosaRescalePatterns(&patternList, ctx); patterns = std::move(patternList); @@ -205,11 +434,25 @@ struct StablehloQuantLegalizeToTosaRescalePass } // namespace void populateStablehloQuantLegalizeToTosaRescalePatterns( - RewritePatternSet* patterns, MLIRContext* context) { + RewritePatternSet *patterns, MLIRContext *context) { + // unary ops + patterns->addWithLabel>( + {"StablehloQuantAbsOp"}, context); + // binary ops patterns->addWithLabel>( {"StablehloQuantAddOp"}, context); patterns->addWithLabel>( {"StablehloQuantSubtractOp"}, context); + patterns->addWithLabel>( + {"StablehloQuantMulOp"}, context); + patterns->addWithLabel>( + {"StablehloQuantDivOp"}, context); + patterns->addWithLabel>( + {"StablehloQuantMaxOp"}, context); + patterns->addWithLabel>( + {"StablehloQuantMinOp"}, context); + patterns->addWithLabel>( + {"StablehloQuantCompareOp"}, context); } } // namespace tosa