Skip to content

Commit faf00f9

Browse files
committed
ES exps
1 parent 3218b0d commit faf00f9

File tree

3 files changed

+160
-78
lines changed

3 files changed

+160
-78
lines changed

es.py

+36-78
Original file line numberDiff line numberDiff line change
@@ -1,93 +1,51 @@
1+
from __future__ import print_function
12
import numpy as np
2-
import backtrader as bt
3+
import multiprocessing as mp
34

4-
from evostra import EvolutionStrategy
5-
from keras.models import Model, Input, Sequential
6-
from keras.layers import Dense, Activation
75

6+
class EvolutionStrategy(object):
87

9-
class ESStrategy(bt.Strategy):
10-
params = {
11-
'model': None
12-
}
8+
def __init__(self, weights, get_reward_func, population_size=50, sigma=0.1, learning_rate=0.001):
9+
np.random.seed(0)
10+
self.weights = weights
11+
self.get_reward = get_reward_func
12+
self.POPULATION_SIZE = population_size
13+
self.SIGMA = sigma
14+
self.LEARNING_RATE = learning_rate
1315

14-
def __init__(self):
15-
self.order = None
16-
self.dataclose = self.datas[0].close
17-
self.datavol = self.datas[0].volume
1816

19-
def stop(self):
20-
cash = self.broker.getvalue()
21-
print('Result cash: {}'.format(cash))
17+
def _get_weights_try(self, w, p):
18+
weights_try = []
19+
for index, i in enumerate(p):
20+
jittered = self.SIGMA*i
21+
weights_try.append(w[index] + jittered)
22+
return weights_try
2223

23-
def notify_order(self, order):
24-
if order.status in [order.Submitted, order.Accepted]:
25-
return
2624

27-
self.order = None
25+
def get_weights(self):
26+
return self.weights
2827

29-
def next(self):
30-
if self.order:
31-
return
3228

33-
input_data = []
34-
for i in range(7):
35-
input_data.append(self.dataclose[i-6])
36-
input_data.append(self.datavol[i-6])
37-
inp = np.asanyarray(input_data)
38-
inp = np.expand_dims(inp, 0)
29+
def run(self, iterations, print_step=10):
30+
for iteration in range(iterations):
3931

40-
predict = self.p.model.predict(inp)[0]
41-
predict = np.argmax(predict)
32+
if iteration % print_step == 0:
33+
print('iter %d. reward: %f' % (iteration, self.get_reward(self.weights)))
4234

43-
if not self.position:
44-
if predict == 0:
45-
self.order = self.buy()
46-
else:
47-
if predict == 1:
48-
self.order = self.sell()
35+
population = []
36+
rewards = np.zeros(self.POPULATION_SIZE)
37+
for i in range(self.POPULATION_SIZE):
38+
x = []
39+
for w in self.weights:
40+
x.append(np.random.randn(*w.shape))
41+
population.append(x)
4942

50-
if not self.position:
51-
if predict == 1:
52-
self.order = self.sell()
53-
else:
54-
if predict == 0:
55-
self.order = self.buy()
43+
for i in range(self.POPULATION_SIZE):
44+
weights_try = self._get_weights_try(self.weights, population[i])
45+
rewards[i] = self.get_reward(weights_try)
5646

57-
model = Sequential()
58-
model.add(Dense(128, input_dim=14, activation='relu'))
59-
model.add(Dense(256, activation='relu'))
60-
model.add(Dense(512, activation='relu'))
61-
model.add(Dense(1024, activation='relu'))
62-
model.add(Dense(2, activation='relu'))
47+
rewards = (rewards - np.mean(rewards)) / np.std(rewards)
6348

64-
model.compile(optimizer='Adam', loss='mse')
65-
66-
data = bt.feeds.GenericCSVData(
67-
dataname='eur_usd_1d.csv',
68-
separator=',',
69-
dtformat=('%Y%m%d'),
70-
tmformat=('%H%M00'),
71-
datetime=0,
72-
time=1,
73-
open=2,
74-
high=3,
75-
low=4,
76-
close=5,
77-
volume=6,
78-
openinterest=-1
79-
)
80-
81-
def get_reward(weights):
82-
model.set_weights(weights)
83-
cerebro = bt.Cerebro()
84-
cerebro.addstrategy(ESStrategy, model=model)
85-
cerebro.adddata(data)
86-
cerebro.broker.setcash(1000)
87-
cerebro.addsizer(bt.sizers.FixedSize, stake=50)
88-
89-
cerebro.run()
90-
return cerebro.broker.getvalue()
91-
92-
es = EvolutionStrategy(model.get_weights(), get_reward, population_size=50, sigma=0.1, learning_rate=0.001)
93-
es.run(1000, print_step=100)
49+
for index, w in enumerate(self.weights):
50+
A = np.array([p[index] for p in population])
51+
self.weights[index] = w + self.LEARNING_RATE/(self.POPULATION_SIZE*self.SIGMA) * np.dot(A.T, rewards).T

