Skip to content

Commit b39cc2d

Browse files
committed
SYCL: Add COUNT_EQUAL operator support (rebased on master)
1 parent b8e09f0 commit b39cc2d

File tree

4 files changed

+64
-0
lines changed

4 files changed

+64
-0
lines changed

ggml/src/ggml-sycl/binbcast.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,10 @@ inline void ggml_sycl_op_sub(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
303303
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_sub>>(ctx, dst->src[0], dst->src[1], dst);
304304
}
305305

306+
inline void ggml_sycl_op_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
307+
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_count_equal>>(ctx, dst->src[0], dst->src[1], dst);
308+
}
309+
306310
inline void ggml_sycl_op_mul(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
307311

308312
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_mul>>(ctx, dst->src[0], dst->src[1], dst);
@@ -328,6 +332,11 @@ void ggml_sycl_sub(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
328332
ggml_sycl_op_sub(ctx, dst);
329333
}
330334

335+
void ggml_sycl_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
336+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
337+
ggml_sycl_op_count_equal(ctx, dst);
338+
}
339+
331340
void ggml_sycl_mul(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
332341
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
333342
ggml_sycl_op_mul(ctx, dst);

ggml/src/ggml-sycl/binbcast.hpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,30 @@ static __dpct_inline__ float op_add(const float a, const float b) {
1313
}
1414

1515
static __dpct_inline__ float op_sub(const float a, const float b) {
16+
17+
static __dpct_inline__ float op_count_equal(const float a, const float b) {
18+
return (a == b) ? 1.0f : 0.0f;
19+
}
20+
21+
void ggml_sycl_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
22+
1623
return a - b;
24+
25+
static __dpct_inline__ float op_count_equal(const float a, const float b) {
26+
return (a == b) ? 1.0f : 0.0f;
27+
}
28+
29+
void ggml_sycl_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
30+
31+
}
32+
33+
static __dpct_inline__ float op_count_equal(const float a, const float b) {
34+
return (a == b) ? 1.0f : 0.0f;
1735
}
1836

37+
void ggml_sycl_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
38+
39+
1940
static __dpct_inline__ float op_mul(const float a, const float b) {
2041
return a * b;
2142
}

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3577,6 +3577,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
35773577
case GGML_OP_SUB:
35783578
ggml_sycl_sub(ctx, dst);
35793579
break;
3580+
case GGML_OP_COUNT_EQUAL:
3581+
ggml_sycl_count_equal(ctx, dst);
3582+
break;
35803583
case GGML_OP_ACC:
35813584
ggml_sycl_acc(ctx, dst);
35823585
break;
@@ -4356,6 +4359,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
43564359
case GGML_OP_ADD:
43574360
case GGML_OP_ADD1:
43584361
case GGML_OP_SUB:
4362+
case GGML_OP_COUNT_EQUAL:
43594363
case GGML_OP_MUL:
43604364
case GGML_OP_DIV:
43614365
case GGML_OP_REPEAT:

tests/test-backend-ops.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2236,6 +2236,30 @@ struct test_count_equal : public test_case {
22362236
}
22372237
};
22382238

2239+
/* COUNT_EQUAL – typed test (no argmax), to cover F32/F16/I32/I16 */
2240+
struct test_count_equal_typed : public test_case {
2241+
const ggml_type type;
2242+
const std::array<int64_t, 4> ne;
2243+
2244+
test_count_equal_typed(ggml_type type = GGML_TYPE_F32,
2245+
std::array<int64_t, 4> ne = {128, 64, 1, 1})
2246+
: type(type), ne(ne) {}
2247+
2248+
std::string vars() override {
2249+
return VARS_TO_STR2(type, ne);
2250+
}
2251+
2252+
ggml_tensor * build_graph(ggml_context * ctx) override {
2253+
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
2254+
ggml_set_name(a, "a");
2255+
ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());
2256+
ggml_set_name(b, "b");
2257+
ggml_tensor * out = ggml_count_equal(ctx, a, b);
2258+
ggml_set_name(out, "out");
2259+
return out;
2260+
}
2261+
};
2262+
22392263
// GGML_OP_REPEAT
22402264
struct test_repeat : public test_case {
22412265
const ggml_type type;
@@ -5940,6 +5964,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
59405964

59415965
test_cases.emplace_back(new test_count_equal(GGML_TYPE_F32, {4, 500, 1, 1}));
59425966
test_cases.emplace_back(new test_count_equal(GGML_TYPE_F32, {4, 5000, 1, 1}));
5967+
// COUNT_EQUAL – typed tests by dtype
5968+
test_cases.emplace_back(new test_count_equal_typed(GGML_TYPE_F32, {1024, 1, 1, 1}));
5969+
test_cases.emplace_back(new test_count_equal_typed(GGML_TYPE_F32, { 64, 64, 1, 1}));
5970+
test_cases.emplace_back(new test_count_equal_typed(GGML_TYPE_F16, { 256, 32, 1, 1}));
5971+
test_cases.emplace_back(new test_count_equal_typed(GGML_TYPE_I32, { 512, 16, 1, 1}));
5972+
test_cases.emplace_back(new test_count_equal_typed(GGML_TYPE_I16, { 512, 16, 1, 1}));
59435973

59445974
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32, 1, 1, 1}));
59455975
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32, 513, 1, 1}));

0 commit comments

Comments
 (0)