Skip to content

Commit

Permalink
cr comments
Browse files Browse the repository at this point in the history
  • Loading branch information
katxiao committed Feb 9, 2022
1 parent 5152fd6 commit b2eadca
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 6 deletions.
13 changes: 8 additions & 5 deletions copulas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,19 @@ def set_random_state(random_state, set_model_random_state):
original_state = np.random.get_state()

if isinstance(random_state, int):
desired_state = np.random.RandomState(seed=random_state)
else:
desired_state = random_state
desired_state = np.random.RandomState(seed=random_state).get_state()
elif isinstance(random_state, np.random.RandomState):
desired_state = random_state.get_state()
elif not isinstance(random_state, tuple):
raise TypeError(f'RandomState {random_state} is an unexpected type. '
'Expected to be int, np.random.RandomState, or tuple.')

np.random.set_state(desired_state.get_state())
np.random.set_state(desired_state)

try:
yield
finally:
set_model_random_state(desired_state)
set_model_random_state(np.random.get_state())
np.random.set_state(original_state)


Expand Down
62 changes: 62 additions & 0 deletions tests/end-to-end/test___init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import numpy as np

from copulas import random_state


class TestRandomState:
"""Test class for the random state wrapper."""

def init(self, random_state=None):
self.random_state = random_state

def set_random_state(self, random_state):
self.random_state = random_state

@random_state
def sample(self):
pass


def test_random_state_decorator():
"""Test the ``random_state`` decorator end-to-end.
Expect that the random state wrapper leaves the global state
where it left off.
Setup:
- A global random state is initialized with a seed of 42.
- The state is advanced by generating two random values.
Input:
- A seed of 0 is given to the test instance.
Side Effects:
- Sampling two more random values after the test instance
method completes is expected to continue where the random
state left off.
"""
# Setup
original_state = np.random.get_state()

# Get the expected random sequence with a seed of 42.
_SEED = 42
np.random.seed(_SEED)
expected = np.random.random(size=4)

# Set the global random state.
new_state = np.random.RandomState(seed=_SEED).get_state()
np.random.set_state(new_state)

first_sequence = np.random.random(size=2)

# Run
instance = TestRandomState()
instance.set_random_state(0)
instance.sample()

second_sequence = np.random.random(size=2)

# Assert
np.testing.assert_array_equal(first_sequence, expected[:2])
np.testing.assert_array_equal(second_sequence, expected[2:])

# Cleanup
np.random.set_state(original_state)
3 changes: 2 additions & 1 deletion tests/unit/test___init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,8 @@ def test_valid_random_state(self, random_mock):
my_function.assert_called_once_with(instance, *args, **kwargs)

instance.assert_not_called
random_mock.get_state.assert_called_once_with()
random_mock.get_state.assert_has_calls([call(), call()])
random_mock.get_state.call_count == 2
random_mock.RandomState.assert_called_once_with(seed=42)
random_mock.set_state.assert_has_calls(
[call('desired random state'), call('random state')])
Expand Down

0 comments on commit b2eadca

Please sign in to comment.