diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index e097980b060c1..9236d6e55fa04 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -546,9 +546,11 @@ def _compile(ic, oc, target, data_layout, kernel_layout, dtypes): n, h, w, ch, cw = 1, 64, 64, 3, 3 if data_layout == 'NCHW': - x = relay.var("x", relay.TensorType((n, ic, h, w), input_dtype)) + data_shape = (n, ic, h, w) + x = relay.var("x", relay.TensorType(data_shape, input_dtype)) elif data_layout == 'NHWC': - x = relay.var("x", relay.TensorType((n, h, w, ic), input_dtype)) + data_shape = (n, h, w, ic) + x = relay.var("x", relay.TensorType(data_shape, input_dtype)) else: raise ValueError('Not supported') @@ -559,8 +561,8 @@ def _compile(ic, oc, target, data_layout, kernel_layout, dtypes): else: raise ValueError('Not supported') - w = relay.var("w", relay.TensorType(kernel_shape, weight_dtype)) - y = relay.nn.conv2d(x, w, + weight = relay.var("weight", relay.TensorType(kernel_shape, weight_dtype)) + y = relay.nn.conv2d(x, weight, kernel_size=(ch, cw), channels=oc, padding=(1, 1), @@ -568,11 +570,13 @@ def _compile(ic, oc, target, data_layout, kernel_layout, dtypes): data_layout=data_layout, kernel_layout=kernel_layout, out_dtype=output_dtype) - func = relay.Function([x, w], y) + func = relay.Function([x, weight], y) wdata = np.random.rand(*kernel_shape) * 10 - parameters = {"w": tvm.nd.array(wdata.astype(weight_dtype))} + parameters = {"weight": tvm.nd.array(wdata.astype(weight_dtype))} + with relay.build_config(opt_level=3): graph, lib, params = relay.build(func, target, params=parameters) + assembly = lib.get_source("asm") return assembly @@ -589,58 +593,63 @@ def _has_fast_int8_instructions(asm, target): llvm_version = tvm.codegen.llvm_version_major() for target in targets: if llvm_version >= 8: - fast_int8_dtypes = ('uint8', 'int8', 'int32') + dtypes = ('uint8', 'int8', 'int32') # Sweep the input channels to check int8 robustness # Input channels should be a multiple of 4 internally. for ic in [1, 4, 6]: - asm = _compile(ic=ic, oc=32, target=target, data_layout="NCHW", + asm = _compile(ic=ic, oc=16, target=target, data_layout="NCHW", kernel_layout='OIHW', - dtypes=fast_int8_dtypes) + dtypes=dtypes) assert _has_fast_int8_instructions(asm, target) for ic in [1, 4, 6]: - asm = _compile(ic=ic, oc=32, target=target, data_layout="NHWC", + asm = _compile(ic=ic, oc=16, target=target, data_layout="NHWC", kernel_layout='HWIO', - dtypes=fast_int8_dtypes) + dtypes=dtypes) assert _has_fast_int8_instructions(asm, target) - # Sweep the output channels to check int8 robustness # Output channels should be a multiple of 16 internally. for oc in [4, 16, 20]: - asm = _compile(ic=16, oc=oc, target=target, data_layout="NCHW", + asm = _compile(ic=8, oc=oc, target=target, data_layout="NCHW", kernel_layout='OIHW', - dtypes=fast_int8_dtypes) + dtypes=dtypes) assert _has_fast_int8_instructions(asm, target) for oc in [4, 16, 20]: - asm = _compile(ic=16, oc=oc, target=target, data_layout="NHWC", + asm = _compile(ic=8, oc=oc, target=target, data_layout="NHWC", kernel_layout='HWIO', - dtypes=fast_int8_dtypes) + dtypes=dtypes) assert _has_fast_int8_instructions(asm, target) # Check that both non-divisible oc and ic work asm = _compile(ic=17, oc=29, target=target, data_layout="NCHW", kernel_layout='OIHW', - dtypes=fast_int8_dtypes) + dtypes=dtypes) assert _has_fast_int8_instructions(asm, target) asm = _compile(ic=17, oc=29, target=target, data_layout="NHWC", kernel_layout='HWIO', - dtypes=fast_int8_dtypes) + dtypes=dtypes) assert _has_fast_int8_instructions(asm, target) - # Ensure that code is generated when datatypes are not HW supported. - dtypes = ('int8', 'int8', 'int32') - asm = _compile(ic=16, oc=32, target=target, data_layout="NHWC", kernel_layout='HWIO', + # Check that int8 x int8 goes through legalization so that fast instructions can be picked up. + for target in targets: + if llvm_version >= 8: + dtypes = (('int8', 'int8', 'int32')) + # Check that both non-divisible oc and ic work + asm = _compile(ic=17, oc=29, target=target, data_layout="NCHW", kernel_layout='OIHW', dtypes=dtypes) - # Check that intrinisic is not present in the assembly. - assert not _has_fast_int8_instructions(asm, target) + assert _has_fast_int8_instructions(asm, target) - # Ensure that code is generated when datatypes are not HW supported. - dtypes = ('uint8', 'uint8', 'int32') - asm = _compile(ic=16, oc=32, target=target, data_layout="NHWC", kernel_layout='HWIO', + asm = _compile(ic=17, oc=29, target=target, data_layout="NHWC", kernel_layout='HWIO', dtypes=dtypes) - # Check that intrinisic is not present in the assembly. - assert not _has_fast_int8_instructions(asm, target) + assert _has_fast_int8_instructions(asm, target) + + # Ensure that code is generated when datatypes are not HW supported. + dtypes = ('uint8', 'uint8', 'int32') + asm = _compile(ic=16, oc=32, target=target, data_layout="NHWC", kernel_layout='HWIO', + dtypes=dtypes) + # Check that intrinisic is not present in the assembly. + assert not _has_fast_int8_instructions(asm, target) # Check that a vectorized instruction is generated for older Intel # generations, because we default to NCHWc layout. diff --git a/topi/python/topi/x86/conv2d_alter_op.py b/topi/python/topi/x86/conv2d_alter_op.py index aec3efc7a86c3..0658562f26218 100644 --- a/topi/python/topi/x86/conv2d_alter_op.py +++ b/topi/python/topi/x86/conv2d_alter_op.py @@ -192,24 +192,72 @@ def _conv2d_legalize(attrs, inputs, arg_types): The legalized expr """ + # Dilation not supported yet. Return None if dilation is not (1, 1) + dilation = attrs.get_int_tuple("dilation") + if not (dilation[0] == 1 and dilation[1] == 1): + return None + # Collect the input tensors. data_tensor, kernel_tensor = arg_types[0], arg_types[1] + data_dtype = data_tensor.dtype + kernel_dtype = kernel_tensor.dtype # Collect the output tensor. output_tensor = arg_types[2] + # Collect the input exprs. + data, kernel = inputs + + # Get the conv attrs + new_attrs = {k: attrs[k] for k in attrs.keys()} + + is_int8_inputs = False + # If both the inputs are int8, we can add 128 to make the input dtype uint8, and then adjust the + # output. This will help picking up Intel VNNI instructions. + # Original --> C = A (conv) B + # A and B are int8 + # C = (A + 128 - 128) (conv) B + # C = (A' conv B) - 128 (conv) B + # where A' = A + 128 + # and 128 (conv) B is basically a reduce on CRS axis for weights. + if data_tensor.dtype == 'int8' and kernel_tensor.dtype == 'int8': + is_int8_inputs = True + padding = attrs.get_int_tuple("padding") + + if attrs['data_layout'] == 'NHWC' and attrs['kernel_layout'] == 'HWIO': + adjust_shift = relay.sum(relay.cast(kernel, dtype='int32'), axis=(0, 1, 2)) + pad_width = ((0,0), (padding[0], padding[0]), (padding[1], padding[1]), (0, 0)) + elif attrs['data_layout'] == 'NCHW' and attrs['kernel_layout'] == 'OIHW': + pad_width = ((0, 0), (0, 0), (padding[0], padding[0]), (padding[1], padding[1])) + adjust_shift = relay.sum(relay.cast(kernel, dtype='int32'), axis=(1, 2, 3)) + adjust_shift = relay.expand_dims(adjust_shift, axis=1, num_newaxis=2) + else: + return None + + data = relay.cast(data, 'int32') + data = relay.add(data, relay.const(128, 'int32')) + data = relay.cast(data, 'uint8') + + # Do external padding as pad value has to be 128. + if not (padding[0] == 0 and padding[0] == 1): + data = relay.nn.pad(data, pad_width=pad_width, pad_value=128) + new_attrs['padding'] = (0, 0) + + # The data type is now shifted to uint8 + data_dtype = 'uint8' + + # Multiply 128 to adjust shift. + adjust_shift = relay.multiply(adjust_shift, relay.const(128, 'int32')) + # Legalize if the datatypes are suitable for fast Int8 instructions. Int8 instructions require # input channel to be a multiple of 4 and output channels to be a multiple of 16. For input # channels, we pad both the inputs and weights input channels. For output channels, we pad the # weight and stride_slice the output. - if _is_int8_hw_support(data_tensor.dtype, kernel_tensor.dtype): + if _is_int8_hw_support(data_dtype, kernel_dtype): # Flags to remember if the expr is modified ic_modified = False oc_modified = False - # Collect the input exprs. - data, kernel = inputs - # Find the value of input and output channel. in_channel = -1 out_channel = -1 @@ -250,16 +298,16 @@ def _conv2d_legalize(attrs, inputs, arg_types): else: return None - if not (ic_modified or oc_modified): - return None - - if ic_modified and not oc_modified: - return relay.nn.conv2d(data, kernel, **attrs) - if oc_modified: - new_attrs = {k: attrs[k] for k in attrs.keys()} new_attrs['channels'] = new_out_channel out = tvm.relay.nn.conv2d(data, kernel, **new_attrs) original_out_shape = [x.value for x in output_tensor.shape] - return relay.strided_slice(out, begin=(0, 0, 0, 0), end=original_out_shape) + out = relay.strided_slice(out, begin=(0, 0, 0, 0), end=original_out_shape) + else: + out = relay.nn.conv2d(data, kernel, **new_attrs) + + if is_int8_inputs: + out = relay.subtract(out, adjust_shift) + + return out return None