-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
executable file
·86 lines (61 loc) · 2.13 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
'''
main.py
'''
mainversion = "1.2.210604"
import inspect
import os
#FROM Python LIBRARY
import time
import warnings
from importlib import import_module
#FROM PyTorch
import torch
import torch.nn.functional as F
from torchvision.utils import save_image
from backbone.config import Config
os.environ["CUDA_VISIBLE_DEVICES"] = str(Config.param.general.GPUNum)
#from this project
import backbone.module.module as module
import backbone.structure as structure
import backbone.utils as utils
import backbone.vision as vision
import edit
import model
from backbone.structure import Epoch
from edit import ModelList
#Arg parser init
parser, args = utils.initArgParser()
#init Folder & Files
utils.initFolderAndFiles(edit.version, edit.subversion)
#turn on tensorboard process
utils.initTensorboardProcess(edit.version)
############################################
############################################
print(" Version : " + edit.version)
print(" sub Version : " + edit.subversion)
############################################
############################################
print("")
print("")
print("Load below models...")
startEpoch, metaData = utils.loadModels(edit.modelList, edit.version, edit.subversion, args.load)
print(f"All model loaded. Last Epoch: {startEpoch}")#", Loss: {lastLoss.item():.6f}, BEST Score: {bestScore:.2f} dB")
print("")
print("")
print("Init Epoch...")
#Define Epochs
if args.inferenceTest == False:
trainEpoch = edit.trainEpoch
validationEpoch = edit.validationEpoch
else:
inferenceEpoch = edit.inferenceEpoch
print("")
print("")
print(f"Running...")
if args.inferenceTest == True :
metaData = inferenceEpoch.run(currentEpoch = 0, metaData = metaData, do_calculateScore = False, do_modelSave = False, do_resultSave='EVERY')
else :
for e in range(startEpoch, Config.param.train.step.maxEpoch):
metaData = trainEpoch.run(currentEpoch = e, metaData = metaData)
if (e + 1) % Config.param.train.step.validationStep == 0:
metaData = validationEpoch.run(currentEpoch = e, metaData = metaData, do_calculateScore = 'DETAIL', do_modelSave = False, do_resultSave='EVERY')