From 9643a6a6cacf10f5d68cd4953b6c21846f19a736 Mon Sep 17 00:00:00 2001 From: Katharine Xiao <2405771+katxiao@users.noreply.github.com> Date: Wed, 16 Mar 2022 21:53:07 -0400 Subject: [PATCH 1/3] Fix field distribution setting --- sdv/tabular/copulas.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/sdv/tabular/copulas.py b/sdv/tabular/copulas.py index bd0b38074..9b5a5ffaf 100644 --- a/sdv/tabular/copulas.py +++ b/sdv/tabular/copulas.py @@ -289,9 +289,11 @@ def _fit(self, table_data): Data to be fitted. """ for column in table_data.columns: - distribution = self._field_distributions.get(column) - if not distribution: - self._field_distributions[column] = self._default_distribution + if column not in self._field_distributions: + # Check if the column is a derived column. + column_name = column.replace('.value', '').replace('.is_null', '') + self._field_distributions[column] = self._field_distribution.get( + column_name, self._default_distribution) self._model = copulas.multivariate.GaussianMultivariate( distribution=self._field_distributions) From 30a47c81c87f9a05a256f16e5a51df0bc42ea43d Mon Sep 17 00:00:00 2001 From: Katharine Xiao <2405771+katxiao@users.noreply.github.com> Date: Wed, 23 Mar 2022 15:07:12 -0400 Subject: [PATCH 2/3] add unit test --- sdv/tabular/copulas.py | 2 +- tests/unit/tabular/test_copulas.py | 56 ++++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 1 deletion(-) diff --git a/sdv/tabular/copulas.py b/sdv/tabular/copulas.py index 9b5a5ffaf..7fa2088de 100644 --- a/sdv/tabular/copulas.py +++ b/sdv/tabular/copulas.py @@ -292,7 +292,7 @@ def _fit(self, table_data): if column not in self._field_distributions: # Check if the column is a derived column. column_name = column.replace('.value', '').replace('.is_null', '') - self._field_distributions[column] = self._field_distribution.get( + self._field_distributions[column] = self._field_distributions.get( column_name, self._default_distribution) self._model = copulas.multivariate.GaussianMultivariate( diff --git a/tests/unit/tabular/test_copulas.py b/tests/unit/tabular/test_copulas.py index 4ffccb50c..9df6976f6 100644 --- a/tests/unit/tabular/test_copulas.py +++ b/tests/unit/tabular/test_copulas.py @@ -241,6 +241,62 @@ def test__fit(self, gm_mock): pd.testing.assert_frame_equal(expected_data, passed_table_data) gaussian_copula._update_metadata.assert_called_once_with() + @patch('sdv.tabular.copulas.copulas.multivariate.GaussianMultivariate', + spec_set=GaussianMultivariate) + def test__fit_with_transformed_columns(self, gm_mock): + """Test the ``GaussianCopula._fit`` method with transformed columns. + + The ``_fit`` method is expected to: + - Call the _get_distribution method to build the distributions dict. + - Set the output from _get_distribution method as self._distribution. + - Create a GaussianMultivriate object with the self._distribution value. + - Store the GaussianMultivariate instance in the self._model attribute. + - Fit the GaussianMultivariate instance with the given table data, unmodified. + - Call the _update_metadata method. + + Setup: + - mock _get_distribution to return a distribution dict + + Input: + - pandas.DataFrame + + Expected Output: + - None + + Side Effects: + - self._distribution is set to the output from _get_distribution + - GaussianMultivariate is called with self._distribution as input + - GaussianMultivariate output is stored as self._model + - self._model.fit is called with the input dataframe + - self._update_metadata is called without arguments + """ + # Setup + gaussian_copula = Mock(spec_set=GaussianCopula) + gaussian_copula._field_distributions = {'a': 'a_distribution'} + + # Run + data = pd.DataFrame({ + 'a.value': [1, 2, 3] + }) + out = GaussianCopula._fit(gaussian_copula, data) + + # asserts + assert out is None + assert gaussian_copula._field_distributions == { + 'a': 'a_distribution', 'a.value': 'a_distribution'} + gm_mock.assert_called_once_with( + distribution={'a': 'a_distribution', 'a.value': 'a_distribution'}) + + assert gaussian_copula._model == gm_mock.return_value + expected_data = pd.DataFrame({ + 'a.value': [1, 2, 3] + }) + call_args = gaussian_copula._model.fit.call_args_list + passed_table_data = call_args[0][0][0] + + pd.testing.assert_frame_equal(expected_data, passed_table_data) + gaussian_copula._update_metadata.assert_called_once_with() + def test__sample(self): """Test the ``GaussianCopula._sample`` method. From cd92c10b209fbfc9bb670dbdb303a82223073787 Mon Sep 17 00:00:00 2001 From: Katharine Xiao <2405771+katxiao@users.noreply.github.com> Date: Wed, 23 Mar 2022 19:16:04 -0400 Subject: [PATCH 3/3] cr --- sdv/tabular/copulas.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdv/tabular/copulas.py b/sdv/tabular/copulas.py index 7fa2088de..de8f5244a 100644 --- a/sdv/tabular/copulas.py +++ b/sdv/tabular/copulas.py @@ -291,7 +291,7 @@ def _fit(self, table_data): for column in table_data.columns: if column not in self._field_distributions: # Check if the column is a derived column. - column_name = column.replace('.value', '').replace('.is_null', '') + column_name = column.replace('.value', '') self._field_distributions[column] = self._field_distributions.get( column_name, self._default_distribution)