From 9022da074e0590de6ae8f1c51fdd74cfbf5dd36a Mon Sep 17 00:00:00 2001 From: caimeng <862786917@qq.com> Date: Wed, 10 Jul 2019 13:37:57 +0800 Subject: [PATCH 1/3] keep choice for run terminal --- vnpy/applications/VnTerminal/__init__.py | 5 +-- vnpy/applications/VnTerminal/run.py | 43 +++++++++++++++--------- 2 files changed, 30 insertions(+), 18 deletions(-) diff --git a/vnpy/applications/VnTerminal/__init__.py b/vnpy/applications/VnTerminal/__init__.py index 2227904..1e5f9dd 100644 --- a/vnpy/applications/VnTerminal/__init__.py +++ b/vnpy/applications/VnTerminal/__init__.py @@ -6,5 +6,6 @@ @click.option('-m', '--monitor', is_flag=True) -def app_cli(monitor=False): - main(monitor=monitor) +@click.option('-k', '--keep', is_flag=True) +def app_cli(monitor=False, keep=False): + main(monitor=monitor, keep=keep) diff --git a/vnpy/applications/VnTerminal/run.py b/vnpy/applications/VnTerminal/run.py index ec0eeb6..b1f5d3f 100644 --- a/vnpy/applications/VnTerminal/run.py +++ b/vnpy/applications/VnTerminal/run.py @@ -141,6 +141,13 @@ def run(self, monitor=False): le.info(u"开始所有策略") cta.startAll() + def activeStrategyCount(self): + count = 0 + for name, strategy in self.cta.strategyDict.items(): + if strategy.trading: + count += 1 + return count + def join(self): while self.running: sleep(1) @@ -160,7 +167,7 @@ def stop(self): logging.info(u"交易程序正常退出") except Exception as e: logging.exception(e) - + self.me.exit() class DaemonApp(object): def __init__(self): @@ -184,18 +191,11 @@ def run(self, monitor=False): self.process = None # 子进程句柄 - def join(self): + def join(self, keep=False): while self.running: currentTime = datetime.now().time() recording = True - # TODO: 设置交易时段 - # 判断当前处于的时间段 - # if ((currentTime >= DAY_START and currentTime <= DAY_END) or - # (currentTime >= NIGHT_START) or - # (currentTime <= NIGHT_END)): - # recording = True - # 记录时间则需要启动子进程 if recording and self.process is None: # TODO: 可能多次启动,可能要在启动前对pipe进行清理或重新创建 @@ -203,15 +203,24 @@ def join(self): self.process = multiprocessing.Process( target=self._run_child, args=(self.pchild, ), - kwargs={"monitor": self._run_with_monitor}) + kwargs={"monitor": self._run_with_monitor}, + daemon=True + ) self.process.start() logging.info(u'子进程启动成功') - + if self.process: + if not self.process.is_alive(): + if keep: + self._stop_child() + else: + self.stop() # 非记录时间则退出子进程 - if not recording and self.process is not None: - self._stop_child() + # if not recording and self.process is not None: + # self._stop_child() sleep(5) + logging.info("停止CTA策略守护父进程") + @staticmethod def _run_child(p, monitor=False): @@ -231,6 +240,8 @@ def interrupt(signal, event): p.recv() raise KeyboardInterrupt else: + if not app.activeStrategyCount(): + app.stop() continue except KeyboardInterrupt: app.stop() @@ -260,11 +271,11 @@ def _stop_child(self): logging.info(u'子进程关闭成功') def stop(self): - self.runing = False + self.running = False self._stop_child() -def main(monitor=False): +def main(monitor=False, keep=False): import signal import logging @@ -276,6 +287,6 @@ def interrupt(signal, event): app = DaemonApp() try: app.run(monitor=monitor) - app.join() + app.join(keep) except KeyboardInterrupt: app.stop() From 8abae9ba0d5a9f78ade2ad38650da65a69301838 Mon Sep 17 00:00:00 2001 From: caimeng <862786917@qq.com> Date: Wed, 10 Jul 2019 17:01:06 +0800 Subject: [PATCH 2/3] engine auto exit --- vnpy/applications/VnTerminal/run.py | 5 +- vnpy/event/eventEngine.py | 716 ++++---- .../trader/gateway/okexGateway/okexGateway.py | 710 ++++---- vnpy/trader/vtEngine.py | 1516 +++++++++-------- vnpy/trader/vtEvent.py | 41 +- 5 files changed, 1503 insertions(+), 1485 deletions(-) diff --git a/vnpy/applications/VnTerminal/run.py b/vnpy/applications/VnTerminal/run.py index b1f5d3f..5d3e0a8 100644 --- a/vnpy/applications/VnTerminal/run.py +++ b/vnpy/applications/VnTerminal/run.py @@ -240,8 +240,11 @@ def interrupt(signal, event): p.recv() raise KeyboardInterrupt else: - if not app.activeStrategyCount(): + if not app.ee.isActive: app.stop() + elif not app.activeStrategyCount(): + app.stop() + continue except KeyboardInterrupt: app.stop() diff --git a/vnpy/event/eventEngine.py b/vnpy/event/eventEngine.py index b401b27..670facf 100644 --- a/vnpy/event/eventEngine.py +++ b/vnpy/event/eventEngine.py @@ -1,356 +1,362 @@ -# encoding: UTF-8 - -# 系统模块 -from queue import Queue, Empty -from threading import Thread -from time import sleep -from collections import defaultdict - -# 第三方模块 -from qtpy.QtCore import QTimer - -# 自己开发的模块 -from .eventType import * - - -######################################################################## -class EventEngine(object): - """ - 事件驱动引擎 - 事件驱动引擎中所有的变量都设置为了私有,这是为了防止不小心 - 从外部修改了这些变量的值或状态,导致bug。 - - 变量说明 - __queue:私有变量,事件队列 - __active:私有变量,事件引擎开关 - __thread:私有变量,事件处理线程 - __timer:私有变量,计时器 - __handlers:私有变量,事件处理函数字典 - - - 方法说明 - __run: 私有方法,事件处理线程连续运行用 - __process: 私有方法,处理事件,调用注册在引擎中的监听函数 - __onTimer:私有方法,计时器固定事件间隔触发后,向事件队列中存入计时器事件 - start: 公共方法,启动引擎 - stop:公共方法,停止引擎 - register:公共方法,向引擎中注册监听函数 - unregister:公共方法,向引擎中注销监听函数 - put:公共方法,向事件队列中存入新的事件 - - 事件监听函数必须定义为输入参数仅为一个event对象,即: - - 函数 - def func(event) - ... - - 对象方法 - def method(self, event) - ... - - """ - - #---------------------------------------------------------------------- - def __init__(self): - """初始化事件引擎""" - # 事件队列 - self.__queue = Queue() - - # 事件引擎开关 - self.__active = False - - # 事件处理线程 - self.__thread = Thread(target = self.__run) - - # 计时器,用于触发计时器事件 - self.__timer = QTimer() - self.__timer.timeout.connect(self.__onTimer) - - # 这里的__handlers是一个字典,用来保存对应的事件调用关系 - # 其中每个键对应的值是一个列表,列表中保存了对该事件进行监听的函数功能 - self.__handlers = defaultdict(list) - - # __generalHandlers是一个列表,用来保存通用回调函数(所有事件均调用) - self.__generalHandlers = [] - - #---------------------------------------------------------------------- - def __run(self): - """引擎运行""" - while self.__active == True: - try: - event = self.__queue.get(block = True, timeout = 1) # 获取事件的阻塞时间设为1秒 - self.__process(event) - except Empty: - pass - - #---------------------------------------------------------------------- - def __process(self, event): - """处理事件""" - # 检查是否存在对该事件进行监听的处理函数 - if event.type_ in self.__handlers: - # 若存在,则按顺序将事件传递给处理函数执行 - [handler(event) for handler in self.__handlers[event.type_]] - - # 以上语句为Python列表解析方式的写法,对应的常规循环写法为: - #for handler in self.__handlers[event.type_]: - #handler(event) - - # 调用通用处理函数进行处理 - if self.__generalHandlers: - [handler(event) for handler in self.__generalHandlers] - - #---------------------------------------------------------------------- - def __onTimer(self): - """向事件队列中存入计时器事件""" - # 创建计时器事件 - event = Event(type_=EVENT_TIMER) - - # 向队列中存入计时器事件 - self.put(event) - - #---------------------------------------------------------------------- - def start(self, timer=True): - """ - 引擎启动 - timer:是否要启动计时器 - """ - # 将引擎设为启动 - self.__active = True - - # 启动事件处理线程 - self.__thread.start() - - # 启动计时器,计时器事件间隔默认设定为1秒 - if timer: - self.__timer.start(1000) - - #---------------------------------------------------------------------- - def stop(self): - """停止引擎""" - # 将引擎设为停止 - self.__active = False - - # 停止计时器 - self.__timer.stop() - - # 等待事件处理线程退出 - self.__thread.join() - - #---------------------------------------------------------------------- - def register(self, type_, handler): - """注册事件处理函数监听""" - # 尝试获取该事件类型对应的处理函数列表,若无defaultDict会自动创建新的list - handlerList = self.__handlers[type_] - - # 若要注册的处理器不在该事件的处理器列表中,则注册该事件 - if handler not in handlerList: - handlerList.append(handler) - - #---------------------------------------------------------------------- - def unregister(self, type_, handler): - """注销事件处理函数监听""" - # 尝试获取该事件类型对应的处理函数列表,若无则忽略该次注销请求 - handlerList = self.__handlers[type_] - - # 如果该函数存在于列表中,则移除 - if handler in handlerList: - handlerList.remove(handler) - - # 如果函数列表为空,则从引擎中移除该事件类型 - if not handlerList: - del self.__handlers[type_] - - #---------------------------------------------------------------------- - def put(self, event): - """向事件队列中存入事件""" - self.__queue.put(event) - - #---------------------------------------------------------------------- - def registerGeneralHandler(self, handler): - """注册通用事件处理函数监听""" - if handler not in self.__generalHandlers: - self.__generalHandlers.append(handler) - - #---------------------------------------------------------------------- - def unregisterGeneralHandler(self, handler): - """注销通用事件处理函数监听""" - if handler in self.__generalHandlers: - self.__generalHandlers.remove(handler) - - - -######################################################################## -class EventEngine2(object): - """ - 计时器使用python线程的事件驱动引擎 - """ - - #---------------------------------------------------------------------- - def __init__(self): - """初始化事件引擎""" - # 事件队列 - self.__queue = Queue() - - # 事件引擎开关 - self.__active = False - - # 事件处理线程 - self.__thread = Thread(target = self.__run) - - # 计时器,用于触发计时器事件 - self.__timer = Thread(target = self.__runTimer) - self.__timerActive = False # 计时器工作状态 - self.__timerSleep = 1 # 计时器触发间隔(默认1秒) - - # 这里的__handlers是一个字典,用来保存对应的事件调用关系 - # 其中每个键对应的值是一个列表,列表中保存了对该事件进行监听的函数功能 - self.__handlers = defaultdict(list) - - # __generalHandlers是一个列表,用来保存通用回调函数(所有事件均调用) - self.__generalHandlers = [] - - #---------------------------------------------------------------------- - def __run(self): - """引擎运行""" - while self.__active == True: - try: - event = self.__queue.get(block = True, timeout = 1) # 获取事件的阻塞时间设为1秒 - self.__process(event) - except Empty: - pass - - #---------------------------------------------------------------------- - def __process(self, event): - """处理事件""" - # 检查是否存在对该事件进行监听的处理函数 - if event.type_ in self.__handlers: - # 若存在,则按顺序将事件传递给处理函数执行 - [handler(event) for handler in self.__handlers[event.type_]] - - # 以上语句为Python列表解析方式的写法,对应的常规循环写法为: - #for handler in self.__handlers[event.type_]: - #handler(event) - - # 调用通用处理函数进行处理 - if self.__generalHandlers: - [handler(event) for handler in self.__generalHandlers] - - #---------------------------------------------------------------------- - def __runTimer(self): - """运行在计时器线程中的循环函数""" - while self.__timerActive: - # 创建计时器事件 - event = Event(type_=EVENT_TIMER) - - # 向队列中存入计时器事件 - self.put(event) - - # 等待 - sleep(self.__timerSleep) - - #---------------------------------------------------------------------- - def start(self, timer=True): - """ - 引擎启动 - timer:是否要启动计时器 - """ - # 将引擎设为启动 - self.__active = True - - # 启动事件处理线程 - self.__thread.start() - - # 启动计时器,计时器事件间隔默认设定为1秒 - if timer: - self.__timerActive = True - self.__timer.start() - - #---------------------------------------------------------------------- - def stop(self): - """停止引擎""" - # 将引擎设为停止 - self.__active = False - - # 停止计时器 - self.__timerActive = False - self.__timer.join() - - # 等待事件处理线程退出 - self.__thread.join() - - #---------------------------------------------------------------------- - def register(self, type_, handler): - """注册事件处理函数监听""" - # 尝试获取该事件类型对应的处理函数列表,若无defaultDict会自动创建新的list - handlerList = self.__handlers[type_] - - # 若要注册的处理器不在该事件的处理器列表中,则注册该事件 - if handler not in handlerList: - handlerList.append(handler) - - #---------------------------------------------------------------------- - def unregister(self, type_, handler): - """注销事件处理函数监听""" - # 尝试获取该事件类型对应的处理函数列表,若无则忽略该次注销请求 - handlerList = self.__handlers[type_] - - # 如果该函数存在于列表中,则移除 - if handler in handlerList: - handlerList.remove(handler) - - # 如果函数列表为空,则从引擎中移除该事件类型 - if not handlerList: - del self.__handlers[type_] - - #---------------------------------------------------------------------- - def put(self, event): - """向事件队列中存入事件""" - self.__queue.put(event) - - #---------------------------------------------------------------------- - def registerGeneralHandler(self, handler): - """注册通用事件处理函数监听""" - if handler not in self.__generalHandlers: - self.__generalHandlers.append(handler) - - #---------------------------------------------------------------------- - def unregisterGeneralHandler(self, handler): - """注销通用事件处理函数监听""" - if handler in self.__generalHandlers: - self.__generalHandlers.remove(handler) - - -######################################################################## -class Event: - """事件对象""" - - #---------------------------------------------------------------------- - def __init__(self, type_=None): - """Constructor""" - self.type_ = type_ # 事件类型 - self.dict_ = {} # 字典用于保存具体的事件数据 - - -#---------------------------------------------------------------------- -def test(): - """测试函数""" - import sys - from datetime import datetime - from PyQt4.QtCore import QCoreApplication - - def simpletest(event): - print('处理每秒触发的计时器事件:%s' % str(datetime.now())) - - app = QCoreApplication(sys.argv) - - ee = EventEngine2() - #ee.register(EVENT_TIMER, simpletest) - ee.registerGeneralHandler(simpletest) - ee.start() - - app.exec_() - - -# 直接运行脚本可以进行测试 -if __name__ == '__main__': +# encoding: UTF-8 + +# 系统模块 +from queue import Queue, Empty +from threading import Thread +from time import sleep +from collections import defaultdict + +# 第三方模块 +from qtpy.QtCore import QTimer + +# 自己开发的模块 +from .eventType import * + + +######################################################################## +class EventEngine(object): + """ + 事件驱动引擎 + 事件驱动引擎中所有的变量都设置为了私有,这是为了防止不小心 + 从外部修改了这些变量的值或状态,导致bug。 + + 变量说明 + __queue:私有变量,事件队列 + __active:私有变量,事件引擎开关 + __thread:私有变量,事件处理线程 + __timer:私有变量,计时器 + __handlers:私有变量,事件处理函数字典 + + + 方法说明 + __run: 私有方法,事件处理线程连续运行用 + __process: 私有方法,处理事件,调用注册在引擎中的监听函数 + __onTimer:私有方法,计时器固定事件间隔触发后,向事件队列中存入计时器事件 + start: 公共方法,启动引擎 + stop:公共方法,停止引擎 + register:公共方法,向引擎中注册监听函数 + unregister:公共方法,向引擎中注销监听函数 + put:公共方法,向事件队列中存入新的事件 + + 事件监听函数必须定义为输入参数仅为一个event对象,即: + + 函数 + def func(event) + ... + + 对象方法 + def method(self, event) + ... + + """ + + #---------------------------------------------------------------------- + def __init__(self): + """初始化事件引擎""" + # 事件队列 + self.__queue = Queue() + + # 事件引擎开关 + self.__active = False + + # 事件处理线程 + self.__thread = Thread(target = self.__run) + + # 计时器,用于触发计时器事件 + self.__timer = QTimer() + self.__timer.timeout.connect(self.__onTimer) + + # 这里的__handlers是一个字典,用来保存对应的事件调用关系 + # 其中每个键对应的值是一个列表,列表中保存了对该事件进行监听的函数功能 + self.__handlers = defaultdict(list) + + # __generalHandlers是一个列表,用来保存通用回调函数(所有事件均调用) + self.__generalHandlers = [] + + #---------------------------------------------------------------------- + def __run(self): + """引擎运行""" + while self.__active == True: + try: + event = self.__queue.get(block = True, timeout = 1) # 获取事件的阻塞时间设为1秒 + self.__process(event) + except Empty: + pass + + #---------------------------------------------------------------------- + def __process(self, event): + """处理事件""" + # 检查是否存在对该事件进行监听的处理函数 + if event.type_ in self.__handlers: + # 若存在,则按顺序将事件传递给处理函数执行 + [handler(event) for handler in self.__handlers[event.type_]] + + # 以上语句为Python列表解析方式的写法,对应的常规循环写法为: + #for handler in self.__handlers[event.type_]: + #handler(event) + + # 调用通用处理函数进行处理 + if self.__generalHandlers: + [handler(event) for handler in self.__generalHandlers] + + #---------------------------------------------------------------------- + def __onTimer(self): + """向事件队列中存入计时器事件""" + # 创建计时器事件 + event = Event(type_=EVENT_TIMER) + + # 向队列中存入计时器事件 + self.put(event) + + #---------------------------------------------------------------------- + def start(self, timer=True): + """ + 引擎启动 + timer:是否要启动计时器 + """ + # 将引擎设为启动 + self.__active = True + + # 启动事件处理线程 + self.__thread.start() + + # 启动计时器,计时器事件间隔默认设定为1秒 + if timer: + self.__timer.start(1000) + + #---------------------------------------------------------------------- + def stop(self): + """停止引擎""" + # 将引擎设为停止 + self.__active = False + + # 停止计时器 + self.__timer.stop() + + # 等待事件处理线程退出 + self.__thread.join() + + #---------------------------------------------------------------------- + def register(self, type_, handler): + """注册事件处理函数监听""" + # 尝试获取该事件类型对应的处理函数列表,若无defaultDict会自动创建新的list + handlerList = self.__handlers[type_] + + # 若要注册的处理器不在该事件的处理器列表中,则注册该事件 + if handler not in handlerList: + handlerList.append(handler) + + #---------------------------------------------------------------------- + def unregister(self, type_, handler): + """注销事件处理函数监听""" + # 尝试获取该事件类型对应的处理函数列表,若无则忽略该次注销请求 + handlerList = self.__handlers[type_] + + # 如果该函数存在于列表中,则移除 + if handler in handlerList: + handlerList.remove(handler) + + # 如果函数列表为空,则从引擎中移除该事件类型 + if not handlerList: + del self.__handlers[type_] + + #---------------------------------------------------------------------- + def put(self, event): + """向事件队列中存入事件""" + self.__queue.put(event) + + #---------------------------------------------------------------------- + def registerGeneralHandler(self, handler): + """注册通用事件处理函数监听""" + if handler not in self.__generalHandlers: + self.__generalHandlers.append(handler) + + #---------------------------------------------------------------------- + def unregisterGeneralHandler(self, handler): + """注销通用事件处理函数监听""" + if handler in self.__generalHandlers: + self.__generalHandlers.remove(handler) + + + +######################################################################## +class EventEngine2(object): + """ + 计时器使用python线程的事件驱动引擎 + """ + + #---------------------------------------------------------------------- + def __init__(self): + """初始化事件引擎""" + # 事件队列 + self.__queue = Queue() + + # 事件引擎开关 + self.__active = False + + # 事件处理线程 + self.__thread = Thread(target = self.__run) + + # 计时器,用于触发计时器事件 + self.__timer = Thread(target = self.__runTimer) + self.__timerActive = False # 计时器工作状态 + self.__timerSleep = 1 # 计时器触发间隔(默认1秒) + + # 这里的__handlers是一个字典,用来保存对应的事件调用关系 + # 其中每个键对应的值是一个列表,列表中保存了对该事件进行监听的函数功能 + self.__handlers = defaultdict(list) + + # __generalHandlers是一个列表,用来保存通用回调函数(所有事件均调用) + self.__generalHandlers = [] + + #---------------------------------------------------------------------- + def __run(self): + """引擎运行""" + while self.__active == True: + try: + event = self.__queue.get(block = True, timeout = 1) # 获取事件的阻塞时间设为1秒 + self.__process(event) + except Empty: + pass + + #---------------------------------------------------------------------- + def __process(self, event): + """处理事件""" + # 检查是否存在对该事件进行监听的处理函数 + if event.type_ in self.__handlers: + # 若存在,则按顺序将事件传递给处理函数执行 + [handler(event) for handler in self.__handlers[event.type_]] + + # 以上语句为Python列表解析方式的写法,对应的常规循环写法为: + #for handler in self.__handlers[event.type_]: + #handler(event) + + # 调用通用处理函数进行处理 + if self.__generalHandlers: + [handler(event) for handler in self.__generalHandlers] + + #---------------------------------------------------------------------- + def __runTimer(self): + """运行在计时器线程中的循环函数""" + while self.__timerActive: + # 创建计时器事件 + event = Event(type_=EVENT_TIMER) + + # 向队列中存入计时器事件 + self.put(event) + + # 等待 + sleep(self.__timerSleep) + + #---------------------------------------------------------------------- + def start(self, timer=True): + """ + 引擎启动 + timer:是否要启动计时器 + """ + # 将引擎设为启动 + self.__active = True + + # 启动事件处理线程 + self.__thread.start() + + # 启动计时器,计时器事件间隔默认设定为1秒 + if timer: + self.__timerActive = True + self.__timer.start() + + #---------------------------------------------------------------------- + def stop(self): + """停止引擎""" + # 将引擎设为停止 + self.__active = False + + # 停止计时器 + self.__timerActive = False + self.__timer.join() + + # 等待事件处理线程退出 + self.__thread.join() + + #---------------------------------------------------------------------- + def register(self, type_, handler): + """注册事件处理函数监听""" + # 尝试获取该事件类型对应的处理函数列表,若无defaultDict会自动创建新的list + handlerList = self.__handlers[type_] + + # 若要注册的处理器不在该事件的处理器列表中,则注册该事件 + if handler not in handlerList: + handlerList.append(handler) + + #---------------------------------------------------------------------- + def unregister(self, type_, handler): + """注销事件处理函数监听""" + # 尝试获取该事件类型对应的处理函数列表,若无则忽略该次注销请求 + handlerList = self.__handlers[type_] + + # 如果该函数存在于列表中,则移除 + if handler in handlerList: + handlerList.remove(handler) + + # 如果函数列表为空,则从引擎中移除该事件类型 + if not handlerList: + del self.__handlers[type_] + + #---------------------------------------------------------------------- + def put(self, event): + """向事件队列中存入事件""" + self.__queue.put(event) + + #---------------------------------------------------------------------- + def registerGeneralHandler(self, handler): + """注册通用事件处理函数监听""" + if handler not in self.__generalHandlers: + self.__generalHandlers.append(handler) + + #---------------------------------------------------------------------- + def unregisterGeneralHandler(self, handler): + """注销通用事件处理函数监听""" + if handler in self.__generalHandlers: + self.__generalHandlers.remove(handler) + + def inactivate(self): + self.__active = False + + @property + def isActive(self): + return self.__active + +######################################################################## +class Event: + """事件对象""" + + #---------------------------------------------------------------------- + def __init__(self, type_=None): + """Constructor""" + self.type_ = type_ # 事件类型 + self.dict_ = {} # 字典用于保存具体的事件数据 + + +#---------------------------------------------------------------------- +def test(): + """测试函数""" + import sys + from datetime import datetime + from PyQt4.QtCore import QCoreApplication + + def simpletest(event): + print('处理每秒触发的计时器事件:%s' % str(datetime.now())) + + app = QCoreApplication(sys.argv) + + ee = EventEngine2() + #ee.register(EVENT_TIMER, simpletest) + ee.registerGeneralHandler(simpletest) + ee.start() + + app.exec_() + + +# 直接运行脚本可以进行测试 +if __name__ == '__main__': test() \ No newline at end of file diff --git a/vnpy/trader/gateway/okexGateway/okexGateway.py b/vnpy/trader/gateway/okexGateway/okexGateway.py index e634036..0954e83 100644 --- a/vnpy/trader/gateway/okexGateway/okexGateway.py +++ b/vnpy/trader/gateway/okexGateway/okexGateway.py @@ -1,353 +1,357 @@ -import os -import json -import time -import logging -from datetime import datetime, timezone, timedelta - -from vnpy.api.rest import RestClient, Request -from vnpy.api.websocket import WebsocketClient -from vnpy.trader.vtGateway import * -from vnpy.trader.vtConstant import constant -from vnpy.trader.vtFunction import getJsonPath, getTempPath -from .future import OkexfRestApi, OkexfWebsocketApi -from .swap import OkexSwapRestApi, OkexSwapWebsocketApi -from .spot import OkexSpotRestApi, OkexSpotWebsocketApi -from .util import ISO_DATETIME_FORMAT, granularityMap - -REST_HOST = os.environ.get('OKEX_REST_URL', 'https://www.okex.com') -WEBSOCKET_HOST = os.environ.get('OKEX_WEBSOCKET_URL', 'wss://real.okex.com:10442/ws/v3') - -######################################################################## -class OkexGateway(VtGateway): - """OKEX V3 接口""" - - #---------------------------------------------------------------------- - def __init__(self, eventEngine, gatewayName=''): - """Constructor""" - super(OkexGateway, self).__init__(eventEngine, gatewayName) - - self.qryEnabled = False # 是否要启动循环查询 - - self.fileName = self.gatewayName + '_connect.json' - self.filePath = getJsonPath(self.fileName, __file__) - - self.apiKey = '' - self.apiSecret = '' - self.passphrase = '' - - self.symbolTypeMap = {} - self.gatewayMap = {} - self.stgMap = {} - - self.orderID = 1 - self.tradeID = 0 - self.loginTime = int(datetime.now().strftime('%y%m%d%H%M%S')) * 100 - - #---------------------------------------------------------------------- - def connect(self): - """连接""" - try: - f = open(self.filePath) - except IOError: - self.writeLog(u"读取连接配置出错,请检查配置文件", logging.ERROR) - return - - # 解析connect.json文件 - setting = json.load(f) - f.close() - - try: - self.apiKey = str(setting['apiKey']) - self.apiSecret = str(setting['apiSecret']) - self.passphrase = str(setting['passphrase']) - sessionCount = int(setting['sessionCount']) - subscrib_symbols = setting['symbols'] - except KeyError as e: - self.writeLog(f"{self.gatewayName} 连接配置缺少字段,请检查{e}", logging.ERROR) - return - - # 记录订阅的交易品种类型 - contract_list = [] - swap_list = [] - spot_list = [] - for symbol in subscrib_symbols: - if "WEEK" in symbol or "QUARTER" in symbol: - self.symbolTypeMap[symbol] = "FUTURE" - contract_list.append(symbol) - elif "SWAP" in symbol: - self.symbolTypeMap[symbol] = "SWAP" - swap_list.append(symbol) - else: - self.symbolTypeMap[symbol] = "SPOT" - spot_list.append(symbol) - - # 创建行情和交易接口对象 - future_leverage = setting.get('future_leverage', 10) - swap_leverage = setting.get('swap_leverage', 1) - margin_token = setting.get('margin_token', 0) - - # 实例化对应品种类别的API - gateway_type = set(self.symbolTypeMap.values()) - if "FUTURE" in gateway_type: - restfutureApi = OkexfRestApi(self) - wsfutureApi = OkexfWebsocketApi(self) - self.gatewayMap['FUTURE'] = {"REST":restfutureApi, "WS":wsfutureApi, "leverage":future_leverage, "symbols":contract_list} - if "SWAP" in gateway_type: - restSwapApi = OkexSwapRestApi(self) - wsSwapApi = OkexSwapWebsocketApi(self) - self.gatewayMap['SWAP'] = {"REST":restSwapApi, "WS":wsSwapApi, "leverage":swap_leverage, "symbols":swap_list} - if "SPOT" in gateway_type: - restSpotApi = OkexSpotRestApi(self) - wsSpotApi = OkexSpotWebsocketApi(self) - self.gatewayMap['SPOT'] = {"REST":restSpotApi, "WS":wsSpotApi, "leverage":margin_token, "symbols":spot_list} - - self.connectSubGateway(sessionCount) - - setQryEnabled = setting.get('setQryEnabled', None) - self.setQryEnabled(setQryEnabled) - - setQryFreq = setting.get('setQryFreq', 60) - self.initQuery(setQryFreq) - - #---------------------------------------------------------------------- - def connectSubGateway(self, sessionCount): - for subGateway in self.gatewayMap.values(): - subGateway["REST"].connect(REST_HOST, subGateway["leverage"], sessionCount) - subGateway["WS"].connect(WEBSOCKET_HOST) - - def subscribe(self, subscribeReq): - """订阅行情""" - # symbolType = self.symbolTypeMap.get(subscribeReq.symbol, None) - # if not symbolType: - # self.writeLog(f"{self.gatewayName} does not have this symbol:{subscribeReq.symbol}", logging.ERROR) - # else: - # self.gatewayMap[symbolType]["WS"].subscribe(subscribeReq.symbol) - - #---------------------------------------------------------------------- - def sendOrder(self, orderReq): - """发单""" - strategy_name = self.stgMap.get(orderReq.byStrategy, None) - if not strategy_name: - # 规定策略名称长度和合法字符 - alpha='abcdefghijklmnopqrstuvwxyz' - filter_text = "0123456789" + alpha + alpha.upper() - new_name = filter(lambda ch: ch in filter_text, orderReq.byStrategy) - name = ''.join(list(new_name))[:13] - self.stgMap.update({strategy_name:name}) - strategy_name = name - - symbolType = self.symbolTypeMap.get(orderReq.symbol, None) - if not symbolType: - self.writeLog(f"{self.gatewayName} does not have this symbol:{orderReq.symbol}", logging.ERROR) - else: - self.orderID += 1 - order_id = f"{strategy_name}{symbolType[:4]}{str(self.loginTime + self.orderID)}" - return self.gatewayMap[symbolType]["REST"].sendOrder(orderReq, order_id) - - #---------------------------------------------------------------------- - def cancelOrder(self, cancelOrderReq): - """撤单""" - symbolType = self.symbolTypeMap.get(cancelOrderReq.symbol, None) - if not symbolType: - self.writeLog(f"{self.gatewayName} does not have this symbol:{cancelOrderReq.symbol}", logging.ERROR) - else: - self.gatewayMap[symbolType]["REST"].cancelOrder(cancelOrderReq) - - # ---------------------------------------------------------------------- - def cancelAll(self, symbols=None, orders=None): - """全撤""" - ids = [] - if not symbols: - symbols = list(self.symbolTypeMap.keys()) - for sym in symbols: - symbolType = self.symbolTypeMap.get(sym, None) - vtOrderIDs = self.gatewayMap[symbolType]["REST"].cancelAll(symbol = sym, orders=orders) - ids.extend(vtOrderIDs) - - print("全部撤单结果", ids) - return ids - - # ---------------------------------------------------------------------- - def closeAll(self, symbols, direction=None, standard_token = "USDT"): - """全平""" - ids = [] - if not symbols: - symbols = list(self.symbolTypeMap.keys()) - for sym in symbols: - symbolType = self.symbolTypeMap.get(sym, None) - if symbolType == "SPOT": - vtOrderIDs = self.gatewayMap[symbolType]["REST"].closeAll(symbol = sym, standard_token = standard_token) - else: - vtOrderIDs = self.gatewayMap[symbolType]["REST"].closeAll(symbol = sym, direction = direction) - ids.extend(vtOrderIDs) - - print("全部平仓结果", ids) - return ids - - #---------------------------------------------------------------------- - def close(self): - """关闭""" - for gateway in self.gatewayMap.values(): - gateway["REST"].stop() - gateway["WS"].stop() - #---------------------------------------------------------------------- - def initQuery(self, freq = 60): - """初始化连续查询""" - if self.qryEnabled: - # 需要循环的查询函数列表 - self.qryFunctionList = [self.queryInfo] - - self.qryCount = 0 # 查询触发倒计时 - self.qryTrigger = freq # 查询触发点 - self.qryNextFunction = 0 # 上次运行的查询函数索引 - - self.startQuery() - - #---------------------------------------------------------------------- - def query(self, event): - """注册到事件处理引擎上的查询函数""" - self.qryCount += 1 - - if self.qryCount > self.qryTrigger: - # 清空倒计时 - self.qryCount = 0 - - # 执行查询函数 - function = self.qryFunctionList[self.qryNextFunction] - function() - - # 计算下次查询函数的索引,如果超过了列表长度,则重新设为0 - self.qryNextFunction += 1 - if self.qryNextFunction == len(self.qryFunctionList): - self.qryNextFunction = 0 - - #---------------------------------------------------------------------- - def startQuery(self): - """启动连续查询""" - self.eventEngine.register(EVENT_TIMER, self.query) - - #---------------------------------------------------------------------- - def setQryEnabled(self, qryEnabled): - """设置是否要启动循环查询""" - self.qryEnabled = qryEnabled - - #---------------------------------------------------------------------- - def queryInfo(self): - """""" - for subGateway in self.gatewayMap.values(): - subGateway["REST"].queryMonoAccount(subGateway['symbols']) - subGateway["REST"].queryMonoPosition(subGateway['symbols']) - subGateway["REST"].queryOrder() - - def initPosition(self, vtSymbol): - symbol = vtSymbol.split(constant.VN_SEPARATOR)[0] - symbolType = self.symbolTypeMap.get(symbol, None) - if not symbolType: - self.writeLog(f"{self.gatewayName} does not have this symbol:{symbol}", logging.ERROR) - else: - self.gatewayMap[symbolType]["REST"].queryMonoPosition([symbol]) - self.gatewayMap[symbolType]["REST"].queryMonoAccount([symbol]) - - def qryAllOrders(self, vtSymbol, order_id, status=None): - pass - - def loadHistoryBar(self, vtSymbol, type_, size = None, since = None, end = None): - import pandas as pd - symbol = vtSymbol.split(constant.VN_SEPARATOR)[0] - symbolType = self.symbolTypeMap.get(symbol, None) - granularity = granularityMap[type_] - - if not symbolType: - self.writeLog(f"{self.gatewayName} does not have this symbol:{symbol}", logging.ERROR) - return [] - else: - subGateway = self.gatewayMap[symbolType]["REST"] - - if end: - end = datetime.utcfromtimestamp(datetime.timestamp(datetime.strptime(end,'%Y%m%d'))) - else: - end = datetime.utcfromtimestamp(datetime.timestamp(datetime.now())) - - if since: - start = datetime.utcfromtimestamp(datetime.timestamp(datetime.strptime(since,'%Y%m%d'))) - bar_count = (end -start).total_seconds()/ granularity - - if size: - bar_count = size - - req = {"granularity":granularity} - - df = pd.DataFrame([]) - loop = min(10, int(bar_count // 200 + 1)) - for i in range(loop): - rotate_end = end.isoformat().split('.')[0]+'Z' - rotate_start = end - timedelta(seconds = granularity * 200) - if (i+1) == loop: - rotate_start = end - timedelta(seconds = granularity * (bar_count % 200)) - rotate_start = rotate_start.isoformat().split('.')[0]+'Z' - - req["start"] = rotate_start - req["end"] = rotate_end - data = subGateway.loadHistoryBar(REST_HOST, symbol, req) - - end = datetime.strptime(rotate_start, "%Y-%m-%dT%H:%M:%SZ") - df = pd.concat([df, data]) - - df["datetime"] = df["time"].map(lambda x: datetime.strptime(x, ISO_DATETIME_FORMAT).replace(tzinfo=timezone(timedelta()))) - df = df[["datetime", "open", "high", "low", "close", "volume"]] - df["datetime"] = df["datetime"].map(lambda x: datetime.fromtimestamp(x.timestamp())) - df[['open','high','low','close','volume']] = df[['open','high','low','close','volume']].applymap(lambda x: float(x)) - df.sort_values(by=['datetime'], axis = 0, ascending =True, inplace = True) - return df - - def writeLog(self, content, level = logging.INFO): - """发出日志""" - log = VtLogData() - log.gatewayName = self.gatewayName - log.logContent = content - log.logLevel = level - self.onLog(log) - - def newOrderObject(self, data): - order = VtOrderData() - order.gatewayName = self.gatewayName - order.symbol = data['instrument_id'] - order.exchange = 'OKEX' - order.vtSymbol = constant.VN_SEPARATOR.join([order.symbol, order.gatewayName]) - - order.orderID = data.get("client_oid", None) - if not order.orderID: - order.orderID = str(data['order_id']) - self.writeLog(f"order by other source, symbol:{order.symbol}, exchange_id: {order.orderID}") - - order.vtOrderID = constant.VN_SEPARATOR.join([self.gatewayName, order.orderID]) - return order - - def newTradeObject(self, order): - self.tradeID += 1 - trade = VtTradeData() - trade.gatewayName = order.gatewayName - trade.symbol = order.symbol - trade.exchange = order.exchange - trade.vtSymbol = order.vtSymbol - - trade.orderID = order.orderID - trade.vtOrderID = order.vtOrderID - trade.tradeID = str(self.tradeID) - trade.vtTradeID = constant.VN_SEPARATOR.join([self.gatewayName, trade.tradeID]) - - trade.direction = order.direction - trade.offset = order.offset - trade.volume = order.thisTradedVolume - trade.price = order.price_avg - trade.tradeDatetime = datetime.now() - trade.tradeTime = trade.tradeDatetime.strftime('%Y%m%d %H:%M:%S') - self.onTrade(trade) - - def convertDatetime(self, timestring): - dt = datetime.strptime(timestring, ISO_DATETIME_FORMAT) - dt = dt.replace(tzinfo=timezone(timedelta())) - local_dt = datetime.fromtimestamp(dt.timestamp()) - date_string = local_dt.strftime('%Y%m%d') - time_string = local_dt.strftime('%H:%M:%S.%f') - return local_dt, date_string, time_string \ No newline at end of file +import os +import json +import time +import logging +from datetime import datetime, timezone, timedelta + +from vnpy.api.rest import RestClient, Request +from vnpy.api.websocket import WebsocketClient +from vnpy.trader.vtGateway import * +from vnpy.trader.vtConstant import constant +from vnpy.trader.vtFunction import getJsonPath, getTempPath +from .future import OkexfRestApi, OkexfWebsocketApi +from .swap import OkexSwapRestApi, OkexSwapWebsocketApi +from .spot import OkexSpotRestApi, OkexSpotWebsocketApi +from .util import ISO_DATETIME_FORMAT, granularityMap + +REST_HOST = os.environ.get('OKEX_REST_URL', 'https://www.okex.com') +WEBSOCKET_HOST = os.environ.get('OKEX_WEBSOCKET_URL', 'wss://real.okex.com:10442/ws/v3') + +######################################################################## +class OkexGateway(VtGateway): + """OKEX V3 接口""" + + #---------------------------------------------------------------------- + def __init__(self, eventEngine, gatewayName=''): + """Constructor""" + super(OkexGateway, self).__init__(eventEngine, gatewayName) + + self.qryEnabled = False # 是否要启动循环查询 + + self.fileName = self.gatewayName + '_connect.json' + self.filePath = getJsonPath(self.fileName, __file__) + + self.apiKey = '' + self.apiSecret = '' + self.passphrase = '' + + self.symbolTypeMap = {} + self.gatewayMap = {} + self.stgMap = {} + + self.orderID = 1 + self.tradeID = 0 + self.loginTime = int(datetime.now().strftime('%y%m%d%H%M%S')) * 100 + + #---------------------------------------------------------------------- + def connect(self): + """连接""" + try: + f = open(self.filePath) + except IOError: + self.writeLog(u"读取连接配置出错,请检查配置文件", logging.ERROR) + return + + # 解析connect.json文件 + setting = json.load(f) + f.close() + + try: + self.apiKey = str(setting['apiKey']) + self.apiSecret = str(setting['apiSecret']) + self.passphrase = str(setting['passphrase']) + sessionCount = int(setting['sessionCount']) + subscrib_symbols = setting['symbols'] + except KeyError as e: + self.writeLog(f"{self.gatewayName} 连接配置缺少字段,请检查{e}", logging.ERROR) + return + + # 记录订阅的交易品种类型 + contract_list = [] + swap_list = [] + spot_list = [] + for symbol in subscrib_symbols: + if "WEEK" in symbol or "QUARTER" in symbol: + self.symbolTypeMap[symbol] = "FUTURE" + contract_list.append(symbol) + elif "SWAP" in symbol: + self.symbolTypeMap[symbol] = "SWAP" + swap_list.append(symbol) + else: + self.symbolTypeMap[symbol] = "SPOT" + spot_list.append(symbol) + + # 创建行情和交易接口对象 + future_leverage = setting.get('future_leverage', 10) + swap_leverage = setting.get('swap_leverage', 1) + margin_token = setting.get('margin_token', 0) + + # 实例化对应品种类别的API + gateway_type = set(self.symbolTypeMap.values()) + if "FUTURE" in gateway_type: + restfutureApi = OkexfRestApi(self) + wsfutureApi = OkexfWebsocketApi(self) + self.gatewayMap['FUTURE'] = {"REST":restfutureApi, "WS":wsfutureApi, "leverage":future_leverage, "symbols":contract_list} + if "SWAP" in gateway_type: + restSwapApi = OkexSwapRestApi(self) + wsSwapApi = OkexSwapWebsocketApi(self) + self.gatewayMap['SWAP'] = {"REST":restSwapApi, "WS":wsSwapApi, "leverage":swap_leverage, "symbols":swap_list} + if "SPOT" in gateway_type: + restSpotApi = OkexSpotRestApi(self) + wsSpotApi = OkexSpotWebsocketApi(self) + self.gatewayMap['SPOT'] = {"REST":restSpotApi, "WS":wsSpotApi, "leverage":margin_token, "symbols":spot_list} + + self.connectSubGateway(sessionCount) + + setQryEnabled = setting.get('setQryEnabled', None) + self.setQryEnabled(setQryEnabled) + + setQryFreq = setting.get('setQryFreq', 60) + self.initQuery(setQryFreq) + + #---------------------------------------------------------------------- + def connectSubGateway(self, sessionCount): + for subGateway in self.gatewayMap.values(): + subGateway["REST"].connect(REST_HOST, subGateway["leverage"], sessionCount) + subGateway["WS"].connect(WEBSOCKET_HOST) + + def subscribe(self, subscribeReq): + """订阅行情""" + # symbolType = self.symbolTypeMap.get(subscribeReq.symbol, None) + # if not symbolType: + # self.writeLog(f"{self.gatewayName} does not have this symbol:{subscribeReq.symbol}", logging.ERROR) + # else: + # self.gatewayMap[symbolType]["WS"].subscribe(subscribeReq.symbol) + + #---------------------------------------------------------------------- + def sendOrder(self, orderReq): + """发单""" + strategy_name = self.stgMap.get(orderReq.byStrategy, None) + if not strategy_name: + # 规定策略名称长度和合法字符 + alpha='abcdefghijklmnopqrstuvwxyz' + filter_text = "0123456789" + alpha + alpha.upper() + new_name = filter(lambda ch: ch in filter_text, orderReq.byStrategy) + name = ''.join(list(new_name))[:13] + self.stgMap.update({strategy_name:name}) + strategy_name = name + + symbolType = self.symbolTypeMap.get(orderReq.symbol, None) + if not symbolType: + self.writeLog(f"{self.gatewayName} does not have this symbol:{orderReq.symbol}", logging.ERROR) + else: + self.orderID += 1 + order_id = f"{strategy_name}{symbolType[:4]}{str(self.loginTime + self.orderID)}" + return self.gatewayMap[symbolType]["REST"].sendOrder(orderReq, order_id) + + #---------------------------------------------------------------------- + def cancelOrder(self, cancelOrderReq): + """撤单""" + symbolType = self.symbolTypeMap.get(cancelOrderReq.symbol, None) + if not symbolType: + self.writeLog(f"{self.gatewayName} does not have this symbol:{cancelOrderReq.symbol}", logging.ERROR) + else: + self.gatewayMap[symbolType]["REST"].cancelOrder(cancelOrderReq) + + # ---------------------------------------------------------------------- + def cancelAll(self, symbols=None, orders=None): + """全撤""" + ids = [] + if not symbols: + symbols = list(self.symbolTypeMap.keys()) + for sym in symbols: + symbolType = self.symbolTypeMap.get(sym, None) + vtOrderIDs = self.gatewayMap[symbolType]["REST"].cancelAll(symbol = sym, orders=orders) + ids.extend(vtOrderIDs) + + print("全部撤单结果", ids) + return ids + + # ---------------------------------------------------------------------- + def closeAll(self, symbols, direction=None, standard_token = "USDT"): + """全平""" + ids = [] + if not symbols: + symbols = list(self.symbolTypeMap.keys()) + for sym in symbols: + symbolType = self.symbolTypeMap.get(sym, None) + if symbolType == "SPOT": + vtOrderIDs = self.gatewayMap[symbolType]["REST"].closeAll(symbol = sym, standard_token = standard_token) + else: + vtOrderIDs = self.gatewayMap[symbolType]["REST"].closeAll(symbol = sym, direction = direction) + ids.extend(vtOrderIDs) + + print("全部平仓结果", ids) + return ids + + #---------------------------------------------------------------------- + def close(self): + """关闭""" + for gateway in self.gatewayMap.values(): + gateway["REST"].stop() + gateway["WS"].stop() + #---------------------------------------------------------------------- + def initQuery(self, freq = 60): + """初始化连续查询""" + if self.qryEnabled: + # 需要循环的查询函数列表 + self.qryFunctionList = [self.queryInfo] + + self.qryCount = 0 # 查询触发倒计时 + self.qryTrigger = freq # 查询触发点 + self.qryNextFunction = 0 # 上次运行的查询函数索引 + + self.startQuery() + + #---------------------------------------------------------------------- + def query(self, event): + """注册到事件处理引擎上的查询函数""" + self.qryCount += 1 + + if self.qryCount > self.qryTrigger: + # 清空倒计时 + self.qryCount = 0 + + # 执行查询函数 + function = self.qryFunctionList[self.qryNextFunction] + function() + + # 计算下次查询函数的索引,如果超过了列表长度,则重新设为0 + self.qryNextFunction += 1 + if self.qryNextFunction == len(self.qryFunctionList): + self.qryNextFunction = 0 + + #---------------------------------------------------------------------- + def startQuery(self): + """启动连续查询""" + self.eventEngine.register(EVENT_TIMER, self.query) + + #---------------------------------------------------------------------- + def setQryEnabled(self, qryEnabled): + """设置是否要启动循环查询""" + self.qryEnabled = qryEnabled + + #---------------------------------------------------------------------- + def queryInfo(self): + """""" + for subGateway in self.gatewayMap.values(): + subGateway["REST"].queryMonoAccount(subGateway['symbols']) + subGateway["REST"].queryMonoPosition(subGateway['symbols']) + subGateway["REST"].queryOrder() + + def initPosition(self, vtSymbol): + symbol = vtSymbol.split(constant.VN_SEPARATOR)[0] + symbolType = self.symbolTypeMap.get(symbol, None) + if not symbolType: + self.writeLog(f"{self.gatewayName} does not have this symbol:{symbol}", logging.ERROR) + else: + self.gatewayMap[symbolType]["REST"].queryMonoPosition([symbol]) + self.gatewayMap[symbolType]["REST"].queryMonoAccount([symbol]) + + def qryAllOrders(self, vtSymbol, order_id, status=None): + pass + + def loadHistoryBar(self, vtSymbol, type_, size = None, since = None, end = None): + import pandas as pd + symbol = vtSymbol.split(constant.VN_SEPARATOR)[0] + symbolType = self.symbolTypeMap.get(symbol, None) + granularity = granularityMap[type_] + + if not symbolType: + self.writeLog(f"{self.gatewayName} does not have this symbol:{symbol}", logging.ERROR) + return [] + else: + subGateway = self.gatewayMap[symbolType]["REST"] + + if end: + end = datetime.utcfromtimestamp(datetime.timestamp(datetime.strptime(end,'%Y%m%d'))) + else: + end = datetime.utcfromtimestamp(datetime.timestamp(datetime.now())) + + if since: + start = datetime.utcfromtimestamp(datetime.timestamp(datetime.strptime(since,'%Y%m%d'))) + bar_count = (end -start).total_seconds()/ granularity + + if size: + bar_count = size + + req = {"granularity":granularity} + + df = pd.DataFrame([]) + loop = min(10, int(bar_count // 200 + 1)) + for i in range(loop): + rotate_end = end.isoformat().split('.')[0]+'Z' + rotate_start = end - timedelta(seconds = granularity * 200) + if (i+1) == loop: + rotate_start = end - timedelta(seconds = granularity * (bar_count % 200)) + rotate_start = rotate_start.isoformat().split('.')[0]+'Z' + + req["start"] = rotate_start + req["end"] = rotate_end + data = subGateway.loadHistoryBar(REST_HOST, symbol, req) + + end = datetime.strptime(rotate_start, "%Y-%m-%dT%H:%M:%SZ") + df = pd.concat([df, data]) + + df["datetime"] = df["time"].map(lambda x: datetime.strptime(x, ISO_DATETIME_FORMAT).replace(tzinfo=timezone(timedelta()))) + df = df[["datetime", "open", "high", "low", "close", "volume"]] + df["datetime"] = df["datetime"].map(lambda x: datetime.fromtimestamp(x.timestamp())) + df[['open','high','low','close','volume']] = df[['open','high','low','close','volume']].applymap(lambda x: float(x)) + df.sort_values(by=['datetime'], axis = 0, ascending =True, inplace = True) + return df + + def writeLog(self, content, level = logging.INFO): + """发出日志""" + log = VtLogData() + log.gatewayName = self.gatewayName + log.logContent = content + log.logLevel = level + self.onLog(log) + + def newOrderObject(self, data): + order = VtOrderData() + order.gatewayName = self.gatewayName + order.symbol = data['instrument_id'] + order.exchange = 'OKEX' + order.vtSymbol = constant.VN_SEPARATOR.join([order.symbol, order.gatewayName]) + + order.orderID = data.get("client_oid", None) + if not order.orderID: + order.orderID = str(data['order_id']) + self.writeLog(f"order by other source, symbol:{order.symbol}, exchange_id: {order.orderID}") + + order.vtOrderID = constant.VN_SEPARATOR.join([self.gatewayName, order.orderID]) + return order + + def newTradeObject(self, order): + self.tradeID += 1 + trade = VtTradeData() + trade.gatewayName = order.gatewayName + trade.symbol = order.symbol + trade.exchange = order.exchange + trade.vtSymbol = order.vtSymbol + + trade.orderID = order.orderID + trade.vtOrderID = order.vtOrderID + trade.tradeID = str(self.tradeID) + trade.vtTradeID = constant.VN_SEPARATOR.join([self.gatewayName, trade.tradeID]) + + trade.direction = order.direction + trade.offset = order.offset + trade.volume = order.thisTradedVolume + trade.price = order.price_avg + trade.tradeDatetime = datetime.now() + trade.tradeTime = trade.tradeDatetime.strftime('%Y%m%d %H:%M:%S') + self.onTrade(trade) + + def convertDatetime(self, timestring): + dt = datetime.strptime(timestring, ISO_DATETIME_FORMAT) + dt = dt.replace(tzinfo=timezone(timedelta())) + local_dt = datetime.fromtimestamp(dt.timestamp()) + date_string = local_dt.strftime('%Y%m%d') + time_string = local_dt.strftime('%H:%M:%S.%f') + return local_dt, date_string, time_string + + def sendExit(self): + event = Event(EVENT_EXIT) + self.eventEngine.put(event) \ No newline at end of file diff --git a/vnpy/trader/vtEngine.py b/vnpy/trader/vtEngine.py index 7de1089..d686db0 100644 --- a/vnpy/trader/vtEngine.py +++ b/vnpy/trader/vtEngine.py @@ -1,757 +1,761 @@ -# encoding: UTF-8 - -import os -import shelve -import logging -from logging import handlers -from collections import OrderedDict -from datetime import datetime -from copy import copy - -# from pymongo import MongoClient, ASCENDING -# from pymongo.errors import ConnectionFailure - -from vnpy.event import Event -from vnpy.trader.vtGlobal import globalSetting -from vnpy.trader.vtEvent import * -from vnpy.trader.vtGateway import * -from vnpy.trader.language import text -from vnpy.trader.vtFunction import getTempPath - - -######################################################################## -class MainEngine(object): - """主引擎""" - - #---------------------------------------------------------------------- - def __init__(self, eventEngine): - """Constructor""" - # 记录今日日期 - self.todayDate = datetime.now().strftime('%Y%m%d') - - # 绑定事件引擎 - self.eventEngine = eventEngine - self.eventEngine.start() - - # 创建数据引擎 - self.dataEngine = DataEngine(self.eventEngine) - - # MongoDB数据库相关 - self.dbClient = None # MongoDB客户端对象 - - # 接口实例 - self.gatewayDict = OrderedDict() - self.gatewayDetailList = [] - - # 应用模块实例 - self.appDict = OrderedDict() - self.appDetailList = [] - - # 风控引擎实例(特殊独立对象) - self.rmEngine = None - - # 日志引擎实例 - self.logEngine = None - self.initLogEngine() - - #---------------------------------------------------------------------- - def addGateway(self, gatewayModule): - """添加底层接口""" - gatewayName = gatewayModule.gatewayName - gatewayTypeMap = {} - - # 创建接口实例 - if type(gatewayName) == list: - for i in range(len(gatewayName)): - self.gatewayDict[gatewayName[i]] = gatewayModule.gatewayClass( - self.eventEngine, gatewayName[i]) - - # 设置接口轮询 - if gatewayModule.gatewayQryEnabled: - self.gatewayDict[gatewayName[i]].setQryEnabled( - gatewayModule.gatewayQryEnabled) - - # 保存接口详细信息 - d = { - 'gatewayName': gatewayModule.gatewayName[i], - 'gatewayDisplayName': gatewayModule.gatewayDisplayName[i], - 'gatewayType': gatewayModule.gatewayType - } - self.gatewayDetailList.append(d) - else: - self.gatewayDict[gatewayName] = gatewayModule.gatewayClass( - self.eventEngine, gatewayName) - - # 设置接口轮询 - if gatewayModule.gatewayQryEnabled: - self.gatewayDict[gatewayName].setQryEnabled( - gatewayModule.gatewayQryEnabled) - - # 保存接口详细信息 - d = { - 'gatewayName': gatewayModule.gatewayName, - 'gatewayDisplayName': gatewayModule.gatewayDisplayName, - 'gatewayType': gatewayModule.gatewayType - } - self.gatewayDetailList.append(d) - - for i in range(len(self.gatewayDetailList)): - s = self.gatewayDetailList[i]['gatewayName'].split( - '_connect.json')[0] - gatewayTypeMap[s] = self.gatewayDetailList[i]['gatewayType'] - - path = os.getcwd() - # 遍历当前目录下的所有文件 - for root, subdirs, files in os.walk(path): - for name in files: - # 只有文件名中包含_connect.json的文件,才是密钥配置文件 - if '_connect.json' in name: - gw = name.replace('_connect.json', '') - if not gw in gatewayTypeMap.keys(): - for existnames in list(gatewayTypeMap.keys()): - if existnames in gw and existnames != gw: - d = { - 'gatewayName': gw, - 'gatewayDisplayName': gw, - 'gatewayType': gatewayTypeMap[existnames] - } - self.gatewayDetailList.append(d) - self.gatewayDict[ - gw] = gatewayModule.gatewayClass( - self.eventEngine, gw) - - #---------------------------------------------------------------------- - def addApp(self, appModule): - """添加上层应用""" - appName = appModule.appName - - # 创建应用实例 - self.appDict[appName] = appModule.appEngine(self, self.eventEngine) - - # 将应用引擎实例添加到主引擎的属性中 - self.__dict__[appName] = self.appDict[appName] - - # 保存应用信息 - d = { - 'appName': appModule.appName, - 'appDisplayName': appModule.appDisplayName, - 'appWidget': appModule.appWidget, - 'appIco': appModule.appIco - } - self.appDetailList.append(d) - - #---------------------------------------------------------------------- - def getGateway(self, gatewayName): - """获取接口""" - if gatewayName in self.gatewayDict: - return self.gatewayDict[gatewayName] - else: - self.writeLog(text.GATEWAY_NOT_EXIST.format(gateway=gatewayName)) - self.writeLog(gatewayName) - return None - - #---------------------------------------------------------------------- - def connect(self, gatewayName): - """连接特定名称的接口""" - gateway = self.getGateway(gatewayName) - - if gateway: - gateway.connect() - - # 接口连接后自动执行数据库连接的任务 - # self.dbConnect() - - #---------------------------------------------------------------------- - def subscribe(self, subscribeReq, gatewayName): - """订阅特定接口的行情""" - gateway = self.getGateway(gatewayName) - - if gateway: - gateway.subscribe(subscribeReq) - - #---------------------------------------------------------------------- - def sendOrder(self, orderReq, gatewayName): - """对特定接口发单""" - # 如果创建了风控引擎,且风控检查失败则不发单 - if self.rmEngine and not self.rmEngine.checkRisk( - orderReq, gatewayName): - return '' - - gateway = self.getGateway(gatewayName) - if gateway: - vtOrderID = gateway.sendOrder(orderReq) - # self.dataEngine.updateOrderReq(orderReq, vtOrderID) # 更新发出的委托请求到数据引擎中 - return vtOrderID - else: - return '' - - #---------------------------------------------------------------------- - def cancelOrder(self, cancelOrderReq, gatewayName): - """对特定接口撤单""" - gateway = self.getGateway(gatewayName) - - if gateway: - gateway.cancelOrder(cancelOrderReq) - - def batchCancelOrder(self, cancelOrderReqList, gatewayName): - gateway = self.getGateway(gatewayName) - - if gateway: - gateway.batchCancelOrder(cancelOrderReqList) - - #---------------------------------------------------------------------- - def qryAccount(self, gatewayName): - """查询特定接口的账户""" - gateway = self.getGateway(gatewayName) - - if gateway: - gateway.qryAccount() - - #---------------------------------------------------------------------- - def qryPosition(self, gatewayName): - """查询特定接口的持仓""" - gateway = self.getGateway(gatewayName) - - if gateway: - gateway.qryPosition() - - #------------------------------------------------ - def initPosition(self, vtSymbol): - """策略初始化时查询特定接口的持仓""" - contract = self.getContract(vtSymbol) - if contract: - gatewayName = contract.gatewayName - gateway = self.getGateway(gatewayName) - if gateway: - gateway.initPosition(vtSymbol) - else: - self.writeLog( - 'we don\'t have this symbol %s, Please check symbolList in ctaSetting.json' - % vtSymbol) - return None - - def loadHistoryBar(self, vtSymbol, type_, size=None, since=None): - """策略初始化时下载历史数据""" - contract = self.getContract(vtSymbol) - gatewayName = contract.gatewayName - gateway = self.getGateway(gatewayName) - if gateway: - data = gateway.loadHistoryBar(vtSymbol, type_, size, since) - return data - - def qryAllOrders(self, vtSymbol, orderId, status=None): - contract = self.getContract(vtSymbol) - gatewayName = contract.gatewayName - gateway = self.getGateway(gatewayName) - if gateway: - gateway.qryAllOrders(vtSymbol, orderId, status) - - #---------------------------------------------------------------------- - def exit(self): - """退出程序前调用,保证正常退出""" - # 安全关闭所有接口 - for gateway in list(self.gatewayDict.values()): - gateway.close() - - # 停止事件引擎 - self.eventEngine.stop() - - # 停止上层应用引擎 - for appEngine in list(self.appDict.values()): - appEngine.stop() - - # 保存数据引擎里的合约数据到硬盘 - self.dataEngine.saveContracts() - - #---------------------------------------------------------------------- - def writeLog(self, content): - """快速发出日志事件""" - log = VtLogData() - log.logContent = content - log.gatewayName = 'MAIN_ENGINE' - event = Event(type_=EVENT_LOG) - event.dict_['data'] = log - self.eventEngine.put(event) - - #---------------------------------------------------------------------- - def getContract(self, vtSymbol): - """查询合约""" - return self.dataEngine.getContract(vtSymbol) - - #---------------------------------------------------------------------- - def getAllContracts(self): - """查询所有合约(返回列表)""" - return self.dataEngine.getAllContracts() - - #---------------------------------------------------------------------- - def getOrder(self, vtOrderID): - """查询委托""" - return self.dataEngine.getOrder(vtOrderID) - - #---------------------------------------------------------------------- - def getAllWorkingOrders(self): - """查询所有的活跃的委托(返回列表)""" - return self.dataEngine.getAllWorkingOrders() - - #---------------------------------------------------------------------- - def getAllOrders(self): - """查询所有委托""" - return self.dataEngine.getAllOrders() - - #---------------------------------------------------------------------- - def getAllTrades(self): - """查询所有成交""" - return self.dataEngine.getAllTrades() - - #---------------------------------------------------------------------- - def getAllAccounts(self): - """查询所有账户""" - return self.dataEngine.getAllAccounts() - - def getAllPositions(self): - """查询所有持仓""" - return self.dataEngine.getAllPositions() - - #---------------------------------------------------------------------- - def getAllPositionDetails(self): - """查询本地持仓缓存细节""" - return self.dataEngine.getAllPositionDetails() - - #---------------------------------------------------------------------- - def getAllGatewayDetails(self): - """查询引擎中所有底层接口的信息""" - return self.gatewayDetailList - - #---------------------------------------------------------------------- - def getAllAppDetails(self): - """查询引擎中所有上层应用的信息""" - return self.appDetailList - - #---------------------------------------------------------------------- - def getApp(self, appName): - """获取APP引擎对象""" - return self.appDict[appName] - - #---------------------------------------------------------------------- - def initLogEngine(self): - """初始化日志引擎""" - if not globalSetting["logActive"]: - return - - # 创建引擎 - self.logEngine = LogEngine() - - # 设置日志级别 - levelDict = { - "debug": LogEngine.LEVEL_DEBUG, - "info": LogEngine.LEVEL_INFO, - "warn": LogEngine.LEVEL_WARN, - "error": LogEngine.LEVEL_ERROR, - "critical": LogEngine.LEVEL_CRITICAL, - } - level = levelDict.get(globalSetting["logLevel"], - LogEngine.LEVEL_CRITICAL) - self.logEngine.setLogLevel(level) - stream_setting = globalSetting.get("streamLevel", "info") - streamLevel = levelDict.get(stream_setting, LogEngine.LEVEL_CRITICAL) - self.logEngine.setStreamLevel(streamLevel) - - # 设置输出 - if globalSetting['logConsole']: - self.logEngine.addConsoleHandler() - - if globalSetting['logFile']: - self.logEngine.addFileHandler() - - # 注册事件监听 - self.registerLogEvent(EVENT_LOG) - - #---------------------------------------------------------------------- - def registerLogEvent(self, eventType): - """注册日志事件监听""" - if self.logEngine: - self.eventEngine.register(eventType, - self.logEngine.processLogEvent) - - #---------------------------------------------------------------------- - # def convertOrderReq(self, req): - # """转换委托请求""" - # return self.dataEngine.convertOrderReq(req) - - #---------------------------------------------------------------------- - def getLog(self): - """查询日志""" - return self.dataEngine.getLog() - - #---------------------------------------------------------------------- - def getError(self): - """查询错误""" - return self.dataEngine.getError() - - -######################################################################## - - -class DataEngine(object): - """数据引擎""" - contractFileName = 'ContractData.vt' - contractFilePath = getTempPath(contractFileName) - - FINISHED_STATUS = [STATUS_ALLTRADED, STATUS_REJECTED, STATUS_CANCELLED] - - #---------------------------------------------------------------------- - def __init__(self, eventEngine): - """Constructor""" - self.eventEngine = eventEngine - - # 保存数据的字典和列表 - self.tickDict = {} - self.contractDict = {} - self.orderDict = {} - self.workingOrderDict = {} # 可撤销委托 - self.tradeDict = {} - self.accountDict = {} - self.positionDict = {} - self.logList = [] - self.errorList = [] - - # 持仓细节相关 - # self.detailDict = {} # vtSymbol:PositionDetail - self.tdPenaltyList = globalSetting['tdPenalty'] # 平今手续费惩罚的产品代码列表 - - # 读取保存在硬盘的合约数据 - self.loadContracts() - - # 注册事件监听 - self.registerEvent() - - #---------------------------------------------------------------------- - def registerEvent(self): - """注册事件监听""" - self.eventEngine.register(EVENT_TICK, self.processTickEvent) - self.eventEngine.register(EVENT_CONTRACT, self.processContractEvent) - self.eventEngine.register(EVENT_ORDER, self.processOrderEvent) - self.eventEngine.register(EVENT_TRADE, self.processTradeEvent) - self.eventEngine.register(EVENT_POSITION, self.processPositionEvent) - self.eventEngine.register(EVENT_ACCOUNT, self.processAccountEvent) - self.eventEngine.register(EVENT_LOG, self.processLogEvent) - self.eventEngine.register(EVENT_ERROR, self.processErrorEvent) - - #---------------------------------------------------------------------- - def processTickEvent(self, event): - """处理成交事件""" - tick = event.dict_['data'] - self.tickDict[tick.vtSymbol] = tick - - #---------------------------------------------------------------------- - def processContractEvent(self, event): - """处理合约事件""" - contract = event.dict_['data'] - self.contractDict[contract.vtSymbol] = contract - self.contractDict[contract.symbol] = contract # 使用常规代码(不包括交易所)可能导致重复 - - #---------------------------------------------------------------------- - def processOrderEvent(self, event): - """处理委托事件""" - order = event.dict_['data'] - self.orderDict[order.vtOrderID] = order - - # 如果订单的状态是全部成交或者撤销,则需要从workingOrderDict中移除 - if order.status in self.FINISHED_STATUS: - if order.vtOrderID in self.workingOrderDict: - del self.workingOrderDict[order.vtOrderID] - # 否则则更新字典中的数据 - else: - self.workingOrderDict[order.vtOrderID] = order - - # 更新到持仓细节中 - # detail = self.getPositionDetail(order.vtSymbol) - # detail.updateOrder(order) - - #---------------------------------------------------------------------- - def processTradeEvent(self, event): - """处理成交事件""" - trade = event.dict_['data'] - - self.tradeDict[trade.vtTradeID] = trade - - # 更新到持仓细节中 - # detail = self.getPositionDetail(trade.vtSymbol) - # detail.updateTrade(trade) - - #---------------------------------------------------------------------- - def processPositionEvent(self, event): - """处理持仓事件""" - pos = event.dict_['data'] - - self.positionDict[pos.vtPositionName] = pos - - # 更新到持仓细节中 - # detail = self.getPositionDetail(pos.vtSymbol) - # detail.updatePosition(pos) - - #---------------------------------------------------------------------- - def processAccountEvent(self, event): - """处理账户事件""" - account = event.dict_['data'] - self.accountDict[account.vtAccountID] = account - - #---------------------------------------------------------------------- - def processLogEvent(self, event): - """处理日志事件""" - log = event.dict_['data'] - self.logList.append(log) - - #---------------------------------------------------------------------- - def processErrorEvent(self, event): - """处理错误事件""" - error = event.dict_['data'] - self.errorList.append(error) - - #---------------------------------------------------------------------- - def getTick(self, vtSymbol): - """查询行情对象""" - try: - return self.tickDict[vtSymbol] - except KeyError: - return None - - #---------------------------------------------------------------------- - def getContract(self, vtSymbol): - """查询合约对象""" - try: - return self.contractDict[vtSymbol] - except KeyError: - return None - - #---------------------------------------------------------------------- - def getAllContracts(self): - """查询所有合约对象(返回列表)""" - return self.contractDict.values() - - #---------------------------------------------------------------------- - def saveContracts(self): - """保存所有合约对象到硬盘""" - f = shelve.open(self.contractFilePath) - f['data'] = self.contractDict - f.close() - - #---------------------------------------------------------------------- - def loadContracts(self): - """从硬盘读取合约对象""" - f = shelve.open(self.contractFilePath) - if 'data' in f: - d = f['data'] - for key, value in d.items(): - self.contractDict[key] = value - f.close() - - #---------------------------------------------------------------------- - def getOrder(self, vtOrderID): - """查询委托""" - try: - return self.orderDict[vtOrderID] - except KeyError: - return None - - #---------------------------------------------------------------------- - def getAllWorkingOrders(self): - """查询所有活动委托(返回列表)""" - return self.workingOrderDict.values() - - #---------------------------------------------------------------------- - def getAllOrders(self): - """获取所有委托""" - return self.orderDict.values() - - #---------------------------------------------------------------------- - def getAllTrades(self): - """获取所有成交""" - return self.tradeDict.values() - - #---------------------------------------------------------------------- - def getAllPositions(self): - """获取所有持仓""" - return self.positionDict.values() - - #---------------------------------------------------------------------- - def getAllAccounts(self): - """获取所有资金""" - return self.accountDict.values() - - # #---------------------------------------------------------------------- - # def getPositionDetail(self, vtSymbol): - # """查询持仓细节""" - # if vtSymbol in self.detailDict: - # detail = self.detailDict[vtSymbol] - # else: - # contract = self.getContract(vtSymbol) - # detail = PositionDetail(vtSymbol, contract) - # self.detailDict[vtSymbol] = detail - - # # 设置持仓细节的委托转换模式 - # contract = self.getContract(vtSymbol) - - # if contract: - # detail.exchange = contract.exchange - - # # 上期所合约 - # if contract.exchange == EXCHANGE_SHFE: - # detail.mode = detail.MODE_SHFE - - # # 检查是否有平今惩罚 - # for productID in self.tdPenaltyList: - # if str(productID) in contract.symbol: - # detail.mode = detail.MODE_TDPENALTY - - # return detail - - #---------------------------------------------------------------------- - # def getAllPositionDetails(self): - # """查询所有本地持仓缓存细节""" - # return self.detailDict.values() - - # #---------------------------------------------------------------------- - # def updateOrderReq(self, req, vtOrderID): - # """委托请求更新""" - # vtSymbol = req.vtSymbol - - # detail = self.getPositionDetail(vtSymbol) - # detail.updateOrderReq(req, vtOrderID) - - # #---------------------------------------------------------------------- - # def convertOrderReq(self, req): - # """根据规则转换委托请求""" - # detail = self.detailDict.get(req.vtSymbol, None) - # if not detail: - # return [req] - # else: - # return detail.convertOrderReq(req) - - #---------------------------------------------------------------------- - def getLog(self): - """获取日志""" - return self.logList - - #---------------------------------------------------------------------- - def getError(self): - """获取错误""" - return self.errorList - - -######################################################################## -class LogEngine(object): - """日志引擎""" - format = '%(asctime)s %(levelname)s: %(message)s' - # 日志级别 - LEVEL_DEBUG = logging.DEBUG - LEVEL_INFO = logging.INFO - LEVEL_WARN = logging.WARN - LEVEL_ERROR = logging.ERROR - LEVEL_CRITICAL = logging.CRITICAL - - # 单例对象 - instance = None - - #---------------------------------------------------------------------- - def __new__(cls, *args, **kwargs): - """创建对象,保证单例""" - if not cls.instance: - cls.instance = super(LogEngine, cls).__new__(cls, *args, **kwargs) - return cls.instance - - #---------------------------------------------------------------------- - def __init__(self): - """Constructor""" - self.logger = logging.getLogger() - # TODO: may be we should put vnpy log in an independant logger. - self.logger.handlers = [] - self.formatter = logging.Formatter(self.format) - self.level = self.LEVEL_CRITICAL - self.streamLevel = self.LEVEL_CRITICAL - - self.consoleHandler = None - self.fileHandler = None - - # 添加NullHandler防止无handler的错误输出 - nullHandler = logging.NullHandler() - self.logger.addHandler(nullHandler) - - # 日志级别函数映射 - self.levelFunctionDict = { - self.LEVEL_DEBUG: self.debug, - self.LEVEL_INFO: self.info, - self.LEVEL_WARN: self.warn, - self.LEVEL_ERROR: self.error, - self.LEVEL_CRITICAL: self.critical, - } - - #---------------------------------------------------------------------- - def setLogLevel(self, level): - """设置日志级别""" - self.logger.setLevel(level) - self.level = level - - def setStreamLevel(self, level): - self.streamLevel = level - - #---------------------------------------------------------------------- - def addConsoleHandler(self): - """添加终端输出""" - if not self.consoleHandler: - self.consoleHandler = logging.StreamHandler() - self.consoleHandler.setLevel(self.streamLevel) - self.consoleHandler.setFormatter(self.formatter) - self.logger.addHandler(self.consoleHandler) - - #---------------------------------------------------------------------- - def addFileHandler(self): - """添加文件输出""" - if not self.fileHandler: - filename = 'vt_' + datetime.now().strftime('%Y%m%d') + '.log' - filepath = getTempPath(filename) - # self.fileHandler = logging.FileHandler(filepath) # 引擎原有的handler - # 限制日志文件大小为20M,一天最多 400 MB - self.fileHandler = logging.handlers.RotatingFileHandler( - filepath, maxBytes=20971520, backupCount=20, encoding="utf-8") - self.fileHandler.setLevel(self.level) - self.fileHandler.setFormatter(self.formatter) - self.logger.addHandler(self.fileHandler) - - #---------------------------------------------------------------------- - def debug(self, msg): - """开发时用""" - self.logger.debug(msg) - - #---------------------------------------------------------------------- - def info(self, msg): - """正常输出""" - self.logger.info(msg) - - #---------------------------------------------------------------------- - def warn(self, msg): - """警告信息""" - self.logger.warn(msg) - - #---------------------------------------------------------------------- - def error(self, msg): - """报错输出""" - self.logger.error(msg) - - #---------------------------------------------------------------------- - def exception(self, msg): - """报错输出+记录异常信息""" - self.logger.exception(msg) - - #---------------------------------------------------------------------- - def critical(self, msg): - """影响程序运行的严重错误""" - self.logger.critical(msg) - - #---------------------------------------------------------------------- - def processLogEvent(self, event): - """处理日志事件""" - log = event.dict_['data'] - function = self.levelFunctionDict[log.logLevel] # 获取日志级别对应的处理函数 - msg = '\t'.join([log.gatewayName, log.logContent]) +# encoding: UTF-8 + +import os +import shelve +import logging +from logging import handlers +from collections import OrderedDict +from datetime import datetime +from copy import copy + +# from pymongo import MongoClient, ASCENDING +# from pymongo.errors import ConnectionFailure + +from vnpy.event import Event +from vnpy.trader.vtGlobal import globalSetting +from vnpy.trader.vtEvent import * +from vnpy.trader.vtGateway import * +from vnpy.trader.language import text +from vnpy.trader.vtFunction import getTempPath + + +######################################################################## +class MainEngine(object): + """主引擎""" + + #---------------------------------------------------------------------- + def __init__(self, eventEngine): + """Constructor""" + # 记录今日日期 + self.todayDate = datetime.now().strftime('%Y%m%d') + + # 绑定事件引擎 + self.eventEngine = eventEngine + self.eventEngine.start() + + # 创建数据引擎 + self.dataEngine = DataEngine(self.eventEngine) + + # MongoDB数据库相关 + self.dbClient = None # MongoDB客户端对象 + + # 接口实例 + self.gatewayDict = OrderedDict() + self.gatewayDetailList = [] + + # 应用模块实例 + self.appDict = OrderedDict() + self.appDetailList = [] + + # 风控引擎实例(特殊独立对象) + self.rmEngine = None + + # 日志引擎实例 + self.logEngine = None + self.initLogEngine() + self.eventEngine.register(EVENT_EXIT, self.processExit) + + def processExit(self, event): + self.eventEngine.inactivate() + + #---------------------------------------------------------------------- + def addGateway(self, gatewayModule): + """添加底层接口""" + gatewayName = gatewayModule.gatewayName + gatewayTypeMap = {} + + # 创建接口实例 + if type(gatewayName) == list: + for i in range(len(gatewayName)): + self.gatewayDict[gatewayName[i]] = gatewayModule.gatewayClass( + self.eventEngine, gatewayName[i]) + + # 设置接口轮询 + if gatewayModule.gatewayQryEnabled: + self.gatewayDict[gatewayName[i]].setQryEnabled( + gatewayModule.gatewayQryEnabled) + + # 保存接口详细信息 + d = { + 'gatewayName': gatewayModule.gatewayName[i], + 'gatewayDisplayName': gatewayModule.gatewayDisplayName[i], + 'gatewayType': gatewayModule.gatewayType + } + self.gatewayDetailList.append(d) + else: + self.gatewayDict[gatewayName] = gatewayModule.gatewayClass( + self.eventEngine, gatewayName) + + # 设置接口轮询 + if gatewayModule.gatewayQryEnabled: + self.gatewayDict[gatewayName].setQryEnabled( + gatewayModule.gatewayQryEnabled) + + # 保存接口详细信息 + d = { + 'gatewayName': gatewayModule.gatewayName, + 'gatewayDisplayName': gatewayModule.gatewayDisplayName, + 'gatewayType': gatewayModule.gatewayType + } + self.gatewayDetailList.append(d) + + for i in range(len(self.gatewayDetailList)): + s = self.gatewayDetailList[i]['gatewayName'].split( + '_connect.json')[0] + gatewayTypeMap[s] = self.gatewayDetailList[i]['gatewayType'] + + path = os.getcwd() + # 遍历当前目录下的所有文件 + for root, subdirs, files in os.walk(path): + for name in files: + # 只有文件名中包含_connect.json的文件,才是密钥配置文件 + if '_connect.json' in name: + gw = name.replace('_connect.json', '') + if not gw in gatewayTypeMap.keys(): + for existnames in list(gatewayTypeMap.keys()): + if existnames in gw and existnames != gw: + d = { + 'gatewayName': gw, + 'gatewayDisplayName': gw, + 'gatewayType': gatewayTypeMap[existnames] + } + self.gatewayDetailList.append(d) + self.gatewayDict[ + gw] = gatewayModule.gatewayClass( + self.eventEngine, gw) + + #---------------------------------------------------------------------- + def addApp(self, appModule): + """添加上层应用""" + appName = appModule.appName + + # 创建应用实例 + self.appDict[appName] = appModule.appEngine(self, self.eventEngine) + + # 将应用引擎实例添加到主引擎的属性中 + self.__dict__[appName] = self.appDict[appName] + + # 保存应用信息 + d = { + 'appName': appModule.appName, + 'appDisplayName': appModule.appDisplayName, + 'appWidget': appModule.appWidget, + 'appIco': appModule.appIco + } + self.appDetailList.append(d) + + #---------------------------------------------------------------------- + def getGateway(self, gatewayName): + """获取接口""" + if gatewayName in self.gatewayDict: + return self.gatewayDict[gatewayName] + else: + self.writeLog(text.GATEWAY_NOT_EXIST.format(gateway=gatewayName)) + self.writeLog(gatewayName) + return None + + #---------------------------------------------------------------------- + def connect(self, gatewayName): + """连接特定名称的接口""" + gateway = self.getGateway(gatewayName) + + if gateway: + gateway.connect() + + # 接口连接后自动执行数据库连接的任务 + # self.dbConnect() + + #---------------------------------------------------------------------- + def subscribe(self, subscribeReq, gatewayName): + """订阅特定接口的行情""" + gateway = self.getGateway(gatewayName) + + if gateway: + gateway.subscribe(subscribeReq) + + #---------------------------------------------------------------------- + def sendOrder(self, orderReq, gatewayName): + """对特定接口发单""" + # 如果创建了风控引擎,且风控检查失败则不发单 + if self.rmEngine and not self.rmEngine.checkRisk( + orderReq, gatewayName): + return '' + + gateway = self.getGateway(gatewayName) + if gateway: + vtOrderID = gateway.sendOrder(orderReq) + # self.dataEngine.updateOrderReq(orderReq, vtOrderID) # 更新发出的委托请求到数据引擎中 + return vtOrderID + else: + return '' + + #---------------------------------------------------------------------- + def cancelOrder(self, cancelOrderReq, gatewayName): + """对特定接口撤单""" + gateway = self.getGateway(gatewayName) + + if gateway: + gateway.cancelOrder(cancelOrderReq) + + def batchCancelOrder(self, cancelOrderReqList, gatewayName): + gateway = self.getGateway(gatewayName) + + if gateway: + gateway.batchCancelOrder(cancelOrderReqList) + + #---------------------------------------------------------------------- + def qryAccount(self, gatewayName): + """查询特定接口的账户""" + gateway = self.getGateway(gatewayName) + + if gateway: + gateway.qryAccount() + + #---------------------------------------------------------------------- + def qryPosition(self, gatewayName): + """查询特定接口的持仓""" + gateway = self.getGateway(gatewayName) + + if gateway: + gateway.qryPosition() + + #------------------------------------------------ + def initPosition(self, vtSymbol): + """策略初始化时查询特定接口的持仓""" + contract = self.getContract(vtSymbol) + if contract: + gatewayName = contract.gatewayName + gateway = self.getGateway(gatewayName) + if gateway: + gateway.initPosition(vtSymbol) + else: + self.writeLog( + 'we don\'t have this symbol %s, Please check symbolList in ctaSetting.json' + % vtSymbol) + return None + + def loadHistoryBar(self, vtSymbol, type_, size=None, since=None): + """策略初始化时下载历史数据""" + contract = self.getContract(vtSymbol) + gatewayName = contract.gatewayName + gateway = self.getGateway(gatewayName) + if gateway: + data = gateway.loadHistoryBar(vtSymbol, type_, size, since) + return data + + def qryAllOrders(self, vtSymbol, orderId, status=None): + contract = self.getContract(vtSymbol) + gatewayName = contract.gatewayName + gateway = self.getGateway(gatewayName) + if gateway: + gateway.qryAllOrders(vtSymbol, orderId, status) + + #---------------------------------------------------------------------- + def exit(self): + """退出程序前调用,保证正常退出""" + # 安全关闭所有接口 + for gateway in list(self.gatewayDict.values()): + gateway.close() + + # 停止事件引擎 + self.eventEngine.stop() + + # 停止上层应用引擎 + for appEngine in list(self.appDict.values()): + appEngine.stop() + + # 保存数据引擎里的合约数据到硬盘 + self.dataEngine.saveContracts() + + #---------------------------------------------------------------------- + def writeLog(self, content): + """快速发出日志事件""" + log = VtLogData() + log.logContent = content + log.gatewayName = 'MAIN_ENGINE' + event = Event(type_=EVENT_LOG) + event.dict_['data'] = log + self.eventEngine.put(event) + + #---------------------------------------------------------------------- + def getContract(self, vtSymbol): + """查询合约""" + return self.dataEngine.getContract(vtSymbol) + + #---------------------------------------------------------------------- + def getAllContracts(self): + """查询所有合约(返回列表)""" + return self.dataEngine.getAllContracts() + + #---------------------------------------------------------------------- + def getOrder(self, vtOrderID): + """查询委托""" + return self.dataEngine.getOrder(vtOrderID) + + #---------------------------------------------------------------------- + def getAllWorkingOrders(self): + """查询所有的活跃的委托(返回列表)""" + return self.dataEngine.getAllWorkingOrders() + + #---------------------------------------------------------------------- + def getAllOrders(self): + """查询所有委托""" + return self.dataEngine.getAllOrders() + + #---------------------------------------------------------------------- + def getAllTrades(self): + """查询所有成交""" + return self.dataEngine.getAllTrades() + + #---------------------------------------------------------------------- + def getAllAccounts(self): + """查询所有账户""" + return self.dataEngine.getAllAccounts() + + def getAllPositions(self): + """查询所有持仓""" + return self.dataEngine.getAllPositions() + + #---------------------------------------------------------------------- + def getAllPositionDetails(self): + """查询本地持仓缓存细节""" + return self.dataEngine.getAllPositionDetails() + + #---------------------------------------------------------------------- + def getAllGatewayDetails(self): + """查询引擎中所有底层接口的信息""" + return self.gatewayDetailList + + #---------------------------------------------------------------------- + def getAllAppDetails(self): + """查询引擎中所有上层应用的信息""" + return self.appDetailList + + #---------------------------------------------------------------------- + def getApp(self, appName): + """获取APP引擎对象""" + return self.appDict[appName] + + #---------------------------------------------------------------------- + def initLogEngine(self): + """初始化日志引擎""" + if not globalSetting["logActive"]: + return + + # 创建引擎 + self.logEngine = LogEngine() + + # 设置日志级别 + levelDict = { + "debug": LogEngine.LEVEL_DEBUG, + "info": LogEngine.LEVEL_INFO, + "warn": LogEngine.LEVEL_WARN, + "error": LogEngine.LEVEL_ERROR, + "critical": LogEngine.LEVEL_CRITICAL, + } + level = levelDict.get(globalSetting["logLevel"], + LogEngine.LEVEL_CRITICAL) + self.logEngine.setLogLevel(level) + stream_setting = globalSetting.get("streamLevel", "info") + streamLevel = levelDict.get(stream_setting, LogEngine.LEVEL_CRITICAL) + self.logEngine.setStreamLevel(streamLevel) + + # 设置输出 + if globalSetting['logConsole']: + self.logEngine.addConsoleHandler() + + if globalSetting['logFile']: + self.logEngine.addFileHandler() + + # 注册事件监听 + self.registerLogEvent(EVENT_LOG) + + #---------------------------------------------------------------------- + def registerLogEvent(self, eventType): + """注册日志事件监听""" + if self.logEngine: + self.eventEngine.register(eventType, + self.logEngine.processLogEvent) + + #---------------------------------------------------------------------- + # def convertOrderReq(self, req): + # """转换委托请求""" + # return self.dataEngine.convertOrderReq(req) + + #---------------------------------------------------------------------- + def getLog(self): + """查询日志""" + return self.dataEngine.getLog() + + #---------------------------------------------------------------------- + def getError(self): + """查询错误""" + return self.dataEngine.getError() + + +######################################################################## + + +class DataEngine(object): + """数据引擎""" + contractFileName = 'ContractData.vt' + contractFilePath = getTempPath(contractFileName) + + FINISHED_STATUS = [STATUS_ALLTRADED, STATUS_REJECTED, STATUS_CANCELLED] + + #---------------------------------------------------------------------- + def __init__(self, eventEngine): + """Constructor""" + self.eventEngine = eventEngine + + # 保存数据的字典和列表 + self.tickDict = {} + self.contractDict = {} + self.orderDict = {} + self.workingOrderDict = {} # 可撤销委托 + self.tradeDict = {} + self.accountDict = {} + self.positionDict = {} + self.logList = [] + self.errorList = [] + + # 持仓细节相关 + # self.detailDict = {} # vtSymbol:PositionDetail + self.tdPenaltyList = globalSetting['tdPenalty'] # 平今手续费惩罚的产品代码列表 + + # 读取保存在硬盘的合约数据 + self.loadContracts() + + # 注册事件监听 + self.registerEvent() + + #---------------------------------------------------------------------- + def registerEvent(self): + """注册事件监听""" + self.eventEngine.register(EVENT_TICK, self.processTickEvent) + self.eventEngine.register(EVENT_CONTRACT, self.processContractEvent) + self.eventEngine.register(EVENT_ORDER, self.processOrderEvent) + self.eventEngine.register(EVENT_TRADE, self.processTradeEvent) + self.eventEngine.register(EVENT_POSITION, self.processPositionEvent) + self.eventEngine.register(EVENT_ACCOUNT, self.processAccountEvent) + self.eventEngine.register(EVENT_LOG, self.processLogEvent) + self.eventEngine.register(EVENT_ERROR, self.processErrorEvent) + + #---------------------------------------------------------------------- + def processTickEvent(self, event): + """处理成交事件""" + tick = event.dict_['data'] + self.tickDict[tick.vtSymbol] = tick + + #---------------------------------------------------------------------- + def processContractEvent(self, event): + """处理合约事件""" + contract = event.dict_['data'] + self.contractDict[contract.vtSymbol] = contract + self.contractDict[contract.symbol] = contract # 使用常规代码(不包括交易所)可能导致重复 + + #---------------------------------------------------------------------- + def processOrderEvent(self, event): + """处理委托事件""" + order = event.dict_['data'] + self.orderDict[order.vtOrderID] = order + + # 如果订单的状态是全部成交或者撤销,则需要从workingOrderDict中移除 + if order.status in self.FINISHED_STATUS: + if order.vtOrderID in self.workingOrderDict: + del self.workingOrderDict[order.vtOrderID] + # 否则则更新字典中的数据 + else: + self.workingOrderDict[order.vtOrderID] = order + + # 更新到持仓细节中 + # detail = self.getPositionDetail(order.vtSymbol) + # detail.updateOrder(order) + + #---------------------------------------------------------------------- + def processTradeEvent(self, event): + """处理成交事件""" + trade = event.dict_['data'] + + self.tradeDict[trade.vtTradeID] = trade + + # 更新到持仓细节中 + # detail = self.getPositionDetail(trade.vtSymbol) + # detail.updateTrade(trade) + + #---------------------------------------------------------------------- + def processPositionEvent(self, event): + """处理持仓事件""" + pos = event.dict_['data'] + + self.positionDict[pos.vtPositionName] = pos + + # 更新到持仓细节中 + # detail = self.getPositionDetail(pos.vtSymbol) + # detail.updatePosition(pos) + + #---------------------------------------------------------------------- + def processAccountEvent(self, event): + """处理账户事件""" + account = event.dict_['data'] + self.accountDict[account.vtAccountID] = account + + #---------------------------------------------------------------------- + def processLogEvent(self, event): + """处理日志事件""" + log = event.dict_['data'] + self.logList.append(log) + + #---------------------------------------------------------------------- + def processErrorEvent(self, event): + """处理错误事件""" + error = event.dict_['data'] + self.errorList.append(error) + + #---------------------------------------------------------------------- + def getTick(self, vtSymbol): + """查询行情对象""" + try: + return self.tickDict[vtSymbol] + except KeyError: + return None + + #---------------------------------------------------------------------- + def getContract(self, vtSymbol): + """查询合约对象""" + try: + return self.contractDict[vtSymbol] + except KeyError: + return None + + #---------------------------------------------------------------------- + def getAllContracts(self): + """查询所有合约对象(返回列表)""" + return self.contractDict.values() + + #---------------------------------------------------------------------- + def saveContracts(self): + """保存所有合约对象到硬盘""" + f = shelve.open(self.contractFilePath) + f['data'] = self.contractDict + f.close() + + #---------------------------------------------------------------------- + def loadContracts(self): + """从硬盘读取合约对象""" + f = shelve.open(self.contractFilePath) + if 'data' in f: + d = f['data'] + for key, value in d.items(): + self.contractDict[key] = value + f.close() + + #---------------------------------------------------------------------- + def getOrder(self, vtOrderID): + """查询委托""" + try: + return self.orderDict[vtOrderID] + except KeyError: + return None + + #---------------------------------------------------------------------- + def getAllWorkingOrders(self): + """查询所有活动委托(返回列表)""" + return self.workingOrderDict.values() + + #---------------------------------------------------------------------- + def getAllOrders(self): + """获取所有委托""" + return self.orderDict.values() + + #---------------------------------------------------------------------- + def getAllTrades(self): + """获取所有成交""" + return self.tradeDict.values() + + #---------------------------------------------------------------------- + def getAllPositions(self): + """获取所有持仓""" + return self.positionDict.values() + + #---------------------------------------------------------------------- + def getAllAccounts(self): + """获取所有资金""" + return self.accountDict.values() + + # #---------------------------------------------------------------------- + # def getPositionDetail(self, vtSymbol): + # """查询持仓细节""" + # if vtSymbol in self.detailDict: + # detail = self.detailDict[vtSymbol] + # else: + # contract = self.getContract(vtSymbol) + # detail = PositionDetail(vtSymbol, contract) + # self.detailDict[vtSymbol] = detail + + # # 设置持仓细节的委托转换模式 + # contract = self.getContract(vtSymbol) + + # if contract: + # detail.exchange = contract.exchange + + # # 上期所合约 + # if contract.exchange == EXCHANGE_SHFE: + # detail.mode = detail.MODE_SHFE + + # # 检查是否有平今惩罚 + # for productID in self.tdPenaltyList: + # if str(productID) in contract.symbol: + # detail.mode = detail.MODE_TDPENALTY + + # return detail + + #---------------------------------------------------------------------- + # def getAllPositionDetails(self): + # """查询所有本地持仓缓存细节""" + # return self.detailDict.values() + + # #---------------------------------------------------------------------- + # def updateOrderReq(self, req, vtOrderID): + # """委托请求更新""" + # vtSymbol = req.vtSymbol + + # detail = self.getPositionDetail(vtSymbol) + # detail.updateOrderReq(req, vtOrderID) + + # #---------------------------------------------------------------------- + # def convertOrderReq(self, req): + # """根据规则转换委托请求""" + # detail = self.detailDict.get(req.vtSymbol, None) + # if not detail: + # return [req] + # else: + # return detail.convertOrderReq(req) + + #---------------------------------------------------------------------- + def getLog(self): + """获取日志""" + return self.logList + + #---------------------------------------------------------------------- + def getError(self): + """获取错误""" + return self.errorList + + +######################################################################## +class LogEngine(object): + """日志引擎""" + format = '%(asctime)s %(levelname)s: %(message)s' + # 日志级别 + LEVEL_DEBUG = logging.DEBUG + LEVEL_INFO = logging.INFO + LEVEL_WARN = logging.WARN + LEVEL_ERROR = logging.ERROR + LEVEL_CRITICAL = logging.CRITICAL + + # 单例对象 + instance = None + + #---------------------------------------------------------------------- + def __new__(cls, *args, **kwargs): + """创建对象,保证单例""" + if not cls.instance: + cls.instance = super(LogEngine, cls).__new__(cls, *args, **kwargs) + return cls.instance + + #---------------------------------------------------------------------- + def __init__(self): + """Constructor""" + self.logger = logging.getLogger() + # TODO: may be we should put vnpy log in an independant logger. + self.logger.handlers = [] + self.formatter = logging.Formatter(self.format) + self.level = self.LEVEL_CRITICAL + self.streamLevel = self.LEVEL_CRITICAL + + self.consoleHandler = None + self.fileHandler = None + + # 添加NullHandler防止无handler的错误输出 + nullHandler = logging.NullHandler() + self.logger.addHandler(nullHandler) + + # 日志级别函数映射 + self.levelFunctionDict = { + self.LEVEL_DEBUG: self.debug, + self.LEVEL_INFO: self.info, + self.LEVEL_WARN: self.warn, + self.LEVEL_ERROR: self.error, + self.LEVEL_CRITICAL: self.critical, + } + + #---------------------------------------------------------------------- + def setLogLevel(self, level): + """设置日志级别""" + self.logger.setLevel(level) + self.level = level + + def setStreamLevel(self, level): + self.streamLevel = level + + #---------------------------------------------------------------------- + def addConsoleHandler(self): + """添加终端输出""" + if not self.consoleHandler: + self.consoleHandler = logging.StreamHandler() + self.consoleHandler.setLevel(self.streamLevel) + self.consoleHandler.setFormatter(self.formatter) + self.logger.addHandler(self.consoleHandler) + + #---------------------------------------------------------------------- + def addFileHandler(self): + """添加文件输出""" + if not self.fileHandler: + filename = 'vt_' + datetime.now().strftime('%Y%m%d') + '.log' + filepath = getTempPath(filename) + # self.fileHandler = logging.FileHandler(filepath) # 引擎原有的handler + # 限制日志文件大小为20M,一天最多 400 MB + self.fileHandler = logging.handlers.RotatingFileHandler( + filepath, maxBytes=20971520, backupCount=20, encoding="utf-8") + self.fileHandler.setLevel(self.level) + self.fileHandler.setFormatter(self.formatter) + self.logger.addHandler(self.fileHandler) + + #---------------------------------------------------------------------- + def debug(self, msg): + """开发时用""" + self.logger.debug(msg) + + #---------------------------------------------------------------------- + def info(self, msg): + """正常输出""" + self.logger.info(msg) + + #---------------------------------------------------------------------- + def warn(self, msg): + """警告信息""" + self.logger.warn(msg) + + #---------------------------------------------------------------------- + def error(self, msg): + """报错输出""" + self.logger.error(msg) + + #---------------------------------------------------------------------- + def exception(self, msg): + """报错输出+记录异常信息""" + self.logger.exception(msg) + + #---------------------------------------------------------------------- + def critical(self, msg): + """影响程序运行的严重错误""" + self.logger.critical(msg) + + #---------------------------------------------------------------------- + def processLogEvent(self, event): + """处理日志事件""" + log = event.dict_['data'] + function = self.levelFunctionDict[log.logLevel] # 获取日志级别对应的处理函数 + msg = '\t'.join([log.gatewayName, log.logContent]) function(msg) \ No newline at end of file diff --git a/vnpy/trader/vtEvent.py b/vnpy/trader/vtEvent.py index 4c5c67e..77e7fec 100644 --- a/vnpy/trader/vtEvent.py +++ b/vnpy/trader/vtEvent.py @@ -1,20 +1,21 @@ -# encoding: UTF-8 - -''' -本文件基于vnpy.event.eventType,并添加更多字段 -''' - -from vnpy.event.eventType import * - -# 系统相关 -EVENT_TIMER = 'eTimer' # 计时器事件,每隔1秒发送一次 -EVENT_LOG = 'eLog' # 日志事件,全局通用 - -# Gateway相关 -EVENT_TICK = 'eTick.' # TICK行情事件,可后接具体的vtSymbol -EVENT_TRADE = 'eTrade.' # 成交回报事件 -EVENT_ORDER = 'eOrder.' # 报单回报事件 -EVENT_POSITION = 'ePosition.' # 持仓回报事件 -EVENT_ACCOUNT = 'eAccount.' # 账户回报事件 -EVENT_CONTRACT = 'eContract.' # 合约基础信息回报事件 -EVENT_ERROR = 'eError.' # 错误回报事件 \ No newline at end of file +# encoding: UTF-8 + +''' +本文件基于vnpy.event.eventType,并添加更多字段 +''' + +from vnpy.event.eventType import * + +# 系统相关 +EVENT_TIMER = 'eTimer' # 计时器事件,每隔1秒发送一次 +EVENT_LOG = 'eLog' # 日志事件,全局通用 + +# Gateway相关 +EVENT_TICK = 'eTick.' # TICK行情事件,可后接具体的vtSymbol +EVENT_TRADE = 'eTrade.' # 成交回报事件 +EVENT_ORDER = 'eOrder.' # 报单回报事件 +EVENT_POSITION = 'ePosition.' # 持仓回报事件 +EVENT_ACCOUNT = 'eAccount.' # 账户回报事件 +EVENT_CONTRACT = 'eContract.' # 合约基础信息回报事件 +EVENT_ERROR = 'eError.' # 错误回报事件 +EVENT_EXIT = "eExit." \ No newline at end of file From d89980216aa191c278447e7eac5e4317b2cc2bfe Mon Sep 17 00:00:00 2001 From: caimeng <862786917@qq.com> Date: Thu, 11 Jul 2019 18:32:10 +0800 Subject: [PATCH 3/3] gateway unreach --- vnpy/trader/gateway/okexGateway/future.py | 60 +++++++++++++++++------ 1 file changed, 46 insertions(+), 14 deletions(-) diff --git a/vnpy/trader/gateway/okexGateway/future.py b/vnpy/trader/gateway/okexGateway/future.py index 807197b..ef6144c 100644 --- a/vnpy/trader/gateway/okexGateway/future.py +++ b/vnpy/trader/gateway/okexGateway/future.py @@ -67,6 +67,8 @@ def __init__(self, gateway): } self.order_queue = queue.Queue() self.orderThread = threading.Thread(target=self.getQueue) + self._health = False + self._firstFailTime = 0 def runOrderThread(self): if not self.orderThread.is_alive(): @@ -78,6 +80,7 @@ def runOrderThread(self): #---------------------------------------------------------------------- def connect(self, REST_HOST, leverage, sessionCount): """连接服务器""" + self._health = True self.leverage = leverage self.init(REST_HOST) @@ -176,12 +179,17 @@ def sendOrder(self, orderReq, orderID):# type: (VtOrderReq)->str self.orderDict[orderID] = order self.unfinished_orders[orderID] = order - self.addRequest('POST', '/api/futures/v3/order', - callback=self.onSendOrder, - data=data, - extra=order, - onFailed=self.onSendOrderFailed, - onError=self.onSendOrderError) + if self._health: + self.addRequest('POST', '/api/futures/v3/order', + callback=self.onSendOrder, + data=data, + extra=order, + onFailed=self.onSendOrderFailed, + onError=self.onSendOrderError) + else: + order.status = constant.STATUS_REJECTED + order.rejectedInfo = "Gateway unreachable" + self.gateway.onOrder(order) return vtOrderID @@ -198,8 +206,10 @@ def cancelOrder(self, cancelOrderReq): #---------------------------------------------------------------------- def queryContract(self): """限速规则:20次/2s""" + print("query /api/futures/v3/instruments") self.addRequest('GET', '/api/futures/v3/instruments', callback=self.onQueryContract) + #---------------------------------------------------------------------- def queryMonoAccount(self, symbolList): @@ -613,6 +623,7 @@ def onQueryMonoOrder(self, d, request): "filled_qty":"0","fee":"0","order_id":"2522410732495872","price":"55","price_avg":"0","status":"0", "type":"1","contract_val":"10","leverage":"20","client_oid":"BarFUTU19032211220110001","pnl":"0", "order_type":"0"}""" + self.unLockRequests() if d: # self.order_queue.put(d) self.putOrderQueue(d, self.ORDER_INSTANCE) @@ -625,17 +636,36 @@ def onQueryOrder(self, d, request): 'filled_qty': '0', 'fee': '0', 'order_id': '2398983698358272', 'price': '50', 'price_avg': '0', 'status': '0', 'type': '1', 'contract_val': '10', 'leverage': '20', 'client_oid': '', 'pnl': '0', 'order_type': '0'}]} """ + self.unLockRequests() for data in d['order_info']: self.putOrderQueue(data, self.ORDER_INSTANCE) def onQueryMonoOrderFailed(self, data, request): - self.putOrderQueue({ - "client_oid": request.extra, - "message": "Order not exists" - }, self.ORDER_REJECT) - oid = request.extra - self.gateway.writeLog(f'Query order failed: {oid} | result: {data}', logging.ERROR) - + if request.response.status_code == 404: + self.putOrderQueue({ + "client_oid": request.extra, + "message": "Order not exists" + }, self.ORDER_REJECT) + oid = request.extra + self.gateway.writeLog(f'Query order failed: {oid} | result: {data}', logging.ERROR) + self.unLockRequests() + else: + self.lockRequests(data, request) + + def lockRequests(self, data, request): + if self._health: + self._health = False + self._firstFailTime = datetime.now().timestamp() + logging.error( + f"Bad response received and lock sending orders: status_code={request.response.status_code}, data={data}" + ) + else: + if datetime.now().timestamp() - self._firstFailTime > 300: + self.gateway.sendExit() + + def unLockRequests(self): + self._health = True + #---------------------------------------------------------------------- def onSendOrderFailed(self, data, request): """ @@ -663,7 +693,7 @@ def onSendOrderError(self, exceptionType, exceptionValue, tb, request): def onSendOrder(self, data, request): """{'result': True, 'error_message': '', 'error_code': 0, 'client_oid': '181129173533', 'order_id': '1878377147147264'}""" - + self.unLockRequests() # success if data.get("result", False): self.putOrderQueue( @@ -722,6 +752,8 @@ def onError(self, exceptionType, exceptionValue, tb, request): """ Python内部错误处理:默认行为是仍给excepthook """ + if isinstance(exceptionValue, ConnectionError): + self.lockRequests({"message": "ConnectionError"}, request) e = VtErrorData() e.gatewayName = self.gatewayName e.errorID = exceptionType