Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🧞 Add output_layer to the list of lm_head_namings in AutoModelForCausalLMWithValueHead #2328

Merged
merged 1 commit into from
Nov 11, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions trl/models/modeling_value_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ class AutoModelForCausalLMWithValueHead(PreTrainedModelWrapper):
- **transformers_parent_class** (`transformers.PreTrainedModel`) -- The parent class of the wrapped model. This
should be set to `transformers.AutoModelForCausalLM` for this class.
- **lm_head_namings** (`tuple`) -- A tuple of strings that are used to identify the language model head of the
wrapped model. This is set to `("lm_head", "embed_out")` for this class but can be changed for other models
in the future
wrapped model. This is set to `("lm_head", "embed_out", "output_layer")` for this class but can be changed
for other models in the future
- **supported_args** (`tuple`) -- A tuple of strings that are used to identify the arguments that are supported
by the `ValueHead` class. Currently, the supported args are:
- **summary_dropout_prob** (`float`, `optional`, defaults to `None`) -- The dropout probability for the
Expand All @@ -86,7 +86,7 @@ class AutoModelForCausalLMWithValueHead(PreTrainedModelWrapper):
"""

transformers_parent_class = AutoModelForCausalLM
lm_head_namings = ["lm_head", "embed_out"]
lm_head_namings = ["lm_head", "embed_out", "output_layer"]
supported_args = (
"summary_dropout_prob",
"v_head_initializer_range",
Expand Down
Loading