Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate NeuTrader to Qlib RL #1169

Merged
merged 46 commits into from
Aug 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
b184cc4
Refine previous version RL codes
lihuoran Jun 16, 2022
92d4ec4
Polish utils/__init__.py
lihuoran Jun 16, 2022
7535d60
Draft
lihuoran Jun 20, 2022
15340ff
Merge branch 'main' into huoran/qlib_rl
lihuoran Jun 24, 2022
e23504c
Use | instead of Union
lihuoran Jun 24, 2022
9348401
Simulator & action interpreter
lihuoran Jun 27, 2022
a2f7383
Test passed
lihuoran Jun 27, 2022
47252a4
Merge branch 'main' into huoran/qlib_rl
lihuoran Jun 28, 2022
d8858ba
Migrate to SAOEState & new qlib interpreter
lihuoran Jul 8, 2022
09f5106
Black format
lihuoran Jul 8, 2022
11ee76e
. Revert file_storage change
lihuoran Jul 14, 2022
3294e4d
Refactor file structure & renaming functions
lihuoran Jul 14, 2022
a44fbf5
Enrich test cases
lihuoran Jul 15, 2022
aeb54cb
Add QlibIntradayBacktestData
lihuoran Jul 15, 2022
5ff6407
Test interpreter
lihuoran Jul 15, 2022
7d46689
Black format
lihuoran Jul 19, 2022
3ab9df2
.
lihuoran Jul 21, 2022
fae0f77
Merge branch 'main' into huoran/qlib_rl
lihuoran Jul 21, 2022
036e593
Rename receive_execute_result()
lihuoran Jul 21, 2022
53dde51
Use indicator to simplify state update
lihuoran Jul 22, 2022
00def78
Format code
lihuoran Jul 22, 2022
0536672
Modify data path
lihuoran Jul 25, 2022
77966c2
Adjust file structure
lihuoran Jul 26, 2022
85a2cb3
Minor change
lihuoran Jul 26, 2022
a573768
Merge branch 'main' into huoran/qlib_rl
lihuoran Jul 26, 2022
ecb385a
Add copyright message
lihuoran Jul 26, 2022
80b2006
Format code
lihuoran Jul 26, 2022
e864bba
Rename util functions
lihuoran Jul 26, 2022
bad1ae5
Add CI
lihuoran Jul 26, 2022
0caa9a4
Pylint issue
lihuoran Jul 26, 2022
83d8f00
Remove useless code to pass pylint
lihuoran Jul 26, 2022
ccc3f96
Pass mypy
lihuoran Jul 26, 2022
f269274
Mypy issue
lihuoran Jul 26, 2022
e453290
mypy issue
lihuoran Jul 26, 2022
8eb1b01
mypy issue
lihuoran Jul 27, 2022
e2a72b6
Revert "mypy issue"
lihuoran Jul 27, 2022
59e0b80
mypy issue
lihuoran Jul 27, 2022
2fcadfe
mypy issue
lihuoran Jul 27, 2022
54231b1
Fix the numpy version incompatible bug
you-n-g Jul 27, 2022
cbb767e
Fix a minor typing issue
lihuoran Jul 27, 2022
87ef47f
Try to skip python 3.7 test for qlib simulator
lihuoran Jul 27, 2022
c495798
Resolve PR comments by Yuge; solve several CI issues.
lihuoran Jul 27, 2022
362c3ab
Black issue
lihuoran Jul 27, 2022
eb8593b
Fix a low-level type error
lihuoran Jul 28, 2022
8ae62fc
Change data name
lihuoran Jul 28, 2022
a6aa367
Resolve PR comments. Leave TODOs in the code base.
lihuoran Jul 29, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions qlib/backtest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def get_exchange(
close_cost: float = 0.0025,
min_cost: float = 5.0,
limit_threshold: Union[Tuple[str, str], float, None] = None,
deal_price: Union[str, Tuple[str], List[str]] = None,
deal_price: Union[str, Tuple[str, str], List[str]] = None,
**kwargs: Any,
) -> Exchange:
"""get_exchange
Expand Down Expand Up @@ -70,10 +70,10 @@ def get_exchange(
min_cost : float
min transaction cost. It is an absolute amount of cost instead of a ratio of your order's deal amount.
e.g. You must pay at least 5 yuan of commission regardless of your order's deal amount.
deal_price: Union[str, Tuple[str], List[str]]
deal_price: Union[str, Tuple[str, str], List[str]]
The `deal_price` supports following two types of input
- <deal_price> : str
- (<buy_price>, <sell_price>): Tuple[str] or List[str]
- (<buy_price>, <sell_price>): Tuple[str, str] or List[str]

<deal_price>, <buy_price> or <sell_price> := <price>
<price> := str
Expand Down
25 changes: 15 additions & 10 deletions qlib/backtest/decision.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
from __future__ import annotations

from abc import abstractmethod
from datetime import time
from enum import IntEnum

# try to fix circular imports when enabling type hints
from typing import Generic, List, TYPE_CHECKING, Any, ClassVar, Optional, Tuple, TypeVar, Union, cast
from typing import TYPE_CHECKING, Any, ClassVar, Generic, List, Optional, Tuple, TypeVar, Union, cast

from qlib.backtest.utils import TradeCalendarManager
from qlib.data.data import Cal
Expand All @@ -23,7 +24,6 @@
import numpy as np
import pandas as pd


DecisionType = TypeVar("DecisionType")


Expand Down Expand Up @@ -182,8 +182,8 @@ def create(
return Order(
stock_id=code,
amount=amount,
start_time=start_time if start_time is not None else pd.Timestamp(start_time),
end_time=end_time if end_time is not None else pd.Timestamp(end_time),
start_time=None if start_time is None else pd.Timestamp(start_time),
end_time=None if end_time is None else pd.Timestamp(end_time),
direction=direction,
)

Expand Down Expand Up @@ -249,7 +249,7 @@ def clip_time_range(self, start_time: pd.Timestamp, end_time: pd.Timestamp) -> T
class TradeRangeByTime(TradeRange):
"""This is a helper function for make decisions"""

def __init__(self, start_time: str, end_time: str) -> None:
def __init__(self, start_time: str | time, end_time: str | time) -> None:
"""
This is a callable class.

