diff --git a/examples/benchmarks/README.md b/examples/benchmarks/README.md index ee2c0a8331..c05aa02897 100644 --- a/examples/benchmarks/README.md +++ b/examples/benchmarks/README.md @@ -25,6 +25,7 @@ The numbers shown below demonstrate the performance of the entire `workflow` of | TCTS (Xueqing Wu, et al.)| Alpha360 | 0.0485±0.00 | 0.3689±0.04| 0.0586±0.00 | 0.4669±0.02 | 0.0816±0.02 | 1.1572±0.30| -0.0689±0.02 | | Transformer (Ashish Vaswani, et al.)| Alpha360 | 0.0141±0.00 | 0.0917±0.02| 0.0331±0.00 | 0.2357±0.03 | -0.0259±0.03 | -0.3323±0.43| -0.1763±0.07 | | Localformer (Juyong Jiang, et al.)| Alpha360 | 0.0408±0.00 | 0.2988±0.03| 0.0538±0.00 | 0.4105±0.02 | 0.0275±0.03 | 0.3464±0.37| -0.1182±0.03 | +| TRA (Hengxu Lin, et al.)| Alpha360 | 0.0500±0.00 | 0.3966±0.04 | 0.0594±0.00 | 0.4856±0.03 | 0.1000±0.02 | 1.3425±0.31 | -0.0845±0.02 | ## Alpha158 dataset | Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown | @@ -43,6 +44,8 @@ The numbers shown below demonstrate the performance of the entire `workflow` of | TabNet (Sercan O. Arik, et al.)| Alpha158 | 0.0383±0.00 | 0.3414±0.00| 0.0388±0.00 | 0.3460±0.00 | 0.0226±0.00 | 0.2652±0.00| -0.1072±0.00 | | Transformer (Ashish Vaswani, et al.)| Alpha158 | 0.0274±0.00 | 0.2166±0.04| 0.0409±0.00 | 0.3342±0.04 | 0.0204±0.03 | 0.2888±0.40| -0.1216±0.04 | | Localformer (Juyong Jiang, et al.)| Alpha158 | 0.0355±0.00 | 0.2747±0.04| 0.0466±0.00 | 0.3762±0.03 | 0.0506±0.02 | 0.7447±0.34| -0.0875±0.02 | +| TRA (Hengxu Lin, et al.)| Alpha158 (with selected 20 features) | 0.0440±0.00 | 0.3592±0.03 | 0.0500±0.00 | 0.4256±0.02 | 0.0747±0.03 | 1.1281±0.49 | -0.0813±0.03 | +| TRA (Hengxu Lin, et al.)| Alpha158 | 0.0474±0.00 | 0.3653±0.03 | 0.0573±0.00 | 0.4494±0.02 | 0.0770±0.02 | 1.1342±0.38 | -0.0852±0.03 | - The selected 20 features are based on the feature importance of a lightgbm-based model. - The base model of DoubleEnsemble is LGBM. diff --git a/examples/benchmarks/TRA/README.md b/examples/benchmarks/TRA/README.md index 070527ddb4..6d3e7a4769 100644 --- a/examples/benchmarks/TRA/README.md +++ b/examples/benchmarks/TRA/README.md @@ -1,53 +1,77 @@ # Learning Multiple Stock Trading Patterns with Temporal Routing Adaptor and Optimal Transport -This code provides a PyTorch implementation for TRA (Temporal Routing Adaptor), as described in the paper [Learning Multiple Stock Trading Patterns with Temporal Routing Adaptor and Optimal Transport](http://arxiv.org/abs/2106.12950). +Temporal Routing Adaptor (TRA) is designed to capture multiple trading patterns in the stock market data. Please refer to [our paper](http://arxiv.org/abs/2106.12950) for more details. -* TRA (Temporal Routing Adaptor) is a lightweight module that consists of a set of independent predictors for learning multiple patterns as well as a router to dispatch samples to different predictors. -* We also design a learning algorithm based on Optimal Transport (OT) to obtain the optimal sample to predictor assignment and effectively optimize the router with such assignment through an auxiliary loss term. +If you find our work useful in your research, please cite: +``` +@inproceedings{HengxuKDD2021, + author = {Hengxu Lin and Dong Zhou and Weiqing Liu and Jiang Bian}, + title = {Learning Multiple Stock Trading Patterns with Temporal Routing Adaptor and Optimal Transport}, + booktitle = {Proceedings of the 27th ACM SIGKDD Conference on Knowledge Discovery \& Data Mining}, + series = {KDD '21}, + year = {2021}, + publisher = {ACM}, +} + +@article{yang2020qlib, + title={Qlib: An AI-oriented Quantitative Investment Platform}, + author={Yang, Xiao and Liu, Weiqing and Zhou, Dong and Bian, Jiang and Liu, Tie-Yan}, + journal={arXiv preprint arXiv:2009.11189}, + year={2020} +} +``` + +## Usage (Recommended) + +**Update**: `TRA` has been moved to `qlib.contrib.model.pytorch_tra` to support other `Qlib` components like `qlib.workflow` and `Alpha158/Alpha360` dataset. + +Please follow the official [doc](https://qlib.readthedocs.io/en/latest/component/workflow.html) to use `TRA` with `workflow`. Here we also provide several example config files: + +- `workflow_config_tra_Alpha360.yaml`: running `TRA` with `Alpha360` dataset +- `workflow_config_tra_Alpha158.yaml`: running `TRA` with `Alpha158` dataset (with feature subsampling) +- `workflow_config_tra_Alpha158_full.yaml`: running `TRA` with `Alpha158` dataset (without feature subsampling) +The performances of `TRA` are reported in [Benchmarks](https://github.com/microsoft/qlib/tree/main/examples/benchmarks). -# Running TRA +## Usage (Not Maintained) -## Requirements -- Install `Qlib` main branch +This section is used to reproduce the results in the paper. -## Running +### Running We attach our running scripts for the paper in `run.sh`. And here are two ways to run the model: * Running from scripts with default parameters - You can directly run from Qlib command `qrun`: - ``` - qrun configs/config_alstm.yaml - ``` + + You can directly run from Qlib command `qrun`: + ``` + qrun configs/config_alstm.yaml + ``` * Running from code with self-defined parameters - Setting different parameters is also allowed. See codes in `example.py`: - ``` - python example.py --config_file configs/config_alstm.yaml - ``` -Here we trained TRA on a pretrained backbone model. Therefore we run `*_init.yaml` before TRA's scipts. + Setting different parameters is also allowed. See codes in `example.py`: + ``` + python example.py --config_file configs/config_alstm.yaml + ``` -# Results +Here we trained TRA on a pretrained backbone model. Therefore we run `*_init.yaml` before TRA's scipts. -## Outputs +### Results After running the scripts, you can find result files in path `./output`: -`info.json` - config settings and result metrics. - -`log.csv` - running logs. +* `info.json` - config settings and result metrics. +* `log.csv` - running logs. +* `model.bin` - the model parameter dictionary. +* `pred.pkl` - the prediction scores and output for inference. -`model.bin` - the model parameter dictionary. +Evaluation metrics reported in the paper: -`pred.pkl` - the prediction scores and output for inference. - -## Our Results | Methods | MSE| MAE| IC | ICIR | AR | AV | SR | MDD | -|-------------------|-------------------|---------------------|--------------------|--------------------|--------------------|--------------------|--------------------|--------------------| +|-------|-------|------|-----|-----|-----|-----|-----|-----| |Linear|0.163|0.327|0.020|0.132|-3.2%|16.8%|-0.191|32.1%| |LightGBM|0.160(0.000)|0.323(0.000)|0.041|0.292|7.8%|15.5%|0.503|25.7%| |MLP|0.160(0.002)|0.323(0.003)|0.037|0.273|3.7%|15.3%|0.264|26.2%| @@ -61,21 +85,8 @@ After running the scripts, you can find result files in path `./output`: A more detailed demo for our experiment results in the paper can be found in `Report.ipynb`. -# Common Issues +## Common Issues For help or issues using TRA, please submit a GitHub issue. -Sometimes we might encounter situation where the loss is `NaN`, please check the `epsilon` parameter in the sinkhorn algorithm, adjusting the `epsilon` according to input's scale is important. - -# Citation -If you find this repository useful in your research, please cite: -``` -@inproceedings{HengxuKDD2021, - author = {Hengxu Lin and Dong Zhou and Weiqing Liu and Jiang Bian}, - title = {Learning Multiple Stock Trading Patterns with Temporal Routing Adaptor and Optimal Transport}, - booktitle = {Proceedings of the 27th ACM SIGKDD Conference on Knowledge Discovery \& Data Mining}, - series = {KDD '21}, - year = {2021}, - publisher = {ACM}, -} -``` +Sometimes we might encounter situation where the loss is `NaN`, please check the `epsilon` parameter in the sinkhorn algorithm, adjusting the `epsilon` according to input's scale is important. diff --git a/examples/benchmarks/TRA/workflow_config_tra_Alpha158.yaml b/examples/benchmarks/TRA/workflow_config_tra_Alpha158.yaml new file mode 100644 index 0000000000..bf4dcb7d8e --- /dev/null +++ b/examples/benchmarks/TRA/workflow_config_tra_Alpha158.yaml @@ -0,0 +1,126 @@ +qlib_init: + provider_uri: "~/.qlib/qlib_data/cn_data" + region: cn + +market: &market csi300 +benchmark: &benchmark SH000300 + +data_handler_config: &data_handler_config + start_time: 2008-01-01 + end_time: 2020-08-01 + fit_start_time: 2008-01-01 + fit_end_time: 2014-12-31 + instruments: *market + infer_processors: + - class: FilterCol + kwargs: + fields_group: feature + col_list: ["RESI5", "WVMA5", "RSQR5", "KLEN", "RSQR10", "CORR5", "CORD5", "CORR10", + "ROC60", "RESI10", "VSTD5", "RSQR60", "CORR60", "WVMA60", "STD5", + "RSQR20", "CORD60", "CORD10", "CORR20", "KLOW"] + - class: RobustZScoreNorm + kwargs: + fields_group: feature + clip_outlier: true + - class: Fillna + kwargs: + fields_group: feature + learn_processors: + - class: CSRankNorm + kwargs: + fields_group: label + label: ["Ref($close, -2) / Ref($close, -1) - 1"] + +num_states: &num_states 3 + +memory_mode: &memory_mode sample + +tra_config: &tra_config + num_states: *num_states + hidden_size: 16 + tau: 1.0 + src_info: LR_TPE + +model_config: &model_config + input_size: 20 + hidden_size: 64 + num_layers: 2 + rnn_arch: LSTM + use_attn: True + dropout: 0.0 + +port_analysis_config: &port_analysis_config + strategy: + class: TopkDropoutStrategy + module_path: qlib.contrib.strategy.strategy + kwargs: + topk: 50 + n_drop: 5 + backtest: + verbose: False + limit_threshold: 0.095 + account: 100000000 + benchmark: *benchmark + deal_price: close + open_cost: 0.0005 + close_cost: 0.0015 + min_cost: 5 + +task: + model: + class: TRAModel + module_path: qlib.contrib.model.pytorch_tra + kwargs: + tra_config: *tra_config + model_config: *model_config + model_type: RNN + lr: 1e-3 + n_epochs: 100 + max_steps_per_epoch: 100 + early_stop: 10 + smooth_steps: 5 + seed: 0 + logdir: + lamb: 1.0 + rho: 1.0 + transport_method: router + memory_mode: *memory_mode + eval_train: False + eval_test: True + pretrain: True + init_state: + freeze_model: False + freeze_predictors: False + dataset: + class: MTSDatasetH + module_path: qlib.contrib.data.dataset + kwargs: + handler: + class: Alpha158 + module_path: qlib.contrib.data.handler + kwargs: *data_handler_config + segments: + train: [2008-01-01, 2014-12-31] + valid: [2015-01-01, 2016-12-31] + test: [2017-01-01, 2020-08-01] + seq_len: 60 + horizon: 2 + input_size: + num_states: *num_states + batch_size: 1024 + n_samples: + memory_mode: *memory_mode + drop_last: True + record: + - class: SignalRecord + module_path: qlib.workflow.record_temp + kwargs: {} + - class: SigAnaRecord + module_path: qlib.workflow.record_temp + kwargs: + ana_long_short: False + ann_scaler: 252 + - class: PortAnaRecord + module_path: qlib.workflow.record_temp + kwargs: + config: *port_analysis_config diff --git a/examples/benchmarks/TRA/workflow_config_tra_Alpha158_full.yaml b/examples/benchmarks/TRA/workflow_config_tra_Alpha158_full.yaml new file mode 100644 index 0000000000..8d3c8e582f --- /dev/null +++ b/examples/benchmarks/TRA/workflow_config_tra_Alpha158_full.yaml @@ -0,0 +1,120 @@ +qlib_init: + provider_uri: "~/.qlib/qlib_data/cn_data" + region: cn + +market: &market csi300 +benchmark: &benchmark SH000300 + +data_handler_config: &data_handler_config + start_time: 2008-01-01 + end_time: 2020-08-01 + fit_start_time: 2008-01-01 + fit_end_time: 2014-12-31 + instruments: *market + infer_processors: + - class: RobustZScoreNorm + kwargs: + fields_group: feature + clip_outlier: true + - class: Fillna + kwargs: + fields_group: feature + learn_processors: + - class: CSRankNorm + kwargs: + fields_group: label + label: ["Ref($close, -2) / Ref($close, -1) - 1"] + +num_states: &num_states 3 + +memory_mode: &memory_mode sample + +tra_config: &tra_config + num_states: *num_states + hidden_size: 16 + tau: 1.0 + src_info: LR_TPE + +model_config: &model_config + input_size: 158 + hidden_size: 256 + num_layers: 2 + rnn_arch: LSTM + use_attn: True + dropout: 0.2 + +port_analysis_config: &port_analysis_config + strategy: + class: TopkDropoutStrategy + module_path: qlib.contrib.strategy.strategy + kwargs: + topk: 50 + n_drop: 5 + backtest: + verbose: False + limit_threshold: 0.095 + account: 100000000 + benchmark: *benchmark + deal_price: close + open_cost: 0.0005 + close_cost: 0.0015 + min_cost: 5 + +task: + model: + class: TRAModel + module_path: qlib.contrib.model.pytorch_tra + kwargs: + tra_config: *tra_config + model_config: *model_config + model_type: RNN + lr: 1e-3 + n_epochs: 100 + max_steps_per_epoch: 100 + early_stop: 10 + smooth_steps: 5 + seed: 0 + logdir: + lamb: 1.0 + rho: 1.0 + transport_method: router + memory_mode: *memory_mode + eval_train: False + eval_test: True + pretrain: True + init_state: + freeze_model: False + freeze_predictors: False + dataset: + class: MTSDatasetH + module_path: qlib.contrib.data.dataset + kwargs: + handler: + class: Alpha158 + module_path: qlib.contrib.data.handler + kwargs: *data_handler_config + segments: + train: [2008-01-01, 2014-12-31] + valid: [2015-01-01, 2016-12-31] + test: [2017-01-01, 2020-08-01] + seq_len: 60 + horizon: 2 + input_size: + num_states: *num_states + batch_size: 1024 + n_samples: + memory_mode: *memory_mode + drop_last: True + record: + - class: SignalRecord + module_path: qlib.workflow.record_temp + kwargs: {} + - class: SigAnaRecord + module_path: qlib.workflow.record_temp + kwargs: + ana_long_short: False + ann_scaler: 252 + - class: PortAnaRecord + module_path: qlib.workflow.record_temp + kwargs: + config: *port_analysis_config diff --git a/examples/benchmarks/TRA/workflow_config_tra_Alpha360.yaml b/examples/benchmarks/TRA/workflow_config_tra_Alpha360.yaml new file mode 100644 index 0000000000..dbdeaf060e --- /dev/null +++ b/examples/benchmarks/TRA/workflow_config_tra_Alpha360.yaml @@ -0,0 +1,120 @@ +qlib_init: + provider_uri: "~/.qlib/qlib_data/cn_data" + region: cn + +market: &market csi300 +benchmark: &benchmark SH000300 + +data_handler_config: &data_handler_config + start_time: 2008-01-01 + end_time: 2020-08-01 + fit_start_time: 2008-01-01 + fit_end_time: 2014-12-31 + instruments: *market + infer_processors: + - class: RobustZScoreNorm + kwargs: + fields_group: feature + clip_outlier: true + - class: Fillna + kwargs: + fields_group: feature + learn_processors: + - class: CSRankNorm + kwargs: + fields_group: label + label: ["Ref($close, -2) / Ref($close, -1) - 1"] + +num_states: &num_states 3 + +memory_mode: &memory_mode sample + +tra_config: &tra_config + num_states: *num_states + hidden_size: 16 + tau: 1.0 + src_info: LR_TPE + +model_config: &model_config + input_size: 6 + hidden_size: 64 + num_layers: 2 + rnn_arch: LSTM + use_attn: True + dropout: 0.0 + +port_analysis_config: &port_analysis_config + strategy: + class: TopkDropoutStrategy + module_path: qlib.contrib.strategy.strategy + kwargs: + topk: 50 + n_drop: 5 + backtest: + verbose: False + limit_threshold: 0.095 + account: 100000000 + benchmark: *benchmark + deal_price: close + open_cost: 0.0005 + close_cost: 0.0015 + min_cost: 5 + +task: + model: + class: TRAModel + module_path: qlib.contrib.model.pytorch_tra + kwargs: + tra_config: *tra_config + model_config: *model_config + model_type: RNN + lr: 1e-3 + n_epochs: 100 + max_steps_per_epoch: 100 + early_stop: 10 + smooth_steps: 5 + logdir: + seed: 0 + lamb: 1.0 + rho: 1.0 + transport_method: router + memory_mode: *memory_mode + eval_train: False + eval_test: True + pretrain: True + init_state: + freeze_model: False + freeze_predictors: False + dataset: + class: MTSDatasetH + module_path: qlib.contrib.data.dataset + kwargs: + handler: + class: Alpha360 + module_path: qlib.contrib.data.handler + kwargs: *data_handler_config + segments: + train: [2008-01-01, 2014-12-31] + valid: [2015-01-01, 2016-12-31] + test: [2017-01-01, 2020-08-01] + seq_len: 60 + horizon: 2 + input_size: 6 + num_states: *num_states + batch_size: 1024 + n_samples: + memory_mode: *memory_mode + drop_last: True + record: + - class: SignalRecord + module_path: qlib.workflow.record_temp + kwargs: {} + - class: SigAnaRecord + module_path: qlib.workflow.record_temp + kwargs: + ana_long_short: False + ann_scaler: 252 + - class: PortAnaRecord + module_path: qlib.workflow.record_temp + kwargs: + config: *port_analysis_config diff --git a/qlib/contrib/data/dataset.py b/qlib/contrib/data/dataset.py new file mode 100644 index 0000000000..af4893acff --- /dev/null +++ b/qlib/contrib/data/dataset.py @@ -0,0 +1,346 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import copy +import torch +import warnings +import numpy as np +import pandas as pd + +from qlib.utils import init_instance_by_config +from qlib.data.dataset import DatasetH, DataHandler + + +device = "cuda" if torch.cuda.is_available() else "cpu" + + +def _to_tensor(x): + if not isinstance(x, torch.Tensor): + return torch.tensor(x, dtype=torch.float, device=device) + return x + + +def _create_ts_slices(index, seq_len): + """ + create time series slices from pandas index + + Args: + index (pd.MultiIndex): pandas multiindex with order + seq_len (int): sequence length + """ + assert isinstance(index, pd.MultiIndex), "unsupported index type" + assert seq_len > 0, "sequence length should be larger than 0" + assert index.is_monotonic_increasing, "index should be sorted" + + # number of dates for each instrument + sample_count_by_insts = index.to_series().groupby(level=0).size().values + + # start index for each instrument + start_index_of_insts = np.roll(np.cumsum(sample_count_by_insts), 1) + start_index_of_insts[0] = 0 + + # all the [start, stop) indices of features + # features between [start, stop) will be used to predict label at `stop - 1` + slices = [] + for cur_loc, cur_cnt in zip(start_index_of_insts, sample_count_by_insts): + for stop in range(1, cur_cnt + 1): + end = cur_loc + stop + start = max(end - seq_len, 0) + slices.append(slice(start, end)) + slices = np.array(slices, dtype="object") + + assert len(slices) == len(index) # the i-th slice = index[i] + + return slices + + +def _get_date_parse_fn(target): + """get date parse function + + This method is used to parse date arguments as target type. + + Example: + get_date_parse_fn('20120101')('2017-01-01') => '20170101' + get_date_parse_fn(20120101)('2017-01-01') => 20170101 + """ + if isinstance(target, pd.Timestamp): + _fn = lambda x: pd.Timestamp(x) # Timestamp('2020-01-01') + elif isinstance(target, int): + _fn = lambda x: int(str(x).replace("-", "")[:8]) # 20200201 + elif isinstance(target, str) and len(target) == 8: + _fn = lambda x: str(x).replace("-", "")[:8] # '20200201' + else: + _fn = lambda x: x # '2021-01-01' + return _fn + + +def _maybe_padding(x, seq_len, zeros=None): + """padding 2d