Skip to content

Commit 25eafc7

Browse files
authored
[Auto Parallel] add spmd_rule for moe_combine moe_combine_grad moe_gate_dispatch moe_gate_dispatch_grad (#74215)
* add moe spmd_rule to .yaml * format spmd name * update spmd_rule for MoEGateDispatchInferSpmd * update test for moe_gate_dispatch_spmd
1 parent 583b68d commit 25eafc7

File tree

9 files changed

+124
-65
lines changed

9 files changed

+124
-65
lines changed

paddle/phi/infermeta/spmd_rules/moe_combine.cc

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@ limitations under the License. */
2525
namespace phi {
2626
namespace distributed {
2727

28-
SpmdInfo MoECombineFwdInferSpmd(const DistMetaTensor& x,
29-
const DistMetaTensor& combine_weights,
30-
const DistMetaTensor& scatter_index) {
28+
SpmdInfo MoECombineInferSpmd(const DistMetaTensor& x,
29+
const DistMetaTensor& combine_weights,
30+
const DistMetaTensor& scatter_index) {
3131
/* kernel logic:
3232
y is [seqlen, hidden_size]
3333
for kk in k:
@@ -107,10 +107,10 @@ SpmdInfo MoECombineFwdInferSpmd(const DistMetaTensor& x,
107107
{y_dist_attr_dst}};
108108
}
109109

110-
SpmdInfo MoECombineBwdInferSpmd(const DistMetaTensor& x,
111-
const DistMetaTensor& combine_weights,
112-
const DistMetaTensor& scatter_index,
113-
const DistMetaTensor& grad_y) {
110+
SpmdInfo MoECombineGradInferSpmd(const DistMetaTensor& x,
111+
const DistMetaTensor& combine_weights,
112+
const DistMetaTensor& scatter_index,
113+
const DistMetaTensor& grad_y) {
114114
/* kernel logic:
115115
for(int i = 0; i < s; ++i) {
116116
for(int j = 0; j < h; ++j) {

paddle/phi/infermeta/spmd_rules/moe_combine.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,14 @@ limitations under the License. */
2222
namespace phi {
2323
namespace distributed {
2424

25-
SpmdInfo MoECombineFwdInferSpmd(const DistMetaTensor& x,
26-
const DistMetaTensor& combine_weights,
27-
const DistMetaTensor& scatter_index);
28-
29-
SpmdInfo MoECombineBwdInferSpmd(const DistMetaTensor& x,
30-
const DistMetaTensor& combine_weights,
31-
const DistMetaTensor& scatter_index,
32-
const DistMetaTensor& grad_y);
25+
SpmdInfo MoECombineInferSpmd(const DistMetaTensor& x,
26+
const DistMetaTensor& combine_weights,
27+
const DistMetaTensor& scatter_index);
28+
29+
SpmdInfo MoECombineGradInferSpmd(const DistMetaTensor& x,
30+
const DistMetaTensor& combine_weights,
31+
const DistMetaTensor& scatter_index,
32+
const DistMetaTensor& grad_y);
3333

3434
} // namespace distributed
3535
} // namespace phi

paddle/phi/infermeta/spmd_rules/moe_gate_dispatch.cc

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,17 @@ limitations under the License. */
2222
namespace phi {
2323
namespace distributed {
2424

25-
SpmdInfo MoEGateDispatchFwdInferSpmd(const DistMetaTensor& x,
26-
const DistMetaTensor& gate_logits,
27-
int64_t k,
28-
int64_t capacity,
29-
bool use_pad) {
25+
SpmdInfo MoEGateDispatchInferSpmd(const DistMetaTensor& x,
26+
const DistMetaTensor& gate_logits,
27+
const DistMetaTensor& corr_bias,
28+
int64_t k,
29+
int64_t capacity,
30+
bool use_pad) {
3031
/*
3132
inputs:
3233
x: [S, H], S = b*s
3334
gate_logits: [S, E]
35+
corr_bias: [E] (optional)
3436
outputs:
3537
y: [E, C, H] is use_pad is true, else [S, K, H], currently only support
3638
use_pad=true combine_weights: [S, K] scatter_index: [K, S] expert_offset: [E]
@@ -52,6 +54,15 @@ SpmdInfo MoEGateDispatchFwdInferSpmd(const DistMetaTensor& x,
5254
errors::InvalidArgument("gate_logits should be a 2-D tensor, but "
5355
"got gate_logits_shape.size() == %d",
5456
gate_logits_shape.size()));
57+
if (corr_bias.initialized()) {
58+
EXTRACT_SHAPE_AND_DIST_ATTR_WITH_DIM_CK(corr_bias);
59+
PADDLE_ENFORCE_EQ(
60+
corr_bias_shape.size(),
61+
1,
62+
errors::InvalidArgument("corr_bias should be a 1-D tensor, but "
63+
"got corr_bias_shape.size() == %d",
64+
corr_bias_shape.size()));
65+
}
5566
// infer axes dims_mapping
5667
std::string x_axes = "sh";
5768
std::string gate_logits_axes = "se";
@@ -73,6 +84,7 @@ SpmdInfo MoEGateDispatchFwdInferSpmd(const DistMetaTensor& x,
7384
TensorDistAttr gate_logits_dist_attr_dst =
7485
CopyTensorDistAttrForOutput(gate_logits_dist_attr_src);
7586
gate_logits_dist_attr_dst.set_dims_mapping(gate_logits_dims_mapping_dst);
87+
TensorDistAttr corr_bias_dist_attr_dst;
7688

7789
// output axes
7890
std::string y_axes = "esh";
@@ -107,22 +119,32 @@ SpmdInfo MoEGateDispatchFwdInferSpmd(const DistMetaTensor& x,
107119
TensorDistAttr expert_id_dist_attr =
108120
CopyTensorDistAttrForOutput(x_dist_attr_src);
109121
expert_id_dist_attr.set_dims_mapping(expert_id_dims_mapping);
110-
return {{x_dist_attr_dst, gate_logits_dist_attr_dst},
122+
if (corr_bias.initialized()) {
123+
EXTRACT_SHAPE_AND_DIST_ATTR(corr_bias);
124+
corr_bias_dist_attr_dst =
125+
CopyTensorDistAttrForOutput(corr_bias_dist_attr_src);
126+
corr_bias_dist_attr_dst.set_dims_mapping(
127+
std::vector<int64_t>{gate_logits_dist_attr_dst.dims_mapping().back()});
128+
} else {
129+
corr_bias_dist_attr_dst = TensorDistAttr();
130+
}
131+
return {{x_dist_attr_dst, gate_logits_dist_attr_dst, corr_bias_dist_attr_dst},
111132
{y_dist_attr_dst,
112133
combine_weights_dist_attr,
113134
scatter_index_dist_attr,
114135
expert_offset_dist_attr,
115136
expert_id_dist_attr}};
116137
}
117138

118-
SpmdInfo MoEGateDispatchBwdInferSpmd(const DistMetaTensor& combine_weights,
119-
const DistMetaTensor& scatter_index,
120-
const DistMetaTensor& expert_id,
121-
const DistMetaTensor& grad_y,
122-
const DistMetaTensor& grad_combine_weights,
123-
int64_t k,
124-
int64_t capacity,
125-
bool use_pad) {
139+
SpmdInfo MoEGateDispatchGradInferSpmd(
140+
const DistMetaTensor& combine_weights,
141+
const DistMetaTensor& scatter_index,
142+
const DistMetaTensor& expert_id,
143+
const DistMetaTensor& grad_y,
144+
const DistMetaTensor& grad_combine_weights,
145+
int64_t k,
146+
int64_t capacity,
147+
bool use_pad) {
126148
/*
127149
inputs:
128150
combine_weights: [S, K]

paddle/phi/infermeta/spmd_rules/moe_gate_dispatch.h

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,21 +19,23 @@ limitations under the License. */
1919
namespace phi {
2020
namespace distributed {
2121

22-
SpmdInfo MoEGateDispatchFwdInferSpmd(const DistMetaTensor& x,
23-
const DistMetaTensor& gate_logits,
24-
int64_t k,
25-
int64_t capacity,
26-
bool use_pad);
22+
SpmdInfo MoEGateDispatchInferSpmd(const DistMetaTensor& x,
23+
const DistMetaTensor& gate_logits,
24+
const DistMetaTensor& corr_bias,
25+
int64_t k,
26+
int64_t capacity,
27+
bool use_pad);
2728
// out: "y", "combine_weights", "scatter_index", "expert_offset", "expert_id"
2829

29-
SpmdInfo MoEGateDispatchBwdInferSpmd(const DistMetaTensor& combine_weights,
30-
const DistMetaTensor& scatter_index,
31-
const DistMetaTensor& expert_id,
32-
const DistMetaTensor& grad_y,
33-
const DistMetaTensor& grad_combine_weights,
34-
int64_t k,
35-
int64_t capacity,
36-
bool use_pad);
30+
SpmdInfo MoEGateDispatchGradInferSpmd(
31+
const DistMetaTensor& combine_weights,
32+
const DistMetaTensor& scatter_index,
33+
const DistMetaTensor& expert_id,
34+
const DistMetaTensor& grad_y,
35+
const DistMetaTensor& grad_combine_weights,
36+
int64_t k,
37+
int64_t capacity,
38+
bool use_pad);
3739
// out: "x_grad", "gate_logits_grad"
3840

3941
} // namespace distributed

paddle/phi/infermeta/spmd_rules/rules.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -835,4 +835,14 @@ PD_REGISTER_SPMD_RULE(
835835
PD_REGISTER_SPMD_RULE(einsum,
836836
PD_INFER_SPMD(phi::distributed::EinsumInferSpmd),
837837
PD_INFER_SPMD(phi::distributed::EinsumGradInferSpmd));
838+
// moe_gate_dispatch
839+
PD_REGISTER_SPMD_RULE(
840+
moe_gate_dispatch,
841+
PD_INFER_SPMD(phi::distributed::MoEGateDispatchInferSpmd),
842+
PD_INFER_SPMD(phi::distributed::MoEGateDispatchGradInferSpmd));
843+
844+
// moe_combine
845+
PD_REGISTER_SPMD_RULE(moe_combine,
846+
PD_INFER_SPMD(phi::distributed::MoECombineInferSpmd),
847+
PD_INFER_SPMD(phi::distributed::MoECombineGradInferSpmd));
838848
} // namespace phi::distributed

paddle/phi/ops/yaml/backward.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2356,6 +2356,7 @@
23562356
output : Tensor(x_grad), Tensor(combine_weights_grad)
23572357
infer_meta :
23582358
func : MoeCombineGradInferMeta
2359+
spmd_rule : MoECombineGradInferSpmd
23592360
kernel :
23602361
func : moe_combine_grad
23612362

@@ -2376,6 +2377,7 @@
23762377
output : Tensor(x_grad), Tensor(gate_logits_grad)
23772378
infer_meta :
23782379
func : MoeGateDispatchGradInferMeta
2380+
spmd_rule : MoEGateDispatchGradInferSpmd
23792381
kernel :
23802382
func : moe_gate_dispatch_grad
23812383
data_type : y_grad

paddle/phi/ops/yaml/ops.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3689,6 +3689,7 @@
36893689
output : Tensor(y)
36903690
infer_meta :
36913691
func : MoeCombineInferMeta
3692+
spmd_rule : MoECombineInferSpmd
36923693
kernel :
36933694
func : moe_combine
36943695
data_type : x
@@ -3709,6 +3710,7 @@
37093710
output : Tensor(y), Tensor(combine_weights), Tensor(scatter_index), Tensor(expert_offset), Tensor(expert_id)
37103711
infer_meta :
37113712
func : MoeGateDispatchInferMeta
3713+
spmd_rule : MoEGateDispatchInferSpmd
37123714
kernel :
37133715
func : moe_gate_dispatch
37143716
data_type : x

test/cpp/auto_parallel/moe_combine_spmd_rule_test.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,12 @@ void test_moe_combine_spmd(
6767

6868
phi::distributed::SpmdInfo spmd_info;
6969
if (test_bwd_spmd) {
70-
spmd_info = phi::distributed::MoECombineBwdInferSpmd(dist_meta_tensors[0],
71-
dist_meta_tensors[1],
72-
dist_meta_tensors[2],
73-
dist_meta_tensors[3]);
70+
spmd_info = phi::distributed::MoECombineGradInferSpmd(dist_meta_tensors[0],
71+
dist_meta_tensors[1],
72+
dist_meta_tensors[2],
73+
dist_meta_tensors[3]);
7474
} else {
75-
spmd_info = phi::distributed::MoECombineFwdInferSpmd(
75+
spmd_info = phi::distributed::MoECombineInferSpmd(
7676
dist_meta_tensors[0], dist_meta_tensors[1], dist_meta_tensors[2]);
7777
}
7878

test/cpp/auto_parallel/moe_gate_dispatch_spmd_rule_test.cc

Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,13 @@ void test_moe_gate_dispatch_spmd(
2626
int64_t k,
2727
int64_t capacity,
2828
bool use_pad,
29-
bool test_bwd_spmd = false) {
29+
bool test_bwd_spmd = false,
30+
bool optional = true) {
3031
size_t num_inputs = 0;
3132
if (test_bwd_spmd) {
3233
num_inputs = 5;
3334
} else {
34-
num_inputs = 2;
35+
num_inputs = 3;
3536
}
3637

3738
EXPECT_EQ(input_shapes.size(), num_inputs)
@@ -68,17 +69,23 @@ void test_moe_gate_dispatch_spmd(
6869
phi::distributed::SpmdInfo spmd_info;
6970
if (test_bwd_spmd) {
7071
spmd_info =
71-
phi::distributed::MoEGateDispatchBwdInferSpmd(dist_meta_tensors[0],
72-
dist_meta_tensors[1],
73-
dist_meta_tensors[2],
74-
dist_meta_tensors[3],
75-
dist_meta_tensors[4],
76-
k,
77-
capacity,
78-
use_pad);
72+
phi::distributed::MoEGateDispatchGradInferSpmd(dist_meta_tensors[0],
73+
dist_meta_tensors[1],
74+
dist_meta_tensors[2],
75+
dist_meta_tensors[3],
76+
dist_meta_tensors[4],
77+
k,
78+
capacity,
79+
use_pad);
7980
} else {
80-
spmd_info = phi::distributed::MoEGateDispatchFwdInferSpmd(
81-
dist_meta_tensors[0], dist_meta_tensors[1], k, capacity, use_pad);
81+
phi::distributed::DistMetaTensor uninitialized_tensor;
82+
spmd_info = phi::distributed::MoEGateDispatchInferSpmd(
83+
dist_meta_tensors[0],
84+
dist_meta_tensors[1],
85+
optional ? dist_meta_tensors[2] : uninitialized_tensor,
86+
k,
87+
capacity,
88+
use_pad);
8289
}
8390

8491
for (size_t i = 0; i < 2; ++i) {
@@ -106,17 +113,18 @@ void test_moe_gate_dispatch_spmd(
106113
TEST(MoECombineSPMDRule, test_moe_gate_dispatch_spmd) {
107114
int64_t s = 1024, h = 512, k = 2, e = 8, capacity = 1024;
108115
bool use_pad = true;
109-
const std::vector<std::vector<int64_t>>& forward_input_shapes = {{s, h},
110-
{s, e}};
116+
const std::vector<std::vector<int64_t>>& forward_input_shapes = {
117+
{s, h}, {s, e}, {e}};
111118
const std::vector<std::vector<int64_t>>& backward_input_shapes = {
112119
{s, k}, {k, s}, {s, k}, {e, capacity, h}, {s, k}};
113120

114121
// replicated case, forward
115-
std::vector<std::vector<int64_t>> input_dims_mappings = {{-1, -1}, {-1, -1}};
122+
std::vector<std::vector<int64_t>> input_dims_mappings = {
123+
{-1, -1}, {-1, -1}, {-1}};
116124
std::pair<std::vector<std::vector<int64_t>>,
117125
std::vector<std::vector<int64_t>>>
118126
expected_dims_mappings = {
119-
{{-1, -1}, {-1, -1}},
127+
{{-1, -1}, {-1, -1}, {-1}},
120128
{{-1, -1, -1}, {-1, -1}, {-1, -1}, {-1}, {-1, -1}}};
121129
test_moe_gate_dispatch_spmd(forward_input_shapes,
122130
input_dims_mappings,
@@ -139,8 +147,8 @@ TEST(MoECombineSPMDRule, test_moe_gate_dispatch_spmd) {
139147
true);
140148

141149
// ep case, forward
142-
input_dims_mappings = {{0, -1}, {-1, -1}};
143-
expected_dims_mappings = {{{0, -1}, {0, -1}},
150+
input_dims_mappings = {{0, -1}, {-1, -1}, {-1}};
151+
expected_dims_mappings = {{{0, -1}, {0, -1}, {-1}},
144152
{{-1, 0, -1}, {0, -1}, {-1, 0}, {-1}, {0, -1}}};
145153
test_moe_gate_dispatch_spmd(forward_input_shapes,
146154
input_dims_mappings,
@@ -160,6 +168,19 @@ TEST(MoECombineSPMDRule, test_moe_gate_dispatch_spmd) {
160168
capacity,
161169
use_pad,
162170
true);
171+
172+
// ep, corr_bias is none case, forward
173+
input_dims_mappings = {{0, -1}, {-1, -1}, {-1}};
174+
expected_dims_mappings = {{{0, -1}, {0, -1}, {}},
175+
{{-1, 0, -1}, {0, -1}, {-1, 0}, {-1}, {0, -1}}};
176+
test_moe_gate_dispatch_spmd(forward_input_shapes,
177+
input_dims_mappings,
178+
expected_dims_mappings,
179+
k,
180+
capacity,
181+
use_pad,
182+
false,
183+
false);
163184
}
164185

165186
} // namespace auto_parallel

0 commit comments

Comments
 (0)