Skip to content

Commit 5d6eb72

Browse files
committed
warp size fixes
1 parent 33091a9 commit 5d6eb72

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

ggml-cuda.cu

+6-2
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,11 @@ typedef struct {
182182
} block_q6_k;
183183
static_assert(sizeof(block_q6_k) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_k block size/padding");
184184

185+
#if defined(GGML_USE_HIPBLAS)
186+
#define WARP_SIZE warpSize
187+
#else
185188
#define WARP_SIZE 32
189+
#endif
186190

187191
#define CUDA_MUL_BLOCK_SIZE 256
188192

@@ -679,8 +683,8 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y,
679683
// sum up partial sums and write back result
680684
__syncthreads();
681685
#pragma unroll
682-
for (int mask = 16; mask > 0; mask >>= 1) {
683-
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
686+
for (int mask = WARP_SIZE/2; mask > 0; mask >>= 1) {
687+
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, WARP_SIZE);
684688
}
685689

686690
if (tid == 0) {

0 commit comments

Comments
 (0)