11import unittest
22import numpy as np
3+ import random
34
45import paddle
56import paddle .distributed as dist
1617batch_size = 5
1718micro_batch_size = 1
1819
20+ def set_random_seed (seed , dp_id , rank_id ):
21+ """Set random seed for reproducibility."""
22+ random .seed (seed )
23+ np .random .seed (seed + dp_id )
24+ paddle .seed (seed + dp_id )
25+
1926class RandomDataset (Dataset ):
2027 def __init__ (self , num_samples ):
2128 self .num_samples = num_samples
2229
2330 def __getitem__ (self , idx ):
24- input_ids = np .random .random ([ 5 ]).astype ('int64' )
25- label = np .random .randint (0 , 5 , (5 )).astype ('int64' )
31+ input_ids = np .random .randint ( 0 , 20 , [ 10 ]).astype ('int64' )
32+ label = np .random .randint (0 , 20 , (10 )).astype ('int64' )
2633 return input_ids , label
2734
2835 def __len__ (self ):
@@ -39,6 +46,7 @@ def __init__(self, **kwargs):
3946
4047 def forward (self , input_ids ):
4148 hidden_states = self .embed_tokens .forward (input_ids )
49+ #print(f"liyurui, hidden_states of embedding pipe {hidden_states._md5sum()}")
4250 return (hidden_states , input_ids )
4351
4452 @property
@@ -116,16 +124,23 @@ def __init__ (self, **kwargs):
116124 ), f"layer.{ i } "
117125 )
118126
119- self .add_sequential_layer (
120- SharedLayerDesc (
121- key = "embed_weight_share" ,
122- layer_func = MTPEmbeddingPipe ,
123- shared_weight_attr = "embedding_weight" ,
124- num_embeddings = vocab_size ,
125- embedding_dim = hidden_size ,
126- ),
127- "embed_shared" ,
128- )
127+ #self.add_sequential_layer(
128+ # SharedLayerDesc(
129+ # key="embed_weight_share",
130+ # layer_func=MTPEmbeddingPipe,
131+ # shared_weight_attr="embedding_weight",
132+ # num_embeddings=vocab_size,
133+ # embedding_dim=hidden_size,
134+ # ),
135+ # "embed_shared",
136+ #)
137+ #self.add_sequential_layer(
138+ # LayerDesc(
139+ # MTPEmbeddingPipe,
140+ # num_embeddings=vocab_size,
141+ # embedding_dim=hidden_size,
142+ # ), "embed"
143+ #)
129144
130145 self .add_sequential_layer (
131146 LayerDesc (
@@ -177,6 +192,12 @@ def wrapper_mix_precision(self, model, optimizer):
177192 return model , optimizer
178193
179194 def test_unified_pp_model (self ):
195+ hcg = fleet .get_hybrid_communicate_group ()
196+ dp_id = hcg .get_data_parallel_rank ()
197+ pp_id = hcg .get_stage_id ()
198+ rank_id = dist .get_rank ()
199+ set_random_seed (1024 , dp_id , rank_id )
200+
180201 unified_model_pp = UnifiedPPModel (num_stages = self .pipeline_parallel_size )
181202 unified_scheduler_pp , unified_optimizer_pp = self .build_optimizer (unified_model_pp )
182203 unified_model_pp , unified_optimizer_pp = self .wrapper_mix_precision (unified_model_pp , unified_optimizer_pp )
@@ -186,6 +207,31 @@ def test_unified_pp_model(self):
186207 unified_model_nonpp = UnifiedPPModel (num_stages = 1 )
187208 unified_scheduler_nonpp , unified_optimizer_nonpp = self .build_optimizer (unified_model_nonpp )
188209
210+ pp_id_sname = {}
211+ for n , p in unified_model_pp .named_parameters ():
212+ pp_id_sname [id (p )] = n
213+
214+ #for p in unified_model_pp.parameters():
215+ # print(f"liyurui, pp parameter is {pp_id_sname[id(p)]}, {p.name}, {p.shape}")
216+
217+ nonpp_id_sname = {}
218+ for n , p in unified_model_nonpp .named_parameters ():
219+ nonpp_id_sname [id (p )] = n
220+
221+ #for p in unified_model_nonpp.parameters():
222+ # print(f"liyurui, nonpp parameter is {nonpp_id_sname[id(p)]}, {p.name}, {p.shape}")
223+
224+ # reset to make pp and nonpp model have same parameters value
225+ if pp_id == 0 :
226+ unified_model_pp .parameters ()[0 ].set_value (unified_model_nonpp .parameters ()[0 ])
227+ unified_model_pp .parameters ()[1 ].set_value (unified_model_nonpp .parameters ()[1 ])
228+ unified_model_pp .parameters ()[2 ].set_value (unified_model_nonpp .parameters ()[2 ])
229+ else :
230+ unified_model_pp .parameters ()[0 ].set_value (unified_model_nonpp .parameters ()[3 ])
231+ unified_model_pp .parameters ()[1 ].set_value (unified_model_nonpp .parameters ()[4 ])
232+ unified_model_pp .parameters ()[2 ].set_value (unified_model_nonpp .parameters ()[5 ])
233+ #unified_model_pp.parameters()[3].set_value(unified_model_nonpp.parameters()[6])
234+
189235 dataset = RandomDataset (5 * batch_size )
190236
191237 train_reader = DataLoader (
@@ -201,12 +247,33 @@ def test_unified_pp_model(self):
201247 pp_loss = unified_model_pp .train_batch ([input_ids , label ], unified_optimizer_pp , unified_scheduler_pp )
202248 print (f"liyurui, pp_loss is { pp_loss } " )
203249
204- nonpp_output = unified_model_nonpp (input_ids )
205- loss_fn = nn .loss .CrossEntropyLoss ()
206- nonpp_loss = loss_fn (nonpp_output [0 ], label )
250+ num_acc = batch_size // micro_batch_size
251+ micro_input_ids = paddle .split (input_ids , num_acc )
252+ micro_labels = paddle .split (label , num_acc )
253+
254+ nonpp_loss = 0
255+ for micro_input , micro_label in zip (micro_input_ids , micro_labels ):
256+ nonpp_output = unified_model_nonpp (micro_input )
257+ loss_fn = nn .loss .CrossEntropyLoss ()
258+ loss = loss_fn (nonpp_output [0 ], micro_label ) / num_acc
259+ loss .backward ()
260+ nonpp_loss += loss .detach ()
261+ #nonpp_loss /= num_acc
207262 print (f"liyurui, nonpp_loss is { nonpp_loss } " )
208263
209- return
264+ for p in unified_model_nonpp .parameters ():
265+ print (f"nonpp { p .name } @grad, sname: { nonpp_id_sname [id (p )]} , { p .grad ._md5sum ()} " )
266+ #if hasattr(p, "main_grad") and p.main_grad is not None:
267+ # print(f"nonpp {p.name}@grad, sname: {nonpp_id_sname[id(p)]}, {p.main_grad._md5sum()}")
268+
269+ for p in unified_model_pp .parameters ():
270+ print (f"pp { p .name } @grad, sname: { pp_id_sname [id (p )]} , { p .grad ._md5sum ()} " )
271+ #if hasattr(p, "main_grad") and p.main_grad is not None:
272+ # print(f"pp {p.name}@grad, sname: {pp_id_sname[id(p)]}, {p.main_grad._md5sum()}")
273+
274+ unified_optimizer_nonpp .step ()
275+ unified_optimizer_nonpp .clear_grad ()
276+ unified_scheduler_nonpp .step ()
210277
211278
212279if __name__ == "__main__" :
0 commit comments