Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CMSIS-NN] Re-use CPU Target Parser #12320

Merged
merged 1 commit into from
Aug 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 17 additions & 9 deletions src/relay/backend/contrib/cmsisnn/buffer_size.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ namespace relay {
namespace contrib {
namespace cmsisnn {

int Conv2dBufferSize(CMSISNNFlags flags, int32_t padding_w, int32_t padding_h, int32_t input_n,
int Conv2dBufferSize(Target target, int32_t padding_w, int32_t padding_h, int32_t input_n,
int32_t input_h, int32_t input_c, int32_t output_h, int32_t output_w,
int32_t stride_w, int32_t stride_h, int32_t dilation_w, int32_t dilation_h,
int32_t filter_w, int32_t filter_h) {
Expand All @@ -37,18 +37,20 @@ int Conv2dBufferSize(CMSISNNFlags flags, int32_t padding_w, int32_t padding_h, i
bool is1xN = (output_h == 1) && (input_h == 1) && (filter_h == 1) && (output_w % 4 == 0) &&
(input_n == 1) && (dilation_w == 1) && (dilation_h == 1);

bool has_mve = target->GetFeature<Bool>("has_mve").value_or(Bool(false));

if (is1x1) {
return 0;
}

if (is1xN) {
if (!flags.mve) {
if (!has_mve) {
return (2 * input_c * filter_w * filter_h) * (int32_t)sizeof(int16_t);
}
return 0;
}

if (flags.mve) {
if (has_mve) {
int32_t col_length = input_c * filter_w * filter_h;
col_length = (col_length + 7) / 8;
return 4 * col_length * 8 * (int32_t)sizeof(int8_t);
Expand All @@ -58,21 +60,27 @@ int Conv2dBufferSize(CMSISNNFlags flags, int32_t padding_w, int32_t padding_h, i
return 0;
}

int DepthwiseConv2dBufferSize(CMSISNNFlags flags, int32_t input_n, int32_t input_c,
int32_t output_c, int32_t filter_w, int32_t filter_h) {
int DepthwiseConv2dBufferSize(Target target, int32_t input_n, int32_t input_c, int32_t output_c,
int32_t filter_w, int32_t filter_h) {
bool has_mve = target->GetFeature<Bool>("has_mve").value_or(Bool(false));
bool has_dsp = target->GetFeature<Bool>("has_dsp").value_or(Bool(false));

if (input_c == output_c && input_n == 1) {
if (flags.mve) {
if (has_mve) {
return (2 * input_c * filter_w * filter_h) * (int32_t)sizeof(int16_t) + 4;
}
if (flags.dsp) {
if (has_dsp) {
return (input_c * filter_w * filter_h) * (int32_t)sizeof(int16_t);
}
}
return 0;
}

int AvgPoolBufferSize(CMSISNNFlags flags, int32_t input_c) {
if (flags.dsp && !flags.mve) {
int AvgPoolBufferSize(Target target, int32_t input_c) {
bool has_mve = target->GetFeature<Bool>("has_mve").value_or(Bool(false));
bool has_dsp = target->GetFeature<Bool>("has_dsp").value_or(Bool(false));

if (has_dsp && !has_mve) {
return (input_c * sizeof(int32_t));
}
return 0;
Expand Down
13 changes: 7 additions & 6 deletions src/relay/backend/contrib/cmsisnn/buffer_size.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ namespace cmsisnn {
* See:
* https://github.com/ARM-software/CMSIS_5/blob/8c60448c0e1e50e426180b26db9bc31ddf774361/CMSIS/NN/Source/ConvolutionFunctions/arm_convolve_wrapper_s8.c#L108-L127
*
* \param flags - CMSIS-NN feature flags
* \param target - CMSIS-NN Target
* \param padding_w - Width padding
* \param padding_h - Height padding
* \param input_n - Input batch size
Expand All @@ -54,7 +54,7 @@ namespace cmsisnn {
*
* \return Size of buffer to allocate for convolution
*/
int Conv2dBufferSize(CMSISNNFlags flags, int32_t padding_w, int32_t padding_h, int32_t input_n,
int Conv2dBufferSize(Target target, int32_t padding_w, int32_t padding_h, int32_t input_n,
int32_t input_h, int32_t input_c, int32_t output_h, int32_t output_w,
int32_t stride_w, int32_t stride_h, int32_t dilation_w, int32_t dilation_h,
int32_t filter_w, int32_t filter_h);
Expand All @@ -64,7 +64,7 @@ int Conv2dBufferSize(CMSISNNFlags flags, int32_t padding_w, int32_t padding_h, i
* See:
* https://github.com/ARM-software/CMSIS_5/blob/325443e52637b6c7eedbd160d238a6c462e89c9f/CMSIS/NN/Source/ConvolutionFunctions/arm_depthwise_conv_wrapper_s8.c#L115-L129
*
* \param flags - CMSIS-NN feature flags
* \param target - CMSIS-NN Target
* \param input_n - Input batch size
* \param input_c - Input channels
* \param output_c - Output channels
Expand All @@ -73,19 +73,20 @@ int Conv2dBufferSize(CMSISNNFlags flags, int32_t padding_w, int32_t padding_h, i
*
* \return Size of buffer to allocate for depthwise convolution
*/
int DepthwiseConv2dBufferSize(CMSISNNFlags flags, int32_t input_n, int32_t input_c,
int32_t output_c, int32_t filter_w, int32_t filter_h);
int DepthwiseConv2dBufferSize(Target target, int32_t input_n, int32_t input_c, int32_t output_c,
int32_t filter_w, int32_t filter_h);

/*!
* \brief Calculates the appropriate buffer size for CMSIS-NN Average Pooling
* See:
* https://github.com/ARM-software/CMSIS_5/blob/bff28575f0c96a4ee9008947fea2b018a69b4900/CMSIS/NN/Source/PoolingFunctions/arm_avgpool_s8.c#L388-L398
*
* \param target - CMSIS-NN Target
* \param input_c - Input channels
*
* \return Size of buffer to allocate for average pooling
*/
int AvgPoolBufferSize(CMSISNNFlags flags, int32_t input_c);
int AvgPoolBufferSize(Target target, int32_t input_c);

} // namespace cmsisnn
} // namespace contrib
Expand Down
42 changes: 11 additions & 31 deletions src/relay/backend/contrib/cmsisnn/compiler_attrs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include <tvm/ir/attrs.h>
#include <tvm/ir/transform.h>
#include <tvm/target/target.h>

#include <string>

Expand All @@ -28,46 +29,25 @@ namespace relay {
namespace contrib {
namespace cmsisnn {

static const char* mveCPUs[] = {"cortex-m55"};
static const char* dspCPUs[] = {"cortex-m55", "cortex-m4", "cortex-m7", "cortex-m33",
"cortex-m35p"};

TVM_REGISTER_NODE_TYPE(CMSISNNCompilerConfigNode);
TVM_REGISTER_PASS_CONFIG_OPTION("relay.ext.cmsisnn.options", CMSISNNCompilerConfig);

template <typename Container>
static inline bool MatchesCpu(std::string mcpu, const Container& cpus) {
auto matches_cpu = [mcpu](const char* cpu) { return mcpu.find(cpu) == 0; };
return std::find_if(std::begin(cpus), std::end(cpus), matches_cpu) != std::end(cpus);
}

static inline bool HasFlag(std::string attr, std::string flag) {
return attr.find(flag) != std::string::npos;
}

CMSISNNFlags GetCompilerFlags(const tvm::transform::PassContext& ctx) {
Target CreateTarget(const tvm::transform::PassContext& ctx) {
auto cfg = ctx->GetConfig<CMSISNNCompilerConfig>("relay.ext.cmsisnn.options");
if (!cfg.defined()) {
return kNoExt;
return Target("cmsis-nn");
}

std::string mcpu = cfg.value()->mcpu;
std::string mattr = cfg.value()->mattr;
String mcpu = cfg.value()->mcpu;
Array<String> mattr = {cfg.value()->mattr};

bool nomve = HasFlag(mcpu, "+nomve") || HasFlag(mattr, "+nomve");
bool nodsp = HasFlag(mcpu, "+nodsp") || HasFlag(mattr, "+nodsp");

auto has_mve = MatchesCpu(mcpu, mveCPUs);
if (has_mve && !nomve && !nodsp) {
return kHasMVE;
}

auto has_dsp = MatchesCpu(mcpu, dspCPUs);
if (has_dsp && !nodsp) {
return kHasDSP;
}
Target cmsis_nn_target(TargetJSON{
{"kind", String("cmsis-nn")},
{"mcpu", mcpu},
{"mattr", mattr},
});

return kNoExt;
return cmsis_nn_target;
}

} // namespace cmsisnn
Expand Down
14 changes: 3 additions & 11 deletions src/relay/backend/contrib/cmsisnn/compiler_attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#define TVM_RELAY_BACKEND_CONTRIB_CMSISNN_COMPILER_ATTRS_H_

#include <tvm/ir/transform.h>
#include <tvm/target/target.h>

namespace tvm {
namespace relay {
Expand Down Expand Up @@ -55,17 +56,8 @@ class CMSISNNCompilerConfig : public Attrs {
CMSISNNCompilerConfigNode);
};

/*! \brief Flags to configure the calculations for CMSIS-NN. */
struct CMSISNNFlags {
bool dsp; // Enable or disable dsp buffers
bool mve; // Enable or disable mve buffers
};

constexpr CMSISNNFlags kNoExt = {.dsp = false, .mve = false};
constexpr CMSISNNFlags kHasDSP = {.dsp = true, .mve = false};
constexpr CMSISNNFlags kHasMVE = {.dsp = true, .mve = true};

CMSISNNFlags GetCompilerFlags(const tvm::transform::PassContext& ctx);
/*! \brief Convert External Code Generator options to TVM Target. */
Target CreateTarget(const tvm::transform::PassContext& ctx);

} // namespace cmsisnn
} // namespace contrib
Expand Down
14 changes: 7 additions & 7 deletions src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -239,15 +239,15 @@ class RelayToTIRVisitor : public MixedModeMutator {
call_ext_args.push_back(output);

PrimExpr context_buffer_var = tir::StringImm("NULL");
CMSISNNFlags flags = GetCompilerFlags(transform::PassContext::Current());
Target target = CreateTarget(transform::PassContext::Current());
size_t context_buffer_size;
if (is_depthwise) {
context_buffer_size =
DepthwiseConv2dBufferSize(flags, input_n, input_c, output_c, filter_w, filter_h);
DepthwiseConv2dBufferSize(target, input_n, input_c, output_c, filter_w, filter_h);
} else {
context_buffer_size = Conv2dBufferSize(flags, padding_w, padding_h, input_n, input_h, input_c,
output_h, output_w, stride_w, stride_h, dilation_w,
dilation_h, filter_w, filter_h);
context_buffer_size = Conv2dBufferSize(target, padding_w, padding_h, input_n, input_h,
input_c, output_h, output_w, stride_w, stride_h,
dilation_w, dilation_h, filter_w, filter_h);
}

if (context_buffer_size) {
Expand Down Expand Up @@ -440,9 +440,9 @@ class RelayToTIRVisitor : public MixedModeMutator {
int context_buffer_size = 0;
PrimExpr context_buffer_var = tir::StringImm("NULL");
if (pool_name == "cmsis-nn.qnn_avg_pool2d") {
CMSISNNFlags flags = GetCompilerFlags(transform::PassContext::Current());
Target target = CreateTarget(transform::PassContext::Current());
int32_t input_c = qnn::get_const_int(input_shape[3]);
context_buffer_size = AvgPoolBufferSize(flags, input_c);
context_buffer_size = AvgPoolBufferSize(target, input_c);
if (context_buffer_size) {
std::string context_buffer_name = "context_buffer_" + std::to_string(context_buffer_id_++);
context_buffer_var = tir::Var(context_buffer_name,
Expand Down
8 changes: 6 additions & 2 deletions src/relay/backend/contrib/cmsisnn/target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
#include <tvm/relay/transform.h>
#include <tvm/target/target.h>

#include "../../../../target/parsers/cpu.h"
#include "compiler_attrs.h"

namespace tvm {

namespace relay {
Expand All @@ -31,10 +34,11 @@ tvm::transform::Pass RelayToTIR();
runtime::Module TIRToRuntime(IRModule mod, Target target);

TVM_REGISTER_TARGET_KIND("cmsis-nn", kDLCPU)
.add_attr_option<Array<String>>("mattr")
.add_attr_option<String>("mcpu")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just musing, i'm sensing we'll shortly want a way to express this as "whatever target_host did"

.set_attr<FTVMRelayToTIR>(tvm::attr::kRelayToTIR, RelayToTIR())
.set_attr<FTVMTIRToRuntime>("TIRToRuntime", TIRToRuntime)
.add_attr_option<Array<String>>("mattr")
.add_attr_option<String>("mcpu");
.set_target_parser(tvm::target::parsers::cpu::ParseTarget);

} // namespace cmsisnn
} // namespace contrib
Expand Down
31 changes: 18 additions & 13 deletions tests/cpp/relay/backend/contrib/cmsisnn/buffer_size_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

#include <gtest/gtest.h>
#include <tvm/ir/transform.h>
#include <tvm/target/target.h>

#include <cmath>
#include <random>
Expand All @@ -39,12 +40,16 @@ static std::random_device rd;
static std::mt19937 gen(rd());
static std::uniform_int_distribution<> fake_parameters(2, 100);

static const Target kHasMVE("cmsis-nn -mcpu=cortex-m55");
static const Target kHasDSP("cmsis-nn -mcpu=cortex-m55 -mattr=+nomve");
static const Target kNoExt("cmsis-nn -mcpu=cortex-m55 -mattr=+nodsp,+nomve");

class CMSISNNCalculatedBufferSize : public testing::TestWithParam<std::array<int32_t, 3>> {};

TEST(CMSISNNConv2dBufferSize, Conv1x1) {
int32_t any = fake_parameters(gen);
auto conv2d_1x1 = [=](CMSISNNFlags flags, int32_t input_c) {
return Conv2dBufferSize(flags, 0, 0, any, any, input_c, any, any, 1, 1, 1, 1, 1, 1);
auto conv2d_1x1 = [=](Target target, int32_t input_c) {
return Conv2dBufferSize(target, 0, 0, any, any, input_c, any, any, 1, 1, 1, 1, 1, 1);
};

ASSERT_EQ(conv2d_1x1(kNoExt, 4), 0);
Expand Down Expand Up @@ -73,8 +78,8 @@ TEST(CMSISNNConv2dBufferSize, Conv1xN) {
int32_t filter_h = 1;
int32_t calculated_buffer = (2 * input_c * filter_w * filter_h) * (int32_t)sizeof(int16_t);

auto conv2d_1xn = [=](CMSISNNFlags flags, int32_t output_w) {
return Conv2dBufferSize(flags, any, any, 1, 1, input_c, 1, output_w, any, any, 1, 1, filter_w,
auto conv2d_1xn = [=](Target target, int32_t output_w) {
return Conv2dBufferSize(target, any, any, 1, 1, input_c, 1, output_w, any, any, 1, 1, filter_w,
filter_h);
};

Expand Down Expand Up @@ -108,8 +113,8 @@ TEST(CMSISNNConv2dBufferSize, Default) {
col_length = (col_length + 7) / 8;
int32_t calculated_buffer_mve = 4 * col_length * 8 * (int32_t)sizeof(int8_t);

auto conv2d = [=](CMSISNNFlags flags, int32_t output_w) {
return Conv2dBufferSize(flags, any, any, 1, 1, input_c, 1, output_w, any, any, any, any,
auto conv2d = [=](Target target, int32_t output_w) {
return Conv2dBufferSize(target, any, any, 1, 1, input_c, 1, output_w, any, any, any, any,
filter_w, filter_h);
};

Expand Down Expand Up @@ -137,8 +142,8 @@ TEST(CMSISNNDepthwiseConv2dBufferSize, UnEvenChannels) {
int32_t filter_h = fake_parameters(gen);
int32_t input_n = 1;

auto depthwise_conv2d_with_channels = [=](CMSISNNFlags flags, int32_t input_c, int32_t output_c) {
return DepthwiseConv2dBufferSize(flags, input_n, input_c, output_c, filter_w, filter_h);
auto depthwise_conv2d_with_channels = [=](Target target, int32_t input_c, int32_t output_c) {
return DepthwiseConv2dBufferSize(target, input_n, input_c, output_c, filter_w, filter_h);
};

ASSERT_EQ(depthwise_conv2d_with_channels(kNoExt, 4, 6), 0);
Expand All @@ -154,8 +159,8 @@ TEST(CMSISNNDepthwiseConv2dBufferSize, MultipleBatches) {
int32_t filter_w = fake_parameters(gen);
int32_t filter_h = fake_parameters(gen);

auto depthwise_conv2d_with_batch = [=](CMSISNNFlags flags, int32_t input_n) {
return DepthwiseConv2dBufferSize(flags, input_n, input_output_c, input_output_c, filter_w,
auto depthwise_conv2d_with_batch = [=](Target target, int32_t input_n) {
return DepthwiseConv2dBufferSize(target, input_n, input_output_c, input_output_c, filter_w,
filter_h);
};

Expand All @@ -177,8 +182,8 @@ TEST(CMSISNNDepthwiseConv2dBufferSize, Default) {
(2 * input_output_c * filter_w * filter_h) * (int32_t)sizeof(int16_t) + 4;
int32_t dsp_calculated_buffer = (input_output_c * filter_w * filter_h) * (int32_t)sizeof(int16_t);

auto depthwise_conv2d = [=](CMSISNNFlags flags) {
return DepthwiseConv2dBufferSize(flags, input_n, input_output_c, input_output_c, filter_w,
auto depthwise_conv2d = [=](Target target) {
return DepthwiseConv2dBufferSize(target, input_n, input_output_c, input_output_c, filter_w,
filter_h);
};

Expand All @@ -194,7 +199,7 @@ TEST(CMSISNNAvgPoolBufferSize, Default) {
int32_t input_c = fake_parameters(gen);
int32_t calculated_buffer = (input_c * sizeof(int32_t));

auto avg_pool = [=](CMSISNNFlags flags) { return AvgPoolBufferSize(flags, input_c); };
auto avg_pool = [=](Target target) { return AvgPoolBufferSize(target, input_c); };

ASSERT_EQ(avg_pool(kNoExt), 0);
ASSERT_EQ(avg_pool(kHasDSP), calculated_buffer);
Expand Down
Loading