Skip to content

Commit

Permalink
[CMSIS-NN] Re-use CPU Target Parser
Browse files Browse the repository at this point in the history
Previously `CMSISNNFlags` was derived using logic specific to the external code generator, this converts the external code generator options into a `Target`.
  • Loading branch information
Mousius committed Aug 7, 2022
1 parent 1ef595f commit bfed2a9
Show file tree
Hide file tree
Showing 9 changed files with 95 additions and 178 deletions.
26 changes: 17 additions & 9 deletions src/relay/backend/contrib/cmsisnn/buffer_size.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ namespace relay {
namespace contrib {
namespace cmsisnn {

int Conv2dBufferSize(CMSISNNFlags flags, int32_t padding_w, int32_t padding_h, int32_t input_n,
int Conv2dBufferSize(Target target, int32_t padding_w, int32_t padding_h, int32_t input_n,
int32_t input_h, int32_t input_c, int32_t output_h, int32_t output_w,
int32_t stride_w, int32_t stride_h, int32_t dilation_w, int32_t dilation_h,
int32_t filter_w, int32_t filter_h) {
Expand All @@ -37,18 +37,20 @@ int Conv2dBufferSize(CMSISNNFlags flags, int32_t padding_w, int32_t padding_h, i
bool is1xN = (output_h == 1) && (input_h == 1) && (filter_h == 1) && (output_w % 4 == 0) &&
(input_n == 1) && (dilation_w == 1) && (dilation_h == 1);

bool has_mve = target->GetFeature<Bool>("has_mve").value_or(Bool(false));

if (is1x1) {
return 0;
}

if (is1xN) {
if (!flags.mve) {
if (!has_mve) {
return (2 * input_c * filter_w * filter_h) * (int32_t)sizeof(int16_t);
}
return 0;
}

if (flags.mve) {
if (has_mve) {
int32_t col_length = input_c * filter_w * filter_h;
col_length = (col_length + 7) / 8;
return 4 * col_length * 8 * (int32_t)sizeof(int8_t);
Expand All @@ -58,21 +60,27 @@ int Conv2dBufferSize(CMSISNNFlags flags, int32_t padding_w, int32_t padding_h, i
return 0;
}

int DepthwiseConv2dBufferSize(CMSISNNFlags flags, int32_t input_n, int32_t input_c,
int32_t output_c, int32_t filter_w, int32_t filter_h) {
int DepthwiseConv2dBufferSize(Target target, int32_t input_n, int32_t input_c, int32_t output_c,
int32_t filter_w, int32_t filter_h) {
bool has_mve = target->GetFeature<Bool>("has_mve").value_or(Bool(false));
bool has_dsp = target->GetFeature<Bool>("has_dsp").value_or(Bool(false));

if (input_c == output_c && input_n == 1) {
if (flags.mve) {
if (has_mve) {
return (2 * input_c * filter_w * filter_h) * (int32_t)sizeof(int16_t) + 4;
}
if (flags.dsp) {
if (has_dsp) {
return (input_c * filter_w * filter_h) * (int32_t)sizeof(int16_t);
}
}
return 0;
}

int AvgPoolBufferSize(CMSISNNFlags flags, int32_t input_c) {
if (flags.dsp && !flags.mve) {
int AvgPoolBufferSize(Target target, int32_t input_c) {
bool has_mve = target->GetFeature<Bool>("has_mve").value_or(Bool(false));
bool has_dsp = target->GetFeature<Bool>("has_dsp").value_or(Bool(false));

if (has_dsp && !has_mve) {
return (input_c * sizeof(int32_t));
}
return 0;
Expand Down
13 changes: 7 additions & 6 deletions src/relay/backend/contrib/cmsisnn/buffer_size.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ namespace cmsisnn {
* See:
* https://github.com/ARM-software/CMSIS_5/blob/8c60448c0e1e50e426180b26db9bc31ddf774361/CMSIS/NN/Source/ConvolutionFunctions/arm_convolve_wrapper_s8.c#L108-L127
*
* \param flags - CMSIS-NN feature flags
* \param target - CMSIS-NN Target
* \param padding_w - Width padding
* \param padding_h - Height padding
* \param input_n - Input batch size
Expand All @@ -54,7 +54,7 @@ namespace cmsisnn {
*
* \return Size of buffer to allocate for convolution
*/
int Conv2dBufferSize(CMSISNNFlags flags, int32_t padding_w, int32_t padding_h, int32_t input_n,
int Conv2dBufferSize(Target target, int32_t padding_w, int32_t padding_h, int32_t input_n,
int32_t input_h, int32_t input_c, int32_t output_h, int32_t output_w,
int32_t stride_w, int32_t stride_h, int32_t dilation_w, int32_t dilation_h,
int32_t filter_w, int32_t filter_h);
Expand All @@ -64,7 +64,7 @@ int Conv2dBufferSize(CMSISNNFlags flags, int32_t padding_w, int32_t padding_h, i
* See:
* https://github.com/ARM-software/CMSIS_5/blob/325443e52637b6c7eedbd160d238a6c462e89c9f/CMSIS/NN/Source/ConvolutionFunctions/arm_depthwise_conv_wrapper_s8.c#L115-L129
*
* \param flags - CMSIS-NN feature flags
* \param target - CMSIS-NN Target
* \param input_n - Input batch size
* \param input_c - Input channels
* \param output_c - Output channels
Expand All @@ -73,19 +73,20 @@ int Conv2dBufferSize(CMSISNNFlags flags, int32_t padding_w, int32_t padding_h, i
*
* \return Size of buffer to allocate for depthwise convolution
*/
int DepthwiseConv2dBufferSize(CMSISNNFlags flags, int32_t input_n, int32_t input_c,
int32_t output_c, int32_t filter_w, int32_t filter_h);
int DepthwiseConv2dBufferSize(Target target, int32_t input_n, int32_t input_c, int32_t output_c,
int32_t filter_w, int32_t filter_h);

/*!
* \brief Calculates the appropriate buffer size for CMSIS-NN Average Pooling
* See:
* https://github.com/ARM-software/CMSIS_5/blob/bff28575f0c96a4ee9008947fea2b018a69b4900/CMSIS/NN/Source/PoolingFunctions/arm_avgpool_s8.c#L388-L398
*
* \param target - CMSIS-NN Target
* \param input_c - Input channels
*
* \return Size of buffer to allocate for average pooling
*/
int AvgPoolBufferSize(CMSISNNFlags flags, int32_t input_c);
int AvgPoolBufferSize(Target target, int32_t input_c);

} // namespace cmsisnn
} // namespace contrib
Expand Down
42 changes: 11 additions & 31 deletions src/relay/backend/contrib/cmsisnn/compiler_attrs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include <tvm/ir/attrs.h>
#include <tvm/ir/transform.h>
#include <tvm/target/target.h>

#include <string>

Expand All @@ -28,46 +29,25 @@ namespace relay {
namespace contrib {
namespace cmsisnn {

static const char* mveCPUs[] = {"cortex-m55"};
static const char* dspCPUs[] = {"cortex-m55", "cortex-m4", "cortex-m7", "cortex-m33",
"cortex-m35p"};

TVM_REGISTER_NODE_TYPE(CMSISNNCompilerConfigNode);
TVM_REGISTER_PASS_CONFIG_OPTION("relay.ext.cmsisnn.options", CMSISNNCompilerConfig);

template <typename Container>
static inline bool MatchesCpu(std::string mcpu, const Container& cpus) {
auto matches_cpu = [mcpu](const char* cpu) { return mcpu.find(cpu) == 0; };
return std::find_if(std::begin(cpus), std::end(cpus), matches_cpu) != std::end(cpus);
}

static inline bool HasFlag(std::string attr, std::string flag) {
return attr.find(flag) != std::string::npos;
}

CMSISNNFlags GetCompilerFlags(const tvm::transform::PassContext& ctx) {
Target CreateTarget(const tvm::transform::PassContext& ctx) {
auto cfg = ctx->GetConfig<CMSISNNCompilerConfig>("relay.ext.cmsisnn.options");
if (!cfg.defined()) {
return kNoExt;
return Target("cmsis-nn");
}

std::string mcpu = cfg.value()->mcpu;
std::string mattr = cfg.value()->mattr;
String mcpu = cfg.value()->mcpu;
Array<String> mattr = {cfg.value()->mattr};

bool nomve = HasFlag(mcpu, "+nomve") || HasFlag(mattr, "+nomve");
bool nodsp = HasFlag(mcpu, "+nodsp") || HasFlag(mattr, "+nodsp");

auto has_mve = MatchesCpu(mcpu, mveCPUs);
if (has_mve && !nomve && !nodsp) {
return kHasMVE;
}

auto has_dsp = MatchesCpu(mcpu, dspCPUs);
if (has_dsp && !nodsp) {
return kHasDSP;
}
Target cmsis_nn_target(TargetJSON{
{"kind", String("cmsis-nn")},
{"mcpu", mcpu},
{"mattr", mattr},
});

return kNoExt;
return cmsis_nn_target;
}

} // namespace cmsisnn
Expand Down
14 changes: 3 additions & 11 deletions src/relay/backend/contrib/cmsisnn/compiler_attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#define TVM_RELAY_BACKEND_CONTRIB_CMSISNN_COMPILER_ATTRS_H_

#include <tvm/ir/transform.h>
#include <tvm/target/target.h>

namespace tvm {
namespace relay {
Expand Down Expand Up @@ -55,17 +56,8 @@ class CMSISNNCompilerConfig : public Attrs {
CMSISNNCompilerConfigNode);
};

/*! \brief Flags to configure the calculations for CMSIS-NN. */
struct CMSISNNFlags {
bool dsp; // Enable or disable dsp buffers
bool mve; // Enable or disable mve buffers
};

constexpr CMSISNNFlags kNoExt = {.dsp = false, .mve = false};
constexpr CMSISNNFlags kHasDSP = {.dsp = true, .mve = false};
constexpr CMSISNNFlags kHasMVE = {.dsp = true, .mve = true};

CMSISNNFlags GetCompilerFlags(const tvm::transform::PassContext& ctx);
/*! \brief Convert External Code Generator options to TVM Target. */
Target CreateTarget(const tvm::transform::PassContext& ctx);

} // namespace cmsisnn
} // namespace contrib
Expand Down
14 changes: 7 additions & 7 deletions src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -239,15 +239,15 @@ class RelayToTIRVisitor : public MixedModeMutator {
call_ext_args.push_back(output);

PrimExpr context_buffer_var = tir::StringImm("NULL");
CMSISNNFlags flags = GetCompilerFlags(transform::PassContext::Current());
Target target = CreateTarget(transform::PassContext::Current());
size_t context_buffer_size;
if (is_depthwise) {
context_buffer_size =
DepthwiseConv2dBufferSize(flags, input_n, input_c, output_c, filter_w, filter_h);
DepthwiseConv2dBufferSize(target, input_n, input_c, output_c, filter_w, filter_h);
} else {
context_buffer_size = Conv2dBufferSize(flags, padding_w, padding_h, input_n, input_h, input_c,
output_h, output_w, stride_w, stride_h, dilation_w,
dilation_h, filter_w, filter_h);
context_buffer_size = Conv2dBufferSize(target, padding_w, padding_h, input_n, input_h,
input_c, output_h, output_w, stride_w, stride_h,
dilation_w, dilation_h, filter_w, filter_h);
}

if (context_buffer_size) {
Expand Down Expand Up @@ -440,9 +440,9 @@ class RelayToTIRVisitor : public MixedModeMutator {
int context_buffer_size = 0;
PrimExpr context_buffer_var = tir::StringImm("NULL");
if (pool_name == "cmsis-nn.qnn_avg_pool2d") {
CMSISNNFlags flags = GetCompilerFlags(transform::PassContext::Current());
Target target = CreateTarget(transform::PassContext::Current());
int32_t input_c = qnn::get_const_int(input_shape[3]);
context_buffer_size = AvgPoolBufferSize(flags, input_c);
context_buffer_size = AvgPoolBufferSize(target, input_c);
if (context_buffer_size) {
std::string context_buffer_name = "context_buffer_" + std::to_string(context_buffer_id_++);
context_buffer_var = tir::Var(context_buffer_name,
Expand Down
8 changes: 6 additions & 2 deletions src/relay/backend/contrib/cmsisnn/target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
#include <tvm/relay/transform.h>
#include <tvm/target/target.h>

#include "../../../../target/parsers/cpu.h"
#include "compiler_attrs.h"

namespace tvm {

namespace relay {
Expand All @@ -31,10 +34,11 @@ tvm::transform::Pass RelayToTIR();
runtime::Module TIRToRuntime(IRModule mod, Target target);

TVM_REGISTER_TARGET_KIND("cmsis-nn", kDLCPU)
.add_attr_option<Array<String>>("mattr")
.add_attr_option<String>("mcpu")
.set_attr<FTVMRelayToTIR>(tvm::attr::kRelayToTIR, RelayToTIR())
.set_attr<FTVMTIRToRuntime>("TIRToRuntime", TIRToRuntime)
.add_attr_option<Array<String>>("mattr")
.add_attr_option<String>("mcpu");
.set_target_parser(tvm::target::parsers::cpu::ParseTarget);

} // namespace cmsisnn
} // namespace contrib
Expand Down
31 changes: 18 additions & 13 deletions tests/cpp/relay/backend/contrib/cmsisnn/buffer_size_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

#include <gtest/gtest.h>
#include <tvm/ir/transform.h>
#include <tvm/target/target.h>

#include <cmath>
#include <random>
Expand All @@ -39,12 +40,16 @@ static std::random_device rd;
static std::mt19937 gen(rd());
static std::uniform_int_distribution<> fake_parameters(2, 100);

static const Target kHasMVE("cmsis-nn -mcpu=cortex-m55");
static const Target kHasDSP("cmsis-nn -mcpu=cortex-m55 -mattr=+nomve");
static const Target kNoExt("cmsis-nn -mcpu=cortex-m55 -mattr=+nodsp,+nomve");

class CMSISNNCalculatedBufferSize : public testing::TestWithParam<std::array<int32_t, 3>> {};

TEST(CMSISNNConv2dBufferSize, Conv1x1) {
int32_t any = fake_parameters(gen);
auto conv2d_1x1 = [=](CMSISNNFlags flags, int32_t input_c) {
return Conv2dBufferSize(flags, 0, 0, any, any, input_c, any, any, 1, 1, 1, 1, 1, 1);
auto conv2d_1x1 = [=](Target target, int32_t input_c) {
return Conv2dBufferSize(target, 0, 0, any, any, input_c, any, any, 1, 1, 1, 1, 1, 1);
};

ASSERT_EQ(conv2d_1x1(kNoExt, 4), 0);
Expand Down Expand Up @@ -73,8 +78,8 @@ TEST(CMSISNNConv2dBufferSize, Conv1xN) {
int32_t filter_h = 1;
int32_t calculated_buffer = (2 * input_c * filter_w * filter_h) * (int32_t)sizeof(int16_t);

auto conv2d_1xn = [=](CMSISNNFlags flags, int32_t output_w) {
return Conv2dBufferSize(flags, any, any, 1, 1, input_c, 1, output_w, any, any, 1, 1, filter_w,
auto conv2d_1xn = [=](Target target, int32_t output_w) {
return Conv2dBufferSize(target, any, any, 1, 1, input_c, 1, output_w, any, any, 1, 1, filter_w,
filter_h);
};

Expand Down Expand Up @@ -108,8 +113,8 @@ TEST(CMSISNNConv2dBufferSize, Default) {
col_length = (col_length + 7) / 8;
int32_t calculated_buffer_mve = 4 * col_length * 8 * (int32_t)sizeof(int8_t);

auto conv2d = [=](CMSISNNFlags flags, int32_t output_w) {
return Conv2dBufferSize(flags, any, any, 1, 1, input_c, 1, output_w, any, any, any, any,
auto conv2d = [=](Target target, int32_t output_w) {
return Conv2dBufferSize(target, any, any, 1, 1, input_c, 1, output_w, any, any, any, any,
filter_w, filter_h);
};

Expand Down Expand Up @@ -137,8 +142,8 @@ TEST(CMSISNNDepthwiseConv2dBufferSize, UnEvenChannels) {
int32_t filter_h = fake_parameters(gen);
int32_t input_n = 1;

auto depthwise_conv2d_with_channels = [=](CMSISNNFlags flags, int32_t input_c, int32_t output_c) {
return DepthwiseConv2dBufferSize(flags, input_n, input_c, output_c, filter_w, filter_h);
auto depthwise_conv2d_with_channels = [=](Target target, int32_t input_c, int32_t output_c) {
return DepthwiseConv2dBufferSize(target, input_n, input_c, output_c, filter_w, filter_h);
};

ASSERT_EQ(depthwise_conv2d_with_channels(kNoExt, 4, 6), 0);
Expand All @@ -154,8 +159,8 @@ TEST(CMSISNNDepthwiseConv2dBufferSize, MultipleBatches) {
int32_t filter_w = fake_parameters(gen);
int32_t filter_h = fake_parameters(gen);

auto depthwise_conv2d_with_batch = [=](CMSISNNFlags flags, int32_t input_n) {
return DepthwiseConv2dBufferSize(flags, input_n, input_output_c, input_output_c, filter_w,
auto depthwise_conv2d_with_batch = [=](Target target, int32_t input_n) {
return DepthwiseConv2dBufferSize(target, input_n, input_output_c, input_output_c, filter_w,
filter_h);
};

Expand All @@ -177,8 +182,8 @@ TEST(CMSISNNDepthwiseConv2dBufferSize, Default) {
(2 * input_output_c * filter_w * filter_h) * (int32_t)sizeof(int16_t) + 4;
int32_t dsp_calculated_buffer = (input_output_c * filter_w * filter_h) * (int32_t)sizeof(int16_t);

auto depthwise_conv2d = [=](CMSISNNFlags flags) {
return DepthwiseConv2dBufferSize(flags, input_n, input_output_c, input_output_c, filter_w,
auto depthwise_conv2d = [=](Target target) {
return DepthwiseConv2dBufferSize(target, input_n, input_output_c, input_output_c, filter_w,
filter_h);
};

Expand All @@ -194,7 +199,7 @@ TEST(CMSISNNAvgPoolBufferSize, Default) {
int32_t input_c = fake_parameters(gen);
int32_t calculated_buffer = (input_c * sizeof(int32_t));

auto avg_pool = [=](CMSISNNFlags flags) { return AvgPoolBufferSize(flags, input_c); };
auto avg_pool = [=](Target target) { return AvgPoolBufferSize(target, input_c); };

ASSERT_EQ(avg_pool(kNoExt), 0);
ASSERT_EQ(avg_pool(kHasDSP), calculated_buffer);
Expand Down
Loading

0 comments on commit bfed2a9

Please sign in to comment.