Skip to content

Commit

Permalink
Add unit test for new functions
Browse files Browse the repository at this point in the history
  • Loading branch information
pvk-developer committed Oct 31, 2024
1 parent f5d32da commit 2e55b66
Showing 1 changed file with 60 additions and 0 deletions.
60 changes: 60 additions & 0 deletions tests/unit/single_table/test_copulas.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,66 @@ def test__fit_mocked_instance(self):
instance._initialize_model.assert_called_once_with(numerical_distributions)
instance._fit_model.assert_called_once_with(processed_data)

def test__learn_num_rows(self):
"""Test that the `_learn_num_rows` method returns the correct number of rows."""
# Setup
metadata = Metadata()
instance = GaussianCopulaSynthesizer(metadata)
processed_data = pd.DataFrame({'a': range(5), 'b': range(5)})

# Run
result = instance._learn_num_rows(processed_data)

# Assert
assert result == 5

def test__get_numerical_distributions_with_existing_columns(self):
"""Test that `_get_numerical_distributions` returns correct distributions."""
# Setup
metadata = Metadata()
instance = GaussianCopulaSynthesizer(metadata)
instance._numerical_distributions = {'a': 'dist_a', 'b': 'dist_b'}
instance._default_distribution = 'default_dist'

processed_data = Mock()
processed_data.columns = ['a', 'b', 'c']

# Run
result = instance._get_numerical_distributions(processed_data)

# Assert
expected_result = {'a': 'dist_a', 'b': 'dist_b', 'c': 'default_dist'}
assert result == expected_result

@patch('sdv.single_table.copulas.multivariate.GaussianMultivariate')
def test__initialize_model(self, mock_gaussian_multivariate):
"""Test that `_initialize_model` calls the GaussianMultivariate with correct parameters."""
# Setup
metadata = Metadata()
instance = GaussianCopulaSynthesizer(metadata)
numerical_distributions = {'a': 'dist_a', 'b': 'dist_b'}

# Run
model = instance._initialize_model(numerical_distributions)

# Assert
mock_gaussian_multivariate.assert_called_once_with(distribution=numerical_distributions)
assert model == mock_gaussian_multivariate.return_value

def test__fit_model(self):
"""Test that `_fit_model` fits the model correctly."""
# Setup
metadata = Metadata()
instance = GaussianCopulaSynthesizer(metadata)
instance._model = Mock()
processed_data = Mock()

# Run
instance._fit_model(processed_data)

# Assert
instance._model.fit.assert_called_once_with(processed_data)

def test__get_nearest_correlation_matrix_valid(self):
"""Test ``_get_nearest_correlation_matrix`` with a psd input.
Expand Down

0 comments on commit 2e55b66

Please sign in to comment.