Skip to content

Commit

Permalink
fix tflite rounding
Browse files Browse the repository at this point in the history
  • Loading branch information
yunjing.lh committed May 17, 2020
1 parent 63f84a1 commit 960cc5c
Show file tree
Hide file tree
Showing 20 changed files with 673 additions and 227 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ if(MSVC)
add_definitions(-D_ENABLE_EXTENDED_ALIGNED_STORAGE)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /EHsc")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /MP")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /bigobj")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /bigobj")
if(USE_MSVC_MT)
foreach(flag_var
Expand Down
28 changes: 24 additions & 4 deletions include/tvm/relay/qnn/attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,25 @@ namespace tvm {
namespace relay {
namespace qnn {

/*! \brief Attribute for qnn add operator */
struct QnnAddAttrs : public tvm::AttrsNode<QnnAddAttrs> {
std::string rounding;

TVM_DECLARE_ATTRS(QnnAddAttrs, "relay.attrs.QnnAddAttrs") {
TVM_ATTR_FIELD(rounding).set_default("UPWARD").describe(
"Defines the rounding direction when the value is midway between"
"two representable values. There are two 3 modes - UPWARD, TONEAREST"
"or TFLITE. UP/TONEAREST modes behave exactly same except at the"
"midpoints between the two representable values. At the midpoint,"
"UPWARD rounds towards positive infinity (for example -1.5 will be"
"rounded to -1). TONEAREST is the standard rounding where the"
"value is rounded away from zero at midpoints (for example, -1.5"
"rounds to -2). More context can be found at following glibc manual"
"https://www.gnu.org/software/libc/manual/html_node/Rounding.html."
"TFLITE mode is more complicated, referring to tflite implementation.");
}
};

/*! \brief Attribute for requantize operator */
struct RequantizeAttrs : public tvm::AttrsNode<RequantizeAttrs> {
int axis;
Expand All @@ -46,14 +65,15 @@ struct RequantizeAttrs : public tvm::AttrsNode<RequantizeAttrs> {
.set_default(-1);
TVM_ATTR_FIELD(rounding).set_default("UPWARD").describe(
"Defines the rounding direction when the value is midway between"
"two representable values. There are two supported modes - UPWARD"
"or TONEAREST. Both modes behave exactly same except at the"
"two representable values. There are two 3 modes - UPWARD, TONEAREST"
"or TFLITE. UP/TONEAREST modes behave exactly same except at the"
"midpoints between the two representable values. At the midpoint,"
"UPWARD rounds towards positive infinity (for example -1.5 will be"
"rounded to -1). TONEAREST is the standard rounding where the"
"value is rounded away from zero at midpoints (for example, -1.5"
"rounds to -2). More context can be found at following gblic manual"
"https://www.gnu.org/software/libc/manual/html_node/Rounding.html.");
"rounds to -2). More context can be found at following glibc manual"
"https://www.gnu.org/software/libc/manual/html_node/Rounding.html."
"TFLITE mode is more complicated, referring to tflite implementation.");
TVM_ATTR_FIELD(out_dtype)
.set_default(NullValue<DataType>())
.describe("Output data type, set to explicit type under mixed precision setting");
Expand Down
53 changes: 40 additions & 13 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(self, tensor_idx, tensor, buffer, qnn_params=None):

class OperatorConverter(object):
"""Operator Converted for converting TFLite ops to Relay ops"""
def __init__(self, model, subgraph, exp_tab):
def __init__(self, model, subgraph, exp_tab, rounding):

try:
from tflite.BuiltinOperator import BuiltinOperator
Expand All @@ -60,6 +60,7 @@ def __init__(self, model, subgraph, exp_tab):
self.builtin_op_code = build_str_map(BuiltinOperator())
self.activation_fn_type = build_str_map(ActivationFunctionType())
self.builtin_options = build_str_map(BuiltinOptions())
self.rounding = rounding

# Add more operators
self.convert_map = {
Expand Down Expand Up @@ -643,6 +644,9 @@ def _hard_swish(data):

return out

# TODO in quantized mode, concat op implicitly invokes requantize in cpp
# implementation, TFLITE mode rounding needed to be selected in order
# to get bit-exact execution.
def convert_concatenation(self, op):
"""Convert TFLite concatenation"""
try:
Expand Down Expand Up @@ -854,14 +858,26 @@ def _convert_elemwise(self, relay_op, op):
if lhs_tensor.qnn_params:
assert rhs_tensor.qnn_params, "Both tensors should be quantized."
assert output_tensor.qnn_params, "Output tensor should be quantized."
out = relay_op(lhs=lhs_expr,
rhs=rhs_expr,
lhs_scale=lhs_tensor.qnn_params['scale'],
lhs_zero_point=lhs_tensor.qnn_params['zero_point'],
rhs_scale=rhs_tensor.qnn_params['scale'],
rhs_zero_point=rhs_tensor.qnn_params['zero_point'],
output_scale=output_tensor.qnn_params['scale'],
output_zero_point=output_tensor.qnn_params['zero_point'])
has_tflite_rounding_mode = [_qnn.op.add]
if relay_op in has_tflite_rounding_mode:
out = relay_op(lhs=lhs_expr,
rhs=rhs_expr,
lhs_scale=lhs_tensor.qnn_params['scale'],
lhs_zero_point=lhs_tensor.qnn_params['zero_point'],
rhs_scale=rhs_tensor.qnn_params['scale'],
rhs_zero_point=rhs_tensor.qnn_params['zero_point'],
output_scale=output_tensor.qnn_params['scale'],
output_zero_point=output_tensor.qnn_params['zero_point'],
rounding=self.rounding)
else:
out = relay_op(lhs=lhs_expr,
rhs=rhs_expr,
lhs_scale=lhs_tensor.qnn_params['scale'],
lhs_zero_point=lhs_tensor.qnn_params['zero_point'],
rhs_scale=rhs_tensor.qnn_params['scale'],
rhs_zero_point=rhs_tensor.qnn_params['zero_point'],
output_scale=output_tensor.qnn_params['scale'],
output_zero_point=output_tensor.qnn_params['zero_point'])
else:
out = relay_op(lhs_expr, rhs_expr)

Expand Down Expand Up @@ -924,6 +940,9 @@ def convert_sub(self, op):
return self._convert_elemwise(_qnn.op.subtract, op)
return self._convert_elemwise(_op.subtract, op)

# TODO in quantized mode, mul op implicitly invokes requantize in cpp
# implementation, TFLITE mode rounding needed to be selected in order
# to get bit-exact execution.
def convert_mul(self, op):
"""Convert TFLite MUL"""
# Check if the input tensor is quantized, call QNN op
Expand Down Expand Up @@ -1327,6 +1346,7 @@ def _convert_reduce(self, relay_op, op):
input_zero_point=input_tensor.qnn_params['zero_point'],
output_scale=output_tensor.qnn_params['scale'],
output_zero_point=output_tensor.qnn_params['zero_point'],
rounding=self.rounding,
out_dtype=output_tensor_type_str)

return out
Expand Down Expand Up @@ -1452,6 +1472,7 @@ def convert_fully_connected(self, op):
input_zero_point=new_input_zero_point,
output_scale=output_tensor.qnn_params['scale'],
output_zero_point=output_tensor.qnn_params['zero_point'],
rounding=self.rounding,
out_dtype=output_tensor_type_str)

# Call activation function
Expand Down Expand Up @@ -1667,6 +1688,7 @@ def convert_conv(self, op, conv_type):
input_zero_point=new_input_zero_point,
output_scale=output_tensor.qnn_params['scale'],
output_zero_point=output_tensor.qnn_params['zero_point'],
rounding=self.rounding,
out_dtype=output_tensor_type_str)

# Call activation function
Expand Down Expand Up @@ -2454,8 +2476,10 @@ def get_scalar_from_constant(expr):
assert isinstance(expr, _expr.Constant) and not expr.data.shape, \
"Expr is not a constant scalar."
value = expr.data.asnumpy()
assert value.dtype == np.dtype(np.int32) or value.dtype == np.dtype(np.float32), \
"value must be float32/int32"
assert value.dtype == np.dtype(np.int32) or \
value.dtype == np.dtype(np.float32) or \
value.dtype == np.dtype(np.float64), \
"value must be float32/float64/int32"
return np.asscalar(value)


Expand Down Expand Up @@ -2524,7 +2548,7 @@ def get_tensor_name(subgraph, tensor_idx):
return subgraph.Tensors(tensor_idx).Name().decode("utf-8")


def from_tflite(model, shape_dict, dtype_dict):
def from_tflite(model, shape_dict, dtype_dict, rounding='TFLITE'):
"""Convert from tflite model into compatible relay Function.
Parameters
Expand All @@ -2538,6 +2562,9 @@ def from_tflite(model, shape_dict, dtype_dict):
dtype_dict : dict of str to str
Input types of the model.
rounding : str
Rounding mode for tflite model
Returns
-------
mod : tvm.IRModule
Expand Down Expand Up @@ -2576,7 +2603,7 @@ def from_tflite(model, shape_dict, dtype_dict):
exp_tab.set_expr(model_input_name, _expr.var(model_input_name, shape=shape, dtype=dtype))

# op code in model
op_converter = OperatorConverter(model, subgraph, exp_tab)
op_converter = OperatorConverter(model, subgraph, exp_tab, rounding)
op_converter.check_unsupported_ops()
op_converter.convert_op_to_relay()

Expand Down
9 changes: 7 additions & 2 deletions python/tvm/relay/qnn/op/qnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,8 @@ def add(lhs,
rhs_scale,
rhs_zero_point,
output_scale,
output_zero_point):
output_zero_point,
rounding="UPWARD"):
"""Quantized addition with numpy-style broadcasting.
Parameters
Expand Down Expand Up @@ -329,6 +330,9 @@ def add(lhs,
output_zero_point: relay.Expr
The zero point of output quantized expr.
rounding: str, optional
rounding mode of qnn add
Returns
-------
result : relay.Expr
Expand All @@ -338,7 +342,8 @@ def add(lhs,
return _make.add(lhs, rhs,
lhs_scale, lhs_zero_point,
rhs_scale, rhs_zero_point,
output_scale, output_zero_point)
output_scale, output_zero_point,
rounding)


def dense(data,
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/quantize/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def qconfig(**kwargs):
is None, which means will try to call all operartors' annotate rewrite
function.
rounding: "UPWARD" or "TONEAREST"
rounding: "UPWARD" or "TONEAREST" or "TFLITE"
Rounding direction for fixed point multiplications.
Returns
Expand Down
66 changes: 64 additions & 2 deletions src/relay/qnn/op/add.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,11 @@ namespace tvm {
namespace relay {
namespace qnn {

TVM_REGISTER_NODE_TYPE(QnnAddAttrs);

/*
* \brief Canonicalizes the QNN add op.
* \param attrs The empty attribute.
* \param attrs The QNN add attrs.
* \param new_args The new mutated args to the call node.
* \param arg_types The types of input and output.
* \return The sequence of Relay ops for add op.
Expand All @@ -42,9 +44,55 @@ Expr QnnAddCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
// Get the args.
QnnBinaryOpArguments args(new_args);

// Get the attrs.
const QnnAddAttrs* add_attrs = attrs.as<QnnAddAttrs>();
CHECK(add_attrs != nullptr);
auto& rounding = add_attrs->rounding;

// Get the input dtype and shape.
QnnBinaryOpTensorType input_type(arg_types, 0);

if (rounding == "TFLITE") {
double lhs_scale_val = GetScalarFromConstant<float>(args.lhs_scale);
double rhs_scale_val = GetScalarFromConstant<float>(args.rhs_scale);
double out_scale_val = GetScalarFromConstant<float>(args.output_scale);
double twice_max_input_scale = 2 * std::max(lhs_scale_val, rhs_scale_val);
double real_lhs_scale_val = lhs_scale_val / twice_max_input_scale;
double real_rhs_scale_val = rhs_scale_val / twice_max_input_scale;
double real_out_scale_val = twice_max_input_scale / ((1 << 20) * out_scale_val);

auto real_lhs_scale = MakeConstantScalar<double>(DataType::Float(64), real_lhs_scale_val);
auto real_rhs_scale = MakeConstantScalar<double>(DataType::Float(64), real_rhs_scale_val);
auto real_out_scale = MakeConstantScalar<double>(DataType::Float(64), real_out_scale_val);
auto one_scalar = MakeConstantScalar<double>(DataType::Float(64), 1);
auto zero_scalar = MakeConstantScalar<int>(DataType::Int(32), 0);
auto left_shift_scalar = MakeConstantScalar<int>(DataType::Int(32), 1 << 20);

Expr adapted_lhs = Cast(args.lhs, DataType::Int(32));
if (!IsEqualScalar(args.lhs_zero_point, zero_scalar)) {
adapted_lhs = Subtract(adapted_lhs, Cast(args.lhs_zero_point, DataType::Int(32)));
}
adapted_lhs = Multiply(adapted_lhs, left_shift_scalar);

Expr adapted_rhs = Cast(args.rhs, DataType::Int(32));
if (!IsEqualScalar(args.rhs_zero_point, zero_scalar)) {
adapted_rhs = Subtract(adapted_rhs, Cast(args.rhs_zero_point, DataType::Int(32)));
}
adapted_rhs = Multiply(adapted_rhs, left_shift_scalar);

auto requantized_lhs = Requantize(adapted_lhs, input_type.shape, real_lhs_scale, zero_scalar,
one_scalar, zero_scalar, DataType::Int(32), rounding);

auto requantized_rhs = Requantize(adapted_rhs, input_type.shape, real_rhs_scale, zero_scalar,
one_scalar, zero_scalar, DataType::Int(32), rounding);

auto output = Add(requantized_lhs, requantized_rhs);
output = Requantize(output, input_type.shape, real_out_scale, zero_scalar, one_scalar,
args.output_zero_point, DataType::Int(32), rounding);
// Go back to lower precision.
return ConvertDtype(output, input_type.dtype);
}

// FIXME (anijain2305) - The lowering can be further optimized. Instead of inserting requantize in
// the start, we can insert requantize at the end if both input tensors have same qnn params. In
// that case, we can first add the tensors, subtract the zero point, and requantize at the end.
Expand Down Expand Up @@ -86,9 +134,23 @@ Expr QnnAddCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
return ConvertDtype(output, input_type.dtype);
}

Expr MakeQnnAdd(Expr lhs, Expr rhs, Expr lhs_scale, Expr lhs_zero_point, Expr rhs_scale,
Expr rhs_zero_point, Expr output_scale, Expr output_zero_point,
std::string rounding) {
auto attrs = make_object<QnnAddAttrs>();
attrs->rounding = std::move(rounding);

static const Op& op = Op::Get("qnn.add");
return Call(op,
{lhs, rhs, lhs_scale, lhs_zero_point, rhs_scale, rhs_zero_point, output_scale,
output_zero_point},
Attrs(attrs), {});
}

// QNN Addition operator.
QNN_REGISTER_BINARY_OP("add")
QNN_REGISTER_BINARY_OP_WITH_BODY("add", MakeQnnAdd)
.describe("Elementwise add with with broadcasting for quantized tensors.")
.set_attrs_type<QnnAddAttrs>()
.set_support_level(11)
.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnAddCanonicalize);

Expand Down
13 changes: 10 additions & 3 deletions src/relay/qnn/op/concatenate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,10 @@ bool QnnConcatenateRel(const Array<Type>& types, int num_inputs, const Attrs& at
<< PrettyPrint(types[1]));
}
for (const auto& input_scale : input_scales_tuple->fields) {
CHECK(IsScalarType(input_scale, DataType::Float(32))); // input_scales[idx]
const auto* input_scale_type = input_scale.as<TensorTypeNode>();
CHECK(input_scale_type && IsScalarType(input_scale, input_scale_type->dtype) &&
(input_scale_type->dtype == DataType::Float(32) ||
input_scale_type->dtype == DataType::Float(64))); // input_scale[idx]
}

const auto* input_zero_points_tuple = types[2].as<TupleTypeNode>();
Expand All @@ -61,8 +64,12 @@ bool QnnConcatenateRel(const Array<Type>& types, int num_inputs, const Attrs& at
CHECK(IsScalarType(input_zero_point, DataType::Int(32))); // input_zero_points[idx]
}

CHECK(IsScalarType(types[3], DataType::Float(32))); // output_scale
CHECK(IsScalarType(types[4], DataType::Int(32))); // output_zero_point
const auto* output_scale_type = types[3].as<TensorTypeNode>();
CHECK(output_scale_type && IsScalarType(types[3], output_scale_type->dtype) &&
(output_scale_type->dtype == DataType::Float(32) ||
output_scale_type->dtype == DataType::Float(64))); // output_scale

CHECK(IsScalarType(types[4], DataType::Int(32))); // output_zero_point

// Collect the input tensor and output tensor devoid of scale and zero points to reuse Relay
// Concatenate infer type function.
Expand Down
13 changes: 9 additions & 4 deletions src/relay/qnn/op/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,18 @@ bool QnnConv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
CHECK(param->out_dtype.bits() > 0) << "Output dtype bits should be greater than 0.";

// Check the types of scale and zero points.
CHECK(IsScalarType(types[2], DataType::Int(32))); // input_zero_point
CHECK(IsScalarType(types[3], DataType::Int(32))); // kernel_zero_point
CHECK(IsScalarType(types[4], DataType::Float(32))); // input_scale
CHECK(IsScalarType(types[2], DataType::Int(32))); // input_zero_point
CHECK(IsScalarType(types[3], DataType::Int(32))); // kernel_zero_point
const auto* input_scale_type = types[4].as<TensorTypeNode>();
CHECK(input_scale_type && IsScalarType(types[4], input_scale_type->dtype) &&
(input_scale_type->dtype == DataType::Float(32) ||
input_scale_type->dtype == DataType::Float(64))); // input_scale

// Kernel scale can be a vector of length output_channels or a scalar.
size_t axis = param->kernel_layout.find('O');
CHECK(axis != std::string::npos) << "Kernel layout attribute is not defined";
AssignType(types[5], DataType::Float(32), weight->shape[axis], reporter); // kernel scale
const auto* kernel_scale_type = types[5].as<TensorTypeNode>();
AssignType(types[5], kernel_scale_type->dtype, weight->shape[axis], reporter); // kernel scale

// Collect the input tensor and output tensor devoid of scale and zero points to reuse Relay
// Conv2D infer type function.
Expand Down
9 changes: 6 additions & 3 deletions src/relay/qnn/op/dequantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,10 @@ bool DequantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
<< input_dtype;

// Check the types of scale and zero points.
CHECK(IsScalarType(types[1], DataType::Float(32))); // input_scale
CHECK(IsScalarType(types[2], DataType::Int(32))); // input_zero_point
const auto* input_scale_type = types[1].as<TensorTypeNode>();
CHECK(input_scale_type && (input_scale_type->dtype == DataType::Float(32) ||
input_scale_type->dtype == DataType::Float(64)));
CHECK(IsScalarType(types[2], DataType::Int(32))); // input_zero_point

const Array<tvm::PrimExpr> oshape = data->shape;
// assign output type, output will always be float 32.
Expand All @@ -66,7 +68,8 @@ Expr MakeDequantize(Expr data, Expr input_scale, Expr input_zero_point) {
Expr DequantizeLower(const Expr& input_tensor, const Expr& input_scale,
const Expr& input_zero_point) {
auto shift = Subtract(Cast(input_tensor, DataType::Int(32)), input_zero_point);
auto scaled_output = Multiply(Cast(shift, DataType::Float(32)), input_scale);
auto scaled_output =
Multiply(Cast(shift, DataType::Float(32)), Cast(input_scale, DataType::Float(32)));
return scaled_output;
}

Expand Down
Loading

0 comments on commit 960cc5c

Please sign in to comment.