forked from varshakishore/dsi
-
Notifications
You must be signed in to change notification settings - Fork 0
/
BertModel.py
73 lines (62 loc) · 2.81 KB
/
BertModel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from transformers import (
BertModel,
BertTokenizer,
BertConfig
)
from torch import nn
class HFBertEncoder(BertModel):
def __init__(self, config):
BertModel.__init__(self, config)
assert config.hidden_size > 0, 'Encoder hidden_size can\'t be zero'
self.init_weights()
@classmethod
def init_encoder(cls, dropout: float = 0.1):
cfg = BertConfig.from_pretrained("bert-base-uncased")
if dropout != 0:
cfg.attention_probs_dropout_prob = dropout
cfg.hidden_dropout_prob = dropout
return cls.from_pretrained("bert-base-uncased", config=cfg)
def forward(self, input_ids, attention_mask):
hidden_states = None
sequence_output, pooled_output = super().forward(input_ids=input_ids,
attention_mask=attention_mask,return_dict=False)
pooled_output = sequence_output[:, 0, :]
return sequence_output, pooled_output, hidden_states
def get_out_size(self):
if self.encode_proj:
return self.encode_proj.out_features
return self.config.hidden_size
class QueryClassifier(nn.Module):
""" Bi-Encoder model component. Encapsulates query/question and context/passage encoders.
"""
def __init__(self, class_num):
super(QueryClassifier, self).__init__()
# note here we only have question encoder
self.question_model = HFBertEncoder.init_encoder()
self.classifier = nn.Linear(self.question_model.config.hidden_size, class_num, bias=False)
def query_emb(self, input_ids, attention_mask):
sequence_output, pooled_output, hidden_states = self.question_model(input_ids, attention_mask)
return pooled_output
def forward(self, query_ids, attention_mask_q, return_hidden_emb=False):
q_embs = self.query_emb(query_ids, attention_mask_q)
if return_hidden_emb:
return q_embs
logits = self.classifier(q_embs)
return logits
class DocClassifier(nn.Module):
""" Bi-Encoder model component. Encapsulates query/question and context/passage encoders.
"""
def __init__(self, class_num):
super(DocClassifier, self).__init__()
# note here we only have ctxt encoder
self.ctx_model = HFBertEncoder.init_encoder()
self.classifier = nn.Linear(self.ctx_model.config.hidden_size, class_num, bias=False)
def doc_emb(self, input_ids, attention_mask):
sequence_output, pooled_output, hidden_states = self.ctx_model(input_ids, attention_mask)
return pooled_output
def forward(self, query_ids, attention_mask_q, return_hidden_emb=False):
doc_embs = self.doc_emb(query_ids, attention_mask_q)
if return_hidden_emb:
return doc_embs
logits = self.classifier(doc_embs)
return logits