diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index 515060595f0e7..19c2cf7b1b9e8 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -561,7 +561,7 @@ TVM_DLL PrimExpr LargeUIntImm(DataType dtype, int64_t low, int64_t high); * \param s integer shift * \return The constructed expression. */ -TVM_DLL PrimExpr fixed_point_multiply(PrimExpr x, PrimExpr m, PrimExpr s); +TVM_DLL PrimExpr fixed_point_multiply(PrimExpr x, PrimExpr y, int32_t n); // Intrinsic operators #define TVM_DECLARE_INTRIN_UNARY(OpName) \ diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py index 1db4e08bed259..f272bafa06b0d 100644 --- a/python/tvm/relay/op/_tensor.py +++ b/python/tvm/relay/op/_tensor.py @@ -20,8 +20,9 @@ from tvm.te.hybrid import script import topi +from . import strategy from .op import register_compute, register_shape_func -from .op import register_broadcast_schedule, register_injective_schedule +from .op import register_broadcast_schedule, register_injective_schedule, register_strategy from .op import register_pattern, OpPattern @@ -131,13 +132,9 @@ def clip_compute(attrs, inputs, output_type): register_injective_schedule("clip") -# fixed point multiply -@register_compute("fixed_point_multiply") -def fixed_point_multiply_compute(attrs, inputs, output_type): - assert len(inputs) == 1 - return [topi.fixed_point_multiply(inputs[0], attrs.multiplier, attrs.shift)] -register_injective_schedule("fixed_point_multiply") +# fixed point multiply +register_strategy("fixed_point_multiply", strategy.fixed_point_multiply_strategy) # full @script diff --git a/python/tvm/relay/op/strategy/arm_cpu.py b/python/tvm/relay/op/strategy/arm_cpu.py index d682aad63bec8..027b72e5be0a0 100644 --- a/python/tvm/relay/op/strategy/arm_cpu.py +++ b/python/tvm/relay/op/strategy/arm_cpu.py @@ -37,7 +37,7 @@ def schedule_concatenate_arm_cpu(_, outs, target): """schedule concatenate for arm cpu""" with target: return topi.arm_cpu.schedule_concatenate(outs) - + @conv2d_strategy.register(["arm_cpu", "micro_dev"]) def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): """conv2d arm cpu strategy""" @@ -332,3 +332,13 @@ def schedule_bitserial_dense_arm_cpu(attrs, inputs, out_type, target): wrap_topi_schedule(topi.arm_cpu.schedule_bitserial_dense), name="bitserial_dense.arm_cpu") return strategy + +@fixed_point_multiply_strategy.register("arm_cpu") +def schedule_fixed_point_multiply(attrs, inputs, out_type, target): + """bitserial_dense arm cpu strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_fixed_point_multiply(topi.arm_cpu.fixed_point_multiply), + wrap_topi_schedule(topi.arm_cpu.schedule_fixed_point_multiply), + name="fixed_point_multiply.arm_cpu") + return strategy diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 632445b71bd2f..93648784578ca 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -124,6 +124,23 @@ def softmax_strategy(attrs, inputs, out_type, target): name="softmax.generic") return strategy +# fixed_point_multiply +def wrap_compute_fixed_point_multiply(topi_compute): + """Wrap softmax topi compute""" + def _compute_fixed_point_multiply(attrs, inputs, out_type): + return [topi_compute(inputs[0], attrs.multiplier, attrs.shift)] + return _compute_fixed_point_multiply + +@override_native_generic_func("fixed_point_multiply_strategy") +def fixed_point_multiply_strategy(attrs, inputs, out_type, target): + """fixed_point_multiply_strategy generic strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_fixed_point_multiply(topi.math.fixed_point_multiply), + wrap_topi_schedule(topi.generic.schedule_injective), + name="fixed_point_multiply.generic") + return strategy + # log_softmax @generic_func def schedule_log_softmax(attrs, outs, target): diff --git a/src/target/intrin_rule.cc b/src/target/intrin_rule.cc index fa2eef3f8f2a1..d6f580191d27c 100644 --- a/src/target/intrin_rule.cc +++ b/src/target/intrin_rule.cc @@ -123,42 +123,22 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.fixed_point_multiply") const tir::CallNode* call = e.as(); CHECK(call != nullptr); - PrimExpr tensor = call->args[0]; - PrimExpr fixed_point_multiplier = call->args[1]; - PrimExpr shift = call->args[2]; + PrimExpr x = call->args[0]; + PrimExpr y = call->args[1]; + PrimExpr n = call->args[2]; // Only int32 types are supported (any number of lanes is allowed) - CHECK(tensor.dtype().code() == DLDataTypeCode::kDLInt && tensor.dtype().bits() == 32); - CHECK(fixed_point_multiplier.dtype().code() == DLDataTypeCode::kDLInt && - fixed_point_multiplier.dtype().bits() == 32); - CHECK(shift.dtype().code() == DLDataTypeCode::kDLInt && shift.dtype().bits() == 32); + CHECK(x.dtype().code() == DLDataTypeCode::kDLInt && x.dtype().bits() == 32); + CHECK(y.dtype().code() == DLDataTypeCode::kDLInt && y.dtype().bits() == 32); - DataType hp_dtype = DataType::Int(64, tensor.dtype().lanes()); - DataType lp_dtype = DataType::Int(32, tensor.dtype().lanes()); + DataType hp_dtype = DataType::Int(64, x.dtype().lanes()); + DataType lp_dtype = DataType::Int(32, x.dtype().lanes()); + PrimExpr K = (make_const(hp_dtype, 1) << (n - 1)); - // 1) Calculating the integer multiplier and integer shift - PrimExpr zero = make_const(shift.dtype(), 0); - PrimExpr left_shift = tir::Select((shift > zero), shift, zero); - PrimExpr right_shift = tir::Select(shift > zero, zero, -shift); - - // 2) Multiply the integer multiplier - tensor = tir::Select(left_shift != zero, tensor << cast(hp_dtype, left_shift), - cast(hp_dtype, tensor)); - - // 3) Perform the multiplication in higher precision. - tensor = tensor * fixed_point_multiplier; - - // 4) Find the rounding scalar - PrimExpr total_right_shift = right_shift + 31; - PrimExpr pos_rounding_value = (make_const(hp_dtype, 1) << (total_right_shift - 1)); - - tensor = tensor + pos_rounding_value; - - // 5) Simply right shift the result to get the final output. - tensor = tensor >> total_right_shift; - - // 6) The fixed point multiplication keeps the value in int32 range. Casting back to int32. - *rv = cast(lp_dtype, tensor); + x = cast(hp_dtype, x) * cast(hp_dtype, y); + x = x + K; + x = x >> n; + *rv = cast(lp_dtype, x); }); } // namespace intrin diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index aec41001bbf28..7c8b0a7afb5f0 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -91,8 +91,8 @@ PrimExpr LargeUIntImm(DataType t, int64_t low, int64_t high) { } // fixed_point_multiply -PrimExpr fixed_point_multiply(PrimExpr x, PrimExpr m, PrimExpr s) { - return tir::Call(x.dtype(), tir::builtin::fixed_point_multiply(), {x, m, s}); +PrimExpr fixed_point_multiply(PrimExpr x, PrimExpr y, int32_t n) { + return tir::Call(x.dtype(), tir::builtin::fixed_point_multiply(), {x, y, make_const(DataType::UInt(32), n)}); } // The public function with a quick checking path. diff --git a/tests/python/relay/test_op_qnn_requantize.py b/tests/python/relay/test_op_qnn_requantize.py index fb52b30305825..d671a1bf88d1a 100644 --- a/tests/python/relay/test_op_qnn_requantize.py +++ b/tests/python/relay/test_op_qnn_requantize.py @@ -90,7 +90,9 @@ def test_downscale(): # Try positive values # 8 corresponds to 0.5, resulting in 1 golden_data = np.arange(0, 32, 1).astype('int32') + print(golden_data) golden_output = np.repeat([0, 1, 2], [8, 16, 8]) + print(golden_output) verify(mod, (golden_data, golden_output)) # Try negative values diff --git a/topi/python/topi/arm_cpu/__init__.py b/topi/python/topi/arm_cpu/__init__.py index e121fbc7ec6d0..4eb8f0ff7c637 100644 --- a/topi/python/topi/arm_cpu/__init__.py +++ b/topi/python/topi/arm_cpu/__init__.py @@ -26,3 +26,4 @@ from .bitserial_dense import * from .injective import * from . import cortex_m7 +from .math import fixed_point_multiply diff --git a/topi/python/topi/arm_cpu/injective.py b/topi/python/topi/arm_cpu/injective.py index 54306deadd85f..2d4817c859530 100644 --- a/topi/python/topi/arm_cpu/injective.py +++ b/topi/python/topi/arm_cpu/injective.py @@ -76,6 +76,37 @@ def schedule_injective(outs): schedule_injective_from_existing(s, x) return s +def schedule_fixed_point_multiply(outs): + """ARM CPU schedule for injective op. + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of injective in the format + of an array of tensors. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs + s = te.create_schedule([x.op for x in outs]) + x = outs[0] + ins = x.op.input_tensors + dtype = ins[0].dtype if len(ins) > 0 else x.dtype + max_vlen = 4 if dtype == 'int32' else 8 + + if list(s[x].op.axis): + # do not vectorize for broadcast + (io, ii) = s[x].split(list(s[x].op.axis)[-1], max_vlen) + s[x].vectorize(ii) + tvm.te.schedule.AutoInlineInjective(s) + + if not is_empty_shape(x.shape): + schedule_injective_from_existing(s, x) + return s + def schedule_concatenate(outs): """Schedule for concatenate op. diff --git a/topi/python/topi/arm_cpu/tensor_intrin.py b/topi/python/topi/arm_cpu/tensor_intrin.py index fa004820d6c31..00caac6559eaf 100644 --- a/topi/python/topi/arm_cpu/tensor_intrin.py +++ b/topi/python/topi/arm_cpu/tensor_intrin.py @@ -20,6 +20,7 @@ import tvm from tvm import te from tvm.contrib import util, clang +from tvm.ir import Array, Op def gemv_quantized_impl(M, N, data_type='uint8'): """ Assembly implementation of a blocked gemv. Given @@ -452,7 +453,7 @@ def _instr(index): C.op, _intrin_func, binds={data:a_buffer, kernel:b_buffer}, default_buffer_params=buffer_params) -def _fixed_point_multiply_arm(op): +def fixed_point_multiply_arm_rule(op): """ Implementation of fixed point multiplication through arm intrinsics sqrdmulh and srshl @@ -460,10 +461,14 @@ def _fixed_point_multiply_arm(op): x = op.args[0] multiplier = op.args[1] shift = op.args[2] - - # Don't use this intrinsic if we don't have a int32x4 vector + + # Don't use AArch64 intrinsics if we don't have a int32x4 vector if x.dtype != "int32x4": - return op + left_shift = tvm.tir.if_then_else(shift > 0, shift, 0) + right_shift = tvm.tir.if_then_else(shift > 0, 0, -shift) + x = tvm.tir.if_then_else(left_shift > 0, x << left_shift, x) + mulq = tvm.tir.fixed_point_multiply(x, multiplier, 31) + return mulq >> right_shift # Case 1, shift is negative sqrdmulh = tvm.tir.call_llvm_intrin(op.dtype, @@ -477,7 +482,7 @@ def _fixed_point_multiply_arm(op): out_1 = tvm.tir.call_llvm_intrin(op.dtype, 'llvm.aarch64.neon.srshl', tvm.tir.const(2, 'uint32'), - sqrdmulh, + fixed_up_x, shift) # Case 2, shift is positive @@ -491,6 +496,16 @@ def _fixed_point_multiply_arm(op): # Select depending on the shift return tvm.tir.Select(shift < 0, out_1, out_2) + +def fixed_point_multiply_arm(x, m, s): + """customized log intrinsic function""" + return tvm.tir.call_intrin(x.dtype, "tir.fixed_point_multiply_arm", x, m, s) + + tvm.target.intrin.register_intrin_rule("llvm.aarch64", - "fixed_point_multiply", - _fixed_point_multiply_arm, override=True) + "fixed_point_multiply_arm", + fixed_point_multiply_arm_rule, override=False) + +tvm.ir.register_op_attr("tir.fixed_point_multiply_arm", "TCallEffectKind", tvm.tir.CallEffectKind.Pure) +tvm.ir.register_op_attr("tir.fixed_point_multiply_arm", "TVectorizable", True) + diff --git a/topi/python/topi/math.py b/topi/python/topi/math.py index 475f8074fb353..c6aa9e986be2e 100644 --- a/topi/python/topi/math.py +++ b/topi/python/topi/math.py @@ -612,7 +612,6 @@ def _compute(*indices): return tvm.te.max(tvm.te.min(value, const_max), const_min) return te.compute(x.shape, _compute) -@tvm.te.tag_scope(tag=tag.ELEMWISE) def fixed_point_multiply(x, multiplier, shift): """ @@ -628,15 +627,20 @@ def fixed_point_multiply(x, multiplier, shift): y : tvm.te.Tensor The result. """ + left_shift = shift if shift > 0 else 0 + right_shift = 0 if shift > 0 else -shift + def _compute(*indices): - value = x(*indices) - m = tvm.tir.const(multiplier, x.dtype) - s = tvm.tir.const(shift, x.dtype) - return tvm.tir.fixed_point_multiply(value, m, s) + val = x(*indices) + val = tvm.tir.if_then_else(left_shift > 0, val << left_shift, val) + mulq = tvm.tir.fixed_point_multiply(val, multiplier, 31) + nudge = (1 << (right_shift-1)) + return (mulq+nudge) >> right_shift assert x.dtype == "int32", "input tensor type needs to be int32" return te.compute(x.shape, _compute) + def cast(x, dtype): """Cast input to specified data type.