Skip to content

Commit

Permalink
store random seed as an object
Browse files Browse the repository at this point in the history
  • Loading branch information
katxiao committed Feb 17, 2022
1 parent c70a35d commit 0383db2
Show file tree
Hide file tree
Showing 13 changed files with 152 additions and 49 deletions.
37 changes: 27 additions & 10 deletions copulas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,14 @@ def set_random_state(random_state, set_model_random_state):
"""
original_state = np.random.get_state()

if isinstance(random_state, int):
random_state = np.random.RandomState(seed=random_state).get_state()
elif isinstance(random_state, np.random.RandomState):
random_state = random_state.get_state()
elif not isinstance(random_state, tuple):
raise TypeError(f'RandomState {random_state} is an unexpected type. '
'Expected to be int, np.random.RandomState, or tuple.')

np.random.set_state(random_state)
np.random.set_state(random_state.get_state())

try:
yield
finally:
set_model_random_state(np.random.get_state())
current_random_state = np.random.RandomState()
current_random_state.set_state(np.random.get_state())
set_model_random_state(current_random_state)
np.random.set_state(original_state)


Expand All @@ -68,6 +62,29 @@ def wrapper(self, *args, **kwargs):
return wrapper


def validate_random_state(random_state):
"""Validate random state argument.
Args:
random_state (int, numpy.random.RandomState, tuple, or None):
Seed or RandomState for the random generator.
Output:
numpy.random.RandomState
"""
if random_state is None:
return None

if isinstance(random_state, int):
return np.random.RandomState(seed=random_state)
elif isinstance(random_state, np.random.RandomState):
return random_state
else:
raise TypeError(
f'`random_state` {random_state} expected to be an int '
'or `np.random.RandomState` object.')


def get_instance(obj, **kwargs):
"""Create new instance of the ``obj`` argument.
Expand Down
6 changes: 3 additions & 3 deletions copulas/bivariate/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from scipy import stats
from scipy.optimize import brentq

from copulas import EPSILON, NotFittedError, random_state
from copulas import EPSILON, NotFittedError, random_state, validate_random_state
from copulas.bivariate.utils import split_matrix


Expand Down Expand Up @@ -113,7 +113,7 @@ def __init__(self, copula_type=None, random_state=None):
random_state (int, np.random.RandomState, or None): Seed or RandomState
for the random generator.
"""
self.random_state = random_state
self.random_state = validate_random_state(random_state)

def check_theta(self):
"""Validate the computed theta against the copula specification.
Expand Down Expand Up @@ -352,7 +352,7 @@ def set_random_state(self, random_state):
random_state (int, np.random.RandomState, or None): Seed or RandomState
for the random generator.
"""
self.random_state = random_state
self.random_state = validate_random_state(random_state)

@random_state
def sample(self, n_samples):
Expand Down
20 changes: 10 additions & 10 deletions copulas/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pandas as pd
from scipy import stats

from copulas import set_random_state
from copulas import set_random_state, validate_random_state


def _dummy_fn(state):
Expand All @@ -27,7 +27,7 @@ def sample_bivariate_age_income(size=1000, seed=42):
pandas.DataFrame:
DataFrame with two columns, ``age`` and ``income``.
"""
with set_random_state(seed, _dummy_fn):
with set_random_state(validate_random_state(seed), _dummy_fn):
age = stats.beta.rvs(a=2.0, b=6.0, loc=18, scale=100, size=size)
income = np.log(age) * 100
income += np.random.normal(loc=np.log(age) / 100, scale=10, size=size)
Expand Down Expand Up @@ -58,7 +58,7 @@ def sample_trivariate_xyz(size=1000, seed=42):
pandas.DataFrame:
DataFrame with three columns, ``x``, ``y`` and ``z``.
"""
with set_random_state(seed, _dummy_fn):
with set_random_state(validate_random_state(seed), _dummy_fn):
x = stats.beta.rvs(a=0.1, b=0.1, size=size)
y = stats.beta.rvs(a=0.1, b=0.5, size=size)
return pd.DataFrame({
Expand All @@ -84,7 +84,7 @@ def sample_univariate_bernoulli(size=1000, seed=42):
pandas.Series:
Series with the sampled values.
"""
with set_random_state(seed, _dummy_fn):
with set_random_state(validate_random_state(seed), _dummy_fn):
return pd.Series(np.random.random(size=size) < 0.3).astype(float)


