From 5c8fa05808382e8515d5ad022f3cbc566fd92527 Mon Sep 17 00:00:00 2001 From: Menooker Date: Fri, 17 Apr 2020 16:24:42 +0800 Subject: [PATCH 01/13] enable blocking format in x86 conv2d and fold scale axis --- python/tvm/relay/op/strategy/x86.py | 22 ++++- src/relay/transforms/fold_scale_axis.cc | 57 ++++++++++--- topi/python/topi/x86/conv2d_alter_op.py | 106 +++++++++++++----------- 3 files changed, 121 insertions(+), 64 deletions(-) diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index ba0b3d20b549..0c5768efb403 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -22,9 +22,13 @@ from tvm.te import SpecializedCondition from .generic import * from .. import op as _op +import re logger = logging.getLogger('strategy') +_NCHWc_matcher = re.compile("^NCHW[-+]?[0-9]+c$") +_OIHWio_matcher = re.compile("^OIHW[-+]?[0-9]+i[-+]?[0-9]+o$") + @schedule_injective.register("cpu") def schedule_injective_cpu(attrs, outs, target): """schedule injective ops for x86""" @@ -84,8 +88,13 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target): raise ValueError("dilation should be positive value") if groups == 1: - if layout == "NCHW": - assert kernel_layout == "OIHW" + if layout.startswith("NCHW"): + if layout != "NCHW": + #check if layout is NCHWxc + assert _NCHWc_matcher.match(layout) + assert _OIHWio_matcher.match(kernel_layout) + else: + assert kernel_layout == "OIHW" if topi.x86.is_int8_hw_support(data.dtype, kernel.dtype): strategy.add_implementation( wrap_compute_conv2d(topi.x86.conv2d_nchw_int8), @@ -113,8 +122,13 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target): else: raise RuntimeError("Unsupported conv2d layout {} for x86".format(layout)) elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups): - if layout == "NCHW": - assert kernel_layout == "OIHW" + if layout.startswith("NCHW"): + if layout != "NCHW": + #check if layout is NCHWxc + assert _NCHWc_matcher.match(layout) + assert _OIHWio_matcher.match(kernel_layout) + else: + assert kernel_layout == "OIHW" channel_multiplier = get_const_tuple(inputs[1].shape)[1] if channel_multiplier == 1 and dilation_h == 1 and dilation_w == 1: strategy.add_implementation( diff --git a/src/relay/transforms/fold_scale_axis.cc b/src/relay/transforms/fold_scale_axis.cc index 57e3d6925b20..1ec91191c9f3 100644 --- a/src/relay/transforms/fold_scale_axis.cc +++ b/src/relay/transforms/fold_scale_axis.cc @@ -39,6 +39,10 @@ namespace relay { * * Use namespace to reduce potential naming conflict. */ + +extern Expr MakeReshape(Expr data, + Array newshape); + namespace fold_scale_axis { using runtime::TypedPackedFunc; @@ -829,13 +833,20 @@ Message Conv2DBackwardPrep(const Call& call, const Array& in_messages) // only handle depthwise or full conv2d. // TODO(tvm-team) handle grouped conv by reshape + bcast bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout); - if (kernel_layout.IndexOf(LayoutAxis::Get('o')) < 0 && - kernel_layout.IndexOf(LayoutAxis::Get('i')) < 0 && c_small_axis < 0 && - (param->groups == 1 || is_depthwise_conv2d)) { - return Message({c_big_axis}, false); - } else { - return NullValue(); + if(param->groups == 1 || is_depthwise_conv2d) { + auto ko_small_axis = kernel_layout.IndexOf(LayoutAxis::Get('o')); + auto ki_small_axis = kernel_layout.IndexOf(LayoutAxis::Get('i')); + if ( (ko_small_axis < 0 && ki_small_axis < 0 && c_small_axis < 0) || //simple layout + (ko_small_axis >= 0 && ki_small_axis >= 0 && c_small_axis >= 0)) //blocked layout + { + Array arr{c_big_axis}; + if (c_small_axis >= 0) { + arr.push_back(c_small_axis); + } + return Message(arr, false); + } } + return NullValue(); } // Conv2D consumes the scale axis during transformation. @@ -852,19 +863,41 @@ Expr Conv2DBackwardTransform(const Call& call, const Message& message, const Exp CHECK_GE(c_big_axis, 0); // For now, we only support simple pattern (no folded weight/data) // TODO(tvm-team) support general data layout - CHECK_EQ(kernel_layout.IndexOf(LayoutAxis::Get('o')), -1); - CHECK_EQ(kernel_layout.IndexOf(LayoutAxis::Get('i')), -1); - CHECK(message->axes.size() == 1 && c_big_axis == message->axes[0]->value); - - int big_oc_axis = kernel_layout.IndexOf(LayoutAxis::Get('O')); + int small_ko_axis = kernel_layout.IndexOf(LayoutAxis::Get('o')); + int small_ki_axis = kernel_layout.IndexOf(LayoutAxis::Get('i')); + int big_ki_axis = kernel_layout.IndexOf(LayoutAxis::Get('I')); + int big_ko_axis = kernel_layout.IndexOf(LayoutAxis::Get('O')); // Check it must be depthwise or full conv2d. bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout); CHECK(param->groups == 1 || is_depthwise_conv2d); + bool is_simple = (small_ko_axis < 0 && small_ki_axis < 0 && big_ki_axis >= 0); + bool is_blocking = (small_ko_axis >= 0 && small_ki_axis >= 0 && big_ki_axis >= 0); + CHECK(is_simple || is_blocking); Expr data = transformer->Transform(call->args[0], NullValue(), NullValue()); Expr weight = transformer->Transform(call->args[1], NullValue(), NullValue()); // scale on input for deptwise. - Expr wscale = ExpandBiasToMatchAxis(scale, kernel_layout.ndim(), {big_oc_axis}); + Expr wscale; + if (is_simple) { + wscale = ExpandBiasToMatchAxis( + scale, kernel_layout.ndim(), {big_ko_axis}); + } else { + auto& wshape = weight->type_as()->shape; + Array arr; + for(size_t i=0; i(small_ko_axis) || i == static_cast(big_ko_axis)) { + auto node = wshape[i].as(); + if(!node) { + // if the shape is not a constant, use normal transform + return transformer->NormalCallTransform(call.operator->()); + } + arr.push_back(node->value); + } else { + arr.push_back(1); + } + } + wscale = MakeReshape(scale, std::move(arr)); + } weight = Multiply(weight, wscale); return Call(call->op, {data, weight}, call->attrs, call->type_args); } diff --git a/topi/python/topi/x86/conv2d_alter_op.py b/topi/python/topi/x86/conv2d_alter_op.py index 5ee691b07362..9d9c532741f9 100644 --- a/topi/python/topi/x86/conv2d_alter_op.py +++ b/topi/python/topi/x86/conv2d_alter_op.py @@ -28,9 +28,13 @@ from ..util import get_const_tuple from ..nn import conv2d_legalize, conv2d_alter_layout from ..nn.util import get_pad_tuple +import re logger = logging.getLogger('topi') +_NCHWc_matcher = re.compile("^NCHW[-+]?[0-9]+c$") +_OIHWio_matcher = re.compile("^OIHW[-+]?[0-9]+i[-+]?[0-9]+o$") + @conv2d_alter_layout.register("cpu") def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): target = tvm.target.Target.current(allow_none=False) @@ -64,30 +68,33 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): if topi_tmpl == "conv2d_NCHWc.x86": # we only convert conv2d_NCHW to conv2d_NCHWc for x86 - assert data_layout == "NCHW" and kernel_layout == "OIHW" - if cfg.is_fallback: - _get_default_config(cfg, data_tensor, kernel_tensor, strides, padding, - out_dtype, False, data_layout) - batch_size, in_channel, height, width = get_const_tuple(data_tensor.shape) - out_channel, _, kh, kw = get_const_tuple(kernel_tensor.shape) - ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1] - - # update new attrs - new_attrs['channels'] = out_channel - new_attrs['data_layout'] = 'NCHW%dc' % ic_bn - # (oc, ic, h, w) -> (OC, IC, h, w, ic, oc) - new_attrs['kernel_layout'] = 'OIHW%di%do' % (ic_bn, oc_bn) - new_attrs['out_layout'] = 'NCHW%dc' % oc_bn - - # Store altered operator's config - new_data = te.placeholder((batch_size, in_channel//ic_bn, height, width, ic_bn), - dtype=data_dtype) - new_kernel = te.placeholder((out_channel//oc_bn, in_channel//ic_bn, - kh, kw, ic_bn, oc_bn), dtype=kernel_tensor.dtype) - new_workload = autotvm.task.args_to_workload( - [new_data, new_kernel, strides, padding, dilation, new_attrs["data_layout"], - new_attrs["out_layout"], out_dtype], topi_tmpl) - dispatch_ctx.update(target, new_workload, cfg) + if data_layout=="NCHW" and kernel_layout=="OIHW": + if cfg.is_fallback: + _get_default_config(cfg, data_tensor, kernel_tensor, strides, padding, + out_dtype, False, data_layout) + batch_size, in_channel, height, width = get_const_tuple(data_tensor.shape) + out_channel, _, kh, kw = get_const_tuple(kernel_tensor.shape) + ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1] + + # update new attrs + new_attrs['channels'] = out_channel + new_attrs['data_layout'] = 'NCHW%dc' % ic_bn + # (oc, ic, h, w) -> (OC, IC, h, w, ic, oc) + new_attrs['kernel_layout'] = 'OIHW%di%do' % (ic_bn, oc_bn) + new_attrs['out_layout'] = 'NCHW%dc' % oc_bn + + # Store altered operator's config + new_data = te.placeholder((batch_size, in_channel//ic_bn, height, width, ic_bn), + dtype=data_dtype) + new_kernel = te.placeholder((out_channel//oc_bn, in_channel//ic_bn, + kh, kw, ic_bn, oc_bn), dtype=kernel_tensor.dtype) + new_workload = autotvm.task.args_to_workload( + [new_data, new_kernel, strides, padding, dilation, new_attrs["data_layout"], + new_attrs["out_layout"], out_dtype], topi_tmpl) + dispatch_ctx.update(target, new_workload, cfg) + else: + assert _NCHWc_matcher.match(data_layout) + assert _OIHWio_matcher.match(kernel_layout) return relay.nn.contrib_conv2d_nchwc(*inputs, **new_attrs) if topi_tmpl == "conv2d_NCHWc_int8.x86": @@ -136,30 +143,33 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): return relay.nn.contrib_conv2d_nchwc(data_expr, kernel_OIHWioe, **new_attrs) if topi_tmpl == "depthwise_conv2d_NCHWc.x86": - assert data_layout == "NCHW" and kernel_layout == "OIHW" - if cfg.is_fallback: - _get_default_config(cfg, data_tensor, kernel_tensor, strides, padding, - out_dtype, True, data_layout) - - batch_size, in_channel, height, width = get_const_tuple(data_tensor.shape) - out_channel, channel_multiplier, kh, kw = get_const_tuple(kernel_tensor.shape) - ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1] - assert channel_multiplier == 1 - - # update new attrs - new_attrs['channels'] = out_channel - new_attrs['data_layout'] = 'NCHW%dc' % ic_bn - new_attrs['kernel_layout'] = 'OIHW1i%do' % oc_bn - new_attrs['out_layout'] = 'NCHW%dc' % oc_bn - - # Store altered operator's config. - new_data = te.placeholder((batch_size, in_channel//ic_bn, height, width, ic_bn), - dtype=data_dtype) - new_kernel = te.placeholder((out_channel//oc_bn, 1, kh, kw, 1, oc_bn), dtype=kernel_dtype) - new_workload = autotvm.task.args_to_workload( - [new_data, new_kernel, strides, padding, dilation, new_attrs['data_layout'], - new_attrs['out_layout'], out_dtype], topi_tmpl) - dispatch_ctx.update(target, new_workload, cfg) + if data_layout=="NCHW" and kernel_layout=="OIHW": + if cfg.is_fallback: + _get_default_config(cfg, data_tensor, kernel_tensor, strides, padding, + out_dtype, True, data_layout) + + batch_size, in_channel, height, width = get_const_tuple(data_tensor.shape) + out_channel, channel_multiplier, kh, kw = get_const_tuple(kernel_tensor.shape) + ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1] + assert channel_multiplier == 1 + + # update new attrs + new_attrs['channels'] = out_channel + new_attrs['data_layout'] = 'NCHW%dc' % ic_bn + new_attrs['kernel_layout'] = 'OIHW1i%do' % oc_bn + new_attrs['out_layout'] = 'NCHW%dc' % oc_bn + + # Store altered operator's config. + new_data = te.placeholder((batch_size, in_channel//ic_bn, height, width, ic_bn), + dtype=data_dtype) + new_kernel = te.placeholder((out_channel//oc_bn, 1, kh, kw, 1, oc_bn), dtype=kernel_dtype) + new_workload = autotvm.task.args_to_workload( + [new_data, new_kernel, strides, padding, dilation, new_attrs['data_layout'], + new_attrs['out_layout'], out_dtype], topi_tmpl) + dispatch_ctx.update(target, new_workload, cfg) + else: + assert _NCHWc_matcher.match(data_layout) + assert _OIHWio_matcher.match(kernel_layout) return relay.nn.contrib_depthwise_conv2d_nchwc(*inputs, **new_attrs) return None From 18e3b7f0fafb5f1afaa469cfd31e030b00c9a708 Mon Sep 17 00:00:00 2001 From: Menooker Date: Fri, 17 Apr 2020 16:56:50 +0800 Subject: [PATCH 02/13] code style changes --- src/relay/transforms/fold_scale_axis.cc | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/relay/transforms/fold_scale_axis.cc b/src/relay/transforms/fold_scale_axis.cc index 1ec91191c9f3..d68040ee2fcc 100644 --- a/src/relay/transforms/fold_scale_axis.cc +++ b/src/relay/transforms/fold_scale_axis.cc @@ -833,12 +833,11 @@ Message Conv2DBackwardPrep(const Call& call, const Array& in_messages) // only handle depthwise or full conv2d. // TODO(tvm-team) handle grouped conv by reshape + bcast bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout); - if(param->groups == 1 || is_depthwise_conv2d) { + if (param->groups == 1 || is_depthwise_conv2d) { auto ko_small_axis = kernel_layout.IndexOf(LayoutAxis::Get('o')); auto ki_small_axis = kernel_layout.IndexOf(LayoutAxis::Get('i')); - if ( (ko_small_axis < 0 && ki_small_axis < 0 && c_small_axis < 0) || //simple layout - (ko_small_axis >= 0 && ki_small_axis >= 0 && c_small_axis >= 0)) //blocked layout - { + if ( (ko_small_axis < 0 && ki_small_axis < 0 && c_small_axis < 0) || // simple layout + (ko_small_axis >= 0 && ki_small_axis >= 0 && c_small_axis >= 0)) { // blocked layout Array arr{c_big_axis}; if (c_small_axis >= 0) { arr.push_back(c_small_axis); @@ -884,10 +883,10 @@ Expr Conv2DBackwardTransform(const Call& call, const Message& message, const Exp } else { auto& wshape = weight->type_as()->shape; Array arr; - for(size_t i=0; i(small_ko_axis) || i == static_cast(big_ko_axis)) { + for (size_t i = 0; i < wshape.size(); i++) { + if (i == static_cast(small_ko_axis) || i == static_cast(big_ko_axis)) { auto node = wshape[i].as(); - if(!node) { + if (!node) { // if the shape is not a constant, use normal transform return transformer->NormalCallTransform(call.operator->()); } From 2b4019eb72617366862d91aca2f4b29dd7da99dc Mon Sep 17 00:00:00 2001 From: Menooker Date: Mon, 20 Apr 2020 09:20:39 +0800 Subject: [PATCH 03/13] style change --- src/relay/transforms/fold_scale_axis.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/relay/transforms/fold_scale_axis.cc b/src/relay/transforms/fold_scale_axis.cc index d68040ee2fcc..5ce7b7999645 100644 --- a/src/relay/transforms/fold_scale_axis.cc +++ b/src/relay/transforms/fold_scale_axis.cc @@ -836,8 +836,8 @@ Message Conv2DBackwardPrep(const Call& call, const Array& in_messages) if (param->groups == 1 || is_depthwise_conv2d) { auto ko_small_axis = kernel_layout.IndexOf(LayoutAxis::Get('o')); auto ki_small_axis = kernel_layout.IndexOf(LayoutAxis::Get('i')); - if ( (ko_small_axis < 0 && ki_small_axis < 0 && c_small_axis < 0) || // simple layout - (ko_small_axis >= 0 && ki_small_axis >= 0 && c_small_axis >= 0)) { // blocked layout + if ( (ko_small_axis < 0 && ki_small_axis < 0 && c_small_axis < 0) || // simple layout + (ko_small_axis >= 0 && ki_small_axis >= 0 && c_small_axis >= 0)) { // blocked layout Array arr{c_big_axis}; if (c_small_axis >= 0) { arr.push_back(c_small_axis); From fb274d82babcc88c4e8fe999afbf07bf128fba30 Mon Sep 17 00:00:00 2001 From: Menooker Date: Mon, 20 Apr 2020 09:52:15 +0800 Subject: [PATCH 04/13] style change --- python/tvm/relay/op/strategy/x86.py | 6 +++--- topi/python/topi/x86/conv2d_alter_op.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index 0c5768efb403..059e9eaafa56 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -18,11 +18,11 @@ # pylint: disable=invalid-name,unused-argument,wildcard-import,unused-wildcard-import import logging +import re import topi from tvm.te import SpecializedCondition from .generic import * from .. import op as _op -import re logger = logging.getLogger('strategy') @@ -90,7 +90,7 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target): if groups == 1: if layout.startswith("NCHW"): if layout != "NCHW": - #check if layout is NCHWxc + # check if layout is NCHWxc assert _NCHWc_matcher.match(layout) assert _OIHWio_matcher.match(kernel_layout) else: @@ -124,7 +124,7 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target): elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups): if layout.startswith("NCHW"): if layout != "NCHW": - #check if layout is NCHWxc + # check if layout is NCHWxc assert _NCHWc_matcher.match(layout) assert _OIHWio_matcher.match(kernel_layout) else: diff --git a/topi/python/topi/x86/conv2d_alter_op.py b/topi/python/topi/x86/conv2d_alter_op.py index 9d9c532741f9..44e943d6d347 100644 --- a/topi/python/topi/x86/conv2d_alter_op.py +++ b/topi/python/topi/x86/conv2d_alter_op.py @@ -19,6 +19,7 @@ import logging +import re import tvm from tvm import te from tvm import relay @@ -28,7 +29,6 @@ from ..util import get_const_tuple from ..nn import conv2d_legalize, conv2d_alter_layout from ..nn.util import get_pad_tuple -import re logger = logging.getLogger('topi') From 7390e58a25093824170c475d52a88dc97d9ce364 Mon Sep 17 00:00:00 2001 From: Menooker Date: Mon, 20 Apr 2020 10:01:44 +0800 Subject: [PATCH 05/13] pylint changes --- topi/python/topi/x86/conv2d_alter_op.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/topi/python/topi/x86/conv2d_alter_op.py b/topi/python/topi/x86/conv2d_alter_op.py index 44e943d6d347..b263c2383ca1 100644 --- a/topi/python/topi/x86/conv2d_alter_op.py +++ b/topi/python/topi/x86/conv2d_alter_op.py @@ -68,7 +68,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): if topi_tmpl == "conv2d_NCHWc.x86": # we only convert conv2d_NCHW to conv2d_NCHWc for x86 - if data_layout=="NCHW" and kernel_layout=="OIHW": + if data_layout == "NCHW" and kernel_layout == "OIHW": if cfg.is_fallback: _get_default_config(cfg, data_tensor, kernel_tensor, strides, padding, out_dtype, False, data_layout) @@ -85,12 +85,12 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): # Store altered operator's config new_data = te.placeholder((batch_size, in_channel//ic_bn, height, width, ic_bn), - dtype=data_dtype) + dtype=data_dtype) new_kernel = te.placeholder((out_channel//oc_bn, in_channel//ic_bn, - kh, kw, ic_bn, oc_bn), dtype=kernel_tensor.dtype) + kh, kw, ic_bn, oc_bn), dtype=kernel_tensor.dtype) new_workload = autotvm.task.args_to_workload( [new_data, new_kernel, strides, padding, dilation, new_attrs["data_layout"], - new_attrs["out_layout"], out_dtype], topi_tmpl) + new_attrs["out_layout"], out_dtype], topi_tmpl) dispatch_ctx.update(target, new_workload, cfg) else: assert _NCHWc_matcher.match(data_layout) @@ -143,7 +143,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): return relay.nn.contrib_conv2d_nchwc(data_expr, kernel_OIHWioe, **new_attrs) if topi_tmpl == "depthwise_conv2d_NCHWc.x86": - if data_layout=="NCHW" and kernel_layout=="OIHW": + if data_layout == "NCHW" and kernel_layout == "OIHW": if cfg.is_fallback: _get_default_config(cfg, data_tensor, kernel_tensor, strides, padding, out_dtype, True, data_layout) @@ -161,11 +161,12 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): # Store altered operator's config. new_data = te.placeholder((batch_size, in_channel//ic_bn, height, width, ic_bn), - dtype=data_dtype) - new_kernel = te.placeholder((out_channel//oc_bn, 1, kh, kw, 1, oc_bn), dtype=kernel_dtype) + dtype=data_dtype) + new_kernel = te.placeholder((out_channel//oc_bn, 1, kh, kw, 1, oc_bn), + dtype=kernel_dtype) new_workload = autotvm.task.args_to_workload( [new_data, new_kernel, strides, padding, dilation, new_attrs['data_layout'], - new_attrs['out_layout'], out_dtype], topi_tmpl) + new_attrs['out_layout'], out_dtype], topi_tmpl) dispatch_ctx.update(target, new_workload, cfg) else: assert _NCHWc_matcher.match(data_layout) From ef9275771eba1abe89b5ace47e9e77f7e89d03a2 Mon Sep 17 00:00:00 2001 From: Menooker Date: Thu, 23 Apr 2020 12:50:59 +0800 Subject: [PATCH 06/13] add forward fold axis, pass tests --- src/relay/transforms/fold_scale_axis.cc | 131 +++-- .../python/relay/test_pass_fold_scale_axis.py | 484 ++++++++++++------ 2 files changed, 412 insertions(+), 203 deletions(-) diff --git a/src/relay/transforms/fold_scale_axis.cc b/src/relay/transforms/fold_scale_axis.cc index 5ce7b7999645..c94ec2315941 100644 --- a/src/relay/transforms/fold_scale_axis.cc +++ b/src/relay/transforms/fold_scale_axis.cc @@ -309,6 +309,40 @@ class ForwardPrep : private ExprVisitor { } }; +static bool IsIntInArray(const Array& axis, int v) { + for (size_t i = 0; i < axis.size(); i++) { + if (axis[i] == v) + return true; + } + return false; +} + +static Expr ReshapeToMatchAxis(Expr scale, const Array& shape, const Array& axis) { + Array arr; + for (size_t i = 0; i < shape.size(); i++) { + if (IsIntInArray(axis, i)) { + auto node = shape[i].as(); + if (!node) { + // if the shape is not a constant, use normal transform + return Expr(); + } + arr.push_back(node->value); + } else { + arr.push_back(1); + } + } + return MakeReshape(scale, std::move(arr)); +} + +// if only one axis, use expand dim. Else, use reshape +static Expr ReshapeOrExpandToMatchAxis(Expr scale, const Array& shape, const Array& axis) { + if (axis.size() > 1) { + return ReshapeToMatchAxis(scale, shape, axis); + } else { + return ExpandBiasToMatchAxis(scale, shape.size(), axis); + } +} + //---------------------------------------------- // Per operator defs for FScaleAxisForward //---------------------------------------------- @@ -369,7 +403,9 @@ Expr AddSubForwardRewrite(const Call& ref_call, const Array& new_args, if (slhs != nullptr) { CHECK(srhs == nullptr); CHECK(MatchBroadcastToLeftAxes(tlhs, trhs, slhs->axes)); - Expr scale = ExpandBiasToMatchAxis(slhs->scale, tlhs->shape.size(), slhs->axes); + Expr scale = ReshapeOrExpandToMatchAxis( + slhs->scale, tlhs->shape, slhs->axes); + CHECK(scale.defined()); Expr rhs = Divide(new_args[1], scale); rnode->value = Call(ref_call->op, {slhs->value, rhs}, ref_call->attrs, ref_call->type_args); rnode->scale = slhs->scale; @@ -377,7 +413,9 @@ Expr AddSubForwardRewrite(const Call& ref_call, const Array& new_args, } else { CHECK(srhs != nullptr); CHECK(MatchBroadcastToLeftAxes(trhs, tlhs, srhs->axes)); - Expr scale = ExpandBiasToMatchAxis(srhs->scale, trhs->shape.size(), srhs->axes); + Expr scale = ReshapeOrExpandToMatchAxis( + srhs->scale, trhs->shape, srhs->axes); + CHECK(scale.defined()); Expr lhs = Divide(new_args[0], scale); rnode->value = Call(ref_call->op, {lhs, srhs->value}, ref_call->attrs, ref_call->type_args); rnode->scale = srhs->scale; @@ -449,7 +487,6 @@ Array Conv2DForwardPrep(const Call& call, const Message& out_message) { CHECK_GE(c_big_axis, 0); Message none = NullValue(); - AxesSet data_axes = NullValue(); // For now, we only support simple pattern (no folded weight/data) // More general layout can be supported under the current framework. // By using a unified layout transformation. @@ -458,12 +495,17 @@ Array Conv2DForwardPrep(const Call& call, const Message& out_message) { // only handle depthwise or full conv2d. // TODO(tvm-team) handle grouped conv by reshape + bcast bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout); - if (kernel_layout.IndexOf(LayoutAxis::Get('i')) < 0 && c_small_axis < 0 && - (param->groups == 1 || is_depthwise_conv2d)) { - data_axes = {c_big_axis}; - } - if (data_axes.defined()) { - return {Message(data_axes, false), none}; + if (param->groups == 1 || is_depthwise_conv2d) { + auto ko_small_axis = kernel_layout.IndexOf(LayoutAxis::Get('o')); + auto ki_small_axis = kernel_layout.IndexOf(LayoutAxis::Get('i')); + if ( (ko_small_axis < 0 && ki_small_axis < 0 && c_small_axis < 0) || // simple layout + (ko_small_axis >= 0 && ki_small_axis >= 0 && c_small_axis >= 0)) { // blocked layout + Array arr{c_big_axis}; + if (c_small_axis >= 0) { + arr.push_back(c_small_axis); + } + return {Message(arr, false), none}; + } } return {none, none}; } @@ -482,12 +524,14 @@ Expr Conv2DForwardRewrite(const Call& ref_call, const Array& new_args, Layout kernel_layout(param->kernel_layout); int c_big_axis = data_layout.IndexOf(LayoutAxis::Get('C')); CHECK_GE(c_big_axis, 0); - // For now, we only support simple pattern (no folded weight/data) - // TODO(tvm-team) support general data layout - CHECK_EQ(kernel_layout.IndexOf(LayoutAxis::Get('i')), -1); - CHECK(sdata->axes.size() == 1 && c_big_axis == sdata->axes[0]->value); - int big_oc_axis = kernel_layout.IndexOf(LayoutAxis::Get('O')); - int big_ic_axis = kernel_layout.IndexOf(LayoutAxis::Get('I')); + int small_ko_axis = kernel_layout.IndexOf(LayoutAxis::Get('o')); + int small_ki_axis = kernel_layout.IndexOf(LayoutAxis::Get('i')); + int big_ki_axis = kernel_layout.IndexOf(LayoutAxis::Get('I')); + int big_ko_axis = kernel_layout.IndexOf(LayoutAxis::Get('O')); + + bool is_simple = (small_ko_axis < 0 && small_ki_axis < 0 && big_ki_axis >= 0); + bool is_blocking = (small_ko_axis >= 0 && small_ki_axis >= 0 && big_ki_axis >= 0); + CHECK(is_simple || is_blocking); // Check it must be depthwise or full conv2d. bool is_depthwise_conv2d = IsDepthwiseConv2D(ref_call, param, kernel_layout); @@ -497,11 +541,30 @@ Expr Conv2DForwardRewrite(const Call& ref_call, const Array& new_args, // match the ic_axis if (is_depthwise_conv2d) { - Expr scale = ExpandBiasToMatchAxis(sdata->scale, kernel_layout.ndim(), {big_oc_axis}); - weight = Multiply(weight, scale); + if (is_simple) { + Expr scale = ExpandBiasToMatchAxis( + sdata->scale, kernel_layout.ndim(), {big_ko_axis}); + weight = Multiply(weight, scale); + } else { + weight = Multiply(weight, ReshapeToMatchAxis(sdata->scale, + weight->type_as()->shape, + {big_ko_axis, small_ko_axis})); + if (!weight.defined()) + return Expr(); + } + } else { - Expr scale = ExpandBiasToMatchAxis(sdata->scale, kernel_layout.ndim(), {big_ic_axis}); - weight = Multiply(weight, scale); + if (is_simple) { + Expr scale = ExpandBiasToMatchAxis( + sdata->scale, kernel_layout.ndim(), {big_ki_axis}); + weight = Multiply(weight, scale); + } else { + weight = Multiply(weight, ReshapeToMatchAxis(sdata->scale, + weight->type_as()->shape, + {big_ki_axis, small_ki_axis})); + if (!weight.defined()) + return Expr(); + } } // return transformed conv2d return Call(ref_call->op, {sdata->value, weight}, ref_call->attrs, ref_call->type_args); @@ -755,15 +818,20 @@ Expr AddSubBackwardTransform(const Call& call, const Message& message, const Exp } else if (lhs_message.defined()) { CHECK(equal(message->axes, lhs_message->axes)); Expr lhs = transformer->Transform(call->args[0], message, scale); - Expr rhs = transformer->Transform(call->args[1], NullValue(), NullValue()); - Expr rhs_scale = ExpandBiasToMatchAxis(scale, tlhs->shape.size(), message->axes); + Expr rhs = transformer->Transform( + call->args[1], NullValue(), NullValue()); + Expr rhs_scale = ReshapeOrExpandToMatchAxis(scale, tlhs->shape, + message->axes); + CHECK(rhs_scale.defined()); rhs = Multiply(rhs, rhs_scale); return Call(call->op, {lhs, rhs}, call->attrs, call->type_args); } else if (rhs_message.defined()) { CHECK(equal(message->axes, rhs_message->axes)); Expr lhs = transformer->Transform(call->args[0], NullValue(), NullValue()); Expr rhs = transformer->Transform(call->args[1], message, scale); - Expr lhs_scale = ExpandBiasToMatchAxis(scale, trhs->shape.size(), message->axes); + Expr lhs_scale = ReshapeOrExpandToMatchAxis( + scale, trhs->shape, message->axes); + CHECK(lhs_scale.defined()); lhs = Multiply(lhs, lhs_scale); return Call(call->op, {lhs, rhs}, call->attrs, call->type_args); } else { @@ -881,21 +949,10 @@ Expr Conv2DBackwardTransform(const Call& call, const Message& message, const Exp wscale = ExpandBiasToMatchAxis( scale, kernel_layout.ndim(), {big_ko_axis}); } else { - auto& wshape = weight->type_as()->shape; - Array arr; - for (size_t i = 0; i < wshape.size(); i++) { - if (i == static_cast(small_ko_axis) || i == static_cast(big_ko_axis)) { - auto node = wshape[i].as(); - if (!node) { - // if the shape is not a constant, use normal transform - return transformer->NormalCallTransform(call.operator->()); - } - arr.push_back(node->value); - } else { - arr.push_back(1); - } - } - wscale = MakeReshape(scale, std::move(arr)); + wscale = ReshapeToMatchAxis(scale, weight->type_as()->shape, + {big_ko_axis, small_ko_axis}); + if (!wscale.defined()) + return transformer->NormalCallTransform(call.operator->()); } weight = Multiply(weight, wscale); return Call(call->op, {data, weight}, call->attrs, call->type_args); diff --git a/tests/python/relay/test_pass_fold_scale_axis.py b/tests/python/relay/test_pass_fold_scale_axis.py index d7c437adcc99..8aecf3f891f3 100644 --- a/tests/python/relay/test_pass_fold_scale_axis.py +++ b/tests/python/relay/test_pass_fold_scale_axis.py @@ -35,58 +35,75 @@ def run_opt_pass(expr, opt_pass): def test_fold_fwd_simple(): """Simple testcase.""" - def before(x, conv_weight, in_bias, in_scale, channels): + def before(x, conv_weight, in_bias, in_scale, channels, blocking): args = [x, conv_weight, in_bias] - in_bias = relay.expand_dims(in_bias, axis=1, num_newaxis=2) x = relay.multiply(x, in_scale) x = relay.nn.relu(x) x = relay.add(x, in_bias) y = relay.nn.conv2d(x, conv_weight, channels=channels, kernel_size=(3, 3), - padding=(1, 1)) + padding=(1, 1), + data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW", + kernel_layout="OIHW2i{}o".format(blocking[1]) if blocking else "OIHW") return relay.Function(args, y) - def expected(x, conv_weight, in_bias, in_scale, channels): + def expected(x, conv_weight, in_bias, in_scale, in_channels, channels, blocking): # use a fixed order of args so alpha equal check can pass args = [x, conv_weight, in_bias] - in_bias = relay.expand_dims(in_bias, axis=1, num_newaxis=2) - squeezed_scale = relay.squeeze(in_scale, axis=[1,2]) - x = relay.nn.relu(x) - in_bias = relay.divide(in_bias, relay.expand_dims(squeezed_scale, axis=1, num_newaxis=2)) - x = relay.add(x, in_bias) - conv_weight = relay.multiply( - conv_weight , relay.expand_dims(squeezed_scale, axis=1, num_newaxis=2)) + if blocking: + squeezed_scale = relay.squeeze(in_scale, axis=[0,2,3]) + x = relay.nn.relu(x) + in_bias = relay.divide(in_bias, + relay.reshape(squeezed_scale, (1, in_channels // blocking[0], 1, 1, blocking[0]))) #NCHWc + x = relay.add(x, in_bias) + conv_weight = relay.multiply(conv_weight, + relay.reshape(squeezed_scale, (1, in_channels//2, 1, 1, 2, 1))) #OIHWio + else: + squeezed_scale = relay.squeeze(in_scale, axis=[1,2]) + x = relay.nn.relu(x) + in_bias = relay.divide(in_bias, relay.expand_dims(squeezed_scale, axis=1, num_newaxis=2)) + x = relay.add(x, in_bias) + conv_weight = relay.multiply( + conv_weight , relay.expand_dims(squeezed_scale, axis=1, num_newaxis=2)) + y = relay.nn.conv2d(x, conv_weight, channels=channels, kernel_size=(3, 3), - padding=(1, 1)) + padding=(1, 1), + data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW", + kernel_layout="OIHW2i{}o".format(blocking[1]) if blocking else "OIHW") return relay.Function(args, y) - def check(shape, channels): + def check(shape, channels, blocking): x = relay.var("x", shape=shape) - in_channels = shape[1] weight = relay.var("weight") - in_bias = relay.var("in_bias", shape=(in_channels,)) - in_scale = relay.const(_get_positive_scale((in_channels, 1, 1))) - y1 = before(x, weight, in_bias, in_scale, channels) + if blocking: + in_channels = shape[1] * shape[4] + in_bias = relay.var("in_bias", shape=(1, in_channels // blocking[0], 1, 1, blocking[0])) + in_scale = relay.const(_get_positive_scale((1, in_channels // blocking[0], 1, 1, blocking[0]))) + else: + in_channels = shape[1] + in_bias = relay.var("in_bias", shape=(in_channels, 1, 1)) + in_scale = relay.const(_get_positive_scale((in_channels, 1, 1))) + y1 = before(x, weight, in_bias, in_scale, channels, blocking) y1 = run_opt_pass(y1, transform.InferType()) type_dict = {x.name_hint:x.checked_type for x in y1.params} weight = relay.var("weight", type_dict["weight"]) y1_folded = run_opt_pass(y1, transform.ForwardFoldScaleAxis()) - y1_expected = expected(x, weight, in_bias, in_scale, channels) + y1_expected = expected(x, weight, in_bias, in_scale, in_channels, channels, blocking) y1_folded = run_opt_pass(y1_folded, transform.InferType()) y1_expected = run_opt_pass(y1_expected, transform.InferType()) assert tvm.ir.structural_equal(y1_folded, y1_expected) - check((2, 4, 10, 10), 2) - + check((2, 4, 10, 10), 2, None) + check((2, 2, 10, 10, 2), 8, (2, 4)) def test_fold_fwd_dual_path(): """scale axis being consumed by two consumers""" - def before(x, conv_weight, in_bias, in_scale, channels): + def before(x, conv_weight, in_bias, in_scale, channels, blocking): args = [x, conv_weight, in_bias] x = relay.multiply(in_scale, x) x = relay.nn.relu(x) @@ -94,363 +111,474 @@ def before(x, conv_weight, in_bias, in_scale, channels): y1 = relay.nn.conv2d(x, conv_weight, channels=channels, kernel_size=(3, 3), - data_layout="NHWC", - kernel_layout="HWIO", + data_layout="NHWC{}c".format(blocking[0]) if blocking else "NHWC", + kernel_layout="HWIO1i{}o".format(blocking[1]) if blocking else "HWIO", groups=channels, padding=(1, 1)) y2 = relay.nn.conv2d(x, conv_weight, channels=channels, kernel_size=(3, 3), - data_layout="NHWC", - kernel_layout="HWIO", + data_layout="NHWC{}c".format(blocking[0]) if blocking else "NHWC", + kernel_layout="HWIO1i{}o".format(blocking[1]) if blocking else "HWIO", groups=channels, padding=(1, 1)) z = relay.add(y1, y2) return relay.Function(args, z) - def expected(x, conv_weight, in_bias, in_scale, channels): + def expected(x, conv_weight, in_bias, in_scale, channels, blocking): args = [x, conv_weight, in_bias] x = relay.nn.relu(x) - in_bias = relay.divide(in_bias, in_scale) + if blocking: + _in_scale = relay.reshape(in_scale, (1, 1, 1, channels//blocking[0], blocking[0])) #NHWCc + else: + _in_scale = in_scale + in_bias = relay.divide(in_bias, _in_scale) x = relay.subtract(x, in_bias) + if blocking: + _in_scale = relay.reshape(in_scale, (1, 1, 1, channels//blocking[0], 1, blocking[0])) #HWIOio y1 = relay.nn.conv2d(x, - relay.multiply(conv_weight, in_scale), + relay.multiply(conv_weight, _in_scale), channels=channels, kernel_size=(3, 3), - data_layout="NHWC", - kernel_layout="HWIO", + data_layout="NHWC{}c".format(blocking[0]) if blocking else "NHWC", + kernel_layout="HWIO1i{}o".format(blocking[1]) if blocking else "HWIO", groups=channels, padding=(1, 1)) + if blocking: + _in_scale = relay.reshape(in_scale, (1, 1, 1, channels//blocking[0], 1, blocking[0])) #HWIOio y2 = relay.nn.conv2d(x, - relay.multiply(conv_weight, in_scale), + relay.multiply(conv_weight, _in_scale), channels=channels, kernel_size=(3, 3), - data_layout="NHWC", - kernel_layout="HWIO", + data_layout="NHWC{}c".format(blocking[0]) if blocking else "NHWC", + kernel_layout="HWIO1i{}o".format(blocking[1]) if blocking else "HWIO", groups=channels, padding=(1, 1)) z = relay.add(y1, y2) return relay.Function(args, z) - def check(dshape, channels): + def check(dshape, channels, blocking): x = relay.var("x", shape=dshape) - in_channels = dshape[-1] + if blocking: + in_channels = dshape[3] * dshape[4] + wshape = (3, 3, 1, channels//blocking[1], 1, blocking[1]) # HWIOio + weight = relay.var("weight", shape=wshape) + in_bias = relay.var("in_bias", shape=(in_channels//blocking[0],blocking[0])) + in_scale = relay.const(_get_positive_scale((in_channels//blocking[0],blocking[0]))) + else: + in_channels = dshape[-1] + wshape = (3, 3, 1, channels) # HWIO + weight = relay.var("weight", shape=wshape) + in_bias = relay.var("in_bias", shape=(in_channels,)) + in_scale = relay.const(_get_positive_scale(in_channels,)) + # test depthwise assert in_channels == channels - wshape = (3, 3, 1, channels) # HWIO - weight = relay.var("weight", shape=wshape) - in_bias = relay.var("in_bias", shape=(in_channels,)) - in_scale = relay.const(_get_positive_scale(in_channels,)) - y1 = before(x, weight, in_bias, in_scale, channels) + + y1 = before(x, weight, in_bias, in_scale, channels, blocking) y1 = run_opt_pass(y1, transform.InferType()) y1_folded = run_opt_pass(y1, transform.ForwardFoldScaleAxis()) type_dict = {x.name_hint:x.checked_type for x in y1.params} weight = relay.var("weight", type_dict["weight"]) - y1_expected = expected(x, weight, in_bias, in_scale, channels) + y1_expected = expected(x, weight, in_bias, in_scale, channels, blocking) y1_expected = run_opt_pass(y1_expected, transform.InferType()) assert tvm.ir.structural_equal(y1_folded, y1_expected) - check((2, 4, 10, 3), 3) - + check((2, 4, 10, 3), 3, None) + check((2, 4, 10, 2, 2), 4, (2, 2)) def test_fold_fwd_fail(): """testcase where we canont fold""" - def before(x, conv_weight, in_bias, in_scale, channels): + def before(x, conv_weight, in_bias, in_scale, channels, blocking): x = relay.multiply(x, in_scale) xx = relay.nn.leaky_relu(x, alpha=0.1) y1 = relay.nn.conv2d(xx, conv_weight, channels=channels, kernel_size=(3, 3), - data_layout="NHWC", + data_layout="NHWC{}c".format(blocking[0]) if blocking else "NHWC", + kernel_layout="HWIO1i{}o".format(blocking[1]) if blocking else "HWIO", padding=(1, 1)) z = relay.add(y1, x) return relay.Function(relay.analysis.free_vars(z), z) - def check(shape, channels): + def check(shape, channels, blocking): x = relay.var("x", shape=shape) - in_channels = shape[-1] + if blocking: + in_channels = shape[3] * shape[4] + in_bias = relay.var("in_bias", shape=(in_channels//blocking[0],blocking[0])) + in_scale = relay.const(_get_positive_scale((in_channels//blocking[0],blocking[0]))) + else: + in_channels = shape[-1] + in_bias = relay.var("in_bias", shape=(in_channels,)) + in_scale = relay.const(_get_positive_scale(size=(in_channels,))) # test depthwise assert in_channels == channels weight = relay.var("weight") - in_bias = relay.var("in_bias", shape=(in_channels,)) - in_scale = relay.const(_get_positive_scale(size=(in_channels,))) - y1 = before(x, weight, in_bias, in_scale, channels) + y1 = before(x, weight, in_bias, in_scale, channels, blocking) y1 = run_opt_pass(y1, transform.InferType()) y1_folded = run_opt_pass(y1, transform.ForwardFoldScaleAxis()) assert tvm.ir.structural_equal(y1, y1_folded) - check((2, 11, 10, 4), 4) - + check((2, 11, 10, 4), 4, None) + check((2, 11, 10, 2, 2), 4, (2,2)) def test_fold_fwd_relu_fail(): """testcase where we canont fold because scale can not pass relu""" - def before(x, conv_weight, in_bias, in_scale, channels): + def before(x, conv_weight, in_bias, in_scale, channels, blocking): x = relay.multiply(x, in_scale) xx = relay.nn.relu(x) y1 = relay.nn.conv2d(xx, conv_weight, channels=channels, kernel_size=(3, 3), - data_layout="NHWC", + data_layout="NHWC{}c".format(blocking[0]) if blocking else "NHWC", + kernel_layout="HWIO1i{}o".format(blocking[1]) if blocking else "HWIO", padding=(1, 1)) z = relay.add(y1, x) return relay.Function(relay.analysis.free_vars(z), z) - def check(shape, channels, in_scale): + def check(shape, channels, blocking, in_scale): x = relay.var("x", shape=shape) - in_channels = shape[-1] - # test depthwise - assert in_channels == channels weight = relay.var("weight") - in_bias = relay.var("in_bias", shape=(in_channels,)) - y1 = before(x, weight, in_bias, in_scale, channels) + if blocking: + in_channels = shape[3] * shape[4] + in_bias = relay.var("in_bias", shape=(1, in_channels // blocking[0], 1, 1, blocking[0])) + else: + in_channels = shape[-1] + in_bias = relay.var("in_bias", shape=(in_channels,)) + + assert in_channels == channels + y1 = before(x, weight, in_bias, in_scale, channels, blocking) y1 = run_opt_pass(y1, transform.InferType()) y1_folded = run_opt_pass(y1, transform.ForwardFoldScaleAxis()) assert tvm.ir.structural_equal(y1, y1_folded) in_scale = relay.var("in_scale", shape=(4,)) - check((2, 11, 10, 4), 4, in_scale) + check((2, 11, 10, 4), 4, None, in_scale) in_scale = relay.const(-_get_positive_scale((4,))) - check((2, 11, 10, 4), 4, in_scale) + check((2, 11, 10, 4), 4, None, in_scale) + + in_scale = relay.var("in_scale", shape=(1,1,1,2,2)) + check((2, 11, 10, 2, 2), 4, (2, 2), in_scale) + in_scale = relay.const(-_get_positive_scale((1,1,1,2,2))) + check((2, 11, 10, 2, 2), 4, (2, 2), in_scale) + + def test_fold_fwd_negative_scale(): """Testcase of folding negative scale""" - def before(x, conv_weight, in_scale, channels): + def before(x, conv_weight, in_scale, channels, blocking): args = [x, conv_weight] x = relay.multiply(x, in_scale) y = relay.nn.conv2d(x, conv_weight, channels=channels, kernel_size=(3, 3), - padding=(1, 1)) + padding=(1, 1), + data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW", + kernel_layout="OIHW4i{}o".format(blocking[1]) if blocking else "OIHW") return relay.Function(args, y) - def expected(x, conv_weight, in_scale, channels): + def expected(x, conv_weight, in_scale, in_channels, channels, blocking): # use a fixed order of args so alpha equal check can pass args = [x, conv_weight] - squeezed_scale = relay.squeeze(in_scale, axis=[1,2]) - conv_weight = relay.multiply( - conv_weight , relay.expand_dims(squeezed_scale, axis=1, num_newaxis=2)) + if blocking: + squeezed_scale = relay.squeeze(in_scale, axis=[0,2,3]) + conv_weight = relay.multiply( + conv_weight , relay.reshape(squeezed_scale, (1, in_channels//4, 1, 1, 4, 1))) + #blocking by "i" in OIHWio + else: + squeezed_scale = relay.squeeze(in_scale, axis=[1,2]) + conv_weight = relay.multiply( + conv_weight , relay.expand_dims(squeezed_scale, axis=1, num_newaxis=2)) y = relay.nn.conv2d(x, conv_weight, channels=channels, kernel_size=(3, 3), - padding=(1, 1)) + padding=(1, 1), + data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW", + kernel_layout="OIHW4i{}o".format(blocking[1]) if blocking else "OIHW") return relay.Function(args, y) - def check(shape, channels): + def check(shape, channels, blocking): x = relay.var("x", shape=shape) - in_channels = shape[1] - in_scale = relay.const(-_get_positive_scale((in_channels, 1, 1))) + if blocking: + in_channels = shape[1] * shape[4] + in_scale = relay.const(-_get_positive_scale((1, shape[1], 1, 1, shape[4]))) + else: + in_channels = shape[1] + in_scale = relay.const(-_get_positive_scale((in_channels, 1, 1))) weight = relay.var("weight") - y1 = before(x, weight, in_scale, channels) + y1 = before(x, weight, in_scale, channels, blocking) y1 = run_opt_pass(y1, transform.InferType()) type_dict = {x.name_hint:x.checked_type for x in y1.params} weight = relay.var("weight", type_dict["weight"]) y1_folded = run_opt_pass(y1, transform.ForwardFoldScaleAxis()) - y1_expected = expected(x, weight, in_scale, channels) + y1_expected = expected(x, weight, in_scale, in_channels, channels, blocking) y1_expected = run_opt_pass(y1_expected, transform.InferType()) assert tvm.ir.structural_equal(y1_folded, y1_expected) - check((2, 4, 10, 10), 4) - + check((2, 4, 10, 10), 4, None) + check((2, 2, 10, 10, 2), 8, (2, 2)) def test_fold_bwd_simple(): """Simple testcase.""" - def before(x, conv_weight, out_bias, out_scale, channels): + def before(x, conv_weight, out_bias, out_scale, in_channels, channels, blocking): args = [x, conv_weight, out_bias] - out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2) + if blocking: + out_bias = relay.reshape(out_bias, (1, channels//blocking[1], 1, 1, blocking[1])) + else: + out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2) y = relay.nn.conv2d(x, conv_weight, channels=channels, kernel_size=(3, 3), - padding=(1, 1)) + padding=(1, 1), + data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW", + kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW") y = relay.add(y, out_bias) y = relay.nn.relu(y) + if blocking: + out_scale = relay.reshape(out_scale, (1, channels//blocking[1], 1, 1, blocking[1])) y = relay.multiply(y, out_scale) return relay.Function(args, y) - def expected(x, conv_weight, out_bias, out_scale, channels): + def expected(x, conv_weight, out_bias, out_scale, in_channels, channels, blocking): # use a fixed order of args so alpha equal check can pass args = [x, conv_weight, out_bias] - out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2) - squeezed_scale = relay.squeeze(out_scale, axis=[1,2]) - conv_weight = relay.multiply( - conv_weight , relay.expand_dims(squeezed_scale, axis=1, num_newaxis=3)) + if blocking: + out_bias = relay.reshape(out_bias, (1, channels//blocking[1], 1, 1, blocking[1])) + out_scale = relay.reshape(out_scale, (1, channels//blocking[1], 1, 1, blocking[1])) + squeezed_scale = relay.squeeze(out_scale, axis=[0, 2, 3]) + conv_weight = relay.multiply( + conv_weight , relay.reshape(squeezed_scale, (channels//blocking[1], 1, 1, 1, 1, blocking[1]))) + else: + out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2) + squeezed_scale = relay.squeeze(out_scale, axis=[1,2]) + conv_weight = relay.multiply( + conv_weight , relay.expand_dims(squeezed_scale, axis=1, num_newaxis=3)) y = relay.nn.conv2d(x, conv_weight, channels=channels, kernel_size=(3, 3), - padding=(1, 1)) - out_bias = relay.multiply(out_bias, + padding=(1, 1), + data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW", + kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW") + if blocking: + out_bias = relay.multiply(out_bias, + relay.reshape(squeezed_scale, (1, channels//blocking[1], 1, 1, blocking[1]))) + else: + out_bias = relay.multiply(out_bias, relay.expand_dims(squeezed_scale, axis=1, num_newaxis=2)) y = relay.add(y, out_bias) y = relay.nn.relu(y) return relay.Function(args, y) - def check(shape, channels): + def check(shape, in_channels, channels, blocking): x = relay.var("x", shape=shape) - in_channels = shape[1] weight = relay.var("weight") out_bias = relay.var("out_bias", shape=(channels,)) - out_scale = relay.const(_get_positive_scale((channels, 1, 1))) - - y1 = before(x, weight, out_bias, out_scale, channels) + if blocking: + out_scale = relay.const(_get_positive_scale((channels,))) + else: + out_scale = relay.const(_get_positive_scale((channels,1, 1))) + y1 = before(x, weight, out_bias, out_scale, in_channels, channels, blocking) y1 = run_opt_pass(y1, transform.InferType()) type_dict = {x.name_hint:x.checked_type for x in y1.params} weight = relay.var("weight", type_dict["weight"]) y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis()) - y1_expected = expected(x, weight, out_bias, out_scale, channels) + y1_expected = expected(x, weight, out_bias, out_scale, in_channels, channels, blocking) y1_expected = run_opt_pass(y1_expected, transform.InferType()) assert tvm.ir.structural_equal(y1_folded, y1_expected) - check((2, 4, 10, 10), 8) + check((2, 4, 10, 10), 4, 8, None) + check((2, 2, 10, 10, 16), 32, 64, (16, 16)) def test_fold_bwd_dual_path(): """Dual path testcase.""" - def before(x, conv_weight, out_bias, out_scale, channels): + def before(x, conv_weight, out_bias, out_scale, in_channels, channels, blocking): args = [x, conv_weight, out_bias] y1 = relay.nn.conv2d(x, conv_weight, channels=channels, kernel_size=(3, 3), - padding=(1, 1)) + padding=(1, 1), + data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW", + kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW") y1 = relay.nn.relu(y1) y2 = relay.nn.conv2d(x, conv_weight, channels=channels, kernel_size=(3, 3), - padding=(1, 1)) + padding=(1, 1), + data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW", + kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW") y2 = relay.nn.relu(y2) y = relay.add(y1, y2) y = relay.multiply(y, out_scale) return relay.Function(args, y) - def expected(x, conv_weight, out_bias, out_scale, channels): + def expected(x, conv_weight, out_bias, out_scale, in_channels, channels, blocking): # use a fixed order of args so alpha equal check can pass args = [x, conv_weight, out_bias] - out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2) + if not blocking: + out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2) squeezed_scale = relay.squeeze(out_scale, axis=[1,2]) def fold_conv_weight(): - return relay.multiply( - conv_weight , - relay.expand_dims(squeezed_scale, axis=1, num_newaxis=3)) + if blocking: + return relay.multiply( + conv_weight , + relay.reshape(squeezed_scale, (channels//blocking[1], 1, 1, 1, 1, blocking[1]))) + else: + return relay.multiply( + conv_weight , + relay.expand_dims(squeezed_scale, axis=1, num_newaxis=3)) y1 = relay.nn.conv2d(x, fold_conv_weight(), channels=channels, kernel_size=(3, 3), - padding=(1, 1)) + padding=(1, 1), + data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW", + kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW") y1 = relay.nn.relu(y1) y2 = relay.nn.conv2d(x, fold_conv_weight(), channels=channels, kernel_size=(3, 3), - padding=(1, 1)) + padding=(1, 1), + data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW", + kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW") y2 = relay.nn.relu(y2) y = relay.add(y1, y2) return relay.Function(args, y) - def check(shape, channels): + def check(shape, in_channels, channels, blocking): x = relay.var("x", shape=shape) - in_channels = shape[1] weight = relay.var("weight") - out_bias = relay.var("out_bias", shape=(channels,)) - out_scale = relay.const(_get_positive_scale((channels, 1, 1))) - - y1 = before(x, weight, out_bias, out_scale, channels) + if blocking: + out_bias = relay.var("out_bias", shape=(channels // blocking[1], 1, 1, blocking[1])) + out_scale = relay.const(_get_positive_scale((channels // blocking[1], 1, 1, blocking[1]))) + else: + out_bias = relay.var("out_bias", shape=(channels,)) + out_scale = relay.const(_get_positive_scale((channels, 1, 1))) + + y1 = before(x, weight, out_bias, out_scale, in_channels, channels, blocking) y1 = run_opt_pass(y1, transform.InferType()) type_dict = {x.name_hint:x.checked_type for x in y1.params} weight = relay.var("weight", type_dict["weight"]) y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis()) - y1_expected = expected(x, weight, out_bias, out_scale, channels) + y1_expected = expected(x, weight, out_bias, out_scale, in_channels, channels, blocking) y1_expected = run_opt_pass(y1_expected, transform.InferType()) assert tvm.ir.structural_equal(y1_folded, y1_expected) - check((2, 4, 10, 10), 8) - + check((2, 4, 10, 10), 4, 8, None) + check((2, 2, 10, 10, 2), 4, 8, (2, 2)) def test_fold_bwd_dual_consumer(): - def before(x, conv_weight, out_bias, out_scale, channels): + def before(x, conv_weight, out_bias, out_scale, in_channels, channels, blocking): args = [x, conv_weight, out_bias] y0 = relay.nn.conv2d(x, conv_weight, channels=channels, kernel_size=(3, 3), - padding=(1, 1)) + padding=(1, 1), + data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW", + kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW") y0 = relay.multiply(y0, out_scale) y0 = relay.nn.relu(y0) y1 = relay.nn.conv2d(y0, conv_weight, channels=channels, kernel_size=(3, 3), - padding=(1, 1)) + padding=(1, 1), + data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW", + kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW") y1 = relay.multiply(y1, out_scale) y1 = relay.nn.relu(y1) y2 = relay.nn.conv2d(y0, conv_weight, channels=channels, kernel_size=(3, 3), - padding=(1, 1)) + padding=(1, 1), + data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW", + kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW") y2 = relay.multiply(y2, out_scale) y2 = relay.nn.relu(y2) y = relay.add(y1, y2) return relay.Function(args, y) - def expected(x, conv_weight, out_bias, out_scale, channels): + def expected(x, conv_weight, out_bias, out_scale, in_channels, channels, blocking): # use a fixed order of args so alpha equal check can pass args = [x, conv_weight, out_bias] def fold_conv_weight(): squeezed_scale = relay.squeeze(out_scale, axis=[1,2]) - return relay.multiply( - conv_weight , - relay.expand_dims(squeezed_scale, axis=1, num_newaxis=3)) + if blocking: + return relay.multiply( + conv_weight , + relay.reshape(squeezed_scale, (channels//blocking[1], 1, 1, 1, 1, blocking[1]))) + else: + return relay.multiply( + conv_weight , + relay.expand_dims(squeezed_scale, axis=1, num_newaxis=3)) y0 = relay.nn.conv2d(x, fold_conv_weight(), channels=channels, kernel_size=(3, 3), - padding=(1, 1)) + padding=(1, 1), + data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW", + kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW") y0 = relay.nn.relu(y0) y1 = relay.nn.conv2d(y0, fold_conv_weight(), channels=channels, kernel_size=(3, 3), - padding=(1, 1)) + padding=(1, 1), + data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW", + kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW") y1 = relay.nn.relu(y1) y2 = relay.nn.conv2d(y0, fold_conv_weight(), channels=channels, kernel_size=(3, 3), - padding=(1, 1)) + padding=(1, 1), + data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW", + kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW") y2 = relay.nn.relu(y2) y = relay.add(y1, y2) return relay.Function(args, y) - def check(shape, channels): + def check(shape, in_channels, channels, blocking): x = relay.var("x", shape=shape) - in_channels = shape[1] weight = relay.var("weight") - out_bias = relay.var("out_bias", shape=(channels,)) - out_scale = relay.const(_get_positive_scale((channels,1, 1))) - - y1 = before(x, weight, out_bias, out_scale, channels) + if blocking: + out_bias = relay.var("out_bias", shape=(channels // blocking[1], 1, 1, blocking[1])) + out_scale = relay.const(_get_positive_scale((channels // blocking[1], 1, 1, blocking[1]))) + else: + out_bias = relay.var("out_bias", shape=(channels,)) + out_scale = relay.const(_get_positive_scale((channels, 1, 1))) + + y1 = before(x, weight, out_bias, out_scale, in_channels, channels, blocking) y1 = run_opt_pass(y1, transform.InferType()) type_dict = {x.name_hint:x.checked_type for x in y1.params} weight = relay.var("weight", type_dict["weight"]) y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis()) - y1_expected = expected(x, weight, out_bias, out_scale, channels) + y1_expected = expected(x, weight, out_bias, out_scale, in_channels, channels, blocking) y1_expected = run_opt_pass(y1_expected, transform.InferType()) assert tvm.ir.structural_equal(y1_folded, y1_expected) - check((2, 4, 10, 10), 4) - + check((2, 4, 10, 10), 4, 4, None) + check((2, 2, 10, 10, 2), 4, 4, (2, 2)) def test_fold_bwd_fail(): """Dual path testcase.""" - def fail1(x, conv_weight, out_bias, out_scale, channels): + def fail1(x, conv_weight, out_bias, out_scale, in_channels, channels, blocking): args = [x, conv_weight, out_bias] - out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2) y1 = relay.nn.conv2d(x, conv_weight, channels=channels, kernel_size=(3, 3), - padding=(1, 1)) + padding=(1, 1), + data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW", + kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW") y1 = relay.nn.relu(y1) y2 = relay.nn.conv2d(x, conv_weight, channels=channels, kernel_size=(3, 3), padding=(1, 1), - out_layout="CNHW") + data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW", + kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW", + out_layout="CNHW{}c".format(blocking[1]) if blocking else "CNHW") # fold will fail because the axis from two path # differs from each other. y2 = relay.nn.relu(y2) @@ -458,99 +586,123 @@ def fail1(x, conv_weight, out_bias, out_scale, channels): y = relay.multiply(y, out_scale) return relay.Function(args, y) - def fail2(x, conv_weight, out_bias, out_scale, channels): + def fail2(x, conv_weight, out_bias, out_scale, in_channels, channels, blocking): args = [x, conv_weight, out_bias] - out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2) y1 = relay.nn.conv2d(x, conv_weight, channels=channels, kernel_size=(3, 3), - padding=(1, 1)) + padding=(1, 1), + data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW", + kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW") y2 = relay.nn.relu(y1) # fold will fail because y1 is referred also by y2 y1 = relay.multiply(y1, out_scale) y = relay.add(y1, y2) return relay.Function(args, y) - def check(shape, channels, fbefore): + def check(shape, in_channels, channels, blocking, fbefore): x = relay.var("x", shape=shape) - in_channels = shape[1] weight = relay.var("weight") - out_bias = relay.var("out_bias", shape=(channels,)) - out_scale = relay.const(_get_positive_scale((channels, 1, 1))) - y1 = fbefore(x, weight, out_bias, out_scale, channels) + if blocking: + out_bias = relay.var("out_bias", shape=(channels // blocking[1], 1, 1, blocking[1])) + out_scale = relay.const(_get_positive_scale((channels // blocking[1], 1, 1, blocking[1]))) + else: + out_bias = relay.var("out_bias", shape=(channels, 1, 1)) + out_scale = relay.const(_get_positive_scale((channels, 1, 1))) + y1 = fbefore(x, weight, out_bias, out_scale, in_channels, channels, blocking) y1 = run_opt_pass(y1, transform.InferType()) y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis()) assert tvm.ir.structural_equal(y1_folded, y1) - check((4, 4, 10, 10), 4, fail1) - check((4, 4, 10, 10), 4, fail2) + check((4, 4, 10, 10), 4, 4, None, fail1) + check((2, 2, 10, 10, 2), 4, 4, (2, 2), fail1) + check((4, 4, 10, 10), 4, 4, None, fail2) + check((4, 2, 10, 10, 2), 4, 4, (2, 2), fail2) def test_fold_bwd_relu_fail(): """testcase where we canont fold because scale can not pass relu""" - def before(x, conv_weight, out_scale, channels): + def before(x, conv_weight, out_scale, channels, blocking): y = relay.nn.conv2d(x, conv_weight, channels=channels, kernel_size=(3, 3), - data_layout="NCHW", - padding=(1, 1)) + padding=(1, 1), + data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW", + kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW") y = relay.nn.relu(y) y = relay.multiply(x, out_scale) return relay.Function(relay.analysis.free_vars(y), y) - def check(shape, channels, out_scale): + def check(shape, channels, blocking, out_scale): x = relay.var("x", shape=shape) in_channels = shape[1] weight = relay.var("weight") - y1 = before(x, weight, out_scale, channels) + y1 = before(x, weight, out_scale, channels, blocking) y1 = run_opt_pass(y1, transform.InferType()) y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis()) assert tvm.ir.structural_equal(y1, y1_folded) out_scale = relay.var("in_scale", shape=(4, 1, 1)) - check((4, 4, 10, 10), 4, out_scale) + check((4, 4, 10, 10), 4, None, out_scale) out_scale = relay.const(np.random.uniform(size=(4, 1, 1), low=-1.0, high=0.0)).astype("float32") - check((4, 4, 10, 10), 4, out_scale) + check((4, 4, 10, 10), 4, None, out_scale) + + out_scale = relay.var("in_scale", shape=(1, 2, 1, 1, 2)) + check((4, 2, 10, 10, 2), 4, (2, 2), out_scale) + out_scale = relay.const(np.random.uniform(size=(1, 2, 1, 1, 2), low=-1.0, high=0.0)).astype("float32") + check((4, 2, 10, 10, 2), 4, (2, 2), out_scale) def test_fold_bwd_negative_scale(): """Testcase of folding negative scale""" - def before(x, conv_weight, out_scale, channels): + def before(x, conv_weight, out_scale, channels, blocking): args = [x, conv_weight] y = relay.nn.conv2d(x, conv_weight, channels=channels, kernel_size=(3, 3), - padding=(1, 1)) + padding=(1, 1), + data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW", + kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW") y = relay.multiply(y, out_scale) return relay.Function(args, y) - def expected(x, conv_weight, out_scale, channels): + def expected(x, conv_weight, out_scale, channels, blocking): # use a fixed order of args so alpha equal check can pass args = [x, conv_weight] - squeezed_scale = relay.squeeze(out_scale, axis=[1,2]) - conv_weight = relay.multiply( - conv_weight , relay.expand_dims(squeezed_scale, axis=1, num_newaxis=3)) + if blocking: + squeezed_scale = relay.squeeze(out_scale, axis=[0,2,3]) + conv_weight = relay.multiply( + conv_weight , relay.reshape(squeezed_scale, (channels//blocking[1], 1, 1, 1, 1, blocking[1]))) + else: + squeezed_scale = relay.squeeze(out_scale, axis=[1,2]) + conv_weight = relay.multiply( + conv_weight , relay.expand_dims(squeezed_scale, axis=1, num_newaxis=3)) y = relay.nn.conv2d(x, conv_weight, channels=channels, kernel_size=(3, 3), - padding=(1, 1)) + padding=(1, 1), + data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW", + kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW") return relay.Function(args, y) - def check(shape, channels): + def check(shape, channels, blocking): x = relay.var("x", shape=shape) weight = relay.var("weight") - out_scale = relay.const(-_get_positive_scale((channels, 1, 1))) - y1 = before(x, weight, out_scale, channels) + if blocking: + out_scale = relay.const(-_get_positive_scale((1,channels//blocking[1], 1, 1, blocking[1]))) + else: + out_scale = relay.const(-_get_positive_scale((channels, 1, 1))) + y1 = before(x, weight, out_scale, channels, blocking) y1 = run_opt_pass(y1, transform.InferType()) type_dict = {x.name_hint:x.checked_type for x in y1.params} weight = relay.var("weight", type_dict["weight"]) y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis()) - y1_expected = expected(x, weight, out_scale, channels) + y1_expected = expected(x, weight, out_scale, channels, blocking) y1_expected = run_opt_pass(y1_expected, transform.InferType()) assert tvm.ir.structural_equal(y1_folded, y1_expected) - check((2, 4, 10, 10), 8) - + check((2, 4, 10, 10), 8, None) + check((2, 2, 10, 10, 2), 8, (2, 2)) if __name__ == "__main__": test_fold_fwd_simple() From 2e1a64950ce4037baa7734066da9fc47c1cc00c6 Mon Sep 17 00:00:00 2001 From: Menooker Date: Thu, 23 Apr 2020 16:24:01 +0800 Subject: [PATCH 07/13] style changes --- src/relay/transforms/fold_scale_axis.cc | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/relay/transforms/fold_scale_axis.cc b/src/relay/transforms/fold_scale_axis.cc index c94ec2315941..a37e11bb6e0e 100644 --- a/src/relay/transforms/fold_scale_axis.cc +++ b/src/relay/transforms/fold_scale_axis.cc @@ -317,7 +317,8 @@ static bool IsIntInArray(const Array& axis, int v) { return false; } -static Expr ReshapeToMatchAxis(Expr scale, const Array& shape, const Array& axis) { +static Expr ReshapeToMatchAxis(Expr scale, const Array& shape, + const Array& axis) { Array arr; for (size_t i = 0; i < shape.size(); i++) { if (IsIntInArray(axis, i)) { @@ -331,11 +332,12 @@ static Expr ReshapeToMatchAxis(Expr scale, const Array& shape, const A arr.push_back(1); } } - return MakeReshape(scale, std::move(arr)); + return MakeReshape(scale, std::move(arr)); } // if only one axis, use expand dim. Else, use reshape -static Expr ReshapeOrExpandToMatchAxis(Expr scale, const Array& shape, const Array& axis) { +static Expr ReshapeOrExpandToMatchAxis(Expr scale, const Array& shape, + const Array& axis) { if (axis.size() > 1) { return ReshapeToMatchAxis(scale, shape, axis); } else { @@ -528,7 +530,7 @@ Expr Conv2DForwardRewrite(const Call& ref_call, const Array& new_args, int small_ki_axis = kernel_layout.IndexOf(LayoutAxis::Get('i')); int big_ki_axis = kernel_layout.IndexOf(LayoutAxis::Get('I')); int big_ko_axis = kernel_layout.IndexOf(LayoutAxis::Get('O')); - + bool is_simple = (small_ko_axis < 0 && small_ki_axis < 0 && big_ki_axis >= 0); bool is_blocking = (small_ko_axis >= 0 && small_ki_axis >= 0 && big_ki_axis >= 0); CHECK(is_simple || is_blocking); @@ -560,7 +562,7 @@ Expr Conv2DForwardRewrite(const Call& ref_call, const Array& new_args, weight = Multiply(weight, scale); } else { weight = Multiply(weight, ReshapeToMatchAxis(sdata->scale, - weight->type_as()->shape, + weight->type_as()->shape, {big_ki_axis, small_ki_axis})); if (!weight.defined()) return Expr(); From e9cfec2c826a5ba9c504d4003d54019f00bd2655 Mon Sep 17 00:00:00 2001 From: Menooker Date: Sun, 26 Apr 2020 12:01:51 +0800 Subject: [PATCH 08/13] Let the optimization fail if shape is not const --- src/relay/transforms/fold_scale_axis.cc | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/relay/transforms/fold_scale_axis.cc b/src/relay/transforms/fold_scale_axis.cc index a37e11bb6e0e..31e00352d38f 100644 --- a/src/relay/transforms/fold_scale_axis.cc +++ b/src/relay/transforms/fold_scale_axis.cc @@ -407,7 +407,8 @@ Expr AddSubForwardRewrite(const Call& ref_call, const Array& new_args, CHECK(MatchBroadcastToLeftAxes(tlhs, trhs, slhs->axes)); Expr scale = ReshapeOrExpandToMatchAxis( slhs->scale, tlhs->shape, slhs->axes); - CHECK(scale.defined()); + if (!scale.defined()) + return Expr(); Expr rhs = Divide(new_args[1], scale); rnode->value = Call(ref_call->op, {slhs->value, rhs}, ref_call->attrs, ref_call->type_args); rnode->scale = slhs->scale; @@ -417,7 +418,8 @@ Expr AddSubForwardRewrite(const Call& ref_call, const Array& new_args, CHECK(MatchBroadcastToLeftAxes(trhs, tlhs, srhs->axes)); Expr scale = ReshapeOrExpandToMatchAxis( srhs->scale, trhs->shape, srhs->axes); - CHECK(scale.defined()); + if (!scale.defined()) + return Expr(); Expr lhs = Divide(new_args[0], scale); rnode->value = Call(ref_call->op, {lhs, srhs->value}, ref_call->attrs, ref_call->type_args); rnode->scale = srhs->scale; @@ -824,7 +826,9 @@ Expr AddSubBackwardTransform(const Call& call, const Message& message, const Exp call->args[1], NullValue(), NullValue()); Expr rhs_scale = ReshapeOrExpandToMatchAxis(scale, tlhs->shape, message->axes); - CHECK(rhs_scale.defined()); + if (!rhs_scale.defined()) { + return transformer->NormalCallTransform(call.operator->()); + } rhs = Multiply(rhs, rhs_scale); return Call(call->op, {lhs, rhs}, call->attrs, call->type_args); } else if (rhs_message.defined()) { @@ -833,7 +837,9 @@ Expr AddSubBackwardTransform(const Call& call, const Message& message, const Exp Expr rhs = transformer->Transform(call->args[1], message, scale); Expr lhs_scale = ReshapeOrExpandToMatchAxis( scale, trhs->shape, message->axes); - CHECK(lhs_scale.defined()); + if (!lhs_scale.defined()) { + return transformer->NormalCallTransform(call.operator->()); + } lhs = Multiply(lhs, lhs_scale); return Call(call->op, {lhs, rhs}, call->attrs, call->type_args); } else { From 001b4e706277a9e80f9758281c3dbdb689315c7d Mon Sep 17 00:00:00 2001 From: Menooker Date: Wed, 29 Apr 2020 11:28:39 +0800 Subject: [PATCH 09/13] move decl of MakeReshape --- src/relay/op/tensor/transform.h | 3 +++ src/relay/transforms/fold_scale_axis.cc | 6 +++--- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/relay/op/tensor/transform.h b/src/relay/op/tensor/transform.h index 62433c297e8e..c107c5d2543a 100644 --- a/src/relay/op/tensor/transform.h +++ b/src/relay/op/tensor/transform.h @@ -38,6 +38,9 @@ namespace tvm { namespace relay { +extern Expr MakeReshape(Expr data, + Array newshape); + template bool ConcatenateRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { diff --git a/src/relay/transforms/fold_scale_axis.cc b/src/relay/transforms/fold_scale_axis.cc index 31e00352d38f..95fa3599ff92 100644 --- a/src/relay/transforms/fold_scale_axis.cc +++ b/src/relay/transforms/fold_scale_axis.cc @@ -28,6 +28,9 @@ #include #include #include +#include "../op/tensor/transform.h" +#include "pattern_util.h" +#include "pass_util.h" #include "pass_util.h" #include "pattern_util.h" @@ -40,9 +43,6 @@ namespace relay { * Use namespace to reduce potential naming conflict. */ -extern Expr MakeReshape(Expr data, - Array newshape); - namespace fold_scale_axis { using runtime::TypedPackedFunc; From 6bfedb183ac07e719823dc5a6ccc9e8c2280279e Mon Sep 17 00:00:00 2001 From: Menooker Date: Sat, 2 May 2020 14:11:26 +0800 Subject: [PATCH 10/13] format change, remove [-+] in format RE --- python/tvm/relay/op/strategy/x86.py | 36 +++++++++++++------------ src/relay/op/tensor/transform.h | 2 +- src/relay/transforms/fold_scale_axis.cc | 10 ++++--- topi/python/topi/x86/conv2d_alter_op.py | 4 +-- 4 files changed, 28 insertions(+), 24 deletions(-) diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index 059e9eaafa56..2b8d3656831d 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -26,8 +26,8 @@ logger = logging.getLogger('strategy') -_NCHWc_matcher = re.compile("^NCHW[-+]?[0-9]+c$") -_OIHWio_matcher = re.compile("^OIHW[-+]?[0-9]+i[-+]?[0-9]+o$") +_NCHWc_matcher = re.compile("^NCHW[0-9]+c$") +_OIHWio_matcher = re.compile("^OIHW[0-9]+i[-+]?[0-9]+o$") @schedule_injective.register("cpu") def schedule_injective_cpu(attrs, outs, target): @@ -88,13 +88,7 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target): raise ValueError("dilation should be positive value") if groups == 1: - if layout.startswith("NCHW"): - if layout != "NCHW": - # check if layout is NCHWxc - assert _NCHWc_matcher.match(layout) - assert _OIHWio_matcher.match(kernel_layout) - else: - assert kernel_layout == "OIHW" + def add_implementation_nchw(): if topi.x86.is_int8_hw_support(data.dtype, kernel.dtype): strategy.add_implementation( wrap_compute_conv2d(topi.x86.conv2d_nchw_int8), @@ -105,6 +99,12 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target): wrap_compute_conv2d(topi.x86.conv2d_nchw), wrap_topi_schedule(topi.x86.schedule_conv2d_nchw), name="conv2d_nchw.x86") + if layout == "NCHW": + assert kernel_layout == "OIHW" + add_implementation_nchw() + elif _NCHWc_matcher.match(layout): # check if layout is NCHWxc + assert _OIHWio_matcher.match(kernel_layout) # check if kernel is OIHWio + add_implementation_nchw() elif layout == "NHWC": assert kernel_layout == "HWIO" logger.warning("For x86 target, NCHW layout is recommended for conv2d.") @@ -122,14 +122,7 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target): else: raise RuntimeError("Unsupported conv2d layout {} for x86".format(layout)) elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups): - if layout.startswith("NCHW"): - if layout != "NCHW": - # check if layout is NCHWxc - assert _NCHWc_matcher.match(layout) - assert _OIHWio_matcher.match(kernel_layout) - else: - assert kernel_layout == "OIHW" - channel_multiplier = get_const_tuple(inputs[1].shape)[1] + def add_implementation_depthwise_nchw(channel_multiplier): if channel_multiplier == 1 and dilation_h == 1 and dilation_w == 1: strategy.add_implementation( wrap_compute_conv2d(topi.x86.depthwise_conv2d_nchw), @@ -142,6 +135,15 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target): wrap_compute_conv2d(topi.nn.depthwise_conv2d_nchw), wrap_topi_schedule(topi.generic.schedule_depthwise_conv2d_nchw), name="depthwise_conv2d_nchw.generic") + if layout == "NCHW": + assert kernel_layout == "OIHW" + channel_multiplier = get_const_tuple(inputs[1].shape)[1] + add_implementation_depthwise_nchw(channel_multiplier) + elif _NCHWc_matcher.match(layout): # check if layout is NCHWxc + assert _OIHWio_matcher.match(kernel_layout) # check if kernel is OIHWio + kernel_shape = get_const_tuple(inputs[1].shape) + channel_multiplier = kernel_shape[1] * kernel_shape[4] + add_implementation_depthwise_nchw(channel_multiplier) elif layout == "NHWC": assert kernel_layout == "HWOI" logger.warning("depthwise_conv2d NHWC layout is not optimized for x86.") diff --git a/src/relay/op/tensor/transform.h b/src/relay/op/tensor/transform.h index c107c5d2543a..8a6dfca6b0cc 100644 --- a/src/relay/op/tensor/transform.h +++ b/src/relay/op/tensor/transform.h @@ -39,7 +39,7 @@ namespace tvm { namespace relay { extern Expr MakeReshape(Expr data, - Array newshape); + Array newshape); template bool ConcatenateRel(const Array& types, int num_inputs, const Attrs& attrs, diff --git a/src/relay/transforms/fold_scale_axis.cc b/src/relay/transforms/fold_scale_axis.cc index 95fa3599ff92..f6ea23c84f0a 100644 --- a/src/relay/transforms/fold_scale_axis.cc +++ b/src/relay/transforms/fold_scale_axis.cc @@ -318,7 +318,7 @@ static bool IsIntInArray(const Array& axis, int v) { } static Expr ReshapeToMatchAxis(Expr scale, const Array& shape, - const Array& axis) { + const Array& axis) { Array arr; for (size_t i = 0; i < shape.size(); i++) { if (IsIntInArray(axis, i)) { @@ -337,7 +337,7 @@ static Expr ReshapeToMatchAxis(Expr scale, const Array& shape, // if only one axis, use expand dim. Else, use reshape static Expr ReshapeOrExpandToMatchAxis(Expr scale, const Array& shape, - const Array& axis) { + const Array& axis) { if (axis.size() > 1) { return ReshapeToMatchAxis(scale, shape, axis); } else { @@ -407,8 +407,9 @@ Expr AddSubForwardRewrite(const Call& ref_call, const Array& new_args, CHECK(MatchBroadcastToLeftAxes(tlhs, trhs, slhs->axes)); Expr scale = ReshapeOrExpandToMatchAxis( slhs->scale, tlhs->shape, slhs->axes); - if (!scale.defined()) + if (!scale.defined()) { return Expr(); + } Expr rhs = Divide(new_args[1], scale); rnode->value = Call(ref_call->op, {slhs->value, rhs}, ref_call->attrs, ref_call->type_args); rnode->scale = slhs->scale; @@ -418,8 +419,9 @@ Expr AddSubForwardRewrite(const Call& ref_call, const Array& new_args, CHECK(MatchBroadcastToLeftAxes(trhs, tlhs, srhs->axes)); Expr scale = ReshapeOrExpandToMatchAxis( srhs->scale, trhs->shape, srhs->axes); - if (!scale.defined()) + if (!scale.defined()) { return Expr(); + } Expr lhs = Divide(new_args[0], scale); rnode->value = Call(ref_call->op, {lhs, srhs->value}, ref_call->attrs, ref_call->type_args); rnode->scale = srhs->scale; diff --git a/topi/python/topi/x86/conv2d_alter_op.py b/topi/python/topi/x86/conv2d_alter_op.py index b263c2383ca1..63d8d9b13db8 100644 --- a/topi/python/topi/x86/conv2d_alter_op.py +++ b/topi/python/topi/x86/conv2d_alter_op.py @@ -32,8 +32,8 @@ logger = logging.getLogger('topi') -_NCHWc_matcher = re.compile("^NCHW[-+]?[0-9]+c$") -_OIHWio_matcher = re.compile("^OIHW[-+]?[0-9]+i[-+]?[0-9]+o$") +_NCHWc_matcher = re.compile("^NCHW[0-9]+c$") +_OIHWio_matcher = re.compile("^OIHW[0-9]+i[-+]?[0-9]+o$") @conv2d_alter_layout.register("cpu") def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): From c8efc1b28b264d69efcb7b4a0b78b008173abd0a Mon Sep 17 00:00:00 2001 From: Menooker Date: Tue, 5 May 2020 20:21:41 +0800 Subject: [PATCH 11/13] directly use HCHWc impl in conv2d_strategy_cpu || refine REGEX --- python/tvm/relay/op/strategy/x86.py | 22 ++++++++-------------- topi/python/topi/x86/conv2d_alter_op.py | 2 +- 2 files changed, 9 insertions(+), 15 deletions(-) diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index 2b8d3656831d..fbc2ed24548b 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -27,7 +27,7 @@ logger = logging.getLogger('strategy') _NCHWc_matcher = re.compile("^NCHW[0-9]+c$") -_OIHWio_matcher = re.compile("^OIHW[0-9]+i[-+]?[0-9]+o$") +_OIHWio_matcher = re.compile("^OIHW[0-9]+i[0-9]+o$") @schedule_injective.register("cpu") def schedule_injective_cpu(attrs, outs, target): @@ -88,7 +88,8 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target): raise ValueError("dilation should be positive value") if groups == 1: - def add_implementation_nchw(): + if layout == "NCHW": + assert kernel_layout == "OIHW" if topi.x86.is_int8_hw_support(data.dtype, kernel.dtype): strategy.add_implementation( wrap_compute_conv2d(topi.x86.conv2d_nchw_int8), @@ -99,12 +100,9 @@ def add_implementation_nchw(): wrap_compute_conv2d(topi.x86.conv2d_nchw), wrap_topi_schedule(topi.x86.schedule_conv2d_nchw), name="conv2d_nchw.x86") - if layout == "NCHW": - assert kernel_layout == "OIHW" - add_implementation_nchw() elif _NCHWc_matcher.match(layout): # check if layout is NCHWxc assert _OIHWio_matcher.match(kernel_layout) # check if kernel is OIHWio - add_implementation_nchw() + return conv2d_NCHWc_strategy_cpu(attrs, inputs, out_type, target) elif layout == "NHWC": assert kernel_layout == "HWIO" logger.warning("For x86 target, NCHW layout is recommended for conv2d.") @@ -122,7 +120,9 @@ def add_implementation_nchw(): else: raise RuntimeError("Unsupported conv2d layout {} for x86".format(layout)) elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups): - def add_implementation_depthwise_nchw(channel_multiplier): + if layout == "NCHW": + assert kernel_layout == "OIHW" + channel_multiplier = get_const_tuple(inputs[1].shape)[1] if channel_multiplier == 1 and dilation_h == 1 and dilation_w == 1: strategy.add_implementation( wrap_compute_conv2d(topi.x86.depthwise_conv2d_nchw), @@ -135,15 +135,9 @@ def add_implementation_depthwise_nchw(channel_multiplier): wrap_compute_conv2d(topi.nn.depthwise_conv2d_nchw), wrap_topi_schedule(topi.generic.schedule_depthwise_conv2d_nchw), name="depthwise_conv2d_nchw.generic") - if layout == "NCHW": - assert kernel_layout == "OIHW" - channel_multiplier = get_const_tuple(inputs[1].shape)[1] - add_implementation_depthwise_nchw(channel_multiplier) elif _NCHWc_matcher.match(layout): # check if layout is NCHWxc assert _OIHWio_matcher.match(kernel_layout) # check if kernel is OIHWio - kernel_shape = get_const_tuple(inputs[1].shape) - channel_multiplier = kernel_shape[1] * kernel_shape[4] - add_implementation_depthwise_nchw(channel_multiplier) + return depthwise_conv2d_NCHWc_strategy_cpu(attrs, inputs, out_type, target) elif layout == "NHWC": assert kernel_layout == "HWOI" logger.warning("depthwise_conv2d NHWC layout is not optimized for x86.") diff --git a/topi/python/topi/x86/conv2d_alter_op.py b/topi/python/topi/x86/conv2d_alter_op.py index 63d8d9b13db8..d1c607f6a3e5 100644 --- a/topi/python/topi/x86/conv2d_alter_op.py +++ b/topi/python/topi/x86/conv2d_alter_op.py @@ -33,7 +33,7 @@ logger = logging.getLogger('topi') _NCHWc_matcher = re.compile("^NCHW[0-9]+c$") -_OIHWio_matcher = re.compile("^OIHW[0-9]+i[-+]?[0-9]+o$") +_OIHWio_matcher = re.compile("^OIHW[0-9]+i[0-9]+o$") @conv2d_alter_layout.register("cpu") def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): From 6fbab1e799e542c62005ea64be892437b784e52a Mon Sep 17 00:00:00 2001 From: Menooker Date: Tue, 12 May 2020 14:57:28 +0800 Subject: [PATCH 12/13] clang format --- src/relay/transforms/fold_scale_axis.cc | 77 ++++++++++--------------- 1 file changed, 31 insertions(+), 46 deletions(-) diff --git a/src/relay/transforms/fold_scale_axis.cc b/src/relay/transforms/fold_scale_axis.cc index f6ea23c84f0a..6c307cecd42e 100644 --- a/src/relay/transforms/fold_scale_axis.cc +++ b/src/relay/transforms/fold_scale_axis.cc @@ -29,9 +29,6 @@ #include #include #include "../op/tensor/transform.h" -#include "pattern_util.h" -#include "pass_util.h" - #include "pass_util.h" #include "pattern_util.h" @@ -311,8 +308,7 @@ class ForwardPrep : private ExprVisitor { static bool IsIntInArray(const Array& axis, int v) { for (size_t i = 0; i < axis.size(); i++) { - if (axis[i] == v) - return true; + if (axis[i] == v) return true; } return false; } @@ -405,8 +401,7 @@ Expr AddSubForwardRewrite(const Call& ref_call, const Array& new_args, if (slhs != nullptr) { CHECK(srhs == nullptr); CHECK(MatchBroadcastToLeftAxes(tlhs, trhs, slhs->axes)); - Expr scale = ReshapeOrExpandToMatchAxis( - slhs->scale, tlhs->shape, slhs->axes); + Expr scale = ReshapeOrExpandToMatchAxis(slhs->scale, tlhs->shape, slhs->axes); if (!scale.defined()) { return Expr(); } @@ -417,8 +412,7 @@ Expr AddSubForwardRewrite(const Call& ref_call, const Array& new_args, } else { CHECK(srhs != nullptr); CHECK(MatchBroadcastToLeftAxes(trhs, tlhs, srhs->axes)); - Expr scale = ReshapeOrExpandToMatchAxis( - srhs->scale, trhs->shape, srhs->axes); + Expr scale = ReshapeOrExpandToMatchAxis(srhs->scale, trhs->shape, srhs->axes); if (!scale.defined()) { return Expr(); } @@ -504,14 +498,14 @@ Array Conv2DForwardPrep(const Call& call, const Message& out_message) { if (param->groups == 1 || is_depthwise_conv2d) { auto ko_small_axis = kernel_layout.IndexOf(LayoutAxis::Get('o')); auto ki_small_axis = kernel_layout.IndexOf(LayoutAxis::Get('i')); - if ( (ko_small_axis < 0 && ki_small_axis < 0 && c_small_axis < 0) || // simple layout + if ((ko_small_axis < 0 && ki_small_axis < 0 && c_small_axis < 0) || // simple layout (ko_small_axis >= 0 && ki_small_axis >= 0 && c_small_axis >= 0)) { // blocked layout - Array arr{c_big_axis}; - if (c_small_axis >= 0) { - arr.push_back(c_small_axis); - } - return {Message(arr, false), none}; + Array arr{c_big_axis}; + if (c_small_axis >= 0) { + arr.push_back(c_small_axis); } + return {Message(arr, false), none}; + } } return {none, none}; } @@ -548,28 +542,24 @@ Expr Conv2DForwardRewrite(const Call& ref_call, const Array& new_args, // match the ic_axis if (is_depthwise_conv2d) { if (is_simple) { - Expr scale = ExpandBiasToMatchAxis( - sdata->scale, kernel_layout.ndim(), {big_ko_axis}); + Expr scale = ExpandBiasToMatchAxis(sdata->scale, kernel_layout.ndim(), {big_ko_axis}); weight = Multiply(weight, scale); } else { - weight = Multiply(weight, ReshapeToMatchAxis(sdata->scale, - weight->type_as()->shape, - {big_ko_axis, small_ko_axis})); - if (!weight.defined()) - return Expr(); + weight = Multiply(weight, + ReshapeToMatchAxis(sdata->scale, weight->type_as()->shape, + {big_ko_axis, small_ko_axis})); + if (!weight.defined()) return Expr(); } } else { if (is_simple) { - Expr scale = ExpandBiasToMatchAxis( - sdata->scale, kernel_layout.ndim(), {big_ki_axis}); + Expr scale = ExpandBiasToMatchAxis(sdata->scale, kernel_layout.ndim(), {big_ki_axis}); weight = Multiply(weight, scale); } else { - weight = Multiply(weight, ReshapeToMatchAxis(sdata->scale, - weight->type_as()->shape, - {big_ki_axis, small_ki_axis})); - if (!weight.defined()) - return Expr(); + weight = Multiply(weight, + ReshapeToMatchAxis(sdata->scale, weight->type_as()->shape, + {big_ki_axis, small_ki_axis})); + if (!weight.defined()) return Expr(); } } // return transformed conv2d @@ -824,10 +814,8 @@ Expr AddSubBackwardTransform(const Call& call, const Message& message, const Exp } else if (lhs_message.defined()) { CHECK(equal(message->axes, lhs_message->axes)); Expr lhs = transformer->Transform(call->args[0], message, scale); - Expr rhs = transformer->Transform( - call->args[1], NullValue(), NullValue()); - Expr rhs_scale = ReshapeOrExpandToMatchAxis(scale, tlhs->shape, - message->axes); + Expr rhs = transformer->Transform(call->args[1], NullValue(), NullValue()); + Expr rhs_scale = ReshapeOrExpandToMatchAxis(scale, tlhs->shape, message->axes); if (!rhs_scale.defined()) { return transformer->NormalCallTransform(call.operator->()); } @@ -837,8 +825,7 @@ Expr AddSubBackwardTransform(const Call& call, const Message& message, const Exp CHECK(equal(message->axes, rhs_message->axes)); Expr lhs = transformer->Transform(call->args[0], NullValue(), NullValue()); Expr rhs = transformer->Transform(call->args[1], message, scale); - Expr lhs_scale = ReshapeOrExpandToMatchAxis( - scale, trhs->shape, message->axes); + Expr lhs_scale = ReshapeOrExpandToMatchAxis(scale, trhs->shape, message->axes); if (!lhs_scale.defined()) { return transformer->NormalCallTransform(call.operator->()); } @@ -914,14 +901,14 @@ Message Conv2DBackwardPrep(const Call& call, const Array& in_messages) if (param->groups == 1 || is_depthwise_conv2d) { auto ko_small_axis = kernel_layout.IndexOf(LayoutAxis::Get('o')); auto ki_small_axis = kernel_layout.IndexOf(LayoutAxis::Get('i')); - if ( (ko_small_axis < 0 && ki_small_axis < 0 && c_small_axis < 0) || // simple layout + if ((ko_small_axis < 0 && ki_small_axis < 0 && c_small_axis < 0) || // simple layout (ko_small_axis >= 0 && ki_small_axis >= 0 && c_small_axis >= 0)) { // blocked layout - Array arr{c_big_axis}; - if (c_small_axis >= 0) { - arr.push_back(c_small_axis); - } - return Message(arr, false); + Array arr{c_big_axis}; + if (c_small_axis >= 0) { + arr.push_back(c_small_axis); } + return Message(arr, false); + } } return NullValue(); } @@ -956,13 +943,11 @@ Expr Conv2DBackwardTransform(const Call& call, const Message& message, const Exp // scale on input for deptwise. Expr wscale; if (is_simple) { - wscale = ExpandBiasToMatchAxis( - scale, kernel_layout.ndim(), {big_ko_axis}); + wscale = ExpandBiasToMatchAxis(scale, kernel_layout.ndim(), {big_ko_axis}); } else { wscale = ReshapeToMatchAxis(scale, weight->type_as()->shape, - {big_ko_axis, small_ko_axis}); - if (!wscale.defined()) - return transformer->NormalCallTransform(call.operator->()); + {big_ko_axis, small_ko_axis}); + if (!wscale.defined()) return transformer->NormalCallTransform(call.operator->()); } weight = Multiply(weight, wscale); return Call(call->op, {data, weight}, call->attrs, call->type_args); From c4346d83bbebef5a10cb2251ecb31e0e80904af4 Mon Sep 17 00:00:00 2001 From: Menooker Date: Tue, 12 May 2020 15:04:23 +0800 Subject: [PATCH 13/13] format changes --- src/relay/op/tensor/transform.h | 3 +-- src/relay/transforms/fold_scale_axis.cc | 1 + 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/relay/op/tensor/transform.h b/src/relay/op/tensor/transform.h index 8a6dfca6b0cc..1d1f9c0b64ee 100644 --- a/src/relay/op/tensor/transform.h +++ b/src/relay/op/tensor/transform.h @@ -38,8 +38,7 @@ namespace tvm { namespace relay { -extern Expr MakeReshape(Expr data, - Array newshape); +extern Expr MakeReshape(Expr data, Array newshape); template bool ConcatenateRel(const Array& types, int num_inputs, const Attrs& attrs, diff --git a/src/relay/transforms/fold_scale_axis.cc b/src/relay/transforms/fold_scale_axis.cc index 6c307cecd42e..4c8025a8d382 100644 --- a/src/relay/transforms/fold_scale_axis.cc +++ b/src/relay/transforms/fold_scale_axis.cc @@ -28,6 +28,7 @@ #include #include #include + #include "../op/tensor/transform.h" #include "pass_util.h" #include "pattern_util.h"