Skip to content

Commit

Permalink
Pass an embedding layer to the constructor of the BertModel class (#1135
Browse files Browse the repository at this point in the history
)
  • Loading branch information
zhangguanheng66 committed Feb 9, 2021
1 parent e368dc9 commit 911744e
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 13 deletions.
20 changes: 11 additions & 9 deletions examples/BERT/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ def __init__(self, ntoken, ninp, dropout=0.5):
self.norm = LayerNorm(ninp)
self.dropout = Dropout(dropout)

def forward(self, src, token_type_input):
def forward(self, seq_inputs):
src, token_type_input = seq_inputs
src = self.embed(src) + self.pos_embed(src) \
+ self.tok_type_embed(src, token_type_input)
return self.dropout(self.norm(src))
Expand Down Expand Up @@ -99,16 +100,16 @@ def forward(self, src, src_mask=None, src_key_padding_mask=None):
class BertModel(nn.Module):
"""Contain a transformer encoder."""

def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
def __init__(self, ntoken, ninp, nhead, nhid, nlayers, embed_layer, dropout=0.5):
super(BertModel, self).__init__()
self.model_type = 'Transformer'
self.bert_embed = BertEmbedding(ntoken, ninp)
self.bert_embed = embed_layer
encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout)
self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
self.ninp = ninp

def forward(self, src, token_type_input):
src = self.bert_embed(src, token_type_input)
def forward(self, seq_inputs):
src = self.bert_embed(seq_inputs)
output = self.transformer_encoder(src)
return output

Expand All @@ -118,15 +119,16 @@ class MLMTask(nn.Module):

def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
super(MLMTask, self).__init__()
self.bert_model = BertModel(ntoken, ninp, nhead, nhid, nlayers, dropout=0.5)
embed_layer = BertEmbedding(ntoken, ninp)
self.bert_model = BertModel(ntoken, ninp, nhead, nhid, nlayers, embed_layer, dropout=0.5)
self.mlm_span = Linear(ninp, ninp)
self.activation = F.gelu
self.norm_layer = LayerNorm(ninp, eps=1e-12)
self.mlm_head = Linear(ninp, ntoken)

def forward(self, src, token_type_input=None):
src = src.transpose(0, 1) # Wrap up by nn.DataParallel
output = self.bert_model(src, token_type_input)
output = self.bert_model((src, token_type_input))
output = self.mlm_span(output)
output = self.activation(output)
output = self.norm_layer(output)
Expand All @@ -147,7 +149,7 @@ def __init__(self, bert_model):

def forward(self, src, token_type_input):
src = src.transpose(0, 1) # Wrap up by nn.DataParallel
output = self.bert_model(src, token_type_input)
output = self.bert_model((src, token_type_input))
# Send the first <'cls'> seq to a classifier
output = self.activation(self.linear_layer(output[0]))
output = self.ns_span(output)
Expand All @@ -164,7 +166,7 @@ def __init__(self, bert_model):
self.qa_span = Linear(bert_model.ninp, 2)

def forward(self, src, token_type_input):
output = self.bert_model(src, token_type_input)
output = self.bert_model((src, token_type_input))
# transpose output (S, N, E) to (N, S, E)
output = output.transpose(0, 1)
output = self.activation(output)
Expand Down
5 changes: 3 additions & 2 deletions examples/BERT/ns_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
from model import NextSentenceTask, BertModel
from model import NextSentenceTask, BertModel, BertEmbedding
from utils import run_demo, run_ddp, wrap_up


Expand Down Expand Up @@ -149,7 +149,8 @@ def run_main(args, rank=None):
if args.checkpoint != 'None':
model = torch.load(args.checkpoint)
else:
pretrained_bert = BertModel(len(vocab), args.emsize, args.nhead, args.nhid, args.nlayers, args.dropout)
embed_layer = BertEmbedding(len(vocab), args.emsize)
pretrained_bert = BertModel(len(vocab), args.emsize, args.nhead, args.nhid, args.nlayers, embed_layer, args.dropout)
pretrained_bert.load_state_dict(torch.load(args.bert_model))
model = NextSentenceTask(pretrained_bert)

Expand Down
5 changes: 3 additions & 2 deletions examples/BERT/qa_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from model import QuestionAnswerTask
from metrics import compute_qa_exact, compute_qa_f1
from utils import print_loss_log
from model import BertModel
from model import BertModel, BertEmbedding


def process_raw_data(data):
Expand Down Expand Up @@ -174,7 +174,8 @@ def train():
train_dataset = process_raw_data(train_dataset)
dev_dataset = process_raw_data(dev_dataset)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pretrained_bert = BertModel(len(vocab), args.emsize, args.nhead, args.nhid, args.nlayers, args.dropout)
embed_layer = BertEmbedding(len(vocab), args.emsize)
pretrained_bert = BertModel(len(vocab), args.emsize, args.nhead, args.nhid, args.nlayers, embed_layer, args.dropout)
pretrained_bert.load_state_dict(torch.load(args.bert_model))
model = QuestionAnswerTask(pretrained_bert).to(device)
criterion = nn.CrossEntropyLoss()
Expand Down

0 comments on commit 911744e

Please sign in to comment.