Expand All @@ -104,7 +104,7 @@ def sample_univariate_bimodal(size=1000, seed=42):
pandas.Series:
Series with the sampled values.
"""
with set_random_state(seed, _dummy_fn):
with set_random_state(validate_random_state(seed), _dummy_fn):
bernoulli = sample_univariate_bernoulli(size, seed)
mode1 = np.random.normal(size=size) * bernoulli
mode2 = np.random.normal(size=size, loc=10) * (1.0 - bernoulli)
Expand All @@ -125,7 +125,7 @@ def sample_univariate_uniform(size=1000, seed=42):
pandas.Series:
Series with the sampled values.
"""
with set_random_state(seed, _dummy_fn):
with set_random_state(validate_random_state(seed), _dummy_fn):
return pd.Series(4.0 * np.random.random(size=size) - 1.0)


Expand All @@ -142,7 +142,7 @@ def sample_univariate_normal(size=1000, seed=42):
pandas.Series:
Series with the sampled values.
"""
with set_random_state(seed, _dummy_fn):
with set_random_state(validate_random_state(seed), _dummy_fn):
return pd.Series(np.random.normal(size=size, loc=1.0))


Expand All @@ -159,7 +159,7 @@ def sample_univariate_degenerate(size=1000, seed=42):
pandas.Series:
Series with the sampled values.
"""
with set_random_state(seed, _dummy_fn):
with set_random_state(validate_random_state(seed), _dummy_fn):
return pd.Series(np.full(size, np.random.random()))


Expand All @@ -176,7 +176,7 @@ def sample_univariate_exponential(size=1000, seed=42):
pandas.Series:
Series with the sampled values.
"""
with set_random_state(seed, _dummy_fn):
with set_random_state(validate_random_state(seed), _dummy_fn):
return pd.Series(np.random.exponential(size=size) + 3.0)


Expand All @@ -193,7 +193,7 @@ def sample_univariate_beta(size=1000, seed=42):
pandas.Series:
Series with the sampled values.
"""
with set_random_state(seed, _dummy_fn):
with set_random_state(validate_random_state(seed), _dummy_fn):
return pd.Series(stats.beta.rvs(a=3, b=1, loc=4, size=size))


Expand Down
6 changes: 3 additions & 3 deletions copulas/multivariate/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import numpy as np

from copulas import NotFittedError, get_instance
from copulas import NotFittedError, get_instance, validate_random_state


class Multivariate(object):
Expand All @@ -13,7 +13,7 @@ class Multivariate(object):
fitted = False

def __init__(self, random_state=None):
self.random_state = random_state
self.random_state = validate_random_state(random_state)

def fit(self, X):
"""Fit the model to table with values from multiple random variables.
Expand Down Expand Up @@ -116,7 +116,7 @@ def set_random_state(self, random_state):
random_state (int, np.random.RandomState, or None):
Seed or RandomState for the random generator.
"""
self.random_state = random_state
self.random_state = validate_random_state(random_state)

