From d07f5e024fbf22c716a19055dc766783c5d2c809 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 5 Nov 2024 17:30:51 +0000 Subject: [PATCH] Update lm_head_namings in AutoModelForCausalLMWithValueHead --- trl/models/modeling_value_head.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/trl/models/modeling_value_head.py b/trl/models/modeling_value_head.py index 75a8fcd157..0797794013 100644 --- a/trl/models/modeling_value_head.py +++ b/trl/models/modeling_value_head.py @@ -70,8 +70,8 @@ class AutoModelForCausalLMWithValueHead(PreTrainedModelWrapper): - **transformers_parent_class** (`transformers.PreTrainedModel`) -- The parent class of the wrapped model. This should be set to `transformers.AutoModelForCausalLM` for this class. - **lm_head_namings** (`tuple`) -- A tuple of strings that are used to identify the language model head of the - wrapped model. This is set to `("lm_head", "embed_out")` for this class but can be changed for other models - in the future + wrapped model. This is set to `("lm_head", "embed_out", "output_layer")` for this class but can be changed + for other models in the future - **supported_args** (`tuple`) -- A tuple of strings that are used to identify the arguments that are supported by the `ValueHead` class. Currently, the supported args are: - **summary_dropout_prob** (`float`, `optional`, defaults to `None`) -- The dropout probability for the @@ -86,7 +86,7 @@ class AutoModelForCausalLMWithValueHead(PreTrainedModelWrapper): """ transformers_parent_class = AutoModelForCausalLM - lm_head_namings = ["lm_head", "embed_out"] + lm_head_namings = ["lm_head", "embed_out", "output_layer"] supported_args = ( "summary_dropout_prob", "v_head_initializer_range",