Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add grouped binary convolution support (3/3): indirect BGEMM kernel. #551

Merged
merged 6 commits into from
Jun 7, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion larq_compute_engine/core/bconv2d/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ cc_library(
deps = [
":zero_padding_correction",
"//larq_compute_engine/core/indirect_bgemm:kernels",
"//larq_compute_engine/core/indirect_bgemm:prepare",
"@org_tensorflow//tensorflow/lite/kernels:cpu_backend_context",
"@org_tensorflow//tensorflow/lite/kernels:cpu_backend_gemm",
"@org_tensorflow//tensorflow/lite/kernels:padding",
Expand Down
41 changes: 20 additions & 21 deletions larq_compute_engine/core/bconv2d/optimized_indirect_bgemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,25 @@ namespace bconv2d {

template <typename AccumScalar, typename DstScalar>
inline void BConv2DOptimizedIndirectBGEMM(
const indirect_bgemm::IndirectBGEMMKernel<DstScalar> kernel,
const BConv2DParams* bconv2d_params,
const indirect_bgemm::Kernel* kernel, const BConv2DParams* bconv2d_params,
const RuntimeShape& bitpacked_input_shape, const RuntimeShape& output_shape,
const OutputTransform<DstScalar>& output_transform,
const TBitpacked* packed_weights, const TBitpacked** indirection_buffer,
DstScalar* output_data, const float* padding_buffer, const int pad_value) {
TF_LITE_ASSERT_EQ(bitpacked_input_shape.DimensionsCount(), 4);
TF_LITE_ASSERT_EQ(output_shape.DimensionsCount(), 4);

DstScalar* output_ptr, const float* padding_buffer, const int pad_value) {
ruy::profiler::ScopeLabel label("BConv2D (optimized, indirect BGEMM)");

const std::int32_t conv_kernel_size =
bconv2d_params->filter_height * bconv2d_params->filter_width;
const std::int32_t bitpacked_input_channels = bitpacked_input_shape.Dims(3);
const std::int32_t output_size =
output_shape.Dims(0) * output_shape.Dims(1) * output_shape.Dims(2);
const std::int32_t output_channels = bconv2d_params->channels_out;
// If writing bitpacked output with a channel count that isn't a multiple of
// 32 (i.e. where padding bits will be required in the output), fill the
// output tensor with zeroes in advance so that the BGEMM doesn't have to
// worry about doing the padding.
if (std::is_same<DstScalar, TBitpacked>::value &&
(kernel->output_channels % bitpacking_bitwidth != 0)) {
std::fill(
output_ptr,
output_ptr + kernel->num_output_pixels *
bitpacking::GetBitpackedSize(kernel->output_channels),
TBitpacked(0));
}

indirect_bgemm::RunKernel(kernel, conv_kernel_size, bitpacked_input_channels,
output_size, output_channels, output_transform,
packed_weights, indirection_buffer, output_data);
kernel->Dispatch(reinterpret_cast<void*>(output_ptr));

if (std::is_same<DstScalar, float>::value &&
bconv2d_params->padding_type == TfLitePadding::kTfLitePaddingSame &&
Expand All @@ -44,7 +42,8 @@ inline void BConv2DOptimizedIndirectBGEMM(
const int dilation_width_factor = bconv2d_params->dilation_width_factor;
const int dilation_height_factor = bconv2d_params->dilation_height_factor;
const int batches = MatchingDim(bitpacked_input_shape, 0, output_shape, 0);
const int input_depth = bconv2d_params->channels_in;
const int input_depth_per_group =
bconv2d_params->channels_in / bconv2d_params->groups;
const int input_width = bitpacked_input_shape.Dims(2);
const int input_height = bitpacked_input_shape.Dims(1);
const int filter_height = bconv2d_params->filter_height;
Expand All @@ -54,10 +53,10 @@ inline void BConv2DOptimizedIndirectBGEMM(
const int output_height = output_shape.Dims(1);

zero_padding_correction::ApplyCorrection(
batches, input_height, input_width, input_depth, filter_height,
filter_width, output_depth, stride_height, stride_width,
batches, input_height, input_width, input_depth_per_group,
filter_height, filter_width, output_depth, stride_height, stride_width,
dilation_height_factor, dilation_width_factor,
reinterpret_cast<float*>(output_data), output_height, output_width,
reinterpret_cast<float*>(output_ptr), output_height, output_width,
padding_buffer);
}
}
Expand Down
2 changes: 1 addition & 1 deletion larq_compute_engine/core/bitpacking/bitpack.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ namespace bitpacking {
// Utility functions

constexpr int GetBitpackedSize(int unpacked_elements) {
return (unpacked_elements + bitpacking_bitwidth - 1) / bitpacking_bitwidth;
return CeilDiv(unpacked_elements, bitpacking_bitwidth);
}

constexpr int GetBitpackedMatrixSize(int rows, int cols) {
Expand Down
20 changes: 6 additions & 14 deletions larq_compute_engine/core/indirect_bgemm/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,23 @@ licenses(["notice"]) # Apache 2.0

package(default_visibility = ["//visibility:public"])

cc_library(
name = "prepare",
hdrs = [
"prepare.h",
],
deps = [
"//larq_compute_engine/core:types",
"//larq_compute_engine/core/bconv2d:params",
"@org_tensorflow//tensorflow/lite/kernels/internal:types",
],
)

cc_library(
name = "kernels",
hdrs = [
"kernel.h",
srcs = [
"kernel_4x2_portable.h",
"kernel_8x4x1_aarch64.h",
"kernel_8x4x2_aarch64.h",
"kernel_8x4x4_aarch64.h",
],
hdrs = [
"kernel.h",
"select_kernel.h",
],
deps = [
"//larq_compute_engine/core:types",
"//larq_compute_engine/core/bconv2d:output_transform",
"//larq_compute_engine/core/bconv2d:params",
"//larq_compute_engine/core/bitpacking:bitpack",
"@org_tensorflow//tensorflow/lite/kernels/internal:types",
"@ruy//ruy/profiler:instrumentation",
],
Expand Down
Loading