Skip to content

Commit c3c5794

Browse files
gshtrasepwalsh
authored andcommitted
[Bugfix][ROCm] Fix for warp_size uses on host (vllm-project#21205)
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
1 parent 84ea6f8 commit c3c5794

File tree

9 files changed

+67
-31
lines changed

9 files changed

+67
-31
lines changed

csrc/attention/attention_kernels.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
#include "attention_dtypes.h"
2626
#include "attention_utils.cuh"
27-
#include "cuda_compat.h"
27+
#include "../cuda_compat.h"
2828

2929
#ifdef USE_ROCM
3030
#include <hip/hip_bf16.h>

csrc/attention/paged_attention_v1.cu

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,8 @@
1616
* See the License for the specific language governing permissions and
1717
* limitations under the License.
1818
*/
19-
2019
#include "attention_kernels.cuh"
21-
#include "cuda_compat.h"
20+
#include "../cuda_compat.h"
2221

2322
#define MAX(a, b) ((a) > (b) ? (a) : (b))
2423
#define MIN(a, b) ((a) < (b) ? (a) : (b))
@@ -75,7 +74,7 @@ void paged_attention_v1_launcher(
7574
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
7675
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
7776

78-
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
77+
const int NUM_WARPS = NUM_THREADS / WARP_SIZE;
7978
int padded_max_seq_len =
8079
DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE;
8180
int logits_size = padded_max_seq_len * sizeof(float);

csrc/attention/paged_attention_v2.cu

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,8 @@
1616
* See the License for the specific language governing permissions and
1717
* limitations under the License.
1818
*/
19-
2019
#include "attention_kernels.cuh"
21-
#include "cuda_compat.h"
20+
#include "../cuda_compat.h"
2221

2322
#define MAX(a, b) ((a) > (b) ? (a) : (b))
2423
#define MIN(a, b) ((a) < (b) ? (a) : (b))
@@ -79,7 +78,7 @@ void paged_attention_v2_launcher(
7978
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
8079
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
8180

82-
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
81+
const int NUM_WARPS = NUM_THREADS / WARP_SIZE;
8382
int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
8483
int logits_size = PARTITION_SIZE * sizeof(float);
8584
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);

csrc/cuda_compat.h

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,35 @@
44
#include <hip/hip_runtime.h>
55
#endif
66

7-
#if defined(USE_ROCM) && defined(__GFX9__)
8-
#define WARP_SIZE 64
7+
#ifdef USE_ROCM
8+
struct Utils {
9+
static __host__ int get_warp_size() {
10+
static bool is_cached = false;
11+
static int result;
12+
13+
if (!is_cached) {
14+
int device_id;
15+
cudaDeviceProp deviceProp;
16+
cudaGetDevice(&device_id);
17+
cudaGetDeviceProperties(&deviceProp, device_id);
18+
19+
result = deviceProp.warpSize;
20+
is_cached = true;
21+
}
22+
23+
return result;
24+
}
25+
26+
static __device__ constexpr int get_warp_size() {
27+
#ifdef __GFX9__
28+
return 64;
29+
#else
30+
return 32;
31+
#endif
32+
}
33+
};
34+
35+
#define WARP_SIZE Utils::get_warp_size()
936
#else
1037
#define WARP_SIZE 32
1138
#endif

csrc/moe/topk_softmax_kernels.cu

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -190,8 +190,8 @@ __launch_bounds__(TPB) __global__ void moeTopK(
190190
2) This implementation assumes k is small, but will work for any k.
191191
*/
192192

193-
template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG, typename IndType>
194-
__launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
193+
template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG, int WARP_SIZE_PARAM, typename IndType>
194+
__launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__
195195
void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, IndType* indices,
196196
int* source_rows, const int k, const int start_expert, const int end_expert)
197197
{
@@ -209,12 +209,12 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
209209

210210
// Restrictions based on previous section.
211211
static_assert(VPT % ELTS_PER_LDG == 0, "The elements per thread must be a multiple of the elements per ldg");
212-
static_assert(WARP_SIZE % THREADS_PER_ROW == 0, "The threads per row must cleanly divide the threads per warp");
212+
static_assert(WARP_SIZE_PARAM % THREADS_PER_ROW == 0, "The threads per row must cleanly divide the threads per warp");
213213
static_assert(THREADS_PER_ROW == (THREADS_PER_ROW & -THREADS_PER_ROW), "THREADS_PER_ROW must be power of 2");
214-
static_assert(THREADS_PER_ROW <= WARP_SIZE, "THREADS_PER_ROW can be at most warp size");
214+
static_assert(THREADS_PER_ROW <= WARP_SIZE_PARAM, "THREADS_PER_ROW can be at most warp size");
215215

216216
// We have NUM_EXPERTS elements per row. We specialize for small #experts
217-
static constexpr int ELTS_PER_WARP = WARP_SIZE * VPT;
217+
static constexpr int ELTS_PER_WARP = WARP_SIZE_PARAM * VPT;
218218
static constexpr int ROWS_PER_WARP = ELTS_PER_WARP / ELTS_PER_ROW;
219219
static constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP;
220220

@@ -393,41 +393,51 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
393393
namespace detail
394394
{
395395
// Constructs some constants needed to partition the work across threads at compile time.
396-
template <int EXPERTS, int BYTES_PER_LDG>
396+
template <int EXPERTS, int BYTES_PER_LDG, int WARP_SIZE_PARAM>
397397
struct TopkConstants
398398
{
399399
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float);
400-
static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, "");
401-
static constexpr int VECs_PER_THREAD = MAX(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE));
400+
static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE_PARAM) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE_PARAM) == 0, "");
401+
static constexpr int VECs_PER_THREAD = MAX(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE_PARAM));
402402
static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG;
403403
static constexpr int THREADS_PER_ROW = EXPERTS / VPT;
404-
static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW;
404+
static const int ROWS_PER_WARP = WARP_SIZE_PARAM / THREADS_PER_ROW;
405405
};
406406
} // namespace detail
407407

