Skip to content

Commit 2040bbf

Browse files
am17anNexesenex
authored andcommitted
CUDA: add dynamic shared mem to softmax, refactor general usage (ggml-org#14497)
1 parent 2e0f87d commit 2040bbf

File tree

1 file changed

+38
-40
lines changed

1 file changed

+38
-40
lines changed

ggml/src/ggml-cuda/softmax.cu

Lines changed: 38 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include "ggml.h"
33
#include "softmax.cuh"
44
#include <cstdint>
5+
#include <utility>
56

67
template <typename T>
78
static __device__ __forceinline__ float t2f32(T val) {
@@ -188,6 +189,37 @@ static __global__ void soft_max_back_f32(
188189
}
189190
}
190191

192+
template<int... Ns, typename T>
193+
static void launch_soft_max_kernels(const float * x, const T * mask, float * dst,
194+
const soft_max_params & p, cudaStream_t stream, dim3 block_dims, dim3 block_nums, size_t nbytes_shared)
195+
{
196+
const int id = ggml_cuda_get_device();
197+
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
198+
199+
auto launch_kernel = [=](auto I) -> bool {
200+
constexpr int ncols = decltype(I)::value;
201+
constexpr int block = (ncols > 1024 ? 1024 : ncols);
202+
203+
if (p.ncols == ncols) {
204+
CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, ncols, block, T>), smpbo);
205+
soft_max_f32<true, ncols, block><<<block_nums, block_dims, nbytes_shared, stream>>>
206+
(x, mask, dst, p);
207+
return true;
208+
}
209+
return false;
210+
};
211+
212+
// unary fold over launch_kernel
213+
if ((launch_kernel(std::integral_constant<int, Ns>{}) || ...)) {
214+
return;
215+
}
216+
217+
//default case
218+
CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, 0, 0, T>), smpbo);
219+
soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>(x, mask, dst, p);
220+
}
221+
222+
191223
template<typename T>
192224
static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const soft_max_params & params, cudaStream_t stream) {
193225
int nth = WARP_SIZE;
@@ -200,46 +232,12 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons
200232
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
201233

202234

203-
// FIXME: this limit could be raised by ~2-4x on Ampere or newer
204-
if (nbytes_shared < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
205-
switch (ncols_x) {
206-
case 32:
207-
soft_max_f32<true, 32, 32><<<block_nums, block_dims, nbytes_shared, stream>>>
208-
(x, mask, dst, params);
209-
break;
210-
case 64:
211-
soft_max_f32<true, 64, 64><<<block_nums, block_dims, nbytes_shared, stream>>>
212-
(x, mask, dst, params);
213-
break;
214-
case 128:
215-
soft_max_f32<true, 128, 128><<<block_nums, block_dims, nbytes_shared, stream>>>
216-
(x, mask, dst, params);
217-
break;
218-
case 256:
219-
soft_max_f32<true, 256, 256><<<block_nums, block_dims, nbytes_shared, stream>>>
220-
(x, mask, dst, params);
221-
break;
222-
case 512:
223-
soft_max_f32<true, 512, 512><<<block_nums, block_dims, nbytes_shared, stream>>>
224-
(x, mask, dst, params);
225-
break;
226-
case 1024:
227-
soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
228-
(x, mask, dst, params);
229-
break;
230-
case 2048:
231-
soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
232-
(x, mask, dst, params);
233-
break;
234-
case 4096:
235-
soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
236-
(x, mask, dst, params);
237-
break;
238-
default:
239-
soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>
240-
(x, mask, dst, params);
241-
break;
242-
}
235+
const int id = ggml_cuda_get_device();
236+
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
237+
238+
239+
if (nbytes_shared <= smpbo) {
240+
launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, dst, params, stream, block_dims, block_nums, nbytes_shared);
243241
} else {
244242
const size_t nbytes_shared_low = WARP_SIZE*sizeof(float);
245243
soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, dst, params);

0 commit comments

Comments
 (0)