Skip to content

Commit f3ad116

Browse files
LiYuRioHz188
authored andcommitted
fix unit test case
1 parent 5c29809 commit f3ad116

File tree

1 file changed

+83
-16
lines changed

1 file changed

+83
-16
lines changed

test/collective/fleet/hybrid_pp_unified_dygraph_model.py

Lines changed: 83 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import unittest
22
import numpy as np
3+
import random
34

45
import paddle
56
import paddle.distributed as dist
@@ -16,13 +17,19 @@
1617
batch_size = 5
1718
micro_batch_size = 1
1819

20+
def set_random_seed(seed, dp_id, rank_id):
21+
"""Set random seed for reproducibility."""
22+
random.seed(seed)
23+
np.random.seed(seed + dp_id)
24+
paddle.seed(seed + dp_id)
25+
1926
class RandomDataset(Dataset):
2027
def __init__(self, num_samples):
2128
self.num_samples = num_samples
2229

2330
def __getitem__(self, idx):
24-
input_ids = np.random.random([5]).astype('int64')
25-
label = np.random.randint(0, 5, (5)).astype('int64')
31+
input_ids = np.random.randint(0, 20, [10]).astype('int64')
32+
label = np.random.randint(0, 20, (10)).astype('int64')
2633
return input_ids, label
2734

2835
def __len__(self):
@@ -39,6 +46,7 @@ def __init__(self, **kwargs):
3946

4047
def forward(self, input_ids):
4148
hidden_states = self.embed_tokens.forward(input_ids)
49+
#print(f"liyurui, hidden_states of embedding pipe {hidden_states._md5sum()}")
4250
return (hidden_states, input_ids)
4351

4452
@property
@@ -116,16 +124,23 @@ def __init__ (self, **kwargs):
116124
), f"layer.{i}"
117125
)
118126

119-
self.add_sequential_layer(
120-
SharedLayerDesc(
121-
key="embed_weight_share",
122-
layer_func=MTPEmbeddingPipe,
123-
shared_weight_attr="embedding_weight",
124-
num_embeddings=vocab_size,
125-
embedding_dim=hidden_size,
126-
),
127-
"embed_shared",
128-
)
127+
#self.add_sequential_layer(
128+
# SharedLayerDesc(
129+
# key="embed_weight_share",
130+
# layer_func=MTPEmbeddingPipe,
131+
# shared_weight_attr="embedding_weight",
132+
# num_embeddings=vocab_size,
133+
# embedding_dim=hidden_size,
134+
# ),
135+
# "embed_shared",
136+
#)
137+
#self.add_sequential_layer(
138+
# LayerDesc(
139+
# MTPEmbeddingPipe,
140+
# num_embeddings=vocab_size,
141+
# embedding_dim=hidden_size,
142+
# ), "embed"
143+
#)
129144

