diff --git a/pypots/imputation/__init__.py b/pypots/imputation/__init__.py index dbf168e3..b27f9079 100644 --- a/pypots/imputation/__init__.py +++ b/pypots/imputation/__init__.py @@ -40,6 +40,7 @@ from .locf import LOCF from .mean import Mean from .median import Median +from .lerp import Lerp __all__ = [ # neural network imputation methods @@ -76,4 +77,5 @@ "LOCF", "Mean", "Median", + "Lerp", ] diff --git a/pypots/imputation/lerp/__init__.py b/pypots/imputation/lerp/__init__.py new file mode 100644 index 00000000..0ca166fc --- /dev/null +++ b/pypots/imputation/lerp/__init__.py @@ -0,0 +1,12 @@ +""" +The package of the partially-observed time-series imputation method linear interpolation. +""" + +# Created by Cole Sussmeier +# License: BSD-3-Clause + +from .model import Lerp + +__all__ = [ + "Lerp", +] \ No newline at end of file diff --git a/pypots/imputation/lerp/model.py b/pypots/imputation/lerp/model.py new file mode 100644 index 00000000..2f39dcdf --- /dev/null +++ b/pypots/imputation/lerp/model.py @@ -0,0 +1,160 @@ +""" +The implementation of linear interpolation for the partially-observed time-series imputation task. +""" + +# Created by Cole Sussmeier +# License: BSD-3-Clause + +import warnings +from typing import Union, Optional + +import h5py +import numpy as np +import torch + +from ..base import BaseImputer + + +class Lerp(BaseImputer): + """Linear interpolation (Lerp) imputation method. + + Lerp will linearly interpolate missing values between the nearest non-missing values. + If there are missing values at the beginning or end of the series, they will be back-filled or forward-filled with the nearest non-missing value, respectively. + If an entire series is empty, all 'nan' values will be filled with zeros. + """ + + def __init__( + self, + ): + super().__init__() + + def fit( + self, + train_set: Union[dict, str], + val_set: Optional[Union[dict, str]] = None, + file_type: str = "hdf5", + ) -> None: + """Train the imputer on the given data. + + Warnings + -------- + Linear interpolation class does not need to run fit(). + Please run func ``predict()`` directly. + """ + warnings.warn( + "Linear interpolation class has no parameter to train. " + "Please run func `predict()` directly." + ) + + def predict( + self, + test_set: Union[dict, str], + file_type: str = "hdf5", + ) -> dict: + """Make predictions for the input data with the trained model. + + Parameters + ---------- + test_set : dict or str + The dataset for model validating, should be a dictionary including keys as 'X', + or a path string locating a data file supported by PyPOTS (e.g. h5 file). + If it is a dict, X should be array-like of shape [n_samples, sequence length (n_steps), n_features], + which is time-series data for validating, can contain missing values, and y should be array-like of shape + [n_samples], which is classification labels of X. + If it is a path string, the path should point to a data file, e.g. a h5 file, which contains + key-value pairs like a dict, and it has to include keys as 'X' and 'y'. + + file_type : + The type of the given file if test_set is a path string. + + Returns + ------- + result_dict: dict + Prediction results in a Python Dictionary for the given samples. + It should be a dictionary including keys as 'imputation', 'classification', 'clustering', and 'forecasting'. + For sure, only the keys that relevant tasks are supported by the model will be returned. + """ + if isinstance(test_set, str): + with h5py.File(test_set, "r") as f: + X = f["X"][:] + else: + X = test_set["X"] + + assert len(X.shape) == 3, ( + f"Input X should have 3 dimensions [n_samples, n_steps, n_features], " + f"but the actual shape of X: {X.shape}" + ) + if isinstance(X, list): + X = np.asarray(X) + + def _interpolate_missing_values(X: np.ndarray): + nans = np.isnan(X) + nan_index = np.where(nans)[0] + index = np.where(~nans)[0] + if np.any(nans) and index.size > 1: + X[nans] = np.interp(nan_index, index, X[~nans]) + elif np.any(nans): + X[nans] = 0 + + if isinstance(X, np.ndarray): + + trans_X = X.transpose((0, 2, 1)) + n_samples, n_features, n_steps = trans_X.shape + reshaped_X = np.reshape(trans_X, (-1, n_steps)) + imputed_X = np.ones(reshaped_X.shape) + + for i, univariate_series in enumerate(reshaped_X): + t = np.copy(univariate_series) + _interpolate_missing_values(t) + imputed_X[i] = t + + imputed_trans_X = np.reshape(imputed_X, (n_samples, n_features, -1)) + imputed_data = imputed_trans_X.transpose((0, 2, 1)) + + elif isinstance(X, torch.Tensor): + + trans_X = X.permute(0, 2, 1) + n_samples, n_features, n_steps = trans_X.shape + reshaped_X = trans_X.reshape(-1, n_steps) + imputed_X = torch.ones_like(reshaped_X) + + for i, univariate_series in enumerate(reshaped_X): + t = univariate_series.clone().cpu().detach().numpy() + _interpolate_missing_values(t) + imputed_X[i] = torch.from_numpy(t) + + imputed_trans_X = imputed_X.reshape(n_samples, n_features, -1) + imputed_data = imputed_trans_X.permute(0, 2, 1) + + else: + raise ValueError() + + result_dict = { + "imputation": imputed_data, + } + return result_dict + + def impute( + self, + test_set: Union[dict, str], + file_type: str = "hdf5", + ) -> np.ndarray: + """Impute missing values in the given data with the trained model. + + Parameters + ---------- + test_set : + The data samples for testing, should be array-like of shape [n_samples, sequence length (n_steps), + n_features], or a path string locating a data file, e.g. h5 file. + + file_type : + The type of the given file if X is a path string. + + Returns + ------- + array-like, shape [n_samples, sequence length (n_steps), n_features], + Imputed data. + """ + + result_dict = self.predict(test_set, file_type=file_type) + return result_dict["imputation"] \ No newline at end of file diff --git a/tests/imputation/lerp.py b/tests/imputation/lerp.py new file mode 100644 index 00000000..d30909e8 --- /dev/null +++ b/tests/imputation/lerp.py @@ -0,0 +1,74 @@ +""" +Test cases for Linear Interpolation(Lerp) imputation method. +""" + +# Created by Cole Sussmeier +# License: BSD-3-Clause + + +import unittest + +import numpy as np +import pytest +import torch + +from pypots.imputation import Lerp +from pypots.utils.logging import logger +from pypots.utils.metrics import calc_mse +from tests.global_test_config import ( + DATA, + TEST_SET, + GENERAL_H5_TRAIN_SET_PATH, + GENERAL_H5_VAL_SET_PATH, + GENERAL_H5_TEST_SET_PATH, +) + + +class TestLerp(unittest.TestCase): + logger.info("Running tests for an imputation model Lerp...") + lerp = Lerp() + + @pytest.mark.xdist_group(name="imputation-lerp") + def test_0_impute(self): + # if input data is numpy ndarray + test_X_imputed = self.lerp.predict(TEST_SET)["imputation"] + assert not np.isnan( + test_X_imputed + ).any(), "Output still has missing values after running impute()." + test_MSE = calc_mse( + test_X_imputed, DATA["test_X_ori"], DATA["test_X_indicating_mask"] + ) + logger.info(f"Lerp test_MSE: {test_MSE}") + + # if input data is torch tensor + X = torch.from_numpy(np.copy(TEST_SET["X"])) + test_X_ori = torch.from_numpy(np.copy(DATA["test_X_ori"])) + test_X_indicating_mask = torch.from_numpy( + np.copy(DATA["test_X_indicating_mask"]) + ) + + test_X_imputed = self.lerp.predict({"X": X})["imputation"] + assert not torch.isnan( + test_X_imputed + ).any(), "Output still has missing values after running impute()." + test_MSE = calc_mse(test_X_imputed, test_X_ori, test_X_indicating_mask) + logger.info(f"Lerp test_MSE: {test_MSE}") + + @pytest.mark.xdist_group(name="imputation-lerp") + def test_4_lazy_loading(self): + self.lerp.fit(GENERAL_H5_TRAIN_SET_PATH, GENERAL_H5_VAL_SET_PATH) + imputation_results = self.lerp.predict(GENERAL_H5_TEST_SET_PATH) + assert not np.isnan( + imputation_results["imputation"] + ).any(), "Output still has missing values after running impute()." + + test_MSE = calc_mse( + imputation_results["imputation"], + DATA["test_X_ori"], + DATA["test_X_indicating_mask"], + ) + logger.info(f"Lazy-loading Lerp test_MSE: {test_MSE}") + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file