diff --git a/larq_compute_engine/core/bconv2d/optimized_bgemm.h b/larq_compute_engine/core/bconv2d/optimized_bgemm.h index 2e7b07963..c10b91478 100644 --- a/larq_compute_engine/core/bconv2d/optimized_bgemm.h +++ b/larq_compute_engine/core/bconv2d/optimized_bgemm.h @@ -112,7 +112,7 @@ inline void BConv2DOptimizedBGEMM( // output tensor with zeroes in advance so that the BGEMM doesn't have to // worry about doing the padding. if (std::is_same::value && - output_shape.Dims(3) % 32 != 0) { + output_shape.Dims(3) % bitpacking_bitwidth != 0) { std::fill( output_data, output_data + FlatSizeSkipDim(output_shape, 3) * diff --git a/larq_compute_engine/core/bconv2d/params.h b/larq_compute_engine/core/bconv2d/params.h index b800b93f7..64e0d075e 100644 --- a/larq_compute_engine/core/bconv2d/params.h +++ b/larq_compute_engine/core/bconv2d/params.h @@ -15,6 +15,7 @@ struct BConv2DParams { std::int32_t filter_height; std::int32_t channels_in; std::int32_t channels_out; + std::int32_t groups; // Strides std::int32_t stride_height; diff --git a/larq_compute_engine/core/bconv2d/reference.h b/larq_compute_engine/core/bconv2d/reference.h index 081024748..2d5896262 100644 --- a/larq_compute_engine/core/bconv2d/reference.h +++ b/larq_compute_engine/core/bconv2d/reference.h @@ -57,9 +57,9 @@ inline void BConv2DReference( TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); const int batches = MatchingDim(packed_input_shape, 0, output_shape, 0); - const int input_depth = - MatchingDim(packed_input_shape, 3, packed_filter_shape, 3); + const int input_depth_per_group = packed_filter_shape.Dims(3); const int output_depth = packed_filter_shape.Dims(0); + const int output_depth_per_group = output_depth / bconv2d_params->groups; const int input_height = packed_input_shape.Dims(1); const int input_width = packed_input_shape.Dims(2); const int filter_height = packed_filter_shape.Dims(1); @@ -67,6 +67,11 @@ inline void BConv2DReference( const int output_height = output_shape.Dims(1); const int output_width = output_shape.Dims(2); + TFLITE_DCHECK_EQ(input_depth_per_group * bconv2d_params->groups, + packed_input_shape.Dims(3)); + TFLITE_DCHECK_EQ(output_depth_per_group * bconv2d_params->groups, + output_depth); + for (int batch = 0; batch < batches; ++batch) { for (int out_y = 0; out_y < output_height; ++out_y) { for (int out_x = 0; out_x < output_width; ++out_x) { @@ -75,10 +80,12 @@ inline void BConv2DReference( for (int out_channel = 0; out_channel < output_depth; ++out_channel) { const int in_x_origin = (out_x * stride_width) - pad_width; const int in_y_origin = (out_y * stride_height) - pad_height; + const int group = out_channel / output_depth_per_group; AccumScalar accum = AccumScalar(0); for (int filter_y = 0; filter_y < filter_height; ++filter_y) { for (int filter_x = 0; filter_x < filter_width; ++filter_x) { - for (int in_channel = 0; in_channel < input_depth; ++in_channel) { + for (int in_channel = 0; in_channel < input_depth_per_group; + ++in_channel) { const int in_x = in_x_origin + dilation_width_factor * filter_x; const int in_y = in_y_origin + dilation_height_factor * filter_y; @@ -88,7 +95,8 @@ inline void BConv2DReference( if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) && (in_y < input_height)) { input_value = packed_input_data[Offset( - packed_input_shape, batch, in_y, in_x, in_channel)]; + packed_input_shape, batch, in_y, in_x, + group * input_depth_per_group + in_channel)]; } TBitpacked filter_value = packed_filter_data[Offset(packed_filter_shape, out_channel, diff --git a/larq_compute_engine/tflite/kernels/bconv2d.cc b/larq_compute_engine/tflite/kernels/bconv2d.cc index 43747bc56..bea24f357 100644 --- a/larq_compute_engine/tflite/kernels/bconv2d.cc +++ b/larq_compute_engine/tflite/kernels/bconv2d.cc @@ -176,6 +176,25 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { bconv2d_params->filter_height = SizeOfDimension(filter, 1); bconv2d_params->filter_width = SizeOfDimension(filter, 2); + if (SizeOfDimension(filter, 3) == + GetBitpackedSize(bconv2d_params->channels_in)) { + bconv2d_params->groups = 1; + } else { + TF_LITE_ENSURE_MSG( + context, kernel_type == KernelType::kReference, + "Grouped binary convolutions are not supported with this kernel."); + TF_LITE_ENSURE_EQ(context, + GetBitpackedSize(bconv2d_params->channels_in) % + SizeOfDimension(filter, 3), + 0); + const std::int32_t groups = GetBitpackedSize(bconv2d_params->channels_in) / + SizeOfDimension(filter, 3); + const std::int32_t group_size = bconv2d_params->channels_in / groups; + TF_LITE_ENSURE_EQ(context, group_size % core::bitpacking_bitwidth, 0); + TF_LITE_ENSURE_EQ(context, bconv2d_params->channels_out % groups, 0); + bconv2d_params->groups = groups; + } + // Compute the padding and output values (height, width) int out_width, out_height; bconv2d_params->padding_values = ComputePaddingHeightWidth( @@ -273,7 +292,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { } // Resize the im2col tensor - int bitpacked_channels_in = GetBitpackedSize(bconv2d_params->channels_in); + const std::int32_t bitpacked_channels_in = + GetBitpackedSize(bconv2d_params->channels_in); TfLiteIntArray* im2col_size = TfLiteIntArrayCopy(output_shape); im2col_size->data[3] = bitpacked_channels_in * bconv2d_params->filter_height * @@ -321,6 +341,11 @@ void OneTimeSetup(TfLiteContext* context, TfLiteNode* node, OpData* op_data) { const auto* post_activation_bias = GetInput(context, node, 3); const auto* output = GetOutput(context, node, 0); + // Division is safe because at this point we know that channels_in is a + // multiple of the number of groups. + const std::int32_t channels_in_per_group = + bconv2d_params->channels_in / bconv2d_params->groups; + // For 'same-zero' padding, compute the padding-correction. if (bconv2d_params->padding_type == kTfLitePaddingSame && bconv2d_params->pad_value == 0) { @@ -331,7 +356,7 @@ void OneTimeSetup(TfLiteContext* context, TfLiteNode* node, OpData* op_data) { zero_padding_correction::CacheCorrectionValues( GetTensorData(filter), bconv2d_params->filter_height, bconv2d_params->filter_width, bconv2d_params->channels_out, - bconv2d_params->channels_in, bconv2d_params->dilation_height_factor, + channels_in_per_group, bconv2d_params->dilation_height_factor, bconv2d_params->dilation_width_factor, GetTensorData(post_activation_multiplier), op_data->padding_buffer.data()); @@ -346,9 +371,8 @@ void OneTimeSetup(TfLiteContext* context, TfLiteNode* node, OpData* op_data) { LCE_EXTRA_BYTES / sizeof(float)); const auto filter_shape = GetTensorShape(GetInput(context, node, 1)); - const std::int32_t backtransform_add = filter_shape.Dims(1) * - filter_shape.Dims(2) * - bconv2d_params->channels_in; + const std::int32_t backtransform_add = + filter_shape.Dims(1) * filter_shape.Dims(2) * channels_in_per_group; const double output_scale = output->type == kTfLiteInt8 ? output->params.scale : 1.0f; const double output_zero_point = diff --git a/larq_compute_engine/tflite/tests/bconv2d_op_model.h b/larq_compute_engine/tflite/tests/bconv2d_op_model.h index 2d3435ff6..b4bf04227 100644 --- a/larq_compute_engine/tflite/tests/bconv2d_op_model.h +++ b/larq_compute_engine/tflite/tests/bconv2d_op_model.h @@ -42,9 +42,6 @@ class BaseBConv2DOpModel : public SingleOpModel { flexbuffers::Builder fbb; fbb.Map([&]() { - // This attribute is necessary because if the filters are bitpacked and - // we're reading bitpacked input then we don't have access to the original - // 'true' number of input channels. fbb.Int("channels_in", channels_in); fbb.Int("stride_height", stride_height); fbb.Int("stride_width", stride_width); diff --git a/larq_compute_engine/tflite/tests/bconv2d_test.cc b/larq_compute_engine/tflite/tests/bconv2d_test.cc index 3c18e210d..b76c21268 100644 --- a/larq_compute_engine/tflite/tests/bconv2d_test.cc +++ b/larq_compute_engine/tflite/tests/bconv2d_test.cc @@ -52,6 +52,7 @@ TfLiteRegistration* Register_CONVOLUTION_GENERIC_OPT(); namespace { +using compute_engine::core::bitpacking_bitwidth; using compute_engine::core::TBitpacked; using namespace compute_engine::core::bitpacking; @@ -194,6 +195,7 @@ namespace testing { typedef std::tuple, // input shape [BHWI] std::array, // filter shape [HWO] + int, // number of groups std::array, // strides [HW] std::array, // dilations [HW] Padding, // paddding @@ -213,15 +215,16 @@ struct TestParam { filter_height(::testing::get<1>(param_tuple)[0]), filter_width(::testing::get<1>(param_tuple)[1]), filter_count(::testing::get<1>(param_tuple)[2]), - stride_height(::testing::get<2>(param_tuple)[0]), - stride_width(::testing::get<2>(param_tuple)[1]), - dilation_height_factor(::testing::get<3>(param_tuple)[0]), - dilation_width_factor(::testing::get<3>(param_tuple)[1]), - padding(::testing::get<4>(param_tuple)), - activation(::testing::get<5>(param_tuple)), - num_threads(::testing::get<6>(param_tuple)), - kernel_name(::testing::get<7>(param_tuple).first), - registration(::testing::get<7>(param_tuple).second) {} + groups(::testing::get<2>(param_tuple)), + stride_height(::testing::get<3>(param_tuple)[0]), + stride_width(::testing::get<3>(param_tuple)[1]), + dilation_height_factor(::testing::get<4>(param_tuple)[0]), + dilation_width_factor(::testing::get<4>(param_tuple)[1]), + padding(::testing::get<5>(param_tuple)), + activation(::testing::get<6>(param_tuple)), + num_threads(::testing::get<7>(param_tuple)), + kernel_name(::testing::get<8>(param_tuple).first), + registration(::testing::get<8>(param_tuple).second) {} static std::string TestNameSuffix( const ::testing::TestParamInfo& info) { @@ -244,10 +247,11 @@ struct TestParam { // WARNING: substitute accepts only 11 arguments return absl::Substitute( - "Op$0_I$1_K$2_P$3_PV$4_S$5_D$6_Act$7_T$8", param.kernel_name, - param_input_oss.str(), param_filter_oss.str(), GetPaddingName(padding), - pad_values, param_strides_oss.str(), param_dilation_oss.str(), - getActivationString(param.activation), param.num_threads); + "Op$0_I$1_K$2_G$3_P$4_PV$5_S$6_D$7_Act$8_T$9", param.kernel_name, + param_input_oss.str(), param_filter_oss.str(), param.groups, + GetPaddingName(padding), pad_values, param_strides_oss.str(), + param_dilation_oss.str(), getActivationString(param.activation), + param.num_threads); } int input_batch_count = 1; @@ -259,6 +263,8 @@ struct TestParam { int filter_width = 3; int filter_count = 1; + int groups = 1; + int stride_height = 1; int stride_width = 1; @@ -318,7 +324,8 @@ class BConv2DOpTest : public ::testing::TestWithParam { } }; -void ComputeThresholds(int input_depth, int filter_height, int filter_width, +void ComputeThresholds(int input_depth_per_group, int filter_height, + int filter_width, const std::vector& post_activation_multiplier, const std::vector& post_activation_bias, enum ActivationFunctionType activation, @@ -340,7 +347,8 @@ void ComputeThresholds(int input_depth, int filter_height, int filter_width, } // We do all intermediate computations here in double to keep accuracy - const double backtransform_add = filter_height * filter_width * input_depth; + const double backtransform_add = + filter_height * filter_width * input_depth_per_group; for (size_t i = 0; i < post_activation_multiplier.size(); ++i) { const double post_mul = post_activation_multiplier[i]; const double post_bias = post_activation_bias[i]; @@ -436,6 +444,7 @@ void runTest(const TestParam& param) { const int filter_height = param.filter_height; const int filter_width = param.filter_width; const int filter_count = param.filter_count; + const int groups = param.groups; const int stride_height = param.stride_height; const int stride_width = param.stride_width; const int dilation_height_factor = param.dilation_height_factor; @@ -453,17 +462,35 @@ void runTest(const TestParam& param) { (padding == Padding_ONE ? Padding_SAME : padding); const int pad_values = (padding == Padding_ONE ? 1 : 0); + if (groups > 1 && + (registration != compute_engine::tflite::Register_BCONV_2D_REF || + input_depth % groups != 0 || filter_count % groups != 0 || + (input_depth / groups) % bitpacking_bitwidth != 0)) { + // Grouped convolutions are only supported in the reference kernel, and + // require compatible input and filter dimensions, with the additional + // requirement that the per-group input depth must be a multiple of the + // bitpacking bitwidth. + GTEST_SKIP(); + return; + } + + const int input_depth_per_group = input_depth / groups; + const int output_depth_per_group = filter_count / groups; + const int packed_channels = GetBitpackedSize(input_depth); + // This is valid because of the constraint on group sizes being a multiple of + // the bitpacking bitwidth. + const int packed_channels_per_group = GetBitpackedSize(input_depth_per_group); const int input_num_elem = input_batch_count * input_height * input_width * input_depth; - const int packed_input_num_elem = + const int bitpacked_input_num_elem = input_batch_count * input_height * input_width * packed_channels; const int filters_num_elem = - filter_height * filter_width * input_depth * filter_count; - const int packed_filters_num_elem = - filter_count * filter_height * filter_width * packed_channels; + filter_height * filter_width * input_depth_per_group * filter_count; + const int bitpacked_filters_num_elem = + filter_count * filter_height * filter_width * packed_channels_per_group; // the reference implementation only support one-padding const auto is_reference_registration = @@ -489,7 +516,7 @@ void runTest(const TestParam& param) { LceTensor input_tensor( {input_batch_count, input_height, input_width, input_depth}); - // Shape will be changed later if padding is required + // The shape will be changed later if padding is required LceTensor padded_input_tensor( {input_batch_count, input_height, input_width, input_depth}); @@ -497,10 +524,10 @@ void runTest(const TestParam& param) { {input_batch_count, input_height, input_width, packed_channels}); LceTensor filter_tensor( - {filter_count, filter_height, filter_width, input_depth}); + {filter_count, filter_height, filter_width, input_depth_per_group}); LceTensor packed_filter_tensor( - {filter_count, filter_height, filter_width, packed_channels}); + {filter_count, filter_height, filter_width, packed_channels_per_group}); // We can use the same tensor object for multiply and bias // because they have the same shape and datatype @@ -519,7 +546,8 @@ void runTest(const TestParam& param) { } } - std::vector input_data, padded_input_data, filters_data, bias_data; + std::vector input_data, padded_input_data, filters_data, bias_data, + builtin_output; std::vector packed_input_data; std::vector post_activation_multiplier_data, post_activation_bias_data; @@ -528,8 +556,8 @@ void runTest(const TestParam& param) { input_data.resize(input_num_elem); filters_data.resize(filters_num_elem); - packed_filters_data.resize(packed_filters_num_elem); - packed_input_data.resize(packed_input_num_elem); + packed_filters_data.resize(bitpacked_filters_num_elem); + packed_input_data.resize(bitpacked_input_num_elem); bias_data.resize(filter_count, 0.0f); post_activation_multiplier_data.resize(filter_count, 1); post_activation_bias_data.resize(filter_count, 0); @@ -550,7 +578,7 @@ void runTest(const TestParam& param) { std::end(post_activation_bias_data), float_generator); if (write_bitpacked_output) { - ComputeThresholds(input_depth, filter_height, filter_width, + ComputeThresholds(input_depth_per_group, filter_height, filter_width, post_activation_multiplier_data, post_activation_bias_data, activation, threshold_data); } @@ -560,8 +588,8 @@ void runTest(const TestParam& param) { input_batch_count * input_height * input_width, input_depth, packed_input_data.data(), input_tensor.zero_point); bitpack_matrix(filters_data.data(), - filter_count * filter_height * filter_width, input_depth, - packed_filters_data.data()); + filter_count * filter_height * filter_width, + input_depth_per_group, packed_filters_data.data()); int output_height, output_width; TfLitePaddingValues padding_values = ComputePaddingHeightWidth( @@ -612,17 +640,87 @@ void runTest(const TestParam& param) { Run built-in op. -----------------*/ - ConvolutionOpModel m_builtin( - ::tflite::ops::builtin::Register_CONVOLUTION_GENERIC_OPT(), - padded_input_tensor, filter_tensor, builtin_output_tensor, stride_width, - stride_height, builtin_padding, activation, dilation_width_factor, - dilation_height_factor, num_threads); + std::vector builtin_output_shape; + + if (groups == 1) { + ConvolutionOpModel m_builtin( + ::tflite::ops::builtin::Register_CONVOLUTION_GENERIC_OPT(), + padded_input_tensor, filter_tensor, builtin_output_tensor, stride_width, + stride_height, builtin_padding, activation, dilation_width_factor, + dilation_height_factor, num_threads); + + m_builtin.SetInput(padded_input_data); + m_builtin.SetFilter(filters_data); + m_builtin.SetBias(bias_data); + m_builtin.Invoke(); + builtin_output = m_builtin.GetOutput(); + builtin_output_shape = m_builtin.GetOutputShape(); + } else { + // A grouped convolution. As the built-in Conv2D op doesn't support groups, + // we have to simulate it with multiple Conv2D ops. + + // Configure group-wise tensor shapes. + LceTensor per_group_input_tensor(padded_input_tensor); + per_group_input_tensor.shape[3] = input_depth_per_group; + LceTensor per_group_filter_tensor(filter_tensor); + per_group_filter_tensor.shape[0] = output_depth_per_group; + LceTensor per_group_builtin_output_tensor; + + // Compute the result for each group in sequence. + for (int group_id = 0; group_id < groups; group_id++) { + // Copy the input data for this Conv2D op into a temporary input tensor. + std::vector per_group_input_data(padded_input_data.size() / + groups); + auto per_group_input_ptr = per_group_input_data.data(); + for (int offset = 0; offset < padded_input_data.size(); + offset += input_depth) { + std::memcpy(per_group_input_ptr, + padded_input_data.data() + offset + + group_id * input_depth_per_group, + input_depth_per_group * sizeof(float)); + per_group_input_ptr += input_depth_per_group; + } + + // Copy the filter data for this Conv2D op into a temporary filter tensor. + const int num_filter_elems_per_group = output_depth_per_group * + filter_height * filter_width * + input_depth_per_group; + std::vector per_group_filter_data( + filters_data.begin() + group_id * num_filter_elems_per_group, + filters_data.begin() + (group_id + 1) * num_filter_elems_per_group); + + // Create and invoke the built-in Conv2D op. + ConvolutionOpModel m_builtin( + ::tflite::ops::builtin::Register_CONVOLUTION_GENERIC_OPT(), + per_group_input_tensor, per_group_filter_tensor, + per_group_builtin_output_tensor, stride_width, stride_height, + builtin_padding, activation, dilation_width_factor, + dilation_height_factor, num_threads); + m_builtin.SetInput(per_group_input_data); + m_builtin.SetFilter(per_group_filter_data); + m_builtin.SetBias(bias_data); + m_builtin.Invoke(); + auto per_group_builtin_output = m_builtin.GetOutput(); + + if (group_id == 0) { + builtin_output.resize(per_group_builtin_output.size() * groups); + builtin_output_shape = m_builtin.GetOutputShape(); + builtin_output_shape.at(3) *= groups; + } - m_builtin.SetInput(padded_input_data); - m_builtin.SetFilter(filters_data); - m_builtin.SetBias(bias_data); - m_builtin.Invoke(); - auto builtin_output = m_builtin.GetOutput(); + // Copy the temporary output into the correct place in the overall builtin + // output tensor. + auto builtin_output_ptr = + builtin_output.data() + group_id * output_depth_per_group; + for (int offset = 0; offset < per_group_builtin_output.size(); + offset += output_depth_per_group) { + std::memcpy(builtin_output_ptr, + per_group_builtin_output.data() + offset, + output_depth_per_group * sizeof(float)); + builtin_output_ptr += filter_count; + } + } + } // Apply the post multiply and add to the TFLite model. // We cannot fuse it into the tflite bias because it should happen *after* @@ -665,9 +763,10 @@ void runTest(const TestParam& param) { m_lce.Invoke(); const bool use_high_error_tolerance = filter_height * filter_width * input_depth > 1 << 12; - test_lce_op_output(m_lce.GetOutput(), m_builtin.GetOutputShape(), - builtin_output, output_tensor.zero_point, - output_tensor.scale, use_high_error_tolerance); + auto lce_output = m_lce.GetOutput(); + test_lce_op_output(lce_output, builtin_output_shape, builtin_output, + output_tensor.zero_point, output_tensor.scale, + use_high_error_tolerance); } TEST_P(BConv2DOpTest, BitpackedOutput) { @@ -695,6 +794,7 @@ INSTANTIATE_TEST_SUITE_P( std::array{1, 4, 4, 256}), // input shape [BHWI] Values(std::array{1, 1, 1}, std::array{3, 3, 4}, std::array{3, 3, 64}), // filter shape [HWO] + Values(1, 2), // number of groups Values(std::array{1, 1}), // strides height/width Values(std::array{1, 1}), // dilation height/width Values(Padding_VALID, Padding_ONE), // padding @@ -712,6 +812,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Combine( Values(std::array{1, 6, 6, 3072}), // input shape [BHWI] Values(std::array{5, 5, 4}), // filter shape [HWO] + Values(1), // number of groups Values(std::array{1, 1}), // strides height/width Values(std::array{1, 1}), // dilation height/width Values(Padding_VALID, Padding_ONE), // padding @@ -729,15 +830,17 @@ INSTANTIATE_TEST_SUITE_P( BigTest, BConv2DOpTest, ::testing::Combine( Values(std::array{1, 7, 7, 4}, std::array{3, 8, 5, 64}, - std::array{10, 7, 7, 96}, + std::array{5, 7, 7, 96}, std::array{1, 8, 5, 128}, std::array{1, 7, 7, 192}, - std::array{1, 8, 5, 256}), // input shape [BHWI] + std::array{1, 8, 5, 256}, + std::array{1, 7, 7, 512}), // input shape [BHWI] Values(std::array{1, 1, 1}, std::array{3, 3, 1}, std::array{2, 3, 2}, std::array{1, 1, 3}, std::array{3, 3, 4}, std::array{2, 3, 5}, std::array{1, 1, 6}, std::array{3, 3, 7}, std::array{2, 3, 32}), // filter shape [HWO] + Values(1, 2, 4), // number of groups Values(std::array{1, 1}, std::array{2, 3}), // strides height/width Values(std::array{1, 1},