Skip to content

Commit 3b6ba34

Browse files
committed
moe_mask_gen_cpu_impl is replaced by moe_mask_gen_gpu_impl
1 parent 23653bb commit 3b6ba34

File tree

5 files changed

+189
-5
lines changed

5 files changed

+189
-5
lines changed

src/plugins/intel_gpu/src/graph/impls/cpu/moe_mask_gen.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -187,8 +187,5 @@ attach_moe_mask_gen_reshape_impl::attach_moe_mask_gen_reshape_impl() {
187187
} // namespace cpu
188188
} // namespace cldnn
189189

190-
BIND_BINARY_BUFFER_WITH_TYPE(cldnn::cpu::moe_mask_gen_impl)
191-
BIND_BINARY_BUFFER_WITH_TYPE(cldnn::moe_mask_gen)
192-
193190
BIND_BINARY_BUFFER_WITH_TYPE(cldnn::cpu::moe_mask_gen_reshape_impl)
194191
BIND_BINARY_BUFFER_WITH_TYPE(cldnn::moe_mask_gen_reshape)
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
// Copyright (C) 2025 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
#include "moe_mask_gen.hpp"
5+
6+
#include "../common_utils/dispatch_utils.hpp"
7+
#include "../common_utils/jitter.hpp"
8+
#include "../primitive_ocl_base.hpp"
9+
#include "../utils/kernel_generator.hpp"
10+
#include "intel_gpu/primitives/moe_mask_gen.hpp"
11+
12+
namespace ov::intel_gpu::ocl {
13+
namespace {
14+
15+
class MoeMaskGenRefGenerator : public KernelGenerator {
16+
public:
17+
MoeMaskGenRefGenerator() : KernelGenerator("moe_mask_gen") {}
18+
19+
protected:
20+
[[nodiscard]] JitConstants get_jit_constants(const RuntimeParams& params) const override {
21+
auto jit = KernelGenerator::get_jit_constants(params);
22+
23+
auto prim = params.typed_desc<moe_mask_gen>();
24+
jit.make("NUM_EXPERTS_PER_TOKEN", prim->num_experts_per_token);
25+
26+
return jit;
27+
}
28+
29+
Arguments get_arguments_desc(const RuntimeParams& params) const override {
30+
Arguments args;
31+
if (params.is_dynamic()) {
32+
args.push_back({ArgumentDescriptor::Types::SHAPE_INFO, 0});
33+
}
34+
35+
args.push_back({ArgumentDescriptor::Types::INPUT, 0});
36+
37+
const uint32_t num_of_outputs = 5;
38+
for (uint32_t i = 0; i < num_of_outputs; i++) {
39+
args.push_back({ArgumentDescriptor::Types::OUTPUT, i});
40+
}
41+
42+
return args;
43+
}
44+
45+
[[nodiscard]] DispatchDataFunc get_dispatch_data_func() const override {
46+
return DispatchDataFunc{[](const RuntimeParams& params, KernelData& kd, ImplRuntimeParams* rt_params) {
47+
auto& wgs = kd.params.workGroups;
48+
if (!params.is_dynamic()) {
49+
auto num_total_experts = static_cast<size_t>(params.typed_desc<moe_mask_gen>()->num_total_experts);
50+
wgs.global = {num_total_experts, 1, 1};
51+
wgs.local = {num_total_experts, 1, 1};
52+
}
53+
}};
54+
}
55+
};
56+
57+
class MoeMaskGenRefImpl : public PrimitiveImplOCL {
58+
public:
59+
DECLARE_OBJECT_TYPE_SERIALIZATION(ov::intel_gpu::ocl::MoeMaskGenRefImpl)
60+
61+
Stage::Ptr moe_mask_gen = make_stage<MoeMaskGenRefGenerator>();
62+
63+
MoeMaskGenRefImpl() : PrimitiveImplOCL(MoeMaskGenRef::get_type_info_static()) {}
64+
MoeMaskGenRefImpl(const program_node& node, const RuntimeParams& params) : MoeMaskGenRefImpl() {
65+
add_stage(moe_mask_gen, params);
66+
}
67+
68+
[[nodiscard]] std::unique_ptr<primitive_impl> clone() const override {
69+
return make_deep_copy<MoeMaskGenRefImpl>(this);
70+
}
71+
};
72+
73+
} // namespace
74+
75+
std::unique_ptr<primitive_impl> MoeMaskGenRef::create_impl(const program_node& node, const RuntimeParams& params) const {
76+
assert(node.is_type<moe_mask_gen>());
77+
return std::make_unique<MoeMaskGenRefImpl>(node, params);
78+
}
79+
80+
} // namespace ov::intel_gpu::ocl
81+
82+
BIND_BINARY_BUFFER_WITH_TYPE(cldnn::moe_mask_gen)
83+
BIND_BINARY_BUFFER_WITH_TYPE(ov::intel_gpu::ocl::MoeMaskGenRefImpl)
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
// Copyright (C) 2025 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#pragma once
6+
7+
#include <memory>
8+
#include <utility>
9+
10+
#include "program_node.h"
11+
#include "registry/implementation_manager.hpp"
12+
13+
using namespace cldnn; // TODO: Remove once namespaces are aligned
14+
15+
namespace ov::intel_gpu::ocl {
16+
17+
struct MoeMaskGenRef : public ImplementationManager {
18+
OV_GPU_PRIMITIVE_IMPL("ocl::moe_mask_gen::ref")
19+
explicit MoeMaskGenRef(shape_types shape_type, ValidateFunc vf = nullptr) : ImplementationManager(impl_types::ocl, shape_type, std::move(vf)) {}
20+
[[nodiscard]] std::unique_ptr<primitive_impl> create_impl(const program_node& node, const RuntimeParams& params) const override;
21+
[[nodiscard]] bool validate_impl(const program_node& node) const override {
22+
static constexpr std::array supported_fmts = {
23+
format::bfyx,
24+
};
25+
26+
static constexpr std::array supported_types = {
27+
ov::element::f32,
28+
ov::element::i32,
29+
ov::element::i64,
30+
};
31+
32+
const auto& in0_layout = node.get_input_layout(0);
33+
const auto& out_layout = node.get_output_layout(0);
34+
35+
if (!one_of(in0_layout.format, supported_fmts) || !one_of(out_layout.format, supported_fmts)) {
36+
return false;
37+
}
38+
39+
if (!one_of(in0_layout.data_type, supported_types) || !one_of(out_layout.data_type, supported_types)) {
40+
return false;
41+
}
42+
43+
return true;
44+
}
45+
};
46+
47+
} // namespace ov::intel_gpu::ocl
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
// Copyright (C) 2025 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "include/batch_headers/common.cl"
6+
7+
KERNEL(moe_mask_gen)(
8+
OPTIONAL_SHAPE_INFO_ARG
9+
const __global INPUT0_TYPE* topk_idx,
10+
__global OUTPUT_TYPE* tokens_per_expert,
11+
__global OUTPUT1_TYPE* experts_info_start_idx,
12+
__global OUTPUT2_TYPE* experts_id,
13+
__global OUTPUT3_TYPE* tokens_lens_per_expert,
14+
__global OUTPUT4_TYPE* num_actual_used_experts
15+
)
16+
{
17+
const size_t expert_id = get_local_id(0);
18+
int num_tokens = INPUT0_BATCH_NUM;
19+
20+
int num_tokens_per_curr_expert = 0;
21+
for (int i = 0; i < num_tokens * NUM_EXPERTS_PER_TOKEN; ++i) {
22+
if (topk_idx[i] == expert_id) {
23+
num_tokens_per_curr_expert += 1;
24+
}
25+
}
26+
int is_used = (num_tokens_per_curr_expert > 0) ? 1 : 0;
27+
28+
int tokens_per_expert_iter = work_group_scan_exclusive_add(num_tokens_per_curr_expert);
29+
int experts_id_iter = work_group_scan_exclusive_add(is_used);
30+
31+
if ((expert_id + 1) == get_local_size(0)) {
32+
num_actual_used_experts[0] = experts_id_iter + is_used;
33+
}
34+
35+
if (num_tokens_per_curr_expert == 0) {
36+
return;
37+
}
38+
39+
experts_info_start_idx[experts_id_iter] = tokens_per_expert_iter;
40+
experts_id[experts_id_iter] = expert_id;
41+
tokens_lens_per_expert[experts_id_iter] = num_tokens_per_curr_expert;
42+
43+
int token_idx = 0;
44+
for (int t = 0; t < num_tokens; ++t) {
45+
for (int e = 0; e < NUM_EXPERTS_PER_TOKEN; ++e) {
46+
if (topk_idx[token_idx] == expert_id) {
47+
tokens_per_expert[tokens_per_expert_iter] = t;
48+
tokens_per_expert_iter += 1;
49+
}
50+
token_idx += 1;
51+
}
52+
}
53+
}

src/plugins/intel_gpu/src/graph/registry/moe_mask_gen_impls.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,18 @@
66
#include "intel_gpu/primitives/moe_mask_gen.hpp"
77
#include "primitive_inst.h"
88

9+
#if OV_GPU_WITH_OCL
10+
#include "impls/ocl_v2/moe/moe_mask_gen.hpp"
11+
#endif
12+
913
namespace ov::intel_gpu {
1014

1115
using namespace cldnn;
1216

1317
const std::vector<std::shared_ptr<cldnn::ImplementationManager>>& Registry<moe_mask_gen>::get_implementations() {
1418
static const std::vector<std::shared_ptr<ImplementationManager>> impls = {
15-
OV_GPU_GET_INSTANCE_CPU(moe_mask_gen, shape_types::static_shape)
16-
OV_GPU_GET_INSTANCE_CPU(moe_mask_gen, shape_types::dynamic_shape)
19+
OV_GPU_CREATE_INSTANCE_OCL(ocl::MoeMaskGenRef, shape_types::static_shape)
20+
OV_GPU_CREATE_INSTANCE_OCL(ocl::MoeMaskGenRef, shape_types::dynamic_shape)
1721
};
1822

1923
return impls;

0 commit comments

Comments
 (0)