diff --git a/qlib/__init__.py b/qlib/__init__.py index 99035e5014..816e5a5852 100644 --- a/qlib/__init__.py +++ b/qlib/__init__.py @@ -3,6 +3,7 @@ __version__ = "0.6.3.99" +__version__bak = __version__ # This version is backup for QlibConfig.reset_qlib_version import os @@ -10,12 +11,13 @@ import logging import platform import subprocess +from pathlib import Path +from .log import get_module_logger # init qlib def init(default_conf="client", **kwargs): from .config import C - from .log import get_module_logger from .data.cache import H H.clear() @@ -48,7 +50,6 @@ def init(default_conf="client", **kwargs): def _mount_nfs_uri(C): - from .log import get_module_logger LOG = get_module_logger("mount nfs", level=logging.INFO) @@ -151,3 +152,73 @@ def init_from_yaml_conf(conf_path, **kwargs): config.update(kwargs) default_conf = config.pop("default_conf", "client") init(default_conf, **config) + + +def get_project_path(config_name="config.yaml") -> Path: + """ + If users are building a project follow the following pattern. + - Qlib is a sub folder in project path + - There is a file named `config.yaml` in qlib. + + For example: + If your project file system stucuture follows such a pattern + + / + - config.yaml + - ...some folders... + - qlib/ + + This folder will return + + NOTE: link is not supported here. + + + This method is often used when + - user want to use a relative config path instead of hard-coding qlib config path in code + + Raises + ------ + FileNotFoundError: + If project path is not found + """ + cur_path = Path(__file__).absolute().resolve() + while True: + if (cur_path / config_name).exists(): + return cur_path + if cur_path == cur_path.parent: + raise FileNotFoundError("We can't find the project path") + cur_path = cur_path.parent + + +def auto_init(**kwargs): + """ + This function will init qlib automatically with following priority + - Find the project configuration and init qlib + - The parsing process will be affected by the `conf_type` of the configuration file + - Init qlib with default config + """ + + try: + pp = get_project_path() + except FileNotFoundError: + init(**kwargs) + else: + + conf_pp = pp / "config.yaml" + with conf_pp.open() as f: + conf = yaml.safe_load(f) + + conf_type = conf.get("conf_type", "origin") + if conf_type == "origin": + # The type of config is just like original qlib config + init_from_yaml_conf(conf_pp, **kwargs) + elif conf_type == "ref": + # This config type will be more convenient in following scenario + # - There is a shared configure file and you don't want to edit it inplace. + # - The shared configure may be updated later and you don't want to copy it. + # - You have some customized config. + qlib_conf_path = conf["qlib_cfg"] + qlib_conf_update = conf.get("qlib_cfg_update") + init_from_yaml_conf(qlib_conf_path, **qlib_conf_update, **kwargs) + logger = get_module_logger("Initialization") + logger.info(f"Auto load project config: {conf_pp}") diff --git a/qlib/config.py b/qlib/config.py index 52b05568d5..b245cc1df5 100644 --- a/qlib/config.py +++ b/qlib/config.py @@ -33,6 +33,9 @@ def __getattr__(self, attr): raise AttributeError(f"No such {attr} in self._config") + def get(self, key, default=None): + return self.__dict__["_config"].get(key, default) + def __setitem__(self, key, value): self.__dict__["_config"][key] = value @@ -310,8 +313,22 @@ def register(self): # clean up experiment when python program ends experiment_exit_handler() + # Supporting user reset qlib version (useful when user want to connect to qlib server with old version) + self.reset_qlib_version() + self._registered = True + def reset_qlib_version(self): + import qlib + + reset_version = self.get("qlib_reset_version", None) + if reset_version is not None: + qlib.__version__ = reset_version + else: + qlib.__version__ = getattr(qlib, "__version__bak") + # Due to a bug? that converting __version__ to _QlibConfig__version__bak + # Using __version__bak instead of __version__ + @property def registered(self): return self._registered diff --git a/qlib/workflow/task/collect.py b/qlib/workflow/task/collect.py index 13b5869de5..7cdca30fae 100644 --- a/qlib/workflow/task/collect.py +++ b/qlib/workflow/task/collect.py @@ -5,11 +5,12 @@ class RollingEnsemble: - ''' + """ Rolling Models Ensemble based on (R)ecord This shares nothing with Ensemble - ''' + """ + # TODO: 这边还可以加加速 def __init__(self, get_key_func, flt_func=None): self.get_key_func = get_key_func @@ -44,9 +45,8 @@ def __call__(self, exp_name) -> Union[pd.Series, dict]: for k, rec_l in recs_group.items(): pred_l = [] for rec in rec_l: - pred_l.append(rec.load_object('pred.pkl').iloc[:, 0]) + pred_l.append(rec.load_object("pred.pkl").iloc[:, 0]) pred = pd.concat(pred_l).sort_index() reduce_group[k] = pred return reduce_group - diff --git a/qlib/workflow/task/gen.py b/qlib/workflow/task/gen.py index 66529f3a5a..9b031435ed 100644 --- a/qlib/workflow/task/gen.py +++ b/qlib/workflow/task/gen.py @@ -1,8 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -''' +""" this is a task generator -''' +""" import abc import copy import typing @@ -54,8 +54,8 @@ def __init__(self, step: int = 40, rtype: str = ROLL_EX): self.rtype = rtype self.ta = TimeAdjuster(future=True) # 为了保证test最后的日期不是None, 所以这边要改一改 - self.test_key = 'test' - self.train_key = 'train' + self.test_key = "test" + self.train_key = "train" def __call__(self, task: dict): """ @@ -102,7 +102,7 @@ def __call__(self, task: dict): if prev_seg is None: # First rolling # 1) prepare the end porint - segments = copy.deepcopy(self.ta.align_seg(t['dataset']['kwargs']['segments'])) + segments = copy.deepcopy(self.ta.align_seg(t["dataset"]["kwargs"]["segments"])) test_end = self.ta.max() if segments[self.test_key][1] is None else segments[self.test_key][1] # 2) and the init test segments test_start_idx = self.ta.align_idx(segments[self.test_key][0]) @@ -120,14 +120,12 @@ def __call__(self, task: dict): segments[k] = self.ta.shift(seg, step=self.step, rtype=rtype) if segments[self.test_key][0] > test_end: break - except KeyError: + except KeyError: # We reach the end of tasks # No more rolling break - t['dataset']['kwargs']['segments'] = copy.deepcopy(segments) + t["dataset"]["kwargs"]["segments"] = copy.deepcopy(segments) prev_seg = segments res.append(t) return res - - diff --git a/qlib/workflow/task/manage.py b/qlib/workflow/task/manage.py index 6407279f06..3bcac83600 100644 --- a/qlib/workflow/task/manage.py +++ b/qlib/workflow/task/manage.py @@ -19,6 +19,8 @@ import concurrent import pymongo from qlib.config import C +from .utils import get_mongodb +from qlib import auto_init class TaskManager: @@ -41,12 +43,13 @@ class TaskManager: NOTE: - 假设: 存储在db里面的都是encode过的, 拿出来的都是decode过的 """ - STATUS_WAITING = 'waiting' - STATUS_RUNNING = 'running' - STATUS_DONE = 'done' - STATUS_PART_DONE = 'part_done' - ENCODE_FIELDS_PREFIX = ['def', 'res'] + STATUS_WAITING = "waiting" + STATUS_RUNNING = "running" + STATUS_DONE = "done" + STATUS_PART_DONE = "part_done" + + ENCODE_FIELDS_PREFIX = ["def", "res"] def __init__(self, task_pool=None): self.mdb = get_mongodb() @@ -73,7 +76,7 @@ def _get_task_pool(self, task_pool=None): if task_pool is None: task_pool = self.task_pool if task_pool is None: - raise ValueError('You must specify a task pool.') + raise ValueError("You must specify a task pool.") if isinstance(task_pool, str): return getattr(self.mdb, task_pool) return task_pool @@ -85,11 +88,11 @@ def replace_task(self, task, new_task, task_pool=None): # 这里的假设是从接口拿出来的都是decode过的,在接口内部的都是 encode过的 new_task = self._encode_task(new_task) task_pool = self._get_task_pool(task_pool) - query = {'_id': ObjectId(task['_id'])} + query = {"_id": ObjectId(task["_id"])} try: task_pool.replace_one(query, new_task) except InvalidDocument: - task['filter'] = self._dict_to_str(task['filter']) + task["filter"] = self._dict_to_str(task["filter"]) task_pool.replace_one(query, new_task) def insert_task(self, task, task_pool=None): @@ -97,16 +100,18 @@ def insert_task(self, task, task_pool=None): try: task_pool.insert_one(task) except InvalidDocument: - task['filter'] = self._dict_to_str(task['filter']) + task["filter"] = self._dict_to_str(task["filter"]) task_pool.insert_one(task) def insert_task_def(self, task_def, task_pool=None): task_pool = self._get_task_pool(task_pool) - task = self._encode_task({ - 'def': task_def, - 'filter': task_def, # FIXME: catch the raised error - 'status': self.STATUS_WAITING, - }) + task = self._encode_task( + { + "def": task_def, + "filter": task_def, # FIXME: catch the raised error + "status": self.STATUS_WAITING, + } + ) self.insert_task(task, task_pool) def create_task(self, task_def_l, task_pool=None, dry_run=False, print_nt=False): @@ -114,9 +119,9 @@ def create_task(self, task_def_l, task_pool=None, dry_run=False, print_nt=False) new_tasks = [] for t in task_def_l: try: - r = task_pool.find_one({'filter': t}) + r = task_pool.find_one({"filter": t}) except InvalidDocument: - r = task_pool.find_one({'filter': self._dict_to_str(t)}) + r = task_pool.find_one({"filter": self._dict_to_str(t)}) if r is None: new_tasks.append(t) print("Total Tasks, New Tasks:", len(task_def_l), len(new_tasks)) @@ -134,17 +139,16 @@ def create_task(self, task_def_l, task_pool=None, dry_run=False, print_nt=False) def fetch_task(self, query={}, task_pool=None): task_pool = self._get_task_pool(task_pool) query = query.copy() - if '_id' in query: - query['_id'] = ObjectId(query['_id']) - query.update({'status': self.STATUS_WAITING}) - task = task_pool.find_one_and_update(query, {'$set': { - 'status': self.STATUS_RUNNING - }}, - sort=[('priority', pymongo.DESCENDING)]) + if "_id" in query: + query["_id"] = ObjectId(query["_id"]) + query.update({"status": self.STATUS_WAITING}) + task = task_pool.find_one_and_update( + query, {"$set": {"status": self.STATUS_RUNNING}}, sort=[("priority", pymongo.DESCENDING)] + ) # 这里我的 priority 必须是 高数优先级更高,因为 null会被在 ASCENDING时被排在最前面 if task is None: return None - task['status'] = self.STATUS_RUNNING + task["status"] = self.STATUS_RUNNING return self._decode_task(task) @contextmanager @@ -154,9 +158,9 @@ def safe_fetch_task(self, query={}, task_pool=None): yield task except Exception: if task is not None: - logger.info('Returning task before raising error') + logger.info("Returning task before raising error") self.return_task(task) - logger.info('Task returned') + logger.info("Task returned") raise def task_fetcher_iter(self, query={}, task_pool=None): @@ -175,8 +179,8 @@ def query(self, query={}, decode=True, task_pool=None): :param task_pool: """ query = query.copy() - if '_id' in query: - query['_id'] = ObjectId(query['_id']) + if "_id" in query: + query["_id"] = ObjectId(query["_id"]) task_pool = self._get_task_pool(task_pool) for t in task_pool.find(query): yield self._decode_task(t) @@ -186,44 +190,44 @@ def commit_task_res(self, task, res, status=None, task_pool=None): # A workaround to use the class attribute. if status is None: status = TaskManager.STATUS_DONE - task_pool.update_one({"_id": task['_id']}, {'$set': {'status': status, 'res': Binary(pickle.dumps(res))}}) + task_pool.update_one({"_id": task["_id"]}, {"$set": {"status": status, "res": Binary(pickle.dumps(res))}}) def return_task(self, task, status=None, task_pool=None): task_pool = self._get_task_pool(task_pool) if status is None: status = TaskManager.STATUS_WAITING - update_dict = {'$set': {'status': status}} - task_pool.update_one({"_id": task['_id']}, update_dict) + update_dict = {"$set": {"status": status}} + task_pool.update_one({"_id": task["_id"]}, update_dict) def remove(self, query={}, task_pool=None): query = query.copy() task_pool = self._get_task_pool(task_pool) - if '_id' in query: - query['_id'] = ObjectId(query['_id']) + if "_id" in query: + query["_id"] = ObjectId(query["_id"]) task_pool.delete_many(query) def task_stat(self, query={}, task_pool=None): query = query.copy() - if '_id' in query: - query['_id'] = ObjectId(query['_id']) + if "_id" in query: + query["_id"] = ObjectId(query["_id"]) tasks = self.query(task_pool=task_pool, query=query, decode=False) status_stat = {} for t in tasks: - status_stat[t['status']] = status_stat.get(t['status'], 0) + 1 + status_stat[t["status"]] = status_stat.get(t["status"], 0) + 1 return status_stat def reset_waiting(self, query={}, task_pool=None): query = query.copy() # default query - if 'status' not in query: - query['status'] = self.STATUS_RUNNING + if "status" not in query: + query["status"] = self.STATUS_RUNNING return self.reset_status(query=query, status=self.STATUS_WAITING, task_pool=task_pool) def reset_status(self, query, status, task_pool=None): query = query.copy() task_pool = self._get_task_pool(task_pool) - if '_id' in query: - query['_id'] = ObjectId(query['_id']) + if "_id" in query: + query["_id"] = ObjectId(query["_id"]) print(task_pool.update_many(query, {"$set": {"status": status}})) def _get_undone_n(self, task_stat): @@ -274,17 +278,18 @@ def run_task(task_func, task_pool, force_release=False, *args, **kwargs): with tm.safe_fetch_task() as task: if task is None: break - logger.info(task['def']) + logger.info(task["def"]) if force_release: with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor: - res = executor.submit(task_func, task['def'], *args, **kwargs).result() + res = executor.submit(task_func, task["def"], *args, **kwargs).result() else: - res = task_func(task['def'], *args, **kwargs) + res = task_func(task["def"], *args, **kwargs) tm.commit_task_res(task, res) ever_run = True return ever_run -if __name__ == '__main__': +if __name__ == "__main__": + auto_init() Fire(TaskManager) diff --git a/qlib/workflow/task/utils.py b/qlib/workflow/task/utils.py index 3d8fe89964..d6089ff66d 100644 --- a/qlib/workflow/task/utils.py +++ b/qlib/workflow/task/utils.py @@ -10,17 +10,18 @@ def get_mongodb(): try: - cfg = C['mongo'] + cfg = C["mongo"] except KeyError: get_module_logger("task").error("Please configure `C['mongo']` before using TaskManager") raise - client = MongoClient(cfg['task_url']) - return client.get_database(name=cfg['task_db_name']) + client = MongoClient(cfg["task_url"]) + return client.get_database(name=cfg["task_db_name"]) class TimeAdjuster: - '''找到合适的日期,然后adjust date''' + """找到合适的日期,然后adjust date""" + def __init__(self, future=False): self.cals = D.calendar(future=future) @@ -45,9 +46,9 @@ def max(self): def align_idx(self, time_point, tp_type="start"): time_point = pd.Timestamp(time_point) - if tp_type == 'start': + if tp_type == "start": idx = bisect.bisect_left(self.cals, time_point) - elif tp_type == 'end': + elif tp_type == "end": idx = bisect.bisect_right(self.cals, time_point) - 1 else: raise NotImplementedError(f"This type of input is not supported") @@ -91,7 +92,7 @@ def truncate(self, segment, test_start, days: int): new_seg = [] for time_point in segment: tp_idx = min(self.align_idx(time_point), test_idx - days) - assert (tp_idx > 0) + assert tp_idx > 0 new_seg.append(self.get(tp_idx)) return tuple(new_seg) else: