From b2eadcaa9e9827d011a69e545844bc6cdb4c1116 Mon Sep 17 00:00:00 2001 From: Katharine Xiao <2405771+katxiao@users.noreply.github.com> Date: Wed, 9 Feb 2022 17:50:49 -0500 Subject: [PATCH] cr comments --- copulas/__init__.py | 13 ++++--- tests/end-to-end/test___init__.py | 62 +++++++++++++++++++++++++++++++ tests/unit/test___init__.py | 3 +- 3 files changed, 72 insertions(+), 6 deletions(-) create mode 100644 tests/end-to-end/test___init__.py diff --git a/copulas/__init__.py b/copulas/__init__.py index 227283b7..19a89e6f 100644 --- a/copulas/__init__.py +++ b/copulas/__init__.py @@ -33,16 +33,19 @@ def set_random_state(random_state, set_model_random_state): original_state = np.random.get_state() if isinstance(random_state, int): - desired_state = np.random.RandomState(seed=random_state) - else: - desired_state = random_state + desired_state = np.random.RandomState(seed=random_state).get_state() + elif isinstance(random_state, np.random.RandomState): + desired_state = random_state.get_state() + elif not isinstance(random_state, tuple): + raise TypeError(f'RandomState {random_state} is an unexpected type. ' + 'Expected to be int, np.random.RandomState, or tuple.') - np.random.set_state(desired_state.get_state()) + np.random.set_state(desired_state) try: yield finally: - set_model_random_state(desired_state) + set_model_random_state(np.random.get_state()) np.random.set_state(original_state) diff --git a/tests/end-to-end/test___init__.py b/tests/end-to-end/test___init__.py new file mode 100644 index 00000000..db78a460 --- /dev/null +++ b/tests/end-to-end/test___init__.py @@ -0,0 +1,62 @@ +import numpy as np + +from copulas import random_state + + +class TestRandomState: + """Test class for the random state wrapper.""" + + def init(self, random_state=None): + self.random_state = random_state + + def set_random_state(self, random_state): + self.random_state = random_state + + @random_state + def sample(self): + pass + + +def test_random_state_decorator(): + """Test the ``random_state`` decorator end-to-end. + + Expect that the random state wrapper leaves the global state + where it left off. + + Setup: + - A global random state is initialized with a seed of 42. + - The state is advanced by generating two random values. + Input: + - A seed of 0 is given to the test instance. + Side Effects: + - Sampling two more random values after the test instance + method completes is expected to continue where the random + state left off. + """ + # Setup + original_state = np.random.get_state() + + # Get the expected random sequence with a seed of 42. + _SEED = 42 + np.random.seed(_SEED) + expected = np.random.random(size=4) + + # Set the global random state. + new_state = np.random.RandomState(seed=_SEED).get_state() + np.random.set_state(new_state) + + first_sequence = np.random.random(size=2) + + # Run + instance = TestRandomState() + instance.set_random_state(0) + instance.sample() + + second_sequence = np.random.random(size=2) + + # Assert + np.testing.assert_array_equal(first_sequence, expected[:2]) + np.testing.assert_array_equal(second_sequence, expected[2:]) + + # Cleanup + np.random.set_state(original_state) diff --git a/tests/unit/test___init__.py b/tests/unit/test___init__.py index 9aa4b2c6..28f2cc4e 100644 --- a/tests/unit/test___init__.py +++ b/tests/unit/test___init__.py @@ -246,7 +246,8 @@ def test_valid_random_state(self, random_mock): my_function.assert_called_once_with(instance, *args, **kwargs) instance.assert_not_called - random_mock.get_state.assert_called_once_with() + random_mock.get_state.assert_has_calls([call(), call()]) + random_mock.get_state.call_count == 2 random_mock.RandomState.assert_called_once_with(seed=42) random_mock.set_state.assert_has_calls( [call('desired random state'), call('random state')])