Skip to content

Commit 7b16281

Browse files
committed
rebase
1 parent 34e5142 commit 7b16281

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

ggml/src/ggml-cuda/softmax.cu

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -183,8 +183,8 @@ static __global__ void soft_max_back_f32(
183183
}
184184

185185
template<int... Ns, typename T>
186-
static void launch_soft_max_kernels(float * x, const T * mask, float * dst,
187-
const soft_max_params & p, cudaStream_t stream)
186+
static void launch_soft_max_kernels(const float * x, const T * mask, float * dst,
187+
const soft_max_params & p, cudaStream_t stream, dim3 block_dims, dim3 block_nums, size_t nbytes_shared)
188188
{
189189
const int id = ggml_cuda_get_device();
190190
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
@@ -195,7 +195,7 @@ static void launch_soft_max_kernels(float * x, const T * mask, float * dst,
195195

196196
if (p.ncols == ncols) {
197197
CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, ncols, block, T>), smpbo);
198-
soft_max_f32<true, ncols, block><<<p.ne01, p.ne02, p.ne03, stream>>>
198+
soft_max_f32<true, ncols, block><<<block_nums, block_dims, nbytes_shared, stream>>>
199199
(x, mask, dst, p);
200200
return true;
201201
}
@@ -209,7 +209,7 @@ static void launch_soft_max_kernels(float * x, const T * mask, float * dst,
209209

210210
//default case
211211
CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, 0, 0, T>), smpbo);
212-
soft_max_f32<true, 0, 0><<<p.ne01, p.ne02, p.ne03, stream>>>(x, mask, dst, p);
212+
soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>(x, mask, dst, p);
213213
}
214214

215215

@@ -230,8 +230,7 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons
230230

231231

232232
if (nbytes_shared <= smpbo) {
233-
234-
launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, dst, params, stream);
233+
launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, dst, params, stream, block_dims, block_nums, nbytes_shared);
235234
} else {
236235
const size_t nbytes_shared_low = WARP_SIZE*sizeof(float);
237236
soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, dst, params);

tests/test-backend-ops.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4901,12 +4901,13 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
49014901
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {3072, 512, 2, 1}, {0, 2, 1, 3}));
49024902

49034903
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
4904-
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {12888, 1024, 5, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
4904+
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {12888, 256, 5, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
49054905
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 4096, 5, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
49064906
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {1024, 1024, 10, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
49074907
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 1024, 10, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
4908-
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {12888, 1024, 5, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
4909-
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 4096, 5, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
4908+
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {256, 256, 20, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
4909+
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {64, 64, 20, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
4910+
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 64, 20, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
49104911

49114912
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32, 10, 1, 1}));
49124913
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1}));

0 commit comments

Comments
 (0)