Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

replace apex.normalization.FusedLayerNorm with torch.nn.LayerNorm #9386

Merged
merged 1 commit into from
Jan 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 10 additions & 20 deletions src/transformers/models/bart/modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn import CrossEntropyLoss
from torch.nn import CrossEntropyLoss, LayerNorm

from ...activations import ACT2FN
from ...file_utils import (
Expand Down Expand Up @@ -109,16 +109,6 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)


def BartLayerNorm(normalized_shape: torch.Size, eps: float = 1e-5, elementwise_affine: bool = True):
try:
from apex.normalization import FusedLayerNorm

return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
except ImportError:
pass
return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)


class BartLearnedPositionalEmbedding(nn.Embedding):
"""
This module learns positional embeddings up to a fixed maximum size. Padding ids are ignored by either offsetting
Expand Down Expand Up @@ -321,13 +311,13 @@ def __init__(self, config: BartConfig):
dropout=config.attention_dropout,
)
self.normalize_before = config.normalize_before
self.self_attn_layer_norm = BartLayerNorm(self.embed_dim)
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout
self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
self.final_layer_norm = BartLayerNorm(self.embed_dim)
self.final_layer_norm = LayerNorm(self.embed_dim)

def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, output_attentions: bool = False):
"""
Expand Down Expand Up @@ -380,17 +370,17 @@ def __init__(self, config: BartConfig):
self.activation_dropout = config.activation_dropout
self.normalize_before = config.normalize_before

self.self_attn_layer_norm = BartLayerNorm(self.embed_dim)
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
self.encoder_attn = BartAttention(
self.embed_dim,
config.decoder_attention_heads,
dropout=config.attention_dropout,
is_decoder=True,
)
self.encoder_attn_layer_norm = BartLayerNorm(self.embed_dim)
self.encoder_attn_layer_norm = LayerNorm(self.embed_dim)
self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
self.final_layer_norm = BartLayerNorm(self.embed_dim)
self.final_layer_norm = LayerNorm(self.embed_dim)

def forward(
self,
Expand Down Expand Up @@ -672,9 +662,9 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = No
config.extra_pos_embeddings,
)
self.layers = nn.ModuleList([BartEncoderLayer(config) for _ in range(config.encoder_layers)])
self.layernorm_embedding = BartLayerNorm(embed_dim) if config.normalize_embedding else nn.Identity()
self.layernorm_embedding = LayerNorm(embed_dim) if config.normalize_embedding else nn.Identity()
# mbart has one extra layer_norm
self.layer_norm = BartLayerNorm(config.d_model) if config.add_final_layer_norm else None
self.layer_norm = LayerNorm(config.d_model) if config.add_final_layer_norm else None

self.init_weights()

Expand Down Expand Up @@ -812,8 +802,8 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = No
config.extra_pos_embeddings,
)
self.layers = nn.ModuleList([BartDecoderLayer(config) for _ in range(config.decoder_layers)])
self.layernorm_embedding = BartLayerNorm(config.d_model) if config.normalize_embedding else nn.Identity()
self.layer_norm = BartLayerNorm(config.d_model) if config.add_final_layer_norm else None
self.layernorm_embedding = LayerNorm(config.d_model) if config.normalize_embedding else nn.Identity()
self.layer_norm = LayerNorm(config.d_model) if config.add_final_layer_norm else None

self.init_weights()

Expand Down
12 changes: 1 addition & 11 deletions src/transformers/models/fsmt/modeling_fsmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from torch.nn import CrossEntropyLoss
from torch.nn import CrossEntropyLoss, LayerNorm

from ...activations import ACT2FN
from ...file_utils import (
Expand Down Expand Up @@ -264,16 +264,6 @@
"""


have_fused_layer_norm = False
try:
from apex.normalization import FusedLayerNorm

have_fused_layer_norm = True
except ImportError:
pass
LayerNorm = FusedLayerNorm if have_fused_layer_norm else torch.nn.LayerNorm


def invert_mask(attention_mask):
"""Turns 1->0, 0->1, False->True, True-> False"""
assert attention_mask.dim() == 2
Expand Down
25 changes: 8 additions & 17 deletions src/transformers/models/prophetnet/modeling_prophetnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from torch.nn import LayerNorm

from ...activations import ACT2FN
from ...file_utils import (
Expand Down Expand Up @@ -510,16 +511,6 @@ class ProphetNetDecoderLMOutput(ModelOutput):
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None


def ProphetNetLayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True):
try:
from apex.normalization import FusedLayerNorm

return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
except ImportError:
pass
return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)


class ProphetNetPreTrainedModel(PreTrainedModel):
config_class = ProphetNetConfig
base_model_prefix = "prophetnet"
Expand Down Expand Up @@ -1044,11 +1035,11 @@ def __init__(self, config: ProphetNetConfig):
super().__init__()
# 1st residual block
self.self_attn = ProphetNetSelfAttention(config, config.num_encoder_attention_heads)
self.self_attn_layer_norm = ProphetNetLayerNorm(config.hidden_size)
self.self_attn_layer_norm = LayerNorm(config.hidden_size)

# 2nd residual block
self.feed_forward = ProhpetNetFeedForward(config, config.encoder_ffn_dim)
self.feed_forward_layer_norm = ProphetNetLayerNorm(config.hidden_size)
self.feed_forward_layer_norm = LayerNorm(config.hidden_size)

def forward(self, hidden_states, attention_mask):
# 1st residual block
Expand All @@ -1073,16 +1064,16 @@ def __init__(self, config: ProphetNetConfig):
super().__init__()
# 1st residual block
self.self_attn = ProphetNetNgramProphetNetSelfAttention(config)
self.self_attn_layer_norm = ProphetNetLayerNorm(config.hidden_size)
self.self_attn_layer_norm = LayerNorm(config.hidden_size)

# 2nd residual block
if config.add_cross_attention:
self.cross_attn = ProphetNetSelfAttention(config, config.num_decoder_attention_heads)
self.cross_attn_layer_norm = ProphetNetLayerNorm(config.hidden_size)
self.cross_attn_layer_norm = LayerNorm(config.hidden_size)

# 3rd residual block
self.feed_forward = ProhpetNetFeedForward(config, config.decoder_ffn_dim)
self.feed_forward_layer_norm = ProphetNetLayerNorm(config.hidden_size)
self.feed_forward_layer_norm = LayerNorm(config.hidden_size)

def forward(
self,
Expand Down Expand Up @@ -1154,7 +1145,7 @@ def __init__(self, config: ProphetNetConfig, word_embeddings: nn.Embedding = Non
else nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
)
self.position_embeddings = ProhpetNetPositionalEmbeddings(config)
self.embeddings_layer_norm = ProphetNetLayerNorm(config.hidden_size)
self.embeddings_layer_norm = LayerNorm(config.hidden_size)

self.layers = nn.ModuleList([ProphetNetEncoderLayer(config) for _ in range(config.num_encoder_layers)])

Expand Down Expand Up @@ -1274,7 +1265,7 @@ def __init__(self, config: ProphetNetConfig, word_embeddings: nn.Embedding = Non

self.ngram_embeddings = nn.Embedding(self.ngram, config.hidden_size, None)
self.layers = nn.ModuleList([ProphetNetDecoderLayer(config) for _ in range(config.num_decoder_layers)])
self.embeddings_layer_norm = ProphetNetLayerNorm(config.hidden_size)
self.embeddings_layer_norm = LayerNorm(config.hidden_size)

self.init_weights()

Expand Down