Skip to content

Commit

Permalink
Fix seed to generate deterministic samples (#713)
Browse files Browse the repository at this point in the history
* Fix seed when randomize samples is false

* update tests

* update dep versions
  • Loading branch information
katxiao committed Mar 3, 2022
1 parent b2b817d commit 9505243
Show file tree
Hide file tree
Showing 6 changed files with 137 additions and 4 deletions.
25 changes: 25 additions & 0 deletions sdv/tabular/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

LOGGER = logging.getLogger(__name__)
COND_IDX = str(uuid.uuid4())
FIXED_RNG_SEED = 73251


class NonParametricError(Exception):
Expand Down Expand Up @@ -423,6 +424,24 @@ def _validate_file_path(self, output_file_path):

return output_path

def _randomize_samples(self, randomize_samples):
"""Randomize the samples according to user input.
If ``randomize_samples`` is false, fix the seed that the random number generator
uses in the underlying models.
Args:
randomize_samples (bool):
Whether or not to randomize the generated samples.
"""
if self._model is None:
return

if randomize_samples:
self._set_random_state(None)
else:
self._set_random_state(FIXED_RNG_SEED)

def sample(self, num_rows, randomize_samples=True, batch_size=None, output_file_path=None,
conditions=None):
"""Sample rows from this table.
Expand Down Expand Up @@ -458,6 +477,8 @@ def sample(self, num_rows, randomize_samples=True, batch_size=None, output_file_
if num_rows == 0:
return pd.DataFrame()

self._randomize_samples(randomize_samples)

output_file_path = self._validate_file_path(output_file_path)

batch_size = min(batch_size, num_rows) if batch_size else num_rows
Expand Down Expand Up @@ -612,6 +633,8 @@ def sample_conditions(self, conditions, max_tries=100, batch_size_per_try=None,
lambda num_rows, condition: condition.get_num_rows() + num_rows, conditions, 0)
conditions = self._make_condition_dfs(conditions)

self._randomize_samples(randomize_samples)

with tqdm(total=num_rows) as progress_bar:
sampled = pd.DataFrame()
for condition_dataframe in conditions:
Expand Down Expand Up @@ -661,6 +684,8 @@ def sample_remaining_columns(self, known_columns, max_tries=100, batch_size_per_
"""
output_file_path = self._validate_file_path(output_file_path)

self._randomize_samples(randomize_samples)

with tqdm(total=len(known_columns)) as progress_bar:
sampled = self._sample_with_conditions(
known_columns, max_tries, batch_size_per_try, progress_bar, output_file_path)
Expand Down
9 changes: 9 additions & 0 deletions sdv/tabular/copulas.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,15 @@ def _sample(self, num_rows, conditions=None):
"""
return self._model.sample(num_rows, conditions=conditions)

def _set_random_state(self, random_state):
"""Set the random state of the model's random number generator.
Args:
random_state (int, np.random.RandomState, or None):
Seed or RandomState to use.
"""
self._model.set_random_state(random_state)

def get_likelihood(self, table_data):
"""Get the likelihood of each row belonging to this table."""
transformed = self._metadata.transform(table_data)
Expand Down
9 changes: 9 additions & 0 deletions sdv/tabular/ctgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,15 @@ def _sample(self, num_rows, conditions=None):

raise NotImplementedError(f"{self._MODEL_CLASS} doesn't support conditional sampling.")

def _set_random_state(self, random_state):
"""Set the random state of the model's random number generator.
Args:
random_state (int, tuple[np.random.RandomState, torch.Generator], or None):
Seed or tuple of random states to use.
"""
self._model.set_random_state(random_state)


class CTGAN(CTGANModel):
"""Model wrapping ``CTGANSynthesizer`` model.
Expand Down
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
"numpy>=1.20.0,<2;python_version>='3.7'",
'pandas>=1.1.3,<2',
'tqdm>=4.15,<5',
'copulas>=0.6.0,<0.7',
'ctgan>=0.5.0,<0.6',
'copulas>=0.6.1,<0.7',
'ctgan>=0.5.1,<0.6',
'deepecho>=0.3.0.post1,<0.4',
'rdt>=0.6.1,<0.7',
'rdt>=0.6.2,<0.7',
'sdmetrics>=0.4.1,<0.5',
]

