From bc8689679c134a766d3f497ea02c80027437a1b7 Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Wed, 18 Dec 2024 20:05:31 +0800 Subject: [PATCH] update trainer.py --- paddlenlp/trainer/trainer.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index 58e747865c7f..d4f832820b3a 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -1904,9 +1904,14 @@ def _load_rng_state(self, checkpoint): if "hybrid_parallel_rng_state_tracker" in checkpoint_rng_state: if self.args.tensor_parallel_degree <= 1: checkpoint_rng_state["hybrid_parallel_rng_state_tracker"].pop("model_parallel_rng", None) - fleet.meta_parallel.get_rng_state_tracker().set_states_tracker( - checkpoint_rng_state["hybrid_parallel_rng_state_tracker"] - ) + try: + fleet.meta_parallel.get_rng_state_tracker().set_states_tracker( + checkpoint_rng_state["hybrid_parallel_rng_state_tracker"] + ) + except: + logger.warning( + "Hybrid paralell rng states change when training environment differs, so we dot not set state tracker here." + ) else: logger.warning("Not found hybrid parallel RNG state.")