diff --git a/ann.py b/ann.py index d87fcb5..3a3afb8 100644 --- a/ann.py +++ b/ann.py @@ -1,8 +1,7 @@ import numpy as np import matplotlib.pyplot as plt -from util import getData, softmax, cost2, y2indicator, error_rate, relu -from sklearn.utils import shuffle +from util import getData, softmax, cost2, y2indicator, error_rate, relu,splitTrainTestFromLast class ANN(object): @@ -11,10 +10,8 @@ def __init__(self, M): # learning rate 10e-6 is too large def fit(self, X, Y, learning_rate=10e-7, reg=10e-7, epochs=10000, show_fig=False): - X, Y = shuffle(X, Y) - Xvalid, Yvalid = X[-1000:], Y[-1000:] # Tvalid = y2indicator(Yvalid) - X, Y = X[:-1000], Y[:-1000] + Xvalid,Yvalid,X,Y = splitTrainTestFromLast(X,Y,1000) N, D = X.shape K = len(set(Y)) @@ -70,7 +67,7 @@ def score(self, X, Y): def main(): X, Y = getData() - + model = ANN(200) model.fit(X, Y, reg=0, show_fig=True) print model.score(X, Y) diff --git a/util.py b/util.py index fb5ab46..3b28a16 100644 --- a/util.py +++ b/util.py @@ -1,6 +1,6 @@ import numpy as np import pandas as pd - +from sklearn.utils import shuffle def init_weight_and_bias(M1, M2): W = np.random.randn(M1, M2) / np.sqrt(M1 + M2) @@ -121,3 +121,7 @@ def crossValidation(model, X, Y, K=5): errors.append(err) print "errors:", errors return np.mean(errors) + +def splitTrainTestFromLast(X,Y,N): + X,Y = shuffle(X,Y) + return (X[-N:],Y[-N:],X[:-N],Y[:-N])