Skip to content

Commit

Permalink
Merge pull request #11 from zsrl/develop
Browse files Browse the repository at this point in the history
Develop
  • Loading branch information
zsrl authored Jan 17, 2025
2 parents 923b66e + 93c188a commit 9cf96c1
Show file tree
Hide file tree
Showing 7 changed files with 357 additions and 40 deletions.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -160,5 +160,4 @@ cython_debug/
#.idea/

xtquant
*.ipynb
poetry.lock
3 changes: 3 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"editor.tabSize": 4
}
153 changes: 153 additions & 0 deletions demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import backtrader as bt
from qmtbt import QMTStore
from datetime import datetime
from xtquant import xtdata
import math

class BuyCondition(bt.Indicator):
'''买入条件'''
lines = ('buy_signal',)

params = (
('up_days', 10), # 连续上涨的天数
)

def __init__(self):
self.lines.buy_signal = bt.If(self.data.close > self.data.close(-250), 1, 0)

def next(self):
# 检查250线斜率是否恰好连续向上self.params.up_days个交易日,再往前一个交易日斜率下降
if len(self) >= self.params.up_days + 1:
slope_up = all(self.data.close[-i] > self.data.close[-i-1] for i in range(1, self.params.up_days + 1))
slope_down_before = self.data.close[-self.params.up_days - 1] < self.data.close[-self.params.up_days - 2]
if slope_up and slope_down_before:
self.lines.buy_signal[0] = 1
else:
self.lines.buy_signal[0] = 0

class SellCondition(bt.Indicator):
'''卖出条件'''
lines = ('sell_signal',)

params = (
('hold_days', 20), # 持有天数
)

def __init__(self):
self.hold_days = 0

def next(self):
# 持有self.params.hold_days个交易日卖出
if self.hold_days >= self.params.hold_days:
self.lines.sell_signal[0] = 1
self.hold_days = 0
else:
self.lines.sell_signal[0] = 0
self.hold_days += 1

class Sizer(bt.Sizer):
'''仓位控制'''
params = (
('buy_count', 1), # 最大持仓股票个数
)

def __init__(self):
pass

def _getsizing(self, comminfo, cash, data, isbuy):
if isbuy:
# 如果是买入,平均分配仓位
commission_rate = comminfo.p.commission
size = math.floor(cash * (1 - commission_rate) / data.close[0] / self.params.buy_count / 100) * 100
else:
# 如果是卖出,全部卖出
position = self.broker.getposition(data)
size = position.size

return size

class DemoStrategy(bt.Strategy):
params = (
('max_positions', 5), # 最大持仓股票个数
('up_days', 10), # 连续上涨的天数
('hold_days', 20), # 持有天数
)

def log(self, txt, dt=None):
""" 记录交易日志 """
dt = dt or self.datas[0].datetime.date(0)
print(f'{dt.isoformat()}, {txt}')

def __init__(self):
# 初始化函数
self.sizer = Sizer()
self.buy_condition = {d: BuyCondition(d, up_days=self.params.up_days) for d in self.datas}
self.sell_condition = {d: SellCondition(d, hold_days=self.params.hold_days) for d in self.datas}

def next(self):
# 先收集所有需要买入和卖出的股票
buy_list = []
sell_list = []

for i, d in enumerate(self.datas):
pos = self.getposition(d).size

if pos and self.sell_condition[d].lines.sell_signal[0] > 0:
sell_list.append(d)

if self.buy_condition[d].lines.buy_signal[0] > 0:
buy_list.append(d)

# 动态设置Sizer的buy_count参数
self.sizer.params.buy_count = len(buy_list)

# 先执行卖出操作
for d in sell_list:
self.sell(data=d)

# 再执行买入操作
for d in buy_list:
self.buy(data=d)

if __name__ == '__main__':


store = QMTStore()

code_list = xtdata.get_stock_list_in_sector('沪深300')

# 添加数据
datas = store.getdatas(code_list=code_list, timeframe=bt.TimeFrame.Days, fromdate=datetime(2022, 7, 1))

for d in datas:
# print(len(d))
cerebro = bt.Cerebro(maxcpus=16)

