-
Notifications
You must be signed in to change notification settings - Fork 276
/
summarization.py
354 lines (308 loc) · 17.3 KB
/
summarization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
import os
import argparse
import random
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoConfig
from transformers.optimization import get_linear_schedule_with_warmup, Adafactor
import nlp
from rouge_score import rouge_scorer
import pytorch_lightning as pl
from pytorch_lightning.logging import TestTubeLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
from longformer import LongformerEncoderDecoderForConditionalGeneration, LongformerEncoderDecoderConfig
from longformer.sliding_chunks import pad_to_window_size
def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100):
"""From fairseq"""
if target.dim() == lprobs.dim() - 1:
target = target.unsqueeze(-1)
nll_loss = -lprobs.gather(dim=-1, index=target)
smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
if ignore_index is not None:
pad_mask = target.eq(ignore_index)
nll_loss.masked_fill_(pad_mask, 0.0)
smooth_loss.masked_fill_(pad_mask, 0.0)
count = (~pad_mask).sum()
else:
nll_loss = nll_loss.squeeze(-1)
smooth_loss = smooth_loss.squeeze(-1)
count = nll_loss.numel()
nll_loss = nll_loss.sum() / count
smooth_loss = smooth_loss.sum() / count
eps_i = epsilon / lprobs.size(-1)
loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss
return loss, nll_loss
class SummarizationDataset(Dataset):
def __init__(self, hf_dataset, tokenizer, max_input_len, max_output_len):
self.hf_dataset = hf_dataset
self.tokenizer = tokenizer
self.max_input_len = max_input_len
self.max_output_len = max_output_len
def __len__(self):
return len(self.hf_dataset)
def __getitem__(self, idx):
entry = self.hf_dataset[idx]
input_ids = self.tokenizer.encode(entry['article'], truncation=True, max_length=self.max_input_len)
output_ids = self.tokenizer.encode(entry['abstract'], truncation=True, max_length=self.max_output_len)
if self.tokenizer.bos_token_id is None: # pegasus
output_ids = [self.tokenizer.pad_token_id] + output_ids
return torch.tensor(input_ids), torch.tensor(output_ids)
@staticmethod
def collate_fn(batch):
# A hack to know if this is bart or pegasus. DDP doesn't like global variables nor class-level memebr variables
if batch[0][0][-1].item() == 2:
pad_token_id = 1 # AutoTokenizer.from_pretrained('facebook/bart-base').pad_token_id
elif batch[0][0][-1].item() == 1:
pad_token_id = 0 # AutoTokenizer.from_pretrained('google/pegasus-large').pad_token_id
else:
assert False
input_ids, output_ids = list(zip(*batch))
input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=pad_token_id)
output_ids = torch.nn.utils.rnn.pad_sequence(output_ids, batch_first=True, padding_value=pad_token_id)
return input_ids, output_ids
class Summarizer(pl.LightningModule):
def __init__(self, params):
super().__init__()
self.args = params
self.hparams = params
self.tokenizer = AutoTokenizer.from_pretrained(self.args.tokenizer, use_fast=True)
if 'long' in self.args.model_path:
config = LongformerEncoderDecoderConfig.from_pretrained(self.args.model_path)
config.attention_dropout = self.args.attention_dropout
config.gradient_checkpointing = self.args.grad_ckpt
config.attention_mode = self.args.attention_mode
config.attention_window = [self.args.attention_window] * config.encoder_layers
self.model = LongformerEncoderDecoderForConditionalGeneration.from_pretrained(
self.args.model_path, config=config)
else:
config = AutoConfig.from_pretrained(self.args.model_path)
config.attention_dropout = self.args.attention_dropout
self.model = AutoModelForSeq2SeqLM.from_pretrained(
self.args.model_path, config=config)
self.train_dataloader_object = self.val_dataloader_object = self.test_dataloader_object = None
def _prepare_input(self, input_ids):
attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=input_ids.device)
attention_mask[input_ids == self.tokenizer.pad_token_id] = 0
if isinstance(self.model, LongformerEncoderDecoderForConditionalGeneration):
attention_mask[:, 0] = 2 # global attention on one token for all model params to be used, which is important for gradient checkpointing to work
if self.args.attention_mode == 'sliding_chunks':
half_padding_mod = self.model.config.attention_window[0]
elif self.args.attention_mode == 'sliding_chunks_no_overlap':
half_padding_mod = self.model.config.attention_window[0] / 2
else:
raise NotImplementedError
input_ids, attention_mask = pad_to_window_size( # ideally, should be moved inside the LongformerModel
input_ids, attention_mask, half_padding_mod, self.tokenizer.pad_token_id)
return input_ids, attention_mask
def forward(self, input_ids, output_ids):
input_ids, attention_mask = self._prepare_input(input_ids)
decoder_input_ids = output_ids[:, :-1]
decoder_attention_mask = (decoder_input_ids != self.tokenizer.pad_token_id)
labels = output_ids[:, 1:].clone()
outputs = self.model(
input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
use_cache=False,)
lm_logits = outputs[0]
if self.args.label_smoothing == 0:
# Same behavior as modeling_bart.py, besides ignoring pad_token_id
ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=self.tokenizer.pad_token_id)
assert lm_logits.shape[-1] == self.model.config.vocab_size
loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), labels.view(-1))
else:
lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1)
loss, nll_loss = label_smoothed_nll_loss(
lprobs, labels, self.args.label_smoothing, ignore_index=self.tokenizer.pad_token_id
)
return [loss]
def training_step(self, batch, batch_nb):
output = self.forward(*batch)
loss = output[0]
lr = loss.new_zeros(1) + self.trainer.optimizers[0].param_groups[0]['lr']
tensorboard_logs = {'train_loss': loss, 'lr': lr,
'input_size': batch[0].numel(),
'output_size': batch[1].numel(),
'mem': torch.cuda.memory_allocated(loss.device) / 1024 ** 3 if torch.cuda.is_available() else 0}
return {'loss': loss, 'log': tensorboard_logs}
def validation_step(self, batch, batch_nb):
for p in self.model.parameters():
p.requires_grad = False
outputs = self.forward(*batch)
vloss = outputs[0]
input_ids, output_ids = batch
input_ids, attention_mask = self._prepare_input(input_ids)
generated_ids = self.model.generate(input_ids=input_ids, attention_mask=attention_mask,
use_cache=True, max_length=self.args.max_output_len,
num_beams=1)
generated_str = self.tokenizer.batch_decode(generated_ids.tolist(), skip_special_tokens=True)
gold_str = self.tokenizer.batch_decode(output_ids.tolist(), skip_special_tokens=True)
scorer = rouge_scorer.RougeScorer(rouge_types=['rouge1', 'rouge2', 'rougeL', 'rougeLsum'], use_stemmer=False)
rouge1 = rouge2 = rougel = rougelsum = 0.0
for ref, pred in zip(gold_str, generated_str):
score = scorer.score(ref, pred)
rouge1 += score['rouge1'].fmeasure
rouge2 += score['rouge2'].fmeasure
rougel += score['rougeL'].fmeasure
rougelsum += score['rougeLsum'].fmeasure
rouge1 /= len(generated_str)
rouge2 /= len(generated_str)
rougel /= len(generated_str)
rougelsum /= len(generated_str)
return {'vloss': vloss,
'rouge1': vloss.new_zeros(1) + rouge1,
'rouge2': vloss.new_zeros(1) + rouge2,
'rougeL': vloss.new_zeros(1) + rougel,
'rougeLsum': vloss.new_zeros(1) + rougelsum, }
def validation_epoch_end(self, outputs):
for p in self.model.parameters():
p.requires_grad = True
names = ['vloss', 'rouge1', 'rouge2', 'rougeL', 'rougeLsum']
metrics = []
for name in names:
metric = torch.stack([x[name] for x in outputs]).mean()
if self.trainer.use_ddp:
torch.distributed.all_reduce(metric, op=torch.distributed.ReduceOp.SUM)
metric /= self.trainer.world_size
metrics.append(metric)
logs = dict(zip(*[names, metrics]))
print(logs)
return {'avg_val_loss': logs['vloss'], 'log': logs, 'progress_bar': logs}
def test_step(self, batch, batch_nb):
return self.validation_step(batch, batch_nb)
def test_epoch_end(self, outputs):
result = self.validation_epoch_end(outputs)
print(result)
def configure_optimizers(self):
if self.args.adafactor:
optimizer = Adafactor(self.model.parameters(), lr=self.args.lr, scale_parameter=False, relative_step=False)
else:
optimizer = torch.optim.Adam(self.model.parameters(), lr=self.args.lr)
if self.args.debug:
return optimizer # const LR
num_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 1
num_steps = self.args.dataset_size * self.args.epochs / num_gpus / self.args.grad_accum / self.args.batch_size
scheduler = get_linear_schedule_with_warmup(
optimizer, num_warmup_steps=self.args.warmup, num_training_steps=num_steps
)
return [optimizer], [{"scheduler": scheduler, "interval": "step"}]
def _get_dataloader(self, current_dataloader, split_name, is_train):
if current_dataloader is not None:
return current_dataloader
dataset = SummarizationDataset(hf_dataset=self.hf_datasets[split_name], tokenizer=self.tokenizer,
max_input_len=self.args.max_input_len, max_output_len=self.args.max_output_len)
sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=is_train) if self.trainer.use_ddp else None
return DataLoader(dataset, batch_size=self.args.batch_size, shuffle=(sampler is None),
num_workers=self.args.num_workers, sampler=sampler,
collate_fn=SummarizationDataset.collate_fn)
@pl.data_loader
def train_dataloader(self):
self.train_dataloader_object = self._get_dataloader(self.train_dataloader_object, 'train', is_train=True)
return self.train_dataloader_object
@pl.data_loader
def val_dataloader(self):
self.val_dataloader_object = self._get_dataloader(self.val_dataloader_object, 'validation', is_train=False)
return self.val_dataloader_object
@pl.data_loader
def test_dataloader(self):
self.test_dataloader_object = self._get_dataloader(self.test_dataloader_object, 'test', is_train=False)
return self.test_dataloader_object
def configure_ddp(self, model, device_ids):
model = LightningDistributedDataParallel(
model,
device_ids=device_ids,
find_unused_parameters=False
)
return model
@staticmethod
def add_model_specific_args(parser, root_dir):
parser.add_argument("--save_dir", type=str, default='summarization')
parser.add_argument("--save_prefix", type=str, default='test')
parser.add_argument("--batch_size", type=int, default=16, help="Batch size")
parser.add_argument("--grad_accum", type=int, default=1, help="number of gradient accumulation steps")
parser.add_argument("--gpus", type=int, default=-1,
help="Number of gpus. 0 for CPU")
parser.add_argument("--warmup", type=int, default=1000, help="Number of warmup steps")
parser.add_argument("--lr", type=float, default=0.00003, help="Maximum learning rate")
parser.add_argument("--val_every", type=float, default=1.0, help="Number of training steps between validations")
parser.add_argument("--val_percent_check", default=1.00, type=float, help='Percent of validation data used')
parser.add_argument("--num_workers", type=int, default=0, help="Number of data loader workers")
parser.add_argument("--seed", type=int, default=1234, help="Seed")
parser.add_argument("--epochs", type=int, default=5, help="Number of epochs")
parser.add_argument("--disable_checkpointing", action='store_true', help="No logging or checkpointing")
parser.add_argument("--max_output_len", type=int, default=256,
help="maximum num of wordpieces/summary. Used for training and testing")
parser.add_argument("--max_input_len", type=int, default=512,
help="maximum num of wordpieces/summary. Used for training and testing")
parser.add_argument("--test", action='store_true', help="Test only, no training")
parser.add_argument("--model_path", type=str, default='facebook/bart-base',
help="Path to the checkpoint directory or model name")
parser.add_argument("--tokenizer", type=str, default='facebook/bart-base')
parser.add_argument("--no_progress_bar", action='store_true', help="no progress bar. Good for printing")
parser.add_argument("--fp32", action='store_true', help="default is fp16. Use --fp32 to switch to fp32")
parser.add_argument("--debug", action='store_true', help="debug run")
parser.add_argument("--resume_ckpt", type=str, help="Path of a checkpoint to resume from")
parser.add_argument("--from_pretrained", type=str, default=None,
help="Path to a checkpoint to load model weights but not training state")
parser.add_argument('--grad_ckpt', action='store_true', help='Enable gradient checkpointing to save memory')
parser.add_argument("--attention_dropout", type=float, default=0.1, help="attention dropout")
parser.add_argument("--attention_mode", type=str, default='sliding_chunks', help="Longformer attention mode")
parser.add_argument("--attention_window", type=int, default=512, help="Attention window")
parser.add_argument("--label_smoothing", type=float, default=0.0, required=False)
parser.add_argument("--adafactor", action='store_true', help="Use adafactor optimizer")
return parser
def main(args):
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(args.seed)
if args.from_pretrained is not None:
model = Summarizer.load_from_checkpoint(args.from_pretrained, args)
else:
model = Summarizer(args)
model.hf_datasets = nlp.load_dataset('scientific_papers', 'arxiv')
logger = TestTubeLogger(
save_dir=args.save_dir,
name=args.save_prefix,
version=0 # always use version=0
)
checkpoint_callback = ModelCheckpoint(
filepath=os.path.join(args.save_dir, args.save_prefix, "checkpoints"),
save_top_k=5,
verbose=True,
monitor='avg_val_loss',
mode='min',
period=-1,
prefix=''
)
print(args)
args.dataset_size = 203037 # hardcode dataset size. Needed to compute number of steps for the lr scheduler
trainer = pl.Trainer(gpus=args.gpus, distributed_backend='ddp' if torch.cuda.is_available() else None,
track_grad_norm=-1,
max_epochs=args.epochs if not args.debug else 100,
max_steps=None if not args.debug else 1,
replace_sampler_ddp=False,
accumulate_grad_batches=args.grad_accum,
val_check_interval=args.val_every if not args.debug else 1,
num_sanity_val_steps=2 if not args.debug else 0,
check_val_every_n_epoch=1 if not args.debug else 1,
val_percent_check=args.val_percent_check,
test_percent_check=args.val_percent_check,
logger=logger,
checkpoint_callback=checkpoint_callback if not args.disable_checkpointing else False,
show_progress_bar=not args.no_progress_bar,
use_amp=not args.fp32, amp_level='O2',
resume_from_checkpoint=args.resume_ckpt,
)
if not args.test:
trainer.fit(model)
trainer.test(model)
if __name__ == "__main__":
main_arg_parser = argparse.ArgumentParser(description="summarization")
parser = Summarizer.add_model_specific_args(main_arg_parser, os.getcwd())
args = parser.parse_args()
main(args)