Skip to content

Commit

Permalink
Add positional args to PeftModelForCausalLM.generate (#1393)
Browse files Browse the repository at this point in the history
* add positional args

* update tests
  • Loading branch information
SumanthRH authored Jan 30, 2024
1 parent 1a7f3e3 commit 9d94367
Show file tree
Hide file tree
Showing 7 changed files with 34 additions and 15 deletions.
4 changes: 2 additions & 2 deletions src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1130,14 +1130,14 @@ def forward(
inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1)
return self.base_model(inputs_embeds=inputs_embeds, **kwargs)

def generate(self, **kwargs):
def generate(self, *args, **kwargs):
self.base_model.prepare_inputs_for_generation = self.prepare_inputs_for_generation
if hasattr(self.base_model, "model"):
self.base_model.model.generation_config = self.generation_config
else:
self.base_model.generation_config = self.generation_config
try:
outputs = self.base_model.generate(**kwargs)
outputs = self.base_model.generate(*args, **kwargs)
except:
self.base_model.prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation
raise
Expand Down
5 changes: 2 additions & 3 deletions tests/test_adaption_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,9 +267,8 @@ def test_generate(self) -> None:
# check if `generate` works
_ = model.generate(input_ids=input_ids, attention_mask=attention_mask)

with self.assertRaises(TypeError):
# check if `generate` raises an error if no positional arguments are passed
_ = model.generate(input_ids, attention_mask=attention_mask)
# check if `generate` works if positional arguments are passed
_ = model.generate(input_ids, attention_mask=attention_mask)

def test_sequence_adapter_ops(self) -> None:
"""Test sequence of adapter operations."""
Expand Down
5 changes: 5 additions & 0 deletions tests/test_decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,11 @@ def test_merge_layers_nan(self, test_name, model_id, config_cls, config_kwargs):
def test_generate(self, test_name, model_id, config_cls, config_kwargs):
self._test_generate(model_id, config_cls, config_kwargs)

@parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
def test_generate_pos_args(self, test_name, model_id, config_cls, config_kwargs):
# positional args are supported for PeftModelForCausalLM
self._test_generate_pos_args(model_id, config_cls, config_kwargs, raises_err=False)

@parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
def test_merge_layers_fp16(self, test_name, model_id, config_cls, config_kwargs):
self._test_merge_layers_fp16(model_id, config_cls, config_kwargs)
Expand Down
6 changes: 6 additions & 0 deletions tests/test_encoder_decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,12 @@ def test_merge_layers(self, test_name, model_id, config_cls, config_kwargs):
def test_generate(self, test_name, model_id, config_cls, config_kwargs):
self._test_generate(model_id, config_cls, config_kwargs)

# skip non lora models - generate does not work for prefix tuning, prompt tuning
@parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
def test_generate_pos_args(self, test_name, model_id, config_cls, config_kwargs):
# positional arguments are not supported for PeftModelForSeq2SeqLM
self._test_generate_pos_args(model_id, config_cls, config_kwargs, raises_err=True)

@parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
def test_generate_half_prec(self, test_name, model_id, config_cls, config_kwargs):
self._test_generate_half_prec(model_id, config_cls, config_kwargs)
Expand Down
5 changes: 2 additions & 3 deletions tests/test_multitask_prompt_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,9 +214,8 @@ def test_generate(self) -> None:
# check if `generate` works
_ = model.generate(input_ids=input_ids, attention_mask=attention_mask, task_ids=task_ids)

with self.assertRaises(TypeError):
# check if `generate` raises an error if no positional arguments are passed
_ = model.generate(input_ids, attention_mask=attention_mask)
# check if `generate` works if positional arguments are passed
_ = model.generate(input_ids, attention_mask=attention_mask, task_ids=task_ids)

def test_use_cache(self) -> None:
"""Test that MultiTaskPromptTuning works when Llama config use_cache=True."""
Expand Down
2 changes: 1 addition & 1 deletion tests/test_tuners_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@

from peft import IA3Config, LoHaConfig, LoraConfig, get_peft_model
from peft.tuners.tuners_utils import (
INCLUDE_LINEAR_LAYERS_SHORTHAND,
_maybe_include_all_linear_layers,
check_target_module_exists,
inspect_matched_modules,
)
from peft.utils import INCLUDE_LINEAR_LAYERS_SHORTHAND

from .testing_utils import require_bitsandbytes, require_torch_gpu

Expand Down
22 changes: 16 additions & 6 deletions tests/testing_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,8 +650,22 @@ def _test_generate(self, model_id, config_cls, config_kwargs):
# check if `generate` works
_ = model.generate(**inputs)

with self.assertRaises(TypeError):
# check if `generate` raises an error if no positional arguments are passed
def _test_generate_pos_args(self, model_id, config_cls, config_kwargs, raises_err: bool):
model = self.transformers_class.from_pretrained(model_id)
config = config_cls(
base_model_name_or_path=model_id,
**config_kwargs,
)
model = get_peft_model(model, config)
model = model.to(self.torch_device)

inputs = self.prepare_inputs_for_testing()
if raises_err:
with self.assertRaises(TypeError):
# check if `generate` raises an error if positional arguments are passed
_ = model.generate(inputs["input_ids"])
else:
# check if `generate` works if positional arguments are passed
_ = model.generate(inputs["input_ids"])

def _test_generate_half_prec(self, model_id, config_cls, config_kwargs):
Expand All @@ -672,10 +686,6 @@ def _test_generate_half_prec(self, model_id, config_cls, config_kwargs):
# check if `generate` works
_ = model.generate(input_ids=input_ids, attention_mask=attention_mask)

with self.assertRaises(TypeError):
# check if `generate` raises an error if no positional arguments are passed
_ = model.generate(input_ids, attention_mask=attention_mask)

def _test_prefix_tuning_half_prec_conversion(self, model_id, config_cls, config_kwargs):
if config_cls not in (PrefixTuningConfig,):
return
Expand Down

0 comments on commit 9d94367

Please sign in to comment.