diff --git a/examples/deepseek_mla/test_example_mla_decode.py b/examples/deepseek_mla/test_example_mla_decode.py index d011526b3..9cde90b83 100644 --- a/examples/deepseek_mla/test_example_mla_decode.py +++ b/examples/deepseek_mla/test_example_mla_decode.py @@ -8,6 +8,7 @@ @tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(9, 0) def test_example_mla_decode(): with mock.patch.object(sys, 'argv', ["example_mla_decode.py"]): example_mla_decode.main() diff --git a/src/tl_templates/cuda/gemm_sm80.h b/src/tl_templates/cuda/gemm_sm80.h index a79a5ccf1..55d18c1b1 100644 --- a/src/tl_templates/cuda/gemm_sm80.h +++ b/src/tl_templates/cuda/gemm_sm80.h @@ -98,7 +98,8 @@ struct OperandTraits<16, N, K, false, num_warp_n, Swizzle<2, 3, 3>{}, Layout, Stride<_1, _32>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape, Int>{}, Step<_2, _1>{})); - using Copy = SM75_U16x8_LDSM_T; + using Copy = typename std::conditional::type; }; template @@ -108,7 +109,8 @@ struct OperandTraits<16, N, K, false, num_warp_n, Swizzle<3, 3, 3>{}, Layout, Stride<_1, _64>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape, Int>{}, Step<_2, _1>{})); - using Copy = SM75_U16x8_LDSM_T; + using Copy = typename std::conditional::type; }; template diff --git a/src/tl_templates/cuda/gemm_sm89.h b/src/tl_templates/cuda/gemm_sm89.h index 4f7058896..8e326f86d 100644 --- a/src/tl_templates/cuda/gemm_sm89.h +++ b/src/tl_templates/cuda/gemm_sm89.h @@ -201,7 +201,8 @@ struct OperandTraits<16, N, K, false, num_warp_n, Swizzle<2, 3, 3>{}, Layout, Stride<_1, _32>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape, Int>{}, Step<_2, _1>{})); - using Copy = SM75_U16x8_LDSM_T; + using Copy = typename std::conditional::type; }; template @@ -211,7 +212,8 @@ struct OperandTraits<16, N, K, false, num_warp_n, Swizzle<3, 3, 3>{}, Layout, Stride<_1, _64>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape, Int>{}, Step<_2, _1>{})); - using Copy = SM75_U16x8_LDSM_T; + using Copy = typename std::conditional::type; }; template diff --git a/src/tl_templates/cuda/gemm_sm90.h b/src/tl_templates/cuda/gemm_sm90.h index 313793cd2..bf55499c8 100644 --- a/src/tl_templates/cuda/gemm_sm90.h +++ b/src/tl_templates/cuda/gemm_sm90.h @@ -255,7 +255,8 @@ struct OperandTraits<16, N, K, false, num_warp_n, Swizzle<2, 3, 3>{}, Layout, Stride<_1, _32>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape, Int>{}, Step<_2, _1>{})); - using Copy = SM75_U16x8_LDSM_T; + using Copy = typename std::conditional::type; }; template @@ -265,7 +266,8 @@ struct OperandTraits<16, N, K, false, num_warp_n, Swizzle<3, 3, 3>{}, Layout, Stride<_1, _64>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape, Int>{}, Step<_2, _1>{})); - using Copy = SM75_U16x8_LDSM_T; + using Copy = typename std::conditional::type; }; template