Skip to content

Commit

Permalink
add unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
katxiao committed Feb 11, 2022
1 parent c44641f commit 30bd3cd
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 2 deletions.
2 changes: 1 addition & 1 deletion sdv/tabular/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,7 +655,7 @@ def sample_conditions(self, conditions, max_tries=100, batch_size_per_try=None,
return sampled

def sample_remaining_columns(self, known_columns, max_tries=100, batch_size_per_try=None,
randomize_samples=True, outout_file_path=None):
randomize_samples=True, output_file_path=None):
"""Sample rows from this table.
Args:
Expand Down
90 changes: 89 additions & 1 deletion tests/unit/tabular/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,73 @@ def test_sample_num_rows_none(self):
with pytest.raises(ValueError):
model.sample(num_rows)

def test_sample_batch_size(self):
"""Test the `BaseTabularModel.sample` method with a valid `batch_size` argument.
Expect that the expected calls to `_sample_batch` are made.
Input:
- num_rows = 10
- batch_size = 5
Output:
- The requested number of sampled rows.
Side Effect:
- Call `_sample_batch` method twice with the expected number of rows.
"""
# Setup
gaussian_copula = Mock(spec_set=GaussianCopula)
sampled_data = pd.DataFrame({
"column1": [28, 28, 21, 1, 2],
"column2": [37, 37, 1, 4, 5],
"column3": [93, 93, 6, 4, 12],
})
gaussian_copula._sample_batch.side_effect = [sampled_data, sampled_data]

# Run
output = BaseTabularModel.sample(gaussian_copula, 10, batch_size=5)

# Assert
assert gaussian_copula._sample_batch.has_calls([
call(5, batch_size_per_try=5, progress_bar=ANY, output_file_path=None),
call(5, batch_size_per_try=5, progress_bar=ANY, output_file_path=None),
])
assert len(output) == 10

def test__sample_batch_with_batch_size_per_try(self):
"""Test the `BaseTabularModel._sample_batch` method with `batch_size_per_try`.
Expect that the expected calls to `_sample_rows` are made.
Input:
- num_rows = 10
- batch_size_per_try = 5
Output:
- The requested number of sampled rows.
Side Effect:
- Call `_sample_rows` method twice with the expected number of rows.
"""
# Setup
gaussian_copula = Mock(spec_set=GaussianCopula)
sampled_data = pd.DataFrame({
"column1": [28, 28, 21, 1, 2],
"column2": [37, 37, 1, 4, 5],
"column3": [93, 93, 6, 4, 12],
})
gaussian_copula._sample_rows.side_effect = [
(sampled_data, 5),
(sampled_data.append(sampled_data, ignore_index=False), 10),
]

# Run
output = BaseTabularModel._sample_batch(gaussian_copula, num_rows=10, batch_size_per_try=5)

# Assert
assert gaussian_copula._sample_rows.has_calls([
call(5, None, None, 0.01, DataFrameMatcher(pd.DataFrame())),
call(5, None, None, 0.01, DataFrameMatcher(sampled_data)),
])
assert len(output) == 10

def test_sample_randomize_samples_true(self):
"""Test the `BaseTabularModel.sample` method with `randomize_samples` set to True.
Expand Down Expand Up @@ -206,6 +273,7 @@ def test_sample_conditions_with_multiple_conditions(self):
"""
# Setup
gaussian_copula = Mock(spec_set=GaussianCopula)
gaussian_copula._validate_file_path.return_value = None

condition_values1 = {'cola': 'a'}
condition1 = Condition(condition_values1, num_rows=2)
Expand Down Expand Up @@ -256,6 +324,7 @@ def test_sample_remaining_columns(self):
"""
# Setup
gaussian_copula = Mock(spec_set=GaussianCopula)
gaussian_copula._validate_file_path.return_value = None

conditions = pd.DataFrame([{'cola': 'a'}] * 5)

Expand All @@ -270,7 +339,7 @@ def test_sample_remaining_columns(self):

# Asserts
gaussian_copula._sample_with_conditions.assert_called_once_with(
DataFrameMatcher(conditions), 100, None)
DataFrameMatcher(conditions), 100, None, ANY, None)
pd.testing.assert_frame_equal(out, sampled)

def test__sample_with_conditions_invalid_column(self):
Expand All @@ -297,6 +366,25 @@ def test__sample_with_conditions_invalid_column(self):
with pytest.raises(ValueError):
GaussianCopula._sample_with_conditions(gaussian_copula, conditions, 100, None)

@patch('sdv.tabular.base.os.path')
def test__validate_file_path(self, path_mock):
"""Test the `BaseTabularModel._validate_file_path` method.
Expect that an error is thrown if the file path already exists.
Input:
- A file path that already exists.
Side Effects:
- An AssertionError.
"""
# Setup
path_mock.exists.return_value = True
gaussian_copula = Mock(spec_set=GaussianCopula)

# Run and Assert
with pytest.raises(AssertionError):
BaseTabularModel._validate_file_path(gaussian_copula, 'file_path')


@patch('sdv.tabular.base.Table', spec_set=Table)
def test__init__passes_correct_parameters(metadata_mock):
Expand Down

0 comments on commit 30bd3cd

Please sign in to comment.