-
Notifications
You must be signed in to change notification settings - Fork 43
/
model.py
161 lines (148 loc) · 6.6 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import torch
from torch import nn
import torch.nn.functional as F
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
from modeling_bart import BartScorer
from modeling_pegasus import PegasusScorer
def RankingLoss(score, summary_score=None, margin=0, gold_margin=0, gold_weight=1, no_gold=False, no_cand=False):
ones = torch.ones_like(score)
loss_func = torch.nn.MarginRankingLoss(0.0)
TotalLoss = loss_func(score, score, ones)
# candidate loss
n = score.size(1)
if not no_cand:
for i in range(1, n):
pos_score = score[:, :-i]
neg_score = score[:, i:]
pos_score = pos_score.contiguous().view(-1)
neg_score = neg_score.contiguous().view(-1)
ones = torch.ones_like(pos_score)
loss_func = torch.nn.MarginRankingLoss(margin * i)
loss = loss_func(pos_score, neg_score, ones)
TotalLoss += loss
if no_gold:
return TotalLoss
# gold summary loss
pos_score = summary_score.unsqueeze(-1).expand_as(score)
neg_score = score
pos_score = pos_score.contiguous().view(-1)
neg_score = neg_score.contiguous().view(-1)
ones = torch.ones_like(pos_score)
loss_func = torch.nn.MarginRankingLoss(gold_margin)
TotalLoss += gold_weight * loss_func(pos_score, neg_score, ones)
return TotalLoss
class BRIO(nn.Module):
def __init__(self, mname, pad_token_id, is_pegasus=False):
super(BRIO, self).__init__()
if is_pegasus:
self.model = PegasusScorer.from_pretrained(mname, cache_dir="./local_cache")
else:
self.model = BartScorer.from_pretrained(mname, cache_dir="./local_cache")
self.pad_token_id = pad_token_id
def forward(self, text_id, candidate_id, normalize=True, score_mode="base", length_penalty=1, require_gold=True, adding=0):
batch_size = text_id.size(0)
input_mask = text_id != self.pad_token_id
cand_mask = candidate_id != self.pad_token_id
cand_mask[:, :, 0] = 1
output = self.model(
input_ids=text_id,
attention_mask=input_mask,
decoder_input_ids=candidate_id,
decoder_attention_mask=cand_mask,
output_hidden_states=True
)
output = output[0] # [bz x cand_num, seq_len, word_dim]
output = output.view(batch_size, -1, output.size(1), output.size(2)) # [bz, cand_num, seq_len, word_dim]
probs = output[:, 0]
output = output[:, :, :-1] # truncate last token
candidate_id = candidate_id[:, :, 1:] # shift right
cand_mask = candidate_id != self.pad_token_id
candidate_id = candidate_id.unsqueeze(-1)
if normalize:
if score_mode == "log":
_output = F.log_softmax(output, dim=3)
else:
_output = F.softmax(output, dim=3)
scores = torch.gather(_output, 3, candidate_id).squeeze(-1) # [bz, cand_num, seq_len]
else:
scores = torch.gather(output, 3, candidate_id).squeeze(-1) # [bz, cand_num, seq_len]
cand_mask = cand_mask.float()
scores = torch.mul(scores, cand_mask).sum(-1) / ((cand_mask.sum(-1) + adding) ** length_penalty) # [bz, cand_num]
if require_gold:
output = {'score': scores[:, 1:], "summary_score": scores[:, 0], "probs": probs}
else:
output = {'score': scores, "probs": probs}
return output
def scoring_mode(self):
self.model.model.scoring_mode()
def generation_mode(self):
self.model.model.generation_mode()
def generate(
self,
input_ids: Optional[torch.LongTensor] = None,
max_length: Optional[int] = None,
min_length: Optional[int] = None,
do_sample: Optional[bool] = None,
early_stopping: Optional[bool] = None,
num_beams: Optional[int] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
repetition_penalty: Optional[float] = None,
bad_words_ids: Optional[Iterable[int]] = None,
bos_token_id: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
length_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
encoder_no_repeat_ngram_size: Optional[int] = None,
num_return_sequences: Optional[int] = None,
max_time: Optional[float] = None,
decoder_start_token_id: Optional[int] = None,
use_cache: Optional[bool] = None,
num_beam_groups: Optional[int] = None,
diversity_penalty: Optional[float] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None,
return_dict_in_generate: Optional[bool] = None,
forced_bos_token_id: Optional[int] = None,
forced_eos_token_id: Optional[int] = None,
remove_invalid_values: Optional[bool] = None,
synced_gpus: Optional[bool] = None,
**model_kwargs,
):
return self.model.generate(input_ids=input_ids,
max_length=max_length,
min_length=min_length,
do_sample=do_sample,
early_stopping=early_stopping,
num_beams=num_beams,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
bad_words_ids=bad_words_ids,
bos_token_id=bos_token_id,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
length_penalty=length_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size,
num_return_sequences=num_return_sequences,
max_time=max_time,
decoder_start_token_id=decoder_start_token_id,
use_cache=use_cache,
num_beam_groups=num_beam_groups,
diversity_penalty=diversity_penalty,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_scores=output_scores,
return_dict_in_generate=return_dict_in_generate,
forced_bos_token_id=forced_bos_token_id,
forced_eos_token_id=forced_eos_token_id,
remove_invalid_values=remove_invalid_values,
synced_gpus=synced_gpus,
**model_kwargs)