Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[t5/t0/mt5 models] faster/leaner custom layer norm #14656

Merged
merged 11 commits into from
Feb 16, 2022
22 changes: 20 additions & 2 deletions src/transformers/models/t5/modeling_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,14 +237,19 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path):
class T5LayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
Construct a layernorm module in the T5 style No bias and no subtraction of mean.
Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps

def forward(self, hidden_states):
# layer norm should always be calculated in float32

# T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
# Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
# half-precision inputs is done in fp32

variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)

Expand All @@ -255,6 +260,19 @@ def forward(self, hidden_states):
return self.weight * hidden_states


try:
from apex.normalization import FusedRMSNorm

T5LayerNorm = FusedRMSNorm # noqa

print("XXX: using FusedRMSNorm")
except ImportError:
print("XXX: using T5LayerNorm")
except Exception:
print("XXX: using T5LayerNorm: unknown exception")
pass
stas00 marked this conversation as resolved.
Show resolved Hide resolved


class T5DenseReluDense(nn.Module):
def __init__(self, config):
super().__init__()
Expand Down