Expand All @@ -259,13 +259,13 @@ def __init__(self, start_time: str, end_time: str) -> None:

Parameters
----------
start_time : str
start_time : str | time
e.g. "9:30"
end_time : str
end_time : str | time
e.g. "14:30"
"""
self.start_time = pd.Timestamp(start_time).time()
self.end_time = pd.Timestamp(end_time).time()
self.start_time = pd.Timestamp(start_time).time() if isinstance(start_time, str) else start_time
self.end_time = pd.Timestamp(end_time).time() if isinstance(end_time, str) else end_time
assert self.start_time < self.end_time

def __call__(self, trade_calendar: TradeCalendarManager) -> Tuple[int, int]:
Expand Down Expand Up @@ -535,7 +535,12 @@ class TradeDecisionWO(BaseTradeDecision[Order]):
Besides, the time_range is also included.
"""

def __init__(self, order_list: List[object], strategy: BaseStrategy, trade_range: Tuple[int, int] = None) -> None:
def __init__(
self,
order_list: List[Order],
strategy: BaseStrategy,
trade_range: Union[Tuple[int, int], TradeRange] = None,
) -> None:
super().__init__(strategy, trade_range=trade_range)
self.order_list = cast(List[Order], order_list)
start, end = strategy.trade_calendar.get_step_time()
Expand Down
17 changes: 10 additions & 7 deletions qlib/backtest/exchange.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(
start_time: Union[pd.Timestamp, str] = None,
end_time: Union[pd.Timestamp, str] = None,
codes: Union[list, str] = "all",
deal_price: Union[str, Tuple[str], List[str]] = None,
deal_price: Union[str, Tuple[str, str], List[str]] = None,
subscribe_fields: list = [],
limit_threshold: Union[Tuple[str, str], float, None] = None,
volume_threshold: Union[tuple, dict] = None,
Expand Down Expand Up @@ -448,9 +448,9 @@ def get_volume(
start_time: pd.Timestamp,
end_time: pd.Timestamp,
method: Optional[str] = "sum",
) -> float:
) -> Union[None, int, float, bool, IndexData]:
"""get the total deal volume of stock with `stock_id` between the time interval [start_time, end_time)"""
return cast(float, self.quote.get_data(stock_id, start_time, end_time, field="$volume", method=method))
return self.quote.get_data(stock_id, start_time, end_time, field="$volume", method=method)