simple_es.py

+106
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import numpy as np
2+
import backtrader as bt
3+
4+
from es import EvolutionStrategy
5+
from keras.models import Model, Input, Sequential
6+
from keras.layers import Dense, Activation
7+
8+
9+
class ESStrategy(bt.Strategy):
10+
params = {
11+
'model': None,
12+
'ma_period': 200,
13+
'rsi_period': 14
14+
}
15+
16+
def __init__(self):
17+
self.order = None
18+
self.dataclose = self.datas[0].close
19+
self.datavol = self.datas[0].volume
20+
self.sma = bt.indicators.SimpleMovingAverage(
21+
self.datas[0],
22+
period=self.p.ma_period
23+
)
24+
self.rsi = bt.indicators.RelativeStrengthIndex(
25+
period=self.p.rsi_period
26+
)
27+
28+
# def stop(self):
29+
# cash = self.broker.getvalue()
30+
# print('Result cash: {}'.format(cash))
31+
32+
def notify_order(self, order):
33+
if order.status in [order.Submitted, order.Accepted]:
34+
return
35+
36+
self.order = None
37+
38+
def next(self):
39+
if self.order:
40+
return
41+
42+
input_data = []
43+
for i in range(7):
44+
input_data.append(self.dataclose[i - 6])
45+
for i in range(7):
46+
input_data.append(self.datavol[i - 6])
47+
for i in range(7):
48+
input_data.append(self.sma[i - 6])
49+
for i in range(7):
50+
input_data.append(self.rsi[i-6])
51+
inp = np.asanyarray(input_data)
52+
inp = np.expand_dims(inp, 0)
53+
54+
predict = self.p.model.predict(inp)[0]
55+
predict = np.argmax(predict)
56+
57+
if not self.position:
58+
if predict == 0:
59+
self.order = self.buy()
60+
else:
61+
if predict == 1:
62+
self.order = self.sell()
63+
64+
if not self.position:
65+
if predict == 1:
66+
self.order = self.sell()
67+
else:
68+
if predict == 0:
69+
self.order = self.buy()
70+
71+
model = Sequential()
72+
model.add(Dense(128, input_dim=28, activation='relu'))
73+
model.add(Dense(256, activation='relu'))
74+
model.add(Dense(128, activation='relu'))
75+
model.add(Dense(2, activation='relu'))
76+
77+
model.compile(optimizer='Adam', loss='mse')
78+
79+
data = bt.feeds.GenericCSVData(
80+
dataname='eur_usd_1d.csv',
81+
separator=',',
82+
dtformat=('%Y%m%d'),
83+
tmformat=('%H%M00'),
84+
datetime=0,
85+
time=1,
86+
open=2,
87+
high=3,
88+
low=4,
89+
close=5,
90+
volume=6,
91+
openinterest=-1
92+
)
93+
94+
def get_reward(weights):
95+
model.set_weights(weights)
96+
cerebro = bt.Cerebro()
97+
cerebro.addstrategy(ESStrategy, model=model)
98+
cerebro.adddata(data)
99+
cerebro.broker.setcash(1000)
100+
cerebro.addsizer(bt.sizers.FixedSize, stake=50)
101+
102+
cerebro.run()
103+
return cerebro.broker.getvalue()
104+
105+
es = EvolutionStrategy(model.get_weights(), get_reward, population_size=50, sigma=0.2, learning_rate=0.001)
106+
es.run(1000, print_step=1)

test_mp.py

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import time
2+
import multiprocessing as mp
3+
4+
def worker(num):
5+
"""worker function"""
6+
# print('Worker: ', num)
7+
return num*num
8+
9+
if __name__ == '__main__':
10+
data = range(1000)
11+
pool_start = time.time()
12+
pool = mp.Pool()
13+
results = pool.map(worker, data)
14+
pool.close()
15+
pool.join()
16+
17+
print('Pool: {} sec'.format(time.time() - pool_start))
18+
print(results)

0 commit comments

Comments
 (0)