From 9ad6b5564ae61c4ab7a60aa2e2c5a91a8a3206d7 Mon Sep 17 00:00:00 2001 From: Young Date: Thu, 29 Jul 2021 04:40:27 +0000 Subject: [PATCH] refactor online serving rolling api --- ..._config_lightgbm_configurable_dataset.yaml | 2 +- examples/model_rolling/requirements.txt | 1 + qlib/utils/__init__.py | 4 +- qlib/workflow/online/strategy.py | 24 ++-- qlib/workflow/online/update.py | 2 + qlib/workflow/task/gen.py | 104 +++++++++++------- 6 files changed, 82 insertions(+), 55 deletions(-) create mode 100644 examples/model_rolling/requirements.txt diff --git a/examples/benchmarks/LightGBM/workflow_config_lightgbm_configurable_dataset.yaml b/examples/benchmarks/LightGBM/workflow_config_lightgbm_configurable_dataset.yaml index 335dc20934..78f567eb31 100644 --- a/examples/benchmarks/LightGBM/workflow_config_lightgbm_configurable_dataset.yaml +++ b/examples/benchmarks/LightGBM/workflow_config_lightgbm_configurable_dataset.yaml @@ -78,4 +78,4 @@ task: - class: PortAnaRecord module_path: qlib.workflow.record_temp kwargs: - config: *port_analysis_config \ No newline at end of file + config: *port_analysis_config diff --git a/examples/model_rolling/requirements.txt b/examples/model_rolling/requirements.txt new file mode 100644 index 0000000000..10ddd5b71e --- /dev/null +++ b/examples/model_rolling/requirements.txt @@ -0,0 +1 @@ +xgboost diff --git a/qlib/utils/__init__.py b/qlib/utils/__init__.py index 778d0e17a5..2fe9eafed2 100644 --- a/qlib/utils/__init__.py +++ b/qlib/utils/__init__.py @@ -570,9 +570,11 @@ def get_pre_trading_date(trading_date, future=False): def transform_end_date(end_date=None, freq="day"): - """get previous trading date + """handle the end date with various format + If end_date is -1, None, or end_date is greater than the maximum trading day, the last trading date is returned. Otherwise, returns the end_date + ---------- end_date: str end trading date diff --git a/qlib/workflow/online/strategy.py b/qlib/workflow/online/strategy.py index 1e8e85c0f9..7a923ebadd 100644 --- a/qlib/workflow/online/strategy.py +++ b/qlib/workflow/online/strategy.py @@ -10,6 +10,7 @@ from qlib.data.data import D from qlib.log import get_module_logger from qlib.model.ens.group import RollingGroup +from qlib.utils import transform_end_date from qlib.workflow.online.utils import OnlineTool, OnlineToolR from qlib.workflow.recorder import Recorder from qlib.workflow.task.collect import Collector, RecorderCollector @@ -118,6 +119,7 @@ def __init__( task_template = [task_template] self.task_template = task_template self.rg = rolling_gen + assert issubclass(self.rg.__class__, RollingGen), "The rolling strategy relies on the feature if RollingGen" self.tool = OnlineToolR(self.exp_name) self.ta = TimeAdjuster() @@ -174,28 +176,20 @@ def prepare_tasks(self, cur_time) -> List[dict]: Returns: List[dict]: a list of new tasks. """ + # TODO: filter recorders by latest test segments is not a necessary latest_records, max_test = self._list_latest(self.tool.online_models()) if max_test is None: self.logger.warn(f"No latest online recorders, no new tasks.") return [] - calendar_latest = D.calendar(end_time=cur_time)[-1] if cur_time is None else cur_time + calendar_latest = transform_end_date(cur_time) self.logger.info( f"The interval between current time {calendar_latest} and last rolling test begin time {max_test[0]} is {self.ta.cal_interval(calendar_latest, max_test[0])}, the rolling step is {self.rg.step}" ) - if self.ta.cal_interval(calendar_latest, max_test[0]) >= self.rg.step: - old_tasks = [] - tasks_tmp = [] - for rec in latest_records: - task = rec.load_object("task") - old_tasks.append(deepcopy(task)) - test_begin = task["dataset"]["kwargs"]["segments"]["test"][0] - # modify the test segment to generate new tasks - task["dataset"]["kwargs"]["segments"]["test"] = (test_begin, calendar_latest) - tasks_tmp.append(task) - new_tasks_tmp = task_generator(tasks_tmp, self.rg) - new_tasks = [task for task in new_tasks_tmp if task not in old_tasks] - return new_tasks - return [] + res = [] + for rec in latest_records: + task = rec.load_object("task") + res.extend(self.rg.gen_following_tasks(task, calendar_latest)) + return res def _list_latest(self, rec_list: List[Recorder]): """ diff --git a/qlib/workflow/online/update.py b/qlib/workflow/online/update.py index e5dbd413e9..f2135a27a6 100644 --- a/qlib/workflow/online/update.py +++ b/qlib/workflow/online/update.py @@ -105,6 +105,8 @@ def __init__(self, record: Recorder, to_date=None, hist_ref: int = 0, freq="day" if to_date == None: to_date = D.calendar(freq=freq)[-1] self.to_date = pd.Timestamp(to_date) + # FIXME: it will raise error when running routine with delay trainer + # should we use another predicition updater for delay trainer? self.old_pred = record.load_object("pred.pkl") self.last_end = self.old_pred.index.get_level_values("datetime").max() diff --git a/qlib/workflow/task/gen.py b/qlib/workflow/task/gen.py index ca7b8ae7fd..e60fa4755b 100644 --- a/qlib/workflow/task/gen.py +++ b/qlib/workflow/task/gen.py @@ -5,6 +5,7 @@ """ import abc import copy +import pandas as pd from typing import List, Union, Callable from qlib.utils import transform_end_date @@ -139,6 +140,53 @@ def __init__(self, step: int = 40, rtype: str = ROLL_EX, ds_extra_mod_func: Unio self.test_key = "test" self.train_key = "train" + def _update_task_segs(self, task, segs): + # update segments of this task + task["dataset"]["kwargs"]["segments"] = copy.deepcopy(segs) + if self.ds_extra_mod_func is not None: + self.ds_extra_mod_func(task, self) + + def gen_following_tasks(self, task: dict, test_end: pd.Timestamp) -> List[dict]: + """ + generating following rolling tasks for `task` until test_end + + Parameters + ---------- + task : dict + Qlib task format + test_end : pd.Timestamp + the latest rolling task includes `test_end` + + Returns + ------- + List[dict]: + the following tasks of `task`(`task` itself is excluded) + """ + t = copy.deepcopy(task) + prev_seg = t["dataset"]["kwargs"]["segments"] + while True: + segments = {} + try: + for k, seg in prev_seg.items(): + # decide how to shift + # expanding only for train data, the segments size of test data and valid data won't change + if k == self.train_key and self.rtype == self.ROLL_EX: + rtype = self.ta.SHIFT_EX + else: + rtype = self.ta.SHIFT_SD + # shift the segments data + segments[k] = self.ta.shift(seg, step=self.step, rtype=rtype) + if segments[self.test_key][0] > test_end: + break + except KeyError: + # We reach the end of tasks + # No more rolling + break + + prev_seg = segments + self._update_task_segs(t, segments) + yield t + def generate(self, task: dict) -> List[dict]: """ Converting the task into a rolling task. @@ -191,43 +239,23 @@ def generate(self, task: dict) -> List[dict]: """ res = [] - prev_seg = None - test_end = None - while True: - t = copy.deepcopy(task) - - # calculate segments - if prev_seg is None: - # First rolling - # 1) prepare the end point - segments: dict = copy.deepcopy(self.ta.align_seg(t["dataset"]["kwargs"]["segments"])) - test_end = transform_end_date(segments[self.test_key][1]) - # 2) and init test segments - test_start_idx = self.ta.align_idx(segments[self.test_key][0]) - segments[self.test_key] = (self.ta.get(test_start_idx), self.ta.get(test_start_idx + self.step - 1)) - else: - segments = {} - try: - for k, seg in prev_seg.items(): - # decide how to shift - # expanding only for train data, the segments size of test data and valid data won't change - if k == self.train_key and self.rtype == self.ROLL_EX: - rtype = self.ta.SHIFT_EX - else: - rtype = self.ta.SHIFT_SD - # shift the segments data - segments[k] = self.ta.shift(seg, step=self.step, rtype=rtype) - if segments[self.test_key][0] > test_end: - break - except KeyError: - # We reach the end of tasks - # No more rolling - break + t = copy.deepcopy(task) - # update segments of this task - t["dataset"]["kwargs"]["segments"] = copy.deepcopy(segments) - prev_seg = segments - if self.ds_extra_mod_func is not None: - self.ds_extra_mod_func(t, self) - res.append(t) + # calculate segments + + # First rolling + # 1) prepare the end point + segments: dict = copy.deepcopy(self.ta.align_seg(t["dataset"]["kwargs"]["segments"])) + test_end = transform_end_date(segments[self.test_key][1]) + # 2) and init test segments + test_start_idx = self.ta.align_idx(segments[self.test_key][0]) + segments[self.test_key] = (self.ta.get(test_start_idx), self.ta.get(test_start_idx + self.step - 1)) + + # update segments of this task + self._update_task_segs(t, segments) + + res.append(t) + + # Update the following rolling + res.extend(self.gen_following_tasks(t, test_end)) return res