diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index 7876a8cb0..a62e42800 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -71,13 +71,15 @@ def __init__(self, config: TRLConfig, **kwargs): # Set up a reference model when hydra heads are not used if not hasattr(self.model, "frozen_head") and not self.model.peft_type: + # Full Reference Copy self.ref_model = self.get_arch(self.config) self.ref_model.base_model.resize_token_embeddings(len(self.tokenizer)) self.ref_model.to(self.accelerator.device) self.ref_model.eval() - else: - # resize hydra heads + elif hasattr(self.model, "frozen_head"): + # Hydra Reference: Use the frozen base layers and head as the reference model, resize hydra heads self.model.frozen_head.resize_token_embeddings(len(self.tokenizer)) + # TODO: else PEFT Reference, do something? # Set up the KL controller # This helps prevent large divergences in the controller (policy)