From fde971a319ff7e0ca3650c121822400c40822757 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20J=C3=A4ger?= Date: Thu, 5 Dec 2024 11:55:36 +0100 Subject: [PATCH] reuse util to seed error_mechanism --- tab_err/error_mechanism/_ear.py | 6 ++--- tab_err/error_mechanism/_ecar.py | 4 +--- tab_err/error_mechanism/_enar.py | 7 +----- tab_err/error_mechanism/_error_mechanism.py | 26 +++++++++++++-------- 4 files changed, 20 insertions(+), 23 deletions(-) diff --git a/tab_err/error_mechanism/_ear.py b/tab_err/error_mechanism/_ear.py index f1eab34..5f8cfe3 100644 --- a/tab_err/error_mechanism/_ear.py +++ b/tab_err/error_mechanism/_ear.py @@ -3,8 +3,6 @@ import warnings from typing import TYPE_CHECKING -import numpy as np - from tab_err._utils import get_column, get_column_str from ._error_mechanism import ErrorMechanism @@ -23,7 +21,7 @@ def _sample(self: EAR, data: pd.DataFrame, column: str | int, error_rate: float, if self.condition_to_column is None: col = get_column_str(data, column) column_selection = [x for x in data.columns if x != col] - condition_to_column = np.random.default_rng(self.seed).choice(column_selection) + condition_to_column = self._random_generator.choice(column_selection) warnings.warn( "The user did not specify 'condition_to_column', the column on which the EAR Mechanism conditions the error distribution. " + f"Randomly select column '{condition_to_column}'.", @@ -46,7 +44,7 @@ def _sample(self: EAR, data: pd.DataFrame, column: str | int, error_rate: float, # we offset the upper bound of the lower_error_index by a) the existing number of errors in the row, and b) the number of errors to-be generated. upper_bound = len(se_data) - sum(se_mask) - n_errors - lower_error_index = np.random.default_rng(self.seed).integers(0, upper_bound) if upper_bound > 0 else 0 + lower_error_index = self._random_generator.integers(0, upper_bound) if upper_bound > 0 else 0 error_index_range = range(lower_error_index, lower_error_index + n_errors) selected_rows = data_column_error_free.sort_values(by=condition_to_column).iloc[error_index_range, :] diff --git a/tab_err/error_mechanism/_ecar.py b/tab_err/error_mechanism/_ecar.py index 0257207..c481da4 100644 --- a/tab_err/error_mechanism/_ecar.py +++ b/tab_err/error_mechanism/_ecar.py @@ -3,8 +3,6 @@ import warnings from typing import TYPE_CHECKING -import numpy as np - from tab_err._utils import get_column from ._error_mechanism import ErrorMechanism @@ -30,6 +28,6 @@ def _sample(self: ECAR, data: pd.DataFrame, column: str | int, error_rate: float raise ValueError(msg) # randomly choose error-cells - error_indices = np.random.default_rng(seed=self.seed).choice(se_mask_error_free.index, n_errors, replace=False) + error_indices = self._random_generator.choice(se_mask_error_free.index, n_errors, replace=False) se_mask[error_indices] = True return error_mask diff --git a/tab_err/error_mechanism/_enar.py b/tab_err/error_mechanism/_enar.py index b2110cb..c16dad7 100644 --- a/tab_err/error_mechanism/_enar.py +++ b/tab_err/error_mechanism/_enar.py @@ -3,8 +3,6 @@ import warnings from typing import TYPE_CHECKING -import numpy as np - from tab_err._utils import get_column from ._error_mechanism import ErrorMechanism @@ -33,10 +31,7 @@ def _sample(self: ENAR, data: pd.DataFrame, column: str | int, error_rate: float msg += f"However, only {len(se_data_error_free)} error-free cells are available." raise ValueError(msg) - if len(se_data_error_free) != n_errors: # noqa: SIM108 - lower_error_index = np.random.default_rng(seed=self.seed).integers(0, len(se_data_error_free) - n_errors) - else: - lower_error_index = 0 + lower_error_index = self._random_generator.integers(0, len(se_data_error_free) - n_errors) if len(se_data_error_free) != n_errors else 0 error_index_range = range(lower_error_index, lower_error_index + n_errors) selected_rows = se_data_error_free.sort_values().iloc[error_index_range] diff --git a/tab_err/error_mechanism/_error_mechanism.py b/tab_err/error_mechanism/_error_mechanism.py index 43209a9..a4d49ff 100644 --- a/tab_err/error_mechanism/_error_mechanism.py +++ b/tab_err/error_mechanism/_error_mechanism.py @@ -1,14 +1,26 @@ from __future__ import annotations from abc import ABC, abstractmethod +from typing import TYPE_CHECKING import pandas as pd +from tab_err._utils import seed_randomness + +if TYPE_CHECKING: + import numpy as np + class ErrorMechanism(ABC): def __init__(self: ErrorMechanism, condition_to_column: int | str | None = None, seed: int | None = None) -> None: + if not (isinstance(seed, int) or seed is None): + msg = "'seed' need to be int or None." + raise TypeError(msg) + self.condition_to_column = condition_to_column - self.seed = seed + + self._seed = seed + self._random_generator: np.random.Generator def sample( self: ErrorMechanism, @@ -22,17 +34,10 @@ def sample( error_rate_msg = "'error_rate' need to be float: 0 <= error_rate <= 1." raise ValueError(error_rate_msg) - if not (isinstance(self.seed, int) or self.seed is None): - msg = "'seed' need to be int or None." - raise TypeError(msg) - - data_msg = "'data' needs to be a non-empty DataFrame." - if not isinstance(data, pd.DataFrame): + if not isinstance(data, pd.DataFrame) or data.empty: + data_msg = "'data' needs to be a non-empty DataFrame." raise TypeError(data_msg) - if data.empty: - raise ValueError(data_msg) - # At least two columns are necessary if we condition to another if self.condition_to_column is not None and len(data.columns) < 2: # noqa: PLR2004 msg = "'data' need at least 2 columns if 'condition_to_column' is given." @@ -45,6 +50,7 @@ def sample( if error_mask is None: # initialize empty error_mask error_mask = pd.DataFrame(data=False, index=data.index, columns=data.columns) + self._random_generator = seed_randomness(self._seed) return self._sample(data, column, error_rate, error_mask) @abstractmethod