Skip to content

Commit

Permalink
refactor: code refactoring;
Browse files Browse the repository at this point in the history
  • Loading branch information
WenjieDu committed Oct 19, 2023
1 parent 80c1479 commit 58a09ea
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 26 deletions.
11 changes: 6 additions & 5 deletions pygrinder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@
#
# Dev branch marker is: 'X.Y.dev' or 'X.Y.devN' where N is an integer.
# 'X.Y.dev0' is the canonical version of 'X.Y.dev'
__version__ = "0.1"
__version__ = "0.1.1"

try:
from pygrinder.mcar import mcar
from pygrinder.mar import mar_logistic
from pygrinder.mnar import mnar_x, mnar_t
from pygrinder.utils import (
cal_missing_rate,
Expand All @@ -33,12 +34,12 @@
except Exception as e:
print(e)


__all__ = [
"__version__",
"mcar",
"mar_logistic",
"mnar_x",
"mnar_t",
"cal_missing_rate",
"masked_fill",
"mcar",
"mnar_x"
"mnar_t"
]
6 changes: 5 additions & 1 deletion pygrinder/mar.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,13 @@
from typing import Union, Tuple

import numpy as np
import torch
from scipy import optimize

try:
import torch
except ImportError:
pass


def mar_logistic(
X: Union[torch.Tensor, np.ndarray],
Expand Down
29 changes: 23 additions & 6 deletions pygrinder/mcar.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
# Created by Wenjie Du <wenjay.du@gmail.com>
# License: GPL-v3

from typing import Union, Tuple

import numpy as np

try:
Expand All @@ -13,12 +15,19 @@
pass


def mcar(X, p, nan=0):
def mcar(
X: Union[np.ndarray, torch.Tensor],
p: float,
nan: Union[float, int] = 0,
) -> Union[
Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray],
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
]:
"""Create completely random missing values (MCAR case).
Parameters
----------
X : array,
X :
Data vector. If X has any missing values, they should be numpy.nan.
p : float, in (0,1),
Expand Down Expand Up @@ -67,12 +76,16 @@ def mcar(X, p, nan=0):
return X_intact, X, missing_mask, indicating_mask


def _mcar_numpy(X: np.ndarray, rate: float, nan: float = 0):
def _mcar_numpy(
X: np.ndarray,
p: float,
nan: Union[float, int] = 0,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
# clone X to ensure values of X out of this function not being affected
X = np.copy(X)

X_intact = np.copy(X) # keep a copy of originally observed values in X_intact
mcar_missing_mask = np.asarray(np.random.rand(np.product(X.shape)) < rate)
mcar_missing_mask = np.asarray(np.random.rand(np.product(X.shape)) < p)
mcar_missing_mask = mcar_missing_mask.reshape(X.shape)
X[mcar_missing_mask] = np.nan # mask values selected by mcar_missing_mask
indicating_mask = ((~np.isnan(X_intact)) ^ (~np.isnan(X))).astype(np.float32)
Expand All @@ -82,12 +95,16 @@ def _mcar_numpy(X: np.ndarray, rate: float, nan: float = 0):
return X_intact, X, missing_mask, indicating_mask


def _mcar_torch(X, rate: float, nan: float = 0):
def _mcar_torch(
X: torch.Tensor,
p: float,
nan: Union[float, int] = 0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# clone X to ensure values of X out of this function not being affected
X = torch.clone(X)

X_intact = torch.clone(X) # keep a copy of originally observed values in X_intact
mcar_missing_mask = torch.rand(X.shape) < rate
mcar_missing_mask = torch.rand(X.shape) < p
X[mcar_missing_mask] = torch.nan # mask values selected by mcar_missing_mask
indicating_mask = ((~torch.isnan(X_intact)) ^ (~torch.isnan(X))).type(torch.float32)
missing_mask = (~torch.isnan(X)).type(torch.float32)
Expand Down
2 changes: 1 addition & 1 deletion pygrinder/mnar.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def mnar_x(
]:
"""Create not-random missing values related to values themselves (MNAR-x case ot self-masking MNAR case).
This case follows the setting in Ipsen et al. "not-MIWAE: Deep Generative Modelling with Missing Not at Random Data"
:cite`ipsen2021notmiwae`.
:cite:`ipsen2021notmiwae`.
Parameters
----------
Expand Down
38 changes: 25 additions & 13 deletions pygrinder/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
# Created by Wenjie Du <wenjay.du@gmail.com>
# License: GPL-v3

from typing import Union

import numpy as np

try:
Expand All @@ -13,18 +15,19 @@
pass


def cal_missing_rate(X):
def cal_missing_rate(X: Union[np.ndarray, torch.Tensor]) -> float:
"""Calculate the originally missing rate of the raw data.
Parameters
----------
X : array-like,
X:
Data array that may contain missing values.
Returns
-------
originally_missing_rate, float,
The originally missing rate of the raw data.
originally_missing_rate,
The originally missing rate of the raw data. Its value should be in the range [0,1].
"""
if isinstance(X, list):
X = np.asarray(X)
Expand All @@ -42,24 +45,29 @@ def cal_missing_rate(X):
return originally_missing_rate


def masked_fill(X, mask, val):
def masked_fill(
X: Union[np.ndarray, torch.Tensor],
mask: Union[np.ndarray, torch.Tensor],
val: float,
) -> Union[np.ndarray, torch.Tensor]:
"""Like torch.Tensor.masked_fill(), fill elements in given `X` with `val` where `mask` is True.
Parameters
----------
X : array-like,
X:
The data vector.
mask : array-like,
mask:
The boolean mask.
val : float
val:
The value to fill in with.
Returns
-------
array,
mask
filled_X:
Mask filled X.
"""
assert X.shape == mask.shape, (
"Shapes of X and mask must match, "
Expand All @@ -74,14 +82,18 @@ def masked_fill(X, mask, val):
mask = np.asarray(mask)

if isinstance(X, np.ndarray):
filled_X = X.copy()
mask = mask.copy()
mask = mask.astype(bool)
X[mask] = val
filled_X[mask] = val
elif isinstance(X, torch.Tensor):
filled_X = torch.clone(X)
mask = torch.clone(mask)
mask = mask.type(torch.bool)
X[mask] = val
filled_X[mask] = val
else:
raise TypeError(
"X must be type of list/numpy.ndarray/torch.Tensor, " f"but got {type(X)}"
)

return X
return filled_X

0 comments on commit 58a09ea

Please sign in to comment.