From 0f0cf5d5fd57ec058221b1703e5ef3aa054411ad Mon Sep 17 00:00:00 2001 From: Katharine Xiao <2405771+katxiao@users.noreply.github.com> Date: Wed, 30 Mar 2022 11:35:06 -0400 Subject: [PATCH] address cr --- sdv/tabular/copulagan.py | 21 ++++++++++++--------- tests/unit/tabular/test_copulagan.py | 17 ++++++++++++----- 2 files changed, 24 insertions(+), 14 deletions(-) diff --git a/sdv/tabular/copulagan.py b/sdv/tabular/copulagan.py index 50befd204..283fa6399 100644 --- a/sdv/tabular/copulagan.py +++ b/sdv/tabular/copulagan.py @@ -210,16 +210,19 @@ def _fit(self, table_data): distributions = self._field_distributions default = self._default_distribution fields = self._metadata.get_fields() - transformers = { - field: GaussianCopulaTransformer( - distribution=distributions.get(field.replace('.value', ''), default) - ) - for field in table_data - if field.replace('.value', '') in fields and fields.get( - field.replace('.value', ''), + + transformers = {} + for field in table_data: + field_name = field.replace('.value', '') + + if field_name in fields and fields.get( + field_name, dict(), - ).get('type') != 'categorical' - } + ).get('type') != 'categorical': + transformers[field] = GaussianCopulaTransformer( + distribution=distributions.get(field.replace('.value', ''), default) + ) + self._ht = HyperTransformer(field_transformers=transformers) table_data = self._ht.fit_transform(table_data) diff --git a/tests/unit/tabular/test_copulagan.py b/tests/unit/tabular/test_copulagan.py index 99fac8b3b..a25becf36 100644 --- a/tests/unit/tabular/test_copulagan.py +++ b/tests/unit/tabular/test_copulagan.py @@ -1,4 +1,4 @@ -from unittest.mock import Mock, patch +from unittest.mock import Mock, call, patch import pandas as pd from rdt import HyperTransformer @@ -18,7 +18,8 @@ def test__fit(self, gct_mock, ht_mock, ctgan_fit_mock): """Test the ``CopulaGAN._fit`` method. The ``_fit`` method is expected to: - - Build transformers for all the data columns based on the field distributions. + - Build transformers for all the non-categorical data columns based on the + field distributions. - Create a HyperTransformer with all the transformers. - Fit and transform the data with the HyperTransformer. - Call CTGAN fit. @@ -43,18 +44,24 @@ def test__fit(self, gct_mock, ht_mock, ctgan_fit_mock): model = Mock(spec_set=CopulaGAN) model._field_distributions = {'a': 'a_distribution'} model._default_distribution = 'default_distribution' - model._metadata.get_fields.return_value = {'a': {}} + model._metadata.get_fields.return_value = {'a': {}, 'b': {}, 'c': {'type': 'categorical'}} # Run data = pd.DataFrame({ - 'a': [1, 2, 3] + 'a': [1, 2, 3], + 'b': [5, 6, 7], + 'c': ['c', 'c', 'c'], }) out = CopulaGAN._fit(model, data) # asserts assert out is None assert model._field_distributions == {'a': 'a_distribution'} - gct_mock.assert_called_once_with(distribution='a_distribution') + gct_mock.assert_has_calls([ + call(distribution='a_distribution'), + call(distribution='default_distribution'), + ]) + assert gct_mock.call_count == 2 assert model._ht == ht_mock.return_value ht_mock.return_value.fit_transform.called_once_with(DataFrameMatcher(data))