|
16 | 16 | from transformers import AutoTokenizer |
17 | 17 |
|
18 | 18 | from trl.experimental.bco import BCOConfig, BCOTrainer |
| 19 | +from trl.experimental.orpo import ORPOConfig, ORPOTrainer |
19 | 20 |
|
20 | 21 | from ..testing_utils import TrlTestCase, require_sklearn |
21 | 22 |
|
@@ -68,3 +69,30 @@ def test_bco(self): |
68 | 69 | assert trainer.args.prompt_sample_size == 512 |
69 | 70 | assert trainer.args.min_density_ratio == 0.2 |
70 | 71 | assert trainer.args.max_density_ratio == 20.0 |
| 72 | + |
| 73 | + def test_orpo(self): |
| 74 | + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" |
| 75 | + tokenizer = AutoTokenizer.from_pretrained(model_id) |
| 76 | + dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train") |
| 77 | + training_args = ORPOConfig( |
| 78 | + self.tmp_dir, |
| 79 | + max_length=256, |
| 80 | + max_prompt_length=64, |
| 81 | + max_completion_length=64, |
| 82 | + beta=0.5, |
| 83 | + disable_dropout=False, |
| 84 | + label_pad_token_id=-99, |
| 85 | + padding_value=-99, |
| 86 | + truncation_mode="keep_start", |
| 87 | + # generate_during_eval=True, # ignore this one, it requires wandb |
| 88 | + is_encoder_decoder=True, |
| 89 | + model_init_kwargs={"trust_remote_code": True}, |
| 90 | + dataset_num_proc=4, |
| 91 | + ) |
| 92 | + trainer = ORPOTrainer(model=model_id, args=training_args, train_dataset=dataset, processing_class=tokenizer) |
| 93 | + assert trainer.args.max_length == 256 |
| 94 | + assert trainer.args.max_prompt_length == 64 |
| 95 | + assert trainer.args.max_completion_length == 64 |
| 96 | + assert trainer.args.beta == 0.5 |
| 97 | + assert not trainer.args.disable_dropout |
| 98 | + assert trainer.args.label_pad_token_id == -99 |
0 commit comments