Skip to content

Commit 5fabe32

Browse files
committed
fix unit test case
1 parent 5c29809 commit 5fabe32

File tree

2 files changed

+113
-34
lines changed

2 files changed

+113
-34
lines changed

python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -510,10 +510,7 @@ def __init__(
510510
self._build_layer()
511511

512512
self.comm_key_to_layer_name = {}
513-
if self._num_stages > 1:
514-
self.shared_comm = self._construct_shared_comm()
515-
else:
516-
self.shared_comm = {}
513+
self.shared_comm = self._construct_shared_comm()
517514
self._synchronize_shared_weights()
518515

519516
def get_stage_from_index(self, layer_idx):
@@ -544,7 +541,7 @@ def get_model_chunks(self):
544541
def _construct_shared_comm(self):
545542
shared_comm = {}
546543
if self._topo.get_dim("pipe") == 1:
547-
return
544+
return shared_comm
548545

549546
# The first loop gets the pivot stage and all different shared_weight_attrs for one layer name.
550547
# Maps stage idx to all shared attrs of each different layer names on that stage.
@@ -1004,12 +1001,9 @@ def flush_into_run_function():
10041001
param.is_firstly_shared = True
10051002

10061003
if layer.forward_func is None:
1007-
if self._num_stages == 1:
1008-
run_function.append(layer.build_layer())
1009-
else:
1010-
run_function.append(
1011-
self.shared_layers[layer.layer_name]
1012-
)
1004+
run_function.append(
1005+
self.shared_layers[layer.layer_name]
1006+
)
10131007

10141008
else:
10151009
run_function.append(

test/collective/fleet/hybrid_pp_unified_dygraph_model.py

Lines changed: 108 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import unittest
22
import numpy as np
3+
import random
34

45
import paddle
56
import paddle.distributed as dist
@@ -16,13 +17,19 @@
1617
batch_size = 5
1718
micro_batch_size = 1
1819

20+
def set_random_seed(seed, dp_id, rank_id):
21+
"""Set random seed for reproducibility."""
22+
random.seed(seed)
23+
np.random.seed(seed + dp_id)
24+
paddle.seed(seed + dp_id)
25+
1926
class RandomDataset(Dataset):
2027
def __init__(self, num_samples):
2128
self.num_samples = num_samples
2229

2330
def __getitem__(self, idx):
24-
input_ids = np.random.random([5]).astype('int64')
25-
label = np.random.randint(0, 5, (5)).astype('int64')
31+
input_ids = np.random.randint(0, 20, [10]).astype('int64')
32+
label = np.random.randint(0, 20, (10)).astype('int64')
2633
return input_ids, label
2734

2835
def __len__(self):
@@ -36,18 +43,29 @@ class EmbeddingPipe(nn.Layer):
3643
def __init__(self, **kwargs):
3744
super().__init__()
3845
self.embed_tokens = nn.Embedding(kwargs["num_embeddings"], kwargs["embedding_dim"])
46+
#print(f"liyurui, embeding weight init = {self.embedding_weight._md5sum()}")
3947

4048
def forward(self, input_ids):
49+
#print(f"liyurui, input_ids is {input_ids}")
50+
#print(f"liyurui, input_ids is {input_ids._md5sum()}, weight={self.embedding_weight._md5sum()}")
4151
hidden_states = self.embed_tokens.forward(input_ids)
52+
#print(f"liyurui, hidden_states of embedding pipe {hidden_states._md5sum()}")
4253
return (hidden_states, input_ids)
4354

4455
@property
4556
def embedding_weight(self):
4657
return getattr(self.embed_tokens, "weight")
4758

59+
def mtp_forward(layer, args):
60+
hidden_states = args[0]
61+
input_ids = args[1]
62+
embed = layer.forward(input_ids)
63+
output = embed[0] + hidden_states
64+
return (output, input_ids)
4865

4966
class MTPEmbeddingPipe(EmbeddingPipe):
5067
def forward(self, args):
68+
#print(f"liyurui, input of MTPEmbedding is {args}")
5169
hidden_states = args[0]
5270
input_ids = args[1]
5371
embed = super().forward(input_ids)
@@ -87,23 +105,23 @@ def __init__ (self, **kwargs):
87105
self._sequential_layers = []
88106
self.num_layer = 4
89107

90-
#self.add_sequential_layer(
91-
# SharedLayerDesc(
92-
# key="embed_weight_share",
93-
# layer_func=EmbeddingPipe,
94-
# shared_weight_attr="embedding_weight",
95-
# num_embeddings=vocab_size,
96-
# embedding_dim=hidden_size,
97-
# ),
98-
# "embed",
99-
#)
100108
self.add_sequential_layer(
101-
LayerDesc(
102-
EmbeddingPipe,
103-
num_embeddings=vocab_size,
104-
embedding_dim=hidden_size,
105-
), "embed"
109+
SharedLayerDesc(
110+
key="embed_weight_share",
111+
layer_func=EmbeddingPipe,
112+
shared_weight_attr="embedding_weight",
113+
num_embeddings=vocab_size,
114+
embedding_dim=hidden_size,
115+
),
116+
"embed",
106117
)
118+
#self.add_sequential_layer(
119+
# LayerDesc(
120+
# EmbeddingPipe,
121+
# num_embeddings=vocab_size,
122+
# embedding_dim=hidden_size,
123+
# ), "embed"
124+
#)
107125

108126
for i in range(self.num_layer):
109127
self.add_sequential_layer(
@@ -119,13 +137,22 @@ def __init__ (self, **kwargs):
119137
self.add_sequential_layer(
120138
SharedLayerDesc(
121139
key="embed_weight_share",
122-
layer_func=MTPEmbeddingPipe,
140+
#layer_func=MTPEmbeddingPipe,
141+
layer_func=EmbeddingPipe,
123142
shared_weight_attr="embedding_weight",
143+
forward_func=mtp_forward,
124144
num_embeddings=vocab_size,
125145
embedding_dim=hidden_size,
126146
),
127147
"embed_shared",
128148
)
149+
#self.add_sequential_layer(
150+
# LayerDesc(
151+
# MTPEmbeddingPipe,
152+
# num_embeddings=vocab_size,
153+
# embedding_dim=hidden_size,
154+
# ), "embed"
155+
#)
129156

130157
self.add_sequential_layer(
131158
LayerDesc(
@@ -177,6 +204,12 @@ def wrapper_mix_precision(self, model, optimizer):
177204
return model, optimizer
178205

179206
def test_unified_pp_model(self):
207+
hcg = fleet.get_hybrid_communicate_group()
208+
dp_id = hcg.get_data_parallel_rank()
209+
pp_id = hcg.get_stage_id()
210+
rank_id = dist.get_rank()
211+
set_random_seed(1024, dp_id, rank_id)
212+
180213
unified_model_pp = UnifiedPPModel(num_stages=self.pipeline_parallel_size)
181214
unified_scheduler_pp, unified_optimizer_pp = self.build_optimizer(unified_model_pp)
182215
unified_model_pp, unified_optimizer_pp = self.wrapper_mix_precision(unified_model_pp, unified_optimizer_pp)
@@ -186,6 +219,32 @@ def test_unified_pp_model(self):
186219
unified_model_nonpp = UnifiedPPModel(num_stages=1)
187220
unified_scheduler_nonpp, unified_optimizer_nonpp = self.build_optimizer(unified_model_nonpp)
188221

222+
pp_id_sname = {}
223+
for n, p in unified_model_pp.named_parameters():
224+
pp_id_sname[id(p)] = n
225+
226+
#for p in unified_model_pp.parameters():
227+
# print(f"liyurui, pp parameter is {pp_id_sname[id(p)]}, {p.name}, {p.shape}")
228+
229+
nonpp_id_sname = {}
230+
for n, p in unified_model_nonpp.named_parameters():
231+
nonpp_id_sname[id(p)] = n
232+
233+
#for p in unified_model_nonpp.parameters():
234+
# print(f"liyurui, nonpp parameter is {nonpp_id_sname[id(p)]}, {p.name}, {p.shape}")
235+
236+
# reset to make pp and nonpp model have same parameters value
237+
if pp_id == 0:
238+
unified_model_pp.parameters()[0].set_value(unified_model_nonpp.parameters()[0])
239+
unified_model_pp.parameters()[1].set_value(unified_model_nonpp.parameters()[1])
240+
unified_model_pp.parameters()[2].set_value(unified_model_nonpp.parameters()[2])
241+
else:
242+
#unified_model_pp.parameters()[0].set_value(unified_model_nonpp.parameters()[0])
243+
unified_model_pp.parameters()[1].set_value(unified_model_nonpp.parameters()[3])
244+
unified_model_pp.parameters()[2].set_value(unified_model_nonpp.parameters()[4])
245+
unified_model_pp.parameters()[3].set_value(unified_model_nonpp.parameters()[5])
246+
#unified_model_pp.parameters()[3].set_value(unified_model_nonpp.parameters()[6])
247+
189248
dataset = RandomDataset(5 * batch_size)
190249

191250
train_reader = DataLoader(
@@ -196,17 +255,43 @@ def test_unified_pp_model(self):
196255
num_workers=2,
197256
)
198257

258+
for p in unified_model_pp.parameters():
259+
print(f"liyurui, pp parameter is {pp_id_sname[id(p)]}, {p.name}, {p._md5sum()}")
260+
261+
for p in unified_model_nonpp.parameters():
262+
print(f"liyurui, nonpp parameter is {nonpp_id_sname[id(p)]}, {p.name}, {p._md5sum()}")
263+
199264
for _, (input_ids, label) in enumerate(train_reader()):
200-
print(f"liyurui, input_ids is {input_ids.shape}, {input_ids.dtype}, label is {label.shape}, {label.dtype}")
265+
#print(f"liyurui, input_ids is {input_ids.shape}, {input_ids.dtype}, label is {label.shape}, {label.dtype}")
201266
pp_loss = unified_model_pp.train_batch([input_ids, label], unified_optimizer_pp, unified_scheduler_pp)
202267
print(f"liyurui, pp_loss is {pp_loss}")
203268

204-
nonpp_output = unified_model_nonpp(input_ids)
205-
loss_fn = nn.loss.CrossEntropyLoss()
206-
nonpp_loss = loss_fn(nonpp_output[0], label)
269+
num_acc = batch_size // micro_batch_size
270+
micro_input_ids = paddle.split(input_ids, num_acc)
271+
micro_labels = paddle.split(label, num_acc)
272+
273+
nonpp_loss = 0
274+
for micro_input, micro_label in zip(micro_input_ids, micro_labels):
275+
nonpp_output = unified_model_nonpp(micro_input)
276+
loss_fn = nn.loss.CrossEntropyLoss()
277+
loss = loss_fn(nonpp_output[0], micro_label) / num_acc
278+
loss.backward()
279+
nonpp_loss += loss.detach()
207280
print(f"liyurui, nonpp_loss is {nonpp_loss}")
208281

209-
return
282+
#for p in unified_model_nonpp.parameters():
283+
# print(f"nonpp {p.name}@grad, sname: {nonpp_id_sname[id(p)]}, {p.grad._md5sum()}")
284+
# #if hasattr(p, "main_grad") and p.main_grad is not None:
285+
# # print(f"nonpp {p.name}@grad, sname: {nonpp_id_sname[id(p)]}, {p.main_grad._md5sum()}")
286+
287+
#for p in unified_model_pp.parameters():
288+
# print(f"pp {p.name}@grad, sname: {pp_id_sname[id(p)]}, {p.grad._md5sum()}")
289+
# #if hasattr(p, "main_grad") and p.main_grad is not None:
290+
# # print(f"pp {p.name}@grad, sname: {pp_id_sname[id(p)]}, {p.main_grad._md5sum()}")
291+
292+
unified_optimizer_nonpp.step()
293+
unified_optimizer_nonpp.clear_grad()
294+
unified_scheduler_nonpp.step()
210295

211296

212297
if __name__ == "__main__":

0 commit comments

Comments
 (0)