Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate numerical_distrubutions parameter #1234

Merged
merged 6 commits into from
Feb 7, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions sdv/single_table/copulagan.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import rdt

from sdv.errors import SynthesizerInputError
from sdv.single_table.copulas import GaussianCopulaSynthesizer
from sdv.single_table.ctgan import CTGANSynthesizer

Expand Down Expand Up @@ -108,6 +109,19 @@ class CopulaGANSynthesizer(CTGANSynthesizer):

_gaussian_normalizer_hyper_transformer = None

def _validate_numerical_distributions(self, numerical_distributions):
if numerical_distributions:
if not isinstance(numerical_distributions, dict):
raise TypeError('numerical_distributions can only be None or a dict instance.')

invalid_columns = numerical_distributions.keys() - set(self.metadata._columns)
if invalid_columns:
raise SynthesizerInputError(
'Invalid column names found in the numerical_distributions dictionary '
f'{invalid_columns}. The column names you provide must be present '
'in the metadata.'
)
fealho marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True,
embedding_dim=128, generator_dim=(256, 256), discriminator_dim=(256, 256),
generator_lr=2e-4, generator_decay=1e-6, discriminator_lr=2e-4,
Expand Down Expand Up @@ -135,6 +149,7 @@ def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True,
cuda=cuda,
)

self._validate_numerical_distributions(numerical_distributions)
fealho marked this conversation as resolved.
Show resolved Hide resolved
self.numerical_distributions = numerical_distributions or {}
self.default_distribution = default_distribution or 'beta'

Expand Down
21 changes: 16 additions & 5 deletions sdv/single_table/copulas.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from copulas import multivariate
from rdt.transformers import OneHotEncoder

from sdv.errors import NonParametricError
from sdv.errors import NonParametricError, SynthesizerInputError
from sdv.single_table.base import BaseSingleTableSynthesizer
from sdv.single_table.utils import flatten_dict, unflatten_dict

Expand Down Expand Up @@ -83,18 +83,29 @@ def get_distribution_class(cls, distribution):

return cls._DISTRIBUTIONS[distribution]

def _validate_numerical_distributions(self, numerical_distributions):
if numerical_distributions:
if not isinstance(numerical_distributions, dict):
raise TypeError('numerical_distributions can only be None or a dict instance.')

invalid_columns = numerical_distributions.keys() - set(self.metadata._columns)
if invalid_columns:
raise SynthesizerInputError(
'Invalid column names found in the numerical_distributions dictionary '
f'{invalid_columns}. The column names you provide must be present '
'in the metadata.'
)

fealho marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True,
numerical_distributions=None, default_distribution=None):
super().__init__(
metadata,
enforce_min_max_values=enforce_min_max_values,
enforce_rounding=enforce_rounding,
)
if numerical_distributions and not isinstance(numerical_distributions, dict):
raise TypeError('numerical_distributions can only be None or a dict instance')

self.default_distribution = default_distribution or 'beta'
self._validate_numerical_distributions(numerical_distributions)
self.numerical_distributions = numerical_distributions or {}
self.default_distribution = default_distribution or 'beta'

