diff --git a/trl/models/modeling_value_head.py b/trl/models/modeling_value_head.py index 75a8fcd157..0797794013 100644 --- a/trl/models/modeling_value_head.py +++ b/trl/models/modeling_value_head.py @@ -70,8 +70,8 @@ class AutoModelForCausalLMWithValueHead(PreTrainedModelWrapper): - **transformers_parent_class** (`transformers.PreTrainedModel`) -- The parent class of the wrapped model. This should be set to `transformers.AutoModelForCausalLM` for this class. - **lm_head_namings** (`tuple`) -- A tuple of strings that are used to identify the language model head of the - wrapped model. This is set to `("lm_head", "embed_out")` for this class but can be changed for other models - in the future + wrapped model. This is set to `("lm_head", "embed_out", "output_layer")` for this class but can be changed + for other models in the future - **supported_args** (`tuple`) -- A tuple of strings that are used to identify the arguments that are supported by the `ValueHead` class. Currently, the supported args are: - **summary_dropout_prob** (`float`, `optional`, defaults to `None`) -- The dropout probability for the @@ -86,7 +86,7 @@ class AutoModelForCausalLMWithValueHead(PreTrainedModelWrapper): """ transformers_parent_class = AutoModelForCausalLM - lm_head_namings = ["lm_head", "embed_out"] + lm_head_namings = ["lm_head", "embed_out", "output_layer"] supported_args = ( "summary_dropout_prob", "v_head_initializer_range",