Skip to content

Commit b260213

Browse files
[SYCL] refactor soft_max, add soft_max_back (#16472)
* refactor to support soft_max_ext * fix error and support soft_max_back * rm unused functions * fix format issue --------- Co-authored-by: Zhang Jianyu <zhang.jianyu@outlook.com>
1 parent e08db42 commit b260213

File tree

5 files changed

+436
-188
lines changed

5 files changed

+436
-188
lines changed

ggml/src/ggml-sycl/common.hpp

Lines changed: 72 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ struct sycl_device_info {
197197
int cc; // compute capability
198198
// int nsm; // number of streaming multiprocessors
199199
// size_t smpb; // max. shared memory per block
200+
size_t smpbo; // max. shared memory per block (with opt-in)
200201
bool vmm; // virtual memory support
201202
size_t total_vram;
202203
//sycl_hw_info hw_info; \\ device id and aarch, currently not used
@@ -416,13 +417,6 @@ static __dpct_inline__ float warp_reduce_sum(float x,
416417
const sycl::nd_item<3>& item_ct1) {
417418
#pragma unroll
418419
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
419-
/*
420-
DPCT1096:98: The right-most dimension of the work-group used in the SYCL
421-
kernel that calls this function may be less than "32". The function
422-
"dpct::permute_sub_group_by_xor" may return an unexpected result on the
423-
CPU device. Modify the size of the work-group to ensure that the value
424-
of the right-most dimension is a multiple of "32".
425-
*/
426420
x += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), x, mask);
427421
}
428422
return x;
@@ -440,17 +434,67 @@ warp_reduce_sum(sycl::float2 a, const sycl::nd_item<3>& item_ct1) {
440434
return a;
441435
}
442436

437+
template <int width = WARP_SIZE>
438+
static __dpct_inline__ int warp_reduce_sum(int x) {
439+
return sycl::reduce_over_group(
440+
sycl::ext::oneapi::this_work_item::get_sub_group(), x, sycl::plus<>());
441+
}
442+
443+
template <int width = WARP_SIZE>
444+
static __dpct_inline__ float warp_reduce_sum(float x) {
445+
#pragma unroll
446+
for (int offset = width / 2; offset > 0; offset >>= 1) {
447+
x += dpct::permute_sub_group_by_xor(
448+
sycl::ext::oneapi::this_work_item::get_sub_group(), x, offset, width);
449+
}
450+
return x;
451+
}
452+
453+
template <int width = WARP_SIZE>
454+
static __dpct_inline__ sycl::float2 warp_reduce_sum(sycl::float2 a) {
455+
#pragma unroll
456+
for (int offset = width / 2; offset > 0; offset >>= 1) {
457+
a.x() += dpct::permute_sub_group_by_xor(
458+
sycl::ext::oneapi::this_work_item::get_sub_group(), a.x(), offset,
459+
width);
460+
a.y() += dpct::permute_sub_group_by_xor(
461+
sycl::ext::oneapi::this_work_item::get_sub_group(), a.y(), offset,
462+
width);
463+
}
464+
return a;
465+
}
466+
467+
template <int width = WARP_SIZE>
468+
static __dpct_inline__ sycl::half2 warp_reduce_sum(sycl::half2 a) {
469+
#pragma unroll
470+
for (int offset = width / 2; offset > 0; offset >>= 1) {
471+
a = a + dpct::permute_sub_group_by_xor(
472+
sycl::ext::oneapi::this_work_item::get_sub_group(), a, offset,
473+
width);
474+
}
475+
return a;
476+
}
477+
478+
static constexpr int ggml_sycl_get_physical_warp_size() {
479+
// todo: for old iGPU + dGPU case, need to be changed.
480+
return WARP_SIZE;
481+
}
482+
483+
template <int width = WARP_SIZE>
484+
static __dpct_inline__ float warp_reduce_max(float x) {
485+
#pragma unroll
486+
for (int offset = width / 2; offset > 0; offset >>= 1) {
487+
x = sycl::fmax(x, dpct::permute_sub_group_by_xor(
488+
sycl::ext::oneapi::this_work_item::get_sub_group(), x,
489+
offset, width));
490+
}
491+
return x;
492+
}
493+
443494
static __dpct_inline__ float warp_reduce_max(float x,
444495
const sycl::nd_item<3>& item_ct1) {
445496
#pragma unroll
446497
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
447-
/*
448-
DPCT1096:97: The right-most dimension of the work-group used in the SYCL
449-
kernel that calls this function may be less than "32". The function
450-
"dpct::permute_sub_group_by_xor" may return an unexpected result on the
451-
CPU device. Modify the size of the work-group to ensure that the value
452-
of the right-most dimension is a multiple of "32".
453-
*/
454498
x = sycl::fmax(x, dpct::permute_sub_group_by_xor(
455499
item_ct1.get_sub_group(), x, mask));
456500
}
@@ -558,4 +602,18 @@ struct scope_op_debug_print {
558602
std::string_view func_suffix;
559603
};
560604

605+
static __dpct_inline__ float get_alibi_slope(const float max_bias,
606+
const uint32_t h,
607+
const uint32_t n_head_log2,
608+
const float m0,
609+
const float m1) {
610+
if (max_bias <= 0.0f) {
611+
return 1.0f;
612+
}
613+
const float base = h < n_head_log2 ? m0 : m1;
614+
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
615+
616+
return dpct::pow(base, exph);
617+
}
618+
561619
#endif // GGML_SYCL_COMMON_HPP

ggml/src/ggml-sycl/dpct/helper.hpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,26 @@ namespace dpct
277277

278278
} // namespace detail
279279

280+
// COPY from DPCT head files
281+
/// dim3 is used to store 3 component dimensions.
282+
class dim3 {
283+
public:
284+
unsigned x, y, z;
285+
286+
constexpr dim3(unsigned x = 1, unsigned y = 1, unsigned z = 1)
287+
: x(x), y(y), z(z) {}
288+
289+
dim3(const sycl::id<3> &r) : dim3(r[2], r[1], r[0]) {}
290+
291+
operator sycl::range<3>() const { return sycl::range<3>(z, y, x); }
292+
}; // namespace dim3
293+
294+
inline dim3 operator*(const dim3 &a, const dim3 &b) {
295+
return dim3{a.x * b.x, a.y * b.y, a.z * b.z};
296+
}
297+
// COPY from DPCT head files
298+
299+
280300
/// Pitched 2D/3D memory data.
281301
class pitched_data
282302
{

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ static ggml_sycl_device_info ggml_sycl_init() {
8787
100 * prop.get_major_version() + 10 * prop.get_minor_version();
8888
info.devices[i].opt_feature.reorder = device.ext_oneapi_architecture_is(syclex::arch_category::intel_gpu);
8989
info.max_work_group_sizes[i] = prop.get_max_work_group_size();
90+
info.devices[i].smpbo = prop.get_local_mem_size();
9091
}
9192

9293
for (int id = 0; id < info.device_count; ++id) {
@@ -3741,6 +3742,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
37413742
case GGML_OP_SOFT_MAX:
37423743
ggml_sycl_op_soft_max(ctx, dst);
37433744
break;
3745+
case GGML_OP_SOFT_MAX_BACK:
3746+
ggml_sycl_op_soft_max_back(ctx, dst);
3747+
break;
37443748
case GGML_OP_ROPE:
37453749
ggml_sycl_rope(ctx, dst);
37463750
break;
@@ -3778,6 +3782,7 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
37783782
return true;
37793783
} catch (sycl::exception & e) {
37803784
std::cerr << e.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
3785+
std::cerr << "Error OP "<<ggml_op_name(dst->op)<< std::endl;
37813786
std::exit(1);
37823787
}
37833788

@@ -4386,19 +4391,15 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
43864391
return true;
43874392
case GGML_OP_CONT:
43884393
return op->src[0]->type != GGML_TYPE_BF16;
4389-
case GGML_OP_SOFT_MAX:
4390-
// TODO: support batching
4391-
if (op->src[0]->ne[3] != 1) {
4392-
return false;
4393-
}
4394-
// TODO: support attention sinks [TAG_ATTN_SINKS]
4395-
if (op->src[2]) {
4396-
return false;
4397-
}
4398-
// TODO: support broadcast
4399-
// ref: https://github.com/ggml-org/llama.cpp/pull/14435
4400-
return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1);
44014394
case GGML_OP_DIAG_MASK_INF:
4395+
return true;
4396+
case GGML_OP_SOFT_MAX:
4397+
return true;
4398+
case GGML_OP_SOFT_MAX_BACK: {
4399+
float max_bias = 0.0f;
4400+
memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float));
4401+
return max_bias == 0.0f;
4402+
}
44024403
case GGML_OP_ROPE:
44034404
case GGML_OP_IM2COL:
44044405
return true;

0 commit comments

Comments
 (0)