|
32 | 32 |
|
33 | 33 |
|
34 | 34 | if is_peft_available(): |
35 | | - from peft import LoraConfig, PeftModel, PromptEncoderConfig, TaskType, get_peft_model |
| 35 | + from peft import ( |
| 36 | + LoraConfig, |
| 37 | + PeftModel, |
| 38 | + PrefixTuningConfig, |
| 39 | + PromptEncoderConfig, |
| 40 | + PromptTuningConfig, |
| 41 | + TaskType, |
| 42 | + get_peft_model, |
| 43 | + ) |
36 | 44 |
|
37 | 45 |
|
38 | 46 | class TestDFTLoss(TrlTestCase): |
@@ -453,7 +461,7 @@ def test_train_model_dtype(self): |
453 | 461 | assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" |
454 | 462 |
|
455 | 463 | @require_peft |
456 | | - def test_train_dense_with_peft_config(self): |
| 464 | + def test_train_dense_with_peft_config_lora(self): |
457 | 465 | # Get the base model parameter names |
458 | 466 | model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" |
459 | 467 | model = AutoModelForCausalLM.from_pretrained(model_id) |
@@ -489,6 +497,66 @@ def test_train_dense_with_peft_config(self): |
489 | 497 | elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer) |
490 | 498 | assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" |
491 | 499 |
|
| 500 | + @parameterized.expand( |
| 501 | + [ |
| 502 | + ("prompt_tuning",), |
| 503 | + ("prefix_tuning",), |
| 504 | + ("prompt_encoder",), |
| 505 | + ] |
| 506 | + ) |
| 507 | + @require_peft |
| 508 | + def test_train_with_peft_config_prompt_tuning(self, peft_type): |
| 509 | + # Get the base model parameter names |
| 510 | + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" |
| 511 | + model = AutoModelForCausalLM.from_pretrained(model_id) |
| 512 | + base_param_names = [f"base_model.{n}" for n, _ in model.named_parameters()] |
| 513 | + |
| 514 | + # Get the dataset |
| 515 | + dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") |
| 516 | + |
| 517 | + # Initialize the trainer, p-tuning doesn't support gradient checkpointing |
| 518 | + training_args = SFTConfig(bf16=False, output_dir=self.tmp_dir, report_to="none", gradient_checkpointing=False) |
| 519 | + if peft_type == "prompt_tuning": |
| 520 | + peft_config = PromptTuningConfig( |
| 521 | + task_type=TaskType.CAUSAL_LM, |
| 522 | + num_virtual_tokens=4, |
| 523 | + tokenizer_name_or_path="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", |
| 524 | + ) |
| 525 | + elif peft_type == "prefix_tuning": |
| 526 | + peft_config = PrefixTuningConfig( |
| 527 | + task_type=TaskType.CAUSAL_LM, |
| 528 | + num_virtual_tokens=4, |
| 529 | + ) |
| 530 | + elif peft_type == "prompt_encoder": |
| 531 | + peft_config = PromptEncoderConfig( |
| 532 | + task_type=TaskType.CAUSAL_LM, |
| 533 | + num_virtual_tokens=4, |
| 534 | + encoder_hidden_size=model.config.hidden_size, # This will be overwritten below |
| 535 | + ) |
| 536 | + trainer = SFTTrainer( |
| 537 | + model=model_id, |
| 538 | + args=training_args, |
| 539 | + train_dataset=dataset, |
| 540 | + peft_config=peft_config, |
| 541 | + ) |
| 542 | + |
| 543 | + # Save the initial parameters to compare them later |
| 544 | + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} |
| 545 | + |
| 546 | + # Train the model |
| 547 | + trainer.train() |
| 548 | + |
| 549 | + # Check that the training loss is not None |
| 550 | + assert trainer.state.log_history[-1]["train_loss"] is not None |
| 551 | + |
| 552 | + # Check the peft params have changed and the base model params have not changed |
| 553 | + for n, param in previous_trainable_params.items(): |
| 554 | + new_param = trainer.model.get_parameter(n) |
| 555 | + if n in base_param_names: # We expect the base model parameters to be the same |
| 556 | + assert torch.allclose(param, new_param), f"Parameter {n} has changed" |
| 557 | + else: # We expect the peft parameters to be different |
| 558 | + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" |
| 559 | + |
492 | 560 | @require_peft |
493 | 561 | def test_train_moe_with_peft_config(self): |
494 | 562 | # Get the base model parameter names |
|
0 commit comments