diff --git a/orpo_training.py b/orpo_training.py index 4a7055c..49dcbd3 100644 --- a/orpo_training.py +++ b/orpo_training.py @@ -45,6 +45,11 @@ class ScriptArguments: The name of the Casual LM model we wish to fine with DPO """ # Model arguments + + max_length: Optional[int] = field(default=512, + metadata={"help": "Maximum total input sequence length after tokenization."}) + max_prompt_length: Optional[int] = field(default=128, metadata={"help": "Maximum length of prompt sequences."}) + model_type: str = field( default=None, metadata={"help": "Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())} @@ -415,6 +420,8 @@ def main(): model.config.use_cache = True training_args = ORPOConfig( + max_length=args.max_length, + max_prompt_length=args.max_prompt_length, per_device_train_batch_size=args.per_device_train_batch_size, per_device_eval_batch_size=args.per_device_eval_batch_size, max_steps=args.max_steps,