Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix field distribution arg in GaussianCopula #743

Merged
merged 3 commits into from
Mar 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions sdv/tabular/copulas.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,9 +289,11 @@ def _fit(self, table_data):
Data to be fitted.
"""
for column in table_data.columns:
distribution = self._field_distributions.get(column)
if not distribution:
self._field_distributions[column] = self._default_distribution
if column not in self._field_distributions:
# Check if the column is a derived column.
column_name = column.replace('.value', '')
self._field_distributions[column] = self._field_distributions.get(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing to note is that the _field_distributions dictionary will have some field names that match the RDT output, and some that match the input. For example, if null columns are created and there is a column 'a', then the dictionary will have

{
    'a': dist,
    'a.is_null': default_dist
}

Idk if it makes sense to have them all match the HyperTransformer output names or not (ie. keep the .value extension)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I created #744 to track this question, not completely sure what we should do right now.

column_name, self._default_distribution)

self._model = copulas.multivariate.GaussianMultivariate(
distribution=self._field_distributions)
Expand Down
56 changes: 56 additions & 0 deletions tests/unit/tabular/test_copulas.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,62 @@ def test__fit(self, gm_mock):
pd.testing.assert_frame_equal(expected_data, passed_table_data)
gaussian_copula._update_metadata.assert_called_once_with()

@patch('sdv.tabular.copulas.copulas.multivariate.GaussianMultivariate',
spec_set=GaussianMultivariate)
def test__fit_with_transformed_columns(self, gm_mock):
"""Test the ``GaussianCopula._fit`` method with transformed columns.

The ``_fit`` method is expected to:
- Call the _get_distribution method to build the distributions dict.
- Set the output from _get_distribution method as self._distribution.
- Create a GaussianMultivriate object with the self._distribution value.
- Store the GaussianMultivariate instance in the self._model attribute.
- Fit the GaussianMultivariate instance with the given table data, unmodified.
- Call the _update_metadata method.

Setup:
- mock _get_distribution to return a distribution dict

Input:
- pandas.DataFrame

Expected Output:
- None

Side Effects:
- self._distribution is set to the output from _get_distribution
- GaussianMultivariate is called with self._distribution as input
- GaussianMultivariate output is stored as self._model
- self._model.fit is called with the input dataframe
- self._update_metadata is called without arguments
"""
# Setup
gaussian_copula = Mock(spec_set=GaussianCopula)
gaussian_copula._field_distributions = {'a': 'a_distribution'}

# Run
data = pd.DataFrame({
'a.value': [1, 2, 3]
})
out = GaussianCopula._fit(gaussian_copula, data)

# asserts
assert out is None
assert gaussian_copula._field_distributions == {
'a': 'a_distribution', 'a.value': 'a_distribution'}
gm_mock.assert_called_once_with(
distribution={'a': 'a_distribution', 'a.value': 'a_distribution'})

assert gaussian_copula._model == gm_mock.return_value
expected_data = pd.DataFrame({
'a.value': [1, 2, 3]
})
call_args = gaussian_copula._model.fit.call_args_list
passed_table_data = call_args[0][0][0]

pd.testing.assert_frame_equal(expected_data, passed_table_data)
gaussian_copula._update_metadata.assert_called_once_with()

def test__sample(self):
"""Test the ``GaussianCopula._sample`` method.

Expand Down