diff --git a/ppfleetx/models/language_model/utils.py b/ppfleetx/models/language_model/utils.py index 28bc4c259..761d22b4f 100644 --- a/ppfleetx/models/language_model/utils.py +++ b/ppfleetx/models/language_model/utils.py @@ -121,8 +121,10 @@ def process_optim_configs(config): nranks = dist.get_world_size() dp_degree = config['Distributed']['dp_degree'] + sharding_degree = config['Distributed']['sharding']['sharding_degree'] if config['Optimizer']['tensor_fusion']: - assert nranks == dp_degree, "tensor_fusion only support single card train or data parallel train" + assert nranks == dp_degree * sharding_degree, \ + "tensor_fusion only support single card train or data/sharding parallel train" def process_data_configs(config): diff --git a/ppfleetx/optims/optimizer.py b/ppfleetx/optims/optimizer.py index f858f89f2..005e39e87 100644 --- a/ppfleetx/optims/optimizer.py +++ b/ppfleetx/optims/optimizer.py @@ -14,6 +14,7 @@ import sys import paddle +import paddle.distributed.fleet as fleet from ppfleetx.utils.tensor_fusion_helper import fused_parameters from paddle.optimizer import Adam, AdamW, Momentum @@ -30,9 +31,13 @@ class FusedAdamW(paddle.optimizer.AdamW): def __init__(self, learning_rate, parameters, grad_clip, **config): tensor_fusion = config.pop("tensor_fusion", False) + if paddle.distributed.get_world_size() > 1: + hcg = fleet.get_hybrid_communicate_group() + sharding_size = hcg.get_sharding_parallel_world_size() + if tensor_fusion: self.decay_fused_tensors, self.all_fused_tensors = fused_parameters( - parameters) + parameters, sharding_size > 1) decay_params = [p.name for p in self.decay_fused_tensors] else: decay_params = [