diff --git a/qlib/backtest/__init__.py b/qlib/backtest/__init__.py index 622b07d356..d3f4d72402 100644 --- a/qlib/backtest/__init__.py +++ b/qlib/backtest/__init__.py @@ -42,7 +42,7 @@ def get_exchange( close_cost: float = 0.0025, min_cost: float = 5.0, limit_threshold: Union[Tuple[str, str], float, None] = None, - deal_price: Union[str, Tuple[str], List[str]] = None, + deal_price: Union[str, Tuple[str, str], List[str]] = None, **kwargs: Any, ) -> Exchange: """get_exchange @@ -70,10 +70,10 @@ def get_exchange( min_cost : float min transaction cost. It is an absolute amount of cost instead of a ratio of your order's deal amount. e.g. You must pay at least 5 yuan of commission regardless of your order's deal amount. - deal_price: Union[str, Tuple[str], List[str]] + deal_price: Union[str, Tuple[str, str], List[str]] The `deal_price` supports following two types of input - : str - - (, ): Tuple[str] or List[str] + - (, ): Tuple[str, str] or List[str] , or := := str diff --git a/qlib/backtest/decision.py b/qlib/backtest/decision.py index 42e798c6dd..4828478c7e 100644 --- a/qlib/backtest/decision.py +++ b/qlib/backtest/decision.py @@ -4,10 +4,11 @@ from __future__ import annotations from abc import abstractmethod +from datetime import time from enum import IntEnum # try to fix circular imports when enabling type hints -from typing import Generic, List, TYPE_CHECKING, Any, ClassVar, Optional, Tuple, TypeVar, Union, cast +from typing import TYPE_CHECKING, Any, ClassVar, Generic, List, Optional, Tuple, TypeVar, Union, cast from qlib.backtest.utils import TradeCalendarManager from qlib.data.data import Cal @@ -23,7 +24,6 @@ import numpy as np import pandas as pd - DecisionType = TypeVar("DecisionType") @@ -182,8 +182,8 @@ def create( return Order( stock_id=code, amount=amount, - start_time=start_time if start_time is not None else pd.Timestamp(start_time), - end_time=end_time if end_time is not None else pd.Timestamp(end_time), + start_time=None if start_time is None else pd.Timestamp(start_time), + end_time=None if end_time is None else pd.Timestamp(end_time), direction=direction, ) @@ -249,7 +249,7 @@ def clip_time_range(self, start_time: pd.Timestamp, end_time: pd.Timestamp) -> T class TradeRangeByTime(TradeRange): """This is a helper function for make decisions""" - def __init__(self, start_time: str, end_time: str) -> None: + def __init__(self, start_time: str | time, end_time: str | time) -> None: """ This is a callable class. @@ -259,13 +259,13 @@ def __init__(self, start_time: str, end_time: str) -> None: Parameters ---------- - start_time : str + start_time : str | time e.g. "9:30" - end_time : str + end_time : str | time e.g. "14:30" """ - self.start_time = pd.Timestamp(start_time).time() - self.end_time = pd.Timestamp(end_time).time() + self.start_time = pd.Timestamp(start_time).time() if isinstance(start_time, str) else start_time + self.end_time = pd.Timestamp(end_time).time() if isinstance(end_time, str) else end_time assert self.start_time < self.end_time def __call__(self, trade_calendar: TradeCalendarManager) -> Tuple[int, int]: @@ -535,7 +535,12 @@ class TradeDecisionWO(BaseTradeDecision[Order]): Besides, the time_range is also included. """ - def __init__(self, order_list: List[object], strategy: BaseStrategy, trade_range: Tuple[int, int] = None) -> None: + def __init__( + self, + order_list: List[Order], + strategy: BaseStrategy, + trade_range: Union[Tuple[int, int], TradeRange] = None, + ) -> None: super().__init__(strategy, trade_range=trade_range) self.order_list = cast(List[Order], order_list) start, end = strategy.trade_calendar.get_step_time() diff --git a/qlib/backtest/exchange.py b/qlib/backtest/exchange.py index 7e4210fe79..16cd8815f9 100644 --- a/qlib/backtest/exchange.py +++ b/qlib/backtest/exchange.py @@ -32,7 +32,7 @@ def __init__( start_time: Union[pd.Timestamp, str] = None, end_time: Union[pd.Timestamp, str] = None, codes: Union[list, str] = "all", - deal_price: Union[str, Tuple[str], List[str]] = None, + deal_price: Union[str, Tuple[str, str], List[str]] = None, subscribe_fields: list = [], limit_threshold: Union[Tuple[str, str], float, None] = None, volume_threshold: Union[tuple, dict] = None, @@ -448,9 +448,9 @@ def get_volume( start_time: pd.Timestamp, end_time: pd.Timestamp, method: Optional[str] = "sum", - ) -> float: + ) -> Union[None, int, float, bool, IndexData]: """get the total deal volume of stock with `stock_id` between the time interval [start_time, end_time)""" - return cast(float, self.quote.get_data(stock_id, start_time, end_time, field="$volume", method=method)) + return self.quote.get_data(stock_id, start_time, end_time, field="$volume", method=method) def get_deal_price( self, @@ -459,7 +459,7 @@ def get_deal_price( end_time: pd.Timestamp, direction: OrderDir, method: Optional[str] = "ts_data_last", - ) -> float: + ) -> Union[None, int, float, bool, IndexData]: if direction == OrderDir.SELL: pstr = self.sell_price elif direction == OrderDir.BUY: @@ -472,7 +472,7 @@ def get_deal_price( self.logger.warning(f"(stock_id:{stock_id}, trade_time:{(start_time, end_time)}, {pstr}): {deal_price}!!!") self.logger.warning(f"setting deal_price to close price") deal_price = self.get_close(stock_id, start_time, end_time, method) - return cast(float, deal_price) + return deal_price def get_factor( self, @@ -832,8 +832,11 @@ def _calc_trade_info_by_order( :param dealt_order_amount: the dealt order amount dict with the format of {stock_id: float} :return: trade_price, trade_val, trade_cost """ - trade_price = self.get_deal_price(order.stock_id, order.start_time, order.end_time, direction=order.direction) - total_trade_val = self.get_volume(order.stock_id, order.start_time, order.end_time) * trade_price + trade_price = cast( + float, + self.get_deal_price(order.stock_id, order.start_time, order.end_time, direction=order.direction), + ) + total_trade_val = cast(float, self.get_volume(order.stock_id, order.start_time, order.end_time)) * trade_price order.factor = self.get_factor(order.stock_id, order.start_time, order.end_time) order.deal_amount = order.amount # set to full amount and clip it step by step # Clipping amount first diff --git a/qlib/backtest/executor.py b/qlib/backtest/executor.py index ef507e1a03..13af7aea71 100644 --- a/qlib/backtest/executor.py +++ b/qlib/backtest/executor.py @@ -484,6 +484,7 @@ def post_inner_exe_step(self, inner_exe_res: List[object]) -> None: inner_exe_res : the execution result of inner task """ + self.inner_strategy.post_exe_step(inner_exe_res) def get_all_executors(self) -> List[BaseExecutor]: """get all executors, including self and inner_executor.get_all_executors()""" diff --git a/qlib/contrib/data/handler.py b/qlib/contrib/data/handler.py index d9d09dbfff..5eae8b89d1 100644 --- a/qlib/contrib/data/handler.py +++ b/qlib/contrib/data/handler.py @@ -284,7 +284,7 @@ def use(x): fields += ["Rsquare($close, %d)" % d for d in windows] names += ["RSQR%d" % d for d in windows] if use("RESI"): - # The redisdual for linear regression for the past d days, represent the trend linearity for past d days. + # The redisdual for linear regression for the past d days, represent the trend linearity for past d days. fields += ["Resi($close, %d)/$close" % d for d in windows] names += ["RESI%d" % d for d in windows] if use("MAX"): @@ -297,7 +297,7 @@ def use(x): names += ["MIN%d" % d for d in windows] if use("QTLU"): # The 80% quantile of past d day's close price, divided by latest close price to remove unit - # Used with MIN and MAX + # Used with MIN and MAX fields += ["Quantile($close, %d, 0.8)/$close" % d for d in windows] names += ["QTLU%d" % d for d in windows] if use("QTLD"): @@ -305,7 +305,7 @@ def use(x): fields += ["Quantile($close, %d, 0.2)/$close" % d for d in windows] names += ["QTLD%d" % d for d in windows] if use("RANK"): - # Get the percentile of current close price in past d day's close price. + # Get the percentile of current close price in past d day's close price. # Represent the current price level comparing to past N days, add additional information to moving average. fields += ["Rank($close, %d)" % d for d in windows] names += ["RANK%d" % d for d in windows] @@ -316,14 +316,14 @@ def use(x): if use("IMAX"): # The number of days between current date and previous highest price date. # Part of Aroon Indicator https://www.investopedia.com/terms/a/aroon.asp - # The indicator measures the time between highs and the time between lows over a time period. + # The indicator measures the time between highs and the time between lows over a time period. # The idea is that strong uptrends will regularly see new highs, and strong downtrends will regularly see new lows. fields += ["IdxMax($high, %d)/%d" % (d, d) for d in windows] names += ["IMAX%d" % d for d in windows] if use("IMIN"): # The number of days between current date and previous lowest price date. # Part of Aroon Indicator https://www.investopedia.com/terms/a/aroon.asp - # The indicator measures the time between highs and the time between lows over a time period. + # The indicator measures the time between highs and the time between lows over a time period. # The idea is that strong uptrends will regularly see new highs, and strong downtrends will regularly see new lows. fields += ["IdxMin($low, %d)/%d" % (d, d) for d in windows] names += ["IMIN%d" % d for d in windows] diff --git a/qlib/data/storage/file_storage.py b/qlib/data/storage/file_storage.py index cfac8d12bd..288500c555 100644 --- a/qlib/data/storage/file_storage.py +++ b/qlib/data/storage/file_storage.py @@ -102,11 +102,22 @@ def _freq_file(self) -> str: self._freq_file_cache = freq return self._freq_file_cache - def _read_calendar(self, skip_rows: int = 0, n_rows: int = None) -> List[CalVT]: + def _read_calendar(self) -> List[CalVT]: + # NOTE: + # if we want to accelerate partial reading calendar + # we can add parameters like `skip_rows: int = 0, n_rows: int = None` to the interface. + # Currently, it is not supported for the txt-based calendar + if not self.uri.exists(): self._write_calendar(values=[]) - with self.uri.open("rb") as fp: - return [str(x) for x in np.loadtxt(fp, str, skiprows=skip_rows, max_rows=n_rows, encoding="utf-8")] + + with self.uri.open("r") as fp: + res = [] + for line in fp.readlines(): + line = line.strip() + if len(line) > 0: + res.append(line) + return res def _write_calendar(self, values: Iterable[CalVT], mode: str = "wb"): with self.uri.open(mode=mode) as fp: diff --git a/qlib/rl/aux_info.py b/qlib/rl/aux_info.py index 65cd95d5dd..9ab0834511 100644 --- a/qlib/rl/aux_info.py +++ b/qlib/rl/aux_info.py @@ -3,7 +3,7 @@ from __future__ import annotations -from typing import Generic, TYPE_CHECKING, TypeVar +from typing import Optional, TYPE_CHECKING, Generic, TypeVar from qlib.typehint import final @@ -21,7 +21,7 @@ class AuxiliaryInfoCollector(Generic[StateType, AuxInfoType]): """Override this class to collect customized auxiliary information from environment.""" - env: EnvWrapper | None = None + env: Optional[EnvWrapper] = None @final def __call__(self, simulator_state: StateType) -> AuxInfoType: diff --git a/qlib/rl/data/exchange_wrapper.py b/qlib/rl/data/exchange_wrapper.py new file mode 100644 index 0000000000..bc36fa11b8 --- /dev/null +++ b/qlib/rl/data/exchange_wrapper.py @@ -0,0 +1,58 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from typing import cast + +import pandas as pd + +from qlib.backtest import Exchange, Order +from .pickle_styled import IntradayBacktestData + + +class QlibIntradayBacktestData(IntradayBacktestData): + """Backtest data for Qlib simulator""" + + def __init__(self, order: Order, exchange: Exchange, start_time: pd.Timestamp, end_time: pd.Timestamp) -> None: + super(QlibIntradayBacktestData, self).__init__() + self._order = order + self._exchange = exchange + self._start_time = start_time + self._end_time = end_time + + self._deal_price = cast( + pd.Series, + self._exchange.get_deal_price( + self._order.stock_id, + self._start_time, + self._end_time, + direction=self._order.direction, + method=None, + ), + ) + self._volume = cast( + pd.Series, + self._exchange.get_volume( + self._order.stock_id, + self._start_time, + self._end_time, + method=None, + ), + ) + + def __repr__(self) -> str: + return ( + f"Order: {self._order}, Exchange: {self._exchange}, " + f"Start time: {self._start_time}, End time: {self._end_time}" + ) + + def __len__(self) -> int: + return len(self._deal_price) + + def get_deal_price(self) -> pd.Series: + return self._deal_price + + def get_volume(self) -> pd.Series: + return self._volume + + def get_time_index(self) -> pd.DatetimeIndex: + return pd.DatetimeIndex([e[1] for e in list(self._exchange.quote_df.index)]) diff --git a/qlib/rl/data/pickle_styled.py b/qlib/rl/data/pickle_styled.py index e2d0382b1a..aa0ba38fff 100644 --- a/qlib/rl/data/pickle_styled.py +++ b/qlib/rl/data/pickle_styled.py @@ -19,19 +19,19 @@ from __future__ import annotations +from abc import abstractmethod from functools import lru_cache -from typing import List, Sequence, cast from pathlib import Path +from typing import List, Sequence, cast import cachetools import numpy as np import pandas as pd from cachetools.keys import hashkey -from qlib.backtest.decision import OrderDir, Order +from qlib.backtest.decision import Order, OrderDir from qlib.typehint import Literal - DealPriceType = Literal["bid_or_ask", "bid_or_ask_fill", "close"] """Several ad-hoc deal price. ``bid_or_ask``: If sell, use column ``$bid0``; if buy, use column ``$ask0``. @@ -40,7 +40,7 @@ """ -def _infer_processed_data_column_names(shape: int) -> list[str]: +def _infer_processed_data_column_names(shape: int) -> List[str]: if shape == 16: return [ "$open", @@ -87,7 +87,36 @@ def _read_pickle(filename_without_suffix: Path) -> pd.DataFrame: class IntradayBacktestData: - """Raw market data that is often used in backtesting (thus called BacktestData).""" + """ + Raw market data that is often used in backtesting (thus called BacktestData). + + Base class for all types of backtest data. Currently, each type of simulator has its corresponding backtest + data type. + """ + + @abstractmethod + def __repr__(self) -> str: + raise NotImplementedError + + @abstractmethod + def __len__(self) -> int: + raise NotImplementedError + + @abstractmethod + def get_deal_price(self) -> pd.Series: + raise NotImplementedError + + @abstractmethod + def get_volume(self) -> pd.Series: + raise NotImplementedError + + @abstractmethod + def get_time_index(self) -> pd.DatetimeIndex: + raise NotImplementedError + + +class SimpleIntradayBacktestData(IntradayBacktestData): + """Backtest data for simple simulator""" def __init__( self, @@ -95,8 +124,10 @@ def __init__( stock_id: str, date: pd.Timestamp, deal_price: DealPriceType = "close", - order_dir: int | None = None, - ): + order_dir: int = None, + ) -> None: + super(SimpleIntradayBacktestData, self).__init__() + backtest = _read_pickle(data_dir / stock_id) backtest = backtest.loc[pd.IndexSlice[stock_id, :, date]] @@ -105,13 +136,13 @@ def __init__( self.data: pd.DataFrame = backtest self.deal_price_type: DealPriceType = deal_price - self.order_dir: int | None = order_dir + self.order_dir = order_dir - def __repr__(self): + def __repr__(self) -> str: with pd.option_context("memory_usage", False, "display.max_info_columns", 1, "display.large_repr", "info"): return f"{self.__class__.__name__}({self.data})" - def __len__(self): + def __len__(self) -> int: return len(self.data) def get_deal_price(self) -> pd.Series: @@ -162,7 +193,14 @@ class IntradayProcessedData: """Processed data for "yesterday". Number of records must be ``time_length``, and columns must be ``feature_dim``.""" - def __init__(self, data_dir: Path, stock_id: str, date: pd.Timestamp, feature_dim: int, time_index: pd.Index): + def __init__( + self, + data_dir: Path, + stock_id: str, + date: pd.Timestamp, + feature_dim: int, + time_index: pd.Index, + ) -> None: proc = _read_pickle(data_dir / stock_id) # We have to infer the names here because, # unfortunately they are not included in the original data. @@ -190,16 +228,20 @@ def __init__(self, data_dir: Path, stock_id: str, date: pd.Timestamp, feature_di assert len(self.today.columns) == len(self.yesterday.columns) == feature_dim assert len(self.today) == len(self.yesterday) == time_length - def __repr__(self): + def __repr__(self) -> str: with pd.option_context("memory_usage", False, "display.max_info_columns", 1, "display.large_repr", "info"): return f"{self.__class__.__name__}({self.today}, {self.yesterday})" @lru_cache(maxsize=100) # 100 * 50K = 5MB -def load_intraday_backtest_data( - data_dir: Path, stock_id: str, date: pd.Timestamp, deal_price: DealPriceType = "close", order_dir: int | None = None -) -> IntradayBacktestData: - return IntradayBacktestData(data_dir, stock_id, date, deal_price, order_dir) +def load_simple_intraday_backtest_data( + data_dir: Path, + stock_id: str, + date: pd.Timestamp, + deal_price: DealPriceType = "close", + order_dir: int = None, +) -> SimpleIntradayBacktestData: + return SimpleIntradayBacktestData(data_dir, stock_id, date, deal_price, order_dir) @cachetools.cached( # type: ignore @@ -207,13 +249,19 @@ def load_intraday_backtest_data( key=lambda data_dir, stock_id, date, _, __: hashkey(data_dir, stock_id, date), ) def load_intraday_processed_data( - data_dir: Path, stock_id: str, date: pd.Timestamp, feature_dim: int, time_index: pd.Index + data_dir: Path, + stock_id: str, + date: pd.Timestamp, + feature_dim: int, + time_index: pd.Index, ) -> IntradayProcessedData: return IntradayProcessedData(data_dir, stock_id, date, feature_dim, time_index) def load_orders( - order_path: Path, start_time: pd.Timestamp | None = None, end_time: pd.Timestamp | None = None + order_path: Path, + start_time: pd.Timestamp = None, + end_time: pd.Timestamp = None, ) -> Sequence[Order]: """Load orders, and set start time and end time for the orders.""" @@ -251,7 +299,7 @@ def load_orders( OrderDir(int(row["order_type"])), row["datetime"].replace(hour=start_time.hour, minute=start_time.minute, second=start_time.second), row["datetime"].replace(hour=end_time.hour, minute=end_time.minute, second=end_time.second), - ) + ), ) return orders diff --git a/qlib/rl/from_neutrader/__init__.py b/qlib/rl/from_neutrader/__init__.py new file mode 100644 index 0000000000..765bdee0c1 --- /dev/null +++ b/qlib/rl/from_neutrader/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# TODO: find a better way to organize contents under this module. diff --git a/qlib/rl/from_neutrader/config.py b/qlib/rl/from_neutrader/config.py new file mode 100644 index 0000000000..d9a681b32d --- /dev/null +++ b/qlib/rl/from_neutrader/config.py @@ -0,0 +1,20 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from dataclasses import dataclass +from pathlib import Path +from typing import Optional, Tuple, Union + + +# TODO: In the future we should merge the dataclass-based config with Qlib's dict-based config. +@dataclass +class ExchangeConfig: + limit_threshold: Union[float, Tuple[str, str]] + deal_price: Union[str, Tuple[str, str]] + volume_threshold: dict + open_cost: float = 0.0005 + close_cost: float = 0.0015 + min_cost: float = 5.0 + trade_unit: Optional[float] = 100.0 + cash_limit: Optional[Union[Path, float]] = None + generate_report: bool = False diff --git a/qlib/rl/from_neutrader/feature.py b/qlib/rl/from_neutrader/feature.py new file mode 100644 index 0000000000..ca42af24c9 --- /dev/null +++ b/qlib/rl/from_neutrader/feature.py @@ -0,0 +1,109 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import collections +from typing import List, Optional + +import pandas as pd + +import qlib +from qlib.config import REG_CN +from qlib.contrib.ops.high_freq import BFillNan, Cut, Date, DayCumsum, DayLast, FFillNan, IsInf, IsNull, Select +from qlib.data.dataset import DatasetH + + +class LRUCache: + def __init__(self, pool_size: int = 200): + self.pool_size = pool_size + self.contents: dict = {} + self.keys: collections.deque = collections.deque() + + def put(self, key, item): + if self.has(key): + self.keys.remove(key) + self.keys.append(key) + self.contents[key] = item + while len(self.contents) > self.pool_size: + self.contents.pop(self.keys.popleft()) + + def get(self, key): + return self.contents[key] + + def has(self, key): + return key in self.contents + + +class DataWrapper: + def __init__( + self, + feature_dataset: DatasetH, + backtest_dataset: DatasetH, + columns_today: List[str], + columns_yesterday: List[str], + _internal: bool = False, + ): + assert _internal, "Init function of data wrapper is for internal use only." + + self.feature_dataset = feature_dataset + self.backtest_dataset = backtest_dataset + self.columns_today = columns_today + self.columns_yesterday = columns_yesterday + + # TODO: We might have the chance to merge them. + self.feature_cache = LRUCache() + self.backtest_cache = LRUCache() + + def get(self, stock_id: str, date: pd.Timestamp, backtest: bool = False) -> pd.DataFrame: + start_time, end_time = date.replace(hour=0, minute=0, second=0), date.replace(hour=23, minute=59, second=59) + + if backtest: + dataset = self.backtest_dataset + cache = self.backtest_cache + else: + dataset = self.feature_dataset + cache = self.feature_cache + + if cache.has((start_time, end_time, stock_id)): + return cache.get((start_time, end_time, stock_id)) + data = dataset.handler.fetch(pd.IndexSlice[stock_id, start_time:end_time], level=None) + cache.put((start_time, end_time, stock_id), data) + return data + + +def init_qlib(config: dict, part: Optional[str] = None) -> None: + provider_uri_map = { + "day": config["provider_uri_day"].as_posix(), + "1min": config["provider_uri_1min"].as_posix(), + } + qlib.init( + region=REG_CN, + auto_mount=False, + custom_ops=[DayLast, FFillNan, BFillNan, Date, Select, IsNull, IsInf, Cut, DayCumsum], + expression_cache=None, + calendar_provider={ + "class": "LocalCalendarProvider", + "module_path": "qlib.data.data", + "kwargs": { + "backend": { + "class": "FileCalendarStorage", + "module_path": "qlib.data.storage.file_storage", + "kwargs": {"provider_uri_map": provider_uri_map}, + }, + }, + }, + feature_provider={ + "class": "LocalFeatureProvider", + "module_path": "qlib.data.data", + "kwargs": { + "backend": { + "class": "FileFeatureStorage", + "module_path": "qlib.data.storage.file_storage", + "kwargs": {"provider_uri_map": provider_uri_map}, + }, + }, + }, + provider_uri=provider_uri_map, + kernels=1, + redis_port=-1, + clear_mem_cache=False, # init_qlib will be called for multiple times. Keep the cache for improving performance + ) diff --git a/qlib/rl/interpreter.py b/qlib/rl/interpreter.py index 3835b5b923..61c9b83819 100644 --- a/qlib/rl/interpreter.py +++ b/qlib/rl/interpreter.py @@ -3,13 +3,13 @@ from __future__ import annotations -from typing import TYPE_CHECKING, TypeVar, Generic, Any +from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar import numpy as np from qlib.typehint import final -from .simulator import StateType, ActType +from .simulator import ActType, StateType if TYPE_CHECKING: from .utils.env_wrapper import EnvWrapper @@ -40,7 +40,7 @@ class Interpreter: class StateInterpreter(Generic[StateType, ObsType], Interpreter): """State Interpreter that interpret execution result of qlib executor into rl env state""" - env: EnvWrapper | None = None + env: Optional[EnvWrapper] = None @property def observation_space(self) -> gym.Space: @@ -74,7 +74,7 @@ def interpret(self, simulator_state: StateType) -> ObsType: class ActionInterpreter(Generic[StateType, PolicyActType, ActType], Interpreter): """Action Interpreter that interpret rl agent action into qlib orders""" - env: "EnvWrapper" | None = None + env: Optional[EnvWrapper] = None @property def action_space(self) -> gym.Space: @@ -141,10 +141,10 @@ def _gym_space_contains(space: gym.Space, x: Any) -> None: class GymSpaceValidationError(Exception): - def __init__(self, message: str, space: gym.Space, x: Any): + def __init__(self, message: str, space: gym.Space, x: Any) -> None: self.message = message self.space = space self.x = x - def __str__(self): + def __str__(self) -> str: return f"{self.message}\n Space: {self.space}\n Sample: {self.x}" diff --git a/qlib/rl/order_execution/interpreter.py b/qlib/rl/order_execution/interpreter.py index 9bb5dc2cf1..602a15e54e 100644 --- a/qlib/rl/order_execution/interpreter.py +++ b/qlib/rl/order_execution/interpreter.py @@ -5,15 +5,15 @@ import math from pathlib import Path -from typing import Any, cast +from typing import Any, List, cast import numpy as np import pandas as pd from gym import spaces from qlib.constant import EPS -from qlib.rl.interpreter import StateInterpreter, ActionInterpreter from qlib.rl.data import pickle_styled +from qlib.rl.interpreter import ActionInterpreter, StateInterpreter from qlib.typehint import TypedDict from .simulator_simple import SAOEState @@ -99,18 +99,18 @@ def interpret(self, state: SAOEState) -> FullHistoryObs: "data_processed": self._mask_future_info(processed.today, state.cur_time), "data_processed_prev": processed.yesterday, "acquiring": state.order.direction == state.order.BUY, - "cur_tick": min(np.sum(state.ticks_index < state.cur_time), self.data_ticks - 1), + "cur_tick": min(int(np.sum(state.ticks_index < state.cur_time)), self.data_ticks - 1), "cur_step": min(self.env.status["cur_step"], self.max_step - 1), "num_step": self.max_step, "target": state.order.amount, "position": state.position, "position_history": position_history[: self.max_step], - } + }, ), ) @property - def observation_space(self): + def observation_space(self) -> spaces.Dict: space = { "data_processed": spaces.Box(-np.inf, np.inf, shape=(self.data_ticks, self.data_dim)), "data_processed_prev": spaces.Box(-np.inf, np.inf, shape=(self.data_ticks, self.data_dim)), @@ -147,11 +147,11 @@ class CurrentStepStateInterpreter(StateInterpreter[SAOEState, CurrentStateObs]): The key list is not full. You can add more if more information is needed by your policy. """ - def __init__(self, max_step: int): + def __init__(self, max_step: int) -> None: self.max_step = max_step @property - def observation_space(self): + def observation_space(self) -> spaces.Dict: space = { "acquiring": spaces.Discrete(2), "cur_step": spaces.Box(0, self.max_step - 1, shape=(), dtype=np.int32), @@ -165,13 +165,11 @@ def interpret(self, state: SAOEState) -> CurrentStateObs: assert self.env is not None assert self.env.status["cur_step"] <= self.max_step obs = CurrentStateObs( - { - "acquiring": state.order.direction == state.order.BUY, - "cur_step": self.env.status["cur_step"], - "num_step": self.max_step, - "target": state.order.amount, - "position": state.position, - } + acquiring=state.order.direction == state.order.BUY, + cur_step=self.env.status["cur_step"], + num_step=self.max_step, + target=state.order.amount, + position=state.position, ) return obs @@ -188,7 +186,7 @@ class CategoricalActionInterpreter(ActionInterpreter[SAOEState, int, float]): i.e., $[0, 1/n, 2/n, \\ldots, n/n]$. """ - def __init__(self, values: int | list[float]): + def __init__(self, values: int | List[float]) -> None: if isinstance(values, int): values = [i / values for i in range(0, values + 1)] self.action_values = values @@ -203,7 +201,7 @@ def interpret(self, state: SAOEState, action: int) -> float: class TwapRelativeActionInterpreter(ActionInterpreter[SAOEState, float, float]): - """Convert a continous ratio to deal amount. + """Convert a continuous ratio to deal amount. The ratio is relative to TWAP on the remainder of the day. For example, there are 5 steps left, and the left position is 300. diff --git a/qlib/rl/order_execution/network.py b/qlib/rl/order_execution/network.py index 908f96130f..3d0279559e 100644 --- a/qlib/rl/order_execution/network.py +++ b/qlib/rl/order_execution/network.py @@ -3,13 +3,14 @@ from __future__ import annotations -from typing import cast +from typing import List, Tuple, cast import torch import torch.nn as nn from tianshou.data import Batch from qlib.typehint import Literal + from .interpreter import FullHistoryObs __all__ = ["Recurrent"] @@ -18,7 +19,7 @@ class Recurrent(nn.Module): """The network architecture proposed in `OPD `_. - At every timestep the input of policy network is divided into two parts, + At every time step the input of policy network is divided into two parts, the public variables and the private variables. which are handled by ``raw_rnn`` and ``pri_rnn`` in this network, respectively. @@ -33,7 +34,7 @@ def __init__( output_dim: int = 32, rnn_type: Literal["rnn", "lstm", "gru"] = "gru", rnn_num_layers: int = 1, - ): + ) -> None: super().__init__() self.hidden_dim = hidden_dim @@ -62,10 +63,10 @@ def __init__( nn.ReLU(), ) - def _init_extra_branches(self): + def _init_extra_branches(self) -> None: pass - def _source_features(self, obs: FullHistoryObs, device: torch.device) -> tuple[list[torch.Tensor], torch.Tensor]: + def _source_features(self, obs: FullHistoryObs, device: torch.device) -> Tuple[List[torch.Tensor], torch.Tensor]: bs, _, data_dim = obs["data_processed"].size() data = torch.cat((torch.zeros(bs, 1, data_dim, device=device), obs["data_processed"]), 1) cur_step = obs["cur_step"].long() diff --git a/qlib/rl/order_execution/policy.py b/qlib/rl/order_execution/policy.py index f95a53c758..18c2e4f175 100644 --- a/qlib/rl/order_execution/policy.py +++ b/qlib/rl/order_execution/policy.py @@ -1,16 +1,17 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +from __future__ import annotations from pathlib import Path -from typing import Optional, cast +from typing import Any, Dict, Generator, Iterable, Optional, Tuple, cast -import numpy as np import gym +import numpy as np import torch import torch.nn as nn from gym.spaces import Discrete -from tianshou.data import Batch, to_torch -from tianshou.policy import PPOPolicy, BasePolicy +from tianshou.data import Batch, ReplayBuffer, to_torch +from tianshou.policy import BasePolicy, PPOPolicy __all__ = ["AllOne", "PPO"] @@ -18,29 +19,39 @@ # baselines # -class NonlearnablePolicy(BasePolicy): +class NonLearnablePolicy(BasePolicy): """Tianshou's BasePolicy with empty ``learn`` and ``process_fn``. This could be moved outside in future. """ - def __init__(self, obs_space: gym.Space, action_space: gym.Space): + def __init__(self, obs_space: gym.Space, action_space: gym.Space) -> None: super().__init__() - def learn(self, batch, batch_size, repeat): + def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, Any]: pass - def process_fn(self, batch, buffer, indice): + def process_fn( + self, + batch: Batch, + buffer: ReplayBuffer, + indices: np.ndarray, + ) -> Batch: pass -class AllOne(NonlearnablePolicy): +class AllOne(NonLearnablePolicy): """Forward returns a batch full of 1. Useful when implementing some baselines (e.g., TWAP). """ - def forward(self, batch, state=None, **kwargs): + def forward( + self, + batch: Batch, + state: dict | Batch | np.ndarray = None, + **kwargs: Any, + ) -> Batch: return Batch(act=np.full(len(batch), 1.0), state=state) @@ -48,24 +59,34 @@ def forward(self, batch, state=None, **kwargs): class PPOActor(nn.Module): - def __init__(self, extractor: nn.Module, action_dim: int): + def __init__(self, extractor: nn.Module, action_dim: int) -> None: super().__init__() self.extractor = extractor self.layer_out = nn.Sequential(nn.Linear(cast(int, extractor.output_dim), action_dim), nn.Softmax(dim=-1)) - def forward(self, obs, state=None, info={}): + def forward( + self, + obs: torch.Tensor, + state: torch.Tensor = None, + info: dict = {}, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: feature = self.extractor(to_torch(obs, device=auto_device(self))) out = self.layer_out(feature) return out, state class PPOCritic(nn.Module): - def __init__(self, extractor: nn.Module): + def __init__(self, extractor: nn.Module) -> None: super().__init__() self.extractor = extractor self.value_out = nn.Linear(cast(int, extractor.output_dim), 1) - def forward(self, obs, state=None, info={}): + def forward( + self, + obs: torch.Tensor, + state: torch.Tensor = None, + info: dict = {}, + ) -> torch.Tensor: feature = self.extractor(to_torch(obs, device=auto_device(self))) return self.value_out(feature).squeeze(dim=-1) @@ -93,18 +114,20 @@ def __init__( max_grad_norm: float = 100.0, reward_normalization: bool = True, eps_clip: float = 0.3, - value_clip: float = True, + value_clip: bool = True, vf_coef: float = 1.0, gae_lambda: float = 1.0, - max_batchsize: int = 256, + max_batch_size: int = 256, deterministic_eval: bool = True, weight_file: Optional[Path] = None, - ): + ) -> None: assert isinstance(action_space, Discrete) actor = PPOActor(network, action_space.n) critic = PPOCritic(network) optimizer = torch.optim.Adam( - chain_dedup(actor.parameters(), critic.parameters()), lr=lr, weight_decay=weight_decay + chain_dedup(actor.parameters(), critic.parameters()), + lr=lr, + weight_decay=weight_decay, ) super().__init__( actor, @@ -118,7 +141,7 @@ def __init__( value_clip=value_clip, vf_coef=vf_coef, gae_lambda=gae_lambda, - max_batchsize=max_batchsize, + max_batchsize=max_batch_size, deterministic_eval=deterministic_eval, observation_space=obs_space, action_space=action_space, @@ -136,7 +159,7 @@ def auto_device(module: nn.Module) -> torch.device: return torch.device("cpu") # fallback to cpu -def load_weight(policy, path): +def load_weight(policy: nn.Module, path: Path) -> None: assert isinstance(policy, nn.Module), "Policy has to be an nn.Module to load weight." loaded_weight = torch.load(path, map_location="cpu") try: @@ -149,7 +172,7 @@ def load_weight(policy, path): policy.load_state_dict(loaded_weight) -def chain_dedup(*iterables): +def chain_dedup(*iterables: Iterable) -> Generator[Any, None, None]: seen = set() for iterable in iterables: for i in iterable: diff --git a/qlib/rl/order_execution/reward.py b/qlib/rl/order_execution/reward.py index 43015407db..f15a152c66 100644 --- a/qlib/rl/order_execution/reward.py +++ b/qlib/rl/order_execution/reward.py @@ -6,9 +6,10 @@ from typing import cast import numpy as np + from qlib.rl.reward import Reward -from .simulator_simple import SAOEState, SAOEMetrics +from .simulator_simple import SAOEMetrics, SAOEState __all__ = ["PAPenaltyReward"] diff --git a/qlib/rl/order_execution/simulator_qlib.py b/qlib/rl/order_execution/simulator_qlib.py index 71aaa222be..c75793f586 100644 --- a/qlib/rl/order_execution/simulator_qlib.py +++ b/qlib/rl/order_execution/simulator_qlib.py @@ -1,4 +1,424 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -"""Placeholder for qlib-based simulator.""" +from __future__ import annotations + +from typing import Any, Callable, cast, Generator, List, Optional, Tuple + +import numpy as np +import pandas as pd + +from qlib.backtest.decision import BaseTradeDecision, Order, OrderHelper, TradeDecisionWO, TradeRange, TradeRangeByTime +from qlib.backtest.executor import BaseExecutor, NestedExecutor +from qlib.backtest.utils import CommonInfrastructure +from qlib.constant import EPS +from qlib.rl.data.exchange_wrapper import QlibIntradayBacktestData +from qlib.rl.from_neutrader.config import ExchangeConfig +from qlib.rl.from_neutrader.feature import init_qlib +from qlib.rl.order_execution.simulator_simple import SAOEMetrics, SAOEState +from qlib.rl.order_execution.utils import ( + dataframe_append, + get_common_infra, + get_portfolio_and_indicator, + get_ticks_slice, + price_advantage, +) +from qlib.rl.simulator import Simulator +from qlib.strategy.base import BaseStrategy + + +class DecomposedStrategy(BaseStrategy): + def __init__(self) -> None: + super().__init__() + + self.execute_order: Optional[Order] = None + self.execute_result: List[Tuple[Order, float, float, float]] = [] + + def generate_trade_decision(self, execute_result: list = None) -> Generator[Any, Any, BaseTradeDecision]: + # Once the following line is executed, this DecomposedStrategy (self) will be yielded to the outside + # of the entire executor, and the execution will be suspended. When the execution is resumed by `send()`, + # the sent item will be captured by `exec_vol`. The outside policy could communicate with the inner + # level strategy through this way. + exec_vol = yield self + + oh = self.trade_exchange.get_order_helper() + order = oh.create(self._order.stock_id, exec_vol, self._order.direction) + + self.execute_order = order + + return TradeDecisionWO([order], self) + + def alter_outer_trade_decision(self, outer_trade_decision: BaseTradeDecision) -> BaseTradeDecision: + return outer_trade_decision + + def post_exe_step(self, execute_result: list) -> None: + self.execute_result = execute_result + + def reset(self, outer_trade_decision: TradeDecisionWO = None, **kwargs: Any) -> None: + super().reset(outer_trade_decision=outer_trade_decision, **kwargs) + if outer_trade_decision is not None: + order_list = outer_trade_decision.order_list + assert len(order_list) == 1 + self._order = order_list[0] + + +class SingleOrderStrategy(BaseStrategy): + # this logic is copied from FileOrderStrategy + def __init__( + self, + common_infra: CommonInfrastructure, + order: Order, + trade_range: TradeRange, + instrument: str, + ) -> None: + super().__init__(common_infra=common_infra) + self._order = order + self._trade_range = trade_range + self._instrument = instrument + + def alter_outer_trade_decision(self, outer_trade_decision: BaseTradeDecision) -> BaseTradeDecision: + return outer_trade_decision + + def generate_trade_decision(self, execute_result: list = None) -> TradeDecisionWO: + oh: OrderHelper = self.common_infra.get("trade_exchange").get_order_helper() + order_list = [ + oh.create( + code=self._instrument, + amount=self._order.amount, + direction=self._order.direction, + ), + ] + return TradeDecisionWO(order_list, self, self._trade_range) + + +# TODO: move these to the configuration files +FINEST_GRANULARITY = "1min" +COARSEST_GRANULARITY = "1day" + + +class StateMaintainer: + """ + Maintain states of the environment. + + Example usage:: + + maintainer = StateMaintainer(...) # in reset + maintainer.update(...) # in step + # get states in get_state from maintainer + """ + + def __init__(self, order: Order, time_per_step: str, tick_index: pd.DatetimeIndex, twap_price: float) -> None: + super().__init__() + + self.position = order.amount + self._order = order + self._time_per_step = time_per_step + self._tick_index = tick_index + self._twap_price = twap_price + + metric_keys = list(SAOEMetrics.__annotations__.keys()) # pylint: disable=no-member + self.history_exec = pd.DataFrame(columns=metric_keys).set_index("datetime") + self.history_steps = pd.DataFrame(columns=metric_keys).set_index("datetime") + self.metrics: Optional[SAOEMetrics] = None + + def update( + self, + inner_executor: BaseExecutor, + inner_strategy: DecomposedStrategy, + done: bool, + all_indicators: dict, + ) -> None: + execute_order = inner_strategy.execute_order + execute_result = inner_strategy.execute_result + exec_vol = np.array([e[0].deal_amount for e in execute_result]) + num_step = len(execute_result) + + assert execute_order is not None + + if num_step == 0: + market_volume = np.array([]) + market_price = np.array([]) + datetime_list = pd.DatetimeIndex([]) + else: + market_volume = np.array( + inner_executor.trade_exchange.get_volume( + execute_order.stock_id, + execute_result[0][0].start_time, + execute_result[-1][0].start_time, + method=None, + ), + ) + + trade_value = all_indicators[FINEST_GRANULARITY].iloc[-num_step:]["value"].values + deal_amount = all_indicators[FINEST_GRANULARITY].iloc[-num_step:]["deal_amount"].values + market_price = trade_value / deal_amount + + datetime_list = all_indicators[FINEST_GRANULARITY].index[-num_step:] + + assert market_price.shape == market_volume.shape == exec_vol.shape + + self.history_exec = dataframe_append( + self.history_exec, + self._collect_multi_order_metric( + order=self._order, + datetime=datetime_list, + market_vol=market_volume, + market_price=market_price, + exec_vol=exec_vol, + pa=all_indicators[self._time_per_step].iloc[-1]["pa"], + ), + ) + + self.history_steps = dataframe_append( + self.history_steps, + [ + self._collect_single_order_metric( + execute_order, + execute_order.start_time, + market_volume, + market_price, + exec_vol.sum(), + exec_vol, + ), + ], + ) + + if done: + self.metrics = self._collect_single_order_metric( + self._order, + self._tick_index[0], # start time + self.history_exec["market_volume"], + self.history_exec["market_price"], + self.history_steps["amount"].sum(), + self.history_exec["deal_amount"], + ) + + # TODO: check whether we need this. Can we get this information from Account? + # Do this at the end + self.position -= exec_vol.sum() + + def _collect_multi_order_metric( + self, + order: Order, + datetime: pd.Timestamp, + market_vol: np.ndarray, + market_price: np.ndarray, + exec_vol: np.ndarray, + pa: float, + ) -> SAOEMetrics: + return SAOEMetrics( + # It should have the same keys with SAOEMetrics, + # but the values do not necessarily have the annotated type. + # Some values could be vectorized (e.g., exec_vol). + stock_id=order.stock_id, + datetime=datetime, + direction=order.direction, + market_volume=market_vol, + market_price=market_price, + amount=exec_vol, + inner_amount=exec_vol, + deal_amount=exec_vol, + trade_price=market_price, + trade_value=market_price * exec_vol, + position=self.position - np.cumsum(exec_vol), + ffr=exec_vol / order.amount, + pa=pa, + ) + + def _collect_single_order_metric( + self, + order: Order, + datetime: pd.Timestamp, + market_vol: np.ndarray, + market_price: np.ndarray, + amount: float, # intended to trade such amount + exec_vol: np.ndarray, + ) -> SAOEMetrics: + assert len(market_vol) == len(market_price) == len(exec_vol) + + if np.abs(np.sum(exec_vol)) < EPS: + exec_avg_price = 0.0 + else: + exec_avg_price = cast(float, np.average(market_price, weights=exec_vol)) # could be nan + if hasattr(exec_avg_price, "item"): # could be numpy scalar + exec_avg_price = exec_avg_price.item() # type: ignore + + exec_sum = exec_vol.sum() + return SAOEMetrics( + stock_id=order.stock_id, + datetime=datetime, + direction=order.direction, + market_volume=market_vol.sum(), + market_price=market_price.mean() if len(market_price) > 0 else np.nan, + amount=amount, + inner_amount=exec_sum, + deal_amount=exec_sum, # in this simulator, there's no other restrictions + trade_price=exec_avg_price, + trade_value=float(np.sum(market_price * exec_vol)), + position=self.position - exec_sum, + ffr=float(exec_sum / order.amount), + pa=price_advantage(exec_avg_price, self._twap_price, order.direction), + ) + + +class SingleAssetOrderExecutionQlib(Simulator[Order, SAOEState, float]): + """Single-asset order execution (SAOE) simulator which is implemented based on Qlib backtest tools. + + Parameters + ---------- + order (Order): + The seed to start an SAOE simulator is an order. + time_per_step (str): + A string to describe the time granularity of each step. Current support "1min", "30min", and "1day" + qlib_config (dict): + Configuration used to initialize Qlib. + inner_executor_fn (Callable[[str, CommonInfrastructure], BaseExecutor]): + Function used to get the inner level executor. + exchange_config (ExchangeConfig): + Configuration used to create the Exchange instance. + """ + + def __init__( + self, + order: Order, + time_per_step: str, # "1min", "30min", "1day" + qlib_config: dict, + inner_executor_fn: Callable[[str, CommonInfrastructure], BaseExecutor], + exchange_config: ExchangeConfig, + ) -> None: + assert time_per_step in ("1min", "30min", "1day") + + super().__init__(initial=order) + + assert order.start_time.date() == order.end_time.date(), "Start date and end date must be the same." + + self._order = order + self._order_date = pd.Timestamp(order.start_time.date()) + self._trade_range = TradeRangeByTime(order.start_time.time(), order.end_time.time()) + self._qlib_config = qlib_config + self._inner_executor_fn = inner_executor_fn + self._exchange_config = exchange_config + + self._time_per_step = time_per_step + self._ticks_per_step = int(pd.Timedelta(time_per_step).total_seconds() // 60) + + self._executor: Optional[NestedExecutor] = None + self._collect_data_loop: Optional[Generator] = None + + self._done = False + + self._inner_strategy = DecomposedStrategy() + + self.reset(self._order) + + def reset(self, order: Order) -> None: + instrument = order.stock_id + + # TODO: Check this logic. Make sure we need to do this every time we reset the simulator. + init_qlib(self._qlib_config, instrument) + + common_infra = get_common_infra( + self._exchange_config, + trade_date=pd.Timestamp(self._order_date), + codes=[instrument], + ) + + # TODO: We can leverage interfaces like (https://tinyurl.com/y8f8fhv4) to create trading environment. + # TODO: By aligning the interface to create environments with Qlib, it will be easier to share the config and + # TODO: code between backtesting and training. + self._inner_executor = self._inner_executor_fn(self._time_per_step, common_infra) + self._executor = NestedExecutor( + time_per_step=COARSEST_GRANULARITY, + inner_executor=self._inner_executor, + inner_strategy=self._inner_strategy, + track_data=True, + common_infra=common_infra, + ) + + exchange = self._inner_executor.trade_exchange + self._ticks_index = pd.DatetimeIndex([e[1] for e in list(exchange.quote_df.index)]) + self._ticks_for_order = get_ticks_slice( + self._ticks_index, + self._order.start_time, + self._order.end_time, + include_end=True, + ) + + self._backtest_data = QlibIntradayBacktestData( + order=self._order, + exchange=exchange, + start_time=self._ticks_for_order[0], + end_time=self._ticks_for_order[-1], + ) + + self.twap_price = self._backtest_data.get_deal_price().mean() + + top_strategy = SingleOrderStrategy(common_infra, order, self._trade_range, instrument) + self._executor.reset(start_time=pd.Timestamp(self._order_date), end_time=pd.Timestamp(self._order_date)) + top_strategy.reset(level_infra=self._executor.get_level_infra()) + + self._collect_data_loop = self._executor.collect_data(top_strategy.generate_trade_decision(), level=0) + assert isinstance(self._collect_data_loop, Generator) + + self._iter_strategy(action=None) + self._done = False + + self._maintainer = StateMaintainer( + order=self._order, + time_per_step=self._time_per_step, + tick_index=self._ticks_index, + twap_price=self.twap_price, + ) + + def _iter_strategy(self, action: float = None) -> DecomposedStrategy: + """Iterate the _collect_data_loop until we get the next yield DecomposedStrategy.""" + assert self._collect_data_loop is not None + + strategy = next(self._collect_data_loop) if action is None else self._collect_data_loop.send(action) + while not isinstance(strategy, DecomposedStrategy): + strategy = next(self._collect_data_loop) if action is None else self._collect_data_loop.send(action) + assert isinstance(strategy, DecomposedStrategy) + return strategy + + def step(self, action: float) -> None: + """Execute one step or SAOE. + + Parameters + ---------- + action (float): + The amount you wish to deal. The simulator doesn't guarantee all the amount to be successfully dealt. + """ + + assert not self._done, "Simulator has already done!" + + try: + self._iter_strategy(action=action) + except StopIteration: + self._done = True + + assert self._executor is not None + _, all_indicators = get_portfolio_and_indicator(self._executor) + + self._maintainer.update( + inner_executor=self._inner_executor, + inner_strategy=self._inner_strategy, + done=self._done, + all_indicators=all_indicators, + ) + + def get_state(self) -> SAOEState: + return SAOEState( + order=self._order, + cur_time=self._inner_executor.trade_calendar.get_step_time()[0], + position=self._maintainer.position, + history_exec=self._maintainer.history_exec, + history_steps=self._maintainer.history_steps, + metrics=self._maintainer.metrics, + backtest_data=self._backtest_data, + ticks_per_step=self._ticks_per_step, + ticks_index=self._ticks_index, + ticks_for_order=self._ticks_for_order, + ) + + def done(self) -> bool: + return self._done diff --git a/qlib/rl/order_execution/simulator_simple.py b/qlib/rl/order_execution/simulator_simple.py index 51357dfdfa..6d49457841 100644 --- a/qlib/rl/order_execution/simulator_simple.py +++ b/qlib/rl/order_execution/simulator_simple.py @@ -4,18 +4,20 @@ from __future__ import annotations from pathlib import Path -from typing import NamedTuple, Any, TypeVar, cast +from typing import Any, NamedTuple, Optional, TypeVar, cast import numpy as np import pandas as pd from qlib.backtest.decision import Order, OrderDir from qlib.constant import EPS +from qlib.rl.data.pickle_styled import DealPriceType, IntradayBacktestData, load_simple_intraday_backtest_data from qlib.rl.simulator import Simulator -from qlib.rl.data.pickle_styled import IntradayBacktestData, load_intraday_backtest_data, DealPriceType from qlib.rl.utils import LogLevel from qlib.typehint import TypedDict +# TODO: Integrating Qlib's native data with simulator_simple + __all__ = ["SAOEMetrics", "SAOEState", "SingleAssetOrderExecution"] ONE_SEC = pd.Timedelta("1s") # use 1 second to exclude the right interval point @@ -33,40 +35,40 @@ class SAOEMetrics(TypedDict): stock_id: str """Stock ID of this record.""" - datetime: pd.Timestamp + datetime: pd.Timestamp | pd.DatetimeIndex # TODO: check this """Datetime of this record (this is index in the dataframe).""" direction: int """Direction of the order. 0 for sell, 1 for buy.""" # Market information. - market_volume: float + market_volume: np.ndarray | float """(total) market volume traded in the period.""" - market_price: float + market_price: np.ndarray | float """Deal price. If it's a period of time, this is the average market deal price.""" # Strategy records. - amount: float + amount: np.ndarray | float """Total amount (volume) strategy intends to trade.""" - inner_amount: float + inner_amount: np.ndarray | float """Total amount that the lower-level strategy intends to trade (might be larger than amount, e.g., to ensure ffr).""" - deal_amount: float + deal_amount: np.ndarray | float """Amount that successfully takes effect (must be less than inner_amount).""" - trade_price: float + trade_price: np.ndarray | float """The average deal price for this strategy.""" - trade_value: float - """Total worth of trading. In the simple simulaton, trade_value = deal_amount * price.""" - position: float + trade_value: np.ndarray | float + """Total worth of trading. In the simple simulation, trade_value = deal_amount * price.""" + position: np.ndarray | float """Position left after this "period".""" # Accumulated metrics - ffr: float + ffr: np.ndarray | float """Completed how much percent of the daily order.""" - pa: float + pa: np.ndarray | float """Price advantage compared to baseline (i.e., trade with baseline market price). The baseline is trade price when using TWAP strategy to execute this order. Please note that there could be data leak here). @@ -87,7 +89,7 @@ class SAOEState(NamedTuple): history_steps: pd.DataFrame """See :attr:`SingleAssetOrderExecution.history_steps`.""" - metrics: SAOEMetrics | None + metrics: Optional[SAOEMetrics] """Daily metric, only available when the trading is in "done" state.""" backtest_data: IntradayBacktestData @@ -114,13 +116,13 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]): If such fine granularity is not needed, use ``ticks_per_step`` to lengthen the ticks for each step. - In each step, the traded amount are "equally" splitted to each tick, - then bounded by volume maximum exeuction volume (i.e., ``vol_threshold``), + In each step, the traded amount are "equally" separated to each tick, + then bounded by volume maximum execution volume (i.e., ``vol_threshold``), and if it's the last step, try to ensure all the amount to be executed. Parameters ---------- - initial + order The seed to start an SAOE simulator is an order. ticks_per_step How many ticks per step. @@ -140,7 +142,7 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]): See :class:`SAOEMetrics` for available columns. Index is ``datetime``, which is the **starting** time of each step.""" - metrics: SAOEMetrics | None + metrics: Optional[SAOEMetrics] """Metrics. Only available when done.""" twap_price: float @@ -159,15 +161,21 @@ def __init__( data_dir: Path, ticks_per_step: int = 30, deal_price_type: DealPriceType = "close", - vol_threshold: float | None = None, + vol_threshold: Optional[float] = None, ) -> None: + super().__init__(initial=order) + self.order = order self.ticks_per_step: int = ticks_per_step self.deal_price_type = deal_price_type self.vol_threshold = vol_threshold self.data_dir = data_dir - self.backtest_data = load_intraday_backtest_data( - self.data_dir, order.stock_id, pd.Timestamp(order.start_time.date()), self.deal_price_type, order.direction + self.backtest_data = load_simple_intraday_backtest_data( + self.data_dir, + order.stock_id, + pd.Timestamp(order.start_time.date()), + self.deal_price_type, + order.direction, ) self.ticks_index = self.backtest_data.get_time_index() @@ -188,9 +196,9 @@ def __init__( self.history_steps = pd.DataFrame(columns=metric_keys).set_index("datetime") self.metrics = None - self.market_price: np.ndarray | None = None - self.market_vol: np.ndarray | None = None - self.market_vol_limit: np.ndarray | None = None + self.market_price: Optional[np.ndarray] = None + self.market_vol: Optional[np.ndarray] = None + self.market_vol_limit: Optional[np.ndarray] = None def step(self, amount: float) -> None: """Execute one step or SAOE. @@ -205,7 +213,8 @@ def step(self, amount: float) -> None: self.market_price = self.market_vol = None # avoid misuse exec_vol = self._split_exec_vol(amount) - assert self.market_price is not None and self.market_vol is not None + assert self.market_price is not None + assert self.market_vol is not None ticks_position = self.position - np.cumsum(exec_vol) @@ -363,7 +372,7 @@ def _metrics_collect( inner_amount=exec_vol.sum(), deal_amount=exec_vol.sum(), # in this simulator, there's no other restrictions trade_price=exec_avg_price, - trade_value=np.sum(market_price * exec_vol), + trade_value=float(np.sum(market_price * exec_vol)), position=self.position, ffr=float(exec_vol.sum() / self.order.amount), pa=price_advantage(exec_avg_price, self.twap_price, self.order.direction), @@ -386,7 +395,9 @@ def _dataframe_append(df: pd.DataFrame, other: Any) -> pd.DataFrame: def price_advantage( - exec_price: _float_or_ndarray, baseline_price: float, direction: OrderDir | int + exec_price: _float_or_ndarray, + baseline_price: float, + direction: OrderDir | int, ) -> _float_or_ndarray: if baseline_price == 0: # something is wrong with data. Should be nan here if isinstance(exec_price, float): diff --git a/qlib/rl/order_execution/utils.py b/qlib/rl/order_execution/utils.py new file mode 100644 index 0000000000..e2d0de9812 --- /dev/null +++ b/qlib/rl/order_execution/utils.py @@ -0,0 +1,111 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +from typing import Any, List, Tuple, cast + +import numpy as np +import pandas as pd + +from qlib.backtest import CommonInfrastructure, get_exchange +from qlib.backtest.account import Account +from qlib.backtest.decision import OrderDir +from qlib.backtest.executor import BaseExecutor +from qlib.rl.from_neutrader.config import ExchangeConfig +from qlib.rl.order_execution.simulator_simple import ONE_SEC, _float_or_ndarray +from qlib.utils.time import Freq + + +def get_common_infra( + config: ExchangeConfig, + trade_date: pd.Timestamp, + codes: List[str], + cash_limit: float = None, +) -> CommonInfrastructure: + # need to specify a range here for acceleration + if cash_limit is None: + trade_account = Account(init_cash=int(1e12), benchmark_config={}, pos_type="InfPosition") + else: + trade_account = Account( + init_cash=cash_limit, + benchmark_config={}, + pos_type="Position", + position_dict={code: {"amount": 1e12, "price": 1.0} for code in codes}, + ) + + exchange = get_exchange( + codes=codes, + freq="1min", + limit_threshold=config.limit_threshold, + deal_price=config.deal_price, + open_cost=config.open_cost, + close_cost=config.close_cost, + min_cost=config.min_cost if config.trade_unit is not None else 0, + start_time=trade_date, + end_time=trade_date + pd.DateOffset(1), + trade_unit=config.trade_unit, + volume_threshold=config.volume_threshold, + ) + + return CommonInfrastructure(trade_account=trade_account, trade_exchange=exchange) + + +def get_ticks_slice( + ticks_index: pd.DatetimeIndex, + start: pd.Timestamp, + end: pd.Timestamp, + include_end: bool = False, +) -> pd.DatetimeIndex: + if not include_end: + end = end - ONE_SEC + return ticks_index[ticks_index.slice_indexer(start, end)] + + +def dataframe_append(df: pd.DataFrame, other: Any) -> pd.DataFrame: + # dataframe.append is deprecated + other_df = pd.DataFrame(other).set_index("datetime") + other_df.index.name = "datetime" + + res = pd.concat([df, other_df], axis=0) + return res + + +def price_advantage( + exec_price: _float_or_ndarray, + baseline_price: float, + direction: OrderDir | int, +) -> _float_or_ndarray: + if baseline_price == 0: # something is wrong with data. Should be nan here + if isinstance(exec_price, float): + return 0.0 + else: + return np.zeros_like(exec_price) + if direction == OrderDir.BUY: + res = (1 - exec_price / baseline_price) * 10000 + elif direction == OrderDir.SELL: + res = (exec_price / baseline_price - 1) * 10000 + else: + raise ValueError(f"Unexpected order direction: {direction}") + res_wo_nan: np.ndarray = np.nan_to_num(res, nan=0.0) + if res_wo_nan.size == 1: + return res_wo_nan.item() + else: + return cast(_float_or_ndarray, res_wo_nan) + + +def get_portfolio_and_indicator(executor: BaseExecutor) -> Tuple[dict, dict]: + all_executors = executor.get_all_executors() + all_portfolio_metrics = { + "{}{}".format(*Freq.parse(_executor.time_per_step)): _executor.trade_account.get_portfolio_metrics() + for _executor in all_executors + if _executor.trade_account.is_port_metr_enabled() + } + + all_indicators = {} + for _executor in all_executors: + key = "{}{}".format(*Freq.parse(_executor.time_per_step)) + all_indicators[key] = _executor.trade_account.get_trade_indicator().generate_trade_indicators_dataframe() + all_indicators[key + "_obj"] = _executor.trade_account.get_trade_indicator() + + return all_portfolio_metrics, all_indicators diff --git a/qlib/rl/reward.py b/qlib/rl/reward.py index 20d9858742..fd0dbdc86e 100644 --- a/qlib/rl/reward.py +++ b/qlib/rl/reward.py @@ -3,7 +3,7 @@ from __future__ import annotations -from typing import Generic, Any, TypeVar, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Dict, Generic, Optional, Tuple, TypeVar from qlib.typehint import final @@ -20,7 +20,7 @@ class Reward(Generic[SimulatorState]): Subclass should implement ``reward(simulator_state)`` to implement their own reward calculation recipe. """ - env: EnvWrapper | None = None + env: Optional[EnvWrapper] = None @final def __call__(self, simulator_state: SimulatorState) -> float: @@ -30,14 +30,15 @@ def reward(self, simulator_state: SimulatorState) -> float: """Implement this method for your own reward.""" raise NotImplementedError("Implement reward calculation recipe in `reward()`.") - def log(self, name, value): + def log(self, name: str, value: Any) -> None: + assert self.env is not None self.env.logger.add_scalar(name, value) class RewardCombination(Reward): """Combination of multiple reward.""" - def __init__(self, rewards: dict[str, tuple[Reward, float]]): + def __init__(self, rewards: Dict[str, Tuple[Reward, float]]) -> None: self.rewards = rewards def reward(self, simulator_state: Any) -> float: diff --git a/qlib/rl/simulator.py b/qlib/rl/simulator.py index 56fc12042c..72e74b64fa 100644 --- a/qlib/rl/simulator.py +++ b/qlib/rl/simulator.py @@ -3,7 +3,7 @@ from __future__ import annotations -from typing import TypeVar, Generic, Any, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar from .seed import InitialStateType @@ -49,7 +49,7 @@ class Simulator(Generic[InitialStateType, StateType, ActType]): Simulators are discouraged to use this, because it's prone to induce errors. """ - env: EnvWrapper | None = None + env: Optional[EnvWrapper] = None def __init__(self, initial: InitialStateType, **kwargs: Any) -> None: pass diff --git a/qlib/rl/trainer/api.py b/qlib/rl/trainer/api.py index 65abbd88d0..e9f48df249 100644 --- a/qlib/rl/trainer/api.py +++ b/qlib/rl/trainer/api.py @@ -3,17 +3,17 @@ from __future__ import annotations -from typing import Callable, Sequence, cast, Any +from typing import Any, Callable, Sequence, cast from tianshou.policy import BasePolicy -from qlib.rl.simulator import InitialStateType, Simulator -from qlib.rl.interpreter import StateInterpreter, ActionInterpreter +from qlib.rl.interpreter import ActionInterpreter, StateInterpreter from qlib.rl.reward import Reward +from qlib.rl.simulator import InitialStateType, Simulator from qlib.rl.utils import FiniteEnvType, LogWriter -from .vessel import TrainingVessel from .trainer import Trainer +from .vessel import TrainingVessel def train( diff --git a/qlib/rl/trainer/callbacks.py b/qlib/rl/trainer/callbacks.py index 72e2df99af..c76b674c6e 100644 --- a/qlib/rl/trainer/callbacks.py +++ b/qlib/rl/trainer/callbacks.py @@ -12,7 +12,7 @@ import time from datetime import datetime from pathlib import Path -from typing import Any, TYPE_CHECKING +from typing import TYPE_CHECKING, Any import numpy as np import torch diff --git a/qlib/rl/trainer/trainer.py b/qlib/rl/trainer/trainer.py index c44419e055..f8f4c548d7 100644 --- a/qlib/rl/trainer/trainer.py +++ b/qlib/rl/trainer/trainer.py @@ -6,13 +6,13 @@ import copy from contextlib import AbstractContextManager, contextmanager from pathlib import Path -from typing import Any, Iterable, TypeVar, Sequence, cast +from typing import Any, Iterable, Sequence, TypeVar, cast import torch -from qlib.rl.simulator import InitialStateType -from qlib.rl.utils import EnvWrapper, FiniteEnvType, LogCollector, LogWriter, LogBuffer, vectorize_env, LogLevel from qlib.log import get_module_logger +from qlib.rl.simulator import InitialStateType +from qlib.rl.utils import EnvWrapper, FiniteEnvType, LogBuffer, LogCollector, LogLevel, LogWriter, vectorize_env from qlib.rl.utils.finite_env import FiniteVectorEnv from qlib.typehint import Literal diff --git a/qlib/rl/trainer/vessel.py b/qlib/rl/trainer/vessel.py index 9c0879ce02..e1ad0cb98e 100644 --- a/qlib/rl/trainer/vessel.py +++ b/qlib/rl/trainer/vessel.py @@ -4,7 +4,7 @@ from __future__ import annotations import weakref -from typing import Callable, ContextManager, Generic, Iterable, TYPE_CHECKING, Sequence, Any, TypeVar, cast, Dict +from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, Generic, Iterable, Sequence, TypeVar, cast import numpy as np from tianshou.data import Collector, VectorReplayBuffer @@ -12,12 +12,11 @@ from tianshou.policy import BasePolicy from qlib.constant import INF -from qlib.rl.interpreter import StateType, ActType, ObsType, PolicyActType -from qlib.rl.simulator import InitialStateType, Simulator -from qlib.rl.interpreter import StateInterpreter, ActionInterpreter +from qlib.log import get_module_logger +from qlib.rl.interpreter import ActionInterpreter, ActType, ObsType, PolicyActType, StateInterpreter, StateType from qlib.rl.reward import Reward +from qlib.rl.simulator import InitialStateType, Simulator from qlib.rl.utils import DataQueue -from qlib.log import get_module_logger from qlib.rl.utils.finite_env import FiniteVectorEnv if TYPE_CHECKING: @@ -209,6 +208,9 @@ def _random_subset(name: str, collection: Sequence[T], size: int | None) -> Sequ order = np.random.permutation(len(collection)) res = [collection[o] for o in order[:size]] _logger.info( - "Fast running in development mode. Cut %s initial states from %d to %d.", name, len(collection), len(res) + "Fast running in development mode. Cut %s initial states from %d to %d.", + name, + len(collection), + len(res), ) return res diff --git a/qlib/rl/utils/__init__.py b/qlib/rl/utils/__init__.py index 4a1fa9d905..7c7ba205d8 100644 --- a/qlib/rl/utils/__init__.py +++ b/qlib/rl/utils/__init__.py @@ -1,7 +1,21 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from .data_queue import * -from .env_wrapper import * -from .finite_env import * -from .log import * +from .data_queue import DataQueue +from .env_wrapper import EnvWrapper, EnvWrapperStatus +from .finite_env import FiniteEnvType, vectorize_env +from .log import ConsoleWriter, CsvWriter, LogBuffer, LogCollector, LogLevel, LogWriter + +__all__ = [ + "LogLevel", + "DataQueue", + "EnvWrapper", + "FiniteEnvType", + "LogCollector", + "LogWriter", + "vectorize_env", + "ConsoleWriter", + "CsvWriter", + "EnvWrapperStatus", + "LogBuffer", +] diff --git a/qlib/rl/utils/data_queue.py b/qlib/rl/utils/data_queue.py index c1f7f3ab04..8282888715 100644 --- a/qlib/rl/utils/data_queue.py +++ b/qlib/rl/utils/data_queue.py @@ -1,13 +1,15 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import os +from __future__ import annotations + import multiprocessing +import os import threading import time import warnings from queue import Empty -from typing import TypeVar, Generic, Sequence, cast +from typing import Any, Generator, Generic, Sequence, TypeVar, cast from qlib.log import get_module_logger @@ -60,7 +62,7 @@ def __init__( shuffle: bool = True, producer_num_workers: int = 0, queue_maxsize: int = 0, - ): + ) -> None: if queue_maxsize == 0: if os.cpu_count() is not None: queue_maxsize = cast(int, os.cpu_count()) @@ -78,14 +80,14 @@ def __init__( self._queue: multiprocessing.Queue = multiprocessing.Queue(maxsize=queue_maxsize) self._done = multiprocessing.Value("i", 0) - def __enter__(self): + def __enter__(self) -> DataQueue: self.activate() return self def __exit__(self, exc_type, exc_val, exc_tb): self.cleanup() - def cleanup(self): + def cleanup(self) -> None: with self._done.get_lock(): self._done.value += 1 for repeat in range(500): @@ -105,7 +107,7 @@ def cleanup(self): break _logger.debug(f"Remaining items in queue collection done. Empty: {self._queue.empty()}") - def get(self, block=True): + def get(self, block: bool = True) -> Any: if not hasattr(self, "_first_get"): self._first_get = True if self._first_get: @@ -120,17 +122,17 @@ def get(self, block=True): if self._done.value: raise StopIteration # pylint: disable=raise-missing-from - def put(self, obj, block=True, timeout=None): - return self._queue.put(obj, block=block, timeout=timeout) + def put(self, obj: Any, block: bool = True, timeout: int = None) -> None: + self._queue.put(obj, block=block, timeout=timeout) - def mark_as_done(self): + def mark_as_done(self) -> None: with self._done.get_lock(): self._done.value = 1 - def done(self): + def done(self) -> int: return self._done.value - def activate(self): + def activate(self) -> DataQueue: if self._activated: raise ValueError("DataQueue can not activate twice.") thread = threading.Thread(target=self._producer, daemon=True) @@ -138,20 +140,20 @@ def activate(self): self._activated = True return self - def __del__(self): + def __del__(self) -> None: _logger.debug(f"__del__ of {__name__}.DataQueue") self.cleanup() - def __iter__(self): + def __iter__(self) -> Generator[Any, None, None]: if not self._activated: raise ValueError( "Need to call activate() to launch a daemon worker " "to produce data into data queue before using it. " - "You probably have forgotten to use the DataQueue in a with block." + "You probably have forgotten to use the DataQueue in a with block.", ) return self._consumer() - def _consumer(self): + def _consumer(self) -> Generator[Any, None, None]: while True: try: yield self.get() @@ -159,7 +161,7 @@ def _consumer(self): _logger.debug("Data consumer timed-out from get.") return - def _producer(self): + def _producer(self) -> None: # pytorch dataloader is used here only because we need its sampler and multi-processing from torch.utils.data import DataLoader, Dataset # pylint: disable=import-outside-toplevel diff --git a/qlib/rl/utils/env_wrapper.py b/qlib/rl/utils/env_wrapper.py index f343e5b9b4..529bfe5973 100644 --- a/qlib/rl/utils/env_wrapper.py +++ b/qlib/rl/utils/env_wrapper.py @@ -4,14 +4,15 @@ from __future__ import annotations import weakref -from typing import Callable, Any, Iterable, Iterator, Generic, cast +from typing import Any, Callable, Dict, Generic, Iterable, Iterator, Optional, Tuple, cast import gym +from gym import Space from qlib.rl.aux_info import AuxiliaryInfoCollector -from qlib.rl.simulator import Simulator, InitialStateType, StateType, ActType -from qlib.rl.interpreter import StateInterpreter, ActionInterpreter, PolicyActType, ObsType +from qlib.rl.interpreter import ActionInterpreter, ObsType, PolicyActType, StateInterpreter from qlib.rl.reward import Reward +from qlib.rl.simulator import ActType, InitialStateType, Simulator, StateType from qlib.typehint import TypedDict from .finite_env import generate_nan_observation @@ -28,7 +29,7 @@ class InfoDict(TypedDict): aux_info: dict """Any information depends on auxiliary info collector.""" - log: dict[str, Any] + log: Dict[str, Any] """Collected by LogCollector.""" @@ -42,14 +43,15 @@ class EnvWrapperStatus(TypedDict): cur_step: int done: bool - initial_state: Any | None + initial_state: Optional[Any] obs_history: list action_history: list reward_history: list class EnvWrapper( - gym.Env[ObsType, PolicyActType], Generic[InitialStateType, StateType, ActType, ObsType, PolicyActType] + gym.Env[ObsType, PolicyActType], + Generic[InitialStateType, StateType, ActType, ObsType, PolicyActType], ): """Qlib-based RL environment, subclassing ``gym.Env``. A wrapper of components, including simulator, state-interpreter, action-interpreter, reward. @@ -97,11 +99,11 @@ def __init__( simulator_fn: Callable[..., Simulator[InitialStateType, StateType, ActType]], state_interpreter: StateInterpreter[StateType, ObsType], action_interpreter: ActionInterpreter[StateType, PolicyActType, ActType], - seed_iterator: Iterable[InitialStateType] | None, - reward_fn: Reward | None = None, - aux_info_collector: AuxiliaryInfoCollector[StateType, Any] | None = None, - logger: LogCollector | None = None, - ): + seed_iterator: Optional[Iterable[InitialStateType]], + reward_fn: Reward = None, + aux_info_collector: AuxiliaryInfoCollector[StateType, Any] = None, + logger: LogCollector = None, + ) -> None: # Assign weak reference to wrapper. # # Use weak reference here, because: @@ -135,11 +137,11 @@ def __init__( self.status: EnvWrapperStatus = cast(EnvWrapperStatus, None) @property - def action_space(self): + def action_space(self) -> Space: return self.action_interpreter.action_space @property - def observation_space(self): + def observation_space(self) -> Space: return self.state_interpreter.observation_space def reset(self, **kwargs: Any) -> ObsType: @@ -191,7 +193,7 @@ def reset(self, **kwargs: Any) -> ObsType: self.seed_iterator = None return generate_nan_observation(self.observation_space) - def step(self, policy_action: PolicyActType, **kwargs: Any) -> tuple[ObsType, float, bool, InfoDict]: + def step(self, policy_action: PolicyActType, **kwargs: Any) -> Tuple[ObsType, float, bool, InfoDict]: """Environment step. See the code along with comments to get a sequence of things happening here. @@ -245,5 +247,5 @@ def step(self, policy_action: PolicyActType, **kwargs: Any) -> tuple[ObsType, fl info_dict = InfoDict(log=self.logger.logs(), aux_info=aux_info) return obs, rew, done, info_dict - def render(self): + def render(self, mode: str = "human") -> None: raise NotImplementedError("Render is not implemented in EnvWrapper.") diff --git a/qlib/rl/utils/finite_env.py b/qlib/rl/utils/finite_env.py index 6d7b0e2096..309b34e6dd 100644 --- a/qlib/rl/utils/finite_env.py +++ b/qlib/rl/utils/finite_env.py @@ -11,11 +11,10 @@ import copy import warnings from contextlib import contextmanager +from typing import Any, Callable, cast, Dict, Generator, List, Optional, Set, Tuple, Type, Union import gym import numpy as np -from typing import Any, Set, Callable, Type - from tianshou.env import BaseVectorEnv, DummyVectorEnv, ShmemVectorEnv, SubprocVectorEnv from qlib.typehint import Literal @@ -32,11 +31,11 @@ "vectorize_env", ] - FiniteEnvType = Literal["dummy", "subproc", "shmem"] +T = Union[dict, list, tuple, np.ndarray] -def fill_invalid(obj): +def fill_invalid(obj: int | float | bool | T) -> T: if isinstance(obj, (int, float, bool)): return fill_invalid(np.array(obj)) if hasattr(obj, "dtype"): @@ -55,11 +54,11 @@ def fill_invalid(obj): raise ValueError(f"Unsupported value to fill with invalid: {obj}") -def is_invalid(arr): - if hasattr(arr, "dtype"): +def is_invalid(arr: int | float | bool | T) -> bool: + if isinstance(arr, np.ndarray): if np.issubdtype(arr.dtype, np.floating): return np.isnan(arr).all() - return (np.iinfo(arr.dtype).max == arr).all() + return cast(bool, cast(np.ndarray, np.iinfo(arr.dtype).max == arr).all()) if isinstance(arr, dict): return all(is_invalid(o) for o in arr.values()) if isinstance(arr, (list, tuple)): @@ -140,44 +139,44 @@ def __init__( self._collector_guarded: bool = False - def _reset_alive_envs(self): + def _reset_alive_envs(self) -> None: if not self._alive_env_ids: # starting or running out self._alive_env_ids = set(range(self.env_num)) # to workaround with tianshou's buffer and batch - def _set_default_obs(self, obs): + def _set_default_obs(self, obs: Any) -> None: if obs is not None and self._default_obs is None: self._default_obs = copy.deepcopy(obs) - def _set_default_info(self, info): + def _set_default_info(self, info: Any) -> None: if info is not None and self._default_info is None: self._default_info = copy.deepcopy(info) - def _set_default_rew(self, rew): + def _set_default_rew(self, rew: Any) -> None: if rew is not None and self._default_rew is None: self._default_rew = copy.deepcopy(rew) - def _get_default_obs(self): + def _get_default_obs(self) -> Any: return copy.deepcopy(self._default_obs) - def _get_default_info(self): + def _get_default_info(self) -> Any: return copy.deepcopy(self._default_info) - def _get_default_rew(self): + def _get_default_rew(self) -> Any: return copy.deepcopy(self._default_rew) # END @staticmethod - def _postproc_env_obs(obs): + def _postproc_env_obs(obs: Any) -> Optional[Any]: # reserved for shmem vector env to restore empty observation if obs is None or check_nan_observation(obs): return None return obs @contextmanager - def collector_guard(self): + def collector_guard(self) -> Generator[FiniteVectorEnv, None, None]: """Guard the collector. Recommended to guard every collect. This guard is for two purposes. @@ -207,7 +206,10 @@ def collector_guard(self): for logger in self._logger: logger.on_env_all_done() - def reset(self, id=None): + def reset( + self, + id: int | List[int] | np.ndarray | None = None, + ) -> np.ndarray: assert not self._zombie # Check whether it's guarded by collector_guard() @@ -219,23 +221,23 @@ def reset(self, id=None): RuntimeWarning, ) - id = self._wrap_id(id) + wrapped_id = self._wrap_id(id) self._reset_alive_envs() # ask super to reset alive envs and remap to current index - request_id = list(filter(lambda i: i in self._alive_env_ids, id)) - obs = [None] * len(id) - id2idx = {i: k for k, i in enumerate(id)} + request_id = [i for i in wrapped_id if i in self._alive_env_ids] + obs = [None] * len(wrapped_id) + id2idx = {i: k for k, i in enumerate(wrapped_id)} if request_id: for i, o in zip(request_id, super().reset(request_id)): obs[id2idx[i]] = self._postproc_env_obs(o) - for i, o in zip(id, obs): + for i, o in zip(wrapped_id, obs): if o is None and i in self._alive_env_ids: self._alive_env_ids.remove(i) # logging - for i, o in zip(id, obs): + for i, o in zip(wrapped_id, obs): if i in self._alive_env_ids: for logger in self._logger: logger.on_env_reset(i, obs) @@ -248,19 +250,23 @@ def reset(self, id=None): obs[i] = self._get_default_obs() if not self._alive_env_ids: - # comment this line so that the env becomes indisposable + # comment this line so that the env becomes indispensable # self.reset() self._zombie = True raise StopIteration return np.stack(obs) - def step(self, action, id=None): + def step( + self, + action: np.ndarray, + id: int | List[int] | np.ndarray | None = None, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: assert not self._zombie - id = self._wrap_id(id) - id2idx = {i: k for k, i in enumerate(id)} - request_id = list(filter(lambda i: i in self._alive_env_ids, id)) - result = [[None, None, False, None] for _ in range(len(id))] + wrapped_id = self._wrap_id(id) + id2idx = {i: k for k, i in enumerate(wrapped_id)} + request_id = list(filter(lambda i: i in self._alive_env_ids, wrapped_id)) + result = [[None, None, False, None] for _ in range(len(wrapped_id))] # ask super to step alive envs and remap to current index if request_id: @@ -270,7 +276,7 @@ def step(self, action, id=None): result[id2idx[i]][0] = self._postproc_env_obs(result[id2idx[i]][0]) # logging - for i, r in zip(id, result): + for i, r in zip(wrapped_id, result): if i in self._alive_env_ids: for logger in self._logger: logger.on_env_step(i, *r) @@ -287,7 +293,8 @@ def step(self, action, id=None): if r[3] is None: result[i][3] = self._get_default_info() - return list(map(np.stack, zip(*result))) + ret = list(map(np.stack, zip(*result))) + return cast(Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], ret) class FiniteDummyVectorEnv(FiniteVectorEnv, DummyVectorEnv): @@ -306,7 +313,7 @@ def vectorize_env( env_factory: Callable[..., gym.Env], env_type: FiniteEnvType, concurrency: int, - logger: LogWriter | list[LogWriter], + logger: LogWriter | List[LogWriter], ) -> FiniteVectorEnv: """Helper function to create a vector env. Can be used to replace usual VectorEnv. @@ -350,7 +357,7 @@ def vectorize_env( def env_factory(): ... vectorize_env(env_factory, ...) """ - env_type_cls_mapping: dict[str, Type[FiniteVectorEnv]] = { + env_type_cls_mapping: Dict[str, Type[FiniteVectorEnv]] = { "dummy": FiniteDummyVectorEnv, "subproc": FiniteSubprocVectorEnv, "shmem": FiniteShmemVectorEnv, diff --git a/qlib/rl/utils/log.py b/qlib/rl/utils/log.py index 409a48a768..e15bf7b54b 100644 --- a/qlib/rl/utils/log.py +++ b/qlib/rl/utils/log.py @@ -21,7 +21,7 @@ from collections import defaultdict from enum import IntEnum from pathlib import Path -from typing import Any, TypeVar, Generic, Set, TYPE_CHECKING, Sequence, Callable +from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, List, Sequence, Set, Tuple, TypeVar import numpy as np import pandas as pd @@ -65,13 +65,13 @@ class LogCollector: ``min_loglevel`` is for optimization purposes: to avoid too much traffic on networks / in pipe. """ - _logged: dict[str, tuple[int, Any]] + _logged: Dict[str, Tuple[int, Any]] _min_loglevel: int - def __init__(self, min_loglevel: int | LogLevel = LogLevel.PERIODIC): + def __init__(self, min_loglevel: int | LogLevel = LogLevel.PERIODIC) -> None: self._min_loglevel = int(min_loglevel) - def reset(self): + def reset(self) -> None: """Clear all collected contents.""" self._logged = {} @@ -104,7 +104,10 @@ def add_scalar(self, name: str, scalar: Any, loglevel: int | LogLevel = LogLevel self._add_metric(name, scalar, loglevel) def add_array( - self, name: str, array: np.ndarray | pd.DataFrame | pd.Series, loglevel: int | LogLevel = LogLevel.PERIODIC + self, + name: str, + array: np.ndarray | pd.DataFrame | pd.Series, + loglevel: int | LogLevel = LogLevel.PERIODIC, ) -> None: """Add an array with name into logging.""" if loglevel < self._min_loglevel: @@ -127,7 +130,7 @@ def add_any(self, name: str, obj: Any, loglevel: int | LogLevel = LogLevel.PERIO self._add_metric(name, obj, loglevel) - def logs(self) -> dict[str, np.ndarray]: + def logs(self) -> Dict[str, np.ndarray]: return {key: np.asanyarray(value, dtype="object") for key, value in self._logged.items()} @@ -154,16 +157,16 @@ class LogWriter(Generic[ObsType, ActType]): active_env_ids: Set[int] """Active environment ids in vector env.""" - episode_lengths: dict[int, int] + episode_lengths: Dict[int, int] """Map from environment id to episode length.""" - episode_rewards: dict[int, list[float]] + episode_rewards: Dict[int, List[float]] """Map from environment id to episode total reward.""" - episode_logs: dict[int, list] + episode_logs: Dict[int, list] """Map from environment id to episode logs.""" - def __init__(self, loglevel: int | LogLevel = LogLevel.PERIODIC): + def __init__(self, loglevel: int | LogLevel = LogLevel.PERIODIC) -> None: self.loglevel = loglevel self.global_step = 0 @@ -207,11 +210,12 @@ def load_state_dict(self, state_dict: dict) -> None: # These are runtime infos. # Though they are loaded, I don't think it really helps. self.active_env_ids = state_dict["active_env_ids"] - self.episode_lenghts = state_dict["episode_lengths"] + self.episode_lengths = state_dict["episode_lengths"] self.episode_rewards = state_dict["episode_rewards"] self.episode_logs = state_dict["episode_logs"] - def aggregation(self, array: Sequence[Any], name: str | None = None) -> Any: + @staticmethod + def aggregation(array: Sequence[Any], name: str | None = None) -> Any: """Aggregation function from step-wise to episode-wise. If it's a sequence of float, take the mean. @@ -229,7 +233,7 @@ def aggregation(self, array: Sequence[Any], name: str | None = None) -> Any: else: return array[0] - def log_episode(self, length: int, rewards: list[float], contents: list[dict[str, Any]]) -> None: + def log_episode(self, length: int, rewards: List[float], contents: List[Dict[str, Any]]) -> None: """This is triggered at the end of each trajectory. Parameters @@ -242,7 +246,7 @@ def log_episode(self, length: int, rewards: list[float], contents: list[dict[str Logged contents for every steps. """ - def log_step(self, reward: float, contents: dict[str, Any]) -> None: + def log_step(self, reward: float, contents: Dict[str, Any]) -> None: """This is triggered at each step. Parameters @@ -265,7 +269,7 @@ def on_env_step(self, env_id: int, obs: ObsType, rew: float, done: bool, info: I # TODO: reward can be a list of list for MARL self.episode_rewards[env_id].append(rew) - values: dict[str, Any] = {} + values: Dict[str, Any] = {} for key, (loglevel, value) in info["log"].items(): if loglevel >= self.loglevel: # FIXME: this is actually incorrect (see last FIXME) @@ -393,11 +397,11 @@ class ConsoleWriter(LogWriter): def __init__( self, log_every_n_episode: int = 20, - total_episodes: int | None = None, + total_episodes: int = None, float_format: str = ":.4f", counter_format: str = ":4d", loglevel: int | LogLevel = LogLevel.PERIODIC, - ): + ) -> None: super().__init__(loglevel) # TODO: support log_every_n_step self.log_every_n_episode = log_every_n_episode @@ -412,15 +416,15 @@ def __init__( # FIXME: save & reload - def clear(self): + def clear(self) -> None: super().clear() # Clear average meters - self.metric_counts: dict[str, int] = defaultdict(int) - self.metric_sums: dict[str, float] = defaultdict(float) + self.metric_counts: Dict[str, int] = defaultdict(int) + self.metric_sums: Dict[str, float] = defaultdict(float) - def log_episode(self, length: int, rewards: list[float], contents: list[dict[str, Any]]) -> None: + def log_episode(self, length: int, rewards: List[float], contents: List[Dict[str, Any]]) -> None: # Aggregate step-wise to episode-wise - episode_wise_contents: dict[str, list] = defaultdict(list) + episode_wise_contents: Dict[str, list] = defaultdict(list) for step_contents in contents: for name, value in step_contents.items(): @@ -429,7 +433,7 @@ def log_episode(self, length: int, rewards: list[float], contents: list[dict[str # Generate log contents and track them in average-meter. # This should be done at every step, regardless of periodic or not. - logs: dict[str, float] = {} + logs: Dict[str, float] = {} for name, values in episode_wise_contents.items(): logs[name] = self.aggregation(values, name) # type: ignore @@ -441,7 +445,7 @@ def log_episode(self, length: int, rewards: list[float], contents: list[dict[str # Only log periodically or at the end self.console_logger.info(self.generate_log_message(logs)) - def generate_log_message(self, logs: dict[str, float]) -> str: + def generate_log_message(self, logs: Dict[str, float]) -> str: if self.prefix: msg_prefix = self.prefix + " " else: @@ -471,29 +475,29 @@ class CsvWriter(LogWriter): SUPPORTED_TYPES = (float, str, pd.Timestamp) - all_records: list[dict[str, Any]] + all_records: List[Dict[str, Any]] # FIXME: save & reload - def __init__(self, output_dir: Path, loglevel: int | LogLevel = LogLevel.PERIODIC): + def __init__(self, output_dir: Path, loglevel: int | LogLevel = LogLevel.PERIODIC) -> None: super().__init__(loglevel) self.output_dir = output_dir self.output_dir.mkdir(exist_ok=True) - def clear(self): + def clear(self) -> None: super().clear() self.all_records = [] - def log_episode(self, length: int, rewards: list[float], contents: list[dict[str, Any]]) -> None: + def log_episode(self, length: int, rewards: List[float], contents: List[Dict[str, Any]]) -> None: # FIXME Same as ConsoleLogger, needs a refactor to eliminate code-dup - episode_wise_contents: dict[str, list] = defaultdict(list) + episode_wise_contents: Dict[str, list] = defaultdict(list) for step_contents in contents: for name, value in step_contents.items(): if isinstance(value, self.SUPPORTED_TYPES): episode_wise_contents[name].append(value) - logs: dict[str, float] = {} + logs: Dict[str, float] = {} for name, values in episode_wise_contents.items(): logs[name] = self.aggregation(values, name) # type: ignore diff --git a/qlib/strategy/base.py b/qlib/strategy/base.py index 37998a4afb..27df347fc5 100644 --- a/qlib/strategy/base.py +++ b/qlib/strategy/base.py @@ -2,14 +2,14 @@ # Licensed under the MIT License. from __future__ import annotations -from abc import abstractmethod -from typing import TYPE_CHECKING, Any, Generator, Optional +from abc import ABCMeta, abstractmethod +from typing import Any, Generator, Optional, TYPE_CHECKING, Union if TYPE_CHECKING: from qlib.backtest.exchange import Exchange from qlib.backtest.position import BasePosition -from typing import Tuple, Union +from typing import Tuple from ..backtest.decision import BaseTradeDecision from ..backtest.utils import CommonInfrastructure, LevelInfrastructure, TradeCalendarManager @@ -207,8 +207,18 @@ def get_data_cal_avail_range(self, rtype: str = "full") -> Tuple[int, int]: range_limit = self.outer_trade_decision.get_data_cal_range_limit(rtype=rtype) return max(cal_range[0], range_limit[0]), min(cal_range[1], range_limit[1]) + def post_exe_step(self, execute_result: list) -> None: + """ + A hook for doing sth after the corresponding executor finished its execution. + + Parameters + ---------- + execute_result : + the execution result + """ + -class RLStrategy(BaseStrategy): +class RLStrategy(BaseStrategy, metaclass=ABCMeta): """RL-based strategy""" def __init__( @@ -229,14 +239,14 @@ def __init__( self.policy = policy -class RLIntStrategy(RLStrategy): +class RLIntStrategy(RLStrategy, metaclass=ABCMeta): """(RL)-based (Strategy) with (Int)erpreter""" def __init__( self, policy, - state_interpreter: Union[dict, StateInterpreter], - action_interpreter: Union[dict, ActionInterpreter], + state_interpreter: dict | StateInterpreter, + action_interpreter: dict | ActionInterpreter, outer_trade_decision: BaseTradeDecision = None, level_infra: LevelInfrastructure = None, common_infra: CommonInfrastructure = None, diff --git a/qlib/utils/index_data.py b/qlib/utils/index_data.py index 9f1aab4fed..ea935dcaa5 100644 --- a/qlib/utils/index_data.py +++ b/qlib/utils/index_data.py @@ -271,7 +271,7 @@ def __getitem__(self, indexing): if isinstance(_indexing, IndexData): _indexing = _indexing.data assert _indexing.ndim == 1 - if _indexing.dtype != np.bool: + if _indexing.dtype != bool: _indexing = np.array(list(index.index(i) for i in _indexing)) else: _indexing = index.index(_indexing) @@ -431,7 +431,7 @@ def sort_index(self, axis=0, inplace=True): # The code below could be simpler like methods in __getattribute__ def __invert__(self): - return self.__class__(~self.data.astype(np.bool), *self.indices) + return self.__class__(~self.data.astype(bool), *self.indices) def abs(self): """get the abs of data except np.NaN.""" diff --git a/tests/rl/test_logger.py b/tests/rl/test_logger.py index 2cf149a75f..c8ceca92ad 100644 --- a/tests/rl/test_logger.py +++ b/tests/rl/test_logger.py @@ -5,6 +5,8 @@ from pathlib import Path import re +from typing import Any, Tuple + import gym import numpy as np import pandas as pd @@ -24,16 +26,16 @@ class SimpleEnv(gym.Env[int, int]): - def __init__(self): + def __init__(self) -> None: self.logger = LogCollector() self.observation_space = gym.spaces.Discrete(2) self.action_space = gym.spaces.Discrete(2) - def reset(self): + def reset(self, *args: Any, **kwargs: Any) -> int: self.step_count = 0 return 0 - def step(self, action: int): + def step(self, action: int) -> Tuple[int, float, bool, dict]: self.logger.reset() self.logger.add_scalar("reward", 42.0) @@ -53,6 +55,9 @@ def step(self, action: int): return 1, 42.0, done, InfoDict(log=self.logger.logs(), aux_info={}) + def render(self, mode: str = "human") -> None: + pass + class AnyPolicy(BasePolicy): def forward(self, batch, state=None): @@ -86,7 +91,8 @@ def test_simple_env_logger(caplog): class SimpleSimulator(Simulator[int, float, float]): - def __init__(self, initial: int, **kwargs) -> None: + def __init__(self, initial: int, **kwargs: Any) -> None: + super(SimpleSimulator, self).__init__(initial, **kwargs) self.initial = float(initial) def step(self, action: float) -> None: diff --git a/tests/rl/test_qlib_simulator.py b/tests/rl/test_qlib_simulator.py new file mode 100644 index 0000000000..ca7820645f --- /dev/null +++ b/tests/rl/test_qlib_simulator.py @@ -0,0 +1,177 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import sys +from pathlib import Path + +import pandas as pd +import pytest + +from qlib.backtest.decision import Order, OrderDir +from qlib.backtest.executor import NestedExecutor, SimulatorExecutor +from qlib.backtest.utils import CommonInfrastructure +from qlib.contrib.strategy import TWAPStrategy +from qlib.rl.order_execution import CategoricalActionInterpreter +from qlib.rl.order_execution.simulator_qlib import ExchangeConfig, SingleAssetOrderExecutionQlib + +TOTAL_POSITION = 2100.0 + +python_version_request = pytest.mark.skipif(sys.version_info < (3, 8), reason="requires python3.8 or higher") + + +def is_close(a: float, b: float, epsilon: float = 1e-4) -> bool: + return abs(a - b) <= epsilon + + +def get_order() -> Order: + return Order( + stock_id="SH600000", + amount=TOTAL_POSITION, + direction=OrderDir.BUY, + start_time=pd.Timestamp("2019-03-04 09:30:00"), + end_time=pd.Timestamp("2019-03-04 14:29:00"), + ) + + +def get_simulator(order: Order) -> SingleAssetOrderExecutionQlib: + def _inner_executor_fn(time_per_step: str, common_infra: CommonInfrastructure) -> NestedExecutor: + return NestedExecutor( + time_per_step=time_per_step, + inner_strategy=TWAPStrategy(), + inner_executor=SimulatorExecutor( + time_per_step="1min", + verbose=False, + trade_type=SimulatorExecutor.TT_SERIAL, + generate_report=False, + common_infra=common_infra, + track_data=True, + ), + common_infra=common_infra, + track_data=True, + ) + + DATA_ROOT_DIR = Path(__file__).parent.parent / ".data" / "rl" / "qlib_simulator" + + # fmt: off + qlib_config = { + "provider_uri_day": DATA_ROOT_DIR / "qlib_1d", + "provider_uri_1min": DATA_ROOT_DIR / "qlib_1min", + "feature_root_dir": DATA_ROOT_DIR / "qlib_handler_stock", + "feature_columns_today": [ + "$open", "$high", "$low", "$close", "$vwap", "$bid", "$ask", "$volume", + "$bidV", "$bidV1", "$bidV3", "$bidV5", "$askV", "$askV1", "$askV3", "$askV5", + ], + "feature_columns_yesterday": [ + "$open_1", "$high_1", "$low_1", "$close_1", "$vwap_1", "$bid_1", "$ask_1", "$volume_1", + "$bidV_1", "$bidV1_1", "$bidV3_1", "$bidV5_1", "$askV_1", "$askV1_1", "$askV3_1", "$askV5_1", + ], + } + # fmt: on + + exchange_config = ExchangeConfig( + limit_threshold=("$ask == 0", "$bid == 0"), + deal_price=("If($ask == 0, $bid, $ask)", "If($bid == 0, $ask, $bid)"), + volume_threshold={ + "all": ("cum", "0.2 * DayCumsum($volume, '9:30', '14:29')"), + "buy": ("current", "$askV1"), + "sell": ("current", "$bidV1"), + }, + open_cost=0.0005, + close_cost=0.0015, + min_cost=5.0, + trade_unit=None, + cash_limit=None, + generate_report=False, + ) + + return SingleAssetOrderExecutionQlib( + order=order, + time_per_step="30min", + qlib_config=qlib_config, + inner_executor_fn=_inner_executor_fn, + exchange_config=exchange_config, + ) + + +@python_version_request +def test_simulator_first_step(): + order = get_order() + simulator = get_simulator(order) + state = simulator.get_state() + assert state.cur_time == pd.Timestamp("2019-03-04 09:30:00") + assert state.position == TOTAL_POSITION + + AMOUNT = 300.0 + simulator.step(AMOUNT) + state = simulator.get_state() + assert state.cur_time == pd.Timestamp("2019-03-04 10:00:00") + assert state.position == TOTAL_POSITION - AMOUNT + assert len(state.history_exec) == 30 + assert state.history_exec.index[0] == pd.Timestamp("2019-03-04 09:30:00") + + assert is_close(state.history_exec["market_volume"].iloc[0], 109382.382812) + assert is_close(state.history_exec["market_price"].iloc[0], 149.566483) + assert (state.history_exec["amount"] == AMOUNT / 30).all() + assert (state.history_exec["deal_amount"] == AMOUNT / 30).all() + assert is_close(state.history_exec["trade_price"].iloc[0], 149.566483) + assert is_close(state.history_exec["trade_value"].iloc[0], 1495.664825) + assert is_close(state.history_exec["position"].iloc[0], TOTAL_POSITION - AMOUNT / 30) + # assert state.history_exec["ffr"].iloc[0] == 1 / 60 # FIXME + + assert is_close(state.history_steps["market_volume"].iloc[0], 1254848.5756835938) + assert state.history_steps["amount"].iloc[0] == AMOUNT + assert state.history_steps["deal_amount"].iloc[0] == AMOUNT + assert state.history_steps["ffr"].iloc[0] == 1.0 + assert is_close( + state.history_steps["pa"].iloc[0] * (1.0 if order.direction == OrderDir.SELL else -1.0), + (state.history_steps["trade_price"].iloc[0] / simulator.twap_price - 1) * 10000, + ) + + +@python_version_request +def test_simulator_stop_twap() -> None: + order = get_order() + simulator = get_simulator(order) + NUM_STEPS = 7 + for i in range(NUM_STEPS): + simulator.step(TOTAL_POSITION / NUM_STEPS) + + HISTORY_STEP_LENGTH = 30 * NUM_STEPS + state = simulator.get_state() + assert len(state.history_exec) == HISTORY_STEP_LENGTH + + assert (state.history_exec["deal_amount"] == TOTAL_POSITION / HISTORY_STEP_LENGTH).all() + assert is_close(state.history_steps["position"].iloc[0], TOTAL_POSITION * (NUM_STEPS - 1) / NUM_STEPS) + assert is_close(state.history_steps["position"].iloc[-1], 0.0) + assert is_close(state.position, 0.0) + assert is_close(state.metrics["ffr"], 1.0) + + assert is_close(state.metrics["market_price"], state.backtest_data.get_deal_price().mean()) + assert is_close(state.metrics["market_volume"], state.backtest_data.get_volume().sum()) + assert is_close(state.metrics["trade_price"], state.metrics["market_price"]) + assert is_close(state.metrics["pa"], 0.0) + + assert simulator.done() + + +@python_version_request +def test_interpreter() -> None: + NUM_EXECUTION = 3 + order = get_order() + simulator = get_simulator(order) + interpreter_action = CategoricalActionInterpreter(values=NUM_EXECUTION) + + NUM_STEPS = 7 + state = simulator.get_state() + position_history = [] + for i in range(NUM_STEPS): + simulator.step(interpreter_action(state, 1)) + state = simulator.get_state() + position_history.append(state.position) + + assert position_history[-1] == max(TOTAL_POSITION - TOTAL_POSITION / NUM_EXECUTION * (i + 1), 0.0) + + +if __name__ == "__main__": + test_simulator_first_step() + test_simulator_stop_twap() + test_interpreter() diff --git a/tests/rl/test_saoe_simple.py b/tests/rl/test_saoe_simple.py index 98e5dd9817..78df41690a 100644 --- a/tests/rl/test_saoe_simple.py +++ b/tests/rl/test_saoe_simple.py @@ -9,7 +9,6 @@ import numpy as np import pandas as pd import pytest - import torch from tianshou.data import Batch @@ -17,8 +16,8 @@ from qlib.config import C from qlib.log import set_log_with_config from qlib.rl.data import pickle_styled -from qlib.rl.trainer import backtest, train from qlib.rl.order_execution import * +from qlib.rl.trainer import backtest, train from qlib.rl.utils import ConsoleWriter, CsvWriter, EnvWrapperStatus pytestmark = pytest.mark.skipif(sys.version_info < (3, 8), reason="Pickle styled data only supports Python >= 3.8") @@ -38,7 +37,7 @@ def test_pickle_data_inspect(): - data = pickle_styled.load_intraday_backtest_data(BACKTEST_DATA_DIR, "AAL", "2013-12-11", "close", 0) + data = pickle_styled.load_simple_intraday_backtest_data(BACKTEST_DATA_DIR, "AAL", "2013-12-11", "close", 0) assert len(data) == 390 data = pickle_styled.load_intraday_processed_data(