Skip to content

Commit cf431db

Browse files
authored
Fix PPO example (#4556)
1 parent cac9f1d commit cf431db

File tree

1 file changed

+3
-7
lines changed

1 file changed

+3
-7
lines changed

examples/scripts/ppo/ppo_tldr.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444

4545
"""
4646
python examples/scripts/ppo/ppo_tldr.py \
47-
--dataset_name trl-internal-testing/tldr-preference-sft-trl-style \
47+
--dataset_name trl-lib/tldr \
4848
--dataset_test_split validation \
4949
--learning_rate 3e-6 \
5050
--output_dir pythia-1b-deduped-tldr-preference-sft-trl-style-ppo \
@@ -62,7 +62,7 @@
6262
6363
accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml \
6464
examples/scripts/ppo/ppo_tldr.py \
65-
--dataset_name trl-internal-testing/tldr-preference-sft-trl-style \
65+
--dataset_name trl-lib/tldr \
6666
--dataset_test_split validation \
6767
--output_dir pythia-1b-deduped-tldr-preference-sft-trl-style-ppo \
6868
--learning_rate 3e-6 \
@@ -134,11 +134,7 @@ def prepare_dataset(dataset, tokenizer):
134134
"""pre-tokenize the dataset before training; only collate during training"""
135135

136136
def tokenize(element):
137-
input_ids = tokenizer.apply_chat_template(
138-
element["messages"][:1],
139-
padding=False,
140-
add_generation_prompt=True,
141-
)
137+
input_ids = tokenizer(element["prompt"], padding=False)["input_ids"]
142138
return {"input_ids": input_ids, "lengths": len(input_ids)}
143139

144140
return dataset.map(

0 commit comments

Comments
 (0)