Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[UnitTest] Parametrized test_conv2d_int8_intrinsics #9143

Merged
merged 1 commit into from
Sep 29, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
225 changes: 106 additions & 119 deletions tests/python/relay/test_op_level2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1587,156 +1587,143 @@ def test_upsampling3d():
_test_upsampling3d("NDHWC", "trilinear", "align_corners")


@tvm.testing.uses_gpu
def test_conv2d_int8_intrinsics():
def _compile(ic, oc, target, data_layout, kernel_layout, dtypes):
@pytest.mark.skipif(tvm.target.codegen.llvm_version_major() < 8, reason="Requires LLVM 8")
class TestConv2DInt8Intrinsics:
supported_targets = [
"llvm -mcpu=nehalem",
"llvm -mcpu=core-avx2",
"llvm -mcpu=skylake-avx512",
"llvm -mcpu=cascadelake",
]

unsupported_targets = [
"llvm -mcpu=x86-64",
]

data_layout, kernel_layout = tvm.testing.parameters(
("NCHW", "OIHW"),
# TODO(@anijain2305, @icemelon9): disable conv2d_int8 for NHWC data layout.
# Re-enable this after adding conv2d_NCHWc_int8 support for NHWC.
# ("NHWC", "HWIO"),
)

input_channels, output_channels = tvm.testing.parameters(
# Sweep the input channels to check int8 robustness
# Input channels should be a multiple of 4 internally.
(1, 16),
(4, 16),
(6, 16),
# Sweep the output channels to check int8 robustness
# Output channels should be a multiple of 16 internally.
(8, 4),
(8, 16),
(8, 20),
# Check that both non-divisible oc and ic work
(17, 29),
)

@tvm.testing.fixture
def fast_int8_intrinsic(self, target):
if "nehalem" in target or "core-avx2" in target or "skylake-avx512" in target:
return "pmaddubs"
elif "cascadelake" in target:
return "vpdpbusd"
else:
assert False, "Target should be Skylake or Cascadelake"

@tvm.testing.fixture
def assembly(
self,
target,
dtypes,
input_channels,
output_channels,
data_layout,
kernel_layout,
):
input_dtype, weight_dtype, output_dtype = dtypes

n, h, w, ch, cw = 1, 64, 64, 3, 3
image_size = (64, 64)
kernel_size = (3, 3)
batch_size = 1

h, w = image_size

if data_layout == "NCHW":
data_shape = (n, ic, h, w)
x = relay.var("x", relay.TensorType(data_shape, input_dtype))
data_shape = (batch_size, input_channels, *image_size)
elif data_layout == "NHWC":
data_shape = (n, h, w, ic)
x = relay.var("x", relay.TensorType(data_shape, input_dtype))
data_shape = (batch_size, *image_size, input_channels)
else:
raise ValueError("Not supported")
raise ValueError(f"Unsupported data layout: {data_layout}")
x = relay.var("x", relay.TensorType(data_shape, input_dtype))

if kernel_layout == "OIHW":
kernel_shape = (oc, ic, ch, cw)
kernel_shape = (output_channels, input_channels, *kernel_size)
elif kernel_layout == "HWIO":
kernel_shape = (ch, cw, ic, oc)
kernel_shape = (*kernel_size, input_channels, output_channels)
else:
raise ValueError("Not supported")

weight = relay.var("weight", relay.TensorType(kernel_shape, weight_dtype))

y = relay.nn.conv2d(
x,
weight,
kernel_size=(ch, cw),
channels=oc,
kernel_size=kernel_size,
channels=output_channels,
padding=(0, 0, 0, 1),
dilation=(1, 1),
data_layout=data_layout,
kernel_layout=kernel_layout,
out_dtype=output_dtype,
)

func = relay.Function([x, weight], y)

wdata = np.random.rand(*kernel_shape) * 10
parameters = {"weight": tvm.nd.array(wdata.astype(weight_dtype))}

with tvm.transform.PassContext(opt_level=3):
graph, lib, params = relay.build(func, target, params=parameters)

assembly = lib.get_source("asm")
return assembly

def _has_fast_int8_instructions(asm, target):
if "nehalem" in target or "core-avx2" in target or "skylake-avx512" in target:
return "pmaddubs" in asm
elif "cascadelake" in target:
return "vpdpbusd" in asm
else:
assert False, "Target should be Skylake or Cascadelake"

# TODO(@anijain2305, @icemelon9): disable conv2d_int8 for NHWC data layout.
# Re-enable this after adding conv2d_NCHWc_int8 support for NHWC.