408-
template <int EXPERTS, int WARPS_PER_TB, typename IndType>
408+
template <int EXPERTS, int WARPS_PER_TB, int WARP_SIZE_PARAM, typename IndType>
409409
void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, IndType* indices,
410410
int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream)
411411
{
412412
static constexpr std::size_t MAX_BYTES_PER_LDG = 16;
413413

414414
static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS);
415-
using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG>;
415+
using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG, WARP_SIZE_PARAM>;
416416
static constexpr int VPT = Constants::VPT;
417417
static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP;
418418
const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP;
419419
const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB;
420420

421-
dim3 block_dim(WARP_SIZE, WARPS_PER_TB);
422-
topkGatingSoftmax<VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG><<<num_blocks, block_dim, 0, stream>>>(
421+
dim3 block_dim(WARP_SIZE_PARAM, WARPS_PER_TB);
422+
topkGatingSoftmax<VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG, WARP_SIZE_PARAM><<<num_blocks, block_dim, 0, stream>>>(
423423
input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert);
424424
}
425425

426-
#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB) \
427-
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB>( \
428-
gating_output, nullptr, topk_weights, topk_indices, \
429-
token_expert_indices, num_tokens, topk, 0, num_experts, \
430-
stream);
426+
#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB) \
427+
switch (warpSize) { \
428+
case 32: \
429+
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, 32>( \
430+
gating_output, nullptr, topk_weights, topk_indices, \
431+
token_expert_indices, num_tokens, topk, 0, num_experts, stream); \
432+
break; \
433+
case 64: \
434+
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, 64>( \
435+
gating_output, nullptr, topk_weights, topk_indices, \
436+
token_expert_indices, num_tokens, topk, 0, num_experts, stream); \
437+
break; \
438+
default: \
439+
TORCH_CHECK(false, "Unsupported warp size: ", warpSize); \
440+
}
431441

432442
template <typename IndType>
433443
void topkGatingSoftmaxKernelLauncher(
@@ -441,6 +451,7 @@ void topkGatingSoftmaxKernelLauncher(
441451
const int topk,
442452
cudaStream_t stream) {
443453
static constexpr int WARPS_PER_TB = 4;
454+
auto warpSize = WARP_SIZE;
444455
switch (num_experts) {
445456
case 1:
446457
LAUNCH_SOFTMAX(1, WARPS_PER_TB);

csrc/quantization/activation_kernels.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
#include <cmath>
66
#include "core/math.hpp"
7-
#include "cuda_compat.h"
7+
#include "../cuda_compat.h"
88
#include "dispatch_utils.h"
99

1010
#include "quantization/fp8/common.cuh"

csrc/quantization/gguf/gguf_kernel.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
#include <torch/all.h>
55
#include <c10/cuda/CUDAGuard.h>
66

7-
#include "cuda_compat.h"
7+
#include "../../cuda_compat.h"
88
#include "dispatch_utils.h"
99

1010
#include "ggml-common.h"

csrc/rocm/attention.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
#include <c10/cuda/CUDAGuard.h>
2020
#include <hip/hip_fp8.h>
2121
#include <hip/hip_bf16.h>
22-
#include "cuda_compat.h"
22+
#include "../cuda_compat.h"
2323

2424
#include <algorithm>
2525
#include "../attention/dtype_fp8.cuh"

csrc/rocm/skinny_gemms.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
#include <stdexcept>
1010
#include <algorithm>
1111

12-
#include "cuda_compat.h"
12+
#include "../cuda_compat.h"
1313
#include "dispatch_utils.h"
1414
#include "quantization/fp8/common.cuh"
1515

0 commit comments

Comments
 (0)