-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathutils.py
124 lines (94 loc) · 3.93 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, regularizers
import json
from os import path
def plot_sample(X, rows, cols, tensor=False):
''' Function for plotting images.'''
nb_rows = rows
nb_cols = cols
fig, axs = plt.subplots(nb_rows, nb_cols, figsize=(8, 8))
k=0
# if input is tensor convert to np.array before plotting
if tensor: X = X.numpy()
for i in range(0, nb_rows):
for j in range(0, nb_cols):
axs[i, j].xaxis.set_ticklabels([])
axs[i, j].yaxis.set_ticklabels([])
axs[i, j].imshow(X[k])
plt.tight_layout()
k += 1
plt.show()
def plot_training_curves(history, save_path):
'''Plot training learning curves for both train and validation.'''
#Defining the metrics we will plot.
train_acc = history[1][0]
val_acc = history[1][1]
train_loss = history[0][0]
val_loss = history[0][1]
#Range for the X axis.
epochs = range(len(train_loss))
#Plotting Loss figures.
fig = plt.figure(figsize=(12,10)) #figure size h,w in inches
plt.rcParams.update({'font.size': 22}) #configuring font size.
plt.plot(epochs,train_loss,c="red",label="Training Loss") #plotting
plt.plot(epochs,val_loss,c="blue",label="Validation Loss")
plt.xlabel("Epochs") #title for x axis
plt.ylabel("Loss") #title for y axis
plt.legend(fontsize=11)
#Plotting Accuracy figures.
fig = plt.figure(figsize=(12,10)) #figure size h,w in inches
plt.plot(epochs,train_acc,c="red",label="Training Acc") #plotting
plt.plot(epochs,val_acc,c="blue",label="Validation Acc")
plt.xlabel("Epochs") #title for x axis
plt.ylabel("Accuracy") #title for y axis
plt.legend(fontsize=11)
plt.savefig(path.join(save_path, 'training_curve.png'))
class job_receiver:
'''
Basic job receiver class. Receives a json file and
produces the relevant dictionary...
'''
def __init__(self, path: str):
self.path = path
def __call__(self):
with open(self.path, 'r') as file:
job = json.load(file)
return job
if __name__ == '__main__':
#run this script to understand the implemented functionality...
###########################################################################################
###########################################################################################
###########################################################################################
#rotnet example...
#path for rotnet construction and training...
rotnet_path = './rotnet_config_example.json'
#rotnet job (dictionary form)
rotnet_job = job_receiver(rotnet_path)()
#here the rotnet model is constructed. see the json file in the rotnet_path to understand...
RotNet = RotNet_constructor(rotnet_job['build_instructions'])
print(RotNet.summary())
'''
Here we imagine training to happen...
Options from rotnet_job['training']
'''
#save model...
tf.keras.models.save_model(RotNet, rotnet_job['save_path'])
###########################################################################################
###########################################################################################
###########################################################################################
#prednet example...
#path for prednet construction and training...
prednet_path = './prednet_config_example.json'
#rotnet job (dictionary form)
prednet_job = job_receiver(prednet_path)()
#here the rotnet model is constructed. see the json file in the rotnet_path to understand...
PredNet = PredNet_constructor(prednet_job['build_instructions'])
print(PredNet.summary())
'''
Here we imagine training to happen...
Options from prednet_job['training']
'''
#save model...
PredNet.save_weights(prednet_job['save_path'])