diff --git a/README.md b/README.md index 28ff004b39..3e2faa94da 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,8 @@ Recent released features | Feature | Status | | -- | ------ | +| Point-in-Time database | :hammer: [Rleased](https://github.com/microsoft/qlib/pull/343) on Mar 10, 2022 | +| Arctic Provider Backend & Orderbook data example | :hammer: [Rleased](https://github.com/microsoft/qlib/pull/744) on Jan 17, 2022 | | Arctic Provider Backend & Orderbook data example | :hammer: [Rleased](https://github.com/microsoft/qlib/pull/744) on Jan 17, 2022 | | Meta-Learning-based framework & DDG-DA | :chart_with_upwards_trend: :hammer: [Released](https://github.com/microsoft/qlib/pull/743) on Jan 10, 2022 | | Planning-based portfolio optimization | :hammer: [Released](https://github.com/microsoft/qlib/pull/754) on Dec 28, 2021 | @@ -95,9 +97,8 @@ For more details, please refer to our paper ["Qlib: An AI-oriented Quantitative # Plans New features under development(order by estimated release time). Your feedbacks about the features are very important. -| Feature | Status | -| -- | ------ | -| Point-in-Time database | Under review: https://github.com/microsoft/qlib/pull/343 | + + # Framework of Qlib diff --git a/docs/advanced/PIT.rst b/docs/advanced/PIT.rst new file mode 100644 index 0000000000..f828a43e45 --- /dev/null +++ b/docs/advanced/PIT.rst @@ -0,0 +1,133 @@ +.. _pit: + +=========================== +(P)oint-(I)n-(T)ime Database +=========================== +.. currentmodule:: qlib + + +Introduction +------------ +Point-in-time data is a very important consideration when performing any sort of historical market analysis. + +For example, let’s say we are backtesting a trading strategy and we are using the past five years of historical data as our input. +Our model is assumed to trade once a day, at the market close, and we’ll say we are calculating the trading signal for 1 January 2020 in our backtest. At that point, we should only have data for 1 January 2020, 31 December 2019, 30 December 2019 etc. + +In financial data (especially financial reports), the same piece of data may be amended for multiple times overtime. If we only use the latest version for historical backtesting, data leakage will happen. +Point-in-time database is designed for solving this problem to make sure user get the right version of data at any historical timestamp. It will keep the performance of online trading and historical backtesting the same. + + + +Data Preparation +---------------- + +Qlib provides a crawler to help users to download financial data and then a converter to dump the data in Qlib format. +Please follow `scripts/data_collector/pit/README.md` to download and convert data. + + +File-based design for PIT data +------------------------------ + +Qlib provides a file-based storage for PIT data. + +For each feature, it contains 4 columns, i.e. date, period, value, _next. +Each row corresponds to a statement. + +The meaning of each feature with filename like `XXX_a.data` +- `date`: the statement's date of publication. +- `period`: the period of the statement. (e.g. it will be quarterly frequency in most of the markets) + - If it is an annual period, it will be an integer corresponding to the year + - If it is an quarterly periods, it will be an integer like ``. The last two decimal digits represents the index of quarter. Others represent the year. +- `value`: the described value +- `_next`: the byte index of the next occurance of the field. + +Besides the feature data, an index `XXX_a.index` is included to speed up the querying performance + +The statements are soted by the `date` in ascending order from the beginning of the file. + +.. code-block:: python + + # the data format from XXXX.data + array([(20070428, 200701, 0.090219 , 4294967295), + (20070817, 200702, 0.13933 , 4294967295), + (20071023, 200703, 0.24586301, 4294967295), + (20080301, 200704, 0.3479 , 80), + (20080313, 200704, 0.395989 , 4294967295), + (20080422, 200801, 0.100724 , 4294967295), + (20080828, 200802, 0.24996801, 4294967295), + (20081027, 200803, 0.33412001, 4294967295), + (20090325, 200804, 0.39011699, 4294967295), + (20090421, 200901, 0.102675 , 4294967295), + (20090807, 200902, 0.230712 , 4294967295), + (20091024, 200903, 0.30072999, 4294967295), + (20100402, 200904, 0.33546099, 4294967295), + (20100426, 201001, 0.083825 , 4294967295), + (20100812, 201002, 0.200545 , 4294967295), + (20101029, 201003, 0.260986 , 4294967295), + (20110321, 201004, 0.30739301, 4294967295), + (20110423, 201101, 0.097411 , 4294967295), + (20110831, 201102, 0.24825101, 4294967295), + (20111018, 201103, 0.318919 , 4294967295), + (20120323, 201104, 0.4039 , 420), + (20120411, 201104, 0.403925 , 4294967295), + (20120426, 201201, 0.112148 , 4294967295), + (20120810, 201202, 0.26484701, 4294967295), + (20121026, 201203, 0.370487 , 4294967295), + (20130329, 201204, 0.45004699, 4294967295), + (20130418, 201301, 0.099958 , 4294967295), + (20130831, 201302, 0.21044201, 4294967295), + (20131016, 201303, 0.30454299, 4294967295), + (20140325, 201304, 0.394328 , 4294967295), + (20140425, 201401, 0.083217 , 4294967295), + (20140829, 201402, 0.16450299, 4294967295), + (20141030, 201403, 0.23408499, 4294967295), + (20150421, 201404, 0.319612 , 4294967295), + (20150421, 201501, 0.078494 , 4294967295), + (20150828, 201502, 0.137504 , 4294967295), + (20151023, 201503, 0.201709 , 4294967295), + (20160324, 201504, 0.26420501, 4294967295), + (20160421, 201601, 0.073664 , 4294967295), + (20160827, 201602, 0.136576 , 4294967295), + (20161029, 201603, 0.188062 , 4294967295), + (20170415, 201604, 0.244385 , 4294967295), + (20170425, 201701, 0.080614 , 4294967295), + (20170728, 201702, 0.15151 , 4294967295), + (20171026, 201703, 0.25416601, 4294967295), + (20180328, 201704, 0.32954201, 4294967295), + (20180428, 201801, 0.088887 , 4294967295), + (20180802, 201802, 0.170563 , 4294967295), + (20181029, 201803, 0.25522 , 4294967295), + (20190329, 201804, 0.34464401, 4294967295), + (20190425, 201901, 0.094737 , 4294967295), + (20190713, 201902, 0. , 1040), + (20190718, 201902, 0.175322 , 4294967295), + (20191016, 201903, 0.25581899, 4294967295)], + dtype=[('date', ' Serialization Task Management + Point-In-Time database .. toctree:: :maxdepth: 3 diff --git a/qlib/config.py b/qlib/config.py index 50b430bb45..bee1811338 100644 --- a/qlib/config.py +++ b/qlib/config.py @@ -92,6 +92,7 @@ def set_conf_from_C(self, config_c): "calendar_provider": "LocalCalendarProvider", "instrument_provider": "LocalInstrumentProvider", "feature_provider": "LocalFeatureProvider", + "pit_provider": "LocalPITProvider", "expression_provider": "LocalExpressionProvider", "dataset_provider": "LocalDatasetProvider", "provider": "LocalProvider", @@ -108,7 +109,6 @@ def set_conf_from_C(self, config_c): "provider_uri": "", # cache "expression_cache": None, - "dataset_cache": None, "calendar_cache": None, # for simple dataset cache "local_cache_path": None, @@ -171,6 +171,18 @@ def set_conf_from_C(self, config_c): "default_exp_name": "Experiment", }, }, + "pit_record_type": { + "date": "I", # uint32 + "period": "I", # uint32 + "value": "d", # float64 + "index": "I", # uint32 + }, + "pit_record_nan": { + "date": 0, + "period": 0, + "value": float("NAN"), + "index": 0xFFFFFFFF, + }, # Default config for MongoDB "mongo": { "task_url": "mongodb://localhost:27017/", @@ -184,20 +196,12 @@ def set_conf_from_C(self, config_c): MODE_CONF = { "server": { - # data provider config - "calendar_provider": "LocalCalendarProvider", - "instrument_provider": "LocalInstrumentProvider", - "feature_provider": "LocalFeatureProvider", - "expression_provider": "LocalExpressionProvider", - "dataset_provider": "LocalDatasetProvider", - "provider": "LocalProvider", # config it in qlib.init() "provider_uri": "", # redis "redis_host": "127.0.0.1", "redis_port": 6379, "redis_task_db": 1, - "kernels": NUM_USABLE_CPU, # cache "expression_cache": DISK_EXPRESSION_CACHE, "dataset_cache": DISK_DATASET_CACHE, @@ -205,25 +209,15 @@ def set_conf_from_C(self, config_c): "mount_path": None, }, "client": { - # data provider config - "calendar_provider": "LocalCalendarProvider", - "instrument_provider": "LocalInstrumentProvider", - "feature_provider": "LocalFeatureProvider", - "expression_provider": "LocalExpressionProvider", - "dataset_provider": "LocalDatasetProvider", - "provider": "LocalProvider", # config it in user's own code "provider_uri": "~/.qlib/qlib_data/cn_data", # cache # Using parameter 'remote' to announce the client is using server_cache, and the writing access will be disabled. # Disable cache by default. Avoid introduce advanced features for beginners - "expression_cache": None, "dataset_cache": None, # SimpleDatasetCache directory "local_cache_path": Path("~/.cache/qlib_simple_cache").expanduser().resolve(), - "calendar_cache": None, # client config - "kernels": NUM_USABLE_CPU, "mount_path": None, "auto_mount": False, # The nfs is already mounted on our server[auto_mount: False]. # The nfs should be auto-mounted by qlib on other diff --git a/qlib/data/__init__.py b/qlib/data/__init__.py index ef5fe4708e..4baf6c72a0 100644 --- a/qlib/data/__init__.py +++ b/qlib/data/__init__.py @@ -15,6 +15,7 @@ LocalCalendarProvider, LocalInstrumentProvider, LocalFeatureProvider, + LocalPITProvider, LocalExpressionProvider, LocalDatasetProvider, ClientCalendarProvider, diff --git a/qlib/data/base.py b/qlib/data/base.py index bbadf68a3d..427c15e3ca 100644 --- a/qlib/data/base.py +++ b/qlib/data/base.py @@ -6,12 +6,20 @@ from __future__ import print_function import abc - +import pandas as pd from ..log import get_module_logger class Expression(abc.ABC): - """Expression base class""" + """ + Expression base class + + Expression is designed to handle the calculation of data with the format below + data with two dimension for each instrument, + - feature + - time: it could be observation time or period time. + - period time is designed for Point-in-time database. For example, the period time maybe 2014Q4, its value can observed for multiple times(different value may be observed at different time due to amendment). + """ def __str__(self): return type(self).__name__ @@ -124,8 +132,18 @@ def __ror__(self, other): return Or(other, self) - def load(self, instrument, start_index, end_index, freq): + def load(self, instrument, start_index, end_index, *args): """load feature + This function is responsible for loading feature/expression based on the expression engine. + + The concerate implementation will be seperated by two parts + 1) caching data, handle errors. + - This part is shared by all the expressions and implemented in Expression + 2) processing and calculating data based on the specific expression. + - This part is different in each expression and implemented in each expression + + Expresion Engine is shared by different data. + Different data will have different extra infomation for `args`. Parameters ---------- @@ -135,8 +153,15 @@ def load(self, instrument, start_index, end_index, freq): feature start index [in calendar]. end_index : str feature end index [in calendar]. - freq : str - feature frequency. + + *args may contains following information; + 1) if it is used in basic experssion engine data, it contains following arguments + freq : str + feature frequency. + + 2) if is used in PIT data, it contains following arguments + cur_pit: + it is designed for the point-in-time data. Returns ---------- @@ -146,26 +171,26 @@ def load(self, instrument, start_index, end_index, freq): from .cache import H # pylint: disable=C0415 # cache - args = str(self), instrument, start_index, end_index, freq - if args in H["f"]: - return H["f"][args] + cache_key = str(self), instrument, start_index, end_index, *args + if cache_key in H["f"]: + return H["f"][cache_key] if start_index is not None and end_index is not None and start_index > end_index: raise ValueError("Invalid index range: {} {}".format(start_index, end_index)) try: - series = self._load_internal(instrument, start_index, end_index, freq) + series = self._load_internal(instrument, start_index, end_index, *args) except Exception as e: get_module_logger("data").debug( f"Loading data error: instrument={instrument}, expression={str(self)}, " - f"start_index={start_index}, end_index={end_index}, freq={freq}. " + f"start_index={start_index}, end_index={end_index}, args={args}. " f"error info: {str(e)}" ) raise series.name = str(self) - H["f"][args] = series + H["f"][cache_key] = series return series @abc.abstractmethod - def _load_internal(self, instrument, start_index, end_index, freq): + def _load_internal(self, instrument, start_index, end_index, *args) -> pd.Series: raise NotImplementedError("This function must be implemented in your newly defined feature") @abc.abstractmethod @@ -225,6 +250,16 @@ def get_extended_window_size(self): return 0, 0 +class PFeature(Feature): + def __str__(self): + return "$$" + self._name + + def _load_internal(self, instrument, start_index, end_index, cur_time): + from .data import PITD # pylint: disable=C0415 + + return PITD.period_feature(instrument, str(self), start_index, end_index, cur_time) + + class ExpressionOps(Expression): """Operator Expression diff --git a/qlib/data/data.py b/qlib/data/data.py index 8080eb66ca..cd8f7f77f6 100644 --- a/qlib/data/data.py +++ b/qlib/data/data.py @@ -34,6 +34,8 @@ code_to_fname, set_log_with_config, time_to_slc_point, + read_period_data, + get_period_list, ) from ..utils.paral import ParallelExt from .ops import Operators # pylint: disable=W0611 @@ -331,6 +333,38 @@ def feature(self, instrument, field, start_time, end_time, freq): raise NotImplementedError("Subclass of FeatureProvider must implement `feature` method") +class PITProvider(abc.ABC): + @abc.abstractmethod + def period_feature(self, instrument, field, start_index: int, end_index: int, cur_time: pd.Timestamp) -> pd.Series: + """ + get the historical periods data series between `start_index` and `end_index` + + Parameters + ---------- + start_index: int + start_index is a relative index to the latest period to cur_time + + end_index: int + end_index is a relative index to the latest period to cur_time + in most cases, the start_index and end_index will be a non-positive values + For example, start_index == -3 end_index == 0 and current period index is cur_idx, + then the data between [start_index + cur_idx, end_index + cur_idx] will be retrieved. + + Returns + ------- + pd.Series + The index will be integers to indicate the periods of the data + An typical examples will be + TODO + + Raises + ------ + FileNotFoundError + This exception will be raised if the queried data do not exist. + """ + raise NotImplementedError(f"Please implement the `period_feature` method") + + class ExpressionProvider(abc.ABC): """Expression provider class @@ -694,6 +728,89 @@ def feature(self, instrument, field, start_index, end_index, freq): return self.backend_obj(instrument=instrument, field=field, freq=freq)[start_index : end_index + 1] +class LocalPITProvider(PITProvider): + # TODO: Add PIT backend file storage + # NOTE: This class is not multi-threading-safe!!!! + + def period_feature(self, instrument, field, start_index, end_index, cur_time): + if not isinstance(cur_time, pd.Timestamp): + raise ValueError( + f"Expected pd.Timestamp for `cur_time`, got '{cur_time}'. Advices: you can't query PIT data directly(e.g. '$$roewa_q'), you must use `P` operator to convert data to each day (e.g. 'P($$roewa_q)')" + ) + + assert end_index <= 0 # PIT don't support querying future data + + DATA_RECORDS = [ + ("date", C.pit_record_type["date"]), + ("period", C.pit_record_type["period"]), + ("value", C.pit_record_type["value"]), + ("_next", C.pit_record_type["index"]), + ] + VALUE_DTYPE = C.pit_record_type["value"] + + field = str(field).lower()[2:] + instrument = code_to_fname(instrument) + + # {For acceleration + # start_index, end_index, cur_index = kwargs["info"] + # if cur_index == start_index: + # if not hasattr(self, "all_fields"): + # self.all_fields = [] + # self.all_fields.append(field) + # if not hasattr(self, "period_index"): + # self.period_index = {} + # if field not in self.period_index: + # self.period_index[field] = {} + # For acceleration} + + if not field.endswith("_q") and not field.endswith("_a"): + raise ValueError("period field must ends with '_q' or '_a'") + quarterly = field.endswith("_q") + index_path = C.dpm.get_data_uri() / "financial" / instrument.lower() / f"{field}.index" + data_path = C.dpm.get_data_uri() / "financial" / instrument.lower() / f"{field}.data" + if not (index_path.exists() and data_path.exists()): + raise FileNotFoundError("No file is found. Raise exception and ") + # NOTE: The most significant performance loss is here. + # Does the accelration that makes the program complicated really matters? + # - It make parameters parameters of the interface complicate + # - It does not performance in the optimal way (places all the pieces together, we may achieve higher performance) + # - If we design it carefully, we can go through for only once to get the historical evolution of the data. + # So I decide to deprecated previous implementation and keep the logic of the program simple + # Instead, I'll add a cache for the index file. + data = np.fromfile(data_path, dtype=DATA_RECORDS) + + # find all revision periods before `cur_time` + cur_time_int = int(cur_time.year) * 10000 + int(cur_time.month) * 100 + int(cur_time.day) + loc = np.searchsorted(data["date"], cur_time_int, side="right") + if loc <= 0: + return pd.Series() + last_period = data["period"][:loc].max() # return the latest quarter + first_period = data["period"][:loc].min() + + period_list = get_period_list(first_period, last_period, quarterly) + period_list = period_list[max(0, len(period_list) + start_index - 1) : len(period_list) + end_index] + value = np.full((len(period_list),), np.nan, dtype=VALUE_DTYPE) + for i, period in enumerate(period_list): + # last_period_index = self.period_index[field].get(period) # For acceleration + value[i], now_period_index = read_period_data( + index_path, data_path, period, cur_time_int, quarterly # , last_period_index # For acceleration + ) + # self.period_index[field].update({period: now_period_index}) # For acceleration + # NOTE: the index is period_list; So it may result in unexpected values(e.g. nan) + # when calculation between different features and only part of its financial indicator is published + series = pd.Series(value, index=period_list, dtype=VALUE_DTYPE) + + # {For acceleration + # if cur_index == end_index: + # self.all_fields.remove(field) + # if not len(self.all_fields): + # del self.all_fields + # del self.period_index + # For acceleration} + + return series + + class LocalExpressionProvider(ExpressionProvider): """Local expression data provider class @@ -1003,6 +1120,8 @@ def dataset( class BaseProvider: """Local provider class + It is a set of interface that allow users to access data. + Because PITD is not exposed publicly to users, so it is not included in the interface. To keep compatible with old qlib provider. """ @@ -1126,6 +1245,7 @@ def is_instance_of_provider(instance: object, cls: type): CalendarProviderWrapper = Annotated[CalendarProvider, Wrapper] InstrumentProviderWrapper = Annotated[InstrumentProvider, Wrapper] FeatureProviderWrapper = Annotated[FeatureProvider, Wrapper] + PITProviderWrapper = Annotated[PITProvider, Wrapper] ExpressionProviderWrapper = Annotated[ExpressionProvider, Wrapper] DatasetProviderWrapper = Annotated[DatasetProvider, Wrapper] BaseProviderWrapper = Annotated[BaseProvider, Wrapper] @@ -1133,6 +1253,7 @@ def is_instance_of_provider(instance: object, cls: type): CalendarProviderWrapper = CalendarProvider InstrumentProviderWrapper = InstrumentProvider FeatureProviderWrapper = FeatureProvider + PITProviderWrapper = PITProvider ExpressionProviderWrapper = ExpressionProvider DatasetProviderWrapper = DatasetProvider BaseProviderWrapper = BaseProvider @@ -1140,6 +1261,7 @@ def is_instance_of_provider(instance: object, cls: type): Cal: CalendarProviderWrapper = Wrapper() Inst: InstrumentProviderWrapper = Wrapper() FeatureD: FeatureProviderWrapper = Wrapper() +PITD: PITProviderWrapper = Wrapper() ExpressionD: ExpressionProviderWrapper = Wrapper() DatasetD: DatasetProviderWrapper = Wrapper() D: BaseProviderWrapper = Wrapper() @@ -1165,6 +1287,11 @@ def register_all_wrappers(C): register_wrapper(FeatureD, feature_provider, "qlib.data") logger.debug(f"registering FeatureD {C.feature_provider}") + if getattr(C, "pit_provider", None) is not None: + pit_provider = init_instance_by_config(C.pit_provider, module) + register_wrapper(PITD, pit_provider, "qlib.data") + logger.debug(f"registering PITD {C.pit_provider}") + if getattr(C, "expression_provider", None) is not None: # This provider is unnecessary in client provider _eprovider = init_instance_by_config(C.expression_provider, module) diff --git a/qlib/data/ops.py b/qlib/data/ops.py index a0a0668170..bdc032c037 100644 --- a/qlib/data/ops.py +++ b/qlib/data/ops.py @@ -10,9 +10,7 @@ from typing import Union, List, Type from scipy.stats import percentileofscore - -from .base import Expression, ExpressionOps, Feature - +from .base import Expression, ExpressionOps, Feature, PFeature from ..log import get_module_logger from ..utils import get_callable_kwargs @@ -84,8 +82,8 @@ def __init__(self, feature, func): self.func = func super(NpElemOperator, self).__init__(feature) - def _load_internal(self, instrument, start_index, end_index, freq): - series = self.feature.load(instrument, start_index, end_index, freq) + def _load_internal(self, instrument, start_index, end_index, *args): + series = self.feature.load(instrument, start_index, end_index, *args) return getattr(np, self.func)(series) @@ -124,11 +122,11 @@ class Sign(NpElemOperator): def __init__(self, feature): super(Sign, self).__init__(feature, "sign") - def _load_internal(self, instrument, start_index, end_index, freq): + def _load_internal(self, instrument, start_index, end_index, *args): """ To avoid error raised by bool type input, we transform the data into float32. """ - series = self.feature.load(instrument, start_index, end_index, freq) + series = self.feature.load(instrument, start_index, end_index, *args) # TODO: More precision types should be configurable series = series.astype(np.float32) return getattr(np, self.func)(series) @@ -173,8 +171,8 @@ def __init__(self, feature, exponent): def __str__(self): return "{}({},{})".format(type(self).__name__, self.feature, self.exponent) - def _load_internal(self, instrument, start_index, end_index, freq): - series = self.feature.load(instrument, start_index, end_index, freq) + def _load_internal(self, instrument, start_index, end_index, *args): + series = self.feature.load(instrument, start_index, end_index, *args) return getattr(np, self.func)(series, self.exponent) @@ -201,8 +199,8 @@ def __init__(self, feature, instrument): def __str__(self): return "{}({},{})".format(type(self).__name__, self.feature, self.instrument.lower()) - def _load_internal(self, instrument, start_index, end_index, freq): - return self.feature.load(self.instrument, start_index, end_index, freq) + def _load_internal(self, instrument, start_index, end_index, *args): + return self.feature.load(self.instrument, start_index, end_index, *args) class Not(NpElemOperator): @@ -252,24 +250,24 @@ def __str__(self): return "{}({},{})".format(type(self).__name__, self.feature_left, self.feature_right) def get_longest_back_rolling(self): - if isinstance(self.feature_left, Expression): + if isinstance(self.feature_left, (Expression,)): left_br = self.feature_left.get_longest_back_rolling() else: left_br = 0 - if isinstance(self.feature_right, Expression): + if isinstance(self.feature_right, (Expression,)): right_br = self.feature_right.get_longest_back_rolling() else: right_br = 0 return max(left_br, right_br) def get_extended_window_size(self): - if isinstance(self.feature_left, Expression): + if isinstance(self.feature_left, (Expression,)): ll, lr = self.feature_left.get_extended_window_size() else: ll, lr = 0, 0 - if isinstance(self.feature_right, Expression): + if isinstance(self.feature_right, (Expression,)): rl, rr = self.feature_right.get_extended_window_size() else: rl, rr = 0, 0 @@ -298,16 +296,16 @@ def __init__(self, feature_left, feature_right, func): self.func = func super(NpPairOperator, self).__init__(feature_left, feature_right) - def _load_internal(self, instrument, start_index, end_index, freq): + def _load_internal(self, instrument, start_index, end_index, *args): assert any( - [isinstance(self.feature_left, Expression), self.feature_right, Expression] + [isinstance(self.feature_left, (Expression,)), self.feature_right, Expression] ), "at least one of two inputs is Expression instance" - if isinstance(self.feature_left, Expression): - series_left = self.feature_left.load(instrument, start_index, end_index, freq) + if isinstance(self.feature_left, (Expression,)): + series_left = self.feature_left.load(instrument, start_index, end_index, *args) else: series_left = self.feature_left # numeric value - if isinstance(self.feature_right, Expression): - series_right = self.feature_right.load(instrument, start_index, end_index, freq) + if isinstance(self.feature_right, (Expression,)): + series_right = self.feature_right.load(instrument, start_index, end_index, *args) else: series_right = self.feature_right check_length = isinstance(series_left, (np.ndarray, pd.Series)) and isinstance( @@ -637,48 +635,48 @@ def __init__(self, condition, feature_left, feature_right): def __str__(self): return "If({},{},{})".format(self.condition, self.feature_left, self.feature_right) - def _load_internal(self, instrument, start_index, end_index, freq): - series_cond = self.condition.load(instrument, start_index, end_index, freq) - if isinstance(self.feature_left, Expression): - series_left = self.feature_left.load(instrument, start_index, end_index, freq) + def _load_internal(self, instrument, start_index, end_index, *args): + series_cond = self.condition.load(instrument, start_index, end_index, *args) + if isinstance(self.feature_left, (Expression,)): + series_left = self.feature_left.load(instrument, start_index, end_index, *args) else: series_left = self.feature_left - if isinstance(self.feature_right, Expression): - series_right = self.feature_right.load(instrument, start_index, end_index, freq) + if isinstance(self.feature_right, (Expression,)): + series_right = self.feature_right.load(instrument, start_index, end_index, *args) else: series_right = self.feature_right series = pd.Series(np.where(series_cond, series_left, series_right), index=series_cond.index) return series def get_longest_back_rolling(self): - if isinstance(self.feature_left, Expression): + if isinstance(self.feature_left, (Expression,)): left_br = self.feature_left.get_longest_back_rolling() else: left_br = 0 - if isinstance(self.feature_right, Expression): + if isinstance(self.feature_right, (Expression,)): right_br = self.feature_right.get_longest_back_rolling() else: right_br = 0 - if isinstance(self.condition, Expression): + if isinstance(self.condition, (Expression,)): c_br = self.condition.get_longest_back_rolling() else: c_br = 0 return max(left_br, right_br, c_br) def get_extended_window_size(self): - if isinstance(self.feature_left, Expression): + if isinstance(self.feature_left, (Expression,)): ll, lr = self.feature_left.get_extended_window_size() else: ll, lr = 0, 0 - if isinstance(self.feature_right, Expression): + if isinstance(self.feature_right, (Expression,)): rl, rr = self.feature_right.get_extended_window_size() else: rl, rr = 0, 0 - if isinstance(self.condition, Expression): + if isinstance(self.condition, (Expression,)): cl, cr = self.condition.get_extended_window_size() else: cl, cr = 0, 0 @@ -719,8 +717,8 @@ def __init__(self, feature, N, func): def __str__(self): return "{}({},{})".format(type(self).__name__, self.feature, self.N) - def _load_internal(self, instrument, start_index, end_index, freq): - series = self.feature.load(instrument, start_index, end_index, freq) + def _load_internal(self, instrument, start_index, end_index, *args): + series = self.feature.load(instrument, start_index, end_index, *args) # NOTE: remove all null check, # now it's user's responsibility to decide whether use features in null days # isnull = series.isnull() # NOTE: isnull = NaN, inf is not null @@ -777,8 +775,8 @@ class Ref(Rolling): def __init__(self, feature, N): super(Ref, self).__init__(feature, N, "ref") - def _load_internal(self, instrument, start_index, end_index, freq): - series = self.feature.load(instrument, start_index, end_index, freq) + def _load_internal(self, instrument, start_index, end_index, *args): + series = self.feature.load(instrument, start_index, end_index, *args) # N = 0, return first day if series.empty: return series # Pandas bug, see: https://github.com/pandas-dev/pandas/issues/21049 @@ -967,8 +965,8 @@ class IdxMax(Rolling): def __init__(self, feature, N): super(IdxMax, self).__init__(feature, N, "idxmax") - def _load_internal(self, instrument, start_index, end_index, freq): - series = self.feature.load(instrument, start_index, end_index, freq) + def _load_internal(self, instrument, start_index, end_index, *args): + series = self.feature.load(instrument, start_index, end_index, *args) if self.N == 0: series = series.expanding(min_periods=1).apply(lambda x: x.argmax() + 1, raw=True) else: @@ -1015,8 +1013,8 @@ class IdxMin(Rolling): def __init__(self, feature, N): super(IdxMin, self).__init__(feature, N, "idxmin") - def _load_internal(self, instrument, start_index, end_index, freq): - series = self.feature.load(instrument, start_index, end_index, freq) + def _load_internal(self, instrument, start_index, end_index, *args): + series = self.feature.load(instrument, start_index, end_index, *args) if self.N == 0: series = series.expanding(min_periods=1).apply(lambda x: x.argmin() + 1, raw=True) else: @@ -1047,8 +1045,8 @@ def __init__(self, feature, N, qscore): def __str__(self): return "{}({},{},{})".format(type(self).__name__, self.feature, self.N, self.qscore) - def _load_internal(self, instrument, start_index, end_index, freq): - series = self.feature.load(instrument, start_index, end_index, freq) + def _load_internal(self, instrument, start_index, end_index, *args): + series = self.feature.load(instrument, start_index, end_index, *args) if self.N == 0: series = series.expanding(min_periods=1).quantile(self.qscore) else: @@ -1095,8 +1093,8 @@ class Mad(Rolling): def __init__(self, feature, N): super(Mad, self).__init__(feature, N, "mad") - def _load_internal(self, instrument, start_index, end_index, freq): - series = self.feature.load(instrument, start_index, end_index, freq) + def _load_internal(self, instrument, start_index, end_index, *args): + series = self.feature.load(instrument, start_index, end_index, *args) # TODO: implement in Cython def mad(x): @@ -1129,8 +1127,8 @@ class Rank(Rolling): def __init__(self, feature, N): super(Rank, self).__init__(feature, N, "rank") - def _load_internal(self, instrument, start_index, end_index, freq): - series = self.feature.load(instrument, start_index, end_index, freq) + def _load_internal(self, instrument, start_index, end_index, *args): + series = self.feature.load(instrument, start_index, end_index, *args) # TODO: implement in Cython def rank(x): @@ -1187,8 +1185,8 @@ class Delta(Rolling): def __init__(self, feature, N): super(Delta, self).__init__(feature, N, "delta") - def _load_internal(self, instrument, start_index, end_index, freq): - series = self.feature.load(instrument, start_index, end_index, freq) + def _load_internal(self, instrument, start_index, end_index, *args): + series = self.feature.load(instrument, start_index, end_index, *args) if self.N == 0: series = series - series.iloc[0] else: @@ -1225,8 +1223,8 @@ class Slope(Rolling): def __init__(self, feature, N): super(Slope, self).__init__(feature, N, "slope") - def _load_internal(self, instrument, start_index, end_index, freq): - series = self.feature.load(instrument, start_index, end_index, freq) + def _load_internal(self, instrument, start_index, end_index, *args): + series = self.feature.load(instrument, start_index, end_index, *args) if self.N == 0: series = pd.Series(expanding_slope(series.values), index=series.index) else: @@ -1253,8 +1251,8 @@ class Rsquare(Rolling): def __init__(self, feature, N): super(Rsquare, self).__init__(feature, N, "rsquare") - def _load_internal(self, instrument, start_index, end_index, freq): - _series = self.feature.load(instrument, start_index, end_index, freq) + def _load_internal(self, instrument, start_index, end_index, *args): + _series = self.feature.load(instrument, start_index, end_index, *args) if self.N == 0: series = pd.Series(expanding_rsquare(_series.values), index=_series.index) else: @@ -1282,8 +1280,8 @@ class Resi(Rolling): def __init__(self, feature, N): super(Resi, self).__init__(feature, N, "resi") - def _load_internal(self, instrument, start_index, end_index, freq): - series = self.feature.load(instrument, start_index, end_index, freq) + def _load_internal(self, instrument, start_index, end_index, *args): + series = self.feature.load(instrument, start_index, end_index, *args) if self.N == 0: series = pd.Series(expanding_resi(series.values), index=series.index) else: @@ -1310,8 +1308,8 @@ class WMA(Rolling): def __init__(self, feature, N): super(WMA, self).__init__(feature, N, "wma") - def _load_internal(self, instrument, start_index, end_index, freq): - series = self.feature.load(instrument, start_index, end_index, freq) + def _load_internal(self, instrument, start_index, end_index, *args): + series = self.feature.load(instrument, start_index, end_index, *args) # TODO: implement in Cython def weighted_mean(x): @@ -1345,8 +1343,8 @@ class EMA(Rolling): def __init__(self, feature, N): super(EMA, self).__init__(feature, N, "ema") - def _load_internal(self, instrument, start_index, end_index, freq): - series = self.feature.load(instrument, start_index, end_index, freq) + def _load_internal(self, instrument, start_index, end_index, *args): + series = self.feature.load(instrument, start_index, end_index, *args) def exp_weighted_mean(x): a = 1 - 2 / (1 + len(x)) @@ -1392,17 +1390,17 @@ def __init__(self, feature_left, feature_right, N, func): def __str__(self): return "{}({},{},{})".format(type(self).__name__, self.feature_left, self.feature_right, self.N) - def _load_internal(self, instrument, start_index, end_index, freq): + def _load_internal(self, instrument, start_index, end_index, *args): assert any( [isinstance(self.feature_left, Expression), self.feature_right, Expression] ), "at least one of two inputs is Expression instance" if isinstance(self.feature_left, Expression): - series_left = self.feature_left.load(instrument, start_index, end_index, freq) + series_left = self.feature_left.load(instrument, start_index, end_index, *args) else: series_left = self.feature_left # numeric value if isinstance(self.feature_right, Expression): - series_right = self.feature_right.load(instrument, start_index, end_index, freq) + series_right = self.feature_right.load(instrument, start_index, end_index, *args) else: series_right = self.feature_right @@ -1465,12 +1463,12 @@ class Corr(PairRolling): def __init__(self, feature_left, feature_right, N): super(Corr, self).__init__(feature_left, feature_right, N, "corr") - def _load_internal(self, instrument, start_index, end_index, freq): - res: pd.Series = super(Corr, self)._load_internal(instrument, start_index, end_index, freq) + def _load_internal(self, instrument, start_index, end_index, *args): + res: pd.Series = super(Corr, self)._load_internal(instrument, start_index, end_index, *args) # NOTE: Load uses MemCache, so calling load again will not cause performance degradation - series_left = self.feature_left.load(instrument, start_index, end_index, freq) - series_right = self.feature_right.load(instrument, start_index, end_index, freq) + series_left = self.feature_left.load(instrument, start_index, end_index, *args) + series_right = self.feature_right.load(instrument, start_index, end_index, *args) res.loc[ np.isclose(series_left.rolling(self.N, min_periods=1).std(), 0, atol=2e-05) | np.isclose(series_right.rolling(self.N, min_periods=1).std(), 0, atol=2e-05) @@ -1529,8 +1527,8 @@ def __init__(self, feature, freq, func): def __str__(self): return "{}({},{})".format(type(self).__name__, self.feature, self.freq) - def _load_internal(self, instrument, start_index, end_index, freq): - series = self.feature.load(instrument, start_index, end_index, freq) + def _load_internal(self, instrument, start_index, end_index, *args): + series = self.feature.load(instrument, start_index, end_index, *args) if series.empty: return series @@ -1590,6 +1588,7 @@ def _load_internal(self, instrument, start_index, end_index, freq): IdxMin, If, Feature, + PFeature, ] + [TResample] @@ -1622,7 +1621,7 @@ def register(self, ops_list: List[Union[Type[ExpressionOps], dict]]): else: _ops_class = _operator - if not issubclass(_ops_class, Expression): + if not issubclass(_ops_class, (Expression,)): raise TypeError("operator must be subclass of ExpressionOps, not {}".format(_ops_class)) if _ops_class.__name__ in self._ops: @@ -1644,8 +1643,10 @@ def register_all_ops(C): """register all operator""" logger = get_module_logger("ops") + from qlib.data.pit import P # pylint: disable=C0415 + Operators.reset() - Operators.register(OpsList) + Operators.register(OpsList + [P]) if getattr(C, "custom_ops", None) is not None: Operators.register(C.custom_ops) diff --git a/qlib/data/pit.py b/qlib/data/pit.py new file mode 100644 index 0000000000..ebe01eaf26 --- /dev/null +++ b/qlib/data/pit.py @@ -0,0 +1,57 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +""" +Qlib follow the logic below to supporting point-in-time database + +For each stock, the format of its data is . Expression Engine support calculation on such format of data + +To calculate the feature value f_t at a specific observe time t, data with format will be used. +For example, the average earning of last 4 quarters (period_time) on 20190719 (observe_time) + +The calculation of both and data rely on expression engine. It consists of 2 phases. +1) calculation at each observation time t and it will collasped into a point (just like a normal feature) +2) concatenate all th collasped data, we will get data with format . +Qlib will use the operator `P` to perform the collapse. +""" +import numpy as np +import pandas as pd +from qlib.data.ops import ElemOperator +from qlib.log import get_module_logger +from .data import Cal + + +class P(ElemOperator): + def _load_internal(self, instrument, start_index, end_index, freq): + + _calendar = Cal.calendar(freq=freq) + resample_data = np.empty(end_index - start_index + 1, dtype="float32") + + for cur_index in range(start_index, end_index + 1): + cur_time = _calendar[cur_index] + # To load expression accurately, more historical data are required + start_ws, end_ws = self.feature.get_extended_window_size() + if end_ws > 0: + raise ValueError( + "PIT database does not support referring to future period (e.g. expressions like `Ref('$$roewa_q', -1)` are not supported" + ) + + # The calculated value will always the last element, so the end_offset is zero. + try: + s = self.feature.load(instrument, -start_ws, 0, cur_time) + resample_data[cur_index - start_index] = s.iloc[-1] if len(s) > 0 else np.nan + except FileNotFoundError: + get_module_logger("base").warning(f"WARN: period data not found for {str(self)}") + return pd.Series(dtype="float32", name=str(self)) + + resample_series = pd.Series( + resample_data, index=pd.RangeIndex(start_index, end_index + 1), dtype="float32", name=str(self) + ) + return resample_series + + def get_longest_back_rolling(self): + # The period data will collapse as a normal feature. So no extending and looking back + return 0 + + def get_extended_window_size(self): + # The period data will collapse as a normal feature. So no extending and looking back + return 0, 0 diff --git a/qlib/utils/__init__.py b/qlib/utils/__init__.py index 4d37e14797..b7cdc79e07 100644 --- a/qlib/utils/__init__.py +++ b/qlib/utils/__init__.py @@ -15,6 +15,7 @@ import redis import bisect import shutil +import struct import difflib import inspect import hashlib @@ -28,7 +29,7 @@ import numpy as np import pandas as pd from pathlib import Path -from typing import Dict, Union, Tuple, Any, Text, Optional, Callable +from typing import List, Dict, Union, Tuple, Any, Text, Optional, Callable from types import ModuleType from urllib.parse import urlparse from .file import get_or_create_path, save_multiple_parts_file, unpack_archive_with_buffer, get_tmp_file_with_buffer @@ -62,6 +63,104 @@ def read_bin(file_path: Union[str, Path], start_index, end_index): return series +def get_period_list(first: int, last: int, quarterly: bool) -> List[int]: + """ + This method will be used in PIT database. + It return all the possible values between `first` and `end` (first and end is included) + + Parameters + ---------- + quarterly : bool + will it return quarterly index or yearly index. + + Returns + ------- + List[int] + the possible index between [first, last] + """ + + if not quarterly: + assert all(1900 <= x <= 2099 for x in (first, last)), "invalid arguments" + return list(range(first, last + 1)) + else: + assert all(190000 <= x <= 209904 for x in (first, last)), "invalid arguments" + res = [] + for year in range(first // 100, last // 100 + 1): + for q in range(1, 5): + period = year * 100 + q + if first <= period <= last: + res.append(year * 100 + q) + return res + + +def get_period_offset(first_year, period, quarterly): + if quarterly: + offset = (period // 100 - first_year) * 4 + period % 100 - 1 + else: + offset = period - first_year + return offset + + +def read_period_data(index_path, data_path, period, cur_date_int: int, quarterly, last_period_index: int = None): + """ + At `cur_date`(e.g. 20190102), read the information at `period`(e.g. 201803). + Only the updating info before cur_date or at cur_date will be used. + + Parameters + ---------- + period: int + date period represented by interger, e.g. 201901 corresponds to the first quarter in 2019 + cur_date_int: int + date which represented by interger, e.g. 20190102 + last_period_index: int + it is a optional parameter; it is designed to avoid repeatedly access the .index data of PIT database when + sequentially observing the data (Because the latest index of a specific period of data certainly appear in after the one in last observation). + + Returns + ------- + the query value and byte index the index value + """ + DATA_DTYPE = "".join( + [ + C.pit_record_type["date"], + C.pit_record_type["period"], + C.pit_record_type["value"], + C.pit_record_type["index"], + ] + ) + + PERIOD_DTYPE = C.pit_record_type["period"] + INDEX_DTYPE = C.pit_record_type["index"] + + NAN_VALUE = C.pit_record_nan["value"] + NAN_INDEX = C.pit_record_nan["index"] + + # find the first index of linked revisions + if last_period_index is None: + with open(index_path, "rb") as fi: + (first_year,) = struct.unpack(PERIOD_DTYPE, fi.read(struct.calcsize(PERIOD_DTYPE))) + all_periods = np.fromfile(fi, dtype=INDEX_DTYPE) + offset = get_period_offset(first_year, period, quarterly) + _next = all_periods[offset] + else: + _next = last_period_index + + # load data following the `_next` link + prev_value = NAN_VALUE + prev_next = _next + + with open(data_path, "rb") as fd: + while _next != NAN_INDEX: + fd.seek(_next) + date, period, value, new_next = struct.unpack(DATA_DTYPE, fd.read(struct.calcsize(DATA_DTYPE))) + if date > cur_date_int: + break + prev_next = _next + _next = new_next + prev_value = value + return prev_value, prev_next + + def np_ffill(arr: np.array): """ forward fill a 1D numpy array @@ -172,7 +271,11 @@ def parse_field(field): if not isinstance(field, str): field = str(field) - for pattern, new in [(r"\$(\w+)", rf'Feature("\1")'), (r"(\w+\s*)\(", r"Operators.\1(")]: # Features # Operators + for pattern, new in [ + (r"\$\$(\w+)", r'PFeature("\1")'), # $$ must be before $ + (r"\$(\w+)", rf'Feature("\1")'), + (r"(\w+\s*)\(", r"Operators.\1("), + ]: # Features # Operators field = re.sub(pattern, new, field) return field diff --git a/scripts/data_collector/base.py b/scripts/data_collector/base.py index 18d8955f7c..236d9cddfe 100644 --- a/scripts/data_collector/base.py +++ b/scripts/data_collector/base.py @@ -323,7 +323,7 @@ def __init__(self, source_dir=None, normalize_dir=None, max_workers=1, interval= freq, value from [1min, 1d], default 1d """ if source_dir is None: - source_dir = Path(self.default_base_dir).joinpath("_source") + source_dir = Path(self.default_base_dir).joinpath("source") self.source_dir = Path(source_dir).expanduser().resolve() self.source_dir.mkdir(parents=True, exist_ok=True) @@ -359,6 +359,7 @@ def download_data( end=None, check_data_length: int = None, limit_nums=None, + **kwargs, ): """download data from Internet @@ -398,6 +399,7 @@ def download_data( interval=self.interval, check_data_length=check_data_length, limit_nums=limit_nums, + **kwargs, ).collector_data() def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs): diff --git a/scripts/data_collector/crypto/collector.py b/scripts/data_collector/crypto/collector.py index 0790aa6405..d25568a72a 100644 --- a/scripts/data_collector/crypto/collector.py +++ b/scripts/data_collector/crypto/collector.py @@ -13,7 +13,7 @@ CUR_DIR = Path(__file__).resolve().parent sys.path.append(str(CUR_DIR.parent.parent)) from data_collector.base import BaseCollector, BaseNormalize, BaseRun -from data_collector.utils import get_cg_crypto_symbols +from data_collector.utils import get_cg_crypto_symbols, deco_retry from pycoingecko import CoinGeckoAPI from time import mktime @@ -21,6 +21,40 @@ import time +_CG_CRYPTO_SYMBOLS = None + + +def get_cg_crypto_symbols(qlib_data_path: [str, Path] = None) -> list: + """get crypto symbols in coingecko + + Returns + ------- + crypto symbols in given exchanges list of coingecko + """ + global _CG_CRYPTO_SYMBOLS + + @deco_retry + def _get_coingecko(): + try: + cg = CoinGeckoAPI() + resp = pd.DataFrame(cg.get_coins_markets(vs_currency="usd")) + except: + raise ValueError("request error") + try: + _symbols = resp["id"].to_list() + except Exception as e: + logger.warning(f"request error: {e}") + raise + return _symbols + + if _CG_CRYPTO_SYMBOLS is None: + _all_symbols = _get_coingecko() + + _CG_CRYPTO_SYMBOLS = sorted(set(_all_symbols)) + + return _CG_CRYPTO_SYMBOLS + + class CryptoCollector(BaseCollector): def __init__( self, diff --git a/scripts/data_collector/pit/README.md b/scripts/data_collector/pit/README.md new file mode 100644 index 0000000000..e18dcd0c17 --- /dev/null +++ b/scripts/data_collector/pit/README.md @@ -0,0 +1,35 @@ +# Collect Point-in-Time Data + +> *Please pay **ATTENTION** that the data is collected from [baostock](http://baostock.com) and the data might not be perfect. We recommend users to prepare their own data if they have high-quality dataset. For more information, users can refer to the [related document](https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format)* + +## Requirements + +```bash +pip install -r requirements.txt +``` + +## Collector Data + + +### Download Quarterly CN Data + +```bash +cd qlib/scripts/data_collector/pit/ +# download from baostock.com +python collector.py download_data --source_dir ./csv_pit --start 2000-01-01 --end 2020-01-01 --interval quarterly +``` + +Downloading all data from the stock is very time consuming. If you just want run a quick test on a few stocks, you can run the command below +``` bash +python collector.py download_data --source_dir ./csv_pit --start 2000-01-01 --end 2020-01-01 --interval quarterly --symbol_flt_regx "^(600519|000725).*" +``` + + + +### Dump Data into PIT Format + +```bash +cd qlib/scripts +# data_collector/pit/csv_pit is the data you download just now. +python dump_pit.py dump --csv_path data_collector/pit/csv_pit --qlib_dir ~/.qlib/qlib_data/cn_data --interval quarterly +``` diff --git a/scripts/data_collector/pit/collector.py b/scripts/data_collector/pit/collector.py new file mode 100644 index 0000000000..45e1f984eb --- /dev/null +++ b/scripts/data_collector/pit/collector.py @@ -0,0 +1,334 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import re +import abc +import sys +import datetime +from abc import ABC +from pathlib import Path + +import fire +import numpy as np +import pandas as pd +import baostock as bs +from loguru import logger + +CUR_DIR = Path(__file__).resolve().parent +sys.path.append(str(CUR_DIR.parent.parent)) +from data_collector.base import BaseCollector, BaseRun +from data_collector.utils import get_calendar_list, get_hs_stock_symbols + + +class PitCollector(BaseCollector): + + DEFAULT_START_DATETIME_QUARTER = pd.Timestamp("2000-01-01") + DEFAULT_START_DATETIME_ANNUAL = pd.Timestamp("2000-01-01") + DEFAULT_END_DATETIME_QUARTER = pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1)) + DEFAULT_END_DATETIME_ANNUAL = pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1)) + + INTERVAL_quarterly = "quarterly" + INTERVAL_annual = "annual" + + def __init__( + self, + save_dir: [str, Path], + start=None, + end=None, + interval="quarterly", + max_workers=1, + max_collector_count=1, + delay=0, + check_data_length: bool = False, + limit_nums: int = None, + symbol_flt_regx=None, + ): + """ + + Parameters + ---------- + save_dir: str + pit save dir + interval: str: + value from ['quarterly', 'annual'] + max_workers: int + workers, default 1 + max_collector_count: int + default 1 + delay: float + time.sleep(delay), default 0 + start: str + start datetime, default None + end: str + end datetime, default None + limit_nums: int + using for debug, by default None + """ + if symbol_flt_regx is None: + self.symbol_flt_regx = None + else: + self.symbol_flt_regx = re.compile(symbol_flt_regx) + super(PitCollector, self).__init__( + save_dir=save_dir, + start=start, + end=end, + interval=interval, + max_workers=max_workers, + max_collector_count=max_collector_count, + delay=delay, + check_data_length=check_data_length, + limit_nums=limit_nums, + ) + + def normalize_symbol(self, symbol): + symbol_s = symbol.split(".") + symbol = f"sh{symbol_s[0]}" if symbol_s[-1] == "ss" else f"sz{symbol_s[0]}" + return symbol + + def get_instrument_list(self): + logger.info("get cn stock symbols......") + symbols = get_hs_stock_symbols() + logger.info(f"get {symbols[:10]}[{len(symbols)}] symbols.") + if self.symbol_flt_regx is not None: + s_flt = [] + for s in symbols: + m = self.symbol_flt_regx.match(s) + if m is not None: + s_flt.append(s) + logger.info(f"after filtering, it becomes {s_flt[:10]}[{len(s_flt)}] symbols") + return s_flt + + return symbols + + def _get_data_from_baostock(self, symbol, interval, start_datetime, end_datetime): + error_msg = f"{symbol}-{interval}-{start_datetime}-{end_datetime}" + + def _str_to_float(r): + try: + return float(r) + except Exception as e: + return np.nan + + try: + code, market = symbol.split(".") + market = {"ss": "sh"}.get(market, market) # baostock's API naming is different from default symbol list + symbol = f"{market}.{code}" + rs_report = bs.query_performance_express_report( + code=symbol, start_date=str(start_datetime.date()), end_date=str(end_datetime.date()) + ) + report_list = [] + while (rs_report.error_code == "0") & rs_report.next(): + report_list.append(rs_report.get_row_data()) + # 获取一条记录,将记录合并在一起 + + df_report = pd.DataFrame(report_list, columns=rs_report.fields) + if {"performanceExpPubDate", "performanceExpStatDate", "performanceExpressROEWa"} <= set(rs_report.fields): + df_report = df_report[["performanceExpPubDate", "performanceExpStatDate", "performanceExpressROEWa"]] + df_report.rename( + columns={ + "performanceExpPubDate": "date", + "performanceExpStatDate": "period", + "performanceExpressROEWa": "value", + }, + inplace=True, + ) + df_report["value"] = df_report["value"].apply(lambda r: _str_to_float(r) / 100.0) + df_report["field"] = "roeWa" + + profit_list = [] + for year in range(start_datetime.year - 1, end_datetime.year + 1): + for q_num in range(0, 4): + rs_profit = bs.query_profit_data(code=symbol, year=year, quarter=q_num + 1) + while (rs_profit.error_code == "0") & rs_profit.next(): + row_data = rs_profit.get_row_data() + if "pubDate" in rs_profit.fields: + pub_date = pd.Timestamp(row_data[rs_profit.fields.index("pubDate")]) + if pub_date >= start_datetime and pub_date <= end_datetime: + profit_list.append(row_data) + + df_profit = pd.DataFrame(profit_list, columns=rs_profit.fields) + if {"pubDate", "statDate", "roeAvg"} <= set(rs_profit.fields): + df_profit = df_profit[["pubDate", "statDate", "roeAvg"]] + df_profit.rename( + columns={"pubDate": "date", "statDate": "period", "roeAvg": "value"}, + inplace=True, + ) + df_profit["value"] = df_profit["value"].apply(_str_to_float) + df_profit["field"] = "roeWa" + + forecast_list = [] + rs_forecast = bs.query_forecast_report( + code=symbol, start_date=str(start_datetime.date()), end_date=str(end_datetime.date()) + ) + + while (rs_forecast.error_code == "0") & rs_forecast.next(): + forecast_list.append(rs_forecast.get_row_data()) + + df_forecast = pd.DataFrame(forecast_list, columns=rs_forecast.fields) + if { + "profitForcastExpPubDate", + "profitForcastExpStatDate", + "profitForcastChgPctUp", + "profitForcastChgPctDwn", + } <= set(rs_forecast.fields): + df_forecast = df_forecast[ + [ + "profitForcastExpPubDate", + "profitForcastExpStatDate", + "profitForcastChgPctUp", + "profitForcastChgPctDwn", + ] + ] + df_forecast.rename( + columns={ + "profitForcastExpPubDate": "date", + "profitForcastExpStatDate": "period", + }, + inplace=True, + ) + + df_forecast["profitForcastChgPctUp"] = df_forecast["profitForcastChgPctUp"].apply(_str_to_float) + df_forecast["profitForcastChgPctDwn"] = df_forecast["profitForcastChgPctDwn"].apply(_str_to_float) + df_forecast["value"] = ( + df_forecast["profitForcastChgPctUp"] + df_forecast["profitForcastChgPctDwn"] + ) / 200 + df_forecast["field"] = "YOYNI" + df_forecast.drop(["profitForcastChgPctUp", "profitForcastChgPctDwn"], axis=1, inplace=True) + + growth_list = [] + for year in range(start_datetime.year - 1, end_datetime.year + 1): + for q_num in range(0, 4): + rs_growth = bs.query_growth_data(code=symbol, year=year, quarter=q_num + 1) + while (rs_growth.error_code == "0") & rs_growth.next(): + row_data = rs_growth.get_row_data() + if "pubDate" in rs_growth.fields: + pub_date = pd.Timestamp(row_data[rs_growth.fields.index("pubDate")]) + if pub_date >= start_datetime and pub_date <= end_datetime: + growth_list.append(row_data) + + df_growth = pd.DataFrame(growth_list, columns=rs_growth.fields) + if {"pubDate", "statDate", "YOYNI"} <= set(rs_growth.fields): + df_growth = df_growth[["pubDate", "statDate", "YOYNI"]] + df_growth.rename( + columns={"pubDate": "date", "statDate": "period", "YOYNI": "value"}, + inplace=True, + ) + df_growth["value"] = df_growth["value"].apply(_str_to_float) + df_growth["field"] = "YOYNI" + df_merge = df_report.append([df_profit, df_forecast, df_growth]) + + return df_merge + except Exception as e: + logger.warning(f"{error_msg}:{e}") + + def _process_data(self, df, symbol, interval): + error_msg = f"{symbol}-{interval}" + + def _process_period(r): + _date = pd.Timestamp(r) + return _date.year if interval == self.INTERVAL_annual else _date.year * 100 + (_date.month - 1) // 3 + 1 + + try: + _date = df["period"].apply( + lambda x: ( + pd.to_datetime(x) + pd.DateOffset(days=(45 if interval == self.INTERVAL_quarterly else 90)) + ).date() + ) + df["date"] = df["date"].fillna(_date.astype(str)) + df["period"] = df["period"].apply(_process_period) + return df + except Exception as e: + logger.warning(f"{error_msg}:{e}") + + def get_data( + self, symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp + ) -> [pd.DataFrame]: + + if interval == self.INTERVAL_quarterly: + _result = self._get_data_from_baostock(symbol, interval, start_datetime, end_datetime) + if _result is None or _result.empty: + return _result + else: + return self._process_data(_result, symbol, interval) + else: + raise ValueError(f"cannot support {interval}") + return self._process_data(_result, interval) + + @property + def min_numbers_trading(self): + pass + + +class Run(BaseRun): + def __init__(self, source_dir=None, max_workers=1, interval="quarterly"): + """ + + Parameters + ---------- + source_dir: str + The directory where the raw data collected from the Internet is saved, default "Path(__file__).parent/source" + normalize_dir: str + Directory for normalize data, default "Path(__file__).parent/normalize" + max_workers: int + Concurrent number, default is 4 + interval: str + freq, value from [quarterly, annual], default 1d + """ + super().__init__(source_dir=source_dir, max_workers=max_workers, interval=interval) + + @property + def collector_class_name(self): + return "PitCollector" + + @property + def default_base_dir(self) -> [Path, str]: + return CUR_DIR + + def download_data( + self, + max_collector_count=1, + delay=0, + start=None, + end=None, + interval="quarterly", + check_data_length=False, + limit_nums=None, + **kwargs, + ): + """download data from Internet + + Parameters + ---------- + max_collector_count: int + default 2 + delay: float + time.sleep(delay), default 0 + interval: str + freq, value from [quarterly, annual], default 1d + start: str + start datetime, default "2000-01-01" + end: str + end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))`` + check_data_length: bool # if this param useful? + check data length, by default False + limit_nums: int + using for debug, by default None + + Examples + --------- + # get quarterly data + $ python collector.py download_data --source_dir ~/.qlib/cn_data/source/pit_quarter --start 2000-01-01 --end 2021-01-01 --interval quarterly + """ + + super(Run, self).download_data( + max_collector_count, delay, start, end, interval, check_data_length, limit_nums, **kwargs + ) + + def normalize_class_name(self): + pass + + +if __name__ == "__main__": + bs.login() + fire.Fire(Run) + bs.logout() diff --git a/scripts/data_collector/pit/requirements.txt b/scripts/data_collector/pit/requirements.txt new file mode 100644 index 0000000000..0cd9b42f9c --- /dev/null +++ b/scripts/data_collector/pit/requirements.txt @@ -0,0 +1,9 @@ +loguru +fire +tqdm +requests +pandas +lxml +loguru +baostock +yahooquery \ No newline at end of file diff --git a/scripts/data_collector/pit/test_pit.py b/scripts/data_collector/pit/test_pit.py new file mode 100644 index 0000000000..fa456670b0 --- /dev/null +++ b/scripts/data_collector/pit/test_pit.py @@ -0,0 +1,194 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import qlib +from qlib.data import D +import unittest + + +class TestPIT(unittest.TestCase): + """ + NOTE!!!!!! + The assert of this test assumes that users follows the cmd below and only download 2 stock. + `python collector.py download_data --source_dir ./csv_pit --start 2000-01-01 --end 2020-01-01 --interval quarterly --symbol_flt_regx "^(600519|000725).*"` + """ + + def setUp(self): + # qlib.init(kernels=1) # NOTE: set kernel to 1 to make it debug easier + qlib.init() # NOTE: set kernel to 1 to make it debug easier + + def to_str(self, obj): + return "".join(str(obj).split()) + + def check_same(self, a, b): + self.assertEqual(self.to_str(a), self.to_str(b)) + + def test_query(self): + instruments = ["sh600519"] + fields = ["P($$roewa_q)", "P($$yoyni_q)"] + # Mao Tai published 2019Q2 report at 2019-07-13 & 2019-07-18 + # - http://www.cninfo.com.cn/new/commonUrl/pageOfSearch?url=disclosure/list/search&lastPage=index + data = D.features(instruments, fields, start_time="2019-01-01", end_time="20190719", freq="day") + + print(data) + + res = """ + P($$roewa_q) P($$yoyni_q) + count 133.000000 133.000000 + mean 0.196412 0.277930 + std 0.097591 0.030262 + min 0.000000 0.243892 + 25% 0.094737 0.243892 + 50% 0.255220 0.304181 + 75% 0.255220 0.305041 + max 0.344644 0.305041 + """ + self.check_same(data.describe(), res) + + res = """ + P($$roewa_q) P($$yoyni_q) + instrument datetime + sh600519 2019-07-15 0.000000 0.305041 + 2019-07-16 0.000000 0.305041 + 2019-07-17 0.000000 0.305041 + 2019-07-18 0.175322 0.252650 + 2019-07-19 0.175322 0.252650 + """ + self.check_same(data.tail(), res) + + def test_no_exist_data(self): + fields = ["P($$roewa_q)", "P($$yoyni_q)", "$close"] + data = D.features(["sh600519", "sh601988"], fields, start_time="2019-01-01", end_time="20190719", freq="day") + data["$close"] = 1 # in case of different dataset gives different values + print(data) + expect = """ + P($$roewa_q) P($$yoyni_q) $close + instrument datetime + sh600519 2019-01-02 0.25522 0.243892 1 + 2019-01-03 0.25522 0.243892 1 + 2019-01-04 0.25522 0.243892 1 + 2019-01-07 0.25522 0.243892 1 + 2019-01-08 0.25522 0.243892 1 + ... ... ... ... + sh601988 2019-07-15 NaN NaN 1 + 2019-07-16 NaN NaN 1 + 2019-07-17 NaN NaN 1 + 2019-07-18 NaN NaN 1 + 2019-07-19 NaN NaN 1 + + [266 rows x 3 columns] + """ + self.check_same(data, expect) + + def test_expr(self): + fields = [ + "P(Mean($$roewa_q, 1))", + "P($$roewa_q)", + "P(Mean($$roewa_q, 2))", + "P(Ref($$roewa_q, 1))", + "P((Ref($$roewa_q, 1) +$$roewa_q) / 2)", + ] + instruments = ["sh600519"] + data = D.features(instruments, fields, start_time="2019-01-01", end_time="20190719", freq="day") + expect = """ + P(Mean($$roewa_q, 1)) P($$roewa_q) P(Mean($$roewa_q, 2)) P(Ref($$roewa_q, 1)) P((Ref($$roewa_q, 1) +$$roewa_q) / 2) + instrument datetime + sh600519 2019-07-01 0.094737 0.094737 0.219691 0.344644 0.219691 + 2019-07-02 0.094737 0.094737 0.219691 0.344644 0.219691 + 2019-07-03 0.094737 0.094737 0.219691 0.344644 0.219691 + 2019-07-04 0.094737 0.094737 0.219691 0.344644 0.219691 + 2019-07-05 0.094737 0.094737 0.219691 0.344644 0.219691 + 2019-07-08 0.094737 0.094737 0.219691 0.344644 0.219691 + 2019-07-09 0.094737 0.094737 0.219691 0.344644 0.219691 + 2019-07-10 0.094737 0.094737 0.219691 0.344644 0.219691 + 2019-07-11 0.094737 0.094737 0.219691 0.344644 0.219691 + 2019-07-12 0.094737 0.094737 0.219691 0.344644 0.219691 + 2019-07-15 0.000000 0.000000 0.047369 0.094737 0.047369 + 2019-07-16 0.000000 0.000000 0.047369 0.094737 0.047369 + 2019-07-17 0.000000 0.000000 0.047369 0.094737 0.047369 + 2019-07-18 0.175322 0.175322 0.135029 0.094737 0.135029 + 2019-07-19 0.175322 0.175322 0.135029 0.094737 0.135029 + """ + self.check_same(data.tail(15), expect) + + def test_unlimit(self): + # fields = ["P(Mean($$roewa_q, 1))", "P($$roewa_q)", "P(Mean($$roewa_q, 2))", "P(Ref($$roewa_q, 1))", "P((Ref($$roewa_q, 1) +$$roewa_q) / 2)"] + fields = ["P($$roewa_q)"] + instruments = ["sh600519"] + _ = D.features(instruments, fields, freq="day") # this should not raise error + data = D.features(instruments, fields, end_time="20200101", freq="day") # this should not raise error + s = data.iloc[:, 0] + # You can check the expected value based on the content in `docs/advanced/PIT.rst` + expect = """ + instrument datetime + sh600519 1999-11-10 NaN + 2007-04-30 0.090219 + 2007-08-17 0.139330 + 2007-10-23 0.245863 + 2008-03-03 0.347900 + 2008-03-13 0.395989 + 2008-04-22 0.100724 + 2008-08-28 0.249968 + 2008-10-27 0.334120 + 2009-03-25 0.390117 + 2009-04-21 0.102675 + 2009-08-07 0.230712 + 2009-10-26 0.300730 + 2010-04-02 0.335461 + 2010-04-26 0.083825 + 2010-08-12 0.200545 + 2010-10-29 0.260986 + 2011-03-21 0.307393 + 2011-04-25 0.097411 + 2011-08-31 0.248251 + 2011-10-18 0.318919 + 2012-03-23 0.403900 + 2012-04-11 0.403925 + 2012-04-26 0.112148 + 2012-08-10 0.264847 + 2012-10-26 0.370487 + 2013-03-29 0.450047 + 2013-04-18 0.099958 + 2013-09-02 0.210442 + 2013-10-16 0.304543 + 2014-03-25 0.394328 + 2014-04-25 0.083217 + 2014-08-29 0.164503 + 2014-10-30 0.234085 + 2015-04-21 0.078494 + 2015-08-28 0.137504 + 2015-10-26 0.201709 + 2016-03-24 0.264205 + 2016-04-21 0.073664 + 2016-08-29 0.136576 + 2016-10-31 0.188062 + 2017-04-17 0.244385 + 2017-04-25 0.080614 + 2017-07-28 0.151510 + 2017-10-26 0.254166 + 2018-03-28 0.329542 + 2018-05-02 0.088887 + 2018-08-02 0.170563 + 2018-10-29 0.255220 + 2019-03-29 0.344644 + 2019-04-25 0.094737 + 2019-07-15 0.000000 + 2019-07-18 0.175322 + 2019-10-16 0.255819 + Name: P($$roewa_q), dtype: float32 + """ + + self.check_same(s[~s.duplicated().values], expect) + + def test_expr2(self): + instruments = ["sh600519"] + fields = ["P($$roewa_q)", "P($$yoyni_q)"] + fields += ["P(($$roewa_q / $$yoyni_q) / Ref($$roewa_q / $$yoyni_q, 1) - 1)"] + fields += ["P(Sum($$yoyni_q, 4))"] + fields += ["$close", "P($$roewa_q) * $close"] + data = D.features(instruments, fields, start_time="2019-01-01", end_time="2020-01-01", freq="day") + print(data) + print(data.describe()) + + +if __name__ == "__main__": + unittest.main() diff --git a/scripts/data_collector/utils.py b/scripts/data_collector/utils.py index 33e3a047f5..19131ec29f 100644 --- a/scripts/data_collector/utils.py +++ b/scripts/data_collector/utils.py @@ -19,7 +19,6 @@ from tqdm import tqdm from functools import partial from concurrent.futures import ProcessPoolExecutor -from pycoingecko import CoinGeckoAPI HS_SYMBOLS_URL = "http://app.finance.ifeng.com/hq/list.php?type=stock_a&class={s_type}" @@ -43,7 +42,6 @@ _US_SYMBOLS = None _IN_SYMBOLS = None _EN_FUND_SYMBOLS = None -_CG_CRYPTO_SYMBOLS = None _CALENDAR_MAP = {} # NOTE: Until 2020-10-20 20:00:00 @@ -379,37 +377,6 @@ def _get_eastmoney(): return _EN_FUND_SYMBOLS -def get_cg_crypto_symbols(qlib_data_path: [str, Path] = None) -> list: - """get crypto symbols in coingecko - - Returns - ------- - crypto symbols in given exchanges list of coingecko - """ - global _CG_CRYPTO_SYMBOLS - - @deco_retry - def _get_coingecko(): - try: - cg = CoinGeckoAPI() - resp = pd.DataFrame(cg.get_coins_markets(vs_currency="usd")) - except: - raise ValueError("request error") - try: - _symbols = resp["id"].to_list() - except Exception as e: - logger.warning(f"request error: {e}") - raise - return _symbols - - if _CG_CRYPTO_SYMBOLS is None: - _all_symbols = _get_coingecko() - - _CG_CRYPTO_SYMBOLS = sorted(set(_all_symbols)) - - return _CG_CRYPTO_SYMBOLS - - def symbol_suffix_to_prefix(symbol: str, capital: bool = True) -> str: """symbol suffix to prefix diff --git a/scripts/dump_pit.py b/scripts/dump_pit.py new file mode 100644 index 0000000000..cda872c09f --- /dev/null +++ b/scripts/dump_pit.py @@ -0,0 +1,282 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +""" +TODO: +- A more well-designed PIT database is required. + - seperated insert, delete, update, query operations are required. +""" + +import abc +import shutil +import struct +import traceback +from pathlib import Path +from typing import Iterable, List, Union +from functools import partial +from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor + +import fire +import numpy as np +import pandas as pd +from tqdm import tqdm +from loguru import logger +from qlib.utils import fname_to_code, code_to_fname, get_period_offset +from qlib.config import C + + +class DumpPitData: + PIT_DIR_NAME = "financial" + PIT_CSV_SEP = "," + DATA_FILE_SUFFIX = ".data" + INDEX_FILE_SUFFIX = ".index" + + INTERVAL_quarterly = "quarterly" + INTERVAL_annual = "annual" + + PERIOD_DTYPE = C.pit_record_type["period"] + INDEX_DTYPE = C.pit_record_type["index"] + DATA_DTYPE = "".join( + [ + C.pit_record_type["date"], + C.pit_record_type["period"], + C.pit_record_type["value"], + C.pit_record_type["index"], + ] + ) + + NA_INDEX = C.pit_record_nan["index"] + + INDEX_DTYPE_SIZE = struct.calcsize(INDEX_DTYPE) + PERIOD_DTYPE_SIZE = struct.calcsize(PERIOD_DTYPE) + DATA_DTYPE_SIZE = struct.calcsize(DATA_DTYPE) + + UPDATE_MODE = "update" + ALL_MODE = "all" + + def __init__( + self, + csv_path: str, + qlib_dir: str, + backup_dir: str = None, + freq: str = "quarterly", + max_workers: int = 16, + date_column_name: str = "date", + period_column_name: str = "period", + value_column_name: str = "value", + field_column_name: str = "field", + file_suffix: str = ".csv", + exclude_fields: str = "", + include_fields: str = "", + limit_nums: int = None, + ): + """ + + Parameters + ---------- + csv_path: str + stock data path or directory + qlib_dir: str + qlib(dump) data director + backup_dir: str, default None + if backup_dir is not None, backup qlib_dir to backup_dir + freq: str, default "quarterly" + data frequency + max_workers: int, default None + number of threads + date_column_name: str, default "date" + the name of the date field in the csv + file_suffix: str, default ".csv" + file suffix + include_fields: tuple + dump fields + exclude_fields: tuple + fields not dumped + limit_nums: int + Use when debugging, default None + """ + csv_path = Path(csv_path).expanduser() + if isinstance(exclude_fields, str): + exclude_fields = exclude_fields.split(",") + if isinstance(include_fields, str): + include_fields = include_fields.split(",") + self._exclude_fields = tuple(filter(lambda x: len(x) > 0, map(str.strip, exclude_fields))) + self._include_fields = tuple(filter(lambda x: len(x) > 0, map(str.strip, include_fields))) + self.file_suffix = file_suffix + self.csv_files = sorted(csv_path.glob(f"*{self.file_suffix}") if csv_path.is_dir() else [csv_path]) + if limit_nums is not None: + self.csv_files = self.csv_files[: int(limit_nums)] + self.qlib_dir = Path(qlib_dir).expanduser() + self.backup_dir = backup_dir if backup_dir is None else Path(backup_dir).expanduser() + if backup_dir is not None: + self._backup_qlib_dir(Path(backup_dir).expanduser()) + + self.works = max_workers + self.date_column_name = date_column_name + self.period_column_name = period_column_name + self.value_column_name = value_column_name + self.field_column_name = field_column_name + + self._mode = self.ALL_MODE + + def _backup_qlib_dir(self, target_dir: Path): + shutil.copytree(str(self.qlib_dir.resolve()), str(target_dir.resolve())) + + def get_source_data(self, file_path: Path) -> pd.DataFrame: + df = pd.read_csv(str(file_path.resolve()), low_memory=False) + df[self.value_column_name] = df[self.value_column_name].astype("float32") + df[self.date_column_name] = df[self.date_column_name].str.replace("-", "").astype("int32") + # df.drop_duplicates([self.date_field_name], inplace=True) + return df + + def get_symbol_from_file(self, file_path: Path) -> str: + return fname_to_code(file_path.name[: -len(self.file_suffix)].strip().lower()) + + def get_dump_fields(self, df: Iterable[str]) -> Iterable[str]: + return ( + set(self._include_fields) + if self._include_fields + else set(df[self.field_column_name]) - set(self._exclude_fields) + if self._exclude_fields + else set(df[self.field_column_name]) + ) + + def get_filenames(self, symbol, field, interval): + dir_name = self.qlib_dir.joinpath(self.PIT_DIR_NAME, symbol) + dir_name.mkdir(parents=True, exist_ok=True) + return ( + dir_name.joinpath(f"{field}_{interval[0]}{self.DATA_FILE_SUFFIX}".lower()), + dir_name.joinpath(f"{field}_{interval[0]}{self.INDEX_FILE_SUFFIX}".lower()), + ) + + def _dump_pit( + self, + file_path: str, + interval: str = "quarterly", + overwrite: bool = False, + ): + """ + dump data as the following format: + `/path/to/.data` + [date, period, value, _next] + [date, period, value, _next] + [...] + `/path/to/.index` + [first_year, index, index, ...] + + `` contains the data as the point-in-time (PIT) order: `value` of `period` + is published at `date`, and its successive revised value can be found at `_next` (linked list). + + `.index` contains the index of value for each period (quarter or year). To save + disk space, we only store the `first_year` as its followings periods can be easily infered. + + Parameters + ---------- + symbol: str + stock symbol + interval: str + data interval + overwrite: bool + whether overwrite existing data or update only + """ + symbol = self.get_symbol_from_file(file_path) + df = self.get_source_data(file_path) + if df.empty: + logger.warning(f"{symbol} file is empty") + return + for field in self.get_dump_fields(df): + df_sub = df.query(f'{self.field_column_name}=="{field}"').sort_values(self.date_column_name) + if df_sub.empty: + logger.warning(f"field {field} of {symbol} is empty") + continue + data_file, index_file = self.get_filenames(symbol, field, interval) + + ## calculate first & last period + start_year = df_sub[self.period_column_name].min() + end_year = df_sub[self.period_column_name].max() + if interval == self.INTERVAL_quarterly: + start_year //= 100 + end_year //= 100 + + # adjust `first_year` if existing data found + if not overwrite and index_file.exists(): + with open(index_file, "rb") as fi: + (first_year,) = struct.unpack(self.PERIOD_DTYPE, fi.read(self.PERIOD_DTYPE_SIZE)) + n_years = len(fi.read()) // self.INDEX_DTYPE_SIZE + if interval == self.INTERVAL_quarterly: + n_years //= 4 + start_year = first_year + n_years + else: + with open(index_file, "wb") as f: + f.write(struct.pack(self.PERIOD_DTYPE, start_year)) + first_year = start_year + + # if data already exists, continue to the next field + if start_year > end_year: + logger.warning(f"{symbol}-{field} data already exists, continue to the next field") + continue + + # dump index filled with NA + with open(index_file, "ab") as fi: + for year in range(start_year, end_year + 1): + if interval == self.INTERVAL_quarterly: + fi.write(struct.pack(self.INDEX_DTYPE * 4, *[self.NA_INDEX] * 4)) + else: + fi.write(struct.pack(self.INDEX_DTYPE, self.NA_INDEX)) + + # if data already exists, remove overlapped data + if not overwrite and data_file.exists(): + with open(data_file, "rb") as fd: + fd.seek(-self.DATA_DTYPE_SIZE, 2) + last_date, _, _, _ = struct.unpack(self.DATA_DTYPE, fd.read()) + df_sub = df_sub.query(f"{self.date_column_name}>{last_date}") + # otherwise, + # 1) truncate existing file or create a new file with `wb+` if overwrite, + # 2) or append existing file or create a new file with `ab+` if not overwrite + else: + with open(data_file, "wb+" if overwrite else "ab+"): + pass + + with open(data_file, "rb+") as fd, open(index_file, "rb+") as fi: + + # update index if needed + for i, row in df_sub.iterrows(): + # get index + offset = get_period_offset(first_year, row.period, interval == self.INTERVAL_quarterly) + + fi.seek(self.PERIOD_DTYPE_SIZE + self.INDEX_DTYPE_SIZE * offset) + (cur_index,) = struct.unpack(self.INDEX_DTYPE, fi.read(self.INDEX_DTYPE_SIZE)) + + # Case I: new data => update `_next` with current index + if cur_index == self.NA_INDEX: + fi.seek(self.PERIOD_DTYPE_SIZE + self.INDEX_DTYPE_SIZE * offset) + fi.write(struct.pack(self.INDEX_DTYPE, fd.tell())) + # Case II: previous data exists => find and update the last `_next` + else: + _cur_fd = fd.tell() + prev_index = self.NA_INDEX + while cur_index != self.NA_INDEX: # NOTE: first iter always != NA_INDEX + fd.seek(cur_index + self.DATA_DTYPE_SIZE - self.INDEX_DTYPE_SIZE) + prev_index = cur_index + (cur_index,) = struct.unpack(self.INDEX_DTYPE, fd.read(self.INDEX_DTYPE_SIZE)) + fd.seek(prev_index + self.DATA_DTYPE_SIZE - self.INDEX_DTYPE_SIZE) + fd.write(struct.pack(self.INDEX_DTYPE, _cur_fd)) # NOTE: add _next pointer + fd.seek(_cur_fd) + + # dump data + fd.write(struct.pack(self.DATA_DTYPE, row.date, row.period, row.value, self.NA_INDEX)) + + def dump(self, interval="quarterly", overwrite=False): + logger.info("start dump pit data......") + _dump_func = partial(self._dump_pit, interval=interval, overwrite=overwrite) + + with tqdm(total=len(self.csv_files)) as p_bar: + with ProcessPoolExecutor(max_workers=self.works) as executor: + for _ in executor.map(_dump_func, self.csv_files): + p_bar.update() + + def __call__(self, *args, **kwargs): + self.dump() + + +if __name__ == "__main__": + fire.Fire(DumpPitData) diff --git a/setup.py b/setup.py index ab397e1cf5..2bd3f0410c 100644 --- a/setup.py +++ b/setup.py @@ -126,6 +126,14 @@ def get_version(rel_path: str) -> str: }, ext_modules=extensions, install_requires=REQUIRED, + extras_require={ + "dev": [ + "coverage", + "pytest>=3", + "sphinx", + "sphinx_rtd_theme", + ] + }, include_package_data=True, classifiers=[ # Trove classifiers