Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate numerical_distrubutions parameter #1234

Merged
merged 6 commits into from
Feb 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sdv/single_table/copulagan.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from sdv.single_table.copulas import GaussianCopulaSynthesizer
from sdv.single_table.ctgan import CTGANSynthesizer
from sdv.single_table.utils import validate_numerical_distributions


class CopulaGANSynthesizer(CTGANSynthesizer):
Expand Down Expand Up @@ -135,6 +136,7 @@ def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True,
cuda=cuda,
)

validate_numerical_distributions(numerical_distributions, self.metadata._columns)
self.numerical_distributions = numerical_distributions or {}
self.default_distribution = default_distribution or 'beta'

Expand Down
8 changes: 3 additions & 5 deletions sdv/single_table/copulas.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from sdv.errors import NonParametricError
from sdv.single_table.base import BaseSingleTableSynthesizer
from sdv.single_table.utils import flatten_dict, unflatten_dict
from sdv.single_table.utils import flatten_dict, unflatten_dict, validate_numerical_distributions


class GaussianCopulaSynthesizer(BaseSingleTableSynthesizer):
Expand Down Expand Up @@ -90,11 +90,9 @@ def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True,
enforce_min_max_values=enforce_min_max_values,
enforce_rounding=enforce_rounding,
)
if numerical_distributions and not isinstance(numerical_distributions, dict):
raise TypeError('numerical_distributions can only be None or a dict instance')

self.default_distribution = default_distribution or 'beta'
validate_numerical_distributions(numerical_distributions, self.metadata._columns)
self.numerical_distributions = numerical_distributions or {}
self.default_distribution = default_distribution or 'beta'

self._default_distribution = self.get_distribution_class(self.default_distribution)
self._numerical_distributions = {
Expand Down
27 changes: 27 additions & 0 deletions sdv/single_table/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

import numpy as np

from sdv.errors import SynthesizerInputError

TMP_FILE_NAME = '.sample.csv.temp'
DISABLE_TMP_FILE = 'disable'
IGNORED_DICT_KEYS = ['fitted', 'distribution', 'type']
Expand Down Expand Up @@ -281,3 +283,28 @@ def unflatten_dict(flat):
unflattened[key] = value

return unflattened


def validate_numerical_distributions(numerical_distributions, metadata_columns):
"""Validate ``numerical_distributions``.

Raise an error if it's not None or dict, or if its columns are not present in the metadata.

Args:
numerical_distributions (dict):
Dictionary that maps field names from the table that is being modeled with
the distribution that needs to be used.
metadata_columns (list):
Columns present in the metadata.
"""
if numerical_distributions:
if not isinstance(numerical_distributions, dict):
raise TypeError('numerical_distributions can only be None or a dict instance.')

invalid_columns = numerical_distributions.keys() - set(metadata_columns)
if invalid_columns:
raise SynthesizerInputError(
'Invalid column names found in the numerical_distributions dictionary '
f'{invalid_columns}. The column names you provide must be present '
'in the metadata.'
)
28 changes: 28 additions & 0 deletions tests/unit/single_table/test_copulagan.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from copulas.univariate import BetaUnivariate, GammaUnivariate, UniformUnivariate
from rdt.transformers import GaussianNormalizer

from sdv.errors import SynthesizerInputError
from sdv.metadata.single_table import SingleTableMetadata
from sdv.single_table.copulagan import CopulaGANSynthesizer

Expand Down Expand Up @@ -53,6 +54,7 @@ def test___init__custom(self):
"""Test creating an instance of ``CopulaGANSynthesizer`` with custom parameters."""
# Setup
metadata = SingleTableMetadata()
metadata.add_column('field', sdtype='numerical')
enforce_min_max_values = False
enforce_rounding = False
embedding_dim = 64
Expand Down Expand Up @@ -117,6 +119,32 @@ def test___init__custom(self):
assert instance.default_distribution == 'uniform'
assert instance._default_distribution == UniformUnivariate

def test___init__incorrect_numerical_distributions(self):
"""Test it crashes when ``numerical_distributions`` receives a non-dictionary."""
# Setup
metadata = SingleTableMetadata()
numerical_distributions = 'invalid'

# Run
err_msg = 'numerical_distributions can only be None or a dict instance.'
with pytest.raises(TypeError, match=err_msg):
CopulaGANSynthesizer(metadata, numerical_distributions=numerical_distributions)

def test___init__invalid_column_numerical_distributions(self):
"""Test it crashes when ``numerical_distributions`` includes invalid columns."""
# Setup
metadata = SingleTableMetadata()
numerical_distributions = {'totally_fake_column_name': 'beta'}

# Run
err_msg = re.escape(
'Invalid column names found in the numerical_distributions dictionary '
"{'totally_fake_column_name'}. The column names you provide must be present "
'in the metadata.'
)
with pytest.raises(SynthesizerInputError, match=err_msg):
CopulaGANSynthesizer(metadata, numerical_distributions=numerical_distributions)

def test_get_params(self):
"""Test that inherited method ``get_params`` returns all the specific init parameters."""
# Setup
Expand Down
30 changes: 30 additions & 0 deletions tests/unit/single_table/test_copulas.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import scipy
from copulas.univariate import BetaUnivariate, GammaUnivariate, UniformUnivariate

from sdv.errors import SynthesizerInputError
from sdv.metadata.single_table import SingleTableMetadata
from sdv.single_table.copulas import GaussianCopulaSynthesizer

Expand Down Expand Up @@ -64,6 +65,7 @@ def test___init__custom(self):
"""Test creating an instance of ``GaussianCopulaSynthesizer`` with custom parameters."""
# Setup
metadata = SingleTableMetadata()
metadata.add_column('field', sdtype='numerical')
enforce_min_max_values = False
enforce_rounding = False
numerical_distributions = {'field': 'gamma'}
Expand All @@ -86,6 +88,32 @@ def test___init__custom(self):
assert instance._default_distribution == UniformUnivariate
assert instance._numerical_distributions == {'field': GammaUnivariate}

def test___init__incorrect_numerical_distributions(self):
"""Test it crashes when ``numerical_distributions`` receives a non-dictionary."""
# Setup
metadata = SingleTableMetadata()
numerical_distributions = 'invalid'

# Run
err_msg = 'numerical_distributions can only be None or a dict instance.'
with pytest.raises(TypeError, match=err_msg):
GaussianCopulaSynthesizer(metadata, numerical_distributions=numerical_distributions)

def test___init__incorrect_column_numerical_distributions(self):
"""Test it crashes when ``numerical_distributions`` includes invalid columns."""
# Setup
metadata = SingleTableMetadata()
numerical_distributions = {'totally_fake_column_name': 'beta'}

# Run
err_msg = re.escape(
'Invalid column names found in the numerical_distributions dictionary '
"{'totally_fake_column_name'}. The column names you provide must be present "
'in the metadata.'
)
with pytest.raises(SynthesizerInputError, match=err_msg):
GaussianCopulaSynthesizer(metadata, numerical_distributions=numerical_distributions)

def test_get_parameters(self):
"""Test that inherited method ``get_parameters`` returns the specified init parameters."""
# Setup
Expand Down Expand Up @@ -114,6 +142,8 @@ def test__fit(self, mock_multivariate, mock_warnings):
"""
# Setup
metadata = SingleTableMetadata()
metadata.add_column('name', sdtype='numerical')
metadata.add_column('user.id', sdtype='numerical')
numerical_distributions = {'name': 'uniform', 'user.id': 'gamma'}

processed_data = pd.DataFrame({
Expand Down