-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnote_train.py
90 lines (73 loc) · 2.65 KB
/
note_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import keras
from keras.models import Sequential
from keras.layers import Dense
import numpy as np
import data_util
import msvcrt
try:
import _pickle as pickle
except ImportError:
import cPickle as pickle
import NN_util
input_size = 30*50
output_classes = 100
datapath = 'C:/Users/ykane/Documents/music symbols datasets/notes'
du = data_util.DataUtil()
[train_set, train_labels], [test_set, test_labels] = du.read_dataset(datapath)
output_classes = len(np.unique(train_labels))
print('num of classes: %d' % output_classes)
train_set = train_set.astype('float32')
test_set = test_set.astype('float32')
train_set /= 255
test_set /= 255
print(train_set.shape)
def train_NN(reload=False):
if not reload:
network = NN_util.create_net(input_size, output_classes)
accuracies = []
nets = []
acc = 0.0
iteration = 0
else:
with open('network_list.pickle', 'rb') as f:
nets = pickle.load(f)
if len(nets) == 0:
raise RuntimeError('networks not found!')
network = NN_util.net_from_data(nets[-1][0], nets[-1][1])
with open('accuracies.pickle', 'rb') as f:
accuracies = pickle.load(f)
acc = 0.0 if len(accuracies) == 0 else accuracies[-1]
iteration = len(nets)
while acc < 0.965 and iteration < 100:
print('start iteration %d' % (iteration + 1))
train_res = network.fit(train_set, keras.utils.to_categorical(train_labels, num_classes=output_classes),
epochs=10, batch_size=50,
validation_data=(test_set, keras.utils.to_categorical(test_labels, num_classes=output_classes)))
train_res = train_res.history
acc = train_res['val_acc']
accuracies += acc if type(acc) is list else [acc]
nets.append(NN_util.network_data(network))
iteration += 1
acc = acc[-1] if type(acc) is list else acc
print('eval acc: ' + str(acc))
if msvcrt.kbhit():
cmd = msvcrt.getch()
if cmd == b'x':
print('are you sure you want to stop? (y/n)')
if input() == 'y':
break
with open('network_list.pickle', 'wb') as f:
pickle.dump(nets, f)
with open('accuracies.pickle', 'wb') as f:
pickle.dump(accuracies, f)
du.plot(accuracies)
if __name__ == '__main__':
train_NN(reload=True)
'''
permutation = np.arange(len(train_set))
np.random.shuffle(permutation)
train_set = train_set[permutation]
print(type(permutation))
print(permutation)
train_labels = train_labels[permutation]
'''