From 868a4b969ad9d73451d7cb678feea7636b086701 Mon Sep 17 00:00:00 2001 From: Katharine Xiao <2405771+katxiao@users.noreply.github.com> Date: Wed, 9 Feb 2022 19:59:58 -0500 Subject: [PATCH 1/6] Add batch sampling and progress bar --- sdv/tabular/base.py | 65 ++++++++++++++++++++++++++++++++++++--------- 1 file changed, 52 insertions(+), 13 deletions(-) diff --git a/sdv/tabular/base.py b/sdv/tabular/base.py index 44a968904..6fa5b804b 100644 --- a/sdv/tabular/base.py +++ b/sdv/tabular/base.py @@ -1,6 +1,8 @@ """Base Class for tabular models.""" +import functools import logging +import math import pickle import uuid from collections import defaultdict @@ -8,6 +10,7 @@ import numpy as np import pandas as pd +from tqdm import tqdm from sdv.errors import ConstraintsNotMetError from sdv.metadata import Table @@ -278,7 +281,8 @@ def _sample_rows(self, num_rows, conditions=None, transformed_conditions=None, return sampled, num_rows def _sample_batch(self, num_rows=None, max_tries=100, batch_size_per_try=None, - conditions=None, transformed_conditions=None, float_rtol=0.01): + conditions=None, transformed_conditions=None, float_rtol=0.01, + progress_bar=None): """Sample a batch of rows with the given conditions. This will enter a reject-sampling loop in which rows will be sampled until @@ -314,6 +318,9 @@ def _sample_batch(self, num_rows=None, max_tries=100, batch_size_per_try=None, The dictionary of conditioning values transformed to the model format. float_rtol (float): Maximum tolerance when considering a float match. + progress_bar (tqdm.tqdm or None): + The progress bar to update when sampling. If None, a new tqdm progress + bar will be created. Returns: pandas.DataFrame: @@ -322,21 +329,30 @@ def _sample_batch(self, num_rows=None, max_tries=100, batch_size_per_try=None, if not batch_size_per_try: batch_size_per_try = num_rows * 10 - sampled, num_valid = self._sample_rows( - num_rows, conditions, transformed_conditions, float_rtol) + if not progress_bar: + progress_bar = tqdm(total=num_rows) + + counter = 0 + num_valid = 0 + prev_num_valid = None + remaining = num_rows + sampled = pd.DataFrame() - counter = 1 while num_valid < num_rows: if counter >= max_tries: break - remaining = num_rows - num_valid - - LOGGER.info(f'{remaining} valid rows remaining. Resampling {batch_size_per_try} rows') + prev_num_valid = num_valid sampled, num_valid = self._sample_rows( batch_size_per_try, conditions, transformed_conditions, float_rtol, sampled, ) + progress_bar.update(min(num_valid - prev_num_valid, remaining)) + remaining = num_rows - num_valid + if remaining > 0: + LOGGER.info( + f'{remaining} valid rows remaining. Resampling {batch_size_per_try} rows') + counter += 1 return sampled.head(min(len(sampled), num_rows)) @@ -367,7 +383,7 @@ def _make_condition_dfs(self, conditions): def _conditionally_sample_rows(self, dataframe, condition, transformed_condition, max_tries=None, batch_size_per_try=None, float_rtol=0.01, - graceful_reject_sampling=True): + graceful_reject_sampling=True, progress_bar=None): num_rows = len(dataframe) sampled_rows = self._sample_batch( num_rows, @@ -376,6 +392,7 @@ def _conditionally_sample_rows(self, dataframe, condition, transformed_condition condition, transformed_condition, float_rtol, + progress_bar, ) num_sampled_rows = len(sampled_rows) @@ -402,7 +419,7 @@ def _conditionally_sample_rows(self, dataframe, condition, transformed_condition return sampled_rows @validate_sample_args - def sample(self, num_rows, randomize_samples=True): + def sample(self, num_rows, randomize_samples=True, batch_size=None): """Sample rows from this table. Args: @@ -411,6 +428,8 @@ def sample(self, num_rows, randomize_samples=True): randomize_samples (bool): Whether or not to use a fixed seed when sampling. Defaults to True. + batch_size (int or None): + The batch size to sample. Defaults to `num_rows`, if None. Returns: pandas.DataFrame: @@ -419,9 +438,21 @@ def sample(self, num_rows, randomize_samples=True): if num_rows is None: raise ValueError('You must specify the number of rows to sample (e.g. num_rows=100).') - return self._sample_batch(num_rows) + batch_size = min(batch_size, num_rows) if batch_size else num_rows + progress_bar = tqdm(total=num_rows) + progress_bar.set_description( + f'Sampling {num_rows} rows of data in batches of size {batch_size}') + + sampled = [] + for step in range(math.ceil(num_rows / batch_size)): + sampled_rows = self._sample_batch( + batch_size, batch_size_per_try=batch_size, progress_bar=progress_bar) + sampled.append(sampled_rows) + + return pd.concat(sampled, ignore_index=True) - def _sample_with_conditions(self, conditions, max_tries, batch_size_per_try): + def _sample_with_conditions(self, conditions, max_tries, batch_size_per_try, + progress_bar=None): """Sample rows with conditions. Args: @@ -432,6 +463,8 @@ def _sample_with_conditions(self, conditions, max_tries, batch_size_per_try): batch_size_per_try (int): The batch size to use per attempt at sampling. Defaults to 10 times the number of rows. + progress_bar (tqdm.tqdm or None): + The progress bar to update. Returns: pandas.DataFrame: @@ -480,6 +513,7 @@ def _sample_with_conditions(self, conditions, max_tries, batch_size_per_try): None, max_tries, batch_size_per_try, + progress_bar=progress_bar, ) all_sampled_rows.append(sampled_rows) else: @@ -496,6 +530,7 @@ def _sample_with_conditions(self, conditions, max_tries, batch_size_per_try): transformed_condition, max_tries, batch_size_per_try, + progress_bar=progress_bar, ) all_sampled_rows.append(sampled_rows) @@ -537,12 +572,15 @@ def sample_conditions(self, conditions, max_tries=100, batch_size_per_try=None, * any of the conditions' columns are not valid. * no rows could be generated. """ + num_rows = functools.reduce( + lambda num_rows, condition: condition.get_num_rows() + num_rows, conditions, 0) + progress_bar = tqdm(total=num_rows) conditions = self._make_condition_dfs(conditions) sampled = pd.DataFrame() for condition_dataframe in conditions: sampled_for_condition = self._sample_with_conditions( - condition_dataframe, max_tries, batch_size_per_try) + condition_dataframe, max_tries, batch_size_per_try, progress_bar) sampled = pd.concat([sampled, sampled_for_condition], ignore_index=True) return sampled @@ -577,7 +615,8 @@ def sample_remaining_columns(self, known_columns, max_tries=100, batch_size_per_ * any of the conditions' columns are not valid. * no rows could be generated. """ - return self._sample_with_conditions(known_columns, max_tries, batch_size_per_try) + return self._sample_with_conditions( + known_columns, max_tries, batch_size_per_try, tqdm(total=len(known_columns))) def _get_parameters(self): raise NonParametricError() From 4614e7622bd952945ac52dfe432f48ba986f349a Mon Sep 17 00:00:00 2001 From: Katharine Xiao <2405771+katxiao@users.noreply.github.com> Date: Thu, 10 Feb 2022 16:09:58 -0500 Subject: [PATCH 2/6] Make sure to close progress bar --- sdv/tabular/base.py | 37 +++++++++++++++++++++---------------- 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/sdv/tabular/base.py b/sdv/tabular/base.py index 6fa5b804b..4ce9fab35 100644 --- a/sdv/tabular/base.py +++ b/sdv/tabular/base.py @@ -347,7 +347,9 @@ def _sample_batch(self, num_rows=None, max_tries=100, batch_size_per_try=None, batch_size_per_try, conditions, transformed_conditions, float_rtol, sampled, ) - progress_bar.update(min(num_valid - prev_num_valid, remaining)) + if num_valid > 0: + progress_bar.update(min(num_valid - prev_num_valid, remaining)) + remaining = num_rows - num_valid if remaining > 0: LOGGER.info( @@ -439,15 +441,15 @@ def sample(self, num_rows, randomize_samples=True, batch_size=None): raise ValueError('You must specify the number of rows to sample (e.g. num_rows=100).') batch_size = min(batch_size, num_rows) if batch_size else num_rows - progress_bar = tqdm(total=num_rows) - progress_bar.set_description( - f'Sampling {num_rows} rows of data in batches of size {batch_size}') sampled = [] - for step in range(math.ceil(num_rows / batch_size)): - sampled_rows = self._sample_batch( - batch_size, batch_size_per_try=batch_size, progress_bar=progress_bar) - sampled.append(sampled_rows) + with tqdm(total=num_rows) as progress_bar: + progress_bar.set_description( + f'Sampling {num_rows} rows of data in batches of size {batch_size}') + for step in range(math.ceil(num_rows / batch_size)): + sampled_rows = self._sample_batch( + batch_size, batch_size_per_try=batch_size, progress_bar=progress_bar) + sampled.append(sampled_rows) return pd.concat(sampled, ignore_index=True) @@ -574,14 +576,14 @@ def sample_conditions(self, conditions, max_tries=100, batch_size_per_try=None, """ num_rows = functools.reduce( lambda num_rows, condition: condition.get_num_rows() + num_rows, conditions, 0) - progress_bar = tqdm(total=num_rows) conditions = self._make_condition_dfs(conditions) - sampled = pd.DataFrame() - for condition_dataframe in conditions: - sampled_for_condition = self._sample_with_conditions( - condition_dataframe, max_tries, batch_size_per_try, progress_bar) - sampled = pd.concat([sampled, sampled_for_condition], ignore_index=True) + with tqdm(total=num_rows) as progress_bar: + sampled = pd.DataFrame() + for condition_dataframe in conditions: + sampled_for_condition = self._sample_with_conditions( + condition_dataframe, max_tries, batch_size_per_try, progress_bar) + sampled = pd.concat([sampled, sampled_for_condition], ignore_index=True) return sampled @@ -615,8 +617,11 @@ def sample_remaining_columns(self, known_columns, max_tries=100, batch_size_per_ * any of the conditions' columns are not valid. * no rows could be generated. """ - return self._sample_with_conditions( - known_columns, max_tries, batch_size_per_try, tqdm(total=len(known_columns))) + with tqdm(total=len(known_columns)) as progress_bar: + sampled = self._sample_with_conditions( + known_columns, max_tries, batch_size_per_try, progress_bar) + + return sampled def _get_parameters(self): raise NonParametricError() From 19a0de9ba5e93707a823c4815683b270aeebbc81 Mon Sep 17 00:00:00 2001 From: Katharine Xiao <2405771+katxiao@users.noreply.github.com> Date: Thu, 10 Feb 2022 19:11:42 -0500 Subject: [PATCH 3/6] Periodically write to file --- sdv/tabular/base.py | 79 ++++++++++++++++++++++++++++----- tests/unit/tabular/test_base.py | 18 +++++--- 2 files changed, 79 insertions(+), 18 deletions(-) diff --git a/sdv/tabular/base.py b/sdv/tabular/base.py index 4ce9fab35..294d7216d 100644 --- a/sdv/tabular/base.py +++ b/sdv/tabular/base.py @@ -3,6 +3,7 @@ import functools import logging import math +import os import pickle import uuid from collections import defaultdict @@ -282,7 +283,7 @@ def _sample_rows(self, num_rows, conditions=None, transformed_conditions=None, def _sample_batch(self, num_rows=None, max_tries=100, batch_size_per_try=None, conditions=None, transformed_conditions=None, float_rtol=0.01, - progress_bar=None): + progress_bar=None, output_file_path=None): """Sample a batch of rows with the given conditions. This will enter a reject-sampling loop in which rows will be sampled until @@ -321,6 +322,9 @@ def _sample_batch(self, num_rows=None, max_tries=100, batch_size_per_try=None, progress_bar (tqdm.tqdm or None): The progress bar to update when sampling. If None, a new tqdm progress bar will be created. + output_file_path (str or None): + The file to periodically write sampled rows to. If None, does not write + rows anywhere. Returns: pandas.DataFrame: @@ -347,8 +351,18 @@ def _sample_batch(self, num_rows=None, max_tries=100, batch_size_per_try=None, batch_size_per_try, conditions, transformed_conditions, float_rtol, sampled, ) - if num_valid > 0: - progress_bar.update(min(num_valid - prev_num_valid, remaining)) + num_increase = min(num_valid - prev_num_valid, remaining) + if num_increase > 0: + if output_file_path: + append_kwargs = {'mode': 'a', 'header': False} if os.path.exists( + output_file_path) else {} + sampled.head(min(len(sampled), num_rows)).tail(num_increase).to_csv( + output_file_path, + index=False, + **append_kwargs, + ) + + progress_bar.update(num_increase) remaining = num_rows - num_valid if remaining > 0: @@ -385,7 +399,8 @@ def _make_condition_dfs(self, conditions): def _conditionally_sample_rows(self, dataframe, condition, transformed_condition, max_tries=None, batch_size_per_try=None, float_rtol=0.01, - graceful_reject_sampling=True, progress_bar=None): + graceful_reject_sampling=True, progress_bar=None, + output_file_path=None): num_rows = len(dataframe) sampled_rows = self._sample_batch( num_rows, @@ -395,6 +410,7 @@ def _conditionally_sample_rows(self, dataframe, condition, transformed_condition transformed_condition, float_rtol, progress_bar, + output_file_path, ) num_sampled_rows = len(sampled_rows) @@ -420,8 +436,20 @@ def _conditionally_sample_rows(self, dataframe, condition, transformed_condition return sampled_rows + def _validate_file_path(self, output_file_path): + output_path = None + + if output_file_path: + output_path = os.path.abspath(output_file_path) + + if os.path.exists(output_path): + raise AssertionError(f'{output_path} already exists.') + + return output_path + + @validate_sample_args - def sample(self, num_rows, randomize_samples=True, batch_size=None): + def sample(self, num_rows, randomize_samples=True, batch_size=None, output_file_path=None): """Sample rows from this table. Args: @@ -432,6 +460,9 @@ def sample(self, num_rows, randomize_samples=True, batch_size=None): to True. batch_size (int or None): The batch size to sample. Defaults to `num_rows`, if None. + output_file_path (str or None): + The file to periodically write sampled rows to. If None, does not + write rows anywhere. Returns: pandas.DataFrame: @@ -440,6 +471,8 @@ def sample(self, num_rows, randomize_samples=True, batch_size=None): if num_rows is None: raise ValueError('You must specify the number of rows to sample (e.g. num_rows=100).') + output_file_path = self._validate_file_path(output_file_path) + batch_size = min(batch_size, num_rows) if batch_size else num_rows sampled = [] @@ -448,13 +481,17 @@ def sample(self, num_rows, randomize_samples=True, batch_size=None): f'Sampling {num_rows} rows of data in batches of size {batch_size}') for step in range(math.ceil(num_rows / batch_size)): sampled_rows = self._sample_batch( - batch_size, batch_size_per_try=batch_size, progress_bar=progress_bar) + batch_size, + batch_size_per_try=batch_size, + progress_bar=progress_bar, + output_file_path=output_file_path, + ) sampled.append(sampled_rows) return pd.concat(sampled, ignore_index=True) def _sample_with_conditions(self, conditions, max_tries, batch_size_per_try, - progress_bar=None): + progress_bar=None, output_file_path=None): """Sample rows with conditions. Args: @@ -467,6 +504,9 @@ def _sample_with_conditions(self, conditions, max_tries, batch_size_per_try, the number of rows. progress_bar (tqdm.tqdm or None): The progress bar to update. + output_file_path (str or None): + The file to periodically write sampled rows to. Defaults to + a temporary file, if None. Returns: pandas.DataFrame: @@ -516,6 +556,7 @@ def _sample_with_conditions(self, conditions, max_tries, batch_size_per_try, max_tries, batch_size_per_try, progress_bar=progress_bar, + output_file_path=output_file_path, ) all_sampled_rows.append(sampled_rows) else: @@ -533,6 +574,7 @@ def _sample_with_conditions(self, conditions, max_tries, batch_size_per_try, max_tries, batch_size_per_try, progress_bar=progress_bar, + output_file_path=output_file_path, ) all_sampled_rows.append(sampled_rows) @@ -545,7 +587,7 @@ def _sample_with_conditions(self, conditions, max_tries, batch_size_per_try, return all_sampled_rows def sample_conditions(self, conditions, max_tries=100, batch_size_per_try=None, - randomize_samples=True): + randomize_samples=True, output_file_path=None): """Sample rows from this table with the given conditions. Args: @@ -561,6 +603,9 @@ def sample_conditions(self, conditions, max_tries=100, batch_size_per_try=None, randomize_samples (bool): Whether or not to use a fixed seed when sampling. Defaults to True. + output_file_path (str or None): + The file to periodically write sampled rows to. Defaults to + a temporary file, if None. Returns: pandas.DataFrame: @@ -574,6 +619,8 @@ def sample_conditions(self, conditions, max_tries=100, batch_size_per_try=None, * any of the conditions' columns are not valid. * no rows could be generated. """ + output_file_path = self._validate_file_path(output_file_path) + num_rows = functools.reduce( lambda num_rows, condition: condition.get_num_rows() + num_rows, conditions, 0) conditions = self._make_condition_dfs(conditions) @@ -582,13 +629,18 @@ def sample_conditions(self, conditions, max_tries=100, batch_size_per_try=None, sampled = pd.DataFrame() for condition_dataframe in conditions: sampled_for_condition = self._sample_with_conditions( - condition_dataframe, max_tries, batch_size_per_try, progress_bar) + condition_dataframe, + max_tries, + batch_size_per_try, + progress_bar, + output_file_path, + ) sampled = pd.concat([sampled, sampled_for_condition], ignore_index=True) return sampled def sample_remaining_columns(self, known_columns, max_tries=100, batch_size_per_try=None, - randomize_samples=True): + randomize_samples=True, outout_file_path=None): """Sample rows from this table. Args: @@ -604,6 +656,9 @@ def sample_remaining_columns(self, known_columns, max_tries=100, batch_size_per_ randomize_samples (bool): Whether or not to use a fixed seed when sampling. Defaults to True. + output_file_path (str or None): + The file to periodically write sampled rows to. Defaults to + a temporary file, if None. Returns: pandas.DataFrame: @@ -617,9 +672,11 @@ def sample_remaining_columns(self, known_columns, max_tries=100, batch_size_per_ * any of the conditions' columns are not valid. * no rows could be generated. """ + output_file_path = self._validate_file_path(output_file_path) + with tqdm(total=len(known_columns)) as progress_bar: sampled = self._sample_with_conditions( - known_columns, max_tries, batch_size_per_try, progress_bar) + known_columns, max_tries, batch_size_per_try, progress_bar, output_file_path) return sampled diff --git a/tests/unit/tabular/test_base.py b/tests/unit/tabular/test_base.py index 2caa902c1..a00af28ac 100644 --- a/tests/unit/tabular/test_base.py +++ b/tests/unit/tabular/test_base.py @@ -1,4 +1,4 @@ -from unittest.mock import Mock, call, patch +from unittest.mock import ANY, Mock, call, patch import pandas as pd import pytest @@ -69,6 +69,8 @@ def test__sample_with_conditions_no_transformed_columns(self): None, 100, None, + progress_bar=None, + output_file_path=None, ) pd.testing.assert_frame_equal(out, expected) @@ -212,8 +214,10 @@ def test_sample_conditions_with_multiple_conditions(self): # Asserts gaussian_copula._sample_with_conditions.assert_has_calls([ - call(DataFrameMatcher(pd.DataFrame([condition_values1] * 2)), 100, None), - call(DataFrameMatcher(pd.DataFrame([condition_values2] * 3)), 100, None), + call(DataFrameMatcher(pd.DataFrame([condition_values1] * 2)), 100, + None, ANY, None), + call(DataFrameMatcher(pd.DataFrame([condition_values2] * 3)), 100, + None, ANY, None), ]) pd.testing.assert_frame_equal(out, expected) @@ -440,7 +444,7 @@ def test__sample_with_conditions_empty_transformed_conditions(): pd.testing.assert_series_equal(args[0]['column1'], conditions_series) assert kwargs['on_missing_column'] == 'drop' model._metadata.transform.assert_called_once() - model._sample_batch.assert_called_with(5, 100, None, conditions, None, 0.01) + model._sample_batch.assert_called_with(5, 100, None, conditions, None, 0.01, None, None) pd.testing.assert_frame_equal(output, expected_output) @@ -502,13 +506,13 @@ def test__sample_with_conditions_transform_conditions_correctly(): assert kwargs['on_missing_column'] == 'drop' model._metadata.transform.assert_called_once() model._sample_batch.assert_any_call( - 3, 100, None, {'column1': 25}, {'transformed_column': 50}, 0.01 + 3, 100, None, {'column1': 25}, {'transformed_column': 50}, 0.01, None, None, ) model._sample_batch.assert_any_call( - 1, 100, None, {'column1': 30}, {'transformed_column': 60}, 0.01 + 1, 100, None, {'column1': 30}, {'transformed_column': 60}, 0.01, None, None, ) model._sample_batch.assert_any_call( - 1, 100, None, {'column1': 30}, {'transformed_column': 70}, 0.01 + 1, 100, None, {'column1': 30}, {'transformed_column': 70}, 0.01, None, None, ) From 0eef2d91db5a1c8a1927b2e2dfd6279777a2fac6 Mon Sep 17 00:00:00 2001 From: Katharine Xiao <2405771+katxiao@users.noreply.github.com> Date: Fri, 11 Feb 2022 18:08:24 -0500 Subject: [PATCH 4/6] add unit tests --- sdv/tabular/base.py | 5 +- tests/unit/tabular/test_base.py | 90 ++++++++++++++++++++++++++++++++- 2 files changed, 93 insertions(+), 2 deletions(-) diff --git a/sdv/tabular/base.py b/sdv/tabular/base.py index 294d7216d..48691aa72 100644 --- a/sdv/tabular/base.py +++ b/sdv/tabular/base.py @@ -471,6 +471,9 @@ def sample(self, num_rows, randomize_samples=True, batch_size=None, output_file_ if num_rows is None: raise ValueError('You must specify the number of rows to sample (e.g. num_rows=100).') + if num_rows == 0: + return pd.DataFrame() + output_file_path = self._validate_file_path(output_file_path) batch_size = min(batch_size, num_rows) if batch_size else num_rows @@ -640,7 +643,7 @@ def sample_conditions(self, conditions, max_tries=100, batch_size_per_try=None, return sampled def sample_remaining_columns(self, known_columns, max_tries=100, batch_size_per_try=None, - randomize_samples=True, outout_file_path=None): + randomize_samples=True, output_file_path=None): """Sample rows from this table. Args: diff --git a/tests/unit/tabular/test_base.py b/tests/unit/tabular/test_base.py index a00af28ac..721fcb5bd 100644 --- a/tests/unit/tabular/test_base.py +++ b/tests/unit/tabular/test_base.py @@ -173,6 +173,73 @@ def test_sample_num_rows_none(self): match=r'You must specify the number of rows to sample \(e.g. num_rows=100\)'): model.sample(num_rows) + def test_sample_batch_size(self): + """Test the `BaseTabularModel.sample` method with a valid `batch_size` argument. + + Expect that the expected calls to `_sample_batch` are made. + + Input: + - num_rows = 10 + - batch_size = 5 + Output: + - The requested number of sampled rows. + Side Effect: + - Call `_sample_batch` method twice with the expected number of rows. + """ + # Setup + gaussian_copula = Mock(spec_set=GaussianCopula) + sampled_data = pd.DataFrame({ + "column1": [28, 28, 21, 1, 2], + "column2": [37, 37, 1, 4, 5], + "column3": [93, 93, 6, 4, 12], + }) + gaussian_copula._sample_batch.side_effect = [sampled_data, sampled_data] + + # Run + output = BaseTabularModel.sample(gaussian_copula, 10, batch_size=5) + + # Assert + assert gaussian_copula._sample_batch.has_calls([ + call(5, batch_size_per_try=5, progress_bar=ANY, output_file_path=None), + call(5, batch_size_per_try=5, progress_bar=ANY, output_file_path=None), + ]) + assert len(output) == 10 + + def test__sample_batch_with_batch_size_per_try(self): + """Test the `BaseTabularModel._sample_batch` method with `batch_size_per_try`. + + Expect that the expected calls to `_sample_rows` are made. + + Input: + - num_rows = 10 + - batch_size_per_try = 5 + Output: + - The requested number of sampled rows. + Side Effect: + - Call `_sample_rows` method twice with the expected number of rows. + """ + # Setup + gaussian_copula = Mock(spec_set=GaussianCopula) + sampled_data = pd.DataFrame({ + "column1": [28, 28, 21, 1, 2], + "column2": [37, 37, 1, 4, 5], + "column3": [93, 93, 6, 4, 12], + }) + gaussian_copula._sample_rows.side_effect = [ + (sampled_data, 5), + (sampled_data.append(sampled_data, ignore_index=False), 10), + ] + + # Run + output = BaseTabularModel._sample_batch(gaussian_copula, num_rows=10, batch_size_per_try=5) + + # Assert + assert gaussian_copula._sample_rows.has_calls([ + call(5, None, None, 0.01, DataFrameMatcher(pd.DataFrame())), + call(5, None, None, 0.01, DataFrameMatcher(sampled_data)), + ]) + assert len(output) == 10 + def test_sample_conditions_with_multiple_conditions(self): """Test the `BaseTabularModel.sample_conditions` method with multiple condtions. @@ -186,6 +253,7 @@ def test_sample_conditions_with_multiple_conditions(self): """ # Setup gaussian_copula = Mock(spec_set=GaussianCopula) + gaussian_copula._validate_file_path.return_value = None condition_values1 = {'cola': 'a'} condition1 = Condition(condition_values1, num_rows=2) @@ -236,6 +304,7 @@ def test_sample_remaining_columns(self): """ # Setup gaussian_copula = Mock(spec_set=GaussianCopula) + gaussian_copula._validate_file_path.return_value = None conditions = pd.DataFrame([{'cola': 'a'}] * 5) @@ -250,7 +319,7 @@ def test_sample_remaining_columns(self): # Asserts gaussian_copula._sample_with_conditions.assert_called_once_with( - DataFrameMatcher(conditions), 100, None) + DataFrameMatcher(conditions), 100, None, ANY, None) pd.testing.assert_frame_equal(out, sampled) def test__sample_with_conditions_invalid_column(self): @@ -279,6 +348,25 @@ def test__sample_with_conditions_invalid_column(self): 'Use a column name that was present in the original data.')): GaussianCopula._sample_with_conditions(gaussian_copula, conditions, 100, None) + @patch('sdv.tabular.base.os.path') + def test__validate_file_path(self, path_mock): + """Test the `BaseTabularModel._validate_file_path` method. + + Expect that an error is thrown if the file path already exists. + + Input: + - A file path that already exists. + Side Effects: + - An AssertionError. + """ + # Setup + path_mock.exists.return_value = True + gaussian_copula = Mock(spec_set=GaussianCopula) + + # Run and Assert + with pytest.raises(AssertionError): + BaseTabularModel._validate_file_path(gaussian_copula, 'file_path') + @patch('sdv.tabular.base.Table', spec_set=Table) def test__init__passes_correct_parameters(metadata_mock): From e878568392fe351e5292de0bc90de16095170ec1 Mon Sep 17 00:00:00 2001 From: Katharine Xiao <2405771+katxiao@users.noreply.github.com> Date: Wed, 16 Feb 2022 17:09:04 -0500 Subject: [PATCH 5/6] cr comments --- sdv/tabular/base.py | 3 --- tests/unit/tabular/test_base.py | 15 ++++++++------- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/sdv/tabular/base.py b/sdv/tabular/base.py index 48691aa72..777291649 100644 --- a/sdv/tabular/base.py +++ b/sdv/tabular/base.py @@ -438,16 +438,13 @@ def _conditionally_sample_rows(self, dataframe, condition, transformed_condition def _validate_file_path(self, output_file_path): output_path = None - if output_file_path: output_path = os.path.abspath(output_file_path) - if os.path.exists(output_path): raise AssertionError(f'{output_path} already exists.') return output_path - @validate_sample_args def sample(self, num_rows, randomize_samples=True, batch_size=None, output_file_path=None): """Sample rows from this table. diff --git a/tests/unit/tabular/test_base.py b/tests/unit/tabular/test_base.py index 721fcb5bd..b654797e8 100644 --- a/tests/unit/tabular/test_base.py +++ b/tests/unit/tabular/test_base.py @@ -189,9 +189,9 @@ def test_sample_batch_size(self): # Setup gaussian_copula = Mock(spec_set=GaussianCopula) sampled_data = pd.DataFrame({ - "column1": [28, 28, 21, 1, 2], - "column2": [37, 37, 1, 4, 5], - "column3": [93, 93, 6, 4, 12], + 'column1': [28, 28, 21, 1, 2], + 'column2': [37, 37, 1, 4, 5], + 'column3': [93, 93, 6, 4, 12], }) gaussian_copula._sample_batch.side_effect = [sampled_data, sampled_data] @@ -221,9 +221,9 @@ def test__sample_batch_with_batch_size_per_try(self): # Setup gaussian_copula = Mock(spec_set=GaussianCopula) sampled_data = pd.DataFrame({ - "column1": [28, 28, 21, 1, 2], - "column2": [37, 37, 1, 4, 5], - "column3": [93, 93, 6, 4, 12], + 'column1': [28, 28, 21, 1, 2], + 'column2': [37, 37, 1, 4, 5], + 'column3': [93, 93, 6, 4, 12], }) gaussian_copula._sample_rows.side_effect = [ (sampled_data, 5), @@ -361,10 +361,11 @@ def test__validate_file_path(self, path_mock): """ # Setup path_mock.exists.return_value = True + path_mock.abspath.return_value = 'path/to/file' gaussian_copula = Mock(spec_set=GaussianCopula) # Run and Assert - with pytest.raises(AssertionError): + with pytest.raises(AssertionError, match='path/to/file already exists'): BaseTabularModel._validate_file_path(gaussian_copula, 'file_path') From 37f962ed207ea475b03a1da0c5f33d1e3f6c9684 Mon Sep 17 00:00:00 2001 From: Katharine Xiao <2405771+katxiao@users.noreply.github.com> Date: Thu, 17 Feb 2022 17:45:55 -0500 Subject: [PATCH 6/6] fix test --- tests/integration/tabular/test_base.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/integration/tabular/test_base.py b/tests/integration/tabular/test_base.py index 98416a2c0..b90d4f4ca 100644 --- a/tests/integration/tabular/test_base.py +++ b/tests/integration/tabular/test_base.py @@ -265,7 +265,6 @@ def test_conditional_sampling_constraint_uses_reject_sampling(gm_mock): }) sample_calls = model._model.sample.mock_calls assert len(sample_calls) == 2 - model._model.sample.assert_any_call(5, conditions=expected_transformed_conditions) model._model.sample.assert_any_call(50, conditions=expected_transformed_conditions) pd.testing.assert_frame_equal(sampled_data, expected_data)