Skip to content

Commit

Permalink
Fix field distribution arg in GaussianCopula (#743)
Browse files Browse the repository at this point in the history
* Fix field distribution setting

* add unit test

* cr
  • Loading branch information
katxiao authored Mar 25, 2022
1 parent d8ff924 commit acdceb7
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 3 deletions.
8 changes: 5 additions & 3 deletions sdv/tabular/copulas.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,9 +289,11 @@ def _fit(self, table_data):
Data to be fitted.
"""
for column in table_data.columns:
distribution = self._field_distributions.get(column)
if not distribution:
self._field_distributions[column] = self._default_distribution
if column not in self._field_distributions:
# Check if the column is a derived column.
column_name = column.replace('.value', '')
self._field_distributions[column] = self._field_distributions.get(
column_name, self._default_distribution)

self._model = copulas.multivariate.GaussianMultivariate(
distribution=self._field_distributions)
Expand Down
56 changes: 56 additions & 0 deletions tests/unit/tabular/test_copulas.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,62 @@ def test__fit(self, gm_mock):
pd.testing.assert_frame_equal(expected_data, passed_table_data)
gaussian_copula._update_metadata.assert_called_once_with()

@patch('sdv.tabular.copulas.copulas.multivariate.GaussianMultivariate',
spec_set=GaussianMultivariate)
def test__fit_with_transformed_columns(self, gm_mock):
"""Test the ``GaussianCopula._fit`` method with transformed columns.
The ``_fit`` method is expected to:
- Call the _get_distribution method to build the distributions dict.
- Set the output from _get_distribution method as self._distribution.
- Create a GaussianMultivriate object with the self._distribution value.
- Store the GaussianMultivariate instance in the self._model attribute.
- Fit the GaussianMultivariate instance with the given table data, unmodified.
- Call the _update_metadata method.
Setup:
- mock _get_distribution to return a distribution dict
Input:
- pandas.DataFrame
Expected Output:
- None
Side Effects:
- self._distribution is set to the output from _get_distribution
- GaussianMultivariate is called with self._distribution as input
- GaussianMultivariate output is stored as self._model
- self._model.fit is called with the input dataframe
- self._update_metadata is called without arguments
"""
# Setup
gaussian_copula = Mock(spec_set=GaussianCopula)
gaussian_copula._field_distributions = {'a': 'a_distribution'}

# Run
data = pd.DataFrame({
'a.value': [1, 2, 3]
})
out = GaussianCopula._fit(gaussian_copula, data)

# asserts
assert out is None
assert gaussian_copula._field_distributions == {
'a': 'a_distribution', 'a.value': 'a_distribution'}
gm_mock.assert_called_once_with(
distribution={'a': 'a_distribution', 'a.value': 'a_distribution'})

assert gaussian_copula._model == gm_mock.return_value
expected_data = pd.DataFrame({
'a.value': [1, 2, 3]
})
call_args = gaussian_copula._model.fit.call_args_list
passed_table_data = call_args[0][0][0]

pd.testing.assert_frame_equal(expected_data, passed_table_data)
gaussian_copula._update_metadata.assert_called_once_with()

def test__sample(self):
"""Test the ``GaussianCopula._sample`` method.
Expand Down

0 comments on commit acdceb7

Please sign in to comment.