Skip to content

Commit

Permalink
Bugfix: LLaMA layer norm incorrectly changes input type and consumers…
Browse files Browse the repository at this point in the history
… lots of memory (#23535)

* Fixed bug where LLaMA layer norm would change input type.

* make fix-copies

---------

Co-authored-by: younesbelkada <younesbelkada@gmail.com>
  • Loading branch information
TimDettmers and younesbelkada authored May 22, 2023
1 parent fe34486 commit 4ddd9de
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 10 deletions.
7 changes: 2 additions & 5 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,11 @@ def __init__(self, hidden_size, eps=1e-6):
self.variance_epsilon = eps

def forward(self, hidden_states):
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)

# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)

return self.weight * hidden_states
return (self.weight * hidden_states).to(input_dtype)


class LlamaRotaryEmbedding(torch.nn.Module):
Expand Down
7 changes: 2 additions & 5 deletions src/transformers/models/open_llama/modeling_open_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,11 @@ def __init__(self, hidden_size, eps=1e-6):
self.variance_epsilon = eps

def forward(self, hidden_states):
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)

# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)

return self.weight * hidden_states
return (self.weight * hidden_states).to(input_dtype)


# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->OpenLlama
Expand Down

0 comments on commit 4ddd9de

Please sign in to comment.