Skip to content

Commit

Permalink
[x86 schedule] Fallback schedule for Int8 depthwise. (apache#4733)
Browse files Browse the repository at this point in the history
  • Loading branch information
anijain2305 authored and alexwong committed Feb 28, 2020
1 parent 1838fcb commit 708e9ad
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 8 deletions.
30 changes: 30 additions & 0 deletions tests/python/relay/test_op_level2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1182,6 +1182,35 @@ def _has_fast_int8_instructions(asm, target):
assert "vpmulld" in asm and "vpadd" in asm


def test_depthwise_conv2d_int8():
input_dtype = 'uint8'
weight_dtype = 'int8'
output_dtype = 'int32'

data_shape = (1, 64, 56, 56)
x = relay.var("x", relay.TensorType(data_shape, input_dtype))

kernel_shape = (64, 1, 3, 3)
weight = relay.var("weight", relay.TensorType(kernel_shape, weight_dtype))

y = relay.nn.conv2d(x, weight,
kernel_size=(3, 3),
groups=64,
padding=(1, 1),
dilation=(1, 1),
out_dtype=output_dtype)
func = relay.Function([x, weight], y)
wdata = np.random.rand(*kernel_shape) * 10
parameters = {"weight": tvm.nd.array(wdata.astype(weight_dtype))}

targets = ["llvm -mcpu=skylake-avx512", "llvm -mcpu=cascadelake"]
llvm_version = tvm.codegen.llvm_version_major()
for target in targets:
if llvm_version >= 8:
with relay.build_config(opt_level=3):
graph, lib, params = relay.build(func, target, params=parameters)


def test_bitserial_conv2d_infer_type():
# Basic shape test with ambiguous batch.
n, c, h, w = tvm.size_var("n"), 32, 224, 224
Expand Down Expand Up @@ -1234,3 +1263,4 @@ def test_bitpack_infer_type():
test_upsampling()
test_upsampling3d()
test_conv2d_int8_intrinsics()
test_depthwise_conv2d_int8()
22 changes: 14 additions & 8 deletions topi/python/topi/x86/conv2d_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from ..nn.util import get_pad_tuple
from ..util import get_const_tuple
from ..nn.conv2d import conv2d_NCHWc_int8
from ..nn.depthwise_conv2d import _get_workload as _get_depthwise_conv2d_workload
from .. import nn
from . import conv2d_avx_1x1, conv2d_avx_common

Expand All @@ -36,15 +37,20 @@ def _get_default_config_int8(cfg, data, kernel, strides, padding, out_dtype, is_
"""
Get default schedule config for the workload
"""
assert not is_depthwise, "Depthwise Int8 not supported"
wkl = _get_conv2d_workload(data, kernel, strides, padding, out_dtype, layout)
is_kernel_1x1 = wkl.hkernel == 1 and wkl.wkernel == 1
if is_kernel_1x1:
conv2d_generic.fallback_schedule_cpu_1x1_int8(
cfg, wkl, int32_lanes=16, num_int8_elements=4)
if is_depthwise:
# Fallback to FP32 default config until a VNNI schedule is defined.
wkl = _get_depthwise_conv2d_workload(data, kernel, strides, padding, out_dtype)
from .depthwise_conv2d import _fallback_schedule
_fallback_schedule(cfg, wkl)
else:
conv2d_generic.fallback_schedule_cpu_common_int8(
cfg, wkl, int32_lanes=16, num_int8_elements=4)
wkl = _get_conv2d_workload(data, kernel, strides, padding, out_dtype, layout)
is_kernel_1x1 = wkl.hkernel == 1 and wkl.wkernel == 1
if is_kernel_1x1:
conv2d_generic.fallback_schedule_cpu_1x1_int8(
cfg, wkl, int32_lanes=16, num_int8_elements=4)
else:
conv2d_generic.fallback_schedule_cpu_common_int8(
cfg, wkl, int32_lanes=16, num_int8_elements=4)


def _is_int8_hw_support(data_dtype, kernel_dtype):
Expand Down

0 comments on commit 708e9ad

Please sign in to comment.