Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Testing] Add decorator tvm.testing.requires_cuda_compute_version #12778

Merged
merged 2 commits into from
Sep 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions python/tvm/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1058,6 +1058,50 @@ def inner(func):
return inner


def requires_cuda_compute_version(major_version, minor_version=0):
"""Mark a test as requiring at least a compute architecture

Unit test marked with this decorator will run only if the CUDA
compute architecture of the GPU is at least `(major_version,
minor_version)`.

This also marks the test as requiring a cuda support.

Parameters
----------
major_version: int

The major version of the (major,minor) version tuple.

minor_version: int

The minor version of the (major,minor) version tuple.
"""
min_version = (major_version, minor_version)
try:
arch = tvm.contrib.nvcc.get_target_compute_version()
compute_version = tvm.contrib.nvcc.parse_compute_version(arch)
except ValueError:
# No GPU present. This test will be skipped from the
# requires_cuda() marks as well.
compute_version = (0, 0)

min_version_str = ".".join(str(v) for v in min_version)
compute_version_str = ".".join(str(v) for v in compute_version)
requires = [
pytest.mark.skipif(
compute_version < min_version,
reason=f"Requires CUDA compute >= {min_version_str}, but have {compute_version_str}",
),
*requires_cuda.marks(),
]

def inner(func):
return _compose([func], requires)

return inner


