1+ # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+ #
3+ # Licensed under the Apache License, Version 2.0 (the "License");
4+ # you may not use this file except in compliance with the License.
5+ # You may obtain a copy of the License at
6+ #
7+ # http://www.apache.org/licenses/LICENSE-2.0
8+ #
9+ # Unless required by applicable law or agreed to in writing, software
10+ # distributed under the License is distributed on an "AS IS" BASIS,
11+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ # See the License for the specific language governing permissions and
13+ # limitations under the License.
14+
15+ import random
116import unittest
17+
218import numpy as np
3- import random
419
520import paddle
621import paddle .distributed as dist
722from paddle import nn
8- from paddle .io import DataLoader , Dataset
9-
1023from paddle .distributed import fleet
1124from paddle .distributed .fleet .meta_parallel import (
1225 LayerDesc ,
13- SharedLayerDesc ,
1426 PipelineLayer ,
27+ SharedLayerDesc ,
1528)
29+ from paddle .io import DataLoader , Dataset
1630
1731batch_size = 5
1832micro_batch_size = 1
1933
34+
2035def set_random_seed (seed , dp_id , rank_id ):
2136 """Set random seed for reproducibility."""
2237 random .seed (seed )
2338 np .random .seed (seed + dp_id )
2439 paddle .seed (seed + dp_id )
2540
41+
2642class RandomDataset (Dataset ):
2743 def __init__ (self , num_samples ):
2844 self .num_samples = num_samples
@@ -39,22 +55,22 @@ def __len__(self):
3955vocab_size = 1024
4056hidden_size = 64
4157
58+
4259class EmbeddingPipe (nn .Layer ):
4360 def __init__ (self , ** kwargs ):
4461 super ().__init__ ()
45- self .embed_tokens = nn .Embedding (kwargs ["num_embeddings" ], kwargs ["embedding_dim" ])
46- #print(f"liyurui, embeding weight init = {self.embedding_weight._md5sum()}")
62+ self .embed_tokens = nn .Embedding (
63+ kwargs ["num_embeddings" ], kwargs ["embedding_dim" ]
64+ )
4765
4866 def forward (self , input_ids ):
49- #print(f"liyurui, input_ids is {input_ids}")
50- #print(f"liyurui, input_ids is {input_ids._md5sum()}, weight={self.embedding_weight._md5sum()}")
5167 hidden_states = self .embed_tokens .forward (input_ids )
52- #print(f"liyurui, hidden_states of embedding pipe {hidden_states._md5sum()}")
5368 return (hidden_states , input_ids )
5469
5570 @property
5671 def embedding_weight (self ):
57- return getattr (self .embed_tokens , "weight" )
72+ return self .embed_tokens .weight
73+
5874
5975def mtp_forward (layer , args ):
6076 hidden_states = args [0 ]
@@ -63,9 +79,9 @@ def mtp_forward(layer, args):
6379 output = embed [0 ] + hidden_states
6480 return (output , input_ids )
6581
82+
6683class MTPEmbeddingPipe (EmbeddingPipe ):
6784 def forward (self , args ):
68- #print(f"liyurui, input of MTPEmbedding is {args}")
6985 hidden_states = args [0 ]
7086 input_ids = args [1 ]
7187 embed = super ().forward (input_ids )
@@ -78,10 +94,10 @@ def __init__(
7894 self ,
7995 in_features ,
8096 out_features ,
81- weight_attr = None ,
82- bias_attr = None ,
83- name = None ,
84- layer_idx = 0
97+ weight_attr = None ,
98+ bias_attr = None ,
99+ name = None ,
100+ layer_idx = 0 ,
85101 ):
86102 self .layer_idx = layer_idx
87103 super ().__init__ (in_features , out_features , bias_attr = bias_attr )
@@ -101,27 +117,20 @@ def forward(self, logits, label):
101117
102118
103119class UnifiedPPModel (PipelineLayer ):
104- def __init__ (self , ** kwargs ):
120+ def __init__ (self , ** kwargs ):
105121 self ._sequential_layers = []
106122 self .num_layer = 4
107123
108124 self .add_sequential_layer (
109- SharedLayerDesc (
110- key = "embed_weight_share" ,
111- layer_func = EmbeddingPipe ,
112- shared_weight_attr = "embedding_weight" ,
113- num_embeddings = vocab_size ,
114- embedding_dim = hidden_size ,
115- ),
116- "embed" ,
125+ SharedLayerDesc (
126+ key = "embed_weight_share" ,
127+ layer_func = EmbeddingPipe ,
128+ shared_weight_attr = "embedding_weight" ,
129+ num_embeddings = vocab_size ,
130+ embedding_dim = hidden_size ,
131+ ),
132+ "embed" ,
117133 )
118- #self.add_sequential_layer(
119- # LayerDesc(
120- # EmbeddingPipe,
121- # num_embeddings=vocab_size,
122- # embedding_dim=hidden_size,
123- # ), "embed"
124- #)
125134
126135 for i in range (self .num_layer ):
127136 self .add_sequential_layer (
@@ -131,49 +140,48 @@ def __init__ (self, **kwargs):
131140 hidden_size ,
132141 bias_attr = False ,
133142 layer_idx = i ,
134- ), f"layer.{ i } "
143+ ),
144+ f"layer.{ i } " ,
135145 )
136146
137147 self .add_sequential_layer (
138- SharedLayerDesc (
139- key = "embed_weight_share" ,
140- #layer_func=MTPEmbeddingPipe,
141- layer_func = EmbeddingPipe ,
142- shared_weight_attr = "embedding_weight" ,
143- forward_func = mtp_forward ,
144- num_embeddings = vocab_size ,
145- embedding_dim = hidden_size ,
146- ),
147- "embed_shared" ,
148+ SharedLayerDesc (
149+ key = "embed_weight_share" ,
150+ layer_func = EmbeddingPipe ,
151+ shared_weight_attr = "embedding_weight" ,
152+ forward_func = mtp_forward ,
153+ num_embeddings = vocab_size ,
154+ embedding_dim = hidden_size ,
155+ ),
156+ "embed_shared" ,
148157 )
149- #self.add_sequential_layer(
150- # LayerDesc(
151- # MTPEmbeddingPipe,
152- # num_embeddings=vocab_size,
153- # embedding_dim=hidden_size,
154- # ), "embed"
155- #)
156158
157159 self .add_sequential_layer (
158- LayerDesc (
159- LinearPipe ,
160- hidden_size ,
161- hidden_size ,
162- bias_attr = False ,
163- layer_idx = self .num_layer
164- ), "last_layer"
160+ LayerDesc (
161+ LinearPipe ,
162+ hidden_size ,
163+ hidden_size ,
164+ bias_attr = False ,
165+ layer_idx = self .num_layer ,
166+ ),
167+ "last_layer" ,
165168 )
166169
167- super ().__init__ (layers = self .get_sequential_layer (), loss_fn = CrossEntropyLossPipe (), ** kwargs )
170+ super ().__init__ (
171+ layers = self .get_sequential_layer (),
172+ loss_fn = CrossEntropyLossPipe (),
173+ ** kwargs ,
174+ )
168175
169176 def add_sequential_layer (self , layer_desc , name_prefix = "" ):
170- self ._sequential_layers .append ({"layer" : layer_desc , "name_prefix" : name_prefix })
177+ self ._sequential_layers .append (
178+ {"layer" : layer_desc , "name_prefix" : name_prefix }
179+ )
171180
172181 def get_sequential_layer (self ):
173182 return [x ["layer" ] for x in self ._sequential_layers ]
174183
175184
176-
177185class TestDistPPTraining (unittest .TestCase ):
178186 def setUp (self ):
179187 strategy = fleet .DistributedStrategy ()
@@ -210,40 +218,44 @@ def test_unified_pp_model(self):
210218 rank_id = dist .get_rank ()
211219 set_random_seed (1024 , dp_id , rank_id )
212220
213- unified_model_pp = UnifiedPPModel (num_stages = self .pipeline_parallel_size )
214- unified_scheduler_pp , unified_optimizer_pp = self .build_optimizer (unified_model_pp )
215- unified_model_pp , unified_optimizer_pp = self .wrapper_mix_precision (unified_model_pp , unified_optimizer_pp )
221+ unified_model_pp = UnifiedPPModel (
222+ num_stages = self .pipeline_parallel_size
223+ )
224+ unified_scheduler_pp , unified_optimizer_pp = self .build_optimizer (
225+ unified_model_pp
226+ )
227+ unified_model_pp , unified_optimizer_pp = self .wrapper_mix_precision (
228+ unified_model_pp , unified_optimizer_pp
229+ )
216230 unified_model_pp = fleet .distributed_model (unified_model_pp )
217231 unified_optimizer_pp = fleet .distributed_optimizer (unified_optimizer_pp )
218232
219233 unified_model_nonpp = UnifiedPPModel (num_stages = 1 )
220- unified_scheduler_nonpp , unified_optimizer_nonpp = self .build_optimizer (unified_model_nonpp )
221-
222- pp_id_sname = {}
223- for n , p in unified_model_pp .named_parameters ():
224- pp_id_sname [id (p )] = n
225-
226- #for p in unified_model_pp.parameters():
227- # print(f"liyurui, pp parameter is {pp_id_sname[id(p)]}, {p.name}, {p.shape}")
228-
229- nonpp_id_sname = {}
230- for n , p in unified_model_nonpp .named_parameters ():
231- nonpp_id_sname [id (p )] = n
232-
233- #for p in unified_model_nonpp.parameters():
234- # print(f"liyurui, nonpp parameter is {nonpp_id_sname[id(p)]}, {p.name}, {p.shape}")
234+ unified_scheduler_nonpp , unified_optimizer_nonpp = self .build_optimizer (
235+ unified_model_nonpp
236+ )
235237
236238 # reset to make pp and nonpp model have same parameters value
237239 if pp_id == 0 :
238- unified_model_pp .parameters ()[0 ].set_value (unified_model_nonpp .parameters ()[0 ])
239- unified_model_pp .parameters ()[1 ].set_value (unified_model_nonpp .parameters ()[1 ])
240- unified_model_pp .parameters ()[2 ].set_value (unified_model_nonpp .parameters ()[2 ])
240+ unified_model_pp .parameters ()[0 ].set_value (
241+ unified_model_nonpp .parameters ()[0 ]
242+ )
243+ unified_model_pp .parameters ()[1 ].set_value (
244+ unified_model_nonpp .parameters ()[1 ]
245+ )
246+ unified_model_pp .parameters ()[2 ].set_value (
247+ unified_model_nonpp .parameters ()[2 ]
248+ )
241249 else :
242- #unified_model_pp.parameters()[0].set_value(unified_model_nonpp.parameters()[0])
243- unified_model_pp .parameters ()[1 ].set_value (unified_model_nonpp .parameters ()[3 ])
244- unified_model_pp .parameters ()[2 ].set_value (unified_model_nonpp .parameters ()[4 ])
245- unified_model_pp .parameters ()[3 ].set_value (unified_model_nonpp .parameters ()[5 ])
246- #unified_model_pp.parameters()[3].set_value(unified_model_nonpp.parameters()[6])
250+ unified_model_pp .parameters ()[1 ].set_value (
251+ unified_model_nonpp .parameters ()[3 ]
252+ )
253+ unified_model_pp .parameters ()[2 ].set_value (
254+ unified_model_nonpp .parameters ()[4 ]
255+ )
256+ unified_model_pp .parameters ()[3 ].set_value (
257+ unified_model_nonpp .parameters ()[5 ]
258+ )
247259
248260 dataset = RandomDataset (5 * batch_size )
249261
@@ -255,16 +267,10 @@ def test_unified_pp_model(self):
255267 num_workers = 2 ,
256268 )
257269
258- for p in unified_model_pp .parameters ():
259- print (f"liyurui, pp parameter is { pp_id_sname [id (p )]} , { p .name } , { p ._md5sum ()} " )
260-
261- for p in unified_model_nonpp .parameters ():
262- print (f"liyurui, nonpp parameter is { nonpp_id_sname [id (p )]} , { p .name } , { p ._md5sum ()} " )
263-
264270 for _ , (input_ids , label ) in enumerate (train_reader ()):
265- #print(f"liyurui, input_ids is {input_ids.shape}, {input_ids.dtype}, label is {label.shape}, {label.dtype}")
266- pp_loss = unified_model_pp . train_batch ( [input_ids , label ], unified_optimizer_pp , unified_scheduler_pp )
267- print ( f"liyurui, pp_loss is { pp_loss } " )
271+ pp_loss = unified_model_pp . train_batch (
272+ [input_ids , label ], unified_optimizer_pp , unified_scheduler_pp
273+ )
268274
269275 num_acc = batch_size // micro_batch_size
270276 micro_input_ids = paddle .split (input_ids , num_acc )
@@ -277,22 +283,13 @@ def test_unified_pp_model(self):
277283 loss = loss_fn (nonpp_output [0 ], micro_label ) / num_acc
278284 loss .backward ()
279285 nonpp_loss += loss .detach ()
280- print (f"liyurui, nonpp_loss is { nonpp_loss } " )
281-
282- #for p in unified_model_nonpp.parameters():
283- # print(f"nonpp {p.name}@grad, sname: {nonpp_id_sname[id(p)]}, {p.grad._md5sum()}")
284- # #if hasattr(p, "main_grad") and p.main_grad is not None:
285- # # print(f"nonpp {p.name}@grad, sname: {nonpp_id_sname[id(p)]}, {p.main_grad._md5sum()}")
286286
287- #for p in unified_model_pp.parameters():
288- # print(f"pp {p.name}@grad, sname: {pp_id_sname[id(p)]}, {p.grad._md5sum()}")
289- # #if hasattr(p, "main_grad") and p.main_grad is not None:
290- # # print(f"pp {p.name}@grad, sname: {pp_id_sname[id(p)]}, {p.main_grad._md5sum()}")
287+ np .testing .assert_equal (nonpp_loss .numpy (), pp_loss .numpy ())
291288
292289 unified_optimizer_nonpp .step ()
293290 unified_optimizer_nonpp .clear_grad ()
294291 unified_scheduler_nonpp .step ()
295292
296293
297294if __name__ == "__main__" :
298- unittest .main ()
295+ unittest .main ()
0 commit comments