diff --git a/sdv/tabular/copulagan.py b/sdv/tabular/copulagan.py index ae95cb5c8..1d1d4641c 100644 --- a/sdv/tabular/copulagan.py +++ b/sdv/tabular/copulagan.py @@ -146,6 +146,9 @@ class CopulaGAN(CTGAN): """ DEFAULT_DISTRIBUTION = 'parametric' + _field_distributions = None + _default_distribution = None + _ht = None def __init__(self, field_names=None, field_types=None, field_transformers=None, anonymize_fields=None, primary_key=None, constraints=None, table_metadata=None, @@ -205,16 +208,20 @@ def _fit(self, table_data): Data to be learned. """ distributions = self._field_distributions - default = self._default_distribution fields = self._metadata.get_fields() - transformers = { - field: GaussianCopulaTransformer( - distribution=distributions.get(field, default) - ) - for field in table_data - if field.replace('.value', '') - in fields and fields.get(field, dict()).get('type') != 'categorical' - } + + transformers = {} + for field in table_data: + field_name = field.replace('.value', '') + + if field_name in fields and fields.get( + field_name, + dict(), + ).get('type') != 'categorical': + transformers[field] = GaussianCopulaTransformer( + distribution=distributions.get(field_name, self._default_distribution) + ) + self._ht = HyperTransformer(field_transformers=transformers) table_data = self._ht.fit_transform(table_data) diff --git a/tests/unit/tabular/test_copulagan.py b/tests/unit/tabular/test_copulagan.py new file mode 100644 index 000000000..a25becf36 --- /dev/null +++ b/tests/unit/tabular/test_copulagan.py @@ -0,0 +1,118 @@ +from unittest.mock import Mock, call, patch + +import pandas as pd +from rdt import HyperTransformer +from rdt.transformers import GaussianCopulaTransformer + +from sdv.tabular.copulagan import CopulaGAN +from tests.utils import DataFrameMatcher + + +class TestCopulaGAN: + + @patch('sdv.tabular.copulagan.CTGAN._fit') + @patch('sdv.tabular.copulagan.HyperTransformer', spec_set=HyperTransformer) + @patch('sdv.tabular.copulagan.GaussianCopulaTransformer', + spec_set=GaussianCopulaTransformer) + def test__fit(self, gct_mock, ht_mock, ctgan_fit_mock): + """Test the ``CopulaGAN._fit`` method. + + The ``_fit`` method is expected to: + - Build transformers for all the non-categorical data columns based on the + field distributions. + - Create a HyperTransformer with all the transformers. + - Fit and transform the data with the HyperTransformer. + - Call CTGAN fit. + + Setup: + - mock _field_distribution and _default_distribution to return the desired + distribution values + + Input: + - pandas.DataFrame + + Expected Output: + - None + + Side Effects: + - GaussianCopulaTransformer is called with the expected disributions. + - HyperTransformer is called to create a hyper transformer object. + - HyperTransformer fit_transform is called with the expected data. + - CTGAN's fit method is called with the expected data. + """ + # Setup + model = Mock(spec_set=CopulaGAN) + model._field_distributions = {'a': 'a_distribution'} + model._default_distribution = 'default_distribution' + model._metadata.get_fields.return_value = {'a': {}, 'b': {}, 'c': {'type': 'categorical'}} + + # Run + data = pd.DataFrame({ + 'a': [1, 2, 3], + 'b': [5, 6, 7], + 'c': ['c', 'c', 'c'], + }) + out = CopulaGAN._fit(model, data) + + # asserts + assert out is None + assert model._field_distributions == {'a': 'a_distribution'} + gct_mock.assert_has_calls([ + call(distribution='a_distribution'), + call(distribution='default_distribution'), + ]) + assert gct_mock.call_count == 2 + + assert model._ht == ht_mock.return_value + ht_mock.return_value.fit_transform.called_once_with(DataFrameMatcher(data)) + ctgan_fit_mock.called_once_with(DataFrameMatcher(data)) + + @patch('sdv.tabular.copulagan.CTGAN._fit') + @patch('sdv.tabular.copulagan.HyperTransformer', spec_set=HyperTransformer) + @patch('sdv.tabular.copulagan.GaussianCopulaTransformer', + spec_set=GaussianCopulaTransformer) + def test__fit_with_transformed_columns(self, gct_mock, ht_mock, ctgan_fit_mock): + """Test the ``CopulaGAN._fit`` method with transformed columns. + + The ``_fit`` method is expected to: + - Build transformers for all the data columns based on the field distributions. + - Create a HyperTransformer with all the transformers. + - Fit and transform the data with the HyperTransformer. + - Call CTGAN fit. + + Setup: + - mock _field_distribution and _default_distribution to return the desired + distribution values + + Input: + - pandas.DataFrame + + Expected Output: + - None + + Side Effects: + - GaussianCopulaTransformer is called with the expected disributions. + - HyperTransformer is called to create a hyper transformer object. + - HyperTransformer fit_transform is called with the expected data. + - CTGAN's fit method is called with the expected data. + """ + # Setup + model = Mock(spec_set=CopulaGAN) + model._field_distributions = {'a': 'a_distribution'} + model._default_distribution = 'default_distribution' + model._metadata.get_fields.return_value = {'a': {}} + + # Run + data = pd.DataFrame({ + 'a.value': [1, 2, 3] + }) + out = CopulaGAN._fit(model, data) + + # asserts + assert out is None + assert model._field_distributions == {'a': 'a_distribution'} + gct_mock.assert_called_once_with(distribution='a_distribution') + + assert model._ht == ht_mock.return_value + ht_mock.return_value.fit_transform.called_once_with(DataFrameMatcher(data)) + ctgan_fit_mock.called_once_with(DataFrameMatcher(data))