From 460e78026526bb5e5351922f6fdc1fc58eb0053c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Tue, 10 Dec 2024 12:51:20 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=91=AF=20Standardize=20`model=5Fargs`=20(?= =?UTF-8?q?#2442)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * `model_config` -> `model_args` * sort --- docs/source/sft_trainer.mdx | 22 ++++++++--------- examples/scripts/cpo.py | 8 +++---- examples/scripts/dpo.py | 20 +++++++--------- examples/scripts/dpo_online.py | 24 +++++++++---------- examples/scripts/dpo_vlm.py | 26 ++++++++++----------- examples/scripts/gkd.py | 30 ++++++++++++------------ examples/scripts/nash_md.py | 22 ++++++++--------- examples/scripts/orpo.py | 8 +++---- examples/scripts/ppo/ppo.py | 26 ++++++++++----------- examples/scripts/ppo/ppo_tldr.py | 26 ++++++++++----------- examples/scripts/reward_modeling.py | 18 +++++++------- examples/scripts/rloo/rloo.py | 12 +++++----- examples/scripts/rloo/rloo_tldr.py | 12 +++++----- examples/scripts/sft.py | 18 +++++++------- examples/scripts/sft_video_llm.py | 14 +++++------ examples/scripts/sft_vlm.py | 18 +++++++------- examples/scripts/sft_vlm_smol_vlm.py | 18 +++++++------- examples/scripts/xpo.py | 22 ++++++++--------- tests/test_utils.py | 8 +++---- trl/trainer/utils.py | 35 +++++++++++++++------------- 20 files changed, 184 insertions(+), 203 deletions(-) diff --git a/docs/source/sft_trainer.mdx b/docs/source/sft_trainer.mdx index 5b7827fe26..c45069d18c 100644 --- a/docs/source/sft_trainer.mdx +++ b/docs/source/sft_trainer.mdx @@ -468,30 +468,30 @@ We included a utility function to create your model. ```python from trl import ModelConfig, SFTTrainer, get_kbit_device_map, get_peft_config, get_quantization_config -model_config = ModelConfig( +model_args = ModelConfig( model_name_or_path="facebook/opt-350m" attn_implementation=None, # or "flash_attention_2" ) torch_dtype = ( - model_config.torch_dtype - if model_config.torch_dtype in ["auto", None] - else getattr(torch, model_config.torch_dtype) + model_args.torch_dtype + if model_args.torch_dtype in ["auto", None] + else getattr(torch, model_args.torch_dtype) ) -quantization_config = get_quantization_config(model_config) +quantization_config = get_quantization_config(model_args) model_kwargs = dict( - revision=model_config.model_revision, - trust_remote_code=model_config.trust_remote_code, - attn_implementation=model_config.attn_implementation, + revision=model_args.model_revision, + trust_remote_code=model_args.trust_remote_code, + attn_implementation=model_args.attn_implementation, torch_dtype=torch_dtype, use_cache=False if training_args.gradient_checkpointing else True, device_map=get_kbit_device_map() if quantization_config is not None else None, quantization_config=quantization_config, ) -model = AutoModelForCausalLM.from_pretrained(model_config.model_name_or_path, **model_kwargs) +model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, **model_kwargs) trainer = SFTTrainer( ..., - model=model_config.model_name_or_path, - peft_config=get_peft_config(model_config), + model=model_args.model_name_or_path, + peft_config=get_peft_config(model_args), ) ``` diff --git a/examples/scripts/cpo.py b/examples/scripts/cpo.py index 20ea85c925..1132e9b573 100644 --- a/examples/scripts/cpo.py +++ b/examples/scripts/cpo.py @@ -64,16 +64,16 @@ if __name__ == "__main__": parser = HfArgumentParser((ScriptArguments, CPOConfig, ModelConfig)) - script_args, training_args, model_config = parser.parse_args_into_dataclasses() + script_args, training_args, model_args = parser.parse_args_into_dataclasses() ################ # Model & Tokenizer ################ model = AutoModelForCausalLM.from_pretrained( - model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code ) tokenizer = AutoTokenizer.from_pretrained( - model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token @@ -94,7 +94,7 @@ train_dataset=dataset[script_args.dataset_train_split], eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, processing_class=tokenizer, - peft_config=get_peft_config(model_config), + peft_config=get_peft_config(model_args), ) # train and save the model diff --git a/examples/scripts/dpo.py b/examples/scripts/dpo.py index cbaba95a1a..08b0f18db7 100644 --- a/examples/scripts/dpo.py +++ b/examples/scripts/dpo.py @@ -66,37 +66,35 @@ if __name__ == "__main__": parser = TrlParser((ScriptArguments, DPOConfig, ModelConfig)) - script_args, training_args, model_config = parser.parse_args_and_config() + script_args, training_args, model_args = parser.parse_args_and_config() ################ # Model & Tokenizer ################### torch_dtype = ( - model_config.torch_dtype - if model_config.torch_dtype in ["auto", None] - else getattr(torch, model_config.torch_dtype) + model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) ) - quantization_config = get_quantization_config(model_config) + quantization_config = get_quantization_config(model_args) model_kwargs = dict( - revision=model_config.model_revision, - attn_implementation=model_config.attn_implementation, + revision=model_args.model_revision, + attn_implementation=model_args.attn_implementation, torch_dtype=torch_dtype, use_cache=False if training_args.gradient_checkpointing else True, device_map=get_kbit_device_map() if quantization_config is not None else None, quantization_config=quantization_config, ) model = AutoModelForCausalLM.from_pretrained( - model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, **model_kwargs + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs ) - peft_config = get_peft_config(model_config) + peft_config = get_peft_config(model_args) if peft_config is None: ref_model = AutoModelForCausalLM.from_pretrained( - model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, **model_kwargs + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs ) else: ref_model = None tokenizer = AutoTokenizer.from_pretrained( - model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token diff --git a/examples/scripts/dpo_online.py b/examples/scripts/dpo_online.py index 4859056259..185343e611 100644 --- a/examples/scripts/dpo_online.py +++ b/examples/scripts/dpo_online.py @@ -65,18 +65,16 @@ if __name__ == "__main__": parser = TrlParser((ScriptArguments, OnlineDPOConfig, ModelConfig)) - script_args, training_args, model_config = parser.parse_args_and_config() + script_args, training_args, model_args = parser.parse_args_and_config() training_args.gradient_checkpointing_kwargs = {"use_reentrant": True} torch_dtype = ( - model_config.torch_dtype - if model_config.torch_dtype in ["auto", None] - else getattr(torch, model_config.torch_dtype) + model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) ) - quantization_config = get_quantization_config(model_config) + quantization_config = get_quantization_config(model_args) model_kwargs = dict( - revision=model_config.model_revision, - attn_implementation=model_config.attn_implementation, + revision=model_args.model_revision, + attn_implementation=model_args.attn_implementation, torch_dtype=torch_dtype, use_cache=False if training_args.gradient_checkpointing else True, device_map=get_kbit_device_map() if quantization_config is not None else None, @@ -84,19 +82,19 @@ ) model = AutoModelForCausalLM.from_pretrained( - model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, **model_kwargs + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs ) if training_args.reward_model_path is not None: reward_model = AutoModelForSequenceClassification.from_pretrained( training_args.reward_model_path, num_labels=1, - trust_remote_code=model_config.trust_remote_code, + trust_remote_code=model_args.trust_remote_code, **model_kwargs, ) reward_tokenizer = AutoTokenizer.from_pretrained( training_args.reward_model_path, - trust_remote_code=model_config.trust_remote_code, + trust_remote_code=model_args.trust_remote_code, truncation=True, truncation_side="left", # since we judge the completion, truncating left is more appropriate ) @@ -111,9 +109,9 @@ judge = None tokenizer = AutoTokenizer.from_pretrained( - model_config.model_name_or_path, + model_args.model_name_or_path, padding_side="left", - trust_remote_code=model_config.trust_remote_code, + trust_remote_code=model_args.trust_remote_code, **model_kwargs, ) if tokenizer.chat_template is None: @@ -132,7 +130,7 @@ eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, processing_class=tokenizer, reward_processing_class=reward_tokenizer, - peft_config=get_peft_config(model_config), + peft_config=get_peft_config(model_args), ) if training_args.eval_strategy != "no": diff --git a/examples/scripts/dpo_vlm.py b/examples/scripts/dpo_vlm.py index a58dfc4152..e093aa4d9d 100644 --- a/examples/scripts/dpo_vlm.py +++ b/examples/scripts/dpo_vlm.py @@ -45,42 +45,40 @@ if __name__ == "__main__": parser = TrlParser((ScriptArguments, DPOConfig, ModelConfig)) - script_args, training_args, model_config = parser.parse_args_and_config() + script_args, training_args, model_args = parser.parse_args_and_config() ################ # Model & Tokenizer ################ torch_dtype = ( - model_config.torch_dtype - if model_config.torch_dtype in ["auto", None] - else getattr(torch, model_config.torch_dtype) + model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) ) - quantization_config = get_quantization_config(model_config) + quantization_config = get_quantization_config(model_args) model_kwargs = dict( - revision=model_config.model_revision, - attn_implementation=model_config.attn_implementation, + revision=model_args.model_revision, + attn_implementation=model_args.attn_implementation, torch_dtype=torch_dtype, device_map=get_kbit_device_map() if quantization_config is not None else None, quantization_config=quantization_config, ) model = AutoModelForVision2Seq.from_pretrained( - model_config.model_name_or_path, - trust_remote_code=model_config.trust_remote_code, + model_args.model_name_or_path, + trust_remote_code=model_args.trust_remote_code, **model_kwargs, ) - peft_config = get_peft_config(model_config) + peft_config = get_peft_config(model_args) if peft_config is None: ref_model = AutoModelForVision2Seq.from_pretrained( - model_config.model_name_or_path, - trust_remote_code=model_config.trust_remote_code, + model_args.model_name_or_path, + trust_remote_code=model_args.trust_remote_code, **model_kwargs, ) else: ref_model = None processor = AutoProcessor.from_pretrained( - model_config.model_name_or_path, - trust_remote_code=model_config.trust_remote_code, + model_args.model_name_or_path, + trust_remote_code=model_args.trust_remote_code, do_image_splitting=False, ) tokenizer = processor.tokenizer diff --git a/examples/scripts/gkd.py b/examples/scripts/gkd.py index bac694b9be..4408c2dfee 100644 --- a/examples/scripts/gkd.py +++ b/examples/scripts/gkd.py @@ -64,17 +64,17 @@ if __name__ == "__main__": parser = TrlParser((ScriptArguments, GKDConfig, ModelConfig)) - script_args, training_args, model_config = parser.parse_args_and_config() + script_args, training_args, model_args = parser.parse_args_and_config() ################ # Model & Tokenizer ################ - quantization_config = get_quantization_config(model_config) + quantization_config = get_quantization_config(model_args) model_kwargs = dict( - revision=model_config.model_revision, - trust_remote_code=model_config.trust_remote_code, - attn_implementation=model_config.attn_implementation, - torch_dtype=model_config.torch_dtype, + revision=model_args.model_revision, + trust_remote_code=model_args.trust_remote_code, + attn_implementation=model_args.attn_implementation, + torch_dtype=model_args.torch_dtype, use_cache=False if training_args.gradient_checkpointing else True, device_map=get_kbit_device_map() if quantization_config is not None else None, quantization_config=quantization_config, @@ -82,10 +82,10 @@ training_args.model_init_kwargs = model_kwargs teacher_model_kwargs = dict( - revision=model_config.model_revision, - trust_remote_code=model_config.trust_remote_code, - attn_implementation=model_config.attn_implementation, - torch_dtype=model_config.torch_dtype, + revision=model_args.model_revision, + trust_remote_code=model_args.trust_remote_code, + attn_implementation=model_args.attn_implementation, + torch_dtype=model_args.torch_dtype, use_cache=True, device_map=get_kbit_device_map() if quantization_config is not None else None, quantization_config=quantization_config, @@ -93,9 +93,9 @@ training_args.teacher_model_init_kwargs = teacher_model_kwargs tokenizer = AutoTokenizer.from_pretrained( - model_config.model_name_or_path, - revision=model_config.model_revision, - trust_remote_code=model_config.trust_remote_code, + model_args.model_name_or_path, + revision=model_args.model_revision, + trust_remote_code=model_args.trust_remote_code, padding_side="left", ) if tokenizer.pad_token is None: @@ -118,13 +118,13 @@ # Training ################ trainer = GKDTrainer( - model=model_config.model_name_or_path, + model=model_args.model_name_or_path, teacher_model=training_args.teacher_model_name_or_path, args=training_args, train_dataset=dataset[script_args.dataset_train_split], eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, processing_class=tokenizer, - peft_config=get_peft_config(model_config), + peft_config=get_peft_config(model_args), ) if training_args.eval_strategy != "no": diff --git a/examples/scripts/nash_md.py b/examples/scripts/nash_md.py index bde51b32e3..71430bc536 100644 --- a/examples/scripts/nash_md.py +++ b/examples/scripts/nash_md.py @@ -70,18 +70,16 @@ if __name__ == "__main__": parser = TrlParser((ScriptArguments, NashMDConfig, ModelConfig)) - script_args, training_args, model_config = parser.parse_args_and_config() + script_args, training_args, model_args = parser.parse_args_and_config() training_args.gradient_checkpointing_kwargs = {"use_reentrant": True} torch_dtype = ( - model_config.torch_dtype - if model_config.torch_dtype in ["auto", None] - else getattr(torch, model_config.torch_dtype) + model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) ) - quantization_config = get_quantization_config(model_config) + quantization_config = get_quantization_config(model_args) model_kwargs = dict( - revision=model_config.model_revision, - attn_implementation=model_config.attn_implementation, + revision=model_args.model_revision, + attn_implementation=model_args.attn_implementation, torch_dtype=torch_dtype, use_cache=False if training_args.gradient_checkpointing else True, device_map=get_kbit_device_map() if quantization_config is not None else None, @@ -89,17 +87,17 @@ ) model = AutoModelForCausalLM.from_pretrained( - model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, **model_kwargs + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs ) ref_model = AutoModelForCausalLM.from_pretrained( - model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, **model_kwargs + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs ) if training_args.reward_model_path is not None: reward_model = AutoModelForSequenceClassification.from_pretrained( training_args.reward_model_path, num_labels=1, - trust_remote_code=model_config.trust_remote_code, + trust_remote_code=model_args.trust_remote_code, **model_kwargs, ) else: @@ -112,9 +110,9 @@ judge = None tokenizer = AutoTokenizer.from_pretrained( - model_config.model_name_or_path, + model_args.model_name_or_path, padding_side="left", - trust_remote_code=model_config.trust_remote_code, + trust_remote_code=model_args.trust_remote_code, ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token diff --git a/examples/scripts/orpo.py b/examples/scripts/orpo.py index 2d0fefd494..82578f99be 100644 --- a/examples/scripts/orpo.py +++ b/examples/scripts/orpo.py @@ -64,16 +64,16 @@ if __name__ == "__main__": parser = HfArgumentParser((ScriptArguments, ORPOConfig, ModelConfig)) - script_args, training_args, model_config = parser.parse_args_into_dataclasses() + script_args, training_args, model_args = parser.parse_args_into_dataclasses() ################ # Model & Tokenizer ################ model = AutoModelForCausalLM.from_pretrained( - model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code ) tokenizer = AutoTokenizer.from_pretrained( - model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token @@ -94,7 +94,7 @@ train_dataset=dataset[script_args.dataset_train_split], eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, processing_class=tokenizer, - peft_config=get_peft_config(model_config), + peft_config=get_peft_config(model_args), ) # train and save the model diff --git a/examples/scripts/ppo/ppo.py b/examples/scripts/ppo/ppo.py index 05b5870dae..9de5135635 100644 --- a/examples/scripts/ppo/ppo.py +++ b/examples/scripts/ppo/ppo.py @@ -69,7 +69,7 @@ if __name__ == "__main__": parser = HfArgumentParser((ScriptArguments, PPOConfig, ModelConfig)) - script_args, training_args, model_config = parser.parse_args_into_dataclasses() + script_args, training_args, model_args = parser.parse_args_into_dataclasses() # remove output_dir if exists shutil.rmtree(training_args.output_dir, ignore_errors=True) @@ -77,41 +77,39 @@ # Model & Tokenizer ################ torch_dtype = ( - model_config.torch_dtype - if model_config.torch_dtype in ["auto", None] - else getattr(torch, model_config.torch_dtype) + model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) ) - quantization_config = get_quantization_config(model_config) + quantization_config = get_quantization_config(model_args) model_kwargs = dict( - revision=model_config.model_revision, - attn_implementation=model_config.attn_implementation, + revision=model_args.model_revision, + attn_implementation=model_args.attn_implementation, torch_dtype=torch_dtype, device_map=get_kbit_device_map() if quantization_config is not None else None, quantization_config=quantization_config, ) tokenizer = AutoTokenizer.from_pretrained( - model_config.model_name_or_path, + model_args.model_name_or_path, padding_side="left", - trust_remote_code=model_config.trust_remote_code, + trust_remote_code=model_args.trust_remote_code, ) tokenizer.add_special_tokens({"pad_token": "[PAD]"}) if tokenizer.chat_template is None: tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE value_model = AutoModelForSequenceClassification.from_pretrained( - training_args.reward_model_path, trust_remote_code=model_config.trust_remote_code, num_labels=1 + training_args.reward_model_path, trust_remote_code=model_args.trust_remote_code, num_labels=1 ) reward_model = AutoModelForSequenceClassification.from_pretrained( - training_args.reward_model_path, trust_remote_code=model_config.trust_remote_code, num_labels=1 + training_args.reward_model_path, trust_remote_code=model_args.trust_remote_code, num_labels=1 ) policy = AutoModelForCausalLM.from_pretrained( - training_args.sft_model_path, trust_remote_code=model_config.trust_remote_code + training_args.sft_model_path, trust_remote_code=model_args.trust_remote_code ) - peft_config = get_peft_config(model_config) + peft_config = get_peft_config(model_args) if peft_config is None: ref_policy = AutoModelForCausalLM.from_pretrained( - training_args.sft_model_path, trust_remote_code=model_config.trust_remote_code + training_args.sft_model_path, trust_remote_code=model_args.trust_remote_code ) else: ref_policy = None diff --git a/examples/scripts/ppo/ppo_tldr.py b/examples/scripts/ppo/ppo_tldr.py index fabacf5b98..d0cd399a89 100644 --- a/examples/scripts/ppo/ppo_tldr.py +++ b/examples/scripts/ppo/ppo_tldr.py @@ -76,7 +76,7 @@ if __name__ == "__main__": parser = HfArgumentParser((ScriptArguments, PPOConfig, ModelConfig)) - script_args, training_args, model_config = parser.parse_args_into_dataclasses() + script_args, training_args, model_args = parser.parse_args_into_dataclasses() # remove output_dir if exists shutil.rmtree(training_args.output_dir, ignore_errors=True) @@ -84,41 +84,39 @@ # Model & Tokenizer ################ torch_dtype = ( - model_config.torch_dtype - if model_config.torch_dtype in ["auto", None] - else getattr(torch, model_config.torch_dtype) + model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) ) - quantization_config = get_quantization_config(model_config) + quantization_config = get_quantization_config(model_args) model_kwargs = dict( - revision=model_config.model_revision, - attn_implementation=model_config.attn_implementation, + revision=model_args.model_revision, + attn_implementation=model_args.attn_implementation, torch_dtype=torch_dtype, device_map=get_kbit_device_map() if quantization_config is not None else None, quantization_config=quantization_config, ) tokenizer = AutoTokenizer.from_pretrained( - model_config.model_name_or_path, + model_args.model_name_or_path, padding_side="left", - trust_remote_code=model_config.trust_remote_code, + trust_remote_code=model_args.trust_remote_code, ) tokenizer.add_special_tokens({"pad_token": "[PAD]"}) if tokenizer.chat_template is None: tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE value_model = AutoModelForSequenceClassification.from_pretrained( - training_args.reward_model_path, trust_remote_code=model_config.trust_remote_code, num_labels=1 + training_args.reward_model_path, trust_remote_code=model_args.trust_remote_code, num_labels=1 ) reward_model = AutoModelForSequenceClassification.from_pretrained( - training_args.reward_model_path, trust_remote_code=model_config.trust_remote_code, num_labels=1 + training_args.reward_model_path, trust_remote_code=model_args.trust_remote_code, num_labels=1 ) policy = AutoModelForCausalLM.from_pretrained( - training_args.sft_model_path, trust_remote_code=model_config.trust_remote_code + training_args.sft_model_path, trust_remote_code=model_args.trust_remote_code ) - peft_config = get_peft_config(model_config) + peft_config = get_peft_config(model_args) if peft_config is None: ref_policy = AutoModelForCausalLM.from_pretrained( - training_args.sft_model_path, trust_remote_code=model_config.trust_remote_code + training_args.sft_model_path, trust_remote_code=model_args.trust_remote_code ) else: ref_policy = None diff --git a/examples/scripts/reward_modeling.py b/examples/scripts/reward_modeling.py index 3a2e311800..a3f299266b 100644 --- a/examples/scripts/reward_modeling.py +++ b/examples/scripts/reward_modeling.py @@ -65,30 +65,28 @@ if __name__ == "__main__": parser = HfArgumentParser((ScriptArguments, RewardConfig, ModelConfig)) - script_args, training_args, model_config = parser.parse_args_into_dataclasses() + script_args, training_args, model_args = parser.parse_args_into_dataclasses() training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False) ################ # Model & Tokenizer ################ torch_dtype = ( - model_config.torch_dtype - if model_config.torch_dtype in ["auto", None] - else getattr(torch, model_config.torch_dtype) + model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) ) - quantization_config = get_quantization_config(model_config) + quantization_config = get_quantization_config(model_args) model_kwargs = dict( - revision=model_config.model_revision, + revision=model_args.model_revision, device_map=get_kbit_device_map() if quantization_config is not None else None, quantization_config=quantization_config, use_cache=False if training_args.gradient_checkpointing else True, torch_dtype=torch_dtype, ) tokenizer = AutoTokenizer.from_pretrained( - model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, use_fast=True + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, use_fast=True ) model = AutoModelForSequenceClassification.from_pretrained( - model_config.model_name_or_path, num_labels=1, trust_remote_code=model_config.trust_remote_code, **model_kwargs + model_args.model_name_or_path, num_labels=1, trust_remote_code=model_args.trust_remote_code, **model_kwargs ) # Align padding tokens between tokenizer and model model.config.pad_token_id = tokenizer.pad_token_id @@ -97,7 +95,7 @@ if tokenizer.chat_template is None: model, tokenizer = setup_chat_format(model, tokenizer) - if model_config.use_peft and model_config.lora_task_type != "SEQ_CLS": + if model_args.use_peft and model_args.lora_task_type != "SEQ_CLS": warnings.warn( "You are using a `task_type` that is different than `SEQ_CLS` for PEFT. This will lead to silent bugs" " Make sure to pass --lora_task_type SEQ_CLS when using this script with PEFT.", @@ -118,7 +116,7 @@ args=training_args, train_dataset=dataset[script_args.dataset_train_split], eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, - peft_config=get_peft_config(model_config), + peft_config=get_peft_config(model_args), ) trainer.train() diff --git a/examples/scripts/rloo/rloo.py b/examples/scripts/rloo/rloo.py index 6a56aac8c8..95eff811d4 100644 --- a/examples/scripts/rloo/rloo.py +++ b/examples/scripts/rloo/rloo.py @@ -63,7 +63,7 @@ if __name__ == "__main__": parser = HfArgumentParser((ScriptArguments, RLOOConfig, ModelConfig)) - script_args, training_args, model_config = parser.parse_args_into_dataclasses() + script_args, training_args, model_args = parser.parse_args_into_dataclasses() # remove output_dir if exists shutil.rmtree(training_args.output_dir, ignore_errors=True) @@ -71,21 +71,21 @@ # Model & Tokenizer ################ tokenizer = AutoTokenizer.from_pretrained( - model_config.model_name_or_path, + model_args.model_name_or_path, padding_side="left", - trust_remote_code=model_config.trust_remote_code, + trust_remote_code=model_args.trust_remote_code, ) tokenizer.add_special_tokens({"pad_token": "[PAD]"}) if tokenizer.chat_template is None: tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE reward_model = AutoModelForSequenceClassification.from_pretrained( - training_args.reward_model_path, trust_remote_code=model_config.trust_remote_code, num_labels=1 + training_args.reward_model_path, trust_remote_code=model_args.trust_remote_code, num_labels=1 ) ref_policy = AutoModelForCausalLM.from_pretrained( - training_args.sft_model_path, trust_remote_code=model_config.trust_remote_code + training_args.sft_model_path, trust_remote_code=model_args.trust_remote_code ) policy = AutoModelForCausalLM.from_pretrained( - training_args.sft_model_path, trust_remote_code=model_config.trust_remote_code + training_args.sft_model_path, trust_remote_code=model_args.trust_remote_code ) ################ # Dataset diff --git a/examples/scripts/rloo/rloo_tldr.py b/examples/scripts/rloo/rloo_tldr.py index 8e89570963..6ac7a6e86c 100644 --- a/examples/scripts/rloo/rloo_tldr.py +++ b/examples/scripts/rloo/rloo_tldr.py @@ -65,7 +65,7 @@ if __name__ == "__main__": parser = HfArgumentParser((ScriptArguments, RLOOConfig, ModelConfig)) - script_args, training_args, model_config = parser.parse_args_into_dataclasses() + script_args, training_args, model_args = parser.parse_args_into_dataclasses() # remove output_dir if exists shutil.rmtree(training_args.output_dir, ignore_errors=True) @@ -73,21 +73,21 @@ # Model & Tokenizer ################ tokenizer = AutoTokenizer.from_pretrained( - model_config.model_name_or_path, + model_args.model_name_or_path, padding_side="left", - trust_remote_code=model_config.trust_remote_code, + trust_remote_code=model_args.trust_remote_code, ) tokenizer.add_special_tokens({"pad_token": "[PAD]"}) if tokenizer.chat_template is None: tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE reward_model = AutoModelForSequenceClassification.from_pretrained( - training_args.reward_model_path, trust_remote_code=model_config.trust_remote_code, num_labels=1 + training_args.reward_model_path, trust_remote_code=model_args.trust_remote_code, num_labels=1 ) ref_policy = AutoModelForCausalLM.from_pretrained( - training_args.sft_model_path, trust_remote_code=model_config.trust_remote_code + training_args.sft_model_path, trust_remote_code=model_args.trust_remote_code ) policy = AutoModelForCausalLM.from_pretrained( - training_args.sft_model_path, trust_remote_code=model_config.trust_remote_code + training_args.sft_model_path, trust_remote_code=model_args.trust_remote_code ) ################ # Dataset diff --git a/examples/scripts/sft.py b/examples/scripts/sft.py index 751be63771..4a73268977 100644 --- a/examples/scripts/sft.py +++ b/examples/scripts/sft.py @@ -66,24 +66,24 @@ if __name__ == "__main__": parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig)) - script_args, training_args, model_config = parser.parse_args_and_config() + script_args, training_args, model_args = parser.parse_args_and_config() ################ # Model init kwargs & Tokenizer ################ - quantization_config = get_quantization_config(model_config) + quantization_config = get_quantization_config(model_args) model_kwargs = dict( - revision=model_config.model_revision, - trust_remote_code=model_config.trust_remote_code, - attn_implementation=model_config.attn_implementation, - torch_dtype=model_config.torch_dtype, + revision=model_args.model_revision, + trust_remote_code=model_args.trust_remote_code, + attn_implementation=model_args.attn_implementation, + torch_dtype=model_args.torch_dtype, use_cache=False if training_args.gradient_checkpointing else True, device_map=get_kbit_device_map() if quantization_config is not None else None, quantization_config=quantization_config, ) training_args.model_init_kwargs = model_kwargs tokenizer = AutoTokenizer.from_pretrained( - model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, use_fast=True + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, use_fast=True ) tokenizer.pad_token = tokenizer.eos_token @@ -96,12 +96,12 @@ # Training ################ trainer = SFTTrainer( - model=model_config.model_name_or_path, + model=model_args.model_name_or_path, args=training_args, train_dataset=dataset[script_args.dataset_train_split], eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, processing_class=tokenizer, - peft_config=get_peft_config(model_config), + peft_config=get_peft_config(model_args), ) trainer.train() diff --git a/examples/scripts/sft_video_llm.py b/examples/scripts/sft_video_llm.py index 2a3f08812b..4a85114d4f 100644 --- a/examples/scripts/sft_video_llm.py +++ b/examples/scripts/sft_video_llm.py @@ -158,7 +158,7 @@ class CustomScriptArguments(ScriptArguments): if __name__ == "__main__": # Parse arguments parser = TrlParser((CustomScriptArguments, SFTConfig, ModelConfig)) - script_args, training_args, model_config = parser.parse_args_and_config() + script_args, training_args, model_args = parser.parse_args_and_config() # Configure training args training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False) @@ -170,9 +170,7 @@ class CustomScriptArguments(ScriptArguments): # Setup model torch_dtype = ( - model_config.torch_dtype - if model_config.torch_dtype in ["auto", None] - else getattr(torch, model_config.torch_dtype) + model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) ) # Quantization configuration for 4-bit training @@ -185,14 +183,14 @@ class CustomScriptArguments(ScriptArguments): # Model initialization model_kwargs = dict( - revision=model_config.model_revision, - trust_remote_code=model_config.trust_remote_code, + revision=model_args.model_revision, + trust_remote_code=model_args.trust_remote_code, torch_dtype=torch_dtype, device_map=get_kbit_device_map(), quantization_config=bnb_config, ) - model = AutoModelForVision2Seq.from_pretrained(model_config.model_name_or_path, **model_kwargs) + model = AutoModelForVision2Seq.from_pretrained(model_args.model_name_or_path, **model_kwargs) peft_config = LoraConfig( task_type="CAUSAL_LM", @@ -210,7 +208,7 @@ class CustomScriptArguments(ScriptArguments): model.enable_input_require_grads() processor = AutoProcessor.from_pretrained( - model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code ) # Prepare dataset diff --git a/examples/scripts/sft_vlm.py b/examples/scripts/sft_vlm.py index ca17ec8a09..497bb69b66 100644 --- a/examples/scripts/sft_vlm.py +++ b/examples/scripts/sft_vlm.py @@ -53,7 +53,7 @@ if __name__ == "__main__": parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig)) - script_args, training_args, model_config = parser.parse_args_and_config() + script_args, training_args, model_args = parser.parse_args_and_config() training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False) training_args.remove_unused_columns = False training_args.dataset_kwargs = {"skip_prepare_dataset": True} @@ -62,24 +62,22 @@ # Model, Tokenizer & Processor ################ torch_dtype = ( - model_config.torch_dtype - if model_config.torch_dtype in ["auto", None] - else getattr(torch, model_config.torch_dtype) + model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) ) - quantization_config = get_quantization_config(model_config) + quantization_config = get_quantization_config(model_args) model_kwargs = dict( - revision=model_config.model_revision, - attn_implementation=model_config.attn_implementation, + revision=model_args.model_revision, + attn_implementation=model_args.attn_implementation, torch_dtype=torch_dtype, device_map=get_kbit_device_map() if quantization_config is not None else None, quantization_config=quantization_config, ) processor = AutoProcessor.from_pretrained( - model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code ) model = AutoModelForVision2Seq.from_pretrained( - model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, **model_kwargs + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs ) ################ @@ -121,7 +119,7 @@ def collate_fn(examples): train_dataset=dataset[script_args.dataset_train_split], eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, processing_class=processor.tokenizer, - peft_config=get_peft_config(model_config), + peft_config=get_peft_config(model_args), ) trainer.train() diff --git a/examples/scripts/sft_vlm_smol_vlm.py b/examples/scripts/sft_vlm_smol_vlm.py index eb08a8d7da..278a38621f 100644 --- a/examples/scripts/sft_vlm_smol_vlm.py +++ b/examples/scripts/sft_vlm_smol_vlm.py @@ -60,7 +60,7 @@ if __name__ == "__main__": parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig)) - script_args, training_args, model_config = parser.parse_args_and_config() + script_args, training_args, model_args = parser.parse_args_and_config() training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False) training_args.remove_unused_columns = False training_args.dataset_kwargs = {"skip_prepare_dataset": True} @@ -69,24 +69,22 @@ # Model, Tokenizer & Processor ################ torch_dtype = ( - model_config.torch_dtype - if model_config.torch_dtype in ["auto", None] - else getattr(torch, model_config.torch_dtype) + model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) ) - quantization_config = get_quantization_config(model_config) + quantization_config = get_quantization_config(model_args) model_kwargs = dict( - revision=model_config.model_revision, - attn_implementation=model_config.attn_implementation, + revision=model_args.model_revision, + attn_implementation=model_args.attn_implementation, torch_dtype=torch_dtype, device_map=get_kbit_device_map() if quantization_config is not None else None, quantization_config=quantization_config, ) processor = AutoProcessor.from_pretrained( - model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code ) model = AutoModelForVision2Seq.from_pretrained( - model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, **model_kwargs + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs ) ################ @@ -133,7 +131,7 @@ def collate_fn(examples): train_dataset=dataset[script_args.dataset_train_split], eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, processing_class=processor.tokenizer, - peft_config=get_peft_config(model_config), + peft_config=get_peft_config(model_args), ) trainer.train() diff --git a/examples/scripts/xpo.py b/examples/scripts/xpo.py index 2ddc532374..b30241cb02 100644 --- a/examples/scripts/xpo.py +++ b/examples/scripts/xpo.py @@ -55,18 +55,16 @@ if __name__ == "__main__": parser = TrlParser((ScriptArguments, XPOConfig, ModelConfig)) - script_args, training_args, model_config = parser.parse_args_and_config() + script_args, training_args, model_args = parser.parse_args_and_config() training_args.gradient_checkpointing_kwargs = {"use_reentrant": True} torch_dtype = ( - model_config.torch_dtype - if model_config.torch_dtype in ["auto", None] - else getattr(torch, model_config.torch_dtype) + model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) ) - quantization_config = get_quantization_config(model_config) + quantization_config = get_quantization_config(model_args) model_kwargs = dict( - revision=model_config.model_revision, - attn_implementation=model_config.attn_implementation, + revision=model_args.model_revision, + attn_implementation=model_args.attn_implementation, torch_dtype=torch_dtype, use_cache=False if training_args.gradient_checkpointing else True, device_map=get_kbit_device_map() if quantization_config is not None else None, @@ -74,17 +72,17 @@ ) model = AutoModelForCausalLM.from_pretrained( - model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, **model_kwargs + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs ) ref_model = AutoModelForCausalLM.from_pretrained( - model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, **model_kwargs + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs ) if training_args.reward_model_path is not None: reward_model = AutoModelForSequenceClassification.from_pretrained( training_args.reward_model_path, num_labels=1, - trust_remote_code=model_config.trust_remote_code, + trust_remote_code=model_args.trust_remote_code, **model_kwargs, ) else: @@ -97,9 +95,9 @@ judge = None tokenizer = AutoTokenizer.from_pretrained( - model_config.model_name_or_path, + model_args.model_name_or_path, padding_side="left", - trust_remote_code=model_config.trust_remote_code, + trust_remote_code=model_args.trust_remote_code, ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token diff --git a/tests/test_utils.py b/tests/test_utils.py index 4488cdc309..f27240748f 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -91,8 +91,8 @@ def test_pad_2_dim_right_multidim(self): class TestGetPEFTConfig(unittest.TestCase): def test_create_peft_config_use_peft_false(self): """Test that when use_peft is False, the function returns None.""" - model_config = ModelConfig(use_peft=False) - peft_config = get_peft_config(model_config) + model_args = ModelConfig(use_peft=False) + peft_config = get_peft_config(model_args) self.assertIsNone(peft_config) def test_create_peft_config_use_peft_true(self): @@ -107,8 +107,8 @@ def test_create_peft_config_use_peft_true(self): "lora_target_modules": ["up_proj", "down_proj"], "lora_modules_to_save": ["up_proj"], } - model_config = ModelConfig(use_peft=True, **peft_kwargs) - peft_config = get_peft_config(model_config) + model_args = ModelConfig(use_peft=True, **peft_kwargs) + peft_config = get_peft_config(model_args) self.assertTrue(isinstance(peft_config, LoraConfig)) for arg, value in peft_kwargs.items(): # Test that lists of modules are converted to sets diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 2473b9a8e2..881f3d3873 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -48,6 +48,7 @@ is_torch_npu_available, is_torch_xpu_available, ) +from transformers.utils.deprecation import deprecate_kwarg from ..import_utils import is_unsloth_available from ..trainer.model_config import ModelConfig @@ -871,16 +872,17 @@ def trl_sanitze_kwargs_for_tagging(model, tag_names, kwargs=None): return kwargs -def get_quantization_config(model_config: ModelConfig) -> Optional[BitsAndBytesConfig]: - if model_config.load_in_4bit: +@deprecate_kwarg("model_config", "0.14.0", "model_args", warn_if_greater_or_equal_version=True) +def get_quantization_config(model_args: ModelConfig) -> Optional[BitsAndBytesConfig]: + if model_args.load_in_4bit: quantization_config = BitsAndBytesConfig( load_in_4bit=True, - bnb_4bit_compute_dtype=model_config.torch_dtype, # For consistency with model weights, we use the same value as `torch_dtype` - bnb_4bit_quant_type=model_config.bnb_4bit_quant_type, - bnb_4bit_use_double_quant=model_config.use_bnb_nested_quant, - bnb_4bit_quant_storage=model_config.torch_dtype, + bnb_4bit_compute_dtype=model_args.torch_dtype, # For consistency with model weights, we use the same value as `torch_dtype` + bnb_4bit_quant_type=model_args.bnb_4bit_quant_type, + bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant, + bnb_4bit_quant_storage=model_args.torch_dtype, ) - elif model_config.load_in_8bit: + elif model_args.load_in_8bit: quantization_config = BitsAndBytesConfig( load_in_8bit=True, ) @@ -899,8 +901,9 @@ def get_kbit_device_map() -> Optional[dict[str, int]]: return None -def get_peft_config(model_config: ModelConfig) -> "Optional[PeftConfig]": - if model_config.use_peft is False: +@deprecate_kwarg("model_config", "0.14.0", "model_args", warn_if_greater_or_equal_version=True) +def get_peft_config(model_args: ModelConfig) -> "Optional[PeftConfig]": + if model_args.use_peft is False: return None if not is_peft_available(): @@ -910,14 +913,14 @@ def get_peft_config(model_config: ModelConfig) -> "Optional[PeftConfig]": ) peft_config = LoraConfig( - task_type=model_config.lora_task_type, - r=model_config.lora_r, - target_modules=model_config.lora_target_modules, - lora_alpha=model_config.lora_alpha, - lora_dropout=model_config.lora_dropout, + task_type=model_args.lora_task_type, + r=model_args.lora_r, + target_modules=model_args.lora_target_modules, + lora_alpha=model_args.lora_alpha, + lora_dropout=model_args.lora_dropout, bias="none", - use_rslora=model_config.use_rslora, - modules_to_save=model_config.lora_modules_to_save, + use_rslora=model_args.use_rslora, + modules_to_save=model_args.lora_modules_to_save, ) return peft_config