def get_deal_price(
self,
Expand All @@ -459,7 +459,7 @@ def get_deal_price(
end_time: pd.Timestamp,
direction: OrderDir,
method: Optional[str] = "ts_data_last",
) -> float:
) -> Union[None, int, float, bool, IndexData]:
if direction == OrderDir.SELL:
pstr = self.sell_price
elif direction == OrderDir.BUY:
Expand All @@ -472,7 +472,7 @@ def get_deal_price(
self.logger.warning(f"(stock_id:{stock_id}, trade_time:{(start_time, end_time)}, {pstr}): {deal_price}!!!")
self.logger.warning(f"setting deal_price to close price")
deal_price = self.get_close(stock_id, start_time, end_time, method)
return cast(float, deal_price)
return deal_price

def get_factor(
self,
Expand Down Expand Up @@ -832,8 +832,11 @@ def _calc_trade_info_by_order(
:param dealt_order_amount: the dealt order amount dict with the format of {stock_id: float}
:return: trade_price, trade_val, trade_cost
"""
trade_price = self.get_deal_price(order.stock_id, order.start_time, order.end_time, direction=order.direction)
total_trade_val = self.get_volume(order.stock_id, order.start_time, order.end_time) * trade_price
trade_price = cast(
float,
self.get_deal_price(order.stock_id, order.start_time, order.end_time, direction=order.direction),
)
total_trade_val = cast(float, self.get_volume(order.stock_id, order.start_time, order.end_time)) * trade_price
order.factor = self.get_factor(order.stock_id, order.start_time, order.end_time)
order.deal_amount = order.amount # set to full amount and clip it step by step
# Clipping amount first
Expand Down
1 change: 1 addition & 0 deletions qlib/backtest/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,7 @@ def post_inner_exe_step(self, inner_exe_res: List[object]) -> None:
inner_exe_res :
the execution result of inner task
"""
self.inner_strategy.post_exe_step(inner_exe_res)

def get_all_executors(self) -> List[BaseExecutor]:
"""get all executors, including self and inner_executor.get_all_executors()"""
Expand Down
10 changes: 5 additions & 5 deletions qlib/contrib/data/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def use(x):
fields += ["Rsquare($close, %d)" % d for d in windows]
names += ["RSQR%d" % d for d in windows]
if use("RESI"):
# The redisdual for linear regression for the past d days, represent the trend linearity for past d days.
# The redisdual for linear regression for the past d days, represent the trend linearity for past d days.
fields += ["Resi($close, %d)/$close" % d for d in windows]
names += ["RESI%d" % d for d in windows]
if use("MAX"):
Expand All @@ -297,15 +297,15 @@ def use(x):
names += ["MIN%d" % d for d in windows]
if use("QTLU"):
# The 80% quantile of past d day's close price, divided by latest close price to remove unit
# Used with MIN and MAX
# Used with MIN and MAX
fields += ["Quantile($close, %d, 0.8)/$close" % d for d in windows]
names += ["QTLU%d" % d for d in windows]
if use("QTLD"):
# The 20% quantile of past d day's close price, divided by latest close price to remove unit
fields += ["Quantile($close, %d, 0.2)/$close" % d for d in windows]
names += ["QTLD%d" % d for d in windows]
if use("RANK"):
# Get the percentile of current close price in past d day's close price.
# Get the percentile of current close price in past d day's close price.
# Represent the current price level comparing to past N days, add additional information to moving average.
fields += ["Rank($close, %d)" % d for d in windows]
names += ["RANK%d" % d for d in windows]
Expand All @@ -316,14 +316,14 @@ def use(x):
if use("IMAX"):
# The number of days between current date and previous highest price date.
# Part of Aroon Indicator https://www.investopedia.com/terms/a/aroon.asp
# The indicator measures the time between highs and the time between lows over a time period.
# The indicator measures the time between highs and the time between lows over a time period.
# The idea is that strong uptrends will regularly see new highs, and strong downtrends will regularly see new lows.
fields += ["IdxMax($high, %d)/%d" % (d, d) for d in windows]
names += ["IMAX%d" % d for d in windows]
if use("IMIN"):
# The number of days between current date and previous lowest price date.
# Part of Aroon Indicator https://www.investopedia.com/terms/a/aroon.asp
# The indicator measures the time between highs and the time between lows over a time period.
# The indicator measures the time between highs and the time between lows over a time period.
# The idea is that strong uptrends will regularly see new highs, and strong downtrends will regularly see new lows.
fields += ["IdxMin($low, %d)/%d" % (d, d) for d in windows]
names += ["IMIN%d" % d for d in windows]
Expand Down
17 changes: 14 additions & 3 deletions qlib/data/storage/file_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,22 @@ def _freq_file(self) -> str:
self._freq_file_cache = freq
return self._freq_file_cache

def _read_calendar(self, skip_rows: int = 0, n_rows: int = None) -> List[CalVT]:
def _read_calendar(self) -> List[CalVT]:
# NOTE:
# if we want to accelerate partial reading calendar
# we can add parameters like `skip_rows: int = 0, n_rows: int = None` to the interface.
# Currently, it is not supported for the txt-based calendar

if not self.uri.exists():
self._write_calendar(values=[])
with self.uri.open("rb") as fp:
return [str(x) for x in np.loadtxt(fp, str, skiprows=skip_rows, max_rows=n_rows, encoding="utf-8")]

with self.uri.open("r") as fp:
res = []
for line in fp.readlines():
line = line.strip()
if len(line) > 0:
res.append(line)
return res

def _write_calendar(self, values: Iterable[CalVT], mode: str = "wb"):
with self.uri.open(mode=mode) as fp:
Expand Down
4 changes: 2 additions & 2 deletions qlib/rl/aux_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from __future__ import annotations

from typing import Generic, TYPE_CHECKING, TypeVar
from typing import Optional, TYPE_CHECKING, Generic, TypeVar

from qlib.typehint import final

Expand All @@ -21,7 +21,7 @@
class AuxiliaryInfoCollector(Generic[StateType, AuxInfoType]):
"""Override this class to collect customized auxiliary information from environment."""

env: EnvWrapper | None = None
env: Optional[EnvWrapper] = None
lihuoran marked this conversation as resolved.
Show resolved Hide resolved

@final
def __call__(self, simulator_state: StateType) -> AuxInfoType:
Expand Down
58 changes: 58 additions & 0 deletions qlib/rl/data/exchange_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from typing import cast
lihuoran marked this conversation as resolved.
Show resolved Hide resolved

import pandas as pd

from qlib.backtest import Exchange, Order
from .pickle_styled import IntradayBacktestData


class QlibIntradayBacktestData(IntradayBacktestData):
"""Backtest data for Qlib simulator"""

def __init__(self, order: Order, exchange: Exchange, start_time: pd.Timestamp, end_time: pd.Timestamp) -> None:
super(QlibIntradayBacktestData, self).__init__()
self._order = order
self._exchange = exchange
self._start_time = start_time
self._end_time = end_time

self._deal_price = cast(
pd.Series,
self._exchange.get_deal_price(
self._order.stock_id,
self._start_time,
self._end_time,
direction=self._order.direction,
method=None,
),
)
self._volume = cast(
pd.Series,
self._exchange.get_volume(
self._order.stock_id,
self._start_time,
self._end_time,
method=None,
),
)

def __repr__(self) -> str:
return (
f"Order: {self._order}, Exchange: {self._exchange}, "
f"Start time: {self._start_time}, End time: {self._end_time}"
)

def __len__(self) -> int:
return len(self._deal_price)

def get_deal_price(self) -> pd.Series:
return self._deal_price

def get_volume(self) -> pd.Series:
return self._volume

def get_time_index(self) -> pd.DatetimeIndex:
return pd.DatetimeIndex([e[1] for e in list(self._exchange.quote_df.index)])
Loading