Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding UNREACHABLE_CODE macro for non MI300 and MI250 cards #138

Merged
merged 8 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 32 additions & 31 deletions csrc/custom/custom_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,18 @@
#include <stdexcept>
#include <algorithm>

#if defined(__HIPCC__) && \
(defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
#define __HIP__MI300__
#if defined(__HIPCC__) && (defined(__gfx90a) || defined(__gfx940__) || \
defined(__gfx941__) || defined(__gfx942__))
#define __HIP__MI300_MI250__
#endif

#if defined(NDEBUG)
#undef NDEBUG
#include <assert.h>
#define UNREACHABLE_CODE assert(false);
#define NDEBUG
#else
#define UNREACHABLE_CODE assert(false);
#endif

constexpr int WARP_SIZE = 64;
Expand Down Expand Up @@ -334,7 +343,7 @@ __device__ __forceinline__ T loadnt(T* addr) {
#define M 1
#define DTYPE half

#if defined(__HIP__MI300__) // TODO: Add NAVI support
#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support

__global__ void wvSpltK_hf_m1_sml_(const int K, const int N, const DTYPE* B,
const DTYPE* __restrict__ A, DTYPE* C,
Expand Down Expand Up @@ -463,17 +472,15 @@ __global__ void wvSpltK_hf_m1_sml_(const int K, const int N, const DTYPE* B,
}
}

#else // !defined(__HIP__MI300__) TODO: Add NAVI support
#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support

__global__ void wvSpltK_hf_m1_sml_(const int K, const int N, const DTYPE* B,
const DTYPE* __restrict__ A, DTYPE* C,
const int CuCount) {
assert(false);
}
const int CuCount){UNREACHABLE_CODE}

#endif // defined(__HIP__MI300__) TODO: Add NAVI support
#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support

#if defined(__HIP__MI300__) // TODO: Add NAVI support
#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support

__global__ void wvSpltK_hf_m1_(const int K, const int N, const DTYPE* B,
const DTYPE* __restrict__ A, DTYPE* C,
Expand Down Expand Up @@ -820,15 +827,13 @@ __global__ void wvSpltK_hf_m1_(const int K, const int N, const DTYPE* B,
}
}

#else // !defined(__HIP__MI300__) TODO: Add NAVI support
#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support

__global__ void wvSpltK_hf_m1_(const int K, const int N, const DTYPE* B,
const DTYPE* __restrict__ A, DTYPE* C,
const int CuCount) {
assert(false);
}
const int CuCount){UNREACHABLE_CODE}

#endif // defined(__HIP__MI300__) TODO: Add NAVI support
#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support

#undef YTILE
#undef UNRL
Expand All @@ -838,7 +843,7 @@ __global__ void wvSpltK_hf_m1_(const int K, const int N, const DTYPE* B,
#define UNRL 2
#define M 2

#if defined(__HIP__MI300__) // TODO: Add NAVI support
#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support

__global__ void wvSpltK_hf_m2_(const int K, const int N, const DTYPE* B,
const DTYPE* __restrict__ A, DTYPE* C,
Expand Down Expand Up @@ -1185,15 +1190,13 @@ __global__ void wvSpltK_hf_m2_(const int K, const int N, const DTYPE* B,
}
}

#else // !defined(__HIP__MI300__) TODO: Add NAVI support
#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support

__global__ void wvSpltK_hf_m2_(const int K, const int N, const DTYPE* B,
const DTYPE* __restrict__ A, DTYPE* C,
const int CuCount) {
assert(false);
}
const int CuCount){UNREACHABLE_CODE}

#endif // defined(__HIP__MI300__) TODO: Add NAVI support
#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support

