From acdceb720430f826091d7a32da7bc73bce533ec5 Mon Sep 17 00:00:00 2001 From: Katharine Xiao <2405771+katxiao@users.noreply.github.com> Date: Thu, 24 Mar 2022 17:05:56 -0700 Subject: [PATCH] Fix field distribution arg in GaussianCopula (#743) * Fix field distribution setting * add unit test * cr --- sdv/tabular/copulas.py | 8 +++-- tests/unit/tabular/test_copulas.py | 56 ++++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+), 3 deletions(-) diff --git a/sdv/tabular/copulas.py b/sdv/tabular/copulas.py index bd0b38074..de8f5244a 100644 --- a/sdv/tabular/copulas.py +++ b/sdv/tabular/copulas.py @@ -289,9 +289,11 @@ def _fit(self, table_data): Data to be fitted. """ for column in table_data.columns: - distribution = self._field_distributions.get(column) - if not distribution: - self._field_distributions[column] = self._default_distribution + if column not in self._field_distributions: + # Check if the column is a derived column. + column_name = column.replace('.value', '') + self._field_distributions[column] = self._field_distributions.get( + column_name, self._default_distribution) self._model = copulas.multivariate.GaussianMultivariate( distribution=self._field_distributions) diff --git a/tests/unit/tabular/test_copulas.py b/tests/unit/tabular/test_copulas.py index 4ffccb50c..9df6976f6 100644 --- a/tests/unit/tabular/test_copulas.py +++ b/tests/unit/tabular/test_copulas.py @@ -241,6 +241,62 @@ def test__fit(self, gm_mock): pd.testing.assert_frame_equal(expected_data, passed_table_data) gaussian_copula._update_metadata.assert_called_once_with() + @patch('sdv.tabular.copulas.copulas.multivariate.GaussianMultivariate', + spec_set=GaussianMultivariate) + def test__fit_with_transformed_columns(self, gm_mock): + """Test the ``GaussianCopula._fit`` method with transformed columns. + + The ``_fit`` method is expected to: + - Call the _get_distribution method to build the distributions dict. + - Set the output from _get_distribution method as self._distribution. + - Create a GaussianMultivriate object with the self._distribution value. + - Store the GaussianMultivariate instance in the self._model attribute. + - Fit the GaussianMultivariate instance with the given table data, unmodified. + - Call the _update_metadata method. + + Setup: + - mock _get_distribution to return a distribution dict + + Input: + - pandas.DataFrame + + Expected Output: + - None + + Side Effects: + - self._distribution is set to the output from _get_distribution + - GaussianMultivariate is called with self._distribution as input + - GaussianMultivariate output is stored as self._model + - self._model.fit is called with the input dataframe + - self._update_metadata is called without arguments + """ + # Setup + gaussian_copula = Mock(spec_set=GaussianCopula) + gaussian_copula._field_distributions = {'a': 'a_distribution'} + + # Run + data = pd.DataFrame({ + 'a.value': [1, 2, 3] + }) + out = GaussianCopula._fit(gaussian_copula, data) + + # asserts + assert out is None + assert gaussian_copula._field_distributions == { + 'a': 'a_distribution', 'a.value': 'a_distribution'} + gm_mock.assert_called_once_with( + distribution={'a': 'a_distribution', 'a.value': 'a_distribution'}) + + assert gaussian_copula._model == gm_mock.return_value + expected_data = pd.DataFrame({ + 'a.value': [1, 2, 3] + }) + call_args = gaussian_copula._model.fit.call_args_list + passed_table_data = call_args[0][0][0] + + pd.testing.assert_frame_equal(expected_data, passed_table_data) + gaussian_copula._update_metadata.assert_called_once_with() + def test__sample(self): """Test the ``GaussianCopula._sample`` method.