-
Notifications
You must be signed in to change notification settings - Fork 7
/
main.py
105 lines (82 loc) · 4.01 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import os
import torch
import random
import numpy as np
import argparse
torch.manual_seed(0)
random.seed(0)
np.random.seed(0)
from models import Model
from pathlib import Path
from utils import create_dir_if_not_exists, read_json, get_n_params
from dataset import Dataset, Ontology
def create_parser():
# Creates a parser for command-line arguments.
parser = argparse.ArgumentParser()
# Required Parameters
parser.add_argument('--data_dir', type=Path, required=True)
parser.add_argument('--bert_model', type=str, required=True,
choices=['bert-base-uncased', 'bert-large-uncased'])
parser.add_argument('--output_dir', type=Path, required=True)
# Training Parameters
parser.add_argument('--epochs', type=int, default=25)
parser.add_argument('--batch_size', type=int, default=16)
parser.add_argument('--learning_rate', type=float, default=2e-5)
parser.add_argument('--warmup_proportion', type=float, default=0.1)
parser.add_argument('--gradient_accumulation_steps', type=int, default=8)
parser.add_argument('--random_oversampling', action='store_true')
# Other Parameters
parser.add_argument('--no_cuda', action='store_true')
parser.add_argument('--do_train', action='store_true', help='Whether to run training.')
parser.add_argument('--do_eval', action='store_true', help='Whether to run evaluation.')
return parser
def load_dataset(base_path):
dataset = {}
dataset['train'] = Dataset.from_dict(read_json(base_path / 'train.json'))
dataset['dev'] = Dataset.from_dict(read_json(base_path / 'dev.json'))
dataset['test'] = Dataset.from_dict(read_json(base_path / 'test.json'))
ontology = Ontology.from_dict(read_json(base_path / 'ontology.json'))
return dataset, ontology
def main(opts):
if not opts.do_train and not opts.do_eval:
raise ValueError('At least one of `do_train` or `do_eval` must be True.')
if os.path.exists(opts.output_dir) and os.listdir(opts.output_dir) and opts.do_train:
raise ValueError('Output directory ({}) already exists and is not empty.'.format(opts.output_dir))
# Create the output directory (if not exists)
create_dir_if_not_exists(opts.output_dir)
# Device device type
opts.device = torch.device('cuda' if torch.cuda.is_available() and not opts.no_cuda else 'cpu')
opts.n_gpus = torch.cuda.device_count() if str(opts.device) == 'cuda' else 0
print('Device Type: {} | Number of GPUs: {}'.format(opts.device, opts.n_gpus), flush=True)
# Load Datasets and Ontology
dataset, ontology = load_dataset(opts.data_dir)
print('Loaded Datasets and Ontology', flush=True)
print('Number of Train Dialogues: {}'.format(len(dataset['train'])), flush=True)
print('Number of Dev Dialogues: {}'.format(len(dataset['dev'])), flush=True)
print('Number of Test Dialogues: {}'.format(len(dataset['test'])), flush=True)
if opts.do_train:
# Load model from scratch
model = Model.from_scratch(opts.bert_model)
model.move_to_device(opts)
print('Number of model parameters is: {}'.format(get_n_params(model)))
# Start Training
print('Start Training', flush=True)
model.run_train(dataset, ontology, opts)
# Free up all memory pytorch is taken from gpu memory
del model
torch.cuda.empty_cache()
if opts.do_eval:
if not (os.path.exists(opts.output_dir) and os.listdir(opts.output_dir)):
raise ValueError('Output directory ({}) is empty. Cannot do evaluation'.format(opts.output_dir))
# Load trained model
model = Model.from_model_path(opts.output_dir)
model.move_to_device(opts)
print('Number of model parameters is: {}'.format(get_n_params(model)))
# Start evaluating
print('Start Evaluating', flush=True)
print(model.run_dev(dataset, ontology, opts),flush=True)
print(model.run_test(dataset, ontology, opts),flush=True)
if __name__ == '__main__':
parser = create_parser()
opts = parser.parse_args()
main(opts)