-
Notifications
You must be signed in to change notification settings - Fork 30
/
Copy pathGPO.py
180 lines (155 loc) · 5.33 KB
/
GPO.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
# -*- coding: utf-8 -*-
"""Implementation of GPO (Shang et al., 2019)
"""
# Author: Wenjie Li <li3549@purdue.edu>
# License: MIT
import pdb
import numpy as np
from PyXAB.algos.Algo import Algorithm
from PyXAB.partition.BinaryPartition import BinaryPartition
class GPO(Algorithm):
"""
Implementation of the General Parallel Optimization (GPO) algorithm (Shang et al., 2019)
"""
def __init__(
self,
numax=1.0,
rhomax=0.9,
rounds=1000,
domain=None,
partition=BinaryPartition,
algo=None,
):
"""
Initialization of the wrapper algorithm
Parameters
----------
numax: float
parameter nu_max in the algorithm (default 1.0)
rhomax: float
parameter rho_max in the algorithm, the maximum rho used (default 0.9)
rounds: int
the number of rounds/budget (default 1000)
domain: list(list)
the domain of the objective function
partition:
the partition used in the optimization process
algo:
the baseline algorithm used by the wrapper, such as T_HOO or HCT
"""
super(GPO, self).__init__()
if domain is None:
raise ValueError("Parameter space is not given.")
if partition is None:
raise ValueError("Partition of the parameter space is not given.")
if algo is None:
raise ValueError("Algorithm for GPO is not given")
if (
algo.__name__ != "T_HOO"
and algo.__name__ != "HCT"
and algo.__name__ != "VHCT"
):
raise NotImplementedError(
"GPO has not yet included implementations for this algorithm"
)
self.rounds = rounds
self.rhomax = rhomax
self.numax = numax
self.Dmax = np.log(2) / np.log(1 / rhomax)
self.domain = domain
self.partition = partition
self.algo = algo
# The big-N in the algorithm
self.N = np.ceil(
0.5 * self.Dmax * np.log((self.rounds / 2) / np.log(self.rounds / 2))
)
# phase number
self.phase = 1
# Starts with a none algorithm
self.curr_algo = None
self.half_phase_length = np.floor(self.rounds / (2 * self.N))
self.counter = 0
self.goodx = None
# The cross-validation list
self.V_x = []
self.V_reward = []
def pull(self, time):
"""
The pull function of GPO that returns a point to be evaluated
Parameters
----------
time: int
The time step of the online process.
Returns
-------
point: list
The point chosen by the GPO algorithm
"""
if self.phase > self.N: # If already finished
return self.goodx
else:
if self.counter == 0:
rho = self.rhomax ** (2 * self.N / (2 * self.phase + 1))
if self.algo.__name__ == "T_HOO":
self.curr_algo = self.algo(
nu=self.numax,
rho=rho,
rounds=self.rounds,
domain=self.domain,
partition=self.partition,
)
elif self.algo.__name__ == "HCT" or self.algo.__name__ == "VHCT":
self.curr_algo = self.algo(
nu=self.numax,
rho=rho,
domain=self.domain,
partition=self.partition,
)
# TODO: add more algorithms that do not need nu or rho
if self.counter < self.half_phase_length:
point = self.curr_algo.pull(time)
self.goodx = point
elif self.counter == self.half_phase_length:
point = self.goodx
self.V_x.append(point)
self.V_reward.append(0)
else:
point = self.goodx
if self.counter >= 2 * self.half_phase_length:
self.phase += 1
self.counter = 0
return point
def receive_reward(self, time, reward):
"""
The receive_reward function of GPO to receive the reward for the chosen point
Parameters
----------
time: int
The time step of the online process.
reward: float
The (Stochastic) reward of the pulled point
Returns
-------
"""
if self.phase > self.N: # If already finished
pass
elif self.phase == self.N:
maxind = np.argmax(np.array(self.V_reward))
self.goodx = self.V_x[maxind]
else:
if self.counter < self.half_phase_length:
self.curr_algo.receive_reward(time, reward)
else:
self.V_reward[self.phase - 1] = (
self.V_reward[self.phase - 1]
* (self.counter - self.half_phase_length)
+ reward
) / (self.counter - self.half_phase_length + 1)
self.counter += 1
def get_last_point(self):
"""
The function to get the last point in GPO
Returns
-------
"""
return self.V_x[np.argmax(np.array(self.V_reward))]