diff --git a/modeling.py b/modeling.py index 66b0de68d9b1..53243e5eb435 100644 --- a/modeling.py +++ b/modeling.py @@ -25,6 +25,7 @@ import torch import torch.nn as nn from torch.nn import CrossEntropyLoss +from six import string_types def gelu(x): """Implementation of the gelu activation function. @@ -34,6 +35,13 @@ def gelu(x): return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) +def swish(x): + return x * torch.sigmoid(x) + + +ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} + + class BertConfig(object): """Configuration class to store the configuration of a `BertModel`. """ @@ -60,7 +68,7 @@ def __init__(self, intermediate_size: The size of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. hidden_act: The non-linear activation function (function or string) in the - encoder and pooler. + encoder and pooler. If string, "gelu", "relu" and "swish" are supported. hidden_dropout_prob: The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler. attention_probs_dropout_prob: The dropout ratio for the attention @@ -237,7 +245,8 @@ class BERTIntermediate(nn.Module): def __init__(self, config): super(BERTIntermediate, self).__init__() self.dense = nn.Linear(config.hidden_size, config.intermediate_size) - self.intermediate_act_fn = gelu + self.intermediate_act_fn = ACT2FN[config.hidden_act] \ + if isinstance(config.hidden_act, string_types) else config.hidden_act def forward(self, hidden_states): hidden_states = self.dense(hidden_states)