Skip to content

Commit

Permalink
init version of online serving and rolling
Browse files Browse the repository at this point in the history
  • Loading branch information
you-n-g committed Feb 26, 2021
1 parent fa8f1cb commit 1e5cf1c
Show file tree
Hide file tree
Showing 7 changed files with 623 additions and 0 deletions.
Empty file added qlib/model/ens/__init__.py
Empty file.
13 changes: 13 additions & 0 deletions qlib/workflow/task/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
Task related workflow is implemented in this folder
A typical task workflow
| Step | Description |
|-----------------------+------------------------------------------------|
| TaskGen | Generating tasks. |
| TaskManager(optional) | Manage generated tasks |
| run task | retrive tasks from TaskManager and run tasks. |
"""
52 changes: 52 additions & 0 deletions qlib/workflow/task/collect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from qlib.workflow import R
import pandas as pd
from typing import Union
from tqdm.auto import tqdm


class RollingEnsemble:
'''
Rolling Models Ensemble based on (R)ecord
This shares nothing with Ensemble
'''
# TODO: 这边还可以加加速
def __init__(self, get_key_func, flt_func=None):
self.get_key_func = get_key_func
self.flt_func = flt_func

def __call__(self, exp_name) -> Union[pd.Series, dict]:
# TODO;
# Should we split the scripts into several sub functions?
exp = R.get_exp(experiment_name=exp_name)

# filter records
recs = exp.list_recorders()

recs_flt = {}
for rid, rec in tqdm(recs.items(), desc="Loading data"):
# rec = exp.get_recorder(recorder_id=rid)
params = rec.load_object("param")
if rec.status == rec.STATUS_FI:
if self.flt_func is None or self.flt_func(params):
rec.params = params
recs_flt[rid] = rec

# group
recs_group = {}
for _, rec in recs_flt.items():
params = rec.params
group_key = self.get_key_func(params)
recs_group.setdefault(group_key, []).append(rec)

# reduce group
reduce_group = {}
for k, rec_l in recs_group.items():
pred_l = []
for rec in rec_l:
pred_l.append(rec.load_object('pred.pkl').iloc[:, 0])
pred = pd.concat(pred_l).sort_index()
reduce_group[k] = pred

return reduce_group

133 changes: 133 additions & 0 deletions qlib/workflow/task/gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
'''
this is a task generator
'''
import abc
import copy
import typing
from .utils import TimeAdjuster


class TaskGen(metaclass=abc.ABCMeta):
@abc.abstractmethod
def __call__(self, *args, **kwargs) -> typing.List[dict]:
"""
generate
Parameters
----------
args, kwargs:
The info for generating tasks
Example 1):
input: a specific task template
output: rolling version of the tasks
Example 2):
input: a specific task template
output: a set of tasks with different losses
Returns
-------
typing.List[dict]:
A list of tasks
"""
pass


class RollingGen(TaskGen):

ROLL_EX = TimeAdjuster.SHIFT_EX
ROLL_SD = TimeAdjuster.SHIFT_SD

def __init__(self, step: int = 40, rtype: str = ROLL_EX):
"""
Generate tasks for rolling
Parameters
----------
step : int
step to rolling
rtype : str
rolling type (expanding, rolling)
"""
self.step = step
self.rtype = rtype
self.ta = TimeAdjuster(future=True) # 为了保证test最后的日期不是None, 所以这边要改一改

self.test_key = 'test'
self.train_key = 'train'

def __call__(self, task: dict):
"""
Converting the task into a rolling task
Parameters
----------
task : dict
A dict describing a task. For example.
DEFAULT_TASK = {
"model": {
"class": "LGBModel",
"module_path": "qlib.contrib.model.gbdt",
},
"dataset": {
"class": "DatasetH",
"module_path": "qlib.data.dataset",
"kwargs": {
"handler": {
"class": "Alpha158",
"module_path": "qlib.contrib.data.handler",
"kwargs": data_handler_config,
},
"segments": {
"train": ("2008-01-01", "2014-12-31"),
"valid": ("2015-01-01", "2016-12-20"), # Please avoid leaking the future test data into validation
"test": ("2017-01-01", "2020-08-01"),
},
},
},
# You shoud record the data in specific sequence
# "record": ['SignalRecord', 'SigAnaRecord', 'PortAnaRecord'],
}
"""
res = []

prev_seg = None
test_end = None
while True:
t = copy.deepcopy(task)

# calculate segments
if prev_seg is None:
# First rolling
# 1) prepare the end porint
segments = copy.deepcopy(self.ta.align_seg(t['dataset']['kwargs']['segments']))
test_end = self.ta.max() if segments[self.test_key][1] is None else segments[self.test_key][1]
# 2) and the init test segments
test_start_idx = self.ta.align_idx(segments[self.test_key][0])
segments[self.test_key] = (self.ta.get(test_start_idx), self.ta.get(test_start_idx + self.step - 1))
else:
segments = {}
try:
for k, seg in prev_seg.items():
# 决定怎么shift
if k == self.train_key and self.rtype == self.ROLL_EX:
rtype = self.ta.SHIFT_EX
else:
rtype = self.ta.SHIFT_SD
# 整段数据做shift
segments[k] = self.ta.shift(seg, step=self.step, rtype=rtype)
if segments[self.test_key][0] > test_end:
break
except KeyError:
# We reach the end of tasks
# No more rolling
break

t['dataset']['kwargs']['segments'] = copy.deepcopy(segments)
prev_seg = segments
res.append(t)
return res


Loading

0 comments on commit 1e5cf1c

Please sign in to comment.