Skip to content

Commit

Permalink
Only cast logits to float when computing loss (#34147)
Browse files Browse the repository at this point in the history
* Only cast logits to float when computing loss

Some misses from #31292 and #33902

* Move logits.float() into existing if labels is not None branch
  • Loading branch information
ringohoffman authored Oct 18, 2024
1 parent e46e3bc commit 816f442
Show file tree
Hide file tree
Showing 9 changed files with 16 additions and 21 deletions.
3 changes: 2 additions & 1 deletion src/transformers/models/chameleon/modeling_chameleon.py
Original file line number Diff line number Diff line change
Expand Up @@ -1602,14 +1602,15 @@ def forward(

hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()

# Disallow image tokens which does not include special begin-image and end-image tokens
image_tokens = self.model.vocabulary_mapping.image_tokens
logits[:, :, image_tokens] = torch.finfo(logits.dtype).min

loss = None
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/granite/modeling_granite.py
Original file line number Diff line number Diff line change
Expand Up @@ -1101,10 +1101,11 @@ def forward(
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits / self.config.logits_scaling
logits = logits.float()

loss = None
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/granitemoe/modeling_granitemoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1345,10 +1345,11 @@ def forward(
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits / self.config.logits_scaling
logits = logits.float()

loss = None
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/idefics3/modeling_idefics3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1210,10 +1210,11 @@ def forward(

hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()

loss = None
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
labels = labels.to(logits.device)
# Shift so that tokens < n predict n
if attention_mask is not None:
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/paligemma/modeling_paligemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,9 +526,10 @@ def forward(
)

logits = outputs.logits
logits = logits.float()
loss = None
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
shift_logits = logits[..., :-1, :]
shift_labels = labels[..., 1:]
if attention_mask is not None:
Expand Down
8 changes: 1 addition & 7 deletions src/transformers/models/phimoe/modeling_phimoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_torchdynamo_compiling,
logging,
replace_return_docstrings,
)
Expand Down Expand Up @@ -1463,13 +1462,8 @@ def forward(
)

hidden_states = outputs[0]
if labels is None and not is_torchdynamo_compiling():
logger.warning_once(
"Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)"
)
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
# TODO: remove the float() operation in v4.46
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float()
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])

loss = None
if labels is not None:
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1760,10 +1760,11 @@ def forward(

hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()

loss = None
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -870,9 +870,10 @@ def forward(
cap = self.config.logits_soft_cap
logits = nn.functional.tanh(logits / cap) * cap

logits = logits.float()
loss = None
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
Expand Down
8 changes: 1 addition & 7 deletions src/transformers/models/zamba/modeling_zamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
from ...utils.import_utils import (
is_causal_conv1d_available,
is_mamba_ssm_available,
is_torchdynamo_compiling,
)
from .configuration_zamba import ZambaConfig

Expand Down Expand Up @@ -1473,13 +1472,8 @@ def forward(
)

hidden_states = outputs[0]
if labels is None and not is_torchdynamo_compiling():
logger.warning_once(
"Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)"
)
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
# TODO: remove the float() operation in v4.46
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float()
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])

loss = None
if labels is not None:
Expand Down

0 comments on commit 816f442

Please sign in to comment.