Skip to content

Commit

Permalink
change the kernel to support other layouts
Browse files Browse the repository at this point in the history
  • Loading branch information
mbrookhart committed Feb 26, 2021
1 parent 0c458e3 commit f84ea99
Showing 1 changed file with 40 additions and 22 deletions.
62 changes: 40 additions & 22 deletions src/relay/transforms/simplify_expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,14 @@ class SimplifyConvPad : public SimplifyPattern {
pattern_ = conv_;
}
template <typename T>
Attrs MakeConvAttrs(const Attrs& attrs, const Array<PrimExpr> padding) const {
const T* old_attrs = attrs.as<T>();
Attrs MakeConvAttrs(const T* old_attrs, const Array<PrimExpr> padding) const {
ICHECK(old_attrs);
ICHECK(padding.size() == old_attrs->padding.size())
<< "Number of dimensions to pad and convolution padding attributes should have the same "
"extent";

auto new_attrs = make_object<T>();
Array<PrimExpr> combined_padding;
ICHECK(padding.size() == old_attrs->padding.size());
for (size_t i = 0; i < padding.size(); ++i) {
combined_padding.push_back(padding[i] + old_attrs->padding[i]);
}
Expand All @@ -120,6 +122,35 @@ class SimplifyConvPad : public SimplifyPattern {
new_attrs->out_dtype = old_attrs->out_dtype;
return Attrs(new_attrs);
}
template <typename T>
Attrs GetAttrs(const PadAttrs* param, const T* attrs) const {
ICHECK(param);
ICHECK(attrs);
ICHECK(attrs->data_layout.size() == param->pad_width.size())
<< "Data Layout and padding attributes should have the same extent";

std::string data_layout = attrs->data_layout;
std::set<char> image_dims({'H', 'W', 'D'});
Array<PrimExpr> padding;
for (size_t i = 0; i < param->pad_width.size(); ++i) {
if (!image_dims.count(data_layout[i])) {
for (size_t j = 0; j < param->pad_width[i].size(); ++j) {
if (param->pad_width[i][j] != 0) {
return Attrs();
}
}
}
}
for (size_t j = 0; j < param->pad_width[0].size(); ++j) {
for (size_t i = 0; i < param->pad_width.size(); ++i) {
if (image_dims.count(data_layout[i])) {
padding.push_back(param->pad_width[i][j]);
}
}
}

return MakeConvAttrs(attrs, padding);
}
Expr callback(const Expr& pre, const Expr& post,
const Map<DFPattern, Array<Expr>>& node_map) const override {
const CallNode* call_node = post.as<CallNode>();
Expand All @@ -130,32 +161,19 @@ class SimplifyConvPad : public SimplifyPattern {
const PadAttrs* param = pad_node->attrs.as<PadAttrs>();
ICHECK(param);
if (param->pad_mode == "constant" && param->pad_value == 0.0) {
for (size_t i = 0; i < 2; ++i) {
for (size_t j = 0; j < param->pad_width[i].size(); ++j) {
if (param->pad_width[i][j] != 0) {
return post;
}
}
}
Array<PrimExpr> padding;
for (size_t j = 0; j < param->pad_width[0].size(); ++j) {
for (size_t i = 2; i < param->pad_width.size(); ++i) {
padding.push_back(param->pad_width[i][j]);
}
}
Attrs attrs;
if (node_map.count(conv1d_)) {
ICHECK(padding.size() == 2);
attrs = MakeConvAttrs<Conv1DAttrs>(call_node->attrs, padding);
attrs = GetAttrs(param, call_node->attrs.as<Conv1DAttrs>());
} else if (node_map.count(conv2d_)) {
ICHECK(padding.size() == 4);
attrs = MakeConvAttrs<Conv2DAttrs>(call_node->attrs, padding);
attrs = GetAttrs(param, call_node->attrs.as<Conv2DAttrs>());
} else if (node_map.count(conv3d_)) {
ICHECK(padding.size() == 6);
attrs = MakeConvAttrs<Conv3DAttrs>(call_node->attrs, padding);
attrs = GetAttrs(param, call_node->attrs.as<Conv3DAttrs>());
} else {
return post;
}
if (!attrs.defined()) {
return post;
}
auto x = node_map[x_][0];
auto w = node_map[w_][0];
return Call(call_node->op, {x, w}, attrs, call_node->type_args, call_node->span);
Expand Down

0 comments on commit f84ea99

Please sign in to comment.