Skip to content

Commit 48544cd

Browse files
committed
Revert "Revert "ggml : add ggml_soft_max_ext (ggml-org#4256)""
This reverts commit a8e66ef.
1 parent 6570a20 commit 48544cd

File tree

8 files changed

+298
-183
lines changed

8 files changed

+298
-183
lines changed

examples/batched-bench/batched-bench.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ int main(int argc, char ** argv) {
155155
}
156156

157157
LOG_TEE("\n");
158-
LOG_TEE("%s: n_kv_max = %d, is_pp_shared = %d, n_gpu_layers = %d, mmq = %d\n", __func__, n_kv_max, is_pp_shared, n_gpu_layers, mmq);
158+
LOG_TEE("%s: n_kv_max = %d, is_pp_shared = %d, n_gpu_layers = %d, mmq = %d, n_threads = %d, n_threads_batch = %d\n", __func__, n_kv_max, is_pp_shared, n_gpu_layers, mmq, ctx_params.n_threads, ctx_params.n_threads_batch);
159159
LOG_TEE("\n");
160160

161161
LOG_TEE("|%6s | %6s | %4s | %6s | %8s | %8s | %8s | %8s | %8s | %8s |\n", "PP", "TG", "B", "N_KV", "T_PP s", "S_PP t/s", "T_TG s", "S_TG t/s", "T s", "S t/s");

ggml-alloc.c

+1-1
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ void ggml_tallocr_alloc(ggml_tallocr_t alloc, struct ggml_tensor * tensor) {
137137

138138
#ifdef GGML_ALLOCATOR_DEBUG
139139
add_allocated_tensor(alloc, tensor);
140-
size_t cur_max = (char*)addr - (char*)alloc->data + size;
140+
size_t cur_max = (char*)addr - (char*)alloc->base + size;
141141
if (cur_max > alloc->max_size) {
142142
printf("max_size = %.2f MB: tensors: ", cur_max / 1024.0 / 1024.0);
143143
for (int i = 0; i < 1024; i++) {

ggml-cuda.cu

+87-43
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,7 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
443443
#define CUDA_SCALE_BLOCK_SIZE 256
444444
#define CUDA_CLAMP_BLOCK_SIZE 256
445445
#define CUDA_ROPE_BLOCK_SIZE 256
446+
#define CUDA_SOFT_MAX_BLOCK_SIZE 1024
446447
#define CUDA_ALIBI_BLOCK_SIZE 32
447448
#define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32
448449
#define CUDA_QUANTIZE_BLOCK_SIZE 256
@@ -502,6 +503,31 @@ static size_t g_scratch_offset = 0;
502503

503504
static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
504505

506+
static __device__ __forceinline__ float warp_reduce_sum(float x) {
507+
#pragma unroll
508+
for (int mask = 16; mask > 0; mask >>= 1) {
509+
x += __shfl_xor_sync(0xffffffff, x, mask, 32);
510+
}
511+
return x;
512+
}
513+
514+
static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
515+
#pragma unroll
516+
for (int mask = 16; mask > 0; mask >>= 1) {
517+
a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
518+
a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
519+
}
520+
return a;
521+
}
522+
523+
static __device__ __forceinline__ float warp_reduce_max(float x) {
524+
#pragma unroll
525+
for (int mask = 16; mask > 0; mask >>= 1) {
526+
x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
527+
}
528+
return x;
529+
}
530+
505531
static __global__ void add_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
506532
const int i = blockDim.x*blockIdx.x + threadIdx.x;
507533

@@ -578,15 +604,6 @@ static __global__ void sqr_f32(const float * x, float * dst, const int k) {
578604
dst[i] = x[i] * x[i];
579605
}
580606

581-
static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
582-
#pragma unroll
583-
for (int mask = 16; mask > 0; mask >>= 1) {
584-
a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
585-
a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
586-
}
587-
return a;
588-
}
589-
590607
template <int block_size>
591608
static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
592609
const int row = blockIdx.x*blockDim.y + threadIdx.y;
@@ -625,14 +642,6 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
625642
}
626643
}
627644

