-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
36 lines (28 loc) · 904 Bytes
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
"""
Train.
"""
########################################################################
## Uncomment if you want to disable Tensorflow debugging information. ##
# import os
# os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
########################################################################
import tensorflow as tf
from utils.pretty_print import *
from trainer_tester.trainer import Trainer
from trainer_tester.options import parse_args
def main():
# Parse arguments
args = parse_args()
# Enable memory growth for a GPU
gpus= tf.config.experimental.list_physical_devices('GPU')
try:
tf.config.experimental.set_memory_growth(gpus[0], True)
except:
WARN("Invalid device or cannot modify virtual devices once initialized.")
pass
# Initialize trainer
trainer = Trainer(args)
# Train
trainer.fit()
if __name__ == '__main__':
main()