Skip to content

Commit

Permalink
Add zero-padding support to the reference kernel (#571)
Browse files Browse the repository at this point in the history
* Add zero-padding support to the reference kernel

* Cleanup asserts
  • Loading branch information
Tombana authored Nov 20, 2020
1 parent 87fee7a commit 747b6b7
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 61 deletions.
27 changes: 19 additions & 8 deletions larq_compute_engine/core/bconv2d/reference.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,15 @@ inline void BConv2DReference(
const int output_height = output_shape.Dims(1);
const int output_width = output_shape.Dims(2);

const bool zero_padding =
bconv2d_params->padding_type == kTfLitePaddingSame &&
bconv2d_params->pad_value == 0;

// For n channels, a popcount of n/2 of the {0,1} bits would correspond to 0
// in the {-1,1} representation. So n/2 can be considered the 'zero point'.
const int binary_zero_point =
(bconv2d_params->channels_in / bconv2d_params->groups) / 2;

TFLITE_DCHECK_EQ(input_depth_per_group * bconv2d_params->groups,
packed_input_shape.Dims(3));
TFLITE_DCHECK_EQ(output_depth_per_group * bconv2d_params->groups,
Expand All @@ -84,16 +93,18 @@ inline void BConv2DReference(
AccumScalar accum = AccumScalar(0);
for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
const int in_x = in_x_origin + dilation_width_factor * filter_x;
const int in_y = in_y_origin + dilation_height_factor * filter_y;
const bool inside = ((in_x >= 0) && (in_x < input_width) &&
(in_y >= 0) && (in_y < input_height));
if (zero_padding && !inside) {
accum += binary_zero_point;
continue;
}
for (int in_channel = 0; in_channel < input_depth_per_group;
++in_channel) {
const int in_x = in_x_origin + dilation_width_factor * filter_x;
const int in_y =
in_y_origin + dilation_height_factor * filter_y;
// `pad_value=1`, which means the bitpacked value is 0, so we
// set `input_value=0`
TBitpacked input_value = 0;
if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) &&
(in_y < input_height)) {
TBitpacked input_value = 0; // represents a +1
if (inside) {
input_value = packed_input_data[Offset(
packed_input_shape, batch, in_y, in_x,
group * input_depth_per_group + in_channel)];
Expand Down
4 changes: 3 additions & 1 deletion larq_compute_engine/tests/end2end_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,9 @@ def preprocess(data):


def assert_model_output(model_lce, inputs, outputs, rtol, atol):
interpreter = Interpreter(model_lce, num_threads=min(os.cpu_count(), 4))
interpreter = Interpreter(
model_lce, num_threads=min(os.cpu_count(), 4), use_reference_bconv=False
)
actual_outputs = interpreter.predict(inputs)
np.testing.assert_allclose(actual_outputs, outputs, rtol=rtol, atol=atol)

Expand Down
44 changes: 17 additions & 27 deletions larq_compute_engine/tflite/kernels/bconv2d.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,14 +124,6 @@ void* Init(TfLiteContext* context, const char* buffer, std::size_t length) {

op_data->fused_activation_function = ConvertActivation(
(ActivationFunctionType)m["fused_activation_function"].AsInt32());
if (bconv2d_params->padding_type == kTfLitePaddingSame &&
bconv2d_params->pad_value != 1 &&
op_data->fused_activation_function != kTfLiteActNone) {
TF_LITE_KERNEL_LOG(
context,
"Fused activations are only supported with valid or one-padding.");
return op_data;
}

// It's not possible to return an error code in this method. If we get to here
// without returning early, initialisation has succeeded without error, and so
Expand Down Expand Up @@ -195,6 +187,20 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
bconv2d_params->groups = groups;
}

if (bconv2d_params->padding_type == kTfLitePaddingSame &&
bconv2d_params->pad_value == 0) {
TF_LITE_ENSURE_MSG(
context,
(kernel_type == KernelType::kReference &&
bconv2d_params->channels_in % 2 == 0) ||
(kernel_type != KernelType::kReference &&
output->type == kTfLiteFloat32 &&
op_data->fused_activation_function == kTfLiteActNone),
"Zero-padding is only supported by the reference kernel with an even "
"number of input channels, or when using "
"float output with no fused activation function.");
}

// Compute the padding and output values (height, width)
int out_width, out_height;
bconv2d_params->padding_values = ComputePaddingHeightWidth(
Expand All @@ -210,11 +216,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, thresholds->type, kTfLiteInt32);
TF_LITE_ENSURE_EQ(context, SizeOfDimension(thresholds, 0),
bconv2d_params->channels_out);
TF_LITE_ENSURE_MSG(context,
bconv2d_params->padding_type != kTfLitePaddingSame ||
bconv2d_params->pad_value == 1,
"Writing bitpacked output is only supported with "
"valid or one-padding.");
} else {
TF_LITE_ENSURE_EQ(context, post_activation_multiplier->type,
kTfLiteFloat32);
Expand All @@ -230,20 +231,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
if (output->type == kTfLiteInt8) {
TF_LITE_ENSURE_EQ(context, output->quantization.type,
kTfLiteAffineQuantization);
TF_LITE_ENSURE_MSG(
context,
bconv2d_params->padding_type != kTfLitePaddingSame ||
bconv2d_params->pad_value == 1,
"8-bit quantization is only supported with valid or one-padding");
}

if (kernel_type == KernelType::kReference) {
TF_LITE_ENSURE_MSG(
context,
bconv2d_params->padding_type != kTfLitePaddingSame ||
bconv2d_params->pad_value == 1,
"The reference kernel only supports valid or one-padding.");
} else if (kernel_type == KernelType::kOptimizedIndirectBGEMM) {
if (kernel_type == KernelType::kOptimizedIndirectBGEMM) {
TF_LITE_ENSURE_MSG(
context, input->allocation_type != kTfLiteDynamic,
"The input tensor must not have dynamic allocation type");
Expand Down Expand Up @@ -374,9 +364,9 @@ void OneTimeSetup(TfLiteContext* context, TfLiteNode* node, OpData* op_data) {
const std::int32_t backtransform_add =
filter_shape.Dims(1) * filter_shape.Dims(2) * channels_in_per_group;
const double output_scale =
output->type == kTfLiteInt8 ? output->params.scale : 1.0f;
output->type == kTfLiteInt8 ? output->params.scale : 1.0;
const double output_zero_point =
output->type == kTfLiteInt8 ? output->params.zero_point : 0.0f;
output->type == kTfLiteInt8 ? output->params.zero_point : 0.0;

for (int i = 0; i < bconv2d_params->channels_out; ++i) {
const double post_mul =
Expand Down
13 changes: 10 additions & 3 deletions larq_compute_engine/tflite/kernels/lce_ops_register.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,24 @@ namespace tflite {
TfLiteRegistration* Register_QUANTIZE();
TfLiteRegistration* Register_DEQUANTIZE();
TfLiteRegistration* Register_BCONV_2D();
TfLiteRegistration* Register_BCONV_2D_REF();
TfLiteRegistration* Register_BMAXPOOL_2D();

// By calling this function on TF lite mutable op resolver, all LCE custom ops
// will be registerd to the op resolver.
inline void RegisterLCECustomOps(::tflite::MutableOpResolver* resolver) {
inline void RegisterLCECustomOps(::tflite::MutableOpResolver* resolver,
const bool use_reference_bconv = false) {
resolver->AddCustom("LceQuantize",
compute_engine::tflite::Register_QUANTIZE());
resolver->AddCustom("LceDequantize",
compute_engine::tflite::Register_DEQUANTIZE());
resolver->AddCustom("LceBconv2d",
compute_engine::tflite::Register_BCONV_2D());
if (use_reference_bconv) {
resolver->AddCustom("LceBconv2d",
compute_engine::tflite::Register_BCONV_2D_REF());
} else {
resolver->AddCustom("LceBconv2d",
compute_engine::tflite::Register_BCONV_2D());
}
resolver->AddCustom("LceBMaxPool2d",
compute_engine::tflite::Register_BMAXPOOL_2D());
};
Expand Down
10 changes: 8 additions & 2 deletions larq_compute_engine/tflite/python/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class Interpreter:
# Arguments
flatbuffer_model: A serialized Larq Compute Engine model in the flatbuffer format.
num_threads: The number of threads used by the interpreter.
use_reference_bconv: When True, uses the reference implementation of LceBconv2d.
# Attributes
input_types: Returns a list of input types.
Expand All @@ -51,9 +52,14 @@ class Interpreter:
output_shapes: Returns a list of output shapes.
"""

def __init__(self, flatbuffer_model: bytes, num_threads: int = 1):
def __init__(
self,
flatbuffer_model: bytes,
num_threads: int = 1,
use_reference_bconv: bool = False,
):
self.interpreter = interpreter_wrapper_lite.LiteInterpreter(
flatbuffer_model, num_threads
flatbuffer_model, num_threads, use_reference_bconv
)

@property
Expand Down
11 changes: 7 additions & 4 deletions larq_compute_engine/tflite/python/interpreter_wrapper_lite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ class LiteInterpreterWrapper
: public InterpreterWrapperBase<tflite::Interpreter> {
public:
LiteInterpreterWrapper(const pybind11::bytes& flatbuffer,
const int num_threads);
const int num_threads = 1,
const bool use_reference_bconv = false);
~LiteInterpreterWrapper(){};

private:
Expand All @@ -21,7 +22,8 @@ class LiteInterpreterWrapper
};

LiteInterpreterWrapper::LiteInterpreterWrapper(
const pybind11::bytes& flatbuffer, const int num_threads = 1) {
const pybind11::bytes& flatbuffer, const int num_threads,
const bool use_reference_bconv) {
// Make a copy of the flatbuffer because it can get deallocated after the
// constructor is done
flatbuffer_ = static_cast<std::string>(flatbuffer);
Expand All @@ -34,7 +36,8 @@ LiteInterpreterWrapper::LiteInterpreterWrapper(

// Build the interpreter
resolver_ = std::make_unique<tflite::ops::builtin::BuiltinOpResolver>();
compute_engine::tflite::RegisterLCECustomOps(resolver_.get());
compute_engine::tflite::RegisterLCECustomOps(resolver_.get(),
use_reference_bconv);

tflite::InterpreterBuilder builder(*model_, *resolver_);
builder(&interpreter_, num_threads);
Expand All @@ -46,7 +49,7 @@ LiteInterpreterWrapper::LiteInterpreterWrapper(

PYBIND11_MODULE(interpreter_wrapper_lite, m) {
pybind11::class_<LiteInterpreterWrapper>(m, "LiteInterpreter")
.def(pybind11::init<const pybind11::bytes&, const int>())
.def(pybind11::init<const pybind11::bytes&, const int, const bool>())
.def_property("input_types", &LiteInterpreterWrapper::get_input_types,
nullptr)
.def_property("output_types", &LiteInterpreterWrapper::get_output_types,
Expand Down
35 changes: 19 additions & 16 deletions larq_compute_engine/tflite/tests/bconv2d_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,7 @@ void runTest(const TestParam& param) {
constexpr bool write_bitpacked_output =
std::is_same<TOutput, TBitpacked>::value;
constexpr bool int8_output = std::is_same<TOutput, std::int8_t>::value;
constexpr bool float_output = std::is_same<TOutput, float>::value;

const Padding builtin_padding =
(padding == Padding_ONE ? Padding_VALID : padding);
Expand Down Expand Up @@ -492,22 +493,24 @@ void runTest(const TestParam& param) {
const int bitpacked_filters_num_elem =
filter_count * filter_height * filter_width * packed_channels_per_group;

// the reference implementation only support one-padding
const auto is_reference_registration =
(registration == compute_engine::tflite::Register_BCONV_2D_REF);

if (padding == Padding_SAME &&
(is_reference_registration || activation == ActivationFunctionType_RELU ||
write_bitpacked_output || int8_output)) {
// Zero-padding is not supported in combination with:
// - The reference implementation
// - Fused ReLu
// - Writing bitpacked output
// - Int8 output
// We could use `EXPECT_DEATH` here but it is extremely slow. Therefore we
// have a separate test below, and here we just skip.
GTEST_SKIP();
return;
if (padding == Padding_SAME) {
if (is_reference_registration) {
if (input_depth % 2 != 0) {
GTEST_SKIP();
return;
}
} else {
if (!float_output || activation == ActivationFunctionType_RELU) {
// We could use `EXPECT_DEATH` here but it is
// extremely slow. Therefore we have a separate test below, and here we
// just skip.
GTEST_SKIP();
return;
}
}
}

std::random_device rd;
Expand Down Expand Up @@ -874,7 +877,7 @@ TEST(BConv2DTests, ReluErrorDeathTest) {
threshold_tensor, 64, 1, 1, Padding_SAME, 0,
ActivationFunctionType_RELU, 1, 1, 1);
},
"Fused activations are only supported with valid or one-padding.");
"Zero-padding is only supported by");

// Test if writing bitpacked output throws an error in combination with
// zero-padding.
Expand All @@ -886,7 +889,7 @@ TEST(BConv2DTests, ReluErrorDeathTest) {
post_tensor, threshold_tensor, 64, 1, 1, Padding_SAME, 0,
ActivationFunctionType_NONE, 1, 1, 1);
},
"Writing bitpacked output is only supported with valid or one-padding.");
"Zero-padding is only supported by");
}

TEST(BConv2DTests, Int8ErrorDeathTest) {
Expand All @@ -910,7 +913,7 @@ TEST(BConv2DTests, Int8ErrorDeathTest) {
threshold_tensor, 64, 1, 1, Padding_SAME, 0,
ActivationFunctionType_NONE, 1, 1, 1);
},
"8-bit quantization is only supported with valid or one-padding.");
"Zero-padding is only supported by");
}

} // namespace testing
Expand Down

0 comments on commit 747b6b7

Please sign in to comment.