@@ -26,12 +26,13 @@ void test_moe_gate_dispatch_spmd(
2626 int64_t k,
2727 int64_t capacity,
2828 bool use_pad,
29- bool test_bwd_spmd = false ) {
29+ bool test_bwd_spmd = false ,
30+ bool optional = true ) {
3031 size_t num_inputs = 0 ;
3132 if (test_bwd_spmd) {
3233 num_inputs = 5 ;
3334 } else {
34- num_inputs = 2 ;
35+ num_inputs = 3 ;
3536 }
3637
3738 EXPECT_EQ (input_shapes.size (), num_inputs)
@@ -68,17 +69,23 @@ void test_moe_gate_dispatch_spmd(
6869 phi::distributed::SpmdInfo spmd_info;
6970 if (test_bwd_spmd) {
7071 spmd_info =
71- phi::distributed::MoEGateDispatchBwdInferSpmd (dist_meta_tensors[0 ],
72- dist_meta_tensors[1 ],
73- dist_meta_tensors[2 ],
74- dist_meta_tensors[3 ],
75- dist_meta_tensors[4 ],
76- k,
77- capacity,
78- use_pad);
72+ phi::distributed::MoEGateDispatchGradInferSpmd (dist_meta_tensors[0 ],
73+ dist_meta_tensors[1 ],
74+ dist_meta_tensors[2 ],
75+ dist_meta_tensors[3 ],
76+ dist_meta_tensors[4 ],
77+ k,
78+ capacity,
79+ use_pad);
7980 } else {
80- spmd_info = phi::distributed::MoEGateDispatchFwdInferSpmd (
81- dist_meta_tensors[0 ], dist_meta_tensors[1 ], k, capacity, use_pad);
81+ phi::distributed::DistMetaTensor uninitialized_tensor;
82+ spmd_info = phi::distributed::MoEGateDispatchInferSpmd (
83+ dist_meta_tensors[0 ],
84+ dist_meta_tensors[1 ],
85+ optional ? dist_meta_tensors[2 ] : uninitialized_tensor,
86+ k,
87+ capacity,
88+ use_pad);
8289 }
8390
8491 for (size_t i = 0 ; i < 2 ; ++i) {
@@ -106,17 +113,18 @@ void test_moe_gate_dispatch_spmd(
106113TEST (MoECombineSPMDRule, test_moe_gate_dispatch_spmd) {
107114 int64_t s = 1024 , h = 512 , k = 2 , e = 8 , capacity = 1024 ;
108115 bool use_pad = true ;
109- const std::vector<std::vector<int64_t >>& forward_input_shapes = {{s, h},
110- {s, e}};
116+ const std::vector<std::vector<int64_t >>& forward_input_shapes = {
117+ {s, h}, {s, e}, { e}};
111118 const std::vector<std::vector<int64_t >>& backward_input_shapes = {
112119 {s, k}, {k, s}, {s, k}, {e, capacity, h}, {s, k}};
113120
114121 // replicated case, forward
115- std::vector<std::vector<int64_t >> input_dims_mappings = {{-1 , -1 }, {-1 , -1 }};
122+ std::vector<std::vector<int64_t >> input_dims_mappings = {
123+ {-1 , -1 }, {-1 , -1 }, {-1 }};
116124 std::pair<std::vector<std::vector<int64_t >>,
117125 std::vector<std::vector<int64_t >>>
118126 expected_dims_mappings = {
119- {{-1 , -1 }, {-1 , -1 }},
127+ {{-1 , -1 }, {-1 , -1 }, {- 1 } },
120128 {{-1 , -1 , -1 }, {-1 , -1 }, {-1 , -1 }, {-1 }, {-1 , -1 }}};
121129 test_moe_gate_dispatch_spmd (forward_input_shapes,
122130 input_dims_mappings,
@@ -139,8 +147,8 @@ TEST(MoECombineSPMDRule, test_moe_gate_dispatch_spmd) {
139147 true );
140148
141149 // ep case, forward
142- input_dims_mappings = {{0 , -1 }, {-1 , -1 }};
143- expected_dims_mappings = {{{0 , -1 }, {0 , -1 }},
150+ input_dims_mappings = {{0 , -1 }, {-1 , -1 }, {- 1 } };
151+ expected_dims_mappings = {{{0 , -1 }, {0 , -1 }, {- 1 } },
144152 {{-1 , 0 , -1 }, {0 , -1 }, {-1 , 0 }, {-1 }, {0 , -1 }}};
145153 test_moe_gate_dispatch_spmd (forward_input_shapes,
146154 input_dims_mappings,
@@ -160,6 +168,19 @@ TEST(MoECombineSPMDRule, test_moe_gate_dispatch_spmd) {
160168 capacity,
161169 use_pad,
162170 true );
171+
172+ // ep, corr_bias is none case, forward
173+ input_dims_mappings = {{0 , -1 }, {-1 , -1 }, {-1 }};
174+ expected_dims_mappings = {{{0 , -1 }, {0 , -1 }, {}},
175+ {{-1 , 0 , -1 }, {0 , -1 }, {-1 , 0 }, {-1 }, {0 , -1 }}};
176+ test_moe_gate_dispatch_spmd (forward_input_shapes,
177+ input_dims_mappings,
178+ expected_dims_mappings,
179+ k,
180+ capacity,
181+ use_pad,
182+ false ,
183+ false );
163184}
164185
165186} // namespace auto_parallel
0 commit comments