-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
74 lines (63 loc) · 2.79 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# SPDX-FileCopyrightText: 2024 Idiap Research Institute
#
# SPDX-License-Identifier: MIT
""" Main script to finetune seq2seq models. """
from pytorch_lightning import Trainer, loggers, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor, GPUStatsMonitor
from bart import BartSummarizer
from dataloader import SummarizationDataModule
from t5 import T5Summarizer
MODELS = {
'bart': BartSummarizer,
't5': T5Summarizer,
}
def main(args):
seed_everything(args.seed)
data_module = SummarizationDataModule(args)
model_class = MODELS[args.model]
model = model_class(args)
monitor_mode = 'max' if args.monitor == 'val_rouge' else 'min'
model_checkpoint = ModelCheckpoint(
dirpath=args.model_dir,
filename=args.model + '-{epoch}-{' + args.monitor + ':.2f}',
monitor=args.monitor,
save_top_k=1,
mode=monitor_mode,
)
early_stopping = EarlyStopping(args.monitor, mode=monitor_mode, patience=5)
lr_monitor = LearningRateMonitor(logging_interval='step')
callbacks = [model_checkpoint, early_stopping, lr_monitor]
if isinstance(args.gpus, int):
gpu_monitor = GPUStatsMonitor(intra_step_time=True, inter_step_time=True)
callbacks.append(gpu_monitor)
logger = loggers.TensorBoardLogger(
save_dir=args.model_dir,
name='',
)
trainer = Trainer.from_argparse_args(
args,
callbacks=callbacks,
logger=logger,
)
trainer.fit(model, datamodule=data_module)
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='Generates interpretations from filtered announcements.')
# select model and add model args
parser.add_argument('--model', default='bart', choices=MODELS.keys(), help='Model name')
temp_args, _ = parser.parse_known_args()
model_class = MODELS[temp_args.model]
parser = model_class.add_model_specific_args(parser)
# data args
parser.add_argument('--data_dir', default='data', help='Path to data directory')
parser.add_argument('--filter_model', default='oracle', choices=['oracle', 'lead'],
help='Model used for filtering input data')
parser.add_argument('--num_workers', type=int, default=0, help='Num workers for data loading')
parser.add_argument('--batch_size', type=int, default=5, help='Train batch size')
# trainer args
parser = Trainer.add_argparse_args(parser)
parser.add_argument('--model_dir', default='models', help='Path to model directory')
parser.add_argument('--monitor', default='val_rouge', choices=['val_rouge', 'val_loss'],
help='Monitor variable to select the best model')
parser.add_argument('--seed', default=1, help='Random seed')
main(parser.parse_args())