Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Lerp as an imputation method and update the docs config #462

Merged
merged 9 commits into from
Jul 17, 2024
5 changes: 5 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
release = pypots.__version__

# -- General configuration ---------------------------------------------------
# Set canonical URL from the Read the Docs Domain
html_baseurl = os.environ.get("READTHEDOCS_CANONICAL_URL", "")

# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
Expand Down Expand Up @@ -98,6 +100,9 @@
html_context = {
"last_updated": f"{date_now.year}/{date_now.month}/{date_now.day}",
}
# Tell Jinja2 templates the build is running on Read the Docs
if os.environ.get("READTHEDOCS", "") == "True":
html_context["READTHEDOCS"] = True

html_favicon = (
"https://raw.githubusercontent.com/"
Expand Down
18 changes: 18 additions & 0 deletions pypots/data/generating.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,14 @@
from typing import Optional, Tuple

import numpy as np
from benchpots.datasets import preprocess_physionet2012
from pygrinder import mcar
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.utils import check_random_state

from ..utils.logging import logger


def gene_complete_random_walk(
n_samples: int = 1000,
Expand Down Expand Up @@ -318,3 +321,18 @@ def gene_random_walk(
data["test_X_indicating_mask"] = np.isnan(test_X_ori) ^ np.isnan(test_X)

return data


def gene_physionet2012(artificially_missing_rate: float = 0.1):
dataset_from_benchpots = preprocess_physionet2012(
subset="all", rate=artificially_missing_rate
)
logger.warning(
"🚨 Due to the full release of BenchPOTS package, "
"gene_physionet2012() has been deprecated and will be removed in pypots v0.8"
)
logger.info(
"🌟 Please refer to https://github.com/WenjieDu/BenchPOTS and "
"check out the func benchpots.datasets.preprocess_physionet2012()"
)
return dataset_from_benchpots
5 changes: 2 additions & 3 deletions pypots/data/saving/pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@
# License: BSD-3-Clause

import pickle
from typing import Optional

from ...utils.file import extract_parent_dir, create_dir_if_not_exist
from ...utils.logging import logger


def pickle_dump(data: object, path: str) -> Optional[str]:
def pickle_dump(data: object, path: str) -> None:
"""Pickle the given object.

Parameters
Expand All @@ -39,7 +38,6 @@ def pickle_dump(data: object, path: str) -> Optional[str]:
)
return None
logger.info(f"Successfully saved to {path}")
return path


def pickle_load(path: str) -> object:
Expand All @@ -61,4 +59,5 @@ def pickle_load(path: str) -> object:
data = pickle.load(f)
except Exception as e:
logger.error(f"❌ Loading data failed. Operation aborted. See info below:\n{e}")
return None
return data
69 changes: 20 additions & 49 deletions pypots/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,12 @@
# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause

import math
from typing import Union

import benchpots
import numpy as np
import torch

from ..utils.logging import logger


def turn_data_into_specified_dtype(
data: Union[np.ndarray, torch.Tensor, list],
Expand Down Expand Up @@ -166,7 +164,11 @@ def parse_delta(
return delta


def sliding_window(time_series, window_len, sliding_len=None):
def sliding_window(
time_series: Union[np.ndarray, torch.Tensor],
window_len: int,
sliding_len: int = None,
) -> Union[np.ndarray, torch.Tensor]:
"""Generate time series samples with sliding window method, truncating windows from time-series data
with a given sequence length.

Expand All @@ -177,41 +179,27 @@ def sliding_window(time_series, window_len, sliding_len=None):

Parameters
----------
time_series : np.ndarray,
time_series :
time series data, len(shape)=2, [total_length, feature_num]

window_len : int,
window_len :
The length of the sliding window, i.e. the number of time steps in the generated data samples.

sliding_len : int, default = None,
sliding_len :
The sliding length of the window for each moving step. It will be set as the same with n_steps if None.

Returns
-------
samples : np.ndarray,
samples :
The generated time-series data samples of shape [seq_len//sliding_len, n_steps, n_features].

"""
sliding_len = window_len if sliding_len is None else sliding_len
total_len = time_series.shape[0]
start_indices = np.asarray(range(total_len // sliding_len)) * sliding_len

# remove the last one if left length is not enough
if total_len - start_indices[-1] < window_len:
to_drop = math.ceil(window_len / sliding_len)
left_len = total_len - start_indices[-1]
start_indices = start_indices[:-to_drop]
logger.warning(
f"The last {to_drop} samples are dropped due to the left length {left_len} is not enough."
)

sample_collector = []
for idx in start_indices:
sample_collector.append(time_series[idx : idx + window_len])

samples = np.asarray(sample_collector).astype("float32")

return samples
return benchpots.utils.sliding_window(
time_series,
window_len,
sliding_len,
)


def inverse_sliding_window(X, sliding_len):
Expand All @@ -238,25 +226,8 @@ def inverse_sliding_window(X, sliding_len):
The restored time-series data with shape of [total_length, n_features].

"""
assert len(X.shape) == 3, f"X should be a 3D array, but got {X.shape}"
n_samples, window_size, n_features = X.shape

if sliding_len >= window_size:
if sliding_len > window_size:
logger.warning(
f"sliding_len {sliding_len} is larger than the window size {window_size}, "
f"hence there will be gaps between restored data."
)
restored_data = X.reshape(n_samples * window_size, n_features)
else:
collector = [X[0][:sliding_len]]
overlap = X[0][sliding_len:]
for x in X[1:]:
overlap_avg = (overlap + x[:-sliding_len]) / 2
collector.append(overlap_avg[:sliding_len])
overlap = np.concatenate(
[overlap_avg[sliding_len:], x[-sliding_len:]], axis=0
)
collector.append(overlap)
restored_data = np.concatenate(collector, axis=0)
return restored_data

return benchpots.utils.inverse_sliding_window(
X,
sliding_len,
)
2 changes: 2 additions & 0 deletions pypots/imputation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from .locf import LOCF
from .mean import Mean
from .median import Median
from .lerp import Lerp

__all__ = [
# neural network imputation methods
Expand Down Expand Up @@ -76,4 +77,5 @@
"LOCF",
"Mean",
"Median",
"Lerp",
]
12 changes: 12 additions & 0 deletions pypots/imputation/lerp/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""
The package of the partially-observed time-series imputation method linear interpolation.
"""

# Created by Cole Sussmeier <colesussmeier@gmail.com>
# License: BSD-3-Clause

from .model import Lerp

__all__ = [
"Lerp",
]
161 changes: 161 additions & 0 deletions pypots/imputation/lerp/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
"""
The implementation of linear interpolation for the partially-observed time-series imputation task.
"""

# Created by Cole Sussmeier <colesussmeier@gmail.com>
# License: BSD-3-Clause

import warnings
from typing import Union, Optional

import h5py
import numpy as np
import torch

from ..base import BaseImputer


class Lerp(BaseImputer):
"""Linear interpolation (Lerp) imputation method.

Lerp will linearly interpolate missing values between the nearest non-missing values.
If there are missing values at the beginning or end of the series, they will be back-filled or
forward-filled with the nearest non-missing value, respectively.
If an entire series is empty, all 'nan' values will be filled with zeros.
"""

def __init__(
self,
):
super().__init__()

def fit(
self,
train_set: Union[dict, str],
val_set: Optional[Union[dict, str]] = None,
file_type: str = "hdf5",
) -> None:
"""Train the imputer on the given data.

Warnings
--------
Linear interpolation class does not need to run fit().
Please run func ``predict()`` directly.
"""
warnings.warn(
"Linear interpolation class has no parameter to train. "
"Please run func `predict()` directly."
)

def predict(
self,
test_set: Union[dict, str],
file_type: str = "hdf5",
) -> dict:
"""Make predictions for the input data with the trained model.

Parameters
----------
test_set : dict or str
The dataset for model validating, should be a dictionary including keys as 'X',
or a path string locating a data file supported by PyPOTS (e.g. h5 file).
If it is a dict, X should be array-like of shape [n_samples, sequence length (n_steps), n_features],
which is time-series data for validating, can contain missing values, and y should be array-like of shape
[n_samples], which is classification labels of X.
If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
key-value pairs like a dict, and it has to include keys as 'X' and 'y'.

file_type :
The type of the given file if test_set is a path string.

Returns
-------
result_dict: dict
Prediction results in a Python Dictionary for the given samples.
It should be a dictionary including keys as 'imputation', 'classification', 'clustering', and 'forecasting'.
For sure, only the keys that relevant tasks are supported by the model will be returned.
"""
if isinstance(test_set, str):
with h5py.File(test_set, "r") as f:
X = f["X"][:]
else:
X = test_set["X"]

assert len(X.shape) == 3, (
f"Input X should have 3 dimensions [n_samples, n_steps, n_features], "
f"but the actual shape of X: {X.shape}"
)
if isinstance(X, list):
X = np.asarray(X)

def _interpolate_missing_values(X: np.ndarray):
nans = np.isnan(X)
nan_index = np.where(nans)[0]
index = np.where(~nans)[0]
if np.any(nans) and index.size > 1:
X[nans] = np.interp(nan_index, index, X[~nans])
elif np.any(nans):
X[nans] = 0

if isinstance(X, np.ndarray):

trans_X = X.transpose((0, 2, 1))
n_samples, n_features, n_steps = trans_X.shape
reshaped_X = np.reshape(trans_X, (-1, n_steps))
imputed_X = np.ones(reshaped_X.shape)

for i, univariate_series in enumerate(reshaped_X):
t = np.copy(univariate_series)
_interpolate_missing_values(t)
imputed_X[i] = t

imputed_trans_X = np.reshape(imputed_X, (n_samples, n_features, -1))
imputed_data = imputed_trans_X.transpose((0, 2, 1))

elif isinstance(X, torch.Tensor):

trans_X = X.permute(0, 2, 1)
n_samples, n_features, n_steps = trans_X.shape
reshaped_X = trans_X.reshape(-1, n_steps)
imputed_X = torch.ones_like(reshaped_X)

for i, univariate_series in enumerate(reshaped_X):
t = univariate_series.clone().cpu().detach().numpy()
_interpolate_missing_values(t)
imputed_X[i] = torch.from_numpy(t)

imputed_trans_X = imputed_X.reshape(n_samples, n_features, -1)
imputed_data = imputed_trans_X.permute(0, 2, 1)

else:
raise ValueError()

result_dict = {
"imputation": imputed_data,
}
return result_dict

def impute(
self,
test_set: Union[dict, str],
file_type: str = "hdf5",
) -> np.ndarray:
"""Impute missing values in the given data with the trained model.

Parameters
----------
test_set :
The data samples for testing, should be array-like of shape [n_samples, sequence length (n_steps),
n_features], or a path string locating a data file, e.g. h5 file.

file_type :
The type of the given file if X is a path string.

Returns
-------
array-like, shape [n_samples, sequence length (n_steps), n_features],
Imputed data.
"""

result_dict = self.predict(test_set, file_type=file_type)
return result_dict["imputation"]
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ scikit-learn
torch>=1.10.0
tsdb>=0.4
pygrinder>=0.6
benchpots>=0.1
benchpots>=0.2
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ basic =
torch>=1.10.0
tsdb>=0.4
pygrinder>=0.6
benchpots>=0.1
benchpots>=0.2

# dependencies that are optional, torch-geometric are only needed for model Raindrop
# but its installation takes too much time
Expand Down
Loading
Loading