Skip to content

Commit

Permalink
Review suggestions.
Browse files Browse the repository at this point in the history
  • Loading branch information
AdamHillier committed Jun 4, 2021
1 parent 1700bf4 commit b4d3580
Show file tree
Hide file tree
Showing 9 changed files with 37 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ inline void BConv2DOptimizedIndirectBGEMM(
// output tensor with zeroes in advance so that the BGEMM doesn't have to
// worry about doing the padding.
if (std::is_same<DstScalar, TBitpacked>::value &&
kernel->output_channels % bitpacking_bitwidth != 0) {
(kernel->output_channels % bitpacking_bitwidth != 0)) {
std::fill(
output_ptr,
output_ptr + kernel->num_output_pixels *
Expand Down
2 changes: 1 addition & 1 deletion larq_compute_engine/core/bitpacking/bitpack.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ namespace bitpacking {
// Utility functions

constexpr int GetBitpackedSize(int unpacked_elements) {
return (unpacked_elements + bitpacking_bitwidth - 1) / bitpacking_bitwidth;
return CeilDiv(unpacked_elements, + bitpacking_bitwidth);
}

constexpr int GetBitpackedMatrixSize(int rows, int cols) {
Expand Down
14 changes: 5 additions & 9 deletions larq_compute_engine/core/indirect_bgemm/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ namespace compute_engine {
namespace core {
namespace indirect_bgemm {

struct Kernel {
class Kernel {
public:
const std::int32_t block_size_output_channels;
const std::int32_t block_size_pixels;
const std::int32_t block_size_depth;
Expand Down Expand Up @@ -54,9 +55,7 @@ struct Kernel {
const std::int32_t input_depth_per_group = input_depth / groups;
const std::int32_t output_channels_per_group = output_channels / groups;
const std::int32_t rounded_up_output_channels_per_group =
block_size_output_channels *
((output_channels_per_group + block_size_output_channels - 1) /
block_size_output_channels);
Ceil(output_channels_per_group, block_size_output_channels);
packed_weights.resize(groups * rounded_up_output_channels_per_group *
filter_size * input_depth_per_group +
/* padding */ block_size_output_channels *
Expand Down Expand Up @@ -113,15 +112,12 @@ struct Kernel {
bconv2d_params->padding_values.height;
const std::int32_t input_padding_left =
bconv2d_params->padding_values.width;
const std::int32_t batch_size = bitpacked_input_shape.Dims(0);
const std::int32_t input_height = bitpacked_input_shape.Dims(1);
const std::int32_t input_width = bitpacked_input_shape.Dims(2);
const std::int32_t output_height = output_shape.Dims(1);
const std::int32_t output_width = output_shape.Dims(2);
const std::int32_t output_size = num_output_pixels;
const std::int32_t tiled_output_size =
block_size_pixels *
((output_size + block_size_pixels - 1) / block_size_pixels);
const std::int32_t tiled_output_size = Ceil(output_size, block_size_pixels);

// Create the indirection buffer with padding (+ block_size_pixels) and fill
// it with pointers to the first element of the input, so that the padding
Expand Down Expand Up @@ -156,7 +152,7 @@ struct Kernel {
const std::int32_t index = output_tile_start * filter_size +
kernel_index * block_size_pixels +
output_tile_offset;
if (0 <= input_x && input_x < input_width) {
if (FastBoundsCheck(input_x, input_width)) {
indirection_buffer.at(index) =
(input_ptr + (batch_index * input_height * input_width +
input_y * input_width + input_x) *
Expand Down
6 changes: 4 additions & 2 deletions larq_compute_engine/core/indirect_bgemm/kernel_4x2_portable.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@ namespace indirect_bgemm {
* A 4x2 C++ micro-kernel for float or int8 output.
*/
template <typename DstScalar>
struct Kernel4x2Portable : Kernel {
class Kernel4x2Portable : public Kernel {
static_assert(std::is_same<DstScalar, float>::value ||
std::is_same<DstScalar, std::int8_t>::value,
"");

const bconv2d::OutputTransform<DstScalar> output_transform;

public:
Kernel4x2Portable(const bconv2d::BConv2DParams* bconv2d_params,
const RuntimeShape& bitpacked_input_shape,
const RuntimeShape& output_shape,
Expand Down Expand Up @@ -161,9 +162,10 @@ struct Kernel4x2Portable : Kernel {
* A 4x2 C++ micro-kernel for bitpacked output.
*/
template <>
struct Kernel4x2Portable<TBitpacked> : Kernel {
class Kernel4x2Portable<TBitpacked> : public Kernel {
const bconv2d::OutputTransform<TBitpacked> output_transform;

public:
Kernel4x2Portable(
const bconv2d::BConv2DParams* bconv2d_params,
const RuntimeShape& bitpacked_input_shape,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -605,14 +605,15 @@ inline void OutputTransformAndLoadNextAndStore(
* A 8x4x1 Neon micro-kernel for float or int8 output on Aarch64.
*/
template <typename DstScalar, bool Depth2OrMore, bool IsGrouped>
struct Kernel8x4x1Aarch64 : Kernel {
class Kernel8x4x1Aarch64 : public Kernel {
static_assert(std::is_same<DstScalar, float>::value ||
std::is_same<DstScalar, std::int8_t>::value,
"");
static_assert(sizeof(TBitpacked) == 4, "");

const bconv2d::OutputTransform<DstScalar> output_transform;

public:
Kernel8x4x1Aarch64(
const bconv2d::BConv2DParams* bconv2d_params,
const RuntimeShape& bitpacked_input_shape,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -656,14 +656,15 @@ inline void OutputTransformAndLoadNextAndStore(
* A 8x4x2 Neon micro-kernel for float or int8 output on Aarch64.
*/
template <typename DstScalar, bool Depth2OrMore, bool IsGrouped>
struct Kernel8x4x2Aarch64 : Kernel {
class Kernel8x4x2Aarch64 : public Kernel {
static_assert(std::is_same<DstScalar, float>::value ||
std::is_same<DstScalar, std::int8_t>::value,
"");
static_assert(sizeof(TBitpacked) == 4, "");

const bconv2d::OutputTransform<DstScalar> output_transform;

public:
Kernel8x4x2Aarch64(
const bconv2d::BConv2DParams* bconv2d_params,
const RuntimeShape& bitpacked_input_shape,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -714,14 +714,15 @@ inline void OutputTransformAndLoadNextAndStore(
* A 8x4x4 Neon micro-kernel for float or int8 output on Aarch64.
*/
template <typename DstScalar, bool Depth2OrMore, bool IsGrouped>
struct Kernel8x4x4Aarch64 : Kernel {
class Kernel8x4x4Aarch64 : public Kernel {
static_assert(std::is_same<DstScalar, float>::value ||
std::is_same<DstScalar, std::int8_t>::value,
"");
static_assert(sizeof(TBitpacked) == 4, "");

const bconv2d::OutputTransform<DstScalar> output_transform;

public:
Kernel8x4x4Aarch64(
const bconv2d::BConv2DParams* bconv2d_params,
const RuntimeShape& bitpacked_input_shape,
Expand Down
17 changes: 17 additions & 0 deletions larq_compute_engine/core/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,23 @@ inline int xor_popcount(const TBitpacked& a, const TBitpacked& b) {
return std::bitset<bitpacking_bitwidth>(a ^ b).count();
}

// Check that 0 <= index < limit using a single comparison, assuming
// that 0 <= limit if Index is signed. Intended for use in performance
// critical contexts where 0 <= index < limit is almost always true.
inline bool FastBoundsCheck(const int index, const int limit) {
return LCE_LIKELY((unsigned)index < (unsigned)limit);
}

template <typename T, typename S>
constexpr T CeilDiv(T a, S b) {
return (a + b - 1) / b;
}

template <typename T, typename S>
constexpr T Ceil(T a, S b) {
return CeilDiv(a, b) * b;
}

} // namespace core
} // namespace compute_engine

Expand Down
7 changes: 3 additions & 4 deletions larq_compute_engine/mlir/transforms/prepare_tf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ namespace TFL {
namespace {

using compute_engine::core::bitpacking_bitwidth;
using compute_engine::core::CeilDiv;

// Prepare LCE operations in functions for subsequent legalization.
struct PrepareLCE : public PassWrapper<PrepareLCE, FunctionPass> {
Expand Down Expand Up @@ -107,10 +108,8 @@ bool IsSamePadding(Attribute paddings_attr, Value input, Value output,

return paddings.getValue<int>({0, 0}) == 0 &&
paddings.getValue<int>({0, 1}) == 0 &&
output_shape[1] ==
(input_shape[1] + stride_height - 1) / stride_height &&
output_shape[2] ==
(input_shape[2] + stride_width - 1) / stride_width &&
output_shape[1] == CeilDiv(input_shape[1], stride_height) &&
output_shape[2] == CeilDiv(input_shape[2], stride_width) &&
pad_height_left == pad_height / 2 &&
pad_height_right == (pad_height + 1) / 2 &&
pad_width_left == pad_width / 2 &&
Expand Down

0 comments on commit b4d3580

Please sign in to comment.