Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrong invocation of PeftModelForCausalLM.generate in DPOTrainer? #877

Closed
jploski opened this issue Oct 16, 2023 · 3 comments · Fixed by #941
Closed

Wrong invocation of PeftModelForCausalLM.generate in DPOTrainer? #877

jploski opened this issue Oct 16, 2023 · 3 comments · Fixed by #941

Comments

@jploski
Copy link

jploski commented Oct 16, 2023

Reporting against trl 0.7.3.dev0 and peft 0.5.0:

Apparently PeftModelForCausalLM.generate does not like positional passing of inputs, causing the following backtrace:

    dpo_trainer.train()
  File "/mnt/seagate/miniconda3/lib/python3.10/site-packages/transformers/trainer.py", line 1591, in train
    return inner_training_loop(
  File "/mnt/seagate/miniconda3/lib/python3.10/site-packages/transformers/trainer.py", line 1984, in _inner_training_loop
    self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
  File "/mnt/seagate/miniconda3/lib/python3.10/site-packages/transformers/trainer.py", line 2328, in _maybe_log_save_evaluate
    metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
  File "/mnt/seagate/miniconda3/lib/python3.10/site-packages/transformers/trainer.py", line 3066, in evaluate
    output = eval_loop(
  File "/mnt/seagate/miniconda3/lib/python3.10/site-packages/trl/trainer/dpo_trainer.py", line 617, in evaluation_loop
    policy_output_decoded, ref_output_decoded = self.get_batch_samples(self.model, random_batch)
  File "/mnt/seagate/miniconda3/lib/python3.10/site-packages/trl/trainer/dpo_trainer.py", line 514, in get_batch_samples
    policy_output = model.generate(
TypeError: PeftModelForCausalLM.generate() takes 1 positional argument but 2 were given

If I change the generate calls in dpo_trainer.py to instead pass a named parameter (i.e. inputs=batch["prompt_input_ids"]), the crash does not happen.

@NitCoh
Copy link

NitCoh commented Oct 18, 2023

Yep, I'm also having the same issues.
I believe it happens only with PeftModelForCausalLM

@ZixuanLiu4869
Copy link

I previously used the generate and have the problem, it should explicitly pass the input_ids like this model.generate(input_ids=[your input_ids], ....) and it will solve the problem.

@younesbelkada
Copy link
Contributor

Hi there! Makes sense, just made #941 to solve the issue

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants