-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmonk3_test_hsp.py
72 lines (63 loc) · 1.89 KB
/
monk3_test_hsp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import sys
from os import path
sys.path.insert(0, "./ISANet/")
sys.path.insert(0, "./")
from isanet.model import Mlp
from isanet.optimizer import SGD, NCG, LBFGS
from isanet.optimizer.utils import l_norm
from isanet.datasets.monk import load_monk
from isanet.utils.model_utils import printMSE, printAcc, plotHistory
import isanet.metrics as metrics
import numpy as np
import matplotlib.pyplot as plt
#############################
monk = "2"
reg = 1e-4
seed = 189
ng_eps = 3e-5
l_eps = 1e-6
max_iter = 1000
verbose = 2
##############################
#########################################
# Construct the Monk2 objective function
# and define a w0 with the seed
#########################################
np.random.seed(seed=seed)
print("Load Monk DataSet")
X_train, Y_train = load_monk(monk, "train")
print("Build the model")
model = Mlp()
model.add(4, input= 17, kernel_initializer = 0.003, kernel_regularizer = reg)
model.add(1, kernel_initializer = 0.003, kernel_regularizer = reg)
#############################
# NCG HS+
#############################
beta_method = "hs+"
c1 = 1e-4
c2 = .3
restart = None
ln_maxiter = 100
#############################
optimizer = NCG(beta_method=beta_method, c1=c1, c2=c2, restart=restart,
ln_maxiter = ln_maxiter, norm_g_eps = ng_eps, l_eps = l_eps)
model.set_optimizer(optimizer)
print("Start the optimization process:")
model.fit(X_train,
Y_train,
epochs=max_iter,
verbose=verbose)
f_fr = model.history["loss_mse_reg"]
##############################
# plot
##############################
pos_train = (0,0)
figsize = (12, 4)
plt.plot(f_fr - f_fr[-1], linestyle='-')
plt.title('Monk{} - seed={} - (f_k - f^*)'.format(monk, seed))
plt.ylabel("Loss")
plt.xlabel('Iteration')
plt.grid()
plt.yscale('log')
plt.legend(['NCG - PR+'], loc='upper right', fontsize='large')
plt.show()