def skip_if_32bit(reason):
def decorator(*args):
if "32bit" in platform.architecture()[0]:
Expand Down
7 changes: 1 addition & 6 deletions tests/python/unittest/test_tir_ptx_cp_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,9 @@ def ptx_cp_async(A: T.Buffer[(32, 128), "float16"], B: T.Buffer[(32, 128), "floa
B[tx, i] = A_shared[tx, i]


@tvm.testing.requires_cuda
@tvm.testing.requires_cuda_compute_version(8)
def test_ptx_cp_async():
f = ptx_cp_async
arch = tvm.contrib.nvcc.get_target_compute_version()
major, _ = tvm.contrib.nvcc.parse_compute_version(arch)
if major < 8:
# Require at least SM80
return

mod = tvm.build(f, target="cuda")
A_np = np.random.rand(32, 128).astype("float16")
Expand Down
8 changes: 2 additions & 6 deletions tests/python/unittest/test_tir_ptx_ldmatrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,11 @@ def ptx_ldmatrix(
B[8 * j + tx // 4, 8 * k + (tx % 4) * 2 + i] = A_local[4 * k + 2 * j + i]


@tvm.testing.requires_cuda
@tvm.testing.requires_cuda_compute_version(7, 5)
def test_ptx_ldmatrix():
f = ptx_ldmatrix
_, _, param_num, param_trans = f.params
arch = tvm.contrib.nvcc.get_target_compute_version()
major, minor = tvm.contrib.nvcc.parse_compute_version(arch)
if major * 10 + minor < 75:
# Require at least SM75
return

for num in [1, 2, 4]:
for trans in [False, True]:
mod = tvm.build(f.specialize({param_num: num, param_trans: trans}), target="cuda")
Expand Down
146 changes: 18 additions & 128 deletions tests/python/unittest/test_tir_ptx_mma.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,9 @@ def gemm_mma_m8n8k4_row_col_fp64pf64fp64(a: T.handle, b: T.handle, c: T.handle):
C[(tx % 32) // 4, (tx % 32) % 4 * 2 + mma_accum_c_id] = Accum[mma_accum_c_id]


@tvm.testing.requires_cuda
@tvm.testing.requires_cuda_compute_version(8)
def test_gemm_mma_m8n8k4_row_col_fp64pf64fp64():
sch = tvm.tir.Schedule(gemm_mma_m8n8k4_row_col_fp64pf64fp64)
arch = tvm.contrib.nvcc.get_target_compute_version()
major, minor = tvm.contrib.nvcc.parse_compute_version(arch)
if major < 8:
# Require at least SM80
return
cuda_mod = tvm.build(sch.mod, target="cuda")

A_np = np.random.uniform(-1, 1, [8, 4]).astype("float64")
Expand Down Expand Up @@ -147,14 +142,9 @@ def gemm_mma_m8n8k4_row_row_fp16fp16fp16(a: T.handle, b: T.handle, c: T.handle):
] = Accum[mma_accum_c_id]


@tvm.testing.requires_cuda
@tvm.testing.requires_cuda_compute_version(7)
def test_gemm_mma_m8n8k4_row_row_fp16fp16fp16():
sch = tvm.tir.Schedule(gemm_mma_m8n8k4_row_row_fp16fp16fp16)
arch = tvm.contrib.nvcc.get_target_compute_version()
major, minor = tvm.contrib.nvcc.parse_compute_version(arch)
if major < 7:
# Require at least SM70
return
cuda_mod = tvm.build(sch.mod, target="cuda")

A_np = np.random.uniform(-1, 1, [16, 4]).astype("float16")
Expand Down Expand Up @@ -235,14 +225,9 @@ def gemm_mma_m8n8k4_row_row_fp16fp16fp32(a: T.handle, b: T.handle, c: T.handle):
] = Accum[mma_accum_c_id]


@tvm.testing.requires_cuda
@tvm.testing.requires_cuda_compute_version(7)
def test_gemm_mma_m8n8k4_row_row_fp16fp16fp32():
sch = tvm.tir.Schedule(gemm_mma_m8n8k4_row_row_fp16fp16fp32)
arch = tvm.contrib.nvcc.get_target_compute_version()
major, minor = tvm.contrib.nvcc.parse_compute_version(arch)
if major < 7:
# Require at least SM70
return
cuda_mod = tvm.build(sch.mod, target="cuda")

A_np = np.random.uniform(-1, 1, [16, 4]).astype("float16")
Expand Down Expand Up @@ -311,14 +296,9 @@ def gemm_mma_m8n8k16_row_col_s8s8s32(a: T.handle, b: T.handle, c: T.handle):
# Failure occurs during the external call to nvcc, when attempting to
# generate the .fatbin file.
@tvm.testing.requires_nvcc_version(11)
@tvm.testing.requires_cuda
@tvm.testing.requires_cuda_compute_version(7, 5)
def test_gemm_mma_m8n8k16_row_col_s8s8s32():
sch = tvm.tir.Schedule(gemm_mma_m8n8k16_row_col_s8s8s32)
arch = tvm.contrib.nvcc.get_target_compute_version()
major, minor = tvm.contrib.nvcc.parse_compute_version(arch)
if major * 10 + minor < 75:
# Require at least SM75
return
cuda_mod = tvm.build(sch.mod, target="cuda")

A_np = np.random.uniform(-10, 10, [8, 16]).astype("int8")
Expand Down Expand Up @@ -387,14 +367,9 @@ def gemm_mma_m8n8k16_row_col_s8u8s32(a: T.handle, b: T.handle, c: T.handle):
# Failure occurs during the external call to nvcc, when attempting to
# generate the .fatbin file.
@tvm.testing.requires_nvcc_version(11)
@tvm.testing.requires_cuda
@tvm.testing.requires_cuda_compute_version(7, 5)
def test_gemm_mma_m8n8k16_row_col_s8u8s32():
sch = tvm.tir.Schedule(gemm_mma_m8n8k16_row_col_s8u8s32)
arch = tvm.contrib.nvcc.get_target_compute_version()
major, minor = tvm.contrib.nvcc.parse_compute_version(arch)
if major * 10 + minor < 75:
# Require at least SM75
return
cuda_mod = tvm.build(sch.mod, target="cuda")

A_np = np.random.uniform(-10, 10, [8, 16]).astype("int8")
Expand Down Expand Up @@ -463,14 +438,9 @@ def gemm_mma_m8n8k32_row_col_s4s4s32(a: T.handle, b: T.handle, c: T.handle):
# Failure occurs during the external call to nvcc, when attempting to
# generate the .fatbin file.
@tvm.testing.requires_nvcc_version(11)
@tvm.testing.requires_cuda
@tvm.testing.requires_cuda_compute_version(7, 5)
def test_gemm_mma_m8n8k32_row_col_s4s4s32():
sch = tvm.tir.Schedule(gemm_mma_m8n8k32_row_col_s4s4s32)
arch = tvm.contrib.nvcc.get_target_compute_version()
major, minor = tvm.contrib.nvcc.parse_compute_version(arch)
if major * 10 + minor < 75:
# Require at least SM75
return
cuda_mod = tvm.build(sch.mod, target="cuda")

ctx = tvm.cuda()
Expand Down Expand Up @@ -531,14 +501,9 @@ def gemm_mma_m8n8k32_row_col_s4u4s32(a: T.handle, b: T.handle, c: T.handle):
# Failure occurs during the external call to nvcc, when attempting to
# generate the .fatbin file.
@tvm.testing.requires_nvcc_version(11)
@tvm.testing.requires_cuda
@tvm.testing.requires_cuda_compute_version(7, 5)
def test_gemm_mma_m8n8k32_row_col_s4u4s32():
sch = tvm.tir.Schedule(gemm_mma_m8n8k32_row_col_s4u4s32)
arch = tvm.contrib.nvcc.get_target_compute_version()
major, minor = tvm.contrib.nvcc.parse_compute_version(arch)
if major * 10 + minor < 75:
# Require at least SM75
return
cuda_mod = tvm.build(sch.mod, target="cuda")

ctx = tvm.cuda()
Expand Down Expand Up @@ -601,14 +566,9 @@ def gemm_mma_m16n8k8_row_col_fp16fp16fp32(a: T.handle, b: T.handle, c: T.handle)
]


@tvm.testing.requires_cuda
@tvm.testing.requires_cuda_compute_version(8)
def test_gemm_mma_m16n8k8_row_col_fp16fp16fp32():
sch = tvm.tir.Schedule(gemm_mma_m16n8k8_row_col_fp16fp16fp32)
arch = tvm.contrib.nvcc.get_target_compute_version()
major, minor = tvm.contrib.nvcc.parse_compute_version(arch)
if major < 8:
# Require at least SM80
return
cuda_mod = tvm.build(sch.mod, target="cuda")

A_np = np.random.uniform(-1, 1, [16, 8]).astype("float16")
Expand Down Expand Up @@ -682,15 +642,9 @@ def gemm_mma_m16n8k16_row_col_fp16fp16fp16(a: T.handle, b: T.handle, c: T.handle
] = Accum[mma_accum_c_id]


@tvm.testing.requires_cuda
@tvm.testing.requires_cuda_compute_version(8)
def test_gemm_mma_m16n8k16_row_col_fp16fp16fp16():
sch = tvm.tir.Schedule(gemm_mma_m16n8k16_row_col_fp16fp16fp16)
arch = tvm.contrib.nvcc.get_target_compute_version()
major, minor = tvm.contrib.nvcc.parse_compute_version(arch)
if major < 8:
# Require at least SM80
return
cuda_mod = tvm.build(sch.mod, target="cuda")
cuda_mod = tvm.build(sch.mod, target="cuda")

A_np = np.random.uniform(-1, 1, [16, 16]).astype("float16")
Expand Down Expand Up @@ -764,15 +718,9 @@ def gemm_mma_m16n8k16_row_col_fp16fp16fp32(a: T.handle, b: T.handle, c: T.handle
] = Accum[mma_accum_c_id]


@tvm.testing.requires_cuda
@tvm.testing.requires_cuda_compute_version(8)
def test_gemm_mma_m16n8k16_row_col_fp16fp16fp32():
sch = tvm.tir.Schedule(gemm_mma_m16n8k16_row_col_fp16fp16fp32)
arch = tvm.contrib.nvcc.get_target_compute_version()
major, minor = tvm.contrib.nvcc.parse_compute_version(arch)
if major < 8:
# Require at least SM80
return
cuda_mod = tvm.build(sch.mod, target="cuda")
cuda_mod = tvm.build(sch.mod, target="cuda")

A_np = np.random.uniform(-1, 1, [16, 16]).astype("float16")
Expand Down Expand Up @@ -846,15 +794,9 @@ def gemm_mma_m16n8k16_row_col_s8s8s32(a: T.handle, b: T.handle, c: T.handle):
] = Accum[mma_accum_c_id]


@tvm.testing.requires_cuda
@tvm.testing.requires_cuda_compute_version(8)
def test_gemm_mma_m16n8k16_row_col_s8s8s32():
sch = tvm.tir.Schedule(gemm_mma_m16n8k16_row_col_s8s8s32)
arch = tvm.contrib.nvcc.get_target_compute_version()
major, minor = tvm.contrib.nvcc.parse_compute_version(arch)
if major < 8:
# Require at least SM80
return
cuda_mod = tvm.build(sch.mod, target="cuda")
cuda_mod = tvm.build(sch.mod, target="cuda")

A_np = np.random.uniform(-10, 10, [16, 16]).astype("int8")
Expand Down Expand Up @@ -928,15 +870,9 @@ def gemm_mma_m16n8k16_row_col_s8u8s32(a: T.handle, b: T.handle, c: T.handle):
] = Accum[mma_accum_c_id]


