diff --git a/src/relay/transforms/simplify_expr.cc b/src/relay/transforms/simplify_expr.cc index 74e48dc4bc54..bfe04e10a9d0 100644 --- a/src/relay/transforms/simplify_expr.cc +++ b/src/relay/transforms/simplify_expr.cc @@ -82,6 +82,121 @@ class SimplifyReshape : public SimplifyPattern { DFPattern x_; }; +/*! + * \brief SimplifyConvPad matches a pad followed by a conv/convtranspose/pool/etc + * with a pad attribute and merges the padding into the kernel. + */ +class SimplifyConvPad : public SimplifyPattern { + public: + SimplifyConvPad() { + x_ = IsWildcard(); + w_ = IsWildcard(); + pad_ = IsOp("nn.pad")({x_}); + conv1d_ = IsOp("nn.conv1d"); + conv2d_ = IsOp("nn.conv2d"); + conv3d_ = IsOp("nn.conv3d"); + conv_ = (conv1d_ || conv2d_ || conv3d_)({pad_, w_}); + pattern_ = conv_; + } + template + Attrs MakeConvAttrs(const T* old_attrs, const Array padding) const { + ICHECK(old_attrs); + ICHECK(padding.size() == old_attrs->padding.size()) + << "Number of dimensions to pad and convolution padding attributes should have the same " + "extent"; + + auto new_attrs = make_object(); + Array combined_padding; + for (size_t i = 0; i < padding.size(); ++i) { + combined_padding.push_back(padding[i] + old_attrs->padding[i]); + } + new_attrs->strides = old_attrs->strides; + new_attrs->padding = combined_padding; + new_attrs->dilation = old_attrs->dilation; + new_attrs->groups = old_attrs->groups; + new_attrs->channels = old_attrs->channels; + new_attrs->kernel_size = old_attrs->kernel_size; + new_attrs->data_layout = old_attrs->data_layout; + new_attrs->kernel_layout = old_attrs->kernel_layout; + new_attrs->out_layout = old_attrs->out_layout; + new_attrs->out_dtype = old_attrs->out_dtype; + return Attrs(new_attrs); + } + template + Attrs GetAttrs(const PadAttrs* param, const T* attrs) const { + ICHECK(param); + ICHECK(attrs); + ICHECK(attrs->data_layout.size() == param->pad_width.size()) + << "Data Layout and padding attributes should have the same extent"; + + std::string data_layout = attrs->data_layout; + std::set image_dims({'H', 'W', 'D'}); + Array padding; + // If we're padding a non-spatial dimension, don't simplify + // Convolution can only pad on spatial axes + for (size_t i = 0; i < param->pad_width.size(); ++i) { + if (!image_dims.count(data_layout[i])) { + for (size_t j = 0; j < param->pad_width[i].size(); ++j) { + if (param->pad_width[i][j] != 0) { + return Attrs(); + } + } + } + } + for (size_t j = 0; j < param->pad_width[0].size(); ++j) { + for (size_t i = 0; i < param->pad_width.size(); ++i) { + if (image_dims.count(data_layout[i])) { + padding.push_back(param->pad_width[i][j]); + } + } + } + + return MakeConvAttrs(attrs, padding); + } + Expr callback(const Expr& pre, const Expr& post, + const Map>& node_map) const override { + const CallNode* call_node = post.as(); + ICHECK(call_node); + auto pad = node_map[pad_][0]; + const CallNode* pad_node = pad.as(); + ICHECK(pad_node); + const PadAttrs* param = pad_node->attrs.as(); + ICHECK(param); + if (param->pad_mode == "constant" && param->pad_value == 0.0) { + Attrs attrs; + if (node_map.count(conv1d_)) { + attrs = GetAttrs(param, call_node->attrs.as()); + } else if (node_map.count(conv2d_)) { + attrs = GetAttrs(param, call_node->attrs.as()); + } else if (node_map.count(conv3d_)) { + attrs = GetAttrs(param, call_node->attrs.as()); + } else { + return post; + } + if (!attrs.defined()) { + return post; + } + auto x = node_map[x_][0]; + auto w = node_map[w_][0]; + return Call(call_node->op, {x, w}, attrs, call_node->type_args, call_node->span); + } + return post; + } + + private: + /*! \brief Pattern input */ + DFPattern x_; + /*! \brief Pattern input weight */ + DFPattern w_; + /*! \brief Pattern pad */ + DFPattern pad_; + /*! \brief Pattern conv */ + DFPattern conv_; + DFPattern conv1d_; + DFPattern conv2d_; + DFPattern conv3d_; +}; + /*! * \brief FullArgwhere finds full followed by argwhere and turns it into an Arange op */ @@ -163,6 +278,7 @@ class ExprSimplifier { explicit ExprSimplifier(IRModule mod) : mod_(mod) { CreateCallback(SimplifyReshape()); CreateCallback(FullElementwise()); + CreateCallback(SimplifyConvPad()); } template void CreateCallback(const T& pattern) { diff --git a/tests/python/relay/test_pass_simplify_expr.py b/tests/python/relay/test_pass_simplify_expr.py index 423f0a4f213d..e3e497e930f9 100644 --- a/tests/python/relay/test_pass_simplify_expr.py +++ b/tests/python/relay/test_pass_simplify_expr.py @@ -19,6 +19,8 @@ from tvm.relay import transform from tvm.relay.testing import run_opt_pass +import numpy as np + def test_simplify_reshape(): def before(): @@ -122,6 +124,82 @@ def after_right(x, elem_op, value): validate(shape, value, dtype) +def test_simplify_conv_pad(): + convs = [relay.nn.conv1d, relay.nn.conv2d, relay.nn.conv3d] + + def validate(ndim, pad_width, pad_value, pad_mode, orig_padding, layout): + if layout[1] == "C": + shape = [1, 3] + [10] * ndim + wshape = [8, 3] + [3] * ndim + elif layout[-1] == "C": + shape = [1] + [10] * ndim + [3] + wshape = [8] + [3] * ndim + [3] + else: + raise ValueError("This test only supports NC* and N*C") + + x = relay.var("x", shape=shape, dtype="float32") + w = relay.var("w", shape=wshape, dtype="float32") + pad = relay.nn.pad(x, pad_width, pad_value, pad_mode) + if layout[1] == "C": + conv = convs[ndim - 1](pad, w, padding=orig_padding) + else: + conv = convs[ndim - 1]( + pad, w, padding=orig_padding, data_layout=layout, kernel_layout="DHWIO"[3 - ndim :] + ) + + if pad_mode == "constant" and pad_value == 0: + new_padding = [] + for j in range(2): + for i in range(len(pad_width)): + if layout[i] in ["D", "H", "W"]: + new_padding.append(pad_width[i][j]) + for i in range(len(new_padding)): + new_padding[i] += orig_padding[i] + if layout[1] == "C": + after = convs[ndim - 1](x, w, padding=new_padding) + else: + after = convs[ndim - 1]( + x, w, padding=new_padding, data_layout=layout, kernel_layout="DHWIO"[3 - ndim :] + ) + else: + after = conv + + zz = run_opt_pass(conv, transform.SimplifyExpr()) + expected = run_opt_pass(after, transform.InferType()) + assert tvm.ir.structural_equal(zz, expected) + + mod1 = tvm.IRModule.from_expr(conv) + mod2 = tvm.IRModule.from_expr(zz) + + with tvm.transform.PassContext(disabled_pass="SimplifyExpr"): + ex1 = relay.create_executor("vm", mod=mod1, ctx=tvm.cpu(), target="llvm") + ex2 = relay.create_executor("vm", mod=mod2, ctx=tvm.cpu(), target="llvm") + x_np = np.random.rand(*shape).astype("float32") + w_np = np.random.rand(*wshape).astype("float32") + result1 = ex1.evaluate()(x_np, w_np) + result2 = ex2.evaluate()(x_np, w_np) + + tvm.testing.assert_allclose(result1.asnumpy(), result2.asnumpy()) + + for orig_pad in [[0, 0], [2, 0], [0, 2]]: + for i_pad in [[0, 0], [1, 1], [1, 0]]: + for ndim in [1, 2, 3]: + for channels_last in [0, 1]: + if channels_last: + layout = "NDHWC" + layout = layout[0:1] + layout[4 - ndim : 4] + layout[-1:] + padding = [[0, 0]] + [i_pad] * ndim + [[0, 0]] + else: + layout = "NCDHW" + layout = layout[0:2] + layout[5 - ndim :] + padding = [[0, 0]] * 2 + [i_pad] * ndim + + validate(ndim, padding, 0, "constant", orig_pad * ndim, layout) + ndim = 2 + validate(ndim, [[0, 0]] * 2 + [i_pad] * ndim, 1, "constant", orig_pad * ndim, "NCHW") + validate(ndim, [[0, 0]] * 2 + [i_pad] * ndim, 0, "edge", orig_pad * ndim, "NCHW") + + if __name__ == "__main__": test_simplify_reshape() test_simplify_full_elementwise()