-
Notifications
You must be signed in to change notification settings - Fork 13
/
main.py
114 lines (88 loc) · 3.9 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import argparse
import os
import logging
import collections
from datetime import datetime
from config.hparams import *
from train import Summarization
import torch
from torch.utils.tensorboard import SummaryWriter
def init_logger(path):
if not os.path.exists(path):
os.makedirs(path)
logger = logging.getLogger()
logger.handlers = []
logger.setLevel(logging.DEBUG)
debug_fh = logging.FileHandler(os.path.join(path, "debug.log"))
debug_fh.setLevel(logging.DEBUG)
info_fh = logging.FileHandler(os.path.join(path, "info.log"))
info_fh.setLevel(logging.INFO)
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
info_formatter = logging.Formatter('%(asctime)s | %(levelname)-8s | %(message)s')
debug_formatter = logging.Formatter('%(asctime)s | %(levelname)-8s | %(message)s | %(lineno)d:%(funcName)s')
ch.setFormatter(info_formatter)
info_fh.setFormatter(info_formatter)
debug_fh.setFormatter(debug_formatter)
logger.addHandler(ch)
logger.addHandler(debug_fh)
logger.addHandler(info_fh)
return logger
def train_model(args):
hparams = PARAMS
hparams = collections.namedtuple("HParams", sorted(hparams.keys()))(**hparams)
save_path = args.save_path
if save_path == '':
raise ValueError("Muse provide save path !")
hparams = hparams._replace(save_dirpath=save_path)
hparams = hparams._replace(use_role=args.use_role)
hparams = hparams._replace(use_role=args.use_pos)
print('hparams.save_dirpath: ', hparams.save_dirpath)
summarization = Summarization(hparams, mode='train')
summarization.train()
def evaluate_model(args):
hparams = PARAMS
hparams = collections.namedtuple("HParams", sorted(hparams.keys()))(**hparams)
model_path = args.model_path
if model_path == '':
raise ValueError('Must provide model_path !')
save_dirpath = '/'.join(model_path.split('/')[:-1])
save_dirpath = save_dirpath + '/'
hparams = hparams._replace(save_dirpath=save_dirpath)
# gen_max_length
gen_max_length = args.gen_max_length
print('gen_max_length: ', gen_max_length)
hparams = hparams._replace(gen_max_length=gen_max_length)
hparams = hparams._replace(use_role=args.use_role)
hparams = hparams._replace(use_role=args.use_pos)
epoch = hparams.start_eval_epoch
print('\n ========= [Evaluation Start Epoch: ', epoch, ']================== ')
for i in range(int(epoch), 100):
load_pthpath = '/'.join(model_path.split('/')[:-1]) + '/checkpoint_' + str(i) + '.pth'
hparams= hparams._replace(load_pthpath=load_pthpath)
print('hparams.load_pthpath: ', hparams.load_pthpath)
summarization = Summarization(hparams, mode='eval')
summarization.predictor.evaluate(epoch=i,
test_dataloader=summarization.test_dataloader, eval_path=load_pthpath)
del summarization
print('\n')
if __name__ == '__main__':
arg_parser = argparse.ArgumentParser(description="End-to-End Meeting Summarization (PyTorch)")
arg_parser.add_argument("--mode", dest="mode", type=str, default="",
help="(train/eval)")
arg_parser.add_argument("--model_path", dest="model_path", type=str, default="",
help="trained model path")
arg_parser.add_argument("--save_path", dest="save_path", type=str, default="",
help="path to save the trained model")
arg_parser.add_argument("--gen_max_length", dest="gen_max_length", type=int,
default=500, help="gen_max_length")
arg_parser.add_argument("--use_role", dest="use_role", type=bool,
default=False)
arg_parser.add_argument("--use_pos", dest="use_pos", type=bool,
default=False)
args = arg_parser.parse_args()
mode = args.mode
if mode == 'train':
train_model(args)
elif mode == 'eval':
evaluate_model(args)