@tvm.testing.requires_cuda
@tvm.testing.requires_cuda_compute_version(8)
def test_gemm_mma_m16n8k16_row_col_s8u8s32():
sch = tvm.tir.Schedule(gemm_mma_m16n8k16_row_col_s8u8s32)
arch = tvm.contrib.nvcc.get_target_compute_version()
major, minor = tvm.contrib.nvcc.parse_compute_version(arch)
if major < 8:
# Require at least SM80
return
cuda_mod = tvm.build(sch.mod, target="cuda")
cuda_mod = tvm.build(sch.mod, target="cuda")

A_np = np.random.uniform(-10, 10, [16, 16]).astype("int8")
Expand Down Expand Up @@ -1010,15 +946,9 @@ def gemm_mma_m16n8k32_row_col_s8s8s32(a: T.handle, b: T.handle, c: T.handle):
] = Accum[mma_accum_c_id]


@tvm.testing.requires_cuda
@tvm.testing.requires_cuda_compute_version(8)
def test_gemm_mma_m16n8k32_row_col_s8s8s32():
sch = tvm.tir.Schedule(gemm_mma_m16n8k32_row_col_s8s8s32)
arch = tvm.contrib.nvcc.get_target_compute_version()
major, minor = tvm.contrib.nvcc.parse_compute_version(arch)
if major < 8:
# Require at least SM80
return
cuda_mod = tvm.build(sch.mod, target="cuda")
cuda_mod = tvm.build(sch.mod, target="cuda")

