11import unittest
22import numpy as np
3+ import random
34
45import paddle
56import paddle .distributed as dist
1617batch_size = 5
1718micro_batch_size = 1
1819
20+ def set_random_seed (seed , dp_id , rank_id ):
21+ """Set random seed for reproducibility."""
22+ random .seed (seed )
23+ np .random .seed (seed + dp_id )
24+ paddle .seed (seed + dp_id )
25+
1926class RandomDataset (Dataset ):
2027 def __init__ (self , num_samples ):
2128 self .num_samples = num_samples
2229
2330 def __getitem__ (self , idx ):
24- input_ids = np .random .random ([ 5 ]).astype ('int64' )
25- label = np .random .randint (0 , 5 , (5 )).astype ('int64' )
31+ input_ids = np .random .randint ( 0 , 20 , [ 10 ]).astype ('int64' )
32+ label = np .random .randint (0 , 20 , (10 )).astype ('int64' )
2633 return input_ids , label
2734
2835 def __len__ (self ):
@@ -36,18 +43,29 @@ class EmbeddingPipe(nn.Layer):
3643 def __init__ (self , ** kwargs ):
3744 super ().__init__ ()
3845 self .embed_tokens = nn .Embedding (kwargs ["num_embeddings" ], kwargs ["embedding_dim" ])
46+ #print(f"liyurui, embeding weight init = {self.embedding_weight._md5sum()}")
3947
4048 def forward (self , input_ids ):
49+ #print(f"liyurui, input_ids is {input_ids}")
50+ #print(f"liyurui, input_ids is {input_ids._md5sum()}, weight={self.embedding_weight._md5sum()}")
4151 hidden_states = self .embed_tokens .forward (input_ids )
52+ #print(f"liyurui, hidden_states of embedding pipe {hidden_states._md5sum()}")
4253 return (hidden_states , input_ids )
4354
4455 @property
4556 def embedding_weight (self ):
4657 return getattr (self .embed_tokens , "weight" )
4758
59+ def mtp_forward (layer , args ):
60+ hidden_states = args [0 ]
61+ input_ids = args [1 ]
62+ embed = layer .forward (input_ids )
63+ output = embed [0 ] + hidden_states
64+ return (output , input_ids )
4865
4966class MTPEmbeddingPipe (EmbeddingPipe ):
5067 def forward (self , args ):
68+ #print(f"liyurui, input of MTPEmbedding is {args}")
5169 hidden_states = args [0 ]
5270 input_ids = args [1 ]
5371 embed = super ().forward (input_ids )
@@ -87,23 +105,23 @@ def __init__ (self, **kwargs):
87105 self ._sequential_layers = []
88106 self .num_layer = 4
89107
90- #self.add_sequential_layer(
91- # SharedLayerDesc(
92- # key="embed_weight_share",
93- # layer_func=EmbeddingPipe,
94- # shared_weight_attr="embedding_weight",
95- # num_embeddings=vocab_size,
96- # embedding_dim=hidden_size,
97- # ),
98- # "embed",
99- #)
100108 self .add_sequential_layer (
101- LayerDesc (
102- EmbeddingPipe ,
103- num_embeddings = vocab_size ,
104- embedding_dim = hidden_size ,
105- ), "embed"
109+ SharedLayerDesc (
110+ key = "embed_weight_share" ,
111+ layer_func = EmbeddingPipe ,
112+ shared_weight_attr = "embedding_weight" ,
113+ num_embeddings = vocab_size ,
114+ embedding_dim = hidden_size ,
115+ ),
116+ "embed" ,
106117 )
118+ #self.add_sequential_layer(
119+ # LayerDesc(
120+ # EmbeddingPipe,
121+ # num_embeddings=vocab_size,
122+ # embedding_dim=hidden_size,
123+ # ), "embed"
124+ #)
107125
108126 for i in range (self .num_layer ):
109127 self .add_sequential_layer (
@@ -119,13 +137,22 @@ def __init__ (self, **kwargs):
119137 self .add_sequential_layer (
120138 SharedLayerDesc (
121139 key = "embed_weight_share" ,
122- layer_func = MTPEmbeddingPipe ,
140+ #layer_func=MTPEmbeddingPipe,
141+ layer_func = EmbeddingPipe ,
123142 shared_weight_attr = "embedding_weight" ,
143+ forward_func = mtp_forward ,
124144 num_embeddings = vocab_size ,
125145 embedding_dim = hidden_size ,
126146 ),
127147 "embed_shared" ,
128148 )
149+ #self.add_sequential_layer(
150+ # LayerDesc(
151+ # MTPEmbeddingPipe,
152+ # num_embeddings=vocab_size,
153+ # embedding_dim=hidden_size,
154+ # ), "embed"
155+ #)
129156
130157 self .add_sequential_layer (
131158 LayerDesc (
@@ -177,6 +204,12 @@ def wrapper_mix_precision(self, model, optimizer):
177204 return model , optimizer
178205
179206 def test_unified_pp_model (self ):
207+ hcg = fleet .get_hybrid_communicate_group ()
208+ dp_id = hcg .get_data_parallel_rank ()
209+ pp_id = hcg .get_stage_id ()
210+ rank_id = dist .get_rank ()
211+ set_random_seed (1024 , dp_id , rank_id )
212+
180213 unified_model_pp = UnifiedPPModel (num_stages = self .pipeline_parallel_size )
181214 unified_scheduler_pp , unified_optimizer_pp = self .build_optimizer (unified_model_pp )
182215 unified_model_pp , unified_optimizer_pp = self .wrapper_mix_precision (unified_model_pp , unified_optimizer_pp )
@@ -186,6 +219,32 @@ def test_unified_pp_model(self):
186219 unified_model_nonpp = UnifiedPPModel (num_stages = 1 )
187220 unified_scheduler_nonpp , unified_optimizer_nonpp = self .build_optimizer (unified_model_nonpp )
188221
222+ pp_id_sname = {}
223+ for n , p in unified_model_pp .named_parameters ():
224+ pp_id_sname [id (p )] = n
225+
226+ #for p in unified_model_pp.parameters():
227+ # print(f"liyurui, pp parameter is {pp_id_sname[id(p)]}, {p.name}, {p.shape}")
228+
229+ nonpp_id_sname = {}
230+ for n , p in unified_model_nonpp .named_parameters ():
231+ nonpp_id_sname [id (p )] = n
232+
233+ #for p in unified_model_nonpp.parameters():
234+ # print(f"liyurui, nonpp parameter is {nonpp_id_sname[id(p)]}, {p.name}, {p.shape}")
235+
236+ # reset to make pp and nonpp model have same parameters value
237+ if pp_id == 0 :
238+ unified_model_pp .parameters ()[0 ].set_value (unified_model_nonpp .parameters ()[0 ])
239+ unified_model_pp .parameters ()[1 ].set_value (unified_model_nonpp .parameters ()[1 ])
240+ unified_model_pp .parameters ()[2 ].set_value (unified_model_nonpp .parameters ()[2 ])
241+ else :
242+ #unified_model_pp.parameters()[0].set_value(unified_model_nonpp.parameters()[0])
243+ unified_model_pp .parameters ()[1 ].set_value (unified_model_nonpp .parameters ()[3 ])
244+ unified_model_pp .parameters ()[2 ].set_value (unified_model_nonpp .parameters ()[4 ])
245+ unified_model_pp .parameters ()[3 ].set_value (unified_model_nonpp .parameters ()[5 ])
246+ #unified_model_pp.parameters()[3].set_value(unified_model_nonpp.parameters()[6])
247+
189248 dataset = RandomDataset (5 * batch_size )
190249
191250 train_reader = DataLoader (
@@ -196,17 +255,43 @@ def test_unified_pp_model(self):
196255 num_workers = 2 ,
197256 )
198257
258+ for p in unified_model_pp .parameters ():
259+ print (f"liyurui, pp parameter is { pp_id_sname [id (p )]} , { p .name } , { p ._md5sum ()} " )
260+
261+ for p in unified_model_nonpp .parameters ():
262+ print (f"liyurui, nonpp parameter is { nonpp_id_sname [id (p )]} , { p .name } , { p ._md5sum ()} " )
263+
199264 for _ , (input_ids , label ) in enumerate (train_reader ()):
200- print (f"liyurui, input_ids is { input_ids .shape } , { input_ids .dtype } , label is { label .shape } , { label .dtype } " )
265+ # print(f"liyurui, input_ids is {input_ids.shape}, {input_ids.dtype}, label is {label.shape}, {label.dtype}")
201266 pp_loss = unified_model_pp .train_batch ([input_ids , label ], unified_optimizer_pp , unified_scheduler_pp )
202267 print (f"liyurui, pp_loss is { pp_loss } " )
203268
204- nonpp_output = unified_model_nonpp (input_ids )
205- loss_fn = nn .loss .CrossEntropyLoss ()
206- nonpp_loss = loss_fn (nonpp_output [0 ], label )
269+ num_acc = batch_size // micro_batch_size
270+ micro_input_ids = paddle .split (input_ids , num_acc )
271+ micro_labels = paddle .split (label , num_acc )
272+
273+ nonpp_loss = 0
274+ for micro_input , micro_label in zip (micro_input_ids , micro_labels ):
275+ nonpp_output = unified_model_nonpp (micro_input )
276+ loss_fn = nn .loss .CrossEntropyLoss ()
277+ loss = loss_fn (nonpp_output [0 ], micro_label ) / num_acc
278+ loss .backward ()
279+ nonpp_loss += loss .detach ()
207280 print (f"liyurui, nonpp_loss is { nonpp_loss } " )
208281
209- return
282+ #for p in unified_model_nonpp.parameters():
283+ # print(f"nonpp {p.name}@grad, sname: {nonpp_id_sname[id(p)]}, {p.grad._md5sum()}")
284+ # #if hasattr(p, "main_grad") and p.main_grad is not None:
285+ # # print(f"nonpp {p.name}@grad, sname: {nonpp_id_sname[id(p)]}, {p.main_grad._md5sum()}")
286+
287+ #for p in unified_model_pp.parameters():
288+ # print(f"pp {p.name}@grad, sname: {pp_id_sname[id(p)]}, {p.grad._md5sum()}")
289+ # #if hasattr(p, "main_grad") and p.main_grad is not None:
290+ # # print(f"pp {p.name}@grad, sname: {pp_id_sname[id(p)]}, {p.main_grad._md5sum()}")
291+
292+ unified_optimizer_nonpp .step ()
293+ unified_optimizer_nonpp .clear_grad ()
294+ unified_scheduler_nonpp .step ()
210295
211296
212297if __name__ == "__main__" :
0 commit comments