From efaeb72975d08a55f1726d8024115b908dfce68c Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Wed, 3 Jul 2024 00:09:23 +0800 Subject: [PATCH 1/7] feat: add gene_physionet2012() back as a deprecated func; --- pypots/data/generating.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/pypots/data/generating.py b/pypots/data/generating.py index c979ac27..f50b5276 100644 --- a/pypots/data/generating.py +++ b/pypots/data/generating.py @@ -9,11 +9,14 @@ from typing import Optional, Tuple import numpy as np +from benchpots.datasets import preprocess_physionet2012 from pygrinder import mcar from sklearn.model_selection import train_test_split from sklearn.preprocessing import StandardScaler from sklearn.utils import check_random_state +from ..utils.logging import logger + def gene_complete_random_walk( n_samples: int = 1000, @@ -318,3 +321,18 @@ def gene_random_walk( data["test_X_indicating_mask"] = np.isnan(test_X_ori) ^ np.isnan(test_X) return data + + +def gene_physionet2012(artificially_missing_rate: float = 0.1): + dataset_from_benchpots = preprocess_physionet2012( + subset="all", rate=artificially_missing_rate + ) + logger.warning( + "🚨 Due to the full release of BenchPOTS package, " + "gene_physionet2012() has been deprecated and will be removed in pypots v0.8" + ) + logger.info( + "🌟 Please refer to https://github.com/WenjieDu/BenchPOTS and " + "check out the func benchpots.datasets.preprocess_physionet2012()" + ) + return dataset_from_benchpots From 4015193fc6d7d5c330b624590518025c993869f0 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Wed, 3 Jul 2024 00:27:05 +0800 Subject: [PATCH 2/7] refactor: move full implementation of sliding_window funcs to benchpots v0.2; --- pypots/data/utils.py | 69 +++++++++++++------------------------------- requirements.txt | 2 +- setup.cfg | 2 +- setup.py | 2 +- 4 files changed, 23 insertions(+), 52 deletions(-) diff --git a/pypots/data/utils.py b/pypots/data/utils.py index f6bc939b..7762ff7f 100644 --- a/pypots/data/utils.py +++ b/pypots/data/utils.py @@ -5,14 +5,12 @@ # Created by Wenjie Du # License: BSD-3-Clause -import math from typing import Union +import benchpots import numpy as np import torch -from ..utils.logging import logger - def turn_data_into_specified_dtype( data: Union[np.ndarray, torch.Tensor, list], @@ -166,7 +164,11 @@ def parse_delta( return delta -def sliding_window(time_series, window_len, sliding_len=None): +def sliding_window( + time_series: Union[np.ndarray, torch.Tensor], + window_len: int, + sliding_len: int = None, +) -> Union[np.ndarray, torch.Tensor]: """Generate time series samples with sliding window method, truncating windows from time-series data with a given sequence length. @@ -177,41 +179,27 @@ def sliding_window(time_series, window_len, sliding_len=None): Parameters ---------- - time_series : np.ndarray, + time_series : time series data, len(shape)=2, [total_length, feature_num] - window_len : int, + window_len : The length of the sliding window, i.e. the number of time steps in the generated data samples. - sliding_len : int, default = None, + sliding_len : The sliding length of the window for each moving step. It will be set as the same with n_steps if None. Returns ------- - samples : np.ndarray, + samples : The generated time-series data samples of shape [seq_len//sliding_len, n_steps, n_features]. """ - sliding_len = window_len if sliding_len is None else sliding_len - total_len = time_series.shape[0] - start_indices = np.asarray(range(total_len // sliding_len)) * sliding_len - - # remove the last one if left length is not enough - if total_len - start_indices[-1] < window_len: - to_drop = math.ceil(window_len / sliding_len) - left_len = total_len - start_indices[-1] - start_indices = start_indices[:-to_drop] - logger.warning( - f"The last {to_drop} samples are dropped due to the left length {left_len} is not enough." - ) - - sample_collector = [] - for idx in start_indices: - sample_collector.append(time_series[idx : idx + window_len]) - samples = np.asarray(sample_collector).astype("float32") - - return samples + return benchpots.utils.sliding_window( + time_series, + window_len, + sliding_len, + ) def inverse_sliding_window(X, sliding_len): @@ -238,25 +226,8 @@ def inverse_sliding_window(X, sliding_len): The restored time-series data with shape of [total_length, n_features]. """ - assert len(X.shape) == 3, f"X should be a 3D array, but got {X.shape}" - n_samples, window_size, n_features = X.shape - - if sliding_len >= window_size: - if sliding_len > window_size: - logger.warning( - f"sliding_len {sliding_len} is larger than the window size {window_size}, " - f"hence there will be gaps between restored data." - ) - restored_data = X.reshape(n_samples * window_size, n_features) - else: - collector = [X[0][:sliding_len]] - overlap = X[0][sliding_len:] - for x in X[1:]: - overlap_avg = (overlap + x[:-sliding_len]) / 2 - collector.append(overlap_avg[:sliding_len]) - overlap = np.concatenate( - [overlap_avg[sliding_len:], x[-sliding_len:]], axis=0 - ) - collector.append(overlap) - restored_data = np.concatenate(collector, axis=0) - return restored_data + + return benchpots.utils.inverse_sliding_window( + X, + sliding_len, + ) diff --git a/requirements.txt b/requirements.txt index 63397d47..6f4d8112 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,4 +13,4 @@ scikit-learn torch>=1.10.0 tsdb>=0.4 pygrinder>=0.6 -benchpots>=0.1 +benchpots>=0.2 diff --git a/setup.cfg b/setup.cfg index f25dd640..b027a9db 100644 --- a/setup.cfg +++ b/setup.cfg @@ -37,7 +37,7 @@ basic = torch>=1.10.0 tsdb>=0.4 pygrinder>=0.6 - benchpots>=0.1 + benchpots>=0.2 # dependencies that are optional, torch-geometric are only needed for model Raindrop # but its installation takes too much time diff --git a/setup.py b/setup.py index 13172658..37f3be90 100644 --- a/setup.py +++ b/setup.py @@ -58,7 +58,7 @@ "torch>=1.10.0", "tsdb>=0.4", "pygrinder>=0.6", - "benchpots>=0.1", + "benchpots>=0.2", ], python_requires=">=3.8.0", setup_requires=["setuptools>=38.6.0"], From 89931376c532433e6b377ea5a40e384b5c9e0fa1 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Thu, 4 Jul 2024 09:00:54 +0800 Subject: [PATCH 3/7] refactor: return None from pickle_dump(); --- pypots/data/saving/pickle.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pypots/data/saving/pickle.py b/pypots/data/saving/pickle.py index cb335d31..a9c50bf6 100644 --- a/pypots/data/saving/pickle.py +++ b/pypots/data/saving/pickle.py @@ -6,13 +6,12 @@ # License: BSD-3-Clause import pickle -from typing import Optional from ...utils.file import extract_parent_dir, create_dir_if_not_exist from ...utils.logging import logger -def pickle_dump(data: object, path: str) -> Optional[str]: +def pickle_dump(data: object, path: str) -> None: """Pickle the given object. Parameters @@ -39,7 +38,6 @@ def pickle_dump(data: object, path: str) -> Optional[str]: ) return None logger.info(f"Successfully saved to {path}") - return path def pickle_load(path: str) -> object: From 4200c94be75d1b446635989ac0ea23c3322e1470 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Fri, 5 Jul 2024 01:05:12 +0800 Subject: [PATCH 4/7] fix: return None if loading failed to avoid using var data before assignment; --- pypots/data/saving/pickle.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pypots/data/saving/pickle.py b/pypots/data/saving/pickle.py index a9c50bf6..f9049b1b 100644 --- a/pypots/data/saving/pickle.py +++ b/pypots/data/saving/pickle.py @@ -59,4 +59,5 @@ def pickle_load(path: str) -> object: data = pickle.load(f) except Exception as e: logger.error(f"❌ Loading data failed. Operation aborted. See info below:\n{e}") + return None return data From 499e0c317ea29730128e6a66f4ce17bdb06a59b4 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Mon, 15 Jul 2024 23:17:02 +0800 Subject: [PATCH 5/7] docs: update the docs configs; --- docs/conf.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/conf.py b/docs/conf.py index c3656cf4..777d523a 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -33,6 +33,8 @@ release = pypots.__version__ # -- General configuration --------------------------------------------------- +# Set canonical URL from the Read the Docs Domain +html_baseurl = os.environ.get("READTHEDOCS_CANONICAL_URL", "") # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom @@ -98,6 +100,9 @@ html_context = { "last_updated": f"{date_now.year}/{date_now.month}/{date_now.day}", } +# Tell Jinja2 templates the build is running on Read the Docs +if os.environ.get("READTHEDOCS", "") == "True": + html_context["READTHEDOCS"] = True html_favicon = ( "https://raw.githubusercontent.com/" From 379ed6b61ec8137d69d4a80e1267bd061f8ed44e Mon Sep 17 00:00:00 2001 From: Cole Date: Tue, 16 Jul 2024 12:28:16 -0400 Subject: [PATCH 6/7] add lerp imputation method, lerp test, and modify imputation init file --- pypots/imputation/__init__.py | 2 + pypots/imputation/lerp/__init__.py | 12 +++ pypots/imputation/lerp/model.py | 160 +++++++++++++++++++++++++++++ tests/imputation/lerp.py | 74 +++++++++++++ 4 files changed, 248 insertions(+) create mode 100644 pypots/imputation/lerp/__init__.py create mode 100644 pypots/imputation/lerp/model.py create mode 100644 tests/imputation/lerp.py diff --git a/pypots/imputation/__init__.py b/pypots/imputation/__init__.py index dbf168e3..b27f9079 100644 --- a/pypots/imputation/__init__.py +++ b/pypots/imputation/__init__.py @@ -40,6 +40,7 @@ from .locf import LOCF from .mean import Mean from .median import Median +from .lerp import Lerp __all__ = [ # neural network imputation methods @@ -76,4 +77,5 @@ "LOCF", "Mean", "Median", + "Lerp", ] diff --git a/pypots/imputation/lerp/__init__.py b/pypots/imputation/lerp/__init__.py new file mode 100644 index 00000000..0ca166fc --- /dev/null +++ b/pypots/imputation/lerp/__init__.py @@ -0,0 +1,12 @@ +""" +The package of the partially-observed time-series imputation method linear interpolation. +""" + +# Created by Cole Sussmeier +# License: BSD-3-Clause + +from .model import Lerp + +__all__ = [ + "Lerp", +] \ No newline at end of file diff --git a/pypots/imputation/lerp/model.py b/pypots/imputation/lerp/model.py new file mode 100644 index 00000000..2f39dcdf --- /dev/null +++ b/pypots/imputation/lerp/model.py @@ -0,0 +1,160 @@ +""" +The implementation of linear interpolation for the partially-observed time-series imputation task. +""" + +# Created by Cole Sussmeier +# License: BSD-3-Clause + +import warnings +from typing import Union, Optional + +import h5py +import numpy as np +import torch + +from ..base import BaseImputer + + +class Lerp(BaseImputer): + """Linear interpolation (Lerp) imputation method. + + Lerp will linearly interpolate missing values between the nearest non-missing values. + If there are missing values at the beginning or end of the series, they will be back-filled or forward-filled with the nearest non-missing value, respectively. + If an entire series is empty, all 'nan' values will be filled with zeros. + """ + + def __init__( + self, + ): + super().__init__() + + def fit( + self, + train_set: Union[dict, str], + val_set: Optional[Union[dict, str]] = None, + file_type: str = "hdf5", + ) -> None: + """Train the imputer on the given data. + + Warnings + -------- + Linear interpolation class does not need to run fit(). + Please run func ``predict()`` directly. + """ + warnings.warn( + "Linear interpolation class has no parameter to train. " + "Please run func `predict()` directly." + ) + + def predict( + self, + test_set: Union[dict, str], + file_type: str = "hdf5", + ) -> dict: + """Make predictions for the input data with the trained model. + + Parameters + ---------- + test_set : dict or str + The dataset for model validating, should be a dictionary including keys as 'X', + or a path string locating a data file supported by PyPOTS (e.g. h5 file). + If it is a dict, X should be array-like of shape [n_samples, sequence length (n_steps), n_features], + which is time-series data for validating, can contain missing values, and y should be array-like of shape + [n_samples], which is classification labels of X. + If it is a path string, the path should point to a data file, e.g. a h5 file, which contains + key-value pairs like a dict, and it has to include keys as 'X' and 'y'. + + file_type : + The type of the given file if test_set is a path string. + + Returns + ------- + result_dict: dict + Prediction results in a Python Dictionary for the given samples. + It should be a dictionary including keys as 'imputation', 'classification', 'clustering', and 'forecasting'. + For sure, only the keys that relevant tasks are supported by the model will be returned. + """ + if isinstance(test_set, str): + with h5py.File(test_set, "r") as f: + X = f["X"][:] + else: + X = test_set["X"] + + assert len(X.shape) == 3, ( + f"Input X should have 3 dimensions [n_samples, n_steps, n_features], " + f"but the actual shape of X: {X.shape}" + ) + if isinstance(X, list): + X = np.asarray(X) + + def _interpolate_missing_values(X: np.ndarray): + nans = np.isnan(X) + nan_index = np.where(nans)[0] + index = np.where(~nans)[0] + if np.any(nans) and index.size > 1: + X[nans] = np.interp(nan_index, index, X[~nans]) + elif np.any(nans): + X[nans] = 0 + + if isinstance(X, np.ndarray): + + trans_X = X.transpose((0, 2, 1)) + n_samples, n_features, n_steps = trans_X.shape + reshaped_X = np.reshape(trans_X, (-1, n_steps)) + imputed_X = np.ones(reshaped_X.shape) + + for i, univariate_series in enumerate(reshaped_X): + t = np.copy(univariate_series) + _interpolate_missing_values(t) + imputed_X[i] = t + + imputed_trans_X = np.reshape(imputed_X, (n_samples, n_features, -1)) + imputed_data = imputed_trans_X.transpose((0, 2, 1)) + + elif isinstance(X, torch.Tensor): + + trans_X = X.permute(0, 2, 1) + n_samples, n_features, n_steps = trans_X.shape + reshaped_X = trans_X.reshape(-1, n_steps) + imputed_X = torch.ones_like(reshaped_X) + + for i, univariate_series in enumerate(reshaped_X): + t = univariate_series.clone().cpu().detach().numpy() + _interpolate_missing_values(t) + imputed_X[i] = torch.from_numpy(t) + + imputed_trans_X = imputed_X.reshape(n_samples, n_features, -1) + imputed_data = imputed_trans_X.permute(0, 2, 1) + + else: + raise ValueError() + + result_dict = { + "imputation": imputed_data, + } + return result_dict + + def impute( + self, + test_set: Union[dict, str], + file_type: str = "hdf5", + ) -> np.ndarray: + """Impute missing values in the given data with the trained model. + + Parameters + ---------- + test_set : + The data samples for testing, should be array-like of shape [n_samples, sequence length (n_steps), + n_features], or a path string locating a data file, e.g. h5 file. + + file_type : + The type of the given file if X is a path string. + + Returns + ------- + array-like, shape [n_samples, sequence length (n_steps), n_features], + Imputed data. + """ + + result_dict = self.predict(test_set, file_type=file_type) + return result_dict["imputation"] \ No newline at end of file diff --git a/tests/imputation/lerp.py b/tests/imputation/lerp.py new file mode 100644 index 00000000..d30909e8 --- /dev/null +++ b/tests/imputation/lerp.py @@ -0,0 +1,74 @@ +""" +Test cases for Linear Interpolation(Lerp) imputation method. +""" + +# Created by Cole Sussmeier +# License: BSD-3-Clause + + +import unittest + +import numpy as np +import pytest +import torch + +from pypots.imputation import Lerp +from pypots.utils.logging import logger +from pypots.utils.metrics import calc_mse +from tests.global_test_config import ( + DATA, + TEST_SET, + GENERAL_H5_TRAIN_SET_PATH, + GENERAL_H5_VAL_SET_PATH, + GENERAL_H5_TEST_SET_PATH, +) + + +class TestLerp(unittest.TestCase): + logger.info("Running tests for an imputation model Lerp...") + lerp = Lerp() + + @pytest.mark.xdist_group(name="imputation-lerp") + def test_0_impute(self): + # if input data is numpy ndarray + test_X_imputed = self.lerp.predict(TEST_SET)["imputation"] + assert not np.isnan( + test_X_imputed + ).any(), "Output still has missing values after running impute()." + test_MSE = calc_mse( + test_X_imputed, DATA["test_X_ori"], DATA["test_X_indicating_mask"] + ) + logger.info(f"Lerp test_MSE: {test_MSE}") + + # if input data is torch tensor + X = torch.from_numpy(np.copy(TEST_SET["X"])) + test_X_ori = torch.from_numpy(np.copy(DATA["test_X_ori"])) + test_X_indicating_mask = torch.from_numpy( + np.copy(DATA["test_X_indicating_mask"]) + ) + + test_X_imputed = self.lerp.predict({"X": X})["imputation"] + assert not torch.isnan( + test_X_imputed + ).any(), "Output still has missing values after running impute()." + test_MSE = calc_mse(test_X_imputed, test_X_ori, test_X_indicating_mask) + logger.info(f"Lerp test_MSE: {test_MSE}") + + @pytest.mark.xdist_group(name="imputation-lerp") + def test_4_lazy_loading(self): + self.lerp.fit(GENERAL_H5_TRAIN_SET_PATH, GENERAL_H5_VAL_SET_PATH) + imputation_results = self.lerp.predict(GENERAL_H5_TEST_SET_PATH) + assert not np.isnan( + imputation_results["imputation"] + ).any(), "Output still has missing values after running impute()." + + test_MSE = calc_mse( + imputation_results["imputation"], + DATA["test_X_ori"], + DATA["test_X_indicating_mask"], + ) + logger.info(f"Lazy-loading Lerp test_MSE: {test_MSE}") + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file From 0e7be497c0514cc2b54bdc9ccb633c7d497f66a1 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Wed, 17 Jul 2024 12:30:44 +0800 Subject: [PATCH 7/7] refactor: fix some linting issues; --- pypots/imputation/lerp/__init__.py | 2 +- pypots/imputation/lerp/model.py | 13 +++++++------ tests/imputation/lerp.py | 2 +- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/pypots/imputation/lerp/__init__.py b/pypots/imputation/lerp/__init__.py index 0ca166fc..2d5c1155 100644 --- a/pypots/imputation/lerp/__init__.py +++ b/pypots/imputation/lerp/__init__.py @@ -9,4 +9,4 @@ __all__ = [ "Lerp", -] \ No newline at end of file +] diff --git a/pypots/imputation/lerp/model.py b/pypots/imputation/lerp/model.py index 2f39dcdf..ffdd60db 100644 --- a/pypots/imputation/lerp/model.py +++ b/pypots/imputation/lerp/model.py @@ -19,10 +19,11 @@ class Lerp(BaseImputer): """Linear interpolation (Lerp) imputation method. Lerp will linearly interpolate missing values between the nearest non-missing values. - If there are missing values at the beginning or end of the series, they will be back-filled or forward-filled with the nearest non-missing value, respectively. + If there are missing values at the beginning or end of the series, they will be back-filled or + forward-filled with the nearest non-missing value, respectively. If an entire series is empty, all 'nan' values will be filled with zeros. """ - + def __init__( self, ): @@ -95,14 +96,14 @@ def _interpolate_missing_values(X: np.ndarray): X[nans] = np.interp(nan_index, index, X[~nans]) elif np.any(nans): X[nans] = 0 - + if isinstance(X, np.ndarray): trans_X = X.transpose((0, 2, 1)) n_samples, n_features, n_steps = trans_X.shape reshaped_X = np.reshape(trans_X, (-1, n_steps)) imputed_X = np.ones(reshaped_X.shape) - + for i, univariate_series in enumerate(reshaped_X): t = np.copy(univariate_series) _interpolate_missing_values(t) @@ -133,7 +134,7 @@ def _interpolate_missing_values(X: np.ndarray): "imputation": imputed_data, } return result_dict - + def impute( self, test_set: Union[dict, str], @@ -157,4 +158,4 @@ def impute( """ result_dict = self.predict(test_set, file_type=file_type) - return result_dict["imputation"] \ No newline at end of file + return result_dict["imputation"] diff --git a/tests/imputation/lerp.py b/tests/imputation/lerp.py index d30909e8..41b396d6 100644 --- a/tests/imputation/lerp.py +++ b/tests/imputation/lerp.py @@ -71,4 +71,4 @@ def test_4_lazy_loading(self): if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main()