From 0fbd9fc979d7d76e4168b2657cfa54db7bfa023c Mon Sep 17 00:00:00 2001 From: Katharine Xiao <2405771+katxiao@users.noreply.github.com> Date: Fri, 11 Feb 2022 18:08:24 -0500 Subject: [PATCH] add unit tests --- sdv/tabular/base.py | 5 +- tests/unit/tabular/test_base.py | 90 ++++++++++++++++++++++++++++++++- 2 files changed, 93 insertions(+), 2 deletions(-) diff --git a/sdv/tabular/base.py b/sdv/tabular/base.py index b124f3a7d..2d9d24280 100644 --- a/sdv/tabular/base.py +++ b/sdv/tabular/base.py @@ -483,6 +483,9 @@ def sample(self, num_rows, randomize_samples=True, batch_size=None, output_file_ raise ValueError( 'Error: You must specify the number of rows to sample (eg. num_rows=100).') + if num_rows == 0: + return pd.DataFrame() + self._set_fixed_seed(randomize_samples) output_file_path = self._validate_file_path(output_file_path) @@ -655,7 +658,7 @@ def sample_conditions(self, conditions, max_tries=100, batch_size_per_try=None, return sampled def sample_remaining_columns(self, known_columns, max_tries=100, batch_size_per_try=None, - randomize_samples=True, outout_file_path=None): + randomize_samples=True, output_file_path=None): """Sample rows from this table. Args: diff --git a/tests/unit/tabular/test_base.py b/tests/unit/tabular/test_base.py index 12af2c37e..b534318b0 100644 --- a/tests/unit/tabular/test_base.py +++ b/tests/unit/tabular/test_base.py @@ -169,6 +169,73 @@ def test_sample_num_rows_none(self): with pytest.raises(ValueError): model.sample(num_rows) + def test_sample_batch_size(self): + """Test the `BaseTabularModel.sample` method with a valid `batch_size` argument. + + Expect that the expected calls to `_sample_batch` are made. + + Input: + - num_rows = 10 + - batch_size = 5 + Output: + - The requested number of sampled rows. + Side Effect: + - Call `_sample_batch` method twice with the expected number of rows. + """ + # Setup + gaussian_copula = Mock(spec_set=GaussianCopula) + sampled_data = pd.DataFrame({ + "column1": [28, 28, 21, 1, 2], + "column2": [37, 37, 1, 4, 5], + "column3": [93, 93, 6, 4, 12], + }) + gaussian_copula._sample_batch.side_effect = [sampled_data, sampled_data] + + # Run + output = BaseTabularModel.sample(gaussian_copula, 10, batch_size=5) + + # Assert + assert gaussian_copula._sample_batch.has_calls([ + call(5, batch_size_per_try=5, progress_bar=ANY, output_file_path=None), + call(5, batch_size_per_try=5, progress_bar=ANY, output_file_path=None), + ]) + assert len(output) == 10 + + def test__sample_batch_with_batch_size_per_try(self): + """Test the `BaseTabularModel._sample_batch` method with `batch_size_per_try`. + + Expect that the expected calls to `_sample_rows` are made. + + Input: + - num_rows = 10 + - batch_size_per_try = 5 + Output: + - The requested number of sampled rows. + Side Effect: + - Call `_sample_rows` method twice with the expected number of rows. + """ + # Setup + gaussian_copula = Mock(spec_set=GaussianCopula) + sampled_data = pd.DataFrame({ + "column1": [28, 28, 21, 1, 2], + "column2": [37, 37, 1, 4, 5], + "column3": [93, 93, 6, 4, 12], + }) + gaussian_copula._sample_rows.side_effect = [ + (sampled_data, 5), + (sampled_data.append(sampled_data, ignore_index=False), 10), + ] + + # Run + output = BaseTabularModel._sample_batch(gaussian_copula, num_rows=10, batch_size_per_try=5) + + # Assert + assert gaussian_copula._sample_rows.has_calls([ + call(5, None, None, 0.01, DataFrameMatcher(pd.DataFrame())), + call(5, None, None, 0.01, DataFrameMatcher(sampled_data)), + ]) + assert len(output) == 10 + def test_sample_randomize_samples_true(self): """Test the `BaseTabularModel.sample` method with `randomize_samples` set to True. @@ -206,6 +273,7 @@ def test_sample_conditions_with_multiple_conditions(self): """ # Setup gaussian_copula = Mock(spec_set=GaussianCopula) + gaussian_copula._validate_file_path.return_value = None condition_values1 = {'cola': 'a'} condition1 = Condition(condition_values1, num_rows=2) @@ -256,6 +324,7 @@ def test_sample_remaining_columns(self): """ # Setup gaussian_copula = Mock(spec_set=GaussianCopula) + gaussian_copula._validate_file_path.return_value = None conditions = pd.DataFrame([{'cola': 'a'}] * 5) @@ -270,7 +339,7 @@ def test_sample_remaining_columns(self): # Asserts gaussian_copula._sample_with_conditions.assert_called_once_with( - DataFrameMatcher(conditions), 100, None) + DataFrameMatcher(conditions), 100, None, ANY, None) pd.testing.assert_frame_equal(out, sampled) def test__sample_with_conditions_invalid_column(self): @@ -297,6 +366,25 @@ def test__sample_with_conditions_invalid_column(self): with pytest.raises(ValueError): GaussianCopula._sample_with_conditions(gaussian_copula, conditions, 100, None) + @patch('sdv.tabular.base.os.path') + def test__validate_file_path(self, path_mock): + """Test the `BaseTabularModel._validate_file_path` method. + + Expect that an error is thrown if the file path already exists. + + Input: + - A file path that already exists. + Side Effects: + - An AssertionError. + """ + # Setup + path_mock.exists.return_value = True + gaussian_copula = Mock(spec_set=GaussianCopula) + + # Run and Assert + with pytest.raises(AssertionError): + BaseTabularModel._validate_file_path(gaussian_copula, 'file_path') + @patch('sdv.tabular.base.Table', spec_set=Table) def test__init__passes_correct_parameters(metadata_mock):