Skip to content

Commit

Permalink
Fix: RuntimeError: 'weight' must be 2-D issue (huggingface#687)
Browse files Browse the repository at this point in the history
* Update dpo_trainer.py

* Fix: self.args.deepspeed > self.is_deepspeed_enabled

* Update trl/trainer/dpo_trainer.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
  • Loading branch information
2 people authored and Andrew Lapp committed May 10, 2024
1 parent 3f90f56 commit e6d35da
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,12 @@ def __init__(
"You are using a `peft` version that does not support `disable_adapter`. Please update your `peft` version to the latest version."
)
else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
if self.is_deepspeed_enabled:
# Read more about the issue in https://github.com/huggingface/trl/pull/687
self.ref_model = self.accelerator._prepare_deepspeed(self.ref_model)
self.ref_model.eval()
else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)

def concatenated_inputs(self, batch: Dict[str, Union[List, torch.LongTensor]]) -> Dict[str, torch.LongTensor]:
"""Concatenate the chosen and rejected inputs into a single tensor.
Expand Down

0 comments on commit e6d35da

Please sign in to comment.