-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathmain_complex.py
39 lines (29 loc) · 1.11 KB
/
main_complex.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from data_loader.data_generator import DataGenerator
from models.complex_conv_model import ComplexConvModel
from trainers.multi_label_conv_model_trainer import MultiLabelConvModelTrainer
from utils.config import process_config
from utils.dirs import create_dirs
from utils.utils import get_args
def main():
# capture the config path from the run arguments
# then process the json configuration file
try:
args = get_args()
config = process_config(args.config)
except:
print("missing or invalid arguments")
exit(0)
# create the experiments dirs
create_dirs([config.summary_dir, config.checkpoint_dir])
print('Create the data generator.')
data_generator = DataGenerator(config)
print('Create the model.')
model = ComplexConvModel(config, data_generator.get_word_index())
print('Create the trainer')
trainer = MultiLabelConvModelTrainer(model.model, data_generator.get_train_data(), config)
print('Start training the model.')
trainer.train()
print('Visualize the losses')
trainer.visualize()
if __name__ == '__main__':
main()