diff --git a/sdv/tabular/base.py b/sdv/tabular/base.py index 44a968904..777291649 100644 --- a/sdv/tabular/base.py +++ b/sdv/tabular/base.py @@ -1,6 +1,9 @@ """Base Class for tabular models.""" +import functools import logging +import math +import os import pickle import uuid from collections import defaultdict @@ -8,6 +11,7 @@ import numpy as np import pandas as pd +from tqdm import tqdm from sdv.errors import ConstraintsNotMetError from sdv.metadata import Table @@ -278,7 +282,8 @@ def _sample_rows(self, num_rows, conditions=None, transformed_conditions=None, return sampled, num_rows def _sample_batch(self, num_rows=None, max_tries=100, batch_size_per_try=None, - conditions=None, transformed_conditions=None, float_rtol=0.01): + conditions=None, transformed_conditions=None, float_rtol=0.01, + progress_bar=None, output_file_path=None): """Sample a batch of rows with the given conditions. This will enter a reject-sampling loop in which rows will be sampled until @@ -314,6 +319,12 @@ def _sample_batch(self, num_rows=None, max_tries=100, batch_size_per_try=None, The dictionary of conditioning values transformed to the model format. float_rtol (float): Maximum tolerance when considering a float match. + progress_bar (tqdm.tqdm or None): + The progress bar to update when sampling. If None, a new tqdm progress + bar will be created. + output_file_path (str or None): + The file to periodically write sampled rows to. If None, does not write + rows anywhere. Returns: pandas.DataFrame: @@ -322,21 +333,42 @@ def _sample_batch(self, num_rows=None, max_tries=100, batch_size_per_try=None, if not batch_size_per_try: batch_size_per_try = num_rows * 10 - sampled, num_valid = self._sample_rows( - num_rows, conditions, transformed_conditions, float_rtol) + if not progress_bar: + progress_bar = tqdm(total=num_rows) + + counter = 0 + num_valid = 0 + prev_num_valid = None + remaining = num_rows + sampled = pd.DataFrame() - counter = 1 while num_valid < num_rows: if counter >= max_tries: break - remaining = num_rows - num_valid - - LOGGER.info(f'{remaining} valid rows remaining. Resampling {batch_size_per_try} rows') + prev_num_valid = num_valid sampled, num_valid = self._sample_rows( batch_size_per_try, conditions, transformed_conditions, float_rtol, sampled, ) + num_increase = min(num_valid - prev_num_valid, remaining) + if num_increase > 0: + if output_file_path: + append_kwargs = {'mode': 'a', 'header': False} if os.path.exists( + output_file_path) else {} + sampled.head(min(len(sampled), num_rows)).tail(num_increase).to_csv( + output_file_path, + index=False, + **append_kwargs, + ) + + progress_bar.update(num_increase) + + remaining = num_rows - num_valid + if remaining > 0: + LOGGER.info( + f'{remaining} valid rows remaining. Resampling {batch_size_per_try} rows') + counter += 1 return sampled.head(min(len(sampled), num_rows)) @@ -367,7 +399,8 @@ def _make_condition_dfs(self, conditions): def _conditionally_sample_rows(self, dataframe, condition, transformed_condition, max_tries=None, batch_size_per_try=None, float_rtol=0.01, - graceful_reject_sampling=True): + graceful_reject_sampling=True, progress_bar=None, + output_file_path=None): num_rows = len(dataframe) sampled_rows = self._sample_batch( num_rows, @@ -376,6 +409,8 @@ def _conditionally_sample_rows(self, dataframe, condition, transformed_condition condition, transformed_condition, float_rtol, + progress_bar, + output_file_path, ) num_sampled_rows = len(sampled_rows) @@ -401,8 +436,17 @@ def _conditionally_sample_rows(self, dataframe, condition, transformed_condition return sampled_rows + def _validate_file_path(self, output_file_path): + output_path = None + if output_file_path: + output_path = os.path.abspath(output_file_path) + if os.path.exists(output_path): + raise AssertionError(f'{output_path} already exists.') + + return output_path + @validate_sample_args - def sample(self, num_rows, randomize_samples=True): + def sample(self, num_rows, randomize_samples=True, batch_size=None, output_file_path=None): """Sample rows from this table. Args: @@ -411,6 +455,11 @@ def sample(self, num_rows, randomize_samples=True): randomize_samples (bool): Whether or not to use a fixed seed when sampling. Defaults to True. + batch_size (int or None): + The batch size to sample. Defaults to `num_rows`, if None. + output_file_path (str or None): + The file to periodically write sampled rows to. If None, does not + write rows anywhere. Returns: pandas.DataFrame: @@ -419,9 +468,30 @@ def sample(self, num_rows, randomize_samples=True): if num_rows is None: raise ValueError('You must specify the number of rows to sample (e.g. num_rows=100).') - return self._sample_batch(num_rows) + if num_rows == 0: + return pd.DataFrame() + + output_file_path = self._validate_file_path(output_file_path) - def _sample_with_conditions(self, conditions, max_tries, batch_size_per_try): + batch_size = min(batch_size, num_rows) if batch_size else num_rows + + sampled = [] + with tqdm(total=num_rows) as progress_bar: + progress_bar.set_description( + f'Sampling {num_rows} rows of data in batches of size {batch_size}') + for step in range(math.ceil(num_rows / batch_size)): + sampled_rows = self._sample_batch( + batch_size, + batch_size_per_try=batch_size, + progress_bar=progress_bar, + output_file_path=output_file_path, + ) + sampled.append(sampled_rows) + + return pd.concat(sampled, ignore_index=True) + + def _sample_with_conditions(self, conditions, max_tries, batch_size_per_try, + progress_bar=None, output_file_path=None): """Sample rows with conditions. Args: @@ -432,6 +502,11 @@ def _sample_with_conditions(self, conditions, max_tries, batch_size_per_try): batch_size_per_try (int): The batch size to use per attempt at sampling. Defaults to 10 times the number of rows. + progress_bar (tqdm.tqdm or None): + The progress bar to update. + output_file_path (str or None): + The file to periodically write sampled rows to. Defaults to + a temporary file, if None. Returns: pandas.DataFrame: @@ -480,6 +555,8 @@ def _sample_with_conditions(self, conditions, max_tries, batch_size_per_try): None, max_tries, batch_size_per_try, + progress_bar=progress_bar, + output_file_path=output_file_path, ) all_sampled_rows.append(sampled_rows) else: @@ -496,6 +573,8 @@ def _sample_with_conditions(self, conditions, max_tries, batch_size_per_try): transformed_condition, max_tries, batch_size_per_try, + progress_bar=progress_bar, + output_file_path=output_file_path, ) all_sampled_rows.append(sampled_rows) @@ -508,7 +587,7 @@ def _sample_with_conditions(self, conditions, max_tries, batch_size_per_try): return all_sampled_rows def sample_conditions(self, conditions, max_tries=100, batch_size_per_try=None, - randomize_samples=True): + randomize_samples=True, output_file_path=None): """Sample rows from this table with the given conditions. Args: @@ -524,6 +603,9 @@ def sample_conditions(self, conditions, max_tries=100, batch_size_per_try=None, randomize_samples (bool): Whether or not to use a fixed seed when sampling. Defaults to True. + output_file_path (str or None): + The file to periodically write sampled rows to. Defaults to + a temporary file, if None. Returns: pandas.DataFrame: @@ -537,18 +619,28 @@ def sample_conditions(self, conditions, max_tries=100, batch_size_per_try=None, * any of the conditions' columns are not valid. * no rows could be generated. """ + output_file_path = self._validate_file_path(output_file_path) + + num_rows = functools.reduce( + lambda num_rows, condition: condition.get_num_rows() + num_rows, conditions, 0) conditions = self._make_condition_dfs(conditions) - sampled = pd.DataFrame() - for condition_dataframe in conditions: - sampled_for_condition = self._sample_with_conditions( - condition_dataframe, max_tries, batch_size_per_try) - sampled = pd.concat([sampled, sampled_for_condition], ignore_index=True) + with tqdm(total=num_rows) as progress_bar: + sampled = pd.DataFrame() + for condition_dataframe in conditions: + sampled_for_condition = self._sample_with_conditions( + condition_dataframe, + max_tries, + batch_size_per_try, + progress_bar, + output_file_path, + ) + sampled = pd.concat([sampled, sampled_for_condition], ignore_index=True) return sampled def sample_remaining_columns(self, known_columns, max_tries=100, batch_size_per_try=None, - randomize_samples=True): + randomize_samples=True, output_file_path=None): """Sample rows from this table. Args: @@ -564,6 +656,9 @@ def sample_remaining_columns(self, known_columns, max_tries=100, batch_size_per_ randomize_samples (bool): Whether or not to use a fixed seed when sampling. Defaults to True. + output_file_path (str or None): + The file to periodically write sampled rows to. Defaults to + a temporary file, if None. Returns: pandas.DataFrame: @@ -577,7 +672,13 @@ def sample_remaining_columns(self, known_columns, max_tries=100, batch_size_per_ * any of the conditions' columns are not valid. * no rows could be generated. """ - return self._sample_with_conditions(known_columns, max_tries, batch_size_per_try) + output_file_path = self._validate_file_path(output_file_path) + + with tqdm(total=len(known_columns)) as progress_bar: + sampled = self._sample_with_conditions( + known_columns, max_tries, batch_size_per_try, progress_bar, output_file_path) + + return sampled def _get_parameters(self): raise NonParametricError() diff --git a/tests/integration/tabular/test_base.py b/tests/integration/tabular/test_base.py index 98416a2c0..b90d4f4ca 100644 --- a/tests/integration/tabular/test_base.py +++ b/tests/integration/tabular/test_base.py @@ -265,7 +265,6 @@ def test_conditional_sampling_constraint_uses_reject_sampling(gm_mock): }) sample_calls = model._model.sample.mock_calls assert len(sample_calls) == 2 - model._model.sample.assert_any_call(5, conditions=expected_transformed_conditions) model._model.sample.assert_any_call(50, conditions=expected_transformed_conditions) pd.testing.assert_frame_equal(sampled_data, expected_data) diff --git a/tests/unit/tabular/test_base.py b/tests/unit/tabular/test_base.py index 2caa902c1..b654797e8 100644 --- a/tests/unit/tabular/test_base.py +++ b/tests/unit/tabular/test_base.py @@ -1,4 +1,4 @@ -from unittest.mock import Mock, call, patch +from unittest.mock import ANY, Mock, call, patch import pandas as pd import pytest @@ -69,6 +69,8 @@ def test__sample_with_conditions_no_transformed_columns(self): None, 100, None, + progress_bar=None, + output_file_path=None, ) pd.testing.assert_frame_equal(out, expected) @@ -171,6 +173,73 @@ def test_sample_num_rows_none(self): match=r'You must specify the number of rows to sample \(e.g. num_rows=100\)'): model.sample(num_rows) + def test_sample_batch_size(self): + """Test the `BaseTabularModel.sample` method with a valid `batch_size` argument. + + Expect that the expected calls to `_sample_batch` are made. + + Input: + - num_rows = 10 + - batch_size = 5 + Output: + - The requested number of sampled rows. + Side Effect: + - Call `_sample_batch` method twice with the expected number of rows. + """ + # Setup + gaussian_copula = Mock(spec_set=GaussianCopula) + sampled_data = pd.DataFrame({ + 'column1': [28, 28, 21, 1, 2], + 'column2': [37, 37, 1, 4, 5], + 'column3': [93, 93, 6, 4, 12], + }) + gaussian_copula._sample_batch.side_effect = [sampled_data, sampled_data] + + # Run + output = BaseTabularModel.sample(gaussian_copula, 10, batch_size=5) + + # Assert + assert gaussian_copula._sample_batch.has_calls([ + call(5, batch_size_per_try=5, progress_bar=ANY, output_file_path=None), + call(5, batch_size_per_try=5, progress_bar=ANY, output_file_path=None), + ]) + assert len(output) == 10 + + def test__sample_batch_with_batch_size_per_try(self): + """Test the `BaseTabularModel._sample_batch` method with `batch_size_per_try`. + + Expect that the expected calls to `_sample_rows` are made. + + Input: + - num_rows = 10 + - batch_size_per_try = 5 + Output: + - The requested number of sampled rows. + Side Effect: + - Call `_sample_rows` method twice with the expected number of rows. + """ + # Setup + gaussian_copula = Mock(spec_set=GaussianCopula) + sampled_data = pd.DataFrame({ + 'column1': [28, 28, 21, 1, 2], + 'column2': [37, 37, 1, 4, 5], + 'column3': [93, 93, 6, 4, 12], + }) + gaussian_copula._sample_rows.side_effect = [ + (sampled_data, 5), + (sampled_data.append(sampled_data, ignore_index=False), 10), + ] + + # Run + output = BaseTabularModel._sample_batch(gaussian_copula, num_rows=10, batch_size_per_try=5) + + # Assert + assert gaussian_copula._sample_rows.has_calls([ + call(5, None, None, 0.01, DataFrameMatcher(pd.DataFrame())), + call(5, None, None, 0.01, DataFrameMatcher(sampled_data)), + ]) + assert len(output) == 10 + def test_sample_conditions_with_multiple_conditions(self): """Test the `BaseTabularModel.sample_conditions` method with multiple condtions. @@ -184,6 +253,7 @@ def test_sample_conditions_with_multiple_conditions(self): """ # Setup gaussian_copula = Mock(spec_set=GaussianCopula) + gaussian_copula._validate_file_path.return_value = None condition_values1 = {'cola': 'a'} condition1 = Condition(condition_values1, num_rows=2) @@ -212,8 +282,10 @@ def test_sample_conditions_with_multiple_conditions(self): # Asserts gaussian_copula._sample_with_conditions.assert_has_calls([ - call(DataFrameMatcher(pd.DataFrame([condition_values1] * 2)), 100, None), - call(DataFrameMatcher(pd.DataFrame([condition_values2] * 3)), 100, None), + call(DataFrameMatcher(pd.DataFrame([condition_values1] * 2)), 100, + None, ANY, None), + call(DataFrameMatcher(pd.DataFrame([condition_values2] * 3)), 100, + None, ANY, None), ]) pd.testing.assert_frame_equal(out, expected) @@ -232,6 +304,7 @@ def test_sample_remaining_columns(self): """ # Setup gaussian_copula = Mock(spec_set=GaussianCopula) + gaussian_copula._validate_file_path.return_value = None conditions = pd.DataFrame([{'cola': 'a'}] * 5) @@ -246,7 +319,7 @@ def test_sample_remaining_columns(self): # Asserts gaussian_copula._sample_with_conditions.assert_called_once_with( - DataFrameMatcher(conditions), 100, None) + DataFrameMatcher(conditions), 100, None, ANY, None) pd.testing.assert_frame_equal(out, sampled) def test__sample_with_conditions_invalid_column(self): @@ -275,6 +348,26 @@ def test__sample_with_conditions_invalid_column(self): 'Use a column name that was present in the original data.')): GaussianCopula._sample_with_conditions(gaussian_copula, conditions, 100, None) + @patch('sdv.tabular.base.os.path') + def test__validate_file_path(self, path_mock): + """Test the `BaseTabularModel._validate_file_path` method. + + Expect that an error is thrown if the file path already exists. + + Input: + - A file path that already exists. + Side Effects: + - An AssertionError. + """ + # Setup + path_mock.exists.return_value = True + path_mock.abspath.return_value = 'path/to/file' + gaussian_copula = Mock(spec_set=GaussianCopula) + + # Run and Assert + with pytest.raises(AssertionError, match='path/to/file already exists'): + BaseTabularModel._validate_file_path(gaussian_copula, 'file_path') + @patch('sdv.tabular.base.Table', spec_set=Table) def test__init__passes_correct_parameters(metadata_mock): @@ -440,7 +533,7 @@ def test__sample_with_conditions_empty_transformed_conditions(): pd.testing.assert_series_equal(args[0]['column1'], conditions_series) assert kwargs['on_missing_column'] == 'drop' model._metadata.transform.assert_called_once() - model._sample_batch.assert_called_with(5, 100, None, conditions, None, 0.01) + model._sample_batch.assert_called_with(5, 100, None, conditions, None, 0.01, None, None) pd.testing.assert_frame_equal(output, expected_output) @@ -502,13 +595,13 @@ def test__sample_with_conditions_transform_conditions_correctly(): assert kwargs['on_missing_column'] == 'drop' model._metadata.transform.assert_called_once() model._sample_batch.assert_any_call( - 3, 100, None, {'column1': 25}, {'transformed_column': 50}, 0.01 + 3, 100, None, {'column1': 25}, {'transformed_column': 50}, 0.01, None, None, ) model._sample_batch.assert_any_call( - 1, 100, None, {'column1': 30}, {'transformed_column': 60}, 0.01 + 1, 100, None, {'column1': 30}, {'transformed_column': 60}, 0.01, None, None, ) model._sample_batch.assert_any_call( - 1, 100, None, {'column1': 30}, {'transformed_column': 70}, 0.01 + 1, 100, None, {'column1': 30}, {'transformed_column': 70}, 0.01, None, None, )