Skip to content

Commit

Permalink
Add grouped binary convolution support (3/3): indirect BGEMM kernel.
Browse files Browse the repository at this point in the history
Add support for grouped binary convolutions to the optimised
indirect BGEMM kernel.
  • Loading branch information
AdamHillier committed Nov 5, 2020
1 parent 1fe65da commit c1676a2
Show file tree
Hide file tree
Showing 9 changed files with 764 additions and 503 deletions.
28 changes: 21 additions & 7 deletions larq_compute_engine/core/bconv2d/optimized_indirect_bgemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,26 @@ inline void BConv2DOptimizedIndirectBGEMM(

const std::int32_t conv_kernel_size =
conv_params->filter_height * conv_params->filter_width;
const std::int32_t bitpacked_input_channels = bitpacked_input_shape.Dims(3);
const std::int32_t output_size =
output_shape.Dims(0) * output_shape.Dims(1) * output_shape.Dims(2);
const std::int32_t output_channels = conv_params->channels_out;
const int32_t groups = conv_params->groups;
const int32_t input_depth = bitpacked_input_shape.Dims(3);
const int32_t output_channels = conv_params->channels_out;

indirect_bgemm::RunKernel(kernel, conv_kernel_size, bitpacked_input_channels,
output_size, output_channels, output_transform,
// If writing bitpacked output with a channel count that isn't a multiple of
// 32 (i.e. where padding bits will be required in the output), fill the
// output tensor with zeroes in advance so that the BGEMM doesn't have to
// worry about doing the padding.
if (std::is_same<DstScalar, TBitpacked>::value &&
output_channels % bitpacking_bitwidth != 0) {
std::fill(output_data,
output_data +
output_size * bitpacking::GetBitpackedSize(output_channels),
TBitpacked(0));
}

indirect_bgemm::RunKernel(kernel, output_size, conv_kernel_size, groups,
input_depth, output_channels, output_transform,
packed_weights, indirection_buffer, output_data);

if (std::is_same<DstScalar, float>::value &&
Expand All @@ -44,7 +57,8 @@ inline void BConv2DOptimizedIndirectBGEMM(
const int dilation_width_factor = conv_params->dilation_width_factor;
const int dilation_height_factor = conv_params->dilation_height_factor;
const int batches = MatchingDim(bitpacked_input_shape, 0, output_shape, 0);
const int input_depth = conv_params->channels_in;
const int input_depth_per_group =
conv_params->channels_in / conv_params->groups;
const int input_width = bitpacked_input_shape.Dims(2);
const int input_height = bitpacked_input_shape.Dims(1);
const int filter_height = conv_params->filter_height;
Expand All @@ -54,8 +68,8 @@ inline void BConv2DOptimizedIndirectBGEMM(
const int output_height = output_shape.Dims(1);

zero_padding_correction::ApplyCorrection(
batches, input_height, input_width, input_depth, filter_height,
filter_width, output_depth, stride_height, stride_width,
batches, input_height, input_width, input_depth_per_group,
filter_height, filter_width, output_depth, stride_height, stride_width,
dilation_height_factor, dilation_width_factor,
reinterpret_cast<float*>(output_data), output_height, output_width,
padding_buffer);
Expand Down
23 changes: 15 additions & 8 deletions larq_compute_engine/core/indirect_bgemm/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ template <typename DstScalar>
struct IndirectBGEMMKernel {
using MicroKernelFunction = void(const std::int32_t, const std::int32_t,
const std::int32_t, const std::int32_t,
const std::int32_t,
const bconv2d::OutputTransform<DstScalar>&,
const TBitpacked*, const TBitpacked**,
DstScalar*);
Expand Down Expand Up @@ -70,20 +71,23 @@ inline IndirectBGEMMKernel<DstScalar> SelectRuntimeKernel(
typename std::conditional<is_float_or_int8, DstScalar, float>::type;
using KFn = typename IndirectBGEMMKernel<DstScalar>::MicroKernelFunction;

if (bitpacked_input_shape.Dims(3) % 4 == 0) {
if (bitpacked_input_shape.Dims(3) > 4) {
const std::int32_t input_depth_per_group =
bitpacked_input_shape.Dims(3) / conv_params->groups;

if (input_depth_per_group % 4 == 0) {
if (input_depth_per_group > 4) {
return {(KFn*)&kernel_8x4x4_aarch64::RunKernel<DS, true>, 8, 4, 4};
} else {
return {(KFn*)&kernel_8x4x4_aarch64::RunKernel<DS, false>, 8, 4, 4};
}
} else if (bitpacked_input_shape.Dims(3) % 2 == 0) {
if (bitpacked_input_shape.Dims(3) > 2) {
} else if (input_depth_per_group % 2 == 0) {
if (input_depth_per_group > 2) {
return {(KFn*)&kernel_8x4x2_aarch64::RunKernel<DS, true>, 8, 4, 2};
} else {
return {(KFn*)&kernel_8x4x2_aarch64::RunKernel<DS, false>, 8, 4, 2};
}
} else {
if (bitpacked_input_shape.Dims(3) > 1) {
if (input_depth_per_group > 1) {
return {(KFn*)&kernel_8x4x1_aarch64::RunKernel<DS, true>, 8, 4, 1};
} else {
return {(KFn*)&kernel_8x4x1_aarch64::RunKernel<DS, false>, 8, 4, 1};
Expand All @@ -98,13 +102,16 @@ inline IndirectBGEMMKernel<DstScalar> SelectRuntimeKernel(

template <typename DstScalar>
void RunKernel(const IndirectBGEMMKernel<DstScalar>& kernel,
const std::int32_t conv_kernel_size,
const std::int32_t bitpacked_input_channels,
const std::int32_t output_size,
const std::int32_t conv_kernel_size, const std::int32_t groups,
const std::int32_t input_depth,
const std::int32_t output_channels,
const bconv2d::OutputTransform<DstScalar>& output_transform,
const TBitpacked* packed_weights_ptr,
const TBitpacked** indirection_buffer, DstScalar* output_ptr) {
TFLITE_DCHECK_EQ(input_depth % groups, 0);
TFLITE_DCHECK_EQ((input_depth / groups) % kernel.block_size_depth, 0);

// TODO: implement multithreading here.
for (std::int32_t pixel_start = 0; pixel_start < output_size;
pixel_start += kernel.block_size_pixels) {
Expand All @@ -114,7 +121,7 @@ void RunKernel(const IndirectBGEMMKernel<DstScalar>& kernel,
: output_channels;
kernel.micro_kernel_function(
std::min(output_size - pixel_start, kernel.block_size_pixels),
conv_kernel_size, bitpacked_input_channels, output_channels,
conv_kernel_size, groups, input_depth, output_channels,
output_transform, packed_weights_ptr,
indirection_buffer + pixel_start * conv_kernel_size,
output_ptr + pixel_start * output_stride);
Expand Down
139 changes: 95 additions & 44 deletions larq_compute_engine/core/indirect_bgemm/kernel_4x2_portable.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ namespace kernel_4x2_portable {
*/
template <typename DstScalar>
void RunKernel(const std::int32_t block_num_pixels,
const std::int32_t conv_kernel_size,
const std::int32_t channels_in, const std::int32_t channels_out,
const std::int32_t conv_kernel_size, const std::int32_t groups,
const std::int32_t input_depth,
const std::int32_t output_channels,
const bconv2d::OutputTransform<DstScalar>& output_transform,
const TBitpacked* weights_ptr,
const TBitpacked** indirection_buffer, DstScalar* output_ptr) {
Expand All @@ -34,11 +35,17 @@ void RunKernel(const std::int32_t block_num_pixels,
TFLITE_DCHECK_GE(block_num_pixels, 1);
TFLITE_DCHECK_LE(block_num_pixels, 2);
TFLITE_DCHECK_GE(conv_kernel_size, 1);
TFLITE_DCHECK_GE(channels_in, 1);
TFLITE_DCHECK_GE(channels_out, 1);
TFLITE_DCHECK_GE(groups, 1);
TFLITE_DCHECK_GE(input_depth, 1);
TFLITE_DCHECK_GE(output_channels, 1);
TFLITE_DCHECK_EQ(input_depth % groups, 0);
TFLITE_DCHECK_EQ(output_channels % groups, 0);

const std::int32_t input_depth_per_group = input_depth / groups;
const std::int32_t output_channels_per_group = output_channels / groups;

DstScalar* output_ptr_0 = output_ptr;
DstScalar* output_ptr_1 = output_ptr + channels_out;
DstScalar* output_ptr_1 = output_ptr + output_channels;

// At the end of the output array we might get a block where the number of
// pixels is less than 2, if the overall output size is not a multiple of 2.
Expand All @@ -51,6 +58,9 @@ void RunKernel(const std::int32_t block_num_pixels,
output_ptr_1 = output_ptr_0;
}

std::int32_t input_depth_offset = 0;
std::int32_t group_end_output_channel = output_channels_per_group;

std::int32_t c_out_index = 0;
do {
// Accumulators
Expand All @@ -61,11 +71,13 @@ void RunKernel(const std::int32_t block_num_pixels,

std::int32_t k_size_index = conv_kernel_size;
do {
const TBitpacked* activations_ptr_0 = indirection_buffer[0];
const TBitpacked* activations_ptr_1 = indirection_buffer[1];
const TBitpacked* activations_ptr_0 =
indirection_buffer[0] + input_depth_offset;
const TBitpacked* activations_ptr_1 =
indirection_buffer[1] + input_depth_offset;
indirection_buffer += 2;

std::int32_t c_in_index = channels_in;
std::int32_t c_in_index = input_depth_per_group;
do {
const TBitpacked w_0 = weights_ptr[0];
const TBitpacked w_1 = weights_ptr[1];
Expand All @@ -87,7 +99,7 @@ void RunKernel(const std::int32_t block_num_pixels,
} while (--c_in_index > 0);
} while (--k_size_index > 0);

if (LCE_LIKELY(channels_out - c_out_index >= 4)) {
if (LCE_LIKELY(group_end_output_channel - c_out_index >= 4)) {
output_ptr_1[0] = output_transform.Run(acc_01, c_out_index);
output_ptr_1[1] = output_transform.Run(acc_11, c_out_index + 1);
output_ptr_1[2] = output_transform.Run(acc_21, c_out_index + 2);
Expand All @@ -99,10 +111,9 @@ void RunKernel(const std::int32_t block_num_pixels,
output_ptr_0[3] = output_transform.Run(acc_30, c_out_index + 3);
output_ptr_0 += 4;

indirection_buffer -= 2 * conv_kernel_size;
c_out_index += 4;
} else {
if (channels_out - c_out_index >= 2) {
if (group_end_output_channel - c_out_index >= 2) {
output_ptr_1[0] = output_transform.Run(acc_01, c_out_index);
output_ptr_1[1] = output_transform.Run(acc_11, c_out_index + 1);
output_ptr_1 += 2;
Expand All @@ -114,14 +125,23 @@ void RunKernel(const std::int32_t block_num_pixels,
acc_00 = acc_20;
c_out_index += 2;
}
if (channels_out - c_out_index >= 1) {
if (group_end_output_channel - c_out_index >= 1) {
output_ptr_1[0] = output_transform.Run(acc_01, c_out_index);
output_ptr_1 += 1;
output_ptr_0[0] = output_transform.Run(acc_00, c_out_index);
output_ptr_0 += 1;

c_out_index += 1;
}
}

c_out_index = channels_out;
indirection_buffer -= 2 * conv_kernel_size;

if (c_out_index == group_end_output_channel) {
input_depth_offset += input_depth_per_group;
group_end_output_channel += output_channels_per_group;
}
} while (c_out_index < channels_out);
} while (c_out_index < output_channels);
}

/**
Expand All @@ -130,7 +150,8 @@ void RunKernel(const std::int32_t block_num_pixels,
template <>
void RunKernel<TBitpacked>(
const std::int32_t block_num_pixels, const std::int32_t conv_kernel_size,
const std::int32_t channels_in, const std::int32_t channels_out,
const std::int32_t groups, const std::int32_t input_depth,
const std::int32_t output_channels,
const bconv2d::OutputTransform<TBitpacked>& output_transform,
const TBitpacked* weights_ptr, const TBitpacked** indirection_buffer,
TBitpacked* output_ptr) {
Expand All @@ -139,12 +160,18 @@ void RunKernel<TBitpacked>(
TFLITE_DCHECK_GE(block_num_pixels, 1);
TFLITE_DCHECK_LE(block_num_pixels, 2);
TFLITE_DCHECK_GE(conv_kernel_size, 1);
TFLITE_DCHECK_GE(channels_in, 1);
TFLITE_DCHECK_GE(channels_out, 1);
TFLITE_DCHECK_GE(groups, 1);
TFLITE_DCHECK_GE(input_depth, 1);
TFLITE_DCHECK_GE(output_channels, 1);
TFLITE_DCHECK_EQ(input_depth % groups, 0);
TFLITE_DCHECK_EQ(output_channels % groups, 0);

const std::int32_t input_depth_per_group = input_depth / groups;
const std::int32_t output_channels_per_group = output_channels / groups;

TBitpacked* output_ptr_0 = output_ptr;
TBitpacked* output_ptr_1 =
output_ptr + bitpacking::GetBitpackedSize(channels_out);
output_ptr + bitpacking::GetBitpackedSize(output_channels);

// At the end of the output array we might get a block where the number of
// pixels is less than 2, if the overall output size is not a multiple of 2.
Expand All @@ -161,6 +188,9 @@ void RunKernel<TBitpacked>(
// value when the columns are full.
TBitpacked output_col_0 = 0, output_col_1 = 0;

std::int32_t input_depth_offset = 0;
std::int32_t group_end_output_channel = output_channels_per_group;

std::int32_t c_out_index = 0;
do {
// Accumulators
Expand All @@ -171,11 +201,13 @@ void RunKernel<TBitpacked>(

std::int32_t k_size_index = conv_kernel_size;
do {
const TBitpacked* activations_ptr_0 = indirection_buffer[0];
const TBitpacked* activations_ptr_1 = indirection_buffer[1];
const TBitpacked* activations_ptr_0 =
indirection_buffer[0] + input_depth_offset;
const TBitpacked* activations_ptr_1 =
indirection_buffer[1] + input_depth_offset;
indirection_buffer += 2;

std::int32_t c_in_index = channels_in;
std::int32_t c_in_index = input_depth_per_group;
do {
const TBitpacked w_0 = weights_ptr[0];
const TBitpacked w_1 = weights_ptr[1];
Expand All @@ -197,44 +229,63 @@ void RunKernel<TBitpacked>(
} while (--c_in_index > 0);
} while (--k_size_index > 0);

// Correctness of the following section relies on the bitpacking bitwidth
// being 32.
static_assert(bitpacking_bitwidth == 32, "");

const int base_output_index = c_out_index % 16;
output_col_0 |= TBitpacked(output_transform.Run(acc_00, c_out_index))
<< (c_out_index % bitpacking_bitwidth);
<< base_output_index;
output_col_0 |= TBitpacked(output_transform.Run(acc_10, c_out_index + 1))
<< ((c_out_index + 1) % bitpacking_bitwidth);
<< base_output_index + 1;
output_col_0 |= TBitpacked(output_transform.Run(acc_20, c_out_index + 2))
<< ((c_out_index + 2) % bitpacking_bitwidth);
<< base_output_index + 2;
output_col_0 |= TBitpacked(output_transform.Run(acc_30, c_out_index + 3))
<< ((c_out_index + 3) % bitpacking_bitwidth);
<< base_output_index + 3;
output_col_1 |= TBitpacked(output_transform.Run(acc_01, c_out_index))
<< (c_out_index % bitpacking_bitwidth);
<< base_output_index;
output_col_1 |= TBitpacked(output_transform.Run(acc_11, c_out_index + 1))
<< ((c_out_index + 1) % bitpacking_bitwidth);
<< base_output_index + 1;
output_col_1 |= TBitpacked(output_transform.Run(acc_21, c_out_index + 2))
<< ((c_out_index + 2) % bitpacking_bitwidth);
<< base_output_index + 2;
output_col_1 |= TBitpacked(output_transform.Run(acc_31, c_out_index + 3))
<< ((c_out_index + 3) % bitpacking_bitwidth);
<< base_output_index + 3;

indirection_buffer -= 2 * conv_kernel_size;
c_out_index += 4;

// Write the bitpacked columns whenever they are full, or if we've computed
// the last output column value.
if (c_out_index % bitpacking_bitwidth == 0 || c_out_index >= channels_out) {
// If this is a 'leftover output channel' block (because the number of
// output channels isn't a multiple of four) then zero-out the extra bits.
if (c_out_index % bitpacking_bitwidth != 0) {

if (group_end_output_channel - c_out_index > 4) {
c_out_index += 4;
} else {
const int gap_to_group_end = group_end_output_channel - c_out_index;
if (gap_to_group_end < 4) {
output_col_0 &=
(TBitpacked(1) << (channels_out % bitpacking_bitwidth)) - 1;
(TBitpacked(1) << base_output_index + gap_to_group_end) - 1;
output_col_1 &=
(TBitpacked(1) << (channels_out % bitpacking_bitwidth)) - 1;
(TBitpacked(1) << base_output_index + gap_to_group_end) - 1;
}
c_out_index = group_end_output_channel;
input_depth_offset += input_depth_per_group;
group_end_output_channel += output_channels_per_group;
}

*output_ptr_1++ = output_col_1;
output_col_1 = 0;
*output_ptr_0++ = output_col_0;
output_col_0 = 0;
// If on the next iteration we will have 'wrapped around' the output
// columns, write the bottom halves to the output array.
if (c_out_index % 16 < base_output_index) {
*((std::int16_t*)output_ptr_1) = (std::int16_t)output_col_1;
*((std::int16_t*)output_ptr_0) = (std::int16_t)output_col_0;
output_col_1 >>= 16;
output_col_0 >>= 16;
output_ptr_1 = (TBitpacked*)(((std::int16_t*)output_ptr_1) + 1);
output_ptr_0 = (TBitpacked*)(((std::int16_t*)output_ptr_0) + 1);
}
} while (c_out_index < channels_out);
} while (c_out_index < output_channels);

// If we've got to the end and there are still un-written bits, make sure they
// get written now.
if (output_channels % 16 > 0) {
*((std::int16_t*)output_ptr_1) = (std::int16_t)output_col_1;
*((std::int16_t*)output_ptr_0) = (std::int16_t)output_col_0;
}
}

} // namespace kernel_4x2_portable
Expand Down
Loading

0 comments on commit c1676a2

Please sign in to comment.