Skip to content

Commit

Permalink
Checking the correct dtypes for choosing the Intel int8 instructions.
Browse files Browse the repository at this point in the history
Intel VNNI and Skylake HW-supported int8 instructions require one tensor to be
unsigned int8 and other to be signed int8. Adding more checks in the x86 topi to
satisfy that requirement.
  • Loading branch information
anijain2305 committed Jul 9, 2019
1 parent 2a7aebe commit 36e636e
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions topi/python/topi/x86/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout):
kh, kw, oc, _ = kshape
elif pat.match(layout) is not None:
n, ic_chunk, h, w, ic_bn = dshape
if data.dtype == 'uint8':
if data.dtype == 'uint8' and kernel.dtype == 'int8': # VNNI takes u8 x s8
oc_chunk, k_ic, kh, kw, k_ic_f, oc_bn, k_ic_s = kshape
ic = ic_chunk*ic_bn
assert ic == k_ic*k_ic_f*kic_s
Expand Down Expand Up @@ -275,7 +275,7 @@ def traverse(op):
data = data_pad.op.input_tensors[0]

args = [s, cfg, data_vec, conv_out, outs[0]]
if data.dtype == 'uint8':
if data.dtype == 'uint8' and kernel.dtype == 'int8': # VNNI takes u8 x s8
# int8 conv kernel is 7-dim
kh, kw, _, _, _ = get_const_tuple(kernel.shape)
if kh == 1 and kw == 1:
Expand Down Expand Up @@ -505,7 +505,7 @@ def _declaration_conv_NCHWc(cfg, data, kernel, strides,

n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape)
in_channel = ic_chunk * ic_bn
if data.dtype == 'uint8':
if data.dtype == 'uint8' and kernel.dtype == 'int8': # VNNI takes u8 x s8
oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn, _ = \
get_const_tuple(kernel.shape)
else:
Expand Down Expand Up @@ -539,7 +539,8 @@ def _declaration_conv_NCHWc(cfg, data, kernel, strides,
kh = tvm.reduce_axis((0, kernel_height), name='kh')
kw = tvm.reduce_axis((0, kernel_width), name='kw')

if data.dtype == 'uint8' and groups == 1:
# Check for Intel VNNI input datatypes
if data.dtype == 'uint8' and kernel.dtype == 'int8' and groups == 1:
assert out_dtype == "int32", \
"INT8 convolution requires input dtype = uint8 and output dtype=int32"
# Intel performs dot product of 2 "4" Int8 values
Expand All @@ -559,7 +560,9 @@ def _declaration_conv_NCHWc(cfg, data, kernel, strides,
oc_block, ic_s_inner].astype(out_dtype),
axis=[kh, kw, ic_outer, ic_f_inner, ic_s_inner]),
name='conv2d_NCHWc_int8', tag="conv2d_NCHWc_int8")
if data.dtype == 'uint8':

# Check for Intel VNNI input datatypes
if data.dtype == 'uint8' and kernel.dtype == 'int8':
# for int8 group conv support
n_elems = 4
ic_chunk = in_channel//ic_bn
Expand Down Expand Up @@ -615,7 +618,8 @@ def traverse(op):
data = data_pad.op.input_tensors[0]

args = [s, cfg, data_vec, conv_out, outs[0]]
if data.dtype == 'uint8':
# VNNI takes u8 x s8
if data.dtype == 'uint8' and kernel.dtype == 'int8':
# int8 conv kernel is 7-dim
_, _, kh, kw, _, _, _ = get_const_tuple(kernel.shape)
if kh == 1 and kw == 1:
Expand Down

0 comments on commit 36e636e

Please sign in to comment.