Skip to content

Commit d738fc6

Browse files
committed
add ES exps
1 parent 860b952 commit d738fc6

File tree

1 file changed

+24
-16
lines changed

1 file changed

+24
-16
lines changed

simple_es.py

+24-16
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@ def __init__(self):
2525
period=self.p.rsi_period
2626
)
2727

28-
# def stop(self):
29-
# cash = self.broker.getvalue()
30-
# print('Result cash: {}'.format(cash))
28+
def stop(self):
29+
cash = self.broker.getvalue()
30+
print('Result cash: {}'.format(cash))
3131

3232
def notify_order(self, order):
3333
if order.status in [order.Submitted, order.Accepted]:
@@ -39,11 +39,11 @@ def next(self):
3939
if self.order:
4040
return
4141

42-
input_data = []
43-
for i in range(7):
44-
input_data.append(self.dataclose[i - 6])
45-
for i in range(7):
46-
input_data.append(self.datavol[i - 6])
42+
input_data = [self.dataclose[0], self.datavol[0]]
43+
# for i in range(7):
44+
# input_data.append(self.dataclose[i - 6])
45+
# for i in range(7):
46+
# input_data.append(self.datavol[i - 6])
4747
# for i in range(7):
4848
# input_data.append(self.sma[i - 6])
4949
# for i in range(7):
@@ -54,6 +54,9 @@ def next(self):
5454
predict = self.p.model.predict(inp)[0]
5555
predict = np.argmax(predict)
5656

57+
if predict == 2:
58+
return
59+
5760
if not self.position:
5861
if predict == 0:
5962
self.order = self.buy()
@@ -69,13 +72,6 @@ def next(self):
6972
self.order = self.buy()
7073

7174

72-
model = Sequential()
73-
model.add(Dense(128, input_dim=14, activation='relu'))
74-
model.add(Dense(256, activation='relu'))
75-
model.add(Dense(2, activation='relu'))
76-
77-
model.compile(optimizer='Adam', loss='mse')
78-
7975
data = bt.feeds.GenericCSVData(
8076
dataname='eur_usd_1d.csv',
8177
separator=',',
@@ -91,9 +87,19 @@ def next(self):
9187
openinterest=-1
9288
)
9389

90+
def get_model():
91+
model = Sequential()
92+
model.add(Dense(128, input_dim=2, activation='relu'))
93+
# model.add(Dense(256, activation='relu'))
94+
model.add(Dense(3, activation='relu'))
95+
96+
model.compile(optimizer='Adam', loss='mse')
97+
return model
9498

9599
def get_reward(weights):
100+
model = get_model()
96101
model.set_weights(weights)
102+
97103
cerebro = bt.Cerebro()
98104
cerebro.addstrategy(ESStrategy, model=model)
99105
cerebro.adddata(data)
@@ -104,5 +110,7 @@ def get_reward(weights):
104110
return cerebro.broker.getvalue() - 5000.0
105111

106112

107-
es = EvolutionStrategy(model.get_weights(), get_reward, population_size=50, sigma=0.1, learning_rate=0.1)
113+
model = get_model()
114+
115+
es = EvolutionStrategy(model.get_weights(), get_reward, population_size=50, sigma=0.2, learning_rate=0.01)
108116
es.run(1000, print_step=1)

0 commit comments

Comments
 (0)