Skip to content

Commit

Permalink
[hannk] requantize() should never skip the operation (#6350)
Browse files Browse the repository at this point in the history
* [hannk] requantize() should never skip the operation

Even if inq == outq, the incoming buffer can contain out-of-range values;  we shouldn't try to optimize the op away, since it's cheap.

* Update ops.cpp

* Update ops.cpp
  • Loading branch information
steven-johnson authored Oct 26, 2021
1 parent d6d7bbc commit 50517cb
Showing 1 changed file with 40 additions and 17 deletions.
57 changes: 40 additions & 17 deletions apps/hannk/interpreter/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -484,24 +484,43 @@ void mul_uint8(const HalideBuffer<const void> &in1, const QuantizationInfo &in1q
elementwise_loop_nest<2>(mul_rank2, in1, in2, out);
}

void requantize(const HalideBuffer<const void> &in, const QuantizationInfo &inq,
HalideBuffer<void> out, const QuantizationInfo &outq,
ActivationFunction activation = ActivationFunction::None) {
if (inq == outq) {
// Some of these are just copies, or no-ops.
if (is_alias(in.raw_buffer(), out.raw_buffer())) {
return;
} else {
out.copy_from(in);
}
} else if (in.type() == halide_type_of<uint8_t>() &&
out.type() == halide_type_of<uint8_t>()) {
bool try_requantize(const HalideBuffer<const void> &in, const QuantizationInfo &inq,
HalideBuffer<void> out, const QuantizationInfo &outq,
ActivationFunction activation = ActivationFunction::None) {
if (in.type() != out.type()) {
HLOG(ERROR) << "requantize: input and output types must match";
return false;
}

if (in.type() == halide_type_of<uint8_t>() &&
out.type() == halide_type_of<uint8_t>()) {
// TODO: Maybe a dedicated pipeline for this would be better. It
// could be a little faster, and avoid some quantization error.
add_uint8(in, inq, 1, in, inq, 0, out, outq, activation);
} else {
HLOG(FATAL) << "Unable to requantize " << in.type() << " -> " << out.type() << "\n";
return true;
}

return false;
}

// Input and output buffer types must match.
// If the input and output buffers are quantized, we always call requantize.
// If not, we simply copy.
bool requantize_or_copy(const HalideBuffer<const void> &in, const QuantizationInfo &inq,
HalideBuffer<void> out, const QuantizationInfo &outq,
ActivationFunction activation = ActivationFunction::None) {
if (in.type() != out.type()) {
HLOG(ERROR) << "requantize_or_copy: input and output types must match";
return false;
}
if (try_requantize(in, inq, out, outq, activation)) {
return true;
}

if (!is_alias(in.raw_buffer(), out.raw_buffer())) {
out.copy_from(in);
}
return true;
}

ActivationFunction to_activation(UnaryOp::Operator op) {
Expand Down Expand Up @@ -729,7 +748,9 @@ void ConcatenationOp::execute() {

auto output_crop = output_buf;
crop_to_union(output_crop, input_buf);
requantize(input_buf, input(i)->quantization(), output_crop, output()->quantization());

bool copied = requantize_or_copy(input_buf, input(i)->quantization(), output_crop, output()->quantization());
HCHECK(copied);
}
}

Expand Down Expand Up @@ -1636,7 +1657,8 @@ void SplitOp::execute() {
assert(output_buf.dim(axis_).min() == 0);

output_buf.translate(axis_, concatenated_i);
requantize(input_buf, input()->quantization(), output_buf, output(i)->quantization());
bool copied = requantize_or_copy(input_buf, input()->quantization(), output_buf, output(i)->quantization());
HCHECK(copied);

concatenated_i += output_buf.dim(axis_).extent();
}
Expand Down Expand Up @@ -1786,7 +1808,8 @@ void UnaryOp::execute() {
mul_uint8(in_buf, in->quantization(), in_buf, in->quantization(), out_buf, out->quantization());
return;
} else if (op_ == Relu || op_ == Relu6 || op_ == ReluN1To1) {
requantize(in_buf, in->quantization(), out_buf, out->quantization(), to_activation(op_));
bool copied = try_requantize(in_buf, in->quantization(), out_buf, out->quantization(), to_activation(op_));
HCHECK(copied);
return;
}
}
Expand Down

0 comments on commit 50517cb

Please sign in to comment.