628-
static __device__ __forceinline__ float warp_reduce_sum(float x) {
629-
#pragma unroll
630-
for (int mask = 16; mask > 0; mask >>= 1) {
631-
x += __shfl_xor_sync(0xffffffff, x, mask, 32);
632-
}
633-
return x;
634-
}
635-
636645
template <int block_size>
637646
static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
638647
const int row = blockIdx.x*blockDim.y + threadIdx.y;
@@ -4718,45 +4727,74 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
47184727
dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU
47194728
}
47204729

4721-
// the CUDA soft max implementation differs from the CPU implementation
4722-
// instead of doubles floats are used
4723-
static __global__ void soft_max_f32(const float * x, float * dst, const int ncols) {
4724-
const int row = blockDim.x*blockIdx.x + threadIdx.x;
4725-
const int block_size = blockDim.y;
4726-
const int tid = threadIdx.y;
4730+
static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale) {
4731+
const int tid = threadIdx.x;
4732+
const int rowx = blockIdx.x;
4733+
const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
4734+
4735+
const int block_size = blockDim.x;
4736+
4737+
const int warp_id = threadIdx.x / WARP_SIZE;
4738+
const int lane_id = threadIdx.x % WARP_SIZE;
4739+
4740+
__shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE/WARP_SIZE];
47274741

47284742
float max_val = -INFINITY;
47294743

47304744
for (int col = tid; col < ncols; col += block_size) {
4731-
const int i = row*ncols + col;
4732-
max_val = max(max_val, x[i]);
4745+
const int ix = rowx*ncols + col;
4746+
const int iy = rowy*ncols + col;
4747+
max_val = max(max_val, x[ix]*scale + (y ? y[iy] : 0.0f));
47334748
}
47344749

47354750
// find the max value in the block
4736-
#pragma unroll
4737-
for (int mask = 16; mask > 0; mask >>= 1) {
4738-
max_val = max(max_val, __shfl_xor_sync(0xffffffff, max_val, mask, 32));
4751+
max_val = warp_reduce_max(max_val);
4752+
if (block_size > WARP_SIZE) {
4753+
if (warp_id == 0) {
4754+
buf[lane_id] = -INFINITY;
4755+
}
4756+
__syncthreads();
4757+
4758+
if (lane_id == 0) {
4759+
buf[warp_id] = max_val;
4760+
}
4761+
__syncthreads();
4762+
4763+
max_val = buf[lane_id];
4764+
max_val = warp_reduce_max(max_val);
47394765
}
47404766

47414767
float tmp = 0.f;
47424768

47434769
for (int col = tid; col < ncols; col += block_size) {
4744-
const int i = row*ncols + col;
4745-
const float val = expf(x[i] - max_val);
4770+
const int ix = rowx*ncols + col;
4771+
const int iy = rowy*ncols + col;
4772+
const float val = expf((x[ix]*scale + (y ? y[iy] : 0.0f)) - max_val);
47464773
tmp += val;
4747-
dst[i] = val;
4774+
dst[ix] = val;
47484775
}
47494776

4750-
// sum up partial sums
4751-
#pragma unroll
4752-
for (int mask = 16; mask > 0; mask >>= 1) {
4753-
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
4777+
// find the sum of exps in the block
4778+
tmp = warp_reduce_sum(tmp);
4779+
if (block_size > WARP_SIZE) {
4780+
if (warp_id == 0) {
4781+
buf[lane_id] = 0.f;
4782+
}
4783+
__syncthreads();
4784+
4785+
if (lane_id == 0) {
4786+
buf[warp_id] = tmp;
4787+
}
4788+
__syncthreads();
4789+
4790+
tmp = buf[lane_id];
4791+
tmp = warp_reduce_sum(tmp);
47544792
}
47554793

47564794
const float inv_tmp = 1.f / tmp;
47574795

47584796
for (int col = tid; col < ncols; col += block_size) {
4759-
const int i = row*ncols + col;
4797+
const int i = rowx*ncols + col;
47604798
dst[i] *= inv_tmp;
47614799
}
47624800
}
@@ -5793,10 +5831,12 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols
57935831
diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
57945832
}
57955833