A_np = np.random.uniform(-10, 10, [16, 32]).astype("int8")
Expand Down Expand Up @@ -1092,15 +1022,9 @@ def gemm_mma_m16n8k32_row_col_s8u8s32(a: T.handle, b: T.handle, c: T.handle):
] = Accum[mma_accum_c_id]


@tvm.testing.requires_cuda
@tvm.testing.requires_cuda_compute_version(8)
def test_gemm_mma_m16n8k32_row_col_s8u8s32():
sch = tvm.tir.Schedule(gemm_mma_m16n8k32_row_col_s8u8s32)
arch = tvm.contrib.nvcc.get_target_compute_version()
major, minor = tvm.contrib.nvcc.parse_compute_version(arch)
if major < 8:
# Require at least SM80
return
cuda_mod = tvm.build(sch.mod, target="cuda")
cuda_mod = tvm.build(sch.mod, target="cuda")

A_np = np.random.uniform(-10, 10, [16, 32]).astype("int8")
Expand Down Expand Up @@ -1174,15 +1098,9 @@ def gemm_mma_m16n8k64_row_col_s4s4s32(a: T.handle, b: T.handle, c: T.handle):
] = Accum[mma_accum_c_id]


@tvm.testing.requires_cuda
@tvm.testing.requires_cuda_compute_version(8)
def test_gemm_mma_m16n8k64_row_col_s4s4s32():
sch = tvm.tir.Schedule(gemm_mma_m16n8k64_row_col_s4s4s32)
arch = tvm.contrib.nvcc.get_target_compute_version()
major, minor = tvm.contrib.nvcc.parse_compute_version(arch)
if major < 8:
# Require at least SM80
return
cuda_mod = tvm.build(sch.mod, target="cuda")
cuda_mod = tvm.build(sch.mod, target="cuda")

