Skip to content

Commit

Permalink
address cr
Browse files Browse the repository at this point in the history
  • Loading branch information
katxiao committed Mar 30, 2022
1 parent cf1d363 commit aa4291c
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 14 deletions.
21 changes: 12 additions & 9 deletions sdv/tabular/copulagan.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,16 +210,19 @@ def _fit(self, table_data):
distributions = self._field_distributions
default = self._default_distribution
fields = self._metadata.get_fields()
transformers = {
field: GaussianCopulaTransformer(
distribution=distributions.get(field.replace('.value', ''), default)
)
for field in table_data
if field.replace('.value', '') in fields and fields.get(
field.replace('.value', ''),

transformers = {}
for field in table_data:
field_name = field.replace('.value', '')

if field_name in fields and fields.get(
field_name,
dict(),
).get('type') != 'categorical'
}
).get('type') != 'categorical':
transformers[field] = GaussianCopulaTransformer(
distribution=distributions.get(field.replace('.value', ''), default)
)

self._ht = HyperTransformer(field_transformers=transformers)
table_data = self._ht.fit_transform(table_data)

Expand Down
16 changes: 11 additions & 5 deletions tests/unit/tabular/test_copulagan.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from unittest.mock import Mock, patch
from unittest.mock import Mock, call, patch

import pandas as pd
from rdt import HyperTransformer
Expand All @@ -18,7 +18,8 @@ def test__fit(self, gct_mock, ht_mock, ctgan_fit_mock):
"""Test the ``CopulaGAN._fit`` method.
The ``_fit`` method is expected to:
- Build transformers for all the data columns based on the field distributions.
- Build transformers for all the non-categorical data columns based on the
field distributions.
- Create a HyperTransformer with all the transformers.
- Fit and transform the data with the HyperTransformer.
- Call CTGAN fit.
Expand All @@ -43,18 +44,23 @@ def test__fit(self, gct_mock, ht_mock, ctgan_fit_mock):
model = Mock(spec_set=CopulaGAN)
model._field_distributions = {'a': 'a_distribution'}
model._default_distribution = 'default_distribution'
model._metadata.get_fields.return_value = {'a': {}}
model._metadata.get_fields.return_value = {'a': {}, 'b': {}, 'c': {'type': 'categorical'}}

# Run
data = pd.DataFrame({
'a': [1, 2, 3]
'a': [1, 2, 3],
'b': [5, 6, 7],
'c': ['c', 'c', 'c'],
})
out = CopulaGAN._fit(model, data)

# asserts
assert out is None
assert model._field_distributions == {'a': 'a_distribution'}
gct_mock.assert_called_once_with(distribution='a_distribution')
gct_mock.assert_has_calls([
call(distribution='a_distribution'),
call(distribution='default_distribution'),
])

assert model._ht == ht_mock.return_value
ht_mock.return_value.fit_transform.called_once_with(DataFrameMatcher(data))
Expand Down

0 comments on commit aa4291c

Please sign in to comment.