Skip to content

Commit

Permalink
Add docstrings for methods in OptiMask class to improve code document…
Browse files Browse the repository at this point in the history
…ation
  • Loading branch information
CyrilJl committed Dec 25, 2024
1 parent 6139340 commit 0e072ea
Showing 1 changed file with 56 additions and 5 deletions.
61 changes: 56 additions & 5 deletions optimask/_optimask.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,13 @@ def _verbose(self, msg):
@staticmethod
@njit(uint32[:](uint32[:], uint32[:], uint32), boundscheck=False)
def groupby_max(a, b, n):
"""
numba equivalent to :
size_a = len(a)
ret = np.zeros(n, dtype=np.uint32)
np.maximum.at(ret, a, b + 1)
return ret
"""
size_a = len(a)
ret = np.zeros(n, dtype=np.uint32)
for k in range(size_a):
Expand Down Expand Up @@ -83,6 +90,17 @@ def numba_apply_permutation_inplace(p, x):

@classmethod
def apply_permutation(cls, p, x, inplace: bool):
"""
Applies a permutation to an array.
Args:
p (np.ndarray): The permutation array.
x (np.ndarray): The array to be permuted.
inplace (bool): If True, applies the permutation in place; otherwise, returns a new permuted array.
Returns:
np.ndarray: The permuted array if inplace is False; otherwise, None.
"""
if inplace:
cls.numba_apply_permutation_inplace(p, x)
else:
Expand All @@ -108,6 +126,18 @@ def _get_largest_rectangle(heights, m, n):
@staticmethod
@njit(boundscheck=False)
def _preprocess(x):
"""
Preprocesses the input array to identify rows and columns containing NaNs.
Args:
x (np.ndarray): The input 2D array with NaN values.
Returns:
Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
- iy, ix = np.isnan(x).nonzero()
- rows_with_nan: Rows that contain NaNs.
- cols_with_nan: Columns that contain NaNs.
"""
m, n = x.shape
iy, ix = np.empty(m*n, dtype=np.uint32), np.empty(m*n, dtype=np.uint32)
cols_index_mapper = -np.ones(n, dtype=np.int32)
Expand Down Expand Up @@ -175,7 +205,16 @@ def _trial(self, rng, m_nan, n_nan, iy, ix, m, n):
@njit(uint32[:](uint32, uint32[:], uint32[:], uint32), boundscheck=False)
def compute_to_keep(size, index_with_nan, permutation, split):
"""
Faster version of `np.setdiff1d(np.arange(size, dtype=np.uint32), index_with_nan[permutation[:split]])`.
Computes the indices to keep after removing a subset of indices with NaNs.
Args:
size (int): The total number of indices.
index_with_nan (np.ndarray): The indices that contain NaNs.
permutation (np.ndarray): The permutation array.
split (int): The split point in the permutation array.
Returns:
np.ndarray: The indices to keep after removing the subset with NaNs.
"""
mask = np.zeros(size, dtype=np.uint8)
for i in range(split):
Expand Down Expand Up @@ -209,7 +248,7 @@ def _solve(self, x):
if m <= n:
return np.array([]), np.arange(n)
else:
return np.arange(m), np.array([])
return np.arange(m), np.array([], dtype=np.uint32)

if len(rows_with_nan) == 1:
if n-n_nan <= n_nan*(m-m_nan):
Expand Down Expand Up @@ -243,6 +282,17 @@ def _solve(self, x):
@staticmethod
@njit(boundscheck=False)
def has_nan_in_subset(X, rows, cols):
"""
Checks if there are any NaN values in the specified subset of the array.
Args:
X (np.ndarray): The input 2D array.
rows (np.ndarray): The row indices of the subset.
cols (np.ndarray): The column indices of the subset.
Returns:
bool: True if there are NaN values in the subset, False otherwise.
"""
for i in range(rows.size):
for j in range(cols.size):
if np.isnan(X[rows[i], cols[j]]):
Expand All @@ -256,11 +306,12 @@ def solve(self, X: Union[np.ndarray, pd.DataFrame], return_data: bool = False, c
Args:
X (Union[np.ndarray, pd.DataFrame]): The input 2D array or DataFrame with NaN values.
return_data (bool): If True, returns the resulting data; otherwise, returns the indices.
check_result (bool): If True, checks if the computed submatrix contains NaNs. Disabled by default
as it can slow down the computation and the algorithm has proven to be reliable.
check_result (bool): If True, checks if the computed submatrix contains NaNs, for tests purposes.
Disabled by default as it can slow down the computation and the algorithm has proven to be reliable.
Returns:
Union[Tuple[np.ndarray, np.ndarray], Tuple[pd.Index, pd.Index]]: If return_data is True, returns the resulting 2D array or DataFrame; otherwise, returns the indices of rows and columns to retain.
Union[Tuple[np.ndarray, np.ndarray], Tuple[pd.Index, pd.Index]]: If return_data is True, returns
the resulting 2D array or DataFrame; otherwise, returns the indices of rows and columns to retain.
Raises:
InvalidDimensionError: If the input numpy array does not have ndim==2.
Expand Down

0 comments on commit 0e072ea

Please sign in to comment.