diff --git a/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py b/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py index c95b62d32..ef8256baf 100644 --- a/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py +++ b/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py @@ -540,7 +540,7 @@ def main(): print_rank_0( f'Epoch: {epoch} | Step: {step} | PPO Epoch: {ppo_ep+1} | Actor Loss: {actor_loss_sum/inner_iter} | Critic Loss: {critic_loss_sum/inner_iter} | Unsupervised Loss: {unsup_loss_sum/inner_iter}', args.global_rank) - print_throughput_step3(rlhf_engine.actor.model, + print_throughput_step3(rlhf_engine.actor.module, rlhf_engine.critic, args, e2e_time, trainer.generate_time, training_time, args.global_rank) diff --git a/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/ppo_trainer.py b/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/ppo_trainer.py index 7840de70a..4c55b37d6 100644 --- a/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/ppo_trainer.py +++ b/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/ppo_trainer.py @@ -75,7 +75,7 @@ def _generate_sequence(self, prompts, mask, step): # This has been added due to a probability/nan error that happens after # meta-llama/Llama-2-7b-hf enabled do_sample: # https://huggingface.co/meta-llama/Llama-2-7b-hf/commit/6fdf2e60f86ff2481f2241aaee459f85b5b0bbb9 - if self.actor_model.model.config.model_type == "llama": + if self.actor_model.module.config.model_type == "llama": kwargs = dict(do_sample=False) else: kwargs = dict()