Skip to content

Commit

Permalink
reuse util to seed error_mechanism
Browse files Browse the repository at this point in the history
  • Loading branch information
se-jaeger committed Dec 5, 2024
1 parent aff5255 commit fde971a
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 23 deletions.
6 changes: 2 additions & 4 deletions tab_err/error_mechanism/_ear.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import warnings
from typing import TYPE_CHECKING

import numpy as np

from tab_err._utils import get_column, get_column_str

from ._error_mechanism import ErrorMechanism
Expand All @@ -23,7 +21,7 @@ def _sample(self: EAR, data: pd.DataFrame, column: str | int, error_rate: float,
if self.condition_to_column is None:
col = get_column_str(data, column)
column_selection = [x for x in data.columns if x != col]
condition_to_column = np.random.default_rng(self.seed).choice(column_selection)
condition_to_column = self._random_generator.choice(column_selection)
warnings.warn(
"The user did not specify 'condition_to_column', the column on which the EAR Mechanism conditions the error distribution. "
+ f"Randomly select column '{condition_to_column}'.",
Expand All @@ -46,7 +44,7 @@ def _sample(self: EAR, data: pd.DataFrame, column: str | int, error_rate: float,

# we offset the upper bound of the lower_error_index by a) the existing number of errors in the row, and b) the number of errors to-be generated.
upper_bound = len(se_data) - sum(se_mask) - n_errors
lower_error_index = np.random.default_rng(self.seed).integers(0, upper_bound) if upper_bound > 0 else 0
lower_error_index = self._random_generator.integers(0, upper_bound) if upper_bound > 0 else 0
error_index_range = range(lower_error_index, lower_error_index + n_errors)
selected_rows = data_column_error_free.sort_values(by=condition_to_column).iloc[error_index_range, :]

Expand Down
4 changes: 1 addition & 3 deletions tab_err/error_mechanism/_ecar.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import warnings
from typing import TYPE_CHECKING

import numpy as np

from tab_err._utils import get_column

from ._error_mechanism import ErrorMechanism
Expand All @@ -30,6 +28,6 @@ def _sample(self: ECAR, data: pd.DataFrame, column: str | int, error_rate: float
raise ValueError(msg)

# randomly choose error-cells
error_indices = np.random.default_rng(seed=self.seed).choice(se_mask_error_free.index, n_errors, replace=False)
error_indices = self._random_generator.choice(se_mask_error_free.index, n_errors, replace=False)
se_mask[error_indices] = True
return error_mask
7 changes: 1 addition & 6 deletions tab_err/error_mechanism/_enar.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import warnings
from typing import TYPE_CHECKING

import numpy as np

from tab_err._utils import get_column

from ._error_mechanism import ErrorMechanism
Expand Down Expand Up @@ -33,10 +31,7 @@ def _sample(self: ENAR, data: pd.DataFrame, column: str | int, error_rate: float
msg += f"However, only {len(se_data_error_free)} error-free cells are available."
raise ValueError(msg)

if len(se_data_error_free) != n_errors: # noqa: SIM108
lower_error_index = np.random.default_rng(seed=self.seed).integers(0, len(se_data_error_free) - n_errors)
else:
lower_error_index = 0
lower_error_index = self._random_generator.integers(0, len(se_data_error_free) - n_errors) if len(se_data_error_free) != n_errors else 0

error_index_range = range(lower_error_index, lower_error_index + n_errors)
selected_rows = se_data_error_free.sort_values().iloc[error_index_range]
Expand Down
26 changes: 16 additions & 10 deletions tab_err/error_mechanism/_error_mechanism.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,26 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING

import pandas as pd

from tab_err._utils import seed_randomness

if TYPE_CHECKING:
import numpy as np


class ErrorMechanism(ABC):
def __init__(self: ErrorMechanism, condition_to_column: int | str | None = None, seed: int | None = None) -> None:
if not (isinstance(seed, int) or seed is None):
msg = "'seed' need to be int or None."
raise TypeError(msg)

self.condition_to_column = condition_to_column
self.seed = seed

self._seed = seed
self._random_generator: np.random.Generator

def sample(
self: ErrorMechanism,
Expand All @@ -22,17 +34,10 @@ def sample(
error_rate_msg = "'error_rate' need to be float: 0 <= error_rate <= 1."
raise ValueError(error_rate_msg)

if not (isinstance(self.seed, int) or self.seed is None):
msg = "'seed' need to be int or None."
raise TypeError(msg)

data_msg = "'data' needs to be a non-empty DataFrame."
if not isinstance(data, pd.DataFrame):
if not isinstance(data, pd.DataFrame) or data.empty:
data_msg = "'data' needs to be a non-empty DataFrame."
raise TypeError(data_msg)

if data.empty:
raise ValueError(data_msg)

# At least two columns are necessary if we condition to another
if self.condition_to_column is not None and len(data.columns) < 2: # noqa: PLR2004
msg = "'data' need at least 2 columns if 'condition_to_column' is given."
Expand All @@ -45,6 +50,7 @@ def sample(
if error_mask is None: # initialize empty error_mask
error_mask = pd.DataFrame(data=False, index=data.index, columns=data.columns)

self._random_generator = seed_randomness(self._seed)
return self._sample(data, column, error_rate, error_mask)

@abstractmethod
Expand Down

0 comments on commit fde971a

Please sign in to comment.