forked from penny4860/tf2-eager-yolo3
-
Notifications
You must be signed in to change notification settings - Fork 1
/
train_eager.py
36 lines (27 loc) · 928 Bytes
/
train_eager.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# -*- coding: utf-8 -*-
import tensorflow as tf
import argparse
from yolo.train import train_fn
from yolo.config import ConfigParser
argparser = argparse.ArgumentParser(
description='train yolo-v3 network')
argparser.add_argument(
'-c',
'--config',
default="configs/svhn.json",
help='config file')
if __name__ == '__main__':
args = argparser.parse_args()
config_parser = ConfigParser(args.config)
# 1. create generator
train_generator, valid_generator = config_parser.create_generator()
# 2. create model
model = config_parser.create_model()
# 3. training
learning_rate, save_dname, n_epoches = config_parser.get_train_params()
train_fn(model,
train_generator,
valid_generator,
learning_rate=learning_rate,
save_dname=save_dname,
num_epoches=n_epoches)