Skip to content

Commit

Permalink
[Topi, ARM] Disbale Winograd for quantized tensors.
Browse files Browse the repository at this point in the history
  • Loading branch information
anijain2305 committed Apr 17, 2020
1 parent 84d1eec commit 717f86f
Showing 1 changed file with 15 additions and 14 deletions.
29 changes: 15 additions & 14 deletions python/tvm/relay/op/strategy/arm_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,21 +65,22 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
wrap_topi_schedule(topi.x86.schedule_conv2d_nchw),
name="conv2d_nchw.x86")
# check if winograd algorithm is applicable
_, _, kh, kw = get_const_tuple(kernel.shape)
pt, pl, pb, pr = topi.nn.get_pad_tuple(padding, (kh, kw))
if kh == 3 and kw == 3 and stride_h == 1 and stride_w == 1 and \
dilation_h == 1 and dilation_w == 1:
strategy.add_implementation(
wrap_compute_conv2d(topi.arm_cpu.conv2d_nchw_winograd),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nchw_winograd),
name="conv2d_nchw_winograd.arm_cpu",
plevel=5)
if "nnpack" in target.libs and pt == 1 and pb == 1 and pl == 1 and pr == 1:
if data.dtype == "float32" and kernel.dtype == "float32":
_, _, kh, kw = get_const_tuple(kernel.shape)
pt, pl, pb, pr = topi.nn.get_pad_tuple(padding, (kh, kw))
if kh == 3 and kw == 3 and stride_h == 1 and stride_w == 1 and \
dilation_h == 1 and dilation_w == 1:
strategy.add_implementation(
wrap_compute_conv2d(topi.arm_cpu.conv2d_nchw_winograd_nnpack),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nchw_winograd_nnpack),
name="conv2d_nchw_winograd_nnpack.arm_cpu",
plevel=15)
wrap_compute_conv2d(topi.arm_cpu.conv2d_nchw_winograd),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nchw_winograd),
name="conv2d_nchw_winograd.arm_cpu",
plevel=5)
if "nnpack" in target.libs and pt == 1 and pb == 1 and pl == 1 and pr == 1:
strategy.add_implementation(
wrap_compute_conv2d(topi.arm_cpu.conv2d_nchw_winograd_nnpack),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nchw_winograd_nnpack),
name="conv2d_nchw_winograd_nnpack.arm_cpu",
plevel=15)
elif re.match(r"OIHW\d*o", kernel_layout):
strategy.add_implementation(
wrap_compute_conv2d(topi.arm_cpu.conv2d_nchw_spatial_pack),
Expand Down

0 comments on commit 717f86f

Please sign in to comment.