Skip to content

Commit

Permalink
Create the return value on device to avoid unnecessary copying from C…
Browse files Browse the repository at this point in the history
  • Loading branch information
mksit authored and MKhalusova committed Sep 19, 2023
1 parent 88b5c8a commit 50e805c
Showing 1 changed file with 1 addition and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -779,7 +779,7 @@ def forward(
if isinstance(hidden_states, tuple):
hidden_states, router_tuple = hidden_states
else:
router_tuple = (torch.tensor([0], device=hidden_states.device),)
router_tuple = (torch.zeros((1,), device=hidden_states.device, dtype=torch.int64),)

# clamp inf values to enable fp16 training
if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
Expand Down

0 comments on commit 50e805c

Please sign in to comment.