diff --git a/ppfleetx/models/language_model/gpt/dygraph/hybrid_model.py b/ppfleetx/models/language_model/gpt/dygraph/hybrid_model.py index ab0b93727..362b05e8d 100644 --- a/ppfleetx/models/language_model/gpt/dygraph/hybrid_model.py +++ b/ppfleetx/models/language_model/gpt/dygraph/hybrid_model.py @@ -975,7 +975,8 @@ def __init__(self, recompute_granularity="full", virtual_pp_degree=1, sequence_parallel=False, - no_recompute_layers=None): + no_recompute_layers=None, + pp_recompute_interval=1): # forward desc self.descs = [] @@ -1057,7 +1058,11 @@ def _logits_helper(embedding, output): recompute_interval = 0 if recompute and recompute_granularity == "full": - recompute_interval = 1 + assert pp_recompute_interval <= \ + num_layers // (virtual_pp_degree * + fleet.get_hybrid_communicate_group().topology().get_dim_size("pipe")), \ + "pp recompute interval should smaller than num layers of each pp chunk" + recompute_interval = pp_recompute_interval super().__init__( layers=self.descs, diff --git a/ppfleetx/utils/config.py b/ppfleetx/utils/config.py index 08704d501..82a1ac50f 100644 --- a/ppfleetx/utils/config.py +++ b/ppfleetx/utils/config.py @@ -38,6 +38,7 @@ def process_dist_config(configs): mp_degree = config.setdefault("mp_degree", 1) pp_degree = config.setdefault("pp_degree", 1) + pp_recompute_interval = config.setdefault("pp_recompute_interval", 1) # sharding default sharding_config = config['sharding']