Skip to content

Commit f15399d

Browse files
authored
Fix entropy and accuracy calculation for prompt_tuning techniques. (#4196)
1 parent cc578b6 commit f15399d

File tree

2 files changed

+84
-11
lines changed

2 files changed

+84
-11
lines changed

tests/test_sft_trainer.py

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,15 @@
3232

3333

3434
if is_peft_available():
35-
from peft import LoraConfig, PeftModel, PromptEncoderConfig, TaskType, get_peft_model
35+
from peft import (
36+
LoraConfig,
37+
PeftModel,
38+
PrefixTuningConfig,
39+
PromptEncoderConfig,
40+
PromptTuningConfig,
41+
TaskType,
42+
get_peft_model,
43+
)
3644

3745

3846
class TestDFTLoss(TrlTestCase):
@@ -453,7 +461,7 @@ def test_train_model_dtype(self):
453461
assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
454462

455463
@require_peft
456-
def test_train_dense_with_peft_config(self):
464+
def test_train_dense_with_peft_config_lora(self):
457465
# Get the base model parameter names
458466
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
459467
model = AutoModelForCausalLM.from_pretrained(model_id)
@@ -489,6 +497,66 @@ def test_train_dense_with_peft_config(self):
489497
elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer)
490498
assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
491499

500+
@parameterized.expand(
501+
[
502+
("prompt_tuning",),
503+
("prefix_tuning",),
504+
("prompt_encoder",),
505+
]
506+
)
507+
@require_peft
508+
def test_train_with_peft_config_prompt_tuning(self, peft_type):
509+
# Get the base model parameter names
510+
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
511+
model = AutoModelForCausalLM.from_pretrained(model_id)
512+
base_param_names = [f"base_model.{n}" for n, _ in model.named_parameters()]
513+
514+
# Get the dataset
515+
dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train")
516+
517+
# Initialize the trainer, p-tuning doesn't support gradient checkpointing
518+
training_args = SFTConfig(bf16=False, output_dir=self.tmp_dir, report_to="none", gradient_checkpointing=False)
519+
if peft_type == "prompt_tuning":
520+
peft_config = PromptTuningConfig(
521+
task_type=TaskType.CAUSAL_LM,
522+
num_virtual_tokens=4,
523+
tokenizer_name_or_path="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
524+
)
525+
elif peft_type == "prefix_tuning":
526+
peft_config = PrefixTuningConfig(
527+
task_type=TaskType.CAUSAL_LM,
528+
num_virtual_tokens=4,
529+
)
530+
elif peft_type == "prompt_encoder":
531+
peft_config = PromptEncoderConfig(
532+
task_type=TaskType.CAUSAL_LM,
533+
num_virtual_tokens=4,
534+
encoder_hidden_size=model.config.hidden_size, # This will be overwritten below
535+
)
536+
trainer = SFTTrainer(
537+
model=model_id,
538+
args=training_args,
539+
train_dataset=dataset,
540+
peft_config=peft_config,
541+
)
542+
543+
# Save the initial parameters to compare them later
544+
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
545+
546+
# Train the model
547+
trainer.train()
548+
549+
# Check that the training loss is not None
550+
assert trainer.state.log_history[-1]["train_loss"] is not None
551+
552+
# Check the peft params have changed and the base model params have not changed
553+
for n, param in previous_trainable_params.items():
554+
new_param = trainer.model.get_parameter(n)
555+
if n in base_param_names: # We expect the base model parameters to be the same
556+
assert torch.allclose(param, new_param), f"Parameter {n} has changed"
557+
else: # We expect the peft parameters to be different
558+
assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
559+
492560
@require_peft
493561
def test_train_moe_with_peft_config(self):
494562
# Get the base model parameter names

trl/trainer/sft_trainer.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161

6262

6363
if is_peft_available():
64-
from peft import PeftConfig, PeftModel
64+
from peft import PeftConfig, PeftModel, PeftType
6565

6666

6767
logger = logging.get_logger(__name__)
@@ -1090,13 +1090,15 @@ def compute_loss(
10901090
if not self.args.use_liger_kernel: # liger doesn't return logits
10911091
with torch.no_grad():
10921092
per_token_entropy = entropy_from_logits(outputs.logits)
1093+
# When using Prompt Tuning, skip the virtual tokens in logits before entropy computation, since they
1094+
# do not correspond to actual input tokens.
1095+
if (
1096+
self.num_virtual_tokens > 0
1097+
and model.peft_config[model.active_adapter].peft_type != PeftType.PREFIX_TUNING
1098+
):
1099+
per_token_entropy = per_token_entropy[:, self.num_virtual_tokens :]
10931100
if "attention_mask" in inputs:
10941101
attention_mask = inputs["attention_mask"]
1095-
# When using Prompt Tuning, we need to add attention for the virtual tokens (all set to 1).
1096-
virtual_attention_mask = torch.ones(
1097-
attention_mask.size(0), self.num_virtual_tokens, device=attention_mask.device
1098-
)
1099-
attention_mask = torch.cat((virtual_attention_mask, attention_mask), dim=1)
11001102
entropy = torch.sum(per_token_entropy * attention_mask) / attention_mask.sum()
11011103
elif "position_ids" in inputs:
11021104
entropy = torch.mean(per_token_entropy)
@@ -1131,9 +1133,12 @@ def compute_loss(
11311133
shift_logits = outputs.logits[..., :-1, :].contiguous()
11321134
shift_labels = labels[..., 1:].contiguous()
11331135

1134-
# When using Prompt Tuning, skip the virtual tokens in logits before accuracy computation, since they do
1135-
# not correspond to actual input labels.
1136-
shift_logits = shift_logits[:, self.num_virtual_tokens :, :]
1136+
# Prompt Tuning and P-Tuning output logits for virtual tokens but Prefix-Tuning does not.
1137+
if (
1138+
self.num_virtual_tokens > 0
1139+
and model.peft_config[model.active_adapter].peft_type != PeftType.PREFIX_TUNING
1140+
):
1141+
shift_logits = shift_logits[:, self.num_virtual_tokens :, :]
11371142

11381143
# Get predictions
11391144
predictions = shift_logits.argmax(dim=-1)

0 commit comments

Comments
 (0)