self._default_distribution = self.get_distribution_class(self.default_distribution)
self._numerical_distributions = {
Expand Down
28 changes: 28 additions & 0 deletions tests/unit/single_table/test_copulagan.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from copulas.univariate import BetaUnivariate, GammaUnivariate, UniformUnivariate
from rdt.transformers import GaussianNormalizer

from sdv.errors import SynthesizerInputError
from sdv.metadata.single_table import SingleTableMetadata
from sdv.single_table.copulagan import CopulaGANSynthesizer

Expand Down Expand Up @@ -53,6 +54,7 @@ def test___init__custom(self):
"""Test creating an instance of ``CopulaGANSynthesizer`` with custom parameters."""
# Setup
metadata = SingleTableMetadata()
metadata.add_column('field', sdtype='numerical')
enforce_min_max_values = False
enforce_rounding = False
embedding_dim = 64
Expand Down Expand Up @@ -117,6 +119,32 @@ def test___init__custom(self):
assert instance.default_distribution == 'uniform'
assert instance._default_distribution == UniformUnivariate

def test___init__incorrect_numerical_distributions(self):
"""Test it crashes when ``numerical_distributions`` receives a non-dictionary."""
# Setup
metadata = SingleTableMetadata()
numerical_distributions = 'invalid'

# Run
err_msg = 'numerical_distributions can only be None or a dict instance.'
with pytest.raises(TypeError, match=err_msg):
CopulaGANSynthesizer(metadata, numerical_distributions=numerical_distributions)

def test___init__invalid_column_numerical_distributions(self):
"""Test it crashes when ``numerical_distributions`` includes invalid columns."""
# Setup
metadata = SingleTableMetadata()
numerical_distributions = {'totally_fake_column_name': 'beta'}

# Run
err_msg = re.escape(
'Invalid column names found in the numerical_distributions dictionary '
"{'totally_fake_column_name'}. The column names you provide must be present "
'in the metadata.'
)
with pytest.raises(SynthesizerInputError, match=err_msg):
CopulaGANSynthesizer(metadata, numerical_distributions=numerical_distributions)

def test_get_params(self):
"""Test that inherited method ``get_params`` returns all the specific init parameters."""
# Setup
Expand Down
30 changes: 30 additions & 0 deletions tests/unit/single_table/test_copulas.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import scipy
from copulas.univariate import BetaUnivariate, GammaUnivariate, UniformUnivariate

from sdv.errors import SynthesizerInputError
from sdv.metadata.single_table import SingleTableMetadata
from sdv.single_table.copulas import GaussianCopulaSynthesizer

Expand Down Expand Up @@ -64,6 +65,7 @@ def test___init__custom(self):
"""Test creating an instance of ``GaussianCopulaSynthesizer`` with custom parameters."""
# Setup
metadata = SingleTableMetadata()
metadata.add_column('field', sdtype='numerical')
enforce_min_max_values = False
enforce_rounding = False
numerical_distributions = {'field': 'gamma'}
Expand All @@ -86,6 +88,32 @@ def test___init__custom(self):
assert instance._default_distribution == UniformUnivariate
assert instance._numerical_distributions == {'field': GammaUnivariate}

def test___init__incorrect_numerical_distributions(self):
"""Test it crashes when ``numerical_distributions`` receives a non-dictionary."""
# Setup
metadata = SingleTableMetadata()
numerical_distributions = 'invalid'

# Run
err_msg = 'numerical_distributions can only be None or a dict instance.'
with pytest.raises(TypeError, match=err_msg):
GaussianCopulaSynthesizer(metadata, numerical_distributions=numerical_distributions)

def test___init__incorrect_column_numerical_distributions(self):
"""Test it crashes when ``numerical_distributions`` includes invalid columns."""
# Setup
metadata = SingleTableMetadata()
numerical_distributions = {'totally_fake_column_name': 'beta'}

# Run
err_msg = re.escape(
'Invalid column names found in the numerical_distributions dictionary '
"{'totally_fake_column_name'}. The column names you provide must be present "
'in the metadata.'
)
with pytest.raises(SynthesizerInputError, match=err_msg):
GaussianCopulaSynthesizer(metadata, numerical_distributions=numerical_distributions)

def test_get_parameters(self):
"""Test that inherited method ``get_parameters`` returns the specified init parameters."""
# Setup
Expand Down Expand Up @@ -114,6 +142,8 @@ def test__fit(self, mock_multivariate, mock_warnings):
"""
# Setup
metadata = SingleTableMetadata()
metadata.add_column('name', sdtype='numerical')
metadata.add_column('user.id', sdtype='numerical')
numerical_distributions = {'name': 'uniform', 'user.id': 'gamma'}

processed_data = pd.DataFrame({
Expand Down