-
Notifications
You must be signed in to change notification settings - Fork 2.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
init version of online serving and rolling
- Loading branch information
Showing
7 changed files
with
623 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
""" | ||
Task related workflow is implemented in this folder | ||
A typical task workflow | ||
| Step | Description | | ||
|-----------------------+------------------------------------------------| | ||
| TaskGen | Generating tasks. | | ||
| TaskManager(optional) | Manage generated tasks | | ||
| run task | retrive tasks from TaskManager and run tasks. | | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
from qlib.workflow import R | ||
import pandas as pd | ||
from typing import Union | ||
from tqdm.auto import tqdm | ||
|
||
|
||
class RollingEnsemble: | ||
''' | ||
Rolling Models Ensemble based on (R)ecord | ||
This shares nothing with Ensemble | ||
''' | ||
# TODO: 这边还可以加加速 | ||
def __init__(self, get_key_func, flt_func=None): | ||
self.get_key_func = get_key_func | ||
self.flt_func = flt_func | ||
|
||
def __call__(self, exp_name) -> Union[pd.Series, dict]: | ||
# TODO; | ||
# Should we split the scripts into several sub functions? | ||
exp = R.get_exp(experiment_name=exp_name) | ||
|
||
# filter records | ||
recs = exp.list_recorders() | ||
|
||
recs_flt = {} | ||
for rid, rec in tqdm(recs.items(), desc="Loading data"): | ||
# rec = exp.get_recorder(recorder_id=rid) | ||
params = rec.load_object("param") | ||
if rec.status == rec.STATUS_FI: | ||
if self.flt_func is None or self.flt_func(params): | ||
rec.params = params | ||
recs_flt[rid] = rec | ||
|
||
# group | ||
recs_group = {} | ||
for _, rec in recs_flt.items(): | ||
params = rec.params | ||
group_key = self.get_key_func(params) | ||
recs_group.setdefault(group_key, []).append(rec) | ||
|
||
# reduce group | ||
reduce_group = {} | ||
for k, rec_l in recs_group.items(): | ||
pred_l = [] | ||
for rec in rec_l: | ||
pred_l.append(rec.load_object('pred.pkl').iloc[:, 0]) | ||
pred = pd.concat(pred_l).sort_index() | ||
reduce_group[k] = pred | ||
|
||
return reduce_group | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
''' | ||
this is a task generator | ||
''' | ||
import abc | ||
import copy | ||
import typing | ||
from .utils import TimeAdjuster | ||
|
||
|
||
class TaskGen(metaclass=abc.ABCMeta): | ||
@abc.abstractmethod | ||
def __call__(self, *args, **kwargs) -> typing.List[dict]: | ||
""" | ||
generate | ||
Parameters | ||
---------- | ||
args, kwargs: | ||
The info for generating tasks | ||
Example 1): | ||
input: a specific task template | ||
output: rolling version of the tasks | ||
Example 2): | ||
input: a specific task template | ||
output: a set of tasks with different losses | ||
Returns | ||
------- | ||
typing.List[dict]: | ||
A list of tasks | ||
""" | ||
pass | ||
|
||
|
||
class RollingGen(TaskGen): | ||
|
||
ROLL_EX = TimeAdjuster.SHIFT_EX | ||
ROLL_SD = TimeAdjuster.SHIFT_SD | ||
|
||
def __init__(self, step: int = 40, rtype: str = ROLL_EX): | ||
""" | ||
Generate tasks for rolling | ||
Parameters | ||
---------- | ||
step : int | ||
step to rolling | ||
rtype : str | ||
rolling type (expanding, rolling) | ||
""" | ||
self.step = step | ||
self.rtype = rtype | ||
self.ta = TimeAdjuster(future=True) # 为了保证test最后的日期不是None, 所以这边要改一改 | ||
|
||
self.test_key = 'test' | ||
self.train_key = 'train' | ||
|
||
def __call__(self, task: dict): | ||
""" | ||
Converting the task into a rolling task | ||
Parameters | ||
---------- | ||
task : dict | ||
A dict describing a task. For example. | ||
DEFAULT_TASK = { | ||
"model": { | ||
"class": "LGBModel", | ||
"module_path": "qlib.contrib.model.gbdt", | ||
}, | ||
"dataset": { | ||
"class": "DatasetH", | ||
"module_path": "qlib.data.dataset", | ||
"kwargs": { | ||
"handler": { | ||
"class": "Alpha158", | ||
"module_path": "qlib.contrib.data.handler", | ||
"kwargs": data_handler_config, | ||
}, | ||
"segments": { | ||
"train": ("2008-01-01", "2014-12-31"), | ||
"valid": ("2015-01-01", "2016-12-20"), # Please avoid leaking the future test data into validation | ||
"test": ("2017-01-01", "2020-08-01"), | ||
}, | ||
}, | ||
}, | ||
# You shoud record the data in specific sequence | ||
# "record": ['SignalRecord', 'SigAnaRecord', 'PortAnaRecord'], | ||
} | ||
""" | ||
res = [] | ||
|
||
prev_seg = None | ||
test_end = None | ||
while True: | ||
t = copy.deepcopy(task) | ||
|
||
# calculate segments | ||
if prev_seg is None: | ||
# First rolling | ||
# 1) prepare the end porint | ||
segments = copy.deepcopy(self.ta.align_seg(t['dataset']['kwargs']['segments'])) | ||
test_end = self.ta.max() if segments[self.test_key][1] is None else segments[self.test_key][1] | ||
# 2) and the init test segments | ||
test_start_idx = self.ta.align_idx(segments[self.test_key][0]) | ||
segments[self.test_key] = (self.ta.get(test_start_idx), self.ta.get(test_start_idx + self.step - 1)) | ||
else: | ||
segments = {} | ||
try: | ||
for k, seg in prev_seg.items(): | ||
# 决定怎么shift | ||
if k == self.train_key and self.rtype == self.ROLL_EX: | ||
rtype = self.ta.SHIFT_EX | ||
else: | ||
rtype = self.ta.SHIFT_SD | ||
# 整段数据做shift | ||
segments[k] = self.ta.shift(seg, step=self.step, rtype=rtype) | ||
if segments[self.test_key][0] > test_end: | ||
break | ||
except KeyError: | ||
# We reach the end of tasks | ||
# No more rolling | ||
break | ||
|
||
t['dataset']['kwargs']['segments'] = copy.deepcopy(segments) | ||
prev_seg = segments | ||
res.append(t) | ||
return res | ||
|
||
|
Oops, something went wrong.