diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 0d078d39372d..b61f209505d8 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -1099,6 +1099,19 @@ def SimplifyExpr(): return _ffi_api.SimplifyExpr() +def FoldExplicitPadding(): + """ + FoldExplicitPadding finds explict padding before an op that can support + implicit padding and fuses them. + + Returns + ------- + ret : tvm.transform.Pass + The registered ImplicitPadding pass. + """ + return _ffi_api.FoldExplicitPadding() + + def AnnotateSpans(): """ Annotate a program with span information by first generating its textual diff --git a/src/relay/transforms/fold_explicit_padding.cc b/src/relay/transforms/fold_explicit_padding.cc new file mode 100644 index 000000000000..d606eb445a79 --- /dev/null +++ b/src/relay/transforms/fold_explicit_padding.cc @@ -0,0 +1,207 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/transforms/fold_explicit_padding.cc + * \brief A pass for folding explicit pads into other ops. + */ + +#include +#include +#include +#include +#include + +#include "../op/tensor/transform.h" +#include "pattern_utils.h" + +namespace tvm { +namespace relay { + +/*! + * \brief SimplifyConvPad matches a pad followed by a conv/convtranspose/pool/etc + * with a pad attribute and merges the padding into the kernel. + */ +class SimplifyConvPad { + public: + DFPattern pattern() const { return pattern_; } + + SimplifyConvPad() { + x_ = IsWildcard(); + w_ = IsWildcard(); + pad_ = IsOp("nn.pad")({x_}); + conv1d_ = IsOp("nn.conv1d"); + conv2d_ = IsOp("nn.conv2d"); + conv3d_ = IsOp("nn.conv3d"); + conv_ = (conv1d_ || conv2d_ || conv3d_)({pad_, w_}); + pattern_ = conv_; + } + + template + Attrs MakeConvAttrs(const T* old_attrs, const Array padding) const { + ICHECK(old_attrs); + ICHECK(padding.size() == old_attrs->padding.size()) + << "Number of dimensions to pad and convolution padding attributes should have the same " + "extent"; + + auto new_attrs = make_object(); + Array combined_padding; + for (size_t i = 0; i < padding.size(); ++i) { + combined_padding.push_back(padding[i] + old_attrs->padding[i]); + } + new_attrs->strides = old_attrs->strides; + new_attrs->padding = combined_padding; + new_attrs->dilation = old_attrs->dilation; + new_attrs->groups = old_attrs->groups; + new_attrs->channels = old_attrs->channels; + new_attrs->kernel_size = old_attrs->kernel_size; + new_attrs->data_layout = old_attrs->data_layout; + new_attrs->kernel_layout = old_attrs->kernel_layout; + new_attrs->out_layout = old_attrs->out_layout; + new_attrs->out_dtype = old_attrs->out_dtype; + return Attrs(new_attrs); + } + + template + Attrs GetAttrs(const PadAttrs* param, const T* attrs) const { + ICHECK(param); + ICHECK(attrs); + ICHECK(attrs->data_layout.size() == param->pad_width.size()) + << "Data Layout and padding attributes should have the same extent"; + + std::string data_layout = attrs->data_layout; + std::set image_dims({'H', 'W', 'D'}); + Array padding; + // If we're padding a non-spatial dimension, don't simplify + // Convolution can only pad on spatial axes + for (size_t i = 0; i < param->pad_width.size(); ++i) { + if (!image_dims.count(data_layout[i])) { + for (size_t j = 0; j < param->pad_width[i].size(); ++j) { + if (param->pad_width[i][j] != 0) { + return Attrs(); + } + } + } + } + for (size_t j = 0; j < param->pad_width[0].size(); ++j) { + for (size_t i = 0; i < param->pad_width.size(); ++i) { + if (image_dims.count(data_layout[i])) { + padding.push_back(param->pad_width[i][j]); + } + } + } + + return MakeConvAttrs(attrs, padding); + } + + Expr callback(const Expr& pre, const Expr& post, + const Map>& node_map) const { + const CallNode* call_node = post.as(); + ICHECK(call_node); + auto pad = node_map[pad_][0]; + const CallNode* pad_node = pad.as(); + ICHECK(pad_node); + const PadAttrs* param = pad_node->attrs.as(); + ICHECK(param); + if (param->pad_mode == "constant" && param->pad_value == 0.0) { + Attrs attrs; + if (node_map.count(conv1d_)) { + attrs = GetAttrs(param, call_node->attrs.as()); + } else if (node_map.count(conv2d_)) { + attrs = GetAttrs(param, call_node->attrs.as()); + } else if (node_map.count(conv3d_)) { + attrs = GetAttrs(param, call_node->attrs.as()); + } else { + return post; + } + if (!attrs.defined()) { + return post; + } + auto x = node_map[x_][0]; + auto w = node_map[w_][0]; + return Call(call_node->op, {x, w}, attrs, call_node->type_args, call_node->span); + } + return post; + } + + private: + /*! \brief Pattern for rewriting */ + DFPattern pattern_; + /*! \brief Pattern input */ + DFPattern x_; + /*! \brief Pattern input weight */ + DFPattern w_; + /*! \brief Pattern pad */ + DFPattern pad_; + /*! \brief Pattern conv */ + DFPattern conv_; + DFPattern conv1d_; + DFPattern conv2d_; + DFPattern conv3d_; +}; + +class SimplifyExplicitPadding { + public: + explicit SimplifyExplicitPadding(IRModule mod) : mod_(mod) { + CreateCallback(SimplifyConvPad()); + // TODO(mbrookhart): ConvTranspose(Pad(x)), Pool(Pad(x)) + } + template + void CreateCallback(const T& pattern) { + auto func = [pattern](TVMArgs args, TVMRetValue* rv) { + Expr pre = args[0]; + Expr post = args[1]; + Map> node_map = args[2]; + *rv = pattern.callback(pre, post, node_map); + }; + callbacks_.push_back(DFPatternCallback(pattern.pattern(), PackedFunc(func), true)); + } + + Expr Simplify(const Expr& expr) { return RewritePatterns(callbacks_, expr, mod_); } + + private: + IRModule mod_; + /*! \brief Callbacks for expr simplification */ + Array callbacks_; +}; + +/*! + * \brief ImplicitPadding finds explict padding before an op that can + * support implicit padding and fuses them. + */ +Expr FoldExplicitPadding(const Expr& expr, const IRModule& mod) { + return SimplifyExplicitPadding(mod).Simplify(expr); +} + +namespace transform { + +Pass FoldExplicitPadding() { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { + return Downcast(FoldExplicitPadding(f, m)); + }; + return CreateFunctionPass(pass_func, 0, " FoldExplicitPadding", {"InferType"}); +} + +TVM_REGISTER_GLOBAL("relay._transform.FoldExplicitPadding").set_body_typed(FoldExplicitPadding); + +} // namespace transform + +} // namespace relay +} // namespace tvm diff --git a/src/relay/transforms/simplify_expr.cc b/src/relay/transforms/simplify_expr.cc index bfe04e10a9d0..74e48dc4bc54 100644 --- a/src/relay/transforms/simplify_expr.cc +++ b/src/relay/transforms/simplify_expr.cc @@ -82,121 +82,6 @@ class SimplifyReshape : public SimplifyPattern { DFPattern x_; }; -/*! - * \brief SimplifyConvPad matches a pad followed by a conv/convtranspose/pool/etc - * with a pad attribute and merges the padding into the kernel. - */ -class SimplifyConvPad : public SimplifyPattern { - public: - SimplifyConvPad() { - x_ = IsWildcard(); - w_ = IsWildcard(); - pad_ = IsOp("nn.pad")({x_}); - conv1d_ = IsOp("nn.conv1d"); - conv2d_ = IsOp("nn.conv2d"); - conv3d_ = IsOp("nn.conv3d"); - conv_ = (conv1d_ || conv2d_ || conv3d_)({pad_, w_}); - pattern_ = conv_; - } - template - Attrs MakeConvAttrs(const T* old_attrs, const Array padding) const { - ICHECK(old_attrs); - ICHECK(padding.size() == old_attrs->padding.size()) - << "Number of dimensions to pad and convolution padding attributes should have the same " - "extent"; - - auto new_attrs = make_object(); - Array combined_padding; - for (size_t i = 0; i < padding.size(); ++i) { - combined_padding.push_back(padding[i] + old_attrs->padding[i]); - } - new_attrs->strides = old_attrs->strides; - new_attrs->padding = combined_padding; - new_attrs->dilation = old_attrs->dilation; - new_attrs->groups = old_attrs->groups; - new_attrs->channels = old_attrs->channels; - new_attrs->kernel_size = old_attrs->kernel_size; - new_attrs->data_layout = old_attrs->data_layout; - new_attrs->kernel_layout = old_attrs->kernel_layout; - new_attrs->out_layout = old_attrs->out_layout; - new_attrs->out_dtype = old_attrs->out_dtype; - return Attrs(new_attrs); - } - template - Attrs GetAttrs(const PadAttrs* param, const T* attrs) const { - ICHECK(param); - ICHECK(attrs); - ICHECK(attrs->data_layout.size() == param->pad_width.size()) - << "Data Layout and padding attributes should have the same extent"; - - std::string data_layout = attrs->data_layout; - std::set image_dims({'H', 'W', 'D'}); - Array padding; - // If we're padding a non-spatial dimension, don't simplify - // Convolution can only pad on spatial axes - for (size_t i = 0; i < param->pad_width.size(); ++i) { - if (!image_dims.count(data_layout[i])) { - for (size_t j = 0; j < param->pad_width[i].size(); ++j) { - if (param->pad_width[i][j] != 0) { - return Attrs(); - } - } - } - } - for (size_t j = 0; j < param->pad_width[0].size(); ++j) { - for (size_t i = 0; i < param->pad_width.size(); ++i) { - if (image_dims.count(data_layout[i])) { - padding.push_back(param->pad_width[i][j]); - } - } - } - - return MakeConvAttrs(attrs, padding); - } - Expr callback(const Expr& pre, const Expr& post, - const Map>& node_map) const override { - const CallNode* call_node = post.as(); - ICHECK(call_node); - auto pad = node_map[pad_][0]; - const CallNode* pad_node = pad.as(); - ICHECK(pad_node); - const PadAttrs* param = pad_node->attrs.as(); - ICHECK(param); - if (param->pad_mode == "constant" && param->pad_value == 0.0) { - Attrs attrs; - if (node_map.count(conv1d_)) { - attrs = GetAttrs(param, call_node->attrs.as()); - } else if (node_map.count(conv2d_)) { - attrs = GetAttrs(param, call_node->attrs.as()); - } else if (node_map.count(conv3d_)) { - attrs = GetAttrs(param, call_node->attrs.as()); - } else { - return post; - } - if (!attrs.defined()) { - return post; - } - auto x = node_map[x_][0]; - auto w = node_map[w_][0]; - return Call(call_node->op, {x, w}, attrs, call_node->type_args, call_node->span); - } - return post; - } - - private: - /*! \brief Pattern input */ - DFPattern x_; - /*! \brief Pattern input weight */ - DFPattern w_; - /*! \brief Pattern pad */ - DFPattern pad_; - /*! \brief Pattern conv */ - DFPattern conv_; - DFPattern conv1d_; - DFPattern conv2d_; - DFPattern conv3d_; -}; - /*! * \brief FullArgwhere finds full followed by argwhere and turns it into an Arange op */ @@ -278,7 +163,6 @@ class ExprSimplifier { explicit ExprSimplifier(IRModule mod) : mod_(mod) { CreateCallback(SimplifyReshape()); CreateCallback(FullElementwise()); - CreateCallback(SimplifyConvPad()); } template void CreateCallback(const T& pattern) { diff --git a/tests/python/relay/test_pass_fold_explicit_padding.py b/tests/python/relay/test_pass_fold_explicit_padding.py new file mode 100644 index 000000000000..302a2b91bb8f --- /dev/null +++ b/tests/python/relay/test_pass_fold_explicit_padding.py @@ -0,0 +1,102 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import relay +from tvm.relay import transform +from tvm.relay.testing import run_opt_pass + +import numpy as np + + +def test_simplify_conv_pad(): + convs = [relay.nn.conv1d, relay.nn.conv2d, relay.nn.conv3d] + + def validate(ndim, pad_width, pad_value, pad_mode, orig_padding, layout): + if layout[1] == "C": + shape = [1, 3] + [10] * ndim + wshape = [8, 3] + [3] * ndim + elif layout[-1] == "C": + shape = [1] + [10] * ndim + [3] + wshape = [8] + [3] * ndim + [3] + else: + raise ValueError("This test only supports NC* and N*C") + + x = relay.var("x", shape=shape, dtype="float32") + w = relay.var("w", shape=wshape, dtype="float32") + pad = relay.nn.pad(x, pad_width, pad_value, pad_mode) + if layout[1] == "C": + conv = convs[ndim - 1](pad, w, padding=orig_padding) + else: + conv = convs[ndim - 1]( + pad, w, padding=orig_padding, data_layout=layout, kernel_layout="DHWIO"[3 - ndim :] + ) + + if pad_mode == "constant" and pad_value == 0: + new_padding = [] + for j in range(2): + for i in range(len(pad_width)): + if layout[i] in ["D", "H", "W"]: + new_padding.append(pad_width[i][j]) + for i in range(len(new_padding)): + new_padding[i] += orig_padding[i] + if layout[1] == "C": + after = convs[ndim - 1](x, w, padding=new_padding) + else: + after = convs[ndim - 1]( + x, w, padding=new_padding, data_layout=layout, kernel_layout="DHWIO"[3 - ndim :] + ) + else: + after = conv + + zz = run_opt_pass(conv, transform.FoldExplicitPadding()) + expected = run_opt_pass(after, transform.InferType()) + assert tvm.ir.structural_equal(zz, expected) + + mod1 = tvm.IRModule.from_expr(conv) + mod2 = tvm.IRModule.from_expr(zz) + + with tvm.transform.PassContext(): + ex1 = relay.create_executor("vm", mod=mod1, ctx=tvm.cpu(), target="llvm") + ex2 = relay.create_executor("vm", mod=mod2, ctx=tvm.cpu(), target="llvm") + x_np = np.random.rand(*shape).astype("float32") + w_np = np.random.rand(*wshape).astype("float32") + result1 = ex1.evaluate()(x_np, w_np) + result2 = ex2.evaluate()(x_np, w_np) + + tvm.testing.assert_allclose(result1.asnumpy(), result2.asnumpy(), rtol=1e-5, atol=1e-5) + + for orig_pad in [[0, 0], [2, 0], [0, 2]]: + for i_pad in [[0, 0], [1, 1], [1, 0]]: + for ndim in [1, 2, 3]: + for channels_last in [0, 1]: + if channels_last: + layout = "NDHWC" + layout = layout[0:1] + layout[4 - ndim : 4] + layout[-1:] + padding = [[0, 0]] + [i_pad] * ndim + [[0, 0]] + else: + layout = "NCDHW" + layout = layout[0:2] + layout[5 - ndim :] + padding = [[0, 0]] * 2 + [i_pad] * ndim + + validate(ndim, padding, 0, "constant", orig_pad * ndim, layout) + ndim = 2 + validate(ndim, [[0, 0]] * 2 + [i_pad] * ndim, 1, "constant", orig_pad * ndim, "NCHW") + validate(ndim, [[0, 0]] * 2 + [i_pad] * ndim, 0, "edge", orig_pad * ndim, "NCHW") + + +if __name__ == "__main__": + test_simplify_conv_pad() diff --git a/tests/python/relay/test_pass_simplify_expr.py b/tests/python/relay/test_pass_simplify_expr.py index e3e497e930f9..9531d896b2ed 100644 --- a/tests/python/relay/test_pass_simplify_expr.py +++ b/tests/python/relay/test_pass_simplify_expr.py @@ -124,82 +124,6 @@ def after_right(x, elem_op, value): validate(shape, value, dtype) -def test_simplify_conv_pad(): - convs = [relay.nn.conv1d, relay.nn.conv2d, relay.nn.conv3d] - - def validate(ndim, pad_width, pad_value, pad_mode, orig_padding, layout): - if layout[1] == "C": - shape = [1, 3] + [10] * ndim - wshape = [8, 3] + [3] * ndim - elif layout[-1] == "C": - shape = [1] + [10] * ndim + [3] - wshape = [8] + [3] * ndim + [3] - else: - raise ValueError("This test only supports NC* and N*C") - - x = relay.var("x", shape=shape, dtype="float32") - w = relay.var("w", shape=wshape, dtype="float32") - pad = relay.nn.pad(x, pad_width, pad_value, pad_mode) - if layout[1] == "C": - conv = convs[ndim - 1](pad, w, padding=orig_padding) - else: - conv = convs[ndim - 1]( - pad, w, padding=orig_padding, data_layout=layout, kernel_layout="DHWIO"[3 - ndim :] - ) - - if pad_mode == "constant" and pad_value == 0: - new_padding = [] - for j in range(2): - for i in range(len(pad_width)): - if layout[i] in ["D", "H", "W"]: - new_padding.append(pad_width[i][j]) - for i in range(len(new_padding)): - new_padding[i] += orig_padding[i] - if layout[1] == "C": - after = convs[ndim - 1](x, w, padding=new_padding) - else: - after = convs[ndim - 1]( - x, w, padding=new_padding, data_layout=layout, kernel_layout="DHWIO"[3 - ndim :] - ) - else: - after = conv - - zz = run_opt_pass(conv, transform.SimplifyExpr()) - expected = run_opt_pass(after, transform.InferType()) - assert tvm.ir.structural_equal(zz, expected) - - mod1 = tvm.IRModule.from_expr(conv) - mod2 = tvm.IRModule.from_expr(zz) - - with tvm.transform.PassContext(disabled_pass="SimplifyExpr"): - ex1 = relay.create_executor("vm", mod=mod1, ctx=tvm.cpu(), target="llvm") - ex2 = relay.create_executor("vm", mod=mod2, ctx=tvm.cpu(), target="llvm") - x_np = np.random.rand(*shape).astype("float32") - w_np = np.random.rand(*wshape).astype("float32") - result1 = ex1.evaluate()(x_np, w_np) - result2 = ex2.evaluate()(x_np, w_np) - - tvm.testing.assert_allclose(result1.asnumpy(), result2.asnumpy()) - - for orig_pad in [[0, 0], [2, 0], [0, 2]]: - for i_pad in [[0, 0], [1, 1], [1, 0]]: - for ndim in [1, 2, 3]: - for channels_last in [0, 1]: - if channels_last: - layout = "NDHWC" - layout = layout[0:1] + layout[4 - ndim : 4] + layout[-1:] - padding = [[0, 0]] + [i_pad] * ndim + [[0, 0]] - else: - layout = "NCDHW" - layout = layout[0:2] + layout[5 - ndim :] - padding = [[0, 0]] * 2 + [i_pad] * ndim - - validate(ndim, padding, 0, "constant", orig_pad * ndim, layout) - ndim = 2 - validate(ndim, [[0, 0]] * 2 + [i_pad] * ndim, 1, "constant", orig_pad * ndim, "NCHW") - validate(ndim, [[0, 0]] * 2 + [i_pad] * ndim, 0, "edge", orig_pad * ndim, "NCHW") - - if __name__ == "__main__": test_simplify_reshape() test_simplify_full_elementwise()