5796-
static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, cudaStream_t stream) {
5797-
const dim3 block_dims(1, WARP_SIZE, 1);
5834+
static void soft_max_f32_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
5835+
int nth = WARP_SIZE;
5836+
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
5837+
const dim3 block_dims(nth, 1, 1);
57985838
const dim3 block_nums(nrows_x, 1, 1);
5799-
soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x);
5839+
soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
58005840
}
58015841

58025842
static void im2col_f32_f16_cuda(const float * x, half * dst,
@@ -6835,14 +6875,18 @@ inline void ggml_cuda_op_soft_max(
68356875
GGML_ASSERT(src0->type == GGML_TYPE_F32);
68366876
GGML_ASSERT( dst->type == GGML_TYPE_F32);
68376877

6878+
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
6879+
68386880
const int64_t ne00 = src0->ne[0];
6839-
const int64_t nrows = ggml_nrows(src0);
6881+
const int64_t nrows_x = ggml_nrows(src0);
6882+
const int64_t nrows_y = src1 ? ggml_nrows(src1) : 1;
68406883

6841-
soft_max_f32_cuda(src0_dd, dst_dd, ne00, nrows, main_stream);
6884+
float scale = 1.0f;
6885+
memcpy(&scale, dst->op_params, sizeof(float));
6886+
6887+
soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
68426888

6843-
(void) src1;
68446889
(void) dst;
6845-
(void) src1_dd;
68466890
}
68476891

68486892
inline void ggml_cuda_op_scale(

ggml-metal.m

+27-16
Original file line numberDiff line numberDiff line change
@@ -1028,20 +1028,27 @@ void ggml_metal_graph_compute(
10281028
int nth = 32; // SIMD width
10291029

10301030
if (ne00%4 == 0) {
1031+
while (nth < ne00/4 && nth < 256) {
1032+
nth *= 2;
1033+
}
10311034
[encoder setComputePipelineState:ctx->pipeline_soft_max_4];
10321035
} else {
1033-
do {
1036+
while (nth < ne00 && nth < 1024) {
10341037
nth *= 2;
1035-
} while (nth <= ne00 && nth <= 1024);
1036-
nth /= 2;
1038+
}
10371039
[encoder setComputePipelineState:ctx->pipeline_soft_max];
10381040
}
1039-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1040-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1041-
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
1042-
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
1043-
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1044-
[encoder setThreadgroupMemoryLength:GGML_PAD(nth/32*sizeof(float), 16) atIndex:0];
1041+
1042+
const float scale = ((float *) dst->op_params)[0];
1043+
1044+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1045+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1046+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1047+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1048+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1049+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1050+
[encoder setBytes:&scale length:sizeof(scale) atIndex:6];
1051+
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
10451052

10461053
[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
10471054
} break;
@@ -1351,15 +1358,19 @@ void ggml_metal_graph_compute(
13511358
float eps;
13521359
memcpy(&eps, dst->op_params, sizeof(float));
13531360

1354-
const int nth = MIN(512, ne00);
1361+
int nth = 32; // SIMD width
1362+
1363+
while (nth < ne00/4 && nth < 1024) {
1364+
nth *= 2;
1365+
}
13551366

13561367
[encoder setComputePipelineState:ctx->pipeline_rms_norm];
1357-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1358-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1359-
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1360-
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
1361-
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
1362-
[encoder setThreadgroupMemoryLength:GGML_PAD(nth/32*sizeof(float), 16) atIndex:0];
1368+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1369+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1370+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1371+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
1372+
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
1373+
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
13631374

13641375
const int64_t nrows = ggml_nrows(src0);
13651376

0 commit comments

Comments
 (0)