From 06a216e6f9ef174048cb6a4754b3d9d481db9450 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Mon, 25 Mar 2019 02:38:46 +0800 Subject: [PATCH] [Relay][Op] Add group conv2d dispatch to topi function (#2870) * [Relay][Op] Add group conv2d dispatch to topi function * Rerun tests --- python/tvm/relay/op/nn/_nn.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 58de44c2e0b5..d38f40ac373b 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -96,6 +96,9 @@ def compute_conv2d(attrs, inputs, out_type, target): get_const_int(inputs[1].shape[3]) == 1: out = topi.nn.depthwise_conv2d_nhwc( inputs[0], inputs[1], strides, padding, dilation, out_dtype=out_dtype) + elif layout in ['NCHW', 'NCHW4c']: + out = topi.nn.group_conv2d_nchw(inputs[0], inputs[1], strides, padding, dilation, groups, + out_dtype=out_dtype) else: raise ValueError("not support arbitrary group number for now") return [out] @@ -120,6 +123,8 @@ def schedule_conv2d(attrs, outs, target): return topi.generic.schedule_depthwise_conv2d_nchw(outs) if layout == "NHWC" and kernel_layout == "HWOI": return topi.generic.schedule_depthwise_conv2d_nhwc(outs) + if layout == "NCHW4c": + return topi.generic.schedule_group_conv2d_nchw(outs) raise ValueError("No compatible schedule")