-
Notifications
You must be signed in to change notification settings - Fork 1
/
mnist_dataset.py
114 lines (77 loc) · 2.74 KB
/
mnist_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import datasets.mnist.mnist_loader as mnist
from nn import NerualNetwork
import random
from utils import printProgressBar
from random import randrange
import pickle
import signal
import sys
from multiprocessing import Process
import keyboard
import time
# Uncomment to download the dataset if it's not already downloaded
# mnist.init()
data = mnist.load()
ITERATIONS = 100000
LEARNING_RATE = 0.1
def trainModel():
nn = NerualNetwork(784,64,10, LEARNING_RATE)
# show a progress bar
printProgressBar(0, ITERATIONS, prefix = 'Progress:', suffix = 'Complete | 0 Iterations', length = 50)
for i in range(ITERATIONS):
printProgressBar(i + 1, ITERATIONS, prefix = 'Progress:', suffix = 'Complete | ' + str(i + 1) + ' Iterations', length = 50)
index = randrange(5999)
inputs = [x/255 for x in data[0][index]]
targets = data[1][index]
t = []
for j in range(10):
if (j == targets):
t.append(1)
else:
t.append(0)
nn.train(inputs, t)
if keyboard.is_pressed('x'):
break
with open('models/model1.pkl', 'wb') as output:
pickle.dump(nn, output, pickle.HIGHEST_PROTOCOL)
def resumeTraining():
with open('models/model1.pkl', 'rb') as input:
nn = pickle.load(input)
# show a progress bar
printProgressBar(0, ITERATIONS, prefix = 'Progress:', suffix = 'Complete | 0 Iterations', length = 50)
for i in range(nn.iterations, ITERATIONS):
printProgressBar(i + 1, ITERATIONS, prefix = 'Progress:', suffix = 'Complete | ' + str(i + 1) + ' Iterations', length = 50)
index = randrange(5999)
inputs = [x/255 for x in data[0][index]]
targets = data[1][index]
t = []
for j in range(10):
if (j == targets):
t.append(1)
else:
t.append(0)
nn.train(inputs, t )
if keyboard.is_pressed('x'):
break
with open('models/model1.pkl', 'wb') as output:
pickle.dump(nn, output, pickle.HIGHEST_PROTOCOL)
def test():
with open('models/model1.pkl', 'rb') as input:
nn = pickle.load(input)
print('Iterations : ', nn.iterations)
correct = 0
for i in range(1000):
res = nn.feedforward([x/255 for x in data[2][i]])
highest = 0
digit = -1
for r in range(len(res)):
if (highest < res[r]):
highest = res[r]
digit = r
if (digit == data[3][i]):
correct += 1
# print('expected : ', data[3][i], 'actual : ', digit)
print('Accuracy : ', str((correct / 1000) * 100) + '%' )
# trainModel()
# resumeTraining()
test()