diff --git a/sdv/single_table/copulagan.py b/sdv/single_table/copulagan.py index 36348a914..27ac770a3 100644 --- a/sdv/single_table/copulagan.py +++ b/sdv/single_table/copulagan.py @@ -5,6 +5,7 @@ from sdv.single_table.copulas import GaussianCopulaSynthesizer from sdv.single_table.ctgan import CTGANSynthesizer +from sdv.single_table.utils import validate_numerical_distributions class CopulaGANSynthesizer(CTGANSynthesizer): @@ -135,6 +136,7 @@ def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True, cuda=cuda, ) + validate_numerical_distributions(numerical_distributions, self.metadata._columns) self.numerical_distributions = numerical_distributions or {} self.default_distribution = default_distribution or 'beta' diff --git a/sdv/single_table/copulas.py b/sdv/single_table/copulas.py index 68b6b0b3f..4dd13c88e 100644 --- a/sdv/single_table/copulas.py +++ b/sdv/single_table/copulas.py @@ -12,7 +12,7 @@ from sdv.errors import NonParametricError from sdv.single_table.base import BaseSingleTableSynthesizer -from sdv.single_table.utils import flatten_dict, unflatten_dict +from sdv.single_table.utils import flatten_dict, unflatten_dict, validate_numerical_distributions class GaussianCopulaSynthesizer(BaseSingleTableSynthesizer): @@ -90,11 +90,9 @@ def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True, enforce_min_max_values=enforce_min_max_values, enforce_rounding=enforce_rounding, ) - if numerical_distributions and not isinstance(numerical_distributions, dict): - raise TypeError('numerical_distributions can only be None or a dict instance') - - self.default_distribution = default_distribution or 'beta' + validate_numerical_distributions(numerical_distributions, self.metadata._columns) self.numerical_distributions = numerical_distributions or {} + self.default_distribution = default_distribution or 'beta' self._default_distribution = self.get_distribution_class(self.default_distribution) self._numerical_distributions = { diff --git a/sdv/single_table/utils.py b/sdv/single_table/utils.py index 953440e81..144eb3a68 100644 --- a/sdv/single_table/utils.py +++ b/sdv/single_table/utils.py @@ -5,6 +5,8 @@ import numpy as np +from sdv.errors import SynthesizerInputError + TMP_FILE_NAME = '.sample.csv.temp' DISABLE_TMP_FILE = 'disable' IGNORED_DICT_KEYS = ['fitted', 'distribution', 'type'] @@ -281,3 +283,28 @@ def unflatten_dict(flat): unflattened[key] = value return unflattened + + +def validate_numerical_distributions(numerical_distributions, metadata_columns): + """Validate ``numerical_distributions``. + + Raise an error if it's not None or dict, or if its columns are not present in the metadata. + + Args: + numerical_distributions (dict): + Dictionary that maps field names from the table that is being modeled with + the distribution that needs to be used. + metadata_columns (list): + Columns present in the metadata. + """ + if numerical_distributions: + if not isinstance(numerical_distributions, dict): + raise TypeError('numerical_distributions can only be None or a dict instance.') + + invalid_columns = numerical_distributions.keys() - set(metadata_columns) + if invalid_columns: + raise SynthesizerInputError( + 'Invalid column names found in the numerical_distributions dictionary ' + f'{invalid_columns}. The column names you provide must be present ' + 'in the metadata.' + ) diff --git a/tests/unit/single_table/test_copulagan.py b/tests/unit/single_table/test_copulagan.py index 4757ed75e..f3301c248 100644 --- a/tests/unit/single_table/test_copulagan.py +++ b/tests/unit/single_table/test_copulagan.py @@ -7,6 +7,7 @@ from copulas.univariate import BetaUnivariate, GammaUnivariate, UniformUnivariate from rdt.transformers import GaussianNormalizer +from sdv.errors import SynthesizerInputError from sdv.metadata.single_table import SingleTableMetadata from sdv.single_table.copulagan import CopulaGANSynthesizer @@ -53,6 +54,7 @@ def test___init__custom(self): """Test creating an instance of ``CopulaGANSynthesizer`` with custom parameters.""" # Setup metadata = SingleTableMetadata() + metadata.add_column('field', sdtype='numerical') enforce_min_max_values = False enforce_rounding = False embedding_dim = 64 @@ -117,6 +119,32 @@ def test___init__custom(self): assert instance.default_distribution == 'uniform' assert instance._default_distribution == UniformUnivariate + def test___init__incorrect_numerical_distributions(self): + """Test it crashes when ``numerical_distributions`` receives a non-dictionary.""" + # Setup + metadata = SingleTableMetadata() + numerical_distributions = 'invalid' + + # Run + err_msg = 'numerical_distributions can only be None or a dict instance.' + with pytest.raises(TypeError, match=err_msg): + CopulaGANSynthesizer(metadata, numerical_distributions=numerical_distributions) + + def test___init__invalid_column_numerical_distributions(self): + """Test it crashes when ``numerical_distributions`` includes invalid columns.""" + # Setup + metadata = SingleTableMetadata() + numerical_distributions = {'totally_fake_column_name': 'beta'} + + # Run + err_msg = re.escape( + 'Invalid column names found in the numerical_distributions dictionary ' + "{'totally_fake_column_name'}. The column names you provide must be present " + 'in the metadata.' + ) + with pytest.raises(SynthesizerInputError, match=err_msg): + CopulaGANSynthesizer(metadata, numerical_distributions=numerical_distributions) + def test_get_params(self): """Test that inherited method ``get_params`` returns all the specific init parameters.""" # Setup diff --git a/tests/unit/single_table/test_copulas.py b/tests/unit/single_table/test_copulas.py index 96ba54dbd..61fcf7b4e 100644 --- a/tests/unit/single_table/test_copulas.py +++ b/tests/unit/single_table/test_copulas.py @@ -7,6 +7,7 @@ import scipy from copulas.univariate import BetaUnivariate, GammaUnivariate, UniformUnivariate +from sdv.errors import SynthesizerInputError from sdv.metadata.single_table import SingleTableMetadata from sdv.single_table.copulas import GaussianCopulaSynthesizer @@ -64,6 +65,7 @@ def test___init__custom(self): """Test creating an instance of ``GaussianCopulaSynthesizer`` with custom parameters.""" # Setup metadata = SingleTableMetadata() + metadata.add_column('field', sdtype='numerical') enforce_min_max_values = False enforce_rounding = False numerical_distributions = {'field': 'gamma'} @@ -86,6 +88,32 @@ def test___init__custom(self): assert instance._default_distribution == UniformUnivariate assert instance._numerical_distributions == {'field': GammaUnivariate} + def test___init__incorrect_numerical_distributions(self): + """Test it crashes when ``numerical_distributions`` receives a non-dictionary.""" + # Setup + metadata = SingleTableMetadata() + numerical_distributions = 'invalid' + + # Run + err_msg = 'numerical_distributions can only be None or a dict instance.' + with pytest.raises(TypeError, match=err_msg): + GaussianCopulaSynthesizer(metadata, numerical_distributions=numerical_distributions) + + def test___init__incorrect_column_numerical_distributions(self): + """Test it crashes when ``numerical_distributions`` includes invalid columns.""" + # Setup + metadata = SingleTableMetadata() + numerical_distributions = {'totally_fake_column_name': 'beta'} + + # Run + err_msg = re.escape( + 'Invalid column names found in the numerical_distributions dictionary ' + "{'totally_fake_column_name'}. The column names you provide must be present " + 'in the metadata.' + ) + with pytest.raises(SynthesizerInputError, match=err_msg): + GaussianCopulaSynthesizer(metadata, numerical_distributions=numerical_distributions) + def test_get_parameters(self): """Test that inherited method ``get_parameters`` returns the specified init parameters.""" # Setup @@ -114,6 +142,8 @@ def test__fit(self, mock_multivariate, mock_warnings): """ # Setup metadata = SingleTableMetadata() + metadata.add_column('name', sdtype='numerical') + metadata.add_column('user.id', sdtype='numerical') numerical_distributions = {'name': 'uniform', 'user.id': 'gamma'} processed_data = pd.DataFrame({