Skip to content

Commit

Permalink
[TESTS] Skip tests with tl.dot() on P100
Browse files Browse the repository at this point in the history
  • Loading branch information
shintaro-iwasaki committed Oct 31, 2022
1 parent 385f44e commit 55a1ade
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 0 deletions.
6 changes: 6 additions & 0 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1046,6 +1046,8 @@ def kernel(X, stride_xm, stride_xn,
if not (allow_tf32 and (dtype in ['float16']))])
def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
if cc < 70:
pytest.skip("Only test tl.dot() on devices with sm >= 70")
if cc < 80:
if dtype == 'int8':
pytest.skip("Only test int8 on devices with sm >= 80")
Expand Down Expand Up @@ -1231,6 +1233,10 @@ def _kernel(in_ptr, out_ptr, in_size: tl.constexpr, out_size: tl.constexpr):

@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
def test_masked_load_shared_memory(dtype, device='cuda'):
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
if cc < 70:
pytest.skip("Only test tl.dot() on devices with sm >= 70")

check_type_supported(dtype) # bfloat16 on cc < 80 will not be tested

M = 32
Expand Down
5 changes: 5 additions & 0 deletions python/test/unit/operators/test_blocksparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch

import triton
import triton._C.libtriton.triton as _triton


@pytest.mark.parametrize("MODE", ["sdd", "dds", "dsd"])
Expand Down Expand Up @@ -125,6 +126,10 @@ def test_attention_fwd_bwd(
batch_size=2,
n_heads=2,
):
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
if cc < 70:
pytest.skip("Only test tl.dot() on devices with sm >= 70")

# inputs
qkv_shape = (batch_size, n_heads, n_ctx, 64)
qkvs = [
Expand Down
2 changes: 2 additions & 0 deletions python/test/unit/operators/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@
)
def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE):
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
if cc < 70:
pytest.skip("Only test tl.dot() on devices with sm >= 70")
if cc < 80 and DTYPE == "bfloat16":
pytest.skip("Only test bfloat16 on devices with sm >= 80")
if DTYPE == "bfloat16" and SPLIT_K != 1:
Expand Down

0 comments on commit 55a1ade

Please sign in to comment.