cerebro.adddata(d)

# 添加策略
# buy_date = datetime(2022, 8, 1).date() # 设置固定买入日期
cerebro.addstrategy(DemoStrategy)

# cerebro.optstrategy

# # 设置初始资金
cerebro.broker.setcash(1000000.0)

# 设置佣金
cerebro.broker.setcommission(commission=0.001)

# 运行回测
# print('Starting Portfolio Value: %.2f' % cerebro.broker.getvalue())
cerebro.run()
if cerebro.broker.getvalue() != 1000000.0:
print('Final Portfolio Value: %.2f' % cerebro.broker.getvalue())

# data.test(1)
# data.test(2)
# data.test(3)
# data.test(4)
# xtdata.run()

# 绘制结果
# cerebro.plot()
38 changes: 37 additions & 1 deletion qmtbt/qmtbroker.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from backtrader import BrokerBase, OrderBase, Order
from backtrader.position import Position
from backtrader.utils.py3 import queue, with_metaclass

import random
from xtquant.xttrader import XtQuantTrader
from xtquant.xttype import StockAccount
from .qmtstore import QMTStore

class QMTOrder(OrderBase):
Expand All @@ -28,15 +30,49 @@ class QMTBroker(BrokerBase, metaclass=MetaQMTBroker):
def __init__(self, **kwargs):

self.store = QMTStore(**kwargs)
self.mini_qmt_path = kwargs.get('mini_qmt_path')
self.account_id = kwargs.get('account_id')

session_id = int(random.randint(100000, 999999))

xt_trader = XtQuantTrader(self.mini_qmt_path, session_id)
# 启动交易对象
xt_trader.start()
# 连接客户端
connect_result = xt_trader.connect()

if connect_result == 0:
print('连接成功')

account = StockAccount(self.account_id)
# 订阅账号
res = xt_trader.subscribe(account)

self.xt_trader = xt_trader
self.account = account


def getcash(self):
res = self.query_stock_asset(self.account)

self.cash = res.cash

return self.cash

def getvalue(self, datas=None):

res = self.query_stock_asset(self.account)

self.value = res.market_value

return self.value

def getposition(self, data, clone=True):

xt_position = self.xt_trader.query_stock_position(self.account, data._dataname)
pos = Position(size=xt_position.volume, price=xt_position.avg_price)
return pos

def get_notification(self):
try:
return self.notifs.get(False)
Expand Down
90 changes: 75 additions & 15 deletions qmtbt/qmtfeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
unicode_literals)

from collections import deque
from datetime import datetime
import datetime
import backtrader as bt
from backtrader.feed import DataBase
import time
import threading
import random

from .qmtstore import QMTStore

