Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions trl/trainer/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from transformers import Trainer, is_wandb_available

from .utils import generate_model_card, get_comet_experiment_url
from .utils import generate_model_card, get_comet_experiment_url, get_config_model_id


if is_wandb_available():
Expand Down Expand Up @@ -50,8 +50,9 @@ def create_model_card(
if not self.is_world_process_zero():
return

if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
base_model = self.model.config._name_or_path
model_name_or_path = get_config_model_id(self.model.config)
if model_name_or_path and not os.path.isdir(model_name_or_path):
base_model = model_name_or_path
else:
base_model = None

Expand Down
4 changes: 2 additions & 2 deletions trl/trainer/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from ..mergekit_utils import MergeConfig, merge_models, upload_model_to_hf
from ..models.utils import unwrap_model_for_generation
from .judges import BasePairwiseJudge
from .utils import log_table_to_comet_experiment
from .utils import get_config_model_id, log_table_to_comet_experiment


if is_rich_available():
Expand Down Expand Up @@ -821,7 +821,7 @@ def _merge_and_maybe_push(self, output_dir, global_step, model):
checkpoint_path = os.path.join(output_dir, f"checkpoint-{global_step}")
self.merge_config.policy_model_path = checkpoint_path
if self.merge_config.target_model_path is None:
self.merge_config.target_model_path = model.config._name_or_path
self.merge_config.target_model_path = get_config_model_id(model.config)
merge_path = os.path.join(checkpoint_path, "merged")

merge_models(self.merge_config.create(), merge_path)
Expand Down
5 changes: 3 additions & 2 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
empty_cache,
flush_left,
flush_right,
get_config_model_id,
log_table_to_comet_experiment,
pad,
pad_to_length,
Expand Down Expand Up @@ -286,7 +287,7 @@ def __init__(
):
# Args
if args is None:
model_name = model if isinstance(model, str) else model.config._name_or_path
model_name = model if isinstance(model, str) else get_config_model_id(model.config)
model_name = model_name.split("/")[-1]
args = DPOConfig(f"{model_name}-DPO")

Expand All @@ -299,7 +300,7 @@ def __init__(
"You passed `model_init_kwargs` to the `DPOConfig`, but your model is already instantiated. "
"The `model_init_kwargs` will be ignored."
)
model_id = model.config._name_or_path
model_id = get_config_model_id(model.config)
if isinstance(ref_model, str):
ref_model = create_model_from_path(ref_model, **args.ref_model_init_kwargs or {})
else:
Expand Down
11 changes: 6 additions & 5 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
disable_dropout_in_model,
ensure_master_addr_port,
entropy_from_logits,
get_config_model_id,
identity,
nanmax,
nanmin,
Expand Down Expand Up @@ -233,7 +234,7 @@ def __init__(
):
# Args
if args is None:
model_name = model if isinstance(model, str) else model.config._name_or_path
model_name = model if isinstance(model, str) else get_config_model_id(model.config)
model_name = model_name.split("/")[-1]
args = GRPOConfig(f"{model_name}-GRPO")

Expand All @@ -258,7 +259,7 @@ def __init__(
architecture = getattr(transformers, config.architectures[0])
model = architecture.from_pretrained(model_id, **model_init_kwargs)
else:
model_id = model.config._name_or_path
model_id = get_config_model_id(model.config)
if args.model_init_kwargs is not None:
logger.warning(
"You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. "
Expand All @@ -278,7 +279,7 @@ def __init__(

# Processing class
if processing_class is None:
processing_class = AutoProcessor.from_pretrained(model.config._name_or_path, truncation_side="left")
processing_class = AutoProcessor.from_pretrained(get_config_model_id(model.config), truncation_side="left")

# Handle pad token for processors or tokenizers
if isinstance(processing_class, ProcessorMixin):
Expand All @@ -305,7 +306,7 @@ def __init__(
reward_func, num_labels=1, **model_init_kwargs
)
if isinstance(reward_funcs[i], nn.Module): # Use Module over PretrainedModel for compat w/ compiled models
self.reward_func_names.append(reward_funcs[i].config._name_or_path.split("/")[-1])
self.reward_func_names.append(get_config_model_id(reward_funcs[i].config).split("/")[-1])
else:
self.reward_func_names.append(reward_funcs[i].__name__)
self.reward_funcs = reward_funcs
Expand Down Expand Up @@ -335,7 +336,7 @@ def __init__(
for i, (reward_processing_class, reward_func) in enumerate(zip(reward_processing_classes, reward_funcs)):
if isinstance(reward_func, PreTrainedModel):
if reward_processing_class is None:
reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path)
reward_processing_class = AutoTokenizer.from_pretrained(get_config_model_id(reward_func.config))
if reward_processing_class.pad_token_id is None:
reward_processing_class.pad_token = reward_processing_class.eos_token
# The reward model computes the reward for the latest non-padded token in the input sequence.
Expand Down
3 changes: 2 additions & 1 deletion trl/trainer/online_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
disable_dropout_in_model,
empty_cache,
ensure_master_addr_port,
get_config_model_id,
pad,
truncate_right,
)
Expand Down Expand Up @@ -243,7 +244,7 @@ def __init__(
reward_func, num_labels=1, **model_init_kwargs
)
if isinstance(reward_funcs[i], nn.Module):
self.reward_func_names.append(reward_funcs[i].config._name_or_path.split("/")[-1])
self.reward_func_names.append(get_config_model_id(reward_funcs[i].config).split("/")[-1])
else:
self.reward_func_names.append(reward_funcs[i].__name__)
self.reward_funcs = reward_funcs
Expand Down
6 changes: 3 additions & 3 deletions trl/trainer/reward_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from ..models import clone_chat_template, get_act_offloading_ctx_manager, prepare_peft_model
from .base_trainer import BaseTrainer
from .reward_config import RewardConfig
from .utils import disable_dropout_in_model, pad, remove_none_values
from .utils import disable_dropout_in_model, get_config_model_id, pad, remove_none_values


if is_peft_available():
Expand Down Expand Up @@ -273,7 +273,7 @@ def __init__(
):
# Args
if args is None:
model_name = model if isinstance(model, str) else model.config._name_or_path
model_name = model if isinstance(model, str) else get_config_model_id(model.config)
model_name = model_name.split("/")[-1]
args = RewardConfig(f"{model_name}-Reward")

Expand All @@ -294,7 +294,7 @@ def __init__(
with suppress_from_pretrained_warning(transformers.modeling_utils.logger):
model = AutoModelForSequenceClassification.from_pretrained(model_id, num_labels=1, **model_init_kwargs)
else:
model_id = model.config._name_or_path
model_id = get_config_model_id(model.config)
if args.model_init_kwargs is not None:
logger.warning(
"You passed `model_init_kwargs` to the `RewardConfig`, but your model is already instantiated. "
Expand Down
11 changes: 6 additions & 5 deletions trl/trainer/rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
disable_dropout_in_model,
ensure_master_addr_port,
entropy_from_logits,
get_config_model_id,
identity,
nanmax,
nanmin,
Expand Down Expand Up @@ -240,7 +241,7 @@ def __init__(

# Args
if args is None:
model_name = model if isinstance(model, str) else model.config._name_or_path
model_name = model if isinstance(model, str) else get_config_model_id(model.config)
model_name = model_name.split("/")[-1]
args = RLOOConfig(f"{model_name}-RLOO")

Expand All @@ -265,7 +266,7 @@ def __init__(
architecture = getattr(transformers, config.architectures[0])
model = architecture.from_pretrained(model_id, **model_init_kwargs)
else:
model_id = model.config._name_or_path
model_id = get_config_model_id(model.config)
if args.model_init_kwargs is not None:
logger.warning(
"You passed `model_init_kwargs` to the `RLOOConfig`, but your model is already instantiated. "
Expand All @@ -285,7 +286,7 @@ def __init__(

# Processing class
if processing_class is None:
processing_class = AutoProcessor.from_pretrained(model.config._name_or_path, truncation_side="left")
processing_class = AutoProcessor.from_pretrained(get_config_model_id(model.config), truncation_side="left")

# Handle pad token for processors or tokenizers
if isinstance(processing_class, ProcessorMixin):
Expand All @@ -312,7 +313,7 @@ def __init__(
reward_func, num_labels=1, **model_init_kwargs
)
if isinstance(reward_funcs[i], nn.Module): # Use Module over PretrainedModel for compat w/ compiled models
self.reward_func_names.append(reward_funcs[i].config._name_or_path.split("/")[-1])
self.reward_func_names.append(get_config_model_id(reward_funcs[i].config).split("/")[-1])
else:
self.reward_func_names.append(reward_funcs[i].__name__)
self.reward_funcs = reward_funcs
Expand Down Expand Up @@ -342,7 +343,7 @@ def __init__(
for i, (reward_processing_class, reward_func) in enumerate(zip(reward_processing_classes, reward_funcs)):
if isinstance(reward_func, PreTrainedModel):
if reward_processing_class is None:
reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path)
reward_processing_class = AutoTokenizer.from_pretrained(get_config_model_id(reward_func.config))
if reward_processing_class.pad_token_id is None:
reward_processing_class.pad_token = reward_processing_class.eos_token
# The reward model computes the reward for the latest non-padded token in the input sequence.
Expand Down
6 changes: 3 additions & 3 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
create_model_from_path,
entropy_from_logits,
flush_left,
get_config_model_id,
pad,
remove_none_values,
selective_log_softmax,
Expand Down Expand Up @@ -590,7 +591,7 @@ def __init__(
):
# Args
if args is None:
model_name = model if isinstance(model, str) else model.config._name_or_path
model_name = model if isinstance(model, str) else get_config_model_id(model.config)
model_name = model_name.split("/")[-1]
args = SFTConfig(f"{model_name}-SFT")
elif isinstance(args, TrainingArguments) and not isinstance(args, SFTConfig):
Expand All @@ -608,11 +609,10 @@ def __init__(
"You passed `model_init_kwargs` to the `SFTConfig`, but your model is already instantiated. "
"The `model_init_kwargs` will be ignored."
)
model_id = model.config._name_or_path

# Processing class
if processing_class is None:
processing_class = AutoProcessor.from_pretrained(model_id)
processing_class = AutoProcessor.from_pretrained(get_config_model_id(model.config))

# Handle pad token for processors or tokenizers
if isinstance(processing_class, ProcessorMixin):
Expand Down
17 changes: 17 additions & 0 deletions trl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
BitsAndBytesConfig,
EvalPrediction,
GenerationConfig,
PretrainedConfig,
PreTrainedModel,
PreTrainedTokenizerBase,
TrainerState,
Expand Down Expand Up @@ -1962,3 +1963,19 @@ def create_model_from_path(model_id: str, **kwargs) -> PreTrainedModel:
architecture = getattr(transformers, config.architectures[0])
model = architecture.from_pretrained(model_id, **kwargs)
return model


def get_config_model_id(config: PretrainedConfig) -> str:
"""
Retrieve the model identifier from a given model configuration.

Args:
config ([`~transformers.PreTrainedConfig`]):
Configuration from which to extract the model identifier.

Returns:
`str`:
The model identifier associated with the model configuration.
"""
# Fall back to `config.text_config._name_or_path` if `config._name_or_path` is missing: Qwen2-VL and Qwen2.5-VL. See GH-4323
return getattr(config, "_name_or_path", "") or getattr(getattr(config, "text_config", None), "_name_or_path", "")
Loading