File tree Expand file tree Collapse file tree 1 file changed +3
-7
lines changed Expand file tree Collapse file tree 1 file changed +3
-7
lines changed Original file line number Diff line number Diff line change 4444
4545"""
4646python examples/scripts/ppo/ppo_tldr.py \
47- --dataset_name trl-internal-testing /tldr-preference-sft-trl-style \
47+ --dataset_name trl-lib /tldr \
4848 --dataset_test_split validation \
4949 --learning_rate 3e-6 \
5050 --output_dir pythia-1b-deduped-tldr-preference-sft-trl-style-ppo \
6262
6363accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml \
6464 examples/scripts/ppo/ppo_tldr.py \
65- --dataset_name trl-internal-testing /tldr-preference-sft-trl-style \
65+ --dataset_name trl-lib /tldr \
6666 --dataset_test_split validation \
6767 --output_dir pythia-1b-deduped-tldr-preference-sft-trl-style-ppo \
6868 --learning_rate 3e-6 \
@@ -134,11 +134,7 @@ def prepare_dataset(dataset, tokenizer):
134134 """pre-tokenize the dataset before training; only collate during training"""
135135
136136 def tokenize (element ):
137- input_ids = tokenizer .apply_chat_template (
138- element ["messages" ][:1 ],
139- padding = False ,
140- add_generation_prompt = True ,
141- )
137+ input_ids = tokenizer (element ["prompt" ], padding = False )["input_ids" ]
142138 return {"input_ids" : input_ids , "lengths" : len (input_ids )}
143139
144140 return dataset .map (
You can’t perform that action at this time.
0 commit comments