diff --git a/python/tvm/relay/op/strategy/arm_cpu.py b/python/tvm/relay/op/strategy/arm_cpu.py index bcef8ab43a243..a76d61884a432 100644 --- a/python/tvm/relay/op/strategy/arm_cpu.py +++ b/python/tvm/relay/op/strategy/arm_cpu.py @@ -65,21 +65,22 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): wrap_topi_schedule(topi.x86.schedule_conv2d_nchw), name="conv2d_nchw.x86") # check if winograd algorithm is applicable - _, _, kh, kw = get_const_tuple(kernel.shape) - pt, pl, pb, pr = topi.nn.get_pad_tuple(padding, (kh, kw)) - if kh == 3 and kw == 3 and stride_h == 1 and stride_w == 1 and \ - dilation_h == 1 and dilation_w == 1: - strategy.add_implementation( - wrap_compute_conv2d(topi.arm_cpu.conv2d_nchw_winograd), - wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nchw_winograd), - name="conv2d_nchw_winograd.arm_cpu", - plevel=5) - if "nnpack" in target.libs and pt == 1 and pb == 1 and pl == 1 and pr == 1: + if data.dtype == "float32" and kernel.dtype == "float32": + _, _, kh, kw = get_const_tuple(kernel.shape) + pt, pl, pb, pr = topi.nn.get_pad_tuple(padding, (kh, kw)) + if kh == 3 and kw == 3 and stride_h == 1 and stride_w == 1 and \ + dilation_h == 1 and dilation_w == 1: strategy.add_implementation( - wrap_compute_conv2d(topi.arm_cpu.conv2d_nchw_winograd_nnpack), - wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nchw_winograd_nnpack), - name="conv2d_nchw_winograd_nnpack.arm_cpu", - plevel=15) + wrap_compute_conv2d(topi.arm_cpu.conv2d_nchw_winograd), + wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nchw_winograd), + name="conv2d_nchw_winograd.arm_cpu", + plevel=5) + if "nnpack" in target.libs and pt == 1 and pb == 1 and pl == 1 and pr == 1: + strategy.add_implementation( + wrap_compute_conv2d(topi.arm_cpu.conv2d_nchw_winograd_nnpack), + wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nchw_winograd_nnpack), + name="conv2d_nchw_winograd_nnpack.arm_cpu", + plevel=15) elif re.match(r"OIHW\d*o", kernel_layout): strategy.add_implementation( wrap_compute_conv2d(topi.arm_cpu.conv2d_nchw_spatial_pack),