Skip to content

Commit

Permalink
update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
katxiao committed Feb 23, 2022
1 parent 990428c commit e809699
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 1 deletion.
53 changes: 53 additions & 0 deletions tests/integration/tabular/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,3 +400,56 @@ def test_conditional_sampling_constraint_uses_columns_model_reject_sampling(colu
sampled_data[['age_joined', 'age']],
expected_result[['age_joined', 'age']],
)


@pytest.mark.parametrize('model', MODELS)
def test_sampling_with_randomize_samples_True(model):
data = pd.DataFrame({
'column1': list(range(100)),
'column2': list(range(100)),
'column3': list(range(100))
})

model.fit(data)

sampled1 = model.sample(10, randomize_samples=True)
sampled2 = model.sample(10, randomize_samples=True)

assert not sampled1.equals(sampled2)


@pytest.mark.parametrize('model', MODELS)
def test_sampling_with_randomize_samples_False(model):
data = pd.DataFrame({
'column1': list(range(100)),
'column2': list(range(100)),
'column3': list(range(100))
})

model.fit(data)

sampled1 = model.sample(10, randomize_samples=False)
sampled2 = model.sample(10, randomize_samples=False)

pd.testing.assert_frame_equal(sampled1, sampled2)


@pytest.mark.parametrize('model', MODELS)
def test_sampling_with_randomize_samples_alternating(model):
data = pd.DataFrame({
'column1': list(range(100)),
'column2': list(range(100)),
'column3': list(range(100))
})

model.fit(data)

sampled_fixed1 = model.sample(10, randomize_samples=False)
sampled_random1 = model.sample(10, randomize_samples=True)
sampled_fixed2 = model.sample(10, randomize_samples=False)
sampled_random2 = model.sample(10, randomize_samples=True)

pd.testing.assert_frame_equal(sampled_fixed1, sampled_fixed2)
assert not sampled_random1.equals(sampled_fixed1)
assert not sampled_random1.equals(sampled_random2)
assert not sampled_random2.equals(sampled_fixed1)
39 changes: 38 additions & 1 deletion tests/unit/tabular/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from sdv.metadata.table import Table
from sdv.sampling import Condition
from sdv.tabular.base import COND_IDX, BaseTabularModel
from sdv.tabular.base import COND_IDX, FIXED_RNG_SEED, BaseTabularModel
from sdv.tabular.copulagan import CopulaGAN
from sdv.tabular.copulas import GaussianCopula
from sdv.tabular.ctgan import CTGAN, TVAE
Expand Down Expand Up @@ -764,3 +764,40 @@ def test__make_condition_dfs_with_multiple_conditions_different_columns(model):
assert isinstance(result_conditions2, pd.DataFrame)
assert len(result_conditions2) == 3
assert all(result_conditions2 == expected_conditions2)


def test__randomize_samples_true():
"""Test that ``_randomize_samples`` sets the random state correctly.
Input:
- randomize_samples as True
Side Effect:
- random state is set
"""
# Setup
instance = Mock()
randomize_samples = True

# Run
BaseTabularModel._randomize_samples(instance, randomize_samples)

# Assert
assert instance._set_random_state.called_once_with(FIXED_RNG_SEED)


def test__randomize_samples_false():
"""Test that ``_randomize_samples`` is a no-op when user wants random samples.
Input:
- randomize_samples as False
"""
# Setup
instance = Mock()
randomize_samples = False

# Run
BaseTabularModel._randomize_samples(instance, randomize_samples)

# Assert
assert instance._set_random_state.called_once_with(None)

0 comments on commit e809699

Please sign in to comment.