Skip to content

Commit

Permalink
Modularize GaussianCopulaSynthesizer fit
Browse files Browse the repository at this point in the history
  • Loading branch information
pvk-developer committed Oct 31, 2024
1 parent 05f7494 commit f5d32da
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 2 deletions.
16 changes: 14 additions & 2 deletions sdv/single_table/copulas.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,16 +135,28 @@ def _fit(self, processed_data):
log_numerical_distributions_error(
self.numerical_distributions, processed_data.columns, LOGGER
)
self._num_rows = len(processed_data)
self._num_rows = self._learn_num_rows(processed_data)
numerical_distributions = self._get_numerical_distributions(processed_data)
self._model = self._initialize_model(numerical_distributions)
self._fit_model(processed_data)

def _learn_num_rows(self, processed_data):
return len(processed_data)

def _get_numerical_distributions(self, processed_data):
numerical_distributions = deepcopy(self._numerical_distributions)
for column in processed_data.columns:
if column not in numerical_distributions:
numerical_distributions[column] = self._numerical_distributions.get(
column, self._default_distribution
)
self._model = multivariate.GaussianMultivariate(distribution=numerical_distributions)

return numerical_distributions

def _initialize_model(self, numerical_distributions):
return multivariate.GaussianMultivariate(distribution=numerical_distributions)

def _fit_model(self, processed_data):
with warnings.catch_warnings():
warnings.filterwarnings('ignore', module='scipy')
self._model.fit(processed_data)
Expand Down
17 changes: 17 additions & 0 deletions tests/unit/single_table/test_copulas.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,23 @@ def test__fit(self, mock_multivariate, mock_warnings):
mock_warnings.catch_warnings.assert_called_once()
instance._num_rows == 10

def test__fit_mocked_instance(self):
"""Test that the `_fit` method calls the modularized functions."""
# Setup
instance = Mock(numerical_distributions={})
processed_data = Mock(columns=[])
numerical_distributions = Mock()
instance._get_numerical_distributions.return_value = numerical_distributions

# Run
GaussianCopulaSynthesizer._fit(instance, processed_data)

# Assert
instance._learn_num_rows.assert_called_once_with(processed_data)
instance._get_numerical_distributions.assert_called_once_with(processed_data)
instance._initialize_model.assert_called_once_with(numerical_distributions)
instance._fit_model.assert_called_once_with(processed_data)

def test__get_nearest_correlation_matrix_valid(self):
"""Test ``_get_nearest_correlation_matrix`` with a psd input.
Expand Down

0 comments on commit f5d32da

Please sign in to comment.