130145
self.add_sequential_layer(
131146
LayerDesc(
@@ -177,6 +192,12 @@ def wrapper_mix_precision(self, model, optimizer):
177192
return model, optimizer
178193

179194
def test_unified_pp_model(self):
195+
hcg = fleet.get_hybrid_communicate_group()
196+
dp_id = hcg.get_data_parallel_rank()
197+
pp_id = hcg.get_stage_id()
198+
rank_id = dist.get_rank()
199+
set_random_seed(1024, dp_id, rank_id)
200+
180201
unified_model_pp = UnifiedPPModel(num_stages=self.pipeline_parallel_size)
181202
unified_scheduler_pp, unified_optimizer_pp = self.build_optimizer(unified_model_pp)
182203
unified_model_pp, unified_optimizer_pp = self.wrapper_mix_precision(unified_model_pp, unified_optimizer_pp)
@@ -186,6 +207,31 @@ def test_unified_pp_model(self):
186207
unified_model_nonpp = UnifiedPPModel(num_stages=1)
187208
unified_scheduler_nonpp, unified_optimizer_nonpp = self.build_optimizer(unified_model_nonpp)
188209

210+
pp_id_sname = {}
211+
for n, p in unified_model_pp.named_parameters():
212+
pp_id_sname[id(p)] = n
213+
214+
#for p in unified_model_pp.parameters():
215+
# print(f"liyurui, pp parameter is {pp_id_sname[id(p)]}, {p.name}, {p.shape}")
216+
217+
nonpp_id_sname = {}
218+
for n, p in unified_model_nonpp.named_parameters():
219+
nonpp_id_sname[id(p)] = n
220+
221+
#for p in unified_model_nonpp.parameters():
222+
# print(f"liyurui, nonpp parameter is {nonpp_id_sname[id(p)]}, {p.name}, {p.shape}")
223+
224+
# reset to make pp and nonpp model have same parameters value
225+
if pp_id == 0:
226+
unified_model_pp.parameters()[0].set_value(unified_model_nonpp.parameters()[0])
227+
unified_model_pp.parameters()[1].set_value(unified_model_nonpp.parameters()[1])
228+
unified_model_pp.parameters()[2].set_value(unified_model_nonpp.parameters()[2])
229+
else:
230+
unified_model_pp.parameters()[0].set_value(unified_model_nonpp.parameters()[3])
231+
unified_model_pp.parameters()[1].set_value(unified_model_nonpp.parameters()[4])
232+
unified_model_pp.parameters()[2].set_value(unified_model_nonpp.parameters()[5])
233+
#unified_model_pp.parameters()[3].set_value(unified_model_nonpp.parameters()[6])
234+
189235
dataset = RandomDataset(5 * batch_size)
190236

191237
train_reader = DataLoader(
@@ -201,12 +247,33 @@ def test_unified_pp_model(self):
201247
pp_loss = unified_model_pp.train_batch([input_ids, label], unified_optimizer_pp, unified_scheduler_pp)
202248
print(f"liyurui, pp_loss is {pp_loss}")
203249

204-
nonpp_output = unified_model_nonpp(input_ids)
205-
loss_fn = nn.loss.CrossEntropyLoss()
206-
nonpp_loss = loss_fn(nonpp_output[0], label)
250+
num_acc = batch_size // micro_batch_size
251+
micro_input_ids = paddle.split(input_ids, num_acc)
252+
micro_labels = paddle.split(label, num_acc)
253+
254+
nonpp_loss = 0
255+
for micro_input, micro_label in zip(micro_input_ids, micro_labels):
256+
nonpp_output = unified_model_nonpp(micro_input)
257+
loss_fn = nn.loss.CrossEntropyLoss()
258+
loss = loss_fn(nonpp_output[0], micro_label) / num_acc
259+
loss.backward()
260+
nonpp_loss += loss.detach()
261+
#nonpp_loss /= num_acc
207262
print(f"liyurui, nonpp_loss is {nonpp_loss}")
208263

209-
return
264+
for p in unified_model_nonpp.parameters():
265+
print(f"nonpp {p.name}@grad, sname: {nonpp_id_sname[id(p)]}, {p.grad._md5sum()}")
266+
#if hasattr(p, "main_grad") and p.main_grad is not None:
267+
# print(f"nonpp {p.name}@grad, sname: {nonpp_id_sname[id(p)]}, {p.main_grad._md5sum()}")
268+
269+
for p in unified_model_pp.parameters():
270+
print(f"pp {p.name}@grad, sname: {pp_id_sname[id(p)]}, {p.grad._md5sum()}")
271+
#if hasattr(p, "main_grad") and p.main_grad is not None:
272+
# print(f"pp {p.name}@grad, sname: {pp_id_sname[id(p)]}, {p.main_grad._md5sum()}")
273+
274+
unified_optimizer_nonpp.step()
275+
unified_optimizer_nonpp.clear_grad()
276+
unified_scheduler_nonpp.step()
210277

211278

212279
if __name__ == "__main__":

0 commit comments

Comments
 (0)