Skip to content

Commit

Permalink
BertEmbedding to accept one input tuple in forward func
Browse files Browse the repository at this point in the history
  • Loading branch information
Guanheng Zhang committed Jan 28, 2021
1 parent 8bcffe6 commit beeab99
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions examples/BERT/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ def __init__(self, ntoken, ninp, dropout=0.5):
self.norm = LayerNorm(ninp)
self.dropout = Dropout(dropout)

def forward(self, src, token_type_input):
def forward(self, seq_inputs):
src, token_type_input = seq_inputs
src = self.embed(src) + self.pos_embed(src) \
+ self.tok_type_embed(src, token_type_input)
return self.dropout(self.norm(src))
Expand Down Expand Up @@ -107,8 +108,8 @@ def __init__(self, ntoken, ninp, nhead, nhid, nlayers, embed_layer, dropout=0.5)
self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
self.ninp = ninp

def forward(self, src, token_type_input):
src = self.bert_embed(src, token_type_input)
def forward(self, seq_inputs):
src = self.bert_embed(seq_inputs)
output = self.transformer_encoder(src)
return output

Expand All @@ -127,7 +128,7 @@ def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):

def forward(self, src, token_type_input=None):
src = src.transpose(0, 1) # Wrap up by nn.DataParallel
output = self.bert_model(src, token_type_input)
output = self.bert_model((src, token_type_input))
output = self.mlm_span(output)
output = self.activation(output)
output = self.norm_layer(output)
Expand All @@ -148,7 +149,7 @@ def __init__(self, bert_model):

def forward(self, src, token_type_input):
src = src.transpose(0, 1) # Wrap up by nn.DataParallel
output = self.bert_model(src, token_type_input)
output = self.bert_model((src, token_type_input))
# Send the first <'cls'> seq to a classifier
output = self.activation(self.linear_layer(output[0]))
output = self.ns_span(output)
Expand All @@ -165,7 +166,7 @@ def __init__(self, bert_model):
self.qa_span = Linear(bert_model.ninp, 2)

def forward(self, src, token_type_input):
output = self.bert_model(src, token_type_input)
output = self.bert_model((src, token_type_input))
# transpose output (S, N, E) to (N, S, E)
output = output.transpose(0, 1)
output = self.activation(output)
Expand Down

0 comments on commit beeab99

Please sign in to comment.