Skip to content

Commit

Permalink
Merge pull request #1 from you-n-g/online_srv
Browse files Browse the repository at this point in the history
qlib auto init basedon project & black format
  • Loading branch information
lzh222333 authored Mar 2, 2021
2 parents 1e5cf1c + 24024d5 commit c4733f6
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 66 deletions.
75 changes: 73 additions & 2 deletions qlib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,21 @@


__version__ = "0.6.3.99"
__version__bak = __version__ # This version is backup for QlibConfig.reset_qlib_version


import os
import yaml
import logging
import platform
import subprocess
from pathlib import Path
from .log import get_module_logger


# init qlib
def init(default_conf="client", **kwargs):
from .config import C
from .log import get_module_logger
from .data.cache import H

H.clear()
Expand Down Expand Up @@ -48,7 +50,6 @@ def init(default_conf="client", **kwargs):


def _mount_nfs_uri(C):
from .log import get_module_logger

LOG = get_module_logger("mount nfs", level=logging.INFO)

Expand Down Expand Up @@ -151,3 +152,73 @@ def init_from_yaml_conf(conf_path, **kwargs):
config.update(kwargs)
default_conf = config.pop("default_conf", "client")
init(default_conf, **config)


def get_project_path(config_name="config.yaml") -> Path:
"""
If users are building a project follow the following pattern.
- Qlib is a sub folder in project path
- There is a file named `config.yaml` in qlib.
For example:
If your project file system stucuture follows such a pattern
<project_path>/
- config.yaml
- ...some folders...
- qlib/
This folder will return <project_path>
NOTE: link is not supported here.
This method is often used when
- user want to use a relative config path instead of hard-coding qlib config path in code
Raises
------
FileNotFoundError:
If project path is not found
"""
cur_path = Path(__file__).absolute().resolve()
while True:
if (cur_path / config_name).exists():
return cur_path
if cur_path == cur_path.parent:
raise FileNotFoundError("We can't find the project path")
cur_path = cur_path.parent


def auto_init(**kwargs):
"""
This function will init qlib automatically with following priority
- Find the project configuration and init qlib
- The parsing process will be affected by the `conf_type` of the configuration file
- Init qlib with default config
"""

try:
pp = get_project_path()
except FileNotFoundError:
init(**kwargs)
else:

conf_pp = pp / "config.yaml"
with conf_pp.open() as f:
conf = yaml.safe_load(f)

conf_type = conf.get("conf_type", "origin")
if conf_type == "origin":
# The type of config is just like original qlib config
init_from_yaml_conf(conf_pp, **kwargs)
elif conf_type == "ref":
# This config type will be more convenient in following scenario
# - There is a shared configure file and you don't want to edit it inplace.
# - The shared configure may be updated later and you don't want to copy it.
# - You have some customized config.
qlib_conf_path = conf["qlib_cfg"]
qlib_conf_update = conf.get("qlib_cfg_update")
init_from_yaml_conf(qlib_conf_path, **qlib_conf_update, **kwargs)
logger = get_module_logger("Initialization")
logger.info(f"Auto load project config: {conf_pp}")
17 changes: 17 additions & 0 deletions qlib/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ def __getattr__(self, attr):

raise AttributeError(f"No such {attr} in self._config")

def get(self, key, default=None):
return self.__dict__["_config"].get(key, default)

def __setitem__(self, key, value):
self.__dict__["_config"][key] = value

Expand Down Expand Up @@ -310,8 +313,22 @@ def register(self):
# clean up experiment when python program ends
experiment_exit_handler()

# Supporting user reset qlib version (useful when user want to connect to qlib server with old version)
self.reset_qlib_version()

self._registered = True

def reset_qlib_version(self):
import qlib

reset_version = self.get("qlib_reset_version", None)
if reset_version is not None:
qlib.__version__ = reset_version
else:
qlib.__version__ = getattr(qlib, "__version__bak")
# Due to a bug? that converting __version__ to _QlibConfig__version__bak
# Using __version__bak instead of __version__

@property
def registered(self):
return self._registered
Expand Down
8 changes: 4 additions & 4 deletions qlib/workflow/task/collect.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@


class RollingEnsemble:
'''
"""
Rolling Models Ensemble based on (R)ecord
This shares nothing with Ensemble
'''
"""

# TODO: 这边还可以加加速
def __init__(self, get_key_func, flt_func=None):
self.get_key_func = get_key_func
Expand Down Expand Up @@ -44,9 +45,8 @@ def __call__(self, exp_name) -> Union[pd.Series, dict]:
for k, rec_l in recs_group.items():
pred_l = []
for rec in rec_l:
pred_l.append(rec.load_object('pred.pkl').iloc[:, 0])
pred_l.append(rec.load_object("pred.pkl").iloc[:, 0])
pred = pd.concat(pred_l).sort_index()
reduce_group[k] = pred

return reduce_group

16 changes: 7 additions & 9 deletions qlib/workflow/task/gen.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
'''
"""
this is a task generator
'''
"""
import abc
import copy
import typing
Expand Down Expand Up @@ -54,8 +54,8 @@ def __init__(self, step: int = 40, rtype: str = ROLL_EX):
self.rtype = rtype
self.ta = TimeAdjuster(future=True) # 为了保证test最后的日期不是None, 所以这边要改一改

self.test_key = 'test'
self.train_key = 'train'
self.test_key = "test"
self.train_key = "train"

def __call__(self, task: dict):
"""
Expand Down Expand Up @@ -102,7 +102,7 @@ def __call__(self, task: dict):
if prev_seg is None:
# First rolling
# 1) prepare the end porint
segments = copy.deepcopy(self.ta.align_seg(t['dataset']['kwargs']['segments']))
segments = copy.deepcopy(self.ta.align_seg(t["dataset"]["kwargs"]["segments"]))
test_end = self.ta.max() if segments[self.test_key][1] is None else segments[self.test_key][1]
# 2) and the init test segments
test_start_idx = self.ta.align_idx(segments[self.test_key][0])
Expand All @@ -120,14 +120,12 @@ def __call__(self, task: dict):
segments[k] = self.ta.shift(seg, step=self.step, rtype=rtype)
if segments[self.test_key][0] > test_end:
break
except KeyError:
except KeyError:
# We reach the end of tasks
# No more rolling
break

t['dataset']['kwargs']['segments'] = copy.deepcopy(segments)
t["dataset"]["kwargs"]["segments"] = copy.deepcopy(segments)
prev_seg = segments
res.append(t)
return res


Loading

0 comments on commit c4733f6

Please sign in to comment.