diff --git a/qlib/backtest/__init__.py b/qlib/backtest/__init__.py index d3f4d72402..e8fe73c5a2 100644 --- a/qlib/backtest/__init__.py +++ b/qlib/backtest/__init__.py @@ -345,4 +345,4 @@ def format_decisions( return res -__all__ = ["Order", "backtest"] +__all__ = ["Order", "backtest", "get_strategy_executor"] diff --git a/qlib/backtest/backtest.py b/qlib/backtest/backtest.py index e476550691..f79622bff6 100644 --- a/qlib/backtest/backtest.py +++ b/qlib/backtest/backtest.py @@ -83,7 +83,9 @@ def collect_data_loop( while not trade_executor.finished(): _trade_decision: BaseTradeDecision = trade_strategy.generate_trade_decision(_execute_result) _execute_result = yield from trade_executor.collect_data(_trade_decision, level=0) + trade_strategy.post_exe_step(_execute_result) bar.update(1) + trade_strategy.post_upper_level_exe_step() if return_value is not None: all_executors = trade_executor.get_all_executors() diff --git a/qlib/backtest/decision.py b/qlib/backtest/decision.py index 4828478c7e..042b73fea8 100644 --- a/qlib/backtest/decision.py +++ b/qlib/backtest/decision.py @@ -135,6 +135,21 @@ def parse_dir(direction: Union[str, int, np.integer, OrderDir, np.ndarray]) -> U else: raise NotImplementedError(f"This type of input is not supported") + @property + def key_by_day(self) -> tuple: + """A hashable & unique key to identify this order, under the granularity in day.""" + return self.stock_id, self.date, self.direction + + @property + def key(self) -> tuple: + """A hashable & unique key to identify this order.""" + return self.stock_id, self.start_time, self.end_time, self.direction + + @property + def date(self) -> pd.Timestamp: + """Date of the order.""" + return pd.Timestamp(self.start_time.replace(hour=0, minute=0, second=0)) + class OrderHelper: """ diff --git a/qlib/backtest/executor.py b/qlib/backtest/executor.py index 13af7aea71..664f33a3cd 100644 --- a/qlib/backtest/executor.py +++ b/qlib/backtest/executor.py @@ -114,7 +114,7 @@ def __init__( self.track_data = track_data self._trade_exchange = trade_exchange self.level_infra = LevelInfrastructure() - self.level_infra.reset_infra(common_infra=common_infra) + self.level_infra.reset_infra(common_infra=common_infra, executor=self) self._settle_type = settle_type self.reset(start_time=start_time, end_time=end_time, common_infra=common_infra) if common_infra is None: @@ -134,6 +134,8 @@ def reset_common_infra(self, common_infra: CommonInfrastructure, copy_trade_acco else: self.common_infra.update(common_infra) + self.level_infra.reset_infra(common_infra=self.common_infra) + if common_infra.has("trade_account"): # NOTE: there is a trick in the code. # shallow copy is used instead of deepcopy. @@ -256,6 +258,7 @@ def collect_data( object trade decision """ + if self.track_data: yield trade_decision @@ -296,6 +299,7 @@ def collect_data( if return_value is not None: return_value.update({"execute_result": res}) + return res def get_all_executors(self) -> List[BaseExecutor]: @@ -396,7 +400,7 @@ def _update_trade_decision(self, trade_decision: BaseTradeDecision) -> BaseTrade trade_decision = updated_trade_decision # NEW UPDATE # create a hook for inner strategy to update outer decision - self.inner_strategy.alter_outer_trade_decision(trade_decision) + trade_decision = self.inner_strategy.alter_outer_trade_decision(trade_decision) return trade_decision def _collect_data( @@ -473,6 +477,9 @@ def _collect_data( # do nothing and just step forward sub_cal.step() + # Let inner strategy know that the outer level execution is done. + self.inner_strategy.post_upper_level_exe_step() + return execute_result, {"inner_order_indicators": inner_order_indicators, "decision_list": decision_list} def post_inner_exe_step(self, inner_exe_res: List[object]) -> None: diff --git a/qlib/backtest/utils.py b/qlib/backtest/utils.py index db35dc4820..f815d10554 100644 --- a/qlib/backtest/utils.py +++ b/qlib/backtest/utils.py @@ -3,9 +3,8 @@ from __future__ import annotations -import bisect from abc import abstractmethod -from typing import TYPE_CHECKING, Any, Set, Tuple, Union +from typing import Any, Set, Tuple, TYPE_CHECKING, Union import numpy as np @@ -184,8 +183,8 @@ def get_range_idx(self, start_time: pd.Timestamp, end_time: pd.Timestamp) -> Tup Tuple[int, int]: the index of the range. **the left and right are closed** """ - left = bisect.bisect_right(list(self._calendar), start_time) - 1 - right = bisect.bisect_right(list(self._calendar), end_time) - 1 + left = np.searchsorted(self._calendar, start_time, side="right") - 1 + right = np.searchsorted(self._calendar, end_time, side="right") - 1 left -= self.start_index right -= self.start_index @@ -248,7 +247,7 @@ def get_support_infra(self) -> Set[str]: sub_level_infra: - **NOTE**: this will only work after _init_sub_trading !!! """ - return {"trade_calendar", "sub_level_infra", "common_infra"} + return {"trade_calendar", "sub_level_infra", "common_infra", "executor"} def reset_cal( self, diff --git a/qlib/constant.py b/qlib/constant.py index 458890957d..ac6c76ae22 100644 --- a/qlib/constant.py +++ b/qlib/constant.py @@ -2,6 +2,11 @@ # Licensed under the MIT License. # REGION CONST +from typing import TypeVar + +import numpy as np +import pandas as pd + REG_CN = "cn" REG_US = "us" REG_TW = "tw" @@ -10,4 +15,8 @@ EPS = 1e-12 # Infinity in integer -INF = 10**18 +INF = int(1e18) +ONE_DAY = pd.Timedelta("1day") +ONE_MIN = pd.Timedelta("1min") +EPS_T = pd.Timedelta("1s") # use 1 second to exclude the right interval point +float_or_ndarray = TypeVar("float_or_ndarray", float, np.ndarray) diff --git a/qlib/data/dataset/__init__.py b/qlib/data/dataset/__init__.py index c74092de34..5e98bfc97a 100644 --- a/qlib/data/dataset/__init__.py +++ b/qlib/data/dataset/__init__.py @@ -615,4 +615,4 @@ def _prepare_seg(self, slc: slice, **kwargs) -> TSDataSampler: return tsds -__all__ = ["Optional"] +__all__ = ["Optional", "Dataset", "DatasetH"] diff --git a/qlib/rl/aux_info.py b/qlib/rl/aux_info.py index 9ab0834511..1fd581544e 100644 --- a/qlib/rl/aux_info.py +++ b/qlib/rl/aux_info.py @@ -3,7 +3,7 @@ from __future__ import annotations -from typing import Optional, TYPE_CHECKING, Generic, TypeVar +from typing import TYPE_CHECKING, Generic, Optional, TypeVar from qlib.typehint import final diff --git a/qlib/rl/data/exchange_wrapper.py b/qlib/rl/data/exchange_wrapper.py index bc36fa11b8..94bb1dcbbd 100644 --- a/qlib/rl/data/exchange_wrapper.py +++ b/qlib/rl/data/exchange_wrapper.py @@ -3,21 +3,33 @@ from typing import cast +import cachetools import pandas as pd from qlib.backtest import Exchange, Order -from .pickle_styled import IntradayBacktestData +from qlib.backtest.decision import TradeRange, TradeRangeByTime +from qlib.constant import ONE_DAY, EPS_T +from qlib.rl.order_execution.utils import get_ticks_slice +from qlib.utils.index_data import IndexData +from .pickle_styled import BaseIntradayBacktestData -class QlibIntradayBacktestData(IntradayBacktestData): +class IntradayBacktestData(BaseIntradayBacktestData): """Backtest data for Qlib simulator""" - def __init__(self, order: Order, exchange: Exchange, start_time: pd.Timestamp, end_time: pd.Timestamp) -> None: - super(QlibIntradayBacktestData, self).__init__() + def __init__( + self, + order: Order, + exchange: Exchange, + ticks_index: pd.DatetimeIndex, + ticks_for_order: pd.DatetimeIndex, + ) -> None: self._order = order self._exchange = exchange - self._start_time = start_time - self._end_time = end_time + self._start_time = ticks_for_order[0] + self._end_time = ticks_for_order[-1] + self.ticks_index = ticks_index + self.ticks_for_order = ticks_for_order self._deal_price = cast( pd.Series, @@ -56,3 +68,43 @@ def get_volume(self) -> pd.Series: def get_time_index(self) -> pd.DatetimeIndex: return pd.DatetimeIndex([e[1] for e in list(self._exchange.quote_df.index)]) + + +@cachetools.cached( # type: ignore + cache=cachetools.LRUCache(100), + key=lambda order, _, __: order.key_by_day, +) +def load_qlib_backtest_data( + order: Order, + trade_exchange: Exchange, + trade_range: TradeRange, +) -> IntradayBacktestData: + data = cast( + IndexData, + trade_exchange.get_deal_price( + stock_id=order.stock_id, + start_time=order.date, + end_time=order.date + ONE_DAY - EPS_T, + direction=order.direction, + method=None, + ), + ) + + ticks_index = pd.DatetimeIndex(data.index) + if isinstance(trade_range, TradeRangeByTime): + ticks_for_order = get_ticks_slice( + ticks_index, + trade_range.start_time, + trade_range.end_time, + include_end=True, + ) + else: + ticks_for_order = None # FIXME: implement this logic + + backtest_data = IntradayBacktestData( + order=order, + exchange=trade_exchange, + ticks_index=ticks_index, + ticks_for_order=ticks_for_order, + ) + return backtest_data diff --git a/qlib/rl/data/pickle_styled.py b/qlib/rl/data/pickle_styled.py index aa0ba38fff..43fe9dd5ad 100644 --- a/qlib/rl/data/pickle_styled.py +++ b/qlib/rl/data/pickle_styled.py @@ -86,7 +86,7 @@ def _read_pickle(filename_without_suffix: Path) -> pd.DataFrame: return pd.read_pickle(_find_pickle(filename_without_suffix)) -class IntradayBacktestData: +class BaseIntradayBacktestData: """ Raw market data that is often used in backtesting (thus called BacktestData). @@ -115,7 +115,7 @@ def get_time_index(self) -> pd.DatetimeIndex: raise NotImplementedError -class SimpleIntradayBacktestData(IntradayBacktestData): +class SimpleIntradayBacktestData(BaseIntradayBacktestData): """Backtest data for simple simulator""" def __init__( diff --git a/qlib/rl/from_neutrader/config.py b/qlib/rl/from_neutrader/config.py deleted file mode 100644 index d9a681b32d..0000000000 --- a/qlib/rl/from_neutrader/config.py +++ /dev/null @@ -1,20 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -from dataclasses import dataclass -from pathlib import Path -from typing import Optional, Tuple, Union - - -# TODO: In the future we should merge the dataclass-based config with Qlib's dict-based config. -@dataclass -class ExchangeConfig: - limit_threshold: Union[float, Tuple[str, str]] - deal_price: Union[str, Tuple[str, str]] - volume_threshold: dict - open_cost: float = 0.0005 - close_cost: float = 0.0015 - min_cost: float = 5.0 - trade_unit: Optional[float] = 100.0 - cash_limit: Optional[Union[Path, float]] = None - generate_report: bool = False diff --git a/qlib/rl/from_neutrader/feature.py b/qlib/rl/from_neutrader/feature.py deleted file mode 100644 index ca42af24c9..0000000000 --- a/qlib/rl/from_neutrader/feature.py +++ /dev/null @@ -1,109 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -import collections -from typing import List, Optional - -import pandas as pd - -import qlib -from qlib.config import REG_CN -from qlib.contrib.ops.high_freq import BFillNan, Cut, Date, DayCumsum, DayLast, FFillNan, IsInf, IsNull, Select -from qlib.data.dataset import DatasetH - - -class LRUCache: - def __init__(self, pool_size: int = 200): - self.pool_size = pool_size - self.contents: dict = {} - self.keys: collections.deque = collections.deque() - - def put(self, key, item): - if self.has(key): - self.keys.remove(key) - self.keys.append(key) - self.contents[key] = item - while len(self.contents) > self.pool_size: - self.contents.pop(self.keys.popleft()) - - def get(self, key): - return self.contents[key] - - def has(self, key): - return key in self.contents - - -class DataWrapper: - def __init__( - self, - feature_dataset: DatasetH, - backtest_dataset: DatasetH, - columns_today: List[str], - columns_yesterday: List[str], - _internal: bool = False, - ): - assert _internal, "Init function of data wrapper is for internal use only." - - self.feature_dataset = feature_dataset - self.backtest_dataset = backtest_dataset - self.columns_today = columns_today - self.columns_yesterday = columns_yesterday - - # TODO: We might have the chance to merge them. - self.feature_cache = LRUCache() - self.backtest_cache = LRUCache() - - def get(self, stock_id: str, date: pd.Timestamp, backtest: bool = False) -> pd.DataFrame: - start_time, end_time = date.replace(hour=0, minute=0, second=0), date.replace(hour=23, minute=59, second=59) - - if backtest: - dataset = self.backtest_dataset - cache = self.backtest_cache - else: - dataset = self.feature_dataset - cache = self.feature_cache - - if cache.has((start_time, end_time, stock_id)): - return cache.get((start_time, end_time, stock_id)) - data = dataset.handler.fetch(pd.IndexSlice[stock_id, start_time:end_time], level=None) - cache.put((start_time, end_time, stock_id), data) - return data - - -def init_qlib(config: dict, part: Optional[str] = None) -> None: - provider_uri_map = { - "day": config["provider_uri_day"].as_posix(), - "1min": config["provider_uri_1min"].as_posix(), - } - qlib.init( - region=REG_CN, - auto_mount=False, - custom_ops=[DayLast, FFillNan, BFillNan, Date, Select, IsNull, IsInf, Cut, DayCumsum], - expression_cache=None, - calendar_provider={ - "class": "LocalCalendarProvider", - "module_path": "qlib.data.data", - "kwargs": { - "backend": { - "class": "FileCalendarStorage", - "module_path": "qlib.data.storage.file_storage", - "kwargs": {"provider_uri_map": provider_uri_map}, - }, - }, - }, - feature_provider={ - "class": "LocalFeatureProvider", - "module_path": "qlib.data.data", - "kwargs": { - "backend": { - "class": "FileFeatureStorage", - "module_path": "qlib.data.storage.file_storage", - "kwargs": {"provider_uri_map": provider_uri_map}, - }, - }, - }, - provider_uri=provider_uri_map, - kernels=1, - redis_port=-1, - clear_mem_cache=False, # init_qlib will be called for multiple times. Keep the cache for improving performance - ) diff --git a/qlib/rl/order_execution/integration.py b/qlib/rl/order_execution/integration.py new file mode 100644 index 0000000000..07ca381613 --- /dev/null +++ b/qlib/rl/order_execution/integration.py @@ -0,0 +1,163 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +TODO: This file is used to integrate NeuTrader with Qlib to run the existing projects. +TODO: The implementation here is kind of adhoc. It is better to design a more uniformed & general implementation. +""" + +from __future__ import annotations + +import pickle +from pathlib import Path +from typing import List + +import cachetools +import numpy as np +import pandas as pd +import qlib +from qlib.constant import REG_CN +from qlib.contrib.ops.high_freq import BFillNan, Cut, Date, DayCumsum, DayLast, FFillNan, IsInf, IsNull, Select +from qlib.data.dataset import DatasetH + +dataset = None + + +class DataWrapper: + def __init__( + self, + feature_dataset: DatasetH, + backtest_dataset: DatasetH, + columns_today: List[str], + columns_yesterday: List[str], + _internal: bool = False, + ): + assert _internal, "Init function of data wrapper is for internal use only." + + self.feature_dataset = feature_dataset + self.backtest_dataset = backtest_dataset + self.columns_today = columns_today + self.columns_yesterday = columns_yesterday + + @cachetools.cached( # type: ignore + cache=cachetools.LRUCache(100), + key=lambda stock_id, date, backtest: (stock_id, date.replace(hour=0, minute=0, second=0), backtest), + ) + def get(self, stock_id: str, date: pd.Timestamp, backtest: bool = False) -> pd.DataFrame: + start_time, end_time = date.replace(hour=0, minute=0, second=0), date.replace(hour=23, minute=59, second=59) + dataset = self.backtest_dataset if backtest else self.feature_dataset + return dataset.handler.fetch(pd.IndexSlice[stock_id, start_time:end_time], level=None) + + +def init_qlib(qlib_config: dict, part: str = None) -> None: + """Initialize necessary resource to launch the workflow, including data direction, feature columns, etc.. + + Parameters + ---------- + qlib_config: + Qlib configuration. + + Example:: + + { + "provider_uri_day": DATA_ROOT_DIR / "qlib_1d", + "provider_uri_1min": DATA_ROOT_DIR / "qlib_1min", + "feature_root_dir": DATA_ROOT_DIR / "qlib_handler_stock", + "feature_columns_today": [ + "$open", "$high", "$low", "$close", "$vwap", "$bid", "$ask", "$volume", + "$bidV", "$bidV1", "$bidV3", "$bidV5", "$askV", "$askV1", "$askV3", "$askV5", + ], + "feature_columns_yesterday": [ + "$open_1", "$high_1", "$low_1", "$close_1", "$vwap_1", "$bid_1", "$ask_1", "$volume_1", + "$bidV_1", "$bidV1_1", "$bidV3_1", "$bidV5_1", "$askV_1", "$askV1_1", "$askV3_1", "$askV5_1", + ], + } + part + Identifying which part (stock / date) to load. + """ + + global dataset # pylint: disable=W0603 + + def _convert_to_path(path: str | Path) -> Path: + return path if isinstance(path, Path) else Path(path) + + provider_uri_map = { + "day": _convert_to_path(qlib_config["provider_uri_day"]).as_posix(), + "1min": _convert_to_path(qlib_config["provider_uri_1min"]).as_posix(), + } + qlib.init( + region=REG_CN, + auto_mount=False, + custom_ops=[DayLast, FFillNan, BFillNan, Date, Select, IsNull, IsInf, Cut, DayCumsum], + expression_cache=None, + calendar_provider={ + "class": "LocalCalendarProvider", + "module_path": "qlib.data.data", + "kwargs": { + "backend": { + "class": "FileCalendarStorage", + "module_path": "qlib.data.storage.file_storage", + "kwargs": {"provider_uri_map": provider_uri_map}, + }, + }, + }, + feature_provider={ + "class": "LocalFeatureProvider", + "module_path": "qlib.data.data", + "kwargs": { + "backend": { + "class": "FileFeatureStorage", + "module_path": "qlib.data.storage.file_storage", + "kwargs": {"provider_uri_map": provider_uri_map}, + }, + }, + }, + provider_uri=provider_uri_map, + kernels=1, + redis_port=-1, + clear_mem_cache=False, # init_qlib will be called for multiple times. Keep the cache for improving performance + ) + + if part == "skip": + return + + # this won't work if it's put outside in case of multiprocessing + from qlib.data import D # noqa pylint: disable=C0415,W0611 + + if part is None: + feature_path = Path(qlib_config["feature_root_dir"]) / "feature.pkl" + backtest_path = Path(qlib_config["feature_root_dir"]) / "backtest.pkl" + else: + feature_path = Path(qlib_config["feature_root_dir"]) / "feature" / (part + ".pkl") + backtest_path = Path(qlib_config["feature_root_dir"]) / "backtest" / (part + ".pkl") + + with feature_path.open("rb") as f: + feature_dataset = pickle.load(f) + with backtest_path.open("rb") as f: + backtest_dataset = pickle.load(f) + + dataset = DataWrapper( + feature_dataset, + backtest_dataset, + qlib_config["feature_columns_today"], + qlib_config["feature_columns_yesterday"], + _internal=True, + ) + + +def fetch_features(stock_id: str, date: pd.Timestamp, yesterday: bool = False, backtest: bool = False) -> pd.DataFrame: + assert dataset is not None, "You must call init_qlib() before doing this." + + if backtest: + fields = ["$close", "$volume"] + else: + fields = dataset.columns_yesterday if yesterday else dataset.columns_today + + data = dataset.get(stock_id, date, backtest) + if data is None or len(data) == 0: + # create a fake index, but RL doesn't care about index + data = pd.DataFrame(0.0, index=np.arange(240), columns=fields, dtype=np.float32) # FIXME: hardcode here + else: + data = data.rename(columns={c: c.rstrip("0") for c in data.columns}) + data = data[fields] + return data diff --git a/qlib/rl/order_execution/interpreter.py b/qlib/rl/order_execution/interpreter.py index 602a15e54e..089fc553cf 100644 --- a/qlib/rl/order_execution/interpreter.py +++ b/qlib/rl/order_execution/interpreter.py @@ -14,15 +14,15 @@ from qlib.constant import EPS from qlib.rl.data import pickle_styled from qlib.rl.interpreter import ActionInterpreter, StateInterpreter +from qlib.rl.order_execution.state import SAOEState from qlib.typehint import TypedDict -from .simulator_simple import SAOEState - __all__ = [ "FullHistoryStateInterpreter", "CurrentStepStateInterpreter", "CategoricalActionInterpreter", "TwapRelativeActionInterpreter", + "FullHistoryObs", ] diff --git a/qlib/rl/order_execution/policy.py b/qlib/rl/order_execution/policy.py index 18c2e4f175..cfd3181ca2 100644 --- a/qlib/rl/order_execution/policy.py +++ b/qlib/rl/order_execution/policy.py @@ -1,5 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. + from __future__ import annotations from pathlib import Path diff --git a/qlib/rl/order_execution/reward.py b/qlib/rl/order_execution/reward.py index f15a152c66..99a88f8e44 100644 --- a/qlib/rl/order_execution/reward.py +++ b/qlib/rl/order_execution/reward.py @@ -7,10 +7,9 @@ import numpy as np +from qlib.rl.order_execution.state import SAOEMetrics, SAOEState from qlib.rl.reward import Reward -from .simulator_simple import SAOEMetrics, SAOEState - __all__ = ["PAPenaltyReward"] diff --git a/qlib/rl/order_execution/simulator_qlib.py b/qlib/rl/order_execution/simulator_qlib.py index c75793f586..3002fd333e 100644 --- a/qlib/rl/order_execution/simulator_qlib.py +++ b/qlib/rl/order_execution/simulator_qlib.py @@ -3,381 +3,102 @@ from __future__ import annotations -from typing import Any, Callable, cast, Generator, List, Optional, Tuple +from typing import Generator, Optional -import numpy as np import pandas as pd - -from qlib.backtest.decision import BaseTradeDecision, Order, OrderHelper, TradeDecisionWO, TradeRange, TradeRangeByTime -from qlib.backtest.executor import BaseExecutor, NestedExecutor -from qlib.backtest.utils import CommonInfrastructure -from qlib.constant import EPS -from qlib.rl.data.exchange_wrapper import QlibIntradayBacktestData -from qlib.rl.from_neutrader.config import ExchangeConfig -from qlib.rl.from_neutrader.feature import init_qlib -from qlib.rl.order_execution.simulator_simple import SAOEMetrics, SAOEState -from qlib.rl.order_execution.utils import ( - dataframe_append, - get_common_infra, - get_portfolio_and_indicator, - get_ticks_slice, - price_advantage, -) +from qlib.backtest import collect_data_loop, get_strategy_executor +from qlib.backtest.decision import Order +from qlib.backtest.executor import NestedExecutor from qlib.rl.simulator import Simulator -from qlib.strategy.base import BaseStrategy - - -class DecomposedStrategy(BaseStrategy): - def __init__(self) -> None: - super().__init__() - - self.execute_order: Optional[Order] = None - self.execute_result: List[Tuple[Order, float, float, float]] = [] - - def generate_trade_decision(self, execute_result: list = None) -> Generator[Any, Any, BaseTradeDecision]: - # Once the following line is executed, this DecomposedStrategy (self) will be yielded to the outside - # of the entire executor, and the execution will be suspended. When the execution is resumed by `send()`, - # the sent item will be captured by `exec_vol`. The outside policy could communicate with the inner - # level strategy through this way. - exec_vol = yield self - - oh = self.trade_exchange.get_order_helper() - order = oh.create(self._order.stock_id, exec_vol, self._order.direction) - - self.execute_order = order - - return TradeDecisionWO([order], self) - - def alter_outer_trade_decision(self, outer_trade_decision: BaseTradeDecision) -> BaseTradeDecision: - return outer_trade_decision - - def post_exe_step(self, execute_result: list) -> None: - self.execute_result = execute_result - - def reset(self, outer_trade_decision: TradeDecisionWO = None, **kwargs: Any) -> None: - super().reset(outer_trade_decision=outer_trade_decision, **kwargs) - if outer_trade_decision is not None: - order_list = outer_trade_decision.order_list - assert len(order_list) == 1 - self._order = order_list[0] - - -class SingleOrderStrategy(BaseStrategy): - # this logic is copied from FileOrderStrategy - def __init__( - self, - common_infra: CommonInfrastructure, - order: Order, - trade_range: TradeRange, - instrument: str, - ) -> None: - super().__init__(common_infra=common_infra) - self._order = order - self._trade_range = trade_range - self._instrument = instrument - - def alter_outer_trade_decision(self, outer_trade_decision: BaseTradeDecision) -> BaseTradeDecision: - return outer_trade_decision - - def generate_trade_decision(self, execute_result: list = None) -> TradeDecisionWO: - oh: OrderHelper = self.common_infra.get("trade_exchange").get_order_helper() - order_list = [ - oh.create( - code=self._instrument, - amount=self._order.amount, - direction=self._order.direction, - ), - ] - return TradeDecisionWO(order_list, self, self._trade_range) - - -# TODO: move these to the configuration files -FINEST_GRANULARITY = "1min" -COARSEST_GRANULARITY = "1day" - - -class StateMaintainer: - """ - Maintain states of the environment. - - Example usage:: - - maintainer = StateMaintainer(...) # in reset - maintainer.update(...) # in step - # get states in get_state from maintainer - """ - - def __init__(self, order: Order, time_per_step: str, tick_index: pd.DatetimeIndex, twap_price: float) -> None: - super().__init__() - - self.position = order.amount - self._order = order - self._time_per_step = time_per_step - self._tick_index = tick_index - self._twap_price = twap_price - - metric_keys = list(SAOEMetrics.__annotations__.keys()) # pylint: disable=no-member - self.history_exec = pd.DataFrame(columns=metric_keys).set_index("datetime") - self.history_steps = pd.DataFrame(columns=metric_keys).set_index("datetime") - self.metrics: Optional[SAOEMetrics] = None - - def update( - self, - inner_executor: BaseExecutor, - inner_strategy: DecomposedStrategy, - done: bool, - all_indicators: dict, - ) -> None: - execute_order = inner_strategy.execute_order - execute_result = inner_strategy.execute_result - exec_vol = np.array([e[0].deal_amount for e in execute_result]) - num_step = len(execute_result) - - assert execute_order is not None - - if num_step == 0: - market_volume = np.array([]) - market_price = np.array([]) - datetime_list = pd.DatetimeIndex([]) - else: - market_volume = np.array( - inner_executor.trade_exchange.get_volume( - execute_order.stock_id, - execute_result[0][0].start_time, - execute_result[-1][0].start_time, - method=None, - ), - ) - - trade_value = all_indicators[FINEST_GRANULARITY].iloc[-num_step:]["value"].values - deal_amount = all_indicators[FINEST_GRANULARITY].iloc[-num_step:]["deal_amount"].values - market_price = trade_value / deal_amount - - datetime_list = all_indicators[FINEST_GRANULARITY].index[-num_step:] - assert market_price.shape == market_volume.shape == exec_vol.shape +from .integration import init_qlib +from .state import SAOEState, SAOEStateAdapter +from .strategy import SAOEStrategy - self.history_exec = dataframe_append( - self.history_exec, - self._collect_multi_order_metric( - order=self._order, - datetime=datetime_list, - market_vol=market_volume, - market_price=market_price, - exec_vol=exec_vol, - pa=all_indicators[self._time_per_step].iloc[-1]["pa"], - ), - ) - - self.history_steps = dataframe_append( - self.history_steps, - [ - self._collect_single_order_metric( - execute_order, - execute_order.start_time, - market_volume, - market_price, - exec_vol.sum(), - exec_vol, - ), - ], - ) - - if done: - self.metrics = self._collect_single_order_metric( - self._order, - self._tick_index[0], # start time - self.history_exec["market_volume"], - self.history_exec["market_price"], - self.history_steps["amount"].sum(), - self.history_exec["deal_amount"], - ) - - # TODO: check whether we need this. Can we get this information from Account? - # Do this at the end - self.position -= exec_vol.sum() - - def _collect_multi_order_metric( - self, - order: Order, - datetime: pd.Timestamp, - market_vol: np.ndarray, - market_price: np.ndarray, - exec_vol: np.ndarray, - pa: float, - ) -> SAOEMetrics: - return SAOEMetrics( - # It should have the same keys with SAOEMetrics, - # but the values do not necessarily have the annotated type. - # Some values could be vectorized (e.g., exec_vol). - stock_id=order.stock_id, - datetime=datetime, - direction=order.direction, - market_volume=market_vol, - market_price=market_price, - amount=exec_vol, - inner_amount=exec_vol, - deal_amount=exec_vol, - trade_price=market_price, - trade_value=market_price * exec_vol, - position=self.position - np.cumsum(exec_vol), - ffr=exec_vol / order.amount, - pa=pa, - ) - - def _collect_single_order_metric( - self, - order: Order, - datetime: pd.Timestamp, - market_vol: np.ndarray, - market_price: np.ndarray, - amount: float, # intended to trade such amount - exec_vol: np.ndarray, - ) -> SAOEMetrics: - assert len(market_vol) == len(market_price) == len(exec_vol) - - if np.abs(np.sum(exec_vol)) < EPS: - exec_avg_price = 0.0 - else: - exec_avg_price = cast(float, np.average(market_price, weights=exec_vol)) # could be nan - if hasattr(exec_avg_price, "item"): # could be numpy scalar - exec_avg_price = exec_avg_price.item() # type: ignore - - exec_sum = exec_vol.sum() - return SAOEMetrics( - stock_id=order.stock_id, - datetime=datetime, - direction=order.direction, - market_volume=market_vol.sum(), - market_price=market_price.mean() if len(market_price) > 0 else np.nan, - amount=amount, - inner_amount=exec_sum, - deal_amount=exec_sum, # in this simulator, there's no other restrictions - trade_price=exec_avg_price, - trade_value=float(np.sum(market_price * exec_vol)), - position=self.position - exec_sum, - ffr=float(exec_sum / order.amount), - pa=price_advantage(exec_avg_price, self._twap_price, order.direction), - ) - -class SingleAssetOrderExecutionQlib(Simulator[Order, SAOEState, float]): +class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]): """Single-asset order execution (SAOE) simulator which is implemented based on Qlib backtest tools. Parameters ---------- - order (Order): + order The seed to start an SAOE simulator is an order. - time_per_step (str): - A string to describe the time granularity of each step. Current support "1min", "30min", and "1day" - qlib_config (dict): - Configuration used to initialize Qlib. - inner_executor_fn (Callable[[str, CommonInfrastructure], BaseExecutor]): - Function used to get the inner level executor. - exchange_config (ExchangeConfig): - Configuration used to create the Exchange instance. + strategy_config + Strategy configuration + executor_config + Executor configuration + exchange_config + Exchange configuration + qlib_config + Configuration used to initialize Qlib. If it is None, Qlib will not be initialized. """ def __init__( self, order: Order, - time_per_step: str, # "1min", "30min", "1day" - qlib_config: dict, - inner_executor_fn: Callable[[str, CommonInfrastructure], BaseExecutor], - exchange_config: ExchangeConfig, + strategy_config: dict, + executor_config: dict, + exchange_config: dict, + qlib_config: dict = None, ) -> None: - assert time_per_step in ("1min", "30min", "1day") - super().__init__(initial=order) assert order.start_time.date() == order.end_time.date(), "Start date and end date must be the same." - self._order = order - self._order_date = pd.Timestamp(order.start_time.date()) - self._trade_range = TradeRangeByTime(order.start_time.time(), order.end_time.time()) - self._qlib_config = qlib_config - self._inner_executor_fn = inner_executor_fn - self._exchange_config = exchange_config - - self._time_per_step = time_per_step - self._ticks_per_step = int(pd.Timedelta(time_per_step).total_seconds() // 60) - - self._executor: Optional[NestedExecutor] = None self._collect_data_loop: Optional[Generator] = None + self.reset(order, strategy_config, executor_config, exchange_config, qlib_config) - self._done = False - - self._inner_strategy = DecomposedStrategy() - - self.reset(self._order) - - def reset(self, order: Order) -> None: - instrument = order.stock_id - - # TODO: Check this logic. Make sure we need to do this every time we reset the simulator. - init_qlib(self._qlib_config, instrument) - - common_infra = get_common_infra( - self._exchange_config, - trade_date=pd.Timestamp(self._order_date), - codes=[instrument], + def reset( + self, + order: Order, + strategy_config: dict, + executor_config: dict, + exchange_config: dict, + qlib_config: dict = None, + ) -> None: + if qlib_config is not None: + init_qlib(qlib_config, part="skip") + + strategy, self._executor = get_strategy_executor( + start_time=order.date, + end_time=order.date + pd.DateOffset(1), + strategy=strategy_config, + executor=executor_config, + benchmark=order.stock_id, + account=1e12, + exchange_kwargs=exchange_config, + pos_type="InfPosition", ) - # TODO: We can leverage interfaces like (https://tinyurl.com/y8f8fhv4) to create trading environment. - # TODO: By aligning the interface to create environments with Qlib, it will be easier to share the config and - # TODO: code between backtesting and training. - self._inner_executor = self._inner_executor_fn(self._time_per_step, common_infra) - self._executor = NestedExecutor( - time_per_step=COARSEST_GRANULARITY, - inner_executor=self._inner_executor, - inner_strategy=self._inner_strategy, - track_data=True, - common_infra=common_infra, - ) + assert isinstance(self._executor, NestedExecutor) - exchange = self._inner_executor.trade_exchange - self._ticks_index = pd.DatetimeIndex([e[1] for e in list(exchange.quote_df.index)]) - self._ticks_for_order = get_ticks_slice( - self._ticks_index, - self._order.start_time, - self._order.end_time, - include_end=True, - ) - - self._backtest_data = QlibIntradayBacktestData( - order=self._order, - exchange=exchange, - start_time=self._ticks_for_order[0], - end_time=self._ticks_for_order[-1], + self._collect_data_loop = collect_data_loop( + start_time=order.date, + end_time=order.date, + trade_strategy=strategy, + trade_executor=self._executor, ) + assert isinstance(self._collect_data_loop, Generator) - self.twap_price = self._backtest_data.get_deal_price().mean() + self._last_yielded_saoe_strategy = self._iter_strategy(action=None) - top_strategy = SingleOrderStrategy(common_infra, order, self._trade_range, instrument) - self._executor.reset(start_time=pd.Timestamp(self._order_date), end_time=pd.Timestamp(self._order_date)) - top_strategy.reset(level_infra=self._executor.get_level_infra()) + self._order = order - self._collect_data_loop = self._executor.collect_data(top_strategy.generate_trade_decision(), level=0) - assert isinstance(self._collect_data_loop, Generator) + def _get_adapter(self) -> SAOEStateAdapter: + return self._last_yielded_saoe_strategy.adapter_dict[self._order.key_by_day] - self._iter_strategy(action=None) - self._done = False + @property + def twap_price(self) -> float: + return self._get_adapter().twap_price - self._maintainer = StateMaintainer( - order=self._order, - time_per_step=self._time_per_step, - tick_index=self._ticks_index, - twap_price=self.twap_price, - ) - - def _iter_strategy(self, action: float = None) -> DecomposedStrategy: - """Iterate the _collect_data_loop until we get the next yield DecomposedStrategy.""" + def _iter_strategy(self, action: float = None) -> SAOEStrategy: + """Iterate the _collect_data_loop until we get the next yield SAOEStrategy.""" assert self._collect_data_loop is not None strategy = next(self._collect_data_loop) if action is None else self._collect_data_loop.send(action) - while not isinstance(strategy, DecomposedStrategy): + while not isinstance(strategy, SAOEStrategy): strategy = next(self._collect_data_loop) if action is None else self._collect_data_loop.send(action) - assert isinstance(strategy, DecomposedStrategy) + assert isinstance(strategy, SAOEStrategy) return strategy def step(self, action: float) -> None: @@ -389,36 +110,17 @@ def step(self, action: float) -> None: The amount you wish to deal. The simulator doesn't guarantee all the amount to be successfully dealt. """ - assert not self._done, "Simulator has already done!" + assert not self.done(), "Simulator has already done!" try: - self._iter_strategy(action=action) + self._last_yielded_saoe_strategy = self._iter_strategy(action=action) except StopIteration: - self._done = True + pass assert self._executor is not None - _, all_indicators = get_portfolio_and_indicator(self._executor) - - self._maintainer.update( - inner_executor=self._inner_executor, - inner_strategy=self._inner_strategy, - done=self._done, - all_indicators=all_indicators, - ) def get_state(self) -> SAOEState: - return SAOEState( - order=self._order, - cur_time=self._inner_executor.trade_calendar.get_step_time()[0], - position=self._maintainer.position, - history_exec=self._maintainer.history_exec, - history_steps=self._maintainer.history_steps, - metrics=self._maintainer.metrics, - backtest_data=self._backtest_data, - ticks_per_step=self._ticks_per_step, - ticks_index=self._ticks_index, - ticks_for_order=self._ticks_for_order, - ) + return self._get_adapter().saoe_state def done(self) -> bool: - return self._done + return self._executor.finished() diff --git a/qlib/rl/order_execution/simulator_simple.py b/qlib/rl/order_execution/simulator_simple.py index 6d49457841..f95aeebad0 100644 --- a/qlib/rl/order_execution/simulator_simple.py +++ b/qlib/rl/order_execution/simulator_simple.py @@ -4,107 +4,21 @@ from __future__ import annotations from pathlib import Path -from typing import Any, NamedTuple, Optional, TypeVar, cast +from typing import Any, cast, Optional import numpy as np import pandas as pd - from qlib.backtest.decision import Order, OrderDir -from qlib.constant import EPS -from qlib.rl.data.pickle_styled import DealPriceType, IntradayBacktestData, load_simple_intraday_backtest_data +from qlib.constant import EPS, EPS_T, float_or_ndarray +from qlib.rl.data.pickle_styled import DealPriceType, load_simple_intraday_backtest_data from qlib.rl.simulator import Simulator from qlib.rl.utils import LogLevel -from qlib.typehint import TypedDict - -# TODO: Integrating Qlib's native data with simulator_simple - -__all__ = ["SAOEMetrics", "SAOEState", "SingleAssetOrderExecution"] - -ONE_SEC = pd.Timedelta("1s") # use 1 second to exclude the right interval point +from .state import SAOEMetrics, SAOEState -class SAOEMetrics(TypedDict): - """Metrics for SAOE accumulated for a "period". - It could be accumulated for a day, or a period of time (e.g., 30min), or calculated separately for every minute. - - Warnings - -------- - The type hints are for single elements. In lots of times, they can be vectorized. - For example, ``market_volume`` could be a list of float (or ndarray) rather tahn a single float. - """ - - stock_id: str - """Stock ID of this record.""" - datetime: pd.Timestamp | pd.DatetimeIndex # TODO: check this - """Datetime of this record (this is index in the dataframe).""" - direction: int - """Direction of the order. 0 for sell, 1 for buy.""" - - # Market information. - market_volume: np.ndarray | float - """(total) market volume traded in the period.""" - market_price: np.ndarray | float - """Deal price. If it's a period of time, this is the average market deal price.""" - - # Strategy records. - - amount: np.ndarray | float - """Total amount (volume) strategy intends to trade.""" - inner_amount: np.ndarray | float - """Total amount that the lower-level strategy intends to trade - (might be larger than amount, e.g., to ensure ffr).""" - - deal_amount: np.ndarray | float - """Amount that successfully takes effect (must be less than inner_amount).""" - trade_price: np.ndarray | float - """The average deal price for this strategy.""" - trade_value: np.ndarray | float - """Total worth of trading. In the simple simulation, trade_value = deal_amount * price.""" - position: np.ndarray | float - """Position left after this "period".""" - - # Accumulated metrics - - ffr: np.ndarray | float - """Completed how much percent of the daily order.""" - - pa: np.ndarray | float - """Price advantage compared to baseline (i.e., trade with baseline market price). - The baseline is trade price when using TWAP strategy to execute this order. - Please note that there could be data leak here). - Unit is BP (basis point, 1/10000).""" - - -class SAOEState(NamedTuple): - """Data structure holding a state for SAOE simulator.""" - - order: Order - """The order we are dealing with.""" - cur_time: pd.Timestamp - """Current time, e.g., 9:30.""" - position: float - """Current remaining volume to execute.""" - history_exec: pd.DataFrame - """See :attr:`SingleAssetOrderExecution.history_exec`.""" - history_steps: pd.DataFrame - """See :attr:`SingleAssetOrderExecution.history_steps`.""" +# TODO: Integrating Qlib's native data with simulator_simple - metrics: Optional[SAOEMetrics] - """Daily metric, only available when the trading is in "done" state.""" - - backtest_data: IntradayBacktestData - """Backtest data is included in the state. - Actually, only the time index of this data is needed, at this moment. - I include the full data so that algorithms (e.g., VWAP) that relies on the raw data can be implemented. - Interpreter can use this as they wish, but they should be careful not to leak future data. - """ - - ticks_per_step: int - """How many ticks for each step.""" - ticks_index: pd.DatetimeIndex - """Trading ticks in all day, NOT sliced by order (defined in data). e.g., [9:30, 9:31, ..., 14:59].""" - ticks_for_order: pd.DatetimeIndex - """Trading ticks sliced by order, e.g., [9:45, 9:46, ..., 14:44].""" +__all__ = ["SingleAssetOrderExecution"] class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]): @@ -326,8 +240,8 @@ def _split_exec_vol(self, exec_vol_sum: float) -> np.ndarray: next_time = self._next_time() # get the backtest data for next interval - self.market_vol = self.backtest_data.get_volume().loc[self.cur_time : next_time - ONE_SEC].to_numpy() - self.market_price = self.backtest_data.get_deal_price().loc[self.cur_time : next_time - ONE_SEC].to_numpy() + self.market_vol = self.backtest_data.get_volume().loc[self.cur_time : next_time - EPS_T].to_numpy() + self.market_price = self.backtest_data.get_deal_price().loc[self.cur_time : next_time - EPS_T].to_numpy() assert self.market_vol is not None and self.market_price is not None @@ -380,7 +294,7 @@ def _metrics_collect( def _get_ticks_slice(self, start: pd.Timestamp, end: pd.Timestamp, include_end: bool = False) -> pd.DatetimeIndex: if not include_end: - end = end - ONE_SEC + end = end - EPS_T return self.ticks_index[self.ticks_index.slice_indexer(start, end)] @staticmethod @@ -391,14 +305,11 @@ def _dataframe_append(df: pd.DataFrame, other: Any) -> pd.DataFrame: return pd.concat([df, other_df], axis=0) -_float_or_ndarray = TypeVar("_float_or_ndarray", float, np.ndarray) - - def price_advantage( - exec_price: _float_or_ndarray, + exec_price: float_or_ndarray, baseline_price: float, direction: OrderDir | int, -) -> _float_or_ndarray: +) -> float_or_ndarray: if baseline_price == 0: # something is wrong with data. Should be nan here if isinstance(exec_price, float): return 0.0 @@ -414,4 +325,4 @@ def price_advantage( if res_wo_nan.size == 1: return res_wo_nan.item() else: - return cast(_float_or_ndarray, res_wo_nan) + return cast(float_or_ndarray, res_wo_nan) diff --git a/qlib/rl/order_execution/state.py b/qlib/rl/order_execution/state.py new file mode 100644 index 0000000000..d6bbeaea5a --- /dev/null +++ b/qlib/rl/order_execution/state.py @@ -0,0 +1,334 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +from typing import cast, NamedTuple, Optional, Tuple + +import numpy as np +import pandas as pd +from qlib.backtest import Exchange, Order +from qlib.backtest.executor import BaseExecutor +from qlib.constant import EPS, ONE_MIN, REG_CN +from qlib.rl.data.exchange_wrapper import IntradayBacktestData +from qlib.rl.data.pickle_styled import BaseIntradayBacktestData +from qlib.rl.order_execution.utils import dataframe_append, price_advantage +from qlib.utils.time import get_day_min_idx_range +from typing_extensions import TypedDict + + +def _get_all_timestamps( + start: pd.Timestamp, + end: pd.Timestamp, + granularity: pd.Timedelta = ONE_MIN, + include_end: bool = True, +) -> pd.DatetimeIndex: + ret = [] + while start <= end: + ret.append(start) + start += granularity + + if ret[-1] > end: + ret.pop() + if ret[-1] == end and not include_end: + ret.pop() + return pd.DatetimeIndex(ret) + + +class SAOEStateAdapter: + """ + Maintain states of the environment. SAOEStateAdapter accepts execution results and update its internal state + according to the execution results with additional information acquired from executors & exchange. For example, + it gets the dealt order amount from execution results, and get the corresponding market price / volume from + exchange. + + Example usage:: + + adapter = SAOEStateAdapter(...) + adapter.update(...) + state = adapter.saoe_state + """ + + def __init__( + self, + order: Order, + executor: BaseExecutor, + exchange: Exchange, + ticks_per_step: int, + backtest_data: IntradayBacktestData, + ) -> None: + self.position = order.amount + self.order = order + self.executor = executor + self.exchange = exchange + self.backtest_data = backtest_data + + self.twap_price = self.backtest_data.get_deal_price().mean() + + metric_keys = list(SAOEMetrics.__annotations__.keys()) # pylint: disable=no-member + self.history_exec = pd.DataFrame(columns=metric_keys).set_index("datetime") + self.history_steps = pd.DataFrame(columns=metric_keys).set_index("datetime") + self.metrics: Optional[SAOEMetrics] = None + + self.cur_time = max(backtest_data.ticks_for_order[0], order.start_time) + self.ticks_per_step = ticks_per_step + + def _next_time(self) -> pd.Timestamp: + current_loc = self.backtest_data.ticks_index.get_loc(self.cur_time) + next_loc = current_loc + self.ticks_per_step + next_loc = next_loc - next_loc % self.ticks_per_step + if ( + next_loc < len(self.backtest_data.ticks_index) + and self.backtest_data.ticks_index[next_loc] < self.order.end_time + ): + return self.backtest_data.ticks_index[next_loc] + else: + return self.order.end_time + + def update( + self, + execute_result: list, + last_step_range: Tuple[int, int], + ) -> None: + last_step_size = last_step_range[1] - last_step_range[0] + 1 + start_time = self.backtest_data.ticks_index[last_step_range[0]] + end_time = self.backtest_data.ticks_index[last_step_range[1]] + + exec_vol = np.zeros(last_step_size) + for order, _, __, ___ in execute_result: + idx, _ = get_day_min_idx_range(order.start_time, order.end_time, "1min", REG_CN) + exec_vol[idx - last_step_range[0]] = order.deal_amount + + if exec_vol.sum() > self.position and exec_vol.sum() > 0.0: + assert exec_vol.sum() < self.position + 1, f"{exec_vol} too large" + exec_vol *= self.position / (exec_vol.sum()) + + market_volume = np.array( + self.exchange.get_volume( + self.order.stock_id, + pd.Timestamp(start_time), + pd.Timestamp(end_time), + method=None, + ), + ).reshape(-1) + + market_price = np.array( + self.exchange.get_deal_price( + self.order.stock_id, + pd.Timestamp(start_time), + pd.Timestamp(end_time), + method=None, + direction=self.order.direction, + ), + ).reshape(-1) + + assert market_price.shape == market_volume.shape == exec_vol.shape + + # Get data from the current level executor's indicator + current_trade_account = self.executor.trade_account + current_df = current_trade_account.get_trade_indicator().generate_trade_indicators_dataframe() + self.history_exec = dataframe_append( + self.history_exec, + self._collect_multi_order_metric( + order=self.order, + datetime=_get_all_timestamps(start_time, end_time, include_end=True), + market_vol=market_volume, + market_price=market_price, + exec_vol=exec_vol, + pa=current_df.iloc[-1]["pa"], + ), + ) + + self.history_steps = dataframe_append( + self.history_steps, + [ + self._collect_single_order_metric( + self.order, + self.cur_time, + market_volume, + market_price, + exec_vol.sum(), + exec_vol, + ), + ], + ) + + # TODO: check whether we need this. Can we get this information from Account? + # Do this at the end + self.position -= exec_vol.sum() + + self.cur_time = self._next_time() + + def generate_metrics_after_done(self) -> None: + """Generate metrics once the upper level execution is done""" + + self.metrics = self._collect_single_order_metric( + self.order, + self.backtest_data.ticks_index[0], # start time + self.history_exec["market_volume"], + self.history_exec["market_price"], + self.history_steps["amount"].sum(), + self.history_exec["deal_amount"], + ) + + def _collect_multi_order_metric( + self, + order: Order, + datetime: pd.DatetimeIndex, + market_vol: np.ndarray, + market_price: np.ndarray, + exec_vol: np.ndarray, + pa: float, + ) -> SAOEMetrics: + return SAOEMetrics( + # It should have the same keys with SAOEMetrics, + # but the values do not necessarily have the annotated type. + # Some values could be vectorized (e.g., exec_vol). + stock_id=order.stock_id, + datetime=datetime, + direction=order.direction, + market_volume=market_vol, + market_price=market_price, + amount=exec_vol, + inner_amount=exec_vol, + deal_amount=exec_vol, + trade_price=market_price, + trade_value=market_price * exec_vol, + position=self.position - np.cumsum(exec_vol), + ffr=exec_vol / order.amount, + pa=pa, + ) + + def _collect_single_order_metric( + self, + order: Order, + datetime: pd.Timestamp, + market_vol: np.ndarray, + market_price: np.ndarray, + amount: float, # intended to trade such amount + exec_vol: np.ndarray, + ) -> SAOEMetrics: + assert len(market_vol) == len(market_price) == len(exec_vol) + + if np.abs(np.sum(exec_vol)) < EPS: + exec_avg_price = 0.0 + else: + exec_avg_price = cast(float, np.average(market_price, weights=exec_vol)) # could be nan + if hasattr(exec_avg_price, "item"): # could be numpy scalar + exec_avg_price = exec_avg_price.item() # type: ignore + + exec_sum = exec_vol.sum() + return SAOEMetrics( + stock_id=order.stock_id, + datetime=datetime, + direction=order.direction, + market_volume=market_vol.sum(), + market_price=market_price.mean() if len(market_price) > 0 else np.nan, + amount=amount, + inner_amount=exec_sum, + deal_amount=exec_sum, # in this simulator, there's no other restrictions + trade_price=exec_avg_price, + trade_value=float(np.sum(market_price * exec_vol)), + position=self.position - exec_sum, + ffr=float(exec_sum / order.amount), + pa=price_advantage(exec_avg_price, self.twap_price, order.direction), + ) + + @property + def saoe_state(self) -> SAOEState: + return SAOEState( + order=self.order, + cur_time=self.cur_time, + position=self.position, + history_exec=self.history_exec, + history_steps=self.history_steps, + metrics=self.metrics, + backtest_data=self.backtest_data, + ticks_per_step=self.ticks_per_step, + ticks_index=self.backtest_data.ticks_index, + ticks_for_order=self.backtest_data.ticks_for_order, + ) + + +class SAOEMetrics(TypedDict): + """Metrics for SAOE accumulated for a "period". + It could be accumulated for a day, or a period of time (e.g., 30min), or calculated separately for every minute. + + Warnings + -------- + The type hints are for single elements. In lots of times, they can be vectorized. + For example, ``market_volume`` could be a list of float (or ndarray) rather tahn a single float. + """ + + stock_id: str + """Stock ID of this record.""" + datetime: pd.Timestamp | pd.DatetimeIndex # TODO: check this + """Datetime of this record (this is index in the dataframe).""" + direction: int + """Direction of the order. 0 for sell, 1 for buy.""" + + # Market information. + market_volume: np.ndarray | float + """(total) market volume traded in the period.""" + market_price: np.ndarray | float + """Deal price. If it's a period of time, this is the average market deal price.""" + + # Strategy records. + + amount: np.ndarray | float + """Total amount (volume) strategy intends to trade.""" + inner_amount: np.ndarray | float + """Total amount that the lower-level strategy intends to trade + (might be larger than amount, e.g., to ensure ffr).""" + + deal_amount: np.ndarray | float + """Amount that successfully takes effect (must be less than inner_amount).""" + trade_price: np.ndarray | float + """The average deal price for this strategy.""" + trade_value: np.ndarray | float + """Total worth of trading. In the simple simulation, trade_value = deal_amount * price.""" + position: np.ndarray | float + """Position left after this "period".""" + + # Accumulated metrics + + ffr: np.ndarray | float + """Completed how much percent of the daily order.""" + + pa: np.ndarray | float + """Price advantage compared to baseline (i.e., trade with baseline market price). + The baseline is trade price when using TWAP strategy to execute this order. + Please note that there could be data leak here). + Unit is BP (basis point, 1/10000).""" + + +class SAOEState(NamedTuple): + """Data structure holding a state for SAOE simulator.""" + + order: Order + """The order we are dealing with.""" + cur_time: pd.Timestamp + """Current time, e.g., 9:30.""" + position: float + """Current remaining volume to execute.""" + history_exec: pd.DataFrame + """See :attr:`SingleAssetOrderExecution.history_exec`.""" + history_steps: pd.DataFrame + """See :attr:`SingleAssetOrderExecution.history_steps`.""" + + metrics: Optional[SAOEMetrics] + """Daily metric, only available when the trading is in "done" state.""" + + backtest_data: BaseIntradayBacktestData + """Backtest data is included in the state. + Actually, only the time index of this data is needed, at this moment. + I include the full data so that algorithms (e.g., VWAP) that relies on the raw data can be implemented. + Interpreter can use this as they wish, but they should be careful not to leak future data. + """ + + ticks_per_step: int + """How many ticks for each step.""" + ticks_index: pd.DatetimeIndex + """Trading ticks in all day, NOT sliced by order (defined in data). e.g., [9:30, 9:31, ..., 14:59].""" + ticks_for_order: pd.DatetimeIndex + """Trading ticks sliced by order, e.g., [9:45, 9:46, ..., 14:44].""" diff --git a/qlib/rl/order_execution/strategy.py b/qlib/rl/order_execution/strategy.py new file mode 100644 index 0000000000..4a85bc76ed --- /dev/null +++ b/qlib/rl/order_execution/strategy.py @@ -0,0 +1,148 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +import collections +from types import GeneratorType +from typing import Any, Optional, Union, cast, Dict, Generator + +import pandas as pd + +from qlib.backtest import CommonInfrastructure, Order +from qlib.backtest.decision import BaseTradeDecision, TradeDecisionWO, TradeRange +from qlib.backtest.utils import LevelInfrastructure +from qlib.constant import ONE_MIN +from qlib.rl.data.exchange_wrapper import load_qlib_backtest_data +from qlib.rl.order_execution.state import SAOEStateAdapter, SAOEState +from qlib.strategy.base import RLStrategy + + +class SAOEStrategy(RLStrategy): + """RL-based strategies that use SAOEState as state.""" + + def __init__( + self, + policy: object, # TODO: add accurate typehint later. + outer_trade_decision: BaseTradeDecision = None, + level_infra: LevelInfrastructure = None, + common_infra: CommonInfrastructure = None, + **kwargs: Any, + ) -> None: + super(SAOEStrategy, self).__init__( + policy=policy, + outer_trade_decision=outer_trade_decision, + level_infra=level_infra, + common_infra=common_infra, + **kwargs, + ) + + self.adapter_dict: Dict[tuple, SAOEStateAdapter] = {} + self._last_step_range = (0, 0) + + def _create_qlib_backtest_adapter(self, order: Order, trade_range: TradeRange) -> SAOEStateAdapter: + backtest_data = load_qlib_backtest_data(order, self.trade_exchange, trade_range) + + return SAOEStateAdapter( + order=order, + executor=self.executor, + exchange=self.trade_exchange, + ticks_per_step=int(pd.Timedelta(self.trade_calendar.get_freq()) / ONE_MIN), + backtest_data=backtest_data, + ) + + def reset(self, outer_trade_decision: BaseTradeDecision = None, **kwargs: Any) -> None: + super(SAOEStrategy, self).reset(outer_trade_decision=outer_trade_decision, **kwargs) + + self.adapter_dict = {} + self._last_step_range = (0, 0) + + if outer_trade_decision is not None and not outer_trade_decision.empty(): + trade_range = outer_trade_decision.trade_range + assert trade_range is not None + + self.adapter_dict = {} + for decision in outer_trade_decision.get_decision(): + order = cast(Order, decision) + self.adapter_dict[order.key_by_day] = self._create_qlib_backtest_adapter(order, trade_range) + + def get_saoe_state_by_order(self, order: Order) -> SAOEState: + return self.adapter_dict[order.key_by_day].saoe_state + + def post_upper_level_exe_step(self) -> None: + for adapter in self.adapter_dict.values(): + adapter.generate_metrics_after_done() + + def post_exe_step(self, execute_result: Optional[list]) -> None: + last_step_length = self._last_step_range[1] - self._last_step_range[0] + if last_step_length <= 0: + assert not execute_result + return + + results = collections.defaultdict(list) + if execute_result is not None: + for e in execute_result: + results[e[0].key_by_day].append(e) + + for key, adapter in self.adapter_dict.items(): + adapter.update(results[key], self._last_step_range) + + def generate_trade_decision( + self, + execute_result: list = None, + ) -> Union[BaseTradeDecision, Generator[Any, Any, BaseTradeDecision]]: + """ + For SAOEStrategy, we need to update the `self._last_step_range` every time a decision is generated. + This operation should be invisible to developers, so we implement it in `generate_trade_decision()` + The concrete logic to generate decisions should be implemented in `_generate_trade_decision()`. + In other words, all subclass of `SAOEStrategy` should overwrite `_generate_trade_decision()` instead of + `generate_trade_decision()`. + """ + self._last_step_range = self.get_data_cal_avail_range(rtype="step") + + decision = self._generate_trade_decision(execute_result) + if isinstance(decision, GeneratorType): + decision = yield from decision + + return decision + + def _generate_trade_decision(self, execute_result: list = None) -> Generator[Any, Any, BaseTradeDecision]: + raise NotImplementedError + + +class ProxySAOEStrategy(SAOEStrategy): + """Proxy strategy that uses SAOEState. It is called a 'proxy' strategy because it does not make any decisions + by itself. Instead, when the strategy is required to generate a decision, it will yield the environment's + information and let the outside agents to make the decision. Please refer to `_generate_trade_decision` for + more details. + """ + + def __init__( + self, + outer_trade_decision: BaseTradeDecision = None, + level_infra: LevelInfrastructure = None, + common_infra: CommonInfrastructure = None, + **kwargs: Any, + ) -> None: + super().__init__(None, outer_trade_decision, level_infra, common_infra, **kwargs) + + def _generate_trade_decision(self, execute_result: list = None) -> Generator[Any, Any, BaseTradeDecision]: + # Once the following line is executed, this ProxySAOEStrategy (self) will be yielded to the outside + # of the entire executor, and the execution will be suspended. When the execution is resumed by `send()`, + # the item will be captured by `exec_vol`. The outside policy could communicate with the inner + # level strategy through this way. + exec_vol = yield self + + oh = self.trade_exchange.get_order_helper() + order = oh.create(self._order.stock_id, exec_vol, self._order.direction) + + return TradeDecisionWO([order], self) + + def reset(self, outer_trade_decision: BaseTradeDecision = None, **kwargs: Any) -> None: + super().reset(outer_trade_decision=outer_trade_decision, **kwargs) + + assert isinstance(outer_trade_decision, TradeDecisionWO) + if outer_trade_decision is not None: + order_list = outer_trade_decision.order_list + assert len(order_list) == 1 + self._order = order_list[0] diff --git a/qlib/rl/order_execution/utils.py b/qlib/rl/order_execution/utils.py index e2d0de9812..43517fe744 100644 --- a/qlib/rl/order_execution/utils.py +++ b/qlib/rl/order_execution/utils.py @@ -3,52 +3,14 @@ from __future__ import annotations -from typing import Any, List, Tuple, cast +from typing import Any, cast import numpy as np import pandas as pd -from qlib.backtest import CommonInfrastructure, get_exchange -from qlib.backtest.account import Account from qlib.backtest.decision import OrderDir -from qlib.backtest.executor import BaseExecutor -from qlib.rl.from_neutrader.config import ExchangeConfig -from qlib.rl.order_execution.simulator_simple import ONE_SEC, _float_or_ndarray -from qlib.utils.time import Freq - - -def get_common_infra( - config: ExchangeConfig, - trade_date: pd.Timestamp, - codes: List[str], - cash_limit: float = None, -) -> CommonInfrastructure: - # need to specify a range here for acceleration - if cash_limit is None: - trade_account = Account(init_cash=int(1e12), benchmark_config={}, pos_type="InfPosition") - else: - trade_account = Account( - init_cash=cash_limit, - benchmark_config={}, - pos_type="Position", - position_dict={code: {"amount": 1e12, "price": 1.0} for code in codes}, - ) - - exchange = get_exchange( - codes=codes, - freq="1min", - limit_threshold=config.limit_threshold, - deal_price=config.deal_price, - open_cost=config.open_cost, - close_cost=config.close_cost, - min_cost=config.min_cost if config.trade_unit is not None else 0, - start_time=trade_date, - end_time=trade_date + pd.DateOffset(1), - trade_unit=config.trade_unit, - volume_threshold=config.volume_threshold, - ) - - return CommonInfrastructure(trade_account=trade_account, trade_exchange=exchange) +from qlib.backtest.executor import BaseExecutor, NestedExecutor, SimulatorExecutor +from qlib.constant import EPS_T, float_or_ndarray def get_ticks_slice( @@ -58,7 +20,7 @@ def get_ticks_slice( include_end: bool = False, ) -> pd.DatetimeIndex: if not include_end: - end = end - ONE_SEC + end = end - EPS_T return ticks_index[ticks_index.slice_indexer(start, end)] @@ -72,10 +34,10 @@ def dataframe_append(df: pd.DataFrame, other: Any) -> pd.DataFrame: def price_advantage( - exec_price: _float_or_ndarray, + exec_price: float_or_ndarray, baseline_price: float, direction: OrderDir | int, -) -> _float_or_ndarray: +) -> float_or_ndarray: if baseline_price == 0: # something is wrong with data. Should be nan here if isinstance(exec_price, float): return 0.0 @@ -91,21 +53,11 @@ def price_advantage( if res_wo_nan.size == 1: return res_wo_nan.item() else: - return cast(_float_or_ndarray, res_wo_nan) - - -def get_portfolio_and_indicator(executor: BaseExecutor) -> Tuple[dict, dict]: - all_executors = executor.get_all_executors() - all_portfolio_metrics = { - "{}{}".format(*Freq.parse(_executor.time_per_step)): _executor.trade_account.get_portfolio_metrics() - for _executor in all_executors - if _executor.trade_account.is_port_metr_enabled() - } + return cast(float_or_ndarray, res_wo_nan) - all_indicators = {} - for _executor in all_executors: - key = "{}{}".format(*Freq.parse(_executor.time_per_step)) - all_indicators[key] = _executor.trade_account.get_trade_indicator().generate_trade_indicators_dataframe() - all_indicators[key + "_obj"] = _executor.trade_account.get_trade_indicator() - return all_portfolio_metrics, all_indicators +def get_simulator_executor(executor: BaseExecutor) -> SimulatorExecutor: + while isinstance(executor, NestedExecutor): + executor = executor.inner_executor + assert isinstance(executor, SimulatorExecutor) + return executor diff --git a/qlib/rl/from_neutrader/__init__.py b/qlib/rl/strategy/__init__.py similarity index 52% rename from qlib/rl/from_neutrader/__init__.py rename to qlib/rl/strategy/__init__.py index 765bdee0c1..59e481eb93 100644 --- a/qlib/rl/from_neutrader/__init__.py +++ b/qlib/rl/strategy/__init__.py @@ -1,4 +1,2 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. - -# TODO: find a better way to organize contents under this module. diff --git a/qlib/rl/strategy/single_order.py b/qlib/rl/strategy/single_order.py new file mode 100644 index 0000000000..9d8e396ce0 --- /dev/null +++ b/qlib/rl/strategy/single_order.py @@ -0,0 +1,31 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from qlib.backtest import Order +from qlib.backtest.decision import OrderHelper, TradeDecisionWO, TradeRange +from qlib.strategy.base import BaseStrategy + + +class SingleOrderStrategy(BaseStrategy): + """Strategy used to generate a trade decision with exactly one order.""" + + def __init__( + self, + order: Order, + trade_range: TradeRange = None, + ) -> None: + super().__init__() + + self._order = order + self._trade_range = trade_range + + def generate_trade_decision(self, execute_result: list = None) -> TradeDecisionWO: + oh: OrderHelper = self.common_infra.get("trade_exchange").get_order_helper() + order_list = [ + oh.create( + code=self._order.stock_id, + amount=self._order.amount, + direction=self._order.direction, + ), + ] + return TradeDecisionWO(order_list, self, self._trade_range) diff --git a/qlib/rl/utils/finite_env.py b/qlib/rl/utils/finite_env.py index 309b34e6dd..87f0900e16 100644 --- a/qlib/rl/utils/finite_env.py +++ b/qlib/rl/utils/finite_env.py @@ -11,13 +11,14 @@ import copy import warnings from contextlib import contextmanager -from typing import Any, Callable, cast, Dict, Generator, List, Optional, Set, Tuple, Type, Union +from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple, Type, Union, cast import gym import numpy as np from tianshou.env import BaseVectorEnv, DummyVectorEnv, ShmemVectorEnv, SubprocVectorEnv from qlib.typehint import Literal + from .log import LogWriter __all__ = [ diff --git a/qlib/strategy/base.py b/qlib/strategy/base.py index 27df347fc5..532e88452e 100644 --- a/qlib/strategy/base.py +++ b/qlib/strategy/base.py @@ -8,6 +8,7 @@ if TYPE_CHECKING: from qlib.backtest.exchange import Exchange from qlib.backtest.position import BasePosition + from qlib.backtest.executor import BaseExecutor from typing import Tuple @@ -55,6 +56,10 @@ def __init__( self._reset(level_infra=level_infra, common_infra=common_infra, outer_trade_decision=outer_trade_decision) self._trade_exchange = trade_exchange + @property + def executor(self) -> BaseExecutor: + return self.level_infra.get("executor") + @property def trade_calendar(self) -> TradeCalendarManager: return self.level_infra.get("trade_calendar") @@ -85,7 +90,7 @@ def reset( level_infra: LevelInfrastructure = None, common_infra: CommonInfrastructure = None, outer_trade_decision: BaseTradeDecision = None, - **kwargs, # TODO: remove this? + **kwargs, ) -> None: """ - reset `level_infra`, used to reset trade calendar, .etc @@ -136,6 +141,41 @@ def generate_trade_decision( """ raise NotImplementedError("generate_trade_decision is not implemented!") + # helper methods: not necessary but for convenience + def get_data_cal_avail_range(self, rtype: str = "full") -> Tuple[int, int]: + """ + return data calendar's available decision range for `self` strategy + the range consider following factors + - data calendar in the charge of `self` strategy + - trading range limitation from the decision of outer strategy + + + related methods + - TradeCalendarManager.get_data_cal_range + - BaseTradeDecision.get_data_cal_range_limit + + Parameters + ---------- + rtype: str + - "full": return the available data index range of the strategy from `start_time` to `end_time` + - "step": return the available data index range of the strategy of current step + + Returns + ------- + Tuple[int, int]: + the available range both sides are closed + """ + cal_range = self.trade_calendar.get_data_cal_range(rtype=rtype) + if self.outer_trade_decision is None: + raise ValueError(f"There is not limitation for strategy {self}") + range_limit = self.outer_trade_decision.get_data_cal_range_limit(rtype=rtype) + return max(cal_range[0], range_limit[0]), min(cal_range[1], range_limit[1]) + + """ + The following methods are used to do cross-level communications in nested execution. + You do not need to care about them if you are implementing a single-level execution. + """ + @staticmethod def update_trade_decision( trade_decision: BaseTradeDecision, @@ -158,7 +198,6 @@ def update_trade_decision( # default to return None, which indicates that the trade decision is not changed return None - # FIXME: do not define this method as an abstract one since it is never implemented def alter_outer_trade_decision(self, outer_trade_decision: BaseTradeDecision) -> BaseTradeDecision: """ A method for updating the outer_trade_decision. @@ -175,39 +214,15 @@ def alter_outer_trade_decision(self, outer_trade_decision: BaseTradeDecision) -> """ # default to reset the decision directly # NOTE: normally, user should do something to the strategy due to the change of outer decision - raise NotImplementedError(f"Please implement the `alter_outer_trade_decision` method") + return outer_trade_decision - # helper methods: not necessary but for convenience - def get_data_cal_avail_range(self, rtype: str = "full") -> Tuple[int, int]: + def post_upper_level_exe_step(self) -> None: """ - return data calendar's available decision range for `self` strategy - the range consider following factors - - data calendar in the charge of `self` strategy - - trading range limitation from the decision of outer strategy - - - related methods - - TradeCalendarManager.get_data_cal_range - - BaseTradeDecision.get_data_cal_range_limit - - Parameters - ---------- - rtype: str - - "full": return the available data index range of the strategy from `start_time` to `end_time` - - "step": return the available data index range of the strategy of current step - - Returns - ------- - Tuple[int, int]: - the available range both sides are closed + A hook for doing sth after the upper level executor finished its execution (for example, finalize + the metrics collection). """ - cal_range = self.trade_calendar.get_data_cal_range(rtype=rtype) - if self.outer_trade_decision is None: - raise ValueError(f"There is not limitation for strategy {self}") - range_limit = self.outer_trade_decision.get_data_cal_range_limit(rtype=rtype) - return max(cal_range[0], range_limit[0]), min(cal_range[1], range_limit[1]) - def post_exe_step(self, execute_result: list) -> None: + def post_exe_step(self, execute_result: Optional[list]) -> None: """ A hook for doing sth after the corresponding executor finished its execution. diff --git a/tests/rl/test_qlib_simulator.py b/tests/rl/test_qlib_simulator.py index ca7820645f..b7d548e9ea 100644 --- a/tests/rl/test_qlib_simulator.py +++ b/tests/rl/test_qlib_simulator.py @@ -1,17 +1,16 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. + import sys from pathlib import Path +from typing import Tuple import pandas as pd import pytest - -from qlib.backtest.decision import Order, OrderDir -from qlib.backtest.executor import NestedExecutor, SimulatorExecutor -from qlib.backtest.utils import CommonInfrastructure -from qlib.contrib.strategy import TWAPStrategy +from qlib.backtest.decision import Order, OrderDir, TradeRangeByTime +from qlib.backtest.executor import SimulatorExecutor from qlib.rl.order_execution import CategoricalActionInterpreter -from qlib.rl.order_execution.simulator_qlib import ExchangeConfig, SingleAssetOrderExecutionQlib +from qlib.rl.order_execution.simulator_qlib import SingleAssetOrderExecution TOTAL_POSITION = 2100.0 @@ -32,23 +31,71 @@ def get_order() -> Order: ) -def get_simulator(order: Order) -> SingleAssetOrderExecutionQlib: - def _inner_executor_fn(time_per_step: str, common_infra: CommonInfrastructure) -> NestedExecutor: - return NestedExecutor( - time_per_step=time_per_step, - inner_strategy=TWAPStrategy(), - inner_executor=SimulatorExecutor( - time_per_step="1min", - verbose=False, - trade_type=SimulatorExecutor.TT_SERIAL, - generate_report=False, - common_infra=common_infra, - track_data=True, - ), - common_infra=common_infra, - track_data=True, - ) +def get_configs(order: Order) -> Tuple[dict, dict, dict]: + strategy_config = { + "class": "SingleOrderStrategy", + "module_path": "qlib.rl.strategy.single_order", + "kwargs": { + "order": order, + "trade_range": TradeRangeByTime(order.start_time.time(), order.end_time.time()), + }, + } + + executor_config = { + "class": "NestedExecutor", + "module_path": "qlib.backtest.executor", + "kwargs": { + "time_per_step": "1day", + "inner_strategy": {"class": "ProxySAOEStrategy", "module_path": "qlib.rl.order_execution.strategy"}, + "track_data": True, + "inner_executor": { + "class": "NestedExecutor", + "module_path": "qlib.backtest.executor", + "kwargs": { + "time_per_step": "30min", + "inner_strategy": { + "class": "TWAPStrategy", + "module_path": "qlib.contrib.strategy.rule_strategy", + }, + "inner_executor": { + "class": "SimulatorExecutor", + "module_path": "qlib.backtest.executor", + "kwargs": { + "time_per_step": "1min", + "verbose": False, + "trade_type": SimulatorExecutor.TT_SERIAL, + "generate_report": False, + "track_data": True, + }, + }, + "track_data": True, + }, + }, + "start_time": pd.Timestamp(order.start_time.date()), + "end_time": pd.Timestamp(order.start_time.date()), + }, + } + + exchange_config = { + "freq": "1min", + "codes": [order.stock_id], + "limit_threshold": ("$ask == 0", "$bid == 0"), + "deal_price": ("If($ask == 0, $bid, $ask)", "If($bid == 0, $ask, $bid)"), + "volume_threshold": { + "all": ("cum", "0.2 * DayCumsum($volume, '9:30', '14:29')"), + "buy": ("current", "$askV1"), + "sell": ("current", "$bidV1"), + }, + "open_cost": 0.0005, + "close_cost": 0.0015, + "min_cost": 5.0, + "trade_unit": None, + } + + return strategy_config, executor_config, exchange_config + +def get_simulator(order: Order) -> SingleAssetOrderExecution: DATA_ROOT_DIR = Path(__file__).parent.parent / ".data" / "rl" / "qlib_simulator" # fmt: off @@ -67,27 +114,13 @@ def _inner_executor_fn(time_per_step: str, common_infra: CommonInfrastructure) - } # fmt: on - exchange_config = ExchangeConfig( - limit_threshold=("$ask == 0", "$bid == 0"), - deal_price=("If($ask == 0, $bid, $ask)", "If($bid == 0, $ask, $bid)"), - volume_threshold={ - "all": ("cum", "0.2 * DayCumsum($volume, '9:30', '14:29')"), - "buy": ("current", "$askV1"), - "sell": ("current", "$bidV1"), - }, - open_cost=0.0005, - close_cost=0.0015, - min_cost=5.0, - trade_unit=None, - cash_limit=None, - generate_report=False, - ) + strategy_config, executor_config, exchange_config = get_configs(order) - return SingleAssetOrderExecutionQlib( + return SingleAssetOrderExecution( order=order, - time_per_step="30min", qlib_config=qlib_config, - inner_executor_fn=_inner_executor_fn, + strategy_config=strategy_config, + executor_config=executor_config, exchange_config=exchange_config, ) @@ -115,12 +148,12 @@ def test_simulator_first_step(): assert is_close(state.history_exec["trade_price"].iloc[0], 149.566483) assert is_close(state.history_exec["trade_value"].iloc[0], 1495.664825) assert is_close(state.history_exec["position"].iloc[0], TOTAL_POSITION - AMOUNT / 30) - # assert state.history_exec["ffr"].iloc[0] == 1 / 60 # FIXME + assert is_close(state.history_exec["ffr"].iloc[0], AMOUNT / TOTAL_POSITION / 30) assert is_close(state.history_steps["market_volume"].iloc[0], 1254848.5756835938) assert state.history_steps["amount"].iloc[0] == AMOUNT assert state.history_steps["deal_amount"].iloc[0] == AMOUNT - assert state.history_steps["ffr"].iloc[0] == 1.0 + assert state.history_steps["ffr"].iloc[0] == AMOUNT / TOTAL_POSITION assert is_close( state.history_steps["pa"].iloc[0] * (1.0 if order.direction == OrderDir.SELL else -1.0), (state.history_steps["trade_price"].iloc[0] / simulator.twap_price - 1) * 10000, @@ -169,9 +202,3 @@ def test_interpreter() -> None: position_history.append(state.position) assert position_history[-1] == max(TOTAL_POSITION - TOTAL_POSITION / NUM_EXECUTION * (i + 1), 0.0) - - -if __name__ == "__main__": - test_simulator_first_step() - test_simulator_stop_twap() - test_interpreter()