You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Apparently PeftModelForCausalLM.generate does not like positional passing of inputs, causing the following backtrace:
dpo_trainer.train()
File "/mnt/seagate/miniconda3/lib/python3.10/site-packages/transformers/trainer.py", line 1591, in train
return inner_training_loop(
File "/mnt/seagate/miniconda3/lib/python3.10/site-packages/transformers/trainer.py", line 1984, in _inner_training_loop
self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
File "/mnt/seagate/miniconda3/lib/python3.10/site-packages/transformers/trainer.py", line 2328, in _maybe_log_save_evaluate
metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
File "/mnt/seagate/miniconda3/lib/python3.10/site-packages/transformers/trainer.py", line 3066, in evaluate
output = eval_loop(
File "/mnt/seagate/miniconda3/lib/python3.10/site-packages/trl/trainer/dpo_trainer.py", line 617, in evaluation_loop
policy_output_decoded, ref_output_decoded = self.get_batch_samples(self.model, random_batch)
File "/mnt/seagate/miniconda3/lib/python3.10/site-packages/trl/trainer/dpo_trainer.py", line 514, in get_batch_samples
policy_output = model.generate(
TypeError: PeftModelForCausalLM.generate() takes 1 positional argument but 2 were given
If I change the generate calls in dpo_trainer.py to instead pass a named parameter (i.e. inputs=batch["prompt_input_ids"]), the crash does not happen.
The text was updated successfully, but these errors were encountered:
I previously used the generate and have the problem, it should explicitly pass the input_ids like this model.generate(input_ids=[your input_ids], ....) and it will solve the problem.
Reporting against trl 0.7.3.dev0 and peft 0.5.0:
Apparently PeftModelForCausalLM.generate does not like positional passing of inputs, causing the following backtrace:
If I change the generate calls in dpo_trainer.py to instead pass a named parameter (i.e. inputs=batch["prompt_input_ids"]), the crash does not happen.
The text was updated successfully, but these errors were encountered: