Skip to content

Commit

Permalink
rebase and improve
Browse files Browse the repository at this point in the history
  • Loading branch information
yunjing.lh committed Apr 6, 2020
1 parent df5935c commit d6e4bb8
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 23 deletions.
20 changes: 8 additions & 12 deletions src/relay/qnn/op/add.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ Expr QnnAddCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
// Get the input dtype and shape.
QnnBinaryOpTensorType input_type(arg_types, 0);


Expr output;
if (rounding == "TFLITE") {
float lhs_scale_val = GetScalarFromConstant<float>(args.lhs_scale);
Expand All @@ -58,7 +57,7 @@ Expr QnnAddCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
float twice_max_input_scale = 2 * std::max(lhs_scale_val, rhs_scale_val);
float real_lhs_scale_val = lhs_scale_val / twice_max_input_scale * (1 << 20);
float real_rhs_scale_val = rhs_scale_val / twice_max_input_scale * (1 << 20);
float real_out_scale_val = twice_max_input_scale / ((1 << 20) * out_scale_val);
float real_out_scale_val = out_scale_val / twice_max_input_scale * (1 << 20);

auto real_lhs_scale = MakeConstantScalar<float>(
DataType::Float(32), real_lhs_scale_val);
Expand All @@ -70,23 +69,20 @@ Expr QnnAddCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
DataType::Float(32), 1);
auto itmd_zero_point = MakeConstantScalar<int>(
DataType::Int(32), 0);

printf("tvm lhs_scale: %f rhs_scale: %f out_scale: %f twice_max_iscale: %f\n",
real_lhs_scale_val, real_rhs_scale_val, real_out_scale_val, twice_max_input_scale);

auto requantized_lhs = Requantize(args.lhs, input_type.shape,
real_lhs_scale, args.lhs_zero_point,
itmd_out_scale, itmd_zero_point,
DataType::Int(32), rounding);
DataType::Int(32));

auto requantized_rhs = Requantize(args.rhs, input_type.shape,
real_rhs_scale, args.rhs_zero_point,
itmd_out_scale, itmd_zero_point,
DataType::Int(32), rounding);
DataType::Int(32));

output = Add(requantized_lhs, requantized_rhs);
output = Requantize(output, input_type.shape,
real_out_scale, itmd_zero_point,
itmd_out_scale, args.output_zero_point,
itmd_out_scale, itmd_zero_point,
real_out_scale, args.output_zero_point,
DataType::Int(32), rounding);
} else {
// Requantize LHS if necessary. Computes Q_a'
Expand All @@ -100,8 +96,8 @@ Expr QnnAddCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
args.output_scale, args.output_zero_point,
input_type.shape);
// Computes Q_a' + Q_b'
auto output = Add(requantized_lhs, requantized_rhs);
output = Add(requantized_lhs, requantized_rhs);

// Subtract zero point.
auto zero_scalar = MakeConstantScalar(DataType::Int(32), 0);
if (!IsEqualScalar(args.output_zero_point, zero_scalar)) {
Expand Down
3 changes: 1 addition & 2 deletions src/relay/qnn/util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -170,10 +170,9 @@ Expr FixedPointMultiply(Expr tensor, double multiplier, const Array<IndexExpr>&

auto zero_t = Zeros(input_shape, hp_dtype);
round_scalar = nearest_rounding_scalar(high32_t, right_shift);
scaled_tensor = right_shift > 0 ? Add(high32_t, round_scalar) : high32_t;
auto rshift_expr = MakeConstantScalar(hp_dtype, right_shift);
auto right_shift_t = Full(rshift_expr, input_shape, hp_dtype);
scaled_tensor = Where(Greater(right_shift_t, zero_t),
Add(high32_t, round_scalar), high32_t);
return RightShift(scaled_tensor, right_shift_t);
} else {
LOG(FATAL) << "Rounding mode " << rounding << " not supported.";
Expand Down
13 changes: 7 additions & 6 deletions src/relay/transforms/legalize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,14 @@ class Legalizer : public ExprMutator {
explicit Legalizer(const std::string& legalize_map_attr_name)
: legalize_map_attr_name_{legalize_map_attr_name} {}

#if defined(__clang__)
__attribute__((optnone)) Expr VisitExpr_(const CallNode* call_node) {
#elif defined(__GNUC__)
Expr VisitExpr_(const CallNode* call_node) __attribute__((optimize(0))) {
#else
// #if defined(__clang__)
// __attribute__((optnone)) Expr VisitExpr_(const CallNode* call_node) {
// #elif defined(__GNUC__)
// Expr VisitExpr_(const CallNode* call_node) __attribute__((optimize(0))) {
// #else
// Expr VisitExpr_(const CallNode* call_node) {
// #endif
Expr VisitExpr_(const CallNode* call_node) {
#endif
// Get the new_call node without any changes to current call node.
Expr new_e = ExprMutator::VisitExpr_(call_node);
Call new_call = Downcast<Call>(new_e);
Expand Down
6 changes: 3 additions & 3 deletions src/relay/transforms/pattern_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -497,17 +497,17 @@ static inline Expr GreaterEqual(const Expr& lhs, const Expr& rhs) {

static inline Expr Greater(const Expr& lhs, const Expr& rhs) {
static const Op& op = Op::Get("greater");
return CallNode::make(op, {lhs, rhs}, Attrs(), {});
return Call(op, {lhs, rhs}, Attrs(), {});
}

static inline Expr Equal(const Expr& lhs, const Expr& rhs) {
static const Op& op = Op::Get("equal");
return CallNode::make(op, {lhs, rhs}, Attrs(), {});
return Call(op, {lhs, rhs}, Attrs(), {});
}

static inline Expr LogicalAnd(const Expr& lhs, const Expr& rhs) {
static const Op& op = Op::Get("logical_and");
return CallNode::make(op, {lhs, rhs}, Attrs(), {});
return Call(op, {lhs, rhs}, Attrs(), {});
}

static inline Expr Full(Expr fill_value,
Expand Down

0 comments on commit d6e4bb8

Please sign in to comment.