From 08f725cf406277785f7db0bcdd5053cf21cad05a Mon Sep 17 00:00:00 2001 From: Default Date: Mon, 19 Sep 2022 15:01:06 +0800 Subject: [PATCH 1/8] RL backtest with simulator --- qlib/backtest/decision.py | 13 ++ qlib/rl/contrib/backtest.py | 181 +++++++++++++++++----- qlib/rl/order_execution/simulator_qlib.py | 65 +++++--- qlib/rl/order_execution/strategy.py | 28 +++- tests/rl/test_qlib_simulator.py | 16 +- 5 files changed, 232 insertions(+), 71 deletions(-) diff --git a/qlib/backtest/decision.py b/qlib/backtest/decision.py index 042b73fea8..115823eadb 100644 --- a/qlib/backtest/decision.py +++ b/qlib/backtest/decision.py @@ -576,3 +576,16 @@ def __repr__(self) -> str: f"trade_range: {self.trade_range}; " f"order_list[{len(self.order_list)}]" ) + + +class TradeDecisionWithDetails(TradeDecisionWO): + def __init__( + self, + order_list: List[Order], + strategy: BaseStrategy, + trade_range: Optional[Tuple[int, int]] = None, + details: Optional[Any] = None, + ) -> None: + super().__init__(order_list, strategy, trade_range) + + self.details = details diff --git a/qlib/rl/contrib/backtest.py b/qlib/rl/contrib/backtest.py index 709c050dfb..bffa850b17 100644 --- a/qlib/rl/contrib/backtest.py +++ b/qlib/rl/contrib/backtest.py @@ -2,11 +2,12 @@ # Licensed under the MIT License. from __future__ import annotations +import argparse import copy import pickle -import sys +from collections import defaultdict from pathlib import Path -from typing import Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import numpy as np import pandas as pd @@ -14,12 +15,13 @@ from joblib import Parallel, delayed from qlib.backtest import collect_data_loop, get_strategy_executor -from qlib.backtest.decision import TradeRangeByTime +from qlib.backtest.decision import Order, OrderDir, TradeRangeByTime from qlib.backtest.executor import BaseExecutor, NestedExecutor, SimulatorExecutor from qlib.backtest.high_performance_ds import BaseOrderIndicator from qlib.rl.contrib.naive_config_parser import get_backtest_config_fromfile from qlib.rl.contrib.utils import read_order_file from qlib.rl.data.integration import init_qlib +from qlib.rl.order_execution.simulator_qlib import SingleAssetOrderExecution from qlib.rl.utils.env_wrapper import CollectDataEnvWrapper @@ -90,26 +92,109 @@ def _convert_indicator_to_dataframe(indicator: dict) -> Optional[pd.DataFrame]: return records -def _generate_report(decisions: list, report_dict: dict) -> dict: +def _generate_report(decisions: list, report_dicts: List[dict]) -> dict: + indicator_dict = defaultdict(list) + indicator_his = defaultdict(list) + for report_dict in report_dicts: + for key, value in report_dict["indicator"].items(): + if key.endswith("_obj"): + indicator_his[key].append(value.order_indicator_his) + else: + indicator_dict[key].append(value) + report = {} decision_details = pd.concat([d.details for d in decisions if hasattr(d, "details")]) - for key in ["1minute", "5minute", "30minute", "1day"]: - if key not in report_dict["indicator"]: + for key in ["1min", "5min", "30min", "1day"]: + if key not in indicator_dict: continue - report[key] = report_dict["indicator"][key] - report[key + "_obj"] = _convert_indicator_to_dataframe( - report_dict["indicator"][key + "_obj"].order_indicator_his - ) - cur_details = decision_details[decision_details.freq == key.rstrip("ute")].set_index(["instrument", "datetime"]) + + report[key] = pd.concat(indicator_dict[key]) + report[key + "_obj"] = pd.concat([_convert_indicator_to_dataframe(his) for his in indicator_his[key + "_obj"]]) + + cur_details = decision_details[decision_details.freq == key].set_index(["instrument", "datetime"]) if len(cur_details) > 0: cur_details.pop("freq") report[key + "_obj"] = report[key + "_obj"].join(cur_details, how="outer") - if "1minute" in report_dict["report"]: - report["simulator"] = report_dict["report"]["1minute"][0] + return report -def single( +def single_with_simulator( + backtest_config: dict, + orders: pd.DataFrame, + split: str = "stock", + cash_limit: float = None, + generate_report: bool = False, +) -> Union[Tuple[pd.DataFrame, dict], pd.DataFrame]: + if split == "stock": + stock_id = orders.iloc[0].instrument + init_qlib(backtest_config["qlib"], part=stock_id) + else: + day = orders.iloc[0].datetime + init_qlib(backtest_config["qlib"], part=day) + + stocks = orders.instrument.unique().tolist() + + reports = [] + decisions = [] + for _, row in orders.iterrows(): + date = pd.Timestamp(row["datetime"]) + start_time = pd.Timestamp(backtest_config["start_time"]).replace(year=date.year, month=date.month, day=date.day) + end_time = pd.Timestamp(backtest_config["end_time"]).replace(year=date.year, month=date.month, day=date.day) + order = Order( + stock_id=row["instrument"], + amount=row["amount"], + direction=OrderDir(row["direction"]), + start_time=start_time, + end_time=end_time, + ) + + executor_config = _get_multi_level_executor_config( + strategy_config=backtest_config["strategies"], + cash_limit=cash_limit, + generate_report=generate_report, + ) + + exchange_config = copy.deepcopy(backtest_config["exchange"]) + exchange_config.update( + { + "codes": stocks, + "freq": "1min", + } + ) + + simulator = SingleAssetOrderExecution( + order=order, + executor_config=executor_config, + exchange_config=exchange_config, + qlib_config=None, + cash_limit=None, + backtest_mode=True, + ) + + reports.append(simulator.report_dict) + decisions += simulator.decisions + + indicator = {k: v for report in reports for k, v in report["indicator"]["1day_obj"].order_indicator_his.items()} + records = _convert_indicator_to_dataframe(indicator) + assert records is None or not np.isnan(records["ffr"]).any() + + if generate_report: + report = _generate_report(decisions, reports) + + if split == "stock": + stock_id = orders.iloc[0].instrument + report = {stock_id: report} + else: + day = orders.iloc[0].datetime + report = {day: report} + + return records, report + else: + return records + + +def single_with_collect_data_loop( backtest_config: dict, orders: pd.DataFrame, split: str = "stock", @@ -127,7 +212,7 @@ def single( trade_end_time = orders["datetime"].max() stocks = orders.instrument.unique().tolist() - top_strategy_config = { + strategy_config = { "class": "FileOrderStrategy", "module_path": "qlib.contrib.strategy.rule_strategy", "kwargs": { @@ -139,14 +224,14 @@ def single( }, } - top_executor_config = _get_multi_level_executor_config( + executor_config = _get_multi_level_executor_config( strategy_config=backtest_config["strategies"], cash_limit=cash_limit, generate_report=generate_report, ) - tmp_backtest_config = copy.deepcopy(backtest_config["exchange"]) - tmp_backtest_config.update( + exchange_config = copy.deepcopy(backtest_config["exchange"]) + exchange_config.update( { "codes": stocks, "freq": "1min", @@ -156,11 +241,11 @@ def single( strategy, executor = get_strategy_executor( start_time=pd.Timestamp(trade_start_time), end_time=pd.Timestamp(trade_end_time) + pd.DateOffset(1), - strategy=top_strategy_config, - executor=top_executor_config, + strategy=strategy_config, + executor=executor_config, benchmark=None, account=cash_limit if cash_limit is not None else int(1e12), - exchange_kwargs=tmp_backtest_config, + exchange_kwargs=exchange_config, pos_type="Position" if cash_limit is not None else "InfPosition", ) _set_env_for_all_strategy(executor=executor) @@ -172,7 +257,7 @@ def single( assert records is None or not np.isnan(records["ffr"]).any() if generate_report: - report = _generate_report(decisions, report_dict) + report = _generate_report(decisions, [report_dict]) if split == "stock": stock_id = orders.iloc[0].instrument report = {stock_id: report} @@ -184,7 +269,7 @@ def single( return records -def backtest(backtest_config: dict) -> pd.DataFrame: +def backtest(backtest_config: dict, parallel_mode: bool = False, with_simulator: bool = False) -> pd.DataFrame: order_df = read_order_file(backtest_config["order_file"]) cash_limit = backtest_config["exchange"].pop("cash_limit") @@ -193,18 +278,33 @@ def backtest(backtest_config: dict) -> pd.DataFrame: stock_pool = order_df["instrument"].unique().tolist() stock_pool.sort() - mp_config = {"n_jobs": backtest_config["concurrency"], "verbose": 10, "backend": "multiprocessing"} - torch.set_num_threads(1) # https://github.com/pytorch/pytorch/issues/17199 - res = Parallel(**mp_config)( - delayed(single)( - backtest_config=backtest_config, - orders=order_df[order_df["instrument"] == stock].copy(), - split="stock", - cash_limit=cash_limit, - generate_report=generate_report, + stock_pool = stock_pool + + single = single_with_simulator if with_simulator else single_with_collect_data_loop + if parallel_mode: + mp_config = {"n_jobs": backtest_config["concurrency"], "verbose": 10, "backend": "multiprocessing"} + torch.set_num_threads(1) # https://github.com/pytorch/pytorch/issues/17199 + res = Parallel(**mp_config)( + delayed(single)( + backtest_config=backtest_config, + orders=order_df[order_df["instrument"] == stock].copy(), + split="stock", + cash_limit=cash_limit, + generate_report=generate_report, + ) + for stock in stock_pool ) - for stock in stock_pool - ) + else: + res = [ + single( + backtest_config=backtest_config, + orders=order_df[order_df["instrument"] == stock].copy(), + split="stock", + cash_limit=cash_limit, + generate_report=generate_report, + ) + for stock in stock_pool + ] output_path = Path(backtest_config["output_dir"]) if generate_report: @@ -227,5 +327,14 @@ def backtest(backtest_config: dict) -> pd.DataFrame: warnings.filterwarnings("ignore", category=DeprecationWarning) warnings.filterwarnings("ignore", category=RuntimeWarning) - path = sys.argv[1] - backtest(get_backtest_config_fromfile(path)) + parser = argparse.ArgumentParser() + parser.add_argument("--config_path", type=str, required=True, help="Path to the config file") + parser.add_argument("--parallel", action="store_true", help="Whether to run pipelines in parallel") + parser.add_argument("--use_simulator", action="store_true", help="Whether to use simulator as the backend") + args = parser.parse_args() + + backtest( + backtest_config=get_backtest_config_fromfile(args.config_path), + parallel_mode=args.parallel, + with_simulator=args.use_simulator, + ) diff --git a/qlib/rl/order_execution/simulator_qlib.py b/qlib/rl/order_execution/simulator_qlib.py index 718c2ba572..7fc94a52b8 100644 --- a/qlib/rl/order_execution/simulator_qlib.py +++ b/qlib/rl/order_execution/simulator_qlib.py @@ -3,17 +3,18 @@ from __future__ import annotations -from typing import Generator, Optional +from typing import Generator, List, Optional import pandas as pd -from qlib.backtest import collect_data_loop, get_strategy_executor -from qlib.backtest.decision import Order -from qlib.backtest.executor import NestedExecutor -from qlib.rl.simulator import Simulator +from qlib.backtest import collect_data_loop, get_strategy_executor +from qlib.backtest.decision import BaseTradeDecision, Order, TradeRangeByTime +from qlib.backtest.executor import BaseExecutor, NestedExecutor from qlib.rl.data.integration import init_qlib +from qlib.rl.simulator import Simulator from .state import SAOEState, SAOEStateAdapter from .strategy import SAOEStrategy +from ..utils.env_wrapper import CollectDataEnvWrapper class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]): @@ -23,30 +24,42 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]): ---------- order The seed to start an SAOE simulator is an order. - strategy_config - Strategy configuration executor_config Executor configuration exchange_config Exchange configuration qlib_config Configuration used to initialize Qlib. If it is None, Qlib will not be initialized. + cash_limit: + Cash limit. + backtest_mode + Whether the simulator is under backtest mode. """ def __init__( self, order: Order, - strategy_config: dict, executor_config: dict, exchange_config: dict, qlib_config: dict = None, + cash_limit: Optional[float] = None, + backtest_mode: bool = False, ) -> None: super().__init__(initial=order) assert order.start_time.date() == order.end_time.date(), "Start date and end date must be the same." + strategy_config = { + "class": "SingleOrderStrategy", + "module_path": "qlib.rl.strategy.single_order", + "kwargs": { + "order": order, + "trade_range": TradeRangeByTime(order.start_time.time(), order.end_time.time()), + }, + } + self._collect_data_loop: Optional[Generator] = None - self.reset(order, strategy_config, executor_config, exchange_config, qlib_config) + self.reset(order, strategy_config, executor_config, exchange_config, qlib_config, cash_limit, backtest_mode) def reset( self, @@ -55,6 +68,8 @@ def reset( executor_config: dict, exchange_config: dict, qlib_config: dict = None, + cash_limit: Optional[float] = None, + backtest_mode: bool = False, ) -> None: if qlib_config is not None: init_qlib(qlib_config, part="skip") @@ -65,22 +80,32 @@ def reset( strategy=strategy_config, executor=executor_config, benchmark=order.stock_id, - account=1e12, + account=cash_limit if cash_limit is not None else int(1e12), exchange_kwargs=exchange_config, - pos_type="InfPosition", + pos_type="Position" if cash_limit is not None else "InfPosition", ) assert isinstance(self._executor, NestedExecutor) + self.report_dict: dict = {} + self.decisions: List[BaseTradeDecision] = [] self._collect_data_loop = collect_data_loop( start_time=order.date, end_time=order.date, trade_strategy=strategy, trade_executor=self._executor, + return_value=self.report_dict, ) assert isinstance(self._collect_data_loop, Generator) - self._last_yielded_saoe_strategy = self._iter_strategy(action=None) + if backtest_mode: + executor: BaseExecutor = self._executor + while isinstance(executor, NestedExecutor): + if hasattr(executor.inner_strategy, "set_env"): + executor.inner_strategy.set_env(CollectDataEnvWrapper()) + executor = executor.inner_executor + + self.step(action=None) self._order = order @@ -91,17 +116,19 @@ def _get_adapter(self) -> SAOEStateAdapter: def twap_price(self) -> float: return self._get_adapter().twap_price - def _iter_strategy(self, action: float = None) -> SAOEStrategy: + def _iter_strategy(self, action: Optional[float] = None) -> SAOEStrategy: """Iterate the _collect_data_loop until we get the next yield SAOEStrategy.""" assert self._collect_data_loop is not None - strategy = next(self._collect_data_loop) if action is None else self._collect_data_loop.send(action) - while not isinstance(strategy, SAOEStrategy): - strategy = next(self._collect_data_loop) if action is None else self._collect_data_loop.send(action) - assert isinstance(strategy, SAOEStrategy) - return strategy + obj = next(self._collect_data_loop) if action is None else self._collect_data_loop.send(action) + while not isinstance(obj, SAOEStrategy): + if isinstance(obj, BaseTradeDecision): + self.decisions.append(obj) + obj = next(self._collect_data_loop) if action is None else self._collect_data_loop.send(action) + assert isinstance(obj, SAOEStrategy) + return obj - def step(self, action: float) -> None: + def step(self, action: Optional[float]) -> None: """Execute one step or SAOE. Parameters diff --git a/qlib/rl/order_execution/strategy.py b/qlib/rl/order_execution/strategy.py index ecc879bf51..663b8e8ff4 100644 --- a/qlib/rl/order_execution/strategy.py +++ b/qlib/rl/order_execution/strategy.py @@ -5,15 +5,16 @@ import collections from types import GeneratorType -from typing import Any, cast, Dict, Generator, Optional, Union +from typing import Any, cast, Dict, Generator, List, Optional, Union +import numpy as np import pandas as pd import torch from tianshou.data import Batch from tianshou.policy import BasePolicy from qlib.backtest import CommonInfrastructure, Order -from qlib.backtest.decision import BaseTradeDecision, TradeDecisionWO, TradeRange +from qlib.backtest.decision import BaseTradeDecision, TradeDecisionWithDetails, TradeDecisionWO, TradeRange from qlib.backtest.utils import LevelInfrastructure from qlib.constant import ONE_MIN from qlib.rl.data.native import load_backtest_data @@ -235,6 +236,23 @@ def reset(self, outer_trade_decision: BaseTradeDecision = None, **kwargs: Any) - if self._backtest: self._env.reset() + def _generate_trade_details(self, act: np.ndarray, exec_vols: List[float]) -> pd.DataFrame: + assert hasattr(self.outer_trade_decision, "order_list") + + trade_details = [] + for a, v, o in zip(act, exec_vols, getattr(self.outer_trade_decision, "order_list")): + trade_details.append( + { + "instrument": o.stock_id, + "datetime": self.trade_calendar.get_step_time()[0], + "freq": self.trade_calendar.get_freq(), + "rl_exec_vol": v, + } + ) + if a is not None: + trade_details[-1]["rl_action"] = a + return pd.DataFrame.from_records(trade_details) + def _generate_trade_decision(self, execute_result: list = None) -> BaseTradeDecision: states = [] obs_batch = [] @@ -261,4 +279,8 @@ def _generate_trade_decision(self, execute_result: list = None) -> BaseTradeDeci order = cast(Order, decision) order_list.append(oh.create(order.stock_id, exec_vol, order.direction)) - return TradeDecisionWO(order_list=order_list, strategy=self) + return TradeDecisionWithDetails( + order_list=order_list, + strategy=self, + details=self._generate_trade_details(act, exec_vols), + ) diff --git a/tests/rl/test_qlib_simulator.py b/tests/rl/test_qlib_simulator.py index 14bf8b5a11..92ad9c0583 100644 --- a/tests/rl/test_qlib_simulator.py +++ b/tests/rl/test_qlib_simulator.py @@ -32,16 +32,7 @@ def get_order() -> Order: ) -def get_configs(order: Order) -> Tuple[dict, dict, dict]: - strategy_config = { - "class": "SingleOrderStrategy", - "module_path": "qlib.rl.strategy.single_order", - "kwargs": { - "order": order, - "trade_range": TradeRangeByTime(order.start_time.time(), order.end_time.time()), - }, - } - +def get_configs(order: Order) -> Tuple[dict, dict]: executor_config = { "class": "NestedExecutor", "module_path": "qlib.backtest.executor", @@ -93,7 +84,7 @@ def get_configs(order: Order) -> Tuple[dict, dict, dict]: "trade_unit": None, } - return strategy_config, executor_config, exchange_config + return executor_config, exchange_config def get_simulator(order: Order) -> SingleAssetOrderExecution: @@ -115,12 +106,11 @@ def get_simulator(order: Order) -> SingleAssetOrderExecution: } # fmt: on - strategy_config, executor_config, exchange_config = get_configs(order) + executor_config, exchange_config = get_configs(order) return SingleAssetOrderExecution( order=order, qlib_config=qlib_config, - strategy_config=strategy_config, executor_config=executor_config, exchange_config=exchange_config, ) From 19073729a3d3942a408f914bf77b91411d572279 Mon Sep 17 00:00:00 2001 From: Huoran Li Date: Tue, 27 Sep 2022 14:12:44 +0800 Subject: [PATCH 2/8] Minor modification in init_qlib --- qlib/rl/data/integration.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/qlib/rl/data/integration.py b/qlib/rl/data/integration.py index d32ce49c82..af5025c843 100644 --- a/qlib/rl/data/integration.py +++ b/qlib/rl/data/integration.py @@ -81,10 +81,12 @@ def init_qlib(qlib_config: dict, part: str = None) -> None: def _convert_to_path(path: str | Path) -> Path: return path if isinstance(path, Path) else Path(path) - provider_uri_map = { - "day": _convert_to_path(qlib_config["provider_uri_day"]).as_posix(), - "1min": _convert_to_path(qlib_config["provider_uri_1min"]).as_posix(), - } + provider_uri_map = {} + if "provider_uri_day" in qlib_config: + provider_uri_map["day"] = _convert_to_path(qlib_config["provider_uri_day"]).as_posix() + if "provider_uri_1min" in qlib_config: + provider_uri_map["1min"] = _convert_to_path(qlib_config["provider_uri_1min"]).as_posix() + qlib.init( region=REG_CN, auto_mount=False, From 22cb8ee2fa20e72b8e1fc8e52ed9ebfbb13583db Mon Sep 17 00:00:00 2001 From: Huoran Li Date: Tue, 27 Sep 2022 14:13:19 +0800 Subject: [PATCH 3/8] Cherry pick PR 1302 --- qlib/contrib/data/highfreq_handler.py | 126 ++++++++++++++++++++++++++ 1 file changed, 126 insertions(+) diff --git a/qlib/contrib/data/highfreq_handler.py b/qlib/contrib/data/highfreq_handler.py index 4898725da9..a27f0ce370 100644 --- a/qlib/contrib/data/highfreq_handler.py +++ b/qlib/contrib/data/highfreq_handler.py @@ -110,6 +110,92 @@ def get_normalized_price_feature(price_field, shift=0): return fields, names +class HighFreqGeneralHandler(HighFreqHandler): + def __init__( + self, + instruments="csi300", + start_time=None, + end_time=None, + infer_processors=[], + learn_processors=[], + fit_start_time=None, + fit_end_time=None, + drop_raw=True, + day_length=240, + ): + self.day_length = day_length + super().__init__( + instruments=instruments, + start_time=start_time, + end_time=end_time, + infer_processors=infer_processors, + learn_processors=learn_processors, + fit_start_time=fit_start_time, + fit_end_time=fit_end_time, + drop_raw=drop_raw, + ) + + def get_feature_config(self): + fields = [] + names = [] + + template_if = "If(IsNull({1}), {0}, {1})" + template_paused = f"Cut({{0}}, {self.day_length * 2}, None)" + + def get_normalized_price_feature(price_field, shift=0): + # norm with the close price of 237th minute of yesterday. + if shift == 0: + template_norm = f"{{0}}/DayLast(Ref({{1}}, {self.day_length * 2}))" + else: + template_norm = f"Ref({{0}}, " + str(shift) + f")/DayLast(Ref({{1}}, {self.day_length}))" + + template_fillnan = "FFillNan({0})" + # calculate -> ffill -> remove paused + feature_ops = template_paused.format( + template_fillnan.format( + template_norm.format(template_if.format("$close", price_field), template_fillnan.format("$close")) + ) + ) + return feature_ops + + fields += [get_normalized_price_feature("$open", 0)] + fields += [get_normalized_price_feature("$high", 0)] + fields += [get_normalized_price_feature("$low", 0)] + fields += [get_normalized_price_feature("$close", 0)] + fields += [get_normalized_price_feature("$vwap", 0)] + names += ["$open", "$high", "$low", "$close", "$vwap"] + + fields += [get_normalized_price_feature("$open", self.day_length)] + fields += [get_normalized_price_feature("$high", self.day_length)] + fields += [get_normalized_price_feature("$low", self.day_length)] + fields += [get_normalized_price_feature("$close", self.day_length)] + fields += [get_normalized_price_feature("$vwap", self.day_length)] + names += ["$open_1", "$high_1", "$low_1", "$close_1", "$vwap_1"] + + # calculate and fill nan with 0 + fields += [ + template_paused.format( + "If(IsNull({0}), 0, {0})".format( + f"{{0}}/Ref(DayLast(Mean({{0}}, {self.day_length * 30})), {self.day_length})".format("$volume") + ) + ) + ] + names += ["$volume"] + + fields += [ + template_paused.format( + "If(IsNull({0}), 0, {0})".format( + f"Ref({{0}}, {self.day_length})/Ref(DayLast(Mean({{0}}, {self.day_length * 30})), {self.day_length})".format( + "$volume" + ) + ) + ) + ] + names += ["$volume_1"] + + return fields, names + + class HighFreqBacktestHandler(DataHandler): def __init__( self, @@ -163,6 +249,45 @@ def get_feature_config(self): return fields, names +class HighFreqGeneralBacktestHandler(HighFreqBacktestHandler): + def __init__( + self, + instruments="csi300", + start_time=None, + end_time=None, + day_length=240, + ): + self.day_length = day_length + super().__init__( + instruments=instruments, + start_time=start_time, + end_time=end_time, + ) + + def get_feature_config(self): + fields = [] + names = [] + + template_paused = f"Cut({{0}}, {self.day_length * 2}, None)" + # template_paused = "{0}" + template_fillnan = "FFillNan({0})" + template_if = "If(IsNull({1}), {0}, {1})" + fields += [ + template_paused.format(template_fillnan.format("$close")), + ] + names += ["$close0"] + + fields += [ + template_paused.format(template_if.format(template_fillnan.format("$close"), "$vwap")), + ] + names += ["$vwap0"] + + fields += [template_paused.format("If(IsNull({0}), 0, {0})".format("$volume"))] + names += ["$volume0"] + + return fields, names + + class HighFreqOrderHandler(DataHandlerLP): def __init__( self, @@ -407,3 +532,4 @@ def get_feature_config(self): names += ["$lowmarket0"] return fields, names + \ No newline at end of file From 321691d2cdf7f1b20cbfb9ec2276563aee6fb8b1 Mon Sep 17 00:00:00 2001 From: Huoran Li Date: Thu, 6 Oct 2022 12:11:36 +0800 Subject: [PATCH 4/8] Resolve PR comments --- qlib/contrib/data/highfreq_handler.py | 1 - qlib/rl/contrib/backtest.py | 118 +++++++++++++++------- qlib/rl/order_execution/simulator_qlib.py | 2 + 3 files changed, 84 insertions(+), 37 deletions(-) diff --git a/qlib/contrib/data/highfreq_handler.py b/qlib/contrib/data/highfreq_handler.py index a27f0ce370..1895830c2d 100644 --- a/qlib/contrib/data/highfreq_handler.py +++ b/qlib/contrib/data/highfreq_handler.py @@ -532,4 +532,3 @@ def get_feature_config(self): names += ["$lowmarket0"] return fields, names - \ No newline at end of file diff --git a/qlib/rl/contrib/backtest.py b/qlib/rl/contrib/backtest.py index bffa850b17..b46e5578ab 100644 --- a/qlib/rl/contrib/backtest.py +++ b/qlib/rl/contrib/backtest.py @@ -7,7 +7,7 @@ import pickle from collections import defaultdict from pathlib import Path -from typing import List, Optional, Tuple, Union +from typing import List, Literal, Optional, Tuple, Union import numpy as np import pandas as pd @@ -15,7 +15,7 @@ from joblib import Parallel, delayed from qlib.backtest import collect_data_loop, get_strategy_executor -from qlib.backtest.decision import Order, OrderDir, TradeRangeByTime +from qlib.backtest.decision import BaseTradeDecision, Order, OrderDir, TradeRangeByTime from qlib.backtest.executor import BaseExecutor, NestedExecutor, SimulatorExecutor from qlib.backtest.high_performance_ds import BaseOrderIndicator from qlib.rl.contrib.naive_config_parser import get_backtest_config_fromfile @@ -92,18 +92,30 @@ def _convert_indicator_to_dataframe(indicator: dict) -> Optional[pd.DataFrame]: return records -def _generate_report(decisions: list, report_dicts: List[dict]) -> dict: +def _generate_report(decisions: List[BaseTradeDecision], report_indicators: List[dict]) -> dict: + """Generate backtest reports + + Parameters + ---------- + decisions: + List of trade decisions. + report_indicators + List of indicator reports. + Returns + ------- + + """ indicator_dict = defaultdict(list) indicator_his = defaultdict(list) - for report_dict in report_dicts: - for key, value in report_dict["indicator"].items(): + for report_indicator in report_indicators: + for key, value in report_indicator.items(): if key.endswith("_obj"): indicator_his[key].append(value.order_indicator_his) else: indicator_dict[key].append(value) report = {} - decision_details = pd.concat([d.details for d in decisions if hasattr(d, "details")]) + decision_details = pd.concat([getattr(d, "details") for d in decisions if hasattr(d, "details")]) for key in ["1min", "5min", "30min", "1day"]: if key not in indicator_dict: continue @@ -122,10 +134,34 @@ def _generate_report(decisions: list, report_dicts: List[dict]) -> dict: def single_with_simulator( backtest_config: dict, orders: pd.DataFrame, - split: str = "stock", + split: Literal["stock", "day"] = "stock", cash_limit: float = None, generate_report: bool = False, ) -> Union[Tuple[pd.DataFrame, dict], pd.DataFrame]: + """Run backtest in a single thread with SingleAssetOrderExecution simulator. The orders will be executed day by day. + A new simulator will be created and used for every single-day order. + + Parameters + ---------- + backtest_config: + Backtest config + orders: + Orders to be executed. Example format: + datetime instrument amount direction + 0 2020-06-01 INST 600.0 0 + 1 2020-06-02 INST 700.0 1 + ... + split + Method to split orders. If it is "stock", split orders by stock. If it is "day", split orders by date. + cash_limit + Limitation of cash. + generate_report + Whether to generate reports. + + Returns + ------- + If generate_report is True, return execution records and the generated report. Otherwise, return only records. + """ if split == "stock": stock_id = orders.iloc[0].instrument init_qlib(backtest_config["qlib"], part=stock_id) @@ -180,7 +216,7 @@ def single_with_simulator( assert records is None or not np.isnan(records["ffr"]).any() if generate_report: - report = _generate_report(decisions, reports) + report = _generate_report(decisions, [report["indicator"] for report in reports]) if split == "stock": stock_id = orders.iloc[0].instrument @@ -197,10 +233,34 @@ def single_with_simulator( def single_with_collect_data_loop( backtest_config: dict, orders: pd.DataFrame, - split: str = "stock", + split: Literal["stock", "day"] = "stock", cash_limit: float = None, generate_report: bool = False, ) -> Union[Tuple[pd.DataFrame, dict], pd.DataFrame]: + """Run backtest in a single thread with collect_data_loop. + + Parameters + ---------- + backtest_config: + Backtest config + orders: + Orders to be executed. Example format: + datetime instrument amount direction + 0 2020-06-01 INST 600.0 0 + 1 2020-06-02 INST 700.0 1 + ... + split + Method to split orders. If it is "stock", split orders by stock. If it is "day", split orders by date. + cash_limit + Limitation of cash. + generate_report + Whether to generate reports. + + Returns + ------- + If generate_report is True, return execution records and the generated report. Otherwise, return only records. + """ + if split == "stock": stock_id = orders.iloc[0].instrument init_qlib(backtest_config["qlib"], part=stock_id) @@ -257,7 +317,7 @@ def single_with_collect_data_loop( assert records is None or not np.isnan(records["ffr"]).any() if generate_report: - report = _generate_report(decisions, [report_dict]) + report = _generate_report(decisions, [report_dict["indicator"]]) if split == "stock": stock_id = orders.iloc[0].instrument report = {stock_id: report} @@ -269,7 +329,7 @@ def single_with_collect_data_loop( return records -def backtest(backtest_config: dict, parallel_mode: bool = False, with_simulator: bool = False) -> pd.DataFrame: +def backtest(backtest_config: dict, with_simulator: bool = False) -> pd.DataFrame: order_df = read_order_file(backtest_config["order_file"]) cash_limit = backtest_config["exchange"].pop("cash_limit") @@ -281,30 +341,18 @@ def backtest(backtest_config: dict, parallel_mode: bool = False, with_simulator: stock_pool = stock_pool single = single_with_simulator if with_simulator else single_with_collect_data_loop - if parallel_mode: - mp_config = {"n_jobs": backtest_config["concurrency"], "verbose": 10, "backend": "multiprocessing"} - torch.set_num_threads(1) # https://github.com/pytorch/pytorch/issues/17199 - res = Parallel(**mp_config)( - delayed(single)( - backtest_config=backtest_config, - orders=order_df[order_df["instrument"] == stock].copy(), - split="stock", - cash_limit=cash_limit, - generate_report=generate_report, - ) - for stock in stock_pool + mp_config = {"n_jobs": backtest_config["concurrency"], "verbose": 10, "backend": "multiprocessing"} + torch.set_num_threads(1) # https://github.com/pytorch/pytorch/issues/17199 + res = Parallel(**mp_config)( + delayed(single)( + backtest_config=backtest_config, + orders=order_df[order_df["instrument"] == stock].copy(), + split="stock", + cash_limit=cash_limit, + generate_report=generate_report, ) - else: - res = [ - single( - backtest_config=backtest_config, - orders=order_df[order_df["instrument"] == stock].copy(), - split="stock", - cash_limit=cash_limit, - generate_report=generate_report, - ) - for stock in stock_pool - ] + for stock in stock_pool + ) output_path = Path(backtest_config["output_dir"]) if generate_report: @@ -329,12 +377,10 @@ def backtest(backtest_config: dict, parallel_mode: bool = False, with_simulator: parser = argparse.ArgumentParser() parser.add_argument("--config_path", type=str, required=True, help="Path to the config file") - parser.add_argument("--parallel", action="store_true", help="Whether to run pipelines in parallel") parser.add_argument("--use_simulator", action="store_true", help="Whether to use simulator as the backend") args = parser.parse_args() backtest( backtest_config=get_backtest_config_fromfile(args.config_path), - parallel_mode=args.parallel, with_simulator=args.use_simulator, ) diff --git a/qlib/rl/order_execution/simulator_qlib.py b/qlib/rl/order_execution/simulator_qlib.py index 7fc94a52b8..693af8004b 100644 --- a/qlib/rl/order_execution/simulator_qlib.py +++ b/qlib/rl/order_execution/simulator_qlib.py @@ -98,6 +98,8 @@ def reset( ) assert isinstance(self._collect_data_loop, Generator) + # TODO: backtest_mode is not a necessary parameter if we carefully design it. + # TODO: It should disappear with CollectDataEnvWrapper in the future. if backtest_mode: executor: BaseExecutor = self._executor while isinstance(executor, NestedExecutor): From 35a19aa2c7faaaa4a7faa34bbf461bbc4a33bfa5 Mon Sep 17 00:00:00 2001 From: Huoran Li Date: Thu, 6 Oct 2022 15:40:33 +0800 Subject: [PATCH 5/8] Fix missing data processing --- qlib/rl/contrib/backtest.py | 6 ++---- qlib/rl/contrib/naive_config_parser.py | 3 ++- qlib/rl/data/native.py | 26 +++++++++++----------- qlib/rl/order_execution/state.py | 30 ++++++++++++++++++++------ 4 files changed, 41 insertions(+), 24 deletions(-) diff --git a/qlib/rl/contrib/backtest.py b/qlib/rl/contrib/backtest.py index b46e5578ab..cf389694e9 100644 --- a/qlib/rl/contrib/backtest.py +++ b/qlib/rl/contrib/backtest.py @@ -43,7 +43,7 @@ def _get_multi_level_executor_config( } freqs = list(strategy_config.keys()) - freqs.sort(key=lambda x: pd.Timedelta(x)) + freqs.sort(key=pd.Timedelta) for freq in freqs: executor_config = { "class": "NestedExecutor", @@ -75,7 +75,7 @@ def _convert_indicator_to_dataframe(indicator: dict) -> Optional[pd.DataFrame]: # HACK: for qlib v0.8 value_dict = value_dict.to_series() try: - value_dict = {k: v for k, v in value_dict.items()} + value_dict = copy.deepcopy(value_dict) if value_dict["ffr"].empty: continue except Exception: @@ -338,8 +338,6 @@ def backtest(backtest_config: dict, with_simulator: bool = False) -> pd.DataFram stock_pool = order_df["instrument"].unique().tolist() stock_pool.sort() - stock_pool = stock_pool - single = single_with_simulator if with_simulator else single_with_collect_data_loop mp_config = {"n_jobs": backtest_config["concurrency"], "verbose": 10, "backend": "multiprocessing"} torch.set_num_threads(1) # https://github.com/pytorch/pytorch/issues/17199 diff --git a/qlib/rl/contrib/naive_config_parser.py b/qlib/rl/contrib/naive_config_parser.py index eaf62636cc..3f3d2eeadc 100644 --- a/qlib/rl/contrib/naive_config_parser.py +++ b/qlib/rl/contrib/naive_config_parser.py @@ -53,7 +53,8 @@ def parse_backtest_config(path: str) -> dict: del sys.modules[tmp_module_name] else: - config = yaml.safe_load(open(tmp_config_file.name)) + with open(tmp_config_file.name) as input_stream: + config = yaml.safe_load(input_stream) if "_base_" in config: base_file_name = config.pop("_base_") diff --git a/qlib/rl/data/native.py b/qlib/rl/data/native.py index eb612cf64e..f18e0f257b 100644 --- a/qlib/rl/data/native.py +++ b/qlib/rl/data/native.py @@ -9,12 +9,11 @@ from qlib.backtest import Exchange, Order from qlib.backtest.decision import TradeRange, TradeRangeByTime -from qlib.constant import EPS_T, ONE_DAY from qlib.rl.order_execution.utils import get_ticks_slice -from qlib.utils.index_data import IndexData from .base import BaseIntradayBacktestData, BaseIntradayProcessedData, ProcessedDataProvider from .integration import fetch_features +from ...data import D class IntradayBacktestData(BaseIntradayBacktestData): @@ -82,18 +81,16 @@ def load_backtest_data( trade_exchange: Exchange, trade_range: TradeRange, ) -> IntradayBacktestData: - data = cast( - IndexData, - trade_exchange.get_deal_price( - stock_id=order.stock_id, - start_time=order.date, - end_time=order.date + ONE_DAY - EPS_T, - direction=order.direction, - method=None, - ), + tmp_data = D.features( + trade_exchange.codes, + trade_exchange.all_fields, + trade_exchange.start_time, + trade_exchange.end_time, + freq=trade_exchange.freq, + disk_cache=True, ) - ticks_index = pd.DatetimeIndex(data.index) + ticks_index = pd.DatetimeIndex(tmp_data.reset_index()["datetime"]) if isinstance(trade_range, TradeRangeByTime): ticks_for_order = get_ticks_slice( ticks_index, @@ -122,7 +119,10 @@ def __init__( date: pd.Timestamp, ) -> None: def _drop_stock_id(df: pd.DataFrame) -> pd.DataFrame: - return df.reset_index().drop(columns=["instrument"]).set_index(["datetime"]) + df = df.reset_index() + if "instrument" in df.columns: + df = df.drop(columns=["instrument"]) + return df.set_index(["datetime"]) self.today = _drop_stock_id(fetch_features(stock_id, date)) self.yesterday = _drop_stock_id(fetch_features(stock_id, date, yesterday=True)) diff --git a/qlib/rl/order_execution/state.py b/qlib/rl/order_execution/state.py index a46928ee89..a38bb32620 100644 --- a/qlib/rl/order_execution/state.py +++ b/qlib/rl/order_execution/state.py @@ -4,7 +4,7 @@ from __future__ import annotations import typing -from typing import cast, NamedTuple, Optional, Tuple +from typing import cast, Callable, List, NamedTuple, Optional, Tuple import numpy as np import pandas as pd @@ -13,6 +13,7 @@ from qlib.constant import EPS, ONE_MIN, REG_CN from qlib.rl.order_execution.utils import dataframe_append, price_advantage from qlib.typehint import TypedDict +from qlib.utils.index_data import IndexData from qlib.utils.time import get_day_min_idx_range if typing.TYPE_CHECKING: @@ -38,6 +39,18 @@ def _get_all_timestamps( return pd.DatetimeIndex(ret) +def fill_missing_data( + original_data: np.ndarray, + total_time_list: List[pd.Timestamp], + found_time_list: List[pd.Timestamp], + fill_method: Callable = np.median, +) -> np.ndarray: + assert len(original_data) == len(found_time_list) + tmp = dict(zip(found_time_list, original_data)) + fill_val = fill_method(original_data) + return np.array([tmp.get(t, fill_val) for t in total_time_list]) + + class SAOEStateAdapter: """ Maintain states of the environment. SAOEStateAdapter accepts execution results and update its internal state @@ -106,16 +119,17 @@ def update( assert exec_vol.sum() < self.position + 1, f"{exec_vol} too large" exec_vol *= self.position / (exec_vol.sum()) - market_volume = np.array( + market_volume = cast( + IndexData, self.exchange.get_volume( self.order.stock_id, pd.Timestamp(start_time), pd.Timestamp(end_time), method=None, ), - ).reshape(-1) - - market_price = np.array( + ) + market_price = cast( + IndexData, self.exchange.get_deal_price( self.order.stock_id, pd.Timestamp(start_time), @@ -123,7 +137,11 @@ def update( method=None, direction=self.order.direction, ), - ).reshape(-1) + ) + found_time_list = [pd.Timestamp(e) for e in list(market_volume.index)] + total_time_list = _get_all_timestamps(start_time, end_time) + market_price = fill_missing_data(np.array(market_price).reshape(-1), total_time_list, found_time_list) + market_volume = fill_missing_data(np.array(market_volume).reshape(-1), total_time_list, found_time_list) assert market_price.shape == market_volume.shape == exec_vol.shape From 773188b0658d0abcfc77a4c156a515c21c0c0781 Mon Sep 17 00:00:00 2001 From: Huoran Li Date: Thu, 6 Oct 2022 16:10:11 +0800 Subject: [PATCH 6/8] Minor bugfix --- qlib/rl/data/native.py | 3 +++ qlib/rl/data/pickle_styled.py | 8 ++++---- qlib/rl/order_execution/policy.py | 3 +++ 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/qlib/rl/data/native.py b/qlib/rl/data/native.py index f18e0f257b..b7e9c78d67 100644 --- a/qlib/rl/data/native.py +++ b/qlib/rl/data/native.py @@ -91,6 +91,9 @@ def load_backtest_data( ) ticks_index = pd.DatetimeIndex(tmp_data.reset_index()["datetime"]) + ticks_index = ticks_index[order.start_time <= ticks_index] + ticks_index = ticks_index[ticks_index <= order.end_time] + if isinstance(trade_range, TradeRangeByTime): ticks_for_order = get_ticks_slice( ticks_index, diff --git a/qlib/rl/data/pickle_styled.py b/qlib/rl/data/pickle_styled.py index ed62a4180d..3af1e24839 100644 --- a/qlib/rl/data/pickle_styled.py +++ b/qlib/rl/data/pickle_styled.py @@ -91,7 +91,7 @@ class SimpleIntradayBacktestData(BaseIntradayBacktestData): def __init__( self, - data_dir: Path, + data_dir: Path | str, stock_id: str, date: pd.Timestamp, deal_price: DealPriceType = "close", @@ -99,7 +99,7 @@ def __init__( ) -> None: super(SimpleIntradayBacktestData, self).__init__() - backtest = _read_pickle(data_dir / stock_id) + backtest = _read_pickle((data_dir if isinstance(data_dir, Path) else Path(data_dir)) / stock_id) backtest = backtest.loc[pd.IndexSlice[stock_id, :, date]] # No longer need for pandas >= 1.4 @@ -154,13 +154,13 @@ class IntradayProcessedData(BaseIntradayProcessedData): def __init__( self, - data_dir: Path, + data_dir: Path | str, stock_id: str, date: pd.Timestamp, feature_dim: int, time_index: pd.Index, ) -> None: - proc = _read_pickle(data_dir / stock_id) + proc = _read_pickle((data_dir if isinstance(data_dir, Path) else Path(data_dir)) / stock_id) # We have to infer the names here because, # unfortunately they are not included in the original data. cnames = _infer_processed_data_column_names(feature_dim) diff --git a/qlib/rl/order_execution/policy.py b/qlib/rl/order_execution/policy.py index cfd3181ca2..bee13757be 100644 --- a/qlib/rl/order_execution/policy.py +++ b/qlib/rl/order_execution/policy.py @@ -163,6 +163,9 @@ def auto_device(module: nn.Module) -> torch.device: def load_weight(policy: nn.Module, path: Path) -> None: assert isinstance(policy, nn.Module), "Policy has to be an nn.Module to load weight." loaded_weight = torch.load(path, map_location="cpu") + + if "vessel" in loaded_weight: + loaded_weight = loaded_weight["vessel"]["policy"] try: policy.load_state_dict(loaded_weight) except RuntimeError: From e8b4165d55981c0b27008b820806c3b9a5a0b615 Mon Sep 17 00:00:00 2001 From: Huoran Li Date: Sun, 9 Oct 2022 10:13:51 +0800 Subject: [PATCH 7/8] Add TODOs and docs --- qlib/backtest/decision.py | 2 ++ qlib/rl/contrib/backtest.py | 2 ++ qlib/rl/data/native.py | 1 + qlib/rl/order_execution/policy.py | 3 +++ qlib/rl/order_execution/state.py | 19 +++++++++++++++++++ 5 files changed, 27 insertions(+) diff --git a/qlib/backtest/decision.py b/qlib/backtest/decision.py index 115823eadb..4b1d8db7b7 100644 --- a/qlib/backtest/decision.py +++ b/qlib/backtest/decision.py @@ -579,6 +579,8 @@ def __repr__(self) -> str: class TradeDecisionWithDetails(TradeDecisionWO): + """Decision with detail information. Detail information is used to generate execution reports. + """ def __init__( self, order_list: List[Order], diff --git a/qlib/rl/contrib/backtest.py b/qlib/rl/contrib/backtest.py index cf389694e9..4d3d3cf4b7 100644 --- a/qlib/rl/contrib/backtest.py +++ b/qlib/rl/contrib/backtest.py @@ -92,6 +92,8 @@ def _convert_indicator_to_dataframe(indicator: dict) -> Optional[pd.DataFrame]: return records +# TODO: there should be richer annotation for the input (e.g. report) and the returned report +# TODO: For example, @ dataclass with typed fields and detailed docstrings. def _generate_report(decisions: List[BaseTradeDecision], report_indicators: List[dict]) -> dict: """Generate backtest reports diff --git a/qlib/rl/data/native.py b/qlib/rl/data/native.py index b7e9c78d67..9417534f86 100644 --- a/qlib/rl/data/native.py +++ b/qlib/rl/data/native.py @@ -81,6 +81,7 @@ def load_backtest_data( trade_exchange: Exchange, trade_range: TradeRange, ) -> IntradayBacktestData: + # TODO: making exchange return data without missing will make it more elegant. Fix this in the future. tmp_data = D.features( trade_exchange.codes, trade_exchange.all_fields, diff --git a/qlib/rl/order_execution/policy.py b/qlib/rl/order_execution/policy.py index bee13757be..7f7a98e9a7 100644 --- a/qlib/rl/order_execution/policy.py +++ b/qlib/rl/order_execution/policy.py @@ -164,6 +164,9 @@ def load_weight(policy: nn.Module, path: Path) -> None: assert isinstance(policy, nn.Module), "Policy has to be an nn.Module to load weight." loaded_weight = torch.load(path, map_location="cpu") + # TODO: this should be handled by whoever calls load_weight. + # TODO: For example, when the outer class receives a weight, it should first unpack it, + # TODO: and send the corresponding part to individual component. if "vessel" in loaded_weight: loaded_weight = loaded_weight["vessel"]["policy"] try: diff --git a/qlib/rl/order_execution/state.py b/qlib/rl/order_execution/state.py index a38bb32620..f417173e52 100644 --- a/qlib/rl/order_execution/state.py +++ b/qlib/rl/order_execution/state.py @@ -45,6 +45,25 @@ def fill_missing_data( found_time_list: List[pd.Timestamp], fill_method: Callable = np.median, ) -> np.ndarray: + """Fill missing data. We need this function to deal with data that have missing values in some minutes. + + TODO: making exchange return data without missing will make it more elegant. Fix this in the future. + + Parameters + ---------- + original_data + Original data without missing values. + total_time_list + All timestamps that required. + found_time_list + Timestamps found in the original data. + fill_method + Method used to fill the missing data. + + Returns + ------- + The filled data. + """ assert len(original_data) == len(found_time_list) tmp = dict(zip(found_time_list, original_data)) fill_val = fill_method(original_data) From c82d05eb9bcd011ab7e699ae41b2507c1151ad26 Mon Sep 17 00:00:00 2001 From: Huoran Li Date: Wed, 12 Oct 2022 16:38:41 +0800 Subject: [PATCH 8/8] Add a comment --- qlib/rl/order_execution/simulator_qlib.py | 1 + 1 file changed, 1 insertion(+) diff --git a/qlib/rl/order_execution/simulator_qlib.py b/qlib/rl/order_execution/simulator_qlib.py index 693af8004b..c9702b1e48 100644 --- a/qlib/rl/order_execution/simulator_qlib.py +++ b/qlib/rl/order_execution/simulator_qlib.py @@ -107,6 +107,7 @@ def reset( executor.inner_strategy.set_env(CollectDataEnvWrapper()) executor = executor.inner_executor + # Call `step()` with None action to initialize the internal generator. self.step(action=None) self._order = order