Skip to content

Commit

Permalink
Addressed review comments: removed disconcerting comments
Browse files Browse the repository at this point in the history
Change-Id: I2e9ca3edec85991be394bbdb2d1739c3afc62d5e
  • Loading branch information
ashutosh-arm committed Dec 3, 2021
1 parent c4d7c89 commit a51669a
Showing 1 changed file with 3 additions and 13 deletions.
16 changes: 3 additions & 13 deletions src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -156,16 +156,14 @@ class RelayToTIRVisitor : public MixedModeMutator {
ToArg(dilation_w), ToArg(dilation_h), ToArg(clip_min),
ToArg(clip_max)};

// cmsis_nn_dims *input_dims (NHWC)
// layout NHWC
Array<PrimExpr> input_shape = conv2d_call->args[0]->type_as<TensorTypeNode>()->shape;

// cmsis_nn_dims *filter_dims (OHWI for Conv2D and IHWO for depthwise)
// OHWI for Conv2D and IHWO for depthwise
Array<PrimExpr> filter_shape = conv2d_call->args[1]->type_as<TensorTypeNode>()->shape;

// cmsis_nn_dims *bias_dims
Array<PrimExpr> bias_shape{1, 1, 1, out_channels};

// cmsis_nn_dims *output_dims (same order as input_dims)
Array<PrimExpr> output_shape = conv2d_call->type_as<TensorTypeNode>()->shape;

int32_t depth_multiplier = -1;
Expand Down Expand Up @@ -287,23 +285,18 @@ class RelayToTIRVisitor : public MixedModeMutator {
ToArg(input_offset), ToArg(filter_offset), ToArg(output_offset), ToArg(clip_min),
ToArg(clip_max), ToArg(multiplier), ToArg(shift)};

// cmsis_nn_dims *input_dims
Array<PrimExpr> input_shape = fc_call->args[0]->type_as<TensorTypeNode>()->shape;
int32_t batch_size = qnn::get_const_int(input_shape[0]);
int32_t in_channels = qnn::get_const_int(input_shape[1]);
Array<PrimExpr> cmsisnn_input_shape{input_shape[0], 1, 1, input_shape[1]};

// cmsis_nn_dims *filter_dims
Array<PrimExpr> cmsisnn_filter_shape{in_channels, 1, 1, out_channels};

// cmsis_nn_dims *bias_dims
Array<PrimExpr> bias_shape{1, 1, 1, out_channels};

// cmsis_nn_dims *output_dims
Array<PrimExpr> cmsisnn_output_shape{batch_size, 1, 1, out_channels};

std::string cmsisnn_api = "arm_fully_connected_s8";
tvm::Array<PrimExpr> call_ext_args = {tir::StringImm(cmsisnn_api), input, filter};
tvm::Array<PrimExpr> call_ext_args = {tir::StringImm("arm_fully_connected_s8"), input, filter};
if (bias_add_call) {
call_ext_args.push_back(bias);
}
Expand Down Expand Up @@ -384,14 +377,11 @@ class RelayToTIRVisitor : public MixedModeMutator {
tvm::Array<PrimExpr> scalar_args = {ToArg(stride_h), ToArg(stride_w), ToArg(padding_h),
ToArg(padding_w), ToArg(clip_min), ToArg(clip_max)};

// cmsis_nn_dims *input_dims
Array<PrimExpr> input_shape = pool->args[0]->type_as<TensorTypeNode>()->shape;
Array<PrimExpr> cmsisnn_input_shape{1, input_shape[1], input_shape[2], input_shape[3]};

// cmsis_nn_dims *filter_dims
Array<PrimExpr> cmsisnn_filter_shape{1, pool_size_h, pool_size_w, 1};

// cmsis_nn_dims *output_dims
Array<PrimExpr> output_shape = pool->type_as<TensorTypeNode>()->shape;
Array<PrimExpr> cmsisnn_output_shape{1, output_shape[1], output_shape[2], output_shape[3]};

Expand Down

0 comments on commit a51669a

Please sign in to comment.