From cd66d327d965841db22a068fa4a3df07893fc536 Mon Sep 17 00:00:00 2001 From: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Date: Mon, 2 Sep 2024 09:39:39 +0200 Subject: [PATCH] Add GraniteRMSNorm (#33177) * Add GraniteRMSNorm * [run_slow] granite --- .../models/granite/modeling_granite.py | 31 +++++++++++++++++-- src/transformers/pytorch_utils.py | 2 +- 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index aee62fd249f350..90aa345b0eef43 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -30,6 +30,7 @@ ) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import ALL_LAYERNORM_LAYERS from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, @@ -99,6 +100,30 @@ def _prepare_4d_causal_attention_mask_with_cache_position( return causal_mask +# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Granite +class GraniteRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + GraniteRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +ALL_LAYERNORM_LAYERS.append(GraniteRMSNorm) + + class GraniteRotaryEmbedding(nn.Module): def __init__(self, config: GraniteConfig): super().__init__() @@ -534,8 +559,8 @@ def __init__(self, config: GraniteConfig, layer_idx: int): self.self_attn = GRANITE_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) self.mlp = GraniteMLP(config) - self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.input_layernorm = GraniteRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = GraniteRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.residual_multiplier = config.residual_multiplier @@ -749,7 +774,7 @@ def __init__(self, config: GraniteConfig): self.layers = nn.ModuleList( [GraniteDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) - self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.norm = GraniteRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False self.embedding_multiplier = config.embedding_multiplier diff --git a/src/transformers/pytorch_utils.py b/src/transformers/pytorch_utils.py index 8c02e0781092d7..f3663c09902f52 100644 --- a/src/transformers/pytorch_utils.py +++ b/src/transformers/pytorch_utils.py @@ -24,7 +24,7 @@ from .utils import is_torch_xla_available, logging -ALL_LAYERNORM_LAYERS = [nn.LayerNorm, nn.RMSNorm] +ALL_LAYERNORM_LAYERS = [nn.LayerNorm] logger = logging.get_logger(__name__)