Skip to content
Merged
1 change: 1 addition & 0 deletions examples/deepseek_mla/test_example_mla_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@


@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_mla_decode():
with mock.patch.object(sys, 'argv', ["example_mla_decode.py"]):
example_mla_decode.main()
Expand Down
6 changes: 4 additions & 2 deletions src/tl_templates/cuda/gemm_sm80.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ struct OperandTraits<16, N, K, false, num_warp_n,
Swizzle<2, 3, 3>{}, Layout<Shape<_32, _8>, Stride<_1, _32>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
Step<_2, _1>{}));
using Copy = SM75_U16x8_LDSM_T;
using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U16x4_LDSM_N,
SM75_U16x8_LDSM_N>::type;
};

template <int N, int K, int num_warp_n>
Expand All @@ -108,7 +109,8 @@ struct OperandTraits<16, N, K, false, num_warp_n,
Swizzle<3, 3, 3>{}, Layout<Shape<_64, _8>, Stride<_1, _64>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
Step<_2, _1>{}));
using Copy = SM75_U16x8_LDSM_T;
using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U16x4_LDSM_N,
SM75_U16x8_LDSM_N>::type;
};

template <int N, int K, int num_warp_n>
Expand Down
6 changes: 4 additions & 2 deletions src/tl_templates/cuda/gemm_sm89.h
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,8 @@ struct OperandTraits<16, N, K, false, num_warp_n,
Swizzle<2, 3, 3>{}, Layout<Shape<_32, _8>, Stride<_1, _32>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
Step<_2, _1>{}));
using Copy = SM75_U16x8_LDSM_T;
using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U16x4_LDSM_N,
SM75_U16x8_LDSM_N>::type;
};

template <int N, int K, int num_warp_n>
Expand All @@ -211,7 +212,8 @@ struct OperandTraits<16, N, K, false, num_warp_n,
Swizzle<3, 3, 3>{}, Layout<Shape<_64, _8>, Stride<_1, _64>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
Step<_2, _1>{}));
using Copy = SM75_U16x8_LDSM_T;
using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U16x4_LDSM_T,
SM75_U16x8_LDSM_T>::type;
};

template <int N, int K, int num_warp_n>
Expand Down
6 changes: 4 additions & 2 deletions src/tl_templates/cuda/gemm_sm90.h
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,8 @@ struct OperandTraits<16, N, K, false, num_warp_n,
Swizzle<2, 3, 3>{}, Layout<Shape<_32, _8>, Stride<_1, _32>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
Step<_2, _1>{}));
using Copy = SM75_U16x8_LDSM_T;
using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U16x4_LDSM_N,
SM75_U16x8_LDSM_N>::type;
};

template <int N, int K, int num_warp_n>
Expand All @@ -265,7 +266,8 @@ struct OperandTraits<16, N, K, false, num_warp_n,
Swizzle<3, 3, 3>{}, Layout<Shape<_64, _8>, Stride<_1, _64>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
Step<_2, _1>{}));
using Copy = SM75_U16x8_LDSM_T;
using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U16x4_LDSM_N,
SM75_U16x8_LDSM_N>::type;
};

template <int N, int K, int num_warp_n>
Expand Down
Loading