Skip to content

Commit

Permalink
add lerp imputation method, lerp test, and modify imputation init file
Browse files Browse the repository at this point in the history
  • Loading branch information
colesussmeier committed Jul 16, 2024
1 parent 90aa00b commit 379ed6b
Show file tree
Hide file tree
Showing 4 changed files with 248 additions and 0 deletions.
2 changes: 2 additions & 0 deletions pypots/imputation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from .locf import LOCF
from .mean import Mean
from .median import Median
from .lerp import Lerp

__all__ = [
# neural network imputation methods
Expand Down Expand Up @@ -76,4 +77,5 @@
"LOCF",
"Mean",
"Median",
"Lerp",
]
12 changes: 12 additions & 0 deletions pypots/imputation/lerp/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""
The package of the partially-observed time-series imputation method linear interpolation.
"""

# Created by Cole Sussmeier <colesussmeier@gmail.com>
# License: BSD-3-Clause

from .model import Lerp

__all__ = [
"Lerp",
]
160 changes: 160 additions & 0 deletions pypots/imputation/lerp/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
"""
The implementation of linear interpolation for the partially-observed time-series imputation task.
"""

# Created by Cole Sussmeier <colesussmeier@gmail.com>
# License: BSD-3-Clause

import warnings
from typing import Union, Optional

import h5py
import numpy as np
import torch

from ..base import BaseImputer


class Lerp(BaseImputer):
"""Linear interpolation (Lerp) imputation method.
Lerp will linearly interpolate missing values between the nearest non-missing values.
If there are missing values at the beginning or end of the series, they will be back-filled or forward-filled with the nearest non-missing value, respectively.
If an entire series is empty, all 'nan' values will be filled with zeros.
"""

def __init__(
self,
):
super().__init__()

def fit(
self,
train_set: Union[dict, str],
val_set: Optional[Union[dict, str]] = None,
file_type: str = "hdf5",
) -> None:
"""Train the imputer on the given data.
Warnings
--------
Linear interpolation class does not need to run fit().
Please run func ``predict()`` directly.
"""
warnings.warn(
"Linear interpolation class has no parameter to train. "
"Please run func `predict()` directly."
)

def predict(
self,
test_set: Union[dict, str],
file_type: str = "hdf5",
) -> dict:
"""Make predictions for the input data with the trained model.
Parameters
----------
test_set : dict or str
The dataset for model validating, should be a dictionary including keys as 'X',
or a path string locating a data file supported by PyPOTS (e.g. h5 file).
If it is a dict, X should be array-like of shape [n_samples, sequence length (n_steps), n_features],
which is time-series data for validating, can contain missing values, and y should be array-like of shape
[n_samples], which is classification labels of X.
If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
key-value pairs like a dict, and it has to include keys as 'X' and 'y'.
file_type :
The type of the given file if test_set is a path string.
Returns
-------
result_dict: dict
Prediction results in a Python Dictionary for the given samples.
It should be a dictionary including keys as 'imputation', 'classification', 'clustering', and 'forecasting'.
For sure, only the keys that relevant tasks are supported by the model will be returned.
"""
if isinstance(test_set, str):
with h5py.File(test_set, "r") as f:
X = f["X"][:]
else:
X = test_set["X"]

assert len(X.shape) == 3, (
f"Input X should have 3 dimensions [n_samples, n_steps, n_features], "
f"but the actual shape of X: {X.shape}"
)
if isinstance(X, list):
X = np.asarray(X)

def _interpolate_missing_values(X: np.ndarray):
nans = np.isnan(X)
nan_index = np.where(nans)[0]
index = np.where(~nans)[0]
if np.any(nans) and index.size > 1:
X[nans] = np.interp(nan_index, index, X[~nans])
elif np.any(nans):
X[nans] = 0

if isinstance(X, np.ndarray):

trans_X = X.transpose((0, 2, 1))
n_samples, n_features, n_steps = trans_X.shape
reshaped_X = np.reshape(trans_X, (-1, n_steps))
imputed_X = np.ones(reshaped_X.shape)

for i, univariate_series in enumerate(reshaped_X):
t = np.copy(univariate_series)
_interpolate_missing_values(t)
imputed_X[i] = t

imputed_trans_X = np.reshape(imputed_X, (n_samples, n_features, -1))
imputed_data = imputed_trans_X.transpose((0, 2, 1))

elif isinstance(X, torch.Tensor):

trans_X = X.permute(0, 2, 1)
n_samples, n_features, n_steps = trans_X.shape
reshaped_X = trans_X.reshape(-1, n_steps)
imputed_X = torch.ones_like(reshaped_X)

for i, univariate_series in enumerate(reshaped_X):
t = univariate_series.clone().cpu().detach().numpy()
_interpolate_missing_values(t)
imputed_X[i] = torch.from_numpy(t)

imputed_trans_X = imputed_X.reshape(n_samples, n_features, -1)
imputed_data = imputed_trans_X.permute(0, 2, 1)

else:
raise ValueError()

result_dict = {
"imputation": imputed_data,
}
return result_dict

def impute(
self,
test_set: Union[dict, str],
file_type: str = "hdf5",
) -> np.ndarray:
"""Impute missing values in the given data with the trained model.
Parameters
----------
test_set :
The data samples for testing, should be array-like of shape [n_samples, sequence length (n_steps),
n_features], or a path string locating a data file, e.g. h5 file.
file_type :
The type of the given file if X is a path string.
Returns
-------
array-like, shape [n_samples, sequence length (n_steps), n_features],
Imputed data.
"""

result_dict = self.predict(test_set, file_type=file_type)
return result_dict["imputation"]
74 changes: 74 additions & 0 deletions tests/imputation/lerp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
"""
Test cases for Linear Interpolation(Lerp) imputation method.
"""

# Created by Cole Sussmeier <colesussmeier@gmail.com>
# License: BSD-3-Clause


import unittest

import numpy as np
import pytest
import torch

from pypots.imputation import Lerp
from pypots.utils.logging import logger
from pypots.utils.metrics import calc_mse
from tests.global_test_config import (
DATA,
TEST_SET,
GENERAL_H5_TRAIN_SET_PATH,
GENERAL_H5_VAL_SET_PATH,
GENERAL_H5_TEST_SET_PATH,
)


class TestLerp(unittest.TestCase):
logger.info("Running tests for an imputation model Lerp...")
lerp = Lerp()

@pytest.mark.xdist_group(name="imputation-lerp")
def test_0_impute(self):
# if input data is numpy ndarray
test_X_imputed = self.lerp.predict(TEST_SET)["imputation"]
assert not np.isnan(
test_X_imputed
).any(), "Output still has missing values after running impute()."
test_MSE = calc_mse(
test_X_imputed, DATA["test_X_ori"], DATA["test_X_indicating_mask"]
)
logger.info(f"Lerp test_MSE: {test_MSE}")

# if input data is torch tensor
X = torch.from_numpy(np.copy(TEST_SET["X"]))
test_X_ori = torch.from_numpy(np.copy(DATA["test_X_ori"]))
test_X_indicating_mask = torch.from_numpy(
np.copy(DATA["test_X_indicating_mask"])
)

test_X_imputed = self.lerp.predict({"X": X})["imputation"]
assert not torch.isnan(
test_X_imputed
).any(), "Output still has missing values after running impute()."
test_MSE = calc_mse(test_X_imputed, test_X_ori, test_X_indicating_mask)
logger.info(f"Lerp test_MSE: {test_MSE}")

@pytest.mark.xdist_group(name="imputation-lerp")
def test_4_lazy_loading(self):
self.lerp.fit(GENERAL_H5_TRAIN_SET_PATH, GENERAL_H5_VAL_SET_PATH)
imputation_results = self.lerp.predict(GENERAL_H5_TEST_SET_PATH)
assert not np.isnan(
imputation_results["imputation"]
).any(), "Output still has missing values after running impute()."

test_MSE = calc_mse(
imputation_results["imputation"],
DATA["test_X_ori"],
DATA["test_X_indicating_mask"],
)
logger.info(f"Lazy-loading Lerp test_MSE: {test_MSE}")


if __name__ == "__main__":
unittest.main()

0 comments on commit 379ed6b

Please sign in to comment.