Expand All @@ -26,7 +29,7 @@ class QMTFeed(DataBase, metaclass=MetaQMTFeed):
- ``historical`` (default: ``False``)
"""

lines = ('lastClose', 'amount', 'pvolume', 'stockStatus', 'openInt', 'lastSettlementPrice', 'settlementPrice', 'transactionNum', 'askPrice1', 'askPrice2', 'askPrice3', 'askPrice4', 'askPrice5', 'bidPrice1', 'bidPrice2', 'bidPrice3', 'bidPrice4', 'bidPrice5', 'askVol1', 'askVol2', 'askVol3', 'askVol4', 'askVol5', 'bidVol1', 'bidVol2', 'bidVol3', 'bidVol4', 'bidVol5', )
lines = ('lastClose', 'amount', 'pvolume', 'stockStatus', 'openInt', 'lastSettlementPrice', 'settlementPrice', 'transactionNum', 'askPrice1', 'askPrice2', 'askPrice3', 'askPrice4', 'askPrice5', 'bidPrice1', 'bidPrice2', 'bidPrice3', 'bidPrice4', 'bidPrice5', 'askVol1', 'askVol2', 'askVol3', 'askVol4', 'askVol5', 'bidVol1', 'bidVol2', 'bidVol3', 'bidVol4', 'bidVol5', 'openInterest', 'dr', 'totaldr', 'preClose', 'suspendFlag', 'settelementPrice', 'pe' )

params = (
('live', False), # only historical download
Expand All @@ -37,9 +40,13 @@ def __init__(self, **kwargs):
self._timeframe = self.p.timeframe
self._compression = 1
self.store = kwargs['store']
# self.cerebro = kwargs['cerebro']
self._data = deque() # data queue for price data
self._seq = None

# def __len__(self):
# return len(self._data)

def start(self, ):
DataBase.start(self)

Expand All @@ -53,34 +60,43 @@ def start(self, ):
self._history_data(period=period_map[self.p.timeframe])
print(f'{self.p.dataname}历史数据装载成功!')
else:
self._history_data(period=period_map[self.p.timeframe])
self._live_data(period=period_map[self.p.timeframe])
print(f'{self.p.dataname}实时数据装载成功!')
# self._live_data(period=period_map[self.p.timeframe])

def stop(self):
DataBase.stop(self)

if self.p.live:
self.store._unsubscribe_live(self._seq)

def _load(self):
while len(self._data):

current = self._data.popleft()
def _get_datetime(self, value):
dtime = datetime.datetime.fromtimestamp(value // 1000)
return bt.date2num(dtime)

for key in current.keys():
def _load_current(self, current):
for key in current.keys():
try:
value = current[key]
if key == 'time':
dtime = datetime.fromtimestamp(value // 1000)
self.lines.datetime[0] = bt.date2num(dtime)
self.lines.datetime[0] = self._get_datetime(value)
elif key == 'lastPrice' and self.p.timeframe == bt.TimeFrame.Ticks:
self.lines.close[0] = value
else:
attr = getattr(self.lines, key)
attr[0] = value
except:
except Exception as e:
print(e)
pass
# print(current, 'current')
self.put_notification(int(random.randint(100000, 999999)))

def _load(self, replace=False):
if len(self._data) > 0:

current = self._data.popleft()

self._load_current(current)

return True
return None

Expand Down Expand Up @@ -111,12 +127,56 @@ def _history_data(self, period):
res = self.store._fetch_history(symbol=self.p.dataname, period=period, start_time=start_time, end_time=end_time)
result = res.to_dict('records')
for item in result:
self._data.append(item)
if item.get('close') != 0 and item.get('lastPrice') != 0:
self._data.append(item)

def _live_data(self, period):

start_time = self._format_datetime(self.p.fromdate, period)

def on_data(res):
self._data.append(res.iloc[0].to_dict())
print(self.lines.datetime)
# current = res[self.p.dataname][0]
# if self._get_datetime(current['time']) == self.lines.datetime[0]:
# self._load_current(current)
# else:
# self._data.append(current)


self._seq = self.store._subscribe_live(symbol=self.p.dataname, period=period, start_time=start_time, callback=on_data)

res = self.store._fetch_history(symbol=self.p.dataname, period=period, start_time=start_time, download=False)
result = res.to_dict('records')
for item in result:
self._data.append(item)

# def test(self, close):
# # 获取当前日期
# current_date = datetime.datetime.now().date()

# # 将当前日期转换为datetime对象,并将时间设置为00:00:00
# start_of_day = datetime.datetime.combine(current_date, datetime.time.min)

# # 将开始时间转换为Unix时间戳(以毫秒为单位)
# start_of_day_unix_ms = int(start_of_day.timestamp() * 1000)
# current = {
# 'time': start_of_day_unix_ms,
# 'close': close,
# 'open': 1.0,
# 'high': 1.0,
# 'low': 1.0,
# 'volume': 1.0,
# }
# self._load_current(current)

# # 定义一个函数,这个函数将在每次循环中执行
# def my_function():
# while True:
# self.test(1)
# time.sleep(3) # 暂停3秒

# # 创建并启动线程
# thread = threading.Thread(target=my_function)
# thread.daemon = True # 设置为守护线程,这样主线程结束时,子线程也会结束
# thread.start()

self._seq = self.store._subscribe_live(symbol=self.p.dataname, period=period, callback=on_data)
Loading

0 comments on commit 9cf96c1

Please sign in to comment.