diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
index 5846263c89..3f77a7273f 100644
--- a/.github/workflows/test.yml
+++ b/.github/workflows/test.yml
@@ -49,6 +49,15 @@ jobs:
pip install --upgrade cython jupyter jupyter_contrib_nbextensions numpy scipy scikit-learn # installing without this line will cause errors on GitHub Actions, while instsalling locally won't
pip install -e .
+ - name: Test data downloads
+ run: |
+ if [ "$RUNNER_OS" == "Windows" ]; then
+ $CONDA\\python.exe scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
+ else
+ $CONDA/bin/python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
+ fi
+ shell: bash
+
- name: Install test dependencies
run: |
pip install --upgrade pip
diff --git a/.github/workflows/test_macos.yml b/.github/workflows/test_macos.yml
index 40c71c6a5f..0db80610a6 100644
--- a/.github/workflows/test_macos.yml
+++ b/.github/workflows/test_macos.yml
@@ -31,6 +31,7 @@ jobs:
python -m pip install black
python -m black qlib -l 120 --check --diff
# Test Qlib installed with pip
+
- name: Install Qlib with pip
run: |
python -m pip install numpy==1.19.5
diff --git a/.gitignore b/.gitignore
index 33a2a25303..a563ed5c7f 100644
--- a/.gitignore
+++ b/.gitignore
@@ -20,6 +20,7 @@ dist/
.nvimrc
.vscode
+qlib/VERSION.txt
qlib/data/_libs/expanding.cpp
qlib/data/_libs/rolling.cpp
examples/estimator/estimator_example/
diff --git a/README.md b/README.md
index db0b6124e9..6ceb26e66d 100644
--- a/README.md
+++ b/README.md
@@ -25,7 +25,7 @@ Recent released features
Features released before 2021 are not listed here.
-
+
@@ -70,7 +70,7 @@ Your feedbacks about the features are very important.
# Framework of Qlib
-
+
@@ -247,19 +247,19 @@ Qlib provides a tool named `qrun` to run the whole workflow automatically (inclu
2. Graphical Reports Analysis: Run `examples/workflow_by_code.ipynb` with `jupyter notebook` to get graphical reports
- Forecasting signal (model prediction) analysis
- Cumulative Return of groups
- ![Cumulative Return](http://fintech.msra.cn/images_v060/analysis/analysis_model_cumulative_return.png?v=0.1)
+ ![Cumulative Return](http://fintech.msra.cn/images_v070/analysis/analysis_model_cumulative_return.png?v=0.1)
- Return distribution
- ![long_short](http://fintech.msra.cn/images_v060/analysis/analysis_model_long_short.png?v=0.1)
+ ![long_short](http://fintech.msra.cn/images_v070/analysis/analysis_model_long_short.png?v=0.1)
- Information Coefficient (IC)
- ![Information Coefficient](http://fintech.msra.cn/images_v060/analysis/analysis_model_IC.png?v=0.1)
- ![Monthly IC](http://fintech.msra.cn/images_v060/analysis/analysis_model_monthly_IC.png?v=0.1)
- ![IC](http://fintech.msra.cn/images_v060/analysis/analysis_model_NDQ.png?v=0.1)
+ ![Information Coefficient](http://fintech.msra.cn/images_v070/analysis/analysis_model_IC.png?v=0.1)
+ ![Monthly IC](http://fintech.msra.cn/images_v070/analysis/analysis_model_monthly_IC.png?v=0.1)
+ ![IC](http://fintech.msra.cn/images_v070/analysis/analysis_model_NDQ.png?v=0.1)
- Auto Correlation of forecasting signal (model prediction)
- ![Auto Correlation](http://fintech.msra.cn/images_v060/analysis/analysis_model_auto_correlation.png?v=0.1)
+ ![Auto Correlation](http://fintech.msra.cn/images_v070/analysis/analysis_model_auto_correlation.png?v=0.1)
- Portfolio analysis
- Backtest return
- ![Report](http://fintech.msra.cn/images_v060/analysis/report.png?v=0.1)
+ ![Report](http://fintech.msra.cn/images_v070/analysis/report.png?v=0.1)
+
+
\ No newline at end of file
diff --git a/docs/component/backtest.rst b/docs/component/backtest.rst
index 88e01e2de4..e83e1023a8 100644
--- a/docs/component/backtest.rst
+++ b/docs/component/backtest.rst
@@ -30,7 +30,7 @@ The simple example of the default strategy is as follows.
from qlib.contrib.evaluate import backtest
# pred_score is the prediction score
- report, positions = backtest(pred_score, topk=50, n_drop=0.5, verbose=False, limit_threshold=0.0095)
+ report, positions = backtest(pred_score, topk=50, n_drop=0.5, limit_threshold=0.0095)
To know more about backtesting with a specific ``Strategy``, please refer to `Portfolio Strategy `_.
diff --git a/docs/component/highfreq.rst b/docs/component/highfreq.rst
new file mode 100644
index 0000000000..13ebb959de
--- /dev/null
+++ b/docs/component/highfreq.rst
@@ -0,0 +1,120 @@
+.. _highfreq:
+
+============================================
+Design of hierarchical order execution framework
+============================================
+.. currentmodule:: qlib
+
+Introduction
+===================
+
+In order to support reinforcement learning algorithms for high-frequency trading, a corresponding framework is required. None of the publicly available high-frequency trading frameworks now consider multi-layer trading mechanisms, and the currently designed algorithms cannot directly use existing frameworks.
+In addition to supporting the basic intraday multi-layer trading, the linkage with the day-ahead strategy is also a factor that affects the performance evaluation of the strategy. Different day strategies generate different order distributions and different patterns on different stocks. To verify that high-frequency trading strategies perform well on real trading orders, it is necessary to support day-frequency and high-frequency multi-level linkage trading. In addition to more accurate backtesting of high-frequency trading algorithms, if the distribution of day-frequency orders is considered when training a high-frequency trading model, the algorithm can also be optimized more for product-specific day-frequency orders.
+Therefore, innovation in the high-frequency trading framework is necessary to solve the various problems mentioned above, for which we designed a hierarchical order execution framework that can link daily-frequency and intra-day trading at different granularities.
+
+.. image:: ../_static/img/framework.svg
+
+The design of the framework is shown in the figure above. At each layer consists of Trading Agent and Execution Env. The Trading Agent has its own data processing module (Information Extractor), forecasting module (Forecast Model) and decision generator (Decision Generator). The trading algorithm generates the corresponding decisions by the Decision Generator based on the forecast signals output by the Forecast Module, and the decisions generated by the trading algorithm are passed to the Execution Env, which returns the execution results. Here the frequency of trading algorithm, decision content and execution environment can be customized by users (e.g. intra-day trading, daily-frequency trading, weekly-frequency trading), and the execution environment can be nested with finer-grained trading algorithm and execution environment inside (i.e. sub-workflow in the figure, e.g. daily-frequency orders can be turned into finer-grained decisions by splitting orders within the day). The hierarchical order execution framework is user-defined in terms of hierarchy division and decision frequency, making it easy for users to explore the effects of combining different levels of trading algorithms and breaking down the barriers between different levels of trading algorithm optimization.
+In addition to the innovation in the framework, the hierarchical order execution framework also takes into account various details of the real backtesting environment, minimizing the differences with the final real environment as much as possible. At the same time, the framework is designed to unify the interface between online and offline (e.g. data pre-processing level supports using the same set of code to process both offline and online data) to reduce the cost of strategy go-live as much as possible.
+
+Prepare Data
+===================
+.. _data:: ../../examples/highfreq/README.md
+
+
+Example
+===========================
+
+Here is an example of highfreq execution.
+
+.. code-block:: python
+
+ import qlib
+ # init qlib
+ provider_uri_day = "~/.qlib/qlib_data/cn_data"
+ provider_uri_1min = "~/.qlib/qlib_data/cn_data_1min"
+ provider_uri_map = {"1min": provider_uri_1min, "day": provider_uri_day}
+ qlib.init(provider_uri=provider_uri_day, expression_cache=None, dataset_cache=None)
+
+ # data freq and backtest time
+ freq = "1min"
+ inst_list = D.list_instruments(D.instruments("all"), as_list=True)
+ start_time = "2020-01-01"
+ start_time = "2020-01-31"
+
+When initializing qlib, if the default data is used, then both daily and minute frequency data need to be passed in.
+
+.. code-block:: python
+
+ # random order strategy config
+ strategy_config = {
+ "class": "RandomOrderStrategy",
+ "module_path": "qlib.contrib.strategy.rule_strategy",
+ "kwargs": {
+ "trade_range": TradeRangeByTime("9:30", "15:00"),
+ "sample_ratio": 1.0,
+ "volume_ratio": 0.01,
+ "market": market,
+ },
+ }
+
+.. code-block:: python
+ # backtest config
+ backtest_config = {
+ "start_time": start_time,
+ "end_time": end_time,
+ "account": 100000000,
+ "benchmark": None,
+ "exchange_kwargs": {
+ "freq": freq,
+ "limit_threshold": 0.095,
+ "deal_price": "close",
+ "open_cost": 0.0005,
+ "close_cost": 0.0015,
+ "min_cost": 5,
+ "codes": market,
+ },
+ "pos_type": "InfPosition", # Position with infinitive position
+ }
+
+please refer to "../../qlib/backtest".
+
+.. code-block:: python
+ # excutor config
+ executor_config = {
+ "class": "NestedExecutor",
+ "module_path": "qlib.backtest.executor",
+ "kwargs": {
+ "time_per_step": "day",
+ "inner_executor": {
+ "class": "SimulatorExecutor",
+ "module_path": "qlib.backtest.executor",
+ "kwargs": {
+ "time_per_step": freq,
+ "generate_portfolio_metrics": True,
+ "verbose": False,
+ # "verbose": True,
+ "indicator_config": {
+ "show_indicator": False,
+ },
+ },
+ },
+ "inner_strategy": {
+ "class": "TWAPStrategy",
+ "module_path": "qlib.contrib.strategy.rule_strategy",
+ },
+ "track_data": True,
+ "generate_portfolio_metrics": True,
+ "indicator_config": {
+ "show_indicator": True,
+ },
+ },
+ }
+
+NestedExecutor represents not the innermost layer, the initialization parameters should contain inner_executor and inner_strategy. simulatorExecutor represents the current excutor is the innermost layer, the innermost strategy used here is the TWAP strategy, the framework currently also supports the VWAP strategy
+
+.. code-block:: python
+ # backtest
+ portfolio_metrics_dict, indicator_dict = backtest(executor=executor_config, strategy=strategy_config, **backtest_config)
+
+The metrics of backtest are included in the portfolio_metrics_dict and indicator_dict.
diff --git a/docs/component/recorder.rst b/docs/component/recorder.rst
index cc425fa8e6..5a7d195d64 100644
--- a/docs/component/recorder.rst
+++ b/docs/component/recorder.rst
@@ -123,7 +123,6 @@ Here is a simple exampke of what is done in ``PortAnaRecord``, which users can r
"n_drop": 5,
}
BACKTEST_CONFIG = {
- "verbose": False,
"limit_threshold": 0.095,
"account": 100000000,
"benchmark": BENCHMARK,
diff --git a/docs/component/strategy.rst b/docs/component/strategy.rst
index e4a5a94d15..c9d002ca1b 100644
--- a/docs/component/strategy.rst
+++ b/docs/component/strategy.rst
@@ -93,7 +93,6 @@ Usage & Example
"n_drop": 5,
}
BACKTEST_CONFIG = {
- "verbose": False,
"limit_threshold": 0.095,
"account": 100000000,
"benchmark": BENCHMARK,
diff --git a/docs/component/workflow.rst b/docs/component/workflow.rst
index 2b7ec19ad3..84522af998 100644
--- a/docs/component/workflow.rst
+++ b/docs/component/workflow.rst
@@ -54,7 +54,6 @@ Below is a typical config file of ``qrun``.
topk: 50
n_drop: 5
backtest:
- verbose: False
limit_threshold: 0.095
account: 100000000
benchmark: *benchmark
@@ -242,7 +241,6 @@ The following script is the configuration of `backtest` and the `strategy` used
topk: 50
n_drop: 5
backtest:
- verbose: False
limit_threshold: 0.095
account: 100000000
benchmark: *benchmark
diff --git a/docs/hidden/tuner.rst b/docs/hidden/tuner.rst
index 6d62f899f5..8abf2ec7c0 100644
--- a/docs/hidden/tuner.rst
+++ b/docs/hidden/tuner.rst
@@ -93,7 +93,6 @@ We write a simple configuration example as following,
fend_time: 2018-12-11
backtest:
normal_backtest_args:
- verbose: False
limit_threshold: 0.095
account: 500000
benchmark: SH000905
@@ -306,7 +305,6 @@ About the data and backtest
fend_time: 2018-12-11
backtest:
normal_backtest_args:
- verbose: False
limit_threshold: 0.095
account: 500000
benchmark: SH000905
diff --git a/docs/introduction/introduction.rst b/docs/introduction/introduction.rst
index 06fac46faf..a55edd5eca 100644
--- a/docs/introduction/introduction.rst
+++ b/docs/introduction/introduction.rst
@@ -15,7 +15,7 @@ With ``Qlib``, users can easily try their ideas to create better Quant investmen
Framework
===================
-.. image:: ../_static/img/framework.png
+.. image:: ../_static/img/framework.svg
:align: center
diff --git a/examples/benchmarks/ALSTM/workflow_config_alstm_Alpha158.yaml b/examples/benchmarks/ALSTM/workflow_config_alstm_Alpha158.yaml
index 878e3c0654..039040d8f8 100755
--- a/examples/benchmarks/ALSTM/workflow_config_alstm_Alpha158.yaml
+++ b/examples/benchmarks/ALSTM/workflow_config_alstm_Alpha158.yaml
@@ -34,19 +34,23 @@ data_handler_config: &data_handler_config
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
- module_path: qlib.contrib.strategy.strategy
+ module_path: qlib.contrib.strategy
kwargs:
+ model:
+ dataset:
topk: 50
n_drop: 5
backtest:
- verbose: False
- limit_threshold: 0.095
+ start_time: 2017-01-01
+ end_time: 2020-08-01
account: 100000000
benchmark: *benchmark
- deal_price: close
- open_cost: 0.0005
- close_cost: 0.0015
- min_cost: 5
+ exchange_kwargs:
+ limit_threshold: 0.095
+ deal_price: close
+ open_cost: 0.0005
+ close_cost: 0.0015
+ min_cost: 5
task:
model:
class: ALSTM
@@ -81,7 +85,9 @@ task:
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
- kwargs: {}
+ kwargs:
+ model:
+ dataset:
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
diff --git a/examples/benchmarks/ALSTM/workflow_config_alstm_Alpha360.yaml b/examples/benchmarks/ALSTM/workflow_config_alstm_Alpha360.yaml
index 6226cdaf26..88c6fcd07e 100644
--- a/examples/benchmarks/ALSTM/workflow_config_alstm_Alpha360.yaml
+++ b/examples/benchmarks/ALSTM/workflow_config_alstm_Alpha360.yaml
@@ -26,19 +26,23 @@ data_handler_config: &data_handler_config
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
- module_path: qlib.contrib.strategy.strategy
+ module_path: qlib.contrib.strategy
kwargs:
+ model:
+ dataset:
topk: 50
n_drop: 5
backtest:
- verbose: False
- limit_threshold: 0.095
+ start_time: 2017-01-01
+ end_time: 2020-08-01
account: 100000000
benchmark: *benchmark
- deal_price: close
- open_cost: 0.0005
- close_cost: 0.0015
- min_cost: 5
+ exchange_kwargs:
+ limit_threshold: 0.095
+ deal_price: close
+ open_cost: 0.0005
+ close_cost: 0.0015
+ min_cost: 5
task:
model:
class: ALSTM
@@ -71,7 +75,9 @@ task:
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
- kwargs: {}
+ kwargs:
+ model:
+ dataset:
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
diff --git a/examples/benchmarks/CatBoost/workflow_config_catboost_Alpha158.yaml b/examples/benchmarks/CatBoost/workflow_config_catboost_Alpha158.yaml
index af556dc871..18e19bd0f8 100644
--- a/examples/benchmarks/CatBoost/workflow_config_catboost_Alpha158.yaml
+++ b/examples/benchmarks/CatBoost/workflow_config_catboost_Alpha158.yaml
@@ -12,19 +12,23 @@ data_handler_config: &data_handler_config
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
- module_path: qlib.contrib.strategy.strategy
+ module_path: qlib.contrib.strategy
kwargs:
+ model:
+ dataset:
topk: 50
n_drop: 5
backtest:
- verbose: False
- limit_threshold: 0.095
+ start_time: 2017-01-01
+ end_time: 2020-08-01
account: 100000000
benchmark: *benchmark
- deal_price: close
- open_cost: 0.0005
- close_cost: 0.0015
- min_cost: 5
+ exchange_kwargs:
+ limit_threshold: 0.095
+ deal_price: close
+ open_cost: 0.0005
+ close_cost: 0.0015
+ min_cost: 5
task:
model:
class: CatBoostModel
@@ -53,7 +57,9 @@ task:
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
- kwargs: {}
+ kwargs:
+ model:
+ dataset:
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
diff --git a/examples/benchmarks/CatBoost/workflow_config_catboost_Alpha360.yaml b/examples/benchmarks/CatBoost/workflow_config_catboost_Alpha360.yaml
index f7dc26f5d8..a6cdd18829 100644
--- a/examples/benchmarks/CatBoost/workflow_config_catboost_Alpha360.yaml
+++ b/examples/benchmarks/CatBoost/workflow_config_catboost_Alpha360.yaml
@@ -19,19 +19,23 @@ data_handler_config: &data_handler_config
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
- module_path: qlib.contrib.strategy.strategy
+ module_path: qlib.contrib.strategy
kwargs:
+ model:
+ dataset:
topk: 50
n_drop: 5
backtest:
- verbose: False
- limit_threshold: 0.095
+ start_time: 2017-01-01
+ end_time: 2020-08-01
account: 100000000
benchmark: *benchmark
- deal_price: close
- open_cost: 0.0005
- close_cost: 0.0015
- min_cost: 5
+ exchange_kwargs:
+ limit_threshold: 0.095
+ deal_price: close
+ open_cost: 0.0005
+ close_cost: 0.0015
+ min_cost: 5
task:
model:
class: CatBoostModel
@@ -60,7 +64,9 @@ task:
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
- kwargs: {}
+ kwargs:
+ model:
+ dataset:
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
diff --git a/examples/benchmarks/DoubleEnsemble/workflow_config_doubleensemble_Alpha158.yaml b/examples/benchmarks/DoubleEnsemble/workflow_config_doubleensemble_Alpha158.yaml
index a12df802da..fb8cce74d8 100644
--- a/examples/benchmarks/DoubleEnsemble/workflow_config_doubleensemble_Alpha158.yaml
+++ b/examples/benchmarks/DoubleEnsemble/workflow_config_doubleensemble_Alpha158.yaml
@@ -12,19 +12,23 @@ data_handler_config: &data_handler_config
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
- module_path: qlib.contrib.strategy.strategy
+ module_path: qlib.contrib.strategy
kwargs:
+ model:
+ dataset:
topk: 50
n_drop: 5
backtest:
- verbose: False
- limit_threshold: 0.095
+ start_time: 2017-01-01
+ end_time: 2020-08-01
account: 100000000
benchmark: *benchmark
- deal_price: close
- open_cost: 0.0005
- close_cost: 0.0015
- min_cost: 5
+ exchange_kwargs:
+ limit_threshold: 0.095
+ deal_price: close
+ open_cost: 0.0005
+ close_cost: 0.0015
+ min_cost: 5
task:
model:
class: DEnsembleModel
@@ -75,16 +79,18 @@ task:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
- record:
+ record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
- kwargs: {}
+ kwargs:
+ model:
+ dataset:
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
- kwargs:
+ kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
- kwargs:
+ kwargs:
config: *port_analysis_config
diff --git a/examples/benchmarks/DoubleEnsemble/workflow_config_doubleensemble_Alpha360.yaml b/examples/benchmarks/DoubleEnsemble/workflow_config_doubleensemble_Alpha360.yaml
index 415448f0be..d1fbd78075 100644
--- a/examples/benchmarks/DoubleEnsemble/workflow_config_doubleensemble_Alpha360.yaml
+++ b/examples/benchmarks/DoubleEnsemble/workflow_config_doubleensemble_Alpha360.yaml
@@ -19,19 +19,23 @@ data_handler_config: &data_handler_config
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
- module_path: qlib.contrib.strategy.strategy
+ module_path: qlib.contrib.strategy
kwargs:
+ model:
+ dataset:
topk: 50
n_drop: 5
backtest:
- verbose: False
- limit_threshold: 0.095
+ start_time: 2017-01-01
+ end_time: 2020-08-01
account: 100000000
benchmark: *benchmark
- deal_price: close
- open_cost: 0.0005
- close_cost: 0.0015
- min_cost: 5
+ exchange_kwargs:
+ limit_threshold: 0.095
+ deal_price: close
+ open_cost: 0.0005
+ close_cost: 0.0015
+ min_cost: 5
task:
model:
class: DEnsembleModel
@@ -82,10 +86,12 @@ task:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
- record:
+ record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
- kwargs: {}
+ kwargs:
+ model:
+ dataset:
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
@@ -93,5 +99,5 @@ task:
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
- kwargs:
+ kwargs:
config: *port_analysis_config
\ No newline at end of file
diff --git a/examples/benchmarks/GATs/workflow_config_gats_Alpha158.yaml b/examples/benchmarks/GATs/workflow_config_gats_Alpha158.yaml
index 5fb7d5cc10..5387adc248 100644
--- a/examples/benchmarks/GATs/workflow_config_gats_Alpha158.yaml
+++ b/examples/benchmarks/GATs/workflow_config_gats_Alpha158.yaml
@@ -33,19 +33,23 @@ data_handler_config: &data_handler_config
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
- module_path: qlib.contrib.strategy.strategy
+ module_path: qlib.contrib.strategy
kwargs:
+ model:
+ dataset:
topk: 50
n_drop: 5
backtest:
- verbose: False
- limit_threshold: 0.095
+ start_time: 2017-01-01
+ end_time: 2020-08-01
account: 100000000
benchmark: *benchmark
- deal_price: close
- open_cost: 0.0005
- close_cost: 0.0015
- min_cost: 5
+ exchange_kwargs:
+ limit_threshold: 0.095
+ deal_price: close
+ open_cost: 0.0005
+ close_cost: 0.0015
+ min_cost: 5
task:
model:
class: GATs
@@ -79,7 +83,9 @@ task:
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
- kwargs: {}
+ kwargs:
+ model:
+ dataset:
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
diff --git a/examples/benchmarks/GATs/workflow_config_gats_Alpha360.yaml b/examples/benchmarks/GATs/workflow_config_gats_Alpha360.yaml
index 86ce510185..1ffd6780e4 100644
--- a/examples/benchmarks/GATs/workflow_config_gats_Alpha360.yaml
+++ b/examples/benchmarks/GATs/workflow_config_gats_Alpha360.yaml
@@ -26,19 +26,23 @@ data_handler_config: &data_handler_config
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
- module_path: qlib.contrib.strategy.strategy
+ module_path: qlib.contrib.strategy
kwargs:
+ model:
+ dataset:
topk: 50
n_drop: 5
backtest:
- verbose: False
- limit_threshold: 0.095
+ start_time: 2017-01-01
+ end_time: 2020-08-01
account: 100000000
benchmark: *benchmark
- deal_price: close
- open_cost: 0.0005
- close_cost: 0.0015
- min_cost: 5
+ exchange_kwargs:
+ limit_threshold: 0.095
+ deal_price: close
+ open_cost: 0.0005
+ close_cost: 0.0015
+ min_cost: 5
task:
model:
class: GATs
@@ -71,7 +75,9 @@ task:
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
- kwargs: {}
+ kwargs:
+ model:
+ dataset:
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
diff --git a/examples/benchmarks/GRU/workflow_config_gru_Alpha158.yaml b/examples/benchmarks/GRU/workflow_config_gru_Alpha158.yaml
index d3078314c1..82c6908890 100755
--- a/examples/benchmarks/GRU/workflow_config_gru_Alpha158.yaml
+++ b/examples/benchmarks/GRU/workflow_config_gru_Alpha158.yaml
@@ -34,19 +34,23 @@ data_handler_config: &data_handler_config
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
- module_path: qlib.contrib.strategy.strategy
+ module_path: qlib.contrib.strategy
kwargs:
+ model:
+ dataset:
topk: 50
n_drop: 5
backtest:
- verbose: False
- limit_threshold: 0.095
+ start_time: 2017-01-01
+ end_time: 2020-08-01
account: 100000000
benchmark: *benchmark
- deal_price: close
- open_cost: 0.0005
- close_cost: 0.0015
- min_cost: 5
+ exchange_kwargs:
+ limit_threshold: 0.095
+ deal_price: close
+ open_cost: 0.0005
+ close_cost: 0.0015
+ min_cost: 5
task:
model:
class: GRU
@@ -80,7 +84,9 @@ task:
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
- kwargs: {}
+ kwargs:
+ model:
+ dataset:
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
diff --git a/examples/benchmarks/GRU/workflow_config_gru_Alpha360.yaml b/examples/benchmarks/GRU/workflow_config_gru_Alpha360.yaml
index 2494d40f08..02c81c8507 100644
--- a/examples/benchmarks/GRU/workflow_config_gru_Alpha360.yaml
+++ b/examples/benchmarks/GRU/workflow_config_gru_Alpha360.yaml
@@ -26,19 +26,23 @@ data_handler_config: &data_handler_config
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
- module_path: qlib.contrib.strategy.strategy
+ module_path: qlib.contrib.strategy
kwargs:
+ model:
+ dataset:
topk: 50
n_drop: 5
backtest:
- verbose: False
- limit_threshold: 0.095
+ start_time: 2017-01-01
+ end_time: 2020-08-01
account: 100000000
benchmark: *benchmark
- deal_price: close
- open_cost: 0.0005
- close_cost: 0.0015
- min_cost: 5
+ exchange_kwargs:
+ limit_threshold: 0.095
+ deal_price: close
+ open_cost: 0.0005
+ close_cost: 0.0015
+ min_cost: 5
task:
model:
class: GRU
@@ -70,7 +74,9 @@ task:
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
- kwargs: {}
+ kwargs:
+ model:
+ dataset:
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
diff --git a/examples/benchmarks/LSTM/workflow_config_lstm_Alpha158.yaml b/examples/benchmarks/LSTM/workflow_config_lstm_Alpha158.yaml
index 14dd69d0a0..f4412c2623 100755
--- a/examples/benchmarks/LSTM/workflow_config_lstm_Alpha158.yaml
+++ b/examples/benchmarks/LSTM/workflow_config_lstm_Alpha158.yaml
@@ -34,19 +34,23 @@ data_handler_config: &data_handler_config
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
- module_path: qlib.contrib.strategy.strategy
+ module_path: qlib.contrib.strategy
kwargs:
+ model:
+ dataset:
topk: 50
n_drop: 5
backtest:
- verbose: False
- limit_threshold: 0.095
+ start_time: 2017-01-01
+ end_time: 2020-08-01
account: 100000000
benchmark: *benchmark
- deal_price: close
- open_cost: 0.0005
- close_cost: 0.0015
- min_cost: 5
+ exchange_kwargs:
+ limit_threshold: 0.095
+ deal_price: close
+ open_cost: 0.0005
+ close_cost: 0.0015
+ min_cost: 5
task:
model:
class: LSTM
@@ -80,7 +84,9 @@ task:
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
- kwargs: {}
+ kwargs:
+ model:
+ dataset:
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
diff --git a/examples/benchmarks/LSTM/workflow_config_lstm_Alpha360.yaml b/examples/benchmarks/LSTM/workflow_config_lstm_Alpha360.yaml
index 2aa5fd061c..10a1dc5dff 100644
--- a/examples/benchmarks/LSTM/workflow_config_lstm_Alpha360.yaml
+++ b/examples/benchmarks/LSTM/workflow_config_lstm_Alpha360.yaml
@@ -26,19 +26,23 @@ data_handler_config: &data_handler_config
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
- module_path: qlib.contrib.strategy.strategy
+ module_path: qlib.contrib.strategy
kwargs:
+ model:
+ dataset:
topk: 50
n_drop: 5
backtest:
- verbose: False
- limit_threshold: 0.095
+ start_time: 2017-01-01
+ end_time: 2020-08-01
account: 100000000
benchmark: *benchmark
- deal_price: close
- open_cost: 0.0005
- close_cost: 0.0015
- min_cost: 5
+ exchange_kwargs:
+ limit_threshold: 0.095
+ deal_price: close
+ open_cost: 0.0005
+ close_cost: 0.0015
+ min_cost: 5
task:
model:
class: LSTM
@@ -70,7 +74,9 @@ task:
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
- kwargs: {}
+ kwargs:
+ model:
+ dataset:
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
diff --git a/examples/benchmarks/LightGBM/features_resample_N.py b/examples/benchmarks/LightGBM/features_resample_N.py
new file mode 100644
index 0000000000..13061513cb
--- /dev/null
+++ b/examples/benchmarks/LightGBM/features_resample_N.py
@@ -0,0 +1,18 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+import pandas as pd
+
+from qlib.data.inst_processor import InstProcessor
+from qlib.utils.resam import resam_calendar
+
+
+class ResampleNProcessor(InstProcessor):
+ def __init__(self, target_frq: str, **kwargs):
+ self.target_frq = target_frq
+
+ def __call__(self, df: pd.DataFrame, *args, **kwargs):
+ df.index = pd.to_datetime(df.index)
+ res_index = resam_calendar(df.index, "1min", self.target_frq)
+ df = df.resample(self.target_frq).last().reindex(res_index)
+ return df
diff --git a/examples/benchmarks/LightGBM/multi_freq_handler.py b/examples/benchmarks/LightGBM/multi_freq_handler.py
new file mode 100644
index 0000000000..07d7ac27c4
--- /dev/null
+++ b/examples/benchmarks/LightGBM/multi_freq_handler.py
@@ -0,0 +1,135 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+import pandas as pd
+
+from qlib.data.dataset.loader import QlibDataLoader
+from qlib.contrib.data.handler import DataHandlerLP, _DEFAULT_LEARN_PROCESSORS, check_transform_proc
+
+
+class Avg15minLoader(QlibDataLoader):
+ def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:
+ df = super(Avg15minLoader, self).load(instruments, start_time, end_time)
+ if self.is_group:
+ # feature_day(day freq) and feature_15min(1min freq, Average every 15 minutes) renamed feature
+ df.columns = df.columns.map(lambda x: ("feature", x[1]) if x[0].startswith("feature") else x)
+ return df
+
+
+class Avg15minHandler(DataHandlerLP):
+ def __init__(
+ self,
+ instruments="csi500",
+ start_time=None,
+ end_time=None,
+ freq="day",
+ infer_processors=[],
+ learn_processors=_DEFAULT_LEARN_PROCESSORS,
+ fit_start_time=None,
+ fit_end_time=None,
+ process_type=DataHandlerLP.PTYPE_A,
+ filter_pipe=None,
+ inst_processor=None,
+ **kwargs,
+ ):
+ infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
+ learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
+ data_loader = Avg15minLoader(
+ config=self.loader_config(), filter_pipe=filter_pipe, freq=freq, inst_processor=inst_processor
+ )
+ super().__init__(
+ instruments=instruments,
+ start_time=start_time,
+ end_time=end_time,
+ data_loader=data_loader,
+ infer_processors=infer_processors,
+ learn_processors=learn_processors,
+ process_type=process_type,
+ )
+
+ def loader_config(self):
+
+ # Results for dataset: df: pd.DataFrame
+ # len(df.columns) == 6 + 6 * 16, len(df.index.get_level_values(level="datetime").unique()) == T
+ # df.columns: close0, close1, ..., close16, open0, ..., open16, ..., vwap16
+ # freq == day:
+ # close0, open0, low0, high0, volume0, vwap0
+ # freq == 1min:
+ # close1, ..., close16, ..., vwap1, ..., vwap16
+ # df.index.name == ["datetime", "instrument"]: pd.MultiIndex
+ # Example:
+ # feature ... label
+ # close0 open0 low0 ... vwap1 vwap16 LABEL0
+ # datetime instrument ...
+ # 2020-10-09 SH600000 11.794546 11.819587 11.769505 ... NaN NaN -0.005214
+ # 2020-10-15 SH600000 12.044961 11.944795 11.932274 ... NaN NaN -0.007202
+ # ... ... ... ... ... ... ... ...
+ # 2021-05-28 SZ300676 6.369684 6.495406 6.306568 ... NaN NaN -0.001321
+ # 2021-05-31 SZ300676 6.601626 6.465643 6.465130 ... NaN NaN -0.023428
+
+ # features day: len(columns) == 6, freq = day
+ # $close is the closing price of the current trading day:
+ # if the user needs to get the `close` before the last T days, use Ref($close, T-1), for example:
+ # $close Ref($close, 1) Ref($close, 2) Ref($close, 3) Ref($close, 4)
+ # instrument datetime
+ # SH600519 2021-06-01 244.271530
+ # 2021-06-02 242.205917 244.271530
+ # 2021-06-03 242.229889 242.205917 244.271530
+ # 2021-06-04 245.421524 242.229889 242.205917 244.271530
+ # 2021-06-07 247.547089 245.421524 242.229889 242.205917 244.271530
+
+ # WARNING: Ref($close, N), if N == 0, Ref($close, N) ==> $close
+
+ fields = ["$close", "$open", "$low", "$high", "$volume", "$vwap"]
+ # names: close0, open0, ..., vwap0
+ names = list(map(lambda x: x.strip("$") + "0", fields))
+
+ config = {"feature_day": (fields, names)}
+
+ # features 15min: len(columns) == 6 * 16, freq = 1min
+ # $close is the closing price of the current trading day:
+ # if the user gets 'close' for the i-th 15min of the last T days, use `Ref(Mean($close, 15), (T-1) * 240 + i * 15)`, for example:
+ # Ref(Mean($close, 15), 225) Ref(Mean($close, 15), 465) Ref(Mean($close, 15), 705)
+ # instrument datetime
+ # SH600519 2021-05-31 241.769897 243.077942 244.712997
+ # 2021-06-01 244.271530 241.769897 243.077942
+ # 2021-06-02 242.205917 244.271530 241.769897
+
+ # WARNING: Ref(Mean($close, 15), N), if N == 0, Ref(Mean($close, 15), N) ==> Mean($close, 15)
+
+ # Results of the current script:
+ # time: 09:00 --> 09:14, ..., 14:45 --> 14:59
+ # fields: Ref(Mean($close, 15), 225), ..., Mean($close, 15)
+ # name: close1, ..., close16
+ #
+
+ # Expression description: take close as an example
+ # Mean($close, 15) ==> df["$close"].rolling(15, min_periods=1).mean()
+ # Ref(Mean($close, 15), 15) ==> df["$close"].rolling(15, min_periods=1).mean().shift(15)
+
+ # NOTE: The last data of each trading day, which is the average of the i-th 15 minutes
+
+ # Average:
+ # Average of the i-th 15-minute period of each trading day: 1 <= i <= 250 // 16
+ # Avg(15minutes): Ref(Mean($close, 15), 240 - i * 15)
+ #
+ # Average of the first 15 minutes of each trading day; i = 1
+ # Avg(09:00 --> 09:14), df.index.loc["09:14"]: Ref(Mean($close, 15), 240- 1 * 15) ==> Ref(Mean($close, 15), 225)
+ # Average of the last 15 minutes of each trading day; i = 16
+ # Avg(14:45 --> 14:59), df.index.loc["14:59"]: Ref(Mean($close, 15), 240 - 16 * 15) ==> Ref(Mean($close, 15), 0) ==> Mean($close, 15)
+
+ # 15min resample to day
+ # df.resample("1d").last()
+ tmp_fields = []
+ tmp_names = []
+ for i, _f in enumerate(fields):
+ _fields = [f"Ref(Mean({_f}, 15), {j * 15})" for j in range(1, 240 // 15)]
+ _names = [f"{names[i][:-1]}{int(names[i][-1])+j}" for j in range(240 // 15 - 1, 0, -1)]
+ _fields.append(f"Mean({_f}, 15)")
+ _names.append(f"{names[i][:-1]}{int(names[i][-1])+240 // 15}")
+ tmp_fields += _fields
+ tmp_names += _names
+ config["feature_15min"] = (tmp_fields, tmp_names)
+ # label
+ config["label"] = (["Ref($close, -2)/Ref($close, -1) - 1"], ["LABEL0"])
+ return config
diff --git a/examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml b/examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
index 6caa4b2a59..2bb21d41dd 100644
--- a/examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
+++ b/examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
@@ -12,19 +12,23 @@ data_handler_config: &data_handler_config
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
- module_path: qlib.contrib.strategy.strategy
+ module_path: qlib.contrib.strategy
kwargs:
+ model:
+ dataset:
topk: 50
n_drop: 5
backtest:
- verbose: False
- limit_threshold: 0.095
+ start_time: 2017-01-01
+ end_time: 2020-08-01
account: 100000000
benchmark: *benchmark
- deal_price: close
- open_cost: 0.0005
- close_cost: 0.0015
- min_cost: 5
+ exchange_kwargs:
+ limit_threshold: 0.095
+ deal_price: close
+ open_cost: 0.0005
+ close_cost: 0.0015
+ min_cost: 5
task:
model:
class: LGBModel
@@ -54,7 +58,9 @@ task:
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
- kwargs: {}
+ kwargs:
+ model:
+ dataset:
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
diff --git a/examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha360.yaml b/examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha360.yaml
index 92df151331..b8af19ec1b 100644
--- a/examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha360.yaml
+++ b/examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha360.yaml
@@ -19,19 +19,23 @@ data_handler_config: &data_handler_config
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
- module_path: qlib.contrib.strategy.strategy
+ module_path: qlib.contrib.strategy
kwargs:
+ model:
+ dataset:
topk: 50
n_drop: 5
backtest:
- verbose: False
- limit_threshold: 0.095
+ start_time: 2017-01-01
+ end_time: 2020-08-01
account: 100000000
benchmark: *benchmark
- deal_price: close
- open_cost: 0.0005
- close_cost: 0.0015
- min_cost: 5
+ exchange_kwargs:
+ limit_threshold: 0.095
+ deal_price: close
+ open_cost: 0.0005
+ close_cost: 0.0015
+ min_cost: 5
task:
model:
class: LGBModel
@@ -61,7 +65,9 @@ task:
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
- kwargs: {}
+ kwargs:
+ model:
+ dataset:
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
diff --git a/examples/benchmarks/LightGBM/workflow_config_lightgbm_configurable_dataset.yaml b/examples/benchmarks/LightGBM/workflow_config_lightgbm_configurable_dataset.yaml
index 78f567eb31..a92f342a17 100644
--- a/examples/benchmarks/LightGBM/workflow_config_lightgbm_configurable_dataset.yaml
+++ b/examples/benchmarks/LightGBM/workflow_config_lightgbm_configurable_dataset.yaml
@@ -27,19 +27,23 @@ data_handler_config: &data_handler_config
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
- module_path: qlib.contrib.strategy.strategy
+ module_path: qlib.contrib.strategy
kwargs:
+ model:
+ dataset:
topk: 50
n_drop: 5
backtest:
- verbose: False
- limit_threshold: 0.095
+ start_time: 2017-01-01
+ end_time: 2020-08-01
account: 100000000
benchmark: *benchmark
- deal_price: close
- open_cost: 0.0005
- close_cost: 0.0015
- min_cost: 5
+ exchange_kwargs:
+ limit_threshold: 0.095
+ deal_price: close
+ open_cost: 0.0005
+ close_cost: 0.0015
+ min_cost: 5
task:
model:
class: LGBModel
@@ -69,7 +73,9 @@ task:
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
- kwargs: {}
+ kwargs:
+ model:
+ dataset:
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
diff --git a/examples/benchmarks/LightGBM/workflow_config_lightgbm_multi_freq.yaml b/examples/benchmarks/LightGBM/workflow_config_lightgbm_multi_freq.yaml
new file mode 100644
index 0000000000..829c871159
--- /dev/null
+++ b/examples/benchmarks/LightGBM/workflow_config_lightgbm_multi_freq.yaml
@@ -0,0 +1,86 @@
+qlib_init:
+ provider_uri:
+ day: "~/.qlib/qlib_data/cn_data"
+ 1min: "~/.qlib/qlib_data/cn_data_1min"
+ region: cn
+ dataset_cache: null
+ maxtasksperchild: null
+market: &market csi300
+benchmark: &benchmark SH000300
+data_handler_config: &data_handler_config
+ start_time: 2008-01-01
+ # 1min closing time is 15:00:00
+ end_time: "2020-08-01 15:00:00"
+ fit_start_time: 2008-01-01
+ fit_end_time: 2014-12-31
+ instruments: *market
+ freq:
+ label: day
+ feature_15min: 1min
+ feature_day: day
+ # with label as reference
+ inst_processor:
+ feature_15min:
+ - class: ResampleNProcessor
+ module_path: features_resample_N.py
+ kwargs:
+ target_frq: 1d
+
+port_analysis_config: &port_analysis_config
+ strategy:
+ class: TopkDropoutStrategy
+ module_path: qlib.contrib.strategy
+ kwargs:
+ model:
+ dataset:
+ topk: 50
+ n_drop: 5
+ backtest:
+ limit_threshold: 0.095
+ account: 100000000
+ benchmark: *benchmark
+ deal_price: close
+ open_cost: 0.0005
+ close_cost: 0.0015
+ min_cost: 5
+task:
+ model:
+ class: LGBModel
+ module_path: qlib.contrib.model.gbdt
+ kwargs:
+ loss: mse
+ colsample_bytree: 0.8879
+ learning_rate: 0.2
+ subsample: 0.8789
+ lambda_l1: 205.6999
+ lambda_l2: 580.9768
+ max_depth: 8
+ num_leaves: 210
+ num_threads: 20
+ dataset:
+ class: DatasetH
+ module_path: qlib.data.dataset
+ kwargs:
+ handler:
+ class: Avg15minHandler
+ module_path: multi_freq_handler.py
+ kwargs: *data_handler_config
+ segments:
+ train: [2008-01-01, 2014-12-31]
+ valid: [2015-01-01, 2016-12-31]
+ test: [2017-01-01, 2020-08-01]
+ record:
+ - class: SignalRecord
+ module_path: qlib.workflow.record_temp
+ kwargs:
+ model:
+ dataset:
+ - class: SigAnaRecord
+ module_path: qlib.workflow.record_temp
+ kwargs:
+ ana_long_short: False
+ ann_scaler: 252
+ - class: PortAnaRecord
+ module_path: qlib.workflow.record_temp
+ kwargs:
+ config: *port_analysis_config
diff --git a/examples/benchmarks/Linear/workflow_config_linear_Alpha158.yaml b/examples/benchmarks/Linear/workflow_config_linear_Alpha158.yaml
index ef2fee4c55..9f055a62cc 100644
--- a/examples/benchmarks/Linear/workflow_config_linear_Alpha158.yaml
+++ b/examples/benchmarks/Linear/workflow_config_linear_Alpha158.yaml
@@ -26,19 +26,23 @@ data_handler_config: &data_handler_config
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
- module_path: qlib.contrib.strategy.strategy
+ module_path: qlib.contrib.strategy
kwargs:
+ model:
+ dataset:
topk: 50
n_drop: 5
backtest:
- verbose: False
- limit_threshold: 0.095
+ start_time: 2017-01-01
+ end_time: 2020-08-01
account: 100000000
benchmark: *benchmark
- deal_price: close
- open_cost: 0.0005
- close_cost: 0.0015
- min_cost: 5
+ exchange_kwargs:
+ limit_threshold: 0.095
+ deal_price: close
+ open_cost: 0.0005
+ close_cost: 0.0015
+ min_cost: 5
task:
model:
class: LinearModel
@@ -57,16 +61,18 @@ task:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
- record:
+ record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
- kwargs: {}
+ kwargs:
+ model:
+ dataset:
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
- kwargs:
+ kwargs:
ana_long_short: True
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
- kwargs:
+ kwargs:
config: *port_analysis_config
diff --git a/examples/benchmarks/Localformer/workflow_config_localformer_Alpha158.yaml b/examples/benchmarks/Localformer/workflow_config_localformer_Alpha158.yaml
index d7e9673333..cd31ecd1e0 100644
--- a/examples/benchmarks/Localformer/workflow_config_localformer_Alpha158.yaml
+++ b/examples/benchmarks/Localformer/workflow_config_localformer_Alpha158.yaml
@@ -34,19 +34,23 @@ data_handler_config: &data_handler_config
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
- module_path: qlib.contrib.strategy.strategy
+ module_path: qlib.contrib.strategy
kwargs:
+ model:
+ dataset:
topk: 50
n_drop: 5
backtest:
- verbose: False
- limit_threshold: 0.095
+ start_time: 2017-01-01
+ end_time: 2020-08-01
account: 100000000
benchmark: *benchmark
- deal_price: close
- open_cost: 0.0005
- close_cost: 0.0015
- min_cost: 5
+ exchange_kwargs:
+ limit_threshold: 0.095
+ deal_price: close
+ open_cost: 0.0005
+ close_cost: 0.0015
+ min_cost: 5
task:
model:
class: LocalformerModel
@@ -70,13 +74,15 @@ task:
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
- kwargs: {}
+ kwargs:
+ model:
+ dataset:
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
- ana_long_short: False
- ann_scaler: 252
+ ana_long_short: False
+ ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
- config: *port_analysis_config
+ config: *port_analysis_config
diff --git a/examples/benchmarks/Localformer/workflow_config_localformer_Alpha360.yaml b/examples/benchmarks/Localformer/workflow_config_localformer_Alpha360.yaml
index 1c8489461c..f9cc091fde 100644
--- a/examples/benchmarks/Localformer/workflow_config_localformer_Alpha360.yaml
+++ b/examples/benchmarks/Localformer/workflow_config_localformer_Alpha360.yaml
@@ -26,19 +26,23 @@ data_handler_config: &data_handler_config
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
- module_path: qlib.contrib.strategy.strategy
+ module_path: qlib.contrib.strategy
kwargs:
+ model:
+ dataset:
topk: 50
n_drop: 5
backtest:
- verbose: False
- limit_threshold: 0.095
+ start_time: 2017-01-01
+ end_time: 2020-08-01
account: 100000000
benchmark: *benchmark
- deal_price: close
- open_cost: 0.0005
- close_cost: 0.0015
- min_cost: 5
+ exchange_kwargs:
+ limit_threshold: 0.095
+ deal_price: close
+ open_cost: 0.0005
+ close_cost: 0.0015
+ min_cost: 5
task:
model:
class: LocalformerModel
@@ -59,15 +63,17 @@ task:
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
record:
- - class: SignalRecord
- module_path: qlib.workflow.record_temp
- kwargs: {}
- - class: SigAnaRecord
- module_path: qlib.workflow.record_temp
- kwargs:
- ana_long_short: False
- ann_scaler: 252
- - class: PortAnaRecord
- module_path: qlib.workflow.record_temp
- kwargs:
- config: *port_analysis_config
\ No newline at end of file
+ - class: SignalRecord
+ module_path: qlib.workflow.record_temp
+ kwargs:
+ model:
+ dataset:
+ - class: SigAnaRecord
+ module_path: qlib.workflow.record_temp
+ kwargs:
+ ana_long_short: False
+ ann_scaler: 252
+ - class: PortAnaRecord
+ module_path: qlib.workflow.record_temp
+ kwargs:
+ config: *port_analysis_config
diff --git a/examples/benchmarks/MLP/workflow_config_mlp_Alpha158.yaml b/examples/benchmarks/MLP/workflow_config_mlp_Alpha158.yaml
index b177a810be..8303f3945c 100644
--- a/examples/benchmarks/MLP/workflow_config_mlp_Alpha158.yaml
+++ b/examples/benchmarks/MLP/workflow_config_mlp_Alpha158.yaml
@@ -39,19 +39,23 @@ data_handler_config: &data_handler_config
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
- module_path: qlib.contrib.strategy.strategy
+ module_path: qlib.contrib.strategy
kwargs:
+ model:
+ dataset:
topk: 50
n_drop: 5
backtest:
- verbose: False
- limit_threshold: 0.095
+ start_time: 2017-01-01
+ end_time: 2020-08-01
account: 100000000
benchmark: *benchmark
- deal_price: close
- open_cost: 0.0005
- close_cost: 0.0015
- min_cost: 5
+ exchange_kwargs:
+ limit_threshold: 0.095
+ deal_price: close
+ open_cost: 0.0005
+ close_cost: 0.0015
+ min_cost: 5
task:
model:
class: DNNModelPytorch
@@ -83,7 +87,9 @@ task:
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
- kwargs: {}
+ kwargs:
+ model:
+ dataset:
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
diff --git a/examples/benchmarks/MLP/workflow_config_mlp_Alpha360.yaml b/examples/benchmarks/MLP/workflow_config_mlp_Alpha360.yaml
index 18920399f8..f52c5930db 100644
--- a/examples/benchmarks/MLP/workflow_config_mlp_Alpha360.yaml
+++ b/examples/benchmarks/MLP/workflow_config_mlp_Alpha360.yaml
@@ -27,19 +27,23 @@ data_handler_config: &data_handler_config
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
- module_path: qlib.contrib.strategy.strategy
+ module_path: qlib.contrib.strategy
kwargs:
+ model:
+ dataset:
topk: 50
n_drop: 5
backtest:
- verbose: False
- limit_threshold: 0.095
+ start_time: 2017-01-01
+ end_time: 2020-08-01
account: 100000000
benchmark: *benchmark
- deal_price: close
- open_cost: 0.0005
- close_cost: 0.0015
- min_cost: 5
+ exchange_kwargs:
+ limit_threshold: 0.095
+ deal_price: close
+ open_cost: 0.0005
+ close_cost: 0.0015
+ min_cost: 5
task:
model:
class: DNNModelPytorch
@@ -70,7 +74,9 @@ task:
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
- kwargs: {}
+ kwargs:
+ model:
+ dataset:
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
diff --git a/examples/benchmarks/SFM/requirements.txt b/examples/benchmarks/SFM/requirements.txt
index 6a3d13097a..16de0a4384 100644
--- a/examples/benchmarks/SFM/requirements.txt
+++ b/examples/benchmarks/SFM/requirements.txt
@@ -1,4 +1,4 @@
pandas==1.1.2
numpy==1.17.4
scikit_learn==0.23.2
-torch==1.7.0
\ No newline at end of file
+torch==1.7.0
diff --git a/examples/benchmarks/SFM/workflow_config_sfm_Alpha360.yaml b/examples/benchmarks/SFM/workflow_config_sfm_Alpha360.yaml
index a23fe3854c..5c66400bbe 100644
--- a/examples/benchmarks/SFM/workflow_config_sfm_Alpha360.yaml
+++ b/examples/benchmarks/SFM/workflow_config_sfm_Alpha360.yaml
@@ -26,19 +26,23 @@ data_handler_config: &data_handler_config
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
- module_path: qlib.contrib.strategy.strategy
+ module_path: qlib.contrib.strategy
kwargs:
+ model:
+ dataset:
topk: 50
n_drop: 5
backtest:
- verbose: False
- limit_threshold: 0.095
+ start_time: 2017-01-01
+ end_time: 2020-08-01
account: 100000000
benchmark: *benchmark
- deal_price: close
- open_cost: 0.0005
- close_cost: 0.0015
- min_cost: 5
+ exchange_kwargs:
+ limit_threshold: 0.095
+ deal_price: close
+ open_cost: 0.0005
+ close_cost: 0.0015
+ min_cost: 5
task:
model:
class: SFM
@@ -73,7 +77,9 @@ task:
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
- kwargs: {}
+ kwargs:
+ model:
+ dataset:
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
diff --git a/examples/benchmarks/TCTS/requirements.txt b/examples/benchmarks/TCTS/requirements.txt
new file mode 100644
index 0000000000..6a3d13097a
--- /dev/null
+++ b/examples/benchmarks/TCTS/requirements.txt
@@ -0,0 +1,4 @@
+pandas==1.1.2
+numpy==1.17.4
+scikit_learn==0.23.2
+torch==1.7.0
\ No newline at end of file
diff --git a/examples/benchmarks/TCTS/workflow_config_tcts_Alpha360.yaml b/examples/benchmarks/TCTS/workflow_config_tcts_Alpha360.yaml
index 89c66f992a..7ca6e937f6 100644
--- a/examples/benchmarks/TCTS/workflow_config_tcts_Alpha360.yaml
+++ b/examples/benchmarks/TCTS/workflow_config_tcts_Alpha360.yaml
@@ -28,19 +28,23 @@ data_handler_config: &data_handler_config
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
- module_path: qlib.contrib.strategy.strategy
+ module_path: qlib.contrib.strategy
kwargs:
+ model:
+ dataset:
topk: 50
n_drop: 5
backtest:
- verbose: False
- limit_threshold: 0.095
+ start_time: 2017-01-01
+ end_time: 2020-08-01
account: 100000000
benchmark: *benchmark
- deal_price: close
- open_cost: 0.0005
- close_cost: 0.0015
- min_cost: 5
+ exchange_kwargs:
+ limit_threshold: 0.095
+ deal_price: close
+ open_cost: 0.0005
+ close_cost: 0.0015
+ min_cost: 5
task:
model:
class: TCTS
@@ -80,7 +84,9 @@ task:
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
- kwargs: {}
+ kwargs:
+ model:
+ dataset:
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
diff --git a/examples/benchmarks/TFT/tft.py b/examples/benchmarks/TFT/tft.py
index 3908b27773..a854c2dd9e 100644
--- a/examples/benchmarks/TFT/tft.py
+++ b/examples/benchmarks/TFT/tft.py
@@ -304,6 +304,7 @@ def to_pickle(self, path: Union[Path, str]):
path : Union[Path, str]
the target path to be dumped
"""
+ # FIXME: implementing saving tensorflow models
# save tensorflow model
# path = Path(path)
# path.mkdir(parents=True)
@@ -311,4 +312,4 @@ def to_pickle(self, path: Union[Path, str]):
# save qlib model wrapper
self.model = None
- super(TFTModel, self).to_pickle(path / "qlib_model")
+ super(TFTModel, self).to_pickle(path)
diff --git a/examples/benchmarks/TFT/workflow_config_tft_Alpha158.yaml b/examples/benchmarks/TFT/workflow_config_tft_Alpha158.yaml
index dba37ab637..0508ce676c 100644
--- a/examples/benchmarks/TFT/workflow_config_tft_Alpha158.yaml
+++ b/examples/benchmarks/TFT/workflow_config_tft_Alpha158.yaml
@@ -14,19 +14,23 @@ data_handler_config: &data_handler_config
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
- module_path: qlib.contrib.strategy.strategy
+ module_path: qlib.contrib.strategy
kwargs:
+ model:
+ dataset:
topk: 50
n_drop: 5
backtest:
- verbose: False
- limit_threshold: 0.095
+ start_time: 2017-01-01
+ end_time: 2020-08-01
account: 100000000
benchmark: *benchmark
- deal_price: close
- open_cost: 0.0005
- close_cost: 0.0015
- min_cost: 5
+ exchange_kwargs:
+ limit_threshold: 0.095
+ deal_price: close
+ open_cost: 0.0005
+ close_cost: 0.0015
+ min_cost: 5
task:
model:
class: TFTModel
@@ -46,7 +50,9 @@ task:
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
- kwargs: {}
+ kwargs:
+ model:
+ dataset:
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
diff --git a/examples/benchmarks/TRA/README.md b/examples/benchmarks/TRA/README.md
index 6d3e7a4769..5ff5b480e3 100644
--- a/examples/benchmarks/TRA/README.md
+++ b/examples/benchmarks/TRA/README.md
@@ -69,6 +69,7 @@ After running the scripts, you can find result files in path `./output`:
* `pred.pkl` - the prediction scores and output for inference.
Evaluation metrics reported in the paper:
+This result is generated by qlib==0.7.1.
| Methods | MSE| MAE| IC | ICIR | AR | AV | SR | MDD |
|-------|-------|------|-----|-----|-----|-----|-----|-----|
diff --git a/examples/benchmarks/TRA/requirements.txt b/examples/benchmarks/TRA/requirements.txt
new file mode 100644
index 0000000000..ab819ec1c4
--- /dev/null
+++ b/examples/benchmarks/TRA/requirements.txt
@@ -0,0 +1,5 @@
+pandas==1.1.2
+numpy==1.17.4
+scikit_learn==0.23.2
+torch==1.7.0
+seaborn
diff --git a/examples/benchmarks/TRA/workflow_config_tra_Alpha158.yaml b/examples/benchmarks/TRA/workflow_config_tra_Alpha158.yaml
index 09ff8893b1..f273f62eeb 100644
--- a/examples/benchmarks/TRA/workflow_config_tra_Alpha158.yaml
+++ b/examples/benchmarks/TRA/workflow_config_tra_Alpha158.yaml
@@ -53,21 +53,25 @@ model_config: &model_config
dropout: 0.0
port_analysis_config: &port_analysis_config
- strategy:
- class: TopkDropoutStrategy
- module_path: qlib.contrib.strategy.strategy
- kwargs:
- topk: 50
- n_drop: 5
- backtest:
- verbose: False
- limit_threshold: 0.095
- account: 100000000
- benchmark: *benchmark
- deal_price: close
- open_cost: 0.0005
- close_cost: 0.0015
- min_cost: 5
+ strategy:
+ class: TopkDropoutStrategy
+ module_path: qlib.contrib.strategy
+ kwargs:
+ model:
+ dataset:
+ topk: 50
+ n_drop: 5
+ backtest:
+ start_time: 2017-01-01
+ end_time: 2020-08-01
+ account: 100000000
+ benchmark: *benchmark
+ exchange_kwargs:
+ limit_threshold: 0.095
+ deal_price: close
+ open_cost: 0.0005
+ close_cost: 0.0015
+ min_cost: 5
task:
model:
@@ -117,13 +121,15 @@ task:
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
- kwargs: {}
+ kwargs:
+ model:
+ dataset:
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
- kwargs:
+ kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
- kwargs:
+ kwargs:
config: *port_analysis_config
diff --git a/examples/benchmarks/TRA/workflow_config_tra_Alpha158_full.yaml b/examples/benchmarks/TRA/workflow_config_tra_Alpha158_full.yaml
index dd413b00a6..8dc82cb999 100644
--- a/examples/benchmarks/TRA/workflow_config_tra_Alpha158_full.yaml
+++ b/examples/benchmarks/TRA/workflow_config_tra_Alpha158_full.yaml
@@ -47,21 +47,25 @@ model_config: &model_config
dropout: 0.2
port_analysis_config: &port_analysis_config
- strategy:
- class: TopkDropoutStrategy
- module_path: qlib.contrib.strategy.strategy
- kwargs:
- topk: 50
- n_drop: 5
- backtest:
- verbose: False
- limit_threshold: 0.095
- account: 100000000
- benchmark: *benchmark
- deal_price: close
- open_cost: 0.0005
- close_cost: 0.0015
- min_cost: 5
+ strategy:
+ class: TopkDropoutStrategy
+ module_path: qlib.contrib.strategy
+ kwargs:
+ model:
+ dataset:
+ topk: 50
+ n_drop: 5
+ backtest:
+ start_time: 2017-01-01
+ end_time: 2020-08-01
+ account: 100000000
+ benchmark: *benchmark
+ exchange_kwargs:
+ limit_threshold: 0.095
+ deal_price: close
+ open_cost: 0.0005
+ close_cost: 0.0015
+ min_cost: 5
task:
model:
@@ -111,10 +115,12 @@ task:
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
- kwargs: {}
+ kwargs:
+ model:
+ dataset:
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
- kwargs:
+ kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
diff --git a/examples/benchmarks/TRA/workflow_config_tra_Alpha360.yaml b/examples/benchmarks/TRA/workflow_config_tra_Alpha360.yaml
index 84dee5d72a..bd5b132ee2 100644
--- a/examples/benchmarks/TRA/workflow_config_tra_Alpha360.yaml
+++ b/examples/benchmarks/TRA/workflow_config_tra_Alpha360.yaml
@@ -47,21 +47,25 @@ model_config: &model_config
dropout: 0.0
port_analysis_config: &port_analysis_config
- strategy:
- class: TopkDropoutStrategy
- module_path: qlib.contrib.strategy.strategy
- kwargs:
- topk: 50
- n_drop: 5
- backtest:
- verbose: False
- limit_threshold: 0.095
- account: 100000000
- benchmark: *benchmark
- deal_price: close
- open_cost: 0.0005
- close_cost: 0.0015
- min_cost: 5
+ strategy:
+ class: TopkDropoutStrategy
+ module_path: qlib.contrib.strategy
+ kwargs:
+ model:
+ dataset:
+ topk: 50
+ n_drop: 5
+ backtest:
+ start_time: 2017-01-01
+ end_time: 2020-08-01
+ account: 100000000
+ benchmark: *benchmark
+ exchange_kwargs:
+ limit_threshold: 0.095
+ deal_price: close
+ open_cost: 0.0005
+ close_cost: 0.0015
+ min_cost: 5
task:
model:
@@ -111,10 +115,12 @@ task:
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
- kwargs: {}
+ kwargs:
+ model:
+ dataset:
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
- kwargs:
+ kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
diff --git a/examples/benchmarks/TabNet/workflow_config_TabNet_Alpha158.yaml b/examples/benchmarks/TabNet/workflow_config_TabNet_Alpha158.yaml
index cf8ef7411e..0fa1b23d52 100644
--- a/examples/benchmarks/TabNet/workflow_config_TabNet_Alpha158.yaml
+++ b/examples/benchmarks/TabNet/workflow_config_TabNet_Alpha158.yaml
@@ -26,19 +26,23 @@ data_handler_config: &data_handler_config
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
- module_path: qlib.contrib.strategy.strategy
+ module_path: qlib.contrib.strategy
kwargs:
+ model:
+ dataset:
topk: 50
n_drop: 5
backtest:
- verbose: False
- limit_threshold: 0.095
+ start_time: 2017-01-01
+ end_time: 2020-08-01
account: 100000000
benchmark: *benchmark
- deal_price: close
- open_cost: 0.0005
- close_cost: 0.0015
- min_cost: 5
+ exchange_kwargs:
+ limit_threshold: 0.095
+ deal_price: close
+ open_cost: 0.0005
+ close_cost: 0.0015
+ min_cost: 5
task:
model:
class: TabnetModel
@@ -63,7 +67,9 @@ task:
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
- kwargs: {}
+ kwargs:
+ model:
+ dataset:
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
diff --git a/examples/benchmarks/TabNet/workflow_config_TabNet_Alpha360.yaml b/examples/benchmarks/TabNet/workflow_config_TabNet_Alpha360.yaml
index 5023e9b3da..0c798ae304 100644
--- a/examples/benchmarks/TabNet/workflow_config_TabNet_Alpha360.yaml
+++ b/examples/benchmarks/TabNet/workflow_config_TabNet_Alpha360.yaml
@@ -26,19 +26,23 @@ data_handler_config: &data_handler_config
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
- module_path: qlib.contrib.strategy.strategy
+ module_path: qlib.contrib.strategy
kwargs:
+ model:
+ dataset:
topk: 50
n_drop: 5
backtest:
- verbose: False
- limit_threshold: 0.095
+ start_time: 2017-01-01
+ end_time: 2020-08-01
account: 100000000
benchmark: *benchmark
- deal_price: close
- open_cost: 0.0005
- close_cost: 0.0015
- min_cost: 5
+ exchange_kwargs:
+ limit_threshold: 0.095
+ deal_price: close
+ open_cost: 0.0005
+ close_cost: 0.0015
+ min_cost: 5
task:
model:
class: TabnetModel
@@ -63,7 +67,9 @@ task:
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
- kwargs: {}
+ kwargs:
+ model:
+ dataset:
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
diff --git a/examples/benchmarks/Transformer/workflow_config_transformer_Alpha158.yaml b/examples/benchmarks/Transformer/workflow_config_transformer_Alpha158.yaml
index 54707386f9..6174abf2eb 100644
--- a/examples/benchmarks/Transformer/workflow_config_transformer_Alpha158.yaml
+++ b/examples/benchmarks/Transformer/workflow_config_transformer_Alpha158.yaml
@@ -34,19 +34,23 @@ data_handler_config: &data_handler_config
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
- module_path: qlib.contrib.strategy.strategy
+ module_path: qlib.contrib.strategy
kwargs:
+ model:
+ dataset:
topk: 50
n_drop: 5
backtest:
- verbose: False
- limit_threshold: 0.095
+ start_time: 2017-01-01
+ end_time: 2020-08-01
account: 100000000
benchmark: *benchmark
- deal_price: close
- open_cost: 0.0005
- close_cost: 0.0015
- min_cost: 5
+ exchange_kwargs:
+ limit_threshold: 0.095
+ deal_price: close
+ open_cost: 0.0005
+ close_cost: 0.0015
+ min_cost: 5
task:
model:
class: TransformerModel
@@ -70,7 +74,9 @@ task:
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
- kwargs: {}
+ kwargs:
+ model:
+ dataset:
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
diff --git a/examples/benchmarks/Transformer/workflow_config_transformer_Alpha360.yaml b/examples/benchmarks/Transformer/workflow_config_transformer_Alpha360.yaml
index e568a1b307..883c18cdcc 100644
--- a/examples/benchmarks/Transformer/workflow_config_transformer_Alpha360.yaml
+++ b/examples/benchmarks/Transformer/workflow_config_transformer_Alpha360.yaml
@@ -26,19 +26,23 @@ data_handler_config: &data_handler_config
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
- module_path: qlib.contrib.strategy.strategy
+ module_path: qlib.contrib.strategy
kwargs:
+ model:
+ dataset:
topk: 50
n_drop: 5
backtest:
- verbose: False
- limit_threshold: 0.095
+ start_time: 2017-01-01
+ end_time: 2020-08-01
account: 100000000
benchmark: *benchmark
- deal_price: close
- open_cost: 0.0005
- close_cost: 0.0015
- min_cost: 5
+ exchange_kwargs:
+ limit_threshold: 0.095
+ deal_price: close
+ open_cost: 0.0005
+ close_cost: 0.0015
+ min_cost: 5
task:
model:
class: TransformerModel
@@ -61,7 +65,9 @@ task:
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
- kwargs: {}
+ kwargs:
+ model:
+ dataset:
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
@@ -70,4 +76,4 @@ task:
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
- config: *port_analysis_config
\ No newline at end of file
+ config: *port_analysis_config
diff --git a/examples/benchmarks/XGBoost/workflow_config_xgboost_Alpha158.yaml b/examples/benchmarks/XGBoost/workflow_config_xgboost_Alpha158.yaml
index 4caaa6f622..502a5e73c5 100644
--- a/examples/benchmarks/XGBoost/workflow_config_xgboost_Alpha158.yaml
+++ b/examples/benchmarks/XGBoost/workflow_config_xgboost_Alpha158.yaml
@@ -12,19 +12,23 @@ data_handler_config: &data_handler_config
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
- module_path: qlib.contrib.strategy.strategy
+ module_path: qlib.contrib.strategy
kwargs:
+ model:
+ dataset:
topk: 50
n_drop: 5
backtest:
- verbose: False
- limit_threshold: 0.095
+ start_time: 2017-01-01
+ end_time: 2020-08-01
account: 100000000
benchmark: *benchmark
- deal_price: close
- open_cost: 0.0005
- close_cost: 0.0015
- min_cost: 5
+ exchange_kwargs:
+ limit_threshold: 0.095
+ deal_price: close
+ open_cost: 0.0005
+ close_cost: 0.0015
+ min_cost: 5
task:
model:
class: XGBModel
@@ -52,7 +56,9 @@ task:
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
- kwargs: {}
+ kwargs:
+ model:
+ dataset:
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
diff --git a/examples/benchmarks/XGBoost/workflow_config_xgboost_Alpha360.yaml b/examples/benchmarks/XGBoost/workflow_config_xgboost_Alpha360.yaml
index 7887a25a6c..a2e40eefbc 100644
--- a/examples/benchmarks/XGBoost/workflow_config_xgboost_Alpha360.yaml
+++ b/examples/benchmarks/XGBoost/workflow_config_xgboost_Alpha360.yaml
@@ -19,19 +19,23 @@ data_handler_config: &data_handler_config
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
- module_path: qlib.contrib.strategy.strategy
+ module_path: qlib.contrib.strategy
kwargs:
+ model:
+ dataset:
topk: 50
n_drop: 5
backtest:
- verbose: False
- limit_threshold: 0.095
+ start_time: 2017-01-01
+ end_time: 2020-08-01
account: 100000000
benchmark: *benchmark
- deal_price: close
- open_cost: 0.0005
- close_cost: 0.0015
- min_cost: 5
+ exchange_kwargs:
+ limit_threshold: 0.095
+ deal_price: close
+ open_cost: 0.0005
+ close_cost: 0.0015
+ min_cost: 5
task:
model:
class: XGBModel
@@ -59,7 +63,9 @@ task:
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
- kwargs: {}
+ kwargs:
+ model:
+ dataset:
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
diff --git a/examples/highfreq/highfreq_ops.py b/examples/highfreq/highfreq_ops.py
index ef784b34cb..175f4f66be 100644
--- a/examples/highfreq/highfreq_ops.py
+++ b/examples/highfreq/highfreq_ops.py
@@ -5,30 +5,7 @@
from qlib.config import C
from qlib.data.cache import H
from qlib.data.data import Cal
-
-
-def get_calendar_day(freq="day", future=False):
- """Load High-Freq Calendar Date Using Memcache.
-
- Parameters
- ----------
- freq : str
- frequency of read calendar file.
- future : bool
- whether including future trading day.
-
- Returns
- -------
- _calendar:
- array of date.
- """
- flag = f"{freq}_future_{future}_day"
- if flag in H["c"]:
- _calendar = H["c"][flag]
- else:
- _calendar = np.array(list(map(lambda x: pd.Timestamp(x.date()), Cal.load_calendar(freq, future))))
- H["c"][flag] = _calendar
- return _calendar
+from qlib.contrib.ops.high_freq import get_calendar_day
class DayLast(ElemOperator):
diff --git a/examples/highfreq/workflow_config_High_Freq_Tree_Alpha158.yaml b/examples/highfreq/workflow_config_High_Freq_Tree_Alpha158.yaml
index 45c59c6705..93d9dde56e 100644
--- a/examples/highfreq/workflow_config_High_Freq_Tree_Alpha158.yaml
+++ b/examples/highfreq/workflow_config_High_Freq_Tree_Alpha158.yaml
@@ -59,7 +59,7 @@ task:
record:
- class: "SignalRecord"
module_path: "qlib.workflow.record_temp"
- kwargs: {}
+ kwargs:
- class: "HFSignalRecord"
module_path: "qlib.workflow.record_temp"
kwargs: {}
\ No newline at end of file
diff --git a/examples/nested_decision_execution/README.md b/examples/nested_decision_execution/README.md
new file mode 100644
index 0000000000..382e5a3206
--- /dev/null
+++ b/examples/nested_decision_execution/README.md
@@ -0,0 +1,30 @@
+# Nested Decision Execution
+
+This workflow is an example for nested decision execution in backtesting. Qlib supports nested decision execution in backtesting. It means that users can use different strategies to make trade decision in different frequencies.
+
+## Weekly Portfolio Generation and Daily Order Execution
+
+This workflow provides an example that uses a DropoutTopkStrategy (a strategy based on the daily frequency Lightgbm model) in weekly frequency for portfolio generation and uses SBBStrategyEMA (a rule-based strategy that uses EMA for decision-making) to execute orders in daily frequency.
+
+### Usage
+
+Start backtesting by running the following command:
+```bash
+ python workflow.py backtest
+```
+
+Start collecting data by running the following command:
+```bash
+ python workflow.py collect_data
+```
+
+## Daily Portfolio Generation and Minutely Order Execution
+
+This workflow also provides a high-frequency example that uses a DropoutTopkStrategy for portfolio generation in daily frequency and uses SBBStrategyEMA to execute orders in minutely frequency.
+
+### Usage
+
+Start backtesting by running the following command:
+```bash
+ python workflow.py backtest_highfreq
+```
\ No newline at end of file
diff --git a/examples/nested_decision_execution/workflow.py b/examples/nested_decision_execution/workflow.py
new file mode 100644
index 0000000000..ef6906018a
--- /dev/null
+++ b/examples/nested_decision_execution/workflow.py
@@ -0,0 +1,206 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+
+import qlib
+import fire
+from qlib.config import REG_CN, HIGH_FREQ_CONFIG
+from qlib.data import D
+from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict
+from qlib.workflow import R
+from qlib.workflow.record_temp import SignalRecord, PortAnaRecord
+from qlib.tests.data import GetData
+from qlib.backtest import collect_data
+
+
+class NestedDecisionExecutionWorkflow:
+
+ market = "csi300"
+ benchmark = "SH000300"
+ data_handler_config = {
+ "start_time": "2008-01-01",
+ "end_time": "2021-05-31",
+ "fit_start_time": "2008-01-01",
+ "fit_end_time": "2014-12-31",
+ "instruments": market,
+ }
+
+ task = {
+ "model": {
+ "class": "LGBModel",
+ "module_path": "qlib.contrib.model.gbdt",
+ "kwargs": {
+ "loss": "mse",
+ "colsample_bytree": 0.8879,
+ "learning_rate": 0.0421,
+ "subsample": 0.8789,
+ "lambda_l1": 205.6999,
+ "lambda_l2": 580.9768,
+ "max_depth": 8,
+ "num_leaves": 210,
+ "num_threads": 20,
+ },
+ },
+ "dataset": {
+ "class": "DatasetH",
+ "module_path": "qlib.data.dataset",
+ "kwargs": {
+ "handler": {
+ "class": "Alpha158",
+ "module_path": "qlib.contrib.data.handler",
+ "kwargs": data_handler_config,
+ },
+ "segments": {
+ "train": ("2007-01-01", "2014-12-31"),
+ "valid": ("2015-01-01", "2016-12-31"),
+ "test": ("2020-01-01", "2021-05-31"),
+ },
+ },
+ },
+ }
+
+ port_analysis_config = {
+ "executor": {
+ "class": "NestedExecutor",
+ "module_path": "qlib.backtest.executor",
+ "kwargs": {
+ "time_per_step": "day",
+ "inner_executor": {
+ "class": "NestedExecutor",
+ "module_path": "qlib.backtest.executor",
+ "kwargs": {
+ "time_per_step": "30min",
+ "inner_executor": {
+ "class": "SimulatorExecutor",
+ "module_path": "qlib.backtest.executor",
+ "kwargs": {
+ "time_per_step": "5min",
+ "generate_portfolio_metrics": True,
+ "verbose": True,
+ "indicator_config": {
+ "show_indicator": True,
+ },
+ },
+ },
+ "inner_strategy": {
+ "class": "TWAPStrategy",
+ "module_path": "qlib.contrib.strategy.rule_strategy",
+ },
+ "generate_portfolio_metrics": True,
+ "indicator_config": {
+ "show_indicator": True,
+ },
+ },
+ },
+ "inner_strategy": {
+ "class": "SBBStrategyEMA",
+ "module_path": "qlib.contrib.strategy.rule_strategy",
+ "kwargs": {
+ "instruments": market,
+ "freq": "1min",
+ },
+ },
+ "track_data": True,
+ "generate_portfolio_metrics": True,
+ "indicator_config": {
+ "show_indicator": True,
+ },
+ },
+ },
+ "backtest": {
+ "start_time": "2020-09-20",
+ "end_time": "2021-05-20",
+ "account": 100000000,
+ "exchange_kwargs": {
+ "freq": "1min",
+ "limit_threshold": 0.095,
+ "deal_price": "close",
+ "open_cost": 0.0005,
+ "close_cost": 0.0015,
+ "min_cost": 5,
+ },
+ },
+ }
+
+ def _init_qlib(self):
+ """initialize qlib"""
+ provider_uri_day = "~/.qlib/qlib_data/cn_data" # target_dir
+ GetData().qlib_data(target_dir=provider_uri_day, region=REG_CN, version="v2", exists_skip=True)
+ provider_uri_1min = HIGH_FREQ_CONFIG.get("provider_uri")
+ GetData().qlib_data(
+ target_dir=provider_uri_1min, interval="1min", region=REG_CN, version="v2", exists_skip=True
+ )
+ provider_uri_map = {"1min": provider_uri_1min, "day": provider_uri_day}
+ qlib.init(provider_uri=provider_uri_map, dataset_cache=None, expression_cache=None)
+
+ def _train_model(self, model, dataset):
+ with R.start(experiment_name="train"):
+ R.log_params(**flatten_dict(self.task))
+ model.fit(dataset)
+ R.save_objects(**{"params.pkl": model})
+
+ # prediction
+ recorder = R.get_recorder()
+ sr = SignalRecord(model, dataset, recorder)
+ sr.generate()
+
+ def backtest(self):
+ self._init_qlib()
+ model = init_instance_by_config(self.task["model"])
+ dataset = init_instance_by_config(self.task["dataset"])
+ self._train_model(model, dataset)
+ strategy_config = {
+ "class": "TopkDropoutStrategy",
+ "module_path": "qlib.contrib.strategy.model_strategy",
+ "kwargs": {
+ "model": model,
+ "dataset": dataset,
+ "topk": 50,
+ "n_drop": 5,
+ },
+ }
+ self.port_analysis_config["strategy"] = strategy_config
+ self.port_analysis_config["backtest"]["benchmark"] = self.benchmark
+
+ with R.start(experiment_name="backtest"):
+
+ recorder = R.get_recorder()
+ par = PortAnaRecord(
+ recorder,
+ self.port_analysis_config,
+ risk_analysis_freq=["day", "30min", "5min"],
+ indicator_analysis_freq=["day", "30min", "5min"],
+ indicator_analysis_method="value_weighted",
+ )
+ par.generate()
+
+ # user could use following methods to analysis the position
+ # report_normal_df = recorder.load_object("portfolio_analysis/report_normal_1day.pkl")
+ # from qlib.contrib.report import analysis_position
+ # analysis_position.report_graph(report_normal_df)
+
+ def collect_data(self):
+ self._init_qlib()
+ model = init_instance_by_config(self.task["model"])
+ dataset = init_instance_by_config(self.task["dataset"])
+ self._train_model(model, dataset)
+ executor_config = self.port_analysis_config["executor"]
+ backtest_config = self.port_analysis_config["backtest"]
+ backtest_config["benchmark"] = self.benchmark
+ strategy_config = {
+ "class": "TopkDropoutStrategy",
+ "module_path": "qlib.contrib.strategy.model_strategy",
+ "kwargs": {
+ "model": model,
+ "dataset": dataset,
+ "topk": 50,
+ "n_drop": 5,
+ },
+ }
+ data_generator = collect_data(executor=executor_config, strategy=strategy_config, **backtest_config)
+ for trade_decision in data_generator:
+ print(trade_decision)
+
+
+if __name__ == "__main__":
+ fire.Fire(NestedDecisionExecutionWorkflow)
diff --git a/examples/rolling_process_data/workflow.py b/examples/rolling_process_data/workflow.py
index 387d5cde70..12579175f3 100644
--- a/examples/rolling_process_data/workflow.py
+++ b/examples/rolling_process_data/workflow.py
@@ -21,7 +21,6 @@ class RollingDataWorkflow:
def _init_qlib(self):
"""initialize qlib"""
- # use yahoo_cn_1min data
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True)
qlib.init(provider_uri=provider_uri, region=REG_CN)
diff --git a/examples/run_all_model.py b/examples/run_all_model.py
index 1284d8e995..41aba091e4 100644
--- a/examples/run_all_model.py
+++ b/examples/run_all_model.py
@@ -6,6 +6,7 @@
import fire
import time
import glob
+import yaml
import shutil
import signal
import inspect
@@ -23,22 +24,6 @@
from qlib.workflow import R
from qlib.tests.data import GetData
-# init qlib
-provider_uri = "~/.qlib/qlib_data/cn_data"
-exp_folder_name = "run_all_model_records"
-exp_path = str(Path(os.getcwd()).resolve() / exp_folder_name)
-exp_manager = {
- "class": "MLflowExpManager",
- "module_path": "qlib.workflow.expm",
- "kwargs": {
- "uri": "file:" + exp_path,
- "default_exp_name": "Experiment",
- },
-}
-
-GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True)
-qlib.init(provider_uri=provider_uri, region=REG_CN, exp_manager=exp_manager)
-
# decorator to check the arguments
def only_allow_defined_args(function_to_decorate):
@@ -88,11 +73,11 @@ def create_env():
sys.stderr.write("\n")
# get anaconda activate path
conda_activate = Path(os.environ["CONDA_PREFIX"]) / "bin" / "activate" # TODO: FIX ME!
- return env_path, python_path, conda_activate
+ return temp_dir, env_path, python_path, conda_activate
# function to execute the cmd
-def execute(cmd, wait_when_err=False):
+def execute(cmd, wait_when_err=False, raise_err=True):
print("Running CMD:", cmd)
with subprocess.Popen(cmd, stdout=subprocess.PIPE, bufsize=1, universal_newlines=True, shell=True) as p:
for line in p.stdout:
@@ -105,6 +90,8 @@ def execute(cmd, wait_when_err=False):
if p.returncode != 0:
if wait_when_err:
input("Press Enter to Continue")
+ if raise_err:
+ raise RuntimeError(f"Error when executing command: {cmd}")
return p.stderr
else:
return None
@@ -134,14 +121,23 @@ def get_all_folders(models, exclude) -> dict:
def get_all_files(folder_path, dataset) -> (str, str):
yaml_path = str(Path(f"{folder_path}") / f"*{dataset}*.yaml")
req_path = str(Path(f"{folder_path}") / f"*.txt")
- return glob.glob(yaml_path)[0], glob.glob(req_path)[0]
+ yaml_file = glob.glob(yaml_path)
+ req_file = glob.glob(req_path)
+ if len(yaml_file) == 0:
+ return None, None
+ else:
+ return yaml_file[0], req_file[0]
# function to retrieve all the results
def get_all_results(folders) -> dict:
results = dict()
for fn in folders:
- exp = R.get_exp(experiment_name=fn, create=False)
+ try:
+ exp = R.get_exp(experiment_name=fn, create=False)
+ except ValueError:
+ # No experiment results
+ continue
recorders = exp.list_recorders()
result = dict()
result["annualized_return_with_cost"] = list()
@@ -155,9 +151,9 @@ def get_all_results(folders) -> dict:
if recorders[recorder_id].status == "FINISHED":
recorder = R.get_recorder(recorder_id=recorder_id, experiment_name=fn)
metrics = recorder.list_metrics()
- result["annualized_return_with_cost"].append(metrics["excess_return_with_cost.annualized_return"])
- result["information_ratio_with_cost"].append(metrics["excess_return_with_cost.information_ratio"])
- result["max_drawdown_with_cost"].append(metrics["excess_return_with_cost.max_drawdown"])
+ result["annualized_return_with_cost"].append(metrics["1day.excess_return_with_cost.annualized_return"])
+ result["information_ratio_with_cost"].append(metrics["1day.excess_return_with_cost.information_ratio"])
+ result["max_drawdown_with_cost"].append(metrics["1day.excess_return_with_cost.max_drawdown"])
result["ic"].append(metrics["IC"])
result["icir"].append(metrics["ICIR"])
result["rank_ic"].append(metrics["Rank IC"])
@@ -185,6 +181,25 @@ def gen_and_save_md_table(metrics, dataset):
return table
+# read yaml, remove seed kwargs of model, and then save file in the temp_dir
+def gen_yaml_file_without_seed_kwargs(yaml_path, temp_dir):
+ with open(yaml_path, "r") as fp:
+ config = yaml.load(fp)
+ try:
+ del config["task"]["model"]["kwargs"]["seed"]
+ except KeyError:
+ # If the key does not exists, use original yaml
+ # NOTE: it is very important if the model most run in original path(when sys.rel_path is used)
+ return yaml_path
+ else:
+ # otherwise, generating a new yaml without random seed
+ file_name = yaml_path.split("/")[-1]
+ temp_path = os.path.join(temp_dir, file_name)
+ with open(temp_path, "w") as fp:
+ yaml.dump(config, fp)
+ return temp_path
+
+
# function to run the all the models
@only_allow_defined_args
def run(
@@ -193,12 +208,13 @@ def run(
dataset="Alpha360",
exclude=False,
qlib_uri: str = "git+https://github.com/microsoft/qlib#egg=pyqlib",
+ exp_folder_name: str = "run_all_model_records",
wait_before_rm_env: bool = False,
wait_when_err: bool = False,
):
"""
Please be aware that this function can only work under Linux. MacOS and Windows will be supported in the future.
- Any PR to enhance this method is highly welcomed. Besides, this script doesn't support parrallel running the same model
+ Any PR to enhance this method is highly welcomed. Besides, this script doesn't support parallel running the same model
for multiple times, and this will be fixed in the future development.
Parameters:
@@ -214,6 +230,8 @@ def run(
qlib_uri : str
the uri to install qlib with pip
it could be url on the we or local path
+ exp_folder_name: str
+ the name of the experiment folder
wait_before_rm_env : bool
wait before remove environment.
wait_when_err : bool
@@ -240,26 +258,58 @@ def run(
# Case 5 - run specific models for one time
python run_all_model.py --models=[mlp,lightgbm]
- # Case 6 - run other models except those are given as aruments for one time
+ # Case 6 - run other models except those are given as arguments for one time
python run_all_model.py --models=[mlp,tft,sfm] --exclude=True
"""
+ # init qlib
+ GetData().qlib_data(exists_skip=True)
+ qlib.init(
+ exp_manager={
+ "class": "MLflowExpManager",
+ "module_path": "qlib.workflow.expm",
+ "kwargs": {
+ "uri": "file:" + str(Path(os.getcwd()).resolve() / exp_folder_name),
+ "default_exp_name": "Experiment",
+ },
+ }
+ )
+
# get all folders
folders = get_all_folders(models, exclude)
# init error messages:
errors = dict()
# run all the model for iterations
for fn in folders:
- # create env by anaconda
- env_path, python_path, conda_activate = create_env()
# get all files
sys.stderr.write("Retrieving files...\n")
yaml_path, req_path = get_all_files(folders[fn], dataset)
+ if yaml_path is None:
+ sys.stderr.write(f"There is no {dataset}.yaml file in {folders[fn]}")
+ continue
sys.stderr.write("\n")
+ # create env by anaconda
+ temp_dir, env_path, python_path, conda_activate = create_env()
+
# install requirements.txt
sys.stderr.write("Installing requirements.txt...\n")
- execute(f"{python_path} -m pip install -r {req_path}", wait_when_err=wait_when_err)
+ with open(req_path) as f:
+ content = f.read()
+ if "torch" in content:
+ # automatically install pytorch according to nvidia's version
+ execute(
+ f"{python_path} -m pip install light-the-torch", wait_when_err=wait_when_err
+ ) # for automatically installing torch according to the nvidia driver
+ execute(
+ f"{env_path / 'bin' / 'ltt'} install --install-cmd '{python_path} -m pip install {{packages}}' -- -r {req_path}",
+ wait_when_err=wait_when_err,
+ )
+ else:
+ execute(f"{python_path} -m pip install -r {req_path}", wait_when_err=wait_when_err)
sys.stderr.write("\n")
+
+ # read yaml, remove seed kwargs of model, and then save file in the temp_dir
+ yaml_path = gen_yaml_file_without_seed_kwargs(yaml_path, temp_dir)
# setup gpu for tft
if fn == "TFT":
execute(
@@ -302,19 +352,20 @@ def run(
# getting all results
sys.stderr.write(f"Retrieving results...\n")
results = get_all_results(folders)
- # calculating the mean and std
- sys.stderr.write(f"Calculating the mean and std of results...\n")
- results = cal_mean_std(results)
- # generating md table
- sys.stderr.write(f"Generating markdown table...\n")
- gen_and_save_md_table(results, dataset)
- sys.stderr.write("\n")
- # print erros
+ if len(results) > 0:
+ # calculating the mean and std
+ sys.stderr.write(f"Calculating the mean and std of results...\n")
+ results = cal_mean_std(results)
+ # generating md table
+ sys.stderr.write(f"Generating markdown table...\n")
+ gen_and_save_md_table(results, dataset)
+ sys.stderr.write("\n")
+ # print errors
sys.stderr.write(f"Here are some of the errors of the models...\n")
pprint(errors)
sys.stderr.write("\n")
# move results folder
- shutil.move(exp_path, exp_path + f"_{dataset}_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}")
+ shutil.move(exp_folder_name, exp_folder_name + f"_{dataset}_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}")
shutil.move("table.md", f"table_{dataset}_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}.md")
diff --git a/examples/workflow_by_code.ipynb b/examples/workflow_by_code.ipynb
index af374b350c..907245adef 100644
--- a/examples/workflow_by_code.ipynb
+++ b/examples/workflow_by_code.ipynb
@@ -20,9 +20,7 @@
{
"cell_type": "code",
"execution_count": null,
- "metadata": {
- "scrolled": true
- },
+ "metadata": {},
"outputs": [],
"source": [
"import sys, site\n",
@@ -66,7 +64,6 @@
"from qlib.config import REG_CN\n",
"from qlib.contrib.model.gbdt import LGBModel\n",
"from qlib.contrib.data.handler import Alpha158\n",
- "from qlib.contrib.strategy.strategy import TopkDropoutStrategy\n",
"from qlib.contrib.evaluate import (\n",
" backtest as normal_backtest,\n",
" risk_analysis,\n",
@@ -197,27 +194,40 @@
"# prediction, backtest & analysis\n",
"###################################\n",
"port_analysis_config = {\n",
+ " \"executor\": {\n",
+ " \"class\": \"SimulatorExecutor\",\n",
+ " \"module_path\": \"qlib.backtest.executor\",\n",
+ " \"kwargs\": {\n",
+ " \"time_per_step\": \"day\",\n",
+ " \"generate_portfolio_metrics\": True,\n",
+ " },\n",
+ " },\n",
" \"strategy\": {\n",
" \"class\": \"TopkDropoutStrategy\",\n",
- " \"module_path\": \"qlib.contrib.strategy.strategy\",\n",
+ " \"module_path\": \"qlib.contrib.strategy.model_strategy\",\n",
" \"kwargs\": {\n",
+ " \"model\": model,\n",
+ " \"dataset\": dataset,\n",
" \"topk\": 50,\n",
" \"n_drop\": 5,\n",
" },\n",
" },\n",
" \"backtest\": {\n",
- " \"verbose\": False,\n",
- " \"limit_threshold\": 0.095,\n",
+ " \"start_time\": \"2017-01-01\",\n",
+ " \"end_time\": \"2020-08-01\",\n",
" \"account\": 100000000,\n",
" \"benchmark\": benchmark,\n",
- " \"deal_price\": \"close\",\n",
- " \"open_cost\": 0.0005,\n",
- " \"close_cost\": 0.0015,\n",
- " \"min_cost\": 5,\n",
+ " \"exchange_kwargs\": {\n",
+ " \"freq\": \"day\",\n",
+ " \"limit_threshold\": 0.095,\n",
+ " \"deal_price\": \"close\",\n",
+ " \"open_cost\": 0.0005,\n",
+ " \"close_cost\": 0.0015,\n",
+ " \"min_cost\": 5,\n",
+ " },\n",
" },\n",
"}\n",
"\n",
- "\n",
"# backtest and analysis\n",
"with R.start(experiment_name=\"backtest_analysis\"):\n",
" recorder = R.get_recorder(recorder_id=rid, experiment_name=\"train_model\")\n",
@@ -230,7 +240,7 @@
" sr.generate()\n",
"\n",
" # backtest & analysis\n",
- " par = PortAnaRecord(recorder, port_analysis_config)\n",
+ " par = PortAnaRecord(recorder, port_analysis_config, \"day\")\n",
" par.generate()\n"
]
},
@@ -250,11 +260,12 @@
"from qlib.contrib.report import analysis_model, analysis_position\n",
"from qlib.data import D\n",
"recorder = R.get_recorder(recorder_id=ba_rid, experiment_name=\"backtest_analysis\")\n",
+ "print(recorder)\n",
"pred_df = recorder.load_object(\"pred.pkl\")\n",
"pred_df_dates = pred_df.index.get_level_values(level='datetime')\n",
- "report_normal_df = recorder.load_object(\"portfolio_analysis/report_normal.pkl\")\n",
- "positions = recorder.load_object(\"portfolio_analysis/positions_normal.pkl\")\n",
- "analysis_df = recorder.load_object(\"portfolio_analysis/port_analysis.pkl\")"
+ "report_normal_df = recorder.load_object(\"portfolio_analysis/report_normal_1day.pkl\")\n",
+ "positions = recorder.load_object(\"portfolio_analysis/positions_normal_1day.pkl\")\n",
+ "analysis_df = recorder.load_object(\"portfolio_analysis/port_analysis_1day.pkl\")"
]
},
{
@@ -349,7 +360,7 @@
],
"metadata": {
"kernelspec": {
- "display_name": "Python 3",
+ "display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
@@ -362,8 +373,7 @@
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.8.3"
+ "pygments_lexer": "ipython3"
},
"toc": {
"base_numbering": 1,
@@ -381,4 +391,4 @@
},
"nbformat": 4,
"nbformat_minor": 4
-}
\ No newline at end of file
+}
diff --git a/examples/workflow_by_code.py b/examples/workflow_by_code.py
index 1cdf2ac80f..486e694a75 100644
--- a/examples/workflow_by_code.py
+++ b/examples/workflow_by_code.py
@@ -17,32 +17,44 @@
GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True)
qlib.init(provider_uri=provider_uri, region=REG_CN)
+ model = init_instance_by_config(CSI300_GBDT_TASK["model"])
+ dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"])
+
port_analysis_config = {
+ "executor": {
+ "class": "SimulatorExecutor",
+ "module_path": "qlib.backtest.executor",
+ "kwargs": {
+ "time_per_step": "day",
+ "generate_portfolio_metrics": True,
+ },
+ },
"strategy": {
"class": "TopkDropoutStrategy",
- "module_path": "qlib.contrib.strategy.strategy",
+ "module_path": "qlib.contrib.strategy.model_strategy",
"kwargs": {
+ "model": model,
+ "dataset": dataset,
"topk": 50,
"n_drop": 5,
},
},
"backtest": {
- "verbose": False,
- "limit_threshold": 0.095,
+ "start_time": "2017-01-01",
+ "end_time": "2020-08-01",
"account": 100000000,
"benchmark": CSI300_BENCH,
- "deal_price": "close",
- "open_cost": 0.0005,
- "close_cost": 0.0015,
- "min_cost": 5,
- "return_order": True,
+ "exchange_kwargs": {
+ "freq": "day",
+ "limit_threshold": 0.095,
+ "deal_price": "close",
+ "open_cost": 0.0005,
+ "close_cost": 0.0015,
+ "min_cost": 5,
+ },
},
}
- # model initialization
- model = init_instance_by_config(CSI300_GBDT_TASK["model"])
- dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"])
-
# NOTE: This line is optional
# It demonstrates that the dataset can be used standalone.
example_df = dataset.prepare("train")
@@ -61,5 +73,5 @@
# backtest. If users want to use backtest based on their own prediction,
# please refer to https://qlib.readthedocs.io/en/latest/component/recorder.html#record-template.
- par = PortAnaRecord(recorder, port_analysis_config)
+ par = PortAnaRecord(recorder, port_analysis_config, "day")
par.generate()
diff --git a/qlib/__init__.py b/qlib/__init__.py
index 7d23df2685..efa89b1537 100644
--- a/qlib/__init__.py
+++ b/qlib/__init__.py
@@ -54,14 +54,15 @@ def init(default_conf="client", **kwargs):
if "flask_server" in C:
logger.info(f"flask_server={C['flask_server']}, flask_port={C['flask_port']}")
logger.info("qlib successfully initialized based on %s settings." % default_conf)
- data_path = {_freq: C.dpm.get_data_path(_freq) for _freq in C.dpm.provider_uri.keys()}
+ data_path = {_freq: C.dpm.get_data_uri(_freq) for _freq in C.dpm.provider_uri.keys()}
logger.info(f"data_path={data_path}")
def _mount_nfs_uri(provider_uri, mount_path, auto_mount: bool = False):
LOG = get_module_logger("mount nfs", level=logging.INFO)
-
+ if mount_path is None:
+ raise ValueError(f"Invalid mount path: {mount_path}!")
# FIXME: the C["provider_uri"] is modified in this function
# If it is not modified, we can pass only provider_uri or mount_path instead of C
mount_command = "sudo mount.nfs %s %s" % (provider_uri, mount_path)
diff --git a/qlib/backtest/__init__.py b/qlib/backtest/__init__.py
new file mode 100644
index 0000000000..38541d7680
--- /dev/null
+++ b/qlib/backtest/__init__.py
@@ -0,0 +1,322 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+from __future__ import annotations
+import copy
+from typing import List, Tuple, Union, TYPE_CHECKING
+
+from .account import Account
+
+if TYPE_CHECKING:
+ from ..strategy.base import BaseStrategy
+ from .executor import BaseExecutor
+ from .decision import BaseTradeDecision
+from .position import Position
+from .exchange import Exchange
+from .backtest import backtest_loop
+from .backtest import collect_data_loop
+from .utils import CommonInfrastructure
+from .decision import Order
+from ..utils import init_instance_by_config
+from ..log import get_module_logger
+from ..config import C
+
+# make import more user-friendly by adding `from qlib.backtest import STH`
+
+
+logger = get_module_logger("backtest caller")
+
+
+def get_exchange(
+ exchange=None,
+ freq="day",
+ start_time=None,
+ end_time=None,
+ codes="all",
+ subscribe_fields=[],
+ open_cost=0.0015,
+ close_cost=0.0025,
+ min_cost=5.0,
+ limit_threshold=None,
+ deal_price: Union[str, Tuple[str], List[str]] = None,
+ **kwargs,
+):
+ """get_exchange
+
+ Parameters
+ ----------
+
+ # exchange related arguments
+ exchange: Exchange().
+ subscribe_fields: list
+ subscribe fields.
+ open_cost : float
+ open transaction cost.
+ close_cost : float
+ close transaction cost.
+ min_cost : float
+ min transaction cost.
+ trade_unit : int
+ Included in kwargs. Please refer to the docs of `__init__` of `Exchange`
+ deal_price: Union[str, Tuple[str], List[str]]
+ The `deal_price` supports following two types of input
+ - : str
+ - (, ): Tuple[str] or List[str]
+
+ , or :=
+ := str
+ - for example '$close', '$open', '$vwap' ("close" is OK. `Exchange` will help to prepend
+ "$" to the expression)
+ limit_threshold : float
+ limit move 0.1 (10%) for example, long and short with same limit.
+
+ Returns
+ -------
+ :class: Exchange
+ an initialized Exchange object
+ """
+
+ if limit_threshold is None:
+ limit_threshold = C.limit_threshold
+ if exchange is None:
+ logger.info("Create new exchange")
+
+ exchange = Exchange(
+ freq=freq,
+ start_time=start_time,
+ end_time=end_time,
+ codes=codes,
+ deal_price=deal_price,
+ subscribe_fields=subscribe_fields,
+ limit_threshold=limit_threshold,
+ open_cost=open_cost,
+ close_cost=close_cost,
+ min_cost=min_cost,
+ **kwargs,
+ )
+ return exchange
+ else:
+ return init_instance_by_config(exchange, accept_types=Exchange)
+
+
+def create_account_instance(
+ start_time, end_time, benchmark: str, account: Union[float, int, dict], pos_type: str = "Position"
+) -> Account:
+ """
+ # TODO: is very strange pass benchmark_config in the account(maybe for report)
+ # There should be a post-step to process the report.
+
+ Parameters
+ ----------
+ start_time
+ start time of the benchmark
+ end_time
+ end time of the benchmark
+ benchmark : str
+ the benchmark for reporting
+ account : Union[
+ float,
+ {
+ "cash": float,
+ "stock1": Union[
+ int, # it is equal to {"amount": int}
+ {"amount": int, "price"(optional): float},
+ ]
+ },
+ ]
+ information for describing how to creating the account
+ For `float`:
+ Using Account with only initial cash
+ For `dict`:
+ key "cash" means initial cash.
+ key "stock1" means the information of first stock with amount and price(optional).
+ ...
+ """
+ if isinstance(account, (int, float)):
+ pos_kwargs = {"init_cash": account}
+ elif isinstance(account, dict):
+ init_cash = account["cash"]
+ del account["cash"]
+ pos_kwargs = {
+ "init_cash": init_cash,
+ "position_dict": account,
+ }
+ else:
+ raise ValueError("account must be in (int, float, Position)")
+
+ kwargs = {
+ "init_cash": account,
+ "benchmark_config": {
+ "benchmark": benchmark,
+ "start_time": start_time,
+ "end_time": end_time,
+ },
+ "pos_type": pos_type,
+ }
+ kwargs.update(pos_kwargs)
+ return Account(**kwargs)
+
+
+def get_strategy_executor(
+ start_time,
+ end_time,
+ strategy: BaseStrategy,
+ executor: BaseExecutor,
+ benchmark: str = "SH000300",
+ account: Union[float, int, Position] = 1e9,
+ exchange_kwargs: dict = {},
+ pos_type: str = "Position",
+):
+
+ # NOTE:
+ # - for avoiding recursive import
+ # - typing annotations is not reliable
+ from ..strategy.base import BaseStrategy
+ from .executor import BaseExecutor
+
+ trade_account = create_account_instance(
+ start_time=start_time, end_time=end_time, benchmark=benchmark, account=account, pos_type=pos_type
+ )
+
+ exchange_kwargs = copy.copy(exchange_kwargs)
+ if "start_time" not in exchange_kwargs:
+ exchange_kwargs["start_time"] = start_time
+ if "end_time" not in exchange_kwargs:
+ exchange_kwargs["end_time"] = end_time
+ trade_exchange = get_exchange(**exchange_kwargs)
+
+ common_infra = CommonInfrastructure(trade_account=trade_account, trade_exchange=trade_exchange)
+ trade_strategy = init_instance_by_config(strategy, accept_types=BaseStrategy, common_infra=common_infra)
+ trade_executor = init_instance_by_config(executor, accept_types=BaseExecutor, common_infra=common_infra)
+
+ return trade_strategy, trade_executor
+
+
+def backtest(
+ start_time,
+ end_time,
+ strategy,
+ executor,
+ benchmark="SH000300",
+ account=1e9,
+ exchange_kwargs={},
+ pos_type: str = "Position",
+):
+ """initialize the strategy and executor, then backtest function for the interaction of the outermost strategy and executor in the nested decision execution
+
+ Parameters
+ ----------
+ start_time : pd.Timestamp|str
+ closed start time for backtest
+ **NOTE**: This will be applied to the outmost executor's calendar.
+ end_time : pd.Timestamp|str
+ closed end time for backtest
+ **NOTE**: This will be applied to the outmost executor's calendar.
+ E.g. Executor[day](Executor[1min]), setting `end_time == 20XX0301` will include all the minutes on 20XX0301
+ strategy : Union[str, dict, BaseStrategy]
+ for initializing outermost portfolio strategy. Please refer to the docs of init_instance_by_config for more information.
+ executor : Union[str, dict, BaseExecutor]
+ for initializing the outermost executor.
+ benchmark: str
+ the benchmark for reporting.
+ account : Union[float, int, Position]
+ information for describing how to creating the account
+ For `float` or `int`:
+ Using Account with only initial cash
+ For `Position`:
+ Using Account with a Position
+ exchange_kwargs : dict
+ the kwargs for initializing Exchange
+ pos_type : str
+ the type of Position.
+
+ Returns
+ -------
+ portfolio_metrics_dict: Dict[PortfolioMetrics]
+ it records the trading portfolio_metrics information
+ indicator_dict: Dict[Indicator]
+ it computes the trading indicator
+ It is organized in a dict format
+
+ """
+ trade_strategy, trade_executor = get_strategy_executor(
+ start_time,
+ end_time,
+ strategy,
+ executor,
+ benchmark,
+ account,
+ exchange_kwargs,
+ pos_type=pos_type,
+ )
+ portfolio_metrics, indicator = backtest_loop(start_time, end_time, trade_strategy, trade_executor)
+ return portfolio_metrics, indicator
+
+
+def collect_data(
+ start_time,
+ end_time,
+ strategy,
+ executor,
+ benchmark="SH000300",
+ account=1e9,
+ exchange_kwargs={},
+ pos_type: str = "Position",
+ return_value: dict = None,
+):
+ """initialize the strategy and executor, then collect the trade decision data for rl training
+
+ please refer to the docs of the backtest for the explanation of the parameters
+
+ Yields
+ -------
+ object
+ trade decision
+ """
+ trade_strategy, trade_executor = get_strategy_executor(
+ start_time,
+ end_time,
+ strategy,
+ executor,
+ benchmark,
+ account,
+ exchange_kwargs,
+ pos_type=pos_type,
+ )
+ yield from collect_data_loop(start_time, end_time, trade_strategy, trade_executor, return_value=return_value)
+
+
+def format_decisions(
+ decisions: List[BaseTradeDecision],
+) -> Tuple[str, List[Tuple[BaseTradeDecision, Union[Tuple, None]]]]:
+ """
+ format the decisions collected by `qlib.backtest.collect_data`
+ The decisions will be organized into a tree-like structure.
+
+ Parameters
+ ----------
+ decisions : List[BaseTradeDecision]
+ decisions collected by `qlib.backtest.collect_data`
+
+ Returns
+ -------
+ Tuple[str, List[Tuple[BaseTradeDecision, Union[Tuple, None]]]]:
+
+ reformat the list of decisions into a more user-friendly format
+ := Tuple[, List[Tuple[, ]]]
+ - := ` in lower level` | None
+ - := "day" | "30min" | "1min" | ...
+ - :=
+ """
+ if len(decisions) == 0:
+ return None
+
+ cur_freq = decisions[0].strategy.trade_calendar.get_freq()
+
+ res = (cur_freq, [])
+ last_dec_idx = 0
+ for i, dec in enumerate(decisions[1:], 1):
+ if dec.strategy.trade_calendar.get_freq() == cur_freq:
+ res[1].append((decisions[last_dec_idx], format_decisions(decisions[last_dec_idx + 1 : i])))
+ last_dec_idx = i
+ res[1].append((decisions[last_dec_idx], format_decisions(decisions[last_dec_idx + 1 :])))
+ return res
diff --git a/qlib/backtest/account.py b/qlib/backtest/account.py
new file mode 100644
index 0000000000..aa503ebc27
--- /dev/null
+++ b/qlib/backtest/account.py
@@ -0,0 +1,377 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+from __future__ import annotations
+import copy
+from typing import Dict, List, Tuple, TYPE_CHECKING
+from qlib.utils import init_instance_by_config
+import pandas as pd
+
+from .position import BasePosition, InfPosition, Position
+from .report import PortfolioMetrics, Indicator
+from .decision import BaseTradeDecision, Order
+from .exchange import Exchange
+
+"""
+rtn & earning in the Account
+ rtn:
+ from order's view
+ 1.change if any order is executed, sell order or buy order
+ 2.change at the end of today, (today_close - stock_price) * amount
+ earning
+ from value of current position
+ earning will be updated at the end of trade date
+ earning = today_value - pre_value
+ **is consider cost**
+ while earning is the difference of two position value, so it considers cost, it is the true return rate
+ in the specific accomplishment for rtn, it does not consider cost, in other words, rtn - cost = earning
+
+"""
+
+
+class AccumulatedInfo:
+ """accumulated trading info, including accumulated return/cost/turnover"""
+
+ def __init__(self):
+ self.reset()
+
+ def reset(self):
+ self.rtn = 0 # accumulated return, do not consider cost
+ self.cost = 0 # accumulated cost
+ self.to = 0 # accumulated turnover
+
+ def add_return_value(self, value):
+ self.rtn += value
+
+ def add_cost(self, value):
+ self.cost += value
+
+ def add_turnover(self, value):
+ self.to += value
+
+ @property
+ def get_return(self):
+ return self.rtn
+
+ @property
+ def get_cost(self):
+ return self.cost
+
+ @property
+ def get_turnover(self):
+ return self.to
+
+
+class Account:
+ def __init__(
+ self,
+ init_cash: float = 1e9,
+ position_dict: dict = {},
+ freq: str = "day",
+ benchmark_config: dict = {},
+ pos_type: str = "Position",
+ port_metr_enabled: bool = True,
+ ):
+ """the trade account of backtest.
+
+ Parameters
+ ----------
+ init_cash : float, optional
+ initial cash, by default 1e9
+ position_dict : Dict[
+ stock_id,
+ Union[
+ int, # it is equal to {"amount": int}
+ {"amount": int, "price"(optional): float},
+ ]
+ ]
+ initial stocks with parameters amount and price,
+ if there is no price key in the dict of stocks, it will be filled by _fill_stock_value.
+ by default {}.
+ """
+
+ self._pos_type = pos_type
+ self._port_metr_enabled = port_metr_enabled
+ self.benchmark_config = None # avoid no attribute error
+ self.init_vars(init_cash, position_dict, freq, benchmark_config)
+
+ def init_vars(self, init_cash, position_dict, freq: str, benchmark_config: dict):
+ self.init_cash = init_cash
+ self.current_position: BasePosition = init_instance_by_config(
+ {
+ "class": self._pos_type,
+ "kwargs": {
+ "cash": init_cash,
+ "position_dict": position_dict,
+ },
+ "module_path": "qlib.backtest.position",
+ }
+ )
+ self.portfolio_metrics = None
+ self.hist_positions = {}
+ self.reset(freq=freq, benchmark_config=benchmark_config)
+
+ def is_port_metr_enabled(self):
+ """
+ Is portfolio-based metrics enabled.
+ """
+ return self._port_metr_enabled and not self.current_position.skip_update()
+
+ def reset_report(self, freq, benchmark_config):
+ # portfolio related metrics
+ if self.is_port_metr_enabled():
+ self.accum_info = AccumulatedInfo()
+ self.portfolio_metrics = PortfolioMetrics(freq, benchmark_config)
+ self.hist_positions = {}
+
+ # fill stock value
+ # The frequency of account may not align with the trading frequency.
+ # This may result in obscure bugs when data quality is low.
+ if isinstance(self.benchmark_config, dict) and self.benchmark_config.get("start_time") is not None:
+ self.current_position.fill_stock_value(self.benchmark_config["start_time"], self.freq)
+
+ # trading related metrics(e.g. high-frequency trading)
+ self.indicator = Indicator()
+
+ def reset(self, freq=None, benchmark_config=None, port_metr_enabled: bool = None):
+ """reset freq and report of account
+
+ Parameters
+ ----------
+ freq : str, optional
+ frequency of account & report, by default None
+ benchmark_config : {}, optional
+ benchmark config of report, by default None
+ """
+ if freq is not None:
+ self.freq = freq
+ if benchmark_config is not None:
+ self.benchmark_config = benchmark_config
+ if port_metr_enabled is not None:
+ self._port_metr_enabled = port_metr_enabled
+
+ self.reset_report(self.freq, self.benchmark_config)
+
+ def get_hist_positions(self):
+ return self.hist_positions
+
+ def get_cash(self):
+ return self.current_position.get_cash()
+
+ def _update_state_from_order(self, order, trade_val, cost, trade_price):
+ if self.is_port_metr_enabled():
+ # update turnover
+ self.accum_info.add_turnover(trade_val)
+ # update cost
+ self.accum_info.add_cost(cost)
+
+ # update return from order
+ trade_amount = trade_val / trade_price
+ if order.direction == Order.SELL: # 0 for sell
+ # when sell stock, get profit from price change
+ profit = trade_val - self.current_position.get_stock_price(order.stock_id) * trade_amount
+ self.accum_info.add_return_value(profit) # note here do not consider cost
+
+ elif order.direction == Order.BUY: # 1 for buy
+ # when buy stock, we get return for the rtn computing method
+ # profit in buy order is to make rtn is consistent with earning at the end of bar
+ profit = self.current_position.get_stock_price(order.stock_id) * trade_amount - trade_val
+ self.accum_info.add_return_value(profit) # note here do not consider cost
+
+ def update_order(self, order, trade_val, cost, trade_price):
+ if self.current_position.skip_update():
+ # TODO: supporting polymorphism for account
+ # updating order for infinite position is meaningless
+ return
+
+ # if stock is sold out, no stock price information in Position, then we should update account first, then update current position
+ # if stock is bought, there is no stock in current position, update current, then update account
+ # The cost will be substracted from the cash at last. So the trading logic can ignore the cost calculation
+ if order.direction == Order.SELL:
+ # sell stock
+ self._update_state_from_order(order, trade_val, cost, trade_price)
+ # update current position
+ # for may sell all of stock_id
+ self.current_position.update_order(order, trade_val, cost, trade_price)
+ else:
+ # buy stock
+ # deal order, then update state
+ self.current_position.update_order(order, trade_val, cost, trade_price)
+ self._update_state_from_order(order, trade_val, cost, trade_price)
+
+ def update_current_position(self, trade_start_time, trade_end_time, trade_exchange):
+ """update current to make rtn consistent with earning at the end of bar, and update holding bar count of stock"""
+ # update price for stock in the position and the profit from changed_price
+ # NOTE: updating position does not only serve portfolio metrics, it also serve the strategy
+ if not self.current_position.skip_update():
+ stock_list = self.current_position.get_stock_list()
+ for code in stock_list:
+ # if suspend, no new price to be updated, profit is 0
+ if trade_exchange.check_stock_suspended(code, trade_start_time, trade_end_time):
+ continue
+ bar_close = trade_exchange.get_close(code, trade_start_time, trade_end_time)
+ self.current_position.update_stock_price(stock_id=code, price=bar_close)
+ # update holding day count
+ # NOTE: updating bar_count does not only serve portfolio metrics, it also serve the strategy
+ self.current_position.add_count_all(bar=self.freq)
+
+ def update_portfolio_metrics(self, trade_start_time, trade_end_time):
+ """update portfolio_metrics"""
+ # calculate earning
+ # account_value - last_account_value
+ # for the first trade date, account_value - init_cash
+ # self.portfolio_metrics.is_empty() to judge is_first_trade_date
+ # get last_account_value, last_total_cost, last_total_turnover
+ if self.portfolio_metrics.is_empty():
+ last_account_value = self.init_cash
+ last_total_cost = 0
+ last_total_turnover = 0
+ else:
+ last_account_value = self.portfolio_metrics.get_latest_account_value()
+ last_total_cost = self.portfolio_metrics.get_latest_total_cost()
+ last_total_turnover = self.portfolio_metrics.get_latest_total_turnover()
+ # get now_account_value, now_stock_value, now_earning, now_cost, now_turnover
+ now_account_value = self.current_position.calculate_value()
+ now_stock_value = self.current_position.calculate_stock_value()
+ now_earning = now_account_value - last_account_value
+ now_cost = self.accum_info.get_cost - last_total_cost
+ now_turnover = self.accum_info.get_turnover - last_total_turnover
+ # update portfolio_metrics for today
+ # judge whether the the trading is begin.
+ # and don't add init account state into portfolio_metrics, due to we don't have excess return in those days.
+ self.portfolio_metrics.update_portfolio_metrics_record(
+ trade_start_time=trade_start_time,
+ trade_end_time=trade_end_time,
+ account_value=now_account_value,
+ cash=self.current_position.position["cash"],
+ return_rate=(now_earning + now_cost) / last_account_value,
+ # here use earning to calculate return, position's view, earning consider cost, true return
+ # in order to make same definition with original backtest in evaluate.py
+ total_turnover=self.accum_info.get_turnover,
+ turnover_rate=now_turnover / last_account_value,
+ total_cost=self.accum_info.get_cost,
+ cost_rate=now_cost / last_account_value,
+ stock_value=now_stock_value,
+ )
+
+ def update_hist_positions(self, trade_start_time):
+ """update history position"""
+ now_account_value = self.current_position.calculate_value()
+ # set now_account_value to position
+ self.current_position.position["now_account_value"] = now_account_value
+ self.current_position.update_weight_all()
+ # update hist_positions
+ # note use deepcopy
+ self.hist_positions[trade_start_time] = copy.deepcopy(self.current_position)
+
+ def update_indicator(
+ self,
+ trade_start_time: pd.Timestamp,
+ trade_exchange: Exchange,
+ atomic: bool,
+ outer_trade_decision: BaseTradeDecision,
+ trade_info: list = None,
+ inner_order_indicators: List[Dict[str, pd.Series]] = None,
+ decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]] = None,
+ indicator_config: dict = {},
+ ):
+ """update trade indicators and order indicators in each bar end"""
+ # TODO: will skip empty decisions make it faster? `outer_trade_decision.empty():`
+
+ # indicator is trading (e.g. high-frequency order execution) related analysis
+ self.indicator.reset()
+
+ # aggregate the information for each order
+ if atomic:
+ self.indicator.update_order_indicators(trade_info)
+ else:
+ self.indicator.agg_order_indicators(
+ inner_order_indicators,
+ decision_list=decision_list,
+ outer_trade_decision=outer_trade_decision,
+ trade_exchange=trade_exchange,
+ indicator_config=indicator_config,
+ )
+
+ # aggregate all the order metrics a single step
+ self.indicator.cal_trade_indicators(trade_start_time, self.freq, indicator_config)
+
+ # record the metrics
+ self.indicator.record(trade_start_time)
+
+ def update_bar_end(
+ self,
+ trade_start_time: pd.Timestamp,
+ trade_end_time: pd.Timestamp,
+ trade_exchange: Exchange,
+ atomic: bool,
+ outer_trade_decision: BaseTradeDecision,
+ trade_info: list = None,
+ inner_order_indicators: List[Dict[str, pd.Series]] = None,
+ decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]] = None,
+ indicator_config: dict = {},
+ ):
+ """update account at each trading bar step
+
+ Parameters
+ ----------
+ trade_start_time : pd.Timestamp
+ closed start time of step
+ trade_end_time : pd.Timestamp
+ closed end time of step
+ trade_exchange : Exchange
+ trading exchange, used to update current
+ atomic : bool
+ whether the trading executor is atomic, which means there is no higher-frequency trading executor inside it
+ - if atomic is True, calculate the indicators with trade_info
+ - else, aggregate indicators with inner indicators
+ trade_info : List[(Order, float, float, float)], optional
+ trading information, by default None
+ - necessary if atomic is True
+ - list of tuple(order, trade_val, trade_cost, trade_price)
+ inner_order_indicators : Indicator, optional
+ indicators of inner executor, by default None
+ - necessary if atomic is False
+ - used to aggregate outer indicators
+ decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]] = None,
+ The decision list of the inner level: List[Tuple[, , ]]
+ The inner level
+ indicator_config : dict, optional
+ config of calculating indicators, by default {}
+ """
+ if atomic is True and trade_info is None:
+ raise ValueError("trade_info is necessary in atomic executor")
+ elif atomic is False and inner_order_indicators is None:
+ raise ValueError("inner_order_indicators is necessary in un-atomic executor")
+
+ # update current position and hold bar count in each bar end
+ self.update_current_position(trade_start_time, trade_end_time, trade_exchange)
+
+ if self.is_port_metr_enabled():
+ # portfolio_metrics is portfolio related analysis
+ self.update_portfolio_metrics(trade_start_time, trade_end_time)
+ self.update_hist_positions(trade_start_time)
+
+ # update indicator in each bar end
+ self.update_indicator(
+ trade_start_time=trade_start_time,
+ trade_exchange=trade_exchange,
+ atomic=atomic,
+ outer_trade_decision=outer_trade_decision,
+ trade_info=trade_info,
+ inner_order_indicators=inner_order_indicators,
+ decision_list=decision_list,
+ indicator_config=indicator_config,
+ )
+
+ def get_portfolio_metrics(self):
+ """get the history portfolio_metrics and postions instance"""
+ if self.is_port_metr_enabled():
+ _portfolio_metrics = self.portfolio_metrics.generate_portfolio_metrics_dataframe()
+ _positions = self.get_hist_positions()
+ return _portfolio_metrics, _positions
+ else:
+ raise ValueError("generate_portfolio_metrics should be True if you want to generate portfolio_metrics")
+
+ def get_trade_indicator(self) -> Indicator:
+ """get the trade indicator instance, which has pa/pos/ffr info."""
+ return self.indicator
diff --git a/qlib/backtest/backtest.py b/qlib/backtest/backtest.py
new file mode 100644
index 0000000000..fa4063bc92
--- /dev/null
+++ b/qlib/backtest/backtest.py
@@ -0,0 +1,81 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+from __future__ import annotations
+from qlib.backtest.decision import BaseTradeDecision
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+ from qlib.strategy.base import BaseStrategy
+ from qlib.backtest.executor import BaseExecutor
+from ..utils.time import Freq
+from tqdm.auto import tqdm
+
+
+def backtest_loop(start_time, end_time, trade_strategy: BaseStrategy, trade_executor: BaseExecutor):
+ """backtest funciton for the interaction of the outermost strategy and executor in the nested decision execution
+
+ please refer to the docs of `collect_data_loop`
+
+ Returns
+ -------
+ portfolio_metrics: PortfolioMetrics
+ it records the trading portfolio_metrics information
+ indicator: Indicator
+ it computes the trading indicator
+ """
+ return_value = {}
+ for _decision in collect_data_loop(start_time, end_time, trade_strategy, trade_executor, return_value):
+ pass
+ return return_value.get("portfolio_metrics"), return_value.get("indicator")
+
+
+def collect_data_loop(
+ start_time, end_time, trade_strategy: BaseStrategy, trade_executor: BaseExecutor, return_value: dict = None
+):
+ """Generator for collecting the trade decision data for rl training
+
+ Parameters
+ ----------
+ start_time : pd.Timestamp|str
+ closed start time for backtest
+ **NOTE**: This will be applied to the outmost executor's calendar.
+ end_time : pd.Timestamp|str
+ closed end time for backtest
+ **NOTE**: This will be applied to the outmost executor's calendar.
+ E.g. Executor[day](Executor[1min]), setting `end_time == 20XX0301` will include all the minutes on 20XX0301
+ trade_strategy : BaseStrategy
+ the outermost portfolio strategy
+ trade_executor : BaseExecutor
+ the outermost executor
+ return_value : dict
+ used for backtest_loop
+
+ Yields
+ -------
+ object
+ trade decision
+ """
+ trade_executor.reset(start_time=start_time, end_time=end_time)
+ trade_strategy.reset(level_infra=trade_executor.get_level_infra())
+
+ with tqdm(total=trade_executor.trade_calendar.get_trade_len(), desc="backtest loop") as bar:
+ _execute_result = None
+ while not trade_executor.finished():
+ _trade_decision: BaseTradeDecision = trade_strategy.generate_trade_decision(_execute_result)
+ _execute_result = yield from trade_executor.collect_data(_trade_decision, level=0)
+ bar.update(1)
+
+ if return_value is not None:
+ all_executors = trade_executor.get_all_executors()
+ all_portfolio_metrics = {
+ "{}{}".format(*Freq.parse(_executor.time_per_step)): _executor.trade_account.get_portfolio_metrics()
+ for _executor in all_executors
+ if _executor.trade_account.is_port_metr_enabled()
+ }
+ all_indicators = {}
+ for _executor in all_executors:
+ key = "{}{}".format(*Freq.parse(_executor.time_per_step))
+ all_indicators[key] = _executor.trade_account.get_trade_indicator().generate_trade_indicators_dataframe()
+ all_indicators[key + "_obj"] = _executor.trade_account.get_trade_indicator()
+ return_value.update({"portfolio_metrics": all_portfolio_metrics, "indicator": all_indicators})
diff --git a/qlib/backtest/decision.py b/qlib/backtest/decision.py
new file mode 100644
index 0000000000..049e56c005
--- /dev/null
+++ b/qlib/backtest/decision.py
@@ -0,0 +1,548 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+from __future__ import annotations
+from enum import IntEnum
+from qlib.data.data import Cal
+from qlib.utils.time import concat_date_time, epsilon_change
+from qlib.log import get_module_logger
+
+# try to fix circular imports when enabling type hints
+from typing import Callable, TYPE_CHECKING
+
+if TYPE_CHECKING:
+ from qlib.strategy.base import BaseStrategy
+ from qlib.backtest.exchange import Exchange
+from qlib.backtest.utils import TradeCalendarManager
+import warnings
+import numpy as np
+import pandas as pd
+import numpy as np
+from dataclasses import dataclass, field
+from typing import ClassVar, Optional, Union, List, Set, Tuple
+
+
+class OrderDir(IntEnum):
+ # Order direction
+ SELL = 0
+ BUY = 1
+
+
+@dataclass
+class Order:
+ """
+ stock_id : str
+ amount : float
+ start_time : pd.Timestamp
+ closed start time for order trading
+ end_time : pd.Timestamp
+ closed end time for order trading
+ direction : int
+ Order.SELL for sell; Order.BUY for buy
+ factor : float
+ presents the weight factor assigned in Exchange()
+ """
+
+ # 1) time invariant values
+ # - they are set by users and is time-invariant.
+ stock_id: str
+ amount: float # `amount` is a non-negative and adjusted value
+ direction: int
+
+ # 2) time variant values:
+ # - Users may want to set these values when using lower level APIs
+ # - If users don't, TradeDecisionWO will help users to set them
+ # The interval of the order which belongs to (NOTE: this is not the expected order dealing range time)
+ start_time: pd.Timestamp
+ end_time: pd.Timestamp
+
+ # 3) results
+ # - users should not care about these values
+ # - they are set by the backtest system after finishing the results.
+ # What the value should be about in all kinds of cases
+ # - not tradable: the deal_amount == 0 , factor is None
+ # - the stock is suspended and the entire order fails. No cost for this order
+ # - dealed or partially dealed: deal_amount >= 0 and factor is not None
+ deal_amount: Optional[float] = None # `deal_amount` is a non-negative value
+ factor: Optional[float] = None
+
+ # TODO:
+ # a status field to indicate the dealing result of the order
+
+ # FIXME:
+ # for compatible now.
+ # Please remove them in the future
+ SELL: ClassVar[OrderDir] = OrderDir.SELL
+ BUY: ClassVar[OrderDir] = OrderDir.BUY
+
+ def __post_init__(self):
+ if self.direction not in {Order.SELL, Order.BUY}:
+ raise NotImplementedError("direction not supported, `Order.SELL` for sell, `Order.BUY` for buy")
+ self.deal_amount = 0
+ self.factor = None
+
+ @property
+ def amount_delta(self) -> float:
+ """
+ return the delta of amount.
+ - Positive value indicates buying `amount` of share
+ - Negative value indicates selling `amount` of share
+ """
+ return self.amount * self.sign
+
+ @property
+ def deal_amount_delta(self) -> float:
+ """
+ return the delta of deal_amount.
+ - Positive value indicates buying `deal_amount` of share
+ - Negative value indicates selling `deal_amount` of share
+ """
+ return self.deal_amount * self.sign
+
+ @property
+ def sign(self) -> float:
+ """
+ return the sign of trading
+ - `+1` indicates buying
+ - `-1` value indicates selling
+ """
+ return self.direction * 2 - 1
+
+ @staticmethod
+ def parse_dir(direction: Union[str, int, np.integer, OrderDir, np.ndarray]) -> Union[OrderDir, np.ndarray]:
+ if isinstance(direction, OrderDir):
+ return direction
+ elif isinstance(direction, (int, float, np.integer, np.floating)):
+ if direction > 0:
+ return Order.BUY
+ else:
+ return Order.SELL
+ elif isinstance(direction, str):
+ dl = direction.lower()
+ if dl.strip() == "sell":
+ return OrderDir.SELL
+ elif dl.strip() == "buy":
+ return OrderDir.BUY
+ else:
+ raise NotImplementedError(f"This type of input is not supported")
+ elif isinstance(direction, np.ndarray):
+ direction_array = direction.copy()
+ direction_array[direction_array > 0] = Order.BUY
+ direction_array[direction_array <= 0] = Order.SELL
+ return direction_array
+ else:
+ raise NotImplementedError(f"This type of input is not supported")
+
+
+class OrderHelper:
+ """
+ Motivation
+ - Make generating order easier
+ - User may have no knowledge about the adjust-factor information about the system.
+ - It involves to much interaction with the exchange when generating orders.
+ """
+
+ def __init__(self, exchange: Exchange):
+ self.exchange = exchange
+
+ def create(
+ self,
+ code: str,
+ amount: float,
+ direction: OrderDir,
+ start_time: Union[str, pd.Timestamp] = None,
+ end_time: Union[str, pd.Timestamp] = None,
+ ) -> Order:
+ """
+ help to create a order
+
+ # TODO: create order for unadjusted amount order
+
+ Parameters
+ ----------
+ code : str
+ the id of the instrument
+ amount : float
+ **adjusted trading amount**
+ direction : OrderDir
+ trading direction
+ start_time : Union[str, pd.Timestamp] (optional)
+ The interval of the order which belongs to
+ end_time : Union[str, pd.Timestamp] (optional)
+ The interval of the order which belongs to
+
+ Returns
+ -------
+ Order:
+ The created order
+ """
+ if start_time is not None:
+ start_time = pd.Timestamp(start_time)
+ if end_time is not None:
+ end_time = pd.Timestamp(end_time)
+ # NOTE: factor is a value belongs to the results section. User don't have to care about it when creating orders
+ return Order(
+ stock_id=code,
+ amount=amount,
+ start_time=start_time,
+ end_time=end_time,
+ direction=direction,
+ )
+
+
+class TradeRange:
+ def __call__(self, trade_calendar: TradeCalendarManager) -> Tuple[int, int]:
+ """
+ This method will be call with following way
+
+ The outer strategy give a decision with with `TradeRange`
+ The decision will be checked by the inner decision.
+ inner decision will pass its trade_calendar as parameter when getting the trading range
+ - The framework's step is integer-index based.
+
+ Parameters
+ ----------
+ trade_calendar : TradeCalendarManager
+ the trade_calendar is from inner strategy
+
+ Returns
+ -------
+ Tuple[int, int]:
+ the start index and end index which are tradable
+
+ Raises
+ ------
+ NotImplementedError:
+ Exceptions are raised when no range limitation
+ """
+ raise NotImplementedError(f"Please implement the `__call__` method")
+
+ def clip_time_range(self, start_time: pd.Timestamp, end_time: pd.Timestamp) -> Tuple[pd.Timestamp, pd.Timestamp]:
+ """
+ Parameters
+ ----------
+ start_time : pd.Timestamp
+ end_time : pd.Timestamp
+ Both sides (start_time, end_time) are closed
+
+ Returns
+ -------
+ Tuple[pd.Timestamp, pd.Timestamp]:
+ The tradable time range.
+ - It is intersection of [start_time, end_time] and the rule of TradeRange itself
+ """
+ raise NotImplementedError(f"Please implement the `clip_time_range` method")
+
+
+class IdxTradeRange(TradeRange):
+ def __init__(self, start_idx: int, end_idx: int):
+ self._start_idx = start_idx
+ self._end_idx = end_idx
+
+ def __call__(self, trade_calendar: TradeCalendarManager = None) -> Tuple[int, int]:
+ return self._start_idx, self._end_idx
+
+
+class TradeRangeByTime(TradeRange):
+ """This is a helper function for make decisions"""
+
+ def __init__(self, start_time: str, end_time: str):
+ """
+ This is a callable class.
+
+ **NOTE**:
+ - It is designed for minute-bar for intraday trading!!!!!
+ - Both start_time and end_time are **closed** in the range
+
+ Parameters
+ ----------
+ start_time : str
+ e.g. "9:30"
+ end_time : str
+ e.g. "14:30"
+ """
+ self.start_time = pd.Timestamp(start_time).time()
+ self.end_time = pd.Timestamp(end_time).time()
+ assert self.start_time < self.end_time
+
+ def __call__(self, trade_calendar: TradeCalendarManager = None) -> Tuple[int, int]:
+ if trade_calendar is None:
+ raise NotImplementedError("trade_calendar is necessary for getting TradeRangeByTime.")
+ start = trade_calendar.start_time
+ val_start, val_end = concat_date_time(start.date(), self.start_time), concat_date_time(
+ start.date(), self.end_time
+ )
+ return trade_calendar.get_range_idx(val_start, val_end)
+
+ def clip_time_range(self, start_time: pd.Timestamp, end_time: pd.Timestamp) -> Tuple[pd.Timestamp, pd.Timestamp]:
+ start_date = start_time.date()
+ val_start, val_end = concat_date_time(start_date, self.start_time), concat_date_time(start_date, self.end_time)
+ # NOTE: `end_date` should not be used. Because the `end_date` is for slicing. It may be in the next day
+ # Assumption: start_time and end_time is for intraday trading. So it is OK for only using start_date
+ return max(val_start, start_time), min(val_end, end_time)
+
+
+class BaseTradeDecision:
+ """
+ Trade decisions ara made by strategy and executed by exeuter
+
+ Motivation:
+ Here are several typical scenarios for `BaseTradeDecision`
+
+ Case 1:
+ 1. Outer strategy makes a decision. The decision is not available at the start of current interval
+ 2. After a period of time, the decision are updated and become available
+ 3. The inner strategy try to get the decision and start to execute the decision according to `get_range_limit`
+ Case 2:
+ 1. The outer strategy's decision is available at the start of the interval
+ 2. Same as `case 1.3`
+ """
+
+ def __init__(self, strategy: BaseStrategy, trade_range: Union[Tuple[int, int], TradeRange] = None):
+ """
+ Parameters
+ ----------
+ strategy : BaseStrategy
+ The strategy who make the decision
+ trade_range: Union[Tuple[int, int], Callable] (optional)
+ The index range for underlying strategy.
+
+ Here are two examples of trade_range for each type
+
+ 1) Tuple[int, int]
+ start_index and end_index of the underlying strategy(both sides are closed)
+
+ 2) TradeRange
+
+ """
+ self.strategy = strategy
+ self.start_time, self.end_time = strategy.trade_calendar.get_step_time()
+ self.total_step = None # upper strategy has no knowledge about the sub executor before `_init_sub_trading`
+ if isinstance(trade_range, Tuple):
+ # for Tuple[int, int]
+ trade_range = IdxTradeRange(*trade_range)
+ self.trade_range: TradeRange = trade_range
+
+ def get_decision(self) -> List[object]:
+ """
+ get the **concrete decision** (e.g. execution orders)
+ This will be called by the inner strategy
+
+ Returns
+ -------
+ List[object]:
+ The decision result. Typically it is some orders
+ Example:
+ []:
+ Decision not available
+ [concrete_decision]:
+ available
+ """
+ raise NotImplementedError(f"This type of input is not supported")
+
+ def update(self, trade_calendar: TradeCalendarManager) -> Union["BaseTradeDecision", None]:
+ """
+ Be called at the **start** of each step.
+
+ This function is design for following purpose
+ 1) Leave a hook for the strategy who make `self` decision to update the decision itself
+ 2) Update some information from the inner executor calendar
+
+ Parameters
+ ----------
+ trade_calendar : TradeCalendarManager
+ The calendar of the **inner strategy**!!!!!
+
+ Returns
+ -------
+ None:
+ No update, use previous decision(or unavailable)
+ BaseTradeDecision:
+ New update, use new decision
+ """
+ # purpose 1)
+ self.total_step = trade_calendar.get_trade_len()
+
+ # purpose 2)
+ return self.strategy.update_trade_decision(self, trade_calendar)
+
+ def _get_range_limit(self, **kwargs) -> Tuple[int, int]:
+ if self.trade_range is not None:
+ return self.trade_range(trade_calendar=kwargs.get("inner_calendar"))
+ else:
+ raise NotImplementedError("The decision didn't provide an index range")
+
+ def get_range_limit(self, **kwargs) -> Tuple[int, int]:
+ """
+ return the expected step range for limiting the decision execution time
+ Both left and right are **closed**
+
+ if no available trade_range, `default_value` will be returned
+
+ It is only used in `NestedExecutor`
+ - The outmost strategy will not follow any range limit (but it may give range_limit)
+ - The inner most strategy's range_limit will be useless due to atomic executors don't have such
+ features.
+
+ **NOTE**:
+ 1) This function must be called after `self.update` in following cases(ensured by NestedExecutor):
+ - user relies on the auto-clip feature of `self.update`
+
+ 2) This function will be called after _init_sub_trading in NestedExecutor.
+
+ Parameters
+ ----------
+ **kwargs:
+ {
+ "default_value": , # using dict is for distinguish no value provided or None provided
+ "inner_calendar":
+ # because the range limit will control the step range of inner strategy, inner calendar will be a
+ # important parameter when trade_range is callable
+ }
+
+ Returns
+ -------
+ Tuple[int, int]:
+
+ Raises
+ ------
+ NotImplementedError:
+ If the following criteria meet
+ 1) the decision can't provide a unified start and end
+ 2) default_value is not provided
+ """
+ try:
+ _start_idx, _end_idx = self._get_range_limit(**kwargs)
+ except NotImplementedError:
+ if "default_value" in kwargs:
+ return kwargs["default_value"]
+ else:
+ # Default to get full index
+ raise NotImplementedError(f"The decision didn't provide an index range")
+
+ # clip index
+ if getattr(self, "total_step", None) is not None:
+ # if `self.update` is called.
+ # Then the _start_idx, _end_idx should be clipped
+ if _start_idx < 0 or _end_idx >= self.total_step:
+ logger = get_module_logger("decision")
+ logger.warning(
+ f"[{_start_idx},{_end_idx}] go beyoud the total_step({self.total_step}), it will be clipped"
+ )
+ _start_idx, _end_idx = max(0, _start_idx), min(self.total_step - 1, _end_idx)
+ return _start_idx, _end_idx
+
+ def get_data_cal_range_limit(self, rtype: str = "full", raise_error: bool = False) -> Tuple[int, int]:
+ """
+ get the range limit based on data calendar
+
+ NOTE: it is **total** range limit instead of a single step
+
+ The following assumptions are made
+ 1) The frequency of the exchange in common_infra is the same as the data calendar
+ 2) Users want the index mod by **day** (i.e. 240 min)
+
+ Parameters
+ ----------
+ rtype: str
+ - "full": return the full limitation of the deicsion in the day
+ - "step": return the limitation of current step
+
+ raise_error: bool
+ True: raise error if no trade_range is set
+ False: return full trade calendar.
+
+ It is useful in following cases
+ - users want to follow the order specific trading time range when decision level trade range is not
+ available. Raising NotImplementedError to indicates that range limit is not available
+
+ Returns
+ -------
+ Tuple[int, int]:
+ the range limit in data calendar
+
+ Raises
+ ------
+ NotImplementedError:
+ If the following criteria meet
+ 1) the decision can't provide a unified start and end
+ 2) raise_error is True
+ """
+ # potential performance issue
+ day_start = pd.Timestamp(self.start_time.date())
+ day_end = epsilon_change(day_start + pd.Timedelta(days=1))
+ freq = self.strategy.trade_exchange.freq
+ _, _, day_start_idx, day_end_idx = Cal.locate_index(day_start, day_end, freq=freq)
+ if self.trade_range is None:
+ if raise_error:
+ raise NotImplementedError(f"There is no trade_range in this case")
+ else:
+ return 0, day_end_idx - day_start_idx
+ else:
+ if rtype == "full":
+ val_start, val_end = self.trade_range.clip_time_range(day_start, day_end)
+ elif rtype == "step":
+ val_start, val_end = self.trade_range.clip_time_range(self.start_time, self.end_time)
+ else:
+ raise ValueError(f"This type of input {rtype} is not supported")
+ _, _, start_idx, end_index = Cal.locate_index(val_start, val_end, freq=freq)
+ return start_idx - day_start_idx, end_index - day_start_idx
+
+ def empty(self) -> bool:
+ for obj in self.get_decision():
+ if isinstance(obj, Order):
+ # Zero amount order will be treated as empty
+ if obj.amount > 1e-6:
+ return False
+ else:
+ return True
+ return True
+
+ def mod_inner_decision(self, inner_trade_decision: BaseTradeDecision):
+ """
+
+ This method will be called on the inner_trade_decision after it is generated.
+ `inner_trade_decision` will be changed **inplaced**.
+
+ Motivation of the `mod_inner_decision`
+ - Leave a hook for outer decision to affact the decision generated by the inner strategy
+ - e.g. the outmost strategy generate a time range for trading. But the upper layer can only affact the
+ nearest layer in the original design. With `mod_inner_decision`, the decision can passed through multiple
+ layers
+
+ Parameters
+ ----------
+ inner_trade_decision : BaseTradeDecision
+ """
+ # base class provide a default behaviour to modify inner_trade_decision
+ # trade_range should be propagated when inner trade_range is not set
+ if inner_trade_decision.trade_range is None:
+ inner_trade_decision.trade_range = self.trade_range
+
+
+class EmptyTradeDecision(BaseTradeDecision):
+ def empty(self) -> bool:
+ return True
+
+
+class TradeDecisionWO(BaseTradeDecision):
+ """
+ Trade Decision (W)ith (O)rder.
+ Besides, the time_range is also included.
+ """
+
+ def __init__(self, order_list: List[Order], strategy: BaseStrategy, trade_range: Tuple[int, int] = None):
+ super().__init__(strategy, trade_range=trade_range)
+ self.order_list = order_list
+ start, end = strategy.trade_calendar.get_step_time()
+ for o in order_list:
+ if o.start_time is None:
+ o.start_time = start
+ if o.end_time is None:
+ o.end_time = end
+
+ def get_decision(self) -> List[object]:
+ return self.order_list
+
+ def __repr__(self) -> str:
+ return f"class: {self.__class__.__name__}; strategy: {self.strategy}; trade_range: {self.trade_range}; order_list[{len(self.order_list)}]"
diff --git a/qlib/backtest/exchange.py b/qlib/backtest/exchange.py
new file mode 100644
index 0000000000..9e40e18773
--- /dev/null
+++ b/qlib/backtest/exchange.py
@@ -0,0 +1,786 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+from __future__ import annotations
+from collections import defaultdict
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+ from .account import Account
+
+from qlib.backtest.position import BasePosition, Position
+import random
+from typing import List, Tuple, Union
+import numpy as np
+import pandas as pd
+
+from ..data.data import D
+from ..config import C, REG_CN
+from ..log import get_module_logger
+from .decision import Order, OrderDir, OrderHelper
+from .high_performance_ds import BaseQuote, PandasQuote, NumpyQuote
+
+
+class Exchange:
+ def __init__(
+ self,
+ freq="day",
+ start_time=None,
+ end_time=None,
+ codes="all",
+ deal_price: Union[str, Tuple[str], List[str]] = None,
+ subscribe_fields=[],
+ limit_threshold: Union[Tuple[str, str], float, None] = None,
+ volume_threshold=None,
+ open_cost=0.0015,
+ close_cost=0.0025,
+ min_cost=5,
+ extra_quote=None,
+ quote_cls=NumpyQuote,
+ **kwargs,
+ ):
+ """__init__
+ :param freq: frequency of data
+ :param start_time: closed start time for backtest
+ :param end_time: closed end time for backtest
+ :param codes: list stock_id list or a string of instruments(i.e. all, csi500, sse50)
+ :param deal_price: Union[str, Tuple[str, str], List[str]]
+ The `deal_price` supports following two types of input
+ - : str
+ - (, ): Tuple[str] or List[str]
+ , or :=
+ := str
+ - for example '$close', '$open', '$vwap' ("close" is OK. `Exchange` will help to prepend
+ "$" to the expression)
+ :param subscribe_fields: list, subscribe fields. This expressions will be added to the query and `self.quote`.
+ It is useful when users want more fields to be queried
+ :param limit_threshold: Union[Tuple[str, str], float, None]
+ 1) `None`: no limitation
+ 2) float, 0.1 for example, default None
+ 3) Tuple[str, str]: (,
+ )
+ `False` value indicates the stock is tradable
+ `True` value indicates the stock is limited and not tradable
+ :param volume_threshold: Union[
+ Dict[
+ "all": ("cum" or "current", limit_str),
+ "buy": ("cum" or "current", limit_str),
+ "sell":("cum" or "current", limit_str),
+ ],
+ ("cum" or "current", limit_str),
+ ]
+ 1) ("cum" or "current", limit_str) denotes a single volume limit.
+ - limit_str is qlib data expression which is allowed to define your own Operator.
+ Please refer to qlib/contrib/ops/high_freq.py, here are any custom operator for high frequency,
+ such as DayCumsum. !!!NOTE: if you want you use the custom operator, you need to
+ register it in qlib_init.
+ - "cum" means that this is a cumulative value over time, such as cumulative market volume.
+ So when it is used as a volume limit, it is necessary to subtract the dealt amount.
+ - "current" means that this is a real-time value and will not accumulate over time,
+ so it can be directly used as a capacity limit.
+ e.g. ("cum", "0.2 * DayCumsum($volume, '9:45', '14:45')"), ("current", "$bidV1")
+ 2) "all" means the volume limits are both buying and selling.
+ "buy" means the volume limits of buying. "sell" means the volume limits of selling.
+ Different volume limits will be aggregated with min(). If volume_threshold is only
+ ("cum" or "current", limit_str) instead of a dict, the volume limits are for
+ both by deault. In other words, it is same as {"all": ("cum" or "current", limit_str)}.
+ 3) e.g. "volume_threshold": {
+ "all": ("cum", "0.2 * DayCumsum($volume, '9:45', '14:45')"),
+ "buy": ("current", "$askV1"),
+ "sell": ("current", "$bidV1"),
+ }
+ :param open_cost: cost rate for open, default 0.0015
+ :param close_cost: cost rate for close, default 0.0025
+ :param trade_unit: trade unit, 100 for China A market.
+ None for disable trade unit.
+ **NOTE**: `trade_unit` is included in the `kwargs`. It is necessary because we must
+ distinguish `not set` and `disable trade_unit`
+ :param min_cost: min cost, default 5
+ :param extra_quote: pandas, dataframe consists of
+ columns: like ['$vwap', '$close', '$volume', '$factor', 'limit_sell', 'limit_buy'].
+ The limit indicates that the etf is tradable on a specific day.
+ Necessary fields:
+ $close is for calculating the total value at end of each day.
+ Optional fields:
+ $volume is only necessary when we limit the trade amount or caculate PA(vwap) indicator
+ $vwap is only necessary when we use the $vwap price as the deal price
+ $factor is for rounding to the trading unit
+ limit_sell will be set to False by default(False indicates we can sell this
+ target on this day).
+ limit_buy will be set to False by default(False indicates we can buy this
+ target on this day).
+ index: MultipleIndex(instrument, pd.Datetime)
+ """
+ self.freq = freq
+ self.start_time = start_time
+ self.end_time = end_time
+
+ self.trade_unit = kwargs.pop("trade_unit", C.trade_unit)
+ if len(kwargs) > 0:
+ raise ValueError(f"Get Unexpected arguments {kwargs}")
+
+ if limit_threshold is None:
+ limit_threshold = C.limit_threshold
+ if deal_price is None:
+ deal_price = C.deal_price
+
+ # we have some verbose information here. So logging is enable
+ self.logger = get_module_logger("online operator")
+
+ # TODO: the quote, trade_dates, codes are not necessary.
+ # It is just for performance consideration.
+ self.limit_type = self._get_limit_type(limit_threshold)
+ if limit_threshold is None:
+ if C.region == REG_CN:
+ self.logger.warning(f"limit_threshold not set. The stocks hit the limit may be bought/sold")
+ elif self.limit_type == self.LT_FLT and abs(limit_threshold) > 0.1:
+ if C.region == REG_CN:
+ self.logger.warning(f"limit_threshold may not be set to a reasonable value")
+
+ if isinstance(deal_price, str):
+ if deal_price[0] != "$":
+ deal_price = "$" + deal_price
+ self.buy_price = self.sell_price = deal_price
+ elif isinstance(deal_price, (tuple, list)):
+ self.buy_price, self.sell_price = deal_price
+ else:
+ raise NotImplementedError(f"This type of input is not supported")
+
+ if isinstance(codes, str):
+ codes = D.instruments(codes)
+ self.codes = codes
+ # Necessary fields
+ # $close is for calculating the total value at end of each day.
+ # $factor is for rounding to the trading unit
+ # $change is for calculating the limit of the stock
+
+ #  get volume limit from kwargs
+ self.buy_vol_limit, self.sell_vol_limit, vol_lt_fields = self._get_vol_limit(volume_threshold)
+
+ necessary_fields = {self.buy_price, self.sell_price, "$close", "$change", "$factor", "$volume"}
+ if self.limit_type == self.LT_TP_EXP:
+ for exp in limit_threshold:
+ necessary_fields.add(exp)
+ all_fields = necessary_fields | vol_lt_fields
+ all_fields = list(all_fields | set(subscribe_fields))
+
+ self.all_fields = all_fields
+ self.open_cost = open_cost
+ self.close_cost = close_cost
+ self.min_cost = min_cost
+ self.limit_threshold: Union[Tuple[str, str], float, None] = limit_threshold
+ self.volume_threshold = volume_threshold
+ self.extra_quote = extra_quote
+ self.get_quote_from_qlib()
+
+ # init quote by quote_df
+ self.quote_cls = quote_cls
+ self.quote: BaseQuote = self.quote_cls(self.quote_df, freq)
+
+ def get_quote_from_qlib(self):
+ # get stock data from qlib
+ if len(self.codes) == 0:
+ self.codes = D.instruments()
+ self.quote_df = D.features(
+ self.codes, self.all_fields, self.start_time, self.end_time, freq=self.freq, disk_cache=True
+ ).dropna(subset=["$close"])
+ self.quote_df.columns = self.all_fields
+
+ # check buy_price data and sell_price data
+ for attr in "buy_price", "sell_price":
+ pstr = getattr(self, attr) # price string
+ if self.quote_df[pstr].isna().any():
+ self.logger.warning("{} field data contains nan.".format(pstr))
+
+ # update trade_w_adj_price
+ if self.quote_df["$factor"].isna().any():
+ # The 'factor.day.bin' file not exists, and `factor` field contains `nan`
+ # Use adjusted price
+ self.trade_w_adj_price = True
+ self.logger.warning("factor.day.bin file not exists or factor contains `nan`. Order using adjusted_price.")
+ if self.trade_unit is not None:
+ self.logger.warning(f"trade unit {self.trade_unit} is not supported in adjusted_price mode.")
+ else:
+ # The `factor.day.bin` file exists and all data `close` and `factor` are not `nan`
+ # Use normal price
+ self.trade_w_adj_price = False
+ # update limit
+ self._update_limit(self.limit_threshold)
+
+ # concat extra_quote
+ if self.extra_quote is not None:
+ # process extra_quote
+ if "$close" not in self.extra_quote:
+ raise ValueError("$close is necessray in extra_quote")
+ for attr in "buy_price", "sell_price":
+ pstr = getattr(self, attr) # price string
+ if pstr not in self.extra_quote.columns:
+ self.extra_quote[pstr] = self.extra_quote["$close"]
+ self.logger.warning(f"No {pstr} set for extra_quote. Use $close as {pstr}.")
+ if "$factor" not in self.extra_quote.columns:
+ self.extra_quote["$factor"] = 1.0
+ self.logger.warning("No $factor set for extra_quote. Use 1.0 as $factor.")
+ if "limit_sell" not in self.extra_quote.columns:
+ self.extra_quote["limit_sell"] = False
+ self.logger.warning("No limit_sell set for extra_quote. All stock will be able to be sold.")
+ if "limit_buy" not in self.extra_quote.columns:
+ self.extra_quote["limit_buy"] = False
+ self.logger.warning("No limit_buy set for extra_quote. All stock will be able to be bought.")
+ assert set(self.extra_quote.columns) == set(self.quote_df.columns) - {"$change"}
+ self.quote_df = pd.concat([self.quote_df, extra_quote], sort=False, axis=0)
+
+ LT_TP_EXP = "(exp)" # Tuple[str, str]
+ LT_FLT = "float" # float
+ LT_NONE = "none" # none
+
+ def _get_limit_type(self, limit_threshold):
+ """get limit type"""
+ if isinstance(limit_threshold, Tuple):
+ return self.LT_TP_EXP
+ elif isinstance(limit_threshold, float):
+ return self.LT_FLT
+ elif limit_threshold is None:
+ return self.LT_NONE
+ else:
+ raise NotImplementedError(f"This type of `limit_threshold` is not supported")
+
+ def _update_limit(self, limit_threshold):
+ # check limit_threshold
+ limit_type = self._get_limit_type(limit_threshold)
+ if limit_type == self.LT_NONE:
+ self.quote_df["limit_buy"] = False
+ self.quote_df["limit_sell"] = False
+ elif limit_type == self.LT_TP_EXP:
+ # set limit
+ self.quote_df["limit_buy"] = self.quote_df[limit_threshold[0]]
+ self.quote_df["limit_sell"] = self.quote_df[limit_threshold[1]]
+ elif limit_type == self.LT_FLT:
+ self.quote_df["limit_buy"] = self.quote_df["$change"].ge(limit_threshold)
+ self.quote_df["limit_sell"] = self.quote_df["$change"].le(-limit_threshold) # pylint: disable=E1130
+
+ def _get_vol_limit(self, volume_threshold):
+ """
+ preproccess the volume limit.
+ get the fields need to get from qlib.
+ get the volume limit list of buying and selling which is composed of all limits.
+ Parameters
+ ----------
+ volume_threshold :
+ please refer to the doc of exchange.
+ Returns
+ -------
+ fields: set
+ the fields need to get from qlib.
+ buy_vol_limit: List[Tuple[str]]
+ all volume limits of buying.
+ sell_vol_limit: List[Tuple[str]]
+ all volume limits of selling.
+ Raises
+ ------
+ ValueError
+ the format of volume_threshold is not supported.
+ """
+ if volume_threshold is None:
+ return None, None, set()
+
+ fields = set()
+ buy_vol_limit = []
+ sell_vol_limit = []
+ if isinstance(volume_threshold, tuple):
+ volume_threshold = {"all": volume_threshold}
+
+ assert isinstance(volume_threshold, dict)
+ for key in volume_threshold:
+ vol_limit = volume_threshold[key]
+ assert isinstance(vol_limit, tuple)
+ fields.add(vol_limit[1])
+
+ if key in ("buy", "all"):
+ buy_vol_limit.append(vol_limit)
+ if key in ("sell", "all"):
+ sell_vol_limit.append(vol_limit)
+
+ return buy_vol_limit, sell_vol_limit, fields
+
+ def check_stock_limit(self, stock_id, start_time, end_time, direction=None):
+ """
+ Parameters
+ ----------
+ direction : int, optional
+ trade direction, by default None
+ - if direction is None, check if tradable for buying and selling.
+ - if direction == Order.BUY, check the if tradable for buying
+ - if direction == Order.SELL, check the sell limit for selling.
+ """
+ if direction is None:
+ buy_limit = self.quote.get_data(stock_id, start_time, end_time, field="limit_buy", method="all")
+ sell_limit = self.quote.get_data(stock_id, start_time, end_time, field="limit_sell", method="all")
+ return buy_limit or sell_limit
+ elif direction == Order.BUY:
+ return self.quote.get_data(stock_id, start_time, end_time, field="limit_buy", method="all")
+ elif direction == Order.SELL:
+ return self.quote.get_data(stock_id, start_time, end_time, field="limit_sell", method="all")
+ else:
+ raise ValueError(f"direction {direction} is not supported!")
+
+ def check_stock_suspended(self, stock_id, start_time, end_time):
+ # is suspended
+ if stock_id in self.quote.get_all_stock():
+ return self.quote.get_data(stock_id, start_time, end_time, "$close") is None
+ else:
+ return True
+
+ def is_stock_tradable(self, stock_id, start_time, end_time, direction=None):
+ # check if stock can be traded
+ # same as check in check_order
+ if self.check_stock_suspended(stock_id, start_time, end_time) or self.check_stock_limit(
+ stock_id, start_time, end_time, direction
+ ):
+ return False
+ else:
+ return True
+
+ def check_order(self, order):
+ # check limit and suspended
+ if self.check_stock_suspended(order.stock_id, order.start_time, order.end_time) or self.check_stock_limit(
+ order.stock_id, order.start_time, order.end_time, order.direction
+ ):
+ return False
+ else:
+ return True
+
+ def deal_order(
+ self,
+ order,
+ trade_account: Account = None,
+ position: BasePosition = None,
+ dealt_order_amount: defaultdict = defaultdict(float),
+ ):
+ """
+ Deal order when the actual transaction
+ the results section in `Order` will be changed.
+ :param order: Deal the order.
+ :param trade_account: Trade account to be updated after dealing the order.
+ :param position: position to be updated after dealing the order.
+ :param dealt_order_amount: the dealt order amount dict with the format of {stock_id: float}
+ :return: trade_val, trade_cost, trade_price
+ """
+ # check order first.
+ if self.check_order(order) is False:
+ order.deal_amount = 0.0
+ # using np.nan instead of None to make it more convenient to should the value in format string
+ self.logger.debug(f"Order failed due to trading limitation: {order}")
+ return 0.0, 0.0, np.nan
+
+ if trade_account is not None and position is not None:
+ raise ValueError("trade_account and position can only choose one")
+
+ # NOTE: order will be changed in this function
+ trade_price, trade_val, trade_cost = self._calc_trade_info_by_order(
+ order, trade_account.current_position if trade_account else position, dealt_order_amount
+ )
+ if trade_val > 1e-5:
+ # If the order can only be deal 0 value. Nothing to be updated
+ # Otherwise, it will result in
+ # 1) some stock with 0 value in the position
+ # 2) `trade_unit` of trade_cost will be lost in user account
+ if trade_account:
+ trade_account.update_order(order=order, trade_val=trade_val, cost=trade_cost, trade_price=trade_price)
+ elif position:
+ position.update_order(order=order, trade_val=trade_val, cost=trade_cost, trade_price=trade_price)
+
+ return trade_val, trade_cost, trade_price
+
+ def get_quote_info(self, stock_id, start_time, end_time, method="ts_data_last"):
+ return self.quote.get_data(stock_id, start_time, end_time, method=method)
+
+ def get_close(self, stock_id, start_time, end_time, method="ts_data_last"):
+ return self.quote.get_data(stock_id, start_time, end_time, field="$close", method=method)
+
+ def get_volume(self, stock_id, start_time, end_time):
+ """get the total deal volume of stock with `stock_id` between the time interval [start_time, end_time)"""
+ return self.quote.get_data(stock_id, start_time, end_time, field="$volume", method="sum")
+
+ def get_deal_price(self, stock_id, start_time, end_time, direction: OrderDir, method="ts_data_last"):
+ if direction == OrderDir.SELL:
+ pstr = self.sell_price
+ elif direction == OrderDir.BUY:
+ pstr = self.buy_price
+ else:
+ raise NotImplementedError(f"This type of input is not supported")
+ deal_price = self.quote.get_data(stock_id, start_time, end_time, field=pstr, method=method)
+ if method is not None and (deal_price is None or np.isnan(deal_price) or deal_price <= 1e-08):
+ self.logger.warning(f"(stock_id:{stock_id}, trade_time:{(start_time, end_time)}, {pstr}): {deal_price}!!!")
+ self.logger.warning(f"setting deal_price to close price")
+ deal_price = self.get_close(stock_id, start_time, end_time, method)
+ return deal_price
+
+ def get_factor(self, stock_id, start_time, end_time) -> Union[float, None]:
+ """
+ Returns
+ -------
+ Union[float, None]:
+ `None`: if the stock is suspended `None` may be returned
+ `float`: return factor if the factor exists
+ """
+ assert start_time is not None and end_time is not None, "the time range must be given"
+ if stock_id not in self.quote.get_all_stock():
+ return None
+ return self.quote.get_data(stock_id, start_time, end_time, field="$factor", method="ts_data_last")
+
+ def generate_amount_position_from_weight_position(
+ self, weight_position, cash, start_time, end_time, direction=OrderDir.BUY
+ ):
+ """
+ The generate the target position according to the weight and the cash.
+ NOTE: All the cash will assigned to the tadable stock.
+ Parameter:
+ weight_position : dict {stock_id : weight}; allocate cash by weight_position
+ among then, weight must be in this range: 0 < weight < 1
+ cash : cash
+ start_time : the start time point of the step
+ end_time : the end time point of the step
+ direction : the direction of the deal price for estimating the amount
+ # NOTE: this function is used for calculating target position. So the default direction is buy
+ """
+
+ # calculate the total weight of tradable value
+ tradable_weight = 0.0
+ for stock_id in weight_position:
+ if self.is_stock_tradable(stock_id=stock_id, start_time=start_time, end_time=end_time):
+ # weight_position must be greater than 0 and less than 1
+ if weight_position[stock_id] < 0 or weight_position[stock_id] > 1:
+ raise ValueError(
+ "weight_position is {}, "
+ "weight_position is not in the range of (0, 1).".format(weight_position[stock_id])
+ )
+ tradable_weight += weight_position[stock_id]
+
+ if tradable_weight - 1.0 >= 1e-5:
+ raise ValueError("tradable_weight is {}, can not greater than 1.".format(tradable_weight))
+
+ amount_dict = {}
+ for stock_id in weight_position:
+ if weight_position[stock_id] > 0.0 and self.is_stock_tradable(
+ stock_id=stock_id, start_time=start_time, end_time=end_time
+ ):
+ amount_dict[stock_id] = (
+ cash
+ * weight_position[stock_id]
+ / tradable_weight
+ // self.get_deal_price(
+ stock_id=stock_id, start_time=start_time, end_time=end_time, direction=direction
+ )
+ )
+ return amount_dict
+
+ def get_real_deal_amount(self, current_amount, target_amount, factor):
+ """
+ Calculate the real adjust deal amount when considering the trading unit
+ :param current_amount:
+ :param target_amount:
+ :param factor:
+ :return real_deal_amount; Positive deal_amount indicates buying more stock.
+ """
+ if current_amount == target_amount:
+ return 0
+ elif current_amount < target_amount:
+ deal_amount = target_amount - current_amount
+ deal_amount = self.round_amount_by_trade_unit(deal_amount, factor)
+ return deal_amount
+ else:
+ if target_amount == 0:
+ return -current_amount
+ else:
+ deal_amount = current_amount - target_amount
+ deal_amount = self.round_amount_by_trade_unit(deal_amount, factor)
+ return -deal_amount
+
+ def generate_order_for_target_amount_position(self, target_position, current_position, start_time, end_time):
+ """
+ Note: some future information is used in this function
+ Parameter:
+ target_position : dict { stock_id : amount }
+ current_postion : dict { stock_id : amount}
+ trade_unit : trade_unit
+ down sample : for amount 321 and trade_unit 100, deal_amount is 300
+ deal order on trade_date
+ """
+ # split buy and sell for further use
+ buy_order_list = []
+ sell_order_list = []
+ # three parts: kept stock_id, dropped stock_id, new stock_id
+ # handle kept stock_id
+
+ # because the order of the set is not fixed, the trading order of the stock is different, so that the backtest results of the same parameter are different;
+ # so here we sort stock_id, and then randomly shuffle the order of stock_id
+ # because the same random seed is used, the final stock_id order is fixed
+ sorted_ids = sorted(set(list(current_position.keys()) + list(target_position.keys())))
+ random.seed(0)
+ random.shuffle(sorted_ids)
+ for stock_id in sorted_ids:
+
+ # Do not generate order for the nontradable stocks
+ if not self.is_stock_tradable(stock_id=stock_id, start_time=start_time, end_time=end_time):
+ continue
+
+ target_amount = target_position.get(stock_id, 0)
+ current_amount = current_position.get(stock_id, 0)
+ factor = self.get_factor(stock_id, start_time=start_time, end_time=end_time)
+
+ deal_amount = self.get_real_deal_amount(current_amount, target_amount, factor)
+ if deal_amount == 0:
+ continue
+ elif deal_amount > 0:
+ # buy stock
+ buy_order_list.append(
+ Order(
+ stock_id=stock_id,
+ amount=deal_amount,
+ direction=Order.BUY,
+ start_time=start_time,
+ end_time=end_time,
+ factor=factor,
+ )
+ )
+ else:
+ # sell stock
+ sell_order_list.append(
+ Order(
+ stock_id=stock_id,
+ amount=abs(deal_amount),
+ direction=Order.SELL,
+ start_time=start_time,
+ end_time=end_time,
+ factor=factor,
+ )
+ )
+ # return order_list : buy + sell
+ return sell_order_list + buy_order_list
+
+ def calculate_amount_position_value(
+ self, amount_dict, start_time, end_time, only_tradable=False, direction=OrderDir.SELL
+ ):
+ """Parameter
+ position : Position()
+ amount_dict : {stock_id : amount}
+ direction : the direction of the deal price for estimating the amount
+ # NOTE:
+ This function is used for calculating current position value.
+ So the default direction is sell.
+ """
+ value = 0
+ for stock_id in amount_dict:
+ if (
+ only_tradable is True
+ and self.check_stock_suspended(stock_id=stock_id, start_time=start_time, end_time=end_time) is False
+ and self.check_stock_limit(stock_id=stock_id, start_time=start_time, end_time=end_time) is False
+ or only_tradable is False
+ ):
+ value += (
+ self.get_deal_price(
+ stock_id=stock_id, start_time=start_time, end_time=end_time, direction=direction
+ )
+ * amount_dict[stock_id]
+ )
+ return value
+
+ def _get_factor_or_raise_error(self, factor: float = None, stock_id: str = None, start_time=None, end_time=None):
+ """Please refer to the docs of get_amount_of_trade_unit"""
+ if factor is None:
+ if stock_id is not None and start_time is not None and end_time is not None:
+ factor = self.get_factor(stock_id=stock_id, start_time=start_time, end_time=end_time)
+ else:
+ raise ValueError(f"`factor` and (`stock_id`, `start_time`, `end_time`) can't both be None")
+ return factor
+
+ def get_amount_of_trade_unit(self, factor: float = None, stock_id: str = None, start_time=None, end_time=None):
+ """
+ get the trade unit of amount based on **factor**
+ the factor can be given directly or calculated in given time range and stock id.
+ `factor` has higher priority than `stock_id`, `start_time` and `end_time`
+ Parameters
+ ----------
+ factor : float
+ the adjusted factor
+ stock_id : str
+ the id of the stock
+ start_time :
+ the start time of trading range
+ end_time :
+ the end time of trading range
+ """
+ if not self.trade_w_adj_price and self.trade_unit is not None:
+ factor = self._get_factor_or_raise_error(
+ factor=factor, stock_id=stock_id, start_time=start_time, end_time=end_time
+ )
+ return self.trade_unit / factor
+ else:
+ return None
+
+ def round_amount_by_trade_unit(
+ self, deal_amount, factor: float = None, stock_id: str = None, start_time=None, end_time=None
+ ):
+ """Parameter
+ Please refer to the docs of get_amount_of_trade_unit
+ deal_amount : float, adjusted amount
+ factor : float, adjusted factor
+ return : float, real amount
+ """
+ if not self.trade_w_adj_price and self.trade_unit is not None:
+ # the minimal amount is 1. Add 0.1 for solving precision problem.
+ factor = self._get_factor_or_raise_error(
+ factor=factor, stock_id=stock_id, start_time=start_time, end_time=end_time
+ )
+ return (deal_amount * factor + 0.1) // self.trade_unit * self.trade_unit / factor
+ return deal_amount
+
+ def _clip_amount_by_volume(self, order: Order, dealt_order_amount: dict) -> int:
+ """parse the capacity limit string and return the actual amount of orders that can be executed.
+ NOTE:
+ this function will change the order.deal_amount **inplace**
+ - This will make the order info more accurate
+ Parameters
+ ----------
+ order : Order
+ the order to be executed.
+ dealt_order_amount : dict
+ :param dealt_order_amount: the dealt order amount dict with the format of {stock_id: float}
+ """
+ if order.direction == Order.BUY:
+ vol_limit = self.buy_vol_limit
+ elif order.direction == Order.SELL:
+ vol_limit = self.sell_vol_limit
+
+ if vol_limit is None:
+ return order.deal_amount
+
+ vol_limit_num = []
+ for limit in vol_limit:
+ assert isinstance(limit, tuple)
+ if limit[0] == "current":
+ limit_value = self.quote.get_data(
+ order.stock_id,
+ order.start_time,
+ order.end_time,
+ field=limit[1],
+ method="sum",
+ )
+ vol_limit_num.append(limit_value)
+ elif limit[0] == "cum":
+ limit_value = self.quote.get_data(
+ order.stock_id,
+ order.start_time,
+ order.end_time,
+ field=limit[1],
+ method="ts_data_last",
+ )
+ vol_limit_num.append(limit_value - dealt_order_amount[order.stock_id])
+ else:
+ raise ValueError(f"{limit[0]} is not supported")
+ vol_limit_min = min(vol_limit_num)
+ orig_deal_amount = order.deal_amount
+ order.deal_amount = max(min(vol_limit_min, orig_deal_amount), 0)
+ if vol_limit_min < orig_deal_amount:
+ self.logger.debug(
+ f"Order clipped due to volume limitation: {order}, {[(vol, rule) for vol, rule in zip(vol_limit_num, vol_limit)]}"
+ )
+
+ def _get_buy_amount_by_cash_limit(self, trade_price, cash):
+ """return the real order amount after cash limit for buying.
+ Parameters
+ ----------
+ trade_price : float
+ position : cash
+ Return
+ ----------
+ float
+ the real order amount after cash limit for buying.
+ """
+ max_trade_amount = 0
+ if cash >= self.min_cost:
+ # critical_price means the stock transaction price when the service fee is equal to min_cost.
+ critical_price = self.min_cost / self.open_cost + self.min_cost
+ if cash >= critical_price:
+ # the service fee is equal to open_cost * trade_amount
+ max_trade_amount = cash / (1 + self.open_cost) / trade_price
+ else:
+ # the service fee is equal to min_cost
+ max_trade_amount = (cash - self.min_cost) / trade_price
+ return max_trade_amount
+
+ def _calc_trade_info_by_order(self, order, position: Position, dealt_order_amount):
+ """
+ Calculation of trade info
+ **NOTE**: Order will be changed in this function
+ :param order:
+ :param position: Position
+ :param dealt_order_amount: the dealt order amount dict with the format of {stock_id: float}
+ :return: trade_price, trade_val, trade_cost
+ """
+ trade_price = self.get_deal_price(order.stock_id, order.start_time, order.end_time, direction=order.direction)
+ order.factor = self.get_factor(order.stock_id, order.start_time, order.end_time)
+ order.deal_amount = order.amount # set to full amount and clip it step by step
+ # Clipping amount first
+ # - It simulates that the order is rejected directly by the exchange due to large order
+ # Another choice is placing it after rounding the order
+ # - It simulates that the large order is submitted, but partial is dealt regardless of rounding by trading unit.
+ self._clip_amount_by_volume(order, dealt_order_amount)
+
+ if order.direction == Order.SELL:
+ cost_ratio = self.close_cost
+ # sell
+ # if we don't know current position, we choose to sell all
+ # Otherwise, we clip the amount based on current position
+ if position is not None:
+ current_amount = (
+ position.get_stock_amount(order.stock_id) if position.check_stock(order.stock_id) else 0
+ )
+ if not np.isclose(order.deal_amount, current_amount):
+ # when not selling last stock. rounding is necessary
+ order.deal_amount = self.round_amount_by_trade_unit(
+ min(current_amount, order.deal_amount), order.factor
+ )
+
+ # in case of negative value of cash
+ if position.get_cash() + order.deal_amount * trade_price < max(
+ order.deal_amount * trade_price * cost_ratio,
+ self.min_cost,
+ ):
+ order.deal_amount = 0
+ self.logger.debug(f"Order clipped due to cash limitation: {order}")
+
+ elif order.direction == Order.BUY:
+ cost_ratio = self.open_cost
+ # buy
+ if position is not None:
+ cash = position.get_cash()
+ trade_val = order.deal_amount * trade_price
+ if cash < trade_val + max(trade_val * cost_ratio, self.min_cost):
+ # The money is not enough
+ max_buy_amount = self._get_buy_amount_by_cash_limit(trade_price, cash)
+ order.deal_amount = self.round_amount_by_trade_unit(
+ min(max_buy_amount, order.deal_amount), order.factor
+ )
+ self.logger.debug(f"Order clipped due to cash limitation: {order}")
+ else:
+ # The money is enough
+ order.deal_amount = self.round_amount_by_trade_unit(order.deal_amount, order.factor)
+ else:
+ # Unknown amount of money. Just round the amount
+ order.deal_amount = self.round_amount_by_trade_unit(order.deal_amount, order.factor)
+
+ else:
+ raise NotImplementedError("order type {} error".format(order.type))
+
+ trade_val = order.deal_amount * trade_price
+ trade_cost = max(trade_val * cost_ratio, self.min_cost)
+ if trade_val <= 1e-5:
+ # if dealing is not successful, the trade_cost should be zero.
+ trade_cost = 0
+ return trade_price, trade_val, trade_cost
+
+ def get_order_helper(self) -> OrderHelper:
+ if not hasattr(self, "_order_helper"):
+ # cache to avoid recreate the same instance
+ self._order_helper = OrderHelper(self)
+ return self._order_helper
diff --git a/qlib/backtest/executor.py b/qlib/backtest/executor.py
new file mode 100644
index 0000000000..44f3e8db03
--- /dev/null
+++ b/qlib/backtest/executor.py
@@ -0,0 +1,541 @@
+from abc import abstractclassmethod, abstractmethod
+import copy
+from qlib.backtest.position import BasePosition
+from qlib.log import get_module_logger
+from types import GeneratorType
+from qlib.backtest.account import Account
+import warnings
+import pandas as pd
+from typing import List, Tuple, Union
+from collections import defaultdict
+
+from qlib.backtest.report import Indicator
+
+from .decision import EmptyTradeDecision, Order, BaseTradeDecision
+from .exchange import Exchange
+from .utils import TradeCalendarManager, CommonInfrastructure, LevelInfrastructure, get_start_end_idx
+
+from ..utils import init_instance_by_config
+from ..utils.time import Freq
+from ..strategy.base import BaseStrategy
+
+
+class BaseExecutor:
+ """Base executor for trading"""
+
+ def __init__(
+ self,
+ time_per_step: str,
+ start_time: Union[str, pd.Timestamp] = None,
+ end_time: Union[str, pd.Timestamp] = None,
+ indicator_config: dict = {},
+ generate_portfolio_metrics: bool = False,
+ verbose: bool = False,
+ track_data: bool = False,
+ trade_exchange: Exchange = None,
+ common_infra: CommonInfrastructure = None,
+ settle_type=BasePosition.ST_NO,
+ **kwargs,
+ ):
+ """
+ Parameters
+ ----------
+ time_per_step : str
+ trade time per trading step, used for genreate the trade calendar
+ show_indicator: bool, optional
+ whether to show indicators, :
+ - 'pa', the price advantage
+ - 'pos', the positive rate
+ - 'ffr', the fulfill rate
+ indicator_config: dict, optional
+ config for calculating trade indicator, including the following fields:
+ - 'show_indicator': whether to show indicators, optional, default by False. The indicators includes
+ - 'pa', the price advantage
+ - 'pos', the positive rate
+ - 'ffr', the fulfill rate
+ - 'pa_config': config for calculating price advantage(pa), optional
+ - 'base_price': the based price than which the trading price is advanced, Optional, default by 'twap'
+ - If 'base_price' is 'twap', the based price is the time weighted average price
+ - If 'base_price' is 'vwap', the based price is the volume weighted average price
+ - 'weight_method': weighted method when calculating total trading pa by different orders' pa in each step, optional, default by 'mean'
+ - If 'weight_method' is 'mean', calculating mean value of different orders' pa
+ - If 'weight_method' is 'amount_weighted', calculating amount weighted average value of different orders' pa
+ - If 'weight_method' is 'value_weighted', calculating value weighted average value of different orders' pa
+ - 'ffr_config': config for calculating fulfill rate(ffr), optional
+ - 'weight_method': weighted method when calculating total trading ffr by different orders' ffr in each step, optional, default by 'mean'
+ - If 'weight_method' is 'mean', calculating mean value of different orders' ffr
+ - If 'weight_method' is 'amount_weighted', calculating amount weighted average value of different orders' ffr
+ - If 'weight_method' is 'value_weighted', calculating value weighted average value of different orders' ffr
+ Example:
+ {
+ 'show_indicator': True,
+ 'pa_config': {
+ "agg": "twap", # "vwap"
+ "price": "$close", # default to use deal price of the exchange
+ },
+ 'ffr_config':{
+ 'weight_method': 'value_weighted',
+ }
+ }
+ generate_portfolio_metrics : bool, optional
+ whether to generate portfolio_metrics, by default False
+ verbose : bool, optional
+ whether to print trading info, by default False
+ track_data : bool, optional
+ whether to generate trade_decision, will be used when training rl agent
+ - If `self.track_data` is true, when making data for training, the input `trade_decision` of `execute` will be generated by `collect_data`
+ - Else, `trade_decision` will not be generated
+
+ trade_exchange : Exchange
+ exchange that provides market info, used to generate portfolio_metrics
+ - If generate_portfolio_metrics is None, trade_exchange will be ignored
+ - Else If `trade_exchange` is None, self.trade_exchange will be set with common_infra
+
+ common_infra : CommonInfrastructure, optional:
+ common infrastructure for backtesting, may including:
+ - trade_account : Account, optional
+ trade account for trading
+ - trade_exchange : Exchange, optional
+ exchange that provides market info
+
+ settle_type : str
+ Please refer to the docs of BasePosition.settle_start
+ """
+ self.time_per_step = time_per_step
+ self.indicator_config = indicator_config
+ self.generate_portfolio_metrics = generate_portfolio_metrics
+ self.verbose = verbose
+ self.track_data = track_data
+ self._trade_exchange = trade_exchange
+ self.level_infra = LevelInfrastructure()
+ self.level_infra.reset_infra(common_infra=common_infra)
+ self._settle_type = settle_type
+ self.reset(start_time=start_time, end_time=end_time, common_infra=common_infra)
+ if common_infra is None:
+ get_module_logger("BaseExecutor").warning(f"`common_infra` is not set for {self}")
+
+ # record deal order amount in one day
+ self.dealt_order_amount = defaultdict(float)
+ self.deal_day = None
+
+ def reset_common_infra(self, common_infra):
+ """
+ reset infrastructure for trading
+ - reset trade_account
+ """
+ if not hasattr(self, "common_infra"):
+ self.common_infra = common_infra
+ else:
+ self.common_infra.update(common_infra)
+
+ if common_infra.has("trade_account"):
+ # NOTE: there is a trick in the code.
+ # copy is used instead of deepcopy. So positions are shared
+ self.trade_account: Account = copy.copy(common_infra.get("trade_account"))
+ self.trade_account.reset(freq=self.time_per_step, port_metr_enabled=self.generate_portfolio_metrics)
+
+ @property
+ def trade_exchange(self) -> Exchange:
+ """get trade exchange in a prioritized order"""
+ return getattr(self, "_trade_exchange", None) or self.common_infra.get("trade_exchange")
+
+ @property
+ def trade_calendar(self) -> TradeCalendarManager:
+ """
+ Though trade calendar can be accessed from multiple sources, but managing in a centralized way will make the
+ code easier
+ """
+ return self.level_infra.get("trade_calendar")
+
+ def reset(self, common_infra: CommonInfrastructure = None, **kwargs):
+ """
+ - reset `start_time` and `end_time`, used in trade calendar
+ - reset `common_infra`, used to reset `trade_account`, `trade_exchange`, .etc
+ """
+
+ if "start_time" in kwargs or "end_time" in kwargs:
+ start_time = kwargs.get("start_time")
+ end_time = kwargs.get("end_time")
+ self.level_infra.reset_cal(freq=self.time_per_step, start_time=start_time, end_time=end_time)
+ if common_infra is not None:
+ self.reset_common_infra(common_infra)
+
+ def get_level_infra(self):
+ return self.level_infra
+
+ def finished(self):
+ return self.trade_calendar.finished()
+
+ def execute(self, trade_decision: BaseTradeDecision, level: int = 0):
+ """execute the trade decision and return the executed result
+
+ NOTE: this function is never used directly in the framework. Should we delete it?
+
+ Parameters
+ ----------
+ trade_decision : BaseTradeDecision
+
+ level : int
+ the level of current executor
+
+ Returns
+ ----------
+ execute_result : List[object]
+ the executed result for trade decision
+ """
+ return_value = {}
+ for _decision in self.collect_data(trade_decision, return_value=return_value, level=level):
+ pass
+ return return_value.get("execute_result")
+
+ @abstractclassmethod
+ def _collect_data(self, trade_decision: BaseTradeDecision, level: int = 0) -> Tuple[List[object], dict]:
+ """
+ Please refer to the doc of collect_data
+ The only difference between `_collect_data` and `collect_data` is that some common steps are moved into
+ collect_data
+
+ Parameters
+ ----------
+ Please refer to the doc of collect_data
+
+
+ Returns
+ -------
+ Tuple[List[object], dict]:
+ (, )
+ """
+
+ def collect_data(
+ self, trade_decision: BaseTradeDecision, return_value: dict = None, level: int = 0
+ ) -> List[object]:
+ """Generator for collecting the trade decision data for rl training
+
+ his function will make a step forward
+
+ Parameters
+ ----------
+ trade_decision : BaseTradeDecision
+
+ level : int
+ the level of current executor. 0 indicates the top level
+
+ return_value : dict
+ the mem address to return the value
+ e.g. {"return_value": }
+
+ Returns
+ ----------
+ execute_result : List[object]
+ the executed result for trade decision.
+ ** NOTE!!!! **:
+ 1) This is necessary, The return value of generator will be used in NestedExecutor
+ 2) Please note the executed results are not merged.
+
+ Yields
+ -------
+ object
+ trade decision
+ """
+ if self.track_data:
+ yield trade_decision
+
+ atomic = not issubclass(self.__class__, NestedExecutor) # issubclass(A, A) is True
+
+ if atomic and trade_decision.get_range_limit(default_value=None) is not None:
+ raise ValueError("atomic executor doesn't support specify `range_limit`")
+
+ if self._settle_type != BasePosition.ST_NO:
+ self.trade_account.current_position.settle_start(self._settle_type)
+
+ obj = self._collect_data(trade_decision=trade_decision, level=level)
+
+ if isinstance(obj, GeneratorType):
+ res, kwargs = yield from obj
+ else:
+ # Some concrete executor don't have inner decisions
+ res, kwargs = obj
+
+ trade_start_time, trade_end_time = self.trade_calendar.get_step_time()
+ # Account will not be changed in this function
+ self.trade_account.update_bar_end(
+ trade_start_time,
+ trade_end_time,
+ self.trade_exchange,
+ atomic=atomic,
+ outer_trade_decision=trade_decision,
+ indicator_config=self.indicator_config,
+ **kwargs,
+ )
+
+ self.trade_calendar.step()
+
+ if self._settle_type != BasePosition.ST_NO:
+ self.trade_account.current_position.settle_commit()
+
+ if return_value is not None:
+ return_value.update({"execute_result": res})
+ return res
+
+ def get_all_executors(self):
+ """get all executors"""
+ return [self]
+
+
+class NestedExecutor(BaseExecutor):
+ """
+ Nested Executor with inner strategy and executor
+ - At each time `execute` is called, it will call the inner strategy and executor to execute the `trade_decision` in a higher frequency env.
+ """
+
+ def __init__(
+ self,
+ time_per_step: str,
+ inner_executor: Union[BaseExecutor, dict],
+ inner_strategy: Union[BaseStrategy, dict],
+ start_time: Union[str, pd.Timestamp] = None,
+ end_time: Union[str, pd.Timestamp] = None,
+ indicator_config: dict = {},
+ generate_portfolio_metrics: bool = False,
+ verbose: bool = False,
+ track_data: bool = False,
+ skip_empty_decision: bool = True,
+ align_range_limit: bool = True,
+ common_infra: CommonInfrastructure = None,
+ **kwargs,
+ ):
+ """
+ Parameters
+ ----------
+ inner_executor : BaseExecutor
+ trading env in each trading bar.
+ inner_strategy : BaseStrategy
+ trading strategy in each trading bar
+ skip_empty_decision: bool
+ Will the executor skip call inner loop when the decision is empty.
+ It should be False in following cases
+ - The decisions may be updated by steps
+ - The inner executor may not follow the decisions from the outer strategy
+ align_range_limit: bool
+ force to align the trade_range decision
+ It is only for nested executor, because range_limit is given by outer strategy
+ """
+ self.inner_executor: BaseExecutor = init_instance_by_config(
+ inner_executor, common_infra=common_infra, accept_types=BaseExecutor
+ )
+ self.inner_strategy: BaseStrategy = init_instance_by_config(
+ inner_strategy, common_infra=common_infra, accept_types=BaseStrategy
+ )
+
+ self._skip_empty_decision = skip_empty_decision
+ self._align_range_limit = align_range_limit
+
+ super(NestedExecutor, self).__init__(
+ time_per_step=time_per_step,
+ start_time=start_time,
+ end_time=end_time,
+ indicator_config=indicator_config,
+ generate_portfolio_metrics=generate_portfolio_metrics,
+ verbose=verbose,
+ track_data=track_data,
+ common_infra=common_infra,
+ **kwargs,
+ )
+
+ def reset_common_infra(self, common_infra):
+ """
+ reset infrastructure for trading
+ - reset inner_strategyand inner_executor common infra
+ """
+ super(NestedExecutor, self).reset_common_infra(common_infra)
+
+ self.inner_executor.reset_common_infra(common_infra)
+ self.inner_strategy.reset_common_infra(common_infra)
+
+ def _init_sub_trading(self, trade_decision):
+ trade_start_time, trade_end_time = self.trade_calendar.get_step_time()
+ self.inner_executor.reset(start_time=trade_start_time, end_time=trade_end_time)
+ sub_level_infra = self.inner_executor.get_level_infra()
+ self.level_infra.set_sub_level_infra(sub_level_infra)
+ self.inner_strategy.reset(level_infra=sub_level_infra, outer_trade_decision=trade_decision)
+
+ def _update_trade_decision(self, trade_decision: BaseTradeDecision) -> BaseTradeDecision:
+ # outter strategy have chance to update decision each iterator
+ updated_trade_decision = trade_decision.update(self.inner_executor.trade_calendar)
+ if updated_trade_decision is not None:
+ trade_decision = updated_trade_decision
+ # NEW UPDATE
+ # create a hook for inner strategy to update outter decision
+ self.inner_strategy.alter_outer_trade_decision(trade_decision)
+ return trade_decision
+
+ def _collect_data(self, trade_decision: BaseTradeDecision, level: int = 0):
+ execute_result = []
+ inner_order_indicators = []
+ decision_list = []
+ # NOTE:
+ # - this is necessary to calculating the steps in sub level
+ # - more detailed information will be set into trade decision
+ self._init_sub_trading(trade_decision)
+
+ _inner_execute_result = None
+ while not self.inner_executor.finished():
+ trade_decision = self._update_trade_decision(trade_decision)
+
+ if trade_decision.empty() and self._skip_empty_decision:
+ # give one chance for outer strategy to update the strategy
+ # - For updating some information in the sub executor(the strategy have no knowledge of the inner
+ # executor when generating the decision)
+ break
+
+ sub_cal: TradeCalendarManager = self.inner_executor.trade_calendar
+
+ # NOTE: make sure get_start_end_idx is after `self._update_trade_decision`
+ start_idx, end_idx = get_start_end_idx(sub_cal, trade_decision)
+ if not self._align_range_limit or start_idx <= sub_cal.get_trade_step() <= end_idx:
+ # if force align the range limit, skip the steps outside the decision range limit
+
+ _inner_trade_decision: BaseTradeDecision = self.inner_strategy.generate_trade_decision(
+ _inner_execute_result
+ )
+ trade_decision.mod_inner_decision(_inner_trade_decision) # propagate part of decision information
+
+ # NOTE sub_cal.get_step_time() must be called before collect_data in case of step shifting
+ decision_list.append((_inner_trade_decision, *sub_cal.get_step_time()))
+
+ # NOTE: Trade Calendar will step forward in the follow line
+ _inner_execute_result = yield from self.inner_executor.collect_data(
+ trade_decision=_inner_trade_decision, level=level + 1
+ )
+ execute_result.extend(_inner_execute_result)
+
+ inner_order_indicators.append(
+ self.inner_executor.trade_account.get_trade_indicator().get_order_indicator(raw=True)
+ )
+ else:
+ # do nothing and just step forward
+ sub_cal.step()
+
+ return execute_result, {"inner_order_indicators": inner_order_indicators, "decision_list": decision_list}
+
+ def get_all_executors(self):
+ """get all executors, including self and inner_executor.get_all_executors()"""
+ return [self, *self.inner_executor.get_all_executors()]
+
+
+class SimulatorExecutor(BaseExecutor):
+ """Executor that simulate the true market"""
+
+ # TODO: TT_SERIAL & TT_PARAL will be replaced by feature fix_pos now.
+ # Please remove them in the future.
+
+ # available trade_types
+ TT_SERIAL = "serial"
+ ## The orders will be executed serially in a sequence
+ # In each trading step, it is possible that users sell instruments first and use the money to buy new instruments
+ TT_PARAL = "parallel"
+ ## The orders will be executed parallelly
+ # In each trading step, if users try to sell instruments first and buy new instruments with money, failure will
+ # occur
+
+ def __init__(
+ self,
+ time_per_step: str,
+ start_time: Union[str, pd.Timestamp] = None,
+ end_time: Union[str, pd.Timestamp] = None,
+ indicator_config: dict = {},
+ generate_portfolio_metrics: bool = False,
+ verbose: bool = False,
+ track_data: bool = False,
+ common_infra: CommonInfrastructure = None,
+ trade_type: str = TT_SERIAL,
+ **kwargs,
+ ):
+ """
+ Parameters
+ ----------
+ trade_type: str
+ please refer to the doc of `TT_SERIAL` & `TT_PARAL`
+ """
+ super(SimulatorExecutor, self).__init__(
+ time_per_step=time_per_step,
+ start_time=start_time,
+ end_time=end_time,
+ indicator_config=indicator_config,
+ generate_portfolio_metrics=generate_portfolio_metrics,
+ verbose=verbose,
+ track_data=track_data,
+ common_infra=common_infra,
+ **kwargs,
+ )
+
+ self.trade_type = trade_type
+
+ def _get_order_iterator(self, trade_decision: BaseTradeDecision) -> List[Order]:
+ """
+
+ Parameters
+ ----------
+ trade_decision : BaseTradeDecision
+ the trade decision given by the strategy
+
+ Returns
+ -------
+ List[Order]:
+ get a list orders according to `self.trade_type`
+ """
+ orders = trade_decision.get_decision()
+
+ if self.trade_type == self.TT_SERIAL:
+ # Orders will be traded in a parallel way
+ order_it = orders
+ elif self.trade_type == self.TT_PARAL:
+ # NOTE: !!!!!!!
+ # Assumption: there will not be orders in different trading direction in a single step of a strategy !!!!
+ # The parallel trading failure will be caused only by the confliction of money
+ # Therefore, make the buying go first will make sure the confliction happen.
+ # It equals to parallel trading after sorting the order by direction
+ order_it = sorted(orders, key=lambda order: -order.direction)
+ else:
+ raise NotImplementedError(f"This type of input is not supported")
+ return order_it
+
+ def _update_dealt_order_amount(self, order):
+ """update date and dealt order amount in the day."""
+
+ now_deal_day = self.trade_calendar.get_step_time()[0].floor(freq="D")
+ if self.deal_day is None or now_deal_day > self.deal_day:
+ self.dealt_order_amount = defaultdict(float)
+ self.deal_day = now_deal_day
+ self.dealt_order_amount[order.stock_id] += order.deal_amount
+
+ def _collect_data(self, trade_decision: BaseTradeDecision, level: int = 0):
+
+ trade_start_time, _ = self.trade_calendar.get_step_time()
+ execute_result = []
+
+ for order in self._get_order_iterator(trade_decision):
+ # execute the order.
+ # NOTE: The trade_account will be changed in this function
+ trade_val, trade_cost, trade_price = self.trade_exchange.deal_order(
+ order,
+ trade_account=self.trade_account,
+ dealt_order_amount=self.dealt_order_amount,
+ )
+ execute_result.append((order, trade_val, trade_cost, trade_price))
+ self._update_dealt_order_amount(order)
+ if self.verbose:
+ print(
+ "[I {:%Y-%m-%d %H:%M:%S}]: {} {}, price {:.2f}, amount {}, deal_amount {}, factor {}, value {:.2f}, cash {:.2f}.".format(
+ trade_start_time,
+ "sell" if order.direction == Order.SELL else "buy",
+ order.stock_id,
+ trade_price,
+ order.amount,
+ order.deal_amount,
+ order.factor,
+ trade_val,
+ self.trade_account.get_cash(),
+ )
+ )
+ return execute_result, {"trade_info": execute_result}
diff --git a/qlib/backtest/high_performance_ds.py b/qlib/backtest/high_performance_ds.py
new file mode 100644
index 0000000000..235bd054b1
--- /dev/null
+++ b/qlib/backtest/high_performance_ds.py
@@ -0,0 +1,629 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+from functools import lru_cache
+import logging
+from typing import List, Text, Union, Callable, Iterable, Dict
+from collections import OrderedDict
+
+import inspect
+import pandas as pd
+import numpy as np
+
+from ..utils.index_data import IndexData, SingleData
+from ..utils.resam import resam_ts_data, ts_data_last
+from ..log import get_module_logger
+from ..utils.time import is_single_value, Freq
+import qlib.utils.index_data as idd
+
+
+class BaseQuote:
+ def __init__(self, quote_df: pd.DataFrame, freq):
+ self.logger = get_module_logger("online operator", level=logging.INFO)
+
+ def get_all_stock(self) -> Iterable:
+ """return all stock codes
+
+ Return
+ ------
+ Iterable
+ all stock codes
+ """
+
+ raise NotImplementedError(f"Please implement the `get_all_stock` method")
+
+ def get_data(
+ self,
+ stock_id: str,
+ start_time: Union[pd.Timestamp, str],
+ end_time: Union[pd.Timestamp, str],
+ field: Union[str],
+ method: Union[str, None] = None,
+ ) -> Union[None, int, float, bool, IndexData]:
+ """get the specific field of stock data during start time and end_time,
+ and apply method to the data.
+
+ Example:
+ .. code-block::
+ $close $volume
+ instrument datetime
+ SH600000 2010-01-04 86.778313 16162960.0
+ 2010-01-05 87.433578 28117442.0
+ 2010-01-06 85.713585 23632884.0
+ 2010-01-07 83.788803 20813402.0
+ 2010-01-08 84.730675 16044853.0
+
+ SH600655 2010-01-04 2699.567383 158193.328125
+ 2010-01-08 2612.359619 77501.406250
+ 2010-01-11 2712.982422 160852.390625
+ 2010-01-12 2788.688232 164587.937500
+ 2010-01-13 2790.604004 145460.453125
+
+ this function is used for three case:
+
+ 1. method is not None. It returns int/float/bool/None.
+ - It will return None in one case, the method return None
+
+ print(get_data(stock_id="SH600000", start_time="2010-01-04", end_time="2010-01-06", field="$close", method="last"))
+
+ 85.713585
+
+ 2. method is None. It returns IndexData.
+ print(get_data(stock_id="SH600000", start_time="2010-01-04", end_time="2010-01-06", field="$close", method=None))
+
+ IndexData([86.778313, 87.433578, 85.713585], [2010-01-04, 2010-01-05, 2010-01-06])
+
+ Parameters
+ ----------
+ stock_id: str
+ start_time : Union[pd.Timestamp, str]
+ closed start time for backtest
+ end_time : Union[pd.Timestamp, str]
+ closed end time for backtest
+ field : str
+ the columns of data to fetch
+ method : Union[str, None]
+ the method apply to data.
+ e.g [None, "last", "all", "sum", "mean", "ts_data_last"]
+
+ Return
+ ----------
+ Union[None, int, float, bool, IndexData]
+ it will return None in following cases
+ - There is no stock data which meet the query criterion from data source.
+ - The `method` returns None
+ """
+
+ raise NotImplementedError(f"Please implement the `get_data` method")
+
+
+class PandasQuote(BaseQuote):
+ def __init__(self, quote_df: pd.DataFrame, freq):
+ super().__init__(quote_df=quote_df, freq=freq)
+ quote_dict = {}
+ for stock_id, stock_val in quote_df.groupby(level="instrument"):
+ quote_dict[stock_id] = stock_val.droplevel(level="instrument")
+ self.data = quote_dict
+
+ def get_all_stock(self):
+ return self.data.keys()
+
+ def get_data(self, stock_id, start_time, end_time, field, method=None):
+ if method == "ts_data_last":
+ method = ts_data_last
+ stock_data = resam_ts_data(self.data[stock_id][field], start_time, end_time, method=method)
+ if stock_data is None:
+ return None
+ elif isinstance(stock_data, (bool, np.bool_, int, float, np.number)):
+ return stock_data
+ elif isinstance(stock_data, pd.Series):
+ return idd.SingleData(stock_data)
+ else:
+ raise ValueError(f"stock data from resam_ts_data must be a number, pd.Series or pd.DataFrame")
+
+
+class NumpyQuote(BaseQuote):
+ def __init__(self, quote_df: pd.DataFrame, freq, region="cn"):
+ """NumpyQuote
+
+ Parameters
+ ----------
+ quote_df : pd.DataFrame
+ the init dataframe from qlib.
+ self.data : Dict(stock_id, IndexData.DataFrame)
+ """
+ super().__init__(quote_df=quote_df, freq=freq)
+ quote_dict = {}
+ for stock_id, stock_val in quote_df.groupby(level="instrument"):
+ quote_dict[stock_id] = idd.MultiData(stock_val.droplevel(level="instrument"))
+ quote_dict[stock_id].sort_index() # To support more flexible slicing, we must sort data first
+ self.data = quote_dict
+
+ n, unit = Freq.parse(freq)
+ if unit in Freq.SUPPORT_CAL_LIST:
+ self.freq = Freq.get_timedelta(1, unit)
+ else:
+ raise ValueError(f"{freq} is not supported in NumpyQuote")
+ self.region = region
+
+ def get_all_stock(self):
+ return self.data.keys()
+
+ @lru_cache(maxsize=512)
+ def get_data(self, stock_id, start_time, end_time, field, method=None):
+ # check stock id
+ if stock_id not in self.get_all_stock():
+ return None
+
+ # single data
+ # If it don't consider the classification of single data, it will consume a lot of time.
+ if is_single_value(start_time, end_time, self.freq, self.region):
+ # this is a very special case.
+ # skip aggregating function to speed-up the query calculation
+ try:
+ return self.data[stock_id].loc[start_time, field]
+ except KeyError:
+ return None
+ else:
+ data = self.data[stock_id].loc[start_time:end_time, field]
+ if data.empty:
+ return None
+ if method is not None:
+ data = self._agg_data(data, method)
+ return data
+
+ def _agg_data(self, data: IndexData, method):
+ """Agg data by specific method."""
+ # FIXME: why not call the method of data directly?
+ if method == "sum":
+ return np.nansum(data)
+ elif method == "mean":
+ return np.nanmean(data)
+ elif method == "last":
+ # FIXME: I've never seen that this method was called.
+ # Please merge it with "ts_data_last"
+ return data[-1]
+ elif method == "all":
+ return data.all()
+ elif method == "ts_data_last":
+ valid_data = data.loc[~data.isna().data.astype(bool)]
+ if len(valid_data) == 0:
+ return None
+ else:
+ return valid_data.iloc[-1]
+ else:
+ raise ValueError(f"{method} is not supported")
+
+
+class BaseSingleMetric:
+ """
+ The data structure of the single metric.
+ The following methods are used for computing metrics in one indicator.
+ """
+
+ def __init__(self, metric: Union[dict, pd.Series]):
+ """Single data structure for each metric.
+
+ Parameters
+ ----------
+ metric : Union[dict, pd.Series]
+ keys/index is stock_id, value is the metric value.
+ for example:
+ SH600068 NaN
+ SH600079 1.0
+ SH600266 NaN
+ ...
+ SZ300692 NaN
+ SZ300719 NaN,
+ """
+ raise NotImplementedError(f"Please implement the `__init__` method")
+
+ def __add__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric":
+ raise NotImplementedError(f"Please implement the `__add__` method")
+
+ def __radd__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric":
+ return self + other
+
+ def __sub__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric":
+ raise NotImplementedError(f"Please implement the `__sub__` method")
+
+ def __rsub__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric":
+ raise NotImplementedError(f"Please implement the `__rsub__` method")
+
+ def __mul__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric":
+ raise NotImplementedError(f"Please implement the `__mul__` method")
+
+ def __truediv__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric":
+ raise NotImplementedError(f"Please implement the `__truediv__` method")
+
+ def __eq__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric":
+ raise NotImplementedError(f"Please implement the `__eq__` method")
+
+ def __gt__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric":
+ raise NotImplementedError(f"Please implement the `__gt__` method")
+
+ def __lt__(self, other: Union["BaseSingleMetric", int, float]) -> "BaseSingleMetric":
+ raise NotImplementedError(f"Please implement the `__lt__` method")
+
+ def __len__(self) -> int:
+ raise NotImplementedError(f"Please implement the `__len__` method")
+
+ def sum(self) -> float:
+ raise NotImplementedError(f"Please implement the `sum` method")
+
+ def mean(self) -> float:
+ raise NotImplementedError(f"Please implement the `mean` method")
+
+ def count(self) -> int:
+ """Return the count of the single metric, NaN is not included."""
+
+ raise NotImplementedError(f"Please implement the `count` method")
+
+ def abs(self) -> "BaseSingleMetric":
+ raise NotImplementedError(f"Please implement the `abs` method")
+
+ @property
+ def empty(self) -> bool:
+ """If metric is empty, return True."""
+
+ raise NotImplementedError(f"Please implement the `empty` method")
+
+ def add(self, other: "BaseSingleMetric", fill_value: float = None) -> "BaseSingleMetric":
+ """Replace np.NaN with fill_value in two metrics and add them."""
+
+ raise NotImplementedError(f"Please implement the `add` method")
+
+ def replace(self, replace_dict: dict) -> "BaseSingleMetric":
+ """Replace the value of metric according to replace_dict."""
+
+ raise NotImplementedError(f"Please implement the `replace` method")
+
+ def apply(self, func: dict) -> "BaseSingleMetric":
+ """Replace the value of metric with func(metric).
+ Currently, the func is only qlib/backtest/order/Order.parse_dir.
+ """
+
+ raise NotImplementedError(f"Please implement the 'apply' method")
+
+
+class BaseOrderIndicator:
+ """
+ The data structure of order indicator.
+ !!!NOTE: There are two ways to organize the data structure. Please choose a better way.
+ 1. One way is using BaseSingleMetric to represent each metric. For example, the data
+ structure of PandasOrderIndicator is Dict[str, PandasSingleMetric]. It uses
+ PandasSingleMetric based on pd.Series to represent each metric.
+ 2. The another way doesn't use BaseSingleMetric to represent each metric. The data
+ structure of PandasOrderIndicator is a whole matrix. It means you are not necessary
+ to inherit the BaseSingleMetric.
+ """
+
+ def __init__(self, data):
+ self.data = data
+ self.logger = get_module_logger("online operator")
+
+ def assign(self, col: str, metric: Union[dict, pd.Series]):
+ """assign one metric.
+
+ Parameters
+ ----------
+ col : str
+ the metric name of one metric.
+ metric : Union[dict, pd.Series]
+ one metric with stock_id index, such as deal_amount, ffr, etc.
+ for example:
+ SH600068 NaN
+ SH600079 1.0
+ SH600266 NaN
+ ...
+ SZ300692 NaN
+ SZ300719 NaN,
+ """
+
+ raise NotImplementedError(f"Please implement the 'assign' method")
+
+ def transfer(self, func: Callable, new_col: str = None) -> Union[None, BaseSingleMetric]:
+ """compute new metric with existing metrics.
+
+ Parameters
+ ----------
+ func : Callable
+ the func of computing new metric.
+ the kwargs of func will be replaced with metric data by name in this function.
+ e.g.
+ def func(pa):
+ return (pa > 0).sum() / pa.count()
+ new_col : str, optional
+ New metric will be assigned in the data if new_col is not None, by default None.
+
+ Return
+ ----------
+ BaseSingleMetric
+ new metric.
+ """
+ func_sig = inspect.signature(func).parameters.keys()
+ func_kwargs = {sig: self.data[sig] for sig in func_sig}
+ tmp_metric = func(**func_kwargs)
+ if new_col is not None:
+ self.data[new_col] = tmp_metric
+ else:
+ return tmp_metric
+
+ def get_metric_series(self, metric: str) -> pd.Series:
+ """return the single metric with pd.Series format.
+
+ Parameters
+ ----------
+ metric : str
+ the metric name.
+
+ Return
+ ----------
+ pd.Series
+ the single metric.
+ If there is no metric name in the data, return pd.Series().
+ """
+
+ raise NotImplementedError(f"Please implement the 'get_metric_series' method")
+
+ def get_index_data(self, metric) -> SingleData:
+ """get one metric with the format of SingleData
+
+ Parameters
+ ----------
+ metric : str
+ the metric name.
+
+ Return
+ ------
+ IndexData.Series
+ one metric with the format of SingleData
+ """
+
+ raise NotImplementedError(f"Please implement the 'get_index_data' method")
+
+ @staticmethod
+ def sum_all_indicators(order_indicator, indicators: list, metrics: Union[str, List[str]], fill_value: float = None):
+ """sum indicators with the same metrics.
+ and assign to the order_indicator(BaseOrderIndicator).
+ NOTE: indicators could be a empty list when orders in lower level all fail.
+
+ Parameters
+ ----------
+ order_indicator : BaseOrderIndicator
+ the order indicator to assign.
+ indicators : List[BaseOrderIndicator]
+ the list of all inner indicators.
+ metrics : Union[str, List[str]]
+ all metrics needs ot be sumed.
+ fill_value : float, optional
+ fill np.NaN with value. By default None.
+ """
+
+ raise NotImplementedError(f"Please implement the 'sum_all_indicators' method")
+
+ def to_series(self) -> Dict[Text, pd.Series]:
+ """return the metrics as pandas series
+
+ for example: { "ffr":
+ SH600068 NaN
+ SH600079 1.0
+ SH600266 NaN
+ ...
+ SZ300692 NaN
+ SZ300719 NaN,
+ ...
+ }
+ """
+ raise NotImplementedError(f"Please implement the `to_series` method")
+
+
+class SingleMetric(BaseSingleMetric):
+ def __init__(self, metric):
+ self.metric = metric
+
+ def __add__(self, other):
+ if isinstance(other, (int, float)):
+ return self.__class__(self.metric + other)
+ elif isinstance(other, self.__class__):
+ return self.__class__(self.metric + other.metric)
+ else:
+ return NotImplemented
+
+ def __sub__(self, other):
+ if isinstance(other, (int, float)):
+ return self.__class__(self.metric - other)
+ elif isinstance(other, self.__class__):
+ return self.__class__(self.metric - other.metric)
+ else:
+ return NotImplemented
+
+ def __rsub__(self, other):
+ if isinstance(other, (int, float)):
+ return self.__class__(other - self.metric)
+ elif isinstance(other, self.__class__):
+ return self.__class__(other.metric - self.metric)
+ else:
+ return NotImplemented
+
+ def __mul__(self, other):
+ if isinstance(other, (int, float)):
+ return self.__class__(self.metric * other)
+ elif isinstance(other, self.__class__):
+ return self.__class__(self.metric * other.metric)
+ else:
+ return NotImplemented
+
+ def __truediv__(self, other):
+ if isinstance(other, (int, float)):
+ return self.__class__(self.metric / other)
+ elif isinstance(other, self.__class__):
+ return self.__class__(self.metric / other.metric)
+ else:
+ return NotImplemented
+
+ def __eq__(self, other):
+ if isinstance(other, (int, float)):
+ return self.__class__(self.metric == other)
+ elif isinstance(other, self.__class__):
+ return self.__class__(self.metric == other.metric)
+ else:
+ return NotImplemented
+
+ def __gt__(self, other):
+ if isinstance(other, (int, float)):
+ return self.__class__(self.metric > other)
+ elif isinstance(other, self.__class__):
+ return self.__class__(self.metric > other.metric)
+ else:
+ return NotImplemented
+
+ def __lt__(self, other):
+ if isinstance(other, (int, float)):
+ return self.__class__(self.metric < other)
+ elif isinstance(other, self.__class__):
+ return self.__class__(self.metric < other.metric)
+ else:
+ return NotImplemented
+
+ def __len__(self):
+ return len(self.metric)
+
+
+class PandasSingleMetric(SingleMetric):
+ """Each SingleMetric is based on pd.Series."""
+
+ def __init__(self, metric: Union[dict, pd.Series] = {}):
+ if isinstance(metric, dict):
+ self.metric = pd.Series(metric)
+ elif isinstance(metric, pd.Series):
+ self.metric = metric
+ else:
+ raise ValueError(f"metric must be dict or pd.Series")
+
+ def sum(self):
+ return self.metric.sum()
+
+ def mean(self):
+ return self.metric.mean()
+
+ def count(self):
+ return self.metric.count()
+
+ def abs(self):
+ return self.__class__(self.metric.abs())
+
+ @property
+ def empty(self):
+ return self.metric.empty
+
+ @property
+ def index(self):
+ return list(self.metric.index)
+
+ def add(self, other, fill_value=None):
+ return self.__class__(self.metric.add(other.metric, fill_value=fill_value))
+
+ def replace(self, replace_dict: dict):
+ return self.__class__(self.metric.replace(replace_dict))
+
+ def apply(self, func: Callable):
+ return self.__class__(self.metric.apply(func))
+
+ def reindex(self, index, fill_value):
+ return self.__class__(self.metric.reindex(index, fill_value=fill_value))
+
+ def __repr__(self):
+ return repr(self.metric)
+
+
+class PandasOrderIndicator(BaseOrderIndicator):
+ """
+ The data structure is OrderedDict(str: PandasSingleMetric).
+ Each PandasSingleMetric based on pd.Series is one metric.
+ Str is the name of metric.
+ """
+
+ def __init__(self):
+ self.data: Dict[str, PandasSingleMetric] = OrderedDict()
+
+ def assign(self, col: str, metric: Union[dict, pd.Series]):
+ self.data[col] = PandasSingleMetric(metric)
+
+ def get_index_data(self, metric):
+ if metric in self.data:
+ return idd.SingleData(self.data[metric].metric)
+ else:
+ return idd.SingleData()
+
+ def get_metric_series(self, metric: str) -> Union[pd.Series]:
+ if metric in self.data:
+ return self.data[metric].metric
+ else:
+ return pd.Series()
+
+ def to_series(self):
+ return {k: v.metric for k, v in self.data.items()}
+
+ @staticmethod
+ def sum_all_indicators(order_indicator, indicators: list, metrics: Union[str, List[str]], fill_value=0):
+ if isinstance(metrics, str):
+ metrics = [metrics]
+ for metric in metrics:
+ tmp_metric = PandasSingleMetric({})
+ for indicator in indicators:
+ tmp_metric = tmp_metric.add(indicator.data[metric], fill_value)
+ order_indicator.assign(metric, tmp_metric.metric)
+
+ def __repr__(self):
+ return repr(self.data)
+
+
+class NumpyOrderIndicator(BaseOrderIndicator):
+ """
+ The data structure is OrderedDict(str: SingleData).
+ Each idd.SingleData is one metric.
+ Str is the name of metric.
+ """
+
+ def __init__(self):
+ self.data: Dict[str, SingleData] = OrderedDict()
+
+ def assign(self, col: str, metric: dict):
+ self.data[col] = idd.SingleData(metric)
+
+ def get_index_data(self, metric):
+ if metric in self.data:
+ return self.data[metric]
+ else:
+ return idd.SingleData()
+
+ def get_metric_series(self, metric: str) -> Union[pd.Series]:
+ return self.data[metric].to_series()
+
+ def to_series(self) -> Dict[str, pd.Series]:
+ tmp_metric_dict = {}
+ for metric in self.data:
+ tmp_metric_dict[metric] = self.get_metric_series(metric)
+ return tmp_metric_dict
+
+ @staticmethod
+ def sum_all_indicators(order_indicator, indicators: list, metrics: Union[str, List[str]], fill_value=0):
+ # get all index(stock_id)
+ stocks = set()
+ for indicator in indicators:
+ # set(np.ndarray.tolist()) is faster than set(np.ndarray)
+ stocks = stocks | set(indicator.data[metrics[0]].index.tolist())
+ stocks = list(stocks)
+ stocks.sort()
+
+ # add metric by index
+ if isinstance(metrics, str):
+ metrics = [metrics]
+ for metric in metrics:
+ order_indicator.data[metric] = idd.sum_by_index(
+ [indicator.data[metric] for indicator in indicators], stocks, fill_value
+ )
+
+ def __repr__(self):
+ return repr(self.data)
diff --git a/qlib/backtest/position.py b/qlib/backtest/position.py
new file mode 100644
index 0000000000..2bfb208938
--- /dev/null
+++ b/qlib/backtest/position.py
@@ -0,0 +1,544 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+
+import copy
+import pathlib
+from typing import Dict, List, Union
+
+import pandas as pd
+from datetime import timedelta
+import numpy as np
+
+from .decision import Order
+from ..data.data import D
+
+
+class BasePosition:
+ """
+ The Position want to maintain the position like a dictionary
+ Please refer to the `Position` class for the position
+ """
+
+ def __init__(self, cash=0.0, *args, **kwargs):
+ self._settle_type = self.ST_NO
+
+ def skip_update(self) -> bool:
+ """
+ Should we skip updating operation for this position
+ For example, updating is meaningless for InfPosition
+
+ Returns
+ -------
+ bool:
+ should we skip the updating operator
+ """
+ return False
+
+ def check_stock(self, stock_id: str) -> bool:
+ """
+ check if is the stock in the position
+
+ Parameters
+ ----------
+ stock_id : str
+ the id of the stock
+
+ Returns
+ -------
+ bool:
+ if is the stock in the position
+ """
+ raise NotImplementedError(f"Please implement the `check_stock` method")
+
+ def update_order(self, order: Order, trade_val: float, cost: float, trade_price: float):
+ """
+ Parameters
+ ----------
+ order : Order
+ the order to update the position
+ trade_val : float
+ the trade value(money) of dealing results
+ cost : float
+ the trade cost of the dealing results
+ trade_price : float
+ the trade price of the dealing results
+ """
+ raise NotImplementedError(f"Please implement the `update_order` method")
+
+ def update_stock_price(self, stock_id, price: float):
+ """
+ Updating the latest price of the order
+ The useful when clearing balance at each bar end
+
+ Parameters
+ ----------
+ stock_id :
+ the id of the stock
+ price : float
+ the price to be updated
+ """
+ raise NotImplementedError(f"Please implement the `update stock price` method")
+
+ def calculate_stock_value(self) -> float:
+ """
+ calculate the value of the all assets except cash in the position
+
+ Returns
+ -------
+ float:
+ the value(money) of all the stock
+ """
+ raise NotImplementedError(f"Please implement the `calculate_stock_value` method")
+
+ def get_stock_list(self) -> List:
+ """
+ Get the list of stocks in the position.
+ """
+ raise NotImplementedError(f"Please implement the `get_stock_list` method")
+
+ def get_stock_price(self, code) -> float:
+ """
+ get the latest price of the stock
+
+ Parameters
+ ----------
+ code :
+ the code of the stock
+ """
+ raise NotImplementedError(f"Please implement the `get_stock_price` method")
+
+ def get_stock_amount(self, code) -> float:
+ """
+ get the amount of the stock
+
+ Parameters
+ ----------
+ code :
+ the code of the stock
+
+ Returns
+ -------
+ float:
+ the amount of the stock
+ """
+ raise NotImplementedError(f"Please implement the `get_stock_amount` method")
+
+ def get_cash(self, include_settle: bool = False) -> float:
+ """
+
+ Returns
+ -------
+ float:
+ the available(tradable) cash in position
+ include_settle:
+ will the unsettled(delayed) cash included
+ Default: not include those unavailable cash
+ """
+ raise NotImplementedError(f"Please implement the `get_cash` method")
+
+ def get_stock_amount_dict(self) -> Dict:
+ """
+ generate stock amount dict {stock_id : amount of stock}
+
+ Returns
+ -------
+ Dict:
+ {stock_id : amount of stock}
+ """
+ raise NotImplementedError(f"Please implement the `get_stock_amount_dict` method")
+
+ def get_stock_weight_dict(self, only_stock: bool = False) -> Dict:
+ """
+ generate stock weight dict {stock_id : value weight of stock in the position}
+ it is meaningful in the beginning or the end of each trade step
+ - During execution of each trading step, the weight may be not consistant with the portfolio value
+
+ Parameters
+ ----------
+ only_stock : bool
+ If only_stock=True, the weight of each stock in total stock will be returned
+ If only_stock=False, the weight of each stock in total assets(stock + cash) will be returned
+
+ Returns
+ -------
+ Dict:
+ {stock_id : value weight of stock in the position}
+ """
+ raise NotImplementedError(f"Please implement the `get_stock_weight_dict` method")
+
+ def add_count_all(self, bar):
+ """
+ Will be called at the end of each bar on each level
+
+ Parameters
+ ----------
+ bar :
+ The level to be updated
+ """
+ raise NotImplementedError(f"Please implement the `add_count_all` method")
+
+ def update_weight_all(self):
+ """
+ Updating the position weight;
+
+ # TODO: this function is a little weird. The weight data in the position is in a wrong state after dealing order
+ # and before updating weight.
+
+ Parameters
+ ----------
+ bar :
+ The level to be updated
+ """
+ raise NotImplementedError(f"Please implement the `add_count_all` method")
+
+ ST_CASH = "cash"
+ ST_NO = None
+
+ def settle_start(self, settle_type: str):
+ """
+ settlement start
+ It will act like start and commit a transaction
+
+ Parameters
+ ----------
+ settle_type : str
+ Should we make delay the settlement in each execution (each execution will make the executor a step forward)
+ - "cash": make the cash settlement delayed.
+ - The cash you get can't be used in current step (e.g. you can't sell a stock to get cash to buy another
+ stock)
+ - None: not settlement mechanism
+ - TODO: other assets will be supported in the future.
+ """
+ raise NotImplementedError(f"Please implement the `settle_conf` method")
+
+ def settle_commit(self):
+ """
+ settlement commit
+
+ Parameters
+ ----------
+ settle_type : str
+ please refer to the documents of Executor
+ """
+ raise NotImplementedError(f"Please implement the `settle_commit` method")
+
+
+class Position(BasePosition):
+ """Position
+
+ current state of position
+ a typical example is :{
+ : {
+ 'count': ,
+ 'amount': ,
+ 'price': ,
+ 'weight': ,
+ },
+ }
+ """
+
+ def __init__(self, cash: float = 0, position_dict: Dict[str, Dict[str, float]] = {}):
+ """Init position by cash and position_dict.
+
+ Parameters
+ ----------
+ start_time :
+ the start time of backtest. It's for filling the initial value of stocks.
+ cash : float, optional
+ initial cash in account, by default 0
+ position_dict : Dict[
+ stock_id,
+ Union[
+ int, # it is equal to {"amount": int}
+ {"amount": int, "price"(optional): float},
+ ]
+ ]
+ initial stocks with parameters amount and price,
+ if there is no price key in the dict of stocks, it will be filled by _fill_stock_value.
+ by default {}.
+ """
+ super().__init__()
+
+ # NOTE: The position dict must be copied!!!
+ # Otherwise the initial value
+ self.init_cash = cash
+ self.position = position_dict.copy()
+ for stock in self.position:
+ if isinstance(self.position[stock], int):
+ self.position[stock] = {"amount": self.position[stock]}
+ self.position["cash"] = cash
+
+ # If the stock price information is missing, the account value will not be calculated temporarily
+ try:
+ self.position["now_account_value"] = self.calculate_value()
+ except KeyError:
+ pass
+
+ def fill_stock_value(self, start_time: Union[str, pd.Timestamp], freq: str, last_days: int = 30):
+ """fill the stock value by the close price of latest last_days from qlib.
+
+ Parameters
+ ----------
+ start_time :
+ the start time of backtest.
+ last_days : int, optional
+ the days to get the latest close price, by default 30.
+ """
+ stock_list = []
+ for stock in self.position:
+ if not isinstance(self.position[stock], dict):
+ continue
+ if ("price" not in self.position[stock]) or (self.position[stock]["price"] is None):
+ stock_list.append(stock)
+
+ if len(stock_list) == 0:
+ return
+
+ start_time = pd.Timestamp(start_time)
+ # note that start time is 2020-01-01 00:00:00 if raw start time is "2020-01-01"
+ price_end_time = start_time
+ price_start_time = start_time - timedelta(days=last_days)
+ price_df = D.features(
+ stock_list, ["$close"], price_start_time, price_end_time, freq=freq, disk_cache=True
+ ).dropna()
+ price_dict = price_df.groupby(["instrument"]).tail(1).reset_index(level=1, drop=True)["$close"].to_dict()
+
+ if len(price_dict) < len(stock_list):
+ lack_stock = set(stock_list) - set(price_dict)
+ raise ValueError(f"{lack_stock} doesn't have close price in qlib in the latest {last_days} days")
+
+ for stock in stock_list:
+ self.position[stock]["price"] = price_dict[stock]
+ self.position["now_account_value"] = self.calculate_value()
+
+ def _init_stock(self, stock_id, amount, price=None):
+ """
+ initialization the stock in current position
+
+ Parameters
+ ----------
+ stock_id :
+ the id of the stock
+ amount : float
+ the amount of the stock
+ price :
+ the price when buying the init stock
+ """
+ self.position[stock_id] = {}
+ self.position[stock_id]["amount"] = amount
+ self.position[stock_id]["price"] = price
+ self.position[stock_id]["weight"] = 0 # update the weight in the end of the trade date
+
+ def _buy_stock(self, stock_id, trade_val, cost, trade_price):
+ trade_amount = trade_val / trade_price
+ if stock_id not in self.position:
+ self._init_stock(stock_id=stock_id, amount=trade_amount, price=trade_price)
+ else:
+ # exist, add amount
+ self.position[stock_id]["amount"] += trade_amount
+
+ self.position["cash"] -= trade_val + cost
+
+ def _sell_stock(self, stock_id, trade_val, cost, trade_price):
+ trade_amount = trade_val / trade_price
+ if stock_id not in self.position:
+ raise KeyError("{} not in current position".format(stock_id))
+ else:
+ # decrease the amount of stock
+ self.position[stock_id]["amount"] -= trade_amount
+ # check if to delete
+ if self.position[stock_id]["amount"] < -1e-5:
+ raise ValueError(
+ "only have {} {}, require {}".format(self.position[stock_id]["amount"], stock_id, trade_amount)
+ )
+ elif abs(self.position[stock_id]["amount"]) <= 1e-5:
+ self._del_stock(stock_id)
+
+ new_cash = trade_val - cost
+ if self._settle_type == self.ST_CASH:
+ self.position["cash_delay"] += new_cash
+ elif self._settle_type == self.ST_NO:
+ self.position["cash"] += new_cash
+ else:
+ raise NotImplementedError(f"This type of input is not supported")
+
+ def _del_stock(self, stock_id):
+ del self.position[stock_id]
+
+ def check_stock(self, stock_id):
+ return stock_id in self.position
+
+ def update_order(self, order, trade_val, cost, trade_price):
+ # handle order, order is a order class, defined in exchange.py
+ if order.direction == Order.BUY:
+ # BUY
+ self._buy_stock(order.stock_id, trade_val, cost, trade_price)
+ elif order.direction == Order.SELL:
+ # SELL
+ self._sell_stock(order.stock_id, trade_val, cost, trade_price)
+ else:
+ raise NotImplementedError("do not support order direction {}".format(order.direction))
+
+ def update_stock_price(self, stock_id, price):
+ self.position[stock_id]["price"] = price
+
+ def update_stock_count(self, stock_id, bar, count):
+ self.position[stock_id][f"count_{bar}"] = count
+
+ def update_stock_weight(self, stock_id, weight):
+ self.position[stock_id]["weight"] = weight
+
+ def calculate_stock_value(self):
+ stock_list = self.get_stock_list()
+ value = 0
+ for stock_id in stock_list:
+ value += self.position[stock_id]["amount"] * self.position[stock_id]["price"]
+ return value
+
+ def calculate_value(self):
+ value = self.calculate_stock_value()
+ value += self.position["cash"] + self.position.get("cash_delay", 0.0)
+ return value
+
+ def get_stock_list(self):
+ stock_list = list(set(self.position.keys()) - {"cash", "now_account_value", "cash_delay"})
+ return stock_list
+
+ def get_stock_price(self, code):
+ return self.position[code]["price"]
+
+ def get_stock_amount(self, code):
+ return self.position[code]["amount"] if code in self.position else 0
+
+ def get_stock_count(self, code, bar):
+ """the days the account has been hold, it may be used in some special strategies"""
+ if f"count_{bar}" in self.position[code]:
+ return self.position[code][f"count_{bar}"]
+ else:
+ return 0
+
+ def get_stock_weight(self, code):
+ return self.position[code]["weight"]
+
+ def get_cash(self, include_settle=False):
+ cash = self.position["cash"]
+ if include_settle:
+ cash += self.position.get("cash_delay", 0.0)
+ return cash
+
+ def get_stock_amount_dict(self):
+ """generate stock amount dict {stock_id : amount of stock}"""
+ d = {}
+ stock_list = self.get_stock_list()
+ for stock_code in stock_list:
+ d[stock_code] = self.get_stock_amount(code=stock_code)
+ return d
+
+ def get_stock_weight_dict(self, only_stock=False):
+ """get_stock_weight_dict
+ generate stock weight dict {stock_id : value weight of stock in the position}
+ it is meaningful in the beginning or the end of each trade date
+
+ :param only_stock: If only_stock=True, the weight of each stock in total stock will be returned
+ If only_stock=False, the weight of each stock in total assets(stock + cash) will be returned
+ """
+ if only_stock:
+ position_value = self.calculate_stock_value()
+ else:
+ position_value = self.calculate_value()
+ d = {}
+ stock_list = self.get_stock_list()
+ for stock_code in stock_list:
+ d[stock_code] = self.position[stock_code]["amount"] * self.position[stock_code]["price"] / position_value
+ return d
+
+ def add_count_all(self, bar):
+ stock_list = self.get_stock_list()
+ for code in stock_list:
+ if f"count_{bar}" in self.position[code]:
+ self.position[code][f"count_{bar}"] += 1
+ else:
+ self.position[code][f"count_{bar}"] = 1
+
+ def update_weight_all(self):
+ weight_dict = self.get_stock_weight_dict()
+ for stock_code, weight in weight_dict.items():
+ self.update_stock_weight(stock_code, weight)
+
+ def settle_start(self, settle_type):
+ assert self._settle_type == self.ST_NO, "Currently, settlement can't be nested!!!!!"
+ self._settle_type = settle_type
+ if settle_type == self.ST_CASH:
+ self.position["cash_delay"] = 0.0
+
+ def settle_commit(self):
+ if self._settle_type != self.ST_NO:
+ if self._settle_type == self.ST_CASH:
+ self.position["cash"] += self.position["cash_delay"]
+ del self.position["cash_delay"]
+ else:
+ raise NotImplementedError(f"This type of input is not supported")
+ self._settle_type = self.ST_NO
+
+
+class InfPosition(BasePosition):
+ """
+ Position with infinite cash and amount.
+
+ This is useful for generating random orders.
+ """
+
+ def skip_update(self) -> bool:
+ """Updating state is meaningless for InfPosition"""
+ return True
+
+ def check_stock(self, stock_id: str) -> bool:
+ # InfPosition always have any stocks
+ return True
+
+ def update_order(self, order: Order, trade_val: float, cost: float, trade_price: float):
+ pass
+
+ def update_stock_price(self, stock_id, price: float):
+ pass
+
+ def calculate_stock_value(self) -> float:
+ """
+ Returns
+ -------
+ float:
+ infinity stock value
+ """
+ return np.inf
+
+ def get_stock_list(self) -> List:
+ raise NotImplementedError(f"InfPosition doesn't support stock list position")
+
+ def get_stock_price(self, code) -> float:
+ """the price of the inf position is meaningless"""
+ return np.nan
+
+ def get_stock_amount(self, code) -> float:
+ return np.inf
+
+ def get_cash(self, include_settle=False) -> float:
+ return np.inf
+
+ def get_stock_amount_dict(self) -> Dict:
+ raise NotImplementedError(f"InfPosition doesn't support get_stock_amount_dict")
+
+ def get_stock_weight_dict(self, only_stock: bool) -> Dict:
+ raise NotImplementedError(f"InfPosition doesn't support get_stock_weight_dict")
+
+ def add_count_all(self, bar):
+ raise NotImplementedError(f"InfPosition doesn't support add_count_all")
+
+ def update_weight_all(self):
+ raise NotImplementedError(f"InfPosition doesn't support update_weight_all")
+
+ def settle_start(self, settle_type: str):
+ pass
+
+ def settle_commit(self):
+ pass
diff --git a/qlib/contrib/backtest/profit_attribution.py b/qlib/backtest/profit_attribution.py
similarity index 98%
rename from qlib/contrib/backtest/profit_attribution.py
rename to qlib/backtest/profit_attribution.py
index df5dd965d4..895f5c78bb 100644
--- a/qlib/contrib/backtest/profit_attribution.py
+++ b/qlib/backtest/profit_attribution.py
@@ -1,12 +1,14 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
-
+"""
+This module is not well maintained.
+"""
import numpy as np
import pandas as pd
from .position import Position
-from ...data import D
-from ...config import C
+from ..data import D
+from ..config import C
import datetime
from pathlib import Path
@@ -35,7 +37,7 @@ def get_benchmark_weight(
"""
if not path:
- path = Path(C.dpm.get_data_path(freq)).expanduser() / "raw" / "AIndexMembers" / "weights.csv"
+ path = Path(C.dpm.get_data_uri(freq)).expanduser() / "raw" / "AIndexMembers" / "weights.csv"
# TODO: the storage of weights should be implemented in a more elegent way
# TODO: The benchmark is not consistant with the filename in instruments.
bench_weight_df = pd.read_csv(path, usecols=["code", "date", "index", "weight"])
diff --git a/qlib/backtest/report.py b/qlib/backtest/report.py
new file mode 100644
index 0000000000..03fb85344c
--- /dev/null
+++ b/qlib/backtest/report.py
@@ -0,0 +1,617 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+
+from collections import OrderedDict
+import pathlib
+from typing import Dict, List, Tuple, Union
+
+import numpy as np
+import pandas as pd
+
+from qlib.backtest.exchange import Exchange
+from .decision import IdxTradeRange
+from qlib.backtest.decision import BaseTradeDecision, Order, OrderDir
+from qlib.backtest.utils import TradeCalendarManager
+from .high_performance_ds import BaseOrderIndicator, PandasOrderIndicator, NumpyOrderIndicator, SingleMetric
+from ..data import D
+from ..tests.config import CSI300_BENCH
+from ..utils.resam import get_higher_eq_freq_feature, resam_ts_data
+import qlib.utils.index_data as idd
+
+
+class PortfolioMetrics:
+ """
+ Motivation:
+ PortfolioMetrics is for supporting portfolio related metrics.
+
+ Implementation:
+
+ daily portfolio metrics of the account
+ contain those followings: return, cost, turnover, account, cash, bench, value
+ For each step(bar/day/minute), each column represents
+ - return: the return of the portfolio generated by strategy **without transaction fee**.
+ - cost: the transaction fee and slippage.
+ - account: the total value of assets(cash and securities are both included) in user account based on the close price of each step.
+ - cash: the amount of cash in user's account.
+ - bench: the return of the benchmark
+ - value: the total value of securities/stocks/instruments (cash is excluded).
+
+ update report
+ """
+
+ def __init__(self, freq: str = "day", benchmark_config: dict = {}):
+ """
+ Parameters
+ ----------
+ freq : str
+ frequency of trading bar, used for updating hold count of trading bar
+ benchmark_config : dict
+ config of benchmark, may including the following arguments:
+ - benchmark : Union[str, list, pd.Series]
+ - If `benchmark` is pd.Series, `index` is trading date; the value T is the change from T-1 to T.
+ example:
+ print(D.features(D.instruments('csi500'), ['$close/Ref($close, 1)-1'])['$close/Ref($close, 1)-1'].head())
+ 2017-01-04 0.011693
+ 2017-01-05 0.000721
+ 2017-01-06 -0.004322
+ 2017-01-09 0.006874
+ 2017-01-10 -0.003350
+ - If `benchmark` is list, will use the daily average change of the stock pool in the list as the 'bench'.
+ - If `benchmark` is str, will use the daily change as the 'bench'.
+ benchmark code, default is SH000300 CSI300
+ - start_time : Union[str, pd.Timestamp], optional
+ - If `benchmark` is pd.Series, it will be ignored
+ - Else, it represent start time of benchmark, by default None
+ - end_time : Union[str, pd.Timestamp], optional
+ - If `benchmark` is pd.Series, it will be ignored
+ - Else, it represent end time of benchmark, by default None
+
+ """
+
+ self.init_vars()
+ self.init_bench(freq=freq, benchmark_config=benchmark_config)
+
+ def init_vars(self):
+ self.accounts = OrderedDict() # account postion value for each trade time
+ self.returns = OrderedDict() # daily return rate for each trade time
+ self.total_turnovers = OrderedDict() # total turnover for each trade time
+ self.turnovers = OrderedDict() # turnover for each trade time
+ self.total_costs = OrderedDict() # total trade cost for each trade time
+ self.costs = OrderedDict() # trade cost rate for each trade time
+ self.values = OrderedDict() # value for each trade time
+ self.cashes = OrderedDict()
+ self.benches = OrderedDict()
+ self.latest_pm_time = None # pd.TimeStamp
+
+ def init_bench(self, freq=None, benchmark_config=None):
+ if freq is not None:
+ self.freq = freq
+ self.benchmark_config = benchmark_config
+ self.bench = self._cal_benchmark(self.benchmark_config, self.freq)
+
+ def _cal_benchmark(self, benchmark_config, freq):
+ if benchmark_config is None:
+ return None
+ benchmark = benchmark_config.get("benchmark", CSI300_BENCH)
+ if benchmark is None:
+ return None
+
+ if isinstance(benchmark, pd.Series):
+ return benchmark
+ else:
+ start_time = benchmark_config.get("start_time", None)
+ end_time = benchmark_config.get("end_time", None)
+
+ if freq is None:
+ raise ValueError("benchmark freq can't be None!")
+ _codes = benchmark if isinstance(benchmark, (list, dict)) else [benchmark]
+ fields = ["$close/Ref($close,1)-1"]
+ _temp_result, _ = get_higher_eq_freq_feature(_codes, fields, start_time, end_time, freq=freq)
+ if len(_temp_result) == 0:
+ raise ValueError(f"The benchmark {_codes} does not exist. Please provide the right benchmark")
+ return _temp_result.groupby(level="datetime")[_temp_result.columns.tolist()[0]].mean().fillna(0)
+
+ def _sample_benchmark(self, bench, trade_start_time, trade_end_time):
+ if self.bench is None:
+ return None
+
+ def cal_change(x):
+ return (x + 1).prod()
+
+ _ret = resam_ts_data(bench, trade_start_time, trade_end_time, method=cal_change)
+ return 0.0 if _ret is None else _ret - 1
+
+ def is_empty(self):
+ return len(self.accounts) == 0
+
+ def get_latest_date(self):
+ return self.latest_pm_time
+
+ def get_latest_account_value(self):
+ return self.accounts[self.latest_pm_time]
+
+ def get_latest_total_cost(self):
+ return self.total_costs[self.latest_pm_time]
+
+ def get_latest_total_turnover(self):
+ return self.total_turnovers[self.latest_pm_time]
+
+ def update_portfolio_metrics_record(
+ self,
+ trade_start_time=None,
+ trade_end_time=None,
+ account_value=None,
+ cash=None,
+ return_rate=None,
+ total_turnover=None,
+ turnover_rate=None,
+ total_cost=None,
+ cost_rate=None,
+ stock_value=None,
+ bench_value=None,
+ ):
+ # check data
+ if None in [
+ trade_start_time,
+ account_value,
+ cash,
+ return_rate,
+ total_turnover,
+ turnover_rate,
+ total_cost,
+ cost_rate,
+ stock_value,
+ ]:
+ raise ValueError(
+ "None in [trade_start_time, account_value, cash, return_rate, total_turnover, turnover_rate, total_cost, cost_rate, stock_value]"
+ )
+
+ if trade_end_time is None and bench_value is None:
+ raise ValueError("Both trade_end_time and bench_value is None, benchmark is not usable.")
+ elif bench_value is None:
+ bench_value = self._sample_benchmark(self.bench, trade_start_time, trade_end_time)
+
+ # update pm data
+ self.accounts[trade_start_time] = account_value
+ self.returns[trade_start_time] = return_rate
+ self.total_turnovers[trade_start_time] = total_turnover
+ self.turnovers[trade_start_time] = turnover_rate
+ self.total_costs[trade_start_time] = total_cost
+ self.costs[trade_start_time] = cost_rate
+ self.values[trade_start_time] = stock_value
+ self.cashes[trade_start_time] = cash
+ self.benches[trade_start_time] = bench_value
+ # update pm
+ self.latest_pm_time = trade_start_time
+ # finish pm update in each step
+
+ def generate_portfolio_metrics_dataframe(self):
+ pm = pd.DataFrame()
+ pm["account"] = pd.Series(self.accounts)
+ pm["return"] = pd.Series(self.returns)
+ pm["total_turnover"] = pd.Series(self.total_turnovers)
+ pm["turnover"] = pd.Series(self.turnovers)
+ pm["total_cost"] = pd.Series(self.total_costs)
+ pm["cost"] = pd.Series(self.costs)
+ pm["value"] = pd.Series(self.values)
+ pm["cash"] = pd.Series(self.cashes)
+ pm["bench"] = pd.Series(self.benches)
+ pm.index.name = "datetime"
+ return pm
+
+ def save_portfolio_metrics(self, path):
+ r = self.generate_portfolio_metrics_dataframe()
+ r.to_csv(path)
+
+ def load_portfolio_metrics(self, path):
+ """load pm from a file
+ should have format like
+ columns = ['account', 'return', 'total_turnover', 'turnover', 'cost', 'total_cost', 'value', 'cash', 'bench']
+ :param
+ path: str/ pathlib.Path()
+ """
+ path = pathlib.Path(path)
+ r = pd.read_csv(open(path, "rb"), index_col=0)
+ r.index = pd.DatetimeIndex(r.index)
+
+ index = r.index
+ self.init_vars()
+ for trade_start_time in index:
+ self.update_portfolio_metrics_record(
+ trade_start_time=trade_start_time,
+ account_value=r.loc[trade_start_time]["account"],
+ cash=r.loc[trade_start_time]["cash"],
+ return_rate=r.loc[trade_start_time]["return"],
+ total_turnover=r.loc[trade_start_time]["total_turnover"],
+ turnover_rate=r.loc[trade_start_time]["turnover"],
+ total_cost=r.loc[trade_start_time]["total_cost"],
+ cost_rate=r.loc[trade_start_time]["cost"],
+ stock_value=r.loc[trade_start_time]["value"],
+ bench_value=r.loc[trade_start_time]["bench"],
+ )
+
+
+class Indicator:
+ """
+ `Indicator` is implemented in a aggregate way.
+ All the metrics are calculated aggregately.
+ All the metrics are calculated for a seperated stock and in a specific step on a specific level.
+
+ | indicator | desc. |
+ |--------------+--------------------------------------------------------------|
+ | amount | the *target* amount given by the outer strategy |
+ | deal_amount | the real deal amount |
+ | inner_amount | the total *target* amount of inner strategy |
+ | trade_price | the average deal price |
+ | trade_value | the total trade value |
+ | trade_cost | the total trade cost (base price need drection) |
+ | trade_dir | the trading direction |
+ | ffr | full fill rate |
+ | pa | price advantage |
+ | pos | win rate |
+ | base_price | the price of baseline |
+ | base_volume | the volume of baseline (for weighted aggregating base_price) |
+
+ **NOTE**:
+ The `base_price` and `base_volume` can't be NaN when there are not trading on that step. Otherwise
+ aggregating get wrong results.
+
+ So `base_price` will not be calculated in a aggregate way!!
+
+ """
+
+ def __init__(self, order_indicator_cls=NumpyOrderIndicator):
+ self.order_indicator_cls = order_indicator_cls
+
+ # order indicator is metrics for a single order for a specific step
+ self.order_indicator_his = OrderedDict()
+ self.order_indicator: BaseOrderIndicator = self.order_indicator_cls()
+
+ # trade indicator is metrics for all orders for a specific step
+ self.trade_indicator_his = OrderedDict()
+ self.trade_indicator: Dict[str, float] = OrderedDict()
+
+ self._trade_calendar = None
+
+ # def reset(self, trade_calendar: TradeCalendarManager):
+ def reset(self):
+ self.order_indicator: BaseOrderIndicator = self.order_indicator_cls()
+ self.trade_indicator = OrderedDict()
+ # self._trade_calendar = trade_calendar
+
+ def record(self, trade_start_time):
+ self.order_indicator_his[trade_start_time] = self.get_order_indicator()
+ self.trade_indicator_his[trade_start_time] = self.get_trade_indicator()
+
+ def _update_order_trade_info(self, trade_info: list):
+ amount = dict()
+ deal_amount = dict()
+ trade_price = dict()
+ trade_value = dict()
+ trade_cost = dict()
+ trade_dir = dict()
+ pa = dict()
+
+ for order, _trade_val, _trade_cost, _trade_price in trade_info:
+ amount[order.stock_id] = order.amount_delta
+ deal_amount[order.stock_id] = order.deal_amount_delta
+ trade_price[order.stock_id] = _trade_price
+ trade_value[order.stock_id] = _trade_val * order.sign
+ trade_cost[order.stock_id] = _trade_cost
+ trade_dir[order.stock_id] = order.direction
+ # The PA in the innermost layer is meanless
+ pa[order.stock_id] = 0
+
+ self.order_indicator.assign("amount", amount)
+ self.order_indicator.assign("inner_amount", amount)
+ self.order_indicator.assign("deal_amount", deal_amount)
+ # NOTE: trade_price and baseline price will be same on the lowest-level
+ self.order_indicator.assign("trade_price", trade_price)
+ self.order_indicator.assign("trade_value", trade_value)
+ self.order_indicator.assign("trade_cost", trade_cost)
+ self.order_indicator.assign("trade_dir", trade_dir)
+ self.order_indicator.assign("pa", pa)
+
+ def _update_order_fulfill_rate(self):
+ def func(deal_amount, amount):
+ # deal_amount is np.NaN or None when there is no inner decision. So full fill rate is 0.
+ tmp_deal_amount = deal_amount.reindex(amount.index, 0)
+ tmp_deal_amount = tmp_deal_amount.replace({np.NaN: 0})
+ return tmp_deal_amount / amount
+
+ self.order_indicator.transfer(func, "ffr")
+
+ def update_order_indicators(self, trade_info: list):
+ self._update_order_trade_info(trade_info=trade_info)
+ self._update_order_fulfill_rate()
+
+ def _agg_order_trade_info(self, inner_order_indicators: List[Dict[str, pd.Series]]):
+ # calculate total trade amount with each inner order indicator.
+ def trade_amount_func(deal_amount, trade_price):
+ return deal_amount * trade_price
+
+ for indicator in inner_order_indicators:
+ indicator.transfer(trade_amount_func, "trade_price")
+
+ # sum inner order indicators with same metric.
+ all_metric = ["inner_amount", "deal_amount", "trade_price", "trade_value", "trade_cost", "trade_dir"]
+ self.order_indicator_cls.sum_all_indicators(
+ self.order_indicator, inner_order_indicators, all_metric, fill_value=0
+ )
+
+ def func(trade_price, deal_amount):
+ # trade_price is np.NaN instead of inf when deal_amount is zero.
+ tmp_deal_amount = deal_amount.replace({0: np.NaN})
+ return trade_price / tmp_deal_amount
+
+ self.order_indicator.transfer(func, "trade_price")
+
+ def func_apply(trade_dir):
+ return trade_dir.apply(Order.parse_dir)
+
+ self.order_indicator.transfer(func_apply, "trade_dir")
+
+ def _update_trade_amount(self, outer_trade_decision: BaseTradeDecision):
+ # NOTE: these indicator is designed for order execution, so the
+ decision: List[Order] = outer_trade_decision.get_decision()
+ if len(decision) == 0:
+ self.order_indicator.assign("amount", {})
+ else:
+ self.order_indicator.assign("amount", {order.stock_id: order.amount_delta for order in decision})
+
+ def _get_base_vol_pri(
+ self,
+ inst: str,
+ trade_start_time: pd.Timestamp,
+ trade_end_time: pd.Timestamp,
+ direction: OrderDir,
+ decision: BaseTradeDecision,
+ trade_exchange: Exchange,
+ pa_config: dict = {},
+ ):
+ """
+ Get the base volume and price information
+ All the base price values are rooted from this function
+ """
+
+ agg = pa_config.get("agg", "twap").lower()
+ price = pa_config.get("price", "deal_price").lower()
+
+ if decision.trade_range is not None:
+ trade_start_time, trade_end_time = decision.trade_range.clip_time_range(
+ start_time=trade_start_time, end_time=trade_end_time
+ )
+
+ if price == "deal_price":
+ price_s = trade_exchange.get_deal_price(
+ inst, trade_start_time, trade_end_time, direction=direction, method=None
+ )
+ else:
+ raise NotImplementedError(f"This type of input is not supported")
+
+ # if there is no stock data during the time period
+ if price_s is None:
+ return None, None
+
+ if isinstance(price_s, (int, float, np.number)):
+ price_s = idd.SingleData(price_s, [trade_start_time])
+ elif isinstance(price_s, idd.SingleData):
+ pass
+ else:
+ raise NotImplementedError(f"This type of input is not supported")
+
+ # NOTE: there are some zeros in the trading price. These cases are known meaningless
+ # for aligning the previous logic, remove it.
+ # remove zero and negative values.
+ price_s = price_s.loc[(price_s > 1e-08).data.astype(np.bool)]
+ # NOTE ~(price_s < 1e-08) is different from price_s >= 1e-8
+ # ~(np.NaN < 1e-8) -> ~(False) -> True
+
+ if agg == "vwap":
+ volume_s = trade_exchange.get_volume(inst, trade_start_time, trade_end_time, method=None)
+ if isinstance(volume_s, (int, float, np.number)):
+ volume_s = idd.SingleData(volume_s, [trade_start_time])
+ volume_s = volume_s.reindex(price_s.index)
+ elif agg == "twap":
+ volume_s = idd.SingleData(1, price_s.index)
+ else:
+ raise NotImplementedError(f"This type of input is not supported")
+
+ base_volume = volume_s.sum()
+ base_price = (price_s * volume_s).sum() / base_volume
+ return base_price, base_volume
+
+ def _agg_base_price(
+ self,
+ inner_order_indicators: List[Dict[str, Union[SingleMetric, idd.SingleData]]],
+ decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]],
+ trade_exchange: Exchange,
+ pa_config: dict = {},
+ ):
+ """
+ # NOTE:!!!!
+ # Strong assumption!!!!!!
+ # the correctness of the base_price relies on that the **same** exchange is used
+
+ Parameters
+ ----------
+ inner_order_indicators : List[Dict[str, pd.Series]]
+ the indicators of account of inner executor
+ decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]],
+ a list of decisions according to inner_order_indicators
+ trade_exchange : Exchange
+ for retrieving trading price
+ pa_config : dict
+ For example
+ {
+ "agg": "twap", # "vwap"
+ "price": "$close", # TODO: this is not supported now!!!!!
+ # default to use deal price of the exchange
+ }
+ """
+
+ # TODO: I think there are potentials to be optimized
+ trade_dir = self.order_indicator.get_index_data("trade_dir")
+ if len(trade_dir) > 0:
+ bp_all, bv_all = [], []
+ #
+ for oi, (dec, start, end) in zip(inner_order_indicators, decision_list):
+ bp_s = oi.get_index_data("base_price").reindex(trade_dir.index)
+ bv_s = oi.get_index_data("base_volume").reindex(trade_dir.index)
+
+ bp_new, bv_new = {}, {}
+ for pr, v, (inst, direction) in zip(bp_s.data, bv_s.data, zip(trade_dir.index, trade_dir.data)):
+ if np.isnan(pr):
+ bp_tmp, bv_tmp = self._get_base_vol_pri(
+ inst,
+ start,
+ end,
+ decision=dec,
+ direction=direction,
+ trade_exchange=trade_exchange,
+ pa_config=pa_config,
+ )
+ if (bp_tmp is not None) and (bv_tmp is not None):
+ bp_new[inst], bv_new[inst] = bp_tmp, bv_tmp
+ else:
+ bp_new[inst], bv_new[inst] = pr, v
+
+ bp_new = idd.SingleData(bp_new)
+ bv_new = idd.SingleData(bv_new)
+ bp_all.append(bp_new)
+ bv_all.append(bv_new)
+ bp_all = idd.concat(bp_all, axis=1)
+ bv_all = idd.concat(bv_all, axis=1)
+
+ base_volume = bv_all.sum(axis=1)
+ self.order_indicator.assign("base_volume", base_volume.to_dict())
+ self.order_indicator.assign("base_price", ((bp_all * bv_all).sum(axis=1) / base_volume).to_dict())
+
+ def _agg_order_price_advantage(self):
+ def if_empty_func(trade_price):
+ return trade_price.empty
+
+ if_empty = self.order_indicator.transfer(if_empty_func)
+ if not if_empty:
+
+ def func(trade_dir, trade_price, base_price):
+ sign = 1 - trade_dir * 2
+ return sign * (trade_price / base_price - 1)
+
+ self.order_indicator.transfer(func, "pa")
+ else:
+ self.order_indicator.assign("pa", {})
+
+ def agg_order_indicators(
+ self,
+ inner_order_indicators: List[Dict[str, pd.Series]],
+ decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]],
+ outer_trade_decision: BaseTradeDecision,
+ trade_exchange: Exchange,
+ indicator_config={},
+ ):
+ self._agg_order_trade_info(inner_order_indicators)
+ self._update_trade_amount(outer_trade_decision)
+ self._update_order_fulfill_rate()
+ pa_config = indicator_config.get("pa_config", {})
+ self._agg_base_price(inner_order_indicators, decision_list, trade_exchange, pa_config=pa_config) # TODO
+ self._agg_order_price_advantage()
+
+ def _cal_trade_fulfill_rate(self, method="mean"):
+ if method == "mean":
+
+ def func(ffr):
+ return ffr.mean()
+
+ elif method == "amount_weighted":
+
+ def func(ffr, deal_amount):
+ return (ffr * deal_amount.abs()).sum() / (deal_amount.abs().sum())
+
+ elif method == "value_weighted":
+
+ def func(ffr, trade_value):
+ return (ffr * trade_value.abs()).sum() / (trade_value.abs().sum())
+
+ else:
+ raise ValueError(f"method {method} is not supported!")
+ return self.order_indicator.transfer(func)
+
+ def _cal_trade_price_advantage(self, method="mean"):
+ if method == "mean":
+
+ def func(pa):
+ return pa.mean()
+
+ elif method == "amount_weighted":
+
+ def func(pa, deal_amount):
+ return (pa * deal_amount.abs()).sum() / (deal_amount.abs().sum())
+
+ elif method == "value_weighted":
+
+ def func(pa, trade_value):
+ return (pa * trade_value.abs()).sum() / (trade_value.abs().sum())
+
+ else:
+ raise ValueError(f"method {method} is not supported!")
+ return self.order_indicator.transfer(func)
+
+ def _cal_trade_positive_rate(self):
+ def func(pa):
+ return (pa > 0).sum() / pa.count()
+
+ return self.order_indicator.transfer(func)
+
+ def _cal_deal_amount(self):
+ def func(deal_amount):
+ return deal_amount.abs().sum()
+
+ return self.order_indicator.transfer(func)
+
+ def _cal_trade_value(self):
+ def func(trade_value):
+ return trade_value.abs().sum()
+
+ return self.order_indicator.transfer(func)
+
+ def _cal_trade_order_count(self):
+ def func(amount):
+ return amount.count()
+
+ return self.order_indicator.transfer(func)
+
+ def cal_trade_indicators(self, trade_start_time, freq, indicator_config={}):
+ show_indicator = indicator_config.get("show_indicator", False)
+ ffr_config = indicator_config.get("ffr_config", {})
+ pa_config = indicator_config.get("pa_config", {})
+ fulfill_rate = self._cal_trade_fulfill_rate(method=ffr_config.get("weight_method", "mean"))
+ price_advantage = self._cal_trade_price_advantage(method=pa_config.get("weight_method", "mean"))
+ positive_rate = self._cal_trade_positive_rate()
+ deal_amount = self._cal_deal_amount()
+ trade_value = self._cal_trade_value()
+ order_count = self._cal_trade_order_count()
+ self.trade_indicator["ffr"] = fulfill_rate
+ self.trade_indicator["pa"] = price_advantage
+ self.trade_indicator["pos"] = positive_rate
+ self.trade_indicator["deal_amount"] = deal_amount
+ self.trade_indicator["value"] = trade_value
+ self.trade_indicator["count"] = order_count
+ if show_indicator:
+ print(
+ "[Indicator({}) {:%Y-%m-%d %H:%M:%S}]: FFR: {}, PA: {}, POS: {}".format(
+ freq, trade_start_time, fulfill_rate, price_advantage, positive_rate
+ )
+ )
+
+ def get_order_indicator(self, raw: bool = True):
+ if raw:
+ return self.order_indicator
+ return self.order_indicator.to_series()
+
+ def get_trade_indicator(self):
+ return self.trade_indicator
+
+ def generate_trade_indicators_dataframe(self):
+ return pd.DataFrame.from_dict(self.trade_indicator_his, orient="index")
diff --git a/qlib/backtest/utils.py b/qlib/backtest/utils.py
new file mode 100644
index 0000000000..51130712d0
--- /dev/null
+++ b/qlib/backtest/utils.py
@@ -0,0 +1,269 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+from __future__ import annotations
+import bisect
+from qlib.utils.time import epsilon_change
+from typing import TYPE_CHECKING, Tuple, Union
+
+if TYPE_CHECKING:
+ from qlib.backtest.decision import BaseTradeDecision
+
+import pandas as pd
+import warnings
+
+from ..data.data import Cal
+
+
+class TradeCalendarManager:
+ """
+ Manager for trading calendar
+ - BaseStrategy and BaseExecutor will use it
+ """
+
+ def __init__(
+ self,
+ freq: str,
+ start_time: Union[str, pd.Timestamp] = None,
+ end_time: Union[str, pd.Timestamp] = None,
+ level_infra: "LevelInfrastructure" = None,
+ ):
+ """
+ Parameters
+ ----------
+ freq : str
+ frequency of trading calendar, also trade time per trading step
+ start_time : Union[str, pd.Timestamp], optional
+ closed start of the trading calendar, by default None
+ If `start_time` is None, it must be reset before trading.
+ end_time : Union[str, pd.Timestamp], optional
+ closed end of the trade time range, by default None
+ If `end_time` is None, it must be reset before trading.
+ """
+ self.level_infra = level_infra
+ self.reset(freq=freq, start_time=start_time, end_time=end_time)
+
+ def reset(self, freq, start_time, end_time):
+ """
+ Please refer to the docs of `__init__`
+
+ Reset the trade calendar
+ - self.trade_len : The total count for trading step
+ - self.trade_step : The number of trading step finished, self.trade_step can be [0, 1, 2, ..., self.trade_len - 1]
+ """
+ self.freq = freq
+ self.start_time = pd.Timestamp(start_time) if start_time else None
+ self.end_time = pd.Timestamp(end_time) if end_time else None
+
+ _calendar = Cal.calendar(freq=freq)
+ self._calendar = _calendar
+ _, _, _start_index, _end_index = Cal.locate_index(start_time, end_time, freq=freq)
+ self.start_index = _start_index
+ self.end_index = _end_index
+ self.trade_len = _end_index - _start_index + 1
+ self.trade_step = 0
+
+ def finished(self):
+ """
+ Check if the trading finished
+ - Should check before calling strategy.generate_decisions and executor.execute
+ - If self.trade_step >= self.self.trade_len, it means the trading is finished
+ - If self.trade_step < self.self.trade_len, it means the number of trading step finished is self.trade_step
+ """
+ return self.trade_step >= self.trade_len
+
+ def step(self):
+ if self.finished():
+ raise RuntimeError(f"The calendar is finished, please reset it if you want to call it!")
+ self.trade_step = self.trade_step + 1
+
+ def get_freq(self):
+ return self.freq
+
+ def get_trade_len(self):
+ """get the total step length"""
+ return self.trade_len
+
+ def get_trade_step(self):
+ return self.trade_step
+
+ def get_step_time(self, trade_step=None, shift=0):
+ """
+ Get the left and right endpoints of the trade_step'th trading interval
+
+ About the endpoints:
+ - Qlib uses the closed interval in time-series data selection, which has the same performance as pandas.Series.loc
+ # - The returned right endpoints should minus 1 seconds becasue of the closed interval representation in Qlib.
+ # Note: Qlib supports up to minutely decision execution, so 1 seconds is less than any trading time interval.
+
+ Parameters
+ ----------
+ trade_step : int, optional
+ the number of trading step finished, by default None to indicate current step
+ shift : int, optional
+ shift bars , by default 0
+
+ Returns
+ -------
+ Tuple[pd.Timestamp, pd.Timestap]
+ - If shift == 0, return the trading time range
+ - If shift > 0, return the trading time range of the earlier shift bars
+ - If shift < 0, return the trading time range of the later shift bar
+ """
+ if trade_step is None:
+ trade_step = self.get_trade_step()
+ trade_step = trade_step - shift
+ calendar_index = self.start_index + trade_step
+ return self._calendar[calendar_index], epsilon_change(self._calendar[calendar_index + 1])
+
+ def get_data_cal_range(self, rtype: str = "full") -> Tuple[int, int]:
+ """
+ get the calendar range
+ The following assumptions are made
+ 1) The frequency of the exchange in common_infra is the same as the data calendar
+ 2) Users want the **data index** mod by **day** (i.e. 240 min)
+
+ Parameters
+ ----------
+ rtype: str
+ - "full": return the full limitation of the deicsion in the day
+ - "step": return the limitation of current step
+
+ Returns
+ -------
+ Tuple[int, int]:
+ """
+ # potential performance issue
+ day_start = pd.Timestamp(self.start_time.date())
+ day_end = epsilon_change(day_start + pd.Timedelta(days=1))
+ freq = self.level_infra.get("common_infra").get("trade_exchange").freq
+ _, _, day_start_idx, _ = Cal.locate_index(day_start, day_end, freq=freq)
+
+ if rtype == "full":
+ _, _, start_idx, end_index = Cal.locate_index(self.start_time, self.end_time, freq=freq)
+ elif rtype == "step":
+ _, _, start_idx, end_index = Cal.locate_index(*self.get_step_time(), freq=freq)
+ else:
+ raise ValueError(f"This type of input {rtype} is not supported")
+
+ return start_idx - day_start_idx, end_index - day_start_idx
+
+ def get_all_time(self):
+ """Get the start_time and end_time for trading"""
+ return self.start_time, self.end_time
+
+ # helper functions
+ def get_range_idx(self, start_time: pd.Timestamp, end_time: pd.Timestamp) -> Tuple[int, int]:
+ """
+ get the range index which involve start_time~end_time (both sides are closed)
+
+ Parameters
+ ----------
+ start_time : pd.Timestamp
+ end_time : pd.Timestamp
+
+ Returns
+ -------
+ Tuple[int, int]:
+ the index of the range. **the left and right are closed**
+ """
+ left, right = (
+ bisect.bisect_right(self._calendar, start_time) - 1,
+ bisect.bisect_right(self._calendar, end_time) - 1,
+ )
+ left -= self.start_index
+ right -= self.start_index
+
+ def clip(idx):
+ return min(max(0, idx), self.trade_len - 1)
+
+ return clip(left), clip(right)
+
+ def __repr__(self) -> str:
+ return f"class: {self.__class__.__name__}; {self.start_time}[{self.start_index}]~{self.end_time}[{self.end_index}]: [{self.trade_step}/{self.trade_len}]"
+
+
+class BaseInfrastructure:
+ def __init__(self, **kwargs):
+ self.reset_infra(**kwargs)
+
+ def get_support_infra(self):
+ raise NotImplementedError("`get_support_infra` is not implemented!")
+
+ def reset_infra(self, **kwargs):
+ support_infra = self.get_support_infra()
+ for k, v in kwargs.items():
+ if k in support_infra:
+ setattr(self, k, v)
+ else:
+ warnings.warn(f"{k} is ignored in `reset_infra`!")
+
+ def get(self, infra_name):
+ if hasattr(self, infra_name):
+ return getattr(self, infra_name)
+ else:
+ warnings.warn(f"infra {infra_name} is not found!")
+
+ def has(self, infra_name):
+ if infra_name in self.get_support_infra() and hasattr(self, infra_name):
+ return True
+ else:
+ return False
+
+ def update(self, other):
+ support_infra = other.get_support_infra()
+ infra_dict = {_infra: getattr(other, _infra) for _infra in support_infra if hasattr(other, _infra)}
+ self.reset_infra(**infra_dict)
+
+
+class CommonInfrastructure(BaseInfrastructure):
+ def get_support_infra(self):
+ return ["trade_account", "trade_exchange"]
+
+
+class LevelInfrastructure(BaseInfrastructure):
+ """level instrastructure is created by executor, and then shared to strategies on the same level"""
+
+ def get_support_infra(self):
+ """
+ Descriptions about the infrastructure
+
+ sub_level_infra:
+ - **NOTE**: this will only work after _init_sub_trading !!!
+ """
+ return ["trade_calendar", "sub_level_infra", "common_infra"]
+
+ def reset_cal(self, freq, start_time, end_time):
+ """reset trade calendar manager"""
+ if self.has("trade_calendar"):
+ self.get("trade_calendar").reset(freq, start_time=start_time, end_time=end_time)
+ else:
+ self.reset_infra(
+ trade_calendar=TradeCalendarManager(freq, start_time=start_time, end_time=end_time, level_infra=self)
+ )
+
+ def set_sub_level_infra(self, sub_level_infra: LevelInfrastructure):
+ """this will make the calendar access easier when acrossing multi-levels"""
+ self.reset_infra(sub_level_infra=sub_level_infra)
+
+
+def get_start_end_idx(trade_calendar: TradeCalendarManager, outer_trade_decision: BaseTradeDecision) -> Union[int, int]:
+ """
+ A helper function for getting the decision-level index range limitation for inner strategy
+ - NOTE: this function is not applicable to order-level
+
+ Parameters
+ ----------
+ trade_calendar : TradeCalendarManager
+ outer_trade_decision : BaseTradeDecision
+ the trade decision made by outer strategy
+
+ Returns
+ -------
+ Union[int, int]:
+ start index and end index
+ """
+ try:
+ return outer_trade_decision.get_range_limit(inner_calendar=trade_calendar)
+ except NotImplementedError:
+ return 0, trade_calendar.get_trade_len() - 1
diff --git a/qlib/config.py b/qlib/config.py
index 7d96f51f6e..029434a886 100644
--- a/qlib/config.py
+++ b/qlib/config.py
@@ -109,6 +109,8 @@ def set_conf_from_C(self, config_c):
"kernels": NUM_USABLE_CPU,
# How many tasks belong to one process. Recommend 1 for high-frequency data and None for daily data.
"maxtasksperchild": None,
+ # If joblib_backend is None, use loky
+ "joblib_backend": "multiprocessing",
"default_disk_cache": 1, # 0:skip/1:use
"mem_cache_size_limit": 500,
# memory cache expire second, only in used 'DatasetURICache' and 'client D.calendar'
@@ -165,6 +167,10 @@ def set_conf_from_C(self, config_c):
"task_url": "mongodb://localhost:27017/",
"task_db_name": "default_task_db",
},
+ # Shift minute for highfreq minite data, used in backtest
+ # if min_data_shift == 0, use default market time [9:30, 11:29, 1:00, 2:59]
+ # if min_data_shift != 0, use shifted market time [9:30, 11:29, 1:00, 2:59] - shift*minute
+ "min_data_shift": 0,
}
MODE_CONF = {
@@ -271,7 +277,7 @@ def get_uri_type(uri: Union[str, Path]):
else:
return QlibConfig.LOCAL_URI
- def get_data_path(self, freq: str = None) -> Path:
+ def get_data_uri(self, freq: str = None) -> Path:
if freq is None or freq not in self.provider_uri:
freq = QlibConfig.DEFAULT_FREQ
_provider_uri = self.provider_uri[freq]
@@ -328,11 +334,41 @@ def resolve_path(self):
if _mount_path[_freq] is None
else str(Path(_mount_path[_freq]).expanduser().resolve())
)
-
self["provider_uri"] = _provider_uri
self["mount_path"] = _mount_path
- def set(self, default_conf="client", **kwargs):
+ def get_uri_type(self):
+ path = self["provider_uri"]
+ if isinstance(path, Path):
+ path = str(path)
+ is_win = re.match("^[a-zA-Z]:.*", path) is not None # such as 'C:\\data', 'D:'
+ is_nfs_or_win = (
+ re.match("^[^/]+:.+", path) is not None
+ ) # such as 'host:/data/' (User may define short hostname by themselves or use localhost)
+
+ if is_nfs_or_win and not is_win:
+ return QlibConfig.NFS_URI
+ else:
+ return QlibConfig.LOCAL_URI
+
+ def set(self, default_conf: str = "client", **kwargs):
+ """
+ configure qlib based on the input parameters
+
+ The configure will act like a dictionary.
+
+ Normally, it literally replace the value according to the keys.
+ However, sometimes it is hard for users to set the config when the configure is nested and complicated
+
+ So this API provides some special parameters for users to set the keys in a more convenient way.
+ - region: REG_CN, REG_US
+ - several region-related config will be changed
+
+ Parameters
+ ----------
+ default_conf : str
+ the default config template chosen by user: "server", "client"
+ """
from .utils import set_log_with_config, get_module_logger, can_use_cache
self.reset()
diff --git a/qlib/contrib/backtest/__init__.py b/qlib/contrib/backtest/__init__.py
deleted file mode 100644
index aa24ffb0cf..0000000000
--- a/qlib/contrib/backtest/__init__.py
+++ /dev/null
@@ -1,324 +0,0 @@
-# Copyright (c) Microsoft Corporation.
-# Licensed under the MIT License.
-
-from .order import Order
-from .account import Account
-from .position import Position
-from .exchange import Exchange
-from .report import Report
-from .backtest import backtest as backtest_func, get_date_range
-
-import numpy as np
-import inspect
-from ...utils import init_instance_by_config
-from ...log import get_module_logger
-from ...config import C
-
-logger = get_module_logger("backtest caller")
-
-
-def get_strategy(
- strategy=None,
- topk=50,
- margin=0.5,
- n_drop=5,
- risk_degree=0.95,
- str_type="dropout",
- adjust_dates=None,
-):
- """get_strategy
-
- There will be 3 ways to return a stratgy. Please follow the code.
-
-
- Parameters
- ----------
-
- strategy : Strategy()
- strategy used in backtest.
- topk : int (Default value: 50)
- top-N stocks to buy.
- margin : int or float(Default value: 0.5)
- - if isinstance(margin, int):
-
- sell_limit = margin
-
- - else:
-
- sell_limit = pred_in_a_day.count() * margin
-
- buffer margin, in single score_mode, continue holding stock if it is in nlargest(sell_limit).
- sell_limit should be no less than topk.
- n_drop : int
- number of stocks to be replaced in each trading date.
- risk_degree: float
- 0-1, 0.95 for example, use 95% money to trade.
- str_type: 'amount', 'weight' or 'dropout'
- strategy type: TopkAmountStrategy ,TopkWeightStrategy or TopkDropoutStrategy.
-
- Returns
- -------
- :class: Strategy
- an initialized strategy object
- """
-
- # There will be 3 ways to return a strategy.
- if strategy is None:
- # 1) create strategy with param `strategy`
- str_cls_dict = {
- "amount": "TopkAmountStrategy",
- "weight": "TopkWeightStrategy",
- "dropout": "TopkDropoutStrategy",
- }
- logger.info("Create new strategy ")
- from .. import strategy as strategy_pool
-
- str_cls = getattr(strategy_pool, str_cls_dict.get(str_type))
- strategy = str_cls(
- topk=topk,
- buffer_margin=margin,
- n_drop=n_drop,
- risk_degree=risk_degree,
- adjust_dates=adjust_dates,
- )
- elif isinstance(strategy, (dict, str)):
- # 2) create strategy with init_instance_by_config
- logger.info("Create new strategy ")
- strategy = init_instance_by_config(strategy)
-
- from ..strategy.strategy import BaseStrategy
-
- # else: nothing happens. 3) Use the strategy directly
- if not isinstance(strategy, BaseStrategy):
- raise TypeError("Strategy not supported")
- return strategy
-
-
-def get_exchange(
- pred,
- exchange=None,
- subscribe_fields=[],
- open_cost=0.0015,
- close_cost=0.0025,
- min_cost=5.0,
- trade_unit=None,
- limit_threshold=None,
- deal_price=None,
- extract_codes=False,
- shift=1,
-):
- """get_exchange
-
- Parameters
- ----------
-
- # exchange related arguments
- exchange: Exchange().
- subscribe_fields: list
- subscribe fields.
- open_cost : float
- open transaction cost.
- close_cost : float
- close transaction cost.
- min_cost : float
- min transaction cost.
- trade_unit : int
- 100 for China A.
- deal_price: str
- dealing price type: 'close', 'open', 'vwap'.
- limit_threshold : float
- limit move 0.1 (10%) for example, long and short with same limit.
- extract_codes: bool
- will we pass the codes extracted from the pred to the exchange.
- NOTE: This will be faster with offline qlib.
-
- Returns
- -------
- :class: Exchange
- an initialized Exchange object
- """
-
- if trade_unit is None:
- trade_unit = C.trade_unit
- if limit_threshold is None:
- limit_threshold = C.limit_threshold
- if deal_price is None:
- deal_price = C.deal_price
- if exchange is None:
- logger.info("Create new exchange")
- # handle exception for deal_price
- if deal_price[0] != "$":
- deal_price = "$" + deal_price
- if extract_codes:
- codes = sorted(pred.index.get_level_values("instrument").unique())
- else:
- codes = "all" # TODO: We must ensure that 'all.txt' includes all the stocks
-
- dates = sorted(pred.index.get_level_values("datetime").unique())
- dates = np.append(dates, get_date_range(dates[-1], left_shift=1, right_shift=shift))
-
- exchange = Exchange(
- trade_dates=dates,
- codes=codes,
- deal_price=deal_price,
- subscribe_fields=subscribe_fields,
- limit_threshold=limit_threshold,
- open_cost=open_cost,
- close_cost=close_cost,
- min_cost=min_cost,
- trade_unit=trade_unit,
- )
- return exchange
-
-
-def get_executor(
- executor=None,
- trade_exchange=None,
- verbose=True,
-):
- """get_executor
-
- There will be 3 ways to return a executor. Please follow the code.
-
- Parameters
- ----------
-
- executor : BaseExecutor
- executor used in backtest.
- trade_exchange : Exchange
- exchange used in executor
- verbose : bool
- whether to print log.
-
- Returns
- -------
- :class: BaseExecutor
- an initialized BaseExecutor object
- """
-
- # There will be 3 ways to return a executor.
- if executor is None:
- # 1) create executor with param `executor`
- logger.info("Create new executor ")
- from ..online.executor import SimulatorExecutor
-
- executor = SimulatorExecutor(trade_exchange=trade_exchange, verbose=verbose)
- elif isinstance(executor, (dict, str)):
- # 2) create executor with config
- logger.info("Create new executor ")
- executor = init_instance_by_config(executor)
-
- from ..online.executor import BaseExecutor
-
- # 3) Use the executor directly
- if not isinstance(executor, BaseExecutor):
- raise TypeError("Executor not supported")
- return executor
-
-
-# This is the API for compatibility for legacy code
-def backtest(pred, account=1e9, shift=1, benchmark="SH000905", verbose=True, return_order=False, **kwargs):
- """This function will help you set a reasonable Exchange and provide default value for strategy
- Parameters
- ----------
-
- - **backtest workflow related or commmon arguments**
-
- pred : pandas.DataFrame
- predict should has index and one `score` column.
- account : float
- init account value.
- shift : int
- whether to shift prediction by one day.
- benchmark : str
- benchmark code, default is SH000905 CSI 500.
- verbose : bool
- whether to print log.
- return_order : bool
- whether to return order list
-
- - **strategy related arguments**
-
- strategy : Strategy()
- strategy used in backtest.
- topk : int (Default value: 50)
- top-N stocks to buy.
- margin : int or float(Default value: 0.5)
- - if isinstance(margin, int):
-
- sell_limit = margin
-
- - else:
-
- sell_limit = pred_in_a_day.count() * margin
-
- buffer margin, in single score_mode, continue holding stock if it is in nlargest(sell_limit).
- sell_limit should be no less than topk.
- n_drop : int
- number of stocks to be replaced in each trading date.
- risk_degree: float
- 0-1, 0.95 for example, use 95% money to trade.
- str_type: 'amount', 'weight' or 'dropout'
- strategy type: TopkAmountStrategy ,TopkWeightStrategy or TopkDropoutStrategy.
-
- - **exchange related arguments**
-
- exchange: Exchange()
- pass the exchange for speeding up.
- subscribe_fields: list
- subscribe fields.
- open_cost : float
- open transaction cost. The default value is 0.002(0.2%).
- close_cost : float
- close transaction cost. The default value is 0.002(0.2%).
- min_cost : float
- min transaction cost.
- trade_unit : int
- 100 for China A.
- deal_price: str
- dealing price type: 'close', 'open', 'vwap'.
- limit_threshold : float
- limit move 0.1 (10%) for example, long and short with same limit.
- extract_codes: bool
- will we pass the codes extracted from the pred to the exchange.
-
- .. note:: This will be faster with offline qlib.
-
- - **executor related arguments**
-
- executor : BaseExecutor()
- executor used in backtest.
- verbose : bool
- whether to print log.
-
- """
- # check strategy:
- spec = inspect.getfullargspec(get_strategy)
- str_args = {k: v for k, v in kwargs.items() if k in spec.args}
- strategy = get_strategy(**str_args)
-
- # init exchange:
- spec = inspect.getfullargspec(get_exchange)
- ex_args = {k: v for k, v in kwargs.items() if k in spec.args}
- trade_exchange = get_exchange(pred, **ex_args)
-
- # init executor:
- executor = get_executor(executor=kwargs.get("executor"), trade_exchange=trade_exchange, verbose=verbose)
-
- # run backtest
- report_dict = backtest_func(
- pred=pred,
- strategy=strategy,
- executor=executor,
- trade_exchange=trade_exchange,
- shift=shift,
- verbose=verbose,
- account=account,
- benchmark=benchmark,
- return_order=return_order,
- )
- # for compatibility of the old API. return the dict positions
-
- positions = report_dict.get("positions")
- report_dict.update({"positions": {k: p.position for k, p in positions.items()}})
- return report_dict
diff --git a/qlib/contrib/backtest/account.py b/qlib/contrib/backtest/account.py
deleted file mode 100644
index a614f08b67..0000000000
--- a/qlib/contrib/backtest/account.py
+++ /dev/null
@@ -1,169 +0,0 @@
-# Copyright (c) Microsoft Corporation.
-# Licensed under the MIT License.
-
-
-import copy
-
-from .position import Position
-from .report import Report
-from .order import Order
-
-
-"""
-rtn & earning in the Account
- rtn:
- from order's view
- 1.change if any order is executed, sell order or buy order
- 2.change at the end of today, (today_clse - stock_price) * amount
- earning
- from value of current position
- earning will be updated at the end of trade date
- earning = today_value - pre_value
- **is consider cost**
- while earning is the difference of two position value, so it considers cost, it is the true return rate
- in the specific accomplishment for rtn, it does not consider cost, in other words, rtn - cost = earning
-"""
-
-
-class Account:
- def __init__(self, init_cash, last_trade_date=None):
- self.init_vars(init_cash, last_trade_date)
-
- def init_vars(self, init_cash, last_trade_date=None):
- # init cash
- self.init_cash = init_cash
- self.current = Position(cash=init_cash)
- self.positions = {}
- self.rtn = 0
- self.ct = 0
- self.to = 0
- self.val = 0
- self.report = Report()
- self.earning = 0
- self.last_trade_date = last_trade_date
-
- def get_positions(self):
- return self.positions
-
- def get_cash(self):
- return self.current.position["cash"]
-
- def update_state_from_order(self, order, trade_val, cost, trade_price):
- # update turnover
- self.to += trade_val
- # update cost
- self.ct += cost
- # update return
- # update self.rtn from order
- trade_amount = trade_val / trade_price
- if order.direction == Order.SELL: # 0 for sell
- # when sell stock, get profit from price change
- profit = trade_val - self.current.get_stock_price(order.stock_id) * trade_amount
- self.rtn += profit # note here do not consider cost
- elif order.direction == Order.BUY: # 1 for buy
- # when buy stock, we get return for the rtn computing method
- # profit in buy order is to make self.rtn is consistent with self.earning at the end of date
- profit = self.current.get_stock_price(order.stock_id) * trade_amount - trade_val
- self.rtn += profit
-
- def update_order(self, order, trade_val, cost, trade_price):
- # if stock is sold out, no stock price information in Position, then we should update account first, then update current position
- # if stock is bought, there is no stock in current position, update current, then update account
- # The cost will be substracted from the cash at last. So the trading logic can ignore the cost calculation
- trade_amount = trade_val / trade_price
- if order.direction == Order.SELL:
- # sell stock
- self.update_state_from_order(order, trade_val, cost, trade_price)
- # update current position
- # for may sell all of stock_id
- self.current.update_order(order, trade_val, cost, trade_price)
- else:
- # buy stock
- # deal order, then update state
- self.current.update_order(order, trade_val, cost, trade_price)
- self.update_state_from_order(order, trade_val, cost, trade_price)
-
- def update_daily_end(self, today, trader):
- """
- today: pd.TimeStamp
- quote: pd.DataFrame (code, date), collumns
- when the end of trade date
- - update rtn
- - update price for each asset
- - update value for this account
- - update earning (2nd view of return )
- - update holding day, count of stock
- - update position hitory
- - update report
- :return: None
- """
- # update price for stock in the position and the profit from changed_price
- stock_list = self.current.get_stock_list()
- profit = 0
- for code in stock_list:
- # if suspend, no new price to be updated, profit is 0
- if trader.check_stock_suspended(code, today):
- continue
- today_close = trader.get_close(code, today)
- profit += (today_close - self.current.position[code]["price"]) * self.current.position[code]["amount"]
- self.current.update_stock_price(stock_id=code, price=today_close)
- self.rtn += profit
- # update holding day count
- self.current.add_count_all()
- # update value
- self.val = self.current.calculate_value()
- # update earning (2nd view of return)
- # account_value - last_account_value
- # for the first trade date, account_value - init_cash
- # self.report.is_empty() to judge is_first_trade_date
- # get last_account_value, today_account_value, today_stock_value
- if self.report.is_empty():
- last_account_value = self.init_cash
- else:
- last_account_value = self.report.get_latest_account_value()
- today_account_value = self.current.calculate_value()
- today_stock_value = self.current.calculate_stock_value()
- self.earning = today_account_value - last_account_value
- # update report for today
- # judge whether the the trading is begin.
- # and don't add init account state into report, due to we don't have excess return in those days.
- self.report.update_report_record(
- trade_date=today,
- account_value=today_account_value,
- cash=self.current.position["cash"],
- return_rate=(self.earning + self.ct) / last_account_value,
- # here use earning to calculate return, position's view, earning consider cost, true return
- # in order to make same definition with original backtest in evaluate.py
- turnover_rate=self.to / last_account_value,
- cost_rate=self.ct / last_account_value,
- stock_value=today_stock_value,
- )
- # set today_account_value to position
- self.current.position["today_account_value"] = today_account_value
- self.current.update_weight_all()
- # update positions
- # note use deepcopy
- self.positions[today] = copy.deepcopy(self.current)
-
- # finish today's updation
- # reset the daily variables
- self.rtn = 0
- self.ct = 0
- self.to = 0
- self.last_trade_date = today
-
- def load_account(self, account_path):
- report = Report()
- position = Position()
- last_trade_date = position.load_position(account_path / "position.xlsx")
- report.load_report(account_path / "report.csv")
-
- # assign values
- self.init_vars(position.init_cash)
- self.current = position
- self.report = report
- self.last_trade_date = last_trade_date if last_trade_date else None
-
- def save_account(self, account_path):
- self.current.save_position(account_path / "position.xlsx", self.last_trade_date)
- self.report.save_report(account_path / "report.csv")
diff --git a/qlib/contrib/backtest/backtest.py b/qlib/contrib/backtest/backtest.py
deleted file mode 100644
index fc30065fd3..0000000000
--- a/qlib/contrib/backtest/backtest.py
+++ /dev/null
@@ -1,146 +0,0 @@
-# Copyright (c) Microsoft Corporation.
-# Licensed under the MIT License.
-
-
-import numpy as np
-import pandas as pd
-from ...utils import get_date_by_shift, get_date_range
-from ...data import D
-from .account import Account
-from ...config import C
-from ...log import get_module_logger
-from ...data.dataset.utils import get_level_index
-
-LOG = get_module_logger("backtest")
-
-
-def backtest(pred, strategy, executor, trade_exchange, shift, verbose, account, benchmark, return_order):
- """
- Parameters
- ----------
- pred : pandas.DataFrame
- predict should has index and one `score` column
- Qlib want to support multi-singal strategy in the future. So pd.Series is not used.
- strategy : Strategy()
- strategy part for backtest
- trade_exchange : Exchange()
- exchage for backtest
- shift : int
- whether to shift prediction by one day
- verbose : bool
- whether to print log
- account : float
- init account value
- benchmark : str/list/pd.Series
- `benchmark` is pd.Series, `index` is trading date; the value T is the change from T-1 to T.
- example:
- print(D.features(D.instruments('csi500'), ['$close/Ref($close, 1)-1'])['$close/Ref($close, 1)-1'].head())
- 2017-01-04 0.011693
- 2017-01-05 0.000721
- 2017-01-06 -0.004322
- 2017-01-09 0.006874
- 2017-01-10 -0.003350
-
- `benchmark` is list, will use the daily average change of the stock pool in the list as the 'bench'.
- `benchmark` is str, will use the daily change as the 'bench'.
- benchmark code, default is SH000905 CSI500
- """
- # Convert format if the input format is not expected
- if get_level_index(pred, level="datetime") == 1:
- pred = pred.swaplevel().sort_index()
- if isinstance(pred, pd.Series):
- pred = pred.to_frame("score")
-
- trade_account = Account(init_cash=account)
- _pred_dates = pred.index.get_level_values(level="datetime")
- predict_dates = D.calendar(start_time=_pred_dates.min(), end_time=_pred_dates.max())
- if isinstance(benchmark, pd.Series):
- bench = benchmark
- else:
- _codes = benchmark if isinstance(benchmark, list) else [benchmark]
- _temp_result = D.features(
- _codes,
- ["$close/Ref($close,1)-1"],
- predict_dates[0],
- get_date_by_shift(predict_dates[-1], shift=shift),
- disk_cache=1,
- )
- if len(_temp_result) == 0:
- raise ValueError(f"The benchmark {_codes} does not exist. Please provide the right benchmark")
- bench = _temp_result.groupby(level="datetime")[_temp_result.columns.tolist()[0]].mean()
-
- trade_dates = np.append(predict_dates[shift:], get_date_range(predict_dates[-1], left_shift=1, right_shift=shift))
- if return_order:
- multi_order_list = []
- # trading apart
- for pred_date, trade_date in zip(predict_dates, trade_dates):
- # for loop predict date and trading date
- # print
- if verbose:
- LOG.info("[I {:%Y-%m-%d}]: trade begin.".format(trade_date))
-
- # 1. Load the score_series at pred_date
- try:
- score = pred.loc(axis=0)[pred_date, :] # (trade_date, stock_id) multi_index, score in pdate
- score_series = score.reset_index(level="datetime", drop=True)[
- "score"
- ] # pd.Series(index:stock_id, data: score)
- except KeyError:
- LOG.warning("No score found on predict date[{:%Y-%m-%d}]".format(trade_date))
- score_series = None
-
- if score_series is not None and score_series.count() > 0: # in case of the scores are all None
- # 2. Update your strategy (and model)
- strategy.update(score_series, pred_date, trade_date)
-
- # 3. Generate order list
- order_list = strategy.generate_order_list(
- score_series=score_series,
- current=trade_account.current,
- trade_exchange=trade_exchange,
- pred_date=pred_date,
- trade_date=trade_date,
- )
- else:
- order_list = []
- if return_order:
- multi_order_list.append((trade_account, order_list, trade_date))
- # 4. Get result after executing order list
- # NOTE: The following operation will modify order.amount.
- # NOTE: If it is buy and the cash is insufficient, the tradable amount will be recalculated
- trade_info = executor.execute(trade_account, order_list, trade_date)
-
- # 5. Update account information according to transaction
- update_account(trade_account, trade_info, trade_exchange, trade_date)
-
- # generate backtest report
- report_df = trade_account.report.generate_report_dataframe()
- report_df["bench"] = bench
- positions = trade_account.get_positions()
-
- report_dict = {"report_df": report_df, "positions": positions}
- if return_order:
- report_dict.update({"order_list": multi_order_list})
- return report_dict
-
-
-def update_account(trade_account, trade_info, trade_exchange, trade_date):
- """
- Update the account and strategy
-
- Parameters
- ----------
- trade_account : Account()
- trade_info : list of [Order(), float, float, float]
- (order, trade_val, trade_cost, trade_price), trade_info with out factor
- trade_exchange : Exchange()
- used to get the $close_price at trade_date to update account
- trade_date : pd.Timestamp
- """
- # update account
- for [order, trade_val, trade_cost, trade_price] in trade_info:
- if order.deal_amount == 0:
- continue
- trade_account.update_order(order=order, trade_val=trade_val, cost=trade_cost, trade_price=trade_price)
- # at the end of trade date, update the account based the $close_price of stocks.
- trade_account.update_daily_end(today=trade_date, trader=trade_exchange)
diff --git a/qlib/contrib/backtest/exchange.py b/qlib/contrib/backtest/exchange.py
deleted file mode 100644
index 178950eebe..0000000000
--- a/qlib/contrib/backtest/exchange.py
+++ /dev/null
@@ -1,425 +0,0 @@
-# Copyright (c) Microsoft Corporation.
-# Licensed under the MIT License.
-
-
-import random
-import logging
-
-import numpy as np
-import pandas as pd
-
-from ...data import D
-from .order import Order
-from ...config import C, REG_CN
-from ...log import get_module_logger
-
-
-class Exchange:
- def __init__(
- self,
- trade_dates=None,
- codes="all",
- deal_price=None,
- subscribe_fields=[],
- limit_threshold=None,
- open_cost=0.0015,
- close_cost=0.0025,
- trade_unit=None,
- min_cost=5,
- extra_quote=None,
- ):
- """__init__
-
- :param trade_dates: list of pd.Timestamp
- :param codes: list stock_id list or a string of instruments(i.e. all, csi500, sse50)
- :param deal_price: str, 'close', 'open', 'vwap'
- :param subscribe_fields: list, subscribe fields
- :param limit_threshold: float, 0.1 for example, default None
- :param open_cost: cost rate for open, default 0.0015
- :param close_cost: cost rate for close, default 0.0025
- :param trade_unit: trade unit, 100 for China A market
- :param min_cost: min cost, default 5
- :param extra_quote: pandas, dataframe consists of
- columns: like ['$vwap', '$close', '$factor', 'limit'].
- The limit indicates that the etf is tradable on a specific day.
- Necessary fields:
- $close is for calculating the total value at end of each day.
- Optional fields:
- $vwap is only necessary when we use the $vwap price as the deal price
- $factor is for rounding to the trading unit
- limit will be set to False by default(False indicates we can buy this
- target on this day).
- index: MultipleIndex(instrument, pd.Datetime)
- """
- if trade_unit is None:
- trade_unit = C.trade_unit
- if limit_threshold is None:
- limit_threshold = C.limit_threshold
- if deal_price is None:
- deal_price = C.deal_price
-
- self.logger = get_module_logger("online operator", level=logging.INFO)
-
- self.trade_unit = trade_unit
-
- # TODO: the quote, trade_dates, codes are not necessray.
- # It is just for performance consideration.
- if limit_threshold is None:
- if C.region == REG_CN:
- self.logger.warning(f"limit_threshold not set. The stocks hit the limit may be bought/sold")
- elif abs(limit_threshold) > 0.1:
- if C.region == REG_CN:
- self.logger.warning(f"limit_threshold may not be set to a reasonable value")
-
- if deal_price[0] != "$":
- self.deal_price = "$" + deal_price
- else:
- self.deal_price = deal_price
- if isinstance(codes, str):
- codes = D.instruments(codes)
- self.codes = codes
- # Necessary fields
- # $close is for calculating the total value at end of each day.
- # $factor is for rounding to the trading unit
- # $change is for calculating the limit of the stock
-
- necessary_fields = {self.deal_price, "$close", "$change", "$factor"}
- subscribe_fields = list(necessary_fields | set(subscribe_fields))
- all_fields = list(necessary_fields | set(subscribe_fields))
- self.all_fields = all_fields
- self.open_cost = open_cost
- self.close_cost = close_cost
- self.min_cost = min_cost
- self.limit_threshold = limit_threshold
- # TODO: the quote, trade_dates, codes are not necessray.
- # It is just for performance consideration.
- if trade_dates is not None and len(trade_dates):
- start_date, end_date = trade_dates[0], trade_dates[-1]
- else:
- self.logger.warning("trade_dates have not been assigned, all dates will be loaded")
- start_date, end_date = None, None
-
- self.extra_quote = extra_quote
- self.set_quote(codes, start_date, end_date)
-
- def set_quote(self, codes, start_date, end_date):
- if len(codes) == 0:
- codes = D.instruments()
- self.quote = D.features(codes, self.all_fields, start_date, end_date, disk_cache=True).dropna(subset=["$close"])
- self.quote.columns = self.all_fields
-
- if self.quote[self.deal_price].isna().any():
- self.logger.warning("{} field data contains nan.".format(self.deal_price))
-
- if self.quote["$factor"].isna().any():
- # The 'factor.day.bin' file not exists, and `factor` field contains `nan`
- # Use adjusted price
- self.trade_w_adj_price = True
- self.logger.warning("factor.day.bin file not exists or factor contains `nan`. Order using adjusted_price.")
- else:
- # The `factor.day.bin` file exists and all data `close` and `factor` are not `nan`
- # Use normal price
- self.trade_w_adj_price = False
- # update limit
- # check limit_threshold
- if self.limit_threshold is None:
- self.quote["limit"] = False
- else:
- # set limit
- self._update_limit(buy_limit=self.limit_threshold, sell_limit=self.limit_threshold)
-
- quote_df = self.quote
- if self.extra_quote is not None:
- # process extra_quote
- if "$close" not in self.extra_quote:
- raise ValueError("$close is necessray in extra_quote")
- if self.deal_price not in self.extra_quote.columns:
- self.extra_quote[self.deal_price] = self.extra_quote["$close"]
- self.logger.warning("No deal_price set for extra_quote. Use $close as deal_price.")
- if "$factor" not in self.extra_quote.columns:
- self.extra_quote["$factor"] = 1.0
- self.logger.warning("No $factor set for extra_quote. Use 1.0 as $factor.")
- if "limit" not in self.extra_quote.columns:
- self.extra_quote["limit"] = False
- self.logger.warning("No limit set for extra_quote. All stock will be tradable.")
- assert set(self.extra_quote.columns) == set(quote_df.columns) - {"$change"}
- quote_df = pd.concat([quote_df, self.extra_quote], sort=False, axis=0)
-
- # update quote: pd.DataFrame to dict, for search use
- self.quote = quote_df.to_dict("index")
-
- def _update_limit(self, buy_limit, sell_limit):
- self.quote["limit"] = ~self.quote["$change"].between(-sell_limit, buy_limit, inclusive=False)
-
- def check_stock_limit(self, stock_id, trade_date):
- """Parameter
- stock_id
- trade_date
- is limtited
- """
- return self.quote[(stock_id, trade_date)]["limit"]
-
- def check_stock_suspended(self, stock_id, trade_date):
- # is suspended
- return (stock_id, trade_date) not in self.quote
-
- def is_stock_tradable(self, stock_id, trade_date):
- # check if stock can be traded
- # same as check in check_order
- if self.check_stock_suspended(stock_id, trade_date) or self.check_stock_limit(stock_id, trade_date):
- return False
- else:
- return True
-
- def check_order(self, order):
- # check limit and suspended
- if self.check_stock_suspended(order.stock_id, order.trade_date) or self.check_stock_limit(
- order.stock_id, order.trade_date
- ):
- return False
- else:
- return True
-
- def deal_order(self, order, trade_account=None, position=None):
- """
- Deal order when the actual transaction
-
- :param order: Deal the order.
- :param trade_account: Trade account to be updated after dealing the order.
- :param position: position to be updated after dealing the order.
- :return: trade_val, trade_cost, trade_price
- """
- # need to check order first
- # TODO: check the order unit limit in the exchange!!!!
- # The order limit is related to the adj factor and the cur_amount.
- # factor = self.quote[(order.stock_id, order.trade_date)]['$factor']
- # cur_amount = trade_account.current.get_stock_amount(order.stock_id)
- if self.check_order(order) is False:
- raise AttributeError("need to check order first")
- if trade_account is not None and position is not None:
- raise ValueError("trade_account and position can only choose one")
-
- trade_price = self.get_deal_price(order.stock_id, order.trade_date)
- trade_val, trade_cost = self._calc_trade_info_by_order(
- order, trade_account.current if trade_account else position
- )
- # update account
- if trade_val > 0:
- # If the order can only be deal 0 trade_val. Nothing to be updated
- # Otherwise, it will result some stock with 0 amount in the position
- if trade_account:
- trade_account.update_order(order=order, trade_val=trade_val, cost=trade_cost, trade_price=trade_price)
- elif position:
- position.update_order(order=order, trade_val=trade_val, cost=trade_cost, trade_price=trade_price)
-
- return trade_val, trade_cost, trade_price
-
- def get_quote_info(self, stock_id, trade_date):
- return self.quote[(stock_id, trade_date)]
-
- def get_close(self, stock_id, trade_date):
- return self.quote[(stock_id, trade_date)]["$close"]
-
- def get_deal_price(self, stock_id, trade_date):
- deal_price = self.quote[(stock_id, trade_date)][self.deal_price]
- if np.isclose(deal_price, 0.0) or np.isnan(deal_price):
- self.logger.warning(f"(stock_id:{stock_id}, trade_date:{trade_date}, {self.deal_price}): {deal_price}!!!")
- self.logger.warning(f"setting deal_price to close price")
- deal_price = self.get_close(stock_id, trade_date)
- return deal_price
-
- def get_factor(self, stock_id, trade_date):
- return self.quote[(stock_id, trade_date)]["$factor"]
-
- def generate_amount_position_from_weight_position(self, weight_position, cash, trade_date):
- """
- The generate the target position according to the weight and the cash.
- NOTE: All the cash will assigned to the tadable stock.
-
- Parameter:
- weight_position : dict {stock_id : weight}; allocate cash by weight_position
- among then, weight must be in this range: 0 < weight < 1
- cash : cash
- trade_date : trade date
- """
-
- # calculate the total weight of tradable value
- tradable_weight = 0.0
- for stock_id in weight_position:
- if self.is_stock_tradable(stock_id=stock_id, trade_date=trade_date):
- # weight_position must be greater than 0 and less than 1
- if weight_position[stock_id] < 0 or weight_position[stock_id] > 1:
- raise ValueError(
- "weight_position is {}, "
- "weight_position is not in the range of (0, 1).".format(weight_position[stock_id])
- )
- tradable_weight += weight_position[stock_id]
-
- if tradable_weight - 1.0 >= 1e-5:
- raise ValueError("tradable_weight is {}, can not greater than 1.".format(tradable_weight))
-
- amount_dict = {}
- for stock_id in weight_position:
- if weight_position[stock_id] > 0.0 and self.is_stock_tradable(stock_id=stock_id, trade_date=trade_date):
- amount_dict[stock_id] = (
- cash
- * weight_position[stock_id]
- / tradable_weight
- // self.get_deal_price(stock_id=stock_id, trade_date=trade_date)
- )
- return amount_dict
-
- def get_real_deal_amount(self, current_amount, target_amount, factor):
- """
- Calculate the real adjust deal amount when considering the trading unit
-
- :param current_amount:
- :param target_amount:
- :param factor:
- :return real_deal_amount; Positive deal_amount indicates buying more stock.
- """
- if current_amount == target_amount:
- return 0
- elif current_amount < target_amount:
- deal_amount = target_amount - current_amount
- deal_amount = self.round_amount_by_trade_unit(deal_amount, factor)
- return deal_amount
- else:
- if target_amount == 0:
- return -current_amount
- else:
- deal_amount = current_amount - target_amount
- deal_amount = self.round_amount_by_trade_unit(deal_amount, factor)
- return -deal_amount
-
- def generate_order_for_target_amount_position(self, target_position, current_position, trade_date):
- """Parameter:
- target_position : dict { stock_id : amount }
- current_postion : dict { stock_id : amount}
- trade_unit : trade_unit
- down sample : for amount 321 and trade_unit 100, deal_amount is 300
- deal order on trade_date
- """
- # split buy and sell for further use
- buy_order_list = []
- sell_order_list = []
- # three parts: kept stock_id, dropped stock_id, new stock_id
- # handle kept stock_id
-
- # because the order of the set is not fixed, the trading order of the stock is different, so that the backtest results of the same parameter are different;
- # so here we sort stock_id, and then randomly shuffle the order of stock_id
- # because the same random seed is used, the final stock_id order is fixed
- sorted_ids = sorted(set(list(current_position.keys()) + list(target_position.keys())))
- random.seed(0)
- random.shuffle(sorted_ids)
- for stock_id in sorted_ids:
-
- # Do not generate order for the nontradable stocks
- if not self.is_stock_tradable(stock_id=stock_id, trade_date=trade_date):
- continue
-
- target_amount = target_position.get(stock_id, 0)
- current_amount = current_position.get(stock_id, 0)
- factor = self.quote[(stock_id, trade_date)]["$factor"]
-
- deal_amount = self.get_real_deal_amount(current_amount, target_amount, factor)
- if deal_amount == 0:
- continue
- elif deal_amount > 0:
- # buy stock
- buy_order_list.append(
- Order(
- stock_id=stock_id,
- amount=deal_amount,
- direction=Order.BUY,
- trade_date=trade_date,
- factor=factor,
- )
- )
- else:
- # sell stock
- sell_order_list.append(
- Order(
- stock_id=stock_id,
- amount=abs(deal_amount),
- direction=Order.SELL,
- trade_date=trade_date,
- factor=factor,
- )
- )
- # return order_list : buy + sell
- return sell_order_list + buy_order_list
-
- def calculate_amount_position_value(self, amount_dict, trade_date, only_tradable=False):
- """Parameter
- position : Position()
- amount_dict : {stock_id : amount}
- """
- value = 0
- for stock_id in amount_dict:
- if (
- self.check_stock_suspended(stock_id=stock_id, trade_date=trade_date) is False
- and self.check_stock_limit(stock_id=stock_id, trade_date=trade_date) is False
- ):
- value += self.get_deal_price(stock_id=stock_id, trade_date=trade_date) * amount_dict[stock_id]
- return value
-
- def round_amount_by_trade_unit(self, deal_amount, factor):
- """Parameter
- deal_amount : float, adjusted amount
- factor : float, adjusted factor
- return : float, real amount
- """
- if not self.trade_w_adj_price:
- # the minimal amount is 1. Add 0.1 for solving precision problem.
- return (deal_amount * factor + 0.1) // self.trade_unit * self.trade_unit / factor
- return deal_amount
-
- def _calc_trade_info_by_order(self, order, position):
- """
- Calculation of trade info
-
- :param order:
- :param position: Position
- :return: trade_val, trade_cost
- """
-
- trade_price = self.get_deal_price(order.stock_id, order.trade_date)
- if order.direction == Order.SELL:
- # sell
- if position is not None:
- if np.isclose(order.amount, position.get_stock_amount(order.stock_id)):
- # when selling last stock. The amount don't need rounding
- order.deal_amount = order.amount
- else:
- order.deal_amount = self.round_amount_by_trade_unit(order.amount, order.factor)
- else:
- # TODO: We don't know current position.
- # We choose to sell all
- order.deal_amount = order.amount
-
- trade_val = order.deal_amount * trade_price
- trade_cost = max(trade_val * self.close_cost, self.min_cost)
- elif order.direction == Order.BUY:
- # buy
- if position is not None:
- cash = position.get_cash()
- trade_val = order.amount * trade_price
- if cash < trade_val * (1 + self.open_cost):
- # The money is not enough
- order.deal_amount = self.round_amount_by_trade_unit(
- cash / (1 + self.open_cost) / trade_price, order.factor
- )
- else:
- # THe money is enough
- order.deal_amount = self.round_amount_by_trade_unit(order.amount, order.factor)
- else:
- # Unknown amount of money. Just round the amount
- order.deal_amount = self.round_amount_by_trade_unit(order.amount, order.factor)
-
- trade_val = order.deal_amount * trade_price
- trade_cost = trade_val * self.open_cost
- else:
- raise NotImplementedError("order type {} error".format(order.type))
-
- return trade_val, trade_cost
diff --git a/qlib/contrib/backtest/order.py b/qlib/contrib/backtest/order.py
deleted file mode 100644
index 740773b2fd..0000000000
--- a/qlib/contrib/backtest/order.py
+++ /dev/null
@@ -1,29 +0,0 @@
-# Copyright (c) Microsoft Corporation.
-# Licensed under the MIT License.
-
-
-class Order:
-
- SELL = 0
- BUY = 1
-
- def __init__(self, stock_id, amount, trade_date, direction, factor):
- """Parameter
- direction : Order.SELL for sell; Order.BUY for buy
- stock_id : str
- amount : float
- trade_date : pd.Timestamp
- factor : float
- presents the weight factor assigned in Exchange()
- """
- # check direction
- if direction not in {Order.SELL, Order.BUY}:
- raise NotImplementedError("direction not supported, `Order.SELL` for sell, `Order.BUY` for buy")
- self.stock_id = stock_id
- # amount of generated orders
- self.amount = amount
- # amount of successfully completed orders
- self.deal_amount = 0
- self.trade_date = trade_date
- self.direction = direction
- self.factor = factor
diff --git a/qlib/contrib/backtest/position.py b/qlib/contrib/backtest/position.py
deleted file mode 100644
index 8a4e137ca3..0000000000
--- a/qlib/contrib/backtest/position.py
+++ /dev/null
@@ -1,213 +0,0 @@
-# Copyright (c) Microsoft Corporation.
-# Licensed under the MIT License.
-
-import copy
-import pathlib
-import pandas as pd
-import numpy as np
-from .order import Order
-
-"""
-Position module
-"""
-
-"""
-current state of position
-a typical example is :{
- : {
- 'count': ,
- 'amount': ,
- 'price': ,
- 'weight': ,
- },
-}
-
-"""
-
-
-class Position:
- """Position"""
-
- def __init__(self, cash=0, position_dict={}, today_account_value=0):
- # NOTE: The position dict must be copied!!!
- # Otherwise the initial value
- self.init_cash = cash
- self.position = position_dict.copy()
- self.position["cash"] = cash
- self.position["today_account_value"] = today_account_value
-
- def init_stock(self, stock_id, amount, price=None):
- self.position[stock_id] = {}
- self.position[stock_id]["count"] = 0 # update count in the end of this date
- self.position[stock_id]["amount"] = amount
- self.position[stock_id]["price"] = price
- self.position[stock_id]["weight"] = 0 # update the weight in the end of the trade date
-
- def buy_stock(self, stock_id, trade_val, cost, trade_price):
- trade_amount = trade_val / trade_price
- if stock_id not in self.position:
- self.init_stock(stock_id=stock_id, amount=trade_amount, price=trade_price)
- else:
- # exist, add amount
- self.position[stock_id]["amount"] += trade_amount
-
- self.position["cash"] -= trade_val + cost
-
- def sell_stock(self, stock_id, trade_val, cost, trade_price):
- trade_amount = trade_val / trade_price
- if stock_id not in self.position:
- raise KeyError("{} not in current position".format(stock_id))
- else:
- # decrease the amount of stock
- self.position[stock_id]["amount"] -= trade_amount
- # check if to delete
- if self.position[stock_id]["amount"] < -1e-5:
- raise ValueError(
- "only have {} {}, require {}".format(self.position[stock_id]["amount"], stock_id, trade_amount)
- )
- elif abs(self.position[stock_id]["amount"]) <= 1e-5:
- self.del_stock(stock_id)
-
- self.position["cash"] += trade_val - cost
-
- def del_stock(self, stock_id):
- del self.position[stock_id]
-
- def update_order(self, order, trade_val, cost, trade_price):
- # handle order, order is a order class, defined in exchange.py
- if order.direction == Order.BUY:
- # BUY
- self.buy_stock(order.stock_id, trade_val, cost, trade_price)
- elif order.direction == Order.SELL:
- # SELL
- self.sell_stock(order.stock_id, trade_val, cost, trade_price)
- else:
- raise NotImplementedError("do not suppotr order direction {}".format(order.direction))
-
- def update_stock_price(self, stock_id, price):
- self.position[stock_id]["price"] = price
-
- def update_stock_count(self, stock_id, count):
- self.position[stock_id]["count"] = count
-
- def update_stock_weight(self, stock_id, weight):
- self.position[stock_id]["weight"] = weight
-
- def update_cash(self, cash):
- self.position["cash"] = cash
-
- def calculate_stock_value(self):
- stock_list = self.get_stock_list()
- value = 0
- for stock_id in stock_list:
- value += self.position[stock_id]["amount"] * self.position[stock_id]["price"]
- return value
-
- def calculate_value(self):
- value = self.calculate_stock_value()
- value += self.position["cash"]
- return value
-
- def get_stock_list(self):
- stock_list = list(set(self.position.keys()) - {"cash", "today_account_value"})
- return stock_list
-
- def get_stock_price(self, code):
- return self.position[code]["price"]
-
- def get_stock_amount(self, code):
- return self.position[code]["amount"]
-
- def get_stock_count(self, code):
- return self.position[code]["count"]
-
- def get_stock_weight(self, code):
- return self.position[code]["weight"]
-
- def get_cash(self):
- return self.position["cash"]
-
- def get_stock_amount_dict(self):
- """generate stock amount dict {stock_id : amount of stock}"""
- d = {}
- stock_list = self.get_stock_list()
- for stock_code in stock_list:
- d[stock_code] = self.get_stock_amount(code=stock_code)
- return d
-
- def get_stock_weight_dict(self, only_stock=False):
- """get_stock_weight_dict
- generate stock weight fict {stock_id : value weight of stock in the position}
- it is meaningful in the beginning or the end of each trade date
-
- :param only_stock: If only_stock=True, the weight of each stock in total stock will be returned
- If only_stock=False, the weight of each stock in total assets(stock + cash) will be returned
- """
- if only_stock:
- position_value = self.calculate_stock_value()
- else:
- position_value = self.calculate_value()
- d = {}
- stock_list = self.get_stock_list()
- for stock_code in stock_list:
- d[stock_code] = self.position[stock_code]["amount"] * self.position[stock_code]["price"] / position_value
- return d
-
- def add_count_all(self):
- stock_list = self.get_stock_list()
- for code in stock_list:
- self.position[code]["count"] += 1
-
- def update_weight_all(self):
- weight_dict = self.get_stock_weight_dict()
- for stock_code, weight in weight_dict.items():
- self.update_stock_weight(stock_code, weight)
-
- def save_position(self, path, last_trade_date):
- path = pathlib.Path(path)
- p = copy.deepcopy(self.position)
- cash = pd.Series(dtype=float)
- cash["init_cash"] = self.init_cash
- cash["cash"] = p["cash"]
- cash["today_account_value"] = p["today_account_value"]
- cash["last_trade_date"] = str(last_trade_date.date()) if last_trade_date else None
- del p["cash"]
- del p["today_account_value"]
- positions = pd.DataFrame.from_dict(p, orient="index")
- with pd.ExcelWriter(path) as writer:
- positions.to_excel(writer, sheet_name="position")
- cash.to_excel(writer, sheet_name="info")
-
- def load_position(self, path):
- """load position information from a file
- should have format below
- sheet "position"
- columns: ['stock', 'count', 'amount', 'price', 'weight']
- 'count': ,
- 'amount': ,
- 'price': ,
- 'weight': ,
-
- sheet "cash"
- index: ['init_cash', 'cash', 'today_account_value']
- 'init_cash': ,
- 'cash': ,
- 'today_account_value':
- """
- path = pathlib.Path(path)
- positions = pd.read_excel(open(path, "rb"), sheet_name="position", index_col=0)
- cash_record = pd.read_excel(open(path, "rb"), sheet_name="info", index_col=0)
- positions = positions.to_dict(orient="index")
- init_cash = cash_record.loc["init_cash"].values[0]
- cash = cash_record.loc["cash"].values[0]
- today_account_value = cash_record.loc["today_account_value"].values[0]
- last_trade_date = cash_record.loc["last_trade_date"].values[0]
-
- # assign values
- self.position = {}
- self.init_cash = init_cash
- self.position = positions
- self.position["cash"] = cash
- self.position["today_account_value"] = today_account_value
-
- return None if pd.isna(last_trade_date) else pd.Timestamp(last_trade_date)
diff --git a/qlib/contrib/backtest/report.py b/qlib/contrib/backtest/report.py
deleted file mode 100644
index beb9759d0d..0000000000
--- a/qlib/contrib/backtest/report.py
+++ /dev/null
@@ -1,106 +0,0 @@
-# Copyright (c) Microsoft Corporation.
-# Licensed under the MIT License.
-
-
-from collections import OrderedDict
-import pandas as pd
-import pathlib
-
-
-class Report:
- # daily report of the account
- # contain those followings: returns, costs turnovers, accounts, cash, bench, value
- # update report
- def __init__(self):
- self.init_vars()
-
- def init_vars(self):
- self.accounts = OrderedDict() # account postion value for each trade date
- self.returns = OrderedDict() # daily return rate for each trade date
- self.turnovers = OrderedDict() # turnover for each trade date
- self.costs = OrderedDict() # trade cost for each trade date
- self.values = OrderedDict() # value for each trade date
- self.cashes = OrderedDict()
- self.latest_report_date = None # pd.TimeStamp
-
- def is_empty(self):
- return len(self.accounts) == 0
-
- def get_latest_date(self):
- return self.latest_report_date
-
- def get_latest_account_value(self):
- return self.accounts[self.latest_report_date]
-
- def update_report_record(
- self,
- trade_date=None,
- account_value=None,
- cash=None,
- return_rate=None,
- turnover_rate=None,
- cost_rate=None,
- stock_value=None,
- ):
- # check data
- if None in [
- trade_date,
- account_value,
- cash,
- return_rate,
- turnover_rate,
- cost_rate,
- stock_value,
- ]:
- raise ValueError(
- "None in [trade_date, account_value, cash, return_rate, turnover_rate, cost_rate, stock_value]"
- )
- # update report data
- self.accounts[trade_date] = account_value
- self.returns[trade_date] = return_rate
- self.turnovers[trade_date] = turnover_rate
- self.costs[trade_date] = cost_rate
- self.values[trade_date] = stock_value
- self.cashes[trade_date] = cash
- # update latest_report_date
- self.latest_report_date = trade_date
- # finish daily report update
-
- def generate_report_dataframe(self):
- report = pd.DataFrame()
- report["account"] = pd.Series(self.accounts)
- report["return"] = pd.Series(self.returns)
- report["turnover"] = pd.Series(self.turnovers)
- report["cost"] = pd.Series(self.costs)
- report["value"] = pd.Series(self.values)
- report["cash"] = pd.Series(self.cashes)
- report.index.name = "date"
- return report
-
- def save_report(self, path):
- r = self.generate_report_dataframe()
- r.to_csv(path)
-
- def load_report(self, path):
- """load report from a file
- should have format like
- columns = ['account', 'return', 'turnover', 'cost', 'value', 'cash']
- :param
- path: str/ pathlib.Path()
- """
- path = pathlib.Path(path)
- r = pd.read_csv(open(path, "rb"), index_col=0)
- r.index = pd.DatetimeIndex(r.index)
-
- index = r.index
- self.init_vars()
- for date in index:
- self.update_report_record(
- trade_date=date,
- account_value=r.loc[date]["account"],
- cash=r.loc[date]["cash"],
- return_rate=r.loc[date]["return"],
- turnover_rate=r.loc[date]["turnover"],
- cost_rate=r.loc[date]["cost"],
- stock_value=r.loc[date]["value"],
- )
diff --git a/qlib/contrib/evaluate.py b/qlib/contrib/evaluate.py
index 4aa5b55156..31d24d8f50 100644
--- a/qlib/contrib/evaluate.py
+++ b/qlib/contrib/evaluate.py
@@ -3,13 +3,15 @@
from __future__ import division
from __future__ import print_function
+from logging import warn
import numpy as np
import pandas as pd
import warnings
from ..log import get_module_logger
-from .backtest import get_exchange, backtest as backtest_func
-from .backtest.backtest import get_date_range
+from ..backtest import get_exchange, backtest as backtest_func
+from ..utils import get_date_range
+from ..utils.resam import Freq
from ..data import D
from ..config import C
@@ -19,7 +21,7 @@
logger = get_module_logger("Evaluate")
-def risk_analysis(r, N=252):
+def risk_analysis(r, N: int = None, freq: str = "day"):
"""Risk Analysis
Parameters
@@ -27,8 +29,29 @@ def risk_analysis(r, N=252):
r : pandas.Series
daily return series.
N: int
- scaler for annualizing information_ratio (day: 250, week: 50, month: 12).
+ scaler for annualizing information_ratio (day: 252, week: 50, month: 12), at least one of `N` and `freq` should exist
+ freq: str
+ analysis frequency used for calculating the scaler, at least one of `N` and `freq` should exist
"""
+
+ def cal_risk_analysis_scaler(freq):
+ _count, _freq = Freq.parse(freq)
+ # len(D.calendar(start_time='2010-01-01', end_time='2019-12-31', freq='day')) = 2384
+ _freq_scaler = {
+ Freq.NORM_FREQ_MINUTE: 240 * 238,
+ Freq.NORM_FREQ_DAY: 238,
+ Freq.NORM_FREQ_WEEK: 50,
+ Freq.NORM_FREQ_MONTH: 12,
+ }
+ return _freq_scaler[_freq] / _count
+
+ if N is None and freq is None:
+ raise ValueError("at least one of `N` and `freq` should exist")
+ if N is not None and freq is not None:
+ warnings.warn("risk_analysis freq will be ignored")
+ if N is None:
+ N = cal_risk_analysis_scaler(freq)
+
mean = r.mean()
std = r.std(ddof=1)
annualized_return = mean * N
@@ -41,7 +64,55 @@ def risk_analysis(r, N=252):
"information_ratio": information_ratio,
"max_drawdown": max_drawdown,
}
- res = pd.Series(data, index=data.keys()).to_frame("risk")
+ res = pd.Series(data).to_frame("risk")
+ return res
+
+
+def indicator_analysis(df, method="mean"):
+ """analyze statistical time-series indicators of trading
+
+ Parameters
+ ----------
+ df : pandas.DataFrame
+ columns: like ['pa', 'pos', 'ffr', 'deal_amount', 'value'].
+ Necessary fields:
+ - 'pa' is the price advantage in trade indicators
+ - 'pos' is the positive rate in trade indicators
+ - 'ffr' is the fulfill rate in trade indicators
+ Optional fields:
+ - 'deal_amount' is the total deal deal_amount, only necessary when method is 'amount_weighted'
+ - 'value' is the total trade value, only necessary when method is 'value_weighted'
+
+ index: Index(datetime)
+ method : str, optional
+ statistics method of pa/ffr, by default "mean"
+ - if method is 'mean', count the mean statistical value of each trade indicator
+ - if method is 'amount_weighted', count the deal_amount weighted mean statistical value of each trade indicator
+ - if method is 'value_weighted', count the value weighted mean statistical value of each trade indicator
+ Note: statistics method of pos is always "mean"
+
+ Returns
+ -------
+ pd.DataFrame
+ statistical value of each trade indicators
+ """
+ weights_dict = {
+ "mean": df["count"],
+ "amount_weighted": df["deal_amount"].abs(),
+ "value_weighted": df["value"].abs(),
+ }
+ if method not in weights_dict:
+ raise ValueError(f"indicator_analysis method {method} is not supported!")
+
+ # statistic pa/ffr indicator
+ indicators_df = df[["ffr", "pa"]]
+ weights = weights_dict.get(method)
+ res = indicators_df.mul(weights, axis=0).sum() / weights.sum()
+
+ # statistic pos
+ weights = weights_dict.get("mean")
+ res.loc["pos"] = df["pos"].mul(weights).sum() / weights.sum()
+ res = res.to_frame("value")
return res
@@ -119,9 +190,7 @@ def backtest(pred, account=1e9, shift=1, benchmark="SH000905", verbose=True, **k
whether to print log.
"""
- warnings.warn(
- "this function is deprecated, please use backtest function in qlib.contrib.backtest", DeprecationWarning
- )
+ warnings.warn("this function is deprecated, please use backtest function in qlib.backtest", DeprecationWarning)
report_dict = backtest_func(
pred=pred, account=account, shift=shift, benchmark=benchmark, verbose=verbose, return_order=False, **kwargs
)
diff --git a/qlib/contrib/online/executor.py b/qlib/contrib/online/executor.py
deleted file mode 100644
index 2bd0937a03..0000000000
--- a/qlib/contrib/online/executor.py
+++ /dev/null
@@ -1,291 +0,0 @@
-# Copyright (c) Microsoft Corporation.
-# Licensed under the MIT License.
-
-
-import re
-import json
-import copy
-import pathlib
-import pandas as pd
-from ...data import D
-from ...utils import get_date_in_file_name
-from ...utils import get_pre_trading_date
-from ..backtest.order import Order
-
-
-class BaseExecutor:
- """
- # Strategy framework document
-
- class Executor(BaseExecutor):
- """
-
- def execute(self, trade_account, order_list, trade_date):
- """
- return the executed result (trade_info) after trading at trade_date.
- NOTICE: trade_account will not be modified after executing.
- Parameter
- ---------
- trade_account : Account()
- order_list : list
- [Order()]
- trade_date : pd.Timestamp
- Return
- ---------
- trade_info : list
- [Order(), float, float, float]
- """
- raise NotImplementedError("get_execute_result for this model is not implemented.")
-
- def save_executed_file_from_trade_info(self, trade_info, user_path, trade_date):
- """
- Save the trade_info to the .csv transaction file in disk
- the columns of result file is
- ['date', 'stock_id', 'direction', 'trade_val', 'trade_cost', 'trade_price', 'factor']
- Parameter
- ---------
- trade_info : list of [Order(), float, float, float]
- (order, trade_val, trade_cost, trade_price), trade_info with out factor
- user_path: str / pathlib.Path()
- the sub folder to save user data
-
- transaction_path : string / pathlib.Path()
- """
- YYYY, MM, DD = str(trade_date.date()).split("-")
- folder_path = pathlib.Path(user_path) / "trade" / YYYY / MM
- if not folder_path.exists():
- folder_path.mkdir(parents=True)
- transaction_path = folder_path / "transaction_{}.csv".format(str(trade_date.date()))
- columns = [
- "date",
- "stock_id",
- "direction",
- "amount",
- "trade_val",
- "trade_cost",
- "trade_price",
- "factor",
- ]
- data = []
- for [order, trade_val, trade_cost, trade_price] in trade_info:
- data.append(
- [
- trade_date,
- order.stock_id,
- order.direction,
- order.amount,
- trade_val,
- trade_cost,
- trade_price,
- order.factor,
- ]
- )
- df = pd.DataFrame(data, columns=columns)
- df.to_csv(transaction_path, index=False)
-
- def load_trade_info_from_executed_file(self, user_path, trade_date):
- YYYY, MM, DD = str(trade_date.date()).split("-")
- file_path = pathlib.Path(user_path) / "trade" / YYYY / MM / "transaction_{}.csv".format(str(trade_date.date()))
- if not file_path.exists():
- raise ValueError("File {} not exists!".format(file_path))
-
- filedate = get_date_in_file_name(file_path)
- transaction = pd.read_csv(file_path)
- trade_info = []
- for i in range(len(transaction)):
- date = transaction.loc[i]["date"]
- if not date == filedate:
- continue
- # raise ValueError("date in transaction file {} not equal to it's file date{}".format(date, filedate))
- order = Order(
- stock_id=transaction.loc[i]["stock_id"],
- amount=transaction.loc[i]["amount"],
- trade_date=transaction.loc[i]["date"],
- direction=transaction.loc[i]["direction"],
- factor=transaction.loc[i]["factor"],
- )
- trade_val = transaction.loc[i]["trade_val"]
- trade_cost = transaction.loc[i]["trade_cost"]
- trade_price = transaction.loc[i]["trade_price"]
- trade_info.append([order, trade_val, trade_cost, trade_price])
- return trade_info
-
-
-class SimulatorExecutor(BaseExecutor):
- def __init__(self, trade_exchange, verbose=False):
- self.trade_exchange = trade_exchange
- self.verbose = verbose
- self.order_list = []
-
- def execute(self, trade_account, order_list, trade_date):
- """
- execute the order list, do the trading wil exchange at date.
- Will not modify the trade_account.
- Parameter
- trade_account : Account()
- order_list : list
- list or orders
- trade_date : pd.Timestamp
- :return:
- trade_info : list of [Order(), float, float, float]
- (order, trade_val, trade_cost, trade_price), trade_info with out factor
- """
- account = copy.deepcopy(trade_account)
- trade_info = []
-
- for order in order_list:
- # check holding thresh is done in strategy
- # if order.direction==0: # sell order
- # # checking holding thresh limit for sell order
- # if trade_account.current.get_stock_count(order.stock_id) < thresh:
- # # can not sell this code
- # continue
- # is order executable
- # check order
- if self.trade_exchange.check_order(order) is True:
- # execute the order
- trade_val, trade_cost, trade_price = self.trade_exchange.deal_order(order, trade_account=account)
- trade_info.append([order, trade_val, trade_cost, trade_price])
- if self.verbose:
- if order.direction == Order.SELL: # sell
- print(
- "[I {:%Y-%m-%d}]: sell {}, price {:.2f}, amount {}, value {:.2f}.".format(
- trade_date,
- order.stock_id,
- trade_price,
- order.deal_amount,
- trade_val,
- )
- )
- else:
- print(
- "[I {:%Y-%m-%d}]: buy {}, price {:.2f}, amount {}, value {:.2f}.".format(
- trade_date,
- order.stock_id,
- trade_price,
- order.deal_amount,
- trade_val,
- )
- )
-
- else:
- if self.verbose:
- print("[W {:%Y-%m-%d}]: {} wrong.".format(trade_date, order.stock_id))
- # do nothing
- pass
- return trade_info
-
-
-def save_score_series(score_series, user_path, trade_date):
- """Save the score_series into a .csv file.
- The columns of saved file is
- [stock_id, score]
-
- Parameter
- ---------
- order_list: [Order()]
- list of Order()
- date: pd.Timestamp
- the date to save the order list
- user_path: str / pathlib.Path()
- the sub folder to save user data
- """
- user_path = pathlib.Path(user_path)
- YYYY, MM, DD = str(trade_date.date()).split("-")
- folder_path = user_path / "score" / YYYY / MM
- if not folder_path.exists():
- folder_path.mkdir(parents=True)
- file_path = folder_path / "score_{}.csv".format(str(trade_date.date()))
- score_series.to_csv(file_path)
-
-
-def load_score_series(user_path, trade_date):
- """Save the score_series into a .csv file.
- The columns of saved file is
- [stock_id, score]
-
- Parameter
- ---------
- order_list: [Order()]
- list of Order()
- date: pd.Timestamp
- the date to save the order list
- user_path: str / pathlib.Path()
- the sub folder to save user data
- """
- user_path = pathlib.Path(user_path)
- YYYY, MM, DD = str(trade_date.date()).split("-")
- folder_path = user_path / "score" / YYYY / MM
- if not folder_path.exists():
- folder_path.mkdir(parents=True)
- file_path = folder_path / "score_{}.csv".format(str(trade_date.date()))
- score_series = pd.read_csv(file_path, index_col=0, header=None, names=["instrument", "score"])
- return score_series
-
-
-def save_order_list(order_list, user_path, trade_date):
- """
- Save the order list into a json file.
- Will calculate the real amount in order according to factors at date.
-
- The format in json file like
- {"sell": {"stock_id": amount, ...}
- ,"buy": {"stock_id": amount, ...}}
-
- :param
- order_list: [Order()]
- list of Order()
- date: pd.Timestamp
- the date to save the order list
- user_path: str / pathlib.Path()
- the sub folder to save user data
- """
- user_path = pathlib.Path(user_path)
- YYYY, MM, DD = str(trade_date.date()).split("-")
- folder_path = user_path / "trade" / YYYY / MM
- if not folder_path.exists():
- folder_path.mkdir(parents=True)
- sell = {}
- buy = {}
- for order in order_list:
- if order.direction == 0: # sell
- sell[order.stock_id] = [order.amount, order.factor]
- else:
- buy[order.stock_id] = [order.amount, order.factor]
- order_dict = {"sell": sell, "buy": buy}
- file_path = folder_path / "orderlist_{}.json".format(str(trade_date.date()))
- with file_path.open("w") as fp:
- json.dump(order_dict, fp)
-
-
-def load_order_list(user_path, trade_date):
- user_path = pathlib.Path(user_path)
- YYYY, MM, DD = str(trade_date.date()).split("-")
- path = user_path / "trade" / YYYY / MM / "orderlist_{}.json".format(str(trade_date.date()))
- if not path.exists():
- raise ValueError("File {} not exists!".format(path))
- # get orders
- with path.open("r") as fp:
- order_dict = json.load(fp)
- order_list = []
- for stock_id in order_dict["sell"]:
- amount, factor = order_dict["sell"][stock_id]
- order = Order(
- stock_id=stock_id,
- amount=amount,
- trade_date=pd.Timestamp(trade_date),
- direction=Order.SELL,
- factor=factor,
- )
- order_list.append(order)
- for stock_id in order_dict["buy"]:
- amount, factor = order_dict["buy"][stock_id]
- order = Order(
- stock_id=stock_id,
- amount=amount,
- trade_date=pd.Timestamp(trade_date),
- direction=Order.BUY,
- factor=factor,
- )
- order_list.append(order)
- return order_list
diff --git a/qlib/contrib/online/operator.py b/qlib/contrib/online/operator.py
index d2307dad54..971dcda75b 100644
--- a/qlib/contrib/online/operator.py
+++ b/qlib/contrib/online/operator.py
@@ -118,9 +118,9 @@ def generate(self, date, path):
user.strategy.update(score_series, pred_date, trade_date)
# generate and save order list
- order_list = user.strategy.generate_order_list(
+ order_list = user.strategy.generate_trade_decision(
score_series=score_series,
- current=user.account.current,
+ current=user.account.current_position,
trade_exchange=trade_exchange,
trade_date=trade_date,
)
@@ -202,13 +202,13 @@ def update(self, date, path, type="SIM"):
score_series = load_score_series((pathlib.Path(path) / user_id), trade_date)
update_account(user.account, trade_info, trade_exchange, trade_date)
- report = user.account.report.generate_report_dataframe()
- self.logger.info(report)
+ portfolio_metrics = user.account.portfolio_metrics.generate_portfolio_metrics_dataframe()
+ self.logger.info(portfolio_metrics)
um.save_user_data(user_id)
self.logger.info("Update account state {} for {}".format(trade_date, user_id))
def simulate(self, id, config, exchange_config, start, end, path, bench="SH000905"):
- """Run the ( generate_order_list -> execute_order_list -> update_account) process everyday
+ """Run the ( generate_trade_decision -> execute_order_list -> update_account) process everyday
from start date to end date.
Parameters
@@ -256,9 +256,9 @@ def simulate(self, id, config, exchange_config, start, end, path, bench="SH00090
user.strategy.update(score_series, pred_date, trade_date)
# 3. generate and save order list
- order_list = user.strategy.generate_order_list(
+ order_list = user.strategy.generate_trade_decision(
score_series=score_series,
- current=user.account.current,
+ current=user.account.current_position,
trade_exchange=trade_exchange,
trade_date=trade_date,
)
@@ -273,8 +273,8 @@ def simulate(self, id, config, exchange_config, start, end, path, bench="SH00090
# 5. update account state
trade_info = executor.load_trade_info_from_executed_file(user_path=user_path, trade_date=trade_date)
update_account(user.account, trade_info, trade_exchange, trade_date)
- report = user.account.report.generate_report_dataframe()
- self.logger.info(report)
+ portfolio_metrics = user.account.portfolio_metrics.generate_portfolio_metrics_dataframe()
+ self.logger.info(portfolio_metrics)
um.save_user_data(id)
self.show(id, path, bench)
@@ -295,12 +295,12 @@ def show(self, id, path, bench="SH000905"):
if id not in um.users:
raise ValueError("Cannot find user ".format(id))
bench = D.features([bench], ["$change"]).loc[bench, "$change"]
- report = um.users[id].account.report.generate_report_dataframe()
- report["bench"] = bench
+ portfolio_metrics = um.users[id].account.portfolio_metrics.generate_portfolio_metrics_dataframe()
+ portfolio_metrics["bench"] = bench
analysis_result = {}
- r = (report["return"] - report["bench"]).dropna()
+ r = (portfolio_metrics["return"] - portfolio_metrics["bench"]).dropna()
analysis_result["excess_return_without_cost"] = risk_analysis(r)
- r = (report["return"] - report["bench"] - report["cost"]).dropna()
+ r = (portfolio_metrics["return"] - portfolio_metrics["bench"] - portfolio_metrics["cost"]).dropna()
analysis_result["excess_return_with_cost"] = risk_analysis(r)
print("Result:")
print("excess_return_without_cost:")
diff --git a/qlib/contrib/online/user.py b/qlib/contrib/online/user.py
index 9b33ec24cd..a7a8654d1a 100644
--- a/qlib/contrib/online/user.py
+++ b/qlib/contrib/online/user.py
@@ -59,16 +59,16 @@ def showReport(self, benchmark="SH000905"):
bench that to be compared, 'SH000905' for csi500
"""
bench = D.features([benchmark], ["$change"], disk_cache=True).loc[benchmark, "$change"]
- report = self.account.report.generate_report_dataframe()
- report["bench"] = bench
+ portfolio_metrics = self.account.portfolio_metrics.generate_portfolio_metrics_dataframe()
+ portfolio_metrics["bench"] = bench
analysis_result = {"pred": {}, "excess_return_without_cost": {}, "excess_return_with_cost": {}}
- r = (report["return"] - report["bench"]).dropna()
+ r = (portfolio_metrics["return"] - portfolio_metrics["bench"]).dropna()
analysis_result["excess_return_without_cost"][0] = risk_analysis(r)
- r = (report["return"] - report["bench"] - report["cost"]).dropna()
+ r = (portfolio_metrics["return"] - portfolio_metrics["bench"] - portfolio_metrics["cost"]).dropna()
analysis_result["excess_return_with_cost"][0] = risk_analysis(r)
self.logger.info("Result of porfolio:")
self.logger.info("excess_return_without_cost:")
self.logger.info(analysis_result["excess_return_without_cost"][0])
self.logger.info("excess_return_with_cost:")
self.logger.info(analysis_result["excess_return_with_cost"][0])
- return report
+ return portfolio_metrics
diff --git a/qlib/contrib/ops/__init__.py b/qlib/contrib/ops/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/qlib/contrib/ops/high_freq.py b/qlib/contrib/ops/high_freq.py
new file mode 100644
index 0000000000..3ce5c961fa
--- /dev/null
+++ b/qlib/contrib/ops/high_freq.py
@@ -0,0 +1,89 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+from pathlib import Path
+import numpy as np
+import pandas as pd
+from datetime import datetime
+
+import qlib
+from qlib.data import D
+from qlib.data.cache import H
+from qlib.data.data import Cal
+from qlib.data.ops import ElemOperator
+from qlib.utils.time import time_to_day_index
+
+
+def get_calendar_day(freq="1min", future=False):
+ """
+ Load High-Freq Calendar Date Using Memcache.
+ !!!NOTE: Loading the calendar is quite slow. So loading calendar before start multiprocessing will make it faster.
+
+ Parameters
+ ----------
+ freq : str
+ frequency of read calendar file.
+ future : bool
+ whether including future trading day.
+
+ Returns
+ -------
+ _calendar:
+ array of date.
+ """
+ flag = f"{freq}_future_{future}_day"
+ if flag in H["c"]:
+ _calendar = H["c"][flag]
+ else:
+ _calendar = np.array(list(map(lambda x: x.date(), Cal.load_calendar(freq, future))))
+ H["c"][flag] = _calendar
+ return _calendar
+
+
+class DayCumsum(ElemOperator):
+ """DayCumsum Operator during start time and end time.
+
+ Parameters
+ ----------
+ feature : Expression
+ feature instance
+ start : str
+ the start time of backtest in one day.
+ !!!NOTE: "9:30" means the time period of (9:30, 9:31) is in transaction.
+ end : str
+ the end time of backtest in one day.
+ !!!NOTE: "14:59" means the time period of (14:59, 15:00) is in transaction,
+ but (15:00, 15:01) is not.
+ So start="9:30" and end="14:59" means trading all day.
+
+ Returns
+ ----------
+ feature:
+ a series of that each value equals the cumsum value during start time and end time.
+ Otherwise, the value is zero.
+ """
+
+ def __init__(self, feature, start: str = "9:30", end: str = "14:59"):
+ self.feature = feature
+ self.start = datetime.strptime(start, "%H:%M")
+ self.end = datetime.strptime(end, "%H:%M")
+
+ self.morning_open = datetime.strptime("9:30", "%H:%M")
+ self.morning_close = datetime.strptime("11:30", "%H:%M")
+ self.noon_open = datetime.strptime("13:00", "%H:%M")
+ self.noon_close = datetime.strptime("15:00", "%H:%M")
+
+ self.start_id = time_to_day_index(self.start)
+ self.end_id = time_to_day_index(self.end)
+
+ def period_cusum(self, df):
+ df = df.copy()
+ assert len(df) == 240
+ df.iloc[0 : self.start_id] = 0
+ df = df.cumsum()
+ df.iloc[self.end_id + 1 : 240] = 0
+ return df
+
+ def _load_internal(self, instrument, start_index, end_index, freq):
+ _calendar = get_calendar_day(freq=freq)
+ series = self.feature.load(instrument, start_index, end_index, freq)
+ return series.groupby(_calendar[series.index]).transform(self.period_cusum)
diff --git a/qlib/contrib/report/analysis_position/parse_position.py b/qlib/contrib/report/analysis_position/parse_position.py
index fe1d611370..1373d902f0 100644
--- a/qlib/contrib/report/analysis_position/parse_position.py
+++ b/qlib/contrib/report/analysis_position/parse_position.py
@@ -4,7 +4,7 @@
import pandas as pd
-from ...backtest.profit_attribution import get_stock_weight_df
+from ....backtest.profit_attribution import get_stock_weight_df
def parse_position(position: dict = None) -> pd.DataFrame:
@@ -41,7 +41,7 @@ def parse_position(position: dict = None) -> pd.DataFrame:
for _trading_date, _value in position.items():
# pd_date type: pd.Timestamp
_cash = _value.pop("cash")
- for _item in ["today_account_value"]:
+ for _item in ["now_account_value"]:
if _item in _value:
_value.pop(_item)
diff --git a/qlib/contrib/report/analysis_position/rank_label.py b/qlib/contrib/report/analysis_position/rank_label.py
index 77743b10c1..2927f12a29 100644
--- a/qlib/contrib/report/analysis_position/rank_label.py
+++ b/qlib/contrib/report/analysis_position/rank_label.py
@@ -97,7 +97,7 @@ def rank_label_graph(
qcr.analysis_position.rank_label_graph(positions, features_df, pred_df_dates.min(), pred_df_dates.max())
- :param position: position data; **qlib.contrib.backtest.backtest.backtest** result.
+ :param position: position data; **qlib.backtest.backtest** result.
:param label_data: **D.features** result; index is **pd.MultiIndex**, index name is **[instrument, datetime]**; columns names is **[label]**.
**The label T is the change from T to T+1**, it is recommended to use ``close``, example: `D.features(D.instruments('csi500'), ['Ref($close, -1)/$close-1'])`.
diff --git a/qlib/contrib/report/graph.py b/qlib/contrib/report/graph.py
index 2d4f546e82..edb7e018d4 100644
--- a/qlib/contrib/report/graph.py
+++ b/qlib/contrib/report/graph.py
@@ -3,7 +3,6 @@
import math
import importlib
-from pathlib import Path
from typing import Iterable
import pandas as pd
@@ -14,8 +13,6 @@
from plotly.subplots import make_subplots
from plotly.figure_factory import create_distplot
-from ...utils import get_module_by_module_path
-
class BaseGraph:
""" """
@@ -138,7 +135,7 @@ def figure(self) -> go.Figure:
:return:
"""
_figure = go.Figure(data=self.data, layout=self._get_layout())
- # NOTE: using default 3.x theme
+ # NOTE: Use the default theme from plotly version 3.x, template=None
_figure["layout"].update(template=None)
return _figure
@@ -378,8 +375,9 @@ def _init_figure(self):
for k, v in self._sub_graph_layout.items():
self._figure["layout"][k].update(v)
- # NOTE: using default 3.x theme
- self._figure["layout"].update(self._layout, template=None)
+ # NOTE: Use the default theme from plotly version 3.x: template=None
+ self._figure["layout"].update(template=None)
+ self._figure["layout"].update(self._layout)
@property
def figure(self):
diff --git a/qlib/contrib/strategy/__init__.py b/qlib/contrib/strategy/__init__.py
index 6c2e4ceede..e308c1a058 100644
--- a/qlib/contrib/strategy/__init__.py
+++ b/qlib/contrib/strategy/__init__.py
@@ -2,8 +2,15 @@
# Licensed under the MIT License.
-from .strategy import (
+from .model_strategy import (
TopkDropoutStrategy,
- BaseStrategy,
WeightStrategyBase,
)
+
+from .rule_strategy import (
+ TWAPStrategy,
+ SBBStrategyBase,
+ SBBStrategyEMA,
+)
+
+from .cost_control import SoftTopkStrategy
diff --git a/qlib/contrib/strategy/cost_control.py b/qlib/contrib/strategy/cost_control.py
index dd90437b03..b45c03ae9c 100644
--- a/qlib/contrib/strategy/cost_control.py
+++ b/qlib/contrib/strategy/cost_control.py
@@ -1,13 +1,30 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
+"""
+This strategy is not well maintained
+"""
-from .strategy import StrategyWrapper, WeightStrategyBase
+from .order_generator import OrderGenWInteract
+from .model_strategy import WeightStrategyBase
import copy
class SoftTopkStrategy(WeightStrategyBase):
- def __init__(self, topk, max_sold_weight=1.0, risk_degree=0.95, buy_method="first_fill"):
+ def __init__(
+ self,
+ model,
+ dataset,
+ topk,
+ order_generator_cls_or_obj=OrderGenWInteract,
+ max_sold_weight=1.0,
+ risk_degree=0.95,
+ buy_method="first_fill",
+ trade_exchange=None,
+ level_infra=None,
+ common_infra=None,
+ **kwargs,
+ ):
"""Parameter
topk : int
top-N stocks to buy
@@ -17,13 +34,15 @@ def __init__(self, topk, max_sold_weight=1.0, risk_degree=0.95, buy_method="firs
rank_fill: assign the weight stocks that rank high first(1/topk max)
average_fill: assign the weight to the stocks rank high averagely.
"""
- super().__init__()
+ super(SoftTopkStrategy, self).__init__(
+ model, dataset, order_generator_cls_or_obj, trade_exchange, level_infra, common_infra, **kwargs
+ )
self.topk = topk
self.max_sold_weight = max_sold_weight
self.risk_degree = risk_degree
self.buy_method = buy_method
- def get_risk_degree(self, date):
+ def get_risk_degree(self, trade_step=None):
"""get_risk_degree
Return the proportion of your total value you will used in investment.
Dynamically risk_degree will result in Market timing
@@ -31,7 +50,7 @@ def get_risk_degree(self, date):
# It will use 95% amoutn of your total value by default
return self.risk_degree
- def generate_target_weight_position(self, score, current, trade_date):
+ def generate_target_weight_position(self, score, current, trade_start_time, trade_end_time):
"""Parameter:
score : pred score for this trade date, pd.Series, index is stock_id, contain 'score' column
current : current position, use Position() class
diff --git a/qlib/contrib/strategy/model_strategy.py b/qlib/contrib/strategy/model_strategy.py
new file mode 100644
index 0000000000..1d22153a7a
--- /dev/null
+++ b/qlib/contrib/strategy/model_strategy.py
@@ -0,0 +1,320 @@
+import copy
+from qlib.backtest.position import Position
+import warnings
+import numpy as np
+import pandas as pd
+
+from ...utils.resam import resam_ts_data
+from ...strategy.base import ModelStrategy
+from ...backtest.decision import Order, BaseTradeDecision, OrderDir, TradeDecisionWO
+
+from .order_generator import OrderGenWInteract
+
+
+class TopkDropoutStrategy(ModelStrategy):
+ # TODO:
+ # 1. Supporting leverage the get_range_limit result from the decision
+ # 2. Supporting alter_outer_trade_decision
+ # 3. Supporting checking the availability of trade decision
+ def __init__(
+ self,
+ model,
+ dataset,
+ topk,
+ n_drop,
+ method_sell="bottom",
+ method_buy="top",
+ risk_degree=0.95,
+ hold_thresh=1,
+ only_tradable=False,
+ trade_exchange=None,
+ level_infra=None,
+ common_infra=None,
+ **kwargs,
+ ):
+ """
+ Parameters
+ -----------
+ topk : int
+ the number of stocks in the portfolio.
+ n_drop : int
+ number of stocks to be replaced in each trading date.
+ method_sell : str
+ dropout method_sell, random/bottom.
+ method_buy : str
+ dropout method_buy, random/top.
+ risk_degree : float
+ position percentage of total value.
+ hold_thresh : int
+ minimum holding days
+ before sell stock , will check current.get_stock_count(order.stock_id) >= self.hold_thresh.
+ only_tradable : bool
+ will the strategy only consider the tradable stock when buying and selling.
+ if only_tradable:
+ strategy will make buy sell decision without checking the tradable state of the stock.
+ else:
+ strategy will make decision with the tradable state of the stock info and avoid buy and sell them.
+ trade_exchange : Exchange
+ exchange that provides market info, used to deal order and generate report
+ - If `trade_exchange` is None, self.trade_exchange will be set with common_infra
+ - It allowes different trade_exchanges is used in different executions.
+ - For example:
+ - In daily execution, both daily exchange and minutely are usable, but the daily exchange is recommended because it run faster.
+ - In minutely execution, the daily exchange is not usable, only the minutely exchange is recommended.
+
+ """
+ super(TopkDropoutStrategy, self).__init__(
+ model, dataset, level_infra=level_infra, common_infra=common_infra, trade_exchange=trade_exchange, **kwargs
+ )
+ self.topk = topk
+ self.n_drop = n_drop
+ self.method_sell = method_sell
+ self.method_buy = method_buy
+ self.risk_degree = risk_degree
+ self.hold_thresh = hold_thresh
+ self.only_tradable = only_tradable
+
+ def get_risk_degree(self, trade_step=None):
+ """get_risk_degree
+ Return the proportion of your total value you will used in investment.
+ Dynamically risk_degree will result in Market timing.
+ """
+ # It will use 95% amoutn of your total value by default
+ return self.risk_degree
+
+ def generate_trade_decision(self, execute_result=None):
+ # get the number of trading step finished, trade_step can be [0, 1, 2, ..., trade_len - 1]
+ trade_step = self.trade_calendar.get_trade_step()
+ trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)
+ pred_start_time, pred_end_time = self.trade_calendar.get_step_time(trade_step, shift=1)
+ pred_score = resam_ts_data(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last")
+ if pred_score is None:
+ return TradeDecisionWO([], self)
+ if self.only_tradable:
+ # If The strategy only consider tradable stock when make decision
+ # It needs following actions to filter stocks
+ def get_first_n(l, n, reverse=False):
+ cur_n = 0
+ res = []
+ for si in reversed(l) if reverse else l:
+ if self.trade_exchange.is_stock_tradable(
+ stock_id=si, start_time=trade_start_time, end_time=trade_end_time
+ ):
+ res.append(si)
+ cur_n += 1
+ if cur_n >= n:
+ break
+ return res[::-1] if reverse else res
+
+ def get_last_n(l, n):
+ return get_first_n(l, n, reverse=True)
+
+ def filter_stock(l):
+ return [
+ si
+ for si in l
+ if self.trade_exchange.is_stock_tradable(
+ stock_id=si, start_time=trade_start_time, end_time=trade_end_time
+ )
+ ]
+
+ else:
+ # Otherwise, the stock will make decision with out the stock tradable info
+ def get_first_n(l, n):
+ return list(l)[:n]
+
+ def get_last_n(l, n):
+ return list(l)[-n:]
+
+ def filter_stock(l):
+ return l
+
+ current_temp = copy.deepcopy(self.trade_position)
+ # generate order list for this adjust date
+ sell_order_list = []
+ buy_order_list = []
+ # load score
+ cash = current_temp.get_cash()
+ current_stock_list = current_temp.get_stock_list()
+ # last position (sorted by score)
+ last = pred_score.reindex(current_stock_list).sort_values(ascending=False).index
+ # The new stocks today want to buy **at most**
+ if self.method_buy == "top":
+ today = get_first_n(
+ pred_score[~pred_score.index.isin(last)].sort_values(ascending=False).index,
+ self.n_drop + self.topk - len(last),
+ )
+ elif self.method_buy == "random":
+ topk_candi = get_first_n(pred_score.sort_values(ascending=False).index, self.topk)
+ candi = list(filter(lambda x: x not in last, topk_candi))
+ n = self.n_drop + self.topk - len(last)
+ try:
+ today = np.random.choice(candi, n, replace=False)
+ except ValueError:
+ today = candi
+ else:
+ raise NotImplementedError(f"This type of input is not supported")
+ # combine(new stocks + last stocks), we will drop stocks from this list
+ # In case of dropping higher score stock and buying lower score stock.
+ comb = pred_score.reindex(last.union(pd.Index(today))).sort_values(ascending=False).index
+
+ # Get the stock list we really want to sell (After filtering the case that we sell high and buy low)
+ if self.method_sell == "bottom":
+ sell = last[last.isin(get_last_n(comb, self.n_drop))]
+ elif self.method_sell == "random":
+ candi = filter_stock(last)
+ try:
+ sell = pd.Index(np.random.choice(candi, self.n_drop, replace=False) if len(last) else [])
+ except ValueError: # No enough candidates
+ sell = candi
+ else:
+ raise NotImplementedError(f"This type of input is not supported")
+
+ # Get the stock list we really want to buy
+ buy = today[: len(sell) + self.topk - len(last)]
+ for code in current_stock_list:
+ if not self.trade_exchange.is_stock_tradable(
+ stock_id=code, start_time=trade_start_time, end_time=trade_end_time
+ ):
+ continue
+ if code in sell:
+ # check hold limit
+ time_per_step = self.trade_calendar.get_freq()
+ if current_temp.get_stock_count(code, bar=time_per_step) < self.hold_thresh:
+ continue
+ # sell order
+ sell_amount = current_temp.get_stock_amount(code=code)
+ factor = self.trade_exchange.get_factor(
+ stock_id=code, start_time=trade_start_time, end_time=trade_end_time
+ )
+ # sell_amount = self.trade_exchange.round_amount_by_trade_unit(sell_amount, factor)
+ sell_order = Order(
+ stock_id=code,
+ amount=sell_amount,
+ start_time=trade_start_time,
+ end_time=trade_end_time,
+ direction=Order.SELL, # 0 for sell, 1 for buy
+ )
+ # is order executable
+ if self.trade_exchange.check_order(sell_order):
+ sell_order_list.append(sell_order)
+ trade_val, trade_cost, trade_price = self.trade_exchange.deal_order(
+ sell_order, position=current_temp
+ )
+ # update cash
+ cash += trade_val - trade_cost
+ # buy new stock
+ # note the current has been changed
+ current_stock_list = current_temp.get_stock_list()
+ value = cash * self.risk_degree / len(buy) if len(buy) > 0 else 0
+
+ # open_cost should be considered in the real trading environment, while the backtest in evaluate.py does not
+ # consider it as the aim of demo is to accomplish same strategy as evaluate.py, so comment out this line
+ # value = value / (1+self.trade_exchange.open_cost) # set open_cost limit
+ for code in buy:
+ # check is stock suspended
+ if not self.trade_exchange.is_stock_tradable(
+ stock_id=code, start_time=trade_start_time, end_time=trade_end_time
+ ):
+ continue
+ # buy order
+ buy_price = self.trade_exchange.get_deal_price(
+ stock_id=code, start_time=trade_start_time, end_time=trade_end_time, direction=OrderDir.BUY
+ )
+ buy_amount = value / buy_price
+ factor = self.trade_exchange.get_factor(stock_id=code, start_time=trade_start_time, end_time=trade_end_time)
+ buy_amount = self.trade_exchange.round_amount_by_trade_unit(buy_amount, factor)
+ buy_order = Order(
+ stock_id=code,
+ amount=buy_amount,
+ start_time=trade_start_time,
+ end_time=trade_end_time,
+ direction=Order.BUY, # 1 for buy
+ )
+ buy_order_list.append(buy_order)
+ return TradeDecisionWO(sell_order_list + buy_order_list, self)
+
+
+class WeightStrategyBase(ModelStrategy):
+ # TODO:
+ # 1. Supporting leverage the get_range_limit result from the decision
+ # 2. Supporting alter_outer_trade_decision
+ # 3. Supporting checking the availability of trade decision
+ def __init__(
+ self,
+ model,
+ dataset,
+ order_generator_cls_or_obj=OrderGenWInteract,
+ trade_exchange=None,
+ level_infra=None,
+ common_infra=None,
+ **kwargs,
+ ):
+ """
+ trade_exchange : Exchange
+ exchange that provides market info, used to deal order and generate report
+ - If `trade_exchange` is None, self.trade_exchange will be set with common_infra
+ - It allowes different trade_exchanges is used in different executions.
+ - For example:
+ - In daily execution, both daily exchange and minutely are usable, but the daily exchange is recommended because it run faster.
+ - In minutely execution, the daily exchange is not usable, only the minutely exchange is recommended.
+ """
+ super(WeightStrategyBase, self).__init__(
+ model, dataset, level_infra=level_infra, common_infra=common_infra, trade_exchange=trade_exchange, **kwargs
+ )
+ if isinstance(order_generator_cls_or_obj, type):
+ self.order_generator = order_generator_cls_or_obj()
+ else:
+ self.order_generator = order_generator_cls_or_obj
+
+ def get_risk_degree(self, trade_step=None):
+ """get_risk_degree
+ Return the proportion of your total value you will used in investment.
+ Dynamically risk_degree will result in Market timing.
+ """
+ # It will use 95% amoutn of your total value by default
+ return 0.95
+
+ def generate_target_weight_position(self, score, current, trade_start_time, trade_end_time):
+ """
+ Generate target position from score for this date and the current position.The cash is not considered in the position
+ Parameters
+ -----------
+ score : pd.Series
+ pred score for this trade date, index is stock_id, contain 'score' column.
+ current : Position()
+ current position.
+ trade_exchange : Exchange()
+ trade_date : pd.Timestamp
+ trade date.
+ """
+ raise NotImplementedError()
+
+ def generate_trade_decision(self, execute_result=None):
+ # generate_trade_decision
+ # generate_target_weight_position() and generate_order_list_from_target_weight_position() to generate order_list
+
+ # get the number of trading step finished, trade_step can be [0, 1, 2, ..., trade_len - 1]
+ trade_step = self.trade_calendar.get_trade_step()
+ trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)
+ pred_start_time, pred_end_time = self.trade_calendar.get_step_time(trade_step, shift=1)
+ pred_score = resam_ts_data(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last")
+ if pred_score is None:
+ return TradeDecisionWO([], self)
+ current_temp = copy.deepcopy(self.trade_position)
+ assert isinstance(current_temp, Position) # Avoid InfPosition
+
+ target_weight_position = self.generate_target_weight_position(
+ score=pred_score, current=current_temp, trade_start_time=trade_start_time, trade_end_time=trade_end_time
+ )
+ order_list = self.order_generator.generate_order_list_from_target_weight_position(
+ current=current_temp,
+ trade_exchange=self.trade_exchange,
+ risk_degree=self.get_risk_degree(trade_step),
+ target_weight_position=target_weight_position,
+ pred_start_time=pred_start_time,
+ pred_end_time=pred_end_time,
+ trade_start_time=trade_start_time,
+ trade_end_time=trade_end_time,
+ )
+ return TradeDecisionWO(order_list, self)
diff --git a/qlib/contrib/strategy/order_generator.py b/qlib/contrib/strategy/order_generator.py
index 494981ecc0..eff938dd77 100644
--- a/qlib/contrib/strategy/order_generator.py
+++ b/qlib/contrib/strategy/order_generator.py
@@ -4,8 +4,10 @@
"""
This order generator is for strategies based on WeightStrategyBase
"""
-from ..backtest.position import Position
-from ..backtest.exchange import Exchange
+from ...backtest.position import Position
+from ...backtest.exchange import Exchange
+from ...backtest.decision import BaseTradeDecision, TradeDecisionWO
+
import pandas as pd
import copy
@@ -17,8 +19,10 @@ def generate_order_list_from_target_weight_position(
trade_exchange: Exchange,
target_weight_position: dict,
risk_degree: float,
- pred_date: pd.Timestamp,
- trade_date: pd.Timestamp,
+ pred_start_time: pd.Timestamp,
+ pred_end_time: pd.Timestamp,
+ trade_start_time: pd.Timestamp,
+ trade_end_time: pd.Timestamp,
) -> list:
"""generate_order_list_from_target_weight_position
@@ -49,8 +53,10 @@ def generate_order_list_from_target_weight_position(
trade_exchange: Exchange,
target_weight_position: dict,
risk_degree: float,
- pred_date: pd.Timestamp,
- trade_date: pd.Timestamp,
+ pred_start_time: pd.Timestamp,
+ pred_end_time: pd.Timestamp,
+ trade_start_time: pd.Timestamp,
+ trade_end_time: pd.Timestamp,
) -> list:
"""generate_order_list_from_target_weight_position
@@ -77,10 +83,16 @@ def generate_order_list_from_target_weight_position(
# calculate current_tradable_value
current_amount_dict = current.get_stock_amount_dict()
current_total_value = trade_exchange.calculate_amount_position_value(
- amount_dict=current_amount_dict, trade_date=trade_date, only_tradable=False
+ amount_dict=current_amount_dict,
+ trade_start_time=trade_start_time,
+ trade_end_time=trade_end_time,
+ only_tradable=False,
)
current_tradable_value = trade_exchange.calculate_amount_position_value(
- amount_dict=current_amount_dict, trade_date=trade_date, only_tradable=True
+ amount_dict=current_amount_dict,
+ trade_start_time=trade_start_time,
+ trade_end_time=trade_end_time,
+ only_tradable=True,
)
# add cash
current_tradable_value += current.get_cash()
@@ -93,7 +105,9 @@ def generate_order_list_from_target_weight_position(
# value. Then just sell all the stocks
target_amount_dict = copy.deepcopy(current_amount_dict.copy())
for stock_id in list(target_amount_dict.keys()):
- if trade_exchange.is_stock_tradable(stock_id, trade_date):
+ if trade_exchange.is_stock_tradable(
+ stock_id, trade_start_time=trade_start_time, trade_end_time=trade_end_time
+ ):
del target_amount_dict[stock_id]
else:
# consider cost rate
@@ -104,14 +118,16 @@ def generate_order_list_from_target_weight_position(
target_amount_dict = trade_exchange.generate_amount_position_from_weight_position(
weight_position=target_weight_position,
cash=current_tradable_value,
- trade_date=trade_date,
+ trade_start_time=trade_start_time,
+ trade_end_time=trade_end_time,
)
order_list = trade_exchange.generate_order_for_target_amount_position(
target_position=target_amount_dict,
current_position=current_amount_dict,
- trade_date=trade_date,
+ trade_start_time=trade_start_time,
+ trade_end_time=trade_end_time,
)
- return order_list
+ return TradeDecisionWO(order_list, self)
class OrderGenWOInteract(OrderGenerator):
@@ -123,8 +139,10 @@ def generate_order_list_from_target_weight_position(
trade_exchange: Exchange,
target_weight_position: dict,
risk_degree: float,
- pred_date: pd.Timestamp,
- trade_date: pd.Timestamp,
+ pred_start_time: pd.Timestamp,
+ pred_end_time: pd.Timestamp,
+ trade_start_time: pd.Timestamp,
+ trade_end_time: pd.Timestamp,
) -> list:
"""generate_order_list_from_target_weight_position
@@ -153,9 +171,13 @@ def generate_order_list_from_target_weight_position(
amount_dict = {}
for stock_id in target_weight_position:
# Current rule will ignore the stock that not hold and cannot be traded at predict date
- if trade_exchange.is_stock_tradable(stock_id=stock_id, trade_date=pred_date):
+ if trade_exchange.is_stock_tradable(
+ stock_id=stock_id, trade_start_time=trade_start_time, trade_end_time=trade_end_time
+ ):
amount_dict[stock_id] = (
- risk_total_value * target_weight_position[stock_id] / trade_exchange.get_close(stock_id, pred_date)
+ risk_total_value
+ * target_weight_position[stock_id]
+ / trade_exchange.get_close(stock_id, trade_start_time=pred_start_time, trade_end_time=pred_end_time)
)
elif stock_id in current_stock:
amount_dict[stock_id] = (
@@ -166,6 +188,7 @@ def generate_order_list_from_target_weight_position(
order_list = trade_exchange.generate_order_for_target_amount_position(
target_position=amount_dict,
current_position=current.get_stock_amount_dict(),
- trade_date=trade_date,
+ trade_start_time=trade_start_time,
+ trade_end_time=trade_end_time,
)
- return order_list
+ return TradeDecisionWO(order_list, self)
diff --git a/qlib/contrib/strategy/rule_strategy.py b/qlib/contrib/strategy/rule_strategy.py
new file mode 100644
index 0000000000..23fdd29913
--- /dev/null
+++ b/qlib/contrib/strategy/rule_strategy.py
@@ -0,0 +1,669 @@
+from pathlib import Path
+import warnings
+import numpy as np
+import pandas as pd
+from typing import IO, List, Tuple, Union
+from qlib.data.dataset.utils import convert_index_format
+
+from qlib.utils import lazy_sort_index
+
+from ...utils.resam import resam_ts_data, ts_data_last
+from ...data.data import D
+from ...strategy.base import BaseStrategy
+from ...backtest.decision import BaseTradeDecision, Order, TradeDecisionWO, TradeRange
+from ...backtest.exchange import Exchange, OrderHelper
+from ...backtest.utils import CommonInfrastructure, LevelInfrastructure
+from qlib.utils.file import get_io_object
+from qlib.backtest.utils import get_start_end_idx
+
+
+class TWAPStrategy(BaseStrategy):
+ """TWAP Strategy for trading
+
+ NOTE:
+ - This TWAP strategy will celling round when trading. This will make the TWAP trading strategy produce the order
+ ealier when the total trade unit of amount is less than the trading step
+ """
+
+ def reset(self, outer_trade_decision: BaseTradeDecision = None, **kwargs):
+ """
+ Parameters
+ ----------
+ outer_trade_decision : BaseTradeDecision, optional
+ """
+
+ super(TWAPStrategy, self).reset(outer_trade_decision=outer_trade_decision, **kwargs)
+ if outer_trade_decision is not None:
+ self.trade_amount_remain = {}
+ for order in outer_trade_decision.get_decision():
+ self.trade_amount_remain[order.stock_id] = order.amount
+
+ def generate_trade_decision(self, execute_result=None):
+ # NOTE: corner cases!!!
+ # - If using upperbound round, please don't sell the amount which should in next step
+ # - the coordinate of the amount between steps is hard to be dealed between steps in the same level. It
+ # is easier to be dealed in upper steps
+
+ # strategy is not available. Give an empty decision
+ if len(self.outer_trade_decision.get_decision()) == 0:
+ return TradeDecisionWO(order_list=[], strategy=self)
+
+ # get the number of trading step finished, trade_step can be [0, 1, 2, ..., trade_len - 1]
+ trade_step = self.trade_calendar.get_trade_step()
+ # get the total count of trading step
+ start_idx, end_idx = get_start_end_idx(self.trade_calendar, self.outer_trade_decision)
+ trade_len = end_idx - start_idx + 1
+
+ if trade_step < start_idx or trade_step > end_idx:
+ # It is not time to start trading or trading has ended.
+ return TradeDecisionWO(order_list=[], strategy=self)
+
+ rel_trade_step = trade_step - start_idx # trade_step relative to start_idx (number of steps has already passed)
+
+ # update the order amount
+ if execute_result is not None:
+ for order, _, _, _ in execute_result:
+ self.trade_amount_remain[order.stock_id] -= order.deal_amount
+
+ trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)
+ order_list = []
+ for order in self.outer_trade_decision.get_decision():
+ # Don't peek the future information, so we use check_stock_suspended instead of is_stock_tradable
+ # necessity of this
+ # - if stock is suspended, the quote values of stocks is NaN. The following code will raise error when
+ # encountering NaN factor
+ if self.trade_exchange.check_stock_suspended(
+ stock_id=order.stock_id, start_time=trade_start_time, end_time=trade_end_time
+ ):
+ continue
+
+ # the expected trade amount after current step
+ amount_expect = order.amount / trade_len * (rel_trade_step + 1)
+
+ # remain amount
+ amount_remain = self.trade_amount_remain[order.stock_id]
+
+ # the amount has already been finished now.
+ amount_finished = order.amount - amount_remain
+
+ # the expected amount of current step
+ amount_delta = amount_expect - amount_finished
+
+ _amount_trade_unit = self.trade_exchange.get_amount_of_trade_unit(
+ stock_id=order.stock_id, start_time=order.start_time, end_time=order.end_time
+ )
+
+ # round the amount_delta by trade_unit and clip by remain
+ # NOTE: this could be more than expected.
+ if _amount_trade_unit is None:
+ # divide the order into equal parts, and trade one part
+ amount_delta_target = amount_delta
+ else:
+ amount_delta_target = min(
+ np.round(amount_delta / _amount_trade_unit) * _amount_trade_unit, amount_remain
+ )
+
+ # handle last step to make sure all positions have gone
+ # necessity: the last step can't be rounded to the a unit (e.g. reminder < 0.5 unit)
+ if rel_trade_step == trade_len - 1:
+ amount_delta_target = amount_remain
+
+ if amount_delta_target > 1e-5:
+ _order = Order(
+ stock_id=order.stock_id,
+ amount=amount_delta_target,
+ start_time=trade_start_time,
+ end_time=trade_end_time,
+ direction=order.direction, # 1 for buy
+ )
+ order_list.append(_order)
+ return TradeDecisionWO(order_list=order_list, strategy=self)
+
+
+class SBBStrategyBase(BaseStrategy):
+ """
+ (S)elect the (B)etter one among every two adjacent trading (B)ars to sell or buy.
+ """
+
+ TREND_MID = 0
+ TREND_SHORT = 1
+ TREND_LONG = 2
+
+ # TODO:
+ # 1. Supporting leverage the get_range_limit result from the decision
+ # 2. Supporting alter_outer_trade_decision
+ # 3. Supporting checking the availability of trade decision
+
+ def reset(self, outer_trade_decision: BaseTradeDecision = None, **kwargs):
+ """
+ Parameters
+ ----------
+ outer_trade_decision : BaseTradeDecision, optional
+ """
+ super(SBBStrategyBase, self).reset(outer_trade_decision=outer_trade_decision, **kwargs)
+ if outer_trade_decision is not None:
+ self.trade_trend = {}
+ self.trade_amount = {}
+ # init the trade amount of order and predicted trade trend
+ for order in outer_trade_decision.get_decision():
+ self.trade_trend[order.stock_id] = self.TREND_MID
+ self.trade_amount[order.stock_id] = order.amount
+
+ def _pred_price_trend(self, stock_id, pred_start_time=None, pred_end_time=None):
+ raise NotImplementedError("pred_price_trend method is not implemented!")
+
+ def generate_trade_decision(self, execute_result=None):
+ # get the number of trading step finished, trade_step can be [0, 1, 2, ..., trade_len - 1]
+ trade_step = self.trade_calendar.get_trade_step()
+ # get the total count of trading step
+ trade_len = self.trade_calendar.get_trade_len()
+
+ # update the order amount
+ if execute_result is not None:
+ for order, _, _, _ in execute_result:
+ self.trade_amount[order.stock_id] -= order.deal_amount
+
+ trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)
+ pred_start_time, pred_end_time = self.trade_calendar.get_step_time(trade_step, shift=1)
+ order_list = []
+ # for each order in in self.outer_trade_decision
+ for order in self.outer_trade_decision.get_decision():
+ # get the price trend
+ if trade_step % 2 == 0:
+ # in the first of two adjacent bars, predict the price trend
+ _pred_trend = self._pred_price_trend(order.stock_id, pred_start_time, pred_end_time)
+ else:
+ # in the second of two adjacent bars, use the trend predicted in the first one
+ _pred_trend = self.trade_trend[order.stock_id]
+ # if not tradable, continue
+ if not self.trade_exchange.is_stock_tradable(
+ stock_id=order.stock_id, start_time=trade_start_time, end_time=trade_end_time
+ ):
+ if trade_step % 2 == 0:
+ self.trade_trend[order.stock_id] = _pred_trend
+ continue
+ # get amount of one trade unit
+ _amount_trade_unit = self.trade_exchange.get_amount_of_trade_unit(
+ stock_id=order.stock_id, start_time=order.start_time, end_time=order.end_time
+ )
+ if _pred_trend == self.TREND_MID:
+ _order_amount = None
+ # considering trade unit
+ if _amount_trade_unit is None:
+ # divide the order into equal parts, and trade one part
+ _order_amount = self.trade_amount[order.stock_id] / (trade_len - trade_step)
+ # without considering trade unit
+ else:
+ # divide the order into equal parts, and trade one part
+ # calculate the total count of trade units to trade
+ trade_unit_cnt = int(self.trade_amount[order.stock_id] // _amount_trade_unit)
+ # calculate the amount of one part, ceil the amount
+ # floor((trade_unit_cnt + trade_len - trade_step - 1) / (trade_len - trade_step)) == ceil(trade_unit_cnt / (trade_len - trade_step))
+ _order_amount = (
+ (trade_unit_cnt + trade_len - trade_step - 1) // (trade_len - trade_step) * _amount_trade_unit
+ )
+ if order.direction == order.SELL:
+ # sell all amount at last
+ if self.trade_amount[order.stock_id] > 1e-5 and (
+ _order_amount < 1e-5 or trade_step == trade_len - 1
+ ):
+ _order_amount = self.trade_amount[order.stock_id]
+
+ _order_amount = min(_order_amount, self.trade_amount[order.stock_id])
+
+ if _order_amount > 1e-5:
+ _order = Order(
+ stock_id=order.stock_id,
+ amount=_order_amount,
+ start_time=trade_start_time,
+ end_time=trade_end_time,
+ direction=order.direction,
+ )
+ order_list.append(_order)
+
+ else:
+ _order_amount = None
+ # considering trade unit
+ if _amount_trade_unit is None:
+ # N trade day left, divide the order into N + 1 parts, and trade 2 parts
+ _order_amount = 2 * self.trade_amount[order.stock_id] / (trade_len - trade_step + 1)
+ # without considering trade unit
+ else:
+ # cal how many trade unit
+ trade_unit_cnt = int(self.trade_amount[order.stock_id] // _amount_trade_unit)
+ # N trade day left, divide the order into N + 1 parts, and trade 2 parts
+ _order_amount = (
+ (trade_unit_cnt + trade_len - trade_step)
+ // (trade_len - trade_step + 1)
+ * 2
+ * _amount_trade_unit
+ )
+ if order.direction == order.SELL:
+ # sell all amount at last
+ if self.trade_amount[order.stock_id] > 1e-5 and (
+ _order_amount < 1e-5 or trade_step == trade_len - 1
+ ):
+ _order_amount = self.trade_amount[order.stock_id]
+
+ _order_amount = min(_order_amount, self.trade_amount[order.stock_id])
+
+ if _order_amount > 1e-5:
+ if trade_step % 2 == 0:
+ # in the first one of two adjacent bars
+ # if look short on the price, sell the stock more
+ # if look long on the price, buy the stock more
+ if (
+ _pred_trend == self.TREND_SHORT
+ and order.direction == order.SELL
+ or _pred_trend == self.TREND_LONG
+ and order.direction == order.BUY
+ ):
+ _order = Order(
+ stock_id=order.stock_id,
+ amount=_order_amount,
+ start_time=trade_start_time,
+ end_time=trade_end_time,
+ direction=order.direction, # 1 for buy
+ )
+ order_list.append(_order)
+ else:
+ # in the second one of two adjacent bars
+ # if look short on the price, buy the stock more
+ # if look long on the price, sell the stock more
+ if (
+ _pred_trend == self.TREND_SHORT
+ and order.direction == order.BUY
+ or _pred_trend == self.TREND_LONG
+ and order.direction == order.SELL
+ ):
+ _order = Order(
+ stock_id=order.stock_id,
+ amount=_order_amount,
+ start_time=trade_start_time,
+ end_time=trade_end_time,
+ direction=order.direction, # 1 for buy
+ )
+ order_list.append(_order)
+
+ if trade_step % 2 == 0:
+ # in the first one of two adjacent bars, store the trend for the second one to use
+ self.trade_trend[order.stock_id] = _pred_trend
+
+ return TradeDecisionWO(order_list, self)
+
+
+class SBBStrategyEMA(SBBStrategyBase):
+ """
+ (S)elect the (B)etter one among every two adjacent trading (B)ars to sell or buy with (EMA) signal.
+ """
+
+ # TODO:
+ # 1. Supporting leverage the get_range_limit result from the decision
+ # 2. Supporting alter_outer_trade_decision
+ # 3. Supporting checking the availability of trade decision
+
+ def __init__(
+ self,
+ outer_trade_decision: BaseTradeDecision = None,
+ instruments: Union[List, str] = "csi300",
+ freq: str = "day",
+ trade_exchange: Exchange = None,
+ level_infra: LevelInfrastructure = None,
+ common_infra: CommonInfrastructure = None,
+ **kwargs,
+ ):
+ """
+ Parameters
+ ----------
+ instruments : Union[List, str], optional
+ instruments of EMA signal, by default "csi300"
+ freq : str, optional
+ freq of EMA signal, by default "day"
+ Note: `freq` may be different from `time_per_step`
+ """
+ if instruments is None:
+ warnings.warn("`instruments` is not set, will load all stocks")
+ self.instruments = "all"
+ if isinstance(instruments, str):
+ self.instruments = D.instruments(instruments)
+ self.freq = freq
+ super(SBBStrategyEMA, self).__init__(
+ outer_trade_decision, level_infra, common_infra, trade_exchange=trade_exchange, **kwargs
+ )
+
+ def _reset_signal(self):
+ trade_len = self.trade_calendar.get_trade_len()
+ fields = ["EMA($close, 10)-EMA($close, 20)"]
+ signal_start_time, _ = self.trade_calendar.get_step_time(trade_step=0, shift=1)
+ _, signal_end_time = self.trade_calendar.get_step_time(trade_step=trade_len - 1, shift=1)
+ signal_df = D.features(
+ self.instruments, fields, start_time=signal_start_time, end_time=signal_end_time, freq=self.freq
+ )
+ signal_df.columns = ["signal"]
+ self.signal = {}
+
+ if not signal_df.empty:
+ for stock_id, stock_val in signal_df.groupby(level="instrument"):
+ self.signal[stock_id] = stock_val["signal"].droplevel(level="instrument")
+
+ def reset_level_infra(self, level_infra):
+ """
+ reset level-shared infra
+ - After reset the trade calendar, the signal will be changed
+ """
+ super().reset_level_infra(level_infra)
+ self._reset_signal()
+
+ def _pred_price_trend(self, stock_id, pred_start_time=None, pred_end_time=None):
+ # if no signal, return mid trend
+ if stock_id not in self.signal:
+ return self.TREND_MID
+ else:
+ _sample_signal = resam_ts_data(
+ self.signal[stock_id],
+ pred_start_time,
+ pred_end_time,
+ method=ts_data_last,
+ )
+ # if EMA signal == 0 or None, return mid trend
+ if _sample_signal is None or np.isnan(_sample_signal) or _sample_signal == 0:
+ return self.TREND_MID
+ # if EMA signal > 0, return long trend
+ elif _sample_signal > 0:
+ return self.TREND_LONG
+ # if EMA signal < 0, return short trend
+ else:
+ return self.TREND_SHORT
+
+
+class ACStrategy(BaseStrategy):
+ # TODO:
+ # 1. Supporting leverage the get_range_limit result from the decision
+ # 2. Supporting alter_outer_trade_decision
+ # 3. Supporting checking the availability of trade decision
+ def __init__(
+ self,
+ lamb: float = 1e-6,
+ eta: float = 2.5e-6,
+ window_size: int = 20,
+ outer_trade_decision: BaseTradeDecision = None,
+ instruments: Union[List, str] = "csi300",
+ freq: str = "day",
+ trade_exchange: Exchange = None,
+ level_infra: LevelInfrastructure = None,
+ common_infra: CommonInfrastructure = None,
+ **kwargs,
+ ):
+ """
+ Parameters
+ ----------
+ instruments : Union[List, str], optional
+ instruments of Volatility, by default "csi300"
+ freq : str, optional
+ freq of Volatility, by default "day"
+ Note: `freq` may be different from `time_per_step`
+ """
+ self.lamb = lamb
+ self.eta = eta
+ self.window_size = window_size
+ if instruments is None:
+ warnings.warn("`instruments` is not set, will load all stocks")
+ self.instruments = "all"
+ if isinstance(instruments, str):
+ self.instruments = D.instruments(instruments)
+ self.freq = freq
+ super(ACStrategy, self).__init__(
+ outer_trade_decision, level_infra, common_infra, trade_exchange=trade_exchange, **kwargs
+ )
+
+ def _reset_signal(self):
+ trade_len = self.trade_calendar.get_trade_len()
+ fields = [
+ f"Power(Sum(Power(Log($close/Ref($close, 1)), 2), {self.window_size})/{self.window_size - 1}-Power(Sum(Log($close/Ref($close, 1)), {self.window_size}), 2)/({self.window_size}*{self.window_size - 1}), 0.5)"
+ ]
+ signal_start_time, _ = self.trade_calendar.get_step_time(trade_step=0, shift=1)
+ _, signal_end_time = self.trade_calendar.get_step_time(trade_step=trade_len - 1, shift=1)
+ signal_df = D.features(
+ self.instruments, fields, start_time=signal_start_time, end_time=signal_end_time, freq=self.freq
+ )
+ signal_df.columns = ["volatility"]
+ self.signal = {}
+
+ if not signal_df.empty:
+ for stock_id, stock_val in signal_df.groupby(level="instrument"):
+ self.signal[stock_id] = stock_val["volatility"].droplevel(level="instrument")
+
+ def reset_level_infra(self, level_infra):
+ """
+ reset level-shared infra
+ - After reset the trade calendar, the signal will be changed
+ """
+ super().reset_level_infra(level_infra)
+ self._reset_signal()
+
+ def reset(self, outer_trade_decision: BaseTradeDecision = None, **kwargs):
+ """
+ Parameters
+ ----------
+ outer_trade_decision : BaseTradeDecision, optional
+ """
+ super(ACStrategy, self).reset(outer_trade_decision=outer_trade_decision, **kwargs)
+ if outer_trade_decision is not None:
+ self.trade_amount = {}
+ # init the trade amount of order and predicted trade trend
+ for order in outer_trade_decision.get_decision():
+ self.trade_amount[order.stock_id] = order.amount
+
+ def generate_trade_decision(self, execute_result=None):
+ # get the number of trading step finished, trade_step can be [0, 1, 2, ..., trade_len - 1]
+ trade_step = self.trade_calendar.get_trade_step()
+ # get the total count of trading step
+ trade_len = self.trade_calendar.get_trade_len()
+
+ # update the order amount
+ if execute_result is not None:
+ for order, _, _, _ in execute_result:
+ self.trade_amount[order.stock_id] -= order.deal_amount
+
+ trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)
+ pred_start_time, pred_end_time = self.trade_calendar.get_step_time(trade_step, shift=1)
+ order_list = []
+ for order in self.outer_trade_decision.get_decision():
+ # if not tradable, continue
+ if not self.trade_exchange.is_stock_tradable(
+ stock_id=order.stock_id, start_time=trade_start_time, end_time=trade_end_time
+ ):
+ continue
+ _order_amount = None
+ # considering trade unit
+
+ sig_sam = (
+ resam_ts_data(self.signal[order.stock_id], pred_start_time, pred_end_time, method=ts_data_last)
+ if order.stock_id in self.signal
+ else None
+ )
+
+ if sig_sam is None or np.isnan(sig_sam):
+ # no signal, TWAP
+ _amount_trade_unit = self.trade_exchange.get_amount_of_trade_unit(
+ stock_id=order.stock_id, start_time=order.start_time, end_time=order.end_time
+ )
+ if _amount_trade_unit is None:
+ # divide the order into equal parts, and trade one part
+ _order_amount = self.trade_amount[order.stock_id] / (trade_len - trade_step)
+ else:
+ # divide the order into equal parts, and trade one part
+ # calculate the total count of trade units to trade
+ trade_unit_cnt = int(self.trade_amount[order.stock_id] // _amount_trade_unit)
+ # calculate the amount of one part, ceil the amount
+ # floor((trade_unit_cnt + trade_len - trade_step - 1) / (trade_len - trade_step)) == ceil(trade_unit_cnt / (trade_len - trade_step))
+ _order_amount = (
+ (trade_unit_cnt + trade_len - trade_step - 1) // (trade_len - trade_step) * _amount_trade_unit
+ )
+ else:
+ # VA strategy
+ kappa_tild = self.lamb / self.eta * sig_sam * sig_sam
+ kappa = np.arccosh(kappa_tild / 2 + 1)
+ amount_ratio = (
+ np.sinh(kappa * (trade_len - trade_step)) - np.sinh(kappa * (trade_len - trade_step - 1))
+ ) / np.sinh(kappa * trade_len)
+ _order_amount = order.amount * amount_ratio
+ _order_amount = self.trade_exchange.round_amount_by_trade_unit(
+ _order_amount, stock_id=order.stock_id, start_time=order.start_time, end_time=order.end_time
+ )
+
+ if order.direction == order.SELL:
+ # sell all amount at last
+ if self.trade_amount[order.stock_id] > 1e-5 and (_order_amount < 1e-5 or trade_step == trade_len - 1):
+ _order_amount = self.trade_amount[order.stock_id]
+
+ _order_amount = min(_order_amount, self.trade_amount[order.stock_id])
+
+ if _order_amount > 1e-5:
+
+ _order = Order(
+ stock_id=order.stock_id,
+ amount=_order_amount,
+ start_time=trade_start_time,
+ end_time=trade_end_time,
+ direction=order.direction, # 1 for buy
+ factor=order.factor,
+ )
+ order_list.append(_order)
+ return TradeDecisionWO(order_list, self)
+
+
+class RandomOrderStrategy(BaseStrategy):
+ def __init__(
+ self,
+ trade_range: Union[Tuple[int, int], TradeRange], # The range is closed on both left and right.
+ sample_ratio: float = 1.0,
+ volume_ratio: float = 0.01,
+ market: str = "all",
+ direction: int = Order.BUY,
+ *args,
+ **kwargs,
+ ):
+ """
+ Parameters
+ ----------
+ trade_range : Tuple
+ please refer to the `trade_range` parameter of BaseStrategy
+ sample_ratio : float
+ the ratio of all orders are sampled
+ volume_ratio : float
+ the volume of the total day
+ raito of the total volume of a specific day
+ market : str
+ stock pool for sampling
+ """
+
+ super().__init__(*args, **kwargs)
+ self.sample_ratio = sample_ratio
+ self.volume_ratio = volume_ratio
+ self.market = market
+ self.direction = direction
+ exch: Exchange = self.common_infra.get("trade_exchange")
+ # TODO: this can't be online
+ self.volume = D.features(
+ D.instruments(market), ["Mean(Ref($volume, 1), 10)"], start_time=exch.start_time, end_time=exch.end_time
+ )
+ self.volume_df = self.volume.iloc[:, 0].unstack()
+ self.trade_range = trade_range
+
+ def generate_trade_decision(self, execute_result=None):
+ trade_step = self.trade_calendar.get_trade_step()
+ step_time_start, step_time_end = self.trade_calendar.get_step_time(trade_step)
+
+ order_list = []
+ if step_time_start in self.volume_df:
+ for stock_id, volume in self.volume_df[step_time_start].dropna().sample(frac=self.sample_ratio).items():
+ order_list.append(
+ self.common_infra.get("trade_exchange")
+ .get_order_helper()
+ .create(
+ code=stock_id,
+ amount=volume * self.volume_ratio,
+ direction=self.direction,
+ )
+ )
+ return TradeDecisionWO(order_list, self, self.trade_range)
+
+
+class FileOrderStrategy(BaseStrategy):
+ """
+ Motivation:
+ - This class provides an interface for user to read orders from csv files.
+ """
+
+ def __init__(
+ self,
+ file: Union[IO, str, Path, pd.DataFrame],
+ trade_range: Union[Tuple[int, int], TradeRange] = None,
+ *args,
+ **kwargs,
+ ):
+ """
+
+ Parameters
+ ----------
+ file : Union[IO, str, Path, pd.DataFrame]
+ this parameters will specify the info of expected orders
+
+ Here is an example of the content
+
+ 1) Amount (**adjusted**) based strategy
+
+ datetime,instrument,amount,direction
+ 20200102, SH600519, 1000, sell
+ 20200103, SH600519, 1000, buy
+ 20200106, SH600519, 1000, sell
+
+ trade_range : Tuple[int, int]
+ the intra day time index range of the orders
+ the left and right is closed.
+
+ If you want to get the trade_range in intra-day
+ - `qlib/utils/time.py:def get_day_min_idx_range` can help you create the index range easier
+ # TODO: this is a trade_range level limitation. We'll implement a more detailed limitation later.
+
+ """
+ super().__init__(*args, **kwargs)
+ if isinstance(file, pd.DataFrame):
+ self.order_df = file
+ else:
+ with get_io_object(file) as f:
+ self.order_df = pd.read_csv(f, dtype={"datetime": np.str})
+
+ self.order_df["datetime"] = self.order_df["datetime"].apply(pd.Timestamp)
+ self.order_df = self.order_df.set_index(["datetime", "instrument"])
+
+ # make sure the datetime is the first level for fast indexing
+ self.order_df = lazy_sort_index(convert_index_format(self.order_df, level="datetime"))
+ self.trade_range = trade_range
+
+ def generate_trade_decision(self, execute_result=None) -> TradeDecisionWO:
+ """
+ Parameters
+ ----------
+ execute_result :
+ execute_result will be ignored in FileOrderStrategy
+ """
+ oh: OrderHelper = self.common_infra.get("trade_exchange").get_order_helper()
+ start, _ = self.trade_calendar.get_step_time()
+ # CONVERSION: the bar is indexed by the time
+ try:
+ df = self.order_df.loc(axis=0)[start]
+ except KeyError:
+ return TradeDecisionWO([], self)
+ else:
+ order_list = []
+ for idx, row in df.iterrows():
+ order_list.append(
+ oh.create(
+ code=idx,
+ amount=row["amount"],
+ direction=Order.parse_dir(row["direction"]),
+ )
+ )
+ return TradeDecisionWO(order_list, self, self.trade_range)
diff --git a/qlib/contrib/strategy/strategy.py b/qlib/contrib/strategy/strategy.py
deleted file mode 100644
index 1f4aa7b06f..0000000000
--- a/qlib/contrib/strategy/strategy.py
+++ /dev/null
@@ -1,412 +0,0 @@
-# Copyright (c) Microsoft Corporation.
-# Licensed under the MIT License.
-
-
-import copy
-import numpy as np
-import pandas as pd
-
-from ..backtest.order import Order
-from .order_generator import OrderGenWInteract
-
-
-# TODO: The base strategies will be moved out of contrib to core code
-class BaseStrategy:
- def __init__(self):
- pass
-
- def get_risk_degree(self, date):
- """get_risk_degree
- Return the proportion of your total value you will used in investment.
- Dynamically risk_degree will result in Market timing
- """
- # It will use 95% amount of your total value by default
- return 0.95
-
- def generate_order_list(self, score_series, current, trade_exchange, pred_date, trade_date):
- """
- DO NOT directly change the state of current
-
- Parameters
- -----------
- score_series : pd.Series
- stock_id , score.
- current : Position()
- current state of position.
- DO NOT directly change the state of current.
- trade_exchange : Exchange()
- trade exchange.
- pred_date : pd.Timestamp
- predict date.
- trade_date : pd.Timestamp
- trade date.
- """
- pass
-
- def update(self, score_series, pred_date, trade_date):
- """User can use this method to update strategy state each trade date.
- Parameters
- -----------
- score_series : pd.Series
- stock_id , score.
- pred_date : pd.Timestamp
- oredict date.
- trade_date : pd.Timestamp
- trade date.
- """
- pass
-
- def init(self, **kwargs):
- """Some strategy need to be initial after been implemented,
- User can use this method to init his strategy with parameters needed.
- """
- pass
-
- def get_init_args_from_model(self, model, init_date):
- """
- This method only be used in 'online' module, it will generate the *args to initial the strategy.
- :param
- mode : model used in 'online' module.
- """
- return {}
-
-
-class StrategyWrapper:
- """
- StrategyWrapper is a wrapper of another strategy.
- By overriding some methods to make some changes on the basic strategy
- Cost control and risk control will base on this class.
- """
-
- def __init__(self, inner_strategy):
- """__init__
-
- :param inner_strategy: set the inner strategy.
- """
- self.inner_strategy = inner_strategy
-
- def __getattr__(self, name):
- """__getattr__
-
- :param name: If no implementation in this method. Call the method in the innter_strategy by default.
- """
- return getattr(self.inner_strategy, name)
-
-
-class AdjustTimer:
- """AdjustTimer
- Responsible for timing of position adjusting
-
- This is designed as multiple inheritance mechanism due to:
- - the is_adjust may need access to the internel state of a strategy.
-
- - it can be reguard as a enhancement to the existing strategy.
- """
-
- # adjust position in each trade date
- def is_adjust(self, trade_date):
- """is_adjust
- Return if the strategy can adjust positions on `trade_date`
- Will normally be used in strategy do trading with trade frequency
- """
- return True
-
-
-class ListAdjustTimer(AdjustTimer):
- def __init__(self, adjust_dates=None):
- """__init__
-
- :param adjust_dates: an iterable object, it will return a timelist for trading dates
- """
- if adjust_dates is None:
- # None indicates that all dates is OK for adjusting
- self.adjust_dates = None
- else:
- self.adjust_dates = {pd.Timestamp(dt) for dt in adjust_dates}
-
- def is_adjust(self, trade_date):
- if self.adjust_dates is None:
- return True
- return pd.Timestamp(trade_date) in self.adjust_dates
-
-
-class WeightStrategyBase(BaseStrategy, AdjustTimer):
- def __init__(self, order_generator_cls_or_obj=OrderGenWInteract, *args, **kwargs):
- super().__init__(*args, **kwargs)
- if isinstance(order_generator_cls_or_obj, type):
- self.order_generator = order_generator_cls_or_obj()
- else:
- self.order_generator = order_generator_cls_or_obj
-
- def generate_target_weight_position(self, score, current, trade_date):
- """
- Generate target position from score for this date and the current position.The cash is not considered in the position
-
- Parameters
- -----------
- score : pd.Series
- pred score for this trade date, index is stock_id, contain 'score' column.
- current : Position()
- current position.
- trade_date : pd.Timestamp
- trade date.
- """
- raise NotImplementedError()
-
- def generate_order_list(self, score_series, current, trade_exchange, pred_date, trade_date):
- """
- Parameters
- -----------
- score_series : pd.Seires
- stock_id , score.
- current : Position()
- current of account.
- trade_exchange : Exchange()
- exchange.
- trade_date : pd.Timestamp
- date.
- """
- # judge if to adjust
- if not self.is_adjust(trade_date):
- return []
- # generate_order_list
- # generate_target_weight_position() and generate_order_list_from_target_weight_position() to generate order_list
- current_temp = copy.deepcopy(current)
- target_weight_position = self.generate_target_weight_position(
- score=score_series, current=current_temp, trade_date=trade_date
- )
-
- order_list = self.order_generator.generate_order_list_from_target_weight_position(
- current=current_temp,
- trade_exchange=trade_exchange,
- risk_degree=self.get_risk_degree(trade_date),
- target_weight_position=target_weight_position,
- pred_date=pred_date,
- trade_date=trade_date,
- )
- return order_list
-
-
-class TopkDropoutStrategy(BaseStrategy, ListAdjustTimer):
- def __init__(
- self,
- topk,
- n_drop,
- method_sell="bottom",
- method_buy="top",
- risk_degree=0.95,
- thresh=1,
- hold_thresh=1,
- only_tradable=False,
- **kwargs,
- ):
- """
- Parameters
- -----------
- topk : int
- the number of stocks in the portfolio.
- n_drop : int
- number of stocks to be replaced in each trading date.
- method_sell : str
- dropout method_sell, random/bottom.
- method_buy : str
- dropout method_buy, random/top.
- risk_degree : float
- position percentage of total value.
- thresh : int
- minimun holding days since last buy singal of the stock.
- hold_thresh : int
- minimum holding days
- before sell stock , will check current.get_stock_count(order.stock_id) >= self.thresh.
- only_tradable : bool
- will the strategy only consider the tradable stock when buying and selling.
- if only_tradable:
- the strategy will peek at the information in the short future to avoid untradable stocks (untradable stocks include stocks that meet suspension, or hit limit up or limit down).
- else:
- the strategy will generate orders without peeking any information in the future, so the order generated by the strategies may fail.
- """
- super(TopkDropoutStrategy, self).__init__()
- ListAdjustTimer.__init__(self, kwargs.get("adjust_dates", None))
- self.topk = topk
- self.n_drop = n_drop
- self.method_sell = method_sell
- self.method_buy = method_buy
- self.risk_degree = risk_degree
- self.thresh = thresh
- # self.stock_count['code'] will be the days the stock has been hold
- # since last buy signal. This is designed for thresh
- self.stock_count = {}
-
- self.hold_thresh = hold_thresh
- self.only_tradable = only_tradable
-
- def get_risk_degree(self, date):
- """get_risk_degree
- Return the proportion of your total value you will used in investment.
- Dynamically risk_degree will result in Market timing.
- """
- # It will use 95% amoutn of your total value by default
- return self.risk_degree
-
- def generate_order_list(self, score_series, current, trade_exchange, pred_date, trade_date):
- """
- Generate order list according to score_series at trade_date, will not change current.
-
- Parameters
- -----------
- score_series : pd.Series
- stock_id , score.
- current : Position()
- current of account.
- trade_exchange : Exchange()
- exchange.
- pred_date : pd.Timestamp
- predict date.
- trade_date : pd.Timestamp
- trade date.
- """
- if not self.is_adjust(trade_date):
- return []
-
- if self.only_tradable:
- # If The strategy only consider tradable stock when make decision
- # It needs following actions to filter stocks
- def get_first_n(l, n, reverse=False):
- cur_n = 0
- res = []
- for si in reversed(l) if reverse else l:
- if trade_exchange.is_stock_tradable(stock_id=si, trade_date=trade_date):
- res.append(si)
- cur_n += 1
- if cur_n >= n:
- break
- return res[::-1] if reverse else res
-
- def get_last_n(l, n):
- return get_first_n(l, n, reverse=True)
-
- def filter_stock(l):
- return [si for si in l if trade_exchange.is_stock_tradable(stock_id=si, trade_date=trade_date)]
-
- else:
- # Otherwise, the stock will make decision with out the stock tradable info
- def get_first_n(l, n):
- return list(l)[:n]
-
- def get_last_n(l, n):
- return list(l)[-n:]
-
- def filter_stock(l):
- return l
-
- current_temp = copy.deepcopy(current)
- # generate order list for this adjust date
- sell_order_list = []
- buy_order_list = []
- # load score
- cash = current_temp.get_cash()
- current_stock_list = current_temp.get_stock_list()
- # last position (sorted by score)
- last = score_series.reindex(current_stock_list).sort_values(ascending=False).index
- # The new stocks today want to buy **at most**
- if self.method_buy == "top":
- today = get_first_n(
- score_series[~score_series.index.isin(last)].sort_values(ascending=False).index,
- self.n_drop + self.topk - len(last),
- )
- elif self.method_buy == "random":
- topk_candi = get_first_n(score_series.sort_values(ascending=False).index, self.topk)
- candi = list(filter(lambda x: x not in last, topk_candi))
- n = self.n_drop + self.topk - len(last)
- try:
- today = np.random.choice(candi, n, replace=False)
- except ValueError:
- today = candi
- else:
- raise NotImplementedError(f"This type of input is not supported")
- # combine(new stocks + last stocks), we will drop stocks from this list
- # In case of dropping higher score stock and buying lower score stock.
- comb = score_series.reindex(last.union(pd.Index(today))).sort_values(ascending=False).index
-
- # Get the stock list we really want to sell (After filtering the case that we sell high and buy low)
- if self.method_sell == "bottom":
- sell = last[last.isin(get_last_n(comb, self.n_drop))]
- elif self.method_sell == "random":
- candi = filter_stock(last)
- try:
- sell = pd.Index(np.random.choice(candi, self.n_drop, replace=False) if len(last) else [])
- except ValueError: # No enough candidates
- sell = candi
- else:
- raise NotImplementedError(f"This type of input is not supported")
-
- # Get the stock list we really want to buy
- buy = today[: len(sell) + self.topk - len(last)]
-
- # buy singal: if a stock falls into topk, it appear in the buy_sinal
- buy_signal = score_series.sort_values(ascending=False).iloc[: self.topk].index
-
- for code in current_stock_list:
- if not trade_exchange.is_stock_tradable(stock_id=code, trade_date=trade_date):
- continue
- if code in sell:
- # check hold limit
- if self.stock_count[code] < self.thresh or current_temp.get_stock_count(code) < self.hold_thresh:
- # can not sell this code
- # no buy signal, but the stock is kept
- self.stock_count[code] += 1
- continue
- # sell order
- sell_amount = current_temp.get_stock_amount(code=code)
- sell_order = Order(
- stock_id=code,
- amount=sell_amount,
- trade_date=trade_date,
- direction=Order.SELL, # 0 for sell, 1 for buy
- factor=trade_exchange.get_factor(code, trade_date),
- )
- # is order executable
- if trade_exchange.check_order(sell_order):
- sell_order_list.append(sell_order)
- trade_val, trade_cost, trade_price = trade_exchange.deal_order(sell_order, position=current_temp)
- # update cash
- cash += trade_val - trade_cost
- # sold
- del self.stock_count[code]
- else:
- # no buy signal, but the stock is kept
- self.stock_count[code] += 1
- elif code in buy_signal:
- # NOTE: This is different from the original version
- # get new buy signal
- # Only the stock fall in to topk will produce buy signal
- self.stock_count[code] = 1
- else:
- self.stock_count[code] += 1
- # buy new stock
- # note the current has been changed
- current_stock_list = current_temp.get_stock_list()
- value = cash * self.risk_degree / len(buy) if len(buy) > 0 else 0
-
- # open_cost should be considered in the real trading environment, while the backtest in evaluate.py does not
- # consider it as the aim of demo is to accomplish same strategy as evaluate.py, so comment out this line
- # value = value / (1+trade_exchange.open_cost) # set open_cost limit
- for code in buy:
- # check is stock suspended
- if not trade_exchange.is_stock_tradable(stock_id=code, trade_date=trade_date):
- continue
- # buy order
- buy_price = trade_exchange.get_deal_price(stock_id=code, trade_date=trade_date)
- buy_amount = value / buy_price
- factor = trade_exchange.quote[(code, trade_date)]["$factor"]
- buy_amount = trade_exchange.round_amount_by_trade_unit(buy_amount, factor)
- buy_order = Order(
- stock_id=code,
- amount=buy_amount,
- trade_date=trade_date,
- direction=Order.BUY, # 1 for buy
- factor=factor,
- )
- buy_order_list.append(buy_order)
- self.stock_count[code] = 1
- return sell_order_list + buy_order_list
diff --git a/qlib/data/cache.py b/qlib/data/cache.py
index fa1142e4f2..362270b619 100644
--- a/qlib/data/cache.py
+++ b/qlib/data/cache.py
@@ -319,7 +319,7 @@ def clear_cache(cache_path: Union[str, Path]):
@staticmethod
def get_cache_dir(dir_name: str, freq: str = None) -> Path:
- cache_dir = Path(C.dpm.get_data_path(freq)).joinpath(dir_name)
+ cache_dir = Path(C.dpm.get_data_uri(freq)).joinpath(dir_name)
cache_dir.mkdir(parents=True, exist_ok=True)
return cache_dir
@@ -544,7 +544,7 @@ def _expression(self, instrument, field, start_time=None, end_time=None, freq="d
series = self.provider.expression(instrument, field, _calendar[0], _calendar[-1], freq)
if not series.empty:
# This expresion is empty, we don't generate any cache for it.
- with CacheUtils.writer_lock(self.r, f"{str(C.dpm.get_data_path(freq))}:expression-{_cache_uri}"):
+ with CacheUtils.writer_lock(self.r, f"{str(C.dpm.get_data_uri(freq))}:expression-{_cache_uri}"):
self.gen_expression_cache(
expression_data=series,
cache_path=cache_path,
@@ -589,7 +589,7 @@ def update(self, sid, cache_uri, freq: str = "day"):
self.clear_cache(cp_cache_uri)
return 2
- with CacheUtils.writer_lock(self.r, f"{str(C.dpm.get_data_path())}:expression-{cache_uri}"):
+ with CacheUtils.writer_lock(self.r, f"{str(C.dpm.get_data_uri())}:expression-{cache_uri}"):
with meta_path.open("rb") as f:
d = pickle.load(f)
instrument = d["info"]["instrument"]
@@ -724,7 +724,7 @@ def _dataset(
if self.check_cache_exists(cache_path):
if disk_cache == 1:
# use cache
- with CacheUtils.reader_lock(self.r, f"{str(C.dpm.get_data_path(freq))}:dataset-{_cache_uri}"):
+ with CacheUtils.reader_lock(self.r, f"{str(C.dpm.get_data_uri(freq))}:dataset-{_cache_uri}"):
CacheUtils.visit(cache_path)
features = self.read_data_from_cache(cache_path, start_time, end_time, fields)
elif disk_cache == 2:
@@ -734,7 +734,7 @@ def _dataset(
if gen_flag:
# cache unavailable, generate the cache
- with CacheUtils.writer_lock(self.r, f"{str(C.dpm.get_data_path(freq))}:dataset-{_cache_uri}"):
+ with CacheUtils.writer_lock(self.r, f"{str(C.dpm.get_data_uri(freq))}:dataset-{_cache_uri}"):
features = self.gen_dataset_cache(
cache_path=cache_path,
instruments=instruments,
@@ -775,12 +775,12 @@ def _dataset_uri(
if self.check_cache_exists(cache_path):
self.logger.debug(f"The cache dataset has already existed {cache_path}. Return the uri directly")
- with CacheUtils.reader_lock(self.r, f"{str(C.dpm.get_data_path(freq))}:dataset-{_cache_uri}"):
+ with CacheUtils.reader_lock(self.r, f"{str(C.dpm.get_data_uri(freq))}:dataset-{_cache_uri}"):
CacheUtils.visit(cache_path)
return _cache_uri
else:
# cache unavailable, generate the cache
- with CacheUtils.writer_lock(self.r, f"{str(C.dpm.get_data_path(freq))}:dataset-{_cache_uri}"):
+ with CacheUtils.writer_lock(self.r, f"{str(C.dpm.get_data_uri(freq))}:dataset-{_cache_uri}"):
self.gen_dataset_cache(
cache_path=cache_path,
instruments=instruments,
@@ -958,7 +958,7 @@ def update(self, cache_uri, freq: str = "day"):
return 2
im = DiskDatasetCache.IndexManager(cp_cache_uri)
- with CacheUtils.writer_lock(self.r, f"{str(C.dpm.get_data_path())}:dataset-{cache_uri}"):
+ with CacheUtils.writer_lock(self.r, f"{str(C.dpm.get_data_uri())}:dataset-{cache_uri}"):
with meta_path.open("rb") as f:
d = pickle.load(f)
instruments = d["info"]["instruments"]
@@ -1154,7 +1154,7 @@ def dataset(
instruments, fields, None, None, freq, disk_cache=disk_cache, inst_processors=inst_processors
)
value, expire = MemCacheExpire.get_cache(H["f"], feature_uri)
- mnt_feature_uri = C.dpm.get_data_path(freq).joinpath(C.dataset_cache_dir_name).joinpath(feature_uri)
+ mnt_feature_uri = C.dpm.get_data_uri(freq).joinpath(C.dataset_cache_dir_name).joinpath(feature_uri)
if value is None or expire or not mnt_feature_uri.exists():
df, uri = self.provider.dataset(
instruments,
diff --git a/qlib/data/data.py b/qlib/data/data.py
index 621ccf9009..115d381703 100644
--- a/qlib/data/data.py
+++ b/qlib/data/data.py
@@ -14,6 +14,10 @@
import pandas as pd
from multiprocessing import Pool
from typing import Iterable, Union
+from typing import List, Union
+
+# For supporting multiprocessing in outter code, joblib is used
+from joblib import delayed
from .cache import H
from ..config import C
@@ -22,6 +26,8 @@
from .inst_processor import InstProcessor
from ..log import get_module_logger
+from ..utils.time import Freq
+from ..utils.resam import resam_calendar
from .cache import DiskDatasetCache, DiskExpressionCache
from ..utils import (
Wrapper,
@@ -33,6 +39,7 @@
normalize_cache_fields,
code_to_fname,
)
+from ..utils.paral import ParallelExt
class ProviderBackendMixin:
@@ -61,7 +68,22 @@ def backend_obj(self, **kwargs):
provider_uri_map = backend_kwargs.setdefault("provider_uri_map", {})
freq = kwargs.get("freq", "day")
if freq not in provider_uri_map:
- provider_uri_map[freq] = C.dpm.get_data_path(freq)
+ # NOTE: uri
+ # 1. If `freq` in C.dpm.provider_uri.keys(), uri = C.dpm.provider_uri[freq]
+ # 2. If `freq` not in C.dpm.provider_uri.keys()
+ # - Get the `min_freq` closest to `freq` from C.dpm.provider_uri.keys(), uri = C.dpm.provider_uri[min_freq]
+ # NOTE: In Storage, only CalendarStorage is supported
+ # 1. If `uri` does not exist
+ # - Get the `min_uri` of the closest `freq` under the same "directory" as the `uri`
+ # - Read data from `min_uri` and resample to `freq`
+ try:
+ _uri = C.dpm.get_data_uri(freq)
+ except KeyError:
+ # provider_uri is dict and freq not in list(provider_uri.keys())
+ # use the nearest freq greater than 0
+ min_freq = Freq.get_recent_freq(freq, C.dpm.provider_uri.keys())
+ _uri = C.dpm.get_data_uri(freq) if min_freq is None else C.dpm.get_data_uri(min_freq)
+ provider_uri_map[freq] = _uri
backend_kwargs["provider_uri"] = provider_uri_map[freq]
backend.setdefault("kwargs", {}).update(**kwargs)
return init_instance_by_config(backend)
@@ -76,7 +98,6 @@ class CalendarProvider(abc.ABC, ProviderBackendMixin):
def __init__(self, *args, **kwargs):
self.backend = kwargs.get("backend", {})
- @abc.abstractmethod
def calendar(self, start_time=None, end_time=None, freq="day", future=False):
"""Get calendar of certain market in given time range.
@@ -96,9 +117,28 @@ def calendar(self, start_time=None, end_time=None, freq="day", future=False):
list
calendar list
"""
- raise NotImplementedError("Subclass of CalendarProvider must implement `calendar` method")
+ _calendar, _calendar_index = self._get_calendar(freq, future)
+ if start_time == "None":
+ start_time = None
+ if end_time == "None":
+ end_time = None
+ # strip
+ if start_time:
+ start_time = pd.Timestamp(start_time)
+ if start_time > _calendar[-1]:
+ return np.array([])
+ else:
+ start_time = _calendar[0]
+ if end_time:
+ end_time = pd.Timestamp(end_time)
+ if end_time < _calendar[0]:
+ return np.array([])
+ else:
+ end_time = _calendar[-1]
+ _, _, si, ei = self.locate_index(start_time, end_time, freq, future)
+ return _calendar[si : ei + 1]
- def locate_index(self, start_time, end_time, freq, future):
+ def locate_index(self, start_time, end_time, freq, future=False):
"""Locate the start time index and end time index in a calendar under certain frequency.
Parameters
@@ -157,18 +197,32 @@ def _get_calendar(self, freq, future):
dict composed by timestamp as key and index as value for fast search.
"""
flag = f"{freq}_future_{future}"
- if flag in H["c"]:
- _calendar, _calendar_index = H["c"][flag]
- else:
+ if flag not in H["c"]:
_calendar = np.array(self.load_calendar(freq, future))
_calendar_index = {x: i for i, x in enumerate(_calendar)} # for fast search
H["c"][flag] = _calendar, _calendar_index
- return _calendar, _calendar_index
+ return H["c"][flag]
def _uri(self, start_time, end_time, freq, future=False):
"""Get the uri of calendar generation task."""
return hash_args(start_time, end_time, freq, future)
+ def load_calendar(self, freq, future):
+ """Load original calendar timestamp from file.
+
+ Parameters
+ ----------
+ freq : str
+ frequency of read calendar file.
+ future: bool
+
+ Returns
+ ----------
+ list
+ list of timestamps
+ """
+ raise NotImplementedError("Subclass of CalendarProvider must implement `load_calendar` method")
+
class InstrumentProvider(abc.ABC, ProviderBackendMixin):
"""Instrument provider base class
@@ -180,19 +234,22 @@ def __init__(self, *args, **kwargs):
self.backend = kwargs.get("backend", {})
@staticmethod
- def instruments(market="all", filter_pipe=None):
+ def instruments(market: Union[List, str] = "all", filter_pipe: Union[List, None] = None):
"""Get the general config dictionary for a base market adding several dynamic filters.
Parameters
----------
- market : str
- market/industry/index shortname, e.g. all/sse/szse/sse50/csi300/csi500.
+ market : Union[List, str]
+ str:
+ market/industry/index shortname, e.g. all/sse/szse/sse50/csi300/csi500.
+ list:
+ ["ID1", "ID2"]. A list of stocks
filter_pipe : list
the list of dynamic filters.
Returns
----------
- dict
+ dict: if isinstance(market, str)
dict of stockpool config.
{`market`=>base market name, `filter_pipe`=>list of filters}
@@ -210,7 +267,13 @@ def instruments(market="all", filter_pipe=None):
'name_rule_re': 'SH[0-9]{4}55',
'filter_start_time': None,
'filter_end_time': None}]}
+
+ list: if isinstance(market, list)
+ just return the original list directly.
+ NOTE: this will make the instruments compatible with more cases. The user code will be simpler.
"""
+ if isinstance(market, list):
+ return market
from .filter import SeriesDFilter
if filter_pipe is None:
@@ -466,58 +529,45 @@ def dataset_processor(instruments_d, column_names, start_time, end_time, freq, i
"""
normalize_column_names = normalize_cache_fields(column_names)
- data = dict()
# One process for one task, so that the memory will be freed quicker.
- workers = min(C.kernels, len(instruments_d))
- if C.maxtasksperchild is None:
- p = Pool(processes=workers)
- else:
- p = Pool(processes=workers, maxtasksperchild=C.maxtasksperchild)
+ workers = max(min(C.kernels, len(instruments_d)), 1)
+
+ # create iterator
if isinstance(instruments_d, dict):
- for inst, spans in instruments_d.items():
- data[inst] = p.apply_async(
- DatasetProvider.expression_calculator,
- args=(
- inst,
- start_time,
- end_time,
- freq,
- normalize_column_names,
- spans,
- C,
- inst_processors,
- ),
- )
+ it = instruments_d.items()
else:
- for inst in instruments_d:
- data[inst] = p.apply_async(
- DatasetProvider.expression_calculator,
- args=(
- inst,
- start_time,
- end_time,
- freq,
- normalize_column_names,
- None,
- C,
- inst_processors,
- ),
+ it = zip(instruments_d, [None] * len(instruments_d))
+
+ inst_l = []
+ task_l = []
+ for inst, spans in it:
+ inst_l.append(inst)
+ task_l.append(
+ delayed(DatasetProvider.expression_calculator)(
+ inst, start_time, end_time, freq, normalize_column_names, spans, C, inst_processors
)
+ )
- p.close()
- p.join()
+ data = dict(
+ zip(
+ inst_l,
+ ParallelExt(n_jobs=workers, backend=C.joblib_backend, maxtasksperchild=C.maxtasksperchild)(task_l),
+ )
+ )
new_data = dict()
for inst in sorted(data.keys()):
- if len(data[inst].get()) > 0:
+ if len(data[inst]) > 0:
# NOTE: Python version >= 3.6; in versions after python3.6, dict will always guarantee the insertion order
- new_data[inst] = data[inst].get()
+ new_data[inst] = data[inst]
if len(new_data) > 0:
data = pd.concat(new_data, names=["instrument"], sort=False)
data = DiskDatasetCache.cache_to_origin_data(data, column_names)
else:
- data = pd.DataFrame(columns=column_names)
+ data = pd.DataFrame(
+ index=pd.MultiIndex.from_arrays([[], []], names=("instrument", "datetime")), columns=column_names
+ )
return data
@@ -578,13 +628,12 @@ def load_calendar(self, freq, future):
----------
freq : str
frequency of read calendar file.
-
+ future: bool
Returns
----------
list
list of timestamps
"""
-
try:
backend_obj = self.backend_obj(freq=freq, future=future).data
except ValueError:
@@ -601,28 +650,6 @@ def load_calendar(self, freq, future):
return [pd.Timestamp(x) for x in backend_obj]
- def calendar(self, start_time=None, end_time=None, freq="day", future=False):
- _calendar, _calendar_index = self._get_calendar(freq, future)
- if start_time == "None":
- start_time = None
- if end_time == "None":
- end_time = None
- # strip
- if start_time:
- start_time = pd.Timestamp(start_time)
- if start_time > _calendar[-1]:
- return np.array([])
- else:
- start_time = _calendar[0]
- if end_time:
- end_time = pd.Timestamp(end_time)
- if end_time < _calendar[0]:
- return np.array([])
- else:
- end_time = _calendar[-1]
- _, _, si, ei = self.locate_index(start_time, end_time, freq, future)
- return _calendar[si : ei + 1]
-
class LocalInstrumentProvider(InstrumentProvider):
"""Local instrument data provider class
@@ -695,7 +722,7 @@ def expression(self, instrument, field, start_time=None, end_time=None, freq="da
expression = self.get_expression_instance(field)
start_time = pd.Timestamp(start_time)
end_time = pd.Timestamp(end_time)
- _, _, start_index, end_index = Cal.locate_index(start_time, end_time, freq, future=False)
+ _, _, start_index, end_index = Cal.locate_index(start_time, end_time, freq=freq, future=False)
lft_etd, rght_etd = expression.get_extended_window_size()
series = expression.load(instrument, max(0, start_index - lft_etd), end_index + rght_etd, freq)
# Ensure that each column type is consistent
@@ -735,7 +762,9 @@ def dataset(
column_names = self.get_column_names(fields)
cal = Cal.calendar(start_time, end_time, freq)
if len(cal) == 0:
- return pd.DataFrame(columns=column_names)
+ return pd.DataFrame(
+ index=pd.MultiIndex.from_arrays([[], []], names=("instrument", "datetime")), columns=column_names
+ )
start_time = cal[0]
end_time = cal[-1]
@@ -759,26 +788,12 @@ def multi_cache_walker(instruments, fields, start_time=None, end_time=None, freq
return
start_time = cal[0]
end_time = cal[-1]
- workers = min(C.kernels, len(instruments_d))
- if C.maxtasksperchild is None:
- p = Pool(processes=workers)
- else:
- p = Pool(processes=workers, maxtasksperchild=C.maxtasksperchild)
-
- for inst in instruments_d:
- p.apply_async(
- LocalDatasetProvider.cache_walker,
- args=(
- inst,
- start_time,
- end_time,
- freq,
- column_names,
- ),
- )
+ workers = max(min(C.kernels, len(instruments_d)), 1)
- p.close()
- p.join()
+ ParallelExt(n_jobs=workers, backend=C.joblib_backend, maxtasksperchild=C.maxtasksperchild)(
+ delayed(LocalDatasetProvider.cache_walker)(inst, start_time, end_time, freq, column_names)
+ for inst in instruments_d
+ )
@staticmethod
def cache_walker(inst, start_time, end_time, freq, column_names):
@@ -807,12 +822,7 @@ def set_conn(self, conn):
def calendar(self, start_time=None, end_time=None, freq="day", future=False):
self.conn.send_request(
request_type="calendar",
- request_content={
- "start_time": str(start_time),
- "end_time": str(end_time),
- "freq": freq,
- "future": future,
- },
+ request_content={"start_time": str(start_time), "end_time": str(end_time), "freq": freq, "future": future},
msg_queue=self.queue,
msg_proc_func=lambda response_content: [pd.Timestamp(c) for c in response_content],
)
@@ -920,7 +930,10 @@ def dataset(
column_names = self.get_column_names(fields)
cal = Cal.calendar(start_time, end_time, freq)
if len(cal) == 0:
- return pd.DataFrame(columns=column_names)
+ return pd.DataFrame(
+ index=pd.MultiIndex.from_arrays([[], []], names=("instrument", "datetime")),
+ columns=column_names,
+ )
start_time = cal[0]
end_time = cal[-1]
@@ -963,7 +976,7 @@ def dataset(
get_module_logger("data").debug("get result")
try:
# pre-mound nfs, used for demo
- mnt_feature_uri = C.dpm.get_data_path(freq).joinpath(C.dataset_cache_dir_name, feature_uri)
+ mnt_feature_uri = C.dpm.get_data_uri(freq).joinpath(C.dataset_cache_dir_name, feature_uri)
df = DiskDatasetCache.read_data_from_cache(mnt_feature_uri, start_time, end_time, fields)
get_module_logger("data").debug("finish slicing data")
if return_uri:
@@ -1063,7 +1076,7 @@ class ClientProvider(BaseProvider):
- Instruments (with filters): Respond a list/dict of instruments
- Features : Respond a cache uri
The general workflow is described as follows:
- When the user use client provider to propose a request, the client provider will connect the server and send the request. The client will start to wait for the response. The response will be made instantly indicating whether the cache is available. The waiting procedure will terminate only when the client get the reponse saying `feature_available` is true.
+ When the user use client provider to propose a request, the client provider will connect the server and send the request. The client will start to wait for the response. The response will be made instantly indicating whether the cache is available. The waiting procedure will terminate only when the client get the response saying `feature_available` is true.
`BUG` : Everytime we make request for certain data we need to connect to the server, wait for the response and disconnect from it. We can't make a sequence of requests within one connection. You can refer to https://python-socketio.readthedocs.io/en/latest/client.html for documentation of python-socketIO client.
"""
@@ -1143,13 +1156,13 @@ def register_all_wrappers(C):
if getattr(C, "expression_cache", None) is not None:
_eprovider = init_instance_by_config(C.expression_cache, module, provider=_eprovider)
register_wrapper(ExpressionD, _eprovider, "qlib.data")
- logger.debug(f"registering ExpressioneD {C.expression_provider}-{C.expression_cache}")
+ logger.debug(f"registering ExpressionD {C.expression_provider}-{C.expression_cache}")
_dprovider = init_instance_by_config(C.dataset_provider, module)
if getattr(C, "dataset_cache", None) is not None:
_dprovider = init_instance_by_config(C.dataset_cache, module, provider=_dprovider)
register_wrapper(DatasetD, _dprovider, "qlib.data")
- logger.debug(f"registering DataseteD {C.dataset_provider}-{C.dataset_cache}")
+ logger.debug(f"registering DatasetD {C.dataset_provider}-{C.dataset_cache}")
register_wrapper(D, C.provider, "qlib.data")
logger.debug(f"registering D {C.provider}")
diff --git a/qlib/data/dataset/handler.py b/qlib/data/dataset/handler.py
index 92e73de048..507e5ea81c 100644
--- a/qlib/data/dataset/handler.py
+++ b/qlib/data/dataset/handler.py
@@ -17,7 +17,7 @@
from ...config import C
from ...utils import parse_config, transform_end_date, init_instance_by_config
from ...utils.serial import Serializable
-from .utils import fetch_df_by_index
+from .utils import fetch_df_by_index, fetch_df_by_col
from ...utils import lazy_sort_index
from pathlib import Path
from .loader import DataLoader
@@ -154,14 +154,6 @@ def setup_data(self, enable_cache: bool = False):
CS_ALL = "__all" # return all columns with single-level index column
CS_RAW = "__raw" # return raw data with multi-level index column
- def _fetch_df_by_col(self, df: pd.DataFrame, col_set: str) -> pd.DataFrame:
- if not isinstance(df.columns, pd.MultiIndex) or col_set == self.CS_RAW:
- return df
- elif col_set == self.CS_ALL:
- return df.droplevel(axis=1, level=0)
- else:
- return df.loc(axis=1)[col_set]
-
def fetch(
self,
selector: Union[pd.Timestamp, slice, str] = slice(None, None),
@@ -185,7 +177,7 @@ def fetch(
select a set of meaningful columns.(e.g. features, columns)
- if cal_set == CS_RAW:
+ if col_set == CS_RAW:
the raw dataset will be returned.
- if isinstance(col_set, List[str]):
@@ -207,23 +199,41 @@ def fetch(
-------
pd.DataFrame.
"""
- if proc_func is None:
- df = self._data
+ from .storage import BaseHandlerStorage
+
+ data_storage = self._data
+ if isinstance(data_storage, pd.DataFrame):
+ data_df = data_storage
+ if proc_func is not None:
+ # FIXME: fetching by time first will be more friendly to `proc_func`
+ # Copy in case of `proc_func` changing the data inplace....
+ data_df = proc_func(fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig).copy())
+ data_df = fetch_df_by_col(data_df, col_set)
+ else:
+ # Fetch column first will be more friendly to SepDataFrame
+ data_df = fetch_df_by_col(data_df, col_set)
+ data_df = fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig)
+ elif isinstance(data_storage, BaseHandlerStorage):
+ if not data_storage.is_proc_func_supported():
+ if proc_func is not None:
+ raise ValueError(f"proc_func is not supported by the storage {type(data_storage)}")
+ data_df = data_storage.fetch(
+ selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig
+ )
+ else:
+ data_df = data_storage.fetch(
+ selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig, proc_func=proc_func
+ )
else:
- # FIXME: fetching by time first will be more friendly to `proc_func`
- # Copy in case of `proc_func` changing the data inplace....
- df = proc_func(fetch_df_by_index(self._data, selector, level, fetch_orig=self.fetch_orig).copy())
+ raise TypeError(f"data_storage should be pd.DataFrame|HasingStockStorage, not {type(data_storage)}")
- # Fetch column first will be more friendly to SepDataFrame
- df = self._fetch_df_by_col(df, col_set)
- df = fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig)
if squeeze:
# squeeze columns
- df = df.squeeze()
+ data_df = data_df.squeeze()
# squeeze index
if isinstance(selector, (str, pd.Timestamp)):
- df = df.reset_index(level=level, drop=True)
- return df
+ data_df = data_df.reset_index(level=level, drop=True)
+ return data_df
def get_cols(self, col_set=CS_ALL) -> list:
"""
@@ -240,7 +250,7 @@ def get_cols(self, col_set=CS_ALL) -> list:
list of column names
"""
df = self._data.head()
- df = self._fetch_df_by_col(df, col_set)
+ df = fetch_df_by_col(df, col_set)
return df.columns.to_list()
def get_range_selector(self, cur_date: Union[pd.Timestamp, str], periods: int) -> slice:
@@ -564,14 +574,36 @@ def fetch(
-------
pd.DataFrame:
"""
- df = self._get_df_by_key(data_key)
- if proc_func is not None:
- # FIXME: fetch by time first will be more friendly to proc_func
- # Copy incase of `proc_func` changing the data inplace....
- df = proc_func(fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig).copy())
- # Fetch column first will be more friendly to SepDataFrame
- df = self._fetch_df_by_col(df, col_set)
- return fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig)
+ from .storage import BaseHandlerStorage
+
+ data_storage = self._get_df_by_key(data_key)
+ if isinstance(data_storage, pd.DataFrame):
+ data_df = data_storage
+ if proc_func is not None:
+ # FIXME: fetch by time first will be more friendly to proc_func
+ # Copy incase of `proc_func` changing the data inplace....
+ data_df = proc_func(fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig).copy())
+ data_df = fetch_df_by_col(data_df, col_set)
+ else:
+ # Fetch column first will be more friendly to SepDataFrame
+ data_df = fetch_df_by_col(data_df, col_set)
+ data_df = fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig)
+
+ elif isinstance(data_storage, BaseHandlerStorage):
+ if not data_storage.is_proc_func_supported():
+ if proc_func is not None:
+ raise ValueError(f"proc_func is not supported by the storage {type(data_storage)}")
+ data_df = data_storage.fetch(
+ selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig
+ )
+ else:
+ data_df = data_storage.fetch(
+ selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig, proc_func=proc_func
+ )
+ else:
+ raise TypeError(f"data_storage should be pd.DataFrame|HasingStockStorage, not {type(data_storage)}")
+
+ return data_df
def get_cols(self, col_set=DataHandler.CS_ALL, data_key: str = DK_I) -> list:
"""
@@ -590,5 +622,5 @@ def get_cols(self, col_set=DataHandler.CS_ALL, data_key: str = DK_I) -> list:
list of column names
"""
df = self._get_df_by_key(data_key).head()
- df = self._fetch_df_by_col(df, col_set)
+ df = fetch_df_by_col(df, col_set)
return df.columns.to_list()
diff --git a/qlib/data/dataset/processor.py b/qlib/data/dataset/processor.py
index 9114ac918e..6de9183466 100644
--- a/qlib/data/dataset/processor.py
+++ b/qlib/data/dataset/processor.py
@@ -327,3 +327,12 @@ def __call__(self, df):
cols = get_group_columns(df, self.fields_group)
df[cols] = df[cols].groupby("datetime").apply(lambda x: x.fillna(x.mean()))
return df
+
+
+class HashStockFormat(Processor):
+ """Process the storage of from df into hasing stock format"""
+
+ def __call__(self, df: pd.DataFrame):
+ from .storage import HasingStockStorage
+
+ return HasingStockStorage.from_df(df)
diff --git a/qlib/data/dataset/storage.py b/qlib/data/dataset/storage.py
new file mode 100644
index 0000000000..7f556f497d
--- /dev/null
+++ b/qlib/data/dataset/storage.py
@@ -0,0 +1,157 @@
+import pandas as pd
+import numpy as np
+
+from .handler import DataHandler
+from typing import Tuple, Union, List, Callable
+
+from .utils import get_level_index, fetch_df_by_index, fetch_df_by_col
+
+
+class BaseHandlerStorage:
+ """Base data storage for datahandler
+ - pd.DataFrame is the default data storage format in Qlib datahandler
+ - If users want to use custom data storage, they should define subclass inherited BaseHandlerStorage, and implement the following method
+ """
+
+ def fetch(
+ self,
+ selector: Union[pd.Timestamp, slice, str, list] = slice(None, None),
+ level: Union[str, int] = "datetime",
+ col_set: Union[str, List[str]] = DataHandler.CS_ALL,
+ fetch_orig: bool = True,
+ proc_func: Callable = None,
+ **kwargs,
+ ) -> pd.DataFrame:
+ """fetch data from the data storage
+
+ Parameters
+ ----------
+ selector : Union[pd.Timestamp, slice, str]
+ describe how to select data by index
+ level : Union[str, int]
+ which index level to select the data
+ - if level is None, apply selector to df directly
+ col_set : Union[str, List[str]]
+ - if isinstance(col_set, str):
+ select a set of meaningful columns.(e.g. features, columns)
+ if col_set == DataHandler.CS_RAW:
+ the raw dataset will be returned.
+ - if isinstance(col_set, List[str]):
+ select several sets of meaningful columns, the returned data has multiple level
+ fetch_orig : bool
+ Return the original data instead of copy if possible.
+ proc_func: Callable
+ please refer to the doc of DataHandler.fetch
+
+ Returns
+ -------
+ pd.DataFrame
+ the dataframe fetched
+ """
+ raise NotImplementedError("fetch is method not implemented!")
+
+ @staticmethod
+ def from_df(df: pd.DataFrame):
+ raise NotImplementedError("from_df method is not implemented!")
+
+ def is_proc_func_supported(self):
+ """whether the arg `proc_func` in `fetch` method is supported."""
+ raise NotImplementedError("is_proc_func_supported method is not implemented!")
+
+
+class HasingStockStorage(BaseHandlerStorage):
+ """Hasing data storage for datahanlder
+ - The default data storage pandas.DataFrame is too slow when randomly accessing one stock's data
+ - HasingStockStorage hashes the multiple stocks' data(pandas.DataFrame) by the key `stock_id`.
+ - HasingStockStorage hases the pandas.DataFrame into a dict, whose key is the stock_id(str) and value this stock data(panda.DataFrame), it has the following format:
+ {
+ stock1_id: stock1_data,
+ stock2_id: stock2_data,
+ ...
+ stockn_id: stockn_data,
+ }
+ - By the `fetch` method, users can access any stock data with much lower time cost than default data storage
+ """
+
+ def __init__(self, df):
+ self.hash_df = dict()
+ self.stock_level = get_level_index(df, "instrument")
+ for k, v in df.groupby(level="instrument"):
+ self.hash_df[k] = v
+ self.columns = df.columns
+
+ @staticmethod
+ def from_df(df):
+ return HasingStockStorage(df)
+
+ def _fetch_hash_df_by_stock(self, selector, level):
+ """fetch the data with stock selector
+
+ Parameters
+ ----------
+ selector : Union[pd.Timestamp, slice, str]
+ describe how to select data by index
+ level : Union[str, int]
+ which index level to select the data
+ - if level is None, apply selector to df directly
+ - the `_fetch_hash_df_by_stock` will parse the stock selector in arg `selector`
+
+ Returns
+ -------
+ Dict
+ The dict whose key is stock_id, value is the stock's data
+ """
+
+ stock_selector = slice(None)
+
+ if level is None:
+ if isinstance(selector, tuple) and self.stock_level < len(selector):
+ stock_selector = selector[self.stock_level]
+ elif isinstance(selector, (list, str)) and self.stock_level == 0:
+ stock_selector = selector
+ elif level == "instrument" or level == self.stock_level:
+ if isinstance(selector, tuple):
+ stock_selector = selector[0]
+ elif isinstance(selector, (list, str)):
+ stock_selector = selector
+
+ if not isinstance(stock_selector, (list, str)) and stock_selector != slice(None):
+ raise TypeError(f"stock selector must be type str|list, or slice(None), rather than {stock_selector}")
+
+ if stock_selector == slice(None):
+ return self.hash_df
+
+ if isinstance(stock_selector, str):
+ stock_selector = [stock_selector]
+
+ select_dict = dict()
+ for each_stock in sorted(stock_selector):
+ if each_stock in self.hash_df:
+ select_dict[each_stock] = self.hash_df[each_stock]
+ return select_dict
+
+ def fetch(
+ self,
+ selector: Union[pd.Timestamp, slice, str] = slice(None, None),
+ level: Union[str, int] = "datetime",
+ col_set: Union[str, List[str]] = DataHandler.CS_ALL,
+ fetch_orig: bool = True,
+ ) -> pd.DataFrame:
+ fetch_stock_df_list = list(self._fetch_hash_df_by_stock(selector=selector, level=level).values())
+ for _index, stock_df in enumerate(fetch_stock_df_list):
+ fetch_col_df = fetch_df_by_col(df=stock_df, col_set=col_set)
+ fetch_index_df = fetch_df_by_index(df=fetch_col_df, selector=selector, level=level, fetch_orig=fetch_orig)
+ fetch_stock_df_list[_index] = fetch_index_df
+ if len(fetch_stock_df_list) == 0:
+ index_names = ("instrument", "datetime") if self.stock_level == 0 else ("datetime", "instrument")
+ return pd.DataFrame(
+ index=pd.MultiIndex.from_arrays([[], []], names=index_names), columns=self.columns, dtype=np.float32
+ )
+ elif len(fetch_stock_df_list) == 1:
+ return fetch_stock_df_list[0]
+ else:
+ return pd.concat(fetch_stock_df_list, sort=False, copy=~fetch_orig)
+
+ def is_proc_func_supported(self):
+ """the arg `proc_func` in `fetch` method is not supported in HasingStockStorage"""
+ return False
diff --git a/qlib/data/dataset/utils.py b/qlib/data/dataset/utils.py
index feda190446..c6b3d97b62 100644
--- a/qlib/data/dataset/utils.py
+++ b/qlib/data/dataset/utils.py
@@ -1,5 +1,8 @@
-from typing import Union
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
import pandas as pd
+from typing import Union, List
def get_level_index(df: pd.DataFrame, level=Union[str, int]) -> int:
@@ -70,3 +73,41 @@ def fetch_df_by_index(
return df.loc[
pd.IndexSlice[idx_slc],
]
+
+
+def fetch_df_by_col(df: pd.DataFrame, col_set: Union[str, List[str]]) -> pd.DataFrame:
+ from .handler import DataHandler
+
+ if not isinstance(df.columns, pd.MultiIndex) or col_set == DataHandler.CS_RAW:
+ return df
+ elif col_set == DataHandler.CS_ALL:
+ return df.droplevel(axis=1, level=0)
+ else:
+ return df.loc(axis=1)[col_set]
+
+
+def convert_index_format(df: Union[pd.DataFrame, pd.Series], level: str = "datetime") -> Union[pd.DataFrame, pd.Series]:
+ """
+ Convert the format of df.MultiIndex according to the following rules:
+ - If `level` is the first level of df.MultiIndex, do nothing
+ - If `level` is the second level of df.MultiIndex, swap the level of index.
+
+ NOTE:
+ the number of levels of df.MultiIndex should be 2
+
+ Parameters
+ ----------
+ df : Union[pd.DataFrame, pd.Series]
+ raw DataFrame/Series
+ level : str, optional
+ the level that will be converted to the first one, by default "datetime"
+
+ Returns
+ -------
+ Union[pd.DataFrame, pd.Series]
+ converted DataFrame/Series
+ """
+
+ if get_level_index(df, level=level) == 1:
+ df = df.swaplevel().sort_index()
+ return df
diff --git a/qlib/data/ops.py b/qlib/data/ops.py
index a34b2ed354..532072f89d 100644
--- a/qlib/data/ops.py
+++ b/qlib/data/ops.py
@@ -1405,7 +1405,7 @@ def __init__(self, feature_left, feature_right, N):
super(Corr, self).__init__(feature_left, feature_right, N, "corr")
def _load_internal(self, instrument, start_index, end_index, freq):
- res = super(Corr, self)._load_internal(instrument, start_index, end_index, freq)
+ res: pd.Series = super(Corr, self)._load_internal(instrument, start_index, end_index, freq)
# NOTE: Load uses MemCache, so calling load again will not cause performance degradation
series_left = self.feature_left.load(instrument, start_index, end_index, freq)
diff --git a/qlib/data/storage/file_storage.py b/qlib/data/storage/file_storage.py
index a2b145c4df..5114ced3ca 100644
--- a/qlib/data/storage/file_storage.py
+++ b/qlib/data/storage/file_storage.py
@@ -8,8 +8,11 @@
import numpy as np
import pandas as pd
+from qlib.utils.time import Freq
+from qlib.utils.resam import resam_calendar
from qlib.log import get_module_logger
from qlib.data.storage import CalendarStorage, InstrumentStorage, FeatureStorage, CalVT, InstKT, InstVT
+from qlib.data.cache import H
logger = get_module_logger("file_storage")
@@ -39,6 +42,7 @@ def check(self):
class FileCalendarStorage(FileStorageMixin, CalendarStorage):
def __init__(self, freq: str, future: bool, **kwargs):
super(FileCalendarStorage, self).__init__(freq, future, **kwargs)
+ self.future = future
self.file_name = f"{freq}_future.txt" if future else f"{freq}.txt".lower()
def _read_calendar(self, skip_rows: int = 0, n_rows: int = None) -> List[CalVT]:
@@ -56,8 +60,31 @@ def _write_calendar(self, values: Iterable[CalVT], mode: str = "wb"):
@property
def data(self) -> List[CalVT]:
- self.check()
- return self._read_calendar()
+ # NOTE: uri
+ # 1. If `uri` does not exist
+ # - Get the `min_uri` of the closest `freq` under the same "directory" as the `uri`
+ # - Read data from `min_uri` and resample to `freq`
+ try:
+ self.check()
+ _calendar = self._read_calendar()
+ except ValueError:
+ freq_list = self._get_storage_freq()
+ _freq = Freq.get_recent_freq(self.freq, freq_list)
+ if _freq is None:
+ raise ValueError(f"can't find a freq from {freq_list} that can resample to {self.freq}!")
+ self.file_name = f"{_freq}_future.txt" if self.future else f"{_freq}.txt".lower()
+ # The cache is useful for the following cases
+ # - multiple frequencies are sampled from the same calendar
+ cache_key = self.uri
+ if cache_key not in H["c"]:
+ H["c"][cache_key] = self._read_calendar()
+ _calendar = H["c"][cache_key]
+ _calendar = resam_calendar(np.array(list(map(pd.Timestamp, _calendar))), _freq, self.freq)
+
+ return _calendar
+
+ def _get_storage_freq(self) -> List[str]:
+ return sorted(set(map(lambda x: x.stem.split("_")[0], self.uri.parent.glob("*.txt"))))
def extend(self, values: Iterable[CalVT]) -> None:
self._write_calendar(values, mode="ab")
diff --git a/qlib/model/trainer.py b/qlib/model/trainer.py
index fb731196d1..9c6866823f 100644
--- a/qlib/model/trainer.py
+++ b/qlib/model/trainer.py
@@ -12,6 +12,8 @@
"""
import socket
+import time
+import re
from typing import Callable, List
from qlib.data.dataset import Dataset
@@ -44,6 +46,47 @@ def begin_task_train(task_config: dict, experiment_name: str, recorder_name: str
return recorder
+def fill_placeholder(config: dict, config_extend: dict):
+ """
+ Detect placeholder in config and fill them with config_extend.
+ The item of dict must be single item(int, str, etc), dict and list. Tuples are not supported.
+
+ Parameters
+ ----------
+ config : dict
+ the parameter dict will be filled
+ config_extend : dict
+ the value of all placeholders
+
+ Returns
+ -------
+ dict
+ the parameter dict
+ """
+ # check the format of config_extend
+ for placeholder in config_extend.keys():
+ assert re.match(r"<[^<>]+>", placeholder)
+
+ # bfs
+ top = 0
+ tail = 1
+ item_quene = [config]
+ while top < tail:
+ now_item = item_quene[top]
+ top += 1
+ if isinstance(now_item, list):
+ item_keys = range(len(now_item))
+ elif isinstance(now_item, dict):
+ item_keys = now_item.keys()
+ for key in item_keys:
+ if isinstance(now_item[key], list) or isinstance(now_item[key], dict):
+ item_quene.append(now_item[key])
+ tail += 1
+ elif now_item[key] in config_extend.keys():
+ now_item[key] = config_extend[now_item[key]]
+ return config
+
+
def end_task_train(rec: Recorder, experiment_name: str) -> Recorder:
"""
Finish task training with real model fitting and saving.
@@ -66,19 +109,16 @@ def end_task_train(rec: Recorder, experiment_name: str) -> Recorder:
# this dataset is saved for online inference. So the concrete data should not be dumped
dataset.config(dump_all=False, recursive=True)
R.save_objects(**{"dataset": dataset})
+ # fill placehorder
+ placehorder_value = {"": model, "": dataset}
+ task_config = fill_placeholder(task_config, placehorder_value)
# generate records: prediction, backtest, and analysis
records = task_config.get("record", [])
if isinstance(records, dict): # prevent only one dict
records = [records]
for record in records:
- cls, kwargs = get_callable_kwargs(record, default_module="qlib.workflow.record_temp")
- if cls is SignalRecord:
- rconf = {"model": model, "dataset": dataset, "recorder": rec}
- else:
- rconf = {"recorder": rec}
- r = cls(**kwargs, **rconf)
+ r = init_instance_by_config(record, recorder=rec)
r.generate()
-
return rec
diff --git a/qlib/rl/__init__.py b/qlib/rl/__init__.py
new file mode 100644
index 0000000000..59e481eb93
--- /dev/null
+++ b/qlib/rl/__init__.py
@@ -0,0 +1,2 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
diff --git a/qlib/rl/env.py b/qlib/rl/env.py
new file mode 100644
index 0000000000..3a77d22954
--- /dev/null
+++ b/qlib/rl/env.py
@@ -0,0 +1,95 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+from typing import Union
+
+
+from ..backtest.executor import BaseExecutor
+from .interpreter import StateInterpreter, ActionInterpreter
+from ..utils import init_instance_by_config
+from .interpreter import BaseInterpreter
+
+
+class BaseRLEnv:
+ """Base environment for reinforcement learning"""
+
+ def reset(self, **kwargs):
+ raise NotImplementedError("reset is not implemented!")
+
+ def step(self, action):
+ """
+ step method of rl env
+ Parameters
+ ----------
+ action :
+ action from rl policy
+
+ Returns
+ -------
+ env state to rl policy
+ """
+ raise NotImplementedError("step is not implemented!")
+
+
+class QlibRLEnv:
+ """qlib-based RL env"""
+
+ def __init__(
+ self,
+ executor: BaseExecutor,
+ ):
+ """
+ Parameters
+ ----------
+ executor : BaseExecutor
+ qlib multi-level/single-level executor, which can be regarded as gamecore in RL
+ """
+ self.executor = executor
+
+ def reset(self, **kwargs):
+ self.executor.reset(**kwargs)
+
+
+class QlibIntRLEnv(QlibRLEnv):
+ """(Qlib)-based RL (Env) with (Interpreter)"""
+
+ def __init__(
+ self,
+ executor: BaseExecutor,
+ state_interpreter: Union[dict, StateInterpreter],
+ action_interpreter: Union[dict, ActionInterpreter],
+ ):
+ """
+
+ Parameters
+ ----------
+ state_interpreter : Union[dict, StateInterpreter]
+ interpretor that interprets the qlib execute result into rl env state.
+
+ action_interpreter : Union[dict, ActionInterpreter]
+ interpretor that interprets the rl agent action into qlib order list
+ """
+ super(QlibIntRLEnv, self).__init__(executor=executor)
+ self.state_interpreter = init_instance_by_config(state_interpreter, accept_types=StateInterpreter)
+ self.action_interpreter = init_instance_by_config(action_interpreter, accept_types=ActionInterpreter)
+
+ def step(self, action):
+ """
+ step method of rl env, it run as following step:
+ - Use `action_interpreter.interpret` method to interpret the agent action into order list
+ - Execute the order list with qlib executor, and get the executed result
+ - Use `state_interpreter.interpret` method to interpret the executed result into env state
+
+ Parameters
+ ----------
+ action :
+ action from rl policy
+
+ Returns
+ -------
+ env state to rl policy
+ """
+ _interpret_decision = self.action_interpreter.interpret(action=action)
+ _execute_result = self.executor.execute(trade_decision=_interpret_decision)
+ _interpret_state = self.state_interpreter.interpret(execute_result=_execute_result)
+ return _interpret_state
diff --git a/qlib/rl/interpreter.py b/qlib/rl/interpreter.py
new file mode 100644
index 0000000000..c711b83808
--- /dev/null
+++ b/qlib/rl/interpreter.py
@@ -0,0 +1,47 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+
+class BaseInterpreter:
+ """Base Interpreter"""
+
+ def interpret(self, **kwargs):
+ raise NotImplementedError("interpret is not implemented!")
+
+
+class ActionInterpreter(BaseInterpreter):
+ """Action Interpreter that interpret rl agent action into qlib orders"""
+
+ def interpret(self, action, **kwargs):
+ """interpret method
+
+ Parameters
+ ----------
+ action :
+ rl agent action
+
+ Returns
+ -------
+ qlib orders
+
+ """
+
+ raise NotImplementedError("interpret is not implemented!")
+
+
+class StateInterpreter(BaseInterpreter):
+ """State Interpreter that interpret execution result of qlib executor into rl env state"""
+
+ def interpret(self, execute_result, **kwargs):
+ """interpret method
+
+ Parameters
+ ----------
+ execute_result :
+ qlib execution result
+
+ Returns
+ ----------
+ rl env state
+ """
+ raise NotImplementedError("interpret is not implemented!")
diff --git a/qlib/strategy/__init__.py b/qlib/strategy/__init__.py
new file mode 100644
index 0000000000..59e481eb93
--- /dev/null
+++ b/qlib/strategy/__init__.py
@@ -0,0 +1,2 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
diff --git a/qlib/strategy/base.py b/qlib/strategy/base.py
new file mode 100644
index 0000000000..f707f7ff5b
--- /dev/null
+++ b/qlib/strategy/base.py
@@ -0,0 +1,289 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+from __future__ import annotations
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+ from qlib.backtest.exchange import Exchange
+ from qlib.backtest.position import BasePosition
+from typing import List, Tuple, Union
+
+from ..model.base import BaseModel
+from ..data.dataset import DatasetH
+from ..data.dataset.utils import convert_index_format
+from ..rl.interpreter import ActionInterpreter, StateInterpreter
+from ..utils import init_instance_by_config
+from ..backtest.utils import CommonInfrastructure, LevelInfrastructure, TradeCalendarManager
+from ..backtest.decision import BaseTradeDecision
+
+__all__ = ["BaseStrategy", "ModelStrategy", "RLStrategy", "RLIntStrategy"]
+
+
+class BaseStrategy:
+ """Base strategy for trading"""
+
+ def __init__(
+ self,
+ outer_trade_decision: BaseTradeDecision = None,
+ level_infra: LevelInfrastructure = None,
+ common_infra: CommonInfrastructure = None,
+ trade_exchange: Exchange = None,
+ ):
+ """
+ Parameters
+ ----------
+ outer_trade_decision : BaseTradeDecision, optional
+ the trade decision of outer strategy which this startegy relies, and it will be traded in [start_time, end_time], by default None
+ - If the strategy is used to split trade decision, it will be used
+ - If the strategy is used for portfolio management, it can be ignored
+ level_infra : LevelInfrastructure, optional
+ level shared infrastructure for backtesting, including trade calendar
+ common_infra : CommonInfrastructure, optional
+ common infrastructure for backtesting, including trade_account, trade_exchange, .etc
+
+ trade_exchange : Exchange
+ exchange that provides market info, used to deal order and generate report
+ - If `trade_exchange` is None, self.trade_exchange will be set with common_infra
+ - It allowes different trade_exchanges is used in different executions.
+ - For example:
+ - In daily execution, both daily exchange and minutely are usable, but the daily exchange is recommended because it run faster.
+ - In minutely execution, the daily exchange is not usable, only the minutely exchange is recommended.
+ """
+
+ self._reset(level_infra=level_infra, common_infra=common_infra, outer_trade_decision=outer_trade_decision)
+ self._trade_exchange = trade_exchange
+
+ @property
+ def trade_calendar(self) -> TradeCalendarManager:
+ return self.level_infra.get("trade_calendar")
+
+ @property
+ def trade_position(self) -> BasePosition:
+ return self.common_infra.get("trade_account").current_position
+
+ @property
+ def trade_exchange(self) -> Exchange:
+ """get trade exchange in a prioritized order"""
+ return getattr(self, "_trade_exchange", None) or self.common_infra.get("trade_exchange")
+
+ def reset_level_infra(self, level_infra: LevelInfrastructure):
+ if not hasattr(self, "level_infra"):
+ self.level_infra = level_infra
+ else:
+ self.level_infra.update(level_infra)
+
+ def reset_common_infra(self, common_infra: CommonInfrastructure):
+ if not hasattr(self, "common_infra"):
+ self.common_infra: CommonInfrastructure = common_infra
+ else:
+ self.common_infra.update(common_infra)
+
+ def reset(
+ self,
+ level_infra: LevelInfrastructure = None,
+ common_infra: CommonInfrastructure = None,
+ outer_trade_decision=None,
+ **kwargs,
+ ):
+ """
+ - reset `level_infra`, used to reset trade calendar, .etc
+ - reset `common_infra`, used to reset `trade_account`, `trade_exchange`, .etc
+ - reset `outer_trade_decision`, used to make split decision
+
+ **NOTE**:
+ split this function into `reset` and `_reset` will make following cases more convenient
+ 1. Users want to initialize his strategy by overriding `reset`, but they don't want to affect the `_reset` called
+ when initialization
+ """
+ self._reset(
+ level_infra=level_infra, common_infra=common_infra, outer_trade_decision=outer_trade_decision, **kwargs
+ )
+
+ def _reset(
+ self,
+ level_infra: LevelInfrastructure = None,
+ common_infra: CommonInfrastructure = None,
+ outer_trade_decision=None,
+ ):
+ """
+ Please refer to the docs of `reset`
+ """
+ if level_infra is not None:
+ self.reset_level_infra(level_infra)
+
+ if common_infra is not None:
+ self.reset_common_infra(common_infra)
+
+ if outer_trade_decision is not None:
+ self.outer_trade_decision = outer_trade_decision
+
+ def generate_trade_decision(self, execute_result=None):
+ """Generate trade decision in each trading bar
+
+ Parameters
+ ----------
+ execute_result : List[object], optional
+ the executed result for trade decision, by default None
+ - When call the generate_trade_decision firstly, `execute_result` could be None
+ """
+ raise NotImplementedError("generate_trade_decision is not implemented!")
+
+ def update_trade_decision(
+ self, trade_decision: BaseTradeDecision, trade_calendar: TradeCalendarManager
+ ) -> Union[BaseTradeDecision, None]:
+ """
+ update trade decision in each step of inner execution, this method enable all order
+
+ Parameters
+ ----------
+ trade_decision : BaseTradeDecision
+ the trade decision that will be updated
+ trade_calendar : TradeCalendarManager
+ The calendar of the **inner strategy**!!!!!
+
+ Returns
+ -------
+ BaseTradeDecision:
+ """
+ # default to return None, which indicates that the trade decision is not changed
+ return None
+
+ def alter_outer_trade_decision(self, outer_trade_decision: BaseTradeDecision):
+ """
+ A method for updating the outer_trade_decision.
+ The outer strategy may change its decision during updating.
+
+ Parameters
+ ----------
+ outer_trade_decision : BaseTradeDecision
+ the decision updated by the outer strategy
+ """
+ # default to reset the decision directly
+ # NOTE: normally, user should do something to the strategy due to the change of outer decision
+ raise NotImplementedError(f"Please implement the `alter_outer_trade_decision` method")
+
+ # helper methods: not necessary but for convenience
+ def get_data_cal_avail_range(self, rtype: str = "full") -> Tuple[int, int]:
+ """
+ return data calendar's available decision range for `self` strategy
+ the range consider following factors
+ - data calendar in the charge of `self` strategy
+ - trading range limitation from the decision of outer strategy
+
+
+ related methods
+ - TradeCalendarManager.get_data_cal_range
+ - BaseTradeDecision.get_data_cal_range_limit
+
+ Parameters
+ ----------
+ rtype: str
+ - "full": return the available data index range of the strategy from `start_time` to `end_time`
+ - "step": return the available data index range of the strategy of current step
+
+ Returns
+ -------
+ Tuple[int, int]:
+ the available range both sides are closed
+ """
+ cal_range = self.trade_calendar.get_data_cal_range(rtype=rtype)
+ if self.outer_trade_decision is None:
+ raise ValueError(f"There is not limitation for strategy {self}")
+ range_limit = self.outer_trade_decision.get_data_cal_range_limit(rtype=rtype)
+ return max(cal_range[0], range_limit[0]), min(cal_range[1], range_limit[1])
+
+
+class ModelStrategy(BaseStrategy):
+ """Model-based trading strategy, use model to make predictions for trading"""
+
+ def __init__(
+ self,
+ model: BaseModel,
+ dataset: DatasetH,
+ outer_trade_decision: BaseTradeDecision = None,
+ level_infra: LevelInfrastructure = None,
+ common_infra: CommonInfrastructure = None,
+ **kwargs,
+ ):
+ """
+ Parameters
+ ----------
+ model : BaseModel
+ the model used in when making predictions
+ dataset : DatasetH
+ provide test data for model
+ kwargs : dict
+ arguments that will be passed into `reset` method
+ """
+ super(ModelStrategy, self).__init__(outer_trade_decision, level_infra, common_infra, **kwargs)
+ self.model = model
+ self.dataset = dataset
+ self.pred_scores = convert_index_format(self.model.predict(dataset), level="datetime")
+
+ def _update_model(self):
+ """
+ When using online data, pdate model in each bar as the following steps:
+ - update dataset with online data, the dataset should support online update
+ - make the latest prediction scores of the new bar
+ - update the pred score into the latest prediction
+ """
+ raise NotImplementedError("_update_model is not implemented!")
+
+
+class RLStrategy(BaseStrategy):
+ """RL-based strategy"""
+
+ def __init__(
+ self,
+ policy,
+ outer_trade_decision: BaseTradeDecision = None,
+ level_infra: LevelInfrastructure = None,
+ common_infra: CommonInfrastructure = None,
+ **kwargs,
+ ):
+ """
+ Parameters
+ ----------
+ policy :
+ RL policy for generate action
+ """
+ super(RLStrategy, self).__init__(outer_trade_decision, level_infra, common_infra, **kwargs)
+ self.policy = policy
+
+
+class RLIntStrategy(RLStrategy):
+ """(RL)-based (Strategy) with (Int)erpreter"""
+
+ def __init__(
+ self,
+ policy,
+ state_interpreter: Union[dict, StateInterpreter],
+ action_interpreter: Union[dict, ActionInterpreter],
+ outer_trade_decision: BaseTradeDecision = None,
+ level_infra: LevelInfrastructure = None,
+ common_infra: CommonInfrastructure = None,
+ **kwargs,
+ ):
+ """
+ Parameters
+ ----------
+ state_interpreter : Union[dict, StateInterpreter]
+ interpretor that interprets the qlib execute result into rl env state
+ action_interpreter : Union[dict, ActionInterpreter]
+ interpretor that interprets the rl agent action into qlib order list
+ start_time : Union[str, pd.Timestamp], optional
+ start time of trading, by default None
+ end_time : Union[str, pd.Timestamp], optional
+ end time of trading, by default None
+ """
+ super(RLIntStrategy, self).__init__(policy, outer_trade_decision, level_infra, common_infra, **kwargs)
+
+ self.policy = policy
+ self.state_interpreter = init_instance_by_config(state_interpreter, accept_types=StateInterpreter)
+ self.action_interpreter = init_instance_by_config(action_interpreter, accept_types=ActionInterpreter)
+
+ def generate_trade_decision(self, execute_result=None):
+ _interpret_state = self.state_interpreter.interpret(execute_result=execute_result)
+ _action = self.policy.step(_interpret_state)
+ _trade_decision = self.action_interpreter.interpret(action=_action)
+ return _trade_decision
diff --git a/qlib/tests/__init__.py b/qlib/tests/__init__.py
index 7f43cd99ac..cc452ae0f3 100644
--- a/qlib/tests/__init__.py
+++ b/qlib/tests/__init__.py
@@ -8,17 +8,72 @@ class TestAutoData(unittest.TestCase):
_setup_kwargs = {}
provider_uri = "~/.qlib/qlib_data/cn_data_simple" # target_dir
+ provider_uri_1day = "~/.qlib/qlib_data/cn_data" # target_dir
+ provider_uri_1min = "~/.qlib/qlib_data/cn_data_1min"
@classmethod
- def setUpClass(cls) -> None:
+ def setUpClass(cls, enable_1d_type="simple", enable_1min=False) -> None:
# use default data
+ if enable_1d_type == "simple":
+ provider_uri_day = cls.provider_uri
+ name_day = "qlib_data_simple"
+ elif enable_1d_type == "full":
+ provider_uri_day = cls.provider_uri_1day
+ name_day = "qlib_data"
+ else:
+ raise NotImplementedError(f"This type of input is not supported")
+
GetData().qlib_data(
- name="qlib_data_simple",
+ name=name_day,
region=REG_CN,
interval="1d",
- target_dir=cls.provider_uri,
+ target_dir=provider_uri_day,
delete_old=False,
exists_skip=True,
)
- init(provider_uri=cls.provider_uri, region=REG_CN, **cls._setup_kwargs)
+
+ if enable_1min:
+ GetData().qlib_data(
+ name="qlib_data",
+ region=REG_CN,
+ interval="1min",
+ target_dir=cls.provider_uri_1min,
+ delete_old=False,
+ exists_skip=True,
+ )
+
+ provider_uri_map = {"1min": cls.provider_uri_1min, "day": provider_uri_day}
+
+ client_config = {
+ "calendar_provider": {
+ "class": "LocalCalendarProvider",
+ "module_path": "qlib.data.data",
+ "kwargs": {
+ "backend": {
+ "class": "FileCalendarStorage",
+ "module_path": "qlib.data.storage.file_storage",
+ "kwargs": {"provider_uri_map": provider_uri_map},
+ }
+ },
+ },
+ "feature_provider": {
+ "class": "LocalFeatureProvider",
+ "module_path": "qlib.data.data",
+ "kwargs": {
+ "backend": {
+ "class": "FileFeatureStorage",
+ "module_path": "qlib.data.storage.file_storage",
+ "kwargs": {"provider_uri_map": provider_uri_map},
+ }
+ },
+ },
+ }
+ init(
+ provider_uri=cls.provider_uri,
+ region=REG_CN,
+ expression_cache=None,
+ dataset_cache=None,
+ **client_config,
+ **cls._setup_kwargs,
+ )
diff --git a/qlib/tests/data.py b/qlib/tests/data.py
index 2bfe435906..b38fd7eee3 100644
--- a/qlib/tests/data.py
+++ b/qlib/tests/data.py
@@ -14,7 +14,7 @@
class GetData:
- DATASET_VERSION = "v1"
+ DATASET_VERSION = "v2"
REMOTE_URL = "http://fintech.msra.cn/stock_data/downloads"
QLIB_DATA_NAME = "{dataset_name}_{region}_{interval}_{qlib_version}.zip"
diff --git a/qlib/utils/__init__.py b/qlib/utils/__init__.py
index 7504f8d611..e247ea23ba 100644
--- a/qlib/utils/__init__.py
+++ b/qlib/utils/__init__.py
@@ -17,6 +17,7 @@
import shutil
import difflib
import hashlib
+import warnings
import datetime
import requests
import tempfile
@@ -210,10 +211,12 @@ def get_callable_kwargs(config: Union[dict, str], default_module: Union[str, Mod
the class/func object and it's arguments.
"""
if isinstance(config, dict):
- module = get_module_by_module_path(config.get("module_path", default_module))
-
- # raise AttributeError
- _callable = getattr(module, config["class" if "class" in config else "func"])
+ if isinstance(config["class"], str):
+ module = get_module_by_module_path(config.get("module_path", default_module))
+ # raise AttributeError
+ _callable = getattr(module, config["class" if "class" in config else "func"])
+ else:
+ _callable = config["class"] # the class type itself is passed in
kwargs = config.get("kwargs", {})
elif isinstance(config, str):
module = get_module_by_module_path(default_module)
@@ -225,6 +228,9 @@ def get_callable_kwargs(config: Union[dict, str], default_module: Union[str, Mod
return _callable, kwargs
+get_cls_kwargs = get_callable_kwargs # NOTE: this is for compatibility for the previous version
+
+
def init_instance_by_config(
config: Union[str, dict, object], default_module=None, accept_types: Union[type, Tuple[type]] = (), **kwargs
) -> Any:
@@ -235,11 +241,17 @@ def init_instance_by_config(
----------
config : Union[str, dict, object]
dict example.
+ case 1)
{
'class': 'ClassName',
'kwargs': dict, # It is optional. {} will be used if not given
'model_path': path, # It is optional if module is given
}
+ case 2)
+ {
+ 'class': ,
+ 'kwargs': dict, # It is optional. {} will be used if not given
+ }
str example.
1) specify a pickle object
- path like 'file:////obj.pkl'
diff --git a/qlib/utils/exceptions.py b/qlib/utils/exceptions.py
index dad12506b7..dd9b3eaf63 100644
--- a/qlib/utils/exceptions.py
+++ b/qlib/utils/exceptions.py
@@ -1,17 +1,20 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
+
# Base exception class
class QlibException(Exception):
def __init__(self, message):
super(QlibException, self).__init__(message)
-# Error type for reinitialization when starting an experiment
class RecorderInitializationError(QlibException):
+ """Error type for re-initialization when starting an experiment"""
+
pass
-# Error type for Recorder when can not load object
class LoadObjectError(QlibException):
+ """Error type for Recorder when can not load object"""
+
pass
diff --git a/qlib/utils/file.py b/qlib/utils/file.py
new file mode 100644
index 0000000000..611260c86d
--- /dev/null
+++ b/qlib/utils/file.py
@@ -0,0 +1,37 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+# TODO: move file related utils into this module
+import contextlib
+from typing import IO, Union
+from pathlib import Path
+
+
+@contextlib.contextmanager
+def get_io_object(file: Union[IO, str, Path], *args, **kwargs) -> IO:
+ """
+ providing a easy interface to get an IO object
+
+ Parameters
+ ----------
+ file : Union[IO, str, Path]
+ a object representing the file
+
+ Returns
+ -------
+ IO:
+ a IO-like object
+
+ Raises
+ ------
+ NotImplementedError:
+ """
+ if isinstance(file, IO):
+ yield file
+ else:
+ if isinstance(file, str):
+ file = Path(file)
+ if not isinstance(file, Path):
+ raise NotImplementedError(f"This type[{type(file)}] of input is not supported")
+ with file.open(*args, **kwargs) as f:
+ yield f
diff --git a/qlib/utils/index_data.py b/qlib/utils/index_data.py
new file mode 100644
index 0000000000..5e3942db53
--- /dev/null
+++ b/qlib/utils/index_data.py
@@ -0,0 +1,636 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+"""
+Motivation of index_data
+- Pandas has a lot of user-friendly interfaces. However, integrating too much features in a single tool bring to much overhead and makes it much slower than numpy.
+ Some users just want a simple numpy dataframe with indices and don't want such a complicated tools.
+ Such users are the target of `index_data`
+
+`index_data` try to behave like pandas (some API will be different because we try to be simpler and more intuitive) but don't compromize the performance. It provides the basic numpy data and simple indexing feature. If users call APIs which may compromize the performance, index_data will raise Errors.
+"""
+
+from typing import Dict, Tuple, Union, Callable, List
+import bisect
+
+import numpy as np
+import pandas as pd
+
+
+def concat(data_list: Union["SingleData"], axis=0) -> "MultiData":
+ """concat all SingleData by index.
+ TODO: now just for SingleData.
+
+ Parameters
+ ----------
+ data_list : List[SingleData]
+ the list of all SingleData to concat.
+
+ Returns
+ -------
+ MultiData
+ the MultiData with ndim == 2
+ """
+ if axis == 0:
+ raise NotImplementedError(f"please implement this func when axis == 0")
+ elif axis == 1:
+ # get all index and row
+ all_index = set()
+ for index_data in data_list:
+ all_index = all_index | set(index_data.index)
+ all_index = list(all_index)
+ all_index.sort()
+ all_index_map = dict(zip(all_index, range(len(all_index))))
+
+ # concat all
+ tmp_data = np.full((len(all_index), len(data_list)), np.NaN)
+ for data_id, index_data in enumerate(data_list):
+ assert isinstance(index_data, SingleData)
+ now_data_map = [all_index_map[index] for index in index_data.index]
+ tmp_data[now_data_map, data_id] = index_data.data
+ return MultiData(tmp_data, all_index)
+ else:
+ raise ValueError(f"axis must be 0 or 1")
+
+
+def sum_by_index(data_list: Union["SingleData"], new_index: list, fill_value=0) -> "SingleData":
+ """concat all SingleData by new index.
+
+ Parameters
+ ----------
+ data_list : List[SingleData]
+ the list of all SingleData to sum.
+ new_index : list
+ the new_index of new SingleData.
+ fill_value : float
+ fill the missing values ​​or replace np.NaN.
+
+ Returns
+ -------
+ SingleData
+ the SingleData with new_index and values after sum.
+ """
+ data_list = [data.to_dict() for data in data_list]
+ data_sum = {}
+ for id in new_index:
+ item_sum = 0
+ for data in data_list:
+ if id in data and not np.isnan(data[id]):
+ item_sum += data[id]
+ else:
+ item_sum += fill_value
+ data_sum[id] = item_sum
+ return SingleData(data_sum)
+
+
+class Index:
+ """
+ This is for indexing(rows or columns)
+
+ Read-only operations has higher priorities than others.
+ So this class is designed in a **read-only** way to shared data for queries.
+ Modifications will results in new Index.
+
+ NOTE: the indexing has following flaws
+ - duplicated index value is not well supported (only the first appearance will be considered)
+ - The order of the index is not considered!!!! So the slicing will not behave like pandas when indexings are ordered
+ """
+
+ def __init__(self, idx_list: Union[List, pd.Index, "Index", int]):
+ self.idx_list: np.ndarray = None # using array type for index list will make things easier
+ if isinstance(idx_list, Index):
+ # Fast read-only copy
+ self.idx_list = idx_list.idx_list
+ self.index_map = idx_list.index_map
+ self._is_sorted = idx_list._is_sorted
+ elif isinstance(idx_list, int):
+ self.index_map = self.idx_list = np.arange(idx_list)
+ self._is_sorted = True
+ else:
+ self.idx_list = np.array(idx_list)
+ # NOTE: only the first appearance is indexed
+ self.index_map = dict(zip(self.idx_list, range(len(self))))
+ self._is_sorted = False
+
+ def __getitem__(self, i: int):
+ return self.idx_list[i]
+
+ def _convert_type(self, item):
+ """
+
+ After user creates indices with Type A, user may query data with other types with the same info.
+ This method try to make type conversion and make query sane rather than raising KeyError strictly
+
+ Parameters
+ ----------
+ item :
+ The item to query index
+ """
+
+ if self.idx_list.dtype.type is np.datetime64:
+ if isinstance(item, pd.Timestamp):
+ # This happens often when creating index based on pandas.DatetimeIndex and query with pd.Timestamp
+ return item.to_numpy()
+ return item
+
+ def index(self, item) -> int:
+ """
+ Given the index value, get the integer index
+
+ Parameters
+ ----------
+ item :
+ The item to query
+
+ Returns
+ -------
+ int:
+ The index of the item
+
+ Raises
+ ------
+ KeyError:
+ If the query item does not exist
+ """
+ try:
+ return self.index_map[self._convert_type(item)]
+ except IndexError:
+ raise KeyError(f"{item} can't be found in {self}")
+
+ def __or__(self, other: "Index"):
+ return Index(idx_list=list(set(self.idx_list) | set(other.idx_list)))
+
+ def __eq__(self, other: "Index"):
+ # NOTE: np.nan is not supported in the index
+ if self.idx_list.shape != other.idx_list.shape:
+ return False
+ return (self.idx_list == other.idx_list).all()
+
+ def __len__(self):
+ return len(self.idx_list)
+
+ def is_sorted(self):
+ return self._is_sorted
+
+ def sort(self) -> Tuple["Index", np.ndarray]:
+ """
+ sort the index
+
+ Returns
+ -------
+ Tuple["Index", np.ndarray]:
+ the sorted Index and the changed index
+ """
+ sorted_idx = np.argsort(self.idx_list)
+ idx = Index(self.idx_list[sorted_idx])
+ idx._is_sorted = True
+ return idx, sorted_idx
+
+ def tolist(self):
+ """return the index with the format of list."""
+ return self.idx_list.tolist()
+
+
+class LocIndexer:
+ """
+ `Indexer` will behave like the `LocIndexer` in Pandas
+
+ Read-only operations has higher priorities than others.
+ So this class is designed in a read-only way to shared data for queries.
+ Modifications will results in new Index.
+ """
+
+ def __init__(self, index_data: "IndexData", indices: List[Index], int_loc: bool = False):
+ self._indices: List[Index] = indices
+ self._bind_id = index_data # bind index data
+ self._int_loc = int_loc
+ assert self._bind_id.data.ndim == len(self._indices)
+
+ @staticmethod
+ def proc_idx_l(indices: List[Union[List, pd.Index, Index]], data_shape: Tuple = None) -> List[Index]:
+ """process the indices from user and output a list of `Index`"""
+ res = []
+ for i, idx in enumerate(indices):
+ res.append(Index(data_shape[i] if len(idx) == 0 else idx))
+ return res
+
+ def _slc_convert(self, index: Index, indexing: slice) -> slice:
+ """
+ convert value-based indexing to integer-based indexing.
+
+ Parameters
+ ----------
+ index : Index
+ index data.
+ indexing : slice
+ value based indexing data with slice type for indexing.
+
+ Returns
+ -------
+ slice:
+ the integer based slicing
+ """
+ if index.is_sorted():
+ int_start = None if indexing.start is None else bisect.bisect_left(index, indexing.start)
+ int_stop = None if indexing.stop is None else bisect.bisect_right(index, indexing.stop)
+ else:
+ int_start = None if indexing.start is None else index.index(indexing.start)
+ int_stop = None if indexing.stop is None else index.index(indexing.stop) + 1
+ return slice(int_start, int_stop)
+
+ def __getitem__(self, indexing):
+ """
+
+ Parameters
+ ----------
+ indexing :
+ query for data
+
+ Raises
+ ------
+ KeyError:
+ If the non-slice index is queried but does not exist, `KeyError` is raised.
+ """
+ # 1) convert slices to int loc
+ if not isinstance(indexing, tuple):
+ # NOTE: tuple is not supported for indexing
+ indexing = (indexing,)
+
+ # TODO: create a subclass for single value query
+ assert len(indexing) <= len(self._indices)
+
+ int_indexing = []
+ for dim, index in enumerate(self._indices):
+ if dim < len(indexing):
+ _indexing = indexing[dim]
+ if not self._int_loc: # type converting is only necessary when it is not `iloc`
+ if isinstance(_indexing, slice):
+ _indexing = self._slc_convert(index, _indexing)
+ elif isinstance(_indexing, (IndexData, np.ndarray)):
+ if isinstance(_indexing, IndexData):
+ _indexing = _indexing.data
+ assert _indexing.ndim == 1
+ if _indexing.dtype != np.bool:
+ _indexing = np.array(list(index.index(i) for i in _indexing))
+ else:
+ _indexing = index.index(_indexing)
+ else:
+ # Default to select all when user input is not given
+ _indexing = slice(None)
+ int_indexing.append(_indexing)
+
+ # 2) select data and index
+ new_data = self._bind_id.data[tuple(int_indexing)]
+ # return directly if it is scalar
+ if new_data.ndim == 0:
+ return new_data
+ # otherwise we go on to the index part
+ new_indices = [idx[indexing] for idx, indexing in zip(self._indices, int_indexing)]
+
+ # 3) squash dimensions
+ new_indices = [
+ idx for idx in new_indices if isinstance(idx, np.ndarray) and idx.ndim > 0
+ ] # squash the zero dim indexing
+
+ if new_data.ndim == 1:
+ cls = SingleData
+ elif new_data.ndim == 2:
+ cls = MultiData
+ else:
+ raise ValueError("Not supported")
+ return cls(new_data, *new_indices)
+
+
+class BinaryOps:
+ def __init__(self, method_name):
+ self.method_name = method_name
+
+ def __get__(self, obj, *args):
+ # bind object
+ self.obj = obj
+ return self
+
+ def __call__(self, other):
+ self_data_method = getattr(self.obj.data, self.method_name)
+
+ if isinstance(other, (int, float, np.number)):
+ return self.obj.__class__(self_data_method(other), *self.obj.indices)
+ elif isinstance(other, self.obj.__class__):
+ other_aligned = self.obj._align_indices(other)
+ return self.obj.__class__(self_data_method(other_aligned.data), *self.obj.indices)
+ else:
+ return NotImplemented
+
+
+def index_data_ops_creator(*args, **kwargs):
+ """
+ meta class for auto generating operations for index data.
+ """
+ for method_name in ["__add__", "__sub__", "__rsub__", "__mul__", "__truediv__", "__eq__", "__gt__", "__lt__"]:
+ args[2][method_name] = BinaryOps(method_name=method_name)
+ return type(*args)
+
+
+class IndexData(metaclass=index_data_ops_creator):
+ """
+ Base data structure of SingleData and MultiData.
+
+ NOTE:
+ - For performance issue, only **np.floating** is supported in the underlayer data !!!
+ - Boolean based on np.floating is also supported. Here are some examples
+
+ .. code-block:: python
+
+ np.array([ np.nan]).any() -> True
+ np.array([ np.nan]).all() -> True
+ np.array([1. , 0.]).any() -> True
+ np.array([1. , 0.]).all() -> False
+ """
+
+ loc_idx_cls = LocIndexer
+
+ def __init__(self, data: np.ndarray, *indices: Union[List, pd.Index, Index]):
+
+ self.data = data
+ self.indices = indices
+
+ # get the expected data shape
+ # - The index has higher priority
+ self.data = np.array(data)
+
+ expected_dim = max(self.data.ndim, len(indices))
+
+ data_shape = []
+ for i in range(expected_dim):
+ idx_l = indices[i] if len(indices) > i else []
+ if len(idx_l) == 0:
+ data_shape.append(self.data.shape[i])
+ else:
+ data_shape.append(len(idx_l))
+ data_shape = tuple(data_shape)
+
+ # broadcast the data to expected shape
+ if self.data.shape != data_shape:
+ self.data = np.broadcast_to(self.data, data_shape)
+
+ self.data = self.data.astype(np.float64)
+ # Please notice following cases when converting the type
+ # - np.array([None, 1]).astype(np.float64) -> array([nan, 1.])
+
+ # create index from user's index data.
+ self.indices: List[Index] = self.loc_idx_cls.proc_idx_l(indices, data_shape)
+
+ for dim in range(expected_dim):
+ assert self.data.shape[dim] == len(self.indices[dim])
+
+ self.ndim = expected_dim
+
+ # indexing related methods
+ @property
+ def loc(self):
+ return self.loc_idx_cls(index_data=self, indices=self.indices)
+
+ @property
+ def iloc(self):
+ return self.loc_idx_cls(index_data=self, indices=self.indices, int_loc=True)
+
+ @property
+ def index(self):
+ return self.indices[0]
+
+ @property
+ def columns(self):
+ return self.indices[1]
+
+ def _align_indices(self, other: "IndexData") -> "IndexData":
+ """
+ Align all indices of `other` to `self` before performing the arithmetic operations.
+ This function will return a new IndexData rather than changing data in `other` inplace
+
+ Parameters
+ ----------
+ other : "IndexData"
+ the index in `other` is to be chagned
+
+ Returns
+ -------
+ IndexData:
+ the data in `other` with index aligned to `self`
+ """
+ raise NotImplementedError(f"please implement _align_indices func")
+
+ def sort_index(self, axis=0, inplace=True):
+ assert inplace, "Only support sorting inplace now"
+ self.indices[axis], sorted_idx = self.indices[axis].sort()
+ self.data = np.take(self.data, sorted_idx, axis=axis)
+
+ # The code below could be simpler like methods in __getattribute__
+ def __invert__(self):
+ return self.__class__(~self.data.astype(np.bool), *self.indices)
+
+ def abs(self):
+ """get the abs of data except np.NaN."""
+ tmp_data = np.absolute(self.data)
+ return self.__class__(tmp_data, *self.indices)
+
+ def replace(self, to_replace: Dict[np.number, np.number]):
+ assert isinstance(to_replace, dict)
+ tmp_data = self.data.copy()
+ for num in to_replace:
+ if num in tmp_data:
+ tmp_data[self.data == num] = to_replace[num]
+ return self.__class__(tmp_data, *self.indices)
+
+ def apply(self, func: Callable):
+ """apply a function to data."""
+ tmp_data = func(self.data)
+ return self.__class__(tmp_data, *self.indices)
+
+ def __len__(self):
+ """the length of the data.
+
+ Returns
+ -------
+ int
+ the length of the data.
+ """
+ return len(self.data)
+
+ def sum(self, axis=None):
+ # FIXME: weird logic and not general
+ if axis is None:
+ return np.nansum(self.data)
+ elif axis == 0:
+ tmp_data = np.nansum(self.data, axis=0)
+ return SingleData(tmp_data, self.columns)
+ elif axis == 1:
+ tmp_data = np.nansum(self.data, axis=1)
+ return SingleData(tmp_data, self.index)
+ else:
+ raise ValueError(f"axis must be None, 0 or 1")
+
+ def mean(self, axis=None):
+ # FIXME: weird logic and not general
+ if axis is None:
+ return np.nanmean(self.data)
+ elif axis == 0:
+ tmp_data = np.nanmean(self.data, axis=0)
+ return SingleData(tmp_data, self.columns)
+ elif axis == 1:
+ tmp_data = np.nanmean(self.data, axis=1)
+ return SingleData(tmp_data, self.index)
+ else:
+ raise ValueError(f"axis must be None, 0 or 1")
+
+ def isna(self):
+ return self.__class__(np.isnan(self.data), *self.indices)
+
+ def fillna(self, value=0.0, inplace: bool = False):
+ if inplace:
+ self.data = np.nan_to_num(self.data, nan=value)
+ else:
+ return self.__class__(np.nan_to_num(self.data, nan=value), *self.indices)
+
+ def count(self):
+ return len(self.data[~np.isnan(self.data)])
+
+ def all(self):
+ if None in self.data:
+ return self.data[self.data is not None].all()
+ else:
+ return self.data.all()
+
+ @property
+ def empty(self):
+ return len(self.data) == 0
+
+ @property
+ def values(self):
+ return self.data
+
+
+class SingleData(IndexData):
+ def __init__(
+ self, data: Union[int, float, np.number, list, dict, pd.Series] = [], index: Union[List, pd.Index, Index] = []
+ ):
+ """A data structure of index and numpy data.
+ It's used to replace pd.Series due to high-speed.
+
+ Parameters
+ ----------
+ data : Union[int, float, np.number, list, dict, pd.Series]
+ the input data
+ index : Union[list, pd.Index]
+ the index of data.
+ empty list indicates that auto filling the index to the length of data
+ """
+ # for special data type
+ if isinstance(data, dict):
+ assert len(index) == 0
+ if len(data) > 0:
+ index, data = zip(*data.items())
+ else:
+ index, data = [], []
+ elif isinstance(data, pd.Series):
+ assert len(index) == 0
+ index, data = data.index, data.values
+ elif isinstance(data, (int, float, np.number)):
+ data = [data]
+ super().__init__(data, index)
+ assert self.ndim == 1
+
+ def _align_indices(self, other):
+ if self.index == other.index:
+ return other
+ elif set(self.index) == set(other.index):
+ return other.reindex(self.index)
+ else:
+ raise ValueError(
+ f"The indexes of self and other do not meet the requirements of the four arithmetic operations"
+ )
+
+ def reindex(self, index: Index, fill_value=np.NaN):
+ """reindex data and fill the missing value with np.NaN.
+
+ Parameters
+ ----------
+ new_index : list
+ new index
+ fill_value:
+ what value to fill if index is missing
+
+ Returns
+ -------
+ SingleData
+ reindex data
+ """
+ # TODO: This method can be more general
+ if self.index == index:
+ return self
+ tmp_data = np.full(len(index), fill_value, dtype=np.float64)
+ for index_id, index_item in enumerate(index):
+ try:
+ tmp_data[index_id] = self.loc[index_item]
+ except KeyError:
+ pass
+ return SingleData(tmp_data, index)
+
+ def add(self, other: "SingleData", fill_value=0):
+ # TODO: add and __add__ are a little confusing.
+ # This could be a more general
+ common_index = self.index | other.index
+ common_index, _ = common_index.sort()
+ tmp_data1 = self.reindex(common_index, fill_value)
+ tmp_data2 = other.reindex(common_index, fill_value)
+ return tmp_data1.fillna(fill_value) + tmp_data2.fillna(fill_value)
+
+ def to_dict(self):
+ """convert SingleData to dict.
+
+ Returns
+ -------
+ dict
+ data with the dict format.
+ """
+ return dict(zip(self.index, self.data.tolist()))
+
+ def to_series(self):
+ return pd.Series(self.data, index=self.index)
+
+ def __repr__(self) -> str:
+ return str(pd.Series(self.data, index=self.index))
+
+
+class MultiData(IndexData):
+ def __init__(
+ self,
+ data: Union[int, float, np.number, list] = [],
+ index: Union[List, pd.Index, Index] = [],
+ columns: Union[List, pd.Index, Index] = [],
+ ):
+ """A data structure of index and numpy data.
+ It's used to replace pd.DataFrame due to high-speed.
+
+ Parameters
+ ----------
+ data : Union[list, np.ndarray]
+ the dim of data must be 2.
+ index : Union[List, pd.Index, Index]
+ the index of data.
+ columns: Union[List, pd.Index, Index]
+ the columns of data.
+ """
+ if isinstance(data, pd.DataFrame):
+ index, columns, data = data.index, data.columns, data.values
+ super().__init__(data, index, columns)
+ assert self.ndim == 2
+
+ def _align_indices(self, other):
+ if self.indices == other.indices:
+ return other
+ else:
+ raise ValueError(
+ f"The indexes of self and other do not meet the requirements of the four arithmetic operations"
+ )
+
+ def __repr__(self) -> str:
+ return str(pd.DataFrame(self.data, index=self.index, columns=self.columns))
diff --git a/qlib/utils/paral.py b/qlib/utils/paral.py
index a640b04ea6..075a1adb84 100644
--- a/qlib/utils/paral.py
+++ b/qlib/utils/paral.py
@@ -1,8 +1,17 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
-from joblib import Parallel, delayed
import pandas as pd
+from joblib import Parallel, delayed
+from joblib._parallel_backends import MultiprocessingBackend
+
+
+class ParallelExt(Parallel):
+ def __init__(self, *args, **kwargs):
+ maxtasksperchild = kwargs.pop("maxtasksperchild", None)
+ super(ParallelExt, self).__init__(*args, **kwargs)
+ if isinstance(self._backend, MultiprocessingBackend):
+ self._backend_args["maxtasksperchild"] = maxtasksperchild
def datetime_groupby_apply(df, apply_func, axis=0, level="datetime", resample_rule="M", n_jobs=-1, skip_group=False):
@@ -31,7 +40,7 @@ def _naive_group_apply(df):
return df.groupby(axis=axis, level=level).apply(apply_func)
if n_jobs != 1:
- dfs = Parallel(n_jobs=n_jobs)(
+ dfs = ParallelExt(n_jobs=n_jobs)(
delayed(_naive_group_apply)(sub_df) for idx, sub_df in df.resample(resample_rule, axis=axis, level=level)
)
return pd.concat(dfs, axis=axis).sort_index()
diff --git a/qlib/utils/resam.py b/qlib/utils/resam.py
new file mode 100644
index 0000000000..c81a15e45a
--- /dev/null
+++ b/qlib/utils/resam.py
@@ -0,0 +1,232 @@
+import numpy as np
+import pandas as pd
+
+from functools import partial
+from typing import Union, Callable
+
+from . import lazy_sort_index
+from .time import Freq, cal_sam_minute
+
+
+def resam_calendar(calendar_raw: np.ndarray, freq_raw: str, freq_sam: str) -> np.ndarray:
+ """
+ Resample the calendar with frequency freq_raw into the calendar with frequency freq_sam
+ Assumption:
+ - Fix length (240) of the calendar in each day.
+
+ Parameters
+ ----------
+ calendar_raw : np.ndarray
+ The calendar with frequency freq_raw
+ freq_raw : str
+ Frequency of the raw calendar
+ freq_sam : str
+ Sample frequency
+
+ Returns
+ -------
+ np.ndarray
+ The calendar with frequency freq_sam
+ """
+ raw_count, freq_raw = Freq.parse(freq_raw)
+ sam_count, freq_sam = Freq.parse(freq_sam)
+ if not len(calendar_raw):
+ return calendar_raw
+
+ # if freq_sam is xminute, divide each trading day into several bars evenly
+ if freq_sam == Freq.NORM_FREQ_MINUTE:
+ if freq_raw != Freq.NORM_FREQ_MINUTE:
+ raise ValueError("when sampling minute calendar, freq of raw calendar must be minute or min")
+ else:
+ if raw_count > sam_count:
+ raise ValueError("raw freq must be higher than sampling freq")
+ _calendar_minute = np.unique(list(map(lambda x: cal_sam_minute(x, sam_count), calendar_raw)))
+ return _calendar_minute
+
+ # else, convert the raw calendar into day calendar, and divide the whole calendar into several bars evenly
+ else:
+ _calendar_day = np.unique(list(map(lambda x: pd.Timestamp(x.year, x.month, x.day, 0, 0, 0), calendar_raw)))
+ if freq_sam == Freq.NORM_FREQ_DAY:
+ return _calendar_day[::sam_count]
+
+ elif freq_sam == Freq.NORM_FREQ_WEEK:
+ _day_in_week = np.array(list(map(lambda x: x.dayofweek, _calendar_day)))
+ _calendar_week = _calendar_day[np.ediff1d(_day_in_week, to_begin=-1) < 0]
+ return _calendar_week[::sam_count]
+
+ elif freq_sam == Freq.NORM_FREQ_MONTH:
+ _day_in_month = np.array(list(map(lambda x: x.day, _calendar_day)))
+ _calendar_month = _calendar_day[np.ediff1d(_day_in_month, to_begin=-1) < 0]
+ return _calendar_month[::sam_count]
+ else:
+ raise ValueError("sampling freq must be xmin, xd, xw, xm")
+
+
+def get_higher_eq_freq_feature(instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1):
+ """get the feature with higher or equal frequency than `freq`.
+ Returns
+ -------
+ pd.DataFrame
+ the feature with higher or equal frequency
+ """
+
+ from ..data.data import D
+
+ try:
+ _result = D.features(instruments, fields, start_time, end_time, freq=freq, disk_cache=disk_cache)
+ _freq = freq
+ except (ValueError, KeyError):
+ _, norm_freq = Freq.parse(freq)
+ if norm_freq in [Freq.NORM_FREQ_MONTH, Freq.NORM_FREQ_WEEK, Freq.NORM_FREQ_DAY]:
+ try:
+ _result = D.features(instruments, fields, start_time, end_time, freq="day", disk_cache=disk_cache)
+ _freq = "day"
+ except (ValueError, KeyError):
+ _result = D.features(instruments, fields, start_time, end_time, freq="1min", disk_cache=disk_cache)
+ _freq = "1min"
+ elif norm_freq == Freq.NORM_FREQ_MINUTE:
+ _result = D.features(instruments, fields, start_time, end_time, freq="1min", disk_cache=disk_cache)
+ _freq = "1min"
+ else:
+ raise ValueError(f"freq {freq} is not supported")
+ return _result, _freq
+
+
+def resam_ts_data(
+ ts_feature: Union[pd.DataFrame, pd.Series],
+ start_time: Union[str, pd.Timestamp] = None,
+ end_time: Union[str, pd.Timestamp] = None,
+ method: Union[str, Callable] = "last",
+ method_kwargs: dict = {},
+):
+ """
+ Resample value from time-series data
+
+ - If `feature` has MultiIndex[instrument, datetime], apply the `method` to each instruemnt data with datetime in [start_time, end_time]
+ Example:
+
+ .. code-block::
+
+ print(feature)
+ $close $volume
+ instrument datetime
+ SH600000 2010-01-04 86.778313 16162960.0
+ 2010-01-05 87.433578 28117442.0
+ 2010-01-06 85.713585 23632884.0
+ 2010-01-07 83.788803 20813402.0
+ 2010-01-08 84.730675 16044853.0
+
+ SH600655 2010-01-04 2699.567383 158193.328125
+ 2010-01-08 2612.359619 77501.406250
+ 2010-01-11 2712.982422 160852.390625
+ 2010-01-12 2788.688232 164587.937500
+ 2010-01-13 2790.604004 145460.453125
+
+ print(resam_ts_data(feature, start_time="2010-01-04", end_time="2010-01-05", fields=["$close", "$volume"], method="last"))
+ $close $volume
+ instrument
+ SH600000 87.433578 28117442.0
+ SH600655 2699.567383 158193.328125
+
+ - Else, the `feature` should have Index[datetime], just apply the `method` to `feature` directly
+ Example:
+
+ .. code-block::
+ print(feature)
+ $close $volume
+ datetime
+ 2010-01-04 86.778313 16162960.0
+ 2010-01-05 87.433578 28117442.0
+ 2010-01-06 85.713585 23632884.0
+ 2010-01-07 83.788803 20813402.0
+ 2010-01-08 84.730675 16044853.0
+
+ print(resam_ts_data(feature, start_time="2010-01-04", end_time="2010-01-05", method="last"))
+
+ $close 87.433578
+ $volume 28117442.0
+
+ print(resam_ts_data(feature['$close'], start_time="2010-01-04", end_time="2010-01-05", method="last"))
+
+ 87.433578
+
+ Parameters
+ ----------
+ ts_feature : Union[pd.DataFrame, pd.Series]
+ Raw time-series feature to be resampled
+ start_time : Union[str, pd.Timestamp], optional
+ start sampling time, by default None
+ end_time : Union[str, pd.Timestamp], optional
+ end sampling time, by default None
+ method : Union[str, Callable], optional
+ sample method, apply method function to each stock series data, by default "last"
+ - If type(method) is str or callable function, it should be an attribute of SeriesGroupBy or DataFrameGroupby, and applies groupy.method for the sliced time-series data
+ - If method is None, do nothing for the sliced time-series data.
+ method_kwargs : dict, optional
+ arguments of method, by default {}
+
+ Returns
+ -------
+ The resampled DataFrame/Series/value, return None when the resampled data is empty.
+ """
+
+ selector_datetime = slice(start_time, end_time)
+
+ from ..data.dataset.utils import get_level_index
+
+ feature = lazy_sort_index(ts_feature)
+
+ datetime_level = get_level_index(feature, level="datetime") == 0
+ if datetime_level:
+ feature = feature.loc[selector_datetime]
+ else:
+ feature = feature.loc(axis=0)[(slice(None), selector_datetime)]
+
+ if feature.empty:
+ return None
+ if isinstance(feature.index, pd.MultiIndex):
+ if callable(method):
+ method_func = method
+ return feature.groupby(level="instrument").apply(lambda x: method_func(x, **method_kwargs))
+ elif isinstance(method, str):
+ return getattr(feature.groupby(level="instrument"), method)(**method_kwargs)
+ else:
+ if callable(method):
+ method_func = method
+ return method_func(feature, **method_kwargs)
+ elif isinstance(method, str):
+ return getattr(feature, method)(**method_kwargs)
+ return feature
+
+
+def get_valid_value(series, last=True):
+ """get the first/last not nan value of pd.Series with single level index
+ Parameters
+ ----------
+ series : pd.Seires
+ series should not be empty
+ last : bool, optional
+ wether to get the last valid value, by default True
+ - if last is True, get the last valid value
+ - else, get the first valid value
+
+ Returns
+ -------
+ Nan | float
+ the first/last valid value
+ """
+ return series.fillna(method="ffill").iloc[-1] if last else series.fillna(method="bfill").iloc[0]
+
+
+def _ts_data_valid(ts_feature, last=False):
+ """get the first/last not nan value of pd.Series|DataFrame with single level index"""
+ if isinstance(ts_feature, pd.DataFrame):
+ return ts_feature.apply(lambda column: get_valid_value(column, last=last))
+ elif isinstance(ts_feature, pd.Series):
+ return get_valid_value(ts_feature, last=last)
+ else:
+ raise TypeError(f"ts_feature should be pd.DataFrame/Series, not {type(ts_feature)}")
+
+
+ts_data_last = partial(_ts_data_valid, last=True)
+ts_data_first = partial(_ts_data_valid, last=False)
diff --git a/qlib/utils/time.py b/qlib/utils/time.py
new file mode 100644
index 0000000000..48ce658c49
--- /dev/null
+++ b/qlib/utils/time.py
@@ -0,0 +1,316 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+"""
+Time related utils are compiled in this script
+"""
+import bisect
+from datetime import datetime, time, date
+from typing import List, Tuple, Union
+import functools
+import re
+
+import pandas as pd
+
+from qlib.config import C
+
+
+@functools.lru_cache(maxsize=240)
+def get_min_cal(shift: int = 0) -> List[time]:
+ """
+ get the minute level calendar in day period
+
+ Parameters
+ ----------
+ shift : int
+ the shift direction would be like pandas shift.
+ series.shift(1) will replace the value at `i`-th with the one at `i-1`-th
+
+ Returns
+ -------
+ List[time]:
+
+ """
+ cal = []
+ for ts in list(pd.date_range("9:30", "11:29", freq="1min") - pd.Timedelta(minutes=shift)) + list(
+ pd.date_range("13:00", "14:59", freq="1min") - pd.Timedelta(minutes=shift)
+ ):
+ cal.append(ts.time())
+ return cal
+
+
+def is_single_value(start_time, end_time, freq, region="cn"):
+ """Is there only one piece of data for stock market.
+
+ Parameters
+ ----------
+ start_time : Union[pd.Timestamp, str]
+ closed start time for data.
+ end_time : Union[pd.Timestamp, str]
+ closed end time for data.
+ freq :
+ Returns
+ -------
+ bool
+ True means one piece of data to obtain.
+ """
+ if region == "cn":
+ if end_time - start_time < freq:
+ return True
+ if start_time.hour == 11 and start_time.minute == 29 and start_time.second == 0:
+ return True
+ if start_time.hour == 14 and start_time.minute == 59 and start_time.second == 0:
+ return True
+ return False
+ else:
+ raise NotImplementedError(f"please implement the is_single_value func for {region}")
+
+
+class Freq:
+ NORM_FREQ_MONTH = "month"
+ NORM_FREQ_WEEK = "week"
+ NORM_FREQ_DAY = "day"
+ NORM_FREQ_MINUTE = "minute"
+ SUPPORT_CAL_LIST = [NORM_FREQ_MINUTE, NORM_FREQ_DAY] # FIXME: this list should from data
+
+ MIN_CAL = get_min_cal()
+
+ def __init__(self, freq: str) -> None:
+ self.count, self.base = self.parse(freq)
+
+ @staticmethod
+ def parse(freq: str) -> Tuple[int, str]:
+ """
+ Parse freq into a unified format
+
+ Parameters
+ ----------
+ freq : str
+ Raw freq, supported freq should match the re '^([0-9]*)(month|mon|week|w|day|d|minute|min)$'
+
+ Returns
+ -------
+ freq: Tuple[int, str]
+ Unified freq, including freq count and unified freq unit. The freq unit should be '[month|week|day|minute]'.
+ Example:
+
+ .. code-block::
+
+ print(Freq.parse("day"))
+ (1, "day" )
+ print(Freq.parse("2mon"))
+ (2, "month")
+ print(Freq.parse("10w"))
+ (10, "week")
+
+ """
+ freq = freq.lower()
+ match_obj = re.match("^([0-9]*)(month|mon|week|w|day|d|minute|min)$", freq)
+ if match_obj is None:
+ raise ValueError(
+ "freq format is not supported, the freq should be like (n)month/mon, (n)week/w, (n)day/d, (n)minute/min"
+ )
+ _count = int(match_obj.group(1)) if match_obj.group(1) else 1
+ _freq = match_obj.group(2)
+ _freq_format_dict = {
+ "month": Freq.NORM_FREQ_MONTH,
+ "mon": Freq.NORM_FREQ_MONTH,
+ "week": Freq.NORM_FREQ_WEEK,
+ "w": Freq.NORM_FREQ_WEEK,
+ "day": Freq.NORM_FREQ_DAY,
+ "d": Freq.NORM_FREQ_DAY,
+ "minute": Freq.NORM_FREQ_MINUTE,
+ "min": Freq.NORM_FREQ_MINUTE,
+ }
+ return _count, _freq_format_dict[_freq]
+
+ @staticmethod
+ def get_timedelta(n: int, freq: str) -> pd.Timedelta:
+ """
+ get pd.Timedeta object
+
+ Parameters
+ ----------
+ n : int
+ freq : str
+ Typically, they are the return value of Freq.parse
+
+ Returns
+ -------
+ pd.Timedelta:
+ """
+ return pd.Timedelta(f"{n}{freq}")
+
+ @staticmethod
+ def get_min_delta(left_frq: str, right_freq: str):
+ """Calculate freq delta
+
+ Parameters
+ ----------
+ left_frq: str
+ right_freq: str
+
+ Returns
+ -------
+
+ """
+ minutes_map = {
+ Freq.NORM_FREQ_MINUTE: 1,
+ Freq.NORM_FREQ_DAY: 60 * 24,
+ Freq.NORM_FREQ_WEEK: 7 * 60 * 24,
+ Freq.NORM_FREQ_MONTH: 30 * 7 * 60 * 24,
+ }
+ left_freq = Freq.parse(left_frq)
+ left_minutes = left_freq[0] * minutes_map[left_freq[1]]
+ right_freq = Freq.parse(right_freq)
+ right_minutes = right_freq[0] * minutes_map[right_freq[1]]
+ return left_minutes - right_minutes
+
+ @staticmethod
+ def get_recent_freq(base_freq: str, freq_list: List[str]) -> str:
+ """Get the closest freq to base_freq from freq_list
+
+ Parameters
+ ----------
+ base_freq
+ freq_list
+
+ Returns
+ -------
+
+ """
+ # use the nearest freq greater than 0
+ _freq_minutes = []
+ min_freq = None
+ for _freq in freq_list:
+ _min_delta = Freq.get_min_delta(base_freq, _freq)
+ if _min_delta < 0:
+ continue
+ if min_freq is None:
+ min_freq = (_min_delta, _freq)
+ continue
+ min_freq = min_freq if min_freq[0] <= _min_delta else (_min_delta, _freq)
+ return min_freq[1] if min_freq else None
+
+
+CN_TIME = [
+ datetime.strptime("9:30", "%H:%M"),
+ datetime.strptime("11:30", "%H:%M"),
+ datetime.strptime("13:00", "%H:%M"),
+ datetime.strptime("15:00", "%H:%M"),
+]
+US_TIME = [datetime.strptime("9:30", "%H:%M"), datetime.strptime("16:00", "%H:%M")]
+
+
+def time_to_day_index(time_obj: Union[str, datetime], region: str = "cn"):
+ if isinstance(time_obj, str):
+ time_obj = datetime.strptime(time_obj, "%H:%M")
+
+ if region == "cn":
+ if CN_TIME[0] <= time_obj < CN_TIME[1]:
+ return int((time_obj - CN_TIME[0]).total_seconds() / 60)
+ elif CN_TIME[2] <= time_obj < CN_TIME[3]:
+ return int((time_obj - CN_TIME[2]).total_seconds() / 60) + 120
+ else:
+ raise ValueError(f"{time_obj} is not the opening time of the {region} stock market")
+ elif region == "us":
+ if US_TIME[0] <= time_obj < US_TIME[1]:
+ return int((time_obj - US_TIME[0]).total_seconds() / 60)
+ else:
+ raise ValueError(f"{time_obj} is not the opening time of the {region} stock market")
+ else:
+ raise ValueError(f"{region} is not supported")
+
+
+def get_day_min_idx_range(start: str, end: str, freq: str) -> Tuple[int, int]:
+ """
+ get the min-bar index in a day for a time range (both left and right is closed) given a fixed frequency
+ Parameters
+ ----------
+ start : str
+ e.g. "9:30"
+ end : str
+ e.g. "14:30"
+ freq : str
+ "1min"
+
+ Returns
+ -------
+ Tuple[int, int]:
+ The index of start and end in the calendar. Both left and right are **closed**
+ """
+ start = pd.Timestamp(start).time()
+ end = pd.Timestamp(end).time()
+ freq = Freq(freq)
+ in_day_cal = Freq.MIN_CAL[:: freq.count]
+ left_idx = bisect.bisect_left(in_day_cal, start)
+ right_idx = bisect.bisect_right(in_day_cal, end) - 1
+ return left_idx, right_idx
+
+
+def concat_date_time(date_obj: date, time_obj: time) -> pd.Timestamp:
+ return pd.Timestamp(
+ datetime(
+ date_obj.year,
+ month=date_obj.month,
+ day=date_obj.day,
+ hour=time_obj.hour,
+ minute=time_obj.minute,
+ second=time_obj.second,
+ microsecond=time_obj.microsecond,
+ )
+ )
+
+
+def cal_sam_minute(x: pd.Timestamp, sam_minutes: int) -> pd.Timestamp:
+ """
+ align the minute-level data to a down sampled calendar
+
+ e.g. align 10:38 to 10:35 in 5 minute-level(10:30 in 10 minute-level)
+
+ Parameters
+ ----------
+ x : pd.Timestamp
+ datetime to be aligned
+ sam_minutes : int
+ align to `sam_minutes` minute-level calendar
+
+ Returns
+ -------
+ pd.Timestamp:
+ the datetime after aligned
+ """
+ cal = get_min_cal(C.min_data_shift)[::sam_minutes]
+ idx = bisect.bisect_right(cal, x.time()) - 1
+ _date, new_time = x.date(), cal[idx]
+ return concat_date_time(_date, new_time)
+
+
+def epsilon_change(date_time: pd.Timestamp, direction: str = "backward") -> pd.Timestamp:
+ """
+ change the time by infinitely small quantity.
+
+
+ Parameters
+ ----------
+ date_time : pd.Timestamp
+ the original time
+ direction : str
+ the direction the time are going to
+ - "backward" for going to history
+ - "forward" for going to the future
+
+ Returns
+ -------
+ pd.Timestamp:
+ the shifted time
+ """
+ if direction == "backward":
+ return date_time - pd.Timedelta(seconds=1)
+ elif direction == "forward":
+ return date_time + pd.Timedelta(seconds=1)
+ else:
+ raise ValueError("Wrong input")
+
+
+if __name__ == "__main__":
+ print(get_day_min_idx_range("8:30", "14:59", "10min"))
diff --git a/qlib/workflow/cli.py b/qlib/workflow/cli.py
index 16e5b62963..f0be8f4e8d 100644
--- a/qlib/workflow/cli.py
+++ b/qlib/workflow/cli.py
@@ -41,7 +41,7 @@ def sys_config(config, config_path):
sys.path.append(str(Path(config_path).parent.resolve().absolute() / p))
-# worflow handler function
+# workflow handler function
def workflow(config_path, experiment_name="workflow", uri_folder="mlruns"):
with open(config_path) as fp:
config = yaml.safe_load(fp)
@@ -57,7 +57,7 @@ def workflow(config_path, experiment_name="workflow", uri_folder="mlruns"):
recorder.save_objects(config=config)
-# function to run worklflow by config
+# function to run workflow by config
def run():
fire.Fire(workflow)
diff --git a/qlib/workflow/record_temp.py b/qlib/workflow/record_temp.py
index e374ecef93..bae14d642d 100644
--- a/qlib/workflow/record_temp.py
+++ b/qlib/workflow/record_temp.py
@@ -1,20 +1,26 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
-import re, logging
+from qlib.backtest import executor
+import re
+import logging
+import warnings
import pandas as pd
from pathlib import Path
from pprint import pprint
-from ..contrib.evaluate import risk_analysis
-from ..contrib.backtest import backtest as normal_backtest
+from typing import Union, List
+from ..contrib.evaluate import indicator_analysis, risk_analysis, indicator_analysis
from ..data.dataset import DatasetH
from ..data.dataset.handler import DataHandlerLP
+from ..backtest import backtest as normal_backtest
from ..utils import init_instance_by_config, get_module_by_module_path
from ..log import get_module_logger
from ..utils import flatten_dict
+from ..utils.time import Freq
+from ..strategy.base import BaseStrategy
from ..contrib.eva.alpha import calc_ic, calc_long_short_return, calc_long_short_prec
-from ..contrib.strategy.strategy import BaseStrategy
+
logger = get_module_logger("workflow", logging.INFO)
@@ -84,15 +90,15 @@ def load(self, name):
def list(self):
"""
- List the stored records.
+ List the supported artifacts.
Return
------
- A list of all the stored records.
+ A list of all the supported artifacts.
"""
return []
- def check(self, parent=False):
+ def check(self, cls="self"):
"""
Check if the records is properly generated and saved.
@@ -101,11 +107,9 @@ def check(self, parent=False):
FileExistsError: whether the records are stored properly.
"""
artifacts = set(self.recorder.list_artifacts())
- if parent:
- # Downcasting have to be done here instead of using `super`
- flist = self.__class__.__base__.list(self) # pylint: disable=E1101
- else:
- flist = self.list()
+ if cls == "self":
+ cls = self
+ flist = cls.list()
for item in flist:
if item not in artifacts:
raise FileExistsError(item)
@@ -163,7 +167,8 @@ def generate(self, **kwargs):
raw_label = self.generate_label(self.dataset)
self.recorder.save_objects(**{"label.pkl": raw_label})
- def list(self):
+ @staticmethod
+ def list():
return ["pred.pkl", "label.pkl"]
def load(self, name="pred.pkl"):
@@ -224,24 +229,22 @@ def list(self):
return paths
-class SigAnaRecord(SignalRecord):
+class SigAnaRecord(RecordTemp):
"""
This is the Signal Analysis Record class that generates the analysis results such as IC and IR. This class inherits the ``RecordTemp`` class.
"""
artifact_path = "sig_analysis"
+ pre_class = SignalRecord
- def __init__(self, recorder, ana_long_short=False, ann_scaler=252, label_col=0, **kwargs):
- super().__init__(recorder=recorder, **kwargs)
+ def __init__(self, recorder, ana_long_short=False, ann_scaler=252, label_col=0):
+ super().__init__(recorder=recorder)
self.ana_long_short = ana_long_short
self.ann_scaler = ann_scaler
self.label_col = label_col
def generate(self, **kwargs):
- try:
- self.check(parent=True)
- except FileExistsError:
- super().generate()
+ self.check(self.pre_class)
pred = self.load("pred.pkl")
label = self.load("label.pkl")
@@ -283,7 +286,7 @@ def list(self):
return paths
-class PortAnaRecord(SignalRecord):
+class PortAnaRecord(RecordTemp):
"""
This is the Portfolio Analysis Record class that generates the analysis results such as those of backtest. This class inherits the ``RecordTemp`` class.
@@ -295,70 +298,164 @@ class PortAnaRecord(SignalRecord):
artifact_path = "portfolio_analysis"
- def __init__(self, recorder, config, **kwargs):
+ def __init__(
+ self,
+ recorder,
+ config,
+ risk_analysis_freq: Union[List, str] = None,
+ indicator_analysis_freq: Union[List, str] = None,
+ indicator_analysis_method=None,
+ **kwargs,
+ ):
"""
config["strategy"] : dict
define the strategy class as well as the kwargs.
+ config["executor"] : dict
+ define the executor class as well as the kwargs.
config["backtest"] : dict
define the backtest kwargs.
+ risk_analysis_freq : str|List[str]
+ risk analysis freq of report
+ indicator_analysis_freq : str|List[str]
+ indicator analysis freq of report
+ indicator_analysis_method : str, optional, default by None
+ the candidated values include 'mean', 'amount_weighted', 'value_weighted'
"""
super().__init__(recorder=recorder, **kwargs)
self.strategy_config = config["strategy"]
+ _default_executor_config = {
+ "class": "SimulatorExecutor",
+ "module_path": "qlib.backtest.executor",
+ "kwargs": {
+ "time_per_step": "day",
+ "generate_portfolio_metrics": True,
+ },
+ }
+ self.executor_config = config.get("executor", _default_executor_config)
self.backtest_config = config["backtest"]
- self.strategy = init_instance_by_config(self.strategy_config, accept_types=BaseStrategy)
- def generate(self, **kwargs):
- # check previously stored prediction results
- try:
- self.check(parent=True) # "Make sure the parent process is completed and store the data properly."
- except FileExistsError:
- super().generate()
+ self.all_freq = self._get_report_freq(self.executor_config)
+ if risk_analysis_freq is None:
+ risk_analysis_freq = [self.all_freq[0]]
+ if indicator_analysis_freq is None:
+ indicator_analysis_freq = [self.all_freq[0]]
+
+ if isinstance(risk_analysis_freq, str):
+ risk_analysis_freq = [risk_analysis_freq]
+ if isinstance(indicator_analysis_freq, str):
+ indicator_analysis_freq = [indicator_analysis_freq]
+
+ self.risk_analysis_freq = [
+ "{0}{1}".format(*Freq.parse(_analysis_freq)) for _analysis_freq in risk_analysis_freq
+ ]
+ self.indicator_analysis_freq = [
+ "{0}{1}".format(*Freq.parse(_analysis_freq)) for _analysis_freq in indicator_analysis_freq
+ ]
+ self.indicator_analysis_method = indicator_analysis_method
+
+ def _get_report_freq(self, executor_config):
+ ret_freq = []
+ if executor_config["kwargs"].get("generate_portfolio_metrics", False):
+ _count, _freq = Freq.parse(executor_config["kwargs"]["time_per_step"])
+ ret_freq.append(f"{_count}{_freq}")
+ if "sub_env" in executor_config["kwargs"]:
+ ret_freq.extend(self._get_report_freq(executor_config["kwargs"]["sub_env"]))
+ return ret_freq
+ def generate(self, **kwargs):
# custom strategy and get backtest
- pred_score = super().load("pred.pkl")
- report_dict = normal_backtest(pred_score, strategy=self.strategy, **self.backtest_config)
- report_normal = report_dict.get("report_df")
- positions_normal = report_dict.get("positions")
- self.recorder.save_objects(
- **{"report_normal.pkl": report_normal},
- artifact_path=PortAnaRecord.get_path(),
- )
- self.recorder.save_objects(
- **{"positions_normal.pkl": positions_normal},
- artifact_path=PortAnaRecord.get_path(),
+ portfolio_metric_dict, indicator_dict = normal_backtest(
+ executor=self.executor_config, strategy=self.strategy_config, **self.backtest_config
)
- order_normal = report_dict.get("order_list")
- if order_normal:
+ for _freq, (report_normal, positions_normal) in portfolio_metric_dict.items():
+ self.recorder.save_objects(
+ **{f"report_normal_{_freq}.pkl": report_normal}, artifact_path=PortAnaRecord.get_path()
+ )
self.recorder.save_objects(
- **{"order_normal.pkl": order_normal},
- artifact_path=PortAnaRecord.get_path(),
+ **{f"positions_normal_{_freq}.pkl": positions_normal}, artifact_path=PortAnaRecord.get_path()
)
- # analysis
- analysis = dict()
- analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"])
- analysis["excess_return_with_cost"] = risk_analysis(
- report_normal["return"] - report_normal["bench"] - report_normal["cost"]
- )
- # save portfolio analysis results
- analysis_df = pd.concat(analysis) # type: pd.DataFrame
- # log metrics
- self.recorder.log_metrics(**flatten_dict(analysis_df["risk"].unstack().T.to_dict()))
- # save results
- self.recorder.save_objects(**{"port_analysis.pkl": analysis_df}, artifact_path=PortAnaRecord.get_path())
- logger.info(
- f"Portfolio analysis record 'port_analysis.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}"
- )
- # print out results
- pprint("The following are analysis results of the excess return without cost.")
- pprint(analysis["excess_return_without_cost"])
- pprint("The following are analysis results of the excess return with cost.")
- pprint(analysis["excess_return_with_cost"])
+ for _freq, indicators_normal in indicator_dict.items():
+ self.recorder.save_objects(
+ **{f"indicators_normal_{_freq}.pkl": indicators_normal}, artifact_path=PortAnaRecord.get_path()
+ )
+
+ for _analysis_freq in self.risk_analysis_freq:
+ if _analysis_freq not in portfolio_metric_dict:
+ warnings.warn(
+ f"the freq {_analysis_freq} report is not found, please set the corresponding env with `generate_portfolio_metrics=True`"
+ )
+ else:
+ report_normal, _ = portfolio_metric_dict.get(_analysis_freq)
+ analysis = dict()
+ analysis["excess_return_without_cost"] = risk_analysis(
+ report_normal["return"] - report_normal["bench"], freq=_analysis_freq
+ )
+ analysis["excess_return_with_cost"] = risk_analysis(
+ report_normal["return"] - report_normal["bench"] - report_normal["cost"], freq=_analysis_freq
+ )
+
+ analysis_df = pd.concat(analysis) # type: pd.DataFrame
+ # log metrics
+ analysis_dict = flatten_dict(analysis_df["risk"].unstack().T.to_dict())
+ self.recorder.log_metrics(**{f"{_analysis_freq}.{k}": v for k, v in analysis_dict.items()})
+ # save results
+ self.recorder.save_objects(
+ **{f"port_analysis_{_analysis_freq}.pkl": analysis_df}, artifact_path=PortAnaRecord.get_path()
+ )
+ logger.info(
+ f"Portfolio analysis record 'port_analysis_{_analysis_freq}.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}"
+ )
+ # print out results
+ pprint(f"The following are analysis results of benchmark return({_analysis_freq}).")
+ pprint(risk_analysis(report_normal["bench"], freq=_analysis_freq))
+ pprint(f"The following are analysis results of the excess return without cost({_analysis_freq}).")
+ pprint(analysis["excess_return_without_cost"])
+ pprint(f"The following are analysis results of the excess return with cost({_analysis_freq}).")
+ pprint(analysis["excess_return_with_cost"])
+
+ for _analysis_freq in self.indicator_analysis_freq:
+ if _analysis_freq not in indicator_dict:
+ warnings.warn(f"the freq {_analysis_freq} indicator is not found")
+ else:
+ indicators_normal = indicator_dict.get(_analysis_freq)
+ if self.indicator_analysis_method is None:
+ analysis_df = indicator_analysis(indicators_normal)
+ else:
+ analysis_df = indicator_analysis(indicators_normal, method=self.indicator_analysis_method)
+ # log metrics
+ analysis_dict = analysis_df["value"].to_dict()
+ self.recorder.log_metrics(**{f"{_analysis_freq}.{k}": v for k, v in analysis_dict.items()})
+ # save results
+ self.recorder.save_objects(
+ **{f"indicator_analysis_{_analysis_freq}.pkl": analysis_df}, artifact_path=PortAnaRecord.get_path()
+ )
+ logger.info(
+ f"Indicator analysis record 'indicator_analysis_{_analysis_freq}.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}"
+ )
+ pprint(f"The following are analysis results of indicators({_analysis_freq}).")
+ pprint(analysis_df)
def list(self):
- return [
- PortAnaRecord.get_path("report_normal.pkl"),
- PortAnaRecord.get_path("positions_normal.pkl"),
- PortAnaRecord.get_path("port_analysis.pkl"),
- ]
+ list_path = []
+ for _freq in self.all_freq:
+ list_path.extend(
+ [
+ PortAnaRecord.get_path(f"report_normal_{_freq}.pkl"),
+ PortAnaRecord.get_path(f"positions_normal_{_freq}.pkl"),
+ ]
+ )
+ for _analysis_freq in self.risk_analysis_freq:
+ if _analysis_freq in self.all_freq:
+ list_path.append(PortAnaRecord.get_path(f"port_analysis_{_analysis_freq}.pkl"))
+ else:
+ warnings.warn(f"risk_analysis freq {_analysis_freq} is not found")
+
+ for _analysis_freq in self.indicator_analysis_freq:
+ if _analysis_freq in self.all_freq:
+ list_path.append(PortAnaRecord.get_path(f"indicator_analysis_{_analysis_freq}.pkl"))
+ else:
+ warnings.warn(f"indicator_analysis freq {_analysis_freq} is not found")
+
+ return list_path
diff --git a/qlib/workflow/task/manage.py b/qlib/workflow/task/manage.py
index 18b389c734..48289d9bc6 100644
--- a/qlib/workflow/task/manage.py
+++ b/qlib/workflow/task/manage.py
@@ -314,6 +314,8 @@ def query(self, query={}, decode=True):
Query task in collection.
This function may raise exception `pymongo.errors.CursorNotFound: cursor id not found` if it takes too long to iterate the generator
+ python -m qlib.workflow.task.manage -t query '{"_id": "615498be837d0053acbc5d58"}'
+
Parameters
----------
query: dict
diff --git a/setup.py b/setup.py
index 21d56371e2..d87610b8a9 100644
--- a/setup.py
+++ b/setup.py
@@ -51,7 +51,7 @@
"fire>=0.3.1",
"statsmodels",
"xlrd>=1.0.0",
- "plotly==4.12.0",
+ "plotly>=4.12.0",
"matplotlib>=3.3",
"tables>=3.6.1",
"pyyaml>=5.3.1",
@@ -65,6 +65,7 @@
"pymongo==3.7.2", # For task management
"scikit-learn>=0.22",
"dill",
+ "dataclasses;python_version<'3.7'",
"filelock",
]
diff --git a/tests/backtest/test_file_strategy.py b/tests/backtest/test_file_strategy.py
new file mode 100644
index 0000000000..83b1793976
--- /dev/null
+++ b/tests/backtest/test_file_strategy.py
@@ -0,0 +1,108 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+import unittest
+from qlib.backtest import backtest, decision
+from qlib.tests import TestAutoData
+import pandas as pd
+from pathlib import Path
+
+DIRNAME = Path(__file__).absolute().resolve().parent
+
+
+class FileStrTest(TestAutoData):
+
+ TEST_INST = "SH600519"
+
+ EXAMPLE_FILE = DIRNAME / "order_example.csv"
+
+ DEAL_NUM_FOR_1000 = 123.47105436976445
+
+ def _gen_orders(self) -> pd.DataFrame:
+ headers = [
+ "datetime",
+ "instrument",
+ "amount",
+ "direction",
+ ]
+ orders = [
+ # test cash limit for buying
+ ["20200103", self.TEST_INST, "1000", "buy"],
+ # test min_cost for buying
+ ["20200106", self.TEST_INST, "1", "buy"],
+ # test held stock limit for selling
+ ["20200107", self.TEST_INST, "1000", "sell"],
+ # test cash limit for buying
+ ["20200108", self.TEST_INST, "1000", "buy"],
+ # test min_cost for selling
+ ["20200109", self.TEST_INST, "1", "sell"],
+ # test selling all stocks
+ ["20200110", self.TEST_INST, str(self.DEAL_NUM_FOR_1000), "sell"],
+ ]
+ return pd.DataFrame(orders, columns=headers).set_index(["datetime", "instrument"])
+
+ def test_file_str(self):
+
+ orders = self._gen_orders()
+ print(orders)
+ orders.to_csv(self.EXAMPLE_FILE)
+
+ orders = pd.read_csv(self.EXAMPLE_FILE, index_col=["datetime", "instrument"])
+
+ strategy_config = {
+ "class": "FileOrderStrategy",
+ "module_path": "qlib.contrib.strategy.rule_strategy",
+ "kwargs": {"file": self.EXAMPLE_FILE},
+ }
+
+ freq = "day"
+ start_time = "2020-01-01"
+ end_time = "2020-01-16"
+ codes = [self.TEST_INST]
+
+ backtest_config = {
+ "start_time": start_time,
+ "end_time": end_time,
+ "account": 30000,
+ "benchmark": None, # benchmark is not required here for trading
+ "exchange_kwargs": {
+ "freq": freq,
+ "limit_threshold": 0.095,
+ "deal_price": "close",
+ "open_cost": 0.0005,
+ "close_cost": 0.0015,
+ "min_cost": 500,
+ "codes": codes,
+ "trade_unit": 2,
+ },
+ # "pos_type": "InfPosition" # Position with infinitive position
+ }
+ executor_config = {
+ "class": "SimulatorExecutor",
+ "module_path": "qlib.backtest.executor",
+ "kwargs": {
+ "time_per_step": freq,
+ "generate_portfolio_metrics": False,
+ "verbose": True,
+ "indicator_config": {
+ "show_indicator": False,
+ },
+ },
+ }
+ report_dict, indicator_dict = backtest(executor=executor_config, strategy=strategy_config, **backtest_config)
+
+ # ffr valid
+ ffr_dict = indicator_dict["1day"]["ffr"].to_dict()
+ ffr_dict = {str(date).split()[0]: ffr_dict[date] for date in ffr_dict}
+ assert ffr_dict["2020-01-03"] == self.DEAL_NUM_FOR_1000 / 1000
+ assert ffr_dict["2020-01-06"] == 0
+ assert ffr_dict["2020-01-07"] == self.DEAL_NUM_FOR_1000 / 1000
+ assert ffr_dict["2020-01-08"] == self.DEAL_NUM_FOR_1000 / 1000
+ assert ffr_dict["2020-01-09"] == 0
+ assert ffr_dict["2020-01-10"] == 1
+
+ self.EXAMPLE_FILE.unlink()
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/backtest/test_high_freq_trading.py b/tests/backtest/test_high_freq_trading.py
new file mode 100644
index 0000000000..550bddac31
--- /dev/null
+++ b/tests/backtest/test_high_freq_trading.py
@@ -0,0 +1,133 @@
+from typing import List, Tuple, Union
+from qlib.backtest.position import Position
+from qlib.backtest import collect_data, format_decisions
+from qlib.backtest.decision import BaseTradeDecision, TradeRangeByTime
+import qlib
+from qlib.tests import TestAutoData
+import unittest
+from qlib.config import REG_CN, HIGH_FREQ_CONFIG
+import pandas as pd
+
+
+@unittest.skip("This test takes a lot of time due to the large size of high-frequency data")
+class TestHFBacktest(TestAutoData):
+ @classmethod
+ def setUpClass(cls) -> None:
+ super().setUpClass(enable_1min=True, enable_1d_type="full")
+
+ def _gen_orders(self, inst, date, pos) -> pd.DataFrame:
+ headers = [
+ "datetime",
+ "instrument",
+ "amount",
+ "direction",
+ ]
+ orders = [
+ [date, inst, pos, "sell"],
+ ]
+ return pd.DataFrame(orders, columns=headers)
+
+ def test_trading(self):
+
+ # date = "2020-02-03"
+ # inst = "SH600068"
+ # pos = 2.0167
+ pos = 100000
+ inst, date = "SH600519", "2021-01-18"
+ market = [inst]
+
+ start_time = f"{date}"
+ end_time = f"{date} 15:00" # include the high-freq data on the end day
+ freq_l0 = "day"
+ freq_l1 = "30min"
+ freq_l2 = "1min"
+
+ orders = self._gen_orders(inst=inst, date=date, pos=pos * 0.90)
+
+ strategy_config = {
+ "class": "FileOrderStrategy",
+ "module_path": "qlib.contrib.strategy.rule_strategy",
+ "kwargs": {
+ "trade_range": TradeRangeByTime("10:45", "14:44"),
+ "file": orders,
+ },
+ }
+ backtest_config = {
+ "start_time": start_time,
+ "end_time": end_time,
+ "account": {
+ "cash": 0,
+ inst: pos,
+ },
+ "benchmark": None, # benchmark is not required here for trading
+ "exchange_kwargs": {
+ "freq": freq_l2, # use the most fine-grained data as the exchange
+ "limit_threshold": 0.095,
+ "deal_price": "close",
+ "open_cost": 0.0005,
+ "close_cost": 0.0015,
+ "min_cost": 5,
+ "codes": market,
+ "trade_unit": 100,
+ },
+ # "pos_type": "InfPosition" # Position with infinitive position
+ }
+ executor_config = {
+ "class": "NestedExecutor", # Level 1 Order execution
+ "module_path": "qlib.backtest.executor",
+ "kwargs": {
+ "time_per_step": freq_l0,
+ "inner_executor": {
+ "class": "NestedExecutor", # Leve 2 Order Execution
+ "module_path": "qlib.backtest.executor",
+ "kwargs": {
+ "time_per_step": freq_l1,
+ "inner_executor": {
+ "class": "SimulatorExecutor",
+ "module_path": "qlib.backtest.executor",
+ "kwargs": {
+ "time_per_step": freq_l2,
+ "generate_portfolio_metrics": False,
+ "verbose": True,
+ "indicator_config": {
+ "show_indicator": False,
+ },
+ "track_data": True,
+ },
+ },
+ "inner_strategy": {
+ "class": "TWAPStrategy",
+ "module_path": "qlib.contrib.strategy.rule_strategy",
+ },
+ "generate_portfolio_metrics": False,
+ "indicator_config": {
+ "show_indicator": True,
+ },
+ "track_data": True,
+ },
+ },
+ "inner_strategy": {
+ "class": "TWAPStrategy",
+ "module_path": "qlib.contrib.strategy.rule_strategy",
+ },
+ "generate_portfolio_metrics": False,
+ "indicator_config": {
+ "show_indicator": True,
+ },
+ "track_data": True,
+ },
+ }
+
+ ret_val = {}
+ decisions = list(
+ collect_data(executor=executor_config, strategy=strategy_config, **backtest_config, return_value=ret_val)
+ )
+ report, indicator = ret_val["report"], ret_val["indicator"]
+ # NOTE: please refer to the docs of format_decisions
+ # NOTE: `"track_data": True,` is very NECESSARY for collecting the decision!!!!!
+ f_dec = format_decisions(decisions)
+ print(indicator["1day"])
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/misc/test_get_multi_proc.py b/tests/misc/test_get_multi_proc.py
new file mode 100644
index 0000000000..7e27781b6e
--- /dev/null
+++ b/tests/misc/test_get_multi_proc.py
@@ -0,0 +1,39 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+import unittest
+
+import qlib
+from qlib.data import D
+from qlib.tests import TestAutoData
+from multiprocessing import Pool
+
+
+def get_features(fields):
+ qlib.init(provider_uri=TestAutoData.provider_uri, expression_cache=None, dataset_cache=None, joblib_backend="loky")
+ return D.features(D.instruments("csi300"), fields)
+
+
+class TestGetData(TestAutoData):
+ FIELDS = "$open,$close,$high,$low,$volume,$factor,$change".split(",")
+
+ def test_multi_proc(self):
+ """
+ For testing if it will raise error
+ """
+ iter_n = 2
+ pool = Pool(iter_n)
+
+ res = []
+ for _ in range(iter_n):
+ res.append(pool.apply_async(get_features, (self.FIELDS,), {}))
+
+ for r in res:
+ print(r.get())
+
+ pool.close()
+ pool.join()
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/misc/test_index_data.py b/tests/misc/test_index_data.py
new file mode 100644
index 0000000000..3cd819a0f9
--- /dev/null
+++ b/tests/misc/test_index_data.py
@@ -0,0 +1,120 @@
+import numpy as np
+import pandas as pd
+
+import qlib.utils.index_data as idd
+
+import unittest
+
+
+class IndexDataTest(unittest.TestCase):
+ def test_index_single_data(self):
+ # Auto broadcast for scalar
+ sd = idd.SingleData(0, index=["foo", "bar"])
+ print(sd)
+
+ # Support empty value
+ sd = idd.SingleData()
+ print(sd)
+
+ # Bad case: the input is not aligned
+ with self.assertRaises(ValueError):
+ idd.SingleData(range(10), index=["foo", "bar"])
+
+ # test indexing
+ sd = idd.SingleData([1, 2, 3, 4], index=["foo", "bar", "f", "g"])
+ print(sd)
+ print(sd.iloc[1]) # get second row
+
+ # Bad case: it is not in the index
+ with self.assertRaises(KeyError):
+ print(sd.loc[1])
+
+ print(sd.loc["foo"])
+
+ # Test slicing
+ print(sd.loc[:"bar"])
+
+ print(sd.iloc[:3])
+
+ def test_index_multi_data(self):
+ # Auto broadcast for scalar
+ sd = idd.MultiData(0, index=["foo", "bar"], columns=["f", "g"])
+ print(sd)
+
+ # Bad case: the input is not aligned
+ with self.assertRaises(ValueError):
+ idd.MultiData(range(10), index=["foo", "bar"], columns=["f", "g"])
+
+ # test indexing
+ sd = idd.MultiData(np.arange(4).reshape(2, 2), index=["foo", "bar"], columns=["f", "g"])
+ print(sd)
+ print(sd.iloc[1]) # get second row
+
+ # Bad case: it is not in the index
+ with self.assertRaises(KeyError):
+ print(sd.loc[1])
+
+ print(sd.loc["foo"])
+
+ # Test slicing
+
+ print(sd.loc[:"foo"])
+
+ print(sd.loc[:, "g":])
+
+ def test_sorting(self):
+ sd = idd.MultiData(np.arange(4).reshape(2, 2), index=["foo", "bar"], columns=["f", "g"])
+ print(sd)
+ sd.sort_index()
+
+ print(sd)
+ print(sd.loc[:"c"])
+
+ def test_corner_cases(self):
+ sd = idd.MultiData([[1, 2], [3, np.NaN]], index=["foo", "bar"], columns=["f", "g"])
+ print(sd)
+
+ self.assertTrue(np.isnan(sd.loc["bar", "g"]))
+
+ # support slicing
+ print(sd.loc[~sd.loc[:, "g"].isna().data.astype(np.bool)])
+
+ print(self.assertTrue(idd.SingleData().index == idd.SingleData().index))
+
+ # empty dict
+ print(idd.SingleData({}))
+ print(idd.SingleData(pd.Series()))
+
+ sd = idd.SingleData()
+ with self.assertRaises(KeyError):
+ sd.loc["foo"]
+
+ # replace
+ sd = idd.SingleData([1, 2, 3, 4], index=["foo", "bar", "f", "g"])
+ sd = sd.replace(dict(zip(range(1, 5), range(2, 6))))
+ print(sd)
+ self.assertTrue(sd.iloc[0] == 2)
+
+ def test_ops(self):
+ sd1 = idd.SingleData([1, 2, 3, 4], index=["foo", "bar", "f", "g"])
+ sd2 = idd.SingleData([1, 2, 3, 4], index=["foo", "bar", "f", "g"])
+ print(sd1 + sd2)
+ new_sd = sd2 * 2
+ self.assertTrue(new_sd.index == sd2.index)
+
+ sd1 = idd.SingleData([1, 2, None, 4], index=["foo", "bar", "f", "g"])
+ sd2 = idd.SingleData([1, 2, 3, None], index=["foo", "bar", "f", "g"])
+ self.assertTrue(np.isnan((sd1 + sd2).iloc[3]))
+ self.assertTrue(sd1.add(sd2).sum() == 13)
+
+ self.assertTrue(idd.sum_by_index([sd1, sd2], sd1.index, fill_value=0.0).sum() == 13)
+
+ def test_todo(self):
+ pass
+ # here are some examples which do not affect the current system, but it is weird not to support it
+ # sd2 = idd.SingleData([1, 2, 3, 4], index=["foo", "bar", "f", "g"])
+ # 2 * sd2
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/misc/test_utils.py b/tests/misc/test_utils.py
new file mode 100644
index 0000000000..4dabf5ed89
--- /dev/null
+++ b/tests/misc/test_utils.py
@@ -0,0 +1,89 @@
+from unittest.case import TestCase
+import unittest
+import pandas as pd
+import numpy as np
+from datetime import datetime
+from qlib import init
+from qlib.config import C
+from qlib.log import TimeInspector
+from qlib.utils.time import cal_sam_minute as cal_sam_minute_new, get_min_cal
+
+
+def cal_sam_minute(x, sam_minutes):
+ """
+ Sample raw calendar into calendar with sam_minutes freq, shift represents the shift minute the market time
+ - open time of stock market is [9:30 - shift*pd.Timedelta(minutes=1)]
+ - mid close time of stock market is [11:29 - shift*pd.Timedelta(minutes=1)]
+ - mid open time of stock market is [13:00 - shift*pd.Timedelta(minutes=1)]
+ - close time of stock market is [14:59 - shift*pd.Timedelta(minutes=1)]
+ """
+ # TODO: actually, this version is much faster when no cache or optimization
+ day_time = pd.Timestamp(x.date())
+ shift = C.min_data_shift
+
+ open_time = day_time + pd.Timedelta(hours=9, minutes=30) - shift * pd.Timedelta(minutes=1)
+ mid_close_time = day_time + pd.Timedelta(hours=11, minutes=29) - shift * pd.Timedelta(minutes=1)
+ mid_open_time = day_time + pd.Timedelta(hours=13, minutes=00) - shift * pd.Timedelta(minutes=1)
+ close_time = day_time + pd.Timedelta(hours=14, minutes=59) - shift * pd.Timedelta(minutes=1)
+
+ if open_time <= x <= mid_close_time:
+ minute_index = (x - open_time).seconds // 60
+ elif mid_open_time <= x <= close_time:
+ minute_index = (x - mid_open_time).seconds // 60 + 120
+ else:
+ raise ValueError("datetime of calendar is out of range")
+ minute_index = minute_index // sam_minutes * sam_minutes
+
+ if 0 <= minute_index < 120:
+ return open_time + minute_index * pd.Timedelta(minutes=1)
+ elif 120 <= minute_index < 240:
+ return mid_open_time + (minute_index - 120) * pd.Timedelta(minutes=1)
+ else:
+ raise ValueError("calendar minute_index error, check `min_data_shift` in qlib.config.C")
+
+
+class TimeUtils(TestCase):
+ @classmethod
+ def setUpClass(cls):
+ init()
+
+ def test_cal_sam_minute(self):
+ # test the correctness of the code
+ random_n = 1000
+ cal = get_min_cal()
+
+ def gen_args():
+ for time in np.random.choice(cal, size=random_n, replace=True):
+ sam_minutes = np.random.choice([1, 2, 3, 4, 5, 6])
+ dt = pd.Timestamp(
+ datetime(
+ 2021,
+ month=3,
+ day=3,
+ hour=time.hour,
+ minute=time.minute,
+ second=time.second,
+ microsecond=time.microsecond,
+ )
+ )
+ args = dt, sam_minutes
+ yield args
+
+ for args in gen_args():
+ assert cal_sam_minute(*args) == cal_sam_minute_new(*args)
+
+ # test the performance of the code
+
+ args_l = list(gen_args())
+
+ with TimeInspector.logt():
+ for args in args_l:
+ cal_sam_minute(*args)
+
+ with TimeInspector.logt():
+ for args in args_l:
+ cal_sam_minute_new(*args)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_all_pipeline.py b/tests/test_all_pipeline.py
index 13420335b2..606a0ea3ba 100644
--- a/tests/test_all_pipeline.py
+++ b/tests/test_all_pipeline.py
@@ -14,27 +14,6 @@
from qlib.tests import TestAutoData
from qlib.tests.config import CSI300_GBDT_TASK, CSI300_BENCH
-port_analysis_config = {
- "strategy": {
- "class": "TopkDropoutStrategy",
- "module_path": "qlib.contrib.strategy.strategy",
- "kwargs": {
- "topk": 50,
- "n_drop": 5,
- },
- },
- "backtest": {
- "verbose": False,
- "limit_threshold": 0.095,
- "account": 100000000,
- "benchmark": CSI300_BENCH,
- "deal_price": "close",
- "open_cost": 0.0005,
- "close_cost": 0.0015,
- "min_cost": 5,
- },
-}
-
def train(uri_path: str = None):
"""train model
@@ -58,7 +37,7 @@ def train(uri_path: str = None):
with R.start(experiment_name="workflow", uri=uri_path):
R.log_params(**flatten_dict(CSI300_GBDT_TASK))
model.fit(dataset)
-
+ R.save_objects(trained_model=model)
# prediction
recorder = R.get_recorder()
# To test __repr__
@@ -68,7 +47,6 @@ def train(uri_path: str = None):
rid = recorder.id
sr = SignalRecord(model, dataset, recorder)
sr.generate()
- pred_score = sr.load()
# calculate ic and ric
sar = SigAnaRecord(recorder)
@@ -76,7 +54,7 @@ def train(uri_path: str = None):
ic = sar.load(sar.get_path("ic.pkl"))
ric = sar.load(sar.get_path("ric.pkl"))
- return pred_score, {"ic": ic, "ric": ric}, rid
+ return {"ic": ic, "ric": ric}, rid
def train_with_sigana(uri_path: str = None):
@@ -102,10 +80,9 @@ def train_with_sigana(uri_path: str = None):
sar.generate()
ic = sar.load(sar.get_path("ic.pkl"))
ric = sar.load(sar.get_path("ric.pkl"))
- pred_score = sar.load("pred.pkl")
uri_path = R.get_uri()
- return pred_score, {"ic": ic, "ric": ric}, uri_path
+ return {"ic": ic, "ric": ric}, uri_path
def fake_experiment():
@@ -134,8 +111,6 @@ def backtest_analysis(pred, rid, uri_path: str = None):
Parameters
----------
- pred : pandas.DataFrame
- predict scores
rid : str
the id of the recorder to be used in this function
uri_path: str
@@ -147,18 +122,53 @@ def backtest_analysis(pred, rid, uri_path: str = None):
the analysis result
"""
- with R.start(experiment_name="workflow", recorder_id=rid, uri=uri_path):
- recorder = R.get_recorder(experiment_name="workflow", recorder_id=rid)
+ recorder = R.get_recorder(experiment_name="workflow", recorder_id=rid)
+ dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"])
+ model = recorder.load_object("trained_model")
+
+ port_analysis_config = {
+ "executor": {
+ "class": "SimulatorExecutor",
+ "module_path": "qlib.backtest.executor",
+ "kwargs": {
+ "time_per_step": "day",
+ "generate_portfolio_metrics": True,
+ },
+ },
+ "strategy": {
+ "class": "TopkDropoutStrategy",
+ "module_path": "qlib.contrib.strategy.model_strategy",
+ "kwargs": {
+ "model": model,
+ "dataset": dataset,
+ "topk": 50,
+ "n_drop": 5,
+ },
+ },
+ "backtest": {
+ "start_time": "2017-01-01",
+ "end_time": "2020-08-01",
+ "account": 100000000,
+ "benchmark": CSI300_BENCH,
+ "exchange_kwargs": {
+ "freq": "day",
+ "limit_threshold": 0.095,
+ "deal_price": "close",
+ "open_cost": 0.0005,
+ "close_cost": 0.0015,
+ "min_cost": 5,
+ },
+ },
+ }
# backtest
- par = PortAnaRecord(recorder, port_analysis_config)
+ par = PortAnaRecord(recorder, port_analysis_config, risk_analysis_freq="day")
par.generate()
- analysis_df = par.load(par.get_path("port_analysis.pkl"))
+ analysis_df = par.load(par.get_path("port_analysis_1day.pkl"))
print(analysis_df)
return analysis_df
class TestAllFlow(TestAutoData):
- PRED_SCORE = None
REPORT_NORMAL = None
POSITIONS = None
RID = None
diff --git a/tests/test_handler_storage.py b/tests/test_handler_storage.py
new file mode 100644
index 0000000000..056595063b
--- /dev/null
+++ b/tests/test_handler_storage.py
@@ -0,0 +1,114 @@
+import unittest
+import time
+import numpy as np
+from qlib.data import D
+from qlib.tests import TestAutoData
+
+from qlib.data.dataset.handler import DataHandlerLP
+from qlib.contrib.data.handler import check_transform_proc
+from qlib.log import TimeInspector
+
+
+class TestHandler(DataHandlerLP):
+ def __init__(
+ self,
+ instruments="csi300",
+ start_time=None,
+ end_time=None,
+ infer_processors=[],
+ learn_processors=[],
+ fit_start_time=None,
+ fit_end_time=None,
+ drop_raw=True,
+ ):
+
+ infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
+ learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
+
+ data_loader = {
+ "class": "QlibDataLoader",
+ "kwargs": {
+ "freq": "day",
+ "config": self.get_feature_config(),
+ "swap_level": False,
+ },
+ }
+
+ super().__init__(
+ instruments=instruments,
+ start_time=start_time,
+ end_time=end_time,
+ data_loader=data_loader,
+ infer_processors=infer_processors,
+ learn_processors=learn_processors,
+ drop_raw=drop_raw,
+ )
+
+ def get_feature_config(self):
+ fields = ["Ref($open, 1)", "Ref($close, 1)", "Ref($volume, 1)", "$open", "$close", "$volume"]
+ names = ["open_0", "close_0", "volume_0", "open_1", "close_1", "volume_1"]
+ return fields, names
+
+
+class TestHandlerStorage(TestAutoData):
+
+ market = "all"
+
+ start_time = "2010-01-01"
+ end_time = "2020-12-31"
+ train_end_time = "2015-12-31"
+ test_start_time = "2016-01-01"
+
+ data_handler_kwargs = {
+ "start_time": start_time,
+ "end_time": end_time,
+ "fit_start_time": start_time,
+ "fit_end_time": train_end_time,
+ "instruments": market,
+ }
+
+ def test_handler_storage(self):
+ # init data handler
+ data_handler = TestHandler(**self.data_handler_kwargs)
+
+ # init data handler with hasing storage
+ data_handler_hs = TestHandler(**self.data_handler_kwargs, infer_processors=["HashStockFormat"])
+
+ fetch_start_time = "2019-01-01"
+ fetch_end_time = "2019-12-31"
+ instruments = D.instruments(market=self.market)
+ instruments = D.list_instruments(
+ instruments=instruments, start_time=fetch_start_time, end_time=fetch_end_time, as_list=True
+ )
+
+ with TimeInspector.logt("random fetch with DataFrame Storage"):
+
+ # single stock
+ for i in range(100):
+ random_index = np.random.randint(len(instruments), size=1)[0]
+ fetch_stock = instruments[random_index]
+ data_handler.fetch(selector=(fetch_stock, slice(fetch_start_time, fetch_end_time)), level=None)
+
+ # multi stocks
+ for i in range(100):
+ random_indexs = np.random.randint(len(instruments), size=5)
+ fetch_stocks = [instruments[_index] for _index in random_indexs]
+ data_handler.fetch(selector=(fetch_stocks, slice(fetch_start_time, fetch_end_time)), level=None)
+
+ with TimeInspector.logt("random fetch with HasingStock Storage"):
+
+ # single stock
+ for i in range(100):
+ random_index = np.random.randint(len(instruments), size=1)[0]
+ fetch_stock = instruments[random_index]
+ data_handler_hs.fetch(selector=(fetch_stock, slice(fetch_start_time, fetch_end_time)), level=None)
+
+ # multi stocks
+ for i in range(100):
+ random_indexs = np.random.randint(len(instruments), size=5)
+ fetch_stocks = [instruments[_index] for _index in random_indexs]
+ data_handler_hs.fetch(selector=(fetch_stocks, slice(fetch_start_time, fetch_end_time)), level=None)
+
+
+if __name__ == "__main__":
+ unittest.main()