def sample(self, num_rows=1):
"""Sample values from this model.
Expand Down
5 changes: 3 additions & 2 deletions copulas/multivariate/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from scipy import stats

from copulas import (
EPSILON, check_valid_values, get_instance, get_qualified_name, random_state, store_args)
EPSILON, check_valid_values, get_instance, get_qualified_name, random_state, store_args,
validate_random_state)
from copulas.multivariate.base import Multivariate
from copulas.univariate import Univariate

Expand All @@ -33,7 +34,7 @@ class GaussianMultivariate(Multivariate):

@store_args
def __init__(self, distribution=DEFAULT_DISTRIBUTION, random_state=None):
self.random_state = random_state
self.random_state = validate_random_state(random_state)
self.distribution = distribution

def __repr__(self):
Expand Down
6 changes: 4 additions & 2 deletions copulas/multivariate/vine.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
import numpy as np
import pandas as pd

from copulas import EPSILON, check_valid_values, get_qualified_name, random_state, store_args
from copulas import (
EPSILON, check_valid_values, get_qualified_name, random_state, store_args,
validate_random_state)
from copulas.bivariate.base import Bivariate, CopulaTypes
from copulas.multivariate.base import Multivariate
from copulas.multivariate.tree import Tree, get_tree
Expand Down Expand Up @@ -73,7 +75,7 @@ def __init__(self, vine_type, random_state=None):
'produce wrong results. Please use Python 3.5, 3.6 or 3.7'
)

self.random_state = random_state
self.random_state = validate_random_state(random_state)
self.vine_type = vine_type
self.u_matrix = None

Expand Down
10 changes: 6 additions & 4 deletions copulas/univariate/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@

import numpy as np

from copulas import NotFittedError, get_instance, get_qualified_name, random_state, store_args
from copulas import (
NotFittedError, get_instance, get_qualified_name, random_state, store_args,
validate_random_state)
from copulas.univariate.selection import select_univariate


Expand Down Expand Up @@ -85,7 +87,7 @@ def _select_candidates(cls, parametric=None, bounded=None):
def __init__(self, candidates=None, parametric=None, bounded=None, random_state=None,
selection_sample_size=None):
self.candidates = candidates or self._select_candidates(parametric, bounded)
self.random_state = random_state
self.random_state = validate_random_state(random_state)
self.selection_sample_size = selection_sample_size

@classmethod
Expand Down Expand Up @@ -364,7 +366,7 @@ def set_random_state(self, random_state):
random_state (int, np.random.RandomState, or None):
Seed or RandomState for the random generator.
"""
self.random_state = random_state
self.random_state = validate_random_state(random_state)

def sample(self, n_samples=1):
"""Sample values from this model.
Expand Down Expand Up @@ -498,7 +500,7 @@ def __init__(self, random_state=None):
random_state (int, np.random.RandomState, or None): seed
or RandomState for random generator.
"""
self.random_state = random_state
self.random_state = validate_random_state(random_state)

def probability_density(self, X):
"""Compute the probability density for each point in X.
Expand Down
4 changes: 2 additions & 2 deletions copulas/univariate/gaussian_kde.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from scipy.special import ndtr
from scipy.stats import gaussian_kde

from copulas import EPSILON, random_state, store_args
from copulas import EPSILON, random_state, store_args, validate_random_state
from copulas.optimize import bisect, chandrupatla
from copulas.univariate.base import BoundedType, ParametricType, ScipyModel

Expand All @@ -29,7 +29,7 @@ class GaussianKDE(ScipyModel):

@store_args
def __init__(self, sample_size=None, random_state=None, bw_method=None, weights=None):
self.random_state = random_state
self.random_state = validate_random_state(random_state)
self._sample_size = sample_size
self.bw_method = bw_method
self.weights = weights
Expand Down
4 changes: 2 additions & 2 deletions copulas/univariate/truncated_gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from scipy.optimize import fmin_slsqp
from scipy.stats import truncnorm

from copulas import EPSILON, store_args
from copulas import EPSILON, store_args, validate_random_state
from copulas.univariate.base import BoundedType, ParametricType, ScipyModel


Expand All @@ -20,7 +20,7 @@ class TruncatedGaussian(ScipyModel):

@store_args
def __init__(self, minimum=None, maximum=None, random_state=None):
self.random_state = random_state
self.random_state = validate_random_state(random_state)
self.min = minimum
self.max = maximum

Expand Down
6 changes: 3 additions & 3 deletions tests/unit/bivariate/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ def setUp(self):
def test___init__random_state(self):
"""If random_state is passed as argument, will be set as attribute."""
# Setup
random_seed = 'random_seed'
random_seed = 42

# Run
instance = Bivariate(copula_type=CopulaTypes.CLAYTON, random_state=random_seed)

# Check
assert instance.random_state == 'random_seed'
assert instance.random_state is not None

def test_from_dict(self):
"""From_dict sets the values of a dictionary as attributes of the instance."""
Expand Down Expand Up @@ -130,4 +130,4 @@ def test_set_random_state(self):
instance.set_random_state(3)

# Check
assert instance.random_state == 3
assert instance.random_state is not None
4 changes: 3 additions & 1 deletion tests/unit/multivariate/test_base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import numpy as np

from copulas.multivariate.base import Multivariate


Expand All @@ -12,4 +14,4 @@ def test_set_random_state(self):
instance.set_random_state(3)

# Check
assert instance.random_state == 3
assert isinstance(instance.random_state, np.random.RandomState)
Loading

0 comments on commit 0383db2

Please sign in to comment.