Expand Down
53 changes: 53 additions & 0 deletions tests/integration/tabular/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,3 +400,56 @@ def test_conditional_sampling_constraint_uses_columns_model_reject_sampling(colu
sampled_data[['age_joined', 'age']],
expected_result[['age_joined', 'age']],
)


@pytest.mark.parametrize('model', MODELS)
def test_sampling_with_randomize_samples_True(model):
data = pd.DataFrame({
'column1': list(range(100)),
'column2': list(range(100)),
'column3': list(range(100))
})

model.fit(data)

sampled1 = model.sample(10, randomize_samples=True)
sampled2 = model.sample(10, randomize_samples=True)

assert not sampled1.equals(sampled2)


@pytest.mark.parametrize('model', MODELS)
def test_sampling_with_randomize_samples_False(model):
data = pd.DataFrame({
'column1': list(range(100)),
'column2': list(range(100)),
'column3': list(range(100))
})

model.fit(data)

sampled1 = model.sample(10, randomize_samples=False)
sampled2 = model.sample(10, randomize_samples=False)

pd.testing.assert_frame_equal(sampled1, sampled2)


@pytest.mark.parametrize('model', MODELS)
def test_sampling_with_randomize_samples_alternating(model):
data = pd.DataFrame({
'column1': list(range(100)),
'column2': list(range(100)),
'column3': list(range(100))
})

model.fit(data)

sampled_fixed1 = model.sample(10, randomize_samples=False)
sampled_random1 = model.sample(10, randomize_samples=True)
sampled_fixed2 = model.sample(10, randomize_samples=False)
sampled_random2 = model.sample(10, randomize_samples=True)

pd.testing.assert_frame_equal(sampled_fixed1, sampled_fixed2)
assert not sampled_random1.equals(sampled_fixed1)
assert not sampled_random1.equals(sampled_random2)
assert not sampled_random2.equals(sampled_fixed1)
39 changes: 38 additions & 1 deletion tests/unit/tabular/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from sdv.metadata.table import Table
from sdv.sampling import Condition
from sdv.tabular.base import COND_IDX, BaseTabularModel
from sdv.tabular.base import COND_IDX, FIXED_RNG_SEED, BaseTabularModel
from sdv.tabular.copulagan import CopulaGAN
from sdv.tabular.copulas import GaussianCopula
from sdv.tabular.ctgan import CTGAN, TVAE
Expand Down Expand Up @@ -764,3 +764,40 @@ def test__make_condition_dfs_with_multiple_conditions_different_columns(model):
assert isinstance(result_conditions2, pd.DataFrame)
assert len(result_conditions2) == 3
assert all(result_conditions2 == expected_conditions2)


def test__randomize_samples_true():
"""Test that ``_randomize_samples`` sets the random state correctly.
Input:
- randomize_samples as True
Side Effect:
- random state is set
"""
# Setup
instance = Mock()
randomize_samples = True

# Run
BaseTabularModel._randomize_samples(instance, randomize_samples)

# Assert
assert instance._set_random_state.called_once_with(FIXED_RNG_SEED)


def test__randomize_samples_false():
"""Test that ``_randomize_samples`` is a no-op when user wants random samples.
Input:
- randomize_samples as False
"""
# Setup
instance = Mock()
randomize_samples = False

# Run
BaseTabularModel._randomize_samples(instance, randomize_samples)

# Assert
assert instance._set_random_state.called_once_with(None)

0 comments on commit 9505243

Please sign in to comment.