-
Notifications
You must be signed in to change notification settings - Fork 1
/
bert_models.py
102 lines (76 loc) · 3.22 KB
/
bert_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import sys; sys.path.append('../common')
from helper import *
from base_models import *
import model_clinicalBERT
# BERT BiLSTM FT model
class BertBiLSTM(BertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.bert = BertModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.rnn_dim = 512
self.lstm = nn.LSTM(config.hidden_size, self.rnn_dim // 2, num_layers=1, bidirectional=True, dropout=0.0, batch_first=True)
self.classifier = nn.Linear(self.rnn_dim, config.num_labels)
self.init_weights()
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, labels=None):
outputs = self.bert(
input_ids.view(-1, input_ids.shape[-1]),
attention_mask = attention_mask.view(-1, input_ids.shape[-1]),
token_type_ids = token_type_ids,
position_ids = position_ids,
head_mask = head_mask,
inputs_embeds = inputs_embeds,
)
bert_out = outputs[1]
bert_out = self.dropout(bert_out)
rnn_in = bert_out.view(input_ids.shape[0], input_ids.shape[1], -1)
lstm_out, final = self.lstm(rnn_in)
logits = self.classifier(lstm_out)
loss = F.binary_cross_entropy_with_logits(logits, labels.float())
return loss, {'med_class': logits}
# BERT, BERT-FT model
class BertPlainNew(BertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.bert = BertModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.init_weights()
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, labels=None):
outputs = self.bert(
input_ids.view(-1, input_ids.shape[-1]),
attention_mask = attention_mask.view(-1, input_ids.shape[-1]),
token_type_ids = token_type_ids,
position_ids = position_ids,
head_mask = head_mask,
inputs_embeds = inputs_embeds,
)
bert_out = outputs[1]
bert_out = self.dropout(bert_out)
rnn_in = bert_out.view(input_ids.shape[0], input_ids.shape[1], -1)
logits = self.classifier(rnn_in)
loss = F.binary_cross_entropy_with_logits(logits, labels.float())
return loss, {'med_class': logits}
# Clinical BioBert-FT model
class ClinicalBertPlainNew(BertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.bert = BertModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.init_weights()
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, labels=None):
outputs = self.bert(
input_ids.view(-1, input_ids.shape[-1]),
attention_mask = attention_mask.view(-1, input_ids.shape[-1]),
token_type_ids = token_type_ids,
position_ids = position_ids,
head_mask = head_mask,
inputs_embeds = inputs_embeds,
)
bert_out = outputs[1]
bert_out = self.dropout(bert_out)
rnn_in = bert_out.view(input_ids.shape[0], input_ids.shape[1], -1)
logits = self.classifier(rnn_in)
loss = F.binary_cross_entropy_with_logits(logits, labels.float())
return loss, {'med_class': logits}