diff --git a/trl/trainer/cpo_trainer.py b/trl/trainer/cpo_trainer.py index 5ff8ca5885..ccf1fcc5b1 100644 --- a/trl/trainer/cpo_trainer.py +++ b/trl/trainer/cpo_trainer.py @@ -106,7 +106,9 @@ class CPOTrainer(Trainer): _tag_names = ["trl", "cpo"] - @deprecate_kwarg("tokenizer", new_name="processing_class", version="0.14.0", raise_if_both_names=True) + @deprecate_kwarg( + "tokenizer", "0.14.0", "processing_class", warn_if_greater_or_equal_version=True, raise_if_both_names=True + ) def __init__( self, model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 4880c231df..1856e1899a 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -188,7 +188,9 @@ class DPOTrainer(Trainer): _tag_names = ["trl", "dpo"] - @deprecate_kwarg("tokenizer", new_name="processing_class", version="0.16.0", raise_if_both_names=True) + @deprecate_kwarg( + "tokenizer", "0.16.0", "processing_class", warn_if_greater_or_equal_version=True, raise_if_both_names=True + ) def __init__( self, model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, diff --git a/trl/trainer/kto_trainer.py b/trl/trainer/kto_trainer.py index 629486674e..303b65e715 100644 --- a/trl/trainer/kto_trainer.py +++ b/trl/trainer/kto_trainer.py @@ -317,7 +317,9 @@ class KTOTrainer(Trainer): _tag_names = ["trl", "kto"] - @deprecate_kwarg("tokenizer", new_name="processing_class", version="0.14.0", raise_if_both_names=True) + @deprecate_kwarg( + "tokenizer", "0.14.0", "processing_class", warn_if_greater_or_equal_version=True, raise_if_both_names=True + ) def __init__( self, model: Union[PreTrainedModel, nn.Module, str] = None, diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index c615cd6349..7014ba6926 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -127,7 +127,9 @@ class OnlineDPOTrainer(Trainer): _tag_names = ["trl", "online-dpo"] - @deprecate_kwarg("tokenizer", new_name="processing_class", version="0.14.0", raise_if_both_names=True) + @deprecate_kwarg( + "tokenizer", "0.14.0", "processing_class", warn_if_greater_or_equal_version=True, raise_if_both_names=True + ) def __init__( self, model: Union[PreTrainedModel, nn.Module], diff --git a/trl/trainer/orpo_trainer.py b/trl/trainer/orpo_trainer.py index 45fe7009ff..529baed769 100644 --- a/trl/trainer/orpo_trainer.py +++ b/trl/trainer/orpo_trainer.py @@ -117,7 +117,9 @@ class ORPOTrainer(Trainer): _tag_names = ["trl", "orpo"] - @deprecate_kwarg("tokenizer", new_name="processing_class", version="0.15.0", raise_if_both_names=True) + @deprecate_kwarg( + "tokenizer", "0.15.0", "processing_class", warn_if_greater_or_equal_version=True, raise_if_both_names=True + ) def __init__( self, model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, diff --git a/trl/trainer/reward_trainer.py b/trl/trainer/reward_trainer.py index d9281d4e27..5dd7c53efd 100644 --- a/trl/trainer/reward_trainer.py +++ b/trl/trainer/reward_trainer.py @@ -81,7 +81,9 @@ def _tokenize(batch: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizerBase") class RewardTrainer(Trainer): _tag_names = ["trl", "reward-trainer"] - @deprecate_kwarg("tokenizer", new_name="processing_class", version="0.15.0", raise_if_both_names=True) + @deprecate_kwarg( + "tokenizer", "0.15.0", "processing_class", warn_if_greater_or_equal_version=True, raise_if_both_names=True + ) def __init__( self, model: Optional[Union[PreTrainedModel, nn.Module]] = None, diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 466874e91c..ea8d4cf713 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -72,7 +72,9 @@ class RLOOTrainer(Trainer): _tag_names = ["trl", "rloo"] - @deprecate_kwarg("tokenizer", new_name="processing_class", version="0.14.0", raise_if_both_names=True) + @deprecate_kwarg( + "tokenizer", "0.14.0", "processing_class", warn_if_greater_or_equal_version=True, raise_if_both_names=True + ) def __init__( self, config: RLOOConfig, diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index 1e1a6781bc..56c4ad0991 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -106,7 +106,9 @@ class SFTTrainer(Trainer): _tag_names = ["trl", "sft"] - @deprecate_kwarg("tokenizer", new_name="processing_class", version="0.16.0", raise_if_both_names=True) + @deprecate_kwarg( + "tokenizer", "0.16.0", "processing_class", warn_if_greater_or_equal_version=True, raise_if_both_names=True + ) def __init__( self, model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,