diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index 8fad0686dd42e..431bc6d7bc389 100755 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -329,6 +329,7 @@ def _broadcast_final_loss(self): def _optimizer_step(self): if self.scaler: self.scaler.step(self.optimizer) + self.scaler.update() else: self.optimizer.step() diff --git a/python/paddle/fluid/tests/unittests/hybrid_parallel_mp_amp.py b/python/paddle/fluid/tests/unittests/hybrid_parallel_mp_amp.py index 083ad319305f3..4c966585d5f1f 100644 --- a/python/paddle/fluid/tests/unittests/hybrid_parallel_mp_amp.py +++ b/python/paddle/fluid/tests/unittests/hybrid_parallel_mp_amp.py @@ -48,6 +48,7 @@ def train_batch(self, batch, model, optimizer, is_mp): scaled.backward() # do backward scaler.step(optimizer) # update parameters + scaler.update() optimizer.clear_grad() return scaled