From 80340c48592bc335d24cb3f9cfe95e32fa00d2c7 Mon Sep 17 00:00:00 2001 From: LiYuRio Date: Tue, 12 Aug 2025 18:08:30 +0800 Subject: [PATCH 1/4] Pipeline Layer and SharedLayerDesc support nonpp parallel --- .../fleet/hybrid_pp_unified_dygraph_model.py | 213 ++++++++++++++++++ .../fleet/test_pp_unified_dygraph_model.py | 14 ++ 2 files changed, 227 insertions(+) create mode 100644 test/collective/fleet/hybrid_pp_unified_dygraph_model.py create mode 100644 test/collective/fleet/test_pp_unified_dygraph_model.py diff --git a/test/collective/fleet/hybrid_pp_unified_dygraph_model.py b/test/collective/fleet/hybrid_pp_unified_dygraph_model.py new file mode 100644 index 00000000000000..e3c6337a3ca253 --- /dev/null +++ b/test/collective/fleet/hybrid_pp_unified_dygraph_model.py @@ -0,0 +1,213 @@ +import unittest +import numpy as np + +import paddle +import paddle.distributed as dist +from paddle import nn +from paddle.io import DataLoader, Dataset + +from paddle.distributed import fleet +from paddle.distributed.fleet.meta_parallel import ( + LayerDesc, + SharedLayerDesc, + PipelineLayer, +) + +batch_size = 5 +micro_batch_size = 1 + +class RandomDataset(Dataset): + def __init__(self, num_samples): + self.num_samples = num_samples + + def __getitem__(self, idx): + input_ids = np.random.random([5]).astype('int64') + label = np.random.randint(0, 5, (5)).astype('int64') + return input_ids, label + + def __len__(self): + return self.num_samples + + +vocab_size = 1024 +hidden_size = 64 + +class EmbeddingPipe(nn.Layer): + def __init__(self, **kwargs): + super().__init__() + self.embed_tokens = nn.Embedding(kwargs["num_embeddings"], kwargs["embedding_dim"]) + + def forward(self, input_ids): + hidden_states = self.embed_tokens.forward(input_ids) + return (hidden_states, input_ids) + + @property + def embedding_weight(self): + return getattr(self.embed_tokens, "weight") + + +class MTPEmbeddingPipe(EmbeddingPipe): + def forward(self, args): + hidden_states = args[0] + input_ids = args[1] + embed = super().forward(input_ids) + output = embed[0] + hidden_states + return (output, input_ids) + + +class LinearPipe(nn.Linear): + def __init__( + self, + in_features, + out_features, + weight_attr = None, + bias_attr = None, + name = None, + layer_idx = 0 + ): + self.layer_idx = layer_idx + super().__init__(in_features, out_features, bias_attr=bias_attr) + + def forward(self, args): + hidden_states = args[0] + input_ids = args[1] + output = super().forward(hidden_states) + return (output, input_ids) + + +class CrossEntropyLossPipe(nn.loss.CrossEntropyLoss): + def forward(self, logits, label): + if isinstance(logits, tuple): + logits = logits[0] + return super().forward(logits, label) + + +class UnifiedPPModel(PipelineLayer): + def __init__ (self, **kwargs): + self._sequential_layers = [] + self.num_layer = 4 + + #self.add_sequential_layer( + # SharedLayerDesc( + # key="embed_weight_share", + # layer_func=EmbeddingPipe, + # shared_weight_attr="embedding_weight", + # num_embeddings=vocab_size, + # embedding_dim=hidden_size, + # ), + # "embed", + #) + self.add_sequential_layer( + LayerDesc( + EmbeddingPipe, + num_embeddings=vocab_size, + embedding_dim=hidden_size, + ), "embed" + ) + + for i in range(self.num_layer): + self.add_sequential_layer( + LayerDesc( + LinearPipe, + hidden_size, + hidden_size, + bias_attr=False, + layer_idx=i, + ), f"layer.{i}" + ) + + self.add_sequential_layer( + SharedLayerDesc( + key="embed_weight_share", + layer_func=MTPEmbeddingPipe, + shared_weight_attr="embedding_weight", + num_embeddings=vocab_size, + embedding_dim=hidden_size, + ), + "embed_shared", + ) + + self.add_sequential_layer( + LayerDesc( + LinearPipe, + hidden_size, + hidden_size, + bias_attr=False, + layer_idx=self.num_layer + ), "last_layer" + ) + + super().__init__(layers=self.get_sequential_layer(), loss_fn=CrossEntropyLossPipe(), **kwargs) + + def add_sequential_layer(self, layer_desc, name_prefix=""): + self._sequential_layers.append({"layer": layer_desc, "name_prefix": name_prefix}) + + def get_sequential_layer(self): + return [x["layer"] for x in self._sequential_layers] + + + +class TestDistPPTraining(unittest.TestCase): + def setUp(self): + strategy = fleet.DistributedStrategy() + self.model_parallel_size = 1 + self.data_parallel_size = 1 + self.pipeline_parallel_size = 2 + strategy.hybrid_configs = { + "dp_degree": self.data_parallel_size, + "mp_degree": self.model_parallel_size, + "pp_degree": self.pipeline_parallel_size, + } + strategy.pipeline_configs = { + "accumulate_steps": batch_size // micro_batch_size, + "micro_batch_size": micro_batch_size, + } + fleet.init(is_collective=True, strategy=strategy) + + def build_optimizer(self, model): + scheduler = paddle.optimizer.lr.PiecewiseDecay( + boundaries=[2], values=[0.001, 0.002], verbose=True + ) + optimizer = paddle.optimizer.SGD( + learning_rate=scheduler, parameters=model.parameters() + ) + return scheduler, optimizer + + def wrapper_mix_precision(self, model, optimizer): + return model, optimizer + + def test_unified_pp_model(self): + unified_model_pp = UnifiedPPModel(num_stages=self.pipeline_parallel_size) + unified_scheduler_pp, unified_optimizer_pp = self.build_optimizer(unified_model_pp) + unified_model_pp, unified_optimizer_pp = self.wrapper_mix_precision(unified_model_pp, unified_optimizer_pp) + unified_model_pp = fleet.distributed_model(unified_model_pp) + unified_optimizer_pp = fleet.distributed_optimizer(unified_optimizer_pp) + + unified_model_nonpp = UnifiedPPModel(num_stages=1) + unified_scheduler_nonpp, unified_optimizer_nonpp = self.build_optimizer(unified_model_nonpp) + + dataset = RandomDataset(5 * batch_size) + + train_reader = DataLoader( + dataset, + batch_size=batch_size, + shuffle=True, + drop_last=True, + num_workers=2, + ) + + for _, (input_ids, label) in enumerate(train_reader()): + print(f"liyurui, input_ids is {input_ids.shape}, {input_ids.dtype}, label is {label.shape}, {label.dtype}") + pp_loss = unified_model_pp.train_batch([input_ids, label], unified_optimizer_pp, unified_scheduler_pp) + print(f"liyurui, pp_loss is {pp_loss}") + + nonpp_output = unified_model_nonpp(input_ids) + loss_fn = nn.loss.CrossEntropyLoss() + nonpp_loss = loss_fn(nonpp_output[0], label) + print(f"liyurui, nonpp_loss is {nonpp_loss}") + + return + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/test/collective/fleet/test_pp_unified_dygraph_model.py b/test/collective/fleet/test_pp_unified_dygraph_model.py new file mode 100644 index 00000000000000..1b301b21a8d33e --- /dev/null +++ b/test/collective/fleet/test_pp_unified_dygraph_model.py @@ -0,0 +1,14 @@ +import unittest + +from legacy_test.test_parallel_dygraph_dataparallel import ( + TestMultipleAccelerators, +) + + +class TestPipelineParallel(TestMultipleAccelerators): + def test_pipeline_parallel(self): + self.run_mnist_2accelerators('hybrid_pp_unified_dygraph_model.py') + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file From 75dffc9473c2687471245398369a2a88c8fdd691 Mon Sep 17 00:00:00 2001 From: AlAuAu <458134681@qq.com> Date: Wed, 13 Aug 2025 19:54:53 +0800 Subject: [PATCH 2/4] pp and nopp unify --- .../meta_parallel/parallel_layers/pp_layers.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py index 8ab4d4990e88ff..f835231ce7c3c3 100755 --- a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py +++ b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py @@ -510,8 +510,10 @@ def __init__( self._build_layer() self.comm_key_to_layer_name = {} - - self.shared_comm = self._construct_shared_comm() + if self._num_stages > 1: + self.shared_comm = self._construct_shared_comm() + else: + self.shared_comm = {} self._synchronize_shared_weights() def get_stage_from_index(self, layer_idx): @@ -1002,7 +1004,12 @@ def flush_into_run_function(): param.is_firstly_shared = True if layer.forward_func is None: - run_function.append(self.shared_layers[layer.layer_name]) + if self._num_stages == 1: + run_function.append(layer.build_layer()) + else: + run_function.append( + self.shared_layers[layer.layer_name] + ) else: run_function.append( From 8ac0915cb4f14dfcb78d5a12eae060fc27449998 Mon Sep 17 00:00:00 2001 From: LiYuRio Date: Wed, 13 Aug 2025 20:17:39 +0800 Subject: [PATCH 3/4] fix unit test case --- .../parallel_layers/pp_layers.py | 16 +-- .../fleet/hybrid_pp_unified_dygraph_model.py | 131 +++++++++++++++--- 2 files changed, 113 insertions(+), 34 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py index f835231ce7c3c3..c0ec17e5a67a2a 100755 --- a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py +++ b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py @@ -510,10 +510,7 @@ def __init__( self._build_layer() self.comm_key_to_layer_name = {} - if self._num_stages > 1: - self.shared_comm = self._construct_shared_comm() - else: - self.shared_comm = {} + self.shared_comm = self._construct_shared_comm() self._synchronize_shared_weights() def get_stage_from_index(self, layer_idx): @@ -544,7 +541,7 @@ def get_model_chunks(self): def _construct_shared_comm(self): shared_comm = {} if self._topo.get_dim("pipe") == 1: - return + return shared_comm # The first loop gets the pivot stage and all different shared_weight_attrs for one layer name. # Maps stage idx to all shared attrs of each different layer names on that stage. @@ -1004,12 +1001,9 @@ def flush_into_run_function(): param.is_firstly_shared = True if layer.forward_func is None: - if self._num_stages == 1: - run_function.append(layer.build_layer()) - else: - run_function.append( - self.shared_layers[layer.layer_name] - ) + run_function.append( + self.shared_layers[layer.layer_name] + ) else: run_function.append( diff --git a/test/collective/fleet/hybrid_pp_unified_dygraph_model.py b/test/collective/fleet/hybrid_pp_unified_dygraph_model.py index e3c6337a3ca253..6d68d5560a3d07 100644 --- a/test/collective/fleet/hybrid_pp_unified_dygraph_model.py +++ b/test/collective/fleet/hybrid_pp_unified_dygraph_model.py @@ -1,5 +1,6 @@ import unittest import numpy as np +import random import paddle import paddle.distributed as dist @@ -16,13 +17,19 @@ batch_size = 5 micro_batch_size = 1 +def set_random_seed(seed, dp_id, rank_id): + """Set random seed for reproducibility.""" + random.seed(seed) + np.random.seed(seed + dp_id) + paddle.seed(seed + dp_id) + class RandomDataset(Dataset): def __init__(self, num_samples): self.num_samples = num_samples def __getitem__(self, idx): - input_ids = np.random.random([5]).astype('int64') - label = np.random.randint(0, 5, (5)).astype('int64') + input_ids = np.random.randint(0, 20, [10]).astype('int64') + label = np.random.randint(0, 20, (10)).astype('int64') return input_ids, label def __len__(self): @@ -36,18 +43,29 @@ class EmbeddingPipe(nn.Layer): def __init__(self, **kwargs): super().__init__() self.embed_tokens = nn.Embedding(kwargs["num_embeddings"], kwargs["embedding_dim"]) + #print(f"liyurui, embeding weight init = {self.embedding_weight._md5sum()}") def forward(self, input_ids): + #print(f"liyurui, input_ids is {input_ids}") + #print(f"liyurui, input_ids is {input_ids._md5sum()}, weight={self.embedding_weight._md5sum()}") hidden_states = self.embed_tokens.forward(input_ids) + #print(f"liyurui, hidden_states of embedding pipe {hidden_states._md5sum()}") return (hidden_states, input_ids) @property def embedding_weight(self): return getattr(self.embed_tokens, "weight") +def mtp_forward(layer, args): + hidden_states = args[0] + input_ids = args[1] + embed = layer.forward(input_ids) + output = embed[0] + hidden_states + return (output, input_ids) class MTPEmbeddingPipe(EmbeddingPipe): def forward(self, args): + #print(f"liyurui, input of MTPEmbedding is {args}") hidden_states = args[0] input_ids = args[1] embed = super().forward(input_ids) @@ -87,23 +105,23 @@ def __init__ (self, **kwargs): self._sequential_layers = [] self.num_layer = 4 - #self.add_sequential_layer( - # SharedLayerDesc( - # key="embed_weight_share", - # layer_func=EmbeddingPipe, - # shared_weight_attr="embedding_weight", - # num_embeddings=vocab_size, - # embedding_dim=hidden_size, - # ), - # "embed", - #) self.add_sequential_layer( - LayerDesc( - EmbeddingPipe, - num_embeddings=vocab_size, - embedding_dim=hidden_size, - ), "embed" + SharedLayerDesc( + key="embed_weight_share", + layer_func=EmbeddingPipe, + shared_weight_attr="embedding_weight", + num_embeddings=vocab_size, + embedding_dim=hidden_size, + ), + "embed", ) + #self.add_sequential_layer( + # LayerDesc( + # EmbeddingPipe, + # num_embeddings=vocab_size, + # embedding_dim=hidden_size, + # ), "embed" + #) for i in range(self.num_layer): self.add_sequential_layer( @@ -119,13 +137,22 @@ def __init__ (self, **kwargs): self.add_sequential_layer( SharedLayerDesc( key="embed_weight_share", - layer_func=MTPEmbeddingPipe, + #layer_func=MTPEmbeddingPipe, + layer_func=EmbeddingPipe, shared_weight_attr="embedding_weight", + forward_func=mtp_forward, num_embeddings=vocab_size, embedding_dim=hidden_size, ), "embed_shared", ) + #self.add_sequential_layer( + # LayerDesc( + # MTPEmbeddingPipe, + # num_embeddings=vocab_size, + # embedding_dim=hidden_size, + # ), "embed" + #) self.add_sequential_layer( LayerDesc( @@ -177,6 +204,12 @@ def wrapper_mix_precision(self, model, optimizer): return model, optimizer def test_unified_pp_model(self): + hcg = fleet.get_hybrid_communicate_group() + dp_id = hcg.get_data_parallel_rank() + pp_id = hcg.get_stage_id() + rank_id = dist.get_rank() + set_random_seed(1024, dp_id, rank_id) + unified_model_pp = UnifiedPPModel(num_stages=self.pipeline_parallel_size) unified_scheduler_pp, unified_optimizer_pp = self.build_optimizer(unified_model_pp) unified_model_pp, unified_optimizer_pp = self.wrapper_mix_precision(unified_model_pp, unified_optimizer_pp) @@ -186,6 +219,32 @@ def test_unified_pp_model(self): unified_model_nonpp = UnifiedPPModel(num_stages=1) unified_scheduler_nonpp, unified_optimizer_nonpp = self.build_optimizer(unified_model_nonpp) + pp_id_sname = {} + for n, p in unified_model_pp.named_parameters(): + pp_id_sname[id(p)] = n + + #for p in unified_model_pp.parameters(): + # print(f"liyurui, pp parameter is {pp_id_sname[id(p)]}, {p.name}, {p.shape}") + + nonpp_id_sname = {} + for n, p in unified_model_nonpp.named_parameters(): + nonpp_id_sname[id(p)] = n + + #for p in unified_model_nonpp.parameters(): + # print(f"liyurui, nonpp parameter is {nonpp_id_sname[id(p)]}, {p.name}, {p.shape}") + + # reset to make pp and nonpp model have same parameters value + if pp_id == 0: + unified_model_pp.parameters()[0].set_value(unified_model_nonpp.parameters()[0]) + unified_model_pp.parameters()[1].set_value(unified_model_nonpp.parameters()[1]) + unified_model_pp.parameters()[2].set_value(unified_model_nonpp.parameters()[2]) + else: + #unified_model_pp.parameters()[0].set_value(unified_model_nonpp.parameters()[0]) + unified_model_pp.parameters()[1].set_value(unified_model_nonpp.parameters()[3]) + unified_model_pp.parameters()[2].set_value(unified_model_nonpp.parameters()[4]) + unified_model_pp.parameters()[3].set_value(unified_model_nonpp.parameters()[5]) + #unified_model_pp.parameters()[3].set_value(unified_model_nonpp.parameters()[6]) + dataset = RandomDataset(5 * batch_size) train_reader = DataLoader( @@ -196,17 +255,43 @@ def test_unified_pp_model(self): num_workers=2, ) + for p in unified_model_pp.parameters(): + print(f"liyurui, pp parameter is {pp_id_sname[id(p)]}, {p.name}, {p._md5sum()}") + + for p in unified_model_nonpp.parameters(): + print(f"liyurui, nonpp parameter is {nonpp_id_sname[id(p)]}, {p.name}, {p._md5sum()}") + for _, (input_ids, label) in enumerate(train_reader()): - print(f"liyurui, input_ids is {input_ids.shape}, {input_ids.dtype}, label is {label.shape}, {label.dtype}") + #print(f"liyurui, input_ids is {input_ids.shape}, {input_ids.dtype}, label is {label.shape}, {label.dtype}") pp_loss = unified_model_pp.train_batch([input_ids, label], unified_optimizer_pp, unified_scheduler_pp) print(f"liyurui, pp_loss is {pp_loss}") - nonpp_output = unified_model_nonpp(input_ids) - loss_fn = nn.loss.CrossEntropyLoss() - nonpp_loss = loss_fn(nonpp_output[0], label) + num_acc = batch_size // micro_batch_size + micro_input_ids = paddle.split(input_ids, num_acc) + micro_labels = paddle.split(label, num_acc) + + nonpp_loss = 0 + for micro_input, micro_label in zip(micro_input_ids, micro_labels): + nonpp_output = unified_model_nonpp(micro_input) + loss_fn = nn.loss.CrossEntropyLoss() + loss = loss_fn(nonpp_output[0], micro_label) / num_acc + loss.backward() + nonpp_loss += loss.detach() print(f"liyurui, nonpp_loss is {nonpp_loss}") - return + #for p in unified_model_nonpp.parameters(): + # print(f"nonpp {p.name}@grad, sname: {nonpp_id_sname[id(p)]}, {p.grad._md5sum()}") + # #if hasattr(p, "main_grad") and p.main_grad is not None: + # # print(f"nonpp {p.name}@grad, sname: {nonpp_id_sname[id(p)]}, {p.main_grad._md5sum()}") + + #for p in unified_model_pp.parameters(): + # print(f"pp {p.name}@grad, sname: {pp_id_sname[id(p)]}, {p.grad._md5sum()}") + # #if hasattr(p, "main_grad") and p.main_grad is not None: + # # print(f"pp {p.name}@grad, sname: {pp_id_sname[id(p)]}, {p.main_grad._md5sum()}") + + unified_optimizer_nonpp.step() + unified_optimizer_nonpp.clear_grad() + unified_scheduler_nonpp.step() if __name__ == "__main__": From 56b5eb90b899acd0a4cdcc4c7bdfe91f00d83563 Mon Sep 17 00:00:00 2001 From: LiYuRio Date: Tue, 19 Aug 2025 22:57:04 +0800 Subject: [PATCH 4/4] refine test case with shared --- .../parallel_layers/pp_layers.py | 4 +- test/collective/fleet/CMakeLists.txt | 14 ++ .../fleet/hybrid_pp_unified_dygraph_model.py | 207 +++++++++--------- .../fleet/test_pp_unified_dygraph_model.py | 16 +- 4 files changed, 132 insertions(+), 109 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py index c0ec17e5a67a2a..bad5a5d4700903 100755 --- a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py +++ b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py @@ -1001,9 +1001,7 @@ def flush_into_run_function(): param.is_firstly_shared = True if layer.forward_func is None: - run_function.append( - self.shared_layers[layer.layer_name] - ) + run_function.append(self.shared_layers[layer.layer_name]) else: run_function.append( diff --git a/test/collective/fleet/CMakeLists.txt b/test/collective/fleet/CMakeLists.txt index e99618dadd09e8..62850027500f1b 100644 --- a/test/collective/fleet/CMakeLists.txt +++ b/test/collective/fleet/CMakeLists.txt @@ -850,3 +850,17 @@ if((WITH_GPU) AND LOCAL_ALL_PLAT) ) set_tests_properties(test_pp_send_recv_dict PROPERTIES TIMEOUT "500") endif() +if((WITH_GPU) AND LOCAL_ALL_PLAT) + bash_test_modules( + test_pp_unified_dygraph_model + START_BASH + ../../legacy_test/dist_test.sh + TIMEOUT + "500" + LABELS + "RUN_TYPE=DIST" + ENVS + "PADDLE_DIST_UT_PORT=21282;http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python" + ) + set_tests_properties(test_pp_unified_dygraph_model PROPERTIES TIMEOUT "500") +endif() diff --git a/test/collective/fleet/hybrid_pp_unified_dygraph_model.py b/test/collective/fleet/hybrid_pp_unified_dygraph_model.py index 6d68d5560a3d07..b544f596d9aeb1 100644 --- a/test/collective/fleet/hybrid_pp_unified_dygraph_model.py +++ b/test/collective/fleet/hybrid_pp_unified_dygraph_model.py @@ -1,28 +1,44 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random import unittest + import numpy as np -import random import paddle import paddle.distributed as dist from paddle import nn -from paddle.io import DataLoader, Dataset - from paddle.distributed import fleet from paddle.distributed.fleet.meta_parallel import ( LayerDesc, - SharedLayerDesc, PipelineLayer, + SharedLayerDesc, ) +from paddle.io import DataLoader, Dataset batch_size = 5 micro_batch_size = 1 + def set_random_seed(seed, dp_id, rank_id): """Set random seed for reproducibility.""" random.seed(seed) np.random.seed(seed + dp_id) paddle.seed(seed + dp_id) + class RandomDataset(Dataset): def __init__(self, num_samples): self.num_samples = num_samples @@ -39,22 +55,22 @@ def __len__(self): vocab_size = 1024 hidden_size = 64 + class EmbeddingPipe(nn.Layer): def __init__(self, **kwargs): super().__init__() - self.embed_tokens = nn.Embedding(kwargs["num_embeddings"], kwargs["embedding_dim"]) - #print(f"liyurui, embeding weight init = {self.embedding_weight._md5sum()}") + self.embed_tokens = nn.Embedding( + kwargs["num_embeddings"], kwargs["embedding_dim"] + ) def forward(self, input_ids): - #print(f"liyurui, input_ids is {input_ids}") - #print(f"liyurui, input_ids is {input_ids._md5sum()}, weight={self.embedding_weight._md5sum()}") hidden_states = self.embed_tokens.forward(input_ids) - #print(f"liyurui, hidden_states of embedding pipe {hidden_states._md5sum()}") return (hidden_states, input_ids) @property def embedding_weight(self): - return getattr(self.embed_tokens, "weight") + return self.embed_tokens.weight + def mtp_forward(layer, args): hidden_states = args[0] @@ -63,9 +79,9 @@ def mtp_forward(layer, args): output = embed[0] + hidden_states return (output, input_ids) + class MTPEmbeddingPipe(EmbeddingPipe): def forward(self, args): - #print(f"liyurui, input of MTPEmbedding is {args}") hidden_states = args[0] input_ids = args[1] embed = super().forward(input_ids) @@ -78,10 +94,10 @@ def __init__( self, in_features, out_features, - weight_attr = None, - bias_attr = None, - name = None, - layer_idx = 0 + weight_attr=None, + bias_attr=None, + name=None, + layer_idx=0, ): self.layer_idx = layer_idx super().__init__(in_features, out_features, bias_attr=bias_attr) @@ -101,27 +117,20 @@ def forward(self, logits, label): class UnifiedPPModel(PipelineLayer): - def __init__ (self, **kwargs): + def __init__(self, **kwargs): self._sequential_layers = [] self.num_layer = 4 self.add_sequential_layer( - SharedLayerDesc( - key="embed_weight_share", - layer_func=EmbeddingPipe, - shared_weight_attr="embedding_weight", - num_embeddings=vocab_size, - embedding_dim=hidden_size, - ), - "embed", + SharedLayerDesc( + key="embed_weight_share", + layer_func=EmbeddingPipe, + shared_weight_attr="embedding_weight", + num_embeddings=vocab_size, + embedding_dim=hidden_size, + ), + "embed", ) - #self.add_sequential_layer( - # LayerDesc( - # EmbeddingPipe, - # num_embeddings=vocab_size, - # embedding_dim=hidden_size, - # ), "embed" - #) for i in range(self.num_layer): self.add_sequential_layer( @@ -131,49 +140,48 @@ def __init__ (self, **kwargs): hidden_size, bias_attr=False, layer_idx=i, - ), f"layer.{i}" + ), + f"layer.{i}", ) self.add_sequential_layer( - SharedLayerDesc( - key="embed_weight_share", - #layer_func=MTPEmbeddingPipe, - layer_func=EmbeddingPipe, - shared_weight_attr="embedding_weight", - forward_func=mtp_forward, - num_embeddings=vocab_size, - embedding_dim=hidden_size, - ), - "embed_shared", + SharedLayerDesc( + key="embed_weight_share", + layer_func=EmbeddingPipe, + shared_weight_attr="embedding_weight", + forward_func=mtp_forward, + num_embeddings=vocab_size, + embedding_dim=hidden_size, + ), + "embed_shared", ) - #self.add_sequential_layer( - # LayerDesc( - # MTPEmbeddingPipe, - # num_embeddings=vocab_size, - # embedding_dim=hidden_size, - # ), "embed" - #) self.add_sequential_layer( - LayerDesc( - LinearPipe, - hidden_size, - hidden_size, - bias_attr=False, - layer_idx=self.num_layer - ), "last_layer" + LayerDesc( + LinearPipe, + hidden_size, + hidden_size, + bias_attr=False, + layer_idx=self.num_layer, + ), + "last_layer", ) - super().__init__(layers=self.get_sequential_layer(), loss_fn=CrossEntropyLossPipe(), **kwargs) + super().__init__( + layers=self.get_sequential_layer(), + loss_fn=CrossEntropyLossPipe(), + **kwargs, + ) def add_sequential_layer(self, layer_desc, name_prefix=""): - self._sequential_layers.append({"layer": layer_desc, "name_prefix": name_prefix}) + self._sequential_layers.append( + {"layer": layer_desc, "name_prefix": name_prefix} + ) def get_sequential_layer(self): return [x["layer"] for x in self._sequential_layers] - class TestDistPPTraining(unittest.TestCase): def setUp(self): strategy = fleet.DistributedStrategy() @@ -210,40 +218,44 @@ def test_unified_pp_model(self): rank_id = dist.get_rank() set_random_seed(1024, dp_id, rank_id) - unified_model_pp = UnifiedPPModel(num_stages=self.pipeline_parallel_size) - unified_scheduler_pp, unified_optimizer_pp = self.build_optimizer(unified_model_pp) - unified_model_pp, unified_optimizer_pp = self.wrapper_mix_precision(unified_model_pp, unified_optimizer_pp) + unified_model_pp = UnifiedPPModel( + num_stages=self.pipeline_parallel_size + ) + unified_scheduler_pp, unified_optimizer_pp = self.build_optimizer( + unified_model_pp + ) + unified_model_pp, unified_optimizer_pp = self.wrapper_mix_precision( + unified_model_pp, unified_optimizer_pp + ) unified_model_pp = fleet.distributed_model(unified_model_pp) unified_optimizer_pp = fleet.distributed_optimizer(unified_optimizer_pp) unified_model_nonpp = UnifiedPPModel(num_stages=1) - unified_scheduler_nonpp, unified_optimizer_nonpp = self.build_optimizer(unified_model_nonpp) - - pp_id_sname = {} - for n, p in unified_model_pp.named_parameters(): - pp_id_sname[id(p)] = n - - #for p in unified_model_pp.parameters(): - # print(f"liyurui, pp parameter is {pp_id_sname[id(p)]}, {p.name}, {p.shape}") - - nonpp_id_sname = {} - for n, p in unified_model_nonpp.named_parameters(): - nonpp_id_sname[id(p)] = n - - #for p in unified_model_nonpp.parameters(): - # print(f"liyurui, nonpp parameter is {nonpp_id_sname[id(p)]}, {p.name}, {p.shape}") + unified_scheduler_nonpp, unified_optimizer_nonpp = self.build_optimizer( + unified_model_nonpp + ) # reset to make pp and nonpp model have same parameters value if pp_id == 0: - unified_model_pp.parameters()[0].set_value(unified_model_nonpp.parameters()[0]) - unified_model_pp.parameters()[1].set_value(unified_model_nonpp.parameters()[1]) - unified_model_pp.parameters()[2].set_value(unified_model_nonpp.parameters()[2]) + unified_model_pp.parameters()[0].set_value( + unified_model_nonpp.parameters()[0] + ) + unified_model_pp.parameters()[1].set_value( + unified_model_nonpp.parameters()[1] + ) + unified_model_pp.parameters()[2].set_value( + unified_model_nonpp.parameters()[2] + ) else: - #unified_model_pp.parameters()[0].set_value(unified_model_nonpp.parameters()[0]) - unified_model_pp.parameters()[1].set_value(unified_model_nonpp.parameters()[3]) - unified_model_pp.parameters()[2].set_value(unified_model_nonpp.parameters()[4]) - unified_model_pp.parameters()[3].set_value(unified_model_nonpp.parameters()[5]) - #unified_model_pp.parameters()[3].set_value(unified_model_nonpp.parameters()[6]) + unified_model_pp.parameters()[1].set_value( + unified_model_nonpp.parameters()[3] + ) + unified_model_pp.parameters()[2].set_value( + unified_model_nonpp.parameters()[4] + ) + unified_model_pp.parameters()[3].set_value( + unified_model_nonpp.parameters()[5] + ) dataset = RandomDataset(5 * batch_size) @@ -255,16 +267,10 @@ def test_unified_pp_model(self): num_workers=2, ) - for p in unified_model_pp.parameters(): - print(f"liyurui, pp parameter is {pp_id_sname[id(p)]}, {p.name}, {p._md5sum()}") - - for p in unified_model_nonpp.parameters(): - print(f"liyurui, nonpp parameter is {nonpp_id_sname[id(p)]}, {p.name}, {p._md5sum()}") - for _, (input_ids, label) in enumerate(train_reader()): - #print(f"liyurui, input_ids is {input_ids.shape}, {input_ids.dtype}, label is {label.shape}, {label.dtype}") - pp_loss = unified_model_pp.train_batch([input_ids, label], unified_optimizer_pp, unified_scheduler_pp) - print(f"liyurui, pp_loss is {pp_loss}") + pp_loss = unified_model_pp.train_batch( + [input_ids, label], unified_optimizer_pp, unified_scheduler_pp + ) num_acc = batch_size // micro_batch_size micro_input_ids = paddle.split(input_ids, num_acc) @@ -277,17 +283,8 @@ def test_unified_pp_model(self): loss = loss_fn(nonpp_output[0], micro_label) / num_acc loss.backward() nonpp_loss += loss.detach() - print(f"liyurui, nonpp_loss is {nonpp_loss}") - - #for p in unified_model_nonpp.parameters(): - # print(f"nonpp {p.name}@grad, sname: {nonpp_id_sname[id(p)]}, {p.grad._md5sum()}") - # #if hasattr(p, "main_grad") and p.main_grad is not None: - # # print(f"nonpp {p.name}@grad, sname: {nonpp_id_sname[id(p)]}, {p.main_grad._md5sum()}") - #for p in unified_model_pp.parameters(): - # print(f"pp {p.name}@grad, sname: {pp_id_sname[id(p)]}, {p.grad._md5sum()}") - # #if hasattr(p, "main_grad") and p.main_grad is not None: - # # print(f"pp {p.name}@grad, sname: {pp_id_sname[id(p)]}, {p.main_grad._md5sum()}") + np.testing.assert_equal(nonpp_loss.numpy(), pp_loss.numpy()) unified_optimizer_nonpp.step() unified_optimizer_nonpp.clear_grad() @@ -295,4 +292,4 @@ def test_unified_pp_model(self): if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() diff --git a/test/collective/fleet/test_pp_unified_dygraph_model.py b/test/collective/fleet/test_pp_unified_dygraph_model.py index 1b301b21a8d33e..74f8153de1ab80 100644 --- a/test/collective/fleet/test_pp_unified_dygraph_model.py +++ b/test/collective/fleet/test_pp_unified_dygraph_model.py @@ -1,3 +1,17 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import unittest from legacy_test.test_parallel_dygraph_dataparallel import ( @@ -11,4 +25,4 @@ def test_pipeline_parallel(self): if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main()