# compile conv2d for x86 (SSE3/AVX2/AVX512/VNNI capable) and test assembly contains *pmadd* instructions
targets = [
"llvm -mcpu=nehalem",
"llvm -mcpu=core-avx2",
"llvm -mcpu=skylake-avx512",
"llvm -mcpu=cascadelake",
]
llvm_version = tvm.target.codegen.llvm_version_major()
for target in targets:
if tvm.testing.device_enabled(target) and llvm_version >= 8:
dtypes = ("uint8", "int8", "int32")
# Sweep the input channels to check int8 robustness
# Input channels should be a multiple of 4 internally.
for ic in [1, 4, 6]:
asm = _compile(
ic=ic,
oc=16,
target=target,
data_layout="NCHW",
kernel_layout="OIHW",
dtypes=dtypes,
)
assert _has_fast_int8_instructions(asm, target)

# for ic in [1, 4, 6]:
# asm = _compile(ic=ic, oc=16, target=target, data_layout="NHWC",
# kernel_layout='HWIO',
# dtypes=dtypes)
# assert _has_fast_int8_instructions(asm, target)

# Sweep the output channels to check int8 robustness
# Output channels should be a multiple of 16 internally.
for oc in [4, 16, 20]:
asm = _compile(
ic=8,
oc=oc,
target=target,
data_layout="NCHW",
kernel_layout="OIHW",
dtypes=dtypes,
)
assert _has_fast_int8_instructions(asm, target)

# for oc in [4, 16, 20]:
# asm = _compile(ic=8, oc=oc, target=target, data_layout="NHWC",
# kernel_layout='HWIO',
# dtypes=dtypes)
# assert _has_fast_int8_instructions(asm, target)

# Check that both non-divisible oc and ic work
asm = _compile(
ic=17, oc=29, target=target, data_layout="NCHW", kernel_layout="OIHW", dtypes=dtypes
)
assert _has_fast_int8_instructions(asm, target)

# asm = _compile(ic=17, oc=29, target=target, data_layout="NHWC", kernel_layout='HWIO',
# dtypes=dtypes)
# assert _has_fast_int8_instructions(asm, target)

# Check that int8 x int8 goes through legalization so that fast instructions can be picked up.
for target in targets:
if tvm.testing.device_enabled(target) and llvm_version >= 8:
dtypes = ("int8", "int8", "int32")
# Check that both non-divisible oc and ic work
asm = _compile(
ic=17, oc=29, target=target, data_layout="NCHW", kernel_layout="OIHW", dtypes=dtypes
)
assert _has_fast_int8_instructions(asm, target)

# asm = _compile(ic=17, oc=29, target=target, data_layout="NHWC", kernel_layout='HWIO',
# dtypes=dtypes)
# assert _has_fast_int8_instructions(asm, target)
return lib.get_source("asm")

# Ensure that code uses the fast int8 instructions when available.
@tvm.testing.parametrize_targets(*supported_targets)
@pytest.mark.parametrize(
"dtypes",
[
# compile conv2d for x86 (skylake, cascadelake) and test
# assembly contains *pmadd* instructions
("uint8", "int8", "int32"),
# Check that int8 x int8 goes through legalization so that
# fast instructions can be picked up.
("int8", "int8", "int32"),
],
)
def test_uses_intrinsic(
self,
fast_int8_intrinsic,
assembly,
):
assert fast_int8_intrinsic in assembly

# Ensure that code is generated when datatypes are not HW supported.
# dtypes = ('uint8', 'uint8', 'int32')
# asm = _compile(ic=16, oc=32, target=target, data_layout="NHWC", kernel_layout='HWIO',
# dtypes=dtypes)
# # Check that intrinisic is not present in the assembly.
# assert not _has_fast_int8_instructions(asm, target)
# For datatypes that don't have HW support, ensure that code is
# generated without the fast int8 intrinsic.
@tvm.testing.parametrize_targets(*supported_targets)
@pytest.mark.parametrize("dtypes", [("uint8", "uint8", "int32")])
def test_no_intrinsic(
self,
fast_int8_intrinsic,
assembly,
):
assert fast_int8_intrinsic not in assembly

# Check that a vectorized instruction is generated for older Intel
# generations, because we default to NCHWc layout.
target = "llvm -mcpu=x86-64"
if tvm.testing.device_enabled(target):
fast_int8_dtypes = ("uint8", "int8", "int32")
asm = _compile(
ic=16,
oc=32,
target=target,
data_layout="NCHW",
kernel_layout="OIHW",
dtypes=fast_int8_dtypes,
)
# Check that vector int mult and add instructions are generated.
assert "pmulhw" in asm and "paddd" in asm
@tvm.testing.parametrize_targets(*unsupported_targets)
@pytest.mark.parametrize("dtypes", [("uint8", "int8", "int32")])
def test_uses_vectorized_instruction(self, assembly):
assert "pmulhw" in assembly and "paddd" in assembly


@tvm.testing.uses_gpu
Expand Down