ctx = tvm.cuda()
Expand Down Expand Up @@ -1248,15 +1166,9 @@ def gemm_mma_m16n8k64_row_col_s4u4s32(a: T.handle, b: T.handle, c: T.handle):
] = Accum[mma_accum_c_id]


@tvm.testing.requires_cuda
@tvm.testing.requires_cuda_compute_version(8)
def test_gemm_mma_m16n8k64_row_col_s4u4s32():
sch = tvm.tir.Schedule(gemm_mma_m16n8k64_row_col_s4u4s32)
arch = tvm.contrib.nvcc.get_target_compute_version()
major, minor = tvm.contrib.nvcc.parse_compute_version(arch)
if major < 8:
# Require at least SM80
return
cuda_mod = tvm.build(sch.mod, target="cuda")
cuda_mod = tvm.build(sch.mod, target="cuda")

ctx = tvm.cuda()
Expand Down Expand Up @@ -1323,15 +1235,9 @@ def gemm_mma_m16n8k256_row_col_b1b1s32(a: T.handle, b: T.handle, c: T.handle):
] = Accum[mma_accum_c_id]


@tvm.testing.requires_cuda
@tvm.testing.requires_cuda_compute_version(8)
def test_gemm_mma_m16n8k256_row_col_b1b1s32():
sch = tvm.tir.Schedule(gemm_mma_m16n8k256_row_col_b1b1s32)
arch = tvm.contrib.nvcc.get_target_compute_version()
major, minor = tvm.contrib.nvcc.parse_compute_version(arch)
if major < 8:
# Require at least SM80
return
cuda_mod = tvm.build(sch.mod, target="cuda")
cuda_mod = tvm.build(sch.mod, target="cuda")

ctx = tvm.cuda()
Expand All @@ -1345,20 +1251,4 @@ def test_gemm_mma_m16n8k256_row_col_b1b1s32():


if __name__ == "__main__":
test_gemm_mma_m8n8k4_row_col_fp64pf64fp64()
test_gemm_mma_m8n8k4_row_row_fp16fp16fp16()
test_gemm_mma_m8n8k4_row_row_fp16fp16fp32()
test_gemm_mma_m8n8k16_row_col_s8s8s32()
test_gemm_mma_m8n8k16_row_col_s8u8s32()
test_gemm_mma_m8n8k32_row_col_s4s4s32()
test_gemm_mma_m8n8k32_row_col_s4u4s32()
test_gemm_mma_m16n8k8_row_col_fp16fp16fp32()
test_gemm_mma_m16n8k16_row_col_fp16fp16fp16()
test_gemm_mma_m16n8k16_row_col_fp16fp16fp32()
test_gemm_mma_m16n8k16_row_col_s8s8s32()
test_gemm_mma_m16n8k16_row_col_s8u8s32()
test_gemm_mma_m16n8k32_row_col_s8s8s32()
test_gemm_mma_m16n8k32_row_col_s8u8s32()
test_gemm_mma_m16n8k64_row_col_s4s4s32()
test_gemm_mma_m16n8k64_row_col_s4u4s32()
test_gemm_mma_m16n8k256_row_col_b1b1s32()
tvm.testing.main()
Loading