Skip to content

Commit

Permalink
cr comments
Browse files Browse the repository at this point in the history
  • Loading branch information
katxiao committed Feb 17, 2022
1 parent ad146d3 commit 50b1585
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 10 deletions.
3 changes: 0 additions & 3 deletions sdv/tabular/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,16 +444,13 @@ def _conditionally_sample_rows(self, dataframe, condition, transformed_condition

def _validate_file_path(self, output_file_path):
output_path = None

if output_file_path:
output_path = os.path.abspath(output_file_path)

if os.path.exists(output_path):
raise AssertionError(f'{output_path} already exists.')

return output_path


@validate_sample_args
def sample(self, num_rows, randomize_samples=True, batch_size=None, output_file_path=None):
"""Sample rows from this table.
Expand Down
15 changes: 8 additions & 7 deletions tests/unit/tabular/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,9 @@ def test_sample_batch_size(self):
# Setup
gaussian_copula = Mock(spec_set=GaussianCopula)
sampled_data = pd.DataFrame({
"column1": [28, 28, 21, 1, 2],
"column2": [37, 37, 1, 4, 5],
"column3": [93, 93, 6, 4, 12],
'column1': [28, 28, 21, 1, 2],
'column2': [37, 37, 1, 4, 5],
'column3': [93, 93, 6, 4, 12],
})
gaussian_copula._sample_batch.side_effect = [sampled_data, sampled_data]

Expand Down Expand Up @@ -221,9 +221,9 @@ def test__sample_batch_with_batch_size_per_try(self):
# Setup
gaussian_copula = Mock(spec_set=GaussianCopula)
sampled_data = pd.DataFrame({
"column1": [28, 28, 21, 1, 2],
"column2": [37, 37, 1, 4, 5],
"column3": [93, 93, 6, 4, 12],
'column1': [28, 28, 21, 1, 2],
'column2': [37, 37, 1, 4, 5],
'column3': [93, 93, 6, 4, 12],
})
gaussian_copula._sample_rows.side_effect = [
(sampled_data, 5),
Expand Down Expand Up @@ -359,10 +359,11 @@ def test__validate_file_path(self, path_mock):
"""
# Setup
path_mock.exists.return_value = True
path_mock.abspath.return_value = 'path/to/file'
gaussian_copula = Mock(spec_set=GaussianCopula)

# Run and Assert
with pytest.raises(AssertionError):
with pytest.raises(AssertionError, match='path/to/file already exists'):
BaseTabularModel._validate_file_path(gaussian_copula, 'file_path')


Expand Down

0 comments on commit 50b1585

Please sign in to comment.