#undef YTILE
#undef UNRL
Expand All @@ -1203,7 +1206,7 @@ __global__ void wvSpltK_hf_m2_(const int K, const int N, const DTYPE* B,
#define UNRL 2
#define M 3

#if defined(__HIP__MI300__) // TODO: Add NAVI support
#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support

__global__ void wvSpltK_hf_m3_(const int K, const int N, const DTYPE* B,
const DTYPE* __restrict__ A, DTYPE* C,
Expand Down Expand Up @@ -1550,15 +1553,13 @@ __global__ void wvSpltK_hf_m3_(const int K, const int N, const DTYPE* B,
}
}

#else // !defined(__HIP__MI300__) TODO: Add NAVI support
#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support

__global__ void wvSpltK_hf_m3_(const int K, const int N, const DTYPE* B,
const DTYPE* __restrict__ A, DTYPE* C,
const int CuCount) {
assert(false);
}
const int CuCount){UNREACHABLE_CODE}

#endif // defined(__HIP__MI300__) TODO: Add NAVI support
#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support

#undef YTILE
#undef UNRL
Expand All @@ -1568,7 +1569,7 @@ __global__ void wvSpltK_hf_m3_(const int K, const int N, const DTYPE* B,
#define UNRL 1
#define M 4

#if defined(__HIP__MI300__) // TODO: Add NAVI support
#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support

__global__ void wvSpltK_hf_m4_(const int K, const int N, const DTYPE* B,
const DTYPE* __restrict__ A, DTYPE* C,
Expand Down Expand Up @@ -1915,15 +1916,15 @@ __global__ void wvSpltK_hf_m4_(const int K, const int N, const DTYPE* B,
}
}

#else // !defined(__HIP__MI300__) TODO: Add NAVI support
#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support

__global__ void wvSpltK_hf_m4_(const int K, const int N, const DTYPE* B,
const DTYPE* __restrict__ A, DTYPE* C,
const int CuCount) {
assert(false);
UNREACHABLE_CODE
}

#endif // defined(__HIP__MI300__) TODO: Add NAVI support
#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support

void wvSpltK_(void* in_a, void* in_b, void* out_c, const int M_in,
const int K_in, const int N_in, cudaStream_t stream,
Expand Down
27 changes: 17 additions & 10 deletions csrc/custom/paged_attention/attention_ll4mi.cu
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,26 @@

#include <algorithm>

#if defined(__HIPCC__) && \
(defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
#define __HIP__MI300__
#if defined(__HIPCC__) && (defined(__gfx90a) || defined(__gfx940__) || \
defined(__gfx941__) || defined(__gfx942__))
#define __HIP__MI300_MI250__
#endif

#if defined(NDEBUG)
#undef NDEBUG
#include <assert.h>
#define UNREACHABLE_CODE assert(false);
#define NDEBUG
#else
#define UNREACHABLE_CODE assert(false);
#endif

#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
#define WARP_SIZE 64

#if defined(__HIP__MI300__) // TODO: Add NAVI support
#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support

#define GCN_MFMA_INSTR1 __builtin_amdgcn_mfma_f32_16x16x4f32
#define GCN_MFMA_INSTR __builtin_amdgcn_mfma_f32_4x4x4f16
Expand Down Expand Up @@ -863,7 +872,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
out_ptr[threadIdx.x] = from_float<scalar_t>(acc);
}

#else // !defined(__HIP__MI300__) TODO: Add NAVI support
#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support

template <typename scalar_t, int BLOCK_SIZE, int HEAD_SIZE, int NUM_THREADS,
int GQA_RATIO>
Expand All @@ -889,7 +898,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
scalar_t* __restrict__ qk_out, // [num_heads, num_seqs, max_ctx_blocks,block_size]
#endif
int max_ctx_blocks) {
assert(false);
UNREACHABLE_CODE
}

// Grid: (num_heads, num_seqs).
Expand All @@ -905,11 +914,9 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
// max_num_partitions, head_size]
const int* __restrict__ context_lens, // [num_seqs]
const int max_num_partitions) {
assert(false);
}
const int max_num_partitions){UNREACHABLE_CODE}

#endif // defined(__HIP__MI300__) TODO: Add NAVI support
#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support

#define LAUNCH_CUSTOM_ATTENTION(GQA_RATIO) \
paged_attention_ll4mi_QKV_kernel<T, BLOCK_SIZE, HEAD_SIZE, NTHR, GQA_RATIO> \
Expand Down
Loading