-
Notifications
You must be signed in to change notification settings - Fork 64
/
global_pointer_bert.py
117 lines (90 loc) · 2.95 KB
/
global_pointer_bert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
from transformers import BertModel
from ark_nlp.nn.base.bert import BertForTokenClassification
from ark_nlp.nn.layer.global_pointer_block import GlobalPointer, EfficientGlobalPointer
class GlobalPointerBert(BertForTokenClassification):
"""
GlobalPointer + Bert 的命名实体模型
Args:
config: 模型的配置对象
bert_trained (:obj:`bool`, optional): 预训练模型的参数是否可训练
Reference:
[1] https://www.kexue.fm/archives/8373
""" # noqa: ignore flake8"
def __init__(
self,
config,
encoder_trained=True,
head_size=64
):
super(GlobalPointerBert, self).__init__(config)
self.num_labels = config.num_labels
self.bert = BertModel(config)
for param in self.bert.parameters():
param.requires_grad = encoder_trained
self.global_pointer = GlobalPointer(
self.num_labels,
head_size,
config.hidden_size
)
self.init_weights()
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
**kwargs
):
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
return_dict=True,
output_hidden_states=True
).hidden_states
sequence_output = outputs[-1]
logits = self.global_pointer(sequence_output, mask=attention_mask)
return logits
class EfficientGlobalPointerBert(BertForTokenClassification):
"""
EfficientGlobalPointer + Bert 的命名实体模型
Args:
config: 模型的配置对象
bert_trained (:obj:`bool`, optional): 预训练模型的参数是否可训练
Reference:
[1] https://www.kexue.fm/archives/8877
[2] https://github.com/powerycy/Efficient-GlobalPointer
""" # noqa: ignore flake8"
def __init__(
self,
config,
encoder_trained=True,
head_size=64
):
super(EfficientGlobalPointerBert, self).__init__(config)
self.num_labels = config.num_labels
self.bert = BertModel(config)
for param in self.bert.parameters():
param.requires_grad = encoder_trained
self.efficient_global_pointer = EfficientGlobalPointer(
self.num_labels,
head_size,
config.hidden_size
)
self.init_weights()
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
**kwargs
):
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
return_dict=True,
output_hidden_states=True
).hidden_states
sequence_output = outputs[-1]
logits = self.efficient_global_pointer(sequence_output, mask=attention_mask)
return logits